[ADD] Merge code

This commit is contained in:
Mark
2025-12-15 19:50:21 +08:00
parent ea0a445d5b
commit 7bbef35b7d
54 changed files with 6956 additions and 652 deletions

View File

@@ -0,0 +1,436 @@
"""
工作流执行器
基于 LangGraph 的工作流执行引擎。
"""
import logging
import uuid
import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langgraph.graph import StateGraph, START, END
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.expression_evaluator import evaluate_condition
from app.models.workflow_model import WorkflowExecution, WorkflowNodeExecution
from app.db import get_db
logger = logging.getLogger(__name__)
class WorkflowExecutor:
"""工作流执行器
负责将工作流配置转换为 LangGraph 并执行。
"""
def __init__(
self,
workflow_config: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
):
"""初始化执行器
Args:
workflow_config: 工作流配置
execution_id: 执行 ID
workspace_id: 工作空间 ID
user_id: 用户 ID
"""
self.workflow_config = workflow_config
self.execution_id = execution_id
self.workspace_id = workspace_id
self.user_id = user_id
self.nodes = workflow_config.get("nodes", [])
self.edges = workflow_config.get("edges", [])
self.execution_config = workflow_config.get("execution_config", {})
def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
"""准备初始状态(注入系统变量和会话变量)
变量命名空间:
- sys.xxx - 系统变量execution_id, workspace_id, user_id, message, input_variables 等)
- conv.xxx - 会话变量(跨多轮对话保持)
- node_id.xxx - 节点输出(执行时动态生成)
Args:
input_data: 输入数据
Returns:
初始化的工作流状态
"""
user_message = input_data.get("message") or ""
conversation_vars = input_data.get("conversation_vars") or {}
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
# 构建分层的变量结构
variables = {
"sys": {
"message": user_message, # 用户消息
"conversation_id": input_data.get("conversation_id"), # 会话 ID
"execution_id": self.execution_id, # 执行 ID
"workspace_id": self.workspace_id, # 工作空间 ID
"user_id": self.user_id, # 用户 ID
"input_variables": input_variables, # 自定义输入变量(给 Start 节点使用)
},
"conv": conversation_vars # 会话级变量(跨多轮对话保持)
}
return {
"messages": [HumanMessage(content=user_message)],
"variables": variables,
"node_outputs": {},
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)
"execution_id": self.execution_id,
"workspace_id": self.workspace_id,
"user_id": self.user_id,
"error": None,
"error_node": None
}
def build_graph(self) -> StateGraph:
"""构建 LangGraph
Returns:
编译后的状态图
"""
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
# 1. 创建状态图
workflow = StateGraph(WorkflowState)
# 2. 添加所有节点(包括 start 和 end
start_node_id = None
end_node_ids = []
for node in self.nodes:
node_type = node.get("type")
node_id = node.get("id")
# 记录 start 和 end 节点 ID
if node_type == "start":
start_node_id = node_id
elif node_type == "end":
end_node_ids.append(node_id)
# 创建节点实例(现在 start 和 end 也会被创建)
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_instance:
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
def make_node_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
workflow.add_node(node_id, make_node_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type})")
# 3. 添加边
# 从 START 连接到 start 节点
if start_node_id:
workflow.add_edge(START, start_node_id)
logger.debug(f"添加边: START -> {start_node_id}")
for edge in self.edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
condition = edge.get("condition")
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start
if source == start_node_id:
# 但要连接 start 到下一个节点
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
continue
# 处理到 end 节点的边
if target in end_node_ids:
# 连接到 end 节点
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
continue
# 跳过错误边(在节点内部处理)
if edge_type == "error":
continue
if condition:
# 条件边
def router(state: WorkflowState, cond=condition, tgt=target):
"""条件路由函数"""
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END # 条件不满足,结束
workflow.add_conditional_edges(source, router)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
else:
# 普通边
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
# 从 end 节点连接到 END
for end_node_id in end_node_ids:
workflow.add_edge(end_node_id, END)
logger.debug(f"添加边: {end_node_id} -> END")
# 4. 编译图
graph = workflow.compile()
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
return graph
async def execute(
self,
input_data: dict[str, Any]
) -> dict[str, Any]:
"""执行工作流(非流式)
Args:
input_data: 输入数据,包含 message 和 variables
Returns:
执行结果,包含 status, output, node_outputs, elapsed_time, token_usage
"""
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
# 记录开始时间
start_time = datetime.datetime.now()
# 1. 构建图
graph = self.build_graph()
# 2. 初始化状态(自动注入系统变量)
initial_state = self._prepare_initial_state(input_data)
# 3. 执行工作流
try:
result = await graph.ainvoke(initial_state)
# 计算耗时
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
# 提取节点输出(现在包含 start 和 end 节点)
node_outputs = result.get("node_outputs", {})
# 提取最终输出(从最后一个非 start/end 节点)
final_output = self._extract_final_output(node_outputs)
# 聚合 token 使用情况
token_usage = self._aggregate_token_usage(node_outputs)
# 提取 conversation_id从 start 节点输出)
conversation_id = None
for node_id, node_output in node_outputs.items():
if node_output.get("node_type") == "start":
conversation_id = node_output.get("output", {}).get("conversation_id")
break
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
return {
"status": "completed",
"output": final_output,
"node_outputs": node_outputs,
"messages": result.get("messages", []),
"conversation_id": conversation_id,
"elapsed_time": elapsed_time,
"token_usage": token_usage,
"error": result.get("error")
}
except Exception as e:
# 计算耗时(即使失败也记录)
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True)
return {
"status": "failed",
"error": str(e),
"output": None,
"node_outputs": {},
"elapsed_time": elapsed_time,
"token_usage": None
}
async def execute_stream(
self,
input_data: dict[str, Any]
):
"""执行工作流(流式)
Args:
input_data: 输入数据
Yields:
流式事件
"""
logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}")
# 1. 构建图
graph = self.build_graph()
# 2. 初始化状态(自动注入系统变量)
initial_state = self._prepare_initial_state(input_data)
# 3. 流式执行工作流
try:
# 使用 astream 获取节点级别的更新
async for event in graph.astream(initial_state, stream_mode="updates"):
for node_name, state_update in event.items():
yield {
"type": "node_complete",
"node": node_name,
"data": state_update,
"execution_id": self.execution_id
}
logger.info(f"工作流执行完成(流式): execution_id={self.execution_id}")
# 发送完成事件
yield {
"type": "workflow_complete",
"execution_id": self.execution_id
}
except Exception as e:
logger.error(f"工作流执行失败(流式): execution_id={self.execution_id}, error={e}", exc_info=True)
yield {
"type": "workflow_error",
"execution_id": self.execution_id,
"error": str(e)
}
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
"""从节点输出中提取最终输出
优先级:
1. 最后一个执行的非 start/end 节点的 output
2. 如果没有节点输出,返回 None
Args:
node_outputs: 所有节点的输出
Returns:
最终输出字符串或 None
"""
if not node_outputs:
return None
# 获取最后一个节点的输出
last_node_output = list(node_outputs.values())[-1] if node_outputs else None
if last_node_output and isinstance(last_node_output, dict):
return last_node_output.get("output")
return None
def _aggregate_token_usage(self, node_outputs: dict[str, Any]) -> dict[str, int] | None:
"""聚合所有节点的 token 使用情况
Args:
node_outputs: 所有节点的输出
Returns:
聚合的 token 使用情况 {"prompt_tokens": x, "completion_tokens": y, "total_tokens": z}
如果没有 token 使用信息,返回 None
"""
total_prompt_tokens = 0
total_completion_tokens = 0
total_tokens = 0
has_token_info = False
for node_output in node_outputs.values():
if isinstance(node_output, dict):
token_usage = node_output.get("token_usage")
if token_usage and isinstance(token_usage, dict):
has_token_info = True
total_prompt_tokens += token_usage.get("prompt_tokens", 0)
total_completion_tokens += token_usage.get("completion_tokens", 0)
total_tokens += token_usage.get("total_tokens", 0)
if not has_token_info:
return None
return {
"prompt_tokens": total_prompt_tokens,
"completion_tokens": total_completion_tokens,
"total_tokens": total_tokens
}
async def execute_workflow(
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
) -> dict[str, Any]:
"""执行工作流(便捷函数)
Args:
workflow_config: 工作流配置
input_data: 输入数据
execution_id: 执行 ID
workspace_id: 工作空间 ID
user_id: 用户 ID
Returns:
执行结果
"""
executor = WorkflowExecutor(
workflow_config=workflow_config,
execution_id=execution_id,
workspace_id=workspace_id,
user_id=user_id
)
return await executor.execute(input_data)
async def execute_workflow_stream(
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
):
"""执行工作流(流式,便捷函数)
Args:
workflow_config: 工作流配置
input_data: 输入数据
execution_id: 执行 ID
workspace_id: 工作空间 ID
user_id: 用户 ID
Yields:
流式事件
"""
executor = WorkflowExecutor(
workflow_config=workflow_config,
execution_id=execution_id,
workspace_id=workspace_id,
user_id=user_id
)
async for event in executor.execute_stream(input_data):
yield event

View File

@@ -0,0 +1,195 @@
"""
安全的表达式求值器
使用 simpleeval 库提供安全的表达式评估,避免代码注入攻击。
"""
import logging
from typing import Any
from simpleeval import simple_eval, NameNotDefined, InvalidExpression
logger = logging.getLogger(__name__)
class ExpressionEvaluator:
"""安全的表达式求值器"""
# 保留的命名空间
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
@staticmethod
def evaluate(
expression: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> Any:
"""安全地评估表达式
Args:
expression: 表达式字符串,如 "{{var.score}} > 0.8"
variables: 用户定义的变量
node_outputs: 节点输出结果
system_vars: 系统变量
Returns:
表达式求值结果
Raises:
ValueError: 表达式无效或求值失败
Examples:
>>> evaluator = ExpressionEvaluator()
>>> evaluator.evaluate(
... "var.score > 0.8",
... {"score": 0.9},
... {},
... {}
... )
True
>>> evaluator.evaluate(
... "node.intent.output == '售前咨询'",
... {},
... {"intent": {"output": "售前咨询"}},
... {}
... )
True
"""
# 移除 Jinja2 模板语法的花括号(如果存在)
expression = expression.strip()
if expression.startswith("{{") and expression.endswith("}}"):
expression = expression[2:-2].strip()
# 构建命名空间上下文
context = {
"var": variables, # 用户变量
"node": node_outputs, # 节点输出
"sys": system_vars or {}, # 系统变量
}
# 为了向后兼容,也支持直接访问(但会在日志中警告)
context.update(variables)
context["nodes"] = node_outputs
try:
# simpleeval 只支持安全的操作:
# - 算术运算: +, -, *, /, //, %, **
# - 比较运算: ==, !=, <, <=, >, >=
# - 逻辑运算: and, or, not
# - 成员运算: in, not in
# - 属性访问: obj.attr
# - 字典/列表访问: obj["key"], obj[0]
# 不支持:函数调用、导入、赋值等危险操作
result = simple_eval(expression, names=context)
return result
except NameNotDefined as e:
logger.error(f"表达式中引用了未定义的变量: {expression}, 错误: {e}")
raise ValueError(f"未定义的变量: {e}")
except InvalidExpression as e:
logger.error(f"表达式语法无效: {expression}, 错误: {e}")
raise ValueError(f"表达式语法无效: {e}")
except SyntaxError as e:
logger.error(f"表达式语法错误: {expression}, 错误: {e}")
raise ValueError(f"表达式语法错误: {e}")
except Exception as e:
logger.error(f"表达式求值异常: {expression}, 错误: {e}")
raise ValueError(f"表达式求值失败: {e}")
@staticmethod
def evaluate_bool(
expression: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> bool:
"""评估布尔表达式(用于条件判断)
Args:
expression: 布尔表达式
variables: 用户变量
node_outputs: 节点输出
system_vars: 系统变量
Returns:
布尔值结果
Examples:
>>> ExpressionEvaluator.evaluate_bool(
... "var.count >= 10 and var.status == 'active'",
... {"count": 15, "status": "active"},
... {},
... {}
... )
True
"""
result = ExpressionEvaluator.evaluate(
expression, variables, node_outputs, system_vars
)
return bool(result)
@staticmethod
def validate_variable_names(variables: list[dict]) -> list[str]:
"""验证变量名是否合法
Args:
variables: 变量定义列表
Returns:
错误列表,如果为空则验证通过
Examples:
>>> ExpressionEvaluator.validate_variable_names([
... {"name": "user_input"},
... {"name": "var"} # 保留字
... ])
["变量名 'var' 是保留的命名空间,请使用其他名称"]
"""
errors = []
for var in variables:
var_name = var.get("name", "")
# 检查是否为保留命名空间
if var_name in ExpressionEvaluator.RESERVED_NAMESPACES:
errors.append(
f"变量名 '{var_name}' 是保留的命名空间,请使用其他名称"
)
# 检查是否为有效的 Python 标识符
if not var_name.isidentifier():
errors.append(
f"变量名 '{var_name}' 不是有效的标识符"
)
return errors
# 便捷函数
def evaluate_expression(
expression: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> Any:
"""评估表达式(便捷函数)"""
return ExpressionEvaluator.evaluate(
expression, variables, node_outputs, system_vars
)
def evaluate_condition(
expression: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> bool:
"""评估条件表达式(便捷函数)"""
return ExpressionEvaluator.evaluate_bool(
expression, variables, node_outputs, system_vars
)

View File

@@ -0,0 +1,24 @@
"""
工作流节点实现
提供各种类型的节点实现,用于工作流执行。
"""
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.node_factory import NodeFactory
__all__ = [
"BaseNode",
"WorkflowState",
"LLMNode",
"AgentNode",
"TransformNode",
"StartNode",
"EndNode",
"NodeFactory",
]

View File

@@ -0,0 +1,6 @@
"""Agent 节点"""
from app.core.workflow.nodes.agent.node import AgentNode
from app.core.workflow.nodes.agent.config import AgentNodeConfig
__all__ = ["AgentNode", "AgentNodeConfig"]

View File

@@ -0,0 +1,71 @@
"""Agent 节点配置"""
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
class AgentNodeConfig(BaseNodeConfig):
"""Agent 节点配置
调用已配置的 Agent 执行任务。
"""
agent_id: str = Field(
...,
description="Agent 配置 ID"
)
message: str = Field(
default="{{ sys.message }}",
description="发送给 Agent 的消息,支持模板变量"
)
conversation_id: str | None = Field(
default=None,
description="会话 ID用于多轮对话"
)
variables: dict[str, str] | None = Field(
default=None,
description="传递给 Agent 的变量"
)
timeout: int = Field(
default=300,
ge=1,
le=3600,
description="超时时间(秒)"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="output",
type=VariableType.STRING,
description="Agent 的回复内容"
),
VariableDefinition(
name="conversation_id",
type=VariableType.STRING,
description="会话 ID"
),
VariableDefinition(
name="token_usage",
type=VariableType.OBJECT,
description="Token 使用情况"
)
],
description="输出变量定义(自动生成,通常不需要修改)"
)
class Config:
json_schema_extra = {
"example": {
"agent_id": "uuid-here",
"message": "{{ sys.message }}",
"timeout": 300,
"description": "调用客服 Agent"
}
}

View File

@@ -0,0 +1,152 @@
"""
Agent 节点实现
调用已发布的 Agent 应用。
"""
import logging
from typing import Any
from langchain_core.messages import AIMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.services.draft_run_service import DraftRunService
from app.models import AppRelease
from app.db import get_db
logger = logging.getLogger(__name__)
class AgentNode(BaseNode):
"""Agent 节点
支持流式和非流式输出。
配置示例:
{
"type": "agent",
"config": {
"agent_id": "uuid", # Agent 的 release_id
"message": "{{var.user_input}}"
}
}
"""
def _prepare_agent(self, state: WorkflowState) -> tuple[DraftRunService, AppRelease, str]:
"""准备 Agent公共逻辑
Args:
state: 工作流状态
Returns:
(draft_service, release, message): 服务实例、发布配置、消息
"""
# 1. 渲染消息
message_template = self.config.get("message", "")
message = self._render_template(message_template, state)
# 2. 获取 Agent 配置
agent_id = self.config.get("agent_id")
if not agent_id:
raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置")
db = next(get_db())
release = db.query(AppRelease).filter(
AppRelease.id == agent_id
).first()
if not release:
raise ValueError(f"Agent 不存在: {agent_id}")
draft_service = DraftRunService(db)
return draft_service, release, message
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""非流式执行
Args:
state: 工作流状态
Returns:
状态更新字典
"""
draft_service, release, message = self._prepare_agent(state)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
# 执行 Agent非流式
result = await draft_service.run(
agent_config=release.config,
model_config=None,
message=message,
workspace_id=state.get("workspace_id"),
user_id=state.get("user_id"),
variables=state.get("variables", {})
)
response = result.get("response", "")
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(response)}")
return {
"messages": [AIMessage(content=response)],
"node_outputs": {
self.node_id: {
"output": response,
"status": "completed",
"meta_data": result.get("meta_data", {})
}
}
}
async def execute_stream(self, state: WorkflowState):
"""流式执行
Args:
state: 工作流状态
Yields:
流式事件字典
"""
draft_service, release, message = self._prepare_agent(state)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
# 累积完整响应
full_response = ""
# 执行 Agent流式
async for chunk in draft_service.run_stream(
agent_config=release.config,
model_config=None,
message=message,
workspace_id=state.get("workspace_id"),
user_id=state.get("user_id"),
variables=state.get("variables", {})
):
# 提取内容
content = chunk.get("content", "")
full_response += content
# 流式返回每个 chunk
yield {
"type": "chunk",
"node_id": self.node_id,
"content": content,
"full_content": full_response,
"meta_data": chunk.get("meta_data", {})
}
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}")
# 最后返回完整结果
yield {
"type": "complete",
"messages": [AIMessage(content=full_response)],
"node_outputs": {
self.node_id: {
"output": full_response,
"status": "completed"
}
}
}

View File

@@ -0,0 +1,109 @@
"""节点配置基类
定义所有节点配置的通用字段和数据结构。
"""
from enum import StrEnum
from pydantic import BaseModel, Field
class VariableType(StrEnum):
"""变量类型枚举"""
STRING = "string"
NUMBER = "number"
BOOLEAN = "boolean"
ARRAY = "array"
OBJECT = "object"
class VariableDefinition(BaseModel):
"""变量定义
定义工作流或节点的输入/输出变量。
这是一个通用的数据结构,可以在多个地方使用。
"""
name: str = Field(
...,
description="变量名称"
)
type: VariableType = Field(
default=VariableType.STRING,
description="变量类型"
)
required: bool = Field(
default=False,
description="是否必需"
)
default: str | int | float | bool | list | dict | None = Field(
default=None,
description="默认值"
)
description: str | None = Field(
default=None,
description="变量描述"
)
class Config:
json_schema_extra = {
"examples": [
{
"name": "language",
"type": "string",
"required": False,
"default": "zh-CN",
"description": "语言设置"
},
{
"name": "max_length",
"type": "number",
"required": False,
"default": 1000,
"description": "最大长度"
},
{
"name": "enable_search",
"type": "boolean",
"required": True,
"description": "是否启用搜索"
}
]
}
class BaseNodeConfig(BaseModel):
"""节点配置基类
所有节点配置都应该继承此基类。
通用字段:
- name: 节点名称(显示名称)
- description: 节点描述
- tags: 节点标签(用于分类和搜索)
"""
name: str | None = Field(
default=None,
description="节点名称(显示名称),如果不设置则使用节点 ID"
)
description: str | None = Field(
default=None,
description="节点描述,说明节点的作用"
)
tags: list[str] = Field(
default_factory=list,
description="节点标签,用于分类和搜索"
)
class Config:
"""Pydantic 配置"""
# 允许额外字段(向后兼容)
extra = "allow"

View File

@@ -0,0 +1,556 @@
"""
工作流节点基类
定义节点的基本接口和通用功能。
"""
import asyncio
import logging
from abc import ABC, abstractmethod
from typing import Any, TypedDict, Annotated
from operator import add
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
class WorkflowState(TypedDict):
"""工作流状态
在节点间传递的状态对象,包含消息、变量、节点输出等信息。
"""
# 消息列表(追加模式)
messages: Annotated[list[AnyMessage], add]
# 输入变量(从配置的 variables 传入)
variables: dict[str, Any]
# 节点输出(存储每个节点的执行结果,用于变量引用)
# 使用自定义合并函数,将新的节点输出合并到现有字典中
node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# 运行时节点变量(简化版,只存储业务数据,供节点间快速访问)
# 格式:{node_id: business_result}
runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# 执行上下文
execution_id: str
workspace_id: str
user_id: str
# 错误信息(用于错误边)
error: str | None
error_node: str | None
class BaseNode(ABC):
"""节点基类
所有节点类型都应该继承此基类,实现 execute 方法。
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
"""初始化节点
Args:
node_config: 节点配置
workflow_config: 工作流配置
"""
self.node_config = node_config
self.workflow_config = workflow_config
self.node_id = node_config["id"]
self.node_type = node_config["type"]
self.node_name = node_config.get("name", self.node_id)
# 使用 or 运算符处理 None 值
self.config = node_config.get("config") or {}
self.error_handling = node_config.get("error_handling") or {}
@abstractmethod
async def execute(self, state: WorkflowState) -> Any:
"""执行节点业务逻辑(非流式)
节点只需要返回业务结果,不需要关心输出格式、时间统计等。
BaseNode 会自动包装成标准格式。
Args:
state: 工作流状态
Returns:
业务结果(任意类型)
Examples:
>>> # LLM 节点
>>> return "这是 AI 的回复"
>>> # Transform 节点
>>> return {"processed_data": [...]}
>>> # Start/End 节点
>>> return {"message": "开始", "conversation_id": "xxx"}
"""
pass
async def execute_stream(self, state: WorkflowState):
"""执行节点业务逻辑(流式)
子类可以重写此方法以支持流式输出。
默认实现:执行非流式方法并一次性返回。
节点需要:
1. yield 中间结果(如文本片段)
2. 最后 yield 一个特殊的完成标记:{"__final__": True, "result": final_result}
Args:
state: 工作流状态
Yields:
业务数据chunk或完成标记
Examples:
>>> # 流式 LLM 节点
>>> full_response = ""
>>> async for chunk in llm.astream(prompt):
... full_response += chunk
... yield chunk # yield 文本片段
>>>
>>> # 最后 yield 完成标记
>>> yield {"__final__": True, "result": AIMessage(content=full_response)}
"""
result = await self.execute(state)
# 默认实现:直接 yield 完成标记
yield {"__final__": True, "result": result}
def supports_streaming(self) -> bool:
"""节点是否支持流式输出
Returns:
是否支持流式输出
"""
# 检查子类是否重写了 execute_stream 方法
return self.execute_stream.__func__ != BaseNode.execute_stream.__func__
def get_timeout(self) -> int:
"""获取超时时间(秒)
Returns:
超时时间
"""
return 60
# return self.error_handling.get("timeout", 60)
async def run(self, state: WorkflowState) -> dict[str, Any]:
"""执行节点(带错误处理和输出包装,非流式)
这个方法由 Executor 调用,负责:
1. 时间统计
2. 调用节点的 execute() 方法
3. 将业务结果包装成标准输出格式
4. 错误处理
Args:
state: 工作流状态
Returns:
标准化的状态更新字典
"""
import time
start_time = time.time()
try:
timeout = self.get_timeout()
# 调用节点的业务逻辑
business_result = await asyncio.wait_for(
self.execute(state),
timeout=timeout
)
elapsed_time = time.time() - start_time
# 提取处理后的输出(调用子类的 _extract_output
extracted_output = self._extract_output(business_result)
# 包装成标准输出格式
wrapped_output = self._wrap_output(business_result, elapsed_time, state)
# 将提取后的输出存储到运行时变量中(供后续节点快速访问)
# 如果提取后的输出是字典,拆包存储;否则存储为 output 字段
if isinstance(extracted_output, dict):
runtime_var = extracted_output
else:
runtime_var = {"output": extracted_output}
# 返回包装后的输出和运行时变量
return {
**wrapped_output,
"runtime_vars": {
self.node_id: runtime_var
}
}
except TimeoutError:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
return self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
return self._wrap_error(str(e), elapsed_time, state)
async def run_stream(self, state: WorkflowState):
"""执行节点(带错误处理和输出包装,流式)
这个方法由 Executor 调用,负责:
1. 时间统计
2. 调用节点的 execute_stream() 方法
3. 将业务数据包装成标准输出格式
4. 错误处理
Args:
state: 工作流状态
Yields:
标准化的流式事件
"""
import time
start_time = time.time()
try:
timeout = self.get_timeout()
# 累积完整结果(用于最后的包装)
chunks = []
final_result = None
# 使用异步生成器包装,支持超时
async def stream_with_timeout():
nonlocal final_result
loop_start = asyncio.get_event_loop().time()
async for item in self.execute_stream(state):
# 检查超时
if asyncio.get_event_loop().time() - loop_start > timeout:
raise TimeoutError()
# 检查是否是完成标记
if isinstance(item, dict) and item.get("__final__"):
final_result = item["result"]
elif isinstance(item, str):
# 字符串是 chunk
chunks.append(item)
yield {
"type": "chunk",
"node_id": self.node_id,
"content": item,
"full_content": "".join(chunks)
}
else:
# 其他类型也当作 chunk 处理
chunks.append(str(item))
yield {
"type": "chunk",
"node_id": self.node_id,
"content": str(item),
"full_content": "".join(chunks)
}
async for chunk_event in stream_with_timeout():
yield chunk_event
elapsed_time = time.time() - start_time
# 包装最终结果
final_output = self._wrap_output(final_result, elapsed_time, state)
yield {
"type": "complete",
**final_output
}
except TimeoutError:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
yield {
"type": "error",
**self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
}
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
yield {
"type": "error",
**self._wrap_error(str(e), elapsed_time, state)
}
def _wrap_output(
self,
business_result: Any,
elapsed_time: float,
state: WorkflowState
) -> dict[str, Any]:
"""将业务结果包装成标准输出格式
Args:
business_result: 节点返回的业务结果
elapsed_time: 执行耗时
state: 工作流状态
Returns:
标准化的状态更新字典
"""
# 提取输入数据(用于记录)
input_data = self._extract_input(state)
# 提取 token 使用情况(如果有)
token_usage = self._extract_token_usage(business_result)
# 提取实际输出(去除元数据)
output = self._extract_output(business_result)
# 构建标准节点输出
node_output = {
"node_id": self.node_id,
"node_type": self.node_type,
"node_name": self.node_name,
"status": "completed",
"input": input_data,
"output": output,
"elapsed_time": elapsed_time,
"token_usage": token_usage,
"error": None
}
return {
"node_outputs": {
self.node_id: node_output
}
}
def _wrap_error(
self,
error_message: str,
elapsed_time: float,
state: WorkflowState
) -> dict[str, Any]:
"""将错误包装成标准输出格式
Args:
error_message: 错误信息
elapsed_time: 执行耗时
state: 工作流状态
Returns:
标准化的状态更新字典
"""
# 查找错误边
error_edge = self._find_error_edge()
# 提取输入数据
input_data = self._extract_input(state)
# 构建错误输出
node_output = {
"node_id": self.node_id,
"node_type": self.node_type,
"node_name": self.node_name,
"status": "failed",
"input": input_data,
"output": None,
"elapsed_time": elapsed_time,
"token_usage": None,
"error": error_message
}
if error_edge:
# 有错误边:记录错误并继续
logger.warning(
f"节点 {self.node_id} 执行失败,跳转到错误处理节点: {error_edge['target']}"
)
return {
"node_outputs": {
self.node_id: node_output
},
"error": error_message,
"error_node": self.node_id
}
else:
# 无错误边:抛出异常停止工作流
logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}")
raise Exception(f"节点 {self.node_id} 执行失败: {error_message}")
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
"""提取节点输入数据(用于记录)
子类可以重写此方法来自定义输入记录。
Args:
state: 工作流状态
Returns:
输入数据字典
"""
# 默认返回配置
return {"config": self.config}
def _extract_output(self, business_result: Any) -> Any:
"""从业务结果中提取实际输出
子类可以重写此方法来自定义输出提取。
Args:
business_result: 业务结果
Returns:
实际输出
"""
# 默认直接返回业务结果
return business_result
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
"""从业务结果中提取 token 使用情况
子类可以重写此方法来提取 token 信息。
Args:
business_result: 业务结果
Returns:
token 使用情况或 None
"""
# 默认返回 None
return None
def _find_error_edge(self) -> dict[str, Any] | None:
"""查找错误边
Returns:
错误边配置或 None
"""
for edge in self.workflow_config.get("edges", []):
if edge.get("source") == self.node_id and edge.get("type") == "error":
return edge
return None
def _render_template(self, template: str, state: WorkflowState | None) -> str:
"""渲染模板
支持的变量命名空间:
- sys.xxx: 系统变量message, execution_id, workspace_id, user_id, conversation_id
- conv.xxx: 会话变量(跨多轮对话保持)
- node_id.xxx: 节点输出
Args:
template: 模板字符串
state: 工作流状态
Returns:
渲染后的字符串
"""
from app.core.workflow.template_renderer import render_template
# 处理 state 为 None 的情况
if state is None:
state = {}
# 使用变量池获取变量
pool = VariablePool(state)
return render_template(
template=template,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars()
)
def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool:
"""评估条件表达式
支持的变量命名空间:
- sys.xxx: 系统变量
- conv.xxx: 会话变量
- node_id.xxx: 节点输出
Args:
expression: 条件表达式
state: 工作流状态
Returns:
布尔值结果
"""
from app.core.workflow.expression_evaluator import evaluate_condition
# 处理 state 为 None 的情况
if state is None:
state = {}
# 使用变量池获取变量
pool = VariablePool(state)
return evaluate_condition(
expression=expression,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars()
)
def get_variable_pool(self, state: WorkflowState) -> VariablePool:
"""获取变量池实例
VariablePool 是轻量级包装器,只持有 state 的引用,创建成本极低。
Args:
state: 工作流状态
Returns:
VariablePool 实例
Examples:
>>> pool = self.get_variable_pool(state)
>>> message = pool.get("sys.message")
>>> llm_output = pool.get("llm_qa.output")
"""
return VariablePool(state)
def get_variable(
self,
selector: list[str] | str,
state: WorkflowState,
default: Any = None
) -> Any:
"""获取变量值(便捷方法)
Args:
selector: 变量选择器
state: 工作流状态
default: 默认值
Returns:
变量值
Examples:
>>> message = self.get_variable("sys.message", state)
>>> output = self.get_variable(["llm_qa", "output"], state)
>>> custom = self.get_variable("var.custom", state, default="默认值")
"""
pool = VariablePool(state)
return pool.get(selector, default=default)
def has_variable(self, selector: list[str] | str, state: WorkflowState) -> bool:
"""检查变量是否存在(便捷方法)
Args:
selector: 变量选择器
state: 工作流状态
Returns:
变量是否存在
Examples:
>>> if self.has_variable("llm_qa.output", state):
... output = self.get_variable("llm_qa.output", state)
"""
pool = VariablePool(state)
return pool.has(selector)

View File

@@ -0,0 +1,29 @@
"""节点配置类统一导出
所有节点的配置类都在这里导出,方便使用。
"""
from app.core.workflow.nodes.base_config import (
BaseNodeConfig,
VariableDefinition,
VariableType,
)
from app.core.workflow.nodes.start.config import StartNodeConfig
from app.core.workflow.nodes.end.config import EndNodeConfig
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
from app.core.workflow.nodes.agent.config import AgentNodeConfig
from app.core.workflow.nodes.transform.config import TransformNodeConfig
__all__ = [
# 基础类
"BaseNodeConfig",
"VariableDefinition",
"VariableType",
# 节点配置
"StartNodeConfig",
"EndNodeConfig",
"LLMNodeConfig",
"MessageConfig",
"AgentNodeConfig",
"TransformNodeConfig",
]

View File

@@ -0,0 +1,6 @@
"""End 节点"""
from app.core.workflow.nodes.end.node import EndNode
from app.core.workflow.nodes.end.config import EndNodeConfig
__all__ = ["EndNode", "EndNodeConfig"]

View File

@@ -0,0 +1,37 @@
"""End 节点配置"""
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
class EndNodeConfig(BaseNodeConfig):
"""End 节点配置
End 节点负责输出工作流的最终结果。
"""
output: str = Field(
default="工作流已完成",
description="输出模板,支持引用前置节点的输出,如:{{ llm_qa.output }}"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="output",
type=VariableType.STRING,
description="工作流的最终输出"
)
],
description="输出变量定义(自动生成,通常不需要修改)"
)
class Config:
json_schema_extra = {
"example": {
"output": "{{ llm_qa.output }}",
"description": "输出 LLM 的回答"
}
}

View File

@@ -0,0 +1,53 @@
"""
End 节点实现
工作流的结束节点,输出最终结果。
"""
import logging
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
logger = logging.getLogger(__name__)
class EndNode(BaseNode):
"""End 节点
工作流的结束节点,根据配置的模板输出最终结果。
"""
async def execute(self, state: WorkflowState) -> str:
"""执行 end 节点业务逻辑
Args:
state: 工作流状态
Returns:
最终输出字符串
"""
logger.info(f"节点 {self.node_id} (End) 开始执行")
# 获取配置的输出模板
output_template = self.config.get("output")
pool = self.get_variable_pool(state)
print("="*20)
print( pool.get("start.test"))
print("="*20)
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template:
output = self._render_template(output_template, state)
else:
output = "工作流已完成"
# 统计信息(用于日志)
node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
print("="*20)
print(output)
print("="*20)
return output

View File

@@ -0,0 +1,15 @@
from enum import StrEnum
class NodeType(StrEnum):
START = "start"
END = "end"
ANSWER = "answer"
LLM = "llm"
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
IF_ELSE = "if-else"
CODE = "code"
TRANSFORM = "transform"
QUESTION_CLASSIFIER = "question-classifier"
HTTP_REQUEST = "http-request"
TOOL = "tool"
AGENT = "agent"

View File

@@ -0,0 +1,6 @@
"""LLM 节点"""
from app.core.workflow.nodes.llm.node import LLMNode
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
__all__ = ["LLMNode", "LLMNodeConfig", "MessageConfig"]

View File

@@ -0,0 +1,141 @@
"""LLM 节点配置"""
from pydantic import BaseModel, Field, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
class MessageConfig(BaseModel):
"""消息配置"""
role: str = Field(
...,
description="消息角色system, user, assistant"
)
content: str = Field(
...,
description="消息内容,支持模板变量,如:{{ sys.message }}"
)
@field_validator("role")
@classmethod
def validate_role(cls, v: str) -> str:
"""验证角色"""
allowed_roles = ["system", "user", "human", "assistant", "ai"]
if v.lower() not in allowed_roles:
raise ValueError(f"角色必须是以下之一: {', '.join(allowed_roles)}")
return v.lower()
class LLMNodeConfig(BaseNodeConfig):
"""LLM 节点配置
支持两种配置方式:
1. 简单模式:使用 prompt 字段
2. 消息模式:使用 messages 字段(推荐)
"""
model_id: str = Field(
...,
description="模型配置 ID"
)
# 简单模式
prompt: str | None = Field(
default=None,
description="提示词模板(简单模式),支持变量引用"
)
# 消息模式(推荐)
messages: list[MessageConfig] | None = Field(
default=None,
description="消息列表(消息模式),支持多轮对话"
)
# 模型参数
temperature: float | None = Field(
default=0.7,
ge=0.0,
le=2.0,
description="温度参数,控制输出的随机性"
)
max_tokens: int | None = Field(
default=1000,
ge=1,
le=32000,
description="最大生成 token 数"
)
top_p: float | None = Field(
default=None,
ge=0.0,
le=1.0,
description="Top-p 采样参数"
)
frequency_penalty: float | None = Field(
default=None,
ge=-2.0,
le=2.0,
description="频率惩罚"
)
presence_penalty: float | None = Field(
default=None,
ge=-2.0,
le=2.0,
description="存在惩罚"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="output",
type=VariableType.STRING,
description="LLM 生成的文本输出"
),
VariableDefinition(
name="token_usage",
type=VariableType.OBJECT,
description="Token 使用情况"
)
],
description="输出变量定义(自动生成,通常不需要修改)"
)
@field_validator("messages", "prompt")
@classmethod
def validate_input_mode(cls, v, info):
"""验证输入模式prompt 和 messages 至少有一个"""
# 这个验证在 model_validator 中更合适
return v
class Config:
json_schema_extra = {
"examples": [
{
"model_id": "uuid-here",
"prompt": "请回答:{{ sys.message }}",
"temperature": 0.7,
"max_tokens": 1000
},
{
"model_id": "uuid-here",
"messages": [
{
"role": "system",
"content": "你是一个专业的 AI 助手"
},
{
"role": "user",
"content": "{{ sys.message }}"
}
],
"temperature": 0.7,
"max_tokens": 1000
}
]
}

View File

@@ -0,0 +1,247 @@
"""
LLM 节点实现
调用 LLM 模型进行文本生成。
"""
import logging
from typing import Any
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models import ModelConfig
from app.db import get_db, get_db_context
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService, ModelApiKeyService
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
logger = logging.getLogger(__name__)
class LLMNode(BaseNode):
"""LLM 节点
支持流式和非流式输出,使用 LangChain 标准的消息格式。
配置示例(支持多种消息格式):
1. 简单文本格式:
{
"type": "llm",
"config": {
"model_id": "uuid",
"prompt": "请分析:{{sys.message}}",
"temperature": 0.7,
"max_tokens": 1000
}
}
2. LangChain 消息格式(推荐):
{
"type": "llm",
"config": {
"model_id": "uuid",
"messages": [
{
"role": "system",
"content": "你是一个专业的 AI 助手。"
},
{
"role": "user",
"content": "{{sys.message}}"
}
],
"temperature": 0.7,
"max_tokens": 1000
}
}
支持的角色类型:
- system: 系统消息SystemMessage
- user/human: 用户消息HumanMessage
- ai/assistant: AI 消息AIMessage
"""
def _prepare_llm(self, state: WorkflowState) -> tuple[RedBearLLM, list | str]:
"""准备 LLM 实例(公共逻辑)
Args:
state: 工作流状态
Returns:
(llm, messages_or_prompt): LLM 实例和消息列表或 prompt 字符串
"""
# 1. 处理消息格式(优先使用 messages
messages_config = self.config.get("messages")
if messages_config:
# 使用 LangChain 消息格式
messages = []
for msg_config in messages_config:
role = msg_config.get("role", "user").lower()
content_template = msg_config.get("content", "")
content = self._render_template(content_template, state)
# 根据角色创建对应的消息对象
if role == "system":
messages.append(SystemMessage(content=content))
elif role in ["user", "human"]:
messages.append(HumanMessage(content=content))
elif role in ["ai", "assistant"]:
messages.append(AIMessage(content=content))
else:
logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append(HumanMessage(content=content))
prompt_or_messages = messages
else:
# 使用简单的 prompt 格式(向后兼容)
prompt_template = self.config.get("prompt", "")
prompt_or_messages = self._render_template(prompt_template, state)
# 2. 获取模型配置
model_id = self.config.get("model_id")
if not model_id:
raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置")
# 3. 在 with 块内完成所有数据库操作和数据提取
with get_db_context() as db:
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
if not config:
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
if not config.api_keys or len(config.api_keys) == 0:
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
# 在 Session 关闭前提取所有需要的数据
api_config = config.api_keys[0]
model_name = api_config.model_name
provider = api_config.provider
api_key = api_config.api_key
api_base = api_config.api_base
model_type = config.type
# 4. 创建 LLM 实例(使用已提取的数据)
llm = RedBearLLM(
RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base
),
type=model_type
)
return llm, prompt_or_messages
async def execute(self, state: WorkflowState) -> AIMessage:
"""非流式执行 LLM 调用
Args:
state: 工作流状态
Returns:
LLM 响应消息
"""
llm, prompt_or_messages = self._prepare_llm(state)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
# 调用 LLM支持字符串或消息列表
response = await llm.ainvoke(prompt_or_messages)
# 提取内容
if hasattr(response, 'content'):
content = response.content
else:
content = str(response)
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
# 返回 AIMessage包含响应元数据
return response if isinstance(response, AIMessage) else AIMessage(content=content)
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
"""提取输入数据(用于记录)"""
_, prompt_or_messages = self._prepare_llm(state)
return {
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
"messages": [
{"role": msg.__class__.__name__.replace("Message", "").lower(), "content": msg.content}
for msg in prompt_or_messages
] if isinstance(prompt_or_messages, list) else None,
"config": {
"model_id": self.config.get("model_id"),
"temperature": self.config.get("temperature"),
"max_tokens": self.config.get("max_tokens")
}
}
def _extract_output(self, business_result: Any) -> str:
"""从 AIMessage 中提取文本内容"""
if isinstance(business_result, AIMessage):
return business_result.content
return str(business_result)
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
"""从 AIMessage 中提取 token 使用情况"""
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
usage = business_result.response_metadata.get('token_usage')
if usage:
return {
"prompt_tokens": usage.get('prompt_tokens', 0),
"completion_tokens": usage.get('completion_tokens', 0),
"total_tokens": usage.get('total_tokens', 0)
}
return None
async def execute_stream(self, state: WorkflowState):
"""流式执行 LLM 调用
Args:
state: 工作流状态
Yields:
文本片段chunk或完成标记
"""
llm, prompt_or_messages = self._prepare_llm(state)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
# 累积完整响应
full_response = ""
last_chunk = None
# 调用 LLM流式支持字符串或消息列表
async for chunk in llm.astream(prompt_or_messages):
# 提取内容
if hasattr(chunk, 'content'):
content = chunk.content
else:
content = str(chunk)
full_response += content
last_chunk = chunk
# 流式返回每个文本片段
yield content
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
# 构建完整的 AIMessage包含元数据
if isinstance(last_chunk, AIMessage):
final_message = AIMessage(
content=full_response,
response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
)
else:
final_message = AIMessage(content=full_response)
# yield 完成标记
yield {"__final__": True, "result": final_message}

View File

@@ -0,0 +1,93 @@
"""
节点工厂
根据节点类型创建相应的节点实例。
"""
import logging
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.end import EndNode
logger = logging.getLogger(__name__)
class NodeFactory:
"""节点工厂
使用工厂模式创建节点实例,便于扩展和维护。
"""
# 节点类型注册表
_node_types: dict[str, type[BaseNode]] = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.LLM: LLMNode,
NodeType.AGENT: AgentNode,
NodeType.TRANSFORM: TransformNode,
}
@classmethod
def register_node_type(cls, node_type: str, node_class: type[BaseNode]):
"""注册新的节点类型
Args:
node_type: 节点类型名称
node_class: 节点类
Examples:
>>> class CustomNode(BaseNode):
... async def execute(self, state):
... return {"node_outputs": {self.node_id: {"output": "custom"}}}
>>> NodeFactory.register_node_type("custom", CustomNode)
"""
cls._node_types[node_type] = node_class
logger.info(f"注册节点类型: {node_type} -> {node_class.__name__}")
@classmethod
def create_node(
cls,
node_config: dict[str, Any],
workflow_config: dict[str, Any]
) -> BaseNode | None:
"""创建节点实例
Args:
node_config: 节点配置
workflow_config: 工作流配置
Returns:
节点实例或 None对于不支持的节点类型
Raises:
ValueError: 不支持的节点类型
"""
node_type = node_config.get("type")
# 跳过条件节点(由 LangGraph 处理)
if node_type == "condition":
return None
# 获取节点类
node_class = cls._node_types.get(node_type)
if not node_class:
raise ValueError(f"不支持的节点类型: {node_type}")
# 创建节点实例
logger.debug(f"创建节点: {node_config.get('id')} (type={node_type})")
return node_class(node_config, workflow_config)
@classmethod
def get_supported_types(cls) -> list[str]:
"""获取支持的节点类型列表
Returns:
节点类型列表
"""
return list(cls._node_types.keys())

View File

@@ -0,0 +1,6 @@
"""Start 节点"""
from app.core.workflow.nodes.start.node import StartNode
from app.core.workflow.nodes.start.config import StartNodeConfig
__all__ = ["StartNode", "StartNodeConfig"]

View File

@@ -0,0 +1,87 @@
"""Start 节点配置"""
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
class StartNodeConfig(BaseNodeConfig):
"""Start 节点配置
Start 节点的作用:
1. 标记工作流的起点
2. 定义自定义输入变量(会作为节点输出,通过 start_node_id.variable_name 访问)
3. 输出系统变量和会话变量
"""
# 自定义输入变量定义
variables: list[VariableDefinition] = Field(
default_factory=list,
description="自定义输入变量列表,这些变量会作为 Start 节点的输出"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="message",
type=VariableType.STRING,
description="用户输入的消息"
),
VariableDefinition(
name="conversation_vars",
type=VariableType.OBJECT,
description="会话级变量"
),
VariableDefinition(
name="execution_id",
type=VariableType.STRING,
description="执行 ID"
),
VariableDefinition(
name="conversation_id",
type=VariableType.STRING,
description="会话 ID"
),
VariableDefinition(
name="workspace_id",
type=VariableType.STRING,
description="工作空间 ID"
),
VariableDefinition(
name="user_id",
type=VariableType.STRING,
description="用户 ID"
)
],
description="输出变量定义(自动生成,通常不需要修改)"
)
class Config:
json_schema_extra = {
"examples": [
{
"description": "工作流开始节点",
"variables": []
},
{
"description": "带自定义变量的开始节点",
"variables": [
{
"name": "language",
"type": "string",
"required": False,
"default": "zh-CN",
"description": "语言设置"
},
{
"name": "max_length",
"type": "number",
"required": False,
"default": 1000,
"description": "最大长度"
}
]
}
]
}

View File

@@ -0,0 +1,136 @@
"""
Start 节点实现
工作流的起始节点,定义输入变量并输出系统参数。
"""
import logging
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.start.config import StartNodeConfig
logger = logging.getLogger(__name__)
class StartNode(BaseNode):
"""Start 节点
工作流的起始节点,负责:
1. 定义工作流的输入变量(通过配置)
2. 输出系统变量sys.*
3. 输出会话变量conv.*
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
"""初始化 Start 节点
Args:
node_config: 节点配置
workflow_config: 工作流配置
"""
super().__init__(node_config, workflow_config)
# 解析并验证配置
self.typed_config = StartNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行 start 节点业务逻辑
Start 节点输出系统变量、会话变量和自定义变量。
Args:
state: 工作流状态
Returns:
包含系统参数、会话变量和自定义变量的字典
"""
logger.info(f"节点 {self.node_id} (Start) 开始执行")
# 创建变量池实例(在方法内复用)
pool = self.get_variable_pool(state)
# 处理自定义变量(传入 pool 避免重复创建)
custom_vars = self._process_custom_variables(pool)
# 返回业务数据(包含自定义变量)
result = {
"message": pool.get("sys.message"),
"execution_id": pool.get("sys.execution_id"),
"conversation_id": pool.get("sys.conversation_id"),
"workspace_id": pool.get("sys.workspace_id"),
"user_id": pool.get("sys.user_id"),
**custom_vars # 自定义变量作为节点输出的一部分
}
logger.info(
f"节点 {self.node_id} (Start) 执行完成,"
f"输出了 {len(custom_vars)} 个自定义变量"
)
return result
def _process_custom_variables(self, pool) -> dict[str, Any]:
"""处理自定义变量
从输入数据中提取自定义变量,应用默认值和验证。
Args:
pool: 变量池实例
Returns:
处理后的自定义变量字典
Raises:
ValueError: 缺少必需变量
"""
# 获取输入数据中的自定义变量
input_variables = pool.get("sys.input_variables", default={})
processed = {}
# 遍历配置的变量定义
for var_def in self.typed_config.variables:
var_name = var_def.name
# 检查变量是否存在
if var_name in input_variables:
# 使用用户提供的值
processed[var_name] = input_variables[var_name]
elif var_def.required:
# 必需变量缺失
raise ValueError(
f"缺少必需的输入变量: {var_name}"
+ (f" ({var_def.description})" if var_def.description else "")
)
elif var_def.default is not None:
# 使用默认值
processed[var_name] = var_def.default
logger.debug(
f"变量 '{var_name}' 使用默认值: {var_def.default}"
)
return processed
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
"""提取输入数据(用于记录)
Args:
state: 工作流状态
Returns:
输入数据字典
"""
pool = self.get_variable_pool(state)
return {
"execution_id": pool.get("sys.execution_id"),
"conversation_id": pool.get("sys.conversation_id"),
"message": pool.get("sys.message"),
"conversation_vars": pool.get_all_conversation_vars()
}

View File

@@ -0,0 +1,6 @@
"""Transform 节点"""
from app.core.workflow.nodes.transform.node import TransformNode
from app.core.workflow.nodes.transform.config import TransformNodeConfig
__all__ = ["TransformNode", "TransformNodeConfig"]

View File

@@ -0,0 +1,80 @@
"""Transform 节点配置"""
from typing import Literal
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
class TransformNodeConfig(BaseNodeConfig):
"""Transform 节点配置
用于数据转换和处理。
"""
transform_type: Literal["template", "code", "json"] = Field(
default="template",
description="转换类型template(模板), code(代码), json(JSON处理)"
)
# 模板模式
template: str | None = Field(
default=None,
description="转换模板,支持变量引用"
)
# 代码模式
code: str | None = Field(
default=None,
description="Python 代码,用于数据转换"
)
# JSON 模式
json_path: str | None = Field(
default=None,
description="JSON 路径表达式"
)
# 输入变量
inputs: dict[str, str] | None = Field(
default=None,
description="输入变量映射key 为变量名value 为变量选择器"
)
# 输出变量
output_key: str = Field(
default="result",
description="输出变量的键名"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="result",
type=VariableType.STRING,
description="转换后的结果"
)
],
description="输出变量定义(根据 output_key 动态生成)"
)
class Config:
json_schema_extra = {
"examples": [
{
"transform_type": "template",
"template": "用户问题:{{ sys.message }}\n回答:{{ llm_qa.output }}",
"output_key": "formatted_result"
},
{
"transform_type": "code",
"code": "result = input_text.upper()",
"inputs": {
"input_text": "{{ sys.message }}"
},
"output_key": "uppercase_text"
}
]
}

View File

@@ -0,0 +1,60 @@
"""
Transform 节点实现
数据转换节点,用于处理和转换数据。
"""
import logging
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
logger = logging.getLogger(__name__)
class TransformNode(BaseNode):
"""数据转换节点
配置示例:
{
"type": "transform",
"config": {
"mapping": {
"output_field": "{{node.previous.output}}",
"processed": "{{var.input | upper}}"
}
}
}
"""
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行数据转换
Args:
state: 工作流状态
Returns:
状态更新字典
"""
logger.info(f"节点 {self.node_id} 开始执行数据转换")
# 获取映射配置
mapping = self.config.get("mapping", {})
# 执行数据转换
transformed_data = {}
for target_key, source_template in mapping.items():
# 渲染模板获取值
value = self._render_template(str(source_template), state)
transformed_data[target_key] = value
logger.info(f"节点 {self.node_id} 数据转换完成,输出字段: {list(transformed_data.keys())}")
return {
"node_outputs": {
self.node_id: {
"output": transformed_data,
"status": "completed"
}
}
}

View File

@@ -0,0 +1,170 @@
"""
工作流模板加载器
从文件系统加载预定义的工作流模板
"""
import os
import yaml
from pathlib import Path
from typing import Optional
class TemplateLoader:
"""工作流模板加载器"""
def __init__(self, templates_dir: str = "app/templates/workflows"):
"""初始化模板加载器
Args:
templates_dir: 模板目录路径
"""
self.templates_dir = Path(templates_dir)
if not self.templates_dir.exists():
raise ValueError(f"模板目录不存在: {templates_dir}")
def list_templates(self) -> list[dict]:
"""列出所有可用的模板
Returns:
模板列表,每个模板包含 id, name, description 等信息
"""
templates = []
# 遍历模板目录
for template_dir in self.templates_dir.iterdir():
if not template_dir.is_dir():
continue
# 检查是否有 template.yml 文件
template_file = template_dir / "template.yml"
if not template_file.exists():
continue
try:
# 读取模板配置
with open(template_file, 'r', encoding='utf-8') as f:
template_data = yaml.safe_load(f)
# 提取模板信息
templates.append({
"id": template_dir.name,
"name": template_data.get("name", template_dir.name),
"description": template_data.get("description", ""),
"category": template_data.get("category", "general"),
"tags": template_data.get("tags", []),
"author": template_data.get("author", ""),
"version": template_data.get("version", "1.0.0")
})
except Exception as e:
print(f"加载模板 {template_dir.name} 失败: {e}")
continue
return templates
def load_template(self, template_id: str) -> Optional[dict]:
"""加载指定的模板
Args:
template_id: 模板 ID目录名
Returns:
模板配置字典,如果模板不存在则返回 None
"""
template_dir = self.templates_dir / template_id
template_file = template_dir / "template.yml"
if not template_file.exists():
return None
try:
with open(template_file, 'r', encoding='utf-8') as f:
template_data = yaml.safe_load(f)
# 返回工作流配置部分
return {
"name": template_data.get("name", template_id),
"description": template_data.get("description", ""),
"nodes": template_data.get("nodes", []),
"edges": template_data.get("edges", []),
"variables": template_data.get("variables", []),
"execution_config": template_data.get("execution_config", {}),
"triggers": template_data.get("triggers", [])
}
except Exception as e:
print(f"加载模板 {template_id} 失败: {e}")
return None
def get_template_readme(self, template_id: str) -> Optional[str]:
"""获取模板的 README 文档
Args:
template_id: 模板 ID
Returns:
README 内容,如果不存在则返回 None
"""
template_dir = self.templates_dir / template_id
readme_file = template_dir / "README.md"
if not readme_file.exists():
return None
try:
with open(readme_file, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
print(f"读取模板 {template_id} 的 README 失败: {e}")
return None
# 全局模板加载器实例
_template_loader: Optional[TemplateLoader] = None
def get_template_loader() -> TemplateLoader:
"""获取全局模板加载器实例
Returns:
TemplateLoader 实例
"""
global _template_loader
if _template_loader is None:
_template_loader = TemplateLoader()
return _template_loader
def list_workflow_templates() -> list[dict]:
"""列出所有工作流模板
Returns:
模板列表
"""
loader = get_template_loader()
return loader.list_templates()
def load_workflow_template(template_id: str) -> Optional[dict]:
"""加载工作流模板
Args:
template_id: 模板 ID
Returns:
模板配置,如果不存在则返回 None
"""
loader = get_template_loader()
return loader.load_template(template_id)
def get_workflow_template_readme(template_id: str) -> Optional[str]:
"""获取工作流模板的 README
Args:
template_id: 模板 ID
Returns:
README 内容,如果不存在则返回 None
"""
loader = get_template_loader()
return loader.get_template_readme(template_id)

View File

@@ -0,0 +1,170 @@
"""
模板渲染器
使用 Jinja2 提供安全的模板渲染功能,支持变量引用和表达式。
"""
import logging
from typing import Any
from jinja2 import Template, TemplateSyntaxError, UndefinedError, Environment, StrictUndefined
logger = logging.getLogger(__name__)
class TemplateRenderer:
"""模板渲染器"""
def __init__(self, strict: bool = True):
"""初始化渲染器
Args:
strict: 是否使用严格模式(未定义变量会抛出异常)
"""
self.env = Environment(
undefined=StrictUndefined if strict else None,
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
)
def render(
self,
template: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> str:
"""渲染模板
Args:
template: 模板字符串
variables: 用户定义的变量
node_outputs: 节点输出结果
system_vars: 系统变量
Returns:
渲染后的字符串
Raises:
ValueError: 模板语法错误或变量未定义
Examples:
>>> renderer = TemplateRenderer()
>>> renderer.render(
... "Hello {{var.name}}!",
... {"name": "World"},
... {},
... {}
... )
'Hello World!'
>>> renderer.render(
... "分析结果: {{node.analyze.output}}",
... {},
... {"analyze": {"output": "正面情绪"}},
... {}
... )
'分析结果: 正面情绪'
"""
# 构建命名空间上下文
context = {
"var": variables, # 用户变量:{{var.user_input}}
"node": node_outputs, # 节点输出:{{node.node_1.output}}
"sys": system_vars or {}, # 系统变量:{{sys.execution_id}}
}
# 支持直接通过节点ID访问节点输出{{llm_qa.output}}
# 将所有节点输出添加到顶层上下文
context.update(node_outputs)
# 为了向后兼容,也支持直接访问用户变量
context.update(variables)
context["nodes"] = node_outputs # 旧语法兼容
try:
tmpl = self.env.from_string(template)
return tmpl.render(**context)
except TemplateSyntaxError as e:
logger.error(f"模板语法错误: {template}, 错误: {e}")
raise ValueError(f"模板语法错误: {e}")
except UndefinedError as e:
logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}")
raise ValueError(f"未定义的变量: {e}")
except Exception as e:
logger.error(f"模板渲染异常: {template}, 错误: {e}")
raise ValueError(f"模板渲染失败: {e}")
def validate(self, template: str) -> list[str]:
"""验证模板语法
Args:
template: 模板字符串
Returns:
错误列表,如果为空则验证通过
Examples:
>>> renderer = TemplateRenderer()
>>> renderer.validate("Hello {{var.name}}!")
[]
>>> renderer.validate("Hello {{var.name") # 缺少结束标记
['模板语法错误: ...']
"""
errors = []
try:
self.env.from_string(template)
except TemplateSyntaxError as e:
errors.append(f"模板语法错误: {e}")
except Exception as e:
errors.append(f"模板验证失败: {e}")
return errors
# 全局渲染器实例(严格模式)
_default_renderer = TemplateRenderer(strict=True)
def render_template(
template: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> str:
"""渲染模板(便捷函数)
Args:
template: 模板字符串
variables: 用户变量
node_outputs: 节点输出
system_vars: 系统变量
Returns:
渲染后的字符串
Examples:
>>> render_template(
... "请分析: {{var.text}}",
... {"text": "这是一段文本"},
... {},
... {}
... )
'请分析: 这是一段文本'
"""
return _default_renderer.render(template, variables, node_outputs, system_vars)
def validate_template(template: str) -> list[str]:
"""验证模板语法(便捷函数)
Args:
template: 模板字符串
Returns:
错误列表
"""
return _default_renderer.validate(template)

View File

@@ -0,0 +1,277 @@
"""
工作流配置验证器
验证工作流配置的有效性,确保配置符合规范。
"""
import logging
from typing import Any, Union
logger = logging.getLogger(__name__)
class WorkflowValidator:
"""工作流配置验证器"""
@staticmethod
def validate(workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
"""验证工作流配置
Args:
workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型
Returns:
(is_valid, errors): 是否有效和错误列表
Examples:
>>> config = {
... "nodes": [
... {"id": "start", "type": "start"},
... {"id": "end", "type": "end"}
... ],
... "edges": [
... {"source": "start", "target": "end"}
... ]
... }
>>> is_valid, errors = WorkflowValidator.validate(config)
>>> is_valid
True
"""
errors = []
# 支持字典和 Pydantic 模型
if isinstance(workflow_config, dict):
nodes = workflow_config.get("nodes", [])
edges = workflow_config.get("edges", [])
variables = workflow_config.get("variables", [])
else:
# Pydantic 模型
nodes = getattr(workflow_config, "nodes", [])
edges = getattr(workflow_config, "edges", [])
variables = getattr(workflow_config, "variables", [])
# 1. 验证 start 节点(有且只有一个)
start_nodes = [n for n in nodes if n.get("type") == "start"]
if len(start_nodes) == 0:
errors.append("工作流必须有一个 start 节点")
elif len(start_nodes) > 1:
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)}")
# 2. 验证 end 节点(至少一个)
end_nodes = [n for n in nodes if n.get("type") == "end"]
if len(end_nodes) == 0:
errors.append("工作流必须至少有一个 end 节点")
# 3. 验证节点 ID 唯一性
node_ids = [n.get("id") for n in nodes]
if len(node_ids) != len(set(node_ids)):
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
# 4. 验证节点必须有 id 和 type
for i, node in enumerate(nodes):
if not node.get("id"):
errors.append(f"节点 #{i} 缺少 id 字段")
if not node.get("type"):
errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段")
# 5. 验证边的有效性
node_id_set = set(node_ids)
for i, edge in enumerate(edges):
source = edge.get("source")
target = edge.get("target")
if not source:
errors.append(f"边 #{i} 缺少 source 字段")
elif source not in node_id_set:
errors.append(f"边 #{i} 的 source 节点不存在: {source}")
if not target:
errors.append(f"边 #{i} 缺少 target 字段")
elif target not in node_id_set:
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
# 6. 验证所有节点可达(从 start 节点出发)
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
reachable = WorkflowValidator._get_reachable_nodes(
start_nodes[0]["id"],
edges
)
unreachable = node_id_set - reachable
if unreachable:
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
# 7. 检测循环依赖(非 loop 节点)
if not errors: # 只有在前面验证通过时才检查循环
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
if has_cycle:
errors.append(
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
)
# 8. 验证变量名
from app.core.workflow.expression_evaluator import ExpressionEvaluator
var_errors = ExpressionEvaluator.validate_variable_names(variables)
errors.extend(var_errors)
return len(errors) == 0, errors
@staticmethod
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
"""获取从 start 节点可达的所有节点
Args:
start_id: 起始节点 ID
edges: 边列表
Returns:
可达节点 ID 集合
"""
reachable = {start_id}
queue = [start_id]
while queue:
current = queue.pop(0)
for edge in edges:
if edge.get("source") == current:
target = edge.get("target")
if target and target not in reachable:
reachable.add(target)
queue.append(target)
return reachable
@staticmethod
def _has_cycle(nodes: list[dict], edges: list[dict]) -> tuple[bool, list[str]]:
"""检测是否存在循环依赖DFS
Args:
nodes: 节点列表
edges: 边列表
Returns:
(has_cycle, cycle_path): 是否有循环和循环路径
"""
# 排除 loop 类型的节点
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
# 构建邻接表(排除 loop 节点的边和错误边)
graph: dict[str, list[str]] = {}
for edge in edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
# 跳过错误边
if edge_type == "error":
continue
# 如果涉及 loop 节点,跳过
if source in loop_nodes or target in loop_nodes:
continue
if source and target:
if source not in graph:
graph[source] = []
graph[source].append(target)
# DFS 检测环
visited = set()
rec_stack = set()
path = []
cycle_path = []
def dfs(node: str) -> bool:
"""DFS 检测环,返回是否找到环"""
visited.add(node)
rec_stack.add(node)
path.append(node)
for neighbor in graph.get(node, []):
if neighbor not in visited:
if dfs(neighbor):
return True
elif neighbor in rec_stack:
# 找到环,记录环路径
cycle_start = path.index(neighbor)
cycle_path.extend([*path[cycle_start:], neighbor])
return True
rec_stack.remove(node)
path.pop()
return False
# 检查所有节点
for node_id in graph:
if node_id not in visited:
if dfs(node_id):
return True, cycle_path
return False, []
@staticmethod
def validate_for_publish(workflow_config: dict[str, Any]) -> tuple[bool, list[str]]:
"""验证工作流配置是否可以发布(更严格的验证)
Args:
workflow_config: 工作流配置
Returns:
(is_valid, errors): 是否有效和错误列表
"""
# 先执行基础验证
is_valid, errors = WorkflowValidator.validate(workflow_config)
if not is_valid:
return False, errors
# 额外的发布验证
nodes = workflow_config.get("nodes", [])
# 1. 验证所有节点都有名称
for node in nodes:
if node.get("type") not in ["start", "end"] and not node.get("name"):
errors.append(
f"节点 {node.get('id')} 缺少名称(发布时必须提供)"
)
# 2. 验证所有非 start/end 节点都有配置
for node in nodes:
node_type = node.get("type")
if node_type not in ["start", "end"]:
config = node.get("config")
if not config or not isinstance(config, dict):
errors.append(
f"节点 {node.get('id')} 缺少配置(发布时必须提供)"
)
# 3. 验证必填变量
variables = workflow_config.get("variables", [])
required_vars = [v for v in variables if v.get("required")]
if required_vars:
# 这里只是提示,实际执行时会检查
logger.info(
f"工作流包含 {len(required_vars)} 个必填变量: "
f"{[v.get('name') for v in required_vars]}"
)
return len(errors) == 0, errors
def validate_workflow_config(
workflow_config: dict[str, Any],
for_publish: bool = False
) -> tuple[bool, list[str]]:
"""验证工作流配置(便捷函数)
Args:
workflow_config: 工作流配置
for_publish: 是否为发布验证(更严格)
Returns:
(is_valid, errors): 是否有效和错误列表
"""
if for_publish:
return WorkflowValidator.validate_for_publish(workflow_config)
else:
return WorkflowValidator.validate(workflow_config)

View File

@@ -0,0 +1,293 @@
"""
变量池 (Variable Pool)
工作流执行的数据中心,管理所有变量的存储和访问。
变量类型:
1. 系统变量 (sys.*) - 系统内置变量execution_id, workspace_id, user_id, message 等)
2. 节点输出 (node_id.*) - 节点执行结果
3. 会话变量 (conv.*) - 会话级变量(跨多轮对话保持)
"""
import logging
from typing import Any
logger = logging.getLogger(__name__)
class VariableSelector:
"""变量选择器
用于引用变量的路径表示。
Examples:
>>> selector = VariableSelector(["sys", "message"])
>>> selector = VariableSelector(["node_A", "output"])
>>> selector = VariableSelector.from_string("sys.message")
"""
def __init__(self, path: list[str]):
"""初始化变量选择器
Args:
path: 变量路径,如 ["sys", "message"] 或 ["node_A", "output"]
"""
if not path or len(path) < 1:
raise ValueError("变量路径不能为空")
self.path = path
self.namespace = path[0] # sys, var, 或 node_id
self.key = path[1] if len(path) > 1 else None
@classmethod
def from_string(cls, selector_str: str) -> "VariableSelector":
"""从字符串创建选择器
Args:
selector_str: 选择器字符串,如 "sys.message""node_A.output"
Returns:
VariableSelector 实例
Examples:
>>> selector = VariableSelector.from_string("sys.message")
>>> selector = VariableSelector.from_string("llm_qa.output")
"""
path = selector_str.split(".")
return cls(path)
def __str__(self) -> str:
return ".".join(self.path)
def __repr__(self) -> str:
return f"VariableSelector({self.path})"
class VariablePool:
"""变量池
管理工作流执行过程中的所有变量。
变量命名空间:
- sys.*: 系统变量message, execution_id, workspace_id, user_id, conversation_id
- conv.*: 会话变量(跨多轮对话保持的变量)
- <node_id>.*: 节点输出
Examples:
>>> pool = VariablePool(state)
>>> pool.get(["sys", "message"])
"用户的问题"
>>> pool.get(["llm_qa", "output"])
"AI 的回答"
>>> pool.set(["conv", "user_name"], "张三")
"""
def __init__(self, state: dict[str, Any]):
"""初始化变量池
Args:
state: 工作流状态LangGraph State
"""
self.state = state
def get(self, selector: list[str] | str, default: Any = None) -> Any:
"""获取变量值
Args:
selector: 变量选择器,可以是列表或字符串
default: 默认值(变量不存在时返回)
Returns:
变量值
Examples:
>>> pool.get(["sys", "message"])
>>> pool.get("sys.message")
>>> pool.get(["llm_qa", "output"])
>>> pool.get("llm_qa.output")
Raises:
KeyError: 变量不存在且未提供默认值
"""
# 转换为 VariableSelector
if isinstance(selector, str):
selector = VariableSelector.from_string(selector).path
if not selector or len(selector) < 1:
raise ValueError("变量选择器不能为空")
namespace = selector[0]
try:
# 系统变量
if namespace == "sys":
key = selector[1] if len(selector) > 1 else None
if not key:
return self.state.get("variables", {}).get("sys", {})
return self.state.get("variables", {}).get("sys", {}).get(key, default)
# 会话变量
elif namespace == "conv":
key = selector[1] if len(selector) > 1 else None
if not key:
return self.state.get("variables", {}).get("conv", {})
return self.state.get("variables", {}).get("conv", {}).get(key, default)
# 节点输出(从 runtime_vars 读取)
else:
node_id = namespace
runtime_vars = self.state.get("runtime_vars", {})
if node_id not in runtime_vars:
if default is not None:
return default
raise KeyError(f"节点 '{node_id}' 的输出不存在")
node_var = runtime_vars[node_id]
# 如果只有节点 ID返回整个变量
if len(selector) == 1:
return node_var
# 获取特定字段
# 支持嵌套访问,如 node_id.field.subfield
result = node_var
for k in selector[1:]:
if isinstance(result, dict):
result = result.get(k)
if result is None:
if default is not None:
return default
raise KeyError(f"字段 '{'.'.join(selector)}' 不存在")
else:
if default is not None:
return default
raise KeyError(f"无法访问 '{'.'.join(selector)}'")
return result
except KeyError:
if default is not None:
return default
raise
def set(self, selector: list[str] | str, value: Any):
"""设置变量值
Args:
selector: 变量选择器
value: 变量值
Examples:
>>> pool.set(["conv", "user_name"], "张三")
>>> pool.set("conv.user_name", "张三")
Note:
- 只能设置会话变量 (conv.*)
- 系统变量和节点输出是只读的
"""
# 转换为 VariableSelector
if isinstance(selector, str):
selector = VariableSelector.from_string(selector).path
if not selector or len(selector) < 2:
raise ValueError("变量选择器必须包含命名空间和键名")
namespace = selector[0]
if namespace != "conv":
raise ValueError("只能设置会话变量 (conv.*)")
key = selector[1]
# 确保 variables 结构存在
if "variables" not in self.state:
self.state["variables"] = {"sys": {}, "conv": {}}
if "conv" not in self.state["variables"]:
self.state["variables"]["conv"] = {}
# 设置值
self.state["variables"]["conv"][key] = value
logger.debug(f"设置变量: {'.'.join(selector)} = {value}")
def has(self, selector: list[str] | str) -> bool:
"""检查变量是否存在
Args:
selector: 变量选择器
Returns:
变量是否存在
Examples:
>>> pool.has(["sys", "message"])
True
>>> pool.has("llm_qa.output")
False
"""
try:
self.get(selector)
return True
except KeyError:
return False
def get_all_system_vars(self) -> dict[str, Any]:
"""获取所有系统变量
Returns:
系统变量字典
"""
return self.state.get("variables", {}).get("sys", {})
def get_all_conversation_vars(self) -> dict[str, Any]:
"""获取所有会话变量
Returns:
会话变量字典
"""
return self.state.get("variables", {}).get("conv", {})
def get_all_node_outputs(self) -> dict[str, Any]:
"""获取所有节点输出(运行时变量)
Returns:
节点输出字典,键为节点 ID
"""
return self.state.get("runtime_vars", {})
def get_node_output(self, node_id: str) -> dict[str, Any] | None:
"""获取指定节点的输出(运行时变量)
Args:
node_id: 节点 ID
Returns:
节点输出或 None
"""
return self.state.get("runtime_vars", {}).get(node_id)
def to_dict(self) -> dict[str, Any]:
"""导出为字典
Returns:
包含所有变量的字典
"""
return {
"system": self.get_all_system_vars(),
"conversation": self.get_all_conversation_vars(),
"nodes": self.get_all_node_outputs() # 从 runtime_vars 读取
}
def __repr__(self) -> str:
sys_vars = self.get_all_system_vars()
conv_vars = self.get_all_conversation_vars()
runtime_vars = self.get_all_node_outputs()
return (
f"VariablePool(\n"
f" system_vars={len(sys_vars)},\n"
f" conversation_vars={len(conv_vars)},\n"
f" runtime_vars={len(runtime_vars)}\n"
f")"
)