[ADD] Merge code

This commit is contained in:
Mark
2025-12-15 19:50:21 +08:00
parent ea0a445d5b
commit 7bbef35b7d
54 changed files with 6956 additions and 652 deletions

View File

@@ -27,6 +27,7 @@ from . import (
release_share_controller,
public_share_controller,
multi_agent_controller,
workflow_controller,
)
# 创建管理端 API 路由器
@@ -56,5 +57,6 @@ manager_router.include_router(release_share_controller.router)
manager_router.include_router(public_share_controller.router) # 公开路由(无需认证)
manager_router.include_router(memory_dashboard_controller.router)
manager_router.include_router(multi_agent_controller.router)
manager_router.include_router(workflow_controller.router)
__all__ = ["manager_router"]

View File

@@ -1,7 +1,6 @@
"""API Key 管理接口 - 基于 JWT 认证"""
import uuid
from typing import Optional
from datetime import datetime
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
@@ -14,6 +13,7 @@ from app.core.response_utils import success
from app.schemas import api_key_schema
from app.schemas.response_schema import ApiResponse
from app.services.api_key_service import ApiKeyService
from app.core.api_key_utils import timestamp_to_datetime
from app.core.logging_config import get_api_logger
from app.core.exceptions import (
BusinessException,
@@ -41,18 +41,14 @@ def create_api_key(
workspace_id = current_user.current_workspace_id
# 创建 API Key
api_key_obj, api_key = ApiKeyService.create_api_key(
api_key_obj = ApiKeyService.create_api_key(
db,
workspace_id=workspace_id,
user_id=current_user.id,
data=data
)
# 返回包含明文 Key 的响应
response_data = api_key_schema.ApiKeyResponse(
**api_key_obj.__dict__,
api_key=api_key
)
response_data = api_key_schema.ApiKeyResponse.model_validate(api_key_obj)
return success(data=response_data, msg="API Key 创建成功")
except BusinessException:
@@ -223,13 +219,9 @@ def regenerate_api_key(
"""
try:
workspace_id = current_user.current_workspace_id
api_key_obj, api_key = ApiKeyService.regenerate_api_key(db, api_key_id, workspace_id)
api_key_obj = ApiKeyService.regenerate_api_key(db, api_key_id, workspace_id)
# 返回包含明文 Key 的响应
response_data = api_key_schema.ApiKeyResponse(
**api_key_obj.__dict__,
api_key=api_key
)
response_data = api_key_schema.ApiKeyResponse.model_validate(api_key_obj)
logger.info("API Key 重新生成成功", extra={
"api_key_id": str(api_key_id),
@@ -283,8 +275,8 @@ def get_api_key_stats(
@cur_workspace_access_guard()
def get_api_key_logs(
api_key_id: uuid.UUID,
start_date: Optional[datetime] = Query(None, description="开始日期"),
end_date: Optional[datetime] = Query(None, description="结束日期"),
start_date: Optional[int] = Query(None, description="开始日期时间戳"),
end_date: Optional[int] = Query(None, description="结束日期时间戳"),
status_code: Optional[int] = Query(None, description="HTTP状态码过滤"),
endpoint: Optional[str] = Query(None, description="端点路径过滤"),
page: int = Query(1, ge=1, description="页码"),
@@ -302,14 +294,17 @@ def get_api_key_logs(
try:
workspace_id = current_user.current_workspace_id
start_datetime = timestamp_to_datetime(start_date) if start_date else None
end_datetime = timestamp_to_datetime(end_date) if end_date else None
# 验证日期范围
if start_date and end_date and start_date > end_date:
if start_datetime and end_datetime and start_datetime > end_datetime:
logger.warning("开始日期晚于结束日期", extra={
"api_key_id": str(api_key_id),
"workspace_id": str(workspace_id),
"user_id": str(current_user.id),
"start_date": start_date.isoformat(),
"end_date": end_date.isoformat()
"start_date": start_datetime.isoformat(),
"end_date": end_datetime.isoformat()
})
raise BusinessException("开始日期不能晚于结束日期", BizCode.INVALID_PARAMETER)
@@ -325,8 +320,8 @@ def get_api_key_logs(
# 构建过滤条件
filters = {
"start_date": start_date,
"end_date": end_date,
"start_date": start_datetime,
"end_date": end_datetime,
"status_code": status_code,
"endpoint": endpoint
}

View File

@@ -1,22 +1,26 @@
import uuid
from typing import Optional
from fastapi import APIRouter, Depends
from typing import Optional, Annotated
from fastapi import APIRouter, Depends, Path
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from app.db import get_db
from app.core.response_utils import success
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.models import User
from app.models.app_model import AppType, App
from app.repositories import knowledge_repository
from app.schemas import app_schema
from app.schemas.response_schema import PageData, PageMeta
from app.schemas.workflow_schema import WorkflowConfigUpdate
from app.services import app_service, workspace_service
from app.services.app_service import AppService
from app.services.agent_config_helper import enrich_agent_config
from app.dependencies import get_current_user, cur_workspace_access_guard, workspace_access_guard
from fastapi.responses import StreamingResponse
from app.models.app_model import AppType
from app.core.error_codes import BizCode
from app.services.app_service import AppService
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
from app.services.workflow_service import WorkflowService, get_workflow_service
router = APIRouter(prefix="/apps", tags=["Apps"])
logger = get_business_logger()
@@ -48,7 +52,7 @@ def list_apps(
current_user=Depends(get_current_user),
):
"""列出应用
- 默认包含本工作空间的应用和分享给本工作空间的应用
- 设置 include_shared=false 可以只查看本工作空间的应用
"""
@@ -63,8 +67,8 @@ def list_apps(
include_shared=include_shared,
page=page,
pagesize=pagesize,
)
)
# 使用 AppService 的转换方法来设置 is_shared 字段
service = app_service.AppService(db)
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
@@ -79,14 +83,14 @@ def get_app(
current_user=Depends(get_current_user),
):
"""获取应用详细信息
- 支持获取本工作空间的应用
- 支持获取分享给本工作空间的应用
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
app = service.get_app(app_id, workspace_id)
# 转换为 Schema 并设置 is_shared 字段
app_schema_obj = service._convert_to_schema(app, workspace_id)
return success(data=app_schema_obj)
@@ -113,7 +117,7 @@ def delete_app(
current_user=Depends(get_current_user),
):
"""删除应用
会级联删除:
- Agent 配置
- 发布版本
@@ -128,9 +132,9 @@ def delete_app(
"workspace_id": str(workspace_id)
}
)
app_service.delete_app(db, app_id=app_id, workspace_id=workspace_id)
return success(msg="应用删除成功")
@@ -143,7 +147,7 @@ def copy_app(
current_user=Depends(get_current_user),
):
"""复制应用(包括基础信息和配置)
- 复制应用的基础信息(名称、描述、图标等)
- 复制 Agent 配置(如果是 agent 类型)
- 新应用默认为草稿状态
@@ -159,7 +163,7 @@ def copy_app(
"new_name": new_name
}
)
service = AppService(db)
new_app = service.copy_app(
app_id=app_id,
@@ -167,7 +171,7 @@ def copy_app(
workspace_id=workspace_id,
new_name=new_name
)
return success(data=app_schema.App.model_validate(new_app), msg="应用复制成功")
@@ -209,9 +213,9 @@ def publish_app(
):
workspace_id = current_user.current_workspace_id
release = app_service.publish(
db,
app_id=app_id,
publisher_id=current_user.id,
db,
app_id=app_id,
publisher_id=current_user.id,
workspace_id=workspace_id,
version_name = payload.version_name,
release_notes=payload.release_notes
@@ -268,13 +272,13 @@ def share_app(
current_user=Depends(get_current_user),
):
"""分享应用到其他工作空间
- 只能分享自己工作空间的应用
- 不能分享到自己的工作空间
- 同一个应用不能重复分享到同一个工作空间
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
shares = service.share_app(
app_id=app_id,
@@ -282,7 +286,7 @@ def share_app(
user_id=current_user.id,
workspace_id=workspace_id
)
data = [app_schema.AppShare.model_validate(s) for s in shares]
return success(data=data, msg=f"应用已分享到 {len(shares)} 个工作空间")
@@ -296,18 +300,18 @@ def unshare_app(
current_user=Depends(get_current_user),
):
"""取消应用分享
- 只能取消自己工作空间应用的分享
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
service.unshare_app(
app_id=app_id,
target_workspace_id=target_workspace_id,
workspace_id=workspace_id
)
return success(msg="应用分享已取消")
@@ -319,17 +323,17 @@ def list_app_shares(
current_user=Depends(get_current_user),
):
"""列出应用的所有分享记录
- 只能查看自己工作空间应用的分享记录
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
shares = service.list_app_shares(
app_id=app_id,
workspace_id=workspace_id
)
data = [app_schema.AppShare.model_validate(s) for s in shares]
return success(data=data)
@@ -340,10 +344,11 @@ async def draft_run(
payload: app_schema.DraftRunRequest,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
workflow_service: Annotated[WorkflowService, Depends(get_workflow_service)] = None
):
"""
试运行 Agent使用当前的草稿配置未发布的配置
- 不需要发布应用即可测试
- 使用当前的 AgentConfig 配置
- 支持流式和非流式返回
@@ -367,33 +372,44 @@ async def draft_run(
)
if knowledge: user_rag_memory_id = str(knowledge.id)
# 提前验证和准备(在流式响应开始前完成)
from app.services.app_service import AppService
from app.services.multi_agent_service import MultiAgentService
from app.models import AgentConfig, ModelConfig
from sqlalchemy import select
from app.core.exceptions import BusinessException
from app.services.draft_run_service import DraftRunService
service = AppService(db)
draft_service = DraftRunService(db)
# 1. 验证应用
app = service._get_app_or_404(app_id)
if app.type != AppType.AGENT and app.type != AppType.MULTI_AGENT:
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
if app.type != AppType.AGENT and app.type != AppType.MULTI_AGENT and app.type != AppType.WORKFLOW:
raise BusinessException("只有 Agent , Workflow 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
# 只读操作,允许访问共享应用
service._validate_app_accessible(app, workspace_id)
# 处理会话ID创建或验证
conversation_id = await draft_service._ensure_conversation(
conversation_id=payload.conversation_id,
app_id=app_id,
workspace_id=workspace_id,
user_id=payload.user_id
)
payload.conversation_id = conversation_id
if app.type == AppType.AGENT:
service._check_agent_config(app_id)
# 2. 获取 Agent 配置
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = db.scalars(stmt).first()
if not agent_cfg:
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
# 3. 获取模型配置
model_config = None
if agent_cfg.default_model_config_id:
@@ -401,12 +417,12 @@ async def draft_run(
if not model_config:
from app.core.exceptions import ResourceNotFoundException
raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id))
# 流式返回
if payload.stream:
async def event_generator():
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
async for event in draft_service.run_stream(
agent_config=agent_cfg,
model_config=model_config,
@@ -419,7 +435,7 @@ async def draft_run(
user_rag_memory_id=user_rag_memory_id
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
@@ -429,7 +445,7 @@ async def draft_run(
"X-Accel-Buffering": "no"
}
)
# 非流式返回
logger.debug(
"开始非流式试运行",
@@ -440,7 +456,7 @@ async def draft_run(
"has_variables": bool(payload.variables)
}
)
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
result = await draft_service.run(
@@ -454,7 +470,7 @@ async def draft_run(
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
logger.debug(
"试运行返回结果",
extra={
@@ -462,7 +478,7 @@ async def draft_run(
"result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict"
}
)
# 验证结果
try:
validated_result = app_schema.DraftRunResponse.model_validate(result)
@@ -481,10 +497,10 @@ async def draft_run(
elif app.type == AppType.MULTI_AGENT:
# 1. 检查多智能体配置完整性
service._check_multi_agent_config(app_id)
# 2. 构建多智能体运行请求
from app.schemas.multi_agent_schema import MultiAgentRunRequest
multi_agent_request = MultiAgentRunRequest(
message=payload.message,
conversation_id=payload.conversation_id,
@@ -492,7 +508,7 @@ async def draft_run(
variables=payload.variables or {},
use_llm_routing=True # 默认启用 LLM 路由
)
# 3. 流式返回
if payload.stream:
logger.debug(
@@ -503,11 +519,11 @@ async def draft_run(
"has_conversation_id": bool(payload.conversation_id)
}
)
async def event_generator():
"""多智能体流式事件生成器"""
multiservice = MultiAgentService(db)
# 调用多智能体服务的流式方法
async for event in multiservice.run_stream(
app_id=app_id,
@@ -517,7 +533,7 @@ async def draft_run(
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
@@ -527,7 +543,7 @@ async def draft_run(
"X-Accel-Buffering": "no"
}
)
# 4. 非流式返回
logger.debug(
"开始多智能体非流式试运行",
@@ -537,10 +553,10 @@ async def draft_run(
"has_conversation_id": bool(payload.conversation_id)
}
)
multiservice = MultiAgentService(db)
result = await multiservice.run(app_id, multi_agent_request)
logger.debug(
"多智能体试运行返回结果",
extra={
@@ -548,12 +564,71 @@ async def draft_run(
"has_response": "response" in result if isinstance(result, dict) else False
}
)
return success(
data=result,
msg="多 Agent 任务执行成功"
)
elif app.type == AppType.WORKFLOW: #工作流
config = workflow_service.check_config(app_id)
# 3. 流式返回
if payload.stream:
logger.debug(
"开始多智能体流式试运行",
extra={
"app_id": str(app_id),
"message_length": len(payload.message),
"has_conversation_id": bool(payload.conversation_id)
}
)
async def event_generator():
"""多智能体流式事件生成器"""
multiservice = MultiAgentService(db)
# 调用多智能体服务的流式方法
async for event in multiservice.run_stream(
app_id=app_id,
request=multi_agent_request,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
# 4. 非流式返回
logger.debug(
"开始非流式试运行",
extra={
"app_id": str(app_id),
"message_length": len(payload.message),
"has_conversation_id": bool(payload.conversation_id)
}
)
result = await workflow_service.run(app_id, payload,config)
logger.debug(
"工作流试运行返回结果",
extra={
"result_type": str(type(result)),
"has_response": "response" in result if isinstance(result, dict) else False
}
)
return success(
data=result,
msg="工作流任务执行成功"
)
@@ -567,21 +642,21 @@ async def draft_run_compare(
):
"""
多模型对比试运行
- 支持对比 1-5 个模型
- 可以是不同的模型,也可以是同一模型的不同参数配置
- 通过 model_parameters 覆盖默认参数
- 支持并行或串行执行(非流式)
- 支持流式返回(串行执行)
- 返回每个模型的运行结果和性能对比
使用场景:
1. 对比不同模型的效果GPT-4 vs Claude vs Gemini
2. 调优模型参数(不同 temperature 的效果对比)
3. 性能和成本分析
"""
workspace_id = current_user.current_workspace_id
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
@@ -597,7 +672,7 @@ async def draft_run_compare(
workspace_id=workspace_id
)
if knowledge: user_rag_memory_id = str(knowledge.id)
logger.info(
"多模型对比试运行",
extra={
@@ -607,13 +682,13 @@ async def draft_run_compare(
"stream": payload.stream
}
)
# 提前验证和准备(在流式响应开始前完成)
from app.services.app_service import AppService
from app.models import ModelConfig
service = AppService(db)
# 1. 验证应用和权限
app = service._get_app_or_404(app_id)
if app.type != "agent":
@@ -621,7 +696,7 @@ async def draft_run_compare(
from app.core.error_codes import BizCode
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
service._validate_app_accessible(app, workspace_id)
# 2. 获取 Agent 配置
from sqlalchemy import select
from app.models import AgentConfig
@@ -631,7 +706,7 @@ async def draft_run_compare(
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
# 3. 验证所有模型配置
model_configs = []
for model_item in payload.models:
@@ -639,12 +714,12 @@ async def draft_run_compare(
if not model_config:
from app.core.exceptions import ResourceNotFoundException
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
merged_parameters = {
**(agent_cfg.model_parameters or {}),
**(model_item.model_parameters or {})
}
model_configs.append({
"model_config": model_config,
"parameters": merged_parameters,
@@ -652,7 +727,7 @@ async def draft_run_compare(
"model_config_id": model_item.model_config_id,
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
})
# 流式返回
if payload.stream:
async def event_generator():
@@ -674,7 +749,7 @@ async def draft_run_compare(
timeout=payload.timeout or 60
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
@@ -684,7 +759,7 @@ async def draft_run_compare(
"X-Accel-Buffering": "no"
}
)
# 非流式返回
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
@@ -703,7 +778,7 @@ async def draft_run_compare(
parallel=payload.parallel,
timeout=payload.timeout or 60
)
logger.info(
"多模型对比完成",
extra={
@@ -712,5 +787,36 @@ async def draft_run_compare(
"failed": result["failed_count"]
}
)
return success(data=app_schema.DraftRunCompareResponse(**result))
@router.get("/{app_id}/workflow")
@cur_workspace_access_guard()
async def get_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
):
"""获取工作流配置
获取应用的工作流配置详情。
"""
workspace_id = current_user.current_workspace_id
cfg = app_service.get_workflow_config(db=db, app_id=app_id, workspace_id=workspace_id)
# 配置总是存在(不存在时返回默认模板)
return success(data=WorkflowConfigSchema.model_validate(cfg))
@router.put("/{app_id}/workflow", summary="更新 Workflow 配置")
@cur_workspace_access_guard()
async def update_workflow_config(
app_id: uuid.UUID,
payload: WorkflowConfigUpdate,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
):
workspace_id = current_user.current_workspace_id
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
return success(data=WorkflowConfigSchema.model_validate(cfg))

View File

@@ -29,10 +29,10 @@ router = APIRouter(
)
@router.get("/{kb_id}/{parent_id}/documents", response_model=ApiResponse)
@router.get("/{kb_id}/documents", response_model=ApiResponse)
async def get_documents(
kb_id: uuid.UUID,
parent_id: uuid.UUID,
parent_id: Optional[uuid.UUID] = Query(None, description="parent folder id when type is Folder"),
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),

View File

@@ -1,10 +1,13 @@
"""App 服务接口 - 基于 API Key 认证"""
from fastapi import APIRouter, Depends
import uuid
from fastapi import APIRouter, Depends, Request, Body
from sqlalchemy.orm import Session
from app.db import get_db
from app.core.response_utils import success
from app.core.logging_config import get_business_logger
from app.core.api_key_auth import require_api_key
from app.schemas.api_key_schema import ApiKeyAuth
router = APIRouter(prefix="/apps", tags=["V1 - App API"])
logger = get_business_logger()
@@ -14,3 +17,30 @@ logger = get_business_logger()
async def list_apps():
"""列出可访问的应用(占位)"""
return success(data=[], msg="App API - Coming Soon")
# /v1/apps/{resource_id}/chat
@router.post("/{resource_id}/chat")
@require_api_key(scopes=["app"])
async def chat_with_agent_demo(
resource_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
message: str = Body(..., description="聊天消息内容"),
):
"""
Agent 聊天接口demo
scopes: 所需的权限范围列表["app", "rag", "memory"]
Args:
resource_id: 如果是应用的apikey传的是应用id; 如果是服务的apikey传的是工作空间id
message: 请求参数
request: 声明请求
api_key_auth: 包含验证后的API Key 信息
db: db_session
"""
logger.info(f"API Key Auth: {api_key_auth}")
logger.info(f"Resource ID: {resource_id}")
logger.info(f"Message: {message}")
return success(data={"received": True}, msg="消息已接收")

View File

@@ -0,0 +1,587 @@
"""
工作流 API 控制器
"""
import logging
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, Path, Query
from sqlalchemy.orm import Session
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.models.user_model import User
from app.models.app_model import App
from app.services.workflow_service import WorkflowService, get_workflow_service
from app.schemas.workflow_schema import (
WorkflowConfigCreate,
WorkflowConfigUpdate,
WorkflowConfig,
WorkflowValidationResponse,
WorkflowExecution,
WorkflowNodeExecution,
WorkflowExecutionRequest,
WorkflowExecutionResponse
)
from app.core.response_utils import success, fail
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/apps", tags=["workflow"])
# ==================== 工作流配置管理 ====================
@router.post("/{app_id}/workflow")
@cur_workspace_access_guard()
async def create_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
config: WorkflowConfigCreate,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""创建工作流配置
创建或更新应用的工作流配置。配置会进行基础验证,但允许保存不完整的配置(草稿)。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 验证应用类型
if app.type != "workflow":
return fail(
code=BizCode.INVALID_PARAMETER,
msg=f"应用类型必须为 workflow当前为 {app.type}"
)
# 创建工作流配置
workflow_config = service.create_workflow_config(
app_id=app_id,
nodes=[node.model_dump() for node in config.nodes],
edges=[edge.model_dump() for edge in config.edges],
variables=[var.model_dump() for var in config.variables],
execution_config=config.execution_config.model_dump(),
triggers=[trigger.model_dump() for trigger in config.triggers],
validate=True # 进行基础验证
)
return success(
data=WorkflowConfig.model_validate(workflow_config),
msg="工作流配置创建成功"
)
except BusinessException as e:
logger.warning(f"创建工作流配置失败: {e.message}")
return fail(code=e.error_code, msg=e.message)
except Exception as e:
logger.error(f"创建工作流配置异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"创建工作流配置失败: {str(e)}"
)
#
# @router.get("/{app_id}/workflow")
# async def get_workflow_config(
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
# db: Annotated[Session, Depends(get_db)],
# current_user: Annotated[User, Depends(get_current_user)]
#
# ):
# """获取工作流配置
#
# 获取应用的工作流配置详情。
# """
# try:
# # 验证应用是否存在且属于当前工作空间
# app = db.query(App).filter(
# App.id == app_id,
# App.workspace_id == current_user.current_workspace_id,
# App.is_active == True
# ).first()
#
# if not app:
# return fail(
# code=BizCode.NOT_FOUND,
# msg="应用不存在或无权访问"
# )
#
# # 获取工作流配置
# service = WorkflowService(db)
# workflow_config = service.get_workflow_config(app_id)
#
# if not workflow_config:
# return fail(
# code=BizCode.NOT_FOUND,
# msg="工作流配置不存在"
# )
#
# return success(
# data=WorkflowConfig.model_validate(workflow_config)
# )
#
# except Exception as e:
# logger.error(f"获取工作流配置异常: {e}", exc_info=True)
# return fail(
# code=BizCode.INTERNAL_ERROR,
# msg=f"获取工作流配置失败: {str(e)}"
# )
# @router.put("/{app_id}/workflow")
# async def update_workflow_config(
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
# config: WorkflowConfigUpdate,
# db: Annotated[Session, Depends(get_db)],
# current_user: Annotated[User, Depends(get_current_user)],
# service: Annotated[WorkflowService, Depends(get_workflow_service)]
# ):
# """更新工作流配置
# 更新应用的工作流配置。可以部分更新,未提供的字段保持不变。
# """
# try:
# # 验证应用是否存在且属于当前工作空间
# app = db.query(App).filter(
# App.id == app_id,
# App.workspace_id == current_user.current_workspace_id,
# App.is_active == True
# ).first()
# if not app:
# return fail(
# code=BizCode.NOT_FOUND,
# msg="应用不存在或无权访问"
# )
# # 更新工作流配置
# workflow_config = service.update_workflow_config(
# app_id=app_id,
# nodes=[node.model_dump() for node in config.nodes] if config.nodes else None,
# edges=[edge.model_dump() for edge in config.edges] if config.edges else None,
# variables=[var.model_dump() for var in config.variables] if config.variables else None,
# execution_config=config.execution_config.model_dump() if config.execution_config else None,
# triggers=[trigger.model_dump() for trigger in config.triggers] if config.triggers else None,
# validate=True
# )
# return success(
# data=WorkflowConfig.model_validate(workflow_config),
# msg="工作流配置更新成功"
# )
# except BusinessException as e:
# logger.warning(f"更新工作流配置失败: {e.message}")
# return fail(code=e.error_code, msg=e.message)
# except Exception as e:
# logger.error(f"更新工作流配置异常: {e}", exc_info=True)
# return fail(
# code=BizCode.INTERNAL_ERROR,
# msg=f"更新工作流配置失败: {str(e)}"
# )
@router.delete("/{app_id}/workflow")
async def delete_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""删除工作流配置
删除应用的工作流配置。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 删除工作流配置
deleted = service.delete_workflow_config(app_id)
if not deleted:
return fail(
code=BizCode.NOT_FOUND,
msg="工作流配置不存在"
)
return success(msg="工作流配置删除成功")
except Exception as e:
logger.error(f"删除工作流配置异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"删除工作流配置失败: {str(e)}"
)
@router.post("/{app_id}/workflow/validate")
async def validate_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)],
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
):
"""验证工作流配置
验证工作流配置是否有效。可以选择是否进行发布级别的严格验证。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 验证工作流配置
if for_publish:
is_valid, errors = service.validate_workflow_config_for_publish(app_id)
else:
workflow_config = service.get_workflow_config(app_id)
if not workflow_config:
return fail(
code=BizCode.NOT_FOUND,
msg="工作流配置不存在"
)
from app.core.workflow.validator import validate_workflow_config as validate_config
config_dict = {
"nodes": workflow_config.nodes,
"edges": workflow_config.edges,
"variables": workflow_config.variables,
"execution_config": workflow_config.execution_config,
"triggers": workflow_config.triggers
}
is_valid, errors = validate_config(config_dict, for_publish=False)
return success(
data=WorkflowValidationResponse(
is_valid=is_valid,
errors=errors,
warnings=[]
)
)
except BusinessException as e:
logger.warning(f"验证工作流配置失败: {e.message}")
return fail(code=e.error_code, msg=e.message)
except Exception as e:
logger.error(f"验证工作流配置异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"验证工作流配置失败: {str(e)}"
)
# ==================== 工作流执行管理 ====================
@router.get("/{app_id}/workflow/executions")
async def get_workflow_executions(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)],
limit: Annotated[int, Query(ge=1, le=100)] = 50,
offset: Annotated[int, Query(ge=0)] = 0
):
"""获取工作流执行记录列表
获取应用的工作流执行历史记录。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 获取执行记录
executions = service.get_executions_by_app(app_id, limit, offset)
# 获取统计信息
statistics = service.get_execution_statistics(app_id)
return success(
data={
"executions": [WorkflowExecution.model_validate(e) for e in executions],
"statistics": statistics,
"pagination": {
"limit": limit,
"offset": offset,
"total": statistics["total"]
}
}
)
except Exception as e:
logger.error(f"获取工作流执行记录异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"获取工作流执行记录失败: {str(e)}"
)
@router.get("/workflow/executions/{execution_id}")
async def get_workflow_execution(
execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""获取工作流执行详情
获取单个工作流执行的详细信息,包括所有节点的执行记录。
"""
try:
# 获取执行记录
execution = service.get_execution(execution_id)
if not execution:
return fail(
code=BizCode.NOT_FOUND,
msg="执行记录不存在"
)
# 验证应用是否属于当前工作空间
app = db.query(App).filter(
App.id == execution.app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="无权访问该执行记录"
)
# 获取节点执行记录
node_executions = service.node_execution_repo.get_by_execution_id(execution.id)
return success(
data={
"execution": WorkflowExecution.model_validate(execution),
"node_executions": [
WorkflowNodeExecution.model_validate(ne) for ne in node_executions
]
}
)
except Exception as e:
logger.error(f"获取工作流执行详情异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"获取工作流执行详情失败: {str(e)}"
)
# ==================== 工作流执行 ====================
@router.post("/{app_id}/workflow/run")
async def run_workflow(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
request: WorkflowExecutionRequest,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""执行工作流
执行工作流并返回结果。支持流式和非流式两种模式。
**非流式模式**:等待工作流执行完成后返回完整结果。
**流式模式**:实时返回执行过程中的事件(节点开始、节点完成、工作流完成等)。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 验证应用类型
if app.type != "workflow":
return fail(
code=BizCode.INVALID_PARAMETER,
msg=f"应用类型必须为 workflow当前为 {app.type}"
)
# 准备输入数据
input_data = {
"message": request.message or "",
"variables": request.variables
}
# 执行工作流
if request.stream:
# 流式执行
from fastapi.responses import StreamingResponse
import json
async def event_generator():
"""生成 SSE 事件"""
try:
async for event in service.run_workflow(
app_id=app_id,
input_data=input_data,
triggered_by=current_user.id,
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
stream=True
):
# 转换为 SSE 格式
yield f"data: {json.dumps(event)}\n\n"
except Exception as e:
logger.error(f"流式执行异常: {e}", exc_info=True)
error_event = {
"type": "error",
"error": str(e)
}
yield f"data: {json.dumps(error_event)}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream"
)
else:
# 非流式执行
result = await service.run_workflow(
app_id=app_id,
input_data=input_data,
triggered_by=current_user.id,
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
stream=False
)
return success(
data=WorkflowExecutionResponse(
execution_id=result["execution_id"],
status=result["status"],
output=result.get("output"),
output_data=result.get("output_data"),
error_message=result.get("error_message"),
elapsed_time=result.get("elapsed_time"),
token_usage=result.get("token_usage")
),
msg="工作流执行完成"
)
except BusinessException as e:
logger.warning(f"执行工作流失败: {e.message}")
return fail(code=e.error_code, msg=e.message)
except Exception as e:
logger.error(f"执行工作流异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"执行工作流失败: {str(e)}"
)
@router.post("/workflow/executions/{execution_id}/cancel")
async def cancel_workflow_execution(
execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""取消工作流执行
取消正在运行的工作流执行。
**注意**:当前版本仅更新状态为 cancelled实际的执行取消功能待实现。
"""
try:
# 获取执行记录
execution = service.get_execution(execution_id)
if not execution:
return fail(
code=BizCode.NOT_FOUND,
msg="执行记录不存在"
)
# 验证应用是否属于当前工作空间
app = db.query(App).filter(
App.id == execution.app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="无权访问该执行记录"
)
# 检查执行状态
if execution.status not in ["pending", "running"]:
return fail(
code=BizCode.INVALID_PARAMETER,
msg=f"无法取消状态为 {execution.status} 的执行"
)
# 更新状态为 cancelled
service.update_execution_status(execution_id, "cancelled")
return success(msg="工作流执行已取消")
except BusinessException as e:
logger.warning(f"取消工作流执行失败: {e.message}")
return fail(code=e.error_code, msg=e.message)
except Exception as e:
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"取消工作流执行失败: {str(e)}"
)

View File

@@ -1,10 +1,12 @@
import asyncio
import time
import uuid
from functools import wraps
from typing import Optional, List
from datetime import datetime
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from app.core.api_key_utils import add_rate_limit_headers
@@ -22,21 +24,17 @@ logger = get_api_logger()
def require_api_key(
scopes: Optional[List[str]] = None,
resource_type: Optional[str] = None
scopes: Optional[List[str]] = None
):
"""
API Key 鉴权装饰器
Args:
scopes: 所需的权限范围列表["app:all",
"rag:search", "rag:upload", "rag:delete",
"memory:read", "memory:write", "memory:delete", "memory:search"]
resource_type: 所需的资源类型("Agent", "Cluster", "Workflow", "Knowledge", "Memory_Engine")
scopes: 所需的权限范围列表[“app”, "rag", "memory"]
Usage:
@router.get("/app/{resource_id}/chat")
@require_api_key(scopes=["app:all"], resource_type="Agent")
@require_api_key(scopes=["app"])
def chat_with_app(
resource_id: uuid.UUID,
api_key_auth: ApiKeyAuth = Depends(),
@@ -113,31 +111,25 @@ def require_api_key(
context={"required_scopes": scopes, "missing_scopes": missing_scopes}
)
if resource_type:
resource_id = kwargs.get("resource_id")
if resource_id and not ApiKeyAuthService.check_resource(
api_key_obj,
resource_type,
resource_id
):
logger.warning("API Key 资源访问被拒绝", extra={
"api_key_id": str(api_key_obj.id),
"required_resource_type": resource_type,
resource_id = kwargs.get("resource_id")
if resource_id and not ApiKeyAuthService.check_resource(
api_key_obj,
resource_id
):
logger.warning("API Key 资源访问被拒绝", extra={
"api_key_id": str(api_key_obj.id),
"required_resource_id": str(resource_id),
"bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None,
"endpoint": str(request.url)
})
return BusinessException(
"API Key 未授权访问该资源",
BizCode.API_KEY_INVALID_RESOURCE,
context={
"required_resource_id": str(resource_id),
"bound_resource_type": api_key_obj.resource_type,
"bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None,
"endpoint": str(request.url)
})
return BusinessException(
"API Key 未授权访问该资源",
BizCode.API_KEY_INVALID_RESOURCE,
context={
"required_resource_type": resource_type,
"required_resource_id": str(resource_id),
"bound_resource_type": api_key_obj.resource_type,
"bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None
}
)
"bound_resource_id": str(api_key_obj.resource_id)
}
)
kwargs["api_key_auth"] = ApiKeyAuth(
api_key_id=api_key_obj.id,
@@ -145,14 +137,17 @@ def require_api_key(
type=api_key_obj.type,
scopes=api_key_obj.scopes,
resource_id=api_key_obj.resource_id,
resource_type=api_key_obj.resource_type
)
start_time = time.perf_counter()
response = await func(*args, **kwargs)
end_time = time.perf_counter()
response_time = (end_time - start_time) * 1000
if not isinstance(response, Response):
response = JSONResponse(content=response)
response = add_rate_limit_headers(response, rate_headers)
asyncio.create_task(log_api_key_usage(
db, api_key_obj.id, request, response
db, api_key_obj.id, request, response, response_time
))
return response
@@ -204,7 +199,8 @@ async def log_api_key_usage(
db: Session,
api_key_id: uuid.UUID,
request: Request,
response: Response
response: Response,
response_time: float
):
"""记录 API Key 使用日志"""
try:
@@ -216,8 +212,8 @@ async def log_api_key_usage(
"ip_address": request.client.host if request.client else None,
"user_agent": request.headers.get("User-Agent"),
"status_code": response.status_code if hasattr(response, "status_code") else None,
"response_time": None, # 需要在 middleware 中计算
"tokens_used": None, # 需要从响应中提取
"response_time": round(response_time),
"tokens_used": None,
"created_at": datetime.now()
}

View File

@@ -1,33 +1,14 @@
"""API Key 工具函数"""
import secrets
import hashlib
from typing import Optional
from typing import Optional, Union
from datetime import datetime
from app.schemas.api_key_schema import ApiKeyType
from fastapi import Response
from fastapi.responses import JSONResponse
class ResourceType:
"""资源类型常量"""
AGENT = "Agent"
CLUSTER = "Cluster"
WORKFLOW = "Workflow"
KNOWLEDGE = "Knowledge"
MEMORY_ENGINE = "Memory_Engine"
@classmethod
def get_all_types(cls) -> list[str]:
"""获取所有支持的资源类型"""
return [cls.AGENT, cls.CLUSTER, cls.WORKFLOW, cls.KNOWLEDGE, cls.MEMORY_ENGINE]
@classmethod
def is_valid_type(cls, resource_type: str) -> bool:
"""验证资源类型是否有效"""
return resource_type in cls.get_all_types()
def generate_api_key(key_type: ApiKeyType) -> tuple[str, str, str]:
def generate_api_key(key_type: ApiKeyType) -> str:
"""
生成 API Key
@@ -39,102 +20,17 @@ def generate_api_key(key_type: ApiKeyType) -> tuple[str, str, str]:
"""
# 前缀映射
prefix_map = {
ApiKeyType.APP: "sk-app-",
ApiKeyType.RAG: "sk-rag-",
ApiKeyType.MEMORY: "sk-mem-",
ApiKeyType.AGENT: "sk-agent-",
ApiKeyType.CLUSTER: "sk-cluster-",
ApiKeyType.WORKFLOW: "sk-workflow-",
ApiKeyType.SERVICE: "sk-service-"
}
prefix = prefix_map[key_type]
random_string = secrets.token_urlsafe(32)[:32] # 32 字符
api_key = f"{prefix}{random_string}"
# 生成哈希值存储
key_hash = hash_api_key(api_key)
return api_key, key_hash, prefix
def hash_api_key(api_key: str) -> str:
"""对 API Key 进行哈希
Args:
api_key: API Key 明文
Returns:
str: 哈希值
"""
return hashlib.sha256(api_key.encode()).hexdigest()
def verify_api_key(api_key: str, key_hash: str) -> bool:
"""
验证 API Key
Args:
api_key: API Key 明文
key_hash: 存储的哈希值
Returns:
bool: 是否匹配
"""
computed_hash = hash_api_key(api_key)
return secrets.compare_digest(computed_hash, key_hash)
def validate_resource_binding(
resource_type: Optional[str],
resource_id: Optional[str]
) -> tuple[bool, str]:
"""
验证资源绑定的有效性
Args:
resource_type: 资源类型
resource_id: 资源ID
Returns:
tuple: (是否有效, 错误信息)
"""
# 如果都为空,表示不绑定资源,这是有效的
if not resource_type and not resource_id:
return True, ""
# 如果只有一个为空,这是无效的
if not resource_type or not resource_id:
return False, "resource_type 和 resource_id 必须同时提供或同时为空"
# 验证资源类型是否支持
if not ResourceType.is_valid_type(resource_type):
valid_types = ", ".join(ResourceType.get_all_types())
return False, f"不支持的资源类型 '{resource_type}',支持的类型:{valid_types}"
return True, ""
def get_resource_scope_mapping() -> dict[str, list[str]]:
"""
获取资源类型与权限范围的映射关系
Returns:
dict: 资源类型到推荐权限范围的映射
"""
return {
ResourceType.AGENT: [
"app:all"
],
ResourceType.CLUSTER: [
"app:all"
],
ResourceType.WORKFLOW: [
"app:all"
],
ResourceType.KNOWLEDGE: [
"rag:search", "rag:upload", "rag:delete"
],
ResourceType.MEMORY_ENGINE: [
"memory:read", "memory:write", "memory:delete", "memory:search"
]
}
return api_key
def add_rate_limit_headers(response, headers: dict):
@@ -151,3 +47,21 @@ def add_rate_limit_headers(response, headers: dict):
return response
def timestamp_to_datetime(timestamp: Optional[Union[int, float]]) -> Optional[datetime]:
"""将时间戳转换为datetime对象"""
if timestamp is None:
return None
# 处理毫秒级时间戳
if timestamp > 1e10:
timestamp = timestamp / 1000
return datetime.fromtimestamp(timestamp)
def datetime_to_timestamp(dt: Optional[datetime]) -> Optional[int]:
"""将datetime对象转换为时间戳毫秒"""
if dt is None:
return None
return int(dt.timestamp() * 1000)

View File

@@ -59,6 +59,7 @@ class BizCode(IntEnum):
EMBED_NOT_ALLOWED = 6009
PERMISSION_DENIED = 6010
INVALID_CONVERSATION = 6011
CONFIG_MISSING = 6012
# 模型7xxx
MODEL_CONFIG_INVALID = 7001
@@ -96,7 +97,7 @@ HTTP_MAPPING = {
BizCode.TOKEN_INVALID: 401,
BizCode.TOKEN_EXPIRED: 401,
BizCode.TOKEN_BLACKLISTED: 401,
BizCode.FORBIDDEN: 403,
BizCode.FORBIDDEN: 403,
BizCode.TENANT_NOT_FOUND: 404,
BizCode.WORKSPACE_NO_ACCESS: 403,
BizCode.NOT_FOUND: 404,
@@ -151,4 +152,4 @@ HTTP_MAPPING = {
BizCode.DB_ERROR: 500,
BizCode.SERVICE_UNAVAILABLE: 503,
BizCode.RATE_LIMITED: 429,
}
}

View File

@@ -0,0 +1,436 @@
"""
工作流执行器
基于 LangGraph 的工作流执行引擎。
"""
import logging
import uuid
import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langgraph.graph import StateGraph, START, END
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.expression_evaluator import evaluate_condition
from app.models.workflow_model import WorkflowExecution, WorkflowNodeExecution
from app.db import get_db
logger = logging.getLogger(__name__)
class WorkflowExecutor:
"""工作流执行器
负责将工作流配置转换为 LangGraph 并执行。
"""
def __init__(
self,
workflow_config: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
):
"""初始化执行器
Args:
workflow_config: 工作流配置
execution_id: 执行 ID
workspace_id: 工作空间 ID
user_id: 用户 ID
"""
self.workflow_config = workflow_config
self.execution_id = execution_id
self.workspace_id = workspace_id
self.user_id = user_id
self.nodes = workflow_config.get("nodes", [])
self.edges = workflow_config.get("edges", [])
self.execution_config = workflow_config.get("execution_config", {})
def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
"""准备初始状态(注入系统变量和会话变量)
变量命名空间:
- sys.xxx - 系统变量execution_id, workspace_id, user_id, message, input_variables 等)
- conv.xxx - 会话变量(跨多轮对话保持)
- node_id.xxx - 节点输出(执行时动态生成)
Args:
input_data: 输入数据
Returns:
初始化的工作流状态
"""
user_message = input_data.get("message") or ""
conversation_vars = input_data.get("conversation_vars") or {}
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
# 构建分层的变量结构
variables = {
"sys": {
"message": user_message, # 用户消息
"conversation_id": input_data.get("conversation_id"), # 会话 ID
"execution_id": self.execution_id, # 执行 ID
"workspace_id": self.workspace_id, # 工作空间 ID
"user_id": self.user_id, # 用户 ID
"input_variables": input_variables, # 自定义输入变量(给 Start 节点使用)
},
"conv": conversation_vars # 会话级变量(跨多轮对话保持)
}
return {
"messages": [HumanMessage(content=user_message)],
"variables": variables,
"node_outputs": {},
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)
"execution_id": self.execution_id,
"workspace_id": self.workspace_id,
"user_id": self.user_id,
"error": None,
"error_node": None
}
def build_graph(self) -> StateGraph:
"""构建 LangGraph
Returns:
编译后的状态图
"""
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
# 1. 创建状态图
workflow = StateGraph(WorkflowState)
# 2. 添加所有节点(包括 start 和 end
start_node_id = None
end_node_ids = []
for node in self.nodes:
node_type = node.get("type")
node_id = node.get("id")
# 记录 start 和 end 节点 ID
if node_type == "start":
start_node_id = node_id
elif node_type == "end":
end_node_ids.append(node_id)
# 创建节点实例(现在 start 和 end 也会被创建)
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_instance:
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
def make_node_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
workflow.add_node(node_id, make_node_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type})")
# 3. 添加边
# 从 START 连接到 start 节点
if start_node_id:
workflow.add_edge(START, start_node_id)
logger.debug(f"添加边: START -> {start_node_id}")
for edge in self.edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
condition = edge.get("condition")
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start
if source == start_node_id:
# 但要连接 start 到下一个节点
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
continue
# 处理到 end 节点的边
if target in end_node_ids:
# 连接到 end 节点
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
continue
# 跳过错误边(在节点内部处理)
if edge_type == "error":
continue
if condition:
# 条件边
def router(state: WorkflowState, cond=condition, tgt=target):
"""条件路由函数"""
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END # 条件不满足,结束
workflow.add_conditional_edges(source, router)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
else:
# 普通边
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
# 从 end 节点连接到 END
for end_node_id in end_node_ids:
workflow.add_edge(end_node_id, END)
logger.debug(f"添加边: {end_node_id} -> END")
# 4. 编译图
graph = workflow.compile()
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
return graph
async def execute(
self,
input_data: dict[str, Any]
) -> dict[str, Any]:
"""执行工作流(非流式)
Args:
input_data: 输入数据,包含 message 和 variables
Returns:
执行结果,包含 status, output, node_outputs, elapsed_time, token_usage
"""
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
# 记录开始时间
start_time = datetime.datetime.now()
# 1. 构建图
graph = self.build_graph()
# 2. 初始化状态(自动注入系统变量)
initial_state = self._prepare_initial_state(input_data)
# 3. 执行工作流
try:
result = await graph.ainvoke(initial_state)
# 计算耗时
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
# 提取节点输出(现在包含 start 和 end 节点)
node_outputs = result.get("node_outputs", {})
# 提取最终输出(从最后一个非 start/end 节点)
final_output = self._extract_final_output(node_outputs)
# 聚合 token 使用情况
token_usage = self._aggregate_token_usage(node_outputs)
# 提取 conversation_id从 start 节点输出)
conversation_id = None
for node_id, node_output in node_outputs.items():
if node_output.get("node_type") == "start":
conversation_id = node_output.get("output", {}).get("conversation_id")
break
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
return {
"status": "completed",
"output": final_output,
"node_outputs": node_outputs,
"messages": result.get("messages", []),
"conversation_id": conversation_id,
"elapsed_time": elapsed_time,
"token_usage": token_usage,
"error": result.get("error")
}
except Exception as e:
# 计算耗时(即使失败也记录)
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True)
return {
"status": "failed",
"error": str(e),
"output": None,
"node_outputs": {},
"elapsed_time": elapsed_time,
"token_usage": None
}
async def execute_stream(
self,
input_data: dict[str, Any]
):
"""执行工作流(流式)
Args:
input_data: 输入数据
Yields:
流式事件
"""
logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}")
# 1. 构建图
graph = self.build_graph()
# 2. 初始化状态(自动注入系统变量)
initial_state = self._prepare_initial_state(input_data)
# 3. 流式执行工作流
try:
# 使用 astream 获取节点级别的更新
async for event in graph.astream(initial_state, stream_mode="updates"):
for node_name, state_update in event.items():
yield {
"type": "node_complete",
"node": node_name,
"data": state_update,
"execution_id": self.execution_id
}
logger.info(f"工作流执行完成(流式): execution_id={self.execution_id}")
# 发送完成事件
yield {
"type": "workflow_complete",
"execution_id": self.execution_id
}
except Exception as e:
logger.error(f"工作流执行失败(流式): execution_id={self.execution_id}, error={e}", exc_info=True)
yield {
"type": "workflow_error",
"execution_id": self.execution_id,
"error": str(e)
}
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
"""从节点输出中提取最终输出
优先级:
1. 最后一个执行的非 start/end 节点的 output
2. 如果没有节点输出,返回 None
Args:
node_outputs: 所有节点的输出
Returns:
最终输出字符串或 None
"""
if not node_outputs:
return None
# 获取最后一个节点的输出
last_node_output = list(node_outputs.values())[-1] if node_outputs else None
if last_node_output and isinstance(last_node_output, dict):
return last_node_output.get("output")
return None
def _aggregate_token_usage(self, node_outputs: dict[str, Any]) -> dict[str, int] | None:
"""聚合所有节点的 token 使用情况
Args:
node_outputs: 所有节点的输出
Returns:
聚合的 token 使用情况 {"prompt_tokens": x, "completion_tokens": y, "total_tokens": z}
如果没有 token 使用信息,返回 None
"""
total_prompt_tokens = 0
total_completion_tokens = 0
total_tokens = 0
has_token_info = False
for node_output in node_outputs.values():
if isinstance(node_output, dict):
token_usage = node_output.get("token_usage")
if token_usage and isinstance(token_usage, dict):
has_token_info = True
total_prompt_tokens += token_usage.get("prompt_tokens", 0)
total_completion_tokens += token_usage.get("completion_tokens", 0)
total_tokens += token_usage.get("total_tokens", 0)
if not has_token_info:
return None
return {
"prompt_tokens": total_prompt_tokens,
"completion_tokens": total_completion_tokens,
"total_tokens": total_tokens
}
async def execute_workflow(
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
) -> dict[str, Any]:
"""执行工作流(便捷函数)
Args:
workflow_config: 工作流配置
input_data: 输入数据
execution_id: 执行 ID
workspace_id: 工作空间 ID
user_id: 用户 ID
Returns:
执行结果
"""
executor = WorkflowExecutor(
workflow_config=workflow_config,
execution_id=execution_id,
workspace_id=workspace_id,
user_id=user_id
)
return await executor.execute(input_data)
async def execute_workflow_stream(
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
):
"""执行工作流(流式,便捷函数)
Args:
workflow_config: 工作流配置
input_data: 输入数据
execution_id: 执行 ID
workspace_id: 工作空间 ID
user_id: 用户 ID
Yields:
流式事件
"""
executor = WorkflowExecutor(
workflow_config=workflow_config,
execution_id=execution_id,
workspace_id=workspace_id,
user_id=user_id
)
async for event in executor.execute_stream(input_data):
yield event

View File

@@ -0,0 +1,195 @@
"""
安全的表达式求值器
使用 simpleeval 库提供安全的表达式评估,避免代码注入攻击。
"""
import logging
from typing import Any
from simpleeval import simple_eval, NameNotDefined, InvalidExpression
logger = logging.getLogger(__name__)
class ExpressionEvaluator:
"""安全的表达式求值器"""
# 保留的命名空间
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
@staticmethod
def evaluate(
expression: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> Any:
"""安全地评估表达式
Args:
expression: 表达式字符串,如 "{{var.score}} > 0.8"
variables: 用户定义的变量
node_outputs: 节点输出结果
system_vars: 系统变量
Returns:
表达式求值结果
Raises:
ValueError: 表达式无效或求值失败
Examples:
>>> evaluator = ExpressionEvaluator()
>>> evaluator.evaluate(
... "var.score > 0.8",
... {"score": 0.9},
... {},
... {}
... )
True
>>> evaluator.evaluate(
... "node.intent.output == '售前咨询'",
... {},
... {"intent": {"output": "售前咨询"}},
... {}
... )
True
"""
# 移除 Jinja2 模板语法的花括号(如果存在)
expression = expression.strip()
if expression.startswith("{{") and expression.endswith("}}"):
expression = expression[2:-2].strip()
# 构建命名空间上下文
context = {
"var": variables, # 用户变量
"node": node_outputs, # 节点输出
"sys": system_vars or {}, # 系统变量
}
# 为了向后兼容,也支持直接访问(但会在日志中警告)
context.update(variables)
context["nodes"] = node_outputs
try:
# simpleeval 只支持安全的操作:
# - 算术运算: +, -, *, /, //, %, **
# - 比较运算: ==, !=, <, <=, >, >=
# - 逻辑运算: and, or, not
# - 成员运算: in, not in
# - 属性访问: obj.attr
# - 字典/列表访问: obj["key"], obj[0]
# 不支持:函数调用、导入、赋值等危险操作
result = simple_eval(expression, names=context)
return result
except NameNotDefined as e:
logger.error(f"表达式中引用了未定义的变量: {expression}, 错误: {e}")
raise ValueError(f"未定义的变量: {e}")
except InvalidExpression as e:
logger.error(f"表达式语法无效: {expression}, 错误: {e}")
raise ValueError(f"表达式语法无效: {e}")
except SyntaxError as e:
logger.error(f"表达式语法错误: {expression}, 错误: {e}")
raise ValueError(f"表达式语法错误: {e}")
except Exception as e:
logger.error(f"表达式求值异常: {expression}, 错误: {e}")
raise ValueError(f"表达式求值失败: {e}")
@staticmethod
def evaluate_bool(
expression: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> bool:
"""评估布尔表达式(用于条件判断)
Args:
expression: 布尔表达式
variables: 用户变量
node_outputs: 节点输出
system_vars: 系统变量
Returns:
布尔值结果
Examples:
>>> ExpressionEvaluator.evaluate_bool(
... "var.count >= 10 and var.status == 'active'",
... {"count": 15, "status": "active"},
... {},
... {}
... )
True
"""
result = ExpressionEvaluator.evaluate(
expression, variables, node_outputs, system_vars
)
return bool(result)
@staticmethod
def validate_variable_names(variables: list[dict]) -> list[str]:
"""验证变量名是否合法
Args:
variables: 变量定义列表
Returns:
错误列表,如果为空则验证通过
Examples:
>>> ExpressionEvaluator.validate_variable_names([
... {"name": "user_input"},
... {"name": "var"} # 保留字
... ])
["变量名 'var' 是保留的命名空间,请使用其他名称"]
"""
errors = []
for var in variables:
var_name = var.get("name", "")
# 检查是否为保留命名空间
if var_name in ExpressionEvaluator.RESERVED_NAMESPACES:
errors.append(
f"变量名 '{var_name}' 是保留的命名空间,请使用其他名称"
)
# 检查是否为有效的 Python 标识符
if not var_name.isidentifier():
errors.append(
f"变量名 '{var_name}' 不是有效的标识符"
)
return errors
# 便捷函数
def evaluate_expression(
expression: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> Any:
"""评估表达式(便捷函数)"""
return ExpressionEvaluator.evaluate(
expression, variables, node_outputs, system_vars
)
def evaluate_condition(
expression: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> bool:
"""评估条件表达式(便捷函数)"""
return ExpressionEvaluator.evaluate_bool(
expression, variables, node_outputs, system_vars
)

View File

@@ -0,0 +1,24 @@
"""
工作流节点实现
提供各种类型的节点实现,用于工作流执行。
"""
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.node_factory import NodeFactory
__all__ = [
"BaseNode",
"WorkflowState",
"LLMNode",
"AgentNode",
"TransformNode",
"StartNode",
"EndNode",
"NodeFactory",
]

View File

@@ -0,0 +1,6 @@
"""Agent 节点"""
from app.core.workflow.nodes.agent.node import AgentNode
from app.core.workflow.nodes.agent.config import AgentNodeConfig
__all__ = ["AgentNode", "AgentNodeConfig"]

View File

@@ -0,0 +1,71 @@
"""Agent 节点配置"""
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
class AgentNodeConfig(BaseNodeConfig):
"""Agent 节点配置
调用已配置的 Agent 执行任务。
"""
agent_id: str = Field(
...,
description="Agent 配置 ID"
)
message: str = Field(
default="{{ sys.message }}",
description="发送给 Agent 的消息,支持模板变量"
)
conversation_id: str | None = Field(
default=None,
description="会话 ID用于多轮对话"
)
variables: dict[str, str] | None = Field(
default=None,
description="传递给 Agent 的变量"
)
timeout: int = Field(
default=300,
ge=1,
le=3600,
description="超时时间(秒)"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="output",
type=VariableType.STRING,
description="Agent 的回复内容"
),
VariableDefinition(
name="conversation_id",
type=VariableType.STRING,
description="会话 ID"
),
VariableDefinition(
name="token_usage",
type=VariableType.OBJECT,
description="Token 使用情况"
)
],
description="输出变量定义(自动生成,通常不需要修改)"
)
class Config:
json_schema_extra = {
"example": {
"agent_id": "uuid-here",
"message": "{{ sys.message }}",
"timeout": 300,
"description": "调用客服 Agent"
}
}

View File

@@ -0,0 +1,152 @@
"""
Agent 节点实现
调用已发布的 Agent 应用。
"""
import logging
from typing import Any
from langchain_core.messages import AIMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.services.draft_run_service import DraftRunService
from app.models import AppRelease
from app.db import get_db
logger = logging.getLogger(__name__)
class AgentNode(BaseNode):
"""Agent 节点
支持流式和非流式输出。
配置示例:
{
"type": "agent",
"config": {
"agent_id": "uuid", # Agent 的 release_id
"message": "{{var.user_input}}"
}
}
"""
def _prepare_agent(self, state: WorkflowState) -> tuple[DraftRunService, AppRelease, str]:
"""准备 Agent公共逻辑
Args:
state: 工作流状态
Returns:
(draft_service, release, message): 服务实例、发布配置、消息
"""
# 1. 渲染消息
message_template = self.config.get("message", "")
message = self._render_template(message_template, state)
# 2. 获取 Agent 配置
agent_id = self.config.get("agent_id")
if not agent_id:
raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置")
db = next(get_db())
release = db.query(AppRelease).filter(
AppRelease.id == agent_id
).first()
if not release:
raise ValueError(f"Agent 不存在: {agent_id}")
draft_service = DraftRunService(db)
return draft_service, release, message
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""非流式执行
Args:
state: 工作流状态
Returns:
状态更新字典
"""
draft_service, release, message = self._prepare_agent(state)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
# 执行 Agent非流式
result = await draft_service.run(
agent_config=release.config,
model_config=None,
message=message,
workspace_id=state.get("workspace_id"),
user_id=state.get("user_id"),
variables=state.get("variables", {})
)
response = result.get("response", "")
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(response)}")
return {
"messages": [AIMessage(content=response)],
"node_outputs": {
self.node_id: {
"output": response,
"status": "completed",
"meta_data": result.get("meta_data", {})
}
}
}
async def execute_stream(self, state: WorkflowState):
"""流式执行
Args:
state: 工作流状态
Yields:
流式事件字典
"""
draft_service, release, message = self._prepare_agent(state)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
# 累积完整响应
full_response = ""
# 执行 Agent流式
async for chunk in draft_service.run_stream(
agent_config=release.config,
model_config=None,
message=message,
workspace_id=state.get("workspace_id"),
user_id=state.get("user_id"),
variables=state.get("variables", {})
):
# 提取内容
content = chunk.get("content", "")
full_response += content
# 流式返回每个 chunk
yield {
"type": "chunk",
"node_id": self.node_id,
"content": content,
"full_content": full_response,
"meta_data": chunk.get("meta_data", {})
}
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}")
# 最后返回完整结果
yield {
"type": "complete",
"messages": [AIMessage(content=full_response)],
"node_outputs": {
self.node_id: {
"output": full_response,
"status": "completed"
}
}
}

View File

@@ -0,0 +1,109 @@
"""节点配置基类
定义所有节点配置的通用字段和数据结构。
"""
from enum import StrEnum
from pydantic import BaseModel, Field
class VariableType(StrEnum):
"""变量类型枚举"""
STRING = "string"
NUMBER = "number"
BOOLEAN = "boolean"
ARRAY = "array"
OBJECT = "object"
class VariableDefinition(BaseModel):
"""变量定义
定义工作流或节点的输入/输出变量。
这是一个通用的数据结构,可以在多个地方使用。
"""
name: str = Field(
...,
description="变量名称"
)
type: VariableType = Field(
default=VariableType.STRING,
description="变量类型"
)
required: bool = Field(
default=False,
description="是否必需"
)
default: str | int | float | bool | list | dict | None = Field(
default=None,
description="默认值"
)
description: str | None = Field(
default=None,
description="变量描述"
)
class Config:
json_schema_extra = {
"examples": [
{
"name": "language",
"type": "string",
"required": False,
"default": "zh-CN",
"description": "语言设置"
},
{
"name": "max_length",
"type": "number",
"required": False,
"default": 1000,
"description": "最大长度"
},
{
"name": "enable_search",
"type": "boolean",
"required": True,
"description": "是否启用搜索"
}
]
}
class BaseNodeConfig(BaseModel):
"""节点配置基类
所有节点配置都应该继承此基类。
通用字段:
- name: 节点名称(显示名称)
- description: 节点描述
- tags: 节点标签(用于分类和搜索)
"""
name: str | None = Field(
default=None,
description="节点名称(显示名称),如果不设置则使用节点 ID"
)
description: str | None = Field(
default=None,
description="节点描述,说明节点的作用"
)
tags: list[str] = Field(
default_factory=list,
description="节点标签,用于分类和搜索"
)
class Config:
"""Pydantic 配置"""
# 允许额外字段(向后兼容)
extra = "allow"

View File

@@ -0,0 +1,556 @@
"""
工作流节点基类
定义节点的基本接口和通用功能。
"""
import asyncio
import logging
from abc import ABC, abstractmethod
from typing import Any, TypedDict, Annotated
from operator import add
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
class WorkflowState(TypedDict):
"""工作流状态
在节点间传递的状态对象,包含消息、变量、节点输出等信息。
"""
# 消息列表(追加模式)
messages: Annotated[list[AnyMessage], add]
# 输入变量(从配置的 variables 传入)
variables: dict[str, Any]
# 节点输出(存储每个节点的执行结果,用于变量引用)
# 使用自定义合并函数,将新的节点输出合并到现有字典中
node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# 运行时节点变量(简化版,只存储业务数据,供节点间快速访问)
# 格式:{node_id: business_result}
runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# 执行上下文
execution_id: str
workspace_id: str
user_id: str
# 错误信息(用于错误边)
error: str | None
error_node: str | None
class BaseNode(ABC):
"""节点基类
所有节点类型都应该继承此基类,实现 execute 方法。
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
"""初始化节点
Args:
node_config: 节点配置
workflow_config: 工作流配置
"""
self.node_config = node_config
self.workflow_config = workflow_config
self.node_id = node_config["id"]
self.node_type = node_config["type"]
self.node_name = node_config.get("name", self.node_id)
# 使用 or 运算符处理 None 值
self.config = node_config.get("config") or {}
self.error_handling = node_config.get("error_handling") or {}
@abstractmethod
async def execute(self, state: WorkflowState) -> Any:
"""执行节点业务逻辑(非流式)
节点只需要返回业务结果,不需要关心输出格式、时间统计等。
BaseNode 会自动包装成标准格式。
Args:
state: 工作流状态
Returns:
业务结果(任意类型)
Examples:
>>> # LLM 节点
>>> return "这是 AI 的回复"
>>> # Transform 节点
>>> return {"processed_data": [...]}
>>> # Start/End 节点
>>> return {"message": "开始", "conversation_id": "xxx"}
"""
pass
async def execute_stream(self, state: WorkflowState):
"""执行节点业务逻辑(流式)
子类可以重写此方法以支持流式输出。
默认实现:执行非流式方法并一次性返回。
节点需要:
1. yield 中间结果(如文本片段)
2. 最后 yield 一个特殊的完成标记:{"__final__": True, "result": final_result}
Args:
state: 工作流状态
Yields:
业务数据chunk或完成标记
Examples:
>>> # 流式 LLM 节点
>>> full_response = ""
>>> async for chunk in llm.astream(prompt):
... full_response += chunk
... yield chunk # yield 文本片段
>>>
>>> # 最后 yield 完成标记
>>> yield {"__final__": True, "result": AIMessage(content=full_response)}
"""
result = await self.execute(state)
# 默认实现:直接 yield 完成标记
yield {"__final__": True, "result": result}
def supports_streaming(self) -> bool:
"""节点是否支持流式输出
Returns:
是否支持流式输出
"""
# 检查子类是否重写了 execute_stream 方法
return self.execute_stream.__func__ != BaseNode.execute_stream.__func__
def get_timeout(self) -> int:
"""获取超时时间(秒)
Returns:
超时时间
"""
return 60
# return self.error_handling.get("timeout", 60)
async def run(self, state: WorkflowState) -> dict[str, Any]:
"""执行节点(带错误处理和输出包装,非流式)
这个方法由 Executor 调用,负责:
1. 时间统计
2. 调用节点的 execute() 方法
3. 将业务结果包装成标准输出格式
4. 错误处理
Args:
state: 工作流状态
Returns:
标准化的状态更新字典
"""
import time
start_time = time.time()
try:
timeout = self.get_timeout()
# 调用节点的业务逻辑
business_result = await asyncio.wait_for(
self.execute(state),
timeout=timeout
)
elapsed_time = time.time() - start_time
# 提取处理后的输出(调用子类的 _extract_output
extracted_output = self._extract_output(business_result)
# 包装成标准输出格式
wrapped_output = self._wrap_output(business_result, elapsed_time, state)
# 将提取后的输出存储到运行时变量中(供后续节点快速访问)
# 如果提取后的输出是字典,拆包存储;否则存储为 output 字段
if isinstance(extracted_output, dict):
runtime_var = extracted_output
else:
runtime_var = {"output": extracted_output}
# 返回包装后的输出和运行时变量
return {
**wrapped_output,
"runtime_vars": {
self.node_id: runtime_var
}
}
except TimeoutError:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
return self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
return self._wrap_error(str(e), elapsed_time, state)
async def run_stream(self, state: WorkflowState):
"""执行节点(带错误处理和输出包装,流式)
这个方法由 Executor 调用,负责:
1. 时间统计
2. 调用节点的 execute_stream() 方法
3. 将业务数据包装成标准输出格式
4. 错误处理
Args:
state: 工作流状态
Yields:
标准化的流式事件
"""
import time
start_time = time.time()
try:
timeout = self.get_timeout()
# 累积完整结果(用于最后的包装)
chunks = []
final_result = None
# 使用异步生成器包装,支持超时
async def stream_with_timeout():
nonlocal final_result
loop_start = asyncio.get_event_loop().time()
async for item in self.execute_stream(state):
# 检查超时
if asyncio.get_event_loop().time() - loop_start > timeout:
raise TimeoutError()
# 检查是否是完成标记
if isinstance(item, dict) and item.get("__final__"):
final_result = item["result"]
elif isinstance(item, str):
# 字符串是 chunk
chunks.append(item)
yield {
"type": "chunk",
"node_id": self.node_id,
"content": item,
"full_content": "".join(chunks)
}
else:
# 其他类型也当作 chunk 处理
chunks.append(str(item))
yield {
"type": "chunk",
"node_id": self.node_id,
"content": str(item),
"full_content": "".join(chunks)
}
async for chunk_event in stream_with_timeout():
yield chunk_event
elapsed_time = time.time() - start_time
# 包装最终结果
final_output = self._wrap_output(final_result, elapsed_time, state)
yield {
"type": "complete",
**final_output
}
except TimeoutError:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
yield {
"type": "error",
**self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
}
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
yield {
"type": "error",
**self._wrap_error(str(e), elapsed_time, state)
}
def _wrap_output(
self,
business_result: Any,
elapsed_time: float,
state: WorkflowState
) -> dict[str, Any]:
"""将业务结果包装成标准输出格式
Args:
business_result: 节点返回的业务结果
elapsed_time: 执行耗时
state: 工作流状态
Returns:
标准化的状态更新字典
"""
# 提取输入数据(用于记录)
input_data = self._extract_input(state)
# 提取 token 使用情况(如果有)
token_usage = self._extract_token_usage(business_result)
# 提取实际输出(去除元数据)
output = self._extract_output(business_result)
# 构建标准节点输出
node_output = {
"node_id": self.node_id,
"node_type": self.node_type,
"node_name": self.node_name,
"status": "completed",
"input": input_data,
"output": output,
"elapsed_time": elapsed_time,
"token_usage": token_usage,
"error": None
}
return {
"node_outputs": {
self.node_id: node_output
}
}
def _wrap_error(
self,
error_message: str,
elapsed_time: float,
state: WorkflowState
) -> dict[str, Any]:
"""将错误包装成标准输出格式
Args:
error_message: 错误信息
elapsed_time: 执行耗时
state: 工作流状态
Returns:
标准化的状态更新字典
"""
# 查找错误边
error_edge = self._find_error_edge()
# 提取输入数据
input_data = self._extract_input(state)
# 构建错误输出
node_output = {
"node_id": self.node_id,
"node_type": self.node_type,
"node_name": self.node_name,
"status": "failed",
"input": input_data,
"output": None,
"elapsed_time": elapsed_time,
"token_usage": None,
"error": error_message
}
if error_edge:
# 有错误边:记录错误并继续
logger.warning(
f"节点 {self.node_id} 执行失败,跳转到错误处理节点: {error_edge['target']}"
)
return {
"node_outputs": {
self.node_id: node_output
},
"error": error_message,
"error_node": self.node_id
}
else:
# 无错误边:抛出异常停止工作流
logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}")
raise Exception(f"节点 {self.node_id} 执行失败: {error_message}")
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
"""提取节点输入数据(用于记录)
子类可以重写此方法来自定义输入记录。
Args:
state: 工作流状态
Returns:
输入数据字典
"""
# 默认返回配置
return {"config": self.config}
def _extract_output(self, business_result: Any) -> Any:
"""从业务结果中提取实际输出
子类可以重写此方法来自定义输出提取。
Args:
business_result: 业务结果
Returns:
实际输出
"""
# 默认直接返回业务结果
return business_result
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
"""从业务结果中提取 token 使用情况
子类可以重写此方法来提取 token 信息。
Args:
business_result: 业务结果
Returns:
token 使用情况或 None
"""
# 默认返回 None
return None
def _find_error_edge(self) -> dict[str, Any] | None:
"""查找错误边
Returns:
错误边配置或 None
"""
for edge in self.workflow_config.get("edges", []):
if edge.get("source") == self.node_id and edge.get("type") == "error":
return edge
return None
def _render_template(self, template: str, state: WorkflowState | None) -> str:
"""渲染模板
支持的变量命名空间:
- sys.xxx: 系统变量message, execution_id, workspace_id, user_id, conversation_id
- conv.xxx: 会话变量(跨多轮对话保持)
- node_id.xxx: 节点输出
Args:
template: 模板字符串
state: 工作流状态
Returns:
渲染后的字符串
"""
from app.core.workflow.template_renderer import render_template
# 处理 state 为 None 的情况
if state is None:
state = {}
# 使用变量池获取变量
pool = VariablePool(state)
return render_template(
template=template,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars()
)
def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool:
"""评估条件表达式
支持的变量命名空间:
- sys.xxx: 系统变量
- conv.xxx: 会话变量
- node_id.xxx: 节点输出
Args:
expression: 条件表达式
state: 工作流状态
Returns:
布尔值结果
"""
from app.core.workflow.expression_evaluator import evaluate_condition
# 处理 state 为 None 的情况
if state is None:
state = {}
# 使用变量池获取变量
pool = VariablePool(state)
return evaluate_condition(
expression=expression,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars()
)
def get_variable_pool(self, state: WorkflowState) -> VariablePool:
"""获取变量池实例
VariablePool 是轻量级包装器,只持有 state 的引用,创建成本极低。
Args:
state: 工作流状态
Returns:
VariablePool 实例
Examples:
>>> pool = self.get_variable_pool(state)
>>> message = pool.get("sys.message")
>>> llm_output = pool.get("llm_qa.output")
"""
return VariablePool(state)
def get_variable(
self,
selector: list[str] | str,
state: WorkflowState,
default: Any = None
) -> Any:
"""获取变量值(便捷方法)
Args:
selector: 变量选择器
state: 工作流状态
default: 默认值
Returns:
变量值
Examples:
>>> message = self.get_variable("sys.message", state)
>>> output = self.get_variable(["llm_qa", "output"], state)
>>> custom = self.get_variable("var.custom", state, default="默认值")
"""
pool = VariablePool(state)
return pool.get(selector, default=default)
def has_variable(self, selector: list[str] | str, state: WorkflowState) -> bool:
"""检查变量是否存在(便捷方法)
Args:
selector: 变量选择器
state: 工作流状态
Returns:
变量是否存在
Examples:
>>> if self.has_variable("llm_qa.output", state):
... output = self.get_variable("llm_qa.output", state)
"""
pool = VariablePool(state)
return pool.has(selector)

View File

@@ -0,0 +1,29 @@
"""节点配置类统一导出
所有节点的配置类都在这里导出,方便使用。
"""
from app.core.workflow.nodes.base_config import (
BaseNodeConfig,
VariableDefinition,
VariableType,
)
from app.core.workflow.nodes.start.config import StartNodeConfig
from app.core.workflow.nodes.end.config import EndNodeConfig
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
from app.core.workflow.nodes.agent.config import AgentNodeConfig
from app.core.workflow.nodes.transform.config import TransformNodeConfig
__all__ = [
# 基础类
"BaseNodeConfig",
"VariableDefinition",
"VariableType",
# 节点配置
"StartNodeConfig",
"EndNodeConfig",
"LLMNodeConfig",
"MessageConfig",
"AgentNodeConfig",
"TransformNodeConfig",
]

View File

@@ -0,0 +1,6 @@
"""End 节点"""
from app.core.workflow.nodes.end.node import EndNode
from app.core.workflow.nodes.end.config import EndNodeConfig
__all__ = ["EndNode", "EndNodeConfig"]

View File

@@ -0,0 +1,37 @@
"""End 节点配置"""
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
class EndNodeConfig(BaseNodeConfig):
"""End 节点配置
End 节点负责输出工作流的最终结果。
"""
output: str = Field(
default="工作流已完成",
description="输出模板,支持引用前置节点的输出,如:{{ llm_qa.output }}"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="output",
type=VariableType.STRING,
description="工作流的最终输出"
)
],
description="输出变量定义(自动生成,通常不需要修改)"
)
class Config:
json_schema_extra = {
"example": {
"output": "{{ llm_qa.output }}",
"description": "输出 LLM 的回答"
}
}

View File

@@ -0,0 +1,53 @@
"""
End 节点实现
工作流的结束节点,输出最终结果。
"""
import logging
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
logger = logging.getLogger(__name__)
class EndNode(BaseNode):
"""End 节点
工作流的结束节点,根据配置的模板输出最终结果。
"""
async def execute(self, state: WorkflowState) -> str:
"""执行 end 节点业务逻辑
Args:
state: 工作流状态
Returns:
最终输出字符串
"""
logger.info(f"节点 {self.node_id} (End) 开始执行")
# 获取配置的输出模板
output_template = self.config.get("output")
pool = self.get_variable_pool(state)
print("="*20)
print( pool.get("start.test"))
print("="*20)
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template:
output = self._render_template(output_template, state)
else:
output = "工作流已完成"
# 统计信息(用于日志)
node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
print("="*20)
print(output)
print("="*20)
return output

View File

@@ -0,0 +1,15 @@
from enum import StrEnum
class NodeType(StrEnum):
START = "start"
END = "end"
ANSWER = "answer"
LLM = "llm"
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
IF_ELSE = "if-else"
CODE = "code"
TRANSFORM = "transform"
QUESTION_CLASSIFIER = "question-classifier"
HTTP_REQUEST = "http-request"
TOOL = "tool"
AGENT = "agent"

View File

@@ -0,0 +1,6 @@
"""LLM 节点"""
from app.core.workflow.nodes.llm.node import LLMNode
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
__all__ = ["LLMNode", "LLMNodeConfig", "MessageConfig"]

View File

@@ -0,0 +1,141 @@
"""LLM 节点配置"""
from pydantic import BaseModel, Field, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
class MessageConfig(BaseModel):
"""消息配置"""
role: str = Field(
...,
description="消息角色system, user, assistant"
)
content: str = Field(
...,
description="消息内容,支持模板变量,如:{{ sys.message }}"
)
@field_validator("role")
@classmethod
def validate_role(cls, v: str) -> str:
"""验证角色"""
allowed_roles = ["system", "user", "human", "assistant", "ai"]
if v.lower() not in allowed_roles:
raise ValueError(f"角色必须是以下之一: {', '.join(allowed_roles)}")
return v.lower()
class LLMNodeConfig(BaseNodeConfig):
"""LLM 节点配置
支持两种配置方式:
1. 简单模式:使用 prompt 字段
2. 消息模式:使用 messages 字段(推荐)
"""
model_id: str = Field(
...,
description="模型配置 ID"
)
# 简单模式
prompt: str | None = Field(
default=None,
description="提示词模板(简单模式),支持变量引用"
)
# 消息模式(推荐)
messages: list[MessageConfig] | None = Field(
default=None,
description="消息列表(消息模式),支持多轮对话"
)
# 模型参数
temperature: float | None = Field(
default=0.7,
ge=0.0,
le=2.0,
description="温度参数,控制输出的随机性"
)
max_tokens: int | None = Field(
default=1000,
ge=1,
le=32000,
description="最大生成 token 数"
)
top_p: float | None = Field(
default=None,
ge=0.0,
le=1.0,
description="Top-p 采样参数"
)
frequency_penalty: float | None = Field(
default=None,
ge=-2.0,
le=2.0,
description="频率惩罚"
)
presence_penalty: float | None = Field(
default=None,
ge=-2.0,
le=2.0,
description="存在惩罚"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="output",
type=VariableType.STRING,
description="LLM 生成的文本输出"
),
VariableDefinition(
name="token_usage",
type=VariableType.OBJECT,
description="Token 使用情况"
)
],
description="输出变量定义(自动生成,通常不需要修改)"
)
@field_validator("messages", "prompt")
@classmethod
def validate_input_mode(cls, v, info):
"""验证输入模式prompt 和 messages 至少有一个"""
# 这个验证在 model_validator 中更合适
return v
class Config:
json_schema_extra = {
"examples": [
{
"model_id": "uuid-here",
"prompt": "请回答:{{ sys.message }}",
"temperature": 0.7,
"max_tokens": 1000
},
{
"model_id": "uuid-here",
"messages": [
{
"role": "system",
"content": "你是一个专业的 AI 助手"
},
{
"role": "user",
"content": "{{ sys.message }}"
}
],
"temperature": 0.7,
"max_tokens": 1000
}
]
}

View File

@@ -0,0 +1,247 @@
"""
LLM 节点实现
调用 LLM 模型进行文本生成。
"""
import logging
from typing import Any
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models import ModelConfig
from app.db import get_db, get_db_context
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService, ModelApiKeyService
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
logger = logging.getLogger(__name__)
class LLMNode(BaseNode):
"""LLM 节点
支持流式和非流式输出,使用 LangChain 标准的消息格式。
配置示例(支持多种消息格式):
1. 简单文本格式:
{
"type": "llm",
"config": {
"model_id": "uuid",
"prompt": "请分析:{{sys.message}}",
"temperature": 0.7,
"max_tokens": 1000
}
}
2. LangChain 消息格式(推荐):
{
"type": "llm",
"config": {
"model_id": "uuid",
"messages": [
{
"role": "system",
"content": "你是一个专业的 AI 助手。"
},
{
"role": "user",
"content": "{{sys.message}}"
}
],
"temperature": 0.7,
"max_tokens": 1000
}
}
支持的角色类型:
- system: 系统消息SystemMessage
- user/human: 用户消息HumanMessage
- ai/assistant: AI 消息AIMessage
"""
def _prepare_llm(self, state: WorkflowState) -> tuple[RedBearLLM, list | str]:
"""准备 LLM 实例(公共逻辑)
Args:
state: 工作流状态
Returns:
(llm, messages_or_prompt): LLM 实例和消息列表或 prompt 字符串
"""
# 1. 处理消息格式(优先使用 messages
messages_config = self.config.get("messages")
if messages_config:
# 使用 LangChain 消息格式
messages = []
for msg_config in messages_config:
role = msg_config.get("role", "user").lower()
content_template = msg_config.get("content", "")
content = self._render_template(content_template, state)
# 根据角色创建对应的消息对象
if role == "system":
messages.append(SystemMessage(content=content))
elif role in ["user", "human"]:
messages.append(HumanMessage(content=content))
elif role in ["ai", "assistant"]:
messages.append(AIMessage(content=content))
else:
logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append(HumanMessage(content=content))
prompt_or_messages = messages
else:
# 使用简单的 prompt 格式(向后兼容)
prompt_template = self.config.get("prompt", "")
prompt_or_messages = self._render_template(prompt_template, state)
# 2. 获取模型配置
model_id = self.config.get("model_id")
if not model_id:
raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置")
# 3. 在 with 块内完成所有数据库操作和数据提取
with get_db_context() as db:
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
if not config:
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
if not config.api_keys or len(config.api_keys) == 0:
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
# 在 Session 关闭前提取所有需要的数据
api_config = config.api_keys[0]
model_name = api_config.model_name
provider = api_config.provider
api_key = api_config.api_key
api_base = api_config.api_base
model_type = config.type
# 4. 创建 LLM 实例(使用已提取的数据)
llm = RedBearLLM(
RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base
),
type=model_type
)
return llm, prompt_or_messages
async def execute(self, state: WorkflowState) -> AIMessage:
"""非流式执行 LLM 调用
Args:
state: 工作流状态
Returns:
LLM 响应消息
"""
llm, prompt_or_messages = self._prepare_llm(state)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
# 调用 LLM支持字符串或消息列表
response = await llm.ainvoke(prompt_or_messages)
# 提取内容
if hasattr(response, 'content'):
content = response.content
else:
content = str(response)
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
# 返回 AIMessage包含响应元数据
return response if isinstance(response, AIMessage) else AIMessage(content=content)
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
"""提取输入数据(用于记录)"""
_, prompt_or_messages = self._prepare_llm(state)
return {
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
"messages": [
{"role": msg.__class__.__name__.replace("Message", "").lower(), "content": msg.content}
for msg in prompt_or_messages
] if isinstance(prompt_or_messages, list) else None,
"config": {
"model_id": self.config.get("model_id"),
"temperature": self.config.get("temperature"),
"max_tokens": self.config.get("max_tokens")
}
}
def _extract_output(self, business_result: Any) -> str:
"""从 AIMessage 中提取文本内容"""
if isinstance(business_result, AIMessage):
return business_result.content
return str(business_result)
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
"""从 AIMessage 中提取 token 使用情况"""
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
usage = business_result.response_metadata.get('token_usage')
if usage:
return {
"prompt_tokens": usage.get('prompt_tokens', 0),
"completion_tokens": usage.get('completion_tokens', 0),
"total_tokens": usage.get('total_tokens', 0)
}
return None
async def execute_stream(self, state: WorkflowState):
"""流式执行 LLM 调用
Args:
state: 工作流状态
Yields:
文本片段chunk或完成标记
"""
llm, prompt_or_messages = self._prepare_llm(state)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
# 累积完整响应
full_response = ""
last_chunk = None
# 调用 LLM流式支持字符串或消息列表
async for chunk in llm.astream(prompt_or_messages):
# 提取内容
if hasattr(chunk, 'content'):
content = chunk.content
else:
content = str(chunk)
full_response += content
last_chunk = chunk
# 流式返回每个文本片段
yield content
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
# 构建完整的 AIMessage包含元数据
if isinstance(last_chunk, AIMessage):
final_message = AIMessage(
content=full_response,
response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
)
else:
final_message = AIMessage(content=full_response)
# yield 完成标记
yield {"__final__": True, "result": final_message}

View File

@@ -0,0 +1,93 @@
"""
节点工厂
根据节点类型创建相应的节点实例。
"""
import logging
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.end import EndNode
logger = logging.getLogger(__name__)
class NodeFactory:
"""节点工厂
使用工厂模式创建节点实例,便于扩展和维护。
"""
# 节点类型注册表
_node_types: dict[str, type[BaseNode]] = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.LLM: LLMNode,
NodeType.AGENT: AgentNode,
NodeType.TRANSFORM: TransformNode,
}
@classmethod
def register_node_type(cls, node_type: str, node_class: type[BaseNode]):
"""注册新的节点类型
Args:
node_type: 节点类型名称
node_class: 节点类
Examples:
>>> class CustomNode(BaseNode):
... async def execute(self, state):
... return {"node_outputs": {self.node_id: {"output": "custom"}}}
>>> NodeFactory.register_node_type("custom", CustomNode)
"""
cls._node_types[node_type] = node_class
logger.info(f"注册节点类型: {node_type} -> {node_class.__name__}")
@classmethod
def create_node(
cls,
node_config: dict[str, Any],
workflow_config: dict[str, Any]
) -> BaseNode | None:
"""创建节点实例
Args:
node_config: 节点配置
workflow_config: 工作流配置
Returns:
节点实例或 None对于不支持的节点类型
Raises:
ValueError: 不支持的节点类型
"""
node_type = node_config.get("type")
# 跳过条件节点(由 LangGraph 处理)
if node_type == "condition":
return None
# 获取节点类
node_class = cls._node_types.get(node_type)
if not node_class:
raise ValueError(f"不支持的节点类型: {node_type}")
# 创建节点实例
logger.debug(f"创建节点: {node_config.get('id')} (type={node_type})")
return node_class(node_config, workflow_config)
@classmethod
def get_supported_types(cls) -> list[str]:
"""获取支持的节点类型列表
Returns:
节点类型列表
"""
return list(cls._node_types.keys())

View File

@@ -0,0 +1,6 @@
"""Start 节点"""
from app.core.workflow.nodes.start.node import StartNode
from app.core.workflow.nodes.start.config import StartNodeConfig
__all__ = ["StartNode", "StartNodeConfig"]

View File

@@ -0,0 +1,87 @@
"""Start 节点配置"""
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
class StartNodeConfig(BaseNodeConfig):
"""Start 节点配置
Start 节点的作用:
1. 标记工作流的起点
2. 定义自定义输入变量(会作为节点输出,通过 start_node_id.variable_name 访问)
3. 输出系统变量和会话变量
"""
# 自定义输入变量定义
variables: list[VariableDefinition] = Field(
default_factory=list,
description="自定义输入变量列表,这些变量会作为 Start 节点的输出"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="message",
type=VariableType.STRING,
description="用户输入的消息"
),
VariableDefinition(
name="conversation_vars",
type=VariableType.OBJECT,
description="会话级变量"
),
VariableDefinition(
name="execution_id",
type=VariableType.STRING,
description="执行 ID"
),
VariableDefinition(
name="conversation_id",
type=VariableType.STRING,
description="会话 ID"
),
VariableDefinition(
name="workspace_id",
type=VariableType.STRING,
description="工作空间 ID"
),
VariableDefinition(
name="user_id",
type=VariableType.STRING,
description="用户 ID"
)
],
description="输出变量定义(自动生成,通常不需要修改)"
)
class Config:
json_schema_extra = {
"examples": [
{
"description": "工作流开始节点",
"variables": []
},
{
"description": "带自定义变量的开始节点",
"variables": [
{
"name": "language",
"type": "string",
"required": False,
"default": "zh-CN",
"description": "语言设置"
},
{
"name": "max_length",
"type": "number",
"required": False,
"default": 1000,
"description": "最大长度"
}
]
}
]
}

View File

@@ -0,0 +1,136 @@
"""
Start 节点实现
工作流的起始节点,定义输入变量并输出系统参数。
"""
import logging
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.start.config import StartNodeConfig
logger = logging.getLogger(__name__)
class StartNode(BaseNode):
"""Start 节点
工作流的起始节点,负责:
1. 定义工作流的输入变量(通过配置)
2. 输出系统变量sys.*
3. 输出会话变量conv.*
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
"""初始化 Start 节点
Args:
node_config: 节点配置
workflow_config: 工作流配置
"""
super().__init__(node_config, workflow_config)
# 解析并验证配置
self.typed_config = StartNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行 start 节点业务逻辑
Start 节点输出系统变量、会话变量和自定义变量。
Args:
state: 工作流状态
Returns:
包含系统参数、会话变量和自定义变量的字典
"""
logger.info(f"节点 {self.node_id} (Start) 开始执行")
# 创建变量池实例(在方法内复用)
pool = self.get_variable_pool(state)
# 处理自定义变量(传入 pool 避免重复创建)
custom_vars = self._process_custom_variables(pool)
# 返回业务数据(包含自定义变量)
result = {
"message": pool.get("sys.message"),
"execution_id": pool.get("sys.execution_id"),
"conversation_id": pool.get("sys.conversation_id"),
"workspace_id": pool.get("sys.workspace_id"),
"user_id": pool.get("sys.user_id"),
**custom_vars # 自定义变量作为节点输出的一部分
}
logger.info(
f"节点 {self.node_id} (Start) 执行完成,"
f"输出了 {len(custom_vars)} 个自定义变量"
)
return result
def _process_custom_variables(self, pool) -> dict[str, Any]:
"""处理自定义变量
从输入数据中提取自定义变量,应用默认值和验证。
Args:
pool: 变量池实例
Returns:
处理后的自定义变量字典
Raises:
ValueError: 缺少必需变量
"""
# 获取输入数据中的自定义变量
input_variables = pool.get("sys.input_variables", default={})
processed = {}
# 遍历配置的变量定义
for var_def in self.typed_config.variables:
var_name = var_def.name
# 检查变量是否存在
if var_name in input_variables:
# 使用用户提供的值
processed[var_name] = input_variables[var_name]
elif var_def.required:
# 必需变量缺失
raise ValueError(
f"缺少必需的输入变量: {var_name}"
+ (f" ({var_def.description})" if var_def.description else "")
)
elif var_def.default is not None:
# 使用默认值
processed[var_name] = var_def.default
logger.debug(
f"变量 '{var_name}' 使用默认值: {var_def.default}"
)
return processed
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
"""提取输入数据(用于记录)
Args:
state: 工作流状态
Returns:
输入数据字典
"""
pool = self.get_variable_pool(state)
return {
"execution_id": pool.get("sys.execution_id"),
"conversation_id": pool.get("sys.conversation_id"),
"message": pool.get("sys.message"),
"conversation_vars": pool.get_all_conversation_vars()
}

View File

@@ -0,0 +1,6 @@
"""Transform 节点"""
from app.core.workflow.nodes.transform.node import TransformNode
from app.core.workflow.nodes.transform.config import TransformNodeConfig
__all__ = ["TransformNode", "TransformNodeConfig"]

View File

@@ -0,0 +1,80 @@
"""Transform 节点配置"""
from typing import Literal
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
class TransformNodeConfig(BaseNodeConfig):
"""Transform 节点配置
用于数据转换和处理。
"""
transform_type: Literal["template", "code", "json"] = Field(
default="template",
description="转换类型template(模板), code(代码), json(JSON处理)"
)
# 模板模式
template: str | None = Field(
default=None,
description="转换模板,支持变量引用"
)
# 代码模式
code: str | None = Field(
default=None,
description="Python 代码,用于数据转换"
)
# JSON 模式
json_path: str | None = Field(
default=None,
description="JSON 路径表达式"
)
# 输入变量
inputs: dict[str, str] | None = Field(
default=None,
description="输入变量映射key 为变量名value 为变量选择器"
)
# 输出变量
output_key: str = Field(
default="result",
description="输出变量的键名"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="result",
type=VariableType.STRING,
description="转换后的结果"
)
],
description="输出变量定义(根据 output_key 动态生成)"
)
class Config:
json_schema_extra = {
"examples": [
{
"transform_type": "template",
"template": "用户问题:{{ sys.message }}\n回答:{{ llm_qa.output }}",
"output_key": "formatted_result"
},
{
"transform_type": "code",
"code": "result = input_text.upper()",
"inputs": {
"input_text": "{{ sys.message }}"
},
"output_key": "uppercase_text"
}
]
}

View File

@@ -0,0 +1,60 @@
"""
Transform 节点实现
数据转换节点,用于处理和转换数据。
"""
import logging
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
logger = logging.getLogger(__name__)
class TransformNode(BaseNode):
"""数据转换节点
配置示例:
{
"type": "transform",
"config": {
"mapping": {
"output_field": "{{node.previous.output}}",
"processed": "{{var.input | upper}}"
}
}
}
"""
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行数据转换
Args:
state: 工作流状态
Returns:
状态更新字典
"""
logger.info(f"节点 {self.node_id} 开始执行数据转换")
# 获取映射配置
mapping = self.config.get("mapping", {})
# 执行数据转换
transformed_data = {}
for target_key, source_template in mapping.items():
# 渲染模板获取值
value = self._render_template(str(source_template), state)
transformed_data[target_key] = value
logger.info(f"节点 {self.node_id} 数据转换完成,输出字段: {list(transformed_data.keys())}")
return {
"node_outputs": {
self.node_id: {
"output": transformed_data,
"status": "completed"
}
}
}

View File

@@ -0,0 +1,170 @@
"""
工作流模板加载器
从文件系统加载预定义的工作流模板
"""
import os
import yaml
from pathlib import Path
from typing import Optional
class TemplateLoader:
"""工作流模板加载器"""
def __init__(self, templates_dir: str = "app/templates/workflows"):
"""初始化模板加载器
Args:
templates_dir: 模板目录路径
"""
self.templates_dir = Path(templates_dir)
if not self.templates_dir.exists():
raise ValueError(f"模板目录不存在: {templates_dir}")
def list_templates(self) -> list[dict]:
"""列出所有可用的模板
Returns:
模板列表,每个模板包含 id, name, description 等信息
"""
templates = []
# 遍历模板目录
for template_dir in self.templates_dir.iterdir():
if not template_dir.is_dir():
continue
# 检查是否有 template.yml 文件
template_file = template_dir / "template.yml"
if not template_file.exists():
continue
try:
# 读取模板配置
with open(template_file, 'r', encoding='utf-8') as f:
template_data = yaml.safe_load(f)
# 提取模板信息
templates.append({
"id": template_dir.name,
"name": template_data.get("name", template_dir.name),
"description": template_data.get("description", ""),
"category": template_data.get("category", "general"),
"tags": template_data.get("tags", []),
"author": template_data.get("author", ""),
"version": template_data.get("version", "1.0.0")
})
except Exception as e:
print(f"加载模板 {template_dir.name} 失败: {e}")
continue
return templates
def load_template(self, template_id: str) -> Optional[dict]:
"""加载指定的模板
Args:
template_id: 模板 ID目录名
Returns:
模板配置字典,如果模板不存在则返回 None
"""
template_dir = self.templates_dir / template_id
template_file = template_dir / "template.yml"
if not template_file.exists():
return None
try:
with open(template_file, 'r', encoding='utf-8') as f:
template_data = yaml.safe_load(f)
# 返回工作流配置部分
return {
"name": template_data.get("name", template_id),
"description": template_data.get("description", ""),
"nodes": template_data.get("nodes", []),
"edges": template_data.get("edges", []),
"variables": template_data.get("variables", []),
"execution_config": template_data.get("execution_config", {}),
"triggers": template_data.get("triggers", [])
}
except Exception as e:
print(f"加载模板 {template_id} 失败: {e}")
return None
def get_template_readme(self, template_id: str) -> Optional[str]:
"""获取模板的 README 文档
Args:
template_id: 模板 ID
Returns:
README 内容,如果不存在则返回 None
"""
template_dir = self.templates_dir / template_id
readme_file = template_dir / "README.md"
if not readme_file.exists():
return None
try:
with open(readme_file, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
print(f"读取模板 {template_id} 的 README 失败: {e}")
return None
# 全局模板加载器实例
_template_loader: Optional[TemplateLoader] = None
def get_template_loader() -> TemplateLoader:
"""获取全局模板加载器实例
Returns:
TemplateLoader 实例
"""
global _template_loader
if _template_loader is None:
_template_loader = TemplateLoader()
return _template_loader
def list_workflow_templates() -> list[dict]:
"""列出所有工作流模板
Returns:
模板列表
"""
loader = get_template_loader()
return loader.list_templates()
def load_workflow_template(template_id: str) -> Optional[dict]:
"""加载工作流模板
Args:
template_id: 模板 ID
Returns:
模板配置,如果不存在则返回 None
"""
loader = get_template_loader()
return loader.load_template(template_id)
def get_workflow_template_readme(template_id: str) -> Optional[str]:
"""获取工作流模板的 README
Args:
template_id: 模板 ID
Returns:
README 内容,如果不存在则返回 None
"""
loader = get_template_loader()
return loader.get_template_readme(template_id)

View File

@@ -0,0 +1,170 @@
"""
模板渲染器
使用 Jinja2 提供安全的模板渲染功能,支持变量引用和表达式。
"""
import logging
from typing import Any
from jinja2 import Template, TemplateSyntaxError, UndefinedError, Environment, StrictUndefined
logger = logging.getLogger(__name__)
class TemplateRenderer:
"""模板渲染器"""
def __init__(self, strict: bool = True):
"""初始化渲染器
Args:
strict: 是否使用严格模式(未定义变量会抛出异常)
"""
self.env = Environment(
undefined=StrictUndefined if strict else None,
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
)
def render(
self,
template: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> str:
"""渲染模板
Args:
template: 模板字符串
variables: 用户定义的变量
node_outputs: 节点输出结果
system_vars: 系统变量
Returns:
渲染后的字符串
Raises:
ValueError: 模板语法错误或变量未定义
Examples:
>>> renderer = TemplateRenderer()
>>> renderer.render(
... "Hello {{var.name}}!",
... {"name": "World"},
... {},
... {}
... )
'Hello World!'
>>> renderer.render(
... "分析结果: {{node.analyze.output}}",
... {},
... {"analyze": {"output": "正面情绪"}},
... {}
... )
'分析结果: 正面情绪'
"""
# 构建命名空间上下文
context = {
"var": variables, # 用户变量:{{var.user_input}}
"node": node_outputs, # 节点输出:{{node.node_1.output}}
"sys": system_vars or {}, # 系统变量:{{sys.execution_id}}
}
# 支持直接通过节点ID访问节点输出{{llm_qa.output}}
# 将所有节点输出添加到顶层上下文
context.update(node_outputs)
# 为了向后兼容,也支持直接访问用户变量
context.update(variables)
context["nodes"] = node_outputs # 旧语法兼容
try:
tmpl = self.env.from_string(template)
return tmpl.render(**context)
except TemplateSyntaxError as e:
logger.error(f"模板语法错误: {template}, 错误: {e}")
raise ValueError(f"模板语法错误: {e}")
except UndefinedError as e:
logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}")
raise ValueError(f"未定义的变量: {e}")
except Exception as e:
logger.error(f"模板渲染异常: {template}, 错误: {e}")
raise ValueError(f"模板渲染失败: {e}")
def validate(self, template: str) -> list[str]:
"""验证模板语法
Args:
template: 模板字符串
Returns:
错误列表,如果为空则验证通过
Examples:
>>> renderer = TemplateRenderer()
>>> renderer.validate("Hello {{var.name}}!")
[]
>>> renderer.validate("Hello {{var.name") # 缺少结束标记
['模板语法错误: ...']
"""
errors = []
try:
self.env.from_string(template)
except TemplateSyntaxError as e:
errors.append(f"模板语法错误: {e}")
except Exception as e:
errors.append(f"模板验证失败: {e}")
return errors
# 全局渲染器实例(严格模式)
_default_renderer = TemplateRenderer(strict=True)
def render_template(
template: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> str:
"""渲染模板(便捷函数)
Args:
template: 模板字符串
variables: 用户变量
node_outputs: 节点输出
system_vars: 系统变量
Returns:
渲染后的字符串
Examples:
>>> render_template(
... "请分析: {{var.text}}",
... {"text": "这是一段文本"},
... {},
... {}
... )
'请分析: 这是一段文本'
"""
return _default_renderer.render(template, variables, node_outputs, system_vars)
def validate_template(template: str) -> list[str]:
"""验证模板语法(便捷函数)
Args:
template: 模板字符串
Returns:
错误列表
"""
return _default_renderer.validate(template)

View File

@@ -0,0 +1,277 @@
"""
工作流配置验证器
验证工作流配置的有效性,确保配置符合规范。
"""
import logging
from typing import Any, Union
logger = logging.getLogger(__name__)
class WorkflowValidator:
"""工作流配置验证器"""
@staticmethod
def validate(workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
"""验证工作流配置
Args:
workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型
Returns:
(is_valid, errors): 是否有效和错误列表
Examples:
>>> config = {
... "nodes": [
... {"id": "start", "type": "start"},
... {"id": "end", "type": "end"}
... ],
... "edges": [
... {"source": "start", "target": "end"}
... ]
... }
>>> is_valid, errors = WorkflowValidator.validate(config)
>>> is_valid
True
"""
errors = []
# 支持字典和 Pydantic 模型
if isinstance(workflow_config, dict):
nodes = workflow_config.get("nodes", [])
edges = workflow_config.get("edges", [])
variables = workflow_config.get("variables", [])
else:
# Pydantic 模型
nodes = getattr(workflow_config, "nodes", [])
edges = getattr(workflow_config, "edges", [])
variables = getattr(workflow_config, "variables", [])
# 1. 验证 start 节点(有且只有一个)
start_nodes = [n for n in nodes if n.get("type") == "start"]
if len(start_nodes) == 0:
errors.append("工作流必须有一个 start 节点")
elif len(start_nodes) > 1:
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)}")
# 2. 验证 end 节点(至少一个)
end_nodes = [n for n in nodes if n.get("type") == "end"]
if len(end_nodes) == 0:
errors.append("工作流必须至少有一个 end 节点")
# 3. 验证节点 ID 唯一性
node_ids = [n.get("id") for n in nodes]
if len(node_ids) != len(set(node_ids)):
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
# 4. 验证节点必须有 id 和 type
for i, node in enumerate(nodes):
if not node.get("id"):
errors.append(f"节点 #{i} 缺少 id 字段")
if not node.get("type"):
errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段")
# 5. 验证边的有效性
node_id_set = set(node_ids)
for i, edge in enumerate(edges):
source = edge.get("source")
target = edge.get("target")
if not source:
errors.append(f"边 #{i} 缺少 source 字段")
elif source not in node_id_set:
errors.append(f"边 #{i} 的 source 节点不存在: {source}")
if not target:
errors.append(f"边 #{i} 缺少 target 字段")
elif target not in node_id_set:
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
# 6. 验证所有节点可达(从 start 节点出发)
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
reachable = WorkflowValidator._get_reachable_nodes(
start_nodes[0]["id"],
edges
)
unreachable = node_id_set - reachable
if unreachable:
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
# 7. 检测循环依赖(非 loop 节点)
if not errors: # 只有在前面验证通过时才检查循环
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
if has_cycle:
errors.append(
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
)
# 8. 验证变量名
from app.core.workflow.expression_evaluator import ExpressionEvaluator
var_errors = ExpressionEvaluator.validate_variable_names(variables)
errors.extend(var_errors)
return len(errors) == 0, errors
@staticmethod
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
"""获取从 start 节点可达的所有节点
Args:
start_id: 起始节点 ID
edges: 边列表
Returns:
可达节点 ID 集合
"""
reachable = {start_id}
queue = [start_id]
while queue:
current = queue.pop(0)
for edge in edges:
if edge.get("source") == current:
target = edge.get("target")
if target and target not in reachable:
reachable.add(target)
queue.append(target)
return reachable
@staticmethod
def _has_cycle(nodes: list[dict], edges: list[dict]) -> tuple[bool, list[str]]:
"""检测是否存在循环依赖DFS
Args:
nodes: 节点列表
edges: 边列表
Returns:
(has_cycle, cycle_path): 是否有循环和循环路径
"""
# 排除 loop 类型的节点
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
# 构建邻接表(排除 loop 节点的边和错误边)
graph: dict[str, list[str]] = {}
for edge in edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
# 跳过错误边
if edge_type == "error":
continue
# 如果涉及 loop 节点,跳过
if source in loop_nodes or target in loop_nodes:
continue
if source and target:
if source not in graph:
graph[source] = []
graph[source].append(target)
# DFS 检测环
visited = set()
rec_stack = set()
path = []
cycle_path = []
def dfs(node: str) -> bool:
"""DFS 检测环,返回是否找到环"""
visited.add(node)
rec_stack.add(node)
path.append(node)
for neighbor in graph.get(node, []):
if neighbor not in visited:
if dfs(neighbor):
return True
elif neighbor in rec_stack:
# 找到环,记录环路径
cycle_start = path.index(neighbor)
cycle_path.extend([*path[cycle_start:], neighbor])
return True
rec_stack.remove(node)
path.pop()
return False
# 检查所有节点
for node_id in graph:
if node_id not in visited:
if dfs(node_id):
return True, cycle_path
return False, []
@staticmethod
def validate_for_publish(workflow_config: dict[str, Any]) -> tuple[bool, list[str]]:
"""验证工作流配置是否可以发布(更严格的验证)
Args:
workflow_config: 工作流配置
Returns:
(is_valid, errors): 是否有效和错误列表
"""
# 先执行基础验证
is_valid, errors = WorkflowValidator.validate(workflow_config)
if not is_valid:
return False, errors
# 额外的发布验证
nodes = workflow_config.get("nodes", [])
# 1. 验证所有节点都有名称
for node in nodes:
if node.get("type") not in ["start", "end"] and not node.get("name"):
errors.append(
f"节点 {node.get('id')} 缺少名称(发布时必须提供)"
)
# 2. 验证所有非 start/end 节点都有配置
for node in nodes:
node_type = node.get("type")
if node_type not in ["start", "end"]:
config = node.get("config")
if not config or not isinstance(config, dict):
errors.append(
f"节点 {node.get('id')} 缺少配置(发布时必须提供)"
)
# 3. 验证必填变量
variables = workflow_config.get("variables", [])
required_vars = [v for v in variables if v.get("required")]
if required_vars:
# 这里只是提示,实际执行时会检查
logger.info(
f"工作流包含 {len(required_vars)} 个必填变量: "
f"{[v.get('name') for v in required_vars]}"
)
return len(errors) == 0, errors
def validate_workflow_config(
workflow_config: dict[str, Any],
for_publish: bool = False
) -> tuple[bool, list[str]]:
"""验证工作流配置(便捷函数)
Args:
workflow_config: 工作流配置
for_publish: 是否为发布验证(更严格)
Returns:
(is_valid, errors): 是否有效和错误列表
"""
if for_publish:
return WorkflowValidator.validate_for_publish(workflow_config)
else:
return WorkflowValidator.validate(workflow_config)

View File

@@ -0,0 +1,293 @@
"""
变量池 (Variable Pool)
工作流执行的数据中心,管理所有变量的存储和访问。
变量类型:
1. 系统变量 (sys.*) - 系统内置变量execution_id, workspace_id, user_id, message 等)
2. 节点输出 (node_id.*) - 节点执行结果
3. 会话变量 (conv.*) - 会话级变量(跨多轮对话保持)
"""
import logging
from typing import Any
logger = logging.getLogger(__name__)
class VariableSelector:
"""变量选择器
用于引用变量的路径表示。
Examples:
>>> selector = VariableSelector(["sys", "message"])
>>> selector = VariableSelector(["node_A", "output"])
>>> selector = VariableSelector.from_string("sys.message")
"""
def __init__(self, path: list[str]):
"""初始化变量选择器
Args:
path: 变量路径,如 ["sys", "message"] 或 ["node_A", "output"]
"""
if not path or len(path) < 1:
raise ValueError("变量路径不能为空")
self.path = path
self.namespace = path[0] # sys, var, 或 node_id
self.key = path[1] if len(path) > 1 else None
@classmethod
def from_string(cls, selector_str: str) -> "VariableSelector":
"""从字符串创建选择器
Args:
selector_str: 选择器字符串,如 "sys.message""node_A.output"
Returns:
VariableSelector 实例
Examples:
>>> selector = VariableSelector.from_string("sys.message")
>>> selector = VariableSelector.from_string("llm_qa.output")
"""
path = selector_str.split(".")
return cls(path)
def __str__(self) -> str:
return ".".join(self.path)
def __repr__(self) -> str:
return f"VariableSelector({self.path})"
class VariablePool:
"""变量池
管理工作流执行过程中的所有变量。
变量命名空间:
- sys.*: 系统变量message, execution_id, workspace_id, user_id, conversation_id
- conv.*: 会话变量(跨多轮对话保持的变量)
- <node_id>.*: 节点输出
Examples:
>>> pool = VariablePool(state)
>>> pool.get(["sys", "message"])
"用户的问题"
>>> pool.get(["llm_qa", "output"])
"AI 的回答"
>>> pool.set(["conv", "user_name"], "张三")
"""
def __init__(self, state: dict[str, Any]):
"""初始化变量池
Args:
state: 工作流状态LangGraph State
"""
self.state = state
def get(self, selector: list[str] | str, default: Any = None) -> Any:
"""获取变量值
Args:
selector: 变量选择器,可以是列表或字符串
default: 默认值(变量不存在时返回)
Returns:
变量值
Examples:
>>> pool.get(["sys", "message"])
>>> pool.get("sys.message")
>>> pool.get(["llm_qa", "output"])
>>> pool.get("llm_qa.output")
Raises:
KeyError: 变量不存在且未提供默认值
"""
# 转换为 VariableSelector
if isinstance(selector, str):
selector = VariableSelector.from_string(selector).path
if not selector or len(selector) < 1:
raise ValueError("变量选择器不能为空")
namespace = selector[0]
try:
# 系统变量
if namespace == "sys":
key = selector[1] if len(selector) > 1 else None
if not key:
return self.state.get("variables", {}).get("sys", {})
return self.state.get("variables", {}).get("sys", {}).get(key, default)
# 会话变量
elif namespace == "conv":
key = selector[1] if len(selector) > 1 else None
if not key:
return self.state.get("variables", {}).get("conv", {})
return self.state.get("variables", {}).get("conv", {}).get(key, default)
# 节点输出(从 runtime_vars 读取)
else:
node_id = namespace
runtime_vars = self.state.get("runtime_vars", {})
if node_id not in runtime_vars:
if default is not None:
return default
raise KeyError(f"节点 '{node_id}' 的输出不存在")
node_var = runtime_vars[node_id]
# 如果只有节点 ID返回整个变量
if len(selector) == 1:
return node_var
# 获取特定字段
# 支持嵌套访问,如 node_id.field.subfield
result = node_var
for k in selector[1:]:
if isinstance(result, dict):
result = result.get(k)
if result is None:
if default is not None:
return default
raise KeyError(f"字段 '{'.'.join(selector)}' 不存在")
else:
if default is not None:
return default
raise KeyError(f"无法访问 '{'.'.join(selector)}'")
return result
except KeyError:
if default is not None:
return default
raise
def set(self, selector: list[str] | str, value: Any):
"""设置变量值
Args:
selector: 变量选择器
value: 变量值
Examples:
>>> pool.set(["conv", "user_name"], "张三")
>>> pool.set("conv.user_name", "张三")
Note:
- 只能设置会话变量 (conv.*)
- 系统变量和节点输出是只读的
"""
# 转换为 VariableSelector
if isinstance(selector, str):
selector = VariableSelector.from_string(selector).path
if not selector or len(selector) < 2:
raise ValueError("变量选择器必须包含命名空间和键名")
namespace = selector[0]
if namespace != "conv":
raise ValueError("只能设置会话变量 (conv.*)")
key = selector[1]
# 确保 variables 结构存在
if "variables" not in self.state:
self.state["variables"] = {"sys": {}, "conv": {}}
if "conv" not in self.state["variables"]:
self.state["variables"]["conv"] = {}
# 设置值
self.state["variables"]["conv"][key] = value
logger.debug(f"设置变量: {'.'.join(selector)} = {value}")
def has(self, selector: list[str] | str) -> bool:
"""检查变量是否存在
Args:
selector: 变量选择器
Returns:
变量是否存在
Examples:
>>> pool.has(["sys", "message"])
True
>>> pool.has("llm_qa.output")
False
"""
try:
self.get(selector)
return True
except KeyError:
return False
def get_all_system_vars(self) -> dict[str, Any]:
"""获取所有系统变量
Returns:
系统变量字典
"""
return self.state.get("variables", {}).get("sys", {})
def get_all_conversation_vars(self) -> dict[str, Any]:
"""获取所有会话变量
Returns:
会话变量字典
"""
return self.state.get("variables", {}).get("conv", {})
def get_all_node_outputs(self) -> dict[str, Any]:
"""获取所有节点输出(运行时变量)
Returns:
节点输出字典,键为节点 ID
"""
return self.state.get("runtime_vars", {})
def get_node_output(self, node_id: str) -> dict[str, Any] | None:
"""获取指定节点的输出(运行时变量)
Args:
node_id: 节点 ID
Returns:
节点输出或 None
"""
return self.state.get("runtime_vars", {}).get(node_id)
def to_dict(self) -> dict[str, Any]:
"""导出为字典
Returns:
包含所有变量的字典
"""
return {
"system": self.get_all_system_vars(),
"conversation": self.get_all_conversation_vars(),
"nodes": self.get_all_node_outputs() # 从 runtime_vars 读取
}
def __repr__(self) -> str:
sys_vars = self.get_all_system_vars()
conv_vars = self.get_all_conversation_vars()
runtime_vars = self.get_all_node_outputs()
return (
f"VariablePool(\n"
f" system_vars={len(sys_vars)},\n"
f" conversation_vars={len(conv_vars)},\n"
f" runtime_vars={len(runtime_vars)}\n"
f")"
)

View File

@@ -1,10 +1,9 @@
import os
import subprocess
from dotenv import load_dotenv
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from app.core.config import settings
from contextlib import asynccontextmanager
from fastapi.responses import JSONResponse
from app.core.response_utils import fail
from app.core.logging_config import LoggingConfig, get_logger
@@ -38,9 +37,13 @@ router = APIRouter(prefix="/memory", tags=["Memory"])
# 管理端 API (JWT 认证)
from app.controllers import manager_router
# 服务端 API (API Key 认证)
from app.controllers.service import service_router
from app.core.config import settings
from app.core.error_codes import BizCode, HTTP_MAPPING
from app.core.exceptions import BusinessException
from app.core.logging_config import LoggingConfig, get_logger
from app.core.response_utils import fail
# Initialize logging system
LoggingConfig.setup_logging()
@@ -414,5 +417,4 @@ async def unhandled_exception_handler(request: Request, exc: Exception):
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -15,9 +15,11 @@ from .end_user_model import EndUser
from .appshare_model import AppShare
from .release_share_model import ReleaseShare
from .conversation_model import Conversation, Message
from .api_key_model import ApiKey, ApiKeyLog, ApiKeyType, ResourceType
from .api_key_model import ApiKey, ApiKeyLog, ApiKeyType
from .data_config_model import DataConfig
from .multi_agent_model import MultiAgentConfig, AgentInvocation
from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution
from .retrieval_info import RetrievalInfo
__all__ = [
"Tenants",
@@ -46,8 +48,11 @@ __all__ = [
"ApiKey",
"ApiKeyLog",
"ApiKeyType",
"ResourceType",
"DataConfig",
"MultiAgentConfig",
"AgentInvocation"
"AgentInvocation",
"WorkflowConfig",
"WorkflowExecution",
"WorkflowNodeExecution",
"RetrievalInfo"
]

View File

@@ -2,7 +2,7 @@
import datetime
import uuid
from sqlalchemy import Column, String, Boolean, DateTime, Integer, ForeignKey, Text, Enum
from sqlalchemy import Column, String, Boolean, DateTime, Integer, ForeignKey, Text
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.orm import relationship
from enum import StrEnum
@@ -12,18 +12,10 @@ from app.db import Base
class ApiKeyType(StrEnum):
"""API Key 类型"""
APP = "app" # 应用 API Key
RAG = "rag" # RAG API Key
MEMORY = "memory" # Memory API Key
class ResourceType(StrEnum):
"""资源类型枚举"""
AGENT = "Agent" # 智能体
CLUSTER = "Cluster" # 集群
WORKFLOW = "Workflow" # 工作流
KNOWLEDGE = "Knowledge" # 知识库
MEMORY_ENGINE = "Memory_Engine" # 记忆引擎
AGENT = "agent" # 智能体
CLUSTER = "cluster" # 集群
WORKFLOW = "workflow" # 工作流
SERVICE = "service" # 服务
class ApiKey(Base):
@@ -35,18 +27,16 @@ class ApiKey(Base):
# 基本信息
name = Column(String(255), nullable=False, comment="API Key 名称")
description = Column(Text, comment="描述")
key_prefix = Column(String(20), nullable=False, comment="Key 前缀")
key_hash = Column(String(255), nullable=False, unique=True, index=True, comment="Key 哈希值")
api_key = Column(String(255), nullable=False, unique=True, index=True, comment="API Key 明文")
# 类型和权限
type = Column(String(50), nullable=False, index=True, comment="API Key 类型")
scopes = Column(JSONB, nullable=False, default=list, comment="权限范围列表")
scopes = Column(JSONB, default=list, comment="权限范围列表")
# 关联资源
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="CASCADE"), nullable=False,
index=True, comment="所属工作空间")
resource_id = Column(UUID(as_uuid=True), index=True, comment="关联资源ID")
resource_type = Column(String(50), comment="资源类型")
# 限制和配额
rate_limit = Column(Integer, default=10, comment="QPS限制请求/秒)")

View File

@@ -86,6 +86,14 @@ class App(Base):
uselist=False,
cascade="all, delete-orphan",
)
# 一对一:工作流配置(仅当 type=workflow 时有效)
workflow_config = relationship(
"WorkflowConfig",
back_populates="app",
uselist=False,
cascade="all, delete-orphan",
)
# 发布版本关联
current_release = relationship("AppRelease", foreign_keys=[current_release_id])

View File

@@ -0,0 +1,196 @@
"""
工作流相关数据模型
"""
import datetime
import uuid
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float, ForeignKey, Text
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.orm import relationship
from app.db import Base
class WorkflowConfig(Base):
"""工作流配置表"""
__tablename__ = "workflow_configs"
# 主键
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
# 关联应用(一对一)
app_id = Column(
UUID(as_uuid=True),
ForeignKey("apps.id", ondelete="CASCADE"),
nullable=False,
unique=True,
index=True
)
# 节点和边的定义JSON 格式)
nodes = Column(JSONB, nullable=False, default=list)
edges = Column(JSONB, nullable=False, default=list)
# 全局变量定义
variables = Column(JSONB, default=list)
# 执行配置
execution_config = Column(JSONB, nullable=False, default=dict)
# 触发器配置(可选)
triggers = Column(JSONB, default=list)
# 状态
is_active = Column(Boolean, nullable=False, default=True)
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
updated_at = Column(
DateTime,
nullable=False,
default=datetime.datetime.now,
onupdate=datetime.datetime.now
)
# 关系
app = relationship("App", back_populates="workflow_config")
executions = relationship(
"WorkflowExecution",
back_populates="workflow_config",
cascade="all, delete-orphan"
)
def __repr__(self):
return f"<WorkflowConfig(id={self.id}, app_id={self.app_id})>"
class WorkflowExecution(Base):
"""工作流执行记录表"""
__tablename__ = "workflow_executions"
# 主键
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
# 关联信息
workflow_config_id = Column(
UUID(as_uuid=True),
ForeignKey("workflow_configs.id", ondelete="CASCADE"),
nullable=False,
index=True
)
app_id = Column(
UUID(as_uuid=True),
ForeignKey("apps.id", ondelete="CASCADE"),
nullable=False,
index=True
)
conversation_id = Column(
UUID(as_uuid=True),
ForeignKey("conversations.id", ondelete="SET NULL"),
nullable=True,
index=True
)
# 执行信息
execution_id = Column(String(100), nullable=False, unique=True, index=True)
trigger_type = Column(String(20), nullable=False) # manual, schedule, webhook, event
triggered_by = Column(
UUID(as_uuid=True),
ForeignKey("users.id"),
nullable=True
)
# 输入输出
input_data = Column(JSONB)
output_data = Column(JSONB)
context = Column(JSONB, default=dict)
# 状态
status = Column(String(20), nullable=False, default="pending", index=True)
# 可选值pending, running, completed, failed, cancelled, timeout
error_message = Column(Text)
error_node_id = Column(String(100))
# 性能指标
started_at = Column(DateTime, nullable=False, default=datetime.datetime.now, index=True)
completed_at = Column(DateTime)
elapsed_time = Column(Float) # 耗时(秒)
# 资源使用
token_usage = Column(JSONB)
# 元数据(使用 meta_data 避免与 SQLAlchemy 保留字 metadata 冲突)
meta_data = Column(JSONB, default=dict)
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
# 关系
workflow_config = relationship("WorkflowConfig", back_populates="executions")
app = relationship("App")
conversation = relationship("Conversation")
triggered_by_user = relationship("User", foreign_keys=[triggered_by])
node_executions = relationship(
"WorkflowNodeExecution",
back_populates="execution",
cascade="all, delete-orphan",
order_by="WorkflowNodeExecution.execution_order"
)
def __repr__(self):
return f"<WorkflowExecution(id={self.id}, execution_id={self.execution_id}, status={self.status})>"
class WorkflowNodeExecution(Base):
"""工作流节点执行记录表"""
__tablename__ = "workflow_node_executions"
# 主键
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
# 关联执行
execution_id = Column(
UUID(as_uuid=True),
ForeignKey("workflow_executions.id", ondelete="CASCADE"),
nullable=False,
index=True
)
# 节点信息
node_id = Column(String(100), nullable=False, index=True)
node_type = Column(String(20), nullable=False)
node_name = Column(String(100))
# 执行顺序
execution_order = Column(Integer, nullable=False)
retry_count = Column(Integer, nullable=False, default=0)
# 输入输出
input_data = Column(JSONB)
output_data = Column(JSONB)
# 状态
status = Column(String(20), nullable=False, default="pending", index=True)
# 可选值pending, running, completed, failed, skipped, cached
error_message = Column(Text)
# 性能指标
started_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
completed_at = Column(DateTime)
elapsed_time = Column(Float) # 耗时(秒)
# 资源使用(针对 LLM 节点)
token_usage = Column(JSONB)
# 缓存信息
cache_hit = Column(Boolean, default=False)
cache_key = Column(String(255))
# 元数据(使用 meta_data 避免与 SQLAlchemy 保留字 metadata 冲突)
meta_data = Column(JSONB, default=dict)
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
# 关系
execution = relationship("WorkflowExecution", back_populates="node_executions")
def __repr__(self):
return f"<WorkflowNodeExecution(id={self.id}, node_id={self.node_id}, status={self.status})>"

View File

@@ -27,9 +27,9 @@ class ApiKeyRepository:
return db.get(ApiKey, api_key_id)
@staticmethod
def get_by_hash(db: Session, key_hash: str) -> Optional[ApiKey]:
"""根据哈希值获取 API Key"""
stmt = select(ApiKey).where(ApiKey.key_hash == key_hash)
def get_by_api_key(db: Session, api_key: str) -> Optional[ApiKey]:
"""根据 API Key 获取 API Key"""
stmt = select(ApiKey).where(ApiKey.api_key == api_key)
return db.scalars(stmt).first()
@staticmethod
@@ -63,11 +63,15 @@ class ApiKeyRepository:
@staticmethod
def update(db: Session, api_key_id: uuid.UUID, update_data: dict) -> ApiKey | None:
"""更新 API Key"""
allow_none_fields = {"description", "quota_limit", "expires_at"}
api_key = db.get(ApiKey, api_key_id)
if api_key:
for key, value in update_data.items():
if value is not None:
if key in allow_none_fields:
setattr(api_key, key, value)
else:
if value is not None:
setattr(api_key, key, value)
db.flush()
return api_key

View File

@@ -14,7 +14,7 @@ class AppRepository:
def __init__(self, db: Session):
self.db = db
def get_apps_by_workspace_id(self, workspace_id: uuid.UUID) -> List[App]:
def get_apps_by_workspace_id(self, workspace_id: uuid.UUID) -> list[App]:
"""根据工作空间ID查询应用"""
try:
apps = self.db.query(App).filter(App.workspace_id == workspace_id).all()
@@ -24,7 +24,19 @@ class AppRepository:
db_logger.error(f"查询工作空间 {workspace_id} 下应用时出错: {str(e)}")
raise
def get_apps_by_id(self, app_id: uuid.UUID) -> App:
try:
app = self.db.query(App).filter(App.id == app_id, App.is_active == True).first()
return app
except Exception as e:
raise
def get_apps_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> List[App]:
"""根据工作空间ID查询应用"""
repo = AppRepository(db)
return repo.get_apps_by_workspace_id(workspace_id)
def get_apps_by_id(db: Session, app_id: uuid.UUID) -> App:
"""根据工作空间ID查询应用"""
repo = AppRepository(db)
return repo.get_apps_by_id(app_id)

View File

@@ -0,0 +1,247 @@
"""
工作流数据访问层
"""
import uuid
from typing import Any, Annotated
from sqlalchemy.orm import Session
from sqlalchemy import desc
from fastapi import Depends
from app.models.workflow_model import (
WorkflowConfig,
WorkflowExecution,
WorkflowNodeExecution
)
from app.db import get_db
class WorkflowConfigRepository:
"""工作流配置仓储"""
def __init__(self, db: Session):
self.db = db
def get_by_app_id(self, app_id: uuid.UUID) -> WorkflowConfig | None:
"""根据应用 ID 获取工作流配置
Args:
app_id: 应用 ID
Returns:
工作流配置或 None
"""
return self.db.query(WorkflowConfig).filter(
WorkflowConfig.app_id == app_id,
WorkflowConfig.is_active == True
).first()
def create_or_update(
self,
app_id: uuid.UUID,
nodes: list[dict[str, Any]],
edges: list[dict[str, Any]],
variables: list[dict[str, Any]] | None = None,
execution_config: dict[str, Any] | None = None,
triggers: list[dict[str, Any]] | None = None
) -> WorkflowConfig:
"""创建或更新工作流配置
Args:
app_id: 应用 ID
nodes: 节点列表
edges: 边列表
variables: 变量列表
execution_config: 执行配置
triggers: 触发器列表
Returns:
工作流配置
"""
# 查找现有配置
existing = self.get_by_app_id(app_id)
if existing:
# 更新现有配置
existing.nodes = nodes
existing.edges = edges
if variables is not None:
existing.variables = variables
if execution_config is not None:
existing.execution_config = execution_config
if triggers is not None:
existing.triggers = triggers
self.db.commit()
self.db.refresh(existing)
return existing
else:
# 创建新配置
config = WorkflowConfig(
app_id=app_id,
nodes=nodes,
edges=edges,
variables=variables or [],
execution_config=execution_config or {},
triggers=triggers or []
)
self.db.add(config)
self.db.commit()
self.db.refresh(config)
return config
class WorkflowExecutionRepository:
"""工作流执行记录仓储"""
def __init__(self, db: Session):
self.db = db
def get_by_execution_id(self, execution_id: str) -> WorkflowExecution | None:
"""根据执行 ID 获取执行记录
Args:
execution_id: 执行 ID
Returns:
执行记录或 None
"""
return self.db.query(WorkflowExecution).filter(
WorkflowExecution.execution_id == execution_id
).first()
def get_by_app_id(
self,
app_id: uuid.UUID,
limit: int = 50,
offset: int = 0
) -> list[WorkflowExecution]:
"""根据应用 ID 获取执行记录列表
Args:
app_id: 应用 ID
limit: 返回数量限制
offset: 偏移量
Returns:
执行记录列表
"""
return self.db.query(WorkflowExecution).filter(
WorkflowExecution.app_id == app_id
).order_by(
desc(WorkflowExecution.started_at)
).limit(limit).offset(offset).all()
def get_by_conversation_id(
self,
conversation_id: uuid.UUID
) -> list[WorkflowExecution]:
"""根据会话 ID 获取执行记录列表
Args:
conversation_id: 会话 ID
Returns:
执行记录列表
"""
return self.db.query(WorkflowExecution).filter(
WorkflowExecution.conversation_id == conversation_id
).order_by(
desc(WorkflowExecution.started_at)
).all()
def count_by_app_id(self, app_id: uuid.UUID) -> int:
"""统计应用的执行次数
Args:
app_id: 应用 ID
Returns:
执行次数
"""
return self.db.query(WorkflowExecution).filter(
WorkflowExecution.app_id == app_id
).count()
def count_by_status(self, app_id: uuid.UUID, status: str) -> int:
"""统计指定状态的执行次数
Args:
app_id: 应用 ID
status: 状态
Returns:
执行次数
"""
return self.db.query(WorkflowExecution).filter(
WorkflowExecution.app_id == app_id,
WorkflowExecution.status == status
).count()
class WorkflowNodeExecutionRepository:
"""工作流节点执行记录仓储"""
def __init__(self, db: Session):
self.db = db
def get_by_execution_id(
self,
execution_id: uuid.UUID
) -> list[WorkflowNodeExecution]:
"""根据执行 ID 获取节点执行记录列表
Args:
execution_id: 执行 ID
Returns:
节点执行记录列表(按执行顺序排序)
"""
return self.db.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.execution_id == execution_id
).order_by(
WorkflowNodeExecution.execution_order
).all()
def get_by_node_id(
self,
execution_id: uuid.UUID,
node_id: str
) -> list[WorkflowNodeExecution]:
"""根据节点 ID 获取节点执行记录(可能有多次重试)
Args:
execution_id: 执行 ID
node_id: 节点 ID
Returns:
节点执行记录列表
"""
return self.db.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.execution_id == execution_id,
WorkflowNodeExecution.node_id == node_id
).order_by(
WorkflowNodeExecution.retry_count
).all()
# ==================== 依赖注入函数 ====================
def get_workflow_config_repository(
db: Annotated[Session, Depends(get_db)]
) -> WorkflowConfigRepository:
"""获取工作流配置仓储(依赖注入)"""
return WorkflowConfigRepository(db)
def get_workflow_execution_repository(
db: Annotated[Session, Depends(get_db)]
) -> WorkflowExecutionRepository:
"""获取工作流执行记录仓储(依赖注入)"""
return WorkflowExecutionRepository(db)
def get_workflow_node_execution_repository(
db: Annotated[Session, Depends(get_db)]
) -> WorkflowNodeExecutionRepository:
"""获取工作流节点执行记录仓储(依赖注入)"""
return WorkflowNodeExecutionRepository(db)

View File

@@ -1,11 +1,11 @@
"""API Key Schema"""
import datetime
import uuid
from pydantic import BaseModel, Field, ConfigDict
from pydantic.v1 import validator
from pydantic import BaseModel, Field, ConfigDict, field_validator, field_serializer, computed_field
from typing import Optional, List
from app.models.api_key_model import ApiKeyType, ResourceType
from app.models.api_key_model import ApiKeyType
from app.core.api_key_utils import timestamp_to_datetime, datetime_to_timestamp
class ApiKeyCreate(BaseModel):
@@ -15,20 +15,34 @@ class ApiKeyCreate(BaseModel):
type: ApiKeyType = Field(..., description="API Key 类型")
scopes: List[str] = Field(default_factory=list, description="权限范围列表")
resource_id: Optional[uuid.UUID] = Field(None, description="关联资源ID")
resource_type: Optional[ResourceType] = Field(None, description="资源类型")
rate_limit: Optional[int] = Field(10, ge=1, le=1000, description="QPS限制请求/秒)")
daily_request_limit: Optional[int] = Field(10000, description="日请求限制", ge=1)
quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1)
expires_at: Optional[datetime.datetime] = Field(None, description="过期时间")
@validator('scopes')
@computed_field
@property
def is_expired(self) -> bool:
"""检查API Key是否已过期"""
if not self.expires_at:
return False
return datetime.datetime.now() > self.expires_at
@field_validator('expires_at', mode='before')
@classmethod
def parse_expires_at(cls, v):
"""将时间戳转换为datetime"""
if isinstance(v, (int, float)):
return timestamp_to_datetime(v)
return v
@field_validator('scopes')
@classmethod
def validate_scopes(cls, v):
"""验证权限范围格式"""
valid_scopes = [
"app:all",
"rag:search", "rag:upload", "rag:delete",
"memory:read", "memory:write", "memory:delete", "memory:search"
]
if v is None:
return []
valid_scopes = ["app", "rag", "memory"]
for scope in v:
if scope not in valid_scopes:
raise ValueError(f"无效范围: {scope}")
@@ -46,14 +60,29 @@ class ApiKeyUpdate(BaseModel):
is_active: Optional[bool] = Field(None, description="是否激活")
expires_at: Optional[datetime.datetime] = Field(None, description="过期时间")
@validator('scopes')
@computed_field
@property
def is_expired(self) -> bool:
"""检查API Key是否已过期"""
if not self.expires_at:
return False
return datetime.datetime.now() > self.expires_at
@field_validator('expires_at', mode='before')
@classmethod
def parse_expires_at(cls, v):
"""将时间戳转换为datetime"""
if isinstance(v, (int, float)):
return timestamp_to_datetime(v)
return v
@field_validator('scopes')
@classmethod
def validate_scopes(cls, v):
"""验证权限范围格式"""
valid_scopes = {
'app:all',
'rag:search', 'rag:upload', 'rag:delete',
'memory:read', 'memory:write', 'memory:delete', 'memory:search'
}
if v is None:
return v
valid_scopes = ["app", "rag", "memory"]
for scope in v:
if scope not in valid_scopes:
raise ValueError(f"无效范围: {scope}")
@@ -67,18 +96,31 @@ class ApiKeyResponse(BaseModel):
id: uuid.UUID
name: str
description: Optional[str]
api_key: str = Field(..., description="API Key 明文(仅创建时返回)")
key_prefix: str
api_key: str
type: str
scopes: List[str]
resource_id: Optional[uuid.UUID]
resource_type: Optional[str]
rate_limit: int
daily_request_limit: int
quota_limit: Optional[int]
is_active: bool
expires_at: Optional[datetime.datetime]
created_at: datetime.datetime
@computed_field
@property
def is_expired(self) -> bool:
"""检查API Key是否已过期"""
if not self.expires_at:
return False
return datetime.datetime.now() > self.expires_at
@field_serializer('expires_at', 'created_at')
@classmethod
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
"""将datetime转换为时间戳"""
return datetime_to_timestamp(v)
class ApiKey(BaseModel):
"""API Key 信息(不包含明文 Key"""
@@ -87,11 +129,10 @@ class ApiKey(BaseModel):
id: uuid.UUID
name: str
description: Optional[str]
key_prefix: str
api_key: str
type: str
scopes: List[str]
resource_id: Optional[uuid.UUID]
resource_type: Optional[str]
rate_limit: int
daily_request_limit: int
quota_limit: Optional[int]
@@ -105,6 +146,20 @@ class ApiKey(BaseModel):
created_at: datetime.datetime
updated_at: datetime.datetime
@computed_field
@property
def is_expired(self) -> bool:
"""检查API Key是否已过期"""
if not self.expires_at:
return False
return datetime.datetime.now() > self.expires_at
@field_serializer('expires_at', 'last_used_at', 'created_at', 'updated_at')
@classmethod
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
"""将datetime转换为时间戳"""
return datetime_to_timestamp(v)
class ApiKeyStats(BaseModel):
"""API Key 使用统计"""
@@ -115,6 +170,12 @@ class ApiKeyStats(BaseModel):
last_used_at: Optional[datetime.datetime] = Field(None, description="最后使用时间")
avg_response_time: Optional[float] = Field(None, description="平均响应时间(毫秒)")
@field_serializer('last_used_at')
@classmethod
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
"""将datetime转换为时间戳"""
return datetime_to_timestamp(v)
class ApiKeyQuery(BaseModel):
"""API Key 查询参数"""
@@ -132,7 +193,6 @@ class ApiKeyAuth(BaseModel):
type: str
scopes: List[str]
resource_id: Optional[uuid.UUID]
resource_type: Optional[str]
class ApiKeyLog(BaseModel):
@@ -157,3 +217,9 @@ class ApiKeyLog(BaseModel):
# 时间信息
created_at: datetime.datetime
@field_serializer('created_at')
@classmethod
def serialize_datetime(cls, v: datetime.datetime) -> int:
"""将datetime转换为时间戳"""
return datetime_to_timestamp(v)

View File

@@ -0,0 +1,215 @@
"""
工作流相关的 Pydantic Schema
"""
import datetime
import uuid
from typing import Any
from pydantic import BaseModel, Field, ConfigDict, field_serializer
# ==================== 节点和边定义 ====================
class NodeConfig(BaseModel):
"""节点配置"""
model_config = ConfigDict(extra="allow") # 允许额外字段
class NodeDefinition(BaseModel):
"""节点定义"""
id: str = Field(..., description="节点唯一标识")
type: str = Field(..., description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code")
name: str | None = Field(None, description="节点名称")
description: str | None = Field(None, description="节点描述")
config: dict[str, Any] = Field(default_factory=dict, description="节点配置")
position: dict[str, float] | None = Field(None, description="节点位置 {x, y}")
error_handling: dict[str, Any] | None = Field(None, description="错误处理配置")
cache: dict[str, Any] | None = Field(None, description="缓存配置")
class EdgeDefinition(BaseModel):
"""边定义"""
id: str | None = Field(None, description="边唯一标识(可选)")
source: str = Field(..., description="源节点 ID")
target: str = Field(..., description="目标节点 ID")
type: str | None = Field(None, description="边类型: normal, error")
condition: str | None = Field(None, description="条件表达式(条件边)")
label: str | None = Field(None, description="边标签")
class VariableDefinition(BaseModel):
"""变量定义"""
name: str = Field(..., description="变量名称")
type: str = Field(default="string", description="变量类型: string, number, boolean, object, array")
required: bool = Field(default=False, description="是否必填")
default: Any = Field(None, description="默认值")
description: str | None = Field(None, description="变量描述")
class ExecutionConfig(BaseModel):
"""执行配置"""
max_iterations: int = Field(default=100, ge=1, le=1000, description="最大迭代次数")
timeout: int = Field(default=600, ge=10, le=3600, description="全局超时时间(秒)")
enable_cache: bool = Field(default=True, description="是否启用节点缓存")
parallel_limit: int = Field(default=5, ge=1, le=20, description="并行执行限制")
class TriggerConfig(BaseModel):
"""触发器配置"""
type: str = Field(..., description="触发器类型: schedule, webhook, event")
config: dict[str, Any] = Field(default_factory=dict, description="触发器配置")
# ==================== 工作流配置 ====================
class WorkflowConfigCreate(BaseModel):
"""创建工作流配置"""
nodes: list[NodeDefinition] = Field(default_factory=list, description="节点列表")
edges: list[EdgeDefinition] = Field(default_factory=list, description="边列表")
variables: list[VariableDefinition] = Field(default_factory=list, description="变量列表")
execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置")
triggers: list[TriggerConfig] = Field(default_factory=list, description="触发器列表")
class WorkflowConfigUpdate(BaseModel):
"""更新工作流配置"""
nodes: list[NodeDefinition] | None = None
edges: list[EdgeDefinition] | None = None
variables: list[VariableDefinition] | None = None
execution_config: ExecutionConfig | None = None
triggers: list[TriggerConfig] | None = None
class WorkflowConfig(BaseModel):
"""工作流配置输出"""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
app_id: uuid.UUID
nodes: list[dict[str, Any]]
edges: list[dict[str, Any]]
variables: list[dict[str, Any]]
execution_config: dict[str, Any]
triggers: list[dict[str, Any]]
is_active: bool
created_at: datetime.datetime
updated_at: datetime.datetime
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("updated_at", when_used="json")
def _serialize_updated_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
# ==================== 工作流执行 ====================
class WorkflowExecutionRequest(BaseModel):
"""工作流执行请求"""
message: str | None = Field(None, description="用户消息(可选)")
variables: dict[str, Any] = Field(default_factory=dict, description="输入变量")
conversation_id: str | None = Field(None, description="会话 ID用于关联对话")
stream: bool = Field(default=False, description="是否流式返回")
class WorkflowExecutionResponse(BaseModel):
"""工作流执行响应(非流式)"""
execution_id: str = Field(..., description="执行 ID")
status: str = Field(..., description="执行状态")
output: str | None = Field(None, description="最终输出(字符串,便于快速访问)")
output_data: dict[str, Any] | None = Field(None, description="所有节点的详细输出数据")
error_message: str | None = Field(None, description="错误信息")
elapsed_time: float | None = Field(None, description="耗时(秒)")
token_usage: dict[str, Any] | None = Field(None, description="Token 使用情况 {prompt_tokens, completion_tokens, total_tokens}")
class WorkflowExecutionStreamChunk(BaseModel):
"""工作流执行流式响应块"""
type: str = Field(..., description="事件类型: node_start, token, node_complete, error_redirect, workflow_complete")
execution_id: str = Field(..., description="执行 ID")
data: dict[str, Any] = Field(default_factory=dict, description="事件数据")
class WorkflowExecution(BaseModel):
"""工作流执行记录输出"""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
workflow_config_id: uuid.UUID
app_id: uuid.UUID
conversation_id: uuid.UUID | None
execution_id: str
trigger_type: str
triggered_by: uuid.UUID | None
input_data: dict[str, Any] | None
output_data: dict[str, Any] | None
context: dict[str, Any]
status: str
error_message: str | None
error_node_id: str | None
started_at: datetime.datetime
completed_at: datetime.datetime | None
elapsed_time: float | None
token_usage: dict[str, Any] | None
meta_data: dict[str, Any]
created_at: datetime.datetime
@field_serializer("started_at", when_used="json")
def _serialize_started_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("completed_at", when_used="json")
def _serialize_completed_at(self, dt: datetime.datetime | None):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
class WorkflowNodeExecution(BaseModel):
"""工作流节点执行记录输出"""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
execution_id: uuid.UUID
node_id: str
node_type: str
node_name: str | None
execution_order: int
retry_count: int
input_data: dict[str, Any] | None
output_data: dict[str, Any] | None
status: str
error_message: str | None
started_at: datetime.datetime
completed_at: datetime.datetime | None
elapsed_time: float | None
token_usage: dict[str, Any] | None
cache_hit: bool
cache_key: str | None
meta_data: dict[str, Any]
created_at: datetime.datetime
@field_serializer("started_at", when_used="json")
def _serialize_started_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("completed_at", when_used="json")
def _serialize_completed_at(self, dt: datetime.datetime | None):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
# ==================== 验证响应 ====================
class WorkflowValidationResponse(BaseModel):
"""工作流验证响应"""
is_valid: bool = Field(..., description="是否有效")
errors: list[str] = Field(default_factory=list, description="错误列表")
warnings: list[str] = Field(default_factory=list, description="警告列表")

View File

@@ -13,7 +13,7 @@ from app.models.api_key_model import ApiKey
from app.repositories.api_key_repository import ApiKeyRepository, ApiKeyLogRepository
from app.schemas import api_key_schema
from app.schemas.response_schema import PageData, PageMeta
from app.core.api_key_utils import generate_api_key, hash_api_key, validate_resource_binding
from app.core.api_key_utils import generate_api_key
from app.core.exceptions import (
BusinessException,
)
@@ -33,21 +33,13 @@ class ApiKeyService:
workspace_id: uuid.UUID,
user_id: uuid.UUID,
data: api_key_schema.ApiKeyCreate
) -> Tuple[ApiKey, str]:
) -> ApiKey:
"""
创建 API Key
Returns:
Tuple[ApiKey, str]: (API Key 对象, API Key 明文)
ApiKey: API Key 对象
"""
try:
# 验证资源绑定
if data.resource_type or data.resource_id:
is_valid, error_msg = validate_resource_binding(
data.resource_type, str(data.resource_id) if data.resource_id else None
)
if not is_valid:
raise BusinessException(error_msg, BizCode.API_KEY_INVALID_RESOURCE)
existing = db.scalar(
select(ApiKey).where(
ApiKey.workspace_id == workspace_id,
@@ -59,22 +51,20 @@ class ApiKeyService:
raise BusinessException(f"API Key 名称 '{data.name}' 已存在", BizCode.API_KEY_DUPLICATE_NAME)
# 生成 API Key
api_key, key_hash, key_prefix = generate_api_key(data.type)
api_key = generate_api_key(data.type)
# 创建数据
api_key_data = {
"id": uuid.uuid4(),
"name": data.name,
"description": data.description,
"key_prefix": key_prefix,
"key_hash": key_hash,
"api_key": api_key,
"type": data.type,
"scopes": data.scopes,
"workspace_id": workspace_id,
"resource_id": data.resource_id,
"resource_type": data.resource_type,
"rate_limit": data.rate_limit or 10,
"daily_request_limit": data.daily_request_limit or 10000,
"rate_limit": data.rate_limit,
"daily_request_limit": data.daily_request_limit,
"quota_limit": data.quota_limit,
"expires_at": data.expires_at,
"created_by": user_id,
@@ -90,7 +80,7 @@ class ApiKeyService:
"type": data.type
})
return api_key_obj, api_key
return api_key_obj
except Exception as e:
db.rollback()
@@ -147,6 +137,9 @@ class ApiKeyService:
"""更新 API Key配置"""
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
if not api_key:
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
# 检查名称重复
if data.name and data.name != api_key.name:
existing = db.scalar(
@@ -177,6 +170,9 @@ class ApiKeyService:
"""删除 API Key"""
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
if not api_key:
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
ApiKeyRepository.delete(db, api_key_id)
db.commit()
@@ -188,27 +184,29 @@ class ApiKeyService:
db: Session,
api_key_id: uuid.UUID,
workspace_id: uuid.UUID
) -> Tuple[ApiKey, str]:
) -> ApiKey:
"""重新生成 API Key"""
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
if not api_key:
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
# 检查 API Key 是否激活
if not api_key.is_active:
raise BusinessException("无法重新生成已停用的 API Key", BizCode.API_KEY_INACTIVE)
# 生成新的 API Key
new_api_key, key_hash, key_prefix = generate_api_key(api_key_schema.ApiKeyType(api_key.type))
new_api_key = generate_api_key(api_key.type)
# 更新
ApiKeyRepository.update(db, api_key_id, {
"key_hash": key_hash,
"key_prefix": key_prefix
"api_key": new_api_key
})
db.commit()
db.refresh(api_key)
logger.info("API Key 重新生成成功", extra={"api_key_id": str(api_key_id)})
return api_key, new_api_key
return api_key
@staticmethod
def get_stats(
@@ -219,6 +217,9 @@ class ApiKeyService:
"""获取使用统计"""
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
if not api_key:
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
stats_data = ApiKeyRepository.get_stats(db, api_key_id)
return api_key_schema.ApiKeyStats(**stats_data)
@@ -235,6 +236,9 @@ class ApiKeyService:
# 验证 API Key 权限
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
if not api_key:
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
items, total = ApiKeyLogRepository.list_by_api_key(
db, api_key_id, filters, page, pagesize
)
@@ -330,7 +334,6 @@ class RateLimiterService:
"X-RateLimit-Reset": str(qps_info["reset"])
}
# Check daily requests
daily_ok, daily_info = await self.check_daily_requests(
api_key.id,
api_key.daily_request_limit
@@ -342,7 +345,6 @@ class RateLimiterService:
"X-RateLimit-Reset": str(daily_info["reset"])
}
# All checks passed
headers = {
"X-RateLimit-Limit-QPS": str(qps_info["limit"]),
"X-RateLimit-Remaining-QPS": str(qps_info["remaining"]),
@@ -363,13 +365,12 @@ class ApiKeyAuthService:
验证API Key 有效性
检查:
1. Key hash 是否存在
1. API Key 是否存在
2. is_active 是否为true
3. expires_at 是否未过期
4. quota 是否未超限
"""
key_hash = hash_api_key(api_key)
api_key_obj = ApiKeyRepository.get_by_hash(db, key_hash)
api_key_obj = ApiKeyRepository.get_by_api_key(db, api_key)
if not api_key_obj:
return None
@@ -393,14 +394,7 @@ class ApiKeyAuthService:
@staticmethod
def check_resource(
api_key: ApiKey,
resource_type: str,
resource_id: uuid.UUID
) -> bool:
"""检查资源绑定"""
if not api_key.resource_id:
return True
return (
api_key.resource_type == resource_type and
api_key.resource_id == resource_id
)
return api_key.resource_id == resource_id

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,731 @@
"""
工作流服务层
"""
import logging
import uuid
import datetime
from typing import Any, Annotated
from sqlalchemy.orm import Session
from fastapi import Depends
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
from app.repositories.workflow_repository import (
WorkflowConfigRepository,
WorkflowExecutionRepository,
WorkflowNodeExecutionRepository,
get_workflow_config_repository,
get_workflow_execution_repository,
get_workflow_node_execution_repository
)
from app.core.workflow.validator import validate_workflow_config
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.db import get_db
from app.schemas import DraftRunRequest
logger = logging.getLogger(__name__)
class WorkflowService:
"""工作流服务"""
def __init__(self, db: Session):
self.db = db
self.config_repo = WorkflowConfigRepository(db)
self.execution_repo = WorkflowExecutionRepository(db)
self.node_execution_repo = WorkflowNodeExecutionRepository(db)
# ==================== 配置管理 ====================
def create_workflow_config(
self,
app_id: uuid.UUID,
nodes: list[dict[str, Any]],
edges: list[dict[str, Any]],
variables: list[dict[str, Any]] | None = None,
execution_config: dict[str, Any] | None = None,
triggers: list[dict[str, Any]] | None = None,
validate: bool = True
) -> WorkflowConfig:
"""创建工作流配置
Args:
app_id: 应用 ID
nodes: 节点列表
edges: 边列表
variables: 变量列表
execution_config: 执行配置
triggers: 触发器列表
validate: 是否验证配置
Returns:
工作流配置
Raises:
BusinessException: 配置无效时抛出
"""
# 构建配置字典
config_dict = {
"nodes": nodes,
"edges": edges,
"variables": variables or [],
"execution_config": execution_config or {},
"triggers": triggers or []
}
# 验证配置
if validate:
is_valid, errors = validate_workflow_config(config_dict, for_publish=False)
if not is_valid:
logger.warning(f"工作流配置验证失败: {errors}")
raise BusinessException(
error_code=BizCode.INVALID_PARAMETER,
message=f"工作流配置无效: {'; '.join(errors)}"
)
# 创建或更新配置
config = self.config_repo.create_or_update(
app_id=app_id,
nodes=nodes,
edges=edges,
variables=variables,
execution_config=execution_config,
triggers=triggers
)
logger.info(f"创建工作流配置成功: app_id={app_id}, config_id={config.id}")
return config
def get_workflow_config(self, app_id: uuid.UUID) -> WorkflowConfig | None:
"""获取工作流配置
Args:
app_id: 应用 ID
Returns:
工作流配置或 None
"""
return self.config_repo.get_by_app_id(app_id)
def update_workflow_config(
self,
app_id: uuid.UUID,
nodes: list[dict[str, Any]] | None = None,
edges: list[dict[str, Any]] | None = None,
variables: list[dict[str, Any]] | None = None,
execution_config: dict[str, Any] | None = None,
triggers: list[dict[str, Any]] | None = None,
validate: bool = True
) -> WorkflowConfig:
"""更新工作流配置
Args:
app_id: 应用 ID
nodes: 节点列表
edges: 边列表
variables: 变量列表
execution_config: 执行配置
triggers: 触发器列表
validate: 是否验证配置
Returns:
工作流配置
Raises:
BusinessException: 配置不存在或无效时抛出
"""
# 获取现有配置
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(
error_code=BizCode.RESOURCE_NOT_FOUND,
message=f"工作流配置不存在: app_id={app_id}"
)
# 合并配置
updated_nodes = nodes if nodes is not None else config.nodes
updated_edges = edges if edges is not None else config.edges
updated_variables = variables if variables is not None else config.variables
updated_execution_config = execution_config if execution_config is not None else config.execution_config
updated_triggers = triggers if triggers is not None else config.triggers
# 构建配置字典
config_dict = {
"nodes": updated_nodes,
"edges": updated_edges,
"variables": updated_variables,
"execution_config": updated_execution_config,
"triggers": updated_triggers
}
# 验证配置
if validate:
is_valid, errors = validate_workflow_config(config_dict, for_publish=False)
if not is_valid:
logger.warning(f"工作流配置验证失败: {errors}")
raise BusinessException(
error_code=BizCode.INVALID_PARAMETER,
message=f"工作流配置无效: {'; '.join(errors)}"
)
# 更新配置
config = self.config_repo.create_or_update(
app_id=app_id,
nodes=updated_nodes,
edges=updated_edges,
variables=updated_variables,
execution_config=updated_execution_config,
triggers=updated_triggers
)
logger.info(f"更新工作流配置成功: app_id={app_id}, config_id={config.id}")
return config
def delete_workflow_config(self, app_id: uuid.UUID) -> bool:
"""删除工作流配置
Args:
app_id: 应用 ID
Returns:
是否删除成功
"""
config = self.get_workflow_config(app_id)
if not config:
return False
self.config_repo.delete(config.id)
logger.info(f"删除工作流配置成功: app_id={app_id}, config_id={config.id}")
return True
def check_config(self, app_id: uuid.UUID) -> WorkflowConfig:
"""检查工作流配置的完整性
Args:
app_id: 应用 ID
Raises:
BusinessException: 配置不完整或不存在时抛出
"""
# 1. 检查多智能体配置是否存在
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(
"工作流配置不存在,无法运行",
BizCode.CONFIG_MISSING
)
# validator 现在支持直接接受 Pydantic 模型
is_valid, errors = validate_workflow_config(config, for_publish=False)
if not is_valid:
logger.warning(f"工作流配置验证失败: {errors}")
raise BusinessException(
code=BizCode.INVALID_PARAMETER,
message=f"工作流配置无效: {'; '.join(errors)}"
)
return config
def validate_workflow_config_for_publish(
self,
app_id: uuid.UUID
) -> tuple[bool, list[str]]:
"""验证工作流配置是否可以发布
Args:
app_id: 应用 ID
Returns:
(is_valid, errors): 是否有效和错误列表
Raises:
BusinessException: 配置不存在时抛出
"""
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(
error_code=BizCode.RESOURCE_NOT_FOUND,
message=f"工作流配置不存在: app_id={app_id}"
)
config_dict = {
"nodes": config.nodes,
"edges": config.edges,
"variables": config.variables,
"execution_config": config.execution_config,
"triggers": config.triggers
}
return validate_workflow_config(config_dict, for_publish=True)
# ==================== 执行管理 ====================
def create_execution(
self,
workflow_config_id: uuid.UUID,
app_id: uuid.UUID,
trigger_type: str,
triggered_by: uuid.UUID | None = None,
conversation_id: uuid.UUID | None = None,
input_data: dict[str, Any] | None = None
) -> WorkflowExecution:
"""创建工作流执行记录
Args:
workflow_config_id: 工作流配置 ID
app_id: 应用 ID
trigger_type: 触发类型
triggered_by: 触发用户 ID
conversation_id: 会话 ID
input_data: 输入数据
Returns:
执行记录
"""
# 生成执行 ID
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
execution = WorkflowExecution(
workflow_config_id=workflow_config_id,
app_id=app_id,
conversation_id=conversation_id,
execution_id=execution_id,
trigger_type=trigger_type,
triggered_by=triggered_by,
input_data=input_data or {},
status="pending"
)
self.db.add(execution)
self.db.commit()
self.db.refresh(execution)
logger.info(f"创建工作流执行记录: execution_id={execution_id}")
return execution
def get_execution(self, execution_id: str) -> WorkflowExecution | None:
"""获取执行记录
Args:
execution_id: 执行 ID
Returns:
执行记录或 None
"""
return self.execution_repo.get_by_execution_id(execution_id)
def get_executions_by_app(
self,
app_id: uuid.UUID,
limit: int = 50,
offset: int = 0
) -> list[WorkflowExecution]:
"""获取应用的执行记录列表
Args:
app_id: 应用 ID
limit: 返回数量限制
offset: 偏移量
Returns:
执行记录列表
"""
return self.execution_repo.get_by_app_id(app_id, limit, offset)
def update_execution_status(
self,
execution_id: str,
status: str,
output_data: dict[str, Any] | None = None,
error_message: str | None = None,
error_node_id: str | None = None
) -> WorkflowExecution:
"""更新执行状态
Args:
execution_id: 执行 ID
status: 状态
output_data: 输出数据
error_message: 错误信息
error_node_id: 出错节点 ID
Returns:
执行记录
Raises:
BusinessException: 执行记录不存在时抛出
"""
execution = self.get_execution(execution_id)
if not execution:
raise BusinessException(
error_code=BizCode.RESOURCE_NOT_FOUND,
message=f"执行记录不存在: execution_id={execution_id}"
)
execution.status = status
if output_data is not None:
execution.output_data = output_data
if error_message is not None:
execution.error_message = error_message
if error_node_id is not None:
execution.error_node_id = error_node_id
# 如果是完成状态,计算耗时
if status in ["completed", "failed", "cancelled", "timeout"]:
if not execution.completed_at:
execution.completed_at = datetime.datetime.now()
elapsed = (execution.completed_at - execution.started_at).total_seconds()
execution.elapsed_time = elapsed
self.db.commit()
self.db.refresh(execution)
logger.info(f"更新执行状态: execution_id={execution_id}, status={status}")
return execution
def get_execution_statistics(self, app_id: uuid.UUID) -> dict[str, Any]:
"""获取执行统计信息
Args:
app_id: 应用 ID
Returns:
统计信息
"""
total = self.execution_repo.count_by_app_id(app_id)
completed = self.execution_repo.count_by_status(app_id, "completed")
failed = self.execution_repo.count_by_status(app_id, "failed")
running = self.execution_repo.count_by_status(app_id, "running")
return {
"total": total,
"completed": completed,
"failed": failed,
"running": running,
"success_rate": completed / total if total > 0 else 0
}
# ==================== 工作流执行 ====================
async def run(
self,
app_id: uuid.UUID,
payload: DraftRunRequest,
config: WorkflowConfig
):
"""运行工作流
Args:
app_id: 应用 ID
input_data: 输入数据(包含 message 和 variables
triggered_by: 触发用户 ID
conversation_id: 会话 ID可选
stream: 是否流式返回
Returns:
执行结果(非流式)或生成器(流式)
Raises:
BusinessException: 配置不存在或执行失败时抛出
"""
# 1. 获取工作流配置
if not config:
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(
code=BizCode.CONFIG_MISSING,
message=f"工作流配置不存在: app_id={app_id}"
)
input_data = {"message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id}
# 转换 user_id 为 UUID
triggered_by_uuid = None
if payload.user_id:
try:
triggered_by_uuid = uuid.UUID(payload.user_id)
except (ValueError, AttributeError):
logger.warning(f"无效的 user_id 格式: {payload.user_id}")
# 转换 conversation_id 为 UUID
conversation_id_uuid = None
if payload.conversation_id:
try:
conversation_id_uuid = uuid.UUID(payload.conversation_id)
except (ValueError, AttributeError):
logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}")
# 2. 创建执行记录
execution = self.create_execution(
workflow_config_id=config.id,
app_id=app_id,
trigger_type="manual",
triggered_by=triggered_by_uuid,
conversation_id=conversation_id_uuid,
input_data=input_data
)
# 3. 构建工作流配置字典
workflow_config_dict = {
"nodes": config.nodes,
"edges": config.edges,
"variables": config.variables,
"execution_config": config.execution_config
}
# 4. 获取工作空间 ID从 app 获取)
from app.models import App
# 5. 执行工作流
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
try:
# 更新状态为运行中
self.update_execution_status(execution.execution_id, "running")
result = await execute_workflow(
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id="",
user_id=payload.user_id
)
# 更新执行结果
if result.get("status") == "completed":
self.update_execution_status(
execution.execution_id,
"completed",
output_data=result.get("node_outputs", {})
)
else:
self.update_execution_status(
execution.execution_id,
"failed",
error_message=result.get("error")
)
# 返回增强的响应结构
return {
"execution_id": execution.execution_id,
"status": result.get("status"),
"output": result.get("output"), # 最终输出(字符串)
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID
"error_message": result.get("error"),
"elapsed_time": result.get("elapsed_time"),
"token_usage": result.get("token_usage")
}
except Exception as e:
logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
self.update_execution_status(
execution.execution_id,
"failed",
error_message=str(e)
)
raise BusinessException(
code=BizCode.INTERNAL_ERROR,
message=f"工作流执行失败: {str(e)}"
)
async def run_workflow(
self,
app_id: uuid.UUID,
input_data: dict[str, Any],
triggered_by: uuid.UUID,
conversation_id: uuid.UUID | None = None,
stream: bool = False
):
"""运行工作流
Args:
app_id: 应用 ID
input_data: 输入数据(包含 message 和 variables
triggered_by: 触发用户 ID
conversation_id: 会话 ID可选
stream: 是否流式返回
Returns:
执行结果(非流式)或生成器(流式)
Raises:
BusinessException: 配置不存在或执行失败时抛出
"""
# 1. 获取工作流配置
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(
error_code=BizCode.RESOURCE_NOT_FOUND,
message=f"工作流配置不存在: app_id={app_id}"
)
# 2. 创建执行记录
execution = self.create_execution(
workflow_config_id=config.id,
app_id=app_id,
trigger_type="manual",
triggered_by=triggered_by,
conversation_id=conversation_id,
input_data=input_data
)
# 3. 构建工作流配置字典
workflow_config_dict = {
"nodes": config.nodes,
"edges": config.edges,
"variables": config.variables,
"execution_config": config.execution_config
}
# 4. 获取工作空间 ID从 app 获取)
from app.models import App
app = self.db.query(App).filter(App.id == app_id).first()
if not app:
raise BusinessException(
error_code=BizCode.RESOURCE_NOT_FOUND,
message=f"应用不存在: app_id={app_id}"
)
# 5. 执行工作流
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
try:
# 更新状态为运行中
self.update_execution_status(execution.execution_id, "running")
if stream:
# 流式执行
return self._run_workflow_stream(
workflow_config_dict,
input_data,
execution.execution_id,
str(app.workspace_id),
str(triggered_by)
)
else:
# 非流式执行
result = await execute_workflow(
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id=str(app.workspace_id),
user_id=str(triggered_by)
)
# 更新执行结果
if result.get("status") == "completed":
self.update_execution_status(
execution.execution_id,
"completed",
output_data=result.get("node_outputs", {})
)
else:
self.update_execution_status(
execution.execution_id,
"failed",
error_message=result.get("error")
)
# 返回增强的响应结构
return {
"execution_id": execution.execution_id,
"status": result.get("status"),
"output": result.get("output"), # 最终输出(字符串)
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"error_message": result.get("error"),
"elapsed_time": result.get("elapsed_time"),
"token_usage": result.get("token_usage")
}
except Exception as e:
logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
self.update_execution_status(
execution.execution_id,
"failed",
error_message=str(e)
)
raise BusinessException(
error_code=BizCode.INTERNAL_ERROR,
message=f"工作流执行失败: {str(e)}"
)
async def _run_workflow_stream(
self,
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
):
"""运行工作流(流式,内部方法)
Args:
workflow_config: 工作流配置
input_data: 输入数据
execution_id: 执行 ID
workspace_id: 工作空间 ID
user_id: 用户 ID
Yields:
流式事件
"""
from app.core.workflow.executor import execute_workflow_stream
try:
output_data = {}
async for event in execute_workflow_stream(
workflow_config=workflow_config,
input_data=input_data,
execution_id=execution_id,
workspace_id=workspace_id,
user_id=user_id
):
# 转发事件
yield event
# 收集输出数据
if event.get("type") == "node_complete":
node_data = event.get("data", {})
node_outputs = node_data.get("node_outputs", {})
output_data.update(node_outputs)
# 处理完成事件
if event.get("type") == "workflow_complete":
self.update_execution_status(
execution_id,
"completed",
output_data=output_data
)
# 处理错误事件
if event.get("type") == "workflow_error":
self.update_execution_status(
execution_id,
"failed",
error_message=event.get("error")
)
except Exception as e:
logger.error(f"工作流流式执行失败: execution_id={execution_id}, error={e}", exc_info=True)
self.update_execution_status(
execution_id,
"failed",
error_message=str(e)
)
yield {
"type": "workflow_error",
"execution_id": execution_id,
"error": str(e)
}
# ==================== 依赖注入函数 ====================
def get_workflow_service(
db: Annotated[Session, Depends(get_db)]
) -> WorkflowService:
"""获取工作流服务(依赖注入)"""
return WorkflowService(db)

View File

@@ -0,0 +1,219 @@
# 智能客服工作流模板
id: customer_service_v1
name: 智能客服工作流
description: 智能客服场景,包含意图识别、知识库查询和回复生成
category: customer_service
version: "1.0.0"
author: RedBear Memory Team
tags:
- 客服
- 意图识别
- 知识库
- 多步骤
# 工作流配置
nodes:
- id: start
type: start
name: 开始
position:
x: 100
y: 200
- id: intent_recognition
type: llm
name: 意图识别
config:
prompt: |
分析用户的问题,识别意图类型。
用户问题:{{ var.user_message }}
请从以下类型中选择一个:
- product_inquiry: 产品咨询
- technical_support: 技术支持
- complaint: 投诉建议
- other: 其他
只返回类型名称,不要其他内容。
model: gpt-3.5-turbo
temperature: 0.3
max_tokens: 50
position:
x: 300
y: 200
- id: intent_router
type: condition
name: 意图路由
position:
x: 500
y: 200
- id: product_handler
type: llm
name: 产品咨询处理
config:
prompt: |
用户咨询产品相关问题。
问题:{{ var.user_message }}
意图:{{ node.intent_recognition.output }}
请提供专业、友好的产品咨询回复。
model: gpt-3.5-turbo
temperature: 0.7
max_tokens: 500
position:
x: 700
y: 100
- id: support_handler
type: llm
name: 技术支持处理
config:
prompt: |
用户需要技术支持。
问题:{{ var.user_message }}
意图:{{ node.intent_recognition.output }}
请提供详细的技术支持方案。
model: gpt-3.5-turbo
temperature: 0.5
max_tokens: 800
position:
x: 700
y: 200
- id: complaint_handler
type: llm
name: 投诉处理
config:
prompt: |
用户提出投诉或建议。
问题:{{ var.user_message }}
意图:{{ node.intent_recognition.output }}
请以同理心回应,并提供解决方案。
model: gpt-3.5-turbo
temperature: 0.8
max_tokens: 600
position:
x: 700
y: 300
- id: general_handler
type: llm
name: 通用处理
config:
prompt: |
用户的问题类型:其他
问题:{{ var.user_message }}
请提供友好的回复。
model: gpt-3.5-turbo
temperature: 0.7
max_tokens: 400
position:
x: 700
y: 400
- id: end
type: end
name: 结束
position:
x: 900
y: 200
edges:
- source: start
target: intent_recognition
label: 开始分析
- source: intent_recognition
target: intent_router
label: 识别完成
- source: intent_router
target: product_handler
condition: "'product_inquiry' in node['intent_recognition']['output']"
label: 产品咨询
- source: intent_router
target: support_handler
condition: "'technical_support' in node['intent_recognition']['output']"
label: 技术支持
- source: intent_router
target: complaint_handler
condition: "'complaint' in node['intent_recognition']['output']"
label: 投诉建议
- source: intent_router
target: general_handler
condition: "True" # 默认路径
label: 其他
- source: product_handler
target: end
label: 完成
- source: support_handler
target: end
label: 完成
- source: complaint_handler
target: end
label: 完成
- source: general_handler
target: end
label: 完成
# 变量定义
variables:
- name: user_message
type: string
required: true
description: 用户的消息
default: ""
- name: user_name
type: string
required: false
description: 用户姓名(可选)
default: "客户"
# 执行配置
execution_config:
max_execution_time: 120
max_iterations: 10
# 触发器
triggers: []
# 使用示例
examples:
- name: 产品咨询
description: 用户咨询产品功能
input:
user_message: "你们的产品支持多语言吗?"
user_name: "张三"
expected_output: "产品功能介绍"
- name: 技术支持
description: 用户遇到技术问题
input:
user_message: "我无法登录系统,一直显示密码错误"
user_name: "李四"
expected_output: "技术支持方案"
- name: 投诉处理
description: 用户提出投诉
input:
user_message: "你们的服务态度太差了,我要投诉"
user_name: "王五"
expected_output: "同理心回应和解决方案"

View File

@@ -0,0 +1,131 @@
# 数据处理工作流模板
id: data_processing_v1
name: 数据处理工作流
description: 数据提取、转换和分析的完整流程
category: data_processing
version: "1.0.0"
author: RedBear Memory Team
tags:
- 数据处理
- ETL
- 分析
- Transform
# 工作流配置
nodes:
- id: start
type: start
name: 开始
position:
x: 100
y: 200
- id: extract_data
type: transform
name: 数据提取
config:
expression: |
{
"raw_text": var['input_text'],
"length": len(var['input_text']),
"timestamp": sys['execution_id']
}
position:
x: 300
y: 200
- id: analyze_data
type: llm
name: 数据分析
config:
prompt: |
请分析以下数据:
原始文本:{{ node.extract_data.raw_text }}
文本长度:{{ node.extract_data.length }}
请提供:
1. 主题分类
2. 情感分析
3. 关键信息提取
以 JSON 格式返回结果。
model: gpt-3.5-turbo
temperature: 0.3
max_tokens: 500
position:
x: 500
y: 200
- id: transform_result
type: transform
name: 结果转换
config:
expression: |
{
"original_length": node['extract_data']['length'],
"analysis": node['analyze_data']['output'],
"processed_at": sys['execution_id'],
"status": "completed"
}
position:
x: 700
y: 200
- id: end
type: end
name: 结束
position:
x: 900
y: 200
edges:
- source: start
target: extract_data
label: 开始提取
- source: extract_data
target: analyze_data
label: 开始分析
- source: analyze_data
target: transform_result
label: 转换结果
- source: transform_result
target: end
label: 完成
# 变量定义
variables:
- name: input_text
type: string
required: true
description: 待处理的文本数据
default: ""
# 执行配置
execution_config:
max_execution_time: 180
max_iterations: 5
# 触发器
triggers: []
# 使用示例
examples:
- name: 文本分析
description: 分析一段文本
input:
input_text: "今天天气真好,心情也很愉快。我们公司推出了新产品,市场反响热烈。"
expected_output:
original_length: 35
analysis: "主题:天气和产品,情感:积极"
status: "completed"
- name: 长文本处理
description: 处理较长的文本
input:
input_text: "这是一段很长的文本..."
expected_output:
status: "completed"

View File

@@ -0,0 +1,99 @@
# 多步骤问答工作流
# 演示节点输出参数的使用
id: multi_step_qa_v1
name: 多步骤问答工作流
description: 先分析问题,再生成答案,展示节点间的数据传递
category: advanced
version: "1.0.0"
author: RedBear Memory Team
tags:
- 问答
- 多步骤
- LLM
# 工作流配置
nodes:
- id: start
type: start
name: 开始
position:
x: 100
y: 100
- id: analyze_question
type: llm
name: 分析问题
description: 分析用户问题的类型和意图
config:
model_id: gpt-3.5-turbo
temperature: 0.3
max_tokens: 500
messages:
- role: system
content: |
你是一个问题分析专家。请分析用户的问题,提取以下信息:
1. 问题类型(事实性、观点性、操作性等)
2. 问题领域(科技、历史、文化等)
3. 关键词
- role: user
content: "{{ sys.message }}"
position:
x: 300
y: 100
- id: generate_answer
type: llm
name: 生成答案
description: 根据问题分析结果生成详细答案
config:
model_id: gpt-3.5-turbo
temperature: 0.7
max_tokens: 1000
messages:
- role: system
content: |
你是一个专业的AI助手。根据问题分析结果生成准确、详细的答案。
问题分析结果:
{{ analyze_question.output }}
- role: user
content: "{{ sys.message }}"
position:
x: 500
y: 100
- id: end
type: end
name: 结束
config:
output: "{{ generate_answer.output }}"
position:
x: 700
y: 100
edges:
- source: start
target: analyze_question
label: 开始分析
- source: analyze_question
target: generate_answer
label: 生成答案
- source: generate_answer
target: end
label: 完成
# 变量定义
variables:
- name: user_question
type: string
required: true
description: 用户的问题
default: ""
# 执行配置
execution_config:
max_execution_time: 120
max_iterations: 1

View File

@@ -0,0 +1,100 @@
# 简单问答工作流模板
id: simple_qa_v1
name: 简单问答工作流
description: 最基础的问答工作流,适合快速开始
category: basic
version: "1.0.0"
author: RedBear Memory Team
tags:
- 问答
- 基础
- LLM
# 工作流配置
nodes:
- id: start
type: start
name: 开始
position:
x: 100
y: 100
- id: llm_qa
type: llm
name: LLM 问答
config:
# 使用 LangChain 标准的消息格式
messages:
- role: system
content: |
你是一个专业、友好且乐于助人的 AI 助手。
你的职责:
- 准确理解用户的问题并提供有价值的回答
- 保持回答的专业性和准确性
- 如果不确定答案,诚实地告知用户
- 使用清晰、易懂的语言进行交流
回答风格:
- 简洁明了,直击要点
- 必要时提供详细解释和示例
- 使用友好、礼貌的语气
- 适当使用格式化(如列表、段落)提高可读性
- role: user
content: "{{ sys.message }}"
model_id: gpt-3.5-turbo
temperature: 0.7
max_tokens: 1000
position:
x: 300
y: 100
- id: end
type: end
name: 结束
config:
output: "{{ llm_qa.output }}"
position:
x: 500
y: 100
edges:
- source: start
target: llm_qa
label: 开始处理
- source: llm_qa
target: end
label: 完成
# 变量定义
variables:
- name: user_question
type: string
required: true
description: 用户的问题
default: ""
# 执行配置
execution_config:
max_execution_time: 60
max_iterations: 1
# 触发器(可选)
triggers: []
# 使用示例
examples:
- name: 基础问答
description: 询问一个简单的问题
input:
user_question: "什么是人工智能?"
expected_output: "关于人工智能的解释"
- name: 技术咨询
description: 询问技术问题
input:
user_question: "如何学习 Python 编程?"
expected_output: "Python 学习建议"