Merge #41 into develop from feature/20251219_myh
feat(workflow): implement a workflow node for knowledge base retrieval * feature/20251219_myh: (3 commits) feat(workflow): support multi-variable assignment in assigner node feat(workflow): implement a workflow node for knowledge base retrieval fix(template): remove default initial model in templates Signed-off-by: Eternity <1533512157@qq.com> Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com> Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com> CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/41
This commit is contained in:
@@ -740,8 +740,9 @@ class ElasticSearchVector(BaseVector):
|
|||||||
self._client.indices.create(index=self._collection_name, body=index_mapping)
|
self._client.indices.create(index=self._collection_name, body=index_mapping)
|
||||||
|
|
||||||
|
|
||||||
class ElasticSearchVectorFactory(ABC):
|
class ElasticSearchVectorFactory:
|
||||||
def init_vector(self, knowledge: Knowledge) -> ElasticSearchVector:
|
@staticmethod
|
||||||
|
def init_vector(knowledge: Knowledge) -> ElasticSearchVector:
|
||||||
collection_name = f"Vector_index_{knowledge.id}_Node"
|
collection_name = f"Vector_index_{knowledge.id}_Node"
|
||||||
|
|
||||||
# Use regular Elasticsearch with config values
|
# Use regular Elasticsearch with config values
|
||||||
@@ -763,17 +764,17 @@ class ElasticSearchVectorFactory(ABC):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if knowledge.embedding and knowledge.reranker:
|
if knowledge.embedding is None:
|
||||||
return ElasticSearchVector(
|
raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}")
|
||||||
index_name=collection_name,
|
if knowledge.reranker is None:
|
||||||
config=ElasticSearchConfig(**config_dict),
|
raise ValueError(f"reranker_id config error: {str(knowledge.reranker_id)}")
|
||||||
embedding_config=knowledge.embedding.api_keys[0],
|
|
||||||
reranker_config=knowledge.reranker.api_keys[0]
|
return ElasticSearchVector(
|
||||||
)
|
index_name=collection_name,
|
||||||
else:
|
config=ElasticSearchConfig(**config_dict),
|
||||||
if knowledge.embedding is None:
|
embedding_config=knowledge.embedding.api_keys[0],
|
||||||
raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}")
|
reranker_config=knowledge.reranker.api_keys[0]
|
||||||
if knowledge.reranker is None:
|
)
|
||||||
raise ValueError(f"reranker_id config error: {str(knowledge.reranker_id)}")
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from app.core.workflow.nodes.assigner import AssignerNode
|
|||||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
from app.core.workflow.nodes.end import EndNode
|
from app.core.workflow.nodes.end import EndNode
|
||||||
from app.core.workflow.nodes.if_else import IfElseNode
|
from app.core.workflow.nodes.if_else import IfElseNode
|
||||||
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||||
from app.core.workflow.nodes.llm import LLMNode
|
from app.core.workflow.nodes.llm import LLMNode
|
||||||
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
|
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
|
||||||
from app.core.workflow.nodes.start import StartNode
|
from app.core.workflow.nodes.start import StartNode
|
||||||
@@ -26,6 +26,6 @@ __all__ = [
|
|||||||
"EndNode",
|
"EndNode",
|
||||||
"NodeFactory",
|
"NodeFactory",
|
||||||
"WorkflowNode",
|
"WorkflowNode",
|
||||||
# "KnowledgeRetrievalNode",
|
"KnowledgeRetrievalNode",
|
||||||
"AssignerNode",
|
"AssignerNode",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,21 +1,32 @@
|
|||||||
from pydantic import Field
|
from pydantic import Field, BaseModel
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
from app.core.workflow.nodes.enums import AssignmentOperator
|
from app.core.workflow.nodes.enums import AssignmentOperator
|
||||||
|
|
||||||
|
|
||||||
class AssignerNodeConfig(BaseNodeConfig):
|
class AssignmentItem(BaseModel):
|
||||||
|
"""
|
||||||
|
Single assignment definition.
|
||||||
|
"""
|
||||||
|
|
||||||
variable_selector: str | list[str] = Field(
|
variable_selector: str | list[str] = Field(
|
||||||
...,
|
...,
|
||||||
description="Variables to be assigned",
|
description="Target variable name(s) to assign",
|
||||||
)
|
)
|
||||||
|
|
||||||
operation: AssignmentOperator = Field(
|
operation: AssignmentOperator = Field(
|
||||||
...,
|
...,
|
||||||
description="Operator to assign",
|
description="Assignment operator",
|
||||||
)
|
)
|
||||||
|
|
||||||
value: str | list[str] = Field(
|
value: str | list[str] = Field(
|
||||||
...,
|
...,
|
||||||
description="Values to assign",
|
description="Value(s) to assign to the variable(s)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AssignerNodeConfig(BaseNodeConfig):
|
||||||
|
assignments: list[AssignmentItem] = Field(
|
||||||
|
...,
|
||||||
|
description="List of variable assignment definitions",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -29,52 +29,52 @@ class AssignerNode(BaseNode):
|
|||||||
"""
|
"""
|
||||||
# Initialize a variable pool for accessing conversation, node, and system variables
|
# Initialize a variable pool for accessing conversation, node, and system variables
|
||||||
pool = VariablePool(state)
|
pool = VariablePool(state)
|
||||||
|
for assignment in self.typed_config.assignments:
|
||||||
|
# Get the target variable selector (e.g., "conv.test")
|
||||||
|
variable_selector = assignment.variable_selector
|
||||||
|
if isinstance(variable_selector, str):
|
||||||
|
# Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"]
|
||||||
|
variable_selector = variable_selector.split('.')
|
||||||
|
|
||||||
# Get the target variable selector (e.g., "conv.test")
|
# Only conversation variables ('conv') are allowed
|
||||||
variable_selector = self.typed_config.variable_selector
|
if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature)
|
||||||
if isinstance(variable_selector, str):
|
raise ValueError("Only conversation variables can be assigned.")
|
||||||
# Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"]
|
|
||||||
variable_selector = variable_selector.split('.')
|
|
||||||
|
|
||||||
# Only conversation variables ('conv') are allowed
|
# Get the value or expression to assign
|
||||||
if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature)
|
value = assignment.value
|
||||||
raise ValueError("Only conversation variables can be assigned.")
|
if isinstance(value, list):
|
||||||
|
value = '.'.join(value)
|
||||||
|
value = ExpressionEvaluator.evaluate(
|
||||||
|
expression=value,
|
||||||
|
variables=pool.get_all_conversation_vars(),
|
||||||
|
node_outputs=pool.get_all_node_outputs(),
|
||||||
|
system_vars=pool.get_all_system_vars(),
|
||||||
|
)
|
||||||
|
|
||||||
# Get the value or expression to assign
|
# Select the appropriate assignment operator instance based on the target variable type
|
||||||
value = self.typed_config.value
|
operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))(
|
||||||
if isinstance(value, list):
|
pool, variable_selector, value
|
||||||
value = '.'.join(value)
|
)
|
||||||
value = ExpressionEvaluator.evaluate(
|
|
||||||
expression=value,
|
|
||||||
variables=pool.get_all_conversation_vars(),
|
|
||||||
node_outputs=pool.get_all_node_outputs(),
|
|
||||||
system_vars=pool.get_all_system_vars(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Select the appropriate assignment operator instance based on the target variable type
|
# Execute the configured assignment operation
|
||||||
operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))(
|
match assignment.operation:
|
||||||
pool, variable_selector, value
|
case AssignmentOperator.ASSIGN:
|
||||||
)
|
operator.assign()
|
||||||
|
case AssignmentOperator.CLEAR:
|
||||||
# Execute the configured assignment operation
|
operator.clear()
|
||||||
match self.typed_config.operation:
|
case AssignmentOperator.ADD:
|
||||||
case AssignmentOperator.ASSIGN:
|
operator.add()
|
||||||
operator.assign()
|
case AssignmentOperator.SUBTRACT:
|
||||||
case AssignmentOperator.CLEAR:
|
operator.subtract()
|
||||||
operator.clear()
|
case AssignmentOperator.MULTIPLY:
|
||||||
case AssignmentOperator.ADD:
|
operator.multiply()
|
||||||
operator.add()
|
case AssignmentOperator.DIVIDE:
|
||||||
case AssignmentOperator.SUBTRACT:
|
operator.divide()
|
||||||
operator.subtract()
|
case AssignmentOperator.APPEND:
|
||||||
case AssignmentOperator.MULTIPLY:
|
operator.append()
|
||||||
operator.multiply()
|
case AssignmentOperator.REMOVE_FIRST:
|
||||||
case AssignmentOperator.DIVIDE:
|
operator.remove_first()
|
||||||
operator.divide()
|
case AssignmentOperator.REMOVE_LAST:
|
||||||
case AssignmentOperator.APPEND:
|
operator.remove_last()
|
||||||
operator.append()
|
case _:
|
||||||
case AssignmentOperator.REMOVE_FIRST:
|
raise ValueError(f"Invalid Operator: {assignment.operation}")
|
||||||
operator.remove_first()
|
|
||||||
case AssignmentOperator.REMOVE_LAST:
|
|
||||||
operator.remove_last()
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"Invalid Operator: {self.typed_config.operation}")
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
|
|||||||
from app.core.workflow.nodes.agent.config import AgentNodeConfig
|
from app.core.workflow.nodes.agent.config import AgentNodeConfig
|
||||||
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||||
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
|
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
|
||||||
# from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
|
from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
|
||||||
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -30,6 +30,6 @@ __all__ = [
|
|||||||
"AgentNodeConfig",
|
"AgentNodeConfig",
|
||||||
"TransformNodeConfig",
|
"TransformNodeConfig",
|
||||||
"IfElseNodeConfig",
|
"IfElseNodeConfig",
|
||||||
# "KnowledgeRetrievalNodeConfig",
|
"KnowledgeRetrievalNodeConfig",
|
||||||
"AssignerNodeConfig",
|
"AssignerNodeConfig",
|
||||||
]
|
]
|
||||||
|
|||||||
4
api/app/core/workflow/nodes/knowledge/__init__.py
Normal file
4
api/app/core/workflow/nodes/knowledge/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
|
||||||
|
from app.core.workflow.nodes.knowledge.node import KnowledgeRetrievalNode
|
||||||
|
|
||||||
|
__all__ = ["KnowledgeRetrievalNode", "KnowledgeRetrievalNodeConfig"]
|
||||||
38
api/app/core/workflow/nodes/knowledge/config.py
Normal file
38
api/app/core/workflow/nodes/knowledge/config.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
|
from app.schemas.chunk_schema import RetrieveType
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
|
||||||
|
query: str = Field(
|
||||||
|
...,
|
||||||
|
description="Search query string"
|
||||||
|
)
|
||||||
|
|
||||||
|
kb_ids: list[UUID] = Field(
|
||||||
|
...,
|
||||||
|
description="Knowledge base IDs"
|
||||||
|
)
|
||||||
|
|
||||||
|
similarity_threshold: float = Field(
|
||||||
|
default=0.2,
|
||||||
|
description="Knowledge base similarity threshold"
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_similarity_weight: float = Field(
|
||||||
|
default=0.3,
|
||||||
|
description="Knowledge base vector similarity weight"
|
||||||
|
)
|
||||||
|
|
||||||
|
top_k: int = Field(
|
||||||
|
default=4,
|
||||||
|
description="Knowledge base top k"
|
||||||
|
)
|
||||||
|
|
||||||
|
retrieve_type: RetrieveType = Field(
|
||||||
|
default=RetrieveType.PARTICIPLE,
|
||||||
|
description="Retrieve type"
|
||||||
|
)
|
||||||
97
api/app/core/workflow/nodes/knowledge/node.py
Normal file
97
api/app/core/workflow/nodes/knowledge/node.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||||
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
|
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
||||||
|
from app.db import get_db
|
||||||
|
from app.models import knowledge_model, knowledgeshare_model
|
||||||
|
from app.repositories import knowledge_repository
|
||||||
|
from app.schemas.chunk_schema import RetrieveType
|
||||||
|
from app.services import knowledge_service, knowledgeshare_service
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeRetrievalNode(BaseNode):
|
||||||
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
|
super().__init__(node_config, workflow_config)
|
||||||
|
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
|
||||||
|
|
||||||
|
async def execute(self, state: WorkflowState) -> Any:
|
||||||
|
query = self._render_template(self.typed_config.query, state)
|
||||||
|
db_gen = get_db()
|
||||||
|
db = next(db_gen)
|
||||||
|
try:
|
||||||
|
filters = [
|
||||||
|
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
|
||||||
|
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private,
|
||||||
|
knowledge_model.Knowledge.chunk_num > 0,
|
||||||
|
knowledge_model.Knowledge.status == 1
|
||||||
|
]
|
||||||
|
existing_ids = knowledge_repository.get_chunked_knowledgeids(
|
||||||
|
db=db,
|
||||||
|
filters=filters
|
||||||
|
)
|
||||||
|
filters = [
|
||||||
|
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
|
||||||
|
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
|
||||||
|
knowledge_model.Knowledge.chunk_num > 0,
|
||||||
|
knowledge_model.Knowledge.status == 1
|
||||||
|
]
|
||||||
|
share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids(
|
||||||
|
db=db,
|
||||||
|
filters=filters
|
||||||
|
)
|
||||||
|
if share_ids:
|
||||||
|
filters = [
|
||||||
|
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(self.typed_config.kb_ids)
|
||||||
|
]
|
||||||
|
items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
|
||||||
|
db=db,
|
||||||
|
filters=filters
|
||||||
|
)
|
||||||
|
existing_ids.extend(items)
|
||||||
|
|
||||||
|
if not existing_ids:
|
||||||
|
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
|
||||||
|
|
||||||
|
kb_id = existing_ids[0]
|
||||||
|
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
|
||||||
|
indices = ",".join(uuid_strs)
|
||||||
|
|
||||||
|
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id)
|
||||||
|
if not db_knowledge:
|
||||||
|
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
||||||
|
|
||||||
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
|
|
||||||
|
match self.typed_config.retrieve_type:
|
||||||
|
case RetrieveType.PARTICIPLE:
|
||||||
|
rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=self.typed_config.similarity_threshold)
|
||||||
|
return [chunk.model_dump() for chunk in rs]
|
||||||
|
case RetrieveType.SEMANTIC:
|
||||||
|
rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=self.typed_config.vector_similarity_weight)
|
||||||
|
return [chunk.model_dump() for chunk in rs]
|
||||||
|
case _:
|
||||||
|
rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=self.typed_config.vector_similarity_weight)
|
||||||
|
rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=self.typed_config.similarity_threshold)
|
||||||
|
# Efficient deduplication
|
||||||
|
seen_ids = set()
|
||||||
|
unique_rs = []
|
||||||
|
for doc in rs1 + rs2:
|
||||||
|
if doc.metadata["doc_id"] not in seen_ids:
|
||||||
|
seen_ids.add(doc.metadata["doc_id"])
|
||||||
|
unique_rs.append(doc)
|
||||||
|
rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k)
|
||||||
|
return [chunk.model_dump() for chunk in rs]
|
||||||
|
finally:
|
||||||
|
next(db_gen)
|
||||||
@@ -7,7 +7,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||||
from app.core.workflow.nodes.agent import AgentNode
|
from app.core.workflow.nodes.agent import AgentNode
|
||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
from app.core.workflow.nodes.end import EndNode
|
from app.core.workflow.nodes.end import EndNode
|
||||||
@@ -29,7 +29,7 @@ WorkflowNode = Union[
|
|||||||
AgentNode,
|
AgentNode,
|
||||||
TransformNode,
|
TransformNode,
|
||||||
AssignerNode,
|
AssignerNode,
|
||||||
# KnowledgeRetrievalNode,
|
KnowledgeRetrievalNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ class NodeFactory:
|
|||||||
NodeType.AGENT: AgentNode,
|
NodeType.AGENT: AgentNode,
|
||||||
NodeType.TRANSFORM: TransformNode,
|
NodeType.TRANSFORM: TransformNode,
|
||||||
NodeType.IF_ELSE: IfElseNode,
|
NodeType.IF_ELSE: IfElseNode,
|
||||||
# NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
||||||
NodeType.ASSIGNER: AssignerNode,
|
NodeType.ASSIGNER: AssignerNode,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class StartNode(BaseNode):
|
|||||||
|
|
||||||
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
|
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
"""初始化 Start 节点
|
"""初始化 Start 节点
|
||||||
|
|
||||||
@@ -32,10 +32,10 @@ class StartNode(BaseNode):
|
|||||||
workflow_config: 工作流配置
|
workflow_config: 工作流配置
|
||||||
"""
|
"""
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
|
|
||||||
# 解析并验证配置
|
# 解析并验证配置
|
||||||
self.typed_config = StartNodeConfig(**self.config)
|
self.typed_config = StartNodeConfig(**self.config)
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
"""执行 start 节点业务逻辑
|
"""执行 start 节点业务逻辑
|
||||||
|
|
||||||
@@ -48,13 +48,13 @@ class StartNode(BaseNode):
|
|||||||
包含系统参数、会话变量和自定义变量的字典
|
包含系统参数、会话变量和自定义变量的字典
|
||||||
"""
|
"""
|
||||||
logger.info(f"节点 {self.node_id} (Start) 开始执行")
|
logger.info(f"节点 {self.node_id} (Start) 开始执行")
|
||||||
|
|
||||||
# 创建变量池实例(在方法内复用)
|
# 创建变量池实例(在方法内复用)
|
||||||
pool = self.get_variable_pool(state)
|
pool = self.get_variable_pool(state)
|
||||||
|
|
||||||
# 处理自定义变量(传入 pool 避免重复创建)
|
# 处理自定义变量(传入 pool 避免重复创建)
|
||||||
custom_vars = self._process_custom_variables(pool)
|
custom_vars = self._process_custom_variables(pool)
|
||||||
|
|
||||||
# 返回业务数据(包含自定义变量)
|
# 返回业务数据(包含自定义变量)
|
||||||
result = {
|
result = {
|
||||||
"message": pool.get("sys.message"),
|
"message": pool.get("sys.message"),
|
||||||
@@ -64,14 +64,14 @@ class StartNode(BaseNode):
|
|||||||
"user_id": pool.get("sys.user_id"),
|
"user_id": pool.get("sys.user_id"),
|
||||||
**custom_vars # 自定义变量作为节点输出的一部分
|
**custom_vars # 自定义变量作为节点输出的一部分
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"节点 {self.node_id} (Start) 执行完成,"
|
f"节点 {self.node_id} (Start) 执行完成,"
|
||||||
f"输出了 {len(custom_vars)} 个自定义变量"
|
f"输出了 {len(custom_vars)} 个自定义变量"
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _process_custom_variables(self, pool) -> dict[str, Any]:
|
def _process_custom_variables(self, pool) -> dict[str, Any]:
|
||||||
"""处理自定义变量
|
"""处理自定义变量
|
||||||
|
|
||||||
@@ -88,34 +88,33 @@ class StartNode(BaseNode):
|
|||||||
"""
|
"""
|
||||||
# 获取输入数据中的自定义变量
|
# 获取输入数据中的自定义变量
|
||||||
input_variables = pool.get("sys.input_variables", default={})
|
input_variables = pool.get("sys.input_variables", default={})
|
||||||
|
|
||||||
processed = {}
|
processed = {}
|
||||||
|
|
||||||
# 遍历配置的变量定义
|
# 遍历配置的变量定义
|
||||||
for var_def in self.typed_config.variables:
|
for var_def in self.typed_config.variables:
|
||||||
var_name = var_def.name
|
var_name = var_def.name
|
||||||
|
|
||||||
# 检查变量是否存在
|
# 检查变量是否存在
|
||||||
if var_name in input_variables:
|
if var_name in input_variables:
|
||||||
# 使用用户提供的值
|
# 使用用户提供的值
|
||||||
processed[var_name] = input_variables[var_name]
|
processed[var_name] = input_variables[var_name]
|
||||||
|
|
||||||
elif var_def.required:
|
elif var_def.required:
|
||||||
# 必需变量缺失
|
# 必需变量缺失
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"缺少必需的输入变量: {var_name}"
|
f"缺少必需的输入变量: {var_name}"
|
||||||
+ (f" ({var_def.description})" if var_def.description else "")
|
+ (f" ({var_def.description})" if var_def.description else "")
|
||||||
)
|
)
|
||||||
|
|
||||||
elif var_def.default is not None:
|
elif var_def.default is not None:
|
||||||
# 使用默认值
|
# 使用默认值
|
||||||
processed[var_name] = var_def.default
|
processed[var_name] = var_def.default
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"变量 '{var_name}' 使用默认值: {var_def.default}"
|
f"变量 '{var_name}' 使用默认值: {var_def.default}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return processed
|
return processed
|
||||||
|
|
||||||
|
|
||||||
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
"""提取输入数据(用于记录)
|
"""提取输入数据(用于记录)
|
||||||
@@ -127,7 +126,7 @@ class StartNode(BaseNode):
|
|||||||
输入数据字典
|
输入数据字典
|
||||||
"""
|
"""
|
||||||
pool = self.get_variable_pool(state)
|
pool = self.get_variable_pool(state)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"execution_id": pool.get("sys.execution_id"),
|
"execution_id": pool.get("sys.execution_id"),
|
||||||
"conversation_id": pool.get("sys.conversation_id"),
|
"conversation_id": pool.get("sys.conversation_id"),
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ def get_knowledges_paginated(
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def get_chunded_knowledgeids(
|
def get_chunked_knowledgeids(
|
||||||
db: Session,
|
db: Session,
|
||||||
filters: list
|
filters: list
|
||||||
) -> list:
|
) -> list:
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ def get_chunded_knowledgeids(
|
|||||||
business_logger.debug(f"Query the list of vectorized knowledge base IDs: username={current_user.username}")
|
business_logger.debug(f"Query the list of vectorized knowledge base IDs: username={current_user.username}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
items = knowledge_repository.get_chunded_knowledgeids(
|
items = knowledge_repository.get_chunked_knowledgeids(
|
||||||
db=db,
|
db=db,
|
||||||
filters=filters
|
filters=filters
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -44,11 +44,11 @@ nodes:
|
|||||||
- role: user
|
- role: user
|
||||||
content: "{{ sys.message }}"
|
content: "{{ sys.message }}"
|
||||||
|
|
||||||
model_id: gpt-3.5-turbo
|
model_id: null
|
||||||
temperature: 0.7
|
temperature: 0.7
|
||||||
max_tokens: 1000
|
max_tokens: 1000
|
||||||
position:
|
position:
|
||||||
x: 300
|
x: 500
|
||||||
y: 100
|
y: 100
|
||||||
|
|
||||||
- id: end
|
- id: end
|
||||||
@@ -57,7 +57,7 @@ nodes:
|
|||||||
config:
|
config:
|
||||||
output: "{{ llm_qa.output }}"
|
output: "{{ llm_qa.output }}"
|
||||||
position:
|
position:
|
||||||
x: 500
|
x: 900
|
||||||
y: 100
|
y: 100
|
||||||
|
|
||||||
edges:
|
edges:
|
||||||
|
|||||||
Reference in New Issue
Block a user