diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index e7007884..727f7391 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -35,7 +35,7 @@ class WorkflowState(TypedDict): # Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx) variables: Annotated[dict[str, Any], lambda x, y: { **x, - **{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v + **{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v for k, v in y.items()} }] @@ -46,12 +46,12 @@ class WorkflowState(TypedDict): # Runtime node variables (simplified version, stores business data for fast access between nodes) # Format: {node_id: business_result} runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}] - + # Execution context execution_id: str workspace_id: str user_id: str - + # Error information (for error edges) error: str | None error_node: str | None @@ -66,7 +66,7 @@ class BaseNode(ABC): 所有节点类型都应该继承此基类,实现 execute 方法。 """ - + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): """初始化节点 @@ -83,7 +83,7 @@ class BaseNode(ABC): # 使用 or 运算符处理 None 值 self.config = node_config.get("config") or {} self.error_handling = node_config.get("error_handling") or {} - + @abstractmethod async def execute(self, state: WorkflowState) -> Any: """执行节点业务逻辑(非流式) @@ -108,7 +108,7 @@ class BaseNode(ABC): >>> return {"message": "开始", "conversation_id": "xxx"} """ pass - + async def execute_stream(self, state: WorkflowState): """执行节点业务逻辑(流式) @@ -138,7 +138,7 @@ class BaseNode(ABC): result = await self.execute(state) # 默认实现:直接 yield 完成标记 yield {"__final__": True, "result": result} - + def supports_streaming(self) -> bool: """节点是否支持流式输出 @@ -147,7 +147,7 @@ class BaseNode(ABC): """ # 检查子类是否重写了 execute_stream 方法 return self.execute_stream.__func__ != BaseNode.execute_stream.__func__ - + def get_timeout(self) -> int: """获取超时时间(秒) @@ -156,7 +156,7 @@ class BaseNode(ABC): """ return 60 # return self.error_handling.get("timeout", 60) - + async def run(self, state: WorkflowState) -> dict[str, Any]: """执行节点(带错误处理和输出包装,非流式) @@ -173,33 +173,33 @@ class BaseNode(ABC): 标准化的状态更新字典 """ import time - + start_time = time.time() timeout = self.get_timeout() - + try: # 调用节点的业务逻辑 business_result = await asyncio.wait_for( self.execute(state), timeout=timeout ) - + elapsed_time = time.time() - start_time - + # 提取处理后的输出(调用子类的 _extract_output) extracted_output = self._extract_output(business_result) - + # 包装成标准输出格式 wrapped_output = self._wrap_output(business_result, elapsed_time, state) - + # 将提取后的输出存储到运行时变量中(供后续节点快速访问) # 如果提取后的输出是字典,拆包存储;否则存储为 output 字段 if isinstance(extracted_output, dict): runtime_var = extracted_output else: runtime_var = {"output": extracted_output} - + # 返回包装后的输出和运行时变量 return { **wrapped_output, @@ -208,7 +208,7 @@ class BaseNode(ABC): }, "looping": state["looping"] } - + except TimeoutError: elapsed_time = time.time() - start_time logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)") @@ -217,7 +217,7 @@ class BaseNode(ABC): elapsed_time = time.time() - start_time logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True) return self._wrap_error(str(e), elapsed_time, state) - + async def run_stream(self, state: WorkflowState): """Execute node with error handling and output wrapping (streaming) @@ -240,40 +240,41 @@ class BaseNode(ABC): State updates with streaming buffer and final result """ import time - + start_time = time.time() timeout = self.get_timeout() - + try: # Get LangGraph's stream writer for sending custom data writer = get_stream_writer() - + # Check if this is an End node # End nodes CAN send chunks (for suffix), but only after LLM content is_end_node = self.node_type == "end" - + # Check if this node is adjacent to End node (for message type) is_adjacent_to_end = getattr(self, '_is_adjacent_to_end', False) - + # Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk" - - logger.debug(f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})") - + + logger.debug( + f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})") + # Accumulate complete result (for final wrapping) chunks = [] final_result = None chunk_count = 0 - + # Stream chunks in real-time loop_start = asyncio.get_event_loop().time() - + async for item in self.execute_stream(state): # Check timeout if asyncio.get_event_loop().time() - loop_start > timeout: raise TimeoutError() - + # Check if it's a completion marker if isinstance(item, dict) and item.get("__final__"): final_result = item["result"] @@ -282,10 +283,10 @@ class BaseNode(ABC): chunk_count += 1 chunks.append(item) full_content = "".join(chunks) - + # Send chunks for all nodes (including End nodes for suffix) logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...") - + # 1. Send via stream writer (for real-time client updates) writer({ "type": chunk_type, # "message" or "node_chunk" @@ -294,7 +295,7 @@ class BaseNode(ABC): "full_content": full_content, "chunk_index": chunk_count }) - + # 2. Update streaming buffer in state (for downstream nodes) # Only non-End nodes need streaming buffer if not is_end_node: @@ -313,7 +314,7 @@ class BaseNode(ABC): chunk_str = str(item) chunks.append(chunk_str) full_content = "".join(chunks) - + # Send chunks for all nodes writer({ "type": chunk_type, # "message" or "node_chunk" @@ -322,7 +323,7 @@ class BaseNode(ABC): "full_content": full_content, "chunk_index": chunk_count }) - + # Only non-End nodes need streaming buffer if not is_end_node: yield { @@ -334,23 +335,23 @@ class BaseNode(ABC): } } } - + elapsed_time = time.time() - start_time - + logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}") - + # Extract processed output (call subclass's _extract_output) extracted_output = self._extract_output(final_result) - + # Wrap final result final_output = self._wrap_output(final_result, elapsed_time, state) - + # Store extracted output in runtime variables (for quick access by subsequent nodes) if isinstance(extracted_output, dict): runtime_var = extracted_output else: runtime_var = {"output": extracted_output} - + # Build complete state update (including node_outputs, runtime_vars, and final streaming buffer) state_update = { **final_output, @@ -359,7 +360,7 @@ class BaseNode(ABC): }, "looping": state["looping"] } - + # Add streaming buffer for non-End nodes if not is_end_node: state_update["streaming_buffer"] = { @@ -369,11 +370,11 @@ class BaseNode(ABC): "is_complete": True # Mark as complete } } - + # Finally yield state update # LangGraph will merge this into state yield state_update - + except TimeoutError: elapsed_time = time.time() - start_time logger.error(f"节点 {self.node_id} 执行超时 ({timeout}s)") @@ -384,12 +385,12 @@ class BaseNode(ABC): logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True) error_output = self._wrap_error(str(e), elapsed_time, state) yield error_output - + def _wrap_output( - self, - business_result: Any, - elapsed_time: float, - state: WorkflowState + self, + business_result: Any, + elapsed_time: float, + state: WorkflowState ) -> dict[str, Any]: """将业务结果包装成标准输出格式 @@ -403,13 +404,13 @@ class BaseNode(ABC): """ # 提取输入数据(用于记录) input_data = self._extract_input(state) - + # 提取 token 使用情况(如果有) token_usage = self._extract_token_usage(business_result) - + # 提取实际输出(去除元数据) output = self._extract_output(business_result) - + # 构建标准节点输出 node_output = { "node_id": self.node_id, @@ -422,18 +423,18 @@ class BaseNode(ABC): "token_usage": token_usage, "error": None } - + return { "node_outputs": { self.node_id: node_output } } - + def _wrap_error( - self, - error_message: str, - elapsed_time: float, - state: WorkflowState + self, + error_message: str, + elapsed_time: float, + state: WorkflowState ) -> dict[str, Any]: """将错误包装成标准输出格式 @@ -447,10 +448,10 @@ class BaseNode(ABC): """ # 查找错误边 error_edge = self._find_error_edge() - + # 提取输入数据 input_data = self._extract_input(state) - + # 构建错误输出 node_output = { "node_id": self.node_id, @@ -463,7 +464,7 @@ class BaseNode(ABC): "token_usage": None, "error": error_message } - + if error_edge: # 有错误边:记录错误并继续 logger.warning( @@ -480,7 +481,7 @@ class BaseNode(ABC): # 无错误边:抛出异常停止工作流 logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}") raise Exception(f"节点 {self.node_id} 执行失败: {error_message}") - + def _extract_input(self, state: WorkflowState) -> dict[str, Any]: """提取节点输入数据(用于记录) @@ -494,7 +495,7 @@ class BaseNode(ABC): """ # 默认返回配置 return {"config": self.config} - + def _extract_output(self, business_result: Any) -> Any: """从业务结果中提取实际输出 @@ -508,7 +509,7 @@ class BaseNode(ABC): """ # 默认直接返回业务结果 return business_result - + def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: """从业务结果中提取 token 使用情况 @@ -522,7 +523,7 @@ class BaseNode(ABC): """ # 默认返回 None return None - + def _find_error_edge(self) -> dict[str, Any] | None: """查找错误边 @@ -533,8 +534,8 @@ class BaseNode(ABC): if edge.get("source") == self.node_id and edge.get("type") == "error": return edge return None - - def _render_template(self, template: str, state: WorkflowState | None, struct: bool = True) -> str: + + def _render_template(self, template: str, state: WorkflowState | None, strict: bool = True) -> str: """渲染模板 支持的变量命名空间: @@ -550,28 +551,28 @@ class BaseNode(ABC): 渲染后的字符串 """ from app.core.workflow.template_renderer import render_template - + # 处理 state 为 None 的情况 if state is None: state = {} - + # 使用变量池获取变量 pool = VariablePool(state) - + # 构建完整的 variables 结构 variables = { "sys": pool.get_all_system_vars(), "conv": pool.get_all_conversation_vars() } - + return render_template( template=template, variables=variables, node_outputs=pool.get_all_node_outputs(), system_vars=pool.get_all_system_vars(), - struct=struct + strict=strict ) - + def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool: """评估条件表达式 @@ -588,20 +589,20 @@ class BaseNode(ABC): 布尔值结果 """ from app.core.workflow.expression_evaluator import evaluate_condition - + # 处理 state 为 None 的情况 if state is None: state = {} - + # 使用变量池获取变量 pool = VariablePool(state) - + # 构建完整的 variables 结构(包含 sys 和 conv) variables = { "sys": pool.get_all_system_vars(), "conv": pool.get_all_conversation_vars() } - + return evaluate_condition( expression=expression, variables=variables, @@ -626,12 +627,12 @@ class BaseNode(ABC): >>> llm_output = pool.get("llm_qa.output") """ return VariablePool(state) - + def get_variable( - self, - selector: list[str] | str, - state: WorkflowState, - default: Any = None + self, + selector: list[str] | str, + state: WorkflowState, + default: Any = None ) -> Any: """获取变量值(便捷方法) @@ -650,7 +651,7 @@ class BaseNode(ABC): """ pool = VariablePool(state) return pool.get(selector, default=default) - + def has_variable(self, selector: list[str] | str, state: WorkflowState) -> bool: """检查变量是否存在(便捷方法) diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index 6230345c..6195afbd 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -37,7 +37,7 @@ class EndNode(BaseNode): # 如果配置了输出模板,使用模板渲染;否则使用默认输出 if output_template: - output = self._render_template(output_template, state, struct=False) + output = self._render_template(output_template, state, strict=False) else: output = "工作流已完成" @@ -156,6 +156,16 @@ class EndNode(BaseNode): if not output_template: output = "工作流已完成" + from langgraph.config import get_stream_writer + writer = get_stream_writer() + writer({ + "type": "message", # End node output uses message type + "node_id": self.node_id, + "chunk": "", + "full_content": output, + "chunk_index": 1, + "is_suffix": False + }) yield {"__final__": True, "result": output} return @@ -190,7 +200,7 @@ class EndNode(BaseNode): if upstream_llm_ref_index is None: # No reference to direct upstream LLM node, output complete template content - output = self._render_template(output_template, state) + output = self._render_template(output_template, state, strict=False) logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'") # Send complete content via writer (as a single message chunk) @@ -246,7 +256,7 @@ class EndNode(BaseNode): suffix = "".join(suffix_parts) # 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀) - full_output = self._render_template(output_template, state) + full_output = self._render_template(output_template, state, strict=False) logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}") logger.info(f"[后缀调试] 后缀内容: '{suffix}'") diff --git a/api/app/core/workflow/template_renderer.py b/api/app/core/workflow/template_renderer.py index 198a3322..b6305b8c 100644 --- a/api/app/core/workflow/template_renderer.py +++ b/api/app/core/workflow/template_renderer.py @@ -5,6 +5,7 @@ """ import logging +from collections import defaultdict from typing import Any from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined @@ -12,6 +13,18 @@ from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndef logger = logging.getLogger(__name__) +class SafeUndefined(Undefined): + """访问未定义属性不会报错,返回空字符串""" + __slots__ = () + + def _fail_with_undefined_error(self, *args, **kwargs): + return "" + + __add__ = __radd__ = __mul__ = __rmul__ = __div__ = __rdiv__ = __truediv__ = __rtruediv__ = _fail_with_undefined_error + __getitem__ = __getattr__ = _fail_with_undefined_error + __str__ = __repr__ = lambda self: "" + + class TemplateRenderer: """模板渲染器""" @@ -21,8 +34,9 @@ class TemplateRenderer: Args: strict: 是否使用严格模式(未定义变量会抛出异常) """ + self.strict = strict self.env = Environment( - undefined=StrictUndefined if strict else Undefined, + undefined=StrictUndefined if strict else SafeUndefined, autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML ) @@ -69,12 +83,17 @@ class TemplateRenderer: # variables 的结构:{"sys": {...}, "conv": {...}} sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {} conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {} - - context = { - "conv": conv_vars, # 会话变量:{{conv.user_name}} - "node": node_outputs, # 节点输出:{{node.node_1.output}} - "sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源) - } + if self.strict: + context = defaultdict(dict) + context["conv"] = conv_vars + context["nodes"] = node_outputs + context["sys"] = {**(system_vars or {}), **sys_vars} + else: + context = { + "conv": conv_vars, # 会话变量:{{conv.user_name}} + "node": node_outputs, # 节点输出:{{node.node_1.output}} + "sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源) + } # 支持直接通过节点ID访问节点输出:{{llm_qa.output}} # 将所有节点输出添加到顶层上下文 @@ -141,12 +160,12 @@ def render_template( variables: dict[str, Any], node_outputs: dict[str, Any], system_vars: dict[str, Any] | None = None, - struct: bool = True + strict: bool = True ) -> str: """渲染模板(便捷函数) Args: - struct: 渲染模式 + strict: 严格模式 template: 模板字符串 variables: 用户变量 node_outputs: 节点输出 @@ -164,7 +183,7 @@ def render_template( ... ) '请分析: 这是一段文本' """ - renderer = TemplateRenderer(strict=struct) + renderer = TemplateRenderer(strict=strict) return renderer.render(template, variables, node_outputs, system_vars) diff --git a/api/app/templates/workflows/simple_qa/template.yml b/api/app/templates/workflows/simple_qa/template.yml index 2cf0f9b1..14de4a73 100644 --- a/api/app/templates/workflows/simple_qa/template.yml +++ b/api/app/templates/workflows/simple_qa/template.yml @@ -53,7 +53,7 @@ nodes: type: end name: 结束 config: - output: "{{llm_qa.output}}" + output: "{{ llm_qa.output }}" position: x: 900 y: 100