feat(workflow): implement a workflow node for knowledge base retrieval

This commit is contained in:
mengyonghao
2025-12-24 12:10:52 +08:00
parent d423e80ddb
commit 8c4d31e4d5
10 changed files with 179 additions and 40 deletions

View File

@@ -740,8 +740,9 @@ class ElasticSearchVector(BaseVector):
self._client.indices.create(index=self._collection_name, body=index_mapping)
class ElasticSearchVectorFactory(ABC):
def init_vector(self, knowledge: Knowledge) -> ElasticSearchVector:
class ElasticSearchVectorFactory:
@staticmethod
def init_vector(knowledge: Knowledge) -> ElasticSearchVector:
collection_name = f"Vector_index_{knowledge.id}_Node"
# Use regular Elasticsearch with config values
@@ -763,17 +764,17 @@ class ElasticSearchVectorFactory(ABC):
}
)
if knowledge.embedding and knowledge.reranker:
return ElasticSearchVector(
index_name=collection_name,
config=ElasticSearchConfig(**config_dict),
embedding_config=knowledge.embedding.api_keys[0],
reranker_config=knowledge.reranker.api_keys[0]
)
else:
if knowledge.embedding is None:
raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}")
if knowledge.reranker is None:
raise ValueError(f"reranker_id config error: {str(knowledge.reranker_id)}")
if knowledge.embedding is None:
raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}")
if knowledge.reranker is None:
raise ValueError(f"reranker_id config error: {str(knowledge.reranker_id)}")
return ElasticSearchVector(
index_name=collection_name,
config=ElasticSearchConfig(**config_dict),
embedding_config=knowledge.embedding.api_keys[0],
reranker_config=knowledge.reranker.api_keys[0]
)

View File

@@ -9,7 +9,7 @@ from app.core.workflow.nodes.assigner import AssignerNode
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.if_else import IfElseNode
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
from app.core.workflow.nodes.start import StartNode
@@ -26,6 +26,6 @@ __all__ = [
"EndNode",
"NodeFactory",
"WorkflowNode",
# "KnowledgeRetrievalNode",
"KnowledgeRetrievalNode",
"AssignerNode",
]

View File

@@ -14,7 +14,7 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
from app.core.workflow.nodes.agent.config import AgentNodeConfig
from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
# from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
__all__ = [
@@ -30,6 +30,6 @@ __all__ = [
"AgentNodeConfig",
"TransformNodeConfig",
"IfElseNodeConfig",
# "KnowledgeRetrievalNodeConfig",
"KnowledgeRetrievalNodeConfig",
"AssignerNodeConfig",
]

View File

@@ -0,0 +1,4 @@
from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
from app.core.workflow.nodes.knowledge.node import KnowledgeRetrievalNode
__all__ = ["KnowledgeRetrievalNode", "KnowledgeRetrievalNodeConfig"]

View File

@@ -0,0 +1,38 @@
from uuid import UUID
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.schemas.chunk_schema import RetrieveType
class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
query: str = Field(
...,
description="Search query string"
)
kb_ids: list[UUID] = Field(
...,
description="Knowledge base IDs"
)
similarity_threshold: float = Field(
default=0.2,
description="Knowledge base similarity threshold"
)
vector_similarity_weight: float = Field(
default=0.3,
description="Knowledge base vector similarity weight"
)
top_k: int = Field(
default=4,
description="Knowledge base top k"
)
retrieve_type: RetrieveType = Field(
default=RetrieveType.PARTICIPLE,
description="Retrieve type"
)

View File

@@ -0,0 +1,97 @@
import logging
from typing import Any
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.db import get_db
from app.models import knowledge_model, knowledgeshare_model
from app.repositories import knowledge_repository
from app.schemas.chunk_schema import RetrieveType
from app.services import knowledge_service, knowledgeshare_service
logger = logging.getLogger(__name__)
class KnowledgeRetrievalNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> Any:
query = self._render_template(self.typed_config.query, state)
db_gen = get_db()
db = next(db_gen)
try:
filters = [
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
existing_ids = knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
filters = [
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
if share_ids:
filters = [
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(self.typed_config.kb_ids)
]
items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
db=db,
filters=filters
)
existing_ids.extend(items)
if not existing_ids:
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
kb_id = existing_ids[0]
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
indices = ",".join(uuid_strs)
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id)
if not db_knowledge:
raise RuntimeError("The knowledge base does not exist or access is denied.")
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
match self.typed_config.retrieve_type:
case RetrieveType.PARTICIPLE:
rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.similarity_threshold)
return [chunk.model_dump() for chunk in rs]
case RetrieveType.SEMANTIC:
rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.vector_similarity_weight)
return [chunk.model_dump() for chunk in rs]
case _:
rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.vector_similarity_weight)
rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.similarity_threshold)
# Efficient deduplication
seen_ids = set()
unique_rs = []
for doc in rs1 + rs2:
if doc.metadata["doc_id"] not in seen_ids:
seen_ids.add(doc.metadata["doc_id"])
unique_rs.append(doc)
rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k)
return [chunk.model_dump() for chunk in rs]
finally:
next(db_gen)

View File

@@ -7,7 +7,7 @@
import logging
from typing import Any, Union
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.end import EndNode
@@ -29,7 +29,7 @@ WorkflowNode = Union[
AgentNode,
TransformNode,
AssignerNode,
# KnowledgeRetrievalNode,
KnowledgeRetrievalNode,
]
@@ -47,7 +47,7 @@ class NodeFactory:
NodeType.AGENT: AgentNode,
NodeType.TRANSFORM: TransformNode,
NodeType.IF_ELSE: IfElseNode,
# NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.ASSIGNER: AssignerNode,
}

View File

@@ -23,7 +23,7 @@ class StartNode(BaseNode):
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
"""初始化 Start 节点
@@ -32,10 +32,10 @@ class StartNode(BaseNode):
workflow_config: 工作流配置
"""
super().__init__(node_config, workflow_config)
# 解析并验证配置
self.typed_config = StartNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行 start 节点业务逻辑
@@ -48,13 +48,13 @@ class StartNode(BaseNode):
包含系统参数、会话变量和自定义变量的字典
"""
logger.info(f"节点 {self.node_id} (Start) 开始执行")
# 创建变量池实例(在方法内复用)
pool = self.get_variable_pool(state)
# 处理自定义变量(传入 pool 避免重复创建)
custom_vars = self._process_custom_variables(pool)
# 返回业务数据(包含自定义变量)
result = {
"message": pool.get("sys.message"),
@@ -64,14 +64,14 @@ class StartNode(BaseNode):
"user_id": pool.get("sys.user_id"),
**custom_vars # 自定义变量作为节点输出的一部分
}
logger.info(
f"节点 {self.node_id} (Start) 执行完成,"
f"输出了 {len(custom_vars)} 个自定义变量"
)
return result
def _process_custom_variables(self, pool) -> dict[str, Any]:
"""处理自定义变量
@@ -88,34 +88,33 @@ class StartNode(BaseNode):
"""
# 获取输入数据中的自定义变量
input_variables = pool.get("sys.input_variables", default={})
processed = {}
# 遍历配置的变量定义
for var_def in self.typed_config.variables:
var_name = var_def.name
# 检查变量是否存在
if var_name in input_variables:
# 使用用户提供的值
processed[var_name] = input_variables[var_name]
elif var_def.required:
# 必需变量缺失
raise ValueError(
f"缺少必需的输入变量: {var_name}"
+ (f" ({var_def.description})" if var_def.description else "")
)
elif var_def.default is not None:
# 使用默认值
processed[var_name] = var_def.default
logger.debug(
f"变量 '{var_name}' 使用默认值: {var_def.default}"
)
return processed
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
"""提取输入数据(用于记录)
@@ -127,7 +126,7 @@ class StartNode(BaseNode):
输入数据字典
"""
pool = self.get_variable_pool(state)
return {
"execution_id": pool.get("sys.execution_id"),
"conversation_id": pool.get("sys.conversation_id"),

View File

@@ -52,7 +52,7 @@ def get_knowledges_paginated(
raise
def get_chunded_knowledgeids(
def get_chunked_knowledgeids(
db: Session,
filters: list
) -> list:

View File

@@ -45,7 +45,7 @@ def get_chunded_knowledgeids(
business_logger.debug(f"Query the list of vectorized knowledge base IDs: username={current_user.username}")
try:
items = knowledge_repository.get_chunded_knowledgeids(
items = knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)