From fc831e04c1428846e4396674dff9ee848a284506 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Mon, 5 Jan 2026 10:52:46 +0800 Subject: [PATCH] feat(workflow): support retrieving variables wrapped in {{}} from variable pool --- api/app/core/workflow/variable_pool.py | 71 ++++++++++++++------------ 1 file changed, 37 insertions(+), 34 deletions(-) diff --git a/api/app/core/workflow/variable_pool.py b/api/app/core/workflow/variable_pool.py index b7814f28..7d4b0609 100644 --- a/api/app/core/workflow/variable_pool.py +++ b/api/app/core/workflow/variable_pool.py @@ -10,6 +10,7 @@ """ import logging +import re from typing import Any, TYPE_CHECKING if TYPE_CHECKING: @@ -28,7 +29,7 @@ class VariableSelector: >>> selector = VariableSelector(["node_A", "output"]) >>> selector = VariableSelector.from_string("sys.message") """ - + def __init__(self, path: list[str]): """初始化变量选择器 @@ -37,11 +38,11 @@ class VariableSelector: """ if not path or len(path) < 1: raise ValueError("变量路径不能为空") - + self.path = path self.namespace = path[0] # sys, var, 或 node_id self.key = path[1] if len(path) > 1 else None - + @classmethod def from_string(cls, selector_str: str) -> "VariableSelector": """从字符串创建选择器 @@ -58,10 +59,10 @@ class VariableSelector: """ path = selector_str.split(".") return cls(path) - + def __str__(self) -> str: return ".".join(self.path) - + def __repr__(self) -> str: return f"VariableSelector({self.path})" @@ -84,7 +85,7 @@ class VariablePool: "AI 的回答" >>> pool.set(["conv", "user_name"], "张三") """ - + def __init__(self, state: "WorkflowState"): """初始化变量池 @@ -92,7 +93,7 @@ class VariablePool: state: 工作流状态(LangGraph State) """ self.state = state - + def get(self, selector: list[str] | str, default: Any = None) -> Any: """获取变量值 @@ -114,13 +115,15 @@ class VariablePool: """ # 转换为 VariableSelector if isinstance(selector, str): - selector = VariableSelector.from_string(selector).path - + pattern = r"\{\{\s*(.*?)\s*\}\}" + variable_literal = re.sub(pattern, r"\1", selector).strip() + selector = VariableSelector.from_string(variable_literal).path + if not selector or len(selector) < 1: raise ValueError("变量选择器不能为空") - + namespace = selector[0] - + try: # 系统变量 if namespace == "sys": @@ -128,30 +131,30 @@ class VariablePool: if not key: return self.state.get("variables", {}).get("sys", {}) return self.state.get("variables", {}).get("sys", {}).get(key, default) - + # 会话变量 elif namespace == "conv": key = selector[1] if len(selector) > 1 else None if not key: return self.state.get("variables", {}).get("conv", {}) return self.state.get("variables", {}).get("conv", {}).get(key, default) - + # 节点输出(从 runtime_vars 读取) else: node_id = namespace runtime_vars = self.state.get("runtime_vars", {}) - + if node_id not in runtime_vars: if default is not None: return default raise KeyError(f"节点 '{node_id}' 的输出不存在") - + node_var = runtime_vars[node_id] - + # 如果只有节点 ID,返回整个变量 if len(selector) == 1: return node_var - + # 获取特定字段 # 支持嵌套访问,如 node_id.field.subfield result = node_var @@ -166,14 +169,14 @@ class VariablePool: if default is not None: return default raise KeyError(f"无法访问 '{'.'.join(selector)}'") - + return result - + except KeyError: if default is not None: return default raise - + def set(self, selector: list[str] | str, value: Any): """设置变量值 @@ -192,17 +195,17 @@ class VariablePool: # 转换为 VariableSelector if isinstance(selector, str): selector = VariableSelector.from_string(selector).path - + if not selector or len(selector) < 2: raise ValueError("变量选择器必须包含命名空间和键名") - + namespace = selector[0] - + if namespace != "conv" and namespace not in self.state["cycle_nodes"]: raise ValueError("Only conversation or cycle variables can be assigned.") - + key = selector[1] - + # 确保 variables 结构存在 if "variables" not in self.state: self.state["variables"] = {"sys": {}, "conv": {}} @@ -214,9 +217,9 @@ class VariablePool: self.state["variables"]["conv"][key] = value elif namespace in self.state["cycle_nodes"]: self.state["runtime_vars"][namespace][key] = value - + logger.debug(f"设置变量: {'.'.join(selector)} = {value}") - + def has(self, selector: list[str] | str) -> bool: """检查变量是否存在 @@ -237,7 +240,7 @@ class VariablePool: return True except KeyError: return False - + def get_all_system_vars(self) -> dict[str, Any]: """获取所有系统变量 @@ -245,7 +248,7 @@ class VariablePool: 系统变量字典 """ return self.state.get("variables", {}).get("sys", {}) - + def get_all_conversation_vars(self) -> dict[str, Any]: """获取所有会话变量 @@ -253,7 +256,7 @@ class VariablePool: 会话变量字典 """ return self.state.get("variables", {}).get("conv", {}) - + def get_all_node_outputs(self) -> dict[str, Any]: """获取所有节点输出(运行时变量) @@ -261,7 +264,7 @@ class VariablePool: 节点输出字典,键为节点 ID """ return self.state.get("runtime_vars", {}) - + def get_node_output(self, node_id: str) -> dict[str, Any] | None: """获取指定节点的输出(运行时变量) @@ -272,7 +275,7 @@ class VariablePool: 节点输出或 None """ return self.state.get("runtime_vars", {}).get(node_id) - + def to_dict(self) -> dict[str, Any]: """导出为字典 @@ -284,12 +287,12 @@ class VariablePool: "conversation": self.get_all_conversation_vars(), "nodes": self.get_all_node_outputs() # 从 runtime_vars 读取 } - + def __repr__(self) -> str: sys_vars = self.get_all_system_vars() conv_vars = self.get_all_conversation_vars() runtime_vars = self.get_all_node_outputs() - + return ( f"VariablePool(\n" f" system_vars={len(sys_vars)},\n"