diff --git a/chains/llm.go b/chains/llm.go index e280d197d..9263ae066 100644 --- a/chains/llm.go +++ b/chains/llm.go @@ -19,6 +19,8 @@ type LLMChain struct { Memory schema.Memory CallbacksHandler callbacks.Handler OutputParser schema.OutputParser[any] + // When enabled usesMultiplePrompts will not 'flatten' the prompt into a single message. + UseMultiPrompt bool OutputKey string } @@ -41,6 +43,7 @@ func NewLLMChain(llm llms.Model, prompt prompts.FormatPrompter, opts ...ChainCal Memory: memory.NewSimple(), OutputKey: _llmChainDefaultOutputKey, CallbacksHandler: opt.CallbackHandler, + UseMultiPrompt: false, } return chain @@ -56,12 +59,17 @@ func (c LLMChain) Call(ctx context.Context, values map[string]any, options ...Ch return nil, err } - result, err := llms.GenerateFromMultiPrompt(ctx, c.LLM, chatMessagesToLLmMessageContent(promptValue.Messages()), getLLMCallOptions(options...)...) + var output string + if c.UseMultiPrompt { + output, err = llms.GenerateFromMultiPrompt(ctx, c.LLM, chatMessagesToLLmMessageContent(promptValue.Messages()), getLLMCallOptions(options...)...) + } else { + output, err = llms.GenerateFromSinglePrompt(ctx, c.LLM, promptValue.String(), getLLMCallOptions(options...)...) + } if err != nil { return nil, err } - finalOutput, err := c.OutputParser.ParseWithPrompt(result, promptValue) + finalOutput, err := c.OutputParser.ParseWithPrompt(output, promptValue) if err != nil { return nil, err }