feat(workflow): support retrieving variables wrapped in {{}} from variable pool

This commit is contained in:
mengyonghao
2026-01-05 10:52:46 +08:00
parent bf6ede64bd
commit fc831e04c1

View File

@@ -10,6 +10,7 @@
"""
import logging
import re
from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
@@ -28,7 +29,7 @@ class VariableSelector:
>>> selector = VariableSelector(["node_A", "output"])
>>> selector = VariableSelector.from_string("sys.message")
"""
def __init__(self, path: list[str]):
"""初始化变量选择器
@@ -37,11 +38,11 @@ class VariableSelector:
"""
if not path or len(path) < 1:
raise ValueError("变量路径不能为空")
self.path = path
self.namespace = path[0] # sys, var, 或 node_id
self.key = path[1] if len(path) > 1 else None
@classmethod
def from_string(cls, selector_str: str) -> "VariableSelector":
"""从字符串创建选择器
@@ -58,10 +59,10 @@ class VariableSelector:
"""
path = selector_str.split(".")
return cls(path)
def __str__(self) -> str:
return ".".join(self.path)
def __repr__(self) -> str:
return f"VariableSelector({self.path})"
@@ -84,7 +85,7 @@ class VariablePool:
"AI 的回答"
>>> pool.set(["conv", "user_name"], "张三")
"""
def __init__(self, state: "WorkflowState"):
"""初始化变量池
@@ -92,7 +93,7 @@ class VariablePool:
state: 工作流状态LangGraph State
"""
self.state = state
def get(self, selector: list[str] | str, default: Any = None) -> Any:
"""获取变量值
@@ -114,13 +115,15 @@ class VariablePool:
"""
# 转换为 VariableSelector
if isinstance(selector, str):
selector = VariableSelector.from_string(selector).path
pattern = r"\{\{\s*(.*?)\s*\}\}"
variable_literal = re.sub(pattern, r"\1", selector).strip()
selector = VariableSelector.from_string(variable_literal).path
if not selector or len(selector) < 1:
raise ValueError("变量选择器不能为空")
namespace = selector[0]
try:
# 系统变量
if namespace == "sys":
@@ -128,30 +131,30 @@ class VariablePool:
if not key:
return self.state.get("variables", {}).get("sys", {})
return self.state.get("variables", {}).get("sys", {}).get(key, default)
# 会话变量
elif namespace == "conv":
key = selector[1] if len(selector) > 1 else None
if not key:
return self.state.get("variables", {}).get("conv", {})
return self.state.get("variables", {}).get("conv", {}).get(key, default)
# 节点输出(从 runtime_vars 读取)
else:
node_id = namespace
runtime_vars = self.state.get("runtime_vars", {})
if node_id not in runtime_vars:
if default is not None:
return default
raise KeyError(f"节点 '{node_id}' 的输出不存在")
node_var = runtime_vars[node_id]
# 如果只有节点 ID返回整个变量
if len(selector) == 1:
return node_var
# 获取特定字段
# 支持嵌套访问,如 node_id.field.subfield
result = node_var
@@ -166,14 +169,14 @@ class VariablePool:
if default is not None:
return default
raise KeyError(f"无法访问 '{'.'.join(selector)}'")
return result
except KeyError:
if default is not None:
return default
raise
def set(self, selector: list[str] | str, value: Any):
"""设置变量值
@@ -192,17 +195,17 @@ class VariablePool:
# 转换为 VariableSelector
if isinstance(selector, str):
selector = VariableSelector.from_string(selector).path
if not selector or len(selector) < 2:
raise ValueError("变量选择器必须包含命名空间和键名")
namespace = selector[0]
if namespace != "conv" and namespace not in self.state["cycle_nodes"]:
raise ValueError("Only conversation or cycle variables can be assigned.")
key = selector[1]
# 确保 variables 结构存在
if "variables" not in self.state:
self.state["variables"] = {"sys": {}, "conv": {}}
@@ -214,9 +217,9 @@ class VariablePool:
self.state["variables"]["conv"][key] = value
elif namespace in self.state["cycle_nodes"]:
self.state["runtime_vars"][namespace][key] = value
logger.debug(f"设置变量: {'.'.join(selector)} = {value}")
def has(self, selector: list[str] | str) -> bool:
"""检查变量是否存在
@@ -237,7 +240,7 @@ class VariablePool:
return True
except KeyError:
return False
def get_all_system_vars(self) -> dict[str, Any]:
"""获取所有系统变量
@@ -245,7 +248,7 @@ class VariablePool:
系统变量字典
"""
return self.state.get("variables", {}).get("sys", {})
def get_all_conversation_vars(self) -> dict[str, Any]:
"""获取所有会话变量
@@ -253,7 +256,7 @@ class VariablePool:
会话变量字典
"""
return self.state.get("variables", {}).get("conv", {})
def get_all_node_outputs(self) -> dict[str, Any]:
"""获取所有节点输出(运行时变量)
@@ -261,7 +264,7 @@ class VariablePool:
节点输出字典,键为节点 ID
"""
return self.state.get("runtime_vars", {})
def get_node_output(self, node_id: str) -> dict[str, Any] | None:
"""获取指定节点的输出(运行时变量)
@@ -272,7 +275,7 @@ class VariablePool:
节点输出或 None
"""
return self.state.get("runtime_vars", {}).get(node_id)
def to_dict(self) -> dict[str, Any]:
"""导出为字典
@@ -284,12 +287,12 @@ class VariablePool:
"conversation": self.get_all_conversation_vars(),
"nodes": self.get_all_node_outputs() # 从 runtime_vars 读取
}
def __repr__(self) -> str:
sys_vars = self.get_all_system_vars()
conv_vars = self.get_all_conversation_vars()
runtime_vars = self.get_all_node_outputs()
return (
f"VariablePool(\n"
f" system_vars={len(sys_vars)},\n"