perf(types): add Union type declaration for workflow nodes
- Introduce a `Nodes` type as a Union of all workflow node classes. - Improves type checking and IDE autocompletion.
This commit is contained in:
@@ -1,4 +1,14 @@
|
||||
from enum import StrEnum
|
||||
from typing import Union
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.if_else import IfElseNode
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.agent import AgentNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.end import EndNode
|
||||
|
||||
|
||||
class NodeType(StrEnum):
|
||||
START = "start"
|
||||
@@ -13,3 +23,14 @@ class NodeType(StrEnum):
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
AGENT = "agent"
|
||||
|
||||
|
||||
WorkflowNode = Union[
|
||||
BaseNode,
|
||||
StartNode,
|
||||
EndNode,
|
||||
LLMNode,
|
||||
IfElseNode,
|
||||
AgentNode,
|
||||
TransformNode,
|
||||
]
|
||||
|
||||
@@ -8,7 +8,8 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.nodes.enums import NodeType, WorkflowNode
|
||||
from app.core.workflow.nodes.if_else import IfElseNode
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.agent import AgentNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
@@ -25,16 +26,17 @@ class NodeFactory:
|
||||
"""
|
||||
|
||||
# 节点类型注册表
|
||||
_node_types: dict[str, type[BaseNode]] = {
|
||||
_node_types: dict[str, type[WorkflowNode]] = {
|
||||
NodeType.START: StartNode,
|
||||
NodeType.END: EndNode,
|
||||
NodeType.LLM: LLMNode,
|
||||
NodeType.AGENT: AgentNode,
|
||||
NodeType.TRANSFORM: TransformNode,
|
||||
NodeType.IF_ELSE: IfElseNode
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_node_type(cls, node_type: str, node_class: type[BaseNode]):
|
||||
def register_node_type(cls, node_type: str, node_class: type[WorkflowNode]):
|
||||
"""注册新的节点类型
|
||||
|
||||
Args:
|
||||
@@ -55,7 +57,7 @@ class NodeFactory:
|
||||
cls,
|
||||
node_config: dict[str, Any],
|
||||
workflow_config: dict[str, Any]
|
||||
) -> BaseNode | None:
|
||||
) -> WorkflowNode | None:
|
||||
"""创建节点实例
|
||||
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user