diff --git a/README.md b/README.md index a9bf0f9..f0ddb19 100644 --- a/README.md +++ b/README.md @@ -61,21 +61,21 @@ The `codai-config` file should be like following example base on your `AI provid **codai-config.yml** ```yml ai_provider_config: - chat_provider_name: "openai" # openai | ollama | azure-openai - chat_base_url: "https://api.openai.com" # "http://localhost:11434" | "https://test,openai.azure.com" + chat_provider_name: "openai" # openai | ollama | azure-openai | anthropic | openrouter + chat_base_url: "https://api.openai.com" # "http://localhost:11434" | "https://test,openai.azure.com" | "https://api.anthropic.com" | "https://openrouter.ai" chat_model: "gpt-4o" - chat_api_version: "2024-04-01-preview" #(Optional, If your AI provider like AzureOpenai has chat api version.) + chat_api_version: "2024-04-01-preview" #(Optional, If your AI provider like 'AzureOpenai' or 'Anthropic' has chat api version.) embeddings_provider_name: "openai" # openai | ollama | azure-openai embeddings_base_url: "https://api.openai.com" # "http://localhost:11434" | "https://test,openai.azure.com" - embeddings_model: "text-embedding-3-small" #(Optional, If you want use RAG.) - embeddings_api_version: "2024-01-01-preview" #(Optional, If your AI provider like AzureOpenai has embeddings api version.) + embeddings_model: "text-embedding-3-small" #(Optional, If you want use 'RAG'.) + embeddings_api_version: "2024-01-01-preview" #(Optional, If your AI provider like 'AzureOpenai' has embeddings api version.) temperature: 0.2 - threshold: 0.2 #(Optional, If you want use RAG.) + threshold: 0.2 #(Optional, If you want use 'RAG'.) theme: "dracula" -rag: true #(Optional, If you want use RAG.) +rag: true #(Optional, If you want use 'RAG'.) ``` -> Note: We used the standard integration of [OpenAI APIs](https://platform.openai.com/docs/api-reference/introduction), [Ollama APIs](https://github.com/ollama/ollama/blob/main/docs/api.md) and [Azure Openai](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) and you can find more details in documentation of each APIs. +> Note: We used the standard integration of [OpenAI APIs](https://platform.openai.com/docs/api-reference/introduction), [Ollama APIs](https://github.com/ollama/ollama/blob/main/docs/api.md), [Azure Openai](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference), [Anthropic](https://docs.anthropic.com/en/api/getting-started), [OpenRouter](https://openrouter.ai/docs/quick-start) and you can find more details in documentation of each AI provider APIs. If you wish to customize your configuration, you can create your own `codai-config.yml` file and place it in the `root directory` of `each project` you want to analyze with codai. If `no configuration` file is provided, codai will use the `default settings`. diff --git a/providers/ai_provider.go b/providers/ai_provider.go index ad27187..296b6fa 100644 --- a/providers/ai_provider.go +++ b/providers/ai_provider.go @@ -2,10 +2,12 @@ package providers import ( "errors" + "github.com/meysamhadeli/codai/providers/anthropic" azure_openai "github.com/meysamhadeli/codai/providers/azure-openai" "github.com/meysamhadeli/codai/providers/contracts" "github.com/meysamhadeli/codai/providers/ollama" "github.com/meysamhadeli/codai/providers/openai" + "github.com/meysamhadeli/codai/providers/openrouter" contracts2 "github.com/meysamhadeli/codai/token_management/contracts" ) @@ -56,7 +58,7 @@ func ChatProviderFactory(config *AIProviderConfig, tokenManagement contracts2.IT ChatApiVersion: config.ChatApiVersion, EmbeddingsApiVersion: config.EmbeddingsApiVersion, }), nil - case "azure-openai", "azure_openai": + case "azure-openai": return azure_openai.NewAzureOpenAIChatProvider(&azure_openai.AzureOpenAIConfig{ Temperature: config.Temperature, EncodingFormat: config.EncodingFormat, @@ -71,6 +73,30 @@ func ChatProviderFactory(config *AIProviderConfig, tokenManagement contracts2.IT ChatApiVersion: config.ChatApiVersion, EmbeddingsApiVersion: config.EmbeddingsApiVersion, }), nil + + case "openrouter": + return openrouter.NewOpenRouterChatProvider(&openrouter.OpenRouterConfig{ + Temperature: config.Temperature, + EncodingFormat: config.EncodingFormat, + ChatModel: config.ChatModel, + ChatApiKey: config.ChatApiKey, + MaxTokens: config.MaxTokens, + Threshold: config.Threshold, + TokenManagement: tokenManagement, + ChatApiVersion: config.ChatApiVersion, + }), nil + + case "anthropic": + return anthropic.NewAnthropicMessageProvider(&anthropic.AnthropicConfig{ + Temperature: config.Temperature, + EncodingFormat: config.EncodingFormat, + MessageModel: config.ChatModel, + MessageApiKey: config.ChatApiKey, + MaxTokens: config.MaxTokens, + Threshold: config.Threshold, + TokenManagement: tokenManagement, + MessageApiVersion: config.ChatApiVersion, + }), nil default: return nil, errors.New("unsupported provider") @@ -106,7 +132,7 @@ func EmbeddingsProviderFactory(config *AIProviderConfig, tokenManagement contrac ChatApiVersion: config.ChatApiVersion, EmbeddingsApiVersion: config.EmbeddingsApiVersion, }), nil - case "azure-openai", "azure_openai": + case "azure-openai": return azure_openai.NewAzureOpenAIEmbeddingsProvider(&azure_openai.AzureOpenAIConfig{ Temperature: config.Temperature, EncodingFormat: config.EncodingFormat, diff --git a/providers/anthropic/anthropic_provider.go b/providers/anthropic/anthropic_provider.go new file mode 100644 index 0000000..cd5b66f --- /dev/null +++ b/providers/anthropic/anthropic_provider.go @@ -0,0 +1,167 @@ +package anthropic + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "github.com/meysamhadeli/codai/providers/anthropic/models" + "github.com/meysamhadeli/codai/providers/contracts" + general_models "github.com/meysamhadeli/codai/providers/models" + contracts2 "github.com/meysamhadeli/codai/token_management/contracts" + "io" + "io/ioutil" + "net/http" + "strings" +) + +// AnthropicConfig implements the Provider interface for OpenAPI. +type AnthropicConfig struct { + MessageBaseURL string + MessageModel string + Temperature float32 + EncodingFormat string + MessageApiKey string + MaxTokens int + Threshold float64 + TokenManagement contracts2.ITokenManagement + MessageApiVersion string +} + +// NewAnthropicMessageProvider initializes a new OpenAPIProvider. +func NewAnthropicMessageProvider(config *AnthropicConfig) contracts.IChatAIProvider { + return &AnthropicConfig{ + MessageBaseURL: config.MessageBaseURL, + MessageModel: config.MessageModel, + Temperature: config.Temperature, + EncodingFormat: config.EncodingFormat, + MaxTokens: config.MaxTokens, + Threshold: config.Threshold, + MessageApiKey: config.MessageApiKey, + MessageApiVersion: config.MessageApiVersion, + TokenManagement: config.TokenManagement, + } +} + +func (anthropicProvider *AnthropicConfig) ChatCompletionRequest(ctx context.Context, userInput string, prompt string) <-chan general_models.StreamResponse { + responseChan := make(chan general_models.StreamResponse) + var markdownBuffer strings.Builder // Accumulate content for streaming responses + var usage models.Usage // To track token usage + + go func() { + defer close(responseChan) + + // Prepare the request body + reqBody := models.AnthropicMessageRequest{ + Messages: []models.Message{ + {Role: "system", Content: prompt}, + {Role: "user", Content: userInput}, + }, + Model: anthropicProvider.MessageModel, + Temperature: &anthropicProvider.Temperature, + Stream: true, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + responseChan <- general_models.StreamResponse{Err: fmt.Errorf("error marshalling request body: %v", err)} + return + } + + // Create the HTTP request + req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/v1/messages", anthropicProvider.MessageBaseURL), bytes.NewBuffer(jsonData)) + if err != nil { + responseChan <- general_models.StreamResponse{Err: fmt.Errorf("error creating HTTP request: %v", err)} + return + } + + // Set required headers + req.Header.Set("content-type", "application/json") // Required content type + req.Header.Set("anthropic-version", anthropicProvider.MessageApiVersion) // Required API version + req.Header.Set("x-api-key", anthropicProvider.MessageApiKey) // API key for authentication + + // Send the HTTP request + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + if errors.Is(ctx.Err(), context.Canceled) { + responseChan <- general_models.StreamResponse{Err: fmt.Errorf("request canceled: %v", err)} + } else { + responseChan <- general_models.StreamResponse{Err: fmt.Errorf("error sending request: %v", err)} + } + return + } + defer resp.Body.Close() + + // Handle non-200 status codes + if resp.StatusCode != http.StatusOK { + body, _ := ioutil.ReadAll(resp.Body) + var apiError models.AnthropicError + if err := json.Unmarshal(body, &apiError); err != nil { + responseChan <- general_models.StreamResponse{Err: fmt.Errorf("error parsing error response: %v", err)} + } else { + responseChan <- general_models.StreamResponse{Err: fmt.Errorf("API request failed: %s", apiError.Error.Message)} + } + return + } + + // Process the streaming response + reader := bufio.NewReader(resp.Body) + for { + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + break + } + responseChan <- general_models.StreamResponse{Err: fmt.Errorf("error reading stream: %v", err)} + return + } + + // Skip ping events or irrelevant data + if strings.HasPrefix(line, "event: ping") || strings.TrimSpace(line) == "" { + continue + } + + // Parse response events + if strings.HasPrefix(line, "data: ") { + data := strings.TrimPrefix(line, "data: ") + var response models.AnthropicMessageResponse + if err := json.Unmarshal([]byte(data), &response); err != nil { + responseChan <- general_models.StreamResponse{Err: fmt.Errorf("error unmarshalling chunk: %v", err)} + return + } + + // Handle content and final message updates + switch response.Type { + case "content_block_delta": + if response.Delta.Type == "text_delta" { + markdownBuffer.WriteString(response.Delta.Text) + if strings.Contains(response.Delta.Text, "\n") { + responseChan <- general_models.StreamResponse{Content: markdownBuffer.String()} + markdownBuffer.Reset() + } + } + case "message_delta": + if response.Usage != nil { + usage = *response.Usage // Capture usage details + } + case "message_stop": + responseChan <- general_models.StreamResponse{Content: markdownBuffer.String(), Done: true} + if usage.TotalTokens > 0 { + anthropicProvider.TokenManagement.UsedTokens(usage.InputTokens, usage.OutputTokens) + } + return + } + } + } + + // Send any remaining content in the buffer + if markdownBuffer.Len() > 0 { + responseChan <- general_models.StreamResponse{Content: markdownBuffer.String()} + } + }() + + return responseChan +} diff --git a/providers/anthropic/models/anthropic_message_request.go b/providers/anthropic/models/anthropic_message_request.go new file mode 100644 index 0000000..11ee360 --- /dev/null +++ b/providers/anthropic/models/anthropic_message_request.go @@ -0,0 +1,15 @@ +package models + +// AnthropicMessageRequest represents the request body for Anthropic message. +type AnthropicMessageRequest struct { + Model string `json:"model"` // Model ID, e.g., "claude-3-5-sonnet-latest" + Messages []Message `json:"messages"` // Array of message history + Temperature *float32 `json:"temperature,omitempty"` // Sampling temperature (0.0-1.0) + Stream bool `json:"stream,omitempty"` // Enable/disable streaming +} + +// Message Define the request body structure +type Message struct { + Role string `json:"role"` // Valid roles: "system", "user", "assistant" + Content string `json:"content"` // The text content for this message +} diff --git a/providers/anthropic/models/anthropic_message_response.go b/providers/anthropic/models/anthropic_message_response.go new file mode 100644 index 0000000..76e2527 --- /dev/null +++ b/providers/anthropic/models/anthropic_message_response.go @@ -0,0 +1,40 @@ +package models + +// AnthropicMessageResponse represents the full response structure for Anthropic's chat completion API (streaming). +type AnthropicMessageResponse struct { + Type string `json:"type"` // Type of the response chunk, e.g., "message_start", "content_block_delta", etc. + Choices []Choice `json:"choices,omitempty"` // Array of choices for response content + Usage *Usage `json:"usage,omitempty"` // Optional token usage details (appears in certain chunks) + Delta *Delta `json:"delta,omitempty"` // Optional content updates or deltas +} + +// Choice represents an individual choice in the response. +type Choice struct { + Delta Delta `json:"delta"` // Streamed content delta +} + +// Delta represents the streamed content or updates. +type Delta struct { + Type string `json:"type,omitempty"` // Type of delta, e.g., "text_delta" + Text string `json:"text,omitempty"` // Text content streamed in chunks + StopReason string `json:"stop_reason,omitempty"` // Reason for stopping (e.g., "end_turn") +} + +// Usage represents token usage details for Anthropic responses. +type Usage struct { + InputTokens int `json:"input_tokens"` // Number of tokens in the input + OutputTokens int `json:"output_tokens"` // Number of tokens in the output + TotalTokens int `json:"total_tokens"` // Total tokens used +} + +// AnthropicError represents the error response structure from Anthropic's API. +type AnthropicError struct { + Type string `json:"type"` // Error type, e.g., "error" + Error Error `json:"error"` // Error details +} + +// Error represents detailed information about the error. +type Error struct { + Type string `json:"type"` // Error category, e.g., "invalid_request_error" + Message string `json:"message"` // Human-readable error message +} diff --git a/providers/openrouter/models/openrouter_chat_completion_request.go b/providers/openrouter/models/openrouter_chat_completion_request.go new file mode 100644 index 0000000..514025b --- /dev/null +++ b/providers/openrouter/models/openrouter_chat_completion_request.go @@ -0,0 +1,15 @@ +package models + +// OpenRouterChatCompletionRequest Define the request body structure +type OpenRouterChatCompletionRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Temperature *float32 `json:"temperature,omitempty"` // Optional field (pointer to float32) + Stream bool `json:"stream"` +} + +// Message Define the request body structure +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} diff --git a/providers/openrouter/models/openrouter_chat_completion_response.go b/providers/openrouter/models/openrouter_chat_completion_response.go new file mode 100644 index 0000000..6f1c511 --- /dev/null +++ b/providers/openrouter/models/openrouter_chat_completion_response.go @@ -0,0 +1,25 @@ +package models + +// OpenRouterChatCompletionResponse represents the entire response structure from OpenAI's chat completion API. +type OpenRouterChatCompletionResponse struct { + Choices []Choice `json:"choices"` // Array of choice completions + Usage Usage `json:"usage"` // Token usage details +} + +// Choice represents an individual choice in the response. +type Choice struct { + Delta Delta `json:"delta"` + FinishReason string `json:"finish_reason"` // Final chunk: reason for stopping (e.g., "stop"). +} + +// Delta represents the delta object in each choice containing the content. +type Delta struct { + Content string `json:"content"` +} + +// Usage defines the token usage information for the response. +type Usage struct { + PromptTokens int `json:"prompt_tokens"` // Number of tokens in the prompt + CompletionTokens int `json:"completion_tokens"` // Number of tokens in the completion + TotalTokens int `json:"total_tokens"` // Total tokens used +} diff --git a/providers/openrouter/openrouter_provider.go b/providers/openrouter/openrouter_provider.go new file mode 100644 index 0000000..7d2a6c5 --- /dev/null +++ b/providers/openrouter/openrouter_provider.go @@ -0,0 +1,182 @@ +package openrouter + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "github.com/meysamhadeli/codai/providers/contracts" + general_models "github.com/meysamhadeli/codai/providers/models" + "github.com/meysamhadeli/codai/providers/openrouter/models" + + contracts2 "github.com/meysamhadeli/codai/token_management/contracts" + "io" + "io/ioutil" + "net/http" + "strings" +) + +// OpenRouterConfig implements the Provider interface for OpenAPI. +type OpenRouterConfig struct { + ChatBaseURL string + ChatModel string + Temperature float32 + EncodingFormat string + ChatApiKey string + MaxTokens int + Threshold float64 + TokenManagement contracts2.ITokenManagement + ChatApiVersion string +} + +// NewOpenRouterChatProvider initializes a new OpenAPIProvider. +func NewOpenRouterChatProvider(config *OpenRouterConfig) contracts.IChatAIProvider { + return &OpenRouterConfig{ + ChatBaseURL: config.ChatBaseURL, + ChatModel: config.ChatModel, + Temperature: config.Temperature, + EncodingFormat: config.EncodingFormat, + MaxTokens: config.MaxTokens, + Threshold: config.Threshold, + ChatApiKey: config.ChatApiKey, + ChatApiVersion: config.ChatApiVersion, + TokenManagement: config.TokenManagement, + } +} + +func (openRouterProvider *OpenRouterConfig) ChatCompletionRequest(ctx context.Context, userInput string, prompt string) <-chan general_models.StreamResponse { + responseChan := make(chan general_models.StreamResponse) + var markdownBuffer strings.Builder // Buffer to accumulate content until newline + var usage models.Usage // Variable to hold usage data + + go func() { + defer close(responseChan) + + // Prepare the request body + reqBody := models.OpenRouterChatCompletionRequest{ + Model: openRouterProvider.ChatModel, + Messages: []models.Message{ + {Role: "system", Content: prompt}, + {Role: "user", Content: userInput}, + }, + Stream: true, + Temperature: &openRouterProvider.Temperature, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + markdownBuffer.Reset() + responseChan <- general_models.StreamResponse{Err: fmt.Errorf("error marshalling request body: %v", err)} + return + } + + // Create a new HTTP request + req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/api/v1/chat/completions", openRouterProvider.ChatBaseURL), bytes.NewBuffer(jsonData)) + if err != nil { + markdownBuffer.Reset() + responseChan <- general_models.StreamResponse{Err: fmt.Errorf("error creating request: %v", err)} + return + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", openRouterProvider.ChatApiKey)) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + markdownBuffer.Reset() + if errors.Is(ctx.Err(), context.Canceled) { + responseChan <- general_models.StreamResponse{Err: fmt.Errorf("request canceled: %v", err)} + return + } + responseChan <- general_models.StreamResponse{Err: fmt.Errorf("error sending request: %v", err)} + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + markdownBuffer.Reset() + body, _ := ioutil.ReadAll(resp.Body) + var apiError general_models.AIError + if err := json.Unmarshal(body, &apiError); err != nil { + responseChan <- general_models.StreamResponse{Err: fmt.Errorf("error parsing error response: %v", err)} + return + } + + responseChan <- general_models.StreamResponse{Err: fmt.Errorf("API request failed with status code '%d' - %s\n", resp.StatusCode, apiError.Error.Message)} + return + } + + reader := bufio.NewReader(resp.Body) + + // Stream processing + for { + line, err := reader.ReadString('\n') + if err != nil { + markdownBuffer.Reset() + if err == io.EOF { + break + } + responseChan <- general_models.StreamResponse{Err: fmt.Errorf("error reading stream: %v", err)} + return + } + + // Skip the [DONE] marker completely + if line == "data: [DONE]\n" { + continue + } + + if strings.HasPrefix(line, "data: ") { + jsonPart := strings.TrimPrefix(line, "data: ") + var response models.OpenRouterChatCompletionResponse + if err := json.Unmarshal([]byte(jsonPart), &response); err != nil { + markdownBuffer.Reset() + + responseChan <- general_models.StreamResponse{Err: fmt.Errorf("error unmarshalling chunk: %v", err)} + return + } + + // Check if the response has usage information + if response.Usage.TotalTokens > 0 { + usage = response.Usage // Capture the usage data for later use + } + + // Accumulate and send response content + if len(response.Choices) > 0 { + choice := response.Choices[0] + content := choice.Delta.Content + markdownBuffer.WriteString(content) + + // Send chunk if it contains a newline, and then reset the buffer + if strings.Contains(content, "\n") { + responseChan <- general_models.StreamResponse{Content: markdownBuffer.String()} + markdownBuffer.Reset() + } + + // Check for completion using FinishReason + if choice.FinishReason == "stop" { + responseChan <- general_models.StreamResponse{Content: markdownBuffer.String()} + + responseChan <- general_models.StreamResponse{Done: true} + + // Count total tokens usage + if usage.TotalTokens > 0 { + openRouterProvider.TokenManagement.UsedTokens(usage.PromptTokens, usage.CompletionTokens) + } + + break + } + } + } + } + + // Send any remaining content in the buffer + if markdownBuffer.Len() > 0 { + responseChan <- general_models.StreamResponse{Content: markdownBuffer.String()} + } + }() + + return responseChan +}