feat(workflow): support retrieving variables wrapped in {{}} from variable pool
This commit is contained in:
@@ -10,6 +10,7 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -28,7 +29,7 @@ class VariableSelector:
|
||||
>>> selector = VariableSelector(["node_A", "output"])
|
||||
>>> selector = VariableSelector.from_string("sys.message")
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, path: list[str]):
|
||||
"""初始化变量选择器
|
||||
|
||||
@@ -37,11 +38,11 @@ class VariableSelector:
|
||||
"""
|
||||
if not path or len(path) < 1:
|
||||
raise ValueError("变量路径不能为空")
|
||||
|
||||
|
||||
self.path = path
|
||||
self.namespace = path[0] # sys, var, 或 node_id
|
||||
self.key = path[1] if len(path) > 1 else None
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, selector_str: str) -> "VariableSelector":
|
||||
"""从字符串创建选择器
|
||||
@@ -58,10 +59,10 @@ class VariableSelector:
|
||||
"""
|
||||
path = selector_str.split(".")
|
||||
return cls(path)
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
return ".".join(self.path)
|
||||
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"VariableSelector({self.path})"
|
||||
|
||||
@@ -84,7 +85,7 @@ class VariablePool:
|
||||
"AI 的回答"
|
||||
>>> pool.set(["conv", "user_name"], "张三")
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, state: "WorkflowState"):
|
||||
"""初始化变量池
|
||||
|
||||
@@ -92,7 +93,7 @@ class VariablePool:
|
||||
state: 工作流状态(LangGraph State)
|
||||
"""
|
||||
self.state = state
|
||||
|
||||
|
||||
def get(self, selector: list[str] | str, default: Any = None) -> Any:
|
||||
"""获取变量值
|
||||
|
||||
@@ -114,13 +115,15 @@ class VariablePool:
|
||||
"""
|
||||
# 转换为 VariableSelector
|
||||
if isinstance(selector, str):
|
||||
selector = VariableSelector.from_string(selector).path
|
||||
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
variable_literal = re.sub(pattern, r"\1", selector).strip()
|
||||
selector = VariableSelector.from_string(variable_literal).path
|
||||
|
||||
if not selector or len(selector) < 1:
|
||||
raise ValueError("变量选择器不能为空")
|
||||
|
||||
|
||||
namespace = selector[0]
|
||||
|
||||
|
||||
try:
|
||||
# 系统变量
|
||||
if namespace == "sys":
|
||||
@@ -128,30 +131,30 @@ class VariablePool:
|
||||
if not key:
|
||||
return self.state.get("variables", {}).get("sys", {})
|
||||
return self.state.get("variables", {}).get("sys", {}).get(key, default)
|
||||
|
||||
|
||||
# 会话变量
|
||||
elif namespace == "conv":
|
||||
key = selector[1] if len(selector) > 1 else None
|
||||
if not key:
|
||||
return self.state.get("variables", {}).get("conv", {})
|
||||
return self.state.get("variables", {}).get("conv", {}).get(key, default)
|
||||
|
||||
|
||||
# 节点输出(从 runtime_vars 读取)
|
||||
else:
|
||||
node_id = namespace
|
||||
runtime_vars = self.state.get("runtime_vars", {})
|
||||
|
||||
|
||||
if node_id not in runtime_vars:
|
||||
if default is not None:
|
||||
return default
|
||||
raise KeyError(f"节点 '{node_id}' 的输出不存在")
|
||||
|
||||
|
||||
node_var = runtime_vars[node_id]
|
||||
|
||||
|
||||
# 如果只有节点 ID,返回整个变量
|
||||
if len(selector) == 1:
|
||||
return node_var
|
||||
|
||||
|
||||
# 获取特定字段
|
||||
# 支持嵌套访问,如 node_id.field.subfield
|
||||
result = node_var
|
||||
@@ -166,14 +169,14 @@ class VariablePool:
|
||||
if default is not None:
|
||||
return default
|
||||
raise KeyError(f"无法访问 '{'.'.join(selector)}'")
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
except KeyError:
|
||||
if default is not None:
|
||||
return default
|
||||
raise
|
||||
|
||||
|
||||
def set(self, selector: list[str] | str, value: Any):
|
||||
"""设置变量值
|
||||
|
||||
@@ -192,17 +195,17 @@ class VariablePool:
|
||||
# 转换为 VariableSelector
|
||||
if isinstance(selector, str):
|
||||
selector = VariableSelector.from_string(selector).path
|
||||
|
||||
|
||||
if not selector or len(selector) < 2:
|
||||
raise ValueError("变量选择器必须包含命名空间和键名")
|
||||
|
||||
|
||||
namespace = selector[0]
|
||||
|
||||
|
||||
if namespace != "conv" and namespace not in self.state["cycle_nodes"]:
|
||||
raise ValueError("Only conversation or cycle variables can be assigned.")
|
||||
|
||||
|
||||
key = selector[1]
|
||||
|
||||
|
||||
# 确保 variables 结构存在
|
||||
if "variables" not in self.state:
|
||||
self.state["variables"] = {"sys": {}, "conv": {}}
|
||||
@@ -214,9 +217,9 @@ class VariablePool:
|
||||
self.state["variables"]["conv"][key] = value
|
||||
elif namespace in self.state["cycle_nodes"]:
|
||||
self.state["runtime_vars"][namespace][key] = value
|
||||
|
||||
|
||||
logger.debug(f"设置变量: {'.'.join(selector)} = {value}")
|
||||
|
||||
|
||||
def has(self, selector: list[str] | str) -> bool:
|
||||
"""检查变量是否存在
|
||||
|
||||
@@ -237,7 +240,7 @@ class VariablePool:
|
||||
return True
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
|
||||
def get_all_system_vars(self) -> dict[str, Any]:
|
||||
"""获取所有系统变量
|
||||
|
||||
@@ -245,7 +248,7 @@ class VariablePool:
|
||||
系统变量字典
|
||||
"""
|
||||
return self.state.get("variables", {}).get("sys", {})
|
||||
|
||||
|
||||
def get_all_conversation_vars(self) -> dict[str, Any]:
|
||||
"""获取所有会话变量
|
||||
|
||||
@@ -253,7 +256,7 @@ class VariablePool:
|
||||
会话变量字典
|
||||
"""
|
||||
return self.state.get("variables", {}).get("conv", {})
|
||||
|
||||
|
||||
def get_all_node_outputs(self) -> dict[str, Any]:
|
||||
"""获取所有节点输出(运行时变量)
|
||||
|
||||
@@ -261,7 +264,7 @@ class VariablePool:
|
||||
节点输出字典,键为节点 ID
|
||||
"""
|
||||
return self.state.get("runtime_vars", {})
|
||||
|
||||
|
||||
def get_node_output(self, node_id: str) -> dict[str, Any] | None:
|
||||
"""获取指定节点的输出(运行时变量)
|
||||
|
||||
@@ -272,7 +275,7 @@ class VariablePool:
|
||||
节点输出或 None
|
||||
"""
|
||||
return self.state.get("runtime_vars", {}).get(node_id)
|
||||
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""导出为字典
|
||||
|
||||
@@ -284,12 +287,12 @@ class VariablePool:
|
||||
"conversation": self.get_all_conversation_vars(),
|
||||
"nodes": self.get_all_node_outputs() # 从 runtime_vars 读取
|
||||
}
|
||||
|
||||
|
||||
def __repr__(self) -> str:
|
||||
sys_vars = self.get_all_system_vars()
|
||||
conv_vars = self.get_all_conversation_vars()
|
||||
runtime_vars = self.get_all_node_outputs()
|
||||
|
||||
|
||||
return (
|
||||
f"VariablePool(\n"
|
||||
f" system_vars={len(sys_vars)},\n"
|
||||
|
||||
Reference in New Issue
Block a user