Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add anthropic and openrouter providers #91

Merged
merged 1 commit into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,21 @@ The `codai-config` file should be like following example base on your `AI provid
**codai-config.yml**
```yml
ai_provider_config:
chat_provider_name: "openai" # openai | ollama | azure-openai
chat_base_url: "https://api.openai.com" # "http://localhost:11434" | "https://test,openai.azure.com"
chat_provider_name: "openai" # openai | ollama | azure-openai | anthropic | openrouter
chat_base_url: "https://api.openai.com" # "http://localhost:11434" | "https://test,openai.azure.com" | "https://api.anthropic.com" | "https://openrouter.ai"
chat_model: "gpt-4o"
chat_api_version: "2024-04-01-preview" #(Optional, If your AI provider like AzureOpenai has chat api version.)
chat_api_version: "2024-04-01-preview" #(Optional, If your AI provider like 'AzureOpenai' or 'Anthropic' has chat api version.)
embeddings_provider_name: "openai" # openai | ollama | azure-openai
embeddings_base_url: "https://api.openai.com" # "http://localhost:11434" | "https://test,openai.azure.com"
embeddings_model: "text-embedding-3-small" #(Optional, If you want use RAG.)
embeddings_api_version: "2024-01-01-preview" #(Optional, If your AI provider like AzureOpenai has embeddings api version.)
embeddings_model: "text-embedding-3-small" #(Optional, If you want use 'RAG'.)
embeddings_api_version: "2024-01-01-preview" #(Optional, If your AI provider like 'AzureOpenai' has embeddings api version.)
temperature: 0.2
threshold: 0.2 #(Optional, If you want use RAG.)
threshold: 0.2 #(Optional, If you want use 'RAG'.)
theme: "dracula"
rag: true #(Optional, If you want use RAG.)
rag: true #(Optional, If you want use 'RAG'.)
```

> Note: We used the standard integration of [OpenAI APIs](https://platform.openai.com/docs/api-reference/introduction), [Ollama APIs](https://github.com/ollama/ollama/blob/main/docs/api.md) and [Azure Openai](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) and you can find more details in documentation of each APIs.
> Note: We used the standard integration of [OpenAI APIs](https://platform.openai.com/docs/api-reference/introduction), [Ollama APIs](https://github.com/ollama/ollama/blob/main/docs/api.md), [Azure Openai](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference), [Anthropic](https://docs.anthropic.com/en/api/getting-started), [OpenRouter](https://openrouter.ai/docs/quick-start) and you can find more details in documentation of each AI provider APIs.

If you wish to customize your configuration, you can create your own `codai-config.yml` file and place it in the `root directory` of `each project` you want to analyze with codai. If `no configuration` file is provided, codai will use the `default settings`.

Expand Down
30 changes: 28 additions & 2 deletions providers/ai_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package providers

import (
"errors"
"github.com/meysamhadeli/codai/providers/anthropic"
azure_openai "github.com/meysamhadeli/codai/providers/azure-openai"
"github.com/meysamhadeli/codai/providers/contracts"
"github.com/meysamhadeli/codai/providers/ollama"
"github.com/meysamhadeli/codai/providers/openai"
"github.com/meysamhadeli/codai/providers/openrouter"
contracts2 "github.com/meysamhadeli/codai/token_management/contracts"
)

Expand Down Expand Up @@ -56,7 +58,7 @@ func ChatProviderFactory(config *AIProviderConfig, tokenManagement contracts2.IT
ChatApiVersion: config.ChatApiVersion,
EmbeddingsApiVersion: config.EmbeddingsApiVersion,
}), nil
case "azure-openai", "azure_openai":
case "azure-openai":
return azure_openai.NewAzureOpenAIChatProvider(&azure_openai.AzureOpenAIConfig{
Temperature: config.Temperature,
EncodingFormat: config.EncodingFormat,
Expand All @@ -71,6 +73,30 @@ func ChatProviderFactory(config *AIProviderConfig, tokenManagement contracts2.IT
ChatApiVersion: config.ChatApiVersion,
EmbeddingsApiVersion: config.EmbeddingsApiVersion,
}), nil

case "openrouter":
return openrouter.NewOpenRouterChatProvider(&openrouter.OpenRouterConfig{
Temperature: config.Temperature,
EncodingFormat: config.EncodingFormat,
ChatModel: config.ChatModel,
ChatApiKey: config.ChatApiKey,
MaxTokens: config.MaxTokens,
Threshold: config.Threshold,
TokenManagement: tokenManagement,
ChatApiVersion: config.ChatApiVersion,
}), nil

case "anthropic":
return anthropic.NewAnthropicMessageProvider(&anthropic.AnthropicConfig{
Temperature: config.Temperature,
EncodingFormat: config.EncodingFormat,
MessageModel: config.ChatModel,
MessageApiKey: config.ChatApiKey,
MaxTokens: config.MaxTokens,
Threshold: config.Threshold,
TokenManagement: tokenManagement,
MessageApiVersion: config.ChatApiVersion,
}), nil
default:

return nil, errors.New("unsupported provider")
Expand Down Expand Up @@ -106,7 +132,7 @@ func EmbeddingsProviderFactory(config *AIProviderConfig, tokenManagement contrac
ChatApiVersion: config.ChatApiVersion,
EmbeddingsApiVersion: config.EmbeddingsApiVersion,
}), nil
case "azure-openai", "azure_openai":
case "azure-openai":
return azure_openai.NewAzureOpenAIEmbeddingsProvider(&azure_openai.AzureOpenAIConfig{
Temperature: config.Temperature,
EncodingFormat: config.EncodingFormat,
Expand Down
167 changes: 167 additions & 0 deletions providers/anthropic/anthropic_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package anthropic

import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/meysamhadeli/codai/providers/anthropic/models"
"github.com/meysamhadeli/codai/providers/contracts"
general_models "github.com/meysamhadeli/codai/providers/models"
contracts2 "github.com/meysamhadeli/codai/token_management/contracts"
"io"
"io/ioutil"
"net/http"
"strings"
)

// AnthropicConfig implements the Provider interface for OpenAPI.
type AnthropicConfig struct {
MessageBaseURL string
MessageModel string
Temperature float32
EncodingFormat string
MessageApiKey string
MaxTokens int
Threshold float64
TokenManagement contracts2.ITokenManagement
MessageApiVersion string
}

// NewAnthropicMessageProvider initializes a new OpenAPIProvider.
func NewAnthropicMessageProvider(config *AnthropicConfig) contracts.IChatAIProvider {
return &AnthropicConfig{
MessageBaseURL: config.MessageBaseURL,
MessageModel: config.MessageModel,
Temperature: config.Temperature,
EncodingFormat: config.EncodingFormat,
MaxTokens: config.MaxTokens,
Threshold: config.Threshold,
MessageApiKey: config.MessageApiKey,
MessageApiVersion: config.MessageApiVersion,
TokenManagement: config.TokenManagement,
}
}

func (anthropicProvider *AnthropicConfig) ChatCompletionRequest(ctx context.Context, userInput string, prompt string) <-chan general_models.StreamResponse {
responseChan := make(chan general_models.StreamResponse)
var markdownBuffer strings.Builder // Accumulate content for streaming responses
var usage models.Usage // To track token usage

go func() {
defer close(responseChan)

// Prepare the request body
reqBody := models.AnthropicMessageRequest{
Messages: []models.Message{
{Role: "system", Content: prompt},
{Role: "user", Content: userInput},
},
Model: anthropicProvider.MessageModel,
Temperature: &anthropicProvider.Temperature,
Stream: true,
}

jsonData, err := json.Marshal(reqBody)
if err != nil {
responseChan <- general_models.StreamResponse{Err: fmt.Errorf("error marshalling request body: %v", err)}
return
}

// Create the HTTP request
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("%s/v1/messages", anthropicProvider.MessageBaseURL), bytes.NewBuffer(jsonData))
if err != nil {
responseChan <- general_models.StreamResponse{Err: fmt.Errorf("error creating HTTP request: %v", err)}
return
}

// Set required headers
req.Header.Set("content-type", "application/json") // Required content type
req.Header.Set("anthropic-version", anthropicProvider.MessageApiVersion) // Required API version
req.Header.Set("x-api-key", anthropicProvider.MessageApiKey) // API key for authentication

// Send the HTTP request
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
if errors.Is(ctx.Err(), context.Canceled) {
responseChan <- general_models.StreamResponse{Err: fmt.Errorf("request canceled: %v", err)}
} else {
responseChan <- general_models.StreamResponse{Err: fmt.Errorf("error sending request: %v", err)}
}
return
}
defer resp.Body.Close()

// Handle non-200 status codes
if resp.StatusCode != http.StatusOK {
body, _ := ioutil.ReadAll(resp.Body)
var apiError models.AnthropicError
if err := json.Unmarshal(body, &apiError); err != nil {
responseChan <- general_models.StreamResponse{Err: fmt.Errorf("error parsing error response: %v", err)}
} else {
responseChan <- general_models.StreamResponse{Err: fmt.Errorf("API request failed: %s", apiError.Error.Message)}
}
return
}

// Process the streaming response
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
break
}
responseChan <- general_models.StreamResponse{Err: fmt.Errorf("error reading stream: %v", err)}
return
}

// Skip ping events or irrelevant data
if strings.HasPrefix(line, "event: ping") || strings.TrimSpace(line) == "" {
continue
}

// Parse response events
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
var response models.AnthropicMessageResponse
if err := json.Unmarshal([]byte(data), &response); err != nil {
responseChan <- general_models.StreamResponse{Err: fmt.Errorf("error unmarshalling chunk: %v", err)}
return
}

// Handle content and final message updates
switch response.Type {
case "content_block_delta":
if response.Delta.Type == "text_delta" {
markdownBuffer.WriteString(response.Delta.Text)
if strings.Contains(response.Delta.Text, "\n") {
responseChan <- general_models.StreamResponse{Content: markdownBuffer.String()}
markdownBuffer.Reset()
}
}
case "message_delta":
if response.Usage != nil {
usage = *response.Usage // Capture usage details
}
case "message_stop":
responseChan <- general_models.StreamResponse{Content: markdownBuffer.String(), Done: true}
if usage.TotalTokens > 0 {
anthropicProvider.TokenManagement.UsedTokens(usage.InputTokens, usage.OutputTokens)
}
return
}
}
}

// Send any remaining content in the buffer
if markdownBuffer.Len() > 0 {
responseChan <- general_models.StreamResponse{Content: markdownBuffer.String()}
}
}()

return responseChan
}
15 changes: 15 additions & 0 deletions providers/anthropic/models/anthropic_message_request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package models

// AnthropicMessageRequest represents the request body for Anthropic message.
type AnthropicMessageRequest struct {
Model string `json:"model"` // Model ID, e.g., "claude-3-5-sonnet-latest"
Messages []Message `json:"messages"` // Array of message history
Temperature *float32 `json:"temperature,omitempty"` // Sampling temperature (0.0-1.0)
Stream bool `json:"stream,omitempty"` // Enable/disable streaming
}

// Message Define the request body structure
type Message struct {
Role string `json:"role"` // Valid roles: "system", "user", "assistant"
Content string `json:"content"` // The text content for this message
}
40 changes: 40 additions & 0 deletions providers/anthropic/models/anthropic_message_response.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package models

// AnthropicMessageResponse represents the full response structure for Anthropic's chat completion API (streaming).
type AnthropicMessageResponse struct {
Type string `json:"type"` // Type of the response chunk, e.g., "message_start", "content_block_delta", etc.
Choices []Choice `json:"choices,omitempty"` // Array of choices for response content
Usage *Usage `json:"usage,omitempty"` // Optional token usage details (appears in certain chunks)
Delta *Delta `json:"delta,omitempty"` // Optional content updates or deltas
}

// Choice represents an individual choice in the response.
type Choice struct {
Delta Delta `json:"delta"` // Streamed content delta
}

// Delta represents the streamed content or updates.
type Delta struct {
Type string `json:"type,omitempty"` // Type of delta, e.g., "text_delta"
Text string `json:"text,omitempty"` // Text content streamed in chunks
StopReason string `json:"stop_reason,omitempty"` // Reason for stopping (e.g., "end_turn")
}

// Usage represents token usage details for Anthropic responses.
type Usage struct {
InputTokens int `json:"input_tokens"` // Number of tokens in the input
OutputTokens int `json:"output_tokens"` // Number of tokens in the output
TotalTokens int `json:"total_tokens"` // Total tokens used
}

// AnthropicError represents the error response structure from Anthropic's API.
type AnthropicError struct {
Type string `json:"type"` // Error type, e.g., "error"
Error Error `json:"error"` // Error details
}

// Error represents detailed information about the error.
type Error struct {
Type string `json:"type"` // Error category, e.g., "invalid_request_error"
Message string `json:"message"` // Human-readable error message
}
15 changes: 15 additions & 0 deletions providers/openrouter/models/openrouter_chat_completion_request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package models

// OpenRouterChatCompletionRequest Define the request body structure
type OpenRouterChatCompletionRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Temperature *float32 `json:"temperature,omitempty"` // Optional field (pointer to float32)
Stream bool `json:"stream"`
}

// Message Define the request body structure
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
25 changes: 25 additions & 0 deletions providers/openrouter/models/openrouter_chat_completion_response.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package models

// OpenRouterChatCompletionResponse represents the entire response structure from OpenAI's chat completion API.
type OpenRouterChatCompletionResponse struct {
Choices []Choice `json:"choices"` // Array of choice completions
Usage Usage `json:"usage"` // Token usage details
}

// Choice represents an individual choice in the response.
type Choice struct {
Delta Delta `json:"delta"`
FinishReason string `json:"finish_reason"` // Final chunk: reason for stopping (e.g., "stop").
}

// Delta represents the delta object in each choice containing the content.
type Delta struct {
Content string `json:"content"`
}

// Usage defines the token usage information for the response.
type Usage struct {
PromptTokens int `json:"prompt_tokens"` // Number of tokens in the prompt
CompletionTokens int `json:"completion_tokens"` // Number of tokens in the completion
TotalTokens int `json:"total_tokens"` // Total tokens used
}
Loading