diff --git a/README.md b/README.md index a92a604..67c29b5 100644 --- a/README.md +++ b/README.md @@ -51,13 +51,12 @@ ai_provider_config: chat_completion_url: "http://localhost:11434/v1/chat/completions" chat_completion_model: "gpt-4o" embedding_url: "http://localhost:11434/v1/embeddings" (Optional, If you want use RAG.) - embedding_model: "text-embedding-3-small" (Optional, If you want use RAG.) + embedding_model: "text-embedding-ada-002" (Optional, If you want use RAG.) temperature: 0.2 + max_tokens: 128000 theme: "dracula" RAG: true (Optional, if you want, can disable RAG.) ``` -> Note: We used the standard integration of [OpenAI APIs](https://platform.openai.com/docs/api-reference/introduction) and [Ollama APIs](https://github.com/ollama/ollama/blob/main/docs/api.md) and you can find more details in documentation of each APIs. - If you wish to customize your configuration, you can create your own `config.yml` file and place it in the `root directory` of each project you want to analyze with codai. If no configuration file is provided, codai will use the default settings. You can also specify a configuration file from any directory by using the following CLI command: @@ -70,7 +69,7 @@ codai code --provider_name openapi --temperature 0.8 ``` This flexibility allows you to customize config of codai on the fly. -> Note: We used [Chroma](https://github.com/alecthomas/chroma) for `style` of our `text` and `code block`, and you can find more theme here in [Chroma Style Gallery](https://xyproto.github.io/splash/docs/) and use it as a `theme` in `codai`. +> Note: We use [Chroma](https://github.com/alecthomas/chroma) for `style` of our `text` and `code block`, and you can find more theme here in [Chroma Style Gallery](https://xyproto.github.io/splash/docs/) and use it as a `theme` in `codai`. ## ๐Ÿ”ฎ LLM Models ### โšก Best Models @@ -80,12 +79,12 @@ The codai works well with advanced LLM models specifically designed for code gen In addition to cloud-based models, codai is compatible with local models such as `Ollama`. To achieve the best results, it is recommended to utilize models like `DeepSeek-Coder-v2`, `CodeLlama`, and `Mistral`. These models have been optimized for coding tasks, ensuring that you can maximize the efficiency and effectiveness of your coding projects. ### ๐ŸŒ OpenAI Embedding Models -The codai platform uses `OpenAI embedding models` to retrieve `relevant content` with high efficiency. Recommended models include are **text-embedding-3-large**, **text-embedding-3-small**, and **text-embedding-ada-002**, both known for their `cost-effectiveness` and `accuracy` in `capturing semantic relationships`. These models are ideal for applications needing high-quality performance in `code context retrieval`. +The codai can utilize `OpenAIโ€™s embedding models` to retrieve the `most relevant content`. The current recommended model for `code context` is `text-embedding-ada-002`, known for its high performance and capability in capturing semantic relationships, making it an excellent choice for accurate and efficient embedding retrieval. ### ๐Ÿฆ™ Ollama Embedding Models -codai also supports `Ollama embedding models` for `local`, `cost-effective`, and `efficient` embedding generation and `retrieval of relevant content`. Models such as **mxbai-embed-large**, **all-minilm**, and **nomic-embed-text** provide **effective**, **private embedding** creation optimized for high-quality performance. These models are well-suited for `RAG-based retrieval in code contexts`, eliminating the need for external API calls. +The codai also supports `Ollama embedding models`, allowing `local embedding` generation and retrieval. A suitable option here is the `nomic-embed-text model`, which provides efficient embedding generation locally, aiding in effective RAG-based retrieval `for relevant code context`. -## โ–ถ๏ธ How to Run +How to Run To use `codai` as your code assistant, navigate to the directory where you want to apply codai and run the following command: ```bash @@ -125,7 +124,7 @@ Allow users to customize settings through a config file (e.g., changing AI provi ๐Ÿ“Š **Project Context Awareness:** Maintain awareness of the entire project context to provide more accurate suggestions. -๐ŸŒณ **Full Project Context Summarization:** +๐ŸŒณ **Full Project Context Summarization:** Summarize the full context of your codebase using Tree-sitter for accurate and efficient code analysis. ๐Ÿ” **RAG System Implementation:** diff --git a/cmd/code.go b/cmd/code.go index 3630999..be48012 100644 --- a/cmd/code.go +++ b/cmd/code.go @@ -42,8 +42,6 @@ func handleCodeCommand(rootDependencies *RootDependencies) { signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) - loopNumber := 0 - reader := bufio.NewReader(os.Stdin) var requestedContext string @@ -67,16 +65,8 @@ func handleCodeCommand(rootDependencies *RootDependencies) { // Launch the user input handler in a goroutine go func() { - startLoop: // Label for the start loop - for { - - if loopNumber > 0 { - // Display token usage details in a boxed format after each AI request - rootDependencies.TokenManagement.DisplayTokens(rootDependencies.Config.AIProviderConfig.ProviderName, rootDependencies.Config.AIProviderConfig.ChatCompletionModel, rootDependencies.Config.AIProviderConfig.EmbeddingModel, rootDependencies.Config.RAG) - } - - loopNumber++ + for { err := utils.CleanupTempFiles(rootDependencies.Cwd) if err != nil { fmt.Println(lipgloss_color.Red.Render(fmt.Sprintf("failed to cleanup temp files: %v", err))) @@ -101,13 +91,15 @@ func handleCodeCommand(rootDependencies *RootDependencies) { go func(dataFile models.FileData) { defer wg.Done() // Decrement the counter when the Goroutine completes filesEmbeddingOperation := func() error { - fileEmbedding, err := rootDependencies.CurrentProvider.EmbeddingRequest(ctx, dataFile.TreeSitterCode) + fileEmbeddingResponse, err := rootDependencies.CurrentProvider.EmbeddingRequest(ctx, dataFile.TreeSitterCode) if err != nil { return err } + fileEmbedding := fileEmbeddingResponse.Data[0].Embedding + // Save embeddings to the embedding store - rootDependencies.Store.Save(dataFile.RelativePath, dataFile.Code, fileEmbedding[0]) + rootDependencies.Store.Save(dataFile.RelativePath, dataFile.Code, fileEmbedding) return nil } @@ -124,18 +116,20 @@ func handleCodeCommand(rootDependencies *RootDependencies) { for err = range errorChan { spinnerLoadContextEmbedding.Stop() fmt.Println(lipgloss_color.Red.Render(fmt.Sprintf("%v", err))) - continue startLoop + continue } queryEmbeddingOperation := func() error { // Step 5: Generate embedding for the user query - queryEmbedding, err := rootDependencies.CurrentProvider.EmbeddingRequest(ctx, userInput) + queryEmbeddingResponse, err := rootDependencies.CurrentProvider.EmbeddingRequest(ctx, userInput) if err != nil { return err } + queryEmbedding := queryEmbeddingResponse.Data[0].Embedding + // Ensure there's an embedding for the user query - if len(queryEmbedding[0]) == 0 { + if len(queryEmbedding) == 0 { return fmt.Errorf(lipgloss_color.Red.Render("no embeddings returned for user query")) } @@ -143,7 +137,7 @@ func handleCodeCommand(rootDependencies *RootDependencies) { topN := -1 // Step 6: Find relevant code chunks based on the user query embedding - fullContextCodes = rootDependencies.Store.FindRelevantChunks(queryEmbedding[0], topN, rootDependencies.Config.AIProviderConfig.EmbeddingModel, rootDependencies.Config.AIProviderConfig.Threshold) + fullContextCodes = rootDependencies.Store.FindRelevantChunks(queryEmbedding, topN, rootDependencies.Config.AIProviderConfig.EmbeddingModel, rootDependencies.Config.AIProviderConfig.Threshold) return nil } @@ -153,10 +147,9 @@ func handleCodeCommand(rootDependencies *RootDependencies) { if err != nil { spinnerLoadContextEmbedding.Stop() fmt.Println(lipgloss_color.Red.Render(fmt.Sprintf("%v", err))) - continue startLoop + continue } - fmt.Println() spinnerLoadContextEmbedding.Stop() } @@ -195,14 +188,21 @@ func handleCodeCommand(rootDependencies *RootDependencies) { if err != nil { fmt.Println(lipgloss_color.Red.Render(fmt.Sprintf("%v", err))) - continue startLoop + continue } + fmt.Print("\n\n") + if !rootDependencies.Config.RAG { // Try to get full block code if block codes is summarized and incomplete requestedContext, err = rootDependencies.Analyzer.TryGetInCompletedCodeBlocK(aiResponseBuilder.String()) - if requestedContext != "" && err == nil { + if err != nil { + fmt.Println(lipgloss_color.Red.Render(fmt.Sprintf("%v", err))) + continue + } + + if requestedContext != "" { aiResponseBuilder.Reset() fmt.Println(lipgloss_color.BlueSky.Render("Trying to send above context files for getting code suggestion fromm AI...\n")) @@ -220,7 +220,7 @@ func handleCodeCommand(rootDependencies *RootDependencies) { changes, err := rootDependencies.Analyzer.ExtractCodeChanges(aiResponseBuilder.String()) if err != nil || changes == nil { - fmt.Println(lipgloss_color.BlueSky.Render("\nno code blocks with a valid path detected to apply.")) + fmt.Println(lipgloss_color.Gray.Render("no code blocks with a valid path detected to apply.")) continue } @@ -265,6 +265,9 @@ func handleCodeCommand(rootDependencies *RootDependencies) { } + // Display token usage details in a boxed format after each AI request + rootDependencies.TokenManagement.DisplayTokens(rootDependencies.Config.AIProviderConfig.ChatCompletionModel, rootDependencies.Config.AIProviderConfig.EmbeddingModel) + // If we need Update the context after apply changes if updateContextNeeded { diff --git a/cmd/root.go b/cmd/root.go index 8d04c1a..0baac3e 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -58,7 +58,7 @@ func handleRootCommand(cmd *cobra.Command) *RootDependencies { rootDependencies.Config = config.LoadConfigs(cmd, rootDependencies.Cwd) - rootDependencies.TokenManagement = providers.NewTokenManager() + rootDependencies.TokenManagement = providers.NewTokenManager(rootDependencies.Config.AIProviderConfig.MaxTokens) rootDependencies.ChatHistory = providers.NewChatHistory() diff --git a/embed_data/embed.go b/embed_data/embed.go index 71d3576..0347d96 100644 --- a/embed_data/embed.go +++ b/embed_data/embed.go @@ -5,9 +5,6 @@ import _ "embed" //go:embed prompts/code_block_prompt.tmpl var CodeBlockTemplate []byte -//go:embed models_details/model_details.tmpl -var ModelDetails []byte - //go:embed tree-sitter/queries/csharp.scm var CSharpQuery []byte diff --git a/embed_data/models_details/model_details.tmpl b/embed_data/models_details/model_details.tmpl deleted file mode 100644 index b5e5c8f..0000000 --- a/embed_data/models_details/model_details.tmpl +++ /dev/null @@ -1,978 +0,0 @@ -{ - "models": { - "gpt-4": { - "max_tokens": 4096, - "max_input_tokens": 8192, - "max_output_tokens": 4096, - "input_cost_per_token": 0.00003, - "output_cost_per_token": 0.00006, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_prompt_caching": true - }, - "gpt-4o": { - "max_tokens": 4096, - "max_input_tokens": 128000, - "max_output_tokens": 4096, - "input_cost_per_token": 0.000005, - "output_cost_per_token": 0.000015, - "cache_read_input_token_cost": 0.00000125, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true, - "supports_prompt_caching": true - }, - "gpt-4o-audio-preview": { - "max_tokens": 16384, - "max_input_tokens": 128000, - "max_output_tokens": 16384, - "input_cost_per_token": 0.0000025, - "input_cost_per_audio_token": 0.0001, - "output_cost_per_token": 0.000010, - "output_cost_per_audio_token": 0.0002, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_audio_input": true, - "supports_audio_output": true - }, - "gpt-4o-audio-preview-2024-10-01": { - "max_tokens": 16384, - "max_input_tokens": 128000, - "max_output_tokens": 16384, - "input_cost_per_token": 0.0000025, - "input_cost_per_audio_token": 0.0001, - "output_cost_per_token": 0.000010, - "output_cost_per_audio_token": 0.0002, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_audio_input": true, - "supports_audio_output": true - }, - "gpt-4o-mini": { - "max_tokens": 16384, - "max_input_tokens": 128000, - "max_output_tokens": 16384, - "input_cost_per_token": 0.00000015, - "output_cost_per_token": 0.00000060, - "cache_read_input_token_cost": 0.000000075, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true, - "supports_prompt_caching": true - }, - "gpt-4o-mini-2024-07-18": { - "max_tokens": 16384, - "max_input_tokens": 128000, - "max_output_tokens": 16384, - "input_cost_per_token": 0.00000015, - "output_cost_per_token": 0.00000060, - "cache_read_input_token_cost": 0.000000075, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true, - "supports_prompt_caching": true - }, - "o1-mini": { - "max_tokens": 65536, - "max_input_tokens": 128000, - "max_output_tokens": 65536, - "input_cost_per_token": 0.000003, - "output_cost_per_token": 0.000012, - "cache_read_input_token_cost": 0.0000015, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true, - "supports_prompt_caching": true - }, - "o1-mini-2024-09-12": { - "max_tokens": 65536, - "max_input_tokens": 128000, - "max_output_tokens": 65536, - "input_cost_per_token": 0.000003, - "output_cost_per_token": 0.000012, - "cache_read_input_token_cost": 0.0000015, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true, - "supports_prompt_caching": true - }, - "o1-preview": { - "max_tokens": 32768, - "max_input_tokens": 128000, - "max_output_tokens": 32768, - "input_cost_per_token": 0.000015, - "output_cost_per_token": 0.000060, - "cache_read_input_token_cost": 0.0000075, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true, - "supports_prompt_caching": true - }, - "o1-preview-2024-09-12": { - "max_tokens": 32768, - "max_input_tokens": 128000, - "max_output_tokens": 32768, - "input_cost_per_token": 0.000015, - "output_cost_per_token": 0.000060, - "cache_read_input_token_cost": 0.0000075, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true, - "supports_prompt_caching": true - }, - "chatgpt-4o-latest": { - "max_tokens": 4096, - "max_input_tokens": 128000, - "max_output_tokens": 4096, - "input_cost_per_token": 0.000005, - "output_cost_per_token": 0.000015, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true, - "supports_prompt_caching": true - }, - "gpt-4o-2024-05-13": { - "max_tokens": 4096, - "max_input_tokens": 128000, - "max_output_tokens": 4096, - "input_cost_per_token": 0.000005, - "output_cost_per_token": 0.000015, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true, - "supports_prompt_caching": true - }, - "gpt-4o-2024-08-06": { - "max_tokens": 16384, - "max_input_tokens": 128000, - "max_output_tokens": 16384, - "input_cost_per_token": 0.0000025, - "output_cost_per_token": 0.000010, - "cache_read_input_token_cost": 0.00000125, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true, - "supports_prompt_caching": true - }, - "gpt-4-turbo-preview": { - "max_tokens": 4096, - "max_input_tokens": 128000, - "max_output_tokens": 4096, - "input_cost_per_token": 0.00001, - "output_cost_per_token": 0.00003, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_prompt_caching": true - }, - "gpt-4-0314": { - "max_tokens": 4096, - "max_input_tokens": 8192, - "max_output_tokens": 4096, - "input_cost_per_token": 0.00003, - "output_cost_per_token": 0.00006, - "litellm_provider": "openai", - "mode": "chat", - "supports_prompt_caching": true - }, - "gpt-4-0613": { - "max_tokens": 4096, - "max_input_tokens": 8192, - "max_output_tokens": 4096, - "input_cost_per_token": 0.00003, - "output_cost_per_token": 0.00006, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_prompt_caching": true - }, - "gpt-4-32k": { - "max_tokens": 4096, - "max_input_tokens": 32768, - "max_output_tokens": 4096, - "input_cost_per_token": 0.00006, - "output_cost_per_token": 0.00012, - "litellm_provider": "openai", - "mode": "chat", - "supports_prompt_caching": true - }, - "gpt-4-32k-0314": { - "max_tokens": 4096, - "max_input_tokens": 32768, - "max_output_tokens": 4096, - "input_cost_per_token": 0.00006, - "output_cost_per_token": 0.00012, - "litellm_provider": "openai", - "mode": "chat", - "supports_prompt_caching": true - }, - "gpt-4-32k-0613": { - "max_tokens": 4096, - "max_input_tokens": 32768, - "max_output_tokens": 4096, - "input_cost_per_token": 0.00006, - "output_cost_per_token": 0.00012, - "litellm_provider": "openai", - "mode": "chat", - "supports_prompt_caching": true - }, - "gpt-4-turbo": { - "max_tokens": 4096, - "max_input_tokens": 128000, - "max_output_tokens": 4096, - "input_cost_per_token": 0.00001, - "output_cost_per_token": 0.00003, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true, - "supports_prompt_caching": true - }, - "gpt-4-turbo-2024-04-09": { - "max_tokens": 4096, - "max_input_tokens": 128000, - "max_output_tokens": 4096, - "input_cost_per_token": 0.00001, - "output_cost_per_token": 0.00003, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true, - "supports_prompt_caching": true - }, - "gpt-4-1106-preview": { - "max_tokens": 4096, - "max_input_tokens": 128000, - "max_output_tokens": 4096, - "input_cost_per_token": 0.00001, - "output_cost_per_token": 0.00003, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_prompt_caching": true - }, - "gpt-4-0125-preview": { - "max_tokens": 4096, - "max_input_tokens": 128000, - "max_output_tokens": 4096, - "input_cost_per_token": 0.00001, - "output_cost_per_token": 0.00003, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_prompt_caching": true - }, - "gpt-4-vision-preview": { - "max_tokens": 4096, - "max_input_tokens": 128000, - "max_output_tokens": 4096, - "input_cost_per_token": 0.00001, - "output_cost_per_token": 0.00003, - "litellm_provider": "openai", - "mode": "chat", - "supports_vision": true, - "supports_prompt_caching": true - }, - "gpt-4-1106-vision-preview": { - "max_tokens": 4096, - "max_input_tokens": 128000, - "max_output_tokens": 4096, - "input_cost_per_token": 0.00001, - "output_cost_per_token": 0.00003, - "litellm_provider": "openai", - "mode": "chat", - "supports_vision": true, - "supports_prompt_caching": true - }, - "gpt-3.5-turbo": { - "max_tokens": 4097, - "max_input_tokens": 16385, - "max_output_tokens": 4096, - "input_cost_per_token": 0.0000015, - "output_cost_per_token": 0.000002, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_prompt_caching": true - }, - "gpt-3.5-turbo-0301": { - "max_tokens": 4097, - "max_input_tokens": 4097, - "max_output_tokens": 4096, - "input_cost_per_token": 0.0000015, - "output_cost_per_token": 0.000002, - "litellm_provider": "openai", - "mode": "chat", - "supports_prompt_caching": true - }, - "gpt-3.5-turbo-0613": { - "max_tokens": 4097, - "max_input_tokens": 4097, - "max_output_tokens": 4096, - "input_cost_per_token": 0.0000015, - "output_cost_per_token": 0.000002, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_prompt_caching": true - }, - "gpt-3.5-turbo-1106": { - "max_tokens": 16385, - "max_input_tokens": 16385, - "max_output_tokens": 4096, - "input_cost_per_token": 0.0000010, - "output_cost_per_token": 0.0000020, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_prompt_caching": true - }, - "gpt-3.5-turbo-0125": { - "max_tokens": 16385, - "max_input_tokens": 16385, - "max_output_tokens": 4096, - "input_cost_per_token": 0.0000005, - "output_cost_per_token": 0.0000015, - "litellm_provider": "openai", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_prompt_caching": true - }, - "gpt-3.5-turbo-16k": { - "max_tokens": 16385, - "max_input_tokens": 16385, - "max_output_tokens": 4096, - "input_cost_per_token": 0.000003, - "output_cost_per_token": 0.000004, - "litellm_provider": "openai", - "mode": "chat", - "supports_prompt_caching": true - }, - "gpt-3.5-turbo-16k-0613": { - "max_tokens": 16385, - "max_input_tokens": 16385, - "max_output_tokens": 4096, - "input_cost_per_token": 0.000003, - "output_cost_per_token": 0.000004, - "litellm_provider": "openai", - "mode": "chat", - "supports_prompt_caching": true - }, - "text-embedding-3-large": { - "max_tokens": 8191, - "max_input_tokens": 8191, - "output_vector_size": 3072, - "input_cost_per_token": 0.00000013, - "output_cost_per_token": 0.000000, - "litellm_provider": "openai", - "mode": "embedding" - }, - "text-embedding-3-small": { - "max_tokens": 8191, - "max_input_tokens": 8191, - "output_vector_size": 1536, - "input_cost_per_token": 0.00000002, - "output_cost_per_token": 0.000000, - "litellm_provider": "openai", - "mode": "embedding" - }, - "text-embedding-ada-002": { - "max_tokens": 8191, - "max_input_tokens": 8191, - "output_vector_size": 1536, - "input_cost_per_token": 0.0000001, - "output_cost_per_token": 0.000000, - "litellm_provider": "openai", - "mode": "embedding" - }, - "text-embedding-ada-002-v2": { - "max_tokens": 8191, - "max_input_tokens": 8191, - "input_cost_per_token": 0.0000001, - "output_cost_per_token": 0.000000, - "litellm_provider": "openai", - "mode": "embedding" - }, - "azure/o1-mini": { - "max_tokens": 65536, - "max_input_tokens": 128000, - "max_output_tokens": 65536, - "input_cost_per_token": 0.000003, - "output_cost_per_token": 0.000012, - "cache_read_input_token_cost": 0.0000015, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true, - "supports_prompt_caching": true - }, - "azure/o1-mini-2024-09-12": { - "max_tokens": 65536, - "max_input_tokens": 128000, - "max_output_tokens": 65536, - "input_cost_per_token": 0.000003, - "output_cost_per_token": 0.000012, - "cache_read_input_token_cost": 0.0000015, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true, - "supports_prompt_caching": true - }, - "azure/o1-preview": { - "max_tokens": 32768, - "max_input_tokens": 128000, - "max_output_tokens": 32768, - "input_cost_per_token": 0.000015, - "output_cost_per_token": 0.000060, - "cache_read_input_token_cost": 0.0000075, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true, - "supports_prompt_caching": true - }, - "azure/o1-preview-2024-09-12": { - "max_tokens": 32768, - "max_input_tokens": 128000, - "max_output_tokens": 32768, - "input_cost_per_token": 0.000015, - "output_cost_per_token": 0.000060, - "cache_read_input_token_cost": 0.0000075, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true, - "supports_prompt_caching": true - }, - "azure/gpt-4o": { - "max_tokens": 4096, - "max_input_tokens": 128000, - "max_output_tokens": 4096, - "input_cost_per_token": 0.000005, - "output_cost_per_token": 0.000015, - "cache_read_input_token_cost": 0.00000125, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true, - "supports_prompt_caching": true - }, - "azure/gpt-4o-2024-08-06": { - "max_tokens": 16384, - "max_input_tokens": 128000, - "max_output_tokens": 16384, - "input_cost_per_token": 0.00000275, - "output_cost_per_token": 0.000011, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true - }, - "azure/gpt-4o-2024-05-13": { - "max_tokens": 4096, - "max_input_tokens": 128000, - "max_output_tokens": 4096, - "input_cost_per_token": 0.000005, - "output_cost_per_token": 0.000015, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true, - "supports_prompt_caching": true - }, - "azure/global-standard/gpt-4o-2024-08-06": { - "max_tokens": 16384, - "max_input_tokens": 128000, - "max_output_tokens": 16384, - "input_cost_per_token": 0.0000025, - "output_cost_per_token": 0.000010, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true - }, - "azure/global-standard/gpt-4o-mini": { - "max_tokens": 16384, - "max_input_tokens": 128000, - "max_output_tokens": 16384, - "input_cost_per_token": 0.00000015, - "output_cost_per_token": 0.00000060, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true - }, - "azure/gpt-4o-mini": { - "max_tokens": 16384, - "max_input_tokens": 128000, - "max_output_tokens": 16384, - "input_cost_per_token": 0.000000165, - "output_cost_per_token": 0.00000066, - "cache_read_input_token_cost": 0.000000075, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true, - "supports_prompt_caching": true - }, - "azure/gpt-4o-mini-2024-07-18": { - "max_tokens": 16384, - "max_input_tokens": 128000, - "max_output_tokens": 16384, - "input_cost_per_token": 0.000000165, - "output_cost_per_token": 0.00000066, - "cache_read_input_token_cost": 0.000000075, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true, - "supports_prompt_caching": true - }, - "azure/gpt-4-turbo-2024-04-09": { - "max_tokens": 4096, - "max_input_tokens": 128000, - "max_output_tokens": 4096, - "input_cost_per_token": 0.00001, - "output_cost_per_token": 0.00003, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true - }, - "azure/gpt-4-0125-preview": { - "max_tokens": 4096, - "max_input_tokens": 128000, - "max_output_tokens": 4096, - "input_cost_per_token": 0.00001, - "output_cost_per_token": 0.00003, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true - }, - "azure/gpt-4-1106-preview": { - "max_tokens": 4096, - "max_input_tokens": 128000, - "max_output_tokens": 4096, - "input_cost_per_token": 0.00001, - "output_cost_per_token": 0.00003, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true - }, - "azure/gpt-4-0613": { - "max_tokens": 4096, - "max_input_tokens": 8192, - "max_output_tokens": 4096, - "input_cost_per_token": 0.00003, - "output_cost_per_token": 0.00006, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true - }, - "azure/gpt-4-32k-0613": { - "max_tokens": 4096, - "max_input_tokens": 32768, - "max_output_tokens": 4096, - "input_cost_per_token": 0.00006, - "output_cost_per_token": 0.00012, - "litellm_provider": "azure", - "mode": "chat" - }, - "azure/gpt-4-32k": { - "max_tokens": 4096, - "max_input_tokens": 32768, - "max_output_tokens": 4096, - "input_cost_per_token": 0.00006, - "output_cost_per_token": 0.00012, - "litellm_provider": "azure", - "mode": "chat" - }, - "azure/gpt-4": { - "max_tokens": 4096, - "max_input_tokens": 8192, - "max_output_tokens": 4096, - "input_cost_per_token": 0.00003, - "output_cost_per_token": 0.00006, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true - }, - "azure/gpt-4-turbo": { - "max_tokens": 4096, - "max_input_tokens": 128000, - "max_output_tokens": 4096, - "input_cost_per_token": 0.00001, - "output_cost_per_token": 0.00003, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true - }, - "azure/gpt-4-turbo-vision-preview": { - "max_tokens": 4096, - "max_input_tokens": 128000, - "max_output_tokens": 4096, - "input_cost_per_token": 0.00001, - "output_cost_per_token": 0.00003, - "litellm_provider": "azure", - "mode": "chat", - "supports_vision": true - }, - "azure/gpt-35-turbo-16k-0613": { - "max_tokens": 4096, - "max_input_tokens": 16385, - "max_output_tokens": 4096, - "input_cost_per_token": 0.000003, - "output_cost_per_token": 0.000004, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true - }, - "azure/gpt-35-turbo-1106": { - "max_tokens": 4096, - "max_input_tokens": 16384, - "max_output_tokens": 4096, - "input_cost_per_token": 0.000001, - "output_cost_per_token": 0.000002, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true - }, - "azure/gpt-35-turbo-0613": { - "max_tokens": 4097, - "max_input_tokens": 4097, - "max_output_tokens": 4096, - "input_cost_per_token": 0.0000015, - "output_cost_per_token": 0.000002, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true - }, - "azure/gpt-35-turbo-0301": { - "max_tokens": 4097, - "max_input_tokens": 4097, - "max_output_tokens": 4096, - "input_cost_per_token": 0.0000002, - "output_cost_per_token": 0.000002, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true - }, - "azure/gpt-35-turbo-0125": { - "max_tokens": 4096, - "max_input_tokens": 16384, - "max_output_tokens": 4096, - "input_cost_per_token": 0.0000005, - "output_cost_per_token": 0.0000015, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true - }, - "azure/gpt-35-turbo-16k": { - "max_tokens": 4096, - "max_input_tokens": 16385, - "max_output_tokens": 4096, - "input_cost_per_token": 0.000003, - "output_cost_per_token": 0.000004, - "litellm_provider": "azure", - "mode": "chat" - }, - "azure/gpt-35-turbo": { - "max_tokens": 4096, - "max_input_tokens": 4097, - "max_output_tokens": 4096, - "input_cost_per_token": 0.0000005, - "output_cost_per_token": 0.0000015, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true - }, - "azure/gpt-3.5-turbo-instruct-0914": { - "max_tokens": 4097, - "max_input_tokens": 4097, - "input_cost_per_token": 0.0000015, - "output_cost_per_token": 0.000002, - "litellm_provider": "text-completion-openai", - "mode": "completion" - }, - "azure/gpt-35-turbo-instruct": { - "max_tokens": 4097, - "max_input_tokens": 4097, - "input_cost_per_token": 0.0000015, - "output_cost_per_token": 0.000002, - "litellm_provider": "text-completion-openai", - "mode": "completion" - }, - "azure/gpt-35-turbo-instruct-0914": { - "max_tokens": 4097, - "max_input_tokens": 4097, - "input_cost_per_token": 0.0000015, - "output_cost_per_token": 0.000002, - "litellm_provider": "text-completion-openai", - "mode": "completion" - }, - "azure/mistral-large-latest": { - "max_tokens": 32000, - "max_input_tokens": 32000, - "input_cost_per_token": 0.000008, - "output_cost_per_token": 0.000024, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true - }, - "azure/mistral-large-2402": { - "max_tokens": 32000, - "max_input_tokens": 32000, - "input_cost_per_token": 0.000008, - "output_cost_per_token": 0.000024, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true - }, - "azure/command-r-plus": { - "max_tokens": 4096, - "max_input_tokens": 128000, - "max_output_tokens": 4096, - "input_cost_per_token": 0.000003, - "output_cost_per_token": 0.000015, - "litellm_provider": "azure", - "mode": "chat", - "supports_function_calling": true - }, - "azure/ada": { - "max_tokens": 8191, - "max_input_tokens": 8191, - "input_cost_per_token": 0.0000001, - "output_cost_per_token": 0.000000, - "litellm_provider": "azure", - "mode": "embedding" - }, - "azure/text-embedding-ada-002": { - "max_tokens": 8191, - "max_input_tokens": 8191, - "input_cost_per_token": 0.0000001, - "output_cost_per_token": 0.000000, - "litellm_provider": "azure", - "mode": "embedding" - }, - "azure/text-embedding-3-large": { - "max_tokens": 8191, - "max_input_tokens": 8191, - "input_cost_per_token": 0.00000013, - "output_cost_per_token": 0.000000, - "litellm_provider": "azure", - "mode": "embedding" - }, - "azure/text-embedding-3-small": { - "max_tokens": 8191, - "max_input_tokens": 8191, - "input_cost_per_token": 0.00000002, - "output_cost_per_token": 0.000000, - "litellm_provider": "azure", - "mode": "embedding" - }, - "azure/standard/1024-x-1024/dall-e-3": { - "input_cost_per_pixel": 0.0000000381469, - "output_cost_per_token": 0.0, - "litellm_provider": "azure", - "mode": "image_generation" - }, - "azure/hd/1024-x-1024/dall-e-3": { - "input_cost_per_pixel": 0.00000007629, - "output_cost_per_token": 0.0, - "litellm_provider": "azure", - "mode": "image_generation" - }, - "azure/standard/1024-x-1792/dall-e-3": { - "input_cost_per_pixel": 0.00000004359, - "output_cost_per_token": 0.0, - "litellm_provider": "azure", - "mode": "image_generation" - }, - "azure/standard/1792-x-1024/dall-e-3": { - "input_cost_per_pixel": 0.00000004359, - "output_cost_per_token": 0.0, - "litellm_provider": "azure", - "mode": "image_generation" - }, - "azure/hd/1024-x-1792/dall-e-3": { - "input_cost_per_pixel": 0.00000006539, - "output_cost_per_token": 0.0, - "litellm_provider": "azure", - "mode": "image_generation" - }, - "azure/hd/1792-x-1024/dall-e-3": { - "input_cost_per_pixel": 0.00000006539, - "output_cost_per_token": 0.0, - "litellm_provider": "azure", - "mode": "image_generation" - }, - "azure/standard/1024-x-1024/dall-e-2": { - "input_cost_per_pixel": 0.0, - "output_cost_per_token": 0.0, - "litellm_provider": "azure", - "mode": "image_generation" - }, - "azure_ai/jamba-instruct": { - "max_tokens": 4096, - "max_input_tokens": 70000, - "max_output_tokens": 4096, - "input_cost_per_token": 0.0000005, - "output_cost_per_token": 0.0000007, - "litellm_provider": "azure_ai", - "mode": "chat" - }, - "azure_ai/mistral-large": { - "max_tokens": 8191, - "max_input_tokens": 32000, - "max_output_tokens": 8191, - "input_cost_per_token": 0.000004, - "output_cost_per_token": 0.000012, - "litellm_provider": "azure_ai", - "mode": "chat", - "supports_function_calling": true - }, - "azure_ai/mistral-small": { - "max_tokens": 8191, - "max_input_tokens": 32000, - "max_output_tokens": 8191, - "input_cost_per_token": 0.000001, - "output_cost_per_token": 0.000003, - "litellm_provider": "azure_ai", - "supports_function_calling": true, - "mode": "chat" - }, - "azure_ai/Meta-Llama-3-70B-Instruct": { - "max_tokens": 8192, - "max_input_tokens": 8192, - "max_output_tokens": 8192, - "input_cost_per_token": 0.0000011, - "output_cost_per_token": 0.00000037, - "litellm_provider": "azure_ai", - "mode": "chat" - }, - "azure_ai/Meta-Llama-3.1-8B-Instruct": { - "max_tokens": 128000, - "max_input_tokens": 128000, - "max_output_tokens": 128000, - "input_cost_per_token": 0.0000003, - "output_cost_per_token": 0.00000061, - "litellm_provider": "azure_ai", - "mode": "chat", - "source": "https://azuremarketplace.microsoft.com/en-us/marketplace/apps/metagenai.meta-llama-3-1-8b-instruct-offer?tab=PlansAndPrice" - }, - "azure_ai/Meta-Llama-3.1-70B-Instruct": { - "max_tokens": 128000, - "max_input_tokens": 128000, - "max_output_tokens": 128000, - "input_cost_per_token": 0.00000268, - "output_cost_per_token": 0.00000354, - "litellm_provider": "azure_ai", - "mode": "chat", - "source": "https://azuremarketplace.microsoft.com/en-us/marketplace/apps/metagenai.meta-llama-3-1-70b-instruct-offer?tab=PlansAndPrice" - }, - "azure_ai/Meta-Llama-3.1-405B-Instruct": { - "max_tokens": 128000, - "max_input_tokens": 128000, - "max_output_tokens": 128000, - "input_cost_per_token": 0.00000533, - "output_cost_per_token": 0.000016, - "litellm_provider": "azure_ai", - "mode": "chat", - "source": "https://azuremarketplace.microsoft.com/en-us/marketplace/apps/metagenai.meta-llama-3-1-405b-instruct-offer?tab=PlansAndPrice" - }, - "azure_ai/cohere-rerank-v3-multilingual": { - "max_tokens": 4096, - "max_input_tokens": 4096, - "max_output_tokens": 4096, - "max_query_tokens": 2048, - "input_cost_per_token": 0.0, - "input_cost_per_query": 0.002, - "output_cost_per_token": 0.0, - "litellm_provider": "azure_ai", - "mode": "rerank" - }, - "azure_ai/cohere-rerank-v3-english": { - "max_tokens": 4096, - "max_input_tokens": 4096, - "max_output_tokens": 4096, - "max_query_tokens": 2048, - "input_cost_per_token": 0.0, - "input_cost_per_query": 0.002, - "output_cost_per_token": 0.0, - "litellm_provider": "azure_ai", - "mode": "rerank" - }, - "azure_ai/Cohere-embed-v3-english": { - "max_tokens": 512, - "max_input_tokens": 512, - "output_vector_size": 1024, - "input_cost_per_token": 0.0000001, - "output_cost_per_token": 0.0, - "litellm_provider": "azure_ai", - "mode": "embedding", - "source": "https://azuremarketplace.microsoft.com/en-us/marketplace/apps/cohere.cohere-embed-v3-english-offer?tab=PlansAndPrice" - }, - "azure_ai/Cohere-embed-v3-multilingual": { - "max_tokens": 512, - "max_input_tokens": 512, - "output_vector_size": 1024, - "input_cost_per_token": 0.0000001, - "output_cost_per_token": 0.0, - "litellm_provider": "azure_ai", - "mode": "embedding", - "source": "https://azuremarketplace.microsoft.com/en-us/marketplace/apps/cohere.cohere-embed-v3-english-offer?tab=PlansAndPrice" - } - } -} \ No newline at end of file diff --git a/embed_data/prompts/code_block_prompt.tmpl b/embed_data/prompts/code_block_prompt.tmpl index dae407a..20968cc 100644 --- a/embed_data/prompts/code_block_prompt.tmpl +++ b/embed_data/prompts/code_block_prompt.tmpl @@ -3,7 +3,7 @@ > Your tasks are according to these steps: ## PRIORITY: Check for Specific Context in Code -1. **If I request the specific context of code, such as a method, class, or any code part that exist in my context but it has an empty body or incomplete body, you must follow these steps:** +1. **If I request the specific context of code, such as a method, class, or any code part that is an empty or incomplete block, you must follow these steps:** - **Only return the relative paths of the relevant files as a JSON array of strings in the following format:** ```json { @@ -18,17 +18,15 @@ If the request does not fall under the above condition, proceed with the followi - Read the code context carefully. - Analyze the code to identify where the requested change or feature should be added or modified. -## Explanation -- Explain any needed changes. ## Every `Code BLOCK` MUST use this format: -- First line: the `file name` with `relative path`; no extra markup, punctuation, comments, etc. **JUST** the `file name` with `relative path` and `file name` using `naming conventions` base on `language`. +- First line: the `file name` with `relative path`; no extra markup, punctuation, comments, etc. **JUST** the `file name` with `relative path`. - Second line: start of md highlighted code block with language base on body of this code block. - ... entire content of the file ... - Final line: end of md highlighted code block. - In the end `Code BLOCK` format should be exactly like below: -**File: relativePath/fileName.ext** +File: relativePath/fileName.ext here. ```language base on body of this code block func main() { greeting := "Hello, World!" @@ -39,6 +37,9 @@ If the request does not fall under the above condition, proceed with the followi } ``` +## Explanation Section +- Explain any needed changes. + # IMPORTANT: - ALWAYS ADD RELATIVE PATH AND NAME OF FILE TOP OF EACH CODE BLOCK. - DO NOTE USE SPECIAL CHARACTER IN RELATIVE PATH AND THE NAME OF FILE. diff --git a/providers/ai_provider.go b/providers/ai_provider.go index 6a91045..4932c36 100644 --- a/providers/ai_provider.go +++ b/providers/ai_provider.go @@ -36,7 +36,7 @@ func ProviderFactory(config *AIProviderConfig, tokenManagement contracts.ITokenM Threshold: config.Threshold, TokenManagement: tokenManagement, }), nil - case "openai", "azure-openai": + case "openai": return openai.NewOpenAIProvider(&openai.OpenAIConfig{ Temperature: config.Temperature, diff --git a/providers/contracts/ai_provider.go b/providers/contracts/ai_provider.go index a025d8f..5890eed 100644 --- a/providers/contracts/ai_provider.go +++ b/providers/contracts/ai_provider.go @@ -7,5 +7,5 @@ import ( type IAIProvider interface { ChatCompletionRequest(ctx context.Context, userInput string, prompt string) <-chan models.StreamResponse - EmbeddingRequest(ctx context.Context, prompt string) ([][]float64, error) + EmbeddingRequest(ctx context.Context, prompt string) (*models.EmbeddingResponse, error) } diff --git a/providers/contracts/token_management.go b/providers/contracts/token_management.go index b316320..eee6e6e 100644 --- a/providers/contracts/token_management.go +++ b/providers/contracts/token_management.go @@ -1,8 +1,9 @@ package contracts type ITokenManagement interface { - UsedTokens(inputToken int, outputToken int) - UsedEmbeddingTokens(inputToken int, outputToken int) - CalculateCost(providerName string, modelName string, inputToken int, outputToken int) float64 - DisplayTokens(providerName string, model string, embeddingModel string, isRag bool) + CountTokens(text string, model string) (int, error) + AvailableTokens() int + UseTokens(count int) error + UseEmbeddingTokens(count int) error + DisplayTokens(model string, embeddingModel string) } diff --git a/providers/ollama/models/ollama_chat_completion_request.go b/providers/models/chat_completion_request.go similarity index 77% rename from providers/ollama/models/ollama_chat_completion_request.go rename to providers/models/chat_completion_request.go index 4f6650e..aba0079 100644 --- a/providers/ollama/models/ollama_chat_completion_request.go +++ b/providers/models/chat_completion_request.go @@ -1,7 +1,7 @@ package models -// OllamaChatCompletionRequest Define the request body structure -type OllamaChatCompletionRequest struct { +// ChatCompletionRequest Define the request body structure +type ChatCompletionRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` Temperature *float32 `json:"temperature,omitempty"` // Optional field (pointer to float32) diff --git a/providers/models/chat_completion_response.go b/providers/models/chat_completion_response.go new file mode 100644 index 0000000..33ba2cc --- /dev/null +++ b/providers/models/chat_completion_response.go @@ -0,0 +1,22 @@ +package models + +// ChatCompletionResponse represents the entire response structure from OpenAI's chat completion API. +type ChatCompletionResponse struct { + Choices []Choice `json:"choices"` +} + +// Choice represents an individual choice in the response. +type Choice struct { + Delta Delta `json:"delta"` +} + +// Delta represents the delta object in each choice containing the content. +type Delta struct { + Content string `json:"content"` +} + +type StreamResponse struct { + Content string // Holds content chunks + Err error // Holds error details + Done bool // Signals end of stream +} diff --git a/providers/openai/models/openai_embedding_request.go b/providers/models/embedding_request.go similarity index 69% rename from providers/openai/models/openai_embedding_request.go rename to providers/models/embedding_request.go index 3df1b1d..1df7fdb 100644 --- a/providers/openai/models/openai_embedding_request.go +++ b/providers/models/embedding_request.go @@ -1,7 +1,7 @@ package models -// OpenAIEmbeddingRequest represents the request structure for the OpenAI embedding API -type OpenAIEmbeddingRequest struct { +// EmbeddingRequest represents the request structure for the OpenAI embedding API +type EmbeddingRequest struct { Input string `json:"input"` // The input text to be embedded Model string `json:"model"` // The model used for generating embeddings EncodingFormat string `json:"encoding_format"` // The encoding format (in this case, "float") diff --git a/providers/models/embedding_response.go b/providers/models/embedding_response.go new file mode 100644 index 0000000..fc4c5d1 --- /dev/null +++ b/providers/models/embedding_response.go @@ -0,0 +1,22 @@ +package models + +// EmbeddingResponse represents the entire response from the embedding API +type EmbeddingResponse struct { + Object string `json:"object"` + Data []Data `json:"data"` + Model string `json:"model"` + Usage UsageInfo `json:"usage"` +} + +// Data represents each individual embedding +type Data struct { + Object string `json:"object"` + Embedding []float64 `json:"embedding"` // Embedding is an array of 1536 floats + Index int `json:"index"` +} + +// UsageInfo represents the token usage details in the API response +type UsageInfo struct { + PromptTokens int `json:"prompt_tokens"` + TotalTokens int `json:"total_tokens"` +} diff --git a/providers/models/general_chat_response.go b/providers/models/general_chat_response.go deleted file mode 100644 index 5ba3111..0000000 --- a/providers/models/general_chat_response.go +++ /dev/null @@ -1,17 +0,0 @@ -package models - -type StreamResponse struct { - Content string // Holds content chunks - Err error // Holds error details - Done bool // Signals end of stream -} - -type Error struct { - Message string `json:"message"` - Code string `json:"code"` -} - -// AIError represents an error response from OpenAI. -type AIError struct { - Error Error `json:"error"` -} diff --git a/providers/ollama/models/ollama_chat_completion_response.go b/providers/ollama/models/ollama_chat_completion_response.go deleted file mode 100644 index fe9d22f..0000000 --- a/providers/ollama/models/ollama_chat_completion_response.go +++ /dev/null @@ -1,19 +0,0 @@ -package models - -import "time" - -// OllamaChatCompletionResponse defines the structure of the combined Ollama API response. -type OllamaChatCompletionResponse struct { - Model string `json:"model"` // Name of the model used - CreatedAt time.Time `json:"created_at"` // Timestamp of when the response was created - Message OllamaMessage `json:"message"` // Message details - Done bool `json:"done"` // Indicates if the response is complete - PromptEvalCount int `json:"prompt_eval_count"` // Number of prompt evaluations - EvalCount int `json:"eval_count"` // Number of evaluations -} - -// OllamaMessage represents the content of the message from the assistant. -type OllamaMessage struct { - Role string `json:"role"` // Role of the message sender (e.g., "assistant") - Content string `json:"content"` // The content of the message -} diff --git a/providers/ollama/models/ollama_embedding_request.go b/providers/ollama/models/ollama_embedding_request.go deleted file mode 100644 index b1d1724..0000000 --- a/providers/ollama/models/ollama_embedding_request.go +++ /dev/null @@ -1,7 +0,0 @@ -package models - -// OllamaEmbeddingRequest represents the request structure for the OpenAI embedding API -type OllamaEmbeddingRequest struct { - Input string `json:"input"` // The input text to be embedded - Model string `json:"model"` // The model used for generating embeddings -} diff --git a/providers/ollama/models/ollama_embedding_response.go b/providers/ollama/models/ollama_embedding_response.go deleted file mode 100644 index 37b9746..0000000 --- a/providers/ollama/models/ollama_embedding_response.go +++ /dev/null @@ -1,8 +0,0 @@ -package models - -// OllamaEmbeddingResponse defines the structure of an embedding response from the Ollama API. -type OllamaEmbeddingResponse struct { - Model string `json:"model"` // The embedding model used (e.g., "all-minilm") - Embeddings [][]float64 `json:"embeddings"` // Embedding vectors, where each embedding is a slice of float64 values - PromptEvalCount int `json:"prompt_eval_count"` // Count of prompt evaluations -} diff --git a/providers/ollama/ollama_provider.go b/providers/ollama/ollama_provider.go index 670ab8d..920fd95 100644 --- a/providers/ollama/ollama_provider.go +++ b/providers/ollama/ollama_provider.go @@ -7,9 +7,9 @@ import ( "encoding/json" "errors" "fmt" + "github.com/meysamhadeli/codai/constants/lipgloss_color" "github.com/meysamhadeli/codai/providers/contracts" "github.com/meysamhadeli/codai/providers/models" - ollama_models "github.com/meysamhadeli/codai/providers/ollama/models" "io" "io/ioutil" "net/http" @@ -27,7 +27,6 @@ type OllamaConfig struct { MaxTokens int Threshold float64 TokenManagement contracts.ITokenManagement - Name string } // NewOllamaProvider initializes a new OllamaProvider. @@ -42,15 +41,27 @@ func NewOllamaProvider(config *OllamaConfig) contracts.IAIProvider { MaxTokens: config.MaxTokens, Threshold: config.Threshold, TokenManagement: config.TokenManagement, - Name: config.Name, } } -func (ollamaProvider *OllamaConfig) EmbeddingRequest(ctx context.Context, prompt string) ([][]float64, error) { +func (ollamaProvider *OllamaConfig) EmbeddingRequest(ctx context.Context, prompt string) (*models.EmbeddingResponse, error) { + + // Count tokens for the user input and prompt + totalChatTokens, err := ollamaProvider.TokenManagement.CountTokens(prompt, ollamaProvider.ChatCompletionModel) + if err != nil { + return nil, fmt.Errorf(lipgloss_color.Red.Render(fmt.Sprintf("%v", err))) + } + + // Check if enough tokens are available + if err := ollamaProvider.TokenManagement.UseTokens(totalChatTokens); err != nil { + return nil, fmt.Errorf(lipgloss_color.Red.Render(fmt.Sprintf("Error: %v", err))) + } + // Create the request payload - requestBody := ollama_models.OllamaEmbeddingRequest{ - Input: prompt, - Model: ollamaProvider.EmbeddingModel, + requestBody := models.EmbeddingRequest{ + Input: prompt, + Model: ollamaProvider.EmbeddingModel, + EncodingFormat: ollamaProvider.EncodingFormat, } // Convert the request payload to JSON @@ -88,41 +99,37 @@ func (ollamaProvider *OllamaConfig) EmbeddingRequest(ctx context.Context, prompt // Check for non-200 status codes if resp.StatusCode != http.StatusOK { - var apiError models.AIError - if err := json.Unmarshal(body, &apiError); err != nil { - return nil, fmt.Errorf("error parsing error response: %v", err) - } - - return nil, fmt.Errorf("embedding request failed with status code '%d' - %s\n", resp.StatusCode, apiError.Error.Message) + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) } // Unmarshal the response JSON into the struct - var embeddingResponse ollama_models.OllamaEmbeddingResponse + var embeddingResponse models.EmbeddingResponse err = json.Unmarshal(body, &embeddingResponse) if err != nil { return nil, fmt.Errorf("error decoding JSON response: %v", err) } - // Count total tokens usage - if embeddingResponse.PromptEvalCount > 0 { - ollamaProvider.TokenManagement.UsedEmbeddingTokens(embeddingResponse.PromptEvalCount, 0) - } - // Return the parsed response - return embeddingResponse.Embeddings, nil + return &embeddingResponse, nil } func (ollamaProvider *OllamaConfig) ChatCompletionRequest(ctx context.Context, userInput string, prompt string) <-chan models.StreamResponse { responseChan := make(chan models.StreamResponse) - var markdownBuffer strings.Builder // Buffer to accumulate content until newline go func() { defer close(responseChan) + // Count tokens for the user input and prompt + totalChatTokens, err := ollamaProvider.TokenManagement.CountTokens(fmt.Sprintf("%s%s", prompt, userInput), ollamaProvider.ChatCompletionModel) + if err != nil { + responseChan <- models.StreamResponse{Err: fmt.Errorf("error counting tokens: %v", err)} + return + } + // Prepare the request body - reqBody := ollama_models.OllamaChatCompletionRequest{ + reqBody := models.ChatCompletionRequest{ Model: ollamaProvider.ChatCompletionModel, - Messages: []ollama_models.Message{ + Messages: []models.Message{ {Role: "system", Content: prompt}, {Role: "user", Content: userInput}, }, @@ -132,7 +139,6 @@ func (ollamaProvider *OllamaConfig) ChatCompletionRequest(ctx context.Context, u jsonData, err := json.Marshal(reqBody) if err != nil { - markdownBuffer.Reset() responseChan <- models.StreamResponse{Err: fmt.Errorf("error marshalling request body: %v", err)} return } @@ -140,7 +146,6 @@ func (ollamaProvider *OllamaConfig) ChatCompletionRequest(ctx context.Context, u // Create a new HTTP request req, err := http.NewRequestWithContext(ctx, "POST", ollamaProvider.ChatCompletionURL, bytes.NewBuffer(jsonData)) if err != nil { - markdownBuffer.Reset() responseChan <- models.StreamResponse{Err: fmt.Errorf("error creating request: %v", err)} return } @@ -150,7 +155,6 @@ func (ollamaProvider *OllamaConfig) ChatCompletionRequest(ctx context.Context, u client := &http.Client{} resp, err := client.Do(req) if err != nil { - markdownBuffer.Reset() if errors.Is(ctx.Err(), context.Canceled) { responseChan <- models.StreamResponse{Err: fmt.Errorf("request canceled: %v", err)} return @@ -161,27 +165,17 @@ func (ollamaProvider *OllamaConfig) ChatCompletionRequest(ctx context.Context, u defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - markdownBuffer.Reset() - - body, _ := ioutil.ReadAll(resp.Body) - var apiError models.AIError - if err := json.Unmarshal(body, &apiError); err != nil { - responseChan <- models.StreamResponse{Err: fmt.Errorf("error parsing error response: %v", err)} - return - } - - responseChan <- models.StreamResponse{Err: fmt.Errorf("API request failed with status code '%d' - %s\n", resp.StatusCode, apiError.Error.Message)} + responseChan <- models.StreamResponse{Err: fmt.Errorf("API request failed with status: %d", resp.StatusCode)} return } reader := bufio.NewReader(resp.Body) + var markdownBuffer strings.Builder // Buffer to accumulate content until newline // Stream processing for { line, err := reader.ReadString('\n') if err != nil { - markdownBuffer.Reset() - if err == io.EOF { break } @@ -189,37 +183,38 @@ func (ollamaProvider *OllamaConfig) ChatCompletionRequest(ctx context.Context, u return } - var response ollama_models.OllamaChatCompletionResponse - if err := json.Unmarshal([]byte(line), &response); err != nil { - markdownBuffer.Reset() + if line == "data: [DONE]\n" { + // Signal end of stream + responseChan <- models.StreamResponse{Content: markdownBuffer.String()} + responseChan <- models.StreamResponse{Done: true} - responseChan <- models.StreamResponse{Err: fmt.Errorf("error unmarshalling chunk: %v", err)} - return - } + // Use tokens + if err := ollamaProvider.TokenManagement.UseTokens(totalChatTokens); err != nil { + responseChan <- models.StreamResponse{Err: fmt.Errorf("error using tokens: %v", err)} + return + } - if len(response.Message.Content) > 0 { - content := response.Message.Content - markdownBuffer.WriteString(content) + break + } - // Send chunk if it contains a newline, and then reset the buffer - if strings.Contains(content, "\n") { - responseChan <- models.StreamResponse{Content: markdownBuffer.String()} - markdownBuffer.Reset() + if strings.HasPrefix(line, "data: ") { + jsonPart := strings.TrimPrefix(line, "data: ") + var response models.ChatCompletionResponse + if err := json.Unmarshal([]byte(jsonPart), &response); err != nil { + responseChan <- models.StreamResponse{Err: fmt.Errorf("error unmarshalling chunk: %v", err)} + return } - } - // Check if the response is marked as done - if response.Done { - // // Signal end of stream - responseChan <- models.StreamResponse{Content: markdownBuffer.String()} - responseChan <- models.StreamResponse{Done: true} + if len(response.Choices) > 0 { + content := response.Choices[0].Delta.Content + markdownBuffer.WriteString(content) - // Count total tokens usage - if response.PromptEvalCount > 0 { - ollamaProvider.TokenManagement.UsedTokens(response.PromptEvalCount, response.EvalCount) + // Send chunk if it contains a newline, and then reset the buffer + if strings.Contains(content, "\n") { + responseChan <- models.StreamResponse{Content: markdownBuffer.String()} + markdownBuffer.Reset() + } } - - break } } diff --git a/providers/openai/models/openai_chat_completion_request.go b/providers/openai/models/openai_chat_completion_request.go deleted file mode 100644 index 0111188..0000000 --- a/providers/openai/models/openai_chat_completion_request.go +++ /dev/null @@ -1,21 +0,0 @@ -package models - -// OpenAIChatCompletionRequest Define the request body structure -type OpenAIChatCompletionRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Temperature *float32 `json:"temperature,omitempty"` // Optional field (pointer to float32) - Stream bool `json:"stream"` - StreamOptions StreamOptions `json:"stream_options"` -} - -// Message Define the request body structure -type Message struct { - Role string `json:"role"` - Content string `json:"content"` -} - -// StreamOptions includes configurations for streaming behavior -type StreamOptions struct { - IncludeUsage bool `json:"include_usage"` // Requests token usage data in the response -} diff --git a/providers/openai/models/openai_chat_completion_response.go b/providers/openai/models/openai_chat_completion_response.go deleted file mode 100644 index c6c8d0b..0000000 --- a/providers/openai/models/openai_chat_completion_response.go +++ /dev/null @@ -1,24 +0,0 @@ -package models - -// OpenAIChatCompletionResponse represents the entire response structure from OpenAI's chat completion API. -type OpenAIChatCompletionResponse struct { - Choices []Choice `json:"choices"` // Array of choice completions - Usage Usage `json:"usage"` // Token usage details -} - -// Choice represents an individual choice in the response. -type Choice struct { - Delta Delta `json:"delta"` -} - -// Delta represents the delta object in each choice containing the content. -type Delta struct { - Content string `json:"content"` -} - -// Usage defines the token usage information for the response. -type Usage struct { - PromptTokens int `json:"prompt_tokens"` // Number of tokens in the prompt - CompletionTokens int `json:"completion_tokens"` // Number of tokens in the completion - TotalTokens int `json:"total_tokens"` // Total tokens used -} diff --git a/providers/openai/models/openai_embedding_response.go b/providers/openai/models/openai_embedding_response.go deleted file mode 100644 index 6579c47..0000000 --- a/providers/openai/models/openai_embedding_response.go +++ /dev/null @@ -1,22 +0,0 @@ -package models - -// OpenAIEmbeddingResponse represents the entire response from the embedding API -type OpenAIEmbeddingResponse struct { - Object string `json:"object"` // Object type, typically "list" - Data []Data `json:"data"` // Array of embedding data - Model string `json:"model"` // Model used for the embedding - UsageEmbedding UsageEmbedding `json:"usage"` // Token usage details -} - -// Data represents each individual embedding -type Data struct { - Object string `json:"object"` - Embedding []float64 `json:"embedding"` // Embedding is an array of 1536 floats - Index int `json:"index"` -} - -// UsageEmbedding defines the token usage information for the embedding request. -type UsageEmbedding struct { - PromptTokens int `json:"prompt_tokens"` // Number of tokens in the prompt - TotalTokens int `json:"total_tokens"` // Total number of tokens used -} diff --git a/providers/openai/openai_provider.go b/providers/openai/openai_provider.go index 8e3ace4..01de0f4 100644 --- a/providers/openai/openai_provider.go +++ b/providers/openai/openai_provider.go @@ -7,9 +7,9 @@ import ( "encoding/json" "errors" "fmt" + "github.com/meysamhadeli/codai/constants/lipgloss_color" "github.com/meysamhadeli/codai/providers/contracts" "github.com/meysamhadeli/codai/providers/models" - openai_models "github.com/meysamhadeli/codai/providers/openai/models" "io" "io/ioutil" "net/http" @@ -18,7 +18,6 @@ import ( // OpenAIConfig implements the Provider interface for OpenAPI. type OpenAIConfig struct { - Name string EmbeddingURL string ChatCompletionURL string EmbeddingModel string @@ -44,14 +43,24 @@ func NewOpenAIProvider(config *OpenAIConfig) contracts.IAIProvider { Threshold: config.Threshold, ApiKey: config.ApiKey, TokenManagement: config.TokenManagement, - Name: config.Name, } } -func (openAIProvider *OpenAIConfig) EmbeddingRequest(ctx context.Context, prompt string) ([][]float64, error) { +func (openAIProvider *OpenAIConfig) EmbeddingRequest(ctx context.Context, prompt string) (*models.EmbeddingResponse, error) { + + // Count tokens for the user input and prompt + totalChatTokens, err := openAIProvider.TokenManagement.CountTokens(prompt, openAIProvider.ChatCompletionModel) + if err != nil { + return nil, fmt.Errorf(lipgloss_color.Red.Render(fmt.Sprintf("%v", err))) + } + + // Check if enough tokens are available + if err := openAIProvider.TokenManagement.UseEmbeddingTokens(totalChatTokens); err != nil { + return nil, fmt.Errorf(lipgloss_color.Red.Render(fmt.Sprintf("Error: %v", err))) + } // Create the request payload - requestBody := openai_models.OpenAIEmbeddingRequest{ + requestBody := models.EmbeddingRequest{ Input: prompt, Model: openAIProvider.EmbeddingModel, EncodingFormat: openAIProvider.EncodingFormat, @@ -91,56 +100,48 @@ func (openAIProvider *OpenAIConfig) EmbeddingRequest(ctx context.Context, prompt return nil, fmt.Errorf("error reading response: %v", err) } - // Check for error status code + // Check for non-200 status codes if resp.StatusCode != http.StatusOK { - var apiError models.AIError - if err := json.Unmarshal(body, &apiError); err != nil { - return nil, fmt.Errorf("error parsing error response: %v", err) - } - - return nil, fmt.Errorf("embedding request failed with status code '%d' - %s\n", resp.StatusCode, apiError.Error.Message) + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) } // Unmarshal the response JSON into the struct - var embeddingResponse openai_models.OpenAIEmbeddingResponse + var embeddingResponse models.EmbeddingResponse err = json.Unmarshal(body, &embeddingResponse) if err != nil { return nil, fmt.Errorf("error decoding JSON response: %v", err) } - // Count total tokens usage - if embeddingResponse.UsageEmbedding.TotalTokens > 0 { - openAIProvider.TokenManagement.UsedEmbeddingTokens(embeddingResponse.UsageEmbedding.TotalTokens, 0) - } - - return [][]float64{embeddingResponse.Data[0].Embedding}, nil + // Return the parsed response + return &embeddingResponse, nil } func (openAIProvider *OpenAIConfig) ChatCompletionRequest(ctx context.Context, userInput string, prompt string) <-chan models.StreamResponse { responseChan := make(chan models.StreamResponse) - var markdownBuffer strings.Builder // Buffer to accumulate content until newline - var usage openai_models.Usage // Variable to hold usage data go func() { defer close(responseChan) + // Count tokens for the user input and prompt + totalChatTokens, err := openAIProvider.TokenManagement.CountTokens(fmt.Sprintf("%s%s", prompt, userInput), openAIProvider.ChatCompletionModel) + if err != nil { + responseChan <- models.StreamResponse{Err: fmt.Errorf("error counting tokens: %v", err)} + return + } + // Prepare the request body - reqBody := openai_models.OpenAIChatCompletionRequest{ + reqBody := models.ChatCompletionRequest{ Model: openAIProvider.ChatCompletionModel, - Messages: []openai_models.Message{ + Messages: []models.Message{ {Role: "system", Content: prompt}, {Role: "user", Content: userInput}, }, Stream: true, Temperature: &openAIProvider.Temperature, - StreamOptions: openai_models.StreamOptions{ - IncludeUsage: true, - }, } jsonData, err := json.Marshal(reqBody) if err != nil { - markdownBuffer.Reset() responseChan <- models.StreamResponse{Err: fmt.Errorf("error marshalling request body: %v", err)} return } @@ -148,7 +149,6 @@ func (openAIProvider *OpenAIConfig) ChatCompletionRequest(ctx context.Context, u // Create a new HTTP request req, err := http.NewRequestWithContext(ctx, "POST", openAIProvider.ChatCompletionURL, bytes.NewBuffer(jsonData)) if err != nil { - markdownBuffer.Reset() responseChan <- models.StreamResponse{Err: fmt.Errorf("error creating request: %v", err)} return } @@ -159,7 +159,6 @@ func (openAIProvider *OpenAIConfig) ChatCompletionRequest(ctx context.Context, u client := &http.Client{} resp, err := client.Do(req) if err != nil { - markdownBuffer.Reset() if errors.Is(ctx.Err(), context.Canceled) { responseChan <- models.StreamResponse{Err: fmt.Errorf("request canceled: %v", err)} return @@ -170,25 +169,17 @@ func (openAIProvider *OpenAIConfig) ChatCompletionRequest(ctx context.Context, u defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - markdownBuffer.Reset() - body, _ := ioutil.ReadAll(resp.Body) - var apiError models.AIError - if err := json.Unmarshal(body, &apiError); err != nil { - responseChan <- models.StreamResponse{Err: fmt.Errorf("error parsing error response: %v", err)} - return - } - - responseChan <- models.StreamResponse{Err: fmt.Errorf("API request failed with status code '%d' - %s\n", resp.StatusCode, apiError.Error.Message)} + responseChan <- models.StreamResponse{Err: fmt.Errorf("API request failed with status: %d", resp.StatusCode)} return } reader := bufio.NewReader(resp.Body) + var markdownBuffer strings.Builder // Buffer to accumulate content until newline // Stream processing for { line, err := reader.ReadString('\n') if err != nil { - markdownBuffer.Reset() if err == io.EOF { break } @@ -197,35 +188,27 @@ func (openAIProvider *OpenAIConfig) ChatCompletionRequest(ctx context.Context, u } if line == "data: [DONE]\n" { - // Send the final content + // Signal end of stream responseChan <- models.StreamResponse{Content: markdownBuffer.String()} - responseChan <- models.StreamResponse{Done: true} - // Count total tokens usage - if usage.TotalTokens > 0 { - openAIProvider.TokenManagement.UsedTokens(usage.PromptTokens, usage.CompletionTokens) + // Use tokens + if err := openAIProvider.TokenManagement.UseTokens(totalChatTokens); err != nil { + responseChan <- models.StreamResponse{Err: fmt.Errorf("error using tokens: %v", err)} + return } - + break } if strings.HasPrefix(line, "data: ") { jsonPart := strings.TrimPrefix(line, "data: ") - var response openai_models.OpenAIChatCompletionResponse + var response models.ChatCompletionResponse if err := json.Unmarshal([]byte(jsonPart), &response); err != nil { - markdownBuffer.Reset() - responseChan <- models.StreamResponse{Err: fmt.Errorf("error unmarshalling chunk: %v", err)} return } - // Check if the response has usage information - if response.Usage.TotalTokens > 0 { - usage = response.Usage // Capture the usage data for later use - } - - // Accumulate and send response content if len(response.Choices) > 0 { content := response.Choices[0].Delta.Content markdownBuffer.WriteString(content) diff --git a/providers/token_management.go b/providers/token_management.go index 76b2efc..f74b36e 100644 --- a/providers/token_management.go +++ b/providers/token_management.go @@ -1,136 +1,95 @@ package providers import ( - "encoding/json" "fmt" "github.com/charmbracelet/lipgloss" - "github.com/meysamhadeli/codai/embed_data" "github.com/meysamhadeli/codai/providers/contracts" + "github.com/pkoukk/tiktoken-go" "log" "strings" ) // Define styles for the box var ( - boxStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#FFFFFF")).Bold(true).Border(lipgloss.NormalBorder()).PaddingLeft(1).PaddingRight(1).Align(lipgloss.Left) + boxStyle = lipgloss.NewStyle().Border(lipgloss.NormalBorder()).PaddingLeft(1).PaddingRight(1).BorderLeft(false).BorderRight(false).Align(lipgloss.Center) ) // TokenManager implementation type tokenManager struct { - usedToken int - usedInputToken int - usedOutputToken int - - usedEmbeddingToken int - usedEmbeddingInputToken int - usedEmbeddingOutputToken int -} - -type details struct { - MaxTokens int `json:"max_tokens"` - MaxInputTokens int `json:"max_input_tokens"` - MaxOutputTokens int `json:"max_output_tokens"` - InputCostPerToken float64 `json:"input_cost_per_token,omitempty"` - OutputCostPerToken float64 `json:"output_cost_per_token,omitempty"` - CacheReadInputTokenCost float64 `json:"cache_read_input_token_cost,omitempty"` - Mode string `json:"mode"` - SupportsFunctionCalling bool `json:"supports_function_calling,omitempty"` -} - -type Models struct { - ModelDetails map[string]details `json:"models"` + maxTokens int + usedTokens int + usedEmbeddingTokens int } // NewTokenManager creates a new token manager -func NewTokenManager() contracts.ITokenManagement { +func NewTokenManager(maxTokens int) contracts.ITokenManagement { return &tokenManager{ - usedToken: 0, - usedInputToken: 0, - usedOutputToken: 0, - usedEmbeddingToken: 0, - usedEmbeddingInputToken: 0, - usedEmbeddingOutputToken: 0, + maxTokens: maxTokens, + usedTokens: 0, + usedEmbeddingTokens: 0, } } -// UsedTokens deducts the token count from the available tokens. -func (tm *tokenManager) UsedTokens(inputToken int, outputToken int) { - tm.usedInputToken = inputToken - tm.usedOutputToken = outputToken - - tm.usedToken += inputToken + outputToken -} - -// UsedEmbeddingTokens deducts the token count from the available tokens. -func (tm *tokenManager) UsedEmbeddingTokens(inputToken int, outputToken int) { - tm.usedEmbeddingInputToken = inputToken - tm.usedEmbeddingOutputToken = outputToken - - tm.usedEmbeddingToken += inputToken + outputToken -} - -func (tm *tokenManager) DisplayTokens(providerName string, model string, embeddingModel string, isRag bool) { - - cost := tm.CalculateCost(providerName, model, tm.usedInputToken, tm.usedOutputToken) - costEmbedding := tm.CalculateCost(providerName, embeddingModel, tm.usedEmbeddingInputToken, tm.usedEmbeddingOutputToken) - - var tokenDetails string - var embeddingTokenDetails string - - tokenDetails = fmt.Sprintf("Chat Model: '%s' - Token Used: '%s' - Cost: '%s'", model, fmt.Sprint(tm.usedToken), fmt.Sprintf("%.6f", cost)) - - if isRag { - embeddingTokenDetails = fmt.Sprintf("Embedding Model: '%s' - Token Used: '%s' - Cost: '%s'", embeddingModel, fmt.Sprint(tm.usedEmbeddingToken), fmt.Sprintf("%.6f", costEmbedding)) +// CountTokens counts the number of tokens in the input text. +func (tm *tokenManager) CountTokens(text string, model string) (int, error) { + + model = strings.ToLower(model) + + var modelName string + switch { + case strings.HasPrefix(model, "gpt-4o"): + modelName = "gpt-4o" + case strings.HasPrefix(model, "gpt-4"): + modelName = "gpt-4" + case strings.HasPrefix(model, "gpt-3"): + modelName = "gpt-3.5-turbo" + case model == "text-embedding-3-small": + modelName = "text-embedding-3-small" + case model == "text-embedding-3-large": + modelName = "text-embedding-3-large" + case model == "text-embedding-ada-002": + modelName = "text-embedding-ada-002" + default: + modelName = "gpt-4" } - tokenInfo := tokenDetails + "\n" + embeddingTokenDetails - tokenBox := boxStyle.Render(tokenInfo) - fmt.Println(tokenBox) -} - -func getModelDetails(providerName string, modelName string) (details, error) { - - providerName = strings.ToLower(providerName) - modelName = strings.ToLower(modelName) - - if strings.HasPrefix(providerName, "azure") { - modelName = "azure/" + modelName - } - - // Initialize the Models struct to hold parsed JSON data - models := Models{ - ModelDetails: make(map[string]details), - } - - // Unmarshal the JSON data from the embedded file - err := json.Unmarshal(embed_data.ModelDetails, &models) + tkm, err := tiktoken.EncodingForModel(modelName) if err != nil { - log.Fatalf("Error unmarshaling JSON: %v", err) - return details{}, err + err = fmt.Errorf("encoding for model: %v", err) + log.Println(err) + return 0, err } - // Look up the model by name - model, exists := models.ModelDetails[modelName] - if !exists { - return details{}, fmt.Errorf("model details price with name '%s' not found for provider '%s'", modelName, providerName) - } + // encode + token := tkm.Encode(text, nil, nil) - return model, nil + return len(token), nil } -func (tm *tokenManager) CalculateCost(providerName string, modelName string, inputToken int, outputToken int) float64 { - modelDetails, err := getModelDetails(providerName, modelName) - if err != nil { - return 0 +// AvailableTokens returns the number of available tokens. +func (tm *tokenManager) AvailableTokens() int { + return tm.maxTokens - tm.usedTokens +} + +// UseTokens deducts the token count from the available tokens. +func (tm *tokenManager) UseTokens(count int) error { + if count > tm.AvailableTokens() { + return fmt.Errorf("not enough tokens available: requested %d, available %d", count, tm.AvailableTokens()) } - // Calculate cost for input tokens - inputCost := float64(inputToken) * modelDetails.InputCostPerToken + tm.usedTokens += count + return nil +} - // Calculate cost for output tokens - outputCost := float64(outputToken) * modelDetails.OutputCostPerToken +// UseEmbeddingTokens deducts the token count from the available tokens. +func (tm *tokenManager) UseEmbeddingTokens(count int) error { + tm.usedEmbeddingTokens += count + return nil +} - // Total cost - totalCost := inputCost + outputCost +func (tm *tokenManager) DisplayTokens(model string, embeddingModel string) { + used, available, total := tm.usedTokens, tm.AvailableTokens(), tm.maxTokens + tokenInfo := fmt.Sprintf("Used Tokens: %d | Available Tokens: %d | Total Tokens (%s): %d | Used Embedding Tokens (%s): %d", used, available, model, total, embeddingModel, tm.usedEmbeddingTokens) - return totalCost + tokenBox := boxStyle.Render(tokenInfo) + fmt.Println(tokenBox) } diff --git a/utils/confirm_prompt.go b/utils/confirm_prompt.go index 3ac8225..4615d55 100644 --- a/utils/confirm_prompt.go +++ b/utils/confirm_prompt.go @@ -13,7 +13,7 @@ func ConfirmPrompt(path string) (bool, error) { reader := bufio.NewReader(os.Stdin) // Styled prompt message - fmt.Printf(lipgloss_color.BlueSky.Render(fmt.Sprintf("Do you want to accept the change for file '%v'%v", lipgloss_color.LightBlueB.Render(path), lipgloss_color.BlueSky.Render(" ? (y/n): ")))) + fmt.Printf(lipgloss_color.BlueSky.Render(fmt.Sprintf("Do you want to accept the change for file `%v`%v", lipgloss_color.LightBlueB.Render(path), lipgloss_color.BlueSky.Render(" ? (y/n): ")))) for { // Read user input