fix(workflow): use loose rendering for end-node variables

This commit is contained in:
mengyonghao
2026-01-13 15:04:44 +08:00
parent c4addc7e54
commit fe4a53563e
4 changed files with 126 additions and 96 deletions

View File

@@ -259,7 +259,8 @@ class BaseNode(ABC):
# Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others # Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others
chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk" chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk"
logger.debug(f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})") logger.debug(
f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})")
# Accumulate complete result (for final wrapping) # Accumulate complete result (for final wrapping)
chunks = [] chunks = []
@@ -386,10 +387,10 @@ class BaseNode(ABC):
yield error_output yield error_output
def _wrap_output( def _wrap_output(
self, self,
business_result: Any, business_result: Any,
elapsed_time: float, elapsed_time: float,
state: WorkflowState state: WorkflowState
) -> dict[str, Any]: ) -> dict[str, Any]:
"""将业务结果包装成标准输出格式 """将业务结果包装成标准输出格式
@@ -430,10 +431,10 @@ class BaseNode(ABC):
} }
def _wrap_error( def _wrap_error(
self, self,
error_message: str, error_message: str,
elapsed_time: float, elapsed_time: float,
state: WorkflowState state: WorkflowState
) -> dict[str, Any]: ) -> dict[str, Any]:
"""将错误包装成标准输出格式 """将错误包装成标准输出格式
@@ -534,7 +535,7 @@ class BaseNode(ABC):
return edge return edge
return None return None
def _render_template(self, template: str, state: WorkflowState | None, struct: bool = True) -> str: def _render_template(self, template: str, state: WorkflowState | None, strict: bool = True) -> str:
"""渲染模板 """渲染模板
支持的变量命名空间: 支持的变量命名空间:
@@ -569,7 +570,7 @@ class BaseNode(ABC):
variables=variables, variables=variables,
node_outputs=pool.get_all_node_outputs(), node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(), system_vars=pool.get_all_system_vars(),
struct=struct strict=strict
) )
def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool: def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool:
@@ -628,10 +629,10 @@ class BaseNode(ABC):
return VariablePool(state) return VariablePool(state)
def get_variable( def get_variable(
self, self,
selector: list[str] | str, selector: list[str] | str,
state: WorkflowState, state: WorkflowState,
default: Any = None default: Any = None
) -> Any: ) -> Any:
"""获取变量值(便捷方法) """获取变量值(便捷方法)

View File

@@ -37,7 +37,7 @@ class EndNode(BaseNode):
# 如果配置了输出模板,使用模板渲染;否则使用默认输出 # 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template: if output_template:
output = self._render_template(output_template, state, struct=False) output = self._render_template(output_template, state, strict=False)
else: else:
output = "工作流已完成" output = "工作流已完成"
@@ -156,6 +156,16 @@ class EndNode(BaseNode):
if not output_template: if not output_template:
output = "工作流已完成" output = "工作流已完成"
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer({
"type": "message", # End node output uses message type
"node_id": self.node_id,
"chunk": "",
"full_content": output,
"chunk_index": 1,
"is_suffix": False
})
yield {"__final__": True, "result": output} yield {"__final__": True, "result": output}
return return
@@ -190,7 +200,7 @@ class EndNode(BaseNode):
if upstream_llm_ref_index is None: if upstream_llm_ref_index is None:
# No reference to direct upstream LLM node, output complete template content # No reference to direct upstream LLM node, output complete template content
output = self._render_template(output_template, state) output = self._render_template(output_template, state, strict=False)
logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'") logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'")
# Send complete content via writer (as a single message chunk) # Send complete content via writer (as a single message chunk)
@@ -246,7 +256,7 @@ class EndNode(BaseNode):
suffix = "".join(suffix_parts) suffix = "".join(suffix_parts)
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀) # 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
full_output = self._render_template(output_template, state) full_output = self._render_template(output_template, state, strict=False)
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}") logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
logger.info(f"[后缀调试] 后缀内容: '{suffix}'") logger.info(f"[后缀调试] 后缀内容: '{suffix}'")

View File

@@ -5,6 +5,7 @@
""" """
import logging import logging
from collections import defaultdict
from typing import Any from typing import Any
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
@@ -12,6 +13,18 @@ from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndef
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SafeUndefined(Undefined):
"""访问未定义属性不会报错,返回空字符串"""
__slots__ = ()
def _fail_with_undefined_error(self, *args, **kwargs):
return ""
__add__ = __radd__ = __mul__ = __rmul__ = __div__ = __rdiv__ = __truediv__ = __rtruediv__ = _fail_with_undefined_error
__getitem__ = __getattr__ = _fail_with_undefined_error
__str__ = __repr__ = lambda self: ""
class TemplateRenderer: class TemplateRenderer:
"""模板渲染器""" """模板渲染器"""
@@ -21,8 +34,9 @@ class TemplateRenderer:
Args: Args:
strict: 是否使用严格模式(未定义变量会抛出异常) strict: 是否使用严格模式(未定义变量会抛出异常)
""" """
self.strict = strict
self.env = Environment( self.env = Environment(
undefined=StrictUndefined if strict else Undefined, undefined=StrictUndefined if strict else SafeUndefined,
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
) )
@@ -69,12 +83,17 @@ class TemplateRenderer:
# variables 的结构:{"sys": {...}, "conv": {...}} # variables 的结构:{"sys": {...}, "conv": {...}}
sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {} sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {}
conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {} conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {}
if self.strict:
context = { context = defaultdict(dict)
"conv": conv_vars, # 会话变量:{{conv.user_name}} context["conv"] = conv_vars
"node": node_outputs, # 节点输出:{{node.node_1.output}} context["nodes"] = node_outputs
"sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源) context["sys"] = {**(system_vars or {}), **sys_vars}
} else:
context = {
"conv": conv_vars, # 会话变量:{{conv.user_name}}
"node": node_outputs, # 节点输出:{{node.node_1.output}}
"sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源)
}
# 支持直接通过节点ID访问节点输出{{llm_qa.output}} # 支持直接通过节点ID访问节点输出{{llm_qa.output}}
# 将所有节点输出添加到顶层上下文 # 将所有节点输出添加到顶层上下文
@@ -141,12 +160,12 @@ def render_template(
variables: dict[str, Any], variables: dict[str, Any],
node_outputs: dict[str, Any], node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None, system_vars: dict[str, Any] | None = None,
struct: bool = True strict: bool = True
) -> str: ) -> str:
"""渲染模板(便捷函数) """渲染模板(便捷函数)
Args: Args:
struct: 渲染模式 strict: 严格模式
template: 模板字符串 template: 模板字符串
variables: 用户变量 variables: 用户变量
node_outputs: 节点输出 node_outputs: 节点输出
@@ -164,7 +183,7 @@ def render_template(
... ) ... )
'请分析: 这是一段文本' '请分析: 这是一段文本'
""" """
renderer = TemplateRenderer(strict=struct) renderer = TemplateRenderer(strict=strict)
return renderer.render(template, variables, node_outputs, system_vars) return renderer.render(template, variables, node_outputs, system_vars)

View File

@@ -53,7 +53,7 @@ nodes:
type: end type: end
name: 结束 name: 结束
config: config:
output: "{{llm_qa.output}}" output: "{{ llm_qa.output }}"
position: position:
x: 900 x: 900
y: 100 y: 100