[ADD] Merge code

This commit is contained in:
Mark
2025-12-15 19:50:21 +08:00
parent ea0a445d5b
commit 7bbef35b7d
54 changed files with 6956 additions and 652 deletions

View File

@@ -1,10 +1,12 @@
import asyncio
import time
import uuid
from functools import wraps
from typing import Optional, List
from datetime import datetime
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from app.core.api_key_utils import add_rate_limit_headers
@@ -22,21 +24,17 @@ logger = get_api_logger()
def require_api_key(
scopes: Optional[List[str]] = None,
resource_type: Optional[str] = None
scopes: Optional[List[str]] = None
):
"""
API Key 鉴权装饰器
Args:
scopes: 所需的权限范围列表["app:all",
"rag:search", "rag:upload", "rag:delete",
"memory:read", "memory:write", "memory:delete", "memory:search"]
resource_type: 所需的资源类型("Agent", "Cluster", "Workflow", "Knowledge", "Memory_Engine")
scopes: 所需的权限范围列表[“app”, "rag", "memory"]
Usage:
@router.get("/app/{resource_id}/chat")
@require_api_key(scopes=["app:all"], resource_type="Agent")
@require_api_key(scopes=["app"])
def chat_with_app(
resource_id: uuid.UUID,
api_key_auth: ApiKeyAuth = Depends(),
@@ -113,31 +111,25 @@ def require_api_key(
context={"required_scopes": scopes, "missing_scopes": missing_scopes}
)
if resource_type:
resource_id = kwargs.get("resource_id")
if resource_id and not ApiKeyAuthService.check_resource(
api_key_obj,
resource_type,
resource_id
):
logger.warning("API Key 资源访问被拒绝", extra={
"api_key_id": str(api_key_obj.id),
"required_resource_type": resource_type,
resource_id = kwargs.get("resource_id")
if resource_id and not ApiKeyAuthService.check_resource(
api_key_obj,
resource_id
):
logger.warning("API Key 资源访问被拒绝", extra={
"api_key_id": str(api_key_obj.id),
"required_resource_id": str(resource_id),
"bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None,
"endpoint": str(request.url)
})
return BusinessException(
"API Key 未授权访问该资源",
BizCode.API_KEY_INVALID_RESOURCE,
context={
"required_resource_id": str(resource_id),
"bound_resource_type": api_key_obj.resource_type,
"bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None,
"endpoint": str(request.url)
})
return BusinessException(
"API Key 未授权访问该资源",
BizCode.API_KEY_INVALID_RESOURCE,
context={
"required_resource_type": resource_type,
"required_resource_id": str(resource_id),
"bound_resource_type": api_key_obj.resource_type,
"bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None
}
)
"bound_resource_id": str(api_key_obj.resource_id)
}
)
kwargs["api_key_auth"] = ApiKeyAuth(
api_key_id=api_key_obj.id,
@@ -145,14 +137,17 @@ def require_api_key(
type=api_key_obj.type,
scopes=api_key_obj.scopes,
resource_id=api_key_obj.resource_id,
resource_type=api_key_obj.resource_type
)
start_time = time.perf_counter()
response = await func(*args, **kwargs)
end_time = time.perf_counter()
response_time = (end_time - start_time) * 1000
if not isinstance(response, Response):
response = JSONResponse(content=response)
response = add_rate_limit_headers(response, rate_headers)
asyncio.create_task(log_api_key_usage(
db, api_key_obj.id, request, response
db, api_key_obj.id, request, response, response_time
))
return response
@@ -204,7 +199,8 @@ async def log_api_key_usage(
db: Session,
api_key_id: uuid.UUID,
request: Request,
response: Response
response: Response,
response_time: float
):
"""记录 API Key 使用日志"""
try:
@@ -216,8 +212,8 @@ async def log_api_key_usage(
"ip_address": request.client.host if request.client else None,
"user_agent": request.headers.get("User-Agent"),
"status_code": response.status_code if hasattr(response, "status_code") else None,
"response_time": None, # 需要在 middleware 中计算
"tokens_used": None, # 需要从响应中提取
"response_time": round(response_time),
"tokens_used": None,
"created_at": datetime.now()
}

View File

@@ -1,33 +1,14 @@
"""API Key 工具函数"""
import secrets
import hashlib
from typing import Optional
from typing import Optional, Union
from datetime import datetime
from app.schemas.api_key_schema import ApiKeyType
from fastapi import Response
from fastapi.responses import JSONResponse
class ResourceType:
"""资源类型常量"""
AGENT = "Agent"
CLUSTER = "Cluster"
WORKFLOW = "Workflow"
KNOWLEDGE = "Knowledge"
MEMORY_ENGINE = "Memory_Engine"
@classmethod
def get_all_types(cls) -> list[str]:
"""获取所有支持的资源类型"""
return [cls.AGENT, cls.CLUSTER, cls.WORKFLOW, cls.KNOWLEDGE, cls.MEMORY_ENGINE]
@classmethod
def is_valid_type(cls, resource_type: str) -> bool:
"""验证资源类型是否有效"""
return resource_type in cls.get_all_types()
def generate_api_key(key_type: ApiKeyType) -> tuple[str, str, str]:
def generate_api_key(key_type: ApiKeyType) -> str:
"""
生成 API Key
@@ -39,102 +20,17 @@ def generate_api_key(key_type: ApiKeyType) -> tuple[str, str, str]:
"""
# 前缀映射
prefix_map = {
ApiKeyType.APP: "sk-app-",
ApiKeyType.RAG: "sk-rag-",
ApiKeyType.MEMORY: "sk-mem-",
ApiKeyType.AGENT: "sk-agent-",
ApiKeyType.CLUSTER: "sk-cluster-",
ApiKeyType.WORKFLOW: "sk-workflow-",
ApiKeyType.SERVICE: "sk-service-"
}
prefix = prefix_map[key_type]
random_string = secrets.token_urlsafe(32)[:32] # 32 字符
api_key = f"{prefix}{random_string}"
# 生成哈希值存储
key_hash = hash_api_key(api_key)
return api_key, key_hash, prefix
def hash_api_key(api_key: str) -> str:
"""对 API Key 进行哈希
Args:
api_key: API Key 明文
Returns:
str: 哈希值
"""
return hashlib.sha256(api_key.encode()).hexdigest()
def verify_api_key(api_key: str, key_hash: str) -> bool:
"""
验证 API Key
Args:
api_key: API Key 明文
key_hash: 存储的哈希值
Returns:
bool: 是否匹配
"""
computed_hash = hash_api_key(api_key)
return secrets.compare_digest(computed_hash, key_hash)
def validate_resource_binding(
resource_type: Optional[str],
resource_id: Optional[str]
) -> tuple[bool, str]:
"""
验证资源绑定的有效性
Args:
resource_type: 资源类型
resource_id: 资源ID
Returns:
tuple: (是否有效, 错误信息)
"""
# 如果都为空,表示不绑定资源,这是有效的
if not resource_type and not resource_id:
return True, ""
# 如果只有一个为空,这是无效的
if not resource_type or not resource_id:
return False, "resource_type 和 resource_id 必须同时提供或同时为空"
# 验证资源类型是否支持
if not ResourceType.is_valid_type(resource_type):
valid_types = ", ".join(ResourceType.get_all_types())
return False, f"不支持的资源类型 '{resource_type}',支持的类型:{valid_types}"
return True, ""
def get_resource_scope_mapping() -> dict[str, list[str]]:
"""
获取资源类型与权限范围的映射关系
Returns:
dict: 资源类型到推荐权限范围的映射
"""
return {
ResourceType.AGENT: [
"app:all"
],
ResourceType.CLUSTER: [
"app:all"
],
ResourceType.WORKFLOW: [
"app:all"
],
ResourceType.KNOWLEDGE: [
"rag:search", "rag:upload", "rag:delete"
],
ResourceType.MEMORY_ENGINE: [
"memory:read", "memory:write", "memory:delete", "memory:search"
]
}
return api_key
def add_rate_limit_headers(response, headers: dict):
@@ -151,3 +47,21 @@ def add_rate_limit_headers(response, headers: dict):
return response
def timestamp_to_datetime(timestamp: Optional[Union[int, float]]) -> Optional[datetime]:
"""将时间戳转换为datetime对象"""
if timestamp is None:
return None
# 处理毫秒级时间戳
if timestamp > 1e10:
timestamp = timestamp / 1000
return datetime.fromtimestamp(timestamp)
def datetime_to_timestamp(dt: Optional[datetime]) -> Optional[int]:
"""将datetime对象转换为时间戳毫秒"""
if dt is None:
return None
return int(dt.timestamp() * 1000)

View File

@@ -59,6 +59,7 @@ class BizCode(IntEnum):
EMBED_NOT_ALLOWED = 6009
PERMISSION_DENIED = 6010
INVALID_CONVERSATION = 6011
CONFIG_MISSING = 6012
# 模型7xxx
MODEL_CONFIG_INVALID = 7001
@@ -96,7 +97,7 @@ HTTP_MAPPING = {
BizCode.TOKEN_INVALID: 401,
BizCode.TOKEN_EXPIRED: 401,
BizCode.TOKEN_BLACKLISTED: 401,
BizCode.FORBIDDEN: 403,
BizCode.FORBIDDEN: 403,
BizCode.TENANT_NOT_FOUND: 404,
BizCode.WORKSPACE_NO_ACCESS: 403,
BizCode.NOT_FOUND: 404,
@@ -151,4 +152,4 @@ HTTP_MAPPING = {
BizCode.DB_ERROR: 500,
BizCode.SERVICE_UNAVAILABLE: 503,
BizCode.RATE_LIMITED: 429,
}
}

View File

@@ -0,0 +1,436 @@
"""
工作流执行器
基于 LangGraph 的工作流执行引擎。
"""
import logging
import uuid
import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langgraph.graph import StateGraph, START, END
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.expression_evaluator import evaluate_condition
from app.models.workflow_model import WorkflowExecution, WorkflowNodeExecution
from app.db import get_db
logger = logging.getLogger(__name__)
class WorkflowExecutor:
"""工作流执行器
负责将工作流配置转换为 LangGraph 并执行。
"""
def __init__(
self,
workflow_config: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
):
"""初始化执行器
Args:
workflow_config: 工作流配置
execution_id: 执行 ID
workspace_id: 工作空间 ID
user_id: 用户 ID
"""
self.workflow_config = workflow_config
self.execution_id = execution_id
self.workspace_id = workspace_id
self.user_id = user_id
self.nodes = workflow_config.get("nodes", [])
self.edges = workflow_config.get("edges", [])
self.execution_config = workflow_config.get("execution_config", {})
def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
"""准备初始状态(注入系统变量和会话变量)
变量命名空间:
- sys.xxx - 系统变量execution_id, workspace_id, user_id, message, input_variables 等)
- conv.xxx - 会话变量(跨多轮对话保持)
- node_id.xxx - 节点输出(执行时动态生成)
Args:
input_data: 输入数据
Returns:
初始化的工作流状态
"""
user_message = input_data.get("message") or ""
conversation_vars = input_data.get("conversation_vars") or {}
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
# 构建分层的变量结构
variables = {
"sys": {
"message": user_message, # 用户消息
"conversation_id": input_data.get("conversation_id"), # 会话 ID
"execution_id": self.execution_id, # 执行 ID
"workspace_id": self.workspace_id, # 工作空间 ID
"user_id": self.user_id, # 用户 ID
"input_variables": input_variables, # 自定义输入变量(给 Start 节点使用)
},
"conv": conversation_vars # 会话级变量(跨多轮对话保持)
}
return {
"messages": [HumanMessage(content=user_message)],
"variables": variables,
"node_outputs": {},
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)
"execution_id": self.execution_id,
"workspace_id": self.workspace_id,
"user_id": self.user_id,
"error": None,
"error_node": None
}
def build_graph(self) -> StateGraph:
"""构建 LangGraph
Returns:
编译后的状态图
"""
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
# 1. 创建状态图
workflow = StateGraph(WorkflowState)
# 2. 添加所有节点(包括 start 和 end
start_node_id = None
end_node_ids = []
for node in self.nodes:
node_type = node.get("type")
node_id = node.get("id")
# 记录 start 和 end 节点 ID
if node_type == "start":
start_node_id = node_id
elif node_type == "end":
end_node_ids.append(node_id)
# 创建节点实例(现在 start 和 end 也会被创建)
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_instance:
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
def make_node_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
workflow.add_node(node_id, make_node_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type})")
# 3. 添加边
# 从 START 连接到 start 节点
if start_node_id:
workflow.add_edge(START, start_node_id)
logger.debug(f"添加边: START -> {start_node_id}")
for edge in self.edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
condition = edge.get("condition")
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start
if source == start_node_id:
# 但要连接 start 到下一个节点
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
continue
# 处理到 end 节点的边
if target in end_node_ids:
# 连接到 end 节点
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
continue
# 跳过错误边(在节点内部处理)
if edge_type == "error":
continue
if condition:
# 条件边
def router(state: WorkflowState, cond=condition, tgt=target):
"""条件路由函数"""
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END # 条件不满足,结束
workflow.add_conditional_edges(source, router)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
else:
# 普通边
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
# 从 end 节点连接到 END
for end_node_id in end_node_ids:
workflow.add_edge(end_node_id, END)
logger.debug(f"添加边: {end_node_id} -> END")
# 4. 编译图
graph = workflow.compile()
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
return graph
async def execute(
self,
input_data: dict[str, Any]
) -> dict[str, Any]:
"""执行工作流(非流式)
Args:
input_data: 输入数据,包含 message 和 variables
Returns:
执行结果,包含 status, output, node_outputs, elapsed_time, token_usage
"""
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
# 记录开始时间
start_time = datetime.datetime.now()
# 1. 构建图
graph = self.build_graph()
# 2. 初始化状态(自动注入系统变量)
initial_state = self._prepare_initial_state(input_data)
# 3. 执行工作流
try:
result = await graph.ainvoke(initial_state)
# 计算耗时
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
# 提取节点输出(现在包含 start 和 end 节点)
node_outputs = result.get("node_outputs", {})
# 提取最终输出(从最后一个非 start/end 节点)
final_output = self._extract_final_output(node_outputs)
# 聚合 token 使用情况
token_usage = self._aggregate_token_usage(node_outputs)
# 提取 conversation_id从 start 节点输出)
conversation_id = None
for node_id, node_output in node_outputs.items():
if node_output.get("node_type") == "start":
conversation_id = node_output.get("output", {}).get("conversation_id")
break
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
return {
"status": "completed",
"output": final_output,
"node_outputs": node_outputs,
"messages": result.get("messages", []),
"conversation_id": conversation_id,
"elapsed_time": elapsed_time,
"token_usage": token_usage,
"error": result.get("error")
}
except Exception as e:
# 计算耗时(即使失败也记录)
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True)
return {
"status": "failed",
"error": str(e),
"output": None,
"node_outputs": {},
"elapsed_time": elapsed_time,
"token_usage": None
}
async def execute_stream(
self,
input_data: dict[str, Any]
):
"""执行工作流(流式)
Args:
input_data: 输入数据
Yields:
流式事件
"""
logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}")
# 1. 构建图
graph = self.build_graph()
# 2. 初始化状态(自动注入系统变量)
initial_state = self._prepare_initial_state(input_data)
# 3. 流式执行工作流
try:
# 使用 astream 获取节点级别的更新
async for event in graph.astream(initial_state, stream_mode="updates"):
for node_name, state_update in event.items():
yield {
"type": "node_complete",
"node": node_name,
"data": state_update,
"execution_id": self.execution_id
}
logger.info(f"工作流执行完成(流式): execution_id={self.execution_id}")
# 发送完成事件
yield {
"type": "workflow_complete",
"execution_id": self.execution_id
}
except Exception as e:
logger.error(f"工作流执行失败(流式): execution_id={self.execution_id}, error={e}", exc_info=True)
yield {
"type": "workflow_error",
"execution_id": self.execution_id,
"error": str(e)
}
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
"""从节点输出中提取最终输出
优先级:
1. 最后一个执行的非 start/end 节点的 output
2. 如果没有节点输出,返回 None
Args:
node_outputs: 所有节点的输出
Returns:
最终输出字符串或 None
"""
if not node_outputs:
return None
# 获取最后一个节点的输出
last_node_output = list(node_outputs.values())[-1] if node_outputs else None
if last_node_output and isinstance(last_node_output, dict):
return last_node_output.get("output")
return None
def _aggregate_token_usage(self, node_outputs: dict[str, Any]) -> dict[str, int] | None:
"""聚合所有节点的 token 使用情况
Args:
node_outputs: 所有节点的输出
Returns:
聚合的 token 使用情况 {"prompt_tokens": x, "completion_tokens": y, "total_tokens": z}
如果没有 token 使用信息,返回 None
"""
total_prompt_tokens = 0
total_completion_tokens = 0
total_tokens = 0
has_token_info = False
for node_output in node_outputs.values():
if isinstance(node_output, dict):
token_usage = node_output.get("token_usage")
if token_usage and isinstance(token_usage, dict):
has_token_info = True
total_prompt_tokens += token_usage.get("prompt_tokens", 0)
total_completion_tokens += token_usage.get("completion_tokens", 0)
total_tokens += token_usage.get("total_tokens", 0)
if not has_token_info:
return None
return {
"prompt_tokens": total_prompt_tokens,
"completion_tokens": total_completion_tokens,
"total_tokens": total_tokens
}
async def execute_workflow(
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
) -> dict[str, Any]:
"""执行工作流(便捷函数)
Args:
workflow_config: 工作流配置
input_data: 输入数据
execution_id: 执行 ID
workspace_id: 工作空间 ID
user_id: 用户 ID
Returns:
执行结果
"""
executor = WorkflowExecutor(
workflow_config=workflow_config,
execution_id=execution_id,
workspace_id=workspace_id,
user_id=user_id
)
return await executor.execute(input_data)
async def execute_workflow_stream(
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
):
"""执行工作流(流式,便捷函数)
Args:
workflow_config: 工作流配置
input_data: 输入数据
execution_id: 执行 ID
workspace_id: 工作空间 ID
user_id: 用户 ID
Yields:
流式事件
"""
executor = WorkflowExecutor(
workflow_config=workflow_config,
execution_id=execution_id,
workspace_id=workspace_id,
user_id=user_id
)
async for event in executor.execute_stream(input_data):
yield event

View File

@@ -0,0 +1,195 @@
"""
安全的表达式求值器
使用 simpleeval 库提供安全的表达式评估,避免代码注入攻击。
"""
import logging
from typing import Any
from simpleeval import simple_eval, NameNotDefined, InvalidExpression
logger = logging.getLogger(__name__)
class ExpressionEvaluator:
"""安全的表达式求值器"""
# 保留的命名空间
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
@staticmethod
def evaluate(
expression: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> Any:
"""安全地评估表达式
Args:
expression: 表达式字符串,如 "{{var.score}} > 0.8"
variables: 用户定义的变量
node_outputs: 节点输出结果
system_vars: 系统变量
Returns:
表达式求值结果
Raises:
ValueError: 表达式无效或求值失败
Examples:
>>> evaluator = ExpressionEvaluator()
>>> evaluator.evaluate(
... "var.score > 0.8",
... {"score": 0.9},
... {},
... {}
... )
True
>>> evaluator.evaluate(
... "node.intent.output == '售前咨询'",
... {},
... {"intent": {"output": "售前咨询"}},
... {}
... )
True
"""
# 移除 Jinja2 模板语法的花括号(如果存在)
expression = expression.strip()
if expression.startswith("{{") and expression.endswith("}}"):
expression = expression[2:-2].strip()
# 构建命名空间上下文
context = {
"var": variables, # 用户变量
"node": node_outputs, # 节点输出
"sys": system_vars or {}, # 系统变量
}
# 为了向后兼容,也支持直接访问(但会在日志中警告)
context.update(variables)
context["nodes"] = node_outputs
try:
# simpleeval 只支持安全的操作:
# - 算术运算: +, -, *, /, //, %, **
# - 比较运算: ==, !=, <, <=, >, >=
# - 逻辑运算: and, or, not
# - 成员运算: in, not in
# - 属性访问: obj.attr
# - 字典/列表访问: obj["key"], obj[0]
# 不支持:函数调用、导入、赋值等危险操作
result = simple_eval(expression, names=context)
return result
except NameNotDefined as e:
logger.error(f"表达式中引用了未定义的变量: {expression}, 错误: {e}")
raise ValueError(f"未定义的变量: {e}")
except InvalidExpression as e:
logger.error(f"表达式语法无效: {expression}, 错误: {e}")
raise ValueError(f"表达式语法无效: {e}")
except SyntaxError as e:
logger.error(f"表达式语法错误: {expression}, 错误: {e}")
raise ValueError(f"表达式语法错误: {e}")
except Exception as e:
logger.error(f"表达式求值异常: {expression}, 错误: {e}")
raise ValueError(f"表达式求值失败: {e}")
@staticmethod
def evaluate_bool(
expression: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> bool:
"""评估布尔表达式(用于条件判断)
Args:
expression: 布尔表达式
variables: 用户变量
node_outputs: 节点输出
system_vars: 系统变量
Returns:
布尔值结果
Examples:
>>> ExpressionEvaluator.evaluate_bool(
... "var.count >= 10 and var.status == 'active'",
... {"count": 15, "status": "active"},
... {},
... {}
... )
True
"""
result = ExpressionEvaluator.evaluate(
expression, variables, node_outputs, system_vars
)
return bool(result)
@staticmethod
def validate_variable_names(variables: list[dict]) -> list[str]:
"""验证变量名是否合法
Args:
variables: 变量定义列表
Returns:
错误列表,如果为空则验证通过
Examples:
>>> ExpressionEvaluator.validate_variable_names([
... {"name": "user_input"},
... {"name": "var"} # 保留字
... ])
["变量名 'var' 是保留的命名空间,请使用其他名称"]
"""
errors = []
for var in variables:
var_name = var.get("name", "")
# 检查是否为保留命名空间
if var_name in ExpressionEvaluator.RESERVED_NAMESPACES:
errors.append(
f"变量名 '{var_name}' 是保留的命名空间,请使用其他名称"
)
# 检查是否为有效的 Python 标识符
if not var_name.isidentifier():
errors.append(
f"变量名 '{var_name}' 不是有效的标识符"
)
return errors
# 便捷函数
def evaluate_expression(
expression: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> Any:
"""评估表达式(便捷函数)"""
return ExpressionEvaluator.evaluate(
expression, variables, node_outputs, system_vars
)
def evaluate_condition(
expression: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> bool:
"""评估条件表达式(便捷函数)"""
return ExpressionEvaluator.evaluate_bool(
expression, variables, node_outputs, system_vars
)

View File

@@ -0,0 +1,24 @@
"""
工作流节点实现
提供各种类型的节点实现,用于工作流执行。
"""
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.node_factory import NodeFactory
__all__ = [
"BaseNode",
"WorkflowState",
"LLMNode",
"AgentNode",
"TransformNode",
"StartNode",
"EndNode",
"NodeFactory",
]

View File

@@ -0,0 +1,6 @@
"""Agent 节点"""
from app.core.workflow.nodes.agent.node import AgentNode
from app.core.workflow.nodes.agent.config import AgentNodeConfig
__all__ = ["AgentNode", "AgentNodeConfig"]

View File

@@ -0,0 +1,71 @@
"""Agent 节点配置"""
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
class AgentNodeConfig(BaseNodeConfig):
"""Agent 节点配置
调用已配置的 Agent 执行任务。
"""
agent_id: str = Field(
...,
description="Agent 配置 ID"
)
message: str = Field(
default="{{ sys.message }}",
description="发送给 Agent 的消息,支持模板变量"
)
conversation_id: str | None = Field(
default=None,
description="会话 ID用于多轮对话"
)
variables: dict[str, str] | None = Field(
default=None,
description="传递给 Agent 的变量"
)
timeout: int = Field(
default=300,
ge=1,
le=3600,
description="超时时间(秒)"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="output",
type=VariableType.STRING,
description="Agent 的回复内容"
),
VariableDefinition(
name="conversation_id",
type=VariableType.STRING,
description="会话 ID"
),
VariableDefinition(
name="token_usage",
type=VariableType.OBJECT,
description="Token 使用情况"
)
],
description="输出变量定义(自动生成,通常不需要修改)"
)
class Config:
json_schema_extra = {
"example": {
"agent_id": "uuid-here",
"message": "{{ sys.message }}",
"timeout": 300,
"description": "调用客服 Agent"
}
}

View File

@@ -0,0 +1,152 @@
"""
Agent 节点实现
调用已发布的 Agent 应用。
"""
import logging
from typing import Any
from langchain_core.messages import AIMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.services.draft_run_service import DraftRunService
from app.models import AppRelease
from app.db import get_db
logger = logging.getLogger(__name__)
class AgentNode(BaseNode):
"""Agent 节点
支持流式和非流式输出。
配置示例:
{
"type": "agent",
"config": {
"agent_id": "uuid", # Agent 的 release_id
"message": "{{var.user_input}}"
}
}
"""
def _prepare_agent(self, state: WorkflowState) -> tuple[DraftRunService, AppRelease, str]:
"""准备 Agent公共逻辑
Args:
state: 工作流状态
Returns:
(draft_service, release, message): 服务实例、发布配置、消息
"""
# 1. 渲染消息
message_template = self.config.get("message", "")
message = self._render_template(message_template, state)
# 2. 获取 Agent 配置
agent_id = self.config.get("agent_id")
if not agent_id:
raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置")
db = next(get_db())
release = db.query(AppRelease).filter(
AppRelease.id == agent_id
).first()
if not release:
raise ValueError(f"Agent 不存在: {agent_id}")
draft_service = DraftRunService(db)
return draft_service, release, message
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""非流式执行
Args:
state: 工作流状态
Returns:
状态更新字典
"""
draft_service, release, message = self._prepare_agent(state)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
# 执行 Agent非流式
result = await draft_service.run(
agent_config=release.config,
model_config=None,
message=message,
workspace_id=state.get("workspace_id"),
user_id=state.get("user_id"),
variables=state.get("variables", {})
)
response = result.get("response", "")
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(response)}")
return {
"messages": [AIMessage(content=response)],
"node_outputs": {
self.node_id: {
"output": response,
"status": "completed",
"meta_data": result.get("meta_data", {})
}
}
}
async def execute_stream(self, state: WorkflowState):
"""流式执行
Args:
state: 工作流状态
Yields:
流式事件字典
"""
draft_service, release, message = self._prepare_agent(state)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
# 累积完整响应
full_response = ""
# 执行 Agent流式
async for chunk in draft_service.run_stream(
agent_config=release.config,
model_config=None,
message=message,
workspace_id=state.get("workspace_id"),
user_id=state.get("user_id"),
variables=state.get("variables", {})
):
# 提取内容
content = chunk.get("content", "")
full_response += content
# 流式返回每个 chunk
yield {
"type": "chunk",
"node_id": self.node_id,
"content": content,
"full_content": full_response,
"meta_data": chunk.get("meta_data", {})
}
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}")
# 最后返回完整结果
yield {
"type": "complete",
"messages": [AIMessage(content=full_response)],
"node_outputs": {
self.node_id: {
"output": full_response,
"status": "completed"
}
}
}

View File

@@ -0,0 +1,109 @@
"""节点配置基类
定义所有节点配置的通用字段和数据结构。
"""
from enum import StrEnum
from pydantic import BaseModel, Field
class VariableType(StrEnum):
"""变量类型枚举"""
STRING = "string"
NUMBER = "number"
BOOLEAN = "boolean"
ARRAY = "array"
OBJECT = "object"
class VariableDefinition(BaseModel):
"""变量定义
定义工作流或节点的输入/输出变量。
这是一个通用的数据结构,可以在多个地方使用。
"""
name: str = Field(
...,
description="变量名称"
)
type: VariableType = Field(
default=VariableType.STRING,
description="变量类型"
)
required: bool = Field(
default=False,
description="是否必需"
)
default: str | int | float | bool | list | dict | None = Field(
default=None,
description="默认值"
)
description: str | None = Field(
default=None,
description="变量描述"
)
class Config:
json_schema_extra = {
"examples": [
{
"name": "language",
"type": "string",
"required": False,
"default": "zh-CN",
"description": "语言设置"
},
{
"name": "max_length",
"type": "number",
"required": False,
"default": 1000,
"description": "最大长度"
},
{
"name": "enable_search",
"type": "boolean",
"required": True,
"description": "是否启用搜索"
}
]
}
class BaseNodeConfig(BaseModel):
"""节点配置基类
所有节点配置都应该继承此基类。
通用字段:
- name: 节点名称(显示名称)
- description: 节点描述
- tags: 节点标签(用于分类和搜索)
"""
name: str | None = Field(
default=None,
description="节点名称(显示名称),如果不设置则使用节点 ID"
)
description: str | None = Field(
default=None,
description="节点描述,说明节点的作用"
)
tags: list[str] = Field(
default_factory=list,
description="节点标签,用于分类和搜索"
)
class Config:
"""Pydantic 配置"""
# 允许额外字段(向后兼容)
extra = "allow"

View File

@@ -0,0 +1,556 @@
"""
工作流节点基类
定义节点的基本接口和通用功能。
"""
import asyncio
import logging
from abc import ABC, abstractmethod
from typing import Any, TypedDict, Annotated
from operator import add
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
class WorkflowState(TypedDict):
"""工作流状态
在节点间传递的状态对象,包含消息、变量、节点输出等信息。
"""
# 消息列表(追加模式)
messages: Annotated[list[AnyMessage], add]
# 输入变量(从配置的 variables 传入)
variables: dict[str, Any]
# 节点输出(存储每个节点的执行结果,用于变量引用)
# 使用自定义合并函数,将新的节点输出合并到现有字典中
node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# 运行时节点变量(简化版,只存储业务数据,供节点间快速访问)
# 格式:{node_id: business_result}
runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# 执行上下文
execution_id: str
workspace_id: str
user_id: str
# 错误信息(用于错误边)
error: str | None
error_node: str | None
class BaseNode(ABC):
"""节点基类
所有节点类型都应该继承此基类,实现 execute 方法。
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
"""初始化节点
Args:
node_config: 节点配置
workflow_config: 工作流配置
"""
self.node_config = node_config
self.workflow_config = workflow_config
self.node_id = node_config["id"]
self.node_type = node_config["type"]
self.node_name = node_config.get("name", self.node_id)
# 使用 or 运算符处理 None 值
self.config = node_config.get("config") or {}
self.error_handling = node_config.get("error_handling") or {}
@abstractmethod
async def execute(self, state: WorkflowState) -> Any:
"""执行节点业务逻辑(非流式)
节点只需要返回业务结果,不需要关心输出格式、时间统计等。
BaseNode 会自动包装成标准格式。
Args:
state: 工作流状态
Returns:
业务结果(任意类型)
Examples:
>>> # LLM 节点
>>> return "这是 AI 的回复"
>>> # Transform 节点
>>> return {"processed_data": [...]}
>>> # Start/End 节点
>>> return {"message": "开始", "conversation_id": "xxx"}
"""
pass
async def execute_stream(self, state: WorkflowState):
"""执行节点业务逻辑(流式)
子类可以重写此方法以支持流式输出。
默认实现:执行非流式方法并一次性返回。
节点需要:
1. yield 中间结果(如文本片段)
2. 最后 yield 一个特殊的完成标记:{"__final__": True, "result": final_result}
Args:
state: 工作流状态
Yields:
业务数据chunk或完成标记
Examples:
>>> # 流式 LLM 节点
>>> full_response = ""
>>> async for chunk in llm.astream(prompt):
... full_response += chunk
... yield chunk # yield 文本片段
>>>
>>> # 最后 yield 完成标记
>>> yield {"__final__": True, "result": AIMessage(content=full_response)}
"""
result = await self.execute(state)
# 默认实现:直接 yield 完成标记
yield {"__final__": True, "result": result}
def supports_streaming(self) -> bool:
"""节点是否支持流式输出
Returns:
是否支持流式输出
"""
# 检查子类是否重写了 execute_stream 方法
return self.execute_stream.__func__ != BaseNode.execute_stream.__func__
def get_timeout(self) -> int:
"""获取超时时间(秒)
Returns:
超时时间
"""
return 60
# return self.error_handling.get("timeout", 60)
async def run(self, state: WorkflowState) -> dict[str, Any]:
"""执行节点(带错误处理和输出包装,非流式)
这个方法由 Executor 调用,负责:
1. 时间统计
2. 调用节点的 execute() 方法
3. 将业务结果包装成标准输出格式
4. 错误处理
Args:
state: 工作流状态
Returns:
标准化的状态更新字典
"""
import time
start_time = time.time()
try:
timeout = self.get_timeout()
# 调用节点的业务逻辑
business_result = await asyncio.wait_for(
self.execute(state),
timeout=timeout
)
elapsed_time = time.time() - start_time
# 提取处理后的输出(调用子类的 _extract_output
extracted_output = self._extract_output(business_result)
# 包装成标准输出格式
wrapped_output = self._wrap_output(business_result, elapsed_time, state)
# 将提取后的输出存储到运行时变量中(供后续节点快速访问)
# 如果提取后的输出是字典,拆包存储;否则存储为 output 字段
if isinstance(extracted_output, dict):
runtime_var = extracted_output
else:
runtime_var = {"output": extracted_output}
# 返回包装后的输出和运行时变量
return {
**wrapped_output,
"runtime_vars": {
self.node_id: runtime_var
}
}
except TimeoutError:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
return self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
return self._wrap_error(str(e), elapsed_time, state)
async def run_stream(self, state: WorkflowState):
"""执行节点(带错误处理和输出包装,流式)
这个方法由 Executor 调用,负责:
1. 时间统计
2. 调用节点的 execute_stream() 方法
3. 将业务数据包装成标准输出格式
4. 错误处理
Args:
state: 工作流状态
Yields:
标准化的流式事件
"""
import time
start_time = time.time()
try:
timeout = self.get_timeout()
# 累积完整结果(用于最后的包装)
chunks = []
final_result = None
# 使用异步生成器包装,支持超时
async def stream_with_timeout():
nonlocal final_result
loop_start = asyncio.get_event_loop().time()
async for item in self.execute_stream(state):
# 检查超时
if asyncio.get_event_loop().time() - loop_start > timeout:
raise TimeoutError()
# 检查是否是完成标记
if isinstance(item, dict) and item.get("__final__"):
final_result = item["result"]
elif isinstance(item, str):
# 字符串是 chunk
chunks.append(item)
yield {
"type": "chunk",
"node_id": self.node_id,
"content": item,
"full_content": "".join(chunks)
}
else:
# 其他类型也当作 chunk 处理
chunks.append(str(item))
yield {
"type": "chunk",
"node_id": self.node_id,
"content": str(item),
"full_content": "".join(chunks)
}
async for chunk_event in stream_with_timeout():
yield chunk_event
elapsed_time = time.time() - start_time
# 包装最终结果
final_output = self._wrap_output(final_result, elapsed_time, state)
yield {
"type": "complete",
**final_output
}
except TimeoutError:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
yield {
"type": "error",
**self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
}
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
yield {
"type": "error",
**self._wrap_error(str(e), elapsed_time, state)
}
def _wrap_output(
self,
business_result: Any,
elapsed_time: float,
state: WorkflowState
) -> dict[str, Any]:
"""将业务结果包装成标准输出格式
Args:
business_result: 节点返回的业务结果
elapsed_time: 执行耗时
state: 工作流状态
Returns:
标准化的状态更新字典
"""
# 提取输入数据(用于记录)
input_data = self._extract_input(state)
# 提取 token 使用情况(如果有)
token_usage = self._extract_token_usage(business_result)
# 提取实际输出(去除元数据)
output = self._extract_output(business_result)
# 构建标准节点输出
node_output = {
"node_id": self.node_id,
"node_type": self.node_type,
"node_name": self.node_name,
"status": "completed",
"input": input_data,
"output": output,
"elapsed_time": elapsed_time,
"token_usage": token_usage,
"error": None
}
return {
"node_outputs": {
self.node_id: node_output
}
}
def _wrap_error(
self,
error_message: str,
elapsed_time: float,
state: WorkflowState
) -> dict[str, Any]:
"""将错误包装成标准输出格式
Args:
error_message: 错误信息
elapsed_time: 执行耗时
state: 工作流状态
Returns:
标准化的状态更新字典
"""
# 查找错误边
error_edge = self._find_error_edge()
# 提取输入数据
input_data = self._extract_input(state)
# 构建错误输出
node_output = {
"node_id": self.node_id,
"node_type": self.node_type,
"node_name": self.node_name,
"status": "failed",
"input": input_data,
"output": None,
"elapsed_time": elapsed_time,
"token_usage": None,
"error": error_message
}
if error_edge:
# 有错误边:记录错误并继续
logger.warning(
f"节点 {self.node_id} 执行失败,跳转到错误处理节点: {error_edge['target']}"
)
return {
"node_outputs": {
self.node_id: node_output
},
"error": error_message,
"error_node": self.node_id
}
else:
# 无错误边:抛出异常停止工作流
logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}")
raise Exception(f"节点 {self.node_id} 执行失败: {error_message}")
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
"""提取节点输入数据(用于记录)
子类可以重写此方法来自定义输入记录。
Args:
state: 工作流状态
Returns:
输入数据字典
"""
# 默认返回配置
return {"config": self.config}
def _extract_output(self, business_result: Any) -> Any:
"""从业务结果中提取实际输出
子类可以重写此方法来自定义输出提取。
Args:
business_result: 业务结果
Returns:
实际输出
"""
# 默认直接返回业务结果
return business_result
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
"""从业务结果中提取 token 使用情况
子类可以重写此方法来提取 token 信息。
Args:
business_result: 业务结果
Returns:
token 使用情况或 None
"""
# 默认返回 None
return None
def _find_error_edge(self) -> dict[str, Any] | None:
"""查找错误边
Returns:
错误边配置或 None
"""
for edge in self.workflow_config.get("edges", []):
if edge.get("source") == self.node_id and edge.get("type") == "error":
return edge
return None
def _render_template(self, template: str, state: WorkflowState | None) -> str:
"""渲染模板
支持的变量命名空间:
- sys.xxx: 系统变量message, execution_id, workspace_id, user_id, conversation_id
- conv.xxx: 会话变量(跨多轮对话保持)
- node_id.xxx: 节点输出
Args:
template: 模板字符串
state: 工作流状态
Returns:
渲染后的字符串
"""
from app.core.workflow.template_renderer import render_template
# 处理 state 为 None 的情况
if state is None:
state = {}
# 使用变量池获取变量
pool = VariablePool(state)
return render_template(
template=template,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars()
)
def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool:
"""评估条件表达式
支持的变量命名空间:
- sys.xxx: 系统变量
- conv.xxx: 会话变量
- node_id.xxx: 节点输出
Args:
expression: 条件表达式
state: 工作流状态
Returns:
布尔值结果
"""
from app.core.workflow.expression_evaluator import evaluate_condition
# 处理 state 为 None 的情况
if state is None:
state = {}
# 使用变量池获取变量
pool = VariablePool(state)
return evaluate_condition(
expression=expression,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars()
)
def get_variable_pool(self, state: WorkflowState) -> VariablePool:
"""获取变量池实例
VariablePool 是轻量级包装器,只持有 state 的引用,创建成本极低。
Args:
state: 工作流状态
Returns:
VariablePool 实例
Examples:
>>> pool = self.get_variable_pool(state)
>>> message = pool.get("sys.message")
>>> llm_output = pool.get("llm_qa.output")
"""
return VariablePool(state)
def get_variable(
self,
selector: list[str] | str,
state: WorkflowState,
default: Any = None
) -> Any:
"""获取变量值(便捷方法)
Args:
selector: 变量选择器
state: 工作流状态
default: 默认值
Returns:
变量值
Examples:
>>> message = self.get_variable("sys.message", state)
>>> output = self.get_variable(["llm_qa", "output"], state)
>>> custom = self.get_variable("var.custom", state, default="默认值")
"""
pool = VariablePool(state)
return pool.get(selector, default=default)
def has_variable(self, selector: list[str] | str, state: WorkflowState) -> bool:
"""检查变量是否存在(便捷方法)
Args:
selector: 变量选择器
state: 工作流状态
Returns:
变量是否存在
Examples:
>>> if self.has_variable("llm_qa.output", state):
... output = self.get_variable("llm_qa.output", state)
"""
pool = VariablePool(state)
return pool.has(selector)

View File

@@ -0,0 +1,29 @@
"""节点配置类统一导出
所有节点的配置类都在这里导出,方便使用。
"""
from app.core.workflow.nodes.base_config import (
BaseNodeConfig,
VariableDefinition,
VariableType,
)
from app.core.workflow.nodes.start.config import StartNodeConfig
from app.core.workflow.nodes.end.config import EndNodeConfig
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
from app.core.workflow.nodes.agent.config import AgentNodeConfig
from app.core.workflow.nodes.transform.config import TransformNodeConfig
__all__ = [
# 基础类
"BaseNodeConfig",
"VariableDefinition",
"VariableType",
# 节点配置
"StartNodeConfig",
"EndNodeConfig",
"LLMNodeConfig",
"MessageConfig",
"AgentNodeConfig",
"TransformNodeConfig",
]

View File

@@ -0,0 +1,6 @@
"""End 节点"""
from app.core.workflow.nodes.end.node import EndNode
from app.core.workflow.nodes.end.config import EndNodeConfig
__all__ = ["EndNode", "EndNodeConfig"]

View File

@@ -0,0 +1,37 @@
"""End 节点配置"""
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
class EndNodeConfig(BaseNodeConfig):
"""End 节点配置
End 节点负责输出工作流的最终结果。
"""
output: str = Field(
default="工作流已完成",
description="输出模板,支持引用前置节点的输出,如:{{ llm_qa.output }}"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="output",
type=VariableType.STRING,
description="工作流的最终输出"
)
],
description="输出变量定义(自动生成,通常不需要修改)"
)
class Config:
json_schema_extra = {
"example": {
"output": "{{ llm_qa.output }}",
"description": "输出 LLM 的回答"
}
}

View File

@@ -0,0 +1,53 @@
"""
End 节点实现
工作流的结束节点,输出最终结果。
"""
import logging
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
logger = logging.getLogger(__name__)
class EndNode(BaseNode):
"""End 节点
工作流的结束节点,根据配置的模板输出最终结果。
"""
async def execute(self, state: WorkflowState) -> str:
"""执行 end 节点业务逻辑
Args:
state: 工作流状态
Returns:
最终输出字符串
"""
logger.info(f"节点 {self.node_id} (End) 开始执行")
# 获取配置的输出模板
output_template = self.config.get("output")
pool = self.get_variable_pool(state)
print("="*20)
print( pool.get("start.test"))
print("="*20)
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template:
output = self._render_template(output_template, state)
else:
output = "工作流已完成"
# 统计信息(用于日志)
node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
print("="*20)
print(output)
print("="*20)
return output

View File

@@ -0,0 +1,15 @@
from enum import StrEnum
class NodeType(StrEnum):
START = "start"
END = "end"
ANSWER = "answer"
LLM = "llm"
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
IF_ELSE = "if-else"
CODE = "code"
TRANSFORM = "transform"
QUESTION_CLASSIFIER = "question-classifier"
HTTP_REQUEST = "http-request"
TOOL = "tool"
AGENT = "agent"

View File

@@ -0,0 +1,6 @@
"""LLM 节点"""
from app.core.workflow.nodes.llm.node import LLMNode
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
__all__ = ["LLMNode", "LLMNodeConfig", "MessageConfig"]

View File

@@ -0,0 +1,141 @@
"""LLM 节点配置"""
from pydantic import BaseModel, Field, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
class MessageConfig(BaseModel):
"""消息配置"""
role: str = Field(
...,
description="消息角色system, user, assistant"
)
content: str = Field(
...,
description="消息内容,支持模板变量,如:{{ sys.message }}"
)
@field_validator("role")
@classmethod
def validate_role(cls, v: str) -> str:
"""验证角色"""
allowed_roles = ["system", "user", "human", "assistant", "ai"]
if v.lower() not in allowed_roles:
raise ValueError(f"角色必须是以下之一: {', '.join(allowed_roles)}")
return v.lower()
class LLMNodeConfig(BaseNodeConfig):
"""LLM 节点配置
支持两种配置方式:
1. 简单模式:使用 prompt 字段
2. 消息模式:使用 messages 字段(推荐)
"""
model_id: str = Field(
...,
description="模型配置 ID"
)
# 简单模式
prompt: str | None = Field(
default=None,
description="提示词模板(简单模式),支持变量引用"
)
# 消息模式(推荐)
messages: list[MessageConfig] | None = Field(
default=None,
description="消息列表(消息模式),支持多轮对话"
)
# 模型参数
temperature: float | None = Field(
default=0.7,
ge=0.0,
le=2.0,
description="温度参数,控制输出的随机性"
)
max_tokens: int | None = Field(
default=1000,
ge=1,
le=32000,
description="最大生成 token 数"
)
top_p: float | None = Field(
default=None,
ge=0.0,
le=1.0,
description="Top-p 采样参数"
)
frequency_penalty: float | None = Field(
default=None,
ge=-2.0,
le=2.0,
description="频率惩罚"
)
presence_penalty: float | None = Field(
default=None,
ge=-2.0,
le=2.0,
description="存在惩罚"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="output",
type=VariableType.STRING,
description="LLM 生成的文本输出"
),
VariableDefinition(
name="token_usage",
type=VariableType.OBJECT,
description="Token 使用情况"
)
],
description="输出变量定义(自动生成,通常不需要修改)"
)
@field_validator("messages", "prompt")
@classmethod
def validate_input_mode(cls, v, info):
"""验证输入模式prompt 和 messages 至少有一个"""
# 这个验证在 model_validator 中更合适
return v
class Config:
json_schema_extra = {
"examples": [
{
"model_id": "uuid-here",
"prompt": "请回答:{{ sys.message }}",
"temperature": 0.7,
"max_tokens": 1000
},
{
"model_id": "uuid-here",
"messages": [
{
"role": "system",
"content": "你是一个专业的 AI 助手"
},
{
"role": "user",
"content": "{{ sys.message }}"
}
],
"temperature": 0.7,
"max_tokens": 1000
}
]
}

View File

@@ -0,0 +1,247 @@
"""
LLM 节点实现
调用 LLM 模型进行文本生成。
"""
import logging
from typing import Any
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models import ModelConfig
from app.db import get_db, get_db_context
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService, ModelApiKeyService
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
logger = logging.getLogger(__name__)
class LLMNode(BaseNode):
"""LLM 节点
支持流式和非流式输出,使用 LangChain 标准的消息格式。
配置示例(支持多种消息格式):
1. 简单文本格式:
{
"type": "llm",
"config": {
"model_id": "uuid",
"prompt": "请分析:{{sys.message}}",
"temperature": 0.7,
"max_tokens": 1000
}
}
2. LangChain 消息格式(推荐):
{
"type": "llm",
"config": {
"model_id": "uuid",
"messages": [
{
"role": "system",
"content": "你是一个专业的 AI 助手。"
},
{
"role": "user",
"content": "{{sys.message}}"
}
],
"temperature": 0.7,
"max_tokens": 1000
}
}
支持的角色类型:
- system: 系统消息SystemMessage
- user/human: 用户消息HumanMessage
- ai/assistant: AI 消息AIMessage
"""
def _prepare_llm(self, state: WorkflowState) -> tuple[RedBearLLM, list | str]:
"""准备 LLM 实例(公共逻辑)
Args:
state: 工作流状态
Returns:
(llm, messages_or_prompt): LLM 实例和消息列表或 prompt 字符串
"""
# 1. 处理消息格式(优先使用 messages
messages_config = self.config.get("messages")
if messages_config:
# 使用 LangChain 消息格式
messages = []
for msg_config in messages_config:
role = msg_config.get("role", "user").lower()
content_template = msg_config.get("content", "")
content = self._render_template(content_template, state)
# 根据角色创建对应的消息对象
if role == "system":
messages.append(SystemMessage(content=content))
elif role in ["user", "human"]:
messages.append(HumanMessage(content=content))
elif role in ["ai", "assistant"]:
messages.append(AIMessage(content=content))
else:
logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append(HumanMessage(content=content))
prompt_or_messages = messages
else:
# 使用简单的 prompt 格式(向后兼容)
prompt_template = self.config.get("prompt", "")
prompt_or_messages = self._render_template(prompt_template, state)
# 2. 获取模型配置
model_id = self.config.get("model_id")
if not model_id:
raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置")
# 3. 在 with 块内完成所有数据库操作和数据提取
with get_db_context() as db:
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
if not config:
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
if not config.api_keys or len(config.api_keys) == 0:
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
# 在 Session 关闭前提取所有需要的数据
api_config = config.api_keys[0]
model_name = api_config.model_name
provider = api_config.provider
api_key = api_config.api_key
api_base = api_config.api_base
model_type = config.type
# 4. 创建 LLM 实例(使用已提取的数据)
llm = RedBearLLM(
RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base
),
type=model_type
)
return llm, prompt_or_messages
async def execute(self, state: WorkflowState) -> AIMessage:
"""非流式执行 LLM 调用
Args:
state: 工作流状态
Returns:
LLM 响应消息
"""
llm, prompt_or_messages = self._prepare_llm(state)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
# 调用 LLM支持字符串或消息列表
response = await llm.ainvoke(prompt_or_messages)
# 提取内容
if hasattr(response, 'content'):
content = response.content
else:
content = str(response)
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
# 返回 AIMessage包含响应元数据
return response if isinstance(response, AIMessage) else AIMessage(content=content)
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
"""提取输入数据(用于记录)"""
_, prompt_or_messages = self._prepare_llm(state)
return {
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
"messages": [
{"role": msg.__class__.__name__.replace("Message", "").lower(), "content": msg.content}
for msg in prompt_or_messages
] if isinstance(prompt_or_messages, list) else None,
"config": {
"model_id": self.config.get("model_id"),
"temperature": self.config.get("temperature"),
"max_tokens": self.config.get("max_tokens")
}
}
def _extract_output(self, business_result: Any) -> str:
"""从 AIMessage 中提取文本内容"""
if isinstance(business_result, AIMessage):
return business_result.content
return str(business_result)
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
"""从 AIMessage 中提取 token 使用情况"""
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
usage = business_result.response_metadata.get('token_usage')
if usage:
return {
"prompt_tokens": usage.get('prompt_tokens', 0),
"completion_tokens": usage.get('completion_tokens', 0),
"total_tokens": usage.get('total_tokens', 0)
}
return None
async def execute_stream(self, state: WorkflowState):
"""流式执行 LLM 调用
Args:
state: 工作流状态
Yields:
文本片段chunk或完成标记
"""
llm, prompt_or_messages = self._prepare_llm(state)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
# 累积完整响应
full_response = ""
last_chunk = None
# 调用 LLM流式支持字符串或消息列表
async for chunk in llm.astream(prompt_or_messages):
# 提取内容
if hasattr(chunk, 'content'):
content = chunk.content
else:
content = str(chunk)
full_response += content
last_chunk = chunk
# 流式返回每个文本片段
yield content
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
# 构建完整的 AIMessage包含元数据
if isinstance(last_chunk, AIMessage):
final_message = AIMessage(
content=full_response,
response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
)
else:
final_message = AIMessage(content=full_response)
# yield 完成标记
yield {"__final__": True, "result": final_message}

View File

@@ -0,0 +1,93 @@
"""
节点工厂
根据节点类型创建相应的节点实例。
"""
import logging
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.end import EndNode
logger = logging.getLogger(__name__)
class NodeFactory:
"""节点工厂
使用工厂模式创建节点实例,便于扩展和维护。
"""
# 节点类型注册表
_node_types: dict[str, type[BaseNode]] = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.LLM: LLMNode,
NodeType.AGENT: AgentNode,
NodeType.TRANSFORM: TransformNode,
}
@classmethod
def register_node_type(cls, node_type: str, node_class: type[BaseNode]):
"""注册新的节点类型
Args:
node_type: 节点类型名称
node_class: 节点类
Examples:
>>> class CustomNode(BaseNode):
... async def execute(self, state):
... return {"node_outputs": {self.node_id: {"output": "custom"}}}
>>> NodeFactory.register_node_type("custom", CustomNode)
"""
cls._node_types[node_type] = node_class
logger.info(f"注册节点类型: {node_type} -> {node_class.__name__}")
@classmethod
def create_node(
cls,
node_config: dict[str, Any],
workflow_config: dict[str, Any]
) -> BaseNode | None:
"""创建节点实例
Args:
node_config: 节点配置
workflow_config: 工作流配置
Returns:
节点实例或 None对于不支持的节点类型
Raises:
ValueError: 不支持的节点类型
"""
node_type = node_config.get("type")
# 跳过条件节点(由 LangGraph 处理)
if node_type == "condition":
return None
# 获取节点类
node_class = cls._node_types.get(node_type)
if not node_class:
raise ValueError(f"不支持的节点类型: {node_type}")
# 创建节点实例
logger.debug(f"创建节点: {node_config.get('id')} (type={node_type})")
return node_class(node_config, workflow_config)
@classmethod
def get_supported_types(cls) -> list[str]:
"""获取支持的节点类型列表
Returns:
节点类型列表
"""
return list(cls._node_types.keys())

View File

@@ -0,0 +1,6 @@
"""Start 节点"""
from app.core.workflow.nodes.start.node import StartNode
from app.core.workflow.nodes.start.config import StartNodeConfig
__all__ = ["StartNode", "StartNodeConfig"]

View File

@@ -0,0 +1,87 @@
"""Start 节点配置"""
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
class StartNodeConfig(BaseNodeConfig):
"""Start 节点配置
Start 节点的作用:
1. 标记工作流的起点
2. 定义自定义输入变量(会作为节点输出,通过 start_node_id.variable_name 访问)
3. 输出系统变量和会话变量
"""
# 自定义输入变量定义
variables: list[VariableDefinition] = Field(
default_factory=list,
description="自定义输入变量列表,这些变量会作为 Start 节点的输出"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="message",
type=VariableType.STRING,
description="用户输入的消息"
),
VariableDefinition(
name="conversation_vars",
type=VariableType.OBJECT,
description="会话级变量"
),
VariableDefinition(
name="execution_id",
type=VariableType.STRING,
description="执行 ID"
),
VariableDefinition(
name="conversation_id",
type=VariableType.STRING,
description="会话 ID"
),
VariableDefinition(
name="workspace_id",
type=VariableType.STRING,
description="工作空间 ID"
),
VariableDefinition(
name="user_id",
type=VariableType.STRING,
description="用户 ID"
)
],
description="输出变量定义(自动生成,通常不需要修改)"
)
class Config:
json_schema_extra = {
"examples": [
{
"description": "工作流开始节点",
"variables": []
},
{
"description": "带自定义变量的开始节点",
"variables": [
{
"name": "language",
"type": "string",
"required": False,
"default": "zh-CN",
"description": "语言设置"
},
{
"name": "max_length",
"type": "number",
"required": False,
"default": 1000,
"description": "最大长度"
}
]
}
]
}

View File

@@ -0,0 +1,136 @@
"""
Start 节点实现
工作流的起始节点,定义输入变量并输出系统参数。
"""
import logging
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.start.config import StartNodeConfig
logger = logging.getLogger(__name__)
class StartNode(BaseNode):
"""Start 节点
工作流的起始节点,负责:
1. 定义工作流的输入变量(通过配置)
2. 输出系统变量sys.*
3. 输出会话变量conv.*
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
"""初始化 Start 节点
Args:
node_config: 节点配置
workflow_config: 工作流配置
"""
super().__init__(node_config, workflow_config)
# 解析并验证配置
self.typed_config = StartNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行 start 节点业务逻辑
Start 节点输出系统变量、会话变量和自定义变量。
Args:
state: 工作流状态
Returns:
包含系统参数、会话变量和自定义变量的字典
"""
logger.info(f"节点 {self.node_id} (Start) 开始执行")
# 创建变量池实例(在方法内复用)
pool = self.get_variable_pool(state)
# 处理自定义变量(传入 pool 避免重复创建)
custom_vars = self._process_custom_variables(pool)
# 返回业务数据(包含自定义变量)
result = {
"message": pool.get("sys.message"),
"execution_id": pool.get("sys.execution_id"),
"conversation_id": pool.get("sys.conversation_id"),
"workspace_id": pool.get("sys.workspace_id"),
"user_id": pool.get("sys.user_id"),
**custom_vars # 自定义变量作为节点输出的一部分
}
logger.info(
f"节点 {self.node_id} (Start) 执行完成,"
f"输出了 {len(custom_vars)} 个自定义变量"
)
return result
def _process_custom_variables(self, pool) -> dict[str, Any]:
"""处理自定义变量
从输入数据中提取自定义变量,应用默认值和验证。
Args:
pool: 变量池实例
Returns:
处理后的自定义变量字典
Raises:
ValueError: 缺少必需变量
"""
# 获取输入数据中的自定义变量
input_variables = pool.get("sys.input_variables", default={})
processed = {}
# 遍历配置的变量定义
for var_def in self.typed_config.variables:
var_name = var_def.name
# 检查变量是否存在
if var_name in input_variables:
# 使用用户提供的值
processed[var_name] = input_variables[var_name]
elif var_def.required:
# 必需变量缺失
raise ValueError(
f"缺少必需的输入变量: {var_name}"
+ (f" ({var_def.description})" if var_def.description else "")
)
elif var_def.default is not None:
# 使用默认值
processed[var_name] = var_def.default
logger.debug(
f"变量 '{var_name}' 使用默认值: {var_def.default}"
)
return processed
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
"""提取输入数据(用于记录)
Args:
state: 工作流状态
Returns:
输入数据字典
"""
pool = self.get_variable_pool(state)
return {
"execution_id": pool.get("sys.execution_id"),
"conversation_id": pool.get("sys.conversation_id"),
"message": pool.get("sys.message"),
"conversation_vars": pool.get_all_conversation_vars()
}

View File

@@ -0,0 +1,6 @@
"""Transform 节点"""
from app.core.workflow.nodes.transform.node import TransformNode
from app.core.workflow.nodes.transform.config import TransformNodeConfig
__all__ = ["TransformNode", "TransformNodeConfig"]

View File

@@ -0,0 +1,80 @@
"""Transform 节点配置"""
from typing import Literal
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
class TransformNodeConfig(BaseNodeConfig):
"""Transform 节点配置
用于数据转换和处理。
"""
transform_type: Literal["template", "code", "json"] = Field(
default="template",
description="转换类型template(模板), code(代码), json(JSON处理)"
)
# 模板模式
template: str | None = Field(
default=None,
description="转换模板,支持变量引用"
)
# 代码模式
code: str | None = Field(
default=None,
description="Python 代码,用于数据转换"
)
# JSON 模式
json_path: str | None = Field(
default=None,
description="JSON 路径表达式"
)
# 输入变量
inputs: dict[str, str] | None = Field(
default=None,
description="输入变量映射key 为变量名value 为变量选择器"
)
# 输出变量
output_key: str = Field(
default="result",
description="输出变量的键名"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="result",
type=VariableType.STRING,
description="转换后的结果"
)
],
description="输出变量定义(根据 output_key 动态生成)"
)
class Config:
json_schema_extra = {
"examples": [
{
"transform_type": "template",
"template": "用户问题:{{ sys.message }}\n回答:{{ llm_qa.output }}",
"output_key": "formatted_result"
},
{
"transform_type": "code",
"code": "result = input_text.upper()",
"inputs": {
"input_text": "{{ sys.message }}"
},
"output_key": "uppercase_text"
}
]
}

View File

@@ -0,0 +1,60 @@
"""
Transform 节点实现
数据转换节点,用于处理和转换数据。
"""
import logging
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
logger = logging.getLogger(__name__)
class TransformNode(BaseNode):
"""数据转换节点
配置示例:
{
"type": "transform",
"config": {
"mapping": {
"output_field": "{{node.previous.output}}",
"processed": "{{var.input | upper}}"
}
}
}
"""
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行数据转换
Args:
state: 工作流状态
Returns:
状态更新字典
"""
logger.info(f"节点 {self.node_id} 开始执行数据转换")
# 获取映射配置
mapping = self.config.get("mapping", {})
# 执行数据转换
transformed_data = {}
for target_key, source_template in mapping.items():
# 渲染模板获取值
value = self._render_template(str(source_template), state)
transformed_data[target_key] = value
logger.info(f"节点 {self.node_id} 数据转换完成,输出字段: {list(transformed_data.keys())}")
return {
"node_outputs": {
self.node_id: {
"output": transformed_data,
"status": "completed"
}
}
}

View File

@@ -0,0 +1,170 @@
"""
工作流模板加载器
从文件系统加载预定义的工作流模板
"""
import os
import yaml
from pathlib import Path
from typing import Optional
class TemplateLoader:
"""工作流模板加载器"""
def __init__(self, templates_dir: str = "app/templates/workflows"):
"""初始化模板加载器
Args:
templates_dir: 模板目录路径
"""
self.templates_dir = Path(templates_dir)
if not self.templates_dir.exists():
raise ValueError(f"模板目录不存在: {templates_dir}")
def list_templates(self) -> list[dict]:
"""列出所有可用的模板
Returns:
模板列表,每个模板包含 id, name, description 等信息
"""
templates = []
# 遍历模板目录
for template_dir in self.templates_dir.iterdir():
if not template_dir.is_dir():
continue
# 检查是否有 template.yml 文件
template_file = template_dir / "template.yml"
if not template_file.exists():
continue
try:
# 读取模板配置
with open(template_file, 'r', encoding='utf-8') as f:
template_data = yaml.safe_load(f)
# 提取模板信息
templates.append({
"id": template_dir.name,
"name": template_data.get("name", template_dir.name),
"description": template_data.get("description", ""),
"category": template_data.get("category", "general"),
"tags": template_data.get("tags", []),
"author": template_data.get("author", ""),
"version": template_data.get("version", "1.0.0")
})
except Exception as e:
print(f"加载模板 {template_dir.name} 失败: {e}")
continue
return templates
def load_template(self, template_id: str) -> Optional[dict]:
"""加载指定的模板
Args:
template_id: 模板 ID目录名
Returns:
模板配置字典,如果模板不存在则返回 None
"""
template_dir = self.templates_dir / template_id
template_file = template_dir / "template.yml"
if not template_file.exists():
return None
try:
with open(template_file, 'r', encoding='utf-8') as f:
template_data = yaml.safe_load(f)
# 返回工作流配置部分
return {
"name": template_data.get("name", template_id),
"description": template_data.get("description", ""),
"nodes": template_data.get("nodes", []),
"edges": template_data.get("edges", []),
"variables": template_data.get("variables", []),
"execution_config": template_data.get("execution_config", {}),
"triggers": template_data.get("triggers", [])
}
except Exception as e:
print(f"加载模板 {template_id} 失败: {e}")
return None
def get_template_readme(self, template_id: str) -> Optional[str]:
"""获取模板的 README 文档
Args:
template_id: 模板 ID
Returns:
README 内容,如果不存在则返回 None
"""
template_dir = self.templates_dir / template_id
readme_file = template_dir / "README.md"
if not readme_file.exists():
return None
try:
with open(readme_file, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
print(f"读取模板 {template_id} 的 README 失败: {e}")
return None
# 全局模板加载器实例
_template_loader: Optional[TemplateLoader] = None
def get_template_loader() -> TemplateLoader:
"""获取全局模板加载器实例
Returns:
TemplateLoader 实例
"""
global _template_loader
if _template_loader is None:
_template_loader = TemplateLoader()
return _template_loader
def list_workflow_templates() -> list[dict]:
"""列出所有工作流模板
Returns:
模板列表
"""
loader = get_template_loader()
return loader.list_templates()
def load_workflow_template(template_id: str) -> Optional[dict]:
"""加载工作流模板
Args:
template_id: 模板 ID
Returns:
模板配置,如果不存在则返回 None
"""
loader = get_template_loader()
return loader.load_template(template_id)
def get_workflow_template_readme(template_id: str) -> Optional[str]:
"""获取工作流模板的 README
Args:
template_id: 模板 ID
Returns:
README 内容,如果不存在则返回 None
"""
loader = get_template_loader()
return loader.get_template_readme(template_id)

View File

@@ -0,0 +1,170 @@
"""
模板渲染器
使用 Jinja2 提供安全的模板渲染功能,支持变量引用和表达式。
"""
import logging
from typing import Any
from jinja2 import Template, TemplateSyntaxError, UndefinedError, Environment, StrictUndefined
logger = logging.getLogger(__name__)
class TemplateRenderer:
"""模板渲染器"""
def __init__(self, strict: bool = True):
"""初始化渲染器
Args:
strict: 是否使用严格模式(未定义变量会抛出异常)
"""
self.env = Environment(
undefined=StrictUndefined if strict else None,
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
)
def render(
self,
template: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> str:
"""渲染模板
Args:
template: 模板字符串
variables: 用户定义的变量
node_outputs: 节点输出结果
system_vars: 系统变量
Returns:
渲染后的字符串
Raises:
ValueError: 模板语法错误或变量未定义
Examples:
>>> renderer = TemplateRenderer()
>>> renderer.render(
... "Hello {{var.name}}!",
... {"name": "World"},
... {},
... {}
... )
'Hello World!'
>>> renderer.render(
... "分析结果: {{node.analyze.output}}",
... {},
... {"analyze": {"output": "正面情绪"}},
... {}
... )
'分析结果: 正面情绪'
"""
# 构建命名空间上下文
context = {
"var": variables, # 用户变量:{{var.user_input}}
"node": node_outputs, # 节点输出:{{node.node_1.output}}
"sys": system_vars or {}, # 系统变量:{{sys.execution_id}}
}
# 支持直接通过节点ID访问节点输出{{llm_qa.output}}
# 将所有节点输出添加到顶层上下文
context.update(node_outputs)
# 为了向后兼容,也支持直接访问用户变量
context.update(variables)
context["nodes"] = node_outputs # 旧语法兼容
try:
tmpl = self.env.from_string(template)
return tmpl.render(**context)
except TemplateSyntaxError as e:
logger.error(f"模板语法错误: {template}, 错误: {e}")
raise ValueError(f"模板语法错误: {e}")
except UndefinedError as e:
logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}")
raise ValueError(f"未定义的变量: {e}")
except Exception as e:
logger.error(f"模板渲染异常: {template}, 错误: {e}")
raise ValueError(f"模板渲染失败: {e}")
def validate(self, template: str) -> list[str]:
"""验证模板语法
Args:
template: 模板字符串
Returns:
错误列表,如果为空则验证通过
Examples:
>>> renderer = TemplateRenderer()
>>> renderer.validate("Hello {{var.name}}!")
[]
>>> renderer.validate("Hello {{var.name") # 缺少结束标记
['模板语法错误: ...']
"""
errors = []
try:
self.env.from_string(template)
except TemplateSyntaxError as e:
errors.append(f"模板语法错误: {e}")
except Exception as e:
errors.append(f"模板验证失败: {e}")
return errors
# 全局渲染器实例(严格模式)
_default_renderer = TemplateRenderer(strict=True)
def render_template(
template: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> str:
"""渲染模板(便捷函数)
Args:
template: 模板字符串
variables: 用户变量
node_outputs: 节点输出
system_vars: 系统变量
Returns:
渲染后的字符串
Examples:
>>> render_template(
... "请分析: {{var.text}}",
... {"text": "这是一段文本"},
... {},
... {}
... )
'请分析: 这是一段文本'
"""
return _default_renderer.render(template, variables, node_outputs, system_vars)
def validate_template(template: str) -> list[str]:
"""验证模板语法(便捷函数)
Args:
template: 模板字符串
Returns:
错误列表
"""
return _default_renderer.validate(template)

View File

@@ -0,0 +1,277 @@
"""
工作流配置验证器
验证工作流配置的有效性,确保配置符合规范。
"""
import logging
from typing import Any, Union
logger = logging.getLogger(__name__)
class WorkflowValidator:
"""工作流配置验证器"""
@staticmethod
def validate(workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
"""验证工作流配置
Args:
workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型
Returns:
(is_valid, errors): 是否有效和错误列表
Examples:
>>> config = {
... "nodes": [
... {"id": "start", "type": "start"},
... {"id": "end", "type": "end"}
... ],
... "edges": [
... {"source": "start", "target": "end"}
... ]
... }
>>> is_valid, errors = WorkflowValidator.validate(config)
>>> is_valid
True
"""
errors = []
# 支持字典和 Pydantic 模型
if isinstance(workflow_config, dict):
nodes = workflow_config.get("nodes", [])
edges = workflow_config.get("edges", [])
variables = workflow_config.get("variables", [])
else:
# Pydantic 模型
nodes = getattr(workflow_config, "nodes", [])
edges = getattr(workflow_config, "edges", [])
variables = getattr(workflow_config, "variables", [])
# 1. 验证 start 节点(有且只有一个)
start_nodes = [n for n in nodes if n.get("type") == "start"]
if len(start_nodes) == 0:
errors.append("工作流必须有一个 start 节点")
elif len(start_nodes) > 1:
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)}")
# 2. 验证 end 节点(至少一个)
end_nodes = [n for n in nodes if n.get("type") == "end"]
if len(end_nodes) == 0:
errors.append("工作流必须至少有一个 end 节点")
# 3. 验证节点 ID 唯一性
node_ids = [n.get("id") for n in nodes]
if len(node_ids) != len(set(node_ids)):
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
# 4. 验证节点必须有 id 和 type
for i, node in enumerate(nodes):
if not node.get("id"):
errors.append(f"节点 #{i} 缺少 id 字段")
if not node.get("type"):
errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段")
# 5. 验证边的有效性
node_id_set = set(node_ids)
for i, edge in enumerate(edges):
source = edge.get("source")
target = edge.get("target")
if not source:
errors.append(f"边 #{i} 缺少 source 字段")
elif source not in node_id_set:
errors.append(f"边 #{i} 的 source 节点不存在: {source}")
if not target:
errors.append(f"边 #{i} 缺少 target 字段")
elif target not in node_id_set:
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
# 6. 验证所有节点可达(从 start 节点出发)
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
reachable = WorkflowValidator._get_reachable_nodes(
start_nodes[0]["id"],
edges
)
unreachable = node_id_set - reachable
if unreachable:
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
# 7. 检测循环依赖(非 loop 节点)
if not errors: # 只有在前面验证通过时才检查循环
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
if has_cycle:
errors.append(
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
)
# 8. 验证变量名
from app.core.workflow.expression_evaluator import ExpressionEvaluator
var_errors = ExpressionEvaluator.validate_variable_names(variables)
errors.extend(var_errors)
return len(errors) == 0, errors
@staticmethod
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
"""获取从 start 节点可达的所有节点
Args:
start_id: 起始节点 ID
edges: 边列表
Returns:
可达节点 ID 集合
"""
reachable = {start_id}
queue = [start_id]
while queue:
current = queue.pop(0)
for edge in edges:
if edge.get("source") == current:
target = edge.get("target")
if target and target not in reachable:
reachable.add(target)
queue.append(target)
return reachable
@staticmethod
def _has_cycle(nodes: list[dict], edges: list[dict]) -> tuple[bool, list[str]]:
"""检测是否存在循环依赖DFS
Args:
nodes: 节点列表
edges: 边列表
Returns:
(has_cycle, cycle_path): 是否有循环和循环路径
"""
# 排除 loop 类型的节点
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
# 构建邻接表(排除 loop 节点的边和错误边)
graph: dict[str, list[str]] = {}
for edge in edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
# 跳过错误边
if edge_type == "error":
continue
# 如果涉及 loop 节点,跳过
if source in loop_nodes or target in loop_nodes:
continue
if source and target:
if source not in graph:
graph[source] = []
graph[source].append(target)
# DFS 检测环
visited = set()
rec_stack = set()
path = []
cycle_path = []
def dfs(node: str) -> bool:
"""DFS 检测环,返回是否找到环"""
visited.add(node)
rec_stack.add(node)
path.append(node)
for neighbor in graph.get(node, []):
if neighbor not in visited:
if dfs(neighbor):
return True
elif neighbor in rec_stack:
# 找到环,记录环路径
cycle_start = path.index(neighbor)
cycle_path.extend([*path[cycle_start:], neighbor])
return True
rec_stack.remove(node)
path.pop()
return False
# 检查所有节点
for node_id in graph:
if node_id not in visited:
if dfs(node_id):
return True, cycle_path
return False, []
@staticmethod
def validate_for_publish(workflow_config: dict[str, Any]) -> tuple[bool, list[str]]:
"""验证工作流配置是否可以发布(更严格的验证)
Args:
workflow_config: 工作流配置
Returns:
(is_valid, errors): 是否有效和错误列表
"""
# 先执行基础验证
is_valid, errors = WorkflowValidator.validate(workflow_config)
if not is_valid:
return False, errors
# 额外的发布验证
nodes = workflow_config.get("nodes", [])
# 1. 验证所有节点都有名称
for node in nodes:
if node.get("type") not in ["start", "end"] and not node.get("name"):
errors.append(
f"节点 {node.get('id')} 缺少名称(发布时必须提供)"
)
# 2. 验证所有非 start/end 节点都有配置
for node in nodes:
node_type = node.get("type")
if node_type not in ["start", "end"]:
config = node.get("config")
if not config or not isinstance(config, dict):
errors.append(
f"节点 {node.get('id')} 缺少配置(发布时必须提供)"
)
# 3. 验证必填变量
variables = workflow_config.get("variables", [])
required_vars = [v for v in variables if v.get("required")]
if required_vars:
# 这里只是提示,实际执行时会检查
logger.info(
f"工作流包含 {len(required_vars)} 个必填变量: "
f"{[v.get('name') for v in required_vars]}"
)
return len(errors) == 0, errors
def validate_workflow_config(
workflow_config: dict[str, Any],
for_publish: bool = False
) -> tuple[bool, list[str]]:
"""验证工作流配置(便捷函数)
Args:
workflow_config: 工作流配置
for_publish: 是否为发布验证(更严格)
Returns:
(is_valid, errors): 是否有效和错误列表
"""
if for_publish:
return WorkflowValidator.validate_for_publish(workflow_config)
else:
return WorkflowValidator.validate(workflow_config)

View File

@@ -0,0 +1,293 @@
"""
变量池 (Variable Pool)
工作流执行的数据中心,管理所有变量的存储和访问。
变量类型:
1. 系统变量 (sys.*) - 系统内置变量execution_id, workspace_id, user_id, message 等)
2. 节点输出 (node_id.*) - 节点执行结果
3. 会话变量 (conv.*) - 会话级变量(跨多轮对话保持)
"""
import logging
from typing import Any
logger = logging.getLogger(__name__)
class VariableSelector:
"""变量选择器
用于引用变量的路径表示。
Examples:
>>> selector = VariableSelector(["sys", "message"])
>>> selector = VariableSelector(["node_A", "output"])
>>> selector = VariableSelector.from_string("sys.message")
"""
def __init__(self, path: list[str]):
"""初始化变量选择器
Args:
path: 变量路径,如 ["sys", "message"] 或 ["node_A", "output"]
"""
if not path or len(path) < 1:
raise ValueError("变量路径不能为空")
self.path = path
self.namespace = path[0] # sys, var, 或 node_id
self.key = path[1] if len(path) > 1 else None
@classmethod
def from_string(cls, selector_str: str) -> "VariableSelector":
"""从字符串创建选择器
Args:
selector_str: 选择器字符串,如 "sys.message""node_A.output"
Returns:
VariableSelector 实例
Examples:
>>> selector = VariableSelector.from_string("sys.message")
>>> selector = VariableSelector.from_string("llm_qa.output")
"""
path = selector_str.split(".")
return cls(path)
def __str__(self) -> str:
return ".".join(self.path)
def __repr__(self) -> str:
return f"VariableSelector({self.path})"
class VariablePool:
"""变量池
管理工作流执行过程中的所有变量。
变量命名空间:
- sys.*: 系统变量message, execution_id, workspace_id, user_id, conversation_id
- conv.*: 会话变量(跨多轮对话保持的变量)
- <node_id>.*: 节点输出
Examples:
>>> pool = VariablePool(state)
>>> pool.get(["sys", "message"])
"用户的问题"
>>> pool.get(["llm_qa", "output"])
"AI 的回答"
>>> pool.set(["conv", "user_name"], "张三")
"""
def __init__(self, state: dict[str, Any]):
"""初始化变量池
Args:
state: 工作流状态LangGraph State
"""
self.state = state
def get(self, selector: list[str] | str, default: Any = None) -> Any:
"""获取变量值
Args:
selector: 变量选择器,可以是列表或字符串
default: 默认值(变量不存在时返回)
Returns:
变量值
Examples:
>>> pool.get(["sys", "message"])
>>> pool.get("sys.message")
>>> pool.get(["llm_qa", "output"])
>>> pool.get("llm_qa.output")
Raises:
KeyError: 变量不存在且未提供默认值
"""
# 转换为 VariableSelector
if isinstance(selector, str):
selector = VariableSelector.from_string(selector).path
if not selector or len(selector) < 1:
raise ValueError("变量选择器不能为空")
namespace = selector[0]
try:
# 系统变量
if namespace == "sys":
key = selector[1] if len(selector) > 1 else None
if not key:
return self.state.get("variables", {}).get("sys", {})
return self.state.get("variables", {}).get("sys", {}).get(key, default)
# 会话变量
elif namespace == "conv":
key = selector[1] if len(selector) > 1 else None
if not key:
return self.state.get("variables", {}).get("conv", {})
return self.state.get("variables", {}).get("conv", {}).get(key, default)
# 节点输出(从 runtime_vars 读取)
else:
node_id = namespace
runtime_vars = self.state.get("runtime_vars", {})
if node_id not in runtime_vars:
if default is not None:
return default
raise KeyError(f"节点 '{node_id}' 的输出不存在")
node_var = runtime_vars[node_id]
# 如果只有节点 ID返回整个变量
if len(selector) == 1:
return node_var
# 获取特定字段
# 支持嵌套访问,如 node_id.field.subfield
result = node_var
for k in selector[1:]:
if isinstance(result, dict):
result = result.get(k)
if result is None:
if default is not None:
return default
raise KeyError(f"字段 '{'.'.join(selector)}' 不存在")
else:
if default is not None:
return default
raise KeyError(f"无法访问 '{'.'.join(selector)}'")
return result
except KeyError:
if default is not None:
return default
raise
def set(self, selector: list[str] | str, value: Any):
"""设置变量值
Args:
selector: 变量选择器
value: 变量值
Examples:
>>> pool.set(["conv", "user_name"], "张三")
>>> pool.set("conv.user_name", "张三")
Note:
- 只能设置会话变量 (conv.*)
- 系统变量和节点输出是只读的
"""
# 转换为 VariableSelector
if isinstance(selector, str):
selector = VariableSelector.from_string(selector).path
if not selector or len(selector) < 2:
raise ValueError("变量选择器必须包含命名空间和键名")
namespace = selector[0]
if namespace != "conv":
raise ValueError("只能设置会话变量 (conv.*)")
key = selector[1]
# 确保 variables 结构存在
if "variables" not in self.state:
self.state["variables"] = {"sys": {}, "conv": {}}
if "conv" not in self.state["variables"]:
self.state["variables"]["conv"] = {}
# 设置值
self.state["variables"]["conv"][key] = value
logger.debug(f"设置变量: {'.'.join(selector)} = {value}")
def has(self, selector: list[str] | str) -> bool:
"""检查变量是否存在
Args:
selector: 变量选择器
Returns:
变量是否存在
Examples:
>>> pool.has(["sys", "message"])
True
>>> pool.has("llm_qa.output")
False
"""
try:
self.get(selector)
return True
except KeyError:
return False
def get_all_system_vars(self) -> dict[str, Any]:
"""获取所有系统变量
Returns:
系统变量字典
"""
return self.state.get("variables", {}).get("sys", {})
def get_all_conversation_vars(self) -> dict[str, Any]:
"""获取所有会话变量
Returns:
会话变量字典
"""
return self.state.get("variables", {}).get("conv", {})
def get_all_node_outputs(self) -> dict[str, Any]:
"""获取所有节点输出(运行时变量)
Returns:
节点输出字典,键为节点 ID
"""
return self.state.get("runtime_vars", {})
def get_node_output(self, node_id: str) -> dict[str, Any] | None:
"""获取指定节点的输出(运行时变量)
Args:
node_id: 节点 ID
Returns:
节点输出或 None
"""
return self.state.get("runtime_vars", {}).get(node_id)
def to_dict(self) -> dict[str, Any]:
"""导出为字典
Returns:
包含所有变量的字典
"""
return {
"system": self.get_all_system_vars(),
"conversation": self.get_all_conversation_vars(),
"nodes": self.get_all_node_outputs() # 从 runtime_vars 读取
}
def __repr__(self) -> str:
sys_vars = self.get_all_system_vars()
conv_vars = self.get_all_conversation_vars()
runtime_vars = self.get_all_node_outputs()
return (
f"VariablePool(\n"
f" system_vars={len(sys_vars)},\n"
f" conversation_vars={len(conv_vars)},\n"
f" runtime_vars={len(runtime_vars)}\n"
f")"
)