diff --git a/.gitignore b/.gitignore index b6c55867..939ac091 100644 --- a/.gitignore +++ b/.gitignore @@ -20,7 +20,8 @@ examples/ .idea # Temporary outputs -**/.DS_Store +app/core/memory/agent/.DS_Store +app/core/memory/src/utils/.DS_Store time.log celerybeat-schedule.db search_results.json diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index 951f2d73..e2295ce3 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -27,6 +27,7 @@ from . import ( release_share_controller, public_share_controller, multi_agent_controller, + workflow_controller, ) # 创建管理端 API 路由器 @@ -56,5 +57,6 @@ manager_router.include_router(release_share_controller.router) manager_router.include_router(public_share_controller.router) # 公开路由(无需认证) manager_router.include_router(memory_dashboard_controller.router) manager_router.include_router(multi_agent_controller.router) +manager_router.include_router(workflow_controller.router) __all__ = ["manager_router"] diff --git a/api/app/controllers/api_key_controller.py b/api/app/controllers/api_key_controller.py index 815d8c69..7617915b 100644 --- a/api/app/controllers/api_key_controller.py +++ b/api/app/controllers/api_key_controller.py @@ -1,7 +1,6 @@ """API Key 管理接口 - 基于 JWT 认证""" import uuid from typing import Optional -from datetime import datetime from fastapi import APIRouter, Depends, Query from sqlalchemy.orm import Session @@ -14,6 +13,7 @@ from app.core.response_utils import success from app.schemas import api_key_schema from app.schemas.response_schema import ApiResponse from app.services.api_key_service import ApiKeyService +from app.core.api_key_utils import timestamp_to_datetime from app.core.logging_config import get_api_logger from app.core.exceptions import ( BusinessException, @@ -41,18 +41,14 @@ def create_api_key( workspace_id = current_user.current_workspace_id # 创建 API Key - api_key_obj, api_key = ApiKeyService.create_api_key( + api_key_obj = ApiKeyService.create_api_key( db, workspace_id=workspace_id, user_id=current_user.id, data=data ) - # 返回包含明文 Key 的响应 - response_data = api_key_schema.ApiKeyResponse( - **api_key_obj.__dict__, - api_key=api_key - ) + response_data = api_key_schema.ApiKeyResponse.model_validate(api_key_obj) return success(data=response_data, msg="API Key 创建成功") except BusinessException: @@ -223,13 +219,9 @@ def regenerate_api_key( """ try: workspace_id = current_user.current_workspace_id - api_key_obj, api_key = ApiKeyService.regenerate_api_key(db, api_key_id, workspace_id) + api_key_obj = ApiKeyService.regenerate_api_key(db, api_key_id, workspace_id) - # 返回包含明文 Key 的响应 - response_data = api_key_schema.ApiKeyResponse( - **api_key_obj.__dict__, - api_key=api_key - ) + response_data = api_key_schema.ApiKeyResponse.model_validate(api_key_obj) logger.info("API Key 重新生成成功", extra={ "api_key_id": str(api_key_id), @@ -283,8 +275,8 @@ def get_api_key_stats( @cur_workspace_access_guard() def get_api_key_logs( api_key_id: uuid.UUID, - start_date: Optional[datetime] = Query(None, description="开始日期"), - end_date: Optional[datetime] = Query(None, description="结束日期"), + start_date: Optional[int] = Query(None, description="开始日期时间戳"), + end_date: Optional[int] = Query(None, description="结束日期时间戳"), status_code: Optional[int] = Query(None, description="HTTP状态码过滤"), endpoint: Optional[str] = Query(None, description="端点路径过滤"), page: int = Query(1, ge=1, description="页码"), @@ -302,14 +294,17 @@ def get_api_key_logs( try: workspace_id = current_user.current_workspace_id + start_datetime = timestamp_to_datetime(start_date) if start_date else None + end_datetime = timestamp_to_datetime(end_date) if end_date else None + # 验证日期范围 - if start_date and end_date and start_date > end_date: + if start_datetime and end_datetime and start_datetime > end_datetime: logger.warning("开始日期晚于结束日期", extra={ "api_key_id": str(api_key_id), "workspace_id": str(workspace_id), "user_id": str(current_user.id), - "start_date": start_date.isoformat(), - "end_date": end_date.isoformat() + "start_date": start_datetime.isoformat(), + "end_date": end_datetime.isoformat() }) raise BusinessException("开始日期不能晚于结束日期", BizCode.INVALID_PARAMETER) @@ -325,8 +320,8 @@ def get_api_key_logs( # 构建过滤条件 filters = { - "start_date": start_date, - "end_date": end_date, + "start_date": start_datetime, + "end_date": end_datetime, "status_code": status_code, "endpoint": endpoint } diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 8177916e..3d09f5fc 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -1,22 +1,26 @@ import uuid -from typing import Optional -from fastapi import APIRouter, Depends +from typing import Optional, Annotated + +from fastapi import APIRouter, Depends, Path +from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session -from app.db import get_db -from app.core.response_utils import success +from app.core.error_codes import BizCode from app.core.logging_config import get_business_logger +from app.core.response_utils import success +from app.db import get_db +from app.dependencies import get_current_user, cur_workspace_access_guard from app.models import User +from app.models.app_model import AppType, App from app.repositories import knowledge_repository from app.schemas import app_schema from app.schemas.response_schema import PageData, PageMeta +from app.schemas.workflow_schema import WorkflowConfigUpdate from app.services import app_service, workspace_service -from app.services.app_service import AppService from app.services.agent_config_helper import enrich_agent_config -from app.dependencies import get_current_user, cur_workspace_access_guard, workspace_access_guard -from fastapi.responses import StreamingResponse -from app.models.app_model import AppType -from app.core.error_codes import BizCode +from app.services.app_service import AppService +from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema +from app.services.workflow_service import WorkflowService, get_workflow_service router = APIRouter(prefix="/apps", tags=["Apps"]) logger = get_business_logger() @@ -48,7 +52,7 @@ def list_apps( current_user=Depends(get_current_user), ): """列出应用 - + - 默认包含本工作空间的应用和分享给本工作空间的应用 - 设置 include_shared=false 可以只查看本工作空间的应用 """ @@ -63,8 +67,8 @@ def list_apps( include_shared=include_shared, page=page, pagesize=pagesize, - ) - + ) + # 使用 AppService 的转换方法来设置 is_shared 字段 service = app_service.AppService(db) items = [service._convert_to_schema(app, workspace_id) for app in items_orm] @@ -79,14 +83,14 @@ def get_app( current_user=Depends(get_current_user), ): """获取应用详细信息 - + - 支持获取本工作空间的应用 - 支持获取分享给本工作空间的应用 """ workspace_id = current_user.current_workspace_id service = app_service.AppService(db) app = service.get_app(app_id, workspace_id) - + # 转换为 Schema 并设置 is_shared 字段 app_schema_obj = service._convert_to_schema(app, workspace_id) return success(data=app_schema_obj) @@ -113,7 +117,7 @@ def delete_app( current_user=Depends(get_current_user), ): """删除应用 - + 会级联删除: - Agent 配置 - 发布版本 @@ -128,9 +132,9 @@ def delete_app( "workspace_id": str(workspace_id) } ) - + app_service.delete_app(db, app_id=app_id, workspace_id=workspace_id) - + return success(msg="应用删除成功") @@ -143,7 +147,7 @@ def copy_app( current_user=Depends(get_current_user), ): """复制应用(包括基础信息和配置) - + - 复制应用的基础信息(名称、描述、图标等) - 复制 Agent 配置(如果是 agent 类型) - 新应用默认为草稿状态 @@ -159,7 +163,7 @@ def copy_app( "new_name": new_name } ) - + service = AppService(db) new_app = service.copy_app( app_id=app_id, @@ -167,7 +171,7 @@ def copy_app( workspace_id=workspace_id, new_name=new_name ) - + return success(data=app_schema.App.model_validate(new_app), msg="应用复制成功") @@ -209,9 +213,9 @@ def publish_app( ): workspace_id = current_user.current_workspace_id release = app_service.publish( - db, - app_id=app_id, - publisher_id=current_user.id, + db, + app_id=app_id, + publisher_id=current_user.id, workspace_id=workspace_id, version_name = payload.version_name, release_notes=payload.release_notes @@ -268,13 +272,13 @@ def share_app( current_user=Depends(get_current_user), ): """分享应用到其他工作空间 - + - 只能分享自己工作空间的应用 - 不能分享到自己的工作空间 - 同一个应用不能重复分享到同一个工作空间 """ workspace_id = current_user.current_workspace_id - + service = app_service.AppService(db) shares = service.share_app( app_id=app_id, @@ -282,7 +286,7 @@ def share_app( user_id=current_user.id, workspace_id=workspace_id ) - + data = [app_schema.AppShare.model_validate(s) for s in shares] return success(data=data, msg=f"应用已分享到 {len(shares)} 个工作空间") @@ -296,18 +300,18 @@ def unshare_app( current_user=Depends(get_current_user), ): """取消应用分享 - + - 只能取消自己工作空间应用的分享 """ workspace_id = current_user.current_workspace_id - + service = app_service.AppService(db) service.unshare_app( app_id=app_id, target_workspace_id=target_workspace_id, workspace_id=workspace_id ) - + return success(msg="应用分享已取消") @@ -319,17 +323,17 @@ def list_app_shares( current_user=Depends(get_current_user), ): """列出应用的所有分享记录 - + - 只能查看自己工作空间应用的分享记录 """ workspace_id = current_user.current_workspace_id - + service = app_service.AppService(db) shares = service.list_app_shares( app_id=app_id, workspace_id=workspace_id ) - + data = [app_schema.AppShare.model_validate(s) for s in shares] return success(data=data) @@ -340,10 +344,11 @@ async def draft_run( payload: app_schema.DraftRunRequest, db: Session = Depends(get_db), current_user=Depends(get_current_user), + workflow_service: Annotated[WorkflowService, Depends(get_workflow_service)] = None ): """ 试运行 Agent,使用当前的草稿配置(未发布的配置) - + - 不需要发布应用即可测试 - 使用当前的 AgentConfig 配置 - 支持流式和非流式返回 @@ -367,33 +372,44 @@ async def draft_run( ) if knowledge: user_rag_memory_id = str(knowledge.id) - + # 提前验证和准备(在流式响应开始前完成) from app.services.app_service import AppService from app.services.multi_agent_service import MultiAgentService from app.models import AgentConfig, ModelConfig from sqlalchemy import select from app.core.exceptions import BusinessException - - + from app.services.draft_run_service import DraftRunService + service = AppService(db) - + draft_service = DraftRunService(db) + # 1. 验证应用 app = service._get_app_or_404(app_id) - if app.type != AppType.AGENT and app.type != AppType.MULTI_AGENT: - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - + if app.type != AppType.AGENT and app.type != AppType.MULTI_AGENT and app.type != AppType.WORKFLOW: + raise BusinessException("只有 Agent , Workflow 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) + # 只读操作,允许访问共享应用 service._validate_app_accessible(app, workspace_id) + + # 处理会话ID(创建或验证) + conversation_id = await draft_service._ensure_conversation( + conversation_id=payload.conversation_id, + app_id=app_id, + workspace_id=workspace_id, + user_id=payload.user_id + ) + payload.conversation_id = conversation_id + if app.type == AppType.AGENT: service._check_agent_config(app_id) - + # 2. 获取 Agent 配置 stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) agent_cfg = db.scalars(stmt).first() if not agent_cfg: raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - + # 3. 获取模型配置 model_config = None if agent_cfg.default_model_config_id: @@ -401,12 +417,12 @@ async def draft_run( if not model_config: from app.core.exceptions import ResourceNotFoundException raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id)) - + # 流式返回 if payload.stream: async def event_generator(): - from app.services.draft_run_service import DraftRunService - draft_service = DraftRunService(db) + + async for event in draft_service.run_stream( agent_config=agent_cfg, model_config=model_config, @@ -419,7 +435,7 @@ async def draft_run( user_rag_memory_id=user_rag_memory_id ): yield event - + return StreamingResponse( event_generator(), media_type="text/event-stream", @@ -429,7 +445,7 @@ async def draft_run( "X-Accel-Buffering": "no" } ) - + # 非流式返回 logger.debug( "开始非流式试运行", @@ -440,7 +456,7 @@ async def draft_run( "has_variables": bool(payload.variables) } ) - + from app.services.draft_run_service import DraftRunService draft_service = DraftRunService(db) result = await draft_service.run( @@ -454,7 +470,7 @@ async def draft_run( storage_type=storage_type, user_rag_memory_id=user_rag_memory_id ) - + logger.debug( "试运行返回结果", extra={ @@ -462,7 +478,7 @@ async def draft_run( "result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict" } ) - + # 验证结果 try: validated_result = app_schema.DraftRunResponse.model_validate(result) @@ -481,10 +497,10 @@ async def draft_run( elif app.type == AppType.MULTI_AGENT: # 1. 检查多智能体配置完整性 service._check_multi_agent_config(app_id) - + # 2. 构建多智能体运行请求 from app.schemas.multi_agent_schema import MultiAgentRunRequest - + multi_agent_request = MultiAgentRunRequest( message=payload.message, conversation_id=payload.conversation_id, @@ -492,7 +508,7 @@ async def draft_run( variables=payload.variables or {}, use_llm_routing=True # 默认启用 LLM 路由 ) - + # 3. 流式返回 if payload.stream: logger.debug( @@ -503,11 +519,11 @@ async def draft_run( "has_conversation_id": bool(payload.conversation_id) } ) - + async def event_generator(): """多智能体流式事件生成器""" multiservice = MultiAgentService(db) - + # 调用多智能体服务的流式方法 async for event in multiservice.run_stream( app_id=app_id, @@ -517,7 +533,7 @@ async def draft_run( ): yield event - + return StreamingResponse( event_generator(), media_type="text/event-stream", @@ -527,7 +543,7 @@ async def draft_run( "X-Accel-Buffering": "no" } ) - + # 4. 非流式返回 logger.debug( "开始多智能体非流式试运行", @@ -537,10 +553,10 @@ async def draft_run( "has_conversation_id": bool(payload.conversation_id) } ) - + multiservice = MultiAgentService(db) result = await multiservice.run(app_id, multi_agent_request) - + logger.debug( "多智能体试运行返回结果", extra={ @@ -548,12 +564,71 @@ async def draft_run( "has_response": "response" in result if isinstance(result, dict) else False } ) - + return success( data=result, msg="多 Agent 任务执行成功" ) - + elif app.type == AppType.WORKFLOW: #工作流 + config = workflow_service.check_config(app_id) + # 3. 流式返回 + if payload.stream: + logger.debug( + "开始多智能体流式试运行", + extra={ + "app_id": str(app_id), + "message_length": len(payload.message), + "has_conversation_id": bool(payload.conversation_id) + } + ) + + async def event_generator(): + """多智能体流式事件生成器""" + multiservice = MultiAgentService(db) + + # 调用多智能体服务的流式方法 + async for event in multiservice.run_stream( + app_id=app_id, + request=multi_agent_request, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id + + ): + yield event + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no" + } + ) + + # 4. 非流式返回 + logger.debug( + "开始非流式试运行", + extra={ + "app_id": str(app_id), + "message_length": len(payload.message), + "has_conversation_id": bool(payload.conversation_id) + } + ) + + result = await workflow_service.run(app_id, payload,config) + + logger.debug( + "工作流试运行返回结果", + extra={ + "result_type": str(type(result)), + "has_response": "response" in result if isinstance(result, dict) else False + } + ) + return success( + data=result, + msg="工作流任务执行成功" + ) @@ -567,21 +642,21 @@ async def draft_run_compare( ): """ 多模型对比试运行 - + - 支持对比 1-5 个模型 - 可以是不同的模型,也可以是同一模型的不同参数配置 - 通过 model_parameters 覆盖默认参数 - 支持并行或串行执行(非流式) - 支持流式返回(串行执行) - 返回每个模型的运行结果和性能对比 - + 使用场景: 1. 对比不同模型的效果(GPT-4 vs Claude vs Gemini) 2. 调优模型参数(不同 temperature 的效果对比) 3. 性能和成本分析 """ workspace_id = current_user.current_workspace_id - + # 获取 storage_type,如果为 None 则使用默认值 storage_type = workspace_service.get_workspace_storage_type( db=db, @@ -597,7 +672,7 @@ async def draft_run_compare( workspace_id=workspace_id ) if knowledge: user_rag_memory_id = str(knowledge.id) - + logger.info( "多模型对比试运行", extra={ @@ -607,13 +682,13 @@ async def draft_run_compare( "stream": payload.stream } ) - + # 提前验证和准备(在流式响应开始前完成) from app.services.app_service import AppService from app.models import ModelConfig - + service = AppService(db) - + # 1. 验证应用和权限 app = service._get_app_or_404(app_id) if app.type != "agent": @@ -621,7 +696,7 @@ async def draft_run_compare( from app.core.error_codes import BizCode raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) service._validate_app_accessible(app, workspace_id) - + # 2. 获取 Agent 配置 from sqlalchemy import select from app.models import AgentConfig @@ -631,7 +706,7 @@ async def draft_run_compare( from app.core.exceptions import BusinessException from app.core.error_codes import BizCode raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - + # 3. 验证所有模型配置 model_configs = [] for model_item in payload.models: @@ -639,12 +714,12 @@ async def draft_run_compare( if not model_config: from app.core.exceptions import ResourceNotFoundException raise ResourceNotFoundException("模型配置", str(model_item.model_config_id)) - + merged_parameters = { **(agent_cfg.model_parameters or {}), **(model_item.model_parameters or {}) } - + model_configs.append({ "model_config": model_config, "parameters": merged_parameters, @@ -652,7 +727,7 @@ async def draft_run_compare( "model_config_id": model_item.model_config_id, "conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id }) - + # 流式返回 if payload.stream: async def event_generator(): @@ -674,7 +749,7 @@ async def draft_run_compare( timeout=payload.timeout or 60 ): yield event - + return StreamingResponse( event_generator(), media_type="text/event-stream", @@ -684,7 +759,7 @@ async def draft_run_compare( "X-Accel-Buffering": "no" } ) - + # 非流式返回 from app.services.draft_run_service import DraftRunService draft_service = DraftRunService(db) @@ -703,7 +778,7 @@ async def draft_run_compare( parallel=payload.parallel, timeout=payload.timeout or 60 ) - + logger.info( "多模型对比完成", extra={ @@ -712,5 +787,36 @@ async def draft_run_compare( "failed": result["failed_count"] } ) - + return success(data=app_schema.DraftRunCompareResponse(**result)) + + +@router.get("/{app_id}/workflow") +@cur_workspace_access_guard() +async def get_workflow_config( + app_id: Annotated[uuid.UUID, Path(description="应用 ID")], + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)] + +): + """获取工作流配置 + + 获取应用的工作流配置详情。 + """ + workspace_id = current_user.current_workspace_id + cfg = app_service.get_workflow_config(db=db, app_id=app_id, workspace_id=workspace_id) + # 配置总是存在(不存在时返回默认模板) + return success(data=WorkflowConfigSchema.model_validate(cfg)) + +@router.put("/{app_id}/workflow", summary="更新 Workflow 配置") +@cur_workspace_access_guard() +async def update_workflow_config( + app_id: uuid.UUID, + payload: WorkflowConfigUpdate, + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)] +): + workspace_id = current_user.current_workspace_id + cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id) + return success(data=WorkflowConfigSchema.model_validate(cfg)) + diff --git a/api/app/controllers/document_controller.py b/api/app/controllers/document_controller.py index b6c688b2..39a690f9 100644 --- a/api/app/controllers/document_controller.py +++ b/api/app/controllers/document_controller.py @@ -29,10 +29,10 @@ router = APIRouter( ) -@router.get("/{kb_id}/{parent_id}/documents", response_model=ApiResponse) +@router.get("/{kb_id}/documents", response_model=ApiResponse) async def get_documents( kb_id: uuid.UUID, - parent_id: uuid.UUID, + parent_id: Optional[uuid.UUID] = Query(None, description="parent folder id when type is Folder"), page: int = Query(1, gt=0), # Default: 1, which must be greater than 0 pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"), diff --git a/api/app/controllers/service/app_api_controller.py b/api/app/controllers/service/app_api_controller.py index ec587510..1731405c 100644 --- a/api/app/controllers/service/app_api_controller.py +++ b/api/app/controllers/service/app_api_controller.py @@ -1,10 +1,13 @@ """App 服务接口 - 基于 API Key 认证""" -from fastapi import APIRouter, Depends +import uuid +from fastapi import APIRouter, Depends, Request, Body from sqlalchemy.orm import Session from app.db import get_db from app.core.response_utils import success from app.core.logging_config import get_business_logger +from app.core.api_key_auth import require_api_key +from app.schemas.api_key_schema import ApiKeyAuth router = APIRouter(prefix="/apps", tags=["V1 - App API"]) logger = get_business_logger() @@ -14,3 +17,30 @@ logger = get_business_logger() async def list_apps(): """列出可访问的应用(占位)""" return success(data=[], msg="App API - Coming Soon") + +# /v1/apps/{resource_id}/chat +@router.post("/{resource_id}/chat") +@require_api_key(scopes=["app"]) +async def chat_with_agent_demo( + resource_id: uuid.UUID, + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), + message: str = Body(..., description="聊天消息内容"), +): + """ + Agent 聊天接口demo + + scopes: 所需的权限范围列表["app", "rag", "memory"] + + Args: + resource_id: 如果是应用的apikey传的是应用id; 如果是服务的apikey传的是工作空间id + message: 请求参数 + request: 声明请求 + api_key_auth: 包含验证后的API Key 信息 + db: db_session + """ + logger.info(f"API Key Auth: {api_key_auth}") + logger.info(f"Resource ID: {resource_id}") + logger.info(f"Message: {message}") + return success(data={"received": True}, msg="消息已接收") diff --git a/api/app/controllers/workflow_controller.py b/api/app/controllers/workflow_controller.py new file mode 100644 index 00000000..9ccfa858 --- /dev/null +++ b/api/app/controllers/workflow_controller.py @@ -0,0 +1,587 @@ +""" +工作流 API 控制器 +""" + +import logging +import uuid +from typing import Annotated + +from fastapi import APIRouter, Depends, Path, Query +from sqlalchemy.orm import Session + +from app.db import get_db +from app.dependencies import get_current_user, cur_workspace_access_guard + +from app.models.user_model import User +from app.models.app_model import App +from app.services.workflow_service import WorkflowService, get_workflow_service +from app.schemas.workflow_schema import ( + WorkflowConfigCreate, + WorkflowConfigUpdate, + WorkflowConfig, + WorkflowValidationResponse, + WorkflowExecution, + WorkflowNodeExecution, + WorkflowExecutionRequest, + WorkflowExecutionResponse +) +from app.core.response_utils import success, fail +from app.core.exceptions import BusinessException +from app.core.error_codes import BizCode + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/apps", tags=["workflow"]) + + +# ==================== 工作流配置管理 ==================== + +@router.post("/{app_id}/workflow") +@cur_workspace_access_guard() +async def create_workflow_config( + app_id: Annotated[uuid.UUID, Path(description="应用 ID")], + config: WorkflowConfigCreate, + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)] +): + """创建工作流配置 + + 创建或更新应用的工作流配置。配置会进行基础验证,但允许保存不完整的配置(草稿)。 + """ + try: + # 验证应用是否存在且属于当前工作空间 + app = db.query(App).filter( + App.id == app_id, + App.workspace_id == current_user.current_workspace_id, + App.is_active == True + ).first() + + if not app: + return fail( + code=BizCode.NOT_FOUND, + msg="应用不存在或无权访问" + ) + + # 验证应用类型 + if app.type != "workflow": + return fail( + code=BizCode.INVALID_PARAMETER, + msg=f"应用类型必须为 workflow,当前为 {app.type}" + ) + + # 创建工作流配置 + workflow_config = service.create_workflow_config( + app_id=app_id, + nodes=[node.model_dump() for node in config.nodes], + edges=[edge.model_dump() for edge in config.edges], + variables=[var.model_dump() for var in config.variables], + execution_config=config.execution_config.model_dump(), + triggers=[trigger.model_dump() for trigger in config.triggers], + validate=True # 进行基础验证 + ) + + return success( + data=WorkflowConfig.model_validate(workflow_config), + msg="工作流配置创建成功" + ) + + except BusinessException as e: + logger.warning(f"创建工作流配置失败: {e.message}") + return fail(code=e.error_code, msg=e.message) + except Exception as e: + logger.error(f"创建工作流配置异常: {e}", exc_info=True) + return fail( + code=BizCode.INTERNAL_ERROR, + msg=f"创建工作流配置失败: {str(e)}" + ) + +# +# @router.get("/{app_id}/workflow") +# async def get_workflow_config( +# app_id: Annotated[uuid.UUID, Path(description="应用 ID")], +# db: Annotated[Session, Depends(get_db)], +# current_user: Annotated[User, Depends(get_current_user)] +# +# ): +# """获取工作流配置 +# +# 获取应用的工作流配置详情。 +# """ +# try: +# # 验证应用是否存在且属于当前工作空间 +# app = db.query(App).filter( +# App.id == app_id, +# App.workspace_id == current_user.current_workspace_id, +# App.is_active == True +# ).first() +# +# if not app: +# return fail( +# code=BizCode.NOT_FOUND, +# msg="应用不存在或无权访问" +# ) +# +# # 获取工作流配置 +# service = WorkflowService(db) +# workflow_config = service.get_workflow_config(app_id) +# +# if not workflow_config: +# return fail( +# code=BizCode.NOT_FOUND, +# msg="工作流配置不存在" +# ) +# +# return success( +# data=WorkflowConfig.model_validate(workflow_config) +# ) +# +# except Exception as e: +# logger.error(f"获取工作流配置异常: {e}", exc_info=True) +# return fail( +# code=BizCode.INTERNAL_ERROR, +# msg=f"获取工作流配置失败: {str(e)}" +# ) + + +# @router.put("/{app_id}/workflow") +# async def update_workflow_config( +# app_id: Annotated[uuid.UUID, Path(description="应用 ID")], +# config: WorkflowConfigUpdate, +# db: Annotated[Session, Depends(get_db)], +# current_user: Annotated[User, Depends(get_current_user)], +# service: Annotated[WorkflowService, Depends(get_workflow_service)] +# ): +# """更新工作流配置 + +# 更新应用的工作流配置。可以部分更新,未提供的字段保持不变。 +# """ +# try: +# # 验证应用是否存在且属于当前工作空间 +# app = db.query(App).filter( +# App.id == app_id, +# App.workspace_id == current_user.current_workspace_id, +# App.is_active == True +# ).first() + +# if not app: +# return fail( +# code=BizCode.NOT_FOUND, +# msg="应用不存在或无权访问" +# ) + +# # 更新工作流配置 +# workflow_config = service.update_workflow_config( +# app_id=app_id, +# nodes=[node.model_dump() for node in config.nodes] if config.nodes else None, +# edges=[edge.model_dump() for edge in config.edges] if config.edges else None, +# variables=[var.model_dump() for var in config.variables] if config.variables else None, +# execution_config=config.execution_config.model_dump() if config.execution_config else None, +# triggers=[trigger.model_dump() for trigger in config.triggers] if config.triggers else None, +# validate=True +# ) + +# return success( +# data=WorkflowConfig.model_validate(workflow_config), +# msg="工作流配置更新成功" +# ) + +# except BusinessException as e: +# logger.warning(f"更新工作流配置失败: {e.message}") +# return fail(code=e.error_code, msg=e.message) +# except Exception as e: +# logger.error(f"更新工作流配置异常: {e}", exc_info=True) +# return fail( +# code=BizCode.INTERNAL_ERROR, +# msg=f"更新工作流配置失败: {str(e)}" +# ) + + +@router.delete("/{app_id}/workflow") +async def delete_workflow_config( + app_id: Annotated[uuid.UUID, Path(description="应用 ID")], + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)] +): + """删除工作流配置 + + 删除应用的工作流配置。 + """ + try: + # 验证应用是否存在且属于当前工作空间 + app = db.query(App).filter( + App.id == app_id, + App.workspace_id == current_user.current_workspace_id, + App.is_active == True + ).first() + + if not app: + return fail( + code=BizCode.NOT_FOUND, + msg="应用不存在或无权访问" + ) + + # 删除工作流配置 + deleted = service.delete_workflow_config(app_id) + + if not deleted: + return fail( + code=BizCode.NOT_FOUND, + msg="工作流配置不存在" + ) + + return success(msg="工作流配置删除成功") + + except Exception as e: + logger.error(f"删除工作流配置异常: {e}", exc_info=True) + return fail( + code=BizCode.INTERNAL_ERROR, + msg=f"删除工作流配置失败: {str(e)}" + ) + + +@router.post("/{app_id}/workflow/validate") +async def validate_workflow_config( + app_id: Annotated[uuid.UUID, Path(description="应用 ID")], + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)], + for_publish: Annotated[bool, Query(description="是否为发布验证")] = False +): + """验证工作流配置 + + 验证工作流配置是否有效。可以选择是否进行发布级别的严格验证。 + """ + try: + # 验证应用是否存在且属于当前工作空间 + app = db.query(App).filter( + App.id == app_id, + App.workspace_id == current_user.current_workspace_id, + App.is_active == True + ).first() + + if not app: + return fail( + code=BizCode.NOT_FOUND, + msg="应用不存在或无权访问" + ) + + # 验证工作流配置 + + if for_publish: + is_valid, errors = service.validate_workflow_config_for_publish(app_id) + else: + workflow_config = service.get_workflow_config(app_id) + if not workflow_config: + return fail( + code=BizCode.NOT_FOUND, + msg="工作流配置不存在" + ) + + from app.core.workflow.validator import validate_workflow_config as validate_config + config_dict = { + "nodes": workflow_config.nodes, + "edges": workflow_config.edges, + "variables": workflow_config.variables, + "execution_config": workflow_config.execution_config, + "triggers": workflow_config.triggers + } + is_valid, errors = validate_config(config_dict, for_publish=False) + + return success( + data=WorkflowValidationResponse( + is_valid=is_valid, + errors=errors, + warnings=[] + ) + ) + + except BusinessException as e: + logger.warning(f"验证工作流配置失败: {e.message}") + return fail(code=e.error_code, msg=e.message) + except Exception as e: + logger.error(f"验证工作流配置异常: {e}", exc_info=True) + return fail( + code=BizCode.INTERNAL_ERROR, + msg=f"验证工作流配置失败: {str(e)}" + ) + + +# ==================== 工作流执行管理 ==================== + +@router.get("/{app_id}/workflow/executions") +async def get_workflow_executions( + app_id: Annotated[uuid.UUID, Path(description="应用 ID")], + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)], + limit: Annotated[int, Query(ge=1, le=100)] = 50, + offset: Annotated[int, Query(ge=0)] = 0 +): + """获取工作流执行记录列表 + + 获取应用的工作流执行历史记录。 + """ + try: + # 验证应用是否存在且属于当前工作空间 + app = db.query(App).filter( + App.id == app_id, + App.workspace_id == current_user.current_workspace_id, + App.is_active == True + ).first() + + if not app: + return fail( + code=BizCode.NOT_FOUND, + msg="应用不存在或无权访问" + ) + + # 获取执行记录 + executions = service.get_executions_by_app(app_id, limit, offset) + + # 获取统计信息 + statistics = service.get_execution_statistics(app_id) + + return success( + data={ + "executions": [WorkflowExecution.model_validate(e) for e in executions], + "statistics": statistics, + "pagination": { + "limit": limit, + "offset": offset, + "total": statistics["total"] + } + } + ) + + except Exception as e: + logger.error(f"获取工作流执行记录异常: {e}", exc_info=True) + return fail( + code=BizCode.INTERNAL_ERROR, + msg=f"获取工作流执行记录失败: {str(e)}" + ) + + +@router.get("/workflow/executions/{execution_id}") +async def get_workflow_execution( + execution_id: Annotated[str, Path(description="执行 ID")], + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)] +): + """获取工作流执行详情 + + 获取单个工作流执行的详细信息,包括所有节点的执行记录。 + """ + try: + # 获取执行记录 + execution = service.get_execution(execution_id) + + if not execution: + return fail( + code=BizCode.NOT_FOUND, + msg="执行记录不存在" + ) + + # 验证应用是否属于当前工作空间 + app = db.query(App).filter( + App.id == execution.app_id, + App.workspace_id == current_user.current_workspace_id, + App.is_active == True + ).first() + + if not app: + return fail( + code=BizCode.NOT_FOUND, + msg="无权访问该执行记录" + ) + + # 获取节点执行记录 + node_executions = service.node_execution_repo.get_by_execution_id(execution.id) + + return success( + data={ + "execution": WorkflowExecution.model_validate(execution), + "node_executions": [ + WorkflowNodeExecution.model_validate(ne) for ne in node_executions + ] + } + ) + + except Exception as e: + logger.error(f"获取工作流执行详情异常: {e}", exc_info=True) + return fail( + code=BizCode.INTERNAL_ERROR, + msg=f"获取工作流执行详情失败: {str(e)}" + ) + + + +# ==================== 工作流执行 ==================== + +@router.post("/{app_id}/workflow/run") +async def run_workflow( + app_id: Annotated[uuid.UUID, Path(description="应用 ID")], + request: WorkflowExecutionRequest, + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)] +): + """执行工作流 + + 执行工作流并返回结果。支持流式和非流式两种模式。 + + **非流式模式**:等待工作流执行完成后返回完整结果。 + + **流式模式**:实时返回执行过程中的事件(节点开始、节点完成、工作流完成等)。 + """ + try: + # 验证应用是否存在且属于当前工作空间 + app = db.query(App).filter( + App.id == app_id, + App.workspace_id == current_user.current_workspace_id, + App.is_active == True + ).first() + + if not app: + return fail( + code=BizCode.NOT_FOUND, + msg="应用不存在或无权访问" + ) + + # 验证应用类型 + if app.type != "workflow": + return fail( + code=BizCode.INVALID_PARAMETER, + msg=f"应用类型必须为 workflow,当前为 {app.type}" + ) + + # 准备输入数据 + input_data = { + "message": request.message or "", + "variables": request.variables + } + + # 执行工作流 + + if request.stream: + # 流式执行 + from fastapi.responses import StreamingResponse + import json + + async def event_generator(): + """生成 SSE 事件""" + try: + async for event in service.run_workflow( + app_id=app_id, + input_data=input_data, + triggered_by=current_user.id, + conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None, + stream=True + ): + # 转换为 SSE 格式 + yield f"data: {json.dumps(event)}\n\n" + except Exception as e: + logger.error(f"流式执行异常: {e}", exc_info=True) + error_event = { + "type": "error", + "error": str(e) + } + yield f"data: {json.dumps(error_event)}\n\n" + + return StreamingResponse( + event_generator(), + media_type="text/event-stream" + ) + else: + # 非流式执行 + result = await service.run_workflow( + app_id=app_id, + input_data=input_data, + triggered_by=current_user.id, + conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None, + stream=False + ) + + return success( + data=WorkflowExecutionResponse( + execution_id=result["execution_id"], + status=result["status"], + output=result.get("output"), + output_data=result.get("output_data"), + error_message=result.get("error_message"), + elapsed_time=result.get("elapsed_time"), + token_usage=result.get("token_usage") + ), + msg="工作流执行完成" + ) + + except BusinessException as e: + logger.warning(f"执行工作流失败: {e.message}") + return fail(code=e.error_code, msg=e.message) + except Exception as e: + logger.error(f"执行工作流异常: {e}", exc_info=True) + return fail( + code=BizCode.INTERNAL_ERROR, + msg=f"执行工作流失败: {str(e)}" + ) + + +@router.post("/workflow/executions/{execution_id}/cancel") +async def cancel_workflow_execution( + execution_id: Annotated[str, Path(description="执行 ID")], + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)], + service: Annotated[WorkflowService, Depends(get_workflow_service)] +): + """取消工作流执行 + + 取消正在运行的工作流执行。 + + **注意**:当前版本仅更新状态为 cancelled,实际的执行取消功能待实现。 + """ + try: + # 获取执行记录 + execution = service.get_execution(execution_id) + + if not execution: + return fail( + code=BizCode.NOT_FOUND, + msg="执行记录不存在" + ) + + # 验证应用是否属于当前工作空间 + app = db.query(App).filter( + App.id == execution.app_id, + App.workspace_id == current_user.current_workspace_id, + App.is_active == True + ).first() + + if not app: + return fail( + code=BizCode.NOT_FOUND, + msg="无权访问该执行记录" + ) + + # 检查执行状态 + if execution.status not in ["pending", "running"]: + return fail( + code=BizCode.INVALID_PARAMETER, + msg=f"无法取消状态为 {execution.status} 的执行" + ) + + # 更新状态为 cancelled + service.update_execution_status(execution_id, "cancelled") + + return success(msg="工作流执行已取消") + + except BusinessException as e: + logger.warning(f"取消工作流执行失败: {e.message}") + return fail(code=e.error_code, msg=e.message) + except Exception as e: + logger.error(f"取消工作流执行异常: {e}", exc_info=True) + return fail( + code=BizCode.INTERNAL_ERROR, + msg=f"取消工作流执行失败: {str(e)}" + ) diff --git a/api/app/core/api_key_auth.py b/api/app/core/api_key_auth.py index d02d2811..a5db49a7 100644 --- a/api/app/core/api_key_auth.py +++ b/api/app/core/api_key_auth.py @@ -1,10 +1,12 @@ import asyncio +import time import uuid from functools import wraps from typing import Optional, List from datetime import datetime from fastapi import Request, Response +from fastapi.responses import JSONResponse from sqlalchemy.orm import Session from app.core.api_key_utils import add_rate_limit_headers @@ -22,21 +24,17 @@ logger = get_api_logger() def require_api_key( - scopes: Optional[List[str]] = None, - resource_type: Optional[str] = None + scopes: Optional[List[str]] = None ): """ API Key 鉴权装饰器 Args: - scopes: 所需的权限范围列表["app:all", - "rag:search", "rag:upload", "rag:delete", - "memory:read", "memory:write", "memory:delete", "memory:search"] - resource_type: 所需的资源类型("Agent", "Cluster", "Workflow", "Knowledge", "Memory_Engine") + scopes: 所需的权限范围列表[“app”, "rag", "memory"] Usage: @router.get("/app/{resource_id}/chat") - @require_api_key(scopes=["app:all"], resource_type="Agent") + @require_api_key(scopes=["app"]) def chat_with_app( resource_id: uuid.UUID, api_key_auth: ApiKeyAuth = Depends(), @@ -113,31 +111,25 @@ def require_api_key( context={"required_scopes": scopes, "missing_scopes": missing_scopes} ) - if resource_type: - resource_id = kwargs.get("resource_id") - if resource_id and not ApiKeyAuthService.check_resource( - api_key_obj, - resource_type, - resource_id - ): - logger.warning("API Key 资源访问被拒绝", extra={ - "api_key_id": str(api_key_obj.id), - "required_resource_type": resource_type, + resource_id = kwargs.get("resource_id") + if resource_id and not ApiKeyAuthService.check_resource( + api_key_obj, + resource_id + ): + logger.warning("API Key 资源访问被拒绝", extra={ + "api_key_id": str(api_key_obj.id), + "required_resource_id": str(resource_id), + "bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None, + "endpoint": str(request.url) + }) + return BusinessException( + "API Key 未授权访问该资源", + BizCode.API_KEY_INVALID_RESOURCE, + context={ "required_resource_id": str(resource_id), - "bound_resource_type": api_key_obj.resource_type, - "bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None, - "endpoint": str(request.url) - }) - return BusinessException( - "API Key 未授权访问该资源", - BizCode.API_KEY_INVALID_RESOURCE, - context={ - "required_resource_type": resource_type, - "required_resource_id": str(resource_id), - "bound_resource_type": api_key_obj.resource_type, - "bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None - } - ) + "bound_resource_id": str(api_key_obj.resource_id) + } + ) kwargs["api_key_auth"] = ApiKeyAuth( api_key_id=api_key_obj.id, @@ -145,14 +137,17 @@ def require_api_key( type=api_key_obj.type, scopes=api_key_obj.scopes, resource_id=api_key_obj.resource_id, - resource_type=api_key_obj.resource_type ) - + start_time = time.perf_counter() response = await func(*args, **kwargs) + end_time = time.perf_counter() + response_time = (end_time - start_time) * 1000 + if not isinstance(response, Response): + response = JSONResponse(content=response) response = add_rate_limit_headers(response, rate_headers) asyncio.create_task(log_api_key_usage( - db, api_key_obj.id, request, response + db, api_key_obj.id, request, response, response_time )) return response @@ -204,7 +199,8 @@ async def log_api_key_usage( db: Session, api_key_id: uuid.UUID, request: Request, - response: Response + response: Response, + response_time: float ): """记录 API Key 使用日志""" try: @@ -216,8 +212,8 @@ async def log_api_key_usage( "ip_address": request.client.host if request.client else None, "user_agent": request.headers.get("User-Agent"), "status_code": response.status_code if hasattr(response, "status_code") else None, - "response_time": None, # 需要在 middleware 中计算 - "tokens_used": None, # 需要从响应中提取 + "response_time": round(response_time), + "tokens_used": None, "created_at": datetime.now() } diff --git a/api/app/core/api_key_utils.py b/api/app/core/api_key_utils.py index 9ebd33e8..98ae0b10 100644 --- a/api/app/core/api_key_utils.py +++ b/api/app/core/api_key_utils.py @@ -1,33 +1,14 @@ """API Key 工具函数""" import secrets -import hashlib -from typing import Optional +from typing import Optional, Union +from datetime import datetime from app.schemas.api_key_schema import ApiKeyType from fastapi import Response from fastapi.responses import JSONResponse -class ResourceType: - """资源类型常量""" - AGENT = "Agent" - CLUSTER = "Cluster" - WORKFLOW = "Workflow" - KNOWLEDGE = "Knowledge" - MEMORY_ENGINE = "Memory_Engine" - - @classmethod - def get_all_types(cls) -> list[str]: - """获取所有支持的资源类型""" - return [cls.AGENT, cls.CLUSTER, cls.WORKFLOW, cls.KNOWLEDGE, cls.MEMORY_ENGINE] - - @classmethod - def is_valid_type(cls, resource_type: str) -> bool: - """验证资源类型是否有效""" - return resource_type in cls.get_all_types() - - -def generate_api_key(key_type: ApiKeyType) -> tuple[str, str, str]: +def generate_api_key(key_type: ApiKeyType) -> str: """ 生成 API Key @@ -39,102 +20,17 @@ def generate_api_key(key_type: ApiKeyType) -> tuple[str, str, str]: """ # 前缀映射 prefix_map = { - ApiKeyType.APP: "sk-app-", - ApiKeyType.RAG: "sk-rag-", - ApiKeyType.MEMORY: "sk-mem-", + ApiKeyType.AGENT: "sk-agent-", + ApiKeyType.CLUSTER: "sk-cluster-", + ApiKeyType.WORKFLOW: "sk-workflow-", + ApiKeyType.SERVICE: "sk-service-" } prefix = prefix_map[key_type] random_string = secrets.token_urlsafe(32)[:32] # 32 字符 api_key = f"{prefix}{random_string}" - # 生成哈希值存储 - key_hash = hash_api_key(api_key) - - return api_key, key_hash, prefix - - -def hash_api_key(api_key: str) -> str: - """对 API Key 进行哈希 - - Args: - api_key: API Key 明文 - - Returns: - str: 哈希值 - """ - return hashlib.sha256(api_key.encode()).hexdigest() - - -def verify_api_key(api_key: str, key_hash: str) -> bool: - """ - 验证 API Key - - Args: - api_key: API Key 明文 - key_hash: 存储的哈希值 - - Returns: - bool: 是否匹配 - """ - computed_hash = hash_api_key(api_key) - return secrets.compare_digest(computed_hash, key_hash) - - -def validate_resource_binding( - resource_type: Optional[str], - resource_id: Optional[str] -) -> tuple[bool, str]: - """ - 验证资源绑定的有效性 - - Args: - resource_type: 资源类型 - resource_id: 资源ID - - Returns: - tuple: (是否有效, 错误信息) - """ - # 如果都为空,表示不绑定资源,这是有效的 - if not resource_type and not resource_id: - return True, "" - - # 如果只有一个为空,这是无效的 - if not resource_type or not resource_id: - return False, "resource_type 和 resource_id 必须同时提供或同时为空" - - # 验证资源类型是否支持 - if not ResourceType.is_valid_type(resource_type): - valid_types = ", ".join(ResourceType.get_all_types()) - return False, f"不支持的资源类型 '{resource_type}',支持的类型:{valid_types}" - - return True, "" - - -def get_resource_scope_mapping() -> dict[str, list[str]]: - """ - 获取资源类型与权限范围的映射关系 - - Returns: - dict: 资源类型到推荐权限范围的映射 - """ - return { - ResourceType.AGENT: [ - "app:all" - ], - ResourceType.CLUSTER: [ - "app:all" - ], - ResourceType.WORKFLOW: [ - "app:all" - ], - ResourceType.KNOWLEDGE: [ - "rag:search", "rag:upload", "rag:delete" - ], - ResourceType.MEMORY_ENGINE: [ - "memory:read", "memory:write", "memory:delete", "memory:search" - ] - } + return api_key def add_rate_limit_headers(response, headers: dict): @@ -151,3 +47,21 @@ def add_rate_limit_headers(response, headers: dict): return response +def timestamp_to_datetime(timestamp: Optional[Union[int, float]]) -> Optional[datetime]: + """将时间戳转换为datetime对象""" + if timestamp is None: + return None + + # 处理毫秒级时间戳 + if timestamp > 1e10: + timestamp = timestamp / 1000 + + return datetime.fromtimestamp(timestamp) + + +def datetime_to_timestamp(dt: Optional[datetime]) -> Optional[int]: + """将datetime对象转换为时间戳(毫秒)""" + if dt is None: + return None + + return int(dt.timestamp() * 1000) diff --git a/api/app/core/error_codes.py b/api/app/core/error_codes.py index f1d0a1cf..6bb8ac29 100644 --- a/api/app/core/error_codes.py +++ b/api/app/core/error_codes.py @@ -59,6 +59,7 @@ class BizCode(IntEnum): EMBED_NOT_ALLOWED = 6009 PERMISSION_DENIED = 6010 INVALID_CONVERSATION = 6011 + CONFIG_MISSING = 6012 # 模型(7xxx) MODEL_CONFIG_INVALID = 7001 @@ -96,7 +97,7 @@ HTTP_MAPPING = { BizCode.TOKEN_INVALID: 401, BizCode.TOKEN_EXPIRED: 401, BizCode.TOKEN_BLACKLISTED: 401, - BizCode.FORBIDDEN: 403, + BizCode.FORBIDDEN: 403, BizCode.TENANT_NOT_FOUND: 404, BizCode.WORKSPACE_NO_ACCESS: 403, BizCode.NOT_FOUND: 404, @@ -151,4 +152,4 @@ HTTP_MAPPING = { BizCode.DB_ERROR: 500, BizCode.SERVICE_UNAVAILABLE: 503, BizCode.RATE_LIMITED: 429, -} \ No newline at end of file +} diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py new file mode 100644 index 00000000..a945356a --- /dev/null +++ b/api/app/core/workflow/executor.py @@ -0,0 +1,436 @@ +""" +工作流执行器 + +基于 LangGraph 的工作流执行引擎。 +""" + +import logging +import uuid +import datetime +from typing import Any + +from langchain_core.messages import HumanMessage +from langgraph.graph import StateGraph, START, END + +from app.core.workflow.nodes import WorkflowState, NodeFactory +from app.core.workflow.expression_evaluator import evaluate_condition +from app.models.workflow_model import WorkflowExecution, WorkflowNodeExecution +from app.db import get_db + +logger = logging.getLogger(__name__) + + +class WorkflowExecutor: + """工作流执行器 + + 负责将工作流配置转换为 LangGraph 并执行。 + """ + + def __init__( + self, + workflow_config: dict[str, Any], + execution_id: str, + workspace_id: str, + user_id: str + ): + """初始化执行器 + + Args: + workflow_config: 工作流配置 + execution_id: 执行 ID + workspace_id: 工作空间 ID + user_id: 用户 ID + """ + self.workflow_config = workflow_config + self.execution_id = execution_id + self.workspace_id = workspace_id + self.user_id = user_id + self.nodes = workflow_config.get("nodes", []) + self.edges = workflow_config.get("edges", []) + self.execution_config = workflow_config.get("execution_config", {}) + + def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState: + """准备初始状态(注入系统变量和会话变量) + + 变量命名空间: + - sys.xxx - 系统变量(execution_id, workspace_id, user_id, message, input_variables 等) + - conv.xxx - 会话变量(跨多轮对话保持) + - node_id.xxx - 节点输出(执行时动态生成) + + Args: + input_data: 输入数据 + + Returns: + 初始化的工作流状态 + """ + user_message = input_data.get("message") or "" + conversation_vars = input_data.get("conversation_vars") or {} + input_variables = input_data.get("variables") or {} # Start 节点的自定义变量 + + # 构建分层的变量结构 + variables = { + "sys": { + "message": user_message, # 用户消息 + "conversation_id": input_data.get("conversation_id"), # 会话 ID + "execution_id": self.execution_id, # 执行 ID + "workspace_id": self.workspace_id, # 工作空间 ID + "user_id": self.user_id, # 用户 ID + "input_variables": input_variables, # 自定义输入变量(给 Start 节点使用) + }, + "conv": conversation_vars # 会话级变量(跨多轮对话保持) + } + + return { + "messages": [HumanMessage(content=user_message)], + "variables": variables, + "node_outputs": {}, + "runtime_vars": {}, # 运行时节点变量(简化版,供快速访问) + "execution_id": self.execution_id, + "workspace_id": self.workspace_id, + "user_id": self.user_id, + "error": None, + "error_node": None + } + + + + def build_graph(self) -> StateGraph: + """构建 LangGraph + + Returns: + 编译后的状态图 + """ + logger.info(f"开始构建工作流图: execution_id={self.execution_id}") + + # 1. 创建状态图 + workflow = StateGraph(WorkflowState) + + # 2. 添加所有节点(包括 start 和 end) + start_node_id = None + end_node_ids = [] + + for node in self.nodes: + node_type = node.get("type") + node_id = node.get("id") + + # 记录 start 和 end 节点 ID + if node_type == "start": + start_node_id = node_id + elif node_type == "end": + end_node_ids.append(node_id) + + # 创建节点实例(现在 start 和 end 也会被创建) + node_instance = NodeFactory.create_node(node, self.workflow_config) + if node_instance: + # 包装节点的 run 方法 + # 使用函数工厂避免闭包问题 + def make_node_func(inst): + async def node_func(state: WorkflowState): + return await inst.run(state) + return node_func + + workflow.add_node(node_id, make_node_func(node_instance)) + logger.debug(f"添加节点: {node_id} (type={node_type})") + + # 3. 添加边 + # 从 START 连接到 start 节点 + if start_node_id: + workflow.add_edge(START, start_node_id) + logger.debug(f"添加边: START -> {start_node_id}") + + for edge in self.edges: + source = edge.get("source") + target = edge.get("target") + edge_type = edge.get("type") + condition = edge.get("condition") + + # 跳过从 start 节点出发的边(因为已经从 START 连接到 start) + if source == start_node_id: + # 但要连接 start 到下一个节点 + workflow.add_edge(source, target) + logger.debug(f"添加边: {source} -> {target}") + continue + + # 处理到 end 节点的边 + if target in end_node_ids: + # 连接到 end 节点 + workflow.add_edge(source, target) + logger.debug(f"添加边: {source} -> {target}") + continue + + # 跳过错误边(在节点内部处理) + if edge_type == "error": + continue + + if condition: + # 条件边 + def router(state: WorkflowState, cond=condition, tgt=target): + """条件路由函数""" + if evaluate_condition( + cond, + state.get("variables", {}), + state.get("node_outputs", {}), + { + "execution_id": state.get("execution_id"), + "workspace_id": state.get("workspace_id"), + "user_id": state.get("user_id") + } + ): + return tgt + return END # 条件不满足,结束 + + workflow.add_conditional_edges(source, router) + logger.debug(f"添加条件边: {source} -> {target} (condition={condition})") + else: + # 普通边 + workflow.add_edge(source, target) + logger.debug(f"添加边: {source} -> {target}") + + # 从 end 节点连接到 END + for end_node_id in end_node_ids: + workflow.add_edge(end_node_id, END) + logger.debug(f"添加边: {end_node_id} -> END") + + # 4. 编译图 + graph = workflow.compile() + logger.info(f"工作流图构建完成: execution_id={self.execution_id}") + + return graph + + async def execute( + self, + input_data: dict[str, Any] + ) -> dict[str, Any]: + """执行工作流(非流式) + + Args: + input_data: 输入数据,包含 message 和 variables + + Returns: + 执行结果,包含 status, output, node_outputs, elapsed_time, token_usage + """ + logger.info(f"开始执行工作流: execution_id={self.execution_id}") + + # 记录开始时间 + start_time = datetime.datetime.now() + + # 1. 构建图 + graph = self.build_graph() + + # 2. 初始化状态(自动注入系统变量) + initial_state = self._prepare_initial_state(input_data) + + # 3. 执行工作流 + try: + result = await graph.ainvoke(initial_state) + + # 计算耗时 + end_time = datetime.datetime.now() + elapsed_time = (end_time - start_time).total_seconds() + + # 提取节点输出(现在包含 start 和 end 节点) + node_outputs = result.get("node_outputs", {}) + + # 提取最终输出(从最后一个非 start/end 节点) + final_output = self._extract_final_output(node_outputs) + + # 聚合 token 使用情况 + token_usage = self._aggregate_token_usage(node_outputs) + + # 提取 conversation_id(从 start 节点输出) + conversation_id = None + for node_id, node_output in node_outputs.items(): + if node_output.get("node_type") == "start": + conversation_id = node_output.get("output", {}).get("conversation_id") + break + + logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s") + + return { + "status": "completed", + "output": final_output, + "node_outputs": node_outputs, + "messages": result.get("messages", []), + "conversation_id": conversation_id, + "elapsed_time": elapsed_time, + "token_usage": token_usage, + "error": result.get("error") + } + + except Exception as e: + # 计算耗时(即使失败也记录) + end_time = datetime.datetime.now() + elapsed_time = (end_time - start_time).total_seconds() + + logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True) + return { + "status": "failed", + "error": str(e), + "output": None, + "node_outputs": {}, + "elapsed_time": elapsed_time, + "token_usage": None + } + + async def execute_stream( + self, + input_data: dict[str, Any] + ): + """执行工作流(流式) + + Args: + input_data: 输入数据 + + Yields: + 流式事件 + """ + logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}") + + # 1. 构建图 + graph = self.build_graph() + + # 2. 初始化状态(自动注入系统变量) + initial_state = self._prepare_initial_state(input_data) + + # 3. 流式执行工作流 + try: + # 使用 astream 获取节点级别的更新 + async for event in graph.astream(initial_state, stream_mode="updates"): + for node_name, state_update in event.items(): + yield { + "type": "node_complete", + "node": node_name, + "data": state_update, + "execution_id": self.execution_id + } + + logger.info(f"工作流执行完成(流式): execution_id={self.execution_id}") + + # 发送完成事件 + yield { + "type": "workflow_complete", + "execution_id": self.execution_id + } + + except Exception as e: + logger.error(f"工作流执行失败(流式): execution_id={self.execution_id}, error={e}", exc_info=True) + yield { + "type": "workflow_error", + "execution_id": self.execution_id, + "error": str(e) + } + + + def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None: + """从节点输出中提取最终输出 + + 优先级: + 1. 最后一个执行的非 start/end 节点的 output + 2. 如果没有节点输出,返回 None + + Args: + node_outputs: 所有节点的输出 + + Returns: + 最终输出字符串或 None + """ + if not node_outputs: + return None + + # 获取最后一个节点的输出 + last_node_output = list(node_outputs.values())[-1] if node_outputs else None + + if last_node_output and isinstance(last_node_output, dict): + return last_node_output.get("output") + + return None + + def _aggregate_token_usage(self, node_outputs: dict[str, Any]) -> dict[str, int] | None: + """聚合所有节点的 token 使用情况 + + Args: + node_outputs: 所有节点的输出 + + Returns: + 聚合的 token 使用情况 {"prompt_tokens": x, "completion_tokens": y, "total_tokens": z} + 如果没有 token 使用信息,返回 None + """ + total_prompt_tokens = 0 + total_completion_tokens = 0 + total_tokens = 0 + has_token_info = False + + for node_output in node_outputs.values(): + if isinstance(node_output, dict): + token_usage = node_output.get("token_usage") + if token_usage and isinstance(token_usage, dict): + has_token_info = True + total_prompt_tokens += token_usage.get("prompt_tokens", 0) + total_completion_tokens += token_usage.get("completion_tokens", 0) + total_tokens += token_usage.get("total_tokens", 0) + + if not has_token_info: + return None + + return { + "prompt_tokens": total_prompt_tokens, + "completion_tokens": total_completion_tokens, + "total_tokens": total_tokens + } + + +async def execute_workflow( + workflow_config: dict[str, Any], + input_data: dict[str, Any], + execution_id: str, + workspace_id: str, + user_id: str +) -> dict[str, Any]: + """执行工作流(便捷函数) + + Args: + workflow_config: 工作流配置 + input_data: 输入数据 + execution_id: 执行 ID + workspace_id: 工作空间 ID + user_id: 用户 ID + + Returns: + 执行结果 + """ + executor = WorkflowExecutor( + workflow_config=workflow_config, + execution_id=execution_id, + workspace_id=workspace_id, + user_id=user_id + ) + return await executor.execute(input_data) + + +async def execute_workflow_stream( + workflow_config: dict[str, Any], + input_data: dict[str, Any], + execution_id: str, + workspace_id: str, + user_id: str +): + """执行工作流(流式,便捷函数) + + Args: + workflow_config: 工作流配置 + input_data: 输入数据 + execution_id: 执行 ID + workspace_id: 工作空间 ID + user_id: 用户 ID + + Yields: + 流式事件 + """ + executor = WorkflowExecutor( + workflow_config=workflow_config, + execution_id=execution_id, + workspace_id=workspace_id, + user_id=user_id + ) + async for event in executor.execute_stream(input_data): + yield event diff --git a/api/app/core/workflow/expression_evaluator.py b/api/app/core/workflow/expression_evaluator.py new file mode 100644 index 00000000..c8875d79 --- /dev/null +++ b/api/app/core/workflow/expression_evaluator.py @@ -0,0 +1,195 @@ +""" +安全的表达式求值器 + +使用 simpleeval 库提供安全的表达式评估,避免代码注入攻击。 +""" + +import logging +from typing import Any + +from simpleeval import simple_eval, NameNotDefined, InvalidExpression + +logger = logging.getLogger(__name__) + + +class ExpressionEvaluator: + """安全的表达式求值器""" + + # 保留的命名空间 + RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"} + + @staticmethod + def evaluate( + expression: str, + variables: dict[str, Any], + node_outputs: dict[str, Any], + system_vars: dict[str, Any] | None = None + ) -> Any: + """安全地评估表达式 + + Args: + expression: 表达式字符串,如 "{{var.score}} > 0.8" + variables: 用户定义的变量 + node_outputs: 节点输出结果 + system_vars: 系统变量 + + Returns: + 表达式求值结果 + + Raises: + ValueError: 表达式无效或求值失败 + + Examples: + >>> evaluator = ExpressionEvaluator() + >>> evaluator.evaluate( + ... "var.score > 0.8", + ... {"score": 0.9}, + ... {}, + ... {} + ... ) + True + + >>> evaluator.evaluate( + ... "node.intent.output == '售前咨询'", + ... {}, + ... {"intent": {"output": "售前咨询"}}, + ... {} + ... ) + True + """ + # 移除 Jinja2 模板语法的花括号(如果存在) + expression = expression.strip() + if expression.startswith("{{") and expression.endswith("}}"): + expression = expression[2:-2].strip() + + # 构建命名空间上下文 + context = { + "var": variables, # 用户变量 + "node": node_outputs, # 节点输出 + "sys": system_vars or {}, # 系统变量 + } + + # 为了向后兼容,也支持直接访问(但会在日志中警告) + context.update(variables) + context["nodes"] = node_outputs + + try: + # simpleeval 只支持安全的操作: + # - 算术运算: +, -, *, /, //, %, ** + # - 比较运算: ==, !=, <, <=, >, >= + # - 逻辑运算: and, or, not + # - 成员运算: in, not in + # - 属性访问: obj.attr + # - 字典/列表访问: obj["key"], obj[0] + # 不支持:函数调用、导入、赋值等危险操作 + result = simple_eval(expression, names=context) + return result + + except NameNotDefined as e: + logger.error(f"表达式中引用了未定义的变量: {expression}, 错误: {e}") + raise ValueError(f"未定义的变量: {e}") + + except InvalidExpression as e: + logger.error(f"表达式语法无效: {expression}, 错误: {e}") + raise ValueError(f"表达式语法无效: {e}") + + except SyntaxError as e: + logger.error(f"表达式语法错误: {expression}, 错误: {e}") + raise ValueError(f"表达式语法错误: {e}") + + except Exception as e: + logger.error(f"表达式求值异常: {expression}, 错误: {e}") + raise ValueError(f"表达式求值失败: {e}") + + @staticmethod + def evaluate_bool( + expression: str, + variables: dict[str, Any], + node_outputs: dict[str, Any], + system_vars: dict[str, Any] | None = None + ) -> bool: + """评估布尔表达式(用于条件判断) + + Args: + expression: 布尔表达式 + variables: 用户变量 + node_outputs: 节点输出 + system_vars: 系统变量 + + Returns: + 布尔值结果 + + Examples: + >>> ExpressionEvaluator.evaluate_bool( + ... "var.count >= 10 and var.status == 'active'", + ... {"count": 15, "status": "active"}, + ... {}, + ... {} + ... ) + True + """ + result = ExpressionEvaluator.evaluate( + expression, variables, node_outputs, system_vars + ) + return bool(result) + + @staticmethod + def validate_variable_names(variables: list[dict]) -> list[str]: + """验证变量名是否合法 + + Args: + variables: 变量定义列表 + + Returns: + 错误列表,如果为空则验证通过 + + Examples: + >>> ExpressionEvaluator.validate_variable_names([ + ... {"name": "user_input"}, + ... {"name": "var"} # 保留字 + ... ]) + ["变量名 'var' 是保留的命名空间,请使用其他名称"] + """ + errors = [] + + for var in variables: + var_name = var.get("name", "") + + # 检查是否为保留命名空间 + if var_name in ExpressionEvaluator.RESERVED_NAMESPACES: + errors.append( + f"变量名 '{var_name}' 是保留的命名空间,请使用其他名称" + ) + + # 检查是否为有效的 Python 标识符 + if not var_name.isidentifier(): + errors.append( + f"变量名 '{var_name}' 不是有效的标识符" + ) + + return errors + + +# 便捷函数 +def evaluate_expression( + expression: str, + variables: dict[str, Any], + node_outputs: dict[str, Any], + system_vars: dict[str, Any] | None = None +) -> Any: + """评估表达式(便捷函数)""" + return ExpressionEvaluator.evaluate( + expression, variables, node_outputs, system_vars + ) + + +def evaluate_condition( + expression: str, + variables: dict[str, Any], + node_outputs: dict[str, Any], + system_vars: dict[str, Any] | None = None +) -> bool: + """评估条件表达式(便捷函数)""" + return ExpressionEvaluator.evaluate_bool( + expression, variables, node_outputs, system_vars + ) diff --git a/api/app/core/workflow/nodes/__init__.py b/api/app/core/workflow/nodes/__init__.py new file mode 100644 index 00000000..820c9301 --- /dev/null +++ b/api/app/core/workflow/nodes/__init__.py @@ -0,0 +1,24 @@ +""" +工作流节点实现 + +提供各种类型的节点实现,用于工作流执行。 +""" + +from app.core.workflow.nodes.base_node import BaseNode, WorkflowState +from app.core.workflow.nodes.llm import LLMNode +from app.core.workflow.nodes.agent import AgentNode +from app.core.workflow.nodes.transform import TransformNode +from app.core.workflow.nodes.start import StartNode +from app.core.workflow.nodes.end import EndNode +from app.core.workflow.nodes.node_factory import NodeFactory + +__all__ = [ + "BaseNode", + "WorkflowState", + "LLMNode", + "AgentNode", + "TransformNode", + "StartNode", + "EndNode", + "NodeFactory", +] diff --git a/api/app/core/workflow/nodes/agent/__init__.py b/api/app/core/workflow/nodes/agent/__init__.py new file mode 100644 index 00000000..9839a9d4 --- /dev/null +++ b/api/app/core/workflow/nodes/agent/__init__.py @@ -0,0 +1,6 @@ +"""Agent 节点""" + +from app.core.workflow.nodes.agent.node import AgentNode +from app.core.workflow.nodes.agent.config import AgentNodeConfig + +__all__ = ["AgentNode", "AgentNodeConfig"] diff --git a/api/app/core/workflow/nodes/agent/config.py b/api/app/core/workflow/nodes/agent/config.py new file mode 100644 index 00000000..413ce606 --- /dev/null +++ b/api/app/core/workflow/nodes/agent/config.py @@ -0,0 +1,71 @@ +"""Agent 节点配置""" + +from pydantic import Field + +from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType + + +class AgentNodeConfig(BaseNodeConfig): + """Agent 节点配置 + + 调用已配置的 Agent 执行任务。 + """ + + agent_id: str = Field( + ..., + description="Agent 配置 ID" + ) + + message: str = Field( + default="{{ sys.message }}", + description="发送给 Agent 的消息,支持模板变量" + ) + + conversation_id: str | None = Field( + default=None, + description="会话 ID,用于多轮对话" + ) + + variables: dict[str, str] | None = Field( + default=None, + description="传递给 Agent 的变量" + ) + + timeout: int = Field( + default=300, + ge=1, + le=3600, + description="超时时间(秒)" + ) + + # 输出变量定义 + output_variables: list[VariableDefinition] = Field( + default_factory=lambda: [ + VariableDefinition( + name="output", + type=VariableType.STRING, + description="Agent 的回复内容" + ), + VariableDefinition( + name="conversation_id", + type=VariableType.STRING, + description="会话 ID" + ), + VariableDefinition( + name="token_usage", + type=VariableType.OBJECT, + description="Token 使用情况" + ) + ], + description="输出变量定义(自动生成,通常不需要修改)" + ) + + class Config: + json_schema_extra = { + "example": { + "agent_id": "uuid-here", + "message": "{{ sys.message }}", + "timeout": 300, + "description": "调用客服 Agent" + } + } diff --git a/api/app/core/workflow/nodes/agent/node.py b/api/app/core/workflow/nodes/agent/node.py new file mode 100644 index 00000000..e4525d88 --- /dev/null +++ b/api/app/core/workflow/nodes/agent/node.py @@ -0,0 +1,152 @@ +""" +Agent 节点实现 + +调用已发布的 Agent 应用。 +""" + +import logging +from typing import Any +from langchain_core.messages import AIMessage + +from app.core.workflow.nodes.base_node import BaseNode, WorkflowState +from app.services.draft_run_service import DraftRunService +from app.models import AppRelease +from app.db import get_db + +logger = logging.getLogger(__name__) + + +class AgentNode(BaseNode): + """Agent 节点 + + 支持流式和非流式输出。 + + 配置示例: + { + "type": "agent", + "config": { + "agent_id": "uuid", # Agent 的 release_id + "message": "{{var.user_input}}" + } + } + """ + + def _prepare_agent(self, state: WorkflowState) -> tuple[DraftRunService, AppRelease, str]: + """准备 Agent(公共逻辑) + + Args: + state: 工作流状态 + + Returns: + (draft_service, release, message): 服务实例、发布配置、消息 + """ + # 1. 渲染消息 + message_template = self.config.get("message", "") + message = self._render_template(message_template, state) + + # 2. 获取 Agent 配置 + agent_id = self.config.get("agent_id") + if not agent_id: + raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置") + + db = next(get_db()) + release = db.query(AppRelease).filter( + AppRelease.id == agent_id + ).first() + + if not release: + raise ValueError(f"Agent 不存在: {agent_id}") + + draft_service = DraftRunService(db) + + return draft_service, release, message + + async def execute(self, state: WorkflowState) -> dict[str, Any]: + """非流式执行 + + Args: + state: 工作流状态 + + Returns: + 状态更新字典 + """ + draft_service, release, message = self._prepare_agent(state) + + logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)") + + # 执行 Agent(非流式) + result = await draft_service.run( + agent_config=release.config, + model_config=None, + message=message, + workspace_id=state.get("workspace_id"), + user_id=state.get("user_id"), + variables=state.get("variables", {}) + ) + + response = result.get("response", "") + + logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(response)}") + + return { + "messages": [AIMessage(content=response)], + "node_outputs": { + self.node_id: { + "output": response, + "status": "completed", + "meta_data": result.get("meta_data", {}) + } + } + } + + async def execute_stream(self, state: WorkflowState): + """流式执行 + + Args: + state: 工作流状态 + + Yields: + 流式事件字典 + """ + draft_service, release, message = self._prepare_agent(state) + + logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)") + + # 累积完整响应 + full_response = "" + + # 执行 Agent(流式) + async for chunk in draft_service.run_stream( + agent_config=release.config, + model_config=None, + message=message, + workspace_id=state.get("workspace_id"), + user_id=state.get("user_id"), + variables=state.get("variables", {}) + ): + # 提取内容 + content = chunk.get("content", "") + full_response += content + + # 流式返回每个 chunk + yield { + "type": "chunk", + "node_id": self.node_id, + "content": content, + "full_content": full_response, + "meta_data": chunk.get("meta_data", {}) + } + + logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}") + + # 最后返回完整结果 + yield { + "type": "complete", + "messages": [AIMessage(content=full_response)], + "node_outputs": { + self.node_id: { + "output": full_response, + "status": "completed" + } + } + } diff --git a/api/app/core/workflow/nodes/base_config.py b/api/app/core/workflow/nodes/base_config.py new file mode 100644 index 00000000..8423f479 --- /dev/null +++ b/api/app/core/workflow/nodes/base_config.py @@ -0,0 +1,109 @@ +"""节点配置基类 + +定义所有节点配置的通用字段和数据结构。 +""" + +from enum import StrEnum + +from pydantic import BaseModel, Field + + +class VariableType(StrEnum): + """变量类型枚举""" + + STRING = "string" + NUMBER = "number" + BOOLEAN = "boolean" + ARRAY = "array" + OBJECT = "object" + + +class VariableDefinition(BaseModel): + """变量定义 + + 定义工作流或节点的输入/输出变量。 + 这是一个通用的数据结构,可以在多个地方使用。 + """ + + name: str = Field( + ..., + description="变量名称" + ) + + type: VariableType = Field( + default=VariableType.STRING, + description="变量类型" + ) + + required: bool = Field( + default=False, + description="是否必需" + ) + + default: str | int | float | bool | list | dict | None = Field( + default=None, + description="默认值" + ) + + description: str | None = Field( + default=None, + description="变量描述" + ) + + class Config: + json_schema_extra = { + "examples": [ + { + "name": "language", + "type": "string", + "required": False, + "default": "zh-CN", + "description": "语言设置" + }, + { + "name": "max_length", + "type": "number", + "required": False, + "default": 1000, + "description": "最大长度" + }, + { + "name": "enable_search", + "type": "boolean", + "required": True, + "description": "是否启用搜索" + } + ] + } + + +class BaseNodeConfig(BaseModel): + """节点配置基类 + + 所有节点配置都应该继承此基类。 + + 通用字段: + - name: 节点名称(显示名称) + - description: 节点描述 + - tags: 节点标签(用于分类和搜索) + """ + + name: str | None = Field( + default=None, + description="节点名称(显示名称),如果不设置则使用节点 ID" + ) + + description: str | None = Field( + default=None, + description="节点描述,说明节点的作用" + ) + + tags: list[str] = Field( + default_factory=list, + description="节点标签,用于分类和搜索" + ) + + class Config: + """Pydantic 配置""" + # 允许额外字段(向后兼容) + extra = "allow" diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py new file mode 100644 index 00000000..d17cc1fd --- /dev/null +++ b/api/app/core/workflow/nodes/base_node.py @@ -0,0 +1,556 @@ +""" +工作流节点基类 + +定义节点的基本接口和通用功能。 +""" + +import asyncio +import logging +from abc import ABC, abstractmethod +from typing import Any, TypedDict, Annotated +from operator import add +from langchain_core.messages import AnyMessage, HumanMessage, AIMessage + +from app.core.workflow.variable_pool import VariablePool + +logger = logging.getLogger(__name__) + + +class WorkflowState(TypedDict): + """工作流状态 + + 在节点间传递的状态对象,包含消息、变量、节点输出等信息。 + """ + # 消息列表(追加模式) + messages: Annotated[list[AnyMessage], add] + + # 输入变量(从配置的 variables 传入) + variables: dict[str, Any] + + # 节点输出(存储每个节点的执行结果,用于变量引用) + # 使用自定义合并函数,将新的节点输出合并到现有字典中 + node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}] + + # 运行时节点变量(简化版,只存储业务数据,供节点间快速访问) + # 格式:{node_id: business_result} + runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}] + + # 执行上下文 + execution_id: str + workspace_id: str + user_id: str + + # 错误信息(用于错误边) + error: str | None + error_node: str | None + + +class BaseNode(ABC): + """节点基类 + + 所有节点类型都应该继承此基类,实现 execute 方法。 + """ + + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + """初始化节点 + + Args: + node_config: 节点配置 + workflow_config: 工作流配置 + """ + self.node_config = node_config + self.workflow_config = workflow_config + self.node_id = node_config["id"] + self.node_type = node_config["type"] + self.node_name = node_config.get("name", self.node_id) + # 使用 or 运算符处理 None 值 + self.config = node_config.get("config") or {} + self.error_handling = node_config.get("error_handling") or {} + + @abstractmethod + async def execute(self, state: WorkflowState) -> Any: + """执行节点业务逻辑(非流式) + + 节点只需要返回业务结果,不需要关心输出格式、时间统计等。 + BaseNode 会自动包装成标准格式。 + + Args: + state: 工作流状态 + + Returns: + 业务结果(任意类型) + + Examples: + >>> # LLM 节点 + >>> return "这是 AI 的回复" + + >>> # Transform 节点 + >>> return {"processed_data": [...]} + + >>> # Start/End 节点 + >>> return {"message": "开始", "conversation_id": "xxx"} + """ + pass + + async def execute_stream(self, state: WorkflowState): + """执行节点业务逻辑(流式) + + 子类可以重写此方法以支持流式输出。 + 默认实现:执行非流式方法并一次性返回。 + + 节点需要: + 1. yield 中间结果(如文本片段) + 2. 最后 yield 一个特殊的完成标记:{"__final__": True, "result": final_result} + + Args: + state: 工作流状态 + + Yields: + 业务数据(chunk)或完成标记 + + Examples: + >>> # 流式 LLM 节点 + >>> full_response = "" + >>> async for chunk in llm.astream(prompt): + ... full_response += chunk + ... yield chunk # yield 文本片段 + >>> + >>> # 最后 yield 完成标记 + >>> yield {"__final__": True, "result": AIMessage(content=full_response)} + """ + result = await self.execute(state) + # 默认实现:直接 yield 完成标记 + yield {"__final__": True, "result": result} + + def supports_streaming(self) -> bool: + """节点是否支持流式输出 + + Returns: + 是否支持流式输出 + """ + # 检查子类是否重写了 execute_stream 方法 + return self.execute_stream.__func__ != BaseNode.execute_stream.__func__ + + def get_timeout(self) -> int: + """获取超时时间(秒) + + Returns: + 超时时间 + """ + return 60 + # return self.error_handling.get("timeout", 60) + + async def run(self, state: WorkflowState) -> dict[str, Any]: + """执行节点(带错误处理和输出包装,非流式) + + 这个方法由 Executor 调用,负责: + 1. 时间统计 + 2. 调用节点的 execute() 方法 + 3. 将业务结果包装成标准输出格式 + 4. 错误处理 + + Args: + state: 工作流状态 + + Returns: + 标准化的状态更新字典 + """ + import time + + start_time = time.time() + + try: + timeout = self.get_timeout() + + # 调用节点的业务逻辑 + business_result = await asyncio.wait_for( + self.execute(state), + timeout=timeout + ) + + elapsed_time = time.time() - start_time + + # 提取处理后的输出(调用子类的 _extract_output) + extracted_output = self._extract_output(business_result) + + # 包装成标准输出格式 + wrapped_output = self._wrap_output(business_result, elapsed_time, state) + + # 将提取后的输出存储到运行时变量中(供后续节点快速访问) + # 如果提取后的输出是字典,拆包存储;否则存储为 output 字段 + if isinstance(extracted_output, dict): + runtime_var = extracted_output + else: + runtime_var = {"output": extracted_output} + + # 返回包装后的输出和运行时变量 + return { + **wrapped_output, + "runtime_vars": { + self.node_id: runtime_var + } + } + + except TimeoutError: + elapsed_time = time.time() - start_time + logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)") + return self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state) + except Exception as e: + elapsed_time = time.time() - start_time + logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True) + return self._wrap_error(str(e), elapsed_time, state) + + async def run_stream(self, state: WorkflowState): + """执行节点(带错误处理和输出包装,流式) + + 这个方法由 Executor 调用,负责: + 1. 时间统计 + 2. 调用节点的 execute_stream() 方法 + 3. 将业务数据包装成标准输出格式 + 4. 错误处理 + + Args: + state: 工作流状态 + + Yields: + 标准化的流式事件 + """ + import time + + start_time = time.time() + + try: + timeout = self.get_timeout() + + # 累积完整结果(用于最后的包装) + chunks = [] + final_result = None + + # 使用异步生成器包装,支持超时 + async def stream_with_timeout(): + nonlocal final_result + loop_start = asyncio.get_event_loop().time() + + async for item in self.execute_stream(state): + # 检查超时 + if asyncio.get_event_loop().time() - loop_start > timeout: + raise TimeoutError() + + # 检查是否是完成标记 + if isinstance(item, dict) and item.get("__final__"): + final_result = item["result"] + elif isinstance(item, str): + # 字符串是 chunk + chunks.append(item) + yield { + "type": "chunk", + "node_id": self.node_id, + "content": item, + "full_content": "".join(chunks) + } + else: + # 其他类型也当作 chunk 处理 + chunks.append(str(item)) + yield { + "type": "chunk", + "node_id": self.node_id, + "content": str(item), + "full_content": "".join(chunks) + } + + async for chunk_event in stream_with_timeout(): + yield chunk_event + + elapsed_time = time.time() - start_time + + # 包装最终结果 + final_output = self._wrap_output(final_result, elapsed_time, state) + yield { + "type": "complete", + **final_output + } + + except TimeoutError: + elapsed_time = time.time() - start_time + logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)") + yield { + "type": "error", + **self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state) + } + except Exception as e: + elapsed_time = time.time() - start_time + logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True) + yield { + "type": "error", + **self._wrap_error(str(e), elapsed_time, state) + } + + def _wrap_output( + self, + business_result: Any, + elapsed_time: float, + state: WorkflowState + ) -> dict[str, Any]: + """将业务结果包装成标准输出格式 + + Args: + business_result: 节点返回的业务结果 + elapsed_time: 执行耗时 + state: 工作流状态 + + Returns: + 标准化的状态更新字典 + """ + # 提取输入数据(用于记录) + input_data = self._extract_input(state) + + # 提取 token 使用情况(如果有) + token_usage = self._extract_token_usage(business_result) + + # 提取实际输出(去除元数据) + output = self._extract_output(business_result) + + # 构建标准节点输出 + node_output = { + "node_id": self.node_id, + "node_type": self.node_type, + "node_name": self.node_name, + "status": "completed", + "input": input_data, + "output": output, + "elapsed_time": elapsed_time, + "token_usage": token_usage, + "error": None + } + + return { + "node_outputs": { + self.node_id: node_output + } + } + + def _wrap_error( + self, + error_message: str, + elapsed_time: float, + state: WorkflowState + ) -> dict[str, Any]: + """将错误包装成标准输出格式 + + Args: + error_message: 错误信息 + elapsed_time: 执行耗时 + state: 工作流状态 + + Returns: + 标准化的状态更新字典 + """ + # 查找错误边 + error_edge = self._find_error_edge() + + # 提取输入数据 + input_data = self._extract_input(state) + + # 构建错误输出 + node_output = { + "node_id": self.node_id, + "node_type": self.node_type, + "node_name": self.node_name, + "status": "failed", + "input": input_data, + "output": None, + "elapsed_time": elapsed_time, + "token_usage": None, + "error": error_message + } + + if error_edge: + # 有错误边:记录错误并继续 + logger.warning( + f"节点 {self.node_id} 执行失败,跳转到错误处理节点: {error_edge['target']}" + ) + return { + "node_outputs": { + self.node_id: node_output + }, + "error": error_message, + "error_node": self.node_id + } + else: + # 无错误边:抛出异常停止工作流 + logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}") + raise Exception(f"节点 {self.node_id} 执行失败: {error_message}") + + def _extract_input(self, state: WorkflowState) -> dict[str, Any]: + """提取节点输入数据(用于记录) + + 子类可以重写此方法来自定义输入记录。 + + Args: + state: 工作流状态 + + Returns: + 输入数据字典 + """ + # 默认返回配置 + return {"config": self.config} + + def _extract_output(self, business_result: Any) -> Any: + """从业务结果中提取实际输出 + + 子类可以重写此方法来自定义输出提取。 + + Args: + business_result: 业务结果 + + Returns: + 实际输出 + """ + # 默认直接返回业务结果 + return business_result + + def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: + """从业务结果中提取 token 使用情况 + + 子类可以重写此方法来提取 token 信息。 + + Args: + business_result: 业务结果 + + Returns: + token 使用情况或 None + """ + # 默认返回 None + return None + + def _find_error_edge(self) -> dict[str, Any] | None: + """查找错误边 + + Returns: + 错误边配置或 None + """ + for edge in self.workflow_config.get("edges", []): + if edge.get("source") == self.node_id and edge.get("type") == "error": + return edge + return None + + def _render_template(self, template: str, state: WorkflowState | None) -> str: + """渲染模板 + + 支持的变量命名空间: + - sys.xxx: 系统变量(message, execution_id, workspace_id, user_id, conversation_id) + - conv.xxx: 会话变量(跨多轮对话保持) + - node_id.xxx: 节点输出 + + Args: + template: 模板字符串 + state: 工作流状态 + + Returns: + 渲染后的字符串 + """ + from app.core.workflow.template_renderer import render_template + + # 处理 state 为 None 的情况 + if state is None: + state = {} + + # 使用变量池获取变量 + pool = VariablePool(state) + + return render_template( + template=template, + variables=pool.get_all_conversation_vars(), + node_outputs=pool.get_all_node_outputs(), + system_vars=pool.get_all_system_vars() + ) + + def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool: + """评估条件表达式 + + 支持的变量命名空间: + - sys.xxx: 系统变量 + - conv.xxx: 会话变量 + - node_id.xxx: 节点输出 + + Args: + expression: 条件表达式 + state: 工作流状态 + + Returns: + 布尔值结果 + """ + from app.core.workflow.expression_evaluator import evaluate_condition + + # 处理 state 为 None 的情况 + if state is None: + state = {} + + # 使用变量池获取变量 + pool = VariablePool(state) + + return evaluate_condition( + expression=expression, + variables=pool.get_all_conversation_vars(), + node_outputs=pool.get_all_node_outputs(), + system_vars=pool.get_all_system_vars() + ) + + def get_variable_pool(self, state: WorkflowState) -> VariablePool: + """获取变量池实例 + + VariablePool 是轻量级包装器,只持有 state 的引用,创建成本极低。 + + Args: + state: 工作流状态 + + Returns: + VariablePool 实例 + + Examples: + >>> pool = self.get_variable_pool(state) + >>> message = pool.get("sys.message") + >>> llm_output = pool.get("llm_qa.output") + """ + return VariablePool(state) + + def get_variable( + self, + selector: list[str] | str, + state: WorkflowState, + default: Any = None + ) -> Any: + """获取变量值(便捷方法) + + Args: + selector: 变量选择器 + state: 工作流状态 + default: 默认值 + + Returns: + 变量值 + + Examples: + >>> message = self.get_variable("sys.message", state) + >>> output = self.get_variable(["llm_qa", "output"], state) + >>> custom = self.get_variable("var.custom", state, default="默认值") + """ + pool = VariablePool(state) + return pool.get(selector, default=default) + + def has_variable(self, selector: list[str] | str, state: WorkflowState) -> bool: + """检查变量是否存在(便捷方法) + + Args: + selector: 变量选择器 + state: 工作流状态 + + Returns: + 变量是否存在 + + Examples: + >>> if self.has_variable("llm_qa.output", state): + ... output = self.get_variable("llm_qa.output", state) + """ + pool = VariablePool(state) + return pool.has(selector) diff --git a/api/app/core/workflow/nodes/configs.py b/api/app/core/workflow/nodes/configs.py new file mode 100644 index 00000000..99d06036 --- /dev/null +++ b/api/app/core/workflow/nodes/configs.py @@ -0,0 +1,29 @@ +"""节点配置类统一导出 + +所有节点的配置类都在这里导出,方便使用。 +""" + +from app.core.workflow.nodes.base_config import ( + BaseNodeConfig, + VariableDefinition, + VariableType, +) +from app.core.workflow.nodes.start.config import StartNodeConfig +from app.core.workflow.nodes.end.config import EndNodeConfig +from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig +from app.core.workflow.nodes.agent.config import AgentNodeConfig +from app.core.workflow.nodes.transform.config import TransformNodeConfig + +__all__ = [ + # 基础类 + "BaseNodeConfig", + "VariableDefinition", + "VariableType", + # 节点配置 + "StartNodeConfig", + "EndNodeConfig", + "LLMNodeConfig", + "MessageConfig", + "AgentNodeConfig", + "TransformNodeConfig", +] diff --git a/api/app/core/workflow/nodes/end/__init__.py b/api/app/core/workflow/nodes/end/__init__.py new file mode 100644 index 00000000..d7be3c5b --- /dev/null +++ b/api/app/core/workflow/nodes/end/__init__.py @@ -0,0 +1,6 @@ +"""End 节点""" + +from app.core.workflow.nodes.end.node import EndNode +from app.core.workflow.nodes.end.config import EndNodeConfig + +__all__ = ["EndNode", "EndNodeConfig"] diff --git a/api/app/core/workflow/nodes/end/config.py b/api/app/core/workflow/nodes/end/config.py new file mode 100644 index 00000000..50e84a36 --- /dev/null +++ b/api/app/core/workflow/nodes/end/config.py @@ -0,0 +1,37 @@ +"""End 节点配置""" + +from pydantic import Field + +from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType + + +class EndNodeConfig(BaseNodeConfig): + """End 节点配置 + + End 节点负责输出工作流的最终结果。 + """ + + output: str = Field( + default="工作流已完成", + description="输出模板,支持引用前置节点的输出,如:{{ llm_qa.output }}" + ) + + # 输出变量定义 + output_variables: list[VariableDefinition] = Field( + default_factory=lambda: [ + VariableDefinition( + name="output", + type=VariableType.STRING, + description="工作流的最终输出" + ) + ], + description="输出变量定义(自动生成,通常不需要修改)" + ) + + class Config: + json_schema_extra = { + "example": { + "output": "{{ llm_qa.output }}", + "description": "输出 LLM 的回答" + } + } diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py new file mode 100644 index 00000000..1c0e6747 --- /dev/null +++ b/api/app/core/workflow/nodes/end/node.py @@ -0,0 +1,53 @@ +""" +End 节点实现 + +工作流的结束节点,输出最终结果。 +""" + +import logging +from typing import Any + +from app.core.workflow.nodes.base_node import BaseNode, WorkflowState + +logger = logging.getLogger(__name__) + + +class EndNode(BaseNode): + """End 节点 + + 工作流的结束节点,根据配置的模板输出最终结果。 + """ + + async def execute(self, state: WorkflowState) -> str: + """执行 end 节点业务逻辑 + + Args: + state: 工作流状态 + + Returns: + 最终输出字符串 + """ + logger.info(f"节点 {self.node_id} (End) 开始执行") + + # 获取配置的输出模板 + output_template = self.config.get("output") + pool = self.get_variable_pool(state) + + print("="*20) + print( pool.get("start.test")) + print("="*20) + # 如果配置了输出模板,使用模板渲染;否则使用默认输出 + if output_template: + output = self._render_template(output_template, state) + else: + output = "工作流已完成" + + # 统计信息(用于日志) + node_outputs = state.get("node_outputs", {}) + total_nodes = len(node_outputs) + + logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点") + print("="*20) + print(output) + print("="*20) + return output diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py new file mode 100644 index 00000000..9cec19d2 --- /dev/null +++ b/api/app/core/workflow/nodes/enums.py @@ -0,0 +1,15 @@ +from enum import StrEnum + +class NodeType(StrEnum): + START = "start" + END = "end" + ANSWER = "answer" + LLM = "llm" + KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" + IF_ELSE = "if-else" + CODE = "code" + TRANSFORM = "transform" + QUESTION_CLASSIFIER = "question-classifier" + HTTP_REQUEST = "http-request" + TOOL = "tool" + AGENT = "agent" diff --git a/api/app/core/workflow/nodes/llm/__init__.py b/api/app/core/workflow/nodes/llm/__init__.py new file mode 100644 index 00000000..99ca570e --- /dev/null +++ b/api/app/core/workflow/nodes/llm/__init__.py @@ -0,0 +1,6 @@ +"""LLM 节点""" + +from app.core.workflow.nodes.llm.node import LLMNode +from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig + +__all__ = ["LLMNode", "LLMNodeConfig", "MessageConfig"] diff --git a/api/app/core/workflow/nodes/llm/config.py b/api/app/core/workflow/nodes/llm/config.py new file mode 100644 index 00000000..da94482b --- /dev/null +++ b/api/app/core/workflow/nodes/llm/config.py @@ -0,0 +1,141 @@ +"""LLM 节点配置""" + +from pydantic import BaseModel, Field, field_validator + +from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType + + +class MessageConfig(BaseModel): + """消息配置""" + + role: str = Field( + ..., + description="消息角色:system, user, assistant" + ) + + content: str = Field( + ..., + description="消息内容,支持模板变量,如:{{ sys.message }}" + ) + + @field_validator("role") + @classmethod + def validate_role(cls, v: str) -> str: + """验证角色""" + allowed_roles = ["system", "user", "human", "assistant", "ai"] + if v.lower() not in allowed_roles: + raise ValueError(f"角色必须是以下之一: {', '.join(allowed_roles)}") + return v.lower() + + +class LLMNodeConfig(BaseNodeConfig): + """LLM 节点配置 + + 支持两种配置方式: + 1. 简单模式:使用 prompt 字段 + 2. 消息模式:使用 messages 字段(推荐) + """ + + model_id: str = Field( + ..., + description="模型配置 ID" + ) + + # 简单模式 + prompt: str | None = Field( + default=None, + description="提示词模板(简单模式),支持变量引用" + ) + + # 消息模式(推荐) + messages: list[MessageConfig] | None = Field( + default=None, + description="消息列表(消息模式),支持多轮对话" + ) + + # 模型参数 + temperature: float | None = Field( + default=0.7, + ge=0.0, + le=2.0, + description="温度参数,控制输出的随机性" + ) + + max_tokens: int | None = Field( + default=1000, + ge=1, + le=32000, + description="最大生成 token 数" + ) + + top_p: float | None = Field( + default=None, + ge=0.0, + le=1.0, + description="Top-p 采样参数" + ) + + frequency_penalty: float | None = Field( + default=None, + ge=-2.0, + le=2.0, + description="频率惩罚" + ) + + presence_penalty: float | None = Field( + default=None, + ge=-2.0, + le=2.0, + description="存在惩罚" + ) + + # 输出变量定义 + output_variables: list[VariableDefinition] = Field( + default_factory=lambda: [ + VariableDefinition( + name="output", + type=VariableType.STRING, + description="LLM 生成的文本输出" + ), + VariableDefinition( + name="token_usage", + type=VariableType.OBJECT, + description="Token 使用情况" + ) + ], + description="输出变量定义(自动生成,通常不需要修改)" + ) + + @field_validator("messages", "prompt") + @classmethod + def validate_input_mode(cls, v, info): + """验证输入模式:prompt 和 messages 至少有一个""" + # 这个验证在 model_validator 中更合适 + return v + + class Config: + json_schema_extra = { + "examples": [ + { + "model_id": "uuid-here", + "prompt": "请回答:{{ sys.message }}", + "temperature": 0.7, + "max_tokens": 1000 + }, + { + "model_id": "uuid-here", + "messages": [ + { + "role": "system", + "content": "你是一个专业的 AI 助手" + }, + { + "role": "user", + "content": "{{ sys.message }}" + } + ], + "temperature": 0.7, + "max_tokens": 1000 + } + ] + } diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py new file mode 100644 index 00000000..bfc7da58 --- /dev/null +++ b/api/app/core/workflow/nodes/llm/node.py @@ -0,0 +1,247 @@ +""" +LLM 节点实现 + +调用 LLM 模型进行文本生成。 +""" + +import logging +from typing import Any +from langchain_core.messages import AIMessage, SystemMessage, HumanMessage + +from app.core.workflow.nodes.base_node import BaseNode, WorkflowState +from app.core.models import RedBearLLM, RedBearModelConfig +from app.models import ModelConfig +from app.db import get_db, get_db_context +from app.models.models_model import ModelApiKey +from app.services.model_service import ModelConfigService, ModelApiKeyService + +from app.core.exceptions import BusinessException +from app.core.error_codes import BizCode + +logger = logging.getLogger(__name__) + + +class LLMNode(BaseNode): + """LLM 节点 + + 支持流式和非流式输出,使用 LangChain 标准的消息格式。 + + 配置示例(支持多种消息格式): + + 1. 简单文本格式: + { + "type": "llm", + "config": { + "model_id": "uuid", + "prompt": "请分析:{{sys.message}}", + "temperature": 0.7, + "max_tokens": 1000 + } + } + + 2. LangChain 消息格式(推荐): + { + "type": "llm", + "config": { + "model_id": "uuid", + "messages": [ + { + "role": "system", + "content": "你是一个专业的 AI 助手。" + }, + { + "role": "user", + "content": "{{sys.message}}" + } + ], + "temperature": 0.7, + "max_tokens": 1000 + } + } + + 支持的角色类型: + - system: 系统消息(SystemMessage) + - user/human: 用户消息(HumanMessage) + - ai/assistant: AI 消息(AIMessage) + """ + + def _prepare_llm(self, state: WorkflowState) -> tuple[RedBearLLM, list | str]: + """准备 LLM 实例(公共逻辑) + + Args: + state: 工作流状态 + + Returns: + (llm, messages_or_prompt): LLM 实例和消息列表或 prompt 字符串 + """ + + # 1. 处理消息格式(优先使用 messages) + messages_config = self.config.get("messages") + + if messages_config: + # 使用 LangChain 消息格式 + messages = [] + for msg_config in messages_config: + role = msg_config.get("role", "user").lower() + content_template = msg_config.get("content", "") + content = self._render_template(content_template, state) + + # 根据角色创建对应的消息对象 + if role == "system": + messages.append(SystemMessage(content=content)) + elif role in ["user", "human"]: + messages.append(HumanMessage(content=content)) + elif role in ["ai", "assistant"]: + messages.append(AIMessage(content=content)) + else: + logger.warning(f"未知的消息角色: {role},默认使用 user") + messages.append(HumanMessage(content=content)) + + prompt_or_messages = messages + else: + # 使用简单的 prompt 格式(向后兼容) + prompt_template = self.config.get("prompt", "") + prompt_or_messages = self._render_template(prompt_template, state) + + # 2. 获取模型配置 + model_id = self.config.get("model_id") + if not model_id: + raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置") + + # 3. 在 with 块内完成所有数据库操作和数据提取 + with get_db_context() as db: + config = ModelConfigService.get_model_by_id(db=db, model_id=model_id) + + if not config: + raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND) + + if not config.api_keys or len(config.api_keys) == 0: + raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER) + + # 在 Session 关闭前提取所有需要的数据 + api_config = config.api_keys[0] + model_name = api_config.model_name + provider = api_config.provider + api_key = api_config.api_key + api_base = api_config.api_base + model_type = config.type + + # 4. 创建 LLM 实例(使用已提取的数据) + llm = RedBearLLM( + RedBearModelConfig( + model_name=model_name, + provider=provider, + api_key=api_key, + base_url=api_base + ), + type=model_type + ) + + return llm, prompt_or_messages + + async def execute(self, state: WorkflowState) -> AIMessage: + """非流式执行 LLM 调用 + + Args: + state: 工作流状态 + + Returns: + LLM 响应消息 + """ + llm, prompt_or_messages = self._prepare_llm(state) + + logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)") + + # 调用 LLM(支持字符串或消息列表) + response = await llm.ainvoke(prompt_or_messages) + + # 提取内容 + if hasattr(response, 'content'): + content = response.content + else: + content = str(response) + + logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}") + + # 返回 AIMessage(包含响应元数据) + return response if isinstance(response, AIMessage) else AIMessage(content=content) + + def _extract_input(self, state: WorkflowState) -> dict[str, Any]: + """提取输入数据(用于记录)""" + _, prompt_or_messages = self._prepare_llm(state) + + return { + "prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None, + "messages": [ + {"role": msg.__class__.__name__.replace("Message", "").lower(), "content": msg.content} + for msg in prompt_or_messages + ] if isinstance(prompt_or_messages, list) else None, + "config": { + "model_id": self.config.get("model_id"), + "temperature": self.config.get("temperature"), + "max_tokens": self.config.get("max_tokens") + } + } + + def _extract_output(self, business_result: Any) -> str: + """从 AIMessage 中提取文本内容""" + if isinstance(business_result, AIMessage): + return business_result.content + return str(business_result) + + def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: + """从 AIMessage 中提取 token 使用情况""" + if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'): + usage = business_result.response_metadata.get('token_usage') + if usage: + return { + "prompt_tokens": usage.get('prompt_tokens', 0), + "completion_tokens": usage.get('completion_tokens', 0), + "total_tokens": usage.get('total_tokens', 0) + } + return None + + async def execute_stream(self, state: WorkflowState): + """流式执行 LLM 调用 + + Args: + state: 工作流状态 + + Yields: + 文本片段(chunk)或完成标记 + """ + llm, prompt_or_messages = self._prepare_llm(state) + + logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") + + # 累积完整响应 + full_response = "" + last_chunk = None + + # 调用 LLM(流式,支持字符串或消息列表) + async for chunk in llm.astream(prompt_or_messages): + # 提取内容 + if hasattr(chunk, 'content'): + content = chunk.content + else: + content = str(chunk) + + full_response += content + last_chunk = chunk + + # 流式返回每个文本片段 + yield content + + logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}") + + # 构建完整的 AIMessage(包含元数据) + if isinstance(last_chunk, AIMessage): + final_message = AIMessage( + content=full_response, + response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {} + ) + else: + final_message = AIMessage(content=full_response) + + # yield 完成标记 + yield {"__final__": True, "result": final_message} diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py new file mode 100644 index 00000000..f279d13a --- /dev/null +++ b/api/app/core/workflow/nodes/node_factory.py @@ -0,0 +1,93 @@ +""" +节点工厂 + +根据节点类型创建相应的节点实例。 +""" + +import logging +from typing import Any + +from app.core.workflow.nodes.base_node import BaseNode +from app.core.workflow.nodes.enums import NodeType +from app.core.workflow.nodes.llm import LLMNode +from app.core.workflow.nodes.agent import AgentNode +from app.core.workflow.nodes.transform import TransformNode +from app.core.workflow.nodes.start import StartNode +from app.core.workflow.nodes.end import EndNode + +logger = logging.getLogger(__name__) + + +class NodeFactory: + """节点工厂 + + 使用工厂模式创建节点实例,便于扩展和维护。 + """ + + # 节点类型注册表 + _node_types: dict[str, type[BaseNode]] = { + NodeType.START: StartNode, + NodeType.END: EndNode, + NodeType.LLM: LLMNode, + NodeType.AGENT: AgentNode, + NodeType.TRANSFORM: TransformNode, + } + + @classmethod + def register_node_type(cls, node_type: str, node_class: type[BaseNode]): + """注册新的节点类型 + + Args: + node_type: 节点类型名称 + node_class: 节点类 + + Examples: + >>> class CustomNode(BaseNode): + ... async def execute(self, state): + ... return {"node_outputs": {self.node_id: {"output": "custom"}}} + >>> NodeFactory.register_node_type("custom", CustomNode) + """ + cls._node_types[node_type] = node_class + logger.info(f"注册节点类型: {node_type} -> {node_class.__name__}") + + @classmethod + def create_node( + cls, + node_config: dict[str, Any], + workflow_config: dict[str, Any] + ) -> BaseNode | None: + """创建节点实例 + + Args: + node_config: 节点配置 + workflow_config: 工作流配置 + + Returns: + 节点实例或 None(对于不支持的节点类型) + + Raises: + ValueError: 不支持的节点类型 + """ + node_type = node_config.get("type") + + # 跳过条件节点(由 LangGraph 处理) + if node_type == "condition": + return None + + # 获取节点类 + node_class = cls._node_types.get(node_type) + if not node_class: + raise ValueError(f"不支持的节点类型: {node_type}") + + # 创建节点实例 + logger.debug(f"创建节点: {node_config.get('id')} (type={node_type})") + return node_class(node_config, workflow_config) + + @classmethod + def get_supported_types(cls) -> list[str]: + """获取支持的节点类型列表 + + Returns: + 节点类型列表 + """ + return list(cls._node_types.keys()) diff --git a/api/app/core/workflow/nodes/start/__init__.py b/api/app/core/workflow/nodes/start/__init__.py new file mode 100644 index 00000000..c81a1e88 --- /dev/null +++ b/api/app/core/workflow/nodes/start/__init__.py @@ -0,0 +1,6 @@ +"""Start 节点""" + +from app.core.workflow.nodes.start.node import StartNode +from app.core.workflow.nodes.start.config import StartNodeConfig + +__all__ = ["StartNode", "StartNodeConfig"] diff --git a/api/app/core/workflow/nodes/start/config.py b/api/app/core/workflow/nodes/start/config.py new file mode 100644 index 00000000..1544f89f --- /dev/null +++ b/api/app/core/workflow/nodes/start/config.py @@ -0,0 +1,87 @@ +"""Start 节点配置""" + +from pydantic import Field + +from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType + + +class StartNodeConfig(BaseNodeConfig): + """Start 节点配置 + + Start 节点的作用: + 1. 标记工作流的起点 + 2. 定义自定义输入变量(会作为节点输出,通过 start_node_id.variable_name 访问) + 3. 输出系统变量和会话变量 + """ + + # 自定义输入变量定义 + variables: list[VariableDefinition] = Field( + default_factory=list, + description="自定义输入变量列表,这些变量会作为 Start 节点的输出" + ) + + # 输出变量定义 + output_variables: list[VariableDefinition] = Field( + default_factory=lambda: [ + VariableDefinition( + name="message", + type=VariableType.STRING, + description="用户输入的消息" + ), + VariableDefinition( + name="conversation_vars", + type=VariableType.OBJECT, + description="会话级变量" + ), + VariableDefinition( + name="execution_id", + type=VariableType.STRING, + description="执行 ID" + ), + VariableDefinition( + name="conversation_id", + type=VariableType.STRING, + description="会话 ID" + ), + VariableDefinition( + name="workspace_id", + type=VariableType.STRING, + description="工作空间 ID" + ), + VariableDefinition( + name="user_id", + type=VariableType.STRING, + description="用户 ID" + ) + ], + description="输出变量定义(自动生成,通常不需要修改)" + ) + + class Config: + json_schema_extra = { + "examples": [ + { + "description": "工作流开始节点", + "variables": [] + }, + { + "description": "带自定义变量的开始节点", + "variables": [ + { + "name": "language", + "type": "string", + "required": False, + "default": "zh-CN", + "description": "语言设置" + }, + { + "name": "max_length", + "type": "number", + "required": False, + "default": 1000, + "description": "最大长度" + } + ] + } + ] + } diff --git a/api/app/core/workflow/nodes/start/node.py b/api/app/core/workflow/nodes/start/node.py new file mode 100644 index 00000000..0acf04b0 --- /dev/null +++ b/api/app/core/workflow/nodes/start/node.py @@ -0,0 +1,136 @@ +""" +Start 节点实现 + +工作流的起始节点,定义输入变量并输出系统参数。 +""" + +import logging +from typing import Any + +from app.core.workflow.nodes.base_node import BaseNode, WorkflowState +from app.core.workflow.nodes.start.config import StartNodeConfig + +logger = logging.getLogger(__name__) + + +class StartNode(BaseNode): + """Start 节点 + + 工作流的起始节点,负责: + 1. 定义工作流的输入变量(通过配置) + 2. 输出系统变量(sys.*) + 3. 输出会话变量(conv.*) + + 注意:变量的验证和默认值处理由 Executor 在初始化时完成。 + """ + + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + """初始化 Start 节点 + + Args: + node_config: 节点配置 + workflow_config: 工作流配置 + """ + super().__init__(node_config, workflow_config) + + # 解析并验证配置 + self.typed_config = StartNodeConfig(**self.config) + + async def execute(self, state: WorkflowState) -> dict[str, Any]: + """执行 start 节点业务逻辑 + + Start 节点输出系统变量、会话变量和自定义变量。 + + Args: + state: 工作流状态 + + Returns: + 包含系统参数、会话变量和自定义变量的字典 + """ + logger.info(f"节点 {self.node_id} (Start) 开始执行") + + # 创建变量池实例(在方法内复用) + pool = self.get_variable_pool(state) + + # 处理自定义变量(传入 pool 避免重复创建) + custom_vars = self._process_custom_variables(pool) + + # 返回业务数据(包含自定义变量) + result = { + "message": pool.get("sys.message"), + "execution_id": pool.get("sys.execution_id"), + "conversation_id": pool.get("sys.conversation_id"), + "workspace_id": pool.get("sys.workspace_id"), + "user_id": pool.get("sys.user_id"), + **custom_vars # 自定义变量作为节点输出的一部分 + } + + logger.info( + f"节点 {self.node_id} (Start) 执行完成," + f"输出了 {len(custom_vars)} 个自定义变量" + ) + + return result + + def _process_custom_variables(self, pool) -> dict[str, Any]: + """处理自定义变量 + + 从输入数据中提取自定义变量,应用默认值和验证。 + + Args: + pool: 变量池实例 + + Returns: + 处理后的自定义变量字典 + + Raises: + ValueError: 缺少必需变量 + """ + # 获取输入数据中的自定义变量 + input_variables = pool.get("sys.input_variables", default={}) + + processed = {} + + # 遍历配置的变量定义 + for var_def in self.typed_config.variables: + var_name = var_def.name + + # 检查变量是否存在 + if var_name in input_variables: + # 使用用户提供的值 + processed[var_name] = input_variables[var_name] + + elif var_def.required: + # 必需变量缺失 + raise ValueError( + f"缺少必需的输入变量: {var_name}" + + (f" ({var_def.description})" if var_def.description else "") + ) + + elif var_def.default is not None: + # 使用默认值 + processed[var_name] = var_def.default + logger.debug( + f"变量 '{var_name}' 使用默认值: {var_def.default}" + ) + + return processed + + + def _extract_input(self, state: WorkflowState) -> dict[str, Any]: + """提取输入数据(用于记录) + + Args: + state: 工作流状态 + + Returns: + 输入数据字典 + """ + pool = self.get_variable_pool(state) + + return { + "execution_id": pool.get("sys.execution_id"), + "conversation_id": pool.get("sys.conversation_id"), + "message": pool.get("sys.message"), + "conversation_vars": pool.get_all_conversation_vars() + } diff --git a/api/app/core/workflow/nodes/transform/__init__.py b/api/app/core/workflow/nodes/transform/__init__.py new file mode 100644 index 00000000..384b818c --- /dev/null +++ b/api/app/core/workflow/nodes/transform/__init__.py @@ -0,0 +1,6 @@ +"""Transform 节点""" + +from app.core.workflow.nodes.transform.node import TransformNode +from app.core.workflow.nodes.transform.config import TransformNodeConfig + +__all__ = ["TransformNode", "TransformNodeConfig"] diff --git a/api/app/core/workflow/nodes/transform/config.py b/api/app/core/workflow/nodes/transform/config.py new file mode 100644 index 00000000..47d2a6ac --- /dev/null +++ b/api/app/core/workflow/nodes/transform/config.py @@ -0,0 +1,80 @@ +"""Transform 节点配置""" + +from typing import Literal + +from pydantic import Field + +from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType + + +class TransformNodeConfig(BaseNodeConfig): + """Transform 节点配置 + + 用于数据转换和处理。 + """ + + transform_type: Literal["template", "code", "json"] = Field( + default="template", + description="转换类型:template(模板), code(代码), json(JSON处理)" + ) + + # 模板模式 + template: str | None = Field( + default=None, + description="转换模板,支持变量引用" + ) + + # 代码模式 + code: str | None = Field( + default=None, + description="Python 代码,用于数据转换" + ) + + # JSON 模式 + json_path: str | None = Field( + default=None, + description="JSON 路径表达式" + ) + + # 输入变量 + inputs: dict[str, str] | None = Field( + default=None, + description="输入变量映射,key 为变量名,value 为变量选择器" + ) + + # 输出变量 + output_key: str = Field( + default="result", + description="输出变量的键名" + ) + + # 输出变量定义 + output_variables: list[VariableDefinition] = Field( + default_factory=lambda: [ + VariableDefinition( + name="result", + type=VariableType.STRING, + description="转换后的结果" + ) + ], + description="输出变量定义(根据 output_key 动态生成)" + ) + + class Config: + json_schema_extra = { + "examples": [ + { + "transform_type": "template", + "template": "用户问题:{{ sys.message }}\n回答:{{ llm_qa.output }}", + "output_key": "formatted_result" + }, + { + "transform_type": "code", + "code": "result = input_text.upper()", + "inputs": { + "input_text": "{{ sys.message }}" + }, + "output_key": "uppercase_text" + } + ] + } diff --git a/api/app/core/workflow/nodes/transform/node.py b/api/app/core/workflow/nodes/transform/node.py new file mode 100644 index 00000000..4211c510 --- /dev/null +++ b/api/app/core/workflow/nodes/transform/node.py @@ -0,0 +1,60 @@ +""" +Transform 节点实现 + +数据转换节点,用于处理和转换数据。 +""" + +import logging +from typing import Any + +from app.core.workflow.nodes.base_node import BaseNode, WorkflowState + +logger = logging.getLogger(__name__) + + +class TransformNode(BaseNode): + """数据转换节点 + + 配置示例: + { + "type": "transform", + "config": { + "mapping": { + "output_field": "{{node.previous.output}}", + "processed": "{{var.input | upper}}" + } + } + } + """ + + async def execute(self, state: WorkflowState) -> dict[str, Any]: + """执行数据转换 + + Args: + state: 工作流状态 + + Returns: + 状态更新字典 + """ + logger.info(f"节点 {self.node_id} 开始执行数据转换") + + # 获取映射配置 + mapping = self.config.get("mapping", {}) + + # 执行数据转换 + transformed_data = {} + for target_key, source_template in mapping.items(): + # 渲染模板获取值 + value = self._render_template(str(source_template), state) + transformed_data[target_key] = value + + logger.info(f"节点 {self.node_id} 数据转换完成,输出字段: {list(transformed_data.keys())}") + + return { + "node_outputs": { + self.node_id: { + "output": transformed_data, + "status": "completed" + } + } + } diff --git a/api/app/core/workflow/template_loader.py b/api/app/core/workflow/template_loader.py new file mode 100644 index 00000000..ab5bd9fa --- /dev/null +++ b/api/app/core/workflow/template_loader.py @@ -0,0 +1,170 @@ +""" +工作流模板加载器 + +从文件系统加载预定义的工作流模板 +""" + +import os +import yaml +from pathlib import Path +from typing import Optional + + +class TemplateLoader: + """工作流模板加载器""" + + def __init__(self, templates_dir: str = "app/templates/workflows"): + """初始化模板加载器 + + Args: + templates_dir: 模板目录路径 + """ + self.templates_dir = Path(templates_dir) + if not self.templates_dir.exists(): + raise ValueError(f"模板目录不存在: {templates_dir}") + + def list_templates(self) -> list[dict]: + """列出所有可用的模板 + + Returns: + 模板列表,每个模板包含 id, name, description 等信息 + """ + templates = [] + + # 遍历模板目录 + for template_dir in self.templates_dir.iterdir(): + if not template_dir.is_dir(): + continue + + # 检查是否有 template.yml 文件 + template_file = template_dir / "template.yml" + if not template_file.exists(): + continue + + try: + # 读取模板配置 + with open(template_file, 'r', encoding='utf-8') as f: + template_data = yaml.safe_load(f) + + # 提取模板信息 + templates.append({ + "id": template_dir.name, + "name": template_data.get("name", template_dir.name), + "description": template_data.get("description", ""), + "category": template_data.get("category", "general"), + "tags": template_data.get("tags", []), + "author": template_data.get("author", ""), + "version": template_data.get("version", "1.0.0") + }) + except Exception as e: + print(f"加载模板 {template_dir.name} 失败: {e}") + continue + + return templates + + def load_template(self, template_id: str) -> Optional[dict]: + """加载指定的模板 + + Args: + template_id: 模板 ID(目录名) + + Returns: + 模板配置字典,如果模板不存在则返回 None + """ + template_dir = self.templates_dir / template_id + template_file = template_dir / "template.yml" + + if not template_file.exists(): + return None + + try: + with open(template_file, 'r', encoding='utf-8') as f: + template_data = yaml.safe_load(f) + + # 返回工作流配置部分 + return { + "name": template_data.get("name", template_id), + "description": template_data.get("description", ""), + "nodes": template_data.get("nodes", []), + "edges": template_data.get("edges", []), + "variables": template_data.get("variables", []), + "execution_config": template_data.get("execution_config", {}), + "triggers": template_data.get("triggers", []) + } + except Exception as e: + print(f"加载模板 {template_id} 失败: {e}") + return None + + def get_template_readme(self, template_id: str) -> Optional[str]: + """获取模板的 README 文档 + + Args: + template_id: 模板 ID + + Returns: + README 内容,如果不存在则返回 None + """ + template_dir = self.templates_dir / template_id + readme_file = template_dir / "README.md" + + if not readme_file.exists(): + return None + + try: + with open(readme_file, 'r', encoding='utf-8') as f: + return f.read() + except Exception as e: + print(f"读取模板 {template_id} 的 README 失败: {e}") + return None + + +# 全局模板加载器实例 +_template_loader: Optional[TemplateLoader] = None + + +def get_template_loader() -> TemplateLoader: + """获取全局模板加载器实例 + + Returns: + TemplateLoader 实例 + """ + global _template_loader + if _template_loader is None: + _template_loader = TemplateLoader() + return _template_loader + + +def list_workflow_templates() -> list[dict]: + """列出所有工作流模板 + + Returns: + 模板列表 + """ + loader = get_template_loader() + return loader.list_templates() + + +def load_workflow_template(template_id: str) -> Optional[dict]: + """加载工作流模板 + + Args: + template_id: 模板 ID + + Returns: + 模板配置,如果不存在则返回 None + """ + loader = get_template_loader() + return loader.load_template(template_id) + + +def get_workflow_template_readme(template_id: str) -> Optional[str]: + """获取工作流模板的 README + + Args: + template_id: 模板 ID + + Returns: + README 内容,如果不存在则返回 None + """ + loader = get_template_loader() + return loader.get_template_readme(template_id) diff --git a/api/app/core/workflow/template_renderer.py b/api/app/core/workflow/template_renderer.py new file mode 100644 index 00000000..e9efec0b --- /dev/null +++ b/api/app/core/workflow/template_renderer.py @@ -0,0 +1,170 @@ +""" +模板渲染器 + +使用 Jinja2 提供安全的模板渲染功能,支持变量引用和表达式。 +""" + +import logging +from typing import Any + +from jinja2 import Template, TemplateSyntaxError, UndefinedError, Environment, StrictUndefined + +logger = logging.getLogger(__name__) + + +class TemplateRenderer: + """模板渲染器""" + + def __init__(self, strict: bool = True): + """初始化渲染器 + + Args: + strict: 是否使用严格模式(未定义变量会抛出异常) + """ + self.env = Environment( + undefined=StrictUndefined if strict else None, + autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML + ) + + def render( + self, + template: str, + variables: dict[str, Any], + node_outputs: dict[str, Any], + system_vars: dict[str, Any] | None = None + ) -> str: + """渲染模板 + + Args: + template: 模板字符串 + variables: 用户定义的变量 + node_outputs: 节点输出结果 + system_vars: 系统变量 + + Returns: + 渲染后的字符串 + + Raises: + ValueError: 模板语法错误或变量未定义 + + Examples: + >>> renderer = TemplateRenderer() + >>> renderer.render( + ... "Hello {{var.name}}!", + ... {"name": "World"}, + ... {}, + ... {} + ... ) + 'Hello World!' + + >>> renderer.render( + ... "分析结果: {{node.analyze.output}}", + ... {}, + ... {"analyze": {"output": "正面情绪"}}, + ... {} + ... ) + '分析结果: 正面情绪' + """ + # 构建命名空间上下文 + context = { + "var": variables, # 用户变量:{{var.user_input}} + "node": node_outputs, # 节点输出:{{node.node_1.output}} + "sys": system_vars or {}, # 系统变量:{{sys.execution_id}} + } + + # 支持直接通过节点ID访问节点输出:{{llm_qa.output}} + # 将所有节点输出添加到顶层上下文 + context.update(node_outputs) + + # 为了向后兼容,也支持直接访问用户变量 + context.update(variables) + context["nodes"] = node_outputs # 旧语法兼容 + + try: + tmpl = self.env.from_string(template) + return tmpl.render(**context) + + except TemplateSyntaxError as e: + logger.error(f"模板语法错误: {template}, 错误: {e}") + raise ValueError(f"模板语法错误: {e}") + + except UndefinedError as e: + logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}") + raise ValueError(f"未定义的变量: {e}") + + except Exception as e: + logger.error(f"模板渲染异常: {template}, 错误: {e}") + raise ValueError(f"模板渲染失败: {e}") + + def validate(self, template: str) -> list[str]: + """验证模板语法 + + Args: + template: 模板字符串 + + Returns: + 错误列表,如果为空则验证通过 + + Examples: + >>> renderer = TemplateRenderer() + >>> renderer.validate("Hello {{var.name}}!") + [] + + >>> renderer.validate("Hello {{var.name") # 缺少结束标记 + ['模板语法错误: ...'] + """ + errors = [] + + try: + self.env.from_string(template) + except TemplateSyntaxError as e: + errors.append(f"模板语法错误: {e}") + except Exception as e: + errors.append(f"模板验证失败: {e}") + + return errors + + +# 全局渲染器实例(严格模式) +_default_renderer = TemplateRenderer(strict=True) + + +def render_template( + template: str, + variables: dict[str, Any], + node_outputs: dict[str, Any], + system_vars: dict[str, Any] | None = None +) -> str: + """渲染模板(便捷函数) + + Args: + template: 模板字符串 + variables: 用户变量 + node_outputs: 节点输出 + system_vars: 系统变量 + + Returns: + 渲染后的字符串 + + Examples: + >>> render_template( + ... "请分析: {{var.text}}", + ... {"text": "这是一段文本"}, + ... {}, + ... {} + ... ) + '请分析: 这是一段文本' + """ + return _default_renderer.render(template, variables, node_outputs, system_vars) + + +def validate_template(template: str) -> list[str]: + """验证模板语法(便捷函数) + + Args: + template: 模板字符串 + + Returns: + 错误列表 + """ + return _default_renderer.validate(template) diff --git a/api/app/core/workflow/validator.py b/api/app/core/workflow/validator.py new file mode 100644 index 00000000..58bc20b9 --- /dev/null +++ b/api/app/core/workflow/validator.py @@ -0,0 +1,277 @@ +""" +工作流配置验证器 + +验证工作流配置的有效性,确保配置符合规范。 +""" + +import logging +from typing import Any, Union + +logger = logging.getLogger(__name__) + + +class WorkflowValidator: + """工作流配置验证器""" + + @staticmethod + def validate(workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]: + """验证工作流配置 + + Args: + workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型 + + Returns: + (is_valid, errors): 是否有效和错误列表 + + Examples: + >>> config = { + ... "nodes": [ + ... {"id": "start", "type": "start"}, + ... {"id": "end", "type": "end"} + ... ], + ... "edges": [ + ... {"source": "start", "target": "end"} + ... ] + ... } + >>> is_valid, errors = WorkflowValidator.validate(config) + >>> is_valid + True + """ + errors = [] + + # 支持字典和 Pydantic 模型 + if isinstance(workflow_config, dict): + nodes = workflow_config.get("nodes", []) + edges = workflow_config.get("edges", []) + variables = workflow_config.get("variables", []) + else: + # Pydantic 模型 + nodes = getattr(workflow_config, "nodes", []) + edges = getattr(workflow_config, "edges", []) + variables = getattr(workflow_config, "variables", []) + + # 1. 验证 start 节点(有且只有一个) + start_nodes = [n for n in nodes if n.get("type") == "start"] + if len(start_nodes) == 0: + errors.append("工作流必须有一个 start 节点") + elif len(start_nodes) > 1: + errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个") + + # 2. 验证 end 节点(至少一个) + end_nodes = [n for n in nodes if n.get("type") == "end"] + if len(end_nodes) == 0: + errors.append("工作流必须至少有一个 end 节点") + + # 3. 验证节点 ID 唯一性 + node_ids = [n.get("id") for n in nodes] + if len(node_ids) != len(set(node_ids)): + duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1] + errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}") + + # 4. 验证节点必须有 id 和 type + for i, node in enumerate(nodes): + if not node.get("id"): + errors.append(f"节点 #{i} 缺少 id 字段") + if not node.get("type"): + errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段") + + # 5. 验证边的有效性 + node_id_set = set(node_ids) + for i, edge in enumerate(edges): + source = edge.get("source") + target = edge.get("target") + + if not source: + errors.append(f"边 #{i} 缺少 source 字段") + elif source not in node_id_set: + errors.append(f"边 #{i} 的 source 节点不存在: {source}") + + if not target: + errors.append(f"边 #{i} 缺少 target 字段") + elif target not in node_id_set: + errors.append(f"边 #{i} 的 target 节点不存在: {target}") + + # 6. 验证所有节点可达(从 start 节点出发) + if start_nodes and not errors: # 只有在前面验证通过时才检查可达性 + reachable = WorkflowValidator._get_reachable_nodes( + start_nodes[0]["id"], + edges + ) + unreachable = node_id_set - reachable + if unreachable: + errors.append(f"以下节点无法从 start 节点到达: {unreachable}") + + # 7. 检测循环依赖(非 loop 节点) + if not errors: # 只有在前面验证通过时才检查循环 + has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges) + if has_cycle: + errors.append( + f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}" + ) + + # 8. 验证变量名 + from app.core.workflow.expression_evaluator import ExpressionEvaluator + var_errors = ExpressionEvaluator.validate_variable_names(variables) + errors.extend(var_errors) + + return len(errors) == 0, errors + + @staticmethod + def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]: + """获取从 start 节点可达的所有节点 + + Args: + start_id: 起始节点 ID + edges: 边列表 + + Returns: + 可达节点 ID 集合 + """ + reachable = {start_id} + queue = [start_id] + + while queue: + current = queue.pop(0) + for edge in edges: + if edge.get("source") == current: + target = edge.get("target") + if target and target not in reachable: + reachable.add(target) + queue.append(target) + + return reachable + + @staticmethod + def _has_cycle(nodes: list[dict], edges: list[dict]) -> tuple[bool, list[str]]: + """检测是否存在循环依赖(DFS) + + Args: + nodes: 节点列表 + edges: 边列表 + + Returns: + (has_cycle, cycle_path): 是否有循环和循环路径 + """ + # 排除 loop 类型的节点 + loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"} + + # 构建邻接表(排除 loop 节点的边和错误边) + graph: dict[str, list[str]] = {} + for edge in edges: + source = edge.get("source") + target = edge.get("target") + edge_type = edge.get("type") + + # 跳过错误边 + if edge_type == "error": + continue + + # 如果涉及 loop 节点,跳过 + if source in loop_nodes or target in loop_nodes: + continue + + if source and target: + if source not in graph: + graph[source] = [] + graph[source].append(target) + + # DFS 检测环 + visited = set() + rec_stack = set() + path = [] + cycle_path = [] + + def dfs(node: str) -> bool: + """DFS 检测环,返回是否找到环""" + visited.add(node) + rec_stack.add(node) + path.append(node) + + for neighbor in graph.get(node, []): + if neighbor not in visited: + if dfs(neighbor): + return True + elif neighbor in rec_stack: + # 找到环,记录环路径 + cycle_start = path.index(neighbor) + cycle_path.extend([*path[cycle_start:], neighbor]) + return True + + rec_stack.remove(node) + path.pop() + return False + + # 检查所有节点 + for node_id in graph: + if node_id not in visited: + if dfs(node_id): + return True, cycle_path + + return False, [] + + @staticmethod + def validate_for_publish(workflow_config: dict[str, Any]) -> tuple[bool, list[str]]: + """验证工作流配置是否可以发布(更严格的验证) + + Args: + workflow_config: 工作流配置 + + Returns: + (is_valid, errors): 是否有效和错误列表 + """ + # 先执行基础验证 + is_valid, errors = WorkflowValidator.validate(workflow_config) + + if not is_valid: + return False, errors + + # 额外的发布验证 + nodes = workflow_config.get("nodes", []) + + # 1. 验证所有节点都有名称 + for node in nodes: + if node.get("type") not in ["start", "end"] and not node.get("name"): + errors.append( + f"节点 {node.get('id')} 缺少名称(发布时必须提供)" + ) + + # 2. 验证所有非 start/end 节点都有配置 + for node in nodes: + node_type = node.get("type") + if node_type not in ["start", "end"]: + config = node.get("config") + if not config or not isinstance(config, dict): + errors.append( + f"节点 {node.get('id')} 缺少配置(发布时必须提供)" + ) + + # 3. 验证必填变量 + variables = workflow_config.get("variables", []) + required_vars = [v for v in variables if v.get("required")] + if required_vars: + # 这里只是提示,实际执行时会检查 + logger.info( + f"工作流包含 {len(required_vars)} 个必填变量: " + f"{[v.get('name') for v in required_vars]}" + ) + + return len(errors) == 0, errors + + +def validate_workflow_config( + workflow_config: dict[str, Any], + for_publish: bool = False +) -> tuple[bool, list[str]]: + """验证工作流配置(便捷函数) + + Args: + workflow_config: 工作流配置 + for_publish: 是否为发布验证(更严格) + + Returns: + (is_valid, errors): 是否有效和错误列表 + """ + if for_publish: + return WorkflowValidator.validate_for_publish(workflow_config) + else: + return WorkflowValidator.validate(workflow_config) diff --git a/api/app/core/workflow/variable_pool.py b/api/app/core/workflow/variable_pool.py new file mode 100644 index 00000000..1f589dab --- /dev/null +++ b/api/app/core/workflow/variable_pool.py @@ -0,0 +1,293 @@ +""" +变量池 (Variable Pool) + +工作流执行的数据中心,管理所有变量的存储和访问。 + +变量类型: +1. 系统变量 (sys.*) - 系统内置变量(execution_id, workspace_id, user_id, message 等) +2. 节点输出 (node_id.*) - 节点执行结果 +3. 会话变量 (conv.*) - 会话级变量(跨多轮对话保持) +""" + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +class VariableSelector: + """变量选择器 + + 用于引用变量的路径表示。 + + Examples: + >>> selector = VariableSelector(["sys", "message"]) + >>> selector = VariableSelector(["node_A", "output"]) + >>> selector = VariableSelector.from_string("sys.message") + """ + + def __init__(self, path: list[str]): + """初始化变量选择器 + + Args: + path: 变量路径,如 ["sys", "message"] 或 ["node_A", "output"] + """ + if not path or len(path) < 1: + raise ValueError("变量路径不能为空") + + self.path = path + self.namespace = path[0] # sys, var, 或 node_id + self.key = path[1] if len(path) > 1 else None + + @classmethod + def from_string(cls, selector_str: str) -> "VariableSelector": + """从字符串创建选择器 + + Args: + selector_str: 选择器字符串,如 "sys.message" 或 "node_A.output" + + Returns: + VariableSelector 实例 + + Examples: + >>> selector = VariableSelector.from_string("sys.message") + >>> selector = VariableSelector.from_string("llm_qa.output") + """ + path = selector_str.split(".") + return cls(path) + + def __str__(self) -> str: + return ".".join(self.path) + + def __repr__(self) -> str: + return f"VariableSelector({self.path})" + + +class VariablePool: + """变量池 + + 管理工作流执行过程中的所有变量。 + + 变量命名空间: + - sys.*: 系统变量(message, execution_id, workspace_id, user_id, conversation_id) + - conv.*: 会话变量(跨多轮对话保持的变量) + - .*: 节点输出 + + Examples: + >>> pool = VariablePool(state) + >>> pool.get(["sys", "message"]) + "用户的问题" + >>> pool.get(["llm_qa", "output"]) + "AI 的回答" + >>> pool.set(["conv", "user_name"], "张三") + """ + + def __init__(self, state: dict[str, Any]): + """初始化变量池 + + Args: + state: 工作流状态(LangGraph State) + """ + self.state = state + + def get(self, selector: list[str] | str, default: Any = None) -> Any: + """获取变量值 + + Args: + selector: 变量选择器,可以是列表或字符串 + default: 默认值(变量不存在时返回) + + Returns: + 变量值 + + Examples: + >>> pool.get(["sys", "message"]) + >>> pool.get("sys.message") + >>> pool.get(["llm_qa", "output"]) + >>> pool.get("llm_qa.output") + + Raises: + KeyError: 变量不存在且未提供默认值 + """ + # 转换为 VariableSelector + if isinstance(selector, str): + selector = VariableSelector.from_string(selector).path + + if not selector or len(selector) < 1: + raise ValueError("变量选择器不能为空") + + namespace = selector[0] + + try: + # 系统变量 + if namespace == "sys": + key = selector[1] if len(selector) > 1 else None + if not key: + return self.state.get("variables", {}).get("sys", {}) + return self.state.get("variables", {}).get("sys", {}).get(key, default) + + # 会话变量 + elif namespace == "conv": + key = selector[1] if len(selector) > 1 else None + if not key: + return self.state.get("variables", {}).get("conv", {}) + return self.state.get("variables", {}).get("conv", {}).get(key, default) + + # 节点输出(从 runtime_vars 读取) + else: + node_id = namespace + runtime_vars = self.state.get("runtime_vars", {}) + + if node_id not in runtime_vars: + if default is not None: + return default + raise KeyError(f"节点 '{node_id}' 的输出不存在") + + node_var = runtime_vars[node_id] + + # 如果只有节点 ID,返回整个变量 + if len(selector) == 1: + return node_var + + # 获取特定字段 + # 支持嵌套访问,如 node_id.field.subfield + result = node_var + for k in selector[1:]: + if isinstance(result, dict): + result = result.get(k) + if result is None: + if default is not None: + return default + raise KeyError(f"字段 '{'.'.join(selector)}' 不存在") + else: + if default is not None: + return default + raise KeyError(f"无法访问 '{'.'.join(selector)}'") + + return result + + except KeyError: + if default is not None: + return default + raise + + def set(self, selector: list[str] | str, value: Any): + """设置变量值 + + Args: + selector: 变量选择器 + value: 变量值 + + Examples: + >>> pool.set(["conv", "user_name"], "张三") + >>> pool.set("conv.user_name", "张三") + + Note: + - 只能设置会话变量 (conv.*) + - 系统变量和节点输出是只读的 + """ + # 转换为 VariableSelector + if isinstance(selector, str): + selector = VariableSelector.from_string(selector).path + + if not selector or len(selector) < 2: + raise ValueError("变量选择器必须包含命名空间和键名") + + namespace = selector[0] + + if namespace != "conv": + raise ValueError("只能设置会话变量 (conv.*)") + + key = selector[1] + + # 确保 variables 结构存在 + if "variables" not in self.state: + self.state["variables"] = {"sys": {}, "conv": {}} + if "conv" not in self.state["variables"]: + self.state["variables"]["conv"] = {} + + # 设置值 + self.state["variables"]["conv"][key] = value + + logger.debug(f"设置变量: {'.'.join(selector)} = {value}") + + def has(self, selector: list[str] | str) -> bool: + """检查变量是否存在 + + Args: + selector: 变量选择器 + + Returns: + 变量是否存在 + + Examples: + >>> pool.has(["sys", "message"]) + True + >>> pool.has("llm_qa.output") + False + """ + try: + self.get(selector) + return True + except KeyError: + return False + + def get_all_system_vars(self) -> dict[str, Any]: + """获取所有系统变量 + + Returns: + 系统变量字典 + """ + return self.state.get("variables", {}).get("sys", {}) + + def get_all_conversation_vars(self) -> dict[str, Any]: + """获取所有会话变量 + + Returns: + 会话变量字典 + """ + return self.state.get("variables", {}).get("conv", {}) + + def get_all_node_outputs(self) -> dict[str, Any]: + """获取所有节点输出(运行时变量) + + Returns: + 节点输出字典,键为节点 ID + """ + return self.state.get("runtime_vars", {}) + + def get_node_output(self, node_id: str) -> dict[str, Any] | None: + """获取指定节点的输出(运行时变量) + + Args: + node_id: 节点 ID + + Returns: + 节点输出或 None + """ + return self.state.get("runtime_vars", {}).get(node_id) + + def to_dict(self) -> dict[str, Any]: + """导出为字典 + + Returns: + 包含所有变量的字典 + """ + return { + "system": self.get_all_system_vars(), + "conversation": self.get_all_conversation_vars(), + "nodes": self.get_all_node_outputs() # 从 runtime_vars 读取 + } + + def __repr__(self) -> str: + sys_vars = self.get_all_system_vars() + conv_vars = self.get_all_conversation_vars() + runtime_vars = self.get_all_node_outputs() + + return ( + f"VariablePool(\n" + f" system_vars={len(sys_vars)},\n" + f" conversation_vars={len(conv_vars)},\n" + f" runtime_vars={len(runtime_vars)}\n" + f")" + ) diff --git a/api/app/main.py b/api/app/main.py index 20a2b0d9..d5efeb35 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -1,10 +1,9 @@ import os import subprocess -from dotenv import load_dotenv +from contextlib import asynccontextmanager + from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware -from app.core.config import settings -from contextlib import asynccontextmanager from fastapi.responses import JSONResponse from app.core.response_utils import fail from app.core.logging_config import LoggingConfig, get_logger @@ -38,9 +37,13 @@ router = APIRouter(prefix="/memory", tags=["Memory"]) # 管理端 API (JWT 认证) from app.controllers import manager_router - # 服务端 API (API Key 认证) from app.controllers.service import service_router +from app.core.config import settings +from app.core.error_codes import BizCode, HTTP_MAPPING +from app.core.exceptions import BusinessException +from app.core.logging_config import LoggingConfig, get_logger +from app.core.response_utils import fail # Initialize logging system LoggingConfig.setup_logging() @@ -414,5 +417,4 @@ async def unhandled_exception_handler(request: Request, exc: Exception): if __name__ == "__main__": import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/api/app/models/__init__.py b/api/app/models/__init__.py index 493e894b..fd0c23e2 100644 --- a/api/app/models/__init__.py +++ b/api/app/models/__init__.py @@ -15,9 +15,11 @@ from .end_user_model import EndUser from .appshare_model import AppShare from .release_share_model import ReleaseShare from .conversation_model import Conversation, Message -from .api_key_model import ApiKey, ApiKeyLog, ApiKeyType, ResourceType +from .api_key_model import ApiKey, ApiKeyLog, ApiKeyType from .data_config_model import DataConfig from .multi_agent_model import MultiAgentConfig, AgentInvocation +from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution +from .retrieval_info import RetrievalInfo __all__ = [ "Tenants", @@ -46,8 +48,11 @@ __all__ = [ "ApiKey", "ApiKeyLog", "ApiKeyType", - "ResourceType", "DataConfig", "MultiAgentConfig", - "AgentInvocation" + "AgentInvocation", + "WorkflowConfig", + "WorkflowExecution", + "WorkflowNodeExecution", + "RetrievalInfo" ] diff --git a/api/app/models/api_key_model.py b/api/app/models/api_key_model.py index b123a034..f7cea634 100644 --- a/api/app/models/api_key_model.py +++ b/api/app/models/api_key_model.py @@ -2,7 +2,7 @@ import datetime import uuid -from sqlalchemy import Column, String, Boolean, DateTime, Integer, ForeignKey, Text, Enum +from sqlalchemy import Column, String, Boolean, DateTime, Integer, ForeignKey, Text from sqlalchemy.dialects.postgresql import UUID, JSONB from sqlalchemy.orm import relationship from enum import StrEnum @@ -12,18 +12,10 @@ from app.db import Base class ApiKeyType(StrEnum): """API Key 类型""" - APP = "app" # 应用 API Key - RAG = "rag" # RAG API Key - MEMORY = "memory" # Memory API Key - - -class ResourceType(StrEnum): - """资源类型枚举""" - AGENT = "Agent" # 智能体 - CLUSTER = "Cluster" # 集群 - WORKFLOW = "Workflow" # 工作流 - KNOWLEDGE = "Knowledge" # 知识库 - MEMORY_ENGINE = "Memory_Engine" # 记忆引擎 + AGENT = "agent" # 智能体 + CLUSTER = "cluster" # 集群 + WORKFLOW = "workflow" # 工作流 + SERVICE = "service" # 服务 class ApiKey(Base): @@ -35,18 +27,16 @@ class ApiKey(Base): # 基本信息 name = Column(String(255), nullable=False, comment="API Key 名称") description = Column(Text, comment="描述") - key_prefix = Column(String(20), nullable=False, comment="Key 前缀") - key_hash = Column(String(255), nullable=False, unique=True, index=True, comment="Key 哈希值") + api_key = Column(String(255), nullable=False, unique=True, index=True, comment="API Key 明文") # 类型和权限 type = Column(String(50), nullable=False, index=True, comment="API Key 类型") - scopes = Column(JSONB, nullable=False, default=list, comment="权限范围列表") + scopes = Column(JSONB, default=list, comment="权限范围列表") # 关联资源 workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="CASCADE"), nullable=False, index=True, comment="所属工作空间") resource_id = Column(UUID(as_uuid=True), index=True, comment="关联资源ID") - resource_type = Column(String(50), comment="资源类型") # 限制和配额 rate_limit = Column(Integer, default=10, comment="QPS限制(请求/秒)") diff --git a/api/app/models/app_model.py b/api/app/models/app_model.py index 7897eb62..6b8da6f0 100644 --- a/api/app/models/app_model.py +++ b/api/app/models/app_model.py @@ -86,6 +86,14 @@ class App(Base): uselist=False, cascade="all, delete-orphan", ) + + # 一对一:工作流配置(仅当 type=workflow 时有效) + workflow_config = relationship( + "WorkflowConfig", + back_populates="app", + uselist=False, + cascade="all, delete-orphan", + ) # 发布版本关联 current_release = relationship("AppRelease", foreign_keys=[current_release_id]) diff --git a/api/app/models/workflow_model.py b/api/app/models/workflow_model.py new file mode 100644 index 00000000..d599f717 --- /dev/null +++ b/api/app/models/workflow_model.py @@ -0,0 +1,196 @@ +""" +工作流相关数据模型 +""" + +import datetime +import uuid +from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float, ForeignKey, Text +from sqlalchemy.dialects.postgresql import UUID, JSONB +from sqlalchemy.orm import relationship +from app.db import Base + + +class WorkflowConfig(Base): + """工作流配置表""" + __tablename__ = "workflow_configs" + + # 主键 + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) + + # 关联应用(一对一) + app_id = Column( + UUID(as_uuid=True), + ForeignKey("apps.id", ondelete="CASCADE"), + nullable=False, + unique=True, + index=True + ) + + # 节点和边的定义(JSON 格式) + nodes = Column(JSONB, nullable=False, default=list) + edges = Column(JSONB, nullable=False, default=list) + + # 全局变量定义 + variables = Column(JSONB, default=list) + + # 执行配置 + execution_config = Column(JSONB, nullable=False, default=dict) + + # 触发器配置(可选) + triggers = Column(JSONB, default=list) + + # 状态 + is_active = Column(Boolean, nullable=False, default=True) + created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) + updated_at = Column( + DateTime, + nullable=False, + default=datetime.datetime.now, + onupdate=datetime.datetime.now + ) + + # 关系 + app = relationship("App", back_populates="workflow_config") + executions = relationship( + "WorkflowExecution", + back_populates="workflow_config", + cascade="all, delete-orphan" + ) + + def __repr__(self): + return f"" + + +class WorkflowExecution(Base): + """工作流执行记录表""" + __tablename__ = "workflow_executions" + + # 主键 + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) + + # 关联信息 + workflow_config_id = Column( + UUID(as_uuid=True), + ForeignKey("workflow_configs.id", ondelete="CASCADE"), + nullable=False, + index=True + ) + app_id = Column( + UUID(as_uuid=True), + ForeignKey("apps.id", ondelete="CASCADE"), + nullable=False, + index=True + ) + conversation_id = Column( + UUID(as_uuid=True), + ForeignKey("conversations.id", ondelete="SET NULL"), + nullable=True, + index=True + ) + + # 执行信息 + execution_id = Column(String(100), nullable=False, unique=True, index=True) + trigger_type = Column(String(20), nullable=False) # manual, schedule, webhook, event + triggered_by = Column( + UUID(as_uuid=True), + ForeignKey("users.id"), + nullable=True + ) + + # 输入输出 + input_data = Column(JSONB) + output_data = Column(JSONB) + context = Column(JSONB, default=dict) + + # 状态 + status = Column(String(20), nullable=False, default="pending", index=True) + # 可选值:pending, running, completed, failed, cancelled, timeout + + error_message = Column(Text) + error_node_id = Column(String(100)) + + # 性能指标 + started_at = Column(DateTime, nullable=False, default=datetime.datetime.now, index=True) + completed_at = Column(DateTime) + elapsed_time = Column(Float) # 耗时(秒) + + # 资源使用 + token_usage = Column(JSONB) + + # 元数据(使用 meta_data 避免与 SQLAlchemy 保留字 metadata 冲突) + meta_data = Column(JSONB, default=dict) + + created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) + + # 关系 + workflow_config = relationship("WorkflowConfig", back_populates="executions") + app = relationship("App") + conversation = relationship("Conversation") + triggered_by_user = relationship("User", foreign_keys=[triggered_by]) + node_executions = relationship( + "WorkflowNodeExecution", + back_populates="execution", + cascade="all, delete-orphan", + order_by="WorkflowNodeExecution.execution_order" + ) + + def __repr__(self): + return f"" + + +class WorkflowNodeExecution(Base): + """工作流节点执行记录表""" + __tablename__ = "workflow_node_executions" + + # 主键 + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) + + # 关联执行 + execution_id = Column( + UUID(as_uuid=True), + ForeignKey("workflow_executions.id", ondelete="CASCADE"), + nullable=False, + index=True + ) + + # 节点信息 + node_id = Column(String(100), nullable=False, index=True) + node_type = Column(String(20), nullable=False) + node_name = Column(String(100)) + + # 执行顺序 + execution_order = Column(Integer, nullable=False) + retry_count = Column(Integer, nullable=False, default=0) + + # 输入输出 + input_data = Column(JSONB) + output_data = Column(JSONB) + + # 状态 + status = Column(String(20), nullable=False, default="pending", index=True) + # 可选值:pending, running, completed, failed, skipped, cached + + error_message = Column(Text) + + # 性能指标 + started_at = Column(DateTime, nullable=False, default=datetime.datetime.now) + completed_at = Column(DateTime) + elapsed_time = Column(Float) # 耗时(秒) + + # 资源使用(针对 LLM 节点) + token_usage = Column(JSONB) + + # 缓存信息 + cache_hit = Column(Boolean, default=False) + cache_key = Column(String(255)) + + # 元数据(使用 meta_data 避免与 SQLAlchemy 保留字 metadata 冲突) + meta_data = Column(JSONB, default=dict) + + created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) + + # 关系 + execution = relationship("WorkflowExecution", back_populates="node_executions") + + def __repr__(self): + return f"" diff --git a/api/app/repositories/api_key_repository.py b/api/app/repositories/api_key_repository.py index 27ffdff0..ad94fccf 100644 --- a/api/app/repositories/api_key_repository.py +++ b/api/app/repositories/api_key_repository.py @@ -27,9 +27,9 @@ class ApiKeyRepository: return db.get(ApiKey, api_key_id) @staticmethod - def get_by_hash(db: Session, key_hash: str) -> Optional[ApiKey]: - """根据哈希值获取 API Key""" - stmt = select(ApiKey).where(ApiKey.key_hash == key_hash) + def get_by_api_key(db: Session, api_key: str) -> Optional[ApiKey]: + """根据 API Key 获取 API Key""" + stmt = select(ApiKey).where(ApiKey.api_key == api_key) return db.scalars(stmt).first() @staticmethod @@ -63,11 +63,15 @@ class ApiKeyRepository: @staticmethod def update(db: Session, api_key_id: uuid.UUID, update_data: dict) -> ApiKey | None: """更新 API Key""" + allow_none_fields = {"description", "quota_limit", "expires_at"} api_key = db.get(ApiKey, api_key_id) if api_key: for key, value in update_data.items(): - if value is not None: + if key in allow_none_fields: setattr(api_key, key, value) + else: + if value is not None: + setattr(api_key, key, value) db.flush() return api_key diff --git a/api/app/repositories/app_repository.py b/api/app/repositories/app_repository.py index 5630238d..11a2ea3e 100644 --- a/api/app/repositories/app_repository.py +++ b/api/app/repositories/app_repository.py @@ -14,7 +14,7 @@ class AppRepository: def __init__(self, db: Session): self.db = db - def get_apps_by_workspace_id(self, workspace_id: uuid.UUID) -> List[App]: + def get_apps_by_workspace_id(self, workspace_id: uuid.UUID) -> list[App]: """根据工作空间ID查询应用""" try: apps = self.db.query(App).filter(App.workspace_id == workspace_id).all() @@ -24,7 +24,19 @@ class AppRepository: db_logger.error(f"查询工作空间 {workspace_id} 下应用时出错: {str(e)}") raise + def get_apps_by_id(self, app_id: uuid.UUID) -> App: + try: + app = self.db.query(App).filter(App.id == app_id, App.is_active == True).first() + return app + except Exception as e: + raise + def get_apps_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> List[App]: """根据工作空间ID查询应用""" repo = AppRepository(db) return repo.get_apps_by_workspace_id(workspace_id) + +def get_apps_by_id(db: Session, app_id: uuid.UUID) -> App: + """根据工作空间ID查询应用""" + repo = AppRepository(db) + return repo.get_apps_by_id(app_id) diff --git a/api/app/repositories/workflow_repository.py b/api/app/repositories/workflow_repository.py new file mode 100644 index 00000000..04734640 --- /dev/null +++ b/api/app/repositories/workflow_repository.py @@ -0,0 +1,247 @@ +""" +工作流数据访问层 +""" + +import uuid +from typing import Any, Annotated +from sqlalchemy.orm import Session +from sqlalchemy import desc +from fastapi import Depends + +from app.models.workflow_model import ( + WorkflowConfig, + WorkflowExecution, + WorkflowNodeExecution +) +from app.db import get_db + + +class WorkflowConfigRepository: + """工作流配置仓储""" + + def __init__(self, db: Session): + self.db = db + + def get_by_app_id(self, app_id: uuid.UUID) -> WorkflowConfig | None: + """根据应用 ID 获取工作流配置 + + Args: + app_id: 应用 ID + + Returns: + 工作流配置或 None + """ + return self.db.query(WorkflowConfig).filter( + WorkflowConfig.app_id == app_id, + WorkflowConfig.is_active == True + ).first() + + def create_or_update( + self, + app_id: uuid.UUID, + nodes: list[dict[str, Any]], + edges: list[dict[str, Any]], + variables: list[dict[str, Any]] | None = None, + execution_config: dict[str, Any] | None = None, + triggers: list[dict[str, Any]] | None = None + ) -> WorkflowConfig: + """创建或更新工作流配置 + + Args: + app_id: 应用 ID + nodes: 节点列表 + edges: 边列表 + variables: 变量列表 + execution_config: 执行配置 + triggers: 触发器列表 + + Returns: + 工作流配置 + """ + # 查找现有配置 + existing = self.get_by_app_id(app_id) + + if existing: + # 更新现有配置 + existing.nodes = nodes + existing.edges = edges + if variables is not None: + existing.variables = variables + if execution_config is not None: + existing.execution_config = execution_config + if triggers is not None: + existing.triggers = triggers + self.db.commit() + self.db.refresh(existing) + return existing + else: + # 创建新配置 + config = WorkflowConfig( + app_id=app_id, + nodes=nodes, + edges=edges, + variables=variables or [], + execution_config=execution_config or {}, + triggers=triggers or [] + ) + self.db.add(config) + self.db.commit() + self.db.refresh(config) + return config + + +class WorkflowExecutionRepository: + """工作流执行记录仓储""" + + def __init__(self, db: Session): + self.db = db + + def get_by_execution_id(self, execution_id: str) -> WorkflowExecution | None: + """根据执行 ID 获取执行记录 + + Args: + execution_id: 执行 ID + + Returns: + 执行记录或 None + """ + return self.db.query(WorkflowExecution).filter( + WorkflowExecution.execution_id == execution_id + ).first() + + def get_by_app_id( + self, + app_id: uuid.UUID, + limit: int = 50, + offset: int = 0 + ) -> list[WorkflowExecution]: + """根据应用 ID 获取执行记录列表 + + Args: + app_id: 应用 ID + limit: 返回数量限制 + offset: 偏移量 + + Returns: + 执行记录列表 + """ + return self.db.query(WorkflowExecution).filter( + WorkflowExecution.app_id == app_id + ).order_by( + desc(WorkflowExecution.started_at) + ).limit(limit).offset(offset).all() + + def get_by_conversation_id( + self, + conversation_id: uuid.UUID + ) -> list[WorkflowExecution]: + """根据会话 ID 获取执行记录列表 + + Args: + conversation_id: 会话 ID + + Returns: + 执行记录列表 + """ + return self.db.query(WorkflowExecution).filter( + WorkflowExecution.conversation_id == conversation_id + ).order_by( + desc(WorkflowExecution.started_at) + ).all() + + def count_by_app_id(self, app_id: uuid.UUID) -> int: + """统计应用的执行次数 + + Args: + app_id: 应用 ID + + Returns: + 执行次数 + """ + return self.db.query(WorkflowExecution).filter( + WorkflowExecution.app_id == app_id + ).count() + + def count_by_status(self, app_id: uuid.UUID, status: str) -> int: + """统计指定状态的执行次数 + + Args: + app_id: 应用 ID + status: 状态 + + Returns: + 执行次数 + """ + return self.db.query(WorkflowExecution).filter( + WorkflowExecution.app_id == app_id, + WorkflowExecution.status == status + ).count() + + +class WorkflowNodeExecutionRepository: + """工作流节点执行记录仓储""" + + def __init__(self, db: Session): + self.db = db + + def get_by_execution_id( + self, + execution_id: uuid.UUID + ) -> list[WorkflowNodeExecution]: + """根据执行 ID 获取节点执行记录列表 + + Args: + execution_id: 执行 ID + + Returns: + 节点执行记录列表(按执行顺序排序) + """ + return self.db.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.execution_id == execution_id + ).order_by( + WorkflowNodeExecution.execution_order + ).all() + + def get_by_node_id( + self, + execution_id: uuid.UUID, + node_id: str + ) -> list[WorkflowNodeExecution]: + """根据节点 ID 获取节点执行记录(可能有多次重试) + + Args: + execution_id: 执行 ID + node_id: 节点 ID + + Returns: + 节点执行记录列表 + """ + return self.db.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.execution_id == execution_id, + WorkflowNodeExecution.node_id == node_id + ).order_by( + WorkflowNodeExecution.retry_count + ).all() + + +# ==================== 依赖注入函数 ==================== + +def get_workflow_config_repository( + db: Annotated[Session, Depends(get_db)] +) -> WorkflowConfigRepository: + """获取工作流配置仓储(依赖注入)""" + return WorkflowConfigRepository(db) + + +def get_workflow_execution_repository( + db: Annotated[Session, Depends(get_db)] +) -> WorkflowExecutionRepository: + """获取工作流执行记录仓储(依赖注入)""" + return WorkflowExecutionRepository(db) + + +def get_workflow_node_execution_repository( + db: Annotated[Session, Depends(get_db)] +) -> WorkflowNodeExecutionRepository: + """获取工作流节点执行记录仓储(依赖注入)""" + return WorkflowNodeExecutionRepository(db) diff --git a/api/app/schemas/api_key_schema.py b/api/app/schemas/api_key_schema.py index 8c0a1031..d19cf061 100644 --- a/api/app/schemas/api_key_schema.py +++ b/api/app/schemas/api_key_schema.py @@ -1,11 +1,11 @@ """API Key Schema""" import datetime import uuid -from pydantic import BaseModel, Field, ConfigDict -from pydantic.v1 import validator +from pydantic import BaseModel, Field, ConfigDict, field_validator, field_serializer, computed_field from typing import Optional, List -from app.models.api_key_model import ApiKeyType, ResourceType +from app.models.api_key_model import ApiKeyType +from app.core.api_key_utils import timestamp_to_datetime, datetime_to_timestamp class ApiKeyCreate(BaseModel): @@ -15,20 +15,34 @@ class ApiKeyCreate(BaseModel): type: ApiKeyType = Field(..., description="API Key 类型") scopes: List[str] = Field(default_factory=list, description="权限范围列表") resource_id: Optional[uuid.UUID] = Field(None, description="关联资源ID") - resource_type: Optional[ResourceType] = Field(None, description="资源类型") rate_limit: Optional[int] = Field(10, ge=1, le=1000, description="QPS限制(请求/秒)") daily_request_limit: Optional[int] = Field(10000, description="日请求限制", ge=1) quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1) expires_at: Optional[datetime.datetime] = Field(None, description="过期时间") - @validator('scopes') + @computed_field + @property + def is_expired(self) -> bool: + """检查API Key是否已过期""" + if not self.expires_at: + return False + return datetime.datetime.now() > self.expires_at + + @field_validator('expires_at', mode='before') + @classmethod + def parse_expires_at(cls, v): + """将时间戳转换为datetime""" + if isinstance(v, (int, float)): + return timestamp_to_datetime(v) + return v + + @field_validator('scopes') + @classmethod def validate_scopes(cls, v): """验证权限范围格式""" - valid_scopes = [ - "app:all", - "rag:search", "rag:upload", "rag:delete", - "memory:read", "memory:write", "memory:delete", "memory:search" - ] + if v is None: + return [] + valid_scopes = ["app", "rag", "memory"] for scope in v: if scope not in valid_scopes: raise ValueError(f"无效范围: {scope}") @@ -46,14 +60,29 @@ class ApiKeyUpdate(BaseModel): is_active: Optional[bool] = Field(None, description="是否激活") expires_at: Optional[datetime.datetime] = Field(None, description="过期时间") - @validator('scopes') + @computed_field + @property + def is_expired(self) -> bool: + """检查API Key是否已过期""" + if not self.expires_at: + return False + return datetime.datetime.now() > self.expires_at + + @field_validator('expires_at', mode='before') + @classmethod + def parse_expires_at(cls, v): + """将时间戳转换为datetime""" + if isinstance(v, (int, float)): + return timestamp_to_datetime(v) + return v + + @field_validator('scopes') + @classmethod def validate_scopes(cls, v): """验证权限范围格式""" - valid_scopes = { - 'app:all', - 'rag:search', 'rag:upload', 'rag:delete', - 'memory:read', 'memory:write', 'memory:delete', 'memory:search' - } + if v is None: + return v + valid_scopes = ["app", "rag", "memory"] for scope in v: if scope not in valid_scopes: raise ValueError(f"无效范围: {scope}") @@ -67,18 +96,31 @@ class ApiKeyResponse(BaseModel): id: uuid.UUID name: str description: Optional[str] - api_key: str = Field(..., description="API Key 明文(仅创建时返回)") - key_prefix: str + api_key: str type: str scopes: List[str] resource_id: Optional[uuid.UUID] - resource_type: Optional[str] rate_limit: int daily_request_limit: int quota_limit: Optional[int] + is_active: bool expires_at: Optional[datetime.datetime] created_at: datetime.datetime + @computed_field + @property + def is_expired(self) -> bool: + """检查API Key是否已过期""" + if not self.expires_at: + return False + return datetime.datetime.now() > self.expires_at + + @field_serializer('expires_at', 'created_at') + @classmethod + def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]: + """将datetime转换为时间戳""" + return datetime_to_timestamp(v) + class ApiKey(BaseModel): """API Key 信息(不包含明文 Key)""" @@ -87,11 +129,10 @@ class ApiKey(BaseModel): id: uuid.UUID name: str description: Optional[str] - key_prefix: str + api_key: str type: str scopes: List[str] resource_id: Optional[uuid.UUID] - resource_type: Optional[str] rate_limit: int daily_request_limit: int quota_limit: Optional[int] @@ -105,6 +146,20 @@ class ApiKey(BaseModel): created_at: datetime.datetime updated_at: datetime.datetime + @computed_field + @property + def is_expired(self) -> bool: + """检查API Key是否已过期""" + if not self.expires_at: + return False + return datetime.datetime.now() > self.expires_at + + @field_serializer('expires_at', 'last_used_at', 'created_at', 'updated_at') + @classmethod + def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]: + """将datetime转换为时间戳""" + return datetime_to_timestamp(v) + class ApiKeyStats(BaseModel): """API Key 使用统计""" @@ -115,6 +170,12 @@ class ApiKeyStats(BaseModel): last_used_at: Optional[datetime.datetime] = Field(None, description="最后使用时间") avg_response_time: Optional[float] = Field(None, description="平均响应时间(毫秒)") + @field_serializer('last_used_at') + @classmethod + def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]: + """将datetime转换为时间戳""" + return datetime_to_timestamp(v) + class ApiKeyQuery(BaseModel): """API Key 查询参数""" @@ -132,7 +193,6 @@ class ApiKeyAuth(BaseModel): type: str scopes: List[str] resource_id: Optional[uuid.UUID] - resource_type: Optional[str] class ApiKeyLog(BaseModel): @@ -157,3 +217,9 @@ class ApiKeyLog(BaseModel): # 时间信息 created_at: datetime.datetime + + @field_serializer('created_at') + @classmethod + def serialize_datetime(cls, v: datetime.datetime) -> int: + """将datetime转换为时间戳""" + return datetime_to_timestamp(v) diff --git a/api/app/schemas/workflow_schema.py b/api/app/schemas/workflow_schema.py new file mode 100644 index 00000000..eb337298 --- /dev/null +++ b/api/app/schemas/workflow_schema.py @@ -0,0 +1,215 @@ +""" +工作流相关的 Pydantic Schema +""" + +import datetime +import uuid +from typing import Any +from pydantic import BaseModel, Field, ConfigDict, field_serializer + + +# ==================== 节点和边定义 ==================== + +class NodeConfig(BaseModel): + """节点配置""" + model_config = ConfigDict(extra="allow") # 允许额外字段 + + +class NodeDefinition(BaseModel): + """节点定义""" + id: str = Field(..., description="节点唯一标识") + type: str = Field(..., description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code") + name: str | None = Field(None, description="节点名称") + description: str | None = Field(None, description="节点描述") + config: dict[str, Any] = Field(default_factory=dict, description="节点配置") + position: dict[str, float] | None = Field(None, description="节点位置 {x, y}") + error_handling: dict[str, Any] | None = Field(None, description="错误处理配置") + cache: dict[str, Any] | None = Field(None, description="缓存配置") + + +class EdgeDefinition(BaseModel): + """边定义""" + id: str | None = Field(None, description="边唯一标识(可选)") + source: str = Field(..., description="源节点 ID") + target: str = Field(..., description="目标节点 ID") + type: str | None = Field(None, description="边类型: normal, error") + condition: str | None = Field(None, description="条件表达式(条件边)") + label: str | None = Field(None, description="边标签") + + +class VariableDefinition(BaseModel): + """变量定义""" + name: str = Field(..., description="变量名称") + type: str = Field(default="string", description="变量类型: string, number, boolean, object, array") + required: bool = Field(default=False, description="是否必填") + default: Any = Field(None, description="默认值") + description: str | None = Field(None, description="变量描述") + + +class ExecutionConfig(BaseModel): + """执行配置""" + max_iterations: int = Field(default=100, ge=1, le=1000, description="最大迭代次数") + timeout: int = Field(default=600, ge=10, le=3600, description="全局超时时间(秒)") + enable_cache: bool = Field(default=True, description="是否启用节点缓存") + parallel_limit: int = Field(default=5, ge=1, le=20, description="并行执行限制") + + +class TriggerConfig(BaseModel): + """触发器配置""" + type: str = Field(..., description="触发器类型: schedule, webhook, event") + config: dict[str, Any] = Field(default_factory=dict, description="触发器配置") + + +# ==================== 工作流配置 ==================== + +class WorkflowConfigCreate(BaseModel): + """创建工作流配置""" + nodes: list[NodeDefinition] = Field(default_factory=list, description="节点列表") + edges: list[EdgeDefinition] = Field(default_factory=list, description="边列表") + variables: list[VariableDefinition] = Field(default_factory=list, description="变量列表") + execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置") + triggers: list[TriggerConfig] = Field(default_factory=list, description="触发器列表") + + +class WorkflowConfigUpdate(BaseModel): + """更新工作流配置""" + nodes: list[NodeDefinition] | None = None + edges: list[EdgeDefinition] | None = None + variables: list[VariableDefinition] | None = None + execution_config: ExecutionConfig | None = None + triggers: list[TriggerConfig] | None = None + + +class WorkflowConfig(BaseModel): + """工作流配置输出""" + model_config = ConfigDict(from_attributes=True) + + id: uuid.UUID + app_id: uuid.UUID + nodes: list[dict[str, Any]] + edges: list[dict[str, Any]] + variables: list[dict[str, Any]] + execution_config: dict[str, Any] + triggers: list[dict[str, Any]] + is_active: bool + created_at: datetime.datetime + updated_at: datetime.datetime + + @field_serializer("created_at", when_used="json") + def _serialize_created_at(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + @field_serializer("updated_at", when_used="json") + def _serialize_updated_at(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + +# ==================== 工作流执行 ==================== + +class WorkflowExecutionRequest(BaseModel): + """工作流执行请求""" + message: str | None = Field(None, description="用户消息(可选)") + variables: dict[str, Any] = Field(default_factory=dict, description="输入变量") + conversation_id: str | None = Field(None, description="会话 ID(用于关联对话)") + stream: bool = Field(default=False, description="是否流式返回") + + +class WorkflowExecutionResponse(BaseModel): + """工作流执行响应(非流式)""" + execution_id: str = Field(..., description="执行 ID") + status: str = Field(..., description="执行状态") + output: str | None = Field(None, description="最终输出(字符串,便于快速访问)") + output_data: dict[str, Any] | None = Field(None, description="所有节点的详细输出数据") + error_message: str | None = Field(None, description="错误信息") + elapsed_time: float | None = Field(None, description="耗时(秒)") + token_usage: dict[str, Any] | None = Field(None, description="Token 使用情况 {prompt_tokens, completion_tokens, total_tokens}") + + +class WorkflowExecutionStreamChunk(BaseModel): + """工作流执行流式响应块""" + type: str = Field(..., description="事件类型: node_start, token, node_complete, error_redirect, workflow_complete") + execution_id: str = Field(..., description="执行 ID") + data: dict[str, Any] = Field(default_factory=dict, description="事件数据") + + +class WorkflowExecution(BaseModel): + """工作流执行记录输出""" + model_config = ConfigDict(from_attributes=True) + + id: uuid.UUID + workflow_config_id: uuid.UUID + app_id: uuid.UUID + conversation_id: uuid.UUID | None + execution_id: str + trigger_type: str + triggered_by: uuid.UUID | None + input_data: dict[str, Any] | None + output_data: dict[str, Any] | None + context: dict[str, Any] + status: str + error_message: str | None + error_node_id: str | None + started_at: datetime.datetime + completed_at: datetime.datetime | None + elapsed_time: float | None + token_usage: dict[str, Any] | None + meta_data: dict[str, Any] + created_at: datetime.datetime + + @field_serializer("started_at", when_used="json") + def _serialize_started_at(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + @field_serializer("completed_at", when_used="json") + def _serialize_completed_at(self, dt: datetime.datetime | None): + return int(dt.timestamp() * 1000) if dt else None + + @field_serializer("created_at", when_used="json") + def _serialize_created_at(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + +class WorkflowNodeExecution(BaseModel): + """工作流节点执行记录输出""" + model_config = ConfigDict(from_attributes=True) + + id: uuid.UUID + execution_id: uuid.UUID + node_id: str + node_type: str + node_name: str | None + execution_order: int + retry_count: int + input_data: dict[str, Any] | None + output_data: dict[str, Any] | None + status: str + error_message: str | None + started_at: datetime.datetime + completed_at: datetime.datetime | None + elapsed_time: float | None + token_usage: dict[str, Any] | None + cache_hit: bool + cache_key: str | None + meta_data: dict[str, Any] + created_at: datetime.datetime + + @field_serializer("started_at", when_used="json") + def _serialize_started_at(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + @field_serializer("completed_at", when_used="json") + def _serialize_completed_at(self, dt: datetime.datetime | None): + return int(dt.timestamp() * 1000) if dt else None + + @field_serializer("created_at", when_used="json") + def _serialize_created_at(self, dt: datetime.datetime): + return int(dt.timestamp() * 1000) if dt else None + + +# ==================== 验证响应 ==================== + +class WorkflowValidationResponse(BaseModel): + """工作流验证响应""" + is_valid: bool = Field(..., description="是否有效") + errors: list[str] = Field(default_factory=list, description="错误列表") + warnings: list[str] = Field(default_factory=list, description="警告列表") diff --git a/api/app/services/api_key_service.py b/api/app/services/api_key_service.py index 53615e7e..09ba5ca1 100644 --- a/api/app/services/api_key_service.py +++ b/api/app/services/api_key_service.py @@ -13,7 +13,7 @@ from app.models.api_key_model import ApiKey from app.repositories.api_key_repository import ApiKeyRepository, ApiKeyLogRepository from app.schemas import api_key_schema from app.schemas.response_schema import PageData, PageMeta -from app.core.api_key_utils import generate_api_key, hash_api_key, validate_resource_binding +from app.core.api_key_utils import generate_api_key from app.core.exceptions import ( BusinessException, ) @@ -33,21 +33,13 @@ class ApiKeyService: workspace_id: uuid.UUID, user_id: uuid.UUID, data: api_key_schema.ApiKeyCreate - ) -> Tuple[ApiKey, str]: + ) -> ApiKey: """ 创建 API Key Returns: - Tuple[ApiKey, str]: (API Key 对象, API Key 明文) + ApiKey: API Key 对象 """ try: - # 验证资源绑定 - if data.resource_type or data.resource_id: - is_valid, error_msg = validate_resource_binding( - data.resource_type, str(data.resource_id) if data.resource_id else None - ) - if not is_valid: - raise BusinessException(error_msg, BizCode.API_KEY_INVALID_RESOURCE) - existing = db.scalar( select(ApiKey).where( ApiKey.workspace_id == workspace_id, @@ -59,22 +51,20 @@ class ApiKeyService: raise BusinessException(f"API Key 名称 '{data.name}' 已存在", BizCode.API_KEY_DUPLICATE_NAME) # 生成 API Key - api_key, key_hash, key_prefix = generate_api_key(data.type) + api_key = generate_api_key(data.type) # 创建数据 api_key_data = { "id": uuid.uuid4(), "name": data.name, "description": data.description, - "key_prefix": key_prefix, - "key_hash": key_hash, + "api_key": api_key, "type": data.type, "scopes": data.scopes, "workspace_id": workspace_id, "resource_id": data.resource_id, - "resource_type": data.resource_type, - "rate_limit": data.rate_limit or 10, - "daily_request_limit": data.daily_request_limit or 10000, + "rate_limit": data.rate_limit, + "daily_request_limit": data.daily_request_limit, "quota_limit": data.quota_limit, "expires_at": data.expires_at, "created_by": user_id, @@ -90,7 +80,7 @@ class ApiKeyService: "type": data.type }) - return api_key_obj, api_key + return api_key_obj except Exception as e: db.rollback() @@ -147,6 +137,9 @@ class ApiKeyService: """更新 API Key配置""" api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id) + if not api_key: + raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND) + # 检查名称重复 if data.name and data.name != api_key.name: existing = db.scalar( @@ -177,6 +170,9 @@ class ApiKeyService: """删除 API Key""" api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id) + if not api_key: + raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND) + ApiKeyRepository.delete(db, api_key_id) db.commit() @@ -188,27 +184,29 @@ class ApiKeyService: db: Session, api_key_id: uuid.UUID, workspace_id: uuid.UUID - ) -> Tuple[ApiKey, str]: + ) -> ApiKey: """重新生成 API Key""" api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id) + if not api_key: + raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND) + # 检查 API Key 是否激活 if not api_key.is_active: raise BusinessException("无法重新生成已停用的 API Key", BizCode.API_KEY_INACTIVE) # 生成新的 API Key - new_api_key, key_hash, key_prefix = generate_api_key(api_key_schema.ApiKeyType(api_key.type)) + new_api_key = generate_api_key(api_key.type) # 更新 ApiKeyRepository.update(db, api_key_id, { - "key_hash": key_hash, - "key_prefix": key_prefix + "api_key": new_api_key }) db.commit() db.refresh(api_key) logger.info("API Key 重新生成成功", extra={"api_key_id": str(api_key_id)}) - return api_key, new_api_key + return api_key @staticmethod def get_stats( @@ -219,6 +217,9 @@ class ApiKeyService: """获取使用统计""" api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id) + if not api_key: + raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND) + stats_data = ApiKeyRepository.get_stats(db, api_key_id) return api_key_schema.ApiKeyStats(**stats_data) @@ -235,6 +236,9 @@ class ApiKeyService: # 验证 API Key 权限 api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id) + if not api_key: + raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND) + items, total = ApiKeyLogRepository.list_by_api_key( db, api_key_id, filters, page, pagesize ) @@ -330,7 +334,6 @@ class RateLimiterService: "X-RateLimit-Reset": str(qps_info["reset"]) } - # Check daily requests daily_ok, daily_info = await self.check_daily_requests( api_key.id, api_key.daily_request_limit @@ -342,7 +345,6 @@ class RateLimiterService: "X-RateLimit-Reset": str(daily_info["reset"]) } - # All checks passed headers = { "X-RateLimit-Limit-QPS": str(qps_info["limit"]), "X-RateLimit-Remaining-QPS": str(qps_info["remaining"]), @@ -363,13 +365,12 @@ class ApiKeyAuthService: 验证API Key 有效性 检查: - 1. Key hash 是否存在 + 1. API Key 是否存在 2. is_active 是否为true 3. expires_at 是否未过期 4. quota 是否未超限 """ - key_hash = hash_api_key(api_key) - api_key_obj = ApiKeyRepository.get_by_hash(db, key_hash) + api_key_obj = ApiKeyRepository.get_by_api_key(db, api_key) if not api_key_obj: return None @@ -393,14 +394,7 @@ class ApiKeyAuthService: @staticmethod def check_resource( api_key: ApiKey, - resource_type: str, resource_id: uuid.UUID ) -> bool: """检查资源绑定""" - if not api_key.resource_id: - return True - - return ( - api_key.resource_type == resource_type and - api_key.resource_id == resource_id - ) + return api_key.resource_id == resource_id diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 06007bf3..07625fee 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -9,22 +9,24 @@ """ import datetime import uuid -from typing import Optional, List, Dict, Any, Tuple +from typing import Optional, List, Dict, Any, Tuple, Type -from sqlalchemy.orm import Session from sqlalchemy import select, func, or_, and_ +from sqlalchemy.orm import Session -from app.models import App, AgentConfig, AppRelease, MultiAgentConfig -from app.schemas import app_schema +from app.core.error_codes import BizCode from app.core.exceptions import ( ResourceNotFoundException, - ValidationException, BusinessException, ) -from app.core.error_codes import BizCode from app.core.logging_config import get_business_logger -from app.services.agent_config_converter import AgentConfigConverter +from app.models import App, AgentConfig, AppRelease, MultiAgentConfig, WorkflowConfig from app.models.app_model import AppStatus, AppType +from app.repositories.app_repository import get_apps_by_id +from app.repositories.workflow_repository import WorkflowConfigRepository +from app.schemas import app_schema +from app.schemas.workflow_schema import WorkflowConfigUpdate +from app.services.agent_config_converter import AgentConfigConverter # 获取业务日志器 logger = get_business_logger() @@ -32,27 +34,27 @@ logger = get_business_logger() class AppService: """应用服务类 - + 负责应用相关的所有业务逻辑处理,遵循单一职责原则。 """ - + def __init__(self, db: Session): """初始化应用服务 - + Args: db: 数据库会话 """ self.db = db - + # ==================== 私有辅助方法 ==================== - + def _validate_workspace_access(self, app: App, workspace_id: Optional[uuid.UUID]) -> None: """验证工作空间访问权限(严格模式,用于修改操作) - + Args: app: 应用对象 workspace_id: 工作空间ID - + Raises: BusinessException: 当应用不在指定工作空间时 """ @@ -62,42 +64,42 @@ class AppService: extra={"app_id": str(app.id), "workspace_id": str(workspace_id)} ) raise BusinessException("应用不在指定工作空间中", BizCode.WORKSPACE_NO_ACCESS) - + def _check_app_accessible(self, app: App, workspace_id: Optional[uuid.UUID]) -> bool: """检查应用是否可访问(包括共享应用) - + Args: app: 应用对象 workspace_id: 工作空间ID - + Returns: bool: 是否可访问 """ from app.models import AppShare - + if workspace_id is None: return True - + # 1. 检查是否是本工作空间的应用 if app.workspace_id == workspace_id: return True - + # 2. 检查是否是共享给本工作空间的应用 stmt = select(AppShare).where( AppShare.source_app_id == app.id, AppShare.target_workspace_id == workspace_id ) share = self.db.scalars(stmt).first() - + return share is not None - + def _validate_app_accessible(self, app: App, workspace_id: Optional[uuid.UUID]) -> None: """验证应用是否可访问(包括共享应用,用于只读操作) - + Args: app: 应用对象 workspace_id: 工作空间ID - + Raises: BusinessException: 当应用不可访问时 """ @@ -107,28 +109,44 @@ class AppService: extra={"app_id": str(app.id), "workspace_id": str(workspace_id)} ) raise BusinessException("应用不可访问", BizCode.WORKSPACE_NO_ACCESS) - + def _get_app_or_404(self, app_id: uuid.UUID) -> App: """获取应用或抛出404异常 - + Args: app_id: 应用ID - + Returns: App: 应用对象 - + Raises: ResourceNotFoundException: 当应用不存在时 """ - app = self.db.get(App, app_id) + app = get_apps_by_id(self.db,app_id) if not app: logger.warning("应用不存在", extra={"app_id": str(app_id)}) raise ResourceNotFoundException("应用", str(app_id)) return app - + + def _check_workflow_config(self, app_id: uuid.UUID): + from app.models import WorkflowConfig, ModelConfig + from sqlalchemy import select + from app.core.exceptions import BusinessException + # 2. 获取 Agent 配置 + stmt = select(WorkflowConfig).where(AgentConfig.app_id == app_id) + agent_cfg = self.db.scalars(stmt).first() + if not agent_cfg: + raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) + + # 3. 获取模型配置 + model_config = None + if agent_cfg.default_model_config_id: + model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id) + + if not model_config: + raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) + def _check_agent_config(self, app_id: uuid.UUID): - from app.models import AgentConfig, ModelConfig - from app.services.app_service import AppService from app.models import AgentConfig, ModelConfig from sqlalchemy import select from app.core.exceptions import BusinessException @@ -137,63 +155,63 @@ class AppService: agent_cfg = self.db.scalars(stmt).first() if not agent_cfg: raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - + # 3. 获取模型配置 model_config = None if agent_cfg.default_model_config_id: model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id) - + if not model_config: raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) def _check_multi_agent_config(self, app_id: uuid.UUID): """检查多智能体配置的完整性 - + 验证内容: 1. 多智能体配置是否存在 2. 主 Agent 配置是否存在 3. 子 Agent 配置是否存在 4. 所有 Agent 的模型配置是否存在 - + Args: app_id: 应用 ID - + Raises: BusinessException: 配置不完整或不存在时抛出 """ - from app.models import MultiAgentConfig, AgentConfig, ModelConfig + from app.models import ModelConfig from app.services.multi_agent_service import MultiAgentService - + # 1. 检查多智能体配置是否存在 service = MultiAgentService(self.db) multi_agent_config = service.get_config(app_id) - + if not multi_agent_config: raise BusinessException( "多智能体配置不存在,无法运行", BizCode.AGENT_CONFIG_MISSING ) - + if not multi_agent_config.is_active: raise BusinessException( "多智能体配置未激活,无法运行", BizCode.AGENT_CONFIG_MISSING ) - + # 2. 检查主 Agent 配置 if not multi_agent_config.master_agent_id: raise BusinessException( "未配置主 Agent,无法运行", BizCode.AGENT_CONFIG_MISSING ) - + master_agent_release = self.db.get(AppRelease, multi_agent_config.master_agent_id) if not master_agent_release: raise BusinessException( f"主 Agent 配置不存在: {multi_agent_config.master_agent_id}", BizCode.AGENT_CONFIG_MISSING ) - + # 检查主 Agent 的模型配置 if master_agent_release.default_model_config_id: master_model = self.db.get(ModelConfig, master_agent_release.default_model_config_id) @@ -207,14 +225,14 @@ class AppService: "主 Agent 未配置模型,无法运行", BizCode.MODEL_NOT_FOUND ) - + # 3. 检查子 Agent 配置 if not multi_agent_config.sub_agents or len(multi_agent_config.sub_agents) == 0: raise BusinessException( "未配置子 Agent,无法运行", BizCode.AGENT_CONFIG_MISSING ) - + # 4. 验证每个子 Agent 及其模型配置 for idx, sub_agent_data in enumerate(multi_agent_config.sub_agents): agent_id = sub_agent_data.get('agent_id') @@ -223,7 +241,7 @@ class AppService: f"子 Agent #{idx + 1} 缺少 agent_id", BizCode.AGENT_CONFIG_MISSING ) - + # 转换为 UUID try: from uuid import UUID @@ -233,7 +251,7 @@ class AppService: f"子 Agent #{idx + 1} 的 agent_id 格式无效: {agent_id}", BizCode.INVALID_PARAMETER ) - + # 检查子 Agent 是否存在 sub_agent_release = self.db.get(AppRelease, agent_uuid) if not sub_agent_release: @@ -241,7 +259,7 @@ class AppService: f"子 Agent 配置不存在: {agent_id} ({sub_agent_data.get('name', '未命名')})", BizCode.AGENT_CONFIG_MISSING ) - + # 检查子 Agent 的模型配置 if sub_agent_release.default_model_config_id: sub_model = self.db.get(ModelConfig, sub_agent_release.default_model_config_id) @@ -255,7 +273,7 @@ class AppService: f"子 Agent '{sub_agent_data.get('name', '未命名')}' 未配置模型,无法运行", BizCode.MODEL_NOT_FOUND ) - + logger.info( "多智能体配置检查通过", extra={ @@ -266,20 +284,20 @@ class AppService: ) def _create_agent_config( - self, - app_id: uuid.UUID, - config_data: app_schema.AgentConfigCreate, + self, + app_id: uuid.UUID, + config_data: app_schema.AgentConfigCreate, now: datetime.datetime ) -> None: """创建 Agent 配置(内部方法) - + Args: app_id: 应用ID config_data: Agent 配置数据 now: 当前时间 """ storage_data = AgentConfigConverter.to_storage_format(config_data) - + agent_cfg = AgentConfig( id=uuid.uuid4(), app_id=app_id, @@ -296,7 +314,7 @@ class AppService: ) self.db.add(agent_cfg) logger.debug("Agent 配置已创建", extra={"app_id": str(app_id)}) - + def _create_multi_agent_config( self, app_id: uuid.UUID, @@ -304,7 +322,7 @@ class AppService: now: datetime.datetime ) -> None: """创建多 Agent 配置(内部方法) - + Args: app_id: 应用ID config_data: 多 Agent 配置数据(Dict) @@ -317,18 +335,18 @@ class AppService: RoutingRule, ExecutionConfig ) - + # 转换 sub_agents sub_agents = [SubAgentConfig(**sa) for sa in config_data.get('sub_agents', [])] - + # 转换 routing_rules(如果有) routing_rules = None if config_data.get('routing_rules'): routing_rules = [RoutingRule(**rr) for rr in config_data['routing_rules']] - + # 转换 execution_config execution_config = ExecutionConfig(**config_data.get('execution_config', {})) - + # 创建 MultiAgentConfigCreate 对象 config = MultiAgentConfigCreate( master_agent_id=config_data['master_agent_id'], @@ -338,18 +356,18 @@ class AppService: execution_config=execution_config, aggregation_strategy=config_data.get('aggregation_strategy', 'merge') ) - + # 验证主 Agent 存在 master_agent = self.db.get(AgentConfig, config.master_agent_id) if not master_agent: raise ResourceNotFoundException("主 Agent", str(config.master_agent_id)) - + # 验证子 Agent 存在 for sub_agent in config.sub_agents: agent = self.db.get(AgentConfig, sub_agent.agent_id) if not agent: raise ResourceNotFoundException("子 Agent", str(sub_agent.agent_id)) - + # 创建多 Agent 配置 # 将 UUID 转换为字符串以便 JSON 序列化 sub_agents_data = [] @@ -357,7 +375,7 @@ class AppService: sa_dict = sub_agent.model_dump() sa_dict['agent_id'] = str(sa_dict['agent_id']) # UUID -> str sub_agents_data.append(sa_dict) - + routing_rules_data = None if config.routing_rules: routing_rules_data = [] @@ -365,7 +383,7 @@ class AppService: rule_dict = rule.model_dump() rule_dict['target_agent_id'] = str(rule_dict['target_agent_id']) # UUID -> str routing_rules_data.append(rule_dict) - + multi_agent_cfg = MultiAgentConfig( id=uuid.uuid4(), app_id=app_id, @@ -381,31 +399,31 @@ class AppService: ) self.db.add(multi_agent_cfg) logger.debug("多 Agent 配置已创建", extra={"app_id": str(app_id), "mode": config.orchestration_mode}) - + def _get_next_version(self, app_id: uuid.UUID) -> int: """获取下一个版本号 - + Args: app_id: 应用ID - + Returns: int: 下一个版本号 """ stmt = select(func.max(AppRelease.version)).where(AppRelease.app_id == app_id) max_ver = self.db.execute(stmt).scalar() return 1 if max_ver is None else int(max_ver) + 1 - + def _convert_to_schema( self, app: App, current_workspace_id: uuid.UUID ) -> app_schema.App: """将 App 模型转换为 Schema,并设置 is_shared 字段 - + Args: app: App 模型实例 current_workspace_id: 当前工作空间ID - + Returns: app_schema.App: 应用 Schema """ @@ -428,23 +446,23 @@ class AppService: "updated_at": app.updated_at } return app_schema.App(**app_dict) - + # ==================== 应用管理 ==================== - + def get_app( self, app_id: uuid.UUID, workspace_id: Optional[uuid.UUID] = None ) -> App: """获取应用详情 - + Args: app_id: 应用ID workspace_id: 工作空间ID(用于权限验证,支持共享应用) - + Returns: App: 应用对象 - + Raises: ResourceNotFoundException: 当应用不存在时 BusinessException: 当应用不可访问时 @@ -452,24 +470,24 @@ class AppService: app = self._get_app_or_404(app_id) self._validate_app_accessible(app, workspace_id) return app - + def create_app( - self, - *, - user_id: uuid.UUID, - workspace_id: uuid.UUID, + self, + *, + user_id: uuid.UUID, + workspace_id: uuid.UUID, data: app_schema.AppCreate ) -> App: """创建应用 - + Args: user_id: 创建者用户ID workspace_id: 工作空间ID data: 应用创建数据 - + Returns: App: 创建的应用对象 - + Raises: BusinessException: 当创建失败时 """ @@ -477,7 +495,7 @@ class AppService: "创建应用", extra={"app_name": data.name, "type": data.type, "workspace_id": str(workspace_id)} ) - + try: now = datetime.datetime.now() @@ -503,45 +521,45 @@ class AppService: # 如果是 agent 类型且提供了配置,创建 AgentConfig if app.type == "agent" and data.agent_config: self._create_agent_config(app.id, data.agent_config, now) - + # 如果是 multi_agent 类型且提供了配置,创建 MultiAgentConfig if app.type == "multi_agent" and data.multi_agent_config: self._create_multi_agent_config(app.id, data.multi_agent_config, now) self.db.commit() self.db.refresh(app) - + logger.info("应用创建成功", extra={"app_id": str(app.id), "app_name": app.name}) return app - + except Exception as e: self.db.rollback() logger.error("应用创建失败", extra={"app_name": data.name, "error": str(e)}) raise BusinessException(f"应用创建失败: {str(e)}", BizCode.INTERNAL_ERROR, cause=e) - + def update_app( - self, - *, - app_id: uuid.UUID, - data: app_schema.AppUpdate, + self, + *, + app_id: uuid.UUID, + data: app_schema.AppUpdate, workspace_id: Optional[uuid.UUID] = None ) -> App: """更新应用基本信息 - + Args: app_id: 应用ID data: 更新数据 workspace_id: 工作空间ID(用于权限验证) - + Returns: App: 更新后的应用对象 - + Raises: ResourceNotFoundException: 当应用不存在时 BusinessException: 当应用不在指定工作空间时 """ logger.info("更新应用", extra={"app_id": str(app_id)}) - + app = self._get_app_or_404(app_id) self._validate_workspace_access(app, workspace_id) @@ -551,7 +569,7 @@ class AppService: if val is not None: setattr(app, field, val) changed = True - + if changed: app.updated_at = datetime.datetime.now() self.db.commit() @@ -559,9 +577,9 @@ class AppService: logger.info("应用更新成功", extra={"app_id": str(app_id)}) else: logger.debug("应用无变更", extra={"app_id": str(app_id)}) - + return app - + def delete_app( self, *, @@ -569,24 +587,24 @@ class AppService: workspace_id: Optional[uuid.UUID] = None ) -> None: """删除应用 - + Args: app_id: 应用ID workspace_id: 工作空间ID(用于权限验证) - + Raises: ResourceNotFoundException: 当应用不存在时 BusinessException: 当应用不在指定工作空间时 """ logger.info("删除应用", extra={"app_id": str(app_id)}) - + app = self._get_app_or_404(app_id) self._validate_workspace_access(app, workspace_id) - + # 逻辑删除应用 app.is_active = False self.db.commit() - + logger.info( "应用删除成功", extra={ @@ -595,7 +613,7 @@ class AppService: "app_type": app.type } ) - + def copy_app( self, *, @@ -605,36 +623,36 @@ class AppService: new_name: Optional[str] = None ) -> App: """复制应用(包括基础信息和配置) - + Args: app_id: 源应用ID user_id: 创建者用户ID workspace_id: 目标工作空间ID(如果为None,则复制到源应用所在工作空间) new_name: 新应用名称(如果为None,则使用"源应用名称 - 副本") - + Returns: App: 复制后的新应用对象 - + Raises: ResourceNotFoundException: 当源应用不存在时 BusinessException: 当复制失败时 """ logger.info("复制应用", extra={"source_app_id": str(app_id)}) - + try: # 获取源应用 source_app = self._get_app_or_404(app_id) self._validate_app_accessible(source_app, workspace_id) - + # 确定目标工作空间 target_workspace_id = workspace_id or source_app.workspace_id - + # 确定新应用名称 if not new_name: new_name = f"{source_app.name} - 副本" - + now = datetime.datetime.now() - + # 创建新应用(复制基础信息) new_app = App( id=uuid.uuid4(), @@ -654,13 +672,13 @@ class AppService: ) self.db.add(new_app) self.db.flush() - + # 如果是 agent 类型,复制 AgentConfig if source_app.type == "agent": source_config = self.db.query(AgentConfig).filter( AgentConfig.app_id == source_app.id ).first() - + if source_config: new_config = AgentConfig( id=uuid.uuid4(), @@ -677,10 +695,10 @@ class AppService: updated_at=now, ) self.db.add(new_config) - + self.db.commit() self.db.refresh(new_app) - + logger.info( "应用复制成功", extra={ @@ -689,9 +707,9 @@ class AppService: "new_app_name": new_app.name } ) - + return new_app - + except Exception as e: self.db.rollback() logger.error( @@ -699,7 +717,7 @@ class AppService: extra={"source_app_id": str(app_id), "error": str(e)} ) raise BusinessException(f"应用复制失败: {str(e)}", BizCode.INTERNAL_ERROR, cause=e) - + def list_apps( self, *, @@ -713,11 +731,11 @@ class AppService: pagesize: int = 10, ) -> Tuple[List[App], int]: """列出工作空间中的应用(分页) - + 包括: 1. 本工作空间创建的应用 2. 其他工作空间分享给本工作空间的应用(如果 include_shared=True) - + Args: workspace_id: 工作空间ID type: 应用类型过滤 @@ -727,12 +745,12 @@ class AppService: include_shared: 是否包含分享的应用 page: 页码(从1开始) pagesize: 每页数量 - + Returns: Tuple[List[App], int]: (应用列表, 总数) """ from app.models import AppShare - + logger.debug( "查询应用列表", extra={ @@ -742,7 +760,7 @@ class AppService: "pagesize": pagesize } ) - + # 构建查询条件 filters = [] filters.append(App.is_active == True) @@ -754,18 +772,18 @@ class AppService: filters.append(App.status == status) if search: filters.append(func.lower(App.name).like(f"%{search.lower()}%")) - + # 基础查询:本工作空间的应用 if include_shared: # 查询本工作空间的应用 + 分享给本工作空间的应用 # 使用 OR 条件:workspace_id = current OR app_id IN (shared apps) - + # 获取分享给本工作空间的应用ID列表 shared_app_ids_stmt = ( select(AppShare.source_app_id) .where(AppShare.target_workspace_id == workspace_id) ) - + # 构建主查询:本工作空间的应用 OR 分享的应用 stmt = select(App).where( or_( @@ -776,7 +794,7 @@ class AppService: else: # 只查询本工作空间的应用 stmt = select(App).where(App.workspace_id == workspace_id) - + # 应用过滤条件 if filters: stmt = stmt.where(and_(*filters)) @@ -790,43 +808,43 @@ class AppService: stmt = stmt.order_by(App.created_at.desc()).offset(offset).limit(pagesize) items = list(self.db.scalars(stmt).all()) - + logger.debug( "应用列表查询完成", extra={"total": total, "returned": len(items), "include_shared": include_shared} ) return items, int(total) - + # ==================== Agent 配置管理 ==================== - + def update_agent_config( - self, - *, - app_id: uuid.UUID, - data: app_schema.AgentConfigUpdate, + self, + *, + app_id: uuid.UUID, + data: app_schema.AgentConfigUpdate, workspace_id: Optional[uuid.UUID] = None ) -> AgentConfig: """更新 Agent 配置 - + Args: app_id: 应用ID data: 配置更新数据 workspace_id: 工作空间ID(用于权限验证) - + Returns: AgentConfig: 更新后的配置对象 - + Raises: ResourceNotFoundException: 当应用不存在时 BusinessException: 当应用类型不支持或不在指定工作空间时 """ logger.info("更新 Agent 配置", extra={"app_id": str(app_id)}) - + app = self._get_app_or_404(app_id) - + if app.type != "agent": raise BusinessException("只有 Agent 类型应用支持 Agent 配置", BizCode.APP_TYPE_NOT_SUPPORTED) - + self._validate_workspace_access(app, workspace_id) stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active==True).order_by(AgentConfig.updated_at.desc()) @@ -846,7 +864,7 @@ class AppService: # 转换为存储格式 storage_data = AgentConfigConverter.to_storage_format(data) - + # 更新字段 # if data.system_prompt is not None: agent_cfg.system_prompt = data.system_prompt @@ -862,67 +880,67 @@ class AppService: agent_cfg.variables = storage_data.get("variables", []) # if data.tools is not None: agent_cfg.tools = storage_data.get("tools", {}) - + agent_cfg.updated_at = now self.db.commit() self.db.refresh(agent_cfg) - + logger.info("Agent 配置更新成功", extra={"app_id": str(app_id)}) return agent_cfg - + def get_agent_config( - self, - *, - app_id: uuid.UUID, + self, + *, + app_id: uuid.UUID, workspace_id: Optional[uuid.UUID] = None ) -> AgentConfig: """获取 Agent 配置 - + 如果配置不存在,返回默认配置模板(不保存到数据库) - + Args: app_id: 应用ID workspace_id: 工作空间ID(用于权限验证) - + Returns: AgentConfig: Agent 配置对象(存在的配置或默认模板) - + Raises: ResourceNotFoundException: 当应用不存在时 BusinessException: 当应用类型不支持或不可访问时 """ logger.debug("获取 Agent 配置", extra={"app_id": str(app_id)}) - + app = self._get_app_or_404(app_id) - + if app.type != "agent": raise BusinessException("只有 Agent 类型应用支持 Agent 配置", BizCode.APP_TYPE_NOT_SUPPORTED) - + # 只读操作,允许访问共享应用 self._validate_app_accessible(app, workspace_id) stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active == True).order_by(AgentConfig.updated_at.desc()) config = self.db.scalars(stmt).first() - + if config: return config - + # 返回默认配置模板(不保存到数据库) logger.debug("配置不存在,返回默认模板", extra={"app_id": str(app_id)}) return self._create_default_agent_config(app_id) - + def _create_default_agent_config(self, app_id: uuid.UUID) -> AgentConfig: """创建默认的 Agent 配置模板(不保存到数据库) - + Args: app_id: 应用ID - + Returns: AgentConfig: 默认配置对象 """ now = datetime.datetime.now() - + # 创建一个临时的配置对象,不添加到数据库 default_config = AgentConfig( id=uuid.uuid4(), # 临时ID @@ -953,37 +971,198 @@ class AppService: created_at=now, updated_at=now, ) - + return default_config - + + def get_workflow_config( + self, + *, + app_id: uuid.UUID, + workspace_id: Optional[uuid.UUID] = None + ) -> WorkflowConfig: + """获取 workflow 配置 + + 如果配置不存在,返回默认配置模板(不保存到数据库) + + Args: + app_id: 应用ID + workspace_id: 工作空间ID(用于权限验证) + + Returns: + WorkflowConfig: Workflow 配置对象(存在的配置或默认模板) + + Raises: + ResourceNotFoundException: 当应用不存在时 + BusinessException: 当应用类型不支持或不可访问时 + """ + logger.debug("获取 Workflow 配置", extra={"app_id": str(app_id)}) + + app = self._get_app_or_404(app_id) + + if app.type != AppType.WORKFLOW: + raise BusinessException("只有 Workflow 类型应用支持 Workflow 配置", BizCode.APP_TYPE_NOT_SUPPORTED) + + # 只读操作,允许访问共享应用 + self._validate_app_accessible(app, workspace_id) + repo = WorkflowConfigRepository(self.db) + config = repo.get_by_app_id(app_id) + if config: + return config + + # 返回默认配置模板(不保存到数据库) + logger.debug("配置不存在,返回默认模板", extra={"app_id": str(app_id)}) + return self._create_default_workflow_config(app_id) + + def update_workflow_config( + self, + *, + app_id: uuid.UUID, + data: WorkflowConfigUpdate, + workspace_id: Optional[uuid.UUID] = None + ) -> WorkflowConfig: + """更新 Workflow 配置(全量更新) + + Args: + app_id: 应用ID + data: 配置更新数据(全量数据) + workspace_id: 工作空间ID(用于权限验证) + + Returns: + WorkflowConfig: 更新后的配置对象 + + Raises: + ResourceNotFoundException: 当应用不存在时 + BusinessException: 当应用类型不支持或不在指定工作空间时 + """ + logger.info("更新 Workflow 配置", extra={"app_id": str(app_id)}) + + app = self._get_app_or_404(app_id) + + if app.type != AppType.WORKFLOW: + raise BusinessException("只有 Workflow 类型应用支持 Workflow 配置", BizCode.APP_TYPE_NOT_SUPPORTED) + + self._validate_workspace_access(app, workspace_id) + + # 获取现有配置 + repo = WorkflowConfigRepository(self.db) + workflow_cfg = repo.get_by_app_id(app_id) + now = datetime.datetime.now() + + if not workflow_cfg: + # 如果配置不存在,创建新配置 + workflow_cfg = WorkflowConfig( + id=uuid.uuid4(), + app_id=app_id, + nodes=[node.model_dump() for node in data.nodes] if data.nodes else [], + edges=[edge.model_dump() for edge in data.edges] if data.edges else [], + variables=[var.model_dump() for var in data.variables] if data.variables else [], + execution_config=data.execution_config.model_dump() if data.execution_config else {}, + triggers=[trigger.model_dump() for trigger in data.triggers] if data.triggers else [], + is_active=True, + created_at=now, + updated_at=now + ) + self.db.add(workflow_cfg) + logger.debug("创建新的 Workflow 配置", extra={"app_id": str(app_id)}) + else: + # 全量更新现有配置 + workflow_cfg.nodes = [node.model_dump() for node in data.nodes] if data.nodes else [] + workflow_cfg.edges = [edge.model_dump() for edge in data.edges] if data.edges else [] + workflow_cfg.variables = [var.model_dump() for var in data.variables] if data.variables else [] + workflow_cfg.execution_config = data.execution_config.model_dump() if data.execution_config else {} + workflow_cfg.triggers = [trigger.model_dump() for trigger in data.triggers] if data.triggers else [] + workflow_cfg.updated_at = now + + self.db.commit() + self.db.refresh(workflow_cfg) + + logger.info("Workflow 配置更新成功", extra={"app_id": str(app_id)}) + return workflow_cfg + + def _create_default_workflow_config(self, app_id: uuid.UUID) -> WorkflowConfig: + """创建默认的 workflow 配置模板(不保存到数据库) + + 使用 template_loader 加载 simple_qa 模板作为默认配置 + + Args: + app_id: 应用ID + + Returns: + WorkflowConfig: 默认配置对象 + """ + from app.core.workflow.template_loader import load_workflow_template + + now = datetime.datetime.now() + + # 使用 template_loader 加载 simple_qa 模板 + template_data = load_workflow_template('simple_qa') + + if not template_data: + # 如果模板加载失败,返回最小化配置 + logger.warning( + "无法加载默认工作流模板,使用最小化配置", + extra={"app_id": str(app_id)} + ) + template_data = { + 'nodes': [ + {'id': 'start', 'type': 'start', 'name': '开始'}, + {'id': 'end', 'type': 'end', 'name': '结束'} + ], + 'edges': [ + {'source': 'start', 'target': 'end'} + ], + 'variables': [], + 'execution_config': { + 'max_execution_time': 300, + 'max_iterations': 10 + }, + 'triggers': [] + } + + # 转换为 WorkflowConfig 格式 + default_config = WorkflowConfig( + id=uuid.uuid4(), + app_id=app_id, + nodes=template_data.get('nodes', []), + edges=template_data.get('edges', []), + variables=template_data.get('variables', []), + execution_config=template_data.get('execution_config', {}), + triggers=template_data.get('triggers', []), + is_active=True, + created_at=now, + updated_at=now + ) + + return default_config + # ==================== 应用发布管理 ==================== - + def publish( - self, - *, - app_id: uuid.UUID, - publisher_id: uuid.UUID, + self, + *, + app_id: uuid.UUID, + publisher_id: uuid.UUID, version_name: str, workspace_id: Optional[uuid.UUID] = None, release_notes: Optional[str] = None ) -> AppRelease: """发布应用(创建不可变快照) - + Args: app_id: 应用ID publisher_id: 发布者用户ID workspace_id: 工作空间ID(用于权限验证) release_notes: 版本说明 - + Returns: AppRelease: 发布版本对象 - + Raises: ResourceNotFoundException: 当应用不存在时 BusinessException: 当应用缺少配置或不在指定工作空间时 """ logger.info("发布应用", extra={"app_id": str(app_id), "publisher_id": str(publisher_id)}) - + app = self._get_app_or_404(app_id) # 检查应用归属 self._validate_workspace_access(app, workspace_id) @@ -991,13 +1170,13 @@ class AppService: # 构建快照配置 config: Dict[str, Any] = {} default_model_config_id = None - + if app.type == AppType.AGENT: stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active == True).order_by(AgentConfig.updated_at.desc()) agent_cfg = self.db.scalars(stmt).first() if not agent_cfg: raise BusinessException("Agent 应用缺少配置,无法发布", BizCode.AGENT_CONFIG_MISSING) - + config = { "system_prompt": agent_cfg.system_prompt, "model_parameters": agent_cfg.model_parameters, @@ -1021,14 +1200,14 @@ class AppService: multi_agent_cfg = self.db.scalars(stmt).first() if not multi_agent_cfg: raise BusinessException("多 Agent 应用缺少有效配置,无法发布", BizCode.AGENT_CONFIG_MISSING) - + # 2. 检查配置完整性 self._check_multi_agent_config(app_id) - + # 3. 获取主 Agent 的模型配置 ID master_agent = self.db.get(AgentConfig, multi_agent_cfg.master_agent_id) default_model_config_id = master_agent.default_model_config_id if master_agent else None - + # 4. 构建配置快照 config = { "master_agent_id": str(multi_agent_cfg.master_agent_id), @@ -1038,7 +1217,7 @@ class AppService: "execution_config": multi_agent_cfg.execution_config, "aggregation_strategy": multi_agent_cfg.aggregation_strategy, } - + logger.info( "多智能体应用发布配置准备完成", extra={ @@ -1048,10 +1227,10 @@ class AppService: "orchestration_mode": multi_agent_cfg.orchestration_mode } ) - + now = datetime.datetime.now() version = self._get_next_version(app_id) - + release = AppRelease( id=uuid.uuid4(), app_id=app_id, @@ -1082,128 +1261,128 @@ class AppService: self.db.commit() self.db.refresh(release) - + logger.info( "应用发布成功", extra={"app_id": str(app_id), "version": version, "release_id": str(release.id)} ) return release - + def get_current_release( - self, - *, - app_id: uuid.UUID, + self, + *, + app_id: uuid.UUID, workspace_id: Optional[uuid.UUID] = None ) -> Optional[AppRelease]: """获取当前发布版本 - + Args: app_id: 应用ID workspace_id: 工作空间ID(用于权限验证) - + Returns: Optional[AppRelease]: 当前发布版本,如果未发布则返回 None - + Raises: ResourceNotFoundException: 当应用不存在时 BusinessException: 当应用不可访问时 """ logger.debug("获取当前发布版本", extra={"app_id": str(app_id)}) - + app = self._get_app_or_404(app_id) # 只读操作,允许访问共享应用 self._validate_app_accessible(app, workspace_id) - + if not app.current_release_id: return None - + return self.db.get(AppRelease, app.current_release_id) - + def list_releases( - self, - *, - app_id: uuid.UUID, + self, + *, + app_id: uuid.UUID, workspace_id: Optional[uuid.UUID] = None ) -> List[AppRelease]: """列出应用的所有发布版本(倒序) - + Args: app_id: 应用ID workspace_id: 工作空间ID(用于权限验证) - + Returns: List[AppRelease]: 发布版本列表 - + Raises: ResourceNotFoundException: 当应用不存在时 BusinessException: 当应用不可访问时 """ logger.debug("列出发布版本", extra={"app_id": str(app_id)}) - + app = self._get_app_or_404(app_id) # 只读操作,允许访问共享应用 self._validate_app_accessible(app, workspace_id) - + stmt = ( select(AppRelease) .where(AppRelease.app_id == app_id, AppRelease.is_active == True) .order_by(AppRelease.version.desc()) ) return list(self.db.scalars(stmt).all()) - + def rollback( - self, - *, - app_id: uuid.UUID, - version: int, + self, + *, + app_id: uuid.UUID, + version: int, workspace_id: Optional[uuid.UUID] = None ) -> AppRelease: """回滚到指定版本 - + Args: app_id: 应用ID version: 目标版本号 workspace_id: 工作空间ID(用于权限验证) - + Returns: AppRelease: 回滚到的版本对象 - + Raises: ResourceNotFoundException: 当应用或版本不存在时 BusinessException: 当应用不在指定工作空间时 """ logger.info("回滚应用", extra={"app_id": str(app_id), "version": version}) - + app = self._get_app_or_404(app_id) self._validate_app_accessible(app, workspace_id) - + stmt = select(AppRelease).where( - AppRelease.app_id == app_id, + AppRelease.app_id == app_id, AppRelease.version == version ) release = self.db.scalars(stmt).first() - + if not release: logger.warning( "发布版本不存在", extra={"app_id": str(app_id), "version": version} ) raise ResourceNotFoundException("发布版本", f"app_id={app_id}, version={version}") - + app.current_release_id = release.id app.updated_at = datetime.datetime.now() - + self.db.commit() self.db.refresh(release) - + logger.info( "应用回滚成功", extra={"app_id": str(app_id), "version": version, "release_id": str(release.id)} ) return release - + # ==================== 应用分享功能 ==================== - + def share_app( self, *, @@ -1213,22 +1392,22 @@ class AppService: workspace_id: Optional[uuid.UUID] = None ) -> List["AppShare"]: """分享应用到其他工作空间 - + Args: app_id: 应用ID target_workspace_ids: 目标工作空间ID列表 user_id: 分享者用户ID workspace_id: 当前工作空间ID(用于权限验证) - + Returns: List[AppShare]: 创建的分享记录列表 - + Raises: ResourceNotFoundException: 当应用不存在时 BusinessException: 当应用不在指定工作空间或目标工作空间无效时 """ from app.models import AppShare, Workspace - + logger.info( "分享应用", extra={ @@ -1237,28 +1416,28 @@ class AppService: "user_id": str(user_id) } ) - + # 1. 验证应用 app = self._get_app_or_404(app_id) self._validate_workspace_access(app, workspace_id) - + # 2. 验证目标工作空间 for target_ws_id in target_workspace_ids: target_ws = self.db.get(Workspace, target_ws_id) if not target_ws: raise ResourceNotFoundException("工作空间", str(target_ws_id)) - + # 不能分享给自己的工作空间 if target_ws_id == app.workspace_id: raise BusinessException( "不能分享应用到自己的工作空间", BizCode.INVALID_PARAMETER ) - + # 3. 创建分享记录 now = datetime.datetime.now() shares = [] - + for target_ws_id in target_workspace_ids: # 检查是否已经分享过 stmt = select(AppShare).where( @@ -1266,7 +1445,7 @@ class AppService: AppShare.target_workspace_id == target_ws_id ) existing_share = self.db.scalars(stmt).first() - + if existing_share: logger.debug( "应用已分享到该工作空间,跳过", @@ -1274,7 +1453,7 @@ class AppService: ) shares.append(existing_share) continue - + # 创建新的分享记录 share = AppShare( id=uuid.uuid4(), @@ -1287,14 +1466,14 @@ class AppService: ) self.db.add(share) shares.append(share) - + logger.debug( "创建分享记录", extra={"app_id": str(app_id), "target_workspace_id": str(target_ws_id)} ) - + self.db.commit() - + logger.info( "应用分享成功", extra={ @@ -1303,9 +1482,9 @@ class AppService: "app_name": app.name } ) - + return shares - + def unshare_app( self, *, @@ -1314,18 +1493,18 @@ class AppService: workspace_id: Optional[uuid.UUID] = None ) -> None: """取消应用分享 - + Args: app_id: 应用ID target_workspace_id: 目标工作空间ID workspace_id: 当前工作空间ID(用于权限验证) - + Raises: ResourceNotFoundException: 当应用或分享记录不存在时 BusinessException: 当应用不在指定工作空间时 """ from app.models import AppShare - + logger.info( "取消应用分享", extra={ @@ -1333,18 +1512,18 @@ class AppService: "target_workspace_id": str(target_workspace_id) } ) - + # 1. 验证应用 app = self._get_app_or_404(app_id) self._validate_workspace_access(app, workspace_id) - + # 2. 查找分享记录 stmt = select(AppShare).where( AppShare.source_app_id == app_id, AppShare.target_workspace_id == target_workspace_id ) share = self.db.scalars(stmt).first() - + if not share: logger.warning( "分享记录不存在", @@ -1354,16 +1533,16 @@ class AppService: "分享记录", f"app_id={app_id}, target_workspace_id={target_workspace_id}" ) - + # 3. 删除分享记录 self.db.delete(share) self.db.commit() - + logger.info( "应用分享已取消", extra={"app_id": str(app_id), "target_workspace_id": str(target_workspace_id)} ) - + def list_app_shares( self, *, @@ -1371,42 +1550,42 @@ class AppService: workspace_id: Optional[uuid.UUID] = None ) -> List["AppShare"]: """列出应用的所有分享记录 - + Args: app_id: 应用ID workspace_id: 当前工作空间ID(用于权限验证) - + Returns: List[AppShare]: 分享记录列表 - + Raises: ResourceNotFoundException: 当应用不存在时 BusinessException: 当应用不在指定工作空间时 """ from app.models import AppShare - + logger.debug("列出应用分享记录", extra={"app_id": str(app_id)}) - + # 验证应用 app = self._get_app_or_404(app_id) self._validate_workspace_access(app, workspace_id) - + # 查询分享记录 stmt = select(AppShare).where( AppShare.source_app_id == app_id ).order_by(AppShare.created_at.desc()) - + shares = list(self.db.scalars(stmt).all()) - + logger.debug( "应用分享记录查询完成", extra={"app_id": str(app_id), "count": len(shares)} ) - + return shares - + # ==================== 试运行功能 ==================== - + async def draft_run( self, *, @@ -1418,7 +1597,7 @@ class AppService: workspace_id: Optional[uuid.UUID] = None ) -> Dict[str, Any]: """试运行 Agent(使用当前草稿配置) - + Args: app_id: 应用ID message: 用户消息 @@ -1426,43 +1605,43 @@ class AppService: user_id: 用户ID(用于会话管理) variables: 自定义变量参数值 workspace_id: 工作空间ID(用于权限验证) - + Returns: Dict: 包含 AI 回复和元数据的字典 - + Raises: ResourceNotFoundException: 当应用不存在时 BusinessException: 当应用类型不支持或配置缺失时 """ from app.services.draft_run_service import DraftRunService - + logger.info("试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]}) - + # 1. 验证应用 app = self._get_app_or_404(app_id) - + if app.type != "agent": raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - + # 只读操作,允许访问共享应用 self._validate_app_accessible(app, workspace_id) - + # 2. 获取 Agent 配置 stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) agent_cfg = self.db.scalars(stmt).first() - + if not agent_cfg: raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - + # 3. 获取模型配置 model_config = None if agent_cfg.default_model_config_id: from app.models import ModelConfig model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id) - + if not model_config: raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - + # 4. 调用试运行服务 logger.debug( "准备调用试运行服务", @@ -1473,7 +1652,7 @@ class AppService: "has_variables": bool(variables) } ) - + draft_service = DraftRunService(self.db) result = await draft_service.run( agent_config=agent_cfg, @@ -1484,7 +1663,7 @@ class AppService: user_id=user_id, variables=variables ) - + logger.debug( "试运行服务返回结果", extra={ @@ -1494,7 +1673,7 @@ class AppService: "has_conversation_id": "conversation_id" in result if isinstance(result, dict) else False } ) - + logger.info( "试运行完成", extra={ @@ -1503,9 +1682,9 @@ class AppService: "model": model_config.name } ) - + return result - + async def draft_run_stream( self, *, @@ -1517,7 +1696,7 @@ class AppService: workspace_id: Optional[uuid.UUID] = None ): """试运行 Agent(流式返回) - + Args: app_id: 应用ID message: 用户消息 @@ -1525,43 +1704,43 @@ class AppService: user_id: 用户ID(用于会话管理) variables: 自定义变量参数值 workspace_id: 工作空间ID(用于权限验证) - + Yields: str: SSE 格式的事件数据 - + Raises: ResourceNotFoundException: 当应用不存在时 BusinessException: 当应用类型不支持或配置缺失时 """ from app.services.draft_run_service import DraftRunService - + logger.info("流式试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]}) - + # 1. 验证应用 app = self._get_app_or_404(app_id) - + if app.type != "agent": raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - + # 只读操作,允许访问共享应用 self._validate_app_accessible(app, workspace_id) - + # 2. 获取 Agent 配置 stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) agent_cfg = self.db.scalars(stmt).first() - + if not agent_cfg: raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - + # 3. 获取模型配置 model_config = None if agent_cfg.default_model_config_id: from app.models import ModelConfig model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id) - + if not model_config: raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - + # 4. 调用流式试运行服务 draft_service = DraftRunService(self.db) async for event in draft_service.run_stream( @@ -1574,9 +1753,9 @@ class AppService: variables=variables ): yield event - + # ==================== 多模型对比试运行 ==================== - + async def draft_run_compare( self, *, @@ -1591,7 +1770,7 @@ class AppService: timeout: int = 60 ) -> Dict[str, Any]: """多模型对比试运行 - + Args: app_id: 应用ID message: 用户消息 @@ -1602,13 +1781,13 @@ class AppService: workspace_id: 工作空间ID parallel: 是否并行执行 timeout: 超时时间(秒) - + Returns: Dict: 对比结果 """ from app.services.draft_run_service import DraftRunService from app.models import ModelConfig - + logger.info( "多模型对比试运行", extra={ @@ -1617,41 +1796,41 @@ class AppService: "parallel": parallel } ) - + # 1. 验证应用 app = self._get_app_or_404(app_id) if app.type != "agent": raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - + # 只读操作,允许访问共享应用 self._validate_app_accessible(app, workspace_id) - + # 2. 获取 Agent 配置 stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) agent_cfg = self.db.scalars(stmt).first() if not agent_cfg: raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - + # 3. 准备所有模型配置 model_configs = [] for model_item in models: model_config = self.db.get(ModelConfig, model_item.model_config_id) if not model_config: raise ResourceNotFoundException("模型配置", str(model_item.model_config_id)) - + # 合并参数:agent配置参数 + 请求覆盖参数 merged_parameters = { **(agent_cfg.model_parameters or {}), **(model_item.model_parameters or {}) } - + model_configs.append({ "model_config": model_config, "parameters": merged_parameters, "label": model_item.label or model_config.name, "model_config_id": model_item.model_config_id }) - + # 4. 调用 DraftRunService 的对比方法 draft_service = DraftRunService(self.db) result = await draft_service.run_compare( @@ -1665,7 +1844,7 @@ class AppService: parallel=parallel, timeout=timeout ) - + logger.info( "多模型对比完成", extra={ @@ -1674,9 +1853,9 @@ class AppService: "failed": result["failed_count"] } ) - + return result - + async def draft_run_compare_stream( self, *, @@ -1691,7 +1870,7 @@ class AppService: timeout: int = 60 ): """多模型对比试运行(流式返回) - + Args: app_id: 应用ID message: 用户消息 @@ -1701,13 +1880,13 @@ class AppService: variables: 变量参数 workspace_id: 工作空间ID timeout: 超时时间(秒) - + Yields: str: SSE 格式的事件数据 """ from app.services.draft_run_service import DraftRunService from app.models import ModelConfig - + logger.info( "多模型对比流式试运行", extra={ @@ -1715,41 +1894,41 @@ class AppService: "model_count": len(models) } ) - + # 1. 验证应用 app = self._get_app_or_404(app_id) if app.type != "agent": raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - + # 只读操作,允许访问共享应用 self._validate_app_accessible(app, workspace_id) - + # 2. 获取 Agent 配置 stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) agent_cfg = self.db.scalars(stmt).first() if not agent_cfg: raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - + # 3. 准备所有模型配置 model_configs = [] for model_item in models: model_config = self.db.get(ModelConfig, model_item.model_config_id) if not model_config: raise ResourceNotFoundException("模型配置", str(model_item.model_config_id)) - + # 合并参数:agent配置参数 + 请求覆盖参数 merged_parameters = { **(agent_cfg.model_parameters or {}), **(model_item.model_parameters or {}) } - + model_configs.append({ "model_config": model_config, "parameters": merged_parameters, "label": model_item.label or model_config.name, "model_config_id": model_item.model_config_id }) - + # 4. 调用 DraftRunService 的流式对比方法 draft_service = DraftRunService(self.db) async for event in draft_service.run_compare_stream( @@ -1764,7 +1943,7 @@ class AppService: timeout=timeout ): yield event - + logger.info( "多模型对比流式完成", extra={"app_id": str(app_id)} @@ -1797,15 +1976,28 @@ def update_agent_config(db: Session, *, app_id: uuid.UUID, data: app_schema.Agen service = AppService(db) return service.update_agent_config(app_id=app_id, data=data, workspace_id=workspace_id) +def update_workflow_config(db: Session, *, app_id: uuid.UUID, data: WorkflowConfigUpdate, workspace_id: uuid.UUID | None = None) -> WorkflowConfig: + """更新 Agent 配置(向后兼容接口)""" + service = AppService(db) + return service.update_workflow_config(app_id=app_id, data=data, workspace_id=workspace_id) + def get_agent_config(db: Session, *, app_id: uuid.UUID, workspace_id: uuid.UUID | None = None) -> AgentConfig: """获取 Agent 配置(向后兼容接口) - + 如果配置不存在,返回默认配置模板 """ service = AppService(db) return service.get_agent_config(app_id=app_id, workspace_id=workspace_id) +def get_workflow_config(db: Session, *, app_id: uuid.UUID, workspace_id: uuid.UUID | None = None) -> WorkflowConfig: + """获取 Agent 配置(向后兼容接口) + + 如果配置不存在,返回默认配置模板 + """ + service = AppService(db) + return service.get_workflow_config(app_id=app_id, workspace_id=workspace_id) + def publish(db: Session, *, app_id: uuid.UUID, publisher_id: uuid.UUID, workspace_id: uuid.UUID | None = None,version_name:str, release_notes: Optional[str] = None) -> AppRelease: """发布应用(向后兼容接口)""" diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py new file mode 100644 index 00000000..c604697b --- /dev/null +++ b/api/app/services/workflow_service.py @@ -0,0 +1,731 @@ +""" +工作流服务层 +""" + +import logging +import uuid +import datetime +from typing import Any, Annotated + +from sqlalchemy.orm import Session +from fastapi import Depends + +from app.models.workflow_model import WorkflowConfig, WorkflowExecution +from app.repositories.workflow_repository import ( + WorkflowConfigRepository, + WorkflowExecutionRepository, + WorkflowNodeExecutionRepository, + get_workflow_config_repository, + get_workflow_execution_repository, + get_workflow_node_execution_repository +) +from app.core.workflow.validator import validate_workflow_config +from app.core.exceptions import BusinessException +from app.core.error_codes import BizCode +from app.db import get_db +from app.schemas import DraftRunRequest + +logger = logging.getLogger(__name__) + + +class WorkflowService: + """工作流服务""" + + def __init__(self, db: Session): + self.db = db + self.config_repo = WorkflowConfigRepository(db) + self.execution_repo = WorkflowExecutionRepository(db) + self.node_execution_repo = WorkflowNodeExecutionRepository(db) + + # ==================== 配置管理 ==================== + + def create_workflow_config( + self, + app_id: uuid.UUID, + nodes: list[dict[str, Any]], + edges: list[dict[str, Any]], + variables: list[dict[str, Any]] | None = None, + execution_config: dict[str, Any] | None = None, + triggers: list[dict[str, Any]] | None = None, + validate: bool = True + ) -> WorkflowConfig: + """创建工作流配置 + + Args: + app_id: 应用 ID + nodes: 节点列表 + edges: 边列表 + variables: 变量列表 + execution_config: 执行配置 + triggers: 触发器列表 + validate: 是否验证配置 + + Returns: + 工作流配置 + + Raises: + BusinessException: 配置无效时抛出 + """ + # 构建配置字典 + config_dict = { + "nodes": nodes, + "edges": edges, + "variables": variables or [], + "execution_config": execution_config or {}, + "triggers": triggers or [] + } + + # 验证配置 + if validate: + is_valid, errors = validate_workflow_config(config_dict, for_publish=False) + if not is_valid: + logger.warning(f"工作流配置验证失败: {errors}") + raise BusinessException( + error_code=BizCode.INVALID_PARAMETER, + message=f"工作流配置无效: {'; '.join(errors)}" + ) + + # 创建或更新配置 + config = self.config_repo.create_or_update( + app_id=app_id, + nodes=nodes, + edges=edges, + variables=variables, + execution_config=execution_config, + triggers=triggers + ) + + logger.info(f"创建工作流配置成功: app_id={app_id}, config_id={config.id}") + return config + + def get_workflow_config(self, app_id: uuid.UUID) -> WorkflowConfig | None: + """获取工作流配置 + + Args: + app_id: 应用 ID + + Returns: + 工作流配置或 None + """ + return self.config_repo.get_by_app_id(app_id) + + def update_workflow_config( + self, + app_id: uuid.UUID, + nodes: list[dict[str, Any]] | None = None, + edges: list[dict[str, Any]] | None = None, + variables: list[dict[str, Any]] | None = None, + execution_config: dict[str, Any] | None = None, + triggers: list[dict[str, Any]] | None = None, + validate: bool = True + ) -> WorkflowConfig: + """更新工作流配置 + + Args: + app_id: 应用 ID + nodes: 节点列表 + edges: 边列表 + variables: 变量列表 + execution_config: 执行配置 + triggers: 触发器列表 + validate: 是否验证配置 + + Returns: + 工作流配置 + + Raises: + BusinessException: 配置不存在或无效时抛出 + """ + # 获取现有配置 + config = self.get_workflow_config(app_id) + if not config: + raise BusinessException( + error_code=BizCode.RESOURCE_NOT_FOUND, + message=f"工作流配置不存在: app_id={app_id}" + ) + + # 合并配置 + updated_nodes = nodes if nodes is not None else config.nodes + updated_edges = edges if edges is not None else config.edges + updated_variables = variables if variables is not None else config.variables + updated_execution_config = execution_config if execution_config is not None else config.execution_config + updated_triggers = triggers if triggers is not None else config.triggers + + # 构建配置字典 + config_dict = { + "nodes": updated_nodes, + "edges": updated_edges, + "variables": updated_variables, + "execution_config": updated_execution_config, + "triggers": updated_triggers + } + + # 验证配置 + if validate: + is_valid, errors = validate_workflow_config(config_dict, for_publish=False) + if not is_valid: + logger.warning(f"工作流配置验证失败: {errors}") + raise BusinessException( + error_code=BizCode.INVALID_PARAMETER, + message=f"工作流配置无效: {'; '.join(errors)}" + ) + + # 更新配置 + config = self.config_repo.create_or_update( + app_id=app_id, + nodes=updated_nodes, + edges=updated_edges, + variables=updated_variables, + execution_config=updated_execution_config, + triggers=updated_triggers + ) + + logger.info(f"更新工作流配置成功: app_id={app_id}, config_id={config.id}") + return config + + def delete_workflow_config(self, app_id: uuid.UUID) -> bool: + """删除工作流配置 + + Args: + app_id: 应用 ID + + Returns: + 是否删除成功 + """ + config = self.get_workflow_config(app_id) + if not config: + return False + + self.config_repo.delete(config.id) + logger.info(f"删除工作流配置成功: app_id={app_id}, config_id={config.id}") + return True + + def check_config(self, app_id: uuid.UUID) -> WorkflowConfig: + """检查工作流配置的完整性 + + Args: + app_id: 应用 ID + + Raises: + BusinessException: 配置不完整或不存在时抛出 + """ + + # 1. 检查多智能体配置是否存在 + config = self.get_workflow_config(app_id) + if not config: + raise BusinessException( + "工作流配置不存在,无法运行", + BizCode.CONFIG_MISSING + ) + # validator 现在支持直接接受 Pydantic 模型 + is_valid, errors = validate_workflow_config(config, for_publish=False) + if not is_valid: + logger.warning(f"工作流配置验证失败: {errors}") + raise BusinessException( + code=BizCode.INVALID_PARAMETER, + message=f"工作流配置无效: {'; '.join(errors)}" + ) + return config + + def validate_workflow_config_for_publish( + self, + app_id: uuid.UUID + ) -> tuple[bool, list[str]]: + """验证工作流配置是否可以发布 + + Args: + app_id: 应用 ID + + Returns: + (is_valid, errors): 是否有效和错误列表 + + Raises: + BusinessException: 配置不存在时抛出 + """ + config = self.get_workflow_config(app_id) + if not config: + raise BusinessException( + error_code=BizCode.RESOURCE_NOT_FOUND, + message=f"工作流配置不存在: app_id={app_id}" + ) + + config_dict = { + "nodes": config.nodes, + "edges": config.edges, + "variables": config.variables, + "execution_config": config.execution_config, + "triggers": config.triggers + } + + return validate_workflow_config(config_dict, for_publish=True) + + # ==================== 执行管理 ==================== + + def create_execution( + self, + workflow_config_id: uuid.UUID, + app_id: uuid.UUID, + trigger_type: str, + triggered_by: uuid.UUID | None = None, + conversation_id: uuid.UUID | None = None, + input_data: dict[str, Any] | None = None + ) -> WorkflowExecution: + """创建工作流执行记录 + + Args: + workflow_config_id: 工作流配置 ID + app_id: 应用 ID + trigger_type: 触发类型 + triggered_by: 触发用户 ID + conversation_id: 会话 ID + input_data: 输入数据 + + Returns: + 执行记录 + """ + # 生成执行 ID + execution_id = f"exec_{uuid.uuid4().hex[:16]}" + + execution = WorkflowExecution( + workflow_config_id=workflow_config_id, + app_id=app_id, + conversation_id=conversation_id, + execution_id=execution_id, + trigger_type=trigger_type, + triggered_by=triggered_by, + input_data=input_data or {}, + status="pending" + ) + + self.db.add(execution) + self.db.commit() + self.db.refresh(execution) + + logger.info(f"创建工作流执行记录: execution_id={execution_id}") + return execution + + def get_execution(self, execution_id: str) -> WorkflowExecution | None: + """获取执行记录 + + Args: + execution_id: 执行 ID + + Returns: + 执行记录或 None + """ + return self.execution_repo.get_by_execution_id(execution_id) + + def get_executions_by_app( + self, + app_id: uuid.UUID, + limit: int = 50, + offset: int = 0 + ) -> list[WorkflowExecution]: + """获取应用的执行记录列表 + + Args: + app_id: 应用 ID + limit: 返回数量限制 + offset: 偏移量 + + Returns: + 执行记录列表 + """ + return self.execution_repo.get_by_app_id(app_id, limit, offset) + + def update_execution_status( + self, + execution_id: str, + status: str, + output_data: dict[str, Any] | None = None, + error_message: str | None = None, + error_node_id: str | None = None + ) -> WorkflowExecution: + """更新执行状态 + + Args: + execution_id: 执行 ID + status: 状态 + output_data: 输出数据 + error_message: 错误信息 + error_node_id: 出错节点 ID + + Returns: + 执行记录 + + Raises: + BusinessException: 执行记录不存在时抛出 + """ + execution = self.get_execution(execution_id) + if not execution: + raise BusinessException( + error_code=BizCode.RESOURCE_NOT_FOUND, + message=f"执行记录不存在: execution_id={execution_id}" + ) + + execution.status = status + if output_data is not None: + execution.output_data = output_data + if error_message is not None: + execution.error_message = error_message + if error_node_id is not None: + execution.error_node_id = error_node_id + + # 如果是完成状态,计算耗时 + if status in ["completed", "failed", "cancelled", "timeout"]: + if not execution.completed_at: + execution.completed_at = datetime.datetime.now() + elapsed = (execution.completed_at - execution.started_at).total_seconds() + execution.elapsed_time = elapsed + + self.db.commit() + self.db.refresh(execution) + + logger.info(f"更新执行状态: execution_id={execution_id}, status={status}") + return execution + + def get_execution_statistics(self, app_id: uuid.UUID) -> dict[str, Any]: + """获取执行统计信息 + + Args: + app_id: 应用 ID + + Returns: + 统计信息 + """ + total = self.execution_repo.count_by_app_id(app_id) + completed = self.execution_repo.count_by_status(app_id, "completed") + failed = self.execution_repo.count_by_status(app_id, "failed") + running = self.execution_repo.count_by_status(app_id, "running") + + return { + "total": total, + "completed": completed, + "failed": failed, + "running": running, + "success_rate": completed / total if total > 0 else 0 + } + + # ==================== 工作流执行 ==================== + + async def run( + self, + app_id: uuid.UUID, + payload: DraftRunRequest, + config: WorkflowConfig + ): + """运行工作流 + + Args: + app_id: 应用 ID + input_data: 输入数据(包含 message 和 variables) + triggered_by: 触发用户 ID + conversation_id: 会话 ID(可选) + stream: 是否流式返回 + + Returns: + 执行结果(非流式)或生成器(流式) + + Raises: + BusinessException: 配置不存在或执行失败时抛出 + """ + # 1. 获取工作流配置 + if not config: + config = self.get_workflow_config(app_id) + if not config: + raise BusinessException( + code=BizCode.CONFIG_MISSING, + message=f"工作流配置不存在: app_id={app_id}" + ) + input_data = {"message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id} + + # 转换 user_id 为 UUID + triggered_by_uuid = None + if payload.user_id: + try: + triggered_by_uuid = uuid.UUID(payload.user_id) + except (ValueError, AttributeError): + logger.warning(f"无效的 user_id 格式: {payload.user_id}") + + # 转换 conversation_id 为 UUID + conversation_id_uuid = None + if payload.conversation_id: + try: + conversation_id_uuid = uuid.UUID(payload.conversation_id) + except (ValueError, AttributeError): + logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}") + + # 2. 创建执行记录 + execution = self.create_execution( + workflow_config_id=config.id, + app_id=app_id, + trigger_type="manual", + triggered_by=triggered_by_uuid, + conversation_id=conversation_id_uuid, + input_data=input_data + ) + + # 3. 构建工作流配置字典 + workflow_config_dict = { + "nodes": config.nodes, + "edges": config.edges, + "variables": config.variables, + "execution_config": config.execution_config + } + + # 4. 获取工作空间 ID(从 app 获取) + from app.models import App + + + # 5. 执行工作流 + from app.core.workflow.executor import execute_workflow, execute_workflow_stream + + try: + # 更新状态为运行中 + self.update_execution_status(execution.execution_id, "running") + + result = await execute_workflow( + workflow_config=workflow_config_dict, + input_data=input_data, + execution_id=execution.execution_id, + workspace_id="", + user_id=payload.user_id + ) + + # 更新执行结果 + if result.get("status") == "completed": + self.update_execution_status( + execution.execution_id, + "completed", + output_data=result.get("node_outputs", {}) + ) + else: + self.update_execution_status( + execution.execution_id, + "failed", + error_message=result.get("error") + ) + + # 返回增强的响应结构 + return { + "execution_id": execution.execution_id, + "status": result.get("status"), + "output": result.get("output"), # 最终输出(字符串) + "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据) + "conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID + "error_message": result.get("error"), + "elapsed_time": result.get("elapsed_time"), + "token_usage": result.get("token_usage") + } + + except Exception as e: + logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True) + self.update_execution_status( + execution.execution_id, + "failed", + error_message=str(e) + ) + raise BusinessException( + code=BizCode.INTERNAL_ERROR, + message=f"工作流执行失败: {str(e)}" + ) + + async def run_workflow( + self, + app_id: uuid.UUID, + input_data: dict[str, Any], + triggered_by: uuid.UUID, + conversation_id: uuid.UUID | None = None, + stream: bool = False + ): + """运行工作流 + + Args: + app_id: 应用 ID + input_data: 输入数据(包含 message 和 variables) + triggered_by: 触发用户 ID + conversation_id: 会话 ID(可选) + stream: 是否流式返回 + + Returns: + 执行结果(非流式)或生成器(流式) + + Raises: + BusinessException: 配置不存在或执行失败时抛出 + """ + # 1. 获取工作流配置 + config = self.get_workflow_config(app_id) + if not config: + raise BusinessException( + error_code=BizCode.RESOURCE_NOT_FOUND, + message=f"工作流配置不存在: app_id={app_id}" + ) + + # 2. 创建执行记录 + execution = self.create_execution( + workflow_config_id=config.id, + app_id=app_id, + trigger_type="manual", + triggered_by=triggered_by, + conversation_id=conversation_id, + input_data=input_data + ) + + # 3. 构建工作流配置字典 + workflow_config_dict = { + "nodes": config.nodes, + "edges": config.edges, + "variables": config.variables, + "execution_config": config.execution_config + } + + # 4. 获取工作空间 ID(从 app 获取) + from app.models import App + app = self.db.query(App).filter(App.id == app_id).first() + if not app: + raise BusinessException( + error_code=BizCode.RESOURCE_NOT_FOUND, + message=f"应用不存在: app_id={app_id}" + ) + + # 5. 执行工作流 + from app.core.workflow.executor import execute_workflow, execute_workflow_stream + + try: + # 更新状态为运行中 + self.update_execution_status(execution.execution_id, "running") + + if stream: + # 流式执行 + return self._run_workflow_stream( + workflow_config_dict, + input_data, + execution.execution_id, + str(app.workspace_id), + str(triggered_by) + ) + else: + # 非流式执行 + result = await execute_workflow( + workflow_config=workflow_config_dict, + input_data=input_data, + execution_id=execution.execution_id, + workspace_id=str(app.workspace_id), + user_id=str(triggered_by) + ) + + # 更新执行结果 + if result.get("status") == "completed": + self.update_execution_status( + execution.execution_id, + "completed", + output_data=result.get("node_outputs", {}) + ) + else: + self.update_execution_status( + execution.execution_id, + "failed", + error_message=result.get("error") + ) + + # 返回增强的响应结构 + return { + "execution_id": execution.execution_id, + "status": result.get("status"), + "output": result.get("output"), # 最终输出(字符串) + "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据) + "error_message": result.get("error"), + "elapsed_time": result.get("elapsed_time"), + "token_usage": result.get("token_usage") + } + + except Exception as e: + logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True) + self.update_execution_status( + execution.execution_id, + "failed", + error_message=str(e) + ) + raise BusinessException( + error_code=BizCode.INTERNAL_ERROR, + message=f"工作流执行失败: {str(e)}" + ) + + async def _run_workflow_stream( + self, + workflow_config: dict[str, Any], + input_data: dict[str, Any], + execution_id: str, + workspace_id: str, + user_id: str + ): + """运行工作流(流式,内部方法) + + Args: + workflow_config: 工作流配置 + input_data: 输入数据 + execution_id: 执行 ID + workspace_id: 工作空间 ID + user_id: 用户 ID + + Yields: + 流式事件 + """ + from app.core.workflow.executor import execute_workflow_stream + + try: + output_data = {} + + async for event in execute_workflow_stream( + workflow_config=workflow_config, + input_data=input_data, + execution_id=execution_id, + workspace_id=workspace_id, + user_id=user_id + ): + # 转发事件 + yield event + + # 收集输出数据 + if event.get("type") == "node_complete": + node_data = event.get("data", {}) + node_outputs = node_data.get("node_outputs", {}) + output_data.update(node_outputs) + + # 处理完成事件 + if event.get("type") == "workflow_complete": + self.update_execution_status( + execution_id, + "completed", + output_data=output_data + ) + + # 处理错误事件 + if event.get("type") == "workflow_error": + self.update_execution_status( + execution_id, + "failed", + error_message=event.get("error") + ) + + except Exception as e: + logger.error(f"工作流流式执行失败: execution_id={execution_id}, error={e}", exc_info=True) + self.update_execution_status( + execution_id, + "failed", + error_message=str(e) + ) + yield { + "type": "workflow_error", + "execution_id": execution_id, + "error": str(e) + } + + +# ==================== 依赖注入函数 ==================== + +def get_workflow_service( + db: Annotated[Session, Depends(get_db)] +) -> WorkflowService: + """获取工作流服务(依赖注入)""" + return WorkflowService(db) diff --git a/api/app/templates/workflows/customer_service/template.yml b/api/app/templates/workflows/customer_service/template.yml new file mode 100644 index 00000000..26219712 --- /dev/null +++ b/api/app/templates/workflows/customer_service/template.yml @@ -0,0 +1,219 @@ +# 智能客服工作流模板 +id: customer_service_v1 +name: 智能客服工作流 +description: 智能客服场景,包含意图识别、知识库查询和回复生成 +category: customer_service +version: "1.0.0" +author: RedBear Memory Team +tags: + - 客服 + - 意图识别 + - 知识库 + - 多步骤 + +# 工作流配置 +nodes: + - id: start + type: start + name: 开始 + position: + x: 100 + y: 200 + + - id: intent_recognition + type: llm + name: 意图识别 + config: + prompt: | + 分析用户的问题,识别意图类型。 + + 用户问题:{{ var.user_message }} + + 请从以下类型中选择一个: + - product_inquiry: 产品咨询 + - technical_support: 技术支持 + - complaint: 投诉建议 + - other: 其他 + + 只返回类型名称,不要其他内容。 + model: gpt-3.5-turbo + temperature: 0.3 + max_tokens: 50 + position: + x: 300 + y: 200 + + - id: intent_router + type: condition + name: 意图路由 + position: + x: 500 + y: 200 + + - id: product_handler + type: llm + name: 产品咨询处理 + config: + prompt: | + 用户咨询产品相关问题。 + + 问题:{{ var.user_message }} + 意图:{{ node.intent_recognition.output }} + + 请提供专业、友好的产品咨询回复。 + model: gpt-3.5-turbo + temperature: 0.7 + max_tokens: 500 + position: + x: 700 + y: 100 + + - id: support_handler + type: llm + name: 技术支持处理 + config: + prompt: | + 用户需要技术支持。 + + 问题:{{ var.user_message }} + 意图:{{ node.intent_recognition.output }} + + 请提供详细的技术支持方案。 + model: gpt-3.5-turbo + temperature: 0.5 + max_tokens: 800 + position: + x: 700 + y: 200 + + - id: complaint_handler + type: llm + name: 投诉处理 + config: + prompt: | + 用户提出投诉或建议。 + + 问题:{{ var.user_message }} + 意图:{{ node.intent_recognition.output }} + + 请以同理心回应,并提供解决方案。 + model: gpt-3.5-turbo + temperature: 0.8 + max_tokens: 600 + position: + x: 700 + y: 300 + + - id: general_handler + type: llm + name: 通用处理 + config: + prompt: | + 用户的问题类型:其他 + + 问题:{{ var.user_message }} + + 请提供友好的回复。 + model: gpt-3.5-turbo + temperature: 0.7 + max_tokens: 400 + position: + x: 700 + y: 400 + + - id: end + type: end + name: 结束 + position: + x: 900 + y: 200 + +edges: + - source: start + target: intent_recognition + label: 开始分析 + + - source: intent_recognition + target: intent_router + label: 识别完成 + + - source: intent_router + target: product_handler + condition: "'product_inquiry' in node['intent_recognition']['output']" + label: 产品咨询 + + - source: intent_router + target: support_handler + condition: "'technical_support' in node['intent_recognition']['output']" + label: 技术支持 + + - source: intent_router + target: complaint_handler + condition: "'complaint' in node['intent_recognition']['output']" + label: 投诉建议 + + - source: intent_router + target: general_handler + condition: "True" # 默认路径 + label: 其他 + + - source: product_handler + target: end + label: 完成 + + - source: support_handler + target: end + label: 完成 + + - source: complaint_handler + target: end + label: 完成 + + - source: general_handler + target: end + label: 完成 + +# 变量定义 +variables: + - name: user_message + type: string + required: true + description: 用户的消息 + default: "" + + - name: user_name + type: string + required: false + description: 用户姓名(可选) + default: "客户" + +# 执行配置 +execution_config: + max_execution_time: 120 + max_iterations: 10 + +# 触发器 +triggers: [] + +# 使用示例 +examples: + - name: 产品咨询 + description: 用户咨询产品功能 + input: + user_message: "你们的产品支持多语言吗?" + user_name: "张三" + expected_output: "产品功能介绍" + + - name: 技术支持 + description: 用户遇到技术问题 + input: + user_message: "我无法登录系统,一直显示密码错误" + user_name: "李四" + expected_output: "技术支持方案" + + - name: 投诉处理 + description: 用户提出投诉 + input: + user_message: "你们的服务态度太差了,我要投诉" + user_name: "王五" + expected_output: "同理心回应和解决方案" diff --git a/api/app/templates/workflows/data_processing/template.yml b/api/app/templates/workflows/data_processing/template.yml new file mode 100644 index 00000000..73c9ebff --- /dev/null +++ b/api/app/templates/workflows/data_processing/template.yml @@ -0,0 +1,131 @@ +# 数据处理工作流模板 +id: data_processing_v1 +name: 数据处理工作流 +description: 数据提取、转换和分析的完整流程 +category: data_processing +version: "1.0.0" +author: RedBear Memory Team +tags: + - 数据处理 + - ETL + - 分析 + - Transform + +# 工作流配置 +nodes: + - id: start + type: start + name: 开始 + position: + x: 100 + y: 200 + + - id: extract_data + type: transform + name: 数据提取 + config: + expression: | + { + "raw_text": var['input_text'], + "length": len(var['input_text']), + "timestamp": sys['execution_id'] + } + position: + x: 300 + y: 200 + + - id: analyze_data + type: llm + name: 数据分析 + config: + prompt: | + 请分析以下数据: + + 原始文本:{{ node.extract_data.raw_text }} + 文本长度:{{ node.extract_data.length }} + + 请提供: + 1. 主题分类 + 2. 情感分析 + 3. 关键信息提取 + + 以 JSON 格式返回结果。 + model: gpt-3.5-turbo + temperature: 0.3 + max_tokens: 500 + position: + x: 500 + y: 200 + + - id: transform_result + type: transform + name: 结果转换 + config: + expression: | + { + "original_length": node['extract_data']['length'], + "analysis": node['analyze_data']['output'], + "processed_at": sys['execution_id'], + "status": "completed" + } + position: + x: 700 + y: 200 + + - id: end + type: end + name: 结束 + position: + x: 900 + y: 200 + +edges: + - source: start + target: extract_data + label: 开始提取 + + - source: extract_data + target: analyze_data + label: 开始分析 + + - source: analyze_data + target: transform_result + label: 转换结果 + + - source: transform_result + target: end + label: 完成 + +# 变量定义 +variables: + - name: input_text + type: string + required: true + description: 待处理的文本数据 + default: "" + +# 执行配置 +execution_config: + max_execution_time: 180 + max_iterations: 5 + +# 触发器 +triggers: [] + +# 使用示例 +examples: + - name: 文本分析 + description: 分析一段文本 + input: + input_text: "今天天气真好,心情也很愉快。我们公司推出了新产品,市场反响热烈。" + expected_output: + original_length: 35 + analysis: "主题:天气和产品,情感:积极" + status: "completed" + + - name: 长文本处理 + description: 处理较长的文本 + input: + input_text: "这是一段很长的文本..." + expected_output: + status: "completed" diff --git a/api/app/templates/workflows/multi_step_qa/template.yml b/api/app/templates/workflows/multi_step_qa/template.yml new file mode 100644 index 00000000..ce04c162 --- /dev/null +++ b/api/app/templates/workflows/multi_step_qa/template.yml @@ -0,0 +1,99 @@ +# 多步骤问答工作流 +# 演示节点输出参数的使用 + +id: multi_step_qa_v1 +name: 多步骤问答工作流 +description: 先分析问题,再生成答案,展示节点间的数据传递 +category: advanced +version: "1.0.0" +author: RedBear Memory Team +tags: + - 问答 + - 多步骤 + - LLM + +# 工作流配置 +nodes: + - id: start + type: start + name: 开始 + position: + x: 100 + y: 100 + + - id: analyze_question + type: llm + name: 分析问题 + description: 分析用户问题的类型和意图 + config: + model_id: gpt-3.5-turbo + temperature: 0.3 + max_tokens: 500 + messages: + - role: system + content: | + 你是一个问题分析专家。请分析用户的问题,提取以下信息: + 1. 问题类型(事实性、观点性、操作性等) + 2. 问题领域(科技、历史、文化等) + 3. 关键词 + - role: user + content: "{{ sys.message }}" + position: + x: 300 + y: 100 + + - id: generate_answer + type: llm + name: 生成答案 + description: 根据问题分析结果生成详细答案 + config: + model_id: gpt-3.5-turbo + temperature: 0.7 + max_tokens: 1000 + messages: + - role: system + content: | + 你是一个专业的AI助手。根据问题分析结果,生成准确、详细的答案。 + + 问题分析结果: + {{ analyze_question.output }} + - role: user + content: "{{ sys.message }}" + position: + x: 500 + y: 100 + + - id: end + type: end + name: 结束 + config: + output: "{{ generate_answer.output }}" + position: + x: 700 + y: 100 + +edges: + - source: start + target: analyze_question + label: 开始分析 + + - source: analyze_question + target: generate_answer + label: 生成答案 + + - source: generate_answer + target: end + label: 完成 + +# 变量定义 +variables: + - name: user_question + type: string + required: true + description: 用户的问题 + default: "" + +# 执行配置 +execution_config: + max_execution_time: 120 + max_iterations: 1 diff --git a/api/app/templates/workflows/simple_qa/template.yml b/api/app/templates/workflows/simple_qa/template.yml new file mode 100644 index 00000000..1b68d55d --- /dev/null +++ b/api/app/templates/workflows/simple_qa/template.yml @@ -0,0 +1,100 @@ +# 简单问答工作流模板 +id: simple_qa_v1 +name: 简单问答工作流 +description: 最基础的问答工作流,适合快速开始 +category: basic +version: "1.0.0" +author: RedBear Memory Team +tags: + - 问答 + - 基础 + - LLM + +# 工作流配置 +nodes: + - id: start + type: start + name: 开始 + position: + x: 100 + y: 100 + + - id: llm_qa + type: llm + name: LLM 问答 + config: + # 使用 LangChain 标准的消息格式 + messages: + - role: system + content: | + 你是一个专业、友好且乐于助人的 AI 助手。 + + 你的职责: + - 准确理解用户的问题并提供有价值的回答 + - 保持回答的专业性和准确性 + - 如果不确定答案,诚实地告知用户 + - 使用清晰、易懂的语言进行交流 + + 回答风格: + - 简洁明了,直击要点 + - 必要时提供详细解释和示例 + - 使用友好、礼貌的语气 + - 适当使用格式化(如列表、段落)提高可读性 + + - role: user + content: "{{ sys.message }}" + + model_id: gpt-3.5-turbo + temperature: 0.7 + max_tokens: 1000 + position: + x: 300 + y: 100 + + - id: end + type: end + name: 结束 + config: + output: "{{ llm_qa.output }}" + position: + x: 500 + y: 100 + +edges: + - source: start + target: llm_qa + label: 开始处理 + + - source: llm_qa + target: end + label: 完成 + +# 变量定义 +variables: + - name: user_question + type: string + required: true + description: 用户的问题 + default: "" + +# 执行配置 +execution_config: + max_execution_time: 60 + max_iterations: 1 + +# 触发器(可选) +triggers: [] + +# 使用示例 +examples: + - name: 基础问答 + description: 询问一个简单的问题 + input: + user_question: "什么是人工智能?" + expected_output: "关于人工智能的解释" + + - name: 技术咨询 + description: 询问技术问题 + input: + user_question: "如何学习 Python 编程?" + expected_output: "Python 学习建议"