fix(workflow): move node config validation to runtime for proper error handling

This commit is contained in:
mengyonghao
2026-01-14 10:47:38 +08:00
parent 7438fedd6b
commit 84e24ede04
13 changed files with 26 additions and 14 deletions

View File

@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
class AssignerNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = AssignerNodeConfig(**self.config)
self.typed_config: AssignerNodeConfig | None = None
async def execute(self, state: WorkflowState) -> Any:
"""
@@ -28,6 +28,7 @@ class AssignerNode(BaseNode):
None or the result of the assignment operation.
"""
# Initialize a variable pool for accessing conversation, node, and system variables
self.typed_config = AssignerNodeConfig(**self.config)
logger.info(f"节点 {self.node_id} 开始执行")
pool = VariablePool(state)
for assignment in self.typed_config.assignments:

View File

@@ -30,7 +30,6 @@ class CycleGraphNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None
self.cycle_nodes = list() # Nodes belonging to this cycle
self.cycle_edges = list() # Edges connecting nodes within the cycle

View File

@@ -32,7 +32,7 @@ class HttpRequestNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = HttpRequestNodeConfig(**self.config)
self.typed_config: HttpRequestNodeConfig | None = None
def _build_timeout(self) -> Timeout:
"""
@@ -181,6 +181,7 @@ class HttpRequestNode(BaseNode):
- dict: Serialized HttpRequestNodeOutput on success
- str: Branch identifier (e.g. "ERROR") when branching is enabled
"""
self.typed_config = HttpRequestNodeConfig(**self.config)
async with httpx.AsyncClient(
verify=self.typed_config.verify_ssl,
timeout=self._build_timeout(),

View File

@@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
class IfElseNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = IfElseNodeConfig(**self.config)
self.typed_config: IfElseNodeConfig | None= None
@staticmethod
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
@@ -109,6 +109,7 @@ class IfElseNode(BaseNode):
Returns:
str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions.
"""
self.typed_config = IfElseNodeConfig(**self.config)
expressions = self.evaluate_conditional_edge_expressions(state)
# TODO: 变量类型及文本类型解析
for i in range(len(expressions)):

View File

@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
class JinjaRenderNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = JinjaRenderNodeConfig(**self.config)
self.typed_config: JinjaRenderNodeConfig | None = None
async def execute(self, state: WorkflowState) -> Any:
"""
@@ -34,6 +34,7 @@ class JinjaRenderNode(BaseNode):
RuntimeError: If Jinja2 template rendering fails due to invalid template
syntax or missing variables.
"""
self.typed_config = JinjaRenderNodeConfig(**self.config)
render = TemplateRenderer(strict=False)
context = {}

View File

@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
class KnowledgeRetrievalNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
@staticmethod
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
@@ -171,6 +171,7 @@ class KnowledgeRetrievalNode(BaseNode):
Raises:
RuntimeError: If no valid knowledge base is found or access is denied.
"""
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
query = self._render_template(self.typed_config.query, state)
with get_db_read() as db:
knowledge_bases = self.typed_config.knowledge_bases

View File

@@ -68,7 +68,7 @@ class LLMNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = LLMNodeConfig(**self.config)
self.typed_config: LLMNodeConfig | None = None
def _render_context(self, message, state):
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
@@ -164,6 +164,7 @@ class LLMNode(BaseNode):
Returns:
LLM 响应消息
"""
self.typed_config = LLMNodeConfig(**self.config)
llm, prompt_or_messages = self._prepare_llm(state, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")

View File

@@ -10,9 +10,10 @@ from app.services.memory_agent_service import MemoryAgentService
class MemoryReadNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = MemoryReadNodeConfig(**self.config)
self.typed_config: MemoryReadNodeConfig | None = None
async def execute(self, state: WorkflowState) -> Any:
self.typed_config = MemoryReadNodeConfig(**self.config)
with get_db_read() as db:
workspace_id = self.get_variable('sys.workspace_id', state)
end_user_id = self.get_variable("sys.user_id", state)

View File

@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
class ParameterExtractorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = ParameterExtractorNodeConfig(**self.config)
self.typed_config: ParameterExtractorNodeConfig | None = None
@staticmethod
def _get_prompt():
@@ -145,6 +145,7 @@ class ParameterExtractorNode(BaseNode):
Raises:
BusinessException: If LLM output cannot be parsed as valid JSON.
"""
self.typed_config = ParameterExtractorNodeConfig(**self.config)
llm = self._get_llm_instance()
system_prompt, user_prompt = self._get_prompt()

View File

@@ -21,8 +21,8 @@ class QuestionClassifierNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = QuestionClassifierNodeConfig(**self.config)
self.category_to_case_map = self._build_category_case_map()
self.typed_config: QuestionClassifierNodeConfig | None = None
self.category_to_case_map = {}
def _get_llm_instance(self) -> RedBearLLM:
"""获取LLM实例"""
@@ -67,6 +67,8 @@ class QuestionClassifierNode(BaseNode):
async def execute(self, state: WorkflowState) -> dict:
"""执行问题分类"""
self.typed_config = QuestionClassifierNodeConfig(**self.config)
self.category_to_case_map = self._build_category_case_map()
question = self.typed_config.input_variable
supplement_prompt = self.typed_config.user_supplement_prompt or ""
categories = self.typed_config.categories or []

View File

@@ -35,7 +35,7 @@ class StartNode(BaseNode):
super().__init__(node_config, workflow_config)
# 解析并验证配置
self.typed_config = StartNodeConfig(**self.config)
self.typed_config: StartNodeConfig | None = None
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行 start 节点业务逻辑
@@ -48,6 +48,7 @@ class StartNode(BaseNode):
Returns:
包含系统参数、会话变量和自定义变量的字典
"""
self.typed_config = StartNodeConfig(**self.config)
logger.info(f"节点 {self.node_id} (Start) 开始执行")
# 创建变量池实例(在方法内复用)

View File

@@ -17,10 +17,11 @@ class ToolNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = ToolNodeConfig(**self.config)
self.typed_config: ToolNodeConfig | None = None
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行工具"""
self.typed_config = ToolNodeConfig(**self.config)
# 获取租户ID和用户ID
tenant_id = self.get_variable("sys.tenant_id", state)
user_id = self.get_variable("sys.user_id", state)

View File

@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
class VariableAggregatorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = VariableAggregatorNodeConfig(**self.config)
self.typed_config: VariableAggregatorNodeConfig | None = None
@staticmethod
def _get_express(variable_string: str) -> Any:
@@ -37,6 +37,7 @@ class VariableAggregatorNode(BaseNode):
- str: In non-group mode, returns the first non-None variable value.
- dict: In group mode, returns a mapping of group_name -> first non-None variable value.
"""
self.typed_config = VariableAggregatorNodeConfig(**self.config)
if not self.typed_config.group:
# --------------------------
# Non-group mode