-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate.py
33 lines (26 loc) · 1 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from typing import Any, Dict
from langchain_core.language_models.chat_models import BaseChatModel
from agent.graph.chains.generation import GenerationChain
from agent.graph.nodes.node import Node
from agent.graph.state import GraphState
class GenerateNode(Node):
def __init__(
self, model: BaseChatModel, get_chat_history: callable, config: dict
) -> None:
self.generator = GenerationChain(model, get_chat_history)
self.generation_chain = self.generator.get_chain()
self.config = config
def action(self, state: GraphState) -> Dict[str, Any]:
print("---GENERATE---")
question = state["question"]
documents = state["documents"]
generation = state["generation"]
generation = self.generation_chain.invoke(
{"context": documents, "question": question},
config=self.config,
)
return {
"documents": documents,
"question": question,
"generation": generation,
}