feat(workflow): support retrieving variables wrapped in {{}} from variable pool
This commit is contained in:
@@ -10,6 +10,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import Any, TYPE_CHECKING
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -28,7 +29,7 @@ class VariableSelector:
|
|||||||
>>> selector = VariableSelector(["node_A", "output"])
|
>>> selector = VariableSelector(["node_A", "output"])
|
||||||
>>> selector = VariableSelector.from_string("sys.message")
|
>>> selector = VariableSelector.from_string("sys.message")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, path: list[str]):
|
def __init__(self, path: list[str]):
|
||||||
"""初始化变量选择器
|
"""初始化变量选择器
|
||||||
|
|
||||||
@@ -37,11 +38,11 @@ class VariableSelector:
|
|||||||
"""
|
"""
|
||||||
if not path or len(path) < 1:
|
if not path or len(path) < 1:
|
||||||
raise ValueError("变量路径不能为空")
|
raise ValueError("变量路径不能为空")
|
||||||
|
|
||||||
self.path = path
|
self.path = path
|
||||||
self.namespace = path[0] # sys, var, 或 node_id
|
self.namespace = path[0] # sys, var, 或 node_id
|
||||||
self.key = path[1] if len(path) > 1 else None
|
self.key = path[1] if len(path) > 1 else None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_string(cls, selector_str: str) -> "VariableSelector":
|
def from_string(cls, selector_str: str) -> "VariableSelector":
|
||||||
"""从字符串创建选择器
|
"""从字符串创建选择器
|
||||||
@@ -58,10 +59,10 @@ class VariableSelector:
|
|||||||
"""
|
"""
|
||||||
path = selector_str.split(".")
|
path = selector_str.split(".")
|
||||||
return cls(path)
|
return cls(path)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return ".".join(self.path)
|
return ".".join(self.path)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"VariableSelector({self.path})"
|
return f"VariableSelector({self.path})"
|
||||||
|
|
||||||
@@ -84,7 +85,7 @@ class VariablePool:
|
|||||||
"AI 的回答"
|
"AI 的回答"
|
||||||
>>> pool.set(["conv", "user_name"], "张三")
|
>>> pool.set(["conv", "user_name"], "张三")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, state: "WorkflowState"):
|
def __init__(self, state: "WorkflowState"):
|
||||||
"""初始化变量池
|
"""初始化变量池
|
||||||
|
|
||||||
@@ -92,7 +93,7 @@ class VariablePool:
|
|||||||
state: 工作流状态(LangGraph State)
|
state: 工作流状态(LangGraph State)
|
||||||
"""
|
"""
|
||||||
self.state = state
|
self.state = state
|
||||||
|
|
||||||
def get(self, selector: list[str] | str, default: Any = None) -> Any:
|
def get(self, selector: list[str] | str, default: Any = None) -> Any:
|
||||||
"""获取变量值
|
"""获取变量值
|
||||||
|
|
||||||
@@ -114,13 +115,15 @@ class VariablePool:
|
|||||||
"""
|
"""
|
||||||
# 转换为 VariableSelector
|
# 转换为 VariableSelector
|
||||||
if isinstance(selector, str):
|
if isinstance(selector, str):
|
||||||
selector = VariableSelector.from_string(selector).path
|
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||||
|
variable_literal = re.sub(pattern, r"\1", selector).strip()
|
||||||
|
selector = VariableSelector.from_string(variable_literal).path
|
||||||
|
|
||||||
if not selector or len(selector) < 1:
|
if not selector or len(selector) < 1:
|
||||||
raise ValueError("变量选择器不能为空")
|
raise ValueError("变量选择器不能为空")
|
||||||
|
|
||||||
namespace = selector[0]
|
namespace = selector[0]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 系统变量
|
# 系统变量
|
||||||
if namespace == "sys":
|
if namespace == "sys":
|
||||||
@@ -128,30 +131,30 @@ class VariablePool:
|
|||||||
if not key:
|
if not key:
|
||||||
return self.state.get("variables", {}).get("sys", {})
|
return self.state.get("variables", {}).get("sys", {})
|
||||||
return self.state.get("variables", {}).get("sys", {}).get(key, default)
|
return self.state.get("variables", {}).get("sys", {}).get(key, default)
|
||||||
|
|
||||||
# 会话变量
|
# 会话变量
|
||||||
elif namespace == "conv":
|
elif namespace == "conv":
|
||||||
key = selector[1] if len(selector) > 1 else None
|
key = selector[1] if len(selector) > 1 else None
|
||||||
if not key:
|
if not key:
|
||||||
return self.state.get("variables", {}).get("conv", {})
|
return self.state.get("variables", {}).get("conv", {})
|
||||||
return self.state.get("variables", {}).get("conv", {}).get(key, default)
|
return self.state.get("variables", {}).get("conv", {}).get(key, default)
|
||||||
|
|
||||||
# 节点输出(从 runtime_vars 读取)
|
# 节点输出(从 runtime_vars 读取)
|
||||||
else:
|
else:
|
||||||
node_id = namespace
|
node_id = namespace
|
||||||
runtime_vars = self.state.get("runtime_vars", {})
|
runtime_vars = self.state.get("runtime_vars", {})
|
||||||
|
|
||||||
if node_id not in runtime_vars:
|
if node_id not in runtime_vars:
|
||||||
if default is not None:
|
if default is not None:
|
||||||
return default
|
return default
|
||||||
raise KeyError(f"节点 '{node_id}' 的输出不存在")
|
raise KeyError(f"节点 '{node_id}' 的输出不存在")
|
||||||
|
|
||||||
node_var = runtime_vars[node_id]
|
node_var = runtime_vars[node_id]
|
||||||
|
|
||||||
# 如果只有节点 ID,返回整个变量
|
# 如果只有节点 ID,返回整个变量
|
||||||
if len(selector) == 1:
|
if len(selector) == 1:
|
||||||
return node_var
|
return node_var
|
||||||
|
|
||||||
# 获取特定字段
|
# 获取特定字段
|
||||||
# 支持嵌套访问,如 node_id.field.subfield
|
# 支持嵌套访问,如 node_id.field.subfield
|
||||||
result = node_var
|
result = node_var
|
||||||
@@ -166,14 +169,14 @@ class VariablePool:
|
|||||||
if default is not None:
|
if default is not None:
|
||||||
return default
|
return default
|
||||||
raise KeyError(f"无法访问 '{'.'.join(selector)}'")
|
raise KeyError(f"无法访问 '{'.'.join(selector)}'")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
if default is not None:
|
if default is not None:
|
||||||
return default
|
return default
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def set(self, selector: list[str] | str, value: Any):
|
def set(self, selector: list[str] | str, value: Any):
|
||||||
"""设置变量值
|
"""设置变量值
|
||||||
|
|
||||||
@@ -192,17 +195,17 @@ class VariablePool:
|
|||||||
# 转换为 VariableSelector
|
# 转换为 VariableSelector
|
||||||
if isinstance(selector, str):
|
if isinstance(selector, str):
|
||||||
selector = VariableSelector.from_string(selector).path
|
selector = VariableSelector.from_string(selector).path
|
||||||
|
|
||||||
if not selector or len(selector) < 2:
|
if not selector or len(selector) < 2:
|
||||||
raise ValueError("变量选择器必须包含命名空间和键名")
|
raise ValueError("变量选择器必须包含命名空间和键名")
|
||||||
|
|
||||||
namespace = selector[0]
|
namespace = selector[0]
|
||||||
|
|
||||||
if namespace != "conv" and namespace not in self.state["cycle_nodes"]:
|
if namespace != "conv" and namespace not in self.state["cycle_nodes"]:
|
||||||
raise ValueError("Only conversation or cycle variables can be assigned.")
|
raise ValueError("Only conversation or cycle variables can be assigned.")
|
||||||
|
|
||||||
key = selector[1]
|
key = selector[1]
|
||||||
|
|
||||||
# 确保 variables 结构存在
|
# 确保 variables 结构存在
|
||||||
if "variables" not in self.state:
|
if "variables" not in self.state:
|
||||||
self.state["variables"] = {"sys": {}, "conv": {}}
|
self.state["variables"] = {"sys": {}, "conv": {}}
|
||||||
@@ -214,9 +217,9 @@ class VariablePool:
|
|||||||
self.state["variables"]["conv"][key] = value
|
self.state["variables"]["conv"][key] = value
|
||||||
elif namespace in self.state["cycle_nodes"]:
|
elif namespace in self.state["cycle_nodes"]:
|
||||||
self.state["runtime_vars"][namespace][key] = value
|
self.state["runtime_vars"][namespace][key] = value
|
||||||
|
|
||||||
logger.debug(f"设置变量: {'.'.join(selector)} = {value}")
|
logger.debug(f"设置变量: {'.'.join(selector)} = {value}")
|
||||||
|
|
||||||
def has(self, selector: list[str] | str) -> bool:
|
def has(self, selector: list[str] | str) -> bool:
|
||||||
"""检查变量是否存在
|
"""检查变量是否存在
|
||||||
|
|
||||||
@@ -237,7 +240,7 @@ class VariablePool:
|
|||||||
return True
|
return True
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_all_system_vars(self) -> dict[str, Any]:
|
def get_all_system_vars(self) -> dict[str, Any]:
|
||||||
"""获取所有系统变量
|
"""获取所有系统变量
|
||||||
|
|
||||||
@@ -245,7 +248,7 @@ class VariablePool:
|
|||||||
系统变量字典
|
系统变量字典
|
||||||
"""
|
"""
|
||||||
return self.state.get("variables", {}).get("sys", {})
|
return self.state.get("variables", {}).get("sys", {})
|
||||||
|
|
||||||
def get_all_conversation_vars(self) -> dict[str, Any]:
|
def get_all_conversation_vars(self) -> dict[str, Any]:
|
||||||
"""获取所有会话变量
|
"""获取所有会话变量
|
||||||
|
|
||||||
@@ -253,7 +256,7 @@ class VariablePool:
|
|||||||
会话变量字典
|
会话变量字典
|
||||||
"""
|
"""
|
||||||
return self.state.get("variables", {}).get("conv", {})
|
return self.state.get("variables", {}).get("conv", {})
|
||||||
|
|
||||||
def get_all_node_outputs(self) -> dict[str, Any]:
|
def get_all_node_outputs(self) -> dict[str, Any]:
|
||||||
"""获取所有节点输出(运行时变量)
|
"""获取所有节点输出(运行时变量)
|
||||||
|
|
||||||
@@ -261,7 +264,7 @@ class VariablePool:
|
|||||||
节点输出字典,键为节点 ID
|
节点输出字典,键为节点 ID
|
||||||
"""
|
"""
|
||||||
return self.state.get("runtime_vars", {})
|
return self.state.get("runtime_vars", {})
|
||||||
|
|
||||||
def get_node_output(self, node_id: str) -> dict[str, Any] | None:
|
def get_node_output(self, node_id: str) -> dict[str, Any] | None:
|
||||||
"""获取指定节点的输出(运行时变量)
|
"""获取指定节点的输出(运行时变量)
|
||||||
|
|
||||||
@@ -272,7 +275,7 @@ class VariablePool:
|
|||||||
节点输出或 None
|
节点输出或 None
|
||||||
"""
|
"""
|
||||||
return self.state.get("runtime_vars", {}).get(node_id)
|
return self.state.get("runtime_vars", {}).get(node_id)
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""导出为字典
|
"""导出为字典
|
||||||
|
|
||||||
@@ -284,12 +287,12 @@ class VariablePool:
|
|||||||
"conversation": self.get_all_conversation_vars(),
|
"conversation": self.get_all_conversation_vars(),
|
||||||
"nodes": self.get_all_node_outputs() # 从 runtime_vars 读取
|
"nodes": self.get_all_node_outputs() # 从 runtime_vars 读取
|
||||||
}
|
}
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
sys_vars = self.get_all_system_vars()
|
sys_vars = self.get_all_system_vars()
|
||||||
conv_vars = self.get_all_conversation_vars()
|
conv_vars = self.get_all_conversation_vars()
|
||||||
runtime_vars = self.get_all_node_outputs()
|
runtime_vars = self.get_all_node_outputs()
|
||||||
|
|
||||||
return (
|
return (
|
||||||
f"VariablePool(\n"
|
f"VariablePool(\n"
|
||||||
f" system_vars={len(sys_vars)},\n"
|
f" system_vars={len(sys_vars)},\n"
|
||||||
|
|||||||
Reference in New Issue
Block a user