[ADD] Merge code
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -20,7 +20,8 @@ examples/
|
|||||||
.idea
|
.idea
|
||||||
|
|
||||||
# Temporary outputs
|
# Temporary outputs
|
||||||
**/.DS_Store
|
app/core/memory/agent/.DS_Store
|
||||||
|
app/core/memory/src/utils/.DS_Store
|
||||||
time.log
|
time.log
|
||||||
celerybeat-schedule.db
|
celerybeat-schedule.db
|
||||||
search_results.json
|
search_results.json
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from . import (
|
|||||||
release_share_controller,
|
release_share_controller,
|
||||||
public_share_controller,
|
public_share_controller,
|
||||||
multi_agent_controller,
|
multi_agent_controller,
|
||||||
|
workflow_controller,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建管理端 API 路由器
|
# 创建管理端 API 路由器
|
||||||
@@ -56,5 +57,6 @@ manager_router.include_router(release_share_controller.router)
|
|||||||
manager_router.include_router(public_share_controller.router) # 公开路由(无需认证)
|
manager_router.include_router(public_share_controller.router) # 公开路由(无需认证)
|
||||||
manager_router.include_router(memory_dashboard_controller.router)
|
manager_router.include_router(memory_dashboard_controller.router)
|
||||||
manager_router.include_router(multi_agent_controller.router)
|
manager_router.include_router(multi_agent_controller.router)
|
||||||
|
manager_router.include_router(workflow_controller.router)
|
||||||
|
|
||||||
__all__ = ["manager_router"]
|
__all__ = ["manager_router"]
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""API Key 管理接口 - 基于 JWT 认证"""
|
"""API Key 管理接口 - 基于 JWT 认证"""
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -14,6 +13,7 @@ from app.core.response_utils import success
|
|||||||
from app.schemas import api_key_schema
|
from app.schemas import api_key_schema
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services.api_key_service import ApiKeyService
|
from app.services.api_key_service import ApiKeyService
|
||||||
|
from app.core.api_key_utils import timestamp_to_datetime
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.exceptions import (
|
from app.core.exceptions import (
|
||||||
BusinessException,
|
BusinessException,
|
||||||
@@ -41,18 +41,14 @@ def create_api_key(
|
|||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
# 创建 API Key
|
# 创建 API Key
|
||||||
api_key_obj, api_key = ApiKeyService.create_api_key(
|
api_key_obj = ApiKeyService.create_api_key(
|
||||||
db,
|
db,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
data=data
|
data=data
|
||||||
)
|
)
|
||||||
|
|
||||||
# 返回包含明文 Key 的响应
|
response_data = api_key_schema.ApiKeyResponse.model_validate(api_key_obj)
|
||||||
response_data = api_key_schema.ApiKeyResponse(
|
|
||||||
**api_key_obj.__dict__,
|
|
||||||
api_key=api_key
|
|
||||||
)
|
|
||||||
|
|
||||||
return success(data=response_data, msg="API Key 创建成功")
|
return success(data=response_data, msg="API Key 创建成功")
|
||||||
except BusinessException:
|
except BusinessException:
|
||||||
@@ -223,13 +219,9 @@ def regenerate_api_key(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
api_key_obj, api_key = ApiKeyService.regenerate_api_key(db, api_key_id, workspace_id)
|
api_key_obj = ApiKeyService.regenerate_api_key(db, api_key_id, workspace_id)
|
||||||
|
|
||||||
# 返回包含明文 Key 的响应
|
response_data = api_key_schema.ApiKeyResponse.model_validate(api_key_obj)
|
||||||
response_data = api_key_schema.ApiKeyResponse(
|
|
||||||
**api_key_obj.__dict__,
|
|
||||||
api_key=api_key
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("API Key 重新生成成功", extra={
|
logger.info("API Key 重新生成成功", extra={
|
||||||
"api_key_id": str(api_key_id),
|
"api_key_id": str(api_key_id),
|
||||||
@@ -283,8 +275,8 @@ def get_api_key_stats(
|
|||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def get_api_key_logs(
|
def get_api_key_logs(
|
||||||
api_key_id: uuid.UUID,
|
api_key_id: uuid.UUID,
|
||||||
start_date: Optional[datetime] = Query(None, description="开始日期"),
|
start_date: Optional[int] = Query(None, description="开始日期时间戳"),
|
||||||
end_date: Optional[datetime] = Query(None, description="结束日期"),
|
end_date: Optional[int] = Query(None, description="结束日期时间戳"),
|
||||||
status_code: Optional[int] = Query(None, description="HTTP状态码过滤"),
|
status_code: Optional[int] = Query(None, description="HTTP状态码过滤"),
|
||||||
endpoint: Optional[str] = Query(None, description="端点路径过滤"),
|
endpoint: Optional[str] = Query(None, description="端点路径过滤"),
|
||||||
page: int = Query(1, ge=1, description="页码"),
|
page: int = Query(1, ge=1, description="页码"),
|
||||||
@@ -302,14 +294,17 @@ def get_api_key_logs(
|
|||||||
try:
|
try:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
start_datetime = timestamp_to_datetime(start_date) if start_date else None
|
||||||
|
end_datetime = timestamp_to_datetime(end_date) if end_date else None
|
||||||
|
|
||||||
# 验证日期范围
|
# 验证日期范围
|
||||||
if start_date and end_date and start_date > end_date:
|
if start_datetime and end_datetime and start_datetime > end_datetime:
|
||||||
logger.warning("开始日期晚于结束日期", extra={
|
logger.warning("开始日期晚于结束日期", extra={
|
||||||
"api_key_id": str(api_key_id),
|
"api_key_id": str(api_key_id),
|
||||||
"workspace_id": str(workspace_id),
|
"workspace_id": str(workspace_id),
|
||||||
"user_id": str(current_user.id),
|
"user_id": str(current_user.id),
|
||||||
"start_date": start_date.isoformat(),
|
"start_date": start_datetime.isoformat(),
|
||||||
"end_date": end_date.isoformat()
|
"end_date": end_datetime.isoformat()
|
||||||
})
|
})
|
||||||
raise BusinessException("开始日期不能晚于结束日期", BizCode.INVALID_PARAMETER)
|
raise BusinessException("开始日期不能晚于结束日期", BizCode.INVALID_PARAMETER)
|
||||||
|
|
||||||
@@ -325,8 +320,8 @@ def get_api_key_logs(
|
|||||||
|
|
||||||
# 构建过滤条件
|
# 构建过滤条件
|
||||||
filters = {
|
filters = {
|
||||||
"start_date": start_date,
|
"start_date": start_datetime,
|
||||||
"end_date": end_date,
|
"end_date": end_datetime,
|
||||||
"status_code": status_code,
|
"status_code": status_code,
|
||||||
"endpoint": endpoint
|
"endpoint": endpoint
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,22 +1,26 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional, Annotated
|
||||||
from fastapi import APIRouter, Depends
|
|
||||||
|
from fastapi import APIRouter, Depends, Path
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.db import get_db
|
from app.core.error_codes import BizCode
|
||||||
from app.core.response_utils import success
|
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||||
from app.models import User
|
from app.models import User
|
||||||
|
from app.models.app_model import AppType, App
|
||||||
from app.repositories import knowledge_repository
|
from app.repositories import knowledge_repository
|
||||||
from app.schemas import app_schema
|
from app.schemas import app_schema
|
||||||
from app.schemas.response_schema import PageData, PageMeta
|
from app.schemas.response_schema import PageData, PageMeta
|
||||||
|
from app.schemas.workflow_schema import WorkflowConfigUpdate
|
||||||
from app.services import app_service, workspace_service
|
from app.services import app_service, workspace_service
|
||||||
from app.services.app_service import AppService
|
|
||||||
from app.services.agent_config_helper import enrich_agent_config
|
from app.services.agent_config_helper import enrich_agent_config
|
||||||
from app.dependencies import get_current_user, cur_workspace_access_guard, workspace_access_guard
|
from app.services.app_service import AppService
|
||||||
from fastapi.responses import StreamingResponse
|
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
|
||||||
from app.models.app_model import AppType
|
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||||
from app.core.error_codes import BizCode
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
@@ -340,6 +344,7 @@ async def draft_run(
|
|||||||
payload: app_schema.DraftRunRequest,
|
payload: app_schema.DraftRunRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user=Depends(get_current_user),
|
current_user=Depends(get_current_user),
|
||||||
|
workflow_service: Annotated[WorkflowService, Depends(get_workflow_service)] = None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
试运行 Agent,使用当前的草稿配置(未发布的配置)
|
试运行 Agent,使用当前的草稿配置(未发布的配置)
|
||||||
@@ -374,17 +379,28 @@ async def draft_run(
|
|||||||
from app.models import AgentConfig, ModelConfig
|
from app.models import AgentConfig, ModelConfig
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.services.draft_run_service import DraftRunService
|
||||||
|
|
||||||
service = AppService(db)
|
service = AppService(db)
|
||||||
|
draft_service = DraftRunService(db)
|
||||||
|
|
||||||
# 1. 验证应用
|
# 1. 验证应用
|
||||||
app = service._get_app_or_404(app_id)
|
app = service._get_app_or_404(app_id)
|
||||||
if app.type != AppType.AGENT and app.type != AppType.MULTI_AGENT:
|
if app.type != AppType.AGENT and app.type != AppType.MULTI_AGENT and app.type != AppType.WORKFLOW:
|
||||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
raise BusinessException("只有 Agent , Workflow 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||||
|
|
||||||
# 只读操作,允许访问共享应用
|
# 只读操作,允许访问共享应用
|
||||||
service._validate_app_accessible(app, workspace_id)
|
service._validate_app_accessible(app, workspace_id)
|
||||||
|
|
||||||
|
# 处理会话ID(创建或验证)
|
||||||
|
conversation_id = await draft_service._ensure_conversation(
|
||||||
|
conversation_id=payload.conversation_id,
|
||||||
|
app_id=app_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
user_id=payload.user_id
|
||||||
|
)
|
||||||
|
payload.conversation_id = conversation_id
|
||||||
|
|
||||||
if app.type == AppType.AGENT:
|
if app.type == AppType.AGENT:
|
||||||
service._check_agent_config(app_id)
|
service._check_agent_config(app_id)
|
||||||
|
|
||||||
@@ -405,8 +421,8 @@ async def draft_run(
|
|||||||
# 流式返回
|
# 流式返回
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
from app.services.draft_run_service import DraftRunService
|
|
||||||
draft_service = DraftRunService(db)
|
|
||||||
async for event in draft_service.run_stream(
|
async for event in draft_service.run_stream(
|
||||||
agent_config=agent_cfg,
|
agent_config=agent_cfg,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
@@ -553,7 +569,66 @@ async def draft_run(
|
|||||||
data=result,
|
data=result,
|
||||||
msg="多 Agent 任务执行成功"
|
msg="多 Agent 任务执行成功"
|
||||||
)
|
)
|
||||||
|
elif app.type == AppType.WORKFLOW: #工作流
|
||||||
|
config = workflow_service.check_config(app_id)
|
||||||
|
# 3. 流式返回
|
||||||
|
if payload.stream:
|
||||||
|
logger.debug(
|
||||||
|
"开始多智能体流式试运行",
|
||||||
|
extra={
|
||||||
|
"app_id": str(app_id),
|
||||||
|
"message_length": len(payload.message),
|
||||||
|
"has_conversation_id": bool(payload.conversation_id)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def event_generator():
|
||||||
|
"""多智能体流式事件生成器"""
|
||||||
|
multiservice = MultiAgentService(db)
|
||||||
|
|
||||||
|
# 调用多智能体服务的流式方法
|
||||||
|
async for event in multiservice.run_stream(
|
||||||
|
app_id=app_id,
|
||||||
|
request=multi_agent_request,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
|
|
||||||
|
):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. 非流式返回
|
||||||
|
logger.debug(
|
||||||
|
"开始非流式试运行",
|
||||||
|
extra={
|
||||||
|
"app_id": str(app_id),
|
||||||
|
"message_length": len(payload.message),
|
||||||
|
"has_conversation_id": bool(payload.conversation_id)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await workflow_service.run(app_id, payload,config)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"工作流试运行返回结果",
|
||||||
|
extra={
|
||||||
|
"result_type": str(type(result)),
|
||||||
|
"has_response": "response" in result if isinstance(result, dict) else False
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return success(
|
||||||
|
data=result,
|
||||||
|
msg="工作流任务执行成功"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -714,3 +789,34 @@ async def draft_run_compare(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return success(data=app_schema.DraftRunCompareResponse(**result))
|
return success(data=app_schema.DraftRunCompareResponse(**result))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{app_id}/workflow")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
async def get_workflow_config(
|
||||||
|
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)]
|
||||||
|
|
||||||
|
):
|
||||||
|
"""获取工作流配置
|
||||||
|
|
||||||
|
获取应用的工作流配置详情。
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
cfg = app_service.get_workflow_config(db=db, app_id=app_id, workspace_id=workspace_id)
|
||||||
|
# 配置总是存在(不存在时返回默认模板)
|
||||||
|
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||||
|
|
||||||
|
@router.put("/{app_id}/workflow", summary="更新 Workflow 配置")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
async def update_workflow_config(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
payload: WorkflowConfigUpdate,
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)]
|
||||||
|
):
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||||
|
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||||
|
|
||||||
|
|||||||
@@ -29,10 +29,10 @@ router = APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{kb_id}/{parent_id}/documents", response_model=ApiResponse)
|
@router.get("/{kb_id}/documents", response_model=ApiResponse)
|
||||||
async def get_documents(
|
async def get_documents(
|
||||||
kb_id: uuid.UUID,
|
kb_id: uuid.UUID,
|
||||||
parent_id: uuid.UUID,
|
parent_id: Optional[uuid.UUID] = Query(None, description="parent folder id when type is Folder"),
|
||||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||||
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),
|
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
"""App 服务接口 - 基于 API Key 认证"""
|
"""App 服务接口 - 基于 API Key 认证"""
|
||||||
from fastapi import APIRouter, Depends
|
import uuid
|
||||||
|
from fastapi import APIRouter, Depends, Request, Body
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.api_key_auth import require_api_key
|
||||||
|
from app.schemas.api_key_schema import ApiKeyAuth
|
||||||
|
|
||||||
router = APIRouter(prefix="/apps", tags=["V1 - App API"])
|
router = APIRouter(prefix="/apps", tags=["V1 - App API"])
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
@@ -14,3 +17,30 @@ logger = get_business_logger()
|
|||||||
async def list_apps():
|
async def list_apps():
|
||||||
"""列出可访问的应用(占位)"""
|
"""列出可访问的应用(占位)"""
|
||||||
return success(data=[], msg="App API - Coming Soon")
|
return success(data=[], msg="App API - Coming Soon")
|
||||||
|
|
||||||
|
# /v1/apps/{resource_id}/chat
|
||||||
|
@router.post("/{resource_id}/chat")
|
||||||
|
@require_api_key(scopes=["app"])
|
||||||
|
async def chat_with_agent_demo(
|
||||||
|
resource_id: uuid.UUID,
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(..., description="聊天消息内容"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Agent 聊天接口demo
|
||||||
|
|
||||||
|
scopes: 所需的权限范围列表["app", "rag", "memory"]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resource_id: 如果是应用的apikey传的是应用id; 如果是服务的apikey传的是工作空间id
|
||||||
|
message: 请求参数
|
||||||
|
request: 声明请求
|
||||||
|
api_key_auth: 包含验证后的API Key 信息
|
||||||
|
db: db_session
|
||||||
|
"""
|
||||||
|
logger.info(f"API Key Auth: {api_key_auth}")
|
||||||
|
logger.info(f"Resource ID: {resource_id}")
|
||||||
|
logger.info(f"Message: {message}")
|
||||||
|
return success(data={"received": True}, msg="消息已接收")
|
||||||
|
|||||||
587
api/app/controllers/workflow_controller.py
Normal file
587
api/app/controllers/workflow_controller.py
Normal file
@@ -0,0 +1,587 @@
|
|||||||
|
"""
|
||||||
|
工作流 API 控制器
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Path, Query
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||||
|
|
||||||
|
from app.models.user_model import User
|
||||||
|
from app.models.app_model import App
|
||||||
|
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||||
|
from app.schemas.workflow_schema import (
|
||||||
|
WorkflowConfigCreate,
|
||||||
|
WorkflowConfigUpdate,
|
||||||
|
WorkflowConfig,
|
||||||
|
WorkflowValidationResponse,
|
||||||
|
WorkflowExecution,
|
||||||
|
WorkflowNodeExecution,
|
||||||
|
WorkflowExecutionRequest,
|
||||||
|
WorkflowExecutionResponse
|
||||||
|
)
|
||||||
|
from app.core.response_utils import success, fail
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/apps", tags=["workflow"])
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 工作流配置管理 ====================
|
||||||
|
|
||||||
|
@router.post("/{app_id}/workflow")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
async def create_workflow_config(
|
||||||
|
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
|
config: WorkflowConfigCreate,
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||||
|
):
|
||||||
|
"""创建工作流配置
|
||||||
|
|
||||||
|
创建或更新应用的工作流配置。配置会进行基础验证,但允许保存不完整的配置(草稿)。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证应用是否存在且属于当前工作空间
|
||||||
|
app = db.query(App).filter(
|
||||||
|
App.id == app_id,
|
||||||
|
App.workspace_id == current_user.current_workspace_id,
|
||||||
|
App.is_active == True
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not app:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
msg="应用不存在或无权访问"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证应用类型
|
||||||
|
if app.type != "workflow":
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INVALID_PARAMETER,
|
||||||
|
msg=f"应用类型必须为 workflow,当前为 {app.type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建工作流配置
|
||||||
|
workflow_config = service.create_workflow_config(
|
||||||
|
app_id=app_id,
|
||||||
|
nodes=[node.model_dump() for node in config.nodes],
|
||||||
|
edges=[edge.model_dump() for edge in config.edges],
|
||||||
|
variables=[var.model_dump() for var in config.variables],
|
||||||
|
execution_config=config.execution_config.model_dump(),
|
||||||
|
triggers=[trigger.model_dump() for trigger in config.triggers],
|
||||||
|
validate=True # 进行基础验证
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data=WorkflowConfig.model_validate(workflow_config),
|
||||||
|
msg="工作流配置创建成功"
|
||||||
|
)
|
||||||
|
|
||||||
|
except BusinessException as e:
|
||||||
|
logger.warning(f"创建工作流配置失败: {e.message}")
|
||||||
|
return fail(code=e.error_code, msg=e.message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建工作流配置异常: {e}", exc_info=True)
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INTERNAL_ERROR,
|
||||||
|
msg=f"创建工作流配置失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
#
|
||||||
|
# @router.get("/{app_id}/workflow")
|
||||||
|
# async def get_workflow_config(
|
||||||
|
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
|
# db: Annotated[Session, Depends(get_db)],
|
||||||
|
# current_user: Annotated[User, Depends(get_current_user)]
|
||||||
|
#
|
||||||
|
# ):
|
||||||
|
# """获取工作流配置
|
||||||
|
#
|
||||||
|
# 获取应用的工作流配置详情。
|
||||||
|
# """
|
||||||
|
# try:
|
||||||
|
# # 验证应用是否存在且属于当前工作空间
|
||||||
|
# app = db.query(App).filter(
|
||||||
|
# App.id == app_id,
|
||||||
|
# App.workspace_id == current_user.current_workspace_id,
|
||||||
|
# App.is_active == True
|
||||||
|
# ).first()
|
||||||
|
#
|
||||||
|
# if not app:
|
||||||
|
# return fail(
|
||||||
|
# code=BizCode.NOT_FOUND,
|
||||||
|
# msg="应用不存在或无权访问"
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# # 获取工作流配置
|
||||||
|
# service = WorkflowService(db)
|
||||||
|
# workflow_config = service.get_workflow_config(app_id)
|
||||||
|
#
|
||||||
|
# if not workflow_config:
|
||||||
|
# return fail(
|
||||||
|
# code=BizCode.NOT_FOUND,
|
||||||
|
# msg="工作流配置不存在"
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# return success(
|
||||||
|
# data=WorkflowConfig.model_validate(workflow_config)
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"获取工作流配置异常: {e}", exc_info=True)
|
||||||
|
# return fail(
|
||||||
|
# code=BizCode.INTERNAL_ERROR,
|
||||||
|
# msg=f"获取工作流配置失败: {str(e)}"
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
# @router.put("/{app_id}/workflow")
|
||||||
|
# async def update_workflow_config(
|
||||||
|
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
|
# config: WorkflowConfigUpdate,
|
||||||
|
# db: Annotated[Session, Depends(get_db)],
|
||||||
|
# current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
# service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||||
|
# ):
|
||||||
|
# """更新工作流配置
|
||||||
|
|
||||||
|
# 更新应用的工作流配置。可以部分更新,未提供的字段保持不变。
|
||||||
|
# """
|
||||||
|
# try:
|
||||||
|
# # 验证应用是否存在且属于当前工作空间
|
||||||
|
# app = db.query(App).filter(
|
||||||
|
# App.id == app_id,
|
||||||
|
# App.workspace_id == current_user.current_workspace_id,
|
||||||
|
# App.is_active == True
|
||||||
|
# ).first()
|
||||||
|
|
||||||
|
# if not app:
|
||||||
|
# return fail(
|
||||||
|
# code=BizCode.NOT_FOUND,
|
||||||
|
# msg="应用不存在或无权访问"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # 更新工作流配置
|
||||||
|
# workflow_config = service.update_workflow_config(
|
||||||
|
# app_id=app_id,
|
||||||
|
# nodes=[node.model_dump() for node in config.nodes] if config.nodes else None,
|
||||||
|
# edges=[edge.model_dump() for edge in config.edges] if config.edges else None,
|
||||||
|
# variables=[var.model_dump() for var in config.variables] if config.variables else None,
|
||||||
|
# execution_config=config.execution_config.model_dump() if config.execution_config else None,
|
||||||
|
# triggers=[trigger.model_dump() for trigger in config.triggers] if config.triggers else None,
|
||||||
|
# validate=True
|
||||||
|
# )
|
||||||
|
|
||||||
|
# return success(
|
||||||
|
# data=WorkflowConfig.model_validate(workflow_config),
|
||||||
|
# msg="工作流配置更新成功"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# except BusinessException as e:
|
||||||
|
# logger.warning(f"更新工作流配置失败: {e.message}")
|
||||||
|
# return fail(code=e.error_code, msg=e.message)
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"更新工作流配置异常: {e}", exc_info=True)
|
||||||
|
# return fail(
|
||||||
|
# code=BizCode.INTERNAL_ERROR,
|
||||||
|
# msg=f"更新工作流配置失败: {str(e)}"
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{app_id}/workflow")
|
||||||
|
async def delete_workflow_config(
|
||||||
|
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||||
|
):
|
||||||
|
"""删除工作流配置
|
||||||
|
|
||||||
|
删除应用的工作流配置。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证应用是否存在且属于当前工作空间
|
||||||
|
app = db.query(App).filter(
|
||||||
|
App.id == app_id,
|
||||||
|
App.workspace_id == current_user.current_workspace_id,
|
||||||
|
App.is_active == True
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not app:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
msg="应用不存在或无权访问"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 删除工作流配置
|
||||||
|
deleted = service.delete_workflow_config(app_id)
|
||||||
|
|
||||||
|
if not deleted:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
msg="工作流配置不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(msg="工作流配置删除成功")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除工作流配置异常: {e}", exc_info=True)
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INTERNAL_ERROR,
|
||||||
|
msg=f"删除工作流配置失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{app_id}/workflow/validate")
|
||||||
|
async def validate_workflow_config(
|
||||||
|
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||||
|
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
|
||||||
|
):
|
||||||
|
"""验证工作流配置
|
||||||
|
|
||||||
|
验证工作流配置是否有效。可以选择是否进行发布级别的严格验证。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证应用是否存在且属于当前工作空间
|
||||||
|
app = db.query(App).filter(
|
||||||
|
App.id == app_id,
|
||||||
|
App.workspace_id == current_user.current_workspace_id,
|
||||||
|
App.is_active == True
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not app:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
msg="应用不存在或无权访问"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证工作流配置
|
||||||
|
|
||||||
|
if for_publish:
|
||||||
|
is_valid, errors = service.validate_workflow_config_for_publish(app_id)
|
||||||
|
else:
|
||||||
|
workflow_config = service.get_workflow_config(app_id)
|
||||||
|
if not workflow_config:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
msg="工作流配置不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.core.workflow.validator import validate_workflow_config as validate_config
|
||||||
|
config_dict = {
|
||||||
|
"nodes": workflow_config.nodes,
|
||||||
|
"edges": workflow_config.edges,
|
||||||
|
"variables": workflow_config.variables,
|
||||||
|
"execution_config": workflow_config.execution_config,
|
||||||
|
"triggers": workflow_config.triggers
|
||||||
|
}
|
||||||
|
is_valid, errors = validate_config(config_dict, for_publish=False)
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data=WorkflowValidationResponse(
|
||||||
|
is_valid=is_valid,
|
||||||
|
errors=errors,
|
||||||
|
warnings=[]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
except BusinessException as e:
|
||||||
|
logger.warning(f"验证工作流配置失败: {e.message}")
|
||||||
|
return fail(code=e.error_code, msg=e.message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"验证工作流配置异常: {e}", exc_info=True)
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INTERNAL_ERROR,
|
||||||
|
msg=f"验证工作流配置失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 工作流执行管理 ====================
|
||||||
|
|
||||||
|
@router.get("/{app_id}/workflow/executions")
|
||||||
|
async def get_workflow_executions(
|
||||||
|
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||||
|
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||||
|
offset: Annotated[int, Query(ge=0)] = 0
|
||||||
|
):
|
||||||
|
"""获取工作流执行记录列表
|
||||||
|
|
||||||
|
获取应用的工作流执行历史记录。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证应用是否存在且属于当前工作空间
|
||||||
|
app = db.query(App).filter(
|
||||||
|
App.id == app_id,
|
||||||
|
App.workspace_id == current_user.current_workspace_id,
|
||||||
|
App.is_active == True
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not app:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
msg="应用不存在或无权访问"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取执行记录
|
||||||
|
executions = service.get_executions_by_app(app_id, limit, offset)
|
||||||
|
|
||||||
|
# 获取统计信息
|
||||||
|
statistics = service.get_execution_statistics(app_id)
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data={
|
||||||
|
"executions": [WorkflowExecution.model_validate(e) for e in executions],
|
||||||
|
"statistics": statistics,
|
||||||
|
"pagination": {
|
||||||
|
"limit": limit,
|
||||||
|
"offset": offset,
|
||||||
|
"total": statistics["total"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取工作流执行记录异常: {e}", exc_info=True)
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INTERNAL_ERROR,
|
||||||
|
msg=f"获取工作流执行记录失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/workflow/executions/{execution_id}")
|
||||||
|
async def get_workflow_execution(
|
||||||
|
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||||
|
):
|
||||||
|
"""获取工作流执行详情
|
||||||
|
|
||||||
|
获取单个工作流执行的详细信息,包括所有节点的执行记录。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 获取执行记录
|
||||||
|
execution = service.get_execution(execution_id)
|
||||||
|
|
||||||
|
if not execution:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
msg="执行记录不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证应用是否属于当前工作空间
|
||||||
|
app = db.query(App).filter(
|
||||||
|
App.id == execution.app_id,
|
||||||
|
App.workspace_id == current_user.current_workspace_id,
|
||||||
|
App.is_active == True
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not app:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
msg="无权访问该执行记录"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取节点执行记录
|
||||||
|
node_executions = service.node_execution_repo.get_by_execution_id(execution.id)
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data={
|
||||||
|
"execution": WorkflowExecution.model_validate(execution),
|
||||||
|
"node_executions": [
|
||||||
|
WorkflowNodeExecution.model_validate(ne) for ne in node_executions
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取工作流执行详情异常: {e}", exc_info=True)
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INTERNAL_ERROR,
|
||||||
|
msg=f"获取工作流执行详情失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 工作流执行 ====================
|
||||||
|
|
||||||
|
@router.post("/{app_id}/workflow/run")
|
||||||
|
async def run_workflow(
|
||||||
|
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||||
|
request: WorkflowExecutionRequest,
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||||
|
):
|
||||||
|
"""执行工作流
|
||||||
|
|
||||||
|
执行工作流并返回结果。支持流式和非流式两种模式。
|
||||||
|
|
||||||
|
**非流式模式**:等待工作流执行完成后返回完整结果。
|
||||||
|
|
||||||
|
**流式模式**:实时返回执行过程中的事件(节点开始、节点完成、工作流完成等)。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证应用是否存在且属于当前工作空间
|
||||||
|
app = db.query(App).filter(
|
||||||
|
App.id == app_id,
|
||||||
|
App.workspace_id == current_user.current_workspace_id,
|
||||||
|
App.is_active == True
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not app:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
msg="应用不存在或无权访问"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证应用类型
|
||||||
|
if app.type != "workflow":
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INVALID_PARAMETER,
|
||||||
|
msg=f"应用类型必须为 workflow,当前为 {app.type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 准备输入数据
|
||||||
|
input_data = {
|
||||||
|
"message": request.message or "",
|
||||||
|
"variables": request.variables
|
||||||
|
}
|
||||||
|
|
||||||
|
# 执行工作流
|
||||||
|
|
||||||
|
if request.stream:
|
||||||
|
# 流式执行
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
import json
|
||||||
|
|
||||||
|
async def event_generator():
|
||||||
|
"""生成 SSE 事件"""
|
||||||
|
try:
|
||||||
|
async for event in service.run_workflow(
|
||||||
|
app_id=app_id,
|
||||||
|
input_data=input_data,
|
||||||
|
triggered_by=current_user.id,
|
||||||
|
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||||
|
stream=True
|
||||||
|
):
|
||||||
|
# 转换为 SSE 格式
|
||||||
|
yield f"data: {json.dumps(event)}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"流式执行异常: {e}", exc_info=True)
|
||||||
|
error_event = {
|
||||||
|
"type": "error",
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(error_event)}\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 非流式执行
|
||||||
|
result = await service.run_workflow(
|
||||||
|
app_id=app_id,
|
||||||
|
input_data=input_data,
|
||||||
|
triggered_by=current_user.id,
|
||||||
|
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||||
|
stream=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data=WorkflowExecutionResponse(
|
||||||
|
execution_id=result["execution_id"],
|
||||||
|
status=result["status"],
|
||||||
|
output=result.get("output"),
|
||||||
|
output_data=result.get("output_data"),
|
||||||
|
error_message=result.get("error_message"),
|
||||||
|
elapsed_time=result.get("elapsed_time"),
|
||||||
|
token_usage=result.get("token_usage")
|
||||||
|
),
|
||||||
|
msg="工作流执行完成"
|
||||||
|
)
|
||||||
|
|
||||||
|
except BusinessException as e:
|
||||||
|
logger.warning(f"执行工作流失败: {e.message}")
|
||||||
|
return fail(code=e.error_code, msg=e.message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"执行工作流异常: {e}", exc_info=True)
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INTERNAL_ERROR,
|
||||||
|
msg=f"执行工作流失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/workflow/executions/{execution_id}/cancel")
|
||||||
|
async def cancel_workflow_execution(
|
||||||
|
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||||
|
):
|
||||||
|
"""取消工作流执行
|
||||||
|
|
||||||
|
取消正在运行的工作流执行。
|
||||||
|
|
||||||
|
**注意**:当前版本仅更新状态为 cancelled,实际的执行取消功能待实现。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 获取执行记录
|
||||||
|
execution = service.get_execution(execution_id)
|
||||||
|
|
||||||
|
if not execution:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
msg="执行记录不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证应用是否属于当前工作空间
|
||||||
|
app = db.query(App).filter(
|
||||||
|
App.id == execution.app_id,
|
||||||
|
App.workspace_id == current_user.current_workspace_id,
|
||||||
|
App.is_active == True
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not app:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.NOT_FOUND,
|
||||||
|
msg="无权访问该执行记录"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查执行状态
|
||||||
|
if execution.status not in ["pending", "running"]:
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INVALID_PARAMETER,
|
||||||
|
msg=f"无法取消状态为 {execution.status} 的执行"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新状态为 cancelled
|
||||||
|
service.update_execution_status(execution_id, "cancelled")
|
||||||
|
|
||||||
|
return success(msg="工作流执行已取消")
|
||||||
|
|
||||||
|
except BusinessException as e:
|
||||||
|
logger.warning(f"取消工作流执行失败: {e.message}")
|
||||||
|
return fail(code=e.error_code, msg=e.message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
|
||||||
|
return fail(
|
||||||
|
code=BizCode.INTERNAL_ERROR,
|
||||||
|
msg=f"取消工作流执行失败: {str(e)}"
|
||||||
|
)
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from fastapi import Request, Response
|
from fastapi import Request, Response
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.api_key_utils import add_rate_limit_headers
|
from app.core.api_key_utils import add_rate_limit_headers
|
||||||
@@ -22,21 +24,17 @@ logger = get_api_logger()
|
|||||||
|
|
||||||
|
|
||||||
def require_api_key(
|
def require_api_key(
|
||||||
scopes: Optional[List[str]] = None,
|
scopes: Optional[List[str]] = None
|
||||||
resource_type: Optional[str] = None
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
API Key 鉴权装饰器
|
API Key 鉴权装饰器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
scopes: 所需的权限范围列表["app:all",
|
scopes: 所需的权限范围列表[“app”, "rag", "memory"]
|
||||||
"rag:search", "rag:upload", "rag:delete",
|
|
||||||
"memory:read", "memory:write", "memory:delete", "memory:search"]
|
|
||||||
resource_type: 所需的资源类型("Agent", "Cluster", "Workflow", "Knowledge", "Memory_Engine")
|
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
@router.get("/app/{resource_id}/chat")
|
@router.get("/app/{resource_id}/chat")
|
||||||
@require_api_key(scopes=["app:all"], resource_type="Agent")
|
@require_api_key(scopes=["app"])
|
||||||
def chat_with_app(
|
def chat_with_app(
|
||||||
resource_id: uuid.UUID,
|
resource_id: uuid.UUID,
|
||||||
api_key_auth: ApiKeyAuth = Depends(),
|
api_key_auth: ApiKeyAuth = Depends(),
|
||||||
@@ -113,31 +111,25 @@ def require_api_key(
|
|||||||
context={"required_scopes": scopes, "missing_scopes": missing_scopes}
|
context={"required_scopes": scopes, "missing_scopes": missing_scopes}
|
||||||
)
|
)
|
||||||
|
|
||||||
if resource_type:
|
resource_id = kwargs.get("resource_id")
|
||||||
resource_id = kwargs.get("resource_id")
|
if resource_id and not ApiKeyAuthService.check_resource(
|
||||||
if resource_id and not ApiKeyAuthService.check_resource(
|
api_key_obj,
|
||||||
api_key_obj,
|
resource_id
|
||||||
resource_type,
|
):
|
||||||
resource_id
|
logger.warning("API Key 资源访问被拒绝", extra={
|
||||||
):
|
"api_key_id": str(api_key_obj.id),
|
||||||
logger.warning("API Key 资源访问被拒绝", extra={
|
"required_resource_id": str(resource_id),
|
||||||
"api_key_id": str(api_key_obj.id),
|
"bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None,
|
||||||
"required_resource_type": resource_type,
|
"endpoint": str(request.url)
|
||||||
|
})
|
||||||
|
return BusinessException(
|
||||||
|
"API Key 未授权访问该资源",
|
||||||
|
BizCode.API_KEY_INVALID_RESOURCE,
|
||||||
|
context={
|
||||||
"required_resource_id": str(resource_id),
|
"required_resource_id": str(resource_id),
|
||||||
"bound_resource_type": api_key_obj.resource_type,
|
"bound_resource_id": str(api_key_obj.resource_id)
|
||||||
"bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None,
|
}
|
||||||
"endpoint": str(request.url)
|
)
|
||||||
})
|
|
||||||
return BusinessException(
|
|
||||||
"API Key 未授权访问该资源",
|
|
||||||
BizCode.API_KEY_INVALID_RESOURCE,
|
|
||||||
context={
|
|
||||||
"required_resource_type": resource_type,
|
|
||||||
"required_resource_id": str(resource_id),
|
|
||||||
"bound_resource_type": api_key_obj.resource_type,
|
|
||||||
"bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
kwargs["api_key_auth"] = ApiKeyAuth(
|
kwargs["api_key_auth"] = ApiKeyAuth(
|
||||||
api_key_id=api_key_obj.id,
|
api_key_id=api_key_obj.id,
|
||||||
@@ -145,14 +137,17 @@ def require_api_key(
|
|||||||
type=api_key_obj.type,
|
type=api_key_obj.type,
|
||||||
scopes=api_key_obj.scopes,
|
scopes=api_key_obj.scopes,
|
||||||
resource_id=api_key_obj.resource_id,
|
resource_id=api_key_obj.resource_id,
|
||||||
resource_type=api_key_obj.resource_type
|
|
||||||
)
|
)
|
||||||
|
start_time = time.perf_counter()
|
||||||
response = await func(*args, **kwargs)
|
response = await func(*args, **kwargs)
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
response_time = (end_time - start_time) * 1000
|
||||||
|
if not isinstance(response, Response):
|
||||||
|
response = JSONResponse(content=response)
|
||||||
response = add_rate_limit_headers(response, rate_headers)
|
response = add_rate_limit_headers(response, rate_headers)
|
||||||
|
|
||||||
asyncio.create_task(log_api_key_usage(
|
asyncio.create_task(log_api_key_usage(
|
||||||
db, api_key_obj.id, request, response
|
db, api_key_obj.id, request, response, response_time
|
||||||
))
|
))
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@@ -204,7 +199,8 @@ async def log_api_key_usage(
|
|||||||
db: Session,
|
db: Session,
|
||||||
api_key_id: uuid.UUID,
|
api_key_id: uuid.UUID,
|
||||||
request: Request,
|
request: Request,
|
||||||
response: Response
|
response: Response,
|
||||||
|
response_time: float
|
||||||
):
|
):
|
||||||
"""记录 API Key 使用日志"""
|
"""记录 API Key 使用日志"""
|
||||||
try:
|
try:
|
||||||
@@ -216,8 +212,8 @@ async def log_api_key_usage(
|
|||||||
"ip_address": request.client.host if request.client else None,
|
"ip_address": request.client.host if request.client else None,
|
||||||
"user_agent": request.headers.get("User-Agent"),
|
"user_agent": request.headers.get("User-Agent"),
|
||||||
"status_code": response.status_code if hasattr(response, "status_code") else None,
|
"status_code": response.status_code if hasattr(response, "status_code") else None,
|
||||||
"response_time": None, # 需要在 middleware 中计算
|
"response_time": round(response_time),
|
||||||
"tokens_used": None, # 需要从响应中提取
|
"tokens_used": None,
|
||||||
"created_at": datetime.now()
|
"created_at": datetime.now()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,33 +1,14 @@
|
|||||||
"""API Key 工具函数"""
|
"""API Key 工具函数"""
|
||||||
import secrets
|
import secrets
|
||||||
import hashlib
|
from typing import Optional, Union
|
||||||
from typing import Optional
|
from datetime import datetime
|
||||||
|
|
||||||
from app.schemas.api_key_schema import ApiKeyType
|
from app.schemas.api_key_schema import ApiKeyType
|
||||||
from fastapi import Response
|
from fastapi import Response
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
|
||||||
class ResourceType:
|
def generate_api_key(key_type: ApiKeyType) -> str:
|
||||||
"""资源类型常量"""
|
|
||||||
AGENT = "Agent"
|
|
||||||
CLUSTER = "Cluster"
|
|
||||||
WORKFLOW = "Workflow"
|
|
||||||
KNOWLEDGE = "Knowledge"
|
|
||||||
MEMORY_ENGINE = "Memory_Engine"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_all_types(cls) -> list[str]:
|
|
||||||
"""获取所有支持的资源类型"""
|
|
||||||
return [cls.AGENT, cls.CLUSTER, cls.WORKFLOW, cls.KNOWLEDGE, cls.MEMORY_ENGINE]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_valid_type(cls, resource_type: str) -> bool:
|
|
||||||
"""验证资源类型是否有效"""
|
|
||||||
return resource_type in cls.get_all_types()
|
|
||||||
|
|
||||||
|
|
||||||
def generate_api_key(key_type: ApiKeyType) -> tuple[str, str, str]:
|
|
||||||
"""
|
"""
|
||||||
生成 API Key
|
生成 API Key
|
||||||
|
|
||||||
@@ -39,102 +20,17 @@ def generate_api_key(key_type: ApiKeyType) -> tuple[str, str, str]:
|
|||||||
"""
|
"""
|
||||||
# 前缀映射
|
# 前缀映射
|
||||||
prefix_map = {
|
prefix_map = {
|
||||||
ApiKeyType.APP: "sk-app-",
|
ApiKeyType.AGENT: "sk-agent-",
|
||||||
ApiKeyType.RAG: "sk-rag-",
|
ApiKeyType.CLUSTER: "sk-cluster-",
|
||||||
ApiKeyType.MEMORY: "sk-mem-",
|
ApiKeyType.WORKFLOW: "sk-workflow-",
|
||||||
|
ApiKeyType.SERVICE: "sk-service-"
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix = prefix_map[key_type]
|
prefix = prefix_map[key_type]
|
||||||
random_string = secrets.token_urlsafe(32)[:32] # 32 字符
|
random_string = secrets.token_urlsafe(32)[:32] # 32 字符
|
||||||
api_key = f"{prefix}{random_string}"
|
api_key = f"{prefix}{random_string}"
|
||||||
|
|
||||||
# 生成哈希值存储
|
return api_key
|
||||||
key_hash = hash_api_key(api_key)
|
|
||||||
|
|
||||||
return api_key, key_hash, prefix
|
|
||||||
|
|
||||||
|
|
||||||
def hash_api_key(api_key: str) -> str:
|
|
||||||
"""对 API Key 进行哈希
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key: API Key 明文
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 哈希值
|
|
||||||
"""
|
|
||||||
return hashlib.sha256(api_key.encode()).hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
def verify_api_key(api_key: str, key_hash: str) -> bool:
|
|
||||||
"""
|
|
||||||
验证 API Key
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key: API Key 明文
|
|
||||||
key_hash: 存储的哈希值
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否匹配
|
|
||||||
"""
|
|
||||||
computed_hash = hash_api_key(api_key)
|
|
||||||
return secrets.compare_digest(computed_hash, key_hash)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_resource_binding(
|
|
||||||
resource_type: Optional[str],
|
|
||||||
resource_id: Optional[str]
|
|
||||||
) -> tuple[bool, str]:
|
|
||||||
"""
|
|
||||||
验证资源绑定的有效性
|
|
||||||
|
|
||||||
Args:
|
|
||||||
resource_type: 资源类型
|
|
||||||
resource_id: 资源ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (是否有效, 错误信息)
|
|
||||||
"""
|
|
||||||
# 如果都为空,表示不绑定资源,这是有效的
|
|
||||||
if not resource_type and not resource_id:
|
|
||||||
return True, ""
|
|
||||||
|
|
||||||
# 如果只有一个为空,这是无效的
|
|
||||||
if not resource_type or not resource_id:
|
|
||||||
return False, "resource_type 和 resource_id 必须同时提供或同时为空"
|
|
||||||
|
|
||||||
# 验证资源类型是否支持
|
|
||||||
if not ResourceType.is_valid_type(resource_type):
|
|
||||||
valid_types = ", ".join(ResourceType.get_all_types())
|
|
||||||
return False, f"不支持的资源类型 '{resource_type}',支持的类型:{valid_types}"
|
|
||||||
|
|
||||||
return True, ""
|
|
||||||
|
|
||||||
|
|
||||||
def get_resource_scope_mapping() -> dict[str, list[str]]:
|
|
||||||
"""
|
|
||||||
获取资源类型与权限范围的映射关系
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 资源类型到推荐权限范围的映射
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
ResourceType.AGENT: [
|
|
||||||
"app:all"
|
|
||||||
],
|
|
||||||
ResourceType.CLUSTER: [
|
|
||||||
"app:all"
|
|
||||||
],
|
|
||||||
ResourceType.WORKFLOW: [
|
|
||||||
"app:all"
|
|
||||||
],
|
|
||||||
ResourceType.KNOWLEDGE: [
|
|
||||||
"rag:search", "rag:upload", "rag:delete"
|
|
||||||
],
|
|
||||||
ResourceType.MEMORY_ENGINE: [
|
|
||||||
"memory:read", "memory:write", "memory:delete", "memory:search"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def add_rate_limit_headers(response, headers: dict):
|
def add_rate_limit_headers(response, headers: dict):
|
||||||
@@ -151,3 +47,21 @@ def add_rate_limit_headers(response, headers: dict):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def timestamp_to_datetime(timestamp: Optional[Union[int, float]]) -> Optional[datetime]:
|
||||||
|
"""将时间戳转换为datetime对象"""
|
||||||
|
if timestamp is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 处理毫秒级时间戳
|
||||||
|
if timestamp > 1e10:
|
||||||
|
timestamp = timestamp / 1000
|
||||||
|
|
||||||
|
return datetime.fromtimestamp(timestamp)
|
||||||
|
|
||||||
|
|
||||||
|
def datetime_to_timestamp(dt: Optional[datetime]) -> Optional[int]:
|
||||||
|
"""将datetime对象转换为时间戳(毫秒)"""
|
||||||
|
if dt is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return int(dt.timestamp() * 1000)
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ class BizCode(IntEnum):
|
|||||||
EMBED_NOT_ALLOWED = 6009
|
EMBED_NOT_ALLOWED = 6009
|
||||||
PERMISSION_DENIED = 6010
|
PERMISSION_DENIED = 6010
|
||||||
INVALID_CONVERSATION = 6011
|
INVALID_CONVERSATION = 6011
|
||||||
|
CONFIG_MISSING = 6012
|
||||||
|
|
||||||
# 模型(7xxx)
|
# 模型(7xxx)
|
||||||
MODEL_CONFIG_INVALID = 7001
|
MODEL_CONFIG_INVALID = 7001
|
||||||
|
|||||||
436
api/app/core/workflow/executor.py
Normal file
436
api/app/core/workflow/executor.py
Normal file
@@ -0,0 +1,436 @@
|
|||||||
|
"""
|
||||||
|
工作流执行器
|
||||||
|
|
||||||
|
基于 LangGraph 的工作流执行引擎。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from langgraph.graph import StateGraph, START, END
|
||||||
|
|
||||||
|
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||||
|
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||||
|
from app.models.workflow_model import WorkflowExecution, WorkflowNodeExecution
|
||||||
|
from app.db import get_db
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowExecutor:
|
||||||
|
"""工作流执行器
|
||||||
|
|
||||||
|
负责将工作流配置转换为 LangGraph 并执行。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
workflow_config: dict[str, Any],
|
||||||
|
execution_id: str,
|
||||||
|
workspace_id: str,
|
||||||
|
user_id: str
|
||||||
|
):
|
||||||
|
"""初始化执行器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_config: 工作流配置
|
||||||
|
execution_id: 执行 ID
|
||||||
|
workspace_id: 工作空间 ID
|
||||||
|
user_id: 用户 ID
|
||||||
|
"""
|
||||||
|
self.workflow_config = workflow_config
|
||||||
|
self.execution_id = execution_id
|
||||||
|
self.workspace_id = workspace_id
|
||||||
|
self.user_id = user_id
|
||||||
|
self.nodes = workflow_config.get("nodes", [])
|
||||||
|
self.edges = workflow_config.get("edges", [])
|
||||||
|
self.execution_config = workflow_config.get("execution_config", {})
|
||||||
|
|
||||||
|
def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
|
||||||
|
"""准备初始状态(注入系统变量和会话变量)
|
||||||
|
|
||||||
|
变量命名空间:
|
||||||
|
- sys.xxx - 系统变量(execution_id, workspace_id, user_id, message, input_variables 等)
|
||||||
|
- conv.xxx - 会话变量(跨多轮对话保持)
|
||||||
|
- node_id.xxx - 节点输出(执行时动态生成)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_data: 输入数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
初始化的工作流状态
|
||||||
|
"""
|
||||||
|
user_message = input_data.get("message") or ""
|
||||||
|
conversation_vars = input_data.get("conversation_vars") or {}
|
||||||
|
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
|
||||||
|
|
||||||
|
# 构建分层的变量结构
|
||||||
|
variables = {
|
||||||
|
"sys": {
|
||||||
|
"message": user_message, # 用户消息
|
||||||
|
"conversation_id": input_data.get("conversation_id"), # 会话 ID
|
||||||
|
"execution_id": self.execution_id, # 执行 ID
|
||||||
|
"workspace_id": self.workspace_id, # 工作空间 ID
|
||||||
|
"user_id": self.user_id, # 用户 ID
|
||||||
|
"input_variables": input_variables, # 自定义输入变量(给 Start 节点使用)
|
||||||
|
},
|
||||||
|
"conv": conversation_vars # 会话级变量(跨多轮对话保持)
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"messages": [HumanMessage(content=user_message)],
|
||||||
|
"variables": variables,
|
||||||
|
"node_outputs": {},
|
||||||
|
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)
|
||||||
|
"execution_id": self.execution_id,
|
||||||
|
"workspace_id": self.workspace_id,
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"error": None,
|
||||||
|
"error_node": None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def build_graph(self) -> StateGraph:
|
||||||
|
"""构建 LangGraph
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
编译后的状态图
|
||||||
|
"""
|
||||||
|
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
|
||||||
|
|
||||||
|
# 1. 创建状态图
|
||||||
|
workflow = StateGraph(WorkflowState)
|
||||||
|
|
||||||
|
# 2. 添加所有节点(包括 start 和 end)
|
||||||
|
start_node_id = None
|
||||||
|
end_node_ids = []
|
||||||
|
|
||||||
|
for node in self.nodes:
|
||||||
|
node_type = node.get("type")
|
||||||
|
node_id = node.get("id")
|
||||||
|
|
||||||
|
# 记录 start 和 end 节点 ID
|
||||||
|
if node_type == "start":
|
||||||
|
start_node_id = node_id
|
||||||
|
elif node_type == "end":
|
||||||
|
end_node_ids.append(node_id)
|
||||||
|
|
||||||
|
# 创建节点实例(现在 start 和 end 也会被创建)
|
||||||
|
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
||||||
|
if node_instance:
|
||||||
|
# 包装节点的 run 方法
|
||||||
|
# 使用函数工厂避免闭包问题
|
||||||
|
def make_node_func(inst):
|
||||||
|
async def node_func(state: WorkflowState):
|
||||||
|
return await inst.run(state)
|
||||||
|
return node_func
|
||||||
|
|
||||||
|
workflow.add_node(node_id, make_node_func(node_instance))
|
||||||
|
logger.debug(f"添加节点: {node_id} (type={node_type})")
|
||||||
|
|
||||||
|
# 3. 添加边
|
||||||
|
# 从 START 连接到 start 节点
|
||||||
|
if start_node_id:
|
||||||
|
workflow.add_edge(START, start_node_id)
|
||||||
|
logger.debug(f"添加边: START -> {start_node_id}")
|
||||||
|
|
||||||
|
for edge in self.edges:
|
||||||
|
source = edge.get("source")
|
||||||
|
target = edge.get("target")
|
||||||
|
edge_type = edge.get("type")
|
||||||
|
condition = edge.get("condition")
|
||||||
|
|
||||||
|
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start)
|
||||||
|
if source == start_node_id:
|
||||||
|
# 但要连接 start 到下一个节点
|
||||||
|
workflow.add_edge(source, target)
|
||||||
|
logger.debug(f"添加边: {source} -> {target}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 处理到 end 节点的边
|
||||||
|
if target in end_node_ids:
|
||||||
|
# 连接到 end 节点
|
||||||
|
workflow.add_edge(source, target)
|
||||||
|
logger.debug(f"添加边: {source} -> {target}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 跳过错误边(在节点内部处理)
|
||||||
|
if edge_type == "error":
|
||||||
|
continue
|
||||||
|
|
||||||
|
if condition:
|
||||||
|
# 条件边
|
||||||
|
def router(state: WorkflowState, cond=condition, tgt=target):
|
||||||
|
"""条件路由函数"""
|
||||||
|
if evaluate_condition(
|
||||||
|
cond,
|
||||||
|
state.get("variables", {}),
|
||||||
|
state.get("node_outputs", {}),
|
||||||
|
{
|
||||||
|
"execution_id": state.get("execution_id"),
|
||||||
|
"workspace_id": state.get("workspace_id"),
|
||||||
|
"user_id": state.get("user_id")
|
||||||
|
}
|
||||||
|
):
|
||||||
|
return tgt
|
||||||
|
return END # 条件不满足,结束
|
||||||
|
|
||||||
|
workflow.add_conditional_edges(source, router)
|
||||||
|
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
|
||||||
|
else:
|
||||||
|
# 普通边
|
||||||
|
workflow.add_edge(source, target)
|
||||||
|
logger.debug(f"添加边: {source} -> {target}")
|
||||||
|
|
||||||
|
# 从 end 节点连接到 END
|
||||||
|
for end_node_id in end_node_ids:
|
||||||
|
workflow.add_edge(end_node_id, END)
|
||||||
|
logger.debug(f"添加边: {end_node_id} -> END")
|
||||||
|
|
||||||
|
# 4. 编译图
|
||||||
|
graph = workflow.compile()
|
||||||
|
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
|
||||||
|
|
||||||
|
return graph
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
input_data: dict[str, Any]
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""执行工作流(非流式)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_data: 输入数据,包含 message 和 variables
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行结果,包含 status, output, node_outputs, elapsed_time, token_usage
|
||||||
|
"""
|
||||||
|
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
|
||||||
|
|
||||||
|
# 记录开始时间
|
||||||
|
start_time = datetime.datetime.now()
|
||||||
|
|
||||||
|
# 1. 构建图
|
||||||
|
graph = self.build_graph()
|
||||||
|
|
||||||
|
# 2. 初始化状态(自动注入系统变量)
|
||||||
|
initial_state = self._prepare_initial_state(input_data)
|
||||||
|
|
||||||
|
# 3. 执行工作流
|
||||||
|
try:
|
||||||
|
result = await graph.ainvoke(initial_state)
|
||||||
|
|
||||||
|
# 计算耗时
|
||||||
|
end_time = datetime.datetime.now()
|
||||||
|
elapsed_time = (end_time - start_time).total_seconds()
|
||||||
|
|
||||||
|
# 提取节点输出(现在包含 start 和 end 节点)
|
||||||
|
node_outputs = result.get("node_outputs", {})
|
||||||
|
|
||||||
|
# 提取最终输出(从最后一个非 start/end 节点)
|
||||||
|
final_output = self._extract_final_output(node_outputs)
|
||||||
|
|
||||||
|
# 聚合 token 使用情况
|
||||||
|
token_usage = self._aggregate_token_usage(node_outputs)
|
||||||
|
|
||||||
|
# 提取 conversation_id(从 start 节点输出)
|
||||||
|
conversation_id = None
|
||||||
|
for node_id, node_output in node_outputs.items():
|
||||||
|
if node_output.get("node_type") == "start":
|
||||||
|
conversation_id = node_output.get("output", {}).get("conversation_id")
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"output": final_output,
|
||||||
|
"node_outputs": node_outputs,
|
||||||
|
"messages": result.get("messages", []),
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"elapsed_time": elapsed_time,
|
||||||
|
"token_usage": token_usage,
|
||||||
|
"error": result.get("error")
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 计算耗时(即使失败也记录)
|
||||||
|
end_time = datetime.datetime.now()
|
||||||
|
elapsed_time = (end_time - start_time).total_seconds()
|
||||||
|
|
||||||
|
logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True)
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"error": str(e),
|
||||||
|
"output": None,
|
||||||
|
"node_outputs": {},
|
||||||
|
"elapsed_time": elapsed_time,
|
||||||
|
"token_usage": None
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute_stream(
|
||||||
|
self,
|
||||||
|
input_data: dict[str, Any]
|
||||||
|
):
|
||||||
|
"""执行工作流(流式)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_data: 输入数据
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
流式事件
|
||||||
|
"""
|
||||||
|
logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}")
|
||||||
|
|
||||||
|
# 1. 构建图
|
||||||
|
graph = self.build_graph()
|
||||||
|
|
||||||
|
# 2. 初始化状态(自动注入系统变量)
|
||||||
|
initial_state = self._prepare_initial_state(input_data)
|
||||||
|
|
||||||
|
# 3. 流式执行工作流
|
||||||
|
try:
|
||||||
|
# 使用 astream 获取节点级别的更新
|
||||||
|
async for event in graph.astream(initial_state, stream_mode="updates"):
|
||||||
|
for node_name, state_update in event.items():
|
||||||
|
yield {
|
||||||
|
"type": "node_complete",
|
||||||
|
"node": node_name,
|
||||||
|
"data": state_update,
|
||||||
|
"execution_id": self.execution_id
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"工作流执行完成(流式): execution_id={self.execution_id}")
|
||||||
|
|
||||||
|
# 发送完成事件
|
||||||
|
yield {
|
||||||
|
"type": "workflow_complete",
|
||||||
|
"execution_id": self.execution_id
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"工作流执行失败(流式): execution_id={self.execution_id}, error={e}", exc_info=True)
|
||||||
|
yield {
|
||||||
|
"type": "workflow_error",
|
||||||
|
"execution_id": self.execution_id,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
|
||||||
|
"""从节点输出中提取最终输出
|
||||||
|
|
||||||
|
优先级:
|
||||||
|
1. 最后一个执行的非 start/end 节点的 output
|
||||||
|
2. 如果没有节点输出,返回 None
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_outputs: 所有节点的输出
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
最终输出字符串或 None
|
||||||
|
"""
|
||||||
|
if not node_outputs:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 获取最后一个节点的输出
|
||||||
|
last_node_output = list(node_outputs.values())[-1] if node_outputs else None
|
||||||
|
|
||||||
|
if last_node_output and isinstance(last_node_output, dict):
|
||||||
|
return last_node_output.get("output")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _aggregate_token_usage(self, node_outputs: dict[str, Any]) -> dict[str, int] | None:
|
||||||
|
"""聚合所有节点的 token 使用情况
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_outputs: 所有节点的输出
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
聚合的 token 使用情况 {"prompt_tokens": x, "completion_tokens": y, "total_tokens": z}
|
||||||
|
如果没有 token 使用信息,返回 None
|
||||||
|
"""
|
||||||
|
total_prompt_tokens = 0
|
||||||
|
total_completion_tokens = 0
|
||||||
|
total_tokens = 0
|
||||||
|
has_token_info = False
|
||||||
|
|
||||||
|
for node_output in node_outputs.values():
|
||||||
|
if isinstance(node_output, dict):
|
||||||
|
token_usage = node_output.get("token_usage")
|
||||||
|
if token_usage and isinstance(token_usage, dict):
|
||||||
|
has_token_info = True
|
||||||
|
total_prompt_tokens += token_usage.get("prompt_tokens", 0)
|
||||||
|
total_completion_tokens += token_usage.get("completion_tokens", 0)
|
||||||
|
total_tokens += token_usage.get("total_tokens", 0)
|
||||||
|
|
||||||
|
if not has_token_info:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"prompt_tokens": total_prompt_tokens,
|
||||||
|
"completion_tokens": total_completion_tokens,
|
||||||
|
"total_tokens": total_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_workflow(
|
||||||
|
workflow_config: dict[str, Any],
|
||||||
|
input_data: dict[str, Any],
|
||||||
|
execution_id: str,
|
||||||
|
workspace_id: str,
|
||||||
|
user_id: str
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""执行工作流(便捷函数)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_config: 工作流配置
|
||||||
|
input_data: 输入数据
|
||||||
|
execution_id: 执行 ID
|
||||||
|
workspace_id: 工作空间 ID
|
||||||
|
user_id: 用户 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行结果
|
||||||
|
"""
|
||||||
|
executor = WorkflowExecutor(
|
||||||
|
workflow_config=workflow_config,
|
||||||
|
execution_id=execution_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
user_id=user_id
|
||||||
|
)
|
||||||
|
return await executor.execute(input_data)
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_workflow_stream(
|
||||||
|
workflow_config: dict[str, Any],
|
||||||
|
input_data: dict[str, Any],
|
||||||
|
execution_id: str,
|
||||||
|
workspace_id: str,
|
||||||
|
user_id: str
|
||||||
|
):
|
||||||
|
"""执行工作流(流式,便捷函数)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_config: 工作流配置
|
||||||
|
input_data: 输入数据
|
||||||
|
execution_id: 执行 ID
|
||||||
|
workspace_id: 工作空间 ID
|
||||||
|
user_id: 用户 ID
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
流式事件
|
||||||
|
"""
|
||||||
|
executor = WorkflowExecutor(
|
||||||
|
workflow_config=workflow_config,
|
||||||
|
execution_id=execution_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
user_id=user_id
|
||||||
|
)
|
||||||
|
async for event in executor.execute_stream(input_data):
|
||||||
|
yield event
|
||||||
195
api/app/core/workflow/expression_evaluator.py
Normal file
195
api/app/core/workflow/expression_evaluator.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
"""
|
||||||
|
安全的表达式求值器
|
||||||
|
|
||||||
|
使用 simpleeval 库提供安全的表达式评估,避免代码注入攻击。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from simpleeval import simple_eval, NameNotDefined, InvalidExpression
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ExpressionEvaluator:
|
||||||
|
"""安全的表达式求值器"""
|
||||||
|
|
||||||
|
# 保留的命名空间
|
||||||
|
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def evaluate(
|
||||||
|
expression: str,
|
||||||
|
variables: dict[str, Any],
|
||||||
|
node_outputs: dict[str, Any],
|
||||||
|
system_vars: dict[str, Any] | None = None
|
||||||
|
) -> Any:
|
||||||
|
"""安全地评估表达式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expression: 表达式字符串,如 "{{var.score}} > 0.8"
|
||||||
|
variables: 用户定义的变量
|
||||||
|
node_outputs: 节点输出结果
|
||||||
|
system_vars: 系统变量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
表达式求值结果
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 表达式无效或求值失败
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> evaluator = ExpressionEvaluator()
|
||||||
|
>>> evaluator.evaluate(
|
||||||
|
... "var.score > 0.8",
|
||||||
|
... {"score": 0.9},
|
||||||
|
... {},
|
||||||
|
... {}
|
||||||
|
... )
|
||||||
|
True
|
||||||
|
|
||||||
|
>>> evaluator.evaluate(
|
||||||
|
... "node.intent.output == '售前咨询'",
|
||||||
|
... {},
|
||||||
|
... {"intent": {"output": "售前咨询"}},
|
||||||
|
... {}
|
||||||
|
... )
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
# 移除 Jinja2 模板语法的花括号(如果存在)
|
||||||
|
expression = expression.strip()
|
||||||
|
if expression.startswith("{{") and expression.endswith("}}"):
|
||||||
|
expression = expression[2:-2].strip()
|
||||||
|
|
||||||
|
# 构建命名空间上下文
|
||||||
|
context = {
|
||||||
|
"var": variables, # 用户变量
|
||||||
|
"node": node_outputs, # 节点输出
|
||||||
|
"sys": system_vars or {}, # 系统变量
|
||||||
|
}
|
||||||
|
|
||||||
|
# 为了向后兼容,也支持直接访问(但会在日志中警告)
|
||||||
|
context.update(variables)
|
||||||
|
context["nodes"] = node_outputs
|
||||||
|
|
||||||
|
try:
|
||||||
|
# simpleeval 只支持安全的操作:
|
||||||
|
# - 算术运算: +, -, *, /, //, %, **
|
||||||
|
# - 比较运算: ==, !=, <, <=, >, >=
|
||||||
|
# - 逻辑运算: and, or, not
|
||||||
|
# - 成员运算: in, not in
|
||||||
|
# - 属性访问: obj.attr
|
||||||
|
# - 字典/列表访问: obj["key"], obj[0]
|
||||||
|
# 不支持:函数调用、导入、赋值等危险操作
|
||||||
|
result = simple_eval(expression, names=context)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except NameNotDefined as e:
|
||||||
|
logger.error(f"表达式中引用了未定义的变量: {expression}, 错误: {e}")
|
||||||
|
raise ValueError(f"未定义的变量: {e}")
|
||||||
|
|
||||||
|
except InvalidExpression as e:
|
||||||
|
logger.error(f"表达式语法无效: {expression}, 错误: {e}")
|
||||||
|
raise ValueError(f"表达式语法无效: {e}")
|
||||||
|
|
||||||
|
except SyntaxError as e:
|
||||||
|
logger.error(f"表达式语法错误: {expression}, 错误: {e}")
|
||||||
|
raise ValueError(f"表达式语法错误: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"表达式求值异常: {expression}, 错误: {e}")
|
||||||
|
raise ValueError(f"表达式求值失败: {e}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def evaluate_bool(
|
||||||
|
expression: str,
|
||||||
|
variables: dict[str, Any],
|
||||||
|
node_outputs: dict[str, Any],
|
||||||
|
system_vars: dict[str, Any] | None = None
|
||||||
|
) -> bool:
|
||||||
|
"""评估布尔表达式(用于条件判断)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expression: 布尔表达式
|
||||||
|
variables: 用户变量
|
||||||
|
node_outputs: 节点输出
|
||||||
|
system_vars: 系统变量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
布尔值结果
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> ExpressionEvaluator.evaluate_bool(
|
||||||
|
... "var.count >= 10 and var.status == 'active'",
|
||||||
|
... {"count": 15, "status": "active"},
|
||||||
|
... {},
|
||||||
|
... {}
|
||||||
|
... )
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
result = ExpressionEvaluator.evaluate(
|
||||||
|
expression, variables, node_outputs, system_vars
|
||||||
|
)
|
||||||
|
return bool(result)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_variable_names(variables: list[dict]) -> list[str]:
|
||||||
|
"""验证变量名是否合法
|
||||||
|
|
||||||
|
Args:
|
||||||
|
variables: 变量定义列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
错误列表,如果为空则验证通过
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> ExpressionEvaluator.validate_variable_names([
|
||||||
|
... {"name": "user_input"},
|
||||||
|
... {"name": "var"} # 保留字
|
||||||
|
... ])
|
||||||
|
["变量名 'var' 是保留的命名空间,请使用其他名称"]
|
||||||
|
"""
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
for var in variables:
|
||||||
|
var_name = var.get("name", "")
|
||||||
|
|
||||||
|
# 检查是否为保留命名空间
|
||||||
|
if var_name in ExpressionEvaluator.RESERVED_NAMESPACES:
|
||||||
|
errors.append(
|
||||||
|
f"变量名 '{var_name}' 是保留的命名空间,请使用其他名称"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查是否为有效的 Python 标识符
|
||||||
|
if not var_name.isidentifier():
|
||||||
|
errors.append(
|
||||||
|
f"变量名 '{var_name}' 不是有效的标识符"
|
||||||
|
)
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
# 便捷函数
|
||||||
|
def evaluate_expression(
|
||||||
|
expression: str,
|
||||||
|
variables: dict[str, Any],
|
||||||
|
node_outputs: dict[str, Any],
|
||||||
|
system_vars: dict[str, Any] | None = None
|
||||||
|
) -> Any:
|
||||||
|
"""评估表达式(便捷函数)"""
|
||||||
|
return ExpressionEvaluator.evaluate(
|
||||||
|
expression, variables, node_outputs, system_vars
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_condition(
|
||||||
|
expression: str,
|
||||||
|
variables: dict[str, Any],
|
||||||
|
node_outputs: dict[str, Any],
|
||||||
|
system_vars: dict[str, Any] | None = None
|
||||||
|
) -> bool:
|
||||||
|
"""评估条件表达式(便捷函数)"""
|
||||||
|
return ExpressionEvaluator.evaluate_bool(
|
||||||
|
expression, variables, node_outputs, system_vars
|
||||||
|
)
|
||||||
24
api/app/core/workflow/nodes/__init__.py
Normal file
24
api/app/core/workflow/nodes/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""
|
||||||
|
工作流节点实现
|
||||||
|
|
||||||
|
提供各种类型的节点实现,用于工作流执行。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
|
from app.core.workflow.nodes.llm import LLMNode
|
||||||
|
from app.core.workflow.nodes.agent import AgentNode
|
||||||
|
from app.core.workflow.nodes.transform import TransformNode
|
||||||
|
from app.core.workflow.nodes.start import StartNode
|
||||||
|
from app.core.workflow.nodes.end import EndNode
|
||||||
|
from app.core.workflow.nodes.node_factory import NodeFactory
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseNode",
|
||||||
|
"WorkflowState",
|
||||||
|
"LLMNode",
|
||||||
|
"AgentNode",
|
||||||
|
"TransformNode",
|
||||||
|
"StartNode",
|
||||||
|
"EndNode",
|
||||||
|
"NodeFactory",
|
||||||
|
]
|
||||||
6
api/app/core/workflow/nodes/agent/__init__.py
Normal file
6
api/app/core/workflow/nodes/agent/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
"""Agent 节点"""
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.agent.node import AgentNode
|
||||||
|
from app.core.workflow.nodes.agent.config import AgentNodeConfig
|
||||||
|
|
||||||
|
__all__ = ["AgentNode", "AgentNodeConfig"]
|
||||||
71
api/app/core/workflow/nodes/agent/config.py
Normal file
71
api/app/core/workflow/nodes/agent/config.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""Agent 节点配置"""
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||||
|
|
||||||
|
|
||||||
|
class AgentNodeConfig(BaseNodeConfig):
|
||||||
|
"""Agent 节点配置
|
||||||
|
|
||||||
|
调用已配置的 Agent 执行任务。
|
||||||
|
"""
|
||||||
|
|
||||||
|
agent_id: str = Field(
|
||||||
|
...,
|
||||||
|
description="Agent 配置 ID"
|
||||||
|
)
|
||||||
|
|
||||||
|
message: str = Field(
|
||||||
|
default="{{ sys.message }}",
|
||||||
|
description="发送给 Agent 的消息,支持模板变量"
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_id: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="会话 ID,用于多轮对话"
|
||||||
|
)
|
||||||
|
|
||||||
|
variables: dict[str, str] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="传递给 Agent 的变量"
|
||||||
|
)
|
||||||
|
|
||||||
|
timeout: int = Field(
|
||||||
|
default=300,
|
||||||
|
ge=1,
|
||||||
|
le=3600,
|
||||||
|
description="超时时间(秒)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 输出变量定义
|
||||||
|
output_variables: list[VariableDefinition] = Field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
VariableDefinition(
|
||||||
|
name="output",
|
||||||
|
type=VariableType.STRING,
|
||||||
|
description="Agent 的回复内容"
|
||||||
|
),
|
||||||
|
VariableDefinition(
|
||||||
|
name="conversation_id",
|
||||||
|
type=VariableType.STRING,
|
||||||
|
description="会话 ID"
|
||||||
|
),
|
||||||
|
VariableDefinition(
|
||||||
|
name="token_usage",
|
||||||
|
type=VariableType.OBJECT,
|
||||||
|
description="Token 使用情况"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
description="输出变量定义(自动生成,通常不需要修改)"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"agent_id": "uuid-here",
|
||||||
|
"message": "{{ sys.message }}",
|
||||||
|
"timeout": 300,
|
||||||
|
"description": "调用客服 Agent"
|
||||||
|
}
|
||||||
|
}
|
||||||
152
api/app/core/workflow/nodes/agent/node.py
Normal file
152
api/app/core/workflow/nodes/agent/node.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
"""
|
||||||
|
Agent 节点实现
|
||||||
|
|
||||||
|
调用已发布的 Agent 应用。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
|
from app.services.draft_run_service import DraftRunService
|
||||||
|
from app.models import AppRelease
|
||||||
|
from app.db import get_db
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentNode(BaseNode):
|
||||||
|
"""Agent 节点
|
||||||
|
|
||||||
|
支持流式和非流式输出。
|
||||||
|
|
||||||
|
配置示例:
|
||||||
|
{
|
||||||
|
"type": "agent",
|
||||||
|
"config": {
|
||||||
|
"agent_id": "uuid", # Agent 的 release_id
|
||||||
|
"message": "{{var.user_input}}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _prepare_agent(self, state: WorkflowState) -> tuple[DraftRunService, AppRelease, str]:
|
||||||
|
"""准备 Agent(公共逻辑)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(draft_service, release, message): 服务实例、发布配置、消息
|
||||||
|
"""
|
||||||
|
# 1. 渲染消息
|
||||||
|
message_template = self.config.get("message", "")
|
||||||
|
message = self._render_template(message_template, state)
|
||||||
|
|
||||||
|
# 2. 获取 Agent 配置
|
||||||
|
agent_id = self.config.get("agent_id")
|
||||||
|
if not agent_id:
|
||||||
|
raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置")
|
||||||
|
|
||||||
|
db = next(get_db())
|
||||||
|
release = db.query(AppRelease).filter(
|
||||||
|
AppRelease.id == agent_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not release:
|
||||||
|
raise ValueError(f"Agent 不存在: {agent_id}")
|
||||||
|
|
||||||
|
draft_service = DraftRunService(db)
|
||||||
|
|
||||||
|
return draft_service, release, message
|
||||||
|
|
||||||
|
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
|
"""非流式执行
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
状态更新字典
|
||||||
|
"""
|
||||||
|
draft_service, release, message = self._prepare_agent(state)
|
||||||
|
|
||||||
|
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
|
||||||
|
|
||||||
|
# 执行 Agent(非流式)
|
||||||
|
result = await draft_service.run(
|
||||||
|
agent_config=release.config,
|
||||||
|
model_config=None,
|
||||||
|
message=message,
|
||||||
|
workspace_id=state.get("workspace_id"),
|
||||||
|
user_id=state.get("user_id"),
|
||||||
|
variables=state.get("variables", {})
|
||||||
|
)
|
||||||
|
|
||||||
|
response = result.get("response", "")
|
||||||
|
|
||||||
|
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(response)}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"messages": [AIMessage(content=response)],
|
||||||
|
"node_outputs": {
|
||||||
|
self.node_id: {
|
||||||
|
"output": response,
|
||||||
|
"status": "completed",
|
||||||
|
"meta_data": result.get("meta_data", {})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute_stream(self, state: WorkflowState):
|
||||||
|
"""流式执行
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
流式事件字典
|
||||||
|
"""
|
||||||
|
draft_service, release, message = self._prepare_agent(state)
|
||||||
|
|
||||||
|
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
|
||||||
|
|
||||||
|
# 累积完整响应
|
||||||
|
full_response = ""
|
||||||
|
|
||||||
|
# 执行 Agent(流式)
|
||||||
|
async for chunk in draft_service.run_stream(
|
||||||
|
agent_config=release.config,
|
||||||
|
model_config=None,
|
||||||
|
message=message,
|
||||||
|
workspace_id=state.get("workspace_id"),
|
||||||
|
user_id=state.get("user_id"),
|
||||||
|
variables=state.get("variables", {})
|
||||||
|
):
|
||||||
|
# 提取内容
|
||||||
|
content = chunk.get("content", "")
|
||||||
|
full_response += content
|
||||||
|
|
||||||
|
# 流式返回每个 chunk
|
||||||
|
yield {
|
||||||
|
"type": "chunk",
|
||||||
|
"node_id": self.node_id,
|
||||||
|
"content": content,
|
||||||
|
"full_content": full_response,
|
||||||
|
"meta_data": chunk.get("meta_data", {})
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}")
|
||||||
|
|
||||||
|
# 最后返回完整结果
|
||||||
|
yield {
|
||||||
|
"type": "complete",
|
||||||
|
"messages": [AIMessage(content=full_response)],
|
||||||
|
"node_outputs": {
|
||||||
|
self.node_id: {
|
||||||
|
"output": full_response,
|
||||||
|
"status": "completed"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
109
api/app/core/workflow/nodes/base_config.py
Normal file
109
api/app/core/workflow/nodes/base_config.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
"""节点配置基类
|
||||||
|
|
||||||
|
定义所有节点配置的通用字段和数据结构。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class VariableType(StrEnum):
|
||||||
|
"""变量类型枚举"""
|
||||||
|
|
||||||
|
STRING = "string"
|
||||||
|
NUMBER = "number"
|
||||||
|
BOOLEAN = "boolean"
|
||||||
|
ARRAY = "array"
|
||||||
|
OBJECT = "object"
|
||||||
|
|
||||||
|
|
||||||
|
class VariableDefinition(BaseModel):
|
||||||
|
"""变量定义
|
||||||
|
|
||||||
|
定义工作流或节点的输入/输出变量。
|
||||||
|
这是一个通用的数据结构,可以在多个地方使用。
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = Field(
|
||||||
|
...,
|
||||||
|
description="变量名称"
|
||||||
|
)
|
||||||
|
|
||||||
|
type: VariableType = Field(
|
||||||
|
default=VariableType.STRING,
|
||||||
|
description="变量类型"
|
||||||
|
)
|
||||||
|
|
||||||
|
required: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="是否必需"
|
||||||
|
)
|
||||||
|
|
||||||
|
default: str | int | float | bool | list | dict | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="默认值"
|
||||||
|
)
|
||||||
|
|
||||||
|
description: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="变量描述"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"examples": [
|
||||||
|
{
|
||||||
|
"name": "language",
|
||||||
|
"type": "string",
|
||||||
|
"required": False,
|
||||||
|
"default": "zh-CN",
|
||||||
|
"description": "语言设置"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "max_length",
|
||||||
|
"type": "number",
|
||||||
|
"required": False,
|
||||||
|
"default": 1000,
|
||||||
|
"description": "最大长度"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "enable_search",
|
||||||
|
"type": "boolean",
|
||||||
|
"required": True,
|
||||||
|
"description": "是否启用搜索"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BaseNodeConfig(BaseModel):
|
||||||
|
"""节点配置基类
|
||||||
|
|
||||||
|
所有节点配置都应该继承此基类。
|
||||||
|
|
||||||
|
通用字段:
|
||||||
|
- name: 节点名称(显示名称)
|
||||||
|
- description: 节点描述
|
||||||
|
- tags: 节点标签(用于分类和搜索)
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="节点名称(显示名称),如果不设置则使用节点 ID"
|
||||||
|
)
|
||||||
|
|
||||||
|
description: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="节点描述,说明节点的作用"
|
||||||
|
)
|
||||||
|
|
||||||
|
tags: list[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="节点标签,用于分类和搜索"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Pydantic 配置"""
|
||||||
|
# 允许额外字段(向后兼容)
|
||||||
|
extra = "allow"
|
||||||
556
api/app/core/workflow/nodes/base_node.py
Normal file
556
api/app/core/workflow/nodes/base_node.py
Normal file
@@ -0,0 +1,556 @@
|
|||||||
|
"""
|
||||||
|
工作流节点基类
|
||||||
|
|
||||||
|
定义节点的基本接口和通用功能。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, TypedDict, Annotated
|
||||||
|
from operator import add
|
||||||
|
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
|
||||||
|
|
||||||
|
from app.core.workflow.variable_pool import VariablePool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowState(TypedDict):
|
||||||
|
"""工作流状态
|
||||||
|
|
||||||
|
在节点间传递的状态对象,包含消息、变量、节点输出等信息。
|
||||||
|
"""
|
||||||
|
# 消息列表(追加模式)
|
||||||
|
messages: Annotated[list[AnyMessage], add]
|
||||||
|
|
||||||
|
# 输入变量(从配置的 variables 传入)
|
||||||
|
variables: dict[str, Any]
|
||||||
|
|
||||||
|
# 节点输出(存储每个节点的执行结果,用于变量引用)
|
||||||
|
# 使用自定义合并函数,将新的节点输出合并到现有字典中
|
||||||
|
node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||||
|
|
||||||
|
# 运行时节点变量(简化版,只存储业务数据,供节点间快速访问)
|
||||||
|
# 格式:{node_id: business_result}
|
||||||
|
runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||||
|
|
||||||
|
# 执行上下文
|
||||||
|
execution_id: str
|
||||||
|
workspace_id: str
|
||||||
|
user_id: str
|
||||||
|
|
||||||
|
# 错误信息(用于错误边)
|
||||||
|
error: str | None
|
||||||
|
error_node: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class BaseNode(ABC):
|
||||||
|
"""节点基类
|
||||||
|
|
||||||
|
所有节点类型都应该继承此基类,实现 execute 方法。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
|
"""初始化节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_config: 节点配置
|
||||||
|
workflow_config: 工作流配置
|
||||||
|
"""
|
||||||
|
self.node_config = node_config
|
||||||
|
self.workflow_config = workflow_config
|
||||||
|
self.node_id = node_config["id"]
|
||||||
|
self.node_type = node_config["type"]
|
||||||
|
self.node_name = node_config.get("name", self.node_id)
|
||||||
|
# 使用 or 运算符处理 None 值
|
||||||
|
self.config = node_config.get("config") or {}
|
||||||
|
self.error_handling = node_config.get("error_handling") or {}
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def execute(self, state: WorkflowState) -> Any:
|
||||||
|
"""执行节点业务逻辑(非流式)
|
||||||
|
|
||||||
|
节点只需要返回业务结果,不需要关心输出格式、时间统计等。
|
||||||
|
BaseNode 会自动包装成标准格式。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
业务结果(任意类型)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # LLM 节点
|
||||||
|
>>> return "这是 AI 的回复"
|
||||||
|
|
||||||
|
>>> # Transform 节点
|
||||||
|
>>> return {"processed_data": [...]}
|
||||||
|
|
||||||
|
>>> # Start/End 节点
|
||||||
|
>>> return {"message": "开始", "conversation_id": "xxx"}
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute_stream(self, state: WorkflowState):
|
||||||
|
"""执行节点业务逻辑(流式)
|
||||||
|
|
||||||
|
子类可以重写此方法以支持流式输出。
|
||||||
|
默认实现:执行非流式方法并一次性返回。
|
||||||
|
|
||||||
|
节点需要:
|
||||||
|
1. yield 中间结果(如文本片段)
|
||||||
|
2. 最后 yield 一个特殊的完成标记:{"__final__": True, "result": final_result}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
业务数据(chunk)或完成标记
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # 流式 LLM 节点
|
||||||
|
>>> full_response = ""
|
||||||
|
>>> async for chunk in llm.astream(prompt):
|
||||||
|
... full_response += chunk
|
||||||
|
... yield chunk # yield 文本片段
|
||||||
|
>>>
|
||||||
|
>>> # 最后 yield 完成标记
|
||||||
|
>>> yield {"__final__": True, "result": AIMessage(content=full_response)}
|
||||||
|
"""
|
||||||
|
result = await self.execute(state)
|
||||||
|
# 默认实现:直接 yield 完成标记
|
||||||
|
yield {"__final__": True, "result": result}
|
||||||
|
|
||||||
|
def supports_streaming(self) -> bool:
|
||||||
|
"""节点是否支持流式输出
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否支持流式输出
|
||||||
|
"""
|
||||||
|
# 检查子类是否重写了 execute_stream 方法
|
||||||
|
return self.execute_stream.__func__ != BaseNode.execute_stream.__func__
|
||||||
|
|
||||||
|
def get_timeout(self) -> int:
|
||||||
|
"""获取超时时间(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
超时时间
|
||||||
|
"""
|
||||||
|
return 60
|
||||||
|
# return self.error_handling.get("timeout", 60)
|
||||||
|
|
||||||
|
async def run(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
|
"""执行节点(带错误处理和输出包装,非流式)
|
||||||
|
|
||||||
|
这个方法由 Executor 调用,负责:
|
||||||
|
1. 时间统计
|
||||||
|
2. 调用节点的 execute() 方法
|
||||||
|
3. 将业务结果包装成标准输出格式
|
||||||
|
4. 错误处理
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
标准化的状态更新字典
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
timeout = self.get_timeout()
|
||||||
|
|
||||||
|
# 调用节点的业务逻辑
|
||||||
|
business_result = await asyncio.wait_for(
|
||||||
|
self.execute(state),
|
||||||
|
timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
# 提取处理后的输出(调用子类的 _extract_output)
|
||||||
|
extracted_output = self._extract_output(business_result)
|
||||||
|
|
||||||
|
# 包装成标准输出格式
|
||||||
|
wrapped_output = self._wrap_output(business_result, elapsed_time, state)
|
||||||
|
|
||||||
|
# 将提取后的输出存储到运行时变量中(供后续节点快速访问)
|
||||||
|
# 如果提取后的输出是字典,拆包存储;否则存储为 output 字段
|
||||||
|
if isinstance(extracted_output, dict):
|
||||||
|
runtime_var = extracted_output
|
||||||
|
else:
|
||||||
|
runtime_var = {"output": extracted_output}
|
||||||
|
|
||||||
|
# 返回包装后的输出和运行时变量
|
||||||
|
return {
|
||||||
|
**wrapped_output,
|
||||||
|
"runtime_vars": {
|
||||||
|
self.node_id: runtime_var
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except TimeoutError:
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
|
||||||
|
return self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
|
||||||
|
except Exception as e:
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
|
||||||
|
return self._wrap_error(str(e), elapsed_time, state)
|
||||||
|
|
||||||
|
async def run_stream(self, state: WorkflowState):
|
||||||
|
"""执行节点(带错误处理和输出包装,流式)
|
||||||
|
|
||||||
|
这个方法由 Executor 调用,负责:
|
||||||
|
1. 时间统计
|
||||||
|
2. 调用节点的 execute_stream() 方法
|
||||||
|
3. 将业务数据包装成标准输出格式
|
||||||
|
4. 错误处理
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
标准化的流式事件
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
timeout = self.get_timeout()
|
||||||
|
|
||||||
|
# 累积完整结果(用于最后的包装)
|
||||||
|
chunks = []
|
||||||
|
final_result = None
|
||||||
|
|
||||||
|
# 使用异步生成器包装,支持超时
|
||||||
|
async def stream_with_timeout():
|
||||||
|
nonlocal final_result
|
||||||
|
loop_start = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
async for item in self.execute_stream(state):
|
||||||
|
# 检查超时
|
||||||
|
if asyncio.get_event_loop().time() - loop_start > timeout:
|
||||||
|
raise TimeoutError()
|
||||||
|
|
||||||
|
# 检查是否是完成标记
|
||||||
|
if isinstance(item, dict) and item.get("__final__"):
|
||||||
|
final_result = item["result"]
|
||||||
|
elif isinstance(item, str):
|
||||||
|
# 字符串是 chunk
|
||||||
|
chunks.append(item)
|
||||||
|
yield {
|
||||||
|
"type": "chunk",
|
||||||
|
"node_id": self.node_id,
|
||||||
|
"content": item,
|
||||||
|
"full_content": "".join(chunks)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 其他类型也当作 chunk 处理
|
||||||
|
chunks.append(str(item))
|
||||||
|
yield {
|
||||||
|
"type": "chunk",
|
||||||
|
"node_id": self.node_id,
|
||||||
|
"content": str(item),
|
||||||
|
"full_content": "".join(chunks)
|
||||||
|
}
|
||||||
|
|
||||||
|
async for chunk_event in stream_with_timeout():
|
||||||
|
yield chunk_event
|
||||||
|
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
# 包装最终结果
|
||||||
|
final_output = self._wrap_output(final_result, elapsed_time, state)
|
||||||
|
yield {
|
||||||
|
"type": "complete",
|
||||||
|
**final_output
|
||||||
|
}
|
||||||
|
|
||||||
|
except TimeoutError:
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
|
||||||
|
yield {
|
||||||
|
"type": "error",
|
||||||
|
**self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
|
||||||
|
yield {
|
||||||
|
"type": "error",
|
||||||
|
**self._wrap_error(str(e), elapsed_time, state)
|
||||||
|
}
|
||||||
|
|
||||||
|
def _wrap_output(
|
||||||
|
self,
|
||||||
|
business_result: Any,
|
||||||
|
elapsed_time: float,
|
||||||
|
state: WorkflowState
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""将业务结果包装成标准输出格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
business_result: 节点返回的业务结果
|
||||||
|
elapsed_time: 执行耗时
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
标准化的状态更新字典
|
||||||
|
"""
|
||||||
|
# 提取输入数据(用于记录)
|
||||||
|
input_data = self._extract_input(state)
|
||||||
|
|
||||||
|
# 提取 token 使用情况(如果有)
|
||||||
|
token_usage = self._extract_token_usage(business_result)
|
||||||
|
|
||||||
|
# 提取实际输出(去除元数据)
|
||||||
|
output = self._extract_output(business_result)
|
||||||
|
|
||||||
|
# 构建标准节点输出
|
||||||
|
node_output = {
|
||||||
|
"node_id": self.node_id,
|
||||||
|
"node_type": self.node_type,
|
||||||
|
"node_name": self.node_name,
|
||||||
|
"status": "completed",
|
||||||
|
"input": input_data,
|
||||||
|
"output": output,
|
||||||
|
"elapsed_time": elapsed_time,
|
||||||
|
"token_usage": token_usage,
|
||||||
|
"error": None
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"node_outputs": {
|
||||||
|
self.node_id: node_output
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def _wrap_error(
|
||||||
|
self,
|
||||||
|
error_message: str,
|
||||||
|
elapsed_time: float,
|
||||||
|
state: WorkflowState
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""将错误包装成标准输出格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_message: 错误信息
|
||||||
|
elapsed_time: 执行耗时
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
标准化的状态更新字典
|
||||||
|
"""
|
||||||
|
# 查找错误边
|
||||||
|
error_edge = self._find_error_edge()
|
||||||
|
|
||||||
|
# 提取输入数据
|
||||||
|
input_data = self._extract_input(state)
|
||||||
|
|
||||||
|
# 构建错误输出
|
||||||
|
node_output = {
|
||||||
|
"node_id": self.node_id,
|
||||||
|
"node_type": self.node_type,
|
||||||
|
"node_name": self.node_name,
|
||||||
|
"status": "failed",
|
||||||
|
"input": input_data,
|
||||||
|
"output": None,
|
||||||
|
"elapsed_time": elapsed_time,
|
||||||
|
"token_usage": None,
|
||||||
|
"error": error_message
|
||||||
|
}
|
||||||
|
|
||||||
|
if error_edge:
|
||||||
|
# 有错误边:记录错误并继续
|
||||||
|
logger.warning(
|
||||||
|
f"节点 {self.node_id} 执行失败,跳转到错误处理节点: {error_edge['target']}"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"node_outputs": {
|
||||||
|
self.node_id: node_output
|
||||||
|
},
|
||||||
|
"error": error_message,
|
||||||
|
"error_node": self.node_id
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 无错误边:抛出异常停止工作流
|
||||||
|
logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}")
|
||||||
|
raise Exception(f"节点 {self.node_id} 执行失败: {error_message}")
|
||||||
|
|
||||||
|
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
|
"""提取节点输入数据(用于记录)
|
||||||
|
|
||||||
|
子类可以重写此方法来自定义输入记录。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
输入数据字典
|
||||||
|
"""
|
||||||
|
# 默认返回配置
|
||||||
|
return {"config": self.config}
|
||||||
|
|
||||||
|
def _extract_output(self, business_result: Any) -> Any:
|
||||||
|
"""从业务结果中提取实际输出
|
||||||
|
|
||||||
|
子类可以重写此方法来自定义输出提取。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
business_result: 业务结果
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
实际输出
|
||||||
|
"""
|
||||||
|
# 默认直接返回业务结果
|
||||||
|
return business_result
|
||||||
|
|
||||||
|
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||||
|
"""从业务结果中提取 token 使用情况
|
||||||
|
|
||||||
|
子类可以重写此方法来提取 token 信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
business_result: 业务结果
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
token 使用情况或 None
|
||||||
|
"""
|
||||||
|
# 默认返回 None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _find_error_edge(self) -> dict[str, Any] | None:
|
||||||
|
"""查找错误边
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
错误边配置或 None
|
||||||
|
"""
|
||||||
|
for edge in self.workflow_config.get("edges", []):
|
||||||
|
if edge.get("source") == self.node_id and edge.get("type") == "error":
|
||||||
|
return edge
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _render_template(self, template: str, state: WorkflowState | None) -> str:
|
||||||
|
"""渲染模板
|
||||||
|
|
||||||
|
支持的变量命名空间:
|
||||||
|
- sys.xxx: 系统变量(message, execution_id, workspace_id, user_id, conversation_id)
|
||||||
|
- conv.xxx: 会话变量(跨多轮对话保持)
|
||||||
|
- node_id.xxx: 节点输出
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template: 模板字符串
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
渲染后的字符串
|
||||||
|
"""
|
||||||
|
from app.core.workflow.template_renderer import render_template
|
||||||
|
|
||||||
|
# 处理 state 为 None 的情况
|
||||||
|
if state is None:
|
||||||
|
state = {}
|
||||||
|
|
||||||
|
# 使用变量池获取变量
|
||||||
|
pool = VariablePool(state)
|
||||||
|
|
||||||
|
return render_template(
|
||||||
|
template=template,
|
||||||
|
variables=pool.get_all_conversation_vars(),
|
||||||
|
node_outputs=pool.get_all_node_outputs(),
|
||||||
|
system_vars=pool.get_all_system_vars()
|
||||||
|
)
|
||||||
|
|
||||||
|
def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool:
|
||||||
|
"""评估条件表达式
|
||||||
|
|
||||||
|
支持的变量命名空间:
|
||||||
|
- sys.xxx: 系统变量
|
||||||
|
- conv.xxx: 会话变量
|
||||||
|
- node_id.xxx: 节点输出
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expression: 条件表达式
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
布尔值结果
|
||||||
|
"""
|
||||||
|
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||||
|
|
||||||
|
# 处理 state 为 None 的情况
|
||||||
|
if state is None:
|
||||||
|
state = {}
|
||||||
|
|
||||||
|
# 使用变量池获取变量
|
||||||
|
pool = VariablePool(state)
|
||||||
|
|
||||||
|
return evaluate_condition(
|
||||||
|
expression=expression,
|
||||||
|
variables=pool.get_all_conversation_vars(),
|
||||||
|
node_outputs=pool.get_all_node_outputs(),
|
||||||
|
system_vars=pool.get_all_system_vars()
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_variable_pool(self, state: WorkflowState) -> VariablePool:
|
||||||
|
"""获取变量池实例
|
||||||
|
|
||||||
|
VariablePool 是轻量级包装器,只持有 state 的引用,创建成本极低。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
VariablePool 实例
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> pool = self.get_variable_pool(state)
|
||||||
|
>>> message = pool.get("sys.message")
|
||||||
|
>>> llm_output = pool.get("llm_qa.output")
|
||||||
|
"""
|
||||||
|
return VariablePool(state)
|
||||||
|
|
||||||
|
def get_variable(
|
||||||
|
self,
|
||||||
|
selector: list[str] | str,
|
||||||
|
state: WorkflowState,
|
||||||
|
default: Any = None
|
||||||
|
) -> Any:
|
||||||
|
"""获取变量值(便捷方法)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
selector: 变量选择器
|
||||||
|
state: 工作流状态
|
||||||
|
default: 默认值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
变量值
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> message = self.get_variable("sys.message", state)
|
||||||
|
>>> output = self.get_variable(["llm_qa", "output"], state)
|
||||||
|
>>> custom = self.get_variable("var.custom", state, default="默认值")
|
||||||
|
"""
|
||||||
|
pool = VariablePool(state)
|
||||||
|
return pool.get(selector, default=default)
|
||||||
|
|
||||||
|
def has_variable(self, selector: list[str] | str, state: WorkflowState) -> bool:
|
||||||
|
"""检查变量是否存在(便捷方法)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
selector: 变量选择器
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
变量是否存在
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> if self.has_variable("llm_qa.output", state):
|
||||||
|
... output = self.get_variable("llm_qa.output", state)
|
||||||
|
"""
|
||||||
|
pool = VariablePool(state)
|
||||||
|
return pool.has(selector)
|
||||||
29
api/app/core/workflow/nodes/configs.py
Normal file
29
api/app/core/workflow/nodes/configs.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""节点配置类统一导出
|
||||||
|
|
||||||
|
所有节点的配置类都在这里导出,方便使用。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_config import (
|
||||||
|
BaseNodeConfig,
|
||||||
|
VariableDefinition,
|
||||||
|
VariableType,
|
||||||
|
)
|
||||||
|
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||||
|
from app.core.workflow.nodes.end.config import EndNodeConfig
|
||||||
|
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
|
||||||
|
from app.core.workflow.nodes.agent.config import AgentNodeConfig
|
||||||
|
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# 基础类
|
||||||
|
"BaseNodeConfig",
|
||||||
|
"VariableDefinition",
|
||||||
|
"VariableType",
|
||||||
|
# 节点配置
|
||||||
|
"StartNodeConfig",
|
||||||
|
"EndNodeConfig",
|
||||||
|
"LLMNodeConfig",
|
||||||
|
"MessageConfig",
|
||||||
|
"AgentNodeConfig",
|
||||||
|
"TransformNodeConfig",
|
||||||
|
]
|
||||||
6
api/app/core/workflow/nodes/end/__init__.py
Normal file
6
api/app/core/workflow/nodes/end/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
"""End 节点"""
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.end.node import EndNode
|
||||||
|
from app.core.workflow.nodes.end.config import EndNodeConfig
|
||||||
|
|
||||||
|
__all__ = ["EndNode", "EndNodeConfig"]
|
||||||
37
api/app/core/workflow/nodes/end/config.py
Normal file
37
api/app/core/workflow/nodes/end/config.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""End 节点配置"""
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||||
|
|
||||||
|
|
||||||
|
class EndNodeConfig(BaseNodeConfig):
|
||||||
|
"""End 节点配置
|
||||||
|
|
||||||
|
End 节点负责输出工作流的最终结果。
|
||||||
|
"""
|
||||||
|
|
||||||
|
output: str = Field(
|
||||||
|
default="工作流已完成",
|
||||||
|
description="输出模板,支持引用前置节点的输出,如:{{ llm_qa.output }}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 输出变量定义
|
||||||
|
output_variables: list[VariableDefinition] = Field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
VariableDefinition(
|
||||||
|
name="output",
|
||||||
|
type=VariableType.STRING,
|
||||||
|
description="工作流的最终输出"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
description="输出变量定义(自动生成,通常不需要修改)"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"output": "{{ llm_qa.output }}",
|
||||||
|
"description": "输出 LLM 的回答"
|
||||||
|
}
|
||||||
|
}
|
||||||
53
api/app/core/workflow/nodes/end/node.py
Normal file
53
api/app/core/workflow/nodes/end/node.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
"""
|
||||||
|
End 节点实现
|
||||||
|
|
||||||
|
工作流的结束节点,输出最终结果。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EndNode(BaseNode):
|
||||||
|
"""End 节点
|
||||||
|
|
||||||
|
工作流的结束节点,根据配置的模板输出最终结果。
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def execute(self, state: WorkflowState) -> str:
|
||||||
|
"""执行 end 节点业务逻辑
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
最终输出字符串
|
||||||
|
"""
|
||||||
|
logger.info(f"节点 {self.node_id} (End) 开始执行")
|
||||||
|
|
||||||
|
# 获取配置的输出模板
|
||||||
|
output_template = self.config.get("output")
|
||||||
|
pool = self.get_variable_pool(state)
|
||||||
|
|
||||||
|
print("="*20)
|
||||||
|
print( pool.get("start.test"))
|
||||||
|
print("="*20)
|
||||||
|
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||||
|
if output_template:
|
||||||
|
output = self._render_template(output_template, state)
|
||||||
|
else:
|
||||||
|
output = "工作流已完成"
|
||||||
|
|
||||||
|
# 统计信息(用于日志)
|
||||||
|
node_outputs = state.get("node_outputs", {})
|
||||||
|
total_nodes = len(node_outputs)
|
||||||
|
|
||||||
|
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
|
||||||
|
print("="*20)
|
||||||
|
print(output)
|
||||||
|
print("="*20)
|
||||||
|
return output
|
||||||
15
api/app/core/workflow/nodes/enums.py
Normal file
15
api/app/core/workflow/nodes/enums.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
class NodeType(StrEnum):
|
||||||
|
START = "start"
|
||||||
|
END = "end"
|
||||||
|
ANSWER = "answer"
|
||||||
|
LLM = "llm"
|
||||||
|
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
|
||||||
|
IF_ELSE = "if-else"
|
||||||
|
CODE = "code"
|
||||||
|
TRANSFORM = "transform"
|
||||||
|
QUESTION_CLASSIFIER = "question-classifier"
|
||||||
|
HTTP_REQUEST = "http-request"
|
||||||
|
TOOL = "tool"
|
||||||
|
AGENT = "agent"
|
||||||
6
api/app/core/workflow/nodes/llm/__init__.py
Normal file
6
api/app/core/workflow/nodes/llm/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
"""LLM 节点"""
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.llm.node import LLMNode
|
||||||
|
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
|
||||||
|
|
||||||
|
__all__ = ["LLMNode", "LLMNodeConfig", "MessageConfig"]
|
||||||
141
api/app/core/workflow/nodes/llm/config.py
Normal file
141
api/app/core/workflow/nodes/llm/config.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""LLM 节点配置"""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||||
|
|
||||||
|
|
||||||
|
class MessageConfig(BaseModel):
|
||||||
|
"""消息配置"""
|
||||||
|
|
||||||
|
role: str = Field(
|
||||||
|
...,
|
||||||
|
description="消息角色:system, user, assistant"
|
||||||
|
)
|
||||||
|
|
||||||
|
content: str = Field(
|
||||||
|
...,
|
||||||
|
description="消息内容,支持模板变量,如:{{ sys.message }}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("role")
|
||||||
|
@classmethod
|
||||||
|
def validate_role(cls, v: str) -> str:
|
||||||
|
"""验证角色"""
|
||||||
|
allowed_roles = ["system", "user", "human", "assistant", "ai"]
|
||||||
|
if v.lower() not in allowed_roles:
|
||||||
|
raise ValueError(f"角色必须是以下之一: {', '.join(allowed_roles)}")
|
||||||
|
return v.lower()
|
||||||
|
|
||||||
|
|
||||||
|
class LLMNodeConfig(BaseNodeConfig):
|
||||||
|
"""LLM 节点配置
|
||||||
|
|
||||||
|
支持两种配置方式:
|
||||||
|
1. 简单模式:使用 prompt 字段
|
||||||
|
2. 消息模式:使用 messages 字段(推荐)
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_id: str = Field(
|
||||||
|
...,
|
||||||
|
description="模型配置 ID"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 简单模式
|
||||||
|
prompt: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="提示词模板(简单模式),支持变量引用"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 消息模式(推荐)
|
||||||
|
messages: list[MessageConfig] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="消息列表(消息模式),支持多轮对话"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 模型参数
|
||||||
|
temperature: float | None = Field(
|
||||||
|
default=0.7,
|
||||||
|
ge=0.0,
|
||||||
|
le=2.0,
|
||||||
|
description="温度参数,控制输出的随机性"
|
||||||
|
)
|
||||||
|
|
||||||
|
max_tokens: int | None = Field(
|
||||||
|
default=1000,
|
||||||
|
ge=1,
|
||||||
|
le=32000,
|
||||||
|
description="最大生成 token 数"
|
||||||
|
)
|
||||||
|
|
||||||
|
top_p: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Top-p 采样参数"
|
||||||
|
)
|
||||||
|
|
||||||
|
frequency_penalty: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
ge=-2.0,
|
||||||
|
le=2.0,
|
||||||
|
description="频率惩罚"
|
||||||
|
)
|
||||||
|
|
||||||
|
presence_penalty: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
ge=-2.0,
|
||||||
|
le=2.0,
|
||||||
|
description="存在惩罚"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 输出变量定义
|
||||||
|
output_variables: list[VariableDefinition] = Field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
VariableDefinition(
|
||||||
|
name="output",
|
||||||
|
type=VariableType.STRING,
|
||||||
|
description="LLM 生成的文本输出"
|
||||||
|
),
|
||||||
|
VariableDefinition(
|
||||||
|
name="token_usage",
|
||||||
|
type=VariableType.OBJECT,
|
||||||
|
description="Token 使用情况"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
description="输出变量定义(自动生成,通常不需要修改)"
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("messages", "prompt")
|
||||||
|
@classmethod
|
||||||
|
def validate_input_mode(cls, v, info):
|
||||||
|
"""验证输入模式:prompt 和 messages 至少有一个"""
|
||||||
|
# 这个验证在 model_validator 中更合适
|
||||||
|
return v
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"examples": [
|
||||||
|
{
|
||||||
|
"model_id": "uuid-here",
|
||||||
|
"prompt": "请回答:{{ sys.message }}",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 1000
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_id": "uuid-here",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "你是一个专业的 AI 助手"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "{{ sys.message }}"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 1000
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
247
api/app/core/workflow/nodes/llm/node.py
Normal file
247
api/app/core/workflow/nodes/llm/node.py
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
"""
|
||||||
|
LLM 节点实现
|
||||||
|
|
||||||
|
调用 LLM 模型进行文本生成。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
|
from app.models import ModelConfig
|
||||||
|
from app.db import get_db, get_db_context
|
||||||
|
from app.models.models_model import ModelApiKey
|
||||||
|
from app.services.model_service import ModelConfigService, ModelApiKeyService
|
||||||
|
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMNode(BaseNode):
|
||||||
|
"""LLM 节点
|
||||||
|
|
||||||
|
支持流式和非流式输出,使用 LangChain 标准的消息格式。
|
||||||
|
|
||||||
|
配置示例(支持多种消息格式):
|
||||||
|
|
||||||
|
1. 简单文本格式:
|
||||||
|
{
|
||||||
|
"type": "llm",
|
||||||
|
"config": {
|
||||||
|
"model_id": "uuid",
|
||||||
|
"prompt": "请分析:{{sys.message}}",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 1000
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
2. LangChain 消息格式(推荐):
|
||||||
|
{
|
||||||
|
"type": "llm",
|
||||||
|
"config": {
|
||||||
|
"model_id": "uuid",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "你是一个专业的 AI 助手。"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "{{sys.message}}"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 1000
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
支持的角色类型:
|
||||||
|
- system: 系统消息(SystemMessage)
|
||||||
|
- user/human: 用户消息(HumanMessage)
|
||||||
|
- ai/assistant: AI 消息(AIMessage)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _prepare_llm(self, state: WorkflowState) -> tuple[RedBearLLM, list | str]:
|
||||||
|
"""准备 LLM 实例(公共逻辑)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(llm, messages_or_prompt): LLM 实例和消息列表或 prompt 字符串
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 1. 处理消息格式(优先使用 messages)
|
||||||
|
messages_config = self.config.get("messages")
|
||||||
|
|
||||||
|
if messages_config:
|
||||||
|
# 使用 LangChain 消息格式
|
||||||
|
messages = []
|
||||||
|
for msg_config in messages_config:
|
||||||
|
role = msg_config.get("role", "user").lower()
|
||||||
|
content_template = msg_config.get("content", "")
|
||||||
|
content = self._render_template(content_template, state)
|
||||||
|
|
||||||
|
# 根据角色创建对应的消息对象
|
||||||
|
if role == "system":
|
||||||
|
messages.append(SystemMessage(content=content))
|
||||||
|
elif role in ["user", "human"]:
|
||||||
|
messages.append(HumanMessage(content=content))
|
||||||
|
elif role in ["ai", "assistant"]:
|
||||||
|
messages.append(AIMessage(content=content))
|
||||||
|
else:
|
||||||
|
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||||
|
messages.append(HumanMessage(content=content))
|
||||||
|
|
||||||
|
prompt_or_messages = messages
|
||||||
|
else:
|
||||||
|
# 使用简单的 prompt 格式(向后兼容)
|
||||||
|
prompt_template = self.config.get("prompt", "")
|
||||||
|
prompt_or_messages = self._render_template(prompt_template, state)
|
||||||
|
|
||||||
|
# 2. 获取模型配置
|
||||||
|
model_id = self.config.get("model_id")
|
||||||
|
if not model_id:
|
||||||
|
raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置")
|
||||||
|
|
||||||
|
# 3. 在 with 块内完成所有数据库操作和数据提取
|
||||||
|
with get_db_context() as db:
|
||||||
|
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||||
|
|
||||||
|
if not config:
|
||||||
|
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
|
||||||
|
|
||||||
|
if not config.api_keys or len(config.api_keys) == 0:
|
||||||
|
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||||
|
|
||||||
|
# 在 Session 关闭前提取所有需要的数据
|
||||||
|
api_config = config.api_keys[0]
|
||||||
|
model_name = api_config.model_name
|
||||||
|
provider = api_config.provider
|
||||||
|
api_key = api_config.api_key
|
||||||
|
api_base = api_config.api_base
|
||||||
|
model_type = config.type
|
||||||
|
|
||||||
|
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||||
|
llm = RedBearLLM(
|
||||||
|
RedBearModelConfig(
|
||||||
|
model_name=model_name,
|
||||||
|
provider=provider,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base
|
||||||
|
),
|
||||||
|
type=model_type
|
||||||
|
)
|
||||||
|
|
||||||
|
return llm, prompt_or_messages
|
||||||
|
|
||||||
|
async def execute(self, state: WorkflowState) -> AIMessage:
|
||||||
|
"""非流式执行 LLM 调用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLM 响应消息
|
||||||
|
"""
|
||||||
|
llm, prompt_or_messages = self._prepare_llm(state)
|
||||||
|
|
||||||
|
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||||
|
|
||||||
|
# 调用 LLM(支持字符串或消息列表)
|
||||||
|
response = await llm.ainvoke(prompt_or_messages)
|
||||||
|
|
||||||
|
# 提取内容
|
||||||
|
if hasattr(response, 'content'):
|
||||||
|
content = response.content
|
||||||
|
else:
|
||||||
|
content = str(response)
|
||||||
|
|
||||||
|
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
|
||||||
|
|
||||||
|
# 返回 AIMessage(包含响应元数据)
|
||||||
|
return response if isinstance(response, AIMessage) else AIMessage(content=content)
|
||||||
|
|
||||||
|
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
|
"""提取输入数据(用于记录)"""
|
||||||
|
_, prompt_or_messages = self._prepare_llm(state)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
|
||||||
|
"messages": [
|
||||||
|
{"role": msg.__class__.__name__.replace("Message", "").lower(), "content": msg.content}
|
||||||
|
for msg in prompt_or_messages
|
||||||
|
] if isinstance(prompt_or_messages, list) else None,
|
||||||
|
"config": {
|
||||||
|
"model_id": self.config.get("model_id"),
|
||||||
|
"temperature": self.config.get("temperature"),
|
||||||
|
"max_tokens": self.config.get("max_tokens")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def _extract_output(self, business_result: Any) -> str:
|
||||||
|
"""从 AIMessage 中提取文本内容"""
|
||||||
|
if isinstance(business_result, AIMessage):
|
||||||
|
return business_result.content
|
||||||
|
return str(business_result)
|
||||||
|
|
||||||
|
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||||
|
"""从 AIMessage 中提取 token 使用情况"""
|
||||||
|
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
|
||||||
|
usage = business_result.response_metadata.get('token_usage')
|
||||||
|
if usage:
|
||||||
|
return {
|
||||||
|
"prompt_tokens": usage.get('prompt_tokens', 0),
|
||||||
|
"completion_tokens": usage.get('completion_tokens', 0),
|
||||||
|
"total_tokens": usage.get('total_tokens', 0)
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def execute_stream(self, state: WorkflowState):
|
||||||
|
"""流式执行 LLM 调用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
文本片段(chunk)或完成标记
|
||||||
|
"""
|
||||||
|
llm, prompt_or_messages = self._prepare_llm(state)
|
||||||
|
|
||||||
|
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||||
|
|
||||||
|
# 累积完整响应
|
||||||
|
full_response = ""
|
||||||
|
last_chunk = None
|
||||||
|
|
||||||
|
# 调用 LLM(流式,支持字符串或消息列表)
|
||||||
|
async for chunk in llm.astream(prompt_or_messages):
|
||||||
|
# 提取内容
|
||||||
|
if hasattr(chunk, 'content'):
|
||||||
|
content = chunk.content
|
||||||
|
else:
|
||||||
|
content = str(chunk)
|
||||||
|
|
||||||
|
full_response += content
|
||||||
|
last_chunk = chunk
|
||||||
|
|
||||||
|
# 流式返回每个文本片段
|
||||||
|
yield content
|
||||||
|
|
||||||
|
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
|
||||||
|
|
||||||
|
# 构建完整的 AIMessage(包含元数据)
|
||||||
|
if isinstance(last_chunk, AIMessage):
|
||||||
|
final_message = AIMessage(
|
||||||
|
content=full_response,
|
||||||
|
response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
final_message = AIMessage(content=full_response)
|
||||||
|
|
||||||
|
# yield 完成标记
|
||||||
|
yield {"__final__": True, "result": final_message}
|
||||||
93
api/app/core/workflow/nodes/node_factory.py
Normal file
93
api/app/core/workflow/nodes/node_factory.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
"""
|
||||||
|
节点工厂
|
||||||
|
|
||||||
|
根据节点类型创建相应的节点实例。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
|
from app.core.workflow.nodes.llm import LLMNode
|
||||||
|
from app.core.workflow.nodes.agent import AgentNode
|
||||||
|
from app.core.workflow.nodes.transform import TransformNode
|
||||||
|
from app.core.workflow.nodes.start import StartNode
|
||||||
|
from app.core.workflow.nodes.end import EndNode
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class NodeFactory:
|
||||||
|
"""节点工厂
|
||||||
|
|
||||||
|
使用工厂模式创建节点实例,便于扩展和维护。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 节点类型注册表
|
||||||
|
_node_types: dict[str, type[BaseNode]] = {
|
||||||
|
NodeType.START: StartNode,
|
||||||
|
NodeType.END: EndNode,
|
||||||
|
NodeType.LLM: LLMNode,
|
||||||
|
NodeType.AGENT: AgentNode,
|
||||||
|
NodeType.TRANSFORM: TransformNode,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_node_type(cls, node_type: str, node_class: type[BaseNode]):
|
||||||
|
"""注册新的节点类型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_type: 节点类型名称
|
||||||
|
node_class: 节点类
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> class CustomNode(BaseNode):
|
||||||
|
... async def execute(self, state):
|
||||||
|
... return {"node_outputs": {self.node_id: {"output": "custom"}}}
|
||||||
|
>>> NodeFactory.register_node_type("custom", CustomNode)
|
||||||
|
"""
|
||||||
|
cls._node_types[node_type] = node_class
|
||||||
|
logger.info(f"注册节点类型: {node_type} -> {node_class.__name__}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_node(
|
||||||
|
cls,
|
||||||
|
node_config: dict[str, Any],
|
||||||
|
workflow_config: dict[str, Any]
|
||||||
|
) -> BaseNode | None:
|
||||||
|
"""创建节点实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_config: 节点配置
|
||||||
|
workflow_config: 工作流配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
节点实例或 None(对于不支持的节点类型)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 不支持的节点类型
|
||||||
|
"""
|
||||||
|
node_type = node_config.get("type")
|
||||||
|
|
||||||
|
# 跳过条件节点(由 LangGraph 处理)
|
||||||
|
if node_type == "condition":
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 获取节点类
|
||||||
|
node_class = cls._node_types.get(node_type)
|
||||||
|
if not node_class:
|
||||||
|
raise ValueError(f"不支持的节点类型: {node_type}")
|
||||||
|
|
||||||
|
# 创建节点实例
|
||||||
|
logger.debug(f"创建节点: {node_config.get('id')} (type={node_type})")
|
||||||
|
return node_class(node_config, workflow_config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_types(cls) -> list[str]:
|
||||||
|
"""获取支持的节点类型列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
节点类型列表
|
||||||
|
"""
|
||||||
|
return list(cls._node_types.keys())
|
||||||
6
api/app/core/workflow/nodes/start/__init__.py
Normal file
6
api/app/core/workflow/nodes/start/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
"""Start 节点"""
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.start.node import StartNode
|
||||||
|
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||||
|
|
||||||
|
__all__ = ["StartNode", "StartNodeConfig"]
|
||||||
87
api/app/core/workflow/nodes/start/config.py
Normal file
87
api/app/core/workflow/nodes/start/config.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
"""Start 节点配置"""
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||||
|
|
||||||
|
|
||||||
|
class StartNodeConfig(BaseNodeConfig):
|
||||||
|
"""Start 节点配置
|
||||||
|
|
||||||
|
Start 节点的作用:
|
||||||
|
1. 标记工作流的起点
|
||||||
|
2. 定义自定义输入变量(会作为节点输出,通过 start_node_id.variable_name 访问)
|
||||||
|
3. 输出系统变量和会话变量
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 自定义输入变量定义
|
||||||
|
variables: list[VariableDefinition] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="自定义输入变量列表,这些变量会作为 Start 节点的输出"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 输出变量定义
|
||||||
|
output_variables: list[VariableDefinition] = Field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
VariableDefinition(
|
||||||
|
name="message",
|
||||||
|
type=VariableType.STRING,
|
||||||
|
description="用户输入的消息"
|
||||||
|
),
|
||||||
|
VariableDefinition(
|
||||||
|
name="conversation_vars",
|
||||||
|
type=VariableType.OBJECT,
|
||||||
|
description="会话级变量"
|
||||||
|
),
|
||||||
|
VariableDefinition(
|
||||||
|
name="execution_id",
|
||||||
|
type=VariableType.STRING,
|
||||||
|
description="执行 ID"
|
||||||
|
),
|
||||||
|
VariableDefinition(
|
||||||
|
name="conversation_id",
|
||||||
|
type=VariableType.STRING,
|
||||||
|
description="会话 ID"
|
||||||
|
),
|
||||||
|
VariableDefinition(
|
||||||
|
name="workspace_id",
|
||||||
|
type=VariableType.STRING,
|
||||||
|
description="工作空间 ID"
|
||||||
|
),
|
||||||
|
VariableDefinition(
|
||||||
|
name="user_id",
|
||||||
|
type=VariableType.STRING,
|
||||||
|
description="用户 ID"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
description="输出变量定义(自动生成,通常不需要修改)"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"examples": [
|
||||||
|
{
|
||||||
|
"description": "工作流开始节点",
|
||||||
|
"variables": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "带自定义变量的开始节点",
|
||||||
|
"variables": [
|
||||||
|
{
|
||||||
|
"name": "language",
|
||||||
|
"type": "string",
|
||||||
|
"required": False,
|
||||||
|
"default": "zh-CN",
|
||||||
|
"description": "语言设置"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "max_length",
|
||||||
|
"type": "number",
|
||||||
|
"required": False,
|
||||||
|
"default": 1000,
|
||||||
|
"description": "最大长度"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
136
api/app/core/workflow/nodes/start/node.py
Normal file
136
api/app/core/workflow/nodes/start/node.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
"""
|
||||||
|
Start 节点实现
|
||||||
|
|
||||||
|
工作流的起始节点,定义输入变量并输出系统参数。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
|
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StartNode(BaseNode):
|
||||||
|
"""Start 节点
|
||||||
|
|
||||||
|
工作流的起始节点,负责:
|
||||||
|
1. 定义工作流的输入变量(通过配置)
|
||||||
|
2. 输出系统变量(sys.*)
|
||||||
|
3. 输出会话变量(conv.*)
|
||||||
|
|
||||||
|
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
|
"""初始化 Start 节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_config: 节点配置
|
||||||
|
workflow_config: 工作流配置
|
||||||
|
"""
|
||||||
|
super().__init__(node_config, workflow_config)
|
||||||
|
|
||||||
|
# 解析并验证配置
|
||||||
|
self.typed_config = StartNodeConfig(**self.config)
|
||||||
|
|
||||||
|
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
|
"""执行 start 节点业务逻辑
|
||||||
|
|
||||||
|
Start 节点输出系统变量、会话变量和自定义变量。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含系统参数、会话变量和自定义变量的字典
|
||||||
|
"""
|
||||||
|
logger.info(f"节点 {self.node_id} (Start) 开始执行")
|
||||||
|
|
||||||
|
# 创建变量池实例(在方法内复用)
|
||||||
|
pool = self.get_variable_pool(state)
|
||||||
|
|
||||||
|
# 处理自定义变量(传入 pool 避免重复创建)
|
||||||
|
custom_vars = self._process_custom_variables(pool)
|
||||||
|
|
||||||
|
# 返回业务数据(包含自定义变量)
|
||||||
|
result = {
|
||||||
|
"message": pool.get("sys.message"),
|
||||||
|
"execution_id": pool.get("sys.execution_id"),
|
||||||
|
"conversation_id": pool.get("sys.conversation_id"),
|
||||||
|
"workspace_id": pool.get("sys.workspace_id"),
|
||||||
|
"user_id": pool.get("sys.user_id"),
|
||||||
|
**custom_vars # 自定义变量作为节点输出的一部分
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"节点 {self.node_id} (Start) 执行完成,"
|
||||||
|
f"输出了 {len(custom_vars)} 个自定义变量"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _process_custom_variables(self, pool) -> dict[str, Any]:
|
||||||
|
"""处理自定义变量
|
||||||
|
|
||||||
|
从输入数据中提取自定义变量,应用默认值和验证。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pool: 变量池实例
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的自定义变量字典
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 缺少必需变量
|
||||||
|
"""
|
||||||
|
# 获取输入数据中的自定义变量
|
||||||
|
input_variables = pool.get("sys.input_variables", default={})
|
||||||
|
|
||||||
|
processed = {}
|
||||||
|
|
||||||
|
# 遍历配置的变量定义
|
||||||
|
for var_def in self.typed_config.variables:
|
||||||
|
var_name = var_def.name
|
||||||
|
|
||||||
|
# 检查变量是否存在
|
||||||
|
if var_name in input_variables:
|
||||||
|
# 使用用户提供的值
|
||||||
|
processed[var_name] = input_variables[var_name]
|
||||||
|
|
||||||
|
elif var_def.required:
|
||||||
|
# 必需变量缺失
|
||||||
|
raise ValueError(
|
||||||
|
f"缺少必需的输入变量: {var_name}"
|
||||||
|
+ (f" ({var_def.description})" if var_def.description else "")
|
||||||
|
)
|
||||||
|
|
||||||
|
elif var_def.default is not None:
|
||||||
|
# 使用默认值
|
||||||
|
processed[var_name] = var_def.default
|
||||||
|
logger.debug(
|
||||||
|
f"变量 '{var_name}' 使用默认值: {var_def.default}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return processed
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
|
"""提取输入数据(用于记录)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
输入数据字典
|
||||||
|
"""
|
||||||
|
pool = self.get_variable_pool(state)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"execution_id": pool.get("sys.execution_id"),
|
||||||
|
"conversation_id": pool.get("sys.conversation_id"),
|
||||||
|
"message": pool.get("sys.message"),
|
||||||
|
"conversation_vars": pool.get_all_conversation_vars()
|
||||||
|
}
|
||||||
6
api/app/core/workflow/nodes/transform/__init__.py
Normal file
6
api/app/core/workflow/nodes/transform/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
"""Transform 节点"""
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.transform.node import TransformNode
|
||||||
|
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||||
|
|
||||||
|
__all__ = ["TransformNode", "TransformNodeConfig"]
|
||||||
80
api/app/core/workflow/nodes/transform/config.py
Normal file
80
api/app/core/workflow/nodes/transform/config.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
"""Transform 节点配置"""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||||
|
|
||||||
|
|
||||||
|
class TransformNodeConfig(BaseNodeConfig):
|
||||||
|
"""Transform 节点配置
|
||||||
|
|
||||||
|
用于数据转换和处理。
|
||||||
|
"""
|
||||||
|
|
||||||
|
transform_type: Literal["template", "code", "json"] = Field(
|
||||||
|
default="template",
|
||||||
|
description="转换类型:template(模板), code(代码), json(JSON处理)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 模板模式
|
||||||
|
template: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="转换模板,支持变量引用"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 代码模式
|
||||||
|
code: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Python 代码,用于数据转换"
|
||||||
|
)
|
||||||
|
|
||||||
|
# JSON 模式
|
||||||
|
json_path: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="JSON 路径表达式"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 输入变量
|
||||||
|
inputs: dict[str, str] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="输入变量映射,key 为变量名,value 为变量选择器"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 输出变量
|
||||||
|
output_key: str = Field(
|
||||||
|
default="result",
|
||||||
|
description="输出变量的键名"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 输出变量定义
|
||||||
|
output_variables: list[VariableDefinition] = Field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
VariableDefinition(
|
||||||
|
name="result",
|
||||||
|
type=VariableType.STRING,
|
||||||
|
description="转换后的结果"
|
||||||
|
)
|
||||||
|
],
|
||||||
|
description="输出变量定义(根据 output_key 动态生成)"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"examples": [
|
||||||
|
{
|
||||||
|
"transform_type": "template",
|
||||||
|
"template": "用户问题:{{ sys.message }}\n回答:{{ llm_qa.output }}",
|
||||||
|
"output_key": "formatted_result"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"transform_type": "code",
|
||||||
|
"code": "result = input_text.upper()",
|
||||||
|
"inputs": {
|
||||||
|
"input_text": "{{ sys.message }}"
|
||||||
|
},
|
||||||
|
"output_key": "uppercase_text"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
60
api/app/core/workflow/nodes/transform/node.py
Normal file
60
api/app/core/workflow/nodes/transform/node.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
"""
|
||||||
|
Transform 节点实现
|
||||||
|
|
||||||
|
数据转换节点,用于处理和转换数据。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformNode(BaseNode):
|
||||||
|
"""数据转换节点
|
||||||
|
|
||||||
|
配置示例:
|
||||||
|
{
|
||||||
|
"type": "transform",
|
||||||
|
"config": {
|
||||||
|
"mapping": {
|
||||||
|
"output_field": "{{node.previous.output}}",
|
||||||
|
"processed": "{{var.input | upper}}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
|
"""执行数据转换
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 工作流状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
状态更新字典
|
||||||
|
"""
|
||||||
|
logger.info(f"节点 {self.node_id} 开始执行数据转换")
|
||||||
|
|
||||||
|
# 获取映射配置
|
||||||
|
mapping = self.config.get("mapping", {})
|
||||||
|
|
||||||
|
# 执行数据转换
|
||||||
|
transformed_data = {}
|
||||||
|
for target_key, source_template in mapping.items():
|
||||||
|
# 渲染模板获取值
|
||||||
|
value = self._render_template(str(source_template), state)
|
||||||
|
transformed_data[target_key] = value
|
||||||
|
|
||||||
|
logger.info(f"节点 {self.node_id} 数据转换完成,输出字段: {list(transformed_data.keys())}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"node_outputs": {
|
||||||
|
self.node_id: {
|
||||||
|
"output": transformed_data,
|
||||||
|
"status": "completed"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
170
api/app/core/workflow/template_loader.py
Normal file
170
api/app/core/workflow/template_loader.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
"""
|
||||||
|
工作流模板加载器
|
||||||
|
|
||||||
|
从文件系统加载预定义的工作流模板
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class TemplateLoader:
|
||||||
|
"""工作流模板加载器"""
|
||||||
|
|
||||||
|
def __init__(self, templates_dir: str = "app/templates/workflows"):
|
||||||
|
"""初始化模板加载器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
templates_dir: 模板目录路径
|
||||||
|
"""
|
||||||
|
self.templates_dir = Path(templates_dir)
|
||||||
|
if not self.templates_dir.exists():
|
||||||
|
raise ValueError(f"模板目录不存在: {templates_dir}")
|
||||||
|
|
||||||
|
def list_templates(self) -> list[dict]:
|
||||||
|
"""列出所有可用的模板
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
模板列表,每个模板包含 id, name, description 等信息
|
||||||
|
"""
|
||||||
|
templates = []
|
||||||
|
|
||||||
|
# 遍历模板目录
|
||||||
|
for template_dir in self.templates_dir.iterdir():
|
||||||
|
if not template_dir.is_dir():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查是否有 template.yml 文件
|
||||||
|
template_file = template_dir / "template.yml"
|
||||||
|
if not template_file.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 读取模板配置
|
||||||
|
with open(template_file, 'r', encoding='utf-8') as f:
|
||||||
|
template_data = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# 提取模板信息
|
||||||
|
templates.append({
|
||||||
|
"id": template_dir.name,
|
||||||
|
"name": template_data.get("name", template_dir.name),
|
||||||
|
"description": template_data.get("description", ""),
|
||||||
|
"category": template_data.get("category", "general"),
|
||||||
|
"tags": template_data.get("tags", []),
|
||||||
|
"author": template_data.get("author", ""),
|
||||||
|
"version": template_data.get("version", "1.0.0")
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载模板 {template_dir.name} 失败: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return templates
|
||||||
|
|
||||||
|
def load_template(self, template_id: str) -> Optional[dict]:
|
||||||
|
"""加载指定的模板
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_id: 模板 ID(目录名)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
模板配置字典,如果模板不存在则返回 None
|
||||||
|
"""
|
||||||
|
template_dir = self.templates_dir / template_id
|
||||||
|
template_file = template_dir / "template.yml"
|
||||||
|
|
||||||
|
if not template_file.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(template_file, 'r', encoding='utf-8') as f:
|
||||||
|
template_data = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# 返回工作流配置部分
|
||||||
|
return {
|
||||||
|
"name": template_data.get("name", template_id),
|
||||||
|
"description": template_data.get("description", ""),
|
||||||
|
"nodes": template_data.get("nodes", []),
|
||||||
|
"edges": template_data.get("edges", []),
|
||||||
|
"variables": template_data.get("variables", []),
|
||||||
|
"execution_config": template_data.get("execution_config", {}),
|
||||||
|
"triggers": template_data.get("triggers", [])
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载模板 {template_id} 失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_template_readme(self, template_id: str) -> Optional[str]:
|
||||||
|
"""获取模板的 README 文档
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_id: 模板 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
README 内容,如果不存在则返回 None
|
||||||
|
"""
|
||||||
|
template_dir = self.templates_dir / template_id
|
||||||
|
readme_file = template_dir / "README.md"
|
||||||
|
|
||||||
|
if not readme_file.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(readme_file, 'r', encoding='utf-8') as f:
|
||||||
|
return f.read()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"读取模板 {template_id} 的 README 失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# 全局模板加载器实例
|
||||||
|
_template_loader: Optional[TemplateLoader] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_template_loader() -> TemplateLoader:
|
||||||
|
"""获取全局模板加载器实例
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TemplateLoader 实例
|
||||||
|
"""
|
||||||
|
global _template_loader
|
||||||
|
if _template_loader is None:
|
||||||
|
_template_loader = TemplateLoader()
|
||||||
|
return _template_loader
|
||||||
|
|
||||||
|
|
||||||
|
def list_workflow_templates() -> list[dict]:
|
||||||
|
"""列出所有工作流模板
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
模板列表
|
||||||
|
"""
|
||||||
|
loader = get_template_loader()
|
||||||
|
return loader.list_templates()
|
||||||
|
|
||||||
|
|
||||||
|
def load_workflow_template(template_id: str) -> Optional[dict]:
|
||||||
|
"""加载工作流模板
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_id: 模板 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
模板配置,如果不存在则返回 None
|
||||||
|
"""
|
||||||
|
loader = get_template_loader()
|
||||||
|
return loader.load_template(template_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_workflow_template_readme(template_id: str) -> Optional[str]:
|
||||||
|
"""获取工作流模板的 README
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_id: 模板 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
README 内容,如果不存在则返回 None
|
||||||
|
"""
|
||||||
|
loader = get_template_loader()
|
||||||
|
return loader.get_template_readme(template_id)
|
||||||
170
api/app/core/workflow/template_renderer.py
Normal file
170
api/app/core/workflow/template_renderer.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
"""
|
||||||
|
模板渲染器
|
||||||
|
|
||||||
|
使用 Jinja2 提供安全的模板渲染功能,支持变量引用和表达式。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from jinja2 import Template, TemplateSyntaxError, UndefinedError, Environment, StrictUndefined
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TemplateRenderer:
|
||||||
|
"""模板渲染器"""
|
||||||
|
|
||||||
|
def __init__(self, strict: bool = True):
|
||||||
|
"""初始化渲染器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strict: 是否使用严格模式(未定义变量会抛出异常)
|
||||||
|
"""
|
||||||
|
self.env = Environment(
|
||||||
|
undefined=StrictUndefined if strict else None,
|
||||||
|
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
|
||||||
|
)
|
||||||
|
|
||||||
|
def render(
|
||||||
|
self,
|
||||||
|
template: str,
|
||||||
|
variables: dict[str, Any],
|
||||||
|
node_outputs: dict[str, Any],
|
||||||
|
system_vars: dict[str, Any] | None = None
|
||||||
|
) -> str:
|
||||||
|
"""渲染模板
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template: 模板字符串
|
||||||
|
variables: 用户定义的变量
|
||||||
|
node_outputs: 节点输出结果
|
||||||
|
system_vars: 系统变量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
渲染后的字符串
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 模板语法错误或变量未定义
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> renderer = TemplateRenderer()
|
||||||
|
>>> renderer.render(
|
||||||
|
... "Hello {{var.name}}!",
|
||||||
|
... {"name": "World"},
|
||||||
|
... {},
|
||||||
|
... {}
|
||||||
|
... )
|
||||||
|
'Hello World!'
|
||||||
|
|
||||||
|
>>> renderer.render(
|
||||||
|
... "分析结果: {{node.analyze.output}}",
|
||||||
|
... {},
|
||||||
|
... {"analyze": {"output": "正面情绪"}},
|
||||||
|
... {}
|
||||||
|
... )
|
||||||
|
'分析结果: 正面情绪'
|
||||||
|
"""
|
||||||
|
# 构建命名空间上下文
|
||||||
|
context = {
|
||||||
|
"var": variables, # 用户变量:{{var.user_input}}
|
||||||
|
"node": node_outputs, # 节点输出:{{node.node_1.output}}
|
||||||
|
"sys": system_vars or {}, # 系统变量:{{sys.execution_id}}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 支持直接通过节点ID访问节点输出:{{llm_qa.output}}
|
||||||
|
# 将所有节点输出添加到顶层上下文
|
||||||
|
context.update(node_outputs)
|
||||||
|
|
||||||
|
# 为了向后兼容,也支持直接访问用户变量
|
||||||
|
context.update(variables)
|
||||||
|
context["nodes"] = node_outputs # 旧语法兼容
|
||||||
|
|
||||||
|
try:
|
||||||
|
tmpl = self.env.from_string(template)
|
||||||
|
return tmpl.render(**context)
|
||||||
|
|
||||||
|
except TemplateSyntaxError as e:
|
||||||
|
logger.error(f"模板语法错误: {template}, 错误: {e}")
|
||||||
|
raise ValueError(f"模板语法错误: {e}")
|
||||||
|
|
||||||
|
except UndefinedError as e:
|
||||||
|
logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}")
|
||||||
|
raise ValueError(f"未定义的变量: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"模板渲染异常: {template}, 错误: {e}")
|
||||||
|
raise ValueError(f"模板渲染失败: {e}")
|
||||||
|
|
||||||
|
def validate(self, template: str) -> list[str]:
|
||||||
|
"""验证模板语法
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template: 模板字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
错误列表,如果为空则验证通过
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> renderer = TemplateRenderer()
|
||||||
|
>>> renderer.validate("Hello {{var.name}}!")
|
||||||
|
[]
|
||||||
|
|
||||||
|
>>> renderer.validate("Hello {{var.name") # 缺少结束标记
|
||||||
|
['模板语法错误: ...']
|
||||||
|
"""
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.env.from_string(template)
|
||||||
|
except TemplateSyntaxError as e:
|
||||||
|
errors.append(f"模板语法错误: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(f"模板验证失败: {e}")
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
# 全局渲染器实例(严格模式)
|
||||||
|
_default_renderer = TemplateRenderer(strict=True)
|
||||||
|
|
||||||
|
|
||||||
|
def render_template(
|
||||||
|
template: str,
|
||||||
|
variables: dict[str, Any],
|
||||||
|
node_outputs: dict[str, Any],
|
||||||
|
system_vars: dict[str, Any] | None = None
|
||||||
|
) -> str:
|
||||||
|
"""渲染模板(便捷函数)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template: 模板字符串
|
||||||
|
variables: 用户变量
|
||||||
|
node_outputs: 节点输出
|
||||||
|
system_vars: 系统变量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
渲染后的字符串
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> render_template(
|
||||||
|
... "请分析: {{var.text}}",
|
||||||
|
... {"text": "这是一段文本"},
|
||||||
|
... {},
|
||||||
|
... {}
|
||||||
|
... )
|
||||||
|
'请分析: 这是一段文本'
|
||||||
|
"""
|
||||||
|
return _default_renderer.render(template, variables, node_outputs, system_vars)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_template(template: str) -> list[str]:
|
||||||
|
"""验证模板语法(便捷函数)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template: 模板字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
错误列表
|
||||||
|
"""
|
||||||
|
return _default_renderer.validate(template)
|
||||||
277
api/app/core/workflow/validator.py
Normal file
277
api/app/core/workflow/validator.py
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
"""
|
||||||
|
工作流配置验证器
|
||||||
|
|
||||||
|
验证工作流配置的有效性,确保配置符合规范。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowValidator:
|
||||||
|
"""工作流配置验证器"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate(workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
|
||||||
|
"""验证工作流配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(is_valid, errors): 是否有效和错误列表
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> config = {
|
||||||
|
... "nodes": [
|
||||||
|
... {"id": "start", "type": "start"},
|
||||||
|
... {"id": "end", "type": "end"}
|
||||||
|
... ],
|
||||||
|
... "edges": [
|
||||||
|
... {"source": "start", "target": "end"}
|
||||||
|
... ]
|
||||||
|
... }
|
||||||
|
>>> is_valid, errors = WorkflowValidator.validate(config)
|
||||||
|
>>> is_valid
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
# 支持字典和 Pydantic 模型
|
||||||
|
if isinstance(workflow_config, dict):
|
||||||
|
nodes = workflow_config.get("nodes", [])
|
||||||
|
edges = workflow_config.get("edges", [])
|
||||||
|
variables = workflow_config.get("variables", [])
|
||||||
|
else:
|
||||||
|
# Pydantic 模型
|
||||||
|
nodes = getattr(workflow_config, "nodes", [])
|
||||||
|
edges = getattr(workflow_config, "edges", [])
|
||||||
|
variables = getattr(workflow_config, "variables", [])
|
||||||
|
|
||||||
|
# 1. 验证 start 节点(有且只有一个)
|
||||||
|
start_nodes = [n for n in nodes if n.get("type") == "start"]
|
||||||
|
if len(start_nodes) == 0:
|
||||||
|
errors.append("工作流必须有一个 start 节点")
|
||||||
|
elif len(start_nodes) > 1:
|
||||||
|
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个")
|
||||||
|
|
||||||
|
# 2. 验证 end 节点(至少一个)
|
||||||
|
end_nodes = [n for n in nodes if n.get("type") == "end"]
|
||||||
|
if len(end_nodes) == 0:
|
||||||
|
errors.append("工作流必须至少有一个 end 节点")
|
||||||
|
|
||||||
|
# 3. 验证节点 ID 唯一性
|
||||||
|
node_ids = [n.get("id") for n in nodes]
|
||||||
|
if len(node_ids) != len(set(node_ids)):
|
||||||
|
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
|
||||||
|
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
|
||||||
|
|
||||||
|
# 4. 验证节点必须有 id 和 type
|
||||||
|
for i, node in enumerate(nodes):
|
||||||
|
if not node.get("id"):
|
||||||
|
errors.append(f"节点 #{i} 缺少 id 字段")
|
||||||
|
if not node.get("type"):
|
||||||
|
errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段")
|
||||||
|
|
||||||
|
# 5. 验证边的有效性
|
||||||
|
node_id_set = set(node_ids)
|
||||||
|
for i, edge in enumerate(edges):
|
||||||
|
source = edge.get("source")
|
||||||
|
target = edge.get("target")
|
||||||
|
|
||||||
|
if not source:
|
||||||
|
errors.append(f"边 #{i} 缺少 source 字段")
|
||||||
|
elif source not in node_id_set:
|
||||||
|
errors.append(f"边 #{i} 的 source 节点不存在: {source}")
|
||||||
|
|
||||||
|
if not target:
|
||||||
|
errors.append(f"边 #{i} 缺少 target 字段")
|
||||||
|
elif target not in node_id_set:
|
||||||
|
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
|
||||||
|
|
||||||
|
# 6. 验证所有节点可达(从 start 节点出发)
|
||||||
|
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
||||||
|
reachable = WorkflowValidator._get_reachable_nodes(
|
||||||
|
start_nodes[0]["id"],
|
||||||
|
edges
|
||||||
|
)
|
||||||
|
unreachable = node_id_set - reachable
|
||||||
|
if unreachable:
|
||||||
|
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
|
||||||
|
|
||||||
|
# 7. 检测循环依赖(非 loop 节点)
|
||||||
|
if not errors: # 只有在前面验证通过时才检查循环
|
||||||
|
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
|
||||||
|
if has_cycle:
|
||||||
|
errors.append(
|
||||||
|
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 8. 验证变量名
|
||||||
|
from app.core.workflow.expression_evaluator import ExpressionEvaluator
|
||||||
|
var_errors = ExpressionEvaluator.validate_variable_names(variables)
|
||||||
|
errors.extend(var_errors)
|
||||||
|
|
||||||
|
return len(errors) == 0, errors
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
|
||||||
|
"""获取从 start 节点可达的所有节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_id: 起始节点 ID
|
||||||
|
edges: 边列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
可达节点 ID 集合
|
||||||
|
"""
|
||||||
|
reachable = {start_id}
|
||||||
|
queue = [start_id]
|
||||||
|
|
||||||
|
while queue:
|
||||||
|
current = queue.pop(0)
|
||||||
|
for edge in edges:
|
||||||
|
if edge.get("source") == current:
|
||||||
|
target = edge.get("target")
|
||||||
|
if target and target not in reachable:
|
||||||
|
reachable.add(target)
|
||||||
|
queue.append(target)
|
||||||
|
|
||||||
|
return reachable
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _has_cycle(nodes: list[dict], edges: list[dict]) -> tuple[bool, list[str]]:
|
||||||
|
"""检测是否存在循环依赖(DFS)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nodes: 节点列表
|
||||||
|
edges: 边列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(has_cycle, cycle_path): 是否有循环和循环路径
|
||||||
|
"""
|
||||||
|
# 排除 loop 类型的节点
|
||||||
|
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
|
||||||
|
|
||||||
|
# 构建邻接表(排除 loop 节点的边和错误边)
|
||||||
|
graph: dict[str, list[str]] = {}
|
||||||
|
for edge in edges:
|
||||||
|
source = edge.get("source")
|
||||||
|
target = edge.get("target")
|
||||||
|
edge_type = edge.get("type")
|
||||||
|
|
||||||
|
# 跳过错误边
|
||||||
|
if edge_type == "error":
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 如果涉及 loop 节点,跳过
|
||||||
|
if source in loop_nodes or target in loop_nodes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if source and target:
|
||||||
|
if source not in graph:
|
||||||
|
graph[source] = []
|
||||||
|
graph[source].append(target)
|
||||||
|
|
||||||
|
# DFS 检测环
|
||||||
|
visited = set()
|
||||||
|
rec_stack = set()
|
||||||
|
path = []
|
||||||
|
cycle_path = []
|
||||||
|
|
||||||
|
def dfs(node: str) -> bool:
|
||||||
|
"""DFS 检测环,返回是否找到环"""
|
||||||
|
visited.add(node)
|
||||||
|
rec_stack.add(node)
|
||||||
|
path.append(node)
|
||||||
|
|
||||||
|
for neighbor in graph.get(node, []):
|
||||||
|
if neighbor not in visited:
|
||||||
|
if dfs(neighbor):
|
||||||
|
return True
|
||||||
|
elif neighbor in rec_stack:
|
||||||
|
# 找到环,记录环路径
|
||||||
|
cycle_start = path.index(neighbor)
|
||||||
|
cycle_path.extend([*path[cycle_start:], neighbor])
|
||||||
|
return True
|
||||||
|
|
||||||
|
rec_stack.remove(node)
|
||||||
|
path.pop()
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查所有节点
|
||||||
|
for node_id in graph:
|
||||||
|
if node_id not in visited:
|
||||||
|
if dfs(node_id):
|
||||||
|
return True, cycle_path
|
||||||
|
|
||||||
|
return False, []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_for_publish(workflow_config: dict[str, Any]) -> tuple[bool, list[str]]:
|
||||||
|
"""验证工作流配置是否可以发布(更严格的验证)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_config: 工作流配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(is_valid, errors): 是否有效和错误列表
|
||||||
|
"""
|
||||||
|
# 先执行基础验证
|
||||||
|
is_valid, errors = WorkflowValidator.validate(workflow_config)
|
||||||
|
|
||||||
|
if not is_valid:
|
||||||
|
return False, errors
|
||||||
|
|
||||||
|
# 额外的发布验证
|
||||||
|
nodes = workflow_config.get("nodes", [])
|
||||||
|
|
||||||
|
# 1. 验证所有节点都有名称
|
||||||
|
for node in nodes:
|
||||||
|
if node.get("type") not in ["start", "end"] and not node.get("name"):
|
||||||
|
errors.append(
|
||||||
|
f"节点 {node.get('id')} 缺少名称(发布时必须提供)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 验证所有非 start/end 节点都有配置
|
||||||
|
for node in nodes:
|
||||||
|
node_type = node.get("type")
|
||||||
|
if node_type not in ["start", "end"]:
|
||||||
|
config = node.get("config")
|
||||||
|
if not config or not isinstance(config, dict):
|
||||||
|
errors.append(
|
||||||
|
f"节点 {node.get('id')} 缺少配置(发布时必须提供)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 验证必填变量
|
||||||
|
variables = workflow_config.get("variables", [])
|
||||||
|
required_vars = [v for v in variables if v.get("required")]
|
||||||
|
if required_vars:
|
||||||
|
# 这里只是提示,实际执行时会检查
|
||||||
|
logger.info(
|
||||||
|
f"工作流包含 {len(required_vars)} 个必填变量: "
|
||||||
|
f"{[v.get('name') for v in required_vars]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return len(errors) == 0, errors
|
||||||
|
|
||||||
|
|
||||||
|
def validate_workflow_config(
|
||||||
|
workflow_config: dict[str, Any],
|
||||||
|
for_publish: bool = False
|
||||||
|
) -> tuple[bool, list[str]]:
|
||||||
|
"""验证工作流配置(便捷函数)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_config: 工作流配置
|
||||||
|
for_publish: 是否为发布验证(更严格)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(is_valid, errors): 是否有效和错误列表
|
||||||
|
"""
|
||||||
|
if for_publish:
|
||||||
|
return WorkflowValidator.validate_for_publish(workflow_config)
|
||||||
|
else:
|
||||||
|
return WorkflowValidator.validate(workflow_config)
|
||||||
293
api/app/core/workflow/variable_pool.py
Normal file
293
api/app/core/workflow/variable_pool.py
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
"""
|
||||||
|
变量池 (Variable Pool)
|
||||||
|
|
||||||
|
工作流执行的数据中心,管理所有变量的存储和访问。
|
||||||
|
|
||||||
|
变量类型:
|
||||||
|
1. 系统变量 (sys.*) - 系统内置变量(execution_id, workspace_id, user_id, message 等)
|
||||||
|
2. 节点输出 (node_id.*) - 节点执行结果
|
||||||
|
3. 会话变量 (conv.*) - 会话级变量(跨多轮对话保持)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VariableSelector:
|
||||||
|
"""变量选择器
|
||||||
|
|
||||||
|
用于引用变量的路径表示。
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> selector = VariableSelector(["sys", "message"])
|
||||||
|
>>> selector = VariableSelector(["node_A", "output"])
|
||||||
|
>>> selector = VariableSelector.from_string("sys.message")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, path: list[str]):
|
||||||
|
"""初始化变量选择器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: 变量路径,如 ["sys", "message"] 或 ["node_A", "output"]
|
||||||
|
"""
|
||||||
|
if not path or len(path) < 1:
|
||||||
|
raise ValueError("变量路径不能为空")
|
||||||
|
|
||||||
|
self.path = path
|
||||||
|
self.namespace = path[0] # sys, var, 或 node_id
|
||||||
|
self.key = path[1] if len(path) > 1 else None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_string(cls, selector_str: str) -> "VariableSelector":
|
||||||
|
"""从字符串创建选择器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
selector_str: 选择器字符串,如 "sys.message" 或 "node_A.output"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
VariableSelector 实例
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> selector = VariableSelector.from_string("sys.message")
|
||||||
|
>>> selector = VariableSelector.from_string("llm_qa.output")
|
||||||
|
"""
|
||||||
|
path = selector_str.split(".")
|
||||||
|
return cls(path)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return ".".join(self.path)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"VariableSelector({self.path})"
|
||||||
|
|
||||||
|
|
||||||
|
class VariablePool:
|
||||||
|
"""变量池
|
||||||
|
|
||||||
|
管理工作流执行过程中的所有变量。
|
||||||
|
|
||||||
|
变量命名空间:
|
||||||
|
- sys.*: 系统变量(message, execution_id, workspace_id, user_id, conversation_id)
|
||||||
|
- conv.*: 会话变量(跨多轮对话保持的变量)
|
||||||
|
- <node_id>.*: 节点输出
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> pool = VariablePool(state)
|
||||||
|
>>> pool.get(["sys", "message"])
|
||||||
|
"用户的问题"
|
||||||
|
>>> pool.get(["llm_qa", "output"])
|
||||||
|
"AI 的回答"
|
||||||
|
>>> pool.set(["conv", "user_name"], "张三")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, state: dict[str, Any]):
|
||||||
|
"""初始化变量池
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: 工作流状态(LangGraph State)
|
||||||
|
"""
|
||||||
|
self.state = state
|
||||||
|
|
||||||
|
def get(self, selector: list[str] | str, default: Any = None) -> Any:
|
||||||
|
"""获取变量值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
selector: 变量选择器,可以是列表或字符串
|
||||||
|
default: 默认值(变量不存在时返回)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
变量值
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> pool.get(["sys", "message"])
|
||||||
|
>>> pool.get("sys.message")
|
||||||
|
>>> pool.get(["llm_qa", "output"])
|
||||||
|
>>> pool.get("llm_qa.output")
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: 变量不存在且未提供默认值
|
||||||
|
"""
|
||||||
|
# 转换为 VariableSelector
|
||||||
|
if isinstance(selector, str):
|
||||||
|
selector = VariableSelector.from_string(selector).path
|
||||||
|
|
||||||
|
if not selector or len(selector) < 1:
|
||||||
|
raise ValueError("变量选择器不能为空")
|
||||||
|
|
||||||
|
namespace = selector[0]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 系统变量
|
||||||
|
if namespace == "sys":
|
||||||
|
key = selector[1] if len(selector) > 1 else None
|
||||||
|
if not key:
|
||||||
|
return self.state.get("variables", {}).get("sys", {})
|
||||||
|
return self.state.get("variables", {}).get("sys", {}).get(key, default)
|
||||||
|
|
||||||
|
# 会话变量
|
||||||
|
elif namespace == "conv":
|
||||||
|
key = selector[1] if len(selector) > 1 else None
|
||||||
|
if not key:
|
||||||
|
return self.state.get("variables", {}).get("conv", {})
|
||||||
|
return self.state.get("variables", {}).get("conv", {}).get(key, default)
|
||||||
|
|
||||||
|
# 节点输出(从 runtime_vars 读取)
|
||||||
|
else:
|
||||||
|
node_id = namespace
|
||||||
|
runtime_vars = self.state.get("runtime_vars", {})
|
||||||
|
|
||||||
|
if node_id not in runtime_vars:
|
||||||
|
if default is not None:
|
||||||
|
return default
|
||||||
|
raise KeyError(f"节点 '{node_id}' 的输出不存在")
|
||||||
|
|
||||||
|
node_var = runtime_vars[node_id]
|
||||||
|
|
||||||
|
# 如果只有节点 ID,返回整个变量
|
||||||
|
if len(selector) == 1:
|
||||||
|
return node_var
|
||||||
|
|
||||||
|
# 获取特定字段
|
||||||
|
# 支持嵌套访问,如 node_id.field.subfield
|
||||||
|
result = node_var
|
||||||
|
for k in selector[1:]:
|
||||||
|
if isinstance(result, dict):
|
||||||
|
result = result.get(k)
|
||||||
|
if result is None:
|
||||||
|
if default is not None:
|
||||||
|
return default
|
||||||
|
raise KeyError(f"字段 '{'.'.join(selector)}' 不存在")
|
||||||
|
else:
|
||||||
|
if default is not None:
|
||||||
|
return default
|
||||||
|
raise KeyError(f"无法访问 '{'.'.join(selector)}'")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except KeyError:
|
||||||
|
if default is not None:
|
||||||
|
return default
|
||||||
|
raise
|
||||||
|
|
||||||
|
def set(self, selector: list[str] | str, value: Any):
|
||||||
|
"""设置变量值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
selector: 变量选择器
|
||||||
|
value: 变量值
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> pool.set(["conv", "user_name"], "张三")
|
||||||
|
>>> pool.set("conv.user_name", "张三")
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- 只能设置会话变量 (conv.*)
|
||||||
|
- 系统变量和节点输出是只读的
|
||||||
|
"""
|
||||||
|
# 转换为 VariableSelector
|
||||||
|
if isinstance(selector, str):
|
||||||
|
selector = VariableSelector.from_string(selector).path
|
||||||
|
|
||||||
|
if not selector or len(selector) < 2:
|
||||||
|
raise ValueError("变量选择器必须包含命名空间和键名")
|
||||||
|
|
||||||
|
namespace = selector[0]
|
||||||
|
|
||||||
|
if namespace != "conv":
|
||||||
|
raise ValueError("只能设置会话变量 (conv.*)")
|
||||||
|
|
||||||
|
key = selector[1]
|
||||||
|
|
||||||
|
# 确保 variables 结构存在
|
||||||
|
if "variables" not in self.state:
|
||||||
|
self.state["variables"] = {"sys": {}, "conv": {}}
|
||||||
|
if "conv" not in self.state["variables"]:
|
||||||
|
self.state["variables"]["conv"] = {}
|
||||||
|
|
||||||
|
# 设置值
|
||||||
|
self.state["variables"]["conv"][key] = value
|
||||||
|
|
||||||
|
logger.debug(f"设置变量: {'.'.join(selector)} = {value}")
|
||||||
|
|
||||||
|
def has(self, selector: list[str] | str) -> bool:
|
||||||
|
"""检查变量是否存在
|
||||||
|
|
||||||
|
Args:
|
||||||
|
selector: 变量选择器
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
变量是否存在
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> pool.has(["sys", "message"])
|
||||||
|
True
|
||||||
|
>>> pool.has("llm_qa.output")
|
||||||
|
False
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.get(selector)
|
||||||
|
return True
|
||||||
|
except KeyError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_all_system_vars(self) -> dict[str, Any]:
|
||||||
|
"""获取所有系统变量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
系统变量字典
|
||||||
|
"""
|
||||||
|
return self.state.get("variables", {}).get("sys", {})
|
||||||
|
|
||||||
|
def get_all_conversation_vars(self) -> dict[str, Any]:
|
||||||
|
"""获取所有会话变量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
会话变量字典
|
||||||
|
"""
|
||||||
|
return self.state.get("variables", {}).get("conv", {})
|
||||||
|
|
||||||
|
def get_all_node_outputs(self) -> dict[str, Any]:
|
||||||
|
"""获取所有节点输出(运行时变量)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
节点输出字典,键为节点 ID
|
||||||
|
"""
|
||||||
|
return self.state.get("runtime_vars", {})
|
||||||
|
|
||||||
|
def get_node_output(self, node_id: str) -> dict[str, Any] | None:
|
||||||
|
"""获取指定节点的输出(运行时变量)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_id: 节点 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
节点输出或 None
|
||||||
|
"""
|
||||||
|
return self.state.get("runtime_vars", {}).get(node_id)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""导出为字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含所有变量的字典
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"system": self.get_all_system_vars(),
|
||||||
|
"conversation": self.get_all_conversation_vars(),
|
||||||
|
"nodes": self.get_all_node_outputs() # 从 runtime_vars 读取
|
||||||
|
}
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
sys_vars = self.get_all_system_vars()
|
||||||
|
conv_vars = self.get_all_conversation_vars()
|
||||||
|
runtime_vars = self.get_all_node_outputs()
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"VariablePool(\n"
|
||||||
|
f" system_vars={len(sys_vars)},\n"
|
||||||
|
f" conversation_vars={len(conv_vars)},\n"
|
||||||
|
f" runtime_vars={len(runtime_vars)}\n"
|
||||||
|
f")"
|
||||||
|
)
|
||||||
@@ -1,10 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
from dotenv import load_dotenv
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from app.core.config import settings
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from app.core.response_utils import fail
|
from app.core.response_utils import fail
|
||||||
from app.core.logging_config import LoggingConfig, get_logger
|
from app.core.logging_config import LoggingConfig, get_logger
|
||||||
@@ -38,9 +37,13 @@ router = APIRouter(prefix="/memory", tags=["Memory"])
|
|||||||
|
|
||||||
# 管理端 API (JWT 认证)
|
# 管理端 API (JWT 认证)
|
||||||
from app.controllers import manager_router
|
from app.controllers import manager_router
|
||||||
|
|
||||||
# 服务端 API (API Key 认证)
|
# 服务端 API (API Key 认证)
|
||||||
from app.controllers.service import service_router
|
from app.controllers.service import service_router
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.error_codes import BizCode, HTTP_MAPPING
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.logging_config import LoggingConfig, get_logger
|
||||||
|
from app.core.response_utils import fail
|
||||||
|
|
||||||
# Initialize logging system
|
# Initialize logging system
|
||||||
LoggingConfig.setup_logging()
|
LoggingConfig.setup_logging()
|
||||||
@@ -414,5 +417,4 @@ async def unhandled_exception_handler(request: Request, exc: Exception):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
|
|||||||
@@ -15,9 +15,11 @@ from .end_user_model import EndUser
|
|||||||
from .appshare_model import AppShare
|
from .appshare_model import AppShare
|
||||||
from .release_share_model import ReleaseShare
|
from .release_share_model import ReleaseShare
|
||||||
from .conversation_model import Conversation, Message
|
from .conversation_model import Conversation, Message
|
||||||
from .api_key_model import ApiKey, ApiKeyLog, ApiKeyType, ResourceType
|
from .api_key_model import ApiKey, ApiKeyLog, ApiKeyType
|
||||||
from .data_config_model import DataConfig
|
from .data_config_model import DataConfig
|
||||||
from .multi_agent_model import MultiAgentConfig, AgentInvocation
|
from .multi_agent_model import MultiAgentConfig, AgentInvocation
|
||||||
|
from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution
|
||||||
|
from .retrieval_info import RetrievalInfo
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Tenants",
|
"Tenants",
|
||||||
@@ -46,8 +48,11 @@ __all__ = [
|
|||||||
"ApiKey",
|
"ApiKey",
|
||||||
"ApiKeyLog",
|
"ApiKeyLog",
|
||||||
"ApiKeyType",
|
"ApiKeyType",
|
||||||
"ResourceType",
|
|
||||||
"DataConfig",
|
"DataConfig",
|
||||||
"MultiAgentConfig",
|
"MultiAgentConfig",
|
||||||
"AgentInvocation"
|
"AgentInvocation",
|
||||||
|
"WorkflowConfig",
|
||||||
|
"WorkflowExecution",
|
||||||
|
"WorkflowNodeExecution",
|
||||||
|
"RetrievalInfo"
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from sqlalchemy import Column, String, Boolean, DateTime, Integer, ForeignKey, Text, Enum
|
from sqlalchemy import Column, String, Boolean, DateTime, Integer, ForeignKey, Text
|
||||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
@@ -12,18 +12,10 @@ from app.db import Base
|
|||||||
|
|
||||||
class ApiKeyType(StrEnum):
|
class ApiKeyType(StrEnum):
|
||||||
"""API Key 类型"""
|
"""API Key 类型"""
|
||||||
APP = "app" # 应用 API Key
|
AGENT = "agent" # 智能体
|
||||||
RAG = "rag" # RAG API Key
|
CLUSTER = "cluster" # 集群
|
||||||
MEMORY = "memory" # Memory API Key
|
WORKFLOW = "workflow" # 工作流
|
||||||
|
SERVICE = "service" # 服务
|
||||||
|
|
||||||
class ResourceType(StrEnum):
|
|
||||||
"""资源类型枚举"""
|
|
||||||
AGENT = "Agent" # 智能体
|
|
||||||
CLUSTER = "Cluster" # 集群
|
|
||||||
WORKFLOW = "Workflow" # 工作流
|
|
||||||
KNOWLEDGE = "Knowledge" # 知识库
|
|
||||||
MEMORY_ENGINE = "Memory_Engine" # 记忆引擎
|
|
||||||
|
|
||||||
|
|
||||||
class ApiKey(Base):
|
class ApiKey(Base):
|
||||||
@@ -35,18 +27,16 @@ class ApiKey(Base):
|
|||||||
# 基本信息
|
# 基本信息
|
||||||
name = Column(String(255), nullable=False, comment="API Key 名称")
|
name = Column(String(255), nullable=False, comment="API Key 名称")
|
||||||
description = Column(Text, comment="描述")
|
description = Column(Text, comment="描述")
|
||||||
key_prefix = Column(String(20), nullable=False, comment="Key 前缀")
|
api_key = Column(String(255), nullable=False, unique=True, index=True, comment="API Key 明文")
|
||||||
key_hash = Column(String(255), nullable=False, unique=True, index=True, comment="Key 哈希值")
|
|
||||||
|
|
||||||
# 类型和权限
|
# 类型和权限
|
||||||
type = Column(String(50), nullable=False, index=True, comment="API Key 类型")
|
type = Column(String(50), nullable=False, index=True, comment="API Key 类型")
|
||||||
scopes = Column(JSONB, nullable=False, default=list, comment="权限范围列表")
|
scopes = Column(JSONB, default=list, comment="权限范围列表")
|
||||||
|
|
||||||
# 关联资源
|
# 关联资源
|
||||||
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="CASCADE"), nullable=False,
|
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="CASCADE"), nullable=False,
|
||||||
index=True, comment="所属工作空间")
|
index=True, comment="所属工作空间")
|
||||||
resource_id = Column(UUID(as_uuid=True), index=True, comment="关联资源ID")
|
resource_id = Column(UUID(as_uuid=True), index=True, comment="关联资源ID")
|
||||||
resource_type = Column(String(50), comment="资源类型")
|
|
||||||
|
|
||||||
# 限制和配额
|
# 限制和配额
|
||||||
rate_limit = Column(Integer, default=10, comment="QPS限制(请求/秒)")
|
rate_limit = Column(Integer, default=10, comment="QPS限制(请求/秒)")
|
||||||
|
|||||||
@@ -87,6 +87,14 @@ class App(Base):
|
|||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 一对一:工作流配置(仅当 type=workflow 时有效)
|
||||||
|
workflow_config = relationship(
|
||||||
|
"WorkflowConfig",
|
||||||
|
back_populates="app",
|
||||||
|
uselist=False,
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
# 发布版本关联
|
# 发布版本关联
|
||||||
current_release = relationship("AppRelease", foreign_keys=[current_release_id])
|
current_release = relationship("AppRelease", foreign_keys=[current_release_id])
|
||||||
# 指定外键以避免与 current_release_id 造成歧义
|
# 指定外键以避免与 current_release_id 造成歧义
|
||||||
|
|||||||
196
api/app/models/workflow_model.py
Normal file
196
api/app/models/workflow_model.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
"""
|
||||||
|
工作流相关数据模型
|
||||||
|
"""
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import uuid
|
||||||
|
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float, ForeignKey, Text
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
from app.db import Base
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowConfig(Base):
|
||||||
|
"""工作流配置表"""
|
||||||
|
__tablename__ = "workflow_configs"
|
||||||
|
|
||||||
|
# 主键
|
||||||
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||||
|
|
||||||
|
# 关联应用(一对一)
|
||||||
|
app_id = Column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("apps.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
unique=True,
|
||||||
|
index=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 节点和边的定义(JSON 格式)
|
||||||
|
nodes = Column(JSONB, nullable=False, default=list)
|
||||||
|
edges = Column(JSONB, nullable=False, default=list)
|
||||||
|
|
||||||
|
# 全局变量定义
|
||||||
|
variables = Column(JSONB, default=list)
|
||||||
|
|
||||||
|
# 执行配置
|
||||||
|
execution_config = Column(JSONB, nullable=False, default=dict)
|
||||||
|
|
||||||
|
# 触发器配置(可选)
|
||||||
|
triggers = Column(JSONB, default=list)
|
||||||
|
|
||||||
|
# 状态
|
||||||
|
is_active = Column(Boolean, nullable=False, default=True)
|
||||||
|
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||||
|
updated_at = Column(
|
||||||
|
DateTime,
|
||||||
|
nullable=False,
|
||||||
|
default=datetime.datetime.now,
|
||||||
|
onupdate=datetime.datetime.now
|
||||||
|
)
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
app = relationship("App", back_populates="workflow_config")
|
||||||
|
executions = relationship(
|
||||||
|
"WorkflowExecution",
|
||||||
|
back_populates="workflow_config",
|
||||||
|
cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<WorkflowConfig(id={self.id}, app_id={self.app_id})>"
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowExecution(Base):
|
||||||
|
"""工作流执行记录表"""
|
||||||
|
__tablename__ = "workflow_executions"
|
||||||
|
|
||||||
|
# 主键
|
||||||
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||||
|
|
||||||
|
# 关联信息
|
||||||
|
workflow_config_id = Column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("workflow_configs.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True
|
||||||
|
)
|
||||||
|
app_id = Column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("apps.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True
|
||||||
|
)
|
||||||
|
conversation_id = Column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("conversations.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
index=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 执行信息
|
||||||
|
execution_id = Column(String(100), nullable=False, unique=True, index=True)
|
||||||
|
trigger_type = Column(String(20), nullable=False) # manual, schedule, webhook, event
|
||||||
|
triggered_by = Column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("users.id"),
|
||||||
|
nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 输入输出
|
||||||
|
input_data = Column(JSONB)
|
||||||
|
output_data = Column(JSONB)
|
||||||
|
context = Column(JSONB, default=dict)
|
||||||
|
|
||||||
|
# 状态
|
||||||
|
status = Column(String(20), nullable=False, default="pending", index=True)
|
||||||
|
# 可选值:pending, running, completed, failed, cancelled, timeout
|
||||||
|
|
||||||
|
error_message = Column(Text)
|
||||||
|
error_node_id = Column(String(100))
|
||||||
|
|
||||||
|
# 性能指标
|
||||||
|
started_at = Column(DateTime, nullable=False, default=datetime.datetime.now, index=True)
|
||||||
|
completed_at = Column(DateTime)
|
||||||
|
elapsed_time = Column(Float) # 耗时(秒)
|
||||||
|
|
||||||
|
# 资源使用
|
||||||
|
token_usage = Column(JSONB)
|
||||||
|
|
||||||
|
# 元数据(使用 meta_data 避免与 SQLAlchemy 保留字 metadata 冲突)
|
||||||
|
meta_data = Column(JSONB, default=dict)
|
||||||
|
|
||||||
|
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
workflow_config = relationship("WorkflowConfig", back_populates="executions")
|
||||||
|
app = relationship("App")
|
||||||
|
conversation = relationship("Conversation")
|
||||||
|
triggered_by_user = relationship("User", foreign_keys=[triggered_by])
|
||||||
|
node_executions = relationship(
|
||||||
|
"WorkflowNodeExecution",
|
||||||
|
back_populates="execution",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
order_by="WorkflowNodeExecution.execution_order"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<WorkflowExecution(id={self.id}, execution_id={self.execution_id}, status={self.status})>"
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowNodeExecution(Base):
|
||||||
|
"""工作流节点执行记录表"""
|
||||||
|
__tablename__ = "workflow_node_executions"
|
||||||
|
|
||||||
|
# 主键
|
||||||
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||||
|
|
||||||
|
# 关联执行
|
||||||
|
execution_id = Column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("workflow_executions.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 节点信息
|
||||||
|
node_id = Column(String(100), nullable=False, index=True)
|
||||||
|
node_type = Column(String(20), nullable=False)
|
||||||
|
node_name = Column(String(100))
|
||||||
|
|
||||||
|
# 执行顺序
|
||||||
|
execution_order = Column(Integer, nullable=False)
|
||||||
|
retry_count = Column(Integer, nullable=False, default=0)
|
||||||
|
|
||||||
|
# 输入输出
|
||||||
|
input_data = Column(JSONB)
|
||||||
|
output_data = Column(JSONB)
|
||||||
|
|
||||||
|
# 状态
|
||||||
|
status = Column(String(20), nullable=False, default="pending", index=True)
|
||||||
|
# 可选值:pending, running, completed, failed, skipped, cached
|
||||||
|
|
||||||
|
error_message = Column(Text)
|
||||||
|
|
||||||
|
# 性能指标
|
||||||
|
started_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||||
|
completed_at = Column(DateTime)
|
||||||
|
elapsed_time = Column(Float) # 耗时(秒)
|
||||||
|
|
||||||
|
# 资源使用(针对 LLM 节点)
|
||||||
|
token_usage = Column(JSONB)
|
||||||
|
|
||||||
|
# 缓存信息
|
||||||
|
cache_hit = Column(Boolean, default=False)
|
||||||
|
cache_key = Column(String(255))
|
||||||
|
|
||||||
|
# 元数据(使用 meta_data 避免与 SQLAlchemy 保留字 metadata 冲突)
|
||||||
|
meta_data = Column(JSONB, default=dict)
|
||||||
|
|
||||||
|
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
execution = relationship("WorkflowExecution", back_populates="node_executions")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<WorkflowNodeExecution(id={self.id}, node_id={self.node_id}, status={self.status})>"
|
||||||
@@ -27,9 +27,9 @@ class ApiKeyRepository:
|
|||||||
return db.get(ApiKey, api_key_id)
|
return db.get(ApiKey, api_key_id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_by_hash(db: Session, key_hash: str) -> Optional[ApiKey]:
|
def get_by_api_key(db: Session, api_key: str) -> Optional[ApiKey]:
|
||||||
"""根据哈希值获取 API Key"""
|
"""根据 API Key 获取 API Key"""
|
||||||
stmt = select(ApiKey).where(ApiKey.key_hash == key_hash)
|
stmt = select(ApiKey).where(ApiKey.api_key == api_key)
|
||||||
return db.scalars(stmt).first()
|
return db.scalars(stmt).first()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -63,11 +63,15 @@ class ApiKeyRepository:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def update(db: Session, api_key_id: uuid.UUID, update_data: dict) -> ApiKey | None:
|
def update(db: Session, api_key_id: uuid.UUID, update_data: dict) -> ApiKey | None:
|
||||||
"""更新 API Key"""
|
"""更新 API Key"""
|
||||||
|
allow_none_fields = {"description", "quota_limit", "expires_at"}
|
||||||
api_key = db.get(ApiKey, api_key_id)
|
api_key = db.get(ApiKey, api_key_id)
|
||||||
if api_key:
|
if api_key:
|
||||||
for key, value in update_data.items():
|
for key, value in update_data.items():
|
||||||
if value is not None:
|
if key in allow_none_fields:
|
||||||
setattr(api_key, key, value)
|
setattr(api_key, key, value)
|
||||||
|
else:
|
||||||
|
if value is not None:
|
||||||
|
setattr(api_key, key, value)
|
||||||
db.flush()
|
db.flush()
|
||||||
return api_key
|
return api_key
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ class AppRepository:
|
|||||||
def __init__(self, db: Session):
|
def __init__(self, db: Session):
|
||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
def get_apps_by_workspace_id(self, workspace_id: uuid.UUID) -> List[App]:
|
def get_apps_by_workspace_id(self, workspace_id: uuid.UUID) -> list[App]:
|
||||||
"""根据工作空间ID查询应用"""
|
"""根据工作空间ID查询应用"""
|
||||||
try:
|
try:
|
||||||
apps = self.db.query(App).filter(App.workspace_id == workspace_id).all()
|
apps = self.db.query(App).filter(App.workspace_id == workspace_id).all()
|
||||||
@@ -24,7 +24,19 @@ class AppRepository:
|
|||||||
db_logger.error(f"查询工作空间 {workspace_id} 下应用时出错: {str(e)}")
|
db_logger.error(f"查询工作空间 {workspace_id} 下应用时出错: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def get_apps_by_id(self, app_id: uuid.UUID) -> App:
|
||||||
|
try:
|
||||||
|
app = self.db.query(App).filter(App.id == app_id, App.is_active == True).first()
|
||||||
|
return app
|
||||||
|
except Exception as e:
|
||||||
|
raise
|
||||||
|
|
||||||
def get_apps_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> List[App]:
|
def get_apps_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> List[App]:
|
||||||
"""根据工作空间ID查询应用"""
|
"""根据工作空间ID查询应用"""
|
||||||
repo = AppRepository(db)
|
repo = AppRepository(db)
|
||||||
return repo.get_apps_by_workspace_id(workspace_id)
|
return repo.get_apps_by_workspace_id(workspace_id)
|
||||||
|
|
||||||
|
def get_apps_by_id(db: Session, app_id: uuid.UUID) -> App:
|
||||||
|
"""根据工作空间ID查询应用"""
|
||||||
|
repo = AppRepository(db)
|
||||||
|
return repo.get_apps_by_id(app_id)
|
||||||
|
|||||||
247
api/app/repositories/workflow_repository.py
Normal file
247
api/app/repositories/workflow_repository.py
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
"""
|
||||||
|
工作流数据访问层
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Annotated
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy import desc
|
||||||
|
from fastapi import Depends
|
||||||
|
|
||||||
|
from app.models.workflow_model import (
|
||||||
|
WorkflowConfig,
|
||||||
|
WorkflowExecution,
|
||||||
|
WorkflowNodeExecution
|
||||||
|
)
|
||||||
|
from app.db import get_db
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowConfigRepository:
|
||||||
|
"""工作流配置仓储"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def get_by_app_id(self, app_id: uuid.UUID) -> WorkflowConfig | None:
|
||||||
|
"""根据应用 ID 获取工作流配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工作流配置或 None
|
||||||
|
"""
|
||||||
|
return self.db.query(WorkflowConfig).filter(
|
||||||
|
WorkflowConfig.app_id == app_id,
|
||||||
|
WorkflowConfig.is_active == True
|
||||||
|
).first()
|
||||||
|
|
||||||
|
def create_or_update(
|
||||||
|
self,
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
nodes: list[dict[str, Any]],
|
||||||
|
edges: list[dict[str, Any]],
|
||||||
|
variables: list[dict[str, Any]] | None = None,
|
||||||
|
execution_config: dict[str, Any] | None = None,
|
||||||
|
triggers: list[dict[str, Any]] | None = None
|
||||||
|
) -> WorkflowConfig:
|
||||||
|
"""创建或更新工作流配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用 ID
|
||||||
|
nodes: 节点列表
|
||||||
|
edges: 边列表
|
||||||
|
variables: 变量列表
|
||||||
|
execution_config: 执行配置
|
||||||
|
triggers: 触发器列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工作流配置
|
||||||
|
"""
|
||||||
|
# 查找现有配置
|
||||||
|
existing = self.get_by_app_id(app_id)
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
# 更新现有配置
|
||||||
|
existing.nodes = nodes
|
||||||
|
existing.edges = edges
|
||||||
|
if variables is not None:
|
||||||
|
existing.variables = variables
|
||||||
|
if execution_config is not None:
|
||||||
|
existing.execution_config = execution_config
|
||||||
|
if triggers is not None:
|
||||||
|
existing.triggers = triggers
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(existing)
|
||||||
|
return existing
|
||||||
|
else:
|
||||||
|
# 创建新配置
|
||||||
|
config = WorkflowConfig(
|
||||||
|
app_id=app_id,
|
||||||
|
nodes=nodes,
|
||||||
|
edges=edges,
|
||||||
|
variables=variables or [],
|
||||||
|
execution_config=execution_config or {},
|
||||||
|
triggers=triggers or []
|
||||||
|
)
|
||||||
|
self.db.add(config)
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(config)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowExecutionRepository:
|
||||||
|
"""工作流执行记录仓储"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def get_by_execution_id(self, execution_id: str) -> WorkflowExecution | None:
|
||||||
|
"""根据执行 ID 获取执行记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
execution_id: 执行 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行记录或 None
|
||||||
|
"""
|
||||||
|
return self.db.query(WorkflowExecution).filter(
|
||||||
|
WorkflowExecution.execution_id == execution_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
def get_by_app_id(
|
||||||
|
self,
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0
|
||||||
|
) -> list[WorkflowExecution]:
|
||||||
|
"""根据应用 ID 获取执行记录列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用 ID
|
||||||
|
limit: 返回数量限制
|
||||||
|
offset: 偏移量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行记录列表
|
||||||
|
"""
|
||||||
|
return self.db.query(WorkflowExecution).filter(
|
||||||
|
WorkflowExecution.app_id == app_id
|
||||||
|
).order_by(
|
||||||
|
desc(WorkflowExecution.started_at)
|
||||||
|
).limit(limit).offset(offset).all()
|
||||||
|
|
||||||
|
def get_by_conversation_id(
|
||||||
|
self,
|
||||||
|
conversation_id: uuid.UUID
|
||||||
|
) -> list[WorkflowExecution]:
|
||||||
|
"""根据会话 ID 获取执行记录列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: 会话 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行记录列表
|
||||||
|
"""
|
||||||
|
return self.db.query(WorkflowExecution).filter(
|
||||||
|
WorkflowExecution.conversation_id == conversation_id
|
||||||
|
).order_by(
|
||||||
|
desc(WorkflowExecution.started_at)
|
||||||
|
).all()
|
||||||
|
|
||||||
|
def count_by_app_id(self, app_id: uuid.UUID) -> int:
|
||||||
|
"""统计应用的执行次数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行次数
|
||||||
|
"""
|
||||||
|
return self.db.query(WorkflowExecution).filter(
|
||||||
|
WorkflowExecution.app_id == app_id
|
||||||
|
).count()
|
||||||
|
|
||||||
|
def count_by_status(self, app_id: uuid.UUID, status: str) -> int:
|
||||||
|
"""统计指定状态的执行次数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用 ID
|
||||||
|
status: 状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行次数
|
||||||
|
"""
|
||||||
|
return self.db.query(WorkflowExecution).filter(
|
||||||
|
WorkflowExecution.app_id == app_id,
|
||||||
|
WorkflowExecution.status == status
|
||||||
|
).count()
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowNodeExecutionRepository:
|
||||||
|
"""工作流节点执行记录仓储"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
def get_by_execution_id(
|
||||||
|
self,
|
||||||
|
execution_id: uuid.UUID
|
||||||
|
) -> list[WorkflowNodeExecution]:
|
||||||
|
"""根据执行 ID 获取节点执行记录列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
execution_id: 执行 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
节点执行记录列表(按执行顺序排序)
|
||||||
|
"""
|
||||||
|
return self.db.query(WorkflowNodeExecution).filter(
|
||||||
|
WorkflowNodeExecution.execution_id == execution_id
|
||||||
|
).order_by(
|
||||||
|
WorkflowNodeExecution.execution_order
|
||||||
|
).all()
|
||||||
|
|
||||||
|
def get_by_node_id(
|
||||||
|
self,
|
||||||
|
execution_id: uuid.UUID,
|
||||||
|
node_id: str
|
||||||
|
) -> list[WorkflowNodeExecution]:
|
||||||
|
"""根据节点 ID 获取节点执行记录(可能有多次重试)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
execution_id: 执行 ID
|
||||||
|
node_id: 节点 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
节点执行记录列表
|
||||||
|
"""
|
||||||
|
return self.db.query(WorkflowNodeExecution).filter(
|
||||||
|
WorkflowNodeExecution.execution_id == execution_id,
|
||||||
|
WorkflowNodeExecution.node_id == node_id
|
||||||
|
).order_by(
|
||||||
|
WorkflowNodeExecution.retry_count
|
||||||
|
).all()
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 依赖注入函数 ====================
|
||||||
|
|
||||||
|
def get_workflow_config_repository(
|
||||||
|
db: Annotated[Session, Depends(get_db)]
|
||||||
|
) -> WorkflowConfigRepository:
|
||||||
|
"""获取工作流配置仓储(依赖注入)"""
|
||||||
|
return WorkflowConfigRepository(db)
|
||||||
|
|
||||||
|
|
||||||
|
def get_workflow_execution_repository(
|
||||||
|
db: Annotated[Session, Depends(get_db)]
|
||||||
|
) -> WorkflowExecutionRepository:
|
||||||
|
"""获取工作流执行记录仓储(依赖注入)"""
|
||||||
|
return WorkflowExecutionRepository(db)
|
||||||
|
|
||||||
|
|
||||||
|
def get_workflow_node_execution_repository(
|
||||||
|
db: Annotated[Session, Depends(get_db)]
|
||||||
|
) -> WorkflowNodeExecutionRepository:
|
||||||
|
"""获取工作流节点执行记录仓储(依赖注入)"""
|
||||||
|
return WorkflowNodeExecutionRepository(db)
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
"""API Key Schema"""
|
"""API Key Schema"""
|
||||||
import datetime
|
import datetime
|
||||||
import uuid
|
import uuid
|
||||||
from pydantic import BaseModel, Field, ConfigDict
|
from pydantic import BaseModel, Field, ConfigDict, field_validator, field_serializer, computed_field
|
||||||
from pydantic.v1 import validator
|
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from app.models.api_key_model import ApiKeyType, ResourceType
|
from app.models.api_key_model import ApiKeyType
|
||||||
|
from app.core.api_key_utils import timestamp_to_datetime, datetime_to_timestamp
|
||||||
|
|
||||||
|
|
||||||
class ApiKeyCreate(BaseModel):
|
class ApiKeyCreate(BaseModel):
|
||||||
@@ -15,20 +15,34 @@ class ApiKeyCreate(BaseModel):
|
|||||||
type: ApiKeyType = Field(..., description="API Key 类型")
|
type: ApiKeyType = Field(..., description="API Key 类型")
|
||||||
scopes: List[str] = Field(default_factory=list, description="权限范围列表")
|
scopes: List[str] = Field(default_factory=list, description="权限范围列表")
|
||||||
resource_id: Optional[uuid.UUID] = Field(None, description="关联资源ID")
|
resource_id: Optional[uuid.UUID] = Field(None, description="关联资源ID")
|
||||||
resource_type: Optional[ResourceType] = Field(None, description="资源类型")
|
|
||||||
rate_limit: Optional[int] = Field(10, ge=1, le=1000, description="QPS限制(请求/秒)")
|
rate_limit: Optional[int] = Field(10, ge=1, le=1000, description="QPS限制(请求/秒)")
|
||||||
daily_request_limit: Optional[int] = Field(10000, description="日请求限制", ge=1)
|
daily_request_limit: Optional[int] = Field(10000, description="日请求限制", ge=1)
|
||||||
quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1)
|
quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1)
|
||||||
expires_at: Optional[datetime.datetime] = Field(None, description="过期时间")
|
expires_at: Optional[datetime.datetime] = Field(None, description="过期时间")
|
||||||
|
|
||||||
@validator('scopes')
|
@computed_field
|
||||||
|
@property
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
"""检查API Key是否已过期"""
|
||||||
|
if not self.expires_at:
|
||||||
|
return False
|
||||||
|
return datetime.datetime.now() > self.expires_at
|
||||||
|
|
||||||
|
@field_validator('expires_at', mode='before')
|
||||||
|
@classmethod
|
||||||
|
def parse_expires_at(cls, v):
|
||||||
|
"""将时间戳转换为datetime"""
|
||||||
|
if isinstance(v, (int, float)):
|
||||||
|
return timestamp_to_datetime(v)
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator('scopes')
|
||||||
|
@classmethod
|
||||||
def validate_scopes(cls, v):
|
def validate_scopes(cls, v):
|
||||||
"""验证权限范围格式"""
|
"""验证权限范围格式"""
|
||||||
valid_scopes = [
|
if v is None:
|
||||||
"app:all",
|
return []
|
||||||
"rag:search", "rag:upload", "rag:delete",
|
valid_scopes = ["app", "rag", "memory"]
|
||||||
"memory:read", "memory:write", "memory:delete", "memory:search"
|
|
||||||
]
|
|
||||||
for scope in v:
|
for scope in v:
|
||||||
if scope not in valid_scopes:
|
if scope not in valid_scopes:
|
||||||
raise ValueError(f"无效范围: {scope}")
|
raise ValueError(f"无效范围: {scope}")
|
||||||
@@ -46,14 +60,29 @@ class ApiKeyUpdate(BaseModel):
|
|||||||
is_active: Optional[bool] = Field(None, description="是否激活")
|
is_active: Optional[bool] = Field(None, description="是否激活")
|
||||||
expires_at: Optional[datetime.datetime] = Field(None, description="过期时间")
|
expires_at: Optional[datetime.datetime] = Field(None, description="过期时间")
|
||||||
|
|
||||||
@validator('scopes')
|
@computed_field
|
||||||
|
@property
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
"""检查API Key是否已过期"""
|
||||||
|
if not self.expires_at:
|
||||||
|
return False
|
||||||
|
return datetime.datetime.now() > self.expires_at
|
||||||
|
|
||||||
|
@field_validator('expires_at', mode='before')
|
||||||
|
@classmethod
|
||||||
|
def parse_expires_at(cls, v):
|
||||||
|
"""将时间戳转换为datetime"""
|
||||||
|
if isinstance(v, (int, float)):
|
||||||
|
return timestamp_to_datetime(v)
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator('scopes')
|
||||||
|
@classmethod
|
||||||
def validate_scopes(cls, v):
|
def validate_scopes(cls, v):
|
||||||
"""验证权限范围格式"""
|
"""验证权限范围格式"""
|
||||||
valid_scopes = {
|
if v is None:
|
||||||
'app:all',
|
return v
|
||||||
'rag:search', 'rag:upload', 'rag:delete',
|
valid_scopes = ["app", "rag", "memory"]
|
||||||
'memory:read', 'memory:write', 'memory:delete', 'memory:search'
|
|
||||||
}
|
|
||||||
for scope in v:
|
for scope in v:
|
||||||
if scope not in valid_scopes:
|
if scope not in valid_scopes:
|
||||||
raise ValueError(f"无效范围: {scope}")
|
raise ValueError(f"无效范围: {scope}")
|
||||||
@@ -67,18 +96,31 @@ class ApiKeyResponse(BaseModel):
|
|||||||
id: uuid.UUID
|
id: uuid.UUID
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str]
|
description: Optional[str]
|
||||||
api_key: str = Field(..., description="API Key 明文(仅创建时返回)")
|
api_key: str
|
||||||
key_prefix: str
|
|
||||||
type: str
|
type: str
|
||||||
scopes: List[str]
|
scopes: List[str]
|
||||||
resource_id: Optional[uuid.UUID]
|
resource_id: Optional[uuid.UUID]
|
||||||
resource_type: Optional[str]
|
|
||||||
rate_limit: int
|
rate_limit: int
|
||||||
daily_request_limit: int
|
daily_request_limit: int
|
||||||
quota_limit: Optional[int]
|
quota_limit: Optional[int]
|
||||||
|
is_active: bool
|
||||||
expires_at: Optional[datetime.datetime]
|
expires_at: Optional[datetime.datetime]
|
||||||
created_at: datetime.datetime
|
created_at: datetime.datetime
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
"""检查API Key是否已过期"""
|
||||||
|
if not self.expires_at:
|
||||||
|
return False
|
||||||
|
return datetime.datetime.now() > self.expires_at
|
||||||
|
|
||||||
|
@field_serializer('expires_at', 'created_at')
|
||||||
|
@classmethod
|
||||||
|
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
|
||||||
|
"""将datetime转换为时间戳"""
|
||||||
|
return datetime_to_timestamp(v)
|
||||||
|
|
||||||
|
|
||||||
class ApiKey(BaseModel):
|
class ApiKey(BaseModel):
|
||||||
"""API Key 信息(不包含明文 Key)"""
|
"""API Key 信息(不包含明文 Key)"""
|
||||||
@@ -87,11 +129,10 @@ class ApiKey(BaseModel):
|
|||||||
id: uuid.UUID
|
id: uuid.UUID
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str]
|
description: Optional[str]
|
||||||
key_prefix: str
|
api_key: str
|
||||||
type: str
|
type: str
|
||||||
scopes: List[str]
|
scopes: List[str]
|
||||||
resource_id: Optional[uuid.UUID]
|
resource_id: Optional[uuid.UUID]
|
||||||
resource_type: Optional[str]
|
|
||||||
rate_limit: int
|
rate_limit: int
|
||||||
daily_request_limit: int
|
daily_request_limit: int
|
||||||
quota_limit: Optional[int]
|
quota_limit: Optional[int]
|
||||||
@@ -105,6 +146,20 @@ class ApiKey(BaseModel):
|
|||||||
created_at: datetime.datetime
|
created_at: datetime.datetime
|
||||||
updated_at: datetime.datetime
|
updated_at: datetime.datetime
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
"""检查API Key是否已过期"""
|
||||||
|
if not self.expires_at:
|
||||||
|
return False
|
||||||
|
return datetime.datetime.now() > self.expires_at
|
||||||
|
|
||||||
|
@field_serializer('expires_at', 'last_used_at', 'created_at', 'updated_at')
|
||||||
|
@classmethod
|
||||||
|
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
|
||||||
|
"""将datetime转换为时间戳"""
|
||||||
|
return datetime_to_timestamp(v)
|
||||||
|
|
||||||
|
|
||||||
class ApiKeyStats(BaseModel):
|
class ApiKeyStats(BaseModel):
|
||||||
"""API Key 使用统计"""
|
"""API Key 使用统计"""
|
||||||
@@ -115,6 +170,12 @@ class ApiKeyStats(BaseModel):
|
|||||||
last_used_at: Optional[datetime.datetime] = Field(None, description="最后使用时间")
|
last_used_at: Optional[datetime.datetime] = Field(None, description="最后使用时间")
|
||||||
avg_response_time: Optional[float] = Field(None, description="平均响应时间(毫秒)")
|
avg_response_time: Optional[float] = Field(None, description="平均响应时间(毫秒)")
|
||||||
|
|
||||||
|
@field_serializer('last_used_at')
|
||||||
|
@classmethod
|
||||||
|
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
|
||||||
|
"""将datetime转换为时间戳"""
|
||||||
|
return datetime_to_timestamp(v)
|
||||||
|
|
||||||
|
|
||||||
class ApiKeyQuery(BaseModel):
|
class ApiKeyQuery(BaseModel):
|
||||||
"""API Key 查询参数"""
|
"""API Key 查询参数"""
|
||||||
@@ -132,7 +193,6 @@ class ApiKeyAuth(BaseModel):
|
|||||||
type: str
|
type: str
|
||||||
scopes: List[str]
|
scopes: List[str]
|
||||||
resource_id: Optional[uuid.UUID]
|
resource_id: Optional[uuid.UUID]
|
||||||
resource_type: Optional[str]
|
|
||||||
|
|
||||||
|
|
||||||
class ApiKeyLog(BaseModel):
|
class ApiKeyLog(BaseModel):
|
||||||
@@ -157,3 +217,9 @@ class ApiKeyLog(BaseModel):
|
|||||||
|
|
||||||
# 时间信息
|
# 时间信息
|
||||||
created_at: datetime.datetime
|
created_at: datetime.datetime
|
||||||
|
|
||||||
|
@field_serializer('created_at')
|
||||||
|
@classmethod
|
||||||
|
def serialize_datetime(cls, v: datetime.datetime) -> int:
|
||||||
|
"""将datetime转换为时间戳"""
|
||||||
|
return datetime_to_timestamp(v)
|
||||||
|
|||||||
215
api/app/schemas/workflow_schema.py
Normal file
215
api/app/schemas/workflow_schema.py
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
"""
|
||||||
|
工作流相关的 Pydantic Schema
|
||||||
|
"""
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
from pydantic import BaseModel, Field, ConfigDict, field_serializer
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 节点和边定义 ====================
|
||||||
|
|
||||||
|
class NodeConfig(BaseModel):
|
||||||
|
"""节点配置"""
|
||||||
|
model_config = ConfigDict(extra="allow") # 允许额外字段
|
||||||
|
|
||||||
|
|
||||||
|
class NodeDefinition(BaseModel):
|
||||||
|
"""节点定义"""
|
||||||
|
id: str = Field(..., description="节点唯一标识")
|
||||||
|
type: str = Field(..., description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code")
|
||||||
|
name: str | None = Field(None, description="节点名称")
|
||||||
|
description: str | None = Field(None, description="节点描述")
|
||||||
|
config: dict[str, Any] = Field(default_factory=dict, description="节点配置")
|
||||||
|
position: dict[str, float] | None = Field(None, description="节点位置 {x, y}")
|
||||||
|
error_handling: dict[str, Any] | None = Field(None, description="错误处理配置")
|
||||||
|
cache: dict[str, Any] | None = Field(None, description="缓存配置")
|
||||||
|
|
||||||
|
|
||||||
|
class EdgeDefinition(BaseModel):
|
||||||
|
"""边定义"""
|
||||||
|
id: str | None = Field(None, description="边唯一标识(可选)")
|
||||||
|
source: str = Field(..., description="源节点 ID")
|
||||||
|
target: str = Field(..., description="目标节点 ID")
|
||||||
|
type: str | None = Field(None, description="边类型: normal, error")
|
||||||
|
condition: str | None = Field(None, description="条件表达式(条件边)")
|
||||||
|
label: str | None = Field(None, description="边标签")
|
||||||
|
|
||||||
|
|
||||||
|
class VariableDefinition(BaseModel):
|
||||||
|
"""变量定义"""
|
||||||
|
name: str = Field(..., description="变量名称")
|
||||||
|
type: str = Field(default="string", description="变量类型: string, number, boolean, object, array")
|
||||||
|
required: bool = Field(default=False, description="是否必填")
|
||||||
|
default: Any = Field(None, description="默认值")
|
||||||
|
description: str | None = Field(None, description="变量描述")
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionConfig(BaseModel):
|
||||||
|
"""执行配置"""
|
||||||
|
max_iterations: int = Field(default=100, ge=1, le=1000, description="最大迭代次数")
|
||||||
|
timeout: int = Field(default=600, ge=10, le=3600, description="全局超时时间(秒)")
|
||||||
|
enable_cache: bool = Field(default=True, description="是否启用节点缓存")
|
||||||
|
parallel_limit: int = Field(default=5, ge=1, le=20, description="并行执行限制")
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerConfig(BaseModel):
|
||||||
|
"""触发器配置"""
|
||||||
|
type: str = Field(..., description="触发器类型: schedule, webhook, event")
|
||||||
|
config: dict[str, Any] = Field(default_factory=dict, description="触发器配置")
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 工作流配置 ====================
|
||||||
|
|
||||||
|
class WorkflowConfigCreate(BaseModel):
|
||||||
|
"""创建工作流配置"""
|
||||||
|
nodes: list[NodeDefinition] = Field(default_factory=list, description="节点列表")
|
||||||
|
edges: list[EdgeDefinition] = Field(default_factory=list, description="边列表")
|
||||||
|
variables: list[VariableDefinition] = Field(default_factory=list, description="变量列表")
|
||||||
|
execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置")
|
||||||
|
triggers: list[TriggerConfig] = Field(default_factory=list, description="触发器列表")
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowConfigUpdate(BaseModel):
|
||||||
|
"""更新工作流配置"""
|
||||||
|
nodes: list[NodeDefinition] | None = None
|
||||||
|
edges: list[EdgeDefinition] | None = None
|
||||||
|
variables: list[VariableDefinition] | None = None
|
||||||
|
execution_config: ExecutionConfig | None = None
|
||||||
|
triggers: list[TriggerConfig] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowConfig(BaseModel):
|
||||||
|
"""工作流配置输出"""
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
|
app_id: uuid.UUID
|
||||||
|
nodes: list[dict[str, Any]]
|
||||||
|
edges: list[dict[str, Any]]
|
||||||
|
variables: list[dict[str, Any]]
|
||||||
|
execution_config: dict[str, Any]
|
||||||
|
triggers: list[dict[str, Any]]
|
||||||
|
is_active: bool
|
||||||
|
created_at: datetime.datetime
|
||||||
|
updated_at: datetime.datetime
|
||||||
|
|
||||||
|
@field_serializer("created_at", when_used="json")
|
||||||
|
def _serialize_created_at(self, dt: datetime.datetime):
|
||||||
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
@field_serializer("updated_at", when_used="json")
|
||||||
|
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||||
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 工作流执行 ====================
|
||||||
|
|
||||||
|
class WorkflowExecutionRequest(BaseModel):
|
||||||
|
"""工作流执行请求"""
|
||||||
|
message: str | None = Field(None, description="用户消息(可选)")
|
||||||
|
variables: dict[str, Any] = Field(default_factory=dict, description="输入变量")
|
||||||
|
conversation_id: str | None = Field(None, description="会话 ID(用于关联对话)")
|
||||||
|
stream: bool = Field(default=False, description="是否流式返回")
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowExecutionResponse(BaseModel):
|
||||||
|
"""工作流执行响应(非流式)"""
|
||||||
|
execution_id: str = Field(..., description="执行 ID")
|
||||||
|
status: str = Field(..., description="执行状态")
|
||||||
|
output: str | None = Field(None, description="最终输出(字符串,便于快速访问)")
|
||||||
|
output_data: dict[str, Any] | None = Field(None, description="所有节点的详细输出数据")
|
||||||
|
error_message: str | None = Field(None, description="错误信息")
|
||||||
|
elapsed_time: float | None = Field(None, description="耗时(秒)")
|
||||||
|
token_usage: dict[str, Any] | None = Field(None, description="Token 使用情况 {prompt_tokens, completion_tokens, total_tokens}")
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowExecutionStreamChunk(BaseModel):
|
||||||
|
"""工作流执行流式响应块"""
|
||||||
|
type: str = Field(..., description="事件类型: node_start, token, node_complete, error_redirect, workflow_complete")
|
||||||
|
execution_id: str = Field(..., description="执行 ID")
|
||||||
|
data: dict[str, Any] = Field(default_factory=dict, description="事件数据")
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowExecution(BaseModel):
|
||||||
|
"""工作流执行记录输出"""
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
|
workflow_config_id: uuid.UUID
|
||||||
|
app_id: uuid.UUID
|
||||||
|
conversation_id: uuid.UUID | None
|
||||||
|
execution_id: str
|
||||||
|
trigger_type: str
|
||||||
|
triggered_by: uuid.UUID | None
|
||||||
|
input_data: dict[str, Any] | None
|
||||||
|
output_data: dict[str, Any] | None
|
||||||
|
context: dict[str, Any]
|
||||||
|
status: str
|
||||||
|
error_message: str | None
|
||||||
|
error_node_id: str | None
|
||||||
|
started_at: datetime.datetime
|
||||||
|
completed_at: datetime.datetime | None
|
||||||
|
elapsed_time: float | None
|
||||||
|
token_usage: dict[str, Any] | None
|
||||||
|
meta_data: dict[str, Any]
|
||||||
|
created_at: datetime.datetime
|
||||||
|
|
||||||
|
@field_serializer("started_at", when_used="json")
|
||||||
|
def _serialize_started_at(self, dt: datetime.datetime):
|
||||||
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
@field_serializer("completed_at", when_used="json")
|
||||||
|
def _serialize_completed_at(self, dt: datetime.datetime | None):
|
||||||
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
@field_serializer("created_at", when_used="json")
|
||||||
|
def _serialize_created_at(self, dt: datetime.datetime):
|
||||||
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowNodeExecution(BaseModel):
|
||||||
|
"""工作流节点执行记录输出"""
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
id: uuid.UUID
|
||||||
|
execution_id: uuid.UUID
|
||||||
|
node_id: str
|
||||||
|
node_type: str
|
||||||
|
node_name: str | None
|
||||||
|
execution_order: int
|
||||||
|
retry_count: int
|
||||||
|
input_data: dict[str, Any] | None
|
||||||
|
output_data: dict[str, Any] | None
|
||||||
|
status: str
|
||||||
|
error_message: str | None
|
||||||
|
started_at: datetime.datetime
|
||||||
|
completed_at: datetime.datetime | None
|
||||||
|
elapsed_time: float | None
|
||||||
|
token_usage: dict[str, Any] | None
|
||||||
|
cache_hit: bool
|
||||||
|
cache_key: str | None
|
||||||
|
meta_data: dict[str, Any]
|
||||||
|
created_at: datetime.datetime
|
||||||
|
|
||||||
|
@field_serializer("started_at", when_used="json")
|
||||||
|
def _serialize_started_at(self, dt: datetime.datetime):
|
||||||
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
@field_serializer("completed_at", when_used="json")
|
||||||
|
def _serialize_completed_at(self, dt: datetime.datetime | None):
|
||||||
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
@field_serializer("created_at", when_used="json")
|
||||||
|
def _serialize_created_at(self, dt: datetime.datetime):
|
||||||
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 验证响应 ====================
|
||||||
|
|
||||||
|
class WorkflowValidationResponse(BaseModel):
|
||||||
|
"""工作流验证响应"""
|
||||||
|
is_valid: bool = Field(..., description="是否有效")
|
||||||
|
errors: list[str] = Field(default_factory=list, description="错误列表")
|
||||||
|
warnings: list[str] = Field(default_factory=list, description="警告列表")
|
||||||
@@ -13,7 +13,7 @@ from app.models.api_key_model import ApiKey
|
|||||||
from app.repositories.api_key_repository import ApiKeyRepository, ApiKeyLogRepository
|
from app.repositories.api_key_repository import ApiKeyRepository, ApiKeyLogRepository
|
||||||
from app.schemas import api_key_schema
|
from app.schemas import api_key_schema
|
||||||
from app.schemas.response_schema import PageData, PageMeta
|
from app.schemas.response_schema import PageData, PageMeta
|
||||||
from app.core.api_key_utils import generate_api_key, hash_api_key, validate_resource_binding
|
from app.core.api_key_utils import generate_api_key
|
||||||
from app.core.exceptions import (
|
from app.core.exceptions import (
|
||||||
BusinessException,
|
BusinessException,
|
||||||
)
|
)
|
||||||
@@ -33,21 +33,13 @@ class ApiKeyService:
|
|||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
data: api_key_schema.ApiKeyCreate
|
data: api_key_schema.ApiKeyCreate
|
||||||
) -> Tuple[ApiKey, str]:
|
) -> ApiKey:
|
||||||
"""
|
"""
|
||||||
创建 API Key
|
创建 API Key
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[ApiKey, str]: (API Key 对象, API Key 明文)
|
ApiKey: API Key 对象
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 验证资源绑定
|
|
||||||
if data.resource_type or data.resource_id:
|
|
||||||
is_valid, error_msg = validate_resource_binding(
|
|
||||||
data.resource_type, str(data.resource_id) if data.resource_id else None
|
|
||||||
)
|
|
||||||
if not is_valid:
|
|
||||||
raise BusinessException(error_msg, BizCode.API_KEY_INVALID_RESOURCE)
|
|
||||||
|
|
||||||
existing = db.scalar(
|
existing = db.scalar(
|
||||||
select(ApiKey).where(
|
select(ApiKey).where(
|
||||||
ApiKey.workspace_id == workspace_id,
|
ApiKey.workspace_id == workspace_id,
|
||||||
@@ -59,22 +51,20 @@ class ApiKeyService:
|
|||||||
raise BusinessException(f"API Key 名称 '{data.name}' 已存在", BizCode.API_KEY_DUPLICATE_NAME)
|
raise BusinessException(f"API Key 名称 '{data.name}' 已存在", BizCode.API_KEY_DUPLICATE_NAME)
|
||||||
|
|
||||||
# 生成 API Key
|
# 生成 API Key
|
||||||
api_key, key_hash, key_prefix = generate_api_key(data.type)
|
api_key = generate_api_key(data.type)
|
||||||
|
|
||||||
# 创建数据
|
# 创建数据
|
||||||
api_key_data = {
|
api_key_data = {
|
||||||
"id": uuid.uuid4(),
|
"id": uuid.uuid4(),
|
||||||
"name": data.name,
|
"name": data.name,
|
||||||
"description": data.description,
|
"description": data.description,
|
||||||
"key_prefix": key_prefix,
|
"api_key": api_key,
|
||||||
"key_hash": key_hash,
|
|
||||||
"type": data.type,
|
"type": data.type,
|
||||||
"scopes": data.scopes,
|
"scopes": data.scopes,
|
||||||
"workspace_id": workspace_id,
|
"workspace_id": workspace_id,
|
||||||
"resource_id": data.resource_id,
|
"resource_id": data.resource_id,
|
||||||
"resource_type": data.resource_type,
|
"rate_limit": data.rate_limit,
|
||||||
"rate_limit": data.rate_limit or 10,
|
"daily_request_limit": data.daily_request_limit,
|
||||||
"daily_request_limit": data.daily_request_limit or 10000,
|
|
||||||
"quota_limit": data.quota_limit,
|
"quota_limit": data.quota_limit,
|
||||||
"expires_at": data.expires_at,
|
"expires_at": data.expires_at,
|
||||||
"created_by": user_id,
|
"created_by": user_id,
|
||||||
@@ -90,7 +80,7 @@ class ApiKeyService:
|
|||||||
"type": data.type
|
"type": data.type
|
||||||
})
|
})
|
||||||
|
|
||||||
return api_key_obj, api_key
|
return api_key_obj
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
@@ -147,6 +137,9 @@ class ApiKeyService:
|
|||||||
"""更新 API Key配置"""
|
"""更新 API Key配置"""
|
||||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
|
||||||
|
|
||||||
# 检查名称重复
|
# 检查名称重复
|
||||||
if data.name and data.name != api_key.name:
|
if data.name and data.name != api_key.name:
|
||||||
existing = db.scalar(
|
existing = db.scalar(
|
||||||
@@ -177,6 +170,9 @@ class ApiKeyService:
|
|||||||
"""删除 API Key"""
|
"""删除 API Key"""
|
||||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
|
||||||
|
|
||||||
ApiKeyRepository.delete(db, api_key_id)
|
ApiKeyRepository.delete(db, api_key_id)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
@@ -188,27 +184,29 @@ class ApiKeyService:
|
|||||||
db: Session,
|
db: Session,
|
||||||
api_key_id: uuid.UUID,
|
api_key_id: uuid.UUID,
|
||||||
workspace_id: uuid.UUID
|
workspace_id: uuid.UUID
|
||||||
) -> Tuple[ApiKey, str]:
|
) -> ApiKey:
|
||||||
"""重新生成 API Key"""
|
"""重新生成 API Key"""
|
||||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
|
||||||
|
|
||||||
# 检查 API Key 是否激活
|
# 检查 API Key 是否激活
|
||||||
if not api_key.is_active:
|
if not api_key.is_active:
|
||||||
raise BusinessException("无法重新生成已停用的 API Key", BizCode.API_KEY_INACTIVE)
|
raise BusinessException("无法重新生成已停用的 API Key", BizCode.API_KEY_INACTIVE)
|
||||||
|
|
||||||
# 生成新的 API Key
|
# 生成新的 API Key
|
||||||
new_api_key, key_hash, key_prefix = generate_api_key(api_key_schema.ApiKeyType(api_key.type))
|
new_api_key = generate_api_key(api_key.type)
|
||||||
|
|
||||||
# 更新
|
# 更新
|
||||||
ApiKeyRepository.update(db, api_key_id, {
|
ApiKeyRepository.update(db, api_key_id, {
|
||||||
"key_hash": key_hash,
|
"api_key": new_api_key
|
||||||
"key_prefix": key_prefix
|
|
||||||
})
|
})
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(api_key)
|
db.refresh(api_key)
|
||||||
|
|
||||||
logger.info("API Key 重新生成成功", extra={"api_key_id": str(api_key_id)})
|
logger.info("API Key 重新生成成功", extra={"api_key_id": str(api_key_id)})
|
||||||
return api_key, new_api_key
|
return api_key
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_stats(
|
def get_stats(
|
||||||
@@ -219,6 +217,9 @@ class ApiKeyService:
|
|||||||
"""获取使用统计"""
|
"""获取使用统计"""
|
||||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
|
||||||
|
|
||||||
stats_data = ApiKeyRepository.get_stats(db, api_key_id)
|
stats_data = ApiKeyRepository.get_stats(db, api_key_id)
|
||||||
return api_key_schema.ApiKeyStats(**stats_data)
|
return api_key_schema.ApiKeyStats(**stats_data)
|
||||||
|
|
||||||
@@ -235,6 +236,9 @@ class ApiKeyService:
|
|||||||
# 验证 API Key 权限
|
# 验证 API Key 权限
|
||||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
|
||||||
|
|
||||||
items, total = ApiKeyLogRepository.list_by_api_key(
|
items, total = ApiKeyLogRepository.list_by_api_key(
|
||||||
db, api_key_id, filters, page, pagesize
|
db, api_key_id, filters, page, pagesize
|
||||||
)
|
)
|
||||||
@@ -330,7 +334,6 @@ class RateLimiterService:
|
|||||||
"X-RateLimit-Reset": str(qps_info["reset"])
|
"X-RateLimit-Reset": str(qps_info["reset"])
|
||||||
}
|
}
|
||||||
|
|
||||||
# Check daily requests
|
|
||||||
daily_ok, daily_info = await self.check_daily_requests(
|
daily_ok, daily_info = await self.check_daily_requests(
|
||||||
api_key.id,
|
api_key.id,
|
||||||
api_key.daily_request_limit
|
api_key.daily_request_limit
|
||||||
@@ -342,7 +345,6 @@ class RateLimiterService:
|
|||||||
"X-RateLimit-Reset": str(daily_info["reset"])
|
"X-RateLimit-Reset": str(daily_info["reset"])
|
||||||
}
|
}
|
||||||
|
|
||||||
# All checks passed
|
|
||||||
headers = {
|
headers = {
|
||||||
"X-RateLimit-Limit-QPS": str(qps_info["limit"]),
|
"X-RateLimit-Limit-QPS": str(qps_info["limit"]),
|
||||||
"X-RateLimit-Remaining-QPS": str(qps_info["remaining"]),
|
"X-RateLimit-Remaining-QPS": str(qps_info["remaining"]),
|
||||||
@@ -363,13 +365,12 @@ class ApiKeyAuthService:
|
|||||||
验证API Key 有效性
|
验证API Key 有效性
|
||||||
|
|
||||||
检查:
|
检查:
|
||||||
1. Key hash 是否存在
|
1. API Key 是否存在
|
||||||
2. is_active 是否为true
|
2. is_active 是否为true
|
||||||
3. expires_at 是否未过期
|
3. expires_at 是否未过期
|
||||||
4. quota 是否未超限
|
4. quota 是否未超限
|
||||||
"""
|
"""
|
||||||
key_hash = hash_api_key(api_key)
|
api_key_obj = ApiKeyRepository.get_by_api_key(db, api_key)
|
||||||
api_key_obj = ApiKeyRepository.get_by_hash(db, key_hash)
|
|
||||||
|
|
||||||
if not api_key_obj:
|
if not api_key_obj:
|
||||||
return None
|
return None
|
||||||
@@ -393,14 +394,7 @@ class ApiKeyAuthService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def check_resource(
|
def check_resource(
|
||||||
api_key: ApiKey,
|
api_key: ApiKey,
|
||||||
resource_type: str,
|
|
||||||
resource_id: uuid.UUID
|
resource_id: uuid.UUID
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""检查资源绑定"""
|
"""检查资源绑定"""
|
||||||
if not api_key.resource_id:
|
return api_key.resource_id == resource_id
|
||||||
return True
|
|
||||||
|
|
||||||
return (
|
|
||||||
api_key.resource_type == resource_type and
|
|
||||||
api_key.resource_id == resource_id
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -9,22 +9,24 @@
|
|||||||
"""
|
"""
|
||||||
import datetime
|
import datetime
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional, List, Dict, Any, Tuple
|
from typing import Optional, List, Dict, Any, Tuple, Type
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from sqlalchemy import select, func, or_, and_
|
from sqlalchemy import select, func, or_, and_
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.models import App, AgentConfig, AppRelease, MultiAgentConfig
|
from app.core.error_codes import BizCode
|
||||||
from app.schemas import app_schema
|
|
||||||
from app.core.exceptions import (
|
from app.core.exceptions import (
|
||||||
ResourceNotFoundException,
|
ResourceNotFoundException,
|
||||||
ValidationException,
|
|
||||||
BusinessException,
|
BusinessException,
|
||||||
)
|
)
|
||||||
from app.core.error_codes import BizCode
|
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.services.agent_config_converter import AgentConfigConverter
|
from app.models import App, AgentConfig, AppRelease, MultiAgentConfig, WorkflowConfig
|
||||||
from app.models.app_model import AppStatus, AppType
|
from app.models.app_model import AppStatus, AppType
|
||||||
|
from app.repositories.app_repository import get_apps_by_id
|
||||||
|
from app.repositories.workflow_repository import WorkflowConfigRepository
|
||||||
|
from app.schemas import app_schema
|
||||||
|
from app.schemas.workflow_schema import WorkflowConfigUpdate
|
||||||
|
from app.services.agent_config_converter import AgentConfigConverter
|
||||||
|
|
||||||
# 获取业务日志器
|
# 获取业务日志器
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
@@ -120,15 +122,31 @@ class AppService:
|
|||||||
Raises:
|
Raises:
|
||||||
ResourceNotFoundException: 当应用不存在时
|
ResourceNotFoundException: 当应用不存在时
|
||||||
"""
|
"""
|
||||||
app = self.db.get(App, app_id)
|
app = get_apps_by_id(self.db,app_id)
|
||||||
if not app:
|
if not app:
|
||||||
logger.warning("应用不存在", extra={"app_id": str(app_id)})
|
logger.warning("应用不存在", extra={"app_id": str(app_id)})
|
||||||
raise ResourceNotFoundException("应用", str(app_id))
|
raise ResourceNotFoundException("应用", str(app_id))
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
def _check_workflow_config(self, app_id: uuid.UUID):
|
||||||
|
from app.models import WorkflowConfig, ModelConfig
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
# 2. 获取 Agent 配置
|
||||||
|
stmt = select(WorkflowConfig).where(AgentConfig.app_id == app_id)
|
||||||
|
agent_cfg = self.db.scalars(stmt).first()
|
||||||
|
if not agent_cfg:
|
||||||
|
raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
|
||||||
|
# 3. 获取模型配置
|
||||||
|
model_config = None
|
||||||
|
if agent_cfg.default_model_config_id:
|
||||||
|
model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id)
|
||||||
|
|
||||||
|
if not model_config:
|
||||||
|
raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
|
||||||
def _check_agent_config(self, app_id: uuid.UUID):
|
def _check_agent_config(self, app_id: uuid.UUID):
|
||||||
from app.models import AgentConfig, ModelConfig
|
|
||||||
from app.services.app_service import AppService
|
|
||||||
from app.models import AgentConfig, ModelConfig
|
from app.models import AgentConfig, ModelConfig
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
@@ -161,7 +179,7 @@ class AppService:
|
|||||||
Raises:
|
Raises:
|
||||||
BusinessException: 配置不完整或不存在时抛出
|
BusinessException: 配置不完整或不存在时抛出
|
||||||
"""
|
"""
|
||||||
from app.models import MultiAgentConfig, AgentConfig, ModelConfig
|
from app.models import ModelConfig
|
||||||
from app.services.multi_agent_service import MultiAgentService
|
from app.services.multi_agent_service import MultiAgentService
|
||||||
|
|
||||||
# 1. 检查多智能体配置是否存在
|
# 1. 检查多智能体配置是否存在
|
||||||
@@ -956,6 +974,167 @@ class AppService:
|
|||||||
|
|
||||||
return default_config
|
return default_config
|
||||||
|
|
||||||
|
def get_workflow_config(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
workspace_id: Optional[uuid.UUID] = None
|
||||||
|
) -> WorkflowConfig:
|
||||||
|
"""获取 workflow 配置
|
||||||
|
|
||||||
|
如果配置不存在,返回默认配置模板(不保存到数据库)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用ID
|
||||||
|
workspace_id: 工作空间ID(用于权限验证)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WorkflowConfig: Workflow 配置对象(存在的配置或默认模板)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ResourceNotFoundException: 当应用不存在时
|
||||||
|
BusinessException: 当应用类型不支持或不可访问时
|
||||||
|
"""
|
||||||
|
logger.debug("获取 Workflow 配置", extra={"app_id": str(app_id)})
|
||||||
|
|
||||||
|
app = self._get_app_or_404(app_id)
|
||||||
|
|
||||||
|
if app.type != AppType.WORKFLOW:
|
||||||
|
raise BusinessException("只有 Workflow 类型应用支持 Workflow 配置", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||||
|
|
||||||
|
# 只读操作,允许访问共享应用
|
||||||
|
self._validate_app_accessible(app, workspace_id)
|
||||||
|
repo = WorkflowConfigRepository(self.db)
|
||||||
|
config = repo.get_by_app_id(app_id)
|
||||||
|
if config:
|
||||||
|
return config
|
||||||
|
|
||||||
|
# 返回默认配置模板(不保存到数据库)
|
||||||
|
logger.debug("配置不存在,返回默认模板", extra={"app_id": str(app_id)})
|
||||||
|
return self._create_default_workflow_config(app_id)
|
||||||
|
|
||||||
|
def update_workflow_config(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
data: WorkflowConfigUpdate,
|
||||||
|
workspace_id: Optional[uuid.UUID] = None
|
||||||
|
) -> WorkflowConfig:
|
||||||
|
"""更新 Workflow 配置(全量更新)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用ID
|
||||||
|
data: 配置更新数据(全量数据)
|
||||||
|
workspace_id: 工作空间ID(用于权限验证)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WorkflowConfig: 更新后的配置对象
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ResourceNotFoundException: 当应用不存在时
|
||||||
|
BusinessException: 当应用类型不支持或不在指定工作空间时
|
||||||
|
"""
|
||||||
|
logger.info("更新 Workflow 配置", extra={"app_id": str(app_id)})
|
||||||
|
|
||||||
|
app = self._get_app_or_404(app_id)
|
||||||
|
|
||||||
|
if app.type != AppType.WORKFLOW:
|
||||||
|
raise BusinessException("只有 Workflow 类型应用支持 Workflow 配置", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||||
|
|
||||||
|
self._validate_workspace_access(app, workspace_id)
|
||||||
|
|
||||||
|
# 获取现有配置
|
||||||
|
repo = WorkflowConfigRepository(self.db)
|
||||||
|
workflow_cfg = repo.get_by_app_id(app_id)
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
|
||||||
|
if not workflow_cfg:
|
||||||
|
# 如果配置不存在,创建新配置
|
||||||
|
workflow_cfg = WorkflowConfig(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
app_id=app_id,
|
||||||
|
nodes=[node.model_dump() for node in data.nodes] if data.nodes else [],
|
||||||
|
edges=[edge.model_dump() for edge in data.edges] if data.edges else [],
|
||||||
|
variables=[var.model_dump() for var in data.variables] if data.variables else [],
|
||||||
|
execution_config=data.execution_config.model_dump() if data.execution_config else {},
|
||||||
|
triggers=[trigger.model_dump() for trigger in data.triggers] if data.triggers else [],
|
||||||
|
is_active=True,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now
|
||||||
|
)
|
||||||
|
self.db.add(workflow_cfg)
|
||||||
|
logger.debug("创建新的 Workflow 配置", extra={"app_id": str(app_id)})
|
||||||
|
else:
|
||||||
|
# 全量更新现有配置
|
||||||
|
workflow_cfg.nodes = [node.model_dump() for node in data.nodes] if data.nodes else []
|
||||||
|
workflow_cfg.edges = [edge.model_dump() for edge in data.edges] if data.edges else []
|
||||||
|
workflow_cfg.variables = [var.model_dump() for var in data.variables] if data.variables else []
|
||||||
|
workflow_cfg.execution_config = data.execution_config.model_dump() if data.execution_config else {}
|
||||||
|
workflow_cfg.triggers = [trigger.model_dump() for trigger in data.triggers] if data.triggers else []
|
||||||
|
workflow_cfg.updated_at = now
|
||||||
|
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(workflow_cfg)
|
||||||
|
|
||||||
|
logger.info("Workflow 配置更新成功", extra={"app_id": str(app_id)})
|
||||||
|
return workflow_cfg
|
||||||
|
|
||||||
|
def _create_default_workflow_config(self, app_id: uuid.UUID) -> WorkflowConfig:
|
||||||
|
"""创建默认的 workflow 配置模板(不保存到数据库)
|
||||||
|
|
||||||
|
使用 template_loader 加载 simple_qa 模板作为默认配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WorkflowConfig: 默认配置对象
|
||||||
|
"""
|
||||||
|
from app.core.workflow.template_loader import load_workflow_template
|
||||||
|
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
|
||||||
|
# 使用 template_loader 加载 simple_qa 模板
|
||||||
|
template_data = load_workflow_template('simple_qa')
|
||||||
|
|
||||||
|
if not template_data:
|
||||||
|
# 如果模板加载失败,返回最小化配置
|
||||||
|
logger.warning(
|
||||||
|
"无法加载默认工作流模板,使用最小化配置",
|
||||||
|
extra={"app_id": str(app_id)}
|
||||||
|
)
|
||||||
|
template_data = {
|
||||||
|
'nodes': [
|
||||||
|
{'id': 'start', 'type': 'start', 'name': '开始'},
|
||||||
|
{'id': 'end', 'type': 'end', 'name': '结束'}
|
||||||
|
],
|
||||||
|
'edges': [
|
||||||
|
{'source': 'start', 'target': 'end'}
|
||||||
|
],
|
||||||
|
'variables': [],
|
||||||
|
'execution_config': {
|
||||||
|
'max_execution_time': 300,
|
||||||
|
'max_iterations': 10
|
||||||
|
},
|
||||||
|
'triggers': []
|
||||||
|
}
|
||||||
|
|
||||||
|
# 转换为 WorkflowConfig 格式
|
||||||
|
default_config = WorkflowConfig(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
app_id=app_id,
|
||||||
|
nodes=template_data.get('nodes', []),
|
||||||
|
edges=template_data.get('edges', []),
|
||||||
|
variables=template_data.get('variables', []),
|
||||||
|
execution_config=template_data.get('execution_config', {}),
|
||||||
|
triggers=template_data.get('triggers', []),
|
||||||
|
is_active=True,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now
|
||||||
|
)
|
||||||
|
|
||||||
|
return default_config
|
||||||
|
|
||||||
# ==================== 应用发布管理 ====================
|
# ==================== 应用发布管理 ====================
|
||||||
|
|
||||||
def publish(
|
def publish(
|
||||||
@@ -1797,6 +1976,11 @@ def update_agent_config(db: Session, *, app_id: uuid.UUID, data: app_schema.Agen
|
|||||||
service = AppService(db)
|
service = AppService(db)
|
||||||
return service.update_agent_config(app_id=app_id, data=data, workspace_id=workspace_id)
|
return service.update_agent_config(app_id=app_id, data=data, workspace_id=workspace_id)
|
||||||
|
|
||||||
|
def update_workflow_config(db: Session, *, app_id: uuid.UUID, data: WorkflowConfigUpdate, workspace_id: uuid.UUID | None = None) -> WorkflowConfig:
|
||||||
|
"""更新 Agent 配置(向后兼容接口)"""
|
||||||
|
service = AppService(db)
|
||||||
|
return service.update_workflow_config(app_id=app_id, data=data, workspace_id=workspace_id)
|
||||||
|
|
||||||
|
|
||||||
def get_agent_config(db: Session, *, app_id: uuid.UUID, workspace_id: uuid.UUID | None = None) -> AgentConfig:
|
def get_agent_config(db: Session, *, app_id: uuid.UUID, workspace_id: uuid.UUID | None = None) -> AgentConfig:
|
||||||
"""获取 Agent 配置(向后兼容接口)
|
"""获取 Agent 配置(向后兼容接口)
|
||||||
@@ -1806,6 +1990,14 @@ def get_agent_config(db: Session, *, app_id: uuid.UUID, workspace_id: uuid.UUID
|
|||||||
service = AppService(db)
|
service = AppService(db)
|
||||||
return service.get_agent_config(app_id=app_id, workspace_id=workspace_id)
|
return service.get_agent_config(app_id=app_id, workspace_id=workspace_id)
|
||||||
|
|
||||||
|
def get_workflow_config(db: Session, *, app_id: uuid.UUID, workspace_id: uuid.UUID | None = None) -> WorkflowConfig:
|
||||||
|
"""获取 Agent 配置(向后兼容接口)
|
||||||
|
|
||||||
|
如果配置不存在,返回默认配置模板
|
||||||
|
"""
|
||||||
|
service = AppService(db)
|
||||||
|
return service.get_workflow_config(app_id=app_id, workspace_id=workspace_id)
|
||||||
|
|
||||||
|
|
||||||
def publish(db: Session, *, app_id: uuid.UUID, publisher_id: uuid.UUID, workspace_id: uuid.UUID | None = None,version_name:str, release_notes: Optional[str] = None) -> AppRelease:
|
def publish(db: Session, *, app_id: uuid.UUID, publisher_id: uuid.UUID, workspace_id: uuid.UUID | None = None,version_name:str, release_notes: Optional[str] = None) -> AppRelease:
|
||||||
"""发布应用(向后兼容接口)"""
|
"""发布应用(向后兼容接口)"""
|
||||||
|
|||||||
731
api/app/services/workflow_service.py
Normal file
731
api/app/services/workflow_service.py
Normal file
@@ -0,0 +1,731 @@
|
|||||||
|
"""
|
||||||
|
工作流服务层
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
import datetime
|
||||||
|
from typing import Any, Annotated
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from fastapi import Depends
|
||||||
|
|
||||||
|
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||||
|
from app.repositories.workflow_repository import (
|
||||||
|
WorkflowConfigRepository,
|
||||||
|
WorkflowExecutionRepository,
|
||||||
|
WorkflowNodeExecutionRepository,
|
||||||
|
get_workflow_config_repository,
|
||||||
|
get_workflow_execution_repository,
|
||||||
|
get_workflow_node_execution_repository
|
||||||
|
)
|
||||||
|
from app.core.workflow.validator import validate_workflow_config
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.db import get_db
|
||||||
|
from app.schemas import DraftRunRequest
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowService:
|
||||||
|
"""工作流服务"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
self.config_repo = WorkflowConfigRepository(db)
|
||||||
|
self.execution_repo = WorkflowExecutionRepository(db)
|
||||||
|
self.node_execution_repo = WorkflowNodeExecutionRepository(db)
|
||||||
|
|
||||||
|
# ==================== 配置管理 ====================
|
||||||
|
|
||||||
|
def create_workflow_config(
|
||||||
|
self,
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
nodes: list[dict[str, Any]],
|
||||||
|
edges: list[dict[str, Any]],
|
||||||
|
variables: list[dict[str, Any]] | None = None,
|
||||||
|
execution_config: dict[str, Any] | None = None,
|
||||||
|
triggers: list[dict[str, Any]] | None = None,
|
||||||
|
validate: bool = True
|
||||||
|
) -> WorkflowConfig:
|
||||||
|
"""创建工作流配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用 ID
|
||||||
|
nodes: 节点列表
|
||||||
|
edges: 边列表
|
||||||
|
variables: 变量列表
|
||||||
|
execution_config: 执行配置
|
||||||
|
triggers: 触发器列表
|
||||||
|
validate: 是否验证配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工作流配置
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: 配置无效时抛出
|
||||||
|
"""
|
||||||
|
# 构建配置字典
|
||||||
|
config_dict = {
|
||||||
|
"nodes": nodes,
|
||||||
|
"edges": edges,
|
||||||
|
"variables": variables or [],
|
||||||
|
"execution_config": execution_config or {},
|
||||||
|
"triggers": triggers or []
|
||||||
|
}
|
||||||
|
|
||||||
|
# 验证配置
|
||||||
|
if validate:
|
||||||
|
is_valid, errors = validate_workflow_config(config_dict, for_publish=False)
|
||||||
|
if not is_valid:
|
||||||
|
logger.warning(f"工作流配置验证失败: {errors}")
|
||||||
|
raise BusinessException(
|
||||||
|
error_code=BizCode.INVALID_PARAMETER,
|
||||||
|
message=f"工作流配置无效: {'; '.join(errors)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建或更新配置
|
||||||
|
config = self.config_repo.create_or_update(
|
||||||
|
app_id=app_id,
|
||||||
|
nodes=nodes,
|
||||||
|
edges=edges,
|
||||||
|
variables=variables,
|
||||||
|
execution_config=execution_config,
|
||||||
|
triggers=triggers
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"创建工作流配置成功: app_id={app_id}, config_id={config.id}")
|
||||||
|
return config
|
||||||
|
|
||||||
|
def get_workflow_config(self, app_id: uuid.UUID) -> WorkflowConfig | None:
|
||||||
|
"""获取工作流配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工作流配置或 None
|
||||||
|
"""
|
||||||
|
return self.config_repo.get_by_app_id(app_id)
|
||||||
|
|
||||||
|
def update_workflow_config(
|
||||||
|
self,
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
nodes: list[dict[str, Any]] | None = None,
|
||||||
|
edges: list[dict[str, Any]] | None = None,
|
||||||
|
variables: list[dict[str, Any]] | None = None,
|
||||||
|
execution_config: dict[str, Any] | None = None,
|
||||||
|
triggers: list[dict[str, Any]] | None = None,
|
||||||
|
validate: bool = True
|
||||||
|
) -> WorkflowConfig:
|
||||||
|
"""更新工作流配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用 ID
|
||||||
|
nodes: 节点列表
|
||||||
|
edges: 边列表
|
||||||
|
variables: 变量列表
|
||||||
|
execution_config: 执行配置
|
||||||
|
triggers: 触发器列表
|
||||||
|
validate: 是否验证配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工作流配置
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: 配置不存在或无效时抛出
|
||||||
|
"""
|
||||||
|
# 获取现有配置
|
||||||
|
config = self.get_workflow_config(app_id)
|
||||||
|
if not config:
|
||||||
|
raise BusinessException(
|
||||||
|
error_code=BizCode.RESOURCE_NOT_FOUND,
|
||||||
|
message=f"工作流配置不存在: app_id={app_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 合并配置
|
||||||
|
updated_nodes = nodes if nodes is not None else config.nodes
|
||||||
|
updated_edges = edges if edges is not None else config.edges
|
||||||
|
updated_variables = variables if variables is not None else config.variables
|
||||||
|
updated_execution_config = execution_config if execution_config is not None else config.execution_config
|
||||||
|
updated_triggers = triggers if triggers is not None else config.triggers
|
||||||
|
|
||||||
|
# 构建配置字典
|
||||||
|
config_dict = {
|
||||||
|
"nodes": updated_nodes,
|
||||||
|
"edges": updated_edges,
|
||||||
|
"variables": updated_variables,
|
||||||
|
"execution_config": updated_execution_config,
|
||||||
|
"triggers": updated_triggers
|
||||||
|
}
|
||||||
|
|
||||||
|
# 验证配置
|
||||||
|
if validate:
|
||||||
|
is_valid, errors = validate_workflow_config(config_dict, for_publish=False)
|
||||||
|
if not is_valid:
|
||||||
|
logger.warning(f"工作流配置验证失败: {errors}")
|
||||||
|
raise BusinessException(
|
||||||
|
error_code=BizCode.INVALID_PARAMETER,
|
||||||
|
message=f"工作流配置无效: {'; '.join(errors)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新配置
|
||||||
|
config = self.config_repo.create_or_update(
|
||||||
|
app_id=app_id,
|
||||||
|
nodes=updated_nodes,
|
||||||
|
edges=updated_edges,
|
||||||
|
variables=updated_variables,
|
||||||
|
execution_config=updated_execution_config,
|
||||||
|
triggers=updated_triggers
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"更新工作流配置成功: app_id={app_id}, config_id={config.id}")
|
||||||
|
return config
|
||||||
|
|
||||||
|
def delete_workflow_config(self, app_id: uuid.UUID) -> bool:
|
||||||
|
"""删除工作流配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否删除成功
|
||||||
|
"""
|
||||||
|
config = self.get_workflow_config(app_id)
|
||||||
|
if not config:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.config_repo.delete(config.id)
|
||||||
|
logger.info(f"删除工作流配置成功: app_id={app_id}, config_id={config.id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def check_config(self, app_id: uuid.UUID) -> WorkflowConfig:
|
||||||
|
"""检查工作流配置的完整性
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用 ID
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: 配置不完整或不存在时抛出
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 1. 检查多智能体配置是否存在
|
||||||
|
config = self.get_workflow_config(app_id)
|
||||||
|
if not config:
|
||||||
|
raise BusinessException(
|
||||||
|
"工作流配置不存在,无法运行",
|
||||||
|
BizCode.CONFIG_MISSING
|
||||||
|
)
|
||||||
|
# validator 现在支持直接接受 Pydantic 模型
|
||||||
|
is_valid, errors = validate_workflow_config(config, for_publish=False)
|
||||||
|
if not is_valid:
|
||||||
|
logger.warning(f"工作流配置验证失败: {errors}")
|
||||||
|
raise BusinessException(
|
||||||
|
code=BizCode.INVALID_PARAMETER,
|
||||||
|
message=f"工作流配置无效: {'; '.join(errors)}"
|
||||||
|
)
|
||||||
|
return config
|
||||||
|
|
||||||
|
def validate_workflow_config_for_publish(
|
||||||
|
self,
|
||||||
|
app_id: uuid.UUID
|
||||||
|
) -> tuple[bool, list[str]]:
|
||||||
|
"""验证工作流配置是否可以发布
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(is_valid, errors): 是否有效和错误列表
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: 配置不存在时抛出
|
||||||
|
"""
|
||||||
|
config = self.get_workflow_config(app_id)
|
||||||
|
if not config:
|
||||||
|
raise BusinessException(
|
||||||
|
error_code=BizCode.RESOURCE_NOT_FOUND,
|
||||||
|
message=f"工作流配置不存在: app_id={app_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
config_dict = {
|
||||||
|
"nodes": config.nodes,
|
||||||
|
"edges": config.edges,
|
||||||
|
"variables": config.variables,
|
||||||
|
"execution_config": config.execution_config,
|
||||||
|
"triggers": config.triggers
|
||||||
|
}
|
||||||
|
|
||||||
|
return validate_workflow_config(config_dict, for_publish=True)
|
||||||
|
|
||||||
|
# ==================== 执行管理 ====================
|
||||||
|
|
||||||
|
def create_execution(
|
||||||
|
self,
|
||||||
|
workflow_config_id: uuid.UUID,
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
trigger_type: str,
|
||||||
|
triggered_by: uuid.UUID | None = None,
|
||||||
|
conversation_id: uuid.UUID | None = None,
|
||||||
|
input_data: dict[str, Any] | None = None
|
||||||
|
) -> WorkflowExecution:
|
||||||
|
"""创建工作流执行记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_config_id: 工作流配置 ID
|
||||||
|
app_id: 应用 ID
|
||||||
|
trigger_type: 触发类型
|
||||||
|
triggered_by: 触发用户 ID
|
||||||
|
conversation_id: 会话 ID
|
||||||
|
input_data: 输入数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行记录
|
||||||
|
"""
|
||||||
|
# 生成执行 ID
|
||||||
|
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
|
||||||
|
|
||||||
|
execution = WorkflowExecution(
|
||||||
|
workflow_config_id=workflow_config_id,
|
||||||
|
app_id=app_id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
execution_id=execution_id,
|
||||||
|
trigger_type=trigger_type,
|
||||||
|
triggered_by=triggered_by,
|
||||||
|
input_data=input_data or {},
|
||||||
|
status="pending"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.db.add(execution)
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(execution)
|
||||||
|
|
||||||
|
logger.info(f"创建工作流执行记录: execution_id={execution_id}")
|
||||||
|
return execution
|
||||||
|
|
||||||
|
def get_execution(self, execution_id: str) -> WorkflowExecution | None:
|
||||||
|
"""获取执行记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
execution_id: 执行 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行记录或 None
|
||||||
|
"""
|
||||||
|
return self.execution_repo.get_by_execution_id(execution_id)
|
||||||
|
|
||||||
|
def get_executions_by_app(
|
||||||
|
self,
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0
|
||||||
|
) -> list[WorkflowExecution]:
|
||||||
|
"""获取应用的执行记录列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用 ID
|
||||||
|
limit: 返回数量限制
|
||||||
|
offset: 偏移量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行记录列表
|
||||||
|
"""
|
||||||
|
return self.execution_repo.get_by_app_id(app_id, limit, offset)
|
||||||
|
|
||||||
|
def update_execution_status(
|
||||||
|
self,
|
||||||
|
execution_id: str,
|
||||||
|
status: str,
|
||||||
|
output_data: dict[str, Any] | None = None,
|
||||||
|
error_message: str | None = None,
|
||||||
|
error_node_id: str | None = None
|
||||||
|
) -> WorkflowExecution:
|
||||||
|
"""更新执行状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
execution_id: 执行 ID
|
||||||
|
status: 状态
|
||||||
|
output_data: 输出数据
|
||||||
|
error_message: 错误信息
|
||||||
|
error_node_id: 出错节点 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行记录
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: 执行记录不存在时抛出
|
||||||
|
"""
|
||||||
|
execution = self.get_execution(execution_id)
|
||||||
|
if not execution:
|
||||||
|
raise BusinessException(
|
||||||
|
error_code=BizCode.RESOURCE_NOT_FOUND,
|
||||||
|
message=f"执行记录不存在: execution_id={execution_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
execution.status = status
|
||||||
|
if output_data is not None:
|
||||||
|
execution.output_data = output_data
|
||||||
|
if error_message is not None:
|
||||||
|
execution.error_message = error_message
|
||||||
|
if error_node_id is not None:
|
||||||
|
execution.error_node_id = error_node_id
|
||||||
|
|
||||||
|
# 如果是完成状态,计算耗时
|
||||||
|
if status in ["completed", "failed", "cancelled", "timeout"]:
|
||||||
|
if not execution.completed_at:
|
||||||
|
execution.completed_at = datetime.datetime.now()
|
||||||
|
elapsed = (execution.completed_at - execution.started_at).total_seconds()
|
||||||
|
execution.elapsed_time = elapsed
|
||||||
|
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(execution)
|
||||||
|
|
||||||
|
logger.info(f"更新执行状态: execution_id={execution_id}, status={status}")
|
||||||
|
return execution
|
||||||
|
|
||||||
|
def get_execution_statistics(self, app_id: uuid.UUID) -> dict[str, Any]:
|
||||||
|
"""获取执行统计信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
统计信息
|
||||||
|
"""
|
||||||
|
total = self.execution_repo.count_by_app_id(app_id)
|
||||||
|
completed = self.execution_repo.count_by_status(app_id, "completed")
|
||||||
|
failed = self.execution_repo.count_by_status(app_id, "failed")
|
||||||
|
running = self.execution_repo.count_by_status(app_id, "running")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": total,
|
||||||
|
"completed": completed,
|
||||||
|
"failed": failed,
|
||||||
|
"running": running,
|
||||||
|
"success_rate": completed / total if total > 0 else 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# ==================== 工作流执行 ====================
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
payload: DraftRunRequest,
|
||||||
|
config: WorkflowConfig
|
||||||
|
):
|
||||||
|
"""运行工作流
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用 ID
|
||||||
|
input_data: 输入数据(包含 message 和 variables)
|
||||||
|
triggered_by: 触发用户 ID
|
||||||
|
conversation_id: 会话 ID(可选)
|
||||||
|
stream: 是否流式返回
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行结果(非流式)或生成器(流式)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: 配置不存在或执行失败时抛出
|
||||||
|
"""
|
||||||
|
# 1. 获取工作流配置
|
||||||
|
if not config:
|
||||||
|
config = self.get_workflow_config(app_id)
|
||||||
|
if not config:
|
||||||
|
raise BusinessException(
|
||||||
|
code=BizCode.CONFIG_MISSING,
|
||||||
|
message=f"工作流配置不存在: app_id={app_id}"
|
||||||
|
)
|
||||||
|
input_data = {"message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id}
|
||||||
|
|
||||||
|
# 转换 user_id 为 UUID
|
||||||
|
triggered_by_uuid = None
|
||||||
|
if payload.user_id:
|
||||||
|
try:
|
||||||
|
triggered_by_uuid = uuid.UUID(payload.user_id)
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
logger.warning(f"无效的 user_id 格式: {payload.user_id}")
|
||||||
|
|
||||||
|
# 转换 conversation_id 为 UUID
|
||||||
|
conversation_id_uuid = None
|
||||||
|
if payload.conversation_id:
|
||||||
|
try:
|
||||||
|
conversation_id_uuid = uuid.UUID(payload.conversation_id)
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}")
|
||||||
|
|
||||||
|
# 2. 创建执行记录
|
||||||
|
execution = self.create_execution(
|
||||||
|
workflow_config_id=config.id,
|
||||||
|
app_id=app_id,
|
||||||
|
trigger_type="manual",
|
||||||
|
triggered_by=triggered_by_uuid,
|
||||||
|
conversation_id=conversation_id_uuid,
|
||||||
|
input_data=input_data
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 构建工作流配置字典
|
||||||
|
workflow_config_dict = {
|
||||||
|
"nodes": config.nodes,
|
||||||
|
"edges": config.edges,
|
||||||
|
"variables": config.variables,
|
||||||
|
"execution_config": config.execution_config
|
||||||
|
}
|
||||||
|
|
||||||
|
# 4. 获取工作空间 ID(从 app 获取)
|
||||||
|
from app.models import App
|
||||||
|
|
||||||
|
|
||||||
|
# 5. 执行工作流
|
||||||
|
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 更新状态为运行中
|
||||||
|
self.update_execution_status(execution.execution_id, "running")
|
||||||
|
|
||||||
|
result = await execute_workflow(
|
||||||
|
workflow_config=workflow_config_dict,
|
||||||
|
input_data=input_data,
|
||||||
|
execution_id=execution.execution_id,
|
||||||
|
workspace_id="",
|
||||||
|
user_id=payload.user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新执行结果
|
||||||
|
if result.get("status") == "completed":
|
||||||
|
self.update_execution_status(
|
||||||
|
execution.execution_id,
|
||||||
|
"completed",
|
||||||
|
output_data=result.get("node_outputs", {})
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.update_execution_status(
|
||||||
|
execution.execution_id,
|
||||||
|
"failed",
|
||||||
|
error_message=result.get("error")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 返回增强的响应结构
|
||||||
|
return {
|
||||||
|
"execution_id": execution.execution_id,
|
||||||
|
"status": result.get("status"),
|
||||||
|
"output": result.get("output"), # 最终输出(字符串)
|
||||||
|
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||||
|
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
||||||
|
"error_message": result.get("error"),
|
||||||
|
"elapsed_time": result.get("elapsed_time"),
|
||||||
|
"token_usage": result.get("token_usage")
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
||||||
|
self.update_execution_status(
|
||||||
|
execution.execution_id,
|
||||||
|
"failed",
|
||||||
|
error_message=str(e)
|
||||||
|
)
|
||||||
|
raise BusinessException(
|
||||||
|
code=BizCode.INTERNAL_ERROR,
|
||||||
|
message=f"工作流执行失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run_workflow(
|
||||||
|
self,
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
input_data: dict[str, Any],
|
||||||
|
triggered_by: uuid.UUID,
|
||||||
|
conversation_id: uuid.UUID | None = None,
|
||||||
|
stream: bool = False
|
||||||
|
):
|
||||||
|
"""运行工作流
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用 ID
|
||||||
|
input_data: 输入数据(包含 message 和 variables)
|
||||||
|
triggered_by: 触发用户 ID
|
||||||
|
conversation_id: 会话 ID(可选)
|
||||||
|
stream: 是否流式返回
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行结果(非流式)或生成器(流式)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: 配置不存在或执行失败时抛出
|
||||||
|
"""
|
||||||
|
# 1. 获取工作流配置
|
||||||
|
config = self.get_workflow_config(app_id)
|
||||||
|
if not config:
|
||||||
|
raise BusinessException(
|
||||||
|
error_code=BizCode.RESOURCE_NOT_FOUND,
|
||||||
|
message=f"工作流配置不存在: app_id={app_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 创建执行记录
|
||||||
|
execution = self.create_execution(
|
||||||
|
workflow_config_id=config.id,
|
||||||
|
app_id=app_id,
|
||||||
|
trigger_type="manual",
|
||||||
|
triggered_by=triggered_by,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
input_data=input_data
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 构建工作流配置字典
|
||||||
|
workflow_config_dict = {
|
||||||
|
"nodes": config.nodes,
|
||||||
|
"edges": config.edges,
|
||||||
|
"variables": config.variables,
|
||||||
|
"execution_config": config.execution_config
|
||||||
|
}
|
||||||
|
|
||||||
|
# 4. 获取工作空间 ID(从 app 获取)
|
||||||
|
from app.models import App
|
||||||
|
app = self.db.query(App).filter(App.id == app_id).first()
|
||||||
|
if not app:
|
||||||
|
raise BusinessException(
|
||||||
|
error_code=BizCode.RESOURCE_NOT_FOUND,
|
||||||
|
message=f"应用不存在: app_id={app_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. 执行工作流
|
||||||
|
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 更新状态为运行中
|
||||||
|
self.update_execution_status(execution.execution_id, "running")
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
# 流式执行
|
||||||
|
return self._run_workflow_stream(
|
||||||
|
workflow_config_dict,
|
||||||
|
input_data,
|
||||||
|
execution.execution_id,
|
||||||
|
str(app.workspace_id),
|
||||||
|
str(triggered_by)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 非流式执行
|
||||||
|
result = await execute_workflow(
|
||||||
|
workflow_config=workflow_config_dict,
|
||||||
|
input_data=input_data,
|
||||||
|
execution_id=execution.execution_id,
|
||||||
|
workspace_id=str(app.workspace_id),
|
||||||
|
user_id=str(triggered_by)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新执行结果
|
||||||
|
if result.get("status") == "completed":
|
||||||
|
self.update_execution_status(
|
||||||
|
execution.execution_id,
|
||||||
|
"completed",
|
||||||
|
output_data=result.get("node_outputs", {})
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.update_execution_status(
|
||||||
|
execution.execution_id,
|
||||||
|
"failed",
|
||||||
|
error_message=result.get("error")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 返回增强的响应结构
|
||||||
|
return {
|
||||||
|
"execution_id": execution.execution_id,
|
||||||
|
"status": result.get("status"),
|
||||||
|
"output": result.get("output"), # 最终输出(字符串)
|
||||||
|
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||||
|
"error_message": result.get("error"),
|
||||||
|
"elapsed_time": result.get("elapsed_time"),
|
||||||
|
"token_usage": result.get("token_usage")
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
||||||
|
self.update_execution_status(
|
||||||
|
execution.execution_id,
|
||||||
|
"failed",
|
||||||
|
error_message=str(e)
|
||||||
|
)
|
||||||
|
raise BusinessException(
|
||||||
|
error_code=BizCode.INTERNAL_ERROR,
|
||||||
|
message=f"工作流执行失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _run_workflow_stream(
|
||||||
|
self,
|
||||||
|
workflow_config: dict[str, Any],
|
||||||
|
input_data: dict[str, Any],
|
||||||
|
execution_id: str,
|
||||||
|
workspace_id: str,
|
||||||
|
user_id: str
|
||||||
|
):
|
||||||
|
"""运行工作流(流式,内部方法)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_config: 工作流配置
|
||||||
|
input_data: 输入数据
|
||||||
|
execution_id: 执行 ID
|
||||||
|
workspace_id: 工作空间 ID
|
||||||
|
user_id: 用户 ID
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
流式事件
|
||||||
|
"""
|
||||||
|
from app.core.workflow.executor import execute_workflow_stream
|
||||||
|
|
||||||
|
try:
|
||||||
|
output_data = {}
|
||||||
|
|
||||||
|
async for event in execute_workflow_stream(
|
||||||
|
workflow_config=workflow_config,
|
||||||
|
input_data=input_data,
|
||||||
|
execution_id=execution_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
user_id=user_id
|
||||||
|
):
|
||||||
|
# 转发事件
|
||||||
|
yield event
|
||||||
|
|
||||||
|
# 收集输出数据
|
||||||
|
if event.get("type") == "node_complete":
|
||||||
|
node_data = event.get("data", {})
|
||||||
|
node_outputs = node_data.get("node_outputs", {})
|
||||||
|
output_data.update(node_outputs)
|
||||||
|
|
||||||
|
# 处理完成事件
|
||||||
|
if event.get("type") == "workflow_complete":
|
||||||
|
self.update_execution_status(
|
||||||
|
execution_id,
|
||||||
|
"completed",
|
||||||
|
output_data=output_data
|
||||||
|
)
|
||||||
|
|
||||||
|
# 处理错误事件
|
||||||
|
if event.get("type") == "workflow_error":
|
||||||
|
self.update_execution_status(
|
||||||
|
execution_id,
|
||||||
|
"failed",
|
||||||
|
error_message=event.get("error")
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"工作流流式执行失败: execution_id={execution_id}, error={e}", exc_info=True)
|
||||||
|
self.update_execution_status(
|
||||||
|
execution_id,
|
||||||
|
"failed",
|
||||||
|
error_message=str(e)
|
||||||
|
)
|
||||||
|
yield {
|
||||||
|
"type": "workflow_error",
|
||||||
|
"execution_id": execution_id,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 依赖注入函数 ====================
|
||||||
|
|
||||||
|
def get_workflow_service(
|
||||||
|
db: Annotated[Session, Depends(get_db)]
|
||||||
|
) -> WorkflowService:
|
||||||
|
"""获取工作流服务(依赖注入)"""
|
||||||
|
return WorkflowService(db)
|
||||||
219
api/app/templates/workflows/customer_service/template.yml
Normal file
219
api/app/templates/workflows/customer_service/template.yml
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
# 智能客服工作流模板
|
||||||
|
id: customer_service_v1
|
||||||
|
name: 智能客服工作流
|
||||||
|
description: 智能客服场景,包含意图识别、知识库查询和回复生成
|
||||||
|
category: customer_service
|
||||||
|
version: "1.0.0"
|
||||||
|
author: RedBear Memory Team
|
||||||
|
tags:
|
||||||
|
- 客服
|
||||||
|
- 意图识别
|
||||||
|
- 知识库
|
||||||
|
- 多步骤
|
||||||
|
|
||||||
|
# 工作流配置
|
||||||
|
nodes:
|
||||||
|
- id: start
|
||||||
|
type: start
|
||||||
|
name: 开始
|
||||||
|
position:
|
||||||
|
x: 100
|
||||||
|
y: 200
|
||||||
|
|
||||||
|
- id: intent_recognition
|
||||||
|
type: llm
|
||||||
|
name: 意图识别
|
||||||
|
config:
|
||||||
|
prompt: |
|
||||||
|
分析用户的问题,识别意图类型。
|
||||||
|
|
||||||
|
用户问题:{{ var.user_message }}
|
||||||
|
|
||||||
|
请从以下类型中选择一个:
|
||||||
|
- product_inquiry: 产品咨询
|
||||||
|
- technical_support: 技术支持
|
||||||
|
- complaint: 投诉建议
|
||||||
|
- other: 其他
|
||||||
|
|
||||||
|
只返回类型名称,不要其他内容。
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
temperature: 0.3
|
||||||
|
max_tokens: 50
|
||||||
|
position:
|
||||||
|
x: 300
|
||||||
|
y: 200
|
||||||
|
|
||||||
|
- id: intent_router
|
||||||
|
type: condition
|
||||||
|
name: 意图路由
|
||||||
|
position:
|
||||||
|
x: 500
|
||||||
|
y: 200
|
||||||
|
|
||||||
|
- id: product_handler
|
||||||
|
type: llm
|
||||||
|
name: 产品咨询处理
|
||||||
|
config:
|
||||||
|
prompt: |
|
||||||
|
用户咨询产品相关问题。
|
||||||
|
|
||||||
|
问题:{{ var.user_message }}
|
||||||
|
意图:{{ node.intent_recognition.output }}
|
||||||
|
|
||||||
|
请提供专业、友好的产品咨询回复。
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
temperature: 0.7
|
||||||
|
max_tokens: 500
|
||||||
|
position:
|
||||||
|
x: 700
|
||||||
|
y: 100
|
||||||
|
|
||||||
|
- id: support_handler
|
||||||
|
type: llm
|
||||||
|
name: 技术支持处理
|
||||||
|
config:
|
||||||
|
prompt: |
|
||||||
|
用户需要技术支持。
|
||||||
|
|
||||||
|
问题:{{ var.user_message }}
|
||||||
|
意图:{{ node.intent_recognition.output }}
|
||||||
|
|
||||||
|
请提供详细的技术支持方案。
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
temperature: 0.5
|
||||||
|
max_tokens: 800
|
||||||
|
position:
|
||||||
|
x: 700
|
||||||
|
y: 200
|
||||||
|
|
||||||
|
- id: complaint_handler
|
||||||
|
type: llm
|
||||||
|
name: 投诉处理
|
||||||
|
config:
|
||||||
|
prompt: |
|
||||||
|
用户提出投诉或建议。
|
||||||
|
|
||||||
|
问题:{{ var.user_message }}
|
||||||
|
意图:{{ node.intent_recognition.output }}
|
||||||
|
|
||||||
|
请以同理心回应,并提供解决方案。
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
temperature: 0.8
|
||||||
|
max_tokens: 600
|
||||||
|
position:
|
||||||
|
x: 700
|
||||||
|
y: 300
|
||||||
|
|
||||||
|
- id: general_handler
|
||||||
|
type: llm
|
||||||
|
name: 通用处理
|
||||||
|
config:
|
||||||
|
prompt: |
|
||||||
|
用户的问题类型:其他
|
||||||
|
|
||||||
|
问题:{{ var.user_message }}
|
||||||
|
|
||||||
|
请提供友好的回复。
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
temperature: 0.7
|
||||||
|
max_tokens: 400
|
||||||
|
position:
|
||||||
|
x: 700
|
||||||
|
y: 400
|
||||||
|
|
||||||
|
- id: end
|
||||||
|
type: end
|
||||||
|
name: 结束
|
||||||
|
position:
|
||||||
|
x: 900
|
||||||
|
y: 200
|
||||||
|
|
||||||
|
edges:
|
||||||
|
- source: start
|
||||||
|
target: intent_recognition
|
||||||
|
label: 开始分析
|
||||||
|
|
||||||
|
- source: intent_recognition
|
||||||
|
target: intent_router
|
||||||
|
label: 识别完成
|
||||||
|
|
||||||
|
- source: intent_router
|
||||||
|
target: product_handler
|
||||||
|
condition: "'product_inquiry' in node['intent_recognition']['output']"
|
||||||
|
label: 产品咨询
|
||||||
|
|
||||||
|
- source: intent_router
|
||||||
|
target: support_handler
|
||||||
|
condition: "'technical_support' in node['intent_recognition']['output']"
|
||||||
|
label: 技术支持
|
||||||
|
|
||||||
|
- source: intent_router
|
||||||
|
target: complaint_handler
|
||||||
|
condition: "'complaint' in node['intent_recognition']['output']"
|
||||||
|
label: 投诉建议
|
||||||
|
|
||||||
|
- source: intent_router
|
||||||
|
target: general_handler
|
||||||
|
condition: "True" # 默认路径
|
||||||
|
label: 其他
|
||||||
|
|
||||||
|
- source: product_handler
|
||||||
|
target: end
|
||||||
|
label: 完成
|
||||||
|
|
||||||
|
- source: support_handler
|
||||||
|
target: end
|
||||||
|
label: 完成
|
||||||
|
|
||||||
|
- source: complaint_handler
|
||||||
|
target: end
|
||||||
|
label: 完成
|
||||||
|
|
||||||
|
- source: general_handler
|
||||||
|
target: end
|
||||||
|
label: 完成
|
||||||
|
|
||||||
|
# 变量定义
|
||||||
|
variables:
|
||||||
|
- name: user_message
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
description: 用户的消息
|
||||||
|
default: ""
|
||||||
|
|
||||||
|
- name: user_name
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
description: 用户姓名(可选)
|
||||||
|
default: "客户"
|
||||||
|
|
||||||
|
# 执行配置
|
||||||
|
execution_config:
|
||||||
|
max_execution_time: 120
|
||||||
|
max_iterations: 10
|
||||||
|
|
||||||
|
# 触发器
|
||||||
|
triggers: []
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
examples:
|
||||||
|
- name: 产品咨询
|
||||||
|
description: 用户咨询产品功能
|
||||||
|
input:
|
||||||
|
user_message: "你们的产品支持多语言吗?"
|
||||||
|
user_name: "张三"
|
||||||
|
expected_output: "产品功能介绍"
|
||||||
|
|
||||||
|
- name: 技术支持
|
||||||
|
description: 用户遇到技术问题
|
||||||
|
input:
|
||||||
|
user_message: "我无法登录系统,一直显示密码错误"
|
||||||
|
user_name: "李四"
|
||||||
|
expected_output: "技术支持方案"
|
||||||
|
|
||||||
|
- name: 投诉处理
|
||||||
|
description: 用户提出投诉
|
||||||
|
input:
|
||||||
|
user_message: "你们的服务态度太差了,我要投诉"
|
||||||
|
user_name: "王五"
|
||||||
|
expected_output: "同理心回应和解决方案"
|
||||||
131
api/app/templates/workflows/data_processing/template.yml
Normal file
131
api/app/templates/workflows/data_processing/template.yml
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
# 数据处理工作流模板
|
||||||
|
id: data_processing_v1
|
||||||
|
name: 数据处理工作流
|
||||||
|
description: 数据提取、转换和分析的完整流程
|
||||||
|
category: data_processing
|
||||||
|
version: "1.0.0"
|
||||||
|
author: RedBear Memory Team
|
||||||
|
tags:
|
||||||
|
- 数据处理
|
||||||
|
- ETL
|
||||||
|
- 分析
|
||||||
|
- Transform
|
||||||
|
|
||||||
|
# 工作流配置
|
||||||
|
nodes:
|
||||||
|
- id: start
|
||||||
|
type: start
|
||||||
|
name: 开始
|
||||||
|
position:
|
||||||
|
x: 100
|
||||||
|
y: 200
|
||||||
|
|
||||||
|
- id: extract_data
|
||||||
|
type: transform
|
||||||
|
name: 数据提取
|
||||||
|
config:
|
||||||
|
expression: |
|
||||||
|
{
|
||||||
|
"raw_text": var['input_text'],
|
||||||
|
"length": len(var['input_text']),
|
||||||
|
"timestamp": sys['execution_id']
|
||||||
|
}
|
||||||
|
position:
|
||||||
|
x: 300
|
||||||
|
y: 200
|
||||||
|
|
||||||
|
- id: analyze_data
|
||||||
|
type: llm
|
||||||
|
name: 数据分析
|
||||||
|
config:
|
||||||
|
prompt: |
|
||||||
|
请分析以下数据:
|
||||||
|
|
||||||
|
原始文本:{{ node.extract_data.raw_text }}
|
||||||
|
文本长度:{{ node.extract_data.length }}
|
||||||
|
|
||||||
|
请提供:
|
||||||
|
1. 主题分类
|
||||||
|
2. 情感分析
|
||||||
|
3. 关键信息提取
|
||||||
|
|
||||||
|
以 JSON 格式返回结果。
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
temperature: 0.3
|
||||||
|
max_tokens: 500
|
||||||
|
position:
|
||||||
|
x: 500
|
||||||
|
y: 200
|
||||||
|
|
||||||
|
- id: transform_result
|
||||||
|
type: transform
|
||||||
|
name: 结果转换
|
||||||
|
config:
|
||||||
|
expression: |
|
||||||
|
{
|
||||||
|
"original_length": node['extract_data']['length'],
|
||||||
|
"analysis": node['analyze_data']['output'],
|
||||||
|
"processed_at": sys['execution_id'],
|
||||||
|
"status": "completed"
|
||||||
|
}
|
||||||
|
position:
|
||||||
|
x: 700
|
||||||
|
y: 200
|
||||||
|
|
||||||
|
- id: end
|
||||||
|
type: end
|
||||||
|
name: 结束
|
||||||
|
position:
|
||||||
|
x: 900
|
||||||
|
y: 200
|
||||||
|
|
||||||
|
edges:
|
||||||
|
- source: start
|
||||||
|
target: extract_data
|
||||||
|
label: 开始提取
|
||||||
|
|
||||||
|
- source: extract_data
|
||||||
|
target: analyze_data
|
||||||
|
label: 开始分析
|
||||||
|
|
||||||
|
- source: analyze_data
|
||||||
|
target: transform_result
|
||||||
|
label: 转换结果
|
||||||
|
|
||||||
|
- source: transform_result
|
||||||
|
target: end
|
||||||
|
label: 完成
|
||||||
|
|
||||||
|
# 变量定义
|
||||||
|
variables:
|
||||||
|
- name: input_text
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
description: 待处理的文本数据
|
||||||
|
default: ""
|
||||||
|
|
||||||
|
# 执行配置
|
||||||
|
execution_config:
|
||||||
|
max_execution_time: 180
|
||||||
|
max_iterations: 5
|
||||||
|
|
||||||
|
# 触发器
|
||||||
|
triggers: []
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
examples:
|
||||||
|
- name: 文本分析
|
||||||
|
description: 分析一段文本
|
||||||
|
input:
|
||||||
|
input_text: "今天天气真好,心情也很愉快。我们公司推出了新产品,市场反响热烈。"
|
||||||
|
expected_output:
|
||||||
|
original_length: 35
|
||||||
|
analysis: "主题:天气和产品,情感:积极"
|
||||||
|
status: "completed"
|
||||||
|
|
||||||
|
- name: 长文本处理
|
||||||
|
description: 处理较长的文本
|
||||||
|
input:
|
||||||
|
input_text: "这是一段很长的文本..."
|
||||||
|
expected_output:
|
||||||
|
status: "completed"
|
||||||
99
api/app/templates/workflows/multi_step_qa/template.yml
Normal file
99
api/app/templates/workflows/multi_step_qa/template.yml
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
# 多步骤问答工作流
|
||||||
|
# 演示节点输出参数的使用
|
||||||
|
|
||||||
|
id: multi_step_qa_v1
|
||||||
|
name: 多步骤问答工作流
|
||||||
|
description: 先分析问题,再生成答案,展示节点间的数据传递
|
||||||
|
category: advanced
|
||||||
|
version: "1.0.0"
|
||||||
|
author: RedBear Memory Team
|
||||||
|
tags:
|
||||||
|
- 问答
|
||||||
|
- 多步骤
|
||||||
|
- LLM
|
||||||
|
|
||||||
|
# 工作流配置
|
||||||
|
nodes:
|
||||||
|
- id: start
|
||||||
|
type: start
|
||||||
|
name: 开始
|
||||||
|
position:
|
||||||
|
x: 100
|
||||||
|
y: 100
|
||||||
|
|
||||||
|
- id: analyze_question
|
||||||
|
type: llm
|
||||||
|
name: 分析问题
|
||||||
|
description: 分析用户问题的类型和意图
|
||||||
|
config:
|
||||||
|
model_id: gpt-3.5-turbo
|
||||||
|
temperature: 0.3
|
||||||
|
max_tokens: 500
|
||||||
|
messages:
|
||||||
|
- role: system
|
||||||
|
content: |
|
||||||
|
你是一个问题分析专家。请分析用户的问题,提取以下信息:
|
||||||
|
1. 问题类型(事实性、观点性、操作性等)
|
||||||
|
2. 问题领域(科技、历史、文化等)
|
||||||
|
3. 关键词
|
||||||
|
- role: user
|
||||||
|
content: "{{ sys.message }}"
|
||||||
|
position:
|
||||||
|
x: 300
|
||||||
|
y: 100
|
||||||
|
|
||||||
|
- id: generate_answer
|
||||||
|
type: llm
|
||||||
|
name: 生成答案
|
||||||
|
description: 根据问题分析结果生成详细答案
|
||||||
|
config:
|
||||||
|
model_id: gpt-3.5-turbo
|
||||||
|
temperature: 0.7
|
||||||
|
max_tokens: 1000
|
||||||
|
messages:
|
||||||
|
- role: system
|
||||||
|
content: |
|
||||||
|
你是一个专业的AI助手。根据问题分析结果,生成准确、详细的答案。
|
||||||
|
|
||||||
|
问题分析结果:
|
||||||
|
{{ analyze_question.output }}
|
||||||
|
- role: user
|
||||||
|
content: "{{ sys.message }}"
|
||||||
|
position:
|
||||||
|
x: 500
|
||||||
|
y: 100
|
||||||
|
|
||||||
|
- id: end
|
||||||
|
type: end
|
||||||
|
name: 结束
|
||||||
|
config:
|
||||||
|
output: "{{ generate_answer.output }}"
|
||||||
|
position:
|
||||||
|
x: 700
|
||||||
|
y: 100
|
||||||
|
|
||||||
|
edges:
|
||||||
|
- source: start
|
||||||
|
target: analyze_question
|
||||||
|
label: 开始分析
|
||||||
|
|
||||||
|
- source: analyze_question
|
||||||
|
target: generate_answer
|
||||||
|
label: 生成答案
|
||||||
|
|
||||||
|
- source: generate_answer
|
||||||
|
target: end
|
||||||
|
label: 完成
|
||||||
|
|
||||||
|
# 变量定义
|
||||||
|
variables:
|
||||||
|
- name: user_question
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
description: 用户的问题
|
||||||
|
default: ""
|
||||||
|
|
||||||
|
# 执行配置
|
||||||
|
execution_config:
|
||||||
|
max_execution_time: 120
|
||||||
|
max_iterations: 1
|
||||||
100
api/app/templates/workflows/simple_qa/template.yml
Normal file
100
api/app/templates/workflows/simple_qa/template.yml
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
# 简单问答工作流模板
|
||||||
|
id: simple_qa_v1
|
||||||
|
name: 简单问答工作流
|
||||||
|
description: 最基础的问答工作流,适合快速开始
|
||||||
|
category: basic
|
||||||
|
version: "1.0.0"
|
||||||
|
author: RedBear Memory Team
|
||||||
|
tags:
|
||||||
|
- 问答
|
||||||
|
- 基础
|
||||||
|
- LLM
|
||||||
|
|
||||||
|
# 工作流配置
|
||||||
|
nodes:
|
||||||
|
- id: start
|
||||||
|
type: start
|
||||||
|
name: 开始
|
||||||
|
position:
|
||||||
|
x: 100
|
||||||
|
y: 100
|
||||||
|
|
||||||
|
- id: llm_qa
|
||||||
|
type: llm
|
||||||
|
name: LLM 问答
|
||||||
|
config:
|
||||||
|
# 使用 LangChain 标准的消息格式
|
||||||
|
messages:
|
||||||
|
- role: system
|
||||||
|
content: |
|
||||||
|
你是一个专业、友好且乐于助人的 AI 助手。
|
||||||
|
|
||||||
|
你的职责:
|
||||||
|
- 准确理解用户的问题并提供有价值的回答
|
||||||
|
- 保持回答的专业性和准确性
|
||||||
|
- 如果不确定答案,诚实地告知用户
|
||||||
|
- 使用清晰、易懂的语言进行交流
|
||||||
|
|
||||||
|
回答风格:
|
||||||
|
- 简洁明了,直击要点
|
||||||
|
- 必要时提供详细解释和示例
|
||||||
|
- 使用友好、礼貌的语气
|
||||||
|
- 适当使用格式化(如列表、段落)提高可读性
|
||||||
|
|
||||||
|
- role: user
|
||||||
|
content: "{{ sys.message }}"
|
||||||
|
|
||||||
|
model_id: gpt-3.5-turbo
|
||||||
|
temperature: 0.7
|
||||||
|
max_tokens: 1000
|
||||||
|
position:
|
||||||
|
x: 300
|
||||||
|
y: 100
|
||||||
|
|
||||||
|
- id: end
|
||||||
|
type: end
|
||||||
|
name: 结束
|
||||||
|
config:
|
||||||
|
output: "{{ llm_qa.output }}"
|
||||||
|
position:
|
||||||
|
x: 500
|
||||||
|
y: 100
|
||||||
|
|
||||||
|
edges:
|
||||||
|
- source: start
|
||||||
|
target: llm_qa
|
||||||
|
label: 开始处理
|
||||||
|
|
||||||
|
- source: llm_qa
|
||||||
|
target: end
|
||||||
|
label: 完成
|
||||||
|
|
||||||
|
# 变量定义
|
||||||
|
variables:
|
||||||
|
- name: user_question
|
||||||
|
type: string
|
||||||
|
required: true
|
||||||
|
description: 用户的问题
|
||||||
|
default: ""
|
||||||
|
|
||||||
|
# 执行配置
|
||||||
|
execution_config:
|
||||||
|
max_execution_time: 60
|
||||||
|
max_iterations: 1
|
||||||
|
|
||||||
|
# 触发器(可选)
|
||||||
|
triggers: []
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
examples:
|
||||||
|
- name: 基础问答
|
||||||
|
description: 询问一个简单的问题
|
||||||
|
input:
|
||||||
|
user_question: "什么是人工智能?"
|
||||||
|
expected_output: "关于人工智能的解释"
|
||||||
|
|
||||||
|
- name: 技术咨询
|
||||||
|
description: 询问技术问题
|
||||||
|
input:
|
||||||
|
user_question: "如何学习 Python 编程?"
|
||||||
|
expected_output: "Python 学习建议"
|
||||||
Reference in New Issue
Block a user