[ADD] Merge code
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from functools import wraps
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import Request, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.api_key_utils import add_rate_limit_headers
|
||||
@@ -22,21 +24,17 @@ logger = get_api_logger()
|
||||
|
||||
|
||||
def require_api_key(
|
||||
scopes: Optional[List[str]] = None,
|
||||
resource_type: Optional[str] = None
|
||||
scopes: Optional[List[str]] = None
|
||||
):
|
||||
"""
|
||||
API Key 鉴权装饰器
|
||||
|
||||
Args:
|
||||
scopes: 所需的权限范围列表["app:all",
|
||||
"rag:search", "rag:upload", "rag:delete",
|
||||
"memory:read", "memory:write", "memory:delete", "memory:search"]
|
||||
resource_type: 所需的资源类型("Agent", "Cluster", "Workflow", "Knowledge", "Memory_Engine")
|
||||
scopes: 所需的权限范围列表[“app”, "rag", "memory"]
|
||||
|
||||
Usage:
|
||||
@router.get("/app/{resource_id}/chat")
|
||||
@require_api_key(scopes=["app:all"], resource_type="Agent")
|
||||
@require_api_key(scopes=["app"])
|
||||
def chat_with_app(
|
||||
resource_id: uuid.UUID,
|
||||
api_key_auth: ApiKeyAuth = Depends(),
|
||||
@@ -113,31 +111,25 @@ def require_api_key(
|
||||
context={"required_scopes": scopes, "missing_scopes": missing_scopes}
|
||||
)
|
||||
|
||||
if resource_type:
|
||||
resource_id = kwargs.get("resource_id")
|
||||
if resource_id and not ApiKeyAuthService.check_resource(
|
||||
api_key_obj,
|
||||
resource_type,
|
||||
resource_id
|
||||
):
|
||||
logger.warning("API Key 资源访问被拒绝", extra={
|
||||
"api_key_id": str(api_key_obj.id),
|
||||
"required_resource_type": resource_type,
|
||||
resource_id = kwargs.get("resource_id")
|
||||
if resource_id and not ApiKeyAuthService.check_resource(
|
||||
api_key_obj,
|
||||
resource_id
|
||||
):
|
||||
logger.warning("API Key 资源访问被拒绝", extra={
|
||||
"api_key_id": str(api_key_obj.id),
|
||||
"required_resource_id": str(resource_id),
|
||||
"bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None,
|
||||
"endpoint": str(request.url)
|
||||
})
|
||||
return BusinessException(
|
||||
"API Key 未授权访问该资源",
|
||||
BizCode.API_KEY_INVALID_RESOURCE,
|
||||
context={
|
||||
"required_resource_id": str(resource_id),
|
||||
"bound_resource_type": api_key_obj.resource_type,
|
||||
"bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None,
|
||||
"endpoint": str(request.url)
|
||||
})
|
||||
return BusinessException(
|
||||
"API Key 未授权访问该资源",
|
||||
BizCode.API_KEY_INVALID_RESOURCE,
|
||||
context={
|
||||
"required_resource_type": resource_type,
|
||||
"required_resource_id": str(resource_id),
|
||||
"bound_resource_type": api_key_obj.resource_type,
|
||||
"bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None
|
||||
}
|
||||
)
|
||||
"bound_resource_id": str(api_key_obj.resource_id)
|
||||
}
|
||||
)
|
||||
|
||||
kwargs["api_key_auth"] = ApiKeyAuth(
|
||||
api_key_id=api_key_obj.id,
|
||||
@@ -145,14 +137,17 @@ def require_api_key(
|
||||
type=api_key_obj.type,
|
||||
scopes=api_key_obj.scopes,
|
||||
resource_id=api_key_obj.resource_id,
|
||||
resource_type=api_key_obj.resource_type
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
response = await func(*args, **kwargs)
|
||||
end_time = time.perf_counter()
|
||||
response_time = (end_time - start_time) * 1000
|
||||
if not isinstance(response, Response):
|
||||
response = JSONResponse(content=response)
|
||||
response = add_rate_limit_headers(response, rate_headers)
|
||||
|
||||
asyncio.create_task(log_api_key_usage(
|
||||
db, api_key_obj.id, request, response
|
||||
db, api_key_obj.id, request, response, response_time
|
||||
))
|
||||
return response
|
||||
|
||||
@@ -204,7 +199,8 @@ async def log_api_key_usage(
|
||||
db: Session,
|
||||
api_key_id: uuid.UUID,
|
||||
request: Request,
|
||||
response: Response
|
||||
response: Response,
|
||||
response_time: float
|
||||
):
|
||||
"""记录 API Key 使用日志"""
|
||||
try:
|
||||
@@ -216,8 +212,8 @@ async def log_api_key_usage(
|
||||
"ip_address": request.client.host if request.client else None,
|
||||
"user_agent": request.headers.get("User-Agent"),
|
||||
"status_code": response.status_code if hasattr(response, "status_code") else None,
|
||||
"response_time": None, # 需要在 middleware 中计算
|
||||
"tokens_used": None, # 需要从响应中提取
|
||||
"response_time": round(response_time),
|
||||
"tokens_used": None,
|
||||
"created_at": datetime.now()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,33 +1,14 @@
|
||||
"""API Key 工具函数"""
|
||||
import secrets
|
||||
import hashlib
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
from datetime import datetime
|
||||
|
||||
from app.schemas.api_key_schema import ApiKeyType
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
|
||||
class ResourceType:
|
||||
"""资源类型常量"""
|
||||
AGENT = "Agent"
|
||||
CLUSTER = "Cluster"
|
||||
WORKFLOW = "Workflow"
|
||||
KNOWLEDGE = "Knowledge"
|
||||
MEMORY_ENGINE = "Memory_Engine"
|
||||
|
||||
@classmethod
|
||||
def get_all_types(cls) -> list[str]:
|
||||
"""获取所有支持的资源类型"""
|
||||
return [cls.AGENT, cls.CLUSTER, cls.WORKFLOW, cls.KNOWLEDGE, cls.MEMORY_ENGINE]
|
||||
|
||||
@classmethod
|
||||
def is_valid_type(cls, resource_type: str) -> bool:
|
||||
"""验证资源类型是否有效"""
|
||||
return resource_type in cls.get_all_types()
|
||||
|
||||
|
||||
def generate_api_key(key_type: ApiKeyType) -> tuple[str, str, str]:
|
||||
def generate_api_key(key_type: ApiKeyType) -> str:
|
||||
"""
|
||||
生成 API Key
|
||||
|
||||
@@ -39,102 +20,17 @@ def generate_api_key(key_type: ApiKeyType) -> tuple[str, str, str]:
|
||||
"""
|
||||
# 前缀映射
|
||||
prefix_map = {
|
||||
ApiKeyType.APP: "sk-app-",
|
||||
ApiKeyType.RAG: "sk-rag-",
|
||||
ApiKeyType.MEMORY: "sk-mem-",
|
||||
ApiKeyType.AGENT: "sk-agent-",
|
||||
ApiKeyType.CLUSTER: "sk-cluster-",
|
||||
ApiKeyType.WORKFLOW: "sk-workflow-",
|
||||
ApiKeyType.SERVICE: "sk-service-"
|
||||
}
|
||||
|
||||
prefix = prefix_map[key_type]
|
||||
random_string = secrets.token_urlsafe(32)[:32] # 32 字符
|
||||
api_key = f"{prefix}{random_string}"
|
||||
|
||||
# 生成哈希值存储
|
||||
key_hash = hash_api_key(api_key)
|
||||
|
||||
return api_key, key_hash, prefix
|
||||
|
||||
|
||||
def hash_api_key(api_key: str) -> str:
|
||||
"""对 API Key 进行哈希
|
||||
|
||||
Args:
|
||||
api_key: API Key 明文
|
||||
|
||||
Returns:
|
||||
str: 哈希值
|
||||
"""
|
||||
return hashlib.sha256(api_key.encode()).hexdigest()
|
||||
|
||||
|
||||
def verify_api_key(api_key: str, key_hash: str) -> bool:
|
||||
"""
|
||||
验证 API Key
|
||||
|
||||
Args:
|
||||
api_key: API Key 明文
|
||||
key_hash: 存储的哈希值
|
||||
|
||||
Returns:
|
||||
bool: 是否匹配
|
||||
"""
|
||||
computed_hash = hash_api_key(api_key)
|
||||
return secrets.compare_digest(computed_hash, key_hash)
|
||||
|
||||
|
||||
def validate_resource_binding(
|
||||
resource_type: Optional[str],
|
||||
resource_id: Optional[str]
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
验证资源绑定的有效性
|
||||
|
||||
Args:
|
||||
resource_type: 资源类型
|
||||
resource_id: 资源ID
|
||||
|
||||
Returns:
|
||||
tuple: (是否有效, 错误信息)
|
||||
"""
|
||||
# 如果都为空,表示不绑定资源,这是有效的
|
||||
if not resource_type and not resource_id:
|
||||
return True, ""
|
||||
|
||||
# 如果只有一个为空,这是无效的
|
||||
if not resource_type or not resource_id:
|
||||
return False, "resource_type 和 resource_id 必须同时提供或同时为空"
|
||||
|
||||
# 验证资源类型是否支持
|
||||
if not ResourceType.is_valid_type(resource_type):
|
||||
valid_types = ", ".join(ResourceType.get_all_types())
|
||||
return False, f"不支持的资源类型 '{resource_type}',支持的类型:{valid_types}"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def get_resource_scope_mapping() -> dict[str, list[str]]:
|
||||
"""
|
||||
获取资源类型与权限范围的映射关系
|
||||
|
||||
Returns:
|
||||
dict: 资源类型到推荐权限范围的映射
|
||||
"""
|
||||
return {
|
||||
ResourceType.AGENT: [
|
||||
"app:all"
|
||||
],
|
||||
ResourceType.CLUSTER: [
|
||||
"app:all"
|
||||
],
|
||||
ResourceType.WORKFLOW: [
|
||||
"app:all"
|
||||
],
|
||||
ResourceType.KNOWLEDGE: [
|
||||
"rag:search", "rag:upload", "rag:delete"
|
||||
],
|
||||
ResourceType.MEMORY_ENGINE: [
|
||||
"memory:read", "memory:write", "memory:delete", "memory:search"
|
||||
]
|
||||
}
|
||||
return api_key
|
||||
|
||||
|
||||
def add_rate_limit_headers(response, headers: dict):
|
||||
@@ -151,3 +47,21 @@ def add_rate_limit_headers(response, headers: dict):
|
||||
return response
|
||||
|
||||
|
||||
def timestamp_to_datetime(timestamp: Optional[Union[int, float]]) -> Optional[datetime]:
|
||||
"""将时间戳转换为datetime对象"""
|
||||
if timestamp is None:
|
||||
return None
|
||||
|
||||
# 处理毫秒级时间戳
|
||||
if timestamp > 1e10:
|
||||
timestamp = timestamp / 1000
|
||||
|
||||
return datetime.fromtimestamp(timestamp)
|
||||
|
||||
|
||||
def datetime_to_timestamp(dt: Optional[datetime]) -> Optional[int]:
|
||||
"""将datetime对象转换为时间戳(毫秒)"""
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
return int(dt.timestamp() * 1000)
|
||||
|
||||
@@ -59,6 +59,7 @@ class BizCode(IntEnum):
|
||||
EMBED_NOT_ALLOWED = 6009
|
||||
PERMISSION_DENIED = 6010
|
||||
INVALID_CONVERSATION = 6011
|
||||
CONFIG_MISSING = 6012
|
||||
|
||||
# 模型(7xxx)
|
||||
MODEL_CONFIG_INVALID = 7001
|
||||
@@ -96,7 +97,7 @@ HTTP_MAPPING = {
|
||||
BizCode.TOKEN_INVALID: 401,
|
||||
BizCode.TOKEN_EXPIRED: 401,
|
||||
BizCode.TOKEN_BLACKLISTED: 401,
|
||||
BizCode.FORBIDDEN: 403,
|
||||
BizCode.FORBIDDEN: 403,
|
||||
BizCode.TENANT_NOT_FOUND: 404,
|
||||
BizCode.WORKSPACE_NO_ACCESS: 403,
|
||||
BizCode.NOT_FOUND: 404,
|
||||
@@ -151,4 +152,4 @@ HTTP_MAPPING = {
|
||||
BizCode.DB_ERROR: 500,
|
||||
BizCode.SERVICE_UNAVAILABLE: 503,
|
||||
BizCode.RATE_LIMITED: 429,
|
||||
}
|
||||
}
|
||||
|
||||
436
api/app/core/workflow/executor.py
Normal file
436
api/app/core/workflow/executor.py
Normal file
@@ -0,0 +1,436 @@
|
||||
"""
|
||||
工作流执行器
|
||||
|
||||
基于 LangGraph 的工作流执行引擎。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.models.workflow_model import WorkflowExecution, WorkflowNodeExecution
|
||||
from app.db import get_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowExecutor:
|
||||
"""工作流执行器
|
||||
|
||||
负责将工作流配置转换为 LangGraph 并执行。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workflow_config: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
):
|
||||
"""初始化执行器
|
||||
|
||||
Args:
|
||||
workflow_config: 工作流配置
|
||||
execution_id: 执行 ID
|
||||
workspace_id: 工作空间 ID
|
||||
user_id: 用户 ID
|
||||
"""
|
||||
self.workflow_config = workflow_config
|
||||
self.execution_id = execution_id
|
||||
self.workspace_id = workspace_id
|
||||
self.user_id = user_id
|
||||
self.nodes = workflow_config.get("nodes", [])
|
||||
self.edges = workflow_config.get("edges", [])
|
||||
self.execution_config = workflow_config.get("execution_config", {})
|
||||
|
||||
def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
|
||||
"""准备初始状态(注入系统变量和会话变量)
|
||||
|
||||
变量命名空间:
|
||||
- sys.xxx - 系统变量(execution_id, workspace_id, user_id, message, input_variables 等)
|
||||
- conv.xxx - 会话变量(跨多轮对话保持)
|
||||
- node_id.xxx - 节点输出(执行时动态生成)
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
|
||||
Returns:
|
||||
初始化的工作流状态
|
||||
"""
|
||||
user_message = input_data.get("message") or ""
|
||||
conversation_vars = input_data.get("conversation_vars") or {}
|
||||
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
|
||||
|
||||
# 构建分层的变量结构
|
||||
variables = {
|
||||
"sys": {
|
||||
"message": user_message, # 用户消息
|
||||
"conversation_id": input_data.get("conversation_id"), # 会话 ID
|
||||
"execution_id": self.execution_id, # 执行 ID
|
||||
"workspace_id": self.workspace_id, # 工作空间 ID
|
||||
"user_id": self.user_id, # 用户 ID
|
||||
"input_variables": input_variables, # 自定义输入变量(给 Start 节点使用)
|
||||
},
|
||||
"conv": conversation_vars # 会话级变量(跨多轮对话保持)
|
||||
}
|
||||
|
||||
return {
|
||||
"messages": [HumanMessage(content=user_message)],
|
||||
"variables": variables,
|
||||
"node_outputs": {},
|
||||
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)
|
||||
"execution_id": self.execution_id,
|
||||
"workspace_id": self.workspace_id,
|
||||
"user_id": self.user_id,
|
||||
"error": None,
|
||||
"error_node": None
|
||||
}
|
||||
|
||||
|
||||
|
||||
def build_graph(self) -> StateGraph:
|
||||
"""构建 LangGraph
|
||||
|
||||
Returns:
|
||||
编译后的状态图
|
||||
"""
|
||||
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
|
||||
|
||||
# 1. 创建状态图
|
||||
workflow = StateGraph(WorkflowState)
|
||||
|
||||
# 2. 添加所有节点(包括 start 和 end)
|
||||
start_node_id = None
|
||||
end_node_ids = []
|
||||
|
||||
for node in self.nodes:
|
||||
node_type = node.get("type")
|
||||
node_id = node.get("id")
|
||||
|
||||
# 记录 start 和 end 节点 ID
|
||||
if node_type == "start":
|
||||
start_node_id = node_id
|
||||
elif node_type == "end":
|
||||
end_node_ids.append(node_id)
|
||||
|
||||
# 创建节点实例(现在 start 和 end 也会被创建)
|
||||
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
||||
if node_instance:
|
||||
# 包装节点的 run 方法
|
||||
# 使用函数工厂避免闭包问题
|
||||
def make_node_func(inst):
|
||||
async def node_func(state: WorkflowState):
|
||||
return await inst.run(state)
|
||||
return node_func
|
||||
|
||||
workflow.add_node(node_id, make_node_func(node_instance))
|
||||
logger.debug(f"添加节点: {node_id} (type={node_type})")
|
||||
|
||||
# 3. 添加边
|
||||
# 从 START 连接到 start 节点
|
||||
if start_node_id:
|
||||
workflow.add_edge(START, start_node_id)
|
||||
logger.debug(f"添加边: START -> {start_node_id}")
|
||||
|
||||
for edge in self.edges:
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
edge_type = edge.get("type")
|
||||
condition = edge.get("condition")
|
||||
|
||||
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start)
|
||||
if source == start_node_id:
|
||||
# 但要连接 start 到下一个节点
|
||||
workflow.add_edge(source, target)
|
||||
logger.debug(f"添加边: {source} -> {target}")
|
||||
continue
|
||||
|
||||
# 处理到 end 节点的边
|
||||
if target in end_node_ids:
|
||||
# 连接到 end 节点
|
||||
workflow.add_edge(source, target)
|
||||
logger.debug(f"添加边: {source} -> {target}")
|
||||
continue
|
||||
|
||||
# 跳过错误边(在节点内部处理)
|
||||
if edge_type == "error":
|
||||
continue
|
||||
|
||||
if condition:
|
||||
# 条件边
|
||||
def router(state: WorkflowState, cond=condition, tgt=target):
|
||||
"""条件路由函数"""
|
||||
if evaluate_condition(
|
||||
cond,
|
||||
state.get("variables", {}),
|
||||
state.get("node_outputs", {}),
|
||||
{
|
||||
"execution_id": state.get("execution_id"),
|
||||
"workspace_id": state.get("workspace_id"),
|
||||
"user_id": state.get("user_id")
|
||||
}
|
||||
):
|
||||
return tgt
|
||||
return END # 条件不满足,结束
|
||||
|
||||
workflow.add_conditional_edges(source, router)
|
||||
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
|
||||
else:
|
||||
# 普通边
|
||||
workflow.add_edge(source, target)
|
||||
logger.debug(f"添加边: {source} -> {target}")
|
||||
|
||||
# 从 end 节点连接到 END
|
||||
for end_node_id in end_node_ids:
|
||||
workflow.add_edge(end_node_id, END)
|
||||
logger.debug(f"添加边: {end_node_id} -> END")
|
||||
|
||||
# 4. 编译图
|
||||
graph = workflow.compile()
|
||||
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
|
||||
|
||||
return graph
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
input_data: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""执行工作流(非流式)
|
||||
|
||||
Args:
|
||||
input_data: 输入数据,包含 message 和 variables
|
||||
|
||||
Returns:
|
||||
执行结果,包含 status, output, node_outputs, elapsed_time, token_usage
|
||||
"""
|
||||
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
|
||||
|
||||
# 记录开始时间
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
# 1. 构建图
|
||||
graph = self.build_graph()
|
||||
|
||||
# 2. 初始化状态(自动注入系统变量)
|
||||
initial_state = self._prepare_initial_state(input_data)
|
||||
|
||||
# 3. 执行工作流
|
||||
try:
|
||||
result = await graph.ainvoke(initial_state)
|
||||
|
||||
# 计算耗时
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
# 提取节点输出(现在包含 start 和 end 节点)
|
||||
node_outputs = result.get("node_outputs", {})
|
||||
|
||||
# 提取最终输出(从最后一个非 start/end 节点)
|
||||
final_output = self._extract_final_output(node_outputs)
|
||||
|
||||
# 聚合 token 使用情况
|
||||
token_usage = self._aggregate_token_usage(node_outputs)
|
||||
|
||||
# 提取 conversation_id(从 start 节点输出)
|
||||
conversation_id = None
|
||||
for node_id, node_output in node_outputs.items():
|
||||
if node_output.get("node_type") == "start":
|
||||
conversation_id = node_output.get("output", {}).get("conversation_id")
|
||||
break
|
||||
|
||||
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"output": final_output,
|
||||
"node_outputs": node_outputs,
|
||||
"messages": result.get("messages", []),
|
||||
"conversation_id": conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"error": result.get("error")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# 计算耗时(即使失败也记录)
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"output": None,
|
||||
"node_outputs": {},
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": None
|
||||
}
|
||||
|
||||
async def execute_stream(
|
||||
self,
|
||||
input_data: dict[str, Any]
|
||||
):
|
||||
"""执行工作流(流式)
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
|
||||
Yields:
|
||||
流式事件
|
||||
"""
|
||||
logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}")
|
||||
|
||||
# 1. 构建图
|
||||
graph = self.build_graph()
|
||||
|
||||
# 2. 初始化状态(自动注入系统变量)
|
||||
initial_state = self._prepare_initial_state(input_data)
|
||||
|
||||
# 3. 流式执行工作流
|
||||
try:
|
||||
# 使用 astream 获取节点级别的更新
|
||||
async for event in graph.astream(initial_state, stream_mode="updates"):
|
||||
for node_name, state_update in event.items():
|
||||
yield {
|
||||
"type": "node_complete",
|
||||
"node": node_name,
|
||||
"data": state_update,
|
||||
"execution_id": self.execution_id
|
||||
}
|
||||
|
||||
logger.info(f"工作流执行完成(流式): execution_id={self.execution_id}")
|
||||
|
||||
# 发送完成事件
|
||||
yield {
|
||||
"type": "workflow_complete",
|
||||
"execution_id": self.execution_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流执行失败(流式): execution_id={self.execution_id}, error={e}", exc_info=True)
|
||||
yield {
|
||||
"type": "workflow_error",
|
||||
"execution_id": self.execution_id,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
|
||||
"""从节点输出中提取最终输出
|
||||
|
||||
优先级:
|
||||
1. 最后一个执行的非 start/end 节点的 output
|
||||
2. 如果没有节点输出,返回 None
|
||||
|
||||
Args:
|
||||
node_outputs: 所有节点的输出
|
||||
|
||||
Returns:
|
||||
最终输出字符串或 None
|
||||
"""
|
||||
if not node_outputs:
|
||||
return None
|
||||
|
||||
# 获取最后一个节点的输出
|
||||
last_node_output = list(node_outputs.values())[-1] if node_outputs else None
|
||||
|
||||
if last_node_output and isinstance(last_node_output, dict):
|
||||
return last_node_output.get("output")
|
||||
|
||||
return None
|
||||
|
||||
def _aggregate_token_usage(self, node_outputs: dict[str, Any]) -> dict[str, int] | None:
|
||||
"""聚合所有节点的 token 使用情况
|
||||
|
||||
Args:
|
||||
node_outputs: 所有节点的输出
|
||||
|
||||
Returns:
|
||||
聚合的 token 使用情况 {"prompt_tokens": x, "completion_tokens": y, "total_tokens": z}
|
||||
如果没有 token 使用信息,返回 None
|
||||
"""
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
total_tokens = 0
|
||||
has_token_info = False
|
||||
|
||||
for node_output in node_outputs.values():
|
||||
if isinstance(node_output, dict):
|
||||
token_usage = node_output.get("token_usage")
|
||||
if token_usage and isinstance(token_usage, dict):
|
||||
has_token_info = True
|
||||
total_prompt_tokens += token_usage.get("prompt_tokens", 0)
|
||||
total_completion_tokens += token_usage.get("completion_tokens", 0)
|
||||
total_tokens += token_usage.get("total_tokens", 0)
|
||||
|
||||
if not has_token_info:
|
||||
return None
|
||||
|
||||
return {
|
||||
"prompt_tokens": total_prompt_tokens,
|
||||
"completion_tokens": total_completion_tokens,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
|
||||
|
||||
async def execute_workflow(
|
||||
workflow_config: dict[str, Any],
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
) -> dict[str, Any]:
|
||||
"""执行工作流(便捷函数)
|
||||
|
||||
Args:
|
||||
workflow_config: 工作流配置
|
||||
input_data: 输入数据
|
||||
execution_id: 执行 ID
|
||||
workspace_id: 工作空间 ID
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
)
|
||||
return await executor.execute(input_data)
|
||||
|
||||
|
||||
async def execute_workflow_stream(
|
||||
workflow_config: dict[str, Any],
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
):
|
||||
"""执行工作流(流式,便捷函数)
|
||||
|
||||
Args:
|
||||
workflow_config: 工作流配置
|
||||
input_data: 输入数据
|
||||
execution_id: 执行 ID
|
||||
workspace_id: 工作空间 ID
|
||||
user_id: 用户 ID
|
||||
|
||||
Yields:
|
||||
流式事件
|
||||
"""
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
)
|
||||
async for event in executor.execute_stream(input_data):
|
||||
yield event
|
||||
195
api/app/core/workflow/expression_evaluator.py
Normal file
195
api/app/core/workflow/expression_evaluator.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""
|
||||
安全的表达式求值器
|
||||
|
||||
使用 simpleeval 库提供安全的表达式评估,避免代码注入攻击。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from simpleeval import simple_eval, NameNotDefined, InvalidExpression
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExpressionEvaluator:
|
||||
"""安全的表达式求值器"""
|
||||
|
||||
# 保留的命名空间
|
||||
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
|
||||
|
||||
@staticmethod
|
||||
def evaluate(
|
||||
expression: str,
|
||||
variables: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""安全地评估表达式
|
||||
|
||||
Args:
|
||||
expression: 表达式字符串,如 "{{var.score}} > 0.8"
|
||||
variables: 用户定义的变量
|
||||
node_outputs: 节点输出结果
|
||||
system_vars: 系统变量
|
||||
|
||||
Returns:
|
||||
表达式求值结果
|
||||
|
||||
Raises:
|
||||
ValueError: 表达式无效或求值失败
|
||||
|
||||
Examples:
|
||||
>>> evaluator = ExpressionEvaluator()
|
||||
>>> evaluator.evaluate(
|
||||
... "var.score > 0.8",
|
||||
... {"score": 0.9},
|
||||
... {},
|
||||
... {}
|
||||
... )
|
||||
True
|
||||
|
||||
>>> evaluator.evaluate(
|
||||
... "node.intent.output == '售前咨询'",
|
||||
... {},
|
||||
... {"intent": {"output": "售前咨询"}},
|
||||
... {}
|
||||
... )
|
||||
True
|
||||
"""
|
||||
# 移除 Jinja2 模板语法的花括号(如果存在)
|
||||
expression = expression.strip()
|
||||
if expression.startswith("{{") and expression.endswith("}}"):
|
||||
expression = expression[2:-2].strip()
|
||||
|
||||
# 构建命名空间上下文
|
||||
context = {
|
||||
"var": variables, # 用户变量
|
||||
"node": node_outputs, # 节点输出
|
||||
"sys": system_vars or {}, # 系统变量
|
||||
}
|
||||
|
||||
# 为了向后兼容,也支持直接访问(但会在日志中警告)
|
||||
context.update(variables)
|
||||
context["nodes"] = node_outputs
|
||||
|
||||
try:
|
||||
# simpleeval 只支持安全的操作:
|
||||
# - 算术运算: +, -, *, /, //, %, **
|
||||
# - 比较运算: ==, !=, <, <=, >, >=
|
||||
# - 逻辑运算: and, or, not
|
||||
# - 成员运算: in, not in
|
||||
# - 属性访问: obj.attr
|
||||
# - 字典/列表访问: obj["key"], obj[0]
|
||||
# 不支持:函数调用、导入、赋值等危险操作
|
||||
result = simple_eval(expression, names=context)
|
||||
return result
|
||||
|
||||
except NameNotDefined as e:
|
||||
logger.error(f"表达式中引用了未定义的变量: {expression}, 错误: {e}")
|
||||
raise ValueError(f"未定义的变量: {e}")
|
||||
|
||||
except InvalidExpression as e:
|
||||
logger.error(f"表达式语法无效: {expression}, 错误: {e}")
|
||||
raise ValueError(f"表达式语法无效: {e}")
|
||||
|
||||
except SyntaxError as e:
|
||||
logger.error(f"表达式语法错误: {expression}, 错误: {e}")
|
||||
raise ValueError(f"表达式语法错误: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"表达式求值异常: {expression}, 错误: {e}")
|
||||
raise ValueError(f"表达式求值失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def evaluate_bool(
|
||||
expression: str,
|
||||
variables: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
"""评估布尔表达式(用于条件判断)
|
||||
|
||||
Args:
|
||||
expression: 布尔表达式
|
||||
variables: 用户变量
|
||||
node_outputs: 节点输出
|
||||
system_vars: 系统变量
|
||||
|
||||
Returns:
|
||||
布尔值结果
|
||||
|
||||
Examples:
|
||||
>>> ExpressionEvaluator.evaluate_bool(
|
||||
... "var.count >= 10 and var.status == 'active'",
|
||||
... {"count": 15, "status": "active"},
|
||||
... {},
|
||||
... {}
|
||||
... )
|
||||
True
|
||||
"""
|
||||
result = ExpressionEvaluator.evaluate(
|
||||
expression, variables, node_outputs, system_vars
|
||||
)
|
||||
return bool(result)
|
||||
|
||||
@staticmethod
|
||||
def validate_variable_names(variables: list[dict]) -> list[str]:
|
||||
"""验证变量名是否合法
|
||||
|
||||
Args:
|
||||
variables: 变量定义列表
|
||||
|
||||
Returns:
|
||||
错误列表,如果为空则验证通过
|
||||
|
||||
Examples:
|
||||
>>> ExpressionEvaluator.validate_variable_names([
|
||||
... {"name": "user_input"},
|
||||
... {"name": "var"} # 保留字
|
||||
... ])
|
||||
["变量名 'var' 是保留的命名空间,请使用其他名称"]
|
||||
"""
|
||||
errors = []
|
||||
|
||||
for var in variables:
|
||||
var_name = var.get("name", "")
|
||||
|
||||
# 检查是否为保留命名空间
|
||||
if var_name in ExpressionEvaluator.RESERVED_NAMESPACES:
|
||||
errors.append(
|
||||
f"变量名 '{var_name}' 是保留的命名空间,请使用其他名称"
|
||||
)
|
||||
|
||||
# 检查是否为有效的 Python 标识符
|
||||
if not var_name.isidentifier():
|
||||
errors.append(
|
||||
f"变量名 '{var_name}' 不是有效的标识符"
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def evaluate_expression(
|
||||
expression: str,
|
||||
variables: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""评估表达式(便捷函数)"""
|
||||
return ExpressionEvaluator.evaluate(
|
||||
expression, variables, node_outputs, system_vars
|
||||
)
|
||||
|
||||
|
||||
def evaluate_condition(
|
||||
expression: str,
|
||||
variables: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
"""评估条件表达式(便捷函数)"""
|
||||
return ExpressionEvaluator.evaluate_bool(
|
||||
expression, variables, node_outputs, system_vars
|
||||
)
|
||||
24
api/app/core/workflow/nodes/__init__.py
Normal file
24
api/app/core/workflow/nodes/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
工作流节点实现
|
||||
|
||||
提供各种类型的节点实现,用于工作流执行。
|
||||
"""
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.agent import AgentNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.end import EndNode
|
||||
from app.core.workflow.nodes.node_factory import NodeFactory
|
||||
|
||||
__all__ = [
|
||||
"BaseNode",
|
||||
"WorkflowState",
|
||||
"LLMNode",
|
||||
"AgentNode",
|
||||
"TransformNode",
|
||||
"StartNode",
|
||||
"EndNode",
|
||||
"NodeFactory",
|
||||
]
|
||||
6
api/app/core/workflow/nodes/agent/__init__.py
Normal file
6
api/app/core/workflow/nodes/agent/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Agent 节点"""
|
||||
|
||||
from app.core.workflow.nodes.agent.node import AgentNode
|
||||
from app.core.workflow.nodes.agent.config import AgentNodeConfig
|
||||
|
||||
__all__ = ["AgentNode", "AgentNodeConfig"]
|
||||
71
api/app/core/workflow/nodes/agent/config.py
Normal file
71
api/app/core/workflow/nodes/agent/config.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Agent 节点配置"""
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
|
||||
|
||||
class AgentNodeConfig(BaseNodeConfig):
|
||||
"""Agent 节点配置
|
||||
|
||||
调用已配置的 Agent 执行任务。
|
||||
"""
|
||||
|
||||
agent_id: str = Field(
|
||||
...,
|
||||
description="Agent 配置 ID"
|
||||
)
|
||||
|
||||
message: str = Field(
|
||||
default="{{ sys.message }}",
|
||||
description="发送给 Agent 的消息,支持模板变量"
|
||||
)
|
||||
|
||||
conversation_id: str | None = Field(
|
||||
default=None,
|
||||
description="会话 ID,用于多轮对话"
|
||||
)
|
||||
|
||||
variables: dict[str, str] | None = Field(
|
||||
default=None,
|
||||
description="传递给 Agent 的变量"
|
||||
)
|
||||
|
||||
timeout: int = Field(
|
||||
default=300,
|
||||
ge=1,
|
||||
le=3600,
|
||||
description="超时时间(秒)"
|
||||
)
|
||||
|
||||
# 输出变量定义
|
||||
output_variables: list[VariableDefinition] = Field(
|
||||
default_factory=lambda: [
|
||||
VariableDefinition(
|
||||
name="output",
|
||||
type=VariableType.STRING,
|
||||
description="Agent 的回复内容"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="conversation_id",
|
||||
type=VariableType.STRING,
|
||||
description="会话 ID"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="token_usage",
|
||||
type=VariableType.OBJECT,
|
||||
description="Token 使用情况"
|
||||
)
|
||||
],
|
||||
description="输出变量定义(自动生成,通常不需要修改)"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"agent_id": "uuid-here",
|
||||
"message": "{{ sys.message }}",
|
||||
"timeout": 300,
|
||||
"description": "调用客服 Agent"
|
||||
}
|
||||
}
|
||||
152
api/app/core/workflow/nodes/agent/node.py
Normal file
152
api/app/core/workflow/nodes/agent/node.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
Agent 节点实现
|
||||
|
||||
调用已发布的 Agent 应用。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.models import AppRelease
|
||||
from app.db import get_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentNode(BaseNode):
|
||||
"""Agent 节点
|
||||
|
||||
支持流式和非流式输出。
|
||||
|
||||
配置示例:
|
||||
{
|
||||
"type": "agent",
|
||||
"config": {
|
||||
"agent_id": "uuid", # Agent 的 release_id
|
||||
"message": "{{var.user_input}}"
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def _prepare_agent(self, state: WorkflowState) -> tuple[DraftRunService, AppRelease, str]:
|
||||
"""准备 Agent(公共逻辑)
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
(draft_service, release, message): 服务实例、发布配置、消息
|
||||
"""
|
||||
# 1. 渲染消息
|
||||
message_template = self.config.get("message", "")
|
||||
message = self._render_template(message_template, state)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
agent_id = self.config.get("agent_id")
|
||||
if not agent_id:
|
||||
raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置")
|
||||
|
||||
db = next(get_db())
|
||||
release = db.query(AppRelease).filter(
|
||||
AppRelease.id == agent_id
|
||||
).first()
|
||||
|
||||
if not release:
|
||||
raise ValueError(f"Agent 不存在: {agent_id}")
|
||||
|
||||
draft_service = DraftRunService(db)
|
||||
|
||||
return draft_service, release, message
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""非流式执行
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
状态更新字典
|
||||
"""
|
||||
draft_service, release, message = self._prepare_agent(state)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
|
||||
|
||||
# 执行 Agent(非流式)
|
||||
result = await draft_service.run(
|
||||
agent_config=release.config,
|
||||
model_config=None,
|
||||
message=message,
|
||||
workspace_id=state.get("workspace_id"),
|
||||
user_id=state.get("user_id"),
|
||||
variables=state.get("variables", {})
|
||||
)
|
||||
|
||||
response = result.get("response", "")
|
||||
|
||||
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(response)}")
|
||||
|
||||
return {
|
||||
"messages": [AIMessage(content=response)],
|
||||
"node_outputs": {
|
||||
self.node_id: {
|
||||
"output": response,
|
||||
"status": "completed",
|
||||
"meta_data": result.get("meta_data", {})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""流式执行
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Yields:
|
||||
流式事件字典
|
||||
"""
|
||||
draft_service, release, message = self._prepare_agent(state)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
|
||||
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
|
||||
# 执行 Agent(流式)
|
||||
async for chunk in draft_service.run_stream(
|
||||
agent_config=release.config,
|
||||
model_config=None,
|
||||
message=message,
|
||||
workspace_id=state.get("workspace_id"),
|
||||
user_id=state.get("user_id"),
|
||||
variables=state.get("variables", {})
|
||||
):
|
||||
# 提取内容
|
||||
content = chunk.get("content", "")
|
||||
full_response += content
|
||||
|
||||
# 流式返回每个 chunk
|
||||
yield {
|
||||
"type": "chunk",
|
||||
"node_id": self.node_id,
|
||||
"content": content,
|
||||
"full_content": full_response,
|
||||
"meta_data": chunk.get("meta_data", {})
|
||||
}
|
||||
|
||||
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}")
|
||||
|
||||
# 最后返回完整结果
|
||||
yield {
|
||||
"type": "complete",
|
||||
"messages": [AIMessage(content=full_response)],
|
||||
"node_outputs": {
|
||||
self.node_id: {
|
||||
"output": full_response,
|
||||
"status": "completed"
|
||||
}
|
||||
}
|
||||
}
|
||||
109
api/app/core/workflow/nodes/base_config.py
Normal file
109
api/app/core/workflow/nodes/base_config.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""节点配置基类
|
||||
|
||||
定义所有节点配置的通用字段和数据结构。
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VariableType(StrEnum):
|
||||
"""变量类型枚举"""
|
||||
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
|
||||
|
||||
class VariableDefinition(BaseModel):
|
||||
"""变量定义
|
||||
|
||||
定义工作流或节点的输入/输出变量。
|
||||
这是一个通用的数据结构,可以在多个地方使用。
|
||||
"""
|
||||
|
||||
name: str = Field(
|
||||
...,
|
||||
description="变量名称"
|
||||
)
|
||||
|
||||
type: VariableType = Field(
|
||||
default=VariableType.STRING,
|
||||
description="变量类型"
|
||||
)
|
||||
|
||||
required: bool = Field(
|
||||
default=False,
|
||||
description="是否必需"
|
||||
)
|
||||
|
||||
default: str | int | float | bool | list | dict | None = Field(
|
||||
default=None,
|
||||
description="默认值"
|
||||
)
|
||||
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="变量描述"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
{
|
||||
"name": "language",
|
||||
"type": "string",
|
||||
"required": False,
|
||||
"default": "zh-CN",
|
||||
"description": "语言设置"
|
||||
},
|
||||
{
|
||||
"name": "max_length",
|
||||
"type": "number",
|
||||
"required": False,
|
||||
"default": 1000,
|
||||
"description": "最大长度"
|
||||
},
|
||||
{
|
||||
"name": "enable_search",
|
||||
"type": "boolean",
|
||||
"required": True,
|
||||
"description": "是否启用搜索"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class BaseNodeConfig(BaseModel):
|
||||
"""节点配置基类
|
||||
|
||||
所有节点配置都应该继承此基类。
|
||||
|
||||
通用字段:
|
||||
- name: 节点名称(显示名称)
|
||||
- description: 节点描述
|
||||
- tags: 节点标签(用于分类和搜索)
|
||||
"""
|
||||
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
description="节点名称(显示名称),如果不设置则使用节点 ID"
|
||||
)
|
||||
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="节点描述,说明节点的作用"
|
||||
)
|
||||
|
||||
tags: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="节点标签,用于分类和搜索"
|
||||
)
|
||||
|
||||
class Config:
|
||||
"""Pydantic 配置"""
|
||||
# 允许额外字段(向后兼容)
|
||||
extra = "allow"
|
||||
556
api/app/core/workflow/nodes/base_node.py
Normal file
556
api/app/core/workflow/nodes/base_node.py
Normal file
@@ -0,0 +1,556 @@
|
||||
"""
|
||||
工作流节点基类
|
||||
|
||||
定义节点的基本接口和通用功能。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, TypedDict, Annotated
|
||||
from operator import add
|
||||
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
|
||||
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowState(TypedDict):
|
||||
"""工作流状态
|
||||
|
||||
在节点间传递的状态对象,包含消息、变量、节点输出等信息。
|
||||
"""
|
||||
# 消息列表(追加模式)
|
||||
messages: Annotated[list[AnyMessage], add]
|
||||
|
||||
# 输入变量(从配置的 variables 传入)
|
||||
variables: dict[str, Any]
|
||||
|
||||
# 节点输出(存储每个节点的执行结果,用于变量引用)
|
||||
# 使用自定义合并函数,将新的节点输出合并到现有字典中
|
||||
node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# 运行时节点变量(简化版,只存储业务数据,供节点间快速访问)
|
||||
# 格式:{node_id: business_result}
|
||||
runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# 执行上下文
|
||||
execution_id: str
|
||||
workspace_id: str
|
||||
user_id: str
|
||||
|
||||
# 错误信息(用于错误边)
|
||||
error: str | None
|
||||
error_node: str | None
|
||||
|
||||
|
||||
class BaseNode(ABC):
|
||||
"""节点基类
|
||||
|
||||
所有节点类型都应该继承此基类,实现 execute 方法。
|
||||
"""
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
"""初始化节点
|
||||
|
||||
Args:
|
||||
node_config: 节点配置
|
||||
workflow_config: 工作流配置
|
||||
"""
|
||||
self.node_config = node_config
|
||||
self.workflow_config = workflow_config
|
||||
self.node_id = node_config["id"]
|
||||
self.node_type = node_config["type"]
|
||||
self.node_name = node_config.get("name", self.node_id)
|
||||
# 使用 or 运算符处理 None 值
|
||||
self.config = node_config.get("config") or {}
|
||||
self.error_handling = node_config.get("error_handling") or {}
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""执行节点业务逻辑(非流式)
|
||||
|
||||
节点只需要返回业务结果,不需要关心输出格式、时间统计等。
|
||||
BaseNode 会自动包装成标准格式。
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
业务结果(任意类型)
|
||||
|
||||
Examples:
|
||||
>>> # LLM 节点
|
||||
>>> return "这是 AI 的回复"
|
||||
|
||||
>>> # Transform 节点
|
||||
>>> return {"processed_data": [...]}
|
||||
|
||||
>>> # Start/End 节点
|
||||
>>> return {"message": "开始", "conversation_id": "xxx"}
|
||||
"""
|
||||
pass
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""执行节点业务逻辑(流式)
|
||||
|
||||
子类可以重写此方法以支持流式输出。
|
||||
默认实现:执行非流式方法并一次性返回。
|
||||
|
||||
节点需要:
|
||||
1. yield 中间结果(如文本片段)
|
||||
2. 最后 yield 一个特殊的完成标记:{"__final__": True, "result": final_result}
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Yields:
|
||||
业务数据(chunk)或完成标记
|
||||
|
||||
Examples:
|
||||
>>> # 流式 LLM 节点
|
||||
>>> full_response = ""
|
||||
>>> async for chunk in llm.astream(prompt):
|
||||
... full_response += chunk
|
||||
... yield chunk # yield 文本片段
|
||||
>>>
|
||||
>>> # 最后 yield 完成标记
|
||||
>>> yield {"__final__": True, "result": AIMessage(content=full_response)}
|
||||
"""
|
||||
result = await self.execute(state)
|
||||
# 默认实现:直接 yield 完成标记
|
||||
yield {"__final__": True, "result": result}
|
||||
|
||||
def supports_streaming(self) -> bool:
|
||||
"""节点是否支持流式输出
|
||||
|
||||
Returns:
|
||||
是否支持流式输出
|
||||
"""
|
||||
# 检查子类是否重写了 execute_stream 方法
|
||||
return self.execute_stream.__func__ != BaseNode.execute_stream.__func__
|
||||
|
||||
def get_timeout(self) -> int:
|
||||
"""获取超时时间(秒)
|
||||
|
||||
Returns:
|
||||
超时时间
|
||||
"""
|
||||
return 60
|
||||
# return self.error_handling.get("timeout", 60)
|
||||
|
||||
async def run(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""执行节点(带错误处理和输出包装,非流式)
|
||||
|
||||
这个方法由 Executor 调用,负责:
|
||||
1. 时间统计
|
||||
2. 调用节点的 execute() 方法
|
||||
3. 将业务结果包装成标准输出格式
|
||||
4. 错误处理
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
标准化的状态更新字典
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
timeout = self.get_timeout()
|
||||
|
||||
# 调用节点的业务逻辑
|
||||
business_result = await asyncio.wait_for(
|
||||
self.execute(state),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 提取处理后的输出(调用子类的 _extract_output)
|
||||
extracted_output = self._extract_output(business_result)
|
||||
|
||||
# 包装成标准输出格式
|
||||
wrapped_output = self._wrap_output(business_result, elapsed_time, state)
|
||||
|
||||
# 将提取后的输出存储到运行时变量中(供后续节点快速访问)
|
||||
# 如果提取后的输出是字典,拆包存储;否则存储为 output 字段
|
||||
if isinstance(extracted_output, dict):
|
||||
runtime_var = extracted_output
|
||||
else:
|
||||
runtime_var = {"output": extracted_output}
|
||||
|
||||
# 返回包装后的输出和运行时变量
|
||||
return {
|
||||
**wrapped_output,
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
}
|
||||
}
|
||||
|
||||
except TimeoutError:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
|
||||
return self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
|
||||
return self._wrap_error(str(e), elapsed_time, state)
|
||||
|
||||
async def run_stream(self, state: WorkflowState):
|
||||
"""执行节点(带错误处理和输出包装,流式)
|
||||
|
||||
这个方法由 Executor 调用,负责:
|
||||
1. 时间统计
|
||||
2. 调用节点的 execute_stream() 方法
|
||||
3. 将业务数据包装成标准输出格式
|
||||
4. 错误处理
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Yields:
|
||||
标准化的流式事件
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
timeout = self.get_timeout()
|
||||
|
||||
# 累积完整结果(用于最后的包装)
|
||||
chunks = []
|
||||
final_result = None
|
||||
|
||||
# 使用异步生成器包装,支持超时
|
||||
async def stream_with_timeout():
|
||||
nonlocal final_result
|
||||
loop_start = asyncio.get_event_loop().time()
|
||||
|
||||
async for item in self.execute_stream(state):
|
||||
# 检查超时
|
||||
if asyncio.get_event_loop().time() - loop_start > timeout:
|
||||
raise TimeoutError()
|
||||
|
||||
# 检查是否是完成标记
|
||||
if isinstance(item, dict) and item.get("__final__"):
|
||||
final_result = item["result"]
|
||||
elif isinstance(item, str):
|
||||
# 字符串是 chunk
|
||||
chunks.append(item)
|
||||
yield {
|
||||
"type": "chunk",
|
||||
"node_id": self.node_id,
|
||||
"content": item,
|
||||
"full_content": "".join(chunks)
|
||||
}
|
||||
else:
|
||||
# 其他类型也当作 chunk 处理
|
||||
chunks.append(str(item))
|
||||
yield {
|
||||
"type": "chunk",
|
||||
"node_id": self.node_id,
|
||||
"content": str(item),
|
||||
"full_content": "".join(chunks)
|
||||
}
|
||||
|
||||
async for chunk_event in stream_with_timeout():
|
||||
yield chunk_event
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 包装最终结果
|
||||
final_output = self._wrap_output(final_result, elapsed_time, state)
|
||||
yield {
|
||||
"type": "complete",
|
||||
**final_output
|
||||
}
|
||||
|
||||
except TimeoutError:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
|
||||
yield {
|
||||
"type": "error",
|
||||
**self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
|
||||
}
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
|
||||
yield {
|
||||
"type": "error",
|
||||
**self._wrap_error(str(e), elapsed_time, state)
|
||||
}
|
||||
|
||||
def _wrap_output(
|
||||
self,
|
||||
business_result: Any,
|
||||
elapsed_time: float,
|
||||
state: WorkflowState
|
||||
) -> dict[str, Any]:
|
||||
"""将业务结果包装成标准输出格式
|
||||
|
||||
Args:
|
||||
business_result: 节点返回的业务结果
|
||||
elapsed_time: 执行耗时
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
标准化的状态更新字典
|
||||
"""
|
||||
# 提取输入数据(用于记录)
|
||||
input_data = self._extract_input(state)
|
||||
|
||||
# 提取 token 使用情况(如果有)
|
||||
token_usage = self._extract_token_usage(business_result)
|
||||
|
||||
# 提取实际输出(去除元数据)
|
||||
output = self._extract_output(business_result)
|
||||
|
||||
# 构建标准节点输出
|
||||
node_output = {
|
||||
"node_id": self.node_id,
|
||||
"node_type": self.node_type,
|
||||
"node_name": self.node_name,
|
||||
"status": "completed",
|
||||
"input": input_data,
|
||||
"output": output,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"error": None
|
||||
}
|
||||
|
||||
return {
|
||||
"node_outputs": {
|
||||
self.node_id: node_output
|
||||
}
|
||||
}
|
||||
|
||||
def _wrap_error(
|
||||
self,
|
||||
error_message: str,
|
||||
elapsed_time: float,
|
||||
state: WorkflowState
|
||||
) -> dict[str, Any]:
|
||||
"""将错误包装成标准输出格式
|
||||
|
||||
Args:
|
||||
error_message: 错误信息
|
||||
elapsed_time: 执行耗时
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
标准化的状态更新字典
|
||||
"""
|
||||
# 查找错误边
|
||||
error_edge = self._find_error_edge()
|
||||
|
||||
# 提取输入数据
|
||||
input_data = self._extract_input(state)
|
||||
|
||||
# 构建错误输出
|
||||
node_output = {
|
||||
"node_id": self.node_id,
|
||||
"node_type": self.node_type,
|
||||
"node_name": self.node_name,
|
||||
"status": "failed",
|
||||
"input": input_data,
|
||||
"output": None,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": None,
|
||||
"error": error_message
|
||||
}
|
||||
|
||||
if error_edge:
|
||||
# 有错误边:记录错误并继续
|
||||
logger.warning(
|
||||
f"节点 {self.node_id} 执行失败,跳转到错误处理节点: {error_edge['target']}"
|
||||
)
|
||||
return {
|
||||
"node_outputs": {
|
||||
self.node_id: node_output
|
||||
},
|
||||
"error": error_message,
|
||||
"error_node": self.node_id
|
||||
}
|
||||
else:
|
||||
# 无错误边:抛出异常停止工作流
|
||||
logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}")
|
||||
raise Exception(f"节点 {self.node_id} 执行失败: {error_message}")
|
||||
|
||||
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""提取节点输入数据(用于记录)
|
||||
|
||||
子类可以重写此方法来自定义输入记录。
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
输入数据字典
|
||||
"""
|
||||
# 默认返回配置
|
||||
return {"config": self.config}
|
||||
|
||||
def _extract_output(self, business_result: Any) -> Any:
|
||||
"""从业务结果中提取实际输出
|
||||
|
||||
子类可以重写此方法来自定义输出提取。
|
||||
|
||||
Args:
|
||||
business_result: 业务结果
|
||||
|
||||
Returns:
|
||||
实际输出
|
||||
"""
|
||||
# 默认直接返回业务结果
|
||||
return business_result
|
||||
|
||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||
"""从业务结果中提取 token 使用情况
|
||||
|
||||
子类可以重写此方法来提取 token 信息。
|
||||
|
||||
Args:
|
||||
business_result: 业务结果
|
||||
|
||||
Returns:
|
||||
token 使用情况或 None
|
||||
"""
|
||||
# 默认返回 None
|
||||
return None
|
||||
|
||||
def _find_error_edge(self) -> dict[str, Any] | None:
|
||||
"""查找错误边
|
||||
|
||||
Returns:
|
||||
错误边配置或 None
|
||||
"""
|
||||
for edge in self.workflow_config.get("edges", []):
|
||||
if edge.get("source") == self.node_id and edge.get("type") == "error":
|
||||
return edge
|
||||
return None
|
||||
|
||||
def _render_template(self, template: str, state: WorkflowState | None) -> str:
|
||||
"""渲染模板
|
||||
|
||||
支持的变量命名空间:
|
||||
- sys.xxx: 系统变量(message, execution_id, workspace_id, user_id, conversation_id)
|
||||
- conv.xxx: 会话变量(跨多轮对话保持)
|
||||
- node_id.xxx: 节点输出
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
渲染后的字符串
|
||||
"""
|
||||
from app.core.workflow.template_renderer import render_template
|
||||
|
||||
# 处理 state 为 None 的情况
|
||||
if state is None:
|
||||
state = {}
|
||||
|
||||
# 使用变量池获取变量
|
||||
pool = VariablePool(state)
|
||||
|
||||
return render_template(
|
||||
template=template,
|
||||
variables=pool.get_all_conversation_vars(),
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars()
|
||||
)
|
||||
|
||||
def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool:
|
||||
"""评估条件表达式
|
||||
|
||||
支持的变量命名空间:
|
||||
- sys.xxx: 系统变量
|
||||
- conv.xxx: 会话变量
|
||||
- node_id.xxx: 节点输出
|
||||
|
||||
Args:
|
||||
expression: 条件表达式
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
布尔值结果
|
||||
"""
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
|
||||
# 处理 state 为 None 的情况
|
||||
if state is None:
|
||||
state = {}
|
||||
|
||||
# 使用变量池获取变量
|
||||
pool = VariablePool(state)
|
||||
|
||||
return evaluate_condition(
|
||||
expression=expression,
|
||||
variables=pool.get_all_conversation_vars(),
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars()
|
||||
)
|
||||
|
||||
def get_variable_pool(self, state: WorkflowState) -> VariablePool:
|
||||
"""获取变量池实例
|
||||
|
||||
VariablePool 是轻量级包装器,只持有 state 的引用,创建成本极低。
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
VariablePool 实例
|
||||
|
||||
Examples:
|
||||
>>> pool = self.get_variable_pool(state)
|
||||
>>> message = pool.get("sys.message")
|
||||
>>> llm_output = pool.get("llm_qa.output")
|
||||
"""
|
||||
return VariablePool(state)
|
||||
|
||||
def get_variable(
|
||||
self,
|
||||
selector: list[str] | str,
|
||||
state: WorkflowState,
|
||||
default: Any = None
|
||||
) -> Any:
|
||||
"""获取变量值(便捷方法)
|
||||
|
||||
Args:
|
||||
selector: 变量选择器
|
||||
state: 工作流状态
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
变量值
|
||||
|
||||
Examples:
|
||||
>>> message = self.get_variable("sys.message", state)
|
||||
>>> output = self.get_variable(["llm_qa", "output"], state)
|
||||
>>> custom = self.get_variable("var.custom", state, default="默认值")
|
||||
"""
|
||||
pool = VariablePool(state)
|
||||
return pool.get(selector, default=default)
|
||||
|
||||
def has_variable(self, selector: list[str] | str, state: WorkflowState) -> bool:
|
||||
"""检查变量是否存在(便捷方法)
|
||||
|
||||
Args:
|
||||
selector: 变量选择器
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
变量是否存在
|
||||
|
||||
Examples:
|
||||
>>> if self.has_variable("llm_qa.output", state):
|
||||
... output = self.get_variable("llm_qa.output", state)
|
||||
"""
|
||||
pool = VariablePool(state)
|
||||
return pool.has(selector)
|
||||
29
api/app/core/workflow/nodes/configs.py
Normal file
29
api/app/core/workflow/nodes/configs.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""节点配置类统一导出
|
||||
|
||||
所有节点的配置类都在这里导出,方便使用。
|
||||
"""
|
||||
|
||||
from app.core.workflow.nodes.base_config import (
|
||||
BaseNodeConfig,
|
||||
VariableDefinition,
|
||||
VariableType,
|
||||
)
|
||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
from app.core.workflow.nodes.end.config import EndNodeConfig
|
||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
|
||||
from app.core.workflow.nodes.agent.config import AgentNodeConfig
|
||||
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||
|
||||
__all__ = [
|
||||
# 基础类
|
||||
"BaseNodeConfig",
|
||||
"VariableDefinition",
|
||||
"VariableType",
|
||||
# 节点配置
|
||||
"StartNodeConfig",
|
||||
"EndNodeConfig",
|
||||
"LLMNodeConfig",
|
||||
"MessageConfig",
|
||||
"AgentNodeConfig",
|
||||
"TransformNodeConfig",
|
||||
]
|
||||
6
api/app/core/workflow/nodes/end/__init__.py
Normal file
6
api/app/core/workflow/nodes/end/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""End 节点"""
|
||||
|
||||
from app.core.workflow.nodes.end.node import EndNode
|
||||
from app.core.workflow.nodes.end.config import EndNodeConfig
|
||||
|
||||
__all__ = ["EndNode", "EndNodeConfig"]
|
||||
37
api/app/core/workflow/nodes/end/config.py
Normal file
37
api/app/core/workflow/nodes/end/config.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""End 节点配置"""
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
|
||||
|
||||
class EndNodeConfig(BaseNodeConfig):
|
||||
"""End 节点配置
|
||||
|
||||
End 节点负责输出工作流的最终结果。
|
||||
"""
|
||||
|
||||
output: str = Field(
|
||||
default="工作流已完成",
|
||||
description="输出模板,支持引用前置节点的输出,如:{{ llm_qa.output }}"
|
||||
)
|
||||
|
||||
# 输出变量定义
|
||||
output_variables: list[VariableDefinition] = Field(
|
||||
default_factory=lambda: [
|
||||
VariableDefinition(
|
||||
name="output",
|
||||
type=VariableType.STRING,
|
||||
description="工作流的最终输出"
|
||||
)
|
||||
],
|
||||
description="输出变量定义(自动生成,通常不需要修改)"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"output": "{{ llm_qa.output }}",
|
||||
"description": "输出 LLM 的回答"
|
||||
}
|
||||
}
|
||||
53
api/app/core/workflow/nodes/end/node.py
Normal file
53
api/app/core/workflow/nodes/end/node.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
End 节点实现
|
||||
|
||||
工作流的结束节点,输出最终结果。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EndNode(BaseNode):
|
||||
"""End 节点
|
||||
|
||||
工作流的结束节点,根据配置的模板输出最终结果。
|
||||
"""
|
||||
|
||||
async def execute(self, state: WorkflowState) -> str:
|
||||
"""执行 end 节点业务逻辑
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
最终输出字符串
|
||||
"""
|
||||
logger.info(f"节点 {self.node_id} (End) 开始执行")
|
||||
|
||||
# 获取配置的输出模板
|
||||
output_template = self.config.get("output")
|
||||
pool = self.get_variable_pool(state)
|
||||
|
||||
print("="*20)
|
||||
print( pool.get("start.test"))
|
||||
print("="*20)
|
||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||
if output_template:
|
||||
output = self._render_template(output_template, state)
|
||||
else:
|
||||
output = "工作流已完成"
|
||||
|
||||
# 统计信息(用于日志)
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
total_nodes = len(node_outputs)
|
||||
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
|
||||
print("="*20)
|
||||
print(output)
|
||||
print("="*20)
|
||||
return output
|
||||
15
api/app/core/workflow/nodes/enums.py
Normal file
15
api/app/core/workflow/nodes/enums.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from enum import StrEnum
|
||||
|
||||
class NodeType(StrEnum):
|
||||
START = "start"
|
||||
END = "end"
|
||||
ANSWER = "answer"
|
||||
LLM = "llm"
|
||||
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
|
||||
IF_ELSE = "if-else"
|
||||
CODE = "code"
|
||||
TRANSFORM = "transform"
|
||||
QUESTION_CLASSIFIER = "question-classifier"
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
AGENT = "agent"
|
||||
6
api/app/core/workflow/nodes/llm/__init__.py
Normal file
6
api/app/core/workflow/nodes/llm/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""LLM 节点"""
|
||||
|
||||
from app.core.workflow.nodes.llm.node import LLMNode
|
||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
|
||||
|
||||
__all__ = ["LLMNode", "LLMNodeConfig", "MessageConfig"]
|
||||
141
api/app/core/workflow/nodes/llm/config.py
Normal file
141
api/app/core/workflow/nodes/llm/config.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""LLM 节点配置"""
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
|
||||
|
||||
class MessageConfig(BaseModel):
|
||||
"""消息配置"""
|
||||
|
||||
role: str = Field(
|
||||
...,
|
||||
description="消息角色:system, user, assistant"
|
||||
)
|
||||
|
||||
content: str = Field(
|
||||
...,
|
||||
description="消息内容,支持模板变量,如:{{ sys.message }}"
|
||||
)
|
||||
|
||||
@field_validator("role")
|
||||
@classmethod
|
||||
def validate_role(cls, v: str) -> str:
|
||||
"""验证角色"""
|
||||
allowed_roles = ["system", "user", "human", "assistant", "ai"]
|
||||
if v.lower() not in allowed_roles:
|
||||
raise ValueError(f"角色必须是以下之一: {', '.join(allowed_roles)}")
|
||||
return v.lower()
|
||||
|
||||
|
||||
class LLMNodeConfig(BaseNodeConfig):
|
||||
"""LLM 节点配置
|
||||
|
||||
支持两种配置方式:
|
||||
1. 简单模式:使用 prompt 字段
|
||||
2. 消息模式:使用 messages 字段(推荐)
|
||||
"""
|
||||
|
||||
model_id: str = Field(
|
||||
...,
|
||||
description="模型配置 ID"
|
||||
)
|
||||
|
||||
# 简单模式
|
||||
prompt: str | None = Field(
|
||||
default=None,
|
||||
description="提示词模板(简单模式),支持变量引用"
|
||||
)
|
||||
|
||||
# 消息模式(推荐)
|
||||
messages: list[MessageConfig] | None = Field(
|
||||
default=None,
|
||||
description="消息列表(消息模式),支持多轮对话"
|
||||
)
|
||||
|
||||
# 模型参数
|
||||
temperature: float | None = Field(
|
||||
default=0.7,
|
||||
ge=0.0,
|
||||
le=2.0,
|
||||
description="温度参数,控制输出的随机性"
|
||||
)
|
||||
|
||||
max_tokens: int | None = Field(
|
||||
default=1000,
|
||||
ge=1,
|
||||
le=32000,
|
||||
description="最大生成 token 数"
|
||||
)
|
||||
|
||||
top_p: float | None = Field(
|
||||
default=None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Top-p 采样参数"
|
||||
)
|
||||
|
||||
frequency_penalty: float | None = Field(
|
||||
default=None,
|
||||
ge=-2.0,
|
||||
le=2.0,
|
||||
description="频率惩罚"
|
||||
)
|
||||
|
||||
presence_penalty: float | None = Field(
|
||||
default=None,
|
||||
ge=-2.0,
|
||||
le=2.0,
|
||||
description="存在惩罚"
|
||||
)
|
||||
|
||||
# 输出变量定义
|
||||
output_variables: list[VariableDefinition] = Field(
|
||||
default_factory=lambda: [
|
||||
VariableDefinition(
|
||||
name="output",
|
||||
type=VariableType.STRING,
|
||||
description="LLM 生成的文本输出"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="token_usage",
|
||||
type=VariableType.OBJECT,
|
||||
description="Token 使用情况"
|
||||
)
|
||||
],
|
||||
description="输出变量定义(自动生成,通常不需要修改)"
|
||||
)
|
||||
|
||||
@field_validator("messages", "prompt")
|
||||
@classmethod
|
||||
def validate_input_mode(cls, v, info):
|
||||
"""验证输入模式:prompt 和 messages 至少有一个"""
|
||||
# 这个验证在 model_validator 中更合适
|
||||
return v
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
{
|
||||
"model_id": "uuid-here",
|
||||
"prompt": "请回答:{{ sys.message }}",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 1000
|
||||
},
|
||||
{
|
||||
"model_id": "uuid-here",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个专业的 AI 助手"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "{{ sys.message }}"
|
||||
}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 1000
|
||||
}
|
||||
]
|
||||
}
|
||||
247
api/app/core/workflow/nodes/llm/node.py
Normal file
247
api/app/core/workflow/nodes/llm/node.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
LLM 节点实现
|
||||
|
||||
调用 LLM 模型进行文本生成。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models import ModelConfig
|
||||
from app.db import get_db, get_db_context
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService
|
||||
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMNode(BaseNode):
|
||||
"""LLM 节点
|
||||
|
||||
支持流式和非流式输出,使用 LangChain 标准的消息格式。
|
||||
|
||||
配置示例(支持多种消息格式):
|
||||
|
||||
1. 简单文本格式:
|
||||
{
|
||||
"type": "llm",
|
||||
"config": {
|
||||
"model_id": "uuid",
|
||||
"prompt": "请分析:{{sys.message}}",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 1000
|
||||
}
|
||||
}
|
||||
|
||||
2. LangChain 消息格式(推荐):
|
||||
{
|
||||
"type": "llm",
|
||||
"config": {
|
||||
"model_id": "uuid",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个专业的 AI 助手。"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "{{sys.message}}"
|
||||
}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 1000
|
||||
}
|
||||
}
|
||||
|
||||
支持的角色类型:
|
||||
- system: 系统消息(SystemMessage)
|
||||
- user/human: 用户消息(HumanMessage)
|
||||
- ai/assistant: AI 消息(AIMessage)
|
||||
"""
|
||||
|
||||
def _prepare_llm(self, state: WorkflowState) -> tuple[RedBearLLM, list | str]:
|
||||
"""准备 LLM 实例(公共逻辑)
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
(llm, messages_or_prompt): LLM 实例和消息列表或 prompt 字符串
|
||||
"""
|
||||
|
||||
# 1. 处理消息格式(优先使用 messages)
|
||||
messages_config = self.config.get("messages")
|
||||
|
||||
if messages_config:
|
||||
# 使用 LangChain 消息格式
|
||||
messages = []
|
||||
for msg_config in messages_config:
|
||||
role = msg_config.get("role", "user").lower()
|
||||
content_template = msg_config.get("content", "")
|
||||
content = self._render_template(content_template, state)
|
||||
|
||||
# 根据角色创建对应的消息对象
|
||||
if role == "system":
|
||||
messages.append(SystemMessage(content=content))
|
||||
elif role in ["user", "human"]:
|
||||
messages.append(HumanMessage(content=content))
|
||||
elif role in ["ai", "assistant"]:
|
||||
messages.append(AIMessage(content=content))
|
||||
else:
|
||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||
messages.append(HumanMessage(content=content))
|
||||
|
||||
prompt_or_messages = messages
|
||||
else:
|
||||
# 使用简单的 prompt 格式(向后兼容)
|
||||
prompt_template = self.config.get("prompt", "")
|
||||
prompt_or_messages = self._render_template(prompt_template, state)
|
||||
|
||||
# 2. 获取模型配置
|
||||
model_id = self.config.get("model_id")
|
||||
if not model_id:
|
||||
raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置")
|
||||
|
||||
# 3. 在 with 块内完成所有数据库操作和数据提取
|
||||
with get_db_context() as db:
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
|
||||
if not config:
|
||||
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
|
||||
|
||||
if not config.api_keys or len(config.api_keys) == 0:
|
||||
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
# 在 Session 关闭前提取所有需要的数据
|
||||
api_config = config.api_keys[0]
|
||||
model_name = api_config.model_name
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
model_type = config.type
|
||||
|
||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||
llm = RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base
|
||||
),
|
||||
type=model_type
|
||||
)
|
||||
|
||||
return llm, prompt_or_messages
|
||||
|
||||
async def execute(self, state: WorkflowState) -> AIMessage:
|
||||
"""非流式执行 LLM 调用
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
LLM 响应消息
|
||||
"""
|
||||
llm, prompt_or_messages = self._prepare_llm(state)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||
|
||||
# 调用 LLM(支持字符串或消息列表)
|
||||
response = await llm.ainvoke(prompt_or_messages)
|
||||
|
||||
# 提取内容
|
||||
if hasattr(response, 'content'):
|
||||
content = response.content
|
||||
else:
|
||||
content = str(response)
|
||||
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
|
||||
|
||||
# 返回 AIMessage(包含响应元数据)
|
||||
return response if isinstance(response, AIMessage) else AIMessage(content=content)
|
||||
|
||||
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""提取输入数据(用于记录)"""
|
||||
_, prompt_or_messages = self._prepare_llm(state)
|
||||
|
||||
return {
|
||||
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
|
||||
"messages": [
|
||||
{"role": msg.__class__.__name__.replace("Message", "").lower(), "content": msg.content}
|
||||
for msg in prompt_or_messages
|
||||
] if isinstance(prompt_or_messages, list) else None,
|
||||
"config": {
|
||||
"model_id": self.config.get("model_id"),
|
||||
"temperature": self.config.get("temperature"),
|
||||
"max_tokens": self.config.get("max_tokens")
|
||||
}
|
||||
}
|
||||
|
||||
def _extract_output(self, business_result: Any) -> str:
|
||||
"""从 AIMessage 中提取文本内容"""
|
||||
if isinstance(business_result, AIMessage):
|
||||
return business_result.content
|
||||
return str(business_result)
|
||||
|
||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||
"""从 AIMessage 中提取 token 使用情况"""
|
||||
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
|
||||
usage = business_result.response_metadata.get('token_usage')
|
||||
if usage:
|
||||
return {
|
||||
"prompt_tokens": usage.get('prompt_tokens', 0),
|
||||
"completion_tokens": usage.get('completion_tokens', 0),
|
||||
"total_tokens": usage.get('total_tokens', 0)
|
||||
}
|
||||
return None
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""流式执行 LLM 调用
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Yields:
|
||||
文本片段(chunk)或完成标记
|
||||
"""
|
||||
llm, prompt_or_messages = self._prepare_llm(state)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
last_chunk = None
|
||||
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
async for chunk in llm.astream(prompt_or_messages):
|
||||
# 提取内容
|
||||
if hasattr(chunk, 'content'):
|
||||
content = chunk.content
|
||||
else:
|
||||
content = str(chunk)
|
||||
|
||||
full_response += content
|
||||
last_chunk = chunk
|
||||
|
||||
# 流式返回每个文本片段
|
||||
yield content
|
||||
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
|
||||
|
||||
# 构建完整的 AIMessage(包含元数据)
|
||||
if isinstance(last_chunk, AIMessage):
|
||||
final_message = AIMessage(
|
||||
content=full_response,
|
||||
response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
|
||||
)
|
||||
else:
|
||||
final_message = AIMessage(content=full_response)
|
||||
|
||||
# yield 完成标记
|
||||
yield {"__final__": True, "result": final_message}
|
||||
93
api/app/core/workflow/nodes/node_factory.py
Normal file
93
api/app/core/workflow/nodes/node_factory.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
节点工厂
|
||||
|
||||
根据节点类型创建相应的节点实例。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.agent import AgentNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.end import EndNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeFactory:
|
||||
"""节点工厂
|
||||
|
||||
使用工厂模式创建节点实例,便于扩展和维护。
|
||||
"""
|
||||
|
||||
# 节点类型注册表
|
||||
_node_types: dict[str, type[BaseNode]] = {
|
||||
NodeType.START: StartNode,
|
||||
NodeType.END: EndNode,
|
||||
NodeType.LLM: LLMNode,
|
||||
NodeType.AGENT: AgentNode,
|
||||
NodeType.TRANSFORM: TransformNode,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_node_type(cls, node_type: str, node_class: type[BaseNode]):
|
||||
"""注册新的节点类型
|
||||
|
||||
Args:
|
||||
node_type: 节点类型名称
|
||||
node_class: 节点类
|
||||
|
||||
Examples:
|
||||
>>> class CustomNode(BaseNode):
|
||||
... async def execute(self, state):
|
||||
... return {"node_outputs": {self.node_id: {"output": "custom"}}}
|
||||
>>> NodeFactory.register_node_type("custom", CustomNode)
|
||||
"""
|
||||
cls._node_types[node_type] = node_class
|
||||
logger.info(f"注册节点类型: {node_type} -> {node_class.__name__}")
|
||||
|
||||
@classmethod
|
||||
def create_node(
|
||||
cls,
|
||||
node_config: dict[str, Any],
|
||||
workflow_config: dict[str, Any]
|
||||
) -> BaseNode | None:
|
||||
"""创建节点实例
|
||||
|
||||
Args:
|
||||
node_config: 节点配置
|
||||
workflow_config: 工作流配置
|
||||
|
||||
Returns:
|
||||
节点实例或 None(对于不支持的节点类型)
|
||||
|
||||
Raises:
|
||||
ValueError: 不支持的节点类型
|
||||
"""
|
||||
node_type = node_config.get("type")
|
||||
|
||||
# 跳过条件节点(由 LangGraph 处理)
|
||||
if node_type == "condition":
|
||||
return None
|
||||
|
||||
# 获取节点类
|
||||
node_class = cls._node_types.get(node_type)
|
||||
if not node_class:
|
||||
raise ValueError(f"不支持的节点类型: {node_type}")
|
||||
|
||||
# 创建节点实例
|
||||
logger.debug(f"创建节点: {node_config.get('id')} (type={node_type})")
|
||||
return node_class(node_config, workflow_config)
|
||||
|
||||
@classmethod
|
||||
def get_supported_types(cls) -> list[str]:
|
||||
"""获取支持的节点类型列表
|
||||
|
||||
Returns:
|
||||
节点类型列表
|
||||
"""
|
||||
return list(cls._node_types.keys())
|
||||
6
api/app/core/workflow/nodes/start/__init__.py
Normal file
6
api/app/core/workflow/nodes/start/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Start 节点"""
|
||||
|
||||
from app.core.workflow.nodes.start.node import StartNode
|
||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
|
||||
__all__ = ["StartNode", "StartNodeConfig"]
|
||||
87
api/app/core/workflow/nodes/start/config.py
Normal file
87
api/app/core/workflow/nodes/start/config.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Start 节点配置"""
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
|
||||
|
||||
class StartNodeConfig(BaseNodeConfig):
|
||||
"""Start 节点配置
|
||||
|
||||
Start 节点的作用:
|
||||
1. 标记工作流的起点
|
||||
2. 定义自定义输入变量(会作为节点输出,通过 start_node_id.variable_name 访问)
|
||||
3. 输出系统变量和会话变量
|
||||
"""
|
||||
|
||||
# 自定义输入变量定义
|
||||
variables: list[VariableDefinition] = Field(
|
||||
default_factory=list,
|
||||
description="自定义输入变量列表,这些变量会作为 Start 节点的输出"
|
||||
)
|
||||
|
||||
# 输出变量定义
|
||||
output_variables: list[VariableDefinition] = Field(
|
||||
default_factory=lambda: [
|
||||
VariableDefinition(
|
||||
name="message",
|
||||
type=VariableType.STRING,
|
||||
description="用户输入的消息"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="conversation_vars",
|
||||
type=VariableType.OBJECT,
|
||||
description="会话级变量"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="execution_id",
|
||||
type=VariableType.STRING,
|
||||
description="执行 ID"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="conversation_id",
|
||||
type=VariableType.STRING,
|
||||
description="会话 ID"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="workspace_id",
|
||||
type=VariableType.STRING,
|
||||
description="工作空间 ID"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="user_id",
|
||||
type=VariableType.STRING,
|
||||
description="用户 ID"
|
||||
)
|
||||
],
|
||||
description="输出变量定义(自动生成,通常不需要修改)"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
{
|
||||
"description": "工作流开始节点",
|
||||
"variables": []
|
||||
},
|
||||
{
|
||||
"description": "带自定义变量的开始节点",
|
||||
"variables": [
|
||||
{
|
||||
"name": "language",
|
||||
"type": "string",
|
||||
"required": False,
|
||||
"default": "zh-CN",
|
||||
"description": "语言设置"
|
||||
},
|
||||
{
|
||||
"name": "max_length",
|
||||
"type": "number",
|
||||
"required": False,
|
||||
"default": 1000,
|
||||
"description": "最大长度"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
136
api/app/core/workflow/nodes/start/node.py
Normal file
136
api/app/core/workflow/nodes/start/node.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Start 节点实现
|
||||
|
||||
工作流的起始节点,定义输入变量并输出系统参数。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StartNode(BaseNode):
|
||||
"""Start 节点
|
||||
|
||||
工作流的起始节点,负责:
|
||||
1. 定义工作流的输入变量(通过配置)
|
||||
2. 输出系统变量(sys.*)
|
||||
3. 输出会话变量(conv.*)
|
||||
|
||||
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
|
||||
"""
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
"""初始化 Start 节点
|
||||
|
||||
Args:
|
||||
node_config: 节点配置
|
||||
workflow_config: 工作流配置
|
||||
"""
|
||||
super().__init__(node_config, workflow_config)
|
||||
|
||||
# 解析并验证配置
|
||||
self.typed_config = StartNodeConfig(**self.config)
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""执行 start 节点业务逻辑
|
||||
|
||||
Start 节点输出系统变量、会话变量和自定义变量。
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
包含系统参数、会话变量和自定义变量的字典
|
||||
"""
|
||||
logger.info(f"节点 {self.node_id} (Start) 开始执行")
|
||||
|
||||
# 创建变量池实例(在方法内复用)
|
||||
pool = self.get_variable_pool(state)
|
||||
|
||||
# 处理自定义变量(传入 pool 避免重复创建)
|
||||
custom_vars = self._process_custom_variables(pool)
|
||||
|
||||
# 返回业务数据(包含自定义变量)
|
||||
result = {
|
||||
"message": pool.get("sys.message"),
|
||||
"execution_id": pool.get("sys.execution_id"),
|
||||
"conversation_id": pool.get("sys.conversation_id"),
|
||||
"workspace_id": pool.get("sys.workspace_id"),
|
||||
"user_id": pool.get("sys.user_id"),
|
||||
**custom_vars # 自定义变量作为节点输出的一部分
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"节点 {self.node_id} (Start) 执行完成,"
|
||||
f"输出了 {len(custom_vars)} 个自定义变量"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _process_custom_variables(self, pool) -> dict[str, Any]:
|
||||
"""处理自定义变量
|
||||
|
||||
从输入数据中提取自定义变量,应用默认值和验证。
|
||||
|
||||
Args:
|
||||
pool: 变量池实例
|
||||
|
||||
Returns:
|
||||
处理后的自定义变量字典
|
||||
|
||||
Raises:
|
||||
ValueError: 缺少必需变量
|
||||
"""
|
||||
# 获取输入数据中的自定义变量
|
||||
input_variables = pool.get("sys.input_variables", default={})
|
||||
|
||||
processed = {}
|
||||
|
||||
# 遍历配置的变量定义
|
||||
for var_def in self.typed_config.variables:
|
||||
var_name = var_def.name
|
||||
|
||||
# 检查变量是否存在
|
||||
if var_name in input_variables:
|
||||
# 使用用户提供的值
|
||||
processed[var_name] = input_variables[var_name]
|
||||
|
||||
elif var_def.required:
|
||||
# 必需变量缺失
|
||||
raise ValueError(
|
||||
f"缺少必需的输入变量: {var_name}"
|
||||
+ (f" ({var_def.description})" if var_def.description else "")
|
||||
)
|
||||
|
||||
elif var_def.default is not None:
|
||||
# 使用默认值
|
||||
processed[var_name] = var_def.default
|
||||
logger.debug(
|
||||
f"变量 '{var_name}' 使用默认值: {var_def.default}"
|
||||
)
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""提取输入数据(用于记录)
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
输入数据字典
|
||||
"""
|
||||
pool = self.get_variable_pool(state)
|
||||
|
||||
return {
|
||||
"execution_id": pool.get("sys.execution_id"),
|
||||
"conversation_id": pool.get("sys.conversation_id"),
|
||||
"message": pool.get("sys.message"),
|
||||
"conversation_vars": pool.get_all_conversation_vars()
|
||||
}
|
||||
6
api/app/core/workflow/nodes/transform/__init__.py
Normal file
6
api/app/core/workflow/nodes/transform/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Transform 节点"""
|
||||
|
||||
from app.core.workflow.nodes.transform.node import TransformNode
|
||||
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||
|
||||
__all__ = ["TransformNode", "TransformNodeConfig"]
|
||||
80
api/app/core/workflow/nodes/transform/config.py
Normal file
80
api/app/core/workflow/nodes/transform/config.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Transform 节点配置"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
|
||||
|
||||
class TransformNodeConfig(BaseNodeConfig):
|
||||
"""Transform 节点配置
|
||||
|
||||
用于数据转换和处理。
|
||||
"""
|
||||
|
||||
transform_type: Literal["template", "code", "json"] = Field(
|
||||
default="template",
|
||||
description="转换类型:template(模板), code(代码), json(JSON处理)"
|
||||
)
|
||||
|
||||
# 模板模式
|
||||
template: str | None = Field(
|
||||
default=None,
|
||||
description="转换模板,支持变量引用"
|
||||
)
|
||||
|
||||
# 代码模式
|
||||
code: str | None = Field(
|
||||
default=None,
|
||||
description="Python 代码,用于数据转换"
|
||||
)
|
||||
|
||||
# JSON 模式
|
||||
json_path: str | None = Field(
|
||||
default=None,
|
||||
description="JSON 路径表达式"
|
||||
)
|
||||
|
||||
# 输入变量
|
||||
inputs: dict[str, str] | None = Field(
|
||||
default=None,
|
||||
description="输入变量映射,key 为变量名,value 为变量选择器"
|
||||
)
|
||||
|
||||
# 输出变量
|
||||
output_key: str = Field(
|
||||
default="result",
|
||||
description="输出变量的键名"
|
||||
)
|
||||
|
||||
# 输出变量定义
|
||||
output_variables: list[VariableDefinition] = Field(
|
||||
default_factory=lambda: [
|
||||
VariableDefinition(
|
||||
name="result",
|
||||
type=VariableType.STRING,
|
||||
description="转换后的结果"
|
||||
)
|
||||
],
|
||||
description="输出变量定义(根据 output_key 动态生成)"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
{
|
||||
"transform_type": "template",
|
||||
"template": "用户问题:{{ sys.message }}\n回答:{{ llm_qa.output }}",
|
||||
"output_key": "formatted_result"
|
||||
},
|
||||
{
|
||||
"transform_type": "code",
|
||||
"code": "result = input_text.upper()",
|
||||
"inputs": {
|
||||
"input_text": "{{ sys.message }}"
|
||||
},
|
||||
"output_key": "uppercase_text"
|
||||
}
|
||||
]
|
||||
}
|
||||
60
api/app/core/workflow/nodes/transform/node.py
Normal file
60
api/app/core/workflow/nodes/transform/node.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Transform 节点实现
|
||||
|
||||
数据转换节点,用于处理和转换数据。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TransformNode(BaseNode):
|
||||
"""数据转换节点
|
||||
|
||||
配置示例:
|
||||
{
|
||||
"type": "transform",
|
||||
"config": {
|
||||
"mapping": {
|
||||
"output_field": "{{node.previous.output}}",
|
||||
"processed": "{{var.input | upper}}"
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""执行数据转换
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
状态更新字典
|
||||
"""
|
||||
logger.info(f"节点 {self.node_id} 开始执行数据转换")
|
||||
|
||||
# 获取映射配置
|
||||
mapping = self.config.get("mapping", {})
|
||||
|
||||
# 执行数据转换
|
||||
transformed_data = {}
|
||||
for target_key, source_template in mapping.items():
|
||||
# 渲染模板获取值
|
||||
value = self._render_template(str(source_template), state)
|
||||
transformed_data[target_key] = value
|
||||
|
||||
logger.info(f"节点 {self.node_id} 数据转换完成,输出字段: {list(transformed_data.keys())}")
|
||||
|
||||
return {
|
||||
"node_outputs": {
|
||||
self.node_id: {
|
||||
"output": transformed_data,
|
||||
"status": "completed"
|
||||
}
|
||||
}
|
||||
}
|
||||
170
api/app/core/workflow/template_loader.py
Normal file
170
api/app/core/workflow/template_loader.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
工作流模板加载器
|
||||
|
||||
从文件系统加载预定义的工作流模板
|
||||
"""
|
||||
|
||||
import os
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class TemplateLoader:
|
||||
"""工作流模板加载器"""
|
||||
|
||||
def __init__(self, templates_dir: str = "app/templates/workflows"):
|
||||
"""初始化模板加载器
|
||||
|
||||
Args:
|
||||
templates_dir: 模板目录路径
|
||||
"""
|
||||
self.templates_dir = Path(templates_dir)
|
||||
if not self.templates_dir.exists():
|
||||
raise ValueError(f"模板目录不存在: {templates_dir}")
|
||||
|
||||
def list_templates(self) -> list[dict]:
|
||||
"""列出所有可用的模板
|
||||
|
||||
Returns:
|
||||
模板列表,每个模板包含 id, name, description 等信息
|
||||
"""
|
||||
templates = []
|
||||
|
||||
# 遍历模板目录
|
||||
for template_dir in self.templates_dir.iterdir():
|
||||
if not template_dir.is_dir():
|
||||
continue
|
||||
|
||||
# 检查是否有 template.yml 文件
|
||||
template_file = template_dir / "template.yml"
|
||||
if not template_file.exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
# 读取模板配置
|
||||
with open(template_file, 'r', encoding='utf-8') as f:
|
||||
template_data = yaml.safe_load(f)
|
||||
|
||||
# 提取模板信息
|
||||
templates.append({
|
||||
"id": template_dir.name,
|
||||
"name": template_data.get("name", template_dir.name),
|
||||
"description": template_data.get("description", ""),
|
||||
"category": template_data.get("category", "general"),
|
||||
"tags": template_data.get("tags", []),
|
||||
"author": template_data.get("author", ""),
|
||||
"version": template_data.get("version", "1.0.0")
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"加载模板 {template_dir.name} 失败: {e}")
|
||||
continue
|
||||
|
||||
return templates
|
||||
|
||||
def load_template(self, template_id: str) -> Optional[dict]:
|
||||
"""加载指定的模板
|
||||
|
||||
Args:
|
||||
template_id: 模板 ID(目录名)
|
||||
|
||||
Returns:
|
||||
模板配置字典,如果模板不存在则返回 None
|
||||
"""
|
||||
template_dir = self.templates_dir / template_id
|
||||
template_file = template_dir / "template.yml"
|
||||
|
||||
if not template_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(template_file, 'r', encoding='utf-8') as f:
|
||||
template_data = yaml.safe_load(f)
|
||||
|
||||
# 返回工作流配置部分
|
||||
return {
|
||||
"name": template_data.get("name", template_id),
|
||||
"description": template_data.get("description", ""),
|
||||
"nodes": template_data.get("nodes", []),
|
||||
"edges": template_data.get("edges", []),
|
||||
"variables": template_data.get("variables", []),
|
||||
"execution_config": template_data.get("execution_config", {}),
|
||||
"triggers": template_data.get("triggers", [])
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"加载模板 {template_id} 失败: {e}")
|
||||
return None
|
||||
|
||||
def get_template_readme(self, template_id: str) -> Optional[str]:
|
||||
"""获取模板的 README 文档
|
||||
|
||||
Args:
|
||||
template_id: 模板 ID
|
||||
|
||||
Returns:
|
||||
README 内容,如果不存在则返回 None
|
||||
"""
|
||||
template_dir = self.templates_dir / template_id
|
||||
readme_file = template_dir / "README.md"
|
||||
|
||||
if not readme_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(readme_file, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
print(f"读取模板 {template_id} 的 README 失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# 全局模板加载器实例
|
||||
_template_loader: Optional[TemplateLoader] = None
|
||||
|
||||
|
||||
def get_template_loader() -> TemplateLoader:
|
||||
"""获取全局模板加载器实例
|
||||
|
||||
Returns:
|
||||
TemplateLoader 实例
|
||||
"""
|
||||
global _template_loader
|
||||
if _template_loader is None:
|
||||
_template_loader = TemplateLoader()
|
||||
return _template_loader
|
||||
|
||||
|
||||
def list_workflow_templates() -> list[dict]:
|
||||
"""列出所有工作流模板
|
||||
|
||||
Returns:
|
||||
模板列表
|
||||
"""
|
||||
loader = get_template_loader()
|
||||
return loader.list_templates()
|
||||
|
||||
|
||||
def load_workflow_template(template_id: str) -> Optional[dict]:
|
||||
"""加载工作流模板
|
||||
|
||||
Args:
|
||||
template_id: 模板 ID
|
||||
|
||||
Returns:
|
||||
模板配置,如果不存在则返回 None
|
||||
"""
|
||||
loader = get_template_loader()
|
||||
return loader.load_template(template_id)
|
||||
|
||||
|
||||
def get_workflow_template_readme(template_id: str) -> Optional[str]:
|
||||
"""获取工作流模板的 README
|
||||
|
||||
Args:
|
||||
template_id: 模板 ID
|
||||
|
||||
Returns:
|
||||
README 内容,如果不存在则返回 None
|
||||
"""
|
||||
loader = get_template_loader()
|
||||
return loader.get_template_readme(template_id)
|
||||
170
api/app/core/workflow/template_renderer.py
Normal file
170
api/app/core/workflow/template_renderer.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
模板渲染器
|
||||
|
||||
使用 Jinja2 提供安全的模板渲染功能,支持变量引用和表达式。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import Template, TemplateSyntaxError, UndefinedError, Environment, StrictUndefined
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TemplateRenderer:
|
||||
"""模板渲染器"""
|
||||
|
||||
def __init__(self, strict: bool = True):
|
||||
"""初始化渲染器
|
||||
|
||||
Args:
|
||||
strict: 是否使用严格模式(未定义变量会抛出异常)
|
||||
"""
|
||||
self.env = Environment(
|
||||
undefined=StrictUndefined if strict else None,
|
||||
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
|
||||
)
|
||||
|
||||
def render(
|
||||
self,
|
||||
template: str,
|
||||
variables: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> str:
|
||||
"""渲染模板
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
variables: 用户定义的变量
|
||||
node_outputs: 节点输出结果
|
||||
system_vars: 系统变量
|
||||
|
||||
Returns:
|
||||
渲染后的字符串
|
||||
|
||||
Raises:
|
||||
ValueError: 模板语法错误或变量未定义
|
||||
|
||||
Examples:
|
||||
>>> renderer = TemplateRenderer()
|
||||
>>> renderer.render(
|
||||
... "Hello {{var.name}}!",
|
||||
... {"name": "World"},
|
||||
... {},
|
||||
... {}
|
||||
... )
|
||||
'Hello World!'
|
||||
|
||||
>>> renderer.render(
|
||||
... "分析结果: {{node.analyze.output}}",
|
||||
... {},
|
||||
... {"analyze": {"output": "正面情绪"}},
|
||||
... {}
|
||||
... )
|
||||
'分析结果: 正面情绪'
|
||||
"""
|
||||
# 构建命名空间上下文
|
||||
context = {
|
||||
"var": variables, # 用户变量:{{var.user_input}}
|
||||
"node": node_outputs, # 节点输出:{{node.node_1.output}}
|
||||
"sys": system_vars or {}, # 系统变量:{{sys.execution_id}}
|
||||
}
|
||||
|
||||
# 支持直接通过节点ID访问节点输出:{{llm_qa.output}}
|
||||
# 将所有节点输出添加到顶层上下文
|
||||
context.update(node_outputs)
|
||||
|
||||
# 为了向后兼容,也支持直接访问用户变量
|
||||
context.update(variables)
|
||||
context["nodes"] = node_outputs # 旧语法兼容
|
||||
|
||||
try:
|
||||
tmpl = self.env.from_string(template)
|
||||
return tmpl.render(**context)
|
||||
|
||||
except TemplateSyntaxError as e:
|
||||
logger.error(f"模板语法错误: {template}, 错误: {e}")
|
||||
raise ValueError(f"模板语法错误: {e}")
|
||||
|
||||
except UndefinedError as e:
|
||||
logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}")
|
||||
raise ValueError(f"未定义的变量: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模板渲染异常: {template}, 错误: {e}")
|
||||
raise ValueError(f"模板渲染失败: {e}")
|
||||
|
||||
def validate(self, template: str) -> list[str]:
|
||||
"""验证模板语法
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
|
||||
Returns:
|
||||
错误列表,如果为空则验证通过
|
||||
|
||||
Examples:
|
||||
>>> renderer = TemplateRenderer()
|
||||
>>> renderer.validate("Hello {{var.name}}!")
|
||||
[]
|
||||
|
||||
>>> renderer.validate("Hello {{var.name") # 缺少结束标记
|
||||
['模板语法错误: ...']
|
||||
"""
|
||||
errors = []
|
||||
|
||||
try:
|
||||
self.env.from_string(template)
|
||||
except TemplateSyntaxError as e:
|
||||
errors.append(f"模板语法错误: {e}")
|
||||
except Exception as e:
|
||||
errors.append(f"模板验证失败: {e}")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
# 全局渲染器实例(严格模式)
|
||||
_default_renderer = TemplateRenderer(strict=True)
|
||||
|
||||
|
||||
def render_template(
|
||||
template: str,
|
||||
variables: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> str:
|
||||
"""渲染模板(便捷函数)
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
variables: 用户变量
|
||||
node_outputs: 节点输出
|
||||
system_vars: 系统变量
|
||||
|
||||
Returns:
|
||||
渲染后的字符串
|
||||
|
||||
Examples:
|
||||
>>> render_template(
|
||||
... "请分析: {{var.text}}",
|
||||
... {"text": "这是一段文本"},
|
||||
... {},
|
||||
... {}
|
||||
... )
|
||||
'请分析: 这是一段文本'
|
||||
"""
|
||||
return _default_renderer.render(template, variables, node_outputs, system_vars)
|
||||
|
||||
|
||||
def validate_template(template: str) -> list[str]:
|
||||
"""验证模板语法(便捷函数)
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
|
||||
Returns:
|
||||
错误列表
|
||||
"""
|
||||
return _default_renderer.validate(template)
|
||||
277
api/app/core/workflow/validator.py
Normal file
277
api/app/core/workflow/validator.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
工作流配置验证器
|
||||
|
||||
验证工作流配置的有效性,确保配置符合规范。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowValidator:
|
||||
"""工作流配置验证器"""
|
||||
|
||||
@staticmethod
|
||||
def validate(workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
|
||||
"""验证工作流配置
|
||||
|
||||
Args:
|
||||
workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型
|
||||
|
||||
Returns:
|
||||
(is_valid, errors): 是否有效和错误列表
|
||||
|
||||
Examples:
|
||||
>>> config = {
|
||||
... "nodes": [
|
||||
... {"id": "start", "type": "start"},
|
||||
... {"id": "end", "type": "end"}
|
||||
... ],
|
||||
... "edges": [
|
||||
... {"source": "start", "target": "end"}
|
||||
... ]
|
||||
... }
|
||||
>>> is_valid, errors = WorkflowValidator.validate(config)
|
||||
>>> is_valid
|
||||
True
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# 支持字典和 Pydantic 模型
|
||||
if isinstance(workflow_config, dict):
|
||||
nodes = workflow_config.get("nodes", [])
|
||||
edges = workflow_config.get("edges", [])
|
||||
variables = workflow_config.get("variables", [])
|
||||
else:
|
||||
# Pydantic 模型
|
||||
nodes = getattr(workflow_config, "nodes", [])
|
||||
edges = getattr(workflow_config, "edges", [])
|
||||
variables = getattr(workflow_config, "variables", [])
|
||||
|
||||
# 1. 验证 start 节点(有且只有一个)
|
||||
start_nodes = [n for n in nodes if n.get("type") == "start"]
|
||||
if len(start_nodes) == 0:
|
||||
errors.append("工作流必须有一个 start 节点")
|
||||
elif len(start_nodes) > 1:
|
||||
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个")
|
||||
|
||||
# 2. 验证 end 节点(至少一个)
|
||||
end_nodes = [n for n in nodes if n.get("type") == "end"]
|
||||
if len(end_nodes) == 0:
|
||||
errors.append("工作流必须至少有一个 end 节点")
|
||||
|
||||
# 3. 验证节点 ID 唯一性
|
||||
node_ids = [n.get("id") for n in nodes]
|
||||
if len(node_ids) != len(set(node_ids)):
|
||||
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
|
||||
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
|
||||
|
||||
# 4. 验证节点必须有 id 和 type
|
||||
for i, node in enumerate(nodes):
|
||||
if not node.get("id"):
|
||||
errors.append(f"节点 #{i} 缺少 id 字段")
|
||||
if not node.get("type"):
|
||||
errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段")
|
||||
|
||||
# 5. 验证边的有效性
|
||||
node_id_set = set(node_ids)
|
||||
for i, edge in enumerate(edges):
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
|
||||
if not source:
|
||||
errors.append(f"边 #{i} 缺少 source 字段")
|
||||
elif source not in node_id_set:
|
||||
errors.append(f"边 #{i} 的 source 节点不存在: {source}")
|
||||
|
||||
if not target:
|
||||
errors.append(f"边 #{i} 缺少 target 字段")
|
||||
elif target not in node_id_set:
|
||||
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
|
||||
|
||||
# 6. 验证所有节点可达(从 start 节点出发)
|
||||
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
||||
reachable = WorkflowValidator._get_reachable_nodes(
|
||||
start_nodes[0]["id"],
|
||||
edges
|
||||
)
|
||||
unreachable = node_id_set - reachable
|
||||
if unreachable:
|
||||
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
|
||||
|
||||
# 7. 检测循环依赖(非 loop 节点)
|
||||
if not errors: # 只有在前面验证通过时才检查循环
|
||||
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
|
||||
if has_cycle:
|
||||
errors.append(
|
||||
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
|
||||
)
|
||||
|
||||
# 8. 验证变量名
|
||||
from app.core.workflow.expression_evaluator import ExpressionEvaluator
|
||||
var_errors = ExpressionEvaluator.validate_variable_names(variables)
|
||||
errors.extend(var_errors)
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
@staticmethod
|
||||
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
|
||||
"""获取从 start 节点可达的所有节点
|
||||
|
||||
Args:
|
||||
start_id: 起始节点 ID
|
||||
edges: 边列表
|
||||
|
||||
Returns:
|
||||
可达节点 ID 集合
|
||||
"""
|
||||
reachable = {start_id}
|
||||
queue = [start_id]
|
||||
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
for edge in edges:
|
||||
if edge.get("source") == current:
|
||||
target = edge.get("target")
|
||||
if target and target not in reachable:
|
||||
reachable.add(target)
|
||||
queue.append(target)
|
||||
|
||||
return reachable
|
||||
|
||||
@staticmethod
|
||||
def _has_cycle(nodes: list[dict], edges: list[dict]) -> tuple[bool, list[str]]:
|
||||
"""检测是否存在循环依赖(DFS)
|
||||
|
||||
Args:
|
||||
nodes: 节点列表
|
||||
edges: 边列表
|
||||
|
||||
Returns:
|
||||
(has_cycle, cycle_path): 是否有循环和循环路径
|
||||
"""
|
||||
# 排除 loop 类型的节点
|
||||
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
|
||||
|
||||
# 构建邻接表(排除 loop 节点的边和错误边)
|
||||
graph: dict[str, list[str]] = {}
|
||||
for edge in edges:
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
edge_type = edge.get("type")
|
||||
|
||||
# 跳过错误边
|
||||
if edge_type == "error":
|
||||
continue
|
||||
|
||||
# 如果涉及 loop 节点,跳过
|
||||
if source in loop_nodes or target in loop_nodes:
|
||||
continue
|
||||
|
||||
if source and target:
|
||||
if source not in graph:
|
||||
graph[source] = []
|
||||
graph[source].append(target)
|
||||
|
||||
# DFS 检测环
|
||||
visited = set()
|
||||
rec_stack = set()
|
||||
path = []
|
||||
cycle_path = []
|
||||
|
||||
def dfs(node: str) -> bool:
|
||||
"""DFS 检测环,返回是否找到环"""
|
||||
visited.add(node)
|
||||
rec_stack.add(node)
|
||||
path.append(node)
|
||||
|
||||
for neighbor in graph.get(node, []):
|
||||
if neighbor not in visited:
|
||||
if dfs(neighbor):
|
||||
return True
|
||||
elif neighbor in rec_stack:
|
||||
# 找到环,记录环路径
|
||||
cycle_start = path.index(neighbor)
|
||||
cycle_path.extend([*path[cycle_start:], neighbor])
|
||||
return True
|
||||
|
||||
rec_stack.remove(node)
|
||||
path.pop()
|
||||
return False
|
||||
|
||||
# 检查所有节点
|
||||
for node_id in graph:
|
||||
if node_id not in visited:
|
||||
if dfs(node_id):
|
||||
return True, cycle_path
|
||||
|
||||
return False, []
|
||||
|
||||
@staticmethod
|
||||
def validate_for_publish(workflow_config: dict[str, Any]) -> tuple[bool, list[str]]:
|
||||
"""验证工作流配置是否可以发布(更严格的验证)
|
||||
|
||||
Args:
|
||||
workflow_config: 工作流配置
|
||||
|
||||
Returns:
|
||||
(is_valid, errors): 是否有效和错误列表
|
||||
"""
|
||||
# 先执行基础验证
|
||||
is_valid, errors = WorkflowValidator.validate(workflow_config)
|
||||
|
||||
if not is_valid:
|
||||
return False, errors
|
||||
|
||||
# 额外的发布验证
|
||||
nodes = workflow_config.get("nodes", [])
|
||||
|
||||
# 1. 验证所有节点都有名称
|
||||
for node in nodes:
|
||||
if node.get("type") not in ["start", "end"] and not node.get("name"):
|
||||
errors.append(
|
||||
f"节点 {node.get('id')} 缺少名称(发布时必须提供)"
|
||||
)
|
||||
|
||||
# 2. 验证所有非 start/end 节点都有配置
|
||||
for node in nodes:
|
||||
node_type = node.get("type")
|
||||
if node_type not in ["start", "end"]:
|
||||
config = node.get("config")
|
||||
if not config or not isinstance(config, dict):
|
||||
errors.append(
|
||||
f"节点 {node.get('id')} 缺少配置(发布时必须提供)"
|
||||
)
|
||||
|
||||
# 3. 验证必填变量
|
||||
variables = workflow_config.get("variables", [])
|
||||
required_vars = [v for v in variables if v.get("required")]
|
||||
if required_vars:
|
||||
# 这里只是提示,实际执行时会检查
|
||||
logger.info(
|
||||
f"工作流包含 {len(required_vars)} 个必填变量: "
|
||||
f"{[v.get('name') for v in required_vars]}"
|
||||
)
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
|
||||
def validate_workflow_config(
|
||||
workflow_config: dict[str, Any],
|
||||
for_publish: bool = False
|
||||
) -> tuple[bool, list[str]]:
|
||||
"""验证工作流配置(便捷函数)
|
||||
|
||||
Args:
|
||||
workflow_config: 工作流配置
|
||||
for_publish: 是否为发布验证(更严格)
|
||||
|
||||
Returns:
|
||||
(is_valid, errors): 是否有效和错误列表
|
||||
"""
|
||||
if for_publish:
|
||||
return WorkflowValidator.validate_for_publish(workflow_config)
|
||||
else:
|
||||
return WorkflowValidator.validate(workflow_config)
|
||||
293
api/app/core/workflow/variable_pool.py
Normal file
293
api/app/core/workflow/variable_pool.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""
|
||||
变量池 (Variable Pool)
|
||||
|
||||
工作流执行的数据中心,管理所有变量的存储和访问。
|
||||
|
||||
变量类型:
|
||||
1. 系统变量 (sys.*) - 系统内置变量(execution_id, workspace_id, user_id, message 等)
|
||||
2. 节点输出 (node_id.*) - 节点执行结果
|
||||
3. 会话变量 (conv.*) - 会话级变量(跨多轮对话保持)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VariableSelector:
|
||||
"""变量选择器
|
||||
|
||||
用于引用变量的路径表示。
|
||||
|
||||
Examples:
|
||||
>>> selector = VariableSelector(["sys", "message"])
|
||||
>>> selector = VariableSelector(["node_A", "output"])
|
||||
>>> selector = VariableSelector.from_string("sys.message")
|
||||
"""
|
||||
|
||||
def __init__(self, path: list[str]):
|
||||
"""初始化变量选择器
|
||||
|
||||
Args:
|
||||
path: 变量路径,如 ["sys", "message"] 或 ["node_A", "output"]
|
||||
"""
|
||||
if not path or len(path) < 1:
|
||||
raise ValueError("变量路径不能为空")
|
||||
|
||||
self.path = path
|
||||
self.namespace = path[0] # sys, var, 或 node_id
|
||||
self.key = path[1] if len(path) > 1 else None
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, selector_str: str) -> "VariableSelector":
|
||||
"""从字符串创建选择器
|
||||
|
||||
Args:
|
||||
selector_str: 选择器字符串,如 "sys.message" 或 "node_A.output"
|
||||
|
||||
Returns:
|
||||
VariableSelector 实例
|
||||
|
||||
Examples:
|
||||
>>> selector = VariableSelector.from_string("sys.message")
|
||||
>>> selector = VariableSelector.from_string("llm_qa.output")
|
||||
"""
|
||||
path = selector_str.split(".")
|
||||
return cls(path)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return ".".join(self.path)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"VariableSelector({self.path})"
|
||||
|
||||
|
||||
class VariablePool:
|
||||
"""变量池
|
||||
|
||||
管理工作流执行过程中的所有变量。
|
||||
|
||||
变量命名空间:
|
||||
- sys.*: 系统变量(message, execution_id, workspace_id, user_id, conversation_id)
|
||||
- conv.*: 会话变量(跨多轮对话保持的变量)
|
||||
- <node_id>.*: 节点输出
|
||||
|
||||
Examples:
|
||||
>>> pool = VariablePool(state)
|
||||
>>> pool.get(["sys", "message"])
|
||||
"用户的问题"
|
||||
>>> pool.get(["llm_qa", "output"])
|
||||
"AI 的回答"
|
||||
>>> pool.set(["conv", "user_name"], "张三")
|
||||
"""
|
||||
|
||||
def __init__(self, state: dict[str, Any]):
|
||||
"""初始化变量池
|
||||
|
||||
Args:
|
||||
state: 工作流状态(LangGraph State)
|
||||
"""
|
||||
self.state = state
|
||||
|
||||
def get(self, selector: list[str] | str, default: Any = None) -> Any:
|
||||
"""获取变量值
|
||||
|
||||
Args:
|
||||
selector: 变量选择器,可以是列表或字符串
|
||||
default: 默认值(变量不存在时返回)
|
||||
|
||||
Returns:
|
||||
变量值
|
||||
|
||||
Examples:
|
||||
>>> pool.get(["sys", "message"])
|
||||
>>> pool.get("sys.message")
|
||||
>>> pool.get(["llm_qa", "output"])
|
||||
>>> pool.get("llm_qa.output")
|
||||
|
||||
Raises:
|
||||
KeyError: 变量不存在且未提供默认值
|
||||
"""
|
||||
# 转换为 VariableSelector
|
||||
if isinstance(selector, str):
|
||||
selector = VariableSelector.from_string(selector).path
|
||||
|
||||
if not selector or len(selector) < 1:
|
||||
raise ValueError("变量选择器不能为空")
|
||||
|
||||
namespace = selector[0]
|
||||
|
||||
try:
|
||||
# 系统变量
|
||||
if namespace == "sys":
|
||||
key = selector[1] if len(selector) > 1 else None
|
||||
if not key:
|
||||
return self.state.get("variables", {}).get("sys", {})
|
||||
return self.state.get("variables", {}).get("sys", {}).get(key, default)
|
||||
|
||||
# 会话变量
|
||||
elif namespace == "conv":
|
||||
key = selector[1] if len(selector) > 1 else None
|
||||
if not key:
|
||||
return self.state.get("variables", {}).get("conv", {})
|
||||
return self.state.get("variables", {}).get("conv", {}).get(key, default)
|
||||
|
||||
# 节点输出(从 runtime_vars 读取)
|
||||
else:
|
||||
node_id = namespace
|
||||
runtime_vars = self.state.get("runtime_vars", {})
|
||||
|
||||
if node_id not in runtime_vars:
|
||||
if default is not None:
|
||||
return default
|
||||
raise KeyError(f"节点 '{node_id}' 的输出不存在")
|
||||
|
||||
node_var = runtime_vars[node_id]
|
||||
|
||||
# 如果只有节点 ID,返回整个变量
|
||||
if len(selector) == 1:
|
||||
return node_var
|
||||
|
||||
# 获取特定字段
|
||||
# 支持嵌套访问,如 node_id.field.subfield
|
||||
result = node_var
|
||||
for k in selector[1:]:
|
||||
if isinstance(result, dict):
|
||||
result = result.get(k)
|
||||
if result is None:
|
||||
if default is not None:
|
||||
return default
|
||||
raise KeyError(f"字段 '{'.'.join(selector)}' 不存在")
|
||||
else:
|
||||
if default is not None:
|
||||
return default
|
||||
raise KeyError(f"无法访问 '{'.'.join(selector)}'")
|
||||
|
||||
return result
|
||||
|
||||
except KeyError:
|
||||
if default is not None:
|
||||
return default
|
||||
raise
|
||||
|
||||
def set(self, selector: list[str] | str, value: Any):
|
||||
"""设置变量值
|
||||
|
||||
Args:
|
||||
selector: 变量选择器
|
||||
value: 变量值
|
||||
|
||||
Examples:
|
||||
>>> pool.set(["conv", "user_name"], "张三")
|
||||
>>> pool.set("conv.user_name", "张三")
|
||||
|
||||
Note:
|
||||
- 只能设置会话变量 (conv.*)
|
||||
- 系统变量和节点输出是只读的
|
||||
"""
|
||||
# 转换为 VariableSelector
|
||||
if isinstance(selector, str):
|
||||
selector = VariableSelector.from_string(selector).path
|
||||
|
||||
if not selector or len(selector) < 2:
|
||||
raise ValueError("变量选择器必须包含命名空间和键名")
|
||||
|
||||
namespace = selector[0]
|
||||
|
||||
if namespace != "conv":
|
||||
raise ValueError("只能设置会话变量 (conv.*)")
|
||||
|
||||
key = selector[1]
|
||||
|
||||
# 确保 variables 结构存在
|
||||
if "variables" not in self.state:
|
||||
self.state["variables"] = {"sys": {}, "conv": {}}
|
||||
if "conv" not in self.state["variables"]:
|
||||
self.state["variables"]["conv"] = {}
|
||||
|
||||
# 设置值
|
||||
self.state["variables"]["conv"][key] = value
|
||||
|
||||
logger.debug(f"设置变量: {'.'.join(selector)} = {value}")
|
||||
|
||||
def has(self, selector: list[str] | str) -> bool:
|
||||
"""检查变量是否存在
|
||||
|
||||
Args:
|
||||
selector: 变量选择器
|
||||
|
||||
Returns:
|
||||
变量是否存在
|
||||
|
||||
Examples:
|
||||
>>> pool.has(["sys", "message"])
|
||||
True
|
||||
>>> pool.has("llm_qa.output")
|
||||
False
|
||||
"""
|
||||
try:
|
||||
self.get(selector)
|
||||
return True
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
def get_all_system_vars(self) -> dict[str, Any]:
|
||||
"""获取所有系统变量
|
||||
|
||||
Returns:
|
||||
系统变量字典
|
||||
"""
|
||||
return self.state.get("variables", {}).get("sys", {})
|
||||
|
||||
def get_all_conversation_vars(self) -> dict[str, Any]:
|
||||
"""获取所有会话变量
|
||||
|
||||
Returns:
|
||||
会话变量字典
|
||||
"""
|
||||
return self.state.get("variables", {}).get("conv", {})
|
||||
|
||||
def get_all_node_outputs(self) -> dict[str, Any]:
|
||||
"""获取所有节点输出(运行时变量)
|
||||
|
||||
Returns:
|
||||
节点输出字典,键为节点 ID
|
||||
"""
|
||||
return self.state.get("runtime_vars", {})
|
||||
|
||||
def get_node_output(self, node_id: str) -> dict[str, Any] | None:
|
||||
"""获取指定节点的输出(运行时变量)
|
||||
|
||||
Args:
|
||||
node_id: 节点 ID
|
||||
|
||||
Returns:
|
||||
节点输出或 None
|
||||
"""
|
||||
return self.state.get("runtime_vars", {}).get(node_id)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""导出为字典
|
||||
|
||||
Returns:
|
||||
包含所有变量的字典
|
||||
"""
|
||||
return {
|
||||
"system": self.get_all_system_vars(),
|
||||
"conversation": self.get_all_conversation_vars(),
|
||||
"nodes": self.get_all_node_outputs() # 从 runtime_vars 读取
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
sys_vars = self.get_all_system_vars()
|
||||
conv_vars = self.get_all_conversation_vars()
|
||||
runtime_vars = self.get_all_node_outputs()
|
||||
|
||||
return (
|
||||
f"VariablePool(\n"
|
||||
f" system_vars={len(sys_vars)},\n"
|
||||
f" conversation_vars={len(conv_vars)},\n"
|
||||
f" runtime_vars={len(runtime_vars)}\n"
|
||||
f")"
|
||||
)
|
||||
Reference in New Issue
Block a user