fix(workflow): use loose rendering for end-node variables

This commit is contained in:
mengyonghao
2026-01-13 15:04:44 +08:00
parent c4addc7e54
commit fe4a53563e
4 changed files with 126 additions and 96 deletions

View File

@@ -35,7 +35,7 @@ class WorkflowState(TypedDict):
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx) # Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
variables: Annotated[dict[str, Any], lambda x, y: { variables: Annotated[dict[str, Any], lambda x, y: {
**x, **x,
**{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v **{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v
for k, v in y.items()} for k, v in y.items()}
}] }]
@@ -46,12 +46,12 @@ class WorkflowState(TypedDict):
# Runtime node variables (simplified version, stores business data for fast access between nodes) # Runtime node variables (simplified version, stores business data for fast access between nodes)
# Format: {node_id: business_result} # Format: {node_id: business_result}
runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}] runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# Execution context # Execution context
execution_id: str execution_id: str
workspace_id: str workspace_id: str
user_id: str user_id: str
# Error information (for error edges) # Error information (for error edges)
error: str | None error: str | None
error_node: str | None error_node: str | None
@@ -66,7 +66,7 @@ class BaseNode(ABC):
所有节点类型都应该继承此基类,实现 execute 方法。 所有节点类型都应该继承此基类,实现 execute 方法。
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
"""初始化节点 """初始化节点
@@ -83,7 +83,7 @@ class BaseNode(ABC):
# 使用 or 运算符处理 None 值 # 使用 or 运算符处理 None 值
self.config = node_config.get("config") or {} self.config = node_config.get("config") or {}
self.error_handling = node_config.get("error_handling") or {} self.error_handling = node_config.get("error_handling") or {}
@abstractmethod @abstractmethod
async def execute(self, state: WorkflowState) -> Any: async def execute(self, state: WorkflowState) -> Any:
"""执行节点业务逻辑(非流式) """执行节点业务逻辑(非流式)
@@ -108,7 +108,7 @@ class BaseNode(ABC):
>>> return {"message": "开始", "conversation_id": "xxx"} >>> return {"message": "开始", "conversation_id": "xxx"}
""" """
pass pass
async def execute_stream(self, state: WorkflowState): async def execute_stream(self, state: WorkflowState):
"""执行节点业务逻辑(流式) """执行节点业务逻辑(流式)
@@ -138,7 +138,7 @@ class BaseNode(ABC):
result = await self.execute(state) result = await self.execute(state)
# 默认实现:直接 yield 完成标记 # 默认实现:直接 yield 完成标记
yield {"__final__": True, "result": result} yield {"__final__": True, "result": result}
def supports_streaming(self) -> bool: def supports_streaming(self) -> bool:
"""节点是否支持流式输出 """节点是否支持流式输出
@@ -147,7 +147,7 @@ class BaseNode(ABC):
""" """
# 检查子类是否重写了 execute_stream 方法 # 检查子类是否重写了 execute_stream 方法
return self.execute_stream.__func__ != BaseNode.execute_stream.__func__ return self.execute_stream.__func__ != BaseNode.execute_stream.__func__
def get_timeout(self) -> int: def get_timeout(self) -> int:
"""获取超时时间(秒) """获取超时时间(秒)
@@ -156,7 +156,7 @@ class BaseNode(ABC):
""" """
return 60 return 60
# return self.error_handling.get("timeout", 60) # return self.error_handling.get("timeout", 60)
async def run(self, state: WorkflowState) -> dict[str, Any]: async def run(self, state: WorkflowState) -> dict[str, Any]:
"""执行节点(带错误处理和输出包装,非流式) """执行节点(带错误处理和输出包装,非流式)
@@ -173,33 +173,33 @@ class BaseNode(ABC):
标准化的状态更新字典 标准化的状态更新字典
""" """
import time import time
start_time = time.time() start_time = time.time()
timeout = self.get_timeout() timeout = self.get_timeout()
try: try:
# 调用节点的业务逻辑 # 调用节点的业务逻辑
business_result = await asyncio.wait_for( business_result = await asyncio.wait_for(
self.execute(state), self.execute(state),
timeout=timeout timeout=timeout
) )
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
# 提取处理后的输出(调用子类的 _extract_output # 提取处理后的输出(调用子类的 _extract_output
extracted_output = self._extract_output(business_result) extracted_output = self._extract_output(business_result)
# 包装成标准输出格式 # 包装成标准输出格式
wrapped_output = self._wrap_output(business_result, elapsed_time, state) wrapped_output = self._wrap_output(business_result, elapsed_time, state)
# 将提取后的输出存储到运行时变量中(供后续节点快速访问) # 将提取后的输出存储到运行时变量中(供后续节点快速访问)
# 如果提取后的输出是字典,拆包存储;否则存储为 output 字段 # 如果提取后的输出是字典,拆包存储;否则存储为 output 字段
if isinstance(extracted_output, dict): if isinstance(extracted_output, dict):
runtime_var = extracted_output runtime_var = extracted_output
else: else:
runtime_var = {"output": extracted_output} runtime_var = {"output": extracted_output}
# 返回包装后的输出和运行时变量 # 返回包装后的输出和运行时变量
return { return {
**wrapped_output, **wrapped_output,
@@ -208,7 +208,7 @@ class BaseNode(ABC):
}, },
"looping": state["looping"] "looping": state["looping"]
} }
except TimeoutError: except TimeoutError:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)") logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
@@ -217,7 +217,7 @@ class BaseNode(ABC):
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True) logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
return self._wrap_error(str(e), elapsed_time, state) return self._wrap_error(str(e), elapsed_time, state)
async def run_stream(self, state: WorkflowState): async def run_stream(self, state: WorkflowState):
"""Execute node with error handling and output wrapping (streaming) """Execute node with error handling and output wrapping (streaming)
@@ -240,40 +240,41 @@ class BaseNode(ABC):
State updates with streaming buffer and final result State updates with streaming buffer and final result
""" """
import time import time
start_time = time.time() start_time = time.time()
timeout = self.get_timeout() timeout = self.get_timeout()
try: try:
# Get LangGraph's stream writer for sending custom data # Get LangGraph's stream writer for sending custom data
writer = get_stream_writer() writer = get_stream_writer()
# Check if this is an End node # Check if this is an End node
# End nodes CAN send chunks (for suffix), but only after LLM content # End nodes CAN send chunks (for suffix), but only after LLM content
is_end_node = self.node_type == "end" is_end_node = self.node_type == "end"
# Check if this node is adjacent to End node (for message type) # Check if this node is adjacent to End node (for message type)
is_adjacent_to_end = getattr(self, '_is_adjacent_to_end', False) is_adjacent_to_end = getattr(self, '_is_adjacent_to_end', False)
# Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others # Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others
chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk" chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk"
logger.debug(f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})") logger.debug(
f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})")
# Accumulate complete result (for final wrapping) # Accumulate complete result (for final wrapping)
chunks = [] chunks = []
final_result = None final_result = None
chunk_count = 0 chunk_count = 0
# Stream chunks in real-time # Stream chunks in real-time
loop_start = asyncio.get_event_loop().time() loop_start = asyncio.get_event_loop().time()
async for item in self.execute_stream(state): async for item in self.execute_stream(state):
# Check timeout # Check timeout
if asyncio.get_event_loop().time() - loop_start > timeout: if asyncio.get_event_loop().time() - loop_start > timeout:
raise TimeoutError() raise TimeoutError()
# Check if it's a completion marker # Check if it's a completion marker
if isinstance(item, dict) and item.get("__final__"): if isinstance(item, dict) and item.get("__final__"):
final_result = item["result"] final_result = item["result"]
@@ -282,10 +283,10 @@ class BaseNode(ABC):
chunk_count += 1 chunk_count += 1
chunks.append(item) chunks.append(item)
full_content = "".join(chunks) full_content = "".join(chunks)
# Send chunks for all nodes (including End nodes for suffix) # Send chunks for all nodes (including End nodes for suffix)
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...") logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...")
# 1. Send via stream writer (for real-time client updates) # 1. Send via stream writer (for real-time client updates)
writer({ writer({
"type": chunk_type, # "message" or "node_chunk" "type": chunk_type, # "message" or "node_chunk"
@@ -294,7 +295,7 @@ class BaseNode(ABC):
"full_content": full_content, "full_content": full_content,
"chunk_index": chunk_count "chunk_index": chunk_count
}) })
# 2. Update streaming buffer in state (for downstream nodes) # 2. Update streaming buffer in state (for downstream nodes)
# Only non-End nodes need streaming buffer # Only non-End nodes need streaming buffer
if not is_end_node: if not is_end_node:
@@ -313,7 +314,7 @@ class BaseNode(ABC):
chunk_str = str(item) chunk_str = str(item)
chunks.append(chunk_str) chunks.append(chunk_str)
full_content = "".join(chunks) full_content = "".join(chunks)
# Send chunks for all nodes # Send chunks for all nodes
writer({ writer({
"type": chunk_type, # "message" or "node_chunk" "type": chunk_type, # "message" or "node_chunk"
@@ -322,7 +323,7 @@ class BaseNode(ABC):
"full_content": full_content, "full_content": full_content,
"chunk_index": chunk_count "chunk_index": chunk_count
}) })
# Only non-End nodes need streaming buffer # Only non-End nodes need streaming buffer
if not is_end_node: if not is_end_node:
yield { yield {
@@ -334,23 +335,23 @@ class BaseNode(ABC):
} }
} }
} }
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}") logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
# Extract processed output (call subclass's _extract_output) # Extract processed output (call subclass's _extract_output)
extracted_output = self._extract_output(final_result) extracted_output = self._extract_output(final_result)
# Wrap final result # Wrap final result
final_output = self._wrap_output(final_result, elapsed_time, state) final_output = self._wrap_output(final_result, elapsed_time, state)
# Store extracted output in runtime variables (for quick access by subsequent nodes) # Store extracted output in runtime variables (for quick access by subsequent nodes)
if isinstance(extracted_output, dict): if isinstance(extracted_output, dict):
runtime_var = extracted_output runtime_var = extracted_output
else: else:
runtime_var = {"output": extracted_output} runtime_var = {"output": extracted_output}
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer) # Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
state_update = { state_update = {
**final_output, **final_output,
@@ -359,7 +360,7 @@ class BaseNode(ABC):
}, },
"looping": state["looping"] "looping": state["looping"]
} }
# Add streaming buffer for non-End nodes # Add streaming buffer for non-End nodes
if not is_end_node: if not is_end_node:
state_update["streaming_buffer"] = { state_update["streaming_buffer"] = {
@@ -369,11 +370,11 @@ class BaseNode(ABC):
"is_complete": True # Mark as complete "is_complete": True # Mark as complete
} }
} }
# Finally yield state update # Finally yield state update
# LangGraph will merge this into state # LangGraph will merge this into state
yield state_update yield state_update
except TimeoutError: except TimeoutError:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行超时 ({timeout}s)") logger.error(f"节点 {self.node_id} 执行超时 ({timeout}s)")
@@ -384,12 +385,12 @@ class BaseNode(ABC):
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True) logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
error_output = self._wrap_error(str(e), elapsed_time, state) error_output = self._wrap_error(str(e), elapsed_time, state)
yield error_output yield error_output
def _wrap_output( def _wrap_output(
self, self,
business_result: Any, business_result: Any,
elapsed_time: float, elapsed_time: float,
state: WorkflowState state: WorkflowState
) -> dict[str, Any]: ) -> dict[str, Any]:
"""将业务结果包装成标准输出格式 """将业务结果包装成标准输出格式
@@ -403,13 +404,13 @@ class BaseNode(ABC):
""" """
# 提取输入数据(用于记录) # 提取输入数据(用于记录)
input_data = self._extract_input(state) input_data = self._extract_input(state)
# 提取 token 使用情况(如果有) # 提取 token 使用情况(如果有)
token_usage = self._extract_token_usage(business_result) token_usage = self._extract_token_usage(business_result)
# 提取实际输出(去除元数据) # 提取实际输出(去除元数据)
output = self._extract_output(business_result) output = self._extract_output(business_result)
# 构建标准节点输出 # 构建标准节点输出
node_output = { node_output = {
"node_id": self.node_id, "node_id": self.node_id,
@@ -422,18 +423,18 @@ class BaseNode(ABC):
"token_usage": token_usage, "token_usage": token_usage,
"error": None "error": None
} }
return { return {
"node_outputs": { "node_outputs": {
self.node_id: node_output self.node_id: node_output
} }
} }
def _wrap_error( def _wrap_error(
self, self,
error_message: str, error_message: str,
elapsed_time: float, elapsed_time: float,
state: WorkflowState state: WorkflowState
) -> dict[str, Any]: ) -> dict[str, Any]:
"""将错误包装成标准输出格式 """将错误包装成标准输出格式
@@ -447,10 +448,10 @@ class BaseNode(ABC):
""" """
# 查找错误边 # 查找错误边
error_edge = self._find_error_edge() error_edge = self._find_error_edge()
# 提取输入数据 # 提取输入数据
input_data = self._extract_input(state) input_data = self._extract_input(state)
# 构建错误输出 # 构建错误输出
node_output = { node_output = {
"node_id": self.node_id, "node_id": self.node_id,
@@ -463,7 +464,7 @@ class BaseNode(ABC):
"token_usage": None, "token_usage": None,
"error": error_message "error": error_message
} }
if error_edge: if error_edge:
# 有错误边:记录错误并继续 # 有错误边:记录错误并继续
logger.warning( logger.warning(
@@ -480,7 +481,7 @@ class BaseNode(ABC):
# 无错误边:抛出异常停止工作流 # 无错误边:抛出异常停止工作流
logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}") logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}")
raise Exception(f"节点 {self.node_id} 执行失败: {error_message}") raise Exception(f"节点 {self.node_id} 执行失败: {error_message}")
def _extract_input(self, state: WorkflowState) -> dict[str, Any]: def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
"""提取节点输入数据(用于记录) """提取节点输入数据(用于记录)
@@ -494,7 +495,7 @@ class BaseNode(ABC):
""" """
# 默认返回配置 # 默认返回配置
return {"config": self.config} return {"config": self.config}
def _extract_output(self, business_result: Any) -> Any: def _extract_output(self, business_result: Any) -> Any:
"""从业务结果中提取实际输出 """从业务结果中提取实际输出
@@ -508,7 +509,7 @@ class BaseNode(ABC):
""" """
# 默认直接返回业务结果 # 默认直接返回业务结果
return business_result return business_result
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
"""从业务结果中提取 token 使用情况 """从业务结果中提取 token 使用情况
@@ -522,7 +523,7 @@ class BaseNode(ABC):
""" """
# 默认返回 None # 默认返回 None
return None return None
def _find_error_edge(self) -> dict[str, Any] | None: def _find_error_edge(self) -> dict[str, Any] | None:
"""查找错误边 """查找错误边
@@ -533,8 +534,8 @@ class BaseNode(ABC):
if edge.get("source") == self.node_id and edge.get("type") == "error": if edge.get("source") == self.node_id and edge.get("type") == "error":
return edge return edge
return None return None
def _render_template(self, template: str, state: WorkflowState | None, struct: bool = True) -> str: def _render_template(self, template: str, state: WorkflowState | None, strict: bool = True) -> str:
"""渲染模板 """渲染模板
支持的变量命名空间: 支持的变量命名空间:
@@ -550,28 +551,28 @@ class BaseNode(ABC):
渲染后的字符串 渲染后的字符串
""" """
from app.core.workflow.template_renderer import render_template from app.core.workflow.template_renderer import render_template
# 处理 state 为 None 的情况 # 处理 state 为 None 的情况
if state is None: if state is None:
state = {} state = {}
# 使用变量池获取变量 # 使用变量池获取变量
pool = VariablePool(state) pool = VariablePool(state)
# 构建完整的 variables 结构 # 构建完整的 variables 结构
variables = { variables = {
"sys": pool.get_all_system_vars(), "sys": pool.get_all_system_vars(),
"conv": pool.get_all_conversation_vars() "conv": pool.get_all_conversation_vars()
} }
return render_template( return render_template(
template=template, template=template,
variables=variables, variables=variables,
node_outputs=pool.get_all_node_outputs(), node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(), system_vars=pool.get_all_system_vars(),
struct=struct strict=strict
) )
def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool: def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool:
"""评估条件表达式 """评估条件表达式
@@ -588,20 +589,20 @@ class BaseNode(ABC):
布尔值结果 布尔值结果
""" """
from app.core.workflow.expression_evaluator import evaluate_condition from app.core.workflow.expression_evaluator import evaluate_condition
# 处理 state 为 None 的情况 # 处理 state 为 None 的情况
if state is None: if state is None:
state = {} state = {}
# 使用变量池获取变量 # 使用变量池获取变量
pool = VariablePool(state) pool = VariablePool(state)
# 构建完整的 variables 结构(包含 sys 和 conv # 构建完整的 variables 结构(包含 sys 和 conv
variables = { variables = {
"sys": pool.get_all_system_vars(), "sys": pool.get_all_system_vars(),
"conv": pool.get_all_conversation_vars() "conv": pool.get_all_conversation_vars()
} }
return evaluate_condition( return evaluate_condition(
expression=expression, expression=expression,
variables=variables, variables=variables,
@@ -626,12 +627,12 @@ class BaseNode(ABC):
>>> llm_output = pool.get("llm_qa.output") >>> llm_output = pool.get("llm_qa.output")
""" """
return VariablePool(state) return VariablePool(state)
def get_variable( def get_variable(
self, self,
selector: list[str] | str, selector: list[str] | str,
state: WorkflowState, state: WorkflowState,
default: Any = None default: Any = None
) -> Any: ) -> Any:
"""获取变量值(便捷方法) """获取变量值(便捷方法)
@@ -650,7 +651,7 @@ class BaseNode(ABC):
""" """
pool = VariablePool(state) pool = VariablePool(state)
return pool.get(selector, default=default) return pool.get(selector, default=default)
def has_variable(self, selector: list[str] | str, state: WorkflowState) -> bool: def has_variable(self, selector: list[str] | str, state: WorkflowState) -> bool:
"""检查变量是否存在(便捷方法) """检查变量是否存在(便捷方法)

View File

@@ -37,7 +37,7 @@ class EndNode(BaseNode):
# 如果配置了输出模板,使用模板渲染;否则使用默认输出 # 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template: if output_template:
output = self._render_template(output_template, state, struct=False) output = self._render_template(output_template, state, strict=False)
else: else:
output = "工作流已完成" output = "工作流已完成"
@@ -156,6 +156,16 @@ class EndNode(BaseNode):
if not output_template: if not output_template:
output = "工作流已完成" output = "工作流已完成"
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer({
"type": "message", # End node output uses message type
"node_id": self.node_id,
"chunk": "",
"full_content": output,
"chunk_index": 1,
"is_suffix": False
})
yield {"__final__": True, "result": output} yield {"__final__": True, "result": output}
return return
@@ -190,7 +200,7 @@ class EndNode(BaseNode):
if upstream_llm_ref_index is None: if upstream_llm_ref_index is None:
# No reference to direct upstream LLM node, output complete template content # No reference to direct upstream LLM node, output complete template content
output = self._render_template(output_template, state) output = self._render_template(output_template, state, strict=False)
logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'") logger.info(f"节点 {self.node_id} 没有引用直接上游 LLM 节点,输出完整内容: '{output[:50]}...'")
# Send complete content via writer (as a single message chunk) # Send complete content via writer (as a single message chunk)
@@ -246,7 +256,7 @@ class EndNode(BaseNode):
suffix = "".join(suffix_parts) suffix = "".join(suffix_parts)
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀) # 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
full_output = self._render_template(output_template, state) full_output = self._render_template(output_template, state, strict=False)
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}") logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
logger.info(f"[后缀调试] 后缀内容: '{suffix}'") logger.info(f"[后缀调试] 后缀内容: '{suffix}'")

View File

@@ -5,6 +5,7 @@
""" """
import logging import logging
from collections import defaultdict
from typing import Any from typing import Any
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
@@ -12,6 +13,18 @@ from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndef
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SafeUndefined(Undefined):
"""访问未定义属性不会报错,返回空字符串"""
__slots__ = ()
def _fail_with_undefined_error(self, *args, **kwargs):
return ""
__add__ = __radd__ = __mul__ = __rmul__ = __div__ = __rdiv__ = __truediv__ = __rtruediv__ = _fail_with_undefined_error
__getitem__ = __getattr__ = _fail_with_undefined_error
__str__ = __repr__ = lambda self: ""
class TemplateRenderer: class TemplateRenderer:
"""模板渲染器""" """模板渲染器"""
@@ -21,8 +34,9 @@ class TemplateRenderer:
Args: Args:
strict: 是否使用严格模式(未定义变量会抛出异常) strict: 是否使用严格模式(未定义变量会抛出异常)
""" """
self.strict = strict
self.env = Environment( self.env = Environment(
undefined=StrictUndefined if strict else Undefined, undefined=StrictUndefined if strict else SafeUndefined,
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
) )
@@ -69,12 +83,17 @@ class TemplateRenderer:
# variables 的结构:{"sys": {...}, "conv": {...}} # variables 的结构:{"sys": {...}, "conv": {...}}
sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {} sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {}
conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {} conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {}
if self.strict:
context = { context = defaultdict(dict)
"conv": conv_vars, # 会话变量:{{conv.user_name}} context["conv"] = conv_vars
"node": node_outputs, # 节点输出:{{node.node_1.output}} context["nodes"] = node_outputs
"sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源) context["sys"] = {**(system_vars or {}), **sys_vars}
} else:
context = {
"conv": conv_vars, # 会话变量:{{conv.user_name}}
"node": node_outputs, # 节点输出:{{node.node_1.output}}
"sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源)
}
# 支持直接通过节点ID访问节点输出{{llm_qa.output}} # 支持直接通过节点ID访问节点输出{{llm_qa.output}}
# 将所有节点输出添加到顶层上下文 # 将所有节点输出添加到顶层上下文
@@ -141,12 +160,12 @@ def render_template(
variables: dict[str, Any], variables: dict[str, Any],
node_outputs: dict[str, Any], node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None, system_vars: dict[str, Any] | None = None,
struct: bool = True strict: bool = True
) -> str: ) -> str:
"""渲染模板(便捷函数) """渲染模板(便捷函数)
Args: Args:
struct: 渲染模式 strict: 严格模式
template: 模板字符串 template: 模板字符串
variables: 用户变量 variables: 用户变量
node_outputs: 节点输出 node_outputs: 节点输出
@@ -164,7 +183,7 @@ def render_template(
... ) ... )
'请分析: 这是一段文本' '请分析: 这是一段文本'
""" """
renderer = TemplateRenderer(strict=struct) renderer = TemplateRenderer(strict=strict)
return renderer.render(template, variables, node_outputs, system_vars) return renderer.render(template, variables, node_outputs, system_vars)

View File

@@ -53,7 +53,7 @@ nodes:
type: end type: end
name: 结束 name: 结束
config: config:
output: "{{llm_qa.output}}" output: "{{ llm_qa.output }}"
position: position:
x: 900 x: 900
y: 100 y: 100