Skip to content

Commit

Permalink
feat[chat]: vectorize extraction result for improved chat content
Browse files Browse the repository at this point in the history
  • Loading branch information
ArslanSaleem committed Oct 28, 2024
1 parent d876353 commit 7065f32
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 2 deletions.
48 changes: 47 additions & 1 deletion backend/app/processing/process_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ def process_step_task(
# Initial DB operations (open and fetch relevant data)
with SessionLocal() as db:
process = process_repository.get_process(db, process_id)
project_id = process.project_id
process_step = process_repository.get_process_step(db, process_step_id)

filename = process_step.asset.filename
if process.status == ProcessStatus.STOPPED:
return False # Stop processing if the process is stopped

Expand Down Expand Up @@ -84,6 +85,15 @@ def process_step_task(
output_references=data["context"],
)

# vectorize extraction result
try:
vectorize_extraction_process_step(project_id=project_id,
process_step_id=process_step_id,
filename=filename,
references=data["context"])
except Exception :
logger.error(f"Failed to vectorize extraction results for chat {traceback.print_exc()}")

success = True

except CreditLimitExceededException:
Expand Down Expand Up @@ -361,3 +371,39 @@ def update_process_step_status(
process_repository.update_process_step_status(
db, process_step, status, output=output, output_references=output_references
)

def vectorize_extraction_process_step(project_id: int, process_step_id: int, filename: str, references: dict) -> None:
# Vectorize extraction result and dump in database
field_references = {}

# Loop to concatenate sources for each reference
for extraction_references in references:
for extraction_reference in extraction_references:
sources = extraction_reference.get("sources", [])
if sources:
sources_catenated = "\n".join(sources)
field_references.setdefault(extraction_reference["name"], "")
field_references[extraction_reference["name"]] += (
"\n" + sources_catenated if field_references[extraction_reference["name"]] else sources_catenated
)

# Only proceed if there are references to add
if not field_references:
return

# Initialize Vectorstore
vectorstore = ChromaDB(f"panda-etl-extraction-{project_id}")

docs = [f"{filename} {key}" for key in field_references]
metadatas = [
{
"project_id": project_id,
"process_step_id": process_step_id,
"filename": filename,
"reference": reference
}
for reference in field_references.values()
]

# Add documents to vectorstore
vectorstore.add_docs(docs=docs, metadatas=metadatas)
107 changes: 106 additions & 1 deletion backend/tests/processing/test_process_queue.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from app.requests.schemas import ExtractFieldsResponse
import pytest
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
from app.processing.process_queue import (
handle_exceptions,
extract_process,
update_process_step_status,
find_best_match_for_short_reference,
vectorize_extraction_process_step,
)
from app.exceptions import CreditLimitExceededException
from app.models import ProcessStepStatus
Expand Down Expand Up @@ -180,3 +181,107 @@ def test_chroma_db_initialization(mock_extract_data, mock_chroma):

mock_chroma.assert_called_with(f"panda-etl-{process.project_id}", similarity_threshold=3)
assert mock_chroma.call_count >= 1

@patch('app.processing.process_queue.ChromaDB')
def test_vectorize_extraction_process_step_single_reference(mock_chroma_db):
# Mock ChromaDB instance
mock_vectorstore = MagicMock()
mock_chroma_db.return_value = mock_vectorstore

# Inputs
project_id = 123
process_step_id = 1
filename = "sample_file"
references = [
[
{"name": "field1", "sources": ["source1", "source2"]}
]
]

# Call function
vectorize_extraction_process_step(project_id, process_step_id, filename, references)

# Expected docs and metadata to add to ChromaDB
expected_docs = ["sample_file field1"]
expected_metadatas = [
{
"project_id": project_id,
"process_step_id": process_step_id,
"filename": filename,
"reference": "source1\nsource2"
}
]

# Assertions
mock_vectorstore.add_docs.assert_called_once_with(
docs=expected_docs,
metadatas=expected_metadatas
)

@patch('app.processing.process_queue.ChromaDB')
def test_vectorize_extraction_process_step_multiple_references_concatenation(mock_chroma_db):
# Mock ChromaDB instance
mock_vectorstore = MagicMock()
mock_chroma_db.return_value = mock_vectorstore

# Inputs
project_id = 456
process_step_id = 2
filename = "test_file"
references = [
[
{"name": "field1", "sources": ["source1", "source2"]},
{"name": "field1", "sources": ["source3"]}
],
[
{"name": "field2", "sources": ["source4"]}
]
]

# Call function
vectorize_extraction_process_step(project_id, process_step_id, filename, references)

# Expected docs and metadata to add to ChromaDB
expected_docs = ["test_file field1", "test_file field2"]
expected_metadatas = [
{
"project_id": project_id,
"process_step_id": process_step_id,
"filename": filename,
"reference": "source1\nsource2\nsource3"
},
{
"project_id": project_id,
"process_step_id": process_step_id,
"filename": filename,
"reference": "source4"
}
]

# Assertions
mock_vectorstore.add_docs.assert_called_once_with(
docs=expected_docs,
metadatas=expected_metadatas
)

@patch('app.processing.process_queue.ChromaDB') # Replace with the correct module path
def test_vectorize_extraction_process_step_empty_sources(mock_chroma_db):
# Mock ChromaDB instance
mock_vectorstore = MagicMock()
mock_chroma_db.return_value = mock_vectorstore

# Inputs
project_id = 789
process_step_id = 3
filename = "empty_sources_file"
references = [
[
{"name": "field1", "sources": []}
]
]

# Call function
vectorize_extraction_process_step(project_id, process_step_id, filename, references)

# Expected no calls to add_docs due to empty sources
mock_vectorstore.add_docs.assert_not_called()

0 comments on commit 7065f32

Please sign in to comment.