feat(tool system): add workflow tool nodes

This commit is contained in:
谢俊男
2025-12-30 21:08:05 +08:00
parent e6c35e5f5a
commit 0475d80472
6 changed files with 96 additions and 2 deletions

View File

@@ -17,6 +17,8 @@ from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.tool import ToolNode
__all__ = [
"BaseNode",
@@ -33,5 +35,7 @@ __all__ = [
"AssignerNode",
"HttpRequestNode",
"JinjaRenderNode",
"ParameterExtractorNode"
"ParameterExtractorNode",
"QuestionClassifierNode",
"ToolNode"
]

View File

@@ -21,6 +21,7 @@ from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
__all__ = [
@@ -45,4 +46,5 @@ __all__ = [
"LoopNodeConfig",
"IterationNodeConfig",
"QuestionClassifierNodeConfig"
"ToolNodeConfig"
]

View File

@@ -24,6 +24,7 @@ from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.breaker import BreakNode
from app.core.workflow.nodes.tool import ToolNode
logger = logging.getLogger(__name__)
@@ -44,7 +45,8 @@ WorkflowNode = Union[
CycleGraphNode,
BreakNode,
ParameterExtractorNode,
QuestionClassifierNode
QuestionClassifierNode,
ToolNode
]
@@ -72,6 +74,7 @@ class NodeFactory:
NodeType.LOOP: CycleGraphNode,
NodeType.ITERATION: CycleGraphNode,
NodeType.BREAK: BreakNode,
NodeType.TOOL: ToolNode,
}
@classmethod

View File

@@ -0,0 +1,4 @@
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.tool.node import ToolNode
__all__ = ["ToolNode", "ToolNodeConfig"]

View File

@@ -0,0 +1,9 @@
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig
class ToolNodeConfig(BaseNodeConfig):
"""工具节点配置"""
tool_id: str = Field(..., description="工具ID")
tool_parameters: dict[str, str] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")

View File

@@ -0,0 +1,72 @@
import logging
import uuid
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.services.tool_service import ToolService
from app.db import get_db_read
logger = logging.getLogger(__name__)
class ToolNode(BaseNode):
"""工具节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = ToolNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行工具"""
# 获取租户ID和用户ID
tenant_id = self.get_variable("sys.tenant_id", state)
user_id = self.get_variable("sys.user_id", state)
# 如果没有租户ID尝试从工作流ID获取
if not tenant_id:
workflow_id = self.get_variable("sys.workflow_id", state)
if workflow_id:
from app.repositories.tool_repository import ToolRepository
with get_db_read() as db:
tenant_id = ToolRepository.get_tenant_id_by_workflow_id(db, workflow_id)
if not tenant_id:
tenant_id = uuid.UUID("6c2c91b0-3f49-4489-9157-2208aa56a097")
# logger.error(f"节点 {self.node_id} 缺少租户ID")
# return {"error": "缺少租户ID"}
# 渲染工具参数
rendered_parameters = {}
for param_name, param_template in self.typed_config.tool_parameters.items():
rendered_value = self._render_template(param_template, state)
rendered_parameters[param_name] = rendered_value
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
print(self.typed_config.tool_id)
# 执行工具
with get_db_read() as db:
tool_service = ToolService(db)
result = await tool_service.execute_tool(
tool_id=self.typed_config.tool_id,
parameters=rendered_parameters,
tenant_id=tenant_id,
user_id=user_id
)
print(result)
if result.success:
logger.info(f"节点 {self.node_id} 工具执行成功")
return {
"success": True,
"data": result.data,
"execution_time": result.execution_time
}
else:
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
return {
"success": False,
"error": result.error,
"error_code": result.error_code,
"execution_time": result.execution_time
}