Merge remote-tracking branch 'origin/develop' into refactor/memory-config-management
This commit is contained in:
@@ -83,17 +83,18 @@ celery_app.autodiscover_tasks(['app'])
|
||||
reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS)
|
||||
health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS)
|
||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||
|
||||
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||
# 构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
"run-reflection-engine": {
|
||||
"task": "app.core.memory.agent.reflection.timer",
|
||||
"schedule": reflection_schedule,
|
||||
"args": (),
|
||||
},
|
||||
"check-read-service": {
|
||||
"task": "app.core.memory.agent.health.check_read_service",
|
||||
"schedule": health_schedule,
|
||||
|
||||
# "check-read-service": {
|
||||
# "task": "app.core.memory.agent.health.check_read_service",
|
||||
# "schedule": health_schedule,
|
||||
# "args": (),
|
||||
# },
|
||||
"run-workspace-reflection": {
|
||||
"task": "app.tasks.workspace_reflection_task",
|
||||
"schedule": workspace_reflection_schedule,
|
||||
"args": (),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -23,11 +23,17 @@ from . import (
|
||||
memory_dashboard_controller,
|
||||
memory_storage_controller,
|
||||
memory_dashboard_controller,
|
||||
memory_reflection_controller,
|
||||
api_key_controller,
|
||||
release_share_controller,
|
||||
public_share_controller,
|
||||
multi_agent_controller,
|
||||
workflow_controller,
|
||||
emotion_controller,
|
||||
emotion_config_controller,
|
||||
prompt_optimizer_controller,
|
||||
tool_controller,
|
||||
tool_execution_controller,
|
||||
)
|
||||
|
||||
# 创建管理端 API 路由器
|
||||
@@ -58,5 +64,11 @@ manager_router.include_router(public_share_controller.router) # 公开路由(
|
||||
manager_router.include_router(memory_dashboard_controller.router)
|
||||
manager_router.include_router(multi_agent_controller.router)
|
||||
manager_router.include_router(workflow_controller.router)
|
||||
manager_router.include_router(emotion_controller.router)
|
||||
manager_router.include_router(emotion_config_controller.router)
|
||||
manager_router.include_router(prompt_optimizer_controller.router)
|
||||
manager_router.include_router(memory_reflection_controller.router)
|
||||
manager_router.include_router(tool_controller.router)
|
||||
manager_router.include_router(tool_execution_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
@@ -8,6 +8,7 @@ from sqlalchemy.orm import Session
|
||||
from app.core.error_codes import BizCode
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
from app.models import ApiKeyType
|
||||
from app.models.user_model import User
|
||||
from app.core.response_utils import success
|
||||
from app.schemas import api_key_schema
|
||||
@@ -39,6 +40,8 @@ def create_api_key(
|
||||
"""
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if data.type == ApiKeyType.SERVICE.value and not data.resource_id:
|
||||
data.resource_id = workspace_id
|
||||
|
||||
# 创建 API Key
|
||||
api_key_obj = ApiKeyService.create_api_key(
|
||||
|
||||
@@ -421,8 +421,8 @@ async def draft_run(
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
|
||||
|
||||
|
||||
|
||||
async for event in draft_service.run_stream(
|
||||
agent_config=agent_cfg,
|
||||
model_config=model_config,
|
||||
@@ -574,7 +574,7 @@ async def draft_run(
|
||||
# 3. 流式返回
|
||||
if payload.stream:
|
||||
logger.debug(
|
||||
"开始多智能体流式试运行",
|
||||
"开始工作流流式试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"message_length": len(payload.message),
|
||||
@@ -583,18 +583,27 @@ async def draft_run(
|
||||
)
|
||||
|
||||
async def event_generator():
|
||||
"""多智能体流式事件生成器"""
|
||||
multiservice = MultiAgentService(db)
|
||||
|
||||
# 调用多智能体服务的流式方法
|
||||
async for event in multiservice.run_stream(
|
||||
"""工作流事件生成器
|
||||
|
||||
将事件转换为标准 SSE 格式:
|
||||
event: <event_type>
|
||||
data: <json_data>
|
||||
"""
|
||||
import json
|
||||
|
||||
# 调用工作流服务的流式方法
|
||||
async for event in workflow_service.run_stream(
|
||||
app_id=app_id,
|
||||
request=multi_agent_request,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
|
||||
payload=payload,
|
||||
config=config
|
||||
):
|
||||
yield event
|
||||
# 提取事件类型和数据
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||
yield sse_message
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
@@ -617,7 +626,7 @@ async def draft_run(
|
||||
)
|
||||
|
||||
result = await workflow_service.run(app_id, payload,config)
|
||||
|
||||
|
||||
logger.debug(
|
||||
"工作流试运行返回结果",
|
||||
extra={
|
||||
|
||||
207
api/app/controllers/emotion_config_controller.py
Normal file
207
api/app/controllers/emotion_config_controller.py
Normal file
@@ -0,0 +1,207 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""情绪配置控制器模块
|
||||
|
||||
本模块提供情绪引擎配置管理的API端点,包括获取和更新配置。
|
||||
|
||||
Routes:
|
||||
GET /memory/config/emotion - 获取情绪引擎配置
|
||||
POST /memory/config/emotion - 更新情绪引擎配置
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.response_utils import success
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.emotion_config_service import EmotionConfigService
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.db import get_db
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/emotion",
|
||||
tags=["Emotion Config"],
|
||||
dependencies=[Depends(get_current_user)] # 所有路由都需要认证
|
||||
)
|
||||
|
||||
class EmotionConfigQuery(BaseModel):
|
||||
"""情绪配置查询请求模型"""
|
||||
config_id: int = Field(..., description="配置ID")
|
||||
|
||||
class EmotionConfigUpdate(BaseModel):
|
||||
"""情绪配置更新请求模型"""
|
||||
config_id: int = Field(..., description="配置ID")
|
||||
emotion_enabled: bool = Field(..., description="是否启用情绪提取")
|
||||
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
|
||||
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
|
||||
emotion_min_intensity: float = Field(..., ge=0.0, le=1.0, description="最小情绪强度阈值(0.0-1.0)")
|
||||
emotion_enable_subject: bool = Field(..., description="是否启用主体分类")
|
||||
|
||||
@router.get("/read_config", response_model=ApiResponse)
|
||||
def get_emotion_config(
|
||||
config_id: int = Query(..., description="配置ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取情绪引擎配置
|
||||
|
||||
查询指定配置ID的情绪相关配置字段。
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含情绪配置数据
|
||||
|
||||
Example Response:
|
||||
{
|
||||
"code": 2000,
|
||||
"msg": "情绪配置获取成功",
|
||||
"data": {
|
||||
"config_id": 17,
|
||||
"emotion_enabled": true,
|
||||
"emotion_model_id": "gpt-4",
|
||||
"emotion_extract_keywords": true,
|
||||
"emotion_min_intensity": 0.1,
|
||||
"emotion_enable_subject": true
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪配置",
|
||||
extra={"config_id": config_id}
|
||||
)
|
||||
|
||||
# 初始化服务
|
||||
config_service = EmotionConfigService(db)
|
||||
|
||||
# 调用服务层
|
||||
data = config_service.get_emotion_config(config_id)
|
||||
|
||||
api_logger.info(
|
||||
"情绪配置获取成功",
|
||||
extra={
|
||||
"config_id": config_id,
|
||||
"emotion_enabled": data.get("emotion_enabled", False)
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=data, msg="情绪配置获取成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(
|
||||
f"获取情绪配置失败: {str(e)}",
|
||||
extra={"config_id": config_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪配置失败: {str(e)}",
|
||||
extra={"config_id": config_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取情绪配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@router.post("/updated_config", response_model=ApiResponse)
|
||||
def update_emotion_config(
|
||||
config: EmotionConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""更新情绪引擎配置
|
||||
|
||||
更新指定配置ID的情绪相关配置字段。
|
||||
|
||||
Args:
|
||||
config: 配置更新数据(包含config_id)
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含更新后的情绪配置数据
|
||||
|
||||
Example Request:
|
||||
{
|
||||
"config_id": 2,
|
||||
"emotion_enabled": true,
|
||||
"emotion_model_id": "gpt-4",
|
||||
"emotion_extract_keywords": true,
|
||||
"emotion_min_intensity": 0.1,
|
||||
"emotion_enable_subject": true
|
||||
}
|
||||
|
||||
Example Response:
|
||||
{
|
||||
"code": 2000,
|
||||
"msg": "情绪配置更新成功",
|
||||
"data": {
|
||||
"config_id": 17,
|
||||
"emotion_enabled": true,
|
||||
"emotion_model_id": "gpt-4",
|
||||
"emotion_extract_keywords": true,
|
||||
"emotion_min_intensity": 0.2,
|
||||
"emotion_enable_subject": true
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求更新情绪配置",
|
||||
extra={
|
||||
"config_id": config.config_id,
|
||||
"emotion_enabled": config.emotion_enabled,
|
||||
"emotion_min_intensity": config.emotion_min_intensity
|
||||
}
|
||||
)
|
||||
|
||||
# 初始化服务
|
||||
config_service = EmotionConfigService(db)
|
||||
|
||||
# 转换为字典(排除config_id,因为它作为参数传递)
|
||||
config_data = config.model_dump(exclude={'config_id'})
|
||||
|
||||
# 调用服务层
|
||||
data = config_service.update_emotion_config(config.config_id, config_data)
|
||||
|
||||
api_logger.info(
|
||||
"情绪配置更新成功",
|
||||
extra={
|
||||
"config_id": config.config_id,
|
||||
"emotion_enabled": data.get("emotion_enabled", False)
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=data, msg="情绪配置更新成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(
|
||||
f"更新情绪配置失败: {str(e)}",
|
||||
extra={"config_id": config.config_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"更新情绪配置失败: {str(e)}",
|
||||
extra={"config_id": config.config_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"更新情绪配置失败: {str(e)}"
|
||||
)
|
||||
255
api/app/controllers/emotion_controller.py
Normal file
255
api/app/controllers/emotion_controller.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""情绪分析控制器模块
|
||||
|
||||
本模块提供情绪分析相关的API端点,包括情绪标签、词云、健康指数和个性化建议。
|
||||
|
||||
Routes:
|
||||
POST /emotion/tags - 获取情绪标签统计
|
||||
POST /emotion/wordcloud - 获取情绪词云数据
|
||||
POST /emotion/health - 获取情绪健康指数
|
||||
POST /emotion/suggestions - 获取个性化情绪建议
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.emotion_schema import (
|
||||
EmotionTagsRequest,
|
||||
EmotionWordcloudRequest,
|
||||
EmotionHealthRequest,
|
||||
EmotionSuggestionsRequest
|
||||
)
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/emotion",
|
||||
tags=["Emotion Analysis"],
|
||||
dependencies=[Depends(get_current_user)] # 所有路由都需要认证
|
||||
)
|
||||
|
||||
|
||||
# 初始化情绪分析服务uv
|
||||
emotion_service = EmotionAnalyticsService()
|
||||
|
||||
|
||||
|
||||
@router.post("/tags", response_model=ApiResponse)
|
||||
async def get_emotion_tags(
|
||||
request: EmotionTagsRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
try:
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪标签统计",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"emotion_type": request.emotion_type,
|
||||
"start_date": request.start_date,
|
||||
"end_date": request.end_date,
|
||||
"limit": request.limit
|
||||
}
|
||||
)
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.get_emotion_tags(
|
||||
end_user_id=request.group_id,
|
||||
emotion_type=request.emotion_type,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
limit=request.limit
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"情绪标签统计获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"total_count": data.get("total_count", 0),
|
||||
"tags_count": len(data.get("tags", []))
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=data, msg="情绪标签获取成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪标签统计失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取情绪标签统计失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@router.post("/wordcloud", response_model=ApiResponse)
|
||||
async def get_emotion_wordcloud(
|
||||
request: EmotionWordcloudRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
try:
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪词云数据",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"emotion_type": request.emotion_type,
|
||||
"limit": request.limit
|
||||
}
|
||||
)
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.get_emotion_wordcloud(
|
||||
end_user_id=request.group_id,
|
||||
emotion_type=request.emotion_type,
|
||||
limit=request.limit
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"情绪词云数据获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"total_keywords": data.get("total_keywords", 0)
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=data, msg="情绪词云获取成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪词云数据失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取情绪词云数据失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@router.post("/health", response_model=ApiResponse)
|
||||
async def get_emotion_health(
|
||||
request: EmotionHealthRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
try:
|
||||
# 验证时间范围参数
|
||||
if request.time_range not in ["7d", "30d", "90d"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="时间范围参数无效,必须是 7d、30d 或 90d"
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪健康指数",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"time_range": request.time_range
|
||||
}
|
||||
)
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.calculate_emotion_health_index(
|
||||
end_user_id=request.group_id,
|
||||
time_range=request.time_range
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"情绪健康指数获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"health_score": data.get("health_score", 0),
|
||||
"level": data.get("level", "未知")
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=data, msg="情绪健康指数获取成功")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪健康指数失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取情绪健康指数失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@router.post("/suggestions", response_model=ApiResponse)
|
||||
async def get_emotion_suggestions(
|
||||
request: EmotionSuggestionsRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取个性化情绪建议
|
||||
|
||||
Args:
|
||||
request: 包含 group_id 和可选的 config_id
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
个性化情绪建议响应
|
||||
"""
|
||||
try:
|
||||
# 验证 config_id(如果提供)
|
||||
config_id = request.config_id
|
||||
if config_id is not None:
|
||||
from app.controllers.memory_agent_controller import validate_config_id
|
||||
try:
|
||||
config_id = validate_config_id(config_id, db)
|
||||
except ValueError as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取个性化情绪建议",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"config_id": config_id
|
||||
}
|
||||
)
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=request.group_id,
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"个性化建议获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"suggestions_count": len(data.get("suggestions", []))
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=data, msg="个性化建议获取成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取个性化建议失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取个性化建议失败: {str(e)}"
|
||||
)
|
||||
269
api/app/controllers/memory_reflection_controller.py
Normal file
269
api/app/controllers/memory_reflection_controller.py
Normal file
@@ -0,0 +1,269 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.core.memory.storage_services.reflection_engine.self_reflexion import ReflectionConfig, ReflectionEngine
|
||||
from app.dependencies import get_current_user
|
||||
from app.db import get_db
|
||||
from app.models.user_model import User
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_reflection_service import WorkspaceAppService, MemoryReflectionService
|
||||
from app.schemas.memory_reflection_schemas import Memory_Reflection
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
load_dotenv()
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory",
|
||||
tags=["Memory"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/reflection/save")
|
||||
async def save_reflection_config(
|
||||
request: Memory_Reflection,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Save reflection configuration to data_comfig table"""
|
||||
|
||||
|
||||
|
||||
try:
|
||||
config_id = request.config_id
|
||||
if not config_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="缺少必需参数: config_id"
|
||||
)
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||
|
||||
update_params = {
|
||||
"enable_self_reflexion": request.reflection_enabled,
|
||||
"iteration_period": request.reflection_period_in_hours,
|
||||
"reflexion_range": request.reflexion_range,
|
||||
"baseline": request.baseline,
|
||||
"reflection_model_id": request.reflection_model_id,
|
||||
"memory_verify": request.memory_verify,
|
||||
"quality_assessment": request.quality_assessment,
|
||||
}
|
||||
|
||||
|
||||
|
||||
query, params = DataConfigRepository.build_update_reflection(config_id, **update_params)
|
||||
|
||||
result = db.execute(text(query), params)
|
||||
if result.rowcount == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"未找到config_id为 {config_id} 的配置"
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
||||
# 查询更新后的配置
|
||||
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
|
||||
result = db.execute(text(select_query), select_params).fetchone()
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"更新后未找到config_id为 {config_id} 的配置"
|
||||
)
|
||||
|
||||
api_logger.info(f"成功保存反思配置到数据库,config_id: {config_id}")
|
||||
|
||||
reflection_result={
|
||||
"config_id": result.config_id,
|
||||
"enable_self_reflexion": result.enable_self_reflexion,
|
||||
"iteration_period": result.iteration_period,
|
||||
"reflexion_range": result.reflexion_range,
|
||||
"baseline": result.baseline,
|
||||
"reflection_model_id": result.reflection_model_id,
|
||||
"memory_verify": result.memory_verify,
|
||||
"quality_assessment": result.quality_assessment,
|
||||
"user_id": result.user_id}
|
||||
|
||||
return success(data=reflection_result, msg="反思配置成功")
|
||||
|
||||
|
||||
|
||||
except ValueError as ve:
|
||||
api_logger.error(f"参数错误: {str(ve)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"参数错误: {str(ve)}"
|
||||
)
|
||||
except Exception as e:
|
||||
api_logger.error(f"反思配置保存失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"反思配置保存失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/reflection")
|
||||
async def start_workspace_reflection(
|
||||
config_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
reflection_service = MemoryReflectionService(db)
|
||||
|
||||
try:
|
||||
api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}")
|
||||
|
||||
service = WorkspaceAppService(db)
|
||||
result = service.get_workspace_apps_detailed(workspace_id)
|
||||
|
||||
reflection_results = []
|
||||
|
||||
for data in result['apps_detailed_info']:
|
||||
if data['data_configs'] == []:
|
||||
continue
|
||||
|
||||
releases = data['releases']
|
||||
data_configs = data['data_configs']
|
||||
end_users = data['end_users']
|
||||
|
||||
for base, config, user in zip(releases, data_configs, end_users):
|
||||
if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']:
|
||||
# 调用反思服务
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
||||
|
||||
reflection_result = await reflection_service.start_reflection_from_data(
|
||||
config_data=config,
|
||||
end_user_id=user['id']
|
||||
)
|
||||
|
||||
reflection_results.append({
|
||||
"app_id": base['app_id'],
|
||||
"config_id": config['config_id'],
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": reflection_result
|
||||
})
|
||||
|
||||
return success(data=reflection_results, msg="反思配置成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"启动workspace反思失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"启动workspace反思失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/reflection/configs")
|
||||
async def start_reflection_configs(
|
||||
config_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""通过config_id查询data_config表中的反思配置信息"""
|
||||
try:
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
|
||||
# 使用DataConfigRepository查询反思配置
|
||||
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
|
||||
result = db.execute(text(select_query), select_params).fetchone()
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"未找到config_id为 {config_id} 的配置"
|
||||
)
|
||||
|
||||
# 构建返回数据
|
||||
reflection_config = {
|
||||
"config_id": result.config_id,
|
||||
"reflection_enabled": result.enable_self_reflexion,
|
||||
"reflection_period_in_hours": result.iteration_period,
|
||||
"reflexion_range": result.reflexion_range,
|
||||
"baseline": result.baseline,
|
||||
"reflection_model_id": result.reflection_model_id,
|
||||
"memory_verify": result.memory_verify,
|
||||
"quality_assessment": result.quality_assessment,
|
||||
"user_id": result.user_id
|
||||
}
|
||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||
return success(data=reflection_config, msg="反思配置查询成功")
|
||||
|
||||
|
||||
except HTTPException:
|
||||
# 重新抛出HTTP异常
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"查询反思配置失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"查询反思配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/reflection/run")
|
||||
async def reflection_run(
|
||||
config_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
|
||||
# 使用DataConfigRepository查询反思配置
|
||||
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
|
||||
result = db.execute(text(select_query), select_params).fetchone()
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"未找到config_id为 {config_id} 的配置"
|
||||
)
|
||||
|
||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||
|
||||
# 验证模型ID是否存在
|
||||
model_id = result.reflection_model_id
|
||||
if model_id:
|
||||
try:
|
||||
ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
api_logger.info(f"模型ID验证成功: {model_id}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"模型ID '{model_id}' 不存在,将使用默认模型: {str(e)}")
|
||||
# 可以设置为None,让反思引擎使用默认模型
|
||||
model_id = None
|
||||
|
||||
config = ReflectionConfig(
|
||||
enabled=result.enable_self_reflexion,
|
||||
iteration_period=result.iteration_period,
|
||||
reflexion_range=result.reflexion_range,
|
||||
baseline=result.baseline,
|
||||
output_example='',
|
||||
memory_verify=result.memory_verify,
|
||||
quality_assessment=result.quality_assessment,
|
||||
violation_handling_strategy="block",
|
||||
model_id=model_id
|
||||
)
|
||||
connector = Neo4jConnector()
|
||||
engine = ReflectionEngine(
|
||||
config=config,
|
||||
neo4j_connector=connector,
|
||||
llm_client=model_id # 传入验证后的 model_id
|
||||
)
|
||||
|
||||
result=await (engine.reflection_run())
|
||||
return success(data=result, msg="反思试运行")
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
from fastapi import APIRouter, Depends, status, Query
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
|
||||
from app.core.models import RedBearLLM
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
@@ -39,7 +35,7 @@ def get_model_providers():
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_model_list(
|
||||
type: Optional[List[model_schema.ModelType]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM&type=EMBEDDING)"),
|
||||
type: Optional[str] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||
@@ -54,13 +50,21 @@ def get_model_list(
|
||||
|
||||
支持多个 type 参数:
|
||||
- 单个:?type=LLM
|
||||
- 多个:?type=LLM&type=EMBEDDING
|
||||
- 多个(逗号分隔):?type=LLM,EMBEDDING
|
||||
- 多个(重复参数):?type=LLM&type=EMBEDDING
|
||||
"""
|
||||
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
# 解析 type 参数(支持逗号分隔)
|
||||
type_list = None
|
||||
if type:
|
||||
type_values = [t.strip() for t in type.split(',')]
|
||||
type_list = [model_schema.ModelType(t.lower()) for t in type_values if t]
|
||||
|
||||
api_logger.error(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQuery(
|
||||
type=type,
|
||||
type=type_list,
|
||||
provider=provider,
|
||||
is_active=is_active,
|
||||
is_public=is_public,
|
||||
|
||||
138
api/app/controllers/prompt_optimizer_controller.py
Normal file
138
api/app/controllers/prompt_optimizer_controller.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.models.prompt_optimizer_model import RoleType
|
||||
from app.schemas.prompt_optimizer_schema import PromptOptMessage, PromptOptModelSet, CreateSessionResponse, \
|
||||
OptimizePromptResponse, SessionHistoryResponse, SessionMessage
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.prompt_optimizer_service import PromptOptimizerService
|
||||
|
||||
router = APIRouter(prefix="/prompt", tags=["Prompts-Optimization"])
|
||||
logger = get_api_logger()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions",
|
||||
summary="Create a new prompt optimization session",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def create_prompt_session(
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Create a new prompt optimization session for the current user.
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains the newly generated session ID.
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
# create new session
|
||||
session = service.create_session(current_user.tenant_id, current_user.id)
|
||||
result_schema = CreateSessionResponse.model_validate(session)
|
||||
return success(data=result_schema)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}",
|
||||
summary="获取 prompt 优化历史对话",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def get_prompt_session(
|
||||
session_id: uuid.UUID = Path(..., description="Session ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Retrieve all messages from a specified prompt optimization session.
|
||||
|
||||
Args:
|
||||
session_id (UUID): The ID of the session to retrieve
|
||||
db (Session): Database session
|
||||
current_user: Current logged-in user
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains the session ID and the list of messages.
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
|
||||
history = service.get_session_message_history(
|
||||
session_id=session_id,
|
||||
user_id=current_user.id
|
||||
)
|
||||
|
||||
messages = [
|
||||
SessionMessage(role=role, content=content)
|
||||
for role, content in history
|
||||
]
|
||||
|
||||
result = SessionHistoryResponse(
|
||||
session_id=session_id,
|
||||
messages=messages
|
||||
)
|
||||
|
||||
return success(data=result)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/messages",
|
||||
summary="Get prompt optimization",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
async def get_prompt_opt(
|
||||
session_id: uuid.UUID = Path(..., description="Session ID"),
|
||||
data: PromptOptMessage = ...,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Send a user message in the specified session and return the optimized prompt
|
||||
along with its description and variables.
|
||||
|
||||
Args:
|
||||
session_id (UUID): The session ID
|
||||
data (PromptOptMessage): Contains the user message, model ID, and current prompt
|
||||
db (Session): Database session
|
||||
current_user: Current user information
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains the optimized prompt, description, and a list of variables.
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
service.create_message(
|
||||
tenant_id=current_user.tenant_id,
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
role=RoleType.USER,
|
||||
content=data.message
|
||||
)
|
||||
opt_result = await service.optimize_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
model_id=data.model_id,
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
current_prompt=data.current_prompt,
|
||||
user_require=data.message
|
||||
)
|
||||
service.create_message(
|
||||
tenant_id=current_user.tenant_id,
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
role=RoleType.ASSISTANT,
|
||||
content=opt_result.desc
|
||||
)
|
||||
variables = service.parser_prompt_variables(opt_result.prompt)
|
||||
result = {
|
||||
"prompt": opt_result.prompt,
|
||||
"desc": opt_result.desc,
|
||||
"variables": variables
|
||||
}
|
||||
result_schema = OptimizePromptResponse.model_validate(result)
|
||||
return success(data=result_schema)
|
||||
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
"""Memory 服务接口 - 基于 API Key 认证"""
|
||||
from fastapi import APIRouter, Depends
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, Request, Body
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.response_utils import success
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
|
||||
router = APIRouter(prefix="/memory", tags=["V1 - Memory API"])
|
||||
logger = get_business_logger()
|
||||
@@ -14,3 +18,31 @@ logger = get_business_logger()
|
||||
async def get_memory_info():
|
||||
"""获取记忆服务信息(占位)"""
|
||||
return success(data={}, msg="Memory API - Coming Soon")
|
||||
|
||||
|
||||
# /v1/memory/{resource_id}/chat
|
||||
@router.post("/{resource_id}/chat")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def chat_with_agent_demo(
|
||||
resource_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(..., description="聊天消息内容"),
|
||||
):
|
||||
"""
|
||||
Agent 聊天接口demo
|
||||
|
||||
scopes: 所需的权限范围列表["app", "rag", "memory"]
|
||||
|
||||
Args:
|
||||
resource_id: 如果是应用的apikey传的是应用id; 如果是服务的apikey传的是工作空间id
|
||||
message: 请求参数
|
||||
request: 声明请求
|
||||
api_key_auth: 包含验证后的API Key 信息
|
||||
db: db_session
|
||||
"""
|
||||
logger.info(f"API Key Auth: {api_key_auth}")
|
||||
logger.info(f"Resource ID: {resource_id}")
|
||||
logger.info(f"Message: {message}")
|
||||
return success(data={"received": True}, msg="消息已接收")
|
||||
585
api/app/controllers/tool_controller.py
Normal file
585
api/app/controllers/tool_controller.py
Normal file
@@ -0,0 +1,585 @@
|
||||
"""工具管理API控制器"""
|
||||
import base64
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Body
|
||||
from langfuse.api.core import jsonable_encoder
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, Field, PositiveInt, field_validator
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.models.tool_model import ToolConfig, BuiltinToolConfig, ToolType, ToolStatus, CustomToolConfig, MCPToolConfig
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.config import settings
|
||||
from app.core.tools.config_manager import ConfigManager
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
router = APIRouter(prefix="/tools", tags=["工具管理"])
|
||||
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
|
||||
def _encrypt_sensitive_params(parameters: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""加密敏感参数"""
|
||||
cipher_key = base64.urlsafe_b64encode(settings.SECRET_KEY[:32].ljust(32, '0').encode())
|
||||
cipher = Fernet(cipher_key)
|
||||
|
||||
encrypted_params = {}
|
||||
sensitive_keys = ['api_key', 'token', 'api_secret', 'password']
|
||||
|
||||
for key, value in parameters.items():
|
||||
if any(sensitive in key.lower() for sensitive in sensitive_keys) and value:
|
||||
encrypted_params[key] = cipher.encrypt(str(value).encode()).decode()
|
||||
else:
|
||||
encrypted_params[key] = value
|
||||
|
||||
return encrypted_params
|
||||
|
||||
|
||||
def _decrypt_sensitive_params(parameters: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""解密敏感参数"""
|
||||
cipher_key = base64.urlsafe_b64encode(settings.SECRET_KEY[:32].ljust(32, '0').encode())
|
||||
cipher = Fernet(cipher_key)
|
||||
|
||||
decrypted_params = {}
|
||||
sensitive_keys = ['api_key', 'token', 'secret', 'password']
|
||||
|
||||
for key, value in parameters.items():
|
||||
if any(sensitive in key.lower() for sensitive in sensitive_keys) and value:
|
||||
try:
|
||||
decrypted_params[key] = cipher.decrypt(value.encode()).decode()
|
||||
except Exception as e:
|
||||
decrypted_params[key] = value
|
||||
else:
|
||||
decrypted_params[key] = value
|
||||
|
||||
return decrypted_params
|
||||
|
||||
|
||||
def _update_tool_status(tool_config: ToolConfig, builtin_config: BuiltinToolConfig = None, tool_info: Dict = None) -> str:
|
||||
"""更新工具状态并返回新状态"""
|
||||
if tool_config.tool_type == ToolType.BUILTIN:
|
||||
if not tool_info or not tool_info.get('requires_config', False):
|
||||
new_status = ToolStatus.ACTIVE.value # 不需要配置的内置工具
|
||||
elif not builtin_config or not builtin_config.parameters:
|
||||
new_status = ToolStatus.INACTIVE.value
|
||||
else:
|
||||
# 检查是否有必要的API密钥
|
||||
has_key = bool(builtin_config.parameters.get('api_key') or builtin_config.parameters.get('token'))
|
||||
new_status = ToolStatus.ACTIVE.value if has_key else ToolStatus.INACTIVE.value
|
||||
else: # 自定义和MCP工具
|
||||
new_status = ToolStatus.ACTIVE.value if tool_config.config_data else ToolStatus.ERROR.value
|
||||
|
||||
# 更新数据库中的状态
|
||||
if tool_config.status != new_status:
|
||||
tool_config.status = new_status
|
||||
|
||||
return new_status
|
||||
|
||||
|
||||
# ==================== 请求/响应模型 ====================
|
||||
|
||||
class ToolListResponse(BaseModel):
|
||||
"""工具列表响应"""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
tool_type: str
|
||||
category: str
|
||||
version: str = "1.0.0"
|
||||
status: str # active inactive error loading
|
||||
requires_config: bool = False
|
||||
# is_configured: bool = False
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class BuiltinToolConfigRequest(BaseModel):
|
||||
"""内置工具配置请求"""
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
|
||||
|
||||
|
||||
class CustomToolCreateRequest(BaseModel):
|
||||
"""自定义工具创建请求体模型,包含参数校验规则"""
|
||||
name: str = Field(..., min_length=1, max_length=100, description="工具名称,必填")
|
||||
description: str = Field(None, description="工具描述")
|
||||
base_url: str = Field(None, description="工具基础URL")
|
||||
schema_url: str = Field(None, description="工具Schema URL")
|
||||
schema_content: Optional[Dict[str, Any]] = Field(None, description="工具Schema内容,可选")
|
||||
auth_type: str = Field("none", pattern=r"^(none|api_key|bearer_token)$", description="认证类型")
|
||||
auth_config: Optional[Dict[str, Any]] = Field(None, description="认证配置,默认空字典")
|
||||
timeout: PositiveInt = Field(30, ge=1, le=300, description="超时时间,1-300秒,默认30")
|
||||
|
||||
# 自定义校验:当auth_type为api_key时,auth_config必须包含api_key字段
|
||||
@field_validator("auth_config")
|
||||
def validate_auth_config(cls, v, values):
|
||||
auth_type = values.data.get("auth_type")
|
||||
if auth_type == "api_key" and (not v or "api_key" not in v):
|
||||
raise ValueError("认证类型为api_key时,auth_config必须包含api_key字段")
|
||||
if auth_type == "bearer_token" and (not v or "bearer_token" not in v):
|
||||
raise ValueError("认证类型为bearer_token时,auth_config必须包含bearer_token字段")
|
||||
return v
|
||||
|
||||
class MCPToolCreateRequest(BaseModel):
|
||||
"""MCP工具创建请求体模型,适配MCP业务特性"""
|
||||
# 基础必填字段(带长度/格式校验)
|
||||
name: str = Field(..., min_length=1, max_length=100,description="MCP工具名称")
|
||||
description: str = Field(None, description="MCP工具描述")
|
||||
# MCP核心字段:服务端URL(强制HTTP/HTTPS格式)
|
||||
server_url: str = Field(..., description="MCP服务端URL,仅支持http/https协议")
|
||||
# 连接配置:默认空字典,可自定义校验规则(根据实际业务调整)
|
||||
connection_config: Dict[str, Any] = Field({},description="MCP连接配置(如认证信息、超时、重试等),默认空字典")
|
||||
|
||||
@field_validator("connection_config")
|
||||
def validate_connection_config(cls, v):
|
||||
# 示例1:若包含timeout,必须是1-300的整数
|
||||
if "timeout" in v:
|
||||
timeout = v["timeout"]
|
||||
if not isinstance(timeout, int) or timeout < 1 or timeout > 300:
|
||||
raise ValueError("connection_config.timeout必须是1-300的整数")
|
||||
return v
|
||||
|
||||
# @field_validator("server_url")
|
||||
# def validate_server_url_protocol(cls, v):
|
||||
# if v.scheme != "https":
|
||||
# raise ValueError("MCP服务端URL仅支持HTTPS协议(安全要求)")
|
||||
# return v
|
||||
|
||||
|
||||
# ==================== API端点 ====================
|
||||
@router.get("", response_model=List[ToolListResponse])
|
||||
async def list_tools(
|
||||
name: Optional[str] = None,
|
||||
tool_type: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取工具列表(包含内置工具、自定义工具和MCP工具)"""
|
||||
try:
|
||||
# 初始化内置工具(如果需要)
|
||||
config_manager = ConfigManager()
|
||||
config_manager.ensure_builtin_tools_initialized(
|
||||
current_user.tenant_id, db, ToolConfig, BuiltinToolConfig, ToolType, ToolStatus
|
||||
)
|
||||
|
||||
response_tools = []
|
||||
|
||||
query = db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == current_user.tenant_id
|
||||
)
|
||||
if tool_type:
|
||||
query = query.filter(ToolConfig.tool_type == tool_type)
|
||||
|
||||
if name:
|
||||
query = query.filter(ToolConfig.name.ilike(f"%{name}%"))
|
||||
|
||||
tools = query.all()
|
||||
builtin_tools = config_manager.load_builtin_tools_config()
|
||||
configured_tools = {tool_info["tool_class"]: tool_info for tool_key, tool_info in builtin_tools.items()}
|
||||
|
||||
for tool_config in tools:
|
||||
if tool_config.tool_type == ToolType.BUILTIN.value:
|
||||
builtin_config = db.query(BuiltinToolConfig).filter(BuiltinToolConfig.id == tool_config.id).first()
|
||||
tool_info = configured_tools.get(builtin_config.tool_class)
|
||||
status = _update_tool_status(tool_config, builtin_config, tool_info)
|
||||
else:
|
||||
status = _update_tool_status(tool_config)
|
||||
|
||||
response_tools.append(ToolListResponse(
|
||||
id=str(tool_config.id),
|
||||
name=tool_config.name,
|
||||
description=tool_config.description,
|
||||
tool_type=tool_config.tool_type,
|
||||
category=tool_info['category'] if tool_config.tool_type == ToolType.BUILTIN.value else tool_config.tool_type,
|
||||
version="1.0.0",
|
||||
status=status,
|
||||
requires_config=tool_info['requires_config'] if tool_config.tool_type == ToolType.BUILTIN.value else False,
|
||||
))
|
||||
|
||||
return response_tools
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/builtin/{tool_id}")
|
||||
async def get_builtin_tool_detail(
|
||||
tool_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取内置工具详情"""
|
||||
try:
|
||||
config_manager = ConfigManager()
|
||||
builtin_tools = config_manager.load_builtin_tools_config()
|
||||
configured_tools = {tool_info["tool_class"]: tool_info for tool_key, tool_info in builtin_tools.items()}
|
||||
tool_config = db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == current_user.tenant_id,
|
||||
ToolConfig.id == tool_id
|
||||
).first()
|
||||
builtin_config = db.query(BuiltinToolConfig).filter(BuiltinToolConfig.id == tool_config.id).first()
|
||||
tool_info = configured_tools.get(builtin_config.tool_class)
|
||||
|
||||
is_configured = False
|
||||
config_parameters = {}
|
||||
|
||||
if builtin_config and builtin_config.parameters:
|
||||
is_configured = bool(builtin_config.parameters.get('api_key') or builtin_config.parameters.get('token'))
|
||||
# 不返回敏感信息,只返回非敏感配置
|
||||
config_parameters = {k: v for k, v in builtin_config.parameters.items()
|
||||
if not any(sensitive in k.lower() for sensitive in ['key', 'secret', 'token', 'password'])}
|
||||
|
||||
return {
|
||||
"id": tool_config.id,
|
||||
"name": tool_config.name,
|
||||
"description": tool_config.description,
|
||||
"category": tool_info['category'],
|
||||
"status": tool_config.tool_type,
|
||||
"requires_config": tool_info['requires_config'],
|
||||
"is_configured": is_configured,
|
||||
"config_parameters": config_parameters
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具详情失败: {tool_id}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/builtin/{tool_id}/configure")
|
||||
async def configure_builtin_tool(
|
||||
tool_id: str,
|
||||
request: BuiltinToolConfigRequest = Body(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""配置内置工具参数(租户级别)"""
|
||||
try:
|
||||
# 查询工具配置
|
||||
tool_config = db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == current_user.tenant_id,
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tool_type == ToolType.BUILTIN
|
||||
).first()
|
||||
|
||||
if not tool_config:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
# 获取内置工具配置
|
||||
builtin_config = db.query(BuiltinToolConfig).filter(
|
||||
BuiltinToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if not builtin_config:
|
||||
raise HTTPException(status_code=404, detail="内置工具配置不存在")
|
||||
|
||||
# 获取全局工具信息
|
||||
config_manager = ConfigManager()
|
||||
builtin_tools_config = config_manager.load_builtin_tools_config()
|
||||
tool_info = None
|
||||
for tool_key, info in builtin_tools_config.items():
|
||||
if info['tool_class'] == builtin_config.tool_class:
|
||||
tool_info = info
|
||||
break
|
||||
|
||||
if not tool_info:
|
||||
raise HTTPException(status_code=404, detail="工具信息不存在")
|
||||
|
||||
# 加密敏感参数
|
||||
encrypted_params = _encrypt_sensitive_params(request.parameters)
|
||||
|
||||
# 更新配置
|
||||
builtin_config.parameters = encrypted_params
|
||||
|
||||
# 更新状态
|
||||
_update_tool_status(tool_config, builtin_config, tool_info)
|
||||
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"工具 {tool_config.name} 配置成功"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"配置内置工具失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/builtin/{tool_id}/config")
|
||||
async def get_builtin_tool_config(
|
||||
tool_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取内置工具配置(用于使用)"""
|
||||
try:
|
||||
# 查询工具配置
|
||||
tool_config = db.query(ToolConfig).filter(
|
||||
ToolConfig.tenant_id == current_user.tenant_id,
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tool_type == ToolType.BUILTIN
|
||||
).first()
|
||||
|
||||
if not tool_config:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
# 获取内置工具配置
|
||||
builtin_config = db.query(BuiltinToolConfig).filter(
|
||||
BuiltinToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if not builtin_config:
|
||||
raise HTTPException(status_code=404, detail="内置工具配置不存在")
|
||||
|
||||
# 解密参数
|
||||
decrypted_params = _decrypt_sensitive_params(builtin_config.parameters or {})
|
||||
|
||||
return {
|
||||
"tool_id": tool_id,
|
||||
"tool_class": builtin_config.tool_class,
|
||||
"name": tool_config.name,
|
||||
"parameters": decrypted_params,
|
||||
"status": tool_config.status
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具配置失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/custom")
|
||||
async def create_custom_tool(
|
||||
request: CustomToolCreateRequest = Body(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""创建自定义工具"""
|
||||
try:
|
||||
config_data = jsonable_encoder(request.model_dump())
|
||||
config_data["tool_type"] = "custom"
|
||||
|
||||
config_manager = ConfigManager()
|
||||
is_valid, error_msg = config_manager.validate_config(config_data, "custom")
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
# 创建数据库记录
|
||||
tool_config = ToolConfig(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
tool_type=ToolType.CUSTOM,
|
||||
tenant_id=current_user.tenant_id,
|
||||
status=ToolStatus.ACTIVE.value,
|
||||
config_data=config_data
|
||||
)
|
||||
db.add(tool_config)
|
||||
db.flush()
|
||||
|
||||
# 创建CustomToolConfig记录
|
||||
custom_config = CustomToolConfig(
|
||||
id=tool_config.id,
|
||||
base_url=request.base_url,
|
||||
schema_url=request.schema_url,
|
||||
schema_content=request.schema_content,
|
||||
auth_type=request.auth_type,
|
||||
auth_config=request.auth_config,
|
||||
timeout=request.timeout
|
||||
)
|
||||
db.add(custom_config)
|
||||
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"自定义工具 {request.name} 创建成功",
|
||||
"tool_id": str(tool_config.id)
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"创建自定义工具失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/mcp")
|
||||
async def create_mcp_tool(
|
||||
request: MCPToolCreateRequest = Body(..., description="MCP工具创建参数"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""创建MCP工具"""
|
||||
try:
|
||||
config_data = jsonable_encoder(request.model_dump())
|
||||
config_data["tool_type"] = "mcp"
|
||||
|
||||
config_manager = ConfigManager()
|
||||
is_valid, error_msg = config_manager.validate_config(config_data, "mcp")
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
# 创建数据库记录
|
||||
try:
|
||||
tool_config = ToolConfig(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
tool_type=ToolType.MCP,
|
||||
tenant_id=current_user.tenant_id,
|
||||
status=ToolStatus.ACTIVE.value,
|
||||
config_data=config_data
|
||||
)
|
||||
db.add(tool_config)
|
||||
db.flush()
|
||||
|
||||
# 创建MCPToolConfig记录
|
||||
mcp_config = MCPToolConfig(
|
||||
id=tool_config.id,
|
||||
server_url=request.server_url,
|
||||
connection_config=request.connection_config
|
||||
)
|
||||
db.add(mcp_config)
|
||||
|
||||
db.commit()
|
||||
except SQLAlchemyError as db_e:
|
||||
db.rollback()
|
||||
logger.error(f"创建MCP工具数据库操作失败(租户ID:{current_user.tenant_id},工具名:{request.name}): {str(db_e)}",
|
||||
exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"创建MCP工具数据库操作失败(租户ID:{current_user.tenant_id},"
|
||||
f"工具名:{request.name}):{str(db_e)}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"MCP工具 {request.name} 创建成功",
|
||||
"tool_id": str(tool_config.id)
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"创建MCP工具失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.delete("/{tool_id}")
|
||||
async def delete_tool(
|
||||
tool_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""删除工具(仅限自定义和MCP工具)"""
|
||||
try:
|
||||
tool = db.query(ToolConfig).filter(
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tenant_id == current_user.tenant_id
|
||||
).first()
|
||||
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
if tool.tool_type == ToolType.BUILTIN:
|
||||
raise HTTPException(status_code=403, detail="内置工具不允许删除")
|
||||
|
||||
db.delete(tool)
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"工具 {tool.name} 删除成功"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"删除工具失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put("/{tool_id}")
|
||||
async def update_tool(
|
||||
tool_id: str,
|
||||
config_data: Optional[Dict[str, Any]] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""更新工具(仅限自定义和MCP工具)"""
|
||||
try:
|
||||
tool = db.query(ToolConfig).filter(
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tenant_id == current_user.tenant_id
|
||||
).first()
|
||||
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
if tool.tool_type == ToolType.BUILTIN:
|
||||
raise HTTPException(status_code=403, detail="内置工具不允许修改")
|
||||
|
||||
if config_data is not None:
|
||||
tool.config_data = config_data
|
||||
# 更新状态
|
||||
_update_tool_status(tool)
|
||||
|
||||
db.commit()
|
||||
db.refresh(tool)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"工具 {tool.name} 更新成功",
|
||||
"status": tool.status
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新工具失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{tool_id}/toggle")
|
||||
async def toggle_tool_status(
|
||||
tool_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""切换工具活跃/非活跃状态"""
|
||||
try:
|
||||
tool = db.query(ToolConfig).filter(
|
||||
ToolConfig.id == tool_id,
|
||||
ToolConfig.tenant_id == current_user.tenant_id
|
||||
).first()
|
||||
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
# 在active和inactive之间切换
|
||||
if tool.status == ToolStatus.ACTIVE.value:
|
||||
tool.status = ToolStatus.INACTIVE.value
|
||||
elif tool.status == ToolStatus.INACTIVE.value:
|
||||
tool.status = ToolStatus.ACTIVE.value
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="只有可用或非活跃状态的工具可以切换")
|
||||
|
||||
db.commit()
|
||||
db.refresh(tool)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"工具 {tool.name} 状态已更新为 {tool.status}",
|
||||
"status": tool.status
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"切换工具状态失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
430
api/app/controllers/tool_execution_controller.py
Normal file
430
api/app/controllers/tool_execution_controller.py
Normal file
@@ -0,0 +1,430 @@
|
||||
"""工具执行API控制器"""
|
||||
import uuid
|
||||
from typing import Dict, Any, List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.core.tools.registry import ToolRegistry
|
||||
from app.core.tools.executor import ToolExecutor
|
||||
from app.core.tools.chain_manager import ChainManager, ChainDefinition, ChainStep, ChainExecutionMode
|
||||
from app.core.tools.builtin import *
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
router = APIRouter(prefix="/tools/execution", tags=["工具执行"])
|
||||
|
||||
|
||||
# ==================== 请求/响应模型 ====================
|
||||
|
||||
class ToolExecutionRequest(BaseModel):
|
||||
"""工具执行请求"""
|
||||
tool_id: str = Field(..., description="工具ID")
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
|
||||
timeout: Optional[float] = Field(None, ge=1, le=300, description="超时时间(秒)")
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, description="额外元数据")
|
||||
|
||||
|
||||
class BatchExecutionRequest(BaseModel):
|
||||
"""批量执行请求"""
|
||||
executions: List[ToolExecutionRequest] = Field(..., description="执行列表")
|
||||
max_concurrency: int = Field(5, ge=1, le=20, description="最大并发数")
|
||||
|
||||
|
||||
class ToolExecutionResponse(BaseModel):
|
||||
"""工具执行响应"""
|
||||
success: bool
|
||||
execution_id: str
|
||||
tool_id: str
|
||||
data: Any = None
|
||||
error: Optional[str] = None
|
||||
error_code: Optional[str] = None
|
||||
execution_time: float
|
||||
token_usage: Optional[Dict[str, int]] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ChainStepRequest(BaseModel):
|
||||
"""链步骤请求"""
|
||||
tool_id: str = Field(..., description="工具ID")
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
|
||||
condition: Optional[str] = Field(None, description="执行条件")
|
||||
output_mapping: Optional[Dict[str, str]] = Field(None, description="输出映射")
|
||||
error_handling: str = Field("stop", description="错误处理策略")
|
||||
|
||||
|
||||
class ChainExecutionRequest(BaseModel):
|
||||
"""链执行请求"""
|
||||
name: str = Field(..., description="链名称")
|
||||
description: str = Field("", description="链描述")
|
||||
steps: List[ChainStepRequest] = Field(..., description="执行步骤")
|
||||
execution_mode: str = Field("sequential", description="执行模式")
|
||||
initial_variables: Optional[Dict[str, Any]] = Field(None, description="初始变量")
|
||||
global_timeout: Optional[float] = Field(None, description="全局超时")
|
||||
|
||||
|
||||
class ExecutionHistoryResponse(BaseModel):
|
||||
"""执行历史响应"""
|
||||
execution_id: str
|
||||
tool_id: str
|
||||
status: str
|
||||
started_at: Optional[str]
|
||||
completed_at: Optional[str]
|
||||
execution_time: Optional[float]
|
||||
user_id: Optional[str]
|
||||
workspace_id: Optional[str]
|
||||
input_data: Optional[Dict[str, Any]]
|
||||
output_data: Optional[Any]
|
||||
error_message: Optional[str]
|
||||
token_usage: Optional[Dict[str, int]]
|
||||
|
||||
|
||||
class ToolConnectionTestResponse(BaseModel):
|
||||
"""工具连接测试响应"""
|
||||
success: bool
|
||||
message: str
|
||||
error: Optional[str] = None
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
# ==================== 依赖注入 ====================
|
||||
|
||||
def get_tool_registry(db: Session = Depends(get_db)) -> ToolRegistry:
|
||||
"""获取工具注册表"""
|
||||
registry = ToolRegistry(db)
|
||||
|
||||
# 注册内置工具类
|
||||
registry.register_tool_class(DateTimeTool)
|
||||
registry.register_tool_class(JsonTool)
|
||||
registry.register_tool_class(BaiduSearchTool)
|
||||
registry.register_tool_class(MinerUTool)
|
||||
registry.register_tool_class(TextInTool)
|
||||
|
||||
return registry
|
||||
|
||||
|
||||
def get_tool_executor(
|
||||
db: Session = Depends(get_db),
|
||||
registry: ToolRegistry = Depends(get_tool_registry)
|
||||
) -> ToolExecutor:
|
||||
"""获取工具执行器"""
|
||||
return ToolExecutor(db, registry)
|
||||
|
||||
|
||||
def get_chain_manager(executor: ToolExecutor = Depends(get_tool_executor)) -> ChainManager:
|
||||
"""获取链管理器"""
|
||||
return ChainManager(executor)
|
||||
|
||||
|
||||
# ==================== API端点 ====================
|
||||
|
||||
@router.post("/execute", response_model=ToolExecutionResponse)
|
||||
async def execute_tool(
|
||||
request: ToolExecutionRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""执行单个工具"""
|
||||
try:
|
||||
# 生成执行ID
|
||||
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
# 执行工具
|
||||
result = await executor.execute_tool(
|
||||
tool_id=request.tool_id,
|
||||
parameters=request.parameters,
|
||||
user_id=current_user.id,
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
execution_id=execution_id,
|
||||
timeout=request.timeout,
|
||||
metadata=request.metadata
|
||||
)
|
||||
|
||||
return ToolExecutionResponse(
|
||||
success=result.success,
|
||||
execution_id=execution_id,
|
||||
tool_id=request.tool_id,
|
||||
data=result.data,
|
||||
error=result.error,
|
||||
error_code=result.error_code,
|
||||
execution_time=result.execution_time,
|
||||
token_usage=result.token_usage,
|
||||
metadata=result.metadata
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具执行失败: {request.tool_id}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/batch", response_model=List[ToolExecutionResponse])
|
||||
async def execute_tools_batch(
|
||||
request: BatchExecutionRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""批量执行工具"""
|
||||
try:
|
||||
# 准备执行配置
|
||||
execution_configs = []
|
||||
execution_ids = []
|
||||
|
||||
for exec_request in request.executions:
|
||||
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
|
||||
execution_ids.append(execution_id)
|
||||
|
||||
execution_configs.append({
|
||||
"tool_id": exec_request.tool_id,
|
||||
"parameters": exec_request.parameters,
|
||||
"user_id": current_user.id,
|
||||
"workspace_id": current_user.current_workspace_id,
|
||||
"execution_id": execution_id,
|
||||
"timeout": exec_request.timeout,
|
||||
"metadata": exec_request.metadata
|
||||
})
|
||||
|
||||
# 批量执行
|
||||
results = await executor.execute_tools_batch(
|
||||
execution_configs,
|
||||
max_concurrency=request.max_concurrency
|
||||
)
|
||||
|
||||
# 转换响应格式
|
||||
responses = []
|
||||
for i, result in enumerate(results):
|
||||
responses.append(ToolExecutionResponse(
|
||||
success=result.success,
|
||||
execution_id=execution_ids[i],
|
||||
tool_id=request.executions[i].tool_id,
|
||||
data=result.data,
|
||||
error=result.error,
|
||||
error_code=result.error_code,
|
||||
execution_time=result.execution_time,
|
||||
token_usage=result.token_usage,
|
||||
metadata=result.metadata
|
||||
))
|
||||
|
||||
return responses
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量执行失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/chain", response_model=Dict[str, Any])
|
||||
async def execute_tool_chain(
|
||||
request: ChainExecutionRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
chain_manager: ChainManager = Depends(get_chain_manager)
|
||||
):
|
||||
"""执行工具链"""
|
||||
try:
|
||||
# 转换步骤格式
|
||||
steps = []
|
||||
for step_request in request.steps:
|
||||
step = ChainStep(
|
||||
tool_id=step_request.tool_id,
|
||||
parameters=step_request.parameters,
|
||||
condition=step_request.condition,
|
||||
output_mapping=step_request.output_mapping,
|
||||
error_handling=step_request.error_handling
|
||||
)
|
||||
steps.append(step)
|
||||
|
||||
# 创建链定义
|
||||
chain_definition = ChainDefinition(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
steps=steps,
|
||||
execution_mode=ChainExecutionMode(request.execution_mode),
|
||||
global_timeout=request.global_timeout
|
||||
)
|
||||
|
||||
# 注册并执行链
|
||||
chain_manager.register_chain(chain_definition)
|
||||
|
||||
result = await chain_manager.execute_chain(
|
||||
chain_name=request.name,
|
||||
initial_variables=request.initial_variables
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具链执行失败: {request.name}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/running", response_model=List[Dict[str, Any]])
|
||||
async def get_running_executions(
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""获取正在运行的执行"""
|
||||
try:
|
||||
running_executions = executor.get_running_executions()
|
||||
|
||||
# 过滤当前工作空间的执行
|
||||
workspace_executions = [
|
||||
exec_info for exec_info in running_executions
|
||||
if exec_info.get("workspace_id") == str(current_user.current_workspace_id)
|
||||
]
|
||||
|
||||
return workspace_executions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取运行中执行失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/cancel/{execution_id}", response_model=Dict[str, Any])
|
||||
async def cancel_execution(
|
||||
execution_id: str = Path(..., description="执行ID"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""取消工具执行"""
|
||||
try:
|
||||
success = await executor.cancel_execution(execution_id)
|
||||
|
||||
if success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "执行已取消"
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="执行不存在或已完成")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"取消执行失败: {execution_id}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/history", response_model=List[ExecutionHistoryResponse])
|
||||
async def get_execution_history(
|
||||
tool_id: Optional[str] = Query(None, description="工具ID过滤"),
|
||||
limit: int = Query(50, ge=1, le=200, description="返回数量限制"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""获取执行历史"""
|
||||
try:
|
||||
history = executor.get_execution_history(
|
||||
tool_id=tool_id,
|
||||
user_id=current_user.id,
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
# 转换响应格式
|
||||
responses = []
|
||||
for record in history:
|
||||
responses.append(ExecutionHistoryResponse(
|
||||
execution_id=record["execution_id"],
|
||||
tool_id=record["tool_id"],
|
||||
status=record["status"],
|
||||
started_at=record["started_at"],
|
||||
completed_at=record["completed_at"],
|
||||
execution_time=record["execution_time"],
|
||||
user_id=record["user_id"],
|
||||
workspace_id=record["workspace_id"],
|
||||
input_data=record["input_data"],
|
||||
output_data=record["output_data"],
|
||||
error_message=record["error_message"],
|
||||
token_usage=record["token_usage"]
|
||||
))
|
||||
|
||||
return responses
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取执行历史失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/statistics", response_model=Dict[str, Any])
|
||||
async def get_execution_statistics(
|
||||
days: int = Query(7, ge=1, le=90, description="统计天数"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""获取执行统计"""
|
||||
try:
|
||||
stats = executor.get_execution_statistics(
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
days=days
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"statistics": stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取执行统计失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/chains/running", response_model=List[Dict[str, Any]])
|
||||
async def get_running_chains(
|
||||
current_user: User = Depends(get_current_user),
|
||||
chain_manager: ChainManager = Depends(get_chain_manager)
|
||||
):
|
||||
"""获取正在运行的工具链"""
|
||||
try:
|
||||
running_chains = chain_manager.get_running_chains()
|
||||
return running_chains
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取运行中工具链失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/chains", response_model=List[Dict[str, Any]])
|
||||
async def list_tool_chains(
|
||||
current_user: User = Depends(get_current_user),
|
||||
chain_manager: ChainManager = Depends(get_chain_manager)
|
||||
):
|
||||
"""列出工具链"""
|
||||
try:
|
||||
chains = chain_manager.list_chains()
|
||||
return chains
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具链列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/test-connection/{tool_id}", response_model=ToolConnectionTestResponse)
|
||||
async def test_tool_connection(
|
||||
tool_id: str = Path(..., description="工具ID"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
executor: ToolExecutor = Depends(get_tool_executor)
|
||||
):
|
||||
"""测试工具连接"""
|
||||
try:
|
||||
result = await executor.test_tool_connection(
|
||||
tool_id=tool_id,
|
||||
user_id=current_user.id,
|
||||
workspace_id=current_user.current_workspace_id
|
||||
)
|
||||
|
||||
return ToolConnectionTestResponse(
|
||||
success=result.get("success", False),
|
||||
message=result.get("message", ""),
|
||||
error=result.get("error"),
|
||||
details=result.get("details")
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具连接测试失败: {tool_id}, 错误: {e}")
|
||||
return ToolConnectionTestResponse(
|
||||
success=False,
|
||||
message="连接测试失败",
|
||||
error=str(e)
|
||||
)
|
||||
@@ -471,28 +471,52 @@ async def run_workflow(
|
||||
import json
|
||||
|
||||
async def event_generator():
|
||||
"""生成 SSE 事件"""
|
||||
"""生成 SSE 事件
|
||||
|
||||
SSE 格式:
|
||||
event: <event_type>
|
||||
data: <json_data>
|
||||
|
||||
支持的事件类型:
|
||||
- workflow_start: 工作流开始
|
||||
- workflow_end: 工作流结束
|
||||
- node_start: 节点开始执行
|
||||
- node_end: 节点执行完成
|
||||
- node_chunk: 中间节点的流式输出
|
||||
- message: 最终消息的流式输出(End 节点及其相邻节点)
|
||||
"""
|
||||
try:
|
||||
async for event in service.run_workflow(
|
||||
async for event in await service.run_workflow(
|
||||
app_id=app_id,
|
||||
input_data=input_data,
|
||||
triggered_by=current_user.id,
|
||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||
stream=True
|
||||
):
|
||||
# 转换为 SSE 格式
|
||||
yield f"data: {json.dumps(event)}\n\n"
|
||||
# 提取事件类型和数据
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
# event: <type>
|
||||
# data: <json>
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||
yield sse_message
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式执行异常: {e}", exc_info=True)
|
||||
error_event = {
|
||||
"type": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
yield f"data: {json.dumps(error_event)}\n\n"
|
||||
# 发送错误事件
|
||||
sse_error = f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
||||
yield sse_error
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream"
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no" # 禁用 nginx 缓冲
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 非流式执行
|
||||
|
||||
@@ -9,18 +9,15 @@ LangChain Agent 封装
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Optional, AsyncGenerator, Sequence
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from app.core.memory.agent.mcp_server.services import session_service
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
|
||||
@@ -37,9 +37,10 @@ def require_api_key(
|
||||
@require_api_key(scopes=["app"])
|
||||
def chat_with_app(
|
||||
resource_id: uuid.UUID,
|
||||
api_key_auth: ApiKeyAuth = Depends(),
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str
|
||||
message: str = Query(..., description="聊天消息内容")
|
||||
):
|
||||
# api_key_auth 包含验证后的API Key 信息
|
||||
pass
|
||||
@@ -70,29 +71,6 @@ def require_api_key(
|
||||
})
|
||||
raise BusinessException("API Key 无效或已过期", BizCode.API_KEY_INVALID)
|
||||
|
||||
rate_limiter = RateLimiterService()
|
||||
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj)
|
||||
if not is_allowed:
|
||||
logger.warning("API Key 限流触发", extra={
|
||||
"api_key_id": str(api_key_obj.id),
|
||||
"endpoint": str(request.url),
|
||||
"method": request.method,
|
||||
"error_msg": error_msg
|
||||
})
|
||||
# 根据错误消息判断限流类型
|
||||
if "QPS" in error_msg:
|
||||
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
|
||||
elif "Daily" in error_msg:
|
||||
code = BizCode.API_KEY_DAILY_LIMIT_EXCEEDED
|
||||
else:
|
||||
code = BizCode.API_KEY_QUOTA_EXCEEDED
|
||||
|
||||
raise RateLimitException(
|
||||
error_msg,
|
||||
code,
|
||||
rate_headers=rate_headers
|
||||
)
|
||||
|
||||
if scopes:
|
||||
missing_scopes = []
|
||||
for scope in scopes:
|
||||
@@ -138,6 +116,30 @@ def require_api_key(
|
||||
scopes=api_key_obj.scopes,
|
||||
resource_id=api_key_obj.resource_id,
|
||||
)
|
||||
|
||||
rate_limiter = RateLimiterService()
|
||||
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj)
|
||||
if not is_allowed:
|
||||
logger.warning("API Key 限流触发", extra={
|
||||
"api_key_id": str(api_key_obj.id),
|
||||
"endpoint": str(request.url),
|
||||
"method": request.method,
|
||||
"error_msg": error_msg
|
||||
})
|
||||
# 根据错误消息判断限流类型
|
||||
if "QPS" in error_msg:
|
||||
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
|
||||
elif "Daily" in error_msg:
|
||||
code = BizCode.API_KEY_DAILY_LIMIT_EXCEEDED
|
||||
else:
|
||||
code = BizCode.API_KEY_QUOTA_EXCEEDED
|
||||
|
||||
raise RateLimitException(
|
||||
error_msg,
|
||||
code,
|
||||
rate_headers=rate_headers
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
response = await func(*args, **kwargs)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
@@ -16,7 +16,7 @@ def generate_api_key(key_type: ApiKeyType) -> str:
|
||||
key_type: API Key 类型
|
||||
|
||||
Returns:
|
||||
tuple: (api_key, key_hash, key_prefix)
|
||||
str: api_key
|
||||
"""
|
||||
# 前缀映射
|
||||
prefix_map = {
|
||||
|
||||
@@ -148,6 +148,7 @@ class Settings:
|
||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
|
||||
REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
# Memory Module Configuration (internal)
|
||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||
@@ -156,6 +157,12 @@ class Settings:
|
||||
MEMORY_RUNTIME_FILE: str = os.getenv("MEMORY_RUNTIME_FILE", "runtime.json")
|
||||
MEMORY_DBRUN_FILE: str = os.getenv("MEMORY_DBRUN_FILE", "dbrun.json")
|
||||
|
||||
# Tool Management Configuration
|
||||
TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools")
|
||||
TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60"))
|
||||
TOOL_MAX_CONCURRENCY: int = int(os.getenv("TOOL_MAX_CONCURRENCY", "10"))
|
||||
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
|
||||
|
||||
def get_memory_output_path(self, filename: str = "") -> str:
|
||||
"""
|
||||
Get the full path for memory module output files.
|
||||
|
||||
85
api/app/core/memory/models/emotion_models.py
Normal file
85
api/app/core/memory/models/emotion_models.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Emotion extraction models for LLM structured output.
|
||||
|
||||
This module contains Pydantic models for emotion extraction from statements,
|
||||
designed to be used with LLM structured output capabilities.
|
||||
|
||||
Classes:
|
||||
EmotionExtraction: Model for emotion extraction results from statements
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class EmotionExtraction(BaseModel):
|
||||
"""Emotion extraction result model for LLM structured output.
|
||||
|
||||
This model represents the structured emotion information extracted from
|
||||
a statement using LLM. It includes emotion type, intensity, keywords,
|
||||
subject classification, and optional target.
|
||||
|
||||
Attributes:
|
||||
emotion_type: Type of emotion (joy/sadness/anger/fear/surprise/neutral)
|
||||
emotion_intensity: Intensity of emotion (0.0-1.0)
|
||||
emotion_keywords: List of emotion keywords from the statement (max 3)
|
||||
emotion_subject: Subject of emotion (self/other/object)
|
||||
emotion_target: Optional target of emotion (person or object name)
|
||||
"""
|
||||
|
||||
emotion_type: str = Field(
|
||||
...,
|
||||
description="Emotion type: joy/sadness/anger/fear/surprise/neutral"
|
||||
)
|
||||
emotion_intensity: float = Field(
|
||||
...,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Emotion intensity from 0.0 to 1.0"
|
||||
)
|
||||
emotion_keywords: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Emotion keywords extracted from the statement (max 3)"
|
||||
)
|
||||
emotion_subject: str = Field(
|
||||
...,
|
||||
description="Emotion subject: self/other/object"
|
||||
)
|
||||
emotion_target: Optional[str] = Field(
|
||||
None,
|
||||
description="Emotion target: person or object name"
|
||||
)
|
||||
|
||||
@field_validator('emotion_type')
|
||||
@classmethod
|
||||
def validate_emotion_type(cls, v):
|
||||
"""Validate emotion type is one of the valid values."""
|
||||
valid_types = ['joy', 'sadness', 'anger', 'fear', 'surprise', 'neutral']
|
||||
if v not in valid_types:
|
||||
raise ValueError(f"emotion_type must be one of {valid_types}, got {v}")
|
||||
return v
|
||||
|
||||
@field_validator('emotion_subject')
|
||||
@classmethod
|
||||
def validate_emotion_subject(cls, v):
|
||||
"""Validate emotion subject is one of the valid values."""
|
||||
valid_subjects = ['self', 'other', 'object']
|
||||
if v not in valid_subjects:
|
||||
raise ValueError(f"emotion_subject must be one of {valid_subjects}, got {v}")
|
||||
return v
|
||||
|
||||
@field_validator('emotion_keywords')
|
||||
@classmethod
|
||||
def validate_emotion_keywords(cls, v):
|
||||
"""Validate and limit emotion keywords to max 3 items."""
|
||||
if not isinstance(v, list):
|
||||
return []
|
||||
# Limit to max 3 keywords
|
||||
return v[:3]
|
||||
|
||||
@field_validator('emotion_intensity')
|
||||
@classmethod
|
||||
def validate_emotion_intensity(cls, v):
|
||||
"""Validate emotion intensity is within valid range."""
|
||||
if not (0.0 <= v <= 1.0):
|
||||
raise ValueError(f"emotion_intensity must be between 0.0 and 1.0, got {v}")
|
||||
return v
|
||||
@@ -215,24 +215,58 @@ class StatementNode(Node):
|
||||
Attributes:
|
||||
chunk_id: ID of the parent chunk this statement belongs to
|
||||
stmt_type: Type of the statement (from ontology)
|
||||
temporal_info: Temporal information extracted from the statement
|
||||
statement: The actual statement text content
|
||||
connect_strength: Classification of connection strength ('Strong' or 'Weak')
|
||||
emotion_intensity: Optional emotion intensity (0.0-1.0) - displayed on node
|
||||
emotion_target: Optional emotion target (person or object name)
|
||||
emotion_subject: Optional emotion subject (self/other/object)
|
||||
emotion_type: Optional emotion type (joy/sadness/anger/fear/surprise/neutral)
|
||||
emotion_keywords: Optional list of emotion keywords (max 3)
|
||||
temporal_info: Temporal information extracted from the statement
|
||||
valid_at: Optional start date of temporal validity
|
||||
invalid_at: Optional end date of temporal validity
|
||||
statement_embedding: Optional embedding vector for the statement
|
||||
chunk_embedding: Optional embedding vector for the parent chunk
|
||||
connect_strength: Classification of connection strength ('Strong' or 'Weak')
|
||||
config_id: Configuration ID used to process this statement
|
||||
"""
|
||||
# Core fields (ordered as requested)
|
||||
chunk_id: str = Field(..., description="ID of the parent chunk")
|
||||
stmt_type: str = Field(..., description="Type of the statement")
|
||||
temporal_info: TemporalInfo = Field(..., description="Temporal information")
|
||||
statement: str = Field(..., description="The statement text content")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
|
||||
|
||||
# Emotion fields (ordered as requested, emotion_intensity first for display)
|
||||
emotion_intensity: Optional[float] = Field(
|
||||
None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Emotion intensity: 0.0-1.0 (displayed on node)"
|
||||
)
|
||||
emotion_target: Optional[str] = Field(
|
||||
None,
|
||||
description="Emotion target: person or object name"
|
||||
)
|
||||
emotion_subject: Optional[str] = Field(
|
||||
None,
|
||||
description="Emotion subject: self/other/object"
|
||||
)
|
||||
emotion_type: Optional[str] = Field(
|
||||
None,
|
||||
description="Emotion type: joy/sadness/anger/fear/surprise/neutral"
|
||||
)
|
||||
emotion_keywords: Optional[List[str]] = Field(
|
||||
default_factory=list,
|
||||
description="Emotion keywords list, max 3 items"
|
||||
)
|
||||
|
||||
# Temporal fields
|
||||
temporal_info: TemporalInfo = Field(..., description="Temporal information")
|
||||
valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
|
||||
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
|
||||
|
||||
# Embedding and other fields
|
||||
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
|
||||
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)")
|
||||
|
||||
@field_validator('valid_at', 'invalid_at', mode='before')
|
||||
@@ -240,6 +274,39 @@ class StatementNode(Node):
|
||||
def validate_datetime(cls, v):
|
||||
"""使用通用的历史日期解析函数"""
|
||||
return parse_historical_datetime(v)
|
||||
|
||||
@field_validator('emotion_type', mode='before')
|
||||
@classmethod
|
||||
def validate_emotion_type(cls, v):
|
||||
"""Validate emotion type is one of the valid values"""
|
||||
if v is None:
|
||||
return v
|
||||
valid_types = ['joy', 'sadness', 'anger', 'fear', 'surprise', 'neutral']
|
||||
if v not in valid_types:
|
||||
raise ValueError(f"emotion_type must be one of {valid_types}, got {v}")
|
||||
return v
|
||||
|
||||
@field_validator('emotion_subject', mode='before')
|
||||
@classmethod
|
||||
def validate_emotion_subject(cls, v):
|
||||
"""Validate emotion subject is one of the valid values"""
|
||||
if v is None:
|
||||
return v
|
||||
valid_subjects = ['self', 'other', 'object']
|
||||
if v not in valid_subjects:
|
||||
raise ValueError(f"emotion_subject must be one of {valid_subjects}, got {v}")
|
||||
return v
|
||||
|
||||
@field_validator('emotion_keywords', mode='before')
|
||||
@classmethod
|
||||
def validate_emotion_keywords(cls, v):
|
||||
"""Validate emotion keywords list has max 3 items"""
|
||||
if v is None:
|
||||
return []
|
||||
if not isinstance(v, list):
|
||||
return []
|
||||
# Limit to max 3 keywords
|
||||
return v[:3]
|
||||
|
||||
|
||||
class ChunkNode(Node):
|
||||
|
||||
@@ -64,6 +64,11 @@ class Statement(BaseModel):
|
||||
connect_strength: Optional connection strength ('Strong' or 'Weak')
|
||||
temporal_validity: Optional temporal validity range
|
||||
triplet_extraction_info: Optional triplet extraction results
|
||||
emotion_type: Optional emotion type (joy/sadness/anger/fear/surprise/neutral)
|
||||
emotion_intensity: Optional emotion intensity (0.0-1.0)
|
||||
emotion_keywords: Optional list of emotion keywords
|
||||
emotion_subject: Optional emotion subject (self/other/object)
|
||||
emotion_target: Optional emotion target (person or object name)
|
||||
"""
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.")
|
||||
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
|
||||
@@ -80,6 +85,12 @@ class Statement(BaseModel):
|
||||
triplet_extraction_info: Optional[TripletExtractionResponse] = Field(
|
||||
None, description="The triplet extraction information of the statement."
|
||||
)
|
||||
# Emotion fields
|
||||
emotion_type: Optional[str] = Field(None, description="Emotion type: joy/sadness/anger/fear/surprise/neutral")
|
||||
emotion_intensity: Optional[float] = Field(None, ge=0.0, le=1.0, description="Emotion intensity: 0.0-1.0")
|
||||
emotion_keywords: Optional[List[str]] = Field(default_factory=list, description="Emotion keywords, max 3")
|
||||
emotion_subject: Optional[str] = Field(None, description="Emotion subject: self/other/object")
|
||||
emotion_target: Optional[str] = Field(None, description="Emotion target: person or object name")
|
||||
|
||||
|
||||
class ConversationContext(BaseModel):
|
||||
|
||||
@@ -480,7 +480,6 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
||||
- global_redirect: dict losing_id -> canonical_id accumulated across rounds
|
||||
- records: textual logs including per-round/per-block summaries and per-pair decisions
|
||||
"""
|
||||
import asyncio
|
||||
import random
|
||||
# 初始化全局日志和全局ID映射(存储所有轮次的结果)
|
||||
records: List[str] = []
|
||||
|
||||
@@ -36,7 +36,6 @@ from app.core.memory.models.graph_models import (
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.models.variate_config import (
|
||||
ExtractionPipelineConfig,
|
||||
StatementExtractionConfig,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
|
||||
dedup_layers_and_merge_and_return,
|
||||
@@ -182,11 +181,12 @@ class ExtractionOrchestrator:
|
||||
all_statements_list.extend(chunk.statements)
|
||||
total_statements = len(all_statements_list)
|
||||
|
||||
# 步骤 2: 并行执行三元组提取、时间信息提取和基础嵌入生成
|
||||
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取和嵌入生成")
|
||||
# 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
|
||||
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成")
|
||||
(
|
||||
triplet_maps,
|
||||
temporal_maps,
|
||||
emotion_maps,
|
||||
statement_embedding_maps,
|
||||
chunk_embedding_maps,
|
||||
dialog_embeddings,
|
||||
@@ -209,78 +209,13 @@ class ExtractionOrchestrator:
|
||||
logger.info("步骤 3/6: 生成实体嵌入")
|
||||
triplet_maps = await self._generate_entity_embeddings(triplet_maps)
|
||||
|
||||
# 进度回调:按三个阶段分别输出知识抽取结果
|
||||
if self.progress_callback:
|
||||
# 第一阶段:陈述句提取结果
|
||||
for i, stmt in enumerate(all_statements_list[:10]): # 只输出前10个陈述句
|
||||
stmt_result = {
|
||||
"extraction_type": "statement",
|
||||
"statement_index": i + 1,
|
||||
"statement": stmt.statement,
|
||||
"statement_id": stmt.id
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_result", "陈述句提取完成", stmt_result)
|
||||
|
||||
# 第二阶段:三元组提取结果
|
||||
for i, triplet in enumerate(all_triplets_list[:10]): # 只输出前10个三元组
|
||||
triplet_result = {
|
||||
"extraction_type": "triplet",
|
||||
"triplet_index": i + 1,
|
||||
"subject": triplet.subject_name,
|
||||
"predicate": triplet.predicate,
|
||||
"object": triplet.object_name
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_result", "三元组提取完成", triplet_result)
|
||||
|
||||
# 第三阶段:时间提取结果
|
||||
if total_temporal > 0:
|
||||
# 收集时间信息
|
||||
temporal_results = []
|
||||
for dialog in dialog_data_list:
|
||||
for chunk in dialog.chunks:
|
||||
for statement in chunk.statements:
|
||||
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
|
||||
temporal_results.append({
|
||||
"statement_id": statement.id,
|
||||
"statement": statement.statement,
|
||||
"valid_at": statement.temporal_validity.valid_at,
|
||||
"invalid_at": statement.temporal_validity.invalid_at
|
||||
})
|
||||
|
||||
# 输出时间提取结果
|
||||
for i, temporal_result in enumerate(temporal_results[:5]): # 只输出前5个时间提取结果
|
||||
time_result = {
|
||||
"extraction_type": "temporal",
|
||||
"temporal_index": i + 1,
|
||||
"statement": temporal_result["statement"],
|
||||
"valid_at": temporal_result["valid_at"],
|
||||
"invalid_at": temporal_result["invalid_at"]
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result)
|
||||
else:
|
||||
# 如果没有时间信息,也发送一个时间提取完成的消息
|
||||
time_result = {
|
||||
"extraction_type": "temporal",
|
||||
"temporal_index": 0,
|
||||
"message": "未发现时间信息"
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result)
|
||||
|
||||
# 进度回调:知识抽取完成,传递知识抽取的统计信息
|
||||
extraction_stats = {
|
||||
"statements_count": total_statements,
|
||||
"entities_count": total_entities,
|
||||
"triplets_count": total_triplets,
|
||||
"temporal_ranges_count": total_temporal,
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_complete", "知识抽取完成", extraction_stats)
|
||||
|
||||
# 步骤 4: 将提取的数据赋值到语句
|
||||
logger.info("步骤 4/6: 数据赋值")
|
||||
dialog_data_list = await self._assign_extracted_data(
|
||||
dialog_data_list,
|
||||
temporal_maps,
|
||||
triplet_maps,
|
||||
emotion_maps,
|
||||
statement_embedding_maps,
|
||||
chunk_embedding_maps,
|
||||
dialog_embeddings,
|
||||
@@ -288,6 +223,9 @@ class ExtractionOrchestrator:
|
||||
|
||||
# 步骤 5: 创建节点和边
|
||||
logger.info("步骤 5/6: 创建节点和边")
|
||||
|
||||
# 注意:creating_nodes_edges 消息已在知识抽取完成后立即发送
|
||||
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
@@ -307,6 +245,8 @@ class ExtractionOrchestrator:
|
||||
else:
|
||||
logger.info("步骤 6/6: 两阶段去重和消歧")
|
||||
|
||||
# 注意:deduplication 消息已在创建节点和边完成后立即发送
|
||||
|
||||
result = await self._run_dedup_and_write_summary(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
@@ -331,7 +271,7 @@ class ExtractionOrchestrator:
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> List[DialogData]:
|
||||
"""
|
||||
从对话中提取陈述句(优化版:全局分块级并行)
|
||||
从对话中提取陈述句(流式输出版本:边提取边发送进度)
|
||||
|
||||
Args:
|
||||
dialog_data_list: 对话数据列表
|
||||
@@ -339,7 +279,7 @@ class ExtractionOrchestrator:
|
||||
Returns:
|
||||
更新后的对话数据列表(包含提取的陈述句)
|
||||
"""
|
||||
logger.info("开始陈述句提取(全局分块级并行)")
|
||||
logger.info("开始陈述句提取(全局分块级并行 + 流式输出)")
|
||||
|
||||
# 收集所有分块及其元数据
|
||||
all_chunks = []
|
||||
@@ -352,17 +292,44 @@ class ExtractionOrchestrator:
|
||||
chunk_metadata.append((d_idx, c_idx))
|
||||
|
||||
logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取")
|
||||
|
||||
# 用于跟踪已完成的分块数量
|
||||
completed_chunks = 0
|
||||
total_chunks = len(all_chunks)
|
||||
|
||||
# 全局并行处理所有分块
|
||||
async def extract_for_chunk(chunk_data):
|
||||
async def extract_for_chunk(chunk_data, chunk_index):
|
||||
nonlocal completed_chunks
|
||||
chunk, group_id, dialogue_content = chunk_data
|
||||
try:
|
||||
return await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content)
|
||||
statements = await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content)
|
||||
|
||||
# 流式输出:每提取完一个分块的陈述句,立即发送进度
|
||||
# 注意:只在试运行模式下发送陈述句详情,正式模式不发送
|
||||
completed_chunks += 1
|
||||
if self.progress_callback and statements and self.is_pilot_run:
|
||||
# 发送前3个陈述句作为示例
|
||||
for idx, stmt in enumerate(statements[:3]):
|
||||
stmt_result = {
|
||||
"extraction_type": "statement",
|
||||
"statement": stmt.statement,
|
||||
"statement_id": stmt.id,
|
||||
"chunk_progress": f"{completed_chunks}/{total_chunks}",
|
||||
"statement_index_in_chunk": idx + 1
|
||||
}
|
||||
await self.progress_callback(
|
||||
"knowledge_extraction_result",
|
||||
f"陈述句提取中 ({completed_chunks}/{total_chunks})",
|
||||
stmt_result
|
||||
)
|
||||
|
||||
return statements
|
||||
except Exception as e:
|
||||
logger.error(f"分块 {chunk.id} 陈述句提取失败: {e}")
|
||||
completed_chunks += 1
|
||||
return []
|
||||
|
||||
tasks = [extract_for_chunk(chunk_data) for chunk_data in all_chunks]
|
||||
tasks = [extract_for_chunk(chunk_data, i) for i, chunk_data in enumerate(all_chunks)]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 将结果分配回对话
|
||||
@@ -394,7 +361,7 @@ class ExtractionOrchestrator:
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从对话中提取三元组(优化版:全局陈述句级并行)
|
||||
从对话中提取三元组(流式输出版本:边提取边发送进度)
|
||||
|
||||
Args:
|
||||
dialog_data_list: 对话数据列表
|
||||
@@ -402,7 +369,7 @@ class ExtractionOrchestrator:
|
||||
Returns:
|
||||
三元组映射列表,每个对话对应一个字典
|
||||
"""
|
||||
logger.info("开始三元组提取(全局陈述句级并行)")
|
||||
logger.info("开始三元组提取(全局陈述句级并行 + 流式输出)")
|
||||
|
||||
# 收集所有陈述句及其元数据
|
||||
all_statements = []
|
||||
@@ -415,20 +382,32 @@ class ExtractionOrchestrator:
|
||||
statement_metadata.append((d_idx, statement.id))
|
||||
|
||||
logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取三元组")
|
||||
|
||||
# 用于跟踪已完成的陈述句数量
|
||||
completed_statements = 0
|
||||
total_statements = len(all_statements)
|
||||
|
||||
# 全局并行处理所有陈述句
|
||||
async def extract_for_statement(stmt_data):
|
||||
async def extract_for_statement(stmt_data, stmt_index):
|
||||
nonlocal completed_statements
|
||||
statement, chunk_content = stmt_data
|
||||
try:
|
||||
return await self.triplet_extractor._extract_triplets(statement, chunk_content)
|
||||
triplet_info = await self.triplet_extractor._extract_triplets(statement, chunk_content)
|
||||
|
||||
# 注意:不再发送三元组提取的流式输出
|
||||
# 三元组提取在后台执行,但不向前端发送详细信息
|
||||
completed_statements += 1
|
||||
|
||||
return triplet_info
|
||||
except Exception as e:
|
||||
logger.error(f"陈述句 {statement.id} 三元组提取失败: {e}")
|
||||
completed_statements += 1
|
||||
from app.core.memory.models.triplet_models import (
|
||||
TripletExtractionResponse,
|
||||
)
|
||||
return TripletExtractionResponse(triplets=[], entities=[])
|
||||
|
||||
tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements]
|
||||
tasks = [extract_for_statement(stmt_data, i) for i, stmt_data in enumerate(all_statements)]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 将结果组织成对话级别的映射
|
||||
@@ -465,7 +444,7 @@ class ExtractionOrchestrator:
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从对话中提取时间信息(优化版:全局陈述句级并行)
|
||||
从对话中提取时间信息(流式输出版本:边提取边发送进度)
|
||||
|
||||
Args:
|
||||
dialog_data_list: 对话数据列表
|
||||
@@ -473,7 +452,21 @@ class ExtractionOrchestrator:
|
||||
Returns:
|
||||
时间信息映射列表,每个对话对应一个字典
|
||||
"""
|
||||
logger.info("开始时间信息提取(全局陈述句级并行)")
|
||||
# 试运行模式:跳过时间提取以节省时间
|
||||
if self.is_pilot_run:
|
||||
logger.info("试运行模式:跳过时间信息提取(节省约 10-15 秒)")
|
||||
# 为所有陈述句返回空的时间范围
|
||||
from app.core.memory.models.message_models import TemporalValidityRange
|
||||
temporal_maps = []
|
||||
for dialog in dialog_data_list:
|
||||
temporal_map = {}
|
||||
for chunk in dialog.chunks:
|
||||
for statement in chunk.statements:
|
||||
temporal_map[statement.id] = TemporalValidityRange(valid_at=None, invalid_at=None)
|
||||
temporal_maps.append(temporal_map)
|
||||
return temporal_maps
|
||||
|
||||
logger.info("开始时间信息提取(全局陈述句级并行 + 流式输出)")
|
||||
|
||||
# 收集所有需要提取时间的陈述句
|
||||
all_statements = []
|
||||
@@ -501,18 +494,30 @@ class ExtractionOrchestrator:
|
||||
statement_metadata.append((d_idx, statement.id))
|
||||
|
||||
logger.info(f"收集到 {len(all_statements)} 个需要时间提取的陈述句,开始全局并行提取")
|
||||
|
||||
# 用于跟踪已完成的时间提取数量
|
||||
completed_temporal = 0
|
||||
total_temporal_statements = len(all_statements)
|
||||
|
||||
# 全局并行处理所有陈述句
|
||||
async def extract_for_statement(stmt_data):
|
||||
async def extract_for_statement(stmt_data, stmt_index):
|
||||
nonlocal completed_temporal
|
||||
statement, ref_dates = stmt_data
|
||||
try:
|
||||
return await self.temporal_extractor._extract_temporal_ranges(statement, ref_dates)
|
||||
temporal_range = await self.temporal_extractor._extract_temporal_ranges(statement, ref_dates)
|
||||
|
||||
# 注意:不再发送时间提取的流式输出
|
||||
# 时间提取在后台执行,但不向前端发送详细信息
|
||||
completed_temporal += 1
|
||||
|
||||
return temporal_range
|
||||
except Exception as e:
|
||||
logger.error(f"陈述句 {statement.id} 时间信息提取失败: {e}")
|
||||
completed_temporal += 1
|
||||
from app.core.memory.models.message_models import TemporalValidityRange
|
||||
return TemporalValidityRange(valid_at=None, invalid_at=None)
|
||||
|
||||
tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements]
|
||||
tasks = [extract_for_statement(stmt_data, i) for i, stmt_data in enumerate(all_statements)]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 将结果组织成对话级别的映射
|
||||
@@ -542,9 +547,108 @@ class ExtractionOrchestrator:
|
||||
|
||||
return temporal_maps
|
||||
|
||||
async def _extract_emotions(
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从对话中提取情绪信息(优化版:全局陈述句级并行)
|
||||
|
||||
Args:
|
||||
dialog_data_list: 对话数据列表
|
||||
|
||||
Returns:
|
||||
情绪信息映射列表,每个对话对应一个字典
|
||||
"""
|
||||
logger.info("开始情绪信息提取(全局陈述句级并行)")
|
||||
|
||||
# 收集所有陈述句及其配置
|
||||
all_statements = []
|
||||
statement_metadata = [] # (dialog_idx, statement_id)
|
||||
|
||||
# 获取第一个对话的config_id来加载配置
|
||||
config_id = None
|
||||
if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'):
|
||||
config_id = dialog_data_list[0].config_id
|
||||
|
||||
# 加载DataConfig
|
||||
data_config = None
|
||||
if config_id:
|
||||
try:
|
||||
from app.db import SessionLocal
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
data_config = DataConfigRepository.get_by_id(db, config_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
if data_config and not data_config.emotion_enabled:
|
||||
logger.info("情绪提取已在配置中禁用,跳过情绪提取")
|
||||
return [{} for _ in dialog_data_list]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"加载DataConfig失败: {e},将跳过情绪提取")
|
||||
return [{} for _ in dialog_data_list]
|
||||
else:
|
||||
logger.info("未找到config_id,跳过情绪提取")
|
||||
return [{} for _ in dialog_data_list]
|
||||
|
||||
# 如果配置未启用情绪提取,直接返回空映射
|
||||
if not data_config or not data_config.emotion_enabled:
|
||||
logger.info("情绪提取未启用,跳过")
|
||||
return [{} for _ in dialog_data_list]
|
||||
|
||||
# 收集所有陈述句
|
||||
for d_idx, dialog in enumerate(dialog_data_list):
|
||||
for chunk in dialog.chunks:
|
||||
for statement in chunk.statements:
|
||||
all_statements.append((statement, data_config))
|
||||
statement_metadata.append((d_idx, statement.id))
|
||||
|
||||
logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取情绪")
|
||||
|
||||
# 初始化情绪提取服务
|
||||
from app.services.emotion_extraction_service import EmotionExtractionService
|
||||
emotion_service = EmotionExtractionService(
|
||||
llm_id=data_config.emotion_model_id if data_config.emotion_model_id else None
|
||||
)
|
||||
|
||||
# 全局并行处理所有陈述句
|
||||
async def extract_for_statement(stmt_data):
|
||||
statement, config = stmt_data
|
||||
try:
|
||||
return await emotion_service.extract_emotion(statement.statement, config)
|
||||
except Exception as e:
|
||||
logger.error(f"陈述句 {statement.id} 情绪提取失败: {e}")
|
||||
return None
|
||||
|
||||
tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 将结果组织成对话级别的映射
|
||||
emotion_maps = [{} for _ in dialog_data_list]
|
||||
successful_extractions = 0
|
||||
|
||||
for i, result in enumerate(results):
|
||||
d_idx, stmt_id = statement_metadata[i]
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"陈述句处理异常: {result}")
|
||||
emotion_maps[d_idx][stmt_id] = None
|
||||
else:
|
||||
emotion_maps[d_idx][stmt_id] = result
|
||||
if result is not None:
|
||||
successful_extractions += 1
|
||||
|
||||
# 统计提取结果
|
||||
logger.info(f"情绪信息提取完成,共成功提取 {successful_extractions}/{len(all_statements)} 个情绪")
|
||||
|
||||
return emotion_maps
|
||||
|
||||
async def _parallel_extract_and_embed(
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> Tuple[
|
||||
List[Dict[str, Any]],
|
||||
List[Dict[str, Any]],
|
||||
List[Dict[str, Any]],
|
||||
List[Dict[str, List[float]]],
|
||||
@@ -552,35 +656,39 @@ class ExtractionOrchestrator:
|
||||
List[List[float]],
|
||||
]:
|
||||
"""
|
||||
并行执行三元组提取、时间信息提取和基础嵌入生成
|
||||
并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
|
||||
|
||||
这三个任务都依赖陈述句提取的结果,但彼此独立,可以并行执行:
|
||||
这四个任务都依赖陈述句提取的结果,但彼此独立,可以并行执行:
|
||||
- 三元组提取:从陈述句中提取实体和关系
|
||||
- 时间信息提取:从陈述句中提取时间范围
|
||||
- 情绪提取:从陈述句中提取情绪信息
|
||||
- 嵌入生成:为陈述句、分块和对话生成向量(不依赖三元组)
|
||||
|
||||
Args:
|
||||
dialog_data_list: 对话数据列表
|
||||
|
||||
Returns:
|
||||
五个列表的元组:
|
||||
六个列表的元组:
|
||||
- 三元组映射列表
|
||||
- 时间信息映射列表
|
||||
- 情绪映射列表
|
||||
- 陈述句嵌入映射列表
|
||||
- 分块嵌入映射列表
|
||||
- 对话嵌入列表
|
||||
"""
|
||||
logger.info("并行执行:三元组提取 + 时间信息提取 + 基础嵌入生成")
|
||||
logger.info("并行执行:三元组提取 + 时间信息提取 + 情绪提取 + 基础嵌入生成")
|
||||
|
||||
# 创建三个并行任务
|
||||
# 创建四个并行任务
|
||||
triplet_task = self._extract_triplets(dialog_data_list)
|
||||
temporal_task = self._extract_temporal(dialog_data_list)
|
||||
emotion_task = self._extract_emotions(dialog_data_list)
|
||||
embedding_task = self._generate_basic_embeddings(dialog_data_list)
|
||||
|
||||
# 并行执行
|
||||
results = await asyncio.gather(
|
||||
triplet_task,
|
||||
temporal_task,
|
||||
emotion_task,
|
||||
embedding_task,
|
||||
return_exceptions=True
|
||||
)
|
||||
@@ -588,19 +696,21 @@ class ExtractionOrchestrator:
|
||||
# 解包结果
|
||||
triplet_maps = results[0] if not isinstance(results[0], Exception) else [{} for _ in dialog_data_list]
|
||||
temporal_maps = results[1] if not isinstance(results[1], Exception) else [{} for _ in dialog_data_list]
|
||||
emotion_maps = results[2] if not isinstance(results[2], Exception) else [{} for _ in dialog_data_list]
|
||||
|
||||
if isinstance(results[2], Exception):
|
||||
logger.error(f"基础嵌入生成失败: {results[2]}")
|
||||
if isinstance(results[3], Exception):
|
||||
logger.error(f"基础嵌入生成失败: {results[3]}")
|
||||
statement_embedding_maps = [{} for _ in dialog_data_list]
|
||||
chunk_embedding_maps = [{} for _ in dialog_data_list]
|
||||
dialog_embeddings = [[] for _ in dialog_data_list]
|
||||
else:
|
||||
statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = results[2]
|
||||
statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = results[3]
|
||||
|
||||
logger.info("并行任务执行完成")
|
||||
return (
|
||||
triplet_maps,
|
||||
temporal_maps,
|
||||
emotion_maps,
|
||||
statement_embedding_maps,
|
||||
chunk_embedding_maps,
|
||||
dialog_embeddings,
|
||||
@@ -711,6 +821,7 @@ class ExtractionOrchestrator:
|
||||
dialog_data_list: List[DialogData],
|
||||
temporal_maps: List[Dict[str, Any]],
|
||||
triplet_maps: List[Dict[str, Any]],
|
||||
emotion_maps: List[Dict[str, Any]],
|
||||
statement_embedding_maps: List[Dict[str, List[float]]],
|
||||
chunk_embedding_maps: List[Dict[str, List[float]]],
|
||||
dialog_embeddings: List[List[float]],
|
||||
@@ -722,6 +833,7 @@ class ExtractionOrchestrator:
|
||||
dialog_data_list: 对话数据列表
|
||||
temporal_maps: 时间信息映射列表
|
||||
triplet_maps: 三元组映射列表
|
||||
emotion_maps: 情绪信息映射列表
|
||||
statement_embedding_maps: 陈述句嵌入映射列表
|
||||
chunk_embedding_maps: 分块嵌入映射列表
|
||||
dialog_embeddings: 对话嵌入列表
|
||||
@@ -736,6 +848,7 @@ class ExtractionOrchestrator:
|
||||
if (
|
||||
len(temporal_maps) != expected_length
|
||||
or len(triplet_maps) != expected_length
|
||||
or len(emotion_maps) != expected_length
|
||||
or len(statement_embedding_maps) != expected_length
|
||||
or len(chunk_embedding_maps) != expected_length
|
||||
or len(dialog_embeddings) != expected_length
|
||||
@@ -743,6 +856,7 @@ class ExtractionOrchestrator:
|
||||
logger.warning(
|
||||
f"数据大小不匹配 - 对话: {len(dialog_data_list)}, "
|
||||
f"时间映射: {len(temporal_maps)}, 三元组映射: {len(triplet_maps)}, "
|
||||
f"情绪映射: {len(emotion_maps)}, "
|
||||
f"陈述句嵌入: {len(statement_embedding_maps)}, "
|
||||
f"分块嵌入: {len(chunk_embedding_maps)}, "
|
||||
f"对话嵌入: {len(dialog_embeddings)}"
|
||||
@@ -751,6 +865,7 @@ class ExtractionOrchestrator:
|
||||
total_statements = 0
|
||||
assigned_temporal = 0
|
||||
assigned_triplets = 0
|
||||
assigned_emotions = 0
|
||||
assigned_statement_embeddings = 0
|
||||
assigned_chunk_embeddings = 0
|
||||
assigned_dialog_embeddings = 0
|
||||
@@ -758,12 +873,13 @@ class ExtractionOrchestrator:
|
||||
# 处理每个对话
|
||||
for i, dialog_data in enumerate(dialog_data_list):
|
||||
# 检查是否有缺失的数据
|
||||
if i >= len(temporal_maps) or i >= len(triplet_maps):
|
||||
if i >= len(temporal_maps) or i >= len(triplet_maps) or i >= len(emotion_maps):
|
||||
logger.warning(f"对话 {dialog_data.id} 缺少提取数据,跳过赋值")
|
||||
continue
|
||||
|
||||
temporal_map = temporal_maps[i]
|
||||
triplet_map = triplet_maps[i]
|
||||
emotion_map = emotion_maps[i]
|
||||
statement_embedding_map = statement_embedding_maps[i] if i < len(statement_embedding_maps) else {}
|
||||
chunk_embedding_map = chunk_embedding_maps[i] if i < len(chunk_embedding_maps) else {}
|
||||
dialog_embedding = dialog_embeddings[i] if i < len(dialog_embeddings) else []
|
||||
@@ -794,6 +910,18 @@ class ExtractionOrchestrator:
|
||||
statement.triplet_extraction_info = triplet_map[statement.id]
|
||||
assigned_triplets += 1
|
||||
|
||||
# 赋值情绪信息
|
||||
if statement.id in emotion_map:
|
||||
emotion_data = emotion_map[statement.id]
|
||||
if emotion_data is not None:
|
||||
# 将EmotionExtraction对象的字段赋值到Statement
|
||||
statement.emotion_type = emotion_data.emotion_type
|
||||
statement.emotion_intensity = emotion_data.emotion_intensity
|
||||
statement.emotion_keywords = emotion_data.emotion_keywords
|
||||
statement.emotion_subject = emotion_data.emotion_subject
|
||||
statement.emotion_target = emotion_data.emotion_target
|
||||
assigned_emotions += 1
|
||||
|
||||
# 赋值陈述句嵌入
|
||||
if statement.id in statement_embedding_map:
|
||||
statement.statement_embedding = statement_embedding_map[statement.id]
|
||||
@@ -802,6 +930,7 @@ class ExtractionOrchestrator:
|
||||
logger.info(
|
||||
f"数据赋值完成 - 总陈述句: {total_statements}, "
|
||||
f"时间信息: {assigned_temporal}, 三元组: {assigned_triplets}, "
|
||||
f"情绪信息: {assigned_emotions}, "
|
||||
f"陈述句嵌入: {assigned_statement_embeddings}, "
|
||||
f"分块嵌入: {assigned_chunk_embeddings}, "
|
||||
f"对话嵌入: {assigned_dialog_embeddings}"
|
||||
@@ -833,9 +962,7 @@ class ExtractionOrchestrator:
|
||||
"""
|
||||
logger.info("开始创建节点和边")
|
||||
|
||||
# 进度回调:正在创建节点和边
|
||||
if self.progress_callback:
|
||||
await self.progress_callback("creating_nodes_edges", "正在创建节点和边...")
|
||||
# 注意:开始消息已在 run 方法中发送,这里不再重复发送
|
||||
|
||||
dialogue_nodes = []
|
||||
chunk_nodes = []
|
||||
@@ -847,8 +974,13 @@ class ExtractionOrchestrator:
|
||||
|
||||
# 用于去重的集合
|
||||
entity_id_set = set()
|
||||
|
||||
# 用于跟踪进度
|
||||
total_dialogs = len(dialog_data_list)
|
||||
processed_dialogs = 0
|
||||
|
||||
for dialog_data in dialog_data_list:
|
||||
processed_dialogs += 1
|
||||
# 创建对话节点
|
||||
dialogue_node = DialogueNode(
|
||||
id=dialog_data.id,
|
||||
@@ -908,6 +1040,12 @@ class ExtractionOrchestrator:
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
|
||||
# Emotion fields
|
||||
emotion_type=getattr(statement, 'emotion_type', None),
|
||||
emotion_intensity=getattr(statement, 'emotion_intensity', None),
|
||||
emotion_keywords=getattr(statement, 'emotion_keywords', None),
|
||||
emotion_subject=getattr(statement, 'emotion_subject', None),
|
||||
emotion_target=getattr(statement, 'emotion_target', None),
|
||||
)
|
||||
statement_nodes.append(statement_node)
|
||||
|
||||
@@ -995,6 +1133,26 @@ class ExtractionOrchestrator:
|
||||
expired_at=dialog_data.expired_at,
|
||||
)
|
||||
entity_entity_edges.append(entity_entity_edge)
|
||||
|
||||
# 流式输出:每创建一个关系边,立即发送进度(限制发送数量)
|
||||
if self.progress_callback and len(entity_entity_edges) <= 10:
|
||||
# 获取实体名称
|
||||
source_name = triplet.subject_name
|
||||
target_name = triplet.object_name
|
||||
relationship_result = {
|
||||
"result_type": "relationship_creation",
|
||||
"relationship_index": len(entity_entity_edges),
|
||||
"source_entity": source_name,
|
||||
"relation_type": triplet.predicate,
|
||||
"target_entity": target_name,
|
||||
"relationship_text": f"{source_name} -[{triplet.predicate}]-> {target_name}",
|
||||
"dialog_progress": f"{processed_dialogs}/{total_dialogs}"
|
||||
}
|
||||
await self.progress_callback(
|
||||
"creating_nodes_edges_result",
|
||||
f"关系创建中 ({processed_dialogs}/{total_dialogs})",
|
||||
relationship_result
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"跳过三元组 - 无法找到实体ID: subject_id={triplet.subject_id}, "
|
||||
@@ -1009,12 +1167,9 @@ class ExtractionOrchestrator:
|
||||
f"实体-实体边: {len(entity_entity_edges)}"
|
||||
)
|
||||
|
||||
# 进度回调:只输出关系创建结果
|
||||
# 进度回调:创建节点和边完成,传递结果统计
|
||||
# 注意:具体的关系创建结果已经在创建过程中实时发送了
|
||||
if self.progress_callback:
|
||||
# 输出关系创建结果
|
||||
await self._output_relationship_creation_results(entity_entity_edges, entity_nodes)
|
||||
|
||||
# 进度回调:创建节点和边完成,传递结果统计
|
||||
nodes_edges_stats = {
|
||||
"dialogue_nodes_count": len(dialogue_nodes),
|
||||
"chunk_nodes_count": len(chunk_nodes),
|
||||
@@ -1072,7 +1227,7 @@ class ExtractionOrchestrator:
|
||||
"""
|
||||
logger.info("开始两阶段实体去重和消歧")
|
||||
|
||||
# 进度回调:正在去重消歧
|
||||
# 进度回调:发送去重消歧开始消息
|
||||
if self.progress_callback:
|
||||
await self.progress_callback("deduplication", "正在去重消歧...")
|
||||
|
||||
@@ -1157,25 +1312,26 @@ class ExtractionOrchestrator:
|
||||
f"实体-实体边减少 {len(entity_entity_edges) - len(final_entity_entity_edges)}"
|
||||
)
|
||||
|
||||
# 进度回调:输出去重消歧的具体结果
|
||||
# 流式输出:实时输出去重消歧的具体结果
|
||||
if self.progress_callback:
|
||||
# 分析实体合并情况
|
||||
# 分析实体合并情况(使用内存中的记录)
|
||||
merge_info = await self._analyze_entity_merges(entity_nodes, final_entity_nodes)
|
||||
|
||||
# 输出去重合并的实体示例
|
||||
# 逐个输出去重合并的实体示例
|
||||
for i, merge_detail in enumerate(merge_info[:5]): # 输出前5个去重结果
|
||||
dedup_result = {
|
||||
"result_type": "entity_merge",
|
||||
"merged_entity_name": merge_detail["main_entity_name"],
|
||||
"merged_count": merge_detail["merged_count"],
|
||||
"merge_progress": f"{i + 1}/{min(len(merge_info), 5)}",
|
||||
"message": f"{merge_detail['main_entity_name']}合并{merge_detail['merged_count']}个:相似实体已合并"
|
||||
}
|
||||
await self.progress_callback("dedup_disambiguation_result", "实体去重完成", dedup_result)
|
||||
await self.progress_callback("dedup_disambiguation_result", "实体去重中", dedup_result)
|
||||
|
||||
# 分析实体消歧情况
|
||||
# 分析实体消歧情况(使用内存中的记录)
|
||||
disamb_info = await self._analyze_entity_disambiguation(entity_nodes, final_entity_nodes)
|
||||
|
||||
# 输出实体消歧的结果
|
||||
# 逐个输出实体消歧的结果
|
||||
for i, disamb_detail in enumerate(disamb_info[:5]): # 输出前5个消歧结果
|
||||
disamb_result = {
|
||||
"result_type": "entity_disambiguation",
|
||||
@@ -1183,11 +1339,10 @@ class ExtractionOrchestrator:
|
||||
"disambiguation_type": disamb_detail["disamb_type"],
|
||||
"confidence": disamb_detail.get("confidence", "unknown"),
|
||||
"reason": disamb_detail.get("reason", ""),
|
||||
"disamb_progress": f"{i + 1}/{min(len(disamb_info), 5)}",
|
||||
"message": f"{disamb_detail['entity_name']}消歧完成:{disamb_detail['disamb_type']}"
|
||||
}
|
||||
await self.progress_callback("dedup_disambiguation_result", "实体消歧完成", disamb_result)
|
||||
|
||||
|
||||
await self.progress_callback("dedup_disambiguation_result", "实体消歧中", disamb_result)
|
||||
|
||||
# 进度回调:去重消歧完成,传递去重和消歧的具体效果
|
||||
await self._send_dedup_progress_callback(
|
||||
@@ -1299,7 +1454,7 @@ class ExtractionOrchestrator:
|
||||
if match:
|
||||
entity1_name = match.group(1).strip()
|
||||
entity1_type = match.group(2)
|
||||
entity2_name = match.group(3).strip()
|
||||
match.group(3).strip()
|
||||
entity2_type = match.group(4)
|
||||
|
||||
# 提取置信度和原因
|
||||
@@ -1611,7 +1766,6 @@ async def get_chunked_dialogs(
|
||||
包含分块的 DialogData 对象列表
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
# 加载测试数据
|
||||
@@ -1794,7 +1948,6 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
Returns:
|
||||
带 chunks 的 DialogData 列表
|
||||
"""
|
||||
import os
|
||||
print("\n=== 完整数据处理流程(包含预处理)===")
|
||||
|
||||
if input_data_path is None:
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
{
|
||||
"memory_verify": {
|
||||
"source_data": [
|
||||
{
|
||||
"statement_name": "用户是2023年春天去北京工作的。",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户后来基本一直都在北京上班。",
|
||||
"statement_id": "4cba5ac08b674d7fb1e2ae634d2b8f0b",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户从2023年开始就一直在北京生活。",
|
||||
"statement_id": "e612a44da4db483993c350df7c97a1a1",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户从来没有长期离开过北京。",
|
||||
"statement_id": "b3c787a2e33c49f7981accabbbb4538a",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "由于公司调整,用户在2024年上半年被调到上海待了差不多半年。",
|
||||
"statement_id": "64cde4230cb24a4da726e7db9e7aa616",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户在被调到上海期间每天都是在上海办公室打卡。",
|
||||
"statement_id": "8b1b12e23b844b8088dfeb67da6ad669",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户在入职时使用的身份信息是之前的,身份证号为11010119950308123X。",
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户的银行卡号是6222023847595898。",
|
||||
"statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户的身份信息和银行卡信息一直没变。",
|
||||
"statement_id": "b3ca618e1e204b83bebd70e75cf2073f",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户认为在上海的那段时间更多算是远程配合。",
|
||||
"statement_id": "150af89d2c154e6eb41ff1a91e37f962",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
}
|
||||
],
|
||||
"databasets": [
|
||||
{
|
||||
"entity1_name": "Person",
|
||||
"description": "表示人类个体的通用类型",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "用户",
|
||||
"entity2": {
|
||||
"entity_idx": 0,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"connect_strength": "strong",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Person",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "用户",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "3d3896797b334572a80d57590026063d"
|
||||
}
|
||||
},
|
||||
{
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "身份信息",
|
||||
"entity2": {
|
||||
"entity_idx": 1,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"connect_strength": "Strong",
|
||||
"description": "用于个人身份识别的数据",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Information",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "身份信息",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "aa766a517e82490599a9b3af54cfd933"
|
||||
}
|
||||
},
|
||||
{
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "6222023847595898",
|
||||
"entity2": {
|
||||
"entity_idx": 1,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"connect_strength": "Strong",
|
||||
"description": "用户的银行卡号码",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Numeric",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "6222023847595898",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "610ba361918f4e68a65ce6ad06e5c7a0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "上海办公室",
|
||||
"entity2": {
|
||||
"entity_idx": 1,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"aliases": ["上海办"],
|
||||
"connect_strength": "Strong",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"description": "位于上海的工作办公场所",
|
||||
"statement_id": "8b1b12e23b844b8088dfeb67da6ad669",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Location",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "上海办公室",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "fb702ef695c14e14af3e56786bc8815b"
|
||||
}
|
||||
},
|
||||
{
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "北京",
|
||||
"entity2": {
|
||||
"entity_idx": 2,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"aliases": ["京", "京城", "北平"],
|
||||
"connect_strength": "strong",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"description": "中国的首都城市,用户主要工作和生活所在地",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Location",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "北京",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "81b2d1a571bb46a08a2d7a1e87efb945"
|
||||
}
|
||||
},
|
||||
{
|
||||
"entity1_name": "11010119950308123X",
|
||||
"description": "具体的身份证号码值",
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "身份证号",
|
||||
"entity2": {
|
||||
"entity_idx": 2,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"connect_strength": "strong",
|
||||
"description": "中华人民共和国公民的身份号码",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Identifier",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "身份证号",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "3e5f920645b2404fadb0e9ff60d1306e"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -8,17 +8,21 @@
|
||||
4. 反思结果应用 - 更新记忆库
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.response_utils import success
|
||||
from app.repositories.neo4j.cypher_queries import neo4j_query_part, neo4j_statement_part, neo4j_query_all, neo4j_statement_all
|
||||
from app.repositories.neo4j.neo4j_update import neo4j_data
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
# 配置日志
|
||||
_root_logger = logging.getLogger()
|
||||
@@ -33,14 +37,14 @@ else:
|
||||
|
||||
class ReflectionRange(str, Enum):
|
||||
"""反思范围枚举"""
|
||||
RETRIEVAL = "retrieval" # 从检索结果中反思
|
||||
DATABASE = "database" # 从整个数据库中反思
|
||||
PARTIAL = "partial" # 从检索结果中反思
|
||||
ALL = "all" # 从整个数据库中反思
|
||||
|
||||
|
||||
class ReflectionBaseline(str, Enum):
|
||||
"""反思基线枚举"""
|
||||
TIME = "TIME" # 基于时间的反思
|
||||
FACT = "FACT" # 基于事实的反思
|
||||
TIME = "TIME" # 基于时间的反思
|
||||
FACT = "FACT" # 基于事实的反思
|
||||
HYBRID = "HYBRID" # 混合反思
|
||||
|
||||
|
||||
@@ -48,9 +52,16 @@ class ReflectionConfig(BaseModel):
|
||||
"""反思引擎配置"""
|
||||
enabled: bool = False
|
||||
iteration_period: str = "3" # 反思周期
|
||||
reflexion_range: ReflectionRange = ReflectionRange.RETRIEVAL
|
||||
reflexion_range: ReflectionRange = ReflectionRange.PARTIAL
|
||||
baseline: ReflectionBaseline = ReflectionBaseline.TIME
|
||||
concurrency: int = Field(default=5, description="并发数量")
|
||||
model_id: Optional[str] = None # 模型ID
|
||||
end_user_id: Optional[str] = None
|
||||
output_example: Optional[str] = None # 输出示例
|
||||
|
||||
# 评估相关字段
|
||||
memory_verify: bool = True # 记忆验证
|
||||
quality_assessment: bool = True # 质量评估
|
||||
violation_handling_strategy: str = "warn" # 违规处理策略
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
@@ -75,16 +86,16 @@ class ReflectionEngine:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ReflectionConfig,
|
||||
neo4j_connector: Optional[Any] = None,
|
||||
llm_client: Optional[Any] = None,
|
||||
get_data_func: Optional[Any] = None,
|
||||
render_evaluate_prompt_func: Optional[Any] = None,
|
||||
render_reflexion_prompt_func: Optional[Any] = None,
|
||||
conflict_schema: Optional[Any] = None,
|
||||
reflexion_schema: Optional[Any] = None,
|
||||
update_query: Optional[str] = None
|
||||
self,
|
||||
config: ReflectionConfig,
|
||||
neo4j_connector: Optional[Any] = None,
|
||||
llm_client: Optional[Any] = None,
|
||||
get_data_func: Optional[Any] = None,
|
||||
render_evaluate_prompt_func: Optional[Any] = None,
|
||||
render_reflexion_prompt_func: Optional[Any] = None,
|
||||
conflict_schema: Optional[Any] = None,
|
||||
reflexion_schema: Optional[Any] = None,
|
||||
update_query: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
初始化反思引擎
|
||||
@@ -109,7 +120,7 @@ class ReflectionEngine:
|
||||
self.conflict_schema = conflict_schema
|
||||
self.reflexion_schema = reflexion_schema
|
||||
self.update_query = update_query
|
||||
self._semaphore = asyncio.Semaphore(config.concurrency)
|
||||
self._semaphore = asyncio.Semaphore(5) # 默认并发数为5
|
||||
|
||||
# 延迟导入以避免循环依赖
|
||||
self._lazy_init_done = False
|
||||
@@ -127,11 +138,21 @@ class ReflectionEngine:
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
elif isinstance(self.llm_client, str):
|
||||
# 如果 llm_client 是字符串(model_id),则用它初始化客户端
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
model_id = self.llm_client
|
||||
self.llm_client = get_llm_client(model_id)
|
||||
|
||||
if self.get_data_func is None:
|
||||
from app.core.memory.utils.config.get_data import get_data
|
||||
self.get_data_func = get_data
|
||||
|
||||
# 导入get_data_statement函数
|
||||
if not hasattr(self, 'get_data_statement'):
|
||||
from app.core.memory.utils.config.get_data import get_data_statement
|
||||
self.get_data_statement = get_data_statement
|
||||
|
||||
if self.render_evaluate_prompt_func is None:
|
||||
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
|
||||
self.render_evaluate_prompt_func = render_evaluate_prompt
|
||||
@@ -154,13 +175,11 @@ class ReflectionEngine:
|
||||
|
||||
self._lazy_init_done = True
|
||||
|
||||
async def execute_reflection(self, host_id: uuid.UUID) -> ReflectionResult:
|
||||
async def execute_reflection(self, host_id) -> ReflectionResult:
|
||||
"""
|
||||
执行完整的反思流程
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
"""
|
||||
@@ -176,9 +195,10 @@ class ReflectionEngine:
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
logging.info("====== 自我反思流程开始 ======")
|
||||
|
||||
print(self.config.baseline, self.config.memory_verify, self.config.quality_assessment)
|
||||
try:
|
||||
# 1. 获取反思数据
|
||||
reflexion_data = await self._get_reflexion_data(host_id)
|
||||
reflexion_data, statement_databasets = await self._get_reflexion_data(host_id)
|
||||
if not reflexion_data:
|
||||
return ReflectionResult(
|
||||
success=True,
|
||||
@@ -187,22 +207,21 @@ class ReflectionEngine:
|
||||
)
|
||||
|
||||
# 2. 检测冲突(基于事实的反思)
|
||||
conflict_data = await self._detect_conflicts(reflexion_data)
|
||||
if not conflict_data:
|
||||
return ReflectionResult(
|
||||
success=True,
|
||||
message="无冲突,无需反思",
|
||||
execution_time=asyncio.get_event_loop().time() - start_time
|
||||
)
|
||||
conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets)
|
||||
print(100 * '-')
|
||||
print(conflict_data)
|
||||
print(100 * '-')
|
||||
|
||||
conflicts_found = len(conflict_data)
|
||||
logging.info(f"发现 {conflicts_found} 个冲突")
|
||||
# 检查是否真的有冲突
|
||||
has_conflict = conflict_data[0].get('conflict', False)
|
||||
conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0
|
||||
logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突")
|
||||
|
||||
# 记录冲突数据
|
||||
await self._log_data("conflict", conflict_data)
|
||||
|
||||
# 3. 解决冲突
|
||||
solved_data = await self._resolve_conflicts(conflict_data)
|
||||
solved_data = await self._resolve_conflicts(conflict_data, statement_databasets)
|
||||
if not solved_data:
|
||||
return ReflectionResult(
|
||||
success=False,
|
||||
@@ -210,6 +229,9 @@ class ReflectionEngine:
|
||||
conflicts_found=conflicts_found,
|
||||
execution_time=asyncio.get_event_loop().time() - start_time
|
||||
)
|
||||
print(100 * '*')
|
||||
print(solved_data)
|
||||
print(100 * '*')
|
||||
|
||||
conflicts_resolved = len(solved_data)
|
||||
logging.info(f"解决了 {conflicts_resolved} 个冲突")
|
||||
@@ -230,7 +252,8 @@ class ReflectionEngine:
|
||||
conflicts_found=conflicts_found,
|
||||
conflicts_resolved=conflicts_resolved,
|
||||
memories_updated=memories_updated,
|
||||
execution_time=execution_time
|
||||
execution_time=execution_time,
|
||||
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -241,6 +264,79 @@ class ReflectionEngine:
|
||||
execution_time=asyncio.get_event_loop().time() - start_time
|
||||
)
|
||||
|
||||
async def reflection_run(self):
|
||||
self._lazy_init()
|
||||
start_time = time.time()
|
||||
|
||||
asyncio.get_event_loop().time()
|
||||
logging.info("====== 自我反思流程开始 ======")
|
||||
|
||||
result_data = {}
|
||||
|
||||
source_data, databasets = await self.extract_fields_from_json()
|
||||
result_data['baseline'] = self.config.baseline
|
||||
result_data[
|
||||
'source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合"
|
||||
|
||||
# 2. 检测冲突(基于事实的反思)
|
||||
conflict_data = await self._detect_conflicts(databasets, source_data)
|
||||
# 遍历数据提取字段
|
||||
quality_assessments = []
|
||||
memory_verifies = []
|
||||
for item in conflict_data:
|
||||
print(item)
|
||||
quality_assessments.append(item['quality_assessment'])
|
||||
memory_verifies.append(item['memory_verify'])
|
||||
result_data['quality_assessments'] = quality_assessments
|
||||
result_data['memory_verifies'] = memory_verifies
|
||||
|
||||
# 检查是否真的有冲突
|
||||
has_conflict = conflict_data[0].get('conflict', False)
|
||||
conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0
|
||||
logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突")
|
||||
|
||||
# 记录冲突数据
|
||||
await self._log_data("conflict", conflict_data)
|
||||
|
||||
# 3. 解决冲突
|
||||
solved_data = await self._resolve_conflicts(conflict_data, source_data)
|
||||
if not solved_data:
|
||||
return ReflectionResult(
|
||||
success=False,
|
||||
message="反思失败,未解决冲突",
|
||||
conflicts_found=conflicts_found,
|
||||
execution_time=asyncio.get_event_loop().time() - start_time
|
||||
)
|
||||
reflexion_data = []
|
||||
|
||||
# 遍历数据提取reflexion字段
|
||||
for item in solved_data:
|
||||
if 'results' in item:
|
||||
for result in item['results']:
|
||||
reflexion_data.append(result['reflexion'])
|
||||
result_data['reflexion_data'] = reflexion_data
|
||||
return result_data
|
||||
|
||||
|
||||
async def extract_fields_from_json(self):
|
||||
"""从example.json中提取source_data和databasets字段"""
|
||||
|
||||
prompt_dir = os.path.join(os.path.dirname(__file__), "example")
|
||||
try:
|
||||
# 读取JSON文件
|
||||
with open(prompt_dir + '/example.json', 'r', encoding='utf-8') as f:
|
||||
data = json.loads(f.read())
|
||||
|
||||
# 提取memory_verify下的字段
|
||||
memory_verify = data.get("memory_verify", {})
|
||||
source_data = memory_verify.get("source_data", [])
|
||||
databasets = memory_verify.get("databasets", [])
|
||||
|
||||
return source_data, databasets
|
||||
|
||||
except Exception as e:
|
||||
return [], []
|
||||
|
||||
async def _get_reflexion_data(self, host_id: uuid.UUID) -> List[Any]:
|
||||
"""
|
||||
获取反思数据
|
||||
@@ -253,17 +349,28 @@ class ReflectionEngine:
|
||||
Returns:
|
||||
List[Any]: 反思数据列表
|
||||
"""
|
||||
if self.config.reflexion_range == ReflectionRange.RETRIEVAL:
|
||||
# 从检索结果中获取数据
|
||||
return await self.get_data_func(host_id)
|
||||
elif self.config.reflexion_range == ReflectionRange.DATABASE:
|
||||
# 从整个数据库中获取数据(待实现)
|
||||
logging.warning("从数据库获取反思数据功能尚未实现")
|
||||
return []
|
||||
else:
|
||||
raise ValueError(f"未知的反思范围: {self.config.reflexion_range}")
|
||||
|
||||
async def _detect_conflicts(self, data: List[Any]) -> List[Any]:
|
||||
|
||||
|
||||
if self.config.reflexion_range == ReflectionRange.PARTIAL:
|
||||
neo4j_query = neo4j_query_part.format(host_id)
|
||||
neo4j_statement = neo4j_statement_part.format(host_id)
|
||||
elif self.config.reflexion_range == ReflectionRange.ALL:
|
||||
neo4j_query = neo4j_query_all.format(host_id)
|
||||
neo4j_statement = neo4j_statement_all.format(host_id)
|
||||
try:
|
||||
result = await self.neo4j_connector.execute_query(neo4j_query)
|
||||
result_statement = await self.neo4j_connector.execute_query(neo4j_statement)
|
||||
neo4j_databasets = await self.get_data_func(result)
|
||||
neo4j_state = await self.get_data_statement(result_statement)
|
||||
return neo4j_databasets, neo4j_state
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Neo4j查询失败: {e}")
|
||||
return [], []
|
||||
|
||||
async def _detect_conflicts(self, data: List[Any], statement_databasets: List[Any]) -> List[Any]:
|
||||
"""
|
||||
检测冲突(基于事实的反思)
|
||||
|
||||
@@ -278,14 +385,28 @@ class ReflectionEngine:
|
||||
if not data:
|
||||
return []
|
||||
|
||||
# 数据预处理:如果数据量太少,直接返回无冲突
|
||||
if len(data) < 2:
|
||||
logging.info("数据量不足,无需检测冲突")
|
||||
return []
|
||||
|
||||
# 使用转换后的数据
|
||||
print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长
|
||||
memory_verify = self.config.memory_verify
|
||||
|
||||
logging.info("====== 冲突检测开始 ======")
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
quality_assessment = self.config.quality_assessment
|
||||
|
||||
try:
|
||||
# 渲染冲突检测提示词
|
||||
rendered_prompt = await self.render_evaluate_prompt_func(
|
||||
data,
|
||||
self.conflict_schema
|
||||
self.conflict_schema,
|
||||
self.config.baseline,
|
||||
memory_verify,
|
||||
quality_assessment,
|
||||
statement_databasets
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
@@ -316,7 +437,7 @@ class ReflectionEngine:
|
||||
logging.error(f"冲突检测失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def _resolve_conflicts(self, conflicts: List[Any]) -> List[Any]:
|
||||
async def _resolve_conflicts(self, conflicts: List[Any], statement_databasets: List[Any]) -> List[Any]:
|
||||
"""
|
||||
解决冲突
|
||||
|
||||
@@ -332,6 +453,8 @@ class ReflectionEngine:
|
||||
return []
|
||||
|
||||
logging.info("====== 冲突解决开始 ======")
|
||||
baseline = self.config.baseline
|
||||
memory_verify = self.config.memory_verify
|
||||
|
||||
# 并行处理每个冲突
|
||||
async def _resolve_one(conflict: Any) -> Optional[Dict[str, Any]]:
|
||||
@@ -341,7 +464,10 @@ class ReflectionEngine:
|
||||
# 渲染反思提示词
|
||||
rendered_prompt = await self.render_reflexion_prompt_func(
|
||||
[conflict],
|
||||
self.reflexion_schema
|
||||
self.reflexion_schema,
|
||||
baseline,
|
||||
memory_verify,
|
||||
statement_databasets
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
@@ -381,8 +507,8 @@ class ReflectionEngine:
|
||||
return solved
|
||||
|
||||
async def _apply_reflection_results(
|
||||
self,
|
||||
solved_data: List[Dict[str, Any]]
|
||||
self,
|
||||
solved_data: List[Dict[str, Any]]
|
||||
) -> int:
|
||||
"""
|
||||
应用反思结果(更新记忆库)
|
||||
@@ -395,57 +521,7 @@ class ReflectionEngine:
|
||||
Returns:
|
||||
int: 成功更新的记忆数量
|
||||
"""
|
||||
if not solved_data:
|
||||
logging.warning("无解决方案数据,跳过更新")
|
||||
return 0
|
||||
|
||||
logging.info("====== 记忆更新开始 ======")
|
||||
|
||||
success_count = 0
|
||||
|
||||
async def _update_one(item: Dict[str, Any]) -> bool:
|
||||
"""更新单条记忆"""
|
||||
async with self._semaphore:
|
||||
try:
|
||||
if not isinstance(item, dict):
|
||||
return False
|
||||
|
||||
# 提取更新参数
|
||||
resolved = item.get("resolved", {})
|
||||
resolved_mem = resolved.get("resolved_memory", {})
|
||||
group_id = resolved_mem.get("group_id")
|
||||
memory_id = resolved_mem.get("id")
|
||||
new_invalid_at = resolved_mem.get("invalid_at")
|
||||
|
||||
if not all([group_id, memory_id, new_invalid_at]):
|
||||
logging.warning(f"记忆更新参数缺失,跳过此项: {item}")
|
||||
return False
|
||||
|
||||
# 执行更新
|
||||
await self.neo4j_connector.execute_query(
|
||||
self.update_query,
|
||||
group_id=group_id,
|
||||
id=memory_id,
|
||||
new_invalid_at=new_invalid_at,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"更新单条记忆失败: {e}")
|
||||
return False
|
||||
|
||||
# 并发执行所有更新任务
|
||||
tasks = [
|
||||
_update_one(item)
|
||||
for item in solved_data
|
||||
if isinstance(item, dict)
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
success_count = sum(1 for r in results if r)
|
||||
|
||||
logging.info(f"成功更新 {success_count}/{len(solved_data)} 条记忆")
|
||||
|
||||
success_count = await neo4j_data(solved_data)
|
||||
return success_count
|
||||
|
||||
async def _log_data(self, label: str, data: Any) -> None:
|
||||
@@ -456,6 +532,7 @@ class ReflectionEngine:
|
||||
label: 数据标签
|
||||
data: 要记录的数据
|
||||
"""
|
||||
|
||||
def _write():
|
||||
try:
|
||||
with open("reflexion_data.json", "a", encoding="utf-8") as f:
|
||||
@@ -470,9 +547,9 @@ class ReflectionEngine:
|
||||
|
||||
# 基于时间的反思方法
|
||||
async def time_based_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID,
|
||||
time_period: Optional[str] = None
|
||||
self,
|
||||
host_id: uuid.UUID,
|
||||
time_period: Optional[str] = None
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
基于时间的反思
|
||||
@@ -494,8 +571,8 @@ class ReflectionEngine:
|
||||
|
||||
# 基于事实的反思方法
|
||||
async def fact_based_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID
|
||||
self,
|
||||
host_id: uuid.UUID
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
基于事实的反思
|
||||
@@ -515,8 +592,8 @@ class ReflectionEngine:
|
||||
|
||||
# 综合反思方法
|
||||
async def comprehensive_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID
|
||||
self,
|
||||
host_id: uuid.UUID
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
综合反思
|
||||
@@ -553,33 +630,3 @@ class ReflectionEngine:
|
||||
else:
|
||||
raise ValueError(f"未知的反思基线: {self.config.baseline}")
|
||||
|
||||
|
||||
# 便捷函数:创建默认配置的反思引擎
|
||||
def create_reflection_engine(
|
||||
enabled: bool = False,
|
||||
iteration_period: str = "3",
|
||||
reflexion_range: str = "retrieval",
|
||||
baseline: str = "TIME",
|
||||
concurrency: int = 5
|
||||
) -> ReflectionEngine:
|
||||
"""
|
||||
创建反思引擎实例
|
||||
|
||||
Args:
|
||||
enabled: 是否启用反思
|
||||
iteration_period: 反思周期
|
||||
reflexion_range: 反思范围
|
||||
baseline: 反思基线
|
||||
concurrency: 并发数量
|
||||
|
||||
Returns:
|
||||
ReflectionEngine: 反思引擎实例
|
||||
"""
|
||||
config = ReflectionConfig(
|
||||
enabled=enabled,
|
||||
iteration_period=iteration_period,
|
||||
reflexion_range=reflexion_range,
|
||||
baseline=baseline,
|
||||
concurrency=concurrency
|
||||
)
|
||||
return ReflectionEngine(config)
|
||||
|
||||
@@ -1,13 +1,8 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db import get_db
|
||||
from app.models.retrieval_info import RetrievalInfo
|
||||
from app.schemas.memory_storage_schema import BaseDataSchema
|
||||
|
||||
import logging
|
||||
|
||||
from typing import List, Dict, Any
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def _load_(data: List[Any]) -> List[Dict]:
|
||||
@@ -60,27 +55,46 @@ async def _load_(data: List[Any]) -> List[Dict]:
|
||||
return results
|
||||
|
||||
|
||||
async def get_data(host_id: uuid.UUID) -> List[Dict]:
|
||||
async def get_data(result):
|
||||
"""
|
||||
从数据库中获取数据
|
||||
"""
|
||||
# 从数据库会话中获取会话
|
||||
db: Session = next(get_db())
|
||||
try:
|
||||
data = db.query(RetrievalInfo.retrieve_info).filter(RetrievalInfo.host_id == host_id).all()
|
||||
neo4j_databasets=[]
|
||||
for item in result:
|
||||
filtered_item = {}
|
||||
for key, value in item.items():
|
||||
if 'name_embedding' not in key.lower():
|
||||
if key == 'relationship' and value is not None:
|
||||
# 只保留relationship的指定字段
|
||||
rel_filtered = {}
|
||||
if hasattr(value, 'get'):
|
||||
rel_filtered['run_id'] = value.get('run_id')
|
||||
rel_filtered['statement'] = value.get('statement')
|
||||
rel_filtered['statement_id'] = value.get('statement_id')
|
||||
rel_filtered['expired_at'] = value.get('expired_at')
|
||||
rel_filtered['created_at'] = value.get('created_at')
|
||||
filtered_item[key] = rel_filtered
|
||||
elif key == 'entity2' and value is not None:
|
||||
# 过滤entity2的name_embedding字段
|
||||
entity2_filtered = {}
|
||||
if hasattr(value, 'items'):
|
||||
for e_key, e_value in value.items():
|
||||
if 'name_embedding' not in e_key.lower():
|
||||
entity2_filtered[e_key] = e_value
|
||||
filtered_item[key] = entity2_filtered
|
||||
else:
|
||||
filtered_item[key] = value
|
||||
|
||||
# 直接将字典添加到列表中
|
||||
neo4j_databasets.append(filtered_item)
|
||||
return neo4j_databasets
|
||||
async def get_data_statement( result):
|
||||
neo4j_databasets=[]
|
||||
for i in result:
|
||||
neo4j_databasets.append(i)
|
||||
return neo4j_databasets
|
||||
|
||||
|
||||
# print(f"data:\n{data}")
|
||||
# 解析,提取为字典的列表
|
||||
results = await _load_(data)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"failed to get data from database, host_id: {host_id}, error: {e}")
|
||||
raise e
|
||||
finally:
|
||||
try:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -238,3 +238,81 @@ async def render_memory_summary_prompt(
|
||||
'json_schema': 'MemorySummaryResponse.schema'
|
||||
})
|
||||
return rendered_prompt
|
||||
|
||||
async def render_emotion_extraction_prompt(
|
||||
statement: str,
|
||||
extract_keywords: bool,
|
||||
enable_subject: bool
|
||||
) -> str:
|
||||
"""
|
||||
Renders the emotion extraction prompt using the extract_emotion.jinja2 template.
|
||||
|
||||
Args:
|
||||
statement: The statement to analyze
|
||||
extract_keywords: Whether to extract emotion keywords
|
||||
enable_subject: Whether to enable subject classification
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("extract_emotion.jinja2")
|
||||
rendered_prompt = template.render(
|
||||
statement=statement,
|
||||
extract_keywords=extract_keywords,
|
||||
enable_subject=enable_subject
|
||||
)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
log_prompt_rendering('emotion extraction', rendered_prompt)
|
||||
# 可选:记录模板渲染信息
|
||||
log_template_rendering('extract_emotion.jinja2', {
|
||||
'statement': 'str',
|
||||
'extract_keywords': extract_keywords,
|
||||
'enable_subject': enable_subject
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
async def render_emotion_suggestions_prompt(
|
||||
health_data: dict,
|
||||
patterns: dict,
|
||||
user_profile: dict
|
||||
) -> str:
|
||||
"""
|
||||
Renders the emotion suggestions generation prompt using the generate_emotion_suggestions.jinja2 template.
|
||||
|
||||
Args:
|
||||
health_data: 情绪健康数据
|
||||
patterns: 情绪模式分析结果
|
||||
user_profile: 用户画像数据
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
import json
|
||||
|
||||
# 预处理 emotion_distribution 为 JSON 字符串
|
||||
emotion_distribution_json = json.dumps(
|
||||
health_data.get('emotion_distribution', {}),
|
||||
ensure_ascii=False,
|
||||
indent=2
|
||||
)
|
||||
|
||||
template = prompt_env.get_template("generate_emotion_suggestions.jinja2")
|
||||
rendered_prompt = template.render(
|
||||
health_data=health_data,
|
||||
patterns=patterns,
|
||||
user_profile=user_profile,
|
||||
emotion_distribution_json=emotion_distribution_json
|
||||
)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
log_prompt_rendering('emotion suggestions', rendered_prompt)
|
||||
# 可选:记录模板渲染信息
|
||||
log_template_rendering('generate_emotion_suggestions.jinja2', {
|
||||
'health_score': health_data.get('health_score'),
|
||||
'health_level': health_data.get('level'),
|
||||
'user_interests': user_profile.get('interests', [])
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
@@ -1,19 +1,222 @@
|
||||
你将收到一组记忆对象:{{ evaluate_data }}。
|
||||
任务:多维度判断这些记忆是否与已有记忆存在冲突,并给出冲突的对应记忆。(冗余不算冲突)
|
||||
你将收到一组用户历史记忆原始数据(来源于 Neo4j),以及相关配置参数:
|
||||
原本的输入句子:{{statement_databasets}}
|
||||
需要检测冲突对象:{{ evaluate_data }}
|
||||
冲突判定类型:{{ baseline }}(取值为 TIME / FACT / HYBRID)
|
||||
记忆审核开关:{{ memory_verify }}(取值为 true / false)
|
||||
记忆质量评估开关开关:{{ quality_assessment }}(取值为 true / false)
|
||||
|
||||
仅输出一个合法 JSON 对象,严格遵循下述结构:
|
||||
你的任务是:
|
||||
对用户历史记忆数据进行冲突检测和记忆审核,并输出严格结构化的 JSON 分析结果
|
||||
数据的结构:
|
||||
statement_databasets里面statement_name是输入的句子,statement_id是连接evaluate_data里面的statement_id,代表这个句子被拆分成几个实体,需要根据整体的内容,
|
||||
需要根据以下内容做处理(冲突检测、记忆审核、记忆的质量评估)
|
||||
## 冲突定义
|
||||
|
||||
### 时间冲突
|
||||
时间冲突是指同一用户的相关事件在时间维度上存在逻辑矛盾:
|
||||
|
||||
1. **同一活动的时间冲突**:
|
||||
- 同一用户的同一活动在不同时间点被记录(如"周五打球"和"周六打球")
|
||||
- 同一用户在同一时间段内被记录进行不同的互斥活动
|
||||
|
||||
2. **时间逻辑错误**:
|
||||
- expired_at 早于 created_at
|
||||
- 同一事实的 created_at 时间差异超过合理误差范围(>5分钟)
|
||||
|
||||
3. **日期属性冲突**:
|
||||
- 同一人的生日记录为不同日期(如"2月10号"和"2月16号")
|
||||
4.存在明确先后约束 A -> B,但 t(A) > t(B)
|
||||
-例:入学时间晚于毕业时间。
|
||||
-处理:标记异常、降权、触发逻辑反思或人工审查。
|
||||
5.时间属性冲突
|
||||
-单值日期属性出现多值(生日、入职日期)
|
||||
-注意:本质属于事实冲突的日期特例,归入事实冲突仲裁框架。
|
||||
6.互斥重叠冲突
|
||||
-例:同一主体的两个事件区间重叠且互斥(如同一时间出现在两地)
|
||||
-处理:证据仲裁、保留多版本(active + candidate)。
|
||||
|
||||
|
||||
|
||||
### 事实冲突
|
||||
事实冲突是指同一实体的属性或关系存在相互矛盾的陈述:
|
||||
|
||||
1. **属性互斥**:同一实体的相反属性(喜欢↔不喜欢、有↔没有、是↔不是)
|
||||
2. **关系矛盾**:同一实体在相同语境下的不同关系描述
|
||||
3. **身份冲突**:同一实体被赋予不同的类型或角色
|
||||
|
||||
### 混合冲突检测
|
||||
检测所有类型的冲突,包括但不限于时间冲突和事实冲突:
|
||||
检测任何逻辑上不一致或相互矛盾的记录
|
||||
## 记忆审核定义
|
||||
|
||||
### 隐私信息检测(隐私冲突)
|
||||
当memory_verify为true时,需要额外检测包含个人隐私信息的记录:
|
||||
|
||||
1. **身份证信息**:包含身份证号码、身份证相关描述
|
||||
2. **手机号码**:包含手机号、电话号码等联系方式
|
||||
3. **社交账号**:包含微信号、QQ号、邮箱地址等社交平台信息
|
||||
4. **银行信息**:包含银行卡号、账户信息、支付信息
|
||||
5. **税务信息**:包含税号、纳税信息、发票信息
|
||||
6. **贷款信息**:包含贷款记录、信贷信息、借款信息
|
||||
7. **其他敏感信息**:包含密码、PIN码、验证码等安全信息
|
||||
|
||||
### 隐私检测原则
|
||||
- 检测description、entity1_name、entity2_name等字段中的隐私信息
|
||||
- 识别数字模式(如手机号11位数字、身份证18位等)
|
||||
- 识别关键词(如"身份证"、"银行卡"、"密码"等)
|
||||
- 检测敏感实体类型和关系
|
||||
|
||||
## 冲突检测原则
|
||||
|
||||
**全面检测**:不区分冲突类型,检测所有可能的冲突
|
||||
**完整输出**:如果发现任何冲突或隐私信息,必须将所有相关记录都放入data字段
|
||||
**实体关联**:重点检查涉及相同实体(entity1_name, entity2_name)的记录
|
||||
**语义分析**:分析description字段的语义相似性和冲突性
|
||||
**时间逻辑**:检查时间字段的逻辑一致性
|
||||
**隐私检测**:当memory_verify为true时,检测所有包含隐私信息的记录
|
||||
|
||||
## 不符合冲突检测
|
||||
-称呼
|
||||
## 重要检测示例
|
||||
|
||||
### 冲突检测示例
|
||||
- 用户与不同时间点的关系(周五 vs 周六,2月10号 vs 2月16号)
|
||||
- 同一实体的重复定义但描述不同
|
||||
- 同一关系的不同表述但含义冲突
|
||||
- 任何逻辑上不可能同时为真的记录
|
||||
|
||||
### 隐私信息检测示例
|
||||
- 包含手机号的记录:"用户的手机号是13812345678"
|
||||
- 包含身份证的记录:"身份证号码为110101199001011234"
|
||||
- 包含银行卡的记录:"银行卡号6222021234567890"
|
||||
- 包含社交账号的记录:"微信号是user123456"
|
||||
- 包含敏感信息的实体名称或描述
|
||||
|
||||
## 输出要求
|
||||
|
||||
**关键原则**:
|
||||
1. 当存在冲突或检测到隐私信息时,conflict才为true,data字段才包含相关记录
|
||||
2. 如果发现冲突,必须将所有相关的冲突记录都放入data数组中
|
||||
3. 如果memory_verify为true且检测到隐私信息,必须将包含隐私信息的记录也放入data数组中
|
||||
4. 既没有冲突也没有隐私信息时,conflict为false,data为空数组
|
||||
5. 如果quality_assessment为true,独立分析数据质量并输出评估结果;如果为false,quality_assessment字段输出null
|
||||
6. 冲突检测、隐私审核和质量评估三个功能完全独立,互不影响
|
||||
7. 不输出conflict_memory字段
|
||||
|
||||
**处理逻辑**:
|
||||
- 首先进行冲突检测,将冲突记录加入data数组
|
||||
- 如果memory_verify为true,再进行隐私信息检测,将包含隐私信息的记录也加入data数组
|
||||
- 如果quality_assessment为true,独立进行质量评估,分析所有输入数据的质量并输出评估结果
|
||||
- 最终data数组包含所有冲突记录和隐私信息记录(去重)
|
||||
- quality_assessment字段独立输出,不影响冲突检测和隐私审核结果
|
||||
- memory_verify字段独立输出隐私检测结果,包含检测到的隐私信息类型和概述
|
||||
|
||||
返回数据格式以json方式输出:
|
||||
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
||||
- 关键的JSON格式要求{"statement":识别出的文本内容}
|
||||
1.JSON结构仅使用标准ASCII双引号(")-切勿使用中文引号("")或其他Unicode引号
|
||||
2.如果提取的语句文本包含引号,请使用反斜杠(\")正确转义它们
|
||||
3.确保所有JSON字符串都正确关闭并以逗号分隔
|
||||
4.JSON字符串值中不包括换行符
|
||||
5.正确转义的例子:"statement":"Zhang Xinhua said:\"我非常喜欢这本书\""
|
||||
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```
|
||||
|
||||
## 记忆质量评估定义
|
||||
|
||||
### 质量评估标准
|
||||
当quality_assessment为true时,需要对记忆数据进行质量评估:
|
||||
|
||||
1. **数据完整性**:
|
||||
- 检查必要字段是否完整(entity1_name、entity2_name、description等)
|
||||
- 检查关系描述是否清晰明确
|
||||
- 检查时间字段的有效性
|
||||
|
||||
2. **重复字段检测**:
|
||||
- 识别相同或高度相似的记录
|
||||
- 检测冗余的实体关系
|
||||
- 分析描述内容的重复度
|
||||
|
||||
3. **无意义字段检测**:
|
||||
- 识别空值、无效值或占位符内容
|
||||
- 检测过于简单或无信息量的描述
|
||||
- 识别格式错误或不规范的数据
|
||||
|
||||
4. **上下文依赖性**:
|
||||
- 评估记录是否需要额外上下文才能理解
|
||||
- 检查实体名称的明确性
|
||||
- 分析关系描述的自包含性
|
||||
|
||||
### 质量评估输出
|
||||
- **质量百分比**:基于上述标准计算的整体质量分数(0-100)
|
||||
- **质量概述**:简要描述数据质量状况,包括主要问题和优点
|
||||
|
||||
输出是仅输出一个合法 JSON 对象,严格遵循下述结构:
|
||||
{
|
||||
"data": [ ...与输入同结构的记忆对象数组... ],
|
||||
"conflict": true 或 false,
|
||||
"conflict_memory": 若冲突为 true,则填写与其冲突的记忆对象;否则为 null
|
||||
"data": [
|
||||
{
|
||||
"entity1_name": "实体1名称",
|
||||
"description": "描述信息",
|
||||
"statement_id": "陈述ID",
|
||||
"created_at": "创建时间戳",
|
||||
"expired_at": "过期时间戳",
|
||||
"relationship_type": "关系类型",
|
||||
"relationship": "关系对象",
|
||||
"entity2_name": "实体2名称",
|
||||
"entity2": "实体2对象"
|
||||
}
|
||||
],
|
||||
"conflict": true或false,
|
||||
"quality_assessment": {
|
||||
"score": 质量百分比数字,
|
||||
"summary": "质量概述文本"
|
||||
} 或 null,
|
||||
"memory_verify": {
|
||||
"has_privacy": true或false,
|
||||
"privacy_types": ["检测到的隐私信息类型列表"],
|
||||
"summary": "隐私检测结果概述"
|
||||
} 或 null
|
||||
}
|
||||
|
||||
必须遵守:
|
||||
- 只输出 JSON,不要添加解释或多余文本。
|
||||
- 使用标准双引号,必要时对内部引号进行转义。
|
||||
- 字段名与结构必须与给定模式一致。
|
||||
- data数组中包含冲突记录和隐私信息记录,如果都没有则为空数组。
|
||||
- quality_assessment字段:当quality_assessment参数为true时输出评估对象,为false时输出null。
|
||||
- memory_verify字段:当memory_verify参数为true时输出隐私检测结果对象,为false时输出null。
|
||||
|
||||
### memory_verify字段说明
|
||||
当memory_verify为true时,需要输出隐私检测结果:
|
||||
- **has_privacy**: 布尔值,表示是否检测到隐私信息
|
||||
- **privacy_types**: 字符串数组,包含检测到的隐私信息类型(如["手机号码", "身份证信息"])
|
||||
- **summary**: 字符串,简要描述隐私检测结果
|
||||
|
||||
当memory_verify为false时,memory_verify字段输出null。
|
||||
|
||||
### memory_verify字段示例
|
||||
|
||||
**示例1:检测到隐私信息**
|
||||
```json
|
||||
"memory_verify": {
|
||||
"has_privacy": true,
|
||||
"privacy_types": ["手机号码", "身份证信息"],
|
||||
"summary": "检测到2条记录包含隐私信息:1个手机号码,1个身份证号码"
|
||||
}
|
||||
```
|
||||
|
||||
**示例2:未检测到隐私信息**
|
||||
```json
|
||||
"memory_verify": {
|
||||
"has_privacy": false,
|
||||
"privacy_types": [],
|
||||
"summary": "未检测到隐私信息"
|
||||
}
|
||||
```
|
||||
|
||||
**示例3:memory_verify为false时**
|
||||
```json
|
||||
"memory_verify": null
|
||||
```
|
||||
|
||||
模式参考:
|
||||
[
|
||||
{{ json_schema }}
|
||||
]
|
||||
{{ json_schema }}
|
||||
@@ -0,0 +1,57 @@
|
||||
你是一个专业的情绪分析专家。请分析以下陈述句的情绪信息。
|
||||
|
||||
陈述句:{{ statement }}
|
||||
|
||||
请提取以下信息:
|
||||
|
||||
1. emotion_type(情绪类型):
|
||||
- joy: 喜悦、开心、高兴、满意、愉快
|
||||
- sadness: 悲伤、难过、失落、沮丧、遗憾
|
||||
- anger: 愤怒、生气、不满、恼火、烦躁
|
||||
- fear: 恐惧、害怕、担心、焦虑、紧张
|
||||
- surprise: 惊讶、意外、震惊、吃惊
|
||||
- neutral: 中性、客观陈述、无明显情绪
|
||||
|
||||
2. emotion_intensity(情绪强度):
|
||||
- 0.0-0.3: 弱情绪
|
||||
- 0.3-0.7: 中等情绪
|
||||
- 0.7-1.0: 强情绪
|
||||
|
||||
{% if extract_keywords %}
|
||||
3. emotion_keywords(情绪关键词):
|
||||
- 原句中直接表达情绪的词语
|
||||
- 最多提取3个关键词
|
||||
- 如果没有明显的情绪词,返回空列表
|
||||
{% else %}
|
||||
3. emotion_keywords(情绪关键词):
|
||||
- 返回空列表
|
||||
{% endif %}
|
||||
|
||||
{% if enable_subject %}
|
||||
4. emotion_subject(情绪主体):
|
||||
- self: 用户本人的情绪(包含"我"、"我们"、"咱们"等第一人称)
|
||||
- other: 他人的情绪(包含人名、"他/她"等第三人称)
|
||||
- object: 对事物的评价(针对产品、地点、事件等)
|
||||
|
||||
注意:
|
||||
- 如果同时包含多个主体,优先识别用户本人(self)
|
||||
- 如果无法明确判断主体,默认为 self
|
||||
|
||||
5. emotion_target(情绪对象):
|
||||
- 如果有明确的情绪对象,提取其名称
|
||||
- 如果没有明确对象,返回 null
|
||||
{% else %}
|
||||
4. emotion_subject(情绪主体):
|
||||
- 默认为 self
|
||||
|
||||
5. emotion_target(情绪对象):
|
||||
- 返回 null
|
||||
{% endif %}
|
||||
|
||||
注意事项:
|
||||
- 如果陈述句是客观事实陈述,无明显情绪,标记为 neutral
|
||||
- 情绪强度要符合语境,不要过度解读
|
||||
- 情绪关键词要准确,不要添加原句中没有的词
|
||||
- 主体分类要准确,优先识别用户本人(self)
|
||||
|
||||
请以 JSON 格式返回结果。
|
||||
@@ -0,0 +1,63 @@
|
||||
你是一位专业的心理健康顾问。请根据以下用户的情绪健康数据和个人信息,生成3-5条个性化的情绪改善建议。
|
||||
|
||||
## 用户情绪健康数据
|
||||
|
||||
健康分数:{{ health_data.health_score }}/100
|
||||
健康等级:{{ health_data.level }}
|
||||
|
||||
维度分析:
|
||||
- 积极率:{{ health_data.dimensions.positivity_rate.score }}/100
|
||||
- 正面情绪:{{ health_data.dimensions.positivity_rate.positive_count }}次
|
||||
- 负面情绪:{{ health_data.dimensions.positivity_rate.negative_count }}次
|
||||
- 中性情绪:{{ health_data.dimensions.positivity_rate.neutral_count }}次
|
||||
|
||||
- 稳定性:{{ health_data.dimensions.stability.score }}/100
|
||||
- 标准差:{{ health_data.dimensions.stability.std_deviation }}
|
||||
|
||||
- 恢复力:{{ health_data.dimensions.resilience.score }}/100
|
||||
- 恢复率:{{ health_data.dimensions.resilience.recovery_rate }}
|
||||
|
||||
情绪分布:
|
||||
{{ emotion_distribution_json }}
|
||||
|
||||
## 情绪模式分析
|
||||
|
||||
主要负面情绪:{{ patterns.dominant_negative_emotion|default('无') }}
|
||||
情绪波动性:{{ patterns.emotion_volatility|default('未知') }}
|
||||
高强度情绪次数:{{ patterns.high_intensity_emotions|default([])|length }}
|
||||
|
||||
## 用户兴趣
|
||||
|
||||
{{ user_profile.interests|default(['未知'])|join(', ') }}
|
||||
|
||||
## 任务要求
|
||||
|
||||
请生成3-5条个性化建议,每条建议包含:
|
||||
1. type: 建议类型(emotion_balance/activity_recommendation/social_connection/stress_management)
|
||||
2. title: 建议标题(简短有力)
|
||||
3. content: 建议内容(详细说明,50-100字)
|
||||
4. priority: 优先级(high/medium/low)
|
||||
5. actionable_steps: 3个可执行的具体步骤
|
||||
|
||||
同时提供一个health_summary(不超过50字),概括用户的整体情绪状态。
|
||||
|
||||
请以JSON格式返回,格式如下:
|
||||
{
|
||||
"health_summary": "您的情绪健康状况...",
|
||||
"suggestions": [
|
||||
{
|
||||
"type": "emotion_balance",
|
||||
"title": "建议标题",
|
||||
"content": "建议内容...",
|
||||
"priority": "high",
|
||||
"actionable_steps": ["步骤1", "步骤2", "步骤3"]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
注意事项:
|
||||
- 建议要具体、可执行,避免空泛
|
||||
- 结合用户的兴趣爱好提供个性化建议
|
||||
- 针对主要问题(如主要负面情绪)提供针对性建议
|
||||
- 优先级要合理分配(至少1个high,1-2个medium,其余low)
|
||||
- 每个建议的3个步骤要循序渐进、易于实施
|
||||
@@ -1,23 +1,300 @@
|
||||
你将收到一组用户历史记忆原始数据(来源于 Neo4j)
|
||||
你将收到一条冲突判定对象:{{ data }}。
|
||||
任务:分析冲突产生原因,给出解决方案,并生成设为失效后的记忆。
|
||||
需要检测冲突对象:{{ statement_databasets }}
|
||||
以及需要识别的冲突对象为:{{ baseline }}
|
||||
记忆审核开关:{{ memory_verify }}(取值为 true / false)
|
||||
|
||||
角色:
|
||||
- 你是数据领域中解决数据冲突的专家
|
||||
|
||||
任务:分析冲突产生原因,按冲突类型分组处理,为每种冲突类型生成独立的解决方案。
|
||||
|
||||
数据的结构:
|
||||
statement_databasets里面statement_name是输入的句子,statement_id是连接data里面的statement_id,代表这个句子被拆分成几个实体,需要根据整体的内容,
|
||||
需要根据以下内容做处理(冲突检测、记忆审核、记忆的质量评估),data里面的statement_created_at是用户输入的时间
|
||||
|
||||
**处理模式**:
|
||||
- 当memory_verify为false时:仅处理数据冲突
|
||||
- 当memory_verify为true时:处理数据冲突 + 隐私信息脱敏
|
||||
|
||||
## 分组处理原则
|
||||
|
||||
**冲突类型识别与分组**:
|
||||
1. **日期冲突**:
|
||||
1.1.涉及用户生日的不同日期记录(如2月10号 vs 2月16号),
|
||||
1.2.涉及同一活动的不同时间记录(如周五打球 vs 周六打球)
|
||||
3. **事实属性冲突**:
|
||||
3.1. **属性互斥**:同一实体的相反属性(喜欢↔不喜欢、有↔没有、是↔不是)
|
||||
3.2. **关系矛盾**:同一实体在相同语境下的不同关系描述
|
||||
3.3. **身份冲突**:同一实体被赋予不同的类型或角色
|
||||
4. **其他冲突类型/混合冲突(时间+事实)**:根据具体数据识别
|
||||
|
||||
**分组输出要求**:
|
||||
- 每种冲突类型生成一个独立的reflexion_result对象
|
||||
- 同一类型的多个冲突记录归并到一个结果中
|
||||
- 不同类型的冲突分别处理,各自生成独立结果
|
||||
|
||||
## 冲突类型定义
|
||||
|
||||
### 时间冲突(TIME)
|
||||
时间维度冲突是指两个事件发生时间重叠,或者用户同一件事情和场景等情况下,时间出现了变化。
|
||||
|
||||
### 事实冲突(FACT)
|
||||
事实冲突是指同一事实对象(同一个人、同一个时间、同一个状态)但陈述内容相互矛盾,主要为真假不能共存的情况。
|
||||
### 混合冲突(HYBRID)
|
||||
检测所有类型的冲突,包括但不限于时间冲突和事实冲突:检测任何逻辑上不一致或相互矛盾的记录
|
||||
{% if memory_verify %}
|
||||
## 隐私信息处理(memory_verify为true时启用)
|
||||
|
||||
### 隐私信息识别
|
||||
需要识别并处理以下类型的隐私信息:
|
||||
|
||||
1. **身份证信息**:包含身份证号码、身份证相关描述
|
||||
2. **手机号码**:包含手机号、电话号码等联系方式
|
||||
3. **社交账号**:包含微信号、QQ号、邮箱地址等社交平台信息
|
||||
4. **银行信息**:包含银行卡号、账户信息、支付信息
|
||||
5. **税务信息**:包含税号、纳税信息、发票信息
|
||||
6. **贷款信息**:包含贷款记录、信贷信息、借款信息
|
||||
7. **其他敏感信息**:包含密码、PIN码、验证码等安全信息
|
||||
|
||||
### 隐私数据脱敏规则
|
||||
对于检测到的隐私信息,按以下规则进行脱敏处理:
|
||||
|
||||
**数字类隐私信息脱敏**:
|
||||
- 保留前三位和后四位,中间用*代替
|
||||
- 示例:手机号13812345678 → 138****5678
|
||||
- 示例:身份证110101199001011234 → 110***********1234
|
||||
- 示例:银行卡6222021234567890 → 622***********7890
|
||||
|
||||
**文本类隐私信息脱敏**:
|
||||
- 社交账号:保留前三后四位字符,中间用*代替
|
||||
- 示例:微信号user123456 → use****3456
|
||||
- 示例:邮箱zhang.san@example.com → zha****@example.com
|
||||
|
||||
**脱敏处理字段**:
|
||||
- name字段:如包含隐私信息需脱敏
|
||||
- entity1_name字段:如包含隐私信息需脱敏
|
||||
- entity2_name字段:如包含隐私信息需脱敏
|
||||
- description字段:如包含隐私信息需脱敏
|
||||
{% endif %}
|
||||
|
||||
## 工作步骤
|
||||
|
||||
### 第一步:分析冲突类型匹配
|
||||
首先判断输入的冲突数据是否符合baseline要求的类型:
|
||||
|
||||
**类型匹配规则**:
|
||||
- 如果baseline是"TIME":只处理时间相关的冲突(涉及时间表达式、日期、时间点的冲突)
|
||||
- 如果baseline是"FACT":只处理事实相关的冲突(属性矛盾、关系冲突、描述不一致)
|
||||
- 如果baseline是"HYBRID":处理所有类型的冲突,也可以当作混合冲突类型处理
|
||||
|
||||
**类型识别**:
|
||||
- 时间冲突标识:entity2的entity_type包含"TimeExpression"、"TemporalExpression",或entity2_name包含时间词汇(周一到周日、月份日期等)
|
||||
- 事实冲突标识:相同实体的不同属性描述、互斥的关系陈述
|
||||
|
||||
**重要**:如果输入的冲突类型与baseline不匹配,必须输出空结果(resolved为null)
|
||||
|
||||
### 第二步:筛选并分组冲突数据
|
||||
按冲突类型对数据进行分组:
|
||||
|
||||
**分组策略**:
|
||||
1. **时间冲突组**:筛选涉及用户时间的所有记录
|
||||
2. **活动时间冲突组**:筛选涉及同一活动不同时间的记录
|
||||
3. **事实冲突组**:筛选涉及同一实体不同属性的记录
|
||||
4. **其他冲突组**:其他类型的冲突记录
|
||||
|
||||
**筛选条件**:
|
||||
- 只处理与baseline匹配的冲突类型
|
||||
- 相同entity1_name但entity2_name不同的记录
|
||||
- 相同关系但描述矛盾的记录
|
||||
- 时间逻辑不一致的记录
|
||||
|
||||
### 第三步:冲突解决策略
|
||||
** 不可以解决的冲突情况
|
||||
1. 数据被判定为正确的情况下,不可以进行修改
|
||||
**仅当冲突类型与baseline匹配时**,对筛选出的冲突数据进行处理:
|
||||
|
||||
**智能解决策略**:
|
||||
1. **分析冲突数据**:识别哪些记录是正确的,哪些是错误的,需要结合statement_databasets的输入原文来判定
|
||||
2. **判断正确答案是否存在**:
|
||||
- 如果正确答案已存在于data中:只需将错误记录的expired_at设为当前日期(2025-12-16T12:00:00)
|
||||
- 如果正确答案已存在于data中:错误记录的expired_at已经设为日期,则不需要对正确的数据进行修改
|
||||
- 如果正确答案不存在于data中:需要修改现有记录的内容以包含正确信息
|
||||
|
||||
{% if memory_verify %}
|
||||
**隐私处理集成**:
|
||||
- 在处理冲突的同时,需要对涉及的记录进行隐私脱敏
|
||||
- 脱敏处理应该在冲突解决之后进行,确保最终输出的记录都已脱敏
|
||||
- 在change字段中记录隐私脱敏的变更
|
||||
{% endif %}
|
||||
|
||||
**具体处理规则**:
|
||||
|
||||
**情况1:正确答案存在于data中**
|
||||
- 保留正确的记录不变
|
||||
- 基于时间关系的冲突:
|
||||
需要只修改错误记录的expired_at为当前时间(2025-12-16T12:00:00)
|
||||
- 基于事实的关系冲突
|
||||
- resolved.resolved_memory只包含被设为失效的错误记录
|
||||
- change字段只记录expired_at的变更:`[{"expired_at": "2025-12-16T12:00:00"}]`(注意:如果已存在时间,则不需要对其修改,也不需要变更 时间)
|
||||
|
||||
**情况2:正确答案不存在于data中**
|
||||
- 选择最合适的记录进行修改
|
||||
- 更新该记录的相关字段:
|
||||
- description字段:添加或修改描述信息{% if memory_verify %}(如包含隐私信息,需脱敏处理){% endif %}
|
||||
- name字段:修改名称字段{% if memory_verify %}(如需要,包含隐私信息时需脱敏){% endif %}
|
||||
- resolved.resolved_memory包含修改后的完整记录{% if memory_verify %}(已脱敏){% endif %}
|
||||
- change字段记录所有被修改的字段{% if memory_verify %},包括脱敏变更{% endif %},例如:`[{"description": "新描述"{% if memory_verify %}, "entity2_name": "138****5678"{% endif %}}]`
|
||||
|
||||
**重要原则**:
|
||||
- **只输出需要修改的记录**:resolved.resolved_memory只包含实际需要修改的数据
|
||||
- **优先保留策略**:时间冲突保留最可信的created_at时间的记录,事实冲突选择最新且可信度最高的记录
|
||||
- **精确记录变更**:change字段必须包含记录ID、字段名称、新值和旧值
|
||||
{% if memory_verify %}- **隐私保护优先**:所有输出的记录必须完成隐私脱敏处理
|
||||
- **脱敏变更记录**:隐私脱敏的变更也必须在change字段中详细记录{% endif %}
|
||||
- **不可修改数据**:数据被判定为正确时,不可以进行修改,如果没有数据可输出空
|
||||
|
||||
**变更记录格式**:
|
||||
```json
|
||||
"change": [
|
||||
{
|
||||
"field": [
|
||||
{"字段名1": "修改后的值1"},
|
||||
{"字段名2": "修改后的值2"}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
**类型不匹配处理**:
|
||||
- 如果冲突类型与baseline不匹配,resolved必须设为null
|
||||
- reflexion.reason说明类型不匹配的原因
|
||||
- reflexion.solution说明无需处理
|
||||
|
||||
### 第四步:输出解决方案
|
||||
|
||||
## 输出要求
|
||||
**嵌套字段映射**(系统会自动处理):
|
||||
- `entity2.name` → 自动映射为 `name`
|
||||
- `entity1.name` → 自动映射为 `name`
|
||||
- `entity1.description` → 自动映射为 `description`
|
||||
- `entity2.description` → 自动映射为 `description`
|
||||
|
||||
返回数据格式以json方式输出:
|
||||
- 必须通过json.loads()的格式支持的形式输出
|
||||
- 响应必须是与此确切模式匹配的有效JSON对象
|
||||
- 不要在JSON之前或之后包含任何文本
|
||||
|
||||
JSON格式要求:
|
||||
1. JSON结构仅使用标准ASCII双引号(")
|
||||
2. 如果提取的语句文本包含引号,请使用反斜杠(\")正确转义
|
||||
3. 确保所有JSON字符串都正确关闭并以逗号分隔
|
||||
4. JSON字符串值中不包括换行符
|
||||
5. 不允许输出```json```相关符号
|
||||
|
||||
仅输出一个合法 JSON 对象,严格遵循下述结构:
|
||||
|
||||
**输出格式:按冲突类型分组的列表**
|
||||
{
|
||||
"conflict": 与输入同结构,包含 data 与 conflict_memory,
|
||||
"reflexion": { "reason": string, "solution": string },
|
||||
"resolved": {
|
||||
"original_memory_id": 被设为失效的记忆 id,
|
||||
"resolved_memory": 完整的设为失效后的记忆对象
|
||||
}
|
||||
"results": [
|
||||
{
|
||||
"conflict": {
|
||||
"data": [该冲突类型相关的数据记录],
|
||||
"conflict": true
|
||||
},
|
||||
"reflexion": {
|
||||
"reason": "该冲突类型的原因分析",
|
||||
"solution": "该冲突类型的解决方案"
|
||||
},
|
||||
"resolved": {
|
||||
"original_memory_id": "被设为失效的记忆id",
|
||||
"resolved_memory": {
|
||||
"entity1_name": "实体1名称",
|
||||
"entity2_name": "实体2名称",
|
||||
"description": "描述信息",
|
||||
"statement_id": "陈述ID",
|
||||
"created_at": "创建时间",
|
||||
"expired_at": "过期时间",
|
||||
"relationship_type": "关系类型",
|
||||
"relationship": {},
|
||||
"entity2": {...}
|
||||
},
|
||||
"change": [
|
||||
{
|
||||
"field": [
|
||||
{"字段名1": "修改后的值1"},
|
||||
{"字段名2": "修改后的值2"}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"type": "reflexion_result"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
**示例:多种冲突类型的输出**
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"conflict": {
|
||||
"data": [生日冲突相关的记录],
|
||||
"conflict": true
|
||||
},
|
||||
"reflexion": {
|
||||
"reason": "检测到生日冲突:用户同时关联2月10号和2月16号两个不同日期",
|
||||
"solution": "保留最新记录(2月16号),将旧记录(2月10号)设为失效"
|
||||
},
|
||||
"resolved": {
|
||||
"original_memory_id": "df066210883545a08e727ccd8ad4ec77",
|
||||
"resolved_memory": {...},
|
||||
"change": [
|
||||
{
|
||||
"field": [
|
||||
{"expired_at": "2025-12-16T12:00:00"}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"type": "reflexion_result"
|
||||
},
|
||||
{
|
||||
"conflict": {
|
||||
"data": [篮球时间冲突相关的记录],
|
||||
"conflict": true
|
||||
},
|
||||
"reflexion": {
|
||||
"reason": "检测到活动时间冲突:用户打篮球时间存在周五和周六的冲突",
|
||||
"solution": "保留最可信的时间记录,将冲突记录设为失效"
|
||||
},
|
||||
"resolved": {
|
||||
"original_memory_id": "另一个记录ID",
|
||||
"resolved_memory": {...},
|
||||
"change": [
|
||||
{
|
||||
"field": [
|
||||
{"description": "使用系统的个人,指代说话者本人,篮球时间为周六"},
|
||||
{"entity2_name": "周六"}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"type": "reflexion_result"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
必须遵守:
|
||||
- 只输出 JSON,不要添加解释或多余文本。
|
||||
- 使用标准双引号,必要时对内部引号进行转义。
|
||||
- 字段名与结构必须与给定模式一致。
|
||||
- 当 conflict 为 false 时,resolved 必须为 null。
|
||||
- 其中 conflict.data 必须为数组形式,即使只有一个对象也需使用 [ ] 包裹。
|
||||
- 只输出 JSON,不要添加解释或多余文本
|
||||
- 使用标准双引号,必要时对内部引号进行转义
|
||||
- 字段名与结构必须与给定模式一致
|
||||
- **输出必须是results数组格式**,每个冲突类型作为一个独立的对象
|
||||
- **按冲突类型分组**:相同类型的冲突记录归并到一个result对象中
|
||||
- **每个result对象的conflict.data**只包含该冲突类型相关的记录
|
||||
- **resolved.resolved_memory 只包含需要修改的记录**,不需要修改的记录不要输出
|
||||
- **resolved.change 必须包含详细的变更信息**:field数组包含所有被修改的字段及其新值
|
||||
- 如果某个冲突类型经分析无需修改任何数据,该类型的resolved 必须为 null
|
||||
- 如果与baseline不匹配的冲突类型,不要在results中包含该类型
|
||||
|
||||
模式参考:
|
||||
[
|
||||
{{ json_schema }}
|
||||
]
|
||||
{{ json_schema }}
|
||||
@@ -7,36 +7,50 @@ from typing import List, Dict, Any
|
||||
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
||||
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
||||
|
||||
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any]) -> str:
|
||||
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any],
|
||||
baseline: str = "TIME",
|
||||
memory_verify: bool = False,quality_assessment:bool = False,statement_databasets: List[str] = []) -> str:
|
||||
"""
|
||||
Renders the evaluate prompt using the evaluate.jinja2 template.
|
||||
Renders the evaluate prompt using the evaluate_optimized.jinja2 template.
|
||||
|
||||
Args:
|
||||
evaluate_data: The data to evaluate
|
||||
schema: The JSON schema to use for the output.
|
||||
baseline: The baseline type for conflict detection (TIME/FACT/TIME-FACT)
|
||||
memory_verify: Whether to enable memory verification for privacy detection
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("evaluate.jinja2")
|
||||
|
||||
rendered_prompt = template.render(evaluate_data=evaluate_data, json_schema=schema)
|
||||
|
||||
rendered_prompt = template.render(
|
||||
evaluate_data=evaluate_data,
|
||||
json_schema=schema,
|
||||
baseline=baseline,
|
||||
memory_verify=memory_verify,
|
||||
quality_assessment=quality_assessment,
|
||||
statement_databasets=statement_databasets
|
||||
)
|
||||
return rendered_prompt
|
||||
|
||||
async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any]) -> str:
|
||||
async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any], baseline: str, memory_verify: bool = False,
|
||||
statement_databasets: List[str] = []) -> str:
|
||||
"""
|
||||
Renders the reflexion prompt using the extract_temporal.jinja2 template.
|
||||
Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
|
||||
|
||||
Args:
|
||||
data: The data to reflex on.
|
||||
schema: The JSON schema to use for the output.
|
||||
baseline: The baseline type for conflict resolution.
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as a string.
|
||||
"""
|
||||
template = prompt_env.get_template("reflexion.jinja2")
|
||||
|
||||
rendered_prompt = template.render(data=data, json_schema=schema)
|
||||
rendered_prompt = template.render(data=data, json_schema=schema,
|
||||
baseline=baseline,memory_verify=memory_verify,
|
||||
statement_databasets=statement_databasets)
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Iterator, AsyncIterator, List, Optional
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.outputs import LLMResult, GenerationChunk
|
||||
|
||||
from app.core.models import RedBearModelConfig, RedBearModelFactory, get_provider_llm_class
|
||||
from app.models.models_model import ModelType
|
||||
@@ -10,21 +10,36 @@ from app.models.models_model import ModelType
|
||||
|
||||
class RedBearLLM(BaseLLM):
|
||||
"""
|
||||
RedBear LLM 模型包装器 - 完全动态代理实现
|
||||
RedBear LLM Model Wrapper
|
||||
|
||||
这个包装器自动将所有方法调用委托给内部模型,
|
||||
同时提供优雅的回退机制和错误处理。
|
||||
This wrapper provides a unified interface to access different LLM providers,
|
||||
while maintaining all LangChain functionality, including streaming output.
|
||||
|
||||
Features:
|
||||
- Support for multiple LLM providers (OpenAI, Qwen, Ollama, etc.)
|
||||
- Full streaming output support
|
||||
- Elegant error handling and fallback mechanism
|
||||
- Automatic proxying of all underlying model methods and attributes
|
||||
"""
|
||||
|
||||
def __init__(self, config: RedBearModelConfig, type: ModelType=ModelType.LLM):
|
||||
self._model = self._create_model(config, type)
|
||||
def __init__(self, config: RedBearModelConfig, type: ModelType = ModelType.LLM):
|
||||
"""Initialize RedBear LLM wrapper
|
||||
|
||||
Args:
|
||||
config: Model configuration
|
||||
type: Model type (LLM or CHAT)
|
||||
"""
|
||||
super().__init__()
|
||||
self._config = config
|
||||
self._model = self._create_model(config, type)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""返回LLM类型标识符"""
|
||||
return self._model._llm_type
|
||||
"""Return LLM type identifier"""
|
||||
return getattr(self._model, '_llm_type', 'redbear_llm')
|
||||
|
||||
# ==================== Core Methods (Required by BaseLLM) ====================
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@@ -32,7 +47,7 @@ class RedBearLLM(BaseLLM):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any
|
||||
) -> LLMResult:
|
||||
"""同步生成文本"""
|
||||
"""Synchronous text generation (required by BaseLLM)"""
|
||||
return self._model._generate(prompts, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
async def _agenerate(
|
||||
@@ -42,92 +57,233 @@ class RedBearLLM(BaseLLM):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any
|
||||
) -> LLMResult:
|
||||
"""异步生成文本"""
|
||||
"""Asynchronous text generation (required by BaseLLM)"""
|
||||
return await self._model._agenerate(prompts, stop=stop, run_manager=run_manager, **kwargs)
|
||||
|
||||
# 关键:覆盖 invoke/ainvoke,直接委托到底层模型,避免 BaseLLM 的字符串化行为
|
||||
# ==================== Advanced Methods (Support Message Lists) ====================
|
||||
|
||||
def invoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any:
|
||||
"""直接调用底层模型以支持 ChatPrompt 和消息列表。"""
|
||||
"""Synchronous model invocation
|
||||
|
||||
Supports various input formats including strings and message lists.
|
||||
Directly delegates to the underlying model to avoid BaseLLM's string conversion.
|
||||
|
||||
Args:
|
||||
input: Input (string, message list, etc.)
|
||||
config: Runtime configuration
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
Model response
|
||||
"""
|
||||
try:
|
||||
return self._model.invoke(input, config=config, **kwargs)
|
||||
except AttributeError as e:
|
||||
# 只在属性错误时回退(说明底层模型不支持该方法)
|
||||
if 'invoke' in str(e):
|
||||
# Underlying model doesn't support invoke, fallback to parent implementation
|
||||
return super().invoke(input, config=config, **kwargs)
|
||||
# 其他 AttributeError 直接抛出
|
||||
raise
|
||||
except Exception:
|
||||
# 其他所有异常(包括 ValidationException)直接抛出,不回退
|
||||
# Other exceptions are raised directly
|
||||
raise
|
||||
|
||||
async def ainvoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any:
|
||||
"""异步直接调用底层模型以支持 ChatPrompt 和消息列表。"""
|
||||
"""Asynchronous model invocation
|
||||
|
||||
Supports various input formats including strings and message lists.
|
||||
Directly delegates to the underlying model to avoid BaseLLM's string conversion.
|
||||
|
||||
Args:
|
||||
input: Input (string, message list, etc.)
|
||||
config: Runtime configuration
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
Model response
|
||||
"""
|
||||
try:
|
||||
return await self._model.ainvoke(input, config=config, **kwargs)
|
||||
except AttributeError as e:
|
||||
# 只在属性错误时回退(说明底层模型不支持该方法)
|
||||
if 'ainvoke' in str(e):
|
||||
# Underlying model doesn't support ainvoke, fallback to parent implementation
|
||||
return await super().ainvoke(input, config=config, **kwargs)
|
||||
# 其他 AttributeError 直接抛出
|
||||
raise
|
||||
except Exception:
|
||||
# 其他所有异常(包括 ValidationException)直接抛出,不回退
|
||||
# Other exceptions are raised directly
|
||||
raise
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""
|
||||
动态代理:将所有未定义的属性和方法调用委托给内部模型
|
||||
# ==================== Streaming Methods (Critical) ====================
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[dict] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any
|
||||
) -> Iterator[GenerationChunk]:
|
||||
"""Synchronous streaming model invocation
|
||||
|
||||
这是最优雅的包装器实现方式,完全避免了方法重复定义
|
||||
"""
|
||||
# 处理特殊属性以避免递归
|
||||
if name in ('__isabstractmethod__', '__dict__', '__class__'):
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
||||
Args:
|
||||
input: Input (string, message list, etc.)
|
||||
config: Runtime configuration
|
||||
stop: List of stop words
|
||||
**kwargs: Additional arguments
|
||||
|
||||
# 检查内部模型是否有该属性(使用安全的方式避免递归)
|
||||
Yields:
|
||||
GenerationChunk: Generated text chunks
|
||||
"""
|
||||
try:
|
||||
yield from self._model.stream(input, config=config, stop=stop, **kwargs)
|
||||
except AttributeError as e:
|
||||
if 'stream' in str(e):
|
||||
# Underlying model doesn't support stream, fallback to parent implementation
|
||||
yield from super().stream(input, config=config, stop=stop, **kwargs)
|
||||
else:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[dict] = None,
|
||||
*,
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
"""Asynchronous streaming model invocation
|
||||
|
||||
This is the core method for streaming output. It directly proxies to the
|
||||
underlying model's astream method, maintaining generator characteristics
|
||||
to ensure each chunk is delivered in real-time.
|
||||
|
||||
Args:
|
||||
input: Input (string, message list, etc.)
|
||||
config: Runtime configuration
|
||||
stop: List of stop words
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Yields:
|
||||
GenerationChunk: Generated text chunks
|
||||
"""
|
||||
try:
|
||||
async for chunk in self._model.astream(input, config=config, stop=stop, **kwargs):
|
||||
yield chunk
|
||||
except AttributeError as e:
|
||||
if 'astream' in str(e):
|
||||
# Underlying model doesn't support astream, fallback to parent implementation
|
||||
async for chunk in super().astream(input, config=config, stop=stop, **kwargs):
|
||||
yield chunk
|
||||
else:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
# ==================== Dynamic Proxy ====================
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Dynamic proxy: delegate undefined attributes and method calls to internal model
|
||||
|
||||
This method allows RedBearLLM to transparently access all attributes and methods
|
||||
of the underlying model without explicitly defining each one.
|
||||
|
||||
Args:
|
||||
name: Attribute or method name
|
||||
|
||||
Returns:
|
||||
Attribute value or method
|
||||
|
||||
Raises:
|
||||
AttributeError: If attribute doesn't exist
|
||||
"""
|
||||
# Avoid recursion: raise error directly for special attributes
|
||||
if name in ('__isabstractmethod__', '__dict__', '__class__', '_model', '_config'):
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
||||
|
||||
# Try to get attribute from internal model
|
||||
try:
|
||||
# 使用 object.__getattribute__ 来安全地检查内部模型的属性
|
||||
attr = object.__getattribute__(self._model, name)
|
||||
|
||||
# 如果是方法,返回一个包装器来处理调用
|
||||
# If it's callable (a method)
|
||||
if callable(attr):
|
||||
# 流式方法直接返回,不包装(保持生成器特性)
|
||||
if name in ('_stream', '_astream', 'stream', 'astream'):
|
||||
# Streaming methods are returned directly to maintain generator characteristics
|
||||
# Note: Although we've explicitly implemented stream/astream,
|
||||
# this is kept to handle internal methods like _stream/_astream
|
||||
if name in ('_stream', '_astream'):
|
||||
return attr
|
||||
|
||||
# 非流式方法使用包装器处理异常
|
||||
# Wrap other methods for easier debugging and error handling
|
||||
def method_wrapper(*args, **kwargs):
|
||||
return attr(*args, **kwargs)
|
||||
try:
|
||||
return attr(*args, **kwargs)
|
||||
except Exception:
|
||||
# Can add logging or error handling here
|
||||
raise
|
||||
|
||||
# 保持方法的元信息
|
||||
# Preserve method metadata
|
||||
method_wrapper.__name__ = name
|
||||
method_wrapper.__doc__ = getattr(attr, '__doc__', f"Delegated method: {name}")
|
||||
return method_wrapper
|
||||
|
||||
# 如果是普通属性,直接返回
|
||||
# If it's a regular attribute, return directly
|
||||
return attr
|
||||
|
||||
except AttributeError:
|
||||
# 内部模型没有该属性,尝试回退实现
|
||||
# Internal model doesn't have this attribute either
|
||||
pass
|
||||
|
||||
# 检查是否有回退方法(使用安全的方式避免递归)
|
||||
# Check if there's a fallback method
|
||||
fallback_name = f'_fallback_{name}'
|
||||
try:
|
||||
fallback_method = object.__getattribute__(self, fallback_name)
|
||||
return fallback_method
|
||||
return object.__getattribute__(self, fallback_name)
|
||||
except AttributeError:
|
||||
# 没有回退方法,抛出适当的错误
|
||||
pass
|
||||
|
||||
# 如果都没有,抛出适当的错误
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
||||
# Nothing found, raise error
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{name}'. "
|
||||
f"The underlying model '{type(self._model).__name__}' also doesn't have this attribute."
|
||||
)
|
||||
|
||||
# ==================== Helper Methods ====================
|
||||
|
||||
def _create_model(self, config: RedBearModelConfig, type: ModelType) -> BaseLLM:
|
||||
"""创建内部模型实例"""
|
||||
"""Create internal model instance
|
||||
|
||||
Args:
|
||||
config: Model configuration
|
||||
type: Model type
|
||||
|
||||
Returns:
|
||||
Created model instance
|
||||
"""
|
||||
llm_class = get_provider_llm_class(config, type)
|
||||
model_params = RedBearModelFactory.get_model_params(config)
|
||||
return llm_class(**model_params)
|
||||
|
||||
|
||||
|
||||
|
||||
def get_config(self) -> RedBearModelConfig:
|
||||
"""Get model configuration
|
||||
|
||||
Returns:
|
||||
Model configuration object
|
||||
"""
|
||||
return self._config
|
||||
|
||||
def get_underlying_model(self) -> BaseLLM:
|
||||
"""Get underlying model instance
|
||||
|
||||
Returns:
|
||||
Underlying model instance
|
||||
"""
|
||||
return self._model
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return string representation of the object"""
|
||||
return (
|
||||
f"RedBearLLM("
|
||||
f"provider={self._config.provider}, "
|
||||
f"model={self._config.model_name}, "
|
||||
f"type={type(self._model).__name__}"
|
||||
f")"
|
||||
)
|
||||
@@ -1,12 +1,23 @@
|
||||
import xxhash
|
||||
from app.aioRedis import aio_redis_set, aio_redis_get
|
||||
import redis
|
||||
from app.core.config import settings
|
||||
|
||||
redis_client = redis.StrictRedis(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD,
|
||||
decode_responses=True,
|
||||
max_connections=30
|
||||
)
|
||||
|
||||
|
||||
def get_llm_cache(llmnm, txt, history, genconf):
|
||||
hasher = xxhash.xxh64()
|
||||
hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8"))
|
||||
hasher.update((str(llmnm) + str(txt) + str(history) + str(genconf)).encode("utf-8"))
|
||||
|
||||
k = hasher.hexdigest()
|
||||
bin = aio_redis_get(k)
|
||||
bin = redis_client.get(k)
|
||||
if not bin:
|
||||
return None
|
||||
return bin
|
||||
@@ -14,6 +25,6 @@ def get_llm_cache(llmnm, txt, history, genconf):
|
||||
|
||||
def set_llm_cache(llmnm, txt, v, history, genconf):
|
||||
hasher = xxhash.xxh64()
|
||||
hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8"))
|
||||
hasher.update((str(llmnm) + str(txt) + str(history) + str(genconf)).encode("utf-8"))
|
||||
k = hasher.hexdigest()
|
||||
aio_redis_set(k, v.encode("utf-8"), 24 * 3600)
|
||||
redis_client.set(k, v.encode("utf-8"), 24 * 3600)
|
||||
|
||||
@@ -119,7 +119,7 @@ def keyword_extraction(chat_mdl, content, topn=3):
|
||||
rendered_prompt = template.render(content=content, topn=topn)
|
||||
|
||||
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
_, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096))
|
||||
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
@@ -194,7 +194,7 @@ def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
|
||||
)
|
||||
|
||||
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
_, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096))
|
||||
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.5})
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
@@ -314,7 +314,7 @@ def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defi
|
||||
hist[-1]["content"] += user_prompt
|
||||
else:
|
||||
hist.append({"role": "user", "content": user_prompt})
|
||||
_, msg = message_fit_in(hist, chat_mdl.max_length)
|
||||
_, msg = message_fit_in(hist, getattr(chat_mdl, 'max_length', 8096))
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
return """
|
||||
@@ -341,7 +341,7 @@ def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defin
|
||||
params=json.dumps(params, ensure_ascii=False, indent=2),
|
||||
result=result)
|
||||
user_prompt = "→ Summary: "
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), getattr(chat_mdl, 'max_length', 8096))
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
|
||||
@@ -350,7 +350,7 @@ def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[st
|
||||
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
|
||||
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
|
||||
user_prompt = " → rank: "
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), getattr(chat_mdl, 'max_length', 8096))
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>")
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
|
||||
@@ -378,7 +378,7 @@ def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None):
|
||||
cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf)
|
||||
if cached:
|
||||
return json_repair.loads(cached)
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), getattr(chat_mdl, 'max_length', 8096))
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:],gen_conf=gen_conf)
|
||||
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
|
||||
try:
|
||||
@@ -641,7 +641,7 @@ def split_chunks(chunks, max_length: int):
|
||||
|
||||
|
||||
async def run_toc_from_text(chunks, chat_mdl, callback=None):
|
||||
input_budget = int(chat_mdl.max_length * INPUT_UTILIZATION) - num_tokens_from_string(
|
||||
input_budget = int(getattr(chat_mdl, 'max_length', 8096) * INPUT_UTILIZATION) - num_tokens_from_string(
|
||||
TOC_FROM_TEXT_USER + TOC_FROM_TEXT_SYSTEM
|
||||
)
|
||||
|
||||
|
||||
37
api/app/core/tools/__init__.py
Normal file
37
api/app/core/tools/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""工具管理核心模块"""
|
||||
|
||||
from .base import BaseTool, ToolResult, ToolParameter
|
||||
from .registry import ToolRegistry
|
||||
from .executor import ToolExecutor
|
||||
from .langchain_adapter import LangchainAdapter
|
||||
from .config_manager import ConfigManager
|
||||
from .chain_manager import ChainManager
|
||||
|
||||
# 可选导入,避免导入错误
|
||||
try:
|
||||
from .custom.base import CustomTool
|
||||
except ImportError:
|
||||
CustomTool = None
|
||||
|
||||
try:
|
||||
from .mcp.base import MCPTool
|
||||
except ImportError:
|
||||
MCPTool = None
|
||||
|
||||
__all__ = [
|
||||
"BaseTool",
|
||||
"ToolResult",
|
||||
"ToolParameter",
|
||||
"ToolRegistry",
|
||||
"ToolExecutor",
|
||||
"LangchainAdapter",
|
||||
"ConfigManager",
|
||||
"ChainManager"
|
||||
]
|
||||
|
||||
# 只有在成功导入时才添加到__all__
|
||||
if CustomTool:
|
||||
__all__.append("CustomTool")
|
||||
|
||||
if MCPTool:
|
||||
__all__.append("MCPTool")
|
||||
302
api/app/core/tools/base.py
Normal file
302
api/app/core/tools/base.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""工具基础接口定义"""
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from enum import Enum
|
||||
|
||||
from app.models.tool_model import ToolType, ToolStatus
|
||||
|
||||
|
||||
class ParameterType(str, Enum):
|
||||
"""参数类型枚举"""
|
||||
STRING = "string"
|
||||
INTEGER = "integer"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
|
||||
|
||||
class ToolParameter(BaseModel):
|
||||
"""工具参数定义"""
|
||||
name: str = Field(..., description="参数名称")
|
||||
type: ParameterType = Field(..., description="参数类型")
|
||||
description: str = Field("", description="参数描述")
|
||||
required: bool = Field(False, description="是否必需")
|
||||
default: Any = Field(None, description="默认值")
|
||||
enum: Optional[List[Any]] = Field(None, description="枚举值")
|
||||
minimum: Optional[Union[int, float]] = Field(None, description="最小值")
|
||||
maximum: Optional[Union[int, float]] = Field(None, description="最大值")
|
||||
pattern: Optional[str] = Field(None, description="正则表达式模式")
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""工具执行结果"""
|
||||
success: bool = Field(..., description="执行是否成功")
|
||||
data: Any = Field(None, description="返回数据")
|
||||
error: Optional[str] = Field(None, description="错误信息")
|
||||
error_code: Optional[str] = Field(None, description="错误代码")
|
||||
execution_time: float = Field(..., description="执行时间(秒)")
|
||||
token_usage: Optional[Dict[str, int]] = Field(None, description="Token使用情况")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="额外元数据")
|
||||
|
||||
@classmethod
|
||||
def success_result(
|
||||
cls,
|
||||
data: Any,
|
||||
execution_time: float,
|
||||
token_usage: Optional[Dict[str, int]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> "ToolResult":
|
||||
"""创建成功结果"""
|
||||
return cls(
|
||||
success=True,
|
||||
data=data,
|
||||
execution_time=execution_time,
|
||||
token_usage=token_usage,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def error_result(
|
||||
cls,
|
||||
error: str,
|
||||
execution_time: float,
|
||||
error_code: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> "ToolResult":
|
||||
"""创建错误结果"""
|
||||
return cls(
|
||||
success=False,
|
||||
error=error,
|
||||
error_code=error_code,
|
||||
execution_time=execution_time,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
|
||||
class ToolInfo(BaseModel):
|
||||
"""工具信息"""
|
||||
id: str = Field(..., description="工具ID")
|
||||
name: str = Field(..., description="工具名称")
|
||||
description: str = Field(..., description="工具描述")
|
||||
tool_type: ToolType = Field(..., description="工具类型")
|
||||
version: str = Field("1.0.0", description="工具版本")
|
||||
parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数")
|
||||
status: ToolStatus = Field(ToolStatus.ACTIVE, description="工具状态")
|
||||
tags: List[str] = Field(default_factory=list, description="工具标签")
|
||||
tenant_id: Optional[str] = Field(None, description="租户ID")
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class BaseTool(ABC):
|
||||
"""所有工具的基础抽象类"""
|
||||
|
||||
def __init__(self, tool_id: str, config: Dict[str, Any]):
|
||||
"""初始化工具
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
config: 工具配置
|
||||
"""
|
||||
self.tool_id = tool_id
|
||||
self.config = config
|
||||
self._status = ToolStatus.ACTIVE
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""工具名称"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str:
|
||||
"""工具描述"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def tool_type(self) -> ToolType:
|
||||
"""工具类型"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def version(self) -> str:
|
||||
"""工具版本"""
|
||||
return self.config.get("version", "1.0.0")
|
||||
|
||||
@property
|
||||
def status(self) -> ToolStatus:
|
||||
"""工具状态"""
|
||||
return self._status
|
||||
|
||||
@status.setter
|
||||
def status(self, value: ToolStatus):
|
||||
"""设置工具状态"""
|
||||
self._status = value
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
"""工具参数定义"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def tags(self) -> List[str]:
|
||||
"""工具标签"""
|
||||
return self.config.get("tags", [])
|
||||
|
||||
def get_info(self) -> ToolInfo:
|
||||
"""获取工具信息"""
|
||||
return ToolInfo(
|
||||
id=self.tool_id,
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
tool_type=self.tool_type,
|
||||
version=self.version,
|
||||
parameters=self.parameters,
|
||||
status=self.status,
|
||||
tags=self.tags,
|
||||
tenant_id=self.config.get("tenant_id")
|
||||
)
|
||||
|
||||
def validate_parameters(self, parameters: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""验证参数
|
||||
|
||||
Args:
|
||||
parameters: 输入参数
|
||||
|
||||
Returns:
|
||||
验证错误字典,空字典表示验证通过
|
||||
"""
|
||||
errors = {}
|
||||
param_definitions = {p.name: p for p in self.parameters}
|
||||
|
||||
# 检查必需参数
|
||||
for param_def in self.parameters:
|
||||
if param_def.required and param_def.name not in parameters:
|
||||
errors[param_def.name] = f"Required parameter '{param_def.name}' is missing"
|
||||
|
||||
# 检查参数类型和约束
|
||||
for param_name, param_value in parameters.items():
|
||||
if param_name not in param_definitions:
|
||||
continue
|
||||
|
||||
param_def = param_definitions[param_name]
|
||||
|
||||
# 类型检查
|
||||
if not self._validate_parameter_type(param_value, param_def):
|
||||
errors[param_name] = f"Parameter '{param_name}' has invalid type, expected {param_def.type}"
|
||||
|
||||
# 约束检查
|
||||
constraint_error = self._validate_parameter_constraints(param_value, param_def)
|
||||
if constraint_error:
|
||||
errors[param_name] = constraint_error
|
||||
|
||||
return errors
|
||||
|
||||
def _validate_parameter_type(self, value: Any, param_def: ToolParameter) -> bool:
|
||||
"""验证参数类型"""
|
||||
if value is None:
|
||||
return not param_def.required
|
||||
|
||||
type_mapping = {
|
||||
ParameterType.STRING: str,
|
||||
ParameterType.INTEGER: int,
|
||||
ParameterType.NUMBER: (int, float),
|
||||
ParameterType.BOOLEAN: bool,
|
||||
ParameterType.ARRAY: list,
|
||||
ParameterType.OBJECT: dict
|
||||
}
|
||||
|
||||
expected_type = type_mapping.get(param_def.type)
|
||||
if expected_type:
|
||||
return isinstance(value, expected_type)
|
||||
|
||||
return True
|
||||
|
||||
def _validate_parameter_constraints(self, value: Any, param_def: ToolParameter) -> Optional[str]:
|
||||
"""验证参数约束"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# 枚举值检查
|
||||
if param_def.enum and value not in param_def.enum:
|
||||
return f"Value must be one of {param_def.enum}"
|
||||
|
||||
# 数值范围检查
|
||||
if param_def.type in [ParameterType.INTEGER, ParameterType.NUMBER]:
|
||||
if param_def.minimum is not None and value < param_def.minimum:
|
||||
return f"Value must be >= {param_def.minimum}"
|
||||
if param_def.maximum is not None and value > param_def.maximum:
|
||||
return f"Value must be <= {param_def.maximum}"
|
||||
|
||||
# 字符串模式检查
|
||||
if param_def.type == ParameterType.STRING and param_def.pattern:
|
||||
import re
|
||||
if not re.match(param_def.pattern, str(value)):
|
||||
return f"Value must match pattern: {param_def.pattern}"
|
||||
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行工具
|
||||
|
||||
Args:
|
||||
**kwargs: 工具参数
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
pass
|
||||
|
||||
async def safe_execute(self, **kwargs) -> ToolResult:
|
||||
"""安全执行工具(包含参数验证和异常处理)
|
||||
|
||||
Args:
|
||||
**kwargs: 工具参数
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 参数验证
|
||||
validation_errors = self.validate_parameters(kwargs)
|
||||
if validation_errors:
|
||||
execution_time = time.time() - start_time
|
||||
error_msg = "; ".join([f"{k}: {v}" for k, v in validation_errors.items()])
|
||||
return ToolResult.error_result(
|
||||
error=f"Parameter validation failed: {error_msg}",
|
||||
error_code="VALIDATION_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
# 执行工具
|
||||
result = await self.execute(**kwargs)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="EXECUTION_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
def to_langchain_tool(self):
|
||||
"""转换为Langchain工具格式"""
|
||||
from .langchain_adapter import LangchainAdapter
|
||||
return LangchainAdapter.convert_tool(self)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>"
|
||||
17
api/app/core/tools/builtin/__init__.py
Normal file
17
api/app/core/tools/builtin/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""内置工具模块"""
|
||||
|
||||
from .base import BuiltinTool
|
||||
from .datetime_tool import DateTimeTool
|
||||
from .json_tool import JsonTool
|
||||
from .baidu_search_tool import BaiduSearchTool
|
||||
from .mineru_tool import MinerUTool
|
||||
from .textin_tool import TextInTool
|
||||
|
||||
__all__ = [
|
||||
"BuiltinTool",
|
||||
"DateTimeTool",
|
||||
"JsonTool",
|
||||
"BaiduSearchTool",
|
||||
"MinerUTool",
|
||||
"TextInTool"
|
||||
]
|
||||
334
api/app/core/tools/builtin/baidu_search_tool.py
Normal file
334
api/app/core/tools/builtin/baidu_search_tool.py
Normal file
@@ -0,0 +1,334 @@
|
||||
"""百度搜索工具 - 搜索引擎服务"""
|
||||
import time
|
||||
from typing import List, Dict, Any
|
||||
import aiohttp
|
||||
|
||||
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
|
||||
from .base import BuiltinTool
|
||||
|
||||
|
||||
class BaiduSearchTool(BuiltinTool):
|
||||
"""百度搜索工具 - 提供网页搜索、新闻搜索、图片搜索、实时结果"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "baidu_search_tool"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "百度搜索 - 搜索引擎服务:网页搜索、新闻搜索、图片搜索、实时结果"
|
||||
|
||||
def get_required_config_parameters(self) -> List[str]:
|
||||
return ["api_key"]
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
return [
|
||||
ToolParameter(
|
||||
name="query",
|
||||
type=ParameterType.STRING,
|
||||
description="搜索关键词",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="search_type",
|
||||
type=ParameterType.STRING,
|
||||
description="搜索类型",
|
||||
required=False,
|
||||
default="web",
|
||||
enum=["web", "news", "image", "video"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="page_size",
|
||||
type=ParameterType.INTEGER,
|
||||
description="每页结果数",
|
||||
required=False,
|
||||
default=10,
|
||||
minimum=1,
|
||||
maximum=50
|
||||
),
|
||||
ToolParameter(
|
||||
name="page_num",
|
||||
type=ParameterType.INTEGER,
|
||||
description="页码(从1开始)",
|
||||
required=False,
|
||||
default=1,
|
||||
minimum=1,
|
||||
maximum=10
|
||||
),
|
||||
ToolParameter(
|
||||
name="safe_search",
|
||||
type=ParameterType.BOOLEAN,
|
||||
description="是否启用安全搜索",
|
||||
required=False,
|
||||
default=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="region",
|
||||
type=ParameterType.STRING,
|
||||
description="搜索地区",
|
||||
required=False,
|
||||
default="cn",
|
||||
enum=["cn", "hk", "tw", "us", "jp", "kr"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="time_filter",
|
||||
type=ParameterType.STRING,
|
||||
description="时间过滤",
|
||||
required=False,
|
||||
enum=["all", "day", "week", "month", "year"]
|
||||
)
|
||||
]
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行百度搜索"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
query = kwargs.get("query")
|
||||
search_type = kwargs.get("search_type", "web")
|
||||
page_size = kwargs.get("page_size", 10)
|
||||
page_num = kwargs.get("page_num", 1)
|
||||
safe_search = kwargs.get("safe_search", True)
|
||||
region = kwargs.get("region", "cn")
|
||||
time_filter = kwargs.get("time_filter")
|
||||
|
||||
if not query:
|
||||
raise ValueError("query 参数是必需的")
|
||||
|
||||
# 根据搜索类型调用不同的API
|
||||
if search_type == "web":
|
||||
result = await self._web_search(query, page_size, page_num, safe_search, region, time_filter)
|
||||
elif search_type == "news":
|
||||
result = await self._news_search(query, page_size, page_num, region, time_filter)
|
||||
elif search_type == "image":
|
||||
result = await self._image_search(query, page_size, page_num, safe_search)
|
||||
elif search_type == "video":
|
||||
result = await self._video_search(query, page_size, page_num, safe_search)
|
||||
else:
|
||||
raise ValueError(f"不支持的搜索类型: {search_type}")
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="BAIDU_SEARCH_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
async def _web_search(self, query: str, page_size: int, page_num: int,
|
||||
safe_search: bool, region: str, time_filter: str = None) -> Dict[str, Any]:
|
||||
"""网页搜索"""
|
||||
payload = {
|
||||
"messages": [{"role": "user", "content": query}],
|
||||
"edition": "standard",
|
||||
"search_source": "baidu_search_v2",
|
||||
"resource_type_filter": [{"type": "web", "top_k": min(page_size, 50)}],
|
||||
"enable_full_content": True
|
||||
}
|
||||
|
||||
if time_filter:
|
||||
time_map = {"day": "now-1d/d", "week": "now-1w/d", "month": "now-1M/d", "year": "now-1y/d"}
|
||||
if time_filter in time_map:
|
||||
payload["search_filter"] = {"range": {"page_time": {"gte": time_map[time_filter], "lt": "now/d"}}}
|
||||
payload["search_recency_filter"] = time_filter
|
||||
|
||||
results = await self._call_baidu_ai_search_api(payload)
|
||||
|
||||
search_results = []
|
||||
if "references" in results:
|
||||
for item in results["references"]:
|
||||
search_results.append({
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("content", ""),
|
||||
"display_url": item.get("url", ""),
|
||||
"rank": len(search_results) + 1
|
||||
})
|
||||
|
||||
return {
|
||||
"search_type": "web",
|
||||
"query": query,
|
||||
"total_results": len(search_results),
|
||||
"page_num": page_num,
|
||||
"page_size": page_size,
|
||||
"results": search_results,
|
||||
"answer": results.get("result", ""),
|
||||
"references": results.get("references", [])
|
||||
}
|
||||
|
||||
async def _news_search(self, query: str, page_size: int, page_num: int,
|
||||
region: str, time_filter: str = None) -> Dict[str, Any]:
|
||||
"""新闻搜索"""
|
||||
payload = {
|
||||
"messages": [{"role": "user", "content": query}],
|
||||
"edition": "standard",
|
||||
"search_source": "baidu_search_v2",
|
||||
"resource_type_filter": [{"type": "new", "top_k": min(page_size, 50)}],
|
||||
"enable_full_content": True
|
||||
}
|
||||
|
||||
if time_filter:
|
||||
time_map = {"day": "now-1d/d", "week": "now-1w/d", "month": "now-1M/d", "year": "now-1y/d"}
|
||||
if time_filter in time_map:
|
||||
payload["search_filter"] = {"range": {"page_time": {"gte": time_map[time_filter], "lt": "now/d"}}}
|
||||
payload["search_recency_filter"] = time_filter
|
||||
|
||||
results = await self._call_baidu_ai_search_api(payload)
|
||||
|
||||
search_results = []
|
||||
if "references" in results:
|
||||
for item in results["references"]:
|
||||
search_results.append({
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("content", ""),
|
||||
"display_url": item.get("url", ""),
|
||||
"rank": len(search_results) + 1
|
||||
})
|
||||
|
||||
return {
|
||||
"search_type": "new",
|
||||
"query": query,
|
||||
"total_results": len(search_results),
|
||||
"page_num": page_num,
|
||||
"page_size": page_size,
|
||||
"results": search_results,
|
||||
"answer": results.get("result", ""),
|
||||
"references": results.get("references", [])
|
||||
}
|
||||
|
||||
async def _image_search(self, query: str, page_size: int, page_num: int,
|
||||
safe_search: bool) -> Dict[str, Any]:
|
||||
"""图片搜索"""
|
||||
payload = {
|
||||
"messages": [{"role": "user", "content": query}],
|
||||
"edition": "standard",
|
||||
"search_source": "baidu_search_v2",
|
||||
"resource_type_filter": [{"type": "image", "top_k": min(page_size, 30)}],
|
||||
"enable_full_content": True
|
||||
}
|
||||
|
||||
results = await self._call_baidu_ai_search_api(payload)
|
||||
|
||||
search_results = []
|
||||
if "references" in results:
|
||||
for item in results["references"]:
|
||||
search_results.append({
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("content", ""),
|
||||
"display_url": item.get("url", ""),
|
||||
"rank": len(search_results) + 1
|
||||
})
|
||||
|
||||
return {
|
||||
"search_type": "image",
|
||||
"query": query,
|
||||
"total_results": len(search_results),
|
||||
"page_num": page_num,
|
||||
"page_size": page_size,
|
||||
"results": search_results,
|
||||
"answer": results.get("result", ""),
|
||||
"references": results.get("references", [])
|
||||
}
|
||||
|
||||
async def _video_search(self, query: str, page_size: int, page_num: int,
|
||||
safe_search: bool) -> Dict[str, Any]:
|
||||
"""视频搜索"""
|
||||
payload = {
|
||||
"messages": [{"role": "user", "content": query}],
|
||||
"edition": "standard",
|
||||
"search_source": "baidu_search_v2",
|
||||
"resource_type_filter": [{"type": "video", "top_k": min(page_size, 10)}],
|
||||
"enable_full_content": True
|
||||
}
|
||||
|
||||
results = await self._call_baidu_ai_search_api(payload)
|
||||
|
||||
search_results = []
|
||||
if "references" in results:
|
||||
for item in results["references"]:
|
||||
search_results.append({
|
||||
"title": item.get("title", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("content", ""),
|
||||
"display_url": item.get("url", ""),
|
||||
"rank": len(search_results) + 1
|
||||
})
|
||||
|
||||
return {
|
||||
"search_type": "video",
|
||||
"query": query,
|
||||
"total_results": len(search_results),
|
||||
"page_num": page_num,
|
||||
"page_size": page_size,
|
||||
"results": search_results,
|
||||
"answer": results.get("result", ""),
|
||||
"references": results.get("references", [])
|
||||
}
|
||||
|
||||
async def _call_baidu_ai_search_api(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""调用百度AI搜索API"""
|
||||
api_key = self.get_config_parameter("api_key")
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("百度搜索API密钥未配置")
|
||||
|
||||
url = "https://qianfan.baidubce.com/v2/ai_search/chat/completions"
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {api_key}'
|
||||
}
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=30)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(url, headers=headers, json=payload) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
raise Exception(f"HTTP错误: {response.status}")
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试连接"""
|
||||
try:
|
||||
api_key = self.get_config_parameter("api_key")
|
||||
|
||||
if not api_key:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "API密钥未配置"
|
||||
}
|
||||
|
||||
# 发送测试请求验证API key是否有效
|
||||
test_payload = {
|
||||
"messages": [{"role": "user", "content": "test"}],
|
||||
"edition": "standard",
|
||||
"search_source": "baidu_search_v2",
|
||||
"resource_type_filter": [{"type": "web", "top_k": 1}]
|
||||
}
|
||||
|
||||
try:
|
||||
await self._call_baidu_ai_search_api(test_payload)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接测试成功",
|
||||
"api_key_masked": api_key[:8] + "***" if len(api_key) > 8 else "***"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"API连接失败: {str(e)}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
118
api/app/core/tools/builtin/base.py
Normal file
118
api/app/core/tools/builtin/base.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""内置工具基类"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from app.models.tool_model import ToolType
|
||||
from app.core.tools.base import BaseTool, ToolResult, ToolParameter
|
||||
|
||||
|
||||
class BuiltinTool(BaseTool, ABC):
|
||||
"""内置工具基类"""
|
||||
|
||||
def __init__(self, tool_id: str, config: Dict[str, Any]):
|
||||
"""初始化内置工具
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
config: 工具配置
|
||||
"""
|
||||
super().__init__(tool_id, config)
|
||||
self.parameters_config = config.get("parameters", {})
|
||||
|
||||
@property
|
||||
def tool_type(self) -> ToolType:
|
||||
"""工具类型"""
|
||||
return ToolType.BUILTIN
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""工具名称 - 子类必须实现"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str:
|
||||
"""工具描述 - 子类必须实现"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
"""工具参数定义 - 子类必须实现"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行工具 - 子类必须实现
|
||||
|
||||
Args:
|
||||
**kwargs: 工具参数
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
"""检查工具是否已正确配置"""
|
||||
required_params = self.get_required_config_parameters()
|
||||
for param in required_params:
|
||||
if not self.parameters_config.get(param):
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_required_config_parameters(self) -> List[str]:
|
||||
"""获取必需的配置参数列表
|
||||
|
||||
Returns:
|
||||
必需配置参数名称列表
|
||||
"""
|
||||
return []
|
||||
|
||||
def get_config_parameter(self, name: str, default: Any = None) -> Any:
|
||||
"""获取配置参数值
|
||||
|
||||
Args:
|
||||
name: 参数名称
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
参数值
|
||||
"""
|
||||
return self.parameters_config.get(name, default)
|
||||
|
||||
def validate_configuration(self) -> tuple[bool, str]:
|
||||
"""验证工具配置
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
if not self.is_configured:
|
||||
required_params = self.get_required_config_parameters()
|
||||
missing_params = [p for p in required_params if not self.parameters_config.get(p)]
|
||||
return False, f"缺少必需的配置参数: {', '.join(missing_params)}"
|
||||
|
||||
return True, ""
|
||||
|
||||
async def safe_execute(self, **kwargs) -> ToolResult:
|
||||
"""安全执行工具(包含配置验证)
|
||||
|
||||
Args:
|
||||
**kwargs: 工具参数
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
# 首先验证配置
|
||||
is_valid, error_msg = self.validate_configuration()
|
||||
if not is_valid:
|
||||
return ToolResult.error_result(
|
||||
error=f"工具配置无效: {error_msg}",
|
||||
error_code="CONFIGURATION_ERROR",
|
||||
execution_time=0.0
|
||||
)
|
||||
|
||||
# 调用父类的安全执行
|
||||
return await super().safe_execute(**kwargs)
|
||||
307
api/app/core/tools/builtin/datetime_tool.py
Normal file
307
api/app/core/tools/builtin/datetime_tool.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""时间工具 - 日期时间处理"""
|
||||
import time
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List
|
||||
import pytz
|
||||
|
||||
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
|
||||
from .base import BuiltinTool
|
||||
|
||||
|
||||
class DateTimeTool(BuiltinTool):
|
||||
"""时间工具 - 提供时间格式转换、时区转换、时间戳转换、时间计算功能"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "datetime_tool"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "时间工具 - 日期时间处理:提供时间格式转化、时区转换、时间戳转换、时间计算"
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
return [
|
||||
ToolParameter(
|
||||
name="operation",
|
||||
type=ParameterType.STRING,
|
||||
description="操作类型",
|
||||
required=True,
|
||||
enum=["format", "convert_timezone", "timestamp_to_datetime", "datetime_to_timestamp", "calculate", "now"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_value",
|
||||
type=ParameterType.STRING,
|
||||
description="输入值(时间字符串或时间戳)",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ToolParameter(
|
||||
name="output_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ToolParameter(
|
||||
name="from_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="源时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="UTC"
|
||||
),
|
||||
ToolParameter(
|
||||
name="to_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="目标时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="UTC"
|
||||
),
|
||||
ToolParameter(
|
||||
name="calculation",
|
||||
type=ParameterType.STRING,
|
||||
description="时间计算表达式(如:+1d, -2h, +30m)",
|
||||
required=False
|
||||
)
|
||||
]
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行时间工具操作"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
operation = kwargs.get("operation")
|
||||
|
||||
if operation == "now":
|
||||
result = self._get_current_time(kwargs)
|
||||
elif operation == "format":
|
||||
result = self._format_datetime(kwargs)
|
||||
elif operation == "convert_timezone":
|
||||
result = self._convert_timezone(kwargs)
|
||||
elif operation == "timestamp_to_datetime":
|
||||
result = self._timestamp_to_datetime(kwargs)
|
||||
elif operation == "datetime_to_timestamp":
|
||||
result = self._datetime_to_timestamp(kwargs)
|
||||
elif operation == "calculate":
|
||||
result = self._calculate_datetime(kwargs)
|
||||
else:
|
||||
raise ValueError(f"不支持的操作类型: {operation}")
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="DATETIME_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
def _get_current_time(self, kwargs) -> dict:
|
||||
"""获取当前时间"""
|
||||
timezone_str = kwargs.get("to_timezone", "UTC")
|
||||
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
|
||||
|
||||
if timezone_str == "UTC":
|
||||
tz = timezone.utc
|
||||
else:
|
||||
tz = pytz.timezone(timezone_str)
|
||||
|
||||
now = datetime.now(tz)
|
||||
|
||||
return {
|
||||
"datetime": now.strftime(output_format),
|
||||
"timestamp": int(now.timestamp()),
|
||||
"timezone": timezone_str,
|
||||
"iso_format": now.isoformat()
|
||||
}
|
||||
|
||||
def _format_datetime(self, kwargs) -> dict:
|
||||
"""格式化时间"""
|
||||
input_value = kwargs.get("input_value")
|
||||
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
|
||||
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
|
||||
|
||||
if not input_value:
|
||||
raise ValueError("input_value 参数是必需的")
|
||||
|
||||
# 解析输入时间
|
||||
dt = datetime.strptime(input_value, input_format)
|
||||
|
||||
return {
|
||||
"original": input_value,
|
||||
"formatted": dt.strftime(output_format),
|
||||
"timestamp": int(dt.timestamp()),
|
||||
"iso_format": dt.isoformat()
|
||||
}
|
||||
|
||||
def _convert_timezone(self, kwargs) -> dict:
|
||||
"""时区转换"""
|
||||
input_value = kwargs.get("input_value")
|
||||
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
|
||||
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
|
||||
from_timezone = kwargs.get("from_timezone", "UTC")
|
||||
to_timezone = kwargs.get("to_timezone", "UTC")
|
||||
|
||||
if not input_value:
|
||||
raise ValueError("input_value 参数是必需的")
|
||||
|
||||
# 解析输入时间
|
||||
dt = datetime.strptime(input_value, input_format)
|
||||
|
||||
# 设置源时区
|
||||
if from_timezone == "UTC":
|
||||
from_tz = pytz.UTC
|
||||
else:
|
||||
from_tz = pytz.timezone(from_timezone)
|
||||
|
||||
# 设置目标时区
|
||||
if to_timezone == "UTC":
|
||||
to_tz = pytz.UTC
|
||||
else:
|
||||
to_tz = pytz.timezone(to_timezone)
|
||||
|
||||
# 本地化时间并转换时区
|
||||
if dt.tzinfo is None:
|
||||
dt = from_tz.localize(dt)
|
||||
|
||||
converted_dt = dt.astimezone(to_tz)
|
||||
|
||||
return {
|
||||
"original": input_value,
|
||||
"original_timezone": from_timezone,
|
||||
"converted": converted_dt.strftime(output_format),
|
||||
"converted_timezone": to_timezone,
|
||||
"timestamp": int(converted_dt.timestamp())
|
||||
}
|
||||
|
||||
def _timestamp_to_datetime(self, kwargs) -> dict:
|
||||
"""时间戳转日期时间"""
|
||||
input_value = kwargs.get("input_value")
|
||||
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
|
||||
timezone_str = kwargs.get("to_timezone", "UTC")
|
||||
|
||||
if not input_value:
|
||||
raise ValueError("input_value 参数是必需的")
|
||||
|
||||
# 转换时间戳
|
||||
timestamp = float(input_value)
|
||||
|
||||
# 设置时区
|
||||
if timezone_str == "UTC":
|
||||
tz = timezone.utc
|
||||
else:
|
||||
tz = pytz.timezone(timezone_str)
|
||||
|
||||
dt = datetime.fromtimestamp(timestamp, tz)
|
||||
|
||||
return {
|
||||
"timestamp": timestamp,
|
||||
"datetime": dt.strftime(output_format),
|
||||
"timezone": timezone_str,
|
||||
"iso_format": dt.isoformat()
|
||||
}
|
||||
|
||||
def _datetime_to_timestamp(self, kwargs) -> dict:
|
||||
"""日期时间转时间戳"""
|
||||
input_value = kwargs.get("input_value")
|
||||
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
|
||||
timezone_str = kwargs.get("from_timezone", "UTC")
|
||||
|
||||
if not input_value:
|
||||
raise ValueError("input_value 参数是必需的")
|
||||
|
||||
# 解析输入时间
|
||||
dt = datetime.strptime(input_value, input_format)
|
||||
|
||||
# 设置时区
|
||||
if timezone_str == "UTC":
|
||||
tz = timezone.utc
|
||||
else:
|
||||
tz = pytz.timezone(timezone_str)
|
||||
|
||||
# 本地化时间
|
||||
if dt.tzinfo is None:
|
||||
dt = tz.localize(dt)
|
||||
|
||||
return {
|
||||
"datetime": input_value,
|
||||
"timezone": timezone_str,
|
||||
"timestamp": int(dt.timestamp()),
|
||||
"iso_format": dt.isoformat()
|
||||
}
|
||||
|
||||
def _calculate_datetime(self, kwargs) -> dict:
|
||||
"""时间计算"""
|
||||
input_value = kwargs.get("input_value")
|
||||
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
|
||||
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
|
||||
calculation = kwargs.get("calculation")
|
||||
timezone_str = kwargs.get("from_timezone", "UTC")
|
||||
|
||||
if not input_value:
|
||||
raise ValueError("input_value 参数是必需的")
|
||||
|
||||
if not calculation:
|
||||
raise ValueError("calculation 参数是必需的")
|
||||
|
||||
# 解析输入时间
|
||||
dt = datetime.strptime(input_value, input_format)
|
||||
|
||||
# 设置时区
|
||||
if timezone_str == "UTC":
|
||||
tz = timezone.utc
|
||||
else:
|
||||
tz = pytz.timezone(timezone_str)
|
||||
|
||||
if dt.tzinfo is None:
|
||||
dt = tz.localize(dt)
|
||||
|
||||
# 解析计算表达式
|
||||
delta = self._parse_time_delta(calculation)
|
||||
calculated_dt = dt + delta
|
||||
|
||||
return {
|
||||
"original": input_value,
|
||||
"calculation": calculation,
|
||||
"result": calculated_dt.strftime(output_format),
|
||||
"timezone": timezone_str,
|
||||
"timestamp": int(calculated_dt.timestamp())
|
||||
}
|
||||
|
||||
def _parse_time_delta(self, calculation: str) -> timedelta:
|
||||
"""解析时间计算表达式"""
|
||||
import re
|
||||
|
||||
# 支持的单位:d(天), h(小时), m(分钟), s(秒)
|
||||
pattern = r'([+-]?\d+)([dhms])'
|
||||
matches = re.findall(pattern, calculation.lower())
|
||||
|
||||
if not matches:
|
||||
raise ValueError(f"无效的时间计算表达式: {calculation}")
|
||||
|
||||
total_delta = timedelta()
|
||||
|
||||
for value_str, unit in matches:
|
||||
value = int(value_str)
|
||||
|
||||
if unit == 'd':
|
||||
total_delta += timedelta(days=value)
|
||||
elif unit == 'h':
|
||||
total_delta += timedelta(hours=value)
|
||||
elif unit == 'm':
|
||||
total_delta += timedelta(minutes=value)
|
||||
elif unit == 's':
|
||||
total_delta += timedelta(seconds=value)
|
||||
|
||||
return total_delta
|
||||
430
api/app/core/tools/builtin/json_tool.py
Normal file
430
api/app/core/tools/builtin/json_tool.py
Normal file
@@ -0,0 +1,430 @@
|
||||
"""JSON转换工具 - 数据格式转换"""
|
||||
import json
|
||||
import time
|
||||
from typing import List, Any, Dict
|
||||
import yaml
|
||||
import xml.etree.ElementTree as ET
|
||||
from xml.dom import minidom
|
||||
|
||||
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
|
||||
from .base import BuiltinTool
|
||||
|
||||
|
||||
class JsonTool(BuiltinTool):
|
||||
"""JSON转换工具 - 提供JSON格式化、压缩、验证、格式转换功能"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "json_tool"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "JSON转换工具 - 数据格式转换:JSON格式化、JSON压缩、JSON验证、格式转换"
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
return [
|
||||
ToolParameter(
|
||||
name="operation",
|
||||
type=ParameterType.STRING,
|
||||
description="操作类型",
|
||||
required=True,
|
||||
enum=["format", "minify", "validate", "convert", "to_yaml", "from_yaml", "to_xml", "from_xml", "merge", "extract"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_data",
|
||||
type=ParameterType.STRING,
|
||||
description="输入数据(JSON字符串、YAML字符串或XML字符串)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="indent",
|
||||
type=ParameterType.INTEGER,
|
||||
description="JSON格式化缩进空格数",
|
||||
required=False,
|
||||
default=2,
|
||||
minimum=0,
|
||||
maximum=8
|
||||
),
|
||||
ToolParameter(
|
||||
name="ensure_ascii",
|
||||
type=ParameterType.BOOLEAN,
|
||||
description="是否确保ASCII编码",
|
||||
required=False,
|
||||
default=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="sort_keys",
|
||||
type=ParameterType.BOOLEAN,
|
||||
description="是否对键进行排序",
|
||||
required=False,
|
||||
default=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="merge_data",
|
||||
type=ParameterType.STRING,
|
||||
description="要合并的JSON数据(用于merge操作)",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="json_path",
|
||||
type=ParameterType.STRING,
|
||||
description="JSON路径表达式(用于extract操作,如:$.user.name)",
|
||||
required=False
|
||||
)
|
||||
]
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行JSON工具操作"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
operation = kwargs.get("operation")
|
||||
input_data = kwargs.get("input_data")
|
||||
|
||||
if not input_data:
|
||||
raise ValueError("input_data 参数是必需的")
|
||||
|
||||
if operation == "format":
|
||||
result = self._format_json(input_data, kwargs)
|
||||
elif operation == "minify":
|
||||
result = self._minify_json(input_data)
|
||||
elif operation == "validate":
|
||||
result = self._validate_json(input_data)
|
||||
elif operation == "convert":
|
||||
result = self._convert_json(input_data)
|
||||
elif operation == "to_yaml":
|
||||
result = self._json_to_yaml(input_data)
|
||||
elif operation == "from_yaml":
|
||||
result = self._yaml_to_json(input_data, kwargs)
|
||||
elif operation == "to_xml":
|
||||
result = self._json_to_xml(input_data)
|
||||
elif operation == "from_xml":
|
||||
result = self._xml_to_json(input_data, kwargs)
|
||||
elif operation == "merge":
|
||||
result = self._merge_json(input_data, kwargs)
|
||||
elif operation == "extract":
|
||||
result = self._extract_json_path(input_data, kwargs)
|
||||
else:
|
||||
raise ValueError(f"不支持的操作类型: {operation}")
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="JSON_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
def _format_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""格式化JSON"""
|
||||
indent = kwargs.get("indent", 2)
|
||||
ensure_ascii = kwargs.get("ensure_ascii", False)
|
||||
sort_keys = kwargs.get("sort_keys", False)
|
||||
|
||||
# 解析JSON
|
||||
data = json.loads(input_data)
|
||||
|
||||
# 格式化输出
|
||||
formatted = json.dumps(
|
||||
data,
|
||||
indent=indent,
|
||||
ensure_ascii=ensure_ascii,
|
||||
sort_keys=sort_keys,
|
||||
separators=(',', ': ')
|
||||
)
|
||||
|
||||
return {
|
||||
"original_size": len(input_data),
|
||||
"formatted_size": len(formatted),
|
||||
"formatted_json": formatted,
|
||||
"is_valid": True,
|
||||
"settings": {
|
||||
"indent": indent,
|
||||
"ensure_ascii": ensure_ascii,
|
||||
"sort_keys": sort_keys
|
||||
}
|
||||
}
|
||||
|
||||
def _minify_json(self, input_data: str) -> Dict[str, Any]:
|
||||
"""压缩JSON"""
|
||||
# 解析并压缩
|
||||
data = json.loads(input_data)
|
||||
minified = json.dumps(data, separators=(',', ':'))
|
||||
|
||||
return {
|
||||
"original_size": len(input_data),
|
||||
"minified_size": len(minified),
|
||||
"compression_ratio": round((1 - len(minified) / len(input_data)) * 100, 2),
|
||||
"minified_json": minified,
|
||||
"is_valid": True
|
||||
}
|
||||
|
||||
def _validate_json(self, input_data: str) -> Dict[str, Any]:
|
||||
"""验证JSON"""
|
||||
try:
|
||||
data = json.loads(input_data)
|
||||
|
||||
# 统计信息
|
||||
stats = self._analyze_json_structure(data)
|
||||
|
||||
return {
|
||||
"is_valid": True,
|
||||
"error": None,
|
||||
"size": len(input_data),
|
||||
"structure": stats
|
||||
}
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
return {
|
||||
"is_valid": False,
|
||||
"error": str(e),
|
||||
"error_line": getattr(e, 'lineno', None),
|
||||
"error_column": getattr(e, 'colno', None),
|
||||
"size": len(input_data)
|
||||
}
|
||||
|
||||
def _convert_json(self, input_data: str) -> Dict[str, Any]:
|
||||
"""JSON转义"""
|
||||
data = json.loads(input_data)
|
||||
converted = json.dumps(data, ensure_ascii=False)
|
||||
|
||||
return {
|
||||
"converted_json": converted,
|
||||
"is_valid": True
|
||||
}
|
||||
|
||||
def _json_to_yaml(self, input_data: str) -> Dict[str, Any]:
|
||||
"""JSON转YAML"""
|
||||
data = json.loads(input_data)
|
||||
yaml_output = yaml.dump(data, default_flow_style=False, allow_unicode=True, indent=2)
|
||||
|
||||
return {
|
||||
"original_format": "json",
|
||||
"target_format": "yaml",
|
||||
"original_size": len(input_data),
|
||||
"converted_size": len(yaml_output),
|
||||
"converted_data": yaml_output
|
||||
}
|
||||
|
||||
def _yaml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""YAML转JSON"""
|
||||
indent = kwargs.get("indent", 2)
|
||||
ensure_ascii = kwargs.get("ensure_ascii", False)
|
||||
|
||||
data = yaml.safe_load(input_data)
|
||||
json_output = json.dumps(data, indent=indent, ensure_ascii=ensure_ascii)
|
||||
|
||||
return {
|
||||
"original_format": "yaml",
|
||||
"target_format": "json",
|
||||
"original_size": len(input_data),
|
||||
"converted_size": len(json_output),
|
||||
"converted_data": json_output
|
||||
}
|
||||
|
||||
def _json_to_xml(self, input_data: str) -> Dict[str, Any]:
|
||||
"""JSON转XML"""
|
||||
data = json.loads(input_data)
|
||||
|
||||
def dict_to_xml(data, root_name="root"):
|
||||
"""递归转换字典为XML"""
|
||||
if isinstance(data, dict):
|
||||
if len(data) == 1 and not root_name == "root":
|
||||
# 如果字典只有一个键,使用该键作为根元素
|
||||
key, value = next(iter(data.items()))
|
||||
return dict_to_xml(value, key)
|
||||
|
||||
root = ET.Element(root_name)
|
||||
for key, value in data.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
child = dict_to_xml(value, key)
|
||||
root.append(child)
|
||||
else:
|
||||
child = ET.SubElement(root, key)
|
||||
child.text = str(value)
|
||||
return root
|
||||
|
||||
elif isinstance(data, list):
|
||||
root = ET.Element(root_name)
|
||||
for i, item in enumerate(data):
|
||||
if isinstance(item, (dict, list)):
|
||||
child = dict_to_xml(item, f"item_{i}")
|
||||
root.append(child)
|
||||
else:
|
||||
child = ET.SubElement(root, f"item_{i}")
|
||||
child.text = str(item)
|
||||
return root
|
||||
|
||||
else:
|
||||
root = ET.Element(root_name)
|
||||
root.text = str(data)
|
||||
return root
|
||||
|
||||
xml_element = dict_to_xml(data)
|
||||
xml_string = ET.tostring(xml_element, encoding='unicode')
|
||||
|
||||
# 格式化XML
|
||||
dom = minidom.parseString(xml_string)
|
||||
formatted_xml = dom.toprettyxml(indent=" ")
|
||||
|
||||
# 移除空行
|
||||
formatted_xml = '\n'.join([line for line in formatted_xml.split('\n') if line.strip()])
|
||||
|
||||
return {
|
||||
"original_format": "json",
|
||||
"target_format": "xml",
|
||||
"original_size": len(input_data),
|
||||
"converted_size": len(formatted_xml),
|
||||
"converted_data": formatted_xml
|
||||
}
|
||||
|
||||
def _xml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""XML转JSON"""
|
||||
indent = kwargs.get("indent", 2)
|
||||
|
||||
def xml_to_dict(element):
|
||||
"""递归转换XML元素为字典"""
|
||||
result = {}
|
||||
|
||||
# 处理属性
|
||||
if element.attrib:
|
||||
result.update(element.attrib)
|
||||
|
||||
# 处理文本内容
|
||||
if element.text and element.text.strip():
|
||||
if len(element) == 0: # 叶子节点
|
||||
return element.text.strip()
|
||||
else:
|
||||
result['text'] = element.text.strip()
|
||||
|
||||
# 处理子元素
|
||||
for child in element:
|
||||
child_data = xml_to_dict(child)
|
||||
if child.tag in result:
|
||||
# 如果标签已存在,转换为列表
|
||||
if not isinstance(result[child.tag], list):
|
||||
result[child.tag] = [result[child.tag]]
|
||||
result[child.tag].append(child_data)
|
||||
else:
|
||||
result[child.tag] = child_data
|
||||
|
||||
return result
|
||||
|
||||
root = ET.fromstring(input_data)
|
||||
data = {root.tag: xml_to_dict(root)}
|
||||
json_output = json.dumps(data, indent=indent, ensure_ascii=False)
|
||||
|
||||
return {
|
||||
"original_format": "xml",
|
||||
"target_format": "json",
|
||||
"original_size": len(input_data),
|
||||
"converted_size": len(json_output),
|
||||
"converted_data": json_output
|
||||
}
|
||||
|
||||
def _merge_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""合并JSON"""
|
||||
merge_data = kwargs.get("merge_data")
|
||||
if not merge_data:
|
||||
raise ValueError("merge_data 参数是必需的")
|
||||
|
||||
data1 = json.loads(input_data)
|
||||
data2 = json.loads(merge_data)
|
||||
|
||||
def deep_merge(dict1, dict2):
|
||||
"""深度合并字典"""
|
||||
result = dict1.copy()
|
||||
for key, value in dict2.items():
|
||||
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
||||
result[key] = deep_merge(result[key], value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
if isinstance(data1, dict) and isinstance(data2, dict):
|
||||
merged = deep_merge(data1, data2)
|
||||
elif isinstance(data1, list) and isinstance(data2, list):
|
||||
merged = data1 + data2
|
||||
else:
|
||||
raise ValueError("无法合并不同类型的数据")
|
||||
|
||||
merged_json = json.dumps(merged, indent=2, ensure_ascii=False)
|
||||
|
||||
return {
|
||||
"operation": "merge",
|
||||
"original_size": len(input_data),
|
||||
"merge_size": len(merge_data),
|
||||
"result_size": len(merged_json),
|
||||
"merged_data": merged_json
|
||||
}
|
||||
|
||||
def _extract_json_path(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取JSON路径"""
|
||||
json_path = kwargs.get("json_path")
|
||||
if not json_path:
|
||||
raise ValueError("json_path 参数是必需的")
|
||||
|
||||
data = json.loads(input_data)
|
||||
|
||||
# 简单的JSONPath实现(支持基本的点号路径)
|
||||
try:
|
||||
result = data
|
||||
if json_path.startswith('$.'):
|
||||
path_parts = json_path[2:].split('.')
|
||||
else:
|
||||
path_parts = json_path.split('.')
|
||||
|
||||
for part in path_parts:
|
||||
if part.isdigit():
|
||||
result = result[int(part)]
|
||||
else:
|
||||
result = result[part]
|
||||
|
||||
extracted_json = json.dumps(result, indent=2, ensure_ascii=False)
|
||||
|
||||
return {
|
||||
"operation": "extract",
|
||||
"json_path": json_path,
|
||||
"found": True,
|
||||
"extracted_data": extracted_json,
|
||||
"data_type": type(result).__name__
|
||||
}
|
||||
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
return {
|
||||
"operation": "extract",
|
||||
"json_path": json_path,
|
||||
"found": False,
|
||||
"error": str(e),
|
||||
"extracted_data": None
|
||||
}
|
||||
|
||||
def _analyze_json_structure(self, data: Any, depth: int = 0) -> Dict[str, Any]:
|
||||
"""分析JSON结构"""
|
||||
if isinstance(data, dict):
|
||||
return {
|
||||
"type": "object",
|
||||
"keys": len(data),
|
||||
"depth": depth,
|
||||
"children": {k: self._analyze_json_structure(v, depth + 1) for k, v in data.items()}
|
||||
}
|
||||
elif isinstance(data, list):
|
||||
return {
|
||||
"type": "array",
|
||||
"length": len(data),
|
||||
"depth": depth,
|
||||
"item_types": list(set(type(item).__name__ for item in data))
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"type": type(data).__name__,
|
||||
"depth": depth,
|
||||
"value": str(data)[:100] + "..." if len(str(data)) > 100 else str(data)
|
||||
}
|
||||
327
api/app/core/tools/builtin/mineru_tool.py
Normal file
327
api/app/core/tools/builtin/mineru_tool.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""MinerU PDF解析工具"""
|
||||
import time
|
||||
from typing import List, Dict, Any
|
||||
import aiohttp
|
||||
|
||||
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
|
||||
from .base import BuiltinTool
|
||||
|
||||
|
||||
class MinerUTool(BuiltinTool):
|
||||
"""MinerU PDF解析工具 - 提供PDF解析、表格提取、图片识别、文本提取功能"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "mineru_tool"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "MinerU - PDF解析工具:PDF解析、表格提取、图片识别、文本提取"
|
||||
|
||||
def get_required_config_parameters(self) -> List[str]:
|
||||
return ["api_key", "api_url"]
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
return [
|
||||
ToolParameter(
|
||||
name="operation",
|
||||
type=ParameterType.STRING,
|
||||
description="操作类型",
|
||||
required=True,
|
||||
enum=["parse_pdf", "extract_text", "extract_tables", "extract_images", "analyze_layout"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="file_content",
|
||||
type=ParameterType.STRING,
|
||||
description="PDF文件内容(Base64编码)",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="file_url",
|
||||
type=ParameterType.STRING,
|
||||
description="PDF文件URL",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="parse_mode",
|
||||
type=ParameterType.STRING,
|
||||
description="解析模式",
|
||||
required=False,
|
||||
default="auto",
|
||||
enum=["auto", "text_only", "table_priority", "image_priority", "layout_analysis"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="extract_images",
|
||||
type=ParameterType.BOOLEAN,
|
||||
description="是否提取图片",
|
||||
required=False,
|
||||
default=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="extract_tables",
|
||||
type=ParameterType.BOOLEAN,
|
||||
description="是否提取表格",
|
||||
required=False,
|
||||
default=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="page_range",
|
||||
type=ParameterType.STRING,
|
||||
description="页面范围(如:1-5, 1,3,5)",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="output_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输出格式",
|
||||
required=False,
|
||||
default="json",
|
||||
enum=["json", "markdown", "html", "text"]
|
||||
)
|
||||
]
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行MinerU PDF解析"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
operation = kwargs.get("operation")
|
||||
file_content = kwargs.get("file_content")
|
||||
file_url = kwargs.get("file_url")
|
||||
|
||||
if not file_content and not file_url:
|
||||
raise ValueError("必须提供 file_content 或 file_url 参数")
|
||||
|
||||
if operation == "parse_pdf":
|
||||
result = await self._parse_pdf(kwargs)
|
||||
elif operation == "extract_text":
|
||||
result = await self._extract_text(kwargs)
|
||||
elif operation == "extract_tables":
|
||||
result = await self._extract_tables(kwargs)
|
||||
elif operation == "extract_images":
|
||||
result = await self._extract_images(kwargs)
|
||||
elif operation == "analyze_layout":
|
||||
result = await self._analyze_layout(kwargs)
|
||||
else:
|
||||
raise ValueError(f"不支持的操作类型: {operation}")
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="MINERU_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
async def _parse_pdf(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""完整PDF解析"""
|
||||
parse_mode = kwargs.get("parse_mode", "auto")
|
||||
extract_images = kwargs.get("extract_images", True)
|
||||
extract_tables = kwargs.get("extract_tables", True)
|
||||
page_range = kwargs.get("page_range")
|
||||
output_format = kwargs.get("output_format", "json")
|
||||
|
||||
# 构建请求参数
|
||||
request_data = {
|
||||
"parse_mode": parse_mode,
|
||||
"extract_images": extract_images,
|
||||
"extract_tables": extract_tables,
|
||||
"output_format": output_format
|
||||
}
|
||||
|
||||
if page_range:
|
||||
request_data["page_range"] = page_range
|
||||
|
||||
# 添加文件数据
|
||||
if kwargs.get("file_content"):
|
||||
request_data["file_content"] = kwargs["file_content"]
|
||||
elif kwargs.get("file_url"):
|
||||
request_data["file_url"] = kwargs["file_url"]
|
||||
|
||||
# 调用MinerU API
|
||||
result = await self._call_mineru_api("parse", request_data)
|
||||
|
||||
return {
|
||||
"operation": "parse_pdf",
|
||||
"parse_mode": parse_mode,
|
||||
"total_pages": result.get("total_pages", 0),
|
||||
"processed_pages": result.get("processed_pages", 0),
|
||||
"text_content": result.get("text_content", ""),
|
||||
"tables": result.get("tables", []),
|
||||
"images": result.get("images", []),
|
||||
"layout_info": result.get("layout_info", {}),
|
||||
"metadata": result.get("metadata", {}),
|
||||
"processing_time": result.get("processing_time", 0)
|
||||
}
|
||||
|
||||
async def _extract_text(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取文本"""
|
||||
page_range = kwargs.get("page_range")
|
||||
output_format = kwargs.get("output_format", "text")
|
||||
|
||||
request_data = {
|
||||
"operation": "extract_text",
|
||||
"output_format": output_format
|
||||
}
|
||||
|
||||
if page_range:
|
||||
request_data["page_range"] = page_range
|
||||
|
||||
if kwargs.get("file_content"):
|
||||
request_data["file_content"] = kwargs["file_content"]
|
||||
elif kwargs.get("file_url"):
|
||||
request_data["file_url"] = kwargs["file_url"]
|
||||
|
||||
result = await self._call_mineru_api("extract_text", request_data)
|
||||
|
||||
return {
|
||||
"operation": "extract_text",
|
||||
"total_pages": result.get("total_pages", 0),
|
||||
"text_content": result.get("text_content", ""),
|
||||
"word_count": len(result.get("text_content", "").split()),
|
||||
"character_count": len(result.get("text_content", "")),
|
||||
"pages_text": result.get("pages_text", [])
|
||||
}
|
||||
|
||||
async def _extract_tables(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取表格"""
|
||||
page_range = kwargs.get("page_range")
|
||||
output_format = kwargs.get("output_format", "json")
|
||||
|
||||
request_data = {
|
||||
"operation": "extract_tables",
|
||||
"output_format": output_format
|
||||
}
|
||||
|
||||
if page_range:
|
||||
request_data["page_range"] = page_range
|
||||
|
||||
if kwargs.get("file_content"):
|
||||
request_data["file_content"] = kwargs["file_content"]
|
||||
elif kwargs.get("file_url"):
|
||||
request_data["file_url"] = kwargs["file_url"]
|
||||
|
||||
result = await self._call_mineru_api("extract_tables", request_data)
|
||||
|
||||
return {
|
||||
"operation": "extract_tables",
|
||||
"total_tables": result.get("total_tables", 0),
|
||||
"tables": result.get("tables", []),
|
||||
"table_locations": result.get("table_locations", [])
|
||||
}
|
||||
|
||||
async def _extract_images(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取图片"""
|
||||
page_range = kwargs.get("page_range")
|
||||
|
||||
request_data = {
|
||||
"operation": "extract_images"
|
||||
}
|
||||
|
||||
if page_range:
|
||||
request_data["page_range"] = page_range
|
||||
|
||||
if kwargs.get("file_content"):
|
||||
request_data["file_content"] = kwargs["file_content"]
|
||||
elif kwargs.get("file_url"):
|
||||
request_data["file_url"] = kwargs["file_url"]
|
||||
|
||||
result = await self._call_mineru_api("extract_images", request_data)
|
||||
|
||||
return {
|
||||
"operation": "extract_images",
|
||||
"total_images": result.get("total_images", 0),
|
||||
"images": result.get("images", []),
|
||||
"image_locations": result.get("image_locations", [])
|
||||
}
|
||||
|
||||
async def _analyze_layout(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""分析布局"""
|
||||
page_range = kwargs.get("page_range")
|
||||
|
||||
request_data = {
|
||||
"operation": "analyze_layout"
|
||||
}
|
||||
|
||||
if page_range:
|
||||
request_data["page_range"] = page_range
|
||||
|
||||
if kwargs.get("file_content"):
|
||||
request_data["file_content"] = kwargs["file_content"]
|
||||
elif kwargs.get("file_url"):
|
||||
request_data["file_url"] = kwargs["file_url"]
|
||||
|
||||
result = await self._call_mineru_api("analyze_layout", request_data)
|
||||
|
||||
return {
|
||||
"operation": "analyze_layout",
|
||||
"layout_info": result.get("layout_info", {}),
|
||||
"page_layouts": result.get("page_layouts", []),
|
||||
"text_blocks": result.get("text_blocks", []),
|
||||
"image_blocks": result.get("image_blocks", []),
|
||||
"table_blocks": result.get("table_blocks", [])
|
||||
}
|
||||
|
||||
async def _call_mineru_api(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""调用MinerU API"""
|
||||
api_key = self.get_config_parameter("api_key")
|
||||
api_url = self.get_config_parameter("api_url")
|
||||
timeout_seconds = self.get_config_parameter("timeout", 60)
|
||||
|
||||
if not api_key or not api_url:
|
||||
raise ValueError("MinerU API配置未完成")
|
||||
|
||||
# 构建完整URL
|
||||
url = f"{api_url.rstrip('/')}/{endpoint}"
|
||||
|
||||
# 构建请求头
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# 发送请求
|
||||
timeout = aiohttp.ClientTimeout(total=timeout_seconds)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(url, json=data, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
if result.get("success", True):
|
||||
return result.get("data", result)
|
||||
else:
|
||||
raise Exception(f"MinerU API错误: {result.get('message', '未知错误')}")
|
||||
else:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"HTTP错误 {response.status}: {error_text}")
|
||||
|
||||
def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试连接"""
|
||||
try:
|
||||
api_key = self.get_config_parameter("api_key")
|
||||
api_url = self.get_config_parameter("api_url")
|
||||
|
||||
if not api_key or not api_url:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "API配置未完成"
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接配置有效",
|
||||
"api_url": api_url,
|
||||
"api_key_masked": api_key[:8] + "***" if len(api_key) > 8 else "***"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
401
api/app/core/tools/builtin/textin_tool.py
Normal file
401
api/app/core/tools/builtin/textin_tool.py
Normal file
@@ -0,0 +1,401 @@
|
||||
"""TextIn OCR文字识别工具"""
|
||||
import time
|
||||
from typing import List, Dict, Any
|
||||
import aiohttp
|
||||
|
||||
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
|
||||
from .base import BuiltinTool
|
||||
|
||||
|
||||
class TextInTool(BuiltinTool):
|
||||
"""TextIn OCR工具 - 提供通用OCR、手写识别、多语言支持、高精度识别"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "textin_tool"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "TextIn - OCR文字识别:通用OCR、手写识别、多语言支持、高精度识别"
|
||||
|
||||
def get_required_config_parameters(self) -> List[str]:
|
||||
return ["app_id", "secret_key", "api_url"]
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
return [
|
||||
ToolParameter(
|
||||
name="image_content",
|
||||
type=ParameterType.STRING,
|
||||
description="图片内容(Base64编码)",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_url",
|
||||
type=ParameterType.STRING,
|
||||
description="图片URL",
|
||||
required=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="language",
|
||||
type=ParameterType.STRING,
|
||||
description="识别语言",
|
||||
required=False,
|
||||
default="auto",
|
||||
enum=["auto", "zh-cn", "zh-tw", "en", "ja", "ko", "fr", "de", "es", "ru"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="recognition_mode",
|
||||
type=ParameterType.STRING,
|
||||
description="识别模式",
|
||||
required=False,
|
||||
default="general",
|
||||
enum=["general", "accurate", "handwriting", "formula", "table", "document"]
|
||||
),
|
||||
ToolParameter(
|
||||
name="return_location",
|
||||
type=ParameterType.BOOLEAN,
|
||||
description="是否返回文字位置信息",
|
||||
required=False,
|
||||
default=False
|
||||
),
|
||||
ToolParameter(
|
||||
name="return_confidence",
|
||||
type=ParameterType.BOOLEAN,
|
||||
description="是否返回置信度",
|
||||
required=False,
|
||||
default=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="merge_lines",
|
||||
type=ParameterType.BOOLEAN,
|
||||
description="是否合并行",
|
||||
required=False,
|
||||
default=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="output_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输出格式",
|
||||
required=False,
|
||||
default="text",
|
||||
enum=["text", "json", "structured"]
|
||||
)
|
||||
]
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行TextIn OCR识别"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
image_content = kwargs.get("image_content")
|
||||
image_url = kwargs.get("image_url")
|
||||
|
||||
if not image_content and not image_url:
|
||||
raise ValueError("必须提供 image_content 或 image_url 参数")
|
||||
|
||||
language = kwargs.get("language", "auto")
|
||||
recognition_mode = kwargs.get("recognition_mode", "general")
|
||||
return_location = kwargs.get("return_location", False)
|
||||
return_confidence = kwargs.get("return_confidence", True)
|
||||
merge_lines = kwargs.get("merge_lines", True)
|
||||
output_format = kwargs.get("output_format", "text")
|
||||
|
||||
# 根据识别模式调用不同的API
|
||||
if recognition_mode == "general":
|
||||
result = await self._general_ocr(kwargs)
|
||||
elif recognition_mode == "accurate":
|
||||
result = await self._accurate_ocr(kwargs)
|
||||
elif recognition_mode == "handwriting":
|
||||
result = await self._handwriting_ocr(kwargs)
|
||||
elif recognition_mode == "formula":
|
||||
result = await self._formula_ocr(kwargs)
|
||||
elif recognition_mode == "table":
|
||||
result = await self._table_ocr(kwargs)
|
||||
elif recognition_mode == "document":
|
||||
result = await self._document_ocr(kwargs)
|
||||
else:
|
||||
raise ValueError(f"不支持的识别模式: {recognition_mode}")
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="TEXTIN_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
async def _general_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""通用OCR识别"""
|
||||
request_data = {
|
||||
"language": kwargs.get("language", "auto"),
|
||||
"return_location": kwargs.get("return_location", False),
|
||||
"return_confidence": kwargs.get("return_confidence", True),
|
||||
"merge_lines": kwargs.get("merge_lines", True)
|
||||
}
|
||||
|
||||
if kwargs.get("image_content"):
|
||||
request_data["image"] = kwargs["image_content"]
|
||||
elif kwargs.get("image_url"):
|
||||
request_data["image_url"] = kwargs["image_url"]
|
||||
|
||||
result = await self._call_textin_api("general_ocr", request_data)
|
||||
|
||||
return self._format_ocr_result(result, kwargs.get("output_format", "text"))
|
||||
|
||||
async def _accurate_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""高精度OCR识别"""
|
||||
request_data = {
|
||||
"language": kwargs.get("language", "auto"),
|
||||
"return_location": kwargs.get("return_location", False),
|
||||
"return_confidence": kwargs.get("return_confidence", True),
|
||||
"merge_lines": kwargs.get("merge_lines", True)
|
||||
}
|
||||
|
||||
if kwargs.get("image_content"):
|
||||
request_data["image"] = kwargs["image_content"]
|
||||
elif kwargs.get("image_url"):
|
||||
request_data["image_url"] = kwargs["image_url"]
|
||||
|
||||
result = await self._call_textin_api("accurate_ocr", request_data)
|
||||
|
||||
return self._format_ocr_result(result, kwargs.get("output_format", "text"))
|
||||
|
||||
async def _handwriting_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""手写体识别"""
|
||||
request_data = {
|
||||
"language": kwargs.get("language", "auto"),
|
||||
"return_location": kwargs.get("return_location", False),
|
||||
"return_confidence": kwargs.get("return_confidence", True)
|
||||
}
|
||||
|
||||
if kwargs.get("image_content"):
|
||||
request_data["image"] = kwargs["image_content"]
|
||||
elif kwargs.get("image_url"):
|
||||
request_data["image_url"] = kwargs["image_url"]
|
||||
|
||||
result = await self._call_textin_api("handwriting_ocr", request_data)
|
||||
|
||||
return self._format_ocr_result(result, kwargs.get("output_format", "text"))
|
||||
|
||||
async def _formula_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""公式识别"""
|
||||
request_data = {
|
||||
"return_location": kwargs.get("return_location", False),
|
||||
"return_confidence": kwargs.get("return_confidence", True),
|
||||
"output_latex": True
|
||||
}
|
||||
|
||||
if kwargs.get("image_content"):
|
||||
request_data["image"] = kwargs["image_content"]
|
||||
elif kwargs.get("image_url"):
|
||||
request_data["image_url"] = kwargs["image_url"]
|
||||
|
||||
result = await self._call_textin_api("formula_ocr", request_data)
|
||||
|
||||
return self._format_formula_result(result, kwargs.get("output_format", "text"))
|
||||
|
||||
async def _table_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""表格识别"""
|
||||
request_data = {
|
||||
"language": kwargs.get("language", "auto"),
|
||||
"return_location": kwargs.get("return_location", False),
|
||||
"return_confidence": kwargs.get("return_confidence", True),
|
||||
"output_excel": True
|
||||
}
|
||||
|
||||
if kwargs.get("image_content"):
|
||||
request_data["image"] = kwargs["image_content"]
|
||||
elif kwargs.get("image_url"):
|
||||
request_data["image_url"] = kwargs["image_url"]
|
||||
|
||||
result = await self._call_textin_api("table_ocr", request_data)
|
||||
|
||||
return self._format_table_result(result, kwargs.get("output_format", "text"))
|
||||
|
||||
async def _document_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""文档识别"""
|
||||
request_data = {
|
||||
"language": kwargs.get("language", "auto"),
|
||||
"return_location": kwargs.get("return_location", False),
|
||||
"return_confidence": kwargs.get("return_confidence", True),
|
||||
"layout_analysis": True
|
||||
}
|
||||
|
||||
if kwargs.get("image_content"):
|
||||
request_data["image"] = kwargs["image_content"]
|
||||
elif kwargs.get("image_url"):
|
||||
request_data["image_url"] = kwargs["image_url"]
|
||||
|
||||
result = await self._call_textin_api("document_ocr", request_data)
|
||||
|
||||
return self._format_document_result(result, kwargs.get("output_format", "text"))
|
||||
|
||||
def _format_ocr_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any] | None:
|
||||
"""格式化OCR结果"""
|
||||
lines = result.get("lines", [])
|
||||
|
||||
if output_format == "text":
|
||||
text_content = "\n".join([line.get("text", "") for line in lines])
|
||||
return {
|
||||
"recognition_mode": "ocr",
|
||||
"text_content": text_content,
|
||||
"line_count": len(lines),
|
||||
"total_confidence": result.get("confidence", 0),
|
||||
"processing_time": result.get("processing_time", 0)
|
||||
}
|
||||
|
||||
elif output_format == "json":
|
||||
return {
|
||||
"recognition_mode": "ocr",
|
||||
"lines": lines,
|
||||
"total_confidence": result.get("confidence", 0),
|
||||
"processing_time": result.get("processing_time", 0)
|
||||
}
|
||||
|
||||
elif output_format == "structured":
|
||||
return {
|
||||
"recognition_mode": "ocr",
|
||||
"text_content": "\n".join([line.get("text", "") for line in lines]),
|
||||
"structured_data": {
|
||||
"lines": lines,
|
||||
"paragraphs": self._group_lines_to_paragraphs(lines),
|
||||
"statistics": {
|
||||
"line_count": len(lines),
|
||||
"word_count": sum(len(line.get("text", "").split()) for line in lines),
|
||||
"character_count": sum(len(line.get("text", "")) for line in lines)
|
||||
}
|
||||
},
|
||||
"total_confidence": result.get("confidence", 0),
|
||||
"processing_time": result.get("processing_time", 0)
|
||||
}
|
||||
|
||||
def _format_formula_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
|
||||
"""格式化公式识别结果"""
|
||||
formulas = result.get("formulas", [])
|
||||
|
||||
return {
|
||||
"recognition_mode": "formula",
|
||||
"formula_count": len(formulas),
|
||||
"formulas": formulas,
|
||||
"latex_content": "\n".join([f.get("latex", "") for f in formulas]),
|
||||
"total_confidence": result.get("confidence", 0),
|
||||
"processing_time": result.get("processing_time", 0)
|
||||
}
|
||||
|
||||
def _format_table_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
|
||||
"""格式化表格识别结果"""
|
||||
tables = result.get("tables", [])
|
||||
|
||||
return {
|
||||
"recognition_mode": "table",
|
||||
"table_count": len(tables),
|
||||
"tables": tables,
|
||||
"excel_data": result.get("excel_data"),
|
||||
"total_confidence": result.get("confidence", 0),
|
||||
"processing_time": result.get("processing_time", 0)
|
||||
}
|
||||
|
||||
def _format_document_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
|
||||
"""格式化文档识别结果"""
|
||||
return {
|
||||
"recognition_mode": "document",
|
||||
"layout_info": result.get("layout_info", {}),
|
||||
"text_blocks": result.get("text_blocks", []),
|
||||
"image_blocks": result.get("image_blocks", []),
|
||||
"table_blocks": result.get("table_blocks", []),
|
||||
"full_text": result.get("full_text", ""),
|
||||
"total_confidence": result.get("confidence", 0),
|
||||
"processing_time": result.get("processing_time", 0)
|
||||
}
|
||||
|
||||
def _group_lines_to_paragraphs(self, lines: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""将行分组为段落"""
|
||||
paragraphs = []
|
||||
current_paragraph = []
|
||||
|
||||
for line in lines:
|
||||
text = line.get("text", "").strip()
|
||||
if text:
|
||||
current_paragraph.append(line)
|
||||
else:
|
||||
if current_paragraph:
|
||||
paragraphs.append({
|
||||
"text": " ".join([l.get("text", "") for l in current_paragraph]),
|
||||
"lines": current_paragraph
|
||||
})
|
||||
current_paragraph = []
|
||||
|
||||
if current_paragraph:
|
||||
paragraphs.append({
|
||||
"text": " ".join([l.get("text", "") for l in current_paragraph]),
|
||||
"lines": current_paragraph
|
||||
})
|
||||
|
||||
return paragraphs
|
||||
|
||||
async def _call_textin_api(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""调用TextIn API"""
|
||||
app_id = self.get_config_parameter("app_id")
|
||||
secret_key = self.get_config_parameter("secret_key")
|
||||
api_url = self.get_config_parameter("api_url")
|
||||
|
||||
if not app_id or not secret_key or not api_url:
|
||||
raise ValueError("TextIn API配置未完成")
|
||||
|
||||
# 构建完整URL
|
||||
url = f"{api_url.rstrip('/')}/{endpoint}"
|
||||
|
||||
# 构建请求头
|
||||
headers = {
|
||||
"X-App-Id": app_id,
|
||||
"X-Secret-Key": secret_key,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# 发送请求
|
||||
timeout = aiohttp.ClientTimeout(total=30)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(url, json=data, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
if result.get("code") == 200:
|
||||
return result.get("data", result)
|
||||
else:
|
||||
raise Exception(f"TextIn API错误: {result.get('message', '未知错误')}")
|
||||
else:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"HTTP错误 {response.status}: {error_text}")
|
||||
|
||||
def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试连接"""
|
||||
try:
|
||||
app_id = self.get_config_parameter("app_id")
|
||||
secret_key = self.get_config_parameter("secret_key")
|
||||
api_url = self.get_config_parameter("api_url")
|
||||
|
||||
if not app_id or not secret_key or not api_url:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "API配置未完成"
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接配置有效",
|
||||
"api_url": api_url,
|
||||
"app_id": app_id,
|
||||
"secret_key_masked": secret_key[:8] + "***" if len(secret_key) > 8 else "***"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
485
api/app/core/tools/chain_manager.py
Normal file
485
api/app/core/tools/chain_manager.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""工具链管理器 - 支持langchain的工具链模式"""
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from app.core.tools.base import ToolResult
|
||||
from app.core.tools.executor import ToolExecutor
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ChainExecutionMode(str, Enum):
|
||||
"""链执行模式"""
|
||||
SEQUENTIAL = "sequential" # 顺序执行
|
||||
PARALLEL = "parallel" # 并行执行
|
||||
CONDITIONAL = "conditional" # 条件执行
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChainStep:
|
||||
"""链步骤定义"""
|
||||
tool_id: str
|
||||
parameters: Dict[str, Any]
|
||||
condition: Optional[str] = None # 执行条件
|
||||
output_mapping: Optional[Dict[str, str]] = None # 输出映射
|
||||
error_handling: str = "stop" # 错误处理:stop, continue, retry
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChainDefinition:
|
||||
"""工具链定义"""
|
||||
name: str
|
||||
description: str
|
||||
steps: List[ChainStep]
|
||||
execution_mode: ChainExecutionMode = ChainExecutionMode.SEQUENTIAL
|
||||
global_timeout: Optional[float] = None
|
||||
retry_policy: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ChainExecutionContext:
|
||||
"""链执行上下文"""
|
||||
|
||||
def __init__(self, chain_id: str):
|
||||
self.chain_id = chain_id
|
||||
self.variables: Dict[str, Any] = {}
|
||||
self.step_results: Dict[int, ToolResult] = {}
|
||||
self.current_step = 0
|
||||
self.is_completed = False
|
||||
self.is_failed = False
|
||||
self.error_message: Optional[str] = None
|
||||
|
||||
|
||||
class ChainManager:
|
||||
"""工具链管理器 - 支持langchain的工具链模式"""
|
||||
|
||||
def __init__(self, executor: ToolExecutor):
|
||||
"""初始化工具链管理器
|
||||
|
||||
Args:
|
||||
executor: 工具执行器
|
||||
"""
|
||||
self.executor = executor
|
||||
self._chains: Dict[str, ChainDefinition] = {}
|
||||
self._running_chains: Dict[str, ChainExecutionContext] = {}
|
||||
|
||||
def register_chain(self, chain: ChainDefinition) -> bool:
|
||||
"""注册工具链
|
||||
|
||||
Args:
|
||||
chain: 工具链定义
|
||||
|
||||
Returns:
|
||||
注册是否成功
|
||||
"""
|
||||
try:
|
||||
# 验证工具链定义
|
||||
validation_result = self._validate_chain(chain)
|
||||
if not validation_result[0]:
|
||||
logger.error(f"工具链验证失败: {chain.name}, 错误: {validation_result[1]}")
|
||||
return False
|
||||
|
||||
self._chains[chain.name] = chain
|
||||
logger.info(f"工具链注册成功: {chain.name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具链注册失败: {chain.name}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def unregister_chain(self, chain_name: str) -> bool:
|
||||
"""注销工具链
|
||||
|
||||
Args:
|
||||
chain_name: 工具链名称
|
||||
|
||||
Returns:
|
||||
注销是否成功
|
||||
"""
|
||||
if chain_name in self._chains:
|
||||
del self._chains[chain_name]
|
||||
logger.info(f"工具链注销成功: {chain_name}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def list_chains(self) -> List[Dict[str, Any]]:
|
||||
"""列出所有工具链
|
||||
|
||||
Returns:
|
||||
工具链信息列表
|
||||
"""
|
||||
chains = []
|
||||
for name, chain in self._chains.items():
|
||||
chains.append({
|
||||
"name": name,
|
||||
"description": chain.description,
|
||||
"step_count": len(chain.steps),
|
||||
"execution_mode": chain.execution_mode.value,
|
||||
"global_timeout": chain.global_timeout
|
||||
})
|
||||
|
||||
return chains
|
||||
|
||||
async def execute_chain(
|
||||
self,
|
||||
chain_name: str,
|
||||
initial_variables: Optional[Dict[str, Any]] = None,
|
||||
chain_id: Optional[str] = None
|
||||
) -> Dict[str, Any] | None:
|
||||
"""执行工具链
|
||||
|
||||
Args:
|
||||
chain_name: 工具链名称
|
||||
initial_variables: 初始变量
|
||||
chain_id: 链执行ID(可选)
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
if chain_name not in self._chains:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"工具链不存在: {chain_name}",
|
||||
"chain_id": chain_id
|
||||
}
|
||||
|
||||
chain = self._chains[chain_name]
|
||||
|
||||
# 生成链ID
|
||||
if not chain_id:
|
||||
import uuid
|
||||
chain_id = f"chain_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
# 创建执行上下文
|
||||
context = ChainExecutionContext(chain_id)
|
||||
context.variables = initial_variables or {}
|
||||
self._running_chains[chain_id] = context
|
||||
|
||||
try:
|
||||
logger.info(f"开始执行工具链: {chain_name} (ID: {chain_id})")
|
||||
|
||||
# 根据执行模式执行
|
||||
if chain.execution_mode == ChainExecutionMode.SEQUENTIAL:
|
||||
result = await self._execute_sequential(chain, context)
|
||||
elif chain.execution_mode == ChainExecutionMode.PARALLEL:
|
||||
result = await self._execute_parallel(chain, context)
|
||||
elif chain.execution_mode == ChainExecutionMode.CONDITIONAL:
|
||||
result = await self._execute_conditional(chain, context)
|
||||
else:
|
||||
raise ValueError(f"不支持的执行模式: {chain.execution_mode}")
|
||||
|
||||
logger.info(f"工具链执行完成: {chain_name} (ID: {chain_id})")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具链执行失败: {chain_name} (ID: {chain_id}), 错误: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"chain_id": chain_id,
|
||||
"completed_steps": context.current_step,
|
||||
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
|
||||
}
|
||||
|
||||
finally:
|
||||
# 清理执行上下文
|
||||
if chain_id in self._running_chains:
|
||||
del self._running_chains[chain_id]
|
||||
|
||||
async def _execute_sequential(
|
||||
self,
|
||||
chain: ChainDefinition,
|
||||
context: ChainExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
"""顺序执行工具链"""
|
||||
for i, step in enumerate(chain.steps):
|
||||
context.current_step = i
|
||||
|
||||
# 检查执行条件
|
||||
if step.condition and not self._evaluate_condition(step.condition, context):
|
||||
logger.debug(f"跳过步骤 {i}: 条件不满足")
|
||||
continue
|
||||
|
||||
# 准备参数
|
||||
parameters = self._prepare_parameters(step.parameters, context)
|
||||
|
||||
# 执行工具
|
||||
try:
|
||||
result = await self.executor.execute_tool(
|
||||
tool_id=step.tool_id,
|
||||
parameters=parameters
|
||||
)
|
||||
|
||||
context.step_results[i] = result
|
||||
|
||||
# 处理输出映射
|
||||
if step.output_mapping and result.success:
|
||||
self._apply_output_mapping(step.output_mapping, result.data, context)
|
||||
|
||||
# 处理执行失败
|
||||
if not result.success:
|
||||
if step.error_handling == "stop":
|
||||
context.is_failed = True
|
||||
context.error_message = result.error
|
||||
break
|
||||
elif step.error_handling == "continue":
|
||||
logger.warning(f"步骤 {i} 执行失败,继续执行: {result.error}")
|
||||
continue
|
||||
elif step.error_handling == "retry":
|
||||
# 简单重试逻辑
|
||||
retry_result = await self.executor.execute_tool(
|
||||
tool_id=step.tool_id,
|
||||
parameters=parameters
|
||||
)
|
||||
context.step_results[i] = retry_result
|
||||
if not retry_result.success and step.error_handling == "stop":
|
||||
context.is_failed = True
|
||||
context.error_message = retry_result.error
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"步骤 {i} 执行异常: {e}")
|
||||
if step.error_handling == "stop":
|
||||
context.is_failed = True
|
||||
context.error_message = str(e)
|
||||
break
|
||||
|
||||
context.is_completed = not context.is_failed
|
||||
|
||||
return {
|
||||
"success": context.is_completed,
|
||||
"error": context.error_message,
|
||||
"chain_id": context.chain_id,
|
||||
"completed_steps": context.current_step + 1,
|
||||
"total_steps": len(chain.steps),
|
||||
"final_variables": context.variables,
|
||||
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
|
||||
}
|
||||
|
||||
async def _execute_parallel(
|
||||
self,
|
||||
chain: ChainDefinition,
|
||||
context: ChainExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
"""并行执行工具链"""
|
||||
# 准备所有步骤的执行配置
|
||||
execution_configs = []
|
||||
|
||||
for i, step in enumerate(chain.steps):
|
||||
# 检查执行条件
|
||||
if step.condition and not self._evaluate_condition(step.condition, context):
|
||||
continue
|
||||
|
||||
parameters = self._prepare_parameters(step.parameters, context)
|
||||
execution_configs.append({
|
||||
"step_index": i,
|
||||
"tool_id": step.tool_id,
|
||||
"parameters": parameters
|
||||
})
|
||||
|
||||
# 并行执行所有步骤
|
||||
try:
|
||||
results = await self.executor.execute_tools_batch(execution_configs)
|
||||
|
||||
# 处理结果
|
||||
for i, result in enumerate(results):
|
||||
step_index = execution_configs[i]["step_index"]
|
||||
context.step_results[step_index] = result
|
||||
|
||||
# 处理输出映射
|
||||
step = chain.steps[step_index]
|
||||
if step.output_mapping and result.success:
|
||||
self._apply_output_mapping(step.output_mapping, result.data, context)
|
||||
|
||||
# 检查是否有失败的步骤
|
||||
failed_steps = [i for i, result in context.step_results.items() if not result.success]
|
||||
|
||||
context.is_completed = len(failed_steps) == 0
|
||||
if failed_steps:
|
||||
context.error_message = f"步骤 {failed_steps} 执行失败"
|
||||
|
||||
except Exception as e:
|
||||
context.is_failed = True
|
||||
context.error_message = str(e)
|
||||
|
||||
return {
|
||||
"success": context.is_completed,
|
||||
"error": context.error_message,
|
||||
"chain_id": context.chain_id,
|
||||
"completed_steps": len(context.step_results),
|
||||
"total_steps": len(chain.steps),
|
||||
"final_variables": context.variables,
|
||||
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
|
||||
}
|
||||
|
||||
async def _execute_conditional(
|
||||
self,
|
||||
chain: ChainDefinition,
|
||||
context: ChainExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
"""条件执行工具链"""
|
||||
# 条件执行类似于顺序执行,但更严格地检查条件
|
||||
return await self._execute_sequential(chain, context)
|
||||
|
||||
def _validate_chain(self, chain: ChainDefinition) -> tuple[bool, Optional[str]]:
|
||||
"""验证工具链定义
|
||||
|
||||
Args:
|
||||
chain: 工具链定义
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
if not chain.name:
|
||||
return False, "工具链名称不能为空"
|
||||
|
||||
if not chain.steps:
|
||||
return False, "工具链必须包含至少一个步骤"
|
||||
|
||||
for i, step in enumerate(chain.steps):
|
||||
if not step.tool_id:
|
||||
return False, f"步骤 {i} 缺少工具ID"
|
||||
|
||||
if step.error_handling not in ["stop", "continue", "retry"]:
|
||||
return False, f"步骤 {i} 错误处理策略无效: {step.error_handling}"
|
||||
|
||||
return True, None
|
||||
|
||||
def _prepare_parameters(
|
||||
self,
|
||||
parameters: Dict[str, Any],
|
||||
context: ChainExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
"""准备参数(支持变量替换)
|
||||
|
||||
Args:
|
||||
parameters: 原始参数
|
||||
context: 执行上下文
|
||||
|
||||
Returns:
|
||||
处理后的参数
|
||||
"""
|
||||
prepared = {}
|
||||
|
||||
for key, value in parameters.items():
|
||||
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
||||
# 变量替换
|
||||
var_name = value[2:-1]
|
||||
if var_name in context.variables:
|
||||
prepared[key] = context.variables[var_name]
|
||||
else:
|
||||
prepared[key] = value # 保持原值
|
||||
else:
|
||||
prepared[key] = value
|
||||
|
||||
return prepared
|
||||
|
||||
def _evaluate_condition(
|
||||
self,
|
||||
condition: str,
|
||||
context: ChainExecutionContext
|
||||
) -> bool:
|
||||
"""评估执行条件
|
||||
|
||||
Args:
|
||||
condition: 条件表达式
|
||||
context: 执行上下文
|
||||
|
||||
Returns:
|
||||
条件是否满足
|
||||
"""
|
||||
try:
|
||||
# 简单的条件评估(可以扩展为更复杂的表达式解析)
|
||||
# 支持格式:variable == value, variable != value, variable > value 等
|
||||
|
||||
if "==" in condition:
|
||||
var_name, expected_value = condition.split("==", 1)
|
||||
var_name = var_name.strip()
|
||||
expected_value = expected_value.strip().strip('"\'')
|
||||
|
||||
return str(context.variables.get(var_name, "")) == expected_value
|
||||
|
||||
elif "!=" in condition:
|
||||
var_name, expected_value = condition.split("!=", 1)
|
||||
var_name = var_name.strip()
|
||||
expected_value = expected_value.strip().strip('"\'')
|
||||
|
||||
return str(context.variables.get(var_name, "")) != expected_value
|
||||
|
||||
elif condition in context.variables:
|
||||
# 简单的布尔检查
|
||||
return bool(context.variables[condition])
|
||||
|
||||
else:
|
||||
# 默认为真
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"条件评估失败: {condition}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def _apply_output_mapping(
|
||||
self,
|
||||
mapping: Dict[str, str],
|
||||
output_data: Any,
|
||||
context: ChainExecutionContext
|
||||
):
|
||||
"""应用输出映射
|
||||
|
||||
Args:
|
||||
mapping: 输出映射配置
|
||||
output_data: 输出数据
|
||||
context: 执行上下文
|
||||
"""
|
||||
try:
|
||||
if isinstance(output_data, dict):
|
||||
for source_key, target_var in mapping.items():
|
||||
if source_key in output_data:
|
||||
context.variables[target_var] = output_data[source_key]
|
||||
else:
|
||||
# 如果输出不是字典,将整个输出映射到指定变量
|
||||
if "result" in mapping:
|
||||
context.variables[mapping["result"]] = output_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"输出映射失败: {e}")
|
||||
|
||||
def _serialize_result(self, result: ToolResult) -> Dict[str, Any]:
|
||||
"""序列化工具结果
|
||||
|
||||
Args:
|
||||
result: 工具结果
|
||||
|
||||
Returns:
|
||||
序列化的结果
|
||||
"""
|
||||
return {
|
||||
"success": result.success,
|
||||
"data": result.data,
|
||||
"error": result.error,
|
||||
"error_code": result.error_code,
|
||||
"execution_time": result.execution_time,
|
||||
"token_usage": result.token_usage,
|
||||
"metadata": result.metadata
|
||||
}
|
||||
|
||||
def get_running_chains(self) -> List[Dict[str, Any]]:
|
||||
"""获取正在运行的工具链
|
||||
|
||||
Returns:
|
||||
运行中的工具链列表
|
||||
"""
|
||||
chains = []
|
||||
for chain_id, context in self._running_chains.items():
|
||||
chains.append({
|
||||
"chain_id": chain_id,
|
||||
"current_step": context.current_step,
|
||||
"is_completed": context.is_completed,
|
||||
"is_failed": context.is_failed,
|
||||
"variables_count": len(context.variables),
|
||||
"completed_steps": len(context.step_results)
|
||||
})
|
||||
|
||||
return chains
|
||||
264
api/app/core/tools/config_manager.py
Normal file
264
api/app/core/tools/config_manager.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""工具配置管理器 - 管理工具配置的加载和验证"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ToolConfigSchema(BaseModel):
|
||||
"""工具配置基础Schema"""
|
||||
name: str
|
||||
description: str
|
||||
tool_type: str
|
||||
version: str = "1.0.0"
|
||||
enabled: bool = True
|
||||
parameters: Dict[str, Any] = {}
|
||||
tags: list[str] = []
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class BuiltinToolConfigSchema(ToolConfigSchema):
|
||||
"""内置工具配置Schema"""
|
||||
tool_class: str
|
||||
tool_type: str = "builtin"
|
||||
|
||||
|
||||
class CustomToolConfigSchema(ToolConfigSchema):
|
||||
"""自定义工具配置Schema"""
|
||||
schema_url: Optional[str] = None
|
||||
schema_content: Optional[Dict[str, Any]] = None
|
||||
auth_type: str = "none"
|
||||
auth_config: Dict[str, Any] = {}
|
||||
base_url: Optional[str] = None
|
||||
timeout: int = 30
|
||||
tool_type: str = "custom"
|
||||
|
||||
|
||||
class MCPToolConfigSchema(ToolConfigSchema):
|
||||
"""MCP工具配置Schema"""
|
||||
server_url: str
|
||||
connection_config: Dict[str, Any] = {}
|
||||
available_tools: list[str] = []
|
||||
tool_type: str = "mcp"
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""工具配置管理器"""
|
||||
|
||||
def __init__(self, config_dir: Optional[str] = None):
|
||||
"""初始化配置管理器
|
||||
|
||||
Args:
|
||||
config_dir: 配置文件目录,默认使用系统配置
|
||||
"""
|
||||
self.config_dir = Path(config_dir or self._get_default_config_dir())
|
||||
self.config_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"配置管理器初始化完成,配置目录: {self.config_dir}")
|
||||
|
||||
def _get_default_config_dir(self) -> str:
|
||||
"""获取默认配置目录"""
|
||||
# 获取tools目录下的configs子目录
|
||||
tools_dir = Path(__file__).parent
|
||||
return str(tools_dir / "configs")
|
||||
|
||||
def load_builtin_tool_configs(self) -> Dict[str, BuiltinToolConfigSchema]:
|
||||
"""加载内置工具配置
|
||||
|
||||
Returns:
|
||||
内置工具配置字典
|
||||
"""
|
||||
configs = {}
|
||||
builtin_dir = self.config_dir / "builtin"
|
||||
|
||||
if not builtin_dir.exists():
|
||||
logger.info("内置工具配置目录不存在,创建默认配置")
|
||||
self._create_default_builtin_configs(builtin_dir)
|
||||
|
||||
for config_file in builtin_dir.glob("*.json"):
|
||||
try:
|
||||
config_data = self._load_config_file(config_file)
|
||||
config = BuiltinToolConfigSchema(**config_data)
|
||||
configs[config.name] = config
|
||||
logger.debug(f"加载内置工具配置: {config.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载内置工具配置失败: {config_file}, 错误: {e}")
|
||||
|
||||
return configs
|
||||
|
||||
def load_builtin_tools_config(self) -> Dict[str, Any]:
|
||||
"""加载全局内置工具配置(兼容原有接口)
|
||||
|
||||
Returns:
|
||||
内置工具配置字典
|
||||
"""
|
||||
config_file = self.config_dir / "builtin_tools.json"
|
||||
try:
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"加载内置工具配置失败: {e}")
|
||||
return {}
|
||||
|
||||
def ensure_builtin_tools_initialized(self, tenant_id, db_session, tool_config_model, builtin_tool_config_model, tool_type_enum, tool_status_enum):
|
||||
"""确保内置工具已初始化到数据库
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID
|
||||
db_session: 数据库会话
|
||||
tool_config_model: ToolConfig模型类
|
||||
builtin_tool_config_model: BuiltinToolConfig模型类
|
||||
tool_type_enum: ToolType枚举
|
||||
tool_status_enum: ToolStatus枚举
|
||||
"""
|
||||
# 检查是否已初始化
|
||||
existing_count = db_session.query(tool_config_model).filter(
|
||||
tool_config_model.tenant_id == tenant_id,
|
||||
tool_config_model.tool_type == tool_type_enum.BUILTIN
|
||||
).count()
|
||||
|
||||
if existing_count > 0:
|
||||
return # 已初始化
|
||||
|
||||
# 加载全局配置
|
||||
builtin_tools = self.load_builtin_tools_config()
|
||||
|
||||
# 为租户创建内置工具记录
|
||||
for tool_key, tool_info in builtin_tools.items():
|
||||
# 设置初始状态
|
||||
initial_status = tool_status_enum.ACTIVE.value if not tool_info['requires_config'] else tool_status_enum.INACTIVE.value
|
||||
|
||||
tool_config = tool_config_model(
|
||||
name=tool_info['name'],
|
||||
description=tool_info['description'],
|
||||
tool_type=tool_type_enum.BUILTIN,
|
||||
tenant_id=tenant_id,
|
||||
status=initial_status
|
||||
)
|
||||
db_session.add(tool_config)
|
||||
db_session.flush()
|
||||
|
||||
builtin_config = builtin_tool_config_model(
|
||||
id=tool_config.id,
|
||||
tool_class=tool_info['tool_class'],
|
||||
parameters={}
|
||||
)
|
||||
db_session.add(builtin_config)
|
||||
|
||||
db_session.commit()
|
||||
logger.info(f"租户 {tenant_id} 的内置工具初始化完成")
|
||||
|
||||
def save_tool_config(self, config: ToolConfigSchema, tool_type: str) -> bool:
|
||||
"""保存工具配置
|
||||
|
||||
Args:
|
||||
config: 工具配置
|
||||
tool_type: 工具类型
|
||||
|
||||
Returns:
|
||||
保存是否成功
|
||||
"""
|
||||
try:
|
||||
config_dir = self.config_dir / tool_type
|
||||
config_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
config_file = config_dir / f"{config.name}.json"
|
||||
config_data = config.model_dump()
|
||||
|
||||
with open(config_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(config_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"工具配置保存成功: {config.name} ({tool_type})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具配置保存失败: {config.name}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def delete_tool_config(self, tool_name: str, tool_type: str) -> bool:
|
||||
"""删除工具配置
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
tool_type: 工具类型
|
||||
|
||||
Returns:
|
||||
删除是否成功
|
||||
"""
|
||||
try:
|
||||
config_file = self.config_dir / tool_type / f"{tool_name}.json"
|
||||
|
||||
if config_file.exists():
|
||||
config_file.unlink()
|
||||
logger.info(f"工具配置删除成功: {tool_name} ({tool_type})")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"工具配置文件不存在: {tool_name} ({tool_type})")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具配置删除失败: {tool_name}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def validate_config(self, config_data: Dict[str, Any], tool_type: str) -> tuple[bool, Optional[str]]:
|
||||
"""验证工具配置
|
||||
|
||||
Args:
|
||||
config_data: 配置数据
|
||||
tool_type: 工具类型
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
try:
|
||||
schema_map = {
|
||||
"builtin": BuiltinToolConfigSchema,
|
||||
"custom": CustomToolConfigSchema,
|
||||
"mcp": MCPToolConfigSchema
|
||||
}
|
||||
|
||||
schema_class = schema_map.get(tool_type)
|
||||
if not schema_class:
|
||||
return False, f"不支持的工具类型: {tool_type}"
|
||||
|
||||
# 验证配置
|
||||
schema_class(**config_data)
|
||||
return True, None
|
||||
|
||||
except ValidationError as e:
|
||||
error_msg = "; ".join([f"{err['loc'][0]}: {err['msg']}" for err in e.errors()])
|
||||
return False, f"配置验证失败: {error_msg}"
|
||||
except Exception as e:
|
||||
return False, f"配置验证异常: {str(e)}"
|
||||
|
||||
def _load_config_file(self, config_file: Path) -> Dict[str, Any]:
|
||||
"""加载配置文件
|
||||
|
||||
Args:
|
||||
config_file: 配置文件路径
|
||||
|
||||
Returns:
|
||||
配置数据字典
|
||||
"""
|
||||
try:
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"加载配置文件失败: {config_file}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def _create_default_builtin_configs(self, builtin_dir: Path):
|
||||
"""创建默认内置工具配置
|
||||
|
||||
Args:
|
||||
builtin_dir: 内置工具配置目录
|
||||
"""
|
||||
builtin_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"内置工具配置目录已创建: {builtin_dir}")
|
||||
# 配置文件已经通过其他方式创建,这里只需要确保目录存在
|
||||
14
api/app/core/tools/configs/builtin/baidu_search_tool.json
Normal file
14
api/app/core/tools/configs/builtin/baidu_search_tool.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"name": "baidu_search_tool",
|
||||
"description": "百度搜索工具 - 网络搜索:提供网页搜索、新闻搜索、图片搜索功能",
|
||||
"tool_type": "builtin",
|
||||
"tool_class": "BaiduSearchTool",
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {
|
||||
"api_key": "",
|
||||
"secret_key": "",
|
||||
"search_type": "web"
|
||||
},
|
||||
"tags": ["search", "web", "baidu", "builtin"]
|
||||
}
|
||||
12
api/app/core/tools/configs/builtin/datetime_tool.json
Normal file
12
api/app/core/tools/configs/builtin/datetime_tool.json
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"name": "datetime_tool",
|
||||
"description": "时间工具 - 日期时间处理:提供时间格式转化、时区转换、时间戳转换、时间计算",
|
||||
"tool_type": "builtin",
|
||||
"tool_class": "DateTimeTool",
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {
|
||||
"timezone": "UTC"
|
||||
},
|
||||
"tags": ["time", "utility", "builtin"]
|
||||
}
|
||||
12
api/app/core/tools/configs/builtin/json_tool.json
Normal file
12
api/app/core/tools/configs/builtin/json_tool.json
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"name": "json_tool",
|
||||
"description": "JSON工具 - 数据格式处理:提供JSON格式化、压缩、验证、格式转换",
|
||||
"tool_type": "builtin",
|
||||
"tool_class": "JsonTool",
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {
|
||||
"indent": 2
|
||||
},
|
||||
"tags": ["json", "data", "utility", "builtin"]
|
||||
}
|
||||
14
api/app/core/tools/configs/builtin/mineru_tool.json
Normal file
14
api/app/core/tools/configs/builtin/mineru_tool.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"name": "mineru_tool",
|
||||
"description": "MinerU PDF解析工具 - 文档处理:提供PDF解析、表格提取、图片识别、文本提取功能",
|
||||
"tool_type": "builtin",
|
||||
"tool_class": "MinerUTool",
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {
|
||||
"api_key": "",
|
||||
"parse_mode": "auto",
|
||||
"timeout": 60
|
||||
},
|
||||
"tags": ["pdf", "document", "ocr", "builtin"]
|
||||
}
|
||||
14
api/app/core/tools/configs/builtin/textin_tool.json
Normal file
14
api/app/core/tools/configs/builtin/textin_tool.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"name": "textin_tool",
|
||||
"description": "TextIn OCR工具 - 图像识别:提供通用OCR、手写识别、多语言支持功能",
|
||||
"tool_type": "builtin",
|
||||
"tool_class": "TextInTool",
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {
|
||||
"app_id": "",
|
||||
"language": "auto",
|
||||
"recognition_mode": "general"
|
||||
},
|
||||
"tags": ["ocr", "image", "text", "builtin"]
|
||||
}
|
||||
60
api/app/core/tools/configs/builtin_tools.json
Normal file
60
api/app/core/tools/configs/builtin_tools.json
Normal file
@@ -0,0 +1,60 @@
|
||||
{
|
||||
"datetime": {
|
||||
"name": "时间工具",
|
||||
"description": "获取当前时间、日期计算",
|
||||
"tool_class": "DateTimeTool",
|
||||
"category": "utility",
|
||||
"requires_config": false,
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {}
|
||||
},
|
||||
"json_converter": {
|
||||
"name": "JSON转换工具",
|
||||
"description": "JSON数据格式化和转换",
|
||||
"tool_class": "JsonTool",
|
||||
"category": "utility",
|
||||
"requires_config": false,
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {}
|
||||
},
|
||||
"baidu_search": {
|
||||
"name": "百度搜索",
|
||||
"description": "百度网页搜索服务",
|
||||
"tool_class": "BaiduSearchTool",
|
||||
"category": "search",
|
||||
"requires_config": true,
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {
|
||||
"api_key": {"type": "string", "description": "百度搜索API密钥", "sensitive": true, "required": true}
|
||||
}
|
||||
},
|
||||
"mineru": {
|
||||
"name": "MinerU",
|
||||
"description": "PDF文档解析工具",
|
||||
"tool_class": "MinerUTool",
|
||||
"category": "document",
|
||||
"requires_config": true,
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {
|
||||
"api_key": {"type": "string", "description": "MinerU API密钥", "sensitive": true, "required": true},
|
||||
"base_url": {"type": "string", "description": "API地址", "default": "https://api.mineru.com"}
|
||||
}
|
||||
},
|
||||
"textin": {
|
||||
"name": "TextIn",
|
||||
"description": "OCR文字识别服务",
|
||||
"tool_class": "TextInTool",
|
||||
"category": "ocr",
|
||||
"requires_config": true,
|
||||
"version": "1.0.0",
|
||||
"enabled": true,
|
||||
"parameters": {
|
||||
"api_key": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true},
|
||||
"api_secret": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true}
|
||||
}
|
||||
}
|
||||
}
|
||||
11
api/app/core/tools/custom/__init__.py
Normal file
11
api/app/core/tools/custom/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""自定义工具模块"""
|
||||
|
||||
from .base import CustomTool
|
||||
from .schema_parser import OpenAPISchemaParser
|
||||
from .auth_manager import AuthManager
|
||||
|
||||
__all__ = [
|
||||
"CustomTool",
|
||||
"OpenAPISchemaParser",
|
||||
"AuthManager"
|
||||
]
|
||||
525
api/app/core/tools/custom/auth_manager.py
Normal file
525
api/app/core/tools/custom/auth_manager.py
Normal file
@@ -0,0 +1,525 @@
|
||||
"""认证管理器 - 处理自定义工具的认证配置"""
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import time
|
||||
from typing import Dict, Any, Tuple
|
||||
from urllib.parse import quote
|
||||
import aiohttp
|
||||
|
||||
from app.models.tool_model import AuthType
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class AuthManager:
|
||||
"""认证管理器 - 支持多种认证方式"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化认证管理器"""
|
||||
self.supported_auth_types = [
|
||||
AuthType.NONE,
|
||||
AuthType.API_KEY,
|
||||
AuthType.BEARER_TOKEN
|
||||
]
|
||||
|
||||
def validate_auth_config(self, auth_type: AuthType, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""验证认证配置
|
||||
|
||||
Args:
|
||||
auth_type: 认证类型
|
||||
auth_config: 认证配置
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
try:
|
||||
if auth_type not in self.supported_auth_types:
|
||||
return False, f"不支持的认证类型: {auth_type}"
|
||||
|
||||
if auth_type == AuthType.NONE:
|
||||
return True, ""
|
||||
|
||||
elif auth_type == AuthType.API_KEY:
|
||||
return self._validate_api_key_config(auth_config)
|
||||
|
||||
elif auth_type == AuthType.BEARER_TOKEN:
|
||||
return self._validate_bearer_token_config(auth_config)
|
||||
|
||||
return False, "未知的认证类型"
|
||||
|
||||
except Exception as e:
|
||||
return False, f"验证认证配置时出错: {e}"
|
||||
|
||||
def _validate_api_key_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""验证API Key认证配置
|
||||
|
||||
Args:
|
||||
auth_config: 认证配置
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
api_key = auth_config.get("api_key")
|
||||
if not api_key:
|
||||
return False, "API Key不能为空"
|
||||
|
||||
if not isinstance(api_key, str):
|
||||
return False, "API Key必须是字符串"
|
||||
|
||||
# 验证key名称
|
||||
key_name = auth_config.get("key_name", "X-API-Key")
|
||||
if not isinstance(key_name, str):
|
||||
return False, "API Key名称必须是字符串"
|
||||
|
||||
# 验证位置
|
||||
key_location = auth_config.get("location", "header")
|
||||
if key_location not in ["header", "query", "cookie"]:
|
||||
return False, "API Key位置必须是 header、query 或 cookie"
|
||||
|
||||
return True, ""
|
||||
|
||||
def _validate_bearer_token_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""验证Bearer Token认证配置
|
||||
|
||||
Args:
|
||||
auth_config: 认证配置
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
token = auth_config.get("token")
|
||||
if not token:
|
||||
return False, "Bearer Token不能为空"
|
||||
|
||||
if not isinstance(token, str):
|
||||
return False, "Bearer Token必须是字符串"
|
||||
|
||||
return True, ""
|
||||
|
||||
def apply_authentication(
|
||||
self,
|
||||
auth_type: AuthType,
|
||||
auth_config: Dict[str, Any],
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
params: Dict[str, Any]
|
||||
) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
|
||||
"""应用认证到请求
|
||||
|
||||
Args:
|
||||
auth_type: 认证类型
|
||||
auth_config: 认证配置
|
||||
url: 请求URL
|
||||
headers: 请求头
|
||||
params: 请求参数
|
||||
|
||||
Returns:
|
||||
(修改后的URL, 修改后的headers, 修改后的params)
|
||||
"""
|
||||
try:
|
||||
if auth_type == AuthType.NONE:
|
||||
return url, headers, params
|
||||
|
||||
elif auth_type == AuthType.API_KEY:
|
||||
return self._apply_api_key_auth(auth_config, url, headers, params)
|
||||
|
||||
elif auth_type == AuthType.BEARER_TOKEN:
|
||||
return self._apply_bearer_token_auth(auth_config, url, headers, params)
|
||||
|
||||
else:
|
||||
logger.warning(f"不支持的认证类型: {auth_type}")
|
||||
return url, headers, params
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"应用认证时出错: {e}")
|
||||
return url, headers, params
|
||||
|
||||
def _apply_api_key_auth(
|
||||
self,
|
||||
auth_config: Dict[str, Any],
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
params: Dict[str, Any]
|
||||
) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
|
||||
"""应用API Key认证
|
||||
|
||||
Args:
|
||||
auth_config: 认证配置
|
||||
url: 请求URL
|
||||
headers: 请求头
|
||||
params: 请求参数
|
||||
|
||||
Returns:
|
||||
(修改后的URL, 修改后的headers, 修改后的params)
|
||||
"""
|
||||
api_key = auth_config.get("api_key")
|
||||
key_name = auth_config.get("key_name", "X-API-Key")
|
||||
location = auth_config.get("location", "header")
|
||||
|
||||
if location == "header":
|
||||
headers[key_name] = api_key
|
||||
|
||||
elif location == "query":
|
||||
# 添加到URL查询参数
|
||||
separator = "&" if "?" in url else "?"
|
||||
encoded_key = quote(str(api_key))
|
||||
url += f"{separator}{key_name}={encoded_key}"
|
||||
|
||||
elif location == "cookie":
|
||||
# 添加到Cookie头
|
||||
cookie_value = f"{key_name}={api_key}"
|
||||
if "Cookie" in headers:
|
||||
headers["Cookie"] += f"; {cookie_value}"
|
||||
else:
|
||||
headers["Cookie"] = cookie_value
|
||||
|
||||
return url, headers, params
|
||||
|
||||
def _apply_bearer_token_auth(
|
||||
self,
|
||||
auth_config: Dict[str, Any],
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
params: Dict[str, Any]
|
||||
) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
|
||||
"""应用Bearer Token认证
|
||||
|
||||
Args:
|
||||
auth_config: 认证配置
|
||||
url: 请求URL
|
||||
headers: 请求头
|
||||
params: 请求参数
|
||||
|
||||
Returns:
|
||||
(修改后的URL, 修改后的headers, 修改后的params)
|
||||
"""
|
||||
token = auth_config.get("token")
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
return url, headers, params
|
||||
|
||||
def encrypt_auth_config(self, auth_config: Dict[str, Any], encryption_key: str) -> Dict[str, Any]:
|
||||
"""加密认证配置中的敏感信息
|
||||
|
||||
Args:
|
||||
auth_config: 认证配置
|
||||
encryption_key: 加密密钥
|
||||
|
||||
Returns:
|
||||
加密后的认证配置
|
||||
"""
|
||||
try:
|
||||
encrypted_config = auth_config.copy()
|
||||
|
||||
# 需要加密的字段
|
||||
sensitive_fields = ["api_key", "token", "secret", "password"]
|
||||
|
||||
for field in sensitive_fields:
|
||||
if field in encrypted_config:
|
||||
value = encrypted_config[field]
|
||||
if isinstance(value, str) and value:
|
||||
encrypted_value = self._encrypt_string(value, encryption_key)
|
||||
encrypted_config[field] = encrypted_value
|
||||
encrypted_config[f"{field}_encrypted"] = True
|
||||
|
||||
return encrypted_config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加密认证配置失败: {e}")
|
||||
return auth_config
|
||||
|
||||
def decrypt_auth_config(self, encrypted_config: Dict[str, Any], encryption_key: str) -> Dict[str, Any]:
|
||||
"""解密认证配置中的敏感信息
|
||||
|
||||
Args:
|
||||
encrypted_config: 加密的认证配置
|
||||
encryption_key: 解密密钥
|
||||
|
||||
Returns:
|
||||
解密后的认证配置
|
||||
"""
|
||||
try:
|
||||
decrypted_config = encrypted_config.copy()
|
||||
|
||||
# 需要解密的字段
|
||||
sensitive_fields = ["api_key", "token", "secret", "password"]
|
||||
|
||||
for field in sensitive_fields:
|
||||
if field in decrypted_config and decrypted_config.get(f"{field}_encrypted"):
|
||||
encrypted_value = decrypted_config[field]
|
||||
if isinstance(encrypted_value, str) and encrypted_value:
|
||||
decrypted_value = self._decrypt_string(encrypted_value, encryption_key)
|
||||
decrypted_config[field] = decrypted_value
|
||||
# 移除加密标记
|
||||
decrypted_config.pop(f"{field}_encrypted", None)
|
||||
|
||||
return decrypted_config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解密认证配置失败: {e}")
|
||||
return encrypted_config
|
||||
|
||||
def _encrypt_string(self, value: str, key: str) -> str:
|
||||
"""加密字符串
|
||||
|
||||
Args:
|
||||
value: 要加密的字符串
|
||||
key: 加密密钥
|
||||
|
||||
Returns:
|
||||
加密后的字符串(Base64编码)
|
||||
"""
|
||||
try:
|
||||
# 使用HMAC-SHA256进行简单加密
|
||||
key_bytes = key.encode('utf-8')
|
||||
value_bytes = value.encode('utf-8')
|
||||
|
||||
# 生成HMAC
|
||||
hmac_obj = hmac.new(key_bytes, value_bytes, hashlib.sha256)
|
||||
signature = hmac_obj.hexdigest()
|
||||
|
||||
# 组合原始值和签名,然后Base64编码
|
||||
combined = f"{value}:{signature}"
|
||||
encrypted = base64.b64encode(combined.encode('utf-8')).decode('utf-8')
|
||||
|
||||
return encrypted
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加密字符串失败: {e}")
|
||||
return value
|
||||
|
||||
def _decrypt_string(self, encrypted_value: str, key: str) -> str:
|
||||
"""解密字符串
|
||||
|
||||
Args:
|
||||
encrypted_value: 加密的字符串
|
||||
key: 解密密钥
|
||||
|
||||
Returns:
|
||||
解密后的字符串
|
||||
"""
|
||||
try:
|
||||
# Base64解码
|
||||
decoded = base64.b64decode(encrypted_value.encode('utf-8')).decode('utf-8')
|
||||
|
||||
# 分离原始值和签名
|
||||
if ':' not in decoded:
|
||||
return encrypted_value # 可能不是加密的值
|
||||
|
||||
value, signature = decoded.rsplit(':', 1)
|
||||
|
||||
# 验证签名
|
||||
key_bytes = key.encode('utf-8')
|
||||
value_bytes = value.encode('utf-8')
|
||||
|
||||
hmac_obj = hmac.new(key_bytes, value_bytes, hashlib.sha256)
|
||||
expected_signature = hmac_obj.hexdigest()
|
||||
|
||||
if signature == expected_signature:
|
||||
return value
|
||||
else:
|
||||
logger.warning("解密时签名验证失败")
|
||||
return encrypted_value
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解密字符串失败: {e}")
|
||||
return encrypted_value
|
||||
|
||||
def test_authentication(
|
||||
self,
|
||||
auth_type: AuthType,
|
||||
auth_config: Dict[str, Any],
|
||||
test_url: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""测试认证配置
|
||||
|
||||
Args:
|
||||
auth_type: 认证类型
|
||||
auth_config: 认证配置
|
||||
test_url: 测试URL(可选)
|
||||
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
try:
|
||||
# 验证配置
|
||||
is_valid, error_msg = self.validate_auth_config(auth_type, auth_config)
|
||||
if not is_valid:
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_msg,
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
|
||||
# 如果没有测试URL,只验证配置
|
||||
if not test_url:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "认证配置有效",
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
|
||||
# 构建测试请求
|
||||
headers = {"User-Agent": "AuthManager-Test/1.0"}
|
||||
params = {}
|
||||
|
||||
# 应用认证
|
||||
test_url, headers, params = self.apply_authentication(
|
||||
auth_type, auth_config, test_url, headers, params
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "认证配置测试成功",
|
||||
"auth_type": auth_type.value,
|
||||
"test_url": test_url,
|
||||
"headers": {k: v for k, v in headers.items() if k != "Authorization"}, # 不返回敏感信息
|
||||
"has_auth_header": "Authorization" in headers
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"auth_type": auth_type.value if auth_type else "unknown"
|
||||
}
|
||||
|
||||
async def test_authentication_with_request(
|
||||
self,
|
||||
auth_type: AuthType,
|
||||
auth_config: Dict[str, Any],
|
||||
test_url: str,
|
||||
timeout: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""通过实际HTTP请求测试认证
|
||||
|
||||
Args:
|
||||
auth_type: 认证类型
|
||||
auth_config: 认证配置
|
||||
test_url: 测试URL
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
try:
|
||||
# 验证配置
|
||||
is_valid, error_msg = self.validate_auth_config(auth_type, auth_config)
|
||||
if not is_valid:
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_msg,
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
|
||||
# 构建请求
|
||||
headers = {"User-Agent": "AuthManager-Test/1.0"}
|
||||
params = {}
|
||||
|
||||
# 应用认证
|
||||
test_url, headers, params = self.apply_authentication(
|
||||
auth_type, auth_config, test_url, headers, params
|
||||
)
|
||||
|
||||
# 发送测试请求
|
||||
client_timeout = aiohttp.ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(timeout=client_timeout) as session:
|
||||
async with session.get(test_url, headers=headers) as response:
|
||||
status_code = response.status
|
||||
|
||||
# 根据状态码判断认证是否成功
|
||||
if status_code == 200:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "认证测试成功",
|
||||
"status_code": status_code,
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
elif status_code == 401:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "认证失败 - 401 Unauthorized",
|
||||
"status_code": status_code,
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
elif status_code == 403:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "认证失败 - 403 Forbidden",
|
||||
"status_code": status_code,
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"请求成功,状态码: {status_code}",
|
||||
"status_code": status_code,
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"网络请求失败: {e}",
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"测试认证时出错: {e}",
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
|
||||
def get_auth_config_template(self, auth_type: AuthType) -> Dict[str, Any]:
|
||||
"""获取认证配置模板
|
||||
|
||||
Args:
|
||||
auth_type: 认证类型
|
||||
|
||||
Returns:
|
||||
配置模板
|
||||
"""
|
||||
templates = {
|
||||
AuthType.NONE: {},
|
||||
|
||||
AuthType.API_KEY: {
|
||||
"api_key": "",
|
||||
"key_name": "X-API-Key",
|
||||
"location": "header", # header, query, cookie
|
||||
"description": "API Key认证配置"
|
||||
},
|
||||
|
||||
AuthType.BEARER_TOKEN: {
|
||||
"token": "",
|
||||
"description": "Bearer Token认证配置"
|
||||
}
|
||||
}
|
||||
|
||||
return templates.get(auth_type, {})
|
||||
|
||||
def mask_sensitive_config(self, auth_config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""遮蔽认证配置中的敏感信息
|
||||
|
||||
Args:
|
||||
auth_config: 认证配置
|
||||
|
||||
Returns:
|
||||
遮蔽敏感信息后的配置
|
||||
"""
|
||||
masked_config = auth_config.copy()
|
||||
|
||||
# 需要遮蔽的字段
|
||||
sensitive_fields = ["api_key", "token", "secret", "password"]
|
||||
|
||||
for field in sensitive_fields:
|
||||
if field in masked_config:
|
||||
value = masked_config[field]
|
||||
if isinstance(value, str) and len(value) > 4:
|
||||
# 只显示前2位和后2位
|
||||
masked_config[field] = f"{value[:2]}***{value[-2:]}"
|
||||
elif isinstance(value, str) and value:
|
||||
masked_config[field] = "***"
|
||||
|
||||
return masked_config
|
||||
318
api/app/core/tools/custom/base.py
Normal file
318
api/app/core/tools/custom/base.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""自定义工具基类"""
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
import aiohttp
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from app.models.tool_model import ToolType, AuthType
|
||||
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class CustomTool(BaseTool):
|
||||
"""自定义工具 - 基于OpenAPI schema的工具"""
|
||||
|
||||
def __init__(self, tool_id: str, config: Dict[str, Any]):
|
||||
"""初始化自定义工具
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
config: 工具配置
|
||||
"""
|
||||
super().__init__(tool_id, config)
|
||||
self.schema_content = config.get("schema_content", {})
|
||||
self.schema_url = config.get("schema_url")
|
||||
self.auth_type = AuthType(config.get("auth_type", "none"))
|
||||
self.auth_config = config.get("auth_config", {})
|
||||
self.base_url = config.get("base_url", "")
|
||||
self.timeout = config.get("timeout", 30)
|
||||
|
||||
# 解析schema
|
||||
self._parsed_operations = self._parse_openapi_schema()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""工具名称"""
|
||||
if self.schema_content:
|
||||
info = self.schema_content.get("info", {})
|
||||
return info.get("title", f"custom_tool_{self.tool_id[:8]}")
|
||||
return f"custom_tool_{self.tool_id[:8]}"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""工具描述"""
|
||||
if self.schema_content:
|
||||
info = self.schema_content.get("info", {})
|
||||
return info.get("description", "自定义API工具")
|
||||
return "自定义API工具"
|
||||
|
||||
@property
|
||||
def tool_type(self) -> ToolType:
|
||||
"""工具类型"""
|
||||
return ToolType.CUSTOM
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
"""工具参数定义"""
|
||||
params = []
|
||||
|
||||
# 添加操作选择参数
|
||||
if len(self._parsed_operations) > 1:
|
||||
params.append(ToolParameter(
|
||||
name="operation",
|
||||
type=ParameterType.STRING,
|
||||
description="要执行的操作",
|
||||
required=True,
|
||||
enum=list(self._parsed_operations.keys())
|
||||
))
|
||||
|
||||
# 添加通用参数(基于第一个操作的参数)
|
||||
if self._parsed_operations:
|
||||
first_operation = next(iter(self._parsed_operations.values()))
|
||||
for param_name, param_info in first_operation.get("parameters", {}).items():
|
||||
params.append(ToolParameter(
|
||||
name=param_name,
|
||||
type=self._convert_openapi_type(param_info.get("type", "string")),
|
||||
description=param_info.get("description", ""),
|
||||
required=param_info.get("required", False),
|
||||
default=param_info.get("default"),
|
||||
enum=param_info.get("enum"),
|
||||
minimum=param_info.get("minimum"),
|
||||
maximum=param_info.get("maximum"),
|
||||
pattern=param_info.get("pattern")
|
||||
))
|
||||
|
||||
return params
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行自定义工具"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 确定要执行的操作
|
||||
operation_name = kwargs.get("operation")
|
||||
if not operation_name and len(self._parsed_operations) == 1:
|
||||
operation_name = next(iter(self._parsed_operations.keys()))
|
||||
|
||||
if not operation_name or operation_name not in self._parsed_operations:
|
||||
raise ValueError(f"无效的操作: {operation_name}")
|
||||
|
||||
operation = self._parsed_operations[operation_name]
|
||||
|
||||
# 构建请求
|
||||
url = self._build_request_url(operation, kwargs)
|
||||
headers = self._build_request_headers(operation)
|
||||
data = self._build_request_data(operation, kwargs)
|
||||
|
||||
# 发送HTTP请求
|
||||
result = await self._send_http_request(
|
||||
method=operation["method"],
|
||||
url=url,
|
||||
headers=headers,
|
||||
data=data
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="CUSTOM_TOOL_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
def _parse_openapi_schema(self) -> Dict[str, Any]:
|
||||
"""解析OpenAPI schema"""
|
||||
operations = {}
|
||||
|
||||
if not self.schema_content:
|
||||
return operations
|
||||
|
||||
paths = self.schema_content.get("paths", {})
|
||||
|
||||
for path, path_item in paths.items():
|
||||
for method, operation in path_item.items():
|
||||
if method.lower() in ["get", "post", "put", "delete", "patch"]:
|
||||
operation_id = operation.get("operationId", f"{method}_{path.replace('/', '_')}")
|
||||
|
||||
# 解析参数
|
||||
parameters = {}
|
||||
if "parameters" in operation:
|
||||
for param in operation["parameters"]:
|
||||
param_name = param.get("name")
|
||||
param_schema = param.get("schema", {})
|
||||
parameters[param_name] = {
|
||||
"type": param_schema.get("type", "string"),
|
||||
"description": param.get("description", ""),
|
||||
"required": param.get("required", False),
|
||||
"in": param.get("in", "query"),
|
||||
**param_schema
|
||||
}
|
||||
|
||||
# 解析请求体
|
||||
request_body = None
|
||||
if "requestBody" in operation:
|
||||
content = operation["requestBody"].get("content", {})
|
||||
if "application/json" in content:
|
||||
request_body = content["application/json"].get("schema", {})
|
||||
|
||||
operations[operation_id] = {
|
||||
"method": method.upper(),
|
||||
"path": path,
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
"parameters": parameters,
|
||||
"request_body": request_body
|
||||
}
|
||||
|
||||
return operations
|
||||
|
||||
def _convert_openapi_type(self, openapi_type: str) -> ParameterType:
|
||||
"""转换OpenAPI类型到内部类型"""
|
||||
type_mapping = {
|
||||
"string": ParameterType.STRING,
|
||||
"integer": ParameterType.INTEGER,
|
||||
"number": ParameterType.NUMBER,
|
||||
"boolean": ParameterType.BOOLEAN,
|
||||
"array": ParameterType.ARRAY,
|
||||
"object": ParameterType.OBJECT
|
||||
}
|
||||
return type_mapping.get(openapi_type, ParameterType.STRING)
|
||||
|
||||
def _build_request_url(self, operation: Dict[str, Any], params: Dict[str, Any]) -> str:
|
||||
"""构建请求URL"""
|
||||
path = operation["path"]
|
||||
|
||||
# 替换路径参数
|
||||
for param_name, param_info in operation.get("parameters", {}).items():
|
||||
if param_info.get("in") == "path" and param_name in params:
|
||||
path = path.replace(f"{{{param_name}}}", str(params[param_name]))
|
||||
|
||||
# 构建完整URL
|
||||
if self.base_url:
|
||||
url = urljoin(self.base_url, path.lstrip("/"))
|
||||
else:
|
||||
# 从schema中获取服务器URL
|
||||
servers = self.schema_content.get("servers", [])
|
||||
if servers:
|
||||
base_url = servers[0].get("url", "")
|
||||
url = urljoin(base_url, path.lstrip("/"))
|
||||
else:
|
||||
url = path
|
||||
|
||||
# 添加查询参数
|
||||
query_params = {}
|
||||
for param_name, param_info in operation.get("parameters", {}).items():
|
||||
if param_info.get("in") == "query" and param_name in params:
|
||||
query_params[param_name] = params[param_name]
|
||||
|
||||
if query_params:
|
||||
from urllib.parse import urlencode
|
||||
url += "?" + urlencode(query_params)
|
||||
|
||||
return url
|
||||
|
||||
def _build_request_headers(self, operation: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""构建请求头"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "CustomTool/1.0"
|
||||
}
|
||||
|
||||
# 添加认证头
|
||||
if self.auth_type == AuthType.API_KEY:
|
||||
api_key = self.auth_config.get("api_key")
|
||||
key_name = self.auth_config.get("key_name", "X-API-Key")
|
||||
if api_key:
|
||||
headers[key_name] = api_key
|
||||
|
||||
elif self.auth_type == AuthType.BEARER_TOKEN:
|
||||
token = self.auth_config.get("token")
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
return headers
|
||||
|
||||
def _build_request_data(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""构建请求数据"""
|
||||
if operation["method"] in ["POST", "PUT", "PATCH"]:
|
||||
request_body = operation.get("request_body")
|
||||
if request_body:
|
||||
# 构建请求体数据
|
||||
data = {}
|
||||
properties = request_body.get("properties", {})
|
||||
|
||||
for prop_name, prop_schema in properties.items():
|
||||
if prop_name in params:
|
||||
data[prop_name] = params[prop_name]
|
||||
|
||||
return data if data else None
|
||||
|
||||
return None
|
||||
|
||||
async def _send_http_request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
) -> Any:
|
||||
"""发送HTTP请求"""
|
||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
kwargs = {
|
||||
"headers": headers
|
||||
}
|
||||
|
||||
if data and method in ["POST", "PUT", "PATCH"]:
|
||||
kwargs["json"] = data
|
||||
|
||||
async with session.request(method, url, **kwargs) as response:
|
||||
if response.status >= 400:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"HTTP {response.status}: {error_text}")
|
||||
|
||||
# 尝试解析JSON响应
|
||||
try:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
return await response.text()
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, schema_url: str, auth_config: Dict[str, Any], tool_id: str = None) -> 'CustomTool':
|
||||
"""从URL导入OpenAPI schema创建工具"""
|
||||
import uuid
|
||||
if not tool_id:
|
||||
tool_id = str(uuid.uuid4())
|
||||
|
||||
config = {
|
||||
"schema_url": schema_url,
|
||||
"auth_config": auth_config,
|
||||
"auth_type": auth_config.get("type", "none")
|
||||
}
|
||||
|
||||
# 这里应该异步加载schema,为了简化暂时返回空配置
|
||||
return cls(tool_id, config)
|
||||
|
||||
@classmethod
|
||||
def from_schema(cls, schema_dict: Dict[str, Any], auth_config: Dict[str, Any], tool_id: str = None) -> 'CustomTool':
|
||||
"""从schema字典创建工具"""
|
||||
import uuid
|
||||
if not tool_id:
|
||||
tool_id = str(uuid.uuid4())
|
||||
|
||||
config = {
|
||||
"schema_content": schema_dict,
|
||||
"auth_config": auth_config,
|
||||
"auth_type": auth_config.get("type", "none")
|
||||
}
|
||||
|
||||
return cls(tool_id, config)
|
||||
477
api/app/core/tools/custom/schema_parser.py
Normal file
477
api/app/core/tools/custom/schema_parser.py
Normal file
@@ -0,0 +1,477 @@
|
||||
"""OpenAPI Schema解析器"""
|
||||
import json
|
||||
import yaml
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
import aiohttp
|
||||
import asyncio
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class OpenAPISchemaParser:
|
||||
"""OpenAPI Schema解析器 - 解析OpenAPI 3.0规范"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化解析器"""
|
||||
self.supported_versions = ["3.0.0", "3.0.1", "3.0.2", "3.0.3", "3.1.0"]
|
||||
|
||||
async def parse_from_url(self, schema_url: str, timeout: int = 30) -> Tuple[bool, Dict[str, Any], str]:
|
||||
"""从URL解析OpenAPI schema
|
||||
|
||||
Args:
|
||||
schema_url: Schema URL
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
(是否成功, schema内容, 错误信息)
|
||||
"""
|
||||
try:
|
||||
# 验证URL格式
|
||||
parsed_url = urlparse(schema_url)
|
||||
if not parsed_url.scheme or not parsed_url.netloc:
|
||||
return False, {}, "无效的URL格式"
|
||||
|
||||
# 下载schema
|
||||
client_timeout = aiohttp.ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(timeout=client_timeout) as session:
|
||||
async with session.get(schema_url) as response:
|
||||
if response.status != 200:
|
||||
return False, {}, f"HTTP错误: {response.status}"
|
||||
|
||||
content_type = response.headers.get('content-type', '').lower()
|
||||
content = await response.text()
|
||||
|
||||
# 解析内容
|
||||
schema_dict = self._parse_content(content, content_type)
|
||||
if not schema_dict:
|
||||
return False, {}, "无法解析schema内容"
|
||||
|
||||
# 验证schema
|
||||
is_valid, error_msg = self.validate_schema(schema_dict)
|
||||
if not is_valid:
|
||||
return False, {}, error_msg
|
||||
|
||||
return True, schema_dict, ""
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return False, {}, "请求超时"
|
||||
except Exception as e:
|
||||
logger.error(f"从URL解析schema失败: {schema_url}, 错误: {e}")
|
||||
return False, {}, str(e)
|
||||
|
||||
def parse_from_content(self, content: str, content_type: str = "application/json") -> Tuple[bool, Dict[str, Any], str]:
|
||||
"""从内容解析OpenAPI schema
|
||||
|
||||
Args:
|
||||
content: Schema内容
|
||||
content_type: 内容类型
|
||||
|
||||
Returns:
|
||||
(是否成功, schema内容, 错误信息)
|
||||
"""
|
||||
try:
|
||||
# 解析内容
|
||||
schema_dict = self._parse_content(content, content_type)
|
||||
if not schema_dict:
|
||||
return False, {}, "无法解析schema内容"
|
||||
|
||||
# 验证schema
|
||||
is_valid, error_msg = self.validate_schema(schema_dict)
|
||||
if not is_valid:
|
||||
return False, {}, error_msg
|
||||
|
||||
return True, schema_dict, ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析schema内容失败: {e}")
|
||||
return False, {}, str(e)
|
||||
|
||||
def _parse_content(self, content: str, content_type: str) -> Optional[Dict[str, Any]]:
|
||||
"""解析内容为字典
|
||||
|
||||
Args:
|
||||
content: 内容字符串
|
||||
content_type: 内容类型
|
||||
|
||||
Returns:
|
||||
解析后的字典,失败返回None
|
||||
"""
|
||||
try:
|
||||
# 根据内容类型解析
|
||||
if 'json' in content_type:
|
||||
return json.loads(content)
|
||||
elif 'yaml' in content_type or 'yml' in content_type:
|
||||
return yaml.safe_load(content)
|
||||
else:
|
||||
# 尝试自动检测格式
|
||||
try:
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
return yaml.safe_load(content)
|
||||
except yaml.YAMLError:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"解析内容失败: {e}")
|
||||
return None
|
||||
|
||||
def validate_schema(self, schema_dict: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""验证OpenAPI schema
|
||||
|
||||
Args:
|
||||
schema_dict: Schema字典
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
try:
|
||||
# 检查基本结构
|
||||
if not isinstance(schema_dict, dict):
|
||||
return False, "Schema必须是JSON对象"
|
||||
|
||||
# 检查OpenAPI版本
|
||||
openapi_version = schema_dict.get("openapi")
|
||||
if not openapi_version:
|
||||
return False, "缺少openapi版本字段"
|
||||
|
||||
if openapi_version not in self.supported_versions:
|
||||
return False, f"不支持的OpenAPI版本: {openapi_version}"
|
||||
|
||||
# 检查必需字段
|
||||
required_fields = ["info", "paths"]
|
||||
for field in required_fields:
|
||||
if field not in schema_dict:
|
||||
return False, f"缺少必需字段: {field}"
|
||||
|
||||
# 验证info字段
|
||||
info = schema_dict.get("info", {})
|
||||
if not isinstance(info, dict):
|
||||
return False, "info字段必须是对象"
|
||||
|
||||
if "title" not in info:
|
||||
return False, "info.title字段是必需的"
|
||||
|
||||
# 验证paths字段
|
||||
paths = schema_dict.get("paths", {})
|
||||
if not isinstance(paths, dict):
|
||||
return False, "paths字段必须是对象"
|
||||
|
||||
# 验证至少有一个路径
|
||||
if not paths:
|
||||
return False, "至少需要定义一个API路径"
|
||||
|
||||
return True, ""
|
||||
|
||||
except Exception as e:
|
||||
return False, f"验证schema时出错: {e}"
|
||||
|
||||
def extract_tool_info(self, schema_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""从schema提取工具信息
|
||||
|
||||
Args:
|
||||
schema_dict: Schema字典
|
||||
|
||||
Returns:
|
||||
工具信息字典
|
||||
"""
|
||||
info = schema_dict.get("info", {})
|
||||
|
||||
return {
|
||||
"name": info.get("title", "Custom API Tool"),
|
||||
"description": info.get("description", ""),
|
||||
"version": info.get("version", "1.0.0"),
|
||||
"servers": schema_dict.get("servers", []),
|
||||
"operations": self._extract_operations(schema_dict)
|
||||
}
|
||||
|
||||
def _extract_operations(self, schema_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取API操作信息
|
||||
|
||||
Args:
|
||||
schema_dict: Schema字典
|
||||
|
||||
Returns:
|
||||
操作信息字典
|
||||
"""
|
||||
operations = {}
|
||||
paths = schema_dict.get("paths", {})
|
||||
|
||||
for path, path_item in paths.items():
|
||||
if not isinstance(path_item, dict):
|
||||
continue
|
||||
|
||||
for method, operation in path_item.items():
|
||||
if method.lower() not in ["get", "post", "put", "delete", "patch", "head", "options"]:
|
||||
continue
|
||||
|
||||
if not isinstance(operation, dict):
|
||||
continue
|
||||
|
||||
# 生成操作ID
|
||||
operation_id = operation.get("operationId")
|
||||
if not operation_id:
|
||||
operation_id = f"{method.lower()}_{path.replace('/', '_').replace('{', '').replace('}', '')}"
|
||||
|
||||
# 提取操作信息
|
||||
operations[operation_id] = {
|
||||
"method": method.upper(),
|
||||
"path": path,
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
"parameters": self._extract_parameters(operation),
|
||||
"request_body": self._extract_request_body(operation),
|
||||
"responses": self._extract_responses(operation),
|
||||
"tags": operation.get("tags", [])
|
||||
}
|
||||
|
||||
return operations
|
||||
|
||||
def _extract_parameters(self, operation: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取操作参数
|
||||
|
||||
Args:
|
||||
operation: 操作定义
|
||||
|
||||
Returns:
|
||||
参数信息字典
|
||||
"""
|
||||
parameters = {}
|
||||
|
||||
for param in operation.get("parameters", []):
|
||||
if not isinstance(param, dict):
|
||||
continue
|
||||
|
||||
param_name = param.get("name")
|
||||
if not param_name:
|
||||
continue
|
||||
|
||||
param_schema = param.get("schema", {})
|
||||
|
||||
parameters[param_name] = {
|
||||
"name": param_name,
|
||||
"in": param.get("in", "query"),
|
||||
"description": param.get("description", ""),
|
||||
"required": param.get("required", False),
|
||||
"type": param_schema.get("type", "string"),
|
||||
"format": param_schema.get("format"),
|
||||
"enum": param_schema.get("enum"),
|
||||
"default": param_schema.get("default"),
|
||||
"minimum": param_schema.get("minimum"),
|
||||
"maximum": param_schema.get("maximum"),
|
||||
"pattern": param_schema.get("pattern"),
|
||||
"example": param.get("example") or param_schema.get("example")
|
||||
}
|
||||
|
||||
return parameters
|
||||
|
||||
def _extract_request_body(self, operation: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""提取请求体信息
|
||||
|
||||
Args:
|
||||
operation: 操作定义
|
||||
|
||||
Returns:
|
||||
请求体信息,如果没有返回None
|
||||
"""
|
||||
request_body = operation.get("requestBody")
|
||||
if not request_body:
|
||||
return None
|
||||
|
||||
content = request_body.get("content", {})
|
||||
|
||||
# 优先使用application/json
|
||||
if "application/json" in content:
|
||||
schema = content["application/json"].get("schema", {})
|
||||
elif content:
|
||||
# 使用第一个可用的内容类型
|
||||
first_content_type = next(iter(content.keys()))
|
||||
schema = content[first_content_type].get("schema", {})
|
||||
else:
|
||||
return None
|
||||
|
||||
return {
|
||||
"description": request_body.get("description", ""),
|
||||
"required": request_body.get("required", False),
|
||||
"schema": schema,
|
||||
"content_types": list(content.keys())
|
||||
}
|
||||
|
||||
def _extract_responses(self, operation: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取响应信息
|
||||
|
||||
Args:
|
||||
operation: 操作定义
|
||||
|
||||
Returns:
|
||||
响应信息字典
|
||||
"""
|
||||
responses = {}
|
||||
|
||||
for status_code, response in operation.get("responses", {}).items():
|
||||
if not isinstance(response, dict):
|
||||
continue
|
||||
|
||||
content = response.get("content", {})
|
||||
schema = None
|
||||
|
||||
# 尝试获取响应schema
|
||||
if "application/json" in content:
|
||||
schema = content["application/json"].get("schema")
|
||||
elif content:
|
||||
first_content_type = next(iter(content.keys()))
|
||||
schema = content[first_content_type].get("schema")
|
||||
|
||||
responses[status_code] = {
|
||||
"description": response.get("description", ""),
|
||||
"schema": schema,
|
||||
"content_types": list(content.keys()) if content else []
|
||||
}
|
||||
|
||||
return responses
|
||||
|
||||
def generate_tool_parameters(self, operations: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""生成工具参数定义
|
||||
|
||||
Args:
|
||||
operations: 操作信息字典
|
||||
|
||||
Returns:
|
||||
参数定义列表
|
||||
"""
|
||||
parameters = []
|
||||
|
||||
# 如果有多个操作,添加操作选择参数
|
||||
if len(operations) > 1:
|
||||
parameters.append({
|
||||
"name": "operation",
|
||||
"type": "string",
|
||||
"description": "要执行的操作",
|
||||
"required": True,
|
||||
"enum": list(operations.keys())
|
||||
})
|
||||
|
||||
# 收集所有参数(去重)
|
||||
all_params = {}
|
||||
|
||||
for operation_id, operation in operations.items():
|
||||
# 路径参数和查询参数
|
||||
for param_name, param_info in operation.get("parameters", {}).items():
|
||||
if param_name not in all_params:
|
||||
all_params[param_name] = {
|
||||
"name": param_name,
|
||||
"type": param_info.get("type", "string"),
|
||||
"description": param_info.get("description", ""),
|
||||
"required": param_info.get("required", False),
|
||||
"enum": param_info.get("enum"),
|
||||
"default": param_info.get("default"),
|
||||
"minimum": param_info.get("minimum"),
|
||||
"maximum": param_info.get("maximum"),
|
||||
"pattern": param_info.get("pattern")
|
||||
}
|
||||
|
||||
# 请求体参数
|
||||
request_body = operation.get("request_body")
|
||||
if request_body:
|
||||
schema = request_body.get("schema", {})
|
||||
properties = schema.get("properties", {})
|
||||
|
||||
for prop_name, prop_schema in properties.items():
|
||||
if prop_name not in all_params:
|
||||
all_params[prop_name] = {
|
||||
"name": prop_name,
|
||||
"type": prop_schema.get("type", "string"),
|
||||
"description": prop_schema.get("description", ""),
|
||||
"required": prop_name in schema.get("required", []),
|
||||
"enum": prop_schema.get("enum"),
|
||||
"default": prop_schema.get("default"),
|
||||
"minimum": prop_schema.get("minimum"),
|
||||
"maximum": prop_schema.get("maximum"),
|
||||
"pattern": prop_schema.get("pattern")
|
||||
}
|
||||
|
||||
# 转换为参数列表
|
||||
parameters.extend(all_params.values())
|
||||
|
||||
return parameters
|
||||
|
||||
def validate_operation_parameters(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Tuple[bool, List[str]]:
|
||||
"""验证操作参数
|
||||
|
||||
Args:
|
||||
operation: 操作定义
|
||||
params: 输入参数
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息列表)
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# 验证路径参数和查询参数
|
||||
for param_name, param_info in operation.get("parameters", {}).items():
|
||||
if param_info.get("required", False) and param_name not in params:
|
||||
errors.append(f"缺少必需参数: {param_name}")
|
||||
|
||||
if param_name in params:
|
||||
value = params[param_name]
|
||||
param_type = param_info.get("type", "string")
|
||||
|
||||
# 类型验证
|
||||
if not self._validate_parameter_type(value, param_type):
|
||||
errors.append(f"参数 {param_name} 类型错误,期望: {param_type}")
|
||||
|
||||
# 枚举验证
|
||||
enum_values = param_info.get("enum")
|
||||
if enum_values and value not in enum_values:
|
||||
errors.append(f"参数 {param_name} 值无效,必须是: {enum_values}")
|
||||
|
||||
# 验证请求体参数
|
||||
request_body = operation.get("request_body")
|
||||
if request_body:
|
||||
schema = request_body.get("schema", {})
|
||||
required_props = schema.get("required", [])
|
||||
properties = schema.get("properties", {})
|
||||
|
||||
for prop_name in required_props:
|
||||
if prop_name not in params:
|
||||
errors.append(f"缺少必需的请求体参数: {prop_name}")
|
||||
|
||||
for prop_name, value in params.items():
|
||||
if prop_name in properties:
|
||||
prop_schema = properties[prop_name]
|
||||
prop_type = prop_schema.get("type", "string")
|
||||
|
||||
if not self._validate_parameter_type(value, prop_type):
|
||||
errors.append(f"请求体参数 {prop_name} 类型错误,期望: {prop_type}")
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
def _validate_parameter_type(self, value: Any, expected_type: str) -> bool:
|
||||
"""验证参数类型
|
||||
|
||||
Args:
|
||||
value: 参数值
|
||||
expected_type: 期望类型
|
||||
|
||||
Returns:
|
||||
是否类型匹配
|
||||
"""
|
||||
if value is None:
|
||||
return True
|
||||
|
||||
type_mapping = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"number": (int, float),
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict
|
||||
}
|
||||
|
||||
expected_python_type = type_mapping.get(expected_type)
|
||||
if expected_python_type:
|
||||
return isinstance(value, expected_python_type)
|
||||
|
||||
return True
|
||||
501
api/app/core/tools/executor.py
Normal file
501
api/app/core/tools/executor.py
Normal file
@@ -0,0 +1,501 @@
|
||||
"""工具执行器 - 负责工具的实际调用和执行管理"""
|
||||
import asyncio
|
||||
import uuid
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.tool_model import ToolExecution, ExecutionStatus
|
||||
from app.core.tools.base import BaseTool, ToolResult
|
||||
from app.core.tools.registry import ToolRegistry
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ExecutionContext:
|
||||
"""执行上下文"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
execution_id: str,
|
||||
tool_id: str,
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
timeout: Optional[float] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
self.execution_id = execution_id
|
||||
self.tool_id = tool_id
|
||||
self.user_id = user_id
|
||||
self.workspace_id = workspace_id
|
||||
self.timeout = timeout or 60.0 # 默认60秒超时
|
||||
self.metadata = metadata or {}
|
||||
self.started_at = datetime.now()
|
||||
self.completed_at: Optional[datetime] = None
|
||||
self.status = ExecutionStatus.PENDING
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""工具执行器 - 使用langchain标准接口执行工具"""
|
||||
|
||||
def __init__(self, db: Session, registry: ToolRegistry):
|
||||
"""初始化工具执行器
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
registry: 工具注册表
|
||||
"""
|
||||
self.db = db
|
||||
self.registry = registry
|
||||
self._running_executions: Dict[str, ExecutionContext] = {}
|
||||
self._execution_lock = asyncio.Lock()
|
||||
|
||||
async def execute_tool(
|
||||
self,
|
||||
tool_id: str,
|
||||
parameters: Dict[str, Any],
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
execution_id: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> ToolResult:
|
||||
"""执行工具
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
parameters: 工具参数
|
||||
user_id: 用户ID
|
||||
workspace_id: 工作空间ID
|
||||
execution_id: 执行ID(可选,自动生成)
|
||||
timeout: 超时时间(秒)
|
||||
metadata: 额外元数据
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
"""
|
||||
# 生成执行ID
|
||||
if not execution_id:
|
||||
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
# 创建执行上下文
|
||||
context = ExecutionContext(
|
||||
execution_id=execution_id,
|
||||
tool_id=tool_id,
|
||||
user_id=user_id,
|
||||
workspace_id=workspace_id,
|
||||
timeout=timeout,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取工具实例
|
||||
tool = self.registry.get_tool(tool_id)
|
||||
if not tool:
|
||||
return ToolResult.error_result(
|
||||
error=f"工具不存在: {tool_id}",
|
||||
error_code="TOOL_NOT_FOUND",
|
||||
execution_time=0.0
|
||||
)
|
||||
|
||||
# 记录执行开始
|
||||
await self._record_execution_start(context, parameters)
|
||||
|
||||
# 执行工具
|
||||
result = await self._execute_with_timeout(tool, parameters, context)
|
||||
|
||||
# 记录执行完成
|
||||
await self._record_execution_complete(context, result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具执行异常: {execution_id}, 错误: {e}")
|
||||
|
||||
# 记录执行失败
|
||||
error_result = ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="EXECUTION_ERROR",
|
||||
execution_time=time.time() - context.started_at.timestamp()
|
||||
)
|
||||
await self._record_execution_complete(context, error_result)
|
||||
|
||||
return error_result
|
||||
|
||||
finally:
|
||||
# 清理执行上下文
|
||||
async with self._execution_lock:
|
||||
if execution_id in self._running_executions:
|
||||
del self._running_executions[execution_id]
|
||||
|
||||
async def execute_tools_batch(
|
||||
self,
|
||||
tool_executions: List[Dict[str, Any]],
|
||||
max_concurrency: int = 5
|
||||
) -> List[ToolResult]:
|
||||
"""批量执行工具
|
||||
|
||||
Args:
|
||||
tool_executions: 工具执行配置列表,每个包含tool_id和parameters
|
||||
max_concurrency: 最大并发数
|
||||
|
||||
Returns:
|
||||
执行结果列表
|
||||
"""
|
||||
semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def execute_single(exec_config: Dict[str, Any]) -> ToolResult:
|
||||
async with semaphore:
|
||||
return await self.execute_tool(
|
||||
tool_id=exec_config["tool_id"],
|
||||
parameters=exec_config.get("parameters", {}),
|
||||
user_id=exec_config.get("user_id"),
|
||||
workspace_id=exec_config.get("workspace_id"),
|
||||
timeout=exec_config.get("timeout"),
|
||||
metadata=exec_config.get("metadata")
|
||||
)
|
||||
|
||||
# 并发执行所有工具
|
||||
tasks = [execute_single(config) for config in tool_executions]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理异常结果
|
||||
processed_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
processed_results.append(
|
||||
ToolResult.error_result(
|
||||
error=str(result),
|
||||
error_code="BATCH_EXECUTION_ERROR",
|
||||
execution_time=0.0
|
||||
)
|
||||
)
|
||||
else:
|
||||
processed_results.append(result)
|
||||
|
||||
return processed_results
|
||||
|
||||
async def cancel_execution(self, execution_id: str) -> bool:
|
||||
"""取消工具执行
|
||||
|
||||
Args:
|
||||
execution_id: 执行ID
|
||||
|
||||
Returns:
|
||||
是否成功取消
|
||||
"""
|
||||
async with self._execution_lock:
|
||||
if execution_id not in self._running_executions:
|
||||
return False
|
||||
|
||||
context = self._running_executions[execution_id]
|
||||
context.status = ExecutionStatus.FAILED
|
||||
|
||||
# 更新数据库记录
|
||||
execution_record = self.db.query(ToolExecution).filter(
|
||||
ToolExecution.execution_id == execution_id
|
||||
).first()
|
||||
|
||||
if execution_record:
|
||||
execution_record.status = ExecutionStatus.FAILED.value
|
||||
execution_record.error_message = "执行被取消"
|
||||
execution_record.completed_at = datetime.now()
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"工具执行已取消: {execution_id}")
|
||||
return True
|
||||
|
||||
def get_running_executions(self) -> List[Dict[str, Any]]:
|
||||
"""获取正在运行的执行列表
|
||||
|
||||
Returns:
|
||||
执行信息列表
|
||||
"""
|
||||
executions = []
|
||||
for execution_id, context in self._running_executions.items():
|
||||
executions.append({
|
||||
"execution_id": execution_id,
|
||||
"tool_id": context.tool_id,
|
||||
"user_id": str(context.user_id) if context.user_id else None,
|
||||
"workspace_id": str(context.workspace_id) if context.workspace_id else None,
|
||||
"started_at": context.started_at.isoformat(),
|
||||
"status": context.status.value,
|
||||
"elapsed_time": (datetime.now() - context.started_at).total_seconds()
|
||||
})
|
||||
|
||||
return executions
|
||||
|
||||
async def _execute_with_timeout(
|
||||
self,
|
||||
tool: BaseTool,
|
||||
parameters: Dict[str, Any],
|
||||
context: ExecutionContext
|
||||
) -> ToolResult:
|
||||
"""带超时的工具执行
|
||||
|
||||
Args:
|
||||
tool: 工具实例
|
||||
parameters: 参数
|
||||
context: 执行上下文
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
async with self._execution_lock:
|
||||
self._running_executions[context.execution_id] = context
|
||||
context.status = ExecutionStatus.RUNNING
|
||||
|
||||
try:
|
||||
# 使用asyncio.wait_for实现超时控制
|
||||
result = await asyncio.wait_for(
|
||||
tool.safe_execute(**parameters),
|
||||
timeout=context.timeout
|
||||
)
|
||||
|
||||
context.status = ExecutionStatus.COMPLETED
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
context.status = ExecutionStatus.TIMEOUT
|
||||
return ToolResult.error_result(
|
||||
error=f"工具执行超时({context.timeout}秒)",
|
||||
error_code="EXECUTION_TIMEOUT",
|
||||
execution_time=context.timeout
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
context.status = ExecutionStatus.FAILED
|
||||
raise
|
||||
|
||||
async def _record_execution_start(
|
||||
self,
|
||||
context: ExecutionContext,
|
||||
parameters: Dict[str, Any]
|
||||
):
|
||||
"""记录执行开始"""
|
||||
try:
|
||||
execution_record = ToolExecution(
|
||||
execution_id=context.execution_id,
|
||||
tool_config_id=uuid.UUID(context.tool_id),
|
||||
status=ExecutionStatus.RUNNING.value,
|
||||
input_data=parameters,
|
||||
started_at=context.started_at,
|
||||
user_id=context.user_id,
|
||||
workspace_id=context.workspace_id
|
||||
)
|
||||
|
||||
self.db.add(execution_record)
|
||||
self.db.commit()
|
||||
|
||||
logger.debug(f"执行记录已创建: {context.execution_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建执行记录失败: {context.execution_id}, 错误: {e}")
|
||||
|
||||
async def _record_execution_complete(
|
||||
self,
|
||||
context: ExecutionContext,
|
||||
result: ToolResult
|
||||
):
|
||||
"""记录执行完成"""
|
||||
try:
|
||||
context.completed_at = datetime.now()
|
||||
|
||||
execution_record = self.db.query(ToolExecution).filter(
|
||||
ToolExecution.execution_id == context.execution_id
|
||||
).first()
|
||||
|
||||
if execution_record:
|
||||
execution_record.status = (
|
||||
ExecutionStatus.COMPLETED.value if result.success
|
||||
else ExecutionStatus.FAILED.value
|
||||
)
|
||||
execution_record.output_data = result.data if result.success else None
|
||||
execution_record.error_message = result.error if not result.success else None
|
||||
execution_record.completed_at = context.completed_at
|
||||
execution_record.execution_time = result.execution_time
|
||||
execution_record.token_usage = result.token_usage
|
||||
|
||||
self.db.commit()
|
||||
|
||||
logger.debug(f"执行记录已更新: {context.execution_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新执行记录失败: {context.execution_id}, 错误: {e}")
|
||||
|
||||
def get_execution_history(
|
||||
self,
|
||||
tool_id: Optional[str] = None,
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取执行历史
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID过滤
|
||||
user_id: 用户ID过滤
|
||||
workspace_id: 工作空间ID过滤
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
执行历史列表
|
||||
"""
|
||||
try:
|
||||
query = self.db.query(ToolExecution).order_by(
|
||||
ToolExecution.started_at.desc()
|
||||
)
|
||||
|
||||
if tool_id:
|
||||
query = query.filter(ToolExecution.tool_config_id == uuid.UUID(tool_id))
|
||||
|
||||
if user_id:
|
||||
query = query.filter(ToolExecution.user_id == user_id)
|
||||
|
||||
if workspace_id:
|
||||
query = query.filter(ToolExecution.workspace_id == workspace_id)
|
||||
|
||||
executions = query.limit(limit).all()
|
||||
|
||||
history = []
|
||||
for execution in executions:
|
||||
history.append({
|
||||
"execution_id": execution.execution_id,
|
||||
"tool_id": str(execution.tool_config_id),
|
||||
"status": execution.status,
|
||||
"started_at": execution.started_at.isoformat() if execution.started_at else None,
|
||||
"completed_at": execution.completed_at.isoformat() if execution.completed_at else None,
|
||||
"execution_time": execution.execution_time,
|
||||
"user_id": str(execution.user_id) if execution.user_id else None,
|
||||
"workspace_id": str(execution.workspace_id) if execution.workspace_id else None,
|
||||
"input_data": execution.input_data,
|
||||
"output_data": execution.output_data,
|
||||
"error_message": execution.error_message,
|
||||
"token_usage": execution.token_usage
|
||||
})
|
||||
|
||||
return history
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取执行历史失败, 错误: {e}")
|
||||
return []
|
||||
|
||||
def get_execution_statistics(
|
||||
self,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
days: int = 7
|
||||
) -> Dict[str, Any]:
|
||||
"""获取执行统计信息
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
days: 统计天数
|
||||
|
||||
Returns:
|
||||
统计信息
|
||||
"""
|
||||
try:
|
||||
from datetime import timedelta
|
||||
|
||||
start_date = datetime.now() - timedelta(days=days)
|
||||
|
||||
query = self.db.query(ToolExecution).filter(
|
||||
ToolExecution.started_at >= start_date
|
||||
)
|
||||
|
||||
if workspace_id:
|
||||
query = query.filter(ToolExecution.workspace_id == workspace_id)
|
||||
|
||||
executions = query.all()
|
||||
|
||||
# 统计数据
|
||||
total_executions = len(executions)
|
||||
successful_executions = len([e for e in executions if e.status == ExecutionStatus.COMPLETED.value])
|
||||
failed_executions = len([e for e in executions if e.status == ExecutionStatus.FAILED.value])
|
||||
|
||||
# 平均执行时间
|
||||
completed_executions = [e for e in executions if e.execution_time is not None]
|
||||
avg_execution_time = (
|
||||
sum(e.execution_time for e in completed_executions) / len(completed_executions)
|
||||
if completed_executions else 0
|
||||
)
|
||||
|
||||
# 按工具统计
|
||||
tool_stats = {}
|
||||
for execution in executions:
|
||||
tool_id = str(execution.tool_config_id)
|
||||
if tool_id not in tool_stats:
|
||||
tool_stats[tool_id] = {"total": 0, "successful": 0, "failed": 0}
|
||||
|
||||
tool_stats[tool_id]["total"] += 1
|
||||
if execution.status == ExecutionStatus.COMPLETED.value:
|
||||
tool_stats[tool_id]["successful"] += 1
|
||||
elif execution.status == ExecutionStatus.FAILED.value:
|
||||
tool_stats[tool_id]["failed"] += 1
|
||||
|
||||
return {
|
||||
"period_days": days,
|
||||
"total_executions": total_executions,
|
||||
"successful_executions": successful_executions,
|
||||
"failed_executions": failed_executions,
|
||||
"success_rate": successful_executions / total_executions if total_executions > 0 else 0,
|
||||
"average_execution_time": avg_execution_time,
|
||||
"tool_statistics": tool_stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取执行统计失败, 错误: {e}")
|
||||
return {}
|
||||
|
||||
async def test_tool_connection(
|
||||
self,
|
||||
tool_id: str,
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""测试工具连接"""
|
||||
try:
|
||||
from app.models.tool_model import ToolConfig, ToolType, MCPToolConfig
|
||||
from .mcp.client import MCPClient
|
||||
|
||||
tool_config = self.db.query(ToolConfig).filter(
|
||||
ToolConfig.id == uuid.UUID(tool_id)
|
||||
).first()
|
||||
|
||||
if not tool_config:
|
||||
return {"success": False, "message": "工具不存在"}
|
||||
|
||||
if tool_config.tool_type == ToolType.MCP.value:
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if not mcp_config:
|
||||
return {"success": False, "message": "MCP配置不存在"}
|
||||
|
||||
client = MCPClient(mcp_config.server_url, mcp_config.connection_config or {})
|
||||
|
||||
if await client.connect():
|
||||
try:
|
||||
tools = await client.list_tools()
|
||||
await client.disconnect()
|
||||
return {
|
||||
"success": True,
|
||||
"message": "MCP连接成功",
|
||||
"details": {"server_url": mcp_config.server_url, "tools": len(tools)}
|
||||
}
|
||||
except:
|
||||
await client.disconnect()
|
||||
return {"success": False, "message": "MCP功能测试失败"}
|
||||
else:
|
||||
return {"success": False, "message": "MCP连接失败"}
|
||||
else:
|
||||
tool = self.registry.get_tool(tool_id)
|
||||
if tool and hasattr(tool, 'test_connection'):
|
||||
result = tool.test_connection()
|
||||
return {"success": result.get("success", False), "message": result.get("message", "")}
|
||||
return {"success": True, "message": "工具无需连接测试"}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": "测试失败", "error": str(e)}
|
||||
375
api/app/core/tools/langchain_adapter.py
Normal file
375
api/app/core/tools/langchain_adapter.py
Normal file
@@ -0,0 +1,375 @@
|
||||
"""Langchain适配器 - 将工具转换为langchain兼容格式"""
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.tools import BaseTool as LangchainBaseTool
|
||||
from langchain_core.tools import ToolException
|
||||
|
||||
from app.core.tools.base import BaseTool, ToolResult, ToolParameter, ParameterType
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class LangchainToolWrapper(LangchainBaseTool):
|
||||
"""Langchain工具包装器"""
|
||||
|
||||
name: str = Field(..., description="工具名称")
|
||||
description: str = Field(..., description="工具描述")
|
||||
args_schema: Optional[Type[BaseModel]] = Field(None, description="参数schema")
|
||||
return_direct: bool = Field(False, description="是否直接返回结果")
|
||||
|
||||
# 内部工具实例
|
||||
tool_instance: BaseTool = Field(..., description="内部工具实例")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, tool_instance: BaseTool, **kwargs):
|
||||
"""初始化Langchain工具包装器
|
||||
|
||||
Args:
|
||||
tool_instance: 内部工具实例
|
||||
"""
|
||||
# 动态创建参数schema
|
||||
args_schema = LangchainAdapter._create_pydantic_schema(tool_instance.parameters)
|
||||
|
||||
super().__init__(
|
||||
name=tool_instance.name,
|
||||
description=tool_instance.description,
|
||||
args_schema=args_schema,
|
||||
_tool_instance=tool_instance,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
run_manager=None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""同步执行工具(Langchain要求)"""
|
||||
# 由于我们的工具是异步的,这里抛出异常提示使用异步版本
|
||||
raise NotImplementedError("请使用 _arun 方法进行异步调用")
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
run_manager=None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""异步执行工具"""
|
||||
try:
|
||||
# 执行内部工具
|
||||
result = await self._tool_instance.safe_execute(**kwargs)
|
||||
|
||||
# 转换结果为Langchain格式
|
||||
return LangchainAdapter._format_result_for_langchain(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具执行失败: {self.name}, 错误: {e}")
|
||||
raise ToolException(f"工具执行失败: {str(e)}")
|
||||
|
||||
|
||||
class LangchainAdapter:
|
||||
"""Langchain适配器 - 负责工具格式转换和标准化"""
|
||||
|
||||
@staticmethod
|
||||
def convert_tool(tool: BaseTool) -> LangchainToolWrapper:
|
||||
"""将内部工具转换为Langchain工具
|
||||
|
||||
Args:
|
||||
tool: 内部工具实例
|
||||
|
||||
Returns:
|
||||
Langchain兼容的工具包装器
|
||||
"""
|
||||
try:
|
||||
wrapper = LangchainToolWrapper(tool_instance=tool)
|
||||
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
|
||||
return wrapper
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具转换失败: {tool.name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]:
|
||||
"""批量转换工具
|
||||
|
||||
Args:
|
||||
tools: 工具列表
|
||||
|
||||
Returns:
|
||||
Langchain工具列表
|
||||
"""
|
||||
converted_tools = []
|
||||
|
||||
for tool in tools:
|
||||
try:
|
||||
converted_tool = LangchainAdapter.convert_tool(tool)
|
||||
converted_tools.append(converted_tool)
|
||||
except Exception as e:
|
||||
logger.error(f"跳过工具转换: {tool.name}, 错误: {e}")
|
||||
|
||||
logger.info(f"批量转换完成: {len(converted_tools)}/{len(tools)} 个工具")
|
||||
return converted_tools
|
||||
|
||||
@staticmethod
|
||||
def _create_pydantic_schema(parameters: List[ToolParameter]) -> Type[BaseModel]:
|
||||
"""根据工具参数创建Pydantic schema
|
||||
|
||||
Args:
|
||||
parameters: 工具参数列表
|
||||
|
||||
Returns:
|
||||
Pydantic模型类
|
||||
"""
|
||||
# 构建字段定义
|
||||
fields = {}
|
||||
annotations = {}
|
||||
|
||||
for param in parameters:
|
||||
# 确定Python类型
|
||||
python_type = LangchainAdapter._get_python_type(param.type)
|
||||
|
||||
# 处理可选参数
|
||||
if not param.required:
|
||||
python_type = Optional[python_type]
|
||||
|
||||
# 创建Field定义
|
||||
field_kwargs = {
|
||||
"description": param.description
|
||||
}
|
||||
|
||||
if param.default is not None:
|
||||
field_kwargs["default"] = param.default
|
||||
elif not param.required:
|
||||
field_kwargs["default"] = None
|
||||
else:
|
||||
field_kwargs["default"] = ... # 必需字段
|
||||
|
||||
# 添加验证约束
|
||||
if param.enum:
|
||||
# 枚举值约束
|
||||
field_kwargs["regex"] = f"^({'|'.join(map(str, param.enum))})$"
|
||||
|
||||
if param.minimum is not None:
|
||||
field_kwargs["ge"] = param.minimum
|
||||
|
||||
if param.maximum is not None:
|
||||
field_kwargs["le"] = param.maximum
|
||||
|
||||
if param.pattern:
|
||||
field_kwargs["regex"] = param.pattern
|
||||
|
||||
fields[param.name] = Field(**field_kwargs)
|
||||
annotations[param.name] = python_type
|
||||
|
||||
# 动态创建Pydantic模型
|
||||
schema_class = type(
|
||||
"ToolArgsSchema",
|
||||
(BaseModel,),
|
||||
{
|
||||
"__annotations__": annotations,
|
||||
**fields,
|
||||
"Config": type("Config", (), {"extra": "forbid"})
|
||||
}
|
||||
)
|
||||
|
||||
return schema_class
|
||||
|
||||
@staticmethod
|
||||
def _get_python_type(param_type: ParameterType) -> type:
|
||||
"""获取参数类型对应的Python类型
|
||||
|
||||
Args:
|
||||
param_type: 参数类型
|
||||
|
||||
Returns:
|
||||
Python类型
|
||||
"""
|
||||
type_mapping = {
|
||||
ParameterType.STRING: str,
|
||||
ParameterType.INTEGER: int,
|
||||
ParameterType.NUMBER: float,
|
||||
ParameterType.BOOLEAN: bool,
|
||||
ParameterType.ARRAY: list,
|
||||
ParameterType.OBJECT: dict
|
||||
}
|
||||
|
||||
return type_mapping.get(param_type, str)
|
||||
|
||||
@staticmethod
|
||||
def _format_result_for_langchain(result: ToolResult) -> str:
|
||||
"""将工具结果格式化为Langchain标准格式
|
||||
|
||||
Args:
|
||||
result: 工具执行结果
|
||||
|
||||
Returns:
|
||||
格式化的字符串结果
|
||||
"""
|
||||
if not result.success:
|
||||
# 错误结果
|
||||
error_info = {
|
||||
"success": False,
|
||||
"error": result.error,
|
||||
"error_code": result.error_code,
|
||||
"execution_time": result.execution_time
|
||||
}
|
||||
return json.dumps(error_info, ensure_ascii=False, indent=2)
|
||||
|
||||
# 成功结果
|
||||
if isinstance(result.data, str):
|
||||
# 如果数据已经是字符串,直接返回
|
||||
return result.data
|
||||
elif isinstance(result.data, (dict, list)):
|
||||
# 如果是结构化数据,转换为JSON
|
||||
return json.dumps(result.data, ensure_ascii=False, indent=2)
|
||||
else:
|
||||
# 其他类型转换为字符串
|
||||
return str(result.data)
|
||||
|
||||
@staticmethod
|
||||
def create_tool_description(tool: BaseTool) -> Dict[str, Any]:
|
||||
"""创建工具描述(用于工具发现和文档生成)
|
||||
|
||||
Args:
|
||||
tool: 工具实例
|
||||
|
||||
Returns:
|
||||
工具描述字典
|
||||
"""
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"tool_type": tool.tool_type.value,
|
||||
"version": tool.version,
|
||||
"status": tool.status.value,
|
||||
"tags": tool.tags,
|
||||
"parameters": [
|
||||
{
|
||||
"name": param.name,
|
||||
"type": param.type.value,
|
||||
"description": param.description,
|
||||
"required": param.required,
|
||||
"default": param.default,
|
||||
"enum": param.enum,
|
||||
"minimum": param.minimum,
|
||||
"maximum": param.maximum,
|
||||
"pattern": param.pattern
|
||||
}
|
||||
for param in tool.parameters
|
||||
],
|
||||
"langchain_compatible": True
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def validate_langchain_compatibility(tool: BaseTool) -> tuple[bool, List[str]]:
|
||||
"""验证工具是否与Langchain兼容
|
||||
|
||||
Args:
|
||||
tool: 工具实例
|
||||
|
||||
Returns:
|
||||
(是否兼容, 问题列表)
|
||||
"""
|
||||
issues = []
|
||||
|
||||
# 检查工具名称
|
||||
if not tool.name or not isinstance(tool.name, str):
|
||||
issues.append("工具名称必须是非空字符串")
|
||||
|
||||
# 检查工具描述
|
||||
if not tool.description or not isinstance(tool.description, str):
|
||||
issues.append("工具描述必须是非空字符串")
|
||||
|
||||
# 检查参数定义
|
||||
for param in tool.parameters:
|
||||
if not param.name or not isinstance(param.name, str):
|
||||
issues.append(f"参数名称无效: {param.name}")
|
||||
|
||||
if param.type not in ParameterType:
|
||||
issues.append(f"不支持的参数类型: {param.type}")
|
||||
|
||||
if param.required and param.default is not None:
|
||||
issues.append(f"必需参数不应有默认值: {param.name}")
|
||||
|
||||
# 检查是否有execute方法
|
||||
if not hasattr(tool, 'execute') or not callable(getattr(tool, 'execute')):
|
||||
issues.append("工具必须实现execute方法")
|
||||
|
||||
return len(issues) == 0, issues
|
||||
|
||||
@staticmethod
|
||||
def get_langchain_tool_schema(tool: BaseTool) -> Dict[str, Any]:
|
||||
"""获取Langchain工具的OpenAPI schema
|
||||
|
||||
Args:
|
||||
tool: 工具实例
|
||||
|
||||
Returns:
|
||||
OpenAPI schema字典
|
||||
"""
|
||||
# 构建参数schema
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
for param in tool.parameters:
|
||||
prop_schema = {
|
||||
"type": LangchainAdapter._get_openapi_type(param.type),
|
||||
"description": param.description
|
||||
}
|
||||
|
||||
if param.enum:
|
||||
prop_schema["enum"] = param.enum
|
||||
|
||||
if param.minimum is not None:
|
||||
prop_schema["minimum"] = param.minimum
|
||||
|
||||
if param.maximum is not None:
|
||||
prop_schema["maximum"] = param.maximum
|
||||
|
||||
if param.pattern:
|
||||
prop_schema["pattern"] = param.pattern
|
||||
|
||||
if param.default is not None:
|
||||
prop_schema["default"] = param.default
|
||||
|
||||
properties[param.name] = prop_schema
|
||||
|
||||
if param.required:
|
||||
required.append(param.name)
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_openapi_type(param_type: ParameterType) -> str:
|
||||
"""获取OpenAPI类型
|
||||
|
||||
Args:
|
||||
param_type: 参数类型
|
||||
|
||||
Returns:
|
||||
OpenAPI类型字符串
|
||||
"""
|
||||
type_mapping = {
|
||||
ParameterType.STRING: "string",
|
||||
ParameterType.INTEGER: "integer",
|
||||
ParameterType.NUMBER: "number",
|
||||
ParameterType.BOOLEAN: "boolean",
|
||||
ParameterType.ARRAY: "array",
|
||||
ParameterType.OBJECT: "object"
|
||||
}
|
||||
|
||||
return type_mapping.get(param_type, "string")
|
||||
12
api/app/core/tools/mcp/__init__.py
Normal file
12
api/app/core/tools/mcp/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""MCP工具模块"""
|
||||
|
||||
from .base import MCPTool
|
||||
from .client import MCPClient, MCPConnectionPool
|
||||
from .service_manager import MCPServiceManager
|
||||
|
||||
__all__ = [
|
||||
"MCPTool",
|
||||
"MCPClient",
|
||||
"MCPConnectionPool",
|
||||
"MCPServiceManager"
|
||||
]
|
||||
258
api/app/core/tools/mcp/base.py
Normal file
258
api/app/core/tools/mcp/base.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""MCP工具基类"""
|
||||
import time
|
||||
from typing import Dict, Any, List
|
||||
import aiohttp
|
||||
|
||||
from app.models.tool_model import ToolType
|
||||
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class MCPTool(BaseTool):
|
||||
"""MCP工具 - Model Context Protocol工具"""
|
||||
|
||||
def __init__(self, tool_id: str, config: Dict[str, Any]):
|
||||
"""初始化MCP工具
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
config: 工具配置
|
||||
"""
|
||||
super().__init__(tool_id, config)
|
||||
self.server_url = config.get("server_url", "")
|
||||
self.connection_config = config.get("connection_config", {})
|
||||
self.available_tools = config.get("available_tools", [])
|
||||
self._client = None
|
||||
self._connected = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""工具名称"""
|
||||
return f"mcp_tool_{self.tool_id[:8]}"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""工具描述"""
|
||||
return f"MCP工具 - 连接到 {self.server_url}"
|
||||
|
||||
@property
|
||||
def tool_type(self) -> ToolType:
|
||||
"""工具类型"""
|
||||
return ToolType.MCP
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
"""工具参数定义"""
|
||||
params = []
|
||||
|
||||
# 添加工具选择参数
|
||||
if len(self.available_tools) > 1:
|
||||
params.append(ToolParameter(
|
||||
name="tool_name",
|
||||
type=ParameterType.STRING,
|
||||
description="要调用的MCP工具名称",
|
||||
required=True,
|
||||
enum=self.available_tools
|
||||
))
|
||||
|
||||
# 添加通用参数
|
||||
params.extend([
|
||||
ToolParameter(
|
||||
name="arguments",
|
||||
type=ParameterType.OBJECT,
|
||||
description="工具参数(JSON对象)",
|
||||
required=False,
|
||||
default={}
|
||||
),
|
||||
ToolParameter(
|
||||
name="timeout",
|
||||
type=ParameterType.INTEGER,
|
||||
description="超时时间(秒)",
|
||||
required=False,
|
||||
default=30,
|
||||
minimum=1,
|
||||
maximum=300
|
||||
)
|
||||
])
|
||||
|
||||
return params
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行MCP工具"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 确保连接
|
||||
if not self._connected:
|
||||
await self.connect()
|
||||
|
||||
# 确定要调用的工具
|
||||
tool_name = kwargs.get("tool_name")
|
||||
if not tool_name and len(self.available_tools) == 1:
|
||||
tool_name = self.available_tools[0]
|
||||
|
||||
if not tool_name:
|
||||
raise ValueError("必须指定要调用的MCP工具名称")
|
||||
|
||||
if tool_name not in self.available_tools:
|
||||
raise ValueError(f"MCP工具不存在: {tool_name}")
|
||||
|
||||
# 获取参数
|
||||
arguments = kwargs.get("arguments", {})
|
||||
timeout = kwargs.get("timeout", 30)
|
||||
|
||||
# 调用MCP工具
|
||||
result = await self._call_mcp_tool(tool_name, arguments, timeout)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="MCP_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""连接到MCP服务器"""
|
||||
try:
|
||||
# 这里应该实现实际的MCP连接逻辑
|
||||
# 为了简化,这里只是模拟连接
|
||||
|
||||
# 测试服务器连接
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
# 尝试获取服务器信息
|
||||
async with session.get(f"{self.server_url}/info") as response:
|
||||
if response.status == 200:
|
||||
server_info = await response.json()
|
||||
self.available_tools = server_info.get("tools", [])
|
||||
self._connected = True
|
||||
logger.info(f"MCP服务器连接成功: {self.server_url}")
|
||||
return True
|
||||
else:
|
||||
raise Exception(f"服务器响应错误: {response.status}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP服务器连接失败: {self.server_url}, 错误: {e}")
|
||||
self._connected = False
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> bool:
|
||||
"""断开MCP服务器连接"""
|
||||
try:
|
||||
if self._client:
|
||||
# 这里应该实现实际的断开逻辑
|
||||
self._client = None
|
||||
|
||||
self._connected = False
|
||||
logger.info(f"MCP服务器连接已断开: {self.server_url}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"断开MCP服务器连接失败: {e}")
|
||||
return False
|
||||
|
||||
def get_health_status(self) -> Dict[str, Any]:
|
||||
"""获取MCP服务健康状态"""
|
||||
return {
|
||||
"connected": self._connected,
|
||||
"server_url": self.server_url,
|
||||
"available_tools": self.available_tools,
|
||||
"last_check": time.time()
|
||||
}
|
||||
|
||||
async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any:
|
||||
"""调用MCP工具"""
|
||||
# 构建MCP请求
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": f"req_{int(time.time() * 1000)}",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
}
|
||||
|
||||
# 发送请求
|
||||
client_timeout = aiohttp.ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(timeout=client_timeout) as session:
|
||||
async with session.post(
|
||||
f"{self.server_url}/mcp",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
) as response:
|
||||
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"MCP请求失败 {response.status}: {error_text}")
|
||||
|
||||
result = await response.json()
|
||||
|
||||
# 检查MCP响应
|
||||
if "error" in result:
|
||||
error = result["error"]
|
||||
raise Exception(f"MCP工具错误: {error.get('message', '未知错误')}")
|
||||
|
||||
return result.get("result", {})
|
||||
|
||||
async def list_available_tools(self) -> List[Dict[str, Any]]:
|
||||
"""列出可用的MCP工具"""
|
||||
try:
|
||||
if not self._connected:
|
||||
await self.connect()
|
||||
|
||||
# 获取工具列表
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": f"req_{int(time.time() * 1000)}",
|
||||
"method": "tools/list"
|
||||
}
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(
|
||||
f"{self.server_url}/mcp",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
) as response:
|
||||
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
if "result" in result:
|
||||
tools = result["result"].get("tools", [])
|
||||
self.available_tools = [tool.get("name") for tool in tools]
|
||||
return tools
|
||||
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取MCP工具列表失败: {e}")
|
||||
return []
|
||||
|
||||
def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试MCP连接"""
|
||||
try:
|
||||
# 这里应该实现同步的连接测试
|
||||
# 为了简化,返回基本信息
|
||||
return {
|
||||
"success": bool(self.server_url),
|
||||
"server_url": self.server_url,
|
||||
"connected": self._connected,
|
||||
"available_tools_count": len(self.available_tools),
|
||||
"message": "MCP配置有效" if self.server_url else "缺少服务器URL配置"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
626
api/app/core/tools/mcp/client.py
Normal file
626
api/app/core/tools/mcp/client.py
Normal file
@@ -0,0 +1,626 @@
|
||||
"""MCP客户端 - Model Context Protocol客户端实现"""
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional, Callable
|
||||
from urllib.parse import urlparse
|
||||
import aiohttp
|
||||
import websockets
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class MCPConnectionError(Exception):
|
||||
"""MCP连接错误"""
|
||||
pass
|
||||
|
||||
|
||||
class MCPProtocolError(Exception):
|
||||
"""MCP协议错误"""
|
||||
pass
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""MCP客户端 - 支持HTTP和WebSocket连接"""
|
||||
|
||||
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
|
||||
"""初始化MCP客户端
|
||||
|
||||
Args:
|
||||
server_url: MCP服务器URL
|
||||
connection_config: 连接配置
|
||||
"""
|
||||
self.server_url = server_url
|
||||
self.connection_config = connection_config or {}
|
||||
|
||||
# 解析URL确定连接类型
|
||||
parsed_url = urlparse(server_url)
|
||||
self.connection_type = "websocket" if parsed_url.scheme in ["ws", "wss"] else "http"
|
||||
|
||||
# 连接状态
|
||||
self._connected = False
|
||||
self._websocket = None
|
||||
self._session = None
|
||||
|
||||
# 请求管理
|
||||
self._request_id = 0
|
||||
self._pending_requests: Dict[str, asyncio.Future] = {}
|
||||
|
||||
# 连接池配置
|
||||
self.max_connections = self.connection_config.get("max_connections", 10)
|
||||
self.connection_timeout = self.connection_config.get("timeout", 30)
|
||||
self.retry_attempts = self.connection_config.get("retry_attempts", 3)
|
||||
self.retry_delay = self.connection_config.get("retry_delay", 1)
|
||||
|
||||
# 健康检查
|
||||
self.health_check_interval = self.connection_config.get("health_check_interval", 60)
|
||||
self._health_check_task = None
|
||||
self._last_health_check = None
|
||||
|
||||
# 事件回调
|
||||
self._on_connect_callbacks: List[Callable] = []
|
||||
self._on_disconnect_callbacks: List[Callable] = []
|
||||
self._on_error_callbacks: List[Callable] = []
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""连接到MCP服务器
|
||||
|
||||
Returns:
|
||||
连接是否成功
|
||||
"""
|
||||
try:
|
||||
if self._connected:
|
||||
return True
|
||||
|
||||
logger.info(f"连接MCP服务器: {self.server_url}")
|
||||
|
||||
if self.connection_type == "websocket":
|
||||
success = await self._connect_websocket()
|
||||
else:
|
||||
success = await self._connect_http()
|
||||
|
||||
if success:
|
||||
self._connected = True
|
||||
await self._start_health_check()
|
||||
await self._notify_connect_callbacks()
|
||||
logger.info(f"MCP服务器连接成功: {self.server_url}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"连接MCP服务器失败: {self.server_url}, 错误: {e}")
|
||||
await self._notify_error_callbacks(e)
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> bool:
|
||||
"""断开MCP服务器连接
|
||||
|
||||
Returns:
|
||||
断开是否成功
|
||||
"""
|
||||
try:
|
||||
if not self._connected:
|
||||
return True
|
||||
|
||||
logger.info(f"断开MCP服务器连接: {self.server_url}")
|
||||
|
||||
# 停止健康检查
|
||||
await self._stop_health_check()
|
||||
|
||||
# 取消所有待处理的请求
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
self._pending_requests.clear()
|
||||
|
||||
# 断开连接
|
||||
if self.connection_type == "websocket" and self._websocket:
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
elif self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
self._connected = False
|
||||
await self._notify_disconnect_callbacks()
|
||||
logger.info(f"MCP服务器连接已断开: {self.server_url}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"断开MCP服务器连接失败: {e}")
|
||||
return False
|
||||
|
||||
async def _connect_websocket(self) -> bool:
|
||||
"""建立WebSocket连接"""
|
||||
try:
|
||||
# WebSocket连接配置
|
||||
extra_headers = self.connection_config.get("headers", {})
|
||||
|
||||
self._websocket = await websockets.connect(
|
||||
self.server_url,
|
||||
extra_headers=extra_headers,
|
||||
timeout=self.connection_timeout
|
||||
)
|
||||
|
||||
# 启动消息监听
|
||||
asyncio.create_task(self._websocket_message_handler())
|
||||
|
||||
# 发送初始化消息
|
||||
init_message = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_next_request_id(),
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {
|
||||
"tools": {}
|
||||
},
|
||||
"clientInfo": {
|
||||
"name": "ToolManagementSystem",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
await self._websocket.send(json.dumps(init_message))
|
||||
|
||||
# 等待初始化响应
|
||||
response = await asyncio.wait_for(
|
||||
self._websocket.recv(),
|
||||
timeout=self.connection_timeout
|
||||
)
|
||||
|
||||
init_response = json.loads(response)
|
||||
if "error" in init_response:
|
||||
raise MCPProtocolError(f"初始化失败: {init_response['error']}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket连接失败: {e}")
|
||||
return False
|
||||
|
||||
async def _connect_http(self) -> bool:
|
||||
"""建立HTTP连接"""
|
||||
try:
|
||||
# HTTP会话配置
|
||||
timeout = aiohttp.ClientTimeout(total=self.connection_timeout)
|
||||
headers = self.connection_config.get("headers", {})
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=timeout,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# 测试连接
|
||||
test_url = f"{self.server_url}/health" if not self.server_url.endswith('/') else f"{self.server_url}health"
|
||||
|
||||
async with self._session.get(test_url) as response:
|
||||
if response.status == 200:
|
||||
return True
|
||||
else:
|
||||
# 尝试根路径
|
||||
async with self._session.get(self.server_url) as root_response:
|
||||
return root_response.status < 400
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"HTTP连接失败: {e}")
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
return False
|
||||
|
||||
async def _websocket_message_handler(self):
|
||||
"""WebSocket消息处理器"""
|
||||
try:
|
||||
while self._websocket and not self._websocket.closed:
|
||||
try:
|
||||
message = await self._websocket.recv()
|
||||
await self._handle_message(json.loads(message))
|
||||
except ConnectionClosed:
|
||||
break
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解析WebSocket消息失败: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理WebSocket消息失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket消息处理器异常: {e}")
|
||||
finally:
|
||||
self._connected = False
|
||||
await self._notify_disconnect_callbacks()
|
||||
|
||||
async def _handle_message(self, message: Dict[str, Any]):
|
||||
"""处理收到的消息"""
|
||||
try:
|
||||
# 检查是否是响应消息
|
||||
if "id" in message:
|
||||
request_id = str(message["id"])
|
||||
if request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
future.set_result(message)
|
||||
|
||||
# 处理通知消息
|
||||
elif "method" in message:
|
||||
await self._handle_notification(message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息失败: {e}")
|
||||
|
||||
async def _handle_notification(self, message: Dict[str, Any]):
|
||||
"""处理通知消息"""
|
||||
method = message.get("method")
|
||||
params = message.get("params", {})
|
||||
|
||||
logger.debug(f"收到MCP通知: {method}, 参数: {params}")
|
||||
|
||||
# 这里可以根据需要处理特定的通知
|
||||
# 例如:工具列表更新、服务器状态变化等
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int = 30) -> Dict[str, Any]:
|
||||
"""调用MCP工具
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
|
||||
Raises:
|
||||
MCPConnectionError: 连接错误
|
||||
MCPProtocolError: 协议错误
|
||||
"""
|
||||
if not self._connected:
|
||||
raise MCPConnectionError("MCP客户端未连接")
|
||||
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_next_request_id(),
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
response = await self._send_request(request_data, timeout)
|
||||
|
||||
if "error" in response:
|
||||
error = response["error"]
|
||||
raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}")
|
||||
|
||||
return response.get("result", {})
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise MCPProtocolError(f"工具调用超时: {tool_name}")
|
||||
|
||||
async def list_tools(self, timeout: int = 10) -> List[Dict[str, Any]]:
|
||||
"""获取可用工具列表
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
工具列表
|
||||
|
||||
Raises:
|
||||
MCPConnectionError: 连接错误
|
||||
MCPProtocolError: 协议错误
|
||||
"""
|
||||
if not self._connected:
|
||||
raise MCPConnectionError("MCP客户端未连接")
|
||||
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_next_request_id(),
|
||||
"method": "tools/list"
|
||||
}
|
||||
|
||||
try:
|
||||
response = await self._send_request(request_data, timeout)
|
||||
|
||||
if not response["error"] is None:
|
||||
error = response["error"]
|
||||
raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}")
|
||||
|
||||
result = response.get("result", {})
|
||||
return result.get("tools", [])
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise MCPProtocolError("获取工具列表超时")
|
||||
|
||||
async def _send_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
|
||||
"""发送请求并等待响应
|
||||
|
||||
Args:
|
||||
request_data: 请求数据
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
request_id = str(request_data["id"])
|
||||
|
||||
if self.connection_type == "websocket":
|
||||
return await self._send_websocket_request(request_data, request_id, timeout)
|
||||
else:
|
||||
return await self._send_http_request(request_data, timeout)
|
||||
|
||||
async def _send_websocket_request(self, request_data: Dict[str, Any], request_id: str, timeout: int) -> Dict[str, Any]:
|
||||
"""发送WebSocket请求"""
|
||||
if not self._websocket or self._websocket.closed:
|
||||
raise MCPConnectionError("WebSocket连接已断开")
|
||||
|
||||
# 创建Future等待响应
|
||||
future = asyncio.Future()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# 发送请求
|
||||
await self._websocket.send(json.dumps(request_data))
|
||||
|
||||
# 等待响应
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
return response
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
raise
|
||||
except Exception as e:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
raise MCPConnectionError(f"发送WebSocket请求失败: {e}")
|
||||
|
||||
async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
|
||||
"""发送HTTP请求"""
|
||||
if not self._session:
|
||||
raise MCPConnectionError("HTTP会话未建立")
|
||||
|
||||
try:
|
||||
url = f"{self.server_url}/mcp" if not self.server_url.endswith('/') else f"{self.server_url}mcp"
|
||||
|
||||
async with self._session.post(
|
||||
url,
|
||||
json=request_data,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout)
|
||||
) as response:
|
||||
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
|
||||
|
||||
return await response.json()
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise MCPConnectionError(f"HTTP请求失败: {e}")
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""执行健康检查
|
||||
|
||||
Returns:
|
||||
健康状态信息
|
||||
"""
|
||||
try:
|
||||
if not self._connected:
|
||||
return {
|
||||
"healthy": False,
|
||||
"error": "未连接",
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
# 发送ping请求
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_next_request_id(),
|
||||
"method": "ping"
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
response = await self._send_request(request_data, timeout=5)
|
||||
response_time = time.time() - start_time
|
||||
|
||||
self._last_health_check = time.time()
|
||||
|
||||
return {
|
||||
"healthy": True,
|
||||
"response_time": response_time,
|
||||
"timestamp": self._last_health_check,
|
||||
"server_info": response.get("result", {})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"healthy": False,
|
||||
"error": str(e),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
async def _start_health_check(self):
|
||||
"""启动健康检查任务"""
|
||||
if self.health_check_interval > 0:
|
||||
self._health_check_task = asyncio.create_task(self._health_check_loop())
|
||||
|
||||
async def _stop_health_check(self):
|
||||
"""停止健康检查任务"""
|
||||
if self._health_check_task:
|
||||
self._health_check_task.cancel()
|
||||
try:
|
||||
await self._health_check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._health_check_task = None
|
||||
|
||||
async def _health_check_loop(self):
|
||||
"""健康检查循环"""
|
||||
try:
|
||||
while self._connected:
|
||||
await asyncio.sleep(self.health_check_interval)
|
||||
|
||||
if self._connected:
|
||||
health_status = await self.health_check()
|
||||
if not health_status["healthy"]:
|
||||
logger.warning(f"MCP服务器健康检查失败: {health_status.get('error')}")
|
||||
# 可以在这里实现重连逻辑
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"健康检查循环异常: {e}")
|
||||
|
||||
def _get_next_request_id(self) -> str:
|
||||
"""获取下一个请求ID"""
|
||||
self._request_id += 1
|
||||
return f"req_{self._request_id}_{int(time.time() * 1000)}"
|
||||
|
||||
# 事件回调管理
|
||||
def on_connect(self, callback: Callable):
|
||||
"""注册连接回调"""
|
||||
self._on_connect_callbacks.append(callback)
|
||||
|
||||
def on_disconnect(self, callback: Callable):
|
||||
"""注册断开连接回调"""
|
||||
self._on_disconnect_callbacks.append(callback)
|
||||
|
||||
def on_error(self, callback: Callable):
|
||||
"""注册错误回调"""
|
||||
self._on_error_callbacks.append(callback)
|
||||
|
||||
async def _notify_connect_callbacks(self):
|
||||
"""通知连接回调"""
|
||||
for callback in self._on_connect_callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback()
|
||||
else:
|
||||
callback()
|
||||
except Exception as e:
|
||||
logger.error(f"连接回调执行失败: {e}")
|
||||
|
||||
async def _notify_disconnect_callbacks(self):
|
||||
"""通知断开连接回调"""
|
||||
for callback in self._on_disconnect_callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback()
|
||||
else:
|
||||
callback()
|
||||
except Exception as e:
|
||||
logger.error(f"断开连接回调执行失败: {e}")
|
||||
|
||||
async def _notify_error_callbacks(self, error: Exception):
|
||||
"""通知错误回调"""
|
||||
for callback in self._on_error_callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(error)
|
||||
else:
|
||||
callback(error)
|
||||
except Exception as e:
|
||||
logger.error(f"错误回调执行失败: {e}")
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""检查是否已连接"""
|
||||
return self._connected
|
||||
|
||||
@property
|
||||
def last_health_check(self) -> Optional[float]:
|
||||
"""获取最后一次健康检查时间"""
|
||||
return self._last_health_check
|
||||
|
||||
def get_connection_info(self) -> Dict[str, Any]:
|
||||
"""获取连接信息"""
|
||||
return {
|
||||
"server_url": self.server_url,
|
||||
"connection_type": self.connection_type,
|
||||
"connected": self._connected,
|
||||
"last_health_check": self._last_health_check,
|
||||
"pending_requests": len(self._pending_requests),
|
||||
"config": self.connection_config
|
||||
}
|
||||
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
await self.disconnect()
|
||||
|
||||
|
||||
class MCPConnectionPool:
|
||||
"""MCP连接池 - 管理多个MCP客户端连接"""
|
||||
|
||||
def __init__(self, max_connections: int = 10):
|
||||
"""初始化连接池
|
||||
|
||||
Args:
|
||||
max_connections: 最大连接数
|
||||
"""
|
||||
self.max_connections = max_connections
|
||||
self._clients: Dict[str, MCPClient] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_client(self, server_url: str, connection_config: Dict[str, Any] = None) -> MCPClient:
|
||||
"""获取或创建MCP客户端
|
||||
|
||||
Args:
|
||||
server_url: 服务器URL
|
||||
connection_config: 连接配置
|
||||
|
||||
Returns:
|
||||
MCP客户端实例
|
||||
"""
|
||||
async with self._lock:
|
||||
if server_url in self._clients:
|
||||
client = self._clients[server_url]
|
||||
if client.is_connected:
|
||||
return client
|
||||
else:
|
||||
# 尝试重连
|
||||
if await client.connect():
|
||||
return client
|
||||
else:
|
||||
# 移除失效的客户端
|
||||
del self._clients[server_url]
|
||||
|
||||
# 检查连接数限制
|
||||
if len(self._clients) >= self.max_connections:
|
||||
# 移除最旧的连接
|
||||
oldest_url = next(iter(self._clients))
|
||||
await self._clients[oldest_url].disconnect()
|
||||
del self._clients[oldest_url]
|
||||
|
||||
# 创建新客户端
|
||||
client = MCPClient(server_url, connection_config)
|
||||
if await client.connect():
|
||||
self._clients[server_url] = client
|
||||
return client
|
||||
else:
|
||||
raise MCPConnectionError(f"无法连接到MCP服务器: {server_url}")
|
||||
|
||||
async def disconnect_all(self):
|
||||
"""断开所有连接"""
|
||||
async with self._lock:
|
||||
for client in self._clients.values():
|
||||
await client.disconnect()
|
||||
self._clients.clear()
|
||||
|
||||
def get_pool_status(self) -> Dict[str, Any]:
|
||||
"""获取连接池状态"""
|
||||
return {
|
||||
"total_connections": len(self._clients),
|
||||
"max_connections": self.max_connections,
|
||||
"connections": {
|
||||
url: client.get_connection_info()
|
||||
for url, client in self._clients.items()
|
||||
}
|
||||
}
|
||||
604
api/app/core/tools/mcp/service_manager.py
Normal file
604
api/app/core/tools/mcp/service_manager.py
Normal file
@@ -0,0 +1,604 @@
|
||||
"""MCP服务管理器 - 管理MCP服务的注册、更新、删除和状态监控"""
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType
|
||||
from app.core.logging_config import get_business_logger
|
||||
from .client import MCPClient, MCPConnectionPool
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class MCPServiceManager:
|
||||
"""MCP服务管理器 - 管理MCP服务的生命周期"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""初始化MCP服务管理器
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
"""
|
||||
self.db = db
|
||||
self.connection_pool = MCPConnectionPool(max_connections=20)
|
||||
|
||||
# 服务状态管理
|
||||
self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info
|
||||
self._monitoring_tasks: Dict[str, asyncio.Task] = {} # service_id -> monitoring_task
|
||||
|
||||
# 配置
|
||||
self.health_check_interval = 60 # 健康检查间隔(秒)
|
||||
self.max_retry_attempts = 3 # 最大重试次数
|
||||
self.retry_delay = 5 # 重试延迟(秒)
|
||||
|
||||
# 状态
|
||||
self._running = False
|
||||
self._manager_task = None
|
||||
|
||||
async def start(self):
|
||||
"""启动服务管理器"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
logger.info("MCP服务管理器启动")
|
||||
|
||||
# 加载现有服务
|
||||
await self._load_existing_services()
|
||||
|
||||
# 启动管理任务
|
||||
self._manager_task = asyncio.create_task(self._management_loop())
|
||||
|
||||
async def stop(self):
|
||||
"""停止服务管理器"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
logger.info("MCP服务管理器停止")
|
||||
|
||||
# 停止管理任务
|
||||
if self._manager_task:
|
||||
self._manager_task.cancel()
|
||||
try:
|
||||
await self._manager_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 停止所有监控任务
|
||||
for task in self._monitoring_tasks.values():
|
||||
task.cancel()
|
||||
|
||||
if self._monitoring_tasks:
|
||||
await asyncio.gather(*self._monitoring_tasks.values(), return_exceptions=True)
|
||||
|
||||
self._monitoring_tasks.clear()
|
||||
|
||||
# 断开所有连接
|
||||
await self.connection_pool.disconnect_all()
|
||||
|
||||
async def register_service(
|
||||
self,
|
||||
server_url: str,
|
||||
connection_config: Dict[str, Any],
|
||||
tenant_id: uuid.UUID,
|
||||
service_name: str = None
|
||||
) -> Tuple[bool, str, Optional[str]]:
|
||||
"""注册MCP服务
|
||||
|
||||
Args:
|
||||
server_url: 服务器URL
|
||||
connection_config: 连接配置
|
||||
tenant_id: 租户ID
|
||||
service_name: 服务名称(可选)
|
||||
|
||||
Returns:
|
||||
(是否成功, 服务ID或错误信息, 错误详情)
|
||||
"""
|
||||
try:
|
||||
# 检查服务是否已存在
|
||||
existing_service = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.server_url == server_url
|
||||
).first()
|
||||
|
||||
if existing_service:
|
||||
return False, "服务已存在", f"URL {server_url} 已被注册"
|
||||
|
||||
# 测试连接
|
||||
try:
|
||||
client = MCPClient(server_url, connection_config)
|
||||
if not await client.connect():
|
||||
return False, "连接测试失败", "无法连接到MCP服务器"
|
||||
|
||||
# 获取可用工具
|
||||
available_tools = await client.list_tools()
|
||||
tool_names = [tool.get("name") for tool in available_tools if tool.get("name")]
|
||||
|
||||
await client.disconnect()
|
||||
|
||||
except Exception as e:
|
||||
return False, "连接测试失败", str(e)
|
||||
|
||||
# 创建工具配置
|
||||
if not service_name:
|
||||
service_name = f"mcp_service_{server_url.split('/')[-1]}"
|
||||
|
||||
tool_config = ToolConfig(
|
||||
name=service_name,
|
||||
description=f"MCP服务 - {server_url}",
|
||||
tool_type=ToolType.MCP.value,
|
||||
tenant_id=tenant_id,
|
||||
version="1.0.0",
|
||||
config_data={
|
||||
"server_url": server_url,
|
||||
"connection_config": connection_config
|
||||
}
|
||||
)
|
||||
|
||||
self.db.add(tool_config)
|
||||
self.db.flush()
|
||||
|
||||
# 创建MCP特定配置
|
||||
mcp_config = MCPToolConfig(
|
||||
id=tool_config.id,
|
||||
server_url=server_url,
|
||||
connection_config=connection_config,
|
||||
available_tools=tool_names,
|
||||
health_status="healthy",
|
||||
last_health_check=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.db.add(mcp_config)
|
||||
self.db.commit()
|
||||
|
||||
service_id = str(tool_config.id)
|
||||
|
||||
# 添加到内存管理
|
||||
self._services[service_id] = {
|
||||
"id": service_id,
|
||||
"server_url": server_url,
|
||||
"connection_config": connection_config,
|
||||
"tenant_id": tenant_id,
|
||||
"available_tools": tool_names,
|
||||
"status": "healthy",
|
||||
"last_health_check": time.time(),
|
||||
"retry_count": 0,
|
||||
"created_at": time.time()
|
||||
}
|
||||
|
||||
# 启动监控
|
||||
await self._start_service_monitoring(service_id)
|
||||
|
||||
logger.info(f"MCP服务注册成功: {service_id} ({server_url})")
|
||||
return True, service_id, None
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"注册MCP服务失败: {server_url}, 错误: {e}")
|
||||
return False, "注册失败", str(e)
|
||||
|
||||
async def unregister_service(self, service_id: str) -> Tuple[bool, str]:
|
||||
"""注销MCP服务
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
|
||||
Returns:
|
||||
(是否成功, 错误信息)
|
||||
"""
|
||||
try:
|
||||
# 从数据库删除
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
|
||||
if not tool_config:
|
||||
return False, "服务不存在"
|
||||
|
||||
self.db.delete(tool_config)
|
||||
self.db.commit()
|
||||
|
||||
# 停止监控
|
||||
await self._stop_service_monitoring(service_id)
|
||||
|
||||
# 从内存移除
|
||||
if service_id in self._services:
|
||||
del self._services[service_id]
|
||||
|
||||
logger.info(f"MCP服务注销成功: {service_id}")
|
||||
return True, ""
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"注销MCP服务失败: {service_id}, 错误: {e}")
|
||||
return False, str(e)
|
||||
|
||||
async def update_service(
|
||||
self,
|
||||
service_id: str,
|
||||
connection_config: Dict[str, Any] = None,
|
||||
enabled: bool = None
|
||||
) -> Tuple[bool, str]:
|
||||
"""更新MCP服务配置
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
connection_config: 新的连接配置
|
||||
enabled: 是否启用
|
||||
|
||||
Returns:
|
||||
(是否成功, 错误信息)
|
||||
"""
|
||||
try:
|
||||
# 更新数据库
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == uuid.UUID(service_id)
|
||||
).first()
|
||||
|
||||
if not mcp_config:
|
||||
return False, "服务不存在"
|
||||
|
||||
tool_config = mcp_config.base_config
|
||||
|
||||
if connection_config is not None:
|
||||
mcp_config.connection_config = connection_config
|
||||
tool_config.config_data["connection_config"] = connection_config
|
||||
|
||||
if enabled is not None:
|
||||
tool_config.is_enabled = enabled
|
||||
|
||||
self.db.commit()
|
||||
|
||||
# 更新内存状态
|
||||
if service_id in self._services:
|
||||
if connection_config is not None:
|
||||
self._services[service_id]["connection_config"] = connection_config
|
||||
|
||||
# 如果配置有变化,重启监控
|
||||
if connection_config is not None:
|
||||
await self._restart_service_monitoring(service_id)
|
||||
|
||||
logger.info(f"MCP服务更新成功: {service_id}")
|
||||
return True, ""
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"更新MCP服务失败: {service_id}, 错误: {e}")
|
||||
return False, str(e)
|
||||
|
||||
async def get_service_status(self, service_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取服务状态
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
|
||||
Returns:
|
||||
服务状态信息
|
||||
"""
|
||||
if service_id not in self._services:
|
||||
return None
|
||||
|
||||
service_info = self._services[service_id].copy()
|
||||
|
||||
# 添加实时健康检查
|
||||
try:
|
||||
client = await self.connection_pool.get_client(
|
||||
service_info["server_url"],
|
||||
service_info["connection_config"]
|
||||
)
|
||||
|
||||
health_status = await client.health_check()
|
||||
service_info["real_time_health"] = health_status
|
||||
|
||||
except Exception as e:
|
||||
service_info["real_time_health"] = {
|
||||
"healthy": False,
|
||||
"error": str(e),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
return service_info
|
||||
|
||||
async def list_services(self, tenant_id: uuid.UUID = None) -> List[Dict[str, Any]]:
|
||||
"""列出所有服务
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID过滤
|
||||
|
||||
Returns:
|
||||
服务列表
|
||||
"""
|
||||
services = []
|
||||
|
||||
for service_id, service_info in self._services.items():
|
||||
if tenant_id and service_info["tenant_id"] != tenant_id:
|
||||
continue
|
||||
|
||||
services.append(service_info.copy())
|
||||
|
||||
return services
|
||||
|
||||
async def get_service_tools(self, service_id: str) -> List[Dict[str, Any]]:
|
||||
"""获取服务的可用工具
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
|
||||
Returns:
|
||||
工具列表
|
||||
"""
|
||||
if service_id not in self._services:
|
||||
return []
|
||||
|
||||
service_info = self._services[service_id]
|
||||
|
||||
try:
|
||||
client = await self.connection_pool.get_client(
|
||||
service_info["server_url"],
|
||||
service_info["connection_config"]
|
||||
)
|
||||
|
||||
tools = await client.list_tools()
|
||||
|
||||
# 更新缓存的工具列表
|
||||
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
service_info["available_tools"] = tool_names
|
||||
|
||||
# 更新数据库
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == uuid.UUID(service_id)
|
||||
).first()
|
||||
|
||||
if mcp_config:
|
||||
mcp_config.available_tools = tool_names
|
||||
self.db.commit()
|
||||
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取服务工具失败: {service_id}, 错误: {e}")
|
||||
return []
|
||||
|
||||
async def call_service_tool(
|
||||
self,
|
||||
service_id: str,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
timeout: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""调用服务工具
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
if service_id not in self._services:
|
||||
raise ValueError(f"服务不存在: {service_id}")
|
||||
|
||||
service_info = self._services[service_id]
|
||||
|
||||
try:
|
||||
client = await self.connection_pool.get_client(
|
||||
service_info["server_url"],
|
||||
service_info["connection_config"]
|
||||
)
|
||||
|
||||
result = await client.call_tool(tool_name, arguments, timeout)
|
||||
|
||||
# 更新服务状态为健康
|
||||
service_info["status"] = "healthy"
|
||||
service_info["last_health_check"] = time.time()
|
||||
service_info["retry_count"] = 0
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# 更新服务状态为错误
|
||||
service_info["status"] = "error"
|
||||
service_info["last_error"] = str(e)
|
||||
service_info["retry_count"] += 1
|
||||
|
||||
logger.error(f"调用服务工具失败: {service_id}/{tool_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
async def _load_existing_services(self):
|
||||
"""加载现有服务"""
|
||||
try:
|
||||
mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter(
|
||||
ToolConfig.is_enabled == True
|
||||
).all()
|
||||
|
||||
for mcp_config in mcp_configs:
|
||||
tool_config = mcp_config.base_config
|
||||
service_id = str(mcp_config.id)
|
||||
|
||||
self._services[service_id] = {
|
||||
"id": service_id,
|
||||
"server_url": mcp_config.server_url,
|
||||
"connection_config": mcp_config.connection_config or {},
|
||||
"tenant_id": tool_config.tenant_id,
|
||||
"available_tools": mcp_config.available_tools or [],
|
||||
"status": mcp_config.health_status or "unknown",
|
||||
"last_health_check": mcp_config.last_health_check.timestamp() if mcp_config.last_health_check else 0,
|
||||
"retry_count": 0,
|
||||
"created_at": tool_config.created_at.timestamp()
|
||||
}
|
||||
|
||||
# 启动监控
|
||||
await self._start_service_monitoring(service_id)
|
||||
|
||||
logger.info(f"加载了 {len(mcp_configs)} 个MCP服务")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载现有服务失败: {e}")
|
||||
|
||||
async def _start_service_monitoring(self, service_id: str):
|
||||
"""启动服务监控"""
|
||||
if service_id in self._monitoring_tasks:
|
||||
return
|
||||
|
||||
task = asyncio.create_task(self._monitor_service(service_id))
|
||||
self._monitoring_tasks[service_id] = task
|
||||
|
||||
async def _stop_service_monitoring(self, service_id: str):
|
||||
"""停止服务监控"""
|
||||
if service_id in self._monitoring_tasks:
|
||||
task = self._monitoring_tasks.pop(service_id)
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _restart_service_monitoring(self, service_id: str):
|
||||
"""重启服务监控"""
|
||||
await self._stop_service_monitoring(service_id)
|
||||
await self._start_service_monitoring(service_id)
|
||||
|
||||
async def _monitor_service(self, service_id: str):
|
||||
"""监控单个服务"""
|
||||
try:
|
||||
while self._running and service_id in self._services:
|
||||
service_info = self._services[service_id]
|
||||
|
||||
try:
|
||||
# 执行健康检查
|
||||
client = await self.connection_pool.get_client(
|
||||
service_info["server_url"],
|
||||
service_info["connection_config"]
|
||||
)
|
||||
|
||||
health_status = await client.health_check()
|
||||
|
||||
if health_status["healthy"]:
|
||||
# 服务健康
|
||||
service_info["status"] = "healthy"
|
||||
service_info["retry_count"] = 0
|
||||
|
||||
# 更新工具列表
|
||||
try:
|
||||
tools = await client.list_tools()
|
||||
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
service_info["available_tools"] = tool_names
|
||||
except Exception as e:
|
||||
logger.warning(f"更新工具列表失败: {service_id}, 错误: {e}")
|
||||
|
||||
else:
|
||||
# 服务不健康
|
||||
service_info["status"] = "unhealthy"
|
||||
service_info["last_error"] = health_status.get("error", "健康检查失败")
|
||||
service_info["retry_count"] += 1
|
||||
|
||||
service_info["last_health_check"] = time.time()
|
||||
|
||||
# 更新数据库
|
||||
await self._update_service_health_in_db(service_id, health_status)
|
||||
|
||||
except Exception as e:
|
||||
# 监控异常
|
||||
service_info["status"] = "error"
|
||||
service_info["last_error"] = str(e)
|
||||
service_info["retry_count"] += 1
|
||||
service_info["last_health_check"] = time.time()
|
||||
|
||||
logger.error(f"服务监控异常: {service_id}, 错误: {e}")
|
||||
|
||||
# 如果重试次数过多,暂停监控
|
||||
if service_info["retry_count"] >= self.max_retry_attempts:
|
||||
logger.warning(f"服务 {service_id} 重试次数过多,暂停监控")
|
||||
await asyncio.sleep(self.health_check_interval * 5) # 延长等待时间
|
||||
service_info["retry_count"] = 0 # 重置重试计数
|
||||
|
||||
# 等待下次检查
|
||||
await asyncio.sleep(self.health_check_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"服务监控任务异常: {service_id}, 错误: {e}")
|
||||
|
||||
async def _update_service_health_in_db(self, service_id: str, health_status: Dict[str, Any]):
|
||||
"""更新数据库中的服务健康状态"""
|
||||
try:
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == uuid.UUID(service_id)
|
||||
).first()
|
||||
|
||||
if mcp_config:
|
||||
mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy"
|
||||
mcp_config.last_health_check = datetime.utcnow()
|
||||
|
||||
if not health_status["healthy"]:
|
||||
mcp_config.error_message = health_status.get("error", "")
|
||||
else:
|
||||
mcp_config.error_message = None
|
||||
|
||||
self.db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新数据库健康状态失败: {service_id}, 错误: {e}")
|
||||
self.db.rollback()
|
||||
|
||||
async def _management_loop(self):
|
||||
"""管理循环 - 处理服务清理等任务"""
|
||||
try:
|
||||
while self._running:
|
||||
# 清理失效的服务
|
||||
await self._cleanup_failed_services()
|
||||
|
||||
# 等待下次循环
|
||||
await asyncio.sleep(300) # 5分钟
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"管理循环异常: {e}")
|
||||
|
||||
async def _cleanup_failed_services(self):
|
||||
"""清理长期失效的服务"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
cleanup_threshold = 24 * 60 * 60 # 24小时
|
||||
|
||||
services_to_cleanup = []
|
||||
|
||||
for service_id, service_info in self._services.items():
|
||||
# 检查服务是否长期失效
|
||||
if (service_info["status"] in ["error", "unhealthy"] and
|
||||
current_time - service_info["last_health_check"] > cleanup_threshold):
|
||||
|
||||
services_to_cleanup.append(service_id)
|
||||
|
||||
for service_id in services_to_cleanup:
|
||||
logger.warning(f"清理长期失效的服务: {service_id}")
|
||||
|
||||
# 停止监控但不删除数据库记录
|
||||
await self._stop_service_monitoring(service_id)
|
||||
|
||||
# 标记为禁用
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
|
||||
if tool_config:
|
||||
tool_config.is_enabled = False
|
||||
self.db.commit()
|
||||
|
||||
# 从内存移除
|
||||
del self._services[service_id]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理失效服务失败: {e}")
|
||||
|
||||
def get_manager_status(self) -> Dict[str, Any]:
|
||||
"""获取管理器状态"""
|
||||
return {
|
||||
"running": self._running,
|
||||
"total_services": len(self._services),
|
||||
"healthy_services": len([s for s in self._services.values() if s["status"] == "healthy"]),
|
||||
"unhealthy_services": len([s for s in self._services.values() if s["status"] in ["unhealthy", "error"]]),
|
||||
"monitoring_tasks": len(self._monitoring_tasks),
|
||||
"connection_pool_status": self.connection_pool.get_pool_status()
|
||||
}
|
||||
436
api/app/core/tools/registry.py
Normal file
436
api/app/core/tools/registry.py
Normal file
@@ -0,0 +1,436 @@
|
||||
"""工具注册表 - 管理所有工具的元数据和状态"""
|
||||
import uuid
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Type, Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_
|
||||
|
||||
from app.models.tool_model import (
|
||||
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
|
||||
ToolType, ToolStatus, ToolExecution, ExecutionStatus
|
||||
)
|
||||
from app.core.logging_config import get_business_logger
|
||||
from .base import BaseTool, ToolInfo
|
||||
from .custom.base import CustomTool
|
||||
from .mcp.base import MCPTool
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""工具注册表 - 管理所有工具的元数据和实例"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""初始化工具注册表
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
"""
|
||||
self.db = db
|
||||
self._tools: Dict[str, BaseTool] = {} # 工具实例缓存
|
||||
self._tool_classes: Dict[str, Type[BaseTool]] = {} # 工具类注册表
|
||||
self._lock = asyncio.Lock() # 异步锁
|
||||
|
||||
def register_tool_class(self, tool_class: Type[BaseTool], class_name: str = None):
|
||||
"""注册工具类
|
||||
|
||||
Args:
|
||||
tool_class: 工具类
|
||||
class_name: 类名(可选,默认使用类的__name__)
|
||||
"""
|
||||
class_name = class_name or tool_class.__name__
|
||||
self._tool_classes[class_name] = tool_class
|
||||
logger.info(f"工具类已注册: {class_name}")
|
||||
|
||||
async def register_tool(self, tool: BaseTool, tenant_id: Optional[uuid.UUID] = None) -> bool:
|
||||
"""注册工具实例到系统
|
||||
|
||||
Args:
|
||||
tool: 工具实例
|
||||
tenant_id: 租户ID(内置工具可以为None,表示全局工具)
|
||||
|
||||
Returns:
|
||||
注册是否成功
|
||||
"""
|
||||
async with self._lock:
|
||||
try:
|
||||
# 检查工具是否已存在
|
||||
if tenant_id:
|
||||
existing_config = self.db.query(ToolConfig).filter(
|
||||
and_(
|
||||
ToolConfig.name == tool.name,
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
ToolConfig.tool_type == tool.tool_type.value
|
||||
)
|
||||
).first()
|
||||
else:
|
||||
# 全局工具(内置工具)
|
||||
existing_config = self.db.query(ToolConfig).filter(
|
||||
and_(
|
||||
ToolConfig.name == tool.name,
|
||||
ToolConfig.tenant_id.is_(None),
|
||||
ToolConfig.tool_type == tool.tool_type.value
|
||||
)
|
||||
).first()
|
||||
|
||||
if existing_config:
|
||||
logger.warning(f"工具已存在: {tool.name} (tenant: {tenant_id or 'global'})")
|
||||
return False
|
||||
|
||||
# 创建工具配置
|
||||
tool_config = ToolConfig(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
tool_type=tool.tool_type.value,
|
||||
tenant_id=tenant_id,
|
||||
version=tool.version,
|
||||
tags=tool.tags,
|
||||
config_data=tool.config
|
||||
)
|
||||
|
||||
self.db.add(tool_config)
|
||||
self.db.flush() # 获取ID
|
||||
|
||||
# 根据工具类型创建特定配置
|
||||
if tool.tool_type == ToolType.BUILTIN:
|
||||
builtin_config = BuiltinToolConfig(
|
||||
id=tool_config.id,
|
||||
tool_class=tool.__class__.__name__,
|
||||
parameters=tool.config.get("parameters", {})
|
||||
)
|
||||
self.db.add(builtin_config)
|
||||
|
||||
elif tool.tool_type == ToolType.CUSTOM:
|
||||
custom_config = CustomToolConfig(
|
||||
id=tool_config.id,
|
||||
schema_url=tool.config.get("schema_url"),
|
||||
schema_content=tool.config.get("schema_content"),
|
||||
auth_type=tool.config.get("auth_type", "none"),
|
||||
auth_config=tool.config.get("auth_config", {}),
|
||||
base_url=tool.config.get("base_url"),
|
||||
timeout=tool.config.get("timeout", 30)
|
||||
)
|
||||
self.db.add(custom_config)
|
||||
|
||||
elif tool.tool_type == ToolType.MCP:
|
||||
mcp_config = MCPToolConfig(
|
||||
id=tool_config.id,
|
||||
server_url=tool.config.get("server_url"),
|
||||
connection_config=tool.config.get("connection_config", {}),
|
||||
available_tools=tool.config.get("available_tools", [])
|
||||
)
|
||||
self.db.add(mcp_config)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
# 缓存工具实例
|
||||
tool.tool_id = str(tool_config.id)
|
||||
self._tools[str(tool_config.id)] = tool
|
||||
|
||||
logger.info(f"工具注册成功: {tool.name} (ID: {tool_config.id})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"工具注册失败: {tool.name}, 错误: {e}")
|
||||
return False
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> bool:
|
||||
"""从系统注销工具
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
|
||||
Returns:
|
||||
注销是否成功
|
||||
"""
|
||||
async with self._lock:
|
||||
try:
|
||||
# 检查工具是否存在
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
|
||||
if not tool_config:
|
||||
logger.warning(f"工具不存在: {tool_id}")
|
||||
return False
|
||||
|
||||
# 检查是否有正在执行的任务
|
||||
running_executions = self.db.query(ToolExecution).filter(
|
||||
and_(
|
||||
ToolExecution.tool_config_id == uuid.UUID(tool_id),
|
||||
ToolExecution.status.in_([ExecutionStatus.PENDING.value, ExecutionStatus.RUNNING.value])
|
||||
)
|
||||
).count()
|
||||
|
||||
if running_executions > 0:
|
||||
logger.warning(f"工具有正在执行的任务,无法注销: {tool_id}")
|
||||
return False
|
||||
|
||||
# 删除工具配置(级联删除相关记录)
|
||||
self.db.delete(tool_config)
|
||||
self.db.commit()
|
||||
|
||||
# 从缓存中移除
|
||||
if tool_id in self._tools:
|
||||
del self._tools[tool_id]
|
||||
|
||||
logger.info(f"工具注销成功: {tool_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"工具注销失败: {tool_id}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def get_tool(self, tool_id: str) -> Optional[BaseTool]:
|
||||
"""获取工具实例
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
|
||||
Returns:
|
||||
工具实例,如果不存在返回None
|
||||
"""
|
||||
# 先从缓存获取
|
||||
if tool_id in self._tools:
|
||||
return self._tools[tool_id]
|
||||
|
||||
# 从数据库加载
|
||||
try:
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
|
||||
if not tool_config or not tool_config.status == ToolStatus.ACTIVE.value:
|
||||
return None
|
||||
|
||||
# 根据工具类型加载实例
|
||||
tool_instance = self._load_tool_instance(tool_config)
|
||||
if tool_instance:
|
||||
self._tools[tool_id] = tool_instance
|
||||
return tool_instance
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载工具失败: {tool_id}, 错误: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def list_tools(
|
||||
self,
|
||||
tenant_id: Optional[uuid.UUID] = None,
|
||||
tool_type: Optional[ToolType] = None,
|
||||
status: Optional[ToolStatus] = None,
|
||||
tags: Optional[List[str]] = None
|
||||
) -> List[ToolInfo]:
|
||||
"""列出工具
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID过滤
|
||||
tool_type: 工具类型过滤
|
||||
status: 工具状态过滤
|
||||
tags: 标签过滤
|
||||
|
||||
Returns:
|
||||
工具信息列表
|
||||
"""
|
||||
try:
|
||||
query = self.db.query(ToolConfig)
|
||||
|
||||
# 应用过滤条件
|
||||
if tenant_id:
|
||||
# 返回全局工具(tenant_id为空)和该租户的工具
|
||||
query = query.filter(
|
||||
or_(
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
ToolConfig.tenant_id.is_(None)
|
||||
)
|
||||
)
|
||||
|
||||
if tool_type:
|
||||
query = query.filter(ToolConfig.tool_type == tool_type.value)
|
||||
|
||||
if status == ToolStatus.ACTIVE:
|
||||
query = query.filter(ToolConfig.is_enabled == True)
|
||||
elif status == ToolStatus.INACTIVE:
|
||||
query = query.filter(ToolConfig.is_enabled == False)
|
||||
|
||||
if tags:
|
||||
for tag in tags:
|
||||
query = query.filter(ToolConfig.tags.contains([tag]))
|
||||
|
||||
tool_configs = query.all()
|
||||
|
||||
# 转换为ToolInfo
|
||||
tool_infos = []
|
||||
for config in tool_configs:
|
||||
tool_info = ToolInfo(
|
||||
id=str(config.id),
|
||||
name=config.name,
|
||||
description=config.description or "",
|
||||
tool_type=ToolType(config.tool_type),
|
||||
version=config.version,
|
||||
status=ToolStatus.ACTIVE if config.is_enabled else ToolStatus.INACTIVE,
|
||||
tags=config.tags or [],
|
||||
tenant_id=str(config.tenant_id) if config.tenant_id else None
|
||||
)
|
||||
|
||||
# 尝试获取参数信息
|
||||
tool_instance = self.get_tool(str(config.id))
|
||||
if tool_instance:
|
||||
tool_info.parameters = tool_instance.parameters
|
||||
|
||||
tool_infos.append(tool_info)
|
||||
|
||||
return tool_infos
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"列出工具失败, 错误: {e}")
|
||||
return []
|
||||
|
||||
async def update_tool_status(self, tool_id: str, status: ToolStatus) -> bool:
|
||||
"""更新工具状态
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
status: 新状态
|
||||
|
||||
Returns:
|
||||
更新是否成功
|
||||
"""
|
||||
try:
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
|
||||
if not tool_config:
|
||||
logger.warning(f"工具不存在: {tool_id}")
|
||||
return False
|
||||
|
||||
# 更新状态
|
||||
if status == ToolStatus.ACTIVE:
|
||||
tool_config.is_enabled = True
|
||||
elif status == ToolStatus.INACTIVE:
|
||||
tool_config.is_enabled = False
|
||||
|
||||
self.db.commit()
|
||||
|
||||
# 更新缓存中的工具状态
|
||||
if tool_id in self._tools:
|
||||
self._tools[tool_id].status = status
|
||||
|
||||
logger.info(f"工具状态更新成功: {tool_id} -> {status}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"工具状态更新失败: {tool_id}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def _load_tool_instance(self, tool_config: type[ToolConfig] | None) -> Optional[BaseTool]:
|
||||
"""从配置加载工具实例
|
||||
|
||||
Args:
|
||||
tool_config: 工具配置
|
||||
|
||||
Returns:
|
||||
工具实例
|
||||
"""
|
||||
try:
|
||||
if tool_config.tool_type == ToolType.BUILTIN.value:
|
||||
# 加载内置工具
|
||||
builtin_config = self.db.query(BuiltinToolConfig).filter(
|
||||
BuiltinToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if builtin_config and builtin_config.tool_class in self._tool_classes:
|
||||
tool_class = self._tool_classes[builtin_config.tool_class]
|
||||
config = {
|
||||
**tool_config.config_data,
|
||||
"parameters": builtin_config.parameters,
|
||||
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
|
||||
"version": tool_config.version,
|
||||
"tags": tool_config.tags
|
||||
}
|
||||
return tool_class(str(tool_config.id), config)
|
||||
|
||||
elif tool_config.tool_type == ToolType.CUSTOM.value:
|
||||
# 加载自定义工具
|
||||
try:
|
||||
custom_config = self.db.query(CustomToolConfig).filter(
|
||||
CustomToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if custom_config:
|
||||
config = {
|
||||
**tool_config.config_data,
|
||||
"schema_url": custom_config.schema_url,
|
||||
"schema_content": custom_config.schema_content,
|
||||
"auth_type": custom_config.auth_type,
|
||||
"auth_config": custom_config.auth_config,
|
||||
"base_url": custom_config.base_url,
|
||||
"timeout": custom_config.timeout,
|
||||
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
|
||||
"version": tool_config.version,
|
||||
"tags": tool_config.tags
|
||||
}
|
||||
return CustomTool(str(tool_config.id), config)
|
||||
except ImportError as e:
|
||||
logger.error(f"无法导入自定义工具模块: {e}")
|
||||
|
||||
elif tool_config.tool_type == ToolType.MCP.value:
|
||||
# 加载MCP工具
|
||||
try:
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if mcp_config:
|
||||
config = {
|
||||
**tool_config.config_data,
|
||||
"server_url": mcp_config.server_url,
|
||||
"connection_config": mcp_config.connection_config,
|
||||
"available_tools": mcp_config.available_tools,
|
||||
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
|
||||
"version": tool_config.version,
|
||||
"tags": tool_config.tags
|
||||
}
|
||||
return MCPTool(str(tool_config.id), config)
|
||||
except ImportError as e:
|
||||
logger.error(f"无法导入MCP工具模块: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载工具实例失败: {tool_config.id}, 错误: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def get_tool_statistics(self, tenant_id: Optional[uuid.UUID] = None) -> Dict[str, Any]:
|
||||
"""获取工具统计信息
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
try:
|
||||
query = self.db.query(ToolConfig)
|
||||
if tenant_id:
|
||||
query = query.filter(ToolConfig.tenant_id == tenant_id)
|
||||
|
||||
total_tools = query.count()
|
||||
active_tools = query.filter(ToolConfig.is_enabled == True).count()
|
||||
|
||||
# 按类型统计
|
||||
type_stats = {}
|
||||
for tool_type in ToolType:
|
||||
count = query.filter(ToolConfig.tool_type == tool_type.value).count()
|
||||
type_stats[tool_type.value] = count
|
||||
|
||||
return {
|
||||
"total_tools": total_tools,
|
||||
"active_tools": active_tools,
|
||||
"inactive_tools": total_tools - active_tools,
|
||||
"by_type": type_stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具统计失败, 错误: {e}")
|
||||
return {}
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空工具缓存"""
|
||||
self._tools.clear()
|
||||
logger.info("工具缓存已清空")
|
||||
@@ -5,36 +5,41 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
# import uuid
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.models.workflow_model import WorkflowExecution, WorkflowNodeExecution
|
||||
from app.db import get_db
|
||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
# from app.core.tools.registry import ToolRegistry
|
||||
# from app.core.tools.executor import ToolExecutor
|
||||
# from app.core.tools.langchain_adapter import LangchainAdapter
|
||||
# TOOL_MANAGEMENT_AVAILABLE = True
|
||||
# from app.db import get_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowExecutor:
|
||||
"""工作流执行器
|
||||
|
||||
|
||||
负责将工作流配置转换为 LangGraph 并执行。
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workflow_config: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
self,
|
||||
workflow_config: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
):
|
||||
"""初始化执行器
|
||||
|
||||
|
||||
Args:
|
||||
workflow_config: 工作流配置
|
||||
execution_id: 执行 ID
|
||||
@@ -48,25 +53,25 @@ class WorkflowExecutor:
|
||||
self.nodes = workflow_config.get("nodes", [])
|
||||
self.edges = workflow_config.get("edges", [])
|
||||
self.execution_config = workflow_config.get("execution_config", {})
|
||||
|
||||
|
||||
def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
|
||||
"""准备初始状态(注入系统变量和会话变量)
|
||||
|
||||
|
||||
变量命名空间:
|
||||
- sys.xxx - 系统变量(execution_id, workspace_id, user_id, message, input_variables 等)
|
||||
- conv.xxx - 会话变量(跨多轮对话保持)
|
||||
- node_id.xxx - 节点输出(执行时动态生成)
|
||||
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
|
||||
|
||||
Returns:
|
||||
初始化的工作流状态
|
||||
"""
|
||||
user_message = input_data.get("message") or ""
|
||||
conversation_vars = input_data.get("conversation_vars") or {}
|
||||
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
|
||||
|
||||
|
||||
# 构建分层的变量结构
|
||||
variables = {
|
||||
"sys": {
|
||||
@@ -79,7 +84,7 @@ class WorkflowExecutor:
|
||||
},
|
||||
"conv": conversation_vars # 会话级变量(跨多轮对话保持)
|
||||
}
|
||||
|
||||
|
||||
return {
|
||||
"messages": [HumanMessage(content=user_message)],
|
||||
"variables": variables,
|
||||
@@ -89,163 +94,277 @@ class WorkflowExecutor:
|
||||
"workspace_id": self.workspace_id,
|
||||
"user_id": self.user_id,
|
||||
"error": None,
|
||||
"error_node": None
|
||||
"error_node": None,
|
||||
"streaming_buffer": {} # 流式缓冲区
|
||||
}
|
||||
|
||||
|
||||
|
||||
def build_graph(self) -> StateGraph:
|
||||
"""构建 LangGraph
|
||||
|
||||
|
||||
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
|
||||
"""分析 End 节点的前缀配置
|
||||
|
||||
检查每个 End 节点的模板,找到直接上游节点的引用,
|
||||
提取该引用之前的前缀部分。
|
||||
|
||||
Returns:
|
||||
元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合})
|
||||
"""
|
||||
import re
|
||||
|
||||
prefixes = {}
|
||||
adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点
|
||||
|
||||
# 找到所有 End 节点
|
||||
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
||||
logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点")
|
||||
|
||||
for end_node in end_nodes:
|
||||
end_node_id = end_node.get("id")
|
||||
output_template = end_node.get("config", {}).get("output")
|
||||
|
||||
logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}")
|
||||
|
||||
if not output_template:
|
||||
continue
|
||||
|
||||
# 找到所有直接连接到 End 节点的上游节点
|
||||
direct_upstream_nodes = []
|
||||
for edge in self.edges:
|
||||
if edge.get("target") == end_node_id:
|
||||
source_node_id = edge.get("source")
|
||||
direct_upstream_nodes.append(source_node_id)
|
||||
|
||||
logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}")
|
||||
|
||||
# 查找模板中引用了哪些节点
|
||||
# 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格)
|
||||
pattern = r'\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*\}\}'
|
||||
matches = list(re.finditer(pattern, output_template))
|
||||
|
||||
logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用")
|
||||
|
||||
# 找到第一个直接上游节点的引用
|
||||
for match in matches:
|
||||
referenced_node_id = match.group(1)
|
||||
logger.info(f"[前缀分析] 检查引用: {referenced_node_id}")
|
||||
|
||||
if referenced_node_id in direct_upstream_nodes:
|
||||
# 这是直接上游节点的引用,提取前缀
|
||||
prefix = output_template[:match.start()]
|
||||
|
||||
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
|
||||
|
||||
# 标记这个节点为"相邻且被引用"
|
||||
adjacent_and_referenced.add(referenced_node_id)
|
||||
|
||||
if prefix:
|
||||
prefixes[referenced_node_id] = prefix
|
||||
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
|
||||
|
||||
# 只处理第一个直接上游节点的引用
|
||||
break
|
||||
|
||||
logger.info(f"[前缀分析] 最终配置: {prefixes}")
|
||||
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
|
||||
return prefixes, adjacent_and_referenced
|
||||
|
||||
def build_graph(self,stream=False) -> CompiledStateGraph:
|
||||
"""构建 LangGraph
|
||||
|
||||
Returns:
|
||||
编译后的状态图
|
||||
"""
|
||||
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
|
||||
|
||||
# 分析 End 节点的前缀配置和相邻且被引用的节点
|
||||
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if stream else ({}, set())
|
||||
|
||||
# 1. 创建状态图
|
||||
workflow = StateGraph(WorkflowState)
|
||||
|
||||
|
||||
# 2. 添加所有节点(包括 start 和 end)
|
||||
start_node_id = None
|
||||
end_node_ids = []
|
||||
|
||||
|
||||
for node in self.nodes:
|
||||
node_type = node.get("type")
|
||||
node_id = node.get("id")
|
||||
|
||||
|
||||
# 记录 start 和 end 节点 ID
|
||||
if node_type == "start":
|
||||
if node_type == NodeType.START:
|
||||
start_node_id = node_id
|
||||
elif node_type == "end":
|
||||
elif node_type == NodeType.END:
|
||||
end_node_ids.append(node_id)
|
||||
|
||||
|
||||
# 创建节点实例(现在 start 和 end 也会被创建)
|
||||
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
||||
|
||||
if node_type in [NodeType.IF_ELSE]:
|
||||
expressions = node_instance.build_conditional_edge_expressions()
|
||||
|
||||
# Number of branches, usually matches the number of conditional expressions
|
||||
branch_number = len(expressions)
|
||||
|
||||
# Find all edges whose source is the current node
|
||||
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
|
||||
|
||||
# Iterate over each branch
|
||||
for idx in range(branch_number):
|
||||
# Generate a condition expression for each edge
|
||||
# Used later to determine which branch to take based on the node's output
|
||||
# Assumes node output `node.<node_id>.output` matches the edge's label
|
||||
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
|
||||
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
||||
|
||||
if node_instance:
|
||||
# 如果是流式模式,且节点有 End 前缀配置,注入配置
|
||||
if stream and node_id in end_prefixes:
|
||||
# 将 End 前缀配置注入到节点实例
|
||||
node_instance._end_node_prefix = end_prefixes[node_id]
|
||||
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
|
||||
|
||||
# 如果是流式模式,标记节点是否与 End 相邻且被引用
|
||||
if stream:
|
||||
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
|
||||
if node_id in adjacent_and_referenced:
|
||||
logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用")
|
||||
|
||||
# 包装节点的 run 方法
|
||||
# 使用函数工厂避免闭包问题
|
||||
def make_node_func(inst):
|
||||
async def node_func(state: WorkflowState):
|
||||
return await inst.run(state)
|
||||
return node_func
|
||||
|
||||
workflow.add_node(node_id, make_node_func(node_instance))
|
||||
logger.debug(f"添加节点: {node_id} (type={node_type})")
|
||||
|
||||
if stream:
|
||||
# 流式模式:创建 async generator 函数
|
||||
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
|
||||
def make_stream_func(inst):
|
||||
async def node_func(state: WorkflowState):
|
||||
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
|
||||
async for item in inst.run_stream(state):
|
||||
yield item
|
||||
return node_func
|
||||
workflow.add_node(node_id, make_stream_func(node_instance))
|
||||
else:
|
||||
# 非流式模式:创建 async function
|
||||
def make_func(inst):
|
||||
async def node_func(state: WorkflowState):
|
||||
return await inst.run(state)
|
||||
return node_func
|
||||
workflow.add_node(node_id, make_func(node_instance))
|
||||
|
||||
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={stream})")
|
||||
|
||||
# 3. 添加边
|
||||
# 从 START 连接到 start 节点
|
||||
if start_node_id:
|
||||
workflow.add_edge(START, start_node_id)
|
||||
logger.debug(f"添加边: START -> {start_node_id}")
|
||||
|
||||
|
||||
for edge in self.edges:
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
edge_type = edge.get("type")
|
||||
condition = edge.get("condition")
|
||||
|
||||
|
||||
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start)
|
||||
if source == start_node_id:
|
||||
# 但要连接 start 到下一个节点
|
||||
workflow.add_edge(source, target)
|
||||
logger.debug(f"添加边: {source} -> {target}")
|
||||
continue
|
||||
|
||||
|
||||
# 处理到 end 节点的边
|
||||
if target in end_node_ids:
|
||||
# 连接到 end 节点
|
||||
workflow.add_edge(source, target)
|
||||
logger.debug(f"添加边: {source} -> {target}")
|
||||
continue
|
||||
|
||||
|
||||
# 跳过错误边(在节点内部处理)
|
||||
if edge_type == "error":
|
||||
continue
|
||||
|
||||
|
||||
if condition:
|
||||
# 条件边
|
||||
def router(state: WorkflowState, cond=condition, tgt=target):
|
||||
"""条件路由函数"""
|
||||
if evaluate_condition(
|
||||
cond,
|
||||
state.get("variables", {}),
|
||||
state.get("node_outputs", {}),
|
||||
{
|
||||
"execution_id": state.get("execution_id"),
|
||||
"workspace_id": state.get("workspace_id"),
|
||||
"user_id": state.get("user_id")
|
||||
}
|
||||
cond,
|
||||
state.get("variables", {}),
|
||||
state.get("node_outputs", {}),
|
||||
{
|
||||
"execution_id": state.get("execution_id"),
|
||||
"workspace_id": state.get("workspace_id"),
|
||||
"user_id": state.get("user_id")
|
||||
}
|
||||
):
|
||||
return tgt
|
||||
return END # 条件不满足,结束
|
||||
|
||||
|
||||
workflow.add_conditional_edges(source, router)
|
||||
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
|
||||
else:
|
||||
# 普通边
|
||||
workflow.add_edge(source, target)
|
||||
logger.debug(f"添加边: {source} -> {target}")
|
||||
|
||||
|
||||
# 从 end 节点连接到 END
|
||||
for end_node_id in end_node_ids:
|
||||
workflow.add_edge(end_node_id, END)
|
||||
logger.debug(f"添加边: {end_node_id} -> END")
|
||||
|
||||
|
||||
# 4. 编译图
|
||||
graph = workflow.compile()
|
||||
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
|
||||
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
input_data: dict[str, Any]
|
||||
self,
|
||||
input_data: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""执行工作流(非流式)
|
||||
|
||||
|
||||
Args:
|
||||
input_data: 输入数据,包含 message 和 variables
|
||||
|
||||
|
||||
Returns:
|
||||
执行结果,包含 status, output, node_outputs, elapsed_time, token_usage
|
||||
"""
|
||||
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
|
||||
|
||||
|
||||
# 记录开始时间
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
|
||||
# 1. 构建图
|
||||
graph = self.build_graph()
|
||||
|
||||
|
||||
# 2. 初始化状态(自动注入系统变量)
|
||||
initial_state = self._prepare_initial_state(input_data)
|
||||
|
||||
|
||||
# 3. 执行工作流
|
||||
try:
|
||||
result = await graph.ainvoke(initial_state)
|
||||
|
||||
|
||||
# 计算耗时
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
|
||||
# 提取节点输出(现在包含 start 和 end 节点)
|
||||
node_outputs = result.get("node_outputs", {})
|
||||
|
||||
|
||||
# 提取最终输出(从最后一个非 start/end 节点)
|
||||
final_output = self._extract_final_output(node_outputs)
|
||||
|
||||
|
||||
# 聚合 token 使用情况
|
||||
token_usage = self._aggregate_token_usage(node_outputs)
|
||||
|
||||
|
||||
# 提取 conversation_id(从 start 节点输出)
|
||||
conversation_id = None
|
||||
for node_id, node_output in node_outputs.items():
|
||||
if node_output.get("node_type") == "start":
|
||||
conversation_id = node_output.get("output", {}).get("conversation_id")
|
||||
break
|
||||
|
||||
|
||||
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"output": final_output,
|
||||
@@ -256,12 +375,12 @@ class WorkflowExecutor:
|
||||
"token_usage": token_usage,
|
||||
"error": result.get("error")
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# 计算耗时(即使失败也记录)
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
|
||||
logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True)
|
||||
return {
|
||||
"status": "failed",
|
||||
@@ -271,86 +390,200 @@ class WorkflowExecutor:
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": None
|
||||
}
|
||||
|
||||
|
||||
async def execute_stream(
|
||||
self,
|
||||
input_data: dict[str, Any]
|
||||
self,
|
||||
input_data: dict[str, Any]
|
||||
):
|
||||
"""执行工作流(流式)
|
||||
|
||||
|
||||
使用多个 stream_mode 来获取:
|
||||
1. "updates" - 节点的 state 更新和流式 chunk
|
||||
2. "debug" - 节点执行的详细信息(开始/完成时间)
|
||||
3. "custom" - 自定义流式数据(chunks)
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
|
||||
|
||||
Yields:
|
||||
流式事件
|
||||
流式事件,格式:
|
||||
{
|
||||
"event": "workflow_start" | "workflow_end" | "node_start" | "node_end" | "node_chunk" | "message",
|
||||
"data": {...}
|
||||
}
|
||||
"""
|
||||
logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}")
|
||||
|
||||
# 记录开始时间
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
# 发送 workflow_start 事件
|
||||
yield {
|
||||
"event": "workflow_start",
|
||||
"data": {
|
||||
"execution_id": self.execution_id,
|
||||
"workspace_id": self.workspace_id,
|
||||
"timestamp": start_time.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
# 1. 构建图
|
||||
graph = self.build_graph()
|
||||
|
||||
graph = self.build_graph(True)
|
||||
|
||||
# 2. 初始化状态(自动注入系统变量)
|
||||
initial_state = self._prepare_initial_state(input_data)
|
||||
|
||||
# 3. 流式执行工作流
|
||||
|
||||
# 3. Execute workflow
|
||||
try:
|
||||
# 使用 astream 获取节点级别的更新
|
||||
async for event in graph.astream(initial_state, stream_mode="updates"):
|
||||
for node_name, state_update in event.items():
|
||||
chunk_count = 0
|
||||
final_state = None
|
||||
|
||||
async for event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
|
||||
):
|
||||
# event should be a tuple: (mode, data)
|
||||
# But let's handle both cases
|
||||
if isinstance(event, tuple) and len(event) == 2:
|
||||
mode, data = event
|
||||
else:
|
||||
# Unexpected format, log and skip
|
||||
logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}")
|
||||
continue
|
||||
|
||||
if mode == "custom":
|
||||
# Handle custom streaming events (chunks from nodes via stream writer)
|
||||
chunk_count += 1
|
||||
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
|
||||
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}")
|
||||
|
||||
yield {
|
||||
"type": "node_complete",
|
||||
"node": node_name,
|
||||
"data": state_update,
|
||||
"execution_id": self.execution_id
|
||||
"event": event_type, # "message" or "node_chunk"
|
||||
"data": {
|
||||
"node_id": data.get("node_id"),
|
||||
"chunk": data.get("chunk"),
|
||||
"full_content": data.get("full_content"),
|
||||
"chunk_index": data.get("chunk_index"),
|
||||
"is_prefix": data.get("is_prefix"),
|
||||
"is_suffix": data.get("is_suffix")
|
||||
}
|
||||
}
|
||||
|
||||
elif mode == "debug":
|
||||
# Handle debug information (node execution status)
|
||||
event_type = data.get("type")
|
||||
payload = data.get("payload", {})
|
||||
node_name = payload.get("name")
|
||||
|
||||
if event_type == "task":
|
||||
# Node starts execution
|
||||
inputv = payload.get("input", {})
|
||||
variables = inputv.get("variables", {})
|
||||
variables_sys = variables.get("sys", {})
|
||||
conversation_id = variables_sys.get("conversation_id")
|
||||
execution_id = variables_sys.get("execution_id")
|
||||
logger.info(f"[DEBUG] Node starts execution: {node_name}")
|
||||
|
||||
yield {
|
||||
"event": "node_start",
|
||||
"data": {
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": execution_id,
|
||||
"timestamp": data.get("timestamp")
|
||||
}
|
||||
}
|
||||
elif event_type == "task_result":
|
||||
# Node execution completed
|
||||
result = payload.get("result", {})
|
||||
inputv = result.get("input", {})
|
||||
variables = inputv.get("variables", {})
|
||||
variables_sys = variables.get("sys", {})
|
||||
conversation_id = variables_sys.get("conversation_id")
|
||||
execution_id = variables_sys.get("execution_id")
|
||||
logger.info(f"[DEBUG] Node execution completed: {node_name}")
|
||||
|
||||
yield {
|
||||
"event": "node_end",
|
||||
"data": {
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": execution_id,
|
||||
"timestamp": data.get("timestamp")
|
||||
}
|
||||
}
|
||||
|
||||
elif mode == "updates":
|
||||
# Handle state updates - store final state
|
||||
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}")
|
||||
final_state = data
|
||||
|
||||
logger.info(f"工作流执行完成(流式): execution_id={self.execution_id}")
|
||||
# 计算耗时
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
# 发送完成事件
|
||||
logger.info(f"Workflow execution completed (streaming), total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s")
|
||||
|
||||
# 发送 workflow_end 事件
|
||||
yield {
|
||||
"type": "workflow_complete",
|
||||
"execution_id": self.execution_id
|
||||
"event": "workflow_end",
|
||||
"data": {
|
||||
"execution_id": self.execution_id,
|
||||
"status": "completed",
|
||||
"elapsed_time": elapsed_time,
|
||||
"timestamp": end_time.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流执行失败(流式): execution_id={self.execution_id}, error={e}", exc_info=True)
|
||||
# 计算耗时(即使失败也记录)
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True)
|
||||
|
||||
# 发送 workflow_end 事件(失败)
|
||||
yield {
|
||||
"type": "workflow_error",
|
||||
"execution_id": self.execution_id,
|
||||
"error": str(e)
|
||||
"event": "workflow_end",
|
||||
"data": {
|
||||
"execution_id": self.execution_id,
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"elapsed_time": elapsed_time,
|
||||
"timestamp": end_time.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
|
||||
"""从节点输出中提取最终输出
|
||||
|
||||
|
||||
优先级:
|
||||
1. 最后一个执行的非 start/end 节点的 output
|
||||
2. 如果没有节点输出,返回 None
|
||||
|
||||
|
||||
Args:
|
||||
node_outputs: 所有节点的输出
|
||||
|
||||
|
||||
Returns:
|
||||
最终输出字符串或 None
|
||||
"""
|
||||
if not node_outputs:
|
||||
return None
|
||||
|
||||
|
||||
# 获取最后一个节点的输出
|
||||
last_node_output = list(node_outputs.values())[-1] if node_outputs else None
|
||||
|
||||
|
||||
if last_node_output and isinstance(last_node_output, dict):
|
||||
return last_node_output.get("output")
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _aggregate_token_usage(self, node_outputs: dict[str, Any]) -> dict[str, int] | None:
|
||||
"""聚合所有节点的 token 使用情况
|
||||
|
||||
|
||||
Args:
|
||||
node_outputs: 所有节点的输出
|
||||
|
||||
|
||||
Returns:
|
||||
聚合的 token 使用情况 {"prompt_tokens": x, "completion_tokens": y, "total_tokens": z}
|
||||
如果没有 token 使用信息,返回 None
|
||||
@@ -359,7 +592,7 @@ class WorkflowExecutor:
|
||||
total_completion_tokens = 0
|
||||
total_tokens = 0
|
||||
has_token_info = False
|
||||
|
||||
|
||||
for node_output in node_outputs.values():
|
||||
if isinstance(node_output, dict):
|
||||
token_usage = node_output.get("token_usage")
|
||||
@@ -368,33 +601,33 @@ class WorkflowExecutor:
|
||||
total_prompt_tokens += token_usage.get("prompt_tokens", 0)
|
||||
total_completion_tokens += token_usage.get("completion_tokens", 0)
|
||||
total_tokens += token_usage.get("total_tokens", 0)
|
||||
|
||||
|
||||
if not has_token_info:
|
||||
return None
|
||||
|
||||
|
||||
return {
|
||||
"prompt_tokens": total_prompt_tokens,
|
||||
"completion_tokens": total_completion_tokens,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
|
||||
|
||||
|
||||
async def execute_workflow(
|
||||
workflow_config: dict[str, Any],
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
workflow_config: dict[str, Any],
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
) -> dict[str, Any]:
|
||||
"""执行工作流(便捷函数)
|
||||
|
||||
|
||||
Args:
|
||||
workflow_config: 工作流配置
|
||||
input_data: 输入数据
|
||||
execution_id: 执行 ID
|
||||
workspace_id: 工作空间 ID
|
||||
user_id: 用户 ID
|
||||
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
@@ -408,21 +641,21 @@ async def execute_workflow(
|
||||
|
||||
|
||||
async def execute_workflow_stream(
|
||||
workflow_config: dict[str, Any],
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
workflow_config: dict[str, Any],
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
):
|
||||
"""执行工作流(流式,便捷函数)
|
||||
|
||||
|
||||
Args:
|
||||
workflow_config: 工作流配置
|
||||
input_data: 输入数据
|
||||
execution_id: 执行 ID
|
||||
workspace_id: 工作空间 ID
|
||||
user_id: 用户 ID
|
||||
|
||||
|
||||
Yields:
|
||||
流式事件
|
||||
"""
|
||||
@@ -434,3 +667,179 @@ async def execute_workflow_stream(
|
||||
)
|
||||
async for event in executor.execute_stream(input_data):
|
||||
yield event
|
||||
|
||||
|
||||
# ==================== 工具管理系统集成 ====================
|
||||
|
||||
# def get_workflow_tools(workspace_id: str, user_id: str) -> list:
|
||||
# """获取工作流可用的工具列表
|
||||
#
|
||||
# Args:
|
||||
# workspace_id: 工作空间ID
|
||||
# user_id: 用户ID
|
||||
#
|
||||
# Returns:
|
||||
# 可用工具列表
|
||||
# """
|
||||
# if not TOOL_MANAGEMENT_AVAILABLE:
|
||||
# logger.warning("工具管理系统不可用")
|
||||
# return []
|
||||
#
|
||||
# try:
|
||||
# db = next(get_db())
|
||||
#
|
||||
# # 创建工具注册表
|
||||
# registry = ToolRegistry(db)
|
||||
#
|
||||
# # 注册内置工具类
|
||||
# from app.core.tools.builtin import (
|
||||
# DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool
|
||||
# )
|
||||
# registry.register_tool_class(DateTimeTool)
|
||||
# registry.register_tool_class(JsonTool)
|
||||
# registry.register_tool_class(BaiduSearchTool)
|
||||
# registry.register_tool_class(MinerUTool)
|
||||
# registry.register_tool_class(TextInTool)
|
||||
#
|
||||
# # 获取活跃的工具
|
||||
# import uuid
|
||||
# tools = registry.list_tools(workspace_id=uuid.UUID(workspace_id))
|
||||
# active_tools = [tool for tool in tools if tool.status.value == "active"]
|
||||
#
|
||||
# # 转换为Langchain工具
|
||||
# langchain_tools = []
|
||||
# for tool_info in active_tools:
|
||||
# try:
|
||||
# tool_instance = registry.get_tool(tool_info.id)
|
||||
# if tool_instance:
|
||||
# langchain_tool = LangchainAdapter.convert_tool(tool_instance)
|
||||
# langchain_tools.append(langchain_tool)
|
||||
# except Exception as e:
|
||||
# logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}")
|
||||
#
|
||||
# logger.info(f"为工作流获取了 {len(langchain_tools)} 个工具")
|
||||
# return langchain_tools
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"获取工作流工具失败: {e}")
|
||||
# return []
|
||||
#
|
||||
#
|
||||
# class ToolWorkflowNode:
|
||||
# """工具工作流节点 - 在工作流中执行工具"""
|
||||
#
|
||||
# def __init__(self, node_config: dict, workflow_config: dict):
|
||||
# """初始化工具节点
|
||||
#
|
||||
# Args:
|
||||
# node_config: 节点配置
|
||||
# workflow_config: 工作流配置
|
||||
# """
|
||||
# self.node_config = node_config
|
||||
# self.workflow_config = workflow_config
|
||||
# self.tool_id = node_config.get("tool_id")
|
||||
# self.tool_parameters = node_config.get("parameters", {})
|
||||
#
|
||||
# async def run(self, state: WorkflowState) -> WorkflowState:
|
||||
# """执行工具节点"""
|
||||
# if not TOOL_MANAGEMENT_AVAILABLE:
|
||||
# logger.error("工具管理系统不可用")
|
||||
# state["error"] = "工具管理系统不可用"
|
||||
# return state
|
||||
#
|
||||
# try:
|
||||
# from sqlalchemy.orm import Session
|
||||
# db = next(get_db())
|
||||
#
|
||||
# # 创建工具执行器
|
||||
# registry = ToolRegistry(db)
|
||||
# executor = ToolExecutor(db, registry)
|
||||
#
|
||||
# # 准备参数(支持变量替换)
|
||||
# parameters = self._prepare_parameters(state)
|
||||
#
|
||||
# # 执行工具
|
||||
# result = await executor.execute_tool(
|
||||
# tool_id=self.tool_id,
|
||||
# parameters=parameters,
|
||||
# user_id=uuid.UUID(state["user_id"]),
|
||||
# workspace_id=uuid.UUID(state["workspace_id"])
|
||||
# )
|
||||
#
|
||||
# # 更新状态
|
||||
# node_id = self.node_config.get("id")
|
||||
# if result.success:
|
||||
# state["node_outputs"][node_id] = {
|
||||
# "type": "tool",
|
||||
# "tool_id": self.tool_id,
|
||||
# "output": result.data,
|
||||
# "execution_time": result.execution_time,
|
||||
# "token_usage": result.token_usage
|
||||
# }
|
||||
#
|
||||
# # 更新运行时变量
|
||||
# if isinstance(result.data, dict):
|
||||
# for key, value in result.data.items():
|
||||
# state["runtime_vars"][f"{node_id}.{key}"] = value
|
||||
# else:
|
||||
# state["runtime_vars"][f"{node_id}.result"] = result.data
|
||||
# else:
|
||||
# state["error"] = result.error
|
||||
# state["error_node"] = node_id
|
||||
# state["node_outputs"][node_id] = {
|
||||
# "type": "tool",
|
||||
# "tool_id": self.tool_id,
|
||||
# "error": result.error,
|
||||
# "execution_time": result.execution_time
|
||||
# }
|
||||
#
|
||||
# return state
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"工具节点执行失败: {e}")
|
||||
# state["error"] = str(e)
|
||||
# state["error_node"] = self.node_config.get("id")
|
||||
# return state
|
||||
#
|
||||
# def _prepare_parameters(self, state: WorkflowState) -> dict:
|
||||
# """准备工具参数(支持变量替换)"""
|
||||
# parameters = {}
|
||||
#
|
||||
# for key, value in self.tool_parameters.items():
|
||||
# if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
||||
# # 变量替换
|
||||
# var_path = value[2:-1]
|
||||
#
|
||||
# # 支持多层级变量访问,如 ${sys.message} 或 ${node1.result}
|
||||
# if "." in var_path:
|
||||
# parts = var_path.split(".")
|
||||
# current = state.get("variables", {})
|
||||
#
|
||||
# for part in parts:
|
||||
# if isinstance(current, dict) and part in current:
|
||||
# current = current[part]
|
||||
# else:
|
||||
# # 尝试从运行时变量获取
|
||||
# runtime_key = ".".join(parts)
|
||||
# current = state.get("runtime_vars", {}).get(runtime_key, value)
|
||||
# break
|
||||
#
|
||||
# parameters[key] = current
|
||||
# else:
|
||||
# # 简单变量
|
||||
# variables = state.get("variables", {})
|
||||
# parameters[key] = variables.get(var_path, value)
|
||||
# else:
|
||||
# parameters[key] = value
|
||||
#
|
||||
# return parameters
|
||||
#
|
||||
#
|
||||
# # 注册工具节点到NodeFactory(如果存在)
|
||||
# try:
|
||||
# from app.core.workflow.nodes import NodeFactory
|
||||
# if hasattr(NodeFactory, 'register_node_type'):
|
||||
# NodeFactory.register_node_type("tool", ToolWorkflowNode)
|
||||
# logger.info("工具节点已注册到工作流系统")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"注册工具节点失败: {e}")
|
||||
@@ -5,6 +5,7 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from simpleeval import simple_eval, NameNotDefined, InvalidExpression
|
||||
@@ -59,9 +60,10 @@ class ExpressionEvaluator:
|
||||
"""
|
||||
# 移除 Jinja2 模板语法的花括号(如果存在)
|
||||
expression = expression.strip()
|
||||
if expression.startswith("{{") and expression.endswith("}}"):
|
||||
expression = expression[2:-2].strip()
|
||||
|
||||
# "{{system.message}} == {{ user.messge }}" -> "system.message == user.message"
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
expression = re.sub(pattern, r"\1", expression).strip()
|
||||
|
||||
# 构建命名空间上下文
|
||||
context = {
|
||||
"var": variables, # 用户变量
|
||||
|
||||
@@ -4,13 +4,14 @@
|
||||
提供各种类型的节点实现,用于工作流执行。
|
||||
"""
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.agent import AgentNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.end import EndNode
|
||||
from app.core.workflow.nodes.node_factory import NodeFactory
|
||||
from app.core.workflow.nodes.if_else import IfElseNode
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
|
||||
__all__ = [
|
||||
"BaseNode",
|
||||
@@ -18,7 +19,9 @@ __all__ = [
|
||||
"LLMNode",
|
||||
"AgentNode",
|
||||
"TransformNode",
|
||||
"IfElseNode",
|
||||
"StartNode",
|
||||
"EndNode",
|
||||
"NodeFactory",
|
||||
"WorkflowNode"
|
||||
]
|
||||
|
||||
@@ -50,6 +50,11 @@ class VariableDefinition(BaseModel):
|
||||
description="变量描述"
|
||||
)
|
||||
|
||||
max_length: int = Field(
|
||||
default=200,
|
||||
description="只对字符串类型生效"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
|
||||
@@ -10,6 +10,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, TypedDict, Annotated
|
||||
from operator import add
|
||||
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
@@ -43,6 +44,10 @@ class WorkflowState(TypedDict):
|
||||
# 错误信息(用于错误边)
|
||||
error: str | None
|
||||
error_node: str | None
|
||||
|
||||
# 流式缓冲区(存储节点的实时流式输出)
|
||||
# 格式:{node_id: {"chunks": [...], "full_content": "..."}}
|
||||
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
|
||||
class BaseNode(ABC):
|
||||
@@ -201,19 +206,25 @@ class BaseNode(ABC):
|
||||
return self._wrap_error(str(e), elapsed_time, state)
|
||||
|
||||
async def run_stream(self, state: WorkflowState):
|
||||
"""执行节点(带错误处理和输出包装,流式)
|
||||
"""Execute node with error handling and output wrapping (streaming)
|
||||
|
||||
这个方法由 Executor 调用,负责:
|
||||
1. 时间统计
|
||||
2. 调用节点的 execute_stream() 方法
|
||||
3. 将业务数据包装成标准输出格式
|
||||
4. 错误处理
|
||||
This method is called by the Executor and is responsible for:
|
||||
1. Time tracking
|
||||
2. Calling the node's execute_stream() method
|
||||
3. Using LangGraph's stream writer to send chunks
|
||||
4. Updating streaming buffer in state for downstream nodes
|
||||
5. Wrapping business data into standard output format
|
||||
6. Error handling
|
||||
|
||||
Special handling for End nodes:
|
||||
- End nodes don't send chunks via writer (prefix and LLM content already sent)
|
||||
- End nodes only yield suffix for final result assembly
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
state: Workflow state
|
||||
|
||||
Yields:
|
||||
标准化的流式事件
|
||||
State updates with streaming buffer and final result
|
||||
"""
|
||||
import time
|
||||
|
||||
@@ -222,68 +233,143 @@ class BaseNode(ABC):
|
||||
try:
|
||||
timeout = self.get_timeout()
|
||||
|
||||
# 累积完整结果(用于最后的包装)
|
||||
# Get LangGraph's stream writer for sending custom data
|
||||
writer = get_stream_writer()
|
||||
|
||||
# Check if this is an End node
|
||||
# End nodes CAN send chunks (for suffix), but only after LLM content
|
||||
is_end_node = self.node_type == "end"
|
||||
|
||||
# Check if this node is adjacent to End node (for message type)
|
||||
is_adjacent_to_end = getattr(self, '_is_adjacent_to_end', False)
|
||||
|
||||
# Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others
|
||||
chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk"
|
||||
|
||||
logger.debug(f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})")
|
||||
|
||||
# Accumulate complete result (for final wrapping)
|
||||
chunks = []
|
||||
final_result = None
|
||||
chunk_count = 0
|
||||
|
||||
# 使用异步生成器包装,支持超时
|
||||
async def stream_with_timeout():
|
||||
nonlocal final_result
|
||||
loop_start = asyncio.get_event_loop().time()
|
||||
# Stream chunks in real-time
|
||||
loop_start = asyncio.get_event_loop().time()
|
||||
|
||||
async for item in self.execute_stream(state):
|
||||
# Check timeout
|
||||
if asyncio.get_event_loop().time() - loop_start > timeout:
|
||||
raise TimeoutError()
|
||||
|
||||
async for item in self.execute_stream(state):
|
||||
# 检查超时
|
||||
if asyncio.get_event_loop().time() - loop_start > timeout:
|
||||
raise TimeoutError()
|
||||
# Check if it's a completion marker
|
||||
if isinstance(item, dict) and item.get("__final__"):
|
||||
final_result = item["result"]
|
||||
elif isinstance(item, str):
|
||||
# String is a chunk
|
||||
chunk_count += 1
|
||||
chunks.append(item)
|
||||
full_content = "".join(chunks)
|
||||
|
||||
# 检查是否是完成标记
|
||||
if isinstance(item, dict) and item.get("__final__"):
|
||||
final_result = item["result"]
|
||||
elif isinstance(item, str):
|
||||
# 字符串是 chunk
|
||||
chunks.append(item)
|
||||
# Send chunks for all nodes (including End nodes for suffix)
|
||||
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...")
|
||||
|
||||
# 1. Send via stream writer (for real-time client updates)
|
||||
writer({
|
||||
"type": chunk_type, # "message" or "node_chunk"
|
||||
"node_id": self.node_id,
|
||||
"chunk": item,
|
||||
"full_content": full_content,
|
||||
"chunk_index": chunk_count
|
||||
})
|
||||
|
||||
# 2. Update streaming buffer in state (for downstream nodes)
|
||||
# Only non-End nodes need streaming buffer
|
||||
if not is_end_node:
|
||||
yield {
|
||||
"type": "chunk",
|
||||
"node_id": self.node_id,
|
||||
"content": item,
|
||||
"full_content": "".join(chunks)
|
||||
"streaming_buffer": {
|
||||
self.node_id: {
|
||||
"full_content": full_content,
|
||||
"chunk_count": chunk_count,
|
||||
"is_complete": False
|
||||
}
|
||||
}
|
||||
}
|
||||
else:
|
||||
# 其他类型也当作 chunk 处理
|
||||
chunks.append(str(item))
|
||||
else:
|
||||
# Other types are also treated as chunks
|
||||
chunk_count += 1
|
||||
chunk_str = str(item)
|
||||
chunks.append(chunk_str)
|
||||
full_content = "".join(chunks)
|
||||
|
||||
# Send chunks for all nodes
|
||||
writer({
|
||||
"type": chunk_type, # "message" or "node_chunk"
|
||||
"node_id": self.node_id,
|
||||
"chunk": chunk_str,
|
||||
"full_content": full_content,
|
||||
"chunk_index": chunk_count
|
||||
})
|
||||
|
||||
# Only non-End nodes need streaming buffer
|
||||
if not is_end_node:
|
||||
yield {
|
||||
"type": "chunk",
|
||||
"node_id": self.node_id,
|
||||
"content": str(item),
|
||||
"full_content": "".join(chunks)
|
||||
"streaming_buffer": {
|
||||
self.node_id: {
|
||||
"full_content": full_content,
|
||||
"chunk_count": chunk_count,
|
||||
"is_complete": False
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async for chunk_event in stream_with_timeout():
|
||||
yield chunk_event
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 包装最终结果
|
||||
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
|
||||
|
||||
# Extract processed output (call subclass's _extract_output)
|
||||
extracted_output = self._extract_output(final_result)
|
||||
|
||||
# Wrap final result
|
||||
final_output = self._wrap_output(final_result, elapsed_time, state)
|
||||
yield {
|
||||
"type": "complete",
|
||||
**final_output
|
||||
|
||||
# Store extracted output in runtime variables (for quick access by subsequent nodes)
|
||||
if isinstance(extracted_output, dict):
|
||||
runtime_var = extracted_output
|
||||
else:
|
||||
runtime_var = {"output": extracted_output}
|
||||
|
||||
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
|
||||
state_update = {
|
||||
**final_output,
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
}
|
||||
}
|
||||
|
||||
# Add streaming buffer for non-End nodes
|
||||
if not is_end_node:
|
||||
state_update["streaming_buffer"] = {
|
||||
self.node_id: {
|
||||
"full_content": "".join(chunks),
|
||||
"chunk_count": chunk_count,
|
||||
"is_complete": True # Mark as complete
|
||||
}
|
||||
}
|
||||
|
||||
# Finally yield state update
|
||||
# LangGraph will merge this into state
|
||||
yield state_update
|
||||
|
||||
except TimeoutError:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
|
||||
yield {
|
||||
"type": "error",
|
||||
**self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
|
||||
}
|
||||
logger.error(f"节点 {self.node_id} 执行超时 ({timeout}s)")
|
||||
error_output = self._wrap_error(f"节点执行超时 ({timeout}s)", elapsed_time, state)
|
||||
yield error_output
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
|
||||
yield {
|
||||
"type": "error",
|
||||
**self._wrap_error(str(e), elapsed_time, state)
|
||||
}
|
||||
error_output = self._wrap_error(str(e), elapsed_time, state)
|
||||
yield error_output
|
||||
|
||||
def _wrap_output(
|
||||
self,
|
||||
|
||||
@@ -13,6 +13,7 @@ from app.core.workflow.nodes.end.config import EndNodeConfig
|
||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
|
||||
from app.core.workflow.nodes.agent.config import AgentNodeConfig
|
||||
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
|
||||
|
||||
__all__ = [
|
||||
# 基础类
|
||||
@@ -26,4 +27,5 @@ __all__ = [
|
||||
"MessageConfig",
|
||||
"AgentNodeConfig",
|
||||
"TransformNodeConfig",
|
||||
"IfElseNodeConfig",
|
||||
]
|
||||
|
||||
@@ -5,7 +5,8 @@ End 节点实现
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
import re
|
||||
import asyncio
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
|
||||
@@ -16,6 +17,7 @@ class EndNode(BaseNode):
|
||||
"""End 节点
|
||||
|
||||
工作流的结束节点,根据配置的模板输出最终结果。
|
||||
支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。
|
||||
"""
|
||||
|
||||
async def execute(self, state: WorkflowState) -> str:
|
||||
@@ -31,11 +33,7 @@ class EndNode(BaseNode):
|
||||
|
||||
# 获取配置的输出模板
|
||||
output_template = self.config.get("output")
|
||||
pool = self.get_variable_pool(state)
|
||||
|
||||
print("="*20)
|
||||
print( pool.get("start.test"))
|
||||
print("="*20)
|
||||
|
||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||
if output_template:
|
||||
output = self._render_template(output_template, state)
|
||||
@@ -47,7 +45,228 @@ class EndNode(BaseNode):
|
||||
total_nodes = len(node_outputs)
|
||||
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
|
||||
print("="*20)
|
||||
print(output)
|
||||
print("="*20)
|
||||
|
||||
return output
|
||||
|
||||
def _extract_referenced_nodes(self, template: str) -> list[str]:
|
||||
"""从模板中提取引用的节点 ID
|
||||
|
||||
例如:'结果:{{llm_qa.output}}' -> ['llm_qa']
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
|
||||
Returns:
|
||||
引用的节点 ID 列表
|
||||
"""
|
||||
# 匹配 {{node_id.xxx}} 格式
|
||||
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
|
||||
matches = re.findall(pattern, template)
|
||||
return list(set(matches)) # 去重
|
||||
|
||||
def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]:
|
||||
"""解析模板,分离静态文本和动态引用
|
||||
|
||||
例如:'你好 {{llm.output}}, 这是后缀'
|
||||
返回:[
|
||||
{"type": "static", "content": "你好 "},
|
||||
{"type": "dynamic", "node_id": "llm", "field": "output"},
|
||||
{"type": "static", "content": ", 这是后缀"}
|
||||
]
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
模板部分列表
|
||||
"""
|
||||
import re
|
||||
|
||||
parts = []
|
||||
last_end = 0
|
||||
|
||||
# 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格)
|
||||
pattern = r'\{\{\s*([^}]+?)\s*\}\}'
|
||||
|
||||
for match in re.finditer(pattern, template):
|
||||
start, end = match.span()
|
||||
|
||||
# 添加前面的静态文本
|
||||
if start > last_end:
|
||||
static_text = template[last_end:start]
|
||||
if static_text:
|
||||
parts.append({"type": "static", "content": static_text})
|
||||
|
||||
# 解析动态引用
|
||||
ref = match.group(1).strip()
|
||||
|
||||
# 检查是否是节点引用(如 llm.output 或 llm_qa.output)
|
||||
if '.' in ref:
|
||||
node_id, field = ref.split('.', 1)
|
||||
parts.append({
|
||||
"type": "dynamic",
|
||||
"node_id": node_id,
|
||||
"field": field,
|
||||
"raw": ref
|
||||
})
|
||||
else:
|
||||
# 其他引用(如 {{var.xxx}}),当作静态处理
|
||||
# 直接渲染这部分
|
||||
rendered = self._render_template(f"{{{{{ref}}}}}", state)
|
||||
parts.append({"type": "static", "content": rendered})
|
||||
|
||||
last_end = end
|
||||
|
||||
# 添加最后的静态文本
|
||||
if last_end < len(template):
|
||||
static_text = template[last_end:]
|
||||
if static_text:
|
||||
parts.append({"type": "static", "content": static_text})
|
||||
|
||||
return parts
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""流式执行 end 节点业务逻辑
|
||||
|
||||
智能输出策略:
|
||||
1. 检测模板中是否引用了直接上游节点
|
||||
2. 如果引用了,只输出该引用**之后**的部分(后缀)
|
||||
3. 前缀和引用内容已经在上游节点流式输出时发送了
|
||||
|
||||
示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
|
||||
- 直接上游节点是 llm_qa
|
||||
- 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送
|
||||
- LLM 内容在 LLM 节点流式输出
|
||||
- End 节点只输出 ' lalalalala a'(后缀,一次性输出)
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Yields:
|
||||
完成标记
|
||||
"""
|
||||
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
|
||||
|
||||
# 获取配置的输出模板
|
||||
output_template = self.config.get("output")
|
||||
|
||||
if not output_template:
|
||||
output = "工作流已完成"
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
# 找到直接上游节点
|
||||
direct_upstream_nodes = []
|
||||
for edge in self.workflow_config.get("edges", []):
|
||||
if edge.get("target") == self.node_id:
|
||||
source_node_id = edge.get("source")
|
||||
direct_upstream_nodes.append(source_node_id)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}")
|
||||
|
||||
# 解析模板部分
|
||||
parts = self._parse_template_parts(output_template, state)
|
||||
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
|
||||
|
||||
# 找到第一个引用直接上游节点的动态引用
|
||||
upstream_ref_index = None
|
||||
for i, part in enumerate(parts):
|
||||
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_nodes:
|
||||
upstream_ref_index = i
|
||||
logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}")
|
||||
break
|
||||
|
||||
if upstream_ref_index is None:
|
||||
# 没有引用直接上游节点,输出完整模板内容
|
||||
output = self._render_template(output_template, state)
|
||||
logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容: '{output[:50]}...'")
|
||||
|
||||
# 通过 writer 发送完整内容(作为一个 message chunk)
|
||||
from langgraph.config import get_stream_writer
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"type": "message", # End 节点的输出使用 message 类型
|
||||
"node_id": self.node_id,
|
||||
"chunk": output,
|
||||
"full_content": output,
|
||||
"chunk_index": 1,
|
||||
"is_suffix": False
|
||||
})
|
||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
|
||||
|
||||
# yield 完成标记
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
# 有引用直接上游节点,只输出该引用之后的部分(后缀)
|
||||
logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)")
|
||||
|
||||
# 收集后缀部分
|
||||
suffix_parts = []
|
||||
for i in range(upstream_ref_index + 1, len(parts)):
|
||||
part = parts[i]
|
||||
|
||||
if part["type"] == "static":
|
||||
# 静态文本
|
||||
suffix_parts.append(part["content"])
|
||||
|
||||
elif part["type"] == "dynamic":
|
||||
# 其他动态引用(如果有多个引用)
|
||||
node_id = part["node_id"]
|
||||
field = part["field"]
|
||||
|
||||
# 从 streaming_buffer 或 node_outputs 读取
|
||||
streaming_buffer = state.get("streaming_buffer", {})
|
||||
if node_id in streaming_buffer:
|
||||
buffer_data = streaming_buffer[node_id]
|
||||
content = buffer_data.get("full_content", "")
|
||||
else:
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
runtime_vars = state.get("runtime_vars", {})
|
||||
|
||||
content = ""
|
||||
if node_id in node_outputs:
|
||||
node_output = node_outputs[node_id]
|
||||
if isinstance(node_output, dict):
|
||||
content = str(node_output.get(field, ""))
|
||||
elif node_id in runtime_vars:
|
||||
runtime_var = runtime_vars[node_id]
|
||||
if isinstance(runtime_var, dict):
|
||||
content = str(runtime_var.get(field, ""))
|
||||
|
||||
suffix_parts.append(content)
|
||||
|
||||
# 拼接后缀
|
||||
suffix = "".join(suffix_parts)
|
||||
|
||||
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
|
||||
full_output = self._render_template(output_template, state)
|
||||
|
||||
if suffix:
|
||||
logger.info(f"节点 {self.node_id} 输出后缀: '{suffix[:50]}...' (长度: {len(suffix)})")
|
||||
# 一次性输出后缀(作为单个 chunk)
|
||||
# 注意:不要直接 yield 字符串,因为 base_node 会逐字符处理
|
||||
# 而是通过 writer 直接发送
|
||||
from langgraph.config import get_stream_writer
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"type": "message", # End 节点的输出使用 message 类型
|
||||
"node_id": self.node_id,
|
||||
"chunk": suffix,
|
||||
"full_content": full_output, # full_content 是完整的渲染结果(前缀+LLM+后缀)
|
||||
"chunk_index": 1,
|
||||
"is_suffix": True
|
||||
})
|
||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}")
|
||||
else:
|
||||
logger.info(f"节点 {self.node_id} 没有后缀需要输出")
|
||||
|
||||
# 统计信息
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
total_nodes = len(node_outputs)
|
||||
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点")
|
||||
|
||||
# yield 完成标记(包含完整输出)
|
||||
yield {"__final__": True, "result": full_output}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class NodeType(StrEnum):
|
||||
START = "start"
|
||||
END = "end"
|
||||
@@ -13,3 +14,23 @@ class NodeType(StrEnum):
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
AGENT = "agent"
|
||||
|
||||
|
||||
class ComparisonOperator(StrEnum):
|
||||
EMPTY = "empty"
|
||||
NOT_EMPTY = "not_empty"
|
||||
CONTAINS = "contains"
|
||||
NOT_CONTAINS = "not_contains"
|
||||
START_WITH = "startwith"
|
||||
END_WITH = "endwith"
|
||||
EQ = "eq"
|
||||
NE = "ne"
|
||||
LT = "lt"
|
||||
LE = "le"
|
||||
GT = "gt"
|
||||
GE = "ge"
|
||||
|
||||
|
||||
class LogicOperator(StrEnum):
|
||||
AND = "and"
|
||||
OR = "or"
|
||||
|
||||
5
api/app/core/workflow/nodes/if_else/__init__.py
Normal file
5
api/app/core/workflow/nodes/if_else/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Condition Node"""
|
||||
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.if_else.node import IfElseNode
|
||||
|
||||
__all__ = ["IfElseNode", "IfElseNodeConfig"]
|
||||
97
api/app/core/workflow/nodes/if_else/config.py
Normal file
97
api/app/core/workflow/nodes/if_else/config.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Condition Configuration"""
|
||||
from pydantic import Field, BaseModel, field_validator
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
|
||||
|
||||
|
||||
class ConditionDetail(BaseModel):
|
||||
comparison_operator: ComparisonOperator = Field(
|
||||
...,
|
||||
description="Comparison operator used to evaluate the condition"
|
||||
)
|
||||
|
||||
left: str = Field(
|
||||
...,
|
||||
description="Value to compare against"
|
||||
)
|
||||
|
||||
right: str = Field(
|
||||
...,
|
||||
description="Value to compare with"
|
||||
)
|
||||
|
||||
|
||||
class ConditionBranchConfig(BaseModel):
|
||||
"""Configuration for a conditional branch"""
|
||||
|
||||
logical_operator: LogicOperator = Field(
|
||||
default=LogicOperator.AND.value,
|
||||
description="Logical operator used to combine multiple condition expressions"
|
||||
)
|
||||
|
||||
conditions: list[ConditionDetail] = Field(
|
||||
...,
|
||||
description="List of condition expressions within this branch"
|
||||
)
|
||||
|
||||
|
||||
class IfElseNodeConfig(BaseNodeConfig):
|
||||
cases: list[ConditionBranchConfig] = Field(
|
||||
...,
|
||||
description="List of branch conditions or expressions"
|
||||
)
|
||||
|
||||
@field_validator("cases")
|
||||
@classmethod
|
||||
def validate_case_number(cls, v, info):
|
||||
if len(v) < 1:
|
||||
raise ValueError("At least one cases are required")
|
||||
return v
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
{
|
||||
"cases": [
|
||||
# CASE1 / IF Branch
|
||||
{
|
||||
"logical_operator": "and",
|
||||
"conditions": [
|
||||
[
|
||||
{
|
||||
"left": "node.userinput.message",
|
||||
"comparison_operator": "eq",
|
||||
"right": "'123'"
|
||||
},
|
||||
{
|
||||
"left": "node.userinput.test",
|
||||
"comparison_operator": "eq",
|
||||
"right": "True"
|
||||
}
|
||||
]
|
||||
]
|
||||
},
|
||||
# CASE1 / ELIF Branch
|
||||
{
|
||||
"logical_operator": "or",
|
||||
"conditions": [
|
||||
[
|
||||
{
|
||||
"left": "node.userinput.test",
|
||||
"comparison_operator": "eq",
|
||||
"right": "False"
|
||||
},
|
||||
{
|
||||
"left": "node.userinput.message",
|
||||
"comparison_operator": "contains",
|
||||
"right": "'123'"
|
||||
}
|
||||
]
|
||||
]
|
||||
}
|
||||
# CASE3 / ELSE Branch
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
167
api/app/core/workflow/nodes/if_else/node.py
Normal file
167
api/app/core/workflow/nodes/if_else/node.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.if_else.config import ConditionDetail
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConditionExpressionBuilder:
|
||||
"""
|
||||
Build a Python boolean expression string based on a comparison operator.
|
||||
|
||||
This class does not evaluate the expression.
|
||||
It only generates a valid Python expression string
|
||||
that can be evaluated later in a workflow context.
|
||||
"""
|
||||
|
||||
def __init__(self, left: str, operator: ComparisonOperator, right: str):
|
||||
self.left = left
|
||||
self.operator = operator
|
||||
self.right = right
|
||||
|
||||
def _empty(self):
|
||||
return f"{self.left} == ''"
|
||||
|
||||
def _not_empty(self):
|
||||
return f"{self.left} != ''"
|
||||
|
||||
def _contains(self):
|
||||
return f"{self.right} in {self.left}"
|
||||
|
||||
def _not_contains(self):
|
||||
return f"{self.right} not in {self.left}"
|
||||
|
||||
def _startwith(self):
|
||||
return f'{self.left}.startswith({self.right})'
|
||||
|
||||
def _endwith(self):
|
||||
return f'{self.left}.endswith({self.right})'
|
||||
|
||||
def _eq(self):
|
||||
return f"{self.left} == {self.right}"
|
||||
|
||||
def _ne(self):
|
||||
return f"{self.left} != {self.right}"
|
||||
|
||||
def _lt(self):
|
||||
return f"{self.left} < {self.right}"
|
||||
|
||||
def _le(self):
|
||||
return f"{self.left} <= {self.right}"
|
||||
|
||||
def _gt(self):
|
||||
return f"{self.left} > {self.right}"
|
||||
|
||||
def _ge(self):
|
||||
return f"{self.left} >= {self.right}"
|
||||
|
||||
def build(self):
|
||||
match self.operator:
|
||||
case ComparisonOperator.EMPTY:
|
||||
return self._empty()
|
||||
case ComparisonOperator.NOT_EMPTY:
|
||||
return self._not_empty()
|
||||
case ComparisonOperator.CONTAINS:
|
||||
return self._contains()
|
||||
case ComparisonOperator.NOT_CONTAINS:
|
||||
return self._not_contains()
|
||||
case ComparisonOperator.START_WITH:
|
||||
return self._startwith()
|
||||
case ComparisonOperator.END_WITH:
|
||||
return self._endwith()
|
||||
case ComparisonOperator.EQ:
|
||||
return self._eq()
|
||||
case ComparisonOperator.NE:
|
||||
return self._ne()
|
||||
case ComparisonOperator.LT:
|
||||
return self._lt()
|
||||
case ComparisonOperator.LE:
|
||||
return self._le()
|
||||
case ComparisonOperator.GT:
|
||||
return self._gt()
|
||||
case ComparisonOperator.GE:
|
||||
return self._ge()
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {self.operator}")
|
||||
|
||||
|
||||
class IfElseNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = IfElseNodeConfig(**self.config)
|
||||
|
||||
@staticmethod
|
||||
def _build_condition_expression(
|
||||
condition: ConditionDetail,
|
||||
) -> str:
|
||||
"""
|
||||
Build a single boolean condition expression string.
|
||||
|
||||
This method does NOT evaluate the condition.
|
||||
It only generates a valid Python boolean expression string
|
||||
(e.g. "x > 10", "'a' in name") that can later be used
|
||||
in a conditional edge or evaluated by the workflow engine.
|
||||
|
||||
Args:
|
||||
condition (ConditionDetail): Definition of a single comparison condition.
|
||||
|
||||
Returns:
|
||||
str: A Python boolean expression string.
|
||||
"""
|
||||
return ConditionExpressionBuilder(
|
||||
left=condition.left,
|
||||
operator=condition.comparison_operator,
|
||||
right=condition.right
|
||||
).build()
|
||||
|
||||
def build_conditional_edge_expressions(self) -> list[str]:
|
||||
"""
|
||||
Build conditional edge expressions for the If-Else node.
|
||||
|
||||
This method does NOT evaluate any condition at runtime.
|
||||
Instead, it converts each case branch into a Python boolean
|
||||
expression string, which will later be attached to LangGraph
|
||||
as conditional edges.
|
||||
|
||||
Each returned expression corresponds to one branch and is
|
||||
evaluated in order. A fallback 'True' condition is appended
|
||||
to ensure a default branch when no previous conditions match.
|
||||
|
||||
Returns:
|
||||
list[str]: A list of Python boolean expression strings,
|
||||
ordered by branch priority.
|
||||
"""
|
||||
branch_index = 0
|
||||
conditions = []
|
||||
|
||||
for case_branch in self.typed_config.cases:
|
||||
branch_index += 1
|
||||
|
||||
branch_conditions = [
|
||||
self._build_condition_expression(condition)
|
||||
for condition in case_branch.conditions
|
||||
]
|
||||
if len(branch_conditions) > 1:
|
||||
combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions)
|
||||
else:
|
||||
combined_condition = branch_conditions[0]
|
||||
conditions.append(combined_condition)
|
||||
|
||||
# Default fallback branch
|
||||
conditions.append("True")
|
||||
|
||||
return conditions
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""
|
||||
"""
|
||||
expressions = self.build_conditional_edge_expressions()
|
||||
for i in range(len(expressions)):
|
||||
logger.info(expressions[i])
|
||||
if self._evaluate_condition(expressions[i], state):
|
||||
return f'CASE{i+1}'
|
||||
return f'CASE{len(expressions)}'
|
||||
@@ -10,10 +10,8 @@ from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models import ModelConfig
|
||||
from app.db import get_db, get_db_context
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService
|
||||
from app.db import get_db_context
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
@@ -65,7 +63,7 @@ class LLMNode(BaseNode):
|
||||
- ai/assistant: AI 消息(AIMessage)
|
||||
"""
|
||||
|
||||
def _prepare_llm(self, state: WorkflowState) -> tuple[RedBearLLM, list | str]:
|
||||
def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]:
|
||||
"""准备 LLM 实例(公共逻辑)
|
||||
|
||||
Args:
|
||||
@@ -127,16 +125,22 @@ class LLMNode(BaseNode):
|
||||
model_type = config.type
|
||||
|
||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
|
||||
extra_params = {"streaming": stream} if stream else {}
|
||||
|
||||
llm = RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base
|
||||
base_url=api_base,
|
||||
extra_params=extra_params
|
||||
),
|
||||
type=model_type
|
||||
)
|
||||
|
||||
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
|
||||
|
||||
return llm, prompt_or_messages
|
||||
|
||||
async def execute(self, state: WorkflowState) -> AIMessage:
|
||||
@@ -148,13 +152,12 @@ class LLMNode(BaseNode):
|
||||
Returns:
|
||||
LLM 响应消息
|
||||
"""
|
||||
llm, prompt_or_messages = self._prepare_llm(state)
|
||||
llm, prompt_or_messages = self._prepare_llm(state,True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||
|
||||
# 调用 LLM(支持字符串或消息列表)
|
||||
response = await llm.ainvoke(prompt_or_messages)
|
||||
|
||||
# 提取内容
|
||||
if hasattr(response, 'content'):
|
||||
content = response.content
|
||||
@@ -210,13 +213,43 @@ class LLMNode(BaseNode):
|
||||
Yields:
|
||||
文本片段(chunk)或完成标记
|
||||
"""
|
||||
llm, prompt_or_messages = self._prepare_llm(state)
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||
|
||||
# 检查是否有注入的 End 节点前缀配置
|
||||
writer = get_stream_writer()
|
||||
end_prefix = getattr(self, '_end_node_prefix', None)
|
||||
|
||||
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
|
||||
if end_prefix:
|
||||
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
|
||||
|
||||
if end_prefix:
|
||||
# 渲染前缀(可能包含其他变量)
|
||||
try:
|
||||
rendered_prefix = self._render_template(end_prefix, state)
|
||||
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
|
||||
|
||||
# 提前发送 End 节点的前缀(使用 "message" 类型)
|
||||
writer({
|
||||
"type": "message", # End 相关的内容都是 message 类型
|
||||
"node_id": "end", # 标记为 end 节点的输出
|
||||
"chunk": rendered_prefix,
|
||||
"full_content": rendered_prefix,
|
||||
"chunk_index": 0,
|
||||
"is_prefix": True # 标记这是前缀
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
|
||||
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
last_chunk = None
|
||||
chunk_count = 0
|
||||
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
async for chunk in llm.astream(prompt_or_messages):
|
||||
@@ -226,13 +259,16 @@ class LLMNode(BaseNode):
|
||||
else:
|
||||
content = str(chunk)
|
||||
|
||||
full_response += content
|
||||
last_chunk = chunk
|
||||
|
||||
# 流式返回每个文本片段
|
||||
yield content
|
||||
# 只有当内容不为空时才处理
|
||||
if content:
|
||||
full_response += content
|
||||
last_chunk = chunk
|
||||
chunk_count += 1
|
||||
|
||||
# 流式返回每个文本片段
|
||||
yield content
|
||||
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
||||
|
||||
# 构建完整的 AIMessage(包含元数据)
|
||||
if isinstance(last_chunk, AIMessage):
|
||||
|
||||
@@ -5,18 +5,29 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, Union
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.agent import AgentNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.end import EndNode
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.nodes.if_else import IfElseNode
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WorkflowNode = Union[
|
||||
BaseNode,
|
||||
StartNode,
|
||||
EndNode,
|
||||
LLMNode,
|
||||
IfElseNode,
|
||||
AgentNode,
|
||||
TransformNode,
|
||||
]
|
||||
|
||||
|
||||
class NodeFactory:
|
||||
"""节点工厂
|
||||
@@ -25,16 +36,17 @@ class NodeFactory:
|
||||
"""
|
||||
|
||||
# 节点类型注册表
|
||||
_node_types: dict[str, type[BaseNode]] = {
|
||||
_node_types: dict[str, type[WorkflowNode]] = {
|
||||
NodeType.START: StartNode,
|
||||
NodeType.END: EndNode,
|
||||
NodeType.LLM: LLMNode,
|
||||
NodeType.AGENT: AgentNode,
|
||||
NodeType.TRANSFORM: TransformNode,
|
||||
NodeType.IF_ELSE: IfElseNode
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_node_type(cls, node_type: str, node_class: type[BaseNode]):
|
||||
def register_node_type(cls, node_type: str, node_class: type[WorkflowNode]):
|
||||
"""注册新的节点类型
|
||||
|
||||
Args:
|
||||
@@ -52,10 +64,10 @@ class NodeFactory:
|
||||
|
||||
@classmethod
|
||||
def create_node(
|
||||
cls,
|
||||
node_config: dict[str, Any],
|
||||
workflow_config: dict[str, Any]
|
||||
) -> BaseNode | None:
|
||||
cls,
|
||||
node_config: dict[str, Any],
|
||||
workflow_config: dict[str, Any]
|
||||
) -> WorkflowNode | None:
|
||||
"""创建节点实例
|
||||
|
||||
Args:
|
||||
|
||||
@@ -20,6 +20,11 @@ from .data_config_model import DataConfig
|
||||
from .multi_agent_model import MultiAgentConfig, AgentInvocation
|
||||
from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution
|
||||
from .retrieval_info import RetrievalInfo
|
||||
from .prompt_optimizer_model import PromptOptimizerSession, PromptOptimizerSessionHistory
|
||||
from .tool_model import (
|
||||
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
|
||||
ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Tenants",
|
||||
@@ -54,5 +59,17 @@ __all__ = [
|
||||
"WorkflowConfig",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
"RetrievalInfo"
|
||||
"RetrievalInfo",
|
||||
"PromptOptimizerSession",
|
||||
"PromptOptimizerSessionHistory",
|
||||
"RetrievalInfo",
|
||||
"ToolConfig",
|
||||
"BuiltinToolConfig",
|
||||
"CustomToolConfig",
|
||||
"MCPToolConfig",
|
||||
"ToolExecution",
|
||||
"ToolType",
|
||||
"ToolStatus",
|
||||
"AuthType",
|
||||
"ExecutionStatus"
|
||||
]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from app.db import Base
|
||||
@@ -11,50 +10,53 @@ class DataConfig(Base):
|
||||
|
||||
# 主键
|
||||
config_id = Column(Integer, primary_key=True, autoincrement=True, comment="配置ID")
|
||||
|
||||
|
||||
# 基本信息
|
||||
config_name = Column(String, nullable=False, comment="配置名称")
|
||||
config_desc = Column(String, nullable=True, comment="配置描述")
|
||||
|
||||
|
||||
# 组织信息
|
||||
workspace_id = Column(UUID(as_uuid=True), nullable=True, comment="工作空间ID")
|
||||
group_id = Column(String, nullable=True, comment="组ID")
|
||||
user_id = Column(String, nullable=True, comment="用户ID")
|
||||
apply_id = Column(String, nullable=True, comment="应用ID")
|
||||
|
||||
|
||||
# 模型选择(从workspace继承)
|
||||
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
|
||||
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
|
||||
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
|
||||
llm = Column(String, nullable=True, comment="LLM模型配置ID")
|
||||
|
||||
|
||||
# 记忆萃取引擎配置
|
||||
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
|
||||
enable_llm_disambiguation = Column(Boolean, default=True, comment="启用LLM决策消歧")
|
||||
deep_retrieval = Column(Boolean, default=True, comment="深度检索开关")
|
||||
|
||||
|
||||
# 阈值配置 (0-1 之间的浮点数)
|
||||
t_type_strict = Column(Float, default=0.8, comment="类型严格阈值")
|
||||
t_name_strict = Column(Float, default=0.8, comment="名称严格阈值")
|
||||
t_overall = Column(Float, default=0.8, comment="综合阈值")
|
||||
|
||||
|
||||
# 状态配置
|
||||
state = Column(Boolean, default=False, comment="配置使用状态")
|
||||
|
||||
|
||||
# 分块策略
|
||||
chunker_strategy = Column(String, default="RecursiveChunker", comment="分块策略")
|
||||
|
||||
|
||||
# 剪枝配置
|
||||
pruning_enabled = Column(Boolean, default=False, comment="是否启动智能语义剪枝")
|
||||
pruning_scene = Column(String, nullable=True, comment="智能剪枝场景:education/online_service/outbound")
|
||||
pruning_threshold = Column(Float, nullable=True, comment="智能语义剪枝阈值(0-0.9)")
|
||||
|
||||
|
||||
# 自我反思配置
|
||||
enable_self_reflexion = Column(Boolean, default=False, comment="是否启用自我反思")
|
||||
iteration_period = Column(String, default="3", comment="反思迭代周期")
|
||||
reflexion_range = Column(String, default="retrieval", comment="反思范围:部分/全部")
|
||||
baseline = Column(String, default="time", comment="基线:时间/事实/时间和事实")
|
||||
|
||||
reflection_model_id = Column(String, nullable=True, comment="反思模型ID")
|
||||
memory_verify = Column(Boolean, default=True, comment="记忆验证")
|
||||
quality_assessment = Column(Boolean, default=True, comment="质量评估")
|
||||
|
||||
# 遗忘引擎配置
|
||||
statement_granularity = Column(Integer, default=2, comment="陈述提取颗粒度,挡位 1/2/3")
|
||||
include_dialogue_context = Column(Boolean, default=False, comment="是否包含对话上下文")
|
||||
@@ -63,6 +65,13 @@ class DataConfig(Base):
|
||||
lambda_mem = Column("lambda_mem", Float, default=0.5, comment="遗忘率,0-1 小数")
|
||||
offset = Column("offset", Float, default=0.0, comment="偏移度,0-1 小数")
|
||||
|
||||
# 情绪引擎配置
|
||||
emotion_enabled = Column(Boolean, default=True, comment="是否启用情绪提取")
|
||||
emotion_model_id = Column(String, nullable=True, comment="情绪分析专用模型ID")
|
||||
emotion_extract_keywords = Column(Boolean, default=True, comment="是否提取情绪关键词")
|
||||
emotion_min_intensity = Column(Float, default=0.1, comment="最小情绪强度阈值")
|
||||
emotion_enable_subject = Column(Boolean, default=True, comment="是否启用主体分类")
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
|
||||
|
||||
@@ -16,7 +16,26 @@ class Document(Base):
|
||||
file_size = Column(Integer, default=0, comment="file size(byte)")
|
||||
file_meta = Column(JSON, nullable=False, default={})
|
||||
parser_id = Column(String, index=True, nullable=False, comment="default parser ID")
|
||||
parser_config = Column(JSON, nullable=False, default={"layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n"}, comment="default parser config")
|
||||
parser_config = Column(JSON, nullable=False,
|
||||
default={
|
||||
"layout_recognize": "DeepDOC",
|
||||
"chunk_token_num": 128,
|
||||
"delimiter": "\n",
|
||||
"auto_keywords": 0,
|
||||
"auto_questions": 0,
|
||||
"html4excel": False,
|
||||
"graphrag": {
|
||||
"use_graphrag": False,
|
||||
"entity_types": [
|
||||
"organization",
|
||||
"person",
|
||||
"geo",
|
||||
"event",
|
||||
"category",
|
||||
],
|
||||
"method": "general",
|
||||
}
|
||||
}, comment="default parser config")
|
||||
chunk_num = Column(Integer, default=0, comment="chunk num")
|
||||
progress = Column(Float, default=0)
|
||||
progress_msg = Column(String, default="", comment="process message")
|
||||
|
||||
@@ -14,6 +14,7 @@ class EndUser(Base):
|
||||
other_id = Column(String, nullable=True) # Store original user_id
|
||||
other_name = Column(String, default="", nullable=False)
|
||||
other_address = Column(String, default="", nullable=False)
|
||||
reflection_time = Column(DateTime, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.datetime.now)
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
||||
|
||||
|
||||
@@ -56,7 +56,25 @@ class Knowledge(Base):
|
||||
chunk_num = Column(Integer, default=0, comment="chunk num")
|
||||
parser_id = Column(String, index=True, default="naive", comment="default parser ID")
|
||||
parser_config = Column(JSON, nullable=False,
|
||||
default={"layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n"},
|
||||
default={
|
||||
"layout_recognize": "DeepDOC",
|
||||
"chunk_token_num": 128,
|
||||
"delimiter": "\n",
|
||||
"auto_keywords": 0,
|
||||
"auto_questions": 0,
|
||||
"html4excel": False,
|
||||
"graphrag": {
|
||||
"use_graphrag": False,
|
||||
"entity_types": [
|
||||
"organization",
|
||||
"person",
|
||||
"geo",
|
||||
"event",
|
||||
"category",
|
||||
],
|
||||
"method": "general",
|
||||
}
|
||||
},
|
||||
comment="default parser config")
|
||||
status = Column(Integer, index=True, default=1, comment="is it validate(0: disable, 1: enable, 2:Soft-delete)")
|
||||
created_at = Column(DateTime, default=datetime.datetime.now)
|
||||
|
||||
@@ -15,6 +15,25 @@ class ModelType(StrEnum):
|
||||
EMBEDDING = "embedding"
|
||||
RERANK = "rerank"
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, value: str) -> "ModelType":
|
||||
"""
|
||||
Get a ModelType enum instance from a string value.
|
||||
|
||||
Args:
|
||||
value (str): The string representation of the model type.
|
||||
|
||||
Returns:
|
||||
ModelType: The corresponding ModelType enum object.
|
||||
|
||||
Raises:
|
||||
ValueError: If the given value does not match any ModelType.
|
||||
"""
|
||||
try:
|
||||
return cls(value)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid ModelType: {value}")
|
||||
|
||||
|
||||
class ModelProvider(StrEnum):
|
||||
"""模型提供商枚举"""
|
||||
|
||||
130
api/app/models/prompt_optimizer_model.py
Normal file
130
api/app/models/prompt_optimizer_model.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlalchemy import Column, ForeignKey, Text, DateTime, String, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from app.db import Base
|
||||
|
||||
|
||||
class RoleType(StrEnum):
|
||||
"""
|
||||
Enumeration of message roles used in prompt optimization conversations.
|
||||
|
||||
This enum standardizes the role identifiers for messages stored in the
|
||||
prompt optimization session history, ensuring consistency across
|
||||
system-generated messages, user inputs, and assistant responses.
|
||||
|
||||
Attributes:
|
||||
SYSTEM (str): Represents system-level instructions or prompts that
|
||||
define the behavior or constraints of the assistant.
|
||||
USER (str): Represents messages originating from the end user.
|
||||
ASSISTANT (str): Represents messages generated by the AI assistant.
|
||||
"""
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
|
||||
class PromptOptimizerSession(Base):
|
||||
"""
|
||||
Prompt Optimization Session Registry.
|
||||
|
||||
This table records high-level metadata for prompt optimization sessions.
|
||||
Each record represents a single logical session initiated by a user
|
||||
under a specific tenant.
|
||||
|
||||
The session acts as a container for multiple conversation messages
|
||||
stored in the session history table.
|
||||
|
||||
Table Name:
|
||||
prompt_opt_session_list
|
||||
|
||||
Columns:
|
||||
id (UUID):
|
||||
Public-facing session identifier used to group conversation history.
|
||||
tenant_id (UUID):
|
||||
Foreign key referencing `tenants.id`.
|
||||
Identifies the tenant under which the session is created.
|
||||
user_id (UUID):
|
||||
Foreign key referencing `users.id`.
|
||||
Identifies the user who initiated the session.
|
||||
created_at (DateTime):
|
||||
Timestamp indicating when the session was created.
|
||||
|
||||
Design Notes:
|
||||
- This table intentionally does not store message content
|
||||
- Message-level data is stored in `prompt_opt_session_history`
|
||||
- Enables efficient session listing and pagination
|
||||
"""
|
||||
__tablename__ = "prompt_opt_session_list"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="Session ID")
|
||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID")
|
||||
# app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False, comment="Application ID")
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, comment="User ID")
|
||||
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time", index=True)
|
||||
|
||||
|
||||
class PromptOptimizerSessionHistory(Base):
|
||||
"""
|
||||
Prompt Optimization Session Message History.
|
||||
|
||||
This table stores the complete conversational history of a prompt
|
||||
optimization session, including system prompts, user inputs, and
|
||||
assistant responses.
|
||||
|
||||
Each record represents a single message within a session, preserving
|
||||
the chronological order of interactions.
|
||||
|
||||
Table Name:
|
||||
prompt_opt_session_history
|
||||
|
||||
Columns:
|
||||
id (UUID):
|
||||
Primary key. Unique identifier for the message record.
|
||||
tenant_id (UUID):
|
||||
Foreign key referencing `tenants.id`.
|
||||
Identifies the tenant under which the session operates.
|
||||
session_id (UUID):
|
||||
Logical session identifier linking messages to a session.
|
||||
user_id (UUID):
|
||||
Foreign key referencing `users.id`.
|
||||
Identifies the user associated with the session.
|
||||
message_role (Text):
|
||||
Role of the message sender (e.g., system, user, assistant).
|
||||
message_content (Text):
|
||||
Raw message content generated or provided during the session.
|
||||
prompt (Text):
|
||||
The prompt snapshot used at the time of message generation.
|
||||
created_at (DateTime):
|
||||
Timestamp indicating when the message was created.
|
||||
|
||||
Design Notes:
|
||||
- Supports full conversation replay and audit
|
||||
- Enables prompt evolution tracking over time
|
||||
- Indexed by creation time for efficient chronological queries
|
||||
"""
|
||||
__tablename__ = "prompt_opt_session_history"
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_prompt_opt_session_history_session_user_created",
|
||||
"session_id",
|
||||
"user_id",
|
||||
"created_at"
|
||||
),
|
||||
)
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID")
|
||||
# app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False, comment="Application ID")
|
||||
session_id = Column(UUID(as_uuid=True), ForeignKey("prompt_opt_session_list.id"),nullable=False, comment="Session ID")
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, comment="User ID")
|
||||
role = Column(String, nullable=False, comment="Message Role")
|
||||
content = Column(Text, nullable=False, comment="Message Content")
|
||||
# prompt = Column(Text, nullable=False, comment="Prompt")
|
||||
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time", index=True)
|
||||
@@ -21,3 +21,6 @@ class Tenants(Base):
|
||||
|
||||
# Relationship to workspaces owned by the tenant
|
||||
owned_workspaces = relationship("Workspace", back_populates="tenant")
|
||||
|
||||
# Relationship to tool configs owned by the tenant
|
||||
tool_configs = relationship("ToolConfig", back_populates="tenant")
|
||||
|
||||
226
api/app/models/tool_model.py
Normal file
226
api/app/models/tool_model.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""工具管理相关数据模型"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.db import Base
|
||||
|
||||
|
||||
class ToolType(StrEnum):
|
||||
"""工具类型枚举"""
|
||||
BUILTIN = "builtin"
|
||||
CUSTOM = "custom"
|
||||
MCP = "mcp"
|
||||
|
||||
|
||||
class ToolStatus(StrEnum):
|
||||
"""工具状态枚举"""
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
ERROR = "error"
|
||||
LOADING = "loading"
|
||||
|
||||
|
||||
class AuthType(StrEnum):
|
||||
"""认证类型枚举"""
|
||||
NONE = "none"
|
||||
API_KEY = "api_key"
|
||||
BEARER_TOKEN = "bearer_token"
|
||||
|
||||
|
||||
class ExecutionStatus(StrEnum):
|
||||
"""执行状态枚举"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
TIMEOUT = "timeout"
|
||||
|
||||
|
||||
class ToolConfig(Base):
|
||||
"""工具配置基础模型"""
|
||||
__tablename__ = "tool_configs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(255), nullable=False, index=True)
|
||||
description = Column(Text)
|
||||
tool_type = Column(String(50), nullable=False, index=True)
|
||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, index=True) # 必须属于租户
|
||||
status = Column(String(50), default=ToolStatus.INACTIVE.value, nullable=False, index=True) # 工具状态
|
||||
|
||||
# 工具特定配置(JSON格式存储)
|
||||
config_data = Column(JSON, default=dict)
|
||||
|
||||
# 元数据
|
||||
version = Column(String(50), default="1.0.0")
|
||||
tags = Column(JSON, default=list) # 标签列表
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, default=datetime.now, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False)
|
||||
|
||||
# 关联关系
|
||||
tenant = relationship("Tenants", back_populates="tool_configs")
|
||||
executions = relationship("ToolExecution", back_populates="tool_config", cascade="all, delete-orphan")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ToolConfig(id={self.id}, name={self.name}, type={self.tool_type}, status={self.status})>"
|
||||
|
||||
|
||||
class BuiltinToolConfig(Base):
|
||||
"""内置工具配置模型"""
|
||||
__tablename__ = "builtin_tool_configs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
|
||||
tool_class = Column(String(255), nullable=False) # 工具类名
|
||||
parameters = Column(JSON, default=dict) # 工具参数配置
|
||||
|
||||
# 关联关系
|
||||
base_config = relationship("ToolConfig", foreign_keys=[id])
|
||||
|
||||
def __repr__(self):
|
||||
return f"<BuiltinToolConfig(id={self.id}, tool_class={self.tool_class})>"
|
||||
|
||||
|
||||
class CustomToolConfig(Base):
|
||||
"""自定义工具配置模型"""
|
||||
__tablename__ = "custom_tool_configs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
|
||||
schema_url = Column(String(1000)) # OpenAPI schema URL
|
||||
schema_content = Column(JSON) # OpenAPI schema 内容
|
||||
|
||||
# 认证配置
|
||||
auth_type = Column(String(50), default=AuthType.NONE.value, nullable=False)
|
||||
auth_config = Column(JSON, default=dict) # 认证配置(加密存储)
|
||||
|
||||
# API配置
|
||||
base_url = Column(String(1000)) # API基础URL
|
||||
timeout = Column(Integer, default=30) # 超时时间(秒)
|
||||
|
||||
# 关联关系
|
||||
base_config = relationship("ToolConfig", foreign_keys=[id])
|
||||
|
||||
def __repr__(self):
|
||||
return f"<CustomToolConfig(id={self.id}, auth_type={self.auth_type})>"
|
||||
|
||||
|
||||
class MCPToolConfig(Base):
|
||||
"""MCP工具配置模型"""
|
||||
__tablename__ = "mcp_tool_configs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
|
||||
server_url = Column(String(1000), nullable=False) # MCP服务器URL
|
||||
connection_config = Column(JSON, default=dict) # 连接配置
|
||||
|
||||
# 服务状态
|
||||
last_health_check = Column(DateTime)
|
||||
health_status = Column(String(50), default="unknown")
|
||||
error_message = Column(Text)
|
||||
|
||||
# 可用工具列表
|
||||
available_tools = Column(JSON, default=list)
|
||||
|
||||
# 关联关系
|
||||
base_config = relationship("ToolConfig", foreign_keys=[id])
|
||||
|
||||
def __repr__(self):
|
||||
return f"<MCPToolConfig(id={self.id}, server_url={self.server_url})>"
|
||||
|
||||
|
||||
class ToolExecution(Base):
|
||||
"""工具执行记录模型"""
|
||||
__tablename__ = "tool_executions"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tool_config_id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), nullable=False, index=True)
|
||||
|
||||
# 执行信息
|
||||
execution_id = Column(String(255), nullable=False, index=True) # 执行ID(可用于关联工作流等)
|
||||
status = Column(String(50), default=ExecutionStatus.PENDING.value, nullable=False, index=True)
|
||||
|
||||
# 输入输出
|
||||
input_data = Column(JSON) # 输入参数
|
||||
output_data = Column(JSON) # 输出结果
|
||||
error_message = Column(Text) # 错误信息
|
||||
|
||||
# 性能指标
|
||||
started_at = Column(DateTime, nullable=False, index=True)
|
||||
completed_at = Column(DateTime)
|
||||
execution_time = Column(Float) # 执行时间(秒)
|
||||
|
||||
# Token使用情况(如果适用)
|
||||
token_usage = Column(JSON)
|
||||
|
||||
# 用户信息
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), index=True)
|
||||
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=False, index=True)
|
||||
|
||||
# 关联关系
|
||||
tool_config = relationship("ToolConfig", back_populates="executions")
|
||||
user = relationship("User")
|
||||
workspace = relationship("Workspace")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ToolExecution(id={self.id}, status={self.status}, tool={self.tool_config_id})>"
|
||||
|
||||
|
||||
# class ToolDependency(Base):
|
||||
# """工具依赖关系模型"""
|
||||
# __tablename__ = "tool_dependencies"
|
||||
#
|
||||
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# tool_id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), nullable=False)
|
||||
# depends_on_tool_id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), nullable=False)
|
||||
#
|
||||
# # 依赖类型和版本要求
|
||||
# dependency_type = Column(String(50), default="required") # required, optional
|
||||
# version_constraint = Column(String(100)) # 版本约束,如 ">=1.0.0"
|
||||
#
|
||||
# # 时间戳
|
||||
# created_at = Column(DateTime, default=datetime.now, nullable=False)
|
||||
#
|
||||
# # 关联关系
|
||||
# tool = relationship("ToolConfig", foreign_keys=[tool_id])
|
||||
# depends_on_tool = relationship("ToolConfig", foreign_keys=[depends_on_tool_id])
|
||||
#
|
||||
# def __repr__(self):
|
||||
# return f"<ToolDependency(tool={self.tool_id}, depends_on={self.depends_on_tool_id})>"
|
||||
|
||||
|
||||
# class PluginConfig(Base):
|
||||
# """插件配置模型"""
|
||||
# __tablename__ = "plugin_configs"
|
||||
#
|
||||
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# name = Column(String(255), nullable=False, unique=True, index=True)
|
||||
# description = Column(Text)
|
||||
#
|
||||
# # 插件信息
|
||||
# plugin_path = Column(String(1000), nullable=False) # 插件文件路径
|
||||
# entry_point = Column(String(255), nullable=False) # 入口点
|
||||
# version = Column(String(50), default="1.0.0")
|
||||
#
|
||||
# # 状态
|
||||
# is_enabled = Column(Boolean, default=True, nullable=False)
|
||||
# is_loaded = Column(Boolean, default=False, nullable=False)
|
||||
# load_error = Column(Text) # 加载错误信息
|
||||
#
|
||||
# # 配置
|
||||
# config_schema = Column(JSON) # 配置schema
|
||||
# config_data = Column(JSON, default=dict) # 配置数据
|
||||
#
|
||||
# # 依赖
|
||||
# dependencies = Column(JSON, default=list) # 依赖的其他插件
|
||||
#
|
||||
# # 时间戳
|
||||
# created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||
# updated_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
|
||||
# last_loaded_at = Column(DateTime)
|
||||
#
|
||||
# def __repr__(self):
|
||||
# return f"<PluginConfig(id={self.id}, name={self.name}, version={self.version})>"
|
||||
@@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
from enum import StrEnum
|
||||
import uuid
|
||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, Boolean
|
||||
from sqlalchemy import Column, String, DateTime, ForeignKey, Boolean
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db import Base
|
||||
|
||||
@@ -16,7 +16,6 @@ from app.models.data_config_model import DataConfig
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
@@ -29,37 +28,37 @@ db_logger = get_db_logger()
|
||||
# 获取配置专用日志器
|
||||
config_logger = get_config_logger()
|
||||
|
||||
|
||||
TABLE_NAME = "data_config"
|
||||
class DataConfigRepository:
|
||||
"""数据配置Repository
|
||||
|
||||
|
||||
提供data_config表的数据访问方法,包括:
|
||||
- SQLAlchemy ORM 数据库操作
|
||||
- Neo4j Cypher查询常量
|
||||
"""
|
||||
|
||||
|
||||
# ==================== Neo4j Cypher 查询常量 ====================
|
||||
|
||||
|
||||
# Dialogue count by group
|
||||
SEARCH_FOR_DIALOGUE = """
|
||||
MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
|
||||
"""
|
||||
|
||||
|
||||
# Chunk count by group
|
||||
SEARCH_FOR_CHUNK = """
|
||||
MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
|
||||
"""
|
||||
|
||||
|
||||
# Statement count by group
|
||||
SEARCH_FOR_STATEMENT = """
|
||||
MATCH (n:Statement) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
|
||||
"""
|
||||
|
||||
|
||||
# ExtractedEntity count by group
|
||||
SEARCH_FOR_ENTITY = """
|
||||
MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
|
||||
"""
|
||||
|
||||
|
||||
# All counts by label and total
|
||||
SEARCH_FOR_ALL = """
|
||||
OPTIONAL MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN 'Dialogue' AS Label, COUNT(n) AS Count
|
||||
@@ -72,7 +71,7 @@ class DataConfigRepository:
|
||||
UNION ALL
|
||||
OPTIONAL MATCH (n) WHERE n.group_id = $group_id RETURN 'ALL' AS Label, COUNT(n) AS Count
|
||||
"""
|
||||
|
||||
|
||||
# Extracted entity details within group/app/user
|
||||
SEARCH_FOR_DETIALS = """
|
||||
MATCH (n:ExtractedEntity)
|
||||
@@ -88,7 +87,7 @@ class DataConfigRepository:
|
||||
n.user_id AS user_id,
|
||||
n.id AS id
|
||||
"""
|
||||
|
||||
|
||||
# Edges between extracted entities within group/app/user
|
||||
SEARCH_FOR_EDGES = """
|
||||
MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity)
|
||||
@@ -104,7 +103,7 @@ class DataConfigRepository:
|
||||
r.statement_id AS statement_id,
|
||||
r.statement AS statement
|
||||
"""
|
||||
|
||||
|
||||
# Entity graph within group (source node, edge, target node)
|
||||
SEARCH_FOR_ENTITY_GRAPH = """
|
||||
MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity)
|
||||
@@ -137,22 +136,106 @@ class DataConfigRepository:
|
||||
id: m.id
|
||||
} AS targetNode
|
||||
"""
|
||||
|
||||
|
||||
# ==================== SQLAlchemy ORM 数据库操作方法 ====================
|
||||
|
||||
@staticmethod
|
||||
def build_update_reflection(config_id: int, **kwargs) -> Tuple[str, Dict]:
|
||||
"""构建反思配置更新语句(SQLAlchemy text() 命名参数)
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
**kwargs: 反思配置参数
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
"""
|
||||
db_logger.debug(f"构建反思配置更新语句: config_id={config_id}")
|
||||
|
||||
key_where = "config_id = :config_id"
|
||||
set_fields: List[str] = []
|
||||
params: Dict = {
|
||||
"config_id": config_id,
|
||||
}
|
||||
|
||||
# 反思配置字段映射
|
||||
mapping = {
|
||||
"enable_self_reflexion": "enable_self_reflexion",
|
||||
"iteration_period": "iteration_period",
|
||||
"reflexion_range": "reflexion_range",
|
||||
"baseline": "baseline",
|
||||
"reflection_model_id": "reflection_model_id",
|
||||
"memory_verify": "memory_verify",
|
||||
"quality_assessment": "quality_assessment",
|
||||
}
|
||||
|
||||
for api_field, db_col in mapping.items():
|
||||
if api_field in kwargs and kwargs[api_field] is not None:
|
||||
set_fields.append(f"{db_col} = :{api_field}")
|
||||
params[api_field] = kwargs[api_field]
|
||||
|
||||
if not set_fields:
|
||||
raise ValueError("No fields to update")
|
||||
|
||||
set_fields.append("updated_at = timezone('Asia/Shanghai', now())")
|
||||
query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}"
|
||||
return query, params
|
||||
|
||||
@staticmethod
|
||||
def build_select_reflection(config_id: int) -> Tuple[str, Dict]:
|
||||
"""构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数)
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
"""
|
||||
db_logger.debug(f"构建反思配置查询语句: config_id={config_id}")
|
||||
|
||||
query = (
|
||||
f"SELECT config_id, enable_self_reflexion, iteration_period, reflexion_range, baseline, "
|
||||
f"reflection_model_id, memory_verify, quality_assessment, user_id "
|
||||
f"FROM {TABLE_NAME} WHERE config_id = :config_id"
|
||||
)
|
||||
params = {"config_id": config_id}
|
||||
return query, params
|
||||
|
||||
@staticmethod
|
||||
def build_select_all(workspace_id: uuid.UUID) -> Tuple[str, Dict]:
|
||||
"""构建查询所有配置的语句(SQLAlchemy text() 命名参数)
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
|
||||
"""
|
||||
db_logger.debug(f"构建查询所有配置语句: workspace_id={workspace_id}")
|
||||
|
||||
query = (
|
||||
f"SELECT config_id, config_name, enable_self_reflexion, iteration_period, reflexion_range, baseline, "
|
||||
f"reflection_model_id, memory_verify, quality_assessment, user_id, created_at, updated_at "
|
||||
f"FROM {TABLE_NAME} WHERE workspace_id = :workspace_id ORDER BY updated_at DESC"
|
||||
)
|
||||
params = {"workspace_id": workspace_id}
|
||||
return query, params
|
||||
|
||||
@staticmethod
|
||||
def create(db: Session, params: ConfigParamsCreate) -> DataConfig:
|
||||
"""创建数据配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
params: 配置参数创建模型
|
||||
|
||||
|
||||
Returns:
|
||||
DataConfig: 创建的配置对象
|
||||
"""
|
||||
db_logger.debug(f"创建数据配置: config_name={params.config_name}, workspace_id={params.workspace_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_config = DataConfig(
|
||||
config_name=params.config_name,
|
||||
@@ -164,37 +247,37 @@ class DataConfigRepository:
|
||||
)
|
||||
db.add(db_config)
|
||||
db.flush() # 获取自增ID但不提交事务
|
||||
|
||||
|
||||
db_logger.info(f"数据配置已添加到会话: {db_config.config_name} (ID: {db_config.config_id})")
|
||||
return db_config
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"创建数据配置失败: {params.config_name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def update(db: Session, update: ConfigUpdate) -> Optional[DataConfig]:
|
||||
"""更新基础配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
update: 配置更新模型
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[DataConfig]: 更新后的配置对象,不存在则返回None
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
"""
|
||||
db_logger.debug(f"更新数据配置: config_id={update.config_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
|
||||
if not db_config:
|
||||
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
|
||||
return None
|
||||
|
||||
|
||||
# 更新字段
|
||||
has_update = False
|
||||
if update.config_name is not None:
|
||||
@@ -203,44 +286,44 @@ class DataConfigRepository:
|
||||
if update.config_desc is not None:
|
||||
db_config.config_desc = update.config_desc
|
||||
has_update = True
|
||||
|
||||
|
||||
if not has_update:
|
||||
raise ValueError("No fields to update")
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_config)
|
||||
|
||||
|
||||
db_logger.info(f"数据配置更新成功: {db_config.config_name} (ID: {update.config_id})")
|
||||
return db_config
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"更新数据配置失败: config_id={update.config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[DataConfig]:
|
||||
"""更新记忆萃取引擎配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
update: 萃取配置更新模型
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[DataConfig]: 更新后的配置对象,不存在则返回None
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
"""
|
||||
db_logger.debug(f"更新萃取配置: config_id={update.config_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
|
||||
if not db_config:
|
||||
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
|
||||
return None
|
||||
|
||||
|
||||
# 更新字段映射
|
||||
field_mapping = {
|
||||
# 模型选择
|
||||
@@ -270,50 +353,50 @@ class DataConfigRepository:
|
||||
"reflexion_range": "reflexion_range",
|
||||
"baseline": "baseline",
|
||||
}
|
||||
|
||||
|
||||
has_update = False
|
||||
for api_field, db_field in field_mapping.items():
|
||||
value = getattr(update, api_field, None)
|
||||
if value is not None:
|
||||
setattr(db_config, db_field, value)
|
||||
has_update = True
|
||||
|
||||
|
||||
if not has_update:
|
||||
raise ValueError("No fields to update")
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_config)
|
||||
|
||||
|
||||
db_logger.info(f"萃取配置更新成功: config_id={update.config_id}")
|
||||
return db_config
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"更新萃取配置失败: config_id={update.config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def update_forget(db: Session, update: ConfigUpdateForget) -> Optional[DataConfig]:
|
||||
"""更新遗忘引擎配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
update: 遗忘配置更新模型
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[DataConfig]: 更新后的配置对象,不存在则返回None
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: 没有字段需要更新时抛出
|
||||
"""
|
||||
db_logger.debug(f"更新遗忘配置: config_id={update.config_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
|
||||
if not db_config:
|
||||
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
|
||||
return None
|
||||
|
||||
|
||||
# 更新字段
|
||||
has_update = False
|
||||
if update.lambda_time is not None:
|
||||
@@ -325,40 +408,40 @@ class DataConfigRepository:
|
||||
if update.offset is not None:
|
||||
db_config.offset = update.offset
|
||||
has_update = True
|
||||
|
||||
|
||||
if not has_update:
|
||||
raise ValueError("No fields to update")
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_config)
|
||||
|
||||
|
||||
db_logger.info(f"遗忘配置更新成功: config_id={update.config_id}")
|
||||
return db_config
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"更新遗忘配置失败: config_id={update.config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_extracted_config(db: Session, config_id: int) -> Optional[Dict]:
|
||||
"""获取萃取配置,通过主键查询某条配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置ID
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: 萃取配置字典,不存在则返回None
|
||||
"""
|
||||
db_logger.debug(f"查询萃取配置: config_id={config_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
if not db_config:
|
||||
db_logger.debug(f"萃取配置不存在: config_id={config_id}")
|
||||
return None
|
||||
|
||||
|
||||
result = {
|
||||
"llm_id": db_config.llm_id,
|
||||
"embedding_id": db_config.embedding_id,
|
||||
@@ -381,62 +464,62 @@ class DataConfigRepository:
|
||||
"reflexion_range": db_config.reflexion_range,
|
||||
"baseline": db_config.baseline,
|
||||
}
|
||||
|
||||
|
||||
db_logger.debug(f"萃取配置查询成功: config_id={config_id}")
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询萃取配置失败: config_id={config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_forget_config(db: Session, config_id: int) -> Optional[Dict]:
|
||||
"""获取遗忘配置,通过主键查询某条配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置ID
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: 遗忘配置字典,不存在则返回None
|
||||
"""
|
||||
db_logger.debug(f"查询遗忘配置: config_id={config_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
if not db_config:
|
||||
db_logger.debug(f"遗忘配置不存在: config_id={config_id}")
|
||||
return None
|
||||
|
||||
|
||||
result = {
|
||||
"lambda_time": db_config.lambda_time,
|
||||
"lambda_mem": db_config.lambda_mem,
|
||||
"offset": db_config.offset,
|
||||
}
|
||||
|
||||
|
||||
db_logger.debug(f"遗忘配置查询成功: config_id={config_id}")
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询遗忘配置失败: config_id={config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(db: Session, config_id: int) -> Optional[DataConfig]:
|
||||
"""根据ID获取数据配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置ID
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[DataConfig]: 配置对象,不存在则返回None
|
||||
"""
|
||||
db_logger.debug(f"根据ID查询数据配置: config_id={config_id}")
|
||||
|
||||
|
||||
try:
|
||||
config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
|
||||
|
||||
if config:
|
||||
db_logger.debug(f"数据配置查询成功: {config.config_name} (ID: {config_id})")
|
||||
else:
|
||||
@@ -571,56 +654,56 @@ class DataConfigRepository:
|
||||
@staticmethod
|
||||
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[DataConfig]:
|
||||
"""获取所有配置参数
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
workspace_id: 工作空间ID,用于过滤查询结果
|
||||
|
||||
|
||||
Returns:
|
||||
List[DataConfig]: 配置列表
|
||||
"""
|
||||
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
|
||||
|
||||
|
||||
try:
|
||||
query = db.query(DataConfig)
|
||||
|
||||
|
||||
if workspace_id:
|
||||
query = query.filter(DataConfig.workspace_id == workspace_id)
|
||||
|
||||
|
||||
configs = query.order_by(desc(DataConfig.updated_at)).all()
|
||||
|
||||
|
||||
db_logger.debug(f"配置列表查询成功: 数量={len(configs)}")
|
||||
return configs
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@staticmethod
|
||||
def delete(db: Session, config_id: int) -> bool:
|
||||
"""删除数据配置
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置ID
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 删除成功返回True,配置不存在返回False
|
||||
"""
|
||||
db_logger.debug(f"删除数据配置: config_id={config_id}")
|
||||
|
||||
|
||||
try:
|
||||
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
if not db_config:
|
||||
db_logger.warning(f"数据配置不存在: config_id={config_id}")
|
||||
return False
|
||||
|
||||
|
||||
db.delete(db_config)
|
||||
db.commit()
|
||||
|
||||
|
||||
db_logger.info(f"数据配置删除成功: config_id={config_id}")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"删除数据配置失败: config_id={config_id} - {str(e)}")
|
||||
|
||||
@@ -115,7 +115,9 @@ def get_knowledge_by_name(db: Session, name: str, workspace_id: uuid.UUID) -> Kn
|
||||
db_logger.debug(f"Query knowledge base based on name and workspace_id: name={name}, workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
knowledge = db.query(Knowledge).filter(Knowledge.name == name).filter(Knowledge.workspace_id == workspace_id).first()
|
||||
knowledge = db.query(Knowledge).filter(Knowledge.name == name,
|
||||
Knowledge.workspace_id == workspace_id,
|
||||
Knowledge.status == 1).first()
|
||||
if knowledge:
|
||||
db_logger.debug(f"knowledge base query successful: {name} (ID: {knowledge.id})")
|
||||
else:
|
||||
|
||||
@@ -3,9 +3,9 @@ from sqlalchemy import and_, or_, func, desc
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
import uuid
|
||||
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelProvider
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType
|
||||
from app.schemas.model_schema import (
|
||||
ModelConfigCreate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
||||
ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
||||
ModelConfigQuery
|
||||
)
|
||||
from app.core.logging_config import get_db_logger
|
||||
@@ -32,7 +32,7 @@ class ModelConfigRepository:
|
||||
query = query.filter(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public == True
|
||||
ModelConfig.is_public
|
||||
)
|
||||
)
|
||||
|
||||
@@ -60,7 +60,7 @@ class ModelConfigRepository:
|
||||
query = query.filter(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public == True
|
||||
ModelConfig.is_public
|
||||
)
|
||||
)
|
||||
|
||||
@@ -92,7 +92,7 @@ class ModelConfigRepository:
|
||||
query = query.filter(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public == True
|
||||
ModelConfig.is_public
|
||||
)
|
||||
)
|
||||
|
||||
@@ -117,13 +117,21 @@ class ModelConfigRepository:
|
||||
filters.append(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public == True
|
||||
ModelConfig.is_public
|
||||
)
|
||||
)
|
||||
|
||||
# 支持多个 type 值(使用 IN 查询)
|
||||
# 兼容 chat 和 llm 类型:如果查询包含其中一个,则同时匹配两者
|
||||
if query.type:
|
||||
filters.append(ModelConfig.type.in_(query.type))
|
||||
type_values = list(query.type)
|
||||
# 如果包含 chat 或 llm,则同时包含两者
|
||||
if ModelType.CHAT in type_values or ModelType.LLM in type_values:
|
||||
if ModelType.CHAT not in type_values:
|
||||
type_values.append(ModelType.CHAT)
|
||||
if ModelType.LLM not in type_values:
|
||||
type_values.append(ModelType.LLM)
|
||||
filters.append(ModelConfig.type.in_(type_values))
|
||||
|
||||
if query.is_active is not None:
|
||||
filters.append(ModelConfig.is_active == query.is_active)
|
||||
@@ -183,12 +191,12 @@ class ModelConfigRepository:
|
||||
query = query.filter(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public == True
|
||||
ModelConfig.is_public
|
||||
)
|
||||
)
|
||||
|
||||
if is_active:
|
||||
query = query.filter(ModelConfig.is_active == True)
|
||||
query = query.filter(ModelConfig.is_active)
|
||||
|
||||
models = query.order_by(ModelConfig.name).all()
|
||||
db_logger.debug(f"根据类型查询模型配置成功: 数量={len(models)}")
|
||||
@@ -285,7 +293,7 @@ class ModelConfigRepository:
|
||||
try:
|
||||
# 总数统计
|
||||
total_models = db.query(ModelConfig).count()
|
||||
active_models = db.query(ModelConfig).filter(ModelConfig.is_active == True).count()
|
||||
active_models = db.query(ModelConfig).filter(ModelConfig.is_active).count()
|
||||
|
||||
# 按类型统计
|
||||
llm_count = db.query(ModelConfig).filter(ModelConfig.type == ModelType.LLM).count()
|
||||
@@ -344,7 +352,7 @@ class ModelApiKeyRepository:
|
||||
query = db.query(ModelApiKey).filter(ModelApiKey.model_config_id == model_config_id)
|
||||
|
||||
if is_active:
|
||||
query = query.filter(ModelApiKey.is_active == True)
|
||||
query = query.filter(ModelApiKey.is_active)
|
||||
|
||||
api_keys = query.order_by(ModelApiKey.priority, ModelApiKey.created_at).all()
|
||||
db_logger.debug(f"API Key列表查询成功: 数量={len(api_keys)}")
|
||||
|
||||
@@ -100,7 +100,13 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
||||
# "triplets": [triplet.model_dump() for triplet in statement.triplet_extraction_info.triplets] if statement.triplet_extraction_info else [],
|
||||
# "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else []
|
||||
# }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}),
|
||||
"statement_embedding": statement.statement_embedding if statement.statement_embedding else None
|
||||
"statement_embedding": statement.statement_embedding if statement.statement_embedding else None,
|
||||
# 添加情绪字段处理
|
||||
"emotion_type": statement.emotion_type,
|
||||
"emotion_intensity": statement.emotion_intensity,
|
||||
"emotion_keywords": statement.emotion_keywords if statement.emotion_keywords else [],
|
||||
"emotion_subject": statement.emotion_subject,
|
||||
"emotion_target": statement.emotion_target
|
||||
}
|
||||
flattened_statements.append(flattened_statement)
|
||||
|
||||
|
||||
@@ -20,20 +20,25 @@ UNWIND $statements AS statement
|
||||
MERGE (s:Statement {id: statement.id})
|
||||
SET s += {
|
||||
id: statement.id,
|
||||
run_id: statement.run_id,
|
||||
chunk_id: statement.chunk_id,
|
||||
group_id: statement.group_id,
|
||||
user_id: statement.user_id,
|
||||
apply_id: statement.apply_id,
|
||||
chunk_id: statement.chunk_id,
|
||||
run_id: statement.run_id,
|
||||
stmt_type: statement.stmt_type,
|
||||
statement: statement.statement,
|
||||
emotion_intensity: statement.emotion_intensity,
|
||||
emotion_target: statement.emotion_target,
|
||||
emotion_subject: statement.emotion_subject,
|
||||
emotion_type: statement.emotion_type,
|
||||
emotion_keywords: statement.emotion_keywords,
|
||||
temporal_info: statement.temporal_info,
|
||||
created_at: statement.created_at,
|
||||
expired_at: statement.expired_at,
|
||||
stmt_type: statement.stmt_type,
|
||||
temporal_info: statement.temporal_info,
|
||||
relevence_info: statement.relevence_info,
|
||||
statement: statement.statement,
|
||||
valid_at: statement.valid_at,
|
||||
invalid_at: statement.invalid_at,
|
||||
statement_embedding: statement.statement_embedding
|
||||
statement_embedding: statement.statement_embedding,
|
||||
relevence_info: statement.relevence_info
|
||||
}
|
||||
RETURN s.id AS uuid
|
||||
"""
|
||||
@@ -746,3 +751,57 @@ DETACH DELETE losing
|
||||
|
||||
RETURN count(losing) as deleted
|
||||
"""
|
||||
|
||||
neo4j_statement_part = '''
|
||||
MATCH (n:Statement)
|
||||
WHERE n.group_id = "{}"
|
||||
AND datetime(n.created_at) >= datetime() - duration('P3D')
|
||||
RETURN
|
||||
n.statement as statement_name,
|
||||
n.id as statement_id,
|
||||
n.created_at as statement_created_at
|
||||
|
||||
'''
|
||||
neo4j_statement_all = '''
|
||||
MATCH (n:Statement)
|
||||
WHERE n.group_id = "{}"
|
||||
RETURN
|
||||
n.statement as statement_name,
|
||||
n.id as statement_id
|
||||
|
||||
'''
|
||||
neo4j_query_part = """
|
||||
MATCH (n)-[r]-(m:ExtractedEntity)
|
||||
WHERE n.group_id = "{}"
|
||||
AND datetime(n.created_at) >= datetime() - duration('P3D')
|
||||
WITH DISTINCT m
|
||||
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
|
||||
RETURN
|
||||
m.name as entity1_name,
|
||||
m.description as description,
|
||||
m.statement_id as statement_id,
|
||||
m.created_at as created_at,
|
||||
m.expired_at as expired_at,
|
||||
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
|
||||
rel as relationship,
|
||||
CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name,
|
||||
other as entity2
|
||||
"""
|
||||
neo4j_query_all = """
|
||||
MATCH (n)-[r]-(m:ExtractedEntity)
|
||||
WHERE n.group_id = "{}"
|
||||
WITH DISTINCT m
|
||||
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
|
||||
RETURN
|
||||
m.name as entity1_name,
|
||||
m.description as description,
|
||||
m.statement_id as statement_id,
|
||||
m.created_at as created_at,
|
||||
m.expired_at as expired_at,
|
||||
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
|
||||
rel as relationship,
|
||||
CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name,
|
||||
other as entity2
|
||||
"""
|
||||
|
||||
|
||||
|
||||
246
api/app/repositories/neo4j/emotion_repository.py
Normal file
246
api/app/repositories/neo4j/emotion_repository.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""情绪数据仓储模块
|
||||
|
||||
本模块提供情绪数据的查询功能,用于情绪分析和统计。
|
||||
|
||||
Classes:
|
||||
EmotionRepository: 情绪数据仓储,提供情绪标签、词云、健康指数等查询方法
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional, Any
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class EmotionRepository:
|
||||
"""情绪数据仓储
|
||||
|
||||
提供情绪数据的查询和统计功能,包括:
|
||||
- 情绪标签统计
|
||||
- 情绪词云数据
|
||||
- 时间范围内的情绪数据查询
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j连接器实例
|
||||
"""
|
||||
|
||||
def __init__(self, connector: Neo4jConnector):
|
||||
"""初始化情绪数据仓储
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器实例
|
||||
"""
|
||||
self.connector = connector
|
||||
logger.info("情绪数据仓储初始化完成")
|
||||
|
||||
async def get_emotion_tags(
|
||||
self,
|
||||
group_id: str,
|
||||
emotion_type: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
limit: int = 10
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取情绪标签统计
|
||||
|
||||
查询指定用户的情绪类型分布,包括计数、百分比和平均强度。
|
||||
|
||||
Args:
|
||||
group_id: 用户组ID(宿主ID)
|
||||
emotion_type: 可选的情绪类型过滤(joy/sadness/anger/fear/surprise/neutral)
|
||||
start_date: 可选的开始日期(ISO格式字符串)
|
||||
end_date: 可选的结束日期(ISO格式字符串)
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
List[Dict]: 情绪标签列表,每个包含:
|
||||
- emotion_type: 情绪类型
|
||||
- count: 该类型的数量
|
||||
- percentage: 占比百分比
|
||||
- avg_intensity: 平均强度
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clauses = ["s.group_id = $group_id", "s.emotion_type IS NOT NULL"]
|
||||
params = {"group_id": group_id, "limit": limit}
|
||||
|
||||
if emotion_type:
|
||||
where_clauses.append("s.emotion_type = $emotion_type")
|
||||
params["emotion_type"] = emotion_type
|
||||
|
||||
if start_date:
|
||||
where_clauses.append("s.created_at >= $start_date")
|
||||
params["start_date"] = start_date
|
||||
|
||||
if end_date:
|
||||
where_clauses.append("s.created_at <= $end_date")
|
||||
params["end_date"] = end_date
|
||||
|
||||
where_str = " AND ".join(where_clauses)
|
||||
|
||||
# 优化的 Cypher 查询:使用索引,减少中间结果
|
||||
query = f"""
|
||||
MATCH (s:Statement)
|
||||
WHERE {where_str}
|
||||
WITH s.emotion_type as emotion_type,
|
||||
count(*) as count,
|
||||
avg(s.emotion_intensity) as avg_intensity
|
||||
WITH collect({{emotion_type: emotion_type, count: count, avg_intensity: avg_intensity}}) as results,
|
||||
sum(count) as total_count
|
||||
UNWIND results as result
|
||||
RETURN result.emotion_type as emotion_type,
|
||||
result.count as count,
|
||||
toFloat(result.count) / total_count * 100 as percentage,
|
||||
result.avg_intensity as avg_intensity
|
||||
ORDER BY count DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
try:
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
formatted_results = [
|
||||
{
|
||||
"emotion_type": record["emotion_type"],
|
||||
"count": record["count"],
|
||||
"percentage": round(record["percentage"], 2),
|
||||
"avg_intensity": round(record["avg_intensity"], 3) if record["avg_intensity"] else 0.0
|
||||
}
|
||||
for record in results
|
||||
]
|
||||
|
||||
return formatted_results
|
||||
except Exception as e:
|
||||
logger.error(f"查询情绪标签失败: {str(e)}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_emotion_wordcloud(
|
||||
self,
|
||||
group_id: str,
|
||||
emotion_type: Optional[str] = None,
|
||||
limit: int = 50
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取情绪词云数据
|
||||
|
||||
查询情绪关键词及其频率,用于生成词云可视化。
|
||||
|
||||
Args:
|
||||
group_id: 用户组ID(宿主ID)
|
||||
emotion_type: 可选的情绪类型过滤
|
||||
limit: 返回关键词的最大数量
|
||||
|
||||
Returns:
|
||||
List[Dict]: 关键词列表,每个包含:
|
||||
- keyword: 关键词
|
||||
- frequency: 出现频率
|
||||
- emotion_type: 关联的情绪类型
|
||||
- avg_intensity: 平均强度
|
||||
"""
|
||||
# 构建查询条件
|
||||
where_clauses = ["s.group_id = $group_id", "s.emotion_keywords IS NOT NULL"]
|
||||
params = {"group_id": group_id, "limit": limit}
|
||||
|
||||
if emotion_type:
|
||||
where_clauses.append("s.emotion_type = $emotion_type")
|
||||
params["emotion_type"] = emotion_type
|
||||
|
||||
where_str = " AND ".join(where_clauses)
|
||||
|
||||
# 优化的 Cypher 查询:使用索引,减少不必要的计算
|
||||
query = f"""
|
||||
MATCH (s:Statement)
|
||||
WHERE {where_str}
|
||||
UNWIND s.emotion_keywords as keyword
|
||||
WITH keyword,
|
||||
s.emotion_type as emotion_type,
|
||||
count(*) as frequency,
|
||||
avg(s.emotion_intensity) as avg_intensity
|
||||
WHERE keyword IS NOT NULL AND keyword <> ''
|
||||
RETURN keyword,
|
||||
frequency,
|
||||
emotion_type,
|
||||
avg_intensity
|
||||
ORDER BY frequency DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
try:
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
formatted_results = [
|
||||
{
|
||||
"keyword": record["keyword"],
|
||||
"frequency": record["frequency"],
|
||||
"emotion_type": record["emotion_type"],
|
||||
"avg_intensity": round(record["avg_intensity"], 3) if record["avg_intensity"] else 0.0
|
||||
}
|
||||
for record in results
|
||||
]
|
||||
|
||||
return formatted_results
|
||||
except Exception as e:
|
||||
logger.error(f"查询情绪词云失败: {str(e)}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_emotions_in_range(
|
||||
self,
|
||||
group_id: str,
|
||||
time_range: str = "30d"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取时间范围内的情绪数据
|
||||
|
||||
查询指定时间范围内的所有情绪数据,用于健康指数计算。
|
||||
|
||||
Args:
|
||||
group_id: 用户组ID(宿主ID)
|
||||
time_range: 时间范围(7d/30d/90d)
|
||||
|
||||
Returns:
|
||||
List[Dict]: 情绪数据列表,每个包含:
|
||||
- emotion_type: 情绪类型
|
||||
- emotion_intensity: 情绪强度
|
||||
- created_at: 创建时间
|
||||
- statement_id: 陈述句ID
|
||||
"""
|
||||
# 解析时间范围
|
||||
days_map = {"7d": 7, "30d": 30, "90d": 90}
|
||||
days = days_map.get(time_range, 30)
|
||||
|
||||
# 计算起始日期(使用字符串比较,避免时区问题)
|
||||
start_date = (datetime.now() - timedelta(days=days)).isoformat()
|
||||
|
||||
# 优化的 Cypher 查询:使用字符串比较避免时区问题
|
||||
query = """
|
||||
MATCH (s:Statement)
|
||||
WHERE s.group_id = $group_id
|
||||
AND s.emotion_type IS NOT NULL
|
||||
AND s.created_at >= $start_date
|
||||
RETURN s.id as statement_id,
|
||||
s.emotion_type as emotion_type,
|
||||
s.emotion_intensity as emotion_intensity,
|
||||
s.created_at as created_at
|
||||
ORDER BY s.created_at ASC
|
||||
"""
|
||||
|
||||
try:
|
||||
results = await self.connector.execute_query(
|
||||
query,
|
||||
group_id=group_id,
|
||||
start_date=start_date
|
||||
)
|
||||
formatted_results = [
|
||||
{
|
||||
"statement_id": record["statement_id"],
|
||||
"emotion_type": record["emotion_type"],
|
||||
"emotion_intensity": record["emotion_intensity"],
|
||||
"created_at": record["created_at"].isoformat() if hasattr(record["created_at"], "isoformat") else str(record["created_at"])
|
||||
}
|
||||
for record in results
|
||||
]
|
||||
|
||||
return formatted_results
|
||||
except Exception as e:
|
||||
logger.error(f"查询时间范围情绪数据失败: {str(e)}", exc_info=True)
|
||||
return []
|
||||
227
api/app/repositories/neo4j/neo4j_update.py
Normal file
227
api/app/repositories/neo4j/neo4j_update.py
Normal file
@@ -0,0 +1,227 @@
|
||||
from app.repositories import Neo4jConnector
|
||||
|
||||
neo4j_connector = Neo4jConnector()
|
||||
|
||||
async def update_neo4j_data(neo4j_dict_data, update_databases):
|
||||
"""
|
||||
Update Neo4j data based on query criteria and update parameters
|
||||
|
||||
Args:
|
||||
neo4j_dict_data: find
|
||||
update_databases: update
|
||||
"""
|
||||
try:
|
||||
# 构建WHERE条件
|
||||
where_conditions = []
|
||||
params = {}
|
||||
|
||||
for key, value in neo4j_dict_data.items():
|
||||
if value is not None:
|
||||
param_name = f"param_{key}"
|
||||
where_conditions.append(f"e.{key} = ${param_name}")
|
||||
params[param_name] = value
|
||||
|
||||
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
|
||||
|
||||
# 构建SET条件
|
||||
set_conditions = []
|
||||
for key, value in update_databases.items():
|
||||
if value is not None:
|
||||
param_name = f"update_{key}"
|
||||
set_conditions.append(f"e.{key} = ${param_name}")
|
||||
params[param_name] = value
|
||||
|
||||
set_clause = ", ".join(set_conditions)
|
||||
|
||||
if not set_clause:
|
||||
print("警告: 没有需要更新的字段")
|
||||
return False
|
||||
|
||||
# 构建Cypher查询
|
||||
cypher_query = f"""
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE {where_clause}
|
||||
SET {set_clause}
|
||||
RETURN count(e) as updated_count, collect(e.name) as updated_names
|
||||
"""
|
||||
|
||||
print(f"\n执行Cypher查询: {cypher_query}")
|
||||
print(f"参数: {params}")
|
||||
|
||||
# 执行更新
|
||||
result = await neo4j_connector.execute_query(cypher_query, **params)
|
||||
|
||||
if result:
|
||||
updated_count = result[0].get('updated_count', 0)
|
||||
updated_names = result[0].get('updated_names', [])
|
||||
print(f"成功更新 {updated_count} 个节点")
|
||||
if updated_names:
|
||||
print(f"更新的实体名称: {updated_names}")
|
||||
return updated_count > 0
|
||||
else:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"更新过程中出现错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def map_field_names(data_dict):
|
||||
mapped_dict = {}
|
||||
has_name_field = False
|
||||
|
||||
# 第一遍:检查是否有name相关字段
|
||||
for key, value in data_dict.items():
|
||||
if key in ['name', 'entity2.name', 'entity1.name']:
|
||||
has_name_field = True
|
||||
break
|
||||
|
||||
print(f"字段检查: has_name_field = {has_name_field}")
|
||||
|
||||
# 第二遍:根据规则映射和过滤字段
|
||||
for key, value in data_dict.items():
|
||||
if key == 'entity2.name' or key == 'entity2_name':
|
||||
# 将 entity2.name 映射为 name
|
||||
mapped_dict['name'] = value
|
||||
print(f"字段名映射: {key} -> name")
|
||||
elif key == 'entity1.name' or key == 'entity1_name':
|
||||
# 将 entity1.name 映射为 name
|
||||
mapped_dict['name'] = value
|
||||
print(f"字段名映射: {key} -> name")
|
||||
elif key == 'entity1.description':
|
||||
# 将 entity1.description 映射为 description
|
||||
mapped_dict['description'] = value
|
||||
print(f"字段名映射: {key} -> description")
|
||||
elif key == 'entity2.description':
|
||||
# 将 entity2.description 映射为 description
|
||||
mapped_dict['description'] = value
|
||||
print(f"字段名映射: {key} -> description")
|
||||
elif key == 'relationship_type':
|
||||
# 跳过relationship_type字段
|
||||
print(f"字段过滤: 跳过不需要的字段 '{key}'")
|
||||
continue
|
||||
elif key == 'entity1_name':
|
||||
if has_name_field:
|
||||
# 如果有name字段,跳过entity1_name
|
||||
print(f"字段过滤: 由于存在name字段,跳过 '{key}'")
|
||||
continue
|
||||
else:
|
||||
# 如果没有name字段,保留entity1_name
|
||||
mapped_dict[key] = value
|
||||
print(f"字段保留: {key}")
|
||||
elif key == 'entity2_name':
|
||||
if has_name_field:
|
||||
# 如果有name字段,跳过entity2_name
|
||||
print(f"字段过滤: 由于存在name字段,跳过 '{key}'")
|
||||
continue
|
||||
else:
|
||||
# 即使没有name字段,也不使用entity2_name(根据需求)
|
||||
print(f"字段过滤: 跳过不推荐的字段 '{key}'")
|
||||
continue
|
||||
elif '.' not in key:
|
||||
# 不包含点号的其他字段直接保留
|
||||
mapped_dict[key] = value
|
||||
else:
|
||||
# 其他包含点号的字段跳过并警告
|
||||
print(f"警告: 跳过不支持的嵌套字段 '{key}'")
|
||||
|
||||
print(f"字段映射结果: {mapped_dict}")
|
||||
return mapped_dict
|
||||
async def neo4j_data(solved_data):
|
||||
"""
|
||||
Process the resolved data and update the Neo4j database
|
||||
Args:
|
||||
Solved_data: Solution Data List
|
||||
Returns:
|
||||
Int: Number of successfully updated records
|
||||
"""
|
||||
success_count = 0
|
||||
|
||||
for i in solved_data:
|
||||
neo4j_dict_data = {}
|
||||
update_databases = {}
|
||||
results = i['results']
|
||||
for data in results:
|
||||
resolved = data.get('resolved')
|
||||
if not resolved:
|
||||
print("跳过:resolved为None")
|
||||
continue
|
||||
|
||||
try:
|
||||
change_list = resolved.get('change', [])
|
||||
except (AttributeError, TypeError):
|
||||
change_list = []
|
||||
|
||||
if change_list == []:
|
||||
print("跳过:change_list为空")
|
||||
continue
|
||||
|
||||
if change_list and len(change_list) > 0:
|
||||
change = change_list[0]
|
||||
print(f"change: {change}")
|
||||
field_data = change.get('field', [])
|
||||
print(f"field_data: {field_data}")
|
||||
print(f"field_data type: {type(field_data)}")
|
||||
|
||||
# 字段名映射和过滤函数
|
||||
|
||||
|
||||
# 处理field数据,可能是字典或列表
|
||||
if isinstance(field_data, dict):
|
||||
# 如果是字典,映射字段名后更新
|
||||
mapped_data = map_field_names(field_data)
|
||||
update_databases.update(mapped_data)
|
||||
elif isinstance(field_data, list):
|
||||
# 如果是列表,遍历每个字典并更新
|
||||
for field_item in field_data:
|
||||
if isinstance(field_item, dict):
|
||||
mapped_item = map_field_names(field_item)
|
||||
update_databases.update(mapped_item)
|
||||
else:
|
||||
print(f"警告: field_item不是字典: {field_item}")
|
||||
else:
|
||||
print(f"警告: field_data类型不支持: {type(field_data)}")
|
||||
|
||||
if 'entity1_name' in data:
|
||||
data['name'] = data.pop('entity1_name')
|
||||
if 'entity2_name' in data:
|
||||
data.pop('entity2_name', None)
|
||||
|
||||
resolved_memory = resolved.get('resolved_memory', {})
|
||||
|
||||
entity2 = None
|
||||
if isinstance(resolved_memory, dict):
|
||||
entity2 = resolved_memory.get('entity2')
|
||||
|
||||
if entity2 and isinstance(entity2, dict) and len(entity2) >= 5:
|
||||
stat_id = resolved.get('original_memory_id')
|
||||
# 安全地获取description
|
||||
statement_id = None
|
||||
if isinstance(resolved_memory, dict):
|
||||
statement_id = resolved_memory.get('statement_id')
|
||||
|
||||
# 只有当neo4j_dict_data中还没有statement_id时才使用original_memory_id
|
||||
if statement_id and 'id' not in neo4j_dict_data:
|
||||
neo4j_dict_data['id'] = stat_id
|
||||
neo4j_dict_data['statement_id'] = statement_id
|
||||
else:
|
||||
# 处理original_memory_id,它可能是字符串或字典
|
||||
try:
|
||||
for key, value in resolved_memory.items():
|
||||
if key == 'statement_id':
|
||||
neo4j_dict_data['statement_id'] = value
|
||||
if key == 'description':
|
||||
neo4j_dict_data['description'] = value
|
||||
except AttributeError:
|
||||
neo4j_dict_data=[]
|
||||
|
||||
print(neo4j_dict_data)
|
||||
print(update_databases)
|
||||
if neo4j_dict_data!=[]:
|
||||
await update_neo4j_data(neo4j_dict_data, update_databases)
|
||||
success_count += 1
|
||||
|
||||
return success_count
|
||||
|
||||
@@ -58,11 +58,22 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]):
|
||||
n['invalid_at'] = datetime.fromisoformat(n['invalid_at'])
|
||||
|
||||
# 处理temporal_info字段
|
||||
if isinstance(n.get('temporal_info'), dict):
|
||||
if isinstance(n.get('temporal_info'), str):
|
||||
# 从字符串转换为枚举值
|
||||
n['temporal_info'] = TemporalInfo(n['temporal_info'])
|
||||
elif isinstance(n.get('temporal_info'), dict):
|
||||
n['temporal_info'] = TemporalInfo(**n['temporal_info'])
|
||||
elif not n.get('temporal_info'):
|
||||
# 如果没有temporal_info,创建一个默认的
|
||||
n['temporal_info'] = TemporalInfo()
|
||||
n['temporal_info'] = TemporalInfo.STATIC
|
||||
|
||||
# 处理情绪字段 - 映射 Neo4j 节点属性到 StatementNode 模型
|
||||
# 处理空值情况,确保字段存在
|
||||
n['emotion_type'] = n.get('emotion_type')
|
||||
n['emotion_intensity'] = n.get('emotion_intensity')
|
||||
n['emotion_keywords'] = n.get('emotion_keywords', [])
|
||||
n['emotion_subject'] = n.get('emotion_subject')
|
||||
n['emotion_target'] = n.get('emotion_target')
|
||||
|
||||
return StatementNode(**n)
|
||||
|
||||
|
||||
124
api/app/repositories/prompt_optimizer_repository.py
Normal file
124
api/app/repositories/prompt_optimizer_repository.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
from app.models.prompt_optimizer_model import (
|
||||
PromptOptimizerSession, PromptOptimizerSessionHistory, RoleType
|
||||
)
|
||||
|
||||
db_logger = get_db_logger()
|
||||
|
||||
|
||||
class PromptOptimizerSessionRepository:
|
||||
"""Repository for managing prompt optimization sessions and session history."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
user_id: uuid.UUID
|
||||
) -> PromptOptimizerSession:
|
||||
"""
|
||||
Create a new prompt optimization session for a user and app.
|
||||
|
||||
Args:
|
||||
tenant_id (uuid.UUID): The unique identifier of the tenant.
|
||||
user_id (uuid.UUID): The unique identifier of the user.
|
||||
|
||||
Returns:
|
||||
PromptOptimizerSession: The newly created session object.
|
||||
"""
|
||||
db_logger.debug(f"Create prompt optimization session: tenant_id={tenant_id}, user_id={user_id}")
|
||||
try:
|
||||
session = PromptOptimizerSession(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
self.db.add(session)
|
||||
self.db.commit()
|
||||
self.db.refresh(session)
|
||||
db_logger.debug(f"Prompt optimization session created: ID:{session.id}")
|
||||
return session
|
||||
except Exception as e:
|
||||
db_logger.error(f"Error creating prompt optimization session: user_id={user_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_session_history(
|
||||
self,
|
||||
session_id: uuid.UUID,
|
||||
user_id: uuid.UUID
|
||||
) -> list[type[PromptOptimizerSessionHistory]]:
|
||||
"""
|
||||
Retrieve all message history of a specific prompt optimization session.
|
||||
|
||||
Args:
|
||||
session_id (uuid.UUID): The unique identifier of the session.
|
||||
user_id (uuid.UUID): The unique identifier of the user.
|
||||
|
||||
Returns:
|
||||
list[PromptOptimizerSessionHistory]: A list of session history records
|
||||
ordered by creation time ascending.
|
||||
"""
|
||||
db_logger.debug(f"Get prompt optimization session history: "
|
||||
f"user_id={user_id}, session_id={session_id}")
|
||||
|
||||
try:
|
||||
# First get the internal session ID from the session list table
|
||||
session = self.db.query(PromptOptimizerSession).filter(
|
||||
PromptOptimizerSession.id == session_id,
|
||||
PromptOptimizerSession.user_id == user_id
|
||||
).first()
|
||||
|
||||
if not session:
|
||||
return []
|
||||
|
||||
history = self.db.query(PromptOptimizerSessionHistory).filter(
|
||||
PromptOptimizerSessionHistory.session_id == session.id,
|
||||
PromptOptimizerSessionHistory.user_id == user_id
|
||||
).order_by(PromptOptimizerSessionHistory.created_at.asc()).all()
|
||||
return history
|
||||
except Exception as e:
|
||||
db_logger.error(f"Error retrieving prompt optimization session history: session_id={session_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def create_message(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
session_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
role: RoleType,
|
||||
content: str,
|
||||
) -> PromptOptimizerSessionHistory:
|
||||
"""
|
||||
Create a new message in the session history.
|
||||
|
||||
This method is a placeholder for future implementation.
|
||||
"""
|
||||
try:
|
||||
# Get the session to ensure it exists and belongs to the user
|
||||
session = self.db.query(PromptOptimizerSession).filter(
|
||||
PromptOptimizerSession.id == session_id,
|
||||
PromptOptimizerSession.user_id == user_id,
|
||||
PromptOptimizerSession.tenant_id == tenant_id
|
||||
).first()
|
||||
|
||||
if not session:
|
||||
db_logger.error(f"Session {session_id} not found for user {user_id}")
|
||||
raise ValueError(f"Session {session_id} not found for user {user_id}")
|
||||
|
||||
message = PromptOptimizerSessionHistory(
|
||||
tenant_id=tenant_id,
|
||||
session_id=session.id,
|
||||
user_id=user_id,
|
||||
role=role.value,
|
||||
content=content,
|
||||
)
|
||||
self.db.add(message)
|
||||
self.db.commit()
|
||||
return message
|
||||
except Exception as e:
|
||||
db_logger.error(f"Error creating prompt optimization session history: session_id={session_id} - {str(e)}")
|
||||
raise
|
||||
32
api/app/schemas/emotion_schema.py
Normal file
32
api/app/schemas/emotion_schema.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""情绪分析相关的请求和响应模型"""
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class EmotionTagsRequest(BaseModel):
|
||||
"""获取情绪标签统计请求"""
|
||||
group_id: str = Field(..., description="组ID")
|
||||
emotion_type: Optional[str] = Field(None, description="情绪类型过滤(joy/sadness/anger/fear/surprise/neutral)")
|
||||
start_date: Optional[str] = Field(None, description="开始日期(ISO格式,如:2024-01-01)")
|
||||
end_date: Optional[str] = Field(None, description="结束日期(ISO格式,如:2024-12-31)")
|
||||
limit: int = Field(10, ge=1, le=100, description="返回数量限制")
|
||||
|
||||
|
||||
class EmotionWordcloudRequest(BaseModel):
|
||||
"""获取情绪词云数据请求"""
|
||||
group_id: str = Field(..., description="组ID")
|
||||
emotion_type: Optional[str] = Field(None, description="情绪类型过滤(joy/sadness/anger/fear/surprise/neutral)")
|
||||
limit: int = Field(50, ge=1, le=200, description="返回词语数量")
|
||||
|
||||
|
||||
class EmotionHealthRequest(BaseModel):
|
||||
"""获取情绪健康指数请求"""
|
||||
group_id: str = Field(..., description="组ID")
|
||||
time_range: str = Field("30d", description="时间范围(7d/30d/90d)")
|
||||
|
||||
|
||||
class EmotionSuggestionsRequest(BaseModel):
|
||||
"""获取个性化情绪建议请求"""
|
||||
group_id: str = Field(..., description="组ID")
|
||||
config_id: Optional[int] = Field(None, description="配置ID(用于指定LLM模型)")
|
||||
@@ -13,5 +13,6 @@ class EndUser(BaseModel):
|
||||
other_id: Optional[str] = Field(description="第三方ID", default=None)
|
||||
other_name: Optional[str] = Field(description="其他名称", default="")
|
||||
other_address: Optional[str] = Field(description="其他地址", default="")
|
||||
reflection_time: Optional[datetime.datetime] = Field(description="反思时间", default_factory=datetime.datetime.now)
|
||||
created_at: datetime.datetime = Field(description="创建时间", default_factory=datetime.datetime.now)
|
||||
updated_at: datetime.datetime = Field(description="更新时间", default_factory=datetime.datetime.now)
|
||||
|
||||
52
api/app/schemas/memory_reflection_schemas.py
Normal file
52
api/app/schemas/memory_reflection_schemas.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class OptimizationStrategy(str, Enum):
|
||||
"""优化策略枚举"""
|
||||
SPEED_FIRST = "speed_first"
|
||||
ACCURACY_FIRST = "accuracy_first"
|
||||
BALANCED = "balanced"
|
||||
class Memory_Reflection(BaseModel):
|
||||
config_id: Optional[int] = None
|
||||
reflection_enabled: bool
|
||||
reflection_period_in_hours: str
|
||||
reflexion_range: str
|
||||
baseline: str
|
||||
reflection_model_id: str
|
||||
memory_verify: bool
|
||||
quality_assessment: bool
|
||||
|
||||
# 新增快速引擎优化参数
|
||||
optimization_strategy: Optional[OptimizationStrategy] = OptimizationStrategy.BALANCED
|
||||
use_fast_model: Optional[bool] = True
|
||||
enable_caching: Optional[bool] = True
|
||||
enable_streaming: Optional[bool] = True
|
||||
batch_size: Optional[int] = Field(default=3, ge=1, le=10)
|
||||
max_concurrent: Optional[int] = Field(default=5, ge=1, le=20)
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class FastReflectionRequest(BaseModel):
|
||||
"""快速反思请求模型"""
|
||||
reflection: Memory_Reflection
|
||||
host_id: Optional[str] = "88a459f5_text02"
|
||||
optimization_strategy: Optional[OptimizationStrategy] = OptimizationStrategy.BALANCED
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class ReflectionBenchmarkRequest(BaseModel):
|
||||
"""反思基准测试请求模型"""
|
||||
reflection: Memory_Reflection
|
||||
host_id: Optional[str] = "88a459f5_text02"
|
||||
iterations: Optional[int] = Field(default=3, ge=1, le=10)
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
所有的内容是放错误地方了,应该放在models
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, List, Dict, Literal
|
||||
from typing import Any, Optional, List, Dict, Literal, Union
|
||||
import time
|
||||
import uuid
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator
|
||||
@@ -28,25 +28,48 @@ class Write_UserInput(BaseModel):
|
||||
# ============================================================================
|
||||
class BaseDataSchema(BaseModel):
|
||||
"""Base schema for the data"""
|
||||
id: str = Field(..., description="The unique identifier for the data entry.")
|
||||
statement: str = Field(..., description="The statement text.")
|
||||
group_id: str = Field(..., description="The group identifier.")
|
||||
chunk_id: str = Field(..., description="The chunk identifier.")
|
||||
# 保持原有必需字段为可选,以兼容不同数据源
|
||||
id: Optional[str] = Field(None, description="The unique identifier for the data entry.")
|
||||
statement: Optional[str] = Field(None, description="The statement text.")
|
||||
group_id: Optional[str] = Field(None, description="The group identifier.")
|
||||
chunk_id: Optional[str] = Field(None, description="The chunk identifier.")
|
||||
created_at: str = Field(..., description="The creation timestamp in ISO 8601 format.")
|
||||
expired_at: Optional[str] = Field(None, description="The expiration timestamp in ISO 8601 format.")
|
||||
valid_at: Optional[str] = Field(None, description="The validation timestamp in ISO 8601 format.")
|
||||
invalid_at: Optional[str] = Field(None, description="The invalidation timestamp in ISO 8601 format.")
|
||||
entity_ids: List[str] = Field([], description="The list of entity identifiers.")
|
||||
description: Optional[str] = Field(None, description="The description of the data entry.")
|
||||
|
||||
# 新增字段以匹配实际输入数据
|
||||
entity1_name: str = Field(..., description="The first entity name.")
|
||||
entity2_name: Optional[str] = Field(None, description="The second entity name.")
|
||||
statement_id: str = Field(..., description="The statement identifier.")
|
||||
relationship_type: str = Field(..., description="The relationship type.")
|
||||
relationship: Optional[Dict[str, Any]] = Field(None, description="The relationship object.")
|
||||
entity2: Optional[Dict[str, Any]] = Field(None, description="The second entity object.")
|
||||
|
||||
|
||||
class QualityAssessmentSchema(BaseModel):
|
||||
"""Schema for memory quality assessment results."""
|
||||
score: int = Field(..., ge=0, le=100, description="Quality score percentage (0-100).")
|
||||
summary: str = Field(..., description="Brief summary of data quality status, including main issues and strengths.")
|
||||
|
||||
|
||||
class MemoryVerifySchema(BaseModel):
|
||||
"""Schema for memory privacy verification results."""
|
||||
has_privacy: bool = Field(..., description="Whether privacy information was detected.")
|
||||
privacy_types: List[str] = Field([], description="List of detected privacy information types.")
|
||||
summary: str = Field(..., description="Brief summary of privacy detection results.")
|
||||
|
||||
|
||||
class ConflictResultSchema(BaseModel):
|
||||
"""Schema for the conflict result data in the reflexion_data.json file."""
|
||||
data: List[BaseDataSchema] = Field(..., description="The conflict memory data.")
|
||||
data: List[BaseDataSchema] = Field(..., description="The conflict memory data. Only contains conflicting records when conflict is True.")
|
||||
conflict: bool = Field(..., description="Whether the memory is in conflict.")
|
||||
conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.")
|
||||
quality_assessment: Optional[QualityAssessmentSchema] = Field(None, description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.")
|
||||
memory_verify: Optional[MemoryVerifySchema] = Field(None, description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_data(cls, v):
|
||||
if isinstance(v, dict):
|
||||
d = v.get("data")
|
||||
@@ -61,7 +84,6 @@ class ConflictSchema(BaseModel):
|
||||
conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_data(cls, v):
|
||||
if isinstance(v, dict):
|
||||
d = v.get("data")
|
||||
@@ -76,21 +98,30 @@ class ReflexionSchema(BaseModel):
|
||||
solution: str = Field(..., description="The solution for the reflexion.")
|
||||
|
||||
|
||||
class ChangeRecordSchema(BaseModel):
|
||||
"""Schema for individual change records"""
|
||||
field: List[Dict[str, str]] = Field(..., description="List of field changes, each containing field name and new value.")
|
||||
|
||||
class ResolvedSchema(BaseModel):
|
||||
"""Schema for the resolved memory data in the reflexion_data"""
|
||||
original_memory_id: Optional[str] = Field(None, description="The original memory identifier.")
|
||||
resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data.")
|
||||
# resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).")
|
||||
resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None, description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.")
|
||||
change: Optional[List[ChangeRecordSchema]] = Field(None, description="List of detailed change records with IDs and field information.")
|
||||
|
||||
|
||||
class SingleReflexionResultSchema(BaseModel):
|
||||
"""Schema for a single reflexion result item."""
|
||||
conflict: ConflictResultSchema = Field(..., description="The conflict result data for this specific conflict type.")
|
||||
reflexion: ReflexionSchema = Field(..., description="The reflexion data for this conflict.")
|
||||
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.")
|
||||
type: str = Field("reflexion_result", description="The type identifier.")
|
||||
|
||||
class ReflexionResultSchema(BaseModel):
|
||||
"""Schema for the reflexion result data in the reflexion_data.json file."""
|
||||
# 模型输出中 "conflict" 为单个冲突对象(包含 data 与 conflict_memory),而非字典映射
|
||||
conflict: ConflictResultSchema = Field(..., description="The conflict result data.")
|
||||
reflexion: Optional[ReflexionSchema] = Field(None, description="The reflexion data.")
|
||||
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data.")
|
||||
"""Schema for the complete reflexion result data - a list of individual conflict resolutions."""
|
||||
results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_resolved(cls, v):
|
||||
if isinstance(v, dict):
|
||||
conflict = v.get("conflict")
|
||||
|
||||
99
api/app/schemas/prompt_optimizer_schema.py
Normal file
99
api/app/schemas/prompt_optimizer_schema.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
# =========================================
|
||||
# API Request Schemas
|
||||
# =========================================
|
||||
class PromptOptMessage(BaseModel):
|
||||
model_id: UUID = Field(
|
||||
...,
|
||||
description="Model ID"
|
||||
)
|
||||
message: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="User's input message"
|
||||
)
|
||||
|
||||
current_prompt: str = Field(
|
||||
default="",
|
||||
description="currently optimized prompt"
|
||||
)
|
||||
|
||||
|
||||
class PromptOptModelSet(BaseModel):
|
||||
id: UUID | None = Field(
|
||||
default=None,
|
||||
description="Configuration ID"
|
||||
)
|
||||
|
||||
system_prompt: str = Field(
|
||||
...,
|
||||
description="System Prompt"
|
||||
)
|
||||
|
||||
|
||||
# =========================================
|
||||
# Service Layer Results
|
||||
# =========================================
|
||||
class OptimizePromptResult(BaseModel):
|
||||
prompt: str = Field(
|
||||
...,
|
||||
description="Optimized Prompt"
|
||||
)
|
||||
desc: str = Field(
|
||||
...,
|
||||
description="Description"
|
||||
)
|
||||
|
||||
|
||||
# =========================================
|
||||
# API Response Schemas
|
||||
# =========================================
|
||||
class CreateSessionResponse(BaseModel):
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
id: UUID = Field(
|
||||
...,
|
||||
description="Session ID"
|
||||
)
|
||||
|
||||
|
||||
class OptimizePromptResponse(BaseModel):
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
prompt: str = Field(
|
||||
...,
|
||||
description="Optimized Prompt"
|
||||
)
|
||||
desc: str = Field(
|
||||
...,
|
||||
description="Description"
|
||||
)
|
||||
variables: list = Field(
|
||||
...,
|
||||
description="Variables"
|
||||
)
|
||||
|
||||
|
||||
class SessionMessage(BaseModel):
|
||||
role: str = Field(
|
||||
...,
|
||||
description="Message role (user/assistant)"
|
||||
)
|
||||
content: str = Field(
|
||||
...,
|
||||
description="Message content"
|
||||
)
|
||||
|
||||
|
||||
class SessionHistoryResponse(BaseModel):
|
||||
session_id: UUID = Field(
|
||||
...,
|
||||
description="Session ID"
|
||||
)
|
||||
messages: list[SessionMessage] = Field(
|
||||
...,
|
||||
description="List of messages in the session"
|
||||
)
|
||||
@@ -14,6 +14,7 @@ from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.repositories import workspace_repository, knowledge_repository
|
||||
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@@ -328,4 +329,4 @@ def create_agent_invocation_tool(
|
||||
)
|
||||
return f"调用 Agent 失败: {str(e)}"
|
||||
|
||||
return invoke_agent
|
||||
return invoke_agent
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user