diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index 9cec19d2..5e586a9c 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -1,4 +1,14 @@ from enum import StrEnum +from typing import Union + +from app.core.workflow.nodes.base_node import BaseNode +from app.core.workflow.nodes.if_else import IfElseNode +from app.core.workflow.nodes.llm import LLMNode +from app.core.workflow.nodes.agent import AgentNode +from app.core.workflow.nodes.transform import TransformNode +from app.core.workflow.nodes.start import StartNode +from app.core.workflow.nodes.end import EndNode + class NodeType(StrEnum): START = "start" @@ -13,3 +23,14 @@ class NodeType(StrEnum): HTTP_REQUEST = "http-request" TOOL = "tool" AGENT = "agent" + + +WorkflowNode = Union[ + BaseNode, + StartNode, + EndNode, + LLMNode, + IfElseNode, + AgentNode, + TransformNode, +] diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index f279d13a..e1f32308 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -8,7 +8,8 @@ import logging from typing import Any from app.core.workflow.nodes.base_node import BaseNode -from app.core.workflow.nodes.enums import NodeType +from app.core.workflow.nodes.enums import NodeType, WorkflowNode +from app.core.workflow.nodes.if_else import IfElseNode from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.agent import AgentNode from app.core.workflow.nodes.transform import TransformNode @@ -25,16 +26,17 @@ class NodeFactory: """ # 节点类型注册表 - _node_types: dict[str, type[BaseNode]] = { + _node_types: dict[str, type[WorkflowNode]] = { NodeType.START: StartNode, NodeType.END: EndNode, NodeType.LLM: LLMNode, NodeType.AGENT: AgentNode, NodeType.TRANSFORM: TransformNode, + NodeType.IF_ELSE: IfElseNode } @classmethod - def register_node_type(cls, node_type: str, node_class: type[BaseNode]): + def register_node_type(cls, node_type: str, node_class: type[WorkflowNode]): """注册新的节点类型 Args: @@ -55,7 +57,7 @@ class NodeFactory: cls, node_config: dict[str, Any], workflow_config: dict[str, Any] - ) -> BaseNode | None: + ) -> WorkflowNode | None: """创建节点实例 Args: