feat(workflow): enforce strong typing for runtime variables
- Reduce exposed information in release workflows
This commit is contained in:
@@ -587,7 +587,8 @@ async def chat(
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=release.app_id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=release.id
|
||||
release_id=release.id,
|
||||
public=True
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,3 @@
|
||||
"""
|
||||
安全的表达式求值器
|
||||
|
||||
使用 simpleeval 库提供安全的表达式评估,避免代码注入攻击。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
@@ -14,160 +8,119 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExpressionEvaluator:
|
||||
"""安全的表达式求值器"""
|
||||
"""Safe expression evaluator for workflow variables and node outputs."""
|
||||
|
||||
# 保留的命名空间
|
||||
# Reserved namespaces
|
||||
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
|
||||
|
||||
@staticmethod
|
||||
def evaluate(
|
||||
expression: str,
|
||||
variables: dict[str, Any],
|
||||
conv_vars: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""安全地评估表达式
|
||||
|
||||
Args:
|
||||
expression: 表达式字符串,如 "{{var.score}} > 0.8"
|
||||
variables: 用户定义的变量
|
||||
node_outputs: 节点输出结果
|
||||
system_vars: 系统变量
|
||||
|
||||
Returns:
|
||||
表达式求值结果
|
||||
|
||||
Raises:
|
||||
ValueError: 表达式无效或求值失败
|
||||
|
||||
Examples:
|
||||
>>> evaluator = ExpressionEvaluator()
|
||||
>>> evaluator.evaluate(
|
||||
... "var.score > 0.8",
|
||||
... {"score": 0.9},
|
||||
... {},
|
||||
... {}
|
||||
... )
|
||||
True
|
||||
|
||||
>>> evaluator.evaluate(
|
||||
... "node.intent.output == '售前咨询'",
|
||||
... {},
|
||||
... {"intent": {"output": "售前咨询"}},
|
||||
... {}
|
||||
... )
|
||||
True
|
||||
"""
|
||||
# 移除 Jinja2 模板语法的花括号(如果存在)
|
||||
Safely evaluate an expression using workflow variables.
|
||||
|
||||
Args:
|
||||
expression (str): The expression string, e.g., "var.score > 0.8"
|
||||
conv_vars (dict): Conversation-level variables
|
||||
node_outputs (dict): Outputs from workflow nodes
|
||||
system_vars (dict, optional): System variables
|
||||
|
||||
Returns:
|
||||
Any: Result of the evaluated expression
|
||||
|
||||
Raises:
|
||||
ValueError: If the expression is invalid or evaluation fails
|
||||
"""
|
||||
# Remove Jinja2-style brackets if present
|
||||
expression = expression.strip()
|
||||
# "{{system.message}} == {{ user.messge }}" -> "system.message == user.message"
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
expression = re.sub(pattern, r"\1", expression).strip()
|
||||
|
||||
# 构建命名空间上下文
|
||||
# Build context for evaluation
|
||||
context = {
|
||||
"var": variables, # 用户变量
|
||||
"node": node_outputs, # 节点输出
|
||||
"sys": system_vars or {}, # 系统变量
|
||||
"conv": conv_vars, # conversation variables
|
||||
"node": node_outputs, # node outputs
|
||||
"sys": system_vars or {}, # system variables
|
||||
}
|
||||
|
||||
# 为了向后兼容,也支持直接访问(但会在日志中警告)
|
||||
context.update(variables)
|
||||
|
||||
context.update(conv_vars)
|
||||
context["nodes"] = node_outputs
|
||||
context.update(node_outputs)
|
||||
|
||||
try:
|
||||
# simpleeval 只支持安全的操作:
|
||||
# - 算术运算: +, -, *, /, //, %, **
|
||||
# - 比较运算: ==, !=, <, <=, >, >=
|
||||
# - 逻辑运算: and, or, not
|
||||
# - 成员运算: in, not in
|
||||
# - 属性访问: obj.attr
|
||||
# - 字典/列表访问: obj["key"], obj[0]
|
||||
# 不支持:函数调用、导入、赋值等危险操作
|
||||
# simpleeval supports safe operations:
|
||||
# arithmetic, comparisons, logical ops, attribute/dict/list access
|
||||
result = simple_eval(expression, names=context)
|
||||
return result
|
||||
|
||||
except NameNotDefined as e:
|
||||
logger.error(f"表达式中引用了未定义的变量: {expression}, 错误: {e}")
|
||||
raise ValueError(f"未定义的变量: {e}")
|
||||
logger.error(f"Undefined variable in expression: {expression}, error: {e}")
|
||||
raise ValueError(f"Undefined variable: {e}")
|
||||
|
||||
except InvalidExpression as e:
|
||||
logger.error(f"表达式语法无效: {expression}, 错误: {e}")
|
||||
raise ValueError(f"表达式语法无效: {e}")
|
||||
logger.error(f"Invalid expression syntax: {expression}, error: {e}")
|
||||
raise ValueError(f"Invalid expression syntax: {e}")
|
||||
|
||||
except SyntaxError as e:
|
||||
logger.error(f"表达式语法错误: {expression}, 错误: {e}")
|
||||
raise ValueError(f"表达式语法错误: {e}")
|
||||
logger.error(f"Syntax error in expression: {expression}, error: {e}")
|
||||
raise ValueError(f"Syntax error: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"表达式求值异常: {expression}, 错误: {e}")
|
||||
raise ValueError(f"表达式求值失败: {e}")
|
||||
logger.error(f"Expression evaluation failed: {expression}, error: {e}")
|
||||
raise ValueError(f"Expression evaluation failed: {e}")
|
||||
|
||||
@staticmethod
|
||||
def evaluate_bool(
|
||||
expression: str,
|
||||
variables: dict[str, Any],
|
||||
conv_var: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
"""评估布尔表达式(用于条件判断)
|
||||
|
||||
"""
|
||||
Evaluate a boolean expression (for conditions).
|
||||
|
||||
Args:
|
||||
expression: 布尔表达式
|
||||
variables: 用户变量
|
||||
node_outputs: 节点输出
|
||||
system_vars: 系统变量
|
||||
|
||||
expression (str): Boolean expression
|
||||
conv_var (dict): Conversation variables
|
||||
node_outputs (dict): Node outputs
|
||||
system_vars (dict, optional): System variables
|
||||
|
||||
Returns:
|
||||
布尔值结果
|
||||
|
||||
Examples:
|
||||
>>> ExpressionEvaluator.evaluate_bool(
|
||||
... "var.count >= 10 and var.status == 'active'",
|
||||
... {"count": 15, "status": "active"},
|
||||
... {},
|
||||
... {}
|
||||
... )
|
||||
True
|
||||
bool: Boolean result
|
||||
"""
|
||||
result = ExpressionEvaluator.evaluate(
|
||||
expression, variables, node_outputs, system_vars
|
||||
expression, conv_var, node_outputs, system_vars
|
||||
)
|
||||
return bool(result)
|
||||
|
||||
@staticmethod
|
||||
def validate_variable_names(variables: list[dict]) -> list[str]:
|
||||
"""验证变量名是否合法
|
||||
|
||||
"""
|
||||
Validate variable names for legality.
|
||||
|
||||
Args:
|
||||
variables: 变量定义列表
|
||||
|
||||
variables (list[dict]): List of variable definitions
|
||||
|
||||
Returns:
|
||||
错误列表,如果为空则验证通过
|
||||
|
||||
Examples:
|
||||
>>> ExpressionEvaluator.validate_variable_names([
|
||||
... {"name": "user_input"},
|
||||
... {"name": "var"} # 保留字
|
||||
... ])
|
||||
["变量名 'var' 是保留的命名空间,请使用其他名称"]
|
||||
list[str]: List of error messages. Empty if all names are valid.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
for var in variables:
|
||||
var_name = var.get("name", "")
|
||||
|
||||
# 检查是否为保留命名空间
|
||||
|
||||
if var_name in ExpressionEvaluator.RESERVED_NAMESPACES:
|
||||
errors.append(
|
||||
f"变量名 '{var_name}' 是保留的命名空间,请使用其他名称"
|
||||
f"Variable name '{var_name}' is a reserved namespace, please use another name"
|
||||
)
|
||||
|
||||
# 检查是否为有效的 Python 标识符
|
||||
|
||||
if not var_name.isidentifier():
|
||||
errors.append(
|
||||
f"变量名 '{var_name}' 不是有效的标识符"
|
||||
f"Variable name '{var_name}' is not a valid Python identifier"
|
||||
)
|
||||
|
||||
return errors
|
||||
@@ -176,23 +129,23 @@ class ExpressionEvaluator:
|
||||
# 便捷函数
|
||||
def evaluate_expression(
|
||||
expression: str,
|
||||
variables: dict[str, Any],
|
||||
conv_var: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
system_vars: dict[str, Any]
|
||||
) -> Any:
|
||||
"""评估表达式(便捷函数)"""
|
||||
"""Evaluate an expression (convenience function)."""
|
||||
return ExpressionEvaluator.evaluate(
|
||||
expression, variables, node_outputs, system_vars
|
||||
expression, conv_var, node_outputs, system_vars
|
||||
)
|
||||
|
||||
|
||||
def evaluate_condition(
|
||||
expression: str,
|
||||
variables: dict[str, Any],
|
||||
conv_var: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
"""评估条件表达式(便捷函数)"""
|
||||
"""Evaluate a boolean condition expression (convenience function)."""
|
||||
return ExpressionEvaluator.evaluate_bool(
|
||||
expression, variables, node_outputs, system_vars
|
||||
expression, conv_var, node_outputs, system_vars
|
||||
)
|
||||
|
||||
@@ -14,9 +14,14 @@ from pydantic import BaseModel, Field
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SCOPE_PATTERN = re.compile(
|
||||
r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\.[a-zA-Z0-9_]+\s*}}"
|
||||
)
|
||||
|
||||
|
||||
class OutputContent(BaseModel):
|
||||
"""
|
||||
@@ -53,6 +58,12 @@ class OutputContent(BaseModel):
|
||||
)
|
||||
)
|
||||
|
||||
_SCOPE: str | None = None
|
||||
|
||||
def get_scope(self) -> str:
|
||||
self._SCOPE = SCOPE_PATTERN.findall(self.literal)[0]
|
||||
return self._SCOPE
|
||||
|
||||
def depends_on_scope(self, scope: str) -> bool:
|
||||
"""
|
||||
Check if this segment depends on a given scope.
|
||||
@@ -63,8 +74,9 @@ class OutputContent(BaseModel):
|
||||
Returns:
|
||||
bool: True if this segment references the given scope.
|
||||
"""
|
||||
pattern = rf"\{{\{{\s*{re.escape(scope)}\.[a-zA-Z0-9_]+\s*\}}\}}"
|
||||
return bool(re.search(pattern, self.literal))
|
||||
if self._SCOPE:
|
||||
return self._SCOPE == scope
|
||||
return self.get_scope() == scope
|
||||
|
||||
|
||||
class StreamOutputConfig(BaseModel):
|
||||
@@ -167,6 +179,7 @@ class GraphBuilder:
|
||||
workflow_config: dict[str, Any],
|
||||
stream: bool = False,
|
||||
subgraph: bool = False,
|
||||
variable_pool: VariablePool | None = None
|
||||
):
|
||||
self.workflow_config = workflow_config
|
||||
|
||||
@@ -180,6 +193,10 @@ class GraphBuilder:
|
||||
self._find_upstream_branch_node = lru_cache(
|
||||
maxsize=len(self.nodes) * 2
|
||||
)(self._find_upstream_branch_node)
|
||||
if variable_pool:
|
||||
self.variable_pool = variable_pool
|
||||
else:
|
||||
self.variable_pool = VariablePool()
|
||||
|
||||
self.graph = StateGraph(WorkflowState)
|
||||
self.add_nodes()
|
||||
@@ -452,9 +469,9 @@ class GraphBuilder:
|
||||
if self.stream:
|
||||
# Stream mode: create an async generator function
|
||||
# LangGraph collects all yielded values; the last yielded dictionary is merged into the state
|
||||
def make_stream_func(inst):
|
||||
def make_stream_func(inst, variable_pool=self.variable_pool):
|
||||
async def node_func(state: WorkflowState):
|
||||
async for item in inst.run_stream(state):
|
||||
async for item in inst.run_stream(state, variable_pool):
|
||||
yield item
|
||||
|
||||
return node_func
|
||||
@@ -462,9 +479,9 @@ class GraphBuilder:
|
||||
self.graph.add_node(node_id, make_stream_func(node_instance))
|
||||
else:
|
||||
# Non-stream mode: create an async function
|
||||
def make_func(inst):
|
||||
def make_func(inst, variable_pool=self.variable_pool):
|
||||
async def node_func(state: WorkflowState):
|
||||
return await inst.run(state)
|
||||
return await inst.run(state, variable_pool)
|
||||
|
||||
return node_func
|
||||
|
||||
@@ -567,27 +584,28 @@ class GraphBuilder:
|
||||
for target in branch_info["target"]:
|
||||
waiting_edges[target].append(branch_info["node"]["name"])
|
||||
|
||||
def router_fn(state: WorkflowState) -> list[Send]:
|
||||
def router_fn(state: WorkflowState, variable_pool: VariablePool = self.variable_pool) -> list[Send]:
|
||||
branch_activate = []
|
||||
new_state = state.copy()
|
||||
new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate
|
||||
|
||||
node_output = variable_pool.get_node_output(src, defalut=dict(), strict=False)
|
||||
for label, branch in unique_branch.items():
|
||||
if evaluate_condition(
|
||||
if node_output and evaluate_condition(
|
||||
branch["condition"],
|
||||
state.get("variables", {}),
|
||||
state.get("runtime_vars", {}),
|
||||
{
|
||||
"execution_id": state.get("execution_id"),
|
||||
"workspace_id": state.get("workspace_id"),
|
||||
"user_id": state.get("user_id")
|
||||
}
|
||||
{},
|
||||
{src: node_output},
|
||||
{}
|
||||
):
|
||||
logger.debug(f"Conditional routing {src}: selected branch {label}")
|
||||
new_state["activate"][branch["node"]["name"]] = True
|
||||
branch_activate.append(
|
||||
Send(
|
||||
branch['node']['name'],
|
||||
new_state
|
||||
)
|
||||
)
|
||||
continue
|
||||
new_state["activate"][branch["node"]["name"]] = False
|
||||
for label, branch in unique_branch.items():
|
||||
branch_activate.append(
|
||||
Send(
|
||||
branch['node']['name'],
|
||||
|
||||
@@ -15,7 +15,6 @@ from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from app.core.workflow.nodes.tool import ToolNode
|
||||
@@ -25,7 +24,6 @@ __all__ = [
|
||||
"WorkflowState",
|
||||
"LLMNode",
|
||||
"AgentNode",
|
||||
"TransformNode",
|
||||
"IfElseNode",
|
||||
"StartNode",
|
||||
"EndNode",
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
|
||||
class AgentNodeConfig(BaseNodeConfig):
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
Agent 节点实现
|
||||
|
||||
调用已发布的 Agent 应用。
|
||||
# TODO
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -9,6 +10,8 @@ from typing import Any
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.models import AppRelease
|
||||
from app.db import get_db
|
||||
@@ -30,19 +33,22 @@ class AgentNode(BaseNode):
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def _prepare_agent(self, state: WorkflowState) -> tuple[DraftRunService, AppRelease, str]:
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {"output": VariableType.STRING}
|
||||
|
||||
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[DraftRunService, AppRelease, str]:
|
||||
"""准备 Agent(公共逻辑)
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
variable_pool: 变量池
|
||||
|
||||
Returns:
|
||||
(draft_service, release, message): 服务实例、发布配置、消息
|
||||
"""
|
||||
# 1. 渲染消息
|
||||
message_template = self.config.get("message", "")
|
||||
message = self._render_template(message_template, state)
|
||||
message = self._render_template(message_template, variable_pool)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
agent_id = self.config.get("agent_id")
|
||||
@@ -61,16 +67,17 @@ class AgentNode(BaseNode):
|
||||
|
||||
return draft_service, release, message
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""非流式执行
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
variable_pool: 变量池
|
||||
|
||||
Returns:
|
||||
状态更新字典
|
||||
"""
|
||||
draft_service, release, message = self._prepare_agent(state)
|
||||
draft_service, release, message = self._prepare_agent(variable_pool)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
|
||||
|
||||
@@ -79,9 +86,9 @@ class AgentNode(BaseNode):
|
||||
agent_config=release.config,
|
||||
model_config=None,
|
||||
message=message,
|
||||
workspace_id=state.get("workspace_id"),
|
||||
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
||||
user_id=state.get("user_id"),
|
||||
variables=state.get("variables", {})
|
||||
variables=variable_pool.get_all_conversation_vars()
|
||||
)
|
||||
|
||||
response = result.get("response", "")
|
||||
@@ -99,16 +106,17 @@ class AgentNode(BaseNode):
|
||||
}
|
||||
}
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||
"""流式执行
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
variable_pool: 变量池
|
||||
|
||||
Yields:
|
||||
流式事件字典
|
||||
"""
|
||||
draft_service, release, message = self._prepare_agent(state)
|
||||
draft_service, release, message = self._prepare_agent(variable_pool)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
|
||||
|
||||
@@ -120,9 +128,9 @@ class AgentNode(BaseNode):
|
||||
agent_config=release.config,
|
||||
model_config=None,
|
||||
message=message,
|
||||
workspace_id=state.get("workspace_id"),
|
||||
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
||||
user_id=state.get("user_id"),
|
||||
variables=state.get("variables", {})
|
||||
variables=variable_pool.get_all_conversation_vars()
|
||||
):
|
||||
# 提取内容
|
||||
content = chunk.get("content", "")
|
||||
|
||||
@@ -6,6 +6,7 @@ from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import AssignmentOperator
|
||||
from app.core.workflow.nodes.operators import AssignmentOperatorInstance, AssignmentOperatorResolver
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -17,13 +18,17 @@ class AssignerNode(BaseNode):
|
||||
self.variable_updater = True
|
||||
self.typed_config: AssignerNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
Execute the assignment operation defined by this node.
|
||||
|
||||
Args:
|
||||
state: The current workflow state, including conversation variables,
|
||||
node outputs, and system variables.
|
||||
variable_pool: variable pool
|
||||
|
||||
Returns:
|
||||
None or the result of the assignment operation.
|
||||
@@ -31,60 +36,57 @@ class AssignerNode(BaseNode):
|
||||
# Initialize a variable pool for accessing conversation, node, and system variables
|
||||
self.typed_config = AssignerNodeConfig(**self.config)
|
||||
logger.info(f"节点 {self.node_id} 开始执行")
|
||||
pool = VariablePool(state)
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
|
||||
for assignment in self.typed_config.assignments:
|
||||
# Get the target variable selector (e.g., "conv.test")
|
||||
variable_selector = assignment.variable_selector
|
||||
if isinstance(variable_selector, str):
|
||||
# Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"]
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
expression = re.sub(pattern, r"\1", variable_selector).strip()
|
||||
variable_selector = expression.split('.')
|
||||
namespace = re.sub(pattern, r"\1", variable_selector).split('.')[0]
|
||||
|
||||
# Only conversation variables ('conv') are allowed
|
||||
if variable_selector[0] != 'conv' and variable_selector[0] not in state["cycle_nodes"]:
|
||||
raise ValueError("Only conversation or cycle variables can be assigned.")
|
||||
if namespace != 'conv' and namespace not in state["cycle_nodes"]:
|
||||
raise ValueError(f"Only conversation or cycle variables can be assigned. - {variable_selector}")
|
||||
|
||||
# Get the value or expression to assign
|
||||
value = assignment.value
|
||||
logger.debug(f"left:{variable_selector}, right: {value}")
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
|
||||
if isinstance(value, str):
|
||||
expression = re.match(pattern, value)
|
||||
if expression:
|
||||
expression = expression.group(1)
|
||||
expression = re.sub(pattern, r"\1", expression).strip()
|
||||
value = self.get_variable(expression, state)
|
||||
value = self.get_variable(expression, variable_pool, default=value, strict=False)
|
||||
|
||||
# Select the appropriate assignment operator instance based on the target variable type
|
||||
operator: AssignmentOperatorInstance = AssignmentOperatorResolver.resolve_by_value(
|
||||
pool.get(variable_selector)
|
||||
variable_pool.get_value(variable_selector)
|
||||
)(
|
||||
pool, variable_selector, value
|
||||
variable_pool, variable_selector, value
|
||||
)
|
||||
|
||||
# Execute the configured assignment operation
|
||||
match assignment.operation:
|
||||
case AssignmentOperator.COVER:
|
||||
operator.assign()
|
||||
await operator.assign()
|
||||
case AssignmentOperator.ASSIGN:
|
||||
operator.assign()
|
||||
await operator.assign()
|
||||
case AssignmentOperator.CLEAR:
|
||||
operator.clear()
|
||||
await operator.clear()
|
||||
case AssignmentOperator.ADD:
|
||||
operator.add()
|
||||
await operator.add()
|
||||
case AssignmentOperator.SUBTRACT:
|
||||
operator.subtract()
|
||||
await operator.subtract()
|
||||
case AssignmentOperator.MULTIPLY:
|
||||
operator.multiply()
|
||||
await operator.multiply()
|
||||
case AssignmentOperator.DIVIDE:
|
||||
operator.divide()
|
||||
await operator.divide()
|
||||
case AssignmentOperator.APPEND:
|
||||
operator.append()
|
||||
await operator.append()
|
||||
case AssignmentOperator.REMOVE_FIRST:
|
||||
operator.remove_first()
|
||||
await operator.remove_first()
|
||||
case AssignmentOperator.REMOVE_LAST:
|
||||
operator.remove_last()
|
||||
await operator.remove_last()
|
||||
case _:
|
||||
raise ValueError(f"Invalid Operator: {assignment.operation}")
|
||||
logger.info(f"Node {self.node_id}: execution completed")
|
||||
|
||||
@@ -3,79 +3,13 @@
|
||||
定义所有节点配置的通用字段和数据结构。
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
VARIABLE_PATTERN = r"\{\{\s*(.*?)\s*\}\}"
|
||||
|
||||
|
||||
class VariableType(StrEnum):
|
||||
"""变量类型枚举"""
|
||||
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
OBJECT = "object"
|
||||
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_BOOLEAN = "array[boolean]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
|
||||
|
||||
class TypedVariable(BaseModel):
|
||||
"""
|
||||
TODO: 强类型限制
|
||||
Strongly typed variable that validates value on assignment.
|
||||
"""
|
||||
|
||||
value: Any = Field(..., description="Variable value")
|
||||
type: VariableType = Field(..., description="Declared type of the variable")
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True
|
||||
)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name == "value":
|
||||
self._validate_value(value)
|
||||
if name == "type":
|
||||
raise RuntimeError("Cannot modify variable type at runtime")
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def _validate_value(self, v: Any):
|
||||
t = self.type
|
||||
match t:
|
||||
case VariableType.STRING:
|
||||
if not isinstance(v, str):
|
||||
raise TypeError("Variable value does not match type STRING")
|
||||
case VariableType.BOOLEAN:
|
||||
if not isinstance(v, bool):
|
||||
raise TypeError("Variable value does not match type BOOLEAN")
|
||||
case VariableType.NUMBER:
|
||||
if not isinstance(v, (int, float)):
|
||||
raise TypeError("Variable value does not match type NUMBER")
|
||||
case VariableType.OBJECT:
|
||||
if not isinstance(v, dict):
|
||||
raise TypeError("Variable value does not match type OBJECT")
|
||||
case VariableType.ARRAY_STRING:
|
||||
if not isinstance(v, list) or not all(isinstance(i, str) for i in v):
|
||||
raise TypeError("Variable value does not match type ARRAY_STRING")
|
||||
case VariableType.ARRAY_NUMBER:
|
||||
if not isinstance(v, list) or not all(isinstance(i, (int, float)) for i in v):
|
||||
raise TypeError("Variable value does not match type ARRAY_NUMBER")
|
||||
case VariableType.ARRAY_BOOLEAN:
|
||||
if not isinstance(v, list) or not all(isinstance(i, bool) for i in v):
|
||||
raise TypeError("Variable value does not match type ARRAY_BOOLEAN")
|
||||
case VariableType.ARRAY_OBJECT:
|
||||
if not isinstance(v, list) or not all(isinstance(i, dict) for i in v):
|
||||
raise TypeError("Variable value does not match type ARRAY_OBJECT")
|
||||
case _:
|
||||
raise TypeError(f"Unknown variable type: {t}")
|
||||
|
||||
|
||||
class VariableDefinition(BaseModel):
|
||||
"""变量定义
|
||||
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
"""
|
||||
工作流节点基类
|
||||
|
||||
定义节点的基本接口和通用功能。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from langgraph.config import get_stream_writer
|
||||
@@ -14,6 +9,7 @@ from typing_extensions import TypedDict, Annotated
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.workflow.nodes.enums import BRANCH_NODES
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -42,22 +38,10 @@ class WorkflowState(TypedDict):
|
||||
cycle_nodes: list
|
||||
looping: Annotated[int, merge_looping_state]
|
||||
|
||||
# Input variables (passed from configured variables)
|
||||
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
|
||||
variables: Annotated[dict[str, Any], lambda x, y: {
|
||||
**x,
|
||||
**{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v
|
||||
for k, v in y.items()}
|
||||
}]
|
||||
|
||||
# Node outputs (stores execution results of each node for variable references)
|
||||
# Uses a custom merge function to combine new node outputs into the existing dictionary
|
||||
node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# Runtime node variables (simplified version, stores business data for fast access between nodes)
|
||||
# Format: {node_id: business_result}
|
||||
runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# Execution context
|
||||
execution_id: str
|
||||
workspace_id: str
|
||||
@@ -72,17 +56,17 @@ class WorkflowState(TypedDict):
|
||||
|
||||
|
||||
class BaseNode(ABC):
|
||||
"""节点基类
|
||||
|
||||
所有节点类型都应该继承此基类,实现 execute 方法。
|
||||
"""Base class for workflow nodes.
|
||||
|
||||
All node types should inherit from this class and implement the `execute` method.
|
||||
"""
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
"""初始化节点
|
||||
|
||||
"""Initialize the node.
|
||||
|
||||
Args:
|
||||
node_config: 节点配置
|
||||
workflow_config: 工作流配置
|
||||
node_config: Configuration of the node.
|
||||
workflow_config: Configuration of the workflow.
|
||||
"""
|
||||
self.node_config = node_config
|
||||
self.workflow_config = workflow_config
|
||||
@@ -94,7 +78,27 @@ class BaseNode(ABC):
|
||||
self.config = node_config.get("config") or {}
|
||||
self.error_handling = node_config.get("error_handling") or {}
|
||||
|
||||
self.variable_updater = False
|
||||
self.variable_change_able = False
|
||||
|
||||
@cached_property
|
||||
def output_types(self) -> dict[str, VariableType]:
|
||||
"""Returns the output variable types of the node.
|
||||
|
||||
This property is cached to avoid recomputation.
|
||||
"""
|
||||
return self._output_types()
|
||||
|
||||
@abstractmethod
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
"""Defines output variable types for the node.
|
||||
|
||||
Subclasses must override this method to declare the variables
|
||||
produced by the node and their corresponding types.
|
||||
|
||||
Returns:
|
||||
A mapping from output variable names to ``VariableType``.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def check_activate(self, state: WorkflowState):
|
||||
"""Check if the current node is activated in the workflow state.
|
||||
@@ -136,92 +140,84 @@ class BaseNode(ABC):
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""执行节点业务逻辑(非流式)
|
||||
|
||||
节点只需要返回业务结果,不需要关心输出格式、时间统计等。
|
||||
BaseNode 会自动包装成标准格式。
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""Executes the node business logic (non-streaming).
|
||||
|
||||
The node implementation should only return the business result.
|
||||
It does not need to handle output formatting, timing, or statistics.
|
||||
The ``BaseNode`` will automatically wrap the result into a standard
|
||||
response format.
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
state: The current workflow state.
|
||||
variable_pool: The variable pool used for reading and writing variables.
|
||||
|
||||
Returns:
|
||||
业务结果(任意类型)
|
||||
|
||||
Examples:
|
||||
>>> # LLM 节点
|
||||
>>> "这是 AI 的回复"
|
||||
|
||||
>>> # Transform 节点
|
||||
>>> {"processed_data": [...]}
|
||||
|
||||
>>> # Start/End 节点
|
||||
>>> {"message": "开始", "conversation_id": "xxx"}
|
||||
The business result produced by the node. The return value can be
|
||||
of any type.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""执行节点业务逻辑(流式)
|
||||
|
||||
子类可以重写此方法以支持流式输出。
|
||||
默认实现:执行非流式方法并一次性返回。
|
||||
|
||||
节点需要:
|
||||
1. yield 中间结果(如文本片段)
|
||||
2. 最后 yield 一个特殊的完成标记:{"__final__": True, "result": final_result}
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Yields:
|
||||
业务数据(chunk)或完成标记
|
||||
|
||||
Examples:
|
||||
# 流式 LLM 节点
|
||||
full_response = ""
|
||||
async for chunk in llm.astream(prompt):
|
||||
full_response += chunk
|
||||
yield chunk # yield 文本片段
|
||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||
"""Executes the node business logic in streaming mode.
|
||||
|
||||
# 最后 yield 完成标记
|
||||
yield {"__final__": True, "result": AIMessage(content=full_response)}
|
||||
Subclasses may override this method to support streaming output.
|
||||
The default implementation executes the non-streaming method and
|
||||
yields a single final result.
|
||||
|
||||
For streaming execution, a node implementation should:
|
||||
1. Yield intermediate results (e.g. text chunks).
|
||||
2. Yield a final completion marker in the following format:
|
||||
``{"__final__": True, "result": final_result}``.
|
||||
|
||||
Args:
|
||||
state: The current workflow state.
|
||||
variable_pool: The variable pool used for reading and writing variables.
|
||||
|
||||
Yields:
|
||||
Business data chunks or a final completion marker.
|
||||
"""
|
||||
result = await self.execute(state)
|
||||
# 默认实现:直接 yield 完成标记
|
||||
result = await self.execute(state, variable_pool)
|
||||
# Default implementation: yield a single final completion marker.
|
||||
yield {"__final__": True, "result": result}
|
||||
|
||||
def supports_streaming(self) -> bool:
|
||||
"""节点是否支持流式输出
|
||||
|
||||
"""Returns whether the node supports streaming output.
|
||||
|
||||
A node is considered to support streaming if its class overrides
|
||||
the ``execute_stream`` method. If the default implementation from
|
||||
``BaseNode`` is used, streaming is not supported.
|
||||
|
||||
Returns:
|
||||
是否支持流式输出
|
||||
True if the node supports streaming output, False otherwise.
|
||||
"""
|
||||
# 检查子类是否重写了 execute_stream 方法
|
||||
# Check whether the subclass overrides the execute_stream method.
|
||||
return self.__class__.execute_stream is not BaseNode.execute_stream
|
||||
|
||||
def get_timeout(self) -> int:
|
||||
"""获取超时时间(秒)
|
||||
|
||||
@staticmethod
|
||||
def get_timeout() -> int:
|
||||
"""Returns the execution timeout in seconds.
|
||||
|
||||
Returns:
|
||||
超时时间
|
||||
The timeout duration, in seconds.
|
||||
"""
|
||||
return settings.WORKFLOW_NODE_TIMEOUT
|
||||
# return self.error_handling.get("timeout", 60)
|
||||
|
||||
async def run(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""执行节点(带错误处理和输出包装,非流式)
|
||||
|
||||
这个方法由 Executor 调用,负责:
|
||||
1. 时间统计
|
||||
2. 调用节点的 execute() 方法
|
||||
3. 将业务结果包装成标准输出格式
|
||||
4. 错误处理
|
||||
|
||||
async def run(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""Runs the node with error handling and output wrapping (non-streaming).
|
||||
|
||||
This method is invoked by the Executor and is responsible for:
|
||||
1. Execution time measurement.
|
||||
2. Invoking the node's ``execute()`` method.
|
||||
3. Wrapping the business result into a standardized output format.
|
||||
4. Handling execution errors.
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
state: The current workflow state.
|
||||
variable_pool: The variable pool used for reading and writing variables.
|
||||
|
||||
Returns:
|
||||
标准化的状态更新字典
|
||||
A standardized state update dictionary.
|
||||
"""
|
||||
if not self.check_activate(state):
|
||||
return self.trans_activate(state)
|
||||
@@ -233,70 +229,78 @@ class BaseNode(ABC):
|
||||
timeout = self.get_timeout()
|
||||
|
||||
try:
|
||||
# 调用节点的业务逻辑
|
||||
# Invoke the node business logic.
|
||||
business_result = await asyncio.wait_for(
|
||||
self.execute(state),
|
||||
self.execute(state, variable_pool),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 提取处理后的输出(调用子类的 _extract_output)
|
||||
# Extract processed outputs using subclass-defined logic.
|
||||
extracted_output = self._extract_output(business_result)
|
||||
|
||||
# 包装成标准输出格式
|
||||
wrapped_output = self._wrap_output(business_result, elapsed_time, state)
|
||||
# Wrap the business result into the standard output format.
|
||||
wrapped_output = self._wrap_output(business_result, elapsed_time, state, variable_pool)
|
||||
|
||||
# 将提取后的输出存储到运行时变量中(供后续节点快速访问)
|
||||
# 如果提取后的输出是字典,拆包存储;否则存储为 output 字段
|
||||
if isinstance(extracted_output, dict):
|
||||
runtime_var = extracted_output
|
||||
else:
|
||||
runtime_var = {"output": extracted_output}
|
||||
# Store extracted outputs as runtime variables for downstream nodes.
|
||||
if extracted_output is not None:
|
||||
runtime_vars = extracted_output
|
||||
if not isinstance(extracted_output, dict):
|
||||
runtime_vars = {"output": extracted_output}
|
||||
for k, v in runtime_vars.items():
|
||||
await variable_pool.new(self.node_id, k, v, self.output_types[k], mut=self.variable_change_able)
|
||||
|
||||
# 返回包装后的输出和运行时变量
|
||||
# Return the wrapped output along with activation state updates.
|
||||
return {
|
||||
**wrapped_output,
|
||||
"messages": state["messages"],
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
},
|
||||
"looping": state["looping"]
|
||||
} | self.trans_activate(state)
|
||||
|
||||
except TimeoutError:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
|
||||
return self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
|
||||
logger.error(
|
||||
f"Node {self.node_id} execution timed out ({timeout} seconds)."
|
||||
)
|
||||
return self._wrap_error(
|
||||
f"Node execution timed out ({timeout} seconds).",
|
||||
elapsed_time,
|
||||
state,
|
||||
variable_pool,
|
||||
)
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
|
||||
return self._wrap_error(str(e), elapsed_time, state)
|
||||
logger.error(
|
||||
f"Node {self.node_id} execution failed: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return self._wrap_error(str(e), elapsed_time, state, variable_pool)
|
||||
|
||||
async def run_stream(
|
||||
self, state: WorkflowState,
|
||||
variable_pool: VariablePool
|
||||
) -> AsyncGenerator[dict[str, Any], Any]:
|
||||
"""Executes the node with error handling and output wrapping (streaming).
|
||||
|
||||
async def run_stream(self, state: WorkflowState) -> AsyncGenerator[dict[str, Any], Any]:
|
||||
"""Execute node with error handling and output wrapping (streaming)
|
||||
|
||||
This method is called by the Executor and is responsible for:
|
||||
1. Time tracking
|
||||
2. Calling the node's execute_stream() method
|
||||
3. Using LangGraph's stream writer to send chunks
|
||||
4. Updating streaming buffer in state for downstream nodes
|
||||
5. Wrapping business data into standard output format
|
||||
6. Error handling
|
||||
|
||||
Special handling for End nodes:
|
||||
- End nodes don't send chunks via writer (prefix and LLM content already sent)
|
||||
- End nodes only yield suffix for final result assembly
|
||||
|
||||
1. Tracking execution time.
|
||||
2. Calling the node's ``execute_stream()`` method.
|
||||
3. Sending streaming chunks via LangGraph's stream writer.
|
||||
4. Updating activation-related state for downstream nodes.
|
||||
5. Wrapping business data into a standardized output format.
|
||||
6. Handling execution errors.
|
||||
|
||||
Args:
|
||||
state: Workflow state
|
||||
|
||||
state: The current workflow state.
|
||||
variable_pool: The variable pool used for reading and writing variables.
|
||||
|
||||
Yields:
|
||||
State updates with streaming buffer and final result
|
||||
Incremental state updates, including activation state changes and
|
||||
the final wrapped result.
|
||||
"""
|
||||
if not self.check_activate(state):
|
||||
yield self.trans_activate(state)
|
||||
logger.info(f"jump node: {self.node_id}")
|
||||
logger.debug(f"jump node: {self.node_id}")
|
||||
return
|
||||
|
||||
import time
|
||||
@@ -317,7 +321,7 @@ class BaseNode(ABC):
|
||||
# Stream chunks in real-time
|
||||
loop_start = asyncio.get_event_loop().time()
|
||||
|
||||
async for item in self.execute_stream(state):
|
||||
async for item in self.execute_stream(state, variable_pool):
|
||||
# Check timeout
|
||||
if asyncio.get_event_loop().time() - loop_start > timeout:
|
||||
raise TimeoutError()
|
||||
@@ -332,7 +336,7 @@ class BaseNode(ABC):
|
||||
chunks.append(content)
|
||||
|
||||
# Send chunks for all nodes (including End nodes for suffix)
|
||||
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {content[:50]}...")
|
||||
logger.debug(f"Node {self.node_id} sent chunk #{chunk_count}: {content[:50]}...")
|
||||
|
||||
# 1. Send via stream writer (for real-time client updates)
|
||||
writer({
|
||||
@@ -344,27 +348,26 @@ class BaseNode(ABC):
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
|
||||
logger.info(f"Node {self.node_id} streaming execution finished, "
|
||||
f"time elapsed: {elapsed_time:.2f}s, chunks: {chunk_count}")
|
||||
|
||||
# Extract processed output (call subclass's _extract_output)
|
||||
extracted_output = self._extract_output(final_result)
|
||||
|
||||
# Wrap final result
|
||||
final_output = self._wrap_output(final_result, elapsed_time, state)
|
||||
final_output = self._wrap_output(final_result, elapsed_time, state, variable_pool)
|
||||
|
||||
# Store extracted output in runtime variables (for quick access by subsequent nodes)
|
||||
if isinstance(extracted_output, dict):
|
||||
runtime_var = extracted_output
|
||||
else:
|
||||
runtime_var = {"output": extracted_output}
|
||||
if extracted_output is not None:
|
||||
runtime_vars = extracted_output
|
||||
if not isinstance(extracted_output, dict):
|
||||
runtime_vars = {"output": extracted_output}
|
||||
for k, v in runtime_vars.items():
|
||||
await variable_pool.new(self.node_id, k, v, self.output_types[k], mut=self.variable_change_able)
|
||||
|
||||
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
|
||||
state_update = {
|
||||
**final_output,
|
||||
"messages": state["messages"],
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
},
|
||||
"looping": state["looping"]
|
||||
}
|
||||
|
||||
@@ -374,41 +377,49 @@ class BaseNode(ABC):
|
||||
|
||||
except TimeoutError:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行超时 ({timeout}s)")
|
||||
error_output = self._wrap_error(f"节点执行超时 ({timeout}s)", elapsed_time, state)
|
||||
logger.error(f"Node {self.node_id} execution timed out ({timeout}s)")
|
||||
error_output = self._wrap_error(
|
||||
f"Node execution timed out ({timeout}s)",
|
||||
elapsed_time,
|
||||
state,
|
||||
variable_pool
|
||||
)
|
||||
yield error_output
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
|
||||
error_output = self._wrap_error(str(e), elapsed_time, state)
|
||||
logger.error(f"Node {self.node_id} execution failed: {e}", exc_info=True)
|
||||
error_output = self._wrap_error(str(e), elapsed_time, state, variable_pool)
|
||||
yield error_output
|
||||
|
||||
def _wrap_output(
|
||||
self,
|
||||
business_result: Any,
|
||||
elapsed_time: float,
|
||||
state: WorkflowState
|
||||
state: WorkflowState,
|
||||
variable_pool: VariablePool
|
||||
) -> dict[str, Any]:
|
||||
"""将业务结果包装成标准输出格式
|
||||
|
||||
Args:
|
||||
business_result: 节点返回的业务结果
|
||||
elapsed_time: 执行耗时
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
标准化的状态更新字典
|
||||
"""
|
||||
# 提取输入数据(用于记录)
|
||||
input_data = self._extract_input(state)
|
||||
"""Wraps the business result into a standardized node output format.
|
||||
|
||||
# 提取 token 使用情况(如果有)
|
||||
Args:
|
||||
business_result: The result returned by the node's business logic.
|
||||
elapsed_time: Time elapsed during node execution (in seconds).
|
||||
state: The current workflow state.
|
||||
variable_pool: The variable pool used for reading and writing variables.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the standardized state update for this node,
|
||||
including node outputs, input, output, elapsed time, token usage, and status.
|
||||
"""
|
||||
# Extract input data (for logging or audit purposes)
|
||||
input_data = self._extract_input(state, variable_pool)
|
||||
|
||||
# Extract token usage information (if applicable)
|
||||
token_usage = self._extract_token_usage(business_result)
|
||||
|
||||
# 提取实际输出(去除元数据)
|
||||
# Extract actual output (strip any metadata)
|
||||
output = self._extract_output(business_result)
|
||||
|
||||
# 构建标准节点输出
|
||||
# Construct standardized node output
|
||||
node_output = {
|
||||
"node_id": self.node_id,
|
||||
"node_type": self.node_type,
|
||||
@@ -423,8 +434,6 @@ class BaseNode(ABC):
|
||||
final_output = {
|
||||
"node_outputs": {self.node_id: node_output},
|
||||
}
|
||||
if self.variable_updater:
|
||||
final_output = final_output | {"variables": state["variables"]}
|
||||
|
||||
return final_output
|
||||
|
||||
@@ -432,25 +441,33 @@ class BaseNode(ABC):
|
||||
self,
|
||||
error_message: str,
|
||||
elapsed_time: float,
|
||||
state: WorkflowState
|
||||
state: WorkflowState,
|
||||
variable_pool: VariablePool
|
||||
) -> dict[str, Any]:
|
||||
"""将错误包装成标准输出格式
|
||||
|
||||
"""Wraps an error into a standardized node output format.
|
||||
|
||||
This method handles both cases:
|
||||
- If an error edge is defined, the workflow can continue to the error handling node.
|
||||
- If no error edge exists, the workflow is stopped by raising an exception.
|
||||
|
||||
Args:
|
||||
error_message: 错误信息
|
||||
elapsed_time: 执行耗时
|
||||
state: 工作流状态
|
||||
|
||||
error_message: The error message describing the failure.
|
||||
elapsed_time: Time elapsed during node execution (in seconds).
|
||||
state: The current workflow state.
|
||||
variable_pool: The variable pool used for reading and writing variables.
|
||||
|
||||
Returns:
|
||||
标准化的状态更新字典
|
||||
A dictionary representing the standardized state update for this node
|
||||
when an error edge exists. If no error edge exists, this method
|
||||
raises an exception to stop the workflow.
|
||||
"""
|
||||
# 查找错误边
|
||||
# Check if the node has an error edge defined
|
||||
error_edge = self._find_error_edge()
|
||||
|
||||
# 提取输入数据
|
||||
input_data = self._extract_input(state)
|
||||
# Extract input data (for logging or audit purposes)
|
||||
input_data = self._extract_input(state, variable_pool)
|
||||
|
||||
# 构建错误输出
|
||||
# Construct the standardized node output for the error
|
||||
node_output = {
|
||||
"node_id": self.node_id,
|
||||
"node_type": self.node_type,
|
||||
@@ -464,9 +481,9 @@ class BaseNode(ABC):
|
||||
}
|
||||
|
||||
if error_edge:
|
||||
# 有错误边:记录错误并继续
|
||||
# If an error edge exists, log a warning and continue to error node
|
||||
logger.warning(
|
||||
f"节点 {self.node_id} 执行失败,跳转到错误处理节点: {error_edge['target']}"
|
||||
f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
|
||||
)
|
||||
return {
|
||||
"node_outputs": {
|
||||
@@ -476,198 +493,161 @@ class BaseNode(ABC):
|
||||
"error_node": self.node_id
|
||||
}
|
||||
else:
|
||||
# If no error edge, send the error via stream writer and stop the workflow
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"type": "node_error",
|
||||
**node_output
|
||||
})
|
||||
# 无错误边:抛出异常停止工作流
|
||||
logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}")
|
||||
raise Exception(f"节点 {self.node_id} 执行失败: {error_message}")
|
||||
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
|
||||
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""Extracts the input data for this node (used for logging or audit).
|
||||
|
||||
Subclasses may override this method to customize what input data
|
||||
should be recorded.
|
||||
|
||||
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""提取节点输入数据(用于记录)
|
||||
|
||||
子类可以重写此方法来自定义输入记录。
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
state: The current workflow state.
|
||||
variable_pool: The variable pool used for reading and writing variables.
|
||||
|
||||
Returns:
|
||||
输入数据字典
|
||||
A dictionary containing the node's input data.
|
||||
"""
|
||||
# 默认返回配置
|
||||
# Default implementation returns the node configuration
|
||||
return {"config": self.config}
|
||||
|
||||
def _extract_output(self, business_result: Any) -> Any:
|
||||
"""从业务结果中提取实际输出
|
||||
|
||||
子类可以重写此方法来自定义输出提取。
|
||||
|
||||
"""Extracts the actual output from the business result.
|
||||
|
||||
Subclasses may override this method to customize how the node's
|
||||
output is extracted.
|
||||
|
||||
Args:
|
||||
business_result: 业务结果
|
||||
|
||||
business_result: The result returned by the node's business logic.
|
||||
|
||||
Returns:
|
||||
实际输出
|
||||
The actual output extracted from the business result.
|
||||
"""
|
||||
# 默认直接返回业务结果
|
||||
# Default implementation returns the business result directly
|
||||
return business_result
|
||||
|
||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||
"""从业务结果中提取 token 使用情况
|
||||
|
||||
子类可以重写此方法来提取 token 信息。
|
||||
|
||||
"""Extracts token usage information from the business result.
|
||||
|
||||
Subclasses may override this method to extract token usage statistics
|
||||
(e.g., for LLM nodes).
|
||||
|
||||
Args:
|
||||
business_result: 业务结果
|
||||
|
||||
business_result: The result returned by the node's business logic.
|
||||
|
||||
Returns:
|
||||
token 使用情况或 None
|
||||
A dictionary mapping token types to counts, or None if not applicable.
|
||||
"""
|
||||
# 默认返回 None
|
||||
# Default implementation returns None
|
||||
return None
|
||||
|
||||
def _find_error_edge(self) -> dict[str, Any] | None:
|
||||
"""查找错误边
|
||||
|
||||
"""Finds the error edge for this node, if any.
|
||||
|
||||
An error edge is used to redirect workflow execution when this node
|
||||
fails.
|
||||
|
||||
Returns:
|
||||
错误边配置或 None
|
||||
A dictionary representing the error edge configuration if it exists,
|
||||
or None if no error edge is defined.
|
||||
"""
|
||||
for edge in self.workflow_config.get("edges", []):
|
||||
if edge.get("source") == self.node_id and edge.get("type") == "error":
|
||||
return edge
|
||||
return None
|
||||
|
||||
def _render_template(self, template: str, state: WorkflowState | None, strict: bool = True) -> str:
|
||||
"""渲染模板
|
||||
|
||||
支持的变量命名空间:
|
||||
- sys.xxx: 系统变量(message, execution_id, workspace_id, user_id, conversation_id)
|
||||
- conv.xxx: 会话变量(跨多轮对话保持)
|
||||
- node_id.xxx: 节点输出
|
||||
|
||||
@staticmethod
|
||||
def _render_template(template: str, variable_pool: VariablePool, strict: bool = True) -> str:
|
||||
"""Renders a template string using the provided variable pool.
|
||||
|
||||
Supported variable namespaces:
|
||||
- sys.xxx: System variables (e.g., message, execution_id, workspace_id,
|
||||
user_id, conversation_id)
|
||||
- conv.xxx: Conversation variables (persist across multiple turns)
|
||||
- node_id.xxx: Node outputs
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
state: 工作流状态
|
||||
|
||||
template: The template string to render.
|
||||
variable_pool: The variable pool containing system, conversation, and
|
||||
node variables.
|
||||
strict: If True, missing variables will raise an error; if False,
|
||||
missing variables are ignored.
|
||||
|
||||
Returns:
|
||||
渲染后的字符串
|
||||
The rendered string with all variables substituted.
|
||||
"""
|
||||
from app.core.workflow.template_renderer import render_template
|
||||
|
||||
# 处理 state 为 None 的情况
|
||||
if state is None:
|
||||
state = {}
|
||||
|
||||
# 使用变量池获取变量
|
||||
pool = VariablePool(state)
|
||||
|
||||
# 构建完整的 variables 结构
|
||||
variables = {
|
||||
"sys": pool.get_all_system_vars(),
|
||||
"conv": pool.get_all_conversation_vars()
|
||||
}
|
||||
|
||||
return render_template(
|
||||
template=template,
|
||||
variables=variables,
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars(),
|
||||
conv_vars=variable_pool.get_all_conversation_vars(),
|
||||
node_outputs=variable_pool.get_all_node_outputs(),
|
||||
system_vars=variable_pool.get_all_system_vars(),
|
||||
strict=strict
|
||||
)
|
||||
|
||||
def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool:
|
||||
"""评估条件表达式
|
||||
|
||||
支持的变量命名空间:
|
||||
- sys.xxx: 系统变量
|
||||
- conv.xxx: 会话变量
|
||||
- node_id.xxx: 节点输出
|
||||
|
||||
@staticmethod
|
||||
def _evaluate_condition(expression: str, variable_pool: VariablePool) -> bool:
|
||||
"""Evaluates a conditional expression using the provided variable pool.
|
||||
|
||||
Supported variable namespaces:
|
||||
- sys.xxx: System variables
|
||||
- conv.xxx: Conversation variables
|
||||
- node_id.xxx: Node outputs
|
||||
|
||||
Args:
|
||||
expression: 条件表达式
|
||||
state: 工作流状态
|
||||
|
||||
expression: The conditional expression to evaluate.
|
||||
variable_pool: The variable pool containing system, conversation, and
|
||||
node variables.
|
||||
|
||||
Returns:
|
||||
布尔值结果
|
||||
The boolean result of evaluating the expression.
|
||||
"""
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
|
||||
# 处理 state 为 None 的情况
|
||||
if state is None:
|
||||
state = {}
|
||||
|
||||
# 使用变量池获取变量
|
||||
pool = VariablePool(state)
|
||||
|
||||
# 构建完整的 variables 结构(包含 sys 和 conv)
|
||||
variables = {
|
||||
"sys": pool.get_all_system_vars(),
|
||||
"conv": pool.get_all_conversation_vars()
|
||||
}
|
||||
|
||||
return evaluate_condition(
|
||||
expression=expression,
|
||||
variables=variables,
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars()
|
||||
conv_var=variable_pool.get_all_conversation_vars(),
|
||||
node_outputs=variable_pool.get_all_node_outputs(),
|
||||
system_vars=variable_pool.get_all_system_vars()
|
||||
)
|
||||
|
||||
def get_variable_pool(self, state: WorkflowState) -> VariablePool:
|
||||
"""获取变量池实例
|
||||
|
||||
VariablePool 是轻量级包装器,只持有 state 的引用,创建成本极低。
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
VariablePool 实例
|
||||
|
||||
Examples:
|
||||
>>> pool = self.get_variable_pool(state)
|
||||
>>> message = pool.get("sys.message")
|
||||
>>> llm_output = pool.get("llm_qa.output")
|
||||
"""
|
||||
return VariablePool(state)
|
||||
|
||||
@staticmethod
|
||||
def get_variable(
|
||||
self,
|
||||
selector: list[str] | str,
|
||||
state: WorkflowState,
|
||||
default: Any = None
|
||||
selector: str,
|
||||
variable_pool: VariablePool,
|
||||
default: Any = None,
|
||||
strict: bool = True
|
||||
) -> Any:
|
||||
"""获取变量值(便捷方法)
|
||||
|
||||
Args:
|
||||
selector: 变量选择器
|
||||
state: 工作流状态
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
变量值
|
||||
|
||||
Examples:
|
||||
>>> message = self.get_variable("sys.message", state)
|
||||
>>> output = self.get_variable(["llm_qa", "output"], state)
|
||||
>>> custom = self.get_variable("var.custom", state, default="默认值")
|
||||
"""
|
||||
pool = VariablePool(state)
|
||||
return pool.get(selector, default=default)
|
||||
"""Retrieves a variable value from the variable pool (convenience method).
|
||||
|
||||
def has_variable(self, selector: list[str] | str, state: WorkflowState) -> bool:
|
||||
"""检查变量是否存在(便捷方法)
|
||||
|
||||
Args:
|
||||
selector: 变量选择器
|
||||
state: 工作流状态
|
||||
|
||||
selector: The variable selector (can be namespaced, e.g., sys.xxx, conv.xxx, node_id.xxx).
|
||||
variable_pool: The variable pool from which to fetch the value.
|
||||
default: The default value to return if the variable does not exist.
|
||||
strict: If True, raise an error when the variable is missing; if False, return the default.
|
||||
|
||||
Returns:
|
||||
变量是否存在
|
||||
|
||||
Examples:
|
||||
>>> if self.has_variable("llm_qa.output", state):
|
||||
... output = self.get_variable("llm_qa.output", state)
|
||||
The value of the selected variable, or the default if not found and strict is False.
|
||||
"""
|
||||
pool = VariablePool(state)
|
||||
return pool.has(selector)
|
||||
return variable_pool.get_value(selector, default, strict=strict)
|
||||
|
||||
@staticmethod
|
||||
def has_variable(selector: str, variable_pool: VariablePool) -> bool:
|
||||
"""Checks whether a variable exists in the variable pool (convenience method).
|
||||
|
||||
Args:
|
||||
selector: The variable selector (can be namespaced, e.g., sys.xxx, conv.xxx, node_id.xxx).
|
||||
variable_pool: The variable pool to check.
|
||||
|
||||
Returns:
|
||||
True if the variable exists in the pool, False otherwise.
|
||||
"""
|
||||
return variable_pool.has(selector)
|
||||
|
||||
@@ -2,6 +2,8 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -14,15 +16,19 @@ class BreakNode(BaseNode):
|
||||
to False, signaling the outer loop runtime to terminate further iterations.
|
||||
"""
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
Execute the break node.
|
||||
|
||||
Args:
|
||||
state: Current workflow state, including loop control flags.
|
||||
variable_pool: Pool of variables for the workflow.
|
||||
|
||||
Effects:
|
||||
- Sets 'looping' in the state to False to stop the loop.
|
||||
- Sets 'looping' in the state too False to stop the loop.
|
||||
- Logs the action for debugging purposes.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Literal
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
|
||||
class InputVariable(BaseModel):
|
||||
@@ -44,7 +45,7 @@ class CodeNodeConfig(BaseNodeConfig):
|
||||
description="code content"
|
||||
)
|
||||
|
||||
language: Literal['python3', 'nodejs'] = Field(
|
||||
language: Literal['python3', 'javascript'] = Field(
|
||||
...,
|
||||
description="language"
|
||||
)
|
||||
|
||||
@@ -9,8 +9,9 @@ from typing import Any
|
||||
import httpx
|
||||
|
||||
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.base_config import VariableType
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.nodes.code.config import CodeNodeConfig
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -52,6 +53,12 @@ class CodeNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: CodeNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
output_dict = {}
|
||||
for output in self.typed_config.output_variables:
|
||||
output_dict[output.name] = output.type
|
||||
return output_dict
|
||||
|
||||
def extract_result(self, content: str):
|
||||
match = re.search(r'<<RESULT>>(.*?)<<RESULT>>', content, re.DOTALL)
|
||||
if match:
|
||||
@@ -92,11 +99,11 @@ class CodeNode(BaseNode):
|
||||
else:
|
||||
raise RuntimeError("The output of main must be a dictionary")
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
self.typed_config = CodeNodeConfig(**self.config)
|
||||
input_variable_dict = {}
|
||||
for input_variable in self.typed_config.input_variables:
|
||||
input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, state)
|
||||
input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, variable_pool)
|
||||
|
||||
code = base64.b64decode(
|
||||
self.typed_config.code
|
||||
@@ -110,7 +117,7 @@ class CodeNode(BaseNode):
|
||||
code=code,
|
||||
inputs_variable=input_variable_dict,
|
||||
)
|
||||
elif self.typed_config.language == 'nodejs':
|
||||
elif self.typed_config.language == 'javascript':
|
||||
final_script = NODEJS_SCRIPT_TEMPLATE.substitute(
|
||||
code=code,
|
||||
inputs_variable=input_variable_dict,
|
||||
|
||||
@@ -8,7 +8,6 @@ from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
||||
from app.core.workflow.nodes.base_config import (
|
||||
BaseNodeConfig,
|
||||
VariableDefinition,
|
||||
VariableType,
|
||||
)
|
||||
from app.core.workflow.nodes.code.config import CodeNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
||||
@@ -23,21 +22,18 @@ from app.core.workflow.nodes.parameter_extractor.config import ParameterExtracto
|
||||
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
|
||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||
|
||||
__all__ = [
|
||||
# 基础类
|
||||
"BaseNodeConfig",
|
||||
"VariableDefinition",
|
||||
"VariableType",
|
||||
# 节点配置
|
||||
"StartNodeConfig",
|
||||
"EndNodeConfig",
|
||||
"LLMNodeConfig",
|
||||
"MessageConfig",
|
||||
"AgentNodeConfig",
|
||||
"TransformNodeConfig",
|
||||
"IfElseNodeConfig",
|
||||
"KnowledgeRetrievalNodeConfig",
|
||||
"AssignerNodeConfig",
|
||||
|
||||
@@ -2,7 +2,8 @@ from typing import Any
|
||||
|
||||
from pydantic import Field, BaseModel, field_validator
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
|
||||
|
||||
|
||||
@@ -127,4 +128,9 @@ class IterationNodeConfig(BaseNodeConfig):
|
||||
description="Output of the loop iteration"
|
||||
)
|
||||
|
||||
output_type: VariableType = Field(
|
||||
...,
|
||||
description="Data type of the loop iteration output"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -28,6 +29,8 @@ class IterationRuntime:
|
||||
node_id: str,
|
||||
config: dict[str, Any],
|
||||
state: WorkflowState,
|
||||
variable_pool: VariablePool,
|
||||
child_variable_pool: VariablePool,
|
||||
):
|
||||
"""
|
||||
Initialize the iteration runtime.
|
||||
@@ -44,11 +47,13 @@ class IterationRuntime:
|
||||
self.node_id = node_id
|
||||
self.typed_config = IterationNodeConfig(**config)
|
||||
self.looping = True
|
||||
self.variable_pool = variable_pool
|
||||
self.child_variable_pool = child_variable_pool
|
||||
|
||||
self.output_value = None
|
||||
self.result: list = []
|
||||
|
||||
def _init_iteration_state(self, item, idx):
|
||||
async def _init_iteration_state(self, item, idx):
|
||||
"""
|
||||
Initialize a per-iteration copy of the workflow state.
|
||||
|
||||
@@ -62,10 +67,9 @@ class IterationRuntime:
|
||||
loopstate = WorkflowState(
|
||||
**self.state
|
||||
)
|
||||
loopstate["runtime_vars"][self.node_id] = {
|
||||
"item": item,
|
||||
"index": idx,
|
||||
}
|
||||
self.child_variable_pool.copy(self.variable_pool)
|
||||
await self.child_variable_pool.new(self.node_id, "item", item, VariableType.type_map(item), mut=True)
|
||||
await self.child_variable_pool.new(self.node_id, "index", item, VariableType.type_map(item), mut=True)
|
||||
loopstate["node_outputs"][self.node_id] = {
|
||||
"item": item,
|
||||
"index": idx,
|
||||
@@ -74,6 +78,11 @@ class IterationRuntime:
|
||||
loopstate["activate"][self.start_id] = True
|
||||
return loopstate
|
||||
|
||||
def merge_conv_vars(self):
|
||||
self.variable_pool.get_all_conversation_vars().update(
|
||||
self.child_variable_pool.get_all_conversation_vars()
|
||||
)
|
||||
|
||||
async def run_task(self, item, idx):
|
||||
"""
|
||||
Execute a single iteration asynchronously.
|
||||
@@ -82,8 +91,8 @@ class IterationRuntime:
|
||||
item: The input element for this iteration.
|
||||
idx: The index of this iteration.
|
||||
"""
|
||||
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
|
||||
output = VariablePool(result).get(self.output_value)
|
||||
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
|
||||
output = self.child_variable_pool.get_value(self.output_value)
|
||||
if isinstance(output, list) and self.typed_config.flatten:
|
||||
self.result.extend(output)
|
||||
else:
|
||||
@@ -125,7 +134,7 @@ class IterationRuntime:
|
||||
input_expression = re.sub(pattern, r"\1", self.typed_config.input).strip()
|
||||
self.output_value = re.sub(pattern, r"\1", self.typed_config.output).strip()
|
||||
|
||||
array_obj = VariablePool(self.state).get(input_expression)
|
||||
array_obj = self.variable_pool.get_value(input_expression)
|
||||
if not isinstance(array_obj, list):
|
||||
raise RuntimeError("Cannot iterate over a non-list variable")
|
||||
child_state = []
|
||||
@@ -137,14 +146,16 @@ class IterationRuntime:
|
||||
logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}")
|
||||
idx += self.typed_config.parallel_count
|
||||
child_state.extend(await asyncio.gather(*tasks))
|
||||
self.merge_conv_vars()
|
||||
else:
|
||||
# Execute iterations sequentially
|
||||
while idx < len(array_obj) and self.looping:
|
||||
logger.info(f"Iteration node {self.node_id}: running")
|
||||
item = array_obj[idx]
|
||||
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
|
||||
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
|
||||
child_state.append(result)
|
||||
output = VariablePool(result).get(self.output_value)
|
||||
output = self.child_variable_pool.get_value(self.output_value)
|
||||
self.merge_conv_vars()
|
||||
if isinstance(output, list) and self.typed_config.flatten:
|
||||
self.result.extend(output)
|
||||
else:
|
||||
|
||||
@@ -31,6 +31,8 @@ class LoopRuntime:
|
||||
node_id: str,
|
||||
config: dict[str, Any],
|
||||
state: WorkflowState,
|
||||
variable_pool: VariablePool,
|
||||
child_variable_pool: VariablePool
|
||||
):
|
||||
"""
|
||||
Initialize the loop runtime executor.
|
||||
@@ -40,6 +42,8 @@ class LoopRuntime:
|
||||
node_id: The unique identifier of the loop node in the workflow.
|
||||
config: Raw configuration dictionary for the loop node.
|
||||
state: The current workflow state before entering the loop.
|
||||
variable_pool: A VariablePool instance for accessing and modifying workflow variables.
|
||||
child_variable_pool: A VariablePool instance for managing child node outputs.
|
||||
"""
|
||||
self.start_id = start_id
|
||||
self.graph = graph
|
||||
@@ -47,8 +51,10 @@ class LoopRuntime:
|
||||
self.node_id = node_id
|
||||
self.typed_config = LoopNodeConfig(**config)
|
||||
self.looping = True
|
||||
self.variable_pool = variable_pool
|
||||
self.child_variable_pool = child_variable_pool
|
||||
|
||||
def _init_loop_state(self):
|
||||
async def _init_loop_state(self):
|
||||
"""
|
||||
Initialize workflow state for loop execution.
|
||||
|
||||
@@ -62,33 +68,35 @@ class LoopRuntime:
|
||||
Returns:
|
||||
WorkflowState: A prepared workflow state used for loop execution.
|
||||
"""
|
||||
pool = VariablePool(self.state)
|
||||
# 循环变量
|
||||
self.state["runtime_vars"][self.node_id] = {
|
||||
variable.name: evaluate_expression(
|
||||
expression=variable.value,
|
||||
variables=pool.get_all_conversation_vars(),
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars(),
|
||||
)
|
||||
if variable.input_type == ValueInputType.VARIABLE
|
||||
else TypeTransformer.transform(variable.value, variable.type)
|
||||
for variable in self.typed_config.cycle_vars
|
||||
}
|
||||
self.state["node_outputs"][self.node_id] = {
|
||||
variable.name: evaluate_expression(
|
||||
expression=variable.value,
|
||||
variables=pool.get_all_conversation_vars(),
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars(),
|
||||
)
|
||||
if variable.input_type == ValueInputType.VARIABLE
|
||||
else TypeTransformer.transform(variable.value, variable.type)
|
||||
for variable in self.typed_config.cycle_vars
|
||||
}
|
||||
self.child_variable_pool.copy(self.variable_pool)
|
||||
|
||||
for variable in self.typed_config.cycle_vars:
|
||||
if variable.input_type == ValueInputType.VARIABLE:
|
||||
value = evaluate_expression(
|
||||
expression=variable.value,
|
||||
conv_var=self.variable_pool.get_all_conversation_vars(),
|
||||
node_outputs=self.variable_pool.get_all_node_outputs(),
|
||||
system_vars=self.variable_pool.get_all_system_vars(),
|
||||
)
|
||||
else:
|
||||
value = TypeTransformer.transform(variable.value, variable.type)
|
||||
await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True)
|
||||
loopstate = WorkflowState(
|
||||
**self.state
|
||||
)
|
||||
loopstate["node_outputs"][self.node_id] = {
|
||||
variable.name: evaluate_expression(
|
||||
expression=variable.value,
|
||||
conv_var=self.variable_pool.get_all_conversation_vars(),
|
||||
node_outputs=self.variable_pool.get_all_node_outputs(),
|
||||
system_vars=self.variable_pool.get_all_system_vars(),
|
||||
)
|
||||
if variable.input_type == ValueInputType.VARIABLE
|
||||
else TypeTransformer.transform(variable.value, variable.type)
|
||||
for variable in self.typed_config.cycle_vars
|
||||
}
|
||||
|
||||
loopstate["looping"] = 1
|
||||
loopstate["activate"][self.start_id] = True
|
||||
return loopstate
|
||||
@@ -134,7 +142,12 @@ class LoopRuntime:
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {operator}")
|
||||
|
||||
def evaluate_conditional(self, state) -> bool:
|
||||
def merge_conv_vars(self):
|
||||
self.variable_pool.variables["conv"].update(
|
||||
self.child_variable_pool.variables.get("conv", {})
|
||||
)
|
||||
|
||||
def evaluate_conditional(self) -> bool:
|
||||
"""
|
||||
Evaluate the loop continuation condition at runtime.
|
||||
|
||||
@@ -143,18 +156,15 @@ class LoopRuntime:
|
||||
- Evaluates each comparison expression immediately
|
||||
- Combines results using the configured logical operator (AND / OR)
|
||||
|
||||
Args:
|
||||
state: The current workflow state during loop execution.
|
||||
|
||||
Returns:
|
||||
bool: True if the loop should continue, False otherwise.
|
||||
"""
|
||||
conditions = []
|
||||
|
||||
for expression in self.typed_config.condition.expressions:
|
||||
left_value = VariablePool(state).get(expression.left)
|
||||
left_value = self.child_variable_pool.get_value(expression.left)
|
||||
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
|
||||
VariablePool(state),
|
||||
self.child_variable_pool,
|
||||
expression.left,
|
||||
expression.right,
|
||||
expression.input_type
|
||||
@@ -177,16 +187,18 @@ class LoopRuntime:
|
||||
Returns:
|
||||
dict[str, Any]: The final runtime variables of this loop node.
|
||||
"""
|
||||
loopstate = self._init_loop_state()
|
||||
loopstate = await self._init_loop_state()
|
||||
loop_time = self.typed_config.max_loop
|
||||
child_state = []
|
||||
while self.evaluate_conditional(loopstate) and self.looping and loop_time > 0:
|
||||
while not self.evaluate_conditional() and self.looping and loop_time > 0:
|
||||
logger.info(f"loop node {self.node_id}: running")
|
||||
result = await self.graph.ainvoke(loopstate)
|
||||
child_state.append(result)
|
||||
|
||||
self.merge_conv_vars()
|
||||
if result["looping"] == 2:
|
||||
self.looping = False
|
||||
loop_time -= 1
|
||||
|
||||
logger.info(f"loop node {self.node_id}: execution completed")
|
||||
return loopstate["runtime_vars"][self.node_id] | {"__child_state": child_state}
|
||||
return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state}
|
||||
|
||||
@@ -6,9 +6,12 @@ from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime
|
||||
from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,9 +38,38 @@ class CycleGraphNode(BaseNode):
|
||||
self.start_node_id = None # ID of the start node within the cycle
|
||||
|
||||
self.graph: StateGraph | CompiledStateGraph | None = None
|
||||
self.child_variable_pool: VariablePool | None = None
|
||||
self.build_graph()
|
||||
self.iteration_flag = True
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
outputs = {"__child_state": VariableType.ARRAY_OBJECT}
|
||||
if self.node_type == NodeType.LOOP:
|
||||
# Loop node outputs the final state of the loop
|
||||
config = LoopNodeConfig(**self.config)
|
||||
for var_def in config.cycle_vars:
|
||||
outputs[var_def.name] = var_def.type
|
||||
return outputs
|
||||
elif self.node_type == NodeType.ITERATION:
|
||||
# Iteration node outputs the processed collection
|
||||
config = IterationNodeConfig(**self.config)
|
||||
if config.output_type in [
|
||||
VariableType.ARRAY_FILE,
|
||||
VariableType.ARRAY_STRING,
|
||||
VariableType.NUMBER,
|
||||
VariableType.ARRAY_OBJECT,
|
||||
VariableType.BOOLEAN
|
||||
]:
|
||||
if config.flatten:
|
||||
outputs['output'] = config.output_type
|
||||
else:
|
||||
outputs['output'] = VariableType.ARRAY_STRING
|
||||
else:
|
||||
outputs['output'] = VariableType(f"array[{config.output_type}]")
|
||||
return outputs
|
||||
else:
|
||||
raise KeyError(f"Valid Cycle Node Type - {self.node_type}")
|
||||
|
||||
def pure_cycle_graph(self) -> tuple[list, list]:
|
||||
"""
|
||||
Extract cycle-scoped nodes and internal edges from the workflow configuration.
|
||||
@@ -103,17 +135,20 @@ class CycleGraphNode(BaseNode):
|
||||
"""
|
||||
from app.core.workflow.graph_builder import GraphBuilder
|
||||
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
||||
self.child_variable_pool = VariablePool()
|
||||
builder = GraphBuilder(
|
||||
{
|
||||
"nodes": self.cycle_nodes,
|
||||
"edges": self.cycle_edges,
|
||||
},
|
||||
subgraph=True
|
||||
subgraph=True,
|
||||
variable_pool=self.child_variable_pool
|
||||
)
|
||||
self.start_node_id = builder.start_node_id
|
||||
self.graph = builder.build()
|
||||
self.child_variable_pool = builder.variable_pool
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
Execute the cycle node at runtime.
|
||||
|
||||
@@ -123,6 +158,7 @@ class CycleGraphNode(BaseNode):
|
||||
|
||||
Args:
|
||||
state: The current workflow state when entering the cycle node.
|
||||
variable_pool: Variable Pool
|
||||
|
||||
Returns:
|
||||
Any: The runtime result produced by the loop or iteration executor.
|
||||
@@ -137,6 +173,8 @@ class CycleGraphNode(BaseNode):
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
state=state,
|
||||
variable_pool=variable_pool,
|
||||
child_variable_pool=self.child_variable_pool,
|
||||
).run()
|
||||
if self.node_type == NodeType.ITERATION:
|
||||
return await IterationRuntime(
|
||||
@@ -145,5 +183,7 @@ class CycleGraphNode(BaseNode):
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
state=state,
|
||||
variable_pool=variable_pool,
|
||||
child_variable_pool=self.child_variable_pool
|
||||
).run()
|
||||
raise RuntimeError("Unknown cycle node type")
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
|
||||
class EndNodeConfig(BaseNodeConfig):
|
||||
|
||||
@@ -7,6 +7,8 @@ End 节点实现
|
||||
import logging
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -17,12 +19,18 @@ class EndNode(BaseNode):
|
||||
工作流的结束节点,根据配置的模板输出最终结果。
|
||||
支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。
|
||||
"""
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
"""声明此节点的输出类型"""
|
||||
return {
|
||||
"output": VariableType.STRING
|
||||
}
|
||||
|
||||
async def execute(self, state: WorkflowState) -> str:
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> str:
|
||||
"""执行 end 节点业务逻辑
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
variable_pool: 变量池
|
||||
|
||||
Returns:
|
||||
最终输出字符串
|
||||
@@ -34,7 +42,7 @@ class EndNode(BaseNode):
|
||||
|
||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||
if output_template:
|
||||
output = self._render_template(output_template, state, strict=False)
|
||||
output = self._render_template(output_template, variable_pool, strict=False)
|
||||
else:
|
||||
output = ""
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ class NodeType(StrEnum):
|
||||
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
|
||||
IF_ELSE = "if-else"
|
||||
CODE = "code"
|
||||
TRANSFORM = "transform"
|
||||
QUESTION_CLASSIFIER = "question-classifier"
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
|
||||
@@ -10,6 +10,8 @@ from httpx import AsyncClient, Response, Timeout
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
|
||||
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
@@ -34,6 +36,14 @@ class HttpRequestNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: HttpRequestNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
"body": VariableType.STRING,
|
||||
"status_code": VariableType.NUMBER,
|
||||
"headers": VariableType.OBJECT,
|
||||
"output": VariableType.STRING
|
||||
}
|
||||
|
||||
def _build_timeout(self) -> Timeout:
|
||||
"""
|
||||
Build httpx Timeout configuration.
|
||||
@@ -50,7 +60,7 @@ class HttpRequestNode(BaseNode):
|
||||
)
|
||||
return timeout
|
||||
|
||||
def _build_auth(self, state: WorkflowState) -> dict[str, str]:
|
||||
def _build_auth(self, variable_pool: VariablePool) -> dict[str, str]:
|
||||
"""
|
||||
Build authentication-related HTTP headers.
|
||||
|
||||
@@ -58,12 +68,12 @@ class HttpRequestNode(BaseNode):
|
||||
the current workflow runtime state.
|
||||
|
||||
Args:
|
||||
state: Current workflow runtime state.
|
||||
variable_pool: Variable Pool
|
||||
|
||||
Returns:
|
||||
A dictionary of HTTP headers used for authentication.
|
||||
"""
|
||||
api_key = self._render_template(self.typed_config.auth.api_key, state)
|
||||
api_key = self._render_template(self.typed_config.auth.api_key, variable_pool)
|
||||
match self.typed_config.auth.auth_type:
|
||||
case HttpAuthType.NONE:
|
||||
return {}
|
||||
@@ -82,7 +92,7 @@ class HttpRequestNode(BaseNode):
|
||||
case _:
|
||||
raise RuntimeError(f"Auth type not supported: {self.typed_config.auth.auth_type}")
|
||||
|
||||
def _build_header(self, state: WorkflowState) -> dict[str, str]:
|
||||
def _build_header(self, variable_pool: VariablePool) -> dict[str, str]:
|
||||
"""
|
||||
Build HTTP request headers.
|
||||
|
||||
@@ -90,10 +100,10 @@ class HttpRequestNode(BaseNode):
|
||||
"""
|
||||
headers = {}
|
||||
for key, value in self.typed_config.headers.items():
|
||||
headers[self._render_template(key, state)] = self._render_template(value, state)
|
||||
headers[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool)
|
||||
return headers
|
||||
|
||||
def _build_params(self, state: WorkflowState) -> dict[str, str]:
|
||||
def _build_params(self, variable_pool: VariablePool) -> dict[str, str]:
|
||||
"""
|
||||
Build URL query parameters.
|
||||
|
||||
@@ -101,10 +111,10 @@ class HttpRequestNode(BaseNode):
|
||||
"""
|
||||
params = {}
|
||||
for key, value in self.typed_config.params.items():
|
||||
params[self._render_template(key, state)] = self._render_template(value, state)
|
||||
params[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool)
|
||||
return params
|
||||
|
||||
def _build_content(self, state) -> dict[str, Any]:
|
||||
def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""
|
||||
Build HTTP request body arguments for httpx request methods.
|
||||
|
||||
@@ -120,13 +130,13 @@ class HttpRequestNode(BaseNode):
|
||||
return {}
|
||||
case HttpContentType.JSON:
|
||||
content["json"] = json.loads(self._render_template(
|
||||
self.typed_config.body.data, state
|
||||
self.typed_config.body.data, variable_pool
|
||||
))
|
||||
case HttpContentType.FROM_DATA:
|
||||
data = {}
|
||||
for item in self.typed_config.body.data:
|
||||
if item.type == "text":
|
||||
data[self._render_template(item.key, state)] = self._render_template(item.value, state)
|
||||
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value, variable_pool)
|
||||
elif item.type == "file":
|
||||
# TODO: File support (Feature)
|
||||
pass
|
||||
@@ -136,11 +146,11 @@ class HttpRequestNode(BaseNode):
|
||||
pass
|
||||
case HttpContentType.WWW_FORM:
|
||||
content["data"] = json.loads(self._render_template(
|
||||
json.dumps(self.typed_config.body.data), state
|
||||
json.dumps(self.typed_config.body.data), variable_pool
|
||||
))
|
||||
|
||||
case HttpContentType.RAW:
|
||||
content["content"] = self._render_template(self.typed_config.body.data, state)
|
||||
content["content"] = self._render_template(self.typed_config.body.data, variable_pool)
|
||||
case _:
|
||||
raise RuntimeError(f"Content type not supported: {self.typed_config.body.content_type}")
|
||||
return content
|
||||
@@ -165,7 +175,7 @@ class HttpRequestNode(BaseNode):
|
||||
case _:
|
||||
raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}")
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict | str:
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict | str:
|
||||
"""
|
||||
Execute the HTTP request node.
|
||||
|
||||
@@ -176,6 +186,7 @@ class HttpRequestNode(BaseNode):
|
||||
|
||||
Args:
|
||||
state: Current workflow runtime state.
|
||||
variable_pool: Variable Pool
|
||||
|
||||
Returns:
|
||||
- dict: Serialized HttpRequestNodeOutput on success
|
||||
@@ -185,8 +196,8 @@ class HttpRequestNode(BaseNode):
|
||||
async with httpx.AsyncClient(
|
||||
verify=self.typed_config.verify_ssl,
|
||||
timeout=self._build_timeout(),
|
||||
headers=self._build_header(state) | self._build_auth(state),
|
||||
params=self._build_params(state),
|
||||
headers=self._build_header(variable_pool) | self._build_auth(variable_pool),
|
||||
params=self._build_params(variable_pool),
|
||||
follow_redirects=True
|
||||
) as client:
|
||||
retries = self.typed_config.retry.max_attempts
|
||||
@@ -194,8 +205,8 @@ class HttpRequestNode(BaseNode):
|
||||
try:
|
||||
request_func = self._get_client_method(client)
|
||||
resp = await request_func(
|
||||
url=self._render_template(self.typed_config.url, state),
|
||||
**self._build_content(state)
|
||||
url=self._render_template(self.typed_config.url, variable_pool),
|
||||
**self._build_content(variable_pool)
|
||||
)
|
||||
resp.raise_for_status()
|
||||
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
||||
|
||||
@@ -6,6 +6,8 @@ from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -15,6 +17,11 @@ class IfElseNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: IfElseNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
"output": VariableType.STRING
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
|
||||
match operator:
|
||||
@@ -45,7 +52,7 @@ class IfElseNode(BaseNode):
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {operator}")
|
||||
|
||||
def evaluate_conditional_edge_expressions(self, state) -> list[bool]:
|
||||
def evaluate_conditional_edge_expressions(self, variable_pool: VariablePool) -> list[bool]:
|
||||
"""
|
||||
Build conditional edge expressions for the If-Else node.
|
||||
|
||||
@@ -72,11 +79,11 @@ class IfElseNode(BaseNode):
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
left_string = re.sub(pattern, r"\1", expression.left).strip()
|
||||
try:
|
||||
left_value = self.get_variable(left_string, state)
|
||||
left_value = self.get_variable(left_string, variable_pool)
|
||||
except KeyError:
|
||||
left_value = None
|
||||
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
|
||||
self.get_variable_pool(state),
|
||||
variable_pool,
|
||||
expression.left,
|
||||
expression.right,
|
||||
expression.input_type
|
||||
@@ -95,7 +102,7 @@ class IfElseNode(BaseNode):
|
||||
|
||||
return conditions
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
Execute the conditional branching logic of the node.
|
||||
|
||||
@@ -105,13 +112,13 @@ class IfElseNode(BaseNode):
|
||||
|
||||
Args:
|
||||
state (WorkflowState): The current workflow state, containing variables, messages, node outputs, etc.
|
||||
variable_pool: Variable Pool
|
||||
|
||||
Returns:
|
||||
str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions.
|
||||
"""
|
||||
self.typed_config = IfElseNodeConfig(**self.config)
|
||||
expressions = self.evaluate_conditional_edge_expressions(state)
|
||||
# TODO: 变量类型及文本类型解析
|
||||
expressions = self.evaluate_conditional_edge_expressions(variable_pool)
|
||||
for i in range(len(expressions)):
|
||||
if expressions[i]:
|
||||
logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}")
|
||||
|
||||
@@ -5,6 +5,8 @@ from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig
|
||||
from app.core.workflow.template_renderer import TemplateRenderer
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -14,7 +16,12 @@ class JinjaRenderNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: JinjaRenderNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
"output": VariableType.STRING
|
||||
}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
Execute the node: render the Jinja2 template with mapped variables.
|
||||
|
||||
@@ -24,6 +31,7 @@ class JinjaRenderNode(BaseNode):
|
||||
Args:
|
||||
state (WorkflowState): Current workflow state containing variables,
|
||||
node outputs, and runtime variables.
|
||||
variable_pool: Variable Pool
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Node output dictionary containing the rendered result
|
||||
@@ -40,7 +48,7 @@ class JinjaRenderNode(BaseNode):
|
||||
context = {}
|
||||
for variable in self.typed_config.mapping:
|
||||
try:
|
||||
context[variable.name] = self.get_variable(variable.value, state)
|
||||
context[variable.name] = self.get_variable(variable.value, variable_pool)
|
||||
except Exception:
|
||||
logger.info(f"variable not found, var: {variable.value}")
|
||||
continue
|
||||
|
||||
@@ -8,6 +8,8 @@ from app.core.models import RedBearRerank, RedBearModelConfig
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_read
|
||||
from app.models import knowledge_model, knowledgeshare_model, ModelType
|
||||
from app.repositories import knowledge_repository, knowledgeshare_repository
|
||||
@@ -22,6 +24,11 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
"output": VariableType.ARRAY_STRING
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
|
||||
"""
|
||||
@@ -149,7 +156,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
)
|
||||
return reranker
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
Execute the knowledge retrieval workflow node.
|
||||
|
||||
@@ -163,6 +170,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
|
||||
Args:
|
||||
state (WorkflowState): Current workflow execution state.
|
||||
variable_pool: Variable Pool
|
||||
|
||||
Returns:
|
||||
Any: List of retrieved knowledge chunks (dict format).
|
||||
@@ -171,7 +179,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
RuntimeError: If no valid knowledge base is found or access is denied.
|
||||
"""
|
||||
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
|
||||
query = self._render_template(self.typed_config.query, state)
|
||||
query = self._render_template(self.typed_config.query, variable_pool)
|
||||
with get_db_read() as db:
|
||||
knowledge_bases = self.typed_config.knowledge_bases
|
||||
existing_ids = self._get_existing_kb_ids(db, [kb.kb_id for kb in knowledge_bases])
|
||||
|
||||
@@ -4,7 +4,8 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
|
||||
class MessageConfig(BaseModel):
|
||||
|
||||
@@ -15,6 +15,8 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_context
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
@@ -66,19 +68,27 @@ class LLMNode(BaseNode):
|
||||
- ai/assistant: AI 消息(AIMessage)
|
||||
"""
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {"output": VariableType.STRING}
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: LLMNodeConfig | None = None
|
||||
|
||||
def _render_context(self, message, state):
|
||||
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
|
||||
def _render_context(self, message: str, variable_pool: VariablePool):
|
||||
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
|
||||
return re.sub(r"{{context}}", context, message)
|
||||
|
||||
def _prepare_llm(self, state: WorkflowState, stream: bool = False) -> tuple[RedBearLLM, list | str]:
|
||||
def _prepare_llm(
|
||||
self,
|
||||
state: WorkflowState,
|
||||
variable_pool: VariablePool,
|
||||
stream: bool = False
|
||||
) -> tuple[RedBearLLM, list | str]:
|
||||
"""准备 LLM 实例(公共逻辑)
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
variable_pool: 变量池
|
||||
|
||||
Returns:
|
||||
(llm, messages_or_prompt): LLM 实例和消息列表或 prompt 字符串
|
||||
@@ -94,8 +104,8 @@ class LLMNode(BaseNode):
|
||||
for msg_config in messages_config:
|
||||
role = msg_config.role.lower()
|
||||
content_template = msg_config.content
|
||||
content_template = self._render_context(content_template, state)
|
||||
content = self._render_template(content_template, state)
|
||||
content_template = self._render_context(content_template, variable_pool)
|
||||
content = self._render_template(content_template, variable_pool)
|
||||
|
||||
# 根据角色创建对应的消息对象
|
||||
if role == "system":
|
||||
@@ -115,7 +125,7 @@ class LLMNode(BaseNode):
|
||||
else:
|
||||
# 使用简单的 prompt 格式(向后兼容)
|
||||
prompt_template = self.config.get("prompt", "")
|
||||
prompt_or_messages = self._render_template(prompt_template, state)
|
||||
prompt_or_messages = self._render_template(prompt_template, variable_pool)
|
||||
|
||||
# 2. 获取模型配置
|
||||
model_id = self.config.get("model_id")
|
||||
@@ -159,17 +169,18 @@ class LLMNode(BaseNode):
|
||||
|
||||
return llm, prompt_or_messages
|
||||
|
||||
async def execute(self, state: WorkflowState) -> AIMessage:
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage:
|
||||
"""非流式执行 LLM 调用
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
variable_pool: 变量池
|
||||
|
||||
Returns:
|
||||
LLM 响应消息
|
||||
"""
|
||||
# self.typed_config = LLMNodeConfig(**self.config)
|
||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||
llm, prompt_or_messages = self._prepare_llm(state, variable_pool, False)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||
|
||||
@@ -186,9 +197,9 @@ class LLMNode(BaseNode):
|
||||
# 返回 AIMessage(包含响应元数据)
|
||||
return response if isinstance(response, AIMessage) else AIMessage(content=content)
|
||||
|
||||
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""提取输入数据(用于记录)"""
|
||||
_, prompt_or_messages = self._prepare_llm(state)
|
||||
_, prompt_or_messages = self._prepare_llm(state, variable_pool)
|
||||
|
||||
return {
|
||||
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
|
||||
@@ -221,18 +232,19 @@ class LLMNode(BaseNode):
|
||||
}
|
||||
return None
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||
"""流式执行 LLM 调用
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
variable_pool: 变量池
|
||||
|
||||
Yields:
|
||||
文本片段(chunk)或完成标记
|
||||
"""
|
||||
self.typed_config = LLMNodeConfig(**self.config)
|
||||
|
||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||
llm, prompt_or_messages = self._prepare_llm(state, variable_pool, True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||
|
||||
@@ -3,6 +3,8 @@ from typing import Any
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_read
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.tasks import write_message_task
|
||||
@@ -13,17 +15,23 @@ class MemoryReadNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: MemoryReadNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
"answer": VariableType.STRING,
|
||||
"intermediate_outputs": VariableType.ARRAY_OBJECT
|
||||
}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
self.typed_config = MemoryReadNodeConfig(**self.config)
|
||||
with get_db_read() as db:
|
||||
end_user_id = self.get_variable("sys.user_id", state)
|
||||
end_user_id = self.get_variable("sys.user_id", variable_pool)
|
||||
|
||||
if not end_user_id:
|
||||
raise RuntimeError("End user id is required")
|
||||
|
||||
return await MemoryAgentService().read_memory(
|
||||
end_user_id=end_user_id,
|
||||
message=self._render_template(self.typed_config.message, state),
|
||||
message=self._render_template(self.typed_config.message, variable_pool),
|
||||
config_id=self.typed_config.config_id,
|
||||
search_switch=self.typed_config.search_switch,
|
||||
history=[],
|
||||
@@ -38,16 +46,19 @@ class MemoryWriteNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: MemoryWriteNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {"output": VariableType.STRING}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
||||
end_user_id = self.get_variable("sys.user_id", state)
|
||||
end_user_id = self.get_variable("sys.user_id", variable_pool)
|
||||
|
||||
if not end_user_id:
|
||||
raise RuntimeError("End user id is required")
|
||||
|
||||
write_message_task.delay(
|
||||
end_user_id,
|
||||
self._render_template(self.typed_config.message, state),
|
||||
self._render_template(self.typed_config.message, variable_pool),
|
||||
str(self.typed_config.config_id),
|
||||
"neo4j",
|
||||
""
|
||||
|
||||
@@ -22,7 +22,6 @@ from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.memory import MemoryReadNode, MemoryWriteNode
|
||||
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from app.core.workflow.nodes.breaker import BreakNode
|
||||
@@ -37,7 +36,6 @@ WorkflowNode = Union[
|
||||
LLMNode,
|
||||
IfElseNode,
|
||||
AgentNode,
|
||||
TransformNode,
|
||||
AssignerNode,
|
||||
HttpRequestNode,
|
||||
KnowledgeRetrievalNode,
|
||||
@@ -67,7 +65,6 @@ class NodeFactory:
|
||||
NodeType.END: EndNode,
|
||||
NodeType.LLM: LLMNode,
|
||||
NodeType.AGENT: AgentNode,
|
||||
NodeType.TRANSFORM: TransformNode,
|
||||
NodeType.IF_ELSE: IfElseNode,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
||||
NodeType.ASSIGNER: AssignerNode,
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC
|
||||
from typing import Union, Type, NoReturn
|
||||
from typing import Union, Type, NoReturn, Any
|
||||
|
||||
from app.core.workflow.nodes.base_config import VariableType
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.nodes.enums import ValueInputType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
@@ -69,7 +69,7 @@ class TypeTransformer:
|
||||
|
||||
|
||||
class OperatorBase(ABC):
|
||||
def __init__(self, pool: VariablePool, left_selector, right):
|
||||
def __init__(self, pool: VariablePool, left_selector: str, right: Any):
|
||||
self.pool = pool
|
||||
self.left_selector = left_selector
|
||||
self.right = right
|
||||
@@ -77,7 +77,7 @@ class OperatorBase(ABC):
|
||||
self.type_limit: type[str, int, dict, list] = None
|
||||
|
||||
def check(self, no_right=False):
|
||||
left = self.pool.get(self.left_selector)
|
||||
left = self.pool.get_value(self.left_selector)
|
||||
if not isinstance(left, self.type_limit):
|
||||
raise TypeError(f"The variable to be operated on must be of {self.type_limit} type")
|
||||
|
||||
@@ -92,13 +92,13 @@ class StringOperator(OperatorBase):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = str
|
||||
|
||||
def assign(self) -> None:
|
||||
async def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
await self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
async def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, '')
|
||||
await self.pool.set(self.left_selector, '')
|
||||
|
||||
|
||||
class NumberOperator(OperatorBase):
|
||||
@@ -106,33 +106,33 @@ class NumberOperator(OperatorBase):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = (float, int)
|
||||
|
||||
def assign(self) -> None:
|
||||
async def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
await self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
async def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, 0)
|
||||
await self.pool.set(self.left_selector, 0)
|
||||
|
||||
def add(self) -> None:
|
||||
async def add(self) -> None:
|
||||
self.check()
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin + self.right)
|
||||
origin = self.pool.get_value(self.left_selector)
|
||||
await self.pool.set(self.left_selector, origin + self.right)
|
||||
|
||||
def subtract(self) -> None:
|
||||
async def subtract(self) -> None:
|
||||
self.check()
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin - self.right)
|
||||
origin = self.pool.get_value(self.left_selector)
|
||||
await self.pool.set(self.left_selector, origin - self.right)
|
||||
|
||||
def multiply(self) -> None:
|
||||
async def multiply(self) -> None:
|
||||
self.check()
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin * self.right)
|
||||
origin = self.pool.get_value(self.left_selector)
|
||||
await self.pool.set(self.left_selector, origin * self.right)
|
||||
|
||||
def divide(self) -> None:
|
||||
async def divide(self) -> None:
|
||||
self.check()
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin / self.right)
|
||||
origin = self.pool.get_value(self.left_selector)
|
||||
await self.pool.set(self.left_selector, origin / self.right)
|
||||
|
||||
|
||||
class BooleanOperator(OperatorBase):
|
||||
@@ -140,13 +140,13 @@ class BooleanOperator(OperatorBase):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = bool
|
||||
|
||||
def assign(self) -> None:
|
||||
async def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
await self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
async def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, False)
|
||||
await self.pool.set(self.left_selector, False)
|
||||
|
||||
|
||||
class ArrayOperator(OperatorBase):
|
||||
@@ -154,38 +154,37 @@ class ArrayOperator(OperatorBase):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = list
|
||||
|
||||
def assign(self) -> None:
|
||||
async def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
await self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
async def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, list())
|
||||
await self.pool.set(self.left_selector, list())
|
||||
|
||||
def append(self) -> None:
|
||||
async def append(self) -> None:
|
||||
self.check(no_right=True)
|
||||
# TODO:require type limit in list
|
||||
origin = self.pool.get(self.left_selector)
|
||||
origin = self.pool.get_value(self.left_selector)
|
||||
origin.append(self.right)
|
||||
self.pool.set(self.left_selector, origin)
|
||||
await self.pool.set(self.left_selector, origin)
|
||||
|
||||
def extend(self) -> None:
|
||||
async def extend(self) -> None:
|
||||
self.check(no_right=True)
|
||||
origin = self.pool.get(self.left_selector)
|
||||
origin = self.pool.get_value(self.left_selector)
|
||||
origin.extend(self.right)
|
||||
self.pool.set(self.left_selector, origin)
|
||||
await self.pool.set(self.left_selector, origin)
|
||||
|
||||
def remove_last(self) -> None:
|
||||
async def remove_last(self) -> None:
|
||||
self.check(no_right=True)
|
||||
origin = self.pool.get(self.left_selector)
|
||||
origin = self.pool.get_value(self.left_selector)
|
||||
origin.pop()
|
||||
self.pool.set(self.left_selector, origin)
|
||||
await self.pool.set(self.left_selector, origin)
|
||||
|
||||
def remove_first(self) -> None:
|
||||
async def remove_first(self) -> None:
|
||||
self.check(no_right=True)
|
||||
origin = self.pool.get(self.left_selector)
|
||||
origin = self.pool.get_value(self.left_selector)
|
||||
origin.pop(0)
|
||||
self.pool.set(self.left_selector, origin)
|
||||
await self.pool.set(self.left_selector, origin)
|
||||
|
||||
|
||||
class ObjectOperator(OperatorBase):
|
||||
@@ -193,13 +192,13 @@ class ObjectOperator(OperatorBase):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = dict
|
||||
|
||||
def assign(self) -> None:
|
||||
async def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
await self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
async def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, dict())
|
||||
await self.pool.set(self.left_selector, dict())
|
||||
|
||||
|
||||
class AssignmentOperatorResolver:
|
||||
@@ -245,7 +244,7 @@ class ConditionBase(ABC):
|
||||
self.right_selector = right_selector
|
||||
self.input_type = input_type
|
||||
|
||||
self.left_value = self.pool.get(self.left_selector)
|
||||
self.left_value = self.pool.get_value(self.left_selector)
|
||||
self.right_value = self.resolve_right_literal_value()
|
||||
|
||||
self.type_limit = getattr(self, "type_limit", None)
|
||||
@@ -254,7 +253,7 @@ class ConditionBase(ABC):
|
||||
if self.input_type == ValueInputType.VARIABLE:
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
right_expression = re.sub(pattern, r"\1", self.right_selector).strip()
|
||||
return self.pool.get(right_expression)
|
||||
return self.pool.get_value(right_expression)
|
||||
elif self.input_type == ValueInputType.CONSTANT:
|
||||
return self.right_selector
|
||||
raise RuntimeError("Unsupported variable type")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import Field, BaseModel
|
||||
from enum import StrEnum
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_read
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
@@ -24,6 +26,12 @@ class ParameterExtractorNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: ParameterExtractorNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
outputs = {}
|
||||
for param in self.typed_config.params:
|
||||
outputs[param.name] = param.type
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _get_prompt():
|
||||
"""
|
||||
@@ -120,7 +128,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
field_type[param.name] = f'{param.type}, required:{str(param.required)}'
|
||||
return field_type
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
Main execution function for this node.
|
||||
|
||||
@@ -138,6 +146,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
|
||||
Args:
|
||||
state (WorkflowState): Current state of the workflow, used for template rendering.
|
||||
variable_pool (VariablePool): Used for accessing and setting variables during execution.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Dictionary containing extracted parameters under the "output" key.
|
||||
@@ -153,7 +162,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
rendered_user_prompt = user_prompt_teplate.render(
|
||||
field_descriptions=str(self._get_field_desc()),
|
||||
field_type=str(self._get_field_type()),
|
||||
text_input=self._render_template(self.typed_config.text, state)
|
||||
text_input=self._render_template(self.typed_config.text, variable_pool)
|
||||
)
|
||||
|
||||
messages = [
|
||||
@@ -162,7 +171,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
]
|
||||
if self.typed_config.prompt:
|
||||
messages.extend([
|
||||
("user", self._render_template(self.typed_config.prompt, state)),
|
||||
("user", self._render_template(self.typed_config.prompt, variable_pool)),
|
||||
("user", rendered_user_prompt),
|
||||
])
|
||||
else:
|
||||
|
||||
@@ -6,6 +6,8 @@ from app.core.workflow.nodes.question_classifier.config import QuestionClassifie
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_read
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
@@ -24,6 +26,12 @@ class QuestionClassifierNode(BaseNode):
|
||||
self.typed_config: QuestionClassifierNodeConfig | None = None
|
||||
self.category_to_case_map = {}
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
"class_name": VariableType.STRING,
|
||||
"output": VariableType.STRING
|
||||
}
|
||||
|
||||
def _get_llm_instance(self) -> RedBearLLM:
|
||||
"""获取LLM实例"""
|
||||
with get_db_read() as db:
|
||||
@@ -65,7 +73,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
category_map[category_name] = case_tag
|
||||
return category_map
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict:
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict:
|
||||
"""执行问题分类"""
|
||||
self.typed_config = QuestionClassifierNodeConfig(**self.config)
|
||||
self.category_to_case_map = self._build_category_case_map()
|
||||
@@ -102,7 +110,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
categories=", ".join(category_names),
|
||||
supplement_prompt=supplement_prompt
|
||||
),
|
||||
state
|
||||
variable_pool
|
||||
)
|
||||
|
||||
messages = [
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
|
||||
class StartNodeConfig(BaseNodeConfig):
|
||||
|
||||
@@ -7,9 +7,10 @@ Start 节点实现
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_config import VariableType
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,14 +37,25 @@ class StartNode(BaseNode):
|
||||
|
||||
# 解析并验证配置
|
||||
self.typed_config: StartNodeConfig | None = None
|
||||
self.output_var_types = {}
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return self.output_var_types | {
|
||||
"message": VariableType.STRING,
|
||||
"execution_id": VariableType.STRING,
|
||||
"conversation_id": VariableType.STRING,
|
||||
"workspace_id": VariableType.STRING,
|
||||
"user_id": VariableType.STRING,
|
||||
}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""执行 start 节点业务逻辑
|
||||
|
||||
Start 节点输出系统变量、会话变量和自定义变量。
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
variable_pool: 变量池
|
||||
|
||||
Returns:
|
||||
包含系统参数、会话变量和自定义变量的字典
|
||||
@@ -51,19 +63,16 @@ class StartNode(BaseNode):
|
||||
self.typed_config = StartNodeConfig(**self.config)
|
||||
logger.info(f"节点 {self.node_id} (Start) 开始执行")
|
||||
|
||||
# 创建变量池实例(在方法内复用)
|
||||
pool = self.get_variable_pool(state)
|
||||
|
||||
# 处理自定义变量(传入 pool 避免重复创建)
|
||||
custom_vars = self._process_custom_variables(pool)
|
||||
custom_vars = self._process_custom_variables(variable_pool)
|
||||
|
||||
# 返回业务数据(包含自定义变量)
|
||||
result = {
|
||||
"message": pool.get("sys.message"),
|
||||
"execution_id": pool.get("sys.execution_id"),
|
||||
"conversation_id": pool.get("sys.conversation_id"),
|
||||
"workspace_id": pool.get("sys.workspace_id"),
|
||||
"user_id": pool.get("sys.user_id"),
|
||||
"message": variable_pool.get_value("sys.message"),
|
||||
"execution_id": variable_pool.get_value("sys.execution_id"),
|
||||
"conversation_id": variable_pool.get_value("sys.conversation_id"),
|
||||
"workspace_id": variable_pool.get_value("sys.workspace_id"),
|
||||
"user_id": variable_pool.get_value("sys.user_id"),
|
||||
**custom_vars # 自定义变量作为节点输出的一部分
|
||||
}
|
||||
|
||||
@@ -74,7 +83,7 @@ class StartNode(BaseNode):
|
||||
|
||||
return result
|
||||
|
||||
def _process_custom_variables(self, pool) -> dict[str, Any]:
|
||||
def _process_custom_variables(self, pool: VariablePool) -> dict[str, Any]:
|
||||
"""处理自定义变量
|
||||
|
||||
从输入数据中提取自定义变量,应用默认值和验证。
|
||||
@@ -89,13 +98,14 @@ class StartNode(BaseNode):
|
||||
ValueError: 缺少必需变量
|
||||
"""
|
||||
# 获取输入数据中的自定义变量
|
||||
input_variables = pool.get("sys.input_variables", default={})
|
||||
input_variables = pool.get_value("sys.input_variables", default={}, strict=False)
|
||||
|
||||
processed = {}
|
||||
|
||||
# 遍历配置的变量定义
|
||||
for var_def in self.typed_config.variables:
|
||||
var_name = var_def.name
|
||||
var_type = var_def.type
|
||||
|
||||
# 检查变量是否存在
|
||||
if var_name in input_variables:
|
||||
@@ -116,21 +126,12 @@ class StartNode(BaseNode):
|
||||
f"变量 '{var_name}' 使用默认值: {var_def.default}"
|
||||
)
|
||||
else:
|
||||
match var_def.type:
|
||||
case VariableType.STRING:
|
||||
processed[var_name] = ""
|
||||
case VariableType.NUMBER:
|
||||
processed[var_name] = 0
|
||||
case VariableType.OBJECT:
|
||||
processed[var_name] = {}
|
||||
case VariableType.BOOLEAN:
|
||||
processed[var_name] = False
|
||||
case VariableType.ARRAY_NUMBER | VariableType.ARRAY_OBJECT | VariableType.ARRAY_BOOLEAN | VariableType.ARRAY_STRING:
|
||||
processed[var_name] = []
|
||||
processed[var_name] = DEFAULT_VALUE(var_type)
|
||||
self.output_var_types[var_name] = var_type
|
||||
|
||||
return processed
|
||||
|
||||
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""提取输入数据(用于记录)
|
||||
|
||||
Args:
|
||||
@@ -139,11 +140,9 @@ class StartNode(BaseNode):
|
||||
Returns:
|
||||
输入数据字典
|
||||
"""
|
||||
pool = self.get_variable_pool(state)
|
||||
|
||||
return {
|
||||
"execution_id": pool.get("sys.execution_id"),
|
||||
"conversation_id": pool.get("sys.conversation_id"),
|
||||
"message": pool.get("sys.message"),
|
||||
"conversation_vars": pool.get_all_conversation_vars()
|
||||
"execution_id": variable_pool.get_value("sys.execution_id"),
|
||||
"conversation_id": variable_pool.get_value("sys.conversation_id"),
|
||||
"message": variable_pool.get_value("sys.message"),
|
||||
"conversation_vars": variable_pool.get_all_conversation_vars()
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.services.tool_service import ToolService
|
||||
from app.db import get_db_read
|
||||
|
||||
@@ -21,13 +23,20 @@ class ToolNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: ToolNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
"data": VariableType.STRING,
|
||||
"error_code": VariableType.STRING,
|
||||
"execution_time": VariableType.NUMBER
|
||||
}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""执行工具"""
|
||||
self.typed_config = ToolNodeConfig(**self.config)
|
||||
# 获取租户ID和用户ID
|
||||
tenant_id = self.get_variable("sys.tenant_id", state)
|
||||
user_id = self.get_variable("sys.user_id", state)
|
||||
workspace_id = self.get_variable("sys.workspace_id", state)
|
||||
tenant_id = self.get_variable("sys.tenant_id", variable_pool, strict=False)
|
||||
user_id = self.get_variable("sys.user_id", variable_pool)
|
||||
workspace_id = self.get_variable("sys.workspace_id", variable_pool)
|
||||
|
||||
# 如果没有租户ID,尝试从工作流ID获取
|
||||
if not tenant_id:
|
||||
@@ -48,7 +57,7 @@ class ToolNode(BaseNode):
|
||||
for param_name, param_template in self.typed_config.tool_parameters.items():
|
||||
if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template):
|
||||
try:
|
||||
rendered_value = self._render_template(param_template, state)
|
||||
rendered_value = self._render_template(param_template, variable_pool)
|
||||
except Exception as e:
|
||||
raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e
|
||||
else:
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Transform 节点"""
|
||||
|
||||
from app.core.workflow.nodes.transform.node import TransformNode
|
||||
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||
|
||||
__all__ = ["TransformNode", "TransformNodeConfig"]
|
||||
@@ -1,80 +0,0 @@
|
||||
"""Transform 节点配置"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
|
||||
|
||||
class TransformNodeConfig(BaseNodeConfig):
|
||||
"""Transform 节点配置
|
||||
|
||||
用于数据转换和处理。
|
||||
"""
|
||||
|
||||
transform_type: Literal["template", "code", "json"] = Field(
|
||||
default="template",
|
||||
description="转换类型:template(模板), code(代码), json(JSON处理)"
|
||||
)
|
||||
|
||||
# 模板模式
|
||||
template: str | None = Field(
|
||||
default=None,
|
||||
description="转换模板,支持变量引用"
|
||||
)
|
||||
|
||||
# 代码模式
|
||||
code: str | None = Field(
|
||||
default=None,
|
||||
description="Python 代码,用于数据转换"
|
||||
)
|
||||
|
||||
# JSON 模式
|
||||
json_path: str | None = Field(
|
||||
default=None,
|
||||
description="JSON 路径表达式"
|
||||
)
|
||||
|
||||
# 输入变量
|
||||
inputs: dict[str, str] | None = Field(
|
||||
default=None,
|
||||
description="输入变量映射,key 为变量名,value 为变量选择器"
|
||||
)
|
||||
|
||||
# 输出变量
|
||||
output_key: str = Field(
|
||||
default="result",
|
||||
description="输出变量的键名"
|
||||
)
|
||||
|
||||
# 输出变量定义
|
||||
output_variables: list[VariableDefinition] = Field(
|
||||
default_factory=lambda: [
|
||||
VariableDefinition(
|
||||
name="result",
|
||||
type=VariableType.STRING,
|
||||
description="转换后的结果"
|
||||
)
|
||||
],
|
||||
description="输出变量定义(根据 output_key 动态生成)"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
{
|
||||
"transform_type": "template",
|
||||
"template": "用户问题:{{ sys.message }}\n回答:{{ llm_qa.output }}",
|
||||
"output_key": "formatted_result"
|
||||
},
|
||||
{
|
||||
"transform_type": "code",
|
||||
"code": "result = input_text.upper()",
|
||||
"inputs": {
|
||||
"input_text": "{{ sys.message }}"
|
||||
},
|
||||
"output_key": "uppercase_text"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
"""
|
||||
Transform 节点实现
|
||||
|
||||
数据转换节点,用于处理和转换数据。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TransformNode(BaseNode):
|
||||
"""数据转换节点
|
||||
|
||||
配置示例:
|
||||
{
|
||||
"type": "transform",
|
||||
"config": {
|
||||
"mapping": {
|
||||
"output_field": "{{node.previous.output}}",
|
||||
"processed": "{{var.input | upper}}"
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""执行数据转换
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
状态更新字典
|
||||
"""
|
||||
logger.info(f"节点 {self.node_id} 开始执行数据转换")
|
||||
|
||||
# 获取映射配置
|
||||
mapping = self.config.get("mapping", {})
|
||||
|
||||
# 执行数据转换
|
||||
transformed_data = {}
|
||||
for target_key, source_template in mapping.items():
|
||||
# 渲染模板获取值
|
||||
value = self._render_template(str(source_template), state)
|
||||
transformed_data[target_key] = value
|
||||
|
||||
logger.info(f"节点 {self.node_id} 数据转换完成,输出字段: {list(transformed_data.keys())}")
|
||||
|
||||
return {
|
||||
"node_outputs": {
|
||||
self.node_id: {
|
||||
"output": transformed_data,
|
||||
"status": "completed"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
|
||||
class VariableAggregatorNodeConfig(BaseNodeConfig):
|
||||
@@ -14,6 +15,11 @@ class VariableAggregatorNodeConfig(BaseNodeConfig):
|
||||
description="需要被聚合的变量"
|
||||
)
|
||||
|
||||
group_type: dict[str, VariableType] = Field(
|
||||
...,
|
||||
description="每个分组的变量类型"
|
||||
)
|
||||
|
||||
@field_validator("group_variables")
|
||||
@classmethod
|
||||
def group_variables_validator(cls, v, info):
|
||||
|
||||
@@ -5,6 +5,8 @@ from typing import Any
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -14,6 +16,13 @@ class VariableAggregatorNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: VariableAggregatorNodeConfig | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
config = VariableAggregatorNodeConfig(**self.config)
|
||||
output = {}
|
||||
for var_type in config.group_type:
|
||||
output[var_type] = config.group_type[var_type]
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def _get_express(variable_string: str) -> Any:
|
||||
"""
|
||||
@@ -29,7 +38,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
expression = re.sub(pattern, r"\1", variable_string).strip()
|
||||
return expression
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
Execute the variable aggregation logic.
|
||||
|
||||
@@ -45,7 +54,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
for variable in self.typed_config.group_variables:
|
||||
var_express = self._get_express(variable)
|
||||
try:
|
||||
value = self.get_variable(var_express, state)
|
||||
value = self.get_variable(var_express, variable_pool)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get variable '{var_express}': {e}")
|
||||
continue
|
||||
@@ -55,7 +64,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
return value
|
||||
|
||||
logger.info("No variable found in non-group mode; returning empty string.")
|
||||
return ""
|
||||
return DEFAULT_VALUE(self.typed_config.group_type["output"])
|
||||
|
||||
# --------------------------
|
||||
# Group mode
|
||||
@@ -65,7 +74,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
for variable in variables:
|
||||
var_express = self._get_express(variable)
|
||||
try:
|
||||
value = self.get_variable(var_express, state)
|
||||
value = self.get_variable(var_express, variable_pool)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get variable '{var_express}' in group '{group_name}': {e}")
|
||||
continue
|
||||
@@ -74,7 +83,7 @@ class VariableAggregatorNode(BaseNode):
|
||||
result[group_name] = value
|
||||
break
|
||||
else:
|
||||
result[group_name] = ""
|
||||
result[group_name] = DEFAULT_VALUE(self.typed_config.group_type[group_name])
|
||||
logger.info(f"No variable found for group '{group_name}'; set empty string.")
|
||||
logger.info(f"Node: {self.node_id} variable aggregation result: {result}")
|
||||
return result
|
||||
|
||||
@@ -43,7 +43,7 @@ class TemplateRenderer:
|
||||
def render(
|
||||
self,
|
||||
template: str,
|
||||
variables: dict[str, Any],
|
||||
conv_vars: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> str:
|
||||
@@ -51,7 +51,7 @@ class TemplateRenderer:
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
variables: 用户定义的变量
|
||||
conv_vars: 会话变量
|
||||
node_outputs: 节点输出结果
|
||||
system_vars: 系统变量
|
||||
|
||||
@@ -80,20 +80,11 @@ class TemplateRenderer:
|
||||
'分析结果: 正面情绪'
|
||||
"""
|
||||
# 构建命名空间上下文
|
||||
# variables 的结构:{"sys": {...}, "conv": {...}}
|
||||
sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {}
|
||||
conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {}
|
||||
if self.strict:
|
||||
context = defaultdict(dict)
|
||||
context["conv"] = conv_vars
|
||||
context["node"] = node_outputs
|
||||
context["sys"] = {**(system_vars or {}), **sys_vars}
|
||||
else:
|
||||
context = {
|
||||
"conv": conv_vars, # 会话变量:{{conv.user_name}}
|
||||
"node": node_outputs, # 节点输出:{{node.node_1.output}}
|
||||
"sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源)
|
||||
}
|
||||
context = {
|
||||
"conv": conv_vars, # 会话变量:{{conv.user_name}}
|
||||
"node": node_outputs, # 节点输出:{{node.node_1.output}}
|
||||
"sys": system_vars, # 系统变量:{{sys.execution_id}}
|
||||
}
|
||||
|
||||
# 支持直接通过节点ID访问节点输出:{{llm_qa.output}}
|
||||
# 将所有节点输出添加到顶层上下文
|
||||
@@ -157,9 +148,9 @@ _default_renderer = TemplateRenderer(strict=True)
|
||||
|
||||
def render_template(
|
||||
template: str,
|
||||
variables: dict[str, Any],
|
||||
conv_vars: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None,
|
||||
system_vars: dict[str, Any],
|
||||
strict: bool = True
|
||||
) -> str:
|
||||
"""渲染模板(便捷函数)
|
||||
@@ -167,7 +158,7 @@ def render_template(
|
||||
Args:
|
||||
strict: 严格模式
|
||||
template: 模板字符串
|
||||
variables: 用户变量
|
||||
conv_vars: 会话变量
|
||||
node_outputs: 节点输出
|
||||
system_vars: 系统变量
|
||||
|
||||
@@ -184,7 +175,7 @@ def render_template(
|
||||
'请分析: 这是一段文本'
|
||||
"""
|
||||
renderer = TemplateRenderer(strict=strict)
|
||||
return renderer.render(template, variables, node_outputs, system_vars)
|
||||
return renderer.render(template, conv_vars, node_outputs, system_vars)
|
||||
|
||||
|
||||
def validate_template(template: str) -> list[str]:
|
||||
|
||||
@@ -5,10 +5,13 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, TYPE_CHECKING
|
||||
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.workflow_schema import WorkflowConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -64,7 +67,7 @@ class WorkflowValidator:
|
||||
return cycle_nodes, cycle_edges
|
||||
|
||||
@classmethod
|
||||
def get_subgraph(cls, workflow_config: Union[dict[str, Any], Any]) -> list:
|
||||
def get_subgraph(cls, workflow_config: Union[dict[str, Any], "WorkflowConfig"]) -> list:
|
||||
if not isinstance(workflow_config, dict):
|
||||
workflow_config = {
|
||||
"nodes": workflow_config.nodes,
|
||||
@@ -331,7 +334,7 @@ class WorkflowValidator:
|
||||
|
||||
|
||||
def validate_workflow_config(
|
||||
workflow_config: dict[str, Any],
|
||||
workflow_config: Union[dict[str, Any], 'WorkflowConfig'],
|
||||
for_publish: bool = False
|
||||
) -> tuple[bool, list[str]]:
|
||||
"""验证工作流配置(便捷函数)
|
||||
|
||||
0
api/app/core/workflow/variable/__init__.py
Normal file
0
api/app/core/workflow/variable/__init__.py
Normal file
162
api/app/core/workflow/variable/base_variable.py
Normal file
162
api/app/core/workflow/variable/base_variable.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from enum import StrEnum
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Any
|
||||
|
||||
|
||||
class VariableType(StrEnum):
|
||||
"""Enumeration of supported variable types in the workflow."""
|
||||
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
OBJECT = "object"
|
||||
FILE = "file"
|
||||
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_BOOLEAN = "array[boolean]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_FILE = "array[file]"
|
||||
|
||||
NESTED_ARRAY = "array_nest"
|
||||
|
||||
@classmethod
|
||||
def type_map(cls, var: Any) -> "VariableType":
|
||||
"""Maps a Python value to a corresponding VariableType.
|
||||
|
||||
Args:
|
||||
var: The Python value to map.
|
||||
|
||||
Returns:
|
||||
The VariableType corresponding to the input value.
|
||||
|
||||
Raises:
|
||||
TypeError: If the type of the input value is not supported.
|
||||
"""
|
||||
var_type = type(var)
|
||||
if isinstance(var_type, str):
|
||||
return cls.STRING
|
||||
elif isinstance(var_type, (int, float)):
|
||||
return cls.NUMBER
|
||||
elif isinstance(var_type, bool):
|
||||
return cls.BOOLEAN
|
||||
elif isinstance(var_type, FileObj):
|
||||
return cls.FILE
|
||||
elif isinstance(var_type, dict):
|
||||
return cls.OBJECT
|
||||
elif isinstance(var_type, list):
|
||||
if len(var) == 0:
|
||||
return cls.ARRAY_STRING
|
||||
else:
|
||||
child_type = type(var[0])
|
||||
if child_type == str:
|
||||
return cls.ARRAY_STRING
|
||||
elif child_type == int or child_type == float:
|
||||
return cls.ARRAY_NUMBER
|
||||
elif child_type == bool:
|
||||
return cls.ARRAY_BOOLEAN
|
||||
elif child_type == dict:
|
||||
return cls.ARRAY_OBJECT
|
||||
elif child_type == list:
|
||||
return cls.NESTED_ARRAY
|
||||
else:
|
||||
raise TypeError(f"Unsupported array child type - {child_type}")
|
||||
raise TypeError(f"Unsupported type - {var_type}")
|
||||
|
||||
|
||||
def DEFAULT_VALUE(var_type: VariableType) -> Any:
|
||||
"""Returns the default value for a given VariableType.
|
||||
|
||||
Args:
|
||||
var_type: The variable type for which to get the default value.
|
||||
|
||||
Returns:
|
||||
The default Python value corresponding to the VariableType.
|
||||
|
||||
Raises:
|
||||
TypeError: If the VariableType is invalid.
|
||||
"""
|
||||
match var_type:
|
||||
case VariableType.STRING:
|
||||
return ""
|
||||
case VariableType.NUMBER:
|
||||
return 0
|
||||
case VariableType.BOOLEAN:
|
||||
return False
|
||||
case VariableType.OBJECT:
|
||||
return {}
|
||||
case VariableType.FILE:
|
||||
return None
|
||||
case VariableType.ARRAY_STRING:
|
||||
return []
|
||||
case VariableType.ARRAY_NUMBER:
|
||||
return []
|
||||
case VariableType.ARRAY_BOOLEAN:
|
||||
return []
|
||||
case VariableType.ARRAY_OBJECT:
|
||||
return []
|
||||
case VariableType.ARRAY_FILE:
|
||||
return []
|
||||
case _:
|
||||
raise TypeError(f"Invalid type - {type}")
|
||||
|
||||
|
||||
class FileObj:
|
||||
pass
|
||||
|
||||
|
||||
class BaseVariable(ABC):
|
||||
"""Abstract base class for all workflow variables.
|
||||
|
||||
Subclasses must implement validation and serialization methods.
|
||||
"""
|
||||
type = None
|
||||
|
||||
def __init__(self, value: Any):
|
||||
"""Initializes a variable instance.
|
||||
|
||||
Args:
|
||||
value: The initial value for the variable.
|
||||
|
||||
Attributes:
|
||||
self.value: The validated value stored in the variable.
|
||||
self.literal: A string representation of the variable.
|
||||
"""
|
||||
self.value = self.valid_value(value)
|
||||
self.literal = self.to_literal()
|
||||
|
||||
@abstractmethod
|
||||
def valid_value(self, value) -> Any:
|
||||
"""Validates or converts a value to the correct type for the variable.
|
||||
|
||||
Args:
|
||||
value: The value to validate.
|
||||
|
||||
Returns:
|
||||
The validated or converted value.
|
||||
|
||||
Raises:
|
||||
TypeError: If the value is invalid.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def to_literal(self) -> str:
|
||||
"""Converts the variable value to a string literal representation.
|
||||
|
||||
Returns:
|
||||
A string representing the variable's value.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_value(self) -> Any:
|
||||
"""Returns the current value of the variable."""
|
||||
return self.value
|
||||
|
||||
def set(self, value):
|
||||
"""Sets the variable to a new value after validation.
|
||||
|
||||
Args:
|
||||
value: The new value to assign to the variable.
|
||||
"""
|
||||
self.value = self.valid_value(value)
|
||||
137
api/app/core/workflow/variable/variable_objects.py
Normal file
137
api/app/core/workflow/variable/variable_objects.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from typing import Any, TypeVar, Type, Generic
|
||||
|
||||
from app.core.workflow.variable.base_variable import BaseVariable, VariableType
|
||||
|
||||
T = TypeVar("T", bound=BaseVariable)
|
||||
|
||||
|
||||
class StringObject(BaseVariable):
|
||||
type = 'str'
|
||||
|
||||
def valid_value(self, value) -> str:
|
||||
if not isinstance(value, str):
|
||||
raise TypeError("Value must be a string")
|
||||
return value
|
||||
|
||||
def to_literal(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class NumberObject(BaseVariable):
|
||||
type = 'number'
|
||||
|
||||
def valid_value(self, value) -> int | float:
|
||||
if not isinstance(value, (int, float)):
|
||||
raise TypeError("Value must be a number")
|
||||
return value
|
||||
|
||||
def to_literal(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
|
||||
class BooleanObject(BaseVariable):
|
||||
type = 'boolean'
|
||||
|
||||
def valid_value(self, value) -> bool:
|
||||
if not isinstance(value, bool):
|
||||
raise TypeError("Value must be a boolean")
|
||||
return value
|
||||
|
||||
def to_literal(self) -> str:
|
||||
return str(self.value).lower()
|
||||
|
||||
|
||||
class DictObject(BaseVariable):
|
||||
type = 'object'
|
||||
|
||||
def valid_value(self, value) -> dict:
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError("Value must be a dict")
|
||||
return value
|
||||
|
||||
def to_literal(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
|
||||
class FileObject(BaseVariable):
|
||||
type = 'file'
|
||||
|
||||
def valid_value(self, value) -> Any:
|
||||
pass
|
||||
|
||||
def to_literal(self) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class ArrayObject(BaseVariable, Generic[T]):
|
||||
type = 'array'
|
||||
|
||||
def __init__(self, child_type: Type[T], value: list[Any]):
|
||||
if not issubclass(child_type, BaseVariable):
|
||||
raise TypeError("child_type must be a subclass of BaseVariable")
|
||||
self.child_type = child_type
|
||||
super().__init__(value)
|
||||
|
||||
def valid_value(self, value: list[Any]) -> list[T]:
|
||||
if not isinstance(value, list):
|
||||
raise TypeError("Value must be a list")
|
||||
final_value = []
|
||||
for v in value:
|
||||
try:
|
||||
final_value.append(self.child_type(v))
|
||||
except:
|
||||
raise TypeError(f"All elements must be of type {self.child_type.type}")
|
||||
return final_value
|
||||
|
||||
def to_literal(self) -> str:
|
||||
return "\n".join([v.to_literal() for v in self.value])
|
||||
|
||||
|
||||
class NestedArrayObject(BaseVariable):
|
||||
type = 'array_nest'
|
||||
|
||||
def valid_value(self, value: list[T]) -> list[T]:
|
||||
if not isinstance(value, list):
|
||||
raise TypeError("Value must be a list")
|
||||
final_value = []
|
||||
for v in value:
|
||||
if not isinstance(v, ArrayObject):
|
||||
raise TypeError("All elements must be of type list")
|
||||
final_value.append(v)
|
||||
return final_value
|
||||
|
||||
def to_literal(self) -> str:
|
||||
return "\n".join(["\n".join([item.to_literal() for item in row]) for row in self.value])
|
||||
|
||||
def get_value(self) -> Any:
|
||||
return [[item.get_value() for item in row] for row in self.value]
|
||||
|
||||
|
||||
def make_array(child_type: Type[T], value: list[Any]) -> ArrayObject[T]:
|
||||
"""简化 ArrayObject 创建,不需要重复写类型"""
|
||||
|
||||
return ArrayObject(child_type, value)
|
||||
|
||||
|
||||
def create_variable_instance(var_type: VariableType, value: Any) -> T:
|
||||
match var_type:
|
||||
case VariableType.STRING:
|
||||
return StringObject(value)
|
||||
case VariableType.NUMBER:
|
||||
return NumberObject(value)
|
||||
case VariableType.BOOLEAN:
|
||||
return BooleanObject(value)
|
||||
case VariableType.OBJECT:
|
||||
return DictObject(value)
|
||||
case VariableType.ARRAY_STRING:
|
||||
return make_array(StringObject, value)
|
||||
case VariableType.ARRAY_NUMBER:
|
||||
return make_array(NumberObject, value)
|
||||
case VariableType.ARRAY_BOOLEAN:
|
||||
return make_array(BooleanObject, value)
|
||||
case VariableType.ARRAY_OBJECT:
|
||||
return make_array(DictObject, value)
|
||||
case VariableType.ARRAY_FILE:
|
||||
return make_array(FileObject, value)
|
||||
case _:
|
||||
raise TypeError(f"Invalid type - {var_type}")
|
||||
@@ -11,10 +11,15 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from asyncio import Lock
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from typing import Any, Generic
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable.variable_objects import T, create_variable_instance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,11 +28,6 @@ class VariableSelector:
|
||||
"""变量选择器
|
||||
|
||||
用于引用变量的路径表示。
|
||||
|
||||
Examples:
|
||||
>>> selector = VariableSelector(["sys", "message"])
|
||||
>>> selector = VariableSelector(["node_A", "output"])
|
||||
>>> selector = VariableSelector.from_string("sys.message")
|
||||
"""
|
||||
|
||||
def __init__(self, path: list[str]):
|
||||
@@ -52,10 +52,6 @@ class VariableSelector:
|
||||
|
||||
Returns:
|
||||
VariableSelector 实例
|
||||
|
||||
Examples:
|
||||
>>> selector = VariableSelector.from_string("sys.message")
|
||||
>>> selector = VariableSelector.from_string("llm_qa.output")
|
||||
"""
|
||||
path = selector_str.split(".")
|
||||
return cls(path)
|
||||
@@ -67,160 +63,212 @@ class VariableSelector:
|
||||
return f"VariableSelector({self.path})"
|
||||
|
||||
|
||||
class VariableStruct(BaseModel, Generic[T]):
|
||||
"""A typed variable struct.
|
||||
|
||||
Represents a runtime variable with an associated logical type and
|
||||
a concrete value object.
|
||||
|
||||
This class bridges the static type system (via generics) and the
|
||||
runtime type system (via ``VariableType``).
|
||||
|
||||
Attributes:
|
||||
type:
|
||||
Logical variable type descriptor used for runtime validation,
|
||||
serialization, and workflow type checking.
|
||||
instance:
|
||||
The concrete variable object. The actual Python type is
|
||||
represented by the generic parameter ``T`` (e.g. StringObject,
|
||||
NumberObject, ArrayObject[StringObject]).
|
||||
mut:
|
||||
Whether the variable is mutable.
|
||||
"""
|
||||
type: VariableType
|
||||
instance: T
|
||||
mut: bool
|
||||
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True
|
||||
}
|
||||
|
||||
|
||||
class VariablePool:
|
||||
"""变量池
|
||||
|
||||
管理工作流执行过程中的所有变量。
|
||||
|
||||
变量命名空间:
|
||||
- sys.*: 系统变量(message, execution_id, workspace_id, user_id, conversation_id)
|
||||
- conv.*: 会话变量(跨多轮对话保持的变量)
|
||||
- <node_id>.*: 节点输出
|
||||
|
||||
Examples:
|
||||
>>> pool = VariablePool(state)
|
||||
>>> pool.get(["sys", "message"])
|
||||
"用户的问题"
|
||||
>>> pool.get(["llm_qa", "output"])
|
||||
"AI 的回答"
|
||||
>>> pool.set(["conv", "user_name"], "张三")
|
||||
"""Variable pool.
|
||||
|
||||
Manages all variables during workflow execution, including storage,
|
||||
namespacing, and concurrency control.
|
||||
|
||||
Variable namespace conventions:
|
||||
- ``sys.*``:
|
||||
System variables (e.g. message, execution_id, workspace_id,
|
||||
user_id, conversation_id).
|
||||
- ``conv.*``:
|
||||
Conversation-level variables that persist across multiple turns.
|
||||
- ``<node_id>.*``:
|
||||
Variables produced by workflow nodes.
|
||||
"""
|
||||
|
||||
def __init__(self, state: "WorkflowState"):
|
||||
"""初始化变量池
|
||||
|
||||
Args:
|
||||
state: 工作流状态(LangGraph State)
|
||||
"""
|
||||
self.state = state
|
||||
def __init__(self):
|
||||
"""Initialize the variable pool.
|
||||
|
||||
Attributes:
|
||||
self.locks:
|
||||
A per-key lock table used for fine-grained concurrency control.
|
||||
|
||||
self.variables:
|
||||
Storage for all variables managed by the pool.
|
||||
"""
|
||||
self.locks = defaultdict(Lock)
|
||||
self.variables: dict[str, dict[str, VariableStruct[Any]]] = {}
|
||||
|
||||
@staticmethod
|
||||
def transform_selector(selector):
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
variable_literal = re.sub(pattern, r"\1", selector).strip()
|
||||
selector = VariableSelector.from_string(variable_literal).path
|
||||
if len(selector) != 2:
|
||||
raise ValueError(f"Selector not valid - {selector}")
|
||||
return selector
|
||||
|
||||
def _get_variable_struct(
|
||||
self,
|
||||
selector: str
|
||||
) -> VariableStruct[T] | None:
|
||||
"""Retrieve a variable struct from the variable pool.
|
||||
|
||||
def get(self, selector: list[str] | str, default: Any = None) -> Any:
|
||||
"""获取变量值
|
||||
|
||||
Args:
|
||||
selector: 变量选择器,可以是列表或字符串
|
||||
default: 默认值(变量不存在时返回)
|
||||
|
||||
selector:
|
||||
Variable selector, either:
|
||||
- A string variable literal (e.g. "{{ sys.message }}")
|
||||
|
||||
Returns:
|
||||
变量值
|
||||
|
||||
Examples:
|
||||
>>> pool.get(["sys", "message"])
|
||||
>>> pool.get("sys.message")
|
||||
>>> pool.get(["llm_qa", "output"])
|
||||
>>> pool.get("llm_qa.output")
|
||||
|
||||
Raises:
|
||||
KeyError: 变量不存在且未提供默认值
|
||||
The variable's struct if it exists; otherwise returns None.
|
||||
"""
|
||||
# 转换为 VariableSelector
|
||||
if isinstance(selector, str):
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
variable_literal = re.sub(pattern, r"\1", selector).strip()
|
||||
selector = VariableSelector.from_string(variable_literal).path
|
||||
|
||||
if not selector or len(selector) < 1:
|
||||
raise ValueError("变量选择器不能为空")
|
||||
selector = self.transform_selector(selector)
|
||||
|
||||
namespace = selector[0]
|
||||
variable_name = selector[1]
|
||||
|
||||
try:
|
||||
# 系统变量
|
||||
if namespace == "sys":
|
||||
key = selector[1] if len(selector) > 1 else None
|
||||
if not key:
|
||||
return self.state.get("variables", {}).get("sys", {})
|
||||
return self.state.get("variables", {}).get("sys", {}).get(key, default)
|
||||
namespace_variables = self.variables.get(namespace)
|
||||
if namespace_variables is None:
|
||||
return None
|
||||
|
||||
# 会话变量
|
||||
elif namespace == "conv":
|
||||
key = selector[1] if len(selector) > 1 else None
|
||||
if not key:
|
||||
return self.state.get("variables", {}).get("conv", {})
|
||||
return self.state.get("variables", {}).get("conv", {}).get(key, default)
|
||||
var_instance = namespace_variables.get(variable_name)
|
||||
if var_instance is None:
|
||||
return None
|
||||
return var_instance
|
||||
|
||||
# 节点输出(从 runtime_vars 读取)
|
||||
else:
|
||||
node_id = namespace
|
||||
runtime_vars = self.state.get("runtime_vars", {})
|
||||
def get_value(
|
||||
self,
|
||||
selector: str,
|
||||
default: Any = None,
|
||||
strict: bool = True,
|
||||
) -> Any:
|
||||
"""Retrieve a variable value from the variable pool.
|
||||
|
||||
if node_id not in runtime_vars:
|
||||
if default is not None:
|
||||
return default
|
||||
raise KeyError(f"节点 '{node_id}' 的输出不存在")
|
||||
Args:
|
||||
selector:
|
||||
Variable selector, either:
|
||||
- A list of path components (e.g. ["sys", "message"])
|
||||
- A string variable literal (e.g. "{{ sys.message }}")
|
||||
default:
|
||||
The value to return if the variable does not exist.
|
||||
strict:
|
||||
If True, raises KeyError when the variable does not exist.
|
||||
|
||||
node_var = runtime_vars[node_id]
|
||||
Returns:
|
||||
The variable's value if it exists; otherwise returns `default`.
|
||||
|
||||
# 如果只有节点 ID,返回整个变量
|
||||
if len(selector) == 1:
|
||||
return node_var
|
||||
Raises:
|
||||
KeyError: If strict is True and the variable does not exist.
|
||||
"""
|
||||
variable_struct = self._get_variable_struct(selector)
|
||||
if variable_struct is None:
|
||||
if strict:
|
||||
raise KeyError(f"{selector} not exist")
|
||||
return default
|
||||
|
||||
# 获取特定字段
|
||||
# 支持嵌套访问,如 node_id.field.subfield
|
||||
result = node_var
|
||||
for k in selector[1:]:
|
||||
if isinstance(result, dict):
|
||||
result = result.get(k)
|
||||
if result is None:
|
||||
if default is not None:
|
||||
return default
|
||||
raise KeyError(f"字段 '{'.'.join(selector)}' 不存在")
|
||||
else:
|
||||
if default is not None:
|
||||
return default
|
||||
raise KeyError(f"无法访问 '{'.'.join(selector)}'")
|
||||
return variable_struct.instance.get_value()
|
||||
|
||||
return result
|
||||
def get_literal(
|
||||
self,
|
||||
selector: str,
|
||||
default: Any = None,
|
||||
strict: bool = True,
|
||||
) -> Any:
|
||||
"""Retrieve a variable value from the variable pool.
|
||||
|
||||
except KeyError:
|
||||
if default is not None:
|
||||
return default
|
||||
raise
|
||||
Args:
|
||||
selector:
|
||||
Variable selector, either:
|
||||
- A list of path components (e.g. ["sys", "message"])
|
||||
- A string variable literal (e.g. "{{ sys.message }}")
|
||||
default:
|
||||
The value to return if the variable does not exist.
|
||||
strict:
|
||||
If True, raises KeyError when the variable does not exist.
|
||||
|
||||
def set(self, selector: list[str] | str, value: Any):
|
||||
Returns:
|
||||
The variable's value if it exists; otherwise returns `default`.
|
||||
|
||||
Raises:
|
||||
KeyError: If strict is True and the variable does not exist.
|
||||
"""
|
||||
variable_struct = self._get_variable_struct(selector)
|
||||
if variable_struct is None:
|
||||
if strict:
|
||||
raise KeyError(f"{selector} not exist")
|
||||
return default
|
||||
|
||||
return variable_struct.instance.to_literal()
|
||||
|
||||
async def set(
|
||||
self,
|
||||
selector: str,
|
||||
value: Any
|
||||
):
|
||||
"""设置变量值
|
||||
|
||||
Args:
|
||||
selector: 变量选择器
|
||||
value: 变量值
|
||||
|
||||
Examples:
|
||||
>>> pool.set(["conv", "user_name"], "张三")
|
||||
>>> pool.set("conv.user_name", "张三")
|
||||
|
||||
|
||||
Note:
|
||||
- 只能设置会话变量 (conv.*)
|
||||
- 系统变量和节点输出是只读的
|
||||
"""
|
||||
# 转换为 VariableSelector
|
||||
if isinstance(selector, str):
|
||||
selector = VariableSelector.from_string(selector).path
|
||||
variable_struct = self._get_variable_struct(selector)
|
||||
if variable_struct is None:
|
||||
raise KeyError(f"Variable {selector} is not defined")
|
||||
if not variable_struct.mut:
|
||||
raise KeyError(f"{selector} cannot be modified")
|
||||
async with self.locks[selector]:
|
||||
variable_struct.instance.set(value)
|
||||
|
||||
if not selector or len(selector) < 2:
|
||||
raise ValueError("变量选择器必须包含命名空间和键名")
|
||||
async def new(
|
||||
self,
|
||||
namespace: str,
|
||||
key: str,
|
||||
value: Any,
|
||||
var_type: VariableType,
|
||||
mut: bool
|
||||
):
|
||||
if self.has(f"{namespace}.{key}"):
|
||||
try:
|
||||
await self.set(f"{namespace}.{key}", value)
|
||||
except KeyError:
|
||||
pass
|
||||
instance = create_variable_instance(var_type, value)
|
||||
variable_struct = VariableStruct(type=var_type, instance=instance, mut=mut)
|
||||
namespace_variable = self.variables.get(namespace)
|
||||
if namespace_variable is None:
|
||||
self.variables[namespace] = {
|
||||
key: variable_struct
|
||||
}
|
||||
else:
|
||||
self.variables[namespace][key] = variable_struct
|
||||
|
||||
namespace = selector[0]
|
||||
|
||||
if namespace != "conv" and namespace not in self.state["cycle_nodes"]:
|
||||
raise ValueError("Only conversation or cycle variables can be assigned.")
|
||||
|
||||
key = selector[1]
|
||||
|
||||
# 确保 variables 结构存在
|
||||
if "variables" not in self.state:
|
||||
self.state["variables"] = {"sys": {}, "conv": {}}
|
||||
if namespace == "conv":
|
||||
if "conv" not in self.state["variables"]:
|
||||
self.state["variables"]["conv"] = {}
|
||||
|
||||
# 设置值
|
||||
self.state["variables"]["conv"][key] = value
|
||||
elif namespace in self.state["cycle_nodes"]:
|
||||
self.state["runtime_vars"][namespace][key] = value
|
||||
|
||||
logger.debug(f"设置变量: {'.'.join(selector)} = {value}")
|
||||
|
||||
def has(self, selector: list[str] | str) -> bool:
|
||||
def has(self, selector: str) -> bool:
|
||||
"""检查变量是否存在
|
||||
|
||||
Args:
|
||||
@@ -228,18 +276,8 @@ class VariablePool:
|
||||
|
||||
Returns:
|
||||
变量是否存在
|
||||
|
||||
Examples:
|
||||
>>> pool.has(["sys", "message"])
|
||||
True
|
||||
>>> pool.has("llm_qa.output")
|
||||
False
|
||||
"""
|
||||
try:
|
||||
self.get(selector)
|
||||
return True
|
||||
except KeyError:
|
||||
return False
|
||||
return self._get_variable_struct(selector) is not None
|
||||
|
||||
def get_all_system_vars(self) -> dict[str, Any]:
|
||||
"""获取所有系统变量
|
||||
@@ -247,7 +285,8 @@ class VariablePool:
|
||||
Returns:
|
||||
系统变量字典
|
||||
"""
|
||||
return self.state.get("variables", {}).get("sys", {})
|
||||
sys_namespace = self.variables.get("sys", {})
|
||||
return {k: v.instance.value for k, v in sys_namespace.items()}
|
||||
|
||||
def get_all_conversation_vars(self) -> dict[str, Any]:
|
||||
"""获取所有会话变量
|
||||
@@ -255,7 +294,8 @@ class VariablePool:
|
||||
Returns:
|
||||
会话变量字典
|
||||
"""
|
||||
return self.state.get("variables", {}).get("conv", {})
|
||||
conv_namespace = self.variables.get("conv", {})
|
||||
return {k: v.instance.value for k, v in conv_namespace.items()}
|
||||
|
||||
def get_all_node_outputs(self) -> dict[str, Any]:
|
||||
"""获取所有节点输出(运行时变量)
|
||||
@@ -263,18 +303,37 @@ class VariablePool:
|
||||
Returns:
|
||||
节点输出字典,键为节点 ID
|
||||
"""
|
||||
return self.state.get("runtime_vars", {})
|
||||
runtime_vars = {
|
||||
namespace: {
|
||||
k: v.instance.value
|
||||
for k, v in vars_dict.items()
|
||||
}
|
||||
for namespace, vars_dict in self.variables.items()
|
||||
if namespace not in ("sys", "conv")
|
||||
}
|
||||
return runtime_vars
|
||||
|
||||
def get_node_output(self, node_id: str) -> dict[str, Any] | None:
|
||||
def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None:
|
||||
"""获取指定节点的输出(运行时变量)
|
||||
|
||||
Args:
|
||||
node_id: 节点 ID
|
||||
defalut: 默认值
|
||||
strict: 是否严格模式
|
||||
|
||||
Returns:
|
||||
节点输出或 None
|
||||
"""
|
||||
return self.state.get("runtime_vars", {}).get(node_id)
|
||||
node_namespace = self.variables.get(node_id)
|
||||
if node_namespace:
|
||||
return {k: v.instance.value for k, v in node_namespace.items()}
|
||||
if strict:
|
||||
raise KeyError(f"node {node_id} output not exist")
|
||||
else:
|
||||
return defalut
|
||||
|
||||
def copy(self, pool: 'VariablePool'):
|
||||
self.variables = deepcopy(pool.variables)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""导出为字典
|
||||
|
||||
@@ -618,6 +618,7 @@ class AppChatService:
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
public=False
|
||||
|
||||
) -> AsyncGenerator[dict, None]:
|
||||
"""聊天(流式)"""
|
||||
@@ -634,7 +635,8 @@ class AppChatService:
|
||||
payload=payload,
|
||||
config=config,
|
||||
workspace_id=workspace_id,
|
||||
release_id=release_id
|
||||
release_id=release_id,
|
||||
public=public
|
||||
):
|
||||
yield event
|
||||
|
||||
|
||||
@@ -4,9 +4,8 @@
|
||||
import datetime
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Annotated, AsyncGenerator, Optional
|
||||
from typing import Any, Annotated, Optional
|
||||
|
||||
from deprecated import deprecated
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -566,6 +565,41 @@ class WorkflowService:
|
||||
message=f"工作流执行失败: {str(e)}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _map_public_event(event: dict) -> dict | None:
|
||||
event_type = event.get("event")
|
||||
payload = event.get("data")
|
||||
match event_type:
|
||||
case "workflow_start":
|
||||
return {
|
||||
"event": "start",
|
||||
"data": {
|
||||
"conversation_id": payload.get("conversation_id"),
|
||||
}
|
||||
}
|
||||
case "workflow_end":
|
||||
return {
|
||||
"event": "end",
|
||||
"data": {
|
||||
"elapsed_time": payload.get("elapsed_time"),
|
||||
"message_length": len(payload.get("output", ""))
|
||||
}
|
||||
}
|
||||
case "node_start" | "node_end" | "node_error":
|
||||
return None
|
||||
case _:
|
||||
return event
|
||||
|
||||
def _emit(self, public: bool, internal_event: dict):
|
||||
"""
|
||||
decide
|
||||
"""
|
||||
if public:
|
||||
mapped = self._map_public_event(internal_event)
|
||||
else:
|
||||
mapped = internal_event
|
||||
return mapped
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
@@ -663,7 +697,7 @@ class WorkflowService:
|
||||
input_data=input_data,
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id=str(workspace_id),
|
||||
user_id=payload.user_id
|
||||
user_id=payload.user_id,
|
||||
):
|
||||
if event.get("event") == "workflow_end":
|
||||
|
||||
@@ -694,7 +728,9 @@ class WorkflowService:
|
||||
)
|
||||
else:
|
||||
logger.error(f"unexpect workflow run status, status: {status}")
|
||||
yield event
|
||||
event = self._emit(public, event)
|
||||
if event:
|
||||
yield event
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
||||
|
||||
@@ -33,7 +33,7 @@ async def run_code(request: RunCodeRequest):
|
||||
"""Execute code in sandbox"""
|
||||
if request.language == "python3":
|
||||
return await run_python_code(request.code, request.preload, request.options)
|
||||
elif request.language == "nodejs":
|
||||
elif request.language == "javascript":
|
||||
return await run_nodejs_code(request.code, request.preload, request.options)
|
||||
else:
|
||||
return error_response(-400, "unsupported language")
|
||||
|
||||
Reference in New Issue
Block a user