Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llms: Add stop sequence support #1092

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ require (
gitlab.com/golang-commonmark/linkify v0.0.0-20191026162114-a0c2df6c8f82 // indirect
gitlab.com/golang-commonmark/mdurl v0.0.0-20191124015652-932350d1cb84 // indirect
gitlab.com/golang-commonmark/puny v0.0.0-20191124015043-9f83538fa04f // indirect
go.mongodb.org/mongo-driver/v2 v2.0.0-beta1 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
Expand Down Expand Up @@ -220,6 +219,7 @@ require (
github.com/weaviate/weaviate-go-client/v4 v4.13.1
gitlab.com/golang-commonmark/markdown v0.0.0-20211110145824-bf3e522c626a
go.mongodb.org/mongo-driver v1.14.0
go.mongodb.org/mongo-driver/v2 v2.0.0-beta1
go.starlark.net v0.0.0-20230302034142-4b1e35fe2254
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1
golang.org/x/tools v0.14.0
Expand Down
170 changes: 107 additions & 63 deletions llms/anthropic/anthropicllm.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,15 @@ func generateCompletionsContent(ctx context.Context, o *LLM, messages []llms.Mes
}
prompt := fmt.Sprintf("\n\nHuman: %s\n\nAssistant:", partText.Text)
result, err := o.client.CreateCompletion(ctx, &anthropicclient.CompletionRequest{
Model: opts.Model,
Prompt: prompt,
MaxTokens: opts.MaxTokens,
StopWords: opts.StopWords,
Model: opts.Model,
Prompt: prompt,
MaxTokens: opts.MaxTokens,
StopWords: func() []string {
if len(opts.StopSequences) > 0 {
return opts.StopSequences
}
return opts.StopWords
}(),
Temperature: opts.Temperature,
TopP: opts.TopP,
StreamingFunc: opts.StreamingFunc,
Expand All @@ -129,18 +134,54 @@ func generateCompletionsContent(ctx context.Context, o *LLM, messages []llms.Mes
}

func generateMessagesContent(ctx context.Context, o *LLM, messages []llms.MessageContent, opts *llms.CallOptions) (*llms.ContentResponse, error) {
// Process messages and handle errors
chatMessages, systemPrompt, err := processMessages(messages)
if err != nil {
return nil, fmt.Errorf("anthropic: failed to process messages: %w", err)
}

// Create message and handle errors
result, err := createMessage(ctx, o, chatMessages, systemPrompt, opts)
if err != nil {
return nil, err
}
if result == nil {
return nil, ErrEmptyResponse
}

// Process content choices
choices := make([]*llms.ContentChoice, len(result.Content))
for i, content := range result.Content {
choice, err := processContent(content, result)
if err != nil {
return nil, fmt.Errorf("anthropic: failed to process content: %w", err)
}
choices[i] = choice
}

return &llms.ContentResponse{
Choices: choices,
}, nil
}

// Helper function to create message.
func createMessage(ctx context.Context, o *LLM, chatMessages []*anthropicclient.ChatMessage, systemPrompt string, opts *llms.CallOptions) (*anthropicclient.MessageResponsePayload, error) {
tools := toolsToTools(opts.Tools)
messages := make([]anthropicclient.ChatMessage, len(chatMessages))
for i, msg := range chatMessages {
messages[i] = *msg
}
result, err := o.client.CreateMessage(ctx, &anthropicclient.MessageRequest{
Model: opts.Model,
Messages: chatMessages,
System: systemPrompt,
MaxTokens: opts.MaxTokens,
StopWords: opts.StopWords,
Model: opts.Model,
Messages: messages,
System: systemPrompt,
MaxTokens: opts.MaxTokens,
StopWords: func() []string {
if len(opts.StopSequences) > 0 {
return opts.StopSequences
}
return opts.StopWords
}(),
Temperature: opts.Temperature,
TopP: opts.TopP,
Tools: tools,
Expand All @@ -152,60 +193,63 @@ func generateMessagesContent(ctx context.Context, o *LLM, messages []llms.Messag
}
return nil, fmt.Errorf("anthropic: failed to create message: %w", err)
}
if result == nil {
return nil, ErrEmptyResponse
return result, nil
}

// Helper function to process content.
func processContent(content anthropicclient.Content, result *anthropicclient.MessageResponsePayload) (*llms.ContentChoice, error) {
switch content.GetType() {
case "text":
return processTextContent(content, result)
case "tool_use":
return processToolUseContent(content, result)
default:
return nil, fmt.Errorf("anthropic: %w: %v", ErrUnsupportedContentType, content.GetType())
}
}

choices := make([]*llms.ContentChoice, len(result.Content))
for i, content := range result.Content {
switch content.GetType() {
case "text":
if textContent, ok := content.(*anthropicclient.TextContent); ok {
choices[i] = &llms.ContentChoice{
Content: textContent.Text,
StopReason: result.StopReason,
GenerationInfo: map[string]any{
"InputTokens": result.Usage.InputTokens,
"OutputTokens": result.Usage.OutputTokens,
},
}
} else {
return nil, fmt.Errorf("anthropic: %w for text message", ErrInvalidContentType)
}
case "tool_use":
if toolUseContent, ok := content.(*anthropicclient.ToolUseContent); ok {
argumentsJSON, err := json.Marshal(toolUseContent.Input)
if err != nil {
return nil, fmt.Errorf("anthropic: failed to marshal tool use arguments: %w", err)
}
choices[i] = &llms.ContentChoice{
ToolCalls: []llms.ToolCall{
{
ID: toolUseContent.ID,
FunctionCall: &llms.FunctionCall{
Name: toolUseContent.Name,
Arguments: string(argumentsJSON),
},
},
},
StopReason: result.StopReason,
GenerationInfo: map[string]any{
"InputTokens": result.Usage.InputTokens,
"OutputTokens": result.Usage.OutputTokens,
},
}
} else {
return nil, fmt.Errorf("anthropic: %w for tool use message", ErrInvalidContentType)
}
default:
return nil, fmt.Errorf("anthropic: %w: %v", ErrUnsupportedContentType, content.GetType())
}
// Helper function to process text content.
func processTextContent(content anthropicclient.Content, result *anthropicclient.MessageResponsePayload) (*llms.ContentChoice, error) {
textContent, ok := content.(*anthropicclient.TextContent)
if !ok {
return nil, fmt.Errorf("anthropic: %w for text message", ErrInvalidContentType)
}
return &llms.ContentChoice{
Content: textContent.Text,
StopReason: result.StopReason,
GenerationInfo: map[string]any{
"InputTokens": result.Usage.InputTokens,
"OutputTokens": result.Usage.OutputTokens,
},
}, nil
}

resp := &llms.ContentResponse{
Choices: choices,
// Helper function to process tool use content.
func processToolUseContent(content anthropicclient.Content, result *anthropicclient.MessageResponsePayload) (*llms.ContentChoice, error) {
toolUseContent, ok := content.(*anthropicclient.ToolUseContent)
if !ok {
return nil, fmt.Errorf("anthropic: %w for tool use message", ErrInvalidContentType)
}
return resp, nil
argumentsJSON, err := json.Marshal(toolUseContent.Input)
if err != nil {
return nil, fmt.Errorf("anthropic: failed to marshal tool use arguments: %w", err)
}
return &llms.ContentChoice{
ToolCalls: []llms.ToolCall{
{
ID: toolUseContent.ID,
FunctionCall: &llms.FunctionCall{
Name: toolUseContent.Name,
Arguments: string(argumentsJSON),
},
},
},
StopReason: result.StopReason,
GenerationInfo: map[string]any{
"InputTokens": result.Usage.InputTokens,
"OutputTokens": result.Usage.OutputTokens,
},
}, nil
}

func toolsToTools(tools []llms.Tool) []anthropicclient.Tool {
Expand All @@ -220,8 +264,8 @@ func toolsToTools(tools []llms.Tool) []anthropicclient.Tool {
return toolReq
}

func processMessages(messages []llms.MessageContent) ([]anthropicclient.ChatMessage, string, error) {
chatMessages := make([]anthropicclient.ChatMessage, 0, len(messages))
func processMessages(messages []llms.MessageContent) ([]*anthropicclient.ChatMessage, string, error) {
chatMessages := make([]*anthropicclient.ChatMessage, 0, len(messages))
systemPrompt := ""
for _, msg := range messages {
switch msg.Role {
Expand All @@ -236,19 +280,19 @@ func processMessages(messages []llms.MessageContent) ([]anthropicclient.ChatMess
if err != nil {
return nil, "", fmt.Errorf("anthropic: failed to handle human message: %w", err)
}
chatMessages = append(chatMessages, chatMessage)
chatMessages = append(chatMessages, &chatMessage)
case llms.ChatMessageTypeAI:
chatMessage, err := handleAIMessage(msg)
if err != nil {
return nil, "", fmt.Errorf("anthropic: failed to handle AI message: %w", err)
}
chatMessages = append(chatMessages, chatMessage)
chatMessages = append(chatMessages, &chatMessage)
case llms.ChatMessageTypeTool:
chatMessage, err := handleToolMessage(msg)
if err != nil {
return nil, "", fmt.Errorf("anthropic: failed to handle tool message: %w", err)
}
chatMessages = append(chatMessages, chatMessage)
chatMessages = append(chatMessages, &chatMessage)
case llms.ChatMessageTypeGeneric, llms.ChatMessageTypeFunction:
return nil, "", fmt.Errorf("anthropic: %w: %v", ErrUnsupportedMessageType, msg.Role)
default:
Expand Down
13 changes: 1 addition & 12 deletions llms/anthropic/internal/anthropicclient/anthropicclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,18 +139,7 @@ type MessageRequest struct {

// CreateMessage creates message for the messages api.
func (c *Client) CreateMessage(ctx context.Context, r *MessageRequest) (*MessageResponsePayload, error) {
resp, err := c.createMessage(ctx, &messagePayload{
Model: r.Model,
Messages: r.Messages,
System: r.System,
Temperature: r.Temperature,
MaxTokens: r.MaxTokens,
StopWords: r.StopWords,
TopP: r.TopP,
Tools: r.Tools,
Stream: r.Stream,
StreamingFunc: r.StreamingFunc,
})
resp, err := c.createMessage(ctx, r)
if err != nil {
return nil, err
}
Expand Down
12 changes: 6 additions & 6 deletions llms/anthropic/internal/anthropicclient/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type ChatMessage struct {
Content interface{} `json:"content"`
}

type messagePayload struct {
type MessagePayload struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
System string `json:"system,omitempty"`
Expand Down Expand Up @@ -142,7 +142,7 @@ func (m *MessageResponsePayload) UnmarshalJSON(data []byte) error {
return nil
}

func (c *Client) setMessageDefaults(payload *messagePayload) {
func (c *Client) setMessageDefaults(payload *MessageRequest) {
// Set defaults
if payload.MaxTokens == 0 {
payload.MaxTokens = 2048
Expand All @@ -168,7 +168,7 @@ func (c *Client) setMessageDefaults(payload *messagePayload) {
}
}

func (c *Client) createMessage(ctx context.Context, payload *messagePayload) (*MessageResponsePayload, error) {
func (c *Client) createMessage(ctx context.Context, payload *MessageRequest) (*MessageResponsePayload, error) {
c.setMessageDefaults(payload)

payloadBytes, err := json.Marshal(payload)
Expand Down Expand Up @@ -203,7 +203,7 @@ type MessageEvent struct {
Err error
}

func parseStreamingMessageResponse(ctx context.Context, r *http.Response, payload *messagePayload) (*MessageResponsePayload, error) {
func parseStreamingMessageResponse(ctx context.Context, r *http.Response, payload *MessageRequest) (*MessageResponsePayload, error) {
scanner := bufio.NewScanner(r.Body)
eventChan := make(chan MessageEvent)

Expand Down Expand Up @@ -248,7 +248,7 @@ func parseStreamEvent(data string) (map[string]interface{}, error) {
return event, err
}

func processStreamEvent(ctx context.Context, event map[string]interface{}, payload *messagePayload, response MessageResponsePayload, eventChan chan<- MessageEvent) (MessageResponsePayload, error) {
func processStreamEvent(ctx context.Context, event map[string]interface{}, payload *MessageRequest, response MessageResponsePayload, eventChan chan<- MessageEvent) (MessageResponsePayload, error) {
eventType, ok := event["type"].(string)
if !ok {
return response, ErrInvalidEventType
Expand Down Expand Up @@ -322,7 +322,7 @@ func handleContentBlockStartEvent(event map[string]interface{}, response Message
return response, nil
}

func handleContentBlockDeltaEvent(ctx context.Context, event map[string]interface{}, response MessageResponsePayload, payload *messagePayload) (MessageResponsePayload, error) {
func handleContentBlockDeltaEvent(ctx context.Context, event map[string]interface{}, response MessageResponsePayload, payload *MessageRequest) (MessageResponsePayload, error) {
indexValue, ok := event["index"].(float64)
if !ok {
return response, ErrInvalidIndexField
Expand Down
7 changes: 6 additions & 1 deletion llms/ollama/ollamallm.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,12 @@ func makeOllamaOptionsFromOptions(ollamaOptions ollamaclient.Options, opts llms.
// Load back CallOptions as ollamaOptions
ollamaOptions.NumPredict = opts.MaxTokens
ollamaOptions.Temperature = float32(opts.Temperature)
ollamaOptions.Stop = opts.StopWords
ollamaOptions.Stop = func() []string {
if len(opts.StopSequences) > 0 {
return opts.StopSequences
}
return opts.StopWords
}()
ollamaOptions.TopK = opts.TopK
ollamaOptions.TopP = float32(opts.TopP)
ollamaOptions.Seed = opts.Seed
Expand Down
9 changes: 7 additions & 2 deletions llms/openai/openaillm.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,13 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
chatMsgs = append(chatMsgs, msg)
}
req := &openaiclient.ChatRequest{
Model: opts.Model,
StopWords: opts.StopWords,
Model: opts.Model,
StopWords: func() []string {
if len(opts.StopSequences) > 0 {
return opts.StopSequences
}
return opts.StopWords
}(),
Messages: chatMsgs,
StreamingFunc: opts.StreamingFunc,
Temperature: opts.Temperature,
Expand Down
11 changes: 11 additions & 0 deletions llms/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ type CallOptions struct {
// Temperature is the temperature for sampling, between 0 and 1.
Temperature float64 `json:"temperature"`
// StopWords is a list of words to stop on.
// Deprecated: Use StopSequences instead.
StopWords []string `json:"stop_words"`
// StopSequences is a list of sequences to stop on.
// If both StopWords and StopSequences are provided, StopSequences takes precedence.
StopSequences []string `json:"stop_sequences,omitempty"`
// StreamingFunc is a function to be called for each chunk of a streaming response.
// Return an error to stop streaming early.
StreamingFunc func(ctx context.Context, chunk []byte) error `json:"-"`
Expand Down Expand Up @@ -148,6 +152,13 @@ func WithStopWords(stopWords []string) CallOption {
}
}

// WithStopSequences specifies a list of sequences to stop generation on.
func WithStopSequences(sequences []string) CallOption {
return func(o *CallOptions) {
o.StopSequences = sequences
}
}

// WithOptions specifies options.
func WithOptions(options CallOptions) CallOption {
return func(o *CallOptions) {
Expand Down
Loading
Loading