feat(workflow): enforce strong typing for runtime variables

- Reduce exposed information in release workflows
This commit is contained in:
Eternity
2026-02-04 11:01:16 +08:00
parent 308e28cecc
commit bd8a451879
50 changed files with 1925 additions and 1372 deletions

View File

@@ -587,7 +587,8 @@ async def chat(
user_rag_memory_id=user_rag_memory_id,
app_id=release.app_id,
workspace_id=workspace_id,
release_id=release.id
release_id=release.id,
public=True
):
event_type = event.get("event", "message")
event_data = event.get("data", {})

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,3 @@
"""
安全的表达式求值器
使用 simpleeval 库提供安全的表达式评估,避免代码注入攻击。
"""
import logging
import re
from typing import Any
@@ -14,160 +8,119 @@ logger = logging.getLogger(__name__)
class ExpressionEvaluator:
"""安全的表达式求值器"""
"""Safe expression evaluator for workflow variables and node outputs."""
# 保留的命名空间
# Reserved namespaces
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
@staticmethod
def evaluate(
expression: str,
variables: dict[str, Any],
conv_vars: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> Any:
"""安全地评估表达式
Args:
expression: 表达式字符串,如 "{{var.score}} > 0.8"
variables: 用户定义的变量
node_outputs: 节点输出结果
system_vars: 系统变量
Returns:
表达式求值结果
Raises:
ValueError: 表达式无效或求值失败
Examples:
>>> evaluator = ExpressionEvaluator()
>>> evaluator.evaluate(
... "var.score > 0.8",
... {"score": 0.9},
... {},
... {}
... )
True
>>> evaluator.evaluate(
... "node.intent.output == '售前咨询'",
... {},
... {"intent": {"output": "售前咨询"}},
... {}
... )
True
"""
# 移除 Jinja2 模板语法的花括号(如果存在)
Safely evaluate an expression using workflow variables.
Args:
expression (str): The expression string, e.g., "var.score > 0.8"
conv_vars (dict): Conversation-level variables
node_outputs (dict): Outputs from workflow nodes
system_vars (dict, optional): System variables
Returns:
Any: Result of the evaluated expression
Raises:
ValueError: If the expression is invalid or evaluation fails
"""
# Remove Jinja2-style brackets if present
expression = expression.strip()
# "{{system.message}} == {{ user.messge }}" -> "system.message == user.message"
pattern = r"\{\{\s*(.*?)\s*\}\}"
expression = re.sub(pattern, r"\1", expression).strip()
# 构建命名空间上下文
# Build context for evaluation
context = {
"var": variables, # 用户变量
"node": node_outputs, # 节点输出
"sys": system_vars or {}, # 系统变量
"conv": conv_vars, # conversation variables
"node": node_outputs, # node outputs
"sys": system_vars or {}, # system variables
}
# 为了向后兼容,也支持直接访问(但会在日志中警告)
context.update(variables)
context.update(conv_vars)
context["nodes"] = node_outputs
context.update(node_outputs)
try:
# simpleeval 只支持安全的操作:
# - 算术运算: +, -, *, /, //, %, **
# - 比较运算: ==, !=, <, <=, >, >=
# - 逻辑运算: and, or, not
# - 成员运算: in, not in
# - 属性访问: obj.attr
# - 字典/列表访问: obj["key"], obj[0]
# 不支持:函数调用、导入、赋值等危险操作
# simpleeval supports safe operations:
# arithmetic, comparisons, logical ops, attribute/dict/list access
result = simple_eval(expression, names=context)
return result
except NameNotDefined as e:
logger.error(f"表达式中引用了未定义的变量: {expression}, 错误: {e}")
raise ValueError(f"未定义的变量: {e}")
logger.error(f"Undefined variable in expression: {expression}, error: {e}")
raise ValueError(f"Undefined variable: {e}")
except InvalidExpression as e:
logger.error(f"表达式语法无效: {expression}, 错误: {e}")
raise ValueError(f"表达式语法无效: {e}")
logger.error(f"Invalid expression syntax: {expression}, error: {e}")
raise ValueError(f"Invalid expression syntax: {e}")
except SyntaxError as e:
logger.error(f"表达式语法错误: {expression}, 错误: {e}")
raise ValueError(f"表达式语法错误: {e}")
logger.error(f"Syntax error in expression: {expression}, error: {e}")
raise ValueError(f"Syntax error: {e}")
except Exception as e:
logger.error(f"表达式求值异常: {expression}, 错误: {e}")
raise ValueError(f"表达式求值失败: {e}")
logger.error(f"Expression evaluation failed: {expression}, error: {e}")
raise ValueError(f"Expression evaluation failed: {e}")
@staticmethod
def evaluate_bool(
expression: str,
variables: dict[str, Any],
conv_var: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> bool:
"""评估布尔表达式(用于条件判断)
"""
Evaluate a boolean expression (for conditions).
Args:
expression: 布尔表达式
variables: 用户变量
node_outputs: 节点输出
system_vars: 系统变量
expression (str): Boolean expression
conv_var (dict): Conversation variables
node_outputs (dict): Node outputs
system_vars (dict, optional): System variables
Returns:
布尔值结果
Examples:
>>> ExpressionEvaluator.evaluate_bool(
... "var.count >= 10 and var.status == 'active'",
... {"count": 15, "status": "active"},
... {},
... {}
... )
True
bool: Boolean result
"""
result = ExpressionEvaluator.evaluate(
expression, variables, node_outputs, system_vars
expression, conv_var, node_outputs, system_vars
)
return bool(result)
@staticmethod
def validate_variable_names(variables: list[dict]) -> list[str]:
"""验证变量名是否合法
"""
Validate variable names for legality.
Args:
variables: 变量定义列表
variables (list[dict]): List of variable definitions
Returns:
错误列表,如果为空则验证通过
Examples:
>>> ExpressionEvaluator.validate_variable_names([
... {"name": "user_input"},
... {"name": "var"} # 保留字
... ])
["变量名 'var' 是保留的命名空间,请使用其他名称"]
list[str]: List of error messages. Empty if all names are valid.
"""
errors = []
for var in variables:
var_name = var.get("name", "")
# 检查是否为保留命名空间
if var_name in ExpressionEvaluator.RESERVED_NAMESPACES:
errors.append(
f"变量名 '{var_name}' 是保留的命名空间,请使用其他名称"
f"Variable name '{var_name}' is a reserved namespace, please use another name"
)
# 检查是否为有效的 Python 标识符
if not var_name.isidentifier():
errors.append(
f"变量名 '{var_name}' 不是有效的标识符"
f"Variable name '{var_name}' is not a valid Python identifier"
)
return errors
@@ -176,23 +129,23 @@ class ExpressionEvaluator:
# 便捷函数
def evaluate_expression(
expression: str,
variables: dict[str, Any],
conv_var: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
system_vars: dict[str, Any]
) -> Any:
"""评估表达式(便捷函数)"""
"""Evaluate an expression (convenience function)."""
return ExpressionEvaluator.evaluate(
expression, variables, node_outputs, system_vars
expression, conv_var, node_outputs, system_vars
)
def evaluate_condition(
expression: str,
variables: dict[str, Any],
conv_var: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> bool:
"""评估条件表达式(便捷函数)"""
"""Evaluate a boolean condition expression (convenience function)."""
return ExpressionEvaluator.evaluate_bool(
expression, variables, node_outputs, system_vars
expression, conv_var, node_outputs, system_vars
)

View File

@@ -14,9 +14,14 @@ from pydantic import BaseModel, Field
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
SCOPE_PATTERN = re.compile(
r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\.[a-zA-Z0-9_]+\s*}}"
)
class OutputContent(BaseModel):
"""
@@ -53,6 +58,12 @@ class OutputContent(BaseModel):
)
)
_SCOPE: str | None = None
def get_scope(self) -> str:
self._SCOPE = SCOPE_PATTERN.findall(self.literal)[0]
return self._SCOPE
def depends_on_scope(self, scope: str) -> bool:
"""
Check if this segment depends on a given scope.
@@ -63,8 +74,9 @@ class OutputContent(BaseModel):
Returns:
bool: True if this segment references the given scope.
"""
pattern = rf"\{{\{{\s*{re.escape(scope)}\.[a-zA-Z0-9_]+\s*\}}\}}"
return bool(re.search(pattern, self.literal))
if self._SCOPE:
return self._SCOPE == scope
return self.get_scope() == scope
class StreamOutputConfig(BaseModel):
@@ -167,6 +179,7 @@ class GraphBuilder:
workflow_config: dict[str, Any],
stream: bool = False,
subgraph: bool = False,
variable_pool: VariablePool | None = None
):
self.workflow_config = workflow_config
@@ -180,6 +193,10 @@ class GraphBuilder:
self._find_upstream_branch_node = lru_cache(
maxsize=len(self.nodes) * 2
)(self._find_upstream_branch_node)
if variable_pool:
self.variable_pool = variable_pool
else:
self.variable_pool = VariablePool()
self.graph = StateGraph(WorkflowState)
self.add_nodes()
@@ -452,9 +469,9 @@ class GraphBuilder:
if self.stream:
# Stream mode: create an async generator function
# LangGraph collects all yielded values; the last yielded dictionary is merged into the state
def make_stream_func(inst):
def make_stream_func(inst, variable_pool=self.variable_pool):
async def node_func(state: WorkflowState):
async for item in inst.run_stream(state):
async for item in inst.run_stream(state, variable_pool):
yield item
return node_func
@@ -462,9 +479,9 @@ class GraphBuilder:
self.graph.add_node(node_id, make_stream_func(node_instance))
else:
# Non-stream mode: create an async function
def make_func(inst):
def make_func(inst, variable_pool=self.variable_pool):
async def node_func(state: WorkflowState):
return await inst.run(state)
return await inst.run(state, variable_pool)
return node_func
@@ -567,27 +584,28 @@ class GraphBuilder:
for target in branch_info["target"]:
waiting_edges[target].append(branch_info["node"]["name"])
def router_fn(state: WorkflowState) -> list[Send]:
def router_fn(state: WorkflowState, variable_pool: VariablePool = self.variable_pool) -> list[Send]:
branch_activate = []
new_state = state.copy()
new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate
node_output = variable_pool.get_node_output(src, defalut=dict(), strict=False)
for label, branch in unique_branch.items():
if evaluate_condition(
if node_output and evaluate_condition(
branch["condition"],
state.get("variables", {}),
state.get("runtime_vars", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
{},
{src: node_output},
{}
):
logger.debug(f"Conditional routing {src}: selected branch {label}")
new_state["activate"][branch["node"]["name"]] = True
branch_activate.append(
Send(
branch['node']['name'],
new_state
)
)
continue
new_state["activate"][branch["node"]["name"]] = False
for label, branch in unique_branch.items():
branch_activate.append(
Send(
branch['node']['name'],

View File

@@ -15,7 +15,6 @@ from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.tool import ToolNode
@@ -25,7 +24,6 @@ __all__ = [
"WorkflowState",
"LLMNode",
"AgentNode",
"TransformNode",
"IfElseNode",
"StartNode",
"EndNode",

View File

@@ -2,7 +2,8 @@
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
from app.core.workflow.variable.base_variable import VariableType
class AgentNodeConfig(BaseNodeConfig):

View File

@@ -2,6 +2,7 @@
Agent 节点实现
调用已发布的 Agent 应用。
# TODO
"""
import logging
@@ -9,6 +10,8 @@ from typing import Any
from langchain_core.messages import AIMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
from app.services.draft_run_service import DraftRunService
from app.models import AppRelease
from app.db import get_db
@@ -30,19 +33,22 @@ class AgentNode(BaseNode):
}
}
"""
def _prepare_agent(self, state: WorkflowState) -> tuple[DraftRunService, AppRelease, str]:
def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING}
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[DraftRunService, AppRelease, str]:
"""准备 Agent公共逻辑
Args:
state: 工作流状态
variable_pool: 变量池
Returns:
(draft_service, release, message): 服务实例、发布配置、消息
"""
# 1. 渲染消息
message_template = self.config.get("message", "")
message = self._render_template(message_template, state)
message = self._render_template(message_template, variable_pool)
# 2. 获取 Agent 配置
agent_id = self.config.get("agent_id")
@@ -61,16 +67,17 @@ class AgentNode(BaseNode):
return draft_service, release, message
async def execute(self, state: WorkflowState) -> dict[str, Any]:
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""非流式执行
Args:
state: 工作流状态
variable_pool: 变量池
Returns:
状态更新字典
"""
draft_service, release, message = self._prepare_agent(state)
draft_service, release, message = self._prepare_agent(variable_pool)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
@@ -79,9 +86,9 @@ class AgentNode(BaseNode):
agent_config=release.config,
model_config=None,
message=message,
workspace_id=state.get("workspace_id"),
workspace_id=variable_pool.get_value("sys.workspace_id"),
user_id=state.get("user_id"),
variables=state.get("variables", {})
variables=variable_pool.get_all_conversation_vars()
)
response = result.get("response", "")
@@ -99,16 +106,17 @@ class AgentNode(BaseNode):
}
}
async def execute_stream(self, state: WorkflowState):
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
"""流式执行
Args:
state: 工作流状态
variable_pool: 变量池
Yields:
流式事件字典
"""
draft_service, release, message = self._prepare_agent(state)
draft_service, release, message = self._prepare_agent(variable_pool)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
@@ -120,9 +128,9 @@ class AgentNode(BaseNode):
agent_config=release.config,
model_config=None,
message=message,
workspace_id=state.get("workspace_id"),
workspace_id=variable_pool.get_value("sys.workspace_id"),
user_id=state.get("user_id"),
variables=state.get("variables", {})
variables=variable_pool.get_all_conversation_vars()
):
# 提取内容
content = chunk.get("content", "")

View File

@@ -6,6 +6,7 @@ from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import AssignmentOperator
from app.core.workflow.nodes.operators import AssignmentOperatorInstance, AssignmentOperatorResolver
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -17,13 +18,17 @@ class AssignerNode(BaseNode):
self.variable_updater = True
self.typed_config: AssignerNodeConfig | None = None
async def execute(self, state: WorkflowState) -> Any:
def _output_types(self) -> dict[str, VariableType]:
return {}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
Execute the assignment operation defined by this node.
Args:
state: The current workflow state, including conversation variables,
node outputs, and system variables.
variable_pool: variable pool
Returns:
None or the result of the assignment operation.
@@ -31,60 +36,57 @@ class AssignerNode(BaseNode):
# Initialize a variable pool for accessing conversation, node, and system variables
self.typed_config = AssignerNodeConfig(**self.config)
logger.info(f"节点 {self.node_id} 开始执行")
pool = VariablePool(state)
pattern = r"\{\{\s*(.*?)\s*\}\}"
for assignment in self.typed_config.assignments:
# Get the target variable selector (e.g., "conv.test")
variable_selector = assignment.variable_selector
if isinstance(variable_selector, str):
# Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"]
pattern = r"\{\{\s*(.*?)\s*\}\}"
expression = re.sub(pattern, r"\1", variable_selector).strip()
variable_selector = expression.split('.')
namespace = re.sub(pattern, r"\1", variable_selector).split('.')[0]
# Only conversation variables ('conv') are allowed
if variable_selector[0] != 'conv' and variable_selector[0] not in state["cycle_nodes"]:
raise ValueError("Only conversation or cycle variables can be assigned.")
if namespace != 'conv' and namespace not in state["cycle_nodes"]:
raise ValueError(f"Only conversation or cycle variables can be assigned. - {variable_selector}")
# Get the value or expression to assign
value = assignment.value
logger.debug(f"left:{variable_selector}, right: {value}")
pattern = r"\{\{\s*(.*?)\s*\}\}"
if isinstance(value, str):
expression = re.match(pattern, value)
if expression:
expression = expression.group(1)
expression = re.sub(pattern, r"\1", expression).strip()
value = self.get_variable(expression, state)
value = self.get_variable(expression, variable_pool, default=value, strict=False)
# Select the appropriate assignment operator instance based on the target variable type
operator: AssignmentOperatorInstance = AssignmentOperatorResolver.resolve_by_value(
pool.get(variable_selector)
variable_pool.get_value(variable_selector)
)(
pool, variable_selector, value
variable_pool, variable_selector, value
)
# Execute the configured assignment operation
match assignment.operation:
case AssignmentOperator.COVER:
operator.assign()
await operator.assign()
case AssignmentOperator.ASSIGN:
operator.assign()
await operator.assign()
case AssignmentOperator.CLEAR:
operator.clear()
await operator.clear()
case AssignmentOperator.ADD:
operator.add()
await operator.add()
case AssignmentOperator.SUBTRACT:
operator.subtract()
await operator.subtract()
case AssignmentOperator.MULTIPLY:
operator.multiply()
await operator.multiply()
case AssignmentOperator.DIVIDE:
operator.divide()
await operator.divide()
case AssignmentOperator.APPEND:
operator.append()
await operator.append()
case AssignmentOperator.REMOVE_FIRST:
operator.remove_first()
await operator.remove_first()
case AssignmentOperator.REMOVE_LAST:
operator.remove_last()
await operator.remove_last()
case _:
raise ValueError(f"Invalid Operator: {assignment.operation}")
logger.info(f"Node {self.node_id}: execution completed")

View File

@@ -3,79 +3,13 @@
定义所有节点配置的通用字段和数据结构。
"""
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from app.core.workflow.variable.base_variable import VariableType
VARIABLE_PATTERN = r"\{\{\s*(.*?)\s*\}\}"
class VariableType(StrEnum):
"""变量类型枚举"""
STRING = "string"
NUMBER = "number"
BOOLEAN = "boolean"
OBJECT = "object"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_BOOLEAN = "array[boolean]"
ARRAY_OBJECT = "array[object]"
class TypedVariable(BaseModel):
"""
TODO: 强类型限制
Strongly typed variable that validates value on assignment.
"""
value: Any = Field(..., description="Variable value")
type: VariableType = Field(..., description="Declared type of the variable")
model_config = ConfigDict(
validate_assignment=True
)
def __setattr__(self, name, value):
if name == "value":
self._validate_value(value)
if name == "type":
raise RuntimeError("Cannot modify variable type at runtime")
super().__setattr__(name, value)
def _validate_value(self, v: Any):
t = self.type
match t:
case VariableType.STRING:
if not isinstance(v, str):
raise TypeError("Variable value does not match type STRING")
case VariableType.BOOLEAN:
if not isinstance(v, bool):
raise TypeError("Variable value does not match type BOOLEAN")
case VariableType.NUMBER:
if not isinstance(v, (int, float)):
raise TypeError("Variable value does not match type NUMBER")
case VariableType.OBJECT:
if not isinstance(v, dict):
raise TypeError("Variable value does not match type OBJECT")
case VariableType.ARRAY_STRING:
if not isinstance(v, list) or not all(isinstance(i, str) for i in v):
raise TypeError("Variable value does not match type ARRAY_STRING")
case VariableType.ARRAY_NUMBER:
if not isinstance(v, list) or not all(isinstance(i, (int, float)) for i in v):
raise TypeError("Variable value does not match type ARRAY_NUMBER")
case VariableType.ARRAY_BOOLEAN:
if not isinstance(v, list) or not all(isinstance(i, bool) for i in v):
raise TypeError("Variable value does not match type ARRAY_BOOLEAN")
case VariableType.ARRAY_OBJECT:
if not isinstance(v, list) or not all(isinstance(i, dict) for i in v):
raise TypeError("Variable value does not match type ARRAY_OBJECT")
case _:
raise TypeError(f"Unknown variable type: {t}")
class VariableDefinition(BaseModel):
"""变量定义

View File

@@ -1,12 +1,7 @@
"""
工作流节点基类
定义节点的基本接口和通用功能。
"""
import asyncio
import logging
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Any, AsyncGenerator
from langgraph.config import get_stream_writer
@@ -14,6 +9,7 @@ from typing_extensions import TypedDict, Annotated
from app.core.config import settings
from app.core.workflow.nodes.enums import BRANCH_NODES
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -42,22 +38,10 @@ class WorkflowState(TypedDict):
cycle_nodes: list
looping: Annotated[int, merge_looping_state]
# Input variables (passed from configured variables)
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
variables: Annotated[dict[str, Any], lambda x, y: {
**x,
**{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v
for k, v in y.items()}
}]
# Node outputs (stores execution results of each node for variable references)
# Uses a custom merge function to combine new node outputs into the existing dictionary
node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# Runtime node variables (simplified version, stores business data for fast access between nodes)
# Format: {node_id: business_result}
runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# Execution context
execution_id: str
workspace_id: str
@@ -72,17 +56,17 @@ class WorkflowState(TypedDict):
class BaseNode(ABC):
"""节点基类
所有节点类型都应该继承此基类,实现 execute 方法。
"""Base class for workflow nodes.
All node types should inherit from this class and implement the `execute` method.
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
"""初始化节点
"""Initialize the node.
Args:
node_config: 节点配置
workflow_config: 工作流配置
node_config: Configuration of the node.
workflow_config: Configuration of the workflow.
"""
self.node_config = node_config
self.workflow_config = workflow_config
@@ -94,7 +78,27 @@ class BaseNode(ABC):
self.config = node_config.get("config") or {}
self.error_handling = node_config.get("error_handling") or {}
self.variable_updater = False
self.variable_change_able = False
@cached_property
def output_types(self) -> dict[str, VariableType]:
"""Returns the output variable types of the node.
This property is cached to avoid recomputation.
"""
return self._output_types()
@abstractmethod
def _output_types(self) -> dict[str, VariableType]:
"""Defines output variable types for the node.
Subclasses must override this method to declare the variables
produced by the node and their corresponding types.
Returns:
A mapping from output variable names to ``VariableType``.
"""
return {}
def check_activate(self, state: WorkflowState):
"""Check if the current node is activated in the workflow state.
@@ -136,92 +140,84 @@ class BaseNode(ABC):
}
@abstractmethod
async def execute(self, state: WorkflowState) -> Any:
"""执行节点业务逻辑(非流式)
节点只需要返回业务结果,不需要关心输出格式、时间统计等。
BaseNode 会自动包装成标准格式。
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""Executes the node business logic (non-streaming).
The node implementation should only return the business result.
It does not need to handle output formatting, timing, or statistics.
The ``BaseNode`` will automatically wrap the result into a standard
response format.
Args:
state: 工作流状态
state: The current workflow state.
variable_pool: The variable pool used for reading and writing variables.
Returns:
业务结果(任意类型)
Examples:
>>> # LLM 节点
>>> "这是 AI 的回复"
>>> # Transform 节点
>>> {"processed_data": [...]}
>>> # Start/End 节点
>>> {"message": "开始", "conversation_id": "xxx"}
The business result produced by the node. The return value can be
of any type.
"""
pass
async def execute_stream(self, state: WorkflowState):
"""执行节点业务逻辑(流式)
子类可以重写此方法以支持流式输出。
默认实现:执行非流式方法并一次性返回。
节点需要:
1. yield 中间结果(如文本片段)
2. 最后 yield 一个特殊的完成标记:{"__final__": True, "result": final_result}
Args:
state: 工作流状态
Yields:
业务数据chunk或完成标记
Examples:
# 流式 LLM 节点
full_response = ""
async for chunk in llm.astream(prompt):
full_response += chunk
yield chunk # yield 文本片段
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
"""Executes the node business logic in streaming mode.
# 最后 yield 完成标记
yield {"__final__": True, "result": AIMessage(content=full_response)}
Subclasses may override this method to support streaming output.
The default implementation executes the non-streaming method and
yields a single final result.
For streaming execution, a node implementation should:
1. Yield intermediate results (e.g. text chunks).
2. Yield a final completion marker in the following format:
``{"__final__": True, "result": final_result}``.
Args:
state: The current workflow state.
variable_pool: The variable pool used for reading and writing variables.
Yields:
Business data chunks or a final completion marker.
"""
result = await self.execute(state)
# 默认实现:直接 yield 完成标记
result = await self.execute(state, variable_pool)
# Default implementation: yield a single final completion marker.
yield {"__final__": True, "result": result}
def supports_streaming(self) -> bool:
"""节点是否支持流式输出
"""Returns whether the node supports streaming output.
A node is considered to support streaming if its class overrides
the ``execute_stream`` method. If the default implementation from
``BaseNode`` is used, streaming is not supported.
Returns:
是否支持流式输出
True if the node supports streaming output, False otherwise.
"""
# 检查子类是否重写了 execute_stream 方法
# Check whether the subclass overrides the execute_stream method.
return self.__class__.execute_stream is not BaseNode.execute_stream
def get_timeout(self) -> int:
"""获取超时时间(秒)
@staticmethod
def get_timeout() -> int:
"""Returns the execution timeout in seconds.
Returns:
超时时间
The timeout duration, in seconds.
"""
return settings.WORKFLOW_NODE_TIMEOUT
# return self.error_handling.get("timeout", 60)
async def run(self, state: WorkflowState) -> dict[str, Any]:
"""执行节点(带错误处理和输出包装,非流式)
这个方法由 Executor 调用,负责:
1. 时间统计
2. 调用节点的 execute() 方法
3. 将业务结果包装成标准输出格式
4. 错误处理
async def run(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""Runs the node with error handling and output wrapping (non-streaming).
This method is invoked by the Executor and is responsible for:
1. Execution time measurement.
2. Invoking the node's ``execute()`` method.
3. Wrapping the business result into a standardized output format.
4. Handling execution errors.
Args:
state: 工作流状态
state: The current workflow state.
variable_pool: The variable pool used for reading and writing variables.
Returns:
标准化的状态更新字典
A standardized state update dictionary.
"""
if not self.check_activate(state):
return self.trans_activate(state)
@@ -233,70 +229,78 @@ class BaseNode(ABC):
timeout = self.get_timeout()
try:
# 调用节点的业务逻辑
# Invoke the node business logic.
business_result = await asyncio.wait_for(
self.execute(state),
self.execute(state, variable_pool),
timeout=timeout
)
elapsed_time = time.time() - start_time
# 提取处理后的输出(调用子类的 _extract_output
# Extract processed outputs using subclass-defined logic.
extracted_output = self._extract_output(business_result)
# 包装成标准输出格式
wrapped_output = self._wrap_output(business_result, elapsed_time, state)
# Wrap the business result into the standard output format.
wrapped_output = self._wrap_output(business_result, elapsed_time, state, variable_pool)
# 将提取后的输出存储到运行时变量中(供后续节点快速访问)
# 如果提取后的输出是字典,拆包存储;否则存储为 output 字段
if isinstance(extracted_output, dict):
runtime_var = extracted_output
else:
runtime_var = {"output": extracted_output}
# Store extracted outputs as runtime variables for downstream nodes.
if extracted_output is not None:
runtime_vars = extracted_output
if not isinstance(extracted_output, dict):
runtime_vars = {"output": extracted_output}
for k, v in runtime_vars.items():
await variable_pool.new(self.node_id, k, v, self.output_types[k], mut=self.variable_change_able)
# 返回包装后的输出和运行时变量
# Return the wrapped output along with activation state updates.
return {
**wrapped_output,
"messages": state["messages"],
"runtime_vars": {
self.node_id: runtime_var
},
"looping": state["looping"]
} | self.trans_activate(state)
except TimeoutError:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
return self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
logger.error(
f"Node {self.node_id} execution timed out ({timeout} seconds)."
)
return self._wrap_error(
f"Node execution timed out ({timeout} seconds).",
elapsed_time,
state,
variable_pool,
)
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
return self._wrap_error(str(e), elapsed_time, state)
logger.error(
f"Node {self.node_id} execution failed: {e}",
exc_info=True,
)
return self._wrap_error(str(e), elapsed_time, state, variable_pool)
async def run_stream(
self, state: WorkflowState,
variable_pool: VariablePool
) -> AsyncGenerator[dict[str, Any], Any]:
"""Executes the node with error handling and output wrapping (streaming).
async def run_stream(self, state: WorkflowState) -> AsyncGenerator[dict[str, Any], Any]:
"""Execute node with error handling and output wrapping (streaming)
This method is called by the Executor and is responsible for:
1. Time tracking
2. Calling the node's execute_stream() method
3. Using LangGraph's stream writer to send chunks
4. Updating streaming buffer in state for downstream nodes
5. Wrapping business data into standard output format
6. Error handling
Special handling for End nodes:
- End nodes don't send chunks via writer (prefix and LLM content already sent)
- End nodes only yield suffix for final result assembly
1. Tracking execution time.
2. Calling the node's ``execute_stream()`` method.
3. Sending streaming chunks via LangGraph's stream writer.
4. Updating activation-related state for downstream nodes.
5. Wrapping business data into a standardized output format.
6. Handling execution errors.
Args:
state: Workflow state
state: The current workflow state.
variable_pool: The variable pool used for reading and writing variables.
Yields:
State updates with streaming buffer and final result
Incremental state updates, including activation state changes and
the final wrapped result.
"""
if not self.check_activate(state):
yield self.trans_activate(state)
logger.info(f"jump node: {self.node_id}")
logger.debug(f"jump node: {self.node_id}")
return
import time
@@ -317,7 +321,7 @@ class BaseNode(ABC):
# Stream chunks in real-time
loop_start = asyncio.get_event_loop().time()
async for item in self.execute_stream(state):
async for item in self.execute_stream(state, variable_pool):
# Check timeout
if asyncio.get_event_loop().time() - loop_start > timeout:
raise TimeoutError()
@@ -332,7 +336,7 @@ class BaseNode(ABC):
chunks.append(content)
# Send chunks for all nodes (including End nodes for suffix)
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {content[:50]}...")
logger.debug(f"Node {self.node_id} sent chunk #{chunk_count}: {content[:50]}...")
# 1. Send via stream writer (for real-time client updates)
writer({
@@ -344,27 +348,26 @@ class BaseNode(ABC):
elapsed_time = time.time() - start_time
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
logger.info(f"Node {self.node_id} streaming execution finished, "
f"time elapsed: {elapsed_time:.2f}s, chunks: {chunk_count}")
# Extract processed output (call subclass's _extract_output)
extracted_output = self._extract_output(final_result)
# Wrap final result
final_output = self._wrap_output(final_result, elapsed_time, state)
final_output = self._wrap_output(final_result, elapsed_time, state, variable_pool)
# Store extracted output in runtime variables (for quick access by subsequent nodes)
if isinstance(extracted_output, dict):
runtime_var = extracted_output
else:
runtime_var = {"output": extracted_output}
if extracted_output is not None:
runtime_vars = extracted_output
if not isinstance(extracted_output, dict):
runtime_vars = {"output": extracted_output}
for k, v in runtime_vars.items():
await variable_pool.new(self.node_id, k, v, self.output_types[k], mut=self.variable_change_able)
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
state_update = {
**final_output,
"messages": state["messages"],
"runtime_vars": {
self.node_id: runtime_var
},
"looping": state["looping"]
}
@@ -374,41 +377,49 @@ class BaseNode(ABC):
except TimeoutError:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行超时 ({timeout}s)")
error_output = self._wrap_error(f"节点执行超时 ({timeout}s)", elapsed_time, state)
logger.error(f"Node {self.node_id} execution timed out ({timeout}s)")
error_output = self._wrap_error(
f"Node execution timed out ({timeout}s)",
elapsed_time,
state,
variable_pool
)
yield error_output
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
error_output = self._wrap_error(str(e), elapsed_time, state)
logger.error(f"Node {self.node_id} execution failed: {e}", exc_info=True)
error_output = self._wrap_error(str(e), elapsed_time, state, variable_pool)
yield error_output
def _wrap_output(
self,
business_result: Any,
elapsed_time: float,
state: WorkflowState
state: WorkflowState,
variable_pool: VariablePool
) -> dict[str, Any]:
"""将业务结果包装成标准输出格式
Args:
business_result: 节点返回的业务结果
elapsed_time: 执行耗时
state: 工作流状态
Returns:
标准化的状态更新字典
"""
# 提取输入数据(用于记录)
input_data = self._extract_input(state)
"""Wraps the business result into a standardized node output format.
# 提取 token 使用情况(如果有)
Args:
business_result: The result returned by the node's business logic.
elapsed_time: Time elapsed during node execution (in seconds).
state: The current workflow state.
variable_pool: The variable pool used for reading and writing variables.
Returns:
A dictionary representing the standardized state update for this node,
including node outputs, input, output, elapsed time, token usage, and status.
"""
# Extract input data (for logging or audit purposes)
input_data = self._extract_input(state, variable_pool)
# Extract token usage information (if applicable)
token_usage = self._extract_token_usage(business_result)
# 提取实际输出(去除元数据)
# Extract actual output (strip any metadata)
output = self._extract_output(business_result)
# 构建标准节点输出
# Construct standardized node output
node_output = {
"node_id": self.node_id,
"node_type": self.node_type,
@@ -423,8 +434,6 @@ class BaseNode(ABC):
final_output = {
"node_outputs": {self.node_id: node_output},
}
if self.variable_updater:
final_output = final_output | {"variables": state["variables"]}
return final_output
@@ -432,25 +441,33 @@ class BaseNode(ABC):
self,
error_message: str,
elapsed_time: float,
state: WorkflowState
state: WorkflowState,
variable_pool: VariablePool
) -> dict[str, Any]:
"""将错误包装成标准输出格式
"""Wraps an error into a standardized node output format.
This method handles both cases:
- If an error edge is defined, the workflow can continue to the error handling node.
- If no error edge exists, the workflow is stopped by raising an exception.
Args:
error_message: 错误信息
elapsed_time: 执行耗时
state: 工作流状态
error_message: The error message describing the failure.
elapsed_time: Time elapsed during node execution (in seconds).
state: The current workflow state.
variable_pool: The variable pool used for reading and writing variables.
Returns:
标准化的状态更新字典
A dictionary representing the standardized state update for this node
when an error edge exists. If no error edge exists, this method
raises an exception to stop the workflow.
"""
# 查找错误边
# Check if the node has an error edge defined
error_edge = self._find_error_edge()
# 提取输入数据
input_data = self._extract_input(state)
# Extract input data (for logging or audit purposes)
input_data = self._extract_input(state, variable_pool)
# 构建错误输出
# Construct the standardized node output for the error
node_output = {
"node_id": self.node_id,
"node_type": self.node_type,
@@ -464,9 +481,9 @@ class BaseNode(ABC):
}
if error_edge:
# 有错误边:记录错误并继续
# If an error edge exists, log a warning and continue to error node
logger.warning(
f"节点 {self.node_id} 执行失败,跳转到错误处理节点: {error_edge['target']}"
f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
)
return {
"node_outputs": {
@@ -476,198 +493,161 @@ class BaseNode(ABC):
"error_node": self.node_id
}
else:
# If no error edge, send the error via stream writer and stop the workflow
writer = get_stream_writer()
writer({
"type": "node_error",
**node_output
})
# 无错误边:抛出异常停止工作流
logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}")
raise Exception(f"节点 {self.node_id} 执行失败: {error_message}")
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""Extracts the input data for this node (used for logging or audit).
Subclasses may override this method to customize what input data
should be recorded.
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
"""提取节点输入数据(用于记录)
子类可以重写此方法来自定义输入记录。
Args:
state: 工作流状态
state: The current workflow state.
variable_pool: The variable pool used for reading and writing variables.
Returns:
输入数据字典
A dictionary containing the node's input data.
"""
# 默认返回配置
# Default implementation returns the node configuration
return {"config": self.config}
def _extract_output(self, business_result: Any) -> Any:
"""从业务结果中提取实际输出
子类可以重写此方法来自定义输出提取。
"""Extracts the actual output from the business result.
Subclasses may override this method to customize how the node's
output is extracted.
Args:
business_result: 业务结果
business_result: The result returned by the node's business logic.
Returns:
实际输出
The actual output extracted from the business result.
"""
# 默认直接返回业务结果
# Default implementation returns the business result directly
return business_result
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
"""从业务结果中提取 token 使用情况
子类可以重写此方法来提取 token 信息。
"""Extracts token usage information from the business result.
Subclasses may override this method to extract token usage statistics
(e.g., for LLM nodes).
Args:
business_result: 业务结果
business_result: The result returned by the node's business logic.
Returns:
token 使用情况或 None
A dictionary mapping token types to counts, or None if not applicable.
"""
# 默认返回 None
# Default implementation returns None
return None
def _find_error_edge(self) -> dict[str, Any] | None:
"""查找错误边
"""Finds the error edge for this node, if any.
An error edge is used to redirect workflow execution when this node
fails.
Returns:
错误边配置或 None
A dictionary representing the error edge configuration if it exists,
or None if no error edge is defined.
"""
for edge in self.workflow_config.get("edges", []):
if edge.get("source") == self.node_id and edge.get("type") == "error":
return edge
return None
def _render_template(self, template: str, state: WorkflowState | None, strict: bool = True) -> str:
"""渲染模板
支持的变量命名空间:
- sys.xxx: 系统变量message, execution_id, workspace_id, user_id, conversation_id
- conv.xxx: 会话变量(跨多轮对话保持)
- node_id.xxx: 节点输出
@staticmethod
def _render_template(template: str, variable_pool: VariablePool, strict: bool = True) -> str:
"""Renders a template string using the provided variable pool.
Supported variable namespaces:
- sys.xxx: System variables (e.g., message, execution_id, workspace_id,
user_id, conversation_id)
- conv.xxx: Conversation variables (persist across multiple turns)
- node_id.xxx: Node outputs
Args:
template: 模板字符串
state: 工作流状态
template: The template string to render.
variable_pool: The variable pool containing system, conversation, and
node variables.
strict: If True, missing variables will raise an error; if False,
missing variables are ignored.
Returns:
渲染后的字符串
The rendered string with all variables substituted.
"""
from app.core.workflow.template_renderer import render_template
# 处理 state 为 None 的情况
if state is None:
state = {}
# 使用变量池获取变量
pool = VariablePool(state)
# 构建完整的 variables 结构
variables = {
"sys": pool.get_all_system_vars(),
"conv": pool.get_all_conversation_vars()
}
return render_template(
template=template,
variables=variables,
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
conv_vars=variable_pool.get_all_conversation_vars(),
node_outputs=variable_pool.get_all_node_outputs(),
system_vars=variable_pool.get_all_system_vars(),
strict=strict
)
def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool:
"""评估条件表达式
支持的变量命名空间:
- sys.xxx: 系统变量
- conv.xxx: 会话变量
- node_id.xxx: 节点输出
@staticmethod
def _evaluate_condition(expression: str, variable_pool: VariablePool) -> bool:
"""Evaluates a conditional expression using the provided variable pool.
Supported variable namespaces:
- sys.xxx: System variables
- conv.xxx: Conversation variables
- node_id.xxx: Node outputs
Args:
expression: 条件表达式
state: 工作流状态
expression: The conditional expression to evaluate.
variable_pool: The variable pool containing system, conversation, and
node variables.
Returns:
布尔值结果
The boolean result of evaluating the expression.
"""
from app.core.workflow.expression_evaluator import evaluate_condition
# 处理 state 为 None 的情况
if state is None:
state = {}
# 使用变量池获取变量
pool = VariablePool(state)
# 构建完整的 variables 结构(包含 sys 和 conv
variables = {
"sys": pool.get_all_system_vars(),
"conv": pool.get_all_conversation_vars()
}
return evaluate_condition(
expression=expression,
variables=variables,
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars()
conv_var=variable_pool.get_all_conversation_vars(),
node_outputs=variable_pool.get_all_node_outputs(),
system_vars=variable_pool.get_all_system_vars()
)
def get_variable_pool(self, state: WorkflowState) -> VariablePool:
"""获取变量池实例
VariablePool 是轻量级包装器,只持有 state 的引用,创建成本极低。
Args:
state: 工作流状态
Returns:
VariablePool 实例
Examples:
>>> pool = self.get_variable_pool(state)
>>> message = pool.get("sys.message")
>>> llm_output = pool.get("llm_qa.output")
"""
return VariablePool(state)
@staticmethod
def get_variable(
self,
selector: list[str] | str,
state: WorkflowState,
default: Any = None
selector: str,
variable_pool: VariablePool,
default: Any = None,
strict: bool = True
) -> Any:
"""获取变量值(便捷方法)
Args:
selector: 变量选择器
state: 工作流状态
default: 默认值
Returns:
变量值
Examples:
>>> message = self.get_variable("sys.message", state)
>>> output = self.get_variable(["llm_qa", "output"], state)
>>> custom = self.get_variable("var.custom", state, default="默认值")
"""
pool = VariablePool(state)
return pool.get(selector, default=default)
"""Retrieves a variable value from the variable pool (convenience method).
def has_variable(self, selector: list[str] | str, state: WorkflowState) -> bool:
"""检查变量是否存在(便捷方法)
Args:
selector: 变量选择器
state: 工作流状态
selector: The variable selector (can be namespaced, e.g., sys.xxx, conv.xxx, node_id.xxx).
variable_pool: The variable pool from which to fetch the value.
default: The default value to return if the variable does not exist.
strict: If True, raise an error when the variable is missing; if False, return the default.
Returns:
变量是否存在
Examples:
>>> if self.has_variable("llm_qa.output", state):
... output = self.get_variable("llm_qa.output", state)
The value of the selected variable, or the default if not found and strict is False.
"""
pool = VariablePool(state)
return pool.has(selector)
return variable_pool.get_value(selector, default, strict=strict)
@staticmethod
def has_variable(selector: str, variable_pool: VariablePool) -> bool:
"""Checks whether a variable exists in the variable pool (convenience method).
Args:
selector: The variable selector (can be namespaced, e.g., sys.xxx, conv.xxx, node_id.xxx).
variable_pool: The variable pool to check.
Returns:
True if the variable exists in the pool, False otherwise.
"""
return variable_pool.has(selector)

View File

@@ -2,6 +2,8 @@ import logging
from typing import Any
from app.core.workflow.nodes import BaseNode, WorkflowState
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -14,15 +16,19 @@ class BreakNode(BaseNode):
to False, signaling the outer loop runtime to terminate further iterations.
"""
async def execute(self, state: WorkflowState) -> Any:
def _output_types(self) -> dict[str, VariableType]:
return {}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
Execute the break node.
Args:
state: Current workflow state, including loop control flags.
variable_pool: Pool of variables for the workflow.
Effects:
- Sets 'looping' in the state to False to stop the loop.
- Sets 'looping' in the state too False to stop the loop.
- Logs the action for debugging purposes.
Returns:

View File

@@ -1,7 +1,8 @@
from typing import Literal
from pydantic import Field, BaseModel
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.variable.base_variable import VariableType
class InputVariable(BaseModel):
@@ -44,7 +45,7 @@ class CodeNodeConfig(BaseNodeConfig):
description="code content"
)
language: Literal['python3', 'nodejs'] = Field(
language: Literal['python3', 'javascript'] = Field(
...,
description="language"
)

View File

@@ -9,8 +9,9 @@ from typing import Any
import httpx
from app.core.workflow.nodes import BaseNode, WorkflowState
from app.core.workflow.nodes.base_config import VariableType
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.nodes.code.config import CodeNodeConfig
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -52,6 +53,12 @@ class CodeNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: CodeNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
output_dict = {}
for output in self.typed_config.output_variables:
output_dict[output.name] = output.type
return output_dict
def extract_result(self, content: str):
match = re.search(r'<<RESULT>>(.*?)<<RESULT>>', content, re.DOTALL)
if match:
@@ -92,11 +99,11 @@ class CodeNode(BaseNode):
else:
raise RuntimeError("The output of main must be a dictionary")
async def execute(self, state: WorkflowState) -> Any:
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
self.typed_config = CodeNodeConfig(**self.config)
input_variable_dict = {}
for input_variable in self.typed_config.input_variables:
input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, state)
input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, variable_pool)
code = base64.b64decode(
self.typed_config.code
@@ -110,7 +117,7 @@ class CodeNode(BaseNode):
code=code,
inputs_variable=input_variable_dict,
)
elif self.typed_config.language == 'nodejs':
elif self.typed_config.language == 'javascript':
final_script = NODEJS_SCRIPT_TEMPLATE.substitute(
code=code,
inputs_variable=input_variable_dict,

View File

@@ -8,7 +8,6 @@ from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
from app.core.workflow.nodes.base_config import (
BaseNodeConfig,
VariableDefinition,
VariableType,
)
from app.core.workflow.nodes.code.config import CodeNodeConfig
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
@@ -23,21 +22,18 @@ from app.core.workflow.nodes.parameter_extractor.config import ParameterExtracto
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
from app.core.workflow.nodes.start.config import StartNodeConfig
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
__all__ = [
# 基础类
"BaseNodeConfig",
"VariableDefinition",
"VariableType",
# 节点配置
"StartNodeConfig",
"EndNodeConfig",
"LLMNodeConfig",
"MessageConfig",
"AgentNodeConfig",
"TransformNodeConfig",
"IfElseNodeConfig",
"KnowledgeRetrievalNodeConfig",
"AssignerNodeConfig",

View File

@@ -2,7 +2,8 @@ from typing import Any
from pydantic import Field, BaseModel, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
@@ -127,4 +128,9 @@ class IterationNodeConfig(BaseNodeConfig):
description="Output of the loop iteration"
)
output_type: VariableType = Field(
...,
description="Data type of the loop iteration output"
)

View File

@@ -7,6 +7,7 @@ from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -28,6 +29,8 @@ class IterationRuntime:
node_id: str,
config: dict[str, Any],
state: WorkflowState,
variable_pool: VariablePool,
child_variable_pool: VariablePool,
):
"""
Initialize the iteration runtime.
@@ -44,11 +47,13 @@ class IterationRuntime:
self.node_id = node_id
self.typed_config = IterationNodeConfig(**config)
self.looping = True
self.variable_pool = variable_pool
self.child_variable_pool = child_variable_pool
self.output_value = None
self.result: list = []
def _init_iteration_state(self, item, idx):
async def _init_iteration_state(self, item, idx):
"""
Initialize a per-iteration copy of the workflow state.
@@ -62,10 +67,9 @@ class IterationRuntime:
loopstate = WorkflowState(
**self.state
)
loopstate["runtime_vars"][self.node_id] = {
"item": item,
"index": idx,
}
self.child_variable_pool.copy(self.variable_pool)
await self.child_variable_pool.new(self.node_id, "item", item, VariableType.type_map(item), mut=True)
await self.child_variable_pool.new(self.node_id, "index", item, VariableType.type_map(item), mut=True)
loopstate["node_outputs"][self.node_id] = {
"item": item,
"index": idx,
@@ -74,6 +78,11 @@ class IterationRuntime:
loopstate["activate"][self.start_id] = True
return loopstate
def merge_conv_vars(self):
self.variable_pool.get_all_conversation_vars().update(
self.child_variable_pool.get_all_conversation_vars()
)
async def run_task(self, item, idx):
"""
Execute a single iteration asynchronously.
@@ -82,8 +91,8 @@ class IterationRuntime:
item: The input element for this iteration.
idx: The index of this iteration.
"""
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
output = VariablePool(result).get(self.output_value)
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
output = self.child_variable_pool.get_value(self.output_value)
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
else:
@@ -125,7 +134,7 @@ class IterationRuntime:
input_expression = re.sub(pattern, r"\1", self.typed_config.input).strip()
self.output_value = re.sub(pattern, r"\1", self.typed_config.output).strip()
array_obj = VariablePool(self.state).get(input_expression)
array_obj = self.variable_pool.get_value(input_expression)
if not isinstance(array_obj, list):
raise RuntimeError("Cannot iterate over a non-list variable")
child_state = []
@@ -137,14 +146,16 @@ class IterationRuntime:
logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}")
idx += self.typed_config.parallel_count
child_state.extend(await asyncio.gather(*tasks))
self.merge_conv_vars()
else:
# Execute iterations sequentially
while idx < len(array_obj) and self.looping:
logger.info(f"Iteration node {self.node_id}: running")
item = array_obj[idx]
result = await self.graph.ainvoke(self._init_iteration_state(item, idx))
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
child_state.append(result)
output = VariablePool(result).get(self.output_value)
output = self.child_variable_pool.get_value(self.output_value)
self.merge_conv_vars()
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
else:

View File

@@ -31,6 +31,8 @@ class LoopRuntime:
node_id: str,
config: dict[str, Any],
state: WorkflowState,
variable_pool: VariablePool,
child_variable_pool: VariablePool
):
"""
Initialize the loop runtime executor.
@@ -40,6 +42,8 @@ class LoopRuntime:
node_id: The unique identifier of the loop node in the workflow.
config: Raw configuration dictionary for the loop node.
state: The current workflow state before entering the loop.
variable_pool: A VariablePool instance for accessing and modifying workflow variables.
child_variable_pool: A VariablePool instance for managing child node outputs.
"""
self.start_id = start_id
self.graph = graph
@@ -47,8 +51,10 @@ class LoopRuntime:
self.node_id = node_id
self.typed_config = LoopNodeConfig(**config)
self.looping = True
self.variable_pool = variable_pool
self.child_variable_pool = child_variable_pool
def _init_loop_state(self):
async def _init_loop_state(self):
"""
Initialize workflow state for loop execution.
@@ -62,33 +68,35 @@ class LoopRuntime:
Returns:
WorkflowState: A prepared workflow state used for loop execution.
"""
pool = VariablePool(self.state)
# 循环变量
self.state["runtime_vars"][self.node_id] = {
variable.name: evaluate_expression(
expression=variable.value,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
)
if variable.input_type == ValueInputType.VARIABLE
else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars
}
self.state["node_outputs"][self.node_id] = {
variable.name: evaluate_expression(
expression=variable.value,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
)
if variable.input_type == ValueInputType.VARIABLE
else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars
}
self.child_variable_pool.copy(self.variable_pool)
for variable in self.typed_config.cycle_vars:
if variable.input_type == ValueInputType.VARIABLE:
value = evaluate_expression(
expression=variable.value,
conv_var=self.variable_pool.get_all_conversation_vars(),
node_outputs=self.variable_pool.get_all_node_outputs(),
system_vars=self.variable_pool.get_all_system_vars(),
)
else:
value = TypeTransformer.transform(variable.value, variable.type)
await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True)
loopstate = WorkflowState(
**self.state
)
loopstate["node_outputs"][self.node_id] = {
variable.name: evaluate_expression(
expression=variable.value,
conv_var=self.variable_pool.get_all_conversation_vars(),
node_outputs=self.variable_pool.get_all_node_outputs(),
system_vars=self.variable_pool.get_all_system_vars(),
)
if variable.input_type == ValueInputType.VARIABLE
else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars
}
loopstate["looping"] = 1
loopstate["activate"][self.start_id] = True
return loopstate
@@ -134,7 +142,12 @@ class LoopRuntime:
case _:
raise ValueError(f"Invalid condition: {operator}")
def evaluate_conditional(self, state) -> bool:
def merge_conv_vars(self):
self.variable_pool.variables["conv"].update(
self.child_variable_pool.variables.get("conv", {})
)
def evaluate_conditional(self) -> bool:
"""
Evaluate the loop continuation condition at runtime.
@@ -143,18 +156,15 @@ class LoopRuntime:
- Evaluates each comparison expression immediately
- Combines results using the configured logical operator (AND / OR)
Args:
state: The current workflow state during loop execution.
Returns:
bool: True if the loop should continue, False otherwise.
"""
conditions = []
for expression in self.typed_config.condition.expressions:
left_value = VariablePool(state).get(expression.left)
left_value = self.child_variable_pool.get_value(expression.left)
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
VariablePool(state),
self.child_variable_pool,
expression.left,
expression.right,
expression.input_type
@@ -177,16 +187,18 @@ class LoopRuntime:
Returns:
dict[str, Any]: The final runtime variables of this loop node.
"""
loopstate = self._init_loop_state()
loopstate = await self._init_loop_state()
loop_time = self.typed_config.max_loop
child_state = []
while self.evaluate_conditional(loopstate) and self.looping and loop_time > 0:
while not self.evaluate_conditional() and self.looping and loop_time > 0:
logger.info(f"loop node {self.node_id}: running")
result = await self.graph.ainvoke(loopstate)
child_state.append(result)
self.merge_conv_vars()
if result["looping"] == 2:
self.looping = False
loop_time -= 1
logger.info(f"loop node {self.node_id}: execution completed")
return loopstate["runtime_vars"][self.node_id] | {"__child_state": child_state}
return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state}

View File

@@ -6,9 +6,12 @@ from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig
from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime
from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -35,9 +38,38 @@ class CycleGraphNode(BaseNode):
self.start_node_id = None # ID of the start node within the cycle
self.graph: StateGraph | CompiledStateGraph | None = None
self.child_variable_pool: VariablePool | None = None
self.build_graph()
self.iteration_flag = True
def _output_types(self) -> dict[str, VariableType]:
outputs = {"__child_state": VariableType.ARRAY_OBJECT}
if self.node_type == NodeType.LOOP:
# Loop node outputs the final state of the loop
config = LoopNodeConfig(**self.config)
for var_def in config.cycle_vars:
outputs[var_def.name] = var_def.type
return outputs
elif self.node_type == NodeType.ITERATION:
# Iteration node outputs the processed collection
config = IterationNodeConfig(**self.config)
if config.output_type in [
VariableType.ARRAY_FILE,
VariableType.ARRAY_STRING,
VariableType.NUMBER,
VariableType.ARRAY_OBJECT,
VariableType.BOOLEAN
]:
if config.flatten:
outputs['output'] = config.output_type
else:
outputs['output'] = VariableType.ARRAY_STRING
else:
outputs['output'] = VariableType(f"array[{config.output_type}]")
return outputs
else:
raise KeyError(f"Valid Cycle Node Type - {self.node_type}")
def pure_cycle_graph(self) -> tuple[list, list]:
"""
Extract cycle-scoped nodes and internal edges from the workflow configuration.
@@ -103,17 +135,20 @@ class CycleGraphNode(BaseNode):
"""
from app.core.workflow.graph_builder import GraphBuilder
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.child_variable_pool = VariablePool()
builder = GraphBuilder(
{
"nodes": self.cycle_nodes,
"edges": self.cycle_edges,
},
subgraph=True
subgraph=True,
variable_pool=self.child_variable_pool
)
self.start_node_id = builder.start_node_id
self.graph = builder.build()
self.child_variable_pool = builder.variable_pool
async def execute(self, state: WorkflowState) -> Any:
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
Execute the cycle node at runtime.
@@ -123,6 +158,7 @@ class CycleGraphNode(BaseNode):
Args:
state: The current workflow state when entering the cycle node.
variable_pool: Variable Pool
Returns:
Any: The runtime result produced by the loop or iteration executor.
@@ -137,6 +173,8 @@ class CycleGraphNode(BaseNode):
node_id=self.node_id,
config=self.config,
state=state,
variable_pool=variable_pool,
child_variable_pool=self.child_variable_pool,
).run()
if self.node_type == NodeType.ITERATION:
return await IterationRuntime(
@@ -145,5 +183,7 @@ class CycleGraphNode(BaseNode):
node_id=self.node_id,
config=self.config,
state=state,
variable_pool=variable_pool,
child_variable_pool=self.child_variable_pool
).run()
raise RuntimeError("Unknown cycle node type")

View File

@@ -2,7 +2,8 @@
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
from app.core.workflow.variable.base_variable import VariableType
class EndNodeConfig(BaseNodeConfig):

View File

@@ -7,6 +7,8 @@ End 节点实现
import logging
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -17,12 +19,18 @@ class EndNode(BaseNode):
工作流的结束节点,根据配置的模板输出最终结果。
支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。
"""
def _output_types(self) -> dict[str, VariableType]:
"""声明此节点的输出类型"""
return {
"output": VariableType.STRING
}
async def execute(self, state: WorkflowState) -> str:
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> str:
"""执行 end 节点业务逻辑
Args:
state: 工作流状态
variable_pool: 变量池
Returns:
最终输出字符串
@@ -34,7 +42,7 @@ class EndNode(BaseNode):
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template:
output = self._render_template(output_template, state, strict=False)
output = self._render_template(output_template, variable_pool, strict=False)
else:
output = ""

View File

@@ -9,7 +9,6 @@ class NodeType(StrEnum):
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
IF_ELSE = "if-else"
CODE = "code"
TRANSFORM = "transform"
QUESTION_CLASSIFIER = "question-classifier"
HTTP_REQUEST = "http-request"
TOOL = "tool"

View File

@@ -10,6 +10,8 @@ from httpx import AsyncClient, Response, Timeout
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__file__)
@@ -34,6 +36,14 @@ class HttpRequestNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: HttpRequestNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
return {
"body": VariableType.STRING,
"status_code": VariableType.NUMBER,
"headers": VariableType.OBJECT,
"output": VariableType.STRING
}
def _build_timeout(self) -> Timeout:
"""
Build httpx Timeout configuration.
@@ -50,7 +60,7 @@ class HttpRequestNode(BaseNode):
)
return timeout
def _build_auth(self, state: WorkflowState) -> dict[str, str]:
def _build_auth(self, variable_pool: VariablePool) -> dict[str, str]:
"""
Build authentication-related HTTP headers.
@@ -58,12 +68,12 @@ class HttpRequestNode(BaseNode):
the current workflow runtime state.
Args:
state: Current workflow runtime state.
variable_pool: Variable Pool
Returns:
A dictionary of HTTP headers used for authentication.
"""
api_key = self._render_template(self.typed_config.auth.api_key, state)
api_key = self._render_template(self.typed_config.auth.api_key, variable_pool)
match self.typed_config.auth.auth_type:
case HttpAuthType.NONE:
return {}
@@ -82,7 +92,7 @@ class HttpRequestNode(BaseNode):
case _:
raise RuntimeError(f"Auth type not supported: {self.typed_config.auth.auth_type}")
def _build_header(self, state: WorkflowState) -> dict[str, str]:
def _build_header(self, variable_pool: VariablePool) -> dict[str, str]:
"""
Build HTTP request headers.
@@ -90,10 +100,10 @@ class HttpRequestNode(BaseNode):
"""
headers = {}
for key, value in self.typed_config.headers.items():
headers[self._render_template(key, state)] = self._render_template(value, state)
headers[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool)
return headers
def _build_params(self, state: WorkflowState) -> dict[str, str]:
def _build_params(self, variable_pool: VariablePool) -> dict[str, str]:
"""
Build URL query parameters.
@@ -101,10 +111,10 @@ class HttpRequestNode(BaseNode):
"""
params = {}
for key, value in self.typed_config.params.items():
params[self._render_template(key, state)] = self._render_template(value, state)
params[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool)
return params
def _build_content(self, state) -> dict[str, Any]:
def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]:
"""
Build HTTP request body arguments for httpx request methods.
@@ -120,13 +130,13 @@ class HttpRequestNode(BaseNode):
return {}
case HttpContentType.JSON:
content["json"] = json.loads(self._render_template(
self.typed_config.body.data, state
self.typed_config.body.data, variable_pool
))
case HttpContentType.FROM_DATA:
data = {}
for item in self.typed_config.body.data:
if item.type == "text":
data[self._render_template(item.key, state)] = self._render_template(item.value, state)
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value, variable_pool)
elif item.type == "file":
# TODO: File support (Feature)
pass
@@ -136,11 +146,11 @@ class HttpRequestNode(BaseNode):
pass
case HttpContentType.WWW_FORM:
content["data"] = json.loads(self._render_template(
json.dumps(self.typed_config.body.data), state
json.dumps(self.typed_config.body.data), variable_pool
))
case HttpContentType.RAW:
content["content"] = self._render_template(self.typed_config.body.data, state)
content["content"] = self._render_template(self.typed_config.body.data, variable_pool)
case _:
raise RuntimeError(f"Content type not supported: {self.typed_config.body.content_type}")
return content
@@ -165,7 +175,7 @@ class HttpRequestNode(BaseNode):
case _:
raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}")
async def execute(self, state: WorkflowState) -> dict | str:
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict | str:
"""
Execute the HTTP request node.
@@ -176,6 +186,7 @@ class HttpRequestNode(BaseNode):
Args:
state: Current workflow runtime state.
variable_pool: Variable Pool
Returns:
- dict: Serialized HttpRequestNodeOutput on success
@@ -185,8 +196,8 @@ class HttpRequestNode(BaseNode):
async with httpx.AsyncClient(
verify=self.typed_config.verify_ssl,
timeout=self._build_timeout(),
headers=self._build_header(state) | self._build_auth(state),
params=self._build_params(state),
headers=self._build_header(variable_pool) | self._build_auth(variable_pool),
params=self._build_params(variable_pool),
follow_redirects=True
) as client:
retries = self.typed_config.retry.max_attempts
@@ -194,8 +205,8 @@ class HttpRequestNode(BaseNode):
try:
request_func = self._get_client_method(client)
resp = await request_func(
url=self._render_template(self.typed_config.url, state),
**self._build_content(state)
url=self._render_template(self.typed_config.url, variable_pool),
**self._build_content(variable_pool)
)
resp.raise_for_status()
logger.info(f"Node {self.node_id}: HTTP request succeeded")

View File

@@ -6,6 +6,8 @@ from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -15,6 +17,11 @@ class IfElseNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: IfElseNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
return {
"output": VariableType.STRING
}
@staticmethod
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
match operator:
@@ -45,7 +52,7 @@ class IfElseNode(BaseNode):
case _:
raise ValueError(f"Invalid condition: {operator}")
def evaluate_conditional_edge_expressions(self, state) -> list[bool]:
def evaluate_conditional_edge_expressions(self, variable_pool: VariablePool) -> list[bool]:
"""
Build conditional edge expressions for the If-Else node.
@@ -72,11 +79,11 @@ class IfElseNode(BaseNode):
pattern = r"\{\{\s*(.*?)\s*\}\}"
left_string = re.sub(pattern, r"\1", expression.left).strip()
try:
left_value = self.get_variable(left_string, state)
left_value = self.get_variable(left_string, variable_pool)
except KeyError:
left_value = None
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
self.get_variable_pool(state),
variable_pool,
expression.left,
expression.right,
expression.input_type
@@ -95,7 +102,7 @@ class IfElseNode(BaseNode):
return conditions
async def execute(self, state: WorkflowState) -> Any:
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
Execute the conditional branching logic of the node.
@@ -105,13 +112,13 @@ class IfElseNode(BaseNode):
Args:
state (WorkflowState): The current workflow state, containing variables, messages, node outputs, etc.
variable_pool: Variable Pool
Returns:
str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions.
"""
self.typed_config = IfElseNodeConfig(**self.config)
expressions = self.evaluate_conditional_edge_expressions(state)
# TODO: 变量类型及文本类型解析
expressions = self.evaluate_conditional_edge_expressions(variable_pool)
for i in range(len(expressions)):
if expressions[i]:
logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}")

View File

@@ -5,6 +5,8 @@ from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig
from app.core.workflow.template_renderer import TemplateRenderer
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -14,7 +16,12 @@ class JinjaRenderNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: JinjaRenderNodeConfig | None = None
async def execute(self, state: WorkflowState) -> Any:
def _output_types(self) -> dict[str, VariableType]:
return {
"output": VariableType.STRING
}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
Execute the node: render the Jinja2 template with mapped variables.
@@ -24,6 +31,7 @@ class JinjaRenderNode(BaseNode):
Args:
state (WorkflowState): Current workflow state containing variables,
node outputs, and runtime variables.
variable_pool: Variable Pool
Returns:
dict[str, Any]: Node output dictionary containing the rendered result
@@ -40,7 +48,7 @@ class JinjaRenderNode(BaseNode):
context = {}
for variable in self.typed_config.mapping:
try:
context[variable.name] = self.get_variable(variable.value, state)
context[variable.name] = self.get_variable(variable.value, variable_pool)
except Exception:
logger.info(f"variable not found, var: {variable.value}")
continue

View File

@@ -8,6 +8,8 @@ from app.core.models import RedBearRerank, RedBearModelConfig
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
from app.db import get_db_read
from app.models import knowledge_model, knowledgeshare_model, ModelType
from app.repositories import knowledge_repository, knowledgeshare_repository
@@ -22,6 +24,11 @@ class KnowledgeRetrievalNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
return {
"output": VariableType.ARRAY_STRING
}
@staticmethod
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
"""
@@ -149,7 +156,7 @@ class KnowledgeRetrievalNode(BaseNode):
)
return reranker
async def execute(self, state: WorkflowState) -> Any:
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
Execute the knowledge retrieval workflow node.
@@ -163,6 +170,7 @@ class KnowledgeRetrievalNode(BaseNode):
Args:
state (WorkflowState): Current workflow execution state.
variable_pool: Variable Pool
Returns:
Any: List of retrieved knowledge chunks (dict format).
@@ -171,7 +179,7 @@ class KnowledgeRetrievalNode(BaseNode):
RuntimeError: If no valid knowledge base is found or access is denied.
"""
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
query = self._render_template(self.typed_config.query, state)
query = self._render_template(self.typed_config.query, variable_pool)
with get_db_read() as db:
knowledge_bases = self.typed_config.knowledge_bases
existing_ids = self._get_existing_kb_ids(db, [kb.kb_id for kb in knowledge_bases])

View File

@@ -4,7 +4,8 @@ from typing import Any
from pydantic import BaseModel, Field, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
from app.core.workflow.variable.base_variable import VariableType
class MessageConfig(BaseModel):

View File

@@ -15,6 +15,8 @@ from app.core.exceptions import BusinessException
from app.core.models import RedBearLLM, RedBearModelConfig
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.llm.config import LLMNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
from app.db import get_db_context
from app.models import ModelType
from app.services.model_service import ModelConfigService
@@ -66,19 +68,27 @@ class LLMNode(BaseNode):
- ai/assistant: AI 消息AIMessage
"""
def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING}
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config: LLMNodeConfig | None = None
def _render_context(self, message, state):
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
def _render_context(self, message: str, variable_pool: VariablePool):
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
return re.sub(r"{{context}}", context, message)
def _prepare_llm(self, state: WorkflowState, stream: bool = False) -> tuple[RedBearLLM, list | str]:
def _prepare_llm(
self,
state: WorkflowState,
variable_pool: VariablePool,
stream: bool = False
) -> tuple[RedBearLLM, list | str]:
"""准备 LLM 实例(公共逻辑)
Args:
state: 工作流状态
variable_pool: 变量池
Returns:
(llm, messages_or_prompt): LLM 实例和消息列表或 prompt 字符串
@@ -94,8 +104,8 @@ class LLMNode(BaseNode):
for msg_config in messages_config:
role = msg_config.role.lower()
content_template = msg_config.content
content_template = self._render_context(content_template, state)
content = self._render_template(content_template, state)
content_template = self._render_context(content_template, variable_pool)
content = self._render_template(content_template, variable_pool)
# 根据角色创建对应的消息对象
if role == "system":
@@ -115,7 +125,7 @@ class LLMNode(BaseNode):
else:
# 使用简单的 prompt 格式(向后兼容)
prompt_template = self.config.get("prompt", "")
prompt_or_messages = self._render_template(prompt_template, state)
prompt_or_messages = self._render_template(prompt_template, variable_pool)
# 2. 获取模型配置
model_id = self.config.get("model_id")
@@ -159,17 +169,18 @@ class LLMNode(BaseNode):
return llm, prompt_or_messages
async def execute(self, state: WorkflowState) -> AIMessage:
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage:
"""非流式执行 LLM 调用
Args:
state: 工作流状态
variable_pool: 变量池
Returns:
LLM 响应消息
"""
# self.typed_config = LLMNodeConfig(**self.config)
llm, prompt_or_messages = self._prepare_llm(state, True)
llm, prompt_or_messages = self._prepare_llm(state, variable_pool, False)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
@@ -186,9 +197,9 @@ class LLMNode(BaseNode):
# 返回 AIMessage包含响应元数据
return response if isinstance(response, AIMessage) else AIMessage(content=content)
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""提取输入数据(用于记录)"""
_, prompt_or_messages = self._prepare_llm(state)
_, prompt_or_messages = self._prepare_llm(state, variable_pool)
return {
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
@@ -221,18 +232,19 @@ class LLMNode(BaseNode):
}
return None
async def execute_stream(self, state: WorkflowState):
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
"""流式执行 LLM 调用
Args:
state: 工作流状态
variable_pool: 变量池
Yields:
文本片段chunk或完成标记
"""
self.typed_config = LLMNodeConfig(**self.config)
llm, prompt_or_messages = self._prepare_llm(state, True)
llm, prompt_or_messages = self._prepare_llm(state, variable_pool, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")

View File

@@ -3,6 +3,8 @@ from typing import Any
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
from app.db import get_db_read
from app.services.memory_agent_service import MemoryAgentService
from app.tasks import write_message_task
@@ -13,17 +15,23 @@ class MemoryReadNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: MemoryReadNodeConfig | None = None
async def execute(self, state: WorkflowState) -> Any:
def _output_types(self) -> dict[str, VariableType]:
return {
"answer": VariableType.STRING,
"intermediate_outputs": VariableType.ARRAY_OBJECT
}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
self.typed_config = MemoryReadNodeConfig(**self.config)
with get_db_read() as db:
end_user_id = self.get_variable("sys.user_id", state)
end_user_id = self.get_variable("sys.user_id", variable_pool)
if not end_user_id:
raise RuntimeError("End user id is required")
return await MemoryAgentService().read_memory(
end_user_id=end_user_id,
message=self._render_template(self.typed_config.message, state),
message=self._render_template(self.typed_config.message, variable_pool),
config_id=self.typed_config.config_id,
search_switch=self.typed_config.search_switch,
history=[],
@@ -38,16 +46,19 @@ class MemoryWriteNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: MemoryWriteNodeConfig | None = None
async def execute(self, state: WorkflowState) -> Any:
def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
self.typed_config = MemoryWriteNodeConfig(**self.config)
end_user_id = self.get_variable("sys.user_id", state)
end_user_id = self.get_variable("sys.user_id", variable_pool)
if not end_user_id:
raise RuntimeError("End user id is required")
write_message_task.delay(
end_user_id,
self._render_template(self.typed_config.message, state),
self._render_template(self.typed_config.message, variable_pool),
str(self.typed_config.config_id),
"neo4j",
""

View File

@@ -22,7 +22,6 @@ from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.memory import MemoryReadNode, MemoryWriteNode
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.breaker import BreakNode
@@ -37,7 +36,6 @@ WorkflowNode = Union[
LLMNode,
IfElseNode,
AgentNode,
TransformNode,
AssignerNode,
HttpRequestNode,
KnowledgeRetrievalNode,
@@ -67,7 +65,6 @@ class NodeFactory:
NodeType.END: EndNode,
NodeType.LLM: LLMNode,
NodeType.AGENT: AgentNode,
NodeType.TRANSFORM: TransformNode,
NodeType.IF_ELSE: IfElseNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.ASSIGNER: AssignerNode,

View File

@@ -1,9 +1,9 @@
import json
import re
from abc import ABC
from typing import Union, Type, NoReturn
from typing import Union, Type, NoReturn, Any
from app.core.workflow.nodes.base_config import VariableType
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.nodes.enums import ValueInputType
from app.core.workflow.variable_pool import VariablePool
@@ -69,7 +69,7 @@ class TypeTransformer:
class OperatorBase(ABC):
def __init__(self, pool: VariablePool, left_selector, right):
def __init__(self, pool: VariablePool, left_selector: str, right: Any):
self.pool = pool
self.left_selector = left_selector
self.right = right
@@ -77,7 +77,7 @@ class OperatorBase(ABC):
self.type_limit: type[str, int, dict, list] = None
def check(self, no_right=False):
left = self.pool.get(self.left_selector)
left = self.pool.get_value(self.left_selector)
if not isinstance(left, self.type_limit):
raise TypeError(f"The variable to be operated on must be of {self.type_limit} type")
@@ -92,13 +92,13 @@ class StringOperator(OperatorBase):
super().__init__(pool, left_selector, right)
self.type_limit = str
def assign(self) -> None:
async def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
await self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
async def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, '')
await self.pool.set(self.left_selector, '')
class NumberOperator(OperatorBase):
@@ -106,33 +106,33 @@ class NumberOperator(OperatorBase):
super().__init__(pool, left_selector, right)
self.type_limit = (float, int)
def assign(self) -> None:
async def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
await self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
async def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, 0)
await self.pool.set(self.left_selector, 0)
def add(self) -> None:
async def add(self) -> None:
self.check()
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin + self.right)
origin = self.pool.get_value(self.left_selector)
await self.pool.set(self.left_selector, origin + self.right)
def subtract(self) -> None:
async def subtract(self) -> None:
self.check()
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin - self.right)
origin = self.pool.get_value(self.left_selector)
await self.pool.set(self.left_selector, origin - self.right)
def multiply(self) -> None:
async def multiply(self) -> None:
self.check()
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin * self.right)
origin = self.pool.get_value(self.left_selector)
await self.pool.set(self.left_selector, origin * self.right)
def divide(self) -> None:
async def divide(self) -> None:
self.check()
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin / self.right)
origin = self.pool.get_value(self.left_selector)
await self.pool.set(self.left_selector, origin / self.right)
class BooleanOperator(OperatorBase):
@@ -140,13 +140,13 @@ class BooleanOperator(OperatorBase):
super().__init__(pool, left_selector, right)
self.type_limit = bool
def assign(self) -> None:
async def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
await self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
async def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, False)
await self.pool.set(self.left_selector, False)
class ArrayOperator(OperatorBase):
@@ -154,38 +154,37 @@ class ArrayOperator(OperatorBase):
super().__init__(pool, left_selector, right)
self.type_limit = list
def assign(self) -> None:
async def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
await self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
async def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, list())
await self.pool.set(self.left_selector, list())
def append(self) -> None:
async def append(self) -> None:
self.check(no_right=True)
# TODOrequire type limit in list
origin = self.pool.get(self.left_selector)
origin = self.pool.get_value(self.left_selector)
origin.append(self.right)
self.pool.set(self.left_selector, origin)
await self.pool.set(self.left_selector, origin)
def extend(self) -> None:
async def extend(self) -> None:
self.check(no_right=True)
origin = self.pool.get(self.left_selector)
origin = self.pool.get_value(self.left_selector)
origin.extend(self.right)
self.pool.set(self.left_selector, origin)
await self.pool.set(self.left_selector, origin)
def remove_last(self) -> None:
async def remove_last(self) -> None:
self.check(no_right=True)
origin = self.pool.get(self.left_selector)
origin = self.pool.get_value(self.left_selector)
origin.pop()
self.pool.set(self.left_selector, origin)
await self.pool.set(self.left_selector, origin)
def remove_first(self) -> None:
async def remove_first(self) -> None:
self.check(no_right=True)
origin = self.pool.get(self.left_selector)
origin = self.pool.get_value(self.left_selector)
origin.pop(0)
self.pool.set(self.left_selector, origin)
await self.pool.set(self.left_selector, origin)
class ObjectOperator(OperatorBase):
@@ -193,13 +192,13 @@ class ObjectOperator(OperatorBase):
super().__init__(pool, left_selector, right)
self.type_limit = dict
def assign(self) -> None:
async def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
await self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
async def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, dict())
await self.pool.set(self.left_selector, dict())
class AssignmentOperatorResolver:
@@ -245,7 +244,7 @@ class ConditionBase(ABC):
self.right_selector = right_selector
self.input_type = input_type
self.left_value = self.pool.get(self.left_selector)
self.left_value = self.pool.get_value(self.left_selector)
self.right_value = self.resolve_right_literal_value()
self.type_limit = getattr(self, "type_limit", None)
@@ -254,7 +253,7 @@ class ConditionBase(ABC):
if self.input_type == ValueInputType.VARIABLE:
pattern = r"\{\{\s*(.*?)\s*\}\}"
right_expression = re.sub(pattern, r"\1", self.right_selector).strip()
return self.pool.get(right_expression)
return self.pool.get_value(right_expression)
elif self.input_type == ValueInputType.CONSTANT:
return self.right_selector
raise RuntimeError("Unsupported variable type")

View File

@@ -1,7 +1,7 @@
import uuid
from enum import StrEnum
from pydantic import Field, BaseModel
from enum import StrEnum
from app.core.workflow.nodes.base_config import BaseNodeConfig

View File

@@ -12,6 +12,8 @@ from app.core.models import RedBearLLM, RedBearModelConfig
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
from app.db import get_db_read
from app.models import ModelType
from app.services.model_service import ModelConfigService
@@ -24,6 +26,12 @@ class ParameterExtractorNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: ParameterExtractorNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
outputs = {}
for param in self.typed_config.params:
outputs[param.name] = param.type
return outputs
@staticmethod
def _get_prompt():
"""
@@ -120,7 +128,7 @@ class ParameterExtractorNode(BaseNode):
field_type[param.name] = f'{param.type}, required:{str(param.required)}'
return field_type
async def execute(self, state: WorkflowState) -> Any:
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
Main execution function for this node.
@@ -138,6 +146,7 @@ class ParameterExtractorNode(BaseNode):
Args:
state (WorkflowState): Current state of the workflow, used for template rendering.
variable_pool (VariablePool): Used for accessing and setting variables during execution.
Returns:
dict[str, Any]: Dictionary containing extracted parameters under the "output" key.
@@ -153,7 +162,7 @@ class ParameterExtractorNode(BaseNode):
rendered_user_prompt = user_prompt_teplate.render(
field_descriptions=str(self._get_field_desc()),
field_type=str(self._get_field_type()),
text_input=self._render_template(self.typed_config.text, state)
text_input=self._render_template(self.typed_config.text, variable_pool)
)
messages = [
@@ -162,7 +171,7 @@ class ParameterExtractorNode(BaseNode):
]
if self.typed_config.prompt:
messages.extend([
("user", self._render_template(self.typed_config.prompt, state)),
("user", self._render_template(self.typed_config.prompt, variable_pool)),
("user", rendered_user_prompt),
])
else:

View File

@@ -6,6 +6,8 @@ from app.core.workflow.nodes.question_classifier.config import QuestionClassifie
from app.core.models import RedBearLLM, RedBearModelConfig
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
from app.db import get_db_read
from app.models import ModelType
from app.services.model_service import ModelConfigService
@@ -24,6 +26,12 @@ class QuestionClassifierNode(BaseNode):
self.typed_config: QuestionClassifierNodeConfig | None = None
self.category_to_case_map = {}
def _output_types(self) -> dict[str, VariableType]:
return {
"class_name": VariableType.STRING,
"output": VariableType.STRING
}
def _get_llm_instance(self) -> RedBearLLM:
"""获取LLM实例"""
with get_db_read() as db:
@@ -65,7 +73,7 @@ class QuestionClassifierNode(BaseNode):
category_map[category_name] = case_tag
return category_map
async def execute(self, state: WorkflowState) -> dict:
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict:
"""执行问题分类"""
self.typed_config = QuestionClassifierNodeConfig(**self.config)
self.category_to_case_map = self._build_category_case_map()
@@ -102,7 +110,7 @@ class QuestionClassifierNode(BaseNode):
categories=", ".join(category_names),
supplement_prompt=supplement_prompt
),
state
variable_pool
)
messages = [

View File

@@ -2,7 +2,8 @@
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
from app.core.workflow.variable.base_variable import VariableType
class StartNodeConfig(BaseNodeConfig):

View File

@@ -7,9 +7,10 @@ Start 节点实现
import logging
from typing import Any
from app.core.workflow.nodes.base_config import VariableType
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.start.config import StartNodeConfig
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -36,14 +37,25 @@ class StartNode(BaseNode):
# 解析并验证配置
self.typed_config: StartNodeConfig | None = None
self.output_var_types = {}
async def execute(self, state: WorkflowState) -> dict[str, Any]:
def _output_types(self) -> dict[str, VariableType]:
return self.output_var_types | {
"message": VariableType.STRING,
"execution_id": VariableType.STRING,
"conversation_id": VariableType.STRING,
"workspace_id": VariableType.STRING,
"user_id": VariableType.STRING,
}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""执行 start 节点业务逻辑
Start 节点输出系统变量、会话变量和自定义变量。
Args:
state: 工作流状态
variable_pool: 变量池
Returns:
包含系统参数、会话变量和自定义变量的字典
@@ -51,19 +63,16 @@ class StartNode(BaseNode):
self.typed_config = StartNodeConfig(**self.config)
logger.info(f"节点 {self.node_id} (Start) 开始执行")
# 创建变量池实例(在方法内复用)
pool = self.get_variable_pool(state)
# 处理自定义变量(传入 pool 避免重复创建)
custom_vars = self._process_custom_variables(pool)
custom_vars = self._process_custom_variables(variable_pool)
# 返回业务数据(包含自定义变量)
result = {
"message": pool.get("sys.message"),
"execution_id": pool.get("sys.execution_id"),
"conversation_id": pool.get("sys.conversation_id"),
"workspace_id": pool.get("sys.workspace_id"),
"user_id": pool.get("sys.user_id"),
"message": variable_pool.get_value("sys.message"),
"execution_id": variable_pool.get_value("sys.execution_id"),
"conversation_id": variable_pool.get_value("sys.conversation_id"),
"workspace_id": variable_pool.get_value("sys.workspace_id"),
"user_id": variable_pool.get_value("sys.user_id"),
**custom_vars # 自定义变量作为节点输出的一部分
}
@@ -74,7 +83,7 @@ class StartNode(BaseNode):
return result
def _process_custom_variables(self, pool) -> dict[str, Any]:
def _process_custom_variables(self, pool: VariablePool) -> dict[str, Any]:
"""处理自定义变量
从输入数据中提取自定义变量,应用默认值和验证。
@@ -89,13 +98,14 @@ class StartNode(BaseNode):
ValueError: 缺少必需变量
"""
# 获取输入数据中的自定义变量
input_variables = pool.get("sys.input_variables", default={})
input_variables = pool.get_value("sys.input_variables", default={}, strict=False)
processed = {}
# 遍历配置的变量定义
for var_def in self.typed_config.variables:
var_name = var_def.name
var_type = var_def.type
# 检查变量是否存在
if var_name in input_variables:
@@ -116,21 +126,12 @@ class StartNode(BaseNode):
f"变量 '{var_name}' 使用默认值: {var_def.default}"
)
else:
match var_def.type:
case VariableType.STRING:
processed[var_name] = ""
case VariableType.NUMBER:
processed[var_name] = 0
case VariableType.OBJECT:
processed[var_name] = {}
case VariableType.BOOLEAN:
processed[var_name] = False
case VariableType.ARRAY_NUMBER | VariableType.ARRAY_OBJECT | VariableType.ARRAY_BOOLEAN | VariableType.ARRAY_STRING:
processed[var_name] = []
processed[var_name] = DEFAULT_VALUE(var_type)
self.output_var_types[var_name] = var_type
return processed
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""提取输入数据(用于记录)
Args:
@@ -139,11 +140,9 @@ class StartNode(BaseNode):
Returns:
输入数据字典
"""
pool = self.get_variable_pool(state)
return {
"execution_id": pool.get("sys.execution_id"),
"conversation_id": pool.get("sys.conversation_id"),
"message": pool.get("sys.message"),
"conversation_vars": pool.get_all_conversation_vars()
"execution_id": variable_pool.get_value("sys.execution_id"),
"conversation_id": variable_pool.get_value("sys.conversation_id"),
"message": variable_pool.get_value("sys.message"),
"conversation_vars": variable_pool.get_all_conversation_vars()
}

View File

@@ -6,6 +6,8 @@ from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
from app.services.tool_service import ToolService
from app.db import get_db_read
@@ -21,13 +23,20 @@ class ToolNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: ToolNodeConfig | None = None
async def execute(self, state: WorkflowState) -> dict[str, Any]:
def _output_types(self) -> dict[str, VariableType]:
return {
"data": VariableType.STRING,
"error_code": VariableType.STRING,
"execution_time": VariableType.NUMBER
}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""执行工具"""
self.typed_config = ToolNodeConfig(**self.config)
# 获取租户ID和用户ID
tenant_id = self.get_variable("sys.tenant_id", state)
user_id = self.get_variable("sys.user_id", state)
workspace_id = self.get_variable("sys.workspace_id", state)
tenant_id = self.get_variable("sys.tenant_id", variable_pool, strict=False)
user_id = self.get_variable("sys.user_id", variable_pool)
workspace_id = self.get_variable("sys.workspace_id", variable_pool)
# 如果没有租户ID尝试从工作流ID获取
if not tenant_id:
@@ -48,7 +57,7 @@ class ToolNode(BaseNode):
for param_name, param_template in self.typed_config.tool_parameters.items():
if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template):
try:
rendered_value = self._render_template(param_template, state)
rendered_value = self._render_template(param_template, variable_pool)
except Exception as e:
raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e
else:

View File

@@ -1,6 +0,0 @@
"""Transform 节点"""
from app.core.workflow.nodes.transform.node import TransformNode
from app.core.workflow.nodes.transform.config import TransformNodeConfig
__all__ = ["TransformNode", "TransformNodeConfig"]

View File

@@ -1,80 +0,0 @@
"""Transform 节点配置"""
from typing import Literal
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
class TransformNodeConfig(BaseNodeConfig):
"""Transform 节点配置
用于数据转换和处理。
"""
transform_type: Literal["template", "code", "json"] = Field(
default="template",
description="转换类型template(模板), code(代码), json(JSON处理)"
)
# 模板模式
template: str | None = Field(
default=None,
description="转换模板,支持变量引用"
)
# 代码模式
code: str | None = Field(
default=None,
description="Python 代码,用于数据转换"
)
# JSON 模式
json_path: str | None = Field(
default=None,
description="JSON 路径表达式"
)
# 输入变量
inputs: dict[str, str] | None = Field(
default=None,
description="输入变量映射key 为变量名value 为变量选择器"
)
# 输出变量
output_key: str = Field(
default="result",
description="输出变量的键名"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="result",
type=VariableType.STRING,
description="转换后的结果"
)
],
description="输出变量定义(根据 output_key 动态生成)"
)
class Config:
json_schema_extra = {
"examples": [
{
"transform_type": "template",
"template": "用户问题:{{ sys.message }}\n回答:{{ llm_qa.output }}",
"output_key": "formatted_result"
},
{
"transform_type": "code",
"code": "result = input_text.upper()",
"inputs": {
"input_text": "{{ sys.message }}"
},
"output_key": "uppercase_text"
}
]
}

View File

@@ -1,60 +0,0 @@
"""
Transform 节点实现
数据转换节点,用于处理和转换数据。
"""
import logging
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
logger = logging.getLogger(__name__)
class TransformNode(BaseNode):
"""数据转换节点
配置示例:
{
"type": "transform",
"config": {
"mapping": {
"output_field": "{{node.previous.output}}",
"processed": "{{var.input | upper}}"
}
}
}
"""
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行数据转换
Args:
state: 工作流状态
Returns:
状态更新字典
"""
logger.info(f"节点 {self.node_id} 开始执行数据转换")
# 获取映射配置
mapping = self.config.get("mapping", {})
# 执行数据转换
transformed_data = {}
for target_key, source_template in mapping.items():
# 渲染模板获取值
value = self._render_template(str(source_template), state)
transformed_data[target_key] = value
logger.info(f"节点 {self.node_id} 数据转换完成,输出字段: {list(transformed_data.keys())}")
return {
"node_outputs": {
self.node_id: {
"output": transformed_data,
"status": "completed"
}
}
}

View File

@@ -1,6 +1,7 @@
from pydantic import Field, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.variable.base_variable import VariableType
class VariableAggregatorNodeConfig(BaseNodeConfig):
@@ -14,6 +15,11 @@ class VariableAggregatorNodeConfig(BaseNodeConfig):
description="需要被聚合的变量"
)
group_type: dict[str, VariableType] = Field(
...,
description="每个分组的变量类型"
)
@field_validator("group_variables")
@classmethod
def group_variables_validator(cls, v, info):

View File

@@ -5,6 +5,8 @@ from typing import Any
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -14,6 +16,13 @@ class VariableAggregatorNode(BaseNode):
super().__init__(node_config, workflow_config)
self.typed_config: VariableAggregatorNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
config = VariableAggregatorNodeConfig(**self.config)
output = {}
for var_type in config.group_type:
output[var_type] = config.group_type[var_type]
return output
@staticmethod
def _get_express(variable_string: str) -> Any:
"""
@@ -29,7 +38,7 @@ class VariableAggregatorNode(BaseNode):
expression = re.sub(pattern, r"\1", variable_string).strip()
return expression
async def execute(self, state: WorkflowState) -> Any:
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
Execute the variable aggregation logic.
@@ -45,7 +54,7 @@ class VariableAggregatorNode(BaseNode):
for variable in self.typed_config.group_variables:
var_express = self._get_express(variable)
try:
value = self.get_variable(var_express, state)
value = self.get_variable(var_express, variable_pool)
except Exception as e:
logger.warning(f"Failed to get variable '{var_express}': {e}")
continue
@@ -55,7 +64,7 @@ class VariableAggregatorNode(BaseNode):
return value
logger.info("No variable found in non-group mode; returning empty string.")
return ""
return DEFAULT_VALUE(self.typed_config.group_type["output"])
# --------------------------
# Group mode
@@ -65,7 +74,7 @@ class VariableAggregatorNode(BaseNode):
for variable in variables:
var_express = self._get_express(variable)
try:
value = self.get_variable(var_express, state)
value = self.get_variable(var_express, variable_pool)
except Exception as e:
logger.warning(f"Failed to get variable '{var_express}' in group '{group_name}': {e}")
continue
@@ -74,7 +83,7 @@ class VariableAggregatorNode(BaseNode):
result[group_name] = value
break
else:
result[group_name] = ""
result[group_name] = DEFAULT_VALUE(self.typed_config.group_type[group_name])
logger.info(f"No variable found for group '{group_name}'; set empty string.")
logger.info(f"Node: {self.node_id} variable aggregation result: {result}")
return result

View File

@@ -43,7 +43,7 @@ class TemplateRenderer:
def render(
self,
template: str,
variables: dict[str, Any],
conv_vars: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> str:
@@ -51,7 +51,7 @@ class TemplateRenderer:
Args:
template: 模板字符串
variables: 用户定义的变量
conv_vars: 会话变量
node_outputs: 节点输出结果
system_vars: 系统变量
@@ -80,20 +80,11 @@ class TemplateRenderer:
'分析结果: 正面情绪'
"""
# 构建命名空间上下文
# variables 的结构:{"sys": {...}, "conv": {...}}
sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {}
conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {}
if self.strict:
context = defaultdict(dict)
context["conv"] = conv_vars
context["node"] = node_outputs
context["sys"] = {**(system_vars or {}), **sys_vars}
else:
context = {
"conv": conv_vars, # 会话变量:{{conv.user_name}}
"node": node_outputs, # 节点输出:{{node.node_1.output}}
"sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源)
}
context = {
"conv": conv_vars, # 会话变量:{{conv.user_name}}
"node": node_outputs, # 节点输出:{{node.node_1.output}}
"sys": system_vars, # 系统变量:{{sys.execution_id}}
}
# 支持直接通过节点ID访问节点输出{{llm_qa.output}}
# 将所有节点输出添加到顶层上下文
@@ -157,9 +148,9 @@ _default_renderer = TemplateRenderer(strict=True)
def render_template(
template: str,
variables: dict[str, Any],
conv_vars: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None,
system_vars: dict[str, Any],
strict: bool = True
) -> str:
"""渲染模板(便捷函数)
@@ -167,7 +158,7 @@ def render_template(
Args:
strict: 严格模式
template: 模板字符串
variables: 用户变量
conv_vars: 会话变量
node_outputs: 节点输出
system_vars: 系统变量
@@ -184,7 +175,7 @@ def render_template(
'请分析: 这是一段文本'
"""
renderer = TemplateRenderer(strict=strict)
return renderer.render(template, variables, node_outputs, system_vars)
return renderer.render(template, conv_vars, node_outputs, system_vars)
def validate_template(template: str) -> list[str]:

View File

@@ -5,10 +5,13 @@
"""
import logging
from typing import Any, Union
from typing import Any, Union, TYPE_CHECKING
from app.core.workflow.nodes.enums import NodeType
if TYPE_CHECKING:
from app.schemas.workflow_schema import WorkflowConfig
logger = logging.getLogger(__name__)
@@ -64,7 +67,7 @@ class WorkflowValidator:
return cycle_nodes, cycle_edges
@classmethod
def get_subgraph(cls, workflow_config: Union[dict[str, Any], Any]) -> list:
def get_subgraph(cls, workflow_config: Union[dict[str, Any], "WorkflowConfig"]) -> list:
if not isinstance(workflow_config, dict):
workflow_config = {
"nodes": workflow_config.nodes,
@@ -331,7 +334,7 @@ class WorkflowValidator:
def validate_workflow_config(
workflow_config: dict[str, Any],
workflow_config: Union[dict[str, Any], 'WorkflowConfig'],
for_publish: bool = False
) -> tuple[bool, list[str]]:
"""验证工作流配置(便捷函数)

View File

@@ -0,0 +1,162 @@
from enum import StrEnum
from abc import abstractmethod, ABC
from typing import Any
class VariableType(StrEnum):
"""Enumeration of supported variable types in the workflow."""
STRING = "string"
NUMBER = "number"
BOOLEAN = "boolean"
OBJECT = "object"
FILE = "file"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_BOOLEAN = "array[boolean]"
ARRAY_OBJECT = "array[object]"
ARRAY_FILE = "array[file]"
NESTED_ARRAY = "array_nest"
@classmethod
def type_map(cls, var: Any) -> "VariableType":
"""Maps a Python value to a corresponding VariableType.
Args:
var: The Python value to map.
Returns:
The VariableType corresponding to the input value.
Raises:
TypeError: If the type of the input value is not supported.
"""
var_type = type(var)
if isinstance(var_type, str):
return cls.STRING
elif isinstance(var_type, (int, float)):
return cls.NUMBER
elif isinstance(var_type, bool):
return cls.BOOLEAN
elif isinstance(var_type, FileObj):
return cls.FILE
elif isinstance(var_type, dict):
return cls.OBJECT
elif isinstance(var_type, list):
if len(var) == 0:
return cls.ARRAY_STRING
else:
child_type = type(var[0])
if child_type == str:
return cls.ARRAY_STRING
elif child_type == int or child_type == float:
return cls.ARRAY_NUMBER
elif child_type == bool:
return cls.ARRAY_BOOLEAN
elif child_type == dict:
return cls.ARRAY_OBJECT
elif child_type == list:
return cls.NESTED_ARRAY
else:
raise TypeError(f"Unsupported array child type - {child_type}")
raise TypeError(f"Unsupported type - {var_type}")
def DEFAULT_VALUE(var_type: VariableType) -> Any:
"""Returns the default value for a given VariableType.
Args:
var_type: The variable type for which to get the default value.
Returns:
The default Python value corresponding to the VariableType.
Raises:
TypeError: If the VariableType is invalid.
"""
match var_type:
case VariableType.STRING:
return ""
case VariableType.NUMBER:
return 0
case VariableType.BOOLEAN:
return False
case VariableType.OBJECT:
return {}
case VariableType.FILE:
return None
case VariableType.ARRAY_STRING:
return []
case VariableType.ARRAY_NUMBER:
return []
case VariableType.ARRAY_BOOLEAN:
return []
case VariableType.ARRAY_OBJECT:
return []
case VariableType.ARRAY_FILE:
return []
case _:
raise TypeError(f"Invalid type - {type}")
class FileObj:
pass
class BaseVariable(ABC):
"""Abstract base class for all workflow variables.
Subclasses must implement validation and serialization methods.
"""
type = None
def __init__(self, value: Any):
"""Initializes a variable instance.
Args:
value: The initial value for the variable.
Attributes:
self.value: The validated value stored in the variable.
self.literal: A string representation of the variable.
"""
self.value = self.valid_value(value)
self.literal = self.to_literal()
@abstractmethod
def valid_value(self, value) -> Any:
"""Validates or converts a value to the correct type for the variable.
Args:
value: The value to validate.
Returns:
The validated or converted value.
Raises:
TypeError: If the value is invalid.
"""
pass
@abstractmethod
def to_literal(self) -> str:
"""Converts the variable value to a string literal representation.
Returns:
A string representing the variable's value.
"""
pass
def get_value(self) -> Any:
"""Returns the current value of the variable."""
return self.value
def set(self, value):
"""Sets the variable to a new value after validation.
Args:
value: The new value to assign to the variable.
"""
self.value = self.valid_value(value)

View File

@@ -0,0 +1,137 @@
from typing import Any, TypeVar, Type, Generic
from app.core.workflow.variable.base_variable import BaseVariable, VariableType
T = TypeVar("T", bound=BaseVariable)
class StringObject(BaseVariable):
type = 'str'
def valid_value(self, value) -> str:
if not isinstance(value, str):
raise TypeError("Value must be a string")
return value
def to_literal(self) -> str:
return self.value
class NumberObject(BaseVariable):
type = 'number'
def valid_value(self, value) -> int | float:
if not isinstance(value, (int, float)):
raise TypeError("Value must be a number")
return value
def to_literal(self) -> str:
return str(self.value)
class BooleanObject(BaseVariable):
type = 'boolean'
def valid_value(self, value) -> bool:
if not isinstance(value, bool):
raise TypeError("Value must be a boolean")
return value
def to_literal(self) -> str:
return str(self.value).lower()
class DictObject(BaseVariable):
type = 'object'
def valid_value(self, value) -> dict:
if not isinstance(value, dict):
raise TypeError("Value must be a dict")
return value
def to_literal(self) -> str:
return str(self.value)
class FileObject(BaseVariable):
type = 'file'
def valid_value(self, value) -> Any:
pass
def to_literal(self) -> str:
pass
class ArrayObject(BaseVariable, Generic[T]):
type = 'array'
def __init__(self, child_type: Type[T], value: list[Any]):
if not issubclass(child_type, BaseVariable):
raise TypeError("child_type must be a subclass of BaseVariable")
self.child_type = child_type
super().__init__(value)
def valid_value(self, value: list[Any]) -> list[T]:
if not isinstance(value, list):
raise TypeError("Value must be a list")
final_value = []
for v in value:
try:
final_value.append(self.child_type(v))
except:
raise TypeError(f"All elements must be of type {self.child_type.type}")
return final_value
def to_literal(self) -> str:
return "\n".join([v.to_literal() for v in self.value])
class NestedArrayObject(BaseVariable):
type = 'array_nest'
def valid_value(self, value: list[T]) -> list[T]:
if not isinstance(value, list):
raise TypeError("Value must be a list")
final_value = []
for v in value:
if not isinstance(v, ArrayObject):
raise TypeError("All elements must be of type list")
final_value.append(v)
return final_value
def to_literal(self) -> str:
return "\n".join(["\n".join([item.to_literal() for item in row]) for row in self.value])
def get_value(self) -> Any:
return [[item.get_value() for item in row] for row in self.value]
def make_array(child_type: Type[T], value: list[Any]) -> ArrayObject[T]:
"""简化 ArrayObject 创建,不需要重复写类型"""
return ArrayObject(child_type, value)
def create_variable_instance(var_type: VariableType, value: Any) -> T:
match var_type:
case VariableType.STRING:
return StringObject(value)
case VariableType.NUMBER:
return NumberObject(value)
case VariableType.BOOLEAN:
return BooleanObject(value)
case VariableType.OBJECT:
return DictObject(value)
case VariableType.ARRAY_STRING:
return make_array(StringObject, value)
case VariableType.ARRAY_NUMBER:
return make_array(NumberObject, value)
case VariableType.ARRAY_BOOLEAN:
return make_array(BooleanObject, value)
case VariableType.ARRAY_OBJECT:
return make_array(DictObject, value)
case VariableType.ARRAY_FILE:
return make_array(FileObject, value)
case _:
raise TypeError(f"Invalid type - {var_type}")

View File

@@ -11,10 +11,15 @@
import logging
import re
from typing import Any, TYPE_CHECKING
from asyncio import Lock
from collections import defaultdict
from copy import deepcopy
from typing import Any, Generic
if TYPE_CHECKING:
from app.core.workflow.nodes import WorkflowState
from pydantic import BaseModel
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable.variable_objects import T, create_variable_instance
logger = logging.getLogger(__name__)
@@ -23,11 +28,6 @@ class VariableSelector:
"""变量选择器
用于引用变量的路径表示。
Examples:
>>> selector = VariableSelector(["sys", "message"])
>>> selector = VariableSelector(["node_A", "output"])
>>> selector = VariableSelector.from_string("sys.message")
"""
def __init__(self, path: list[str]):
@@ -52,10 +52,6 @@ class VariableSelector:
Returns:
VariableSelector 实例
Examples:
>>> selector = VariableSelector.from_string("sys.message")
>>> selector = VariableSelector.from_string("llm_qa.output")
"""
path = selector_str.split(".")
return cls(path)
@@ -67,160 +63,212 @@ class VariableSelector:
return f"VariableSelector({self.path})"
class VariableStruct(BaseModel, Generic[T]):
"""A typed variable struct.
Represents a runtime variable with an associated logical type and
a concrete value object.
This class bridges the static type system (via generics) and the
runtime type system (via ``VariableType``).
Attributes:
type:
Logical variable type descriptor used for runtime validation,
serialization, and workflow type checking.
instance:
The concrete variable object. The actual Python type is
represented by the generic parameter ``T`` (e.g. StringObject,
NumberObject, ArrayObject[StringObject]).
mut:
Whether the variable is mutable.
"""
type: VariableType
instance: T
mut: bool
model_config = {
"arbitrary_types_allowed": True
}
class VariablePool:
"""变量池
管理工作流执行过程中的所有变量。
变量命名空间:
- sys.*: 系统变量message, execution_id, workspace_id, user_id, conversation_id
- conv.*: 会话变量(跨多轮对话保持的变量)
- <node_id>.*: 节点输出
Examples:
>>> pool = VariablePool(state)
>>> pool.get(["sys", "message"])
"用户的问题"
>>> pool.get(["llm_qa", "output"])
"AI 的回答"
>>> pool.set(["conv", "user_name"], "张三")
"""Variable pool.
Manages all variables during workflow execution, including storage,
namespacing, and concurrency control.
Variable namespace conventions:
- ``sys.*``:
System variables (e.g. message, execution_id, workspace_id,
user_id, conversation_id).
- ``conv.*``:
Conversation-level variables that persist across multiple turns.
- ``<node_id>.*``:
Variables produced by workflow nodes.
"""
def __init__(self, state: "WorkflowState"):
"""初始化变量池
Args:
state: 工作流状态LangGraph State
"""
self.state = state
def __init__(self):
"""Initialize the variable pool.
Attributes:
self.locks:
A per-key lock table used for fine-grained concurrency control.
self.variables:
Storage for all variables managed by the pool.
"""
self.locks = defaultdict(Lock)
self.variables: dict[str, dict[str, VariableStruct[Any]]] = {}
@staticmethod
def transform_selector(selector):
pattern = r"\{\{\s*(.*?)\s*\}\}"
variable_literal = re.sub(pattern, r"\1", selector).strip()
selector = VariableSelector.from_string(variable_literal).path
if len(selector) != 2:
raise ValueError(f"Selector not valid - {selector}")
return selector
def _get_variable_struct(
self,
selector: str
) -> VariableStruct[T] | None:
"""Retrieve a variable struct from the variable pool.
def get(self, selector: list[str] | str, default: Any = None) -> Any:
"""获取变量值
Args:
selector: 变量选择器,可以是列表或字符串
default: 默认值(变量不存在时返回)
selector:
Variable selector, either:
- A string variable literal (e.g. "{{ sys.message }}")
Returns:
变量值
Examples:
>>> pool.get(["sys", "message"])
>>> pool.get("sys.message")
>>> pool.get(["llm_qa", "output"])
>>> pool.get("llm_qa.output")
Raises:
KeyError: 变量不存在且未提供默认值
The variable's struct if it exists; otherwise returns None.
"""
# 转换为 VariableSelector
if isinstance(selector, str):
pattern = r"\{\{\s*(.*?)\s*\}\}"
variable_literal = re.sub(pattern, r"\1", selector).strip()
selector = VariableSelector.from_string(variable_literal).path
if not selector or len(selector) < 1:
raise ValueError("变量选择器不能为空")
selector = self.transform_selector(selector)
namespace = selector[0]
variable_name = selector[1]
try:
# 系统变量
if namespace == "sys":
key = selector[1] if len(selector) > 1 else None
if not key:
return self.state.get("variables", {}).get("sys", {})
return self.state.get("variables", {}).get("sys", {}).get(key, default)
namespace_variables = self.variables.get(namespace)
if namespace_variables is None:
return None
# 会话变量
elif namespace == "conv":
key = selector[1] if len(selector) > 1 else None
if not key:
return self.state.get("variables", {}).get("conv", {})
return self.state.get("variables", {}).get("conv", {}).get(key, default)
var_instance = namespace_variables.get(variable_name)
if var_instance is None:
return None
return var_instance
# 节点输出(从 runtime_vars 读取)
else:
node_id = namespace
runtime_vars = self.state.get("runtime_vars", {})
def get_value(
self,
selector: str,
default: Any = None,
strict: bool = True,
) -> Any:
"""Retrieve a variable value from the variable pool.
if node_id not in runtime_vars:
if default is not None:
return default
raise KeyError(f"节点 '{node_id}' 的输出不存在")
Args:
selector:
Variable selector, either:
- A list of path components (e.g. ["sys", "message"])
- A string variable literal (e.g. "{{ sys.message }}")
default:
The value to return if the variable does not exist.
strict:
If True, raises KeyError when the variable does not exist.
node_var = runtime_vars[node_id]
Returns:
The variable's value if it exists; otherwise returns `default`.
# 如果只有节点 ID返回整个变量
if len(selector) == 1:
return node_var
Raises:
KeyError: If strict is True and the variable does not exist.
"""
variable_struct = self._get_variable_struct(selector)
if variable_struct is None:
if strict:
raise KeyError(f"{selector} not exist")
return default
# 获取特定字段
# 支持嵌套访问,如 node_id.field.subfield
result = node_var
for k in selector[1:]:
if isinstance(result, dict):
result = result.get(k)
if result is None:
if default is not None:
return default
raise KeyError(f"字段 '{'.'.join(selector)}' 不存在")
else:
if default is not None:
return default
raise KeyError(f"无法访问 '{'.'.join(selector)}'")
return variable_struct.instance.get_value()
return result
def get_literal(
self,
selector: str,
default: Any = None,
strict: bool = True,
) -> Any:
"""Retrieve a variable value from the variable pool.
except KeyError:
if default is not None:
return default
raise
Args:
selector:
Variable selector, either:
- A list of path components (e.g. ["sys", "message"])
- A string variable literal (e.g. "{{ sys.message }}")
default:
The value to return if the variable does not exist.
strict:
If True, raises KeyError when the variable does not exist.
def set(self, selector: list[str] | str, value: Any):
Returns:
The variable's value if it exists; otherwise returns `default`.
Raises:
KeyError: If strict is True and the variable does not exist.
"""
variable_struct = self._get_variable_struct(selector)
if variable_struct is None:
if strict:
raise KeyError(f"{selector} not exist")
return default
return variable_struct.instance.to_literal()
async def set(
self,
selector: str,
value: Any
):
"""设置变量值
Args:
selector: 变量选择器
value: 变量值
Examples:
>>> pool.set(["conv", "user_name"], "张三")
>>> pool.set("conv.user_name", "张三")
Note:
- 只能设置会话变量 (conv.*)
- 系统变量和节点输出是只读的
"""
# 转换为 VariableSelector
if isinstance(selector, str):
selector = VariableSelector.from_string(selector).path
variable_struct = self._get_variable_struct(selector)
if variable_struct is None:
raise KeyError(f"Variable {selector} is not defined")
if not variable_struct.mut:
raise KeyError(f"{selector} cannot be modified")
async with self.locks[selector]:
variable_struct.instance.set(value)
if not selector or len(selector) < 2:
raise ValueError("变量选择器必须包含命名空间和键名")
async def new(
self,
namespace: str,
key: str,
value: Any,
var_type: VariableType,
mut: bool
):
if self.has(f"{namespace}.{key}"):
try:
await self.set(f"{namespace}.{key}", value)
except KeyError:
pass
instance = create_variable_instance(var_type, value)
variable_struct = VariableStruct(type=var_type, instance=instance, mut=mut)
namespace_variable = self.variables.get(namespace)
if namespace_variable is None:
self.variables[namespace] = {
key: variable_struct
}
else:
self.variables[namespace][key] = variable_struct
namespace = selector[0]
if namespace != "conv" and namespace not in self.state["cycle_nodes"]:
raise ValueError("Only conversation or cycle variables can be assigned.")
key = selector[1]
# 确保 variables 结构存在
if "variables" not in self.state:
self.state["variables"] = {"sys": {}, "conv": {}}
if namespace == "conv":
if "conv" not in self.state["variables"]:
self.state["variables"]["conv"] = {}
# 设置值
self.state["variables"]["conv"][key] = value
elif namespace in self.state["cycle_nodes"]:
self.state["runtime_vars"][namespace][key] = value
logger.debug(f"设置变量: {'.'.join(selector)} = {value}")
def has(self, selector: list[str] | str) -> bool:
def has(self, selector: str) -> bool:
"""检查变量是否存在
Args:
@@ -228,18 +276,8 @@ class VariablePool:
Returns:
变量是否存在
Examples:
>>> pool.has(["sys", "message"])
True
>>> pool.has("llm_qa.output")
False
"""
try:
self.get(selector)
return True
except KeyError:
return False
return self._get_variable_struct(selector) is not None
def get_all_system_vars(self) -> dict[str, Any]:
"""获取所有系统变量
@@ -247,7 +285,8 @@ class VariablePool:
Returns:
系统变量字典
"""
return self.state.get("variables", {}).get("sys", {})
sys_namespace = self.variables.get("sys", {})
return {k: v.instance.value for k, v in sys_namespace.items()}
def get_all_conversation_vars(self) -> dict[str, Any]:
"""获取所有会话变量
@@ -255,7 +294,8 @@ class VariablePool:
Returns:
会话变量字典
"""
return self.state.get("variables", {}).get("conv", {})
conv_namespace = self.variables.get("conv", {})
return {k: v.instance.value for k, v in conv_namespace.items()}
def get_all_node_outputs(self) -> dict[str, Any]:
"""获取所有节点输出(运行时变量)
@@ -263,18 +303,37 @@ class VariablePool:
Returns:
节点输出字典,键为节点 ID
"""
return self.state.get("runtime_vars", {})
runtime_vars = {
namespace: {
k: v.instance.value
for k, v in vars_dict.items()
}
for namespace, vars_dict in self.variables.items()
if namespace not in ("sys", "conv")
}
return runtime_vars
def get_node_output(self, node_id: str) -> dict[str, Any] | None:
def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None:
"""获取指定节点的输出(运行时变量)
Args:
node_id: 节点 ID
defalut: 默认值
strict: 是否严格模式
Returns:
节点输出或 None
"""
return self.state.get("runtime_vars", {}).get(node_id)
node_namespace = self.variables.get(node_id)
if node_namespace:
return {k: v.instance.value for k, v in node_namespace.items()}
if strict:
raise KeyError(f"node {node_id} output not exist")
else:
return defalut
def copy(self, pool: 'VariablePool'):
self.variables = deepcopy(pool.variables)
def to_dict(self) -> dict[str, Any]:
"""导出为字典

View File

@@ -618,6 +618,7 @@ class AppChatService:
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
public=False
) -> AsyncGenerator[dict, None]:
"""聊天(流式)"""
@@ -634,7 +635,8 @@ class AppChatService:
payload=payload,
config=config,
workspace_id=workspace_id,
release_id=release_id
release_id=release_id,
public=public
):
yield event

View File

@@ -4,9 +4,8 @@
import datetime
import logging
import uuid
from typing import Any, Annotated, AsyncGenerator, Optional
from typing import Any, Annotated, Optional
from deprecated import deprecated
from fastapi import Depends
from sqlalchemy.orm import Session
@@ -566,6 +565,41 @@ class WorkflowService:
message=f"工作流执行失败: {str(e)}"
)
@staticmethod
def _map_public_event(event: dict) -> dict | None:
event_type = event.get("event")
payload = event.get("data")
match event_type:
case "workflow_start":
return {
"event": "start",
"data": {
"conversation_id": payload.get("conversation_id"),
}
}
case "workflow_end":
return {
"event": "end",
"data": {
"elapsed_time": payload.get("elapsed_time"),
"message_length": len(payload.get("output", ""))
}
}
case "node_start" | "node_end" | "node_error":
return None
case _:
return event
def _emit(self, public: bool, internal_event: dict):
"""
decide
"""
if public:
mapped = self._map_public_event(internal_event)
else:
mapped = internal_event
return mapped
async def run_stream(
self,
app_id: uuid.UUID,
@@ -663,7 +697,7 @@ class WorkflowService:
input_data=input_data,
execution_id=execution.execution_id,
workspace_id=str(workspace_id),
user_id=payload.user_id
user_id=payload.user_id,
):
if event.get("event") == "workflow_end":
@@ -694,7 +728,9 @@ class WorkflowService:
)
else:
logger.error(f"unexpect workflow run status, status: {status}")
yield event
event = self._emit(public, event)
if event:
yield event
except Exception as e:
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)

View File

@@ -33,7 +33,7 @@ async def run_code(request: RunCodeRequest):
"""Execute code in sandbox"""
if request.language == "python3":
return await run_python_code(request.code, request.preload, request.options)
elif request.language == "nodejs":
elif request.language == "javascript":
return await run_nodejs_code(request.code, request.preload, request.options)
else:
return error_response(-400, "unsupported language")