fix(workflow): move node config validation to runtime for proper error handling
This commit is contained in:
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class AssignerNode(BaseNode):
|
class AssignerNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config = AssignerNodeConfig(**self.config)
|
self.typed_config: AssignerNodeConfig | None = None
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> Any:
|
async def execute(self, state: WorkflowState) -> Any:
|
||||||
"""
|
"""
|
||||||
@@ -28,6 +28,7 @@ class AssignerNode(BaseNode):
|
|||||||
None or the result of the assignment operation.
|
None or the result of the assignment operation.
|
||||||
"""
|
"""
|
||||||
# Initialize a variable pool for accessing conversation, node, and system variables
|
# Initialize a variable pool for accessing conversation, node, and system variables
|
||||||
|
self.typed_config = AssignerNodeConfig(**self.config)
|
||||||
logger.info(f"节点 {self.node_id} 开始执行")
|
logger.info(f"节点 {self.node_id} 开始执行")
|
||||||
pool = VariablePool(state)
|
pool = VariablePool(state)
|
||||||
for assignment in self.typed_config.assignments:
|
for assignment in self.typed_config.assignments:
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ class CycleGraphNode(BaseNode):
|
|||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None
|
|
||||||
|
|
||||||
self.cycle_nodes = list() # Nodes belonging to this cycle
|
self.cycle_nodes = list() # Nodes belonging to this cycle
|
||||||
self.cycle_edges = list() # Edges connecting nodes within the cycle
|
self.cycle_edges = list() # Edges connecting nodes within the cycle
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ class HttpRequestNode(BaseNode):
|
|||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config = HttpRequestNodeConfig(**self.config)
|
self.typed_config: HttpRequestNodeConfig | None = None
|
||||||
|
|
||||||
def _build_timeout(self) -> Timeout:
|
def _build_timeout(self) -> Timeout:
|
||||||
"""
|
"""
|
||||||
@@ -181,6 +181,7 @@ class HttpRequestNode(BaseNode):
|
|||||||
- dict: Serialized HttpRequestNodeOutput on success
|
- dict: Serialized HttpRequestNodeOutput on success
|
||||||
- str: Branch identifier (e.g. "ERROR") when branching is enabled
|
- str: Branch identifier (e.g. "ERROR") when branching is enabled
|
||||||
"""
|
"""
|
||||||
|
self.typed_config = HttpRequestNodeConfig(**self.config)
|
||||||
async with httpx.AsyncClient(
|
async with httpx.AsyncClient(
|
||||||
verify=self.typed_config.verify_ssl,
|
verify=self.typed_config.verify_ssl,
|
||||||
timeout=self._build_timeout(),
|
timeout=self._build_timeout(),
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class IfElseNode(BaseNode):
|
class IfElseNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config = IfElseNodeConfig(**self.config)
|
self.typed_config: IfElseNodeConfig | None= None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
|
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
|
||||||
@@ -109,6 +109,7 @@ class IfElseNode(BaseNode):
|
|||||||
Returns:
|
Returns:
|
||||||
str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions.
|
str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions.
|
||||||
"""
|
"""
|
||||||
|
self.typed_config = IfElseNodeConfig(**self.config)
|
||||||
expressions = self.evaluate_conditional_edge_expressions(state)
|
expressions = self.evaluate_conditional_edge_expressions(state)
|
||||||
# TODO: 变量类型及文本类型解析
|
# TODO: 变量类型及文本类型解析
|
||||||
for i in range(len(expressions)):
|
for i in range(len(expressions)):
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class JinjaRenderNode(BaseNode):
|
class JinjaRenderNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config = JinjaRenderNodeConfig(**self.config)
|
self.typed_config: JinjaRenderNodeConfig | None = None
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> Any:
|
async def execute(self, state: WorkflowState) -> Any:
|
||||||
"""
|
"""
|
||||||
@@ -34,6 +34,7 @@ class JinjaRenderNode(BaseNode):
|
|||||||
RuntimeError: If Jinja2 template rendering fails due to invalid template
|
RuntimeError: If Jinja2 template rendering fails due to invalid template
|
||||||
syntax or missing variables.
|
syntax or missing variables.
|
||||||
"""
|
"""
|
||||||
|
self.typed_config = JinjaRenderNodeConfig(**self.config)
|
||||||
render = TemplateRenderer(strict=False)
|
render = TemplateRenderer(strict=False)
|
||||||
|
|
||||||
context = {}
|
context = {}
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class KnowledgeRetrievalNode(BaseNode):
|
class KnowledgeRetrievalNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
|
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
|
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
|
||||||
@@ -171,6 +171,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If no valid knowledge base is found or access is denied.
|
RuntimeError: If no valid knowledge base is found or access is denied.
|
||||||
"""
|
"""
|
||||||
|
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
|
||||||
query = self._render_template(self.typed_config.query, state)
|
query = self._render_template(self.typed_config.query, state)
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
knowledge_bases = self.typed_config.knowledge_bases
|
knowledge_bases = self.typed_config.knowledge_bases
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ class LLMNode(BaseNode):
|
|||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config = LLMNodeConfig(**self.config)
|
self.typed_config: LLMNodeConfig | None = None
|
||||||
|
|
||||||
def _render_context(self, message, state):
|
def _render_context(self, message, state):
|
||||||
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
|
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
|
||||||
@@ -164,6 +164,7 @@ class LLMNode(BaseNode):
|
|||||||
Returns:
|
Returns:
|
||||||
LLM 响应消息
|
LLM 响应消息
|
||||||
"""
|
"""
|
||||||
|
self.typed_config = LLMNodeConfig(**self.config)
|
||||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||||
|
|||||||
@@ -10,9 +10,10 @@ from app.services.memory_agent_service import MemoryAgentService
|
|||||||
class MemoryReadNode(BaseNode):
|
class MemoryReadNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config = MemoryReadNodeConfig(**self.config)
|
self.typed_config: MemoryReadNodeConfig | None = None
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> Any:
|
async def execute(self, state: WorkflowState) -> Any:
|
||||||
|
self.typed_config = MemoryReadNodeConfig(**self.config)
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
workspace_id = self.get_variable('sys.workspace_id', state)
|
workspace_id = self.get_variable('sys.workspace_id', state)
|
||||||
end_user_id = self.get_variable("sys.user_id", state)
|
end_user_id = self.get_variable("sys.user_id", state)
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class ParameterExtractorNode(BaseNode):
|
class ParameterExtractorNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config = ParameterExtractorNodeConfig(**self.config)
|
self.typed_config: ParameterExtractorNodeConfig | None = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_prompt():
|
def _get_prompt():
|
||||||
@@ -145,6 +145,7 @@ class ParameterExtractorNode(BaseNode):
|
|||||||
Raises:
|
Raises:
|
||||||
BusinessException: If LLM output cannot be parsed as valid JSON.
|
BusinessException: If LLM output cannot be parsed as valid JSON.
|
||||||
"""
|
"""
|
||||||
|
self.typed_config = ParameterExtractorNodeConfig(**self.config)
|
||||||
llm = self._get_llm_instance()
|
llm = self._get_llm_instance()
|
||||||
system_prompt, user_prompt = self._get_prompt()
|
system_prompt, user_prompt = self._get_prompt()
|
||||||
|
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config = QuestionClassifierNodeConfig(**self.config)
|
self.typed_config: QuestionClassifierNodeConfig | None = None
|
||||||
self.category_to_case_map = self._build_category_case_map()
|
self.category_to_case_map = {}
|
||||||
|
|
||||||
def _get_llm_instance(self) -> RedBearLLM:
|
def _get_llm_instance(self) -> RedBearLLM:
|
||||||
"""获取LLM实例"""
|
"""获取LLM实例"""
|
||||||
@@ -67,6 +67,8 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> dict:
|
async def execute(self, state: WorkflowState) -> dict:
|
||||||
"""执行问题分类"""
|
"""执行问题分类"""
|
||||||
|
self.typed_config = QuestionClassifierNodeConfig(**self.config)
|
||||||
|
self.category_to_case_map = self._build_category_case_map()
|
||||||
question = self.typed_config.input_variable
|
question = self.typed_config.input_variable
|
||||||
supplement_prompt = self.typed_config.user_supplement_prompt or ""
|
supplement_prompt = self.typed_config.user_supplement_prompt or ""
|
||||||
categories = self.typed_config.categories or []
|
categories = self.typed_config.categories or []
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class StartNode(BaseNode):
|
|||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
|
|
||||||
# 解析并验证配置
|
# 解析并验证配置
|
||||||
self.typed_config = StartNodeConfig(**self.config)
|
self.typed_config: StartNodeConfig | None = None
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
"""执行 start 节点业务逻辑
|
"""执行 start 节点业务逻辑
|
||||||
@@ -48,6 +48,7 @@ class StartNode(BaseNode):
|
|||||||
Returns:
|
Returns:
|
||||||
包含系统参数、会话变量和自定义变量的字典
|
包含系统参数、会话变量和自定义变量的字典
|
||||||
"""
|
"""
|
||||||
|
self.typed_config = StartNodeConfig(**self.config)
|
||||||
logger.info(f"节点 {self.node_id} (Start) 开始执行")
|
logger.info(f"节点 {self.node_id} (Start) 开始执行")
|
||||||
|
|
||||||
# 创建变量池实例(在方法内复用)
|
# 创建变量池实例(在方法内复用)
|
||||||
|
|||||||
@@ -17,10 +17,11 @@ class ToolNode(BaseNode):
|
|||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config = ToolNodeConfig(**self.config)
|
self.typed_config: ToolNodeConfig | None = None
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
"""执行工具"""
|
"""执行工具"""
|
||||||
|
self.typed_config = ToolNodeConfig(**self.config)
|
||||||
# 获取租户ID和用户ID
|
# 获取租户ID和用户ID
|
||||||
tenant_id = self.get_variable("sys.tenant_id", state)
|
tenant_id = self.get_variable("sys.tenant_id", state)
|
||||||
user_id = self.get_variable("sys.user_id", state)
|
user_id = self.get_variable("sys.user_id", state)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class VariableAggregatorNode(BaseNode):
|
class VariableAggregatorNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config = VariableAggregatorNodeConfig(**self.config)
|
self.typed_config: VariableAggregatorNodeConfig | None = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_express(variable_string: str) -> Any:
|
def _get_express(variable_string: str) -> Any:
|
||||||
@@ -37,6 +37,7 @@ class VariableAggregatorNode(BaseNode):
|
|||||||
- str: In non-group mode, returns the first non-None variable value.
|
- str: In non-group mode, returns the first non-None variable value.
|
||||||
- dict: In group mode, returns a mapping of group_name -> first non-None variable value.
|
- dict: In group mode, returns a mapping of group_name -> first non-None variable value.
|
||||||
"""
|
"""
|
||||||
|
self.typed_config = VariableAggregatorNodeConfig(**self.config)
|
||||||
if not self.typed_config.group:
|
if not self.typed_config.group:
|
||||||
# --------------------------
|
# --------------------------
|
||||||
# Non-group mode
|
# Non-group mode
|
||||||
|
|||||||
Reference in New Issue
Block a user