Merge remote-tracking branch 'origin/develop' into refactor/memory-config-management
This commit is contained in:
@@ -23,11 +23,17 @@ from . import (
|
||||
memory_dashboard_controller,
|
||||
memory_storage_controller,
|
||||
memory_dashboard_controller,
|
||||
memory_reflection_controller,
|
||||
api_key_controller,
|
||||
release_share_controller,
|
||||
public_share_controller,
|
||||
multi_agent_controller,
|
||||
workflow_controller,
|
||||
emotion_controller,
|
||||
emotion_config_controller,
|
||||
prompt_optimizer_controller,
|
||||
tool_controller,
|
||||
tool_execution_controller,
|
||||
)
|
||||
|
||||
# 创建管理端 API 路由器
|
||||
@@ -58,5 +64,11 @@ manager_router.include_router(public_share_controller.router) # 公开路由(
|
||||
manager_router.include_router(memory_dashboard_controller.router)
|
||||
manager_router.include_router(multi_agent_controller.router)
|
||||
manager_router.include_router(workflow_controller.router)
|
||||
manager_router.include_router(emotion_controller.router)
|
||||
manager_router.include_router(emotion_config_controller.router)
|
||||
manager_router.include_router(prompt_optimizer_controller.router)
|
||||
manager_router.include_router(memory_reflection_controller.router)
|
||||
manager_router.include_router(tool_controller.router)
|
||||
manager_router.include_router(tool_execution_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
@@ -8,6 +8,7 @@ from sqlalchemy.orm import Session
|
||||
from app.core.error_codes import BizCode
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
from app.models import ApiKeyType
|
||||
from app.models.user_model import User
|
||||
from app.core.response_utils import success
|
||||
from app.schemas import api_key_schema
|
||||
@@ -39,6 +40,8 @@ def create_api_key(
|
||||
"""
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if data.type == ApiKeyType.SERVICE.value and not data.resource_id:
|
||||
data.resource_id = workspace_id
|
||||
|
||||
# 创建 API Key
|
||||
api_key_obj = ApiKeyService.create_api_key(
|
||||
|
||||
@@ -421,8 +421,8 @@ async def draft_run(
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
|
||||
|
||||
|
||||
|
||||
async for event in draft_service.run_stream(
|
||||
agent_config=agent_cfg,
|
||||
model_config=model_config,
|
||||
@@ -574,7 +574,7 @@ async def draft_run(
|
||||
# 3. 流式返回
|
||||
if payload.stream:
|
||||
logger.debug(
|
||||
"开始多智能体流式试运行",
|
||||
"开始工作流流式试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"message_length": len(payload.message),
|
||||
@@ -583,18 +583,27 @@ async def draft_run(
|
||||
)
|
||||
|
||||
async def event_generator():
|
||||
"""多智能体流式事件生成器"""
|
||||
multiservice = MultiAgentService(db)
|
||||
|
||||
# 调用多智能体服务的流式方法
|
||||
async for event in multiservice.run_stream(
|
||||
"""工作流事件生成器
|
||||
|
||||
将事件转换为标准 SSE 格式:
|
||||
event: <event_type>
|
||||
data: <json_data>
|
||||
"""
|
||||
import json
|
||||
|
||||
# 调用工作流服务的流式方法
|
||||
async for event in workflow_service.run_stream(
|
||||
app_id=app_id,
|
||||
request=multi_agent_request,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
|
||||
payload=payload,
|
||||
config=config
|
||||
):
|
||||
yield event
|
||||
# 提取事件类型和数据
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||
yield sse_message
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
@@ -617,7 +626,7 @@ async def draft_run(
|
||||
)
|
||||
|
||||
result = await workflow_service.run(app_id, payload,config)
|
||||
|
||||
|
||||
logger.debug(
|
||||
"工作流试运行返回结果",
|
||||
extra={
|
||||
|
||||
207
api/app/controllers/emotion_config_controller.py
Normal file
207
api/app/controllers/emotion_config_controller.py
Normal file
@@ -0,0 +1,207 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""情绪配置控制器模块
|
||||
|
||||
本模块提供情绪引擎配置管理的API端点,包括获取和更新配置。
|
||||
|
||||
Routes:
|
||||
GET /memory/config/emotion - 获取情绪引擎配置
|
||||
POST /memory/config/emotion - 更新情绪引擎配置
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.response_utils import success
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.emotion_config_service import EmotionConfigService
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.db import get_db
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/emotion",
|
||||
tags=["Emotion Config"],
|
||||
dependencies=[Depends(get_current_user)] # 所有路由都需要认证
|
||||
)
|
||||
|
||||
class EmotionConfigQuery(BaseModel):
|
||||
"""情绪配置查询请求模型"""
|
||||
config_id: int = Field(..., description="配置ID")
|
||||
|
||||
class EmotionConfigUpdate(BaseModel):
|
||||
"""情绪配置更新请求模型"""
|
||||
config_id: int = Field(..., description="配置ID")
|
||||
emotion_enabled: bool = Field(..., description="是否启用情绪提取")
|
||||
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
|
||||
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
|
||||
emotion_min_intensity: float = Field(..., ge=0.0, le=1.0, description="最小情绪强度阈值(0.0-1.0)")
|
||||
emotion_enable_subject: bool = Field(..., description="是否启用主体分类")
|
||||
|
||||
@router.get("/read_config", response_model=ApiResponse)
|
||||
def get_emotion_config(
|
||||
config_id: int = Query(..., description="配置ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取情绪引擎配置
|
||||
|
||||
查询指定配置ID的情绪相关配置字段。
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含情绪配置数据
|
||||
|
||||
Example Response:
|
||||
{
|
||||
"code": 2000,
|
||||
"msg": "情绪配置获取成功",
|
||||
"data": {
|
||||
"config_id": 17,
|
||||
"emotion_enabled": true,
|
||||
"emotion_model_id": "gpt-4",
|
||||
"emotion_extract_keywords": true,
|
||||
"emotion_min_intensity": 0.1,
|
||||
"emotion_enable_subject": true
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪配置",
|
||||
extra={"config_id": config_id}
|
||||
)
|
||||
|
||||
# 初始化服务
|
||||
config_service = EmotionConfigService(db)
|
||||
|
||||
# 调用服务层
|
||||
data = config_service.get_emotion_config(config_id)
|
||||
|
||||
api_logger.info(
|
||||
"情绪配置获取成功",
|
||||
extra={
|
||||
"config_id": config_id,
|
||||
"emotion_enabled": data.get("emotion_enabled", False)
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=data, msg="情绪配置获取成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(
|
||||
f"获取情绪配置失败: {str(e)}",
|
||||
extra={"config_id": config_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪配置失败: {str(e)}",
|
||||
extra={"config_id": config_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取情绪配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@router.post("/updated_config", response_model=ApiResponse)
|
||||
def update_emotion_config(
|
||||
config: EmotionConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""更新情绪引擎配置
|
||||
|
||||
更新指定配置ID的情绪相关配置字段。
|
||||
|
||||
Args:
|
||||
config: 配置更新数据(包含config_id)
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含更新后的情绪配置数据
|
||||
|
||||
Example Request:
|
||||
{
|
||||
"config_id": 2,
|
||||
"emotion_enabled": true,
|
||||
"emotion_model_id": "gpt-4",
|
||||
"emotion_extract_keywords": true,
|
||||
"emotion_min_intensity": 0.1,
|
||||
"emotion_enable_subject": true
|
||||
}
|
||||
|
||||
Example Response:
|
||||
{
|
||||
"code": 2000,
|
||||
"msg": "情绪配置更新成功",
|
||||
"data": {
|
||||
"config_id": 17,
|
||||
"emotion_enabled": true,
|
||||
"emotion_model_id": "gpt-4",
|
||||
"emotion_extract_keywords": true,
|
||||
"emotion_min_intensity": 0.2,
|
||||
"emotion_enable_subject": true
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求更新情绪配置",
|
||||
extra={
|
||||
"config_id": config.config_id,
|
||||
"emotion_enabled": config.emotion_enabled,
|
||||
"emotion_min_intensity": config.emotion_min_intensity
|
||||
}
|
||||
)
|
||||
|
||||
# 初始化服务
|
||||
config_service = EmotionConfigService(db)
|
||||
|
||||
# 转换为字典(排除config_id,因为它作为参数传递)
|
||||
config_data = config.model_dump(exclude={'config_id'})
|
||||
|
||||
# 调用服务层
|
||||
data = config_service.update_emotion_config(config.config_id, config_data)
|
||||
|
||||
api_logger.info(
|
||||
"情绪配置更新成功",
|
||||
extra={
|
||||
"config_id": config.config_id,
|
||||
"emotion_enabled": data.get("emotion_enabled", False)
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=data, msg="情绪配置更新成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(
|
||||
f"更新情绪配置失败: {str(e)}",
|
||||
extra={"config_id": config.config_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"更新情绪配置失败: {str(e)}",
|
||||
extra={"config_id": config.config_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"更新情绪配置失败: {str(e)}"
|
||||
)
|
||||
255
api/app/controllers/emotion_controller.py
Normal file
255
api/app/controllers/emotion_controller.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""情绪分析控制器模块
|
||||
|
||||
本模块提供情绪分析相关的API端点,包括情绪标签、词云、健康指数和个性化建议。
|
||||
|
||||
Routes:
|
||||
POST /emotion/tags - 获取情绪标签统计
|
||||
POST /emotion/wordcloud - 获取情绪词云数据
|
||||
POST /emotion/health - 获取情绪健康指数
|
||||
POST /emotion/suggestions - 获取个性化情绪建议
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.emotion_schema import (
|
||||
EmotionTagsRequest,
|
||||
EmotionWordcloudRequest,
|
||||
EmotionHealthRequest,
|
||||
EmotionSuggestionsRequest
|
||||
)
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/emotion",
|
||||
tags=["Emotion Analysis"],
|
||||
dependencies=[Depends(get_current_user)] # 所有路由都需要认证
|
||||
)
|
||||
|
||||
|
||||
# 初始化情绪分析服务uv
|
||||
emotion_service = EmotionAnalyticsService()
|
||||
|
||||
|
||||
|
||||
@router.post("/tags", response_model=ApiResponse)
|
||||
async def get_emotion_tags(
|
||||
request: EmotionTagsRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
try:
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪标签统计",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"emotion_type": request.emotion_type,
|
||||
"start_date": request.start_date,
|
||||
"end_date": request.end_date,
|
||||
"limit": request.limit
|
||||
}
|
||||
)
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.get_emotion_tags(
|
||||
end_user_id=request.group_id,
|
||||
emotion_type=request.emotion_type,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
limit=request.limit
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"情绪标签统计获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"total_count": data.get("total_count", 0),
|
||||
"tags_count": len(data.get("tags", []))
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=data, msg="情绪标签获取成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪标签统计失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取情绪标签统计失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@router.post("/wordcloud", response_model=ApiResponse)
|
||||
async def get_emotion_wordcloud(
|
||||
request: EmotionWordcloudRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
try:
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪词云数据",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"emotion_type": request.emotion_type,
|
||||
"limit": request.limit
|
||||
}
|
||||
)
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.get_emotion_wordcloud(
|
||||
end_user_id=request.group_id,
|
||||
emotion_type=request.emotion_type,
|
||||
limit=request.limit
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"情绪词云数据获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"total_keywords": data.get("total_keywords", 0)
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=data, msg="情绪词云获取成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪词云数据失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取情绪词云数据失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@router.post("/health", response_model=ApiResponse)
|
||||
async def get_emotion_health(
|
||||
request: EmotionHealthRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
try:
|
||||
# 验证时间范围参数
|
||||
if request.time_range not in ["7d", "30d", "90d"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="时间范围参数无效,必须是 7d、30d 或 90d"
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪健康指数",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"time_range": request.time_range
|
||||
}
|
||||
)
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.calculate_emotion_health_index(
|
||||
end_user_id=request.group_id,
|
||||
time_range=request.time_range
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"情绪健康指数获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"health_score": data.get("health_score", 0),
|
||||
"level": data.get("level", "未知")
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=data, msg="情绪健康指数获取成功")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪健康指数失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取情绪健康指数失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@router.post("/suggestions", response_model=ApiResponse)
|
||||
async def get_emotion_suggestions(
|
||||
request: EmotionSuggestionsRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取个性化情绪建议
|
||||
|
||||
Args:
|
||||
request: 包含 group_id 和可选的 config_id
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
个性化情绪建议响应
|
||||
"""
|
||||
try:
|
||||
# 验证 config_id(如果提供)
|
||||
config_id = request.config_id
|
||||
if config_id is not None:
|
||||
from app.controllers.memory_agent_controller import validate_config_id
|
||||
try:
|
||||
config_id = validate_config_id(config_id, db)
|
||||
except ValueError as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取个性化情绪建议",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"config_id": config_id
|
||||
}
|
||||
)
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=request.group_id,
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"个性化建议获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"suggestions_count": len(data.get("suggestions", []))
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=data, msg="个性化建议获取成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取个性化建议失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取个性化建议失败: {str(e)}"
|
||||
)
|
||||
269
api/app/controllers/memory_reflection_controller.py
Normal file
269
api/app/controllers/memory_reflection_controller.py
Normal file
@@ -0,0 +1,269 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.core.memory.storage_services.reflection_engine.self_reflexion import ReflectionConfig, ReflectionEngine
|
||||
from app.dependencies import get_current_user
|
||||
from app.db import get_db
|
||||
from app.models.user_model import User
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_reflection_service import WorkspaceAppService, MemoryReflectionService
|
||||
from app.schemas.memory_reflection_schemas import Memory_Reflection
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
load_dotenv()
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory",
|
||||
tags=["Memory"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/reflection/save")
|
||||
async def save_reflection_config(
|
||||
request: Memory_Reflection,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Save reflection configuration to data_comfig table"""
|
||||
|
||||
|
||||
|
||||
try:
|
||||
config_id = request.config_id
|
||||
if not config_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="缺少必需参数: config_id"
|
||||
)
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||
|
||||
update_params = {
|
||||
"enable_self_reflexion": request.reflection_enabled,
|
||||
"iteration_period": request.reflection_period_in_hours,
|
||||
"reflexion_range": request.reflexion_range,
|
||||
"baseline": request.baseline,
|
||||
"reflection_model_id": request.reflection_model_id,
|
||||
"memory_verify": request.memory_verify,
|
||||
"quality_assessment": request.quality_assessment,
|
||||
}
|
||||
|
||||
|
||||
|
||||
query, params = DataConfigRepository.build_update_reflection(config_id, **update_params)
|
||||
|
||||
result = db.execute(text(query), params)
|
||||
if result.rowcount == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"未找到config_id为 {config_id} 的配置"
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
||||
# 查询更新后的配置
|
||||
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
|
||||
result = db.execute(text(select_query), select_params).fetchone()
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"更新后未找到config_id为 {config_id} 的配置"
|
||||
)
|
||||
|
||||
api_logger.info(f"成功保存反思配置到数据库,config_id: {config_id}")
|
||||
|
||||
reflection_result={
|
||||
"config_id": result.config_id,
|
||||
"enable_self_reflexion": result.enable_self_reflexion,
|
||||
"iteration_period": result.iteration_period,
|
||||
"reflexion_range": result.reflexion_range,
|
||||
"baseline": result.baseline,
|
||||
"reflection_model_id": result.reflection_model_id,
|
||||
"memory_verify": result.memory_verify,
|
||||
"quality_assessment": result.quality_assessment,
|
||||
"user_id": result.user_id}
|
||||
|
||||
return success(data=reflection_result, msg="反思配置成功")
|
||||
|
||||
|
||||
|
||||
except ValueError as ve:
|
||||
api_logger.error(f"参数错误: {str(ve)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"参数错误: {str(ve)}"
|
||||
)
|
||||
except Exception as e:
|
||||
api_logger.error(f"反思配置保存失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"反思配置保存失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/reflection")
|
||||
async def start_workspace_reflection(
|
||||
config_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
reflection_service = MemoryReflectionService(db)
|
||||
|
||||
try:
|
||||
api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}")
|
||||
|
||||
service = WorkspaceAppService(db)
|
||||
result = service.get_workspace_apps_detailed(workspace_id)
|
||||
|
||||
reflection_results = []
|
||||
|
||||
for data in result['apps_detailed_info']:
|
||||
if data['data_configs'] == []:
|
||||
continue
|
||||
|
||||
releases = data['releases']
|
||||
data_configs = data['data_configs']
|
||||
end_users = data['end_users']
|
||||
|
||||
for base, config, user in zip(releases, data_configs, end_users):
|
||||
if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']:
|
||||
# 调用反思服务
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
||||
|
||||
reflection_result = await reflection_service.start_reflection_from_data(
|
||||
config_data=config,
|
||||
end_user_id=user['id']
|
||||
)
|
||||
|
||||
reflection_results.append({
|
||||
"app_id": base['app_id'],
|
||||
"config_id": config['config_id'],
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": reflection_result
|
||||
})
|
||||
|
||||
return success(data=reflection_results, msg="反思配置成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"启动workspace反思失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"启动workspace反思失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/reflection/configs")
|
||||
async def start_reflection_configs(
|
||||
config_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""通过config_id查询data_config表中的反思配置信息"""
|
||||
try:
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
|
||||
# 使用DataConfigRepository查询反思配置
|
||||
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
|
||||
result = db.execute(text(select_query), select_params).fetchone()
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"未找到config_id为 {config_id} 的配置"
|
||||
)
|
||||
|
||||
# 构建返回数据
|
||||
reflection_config = {
|
||||
"config_id": result.config_id,
|
||||
"reflection_enabled": result.enable_self_reflexion,
|
||||
"reflection_period_in_hours": result.iteration_period,
|
||||
"reflexion_range": result.reflexion_range,
|
||||
"baseline": result.baseline,
|
||||
"reflection_model_id": result.reflection_model_id,
|
||||
"memory_verify": result.memory_verify,
|
||||
"quality_assessment": result.quality_assessment,
|
||||
"user_id": result.user_id
|
||||
}
|
||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||
return success(data=reflection_config, msg="反思配置查询成功")
|
||||
|
||||
|
||||
except HTTPException:
|
||||
# 重新抛出HTTP异常
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"查询反思配置失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"查询反思配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/reflection/run")
|
||||
async def reflection_run(
|
||||
config_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
|
||||
# 使用DataConfigRepository查询反思配置
|
||||
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
|
||||
result = db.execute(text(select_query), select_params).fetchone()
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"未找到config_id为 {config_id} 的配置"
|
||||
)
|
||||
|
||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||
|
||||
# 验证模型ID是否存在
|
||||
model_id = result.reflection_model_id
|
||||
if model_id:
|
||||
try:
|
||||
ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
api_logger.info(f"模型ID验证成功: {model_id}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"模型ID '{model_id}' 不存在,将使用默认模型: {str(e)}")
|
||||
# 可以设置为None,让反思引擎使用默认模型
|
||||
model_id = None
|
||||
|
||||
config = ReflectionConfig(
|
||||
enabled=result.enable_self_reflexion,
|
||||
iteration_period=result.iteration_period,
|
||||
reflexion_range=result.reflexion_range,
|
||||
baseline=result.baseline,
|
||||
output_example='',
|
||||
memory_verify=result.memory_verify,
|
||||
quality_assessment=result.quality_assessment,
|
||||
violation_handling_strategy="block",
|
||||
model_id=model_id
|
||||
)
|
||||
connector = Neo4jConnector()
|
||||
engine = ReflectionEngine(
|
||||
config=config,
|
||||
neo4j_connector=connector,
|
||||
llm_client=model_id # 传入验证后的 model_id
|
||||
)
|
||||
|
||||
result=await (engine.reflection_run())
|
||||
return success(data=result, msg="反思试运行")
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
from fastapi import APIRouter, Depends, status, Query
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
|
||||
from app.core.models import RedBearLLM
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
@@ -39,7 +35,7 @@ def get_model_providers():
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_model_list(
|
||||
type: Optional[List[model_schema.ModelType]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM&type=EMBEDDING)"),
|
||||
type: Optional[str] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||
@@ -54,13 +50,21 @@ def get_model_list(
|
||||
|
||||
支持多个 type 参数:
|
||||
- 单个:?type=LLM
|
||||
- 多个:?type=LLM&type=EMBEDDING
|
||||
- 多个(逗号分隔):?type=LLM,EMBEDDING
|
||||
- 多个(重复参数):?type=LLM&type=EMBEDDING
|
||||
"""
|
||||
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
# 解析 type 参数(支持逗号分隔)
|
||||
type_list = None
|
||||
if type:
|
||||
type_values = [t.strip() for t in type.split(',')]
|
||||
type_list = [model_schema.ModelType(t.lower()) for t in type_values if t]
|
||||
|
||||
api_logger.error(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQuery(
|
||||
type=type,
|
||||
type=type_list,
|
||||
provider=provider,
|
||||
is_active=is_active,
|
||||
is_public=is_public,
|
||||
|
||||
138
api/app/controllers/prompt_optimizer_controller.py
Normal file
138
api/app/controllers/prompt_optimizer_controller.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.models.prompt_optimizer_model import RoleType
|
||||
from app.schemas.prompt_optimizer_schema import PromptOptMessage, PromptOptModelSet, CreateSessionResponse, \
|
||||
OptimizePromptResponse, SessionHistoryResponse, SessionMessage
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.prompt_optimizer_service import PromptOptimizerService
|
||||
|
||||
router = APIRouter(prefix="/prompt", tags=["Prompts-Optimization"])
|
||||
logger = get_api_logger()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions",
|
||||
summary="Create a new prompt optimization session",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def create_prompt_session(
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Create a new prompt optimization session for the current user.
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains the newly generated session ID.
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
# create new session
|
||||
session = service.create_session(current_user.tenant_id, current_user.id)
|
||||
result_schema = CreateSessionResponse.model_validate(session)
|
||||
return success(data=result_schema)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}",
|
||||
summary="获取 prompt 优化历史对话",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def get_prompt_session(
|
||||
session_id: uuid.UUID = Path(..., description="Session ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Retrieve all messages from a specified prompt optimization session.
|
||||
|
||||
Args:
|
||||
session_id (UUID): The ID of the session to retrieve
|
||||
db (Session): Database session
|
||||
current_user: Current logged-in user
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains the session ID and the list of messages.
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
|
||||
history = service.get_session_message_history(
|
||||
session_id=session_id,
|
||||
user_id=current_user.id
|
||||
)
|
||||
|
||||
messages = [
|
||||
SessionMessage(role=role, content=content)
|
||||
for role, content in history
|
||||
]
|
||||
|
||||
result = SessionHistoryResponse(
|
||||
session_id=session_id,
|
||||
messages=messages
|
||||
)
|
||||
|
||||
return success(data=result)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/messages",
|
||||
summary="Get prompt optimization",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
async def get_prompt_opt(
|
||||
session_id: uuid.UUID = Path(..., description="Session ID"),
|
||||
data: PromptOptMessage = ...,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Send a user message in the specified session and return the optimized prompt
|
||||
along with its description and variables.
|
||||
|
||||
Args:
|
||||
session_id (UUID): The session ID
|
||||
data (PromptOptMessage): Contains the user message, model ID, and current prompt
|
||||
db (Session): Database session
|
||||
current_user: Current user information
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains the optimized prompt, description, and a list of variables.
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
service.create_message(
|
||||
tenant_id=current_user.tenant_id,
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
role=RoleType.USER,
|
||||
content=data.message
|
||||
)
|
||||
opt_result = await service.optimize_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
model_id=data.model_id,
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
current_prompt=data.current_prompt,
|
||||
user_require=data.message
|
||||
)
|
||||
service.create_message(
|
||||
tenant_id=current_user.tenant_id,
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
role=RoleType.ASSISTANT,
|
||||
content=opt_result.desc
|
||||
)
|
||||
variables = service.parser_prompt_variables(opt_result.prompt)
|
||||
result = {
|
||||
"prompt": opt_result.prompt,
|
||||
"desc": opt_result.desc,
|
||||
"variables": variables
|
||||
}
|
||||
result_schema = OptimizePromptResponse.model_validate(result)
|
||||
return success(data=result_schema)
|
||||
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
"""Memory 服务接口 - 基于 API Key 认证"""
|
||||
from fastapi import APIRouter, Depends
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, Request, Body
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.response_utils import success
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
|
||||
router = APIRouter(prefix="/memory", tags=["V1 - Memory API"])
|
||||
logger = get_business_logger()
|
||||
@@ -14,3 +18,31 @@ logger = get_business_logger()
|
||||
async def get_memory_info():
|
||||
"""获取记忆服务信息(占位)"""
|
||||
return success(data={}, msg="Memory API - Coming Soon")
|
||||
|
||||
|
||||
# /v1/memory/{resource_id}/chat
|
||||
@router.post("/{resource_id}/chat")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def chat_with_agent_demo(
|
||||
resource_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(..., description="聊天消息内容"),
|
||||
):
|
||||
"""
|
||||
Agent 聊天接口demo
|
||||
|
||||
scopes: 所需的权限范围列表["app", "rag", "memory"]
|
||||
|
||||
Args:
|
||||
resource_id: 如果是应用的apikey传的是应用id; 如果是服务的apikey传的是工作空间id
|
||||
message: 请求参数
|
||||
request: 声明请求
|
||||
api_key_auth: 包含验证后的API Key 信息
|
||||
db: db_session
|
||||
"""
|
||||
logger.info(f"API Key Auth: {api_key_auth}")
|
||||
logger.info(f"Resource ID: {resource_id}")
|
||||
logger.info(f"Message: {message}")
|
||||
return success(data={"received": True}, msg="消息已接收")
|
||||
585
api/app/controllers/tool_controller.py
Normal file
585
api/app/controllers/tool_controller.py
Normal file
@@ -0,0 +1,585 @@
|
||||
"""工具管理API控制器"""
|
||||
import base64
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Body
|
||||
from langfuse.api.core import jsonable_encoder
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, Field, PositiveInt, field_validator
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.models.tool_model import ToolConfig, BuiltinToolConfig, ToolType, ToolStatus, CustomToolConfig, MCPToolConfig
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.config import settings
|
||||
from app.core.tools.config_manager import ConfigManager
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
router = APIRouter(prefix="/tools", tags=["工具管理"])
|
||||
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
|
||||
def _encrypt_sensitive_params(parameters: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""加密敏感参数"""
|
||||
cipher_key = base64.urlsafe_b64encode(settings.SECRET_KEY[:32].ljust(32, '0').encode())
|
||||
cipher = Fernet(cipher_key)
|
||||
|
||||
encrypted_params = {}
|
||||
sensitive_keys = ['api_key', 'token', 'api_secret', 'password']
|
||||
|
||||
for key, value in parameters.items():
|
||||
if any(sensitive in key.lower() for sensitive in sensitive_keys) and value:
|
||||
encrypted_params[key] = cipher.encrypt(str(value).encode()).decode()
|
||||
else:
|
||||
encrypted_params[key] = value
|
||||
|
||||
return encrypted_params
|
||||
|
||||
|
||||
def _decrypt_sensitive_params(parameters: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""解密敏感参数"""
|
||||
cipher_key = base64.urlsafe_b64encode(settings.SECRET_KEY[:32].ljust(32, '0').encode())
|
||||
cipher = Fernet(cipher_key)
|
||||
|
||||
decrypted_params = {}
|
||||
sensitive_keys = ['api_key', 'token', 'secret', 'password']
|
||||
|
||||
for key, value in parameters.items():
|
||||
if any(sensitive in key.lower() for sensitive in sensitive_keys) and value:
|
||||
try:
|
||||
decrypted_params[key] = cipher.decrypt(value.encode()).decode()
|
||||
except Exception as e:
|
||||
decrypted_params[key] = value
|
||||
else:
|
||||
decrypted_params[key] = value
|
||||
|
||||
return decrypted_params
|
||||
|
||||
|
||||
def _update_tool_status(tool_config: ToolConfig, builtin_config: BuiltinToolConfig = None, tool_info: Dict = None) -> str:
|
||||
"""更新工具状态并返回新状态"""
|
||||
if tool_config.tool_type == ToolType.BUILTIN:
|
||||
if not tool_info or not tool_info.get('requires_config', False):
|
||||
new_status = ToolStatus.ACTIVE.value # 不需要配置的内置工具
|
||||
elif not builtin_config or not builtin_config.parameters:
|
||||
new_status = ToolStatus.INACTIVE.value
|
||||
else:
|
||||
# 检查是否有必要的API密钥
|
||||
has_key = bool(builtin_config.parameters.get('api_key') or builtin_config.parameters.get('token'))
|
||||
new_status = ToolStatus.ACTIVE.value if has_key else ToolStatus.INACTIVE.value
|
||||
else: # 自定义和MCP工具
|
||||
new_status = ToolStatus.ACTIVE.value if tool_config.config_data else ToolStatus.ERROR.value
|
||||
|
||||
# 更新数据库中的状态
|
||||
if tool_config.status != new_status:
|
||||
tool_config.status = new_status
|
||||
|
||||
return new_status
|
||||
|
||||
|
||||
# ==================== 请求/响应模型 ====================
|
||||
|
||||
class ToolListResponse(BaseModel):
|
||||
"""工具列表响应"""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
tool_type: str
|
||||
category: str
|
||||
version: str = "1.0.0"
|
||||
status: str # active inactive error loading
|
||||
requires_config: bool = False
|
||||
# is_configured: bool = False
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class BuiltinToolConfigRequest(BaseModel):
|
||||
"""内置工具配置请求"""
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
|
||||
|
||||
|
||||
class CustomToolCreateRequest(BaseModel):
|
||||
"""自定义工具创建请求体模型,包含参数校验规则"""
|
||||
name: str = Field(..., min_length=1, max_length=100, description="工具名称,必填")
|
||||
description: str = Field(None, description="工具描述")
|
||||
base_url: str = Field(None, description="工具基础URL")
|
||||
schema_url: str = Field(None, description="工具Schema URL")
|
||||
schema_content: Optional[Dict[str, Any]] = Field(None, description="工具Schema内容,可选")
|
||||
auth_type: str = Field("none", pattern=r"^(none|api_key|bearer_token)$", description="认证类型")
|
||||
auth_config: Optional[Dict[str, Any]] = Field(None, description="认证配置,默认空字典")
|
||||
timeout: PositiveInt = Field(30, ge=1, le=300, description="超时时间,1-300秒,默认30")
|
||||
|
||||
# 自定义校验:当auth_type为api_key时,auth_config必须包含api_key字段
|
||||
@field_validator("auth_config")
|
||||
def validate_auth_config(cls, v, values):
|
||||
auth_type = values.data.get("auth_type")
|
||||
if auth_type == "api_key" and (not v or "api_key" not in v):
|
||||
raise ValueError("认证类型为api_key时,auth_config必须包含api_key字段")
|
||||
if auth_type == "bearer_token" and (not v or "bearer_token" not in v):
|
||||
raise ValueError("认证类型为bearer_token时,auth_config必须包含bearer_token字段")
|
||||
return v
|
||||
|
||||
class MCPToolCreateRequest(BaseModel):
|
||||
"""MCP工具创建请求体模型,适配MCP业务特性"""
|
||||
# 基础必填字段(带长度/格式校验)
|
||||
name: str = Field(..., min_length=1, max_length=100,description="MCP工具名称")
|
||||
description: str = Field(None, description="MCP工具描述")
|
||||
# MCP核心字段:服务端URL(强制HTTP/HTTPS格式)
|
||||
server_url: str = Field(..., description="MCP服务端URL,仅支持http/https协议")
|
||||
# 连接配置:默认空字典,可自定义校验规则(根据实际业务调整)
|
||||
connection_config: Dict[str, Any] = Field({},description="MCP连接配置(如认证信息、超时、重试等),默认空字典")
|
||||
|
||||
@field_validator("connection_config")
|
||||
def validate_connection_config(cls, v):
|
||||
# 示例1:若包含timeout,必须是1-300的整数
|
||||
if "timeout" in v:
|
||||
timeout = v["timeout"]
|
||||
if not isinstance(timeout, int) or timeout < 1 or timeout > 300:
|
||||
raise ValueError("connection_config.timeout必须是1-300的整数")
|
||||
return v
|
||||
|
||||
# @field_validator("server_url")
|
||||
# def validate_server_url_protocol(cls, v):
|
||||
# if v.scheme != "https":
|
||||
# raise ValueError("MCP服务端URL仅支持HTTPS协议(安全要求)")
|
||||
# return v
|
||||
|
||||
|
||||
# ==================== API端点 ====================
|
||||
@router.get("", response_model=List[ToolListResponse])
|
||||
async def list_tools(
|
||||
name: Optional[str] = None,
|
||||
tool_type: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取工具列表(包含内置工具、自定义工具和MCP工具)"""
|
||||
try:
|
||||
# 初始化内置工具(如果需要)
|
||||
config_manager = ConfigManager()
|
||||
config_manager.ensure_builtin_tools_initialized(
|
||||
current_user.tenant_id, db, ToolConfig, BuiltinToolConfig, ToolType, ToolStatus
|
||||
)
|
||||
|
||||
response_tools = []
|
||||
|
||||
query = db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == current_user.tenant_id
|
||||
)
|
||||
if tool_type:
|
||||
query = query.filter(ToolConfig.tool_type == tool_type)
|
||||
|
||||
if name:
|
||||
query = query.filter(ToolConfig.name.ilike(f"%{name}%"))
|
||||
|
||||
tools = query.all()
|
||||
builtin_tools = config_manager.load_builtin_tools_config()
|
||||
configured_tools = {tool_info["tool_class"]: tool_info for tool_key, tool_info in builtin_tools.items()}
|
||||
|
||||
for tool_config in tools:
|
||||
if tool_config.tool_type == ToolType.BUILTIN.value:
|
||||
builtin_config = db.query(BuiltinToolConfig).filter(BuiltinToolConfig.id == tool_config.id).first()
|
||||
tool_info = configured_tools.get(builtin_config.tool_class)
|
||||
status = _update_tool_status(tool_config, builtin_config, tool_info)
|
||||
else:
|
||||
status = _update_tool_status(tool_config)
|
||||
|
||||
response_tools.append(ToolListResponse(
|
||||
id=str(tool_config.id),
|
||||
name=tool_config.name,
|
||||
description=tool_config.description,
|
||||
tool_type=tool_config.tool_type,
|
||||
category=tool_info['category'] if tool_config.tool_type == ToolType.BUILTIN.value else tool_config.tool_type,
|
||||
version="1.0.0",
|
||||
status=status,
|
||||
requires_config=tool_info['requires_config'] if tool_config.tool_type == ToolType.BUILTIN.value else False,
|
||||
))
|
||||
|
||||
return response_tools
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/builtin/{tool_id}")
|
||||
async def get_builtin_tool_detail(
|
||||
tool_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取内置工具详情"""
|
||||
try:
|
||||
config_manager = ConfigManager()
|
||||
builtin_tools = config_manager.load_builtin_tools_config()
|
||||
configured_tools = {tool_info["tool_class"]: tool_info for tool_key, tool_info in builtin_tools.items()}
|
||||
tool_config = db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == current_user.tenant_id,
|
||||
ToolConfig.id == tool_id
|
||||
).first()
|
||||
builtin_config = db.query(BuiltinToolConfig).filter(BuiltinToolConfig.id == tool_config.id).first()
|
||||
tool_info = configured_tools.get(builtin_config.tool_class)
|
||||
|
||||
is_configured = False
|
||||
config_parameters = {}
|
||||
|
||||
if builtin_config and builtin_config.parameters:
|
||||
is_configured = bool(builtin_config.parameters.get('api_key') or builtin_config.parameters.get('token'))
|
||||
# 不返回敏感信息,只返回非敏感配置
|
||||
config_parameters = {k: v for k, v in builtin_config.parameters.items()
|
||||
if not any(sensitive in k.lower() for sensitive in ['key', 'secret', 'token', 'password'])}
|
||||
|
||||
return {
|
||||
"id": tool_config.id,
|
||||
"name": tool_config.name,
|
||||
"description": tool_config.description,
|
||||
"category": tool_info['category'],
|
||||
"status": tool_config.tool_type,
|
||||
"requires_config": tool_info['requires_config'],
|
||||
"is_configured": is_configured,
|
||||
"config_parameters": config_parameters
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具详情失败: {tool_id}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/builtin/{tool_id}/configure")
|
||||
async def configure_builtin_tool(
|
||||
tool_id: str,
|
||||
request: BuiltinToolConfigRequest = Body(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""配置内置工具参数(租户级别)"""
|
||||
try:
|
||||
# 查询工具配置
|
||||
tool_config = db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == current_user.tenant_id,
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tool_type == ToolType.BUILTIN
|
||||
).first()
|
||||
|
||||
if not tool_config:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
# 获取内置工具配置
|
||||
builtin_config = db.query(BuiltinToolConfig).filter(
|
||||
BuiltinToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if not builtin_config:
|
||||
raise HTTPException(status_code=404, detail="内置工具配置不存在")
|
||||
|
||||
# 获取全局工具信息
|
||||
config_manager = ConfigManager()
|
||||
builtin_tools_config = config_manager.load_builtin_tools_config()
|
||||
tool_info = None
|
||||
for tool_key, info in builtin_tools_config.items():
|
||||
if info['tool_class'] == builtin_config.tool_class:
|
||||
tool_info = info
|
||||
break
|
||||
|
||||
if not tool_info:
|
||||
raise HTTPException(status_code=404, detail="工具信息不存在")
|
||||
|
||||
# 加密敏感参数
|
||||
encrypted_params = _encrypt_sensitive_params(request.parameters)
|
||||
|
||||
# 更新配置
|
||||
builtin_config.parameters = encrypted_params
|
||||
|
||||
# 更新状态
|
||||
_update_tool_status(tool_config, builtin_config, tool_info)
|
||||
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"工具 {tool_config.name} 配置成功"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"配置内置工具失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/builtin/{tool_id}/config")
|
||||
async def get_builtin_tool_config(
|
||||
tool_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取内置工具配置(用于使用)"""
|
||||
try:
|
||||
# 查询工具配置
|
||||
tool_config = db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == current_user.tenant_id,
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tool_type == ToolType.BUILTIN
|
||||
).first()
|
||||
|
||||
if not tool_config:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
# 获取内置工具配置
|
||||
builtin_config = db.query(BuiltinToolConfig).filter(
|
||||
BuiltinToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if not builtin_config:
|
||||
raise HTTPException(status_code=404, detail="内置工具配置不存在")
|
||||
|
||||
# 解密参数
|
||||
decrypted_params = _decrypt_sensitive_params(builtin_config.parameters or {})
|
||||
|
||||
return {
|
||||
"tool_id": tool_id,
|
||||
"tool_class": builtin_config.tool_class,
|
||||
"name": tool_config.name,
|
||||
"parameters": decrypted_params,
|
||||
"status": tool_config.status
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具配置失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/custom")
|
||||
async def create_custom_tool(
|
||||
request: CustomToolCreateRequest = Body(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""创建自定义工具"""
|
||||
try:
|
||||
config_data = jsonable_encoder(request.model_dump())
|
||||
config_data["tool_type"] = "custom"
|
||||
|
||||
config_manager = ConfigManager()
|
||||
is_valid, error_msg = config_manager.validate_config(config_data, "custom")
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
# 创建数据库记录
|
||||
tool_config = ToolConfig(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
tool_type=ToolType.CUSTOM,
|
||||
tenant_id=current_user.tenant_id,
|
||||
status=ToolStatus.ACTIVE.value,
|
||||
config_data=config_data
|
||||
)
|
||||
db.add(tool_config)
|
||||
db.flush()
|
||||
|
||||
# 创建CustomToolConfig记录
|
||||
custom_config = CustomToolConfig(
|
||||
id=tool_config.id,
|
||||
base_url=request.base_url,
|
||||
schema_url=request.schema_url,
|
||||
schema_content=request.schema_content,
|
||||
auth_type=request.auth_type,
|
||||
auth_config=request.auth_config,
|
||||
timeout=request.timeout
|
||||
)
|
||||
db.add(custom_config)
|
||||
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"自定义工具 {request.name} 创建成功",
|
||||
"tool_id": str(tool_config.id)
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"创建自定义工具失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/mcp")
|
||||
async def create_mcp_tool(
|
||||
request: MCPToolCreateRequest = Body(..., description="MCP工具创建参数"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""创建MCP工具"""
|
||||
try:
|
||||
config_data = jsonable_encoder(request.model_dump())
|
||||
config_data["tool_type"] = "mcp"
|
||||
|
||||
config_manager = ConfigManager()
|
||||
is_valid, error_msg = config_manager.validate_config(config_data, "mcp")
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
# 创建数据库记录
|
||||
try:
|
||||
tool_config = ToolConfig(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
tool_type=ToolType.MCP,
|
||||
tenant_id=current_user.tenant_id,
|
||||
status=ToolStatus.ACTIVE.value,
|
||||
config_data=config_data
|
||||
)
|
||||
db.add(tool_config)
|
||||
db.flush()
|
||||
|
||||
# 创建MCPToolConfig记录
|
||||
mcp_config = MCPToolConfig(
|
||||
id=tool_config.id,
|
||||
server_url=request.server_url,
|
||||
connection_config=request.connection_config
|
||||
)
|
||||
db.add(mcp_config)
|
||||
|
||||
db.commit()
|
||||
except SQLAlchemyError as db_e:
|
||||
db.rollback()
|
||||
logger.error(f"创建MCP工具数据库操作失败(租户ID:{current_user.tenant_id},工具名:{request.name}): {str(db_e)}",
|
||||
exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"创建MCP工具数据库操作失败(租户ID:{current_user.tenant_id},"
|
||||
f"工具名:{request.name}):{str(db_e)}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"MCP工具 {request.name} 创建成功",
|
||||
"tool_id": str(tool_config.id)
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"创建MCP工具失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.delete("/{tool_id}")
|
||||
async def delete_tool(
|
||||
tool_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""删除工具(仅限自定义和MCP工具)"""
|
||||
try:
|
||||
tool = db.query(ToolConfig).filter(
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tenant_id == current_user.tenant_id
|
||||
).first()
|
||||
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
if tool.tool_type == ToolType.BUILTIN:
|
||||
raise HTTPException(status_code=403, detail="内置工具不允许删除")
|
||||
|
||||
db.delete(tool)
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"工具 {tool.name} 删除成功"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"删除工具失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/{tool_id}")
|
||||
async def update_tool(
|
||||
tool_id: str,
|
||||
config_data: Optional[Dict[str, Any]] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""更新工具(仅限自定义和MCP工具)"""
|
||||
try:
|
||||
tool = db.query(ToolConfig).filter(
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tenant_id == current_user.tenant_id
|
||||
).first()
|
||||
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
if tool.tool_type == ToolType.BUILTIN:
|
||||
raise HTTPException(status_code=403, detail="内置工具不允许修改")
|
||||
|
||||
if config_data is not None:
|
||||
tool.config_data = config_data
|
||||
# 更新状态
|
||||
_update_tool_status(tool)
|
||||
|
||||
db.commit()
|
||||
db.refresh(tool)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"工具 {tool.name} 更新成功",
|
||||
"status": tool.status
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新工具失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{tool_id}/toggle")
|
||||
async def toggle_tool_status(
|
||||
tool_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""切换工具活跃/非活跃状态"""
|
||||
try:
|
||||
tool = db.query(ToolConfig).filter(
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tenant_id == current_user.tenant_id
|
||||
).first()
|
||||
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
# 在active和inactive之间切换
|
||||
if tool.status == ToolStatus.ACTIVE.value:
|
||||
tool.status = ToolStatus.INACTIVE.value
|
||||
elif tool.status == ToolStatus.INACTIVE.value:
|
||||
tool.status = ToolStatus.ACTIVE.value
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="只有可用或非活跃状态的工具可以切换")
|
||||
|
||||
db.commit()
|
||||
db.refresh(tool)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"工具 {tool.name} 状态已更新为 {tool.status}",
|
||||
"status": tool.status
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"切换工具状态失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
430
api/app/controllers/tool_execution_controller.py
Normal file
430
api/app/controllers/tool_execution_controller.py
Normal file
@@ -0,0 +1,430 @@
|
||||
"""工具执行API控制器"""
|
||||
import uuid
|
||||
from typing import Dict, Any, List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.core.tools.registry import ToolRegistry
|
||||
from app.core.tools.executor import ToolExecutor
|
||||
from app.core.tools.chain_manager import ChainManager, ChainDefinition, ChainStep, ChainExecutionMode
|
||||
from app.core.tools.builtin import *
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
router = APIRouter(prefix="/tools/execution", tags=["工具执行"])
|
||||
|
||||
|
||||
# ==================== 请求/响应模型 ====================
|
||||
|
||||
class ToolExecutionRequest(BaseModel):
|
||||
"""工具执行请求"""
|
||||
tool_id: str = Field(..., description="工具ID")
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
|
||||
timeout: Optional[float] = Field(None, ge=1, le=300, description="超时时间(秒)")
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, description="额外元数据")
|
||||
|
||||
|
||||
class BatchExecutionRequest(BaseModel):
|
||||
"""批量执行请求"""
|
||||
executions: List[ToolExecutionRequest] = Field(..., description="执行列表")
|
||||
max_concurrency: int = Field(5, ge=1, le=20, description="最大并发数")
|
||||
|
||||
|
||||
class ToolExecutionResponse(BaseModel):
|
||||
"""工具执行响应"""
|
||||
success: bool
|
||||
execution_id: str
|
||||
tool_id: str
|
||||
data: Any = None
|
||||
error: Optional[str] = None
|
||||
error_code: Optional[str] = None
|
||||
execution_time: float
|
||||
token_usage: Optional[Dict[str, int]] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ChainStepRequest(BaseModel):
|
||||
"""链步骤请求"""
|
||||
tool_id: str = Field(..., description="工具ID")
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
|
||||
condition: Optional[str] = Field(None, description="执行条件")
|
||||
output_mapping: Optional[Dict[str, str]] = Field(None, description="输出映射")
|
||||
error_handling: str = Field("stop", description="错误处理策略")
|
||||
|
||||
|
||||
class ChainExecutionRequest(BaseModel):
|
||||
"""链执行请求"""
|
||||
name: str = Field(..., description="链名称")
|
||||
description: str = Field("", description="链描述")
|
||||
steps: List[ChainStepRequest] = Field(..., description="执行步骤")
|
||||
execution_mode: str = Field("sequential", description="执行模式")
|
||||
initial_variables: Optional[Dict[str, Any]] = Field(None, description="初始变量")
|
||||
global_timeout: Optional[float] = Field(None, description="全局超时")
|
||||
|
||||
|
||||
class ExecutionHistoryResponse(BaseModel):
|
||||
"""执行历史响应"""
|
||||
execution_id: str
|
||||
tool_id: str
|
||||
status: str
|
||||
started_at: Optional[str]
|
||||
completed_at: Optional[str]
|
||||
execution_time: Optional[float]
|
||||
user_id: Optional[str]
|
||||
workspace_id: Optional[str]
|
||||
input_data: Optional[Dict[str, Any]]
|
||||
output_data: Optional[Any]
|
||||
error_message: Optional[str]
|
||||
token_usage: Optional[Dict[str, int]]
|
||||
|
||||
|
||||
class ToolConnectionTestResponse(BaseModel):
|
||||
"""工具连接测试响应"""
|
||||
success: bool
|
||||
message: str
|
||||
error: Optional[str] = None
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
# ==================== 依赖注入 ====================
|
||||
|
||||
def get_tool_registry(db: Session = Depends(get_db)) -> ToolRegistry:
|
||||
"""获取工具注册表"""
|
||||
registry = ToolRegistry(db)
|
||||
|
||||
# 注册内置工具类
|
||||
registry.register_tool_class(DateTimeTool)
|
||||
registry.register_tool_class(JsonTool)
|
||||
registry.register_tool_class(BaiduSearchTool)
|
||||
registry.register_tool_class(MinerUTool)
|
||||
registry.register_tool_class(TextInTool)
|
||||
|
||||
return registry
|
||||
|
||||
|
||||
def get_tool_executor(
|
||||
db: Session = Depends(get_db),
|
||||
registry: ToolRegistry = Depends(get_tool_registry)
|
||||
) -> ToolExecutor:
|
||||
"""获取工具执行器"""
|
||||
return ToolExecutor(db, registry)
|
||||
|
||||
|
||||
def get_chain_manager(executor: ToolExecutor = Depends(get_tool_executor)) -> ChainManager:
|
||||
"""获取链管理器"""
|
||||
return ChainManager(executor)
|
||||
|
||||
|
||||
# ==================== API端点 ====================
|
||||
|
||||
@router.post("/execute", response_model=ToolExecutionResponse)
|
||||
async def execute_tool(
|
||||
request: ToolExecutionRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""执行单个工具"""
|
||||
try:
|
||||
# 生成执行ID
|
||||
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
# 执行工具
|
||||
result = await executor.execute_tool(
|
||||
tool_id=request.tool_id,
|
||||
parameters=request.parameters,
|
||||
user_id=current_user.id,
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
execution_id=execution_id,
|
||||
timeout=request.timeout,
|
||||
metadata=request.metadata
|
||||
)
|
||||
|
||||
return ToolExecutionResponse(
|
||||
success=result.success,
|
||||
execution_id=execution_id,
|
||||
tool_id=request.tool_id,
|
||||
data=result.data,
|
||||
error=result.error,
|
||||
error_code=result.error_code,
|
||||
execution_time=result.execution_time,
|
||||
token_usage=result.token_usage,
|
||||
metadata=result.metadata
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具执行失败: {request.tool_id}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/batch", response_model=List[ToolExecutionResponse])
|
||||
async def execute_tools_batch(
|
||||
request: BatchExecutionRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""批量执行工具"""
|
||||
try:
|
||||
# 准备执行配置
|
||||
execution_configs = []
|
||||
execution_ids = []
|
||||
|
||||
for exec_request in request.executions:
|
||||
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
|
||||
execution_ids.append(execution_id)
|
||||
|
||||
execution_configs.append({
|
||||
"tool_id": exec_request.tool_id,
|
||||
"parameters": exec_request.parameters,
|
||||
"user_id": current_user.id,
|
||||
"workspace_id": current_user.current_workspace_id,
|
||||
"execution_id": execution_id,
|
||||
"timeout": exec_request.timeout,
|
||||
"metadata": exec_request.metadata
|
||||
})
|
||||
|
||||
# 批量执行
|
||||
results = await executor.execute_tools_batch(
|
||||
execution_configs,
|
||||
max_concurrency=request.max_concurrency
|
||||
)
|
||||
|
||||
# 转换响应格式
|
||||
responses = []
|
||||
for i, result in enumerate(results):
|
||||
responses.append(ToolExecutionResponse(
|
||||
success=result.success,
|
||||
execution_id=execution_ids[i],
|
||||
tool_id=request.executions[i].tool_id,
|
||||
data=result.data,
|
||||
error=result.error,
|
||||
error_code=result.error_code,
|
||||
execution_time=result.execution_time,
|
||||
token_usage=result.token_usage,
|
||||
metadata=result.metadata
|
||||
))
|
||||
|
||||
return responses
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量执行失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/chain", response_model=Dict[str, Any])
|
||||
async def execute_tool_chain(
|
||||
request: ChainExecutionRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
chain_manager: ChainManager = Depends(get_chain_manager)
|
||||
):
|
||||
"""执行工具链"""
|
||||
try:
|
||||
# 转换步骤格式
|
||||
steps = []
|
||||
for step_request in request.steps:
|
||||
step = ChainStep(
|
||||
tool_id=step_request.tool_id,
|
||||
parameters=step_request.parameters,
|
||||
condition=step_request.condition,
|
||||
output_mapping=step_request.output_mapping,
|
||||
error_handling=step_request.error_handling
|
||||
)
|
||||
steps.append(step)
|
||||
|
||||
# 创建链定义
|
||||
chain_definition = ChainDefinition(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
steps=steps,
|
||||
execution_mode=ChainExecutionMode(request.execution_mode),
|
||||
global_timeout=request.global_timeout
|
||||
)
|
||||
|
||||
# 注册并执行链
|
||||
chain_manager.register_chain(chain_definition)
|
||||
|
||||
result = await chain_manager.execute_chain(
|
||||
chain_name=request.name,
|
||||
initial_variables=request.initial_variables
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具链执行失败: {request.name}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/running", response_model=List[Dict[str, Any]])
|
||||
async def get_running_executions(
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""获取正在运行的执行"""
|
||||
try:
|
||||
running_executions = executor.get_running_executions()
|
||||
|
||||
# 过滤当前工作空间的执行
|
||||
workspace_executions = [
|
||||
exec_info for exec_info in running_executions
|
||||
if exec_info.get("workspace_id") == str(current_user.current_workspace_id)
|
||||
]
|
||||
|
||||
return workspace_executions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取运行中执行失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/cancel/{execution_id}", response_model=Dict[str, Any])
|
||||
async def cancel_execution(
|
||||
execution_id: str = Path(..., description="执行ID"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""取消工具执行"""
|
||||
try:
|
||||
success = await executor.cancel_execution(execution_id)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "执行已取消"
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="执行不存在或已完成")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"取消执行失败: {execution_id}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/history", response_model=List[ExecutionHistoryResponse])
|
||||
async def get_execution_history(
|
||||
tool_id: Optional[str] = Query(None, description="工具ID过滤"),
|
||||
limit: int = Query(50, ge=1, le=200, description="返回数量限制"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""获取执行历史"""
|
||||
try:
|
||||
history = executor.get_execution_history(
|
||||
tool_id=tool_id,
|
||||
user_id=current_user.id,
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
# 转换响应格式
|
||||
responses = []
|
||||
for record in history:
|
||||
responses.append(ExecutionHistoryResponse(
|
||||
execution_id=record["execution_id"],
|
||||
tool_id=record["tool_id"],
|
||||
status=record["status"],
|
||||
started_at=record["started_at"],
|
||||
completed_at=record["completed_at"],
|
||||
execution_time=record["execution_time"],
|
||||
user_id=record["user_id"],
|
||||
workspace_id=record["workspace_id"],
|
||||
input_data=record["input_data"],
|
||||
output_data=record["output_data"],
|
||||
error_message=record["error_message"],
|
||||
token_usage=record["token_usage"]
|
||||
))
|
||||
|
||||
return responses
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取执行历史失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/statistics", response_model=Dict[str, Any])
|
||||
async def get_execution_statistics(
|
||||
days: int = Query(7, ge=1, le=90, description="统计天数"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""获取执行统计"""
|
||||
try:
|
||||
stats = executor.get_execution_statistics(
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
days=days
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"statistics": stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取执行统计失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/chains/running", response_model=List[Dict[str, Any]])
|
||||
async def get_running_chains(
|
||||
current_user: User = Depends(get_current_user),
|
||||
chain_manager: ChainManager = Depends(get_chain_manager)
|
||||
):
|
||||
"""获取正在运行的工具链"""
|
||||
try:
|
||||
running_chains = chain_manager.get_running_chains()
|
||||
return running_chains
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取运行中工具链失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/chains", response_model=List[Dict[str, Any]])
|
||||
async def list_tool_chains(
|
||||
current_user: User = Depends(get_current_user),
|
||||
chain_manager: ChainManager = Depends(get_chain_manager)
|
||||
):
|
||||
"""列出工具链"""
|
||||
try:
|
||||
chains = chain_manager.list_chains()
|
||||
return chains
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具链列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/test-connection/{tool_id}", response_model=ToolConnectionTestResponse)
|
||||
async def test_tool_connection(
|
||||
tool_id: str = Path(..., description="工具ID"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""测试工具连接"""
|
||||
try:
|
||||
result = await executor.test_tool_connection(
|
||||
tool_id=tool_id,
|
||||
user_id=current_user.id,
|
||||
workspace_id=current_user.current_workspace_id
|
||||
)
|
||||
|
||||
return ToolConnectionTestResponse(
|
||||
success=result.get("success", False),
|
||||
message=result.get("message", ""),
|
||||
error=result.get("error"),
|
||||
details=result.get("details")
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具连接测试失败: {tool_id}, 错误: {e}")
|
||||
return ToolConnectionTestResponse(
|
||||
success=False,
|
||||
message="连接测试失败",
|
||||
error=str(e)
|
||||
)
|
||||
@@ -471,28 +471,52 @@ async def run_workflow(
|
||||
import json
|
||||
|
||||
async def event_generator():
|
||||
"""生成 SSE 事件"""
|
||||
"""生成 SSE 事件
|
||||
|
||||
SSE 格式:
|
||||
event: <event_type>
|
||||
data: <json_data>
|
||||
|
||||
支持的事件类型:
|
||||
- workflow_start: 工作流开始
|
||||
- workflow_end: 工作流结束
|
||||
- node_start: 节点开始执行
|
||||
- node_end: 节点执行完成
|
||||
- node_chunk: 中间节点的流式输出
|
||||
- message: 最终消息的流式输出(End 节点及其相邻节点)
|
||||
"""
|
||||
try:
|
||||
async for event in service.run_workflow(
|
||||
async for event in await service.run_workflow(
|
||||
app_id=app_id,
|
||||
input_data=input_data,
|
||||
triggered_by=current_user.id,
|
||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||
stream=True
|
||||
):
|
||||
# 转换为 SSE 格式
|
||||
yield f"data: {json.dumps(event)}\n\n"
|
||||
# 提取事件类型和数据
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
# event: <type>
|
||||
# data: <json>
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||
yield sse_message
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式执行异常: {e}", exc_info=True)
|
||||
error_event = {
|
||||
"type": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
yield f"data: {json.dumps(error_event)}\n\n"
|
||||
# 发送错误事件
|
||||
sse_error = f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
||||
yield sse_error
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream"
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no" # 禁用 nginx 缓冲
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 非流式执行
|
||||
|
||||
Reference in New Issue
Block a user