From 8c4d31e4d589847e2d6b2ead5777138bd308158b Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Wed, 24 Dec 2025 12:10:52 +0800 Subject: [PATCH] feat(workflow): implement a workflow node for knowledge base retrieval --- .../vdb/elasticsearch/elasticsearch_vector.py | 29 +++--- api/app/core/workflow/nodes/__init__.py | 4 +- api/app/core/workflow/nodes/configs.py | 4 +- .../core/workflow/nodes/knowledge/__init__.py | 4 + .../core/workflow/nodes/knowledge/config.py | 38 ++++++++ api/app/core/workflow/nodes/knowledge/node.py | 97 +++++++++++++++++++ api/app/core/workflow/nodes/node_factory.py | 6 +- api/app/core/workflow/nodes/start/node.py | 33 +++---- api/app/repositories/knowledge_repository.py | 2 +- api/app/services/knowledge_service.py | 2 +- 10 files changed, 179 insertions(+), 40 deletions(-) create mode 100644 api/app/core/workflow/nodes/knowledge/__init__.py create mode 100644 api/app/core/workflow/nodes/knowledge/config.py create mode 100644 api/app/core/workflow/nodes/knowledge/node.py diff --git a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py index 176f996a..198d1473 100644 --- a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py @@ -740,8 +740,9 @@ class ElasticSearchVector(BaseVector): self._client.indices.create(index=self._collection_name, body=index_mapping) -class ElasticSearchVectorFactory(ABC): - def init_vector(self, knowledge: Knowledge) -> ElasticSearchVector: +class ElasticSearchVectorFactory: + @staticmethod + def init_vector(knowledge: Knowledge) -> ElasticSearchVector: collection_name = f"Vector_index_{knowledge.id}_Node" # Use regular Elasticsearch with config values @@ -763,17 +764,17 @@ class ElasticSearchVectorFactory(ABC): } ) - if knowledge.embedding and knowledge.reranker: - return ElasticSearchVector( - index_name=collection_name, - config=ElasticSearchConfig(**config_dict), - embedding_config=knowledge.embedding.api_keys[0], - reranker_config=knowledge.reranker.api_keys[0] - ) - else: - if knowledge.embedding is None: - raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}") - if knowledge.reranker is None: - raise ValueError(f"reranker_id config error: {str(knowledge.reranker_id)}") + if knowledge.embedding is None: + raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}") + if knowledge.reranker is None: + raise ValueError(f"reranker_id config error: {str(knowledge.reranker_id)}") + + return ElasticSearchVector( + index_name=collection_name, + config=ElasticSearchConfig(**config_dict), + embedding_config=knowledge.embedding.api_keys[0], + reranker_config=knowledge.reranker.api_keys[0] + ) + diff --git a/api/app/core/workflow/nodes/__init__.py b/api/app/core/workflow/nodes/__init__.py index 1d00532e..fa5a5a2b 100644 --- a/api/app/core/workflow/nodes/__init__.py +++ b/api/app/core/workflow/nodes/__init__.py @@ -9,7 +9,7 @@ from app.core.workflow.nodes.assigner import AssignerNode from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.end import EndNode from app.core.workflow.nodes.if_else import IfElseNode -# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode +from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode from app.core.workflow.nodes.start import StartNode @@ -26,6 +26,6 @@ __all__ = [ "EndNode", "NodeFactory", "WorkflowNode", - # "KnowledgeRetrievalNode", + "KnowledgeRetrievalNode", "AssignerNode", ] diff --git a/api/app/core/workflow/nodes/configs.py b/api/app/core/workflow/nodes/configs.py index ecded070..e9f102f0 100644 --- a/api/app/core/workflow/nodes/configs.py +++ b/api/app/core/workflow/nodes/configs.py @@ -14,7 +14,7 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig from app.core.workflow.nodes.agent.config import AgentNodeConfig from app.core.workflow.nodes.transform.config import TransformNodeConfig from app.core.workflow.nodes.if_else.config import IfElseNodeConfig -# from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig +from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig from app.core.workflow.nodes.assigner.config import AssignerNodeConfig __all__ = [ @@ -30,6 +30,6 @@ __all__ = [ "AgentNodeConfig", "TransformNodeConfig", "IfElseNodeConfig", - # "KnowledgeRetrievalNodeConfig", + "KnowledgeRetrievalNodeConfig", "AssignerNodeConfig", ] diff --git a/api/app/core/workflow/nodes/knowledge/__init__.py b/api/app/core/workflow/nodes/knowledge/__init__.py new file mode 100644 index 00000000..25d0f00b --- /dev/null +++ b/api/app/core/workflow/nodes/knowledge/__init__.py @@ -0,0 +1,4 @@ +from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig +from app.core.workflow.nodes.knowledge.node import KnowledgeRetrievalNode + +__all__ = ["KnowledgeRetrievalNode", "KnowledgeRetrievalNodeConfig"] diff --git a/api/app/core/workflow/nodes/knowledge/config.py b/api/app/core/workflow/nodes/knowledge/config.py new file mode 100644 index 00000000..530116ff --- /dev/null +++ b/api/app/core/workflow/nodes/knowledge/config.py @@ -0,0 +1,38 @@ +from uuid import UUID + +from pydantic import Field + +from app.core.workflow.nodes.base_config import BaseNodeConfig +from app.schemas.chunk_schema import RetrieveType + + +class KnowledgeRetrievalNodeConfig(BaseNodeConfig): + query: str = Field( + ..., + description="Search query string" + ) + + kb_ids: list[UUID] = Field( + ..., + description="Knowledge base IDs" + ) + + similarity_threshold: float = Field( + default=0.2, + description="Knowledge base similarity threshold" + ) + + vector_similarity_weight: float = Field( + default=0.3, + description="Knowledge base vector similarity weight" + ) + + top_k: int = Field( + default=4, + description="Knowledge base top k" + ) + + retrieve_type: RetrieveType = Field( + default=RetrieveType.PARTICIPLE, + description="Retrieve type" + ) diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py new file mode 100644 index 00000000..72e8750f --- /dev/null +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -0,0 +1,97 @@ +import logging +from typing import Any + +from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory +from app.core.workflow.nodes.base_node import BaseNode, WorkflowState +from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig +from app.db import get_db +from app.models import knowledge_model, knowledgeshare_model +from app.repositories import knowledge_repository +from app.schemas.chunk_schema import RetrieveType +from app.services import knowledge_service, knowledgeshare_service + +logger = logging.getLogger(__name__) + + +class KnowledgeRetrievalNode(BaseNode): + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + super().__init__(node_config, workflow_config) + self.typed_config = KnowledgeRetrievalNodeConfig(**self.config) + + async def execute(self, state: WorkflowState) -> Any: + query = self._render_template(self.typed_config.query, state) + db_gen = get_db() + db = next(db_gen) + try: + filters = [ + knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids), + knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private, + knowledge_model.Knowledge.chunk_num > 0, + knowledge_model.Knowledge.status == 1 + ] + existing_ids = knowledge_repository.get_chunked_knowledgeids( + db=db, + filters=filters + ) + filters = [ + knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids), + knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share, + knowledge_model.Knowledge.chunk_num > 0, + knowledge_model.Knowledge.status == 1 + ] + share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids( + db=db, + filters=filters + ) + if share_ids: + filters = [ + knowledgeshare_model.KnowledgeShare.target_kb_id.in_(self.typed_config.kb_ids) + ] + items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id( + db=db, + filters=filters + ) + existing_ids.extend(items) + + if not existing_ids: + raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.") + + kb_id = existing_ids[0] + uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids] + indices = ",".join(uuid_strs) + + db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id) + if not db_knowledge: + raise RuntimeError("The knowledge base does not exist or access is denied.") + + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + + match self.typed_config.retrieve_type: + case RetrieveType.PARTICIPLE: + rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, + indices=indices, + score_threshold=self.typed_config.similarity_threshold) + return [chunk.model_dump() for chunk in rs] + case RetrieveType.SEMANTIC: + rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, + indices=indices, + score_threshold=self.typed_config.vector_similarity_weight) + return [chunk.model_dump() for chunk in rs] + case _: + rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, + indices=indices, + score_threshold=self.typed_config.vector_similarity_weight) + rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, + indices=indices, + score_threshold=self.typed_config.similarity_threshold) + # Efficient deduplication + seen_ids = set() + unique_rs = [] + for doc in rs1 + rs2: + if doc.metadata["doc_id"] not in seen_ids: + seen_ids.add(doc.metadata["doc_id"]) + unique_rs.append(doc) + rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k) + return [chunk.model_dump() for chunk in rs] + finally: + next(db_gen) diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index 93364083..2ae31d4d 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -7,7 +7,7 @@ import logging from typing import Any, Union -# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode +from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.agent import AgentNode from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.end import EndNode @@ -29,7 +29,7 @@ WorkflowNode = Union[ AgentNode, TransformNode, AssignerNode, - # KnowledgeRetrievalNode, + KnowledgeRetrievalNode, ] @@ -47,7 +47,7 @@ class NodeFactory: NodeType.AGENT: AgentNode, NodeType.TRANSFORM: TransformNode, NodeType.IF_ELSE: IfElseNode, - # NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, + NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, NodeType.ASSIGNER: AssignerNode, } diff --git a/api/app/core/workflow/nodes/start/node.py b/api/app/core/workflow/nodes/start/node.py index 0acf04b0..7c3a2fca 100644 --- a/api/app/core/workflow/nodes/start/node.py +++ b/api/app/core/workflow/nodes/start/node.py @@ -23,7 +23,7 @@ class StartNode(BaseNode): 注意:变量的验证和默认值处理由 Executor 在初始化时完成。 """ - + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): """初始化 Start 节点 @@ -32,10 +32,10 @@ class StartNode(BaseNode): workflow_config: 工作流配置 """ super().__init__(node_config, workflow_config) - + # 解析并验证配置 self.typed_config = StartNodeConfig(**self.config) - + async def execute(self, state: WorkflowState) -> dict[str, Any]: """执行 start 节点业务逻辑 @@ -48,13 +48,13 @@ class StartNode(BaseNode): 包含系统参数、会话变量和自定义变量的字典 """ logger.info(f"节点 {self.node_id} (Start) 开始执行") - + # 创建变量池实例(在方法内复用) pool = self.get_variable_pool(state) - + # 处理自定义变量(传入 pool 避免重复创建) custom_vars = self._process_custom_variables(pool) - + # 返回业务数据(包含自定义变量) result = { "message": pool.get("sys.message"), @@ -64,14 +64,14 @@ class StartNode(BaseNode): "user_id": pool.get("sys.user_id"), **custom_vars # 自定义变量作为节点输出的一部分 } - + logger.info( f"节点 {self.node_id} (Start) 执行完成," f"输出了 {len(custom_vars)} 个自定义变量" ) - + return result - + def _process_custom_variables(self, pool) -> dict[str, Any]: """处理自定义变量 @@ -88,34 +88,33 @@ class StartNode(BaseNode): """ # 获取输入数据中的自定义变量 input_variables = pool.get("sys.input_variables", default={}) - + processed = {} - + # 遍历配置的变量定义 for var_def in self.typed_config.variables: var_name = var_def.name - + # 检查变量是否存在 if var_name in input_variables: # 使用用户提供的值 processed[var_name] = input_variables[var_name] - + elif var_def.required: # 必需变量缺失 raise ValueError( f"缺少必需的输入变量: {var_name}" + (f" ({var_def.description})" if var_def.description else "") ) - + elif var_def.default is not None: # 使用默认值 processed[var_name] = var_def.default logger.debug( f"变量 '{var_name}' 使用默认值: {var_def.default}" ) - + return processed - def _extract_input(self, state: WorkflowState) -> dict[str, Any]: """提取输入数据(用于记录) @@ -127,7 +126,7 @@ class StartNode(BaseNode): 输入数据字典 """ pool = self.get_variable_pool(state) - + return { "execution_id": pool.get("sys.execution_id"), "conversation_id": pool.get("sys.conversation_id"), diff --git a/api/app/repositories/knowledge_repository.py b/api/app/repositories/knowledge_repository.py index 5d4946fa..b7908cb0 100644 --- a/api/app/repositories/knowledge_repository.py +++ b/api/app/repositories/knowledge_repository.py @@ -52,7 +52,7 @@ def get_knowledges_paginated( raise -def get_chunded_knowledgeids( +def get_chunked_knowledgeids( db: Session, filters: list ) -> list: diff --git a/api/app/services/knowledge_service.py b/api/app/services/knowledge_service.py index b9d97c29..cf47fd4f 100644 --- a/api/app/services/knowledge_service.py +++ b/api/app/services/knowledge_service.py @@ -45,7 +45,7 @@ def get_chunded_knowledgeids( business_logger.debug(f"Query the list of vectorized knowledge base IDs: username={current_user.username}") try: - items = knowledge_repository.get_chunded_knowledgeids( + items = knowledge_repository.get_chunked_knowledgeids( db=db, filters=filters )