-
Notifications
You must be signed in to change notification settings - Fork 65
/
Copy pathapp.py
585 lines (518 loc) · 21 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
#!/usr/bin/env python
# -*- coding:utf-8 _*-
"""
@author:quincy qiang
@license: Apache Licence
@file: app.py
@time: 2024/05/21
@contact: [email protected]
"""
import sys
sys.path.append(".")
import os
import shutil
import time
import gradio as gr
import loguru
import pandas as pd
from trustrag.applications.rag_openai import RagApplication, ApplicationConfig
from trustrag.modules.reranker.bge_reranker import BgeRerankerConfig
from trustrag.modules.retrieval.dense_retriever import DenseRetrieverConfig
from datetime import datetime
import pytz
# ========================== Config Start====================
app_config = ApplicationConfig()
app_config.docs_path = r"H:\Projects\TrustRAG\data\docs"
app_config.base_url = "https://www.dmxapi.com/v1"
app_config.api_key = "sk-gDbFoQAYz9pwqBsH0aPA1H8DN9s0B9F3vWNjjPcijRBFjk7f"
app_config.model_name = "gpt-4o-all"
retriever_config = DenseRetrieverConfig(
dim=3072,
base_url="https://www.dmxapi.com/v1",
api_key= "sk-gDbFoQAYz9pwqBsH0aPA1H8DN9s0B9F3vWNjjPcijRBFjk7f",
embedding_model_name='text-embedding-3-large',
index_path=r'H:\Projects\TrustRAG\examples\retrievers\dense_cache'
)
# rerank_config = BgeRerankerConfig(
# model_name_or_path=r"H:\pretrained_models\mteb\bge-reranker-large"
# )
app_config.retriever_config = retriever_config
# app_config.rerank_config = rerank_config
application = RagApplication(app_config)
application.init_vector_store()
# ========================== Config End====================
# 创建北京时区的变量
beijing_tz = pytz.timezone("Asia/Shanghai")
IGNORE_FILE_LIST = [".DS_Store"]
class CustomUploadFile:
def __init__(self, file): # 修改构造函数接收 gradio 文件对象
self.file = file # 保存整个文件对象
self.file_name = os.path.basename(file.name)
self.start_time = datetime.now(beijing_tz)
self.end_time = None
self.duration = None
self.state = None
self.finished = False
def update_process_duration(self):
if not self.finished:
self.end_time = datetime.now(beijing_tz)
self.duration = (self.end_time - self.start_time).total_seconds()
return self.duration
def update_state(self, state):
self.state = state
def is_finished(self):
self.finished = True
def __info__(self):
return [
self.start_time.strftime("%Y-%m-%d %H:%M:%S"),
self.end_time.strftime("%Y-%m-%d %H:%M:%S"),
self.duration,
self.state,
]
def upload_files(
upload_files,
chunk_size,
chunk_overlap,
upload_index,
):
if not upload_files:
return [
gr.update(visible=False),
gr.update(
visible=True,
value="No file selected. Please choose at least one file.",
),
]
for state_info in upload_knowledge(
upload_files=upload_files,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
index_name=upload_index,
):
yield state_info
def upload_knowledge(
upload_files,
chunk_size,
chunk_overlap,
index_name
):
my_upload_files = []
# 处理上传的文件
for file in upload_files:
if os.path.basename(file.name) not in IGNORE_FILE_LIST:
my_upload_files.append(CustomUploadFile(file))
result = {"Info": ["StartTime", "EndTime", "Duration(s)", "Status"]}
error_msg = None
success_msg = None
while True:
for file in my_upload_files:
try:
cache_base_dir = app_config.docs_path
if not os.path.exists(cache_base_dir):
os.makedirs(cache_base_dir) # 使用 makedirs 替代 mkdir
# 使用正确的源文件路径
source_path = file.file.name # 使用 gradio 上传文件的临时路径
dest_path = os.path.join(cache_base_dir, file.file_name) # 使用 os.path.join 构建路径
# 复制文件而不是移动,因为我们在处理临时文件
shutil.copy2(source_path, dest_path)
# 添加文档
response = application.add_document(dest_path)
file.update_state(response["status"])
file.update_process_duration()
result[file.file_name] = file.__info__()
if response["status"] in ["completed", "failed"]:
file.is_finished()
if response["detail"]:
success_msg = response["detail"]
except Exception as api_error:
error_msg = str(api_error)
file.update_state("failed")
file.is_finished()
result[file.file_name] = file.__info__()
yield [
gr.update(visible=True, value=pd.DataFrame(result)),
gr.update(visible=False),
]
if all(file.finished for file in my_upload_files):
break
time.sleep(2)
upload_result = f"Upload success:{success_msg}" if not error_msg else f"Upload failed: {error_msg}"
yield [
gr.update(visible=True, value=pd.DataFrame(result)),
gr.update(visible=True, value=upload_result),
]
def clear_files():
yield [
gr.update(visible=False, value=pd.DataFrame()),
gr.update(visible=False, value=""),
]
def get_file_info(file_path):
"""Get detailed information about a file"""
stats = os.stat(file_path)
creation_time = datetime.fromtimestamp(stats.st_ctime, beijing_tz)
modified_time = datetime.fromtimestamp(stats.st_mtime, beijing_tz)
return {
"Filename": os.path.basename(file_path),
"Size(KB)": round(stats.st_size / 1024, 2),
"CreateTime": creation_time.strftime("%Y-%m-%d %H:%M:%S"),
"UpdateTime": modified_time.strftime("%Y-%m-%d %H:%M:%S"),
"FileType": os.path.splitext(file_path)[1].lower()
}
def list_documents():
"""List all documents in the docs directory"""
if not os.path.exists(app_config.docs_path):
return pd.DataFrame()
files_info = []
for file_name in os.listdir(app_config.docs_path):
if file_name not in IGNORE_FILE_LIST:
file_path = os.path.join(app_config.docs_path, file_name)
if os.path.isfile(file_path):
files_info.append(get_file_info(file_path))
return pd.DataFrame(files_info)
def refresh_file_list():
return gr.update(value=list_documents())
def delete_selected_file(selected_file_name):
"""Delete the selected file and return updated file list"""
if not selected_file_name:
gr.Info("Please select a file first")
return gr.update(value=list_documents()), gr.update(visible=False)
try:
file_path = os.path.join(app_config.docs_path, selected_file_name)
if os.path.exists(file_path):
gr.Info(f"Successfully deleted file and rebuild faiss index: {selected_file_name}")
os.remove(file_path)
application.init_vector_store()
return gr.update(value=list_documents()), gr.update(visible=False)
else:
gr.Warning(f"File not found: {selected_file_name}")
return gr.update(value=list_documents()), gr.update(visible=False)
except Exception as e:
gr.Error(f"Error deleting file: {str(e)}")
return gr.update(value=list_documents()), gr.update(visible=False)
def get_file_chunks(file_path, chunk_size):
"""Get chunks from a file and return as a DataFrame"""
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
chunks = application.tc.get_chunks([content],chunk_size)
# Create DataFrame with chunk information
chunks_data = []
for idx, chunk in enumerate(chunks, 1):
chunks_data.append({
"Chunk ID": idx,
"Content": chunk,
"Length": len(chunk)
})
return pd.DataFrame(chunks_data)
except Exception as e:
print(f"Error reading file: {str(e)}")
return pd.DataFrame()
def on_file_select(files_df, chunk_size, evt: gr.SelectData):
"""Handle file selection event"""
if not evt.value:
return None, None, gr.update(visible=False), pd.DataFrame()
try:
selected_filename = files_df.to_dict('records')[evt.index[0]]['Filename']
file_path = os.path.join(app_config.docs_path, selected_filename)
# Get chunks information
chunks_df = get_file_chunks(file_path, chunk_size)
return selected_filename, f"Selected file: {selected_filename}", gr.update(visible=True), gr.update(visible=True,value=chunks_df)
except (KeyError, IndexError):
return None, None, gr.update(visible=False), pd.DataFrame()
def clear_session():
return '', None
def shorten_label(text, max_length=10):
if len(text) > 2 * max_length:
return text[:max_length] + "..." + text[-max_length:]
return text
def predict(question,
large_language_model,
embedding_model,
top_k,
use_web,
use_pattern,
history=None):
loguru.logger.info("User Question:" + question)
if history is None:
history = []
# Handle web content
web_content = ''
if use_web == 'Use':
loguru.logger.info("Use Web Search")
results = application.web_searcher.retrieve(query=question, top_k=5)
for search_result in results:
web_content += search_result['title'] + " " + search_result['body'] + "\n"
search_text = ''
if use_pattern == 'Only LLM':
# Handle model Q&A mode
loguru.logger.info('Only LLM Mode:')
# result = application.llm.chat(query=question, web_content=web_content)
system_prompt = "You are a helpful assistant."
user_input = [
{"role": "user", "content": question}
]
# 调用 chat 方法进行对话
result, total_tokens = application.llm.chat(system=system_prompt, history=user_input)
history.append((question, result))
search_text += web_content
# Return empty judge results for Q&A mode
checkboxes = []
for item in range(5):
checkbox = gr.Checkbox(value=False, visible=False, interactive=False)
checkboxes.append(checkbox)
return '', history, history, search_text, '', checkboxes[0], checkboxes[1], checkboxes[2], checkboxes[3], \
checkboxes[4]
else:
# Handle RAG mode
loguru.logger.info('RAG Mode:')
response, _, contents, rewrite_query = application.chat(
question=question,
top_k=top_k,
)
history.append((question, response))
# Format search results
for idx, source in enumerate(contents):
sep = f'----------【搜索结果{idx + 1}:】---------------\n'
search_text += f'{sep}\n{source["text"]}\n分数:{source["score"]:.2f}\n\n'
# Add web content if available
if web_content:
search_text += "----------【网络检索内容】-----------\n"
search_text += web_content
checkboxes = []
for idx,item in enumerate(contents[:5]):
checked = bool(item.get('label', 0))
label_text = item.get('text', '')
shortened_label = str(idx+1)+"."+shorten_label(label_text)
checkbox = gr.Checkbox(value=checked, visible=True, label=shortened_label, interactive=True)
checkboxes.append(checkbox)
return '', history, history, search_text, rewrite_query, checkboxes[0], checkboxes[1], checkboxes[2], \
checkboxes[3], checkboxes[4]
with gr.Blocks(theme="soft") as demo:
gr.Markdown("""<h1><center>TrustRAG Studio</center></h1><center><font size=3></center></font>""")
# ===== tab start
with gr.Tab("\N{rocket} Corpus"):
with gr.Row():
with gr.Column(scale=2):
upload_index = gr.Dropdown(
choices=["default_index"],
value="default_index",
label="\N{book} Knowledge Name",
elem_id="knowledge_name",
)
chunk_size = gr.Slider(
minimum=128,
maximum=1024,
value=128,
step=64,
label="\N{GEAR} Chunk Size",
info="Split Document With The Chunk Size",
interactive=True
)
chunk_overlap = gr.Slider(
minimum=0,
maximum=128,
value=0,
step=1,
label="\N{GEAR} Chunk Overlap",
info="Chunk Overlap Within Chunks",
interactive=True
)
enable_decontextualization = gr.Checkbox(
label="Yes",
value=True,
info="Process with Contextual Decontextualization",
elem_id="enable_Decontextualization",
visible=True,
)
with gr.Column(scale=8):
selected_file = gr.State(None)
files_df = gr.DataFrame(
value=list_documents(),
label="Current Documents",
interactive=False,
)
selected_file_info = gr.Markdown(visible=True)
with gr.Row():
refresh_btn = gr.Button("Refresh Files List")
delete_btn = gr.Button("Delete Selected File", visible=False) # Initially hidden delete button
with gr.Row():
# Add chunks display DataFrame
chunks_display = gr.DataFrame(
label="Document Chunks",
value=pd.DataFrame(),
visible=False,
interactive=False
)
refresh_btn.click(fn=refresh_file_list, outputs=[files_df])
files_df.select(
fn=on_file_select,
inputs=[files_df, chunk_size],
outputs=[selected_file, selected_file_info, delete_btn, chunks_display]
)
# Handle file deletion
delete_btn.click(
fn=delete_selected_file,
inputs=[selected_file],
outputs=[files_df, delete_btn]
)
with gr.Tab("Files"):
upload_file = gr.File(
label="Upload a knowledge file.", file_count="multiple"
)
upload_file_state_df = gr.DataFrame(
label="Upload Status Info", visible=False
)
upload_file_state = gr.Textbox(label="Upload Status", visible=False)
with gr.Tab("Directory"):
upload_file_dir = gr.File(
label="Upload a knowledge directory.",
file_count="directory",
)
upload_dir_state_df = gr.DataFrame(
label="Upload Status Info", visible=False
)
upload_dir_state = gr.Textbox(label="Upload Status", visible=False)
upload_file.upload(
fn=upload_files,
inputs=[
upload_file,
chunk_size,
chunk_overlap,
upload_index,
],
outputs=[upload_file_state_df, upload_file_state],
api_name="upload_knowledge",
)
upload_file.clear(
fn=clear_files,
inputs=[],
outputs=[upload_file_state_df, upload_file_state],
api_name="clear_file",
)
dummy_component = gr.Textbox(visible=False, value="")
upload_file_dir.upload(
fn=upload_knowledge,
inputs=[
upload_file_dir,
chunk_size,
chunk_overlap,
upload_index,
],
outputs=[upload_dir_state_df, upload_dir_state],
api_name="upload_knowledge_dir",
)
upload_file_dir.clear(
fn=clear_files,
inputs=[],
outputs=[upload_dir_state_df, upload_dir_state],
api_name="clear_file_dir",
)
with gr.Tab("\N{fire} Chat"):
state = gr.State()
with gr.Row():
with gr.Column(scale=1):
embedding_model = gr.Dropdown(
choices=[
"text-embedding-3-large"
],
label="Embedding model",
value="text-embedding-3-large"
)
large_language_model = gr.Dropdown(
choices=[
"GPT-4O-ALL",
],
label="Large Language model",
value="GPT-4O-ALL"
)
top_k = gr.Slider(
minimum=1,
maximum=20,
value=5,
step=1,
label="Retrieve Top-k Documents",
interactive=True
)
use_web = gr.Radio(
choices=["Use", "Not used"],
label="Web Search",
info="Do you use network search? When using it, make sure the network is normal.",
value="Not used",
interactive=True
)
use_pattern = gr.Radio(
choices=[
'Only LLM',
'RAG',
],
label="Chat Mode",
value='RAG',
interactive=True
)
knowledge_name = gr.Dropdown(
choices=[
"default_index",
],
label="Knowledge Name",
value="default_index",
interactive=True
)
with gr.Column(scale=4):
with gr.Row():
chatbot = gr.Chatbot(label='TrustRAG Application').style(height=650)
with gr.Row():
message = gr.Textbox(label='Please enter a question')
with gr.Row():
clear_history = gr.Button("🧹 Clear")
send = gr.Button("🚀 Send")
with gr.Row():
gr.Markdown(
""">Remind:[TrustRAG Application](https://github.com/TrustRAG-community/TrustRAG)If you have any questions, please provide feedback in [Github Issue区](https://github.com/TrustRAG-community/TrustRAG) .""")
with gr.Column(scale=2):
with gr.Row():
rewrite = gr.Textbox(label='Query Reformulate')
with gr.Row():
# todo:创建judge显示结果,使用复选框
with gr.Column() as checkbox_container:
# gr.Markdown("Document Judge")
checkbox_outputs = [gr.Checkbox(visible=False, interactive=True) for _ in range(5)]
with gr.Row():
search = gr.Textbox(label='Claim Attribute')
# submit
send.click(predict,
inputs=[
message,
large_language_model,
embedding_model,
top_k,
use_web,
use_pattern,
state
],
outputs=[message, chatbot, state, search, rewrite] + checkbox_outputs)
# clear
clear_history.click(fn=clear_session,
inputs=[],
outputs=[chatbot, state],
queue=False)
# enter
message.submit(predict,
inputs=[
message,
large_language_model,
embedding_model,
top_k,
use_web,
use_pattern,
state
],
outputs=[message, chatbot, state, search, rewrite] + checkbox_outputs)
demo.queue(concurrency_count=2).launch(
server_name='0.0.0.0',
server_port=7860,
share=True,
show_error=True,
debug=True,
enable_queue=True,
inbrowser=False,
)