feat(tool system): add workflow tool nodes
This commit is contained in:
@@ -17,6 +17,8 @@ from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
|
|||||||
from app.core.workflow.nodes.start import StartNode
|
from app.core.workflow.nodes.start import StartNode
|
||||||
from app.core.workflow.nodes.transform import TransformNode
|
from app.core.workflow.nodes.transform import TransformNode
|
||||||
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||||
|
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||||
|
from app.core.workflow.nodes.tool import ToolNode
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseNode",
|
"BaseNode",
|
||||||
@@ -33,5 +35,7 @@ __all__ = [
|
|||||||
"AssignerNode",
|
"AssignerNode",
|
||||||
"HttpRequestNode",
|
"HttpRequestNode",
|
||||||
"JinjaRenderNode",
|
"JinjaRenderNode",
|
||||||
"ParameterExtractorNode"
|
"ParameterExtractorNode",
|
||||||
|
"QuestionClassifierNode",
|
||||||
|
"ToolNode"
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
|||||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||||
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
|
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
|
||||||
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
|
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
|
||||||
|
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||||
|
|
||||||
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -45,4 +46,5 @@ __all__ = [
|
|||||||
"LoopNodeConfig",
|
"LoopNodeConfig",
|
||||||
"IterationNodeConfig",
|
"IterationNodeConfig",
|
||||||
"QuestionClassifierNodeConfig"
|
"QuestionClassifierNodeConfig"
|
||||||
|
"ToolNodeConfig"
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from app.core.workflow.nodes.transform import TransformNode
|
|||||||
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||||
from app.core.workflow.nodes.breaker import BreakNode
|
from app.core.workflow.nodes.breaker import BreakNode
|
||||||
|
from app.core.workflow.nodes.tool import ToolNode
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -44,7 +45,8 @@ WorkflowNode = Union[
|
|||||||
CycleGraphNode,
|
CycleGraphNode,
|
||||||
BreakNode,
|
BreakNode,
|
||||||
ParameterExtractorNode,
|
ParameterExtractorNode,
|
||||||
QuestionClassifierNode
|
QuestionClassifierNode,
|
||||||
|
ToolNode
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -72,6 +74,7 @@ class NodeFactory:
|
|||||||
NodeType.LOOP: CycleGraphNode,
|
NodeType.LOOP: CycleGraphNode,
|
||||||
NodeType.ITERATION: CycleGraphNode,
|
NodeType.ITERATION: CycleGraphNode,
|
||||||
NodeType.BREAK: BreakNode,
|
NodeType.BREAK: BreakNode,
|
||||||
|
NodeType.TOOL: ToolNode,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
4
api/app/core/workflow/nodes/tool/__init__.py
Normal file
4
api/app/core/workflow/nodes/tool/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||||
|
from app.core.workflow.nodes.tool.node import ToolNode
|
||||||
|
|
||||||
|
__all__ = ["ToolNode", "ToolNodeConfig"]
|
||||||
9
api/app/core/workflow/nodes/tool/config.py
Normal file
9
api/app/core/workflow/nodes/tool/config.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from pydantic import Field
|
||||||
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ToolNodeConfig(BaseNodeConfig):
|
||||||
|
"""工具节点配置"""
|
||||||
|
|
||||||
|
tool_id: str = Field(..., description="工具ID")
|
||||||
|
tool_parameters: dict[str, str] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")
|
||||||
72
api/app/core/workflow/nodes/tool/node.py
Normal file
72
api/app/core/workflow/nodes/tool/node.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
|
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||||
|
from app.services.tool_service import ToolService
|
||||||
|
from app.db import get_db_read
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolNode(BaseNode):
|
||||||
|
"""工具节点"""
|
||||||
|
|
||||||
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
|
super().__init__(node_config, workflow_config)
|
||||||
|
self.typed_config = ToolNodeConfig(**self.config)
|
||||||
|
|
||||||
|
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
|
"""执行工具"""
|
||||||
|
# 获取租户ID和用户ID
|
||||||
|
tenant_id = self.get_variable("sys.tenant_id", state)
|
||||||
|
user_id = self.get_variable("sys.user_id", state)
|
||||||
|
|
||||||
|
# 如果没有租户ID,尝试从工作流ID获取
|
||||||
|
if not tenant_id:
|
||||||
|
workflow_id = self.get_variable("sys.workflow_id", state)
|
||||||
|
if workflow_id:
|
||||||
|
from app.repositories.tool_repository import ToolRepository
|
||||||
|
with get_db_read() as db:
|
||||||
|
tenant_id = ToolRepository.get_tenant_id_by_workflow_id(db, workflow_id)
|
||||||
|
|
||||||
|
if not tenant_id:
|
||||||
|
tenant_id = uuid.UUID("6c2c91b0-3f49-4489-9157-2208aa56a097")
|
||||||
|
# logger.error(f"节点 {self.node_id} 缺少租户ID")
|
||||||
|
# return {"error": "缺少租户ID"}
|
||||||
|
|
||||||
|
# 渲染工具参数
|
||||||
|
rendered_parameters = {}
|
||||||
|
for param_name, param_template in self.typed_config.tool_parameters.items():
|
||||||
|
rendered_value = self._render_template(param_template, state)
|
||||||
|
rendered_parameters[param_name] = rendered_value
|
||||||
|
|
||||||
|
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
|
||||||
|
print(self.typed_config.tool_id)
|
||||||
|
|
||||||
|
# 执行工具
|
||||||
|
with get_db_read() as db:
|
||||||
|
tool_service = ToolService(db)
|
||||||
|
result = await tool_service.execute_tool(
|
||||||
|
tool_id=self.typed_config.tool_id,
|
||||||
|
parameters=rendered_parameters,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id
|
||||||
|
)
|
||||||
|
print(result)
|
||||||
|
if result.success:
|
||||||
|
logger.info(f"节点 {self.node_id} 工具执行成功")
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": result.data,
|
||||||
|
"execution_time": result.execution_time
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": result.error,
|
||||||
|
"error_code": result.error_code,
|
||||||
|
"execution_time": result.execution_time
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user