-
Notifications
You must be signed in to change notification settings - Fork 30
/
chatsearch.py
208 lines (160 loc) · 6.26 KB
/
chatsearch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# from https://docs.streamlit.io/develop/tutorials/llms/build-conversational-apps
import streamlit as st
from pydantic import BaseModel, Field
from langchain_upstage import ChatUpstage as Chat
from solar_util import initialize_solar_llm
from langchain_community.tools import DuckDuckGoSearchResults
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
ChatPromptTemplate,
MessagesPlaceholder,
PromptTemplate,
)
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.messages import AIMessage, HumanMessage
MAX_TOKENS = 4000
MAX_SEAERCH_RESULTS = 5
ddg_search = DuckDuckGoSearchResults()
llm = initialize_solar_llm()
st.set_page_config(page_title="Search and Chat", page_icon="🔍")
st.title("SolarLLM Search")
short_answer_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are Solar, a smart search engine by Upstage, loved by many people.
Write one word answer if you can say "yes", "no", or direct answer.
Otherwise just one or two sentense short answer for the query from the given conetxt.
Try to understand the user's intention and provide a quick answer.
If the answer is not in context, please say you don't know and ask to clarify the question.
""",
),
MessagesPlaceholder("chat_history"),
(
"human",
"""Query: {user_query}
----
Context: {context}""",
),
]
)
search_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are Solar, a smart search engine by Upstage, loved by many people.
See the origial query, context, and quick answer, and then provide detailed explanation.
Try to understand the user's intention and provide the relevant information in detail.
If the answer is not in context, please say you don't know and ask to clarify the question.
Do not repeat the short answer.
When you write the explnation, please cite the source like [1], [2] if possible.
Thyen, put the cited references including citation number, title, and URL at the end of the answer.
Each reference should be in a new line in the markdown format like this:
[1] Title - URL
[2] Title - URL
...
""",
),
MessagesPlaceholder("chat_history"),
(
"human",
"""Query: {user_query}
----
Short answer: {short_answer}
----
Context: {context}""",
),
]
)
query_context_expansion_prompt = """
For a given query and context(if provided), expand it with related questions and search the web for answers.
Try to understand the purpose of the query and expand with upto three related questions
to privde answer to the original query.
Note that it's for keyword-based search engines, so it should be short and concise.
Please write in Python LIST format like this:
["number of people in France?", How many people in France?", "France population"]
---
Context: {context}
----
History: {chat_history}
---
Orignal query: {query}
"""
# Define your desired data structure.
class List(BaseModel):
list[str]
def query_context_expansion(query, chat_history, context=None):
# Set up a parser + inject instructions into the prompt template.
parser = JsonOutputParser(pydantic_object=List)
prompt = PromptTemplate(
template=query_context_expansion_prompt,
input_variables=["query", "context"],
)
chain = prompt | llm | parser
# Invoke the chain with the joke_query.
for attempt in range(3):
try:
parsed_output = chain.invoke(
{"query": query, "chat_history": chat_history, "context": context}
)
return parsed_output
except Exception as e:
st.warning(f"Attempt {attempt + 1} failed. Retrying...")
st.error("All attempts failed. Returning empty list.")
return []
def get_short_search(user_query, context, chat_history):
chain = short_answer_prompt | llm | StrOutputParser()
return chain.stream(
{
"context": context,
"chat_history": chat_history,
"user_query": user_query,
}
)
def get_search_desc(user_query, short_answer, context, chat_history):
chain = search_prompt | llm | StrOutputParser()
return chain.stream(
{
"context": context,
"chat_history": chat_history,
"user_query": user_query,
"short_answer": short_answer,
}
)
def search(query, chat_history, context=None):
with st.status("Extending query with context to related questions..."):
q_list = query_context_expansion(query, chat_history, context)
st.write(q_list)
if not q_list:
return []
# combine all queries with "OR" operator
or_merged_search_query = " OR ".join(q_list)
with st.spinner(f"Searching for '{or_merged_search_query}'..."):
results = ddg_search.invoke(or_merged_search_query)
return results
if "messages" not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
role = "AI" if isinstance(message, AIMessage) else "Human"
with st.chat_message(role):
st.markdown(message.content)
q = "How to use residence parking permit in palo alto?"
if prompt := st.chat_input(q):
st.session_state.messages.append(HumanMessage(content=prompt))
with st.chat_message("user"):
st.markdown(prompt)
r1 = search(prompt, st.session_state.messages)
result1_summary = str(r1)
r2 = search(prompt, st.session_state.messages, result1_summary[:MAX_TOKENS])
context = str(r1 + r2)
context = context[:MAX_TOKENS]
with st.status("Search Results:"):
st.write(context)
with st.chat_message("assistant"):
short_answer = st.write_stream(
get_short_search(prompt, context, st.session_state.messages)
)
desc = st.write_stream(
get_search_desc(prompt, short_answer, context, st.session_state.messages)
)
st.session_state.messages.append(AIMessage(content=short_answer + desc))