fix(workflow): move node config validation to runtime for proper error handling
This commit is contained in:
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
|
||||
class AssignerNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = AssignerNodeConfig(**self.config)
|
||||
self.typed_config: AssignerNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""
|
||||
@@ -28,6 +28,7 @@ class AssignerNode(BaseNode):
|
||||
None or the result of the assignment operation.
|
||||
"""
|
||||
# Initialize a variable pool for accessing conversation, node, and system variables
|
||||
self.typed_config = AssignerNodeConfig(**self.config)
|
||||
logger.info(f"节点 {self.node_id} 开始执行")
|
||||
pool = VariablePool(state)
|
||||
for assignment in self.typed_config.assignments:
|
||||
|
||||
@@ -30,7 +30,6 @@ class CycleGraphNode(BaseNode):
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None
|
||||
|
||||
self.cycle_nodes = list() # Nodes belonging to this cycle
|
||||
self.cycle_edges = list() # Edges connecting nodes within the cycle
|
||||
|
||||
@@ -32,7 +32,7 @@ class HttpRequestNode(BaseNode):
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = HttpRequestNodeConfig(**self.config)
|
||||
self.typed_config: HttpRequestNodeConfig | None = None
|
||||
|
||||
def _build_timeout(self) -> Timeout:
|
||||
"""
|
||||
@@ -181,6 +181,7 @@ class HttpRequestNode(BaseNode):
|
||||
- dict: Serialized HttpRequestNodeOutput on success
|
||||
- str: Branch identifier (e.g. "ERROR") when branching is enabled
|
||||
"""
|
||||
self.typed_config = HttpRequestNodeConfig(**self.config)
|
||||
async with httpx.AsyncClient(
|
||||
verify=self.typed_config.verify_ssl,
|
||||
timeout=self._build_timeout(),
|
||||
|
||||
@@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
|
||||
class IfElseNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = IfElseNodeConfig(**self.config)
|
||||
self.typed_config: IfElseNodeConfig | None= None
|
||||
|
||||
@staticmethod
|
||||
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
|
||||
@@ -109,6 +109,7 @@ class IfElseNode(BaseNode):
|
||||
Returns:
|
||||
str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions.
|
||||
"""
|
||||
self.typed_config = IfElseNodeConfig(**self.config)
|
||||
expressions = self.evaluate_conditional_edge_expressions(state)
|
||||
# TODO: 变量类型及文本类型解析
|
||||
for i in range(len(expressions)):
|
||||
|
||||
@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
|
||||
class JinjaRenderNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = JinjaRenderNodeConfig(**self.config)
|
||||
self.typed_config: JinjaRenderNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""
|
||||
@@ -34,6 +34,7 @@ class JinjaRenderNode(BaseNode):
|
||||
RuntimeError: If Jinja2 template rendering fails due to invalid template
|
||||
syntax or missing variables.
|
||||
"""
|
||||
self.typed_config = JinjaRenderNodeConfig(**self.config)
|
||||
render = TemplateRenderer(strict=False)
|
||||
|
||||
context = {}
|
||||
|
||||
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
|
||||
class KnowledgeRetrievalNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
|
||||
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
||||
|
||||
@staticmethod
|
||||
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
|
||||
@@ -171,6 +171,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
Raises:
|
||||
RuntimeError: If no valid knowledge base is found or access is denied.
|
||||
"""
|
||||
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
|
||||
query = self._render_template(self.typed_config.query, state)
|
||||
with get_db_read() as db:
|
||||
knowledge_bases = self.typed_config.knowledge_bases
|
||||
|
||||
@@ -68,7 +68,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = LLMNodeConfig(**self.config)
|
||||
self.typed_config: LLMNodeConfig | None = None
|
||||
|
||||
def _render_context(self, message, state):
|
||||
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
|
||||
@@ -164,6 +164,7 @@ class LLMNode(BaseNode):
|
||||
Returns:
|
||||
LLM 响应消息
|
||||
"""
|
||||
self.typed_config = LLMNodeConfig(**self.config)
|
||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||
|
||||
@@ -10,9 +10,10 @@ from app.services.memory_agent_service import MemoryAgentService
|
||||
class MemoryReadNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = MemoryReadNodeConfig(**self.config)
|
||||
self.typed_config: MemoryReadNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
self.typed_config = MemoryReadNodeConfig(**self.config)
|
||||
with get_db_read() as db:
|
||||
workspace_id = self.get_variable('sys.workspace_id', state)
|
||||
end_user_id = self.get_variable("sys.user_id", state)
|
||||
|
||||
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
|
||||
class ParameterExtractorNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = ParameterExtractorNodeConfig(**self.config)
|
||||
self.typed_config: ParameterExtractorNodeConfig | None = None
|
||||
|
||||
@staticmethod
|
||||
def _get_prompt():
|
||||
@@ -145,6 +145,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
Raises:
|
||||
BusinessException: If LLM output cannot be parsed as valid JSON.
|
||||
"""
|
||||
self.typed_config = ParameterExtractorNodeConfig(**self.config)
|
||||
llm = self._get_llm_instance()
|
||||
system_prompt, user_prompt = self._get_prompt()
|
||||
|
||||
|
||||
@@ -21,8 +21,8 @@ class QuestionClassifierNode(BaseNode):
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = QuestionClassifierNodeConfig(**self.config)
|
||||
self.category_to_case_map = self._build_category_case_map()
|
||||
self.typed_config: QuestionClassifierNodeConfig | None = None
|
||||
self.category_to_case_map = {}
|
||||
|
||||
def _get_llm_instance(self) -> RedBearLLM:
|
||||
"""获取LLM实例"""
|
||||
@@ -67,6 +67,8 @@ class QuestionClassifierNode(BaseNode):
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict:
|
||||
"""执行问题分类"""
|
||||
self.typed_config = QuestionClassifierNodeConfig(**self.config)
|
||||
self.category_to_case_map = self._build_category_case_map()
|
||||
question = self.typed_config.input_variable
|
||||
supplement_prompt = self.typed_config.user_supplement_prompt or ""
|
||||
categories = self.typed_config.categories or []
|
||||
|
||||
@@ -35,7 +35,7 @@ class StartNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
|
||||
# 解析并验证配置
|
||||
self.typed_config = StartNodeConfig(**self.config)
|
||||
self.typed_config: StartNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""执行 start 节点业务逻辑
|
||||
@@ -48,6 +48,7 @@ class StartNode(BaseNode):
|
||||
Returns:
|
||||
包含系统参数、会话变量和自定义变量的字典
|
||||
"""
|
||||
self.typed_config = StartNodeConfig(**self.config)
|
||||
logger.info(f"节点 {self.node_id} (Start) 开始执行")
|
||||
|
||||
# 创建变量池实例(在方法内复用)
|
||||
|
||||
@@ -17,10 +17,11 @@ class ToolNode(BaseNode):
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = ToolNodeConfig(**self.config)
|
||||
self.typed_config: ToolNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""执行工具"""
|
||||
self.typed_config = ToolNodeConfig(**self.config)
|
||||
# 获取租户ID和用户ID
|
||||
tenant_id = self.get_variable("sys.tenant_id", state)
|
||||
user_id = self.get_variable("sys.user_id", state)
|
||||
|
||||
@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
|
||||
class VariableAggregatorNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = VariableAggregatorNodeConfig(**self.config)
|
||||
self.typed_config: VariableAggregatorNodeConfig | None = None
|
||||
|
||||
@staticmethod
|
||||
def _get_express(variable_string: str) -> Any:
|
||||
@@ -37,6 +37,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
- str: In non-group mode, returns the first non-None variable value.
|
||||
- dict: In group mode, returns a mapping of group_name -> first non-None variable value.
|
||||
"""
|
||||
self.typed_config = VariableAggregatorNodeConfig(**self.config)
|
||||
if not self.typed_config.group:
|
||||
# --------------------------
|
||||
# Non-group mode
|
||||
|
||||
Reference in New Issue
Block a user