Skip to content

Commit

Permalink
Merge branch 'main' of github.com:langgenius/dify
Browse files Browse the repository at this point in the history
  • Loading branch information
deershark committed Feb 14, 2025
2 parents 0bee51b + 33a565a commit ea3e026
Show file tree
Hide file tree
Showing 47 changed files with 311 additions and 126 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/expose_service_ports.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ yq eval '.services["pgvecto-rs"].ports += ["5431:5432"]' -i docker/docker-compos
yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-compose.yaml
yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml
yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml
yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/docker-compose.yaml
yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.yaml

echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase"
14 changes: 12 additions & 2 deletions .github/workflows/vdb-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,15 @@ jobs:
- name: Expose Service Ports
run: sh .github/workflows/expose_service_ports.sh

- name: Set up Vector Stores (TiDB, Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
- name: Set up Vector Store (TiDB)
uses: hoverkraft-tech/[email protected]
with:
compose-file: docker/tidb/docker-compose.yaml
services: |
tidb
tiflash
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
uses: hoverkraft-tech/[email protected]
with:
compose-file: |
Expand All @@ -70,7 +78,9 @@ jobs:
pgvector
chroma
elasticsearch
tidb
- name: Check TiDB Ready
run: poetry run -P api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py

- name: Test Vector Stores
run: poetry run -P api bash dev/pytest/pytest_vdb.sh
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ docker/volumes/db/data/*
docker/volumes/redis/data/*
docker/volumes/weaviate/*
docker/volumes/qdrant/*
docker/tidb/volumes/*
docker/volumes/etcd/*
docker/volumes/minio/*
docker/volumes/milvus/*
Expand Down
2 changes: 1 addition & 1 deletion api/constants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

if dify_config.ETL_TYPE == "Unstructured":
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls"]
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
DOCUMENT_EXTENSIONS.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
if dify_config.UNSTRUCTURED_API_URL:
DOCUMENT_EXTENSIONS.append("ppt")
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
Expand Down
4 changes: 1 addition & 3 deletions api/core/app/apps/advanced_chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,7 @@ def generate(
app_config=app_config,
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
Expand Down
4 changes: 1 addition & 3 deletions api/core/app/apps/agent_chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ def generate(
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
Expand Down
4 changes: 1 addition & 3 deletions api/core/app/apps/chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,7 @@ def generate(
model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs
if conversation
else self._prepare_user_inputs(
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query,
Expand Down
2 changes: 1 addition & 1 deletion api/core/app/task_pipeline/workflow_cycle_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,4 +842,4 @@ def _get_workflow_node_execution(self, session: Session, node_execution_id: str)
if node_execution_id not in self._workflow_node_executions:
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
return cached_workflow_node_execution
return session.merge(cached_workflow_node_execution)
2 changes: 1 addition & 1 deletion api/core/model_runtime/model_providers/tongyi/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _generate(
else:
# nothing different between chat model and completion model in tongyi
params["messages"] = self._convert_prompt_messages_to_tongyi_messages(prompt_messages)
response = Generation.call(**params, result_format="message", stream=stream, incremental_output=True)
response = Generation.call(**params, result_format="message", stream=stream, incremental_output=stream)
if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)

Expand Down
8 changes: 3 additions & 5 deletions api/core/provider_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,11 +452,9 @@ def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[L

provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
for provider_load_balancing_config in provider_load_balancing_configs:
(
provider_name_to_provider_load_balancing_model_configs_dict[
provider_load_balancing_config.provider_name
].append(provider_load_balancing_config)
)
provider_name_to_provider_load_balancing_model_configs_dict[
provider_load_balancing_config.provider_name
].append(provider_load_balancing_config)

return provider_name_to_provider_load_balancing_model_configs_dict

Expand Down
62 changes: 40 additions & 22 deletions api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sqlalchemy.orm import Session, declarative_base

from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
Expand Down Expand Up @@ -54,14 +55,13 @@ def _table(self, dim: int) -> Table:
return Table(
self._collection_name,
self._orm_base.metadata,
Column("id", String(36), primary_key=True, nullable=False),
Column(Field.PRIMARY_KEY.value, String(36), primary_key=True, nullable=False),
Column(
"vector",
Field.VECTOR.value,
VectorType(dim),
nullable=False,
comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})",
),
Column("text", TEXT, nullable=False),
Column(Field.TEXT_KEY.value, TEXT, nullable=False),
Column("meta", JSON, nullable=False),
Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")),
Column(
Expand Down Expand Up @@ -96,6 +96,7 @@ def _create_collection(self, dimension: int):
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
tidb_dist_func = self._get_distance_func()
with Session(self._engine) as session:
session.begin()
create_statement = sql_text(f"""
Expand All @@ -104,14 +105,14 @@ def _create_collection(self, dimension: int):
text TEXT NOT NULL,
meta JSON NOT NULL,
doc_id VARCHAR(64) AS (JSON_UNQUOTE(JSON_EXTRACT(meta, '$.doc_id'))) STORED,
KEY (doc_id),
vector VECTOR<FLOAT>({dimension}) NOT NULL COMMENT "hnsw(distance={self._distance_func})",
vector VECTOR<FLOAT>({dimension}) NOT NULL,
create_time DATETIME DEFAULT CURRENT_TIMESTAMP,
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
KEY (doc_id),
VECTOR INDEX idx_vector (({tidb_dist_func}(vector))) USING HNSW
);
""")
session.execute(create_statement)
# tidb vector not support 'CREATE/ADD INDEX' now
session.commit()
redis_client.set(collection_exist_cache_key, 1, ex=3600)

Expand Down Expand Up @@ -194,23 +195,30 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
)

docs = []
if self._distance_func == "l2":
tidb_func = "Vec_l2_distance"
elif self._distance_func == "cosine":
tidb_func = "Vec_Cosine_distance"
else:
tidb_func = "Vec_Cosine_distance"
tidb_dist_func = self._get_distance_func()

with Session(self._engine) as session:
select_statement = sql_text(
f"""SELECT meta, text, distance FROM (
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
FROM {self._collection_name}
ORDER BY distance
LIMIT {top_k}
) t WHERE distance < {distance};"""
select_statement = sql_text(f"""
SELECT meta, text, distance
FROM (
SELECT
meta,
text,
{tidb_dist_func}(vector, :query_vector_str) AS distance
FROM {self._collection_name}
ORDER BY distance ASC
LIMIT :top_k
) t
WHERE distance <= :distance
""")
res = session.execute(
select_statement,
params={
"query_vector_str": query_vector_str,
"distance": distance,
"top_k": top_k,
},
)
res = session.execute(select_statement)
results = [(row[0], row[1], row[2]) for row in res]
for meta, text, distance in results:
metadata = json.loads(meta)
Expand All @@ -227,6 +235,16 @@ def delete(self) -> None:
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};"""))
session.commit()

def _get_distance_func(self) -> str:
match self._distance_func:
case "l2":
tidb_dist_func = "VEC_L2_DISTANCE"
case "cosine":
tidb_dist_func = "VEC_COSINE_DISTANCE"
case _:
tidb_dist_func = "VEC_COSINE_DISTANCE"
return tidb_dist_func


class TiDBVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TiDBVector:
Expand Down
2 changes: 0 additions & 2 deletions api/core/workflow/graph_engine/entities/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,8 +590,6 @@ def _fetch_all_node_ids_in_parallels(
start_node_id=node_id,
routes_node_ids=routes_node_ids,
)
# Exclude conditional branch nodes
and all(edge.run_condition is None for edge in reverse_edge_mapping.get(node_id, []))
):
if node_id not in merge_branch_node_ids:
merge_branch_node_ids[node_id] = []
Expand Down
4 changes: 2 additions & 2 deletions api/core/workflow/graph_engine/graph_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ def _run_node(
retries += 1
route_node_state.node_run_result = run_result
yield NodeRunRetryEvent(
id=node_instance.id,
id=str(uuid.uuid4()),
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
Expand All @@ -663,7 +663,7 @@ def _run_node(
start_at=retry_start_at,
)
time.sleep(retry_interval)
continue
break
route_node_state.set_finished(run_result=run_result)

if run_result.status == WorkflowNodeExecutionStatus.FAILED:
Expand Down
61 changes: 53 additions & 8 deletions api/core/workflow/nodes/document_extractor/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,10 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
return _extract_text_from_plain_text(file_content)
case "application/pdf":
return _extract_text_from_pdf(file_content)
case "application/vnd.openxmlformats-officedocument.wordprocessingml.document" | "application/msword":
case "application/msword":
return _extract_text_from_doc(file_content)
case "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
return _extract_text_from_docx(file_content)
case "text/csv":
return _extract_text_from_csv(file_content)
case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | "application/vnd.ms-excel":
Expand Down Expand Up @@ -142,8 +144,10 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
return _extract_text_from_yaml(file_content)
case ".pdf":
return _extract_text_from_pdf(file_content)
case ".doc" | ".docx":
case ".doc":
return _extract_text_from_doc(file_content)
case ".docx":
return _extract_text_from_docx(file_content)
case ".csv":
return _extract_text_from_csv(file_content)
case ".xls" | ".xlsx":
Expand Down Expand Up @@ -203,7 +207,33 @@ def _extract_text_from_pdf(file_content: bytes) -> str:

def _extract_text_from_doc(file_content: bytes) -> str:
"""
Extract text from a DOC/DOCX file.
Extract text from a DOC file.
"""
from unstructured.partition.api import partition_via_api

if not (dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY):
raise TextExtractionError("UNSTRUCTURED_API_URL and UNSTRUCTURED_API_KEY must be set")

try:
with tempfile.NamedTemporaryFile(suffix=".doc", delete=False) as temp_file:
temp_file.write(file_content)
temp_file.flush()
with open(temp_file.name, "rb") as file:
elements = partition_via_api(
file=file,
metadata_filename=temp_file.name,
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY,
)
os.unlink(temp_file.name)
return "\n".join([getattr(element, "text", "") for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from DOC: {str(e)}") from e


def _extract_text_from_docx(file_content: bytes) -> str:
"""
Extract text from a DOCX file.
For now support only paragraph and table add more if needed
"""
try:
Expand Down Expand Up @@ -255,13 +285,13 @@ def _extract_text_from_doc(file_content: bytes) -> str:

text.append(markdown_table)
except Exception as e:
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
logger.warning(f"Failed to extract table from DOC: {e}")
continue

return "\n".join(text)

except Exception as e:
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e
raise TextExtractionError(f"Failed to extract text from DOCX: {str(e)}") from e


def _download_file_content(file: File) -> bytes:
Expand Down Expand Up @@ -329,14 +359,29 @@ def _extract_text_from_excel(file_content: bytes) -> str:


def _extract_text_from_ppt(file_content: bytes) -> str:
from unstructured.partition.api import partition_via_api
from unstructured.partition.ppt import partition_ppt

try:
with io.BytesIO(file_content) as file:
elements = partition_ppt(file=file)
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
with tempfile.NamedTemporaryFile(suffix=".ppt", delete=False) as temp_file:
temp_file.write(file_content)
temp_file.flush()
with open(temp_file.name, "rb") as file:
elements = partition_via_api(
file=file,
metadata_filename=temp_file.name,
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY,
)
os.unlink(temp_file.name)
else:
with io.BytesIO(file_content) as file:
elements = partition_ppt(file=file)
return "\n".join([getattr(element, "text", "") for element in elements])

except Exception as e:
raise TextExtractionError(f"Failed to extract text from PPT: {str(e)}") from e
raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e


def _extract_text_from_pptx(file_content: bytes) -> str:
Expand Down
4 changes: 4 additions & 0 deletions api/core/workflow/utils/condition/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def process_conditions(
expected=expected_value,
)
group_results.append(result)
# Implemented short-circuit evaluation for logical conditions
if (operator == "and" and not result) or (operator == "or" and result):
final_result = result
return input_conditions, group_results, final_result

final_result = all(group_results) if operator == "and" else any(group_results)
return input_conditions, group_results, final_result
Expand Down
Loading

0 comments on commit ea3e026

Please sign in to comment.