From d9f08860bcb631da521316096c198d368563ad71 Mon Sep 17 00:00:00 2001 From: wxy Date: Thu, 7 May 2026 11:43:24 +0800 Subject: [PATCH] feat(LLM node): integrate exception handling and enable branch routing - Integrate exception handling configuration into LLM nodes, supporting three strategies: throw exception, return default value, or trigger exception branch. - Modify execution logic to return a result structure containing a branch signal, enabling routing to designated branches upon failure. - Update graph_builder to support LLM node branch routing logic using the branch_signal field for conditional judgment. - Implement backward compatibility to support both legacy and new result formats. --- api/app/core/workflow/engine/graph_builder.py | 14 +- api/app/core/workflow/nodes/enums.py | 2 +- api/app/core/workflow/nodes/llm/config.py | 20 ++ api/app/core/workflow/nodes/llm/node.py | 209 ++++++++++++------ 4 files changed, 169 insertions(+), 76 deletions(-) diff --git a/api/app/core/workflow/engine/graph_builder.py b/api/app/core/workflow/engine/graph_builder.py index 5ecf41d2..fca0e2fe 100644 --- a/api/app/core/workflow/engine/graph_builder.py +++ b/api/app/core/workflow/engine/graph_builder.py @@ -141,9 +141,10 @@ class GraphBuilder: for node_info in source_nodes: if self.get_node_type(node_info["id"]) in BRANCH_NODES: - branch_nodes.append( - (node_info["id"], node_info["branch"]) - ) + if node_info.get("branch") is not None: + branch_nodes.append( + (node_info["id"], node_info["branch"]) + ) else: if self.get_node_type(node_info["id"]) in (NodeType.END, NodeType.OUTPUT): output_nodes.append(node_info["id"]) @@ -314,9 +315,10 @@ class GraphBuilder: for idx in range(len(related_edge)): # Generate a condition expression for each edge # Used later to determine which branch to take based on the node's output - # Assumes node output `node..output` matches the edge's label - # For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1' - related_edge[idx]['condition'] = f"node['{node_id}']['output'] == '{related_edge[idx]['label']}'" + # For LLM nodes, use branch_signal field for routing (output is dynamic text) + # For other branch nodes (e.g. HTTP), use output field + route_field = "branch_signal" if node_type == NodeType.LLM else "output" + related_edge[idx]['condition'] = f"node['{node_id}']['{route_field}'] == '{related_edge[idx]['label']}'" if node_instance: # Wrap node's run method to avoid closure issues diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index 0c0e8fb8..60c69dac 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -31,7 +31,7 @@ class NodeType(StrEnum): NOTES = "notes" -BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER}) +BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER, NodeType.LLM}) class ComparisonOperator(StrEnum): diff --git a/api/app/core/workflow/nodes/llm/config.py b/api/app/core/workflow/nodes/llm/config.py index b815c80f..d51a3575 100644 --- a/api/app/core/workflow/nodes/llm/config.py +++ b/api/app/core/workflow/nodes/llm/config.py @@ -6,6 +6,7 @@ import uuid from pydantic import BaseModel, Field, field_validator from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition +from app.core.workflow.nodes.enums import HttpErrorHandle from app.core.workflow.variable.base_variable import VariableType @@ -49,6 +50,20 @@ class MemoryWindowSetting(BaseModel): ) +class LLMErrorHandleConfig(BaseModel): + """LLM 异常处理配置""" + + method: HttpErrorHandle = Field( + default=HttpErrorHandle.NONE, + description="异常处理策略:'none' 抛出异常, 'default' 返回默认值, 'branch' 走异常分支", + ) + + output: str = Field( + default="", + description="LLM 异常时返回的默认输出文本(method=default 时生效)", + ) + + class LLMNodeConfig(BaseNodeConfig): """LLM 节点配置 @@ -152,6 +167,11 @@ class LLMNodeConfig(BaseNodeConfig): description="输出变量定义(自动生成,通常不需要修改)" ) + error_handle: LLMErrorHandleConfig = Field( + default_factory=LLMErrorHandleConfig, + description="LLM 异常处理配置", + ) + @field_validator("messages", "prompt") @classmethod def validate_input_mode(cls, v): diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index 352e735d..65fc624e 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -15,6 +15,7 @@ from app.core.models import RedBearLLM, RedBearModelConfig from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.base_node import BaseNode +from app.core.workflow.nodes.enums import HttpErrorHandle from app.core.workflow.nodes.llm.config import LLMNodeConfig from app.core.workflow.variable.base_variable import VariableType from app.db import get_db_context @@ -76,7 +77,7 @@ class LLMNode(BaseNode): self.messages = [] def _output_types(self) -> dict[str, VariableType]: - return {"output": VariableType.STRING} + return {"output": VariableType.STRING, "branch_signal": VariableType.STRING} def _render_context(self, message: str, variable_pool: VariablePool): context = f"{self._render_template(self.typed_config.context, variable_pool)}" @@ -239,7 +240,7 @@ class LLMNode(BaseNode): return llm - async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage: + async def execute(self, state: WorkflowState, variable_pool: VariablePool): """非流式执行 LLM 调用 Args: @@ -247,28 +248,36 @@ class LLMNode(BaseNode): variable_pool: 变量池 Returns: - LLM 响应消息 + dict: {"llm_result": AIMessage, "branch_signal": "SUCCESS"} on success, + {"llm_result": None, "branch_signal": "ERROR"} on branch error """ - # self.typed_config = LLMNodeConfig(**self.config) - llm = await self._prepare_llm(state, variable_pool, False) + try: + # self.typed_config = LLMNodeConfig(**self.config) + llm = await self._prepare_llm(state, variable_pool, False) - logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)") + logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)") - # 调用 LLM(支持字符串或消息列表) - response = await llm.ainvoke(self.messages) - # 提取内容 - if hasattr(response, 'content'): - content = self.process_model_output(response.content) - else: - content = str(response) + # 调用 LLM(支持字符串或消息列表) + response = await llm.ainvoke(self.messages) + # 提取内容 + if hasattr(response, 'content'): + content = self.process_model_output(response.content) + else: + content = str(response) - logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}") + logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}") - # 返回 AIMessage(包含响应元数据) - return AIMessage(content=content, response_metadata={ - **response.response_metadata, - "token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage') - }) + # 返回 AIMessage(包含响应元数据) + return { + "llm_result": AIMessage(content=content, response_metadata={ + **response.response_metadata, + "token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage') + }), + "branch_signal": "SUCCESS", + } + except Exception as e: + logger.error(f"节点 {self.node_id} LLM 调用失败: {e}") + return self._handle_llm_error(e) def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: """提取输入数据(用于记录)""" @@ -286,16 +295,36 @@ class LLMNode(BaseNode): } } - def _extract_output(self, business_result: Any) -> str: - """从 AIMessage 中提取文本内容""" + def _extract_output(self, business_result: Any) -> dict: + """从业务结果中提取输出变量 + + 支持新旧两种格式: + - 新格式:{"llm_result": AIMessage, "branch_signal": "SUCCESS"} + - 旧格式:AIMessage(向后兼容) + """ + if isinstance(business_result, dict) and "branch_signal" in business_result: + llm_result = business_result.get("llm_result") + if isinstance(llm_result, AIMessage): + return { + "output": llm_result.content, + "branch_signal": business_result["branch_signal"], + } + return { + "output": str(llm_result) if llm_result else "", + "branch_signal": business_result["branch_signal"], + } + # 旧格式向后兼容 if isinstance(business_result, AIMessage): - return business_result.content - return str(business_result) + return {"output": business_result.content, "branch_signal": "SUCCESS"} + return {"output": str(business_result), "branch_signal": "SUCCESS"} def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: - """从 AIMessage 中提取 token 使用情况""" - if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'): - usage = business_result.response_metadata.get('token_usage') + """从业务结果中提取 token 使用情况""" + llm_result = business_result + if isinstance(business_result, dict): + llm_result = business_result.get("llm_result", business_result) + if isinstance(llm_result, AIMessage) and hasattr(llm_result, 'response_metadata'): + usage = llm_result.response_metadata.get('token_usage') if usage: return { "prompt_tokens": usage.get('input_tokens', 0), @@ -304,6 +333,44 @@ class LLMNode(BaseNode): } return None + def _handle_llm_error(self, error: Exception) -> dict: + """处理 LLM 调用异常,根据 error_handle 配置决定行为 + + Args: + error: LLM 调用中捕获的异常 + + Returns: + dict: {"llm_result": None, "branch_signal": "ERROR"} for branch mode, + or default output for default mode + + Raises: + 原异常(当 error_handle.method 为 NONE 时) + """ + if self.typed_config is None: + raise error + + match self.typed_config.error_handle.method: + case HttpErrorHandle.NONE: + raise error + case HttpErrorHandle.DEFAULT: + logger.warning( + f"节点 {self.node_id}: LLM 调用失败,返回默认输出" + ) + default_output = self.typed_config.error_handle.output or "" + return { + "llm_result": AIMessage(content=default_output, response_metadata={}), + "branch_signal": "SUCCESS", + } + case HttpErrorHandle.BRANCH: + logger.warning( + f"节点 {self.node_id}: LLM 调用失败,切换到异常处理分支" + ) + return { + "llm_result": None, + "branch_signal": "ERROR", + } + raise error + async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool): """流式执行 LLM 调用 @@ -316,54 +383,58 @@ class LLMNode(BaseNode): """ self.typed_config = LLMNodeConfig(**self.config) - llm = await self._prepare_llm(state, variable_pool, True) + try: + llm = await self._prepare_llm(state, variable_pool, True) - logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") - # logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}") + logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") - # 累积完整响应 - full_response = "" - chunk_count = 0 + # 累积完整响应 + full_response = "" + chunk_count = 0 - # 调用 LLM(流式,支持字符串或消息列表) - last_meta_data = {} - last_usage_metadata = {} - async for chunk in llm.astream(self.messages): - if hasattr(chunk, 'content'): - content = self.process_model_output(chunk.content) - else: - content = str(chunk) - if hasattr(chunk, 'response_metadata') and chunk.response_metadata: - last_meta_data = chunk.response_metadata - if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata: - last_usage_metadata = chunk.usage_metadata + # 调用 LLM(流式,支持字符串或消息列表) + last_meta_data = {} + last_usage_metadata = {} + async for chunk in llm.astream(self.messages): + if hasattr(chunk, 'content'): + content = self.process_model_output(chunk.content) + else: + content = str(chunk) + if hasattr(chunk, 'response_metadata') and chunk.response_metadata: + last_meta_data = chunk.response_metadata + if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata: + last_usage_metadata = chunk.usage_metadata - # 只有当内容不为空时才处理 - if content: - full_response += content - chunk_count += 1 + # 只有当内容不为空时才处理 + if content: + full_response += content + chunk_count += 1 - # 流式返回每个文本片段 - yield { - "__final__": False, - "chunk": content - } + # 流式返回每个文本片段 + yield { + "__final__": False, + "chunk": content + } - yield { - "__final__": False, - "chunk": "", - "done": True - } - logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}") - - # 构建完整的 AIMessage(包含元数据) - final_message = AIMessage( - content=full_response, - response_metadata={ - **last_meta_data, - "token_usage": last_usage_metadata or last_meta_data.get('token_usage') + yield { + "__final__": False, + "chunk": "", + "done": True } - ) + logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}") - # yield 完成标记 - yield {"__final__": True, "result": final_message} + # 构建完整的 AIMessage(包含元数据) + final_message = AIMessage( + content=full_response, + response_metadata={ + **last_meta_data, + "token_usage": last_usage_metadata or last_meta_data.get('token_usage') + } + ) + + # yield 完成标记 + yield {"__final__": True, "result": {"llm_result": final_message, "branch_signal": "SUCCESS"}} + except Exception as e: + logger.error(f"节点 {self.node_id} LLM 流式调用失败: {e}") + error_result = self._handle_llm_error(e) + yield {"__final__": True, "result": error_result}