Skip to content

Commit

Permalink
refactor(retrieval_service): streamline query retrieval and enhance o…
Browse files Browse the repository at this point in the history
…utput validation
  • Loading branch information
charli117 committed Feb 11, 2025
1 parent 3b1513d commit 923c6b9
Showing 1 changed file with 23 additions and 32 deletions.
55 changes: 23 additions & 32 deletions api/core/rag/datasource/retrieval_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import concurrent.futures
import json
from typing import Optional, cast
from typing import Any, Optional, cast

from flask import Flask, current_app
from sqlalchemy.orm import load_only
Expand Down Expand Up @@ -287,28 +287,23 @@ def format_retrieval_documents(cls, documents: list[Document]) -> list[Retrieval
if not child_chunk:
continue

segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == child_chunk.segment_id,
result = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id == child_chunk.segment_id,
).options(
load_only(
DocumentSegment.id,
DocumentSegment.content,
DocumentSegment.answer,
DocumentSegment.doc_metadata,
)
.options(
load_only(
DocumentSegment.id,
DocumentSegment.content,
DocumentSegment.answer,
DocumentSegment.doc_metadata,
)
)
.first()
)

if not segment:
).first()
if result is None:
continue

segment: DocumentSegment = cast(DocumentSegment, result)
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
child_chunk_detail = {
Expand Down Expand Up @@ -343,20 +338,16 @@ def format_retrieval_documents(cls, documents: list[Document]) -> list[Retrieval
if not index_node_id:
continue

segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
).first()
)

if not segment:
result = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
).first()
if result is None:
continue

segment = cast(DocumentSegment, segment)
segment: DocumentSegment = cast(DocumentSegment, result)
include_segment_ids.add(segment.id)
record = {
"segment": segment,
Expand Down

0 comments on commit 923c6b9

Please sign in to comment.