Merge remote-tracking branch 'origin/develop' into refactor/memory-config-management

This commit is contained in:
Ke Sun
2025-12-22 11:37:08 +08:00
119 changed files with 18212 additions and 2208 deletions

View File

@@ -83,17 +83,18 @@ celery_app.autodiscover_tasks(['app'])
reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS)
health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS)
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
# 构建定时任务配置
beat_schedule_config = {
"run-reflection-engine": {
"task": "app.core.memory.agent.reflection.timer",
"schedule": reflection_schedule,
"args": (),
},
"check-read-service": {
"task": "app.core.memory.agent.health.check_read_service",
"schedule": health_schedule,
# "check-read-service": {
# "task": "app.core.memory.agent.health.check_read_service",
# "schedule": health_schedule,
# "args": (),
# },
"run-workspace-reflection": {
"task": "app.tasks.workspace_reflection_task",
"schedule": workspace_reflection_schedule,
"args": (),
},
}

View File

@@ -23,11 +23,17 @@ from . import (
memory_dashboard_controller,
memory_storage_controller,
memory_dashboard_controller,
memory_reflection_controller,
api_key_controller,
release_share_controller,
public_share_controller,
multi_agent_controller,
workflow_controller,
emotion_controller,
emotion_config_controller,
prompt_optimizer_controller,
tool_controller,
tool_execution_controller,
)
# 创建管理端 API 路由器
@@ -58,5 +64,11 @@ manager_router.include_router(public_share_controller.router) # 公开路由(
manager_router.include_router(memory_dashboard_controller.router)
manager_router.include_router(multi_agent_controller.router)
manager_router.include_router(workflow_controller.router)
manager_router.include_router(emotion_controller.router)
manager_router.include_router(emotion_config_controller.router)
manager_router.include_router(prompt_optimizer_controller.router)
manager_router.include_router(memory_reflection_controller.router)
manager_router.include_router(tool_controller.router)
manager_router.include_router(tool_execution_controller.router)
__all__ = ["manager_router"]

View File

@@ -8,6 +8,7 @@ from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.models import ApiKeyType
from app.models.user_model import User
from app.core.response_utils import success
from app.schemas import api_key_schema
@@ -39,6 +40,8 @@ def create_api_key(
"""
try:
workspace_id = current_user.current_workspace_id
if data.type == ApiKeyType.SERVICE.value and not data.resource_id:
data.resource_id = workspace_id
# 创建 API Key
api_key_obj = ApiKeyService.create_api_key(

View File

@@ -421,8 +421,8 @@ async def draft_run(
# 流式返回
if payload.stream:
async def event_generator():
async for event in draft_service.run_stream(
agent_config=agent_cfg,
model_config=model_config,
@@ -574,7 +574,7 @@ async def draft_run(
# 3. 流式返回
if payload.stream:
logger.debug(
"开始多智能体流式试运行",
"开始工作流流式试运行",
extra={
"app_id": str(app_id),
"message_length": len(payload.message),
@@ -583,18 +583,27 @@ async def draft_run(
)
async def event_generator():
"""多智能体流式事件生成器"""
multiservice = MultiAgentService(db)
# 调用多智能体服务的流式方法
async for event in multiservice.run_stream(
"""工作流事件生成器
将事件转换为标准 SSE 格式:
event: <event_type>
data: <json_data>
"""
import json
# 调用工作流服务的流式方法
async for event in workflow_service.run_stream(
app_id=app_id,
request=multi_agent_request,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
payload=payload,
config=config
):
yield event
# 提取事件类型和数据
event_type = event.get("event", "message")
event_data = event.get("data", {})
# 转换为标准 SSE 格式(字符串)
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
yield sse_message
return StreamingResponse(
event_generator(),
@@ -617,7 +626,7 @@ async def draft_run(
)
result = await workflow_service.run(app_id, payload,config)
logger.debug(
"工作流试运行返回结果",
extra={

View File

@@ -0,0 +1,207 @@
# -*- coding: utf-8 -*-
"""情绪配置控制器模块
本模块提供情绪引擎配置管理的API端点包括获取和更新配置。
Routes:
GET /memory/config/emotion - 获取情绪引擎配置
POST /memory/config/emotion - 更新情绪引擎配置
"""
from fastapi import APIRouter, Depends, Query, HTTPException, status
from pydantic import BaseModel, Field
from typing import Optional
from sqlalchemy.orm import Session
from app.core.response_utils import success
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.response_schema import ApiResponse
from app.services.emotion_config_service import EmotionConfigService
from app.core.logging_config import get_api_logger
from app.db import get_db
# 获取API专用日志器
api_logger = get_api_logger()
router = APIRouter(
prefix="/memory/emotion",
tags=["Emotion Config"],
dependencies=[Depends(get_current_user)] # 所有路由都需要认证
)
class EmotionConfigQuery(BaseModel):
"""情绪配置查询请求模型"""
config_id: int = Field(..., description="配置ID")
class EmotionConfigUpdate(BaseModel):
"""情绪配置更新请求模型"""
config_id: int = Field(..., description="配置ID")
emotion_enabled: bool = Field(..., description="是否启用情绪提取")
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
emotion_min_intensity: float = Field(..., ge=0.0, le=1.0, description="最小情绪强度阈值0.0-1.0")
emotion_enable_subject: bool = Field(..., description="是否启用主体分类")
@router.get("/read_config", response_model=ApiResponse)
def get_emotion_config(
config_id: int = Query(..., description="配置ID"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取情绪引擎配置
查询指定配置ID的情绪相关配置字段。
Args:
config_id: 配置ID
Returns:
ApiResponse: 包含情绪配置数据
Example Response:
{
"code": 2000,
"msg": "情绪配置获取成功",
"data": {
"config_id": 17,
"emotion_enabled": true,
"emotion_model_id": "gpt-4",
"emotion_extract_keywords": true,
"emotion_min_intensity": 0.1,
"emotion_enable_subject": true
}
}
"""
try:
api_logger.info(
f"用户 {current_user.username} 请求获取情绪配置",
extra={"config_id": config_id}
)
# 初始化服务
config_service = EmotionConfigService(db)
# 调用服务层
data = config_service.get_emotion_config(config_id)
api_logger.info(
"情绪配置获取成功",
extra={
"config_id": config_id,
"emotion_enabled": data.get("emotion_enabled", False)
}
)
return success(data=data, msg="情绪配置获取成功")
except ValueError as e:
api_logger.warning(
f"获取情绪配置失败: {str(e)}",
extra={"config_id": config_id}
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e)
)
except Exception as e:
api_logger.error(
f"获取情绪配置失败: {str(e)}",
extra={"config_id": config_id},
exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取情绪配置失败: {str(e)}"
)
@router.post("/updated_config", response_model=ApiResponse)
def update_emotion_config(
config: EmotionConfigUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""更新情绪引擎配置
更新指定配置ID的情绪相关配置字段。
Args:
config: 配置更新数据包含config_id
Returns:
ApiResponse: 包含更新后的情绪配置数据
Example Request:
{
"config_id": 2,
"emotion_enabled": true,
"emotion_model_id": "gpt-4",
"emotion_extract_keywords": true,
"emotion_min_intensity": 0.1,
"emotion_enable_subject": true
}
Example Response:
{
"code": 2000,
"msg": "情绪配置更新成功",
"data": {
"config_id": 17,
"emotion_enabled": true,
"emotion_model_id": "gpt-4",
"emotion_extract_keywords": true,
"emotion_min_intensity": 0.2,
"emotion_enable_subject": true
}
}
"""
try:
api_logger.info(
f"用户 {current_user.username} 请求更新情绪配置",
extra={
"config_id": config.config_id,
"emotion_enabled": config.emotion_enabled,
"emotion_min_intensity": config.emotion_min_intensity
}
)
# 初始化服务
config_service = EmotionConfigService(db)
# 转换为字典排除config_id因为它作为参数传递
config_data = config.model_dump(exclude={'config_id'})
# 调用服务层
data = config_service.update_emotion_config(config.config_id, config_data)
api_logger.info(
"情绪配置更新成功",
extra={
"config_id": config.config_id,
"emotion_enabled": data.get("emotion_enabled", False)
}
)
return success(data=data, msg="情绪配置更新成功")
except ValueError as e:
api_logger.warning(
f"更新情绪配置失败: {str(e)}",
extra={"config_id": config.config_id}
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
api_logger.error(
f"更新情绪配置失败: {str(e)}",
extra={"config_id": config.config_id},
exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"更新情绪配置失败: {str(e)}"
)

View File

@@ -0,0 +1,255 @@
# -*- coding: utf-8 -*-
"""情绪分析控制器模块
本模块提供情绪分析相关的API端点包括情绪标签、词云、健康指数和个性化建议。
Routes:
POST /emotion/tags - 获取情绪标签统计
POST /emotion/wordcloud - 获取情绪词云数据
POST /emotion/health - 获取情绪健康指数
POST /emotion/suggestions - 获取个性化情绪建议
"""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode
from app.dependencies import get_current_user, get_db
from app.models.user_model import User
from app.schemas.response_schema import ApiResponse
from app.schemas.emotion_schema import (
EmotionTagsRequest,
EmotionWordcloudRequest,
EmotionHealthRequest,
EmotionSuggestionsRequest
)
from app.services.emotion_analytics_service import EmotionAnalyticsService
from app.core.logging_config import get_api_logger
# 获取API专用日志器
api_logger = get_api_logger()
router = APIRouter(
prefix="/memory/emotion",
tags=["Emotion Analysis"],
dependencies=[Depends(get_current_user)] # 所有路由都需要认证
)
# 初始化情绪分析服务uv
emotion_service = EmotionAnalyticsService()
@router.post("/tags", response_model=ApiResponse)
async def get_emotion_tags(
request: EmotionTagsRequest,
current_user: User = Depends(get_current_user),
):
try:
api_logger.info(
f"用户 {current_user.username} 请求获取情绪标签统计",
extra={
"group_id": request.group_id,
"emotion_type": request.emotion_type,
"start_date": request.start_date,
"end_date": request.end_date,
"limit": request.limit
}
)
# 调用服务层
data = await emotion_service.get_emotion_tags(
end_user_id=request.group_id,
emotion_type=request.emotion_type,
start_date=request.start_date,
end_date=request.end_date,
limit=request.limit
)
api_logger.info(
"情绪标签统计获取成功",
extra={
"group_id": request.group_id,
"total_count": data.get("total_count", 0),
"tags_count": len(data.get("tags", []))
}
)
return success(data=data, msg="情绪标签获取成功")
except Exception as e:
api_logger.error(
f"获取情绪标签统计失败: {str(e)}",
extra={"group_id": request.group_id},
exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取情绪标签统计失败: {str(e)}"
)
@router.post("/wordcloud", response_model=ApiResponse)
async def get_emotion_wordcloud(
request: EmotionWordcloudRequest,
current_user: User = Depends(get_current_user),
):
try:
api_logger.info(
f"用户 {current_user.username} 请求获取情绪词云数据",
extra={
"group_id": request.group_id,
"emotion_type": request.emotion_type,
"limit": request.limit
}
)
# 调用服务层
data = await emotion_service.get_emotion_wordcloud(
end_user_id=request.group_id,
emotion_type=request.emotion_type,
limit=request.limit
)
api_logger.info(
"情绪词云数据获取成功",
extra={
"group_id": request.group_id,
"total_keywords": data.get("total_keywords", 0)
}
)
return success(data=data, msg="情绪词云获取成功")
except Exception as e:
api_logger.error(
f"获取情绪词云数据失败: {str(e)}",
extra={"group_id": request.group_id},
exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取情绪词云数据失败: {str(e)}"
)
@router.post("/health", response_model=ApiResponse)
async def get_emotion_health(
request: EmotionHealthRequest,
current_user: User = Depends(get_current_user),
):
try:
# 验证时间范围参数
if request.time_range not in ["7d", "30d", "90d"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="时间范围参数无效,必须是 7d、30d 或 90d"
)
api_logger.info(
f"用户 {current_user.username} 请求获取情绪健康指数",
extra={
"group_id": request.group_id,
"time_range": request.time_range
}
)
# 调用服务层
data = await emotion_service.calculate_emotion_health_index(
end_user_id=request.group_id,
time_range=request.time_range
)
api_logger.info(
"情绪健康指数获取成功",
extra={
"group_id": request.group_id,
"health_score": data.get("health_score", 0),
"level": data.get("level", "未知")
}
)
return success(data=data, msg="情绪健康指数获取成功")
except HTTPException:
raise
except Exception as e:
api_logger.error(
f"获取情绪健康指数失败: {str(e)}",
extra={"group_id": request.group_id},
exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取情绪健康指数失败: {str(e)}"
)
@router.post("/suggestions", response_model=ApiResponse)
async def get_emotion_suggestions(
request: EmotionSuggestionsRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取个性化情绪建议
Args:
request: 包含 group_id 和可选的 config_id
db: 数据库会话
current_user: 当前用户
Returns:
个性化情绪建议响应
"""
try:
# 验证 config_id如果提供
config_id = request.config_id
if config_id is not None:
from app.controllers.memory_agent_controller import validate_config_id
try:
config_id = validate_config_id(config_id, db)
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", str(e))
api_logger.info(
f"用户 {current_user.username} 请求获取个性化情绪建议",
extra={
"group_id": request.group_id,
"config_id": config_id
}
)
# 调用服务层
data = await emotion_service.generate_emotion_suggestions(
end_user_id=request.group_id,
config_id=config_id
)
api_logger.info(
"个性化建议获取成功",
extra={
"group_id": request.group_id,
"suggestions_count": len(data.get("suggestions", []))
}
)
return success(data=data, msg="个性化建议获取成功")
except Exception as e:
api_logger.error(
f"获取个性化建议失败: {str(e)}",
extra={"group_id": request.group_id},
exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取个性化建议失败: {str(e)}"
)

View File

@@ -0,0 +1,269 @@
import asyncio
import time
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from sqlalchemy import text
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.core.memory.storage_services.reflection_engine.self_reflexion import ReflectionConfig, ReflectionEngine
from app.dependencies import get_current_user
from app.db import get_db
from app.models.user_model import User
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_reflection_service import WorkspaceAppService, MemoryReflectionService
from app.schemas.memory_reflection_schemas import Memory_Reflection
from app.services.model_service import ModelConfigService
load_dotenv()
api_logger = get_api_logger()
router = APIRouter(
prefix="/memory",
tags=["Memory"],
)
@router.post("/reflection/save")
async def save_reflection_config(
request: Memory_Reflection,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""Save reflection configuration to data_comfig table"""
try:
config_id = request.config_id
if not config_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="缺少必需参数: config_id"
)
api_logger.info(f"用户 {current_user.username} 保存反思配置config_id: {config_id}")
update_params = {
"enable_self_reflexion": request.reflection_enabled,
"iteration_period": request.reflection_period_in_hours,
"reflexion_range": request.reflexion_range,
"baseline": request.baseline,
"reflection_model_id": request.reflection_model_id,
"memory_verify": request.memory_verify,
"quality_assessment": request.quality_assessment,
}
query, params = DataConfigRepository.build_update_reflection(config_id, **update_params)
result = db.execute(text(query), params)
if result.rowcount == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"未找到config_id为 {config_id} 的配置"
)
db.commit()
# 查询更新后的配置
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
result = db.execute(text(select_query), select_params).fetchone()
if not result:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"更新后未找到config_id为 {config_id} 的配置"
)
api_logger.info(f"成功保存反思配置到数据库config_id: {config_id}")
reflection_result={
"config_id": result.config_id,
"enable_self_reflexion": result.enable_self_reflexion,
"iteration_period": result.iteration_period,
"reflexion_range": result.reflexion_range,
"baseline": result.baseline,
"reflection_model_id": result.reflection_model_id,
"memory_verify": result.memory_verify,
"quality_assessment": result.quality_assessment,
"user_id": result.user_id}
return success(data=reflection_result, msg="反思配置成功")
except ValueError as ve:
api_logger.error(f"参数错误: {str(ve)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"参数错误: {str(ve)}"
)
except Exception as e:
api_logger.error(f"反思配置保存失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"反思配置保存失败: {str(e)}"
)
@router.post("/reflection")
async def start_workspace_reflection(
config_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""Activate the reflection function for all matching applications in the workspace"""
workspace_id = current_user.current_workspace_id
reflection_service = MemoryReflectionService(db)
try:
api_logger.info(f"用户 {current_user.username} 启动workspace反思workspace_id: {workspace_id}")
service = WorkspaceAppService(db)
result = service.get_workspace_apps_detailed(workspace_id)
reflection_results = []
for data in result['apps_detailed_info']:
if data['data_configs'] == []:
continue
releases = data['releases']
data_configs = data['data_configs']
end_users = data['end_users']
for base, config, user in zip(releases, data_configs, end_users):
if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']:
# 调用反思服务
api_logger.info(f"为用户 {user['id']} 启动反思config_id: {config['config_id']}")
reflection_result = await reflection_service.start_reflection_from_data(
config_data=config,
end_user_id=user['id']
)
reflection_results.append({
"app_id": base['app_id'],
"config_id": config['config_id'],
"end_user_id": user['id'],
"reflection_result": reflection_result
})
return success(data=reflection_results, msg="反思配置成功")
except Exception as e:
api_logger.error(f"启动workspace反思失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"启动workspace反思失败: {str(e)}"
)
@router.get("/reflection/configs")
async def start_reflection_configs(
config_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""通过config_id查询data_config表中的反思配置信息"""
try:
api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}")
# 使用DataConfigRepository查询反思配置
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
result = db.execute(text(select_query), select_params).fetchone()
if not result:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"未找到config_id为 {config_id} 的配置"
)
# 构建返回数据
reflection_config = {
"config_id": result.config_id,
"reflection_enabled": result.enable_self_reflexion,
"reflection_period_in_hours": result.iteration_period,
"reflexion_range": result.reflexion_range,
"baseline": result.baseline,
"reflection_model_id": result.reflection_model_id,
"memory_verify": result.memory_verify,
"quality_assessment": result.quality_assessment,
"user_id": result.user_id
}
api_logger.info(f"成功查询反思配置config_id: {config_id}")
return success(data=reflection_config, msg="反思配置查询成功")
except HTTPException:
# 重新抛出HTTP异常
raise
except Exception as e:
api_logger.error(f"查询反思配置失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"查询反思配置失败: {str(e)}"
)
@router.get("/reflection/run")
async def reflection_run(
config_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""Activate the reflection function for all matching applications in the workspace"""
api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}")
# 使用DataConfigRepository查询反思配置
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
result = db.execute(text(select_query), select_params).fetchone()
if not result:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"未找到config_id为 {config_id} 的配置"
)
api_logger.info(f"成功查询反思配置config_id: {config_id}")
# 验证模型ID是否存在
model_id = result.reflection_model_id
if model_id:
try:
ModelConfigService.get_model_by_id(db=db, model_id=model_id)
api_logger.info(f"模型ID验证成功: {model_id}")
except Exception as e:
api_logger.warning(f"模型ID '{model_id}' 不存在,将使用默认模型: {str(e)}")
# 可以设置为None让反思引擎使用默认模型
model_id = None
config = ReflectionConfig(
enabled=result.enable_self_reflexion,
iteration_period=result.iteration_period,
reflexion_range=result.reflexion_range,
baseline=result.baseline,
output_example='',
memory_verify=result.memory_verify,
quality_assessment=result.quality_assessment,
violation_handling_strategy="block",
model_id=model_id
)
connector = Neo4jConnector()
engine = ReflectionEngine(
config=config,
neo4j_connector=connector,
llm_client=model_id # 传入验证后的 model_id
)
result=await (engine.reflection_run())
return success(data=result, msg="反思试运行")

View File

@@ -1,13 +1,9 @@
from fastapi import APIRouter, Depends, status, Query
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from sqlalchemy.orm import Session
from typing import List, Optional
from typing import Optional
import uuid
from app.core.models import RedBearLLM
from app.core.models.base import RedBearModelConfig
from app.db import get_db
from app.dependencies import get_current_user
from app.models.models_model import ModelProvider, ModelType
@@ -39,7 +35,7 @@ def get_model_providers():
@router.get("", response_model=ApiResponse)
def get_model_list(
type: Optional[List[model_schema.ModelType]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM&type=EMBEDDING"),
type: Optional[str] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING"),
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
@@ -54,13 +50,21 @@ def get_model_list(
支持多个 type 参数:
- 单个:?type=LLM
- 多个:?type=LLM&type=EMBEDDING
- 多个(逗号分隔)?type=LLM,EMBEDDING
- 多个(重复参数):?type=LLM&type=EMBEDDING
"""
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
try:
# 解析 type 参数(支持逗号分隔)
type_list = None
if type:
type_values = [t.strip() for t in type.split(',')]
type_list = [model_schema.ModelType(t.lower()) for t in type_values if t]
api_logger.error(f"获取模型type_list: {type_list}")
query = model_schema.ModelConfigQuery(
type=type,
type=type_list,
provider=provider,
is_active=is_active,
is_public=is_public,

View File

@@ -0,0 +1,138 @@
import uuid
from fastapi import APIRouter, Depends, Path
from sqlalchemy.orm import Session
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.dependencies import get_current_user, get_db
from app.models.prompt_optimizer_model import RoleType
from app.schemas.prompt_optimizer_schema import PromptOptMessage, PromptOptModelSet, CreateSessionResponse, \
OptimizePromptResponse, SessionHistoryResponse, SessionMessage
from app.schemas.response_schema import ApiResponse
from app.services.prompt_optimizer_service import PromptOptimizerService
router = APIRouter(prefix="/prompt", tags=["Prompts-Optimization"])
logger = get_api_logger()
@router.post(
"/sessions",
summary="Create a new prompt optimization session",
response_model=ApiResponse
)
def create_prompt_session(
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
Create a new prompt optimization session for the current user.
Returns:
ApiResponse: Contains the newly generated session ID.
"""
service = PromptOptimizerService(db)
# create new session
session = service.create_session(current_user.tenant_id, current_user.id)
result_schema = CreateSessionResponse.model_validate(session)
return success(data=result_schema)
@router.get(
"/sessions/{session_id}",
summary="获取 prompt 优化历史对话",
response_model=ApiResponse
)
def get_prompt_session(
session_id: uuid.UUID = Path(..., description="Session ID"),
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
Retrieve all messages from a specified prompt optimization session.
Args:
session_id (UUID): The ID of the session to retrieve
db (Session): Database session
current_user: Current logged-in user
Returns:
ApiResponse: Contains the session ID and the list of messages.
"""
service = PromptOptimizerService(db)
history = service.get_session_message_history(
session_id=session_id,
user_id=current_user.id
)
messages = [
SessionMessage(role=role, content=content)
for role, content in history
]
result = SessionHistoryResponse(
session_id=session_id,
messages=messages
)
return success(data=result)
@router.post(
"/sessions/{session_id}/messages",
summary="Get prompt optimization",
response_model=ApiResponse
)
async def get_prompt_opt(
session_id: uuid.UUID = Path(..., description="Session ID"),
data: PromptOptMessage = ...,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
Send a user message in the specified session and return the optimized prompt
along with its description and variables.
Args:
session_id (UUID): The session ID
data (PromptOptMessage): Contains the user message, model ID, and current prompt
db (Session): Database session
current_user: Current user information
Returns:
ApiResponse: Contains the optimized prompt, description, and a list of variables.
"""
service = PromptOptimizerService(db)
service.create_message(
tenant_id=current_user.tenant_id,
session_id=session_id,
user_id=current_user.id,
role=RoleType.USER,
content=data.message
)
opt_result = await service.optimize_prompt(
tenant_id=current_user.tenant_id,
model_id=data.model_id,
session_id=session_id,
user_id=current_user.id,
current_prompt=data.current_prompt,
user_require=data.message
)
service.create_message(
tenant_id=current_user.tenant_id,
session_id=session_id,
user_id=current_user.id,
role=RoleType.ASSISTANT,
content=opt_result.desc
)
variables = service.parser_prompt_variables(opt_result.prompt)
result = {
"prompt": opt_result.prompt,
"desc": opt_result.desc,
"variables": variables
}
result_schema = OptimizePromptResponse.model_validate(result)
return success(data=result_schema)

View File

@@ -1,10 +1,14 @@
"""Memory 服务接口 - 基于 API Key 认证"""
from fastapi import APIRouter, Depends
import uuid
from fastapi import APIRouter, Depends, Request, Body
from sqlalchemy.orm import Session
from app.db import get_db
from app.core.response_utils import success
from app.core.logging_config import get_business_logger
from app.core.api_key_auth import require_api_key
from app.schemas.api_key_schema import ApiKeyAuth
router = APIRouter(prefix="/memory", tags=["V1 - Memory API"])
logger = get_business_logger()
@@ -14,3 +18,31 @@ logger = get_business_logger()
async def get_memory_info():
"""获取记忆服务信息(占位)"""
return success(data={}, msg="Memory API - Coming Soon")
# /v1/memory/{resource_id}/chat
@router.post("/{resource_id}/chat")
@require_api_key(scopes=["memory"])
async def chat_with_agent_demo(
resource_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
message: str = Body(..., description="聊天消息内容"),
):
"""
Agent 聊天接口demo
scopes: 所需的权限范围列表["app", "rag", "memory"]
Args:
resource_id: 如果是应用的apikey传的是应用id; 如果是服务的apikey传的是工作空间id
message: 请求参数
request: 声明请求
api_key_auth: 包含验证后的API Key 信息
db: db_session
"""
logger.info(f"API Key Auth: {api_key_auth}")
logger.info(f"Resource ID: {resource_id}")
logger.info(f"Message: {message}")
return success(data={"received": True}, msg="消息已接收")

View File

@@ -0,0 +1,585 @@
"""工具管理API控制器"""
import base64
from typing import List, Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, Body
from langfuse.api.core import jsonable_encoder
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from pydantic import BaseModel, Field, PositiveInt, field_validator
from cryptography.fernet import Fernet
from app.db import get_db
from app.dependencies import get_current_user
from app.models import User
from app.models.tool_model import ToolConfig, BuiltinToolConfig, ToolType, ToolStatus, CustomToolConfig, MCPToolConfig
from app.core.logging_config import get_business_logger
from app.core.config import settings
from app.core.tools.config_manager import ConfigManager
logger = get_business_logger()
router = APIRouter(prefix="/tools", tags=["工具管理"])
# ==================== 辅助函数 ====================
def _encrypt_sensitive_params(parameters: Dict[str, Any]) -> Dict[str, Any]:
"""加密敏感参数"""
cipher_key = base64.urlsafe_b64encode(settings.SECRET_KEY[:32].ljust(32, '0').encode())
cipher = Fernet(cipher_key)
encrypted_params = {}
sensitive_keys = ['api_key', 'token', 'api_secret', 'password']
for key, value in parameters.items():
if any(sensitive in key.lower() for sensitive in sensitive_keys) and value:
encrypted_params[key] = cipher.encrypt(str(value).encode()).decode()
else:
encrypted_params[key] = value
return encrypted_params
def _decrypt_sensitive_params(parameters: Dict[str, Any]) -> Dict[str, Any]:
"""解密敏感参数"""
cipher_key = base64.urlsafe_b64encode(settings.SECRET_KEY[:32].ljust(32, '0').encode())
cipher = Fernet(cipher_key)
decrypted_params = {}
sensitive_keys = ['api_key', 'token', 'secret', 'password']
for key, value in parameters.items():
if any(sensitive in key.lower() for sensitive in sensitive_keys) and value:
try:
decrypted_params[key] = cipher.decrypt(value.encode()).decode()
except Exception as e:
decrypted_params[key] = value
else:
decrypted_params[key] = value
return decrypted_params
def _update_tool_status(tool_config: ToolConfig, builtin_config: BuiltinToolConfig = None, tool_info: Dict = None) -> str:
"""更新工具状态并返回新状态"""
if tool_config.tool_type == ToolType.BUILTIN:
if not tool_info or not tool_info.get('requires_config', False):
new_status = ToolStatus.ACTIVE.value # 不需要配置的内置工具
elif not builtin_config or not builtin_config.parameters:
new_status = ToolStatus.INACTIVE.value
else:
# 检查是否有必要的API密钥
has_key = bool(builtin_config.parameters.get('api_key') or builtin_config.parameters.get('token'))
new_status = ToolStatus.ACTIVE.value if has_key else ToolStatus.INACTIVE.value
else: # 自定义和MCP工具
new_status = ToolStatus.ACTIVE.value if tool_config.config_data else ToolStatus.ERROR.value
# 更新数据库中的状态
if tool_config.status != new_status:
tool_config.status = new_status
return new_status
# ==================== 请求/响应模型 ====================
class ToolListResponse(BaseModel):
"""工具列表响应"""
id: str
name: str
description: str
tool_type: str
category: str
version: str = "1.0.0"
status: str # active inactive error loading
requires_config: bool = False
# is_configured: bool = False
class Config:
from_attributes = True
class BuiltinToolConfigRequest(BaseModel):
"""内置工具配置请求"""
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
class CustomToolCreateRequest(BaseModel):
"""自定义工具创建请求体模型,包含参数校验规则"""
name: str = Field(..., min_length=1, max_length=100, description="工具名称,必填")
description: str = Field(None, description="工具描述")
base_url: str = Field(None, description="工具基础URL")
schema_url: str = Field(None, description="工具Schema URL")
schema_content: Optional[Dict[str, Any]] = Field(None, description="工具Schema内容可选")
auth_type: str = Field("none", pattern=r"^(none|api_key|bearer_token)$", description="认证类型")
auth_config: Optional[Dict[str, Any]] = Field(None, description="认证配置,默认空字典")
timeout: PositiveInt = Field(30, ge=1, le=300, description="超时时间1-300秒默认30")
# 自定义校验当auth_type为api_key时auth_config必须包含api_key字段
@field_validator("auth_config")
def validate_auth_config(cls, v, values):
auth_type = values.data.get("auth_type")
if auth_type == "api_key" and (not v or "api_key" not in v):
raise ValueError("认证类型为api_key时auth_config必须包含api_key字段")
if auth_type == "bearer_token" and (not v or "bearer_token" not in v):
raise ValueError("认证类型为bearer_token时auth_config必须包含bearer_token字段")
return v
class MCPToolCreateRequest(BaseModel):
"""MCP工具创建请求体模型适配MCP业务特性"""
# 基础必填字段(带长度/格式校验)
name: str = Field(..., min_length=1, max_length=100,description="MCP工具名称")
description: str = Field(None, description="MCP工具描述")
# MCP核心字段服务端URL强制HTTP/HTTPS格式
server_url: str = Field(..., description="MCP服务端URL仅支持http/https协议")
# 连接配置:默认空字典,可自定义校验规则(根据实际业务调整)
connection_config: Dict[str, Any] = Field({},description="MCP连接配置如认证信息、超时、重试等默认空字典")
@field_validator("connection_config")
def validate_connection_config(cls, v):
# 示例1若包含timeout必须是1-300的整数
if "timeout" in v:
timeout = v["timeout"]
if not isinstance(timeout, int) or timeout < 1 or timeout > 300:
raise ValueError("connection_config.timeout必须是1-300的整数")
return v
# @field_validator("server_url")
# def validate_server_url_protocol(cls, v):
# if v.scheme != "https":
# raise ValueError("MCP服务端URL仅支持HTTPS协议安全要求")
# return v
# ==================== API端点 ====================
@router.get("", response_model=List[ToolListResponse])
async def list_tools(
name: Optional[str] = None,
tool_type: Optional[str] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取工具列表包含内置工具、自定义工具和MCP工具"""
try:
# 初始化内置工具(如果需要)
config_manager = ConfigManager()
config_manager.ensure_builtin_tools_initialized(
current_user.tenant_id, db, ToolConfig, BuiltinToolConfig, ToolType, ToolStatus
)
response_tools = []
query = db.query(ToolConfig).filter(
ToolConfig.tenant_id == current_user.tenant_id
)
if tool_type:
query = query.filter(ToolConfig.tool_type == tool_type)
if name:
query = query.filter(ToolConfig.name.ilike(f"%{name}%"))
tools = query.all()
builtin_tools = config_manager.load_builtin_tools_config()
configured_tools = {tool_info["tool_class"]: tool_info for tool_key, tool_info in builtin_tools.items()}
for tool_config in tools:
if tool_config.tool_type == ToolType.BUILTIN.value:
builtin_config = db.query(BuiltinToolConfig).filter(BuiltinToolConfig.id == tool_config.id).first()
tool_info = configured_tools.get(builtin_config.tool_class)
status = _update_tool_status(tool_config, builtin_config, tool_info)
else:
status = _update_tool_status(tool_config)
response_tools.append(ToolListResponse(
id=str(tool_config.id),
name=tool_config.name,
description=tool_config.description,
tool_type=tool_config.tool_type,
category=tool_info['category'] if tool_config.tool_type == ToolType.BUILTIN.value else tool_config.tool_type,
version="1.0.0",
status=status,
requires_config=tool_info['requires_config'] if tool_config.tool_type == ToolType.BUILTIN.value else False,
))
return response_tools
except Exception as e:
logger.error(f"获取工具列表失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/builtin/{tool_id}")
async def get_builtin_tool_detail(
tool_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取内置工具详情"""
try:
config_manager = ConfigManager()
builtin_tools = config_manager.load_builtin_tools_config()
configured_tools = {tool_info["tool_class"]: tool_info for tool_key, tool_info in builtin_tools.items()}
tool_config = db.query(ToolConfig).filter(
ToolConfig.tenant_id == current_user.tenant_id,
ToolConfig.id == tool_id
).first()
builtin_config = db.query(BuiltinToolConfig).filter(BuiltinToolConfig.id == tool_config.id).first()
tool_info = configured_tools.get(builtin_config.tool_class)
is_configured = False
config_parameters = {}
if builtin_config and builtin_config.parameters:
is_configured = bool(builtin_config.parameters.get('api_key') or builtin_config.parameters.get('token'))
# 不返回敏感信息,只返回非敏感配置
config_parameters = {k: v for k, v in builtin_config.parameters.items()
if not any(sensitive in k.lower() for sensitive in ['key', 'secret', 'token', 'password'])}
return {
"id": tool_config.id,
"name": tool_config.name,
"description": tool_config.description,
"category": tool_info['category'],
"status": tool_config.tool_type,
"requires_config": tool_info['requires_config'],
"is_configured": is_configured,
"config_parameters": config_parameters
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取工具详情失败: {tool_id}, 错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/builtin/{tool_id}/configure")
async def configure_builtin_tool(
tool_id: str,
request: BuiltinToolConfigRequest = Body(...),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""配置内置工具参数(租户级别)"""
try:
# 查询工具配置
tool_config = db.query(ToolConfig).filter(
ToolConfig.tenant_id == current_user.tenant_id,
ToolConfig.id == tool_id,
ToolConfig.tool_type == ToolType.BUILTIN
).first()
if not tool_config:
raise HTTPException(status_code=404, detail="工具不存在")
# 获取内置工具配置
builtin_config = db.query(BuiltinToolConfig).filter(
BuiltinToolConfig.id == tool_config.id
).first()
if not builtin_config:
raise HTTPException(status_code=404, detail="内置工具配置不存在")
# 获取全局工具信息
config_manager = ConfigManager()
builtin_tools_config = config_manager.load_builtin_tools_config()
tool_info = None
for tool_key, info in builtin_tools_config.items():
if info['tool_class'] == builtin_config.tool_class:
tool_info = info
break
if not tool_info:
raise HTTPException(status_code=404, detail="工具信息不存在")
# 加密敏感参数
encrypted_params = _encrypt_sensitive_params(request.parameters)
# 更新配置
builtin_config.parameters = encrypted_params
# 更新状态
_update_tool_status(tool_config, builtin_config, tool_info)
db.commit()
return {
"success": True,
"message": f"工具 {tool_config.name} 配置成功"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"配置内置工具失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/builtin/{tool_id}/config")
async def get_builtin_tool_config(
tool_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取内置工具配置(用于使用)"""
try:
# 查询工具配置
tool_config = db.query(ToolConfig).filter(
ToolConfig.tenant_id == current_user.tenant_id,
ToolConfig.id == tool_id,
ToolConfig.tool_type == ToolType.BUILTIN
).first()
if not tool_config:
raise HTTPException(status_code=404, detail="工具不存在")
# 获取内置工具配置
builtin_config = db.query(BuiltinToolConfig).filter(
BuiltinToolConfig.id == tool_config.id
).first()
if not builtin_config:
raise HTTPException(status_code=404, detail="内置工具配置不存在")
# 解密参数
decrypted_params = _decrypt_sensitive_params(builtin_config.parameters or {})
return {
"tool_id": tool_id,
"tool_class": builtin_config.tool_class,
"name": tool_config.name,
"parameters": decrypted_params,
"status": tool_config.status
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取工具配置失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/custom")
async def create_custom_tool(
request: CustomToolCreateRequest = Body(...),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""创建自定义工具"""
try:
config_data = jsonable_encoder(request.model_dump())
config_data["tool_type"] = "custom"
config_manager = ConfigManager()
is_valid, error_msg = config_manager.validate_config(config_data, "custom")
if not is_valid:
raise HTTPException(status_code=400, detail=error_msg)
# 创建数据库记录
tool_config = ToolConfig(
name=request.name,
description=request.description,
tool_type=ToolType.CUSTOM,
tenant_id=current_user.tenant_id,
status=ToolStatus.ACTIVE.value,
config_data=config_data
)
db.add(tool_config)
db.flush()
# 创建CustomToolConfig记录
custom_config = CustomToolConfig(
id=tool_config.id,
base_url=request.base_url,
schema_url=request.schema_url,
schema_content=request.schema_content,
auth_type=request.auth_type,
auth_config=request.auth_config,
timeout=request.timeout
)
db.add(custom_config)
db.commit()
return {
"success": True,
"message": f"自定义工具 {request.name} 创建成功",
"tool_id": str(tool_config.id)
}
except HTTPException:
raise
except Exception as e:
logger.error(f"创建自定义工具失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/mcp")
async def create_mcp_tool(
request: MCPToolCreateRequest = Body(..., description="MCP工具创建参数"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""创建MCP工具"""
try:
config_data = jsonable_encoder(request.model_dump())
config_data["tool_type"] = "mcp"
config_manager = ConfigManager()
is_valid, error_msg = config_manager.validate_config(config_data, "mcp")
if not is_valid:
raise HTTPException(status_code=400, detail=error_msg)
# 创建数据库记录
try:
tool_config = ToolConfig(
name=request.name,
description=request.description,
tool_type=ToolType.MCP,
tenant_id=current_user.tenant_id,
status=ToolStatus.ACTIVE.value,
config_data=config_data
)
db.add(tool_config)
db.flush()
# 创建MCPToolConfig记录
mcp_config = MCPToolConfig(
id=tool_config.id,
server_url=request.server_url,
connection_config=request.connection_config
)
db.add(mcp_config)
db.commit()
except SQLAlchemyError as db_e:
db.rollback()
logger.error(f"创建MCP工具数据库操作失败租户ID{current_user.tenant_id},工具名:{request.name}: {str(db_e)}",
exc_info=True)
raise HTTPException(status_code=500, detail=f"创建MCP工具数据库操作失败租户ID{current_user.tenant_id}"
f"工具名:{request.name}{str(db_e)}")
return {
"success": True,
"message": f"MCP工具 {request.name} 创建成功",
"tool_id": str(tool_config.id)
}
except HTTPException:
raise
except Exception as e:
logger.error(f"创建MCP工具失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/{tool_id}")
async def delete_tool(
tool_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""删除工具仅限自定义和MCP工具"""
try:
tool = db.query(ToolConfig).filter(
ToolConfig.id == tool_id,
ToolConfig.tenant_id == current_user.tenant_id
).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
if tool.tool_type == ToolType.BUILTIN:
raise HTTPException(status_code=403, detail="内置工具不允许删除")
db.delete(tool)
db.commit()
return {
"success": True,
"message": f"工具 {tool.name} 删除成功"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"删除工具失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.put("/{tool_id}")
async def update_tool(
tool_id: str,
config_data: Optional[Dict[str, Any]] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""更新工具仅限自定义和MCP工具"""
try:
tool = db.query(ToolConfig).filter(
ToolConfig.id == tool_id,
ToolConfig.tenant_id == current_user.tenant_id
).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
if tool.tool_type == ToolType.BUILTIN:
raise HTTPException(status_code=403, detail="内置工具不允许修改")
if config_data is not None:
tool.config_data = config_data
# 更新状态
_update_tool_status(tool)
db.commit()
db.refresh(tool)
return {
"success": True,
"message": f"工具 {tool.name} 更新成功",
"status": tool.status
}
except HTTPException:
raise
except Exception as e:
logger.error(f"更新工具失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/{tool_id}/toggle")
async def toggle_tool_status(
tool_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""切换工具活跃/非活跃状态"""
try:
tool = db.query(ToolConfig).filter(
ToolConfig.id == tool_id,
ToolConfig.tenant_id == current_user.tenant_id
).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
# 在active和inactive之间切换
if tool.status == ToolStatus.ACTIVE.value:
tool.status = ToolStatus.INACTIVE.value
elif tool.status == ToolStatus.INACTIVE.value:
tool.status = ToolStatus.ACTIVE.value
else:
raise HTTPException(status_code=400, detail="只有可用或非活跃状态的工具可以切换")
db.commit()
db.refresh(tool)
return {
"success": True,
"message": f"工具 {tool.name} 状态已更新为 {tool.status}",
"status": tool.status
}
except HTTPException:
raise
except Exception as e:
logger.error(f"切换工具状态失败: {e}")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -0,0 +1,430 @@
"""工具执行API控制器"""
import uuid
from typing import Dict, Any, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Path, Query
from sqlalchemy.orm import Session
from pydantic import BaseModel, Field
from app.db import get_db
from app.dependencies import get_current_user
from app.models import User
from app.core.tools.registry import ToolRegistry
from app.core.tools.executor import ToolExecutor
from app.core.tools.chain_manager import ChainManager, ChainDefinition, ChainStep, ChainExecutionMode
from app.core.tools.builtin import *
from app.core.logging_config import get_business_logger
logger = get_business_logger()
router = APIRouter(prefix="/tools/execution", tags=["工具执行"])
# ==================== 请求/响应模型 ====================
class ToolExecutionRequest(BaseModel):
"""工具执行请求"""
tool_id: str = Field(..., description="工具ID")
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
timeout: Optional[float] = Field(None, ge=1, le=300, description="超时时间(秒)")
metadata: Optional[Dict[str, Any]] = Field(None, description="额外元数据")
class BatchExecutionRequest(BaseModel):
"""批量执行请求"""
executions: List[ToolExecutionRequest] = Field(..., description="执行列表")
max_concurrency: int = Field(5, ge=1, le=20, description="最大并发数")
class ToolExecutionResponse(BaseModel):
"""工具执行响应"""
success: bool
execution_id: str
tool_id: str
data: Any = None
error: Optional[str] = None
error_code: Optional[str] = None
execution_time: float
token_usage: Optional[Dict[str, int]] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
class ChainStepRequest(BaseModel):
"""链步骤请求"""
tool_id: str = Field(..., description="工具ID")
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
condition: Optional[str] = Field(None, description="执行条件")
output_mapping: Optional[Dict[str, str]] = Field(None, description="输出映射")
error_handling: str = Field("stop", description="错误处理策略")
class ChainExecutionRequest(BaseModel):
"""链执行请求"""
name: str = Field(..., description="链名称")
description: str = Field("", description="链描述")
steps: List[ChainStepRequest] = Field(..., description="执行步骤")
execution_mode: str = Field("sequential", description="执行模式")
initial_variables: Optional[Dict[str, Any]] = Field(None, description="初始变量")
global_timeout: Optional[float] = Field(None, description="全局超时")
class ExecutionHistoryResponse(BaseModel):
"""执行历史响应"""
execution_id: str
tool_id: str
status: str
started_at: Optional[str]
completed_at: Optional[str]
execution_time: Optional[float]
user_id: Optional[str]
workspace_id: Optional[str]
input_data: Optional[Dict[str, Any]]
output_data: Optional[Any]
error_message: Optional[str]
token_usage: Optional[Dict[str, int]]
class ToolConnectionTestResponse(BaseModel):
"""工具连接测试响应"""
success: bool
message: str
error: Optional[str] = None
details: Optional[Dict[str, Any]] = None
# ==================== 依赖注入 ====================
def get_tool_registry(db: Session = Depends(get_db)) -> ToolRegistry:
"""获取工具注册表"""
registry = ToolRegistry(db)
# 注册内置工具类
registry.register_tool_class(DateTimeTool)
registry.register_tool_class(JsonTool)
registry.register_tool_class(BaiduSearchTool)
registry.register_tool_class(MinerUTool)
registry.register_tool_class(TextInTool)
return registry
def get_tool_executor(
db: Session = Depends(get_db),
registry: ToolRegistry = Depends(get_tool_registry)
) -> ToolExecutor:
"""获取工具执行器"""
return ToolExecutor(db, registry)
def get_chain_manager(executor: ToolExecutor = Depends(get_tool_executor)) -> ChainManager:
"""获取链管理器"""
return ChainManager(executor)
# ==================== API端点 ====================
@router.post("/execute", response_model=ToolExecutionResponse)
async def execute_tool(
request: ToolExecutionRequest,
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""执行单个工具"""
try:
# 生成执行ID
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
# 执行工具
result = await executor.execute_tool(
tool_id=request.tool_id,
parameters=request.parameters,
user_id=current_user.id,
workspace_id=current_user.current_workspace_id,
execution_id=execution_id,
timeout=request.timeout,
metadata=request.metadata
)
return ToolExecutionResponse(
success=result.success,
execution_id=execution_id,
tool_id=request.tool_id,
data=result.data,
error=result.error,
error_code=result.error_code,
execution_time=result.execution_time,
token_usage=result.token_usage,
metadata=result.metadata
)
except Exception as e:
logger.error(f"工具执行失败: {request.tool_id}, 错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/batch", response_model=List[ToolExecutionResponse])
async def execute_tools_batch(
request: BatchExecutionRequest,
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""批量执行工具"""
try:
# 准备执行配置
execution_configs = []
execution_ids = []
for exec_request in request.executions:
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
execution_ids.append(execution_id)
execution_configs.append({
"tool_id": exec_request.tool_id,
"parameters": exec_request.parameters,
"user_id": current_user.id,
"workspace_id": current_user.current_workspace_id,
"execution_id": execution_id,
"timeout": exec_request.timeout,
"metadata": exec_request.metadata
})
# 批量执行
results = await executor.execute_tools_batch(
execution_configs,
max_concurrency=request.max_concurrency
)
# 转换响应格式
responses = []
for i, result in enumerate(results):
responses.append(ToolExecutionResponse(
success=result.success,
execution_id=execution_ids[i],
tool_id=request.executions[i].tool_id,
data=result.data,
error=result.error,
error_code=result.error_code,
execution_time=result.execution_time,
token_usage=result.token_usage,
metadata=result.metadata
))
return responses
except Exception as e:
logger.error(f"批量执行失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/chain", response_model=Dict[str, Any])
async def execute_tool_chain(
request: ChainExecutionRequest,
current_user: User = Depends(get_current_user),
chain_manager: ChainManager = Depends(get_chain_manager)
):
"""执行工具链"""
try:
# 转换步骤格式
steps = []
for step_request in request.steps:
step = ChainStep(
tool_id=step_request.tool_id,
parameters=step_request.parameters,
condition=step_request.condition,
output_mapping=step_request.output_mapping,
error_handling=step_request.error_handling
)
steps.append(step)
# 创建链定义
chain_definition = ChainDefinition(
name=request.name,
description=request.description,
steps=steps,
execution_mode=ChainExecutionMode(request.execution_mode),
global_timeout=request.global_timeout
)
# 注册并执行链
chain_manager.register_chain(chain_definition)
result = await chain_manager.execute_chain(
chain_name=request.name,
initial_variables=request.initial_variables
)
return result
except Exception as e:
logger.error(f"工具链执行失败: {request.name}, 错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/running", response_model=List[Dict[str, Any]])
async def get_running_executions(
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""获取正在运行的执行"""
try:
running_executions = executor.get_running_executions()
# 过滤当前工作空间的执行
workspace_executions = [
exec_info for exec_info in running_executions
if exec_info.get("workspace_id") == str(current_user.current_workspace_id)
]
return workspace_executions
except Exception as e:
logger.error(f"获取运行中执行失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/cancel/{execution_id}", response_model=Dict[str, Any])
async def cancel_execution(
execution_id: str = Path(..., description="执行ID"),
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""取消工具执行"""
try:
success = await executor.cancel_execution(execution_id)
if success:
return {
"success": True,
"message": "执行已取消"
}
else:
raise HTTPException(status_code=404, detail="执行不存在或已完成")
except HTTPException:
raise
except Exception as e:
logger.error(f"取消执行失败: {execution_id}, 错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/history", response_model=List[ExecutionHistoryResponse])
async def get_execution_history(
tool_id: Optional[str] = Query(None, description="工具ID过滤"),
limit: int = Query(50, ge=1, le=200, description="返回数量限制"),
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""获取执行历史"""
try:
history = executor.get_execution_history(
tool_id=tool_id,
user_id=current_user.id,
workspace_id=current_user.current_workspace_id,
limit=limit
)
# 转换响应格式
responses = []
for record in history:
responses.append(ExecutionHistoryResponse(
execution_id=record["execution_id"],
tool_id=record["tool_id"],
status=record["status"],
started_at=record["started_at"],
completed_at=record["completed_at"],
execution_time=record["execution_time"],
user_id=record["user_id"],
workspace_id=record["workspace_id"],
input_data=record["input_data"],
output_data=record["output_data"],
error_message=record["error_message"],
token_usage=record["token_usage"]
))
return responses
except Exception as e:
logger.error(f"获取执行历史失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/statistics", response_model=Dict[str, Any])
async def get_execution_statistics(
days: int = Query(7, ge=1, le=90, description="统计天数"),
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""获取执行统计"""
try:
stats = executor.get_execution_statistics(
workspace_id=current_user.current_workspace_id,
days=days
)
return {
"success": True,
"statistics": stats
}
except Exception as e:
logger.error(f"获取执行统计失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/chains/running", response_model=List[Dict[str, Any]])
async def get_running_chains(
current_user: User = Depends(get_current_user),
chain_manager: ChainManager = Depends(get_chain_manager)
):
"""获取正在运行的工具链"""
try:
running_chains = chain_manager.get_running_chains()
return running_chains
except Exception as e:
logger.error(f"获取运行中工具链失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/chains", response_model=List[Dict[str, Any]])
async def list_tool_chains(
current_user: User = Depends(get_current_user),
chain_manager: ChainManager = Depends(get_chain_manager)
):
"""列出工具链"""
try:
chains = chain_manager.list_chains()
return chains
except Exception as e:
logger.error(f"获取工具链列表失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/test-connection/{tool_id}", response_model=ToolConnectionTestResponse)
async def test_tool_connection(
tool_id: str = Path(..., description="工具ID"),
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""测试工具连接"""
try:
result = await executor.test_tool_connection(
tool_id=tool_id,
user_id=current_user.id,
workspace_id=current_user.current_workspace_id
)
return ToolConnectionTestResponse(
success=result.get("success", False),
message=result.get("message", ""),
error=result.get("error"),
details=result.get("details")
)
except Exception as e:
logger.error(f"工具连接测试失败: {tool_id}, 错误: {e}")
return ToolConnectionTestResponse(
success=False,
message="连接测试失败",
error=str(e)
)

View File

@@ -471,28 +471,52 @@ async def run_workflow(
import json
async def event_generator():
"""生成 SSE 事件"""
"""生成 SSE 事件
SSE 格式:
event: <event_type>
data: <json_data>
支持的事件类型:
- workflow_start: 工作流开始
- workflow_end: 工作流结束
- node_start: 节点开始执行
- node_end: 节点执行完成
- node_chunk: 中间节点的流式输出
- message: 最终消息的流式输出End 节点及其相邻节点)
"""
try:
async for event in service.run_workflow(
async for event in await service.run_workflow(
app_id=app_id,
input_data=input_data,
triggered_by=current_user.id,
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
stream=True
):
# 转换为 SSE 格式
yield f"data: {json.dumps(event)}\n\n"
# 提取事件类型和数据
event_type = event.get("event", "message")
event_data = event.get("data", {})
# 转换为标准 SSE 格式(字符串)
# event: <type>
# data: <json>
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
yield sse_message
except Exception as e:
logger.error(f"流式执行异常: {e}", exc_info=True)
error_event = {
"type": "error",
"error": str(e)
}
yield f"data: {json.dumps(error_event)}\n\n"
# 发送错误事件
sse_error = f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
yield sse_error
return StreamingResponse(
event_generator(),
media_type="text/event-stream"
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # 禁用 nginx 缓冲
}
)
else:
# 非流式执行

View File

@@ -9,18 +9,15 @@ LangChain Agent 封装
"""
import os
import time
import asyncio
from typing import Dict, Any, List, Optional, AsyncGenerator, Sequence
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from langchain_core.tools import BaseTool
from langchain.agents import create_agent
from app.core.memory.agent.mcp_server.services import session_service
from app.core.memory.agent.utils.redis_tool import store
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
from app.core.logging_config import get_business_logger
from app.services.memory_agent_service import MemoryAgentService
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task

View File

@@ -37,9 +37,10 @@ def require_api_key(
@require_api_key(scopes=["app"])
def chat_with_app(
resource_id: uuid.UUID,
api_key_auth: ApiKeyAuth = Depends(),
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
message: str
message: str = Query(..., description="聊天消息内容")
):
# api_key_auth 包含验证后的API Key 信息
pass
@@ -70,29 +71,6 @@ def require_api_key(
})
raise BusinessException("API Key 无效或已过期", BizCode.API_KEY_INVALID)
rate_limiter = RateLimiterService()
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj)
if not is_allowed:
logger.warning("API Key 限流触发", extra={
"api_key_id": str(api_key_obj.id),
"endpoint": str(request.url),
"method": request.method,
"error_msg": error_msg
})
# 根据错误消息判断限流类型
if "QPS" in error_msg:
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
elif "Daily" in error_msg:
code = BizCode.API_KEY_DAILY_LIMIT_EXCEEDED
else:
code = BizCode.API_KEY_QUOTA_EXCEEDED
raise RateLimitException(
error_msg,
code,
rate_headers=rate_headers
)
if scopes:
missing_scopes = []
for scope in scopes:
@@ -138,6 +116,30 @@ def require_api_key(
scopes=api_key_obj.scopes,
resource_id=api_key_obj.resource_id,
)
rate_limiter = RateLimiterService()
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj)
if not is_allowed:
logger.warning("API Key 限流触发", extra={
"api_key_id": str(api_key_obj.id),
"endpoint": str(request.url),
"method": request.method,
"error_msg": error_msg
})
# 根据错误消息判断限流类型
if "QPS" in error_msg:
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
elif "Daily" in error_msg:
code = BizCode.API_KEY_DAILY_LIMIT_EXCEEDED
else:
code = BizCode.API_KEY_QUOTA_EXCEEDED
raise RateLimitException(
error_msg,
code,
rate_headers=rate_headers
)
start_time = time.perf_counter()
response = await func(*args, **kwargs)
end_time = time.perf_counter()

View File

@@ -16,7 +16,7 @@ def generate_api_key(key_type: ApiKeyType) -> str:
key_type: API Key 类型
Returns:
tuple: (api_key, key_hash, key_prefix)
str: api_key
"""
# 前缀映射
prefix_map = {

View File

@@ -148,6 +148,7 @@ class Settings:
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
# Memory Module Configuration (internal)
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
@@ -156,6 +157,12 @@ class Settings:
MEMORY_RUNTIME_FILE: str = os.getenv("MEMORY_RUNTIME_FILE", "runtime.json")
MEMORY_DBRUN_FILE: str = os.getenv("MEMORY_DBRUN_FILE", "dbrun.json")
# Tool Management Configuration
TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools")
TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60"))
TOOL_MAX_CONCURRENCY: int = int(os.getenv("TOOL_MAX_CONCURRENCY", "10"))
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
def get_memory_output_path(self, filename: str = "") -> str:
"""
Get the full path for memory module output files.

View File

@@ -0,0 +1,85 @@
"""Emotion extraction models for LLM structured output.
This module contains Pydantic models for emotion extraction from statements,
designed to be used with LLM structured output capabilities.
Classes:
EmotionExtraction: Model for emotion extraction results from statements
"""
from pydantic import BaseModel, Field, field_validator
from typing import List, Optional
class EmotionExtraction(BaseModel):
"""Emotion extraction result model for LLM structured output.
This model represents the structured emotion information extracted from
a statement using LLM. It includes emotion type, intensity, keywords,
subject classification, and optional target.
Attributes:
emotion_type: Type of emotion (joy/sadness/anger/fear/surprise/neutral)
emotion_intensity: Intensity of emotion (0.0-1.0)
emotion_keywords: List of emotion keywords from the statement (max 3)
emotion_subject: Subject of emotion (self/other/object)
emotion_target: Optional target of emotion (person or object name)
"""
emotion_type: str = Field(
...,
description="Emotion type: joy/sadness/anger/fear/surprise/neutral"
)
emotion_intensity: float = Field(
...,
ge=0.0,
le=1.0,
description="Emotion intensity from 0.0 to 1.0"
)
emotion_keywords: List[str] = Field(
default_factory=list,
description="Emotion keywords extracted from the statement (max 3)"
)
emotion_subject: str = Field(
...,
description="Emotion subject: self/other/object"
)
emotion_target: Optional[str] = Field(
None,
description="Emotion target: person or object name"
)
@field_validator('emotion_type')
@classmethod
def validate_emotion_type(cls, v):
"""Validate emotion type is one of the valid values."""
valid_types = ['joy', 'sadness', 'anger', 'fear', 'surprise', 'neutral']
if v not in valid_types:
raise ValueError(f"emotion_type must be one of {valid_types}, got {v}")
return v
@field_validator('emotion_subject')
@classmethod
def validate_emotion_subject(cls, v):
"""Validate emotion subject is one of the valid values."""
valid_subjects = ['self', 'other', 'object']
if v not in valid_subjects:
raise ValueError(f"emotion_subject must be one of {valid_subjects}, got {v}")
return v
@field_validator('emotion_keywords')
@classmethod
def validate_emotion_keywords(cls, v):
"""Validate and limit emotion keywords to max 3 items."""
if not isinstance(v, list):
return []
# Limit to max 3 keywords
return v[:3]
@field_validator('emotion_intensity')
@classmethod
def validate_emotion_intensity(cls, v):
"""Validate emotion intensity is within valid range."""
if not (0.0 <= v <= 1.0):
raise ValueError(f"emotion_intensity must be between 0.0 and 1.0, got {v}")
return v

View File

@@ -215,24 +215,58 @@ class StatementNode(Node):
Attributes:
chunk_id: ID of the parent chunk this statement belongs to
stmt_type: Type of the statement (from ontology)
temporal_info: Temporal information extracted from the statement
statement: The actual statement text content
connect_strength: Classification of connection strength ('Strong' or 'Weak')
emotion_intensity: Optional emotion intensity (0.0-1.0) - displayed on node
emotion_target: Optional emotion target (person or object name)
emotion_subject: Optional emotion subject (self/other/object)
emotion_type: Optional emotion type (joy/sadness/anger/fear/surprise/neutral)
emotion_keywords: Optional list of emotion keywords (max 3)
temporal_info: Temporal information extracted from the statement
valid_at: Optional start date of temporal validity
invalid_at: Optional end date of temporal validity
statement_embedding: Optional embedding vector for the statement
chunk_embedding: Optional embedding vector for the parent chunk
connect_strength: Classification of connection strength ('Strong' or 'Weak')
config_id: Configuration ID used to process this statement
"""
# Core fields (ordered as requested)
chunk_id: str = Field(..., description="ID of the parent chunk")
stmt_type: str = Field(..., description="Type of the statement")
temporal_info: TemporalInfo = Field(..., description="Temporal information")
statement: str = Field(..., description="The statement text content")
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
# Emotion fields (ordered as requested, emotion_intensity first for display)
emotion_intensity: Optional[float] = Field(
None,
ge=0.0,
le=1.0,
description="Emotion intensity: 0.0-1.0 (displayed on node)"
)
emotion_target: Optional[str] = Field(
None,
description="Emotion target: person or object name"
)
emotion_subject: Optional[str] = Field(
None,
description="Emotion subject: self/other/object"
)
emotion_type: Optional[str] = Field(
None,
description="Emotion type: joy/sadness/anger/fear/surprise/neutral"
)
emotion_keywords: Optional[List[str]] = Field(
default_factory=list,
description="Emotion keywords list, max 3 items"
)
# Temporal fields
temporal_info: TemporalInfo = Field(..., description="Temporal information")
valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
# Embedding and other fields
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)")
@field_validator('valid_at', 'invalid_at', mode='before')
@@ -240,6 +274,39 @@ class StatementNode(Node):
def validate_datetime(cls, v):
"""使用通用的历史日期解析函数"""
return parse_historical_datetime(v)
@field_validator('emotion_type', mode='before')
@classmethod
def validate_emotion_type(cls, v):
"""Validate emotion type is one of the valid values"""
if v is None:
return v
valid_types = ['joy', 'sadness', 'anger', 'fear', 'surprise', 'neutral']
if v not in valid_types:
raise ValueError(f"emotion_type must be one of {valid_types}, got {v}")
return v
@field_validator('emotion_subject', mode='before')
@classmethod
def validate_emotion_subject(cls, v):
"""Validate emotion subject is one of the valid values"""
if v is None:
return v
valid_subjects = ['self', 'other', 'object']
if v not in valid_subjects:
raise ValueError(f"emotion_subject must be one of {valid_subjects}, got {v}")
return v
@field_validator('emotion_keywords', mode='before')
@classmethod
def validate_emotion_keywords(cls, v):
"""Validate emotion keywords list has max 3 items"""
if v is None:
return []
if not isinstance(v, list):
return []
# Limit to max 3 keywords
return v[:3]
class ChunkNode(Node):

View File

@@ -64,6 +64,11 @@ class Statement(BaseModel):
connect_strength: Optional connection strength ('Strong' or 'Weak')
temporal_validity: Optional temporal validity range
triplet_extraction_info: Optional triplet extraction results
emotion_type: Optional emotion type (joy/sadness/anger/fear/surprise/neutral)
emotion_intensity: Optional emotion intensity (0.0-1.0)
emotion_keywords: Optional list of emotion keywords
emotion_subject: Optional emotion subject (self/other/object)
emotion_target: Optional emotion target (person or object name)
"""
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.")
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
@@ -80,6 +85,12 @@ class Statement(BaseModel):
triplet_extraction_info: Optional[TripletExtractionResponse] = Field(
None, description="The triplet extraction information of the statement."
)
# Emotion fields
emotion_type: Optional[str] = Field(None, description="Emotion type: joy/sadness/anger/fear/surprise/neutral")
emotion_intensity: Optional[float] = Field(None, ge=0.0, le=1.0, description="Emotion intensity: 0.0-1.0")
emotion_keywords: Optional[List[str]] = Field(default_factory=list, description="Emotion keywords, max 3")
emotion_subject: Optional[str] = Field(None, description="Emotion subject: self/other/object")
emotion_target: Optional[str] = Field(None, description="Emotion target: person or object name")
class ConversationContext(BaseModel):

View File

@@ -480,7 +480,6 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
- global_redirect: dict losing_id -> canonical_id accumulated across rounds
- records: textual logs including per-round/per-block summaries and per-pair decisions
"""
import asyncio
import random
# 初始化全局日志和全局ID映射存储所有轮次的结果
records: List[str] = []

View File

@@ -36,7 +36,6 @@ from app.core.memory.models.graph_models import (
from app.core.memory.models.message_models import DialogData
from app.core.memory.models.variate_config import (
ExtractionPipelineConfig,
StatementExtractionConfig,
)
from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_dedup import (
dedup_layers_and_merge_and_return,
@@ -182,11 +181,12 @@ class ExtractionOrchestrator:
all_statements_list.extend(chunk.statements)
total_statements = len(all_statements_list)
# 步骤 2: 并行执行三元组提取、时间信息提取和基础嵌入生成
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取和嵌入生成")
# 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成")
(
triplet_maps,
temporal_maps,
emotion_maps,
statement_embedding_maps,
chunk_embedding_maps,
dialog_embeddings,
@@ -209,78 +209,13 @@ class ExtractionOrchestrator:
logger.info("步骤 3/6: 生成实体嵌入")
triplet_maps = await self._generate_entity_embeddings(triplet_maps)
# 进度回调:按三个阶段分别输出知识抽取结果
if self.progress_callback:
# 第一阶段:陈述句提取结果
for i, stmt in enumerate(all_statements_list[:10]): # 只输出前10个陈述句
stmt_result = {
"extraction_type": "statement",
"statement_index": i + 1,
"statement": stmt.statement,
"statement_id": stmt.id
}
await self.progress_callback("knowledge_extraction_result", "陈述句提取完成", stmt_result)
# 第二阶段:三元组提取结果
for i, triplet in enumerate(all_triplets_list[:10]): # 只输出前10个三元组
triplet_result = {
"extraction_type": "triplet",
"triplet_index": i + 1,
"subject": triplet.subject_name,
"predicate": triplet.predicate,
"object": triplet.object_name
}
await self.progress_callback("knowledge_extraction_result", "三元组提取完成", triplet_result)
# 第三阶段:时间提取结果
if total_temporal > 0:
# 收集时间信息
temporal_results = []
for dialog in dialog_data_list:
for chunk in dialog.chunks:
for statement in chunk.statements:
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
temporal_results.append({
"statement_id": statement.id,
"statement": statement.statement,
"valid_at": statement.temporal_validity.valid_at,
"invalid_at": statement.temporal_validity.invalid_at
})
# 输出时间提取结果
for i, temporal_result in enumerate(temporal_results[:5]): # 只输出前5个时间提取结果
time_result = {
"extraction_type": "temporal",
"temporal_index": i + 1,
"statement": temporal_result["statement"],
"valid_at": temporal_result["valid_at"],
"invalid_at": temporal_result["invalid_at"]
}
await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result)
else:
# 如果没有时间信息,也发送一个时间提取完成的消息
time_result = {
"extraction_type": "temporal",
"temporal_index": 0,
"message": "未发现时间信息"
}
await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result)
# 进度回调:知识抽取完成,传递知识抽取的统计信息
extraction_stats = {
"statements_count": total_statements,
"entities_count": total_entities,
"triplets_count": total_triplets,
"temporal_ranges_count": total_temporal,
}
await self.progress_callback("knowledge_extraction_complete", "知识抽取完成", extraction_stats)
# 步骤 4: 将提取的数据赋值到语句
logger.info("步骤 4/6: 数据赋值")
dialog_data_list = await self._assign_extracted_data(
dialog_data_list,
temporal_maps,
triplet_maps,
emotion_maps,
statement_embedding_maps,
chunk_embedding_maps,
dialog_embeddings,
@@ -288,6 +223,9 @@ class ExtractionOrchestrator:
# 步骤 5: 创建节点和边
logger.info("步骤 5/6: 创建节点和边")
# 注意creating_nodes_edges 消息已在知识抽取完成后立即发送
(
dialogue_nodes,
chunk_nodes,
@@ -307,6 +245,8 @@ class ExtractionOrchestrator:
else:
logger.info("步骤 6/6: 两阶段去重和消歧")
# 注意deduplication 消息已在创建节点和边完成后立即发送
result = await self._run_dedup_and_write_summary(
dialogue_nodes,
chunk_nodes,
@@ -331,7 +271,7 @@ class ExtractionOrchestrator:
self, dialog_data_list: List[DialogData]
) -> List[DialogData]:
"""
从对话中提取陈述句(优化版:全局分块级并行
从对话中提取陈述句(流式输出版本:边提取边发送进度
Args:
dialog_data_list: 对话数据列表
@@ -339,7 +279,7 @@ class ExtractionOrchestrator:
Returns:
更新后的对话数据列表(包含提取的陈述句)
"""
logger.info("开始陈述句提取(全局分块级并行)")
logger.info("开始陈述句提取(全局分块级并行 + 流式输出")
# 收集所有分块及其元数据
all_chunks = []
@@ -352,17 +292,44 @@ class ExtractionOrchestrator:
chunk_metadata.append((d_idx, c_idx))
logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取")
# 用于跟踪已完成的分块数量
completed_chunks = 0
total_chunks = len(all_chunks)
# 全局并行处理所有分块
async def extract_for_chunk(chunk_data):
async def extract_for_chunk(chunk_data, chunk_index):
nonlocal completed_chunks
chunk, group_id, dialogue_content = chunk_data
try:
return await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content)
statements = await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content)
# 流式输出:每提取完一个分块的陈述句,立即发送进度
# 注意:只在试运行模式下发送陈述句详情,正式模式不发送
completed_chunks += 1
if self.progress_callback and statements and self.is_pilot_run:
# 发送前3个陈述句作为示例
for idx, stmt in enumerate(statements[:3]):
stmt_result = {
"extraction_type": "statement",
"statement": stmt.statement,
"statement_id": stmt.id,
"chunk_progress": f"{completed_chunks}/{total_chunks}",
"statement_index_in_chunk": idx + 1
}
await self.progress_callback(
"knowledge_extraction_result",
f"陈述句提取中 ({completed_chunks}/{total_chunks})",
stmt_result
)
return statements
except Exception as e:
logger.error(f"分块 {chunk.id} 陈述句提取失败: {e}")
completed_chunks += 1
return []
tasks = [extract_for_chunk(chunk_data) for chunk_data in all_chunks]
tasks = [extract_for_chunk(chunk_data, i) for i, chunk_data in enumerate(all_chunks)]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 将结果分配回对话
@@ -394,7 +361,7 @@ class ExtractionOrchestrator:
self, dialog_data_list: List[DialogData]
) -> List[Dict[str, Any]]:
"""
从对话中提取三元组(优化版:全局陈述句级并行
从对话中提取三元组(流式输出版本:边提取边发送进度
Args:
dialog_data_list: 对话数据列表
@@ -402,7 +369,7 @@ class ExtractionOrchestrator:
Returns:
三元组映射列表,每个对话对应一个字典
"""
logger.info("开始三元组提取(全局陈述句级并行)")
logger.info("开始三元组提取(全局陈述句级并行 + 流式输出")
# 收集所有陈述句及其元数据
all_statements = []
@@ -415,20 +382,32 @@ class ExtractionOrchestrator:
statement_metadata.append((d_idx, statement.id))
logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取三元组")
# 用于跟踪已完成的陈述句数量
completed_statements = 0
total_statements = len(all_statements)
# 全局并行处理所有陈述句
async def extract_for_statement(stmt_data):
async def extract_for_statement(stmt_data, stmt_index):
nonlocal completed_statements
statement, chunk_content = stmt_data
try:
return await self.triplet_extractor._extract_triplets(statement, chunk_content)
triplet_info = await self.triplet_extractor._extract_triplets(statement, chunk_content)
# 注意:不再发送三元组提取的流式输出
# 三元组提取在后台执行,但不向前端发送详细信息
completed_statements += 1
return triplet_info
except Exception as e:
logger.error(f"陈述句 {statement.id} 三元组提取失败: {e}")
completed_statements += 1
from app.core.memory.models.triplet_models import (
TripletExtractionResponse,
)
return TripletExtractionResponse(triplets=[], entities=[])
tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements]
tasks = [extract_for_statement(stmt_data, i) for i, stmt_data in enumerate(all_statements)]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 将结果组织成对话级别的映射
@@ -465,7 +444,7 @@ class ExtractionOrchestrator:
self, dialog_data_list: List[DialogData]
) -> List[Dict[str, Any]]:
"""
从对话中提取时间信息(优化版:全局陈述句级并行
从对话中提取时间信息(流式输出版本:边提取边发送进度
Args:
dialog_data_list: 对话数据列表
@@ -473,7 +452,21 @@ class ExtractionOrchestrator:
Returns:
时间信息映射列表,每个对话对应一个字典
"""
logger.info("开始时间信息提取(全局陈述句级并行)")
# 试运行模式:跳过时间提取以节省时间
if self.is_pilot_run:
logger.info("试运行模式:跳过时间信息提取(节省约 10-15 秒)")
# 为所有陈述句返回空的时间范围
from app.core.memory.models.message_models import TemporalValidityRange
temporal_maps = []
for dialog in dialog_data_list:
temporal_map = {}
for chunk in dialog.chunks:
for statement in chunk.statements:
temporal_map[statement.id] = TemporalValidityRange(valid_at=None, invalid_at=None)
temporal_maps.append(temporal_map)
return temporal_maps
logger.info("开始时间信息提取(全局陈述句级并行 + 流式输出)")
# 收集所有需要提取时间的陈述句
all_statements = []
@@ -501,18 +494,30 @@ class ExtractionOrchestrator:
statement_metadata.append((d_idx, statement.id))
logger.info(f"收集到 {len(all_statements)} 个需要时间提取的陈述句,开始全局并行提取")
# 用于跟踪已完成的时间提取数量
completed_temporal = 0
total_temporal_statements = len(all_statements)
# 全局并行处理所有陈述句
async def extract_for_statement(stmt_data):
async def extract_for_statement(stmt_data, stmt_index):
nonlocal completed_temporal
statement, ref_dates = stmt_data
try:
return await self.temporal_extractor._extract_temporal_ranges(statement, ref_dates)
temporal_range = await self.temporal_extractor._extract_temporal_ranges(statement, ref_dates)
# 注意:不再发送时间提取的流式输出
# 时间提取在后台执行,但不向前端发送详细信息
completed_temporal += 1
return temporal_range
except Exception as e:
logger.error(f"陈述句 {statement.id} 时间信息提取失败: {e}")
completed_temporal += 1
from app.core.memory.models.message_models import TemporalValidityRange
return TemporalValidityRange(valid_at=None, invalid_at=None)
tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements]
tasks = [extract_for_statement(stmt_data, i) for i, stmt_data in enumerate(all_statements)]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 将结果组织成对话级别的映射
@@ -542,9 +547,108 @@ class ExtractionOrchestrator:
return temporal_maps
async def _extract_emotions(
self, dialog_data_list: List[DialogData]
) -> List[Dict[str, Any]]:
"""
从对话中提取情绪信息(优化版:全局陈述句级并行)
Args:
dialog_data_list: 对话数据列表
Returns:
情绪信息映射列表,每个对话对应一个字典
"""
logger.info("开始情绪信息提取(全局陈述句级并行)")
# 收集所有陈述句及其配置
all_statements = []
statement_metadata = [] # (dialog_idx, statement_id)
# 获取第一个对话的config_id来加载配置
config_id = None
if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'):
config_id = dialog_data_list[0].config_id
# 加载DataConfig
data_config = None
if config_id:
try:
from app.db import SessionLocal
from app.repositories.data_config_repository import DataConfigRepository
db = SessionLocal()
try:
data_config = DataConfigRepository.get_by_id(db, config_id)
finally:
db.close()
if data_config and not data_config.emotion_enabled:
logger.info("情绪提取已在配置中禁用,跳过情绪提取")
return [{} for _ in dialog_data_list]
except Exception as e:
logger.warning(f"加载DataConfig失败: {e},将跳过情绪提取")
return [{} for _ in dialog_data_list]
else:
logger.info("未找到config_id跳过情绪提取")
return [{} for _ in dialog_data_list]
# 如果配置未启用情绪提取,直接返回空映射
if not data_config or not data_config.emotion_enabled:
logger.info("情绪提取未启用,跳过")
return [{} for _ in dialog_data_list]
# 收集所有陈述句
for d_idx, dialog in enumerate(dialog_data_list):
for chunk in dialog.chunks:
for statement in chunk.statements:
all_statements.append((statement, data_config))
statement_metadata.append((d_idx, statement.id))
logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取情绪")
# 初始化情绪提取服务
from app.services.emotion_extraction_service import EmotionExtractionService
emotion_service = EmotionExtractionService(
llm_id=data_config.emotion_model_id if data_config.emotion_model_id else None
)
# 全局并行处理所有陈述句
async def extract_for_statement(stmt_data):
statement, config = stmt_data
try:
return await emotion_service.extract_emotion(statement.statement, config)
except Exception as e:
logger.error(f"陈述句 {statement.id} 情绪提取失败: {e}")
return None
tasks = [extract_for_statement(stmt_data) for stmt_data in all_statements]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 将结果组织成对话级别的映射
emotion_maps = [{} for _ in dialog_data_list]
successful_extractions = 0
for i, result in enumerate(results):
d_idx, stmt_id = statement_metadata[i]
if isinstance(result, Exception):
logger.error(f"陈述句处理异常: {result}")
emotion_maps[d_idx][stmt_id] = None
else:
emotion_maps[d_idx][stmt_id] = result
if result is not None:
successful_extractions += 1
# 统计提取结果
logger.info(f"情绪信息提取完成,共成功提取 {successful_extractions}/{len(all_statements)} 个情绪")
return emotion_maps
async def _parallel_extract_and_embed(
self, dialog_data_list: List[DialogData]
) -> Tuple[
List[Dict[str, Any]],
List[Dict[str, Any]],
List[Dict[str, Any]],
List[Dict[str, List[float]]],
@@ -552,35 +656,39 @@ class ExtractionOrchestrator:
List[List[float]],
]:
"""
并行执行三元组提取、时间信息提取和基础嵌入生成
并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
个任务都依赖陈述句提取的结果,但彼此独立,可以并行执行:
个任务都依赖陈述句提取的结果,但彼此独立,可以并行执行:
- 三元组提取:从陈述句中提取实体和关系
- 时间信息提取:从陈述句中提取时间范围
- 情绪提取:从陈述句中提取情绪信息
- 嵌入生成:为陈述句、分块和对话生成向量(不依赖三元组)
Args:
dialog_data_list: 对话数据列表
Returns:
个列表的元组:
个列表的元组:
- 三元组映射列表
- 时间信息映射列表
- 情绪映射列表
- 陈述句嵌入映射列表
- 分块嵌入映射列表
- 对话嵌入列表
"""
logger.info("并行执行:三元组提取 + 时间信息提取 + 基础嵌入生成")
logger.info("并行执行:三元组提取 + 时间信息提取 + 情绪提取 + 基础嵌入生成")
# 创建个并行任务
# 创建个并行任务
triplet_task = self._extract_triplets(dialog_data_list)
temporal_task = self._extract_temporal(dialog_data_list)
emotion_task = self._extract_emotions(dialog_data_list)
embedding_task = self._generate_basic_embeddings(dialog_data_list)
# 并行执行
results = await asyncio.gather(
triplet_task,
temporal_task,
emotion_task,
embedding_task,
return_exceptions=True
)
@@ -588,19 +696,21 @@ class ExtractionOrchestrator:
# 解包结果
triplet_maps = results[0] if not isinstance(results[0], Exception) else [{} for _ in dialog_data_list]
temporal_maps = results[1] if not isinstance(results[1], Exception) else [{} for _ in dialog_data_list]
emotion_maps = results[2] if not isinstance(results[2], Exception) else [{} for _ in dialog_data_list]
if isinstance(results[2], Exception):
logger.error(f"基础嵌入生成失败: {results[2]}")
if isinstance(results[3], Exception):
logger.error(f"基础嵌入生成失败: {results[3]}")
statement_embedding_maps = [{} for _ in dialog_data_list]
chunk_embedding_maps = [{} for _ in dialog_data_list]
dialog_embeddings = [[] for _ in dialog_data_list]
else:
statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = results[2]
statement_embedding_maps, chunk_embedding_maps, dialog_embeddings = results[3]
logger.info("并行任务执行完成")
return (
triplet_maps,
temporal_maps,
emotion_maps,
statement_embedding_maps,
chunk_embedding_maps,
dialog_embeddings,
@@ -711,6 +821,7 @@ class ExtractionOrchestrator:
dialog_data_list: List[DialogData],
temporal_maps: List[Dict[str, Any]],
triplet_maps: List[Dict[str, Any]],
emotion_maps: List[Dict[str, Any]],
statement_embedding_maps: List[Dict[str, List[float]]],
chunk_embedding_maps: List[Dict[str, List[float]]],
dialog_embeddings: List[List[float]],
@@ -722,6 +833,7 @@ class ExtractionOrchestrator:
dialog_data_list: 对话数据列表
temporal_maps: 时间信息映射列表
triplet_maps: 三元组映射列表
emotion_maps: 情绪信息映射列表
statement_embedding_maps: 陈述句嵌入映射列表
chunk_embedding_maps: 分块嵌入映射列表
dialog_embeddings: 对话嵌入列表
@@ -736,6 +848,7 @@ class ExtractionOrchestrator:
if (
len(temporal_maps) != expected_length
or len(triplet_maps) != expected_length
or len(emotion_maps) != expected_length
or len(statement_embedding_maps) != expected_length
or len(chunk_embedding_maps) != expected_length
or len(dialog_embeddings) != expected_length
@@ -743,6 +856,7 @@ class ExtractionOrchestrator:
logger.warning(
f"数据大小不匹配 - 对话: {len(dialog_data_list)}, "
f"时间映射: {len(temporal_maps)}, 三元组映射: {len(triplet_maps)}, "
f"情绪映射: {len(emotion_maps)}, "
f"陈述句嵌入: {len(statement_embedding_maps)}, "
f"分块嵌入: {len(chunk_embedding_maps)}, "
f"对话嵌入: {len(dialog_embeddings)}"
@@ -751,6 +865,7 @@ class ExtractionOrchestrator:
total_statements = 0
assigned_temporal = 0
assigned_triplets = 0
assigned_emotions = 0
assigned_statement_embeddings = 0
assigned_chunk_embeddings = 0
assigned_dialog_embeddings = 0
@@ -758,12 +873,13 @@ class ExtractionOrchestrator:
# 处理每个对话
for i, dialog_data in enumerate(dialog_data_list):
# 检查是否有缺失的数据
if i >= len(temporal_maps) or i >= len(triplet_maps):
if i >= len(temporal_maps) or i >= len(triplet_maps) or i >= len(emotion_maps):
logger.warning(f"对话 {dialog_data.id} 缺少提取数据,跳过赋值")
continue
temporal_map = temporal_maps[i]
triplet_map = triplet_maps[i]
emotion_map = emotion_maps[i]
statement_embedding_map = statement_embedding_maps[i] if i < len(statement_embedding_maps) else {}
chunk_embedding_map = chunk_embedding_maps[i] if i < len(chunk_embedding_maps) else {}
dialog_embedding = dialog_embeddings[i] if i < len(dialog_embeddings) else []
@@ -794,6 +910,18 @@ class ExtractionOrchestrator:
statement.triplet_extraction_info = triplet_map[statement.id]
assigned_triplets += 1
# 赋值情绪信息
if statement.id in emotion_map:
emotion_data = emotion_map[statement.id]
if emotion_data is not None:
# 将EmotionExtraction对象的字段赋值到Statement
statement.emotion_type = emotion_data.emotion_type
statement.emotion_intensity = emotion_data.emotion_intensity
statement.emotion_keywords = emotion_data.emotion_keywords
statement.emotion_subject = emotion_data.emotion_subject
statement.emotion_target = emotion_data.emotion_target
assigned_emotions += 1
# 赋值陈述句嵌入
if statement.id in statement_embedding_map:
statement.statement_embedding = statement_embedding_map[statement.id]
@@ -802,6 +930,7 @@ class ExtractionOrchestrator:
logger.info(
f"数据赋值完成 - 总陈述句: {total_statements}, "
f"时间信息: {assigned_temporal}, 三元组: {assigned_triplets}, "
f"情绪信息: {assigned_emotions}, "
f"陈述句嵌入: {assigned_statement_embeddings}, "
f"分块嵌入: {assigned_chunk_embeddings}, "
f"对话嵌入: {assigned_dialog_embeddings}"
@@ -833,9 +962,7 @@ class ExtractionOrchestrator:
"""
logger.info("开始创建节点和边")
# 进度回调:正在创建节点和边
if self.progress_callback:
await self.progress_callback("creating_nodes_edges", "正在创建节点和边...")
# 注意:开始消息已在 run 方法中发送,这里不再重复发送
dialogue_nodes = []
chunk_nodes = []
@@ -847,8 +974,13 @@ class ExtractionOrchestrator:
# 用于去重的集合
entity_id_set = set()
# 用于跟踪进度
total_dialogs = len(dialog_data_list)
processed_dialogs = 0
for dialog_data in dialog_data_list:
processed_dialogs += 1
# 创建对话节点
dialogue_node = DialogueNode(
id=dialog_data.id,
@@ -908,6 +1040,12 @@ class ExtractionOrchestrator:
created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at,
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
# Emotion fields
emotion_type=getattr(statement, 'emotion_type', None),
emotion_intensity=getattr(statement, 'emotion_intensity', None),
emotion_keywords=getattr(statement, 'emotion_keywords', None),
emotion_subject=getattr(statement, 'emotion_subject', None),
emotion_target=getattr(statement, 'emotion_target', None),
)
statement_nodes.append(statement_node)
@@ -995,6 +1133,26 @@ class ExtractionOrchestrator:
expired_at=dialog_data.expired_at,
)
entity_entity_edges.append(entity_entity_edge)
# 流式输出:每创建一个关系边,立即发送进度(限制发送数量)
if self.progress_callback and len(entity_entity_edges) <= 10:
# 获取实体名称
source_name = triplet.subject_name
target_name = triplet.object_name
relationship_result = {
"result_type": "relationship_creation",
"relationship_index": len(entity_entity_edges),
"source_entity": source_name,
"relation_type": triplet.predicate,
"target_entity": target_name,
"relationship_text": f"{source_name} -[{triplet.predicate}]-> {target_name}",
"dialog_progress": f"{processed_dialogs}/{total_dialogs}"
}
await self.progress_callback(
"creating_nodes_edges_result",
f"关系创建中 ({processed_dialogs}/{total_dialogs})",
relationship_result
)
else:
logger.warning(
f"跳过三元组 - 无法找到实体ID: subject_id={triplet.subject_id}, "
@@ -1009,12 +1167,9 @@ class ExtractionOrchestrator:
f"实体-实体边: {len(entity_entity_edges)}"
)
# 进度回调:只输出关系创建结果
# 进度回调:创建节点和边完成,传递结果统计
# 注意:具体的关系创建结果已经在创建过程中实时发送了
if self.progress_callback:
# 输出关系创建结果
await self._output_relationship_creation_results(entity_entity_edges, entity_nodes)
# 进度回调:创建节点和边完成,传递结果统计
nodes_edges_stats = {
"dialogue_nodes_count": len(dialogue_nodes),
"chunk_nodes_count": len(chunk_nodes),
@@ -1072,7 +1227,7 @@ class ExtractionOrchestrator:
"""
logger.info("开始两阶段实体去重和消歧")
# 进度回调:正在去重消歧
# 进度回调:发送去重消歧开始消息
if self.progress_callback:
await self.progress_callback("deduplication", "正在去重消歧...")
@@ -1157,25 +1312,26 @@ class ExtractionOrchestrator:
f"实体-实体边减少 {len(entity_entity_edges) - len(final_entity_entity_edges)}"
)
# 进度回调:输出去重消歧的具体结果
# 流式输出:实时输出去重消歧的具体结果
if self.progress_callback:
# 分析实体合并情况
# 分析实体合并情况(使用内存中的记录)
merge_info = await self._analyze_entity_merges(entity_nodes, final_entity_nodes)
# 输出去重合并的实体示例
# 逐个输出去重合并的实体示例
for i, merge_detail in enumerate(merge_info[:5]): # 输出前5个去重结果
dedup_result = {
"result_type": "entity_merge",
"merged_entity_name": merge_detail["main_entity_name"],
"merged_count": merge_detail["merged_count"],
"merge_progress": f"{i + 1}/{min(len(merge_info), 5)}",
"message": f"{merge_detail['main_entity_name']}合并{merge_detail['merged_count']}个:相似实体已合并"
}
await self.progress_callback("dedup_disambiguation_result", "实体去重完成", dedup_result)
await self.progress_callback("dedup_disambiguation_result", "实体去重", dedup_result)
# 分析实体消歧情况
# 分析实体消歧情况(使用内存中的记录)
disamb_info = await self._analyze_entity_disambiguation(entity_nodes, final_entity_nodes)
# 输出实体消歧的结果
# 逐个输出实体消歧的结果
for i, disamb_detail in enumerate(disamb_info[:5]): # 输出前5个消歧结果
disamb_result = {
"result_type": "entity_disambiguation",
@@ -1183,11 +1339,10 @@ class ExtractionOrchestrator:
"disambiguation_type": disamb_detail["disamb_type"],
"confidence": disamb_detail.get("confidence", "unknown"),
"reason": disamb_detail.get("reason", ""),
"disamb_progress": f"{i + 1}/{min(len(disamb_info), 5)}",
"message": f"{disamb_detail['entity_name']}消歧完成:{disamb_detail['disamb_type']}"
}
await self.progress_callback("dedup_disambiguation_result", "实体消歧完成", disamb_result)
await self.progress_callback("dedup_disambiguation_result", "实体消歧", disamb_result)
# 进度回调:去重消歧完成,传递去重和消歧的具体效果
await self._send_dedup_progress_callback(
@@ -1299,7 +1454,7 @@ class ExtractionOrchestrator:
if match:
entity1_name = match.group(1).strip()
entity1_type = match.group(2)
entity2_name = match.group(3).strip()
match.group(3).strip()
entity2_type = match.group(4)
# 提取置信度和原因
@@ -1611,7 +1766,6 @@ async def get_chunked_dialogs(
包含分块的 DialogData 对象列表
"""
import json
import os
import re
# 加载测试数据
@@ -1794,7 +1948,6 @@ async def get_chunked_dialogs_with_preprocessing(
Returns:
带 chunks 的 DialogData 列表
"""
import os
print("\n=== 完整数据处理流程(包含预处理)===")
if input_data_path is None:

View File

@@ -0,0 +1,210 @@
{
"memory_verify": {
"source_data": [
{
"statement_name": "用户是2023年春天去北京工作的。",
"statement_id": "62beac695b1346f4871740a45db88782",
"statement_created_at": "2025-12-19T10:31:15.239252"
},
{
"statement_name": "用户后来基本一直都在北京上班。",
"statement_id": "4cba5ac08b674d7fb1e2ae634d2b8f0b",
"statement_created_at": "2025-12-19T10:31:15.239252"
},
{
"statement_name": "用户从2023年开始就一直在北京生活。",
"statement_id": "e612a44da4db483993c350df7c97a1a1",
"statement_created_at": "2025-12-19T10:31:15.239252"
},
{
"statement_name": "用户从来没有长期离开过北京。",
"statement_id": "b3c787a2e33c49f7981accabbbb4538a",
"statement_created_at": "2025-12-19T10:31:15.239252"
},
{
"statement_name": "由于公司调整用户在2024年上半年被调到上海待了差不多半年。",
"statement_id": "64cde4230cb24a4da726e7db9e7aa616",
"statement_created_at": "2025-12-19T10:31:15.239252"
},
{
"statement_name": "用户在被调到上海期间每天都是在上海办公室打卡。",
"statement_id": "8b1b12e23b844b8088dfeb67da6ad669",
"statement_created_at": "2025-12-19T10:31:15.239252"
},
{
"statement_name": "用户在入职时使用的身份信息是之前的身份证号为11010119950308123X。",
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
"statement_created_at": "2025-12-19T10:31:15.239252"
},
{
"statement_name": "用户的银行卡号是6222023847595898。",
"statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f",
"statement_created_at": "2025-12-19T10:31:15.239252"
},
{
"statement_name": "用户的身份信息和银行卡信息一直没变。",
"statement_id": "b3ca618e1e204b83bebd70e75cf2073f",
"statement_created_at": "2025-12-19T10:31:15.239252"
},
{
"statement_name": "用户认为在上海的那段时间更多算是远程配合。",
"statement_id": "150af89d2c154e6eb41ff1a91e37f962",
"statement_created_at": "2025-12-19T10:31:15.239252"
}
],
"databasets": [
{
"entity1_name": "Person",
"description": "表示人类个体的通用类型",
"statement_id": "62beac695b1346f4871740a45db88782",
"created_at": "2025-12-19T10:31:15.239252000",
"expired_at": "9999-12-31T00:00:00.000000000",
"relationship_type": "EXTRACTED_RELATIONSHIP",
"relationship": {},
"entity2_name": "用户",
"entity2": {
"entity_idx": 0,
"run_id": "62b59cfebeea43dd94d91763056f069a",
"connect_strength": "strong",
"created_at": "2025-12-19T10:31:15.239252000",
"description": "叙述者,讲述个人工作与生活经历的个体",
"statement_id": "62beac695b1346f4871740a45db88782",
"expired_at": "9999-12-31T00:00:00.000000000",
"entity_type": "Person",
"group_id": "88a459f5_text08",
"user_id": "88a459f5_text08",
"name": "用户",
"apply_id": "88a459f5_text08",
"id": "3d3896797b334572a80d57590026063d"
}
},
{
"entity1_name": "用户",
"description": "叙述者,讲述个人工作与生活经历的个体",
"statement_id": "62beac695b1346f4871740a45db88782",
"created_at": "2025-12-19T10:31:15.239252000",
"expired_at": "9999-12-31T00:00:00.000000000",
"relationship_type": "EXTRACTED_RELATIONSHIP",
"relationship": {},
"entity2_name": "身份信息",
"entity2": {
"entity_idx": 1,
"run_id": "62b59cfebeea43dd94d91763056f069a",
"connect_strength": "Strong",
"description": "用于个人身份识别的数据",
"created_at": "2025-12-19T10:31:15.239252000",
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
"expired_at": "9999-12-31T00:00:00.000000000",
"entity_type": "Information",
"group_id": "88a459f5_text08",
"user_id": "88a459f5_text08",
"name": "身份信息",
"apply_id": "88a459f5_text08",
"id": "aa766a517e82490599a9b3af54cfd933"
}
},
{
"entity1_name": "用户",
"description": "叙述者,讲述个人工作与生活经历的个体",
"statement_id": "62beac695b1346f4871740a45db88782",
"created_at": "2025-12-19T10:31:15.239252000",
"expired_at": "9999-12-31T00:00:00.000000000",
"relationship_type": "EXTRACTED_RELATIONSHIP",
"relationship": {},
"entity2_name": "6222023847595898",
"entity2": {
"entity_idx": 1,
"run_id": "62b59cfebeea43dd94d91763056f069a",
"connect_strength": "Strong",
"description": "用户的银行卡号码",
"created_at": "2025-12-19T10:31:15.239252000",
"statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f",
"expired_at": "9999-12-31T00:00:00.000000000",
"entity_type": "Numeric",
"group_id": "88a459f5_text08",
"user_id": "88a459f5_text08",
"name": "6222023847595898",
"apply_id": "88a459f5_text08",
"id": "610ba361918f4e68a65ce6ad06e5c7a0"
}
},
{
"entity1_name": "用户",
"description": "叙述者,讲述个人工作与生活经历的个体",
"statement_id": "62beac695b1346f4871740a45db88782",
"created_at": "2025-12-19T10:31:15.239252000",
"expired_at": "9999-12-31T00:00:00.000000000",
"relationship_type": "EXTRACTED_RELATIONSHIP",
"relationship": {},
"entity2_name": "上海办公室",
"entity2": {
"entity_idx": 1,
"run_id": "62b59cfebeea43dd94d91763056f069a",
"aliases": ["上海办"],
"connect_strength": "Strong",
"created_at": "2025-12-19T10:31:15.239252000",
"description": "位于上海的工作办公场所",
"statement_id": "8b1b12e23b844b8088dfeb67da6ad669",
"expired_at": "9999-12-31T00:00:00.000000000",
"entity_type": "Location",
"group_id": "88a459f5_text08",
"user_id": "88a459f5_text08",
"name": "上海办公室",
"apply_id": "88a459f5_text08",
"id": "fb702ef695c14e14af3e56786bc8815b"
}
},
{
"entity1_name": "用户",
"description": "叙述者,讲述个人工作与生活经历的个体",
"statement_id": "62beac695b1346f4871740a45db88782",
"created_at": "2025-12-19T10:31:15.239252000",
"expired_at": "9999-12-31T00:00:00.000000000",
"relationship_type": "EXTRACTED_RELATIONSHIP",
"relationship": {},
"entity2_name": "北京",
"entity2": {
"entity_idx": 2,
"run_id": "62b59cfebeea43dd94d91763056f069a",
"aliases": ["京", "京城", "北平"],
"connect_strength": "strong",
"created_at": "2025-12-19T10:31:15.239252000",
"description": "中国的首都城市,用户主要工作和生活所在地",
"statement_id": "62beac695b1346f4871740a45db88782",
"expired_at": "9999-12-31T00:00:00.000000000",
"entity_type": "Location",
"group_id": "88a459f5_text08",
"user_id": "88a459f5_text08",
"name": "北京",
"apply_id": "88a459f5_text08",
"id": "81b2d1a571bb46a08a2d7a1e87efb945"
}
},
{
"entity1_name": "11010119950308123X",
"description": "具体的身份证号码值",
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
"created_at": "2025-12-19T10:31:15.239252000",
"expired_at": "9999-12-31T00:00:00.000000000",
"relationship_type": "EXTRACTED_RELATIONSHIP",
"relationship": {},
"entity2_name": "身份证号",
"entity2": {
"entity_idx": 2,
"run_id": "62b59cfebeea43dd94d91763056f069a",
"connect_strength": "strong",
"description": "中华人民共和国公民的身份号码",
"created_at": "2025-12-19T10:31:15.239252000",
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
"expired_at": "9999-12-31T00:00:00.000000000",
"entity_type": "Identifier",
"group_id": "88a459f5_text08",
"user_id": "88a459f5_text08",
"name": "身份证号",
"apply_id": "88a459f5_text08",
"id": "3e5f920645b2404fadb0e9ff60d1306e"
}
}
]
}
}

View File

@@ -8,17 +8,21 @@
4. 反思结果应用 - 更新记忆库
"""
import os
import json
import logging
import asyncio
import os
import time
from typing import List, Dict, Any, Optional
from datetime import datetime
from enum import Enum
import uuid
from pydantic import BaseModel, Field
from pydantic import BaseModel
from app.core.response_utils import success
from app.repositories.neo4j.cypher_queries import neo4j_query_part, neo4j_statement_part, neo4j_query_all, neo4j_statement_all
from app.repositories.neo4j.neo4j_update import neo4j_data
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# 配置日志
_root_logger = logging.getLogger()
@@ -33,14 +37,14 @@ else:
class ReflectionRange(str, Enum):
"""反思范围枚举"""
RETRIEVAL = "retrieval" # 从检索结果中反思
DATABASE = "database" # 从整个数据库中反思
PARTIAL = "partial" # 从检索结果中反思
ALL = "all" # 从整个数据库中反思
class ReflectionBaseline(str, Enum):
"""反思基线枚举"""
TIME = "TIME" # 基于时间的反思
FACT = "FACT" # 基于事实的反思
TIME = "TIME" # 基于时间的反思
FACT = "FACT" # 基于事实的反思
HYBRID = "HYBRID" # 混合反思
@@ -48,9 +52,16 @@ class ReflectionConfig(BaseModel):
"""反思引擎配置"""
enabled: bool = False
iteration_period: str = "3" # 反思周期
reflexion_range: ReflectionRange = ReflectionRange.RETRIEVAL
reflexion_range: ReflectionRange = ReflectionRange.PARTIAL
baseline: ReflectionBaseline = ReflectionBaseline.TIME
concurrency: int = Field(default=5, description="并发数量")
model_id: Optional[str] = None # 模型ID
end_user_id: Optional[str] = None
output_example: Optional[str] = None # 输出示例
# 评估相关字段
memory_verify: bool = True # 记忆验证
quality_assessment: bool = True # 质量评估
violation_handling_strategy: str = "warn" # 违规处理策略
class Config:
use_enum_values = True
@@ -75,16 +86,16 @@ class ReflectionEngine:
"""
def __init__(
self,
config: ReflectionConfig,
neo4j_connector: Optional[Any] = None,
llm_client: Optional[Any] = None,
get_data_func: Optional[Any] = None,
render_evaluate_prompt_func: Optional[Any] = None,
render_reflexion_prompt_func: Optional[Any] = None,
conflict_schema: Optional[Any] = None,
reflexion_schema: Optional[Any] = None,
update_query: Optional[str] = None
self,
config: ReflectionConfig,
neo4j_connector: Optional[Any] = None,
llm_client: Optional[Any] = None,
get_data_func: Optional[Any] = None,
render_evaluate_prompt_func: Optional[Any] = None,
render_reflexion_prompt_func: Optional[Any] = None,
conflict_schema: Optional[Any] = None,
reflexion_schema: Optional[Any] = None,
update_query: Optional[str] = None
):
"""
初始化反思引擎
@@ -109,7 +120,7 @@ class ReflectionEngine:
self.conflict_schema = conflict_schema
self.reflexion_schema = reflexion_schema
self.update_query = update_query
self._semaphore = asyncio.Semaphore(config.concurrency)
self._semaphore = asyncio.Semaphore(5) # 默认并发数为5
# 延迟导入以避免循环依赖
self._lazy_init_done = False
@@ -127,11 +138,21 @@ class ReflectionEngine:
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config import definitions as config_defs
self.llm_client = get_llm_client(config_defs.SELECTED_LLM_ID)
elif isinstance(self.llm_client, str):
# 如果 llm_client 是字符串model_id则用它初始化客户端
from app.core.memory.utils.llm.llm_utils import get_llm_client
model_id = self.llm_client
self.llm_client = get_llm_client(model_id)
if self.get_data_func is None:
from app.core.memory.utils.config.get_data import get_data
self.get_data_func = get_data
# 导入get_data_statement函数
if not hasattr(self, 'get_data_statement'):
from app.core.memory.utils.config.get_data import get_data_statement
self.get_data_statement = get_data_statement
if self.render_evaluate_prompt_func is None:
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
self.render_evaluate_prompt_func = render_evaluate_prompt
@@ -154,13 +175,11 @@ class ReflectionEngine:
self._lazy_init_done = True
async def execute_reflection(self, host_id: uuid.UUID) -> ReflectionResult:
async def execute_reflection(self, host_id) -> ReflectionResult:
"""
执行完整的反思流程
Args:
host_id: 主机ID
Returns:
ReflectionResult: 反思结果
"""
@@ -176,9 +195,10 @@ class ReflectionEngine:
start_time = asyncio.get_event_loop().time()
logging.info("====== 自我反思流程开始 ======")
print(self.config.baseline, self.config.memory_verify, self.config.quality_assessment)
try:
# 1. 获取反思数据
reflexion_data = await self._get_reflexion_data(host_id)
reflexion_data, statement_databasets = await self._get_reflexion_data(host_id)
if not reflexion_data:
return ReflectionResult(
success=True,
@@ -187,22 +207,21 @@ class ReflectionEngine:
)
# 2. 检测冲突(基于事实的反思)
conflict_data = await self._detect_conflicts(reflexion_data)
if not conflict_data:
return ReflectionResult(
success=True,
message="无冲突,无需反思",
execution_time=asyncio.get_event_loop().time() - start_time
)
conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets)
print(100 * '-')
print(conflict_data)
print(100 * '-')
conflicts_found = len(conflict_data)
logging.info(f"发现 {conflicts_found} 个冲突")
# 检查是否真的有冲突
has_conflict = conflict_data[0].get('conflict', False)
conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0
logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突")
# 记录冲突数据
await self._log_data("conflict", conflict_data)
# 3. 解决冲突
solved_data = await self._resolve_conflicts(conflict_data)
solved_data = await self._resolve_conflicts(conflict_data, statement_databasets)
if not solved_data:
return ReflectionResult(
success=False,
@@ -210,6 +229,9 @@ class ReflectionEngine:
conflicts_found=conflicts_found,
execution_time=asyncio.get_event_loop().time() - start_time
)
print(100 * '*')
print(solved_data)
print(100 * '*')
conflicts_resolved = len(solved_data)
logging.info(f"解决了 {conflicts_resolved} 个冲突")
@@ -230,7 +252,8 @@ class ReflectionEngine:
conflicts_found=conflicts_found,
conflicts_resolved=conflicts_resolved,
memories_updated=memories_updated,
execution_time=execution_time
execution_time=execution_time,
)
except Exception as e:
@@ -241,6 +264,79 @@ class ReflectionEngine:
execution_time=asyncio.get_event_loop().time() - start_time
)
async def reflection_run(self):
self._lazy_init()
start_time = time.time()
asyncio.get_event_loop().time()
logging.info("====== 自我反思流程开始 ======")
result_data = {}
source_data, databasets = await self.extract_fields_from_json()
result_data['baseline'] = self.config.baseline
result_data[
'source_data'] = "我是 2023 年春天去北京工作的后来基本一直都在北京上班也没怎么换过城市。不过后来公司调整2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X银行卡是 6222023847595898这些一直没变。对了其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合"
# 2. 检测冲突(基于事实的反思)
conflict_data = await self._detect_conflicts(databasets, source_data)
# 遍历数据提取字段
quality_assessments = []
memory_verifies = []
for item in conflict_data:
print(item)
quality_assessments.append(item['quality_assessment'])
memory_verifies.append(item['memory_verify'])
result_data['quality_assessments'] = quality_assessments
result_data['memory_verifies'] = memory_verifies
# 检查是否真的有冲突
has_conflict = conflict_data[0].get('conflict', False)
conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0
logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突")
# 记录冲突数据
await self._log_data("conflict", conflict_data)
# 3. 解决冲突
solved_data = await self._resolve_conflicts(conflict_data, source_data)
if not solved_data:
return ReflectionResult(
success=False,
message="反思失败,未解决冲突",
conflicts_found=conflicts_found,
execution_time=asyncio.get_event_loop().time() - start_time
)
reflexion_data = []
# 遍历数据提取reflexion字段
for item in solved_data:
if 'results' in item:
for result in item['results']:
reflexion_data.append(result['reflexion'])
result_data['reflexion_data'] = reflexion_data
return result_data
async def extract_fields_from_json(self):
"""从example.json中提取source_data和databasets字段"""
prompt_dir = os.path.join(os.path.dirname(__file__), "example")
try:
# 读取JSON文件
with open(prompt_dir + '/example.json', 'r', encoding='utf-8') as f:
data = json.loads(f.read())
# 提取memory_verify下的字段
memory_verify = data.get("memory_verify", {})
source_data = memory_verify.get("source_data", [])
databasets = memory_verify.get("databasets", [])
return source_data, databasets
except Exception as e:
return [], []
async def _get_reflexion_data(self, host_id: uuid.UUID) -> List[Any]:
"""
获取反思数据
@@ -253,17 +349,28 @@ class ReflectionEngine:
Returns:
List[Any]: 反思数据列表
"""
if self.config.reflexion_range == ReflectionRange.RETRIEVAL:
# 从检索结果中获取数据
return await self.get_data_func(host_id)
elif self.config.reflexion_range == ReflectionRange.DATABASE:
# 从整个数据库中获取数据(待实现)
logging.warning("从数据库获取反思数据功能尚未实现")
return []
else:
raise ValueError(f"未知的反思范围: {self.config.reflexion_range}")
async def _detect_conflicts(self, data: List[Any]) -> List[Any]:
if self.config.reflexion_range == ReflectionRange.PARTIAL:
neo4j_query = neo4j_query_part.format(host_id)
neo4j_statement = neo4j_statement_part.format(host_id)
elif self.config.reflexion_range == ReflectionRange.ALL:
neo4j_query = neo4j_query_all.format(host_id)
neo4j_statement = neo4j_statement_all.format(host_id)
try:
result = await self.neo4j_connector.execute_query(neo4j_query)
result_statement = await self.neo4j_connector.execute_query(neo4j_statement)
neo4j_databasets = await self.get_data_func(result)
neo4j_state = await self.get_data_statement(result_statement)
return neo4j_databasets, neo4j_state
except Exception as e:
logging.error(f"Neo4j查询失败: {e}")
return [], []
async def _detect_conflicts(self, data: List[Any], statement_databasets: List[Any]) -> List[Any]:
"""
检测冲突(基于事实的反思)
@@ -278,14 +385,28 @@ class ReflectionEngine:
if not data:
return []
# 数据预处理:如果数据量太少,直接返回无冲突
if len(data) < 2:
logging.info("数据量不足,无需检测冲突")
return []
# 使用转换后的数据
print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长
memory_verify = self.config.memory_verify
logging.info("====== 冲突检测开始 ======")
start_time = asyncio.get_event_loop().time()
quality_assessment = self.config.quality_assessment
try:
# 渲染冲突检测提示词
rendered_prompt = await self.render_evaluate_prompt_func(
data,
self.conflict_schema
self.conflict_schema,
self.config.baseline,
memory_verify,
quality_assessment,
statement_databasets
)
messages = [{"role": "user", "content": rendered_prompt}]
@@ -316,7 +437,7 @@ class ReflectionEngine:
logging.error(f"冲突检测失败: {e}", exc_info=True)
return []
async def _resolve_conflicts(self, conflicts: List[Any]) -> List[Any]:
async def _resolve_conflicts(self, conflicts: List[Any], statement_databasets: List[Any]) -> List[Any]:
"""
解决冲突
@@ -332,6 +453,8 @@ class ReflectionEngine:
return []
logging.info("====== 冲突解决开始 ======")
baseline = self.config.baseline
memory_verify = self.config.memory_verify
# 并行处理每个冲突
async def _resolve_one(conflict: Any) -> Optional[Dict[str, Any]]:
@@ -341,7 +464,10 @@ class ReflectionEngine:
# 渲染反思提示词
rendered_prompt = await self.render_reflexion_prompt_func(
[conflict],
self.reflexion_schema
self.reflexion_schema,
baseline,
memory_verify,
statement_databasets
)
messages = [{"role": "user", "content": rendered_prompt}]
@@ -381,8 +507,8 @@ class ReflectionEngine:
return solved
async def _apply_reflection_results(
self,
solved_data: List[Dict[str, Any]]
self,
solved_data: List[Dict[str, Any]]
) -> int:
"""
应用反思结果(更新记忆库)
@@ -395,57 +521,7 @@ class ReflectionEngine:
Returns:
int: 成功更新的记忆数量
"""
if not solved_data:
logging.warning("无解决方案数据,跳过更新")
return 0
logging.info("====== 记忆更新开始 ======")
success_count = 0
async def _update_one(item: Dict[str, Any]) -> bool:
"""更新单条记忆"""
async with self._semaphore:
try:
if not isinstance(item, dict):
return False
# 提取更新参数
resolved = item.get("resolved", {})
resolved_mem = resolved.get("resolved_memory", {})
group_id = resolved_mem.get("group_id")
memory_id = resolved_mem.get("id")
new_invalid_at = resolved_mem.get("invalid_at")
if not all([group_id, memory_id, new_invalid_at]):
logging.warning(f"记忆更新参数缺失,跳过此项: {item}")
return False
# 执行更新
await self.neo4j_connector.execute_query(
self.update_query,
group_id=group_id,
id=memory_id,
new_invalid_at=new_invalid_at,
)
return True
except Exception as e:
logging.error(f"更新单条记忆失败: {e}")
return False
# 并发执行所有更新任务
tasks = [
_update_one(item)
for item in solved_data
if isinstance(item, dict)
]
results = await asyncio.gather(*tasks, return_exceptions=False)
success_count = sum(1 for r in results if r)
logging.info(f"成功更新 {success_count}/{len(solved_data)} 条记忆")
success_count = await neo4j_data(solved_data)
return success_count
async def _log_data(self, label: str, data: Any) -> None:
@@ -456,6 +532,7 @@ class ReflectionEngine:
label: 数据标签
data: 要记录的数据
"""
def _write():
try:
with open("reflexion_data.json", "a", encoding="utf-8") as f:
@@ -470,9 +547,9 @@ class ReflectionEngine:
# 基于时间的反思方法
async def time_based_reflection(
self,
host_id: uuid.UUID,
time_period: Optional[str] = None
self,
host_id: uuid.UUID,
time_period: Optional[str] = None
) -> ReflectionResult:
"""
基于时间的反思
@@ -494,8 +571,8 @@ class ReflectionEngine:
# 基于事实的反思方法
async def fact_based_reflection(
self,
host_id: uuid.UUID
self,
host_id: uuid.UUID
) -> ReflectionResult:
"""
基于事实的反思
@@ -515,8 +592,8 @@ class ReflectionEngine:
# 综合反思方法
async def comprehensive_reflection(
self,
host_id: uuid.UUID
self,
host_id: uuid.UUID
) -> ReflectionResult:
"""
综合反思
@@ -553,33 +630,3 @@ class ReflectionEngine:
else:
raise ValueError(f"未知的反思基线: {self.config.baseline}")
# 便捷函数:创建默认配置的反思引擎
def create_reflection_engine(
enabled: bool = False,
iteration_period: str = "3",
reflexion_range: str = "retrieval",
baseline: str = "TIME",
concurrency: int = 5
) -> ReflectionEngine:
"""
创建反思引擎实例
Args:
enabled: 是否启用反思
iteration_period: 反思周期
reflexion_range: 反思范围
baseline: 反思基线
concurrency: 并发数量
Returns:
ReflectionEngine: 反思引擎实例
"""
config = ReflectionConfig(
enabled=enabled,
iteration_period=iteration_period,
reflexion_range=reflexion_range,
baseline=baseline,
concurrency=concurrency
)
return ReflectionEngine(config)

View File

@@ -1,13 +1,8 @@
import json
import os
import uuid
from typing import List, Dict, Any, Optional
from sqlalchemy.orm import Session
from app.db import get_db
from app.models.retrieval_info import RetrievalInfo
from app.schemas.memory_storage_schema import BaseDataSchema
import logging
from typing import List, Dict, Any
logger = logging.getLogger(__name__)
async def _load_(data: List[Any]) -> List[Dict]:
@@ -60,27 +55,46 @@ async def _load_(data: List[Any]) -> List[Dict]:
return results
async def get_data(host_id: uuid.UUID) -> List[Dict]:
async def get_data(result):
"""
从数据库中获取数据
"""
# 从数据库会话中获取会话
db: Session = next(get_db())
try:
data = db.query(RetrievalInfo.retrieve_info).filter(RetrievalInfo.host_id == host_id).all()
neo4j_databasets=[]
for item in result:
filtered_item = {}
for key, value in item.items():
if 'name_embedding' not in key.lower():
if key == 'relationship' and value is not None:
# 只保留relationship的指定字段
rel_filtered = {}
if hasattr(value, 'get'):
rel_filtered['run_id'] = value.get('run_id')
rel_filtered['statement'] = value.get('statement')
rel_filtered['statement_id'] = value.get('statement_id')
rel_filtered['expired_at'] = value.get('expired_at')
rel_filtered['created_at'] = value.get('created_at')
filtered_item[key] = rel_filtered
elif key == 'entity2' and value is not None:
# 过滤entity2的name_embedding字段
entity2_filtered = {}
if hasattr(value, 'items'):
for e_key, e_value in value.items():
if 'name_embedding' not in e_key.lower():
entity2_filtered[e_key] = e_value
filtered_item[key] = entity2_filtered
else:
filtered_item[key] = value
# 直接将字典添加到列表中
neo4j_databasets.append(filtered_item)
return neo4j_databasets
async def get_data_statement( result):
neo4j_databasets=[]
for i in result:
neo4j_databasets.append(i)
return neo4j_databasets
# print(f"data:\n{data}")
# 解析,提取为字典的列表
results = await _load_(data)
return results
except Exception as e:
logger.error(f"failed to get data from database, host_id: {host_id}, error: {e}")
raise e
finally:
try:
db.close()
except Exception:
pass
if __name__ == "__main__":

View File

@@ -238,3 +238,81 @@ async def render_memory_summary_prompt(
'json_schema': 'MemorySummaryResponse.schema'
})
return rendered_prompt
async def render_emotion_extraction_prompt(
statement: str,
extract_keywords: bool,
enable_subject: bool
) -> str:
"""
Renders the emotion extraction prompt using the extract_emotion.jinja2 template.
Args:
statement: The statement to analyze
extract_keywords: Whether to extract emotion keywords
enable_subject: Whether to enable subject classification
Returns:
Rendered prompt content as string
"""
template = prompt_env.get_template("extract_emotion.jinja2")
rendered_prompt = template.render(
statement=statement,
extract_keywords=extract_keywords,
enable_subject=enable_subject
)
# 记录渲染结果到提示日志
log_prompt_rendering('emotion extraction', rendered_prompt)
# 可选:记录模板渲染信息
log_template_rendering('extract_emotion.jinja2', {
'statement': 'str',
'extract_keywords': extract_keywords,
'enable_subject': enable_subject
})
return rendered_prompt
async def render_emotion_suggestions_prompt(
health_data: dict,
patterns: dict,
user_profile: dict
) -> str:
"""
Renders the emotion suggestions generation prompt using the generate_emotion_suggestions.jinja2 template.
Args:
health_data: 情绪健康数据
patterns: 情绪模式分析结果
user_profile: 用户画像数据
Returns:
Rendered prompt content as string
"""
import json
# 预处理 emotion_distribution 为 JSON 字符串
emotion_distribution_json = json.dumps(
health_data.get('emotion_distribution', {}),
ensure_ascii=False,
indent=2
)
template = prompt_env.get_template("generate_emotion_suggestions.jinja2")
rendered_prompt = template.render(
health_data=health_data,
patterns=patterns,
user_profile=user_profile,
emotion_distribution_json=emotion_distribution_json
)
# 记录渲染结果到提示日志
log_prompt_rendering('emotion suggestions', rendered_prompt)
# 可选:记录模板渲染信息
log_template_rendering('generate_emotion_suggestions.jinja2', {
'health_score': health_data.get('health_score'),
'health_level': health_data.get('level'),
'user_interests': user_profile.get('interests', [])
})
return rendered_prompt

View File

@@ -1,19 +1,222 @@
你将收到一组记忆对象:{{ evaluate_data }}。
任务:多维度判断这些记忆是否与已有记忆存在冲突,并给出冲突的对应记忆。(冗余不算冲突)
你将收到一组用户历史记忆原始数据(来源于 Neo4j以及相关配置参数
原本的输入句子:{{statement_databasets}}
需要检测冲突对象:{{ evaluate_data }}
冲突判定类型:{{ baseline }}(取值为 TIME / FACT / HYBRID
记忆审核开关:{{ memory_verify }}(取值为 true / false
记忆质量评估开关开关:{{ quality_assessment }}(取值为 true / false
仅输出一个合法 JSON 对象,严格遵循下述结构
你的任务是
对用户历史记忆数据进行冲突检测和记忆审核,并输出严格结构化的 JSON 分析结果
数据的结构:
statement_databasets里面statement_name是输入的句子statement_id是连接evaluate_data里面的statement_id代表这个句子被拆分成几个实体需要根据整体的内容
需要根据以下内容做处理(冲突检测、记忆审核、记忆的质量评估)
## 冲突定义
### 时间冲突
时间冲突是指同一用户的相关事件在时间维度上存在逻辑矛盾:
1. **同一活动的时间冲突**
- 同一用户的同一活动在不同时间点被记录(如"周五打球"和"周六打球"
- 同一用户在同一时间段内被记录进行不同的互斥活动
2. **时间逻辑错误**
- expired_at 早于 created_at
- 同一事实的 created_at 时间差异超过合理误差范围(>5分钟
3. **日期属性冲突**
- 同一人的生日记录为不同日期(如"2月10号"和"2月16号"
4.存在明确先后约束 A -> B但 t(A) > t(B)
-例:入学时间晚于毕业时间。
-处理:标记异常、降权、触发逻辑反思或人工审查。
5.时间属性冲突
-单值日期属性出现多值(生日、入职日期)
-注意:本质属于事实冲突的日期特例,归入事实冲突仲裁框架。
6.互斥重叠冲突
-例:同一主体的两个事件区间重叠且互斥(如同一时间出现在两地)
-处理证据仲裁、保留多版本active + candidate
### 事实冲突
事实冲突是指同一实体的属性或关系存在相互矛盾的陈述:
1. **属性互斥**:同一实体的相反属性(喜欢↔不喜欢、有↔没有、是↔不是)
2. **关系矛盾**:同一实体在相同语境下的不同关系描述
3. **身份冲突**:同一实体被赋予不同的类型或角色
### 混合冲突检测
检测所有类型的冲突,包括但不限于时间冲突和事实冲突:
检测任何逻辑上不一致或相互矛盾的记录
## 记忆审核定义
### 隐私信息检测(隐私冲突)
当memory_verify为true时需要额外检测包含个人隐私信息的记录
1. **身份证信息**:包含身份证号码、身份证相关描述
2. **手机号码**:包含手机号、电话号码等联系方式
3. **社交账号**包含微信号、QQ号、邮箱地址等社交平台信息
4. **银行信息**:包含银行卡号、账户信息、支付信息
5. **税务信息**:包含税号、纳税信息、发票信息
6. **贷款信息**:包含贷款记录、信贷信息、借款信息
7. **其他敏感信息**包含密码、PIN码、验证码等安全信息
### 隐私检测原则
- 检测description、entity1_name、entity2_name等字段中的隐私信息
- 识别数字模式如手机号11位数字、身份证18位等
- 识别关键词(如"身份证"、"银行卡"、"密码"等)
- 检测敏感实体类型和关系
## 冲突检测原则
**全面检测**:不区分冲突类型,检测所有可能的冲突
**完整输出**如果发现任何冲突或隐私信息必须将所有相关记录都放入data字段
**实体关联**重点检查涉及相同实体entity1_name, entity2_name的记录
**语义分析**分析description字段的语义相似性和冲突性
**时间逻辑**:检查时间字段的逻辑一致性
**隐私检测**当memory_verify为true时检测所有包含隐私信息的记录
## 不符合冲突检测
-称呼
## 重要检测示例
### 冲突检测示例
- 用户与不同时间点的关系(周五 vs 周六2月10号 vs 2月16号
- 同一实体的重复定义但描述不同
- 同一关系的不同表述但含义冲突
- 任何逻辑上不可能同时为真的记录
### 隐私信息检测示例
- 包含手机号的记录:"用户的手机号是13812345678"
- 包含身份证的记录:"身份证号码为110101199001011234"
- 包含银行卡的记录:"银行卡号6222021234567890"
- 包含社交账号的记录:"微信号是user123456"
- 包含敏感信息的实体名称或描述
## 输出要求
**关键原则**
1. 当存在冲突或检测到隐私信息时conflict才为truedata字段才包含相关记录
2. 如果发现冲突必须将所有相关的冲突记录都放入data数组中
3. 如果memory_verify为true且检测到隐私信息必须将包含隐私信息的记录也放入data数组中
4. 既没有冲突也没有隐私信息时conflict为falsedata为空数组
5. 如果quality_assessment为true独立分析数据质量并输出评估结果如果为falsequality_assessment字段输出null
6. 冲突检测、隐私审核和质量评估三个功能完全独立,互不影响
7. 不输出conflict_memory字段
**处理逻辑**
- 首先进行冲突检测将冲突记录加入data数组
- 如果memory_verify为true再进行隐私信息检测将包含隐私信息的记录也加入data数组
- 如果quality_assessment为true独立进行质量评估分析所有输入数据的质量并输出评估结果
- 最终data数组包含所有冲突记录和隐私信息记录去重
- quality_assessment字段独立输出不影响冲突检测和隐私审核结果
- memory_verify字段独立输出隐私检测结果包含检测到的隐私信息类型和概述
返回数据格式以json方式输出:
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
- 关键的JSON格式要求{"statement":识别出的文本内容}
1.JSON结构仅使用标准ASCII双引号"-切勿使用中文引号(""或其他Unicode引号
2.如果提取的语句文本包含引号,请使用反斜杠(\")正确转义它们
3.确保所有JSON字符串都正确关闭并以逗号分隔
4.JSON字符串值中不包括换行符
5.正确转义的例子:"statement""Zhang Xinhua said\"我非常喜欢这本书\""
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```
## 记忆质量评估定义
### 质量评估标准
当quality_assessment为true时需要对记忆数据进行质量评估
1. **数据完整性**
- 检查必要字段是否完整entity1_name、entity2_name、description等
- 检查关系描述是否清晰明确
- 检查时间字段的有效性
2. **重复字段检测**
- 识别相同或高度相似的记录
- 检测冗余的实体关系
- 分析描述内容的重复度
3. **无意义字段检测**
- 识别空值、无效值或占位符内容
- 检测过于简单或无信息量的描述
- 识别格式错误或不规范的数据
4. **上下文依赖性**
- 评估记录是否需要额外上下文才能理解
- 检查实体名称的明确性
- 分析关系描述的自包含性
### 质量评估输出
- **质量百分比**基于上述标准计算的整体质量分数0-100
- **质量概述**:简要描述数据质量状况,包括主要问题和优点
输出是仅输出一个合法 JSON 对象,严格遵循下述结构:
{
"data": [ ...与输入同结构的记忆对象数组... ],
"conflict": true 或 false,
"conflict_memory": 若冲突为 true则填写与其冲突的记忆对象否则为 null
"data": [
{
"entity1_name": "实体1名称",
"description": "描述信息",
"statement_id": "陈述ID",
"created_at": "创建时间戳",
"expired_at": "过期时间戳",
"relationship_type": "关系类型",
"relationship": "关系对象",
"entity2_name": "实体2名称",
"entity2": "实体2对象"
}
],
"conflict": true或false,
"quality_assessment": {
"score": 质量百分比数字,
"summary": "质量概述文本"
} 或 null,
"memory_verify": {
"has_privacy": true或false,
"privacy_types": ["检测到的隐私信息类型列表"],
"summary": "隐私检测结果概述"
} 或 null
}
必须遵守:
- 只输出 JSON不要添加解释或多余文本。
- 使用标准双引号,必要时对内部引号进行转义。
- 字段名与结构必须与给定模式一致。
- data数组中包含冲突记录和隐私信息记录如果都没有则为空数组。
- quality_assessment字段当quality_assessment参数为true时输出评估对象为false时输出null。
- memory_verify字段当memory_verify参数为true时输出隐私检测结果对象为false时输出null。
### memory_verify字段说明
当memory_verify为true时需要输出隐私检测结果
- **has_privacy**: 布尔值,表示是否检测到隐私信息
- **privacy_types**: 字符串数组,包含检测到的隐私信息类型(如["手机号码", "身份证信息"]
- **summary**: 字符串,简要描述隐私检测结果
当memory_verify为false时memory_verify字段输出null。
### memory_verify字段示例
**示例1检测到隐私信息**
```json
"memory_verify": {
"has_privacy": true,
"privacy_types": ["手机号码", "身份证信息"],
"summary": "检测到2条记录包含隐私信息1个手机号码1个身份证号码"
}
```
**示例2未检测到隐私信息**
```json
"memory_verify": {
"has_privacy": false,
"privacy_types": [],
"summary": "未检测到隐私信息"
}
```
**示例3memory_verify为false时**
```json
"memory_verify": null
```
模式参考:
[
{{ json_schema }}
]
{{ json_schema }}

View File

@@ -0,0 +1,57 @@
你是一个专业的情绪分析专家。请分析以下陈述句的情绪信息。
陈述句:{{ statement }}
请提取以下信息:
1. emotion_type情绪类型
- joy: 喜悦、开心、高兴、满意、愉快
- sadness: 悲伤、难过、失落、沮丧、遗憾
- anger: 愤怒、生气、不满、恼火、烦躁
- fear: 恐惧、害怕、担心、焦虑、紧张
- surprise: 惊讶、意外、震惊、吃惊
- neutral: 中性、客观陈述、无明显情绪
2. emotion_intensity情绪强度
- 0.0-0.3: 弱情绪
- 0.3-0.7: 中等情绪
- 0.7-1.0: 强情绪
{% if extract_keywords %}
3. emotion_keywords情绪关键词
- 原句中直接表达情绪的词语
- 最多提取3个关键词
- 如果没有明显的情绪词,返回空列表
{% else %}
3. emotion_keywords情绪关键词
- 返回空列表
{% endif %}
{% if enable_subject %}
4. emotion_subject情绪主体
- self: 用户本人的情绪(包含"我"、"我们"、"咱们"等第一人称)
- other: 他人的情绪(包含人名、"他/她"等第三人称)
- object: 对事物的评价(针对产品、地点、事件等)
注意:
- 如果同时包含多个主体优先识别用户本人self
- 如果无法明确判断主体,默认为 self
5. emotion_target情绪对象
- 如果有明确的情绪对象,提取其名称
- 如果没有明确对象,返回 null
{% else %}
4. emotion_subject情绪主体
- 默认为 self
5. emotion_target情绪对象
- 返回 null
{% endif %}
注意事项:
- 如果陈述句是客观事实陈述,无明显情绪,标记为 neutral
- 情绪强度要符合语境,不要过度解读
- 情绪关键词要准确,不要添加原句中没有的词
- 主体分类要准确优先识别用户本人self
请以 JSON 格式返回结果。

View File

@@ -0,0 +1,63 @@
你是一位专业的心理健康顾问。请根据以下用户的情绪健康数据和个人信息生成3-5条个性化的情绪改善建议。
## 用户情绪健康数据
健康分数:{{ health_data.health_score }}/100
健康等级:{{ health_data.level }}
维度分析:
- 积极率:{{ health_data.dimensions.positivity_rate.score }}/100
- 正面情绪:{{ health_data.dimensions.positivity_rate.positive_count }}次
- 负面情绪:{{ health_data.dimensions.positivity_rate.negative_count }}次
- 中性情绪:{{ health_data.dimensions.positivity_rate.neutral_count }}次
- 稳定性:{{ health_data.dimensions.stability.score }}/100
- 标准差:{{ health_data.dimensions.stability.std_deviation }}
- 恢复力:{{ health_data.dimensions.resilience.score }}/100
- 恢复率:{{ health_data.dimensions.resilience.recovery_rate }}
情绪分布:
{{ emotion_distribution_json }}
## 情绪模式分析
主要负面情绪:{{ patterns.dominant_negative_emotion|default('无') }}
情绪波动性:{{ patterns.emotion_volatility|default('未知') }}
高强度情绪次数:{{ patterns.high_intensity_emotions|default([])|length }}
## 用户兴趣
{{ user_profile.interests|default(['未知'])|join(', ') }}
## 任务要求
请生成3-5条个性化建议每条建议包含
1. type: 建议类型emotion_balance/activity_recommendation/social_connection/stress_management
2. title: 建议标题(简短有力)
3. content: 建议内容详细说明50-100字
4. priority: 优先级high/medium/low
5. actionable_steps: 3个可执行的具体步骤
同时提供一个health_summary不超过50字概括用户的整体情绪状态。
请以JSON格式返回格式如下
{
"health_summary": "您的情绪健康状况...",
"suggestions": [
{
"type": "emotion_balance",
"title": "建议标题",
"content": "建议内容...",
"priority": "high",
"actionable_steps": ["步骤1", "步骤2", "步骤3"]
}
]
}
注意事项:
- 建议要具体、可执行,避免空泛
- 结合用户的兴趣爱好提供个性化建议
- 针对主要问题(如主要负面情绪)提供针对性建议
- 优先级要合理分配至少1个high1-2个medium其余low
- 每个建议的3个步骤要循序渐进、易于实施

View File

@@ -1,23 +1,300 @@
你将收到一组用户历史记忆原始数据(来源于 Neo4j
你将收到一条冲突判定对象:{{ data }}。
任务:分析冲突产生原因,给出解决方案,并生成设为失效后的记忆。
需要检测冲突对象:{{ statement_databasets }}
以及需要识别的冲突对象为:{{ baseline }}
记忆审核开关:{{ memory_verify }}(取值为 true / false
角色:
- 你是数据领域中解决数据冲突的专家
任务:分析冲突产生原因,按冲突类型分组处理,为每种冲突类型生成独立的解决方案。
数据的结构:
statement_databasets里面statement_name是输入的句子statement_id是连接data里面的statement_id代表这个句子被拆分成几个实体需要根据整体的内容
需要根据以下内容做处理(冲突检测、记忆审核、记忆的质量评估),data里面的statement_created_at是用户输入的时间
**处理模式**
- 当memory_verify为false时仅处理数据冲突
- 当memory_verify为true时处理数据冲突 + 隐私信息脱敏
## 分组处理原则
**冲突类型识别与分组**
1. **日期冲突**
1.1.涉及用户生日的不同日期记录如2月10号 vs 2月16号
1.2.涉及同一活动的不同时间记录(如周五打球 vs 周六打球)
3. **事实属性冲突**
3.1. **属性互斥**:同一实体的相反属性(喜欢↔不喜欢、有↔没有、是↔不是)
3.2. **关系矛盾**:同一实体在相同语境下的不同关系描述
3.3. **身份冲突**:同一实体被赋予不同的类型或角色
4. **其他冲突类型/混合冲突(时间+事实)**:根据具体数据识别
**分组输出要求**
- 每种冲突类型生成一个独立的reflexion_result对象
- 同一类型的多个冲突记录归并到一个结果中
- 不同类型的冲突分别处理,各自生成独立结果
## 冲突类型定义
### 时间冲突TIME
时间维度冲突是指两个事件发生时间重叠,或者用户同一件事情和场景等情况下,时间出现了变化。
### 事实冲突FACT
事实冲突是指同一事实对象(同一个人、同一个时间、同一个状态)但陈述内容相互矛盾,主要为真假不能共存的情况。
### 混合冲突HYBRID
检测所有类型的冲突,包括但不限于时间冲突和事实冲突:检测任何逻辑上不一致或相互矛盾的记录
{% if memory_verify %}
## 隐私信息处理memory_verify为true时启用
### 隐私信息识别
需要识别并处理以下类型的隐私信息:
1. **身份证信息**:包含身份证号码、身份证相关描述
2. **手机号码**:包含手机号、电话号码等联系方式
3. **社交账号**包含微信号、QQ号、邮箱地址等社交平台信息
4. **银行信息**:包含银行卡号、账户信息、支付信息
5. **税务信息**:包含税号、纳税信息、发票信息
6. **贷款信息**:包含贷款记录、信贷信息、借款信息
7. **其他敏感信息**包含密码、PIN码、验证码等安全信息
### 隐私数据脱敏规则
对于检测到的隐私信息,按以下规则进行脱敏处理:
**数字类隐私信息脱敏**
- 保留前三位和后四位,中间用*代替
- 示例手机号13812345678 → 138****5678
- 示例身份证110101199001011234 → 110***********1234
- 示例银行卡6222021234567890 → 622***********7890
**文本类隐私信息脱敏**
- 社交账号:保留前三后四位字符,中间用*代替
- 示例微信号user123456 → use****3456
- 示例邮箱zhang.san@example.com → zha****@example.com
**脱敏处理字段**
- name字段如包含隐私信息需脱敏
- entity1_name字段如包含隐私信息需脱敏
- entity2_name字段如包含隐私信息需脱敏
- description字段如包含隐私信息需脱敏
{% endif %}
## 工作步骤
### 第一步:分析冲突类型匹配
首先判断输入的冲突数据是否符合baseline要求的类型
**类型匹配规则**
- 如果baseline是"TIME":只处理时间相关的冲突(涉及时间表达式、日期、时间点的冲突)
- 如果baseline是"FACT":只处理事实相关的冲突(属性矛盾、关系冲突、描述不一致)
- 如果baseline是"HYBRID":处理所有类型的冲突,也可以当作混合冲突类型处理
**类型识别**
- 时间冲突标识entity2的entity_type包含"TimeExpression"、"TemporalExpression"或entity2_name包含时间词汇周一到周日、月份日期等
- 事实冲突标识:相同实体的不同属性描述、互斥的关系陈述
**重要**如果输入的冲突类型与baseline不匹配必须输出空结果resolved为null
### 第二步:筛选并分组冲突数据
按冲突类型对数据进行分组:
**分组策略**
1. **时间冲突组**:筛选涉及用户时间的所有记录
2. **活动时间冲突组**:筛选涉及同一活动不同时间的记录
3. **事实冲突组**:筛选涉及同一实体不同属性的记录
4. **其他冲突组**:其他类型的冲突记录
**筛选条件**
- 只处理与baseline匹配的冲突类型
- 相同entity1_name但entity2_name不同的记录
- 相同关系但描述矛盾的记录
- 时间逻辑不一致的记录
### 第三步:冲突解决策略
** 不可以解决的冲突情况
1. 数据被判定为正确的情况下,不可以进行修改
**仅当冲突类型与baseline匹配时**,对筛选出的冲突数据进行处理:
**智能解决策略**
1. **分析冲突数据**:识别哪些记录是正确的,哪些是错误的,需要结合statement_databasets的输入原文来判定
2. **判断正确答案是否存在**
- 如果正确答案已存在于data中只需将错误记录的expired_at设为当前日期2025-12-16T12:00:00
- 如果正确答案已存在于data中错误记录的expired_at已经设为日期则不需要对正确的数据进行修改
- 如果正确答案不存在于data中需要修改现有记录的内容以包含正确信息
{% if memory_verify %}
**隐私处理集成**
- 在处理冲突的同时,需要对涉及的记录进行隐私脱敏
- 脱敏处理应该在冲突解决之后进行,确保最终输出的记录都已脱敏
- 在change字段中记录隐私脱敏的变更
{% endif %}
**具体处理规则**
**情况1正确答案存在于data中**
- 保留正确的记录不变
- 基于时间关系的冲突:
需要只修改错误记录的expired_at为当前时间2025-12-16T12:00:00
- 基于事实的关系冲突
- resolved.resolved_memory只包含被设为失效的错误记录
- change字段只记录expired_at的变更`[{"expired_at": "2025-12-16T12:00:00"}]`(注意:如果已存在时间,则不需要对其修改,也不需要变更 时间)
**情况2正确答案不存在于data中**
- 选择最合适的记录进行修改
- 更新该记录的相关字段:
- description字段添加或修改描述信息{% if memory_verify %}(如包含隐私信息,需脱敏处理){% endif %}
- name字段修改名称字段{% if memory_verify %}(如需要,包含隐私信息时需脱敏){% endif %}
- resolved.resolved_memory包含修改后的完整记录{% if memory_verify %}(已脱敏){% endif %}
- change字段记录所有被修改的字段{% if memory_verify %},包括脱敏变更{% endif %},例如:`[{"description": "新描述"{% if memory_verify %}, "entity2_name": "138****5678"{% endif %}}]`
**重要原则**
- **只输出需要修改的记录**resolved.resolved_memory只包含实际需要修改的数据
- **优先保留策略**时间冲突保留最可信的created_at时间的记录事实冲突选择最新且可信度最高的记录
- **精确记录变更**change字段必须包含记录ID、字段名称、新值和旧值
{% if memory_verify %}- **隐私保护优先**:所有输出的记录必须完成隐私脱敏处理
- **脱敏变更记录**隐私脱敏的变更也必须在change字段中详细记录{% endif %}
- **不可修改数据**:数据被判定为正确时,不可以进行修改,如果没有数据可输出空
**变更记录格式**
```json
"change": [
{
"field": [
{"字段名1": "修改后的值1"},
{"字段名2": "修改后的值2"}
]
}
]
```
**类型不匹配处理**
- 如果冲突类型与baseline不匹配resolved必须设为null
- reflexion.reason说明类型不匹配的原因
- reflexion.solution说明无需处理
### 第四步:输出解决方案
## 输出要求
**嵌套字段映射**(系统会自动处理):
- `entity2.name` → 自动映射为 `name`
- `entity1.name` → 自动映射为 `name`
- `entity1.description` → 自动映射为 `description`
- `entity2.description` → 自动映射为 `description`
返回数据格式以json方式输出
- 必须通过json.loads()的格式支持的形式输出
- 响应必须是与此确切模式匹配的有效JSON对象
- 不要在JSON之前或之后包含任何文本
JSON格式要求
1. JSON结构仅使用标准ASCII双引号"
2. 如果提取的语句文本包含引号,请使用反斜杠(\")正确转义
3. 确保所有JSON字符串都正确关闭并以逗号分隔
4. JSON字符串值中不包括换行符
5. 不允许输出```json```相关符号
仅输出一个合法 JSON 对象,严格遵循下述结构:
**输出格式:按冲突类型分组的列表**
{
"conflict": 与输入同结构,包含 data 与 conflict_memory,
"reflexion": { "reason": string, "solution": string },
"resolved": {
"original_memory_id": 被设为失效的记忆 id,
"resolved_memory": 完整的设为失效后的记忆对象
}
"results": [
{
"conflict": {
"data": [该冲突类型相关的数据记录],
"conflict": true
},
"reflexion": {
"reason": "该冲突类型的原因分析",
"solution": "该冲突类型的解决方案"
},
"resolved": {
"original_memory_id": "被设为失效的记忆id",
"resolved_memory": {
"entity1_name": "实体1名称",
"entity2_name": "实体2名称",
"description": "描述信息",
"statement_id": "陈述ID",
"created_at": "创建时间",
"expired_at": "过期时间",
"relationship_type": "关系类型",
"relationship": {},
"entity2": {...}
},
"change": [
{
"field": [
{"字段名1": "修改后的值1"},
{"字段名2": "修改后的值2"}
]
}
]
},
"type": "reflexion_result"
}
]
}
**示例:多种冲突类型的输出**
{
"results": [
{
"conflict": {
"data": [生日冲突相关的记录],
"conflict": true
},
"reflexion": {
"reason": "检测到生日冲突用户同时关联2月10号和2月16号两个不同日期",
"solution": "保留最新记录2月16号将旧记录2月10号设为失效"
},
"resolved": {
"original_memory_id": "df066210883545a08e727ccd8ad4ec77",
"resolved_memory": {...},
"change": [
{
"field": [
{"expired_at": "2025-12-16T12:00:00"}
]
}
]
},
"type": "reflexion_result"
},
{
"conflict": {
"data": [篮球时间冲突相关的记录],
"conflict": true
},
"reflexion": {
"reason": "检测到活动时间冲突:用户打篮球时间存在周五和周六的冲突",
"solution": "保留最可信的时间记录,将冲突记录设为失效"
},
"resolved": {
"original_memory_id": "另一个记录ID",
"resolved_memory": {...},
"change": [
{
"field": [
{"description": "使用系统的个人,指代说话者本人,篮球时间为周六"},
{"entity2_name": "周六"}
]
}
]
},
"type": "reflexion_result"
}
]
}
必须遵守:
- 只输出 JSON不要添加解释或多余文本
- 使用标准双引号,必要时对内部引号进行转义
- 字段名与结构必须与给定模式一致
- 当 conflict 为 false 时resolved 必须为 null。
- 其中 conflict.data 必须为数组形式,即使只有一个对象也需使用 [ ] 包裹。
- 只输出 JSON不要添加解释或多余文本
- 使用标准双引号,必要时对内部引号进行转义
- 字段名与结构必须与给定模式一致
- **输出必须是results数组格式**,每个冲突类型作为一个独立的对象
- **按冲突类型分组**相同类型的冲突记录归并到一个result对象中
- **每个result对象的conflict.data**只包含该冲突类型相关的记录
- **resolved.resolved_memory 只包含需要修改的记录**,不需要修改的记录不要输出
- **resolved.change 必须包含详细的变更信息**field数组包含所有被修改的字段及其新值
- 如果某个冲突类型经分析无需修改任何数据该类型的resolved 必须为 null
- 如果与baseline不匹配的冲突类型不要在results中包含该类型
模式参考:
[
{{ json_schema }}
]
{{ json_schema }}

View File

@@ -7,36 +7,50 @@ from typing import List, Dict, Any
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any]) -> str:
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any],
baseline: str = "TIME",
memory_verify: bool = False,quality_assessment:bool = False,statement_databasets: List[str] = []) -> str:
"""
Renders the evaluate prompt using the evaluate.jinja2 template.
Renders the evaluate prompt using the evaluate_optimized.jinja2 template.
Args:
evaluate_data: The data to evaluate
schema: The JSON schema to use for the output.
baseline: The baseline type for conflict detection (TIME/FACT/TIME-FACT)
memory_verify: Whether to enable memory verification for privacy detection
Returns:
Rendered prompt content as string
"""
template = prompt_env.get_template("evaluate.jinja2")
rendered_prompt = template.render(evaluate_data=evaluate_data, json_schema=schema)
rendered_prompt = template.render(
evaluate_data=evaluate_data,
json_schema=schema,
baseline=baseline,
memory_verify=memory_verify,
quality_assessment=quality_assessment,
statement_databasets=statement_databasets
)
return rendered_prompt
async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any]) -> str:
async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any], baseline: str, memory_verify: bool = False,
statement_databasets: List[str] = []) -> str:
"""
Renders the reflexion prompt using the extract_temporal.jinja2 template.
Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
Args:
data: The data to reflex on.
schema: The JSON schema to use for the output.
baseline: The baseline type for conflict resolution.
Returns:
Rendered prompt content as a string.
"""
template = prompt_env.get_template("reflexion.jinja2")
rendered_prompt = template.render(data=data, json_schema=schema)
rendered_prompt = template.render(data=data, json_schema=schema,
baseline=baseline,memory_verify=memory_verify,
statement_databasets=statement_databasets)
return rendered_prompt

View File

@@ -1,8 +1,8 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional
from typing import Any, Iterator, AsyncIterator, List, Optional
from langchain_core.callbacks import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
from langchain_core.language_models import BaseLLM
from langchain_core.outputs import LLMResult
from langchain_core.outputs import LLMResult, GenerationChunk
from app.core.models import RedBearModelConfig, RedBearModelFactory, get_provider_llm_class
from app.models.models_model import ModelType
@@ -10,21 +10,36 @@ from app.models.models_model import ModelType
class RedBearLLM(BaseLLM):
"""
RedBear LLM 模型包装器 - 完全动态代理实现
RedBear LLM Model Wrapper
这个包装器自动将所有方法调用委托给内部模型,
同时提供优雅的回退机制和错误处理。
This wrapper provides a unified interface to access different LLM providers,
while maintaining all LangChain functionality, including streaming output.
Features:
- Support for multiple LLM providers (OpenAI, Qwen, Ollama, etc.)
- Full streaming output support
- Elegant error handling and fallback mechanism
- Automatic proxying of all underlying model methods and attributes
"""
def __init__(self, config: RedBearModelConfig, type: ModelType=ModelType.LLM):
self._model = self._create_model(config, type)
def __init__(self, config: RedBearModelConfig, type: ModelType = ModelType.LLM):
"""Initialize RedBear LLM wrapper
Args:
config: Model configuration
type: Model type (LLM or CHAT)
"""
super().__init__()
self._config = config
self._model = self._create_model(config, type)
@property
def _llm_type(self) -> str:
"""返回LLM类型标识符"""
return self._model._llm_type
"""Return LLM type identifier"""
return getattr(self._model, '_llm_type', 'redbear_llm')
# ==================== Core Methods (Required by BaseLLM) ====================
def _generate(
self,
prompts: List[str],
@@ -32,7 +47,7 @@ class RedBearLLM(BaseLLM):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> LLMResult:
"""同步生成文本"""
"""Synchronous text generation (required by BaseLLM)"""
return self._model._generate(prompts, stop=stop, run_manager=run_manager, **kwargs)
async def _agenerate(
@@ -42,92 +57,233 @@ class RedBearLLM(BaseLLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any
) -> LLMResult:
"""异步生成文本"""
"""Asynchronous text generation (required by BaseLLM)"""
return await self._model._agenerate(prompts, stop=stop, run_manager=run_manager, **kwargs)
# 关键:覆盖 invoke/ainvoke直接委托到底层模型避免 BaseLLM 的字符串化行为
# ==================== Advanced Methods (Support Message Lists) ====================
def invoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any:
"""直接调用底层模型以支持 ChatPrompt 和消息列表。"""
"""Synchronous model invocation
Supports various input formats including strings and message lists.
Directly delegates to the underlying model to avoid BaseLLM's string conversion.
Args:
input: Input (string, message list, etc.)
config: Runtime configuration
**kwargs: Additional arguments
Returns:
Model response
"""
try:
return self._model.invoke(input, config=config, **kwargs)
except AttributeError as e:
# 只在属性错误时回退(说明底层模型不支持该方法)
if 'invoke' in str(e):
# Underlying model doesn't support invoke, fallback to parent implementation
return super().invoke(input, config=config, **kwargs)
# 其他 AttributeError 直接抛出
raise
except Exception:
# 其他所有异常(包括 ValidationException直接抛出不回退
# Other exceptions are raised directly
raise
async def ainvoke(self, input: Any, config: Optional[dict] = None, **kwargs: Any) -> Any:
"""异步直接调用底层模型以支持 ChatPrompt 和消息列表。"""
"""Asynchronous model invocation
Supports various input formats including strings and message lists.
Directly delegates to the underlying model to avoid BaseLLM's string conversion.
Args:
input: Input (string, message list, etc.)
config: Runtime configuration
**kwargs: Additional arguments
Returns:
Model response
"""
try:
return await self._model.ainvoke(input, config=config, **kwargs)
except AttributeError as e:
# 只在属性错误时回退(说明底层模型不支持该方法)
if 'ainvoke' in str(e):
# Underlying model doesn't support ainvoke, fallback to parent implementation
return await super().ainvoke(input, config=config, **kwargs)
# 其他 AttributeError 直接抛出
raise
except Exception:
# 其他所有异常(包括 ValidationException直接抛出不回退
# Other exceptions are raised directly
raise
def __getattr__(self, name):
"""
动态代理:将所有未定义的属性和方法调用委托给内部模型
# ==================== Streaming Methods (Critical) ====================
def stream(
self,
input: Any,
config: Optional[dict] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any
) -> Iterator[GenerationChunk]:
"""Synchronous streaming model invocation
这是最优雅的包装器实现方式,完全避免了方法重复定义
"""
# 处理特殊属性以避免递归
if name in ('__isabstractmethod__', '__dict__', '__class__'):
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
Args:
input: Input (string, message list, etc.)
config: Runtime configuration
stop: List of stop words
**kwargs: Additional arguments
# 检查内部模型是否有该属性(使用安全的方式避免递归)
Yields:
GenerationChunk: Generated text chunks
"""
try:
yield from self._model.stream(input, config=config, stop=stop, **kwargs)
except AttributeError as e:
if 'stream' in str(e):
# Underlying model doesn't support stream, fallback to parent implementation
yield from super().stream(input, config=config, stop=stop, **kwargs)
else:
raise
except Exception:
raise
async def astream(
self,
input: Any,
config: Optional[dict] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any
) -> AsyncIterator[GenerationChunk]:
"""Asynchronous streaming model invocation
This is the core method for streaming output. It directly proxies to the
underlying model's astream method, maintaining generator characteristics
to ensure each chunk is delivered in real-time.
Args:
input: Input (string, message list, etc.)
config: Runtime configuration
stop: List of stop words
**kwargs: Additional arguments
Yields:
GenerationChunk: Generated text chunks
"""
try:
async for chunk in self._model.astream(input, config=config, stop=stop, **kwargs):
yield chunk
except AttributeError as e:
if 'astream' in str(e):
# Underlying model doesn't support astream, fallback to parent implementation
async for chunk in super().astream(input, config=config, stop=stop, **kwargs):
yield chunk
else:
raise
except Exception:
raise
# ==================== Dynamic Proxy ====================
def __getattr__(self, name: str) -> Any:
"""Dynamic proxy: delegate undefined attributes and method calls to internal model
This method allows RedBearLLM to transparently access all attributes and methods
of the underlying model without explicitly defining each one.
Args:
name: Attribute or method name
Returns:
Attribute value or method
Raises:
AttributeError: If attribute doesn't exist
"""
# Avoid recursion: raise error directly for special attributes
if name in ('__isabstractmethod__', '__dict__', '__class__', '_model', '_config'):
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
# Try to get attribute from internal model
try:
# 使用 object.__getattribute__ 来安全地检查内部模型的属性
attr = object.__getattribute__(self._model, name)
# 如果是方法,返回一个包装器来处理调用
# If it's callable (a method)
if callable(attr):
# 流式方法直接返回,不包装(保持生成器特性)
if name in ('_stream', '_astream', 'stream', 'astream'):
# Streaming methods are returned directly to maintain generator characteristics
# Note: Although we've explicitly implemented stream/astream,
# this is kept to handle internal methods like _stream/_astream
if name in ('_stream', '_astream'):
return attr
# 非流式方法使用包装器处理异常
# Wrap other methods for easier debugging and error handling
def method_wrapper(*args, **kwargs):
return attr(*args, **kwargs)
try:
return attr(*args, **kwargs)
except Exception:
# Can add logging or error handling here
raise
# 保持方法的元信息
# Preserve method metadata
method_wrapper.__name__ = name
method_wrapper.__doc__ = getattr(attr, '__doc__', f"Delegated method: {name}")
return method_wrapper
# 如果是普通属性,直接返回
# If it's a regular attribute, return directly
return attr
except AttributeError:
# 内部模型没有该属性,尝试回退实现
# Internal model doesn't have this attribute either
pass
# 检查是否有回退方法(使用安全的方式避免递归)
# Check if there's a fallback method
fallback_name = f'_fallback_{name}'
try:
fallback_method = object.__getattribute__(self, fallback_name)
return fallback_method
return object.__getattribute__(self, fallback_name)
except AttributeError:
# 没有回退方法,抛出适当的错误
pass
# 如果都没有,抛出适当的错误
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
# Nothing found, raise error
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'. "
f"The underlying model '{type(self._model).__name__}' also doesn't have this attribute."
)
# ==================== Helper Methods ====================
def _create_model(self, config: RedBearModelConfig, type: ModelType) -> BaseLLM:
"""创建内部模型实例"""
"""Create internal model instance
Args:
config: Model configuration
type: Model type
Returns:
Created model instance
"""
llm_class = get_provider_llm_class(config, type)
model_params = RedBearModelFactory.get_model_params(config)
return llm_class(**model_params)
def get_config(self) -> RedBearModelConfig:
"""Get model configuration
Returns:
Model configuration object
"""
return self._config
def get_underlying_model(self) -> BaseLLM:
"""Get underlying model instance
Returns:
Underlying model instance
"""
return self._model
def __repr__(self) -> str:
"""Return string representation of the object"""
return (
f"RedBearLLM("
f"provider={self._config.provider}, "
f"model={self._config.model_name}, "
f"type={type(self._model).__name__}"
f")"
)

View File

@@ -1,12 +1,23 @@
import xxhash
from app.aioRedis import aio_redis_set, aio_redis_get
import redis
from app.core.config import settings
redis_client = redis.StrictRedis(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD,
decode_responses=True,
max_connections=30
)
def get_llm_cache(llmnm, txt, history, genconf):
hasher = xxhash.xxh64()
hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8"))
hasher.update((str(llmnm) + str(txt) + str(history) + str(genconf)).encode("utf-8"))
k = hasher.hexdigest()
bin = aio_redis_get(k)
bin = redis_client.get(k)
if not bin:
return None
return bin
@@ -14,6 +25,6 @@ def get_llm_cache(llmnm, txt, history, genconf):
def set_llm_cache(llmnm, txt, v, history, genconf):
hasher = xxhash.xxh64()
hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8"))
hasher.update((str(llmnm) + str(txt) + str(history) + str(genconf)).encode("utf-8"))
k = hasher.hexdigest()
aio_redis_set(k, v.encode("utf-8"), 24 * 3600)
redis_client.set(k, v.encode("utf-8"), 24 * 3600)

View File

@@ -119,7 +119,7 @@ def keyword_extraction(chat_mdl, content, topn=3):
rendered_prompt = template.render(content=content, topn=topn)
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, chat_mdl.max_length)
_, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096))
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple):
kwd = kwd[0]
@@ -194,7 +194,7 @@ def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
)
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, chat_mdl.max_length)
_, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096))
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.5})
if isinstance(kwd, tuple):
kwd = kwd[0]
@@ -314,7 +314,7 @@ def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defi
hist[-1]["content"] += user_prompt
else:
hist.append({"role": "user", "content": user_prompt})
_, msg = message_fit_in(hist, chat_mdl.max_length)
_, msg = message_fit_in(hist, getattr(chat_mdl, 'max_length', 8096))
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
return """
@@ -341,7 +341,7 @@ def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defin
params=json.dumps(params, ensure_ascii=False, indent=2),
result=result)
user_prompt = "→ Summary: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
_, msg = message_fit_in(form_message(system_prompt, user_prompt), getattr(chat_mdl, 'max_length', 8096))
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
@@ -350,7 +350,7 @@ def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[st
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
user_prompt = " → rank: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
_, msg = message_fit_in(form_message(system_prompt, user_prompt), getattr(chat_mdl, 'max_length', 8096))
ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>")
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
@@ -378,7 +378,7 @@ def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None):
cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf)
if cached:
return json_repair.loads(cached)
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
_, msg = message_fit_in(form_message(system_prompt, user_prompt), getattr(chat_mdl, 'max_length', 8096))
ans = chat_mdl.chat(msg[0]["content"], msg[1:],gen_conf=gen_conf)
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
try:
@@ -641,7 +641,7 @@ def split_chunks(chunks, max_length: int):
async def run_toc_from_text(chunks, chat_mdl, callback=None):
input_budget = int(chat_mdl.max_length * INPUT_UTILIZATION) - num_tokens_from_string(
input_budget = int(getattr(chat_mdl, 'max_length', 8096) * INPUT_UTILIZATION) - num_tokens_from_string(
TOC_FROM_TEXT_USER + TOC_FROM_TEXT_SYSTEM
)

View File

@@ -0,0 +1,37 @@
"""工具管理核心模块"""
from .base import BaseTool, ToolResult, ToolParameter
from .registry import ToolRegistry
from .executor import ToolExecutor
from .langchain_adapter import LangchainAdapter
from .config_manager import ConfigManager
from .chain_manager import ChainManager
# 可选导入,避免导入错误
try:
from .custom.base import CustomTool
except ImportError:
CustomTool = None
try:
from .mcp.base import MCPTool
except ImportError:
MCPTool = None
__all__ = [
"BaseTool",
"ToolResult",
"ToolParameter",
"ToolRegistry",
"ToolExecutor",
"LangchainAdapter",
"ConfigManager",
"ChainManager"
]
# 只有在成功导入时才添加到__all__
if CustomTool:
__all__.append("CustomTool")
if MCPTool:
__all__.append("MCPTool")

302
api/app/core/tools/base.py Normal file
View File

@@ -0,0 +1,302 @@
"""工具基础接口定义"""
import time
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
from enum import Enum
from app.models.tool_model import ToolType, ToolStatus
class ParameterType(str, Enum):
"""参数类型枚举"""
STRING = "string"
INTEGER = "integer"
NUMBER = "number"
BOOLEAN = "boolean"
ARRAY = "array"
OBJECT = "object"
class ToolParameter(BaseModel):
"""工具参数定义"""
name: str = Field(..., description="参数名称")
type: ParameterType = Field(..., description="参数类型")
description: str = Field("", description="参数描述")
required: bool = Field(False, description="是否必需")
default: Any = Field(None, description="默认值")
enum: Optional[List[Any]] = Field(None, description="枚举值")
minimum: Optional[Union[int, float]] = Field(None, description="最小值")
maximum: Optional[Union[int, float]] = Field(None, description="最大值")
pattern: Optional[str] = Field(None, description="正则表达式模式")
class Config:
use_enum_values = True
class ToolResult(BaseModel):
"""工具执行结果"""
success: bool = Field(..., description="执行是否成功")
data: Any = Field(None, description="返回数据")
error: Optional[str] = Field(None, description="错误信息")
error_code: Optional[str] = Field(None, description="错误代码")
execution_time: float = Field(..., description="执行时间(秒)")
token_usage: Optional[Dict[str, int]] = Field(None, description="Token使用情况")
metadata: Dict[str, Any] = Field(default_factory=dict, description="额外元数据")
@classmethod
def success_result(
cls,
data: Any,
execution_time: float,
token_usage: Optional[Dict[str, int]] = None,
metadata: Optional[Dict[str, Any]] = None
) -> "ToolResult":
"""创建成功结果"""
return cls(
success=True,
data=data,
execution_time=execution_time,
token_usage=token_usage,
metadata=metadata or {}
)
@classmethod
def error_result(
cls,
error: str,
execution_time: float,
error_code: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
) -> "ToolResult":
"""创建错误结果"""
return cls(
success=False,
error=error,
error_code=error_code,
execution_time=execution_time,
metadata=metadata or {}
)
class ToolInfo(BaseModel):
"""工具信息"""
id: str = Field(..., description="工具ID")
name: str = Field(..., description="工具名称")
description: str = Field(..., description="工具描述")
tool_type: ToolType = Field(..., description="工具类型")
version: str = Field("1.0.0", description="工具版本")
parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数")
status: ToolStatus = Field(ToolStatus.ACTIVE, description="工具状态")
tags: List[str] = Field(default_factory=list, description="工具标签")
tenant_id: Optional[str] = Field(None, description="租户ID")
class Config:
use_enum_values = True
class BaseTool(ABC):
"""所有工具的基础抽象类"""
def __init__(self, tool_id: str, config: Dict[str, Any]):
"""初始化工具
Args:
tool_id: 工具ID
config: 工具配置
"""
self.tool_id = tool_id
self.config = config
self._status = ToolStatus.ACTIVE
@property
@abstractmethod
def name(self) -> str:
"""工具名称"""
pass
@property
@abstractmethod
def description(self) -> str:
"""工具描述"""
pass
@property
@abstractmethod
def tool_type(self) -> ToolType:
"""工具类型"""
pass
@property
def version(self) -> str:
"""工具版本"""
return self.config.get("version", "1.0.0")
@property
def status(self) -> ToolStatus:
"""工具状态"""
return self._status
@status.setter
def status(self, value: ToolStatus):
"""设置工具状态"""
self._status = value
@property
@abstractmethod
def parameters(self) -> List[ToolParameter]:
"""工具参数定义"""
pass
@property
def tags(self) -> List[str]:
"""工具标签"""
return self.config.get("tags", [])
def get_info(self) -> ToolInfo:
"""获取工具信息"""
return ToolInfo(
id=self.tool_id,
name=self.name,
description=self.description,
tool_type=self.tool_type,
version=self.version,
parameters=self.parameters,
status=self.status,
tags=self.tags,
tenant_id=self.config.get("tenant_id")
)
def validate_parameters(self, parameters: Dict[str, Any]) -> Dict[str, str]:
"""验证参数
Args:
parameters: 输入参数
Returns:
验证错误字典,空字典表示验证通过
"""
errors = {}
param_definitions = {p.name: p for p in self.parameters}
# 检查必需参数
for param_def in self.parameters:
if param_def.required and param_def.name not in parameters:
errors[param_def.name] = f"Required parameter '{param_def.name}' is missing"
# 检查参数类型和约束
for param_name, param_value in parameters.items():
if param_name not in param_definitions:
continue
param_def = param_definitions[param_name]
# 类型检查
if not self._validate_parameter_type(param_value, param_def):
errors[param_name] = f"Parameter '{param_name}' has invalid type, expected {param_def.type}"
# 约束检查
constraint_error = self._validate_parameter_constraints(param_value, param_def)
if constraint_error:
errors[param_name] = constraint_error
return errors
def _validate_parameter_type(self, value: Any, param_def: ToolParameter) -> bool:
"""验证参数类型"""
if value is None:
return not param_def.required
type_mapping = {
ParameterType.STRING: str,
ParameterType.INTEGER: int,
ParameterType.NUMBER: (int, float),
ParameterType.BOOLEAN: bool,
ParameterType.ARRAY: list,
ParameterType.OBJECT: dict
}
expected_type = type_mapping.get(param_def.type)
if expected_type:
return isinstance(value, expected_type)
return True
def _validate_parameter_constraints(self, value: Any, param_def: ToolParameter) -> Optional[str]:
"""验证参数约束"""
if value is None:
return None
# 枚举值检查
if param_def.enum and value not in param_def.enum:
return f"Value must be one of {param_def.enum}"
# 数值范围检查
if param_def.type in [ParameterType.INTEGER, ParameterType.NUMBER]:
if param_def.minimum is not None and value < param_def.minimum:
return f"Value must be >= {param_def.minimum}"
if param_def.maximum is not None and value > param_def.maximum:
return f"Value must be <= {param_def.maximum}"
# 字符串模式检查
if param_def.type == ParameterType.STRING and param_def.pattern:
import re
if not re.match(param_def.pattern, str(value)):
return f"Value must match pattern: {param_def.pattern}"
return None
@abstractmethod
async def execute(self, **kwargs) -> ToolResult:
"""执行工具
Args:
**kwargs: 工具参数
Returns:
执行结果
"""
pass
async def safe_execute(self, **kwargs) -> ToolResult:
"""安全执行工具(包含参数验证和异常处理)
Args:
**kwargs: 工具参数
Returns:
执行结果
"""
start_time = time.time()
try:
# 参数验证
validation_errors = self.validate_parameters(kwargs)
if validation_errors:
execution_time = time.time() - start_time
error_msg = "; ".join([f"{k}: {v}" for k, v in validation_errors.items()])
return ToolResult.error_result(
error=f"Parameter validation failed: {error_msg}",
error_code="VALIDATION_ERROR",
execution_time=execution_time
)
# 执行工具
result = await self.execute(**kwargs)
return result
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="EXECUTION_ERROR",
execution_time=execution_time
)
def to_langchain_tool(self):
"""转换为Langchain工具格式"""
from .langchain_adapter import LangchainAdapter
return LangchainAdapter.convert_tool(self)
def __repr__(self):
return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>"

View File

@@ -0,0 +1,17 @@
"""内置工具模块"""
from .base import BuiltinTool
from .datetime_tool import DateTimeTool
from .json_tool import JsonTool
from .baidu_search_tool import BaiduSearchTool
from .mineru_tool import MinerUTool
from .textin_tool import TextInTool
__all__ = [
"BuiltinTool",
"DateTimeTool",
"JsonTool",
"BaiduSearchTool",
"MinerUTool",
"TextInTool"
]

View File

@@ -0,0 +1,334 @@
"""百度搜索工具 - 搜索引擎服务"""
import time
from typing import List, Dict, Any
import aiohttp
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
class BaiduSearchTool(BuiltinTool):
"""百度搜索工具 - 提供网页搜索、新闻搜索、图片搜索、实时结果"""
@property
def name(self) -> str:
return "baidu_search_tool"
@property
def description(self) -> str:
return "百度搜索 - 搜索引擎服务:网页搜索、新闻搜索、图片搜索、实时结果"
def get_required_config_parameters(self) -> List[str]:
return ["api_key"]
@property
def parameters(self) -> List[ToolParameter]:
return [
ToolParameter(
name="query",
type=ParameterType.STRING,
description="搜索关键词",
required=True
),
ToolParameter(
name="search_type",
type=ParameterType.STRING,
description="搜索类型",
required=False,
default="web",
enum=["web", "news", "image", "video"]
),
ToolParameter(
name="page_size",
type=ParameterType.INTEGER,
description="每页结果数",
required=False,
default=10,
minimum=1,
maximum=50
),
ToolParameter(
name="page_num",
type=ParameterType.INTEGER,
description="页码从1开始",
required=False,
default=1,
minimum=1,
maximum=10
),
ToolParameter(
name="safe_search",
type=ParameterType.BOOLEAN,
description="是否启用安全搜索",
required=False,
default=True
),
ToolParameter(
name="region",
type=ParameterType.STRING,
description="搜索地区",
required=False,
default="cn",
enum=["cn", "hk", "tw", "us", "jp", "kr"]
),
ToolParameter(
name="time_filter",
type=ParameterType.STRING,
description="时间过滤",
required=False,
enum=["all", "day", "week", "month", "year"]
)
]
async def execute(self, **kwargs) -> ToolResult:
"""执行百度搜索"""
start_time = time.time()
try:
query = kwargs.get("query")
search_type = kwargs.get("search_type", "web")
page_size = kwargs.get("page_size", 10)
page_num = kwargs.get("page_num", 1)
safe_search = kwargs.get("safe_search", True)
region = kwargs.get("region", "cn")
time_filter = kwargs.get("time_filter")
if not query:
raise ValueError("query 参数是必需的")
# 根据搜索类型调用不同的API
if search_type == "web":
result = await self._web_search(query, page_size, page_num, safe_search, region, time_filter)
elif search_type == "news":
result = await self._news_search(query, page_size, page_num, region, time_filter)
elif search_type == "image":
result = await self._image_search(query, page_size, page_num, safe_search)
elif search_type == "video":
result = await self._video_search(query, page_size, page_num, safe_search)
else:
raise ValueError(f"不支持的搜索类型: {search_type}")
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="BAIDU_SEARCH_ERROR",
execution_time=execution_time
)
async def _web_search(self, query: str, page_size: int, page_num: int,
safe_search: bool, region: str, time_filter: str = None) -> Dict[str, Any]:
"""网页搜索"""
payload = {
"messages": [{"role": "user", "content": query}],
"edition": "standard",
"search_source": "baidu_search_v2",
"resource_type_filter": [{"type": "web", "top_k": min(page_size, 50)}],
"enable_full_content": True
}
if time_filter:
time_map = {"day": "now-1d/d", "week": "now-1w/d", "month": "now-1M/d", "year": "now-1y/d"}
if time_filter in time_map:
payload["search_filter"] = {"range": {"page_time": {"gte": time_map[time_filter], "lt": "now/d"}}}
payload["search_recency_filter"] = time_filter
results = await self._call_baidu_ai_search_api(payload)
search_results = []
if "references" in results:
for item in results["references"]:
search_results.append({
"title": item.get("title", ""),
"url": item.get("url", ""),
"snippet": item.get("content", ""),
"display_url": item.get("url", ""),
"rank": len(search_results) + 1
})
return {
"search_type": "web",
"query": query,
"total_results": len(search_results),
"page_num": page_num,
"page_size": page_size,
"results": search_results,
"answer": results.get("result", ""),
"references": results.get("references", [])
}
async def _news_search(self, query: str, page_size: int, page_num: int,
region: str, time_filter: str = None) -> Dict[str, Any]:
"""新闻搜索"""
payload = {
"messages": [{"role": "user", "content": query}],
"edition": "standard",
"search_source": "baidu_search_v2",
"resource_type_filter": [{"type": "new", "top_k": min(page_size, 50)}],
"enable_full_content": True
}
if time_filter:
time_map = {"day": "now-1d/d", "week": "now-1w/d", "month": "now-1M/d", "year": "now-1y/d"}
if time_filter in time_map:
payload["search_filter"] = {"range": {"page_time": {"gte": time_map[time_filter], "lt": "now/d"}}}
payload["search_recency_filter"] = time_filter
results = await self._call_baidu_ai_search_api(payload)
search_results = []
if "references" in results:
for item in results["references"]:
search_results.append({
"title": item.get("title", ""),
"url": item.get("url", ""),
"snippet": item.get("content", ""),
"display_url": item.get("url", ""),
"rank": len(search_results) + 1
})
return {
"search_type": "new",
"query": query,
"total_results": len(search_results),
"page_num": page_num,
"page_size": page_size,
"results": search_results,
"answer": results.get("result", ""),
"references": results.get("references", [])
}
async def _image_search(self, query: str, page_size: int, page_num: int,
safe_search: bool) -> Dict[str, Any]:
"""图片搜索"""
payload = {
"messages": [{"role": "user", "content": query}],
"edition": "standard",
"search_source": "baidu_search_v2",
"resource_type_filter": [{"type": "image", "top_k": min(page_size, 30)}],
"enable_full_content": True
}
results = await self._call_baidu_ai_search_api(payload)
search_results = []
if "references" in results:
for item in results["references"]:
search_results.append({
"title": item.get("title", ""),
"url": item.get("url", ""),
"snippet": item.get("content", ""),
"display_url": item.get("url", ""),
"rank": len(search_results) + 1
})
return {
"search_type": "image",
"query": query,
"total_results": len(search_results),
"page_num": page_num,
"page_size": page_size,
"results": search_results,
"answer": results.get("result", ""),
"references": results.get("references", [])
}
async def _video_search(self, query: str, page_size: int, page_num: int,
safe_search: bool) -> Dict[str, Any]:
"""视频搜索"""
payload = {
"messages": [{"role": "user", "content": query}],
"edition": "standard",
"search_source": "baidu_search_v2",
"resource_type_filter": [{"type": "video", "top_k": min(page_size, 10)}],
"enable_full_content": True
}
results = await self._call_baidu_ai_search_api(payload)
search_results = []
if "references" in results:
for item in results["references"]:
search_results.append({
"title": item.get("title", ""),
"url": item.get("url", ""),
"snippet": item.get("content", ""),
"display_url": item.get("url", ""),
"rank": len(search_results) + 1
})
return {
"search_type": "video",
"query": query,
"total_results": len(search_results),
"page_num": page_num,
"page_size": page_size,
"results": search_results,
"answer": results.get("result", ""),
"references": results.get("references", [])
}
async def _call_baidu_ai_search_api(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""调用百度AI搜索API"""
api_key = self.get_config_parameter("api_key")
if not api_key:
raise ValueError("百度搜索API密钥未配置")
url = "https://qianfan.baidubce.com/v2/ai_search/chat/completions"
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
timeout = aiohttp.ClientTimeout(total=30)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(url, headers=headers, json=payload) as response:
if response.status == 200:
return await response.json()
else:
raise Exception(f"HTTP错误: {response.status}")
async def test_connection(self) -> Dict[str, Any]:
"""测试连接"""
try:
api_key = self.get_config_parameter("api_key")
if not api_key:
return {
"success": False,
"error": "API密钥未配置"
}
# 发送测试请求验证API key是否有效
test_payload = {
"messages": [{"role": "user", "content": "test"}],
"edition": "standard",
"search_source": "baidu_search_v2",
"resource_type_filter": [{"type": "web", "top_k": 1}]
}
try:
await self._call_baidu_ai_search_api(test_payload)
return {
"success": True,
"message": "连接测试成功",
"api_key_masked": api_key[:8] + "***" if len(api_key) > 8 else "***"
}
except Exception as e:
return {
"success": False,
"error": f"API连接失败: {str(e)}"
}
except Exception as e:
return {
"success": False,
"error": str(e)
}

View File

@@ -0,0 +1,118 @@
"""内置工具基类"""
from abc import ABC, abstractmethod
from typing import Dict, Any, List
from app.models.tool_model import ToolType
from app.core.tools.base import BaseTool, ToolResult, ToolParameter
class BuiltinTool(BaseTool, ABC):
"""内置工具基类"""
def __init__(self, tool_id: str, config: Dict[str, Any]):
"""初始化内置工具
Args:
tool_id: 工具ID
config: 工具配置
"""
super().__init__(tool_id, config)
self.parameters_config = config.get("parameters", {})
@property
def tool_type(self) -> ToolType:
"""工具类型"""
return ToolType.BUILTIN
@property
@abstractmethod
def name(self) -> str:
"""工具名称 - 子类必须实现"""
pass
@property
@abstractmethod
def description(self) -> str:
"""工具描述 - 子类必须实现"""
pass
@property
@abstractmethod
def parameters(self) -> List[ToolParameter]:
"""工具参数定义 - 子类必须实现"""
pass
@abstractmethod
async def execute(self, **kwargs) -> ToolResult:
"""执行工具 - 子类必须实现
Args:
**kwargs: 工具参数
Returns:
执行结果
"""
pass
@property
def is_configured(self) -> bool:
"""检查工具是否已正确配置"""
required_params = self.get_required_config_parameters()
for param in required_params:
if not self.parameters_config.get(param):
return False
return True
def get_required_config_parameters(self) -> List[str]:
"""获取必需的配置参数列表
Returns:
必需配置参数名称列表
"""
return []
def get_config_parameter(self, name: str, default: Any = None) -> Any:
"""获取配置参数值
Args:
name: 参数名称
default: 默认值
Returns:
参数值
"""
return self.parameters_config.get(name, default)
def validate_configuration(self) -> tuple[bool, str]:
"""验证工具配置
Returns:
(是否有效, 错误信息)
"""
if not self.is_configured:
required_params = self.get_required_config_parameters()
missing_params = [p for p in required_params if not self.parameters_config.get(p)]
return False, f"缺少必需的配置参数: {', '.join(missing_params)}"
return True, ""
async def safe_execute(self, **kwargs) -> ToolResult:
"""安全执行工具(包含配置验证)
Args:
**kwargs: 工具参数
Returns:
执行结果
"""
# 首先验证配置
is_valid, error_msg = self.validate_configuration()
if not is_valid:
return ToolResult.error_result(
error=f"工具配置无效: {error_msg}",
error_code="CONFIGURATION_ERROR",
execution_time=0.0
)
# 调用父类的安全执行
return await super().safe_execute(**kwargs)

View File

@@ -0,0 +1,307 @@
"""时间工具 - 日期时间处理"""
import time
from datetime import datetime, timezone, timedelta
from typing import List
import pytz
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
class DateTimeTool(BuiltinTool):
"""时间工具 - 提供时间格式转换、时区转换、时间戳转换、时间计算功能"""
@property
def name(self) -> str:
return "datetime_tool"
@property
def description(self) -> str:
return "时间工具 - 日期时间处理:提供时间格式转化、时区转换、时间戳转换、时间计算"
@property
def parameters(self) -> List[ToolParameter]:
return [
ToolParameter(
name="operation",
type=ParameterType.STRING,
description="操作类型",
required=True,
enum=["format", "convert_timezone", "timestamp_to_datetime", "datetime_to_timestamp", "calculate", "now"]
),
ToolParameter(
name="input_value",
type=ParameterType.STRING,
description="输入值(时间字符串或时间戳)",
required=False
),
ToolParameter(
name="input_format",
type=ParameterType.STRING,
description="输入时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="from_timezone",
type=ParameterType.STRING,
description="源时区UTC, Asia/Shanghai",
required=False,
default="UTC"
),
ToolParameter(
name="to_timezone",
type=ParameterType.STRING,
description="目标时区UTC, Asia/Shanghai",
required=False,
default="UTC"
),
ToolParameter(
name="calculation",
type=ParameterType.STRING,
description="时间计算表达式(如:+1d, -2h, +30m",
required=False
)
]
async def execute(self, **kwargs) -> ToolResult:
"""执行时间工具操作"""
start_time = time.time()
try:
operation = kwargs.get("operation")
if operation == "now":
result = self._get_current_time(kwargs)
elif operation == "format":
result = self._format_datetime(kwargs)
elif operation == "convert_timezone":
result = self._convert_timezone(kwargs)
elif operation == "timestamp_to_datetime":
result = self._timestamp_to_datetime(kwargs)
elif operation == "datetime_to_timestamp":
result = self._datetime_to_timestamp(kwargs)
elif operation == "calculate":
result = self._calculate_datetime(kwargs)
else:
raise ValueError(f"不支持的操作类型: {operation}")
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="DATETIME_ERROR",
execution_time=execution_time
)
def _get_current_time(self, kwargs) -> dict:
"""获取当前时间"""
timezone_str = kwargs.get("to_timezone", "UTC")
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
if timezone_str == "UTC":
tz = timezone.utc
else:
tz = pytz.timezone(timezone_str)
now = datetime.now(tz)
return {
"datetime": now.strftime(output_format),
"timestamp": int(now.timestamp()),
"timezone": timezone_str,
"iso_format": now.isoformat()
}
def _format_datetime(self, kwargs) -> dict:
"""格式化时间"""
input_value = kwargs.get("input_value")
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
if not input_value:
raise ValueError("input_value 参数是必需的")
# 解析输入时间
dt = datetime.strptime(input_value, input_format)
return {
"original": input_value,
"formatted": dt.strftime(output_format),
"timestamp": int(dt.timestamp()),
"iso_format": dt.isoformat()
}
def _convert_timezone(self, kwargs) -> dict:
"""时区转换"""
input_value = kwargs.get("input_value")
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
from_timezone = kwargs.get("from_timezone", "UTC")
to_timezone = kwargs.get("to_timezone", "UTC")
if not input_value:
raise ValueError("input_value 参数是必需的")
# 解析输入时间
dt = datetime.strptime(input_value, input_format)
# 设置源时区
if from_timezone == "UTC":
from_tz = pytz.UTC
else:
from_tz = pytz.timezone(from_timezone)
# 设置目标时区
if to_timezone == "UTC":
to_tz = pytz.UTC
else:
to_tz = pytz.timezone(to_timezone)
# 本地化时间并转换时区
if dt.tzinfo is None:
dt = from_tz.localize(dt)
converted_dt = dt.astimezone(to_tz)
return {
"original": input_value,
"original_timezone": from_timezone,
"converted": converted_dt.strftime(output_format),
"converted_timezone": to_timezone,
"timestamp": int(converted_dt.timestamp())
}
def _timestamp_to_datetime(self, kwargs) -> dict:
"""时间戳转日期时间"""
input_value = kwargs.get("input_value")
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
timezone_str = kwargs.get("to_timezone", "UTC")
if not input_value:
raise ValueError("input_value 参数是必需的")
# 转换时间戳
timestamp = float(input_value)
# 设置时区
if timezone_str == "UTC":
tz = timezone.utc
else:
tz = pytz.timezone(timezone_str)
dt = datetime.fromtimestamp(timestamp, tz)
return {
"timestamp": timestamp,
"datetime": dt.strftime(output_format),
"timezone": timezone_str,
"iso_format": dt.isoformat()
}
def _datetime_to_timestamp(self, kwargs) -> dict:
"""日期时间转时间戳"""
input_value = kwargs.get("input_value")
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
timezone_str = kwargs.get("from_timezone", "UTC")
if not input_value:
raise ValueError("input_value 参数是必需的")
# 解析输入时间
dt = datetime.strptime(input_value, input_format)
# 设置时区
if timezone_str == "UTC":
tz = timezone.utc
else:
tz = pytz.timezone(timezone_str)
# 本地化时间
if dt.tzinfo is None:
dt = tz.localize(dt)
return {
"datetime": input_value,
"timezone": timezone_str,
"timestamp": int(dt.timestamp()),
"iso_format": dt.isoformat()
}
def _calculate_datetime(self, kwargs) -> dict:
"""时间计算"""
input_value = kwargs.get("input_value")
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
calculation = kwargs.get("calculation")
timezone_str = kwargs.get("from_timezone", "UTC")
if not input_value:
raise ValueError("input_value 参数是必需的")
if not calculation:
raise ValueError("calculation 参数是必需的")
# 解析输入时间
dt = datetime.strptime(input_value, input_format)
# 设置时区
if timezone_str == "UTC":
tz = timezone.utc
else:
tz = pytz.timezone(timezone_str)
if dt.tzinfo is None:
dt = tz.localize(dt)
# 解析计算表达式
delta = self._parse_time_delta(calculation)
calculated_dt = dt + delta
return {
"original": input_value,
"calculation": calculation,
"result": calculated_dt.strftime(output_format),
"timezone": timezone_str,
"timestamp": int(calculated_dt.timestamp())
}
def _parse_time_delta(self, calculation: str) -> timedelta:
"""解析时间计算表达式"""
import re
# 支持的单位d(天), h(小时), m(分钟), s(秒)
pattern = r'([+-]?\d+)([dhms])'
matches = re.findall(pattern, calculation.lower())
if not matches:
raise ValueError(f"无效的时间计算表达式: {calculation}")
total_delta = timedelta()
for value_str, unit in matches:
value = int(value_str)
if unit == 'd':
total_delta += timedelta(days=value)
elif unit == 'h':
total_delta += timedelta(hours=value)
elif unit == 'm':
total_delta += timedelta(minutes=value)
elif unit == 's':
total_delta += timedelta(seconds=value)
return total_delta

View File

@@ -0,0 +1,430 @@
"""JSON转换工具 - 数据格式转换"""
import json
import time
from typing import List, Any, Dict
import yaml
import xml.etree.ElementTree as ET
from xml.dom import minidom
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
class JsonTool(BuiltinTool):
"""JSON转换工具 - 提供JSON格式化、压缩、验证、格式转换功能"""
@property
def name(self) -> str:
return "json_tool"
@property
def description(self) -> str:
return "JSON转换工具 - 数据格式转换JSON格式化、JSON压缩、JSON验证、格式转换"
@property
def parameters(self) -> List[ToolParameter]:
return [
ToolParameter(
name="operation",
type=ParameterType.STRING,
description="操作类型",
required=True,
enum=["format", "minify", "validate", "convert", "to_yaml", "from_yaml", "to_xml", "from_xml", "merge", "extract"]
),
ToolParameter(
name="input_data",
type=ParameterType.STRING,
description="输入数据JSON字符串、YAML字符串或XML字符串",
required=True
),
ToolParameter(
name="indent",
type=ParameterType.INTEGER,
description="JSON格式化缩进空格数",
required=False,
default=2,
minimum=0,
maximum=8
),
ToolParameter(
name="ensure_ascii",
type=ParameterType.BOOLEAN,
description="是否确保ASCII编码",
required=False,
default=False
),
ToolParameter(
name="sort_keys",
type=ParameterType.BOOLEAN,
description="是否对键进行排序",
required=False,
default=False
),
ToolParameter(
name="merge_data",
type=ParameterType.STRING,
description="要合并的JSON数据用于merge操作",
required=False
),
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式用于extract操作$.user.name",
required=False
)
]
async def execute(self, **kwargs) -> ToolResult:
"""执行JSON工具操作"""
start_time = time.time()
try:
operation = kwargs.get("operation")
input_data = kwargs.get("input_data")
if not input_data:
raise ValueError("input_data 参数是必需的")
if operation == "format":
result = self._format_json(input_data, kwargs)
elif operation == "minify":
result = self._minify_json(input_data)
elif operation == "validate":
result = self._validate_json(input_data)
elif operation == "convert":
result = self._convert_json(input_data)
elif operation == "to_yaml":
result = self._json_to_yaml(input_data)
elif operation == "from_yaml":
result = self._yaml_to_json(input_data, kwargs)
elif operation == "to_xml":
result = self._json_to_xml(input_data)
elif operation == "from_xml":
result = self._xml_to_json(input_data, kwargs)
elif operation == "merge":
result = self._merge_json(input_data, kwargs)
elif operation == "extract":
result = self._extract_json_path(input_data, kwargs)
else:
raise ValueError(f"不支持的操作类型: {operation}")
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="JSON_ERROR",
execution_time=execution_time
)
def _format_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""格式化JSON"""
indent = kwargs.get("indent", 2)
ensure_ascii = kwargs.get("ensure_ascii", False)
sort_keys = kwargs.get("sort_keys", False)
# 解析JSON
data = json.loads(input_data)
# 格式化输出
formatted = json.dumps(
data,
indent=indent,
ensure_ascii=ensure_ascii,
sort_keys=sort_keys,
separators=(',', ': ')
)
return {
"original_size": len(input_data),
"formatted_size": len(formatted),
"formatted_json": formatted,
"is_valid": True,
"settings": {
"indent": indent,
"ensure_ascii": ensure_ascii,
"sort_keys": sort_keys
}
}
def _minify_json(self, input_data: str) -> Dict[str, Any]:
"""压缩JSON"""
# 解析并压缩
data = json.loads(input_data)
minified = json.dumps(data, separators=(',', ':'))
return {
"original_size": len(input_data),
"minified_size": len(minified),
"compression_ratio": round((1 - len(minified) / len(input_data)) * 100, 2),
"minified_json": minified,
"is_valid": True
}
def _validate_json(self, input_data: str) -> Dict[str, Any]:
"""验证JSON"""
try:
data = json.loads(input_data)
# 统计信息
stats = self._analyze_json_structure(data)
return {
"is_valid": True,
"error": None,
"size": len(input_data),
"structure": stats
}
except json.JSONDecodeError as e:
return {
"is_valid": False,
"error": str(e),
"error_line": getattr(e, 'lineno', None),
"error_column": getattr(e, 'colno', None),
"size": len(input_data)
}
def _convert_json(self, input_data: str) -> Dict[str, Any]:
"""JSON转义"""
data = json.loads(input_data)
converted = json.dumps(data, ensure_ascii=False)
return {
"converted_json": converted,
"is_valid": True
}
def _json_to_yaml(self, input_data: str) -> Dict[str, Any]:
"""JSON转YAML"""
data = json.loads(input_data)
yaml_output = yaml.dump(data, default_flow_style=False, allow_unicode=True, indent=2)
return {
"original_format": "json",
"target_format": "yaml",
"original_size": len(input_data),
"converted_size": len(yaml_output),
"converted_data": yaml_output
}
def _yaml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""YAML转JSON"""
indent = kwargs.get("indent", 2)
ensure_ascii = kwargs.get("ensure_ascii", False)
data = yaml.safe_load(input_data)
json_output = json.dumps(data, indent=indent, ensure_ascii=ensure_ascii)
return {
"original_format": "yaml",
"target_format": "json",
"original_size": len(input_data),
"converted_size": len(json_output),
"converted_data": json_output
}
def _json_to_xml(self, input_data: str) -> Dict[str, Any]:
"""JSON转XML"""
data = json.loads(input_data)
def dict_to_xml(data, root_name="root"):
"""递归转换字典为XML"""
if isinstance(data, dict):
if len(data) == 1 and not root_name == "root":
# 如果字典只有一个键,使用该键作为根元素
key, value = next(iter(data.items()))
return dict_to_xml(value, key)
root = ET.Element(root_name)
for key, value in data.items():
if isinstance(value, (dict, list)):
child = dict_to_xml(value, key)
root.append(child)
else:
child = ET.SubElement(root, key)
child.text = str(value)
return root
elif isinstance(data, list):
root = ET.Element(root_name)
for i, item in enumerate(data):
if isinstance(item, (dict, list)):
child = dict_to_xml(item, f"item_{i}")
root.append(child)
else:
child = ET.SubElement(root, f"item_{i}")
child.text = str(item)
return root
else:
root = ET.Element(root_name)
root.text = str(data)
return root
xml_element = dict_to_xml(data)
xml_string = ET.tostring(xml_element, encoding='unicode')
# 格式化XML
dom = minidom.parseString(xml_string)
formatted_xml = dom.toprettyxml(indent=" ")
# 移除空行
formatted_xml = '\n'.join([line for line in formatted_xml.split('\n') if line.strip()])
return {
"original_format": "json",
"target_format": "xml",
"original_size": len(input_data),
"converted_size": len(formatted_xml),
"converted_data": formatted_xml
}
def _xml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""XML转JSON"""
indent = kwargs.get("indent", 2)
def xml_to_dict(element):
"""递归转换XML元素为字典"""
result = {}
# 处理属性
if element.attrib:
result.update(element.attrib)
# 处理文本内容
if element.text and element.text.strip():
if len(element) == 0: # 叶子节点
return element.text.strip()
else:
result['text'] = element.text.strip()
# 处理子元素
for child in element:
child_data = xml_to_dict(child)
if child.tag in result:
# 如果标签已存在,转换为列表
if not isinstance(result[child.tag], list):
result[child.tag] = [result[child.tag]]
result[child.tag].append(child_data)
else:
result[child.tag] = child_data
return result
root = ET.fromstring(input_data)
data = {root.tag: xml_to_dict(root)}
json_output = json.dumps(data, indent=indent, ensure_ascii=False)
return {
"original_format": "xml",
"target_format": "json",
"original_size": len(input_data),
"converted_size": len(json_output),
"converted_data": json_output
}
def _merge_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""合并JSON"""
merge_data = kwargs.get("merge_data")
if not merge_data:
raise ValueError("merge_data 参数是必需的")
data1 = json.loads(input_data)
data2 = json.loads(merge_data)
def deep_merge(dict1, dict2):
"""深度合并字典"""
result = dict1.copy()
for key, value in dict2.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = deep_merge(result[key], value)
else:
result[key] = value
return result
if isinstance(data1, dict) and isinstance(data2, dict):
merged = deep_merge(data1, data2)
elif isinstance(data1, list) and isinstance(data2, list):
merged = data1 + data2
else:
raise ValueError("无法合并不同类型的数据")
merged_json = json.dumps(merged, indent=2, ensure_ascii=False)
return {
"operation": "merge",
"original_size": len(input_data),
"merge_size": len(merge_data),
"result_size": len(merged_json),
"merged_data": merged_json
}
def _extract_json_path(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""提取JSON路径"""
json_path = kwargs.get("json_path")
if not json_path:
raise ValueError("json_path 参数是必需的")
data = json.loads(input_data)
# 简单的JSONPath实现支持基本的点号路径
try:
result = data
if json_path.startswith('$.'):
path_parts = json_path[2:].split('.')
else:
path_parts = json_path.split('.')
for part in path_parts:
if part.isdigit():
result = result[int(part)]
else:
result = result[part]
extracted_json = json.dumps(result, indent=2, ensure_ascii=False)
return {
"operation": "extract",
"json_path": json_path,
"found": True,
"extracted_data": extracted_json,
"data_type": type(result).__name__
}
except (KeyError, IndexError, TypeError) as e:
return {
"operation": "extract",
"json_path": json_path,
"found": False,
"error": str(e),
"extracted_data": None
}
def _analyze_json_structure(self, data: Any, depth: int = 0) -> Dict[str, Any]:
"""分析JSON结构"""
if isinstance(data, dict):
return {
"type": "object",
"keys": len(data),
"depth": depth,
"children": {k: self._analyze_json_structure(v, depth + 1) for k, v in data.items()}
}
elif isinstance(data, list):
return {
"type": "array",
"length": len(data),
"depth": depth,
"item_types": list(set(type(item).__name__ for item in data))
}
else:
return {
"type": type(data).__name__,
"depth": depth,
"value": str(data)[:100] + "..." if len(str(data)) > 100 else str(data)
}

View File

@@ -0,0 +1,327 @@
"""MinerU PDF解析工具"""
import time
from typing import List, Dict, Any
import aiohttp
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
class MinerUTool(BuiltinTool):
"""MinerU PDF解析工具 - 提供PDF解析、表格提取、图片识别、文本提取功能"""
@property
def name(self) -> str:
return "mineru_tool"
@property
def description(self) -> str:
return "MinerU - PDF解析工具PDF解析、表格提取、图片识别、文本提取"
def get_required_config_parameters(self) -> List[str]:
return ["api_key", "api_url"]
@property
def parameters(self) -> List[ToolParameter]:
return [
ToolParameter(
name="operation",
type=ParameterType.STRING,
description="操作类型",
required=True,
enum=["parse_pdf", "extract_text", "extract_tables", "extract_images", "analyze_layout"]
),
ToolParameter(
name="file_content",
type=ParameterType.STRING,
description="PDF文件内容Base64编码",
required=False
),
ToolParameter(
name="file_url",
type=ParameterType.STRING,
description="PDF文件URL",
required=False
),
ToolParameter(
name="parse_mode",
type=ParameterType.STRING,
description="解析模式",
required=False,
default="auto",
enum=["auto", "text_only", "table_priority", "image_priority", "layout_analysis"]
),
ToolParameter(
name="extract_images",
type=ParameterType.BOOLEAN,
description="是否提取图片",
required=False,
default=True
),
ToolParameter(
name="extract_tables",
type=ParameterType.BOOLEAN,
description="是否提取表格",
required=False,
default=True
),
ToolParameter(
name="page_range",
type=ParameterType.STRING,
description="页面范围1-5, 1,3,5",
required=False
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出格式",
required=False,
default="json",
enum=["json", "markdown", "html", "text"]
)
]
async def execute(self, **kwargs) -> ToolResult:
"""执行MinerU PDF解析"""
start_time = time.time()
try:
operation = kwargs.get("operation")
file_content = kwargs.get("file_content")
file_url = kwargs.get("file_url")
if not file_content and not file_url:
raise ValueError("必须提供 file_content 或 file_url 参数")
if operation == "parse_pdf":
result = await self._parse_pdf(kwargs)
elif operation == "extract_text":
result = await self._extract_text(kwargs)
elif operation == "extract_tables":
result = await self._extract_tables(kwargs)
elif operation == "extract_images":
result = await self._extract_images(kwargs)
elif operation == "analyze_layout":
result = await self._analyze_layout(kwargs)
else:
raise ValueError(f"不支持的操作类型: {operation}")
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="MINERU_ERROR",
execution_time=execution_time
)
async def _parse_pdf(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""完整PDF解析"""
parse_mode = kwargs.get("parse_mode", "auto")
extract_images = kwargs.get("extract_images", True)
extract_tables = kwargs.get("extract_tables", True)
page_range = kwargs.get("page_range")
output_format = kwargs.get("output_format", "json")
# 构建请求参数
request_data = {
"parse_mode": parse_mode,
"extract_images": extract_images,
"extract_tables": extract_tables,
"output_format": output_format
}
if page_range:
request_data["page_range"] = page_range
# 添加文件数据
if kwargs.get("file_content"):
request_data["file_content"] = kwargs["file_content"]
elif kwargs.get("file_url"):
request_data["file_url"] = kwargs["file_url"]
# 调用MinerU API
result = await self._call_mineru_api("parse", request_data)
return {
"operation": "parse_pdf",
"parse_mode": parse_mode,
"total_pages": result.get("total_pages", 0),
"processed_pages": result.get("processed_pages", 0),
"text_content": result.get("text_content", ""),
"tables": result.get("tables", []),
"images": result.get("images", []),
"layout_info": result.get("layout_info", {}),
"metadata": result.get("metadata", {}),
"processing_time": result.get("processing_time", 0)
}
async def _extract_text(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""提取文本"""
page_range = kwargs.get("page_range")
output_format = kwargs.get("output_format", "text")
request_data = {
"operation": "extract_text",
"output_format": output_format
}
if page_range:
request_data["page_range"] = page_range
if kwargs.get("file_content"):
request_data["file_content"] = kwargs["file_content"]
elif kwargs.get("file_url"):
request_data["file_url"] = kwargs["file_url"]
result = await self._call_mineru_api("extract_text", request_data)
return {
"operation": "extract_text",
"total_pages": result.get("total_pages", 0),
"text_content": result.get("text_content", ""),
"word_count": len(result.get("text_content", "").split()),
"character_count": len(result.get("text_content", "")),
"pages_text": result.get("pages_text", [])
}
async def _extract_tables(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""提取表格"""
page_range = kwargs.get("page_range")
output_format = kwargs.get("output_format", "json")
request_data = {
"operation": "extract_tables",
"output_format": output_format
}
if page_range:
request_data["page_range"] = page_range
if kwargs.get("file_content"):
request_data["file_content"] = kwargs["file_content"]
elif kwargs.get("file_url"):
request_data["file_url"] = kwargs["file_url"]
result = await self._call_mineru_api("extract_tables", request_data)
return {
"operation": "extract_tables",
"total_tables": result.get("total_tables", 0),
"tables": result.get("tables", []),
"table_locations": result.get("table_locations", [])
}
async def _extract_images(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""提取图片"""
page_range = kwargs.get("page_range")
request_data = {
"operation": "extract_images"
}
if page_range:
request_data["page_range"] = page_range
if kwargs.get("file_content"):
request_data["file_content"] = kwargs["file_content"]
elif kwargs.get("file_url"):
request_data["file_url"] = kwargs["file_url"]
result = await self._call_mineru_api("extract_images", request_data)
return {
"operation": "extract_images",
"total_images": result.get("total_images", 0),
"images": result.get("images", []),
"image_locations": result.get("image_locations", [])
}
async def _analyze_layout(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""分析布局"""
page_range = kwargs.get("page_range")
request_data = {
"operation": "analyze_layout"
}
if page_range:
request_data["page_range"] = page_range
if kwargs.get("file_content"):
request_data["file_content"] = kwargs["file_content"]
elif kwargs.get("file_url"):
request_data["file_url"] = kwargs["file_url"]
result = await self._call_mineru_api("analyze_layout", request_data)
return {
"operation": "analyze_layout",
"layout_info": result.get("layout_info", {}),
"page_layouts": result.get("page_layouts", []),
"text_blocks": result.get("text_blocks", []),
"image_blocks": result.get("image_blocks", []),
"table_blocks": result.get("table_blocks", [])
}
async def _call_mineru_api(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""调用MinerU API"""
api_key = self.get_config_parameter("api_key")
api_url = self.get_config_parameter("api_url")
timeout_seconds = self.get_config_parameter("timeout", 60)
if not api_key or not api_url:
raise ValueError("MinerU API配置未完成")
# 构建完整URL
url = f"{api_url.rstrip('/')}/{endpoint}"
# 构建请求头
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
# 发送请求
timeout = aiohttp.ClientTimeout(total=timeout_seconds)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(url, json=data, headers=headers) as response:
if response.status == 200:
result = await response.json()
if result.get("success", True):
return result.get("data", result)
else:
raise Exception(f"MinerU API错误: {result.get('message', '未知错误')}")
else:
error_text = await response.text()
raise Exception(f"HTTP错误 {response.status}: {error_text}")
def test_connection(self) -> Dict[str, Any]:
"""测试连接"""
try:
api_key = self.get_config_parameter("api_key")
api_url = self.get_config_parameter("api_url")
if not api_key or not api_url:
return {
"success": False,
"error": "API配置未完成"
}
return {
"success": True,
"message": "连接配置有效",
"api_url": api_url,
"api_key_masked": api_key[:8] + "***" if len(api_key) > 8 else "***"
}
except Exception as e:
return {
"success": False,
"error": str(e)
}

View File

@@ -0,0 +1,401 @@
"""TextIn OCR文字识别工具"""
import time
from typing import List, Dict, Any
import aiohttp
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
class TextInTool(BuiltinTool):
"""TextIn OCR工具 - 提供通用OCR、手写识别、多语言支持、高精度识别"""
@property
def name(self) -> str:
return "textin_tool"
@property
def description(self) -> str:
return "TextIn - OCR文字识别通用OCR、手写识别、多语言支持、高精度识别"
def get_required_config_parameters(self) -> List[str]:
return ["app_id", "secret_key", "api_url"]
@property
def parameters(self) -> List[ToolParameter]:
return [
ToolParameter(
name="image_content",
type=ParameterType.STRING,
description="图片内容Base64编码",
required=False
),
ToolParameter(
name="image_url",
type=ParameterType.STRING,
description="图片URL",
required=False
),
ToolParameter(
name="language",
type=ParameterType.STRING,
description="识别语言",
required=False,
default="auto",
enum=["auto", "zh-cn", "zh-tw", "en", "ja", "ko", "fr", "de", "es", "ru"]
),
ToolParameter(
name="recognition_mode",
type=ParameterType.STRING,
description="识别模式",
required=False,
default="general",
enum=["general", "accurate", "handwriting", "formula", "table", "document"]
),
ToolParameter(
name="return_location",
type=ParameterType.BOOLEAN,
description="是否返回文字位置信息",
required=False,
default=False
),
ToolParameter(
name="return_confidence",
type=ParameterType.BOOLEAN,
description="是否返回置信度",
required=False,
default=True
),
ToolParameter(
name="merge_lines",
type=ParameterType.BOOLEAN,
description="是否合并行",
required=False,
default=True
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出格式",
required=False,
default="text",
enum=["text", "json", "structured"]
)
]
async def execute(self, **kwargs) -> ToolResult:
"""执行TextIn OCR识别"""
start_time = time.time()
try:
image_content = kwargs.get("image_content")
image_url = kwargs.get("image_url")
if not image_content and not image_url:
raise ValueError("必须提供 image_content 或 image_url 参数")
language = kwargs.get("language", "auto")
recognition_mode = kwargs.get("recognition_mode", "general")
return_location = kwargs.get("return_location", False)
return_confidence = kwargs.get("return_confidence", True)
merge_lines = kwargs.get("merge_lines", True)
output_format = kwargs.get("output_format", "text")
# 根据识别模式调用不同的API
if recognition_mode == "general":
result = await self._general_ocr(kwargs)
elif recognition_mode == "accurate":
result = await self._accurate_ocr(kwargs)
elif recognition_mode == "handwriting":
result = await self._handwriting_ocr(kwargs)
elif recognition_mode == "formula":
result = await self._formula_ocr(kwargs)
elif recognition_mode == "table":
result = await self._table_ocr(kwargs)
elif recognition_mode == "document":
result = await self._document_ocr(kwargs)
else:
raise ValueError(f"不支持的识别模式: {recognition_mode}")
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="TEXTIN_ERROR",
execution_time=execution_time
)
async def _general_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""通用OCR识别"""
request_data = {
"language": kwargs.get("language", "auto"),
"return_location": kwargs.get("return_location", False),
"return_confidence": kwargs.get("return_confidence", True),
"merge_lines": kwargs.get("merge_lines", True)
}
if kwargs.get("image_content"):
request_data["image"] = kwargs["image_content"]
elif kwargs.get("image_url"):
request_data["image_url"] = kwargs["image_url"]
result = await self._call_textin_api("general_ocr", request_data)
return self._format_ocr_result(result, kwargs.get("output_format", "text"))
async def _accurate_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""高精度OCR识别"""
request_data = {
"language": kwargs.get("language", "auto"),
"return_location": kwargs.get("return_location", False),
"return_confidence": kwargs.get("return_confidence", True),
"merge_lines": kwargs.get("merge_lines", True)
}
if kwargs.get("image_content"):
request_data["image"] = kwargs["image_content"]
elif kwargs.get("image_url"):
request_data["image_url"] = kwargs["image_url"]
result = await self._call_textin_api("accurate_ocr", request_data)
return self._format_ocr_result(result, kwargs.get("output_format", "text"))
async def _handwriting_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""手写体识别"""
request_data = {
"language": kwargs.get("language", "auto"),
"return_location": kwargs.get("return_location", False),
"return_confidence": kwargs.get("return_confidence", True)
}
if kwargs.get("image_content"):
request_data["image"] = kwargs["image_content"]
elif kwargs.get("image_url"):
request_data["image_url"] = kwargs["image_url"]
result = await self._call_textin_api("handwriting_ocr", request_data)
return self._format_ocr_result(result, kwargs.get("output_format", "text"))
async def _formula_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""公式识别"""
request_data = {
"return_location": kwargs.get("return_location", False),
"return_confidence": kwargs.get("return_confidence", True),
"output_latex": True
}
if kwargs.get("image_content"):
request_data["image"] = kwargs["image_content"]
elif kwargs.get("image_url"):
request_data["image_url"] = kwargs["image_url"]
result = await self._call_textin_api("formula_ocr", request_data)
return self._format_formula_result(result, kwargs.get("output_format", "text"))
async def _table_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""表格识别"""
request_data = {
"language": kwargs.get("language", "auto"),
"return_location": kwargs.get("return_location", False),
"return_confidence": kwargs.get("return_confidence", True),
"output_excel": True
}
if kwargs.get("image_content"):
request_data["image"] = kwargs["image_content"]
elif kwargs.get("image_url"):
request_data["image_url"] = kwargs["image_url"]
result = await self._call_textin_api("table_ocr", request_data)
return self._format_table_result(result, kwargs.get("output_format", "text"))
async def _document_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""文档识别"""
request_data = {
"language": kwargs.get("language", "auto"),
"return_location": kwargs.get("return_location", False),
"return_confidence": kwargs.get("return_confidence", True),
"layout_analysis": True
}
if kwargs.get("image_content"):
request_data["image"] = kwargs["image_content"]
elif kwargs.get("image_url"):
request_data["image_url"] = kwargs["image_url"]
result = await self._call_textin_api("document_ocr", request_data)
return self._format_document_result(result, kwargs.get("output_format", "text"))
def _format_ocr_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any] | None:
"""格式化OCR结果"""
lines = result.get("lines", [])
if output_format == "text":
text_content = "\n".join([line.get("text", "") for line in lines])
return {
"recognition_mode": "ocr",
"text_content": text_content,
"line_count": len(lines),
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
elif output_format == "json":
return {
"recognition_mode": "ocr",
"lines": lines,
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
elif output_format == "structured":
return {
"recognition_mode": "ocr",
"text_content": "\n".join([line.get("text", "") for line in lines]),
"structured_data": {
"lines": lines,
"paragraphs": self._group_lines_to_paragraphs(lines),
"statistics": {
"line_count": len(lines),
"word_count": sum(len(line.get("text", "").split()) for line in lines),
"character_count": sum(len(line.get("text", "")) for line in lines)
}
},
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
def _format_formula_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
"""格式化公式识别结果"""
formulas = result.get("formulas", [])
return {
"recognition_mode": "formula",
"formula_count": len(formulas),
"formulas": formulas,
"latex_content": "\n".join([f.get("latex", "") for f in formulas]),
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
def _format_table_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
"""格式化表格识别结果"""
tables = result.get("tables", [])
return {
"recognition_mode": "table",
"table_count": len(tables),
"tables": tables,
"excel_data": result.get("excel_data"),
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
def _format_document_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
"""格式化文档识别结果"""
return {
"recognition_mode": "document",
"layout_info": result.get("layout_info", {}),
"text_blocks": result.get("text_blocks", []),
"image_blocks": result.get("image_blocks", []),
"table_blocks": result.get("table_blocks", []),
"full_text": result.get("full_text", ""),
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
def _group_lines_to_paragraphs(self, lines: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""将行分组为段落"""
paragraphs = []
current_paragraph = []
for line in lines:
text = line.get("text", "").strip()
if text:
current_paragraph.append(line)
else:
if current_paragraph:
paragraphs.append({
"text": " ".join([l.get("text", "") for l in current_paragraph]),
"lines": current_paragraph
})
current_paragraph = []
if current_paragraph:
paragraphs.append({
"text": " ".join([l.get("text", "") for l in current_paragraph]),
"lines": current_paragraph
})
return paragraphs
async def _call_textin_api(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""调用TextIn API"""
app_id = self.get_config_parameter("app_id")
secret_key = self.get_config_parameter("secret_key")
api_url = self.get_config_parameter("api_url")
if not app_id or not secret_key or not api_url:
raise ValueError("TextIn API配置未完成")
# 构建完整URL
url = f"{api_url.rstrip('/')}/{endpoint}"
# 构建请求头
headers = {
"X-App-Id": app_id,
"X-Secret-Key": secret_key,
"Content-Type": "application/json"
}
# 发送请求
timeout = aiohttp.ClientTimeout(total=30)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(url, json=data, headers=headers) as response:
if response.status == 200:
result = await response.json()
if result.get("code") == 200:
return result.get("data", result)
else:
raise Exception(f"TextIn API错误: {result.get('message', '未知错误')}")
else:
error_text = await response.text()
raise Exception(f"HTTP错误 {response.status}: {error_text}")
def test_connection(self) -> Dict[str, Any]:
"""测试连接"""
try:
app_id = self.get_config_parameter("app_id")
secret_key = self.get_config_parameter("secret_key")
api_url = self.get_config_parameter("api_url")
if not app_id or not secret_key or not api_url:
return {
"success": False,
"error": "API配置未完成"
}
return {
"success": True,
"message": "连接配置有效",
"api_url": api_url,
"app_id": app_id,
"secret_key_masked": secret_key[:8] + "***" if len(secret_key) > 8 else "***"
}
except Exception as e:
return {
"success": False,
"error": str(e)
}

View File

@@ -0,0 +1,485 @@
"""工具链管理器 - 支持langchain的工具链模式"""
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from enum import Enum
from app.core.tools.base import ToolResult
from app.core.tools.executor import ToolExecutor
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class ChainExecutionMode(str, Enum):
"""链执行模式"""
SEQUENTIAL = "sequential" # 顺序执行
PARALLEL = "parallel" # 并行执行
CONDITIONAL = "conditional" # 条件执行
@dataclass
class ChainStep:
"""链步骤定义"""
tool_id: str
parameters: Dict[str, Any]
condition: Optional[str] = None # 执行条件
output_mapping: Optional[Dict[str, str]] = None # 输出映射
error_handling: str = "stop" # 错误处理stop, continue, retry
@dataclass
class ChainDefinition:
"""工具链定义"""
name: str
description: str
steps: List[ChainStep]
execution_mode: ChainExecutionMode = ChainExecutionMode.SEQUENTIAL
global_timeout: Optional[float] = None
retry_policy: Optional[Dict[str, Any]] = None
class ChainExecutionContext:
"""链执行上下文"""
def __init__(self, chain_id: str):
self.chain_id = chain_id
self.variables: Dict[str, Any] = {}
self.step_results: Dict[int, ToolResult] = {}
self.current_step = 0
self.is_completed = False
self.is_failed = False
self.error_message: Optional[str] = None
class ChainManager:
"""工具链管理器 - 支持langchain的工具链模式"""
def __init__(self, executor: ToolExecutor):
"""初始化工具链管理器
Args:
executor: 工具执行器
"""
self.executor = executor
self._chains: Dict[str, ChainDefinition] = {}
self._running_chains: Dict[str, ChainExecutionContext] = {}
def register_chain(self, chain: ChainDefinition) -> bool:
"""注册工具链
Args:
chain: 工具链定义
Returns:
注册是否成功
"""
try:
# 验证工具链定义
validation_result = self._validate_chain(chain)
if not validation_result[0]:
logger.error(f"工具链验证失败: {chain.name}, 错误: {validation_result[1]}")
return False
self._chains[chain.name] = chain
logger.info(f"工具链注册成功: {chain.name}")
return True
except Exception as e:
logger.error(f"工具链注册失败: {chain.name}, 错误: {e}")
return False
def unregister_chain(self, chain_name: str) -> bool:
"""注销工具链
Args:
chain_name: 工具链名称
Returns:
注销是否成功
"""
if chain_name in self._chains:
del self._chains[chain_name]
logger.info(f"工具链注销成功: {chain_name}")
return True
return False
def list_chains(self) -> List[Dict[str, Any]]:
"""列出所有工具链
Returns:
工具链信息列表
"""
chains = []
for name, chain in self._chains.items():
chains.append({
"name": name,
"description": chain.description,
"step_count": len(chain.steps),
"execution_mode": chain.execution_mode.value,
"global_timeout": chain.global_timeout
})
return chains
async def execute_chain(
self,
chain_name: str,
initial_variables: Optional[Dict[str, Any]] = None,
chain_id: Optional[str] = None
) -> Dict[str, Any] | None:
"""执行工具链
Args:
chain_name: 工具链名称
initial_variables: 初始变量
chain_id: 链执行ID可选
Returns:
执行结果
"""
if chain_name not in self._chains:
return {
"success": False,
"error": f"工具链不存在: {chain_name}",
"chain_id": chain_id
}
chain = self._chains[chain_name]
# 生成链ID
if not chain_id:
import uuid
chain_id = f"chain_{uuid.uuid4().hex[:16]}"
# 创建执行上下文
context = ChainExecutionContext(chain_id)
context.variables = initial_variables or {}
self._running_chains[chain_id] = context
try:
logger.info(f"开始执行工具链: {chain_name} (ID: {chain_id})")
# 根据执行模式执行
if chain.execution_mode == ChainExecutionMode.SEQUENTIAL:
result = await self._execute_sequential(chain, context)
elif chain.execution_mode == ChainExecutionMode.PARALLEL:
result = await self._execute_parallel(chain, context)
elif chain.execution_mode == ChainExecutionMode.CONDITIONAL:
result = await self._execute_conditional(chain, context)
else:
raise ValueError(f"不支持的执行模式: {chain.execution_mode}")
logger.info(f"工具链执行完成: {chain_name} (ID: {chain_id})")
return result
except Exception as e:
logger.error(f"工具链执行失败: {chain_name} (ID: {chain_id}), 错误: {e}")
return {
"success": False,
"error": str(e),
"chain_id": chain_id,
"completed_steps": context.current_step,
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
}
finally:
# 清理执行上下文
if chain_id in self._running_chains:
del self._running_chains[chain_id]
async def _execute_sequential(
self,
chain: ChainDefinition,
context: ChainExecutionContext
) -> Dict[str, Any]:
"""顺序执行工具链"""
for i, step in enumerate(chain.steps):
context.current_step = i
# 检查执行条件
if step.condition and not self._evaluate_condition(step.condition, context):
logger.debug(f"跳过步骤 {i}: 条件不满足")
continue
# 准备参数
parameters = self._prepare_parameters(step.parameters, context)
# 执行工具
try:
result = await self.executor.execute_tool(
tool_id=step.tool_id,
parameters=parameters
)
context.step_results[i] = result
# 处理输出映射
if step.output_mapping and result.success:
self._apply_output_mapping(step.output_mapping, result.data, context)
# 处理执行失败
if not result.success:
if step.error_handling == "stop":
context.is_failed = True
context.error_message = result.error
break
elif step.error_handling == "continue":
logger.warning(f"步骤 {i} 执行失败,继续执行: {result.error}")
continue
elif step.error_handling == "retry":
# 简单重试逻辑
retry_result = await self.executor.execute_tool(
tool_id=step.tool_id,
parameters=parameters
)
context.step_results[i] = retry_result
if not retry_result.success and step.error_handling == "stop":
context.is_failed = True
context.error_message = retry_result.error
break
except Exception as e:
logger.error(f"步骤 {i} 执行异常: {e}")
if step.error_handling == "stop":
context.is_failed = True
context.error_message = str(e)
break
context.is_completed = not context.is_failed
return {
"success": context.is_completed,
"error": context.error_message,
"chain_id": context.chain_id,
"completed_steps": context.current_step + 1,
"total_steps": len(chain.steps),
"final_variables": context.variables,
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
}
async def _execute_parallel(
self,
chain: ChainDefinition,
context: ChainExecutionContext
) -> Dict[str, Any]:
"""并行执行工具链"""
# 准备所有步骤的执行配置
execution_configs = []
for i, step in enumerate(chain.steps):
# 检查执行条件
if step.condition and not self._evaluate_condition(step.condition, context):
continue
parameters = self._prepare_parameters(step.parameters, context)
execution_configs.append({
"step_index": i,
"tool_id": step.tool_id,
"parameters": parameters
})
# 并行执行所有步骤
try:
results = await self.executor.execute_tools_batch(execution_configs)
# 处理结果
for i, result in enumerate(results):
step_index = execution_configs[i]["step_index"]
context.step_results[step_index] = result
# 处理输出映射
step = chain.steps[step_index]
if step.output_mapping and result.success:
self._apply_output_mapping(step.output_mapping, result.data, context)
# 检查是否有失败的步骤
failed_steps = [i for i, result in context.step_results.items() if not result.success]
context.is_completed = len(failed_steps) == 0
if failed_steps:
context.error_message = f"步骤 {failed_steps} 执行失败"
except Exception as e:
context.is_failed = True
context.error_message = str(e)
return {
"success": context.is_completed,
"error": context.error_message,
"chain_id": context.chain_id,
"completed_steps": len(context.step_results),
"total_steps": len(chain.steps),
"final_variables": context.variables,
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
}
async def _execute_conditional(
self,
chain: ChainDefinition,
context: ChainExecutionContext
) -> Dict[str, Any]:
"""条件执行工具链"""
# 条件执行类似于顺序执行,但更严格地检查条件
return await self._execute_sequential(chain, context)
def _validate_chain(self, chain: ChainDefinition) -> tuple[bool, Optional[str]]:
"""验证工具链定义
Args:
chain: 工具链定义
Returns:
(是否有效, 错误信息)
"""
if not chain.name:
return False, "工具链名称不能为空"
if not chain.steps:
return False, "工具链必须包含至少一个步骤"
for i, step in enumerate(chain.steps):
if not step.tool_id:
return False, f"步骤 {i} 缺少工具ID"
if step.error_handling not in ["stop", "continue", "retry"]:
return False, f"步骤 {i} 错误处理策略无效: {step.error_handling}"
return True, None
def _prepare_parameters(
self,
parameters: Dict[str, Any],
context: ChainExecutionContext
) -> Dict[str, Any]:
"""准备参数(支持变量替换)
Args:
parameters: 原始参数
context: 执行上下文
Returns:
处理后的参数
"""
prepared = {}
for key, value in parameters.items():
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
# 变量替换
var_name = value[2:-1]
if var_name in context.variables:
prepared[key] = context.variables[var_name]
else:
prepared[key] = value # 保持原值
else:
prepared[key] = value
return prepared
def _evaluate_condition(
self,
condition: str,
context: ChainExecutionContext
) -> bool:
"""评估执行条件
Args:
condition: 条件表达式
context: 执行上下文
Returns:
条件是否满足
"""
try:
# 简单的条件评估(可以扩展为更复杂的表达式解析)
# 支持格式variable == value, variable != value, variable > value 等
if "==" in condition:
var_name, expected_value = condition.split("==", 1)
var_name = var_name.strip()
expected_value = expected_value.strip().strip('"\'')
return str(context.variables.get(var_name, "")) == expected_value
elif "!=" in condition:
var_name, expected_value = condition.split("!=", 1)
var_name = var_name.strip()
expected_value = expected_value.strip().strip('"\'')
return str(context.variables.get(var_name, "")) != expected_value
elif condition in context.variables:
# 简单的布尔检查
return bool(context.variables[condition])
else:
# 默认为真
return True
except Exception as e:
logger.error(f"条件评估失败: {condition}, 错误: {e}")
return False
def _apply_output_mapping(
self,
mapping: Dict[str, str],
output_data: Any,
context: ChainExecutionContext
):
"""应用输出映射
Args:
mapping: 输出映射配置
output_data: 输出数据
context: 执行上下文
"""
try:
if isinstance(output_data, dict):
for source_key, target_var in mapping.items():
if source_key in output_data:
context.variables[target_var] = output_data[source_key]
else:
# 如果输出不是字典,将整个输出映射到指定变量
if "result" in mapping:
context.variables[mapping["result"]] = output_data
except Exception as e:
logger.error(f"输出映射失败: {e}")
def _serialize_result(self, result: ToolResult) -> Dict[str, Any]:
"""序列化工具结果
Args:
result: 工具结果
Returns:
序列化的结果
"""
return {
"success": result.success,
"data": result.data,
"error": result.error,
"error_code": result.error_code,
"execution_time": result.execution_time,
"token_usage": result.token_usage,
"metadata": result.metadata
}
def get_running_chains(self) -> List[Dict[str, Any]]:
"""获取正在运行的工具链
Returns:
运行中的工具链列表
"""
chains = []
for chain_id, context in self._running_chains.items():
chains.append({
"chain_id": chain_id,
"current_step": context.current_step,
"is_completed": context.is_completed,
"is_failed": context.is_failed,
"variables_count": len(context.variables),
"completed_steps": len(context.step_results)
})
return chains

View File

@@ -0,0 +1,264 @@
"""工具配置管理器 - 管理工具配置的加载和验证"""
import json
from pathlib import Path
from typing import Dict, Any, Optional
from pydantic import BaseModel, ValidationError
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class ToolConfigSchema(BaseModel):
"""工具配置基础Schema"""
name: str
description: str
tool_type: str
version: str = "1.0.0"
enabled: bool = True
parameters: Dict[str, Any] = {}
tags: list[str] = []
class Config:
extra = "allow"
class BuiltinToolConfigSchema(ToolConfigSchema):
"""内置工具配置Schema"""
tool_class: str
tool_type: str = "builtin"
class CustomToolConfigSchema(ToolConfigSchema):
"""自定义工具配置Schema"""
schema_url: Optional[str] = None
schema_content: Optional[Dict[str, Any]] = None
auth_type: str = "none"
auth_config: Dict[str, Any] = {}
base_url: Optional[str] = None
timeout: int = 30
tool_type: str = "custom"
class MCPToolConfigSchema(ToolConfigSchema):
"""MCP工具配置Schema"""
server_url: str
connection_config: Dict[str, Any] = {}
available_tools: list[str] = []
tool_type: str = "mcp"
class ConfigManager:
"""工具配置管理器"""
def __init__(self, config_dir: Optional[str] = None):
"""初始化配置管理器
Args:
config_dir: 配置文件目录,默认使用系统配置
"""
self.config_dir = Path(config_dir or self._get_default_config_dir())
self.config_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"配置管理器初始化完成,配置目录: {self.config_dir}")
def _get_default_config_dir(self) -> str:
"""获取默认配置目录"""
# 获取tools目录下的configs子目录
tools_dir = Path(__file__).parent
return str(tools_dir / "configs")
def load_builtin_tool_configs(self) -> Dict[str, BuiltinToolConfigSchema]:
"""加载内置工具配置
Returns:
内置工具配置字典
"""
configs = {}
builtin_dir = self.config_dir / "builtin"
if not builtin_dir.exists():
logger.info("内置工具配置目录不存在,创建默认配置")
self._create_default_builtin_configs(builtin_dir)
for config_file in builtin_dir.glob("*.json"):
try:
config_data = self._load_config_file(config_file)
config = BuiltinToolConfigSchema(**config_data)
configs[config.name] = config
logger.debug(f"加载内置工具配置: {config.name}")
except Exception as e:
logger.error(f"加载内置工具配置失败: {config_file}, 错误: {e}")
return configs
def load_builtin_tools_config(self) -> Dict[str, Any]:
"""加载全局内置工具配置(兼容原有接口)
Returns:
内置工具配置字典
"""
config_file = self.config_dir / "builtin_tools.json"
try:
with open(config_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"加载内置工具配置失败: {e}")
return {}
def ensure_builtin_tools_initialized(self, tenant_id, db_session, tool_config_model, builtin_tool_config_model, tool_type_enum, tool_status_enum):
"""确保内置工具已初始化到数据库
Args:
tenant_id: 租户ID
db_session: 数据库会话
tool_config_model: ToolConfig模型类
builtin_tool_config_model: BuiltinToolConfig模型类
tool_type_enum: ToolType枚举
tool_status_enum: ToolStatus枚举
"""
# 检查是否已初始化
existing_count = db_session.query(tool_config_model).filter(
tool_config_model.tenant_id == tenant_id,
tool_config_model.tool_type == tool_type_enum.BUILTIN
).count()
if existing_count > 0:
return # 已初始化
# 加载全局配置
builtin_tools = self.load_builtin_tools_config()
# 为租户创建内置工具记录
for tool_key, tool_info in builtin_tools.items():
# 设置初始状态
initial_status = tool_status_enum.ACTIVE.value if not tool_info['requires_config'] else tool_status_enum.INACTIVE.value
tool_config = tool_config_model(
name=tool_info['name'],
description=tool_info['description'],
tool_type=tool_type_enum.BUILTIN,
tenant_id=tenant_id,
status=initial_status
)
db_session.add(tool_config)
db_session.flush()
builtin_config = builtin_tool_config_model(
id=tool_config.id,
tool_class=tool_info['tool_class'],
parameters={}
)
db_session.add(builtin_config)
db_session.commit()
logger.info(f"租户 {tenant_id} 的内置工具初始化完成")
def save_tool_config(self, config: ToolConfigSchema, tool_type: str) -> bool:
"""保存工具配置
Args:
config: 工具配置
tool_type: 工具类型
Returns:
保存是否成功
"""
try:
config_dir = self.config_dir / tool_type
config_dir.mkdir(parents=True, exist_ok=True)
config_file = config_dir / f"{config.name}.json"
config_data = config.model_dump()
with open(config_file, 'w', encoding='utf-8') as f:
json.dump(config_data, f, indent=2, ensure_ascii=False)
logger.info(f"工具配置保存成功: {config.name} ({tool_type})")
return True
except Exception as e:
logger.error(f"工具配置保存失败: {config.name}, 错误: {e}")
return False
def delete_tool_config(self, tool_name: str, tool_type: str) -> bool:
"""删除工具配置
Args:
tool_name: 工具名称
tool_type: 工具类型
Returns:
删除是否成功
"""
try:
config_file = self.config_dir / tool_type / f"{tool_name}.json"
if config_file.exists():
config_file.unlink()
logger.info(f"工具配置删除成功: {tool_name} ({tool_type})")
return True
else:
logger.warning(f"工具配置文件不存在: {tool_name} ({tool_type})")
return False
except Exception as e:
logger.error(f"工具配置删除失败: {tool_name}, 错误: {e}")
return False
def validate_config(self, config_data: Dict[str, Any], tool_type: str) -> tuple[bool, Optional[str]]:
"""验证工具配置
Args:
config_data: 配置数据
tool_type: 工具类型
Returns:
(是否有效, 错误信息)
"""
try:
schema_map = {
"builtin": BuiltinToolConfigSchema,
"custom": CustomToolConfigSchema,
"mcp": MCPToolConfigSchema
}
schema_class = schema_map.get(tool_type)
if not schema_class:
return False, f"不支持的工具类型: {tool_type}"
# 验证配置
schema_class(**config_data)
return True, None
except ValidationError as e:
error_msg = "; ".join([f"{err['loc'][0]}: {err['msg']}" for err in e.errors()])
return False, f"配置验证失败: {error_msg}"
except Exception as e:
return False, f"配置验证异常: {str(e)}"
def _load_config_file(self, config_file: Path) -> Dict[str, Any]:
"""加载配置文件
Args:
config_file: 配置文件路径
Returns:
配置数据字典
"""
try:
with open(config_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"加载配置文件失败: {config_file}, 错误: {e}")
raise
def _create_default_builtin_configs(self, builtin_dir: Path):
"""创建默认内置工具配置
Args:
builtin_dir: 内置工具配置目录
"""
builtin_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"内置工具配置目录已创建: {builtin_dir}")
# 配置文件已经通过其他方式创建,这里只需要确保目录存在

View File

@@ -0,0 +1,14 @@
{
"name": "baidu_search_tool",
"description": "百度搜索工具 - 网络搜索:提供网页搜索、新闻搜索、图片搜索功能",
"tool_type": "builtin",
"tool_class": "BaiduSearchTool",
"version": "1.0.0",
"enabled": true,
"parameters": {
"api_key": "",
"secret_key": "",
"search_type": "web"
},
"tags": ["search", "web", "baidu", "builtin"]
}

View File

@@ -0,0 +1,12 @@
{
"name": "datetime_tool",
"description": "时间工具 - 日期时间处理:提供时间格式转化、时区转换、时间戳转换、时间计算",
"tool_type": "builtin",
"tool_class": "DateTimeTool",
"version": "1.0.0",
"enabled": true,
"parameters": {
"timezone": "UTC"
},
"tags": ["time", "utility", "builtin"]
}

View File

@@ -0,0 +1,12 @@
{
"name": "json_tool",
"description": "JSON工具 - 数据格式处理提供JSON格式化、压缩、验证、格式转换",
"tool_type": "builtin",
"tool_class": "JsonTool",
"version": "1.0.0",
"enabled": true,
"parameters": {
"indent": 2
},
"tags": ["json", "data", "utility", "builtin"]
}

View File

@@ -0,0 +1,14 @@
{
"name": "mineru_tool",
"description": "MinerU PDF解析工具 - 文档处理提供PDF解析、表格提取、图片识别、文本提取功能",
"tool_type": "builtin",
"tool_class": "MinerUTool",
"version": "1.0.0",
"enabled": true,
"parameters": {
"api_key": "",
"parse_mode": "auto",
"timeout": 60
},
"tags": ["pdf", "document", "ocr", "builtin"]
}

View File

@@ -0,0 +1,14 @@
{
"name": "textin_tool",
"description": "TextIn OCR工具 - 图像识别提供通用OCR、手写识别、多语言支持功能",
"tool_type": "builtin",
"tool_class": "TextInTool",
"version": "1.0.0",
"enabled": true,
"parameters": {
"app_id": "",
"language": "auto",
"recognition_mode": "general"
},
"tags": ["ocr", "image", "text", "builtin"]
}

View File

@@ -0,0 +1,60 @@
{
"datetime": {
"name": "时间工具",
"description": "获取当前时间、日期计算",
"tool_class": "DateTimeTool",
"category": "utility",
"requires_config": false,
"version": "1.0.0",
"enabled": true,
"parameters": {}
},
"json_converter": {
"name": "JSON转换工具",
"description": "JSON数据格式化和转换",
"tool_class": "JsonTool",
"category": "utility",
"requires_config": false,
"version": "1.0.0",
"enabled": true,
"parameters": {}
},
"baidu_search": {
"name": "百度搜索",
"description": "百度网页搜索服务",
"tool_class": "BaiduSearchTool",
"category": "search",
"requires_config": true,
"version": "1.0.0",
"enabled": true,
"parameters": {
"api_key": {"type": "string", "description": "百度搜索API密钥", "sensitive": true, "required": true}
}
},
"mineru": {
"name": "MinerU",
"description": "PDF文档解析工具",
"tool_class": "MinerUTool",
"category": "document",
"requires_config": true,
"version": "1.0.0",
"enabled": true,
"parameters": {
"api_key": {"type": "string", "description": "MinerU API密钥", "sensitive": true, "required": true},
"base_url": {"type": "string", "description": "API地址", "default": "https://api.mineru.com"}
}
},
"textin": {
"name": "TextIn",
"description": "OCR文字识别服务",
"tool_class": "TextInTool",
"category": "ocr",
"requires_config": true,
"version": "1.0.0",
"enabled": true,
"parameters": {
"api_key": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true},
"api_secret": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true}
}
}
}

View File

@@ -0,0 +1,11 @@
"""自定义工具模块"""
from .base import CustomTool
from .schema_parser import OpenAPISchemaParser
from .auth_manager import AuthManager
__all__ = [
"CustomTool",
"OpenAPISchemaParser",
"AuthManager"
]

View File

@@ -0,0 +1,525 @@
"""认证管理器 - 处理自定义工具的认证配置"""
import base64
import hashlib
import hmac
import time
from typing import Dict, Any, Tuple
from urllib.parse import quote
import aiohttp
from app.models.tool_model import AuthType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class AuthManager:
"""认证管理器 - 支持多种认证方式"""
def __init__(self):
"""初始化认证管理器"""
self.supported_auth_types = [
AuthType.NONE,
AuthType.API_KEY,
AuthType.BEARER_TOKEN
]
def validate_auth_config(self, auth_type: AuthType, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
"""验证认证配置
Args:
auth_type: 认证类型
auth_config: 认证配置
Returns:
(是否有效, 错误信息)
"""
try:
if auth_type not in self.supported_auth_types:
return False, f"不支持的认证类型: {auth_type}"
if auth_type == AuthType.NONE:
return True, ""
elif auth_type == AuthType.API_KEY:
return self._validate_api_key_config(auth_config)
elif auth_type == AuthType.BEARER_TOKEN:
return self._validate_bearer_token_config(auth_config)
return False, "未知的认证类型"
except Exception as e:
return False, f"验证认证配置时出错: {e}"
def _validate_api_key_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
"""验证API Key认证配置
Args:
auth_config: 认证配置
Returns:
(是否有效, 错误信息)
"""
api_key = auth_config.get("api_key")
if not api_key:
return False, "API Key不能为空"
if not isinstance(api_key, str):
return False, "API Key必须是字符串"
# 验证key名称
key_name = auth_config.get("key_name", "X-API-Key")
if not isinstance(key_name, str):
return False, "API Key名称必须是字符串"
# 验证位置
key_location = auth_config.get("location", "header")
if key_location not in ["header", "query", "cookie"]:
return False, "API Key位置必须是 header、query 或 cookie"
return True, ""
def _validate_bearer_token_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
"""验证Bearer Token认证配置
Args:
auth_config: 认证配置
Returns:
(是否有效, 错误信息)
"""
token = auth_config.get("token")
if not token:
return False, "Bearer Token不能为空"
if not isinstance(token, str):
return False, "Bearer Token必须是字符串"
return True, ""
def apply_authentication(
self,
auth_type: AuthType,
auth_config: Dict[str, Any],
url: str,
headers: Dict[str, str],
params: Dict[str, Any]
) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
"""应用认证到请求
Args:
auth_type: 认证类型
auth_config: 认证配置
url: 请求URL
headers: 请求头
params: 请求参数
Returns:
(修改后的URL, 修改后的headers, 修改后的params)
"""
try:
if auth_type == AuthType.NONE:
return url, headers, params
elif auth_type == AuthType.API_KEY:
return self._apply_api_key_auth(auth_config, url, headers, params)
elif auth_type == AuthType.BEARER_TOKEN:
return self._apply_bearer_token_auth(auth_config, url, headers, params)
else:
logger.warning(f"不支持的认证类型: {auth_type}")
return url, headers, params
except Exception as e:
logger.error(f"应用认证时出错: {e}")
return url, headers, params
def _apply_api_key_auth(
self,
auth_config: Dict[str, Any],
url: str,
headers: Dict[str, str],
params: Dict[str, Any]
) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
"""应用API Key认证
Args:
auth_config: 认证配置
url: 请求URL
headers: 请求头
params: 请求参数
Returns:
(修改后的URL, 修改后的headers, 修改后的params)
"""
api_key = auth_config.get("api_key")
key_name = auth_config.get("key_name", "X-API-Key")
location = auth_config.get("location", "header")
if location == "header":
headers[key_name] = api_key
elif location == "query":
# 添加到URL查询参数
separator = "&" if "?" in url else "?"
encoded_key = quote(str(api_key))
url += f"{separator}{key_name}={encoded_key}"
elif location == "cookie":
# 添加到Cookie头
cookie_value = f"{key_name}={api_key}"
if "Cookie" in headers:
headers["Cookie"] += f"; {cookie_value}"
else:
headers["Cookie"] = cookie_value
return url, headers, params
def _apply_bearer_token_auth(
self,
auth_config: Dict[str, Any],
url: str,
headers: Dict[str, str],
params: Dict[str, Any]
) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
"""应用Bearer Token认证
Args:
auth_config: 认证配置
url: 请求URL
headers: 请求头
params: 请求参数
Returns:
(修改后的URL, 修改后的headers, 修改后的params)
"""
token = auth_config.get("token")
headers["Authorization"] = f"Bearer {token}"
return url, headers, params
def encrypt_auth_config(self, auth_config: Dict[str, Any], encryption_key: str) -> Dict[str, Any]:
"""加密认证配置中的敏感信息
Args:
auth_config: 认证配置
encryption_key: 加密密钥
Returns:
加密后的认证配置
"""
try:
encrypted_config = auth_config.copy()
# 需要加密的字段
sensitive_fields = ["api_key", "token", "secret", "password"]
for field in sensitive_fields:
if field in encrypted_config:
value = encrypted_config[field]
if isinstance(value, str) and value:
encrypted_value = self._encrypt_string(value, encryption_key)
encrypted_config[field] = encrypted_value
encrypted_config[f"{field}_encrypted"] = True
return encrypted_config
except Exception as e:
logger.error(f"加密认证配置失败: {e}")
return auth_config
def decrypt_auth_config(self, encrypted_config: Dict[str, Any], encryption_key: str) -> Dict[str, Any]:
"""解密认证配置中的敏感信息
Args:
encrypted_config: 加密的认证配置
encryption_key: 解密密钥
Returns:
解密后的认证配置
"""
try:
decrypted_config = encrypted_config.copy()
# 需要解密的字段
sensitive_fields = ["api_key", "token", "secret", "password"]
for field in sensitive_fields:
if field in decrypted_config and decrypted_config.get(f"{field}_encrypted"):
encrypted_value = decrypted_config[field]
if isinstance(encrypted_value, str) and encrypted_value:
decrypted_value = self._decrypt_string(encrypted_value, encryption_key)
decrypted_config[field] = decrypted_value
# 移除加密标记
decrypted_config.pop(f"{field}_encrypted", None)
return decrypted_config
except Exception as e:
logger.error(f"解密认证配置失败: {e}")
return encrypted_config
def _encrypt_string(self, value: str, key: str) -> str:
"""加密字符串
Args:
value: 要加密的字符串
key: 加密密钥
Returns:
加密后的字符串Base64编码
"""
try:
# 使用HMAC-SHA256进行简单加密
key_bytes = key.encode('utf-8')
value_bytes = value.encode('utf-8')
# 生成HMAC
hmac_obj = hmac.new(key_bytes, value_bytes, hashlib.sha256)
signature = hmac_obj.hexdigest()
# 组合原始值和签名然后Base64编码
combined = f"{value}:{signature}"
encrypted = base64.b64encode(combined.encode('utf-8')).decode('utf-8')
return encrypted
except Exception as e:
logger.error(f"加密字符串失败: {e}")
return value
def _decrypt_string(self, encrypted_value: str, key: str) -> str:
"""解密字符串
Args:
encrypted_value: 加密的字符串
key: 解密密钥
Returns:
解密后的字符串
"""
try:
# Base64解码
decoded = base64.b64decode(encrypted_value.encode('utf-8')).decode('utf-8')
# 分离原始值和签名
if ':' not in decoded:
return encrypted_value # 可能不是加密的值
value, signature = decoded.rsplit(':', 1)
# 验证签名
key_bytes = key.encode('utf-8')
value_bytes = value.encode('utf-8')
hmac_obj = hmac.new(key_bytes, value_bytes, hashlib.sha256)
expected_signature = hmac_obj.hexdigest()
if signature == expected_signature:
return value
else:
logger.warning("解密时签名验证失败")
return encrypted_value
except Exception as e:
logger.error(f"解密字符串失败: {e}")
return encrypted_value
def test_authentication(
self,
auth_type: AuthType,
auth_config: Dict[str, Any],
test_url: str = None
) -> Dict[str, Any]:
"""测试认证配置
Args:
auth_type: 认证类型
auth_config: 认证配置
test_url: 测试URL可选
Returns:
测试结果
"""
try:
# 验证配置
is_valid, error_msg = self.validate_auth_config(auth_type, auth_config)
if not is_valid:
return {
"success": False,
"error": error_msg,
"auth_type": auth_type.value
}
# 如果没有测试URL只验证配置
if not test_url:
return {
"success": True,
"message": "认证配置有效",
"auth_type": auth_type.value
}
# 构建测试请求
headers = {"User-Agent": "AuthManager-Test/1.0"}
params = {}
# 应用认证
test_url, headers, params = self.apply_authentication(
auth_type, auth_config, test_url, headers, params
)
return {
"success": True,
"message": "认证配置测试成功",
"auth_type": auth_type.value,
"test_url": test_url,
"headers": {k: v for k, v in headers.items() if k != "Authorization"}, # 不返回敏感信息
"has_auth_header": "Authorization" in headers
}
except Exception as e:
return {
"success": False,
"error": str(e),
"auth_type": auth_type.value if auth_type else "unknown"
}
async def test_authentication_with_request(
self,
auth_type: AuthType,
auth_config: Dict[str, Any],
test_url: str,
timeout: int = 10
) -> Dict[str, Any]:
"""通过实际HTTP请求测试认证
Args:
auth_type: 认证类型
auth_config: 认证配置
test_url: 测试URL
timeout: 超时时间(秒)
Returns:
测试结果
"""
try:
# 验证配置
is_valid, error_msg = self.validate_auth_config(auth_type, auth_config)
if not is_valid:
return {
"success": False,
"error": error_msg,
"auth_type": auth_type.value
}
# 构建请求
headers = {"User-Agent": "AuthManager-Test/1.0"}
params = {}
# 应用认证
test_url, headers, params = self.apply_authentication(
auth_type, auth_config, test_url, headers, params
)
# 发送测试请求
client_timeout = aiohttp.ClientTimeout(total=timeout)
async with aiohttp.ClientSession(timeout=client_timeout) as session:
async with session.get(test_url, headers=headers) as response:
status_code = response.status
# 根据状态码判断认证是否成功
if status_code == 200:
return {
"success": True,
"message": "认证测试成功",
"status_code": status_code,
"auth_type": auth_type.value
}
elif status_code == 401:
return {
"success": False,
"error": "认证失败 - 401 Unauthorized",
"status_code": status_code,
"auth_type": auth_type.value
}
elif status_code == 403:
return {
"success": False,
"error": "认证失败 - 403 Forbidden",
"status_code": status_code,
"auth_type": auth_type.value
}
else:
return {
"success": True,
"message": f"请求成功,状态码: {status_code}",
"status_code": status_code,
"auth_type": auth_type.value
}
except aiohttp.ClientError as e:
return {
"success": False,
"error": f"网络请求失败: {e}",
"auth_type": auth_type.value
}
except Exception as e:
return {
"success": False,
"error": f"测试认证时出错: {e}",
"auth_type": auth_type.value
}
def get_auth_config_template(self, auth_type: AuthType) -> Dict[str, Any]:
"""获取认证配置模板
Args:
auth_type: 认证类型
Returns:
配置模板
"""
templates = {
AuthType.NONE: {},
AuthType.API_KEY: {
"api_key": "",
"key_name": "X-API-Key",
"location": "header", # header, query, cookie
"description": "API Key认证配置"
},
AuthType.BEARER_TOKEN: {
"token": "",
"description": "Bearer Token认证配置"
}
}
return templates.get(auth_type, {})
def mask_sensitive_config(self, auth_config: Dict[str, Any]) -> Dict[str, Any]:
"""遮蔽认证配置中的敏感信息
Args:
auth_config: 认证配置
Returns:
遮蔽敏感信息后的配置
"""
masked_config = auth_config.copy()
# 需要遮蔽的字段
sensitive_fields = ["api_key", "token", "secret", "password"]
for field in sensitive_fields:
if field in masked_config:
value = masked_config[field]
if isinstance(value, str) and len(value) > 4:
# 只显示前2位和后2位
masked_config[field] = f"{value[:2]}***{value[-2:]}"
elif isinstance(value, str) and value:
masked_config[field] = "***"
return masked_config

View File

@@ -0,0 +1,318 @@
"""自定义工具基类"""
import time
from typing import Dict, Any, List, Optional
import aiohttp
from urllib.parse import urljoin
from app.models.tool_model import ToolType, AuthType
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class CustomTool(BaseTool):
"""自定义工具 - 基于OpenAPI schema的工具"""
def __init__(self, tool_id: str, config: Dict[str, Any]):
"""初始化自定义工具
Args:
tool_id: 工具ID
config: 工具配置
"""
super().__init__(tool_id, config)
self.schema_content = config.get("schema_content", {})
self.schema_url = config.get("schema_url")
self.auth_type = AuthType(config.get("auth_type", "none"))
self.auth_config = config.get("auth_config", {})
self.base_url = config.get("base_url", "")
self.timeout = config.get("timeout", 30)
# 解析schema
self._parsed_operations = self._parse_openapi_schema()
@property
def name(self) -> str:
"""工具名称"""
if self.schema_content:
info = self.schema_content.get("info", {})
return info.get("title", f"custom_tool_{self.tool_id[:8]}")
return f"custom_tool_{self.tool_id[:8]}"
@property
def description(self) -> str:
"""工具描述"""
if self.schema_content:
info = self.schema_content.get("info", {})
return info.get("description", "自定义API工具")
return "自定义API工具"
@property
def tool_type(self) -> ToolType:
"""工具类型"""
return ToolType.CUSTOM
@property
def parameters(self) -> List[ToolParameter]:
"""工具参数定义"""
params = []
# 添加操作选择参数
if len(self._parsed_operations) > 1:
params.append(ToolParameter(
name="operation",
type=ParameterType.STRING,
description="要执行的操作",
required=True,
enum=list(self._parsed_operations.keys())
))
# 添加通用参数(基于第一个操作的参数)
if self._parsed_operations:
first_operation = next(iter(self._parsed_operations.values()))
for param_name, param_info in first_operation.get("parameters", {}).items():
params.append(ToolParameter(
name=param_name,
type=self._convert_openapi_type(param_info.get("type", "string")),
description=param_info.get("description", ""),
required=param_info.get("required", False),
default=param_info.get("default"),
enum=param_info.get("enum"),
minimum=param_info.get("minimum"),
maximum=param_info.get("maximum"),
pattern=param_info.get("pattern")
))
return params
async def execute(self, **kwargs) -> ToolResult:
"""执行自定义工具"""
start_time = time.time()
try:
# 确定要执行的操作
operation_name = kwargs.get("operation")
if not operation_name and len(self._parsed_operations) == 1:
operation_name = next(iter(self._parsed_operations.keys()))
if not operation_name or operation_name not in self._parsed_operations:
raise ValueError(f"无效的操作: {operation_name}")
operation = self._parsed_operations[operation_name]
# 构建请求
url = self._build_request_url(operation, kwargs)
headers = self._build_request_headers(operation)
data = self._build_request_data(operation, kwargs)
# 发送HTTP请求
result = await self._send_http_request(
method=operation["method"],
url=url,
headers=headers,
data=data
)
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="CUSTOM_TOOL_ERROR",
execution_time=execution_time
)
def _parse_openapi_schema(self) -> Dict[str, Any]:
"""解析OpenAPI schema"""
operations = {}
if not self.schema_content:
return operations
paths = self.schema_content.get("paths", {})
for path, path_item in paths.items():
for method, operation in path_item.items():
if method.lower() in ["get", "post", "put", "delete", "patch"]:
operation_id = operation.get("operationId", f"{method}_{path.replace('/', '_')}")
# 解析参数
parameters = {}
if "parameters" in operation:
for param in operation["parameters"]:
param_name = param.get("name")
param_schema = param.get("schema", {})
parameters[param_name] = {
"type": param_schema.get("type", "string"),
"description": param.get("description", ""),
"required": param.get("required", False),
"in": param.get("in", "query"),
**param_schema
}
# 解析请求体
request_body = None
if "requestBody" in operation:
content = operation["requestBody"].get("content", {})
if "application/json" in content:
request_body = content["application/json"].get("schema", {})
operations[operation_id] = {
"method": method.upper(),
"path": path,
"summary": operation.get("summary", ""),
"description": operation.get("description", ""),
"parameters": parameters,
"request_body": request_body
}
return operations
def _convert_openapi_type(self, openapi_type: str) -> ParameterType:
"""转换OpenAPI类型到内部类型"""
type_mapping = {
"string": ParameterType.STRING,
"integer": ParameterType.INTEGER,
"number": ParameterType.NUMBER,
"boolean": ParameterType.BOOLEAN,
"array": ParameterType.ARRAY,
"object": ParameterType.OBJECT
}
return type_mapping.get(openapi_type, ParameterType.STRING)
def _build_request_url(self, operation: Dict[str, Any], params: Dict[str, Any]) -> str:
"""构建请求URL"""
path = operation["path"]
# 替换路径参数
for param_name, param_info in operation.get("parameters", {}).items():
if param_info.get("in") == "path" and param_name in params:
path = path.replace(f"{{{param_name}}}", str(params[param_name]))
# 构建完整URL
if self.base_url:
url = urljoin(self.base_url, path.lstrip("/"))
else:
# 从schema中获取服务器URL
servers = self.schema_content.get("servers", [])
if servers:
base_url = servers[0].get("url", "")
url = urljoin(base_url, path.lstrip("/"))
else:
url = path
# 添加查询参数
query_params = {}
for param_name, param_info in operation.get("parameters", {}).items():
if param_info.get("in") == "query" and param_name in params:
query_params[param_name] = params[param_name]
if query_params:
from urllib.parse import urlencode
url += "?" + urlencode(query_params)
return url
def _build_request_headers(self, operation: Dict[str, Any]) -> Dict[str, str]:
"""构建请求头"""
headers = {
"Content-Type": "application/json",
"User-Agent": "CustomTool/1.0"
}
# 添加认证头
if self.auth_type == AuthType.API_KEY:
api_key = self.auth_config.get("api_key")
key_name = self.auth_config.get("key_name", "X-API-Key")
if api_key:
headers[key_name] = api_key
elif self.auth_type == AuthType.BEARER_TOKEN:
token = self.auth_config.get("token")
if token:
headers["Authorization"] = f"Bearer {token}"
return headers
def _build_request_data(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""构建请求数据"""
if operation["method"] in ["POST", "PUT", "PATCH"]:
request_body = operation.get("request_body")
if request_body:
# 构建请求体数据
data = {}
properties = request_body.get("properties", {})
for prop_name, prop_schema in properties.items():
if prop_name in params:
data[prop_name] = params[prop_name]
return data if data else None
return None
async def _send_http_request(
self,
method: str,
url: str,
headers: Dict[str, str],
data: Optional[Dict[str, Any]] = None
) -> Any:
"""发送HTTP请求"""
timeout = aiohttp.ClientTimeout(total=self.timeout)
async with aiohttp.ClientSession(timeout=timeout) as session:
kwargs = {
"headers": headers
}
if data and method in ["POST", "PUT", "PATCH"]:
kwargs["json"] = data
async with session.request(method, url, **kwargs) as response:
if response.status >= 400:
error_text = await response.text()
raise Exception(f"HTTP {response.status}: {error_text}")
# 尝试解析JSON响应
try:
return await response.json()
except Exception as e:
return await response.text()
@classmethod
def from_url(cls, schema_url: str, auth_config: Dict[str, Any], tool_id: str = None) -> 'CustomTool':
"""从URL导入OpenAPI schema创建工具"""
import uuid
if not tool_id:
tool_id = str(uuid.uuid4())
config = {
"schema_url": schema_url,
"auth_config": auth_config,
"auth_type": auth_config.get("type", "none")
}
# 这里应该异步加载schema为了简化暂时返回空配置
return cls(tool_id, config)
@classmethod
def from_schema(cls, schema_dict: Dict[str, Any], auth_config: Dict[str, Any], tool_id: str = None) -> 'CustomTool':
"""从schema字典创建工具"""
import uuid
if not tool_id:
tool_id = str(uuid.uuid4())
config = {
"schema_content": schema_dict,
"auth_config": auth_config,
"auth_type": auth_config.get("type", "none")
}
return cls(tool_id, config)

View File

@@ -0,0 +1,477 @@
"""OpenAPI Schema解析器"""
import json
import yaml
from typing import Dict, Any, List, Optional, Tuple
from urllib.parse import urlparse
import aiohttp
import asyncio
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class OpenAPISchemaParser:
"""OpenAPI Schema解析器 - 解析OpenAPI 3.0规范"""
def __init__(self):
"""初始化解析器"""
self.supported_versions = ["3.0.0", "3.0.1", "3.0.2", "3.0.3", "3.1.0"]
async def parse_from_url(self, schema_url: str, timeout: int = 30) -> Tuple[bool, Dict[str, Any], str]:
"""从URL解析OpenAPI schema
Args:
schema_url: Schema URL
timeout: 超时时间(秒)
Returns:
(是否成功, schema内容, 错误信息)
"""
try:
# 验证URL格式
parsed_url = urlparse(schema_url)
if not parsed_url.scheme or not parsed_url.netloc:
return False, {}, "无效的URL格式"
# 下载schema
client_timeout = aiohttp.ClientTimeout(total=timeout)
async with aiohttp.ClientSession(timeout=client_timeout) as session:
async with session.get(schema_url) as response:
if response.status != 200:
return False, {}, f"HTTP错误: {response.status}"
content_type = response.headers.get('content-type', '').lower()
content = await response.text()
# 解析内容
schema_dict = self._parse_content(content, content_type)
if not schema_dict:
return False, {}, "无法解析schema内容"
# 验证schema
is_valid, error_msg = self.validate_schema(schema_dict)
if not is_valid:
return False, {}, error_msg
return True, schema_dict, ""
except asyncio.TimeoutError:
return False, {}, "请求超时"
except Exception as e:
logger.error(f"从URL解析schema失败: {schema_url}, 错误: {e}")
return False, {}, str(e)
def parse_from_content(self, content: str, content_type: str = "application/json") -> Tuple[bool, Dict[str, Any], str]:
"""从内容解析OpenAPI schema
Args:
content: Schema内容
content_type: 内容类型
Returns:
(是否成功, schema内容, 错误信息)
"""
try:
# 解析内容
schema_dict = self._parse_content(content, content_type)
if not schema_dict:
return False, {}, "无法解析schema内容"
# 验证schema
is_valid, error_msg = self.validate_schema(schema_dict)
if not is_valid:
return False, {}, error_msg
return True, schema_dict, ""
except Exception as e:
logger.error(f"解析schema内容失败: {e}")
return False, {}, str(e)
def _parse_content(self, content: str, content_type: str) -> Optional[Dict[str, Any]]:
"""解析内容为字典
Args:
content: 内容字符串
content_type: 内容类型
Returns:
解析后的字典失败返回None
"""
try:
# 根据内容类型解析
if 'json' in content_type:
return json.loads(content)
elif 'yaml' in content_type or 'yml' in content_type:
return yaml.safe_load(content)
else:
# 尝试自动检测格式
try:
return json.loads(content)
except json.JSONDecodeError:
try:
return yaml.safe_load(content)
except yaml.YAMLError:
return None
except Exception as e:
logger.error(f"解析内容失败: {e}")
return None
def validate_schema(self, schema_dict: Dict[str, Any]) -> Tuple[bool, str]:
"""验证OpenAPI schema
Args:
schema_dict: Schema字典
Returns:
(是否有效, 错误信息)
"""
try:
# 检查基本结构
if not isinstance(schema_dict, dict):
return False, "Schema必须是JSON对象"
# 检查OpenAPI版本
openapi_version = schema_dict.get("openapi")
if not openapi_version:
return False, "缺少openapi版本字段"
if openapi_version not in self.supported_versions:
return False, f"不支持的OpenAPI版本: {openapi_version}"
# 检查必需字段
required_fields = ["info", "paths"]
for field in required_fields:
if field not in schema_dict:
return False, f"缺少必需字段: {field}"
# 验证info字段
info = schema_dict.get("info", {})
if not isinstance(info, dict):
return False, "info字段必须是对象"
if "title" not in info:
return False, "info.title字段是必需的"
# 验证paths字段
paths = schema_dict.get("paths", {})
if not isinstance(paths, dict):
return False, "paths字段必须是对象"
# 验证至少有一个路径
if not paths:
return False, "至少需要定义一个API路径"
return True, ""
except Exception as e:
return False, f"验证schema时出错: {e}"
def extract_tool_info(self, schema_dict: Dict[str, Any]) -> Dict[str, Any]:
"""从schema提取工具信息
Args:
schema_dict: Schema字典
Returns:
工具信息字典
"""
info = schema_dict.get("info", {})
return {
"name": info.get("title", "Custom API Tool"),
"description": info.get("description", ""),
"version": info.get("version", "1.0.0"),
"servers": schema_dict.get("servers", []),
"operations": self._extract_operations(schema_dict)
}
def _extract_operations(self, schema_dict: Dict[str, Any]) -> Dict[str, Any]:
"""提取API操作信息
Args:
schema_dict: Schema字典
Returns:
操作信息字典
"""
operations = {}
paths = schema_dict.get("paths", {})
for path, path_item in paths.items():
if not isinstance(path_item, dict):
continue
for method, operation in path_item.items():
if method.lower() not in ["get", "post", "put", "delete", "patch", "head", "options"]:
continue
if not isinstance(operation, dict):
continue
# 生成操作ID
operation_id = operation.get("operationId")
if not operation_id:
operation_id = f"{method.lower()}_{path.replace('/', '_').replace('{', '').replace('}', '')}"
# 提取操作信息
operations[operation_id] = {
"method": method.upper(),
"path": path,
"summary": operation.get("summary", ""),
"description": operation.get("description", ""),
"parameters": self._extract_parameters(operation),
"request_body": self._extract_request_body(operation),
"responses": self._extract_responses(operation),
"tags": operation.get("tags", [])
}
return operations
def _extract_parameters(self, operation: Dict[str, Any]) -> Dict[str, Any]:
"""提取操作参数
Args:
operation: 操作定义
Returns:
参数信息字典
"""
parameters = {}
for param in operation.get("parameters", []):
if not isinstance(param, dict):
continue
param_name = param.get("name")
if not param_name:
continue
param_schema = param.get("schema", {})
parameters[param_name] = {
"name": param_name,
"in": param.get("in", "query"),
"description": param.get("description", ""),
"required": param.get("required", False),
"type": param_schema.get("type", "string"),
"format": param_schema.get("format"),
"enum": param_schema.get("enum"),
"default": param_schema.get("default"),
"minimum": param_schema.get("minimum"),
"maximum": param_schema.get("maximum"),
"pattern": param_schema.get("pattern"),
"example": param.get("example") or param_schema.get("example")
}
return parameters
def _extract_request_body(self, operation: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""提取请求体信息
Args:
operation: 操作定义
Returns:
请求体信息如果没有返回None
"""
request_body = operation.get("requestBody")
if not request_body:
return None
content = request_body.get("content", {})
# 优先使用application/json
if "application/json" in content:
schema = content["application/json"].get("schema", {})
elif content:
# 使用第一个可用的内容类型
first_content_type = next(iter(content.keys()))
schema = content[first_content_type].get("schema", {})
else:
return None
return {
"description": request_body.get("description", ""),
"required": request_body.get("required", False),
"schema": schema,
"content_types": list(content.keys())
}
def _extract_responses(self, operation: Dict[str, Any]) -> Dict[str, Any]:
"""提取响应信息
Args:
operation: 操作定义
Returns:
响应信息字典
"""
responses = {}
for status_code, response in operation.get("responses", {}).items():
if not isinstance(response, dict):
continue
content = response.get("content", {})
schema = None
# 尝试获取响应schema
if "application/json" in content:
schema = content["application/json"].get("schema")
elif content:
first_content_type = next(iter(content.keys()))
schema = content[first_content_type].get("schema")
responses[status_code] = {
"description": response.get("description", ""),
"schema": schema,
"content_types": list(content.keys()) if content else []
}
return responses
def generate_tool_parameters(self, operations: Dict[str, Any]) -> List[Dict[str, Any]]:
"""生成工具参数定义
Args:
operations: 操作信息字典
Returns:
参数定义列表
"""
parameters = []
# 如果有多个操作,添加操作选择参数
if len(operations) > 1:
parameters.append({
"name": "operation",
"type": "string",
"description": "要执行的操作",
"required": True,
"enum": list(operations.keys())
})
# 收集所有参数(去重)
all_params = {}
for operation_id, operation in operations.items():
# 路径参数和查询参数
for param_name, param_info in operation.get("parameters", {}).items():
if param_name not in all_params:
all_params[param_name] = {
"name": param_name,
"type": param_info.get("type", "string"),
"description": param_info.get("description", ""),
"required": param_info.get("required", False),
"enum": param_info.get("enum"),
"default": param_info.get("default"),
"minimum": param_info.get("minimum"),
"maximum": param_info.get("maximum"),
"pattern": param_info.get("pattern")
}
# 请求体参数
request_body = operation.get("request_body")
if request_body:
schema = request_body.get("schema", {})
properties = schema.get("properties", {})
for prop_name, prop_schema in properties.items():
if prop_name not in all_params:
all_params[prop_name] = {
"name": prop_name,
"type": prop_schema.get("type", "string"),
"description": prop_schema.get("description", ""),
"required": prop_name in schema.get("required", []),
"enum": prop_schema.get("enum"),
"default": prop_schema.get("default"),
"minimum": prop_schema.get("minimum"),
"maximum": prop_schema.get("maximum"),
"pattern": prop_schema.get("pattern")
}
# 转换为参数列表
parameters.extend(all_params.values())
return parameters
def validate_operation_parameters(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Tuple[bool, List[str]]:
"""验证操作参数
Args:
operation: 操作定义
params: 输入参数
Returns:
(是否有效, 错误信息列表)
"""
errors = []
# 验证路径参数和查询参数
for param_name, param_info in operation.get("parameters", {}).items():
if param_info.get("required", False) and param_name not in params:
errors.append(f"缺少必需参数: {param_name}")
if param_name in params:
value = params[param_name]
param_type = param_info.get("type", "string")
# 类型验证
if not self._validate_parameter_type(value, param_type):
errors.append(f"参数 {param_name} 类型错误,期望: {param_type}")
# 枚举验证
enum_values = param_info.get("enum")
if enum_values and value not in enum_values:
errors.append(f"参数 {param_name} 值无效,必须是: {enum_values}")
# 验证请求体参数
request_body = operation.get("request_body")
if request_body:
schema = request_body.get("schema", {})
required_props = schema.get("required", [])
properties = schema.get("properties", {})
for prop_name in required_props:
if prop_name not in params:
errors.append(f"缺少必需的请求体参数: {prop_name}")
for prop_name, value in params.items():
if prop_name in properties:
prop_schema = properties[prop_name]
prop_type = prop_schema.get("type", "string")
if not self._validate_parameter_type(value, prop_type):
errors.append(f"请求体参数 {prop_name} 类型错误,期望: {prop_type}")
return len(errors) == 0, errors
def _validate_parameter_type(self, value: Any, expected_type: str) -> bool:
"""验证参数类型
Args:
value: 参数值
expected_type: 期望类型
Returns:
是否类型匹配
"""
if value is None:
return True
type_mapping = {
"string": str,
"integer": int,
"number": (int, float),
"boolean": bool,
"array": list,
"object": dict
}
expected_python_type = type_mapping.get(expected_type)
if expected_python_type:
return isinstance(value, expected_python_type)
return True

View File

@@ -0,0 +1,501 @@
"""工具执行器 - 负责工具的实际调用和执行管理"""
import asyncio
import uuid
import time
from typing import Dict, Any, List, Optional
from datetime import datetime
from sqlalchemy.orm import Session
from app.models.tool_model import ToolExecution, ExecutionStatus
from app.core.tools.base import BaseTool, ToolResult
from app.core.tools.registry import ToolRegistry
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class ExecutionContext:
"""执行上下文"""
def __init__(
self,
execution_id: str,
tool_id: str,
user_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None,
timeout: Optional[float] = None,
metadata: Optional[Dict[str, Any]] = None
):
self.execution_id = execution_id
self.tool_id = tool_id
self.user_id = user_id
self.workspace_id = workspace_id
self.timeout = timeout or 60.0 # 默认60秒超时
self.metadata = metadata or {}
self.started_at = datetime.now()
self.completed_at: Optional[datetime] = None
self.status = ExecutionStatus.PENDING
class ToolExecutor:
"""工具执行器 - 使用langchain标准接口执行工具"""
def __init__(self, db: Session, registry: ToolRegistry):
"""初始化工具执行器
Args:
db: 数据库会话
registry: 工具注册表
"""
self.db = db
self.registry = registry
self._running_executions: Dict[str, ExecutionContext] = {}
self._execution_lock = asyncio.Lock()
async def execute_tool(
self,
tool_id: str,
parameters: Dict[str, Any],
user_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None,
execution_id: Optional[str] = None,
timeout: Optional[float] = None,
metadata: Optional[Dict[str, Any]] = None
) -> ToolResult:
"""执行工具
Args:
tool_id: 工具ID
parameters: 工具参数
user_id: 用户ID
workspace_id: 工作空间ID
execution_id: 执行ID可选自动生成
timeout: 超时时间(秒)
metadata: 额外元数据
Returns:
工具执行结果
"""
# 生成执行ID
if not execution_id:
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
# 创建执行上下文
context = ExecutionContext(
execution_id=execution_id,
tool_id=tool_id,
user_id=user_id,
workspace_id=workspace_id,
timeout=timeout,
metadata=metadata
)
try:
# 获取工具实例
tool = self.registry.get_tool(tool_id)
if not tool:
return ToolResult.error_result(
error=f"工具不存在: {tool_id}",
error_code="TOOL_NOT_FOUND",
execution_time=0.0
)
# 记录执行开始
await self._record_execution_start(context, parameters)
# 执行工具
result = await self._execute_with_timeout(tool, parameters, context)
# 记录执行完成
await self._record_execution_complete(context, result)
return result
except Exception as e:
logger.error(f"工具执行异常: {execution_id}, 错误: {e}")
# 记录执行失败
error_result = ToolResult.error_result(
error=str(e),
error_code="EXECUTION_ERROR",
execution_time=time.time() - context.started_at.timestamp()
)
await self._record_execution_complete(context, error_result)
return error_result
finally:
# 清理执行上下文
async with self._execution_lock:
if execution_id in self._running_executions:
del self._running_executions[execution_id]
async def execute_tools_batch(
self,
tool_executions: List[Dict[str, Any]],
max_concurrency: int = 5
) -> List[ToolResult]:
"""批量执行工具
Args:
tool_executions: 工具执行配置列表每个包含tool_id和parameters
max_concurrency: 最大并发数
Returns:
执行结果列表
"""
semaphore = asyncio.Semaphore(max_concurrency)
async def execute_single(exec_config: Dict[str, Any]) -> ToolResult:
async with semaphore:
return await self.execute_tool(
tool_id=exec_config["tool_id"],
parameters=exec_config.get("parameters", {}),
user_id=exec_config.get("user_id"),
workspace_id=exec_config.get("workspace_id"),
timeout=exec_config.get("timeout"),
metadata=exec_config.get("metadata")
)
# 并发执行所有工具
tasks = [execute_single(config) for config in tool_executions]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理异常结果
processed_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
processed_results.append(
ToolResult.error_result(
error=str(result),
error_code="BATCH_EXECUTION_ERROR",
execution_time=0.0
)
)
else:
processed_results.append(result)
return processed_results
async def cancel_execution(self, execution_id: str) -> bool:
"""取消工具执行
Args:
execution_id: 执行ID
Returns:
是否成功取消
"""
async with self._execution_lock:
if execution_id not in self._running_executions:
return False
context = self._running_executions[execution_id]
context.status = ExecutionStatus.FAILED
# 更新数据库记录
execution_record = self.db.query(ToolExecution).filter(
ToolExecution.execution_id == execution_id
).first()
if execution_record:
execution_record.status = ExecutionStatus.FAILED.value
execution_record.error_message = "执行被取消"
execution_record.completed_at = datetime.now()
self.db.commit()
logger.info(f"工具执行已取消: {execution_id}")
return True
def get_running_executions(self) -> List[Dict[str, Any]]:
"""获取正在运行的执行列表
Returns:
执行信息列表
"""
executions = []
for execution_id, context in self._running_executions.items():
executions.append({
"execution_id": execution_id,
"tool_id": context.tool_id,
"user_id": str(context.user_id) if context.user_id else None,
"workspace_id": str(context.workspace_id) if context.workspace_id else None,
"started_at": context.started_at.isoformat(),
"status": context.status.value,
"elapsed_time": (datetime.now() - context.started_at).total_seconds()
})
return executions
async def _execute_with_timeout(
self,
tool: BaseTool,
parameters: Dict[str, Any],
context: ExecutionContext
) -> ToolResult:
"""带超时的工具执行
Args:
tool: 工具实例
parameters: 参数
context: 执行上下文
Returns:
执行结果
"""
async with self._execution_lock:
self._running_executions[context.execution_id] = context
context.status = ExecutionStatus.RUNNING
try:
# 使用asyncio.wait_for实现超时控制
result = await asyncio.wait_for(
tool.safe_execute(**parameters),
timeout=context.timeout
)
context.status = ExecutionStatus.COMPLETED
return result
except asyncio.TimeoutError:
context.status = ExecutionStatus.TIMEOUT
return ToolResult.error_result(
error=f"工具执行超时({context.timeout}秒)",
error_code="EXECUTION_TIMEOUT",
execution_time=context.timeout
)
except Exception as e:
context.status = ExecutionStatus.FAILED
raise
async def _record_execution_start(
self,
context: ExecutionContext,
parameters: Dict[str, Any]
):
"""记录执行开始"""
try:
execution_record = ToolExecution(
execution_id=context.execution_id,
tool_config_id=uuid.UUID(context.tool_id),
status=ExecutionStatus.RUNNING.value,
input_data=parameters,
started_at=context.started_at,
user_id=context.user_id,
workspace_id=context.workspace_id
)
self.db.add(execution_record)
self.db.commit()
logger.debug(f"执行记录已创建: {context.execution_id}")
except Exception as e:
logger.error(f"创建执行记录失败: {context.execution_id}, 错误: {e}")
async def _record_execution_complete(
self,
context: ExecutionContext,
result: ToolResult
):
"""记录执行完成"""
try:
context.completed_at = datetime.now()
execution_record = self.db.query(ToolExecution).filter(
ToolExecution.execution_id == context.execution_id
).first()
if execution_record:
execution_record.status = (
ExecutionStatus.COMPLETED.value if result.success
else ExecutionStatus.FAILED.value
)
execution_record.output_data = result.data if result.success else None
execution_record.error_message = result.error if not result.success else None
execution_record.completed_at = context.completed_at
execution_record.execution_time = result.execution_time
execution_record.token_usage = result.token_usage
self.db.commit()
logger.debug(f"执行记录已更新: {context.execution_id}")
except Exception as e:
logger.error(f"更新执行记录失败: {context.execution_id}, 错误: {e}")
def get_execution_history(
self,
tool_id: Optional[str] = None,
user_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None,
limit: int = 100
) -> List[Dict[str, Any]]:
"""获取执行历史
Args:
tool_id: 工具ID过滤
user_id: 用户ID过滤
workspace_id: 工作空间ID过滤
limit: 返回数量限制
Returns:
执行历史列表
"""
try:
query = self.db.query(ToolExecution).order_by(
ToolExecution.started_at.desc()
)
if tool_id:
query = query.filter(ToolExecution.tool_config_id == uuid.UUID(tool_id))
if user_id:
query = query.filter(ToolExecution.user_id == user_id)
if workspace_id:
query = query.filter(ToolExecution.workspace_id == workspace_id)
executions = query.limit(limit).all()
history = []
for execution in executions:
history.append({
"execution_id": execution.execution_id,
"tool_id": str(execution.tool_config_id),
"status": execution.status,
"started_at": execution.started_at.isoformat() if execution.started_at else None,
"completed_at": execution.completed_at.isoformat() if execution.completed_at else None,
"execution_time": execution.execution_time,
"user_id": str(execution.user_id) if execution.user_id else None,
"workspace_id": str(execution.workspace_id) if execution.workspace_id else None,
"input_data": execution.input_data,
"output_data": execution.output_data,
"error_message": execution.error_message,
"token_usage": execution.token_usage
})
return history
except Exception as e:
logger.error(f"获取执行历史失败, 错误: {e}")
return []
def get_execution_statistics(
self,
workspace_id: Optional[uuid.UUID] = None,
days: int = 7
) -> Dict[str, Any]:
"""获取执行统计信息
Args:
workspace_id: 工作空间ID
days: 统计天数
Returns:
统计信息
"""
try:
from datetime import timedelta
start_date = datetime.now() - timedelta(days=days)
query = self.db.query(ToolExecution).filter(
ToolExecution.started_at >= start_date
)
if workspace_id:
query = query.filter(ToolExecution.workspace_id == workspace_id)
executions = query.all()
# 统计数据
total_executions = len(executions)
successful_executions = len([e for e in executions if e.status == ExecutionStatus.COMPLETED.value])
failed_executions = len([e for e in executions if e.status == ExecutionStatus.FAILED.value])
# 平均执行时间
completed_executions = [e for e in executions if e.execution_time is not None]
avg_execution_time = (
sum(e.execution_time for e in completed_executions) / len(completed_executions)
if completed_executions else 0
)
# 按工具统计
tool_stats = {}
for execution in executions:
tool_id = str(execution.tool_config_id)
if tool_id not in tool_stats:
tool_stats[tool_id] = {"total": 0, "successful": 0, "failed": 0}
tool_stats[tool_id]["total"] += 1
if execution.status == ExecutionStatus.COMPLETED.value:
tool_stats[tool_id]["successful"] += 1
elif execution.status == ExecutionStatus.FAILED.value:
tool_stats[tool_id]["failed"] += 1
return {
"period_days": days,
"total_executions": total_executions,
"successful_executions": successful_executions,
"failed_executions": failed_executions,
"success_rate": successful_executions / total_executions if total_executions > 0 else 0,
"average_execution_time": avg_execution_time,
"tool_statistics": tool_stats
}
except Exception as e:
logger.error(f"获取执行统计失败, 错误: {e}")
return {}
async def test_tool_connection(
self,
tool_id: str,
user_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]:
"""测试工具连接"""
try:
from app.models.tool_model import ToolConfig, ToolType, MCPToolConfig
from .mcp.client import MCPClient
tool_config = self.db.query(ToolConfig).filter(
ToolConfig.id == uuid.UUID(tool_id)
).first()
if not tool_config:
return {"success": False, "message": "工具不存在"}
if tool_config.tool_type == ToolType.MCP.value:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == tool_config.id
).first()
if not mcp_config:
return {"success": False, "message": "MCP配置不存在"}
client = MCPClient(mcp_config.server_url, mcp_config.connection_config or {})
if await client.connect():
try:
tools = await client.list_tools()
await client.disconnect()
return {
"success": True,
"message": "MCP连接成功",
"details": {"server_url": mcp_config.server_url, "tools": len(tools)}
}
except:
await client.disconnect()
return {"success": False, "message": "MCP功能测试失败"}
else:
return {"success": False, "message": "MCP连接失败"}
else:
tool = self.registry.get_tool(tool_id)
if tool and hasattr(tool, 'test_connection'):
result = tool.test_connection()
return {"success": result.get("success", False), "message": result.get("message", "")}
return {"success": True, "message": "工具无需连接测试"}
except Exception as e:
return {"success": False, "message": "测试失败", "error": str(e)}

View File

@@ -0,0 +1,375 @@
"""Langchain适配器 - 将工具转换为langchain兼容格式"""
import json
from typing import Dict, Any, List, Optional, Type
from pydantic import BaseModel, Field
from langchain.tools import BaseTool as LangchainBaseTool
from langchain_core.tools import ToolException
from app.core.tools.base import BaseTool, ToolResult, ToolParameter, ParameterType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class LangchainToolWrapper(LangchainBaseTool):
"""Langchain工具包装器"""
name: str = Field(..., description="工具名称")
description: str = Field(..., description="工具描述")
args_schema: Optional[Type[BaseModel]] = Field(None, description="参数schema")
return_direct: bool = Field(False, description="是否直接返回结果")
# 内部工具实例
tool_instance: BaseTool = Field(..., description="内部工具实例")
class Config:
arbitrary_types_allowed = True
def __init__(self, tool_instance: BaseTool, **kwargs):
"""初始化Langchain工具包装器
Args:
tool_instance: 内部工具实例
"""
# 动态创建参数schema
args_schema = LangchainAdapter._create_pydantic_schema(tool_instance.parameters)
super().__init__(
name=tool_instance.name,
description=tool_instance.description,
args_schema=args_schema,
_tool_instance=tool_instance,
**kwargs
)
def _run(
self,
run_manager=None,
**kwargs: Any,
) -> str:
"""同步执行工具Langchain要求"""
# 由于我们的工具是异步的,这里抛出异常提示使用异步版本
raise NotImplementedError("请使用 _arun 方法进行异步调用")
async def _arun(
self,
run_manager=None,
**kwargs: Any,
) -> str:
"""异步执行工具"""
try:
# 执行内部工具
result = await self._tool_instance.safe_execute(**kwargs)
# 转换结果为Langchain格式
return LangchainAdapter._format_result_for_langchain(result)
except Exception as e:
logger.error(f"工具执行失败: {self.name}, 错误: {e}")
raise ToolException(f"工具执行失败: {str(e)}")
class LangchainAdapter:
"""Langchain适配器 - 负责工具格式转换和标准化"""
@staticmethod
def convert_tool(tool: BaseTool) -> LangchainToolWrapper:
"""将内部工具转换为Langchain工具
Args:
tool: 内部工具实例
Returns:
Langchain兼容的工具包装器
"""
try:
wrapper = LangchainToolWrapper(tool_instance=tool)
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
return wrapper
except Exception as e:
logger.error(f"工具转换失败: {tool.name}, 错误: {e}")
raise
@staticmethod
def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]:
"""批量转换工具
Args:
tools: 工具列表
Returns:
Langchain工具列表
"""
converted_tools = []
for tool in tools:
try:
converted_tool = LangchainAdapter.convert_tool(tool)
converted_tools.append(converted_tool)
except Exception as e:
logger.error(f"跳过工具转换: {tool.name}, 错误: {e}")
logger.info(f"批量转换完成: {len(converted_tools)}/{len(tools)} 个工具")
return converted_tools
@staticmethod
def _create_pydantic_schema(parameters: List[ToolParameter]) -> Type[BaseModel]:
"""根据工具参数创建Pydantic schema
Args:
parameters: 工具参数列表
Returns:
Pydantic模型类
"""
# 构建字段定义
fields = {}
annotations = {}
for param in parameters:
# 确定Python类型
python_type = LangchainAdapter._get_python_type(param.type)
# 处理可选参数
if not param.required:
python_type = Optional[python_type]
# 创建Field定义
field_kwargs = {
"description": param.description
}
if param.default is not None:
field_kwargs["default"] = param.default
elif not param.required:
field_kwargs["default"] = None
else:
field_kwargs["default"] = ... # 必需字段
# 添加验证约束
if param.enum:
# 枚举值约束
field_kwargs["regex"] = f"^({'|'.join(map(str, param.enum))})$"
if param.minimum is not None:
field_kwargs["ge"] = param.minimum
if param.maximum is not None:
field_kwargs["le"] = param.maximum
if param.pattern:
field_kwargs["regex"] = param.pattern
fields[param.name] = Field(**field_kwargs)
annotations[param.name] = python_type
# 动态创建Pydantic模型
schema_class = type(
"ToolArgsSchema",
(BaseModel,),
{
"__annotations__": annotations,
**fields,
"Config": type("Config", (), {"extra": "forbid"})
}
)
return schema_class
@staticmethod
def _get_python_type(param_type: ParameterType) -> type:
"""获取参数类型对应的Python类型
Args:
param_type: 参数类型
Returns:
Python类型
"""
type_mapping = {
ParameterType.STRING: str,
ParameterType.INTEGER: int,
ParameterType.NUMBER: float,
ParameterType.BOOLEAN: bool,
ParameterType.ARRAY: list,
ParameterType.OBJECT: dict
}
return type_mapping.get(param_type, str)
@staticmethod
def _format_result_for_langchain(result: ToolResult) -> str:
"""将工具结果格式化为Langchain标准格式
Args:
result: 工具执行结果
Returns:
格式化的字符串结果
"""
if not result.success:
# 错误结果
error_info = {
"success": False,
"error": result.error,
"error_code": result.error_code,
"execution_time": result.execution_time
}
return json.dumps(error_info, ensure_ascii=False, indent=2)
# 成功结果
if isinstance(result.data, str):
# 如果数据已经是字符串,直接返回
return result.data
elif isinstance(result.data, (dict, list)):
# 如果是结构化数据转换为JSON
return json.dumps(result.data, ensure_ascii=False, indent=2)
else:
# 其他类型转换为字符串
return str(result.data)
@staticmethod
def create_tool_description(tool: BaseTool) -> Dict[str, Any]:
"""创建工具描述(用于工具发现和文档生成)
Args:
tool: 工具实例
Returns:
工具描述字典
"""
return {
"name": tool.name,
"description": tool.description,
"tool_type": tool.tool_type.value,
"version": tool.version,
"status": tool.status.value,
"tags": tool.tags,
"parameters": [
{
"name": param.name,
"type": param.type.value,
"description": param.description,
"required": param.required,
"default": param.default,
"enum": param.enum,
"minimum": param.minimum,
"maximum": param.maximum,
"pattern": param.pattern
}
for param in tool.parameters
],
"langchain_compatible": True
}
@staticmethod
def validate_langchain_compatibility(tool: BaseTool) -> tuple[bool, List[str]]:
"""验证工具是否与Langchain兼容
Args:
tool: 工具实例
Returns:
(是否兼容, 问题列表)
"""
issues = []
# 检查工具名称
if not tool.name or not isinstance(tool.name, str):
issues.append("工具名称必须是非空字符串")
# 检查工具描述
if not tool.description or not isinstance(tool.description, str):
issues.append("工具描述必须是非空字符串")
# 检查参数定义
for param in tool.parameters:
if not param.name or not isinstance(param.name, str):
issues.append(f"参数名称无效: {param.name}")
if param.type not in ParameterType:
issues.append(f"不支持的参数类型: {param.type}")
if param.required and param.default is not None:
issues.append(f"必需参数不应有默认值: {param.name}")
# 检查是否有execute方法
if not hasattr(tool, 'execute') or not callable(getattr(tool, 'execute')):
issues.append("工具必须实现execute方法")
return len(issues) == 0, issues
@staticmethod
def get_langchain_tool_schema(tool: BaseTool) -> Dict[str, Any]:
"""获取Langchain工具的OpenAPI schema
Args:
tool: 工具实例
Returns:
OpenAPI schema字典
"""
# 构建参数schema
properties = {}
required = []
for param in tool.parameters:
prop_schema = {
"type": LangchainAdapter._get_openapi_type(param.type),
"description": param.description
}
if param.enum:
prop_schema["enum"] = param.enum
if param.minimum is not None:
prop_schema["minimum"] = param.minimum
if param.maximum is not None:
prop_schema["maximum"] = param.maximum
if param.pattern:
prop_schema["pattern"] = param.pattern
if param.default is not None:
prop_schema["default"] = param.default
properties[param.name] = prop_schema
if param.required:
required.append(param.name)
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": properties,
"required": required
}
}
}
@staticmethod
def _get_openapi_type(param_type: ParameterType) -> str:
"""获取OpenAPI类型
Args:
param_type: 参数类型
Returns:
OpenAPI类型字符串
"""
type_mapping = {
ParameterType.STRING: "string",
ParameterType.INTEGER: "integer",
ParameterType.NUMBER: "number",
ParameterType.BOOLEAN: "boolean",
ParameterType.ARRAY: "array",
ParameterType.OBJECT: "object"
}
return type_mapping.get(param_type, "string")

View File

@@ -0,0 +1,12 @@
"""MCP工具模块"""
from .base import MCPTool
from .client import MCPClient, MCPConnectionPool
from .service_manager import MCPServiceManager
__all__ = [
"MCPTool",
"MCPClient",
"MCPConnectionPool",
"MCPServiceManager"
]

View File

@@ -0,0 +1,258 @@
"""MCP工具基类"""
import time
from typing import Dict, Any, List
import aiohttp
from app.models.tool_model import ToolType
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class MCPTool(BaseTool):
"""MCP工具 - Model Context Protocol工具"""
def __init__(self, tool_id: str, config: Dict[str, Any]):
"""初始化MCP工具
Args:
tool_id: 工具ID
config: 工具配置
"""
super().__init__(tool_id, config)
self.server_url = config.get("server_url", "")
self.connection_config = config.get("connection_config", {})
self.available_tools = config.get("available_tools", [])
self._client = None
self._connected = False
@property
def name(self) -> str:
"""工具名称"""
return f"mcp_tool_{self.tool_id[:8]}"
@property
def description(self) -> str:
"""工具描述"""
return f"MCP工具 - 连接到 {self.server_url}"
@property
def tool_type(self) -> ToolType:
"""工具类型"""
return ToolType.MCP
@property
def parameters(self) -> List[ToolParameter]:
"""工具参数定义"""
params = []
# 添加工具选择参数
if len(self.available_tools) > 1:
params.append(ToolParameter(
name="tool_name",
type=ParameterType.STRING,
description="要调用的MCP工具名称",
required=True,
enum=self.available_tools
))
# 添加通用参数
params.extend([
ToolParameter(
name="arguments",
type=ParameterType.OBJECT,
description="工具参数JSON对象",
required=False,
default={}
),
ToolParameter(
name="timeout",
type=ParameterType.INTEGER,
description="超时时间(秒)",
required=False,
default=30,
minimum=1,
maximum=300
)
])
return params
async def execute(self, **kwargs) -> ToolResult:
"""执行MCP工具"""
start_time = time.time()
try:
# 确保连接
if not self._connected:
await self.connect()
# 确定要调用的工具
tool_name = kwargs.get("tool_name")
if not tool_name and len(self.available_tools) == 1:
tool_name = self.available_tools[0]
if not tool_name:
raise ValueError("必须指定要调用的MCP工具名称")
if tool_name not in self.available_tools:
raise ValueError(f"MCP工具不存在: {tool_name}")
# 获取参数
arguments = kwargs.get("arguments", {})
timeout = kwargs.get("timeout", 30)
# 调用MCP工具
result = await self._call_mcp_tool(tool_name, arguments, timeout)
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="MCP_ERROR",
execution_time=execution_time
)
async def connect(self) -> bool:
"""连接到MCP服务器"""
try:
# 这里应该实现实际的MCP连接逻辑
# 为了简化,这里只是模拟连接
# 测试服务器连接
timeout = aiohttp.ClientTimeout(total=10)
async with aiohttp.ClientSession(timeout=timeout) as session:
# 尝试获取服务器信息
async with session.get(f"{self.server_url}/info") as response:
if response.status == 200:
server_info = await response.json()
self.available_tools = server_info.get("tools", [])
self._connected = True
logger.info(f"MCP服务器连接成功: {self.server_url}")
return True
else:
raise Exception(f"服务器响应错误: {response.status}")
except Exception as e:
logger.error(f"MCP服务器连接失败: {self.server_url}, 错误: {e}")
self._connected = False
return False
async def disconnect(self) -> bool:
"""断开MCP服务器连接"""
try:
if self._client:
# 这里应该实现实际的断开逻辑
self._client = None
self._connected = False
logger.info(f"MCP服务器连接已断开: {self.server_url}")
return True
except Exception as e:
logger.error(f"断开MCP服务器连接失败: {e}")
return False
def get_health_status(self) -> Dict[str, Any]:
"""获取MCP服务健康状态"""
return {
"connected": self._connected,
"server_url": self.server_url,
"available_tools": self.available_tools,
"last_check": time.time()
}
async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any:
"""调用MCP工具"""
# 构建MCP请求
request_data = {
"jsonrpc": "2.0",
"id": f"req_{int(time.time() * 1000)}",
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments
}
}
# 发送请求
client_timeout = aiohttp.ClientTimeout(total=timeout)
async with aiohttp.ClientSession(timeout=client_timeout) as session:
async with session.post(
f"{self.server_url}/mcp",
json=request_data,
headers={"Content-Type": "application/json"}
) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(f"MCP请求失败 {response.status}: {error_text}")
result = await response.json()
# 检查MCP响应
if "error" in result:
error = result["error"]
raise Exception(f"MCP工具错误: {error.get('message', '未知错误')}")
return result.get("result", {})
async def list_available_tools(self) -> List[Dict[str, Any]]:
"""列出可用的MCP工具"""
try:
if not self._connected:
await self.connect()
# 获取工具列表
request_data = {
"jsonrpc": "2.0",
"id": f"req_{int(time.time() * 1000)}",
"method": "tools/list"
}
timeout = aiohttp.ClientTimeout(total=10)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(
f"{self.server_url}/mcp",
json=request_data,
headers={"Content-Type": "application/json"}
) as response:
if response.status == 200:
result = await response.json()
if "result" in result:
tools = result["result"].get("tools", [])
self.available_tools = [tool.get("name") for tool in tools]
return tools
return []
except Exception as e:
logger.error(f"获取MCP工具列表失败: {e}")
return []
def test_connection(self) -> Dict[str, Any]:
"""测试MCP连接"""
try:
# 这里应该实现同步的连接测试
# 为了简化,返回基本信息
return {
"success": bool(self.server_url),
"server_url": self.server_url,
"connected": self._connected,
"available_tools_count": len(self.available_tools),
"message": "MCP配置有效" if self.server_url else "缺少服务器URL配置"
}
except Exception as e:
return {
"success": False,
"error": str(e)
}

View File

@@ -0,0 +1,626 @@
"""MCP客户端 - Model Context Protocol客户端实现"""
import asyncio
import json
import time
from typing import Dict, Any, List, Optional, Callable
from urllib.parse import urlparse
import aiohttp
import websockets
from websockets.exceptions import ConnectionClosed
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class MCPConnectionError(Exception):
"""MCP连接错误"""
pass
class MCPProtocolError(Exception):
"""MCP协议错误"""
pass
class MCPClient:
"""MCP客户端 - 支持HTTP和WebSocket连接"""
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
"""初始化MCP客户端
Args:
server_url: MCP服务器URL
connection_config: 连接配置
"""
self.server_url = server_url
self.connection_config = connection_config or {}
# 解析URL确定连接类型
parsed_url = urlparse(server_url)
self.connection_type = "websocket" if parsed_url.scheme in ["ws", "wss"] else "http"
# 连接状态
self._connected = False
self._websocket = None
self._session = None
# 请求管理
self._request_id = 0
self._pending_requests: Dict[str, asyncio.Future] = {}
# 连接池配置
self.max_connections = self.connection_config.get("max_connections", 10)
self.connection_timeout = self.connection_config.get("timeout", 30)
self.retry_attempts = self.connection_config.get("retry_attempts", 3)
self.retry_delay = self.connection_config.get("retry_delay", 1)
# 健康检查
self.health_check_interval = self.connection_config.get("health_check_interval", 60)
self._health_check_task = None
self._last_health_check = None
# 事件回调
self._on_connect_callbacks: List[Callable] = []
self._on_disconnect_callbacks: List[Callable] = []
self._on_error_callbacks: List[Callable] = []
async def connect(self) -> bool:
"""连接到MCP服务器
Returns:
连接是否成功
"""
try:
if self._connected:
return True
logger.info(f"连接MCP服务器: {self.server_url}")
if self.connection_type == "websocket":
success = await self._connect_websocket()
else:
success = await self._connect_http()
if success:
self._connected = True
await self._start_health_check()
await self._notify_connect_callbacks()
logger.info(f"MCP服务器连接成功: {self.server_url}")
return success
except Exception as e:
logger.error(f"连接MCP服务器失败: {self.server_url}, 错误: {e}")
await self._notify_error_callbacks(e)
return False
async def disconnect(self) -> bool:
"""断开MCP服务器连接
Returns:
断开是否成功
"""
try:
if not self._connected:
return True
logger.info(f"断开MCP服务器连接: {self.server_url}")
# 停止健康检查
await self._stop_health_check()
# 取消所有待处理的请求
for future in self._pending_requests.values():
if not future.done():
future.cancel()
self._pending_requests.clear()
# 断开连接
if self.connection_type == "websocket" and self._websocket:
await self._websocket.close()
self._websocket = None
elif self._session:
await self._session.close()
self._session = None
self._connected = False
await self._notify_disconnect_callbacks()
logger.info(f"MCP服务器连接已断开: {self.server_url}")
return True
except Exception as e:
logger.error(f"断开MCP服务器连接失败: {e}")
return False
async def _connect_websocket(self) -> bool:
"""建立WebSocket连接"""
try:
# WebSocket连接配置
extra_headers = self.connection_config.get("headers", {})
self._websocket = await websockets.connect(
self.server_url,
extra_headers=extra_headers,
timeout=self.connection_timeout
)
# 启动消息监听
asyncio.create_task(self._websocket_message_handler())
# 发送初始化消息
init_message = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"clientInfo": {
"name": "ToolManagementSystem",
"version": "1.0.0"
}
}
}
await self._websocket.send(json.dumps(init_message))
# 等待初始化响应
response = await asyncio.wait_for(
self._websocket.recv(),
timeout=self.connection_timeout
)
init_response = json.loads(response)
if "error" in init_response:
raise MCPProtocolError(f"初始化失败: {init_response['error']}")
return True
except Exception as e:
logger.error(f"WebSocket连接失败: {e}")
return False
async def _connect_http(self) -> bool:
"""建立HTTP连接"""
try:
# HTTP会话配置
timeout = aiohttp.ClientTimeout(total=self.connection_timeout)
headers = self.connection_config.get("headers", {})
self._session = aiohttp.ClientSession(
timeout=timeout,
headers=headers
)
# 测试连接
test_url = f"{self.server_url}/health" if not self.server_url.endswith('/') else f"{self.server_url}health"
async with self._session.get(test_url) as response:
if response.status == 200:
return True
else:
# 尝试根路径
async with self._session.get(self.server_url) as root_response:
return root_response.status < 400
except Exception as e:
logger.error(f"HTTP连接失败: {e}")
if self._session:
await self._session.close()
self._session = None
return False
async def _websocket_message_handler(self):
"""WebSocket消息处理器"""
try:
while self._websocket and not self._websocket.closed:
try:
message = await self._websocket.recv()
await self._handle_message(json.loads(message))
except ConnectionClosed:
break
except json.JSONDecodeError as e:
logger.error(f"解析WebSocket消息失败: {e}")
except Exception as e:
logger.error(f"处理WebSocket消息失败: {e}")
except Exception as e:
logger.error(f"WebSocket消息处理器异常: {e}")
finally:
self._connected = False
await self._notify_disconnect_callbacks()
async def _handle_message(self, message: Dict[str, Any]):
"""处理收到的消息"""
try:
# 检查是否是响应消息
if "id" in message:
request_id = str(message["id"])
if request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
if not future.done():
future.set_result(message)
# 处理通知消息
elif "method" in message:
await self._handle_notification(message)
except Exception as e:
logger.error(f"处理消息失败: {e}")
async def _handle_notification(self, message: Dict[str, Any]):
"""处理通知消息"""
method = message.get("method")
params = message.get("params", {})
logger.debug(f"收到MCP通知: {method}, 参数: {params}")
# 这里可以根据需要处理特定的通知
# 例如:工具列表更新、服务器状态变化等
async def call_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int = 30) -> Dict[str, Any]:
"""调用MCP工具
Args:
tool_name: 工具名称
arguments: 工具参数
timeout: 超时时间(秒)
Returns:
工具执行结果
Raises:
MCPConnectionError: 连接错误
MCPProtocolError: 协议错误
"""
if not self._connected:
raise MCPConnectionError("MCP客户端未连接")
request_data = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments
}
}
try:
response = await self._send_request(request_data, timeout)
if "error" in response:
error = response["error"]
raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}")
return response.get("result", {})
except asyncio.TimeoutError:
raise MCPProtocolError(f"工具调用超时: {tool_name}")
async def list_tools(self, timeout: int = 10) -> List[Dict[str, Any]]:
"""获取可用工具列表
Args:
timeout: 超时时间(秒)
Returns:
工具列表
Raises:
MCPConnectionError: 连接错误
MCPProtocolError: 协议错误
"""
if not self._connected:
raise MCPConnectionError("MCP客户端未连接")
request_data = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "tools/list"
}
try:
response = await self._send_request(request_data, timeout)
if not response["error"] is None:
error = response["error"]
raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}")
result = response.get("result", {})
return result.get("tools", [])
except asyncio.TimeoutError:
raise MCPProtocolError("获取工具列表超时")
async def _send_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
"""发送请求并等待响应
Args:
request_data: 请求数据
timeout: 超时时间(秒)
Returns:
响应数据
"""
request_id = str(request_data["id"])
if self.connection_type == "websocket":
return await self._send_websocket_request(request_data, request_id, timeout)
else:
return await self._send_http_request(request_data, timeout)
async def _send_websocket_request(self, request_data: Dict[str, Any], request_id: str, timeout: int) -> Dict[str, Any]:
"""发送WebSocket请求"""
if not self._websocket or self._websocket.closed:
raise MCPConnectionError("WebSocket连接已断开")
# 创建Future等待响应
future = asyncio.Future()
self._pending_requests[request_id] = future
try:
# 发送请求
await self._websocket.send(json.dumps(request_data))
# 等待响应
response = await asyncio.wait_for(future, timeout=timeout)
return response
except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None)
raise
except Exception as e:
self._pending_requests.pop(request_id, None)
raise MCPConnectionError(f"发送WebSocket请求失败: {e}")
async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
"""发送HTTP请求"""
if not self._session:
raise MCPConnectionError("HTTP会话未建立")
try:
url = f"{self.server_url}/mcp" if not self.server_url.endswith('/') else f"{self.server_url}mcp"
async with self._session.post(
url,
json=request_data,
timeout=aiohttp.ClientTimeout(total=timeout)
) as response:
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
return await response.json()
except aiohttp.ClientError as e:
raise MCPConnectionError(f"HTTP请求失败: {e}")
async def health_check(self) -> Dict[str, Any]:
"""执行健康检查
Returns:
健康状态信息
"""
try:
if not self._connected:
return {
"healthy": False,
"error": "未连接",
"timestamp": time.time()
}
# 发送ping请求
request_data = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "ping"
}
start_time = time.time()
response = await self._send_request(request_data, timeout=5)
response_time = time.time() - start_time
self._last_health_check = time.time()
return {
"healthy": True,
"response_time": response_time,
"timestamp": self._last_health_check,
"server_info": response.get("result", {})
}
except Exception as e:
return {
"healthy": False,
"error": str(e),
"timestamp": time.time()
}
async def _start_health_check(self):
"""启动健康检查任务"""
if self.health_check_interval > 0:
self._health_check_task = asyncio.create_task(self._health_check_loop())
async def _stop_health_check(self):
"""停止健康检查任务"""
if self._health_check_task:
self._health_check_task.cancel()
try:
await self._health_check_task
except asyncio.CancelledError:
pass
self._health_check_task = None
async def _health_check_loop(self):
"""健康检查循环"""
try:
while self._connected:
await asyncio.sleep(self.health_check_interval)
if self._connected:
health_status = await self.health_check()
if not health_status["healthy"]:
logger.warning(f"MCP服务器健康检查失败: {health_status.get('error')}")
# 可以在这里实现重连逻辑
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"健康检查循环异常: {e}")
def _get_next_request_id(self) -> str:
"""获取下一个请求ID"""
self._request_id += 1
return f"req_{self._request_id}_{int(time.time() * 1000)}"
# 事件回调管理
def on_connect(self, callback: Callable):
"""注册连接回调"""
self._on_connect_callbacks.append(callback)
def on_disconnect(self, callback: Callable):
"""注册断开连接回调"""
self._on_disconnect_callbacks.append(callback)
def on_error(self, callback: Callable):
"""注册错误回调"""
self._on_error_callbacks.append(callback)
async def _notify_connect_callbacks(self):
"""通知连接回调"""
for callback in self._on_connect_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback()
else:
callback()
except Exception as e:
logger.error(f"连接回调执行失败: {e}")
async def _notify_disconnect_callbacks(self):
"""通知断开连接回调"""
for callback in self._on_disconnect_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback()
else:
callback()
except Exception as e:
logger.error(f"断开连接回调执行失败: {e}")
async def _notify_error_callbacks(self, error: Exception):
"""通知错误回调"""
for callback in self._on_error_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(error)
else:
callback(error)
except Exception as e:
logger.error(f"错误回调执行失败: {e}")
@property
def is_connected(self) -> bool:
"""检查是否已连接"""
return self._connected
@property
def last_health_check(self) -> Optional[float]:
"""获取最后一次健康检查时间"""
return self._last_health_check
def get_connection_info(self) -> Dict[str, Any]:
"""获取连接信息"""
return {
"server_url": self.server_url,
"connection_type": self.connection_type,
"connected": self._connected,
"last_health_check": self._last_health_check,
"pending_requests": len(self._pending_requests),
"config": self.connection_config
}
async def __aenter__(self):
"""异步上下文管理器入口"""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口"""
await self.disconnect()
class MCPConnectionPool:
"""MCP连接池 - 管理多个MCP客户端连接"""
def __init__(self, max_connections: int = 10):
"""初始化连接池
Args:
max_connections: 最大连接数
"""
self.max_connections = max_connections
self._clients: Dict[str, MCPClient] = {}
self._lock = asyncio.Lock()
async def get_client(self, server_url: str, connection_config: Dict[str, Any] = None) -> MCPClient:
"""获取或创建MCP客户端
Args:
server_url: 服务器URL
connection_config: 连接配置
Returns:
MCP客户端实例
"""
async with self._lock:
if server_url in self._clients:
client = self._clients[server_url]
if client.is_connected:
return client
else:
# 尝试重连
if await client.connect():
return client
else:
# 移除失效的客户端
del self._clients[server_url]
# 检查连接数限制
if len(self._clients) >= self.max_connections:
# 移除最旧的连接
oldest_url = next(iter(self._clients))
await self._clients[oldest_url].disconnect()
del self._clients[oldest_url]
# 创建新客户端
client = MCPClient(server_url, connection_config)
if await client.connect():
self._clients[server_url] = client
return client
else:
raise MCPConnectionError(f"无法连接到MCP服务器: {server_url}")
async def disconnect_all(self):
"""断开所有连接"""
async with self._lock:
for client in self._clients.values():
await client.disconnect()
self._clients.clear()
def get_pool_status(self) -> Dict[str, Any]:
"""获取连接池状态"""
return {
"total_connections": len(self._clients),
"max_connections": self.max_connections,
"connections": {
url: client.get_connection_info()
for url, client in self._clients.items()
}
}

View File

@@ -0,0 +1,604 @@
"""MCP服务管理器 - 管理MCP服务的注册、更新、删除和状态监控"""
import asyncio
import time
import uuid
from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime
from sqlalchemy.orm import Session
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType
from app.core.logging_config import get_business_logger
from .client import MCPClient, MCPConnectionPool
logger = get_business_logger()
class MCPServiceManager:
"""MCP服务管理器 - 管理MCP服务的生命周期"""
def __init__(self, db: Session):
"""初始化MCP服务管理器
Args:
db: 数据库会话
"""
self.db = db
self.connection_pool = MCPConnectionPool(max_connections=20)
# 服务状态管理
self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info
self._monitoring_tasks: Dict[str, asyncio.Task] = {} # service_id -> monitoring_task
# 配置
self.health_check_interval = 60 # 健康检查间隔(秒)
self.max_retry_attempts = 3 # 最大重试次数
self.retry_delay = 5 # 重试延迟(秒)
# 状态
self._running = False
self._manager_task = None
async def start(self):
"""启动服务管理器"""
if self._running:
return
self._running = True
logger.info("MCP服务管理器启动")
# 加载现有服务
await self._load_existing_services()
# 启动管理任务
self._manager_task = asyncio.create_task(self._management_loop())
async def stop(self):
"""停止服务管理器"""
if not self._running:
return
self._running = False
logger.info("MCP服务管理器停止")
# 停止管理任务
if self._manager_task:
self._manager_task.cancel()
try:
await self._manager_task
except asyncio.CancelledError:
pass
# 停止所有监控任务
for task in self._monitoring_tasks.values():
task.cancel()
if self._monitoring_tasks:
await asyncio.gather(*self._monitoring_tasks.values(), return_exceptions=True)
self._monitoring_tasks.clear()
# 断开所有连接
await self.connection_pool.disconnect_all()
async def register_service(
self,
server_url: str,
connection_config: Dict[str, Any],
tenant_id: uuid.UUID,
service_name: str = None
) -> Tuple[bool, str, Optional[str]]:
"""注册MCP服务
Args:
server_url: 服务器URL
connection_config: 连接配置
tenant_id: 租户ID
service_name: 服务名称(可选)
Returns:
(是否成功, 服务ID或错误信息, 错误详情)
"""
try:
# 检查服务是否已存在
existing_service = self.db.query(MCPToolConfig).filter(
MCPToolConfig.server_url == server_url
).first()
if existing_service:
return False, "服务已存在", f"URL {server_url} 已被注册"
# 测试连接
try:
client = MCPClient(server_url, connection_config)
if not await client.connect():
return False, "连接测试失败", "无法连接到MCP服务器"
# 获取可用工具
available_tools = await client.list_tools()
tool_names = [tool.get("name") for tool in available_tools if tool.get("name")]
await client.disconnect()
except Exception as e:
return False, "连接测试失败", str(e)
# 创建工具配置
if not service_name:
service_name = f"mcp_service_{server_url.split('/')[-1]}"
tool_config = ToolConfig(
name=service_name,
description=f"MCP服务 - {server_url}",
tool_type=ToolType.MCP.value,
tenant_id=tenant_id,
version="1.0.0",
config_data={
"server_url": server_url,
"connection_config": connection_config
}
)
self.db.add(tool_config)
self.db.flush()
# 创建MCP特定配置
mcp_config = MCPToolConfig(
id=tool_config.id,
server_url=server_url,
connection_config=connection_config,
available_tools=tool_names,
health_status="healthy",
last_health_check=datetime.utcnow()
)
self.db.add(mcp_config)
self.db.commit()
service_id = str(tool_config.id)
# 添加到内存管理
self._services[service_id] = {
"id": service_id,
"server_url": server_url,
"connection_config": connection_config,
"tenant_id": tenant_id,
"available_tools": tool_names,
"status": "healthy",
"last_health_check": time.time(),
"retry_count": 0,
"created_at": time.time()
}
# 启动监控
await self._start_service_monitoring(service_id)
logger.info(f"MCP服务注册成功: {service_id} ({server_url})")
return True, service_id, None
except Exception as e:
self.db.rollback()
logger.error(f"注册MCP服务失败: {server_url}, 错误: {e}")
return False, "注册失败", str(e)
async def unregister_service(self, service_id: str) -> Tuple[bool, str]:
"""注销MCP服务
Args:
service_id: 服务ID
Returns:
(是否成功, 错误信息)
"""
try:
# 从数据库删除
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
if not tool_config:
return False, "服务不存在"
self.db.delete(tool_config)
self.db.commit()
# 停止监控
await self._stop_service_monitoring(service_id)
# 从内存移除
if service_id in self._services:
del self._services[service_id]
logger.info(f"MCP服务注销成功: {service_id}")
return True, ""
except Exception as e:
self.db.rollback()
logger.error(f"注销MCP服务失败: {service_id}, 错误: {e}")
return False, str(e)
async def update_service(
self,
service_id: str,
connection_config: Dict[str, Any] = None,
enabled: bool = None
) -> Tuple[bool, str]:
"""更新MCP服务配置
Args:
service_id: 服务ID
connection_config: 新的连接配置
enabled: 是否启用
Returns:
(是否成功, 错误信息)
"""
try:
# 更新数据库
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == uuid.UUID(service_id)
).first()
if not mcp_config:
return False, "服务不存在"
tool_config = mcp_config.base_config
if connection_config is not None:
mcp_config.connection_config = connection_config
tool_config.config_data["connection_config"] = connection_config
if enabled is not None:
tool_config.is_enabled = enabled
self.db.commit()
# 更新内存状态
if service_id in self._services:
if connection_config is not None:
self._services[service_id]["connection_config"] = connection_config
# 如果配置有变化,重启监控
if connection_config is not None:
await self._restart_service_monitoring(service_id)
logger.info(f"MCP服务更新成功: {service_id}")
return True, ""
except Exception as e:
self.db.rollback()
logger.error(f"更新MCP服务失败: {service_id}, 错误: {e}")
return False, str(e)
async def get_service_status(self, service_id: str) -> Optional[Dict[str, Any]]:
"""获取服务状态
Args:
service_id: 服务ID
Returns:
服务状态信息
"""
if service_id not in self._services:
return None
service_info = self._services[service_id].copy()
# 添加实时健康检查
try:
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
health_status = await client.health_check()
service_info["real_time_health"] = health_status
except Exception as e:
service_info["real_time_health"] = {
"healthy": False,
"error": str(e),
"timestamp": time.time()
}
return service_info
async def list_services(self, tenant_id: uuid.UUID = None) -> List[Dict[str, Any]]:
"""列出所有服务
Args:
tenant_id: 租户ID过滤
Returns:
服务列表
"""
services = []
for service_id, service_info in self._services.items():
if tenant_id and service_info["tenant_id"] != tenant_id:
continue
services.append(service_info.copy())
return services
async def get_service_tools(self, service_id: str) -> List[Dict[str, Any]]:
"""获取服务的可用工具
Args:
service_id: 服务ID
Returns:
工具列表
"""
if service_id not in self._services:
return []
service_info = self._services[service_id]
try:
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
tools = await client.list_tools()
# 更新缓存的工具列表
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
service_info["available_tools"] = tool_names
# 更新数据库
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == uuid.UUID(service_id)
).first()
if mcp_config:
mcp_config.available_tools = tool_names
self.db.commit()
return tools
except Exception as e:
logger.error(f"获取服务工具失败: {service_id}, 错误: {e}")
return []
async def call_service_tool(
self,
service_id: str,
tool_name: str,
arguments: Dict[str, Any],
timeout: int = 30
) -> Dict[str, Any]:
"""调用服务工具
Args:
service_id: 服务ID
tool_name: 工具名称
arguments: 工具参数
timeout: 超时时间
Returns:
执行结果
"""
if service_id not in self._services:
raise ValueError(f"服务不存在: {service_id}")
service_info = self._services[service_id]
try:
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
result = await client.call_tool(tool_name, arguments, timeout)
# 更新服务状态为健康
service_info["status"] = "healthy"
service_info["last_health_check"] = time.time()
service_info["retry_count"] = 0
return result
except Exception as e:
# 更新服务状态为错误
service_info["status"] = "error"
service_info["last_error"] = str(e)
service_info["retry_count"] += 1
logger.error(f"调用服务工具失败: {service_id}/{tool_name}, 错误: {e}")
raise
async def _load_existing_services(self):
"""加载现有服务"""
try:
mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter(
ToolConfig.is_enabled == True
).all()
for mcp_config in mcp_configs:
tool_config = mcp_config.base_config
service_id = str(mcp_config.id)
self._services[service_id] = {
"id": service_id,
"server_url": mcp_config.server_url,
"connection_config": mcp_config.connection_config or {},
"tenant_id": tool_config.tenant_id,
"available_tools": mcp_config.available_tools or [],
"status": mcp_config.health_status or "unknown",
"last_health_check": mcp_config.last_health_check.timestamp() if mcp_config.last_health_check else 0,
"retry_count": 0,
"created_at": tool_config.created_at.timestamp()
}
# 启动监控
await self._start_service_monitoring(service_id)
logger.info(f"加载了 {len(mcp_configs)} 个MCP服务")
except Exception as e:
logger.error(f"加载现有服务失败: {e}")
async def _start_service_monitoring(self, service_id: str):
"""启动服务监控"""
if service_id in self._monitoring_tasks:
return
task = asyncio.create_task(self._monitor_service(service_id))
self._monitoring_tasks[service_id] = task
async def _stop_service_monitoring(self, service_id: str):
"""停止服务监控"""
if service_id in self._monitoring_tasks:
task = self._monitoring_tasks.pop(service_id)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
async def _restart_service_monitoring(self, service_id: str):
"""重启服务监控"""
await self._stop_service_monitoring(service_id)
await self._start_service_monitoring(service_id)
async def _monitor_service(self, service_id: str):
"""监控单个服务"""
try:
while self._running and service_id in self._services:
service_info = self._services[service_id]
try:
# 执行健康检查
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
health_status = await client.health_check()
if health_status["healthy"]:
# 服务健康
service_info["status"] = "healthy"
service_info["retry_count"] = 0
# 更新工具列表
try:
tools = await client.list_tools()
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
service_info["available_tools"] = tool_names
except Exception as e:
logger.warning(f"更新工具列表失败: {service_id}, 错误: {e}")
else:
# 服务不健康
service_info["status"] = "unhealthy"
service_info["last_error"] = health_status.get("error", "健康检查失败")
service_info["retry_count"] += 1
service_info["last_health_check"] = time.time()
# 更新数据库
await self._update_service_health_in_db(service_id, health_status)
except Exception as e:
# 监控异常
service_info["status"] = "error"
service_info["last_error"] = str(e)
service_info["retry_count"] += 1
service_info["last_health_check"] = time.time()
logger.error(f"服务监控异常: {service_id}, 错误: {e}")
# 如果重试次数过多,暂停监控
if service_info["retry_count"] >= self.max_retry_attempts:
logger.warning(f"服务 {service_id} 重试次数过多,暂停监控")
await asyncio.sleep(self.health_check_interval * 5) # 延长等待时间
service_info["retry_count"] = 0 # 重置重试计数
# 等待下次检查
await asyncio.sleep(self.health_check_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"服务监控任务异常: {service_id}, 错误: {e}")
async def _update_service_health_in_db(self, service_id: str, health_status: Dict[str, Any]):
"""更新数据库中的服务健康状态"""
try:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == uuid.UUID(service_id)
).first()
if mcp_config:
mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy"
mcp_config.last_health_check = datetime.utcnow()
if not health_status["healthy"]:
mcp_config.error_message = health_status.get("error", "")
else:
mcp_config.error_message = None
self.db.commit()
except Exception as e:
logger.error(f"更新数据库健康状态失败: {service_id}, 错误: {e}")
self.db.rollback()
async def _management_loop(self):
"""管理循环 - 处理服务清理等任务"""
try:
while self._running:
# 清理失效的服务
await self._cleanup_failed_services()
# 等待下次循环
await asyncio.sleep(300) # 5分钟
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"管理循环异常: {e}")
async def _cleanup_failed_services(self):
"""清理长期失效的服务"""
try:
current_time = time.time()
cleanup_threshold = 24 * 60 * 60 # 24小时
services_to_cleanup = []
for service_id, service_info in self._services.items():
# 检查服务是否长期失效
if (service_info["status"] in ["error", "unhealthy"] and
current_time - service_info["last_health_check"] > cleanup_threshold):
services_to_cleanup.append(service_id)
for service_id in services_to_cleanup:
logger.warning(f"清理长期失效的服务: {service_id}")
# 停止监控但不删除数据库记录
await self._stop_service_monitoring(service_id)
# 标记为禁用
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
if tool_config:
tool_config.is_enabled = False
self.db.commit()
# 从内存移除
del self._services[service_id]
except Exception as e:
logger.error(f"清理失效服务失败: {e}")
def get_manager_status(self) -> Dict[str, Any]:
"""获取管理器状态"""
return {
"running": self._running,
"total_services": len(self._services),
"healthy_services": len([s for s in self._services.values() if s["status"] == "healthy"]),
"unhealthy_services": len([s for s in self._services.values() if s["status"] in ["unhealthy", "error"]]),
"monitoring_tasks": len(self._monitoring_tasks),
"connection_pool_status": self.connection_pool.get_pool_status()
}

View File

@@ -0,0 +1,436 @@
"""工具注册表 - 管理所有工具的元数据和状态"""
import uuid
import asyncio
from typing import Dict, List, Optional, Type, Any
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
from app.models.tool_model import (
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
ToolType, ToolStatus, ToolExecution, ExecutionStatus
)
from app.core.logging_config import get_business_logger
from .base import BaseTool, ToolInfo
from .custom.base import CustomTool
from .mcp.base import MCPTool
logger = get_business_logger()
class ToolRegistry:
"""工具注册表 - 管理所有工具的元数据和实例"""
def __init__(self, db: Session):
"""初始化工具注册表
Args:
db: 数据库会话
"""
self.db = db
self._tools: Dict[str, BaseTool] = {} # 工具实例缓存
self._tool_classes: Dict[str, Type[BaseTool]] = {} # 工具类注册表
self._lock = asyncio.Lock() # 异步锁
def register_tool_class(self, tool_class: Type[BaseTool], class_name: str = None):
"""注册工具类
Args:
tool_class: 工具类
class_name: 类名可选默认使用类的__name__
"""
class_name = class_name or tool_class.__name__
self._tool_classes[class_name] = tool_class
logger.info(f"工具类已注册: {class_name}")
async def register_tool(self, tool: BaseTool, tenant_id: Optional[uuid.UUID] = None) -> bool:
"""注册工具实例到系统
Args:
tool: 工具实例
tenant_id: 租户ID内置工具可以为None表示全局工具
Returns:
注册是否成功
"""
async with self._lock:
try:
# 检查工具是否已存在
if tenant_id:
existing_config = self.db.query(ToolConfig).filter(
and_(
ToolConfig.name == tool.name,
ToolConfig.tenant_id == tenant_id,
ToolConfig.tool_type == tool.tool_type.value
)
).first()
else:
# 全局工具(内置工具)
existing_config = self.db.query(ToolConfig).filter(
and_(
ToolConfig.name == tool.name,
ToolConfig.tenant_id.is_(None),
ToolConfig.tool_type == tool.tool_type.value
)
).first()
if existing_config:
logger.warning(f"工具已存在: {tool.name} (tenant: {tenant_id or 'global'})")
return False
# 创建工具配置
tool_config = ToolConfig(
name=tool.name,
description=tool.description,
tool_type=tool.tool_type.value,
tenant_id=tenant_id,
version=tool.version,
tags=tool.tags,
config_data=tool.config
)
self.db.add(tool_config)
self.db.flush() # 获取ID
# 根据工具类型创建特定配置
if tool.tool_type == ToolType.BUILTIN:
builtin_config = BuiltinToolConfig(
id=tool_config.id,
tool_class=tool.__class__.__name__,
parameters=tool.config.get("parameters", {})
)
self.db.add(builtin_config)
elif tool.tool_type == ToolType.CUSTOM:
custom_config = CustomToolConfig(
id=tool_config.id,
schema_url=tool.config.get("schema_url"),
schema_content=tool.config.get("schema_content"),
auth_type=tool.config.get("auth_type", "none"),
auth_config=tool.config.get("auth_config", {}),
base_url=tool.config.get("base_url"),
timeout=tool.config.get("timeout", 30)
)
self.db.add(custom_config)
elif tool.tool_type == ToolType.MCP:
mcp_config = MCPToolConfig(
id=tool_config.id,
server_url=tool.config.get("server_url"),
connection_config=tool.config.get("connection_config", {}),
available_tools=tool.config.get("available_tools", [])
)
self.db.add(mcp_config)
self.db.commit()
# 缓存工具实例
tool.tool_id = str(tool_config.id)
self._tools[str(tool_config.id)] = tool
logger.info(f"工具注册成功: {tool.name} (ID: {tool_config.id})")
return True
except Exception as e:
self.db.rollback()
logger.error(f"工具注册失败: {tool.name}, 错误: {e}")
return False
async def unregister_tool(self, tool_id: str) -> bool:
"""从系统注销工具
Args:
tool_id: 工具ID
Returns:
注销是否成功
"""
async with self._lock:
try:
# 检查工具是否存在
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
if not tool_config:
logger.warning(f"工具不存在: {tool_id}")
return False
# 检查是否有正在执行的任务
running_executions = self.db.query(ToolExecution).filter(
and_(
ToolExecution.tool_config_id == uuid.UUID(tool_id),
ToolExecution.status.in_([ExecutionStatus.PENDING.value, ExecutionStatus.RUNNING.value])
)
).count()
if running_executions > 0:
logger.warning(f"工具有正在执行的任务,无法注销: {tool_id}")
return False
# 删除工具配置(级联删除相关记录)
self.db.delete(tool_config)
self.db.commit()
# 从缓存中移除
if tool_id in self._tools:
del self._tools[tool_id]
logger.info(f"工具注销成功: {tool_id}")
return True
except Exception as e:
self.db.rollback()
logger.error(f"工具注销失败: {tool_id}, 错误: {e}")
return False
def get_tool(self, tool_id: str) -> Optional[BaseTool]:
"""获取工具实例
Args:
tool_id: 工具ID
Returns:
工具实例如果不存在返回None
"""
# 先从缓存获取
if tool_id in self._tools:
return self._tools[tool_id]
# 从数据库加载
try:
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
if not tool_config or not tool_config.status == ToolStatus.ACTIVE.value:
return None
# 根据工具类型加载实例
tool_instance = self._load_tool_instance(tool_config)
if tool_instance:
self._tools[tool_id] = tool_instance
return tool_instance
except Exception as e:
logger.error(f"加载工具失败: {tool_id}, 错误: {e}")
return None
def list_tools(
self,
tenant_id: Optional[uuid.UUID] = None,
tool_type: Optional[ToolType] = None,
status: Optional[ToolStatus] = None,
tags: Optional[List[str]] = None
) -> List[ToolInfo]:
"""列出工具
Args:
tenant_id: 租户ID过滤
tool_type: 工具类型过滤
status: 工具状态过滤
tags: 标签过滤
Returns:
工具信息列表
"""
try:
query = self.db.query(ToolConfig)
# 应用过滤条件
if tenant_id:
# 返回全局工具tenant_id为空和该租户的工具
query = query.filter(
or_(
ToolConfig.tenant_id == tenant_id,
ToolConfig.tenant_id.is_(None)
)
)
if tool_type:
query = query.filter(ToolConfig.tool_type == tool_type.value)
if status == ToolStatus.ACTIVE:
query = query.filter(ToolConfig.is_enabled == True)
elif status == ToolStatus.INACTIVE:
query = query.filter(ToolConfig.is_enabled == False)
if tags:
for tag in tags:
query = query.filter(ToolConfig.tags.contains([tag]))
tool_configs = query.all()
# 转换为ToolInfo
tool_infos = []
for config in tool_configs:
tool_info = ToolInfo(
id=str(config.id),
name=config.name,
description=config.description or "",
tool_type=ToolType(config.tool_type),
version=config.version,
status=ToolStatus.ACTIVE if config.is_enabled else ToolStatus.INACTIVE,
tags=config.tags or [],
tenant_id=str(config.tenant_id) if config.tenant_id else None
)
# 尝试获取参数信息
tool_instance = self.get_tool(str(config.id))
if tool_instance:
tool_info.parameters = tool_instance.parameters
tool_infos.append(tool_info)
return tool_infos
except Exception as e:
logger.error(f"列出工具失败, 错误: {e}")
return []
async def update_tool_status(self, tool_id: str, status: ToolStatus) -> bool:
"""更新工具状态
Args:
tool_id: 工具ID
status: 新状态
Returns:
更新是否成功
"""
try:
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
if not tool_config:
logger.warning(f"工具不存在: {tool_id}")
return False
# 更新状态
if status == ToolStatus.ACTIVE:
tool_config.is_enabled = True
elif status == ToolStatus.INACTIVE:
tool_config.is_enabled = False
self.db.commit()
# 更新缓存中的工具状态
if tool_id in self._tools:
self._tools[tool_id].status = status
logger.info(f"工具状态更新成功: {tool_id} -> {status}")
return True
except Exception as e:
self.db.rollback()
logger.error(f"工具状态更新失败: {tool_id}, 错误: {e}")
return False
def _load_tool_instance(self, tool_config: type[ToolConfig] | None) -> Optional[BaseTool]:
"""从配置加载工具实例
Args:
tool_config: 工具配置
Returns:
工具实例
"""
try:
if tool_config.tool_type == ToolType.BUILTIN.value:
# 加载内置工具
builtin_config = self.db.query(BuiltinToolConfig).filter(
BuiltinToolConfig.id == tool_config.id
).first()
if builtin_config and builtin_config.tool_class in self._tool_classes:
tool_class = self._tool_classes[builtin_config.tool_class]
config = {
**tool_config.config_data,
"parameters": builtin_config.parameters,
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
"version": tool_config.version,
"tags": tool_config.tags
}
return tool_class(str(tool_config.id), config)
elif tool_config.tool_type == ToolType.CUSTOM.value:
# 加载自定义工具
try:
custom_config = self.db.query(CustomToolConfig).filter(
CustomToolConfig.id == tool_config.id
).first()
if custom_config:
config = {
**tool_config.config_data,
"schema_url": custom_config.schema_url,
"schema_content": custom_config.schema_content,
"auth_type": custom_config.auth_type,
"auth_config": custom_config.auth_config,
"base_url": custom_config.base_url,
"timeout": custom_config.timeout,
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
"version": tool_config.version,
"tags": tool_config.tags
}
return CustomTool(str(tool_config.id), config)
except ImportError as e:
logger.error(f"无法导入自定义工具模块: {e}")
elif tool_config.tool_type == ToolType.MCP.value:
# 加载MCP工具
try:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == tool_config.id
).first()
if mcp_config:
config = {
**tool_config.config_data,
"server_url": mcp_config.server_url,
"connection_config": mcp_config.connection_config,
"available_tools": mcp_config.available_tools,
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
"version": tool_config.version,
"tags": tool_config.tags
}
return MCPTool(str(tool_config.id), config)
except ImportError as e:
logger.error(f"无法导入MCP工具模块: {e}")
except Exception as e:
logger.error(f"加载工具实例失败: {tool_config.id}, 错误: {e}")
return None
def get_tool_statistics(self, tenant_id: Optional[uuid.UUID] = None) -> Dict[str, Any]:
"""获取工具统计信息
Args:
tenant_id: 租户ID
Returns:
统计信息字典
"""
try:
query = self.db.query(ToolConfig)
if tenant_id:
query = query.filter(ToolConfig.tenant_id == tenant_id)
total_tools = query.count()
active_tools = query.filter(ToolConfig.is_enabled == True).count()
# 按类型统计
type_stats = {}
for tool_type in ToolType:
count = query.filter(ToolConfig.tool_type == tool_type.value).count()
type_stats[tool_type.value] = count
return {
"total_tools": total_tools,
"active_tools": active_tools,
"inactive_tools": total_tools - active_tools,
"by_type": type_stats
}
except Exception as e:
logger.error(f"获取工具统计失败, 错误: {e}")
return {}
def clear_cache(self):
"""清空工具缓存"""
self._tools.clear()
logger.info("工具缓存已清空")

View File

@@ -5,36 +5,41 @@
"""
import logging
import uuid
# import uuid
import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.expression_evaluator import evaluate_condition
from app.models.workflow_model import WorkflowExecution, WorkflowNodeExecution
from app.db import get_db
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.nodes.enums import NodeType
# from app.core.tools.registry import ToolRegistry
# from app.core.tools.executor import ToolExecutor
# from app.core.tools.langchain_adapter import LangchainAdapter
# TOOL_MANAGEMENT_AVAILABLE = True
# from app.db import get_db
logger = logging.getLogger(__name__)
class WorkflowExecutor:
"""工作流执行器
负责将工作流配置转换为 LangGraph 并执行。
"""
def __init__(
self,
workflow_config: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
self,
workflow_config: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
):
"""初始化执行器
Args:
workflow_config: 工作流配置
execution_id: 执行 ID
@@ -48,25 +53,25 @@ class WorkflowExecutor:
self.nodes = workflow_config.get("nodes", [])
self.edges = workflow_config.get("edges", [])
self.execution_config = workflow_config.get("execution_config", {})
def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
"""准备初始状态(注入系统变量和会话变量)
变量命名空间:
- sys.xxx - 系统变量execution_id, workspace_id, user_id, message, input_variables 等)
- conv.xxx - 会话变量(跨多轮对话保持)
- node_id.xxx - 节点输出(执行时动态生成)
Args:
input_data: 输入数据
Returns:
初始化的工作流状态
"""
user_message = input_data.get("message") or ""
conversation_vars = input_data.get("conversation_vars") or {}
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
# 构建分层的变量结构
variables = {
"sys": {
@@ -79,7 +84,7 @@ class WorkflowExecutor:
},
"conv": conversation_vars # 会话级变量(跨多轮对话保持)
}
return {
"messages": [HumanMessage(content=user_message)],
"variables": variables,
@@ -89,163 +94,277 @@ class WorkflowExecutor:
"workspace_id": self.workspace_id,
"user_id": self.user_id,
"error": None,
"error_node": None
"error_node": None,
"streaming_buffer": {} # 流式缓冲区
}
def build_graph(self) -> StateGraph:
"""构建 LangGraph
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
"""分析 End 节点的前缀配置
检查每个 End 节点的模板,找到直接上游节点的引用,
提取该引用之前的前缀部分。
Returns:
元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合})
"""
import re
prefixes = {}
adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点
# 找到所有 End 节点
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点")
for end_node in end_nodes:
end_node_id = end_node.get("id")
output_template = end_node.get("config", {}).get("output")
logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}")
if not output_template:
continue
# 找到所有直接连接到 End 节点的上游节点
direct_upstream_nodes = []
for edge in self.edges:
if edge.get("target") == end_node_id:
source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id)
logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}")
# 查找模板中引用了哪些节点
# 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格)
pattern = r'\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*\}\}'
matches = list(re.finditer(pattern, output_template))
logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用")
# 找到第一个直接上游节点的引用
for match in matches:
referenced_node_id = match.group(1)
logger.info(f"[前缀分析] 检查引用: {referenced_node_id}")
if referenced_node_id in direct_upstream_nodes:
# 这是直接上游节点的引用,提取前缀
prefix = output_template[:match.start()]
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
# 标记这个节点为"相邻且被引用"
adjacent_and_referenced.add(referenced_node_id)
if prefix:
prefixes[referenced_node_id] = prefix
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
# 只处理第一个直接上游节点的引用
break
logger.info(f"[前缀分析] 最终配置: {prefixes}")
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
return prefixes, adjacent_and_referenced
def build_graph(self,stream=False) -> CompiledStateGraph:
"""构建 LangGraph
Returns:
编译后的状态图
"""
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
# 分析 End 节点的前缀配置和相邻且被引用的节点
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if stream else ({}, set())
# 1. 创建状态图
workflow = StateGraph(WorkflowState)
# 2. 添加所有节点(包括 start 和 end
start_node_id = None
end_node_ids = []
for node in self.nodes:
node_type = node.get("type")
node_id = node.get("id")
# 记录 start 和 end 节点 ID
if node_type == "start":
if node_type == NodeType.START:
start_node_id = node_id
elif node_type == "end":
elif node_type == NodeType.END:
end_node_ids.append(node_id)
# 创建节点实例(现在 start 和 end 也会被创建)
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_type in [NodeType.IF_ELSE]:
expressions = node_instance.build_conditional_edge_expressions()
# Number of branches, usually matches the number of conditional expressions
branch_number = len(expressions)
# Find all edges whose source is the current node
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
# Iterate over each branch
for idx in range(branch_number):
# Generate a condition expression for each edge
# Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
if node_instance:
# 如果是流式模式,且节点有 End 前缀配置,注入配置
if stream and node_id in end_prefixes:
# 将 End 前缀配置注入到节点实例
node_instance._end_node_prefix = end_prefixes[node_id]
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
# 如果是流式模式,标记节点是否与 End 相邻且被引用
if stream:
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
if node_id in adjacent_and_referenced:
logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用")
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
def make_node_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
workflow.add_node(node_id, make_node_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type})")
if stream:
# 流式模式:创建 async generator 函数
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
def make_stream_func(inst):
async def node_func(state: WorkflowState):
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
async for item in inst.run_stream(state):
yield item
return node_func
workflow.add_node(node_id, make_stream_func(node_instance))
else:
# 非流式模式:创建 async function
def make_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
workflow.add_node(node_id, make_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={stream})")
# 3. 添加边
# 从 START 连接到 start 节点
if start_node_id:
workflow.add_edge(START, start_node_id)
logger.debug(f"添加边: START -> {start_node_id}")
for edge in self.edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
condition = edge.get("condition")
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start
if source == start_node_id:
# 但要连接 start 到下一个节点
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
continue
# 处理到 end 节点的边
if target in end_node_ids:
# 连接到 end 节点
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
continue
# 跳过错误边(在节点内部处理)
if edge_type == "error":
continue
if condition:
# 条件边
def router(state: WorkflowState, cond=condition, tgt=target):
"""条件路由函数"""
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END # 条件不满足,结束
workflow.add_conditional_edges(source, router)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
else:
# 普通边
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
# 从 end 节点连接到 END
for end_node_id in end_node_ids:
workflow.add_edge(end_node_id, END)
logger.debug(f"添加边: {end_node_id} -> END")
# 4. 编译图
graph = workflow.compile()
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
return graph
async def execute(
self,
input_data: dict[str, Any]
self,
input_data: dict[str, Any]
) -> dict[str, Any]:
"""执行工作流(非流式)
Args:
input_data: 输入数据,包含 message 和 variables
Returns:
执行结果,包含 status, output, node_outputs, elapsed_time, token_usage
"""
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
# 记录开始时间
start_time = datetime.datetime.now()
# 1. 构建图
graph = self.build_graph()
# 2. 初始化状态(自动注入系统变量)
initial_state = self._prepare_initial_state(input_data)
# 3. 执行工作流
try:
result = await graph.ainvoke(initial_state)
# 计算耗时
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
# 提取节点输出(现在包含 start 和 end 节点)
node_outputs = result.get("node_outputs", {})
# 提取最终输出(从最后一个非 start/end 节点)
final_output = self._extract_final_output(node_outputs)
# 聚合 token 使用情况
token_usage = self._aggregate_token_usage(node_outputs)
# 提取 conversation_id从 start 节点输出)
conversation_id = None
for node_id, node_output in node_outputs.items():
if node_output.get("node_type") == "start":
conversation_id = node_output.get("output", {}).get("conversation_id")
break
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
return {
"status": "completed",
"output": final_output,
@@ -256,12 +375,12 @@ class WorkflowExecutor:
"token_usage": token_usage,
"error": result.get("error")
}
except Exception as e:
# 计算耗时(即使失败也记录)
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True)
return {
"status": "failed",
@@ -271,86 +390,200 @@ class WorkflowExecutor:
"elapsed_time": elapsed_time,
"token_usage": None
}
async def execute_stream(
self,
input_data: dict[str, Any]
self,
input_data: dict[str, Any]
):
"""执行工作流(流式)
使用多个 stream_mode 来获取:
1. "updates" - 节点的 state 更新和流式 chunk
2. "debug" - 节点执行的详细信息(开始/完成时间)
3. "custom" - 自定义流式数据chunks
Args:
input_data: 输入数据
Yields:
流式事件
流式事件,格式:
{
"event": "workflow_start" | "workflow_end" | "node_start" | "node_end" | "node_chunk" | "message",
"data": {...}
}
"""
logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}")
# 记录开始时间
start_time = datetime.datetime.now()
# 发送 workflow_start 事件
yield {
"event": "workflow_start",
"data": {
"execution_id": self.execution_id,
"workspace_id": self.workspace_id,
"timestamp": start_time.isoformat()
}
}
# 1. 构建图
graph = self.build_graph()
graph = self.build_graph(True)
# 2. 初始化状态(自动注入系统变量)
initial_state = self._prepare_initial_state(input_data)
# 3. 流式执行工作流
# 3. Execute workflow
try:
# 使用 astream 获取节点级别的更新
async for event in graph.astream(initial_state, stream_mode="updates"):
for node_name, state_update in event.items():
chunk_count = 0
final_state = None
async for event in graph.astream(
initial_state,
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
):
# event should be a tuple: (mode, data)
# But let's handle both cases
if isinstance(event, tuple) and len(event) == 2:
mode, data = event
else:
# Unexpected format, log and skip
logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}")
continue
if mode == "custom":
# Handle custom streaming events (chunks from nodes via stream writer)
chunk_count += 1
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}")
yield {
"type": "node_complete",
"node": node_name,
"data": state_update,
"execution_id": self.execution_id
"event": event_type, # "message" or "node_chunk"
"data": {
"node_id": data.get("node_id"),
"chunk": data.get("chunk"),
"full_content": data.get("full_content"),
"chunk_index": data.get("chunk_index"),
"is_prefix": data.get("is_prefix"),
"is_suffix": data.get("is_suffix")
}
}
elif mode == "debug":
# Handle debug information (node execution status)
event_type = data.get("type")
payload = data.get("payload", {})
node_name = payload.get("name")
if event_type == "task":
# Node starts execution
inputv = payload.get("input", {})
variables = inputv.get("variables", {})
variables_sys = variables.get("sys", {})
conversation_id = variables_sys.get("conversation_id")
execution_id = variables_sys.get("execution_id")
logger.info(f"[DEBUG] Node starts execution: {node_name}")
yield {
"event": "node_start",
"data": {
"node_id": node_name,
"conversation_id": conversation_id,
"execution_id": execution_id,
"timestamp": data.get("timestamp")
}
}
elif event_type == "task_result":
# Node execution completed
result = payload.get("result", {})
inputv = result.get("input", {})
variables = inputv.get("variables", {})
variables_sys = variables.get("sys", {})
conversation_id = variables_sys.get("conversation_id")
execution_id = variables_sys.get("execution_id")
logger.info(f"[DEBUG] Node execution completed: {node_name}")
yield {
"event": "node_end",
"data": {
"node_id": node_name,
"conversation_id": conversation_id,
"execution_id": execution_id,
"timestamp": data.get("timestamp")
}
}
elif mode == "updates":
# Handle state updates - store final state
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())}")
final_state = data
logger.info(f"工作流执行完成(流式): execution_id={self.execution_id}")
# 计算耗时
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
# 发送完成事件
logger.info(f"Workflow execution completed (streaming), total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s")
# 发送 workflow_end 事件
yield {
"type": "workflow_complete",
"execution_id": self.execution_id
"event": "workflow_end",
"data": {
"execution_id": self.execution_id,
"status": "completed",
"elapsed_time": elapsed_time,
"timestamp": end_time.isoformat()
}
}
except Exception as e:
logger.error(f"工作流执行失败(流式): execution_id={self.execution_id}, error={e}", exc_info=True)
# 计算耗时(即使失败也记录)
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True)
# 发送 workflow_end 事件(失败)
yield {
"type": "workflow_error",
"execution_id": self.execution_id,
"error": str(e)
"event": "workflow_end",
"data": {
"execution_id": self.execution_id,
"status": "failed",
"error": str(e),
"elapsed_time": elapsed_time,
"timestamp": end_time.isoformat()
}
}
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
"""从节点输出中提取最终输出
优先级:
1. 最后一个执行的非 start/end 节点的 output
2. 如果没有节点输出,返回 None
Args:
node_outputs: 所有节点的输出
Returns:
最终输出字符串或 None
"""
if not node_outputs:
return None
# 获取最后一个节点的输出
last_node_output = list(node_outputs.values())[-1] if node_outputs else None
if last_node_output and isinstance(last_node_output, dict):
return last_node_output.get("output")
return None
def _aggregate_token_usage(self, node_outputs: dict[str, Any]) -> dict[str, int] | None:
"""聚合所有节点的 token 使用情况
Args:
node_outputs: 所有节点的输出
Returns:
聚合的 token 使用情况 {"prompt_tokens": x, "completion_tokens": y, "total_tokens": z}
如果没有 token 使用信息,返回 None
@@ -359,7 +592,7 @@ class WorkflowExecutor:
total_completion_tokens = 0
total_tokens = 0
has_token_info = False
for node_output in node_outputs.values():
if isinstance(node_output, dict):
token_usage = node_output.get("token_usage")
@@ -368,33 +601,33 @@ class WorkflowExecutor:
total_prompt_tokens += token_usage.get("prompt_tokens", 0)
total_completion_tokens += token_usage.get("completion_tokens", 0)
total_tokens += token_usage.get("total_tokens", 0)
if not has_token_info:
return None
return {
"prompt_tokens": total_prompt_tokens,
"completion_tokens": total_completion_tokens,
"total_tokens": total_tokens
}
async def execute_workflow(
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
) -> dict[str, Any]:
"""执行工作流(便捷函数)
Args:
workflow_config: 工作流配置
input_data: 输入数据
execution_id: 执行 ID
workspace_id: 工作空间 ID
user_id: 用户 ID
Returns:
执行结果
"""
@@ -408,21 +641,21 @@ async def execute_workflow(
async def execute_workflow_stream(
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
):
"""执行工作流(流式,便捷函数)
Args:
workflow_config: 工作流配置
input_data: 输入数据
execution_id: 执行 ID
workspace_id: 工作空间 ID
user_id: 用户 ID
Yields:
流式事件
"""
@@ -434,3 +667,179 @@ async def execute_workflow_stream(
)
async for event in executor.execute_stream(input_data):
yield event
# ==================== 工具管理系统集成 ====================
# def get_workflow_tools(workspace_id: str, user_id: str) -> list:
# """获取工作流可用的工具列表
#
# Args:
# workspace_id: 工作空间ID
# user_id: 用户ID
#
# Returns:
# 可用工具列表
# """
# if not TOOL_MANAGEMENT_AVAILABLE:
# logger.warning("工具管理系统不可用")
# return []
#
# try:
# db = next(get_db())
#
# # 创建工具注册表
# registry = ToolRegistry(db)
#
# # 注册内置工具类
# from app.core.tools.builtin import (
# DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool
# )
# registry.register_tool_class(DateTimeTool)
# registry.register_tool_class(JsonTool)
# registry.register_tool_class(BaiduSearchTool)
# registry.register_tool_class(MinerUTool)
# registry.register_tool_class(TextInTool)
#
# # 获取活跃的工具
# import uuid
# tools = registry.list_tools(workspace_id=uuid.UUID(workspace_id))
# active_tools = [tool for tool in tools if tool.status.value == "active"]
#
# # 转换为Langchain工具
# langchain_tools = []
# for tool_info in active_tools:
# try:
# tool_instance = registry.get_tool(tool_info.id)
# if tool_instance:
# langchain_tool = LangchainAdapter.convert_tool(tool_instance)
# langchain_tools.append(langchain_tool)
# except Exception as e:
# logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}")
#
# logger.info(f"为工作流获取了 {len(langchain_tools)} 个工具")
# return langchain_tools
#
# except Exception as e:
# logger.error(f"获取工作流工具失败: {e}")
# return []
#
#
# class ToolWorkflowNode:
# """工具工作流节点 - 在工作流中执行工具"""
#
# def __init__(self, node_config: dict, workflow_config: dict):
# """初始化工具节点
#
# Args:
# node_config: 节点配置
# workflow_config: 工作流配置
# """
# self.node_config = node_config
# self.workflow_config = workflow_config
# self.tool_id = node_config.get("tool_id")
# self.tool_parameters = node_config.get("parameters", {})
#
# async def run(self, state: WorkflowState) -> WorkflowState:
# """执行工具节点"""
# if not TOOL_MANAGEMENT_AVAILABLE:
# logger.error("工具管理系统不可用")
# state["error"] = "工具管理系统不可用"
# return state
#
# try:
# from sqlalchemy.orm import Session
# db = next(get_db())
#
# # 创建工具执行器
# registry = ToolRegistry(db)
# executor = ToolExecutor(db, registry)
#
# # 准备参数(支持变量替换)
# parameters = self._prepare_parameters(state)
#
# # 执行工具
# result = await executor.execute_tool(
# tool_id=self.tool_id,
# parameters=parameters,
# user_id=uuid.UUID(state["user_id"]),
# workspace_id=uuid.UUID(state["workspace_id"])
# )
#
# # 更新状态
# node_id = self.node_config.get("id")
# if result.success:
# state["node_outputs"][node_id] = {
# "type": "tool",
# "tool_id": self.tool_id,
# "output": result.data,
# "execution_time": result.execution_time,
# "token_usage": result.token_usage
# }
#
# # 更新运行时变量
# if isinstance(result.data, dict):
# for key, value in result.data.items():
# state["runtime_vars"][f"{node_id}.{key}"] = value
# else:
# state["runtime_vars"][f"{node_id}.result"] = result.data
# else:
# state["error"] = result.error
# state["error_node"] = node_id
# state["node_outputs"][node_id] = {
# "type": "tool",
# "tool_id": self.tool_id,
# "error": result.error,
# "execution_time": result.execution_time
# }
#
# return state
#
# except Exception as e:
# logger.error(f"工具节点执行失败: {e}")
# state["error"] = str(e)
# state["error_node"] = self.node_config.get("id")
# return state
#
# def _prepare_parameters(self, state: WorkflowState) -> dict:
# """准备工具参数(支持变量替换)"""
# parameters = {}
#
# for key, value in self.tool_parameters.items():
# if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
# # 变量替换
# var_path = value[2:-1]
#
# # 支持多层级变量访问,如 ${sys.message} 或 ${node1.result}
# if "." in var_path:
# parts = var_path.split(".")
# current = state.get("variables", {})
#
# for part in parts:
# if isinstance(current, dict) and part in current:
# current = current[part]
# else:
# # 尝试从运行时变量获取
# runtime_key = ".".join(parts)
# current = state.get("runtime_vars", {}).get(runtime_key, value)
# break
#
# parameters[key] = current
# else:
# # 简单变量
# variables = state.get("variables", {})
# parameters[key] = variables.get(var_path, value)
# else:
# parameters[key] = value
#
# return parameters
#
#
# # 注册工具节点到NodeFactory如果存在
# try:
# from app.core.workflow.nodes import NodeFactory
# if hasattr(NodeFactory, 'register_node_type'):
# NodeFactory.register_node_type("tool", ToolWorkflowNode)
# logger.info("工具节点已注册到工作流系统")
# except Exception as e:
# logger.warning(f"注册工具节点失败: {e}")

View File

@@ -5,6 +5,7 @@
"""
import logging
import re
from typing import Any
from simpleeval import simple_eval, NameNotDefined, InvalidExpression
@@ -59,9 +60,10 @@ class ExpressionEvaluator:
"""
# 移除 Jinja2 模板语法的花括号(如果存在)
expression = expression.strip()
if expression.startswith("{{") and expression.endswith("}}"):
expression = expression[2:-2].strip()
# "{{system.message}} == {{ user.messge }}" -> "system.message == user.message"
pattern = r"\{\{\s*(.*?)\s*\}\}"
expression = re.sub(pattern, r"\1", expression).strip()
# 构建命名空间上下文
context = {
"var": variables, # 用户变量

View File

@@ -4,13 +4,14 @@
提供各种类型的节点实现,用于工作流执行。
"""
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.node_factory import NodeFactory
from app.core.workflow.nodes.if_else import IfElseNode
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode
__all__ = [
"BaseNode",
@@ -18,7 +19,9 @@ __all__ = [
"LLMNode",
"AgentNode",
"TransformNode",
"IfElseNode",
"StartNode",
"EndNode",
"NodeFactory",
"WorkflowNode"
]

View File

@@ -50,6 +50,11 @@ class VariableDefinition(BaseModel):
description="变量描述"
)
max_length: int = Field(
default=200,
description="只对字符串类型生效"
)
class Config:
json_schema_extra = {
"examples": [

View File

@@ -10,6 +10,7 @@ from abc import ABC, abstractmethod
from typing import Any, TypedDict, Annotated
from operator import add
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
from langgraph.config import get_stream_writer
from app.core.workflow.variable_pool import VariablePool
@@ -43,6 +44,10 @@ class WorkflowState(TypedDict):
# 错误信息(用于错误边)
error: str | None
error_node: str | None
# 流式缓冲区(存储节点的实时流式输出)
# 格式:{node_id: {"chunks": [...], "full_content": "..."}}
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
class BaseNode(ABC):
@@ -201,19 +206,25 @@ class BaseNode(ABC):
return self._wrap_error(str(e), elapsed_time, state)
async def run_stream(self, state: WorkflowState):
"""执行节点(带错误处理和输出包装,流式)
"""Execute node with error handling and output wrapping (streaming)
这个方法由 Executor 调用,负责:
1. 时间统计
2. 调用节点的 execute_stream() 方法
3. 将业务数据包装成标准输出格式
4. 错误处理
This method is called by the Executor and is responsible for:
1. Time tracking
2. Calling the node's execute_stream() method
3. Using LangGraph's stream writer to send chunks
4. Updating streaming buffer in state for downstream nodes
5. Wrapping business data into standard output format
6. Error handling
Special handling for End nodes:
- End nodes don't send chunks via writer (prefix and LLM content already sent)
- End nodes only yield suffix for final result assembly
Args:
state: 工作流状态
state: Workflow state
Yields:
标准化的流式事件
State updates with streaming buffer and final result
"""
import time
@@ -222,68 +233,143 @@ class BaseNode(ABC):
try:
timeout = self.get_timeout()
# 累积完整结果(用于最后的包装)
# Get LangGraph's stream writer for sending custom data
writer = get_stream_writer()
# Check if this is an End node
# End nodes CAN send chunks (for suffix), but only after LLM content
is_end_node = self.node_type == "end"
# Check if this node is adjacent to End node (for message type)
is_adjacent_to_end = getattr(self, '_is_adjacent_to_end', False)
# Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others
chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk"
logger.debug(f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})")
# Accumulate complete result (for final wrapping)
chunks = []
final_result = None
chunk_count = 0
# 使用异步生成器包装,支持超时
async def stream_with_timeout():
nonlocal final_result
loop_start = asyncio.get_event_loop().time()
# Stream chunks in real-time
loop_start = asyncio.get_event_loop().time()
async for item in self.execute_stream(state):
# Check timeout
if asyncio.get_event_loop().time() - loop_start > timeout:
raise TimeoutError()
async for item in self.execute_stream(state):
# 检查超时
if asyncio.get_event_loop().time() - loop_start > timeout:
raise TimeoutError()
# Check if it's a completion marker
if isinstance(item, dict) and item.get("__final__"):
final_result = item["result"]
elif isinstance(item, str):
# String is a chunk
chunk_count += 1
chunks.append(item)
full_content = "".join(chunks)
# 检查是否是完成标记
if isinstance(item, dict) and item.get("__final__"):
final_result = item["result"]
elif isinstance(item, str):
# 字符串是 chunk
chunks.append(item)
# Send chunks for all nodes (including End nodes for suffix)
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...")
# 1. Send via stream writer (for real-time client updates)
writer({
"type": chunk_type, # "message" or "node_chunk"
"node_id": self.node_id,
"chunk": item,
"full_content": full_content,
"chunk_index": chunk_count
})
# 2. Update streaming buffer in state (for downstream nodes)
# Only non-End nodes need streaming buffer
if not is_end_node:
yield {
"type": "chunk",
"node_id": self.node_id,
"content": item,
"full_content": "".join(chunks)
"streaming_buffer": {
self.node_id: {
"full_content": full_content,
"chunk_count": chunk_count,
"is_complete": False
}
}
}
else:
# 其他类型也当作 chunk 处理
chunks.append(str(item))
else:
# Other types are also treated as chunks
chunk_count += 1
chunk_str = str(item)
chunks.append(chunk_str)
full_content = "".join(chunks)
# Send chunks for all nodes
writer({
"type": chunk_type, # "message" or "node_chunk"
"node_id": self.node_id,
"chunk": chunk_str,
"full_content": full_content,
"chunk_index": chunk_count
})
# Only non-End nodes need streaming buffer
if not is_end_node:
yield {
"type": "chunk",
"node_id": self.node_id,
"content": str(item),
"full_content": "".join(chunks)
"streaming_buffer": {
self.node_id: {
"full_content": full_content,
"chunk_count": chunk_count,
"is_complete": False
}
}
}
async for chunk_event in stream_with_timeout():
yield chunk_event
elapsed_time = time.time() - start_time
# 包装最终结果
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
# Extract processed output (call subclass's _extract_output)
extracted_output = self._extract_output(final_result)
# Wrap final result
final_output = self._wrap_output(final_result, elapsed_time, state)
yield {
"type": "complete",
**final_output
# Store extracted output in runtime variables (for quick access by subsequent nodes)
if isinstance(extracted_output, dict):
runtime_var = extracted_output
else:
runtime_var = {"output": extracted_output}
# Build complete state update (including node_outputs, runtime_vars, and final streaming buffer)
state_update = {
**final_output,
"runtime_vars": {
self.node_id: runtime_var
}
}
# Add streaming buffer for non-End nodes
if not is_end_node:
state_update["streaming_buffer"] = {
self.node_id: {
"full_content": "".join(chunks),
"chunk_count": chunk_count,
"is_complete": True # Mark as complete
}
}
# Finally yield state update
# LangGraph will merge this into state
yield state_update
except TimeoutError:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行超时{timeout}秒)")
yield {
"type": "error",
**self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
}
logger.error(f"节点 {self.node_id} 执行超时 ({timeout}s)")
error_output = self._wrap_error(f"节点执行超时 ({timeout}s)", elapsed_time, state)
yield error_output
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
yield {
"type": "error",
**self._wrap_error(str(e), elapsed_time, state)
}
error_output = self._wrap_error(str(e), elapsed_time, state)
yield error_output
def _wrap_output(
self,

View File

@@ -13,6 +13,7 @@ from app.core.workflow.nodes.end.config import EndNodeConfig
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
from app.core.workflow.nodes.agent.config import AgentNodeConfig
from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
__all__ = [
# 基础类
@@ -26,4 +27,5 @@ __all__ = [
"MessageConfig",
"AgentNodeConfig",
"TransformNodeConfig",
"IfElseNodeConfig",
]

View File

@@ -5,7 +5,8 @@ End 节点实现
"""
import logging
from typing import Any
import re
import asyncio
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
@@ -16,6 +17,7 @@ class EndNode(BaseNode):
"""End 节点
工作流的结束节点,根据配置的模板输出最终结果。
支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。
"""
async def execute(self, state: WorkflowState) -> str:
@@ -31,11 +33,7 @@ class EndNode(BaseNode):
# 获取配置的输出模板
output_template = self.config.get("output")
pool = self.get_variable_pool(state)
print("="*20)
print( pool.get("start.test"))
print("="*20)
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template:
output = self._render_template(output_template, state)
@@ -47,7 +45,228 @@ class EndNode(BaseNode):
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
print("="*20)
print(output)
print("="*20)
return output
def _extract_referenced_nodes(self, template: str) -> list[str]:
"""从模板中提取引用的节点 ID
例如:'结果:{{llm_qa.output}}' -> ['llm_qa']
Args:
template: 模板字符串
Returns:
引用的节点 ID 列表
"""
# 匹配 {{node_id.xxx}} 格式
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
matches = re.findall(pattern, template)
return list(set(matches)) # 去重
def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]:
"""解析模板,分离静态文本和动态引用
例如:'你好 {{llm.output}}, 这是后缀'
返回:[
{"type": "static", "content": "你好 "},
{"type": "dynamic", "node_id": "llm", "field": "output"},
{"type": "static", "content": ", 这是后缀"}
]
Args:
template: 模板字符串
state: 工作流状态
Returns:
模板部分列表
"""
import re
parts = []
last_end = 0
# 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格)
pattern = r'\{\{\s*([^}]+?)\s*\}\}'
for match in re.finditer(pattern, template):
start, end = match.span()
# 添加前面的静态文本
if start > last_end:
static_text = template[last_end:start]
if static_text:
parts.append({"type": "static", "content": static_text})
# 解析动态引用
ref = match.group(1).strip()
# 检查是否是节点引用(如 llm.output 或 llm_qa.output
if '.' in ref:
node_id, field = ref.split('.', 1)
parts.append({
"type": "dynamic",
"node_id": node_id,
"field": field,
"raw": ref
})
else:
# 其他引用(如 {{var.xxx}}),当作静态处理
# 直接渲染这部分
rendered = self._render_template(f"{{{{{ref}}}}}", state)
parts.append({"type": "static", "content": rendered})
last_end = end
# 添加最后的静态文本
if last_end < len(template):
static_text = template[last_end:]
if static_text:
parts.append({"type": "static", "content": static_text})
return parts
async def execute_stream(self, state: WorkflowState):
"""流式执行 end 节点业务逻辑
智能输出策略:
1. 检测模板中是否引用了直接上游节点
2. 如果引用了,只输出该引用**之后**的部分(后缀)
3. 前缀和引用内容已经在上游节点流式输出时发送了
示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
- 直接上游节点是 llm_qa
- 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送
- LLM 内容在 LLM 节点流式输出
- End 节点只输出 ' lalalalala a'(后缀,一次性输出)
Args:
state: 工作流状态
Yields:
完成标记
"""
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
# 获取配置的输出模板
output_template = self.config.get("output")
if not output_template:
output = "工作流已完成"
yield {"__final__": True, "result": output}
return
# 找到直接上游节点
direct_upstream_nodes = []
for edge in self.workflow_config.get("edges", []):
if edge.get("target") == self.node_id:
source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id)
logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}")
# 解析模板部分
parts = self._parse_template_parts(output_template, state)
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
# 找到第一个引用直接上游节点的动态引用
upstream_ref_index = None
for i, part in enumerate(parts):
if part["type"] == "dynamic" and part["node_id"] in direct_upstream_nodes:
upstream_ref_index = i
logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}")
break
if upstream_ref_index is None:
# 没有引用直接上游节点,输出完整模板内容
output = self._render_template(output_template, state)
logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容: '{output[:50]}...'")
# 通过 writer 发送完整内容(作为一个 message chunk
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer({
"type": "message", # End 节点的输出使用 message 类型
"node_id": self.node_id,
"chunk": output,
"full_content": output,
"chunk_index": 1,
"is_suffix": False
})
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
# yield 完成标记
yield {"__final__": True, "result": output}
return
# 有引用直接上游节点,只输出该引用之后的部分(后缀)
logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)")
# 收集后缀部分
suffix_parts = []
for i in range(upstream_ref_index + 1, len(parts)):
part = parts[i]
if part["type"] == "static":
# 静态文本
suffix_parts.append(part["content"])
elif part["type"] == "dynamic":
# 其他动态引用(如果有多个引用)
node_id = part["node_id"]
field = part["field"]
# 从 streaming_buffer 或 node_outputs 读取
streaming_buffer = state.get("streaming_buffer", {})
if node_id in streaming_buffer:
buffer_data = streaming_buffer[node_id]
content = buffer_data.get("full_content", "")
else:
node_outputs = state.get("node_outputs", {})
runtime_vars = state.get("runtime_vars", {})
content = ""
if node_id in node_outputs:
node_output = node_outputs[node_id]
if isinstance(node_output, dict):
content = str(node_output.get(field, ""))
elif node_id in runtime_vars:
runtime_var = runtime_vars[node_id]
if isinstance(runtime_var, dict):
content = str(runtime_var.get(field, ""))
suffix_parts.append(content)
# 拼接后缀
suffix = "".join(suffix_parts)
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
full_output = self._render_template(output_template, state)
if suffix:
logger.info(f"节点 {self.node_id} 输出后缀: '{suffix[:50]}...' (长度: {len(suffix)})")
# 一次性输出后缀(作为单个 chunk
# 注意:不要直接 yield 字符串,因为 base_node 会逐字符处理
# 而是通过 writer 直接发送
from langgraph.config import get_stream_writer
writer = get_stream_writer()
writer({
"type": "message", # End 节点的输出使用 message 类型
"node_id": self.node_id,
"chunk": suffix,
"full_content": full_output, # full_content 是完整的渲染结果(前缀+LLM+后缀)
"chunk_index": 1,
"is_suffix": True
})
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀full_content 长度: {len(full_output)}")
else:
logger.info(f"节点 {self.node_id} 没有后缀需要输出")
# 统计信息
node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点")
# yield 完成标记(包含完整输出)
yield {"__final__": True, "result": full_output}

View File

@@ -1,5 +1,6 @@
from enum import StrEnum
class NodeType(StrEnum):
START = "start"
END = "end"
@@ -13,3 +14,23 @@ class NodeType(StrEnum):
HTTP_REQUEST = "http-request"
TOOL = "tool"
AGENT = "agent"
class ComparisonOperator(StrEnum):
EMPTY = "empty"
NOT_EMPTY = "not_empty"
CONTAINS = "contains"
NOT_CONTAINS = "not_contains"
START_WITH = "startwith"
END_WITH = "endwith"
EQ = "eq"
NE = "ne"
LT = "lt"
LE = "le"
GT = "gt"
GE = "ge"
class LogicOperator(StrEnum):
AND = "and"
OR = "or"

View File

@@ -0,0 +1,5 @@
"""Condition Node"""
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
from app.core.workflow.nodes.if_else.node import IfElseNode
__all__ = ["IfElseNode", "IfElseNodeConfig"]

View File

@@ -0,0 +1,97 @@
"""Condition Configuration"""
from pydantic import Field, BaseModel, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
class ConditionDetail(BaseModel):
comparison_operator: ComparisonOperator = Field(
...,
description="Comparison operator used to evaluate the condition"
)
left: str = Field(
...,
description="Value to compare against"
)
right: str = Field(
...,
description="Value to compare with"
)
class ConditionBranchConfig(BaseModel):
"""Configuration for a conditional branch"""
logical_operator: LogicOperator = Field(
default=LogicOperator.AND.value,
description="Logical operator used to combine multiple condition expressions"
)
conditions: list[ConditionDetail] = Field(
...,
description="List of condition expressions within this branch"
)
class IfElseNodeConfig(BaseNodeConfig):
cases: list[ConditionBranchConfig] = Field(
...,
description="List of branch conditions or expressions"
)
@field_validator("cases")
@classmethod
def validate_case_number(cls, v, info):
if len(v) < 1:
raise ValueError("At least one cases are required")
return v
class Config:
json_schema_extra = {
"examples": [
{
"cases": [
# CASE1 / IF Branch
{
"logical_operator": "and",
"conditions": [
[
{
"left": "node.userinput.message",
"comparison_operator": "eq",
"right": "'123'"
},
{
"left": "node.userinput.test",
"comparison_operator": "eq",
"right": "True"
}
]
]
},
# CASE1 / ELIF Branch
{
"logical_operator": "or",
"conditions": [
[
{
"left": "node.userinput.test",
"comparison_operator": "eq",
"right": "False"
},
{
"left": "node.userinput.message",
"comparison_operator": "contains",
"right": "'123'"
}
]
]
}
# CASE3 / ELSE Branch
]
}
]
}

View File

@@ -0,0 +1,167 @@
import logging
from typing import Any
from app.core.workflow.nodes import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import ComparisonOperator
from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.if_else.config import ConditionDetail
logger = logging.getLogger(__name__)
class ConditionExpressionBuilder:
"""
Build a Python boolean expression string based on a comparison operator.
This class does not evaluate the expression.
It only generates a valid Python expression string
that can be evaluated later in a workflow context.
"""
def __init__(self, left: str, operator: ComparisonOperator, right: str):
self.left = left
self.operator = operator
self.right = right
def _empty(self):
return f"{self.left} == ''"
def _not_empty(self):
return f"{self.left} != ''"
def _contains(self):
return f"{self.right} in {self.left}"
def _not_contains(self):
return f"{self.right} not in {self.left}"
def _startwith(self):
return f'{self.left}.startswith({self.right})'
def _endwith(self):
return f'{self.left}.endswith({self.right})'
def _eq(self):
return f"{self.left} == {self.right}"
def _ne(self):
return f"{self.left} != {self.right}"
def _lt(self):
return f"{self.left} < {self.right}"
def _le(self):
return f"{self.left} <= {self.right}"
def _gt(self):
return f"{self.left} > {self.right}"
def _ge(self):
return f"{self.left} >= {self.right}"
def build(self):
match self.operator:
case ComparisonOperator.EMPTY:
return self._empty()
case ComparisonOperator.NOT_EMPTY:
return self._not_empty()
case ComparisonOperator.CONTAINS:
return self._contains()
case ComparisonOperator.NOT_CONTAINS:
return self._not_contains()
case ComparisonOperator.START_WITH:
return self._startwith()
case ComparisonOperator.END_WITH:
return self._endwith()
case ComparisonOperator.EQ:
return self._eq()
case ComparisonOperator.NE:
return self._ne()
case ComparisonOperator.LT:
return self._lt()
case ComparisonOperator.LE:
return self._le()
case ComparisonOperator.GT:
return self._gt()
case ComparisonOperator.GE:
return self._ge()
case _:
raise ValueError(f"Invalid condition: {self.operator}")
class IfElseNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = IfElseNodeConfig(**self.config)
@staticmethod
def _build_condition_expression(
condition: ConditionDetail,
) -> str:
"""
Build a single boolean condition expression string.
This method does NOT evaluate the condition.
It only generates a valid Python boolean expression string
(e.g. "x > 10", "'a' in name") that can later be used
in a conditional edge or evaluated by the workflow engine.
Args:
condition (ConditionDetail): Definition of a single comparison condition.
Returns:
str: A Python boolean expression string.
"""
return ConditionExpressionBuilder(
left=condition.left,
operator=condition.comparison_operator,
right=condition.right
).build()
def build_conditional_edge_expressions(self) -> list[str]:
"""
Build conditional edge expressions for the If-Else node.
This method does NOT evaluate any condition at runtime.
Instead, it converts each case branch into a Python boolean
expression string, which will later be attached to LangGraph
as conditional edges.
Each returned expression corresponds to one branch and is
evaluated in order. A fallback 'True' condition is appended
to ensure a default branch when no previous conditions match.
Returns:
list[str]: A list of Python boolean expression strings,
ordered by branch priority.
"""
branch_index = 0
conditions = []
for case_branch in self.typed_config.cases:
branch_index += 1
branch_conditions = [
self._build_condition_expression(condition)
for condition in case_branch.conditions
]
if len(branch_conditions) > 1:
combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions)
else:
combined_condition = branch_conditions[0]
conditions.append(combined_condition)
# Default fallback branch
conditions.append("True")
return conditions
async def execute(self, state: WorkflowState) -> Any:
"""
"""
expressions = self.build_conditional_edge_expressions()
for i in range(len(expressions)):
logger.info(expressions[i])
if self._evaluate_condition(expressions[i], state):
return f'CASE{i+1}'
return f'CASE{len(expressions)}'

View File

@@ -10,10 +10,8 @@ from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models import ModelConfig
from app.db import get_db, get_db_context
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService, ModelApiKeyService
from app.db import get_db_context
from app.services.model_service import ModelConfigService
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
@@ -65,7 +63,7 @@ class LLMNode(BaseNode):
- ai/assistant: AI 消息AIMessage
"""
def _prepare_llm(self, state: WorkflowState) -> tuple[RedBearLLM, list | str]:
def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]:
"""准备 LLM 实例(公共逻辑)
Args:
@@ -127,16 +125,22 @@ class LLMNode(BaseNode):
model_type = config.type
# 4. 创建 LLM 实例(使用已提取的数据)
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
extra_params = {"streaming": stream} if stream else {}
llm = RedBearLLM(
RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base
base_url=api_base,
extra_params=extra_params
),
type=model_type
)
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
return llm, prompt_or_messages
async def execute(self, state: WorkflowState) -> AIMessage:
@@ -148,13 +152,12 @@ class LLMNode(BaseNode):
Returns:
LLM 响应消息
"""
llm, prompt_or_messages = self._prepare_llm(state)
llm, prompt_or_messages = self._prepare_llm(state,True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
# 调用 LLM支持字符串或消息列表
response = await llm.ainvoke(prompt_or_messages)
# 提取内容
if hasattr(response, 'content'):
content = response.content
@@ -210,13 +213,43 @@ class LLMNode(BaseNode):
Yields:
文本片段chunk或完成标记
"""
llm, prompt_or_messages = self._prepare_llm(state)
from langgraph.config import get_stream_writer
llm, prompt_or_messages = self._prepare_llm(state, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
# 检查是否有注入的 End 节点前缀配置
writer = get_stream_writer()
end_prefix = getattr(self, '_end_node_prefix', None)
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
if end_prefix:
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
if end_prefix:
# 渲染前缀(可能包含其他变量)
try:
rendered_prefix = self._render_template(end_prefix, state)
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
# 提前发送 End 节点的前缀(使用 "message" 类型)
writer({
"type": "message", # End 相关的内容都是 message 类型
"node_id": "end", # 标记为 end 节点的输出
"chunk": rendered_prefix,
"full_content": rendered_prefix,
"chunk_index": 0,
"is_prefix": True # 标记这是前缀
})
except Exception as e:
logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
# 累积完整响应
full_response = ""
last_chunk = None
chunk_count = 0
# 调用 LLM流式支持字符串或消息列表
async for chunk in llm.astream(prompt_or_messages):
@@ -226,13 +259,16 @@ class LLMNode(BaseNode):
else:
content = str(chunk)
full_response += content
last_chunk = chunk
# 流式返回每个文本片段
yield content
# 只有当内容不为空时才处理
if content:
full_response += content
last_chunk = chunk
chunk_count += 1
# 流式返回每个文本片段
yield content
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
# 构建完整的 AIMessage包含元数据
if isinstance(last_chunk, AIMessage):

View File

@@ -5,18 +5,29 @@
"""
import logging
from typing import Any
from typing import Any, Union
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.nodes.if_else import IfElseNode
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode
logger = logging.getLogger(__name__)
WorkflowNode = Union[
BaseNode,
StartNode,
EndNode,
LLMNode,
IfElseNode,
AgentNode,
TransformNode,
]
class NodeFactory:
"""节点工厂
@@ -25,16 +36,17 @@ class NodeFactory:
"""
# 节点类型注册表
_node_types: dict[str, type[BaseNode]] = {
_node_types: dict[str, type[WorkflowNode]] = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.LLM: LLMNode,
NodeType.AGENT: AgentNode,
NodeType.TRANSFORM: TransformNode,
NodeType.IF_ELSE: IfElseNode
}
@classmethod
def register_node_type(cls, node_type: str, node_class: type[BaseNode]):
def register_node_type(cls, node_type: str, node_class: type[WorkflowNode]):
"""注册新的节点类型
Args:
@@ -52,10 +64,10 @@ class NodeFactory:
@classmethod
def create_node(
cls,
node_config: dict[str, Any],
workflow_config: dict[str, Any]
) -> BaseNode | None:
cls,
node_config: dict[str, Any],
workflow_config: dict[str, Any]
) -> WorkflowNode | None:
"""创建节点实例
Args:

View File

@@ -20,6 +20,11 @@ from .data_config_model import DataConfig
from .multi_agent_model import MultiAgentConfig, AgentInvocation
from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution
from .retrieval_info import RetrievalInfo
from .prompt_optimizer_model import PromptOptimizerSession, PromptOptimizerSessionHistory
from .tool_model import (
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus
)
__all__ = [
"Tenants",
@@ -54,5 +59,17 @@ __all__ = [
"WorkflowConfig",
"WorkflowExecution",
"WorkflowNodeExecution",
"RetrievalInfo"
"RetrievalInfo",
"PromptOptimizerSession",
"PromptOptimizerSessionHistory",
"RetrievalInfo",
"ToolConfig",
"BuiltinToolConfig",
"CustomToolConfig",
"MCPToolConfig",
"ToolExecution",
"ToolType",
"ToolStatus",
"AuthType",
"ExecutionStatus"
]

View File

@@ -1,5 +1,4 @@
import datetime
import uuid
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float
from sqlalchemy.dialects.postgresql import UUID
from app.db import Base
@@ -11,50 +10,53 @@ class DataConfig(Base):
# 主键
config_id = Column(Integer, primary_key=True, autoincrement=True, comment="配置ID")
# 基本信息
config_name = Column(String, nullable=False, comment="配置名称")
config_desc = Column(String, nullable=True, comment="配置描述")
# 组织信息
workspace_id = Column(UUID(as_uuid=True), nullable=True, comment="工作空间ID")
group_id = Column(String, nullable=True, comment="组ID")
user_id = Column(String, nullable=True, comment="用户ID")
apply_id = Column(String, nullable=True, comment="应用ID")
# 模型选择从workspace继承
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
llm = Column(String, nullable=True, comment="LLM模型配置ID")
# 记忆萃取引擎配置
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
enable_llm_disambiguation = Column(Boolean, default=True, comment="启用LLM决策消歧")
deep_retrieval = Column(Boolean, default=True, comment="深度检索开关")
# 阈值配置 (0-1 之间的浮点数)
t_type_strict = Column(Float, default=0.8, comment="类型严格阈值")
t_name_strict = Column(Float, default=0.8, comment="名称严格阈值")
t_overall = Column(Float, default=0.8, comment="综合阈值")
# 状态配置
state = Column(Boolean, default=False, comment="配置使用状态")
# 分块策略
chunker_strategy = Column(String, default="RecursiveChunker", comment="分块策略")
# 剪枝配置
pruning_enabled = Column(Boolean, default=False, comment="是否启动智能语义剪枝")
pruning_scene = Column(String, nullable=True, comment="智能剪枝场景education/online_service/outbound")
pruning_threshold = Column(Float, nullable=True, comment="智能语义剪枝阈值0-0.9")
# 自我反思配置
enable_self_reflexion = Column(Boolean, default=False, comment="是否启用自我反思")
iteration_period = Column(String, default="3", comment="反思迭代周期")
reflexion_range = Column(String, default="retrieval", comment="反思范围:部分/全部")
baseline = Column(String, default="time", comment="基线:时间/事实/时间和事实")
reflection_model_id = Column(String, nullable=True, comment="反思模型ID")
memory_verify = Column(Boolean, default=True, comment="记忆验证")
quality_assessment = Column(Boolean, default=True, comment="质量评估")
# 遗忘引擎配置
statement_granularity = Column(Integer, default=2, comment="陈述提取颗粒度,挡位 1/2/3")
include_dialogue_context = Column(Boolean, default=False, comment="是否包含对话上下文")
@@ -63,6 +65,13 @@ class DataConfig(Base):
lambda_mem = Column("lambda_mem", Float, default=0.5, comment="遗忘率0-1 小数")
offset = Column("offset", Float, default=0.0, comment="偏移度0-1 小数")
# 情绪引擎配置
emotion_enabled = Column(Boolean, default=True, comment="是否启用情绪提取")
emotion_model_id = Column(String, nullable=True, comment="情绪分析专用模型ID")
emotion_extract_keywords = Column(Boolean, default=True, comment="是否提取情绪关键词")
emotion_min_intensity = Column(Float, default=0.1, comment="最小情绪强度阈值")
emotion_enable_subject = Column(Boolean, default=True, comment="是否启用主体分类")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")

View File

@@ -16,7 +16,26 @@ class Document(Base):
file_size = Column(Integer, default=0, comment="file size(byte)")
file_meta = Column(JSON, nullable=False, default={})
parser_id = Column(String, index=True, nullable=False, comment="default parser ID")
parser_config = Column(JSON, nullable=False, default={"layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n"}, comment="default parser config")
parser_config = Column(JSON, nullable=False,
default={
"layout_recognize": "DeepDOC",
"chunk_token_num": 128,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": False,
"graphrag": {
"use_graphrag": False,
"entity_types": [
"organization",
"person",
"geo",
"event",
"category",
],
"method": "general",
}
}, comment="default parser config")
chunk_num = Column(Integer, default=0, comment="chunk num")
progress = Column(Float, default=0)
progress_msg = Column(String, default="", comment="process message")

View File

@@ -14,6 +14,7 @@ class EndUser(Base):
other_id = Column(String, nullable=True) # Store original user_id
other_name = Column(String, default="", nullable=False)
other_address = Column(String, default="", nullable=False)
reflection_time = Column(DateTime, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.now)
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)

View File

@@ -56,7 +56,25 @@ class Knowledge(Base):
chunk_num = Column(Integer, default=0, comment="chunk num")
parser_id = Column(String, index=True, default="naive", comment="default parser ID")
parser_config = Column(JSON, nullable=False,
default={"layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n"},
default={
"layout_recognize": "DeepDOC",
"chunk_token_num": 128,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": False,
"graphrag": {
"use_graphrag": False,
"entity_types": [
"organization",
"person",
"geo",
"event",
"category",
],
"method": "general",
}
},
comment="default parser config")
status = Column(Integer, index=True, default=1, comment="is it validate(0: disable, 1: enable, 2:Soft-delete)")
created_at = Column(DateTime, default=datetime.datetime.now)

View File

@@ -15,6 +15,25 @@ class ModelType(StrEnum):
EMBEDDING = "embedding"
RERANK = "rerank"
@classmethod
def from_str(cls, value: str) -> "ModelType":
"""
Get a ModelType enum instance from a string value.
Args:
value (str): The string representation of the model type.
Returns:
ModelType: The corresponding ModelType enum object.
Raises:
ValueError: If the given value does not match any ModelType.
"""
try:
return cls(value)
except ValueError:
raise ValueError(f"Invalid ModelType: {value}")
class ModelProvider(StrEnum):
"""模型提供商枚举"""

View File

@@ -0,0 +1,130 @@
import datetime
import uuid
from enum import StrEnum
from sqlalchemy import Column, ForeignKey, Text, DateTime, String, Index
from sqlalchemy.dialects.postgresql import UUID
from app.db import Base
class RoleType(StrEnum):
"""
Enumeration of message roles used in prompt optimization conversations.
This enum standardizes the role identifiers for messages stored in the
prompt optimization session history, ensuring consistency across
system-generated messages, user inputs, and assistant responses.
Attributes:
SYSTEM (str): Represents system-level instructions or prompts that
define the behavior or constraints of the assistant.
USER (str): Represents messages originating from the end user.
ASSISTANT (str): Represents messages generated by the AI assistant.
"""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
class PromptOptimizerSession(Base):
"""
Prompt Optimization Session Registry.
This table records high-level metadata for prompt optimization sessions.
Each record represents a single logical session initiated by a user
under a specific tenant.
The session acts as a container for multiple conversation messages
stored in the session history table.
Table Name:
prompt_opt_session_list
Columns:
id (UUID):
Public-facing session identifier used to group conversation history.
tenant_id (UUID):
Foreign key referencing `tenants.id`.
Identifies the tenant under which the session is created.
user_id (UUID):
Foreign key referencing `users.id`.
Identifies the user who initiated the session.
created_at (DateTime):
Timestamp indicating when the session was created.
Design Notes:
- This table intentionally does not store message content
- Message-level data is stored in `prompt_opt_session_history`
- Enables efficient session listing and pagination
"""
__tablename__ = "prompt_opt_session_list"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="Session ID")
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID")
# app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False, comment="Application ID")
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, comment="User ID")
created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time", index=True)
class PromptOptimizerSessionHistory(Base):
"""
Prompt Optimization Session Message History.
This table stores the complete conversational history of a prompt
optimization session, including system prompts, user inputs, and
assistant responses.
Each record represents a single message within a session, preserving
the chronological order of interactions.
Table Name:
prompt_opt_session_history
Columns:
id (UUID):
Primary key. Unique identifier for the message record.
tenant_id (UUID):
Foreign key referencing `tenants.id`.
Identifies the tenant under which the session operates.
session_id (UUID):
Logical session identifier linking messages to a session.
user_id (UUID):
Foreign key referencing `users.id`.
Identifies the user associated with the session.
message_role (Text):
Role of the message sender (e.g., system, user, assistant).
message_content (Text):
Raw message content generated or provided during the session.
prompt (Text):
The prompt snapshot used at the time of message generation.
created_at (DateTime):
Timestamp indicating when the message was created.
Design Notes:
- Supports full conversation replay and audit
- Enables prompt evolution tracking over time
- Indexed by creation time for efficient chronological queries
"""
__tablename__ = "prompt_opt_session_history"
__table_args__ = (
Index(
"ix_prompt_opt_session_history_session_user_created",
"session_id",
"user_id",
"created_at"
),
)
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID")
# app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False, comment="Application ID")
session_id = Column(UUID(as_uuid=True), ForeignKey("prompt_opt_session_list.id"),nullable=False, comment="Session ID")
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, comment="User ID")
role = Column(String, nullable=False, comment="Message Role")
content = Column(Text, nullable=False, comment="Message Content")
# prompt = Column(Text, nullable=False, comment="Prompt")
created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time", index=True)

View File

@@ -21,3 +21,6 @@ class Tenants(Base):
# Relationship to workspaces owned by the tenant
owned_workspaces = relationship("Workspace", back_populates="tenant")
# Relationship to tool configs owned by the tenant
tool_configs = relationship("ToolConfig", back_populates="tenant")

View File

@@ -0,0 +1,226 @@
"""工具管理相关数据模型"""
import uuid
from datetime import datetime
from enum import StrEnum
from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from app.db import Base
class ToolType(StrEnum):
"""工具类型枚举"""
BUILTIN = "builtin"
CUSTOM = "custom"
MCP = "mcp"
class ToolStatus(StrEnum):
"""工具状态枚举"""
ACTIVE = "active"
INACTIVE = "inactive"
ERROR = "error"
LOADING = "loading"
class AuthType(StrEnum):
"""认证类型枚举"""
NONE = "none"
API_KEY = "api_key"
BEARER_TOKEN = "bearer_token"
class ExecutionStatus(StrEnum):
"""执行状态枚举"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
TIMEOUT = "timeout"
class ToolConfig(Base):
"""工具配置基础模型"""
__tablename__ = "tool_configs"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(255), nullable=False, index=True)
description = Column(Text)
tool_type = Column(String(50), nullable=False, index=True)
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, index=True) # 必须属于租户
status = Column(String(50), default=ToolStatus.INACTIVE.value, nullable=False, index=True) # 工具状态
# 工具特定配置JSON格式存储
config_data = Column(JSON, default=dict)
# 元数据
version = Column(String(50), default="1.0.0")
tags = Column(JSON, default=list) # 标签列表
# 时间戳
created_at = Column(DateTime, default=datetime.now, nullable=False)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False)
# 关联关系
tenant = relationship("Tenants", back_populates="tool_configs")
executions = relationship("ToolExecution", back_populates="tool_config", cascade="all, delete-orphan")
def __repr__(self):
return f"<ToolConfig(id={self.id}, name={self.name}, type={self.tool_type}, status={self.status})>"
class BuiltinToolConfig(Base):
"""内置工具配置模型"""
__tablename__ = "builtin_tool_configs"
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
tool_class = Column(String(255), nullable=False) # 工具类名
parameters = Column(JSON, default=dict) # 工具参数配置
# 关联关系
base_config = relationship("ToolConfig", foreign_keys=[id])
def __repr__(self):
return f"<BuiltinToolConfig(id={self.id}, tool_class={self.tool_class})>"
class CustomToolConfig(Base):
"""自定义工具配置模型"""
__tablename__ = "custom_tool_configs"
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
schema_url = Column(String(1000)) # OpenAPI schema URL
schema_content = Column(JSON) # OpenAPI schema 内容
# 认证配置
auth_type = Column(String(50), default=AuthType.NONE.value, nullable=False)
auth_config = Column(JSON, default=dict) # 认证配置(加密存储)
# API配置
base_url = Column(String(1000)) # API基础URL
timeout = Column(Integer, default=30) # 超时时间(秒)
# 关联关系
base_config = relationship("ToolConfig", foreign_keys=[id])
def __repr__(self):
return f"<CustomToolConfig(id={self.id}, auth_type={self.auth_type})>"
class MCPToolConfig(Base):
"""MCP工具配置模型"""
__tablename__ = "mcp_tool_configs"
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
server_url = Column(String(1000), nullable=False) # MCP服务器URL
connection_config = Column(JSON, default=dict) # 连接配置
# 服务状态
last_health_check = Column(DateTime)
health_status = Column(String(50), default="unknown")
error_message = Column(Text)
# 可用工具列表
available_tools = Column(JSON, default=list)
# 关联关系
base_config = relationship("ToolConfig", foreign_keys=[id])
def __repr__(self):
return f"<MCPToolConfig(id={self.id}, server_url={self.server_url})>"
class ToolExecution(Base):
"""工具执行记录模型"""
__tablename__ = "tool_executions"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tool_config_id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), nullable=False, index=True)
# 执行信息
execution_id = Column(String(255), nullable=False, index=True) # 执行ID可用于关联工作流等
status = Column(String(50), default=ExecutionStatus.PENDING.value, nullable=False, index=True)
# 输入输出
input_data = Column(JSON) # 输入参数
output_data = Column(JSON) # 输出结果
error_message = Column(Text) # 错误信息
# 性能指标
started_at = Column(DateTime, nullable=False, index=True)
completed_at = Column(DateTime)
execution_time = Column(Float) # 执行时间(秒)
# Token使用情况如果适用
token_usage = Column(JSON)
# 用户信息
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), index=True)
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=False, index=True)
# 关联关系
tool_config = relationship("ToolConfig", back_populates="executions")
user = relationship("User")
workspace = relationship("Workspace")
def __repr__(self):
return f"<ToolExecution(id={self.id}, status={self.status}, tool={self.tool_config_id})>"
# class ToolDependency(Base):
# """工具依赖关系模型"""
# __tablename__ = "tool_dependencies"
#
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# tool_id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), nullable=False)
# depends_on_tool_id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), nullable=False)
#
# # 依赖类型和版本要求
# dependency_type = Column(String(50), default="required") # required, optional
# version_constraint = Column(String(100)) # 版本约束,如 ">=1.0.0"
#
# # 时间戳
# created_at = Column(DateTime, default=datetime.now, nullable=False)
#
# # 关联关系
# tool = relationship("ToolConfig", foreign_keys=[tool_id])
# depends_on_tool = relationship("ToolConfig", foreign_keys=[depends_on_tool_id])
#
# def __repr__(self):
# return f"<ToolDependency(tool={self.tool_id}, depends_on={self.depends_on_tool_id})>"
# class PluginConfig(Base):
# """插件配置模型"""
# __tablename__ = "plugin_configs"
#
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# name = Column(String(255), nullable=False, unique=True, index=True)
# description = Column(Text)
#
# # 插件信息
# plugin_path = Column(String(1000), nullable=False) # 插件文件路径
# entry_point = Column(String(255), nullable=False) # 入口点
# version = Column(String(50), default="1.0.0")
#
# # 状态
# is_enabled = Column(Boolean, default=True, nullable=False)
# is_loaded = Column(Boolean, default=False, nullable=False)
# load_error = Column(Text) # 加载错误信息
#
# # 配置
# config_schema = Column(JSON) # 配置schema
# config_data = Column(JSON, default=dict) # 配置数据
#
# # 依赖
# dependencies = Column(JSON, default=list) # 依赖的其他插件
#
# # 时间戳
# created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
# updated_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
# last_loaded_at = Column(DateTime)
#
# def __repr__(self):
# return f"<PluginConfig(id={self.id}, name={self.name}, version={self.version})>"

View File

@@ -1,7 +1,7 @@
import datetime
from enum import StrEnum
import uuid
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, Boolean
from sqlalchemy import Column, String, DateTime, ForeignKey, Boolean
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from app.db import Base

View File

@@ -16,7 +16,6 @@ from app.models.data_config_model import DataConfig
from app.schemas.memory_storage_schema import (
ConfigKey,
ConfigParamsCreate,
ConfigParamsDelete,
ConfigUpdate,
ConfigUpdateExtracted,
ConfigUpdateForget,
@@ -29,37 +28,37 @@ db_logger = get_db_logger()
# 获取配置专用日志器
config_logger = get_config_logger()
TABLE_NAME = "data_config"
class DataConfigRepository:
"""数据配置Repository
提供data_config表的数据访问方法包括
- SQLAlchemy ORM 数据库操作
- Neo4j Cypher查询常量
"""
# ==================== Neo4j Cypher 查询常量 ====================
# Dialogue count by group
SEARCH_FOR_DIALOGUE = """
MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
"""
# Chunk count by group
SEARCH_FOR_CHUNK = """
MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
"""
# Statement count by group
SEARCH_FOR_STATEMENT = """
MATCH (n:Statement) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
"""
# ExtractedEntity count by group
SEARCH_FOR_ENTITY = """
MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN COUNT(n) AS num
"""
# All counts by label and total
SEARCH_FOR_ALL = """
OPTIONAL MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN 'Dialogue' AS Label, COUNT(n) AS Count
@@ -72,7 +71,7 @@ class DataConfigRepository:
UNION ALL
OPTIONAL MATCH (n) WHERE n.group_id = $group_id RETURN 'ALL' AS Label, COUNT(n) AS Count
"""
# Extracted entity details within group/app/user
SEARCH_FOR_DETIALS = """
MATCH (n:ExtractedEntity)
@@ -88,7 +87,7 @@ class DataConfigRepository:
n.user_id AS user_id,
n.id AS id
"""
# Edges between extracted entities within group/app/user
SEARCH_FOR_EDGES = """
MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity)
@@ -104,7 +103,7 @@ class DataConfigRepository:
r.statement_id AS statement_id,
r.statement AS statement
"""
# Entity graph within group (source node, edge, target node)
SEARCH_FOR_ENTITY_GRAPH = """
MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity)
@@ -137,22 +136,106 @@ class DataConfigRepository:
id: m.id
} AS targetNode
"""
# ==================== SQLAlchemy ORM 数据库操作方法 ====================
@staticmethod
def build_update_reflection(config_id: int, **kwargs) -> Tuple[str, Dict]:
"""构建反思配置更新语句SQLAlchemy text() 命名参数)
Args:
config_id: 配置ID
**kwargs: 反思配置参数
Returns:
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
Raises:
ValueError: 没有字段需要更新时抛出
"""
db_logger.debug(f"构建反思配置更新语句: config_id={config_id}")
key_where = "config_id = :config_id"
set_fields: List[str] = []
params: Dict = {
"config_id": config_id,
}
# 反思配置字段映射
mapping = {
"enable_self_reflexion": "enable_self_reflexion",
"iteration_period": "iteration_period",
"reflexion_range": "reflexion_range",
"baseline": "baseline",
"reflection_model_id": "reflection_model_id",
"memory_verify": "memory_verify",
"quality_assessment": "quality_assessment",
}
for api_field, db_col in mapping.items():
if api_field in kwargs and kwargs[api_field] is not None:
set_fields.append(f"{db_col} = :{api_field}")
params[api_field] = kwargs[api_field]
if not set_fields:
raise ValueError("No fields to update")
set_fields.append("updated_at = timezone('Asia/Shanghai', now())")
query = f"UPDATE {TABLE_NAME} SET " + ", ".join(set_fields) + f" WHERE {key_where}"
return query, params
@staticmethod
def build_select_reflection(config_id: int) -> Tuple[str, Dict]:
"""构建反思配置查询语句通过config_id查询反思配置SQLAlchemy text() 命名参数)
Args:
config_id: 配置ID
Returns:
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
"""
db_logger.debug(f"构建反思配置查询语句: config_id={config_id}")
query = (
f"SELECT config_id, enable_self_reflexion, iteration_period, reflexion_range, baseline, "
f"reflection_model_id, memory_verify, quality_assessment, user_id "
f"FROM {TABLE_NAME} WHERE config_id = :config_id"
)
params = {"config_id": config_id}
return query, params
@staticmethod
def build_select_all(workspace_id: uuid.UUID) -> Tuple[str, Dict]:
"""构建查询所有配置的语句SQLAlchemy text() 命名参数)
Args:
workspace_id: 工作空间ID
Returns:
Tuple[str, Dict]: (SQL查询字符串, 参数字典)
"""
db_logger.debug(f"构建查询所有配置语句: workspace_id={workspace_id}")
query = (
f"SELECT config_id, config_name, enable_self_reflexion, iteration_period, reflexion_range, baseline, "
f"reflection_model_id, memory_verify, quality_assessment, user_id, created_at, updated_at "
f"FROM {TABLE_NAME} WHERE workspace_id = :workspace_id ORDER BY updated_at DESC"
)
params = {"workspace_id": workspace_id}
return query, params
@staticmethod
def create(db: Session, params: ConfigParamsCreate) -> DataConfig:
"""创建数据配置
Args:
db: 数据库会话
params: 配置参数创建模型
Returns:
DataConfig: 创建的配置对象
"""
db_logger.debug(f"创建数据配置: config_name={params.config_name}, workspace_id={params.workspace_id}")
try:
db_config = DataConfig(
config_name=params.config_name,
@@ -164,37 +247,37 @@ class DataConfigRepository:
)
db.add(db_config)
db.flush() # 获取自增ID但不提交事务
db_logger.info(f"数据配置已添加到会话: {db_config.config_name} (ID: {db_config.config_id})")
return db_config
except Exception as e:
db.rollback()
db_logger.error(f"创建数据配置失败: {params.config_name} - {str(e)}")
raise
@staticmethod
def update(db: Session, update: ConfigUpdate) -> Optional[DataConfig]:
"""更新基础配置
Args:
db: 数据库会话
update: 配置更新模型
Returns:
Optional[DataConfig]: 更新后的配置对象不存在则返回None
Raises:
ValueError: 没有字段需要更新时抛出
"""
db_logger.debug(f"更新数据配置: config_id={update.config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
if not db_config:
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
return None
# 更新字段
has_update = False
if update.config_name is not None:
@@ -203,44 +286,44 @@ class DataConfigRepository:
if update.config_desc is not None:
db_config.config_desc = update.config_desc
has_update = True
if not has_update:
raise ValueError("No fields to update")
db.commit()
db.refresh(db_config)
db_logger.info(f"数据配置更新成功: {db_config.config_name} (ID: {update.config_id})")
return db_config
except Exception as e:
db.rollback()
db_logger.error(f"更新数据配置失败: config_id={update.config_id} - {str(e)}")
raise
@staticmethod
def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[DataConfig]:
"""更新记忆萃取引擎配置
Args:
db: 数据库会话
update: 萃取配置更新模型
Returns:
Optional[DataConfig]: 更新后的配置对象不存在则返回None
Raises:
ValueError: 没有字段需要更新时抛出
"""
db_logger.debug(f"更新萃取配置: config_id={update.config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
if not db_config:
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
return None
# 更新字段映射
field_mapping = {
# 模型选择
@@ -270,50 +353,50 @@ class DataConfigRepository:
"reflexion_range": "reflexion_range",
"baseline": "baseline",
}
has_update = False
for api_field, db_field in field_mapping.items():
value = getattr(update, api_field, None)
if value is not None:
setattr(db_config, db_field, value)
has_update = True
if not has_update:
raise ValueError("No fields to update")
db.commit()
db.refresh(db_config)
db_logger.info(f"萃取配置更新成功: config_id={update.config_id}")
return db_config
except Exception as e:
db.rollback()
db_logger.error(f"更新萃取配置失败: config_id={update.config_id} - {str(e)}")
raise
@staticmethod
def update_forget(db: Session, update: ConfigUpdateForget) -> Optional[DataConfig]:
"""更新遗忘引擎配置
Args:
db: 数据库会话
update: 遗忘配置更新模型
Returns:
Optional[DataConfig]: 更新后的配置对象不存在则返回None
Raises:
ValueError: 没有字段需要更新时抛出
"""
db_logger.debug(f"更新遗忘配置: config_id={update.config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first()
if not db_config:
db_logger.warning(f"数据配置不存在: config_id={update.config_id}")
return None
# 更新字段
has_update = False
if update.lambda_time is not None:
@@ -325,40 +408,40 @@ class DataConfigRepository:
if update.offset is not None:
db_config.offset = update.offset
has_update = True
if not has_update:
raise ValueError("No fields to update")
db.commit()
db.refresh(db_config)
db_logger.info(f"遗忘配置更新成功: config_id={update.config_id}")
return db_config
except Exception as e:
db.rollback()
db_logger.error(f"更新遗忘配置失败: config_id={update.config_id} - {str(e)}")
raise
@staticmethod
def get_extracted_config(db: Session, config_id: int) -> Optional[Dict]:
"""获取萃取配置,通过主键查询某条配置
Args:
db: 数据库会话
config_id: 配置ID
Returns:
Optional[Dict]: 萃取配置字典不存在则返回None
"""
db_logger.debug(f"查询萃取配置: config_id={config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
if not db_config:
db_logger.debug(f"萃取配置不存在: config_id={config_id}")
return None
result = {
"llm_id": db_config.llm_id,
"embedding_id": db_config.embedding_id,
@@ -381,62 +464,62 @@ class DataConfigRepository:
"reflexion_range": db_config.reflexion_range,
"baseline": db_config.baseline,
}
db_logger.debug(f"萃取配置查询成功: config_id={config_id}")
return result
except Exception as e:
db_logger.error(f"查询萃取配置失败: config_id={config_id} - {str(e)}")
raise
@staticmethod
def get_forget_config(db: Session, config_id: int) -> Optional[Dict]:
"""获取遗忘配置,通过主键查询某条配置
Args:
db: 数据库会话
config_id: 配置ID
Returns:
Optional[Dict]: 遗忘配置字典不存在则返回None
"""
db_logger.debug(f"查询遗忘配置: config_id={config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
if not db_config:
db_logger.debug(f"遗忘配置不存在: config_id={config_id}")
return None
result = {
"lambda_time": db_config.lambda_time,
"lambda_mem": db_config.lambda_mem,
"offset": db_config.offset,
}
db_logger.debug(f"遗忘配置查询成功: config_id={config_id}")
return result
except Exception as e:
db_logger.error(f"查询遗忘配置失败: config_id={config_id} - {str(e)}")
raise
@staticmethod
def get_by_id(db: Session, config_id: int) -> Optional[DataConfig]:
"""根据ID获取数据配置
Args:
db: 数据库会话
config_id: 配置ID
Returns:
Optional[DataConfig]: 配置对象不存在则返回None
"""
db_logger.debug(f"根据ID查询数据配置: config_id={config_id}")
try:
config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
if config:
db_logger.debug(f"数据配置查询成功: {config.config_name} (ID: {config_id})")
else:
@@ -571,56 +654,56 @@ class DataConfigRepository:
@staticmethod
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[DataConfig]:
"""获取所有配置参数
Args:
db: 数据库会话
workspace_id: 工作空间ID用于过滤查询结果
Returns:
List[DataConfig]: 配置列表
"""
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
try:
query = db.query(DataConfig)
if workspace_id:
query = query.filter(DataConfig.workspace_id == workspace_id)
configs = query.order_by(desc(DataConfig.updated_at)).all()
db_logger.debug(f"配置列表查询成功: 数量={len(configs)}")
return configs
except Exception as e:
db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}")
raise
@staticmethod
def delete(db: Session, config_id: int) -> bool:
"""删除数据配置
Args:
db: 数据库会话
config_id: 配置ID
Returns:
bool: 删除成功返回True配置不存在返回False
"""
db_logger.debug(f"删除数据配置: config_id={config_id}")
try:
db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
if not db_config:
db_logger.warning(f"数据配置不存在: config_id={config_id}")
return False
db.delete(db_config)
db.commit()
db_logger.info(f"数据配置删除成功: config_id={config_id}")
return True
except Exception as e:
db.rollback()
db_logger.error(f"删除数据配置失败: config_id={config_id} - {str(e)}")

View File

@@ -115,7 +115,9 @@ def get_knowledge_by_name(db: Session, name: str, workspace_id: uuid.UUID) -> Kn
db_logger.debug(f"Query knowledge base based on name and workspace_id: name={name}, workspace_id={workspace_id}")
try:
knowledge = db.query(Knowledge).filter(Knowledge.name == name).filter(Knowledge.workspace_id == workspace_id).first()
knowledge = db.query(Knowledge).filter(Knowledge.name == name,
Knowledge.workspace_id == workspace_id,
Knowledge.status == 1).first()
if knowledge:
db_logger.debug(f"knowledge base query successful: {name} (ID: {knowledge.id})")
else:

View File

@@ -3,9 +3,9 @@ from sqlalchemy import and_, or_, func, desc
from typing import List, Optional, Dict, Any, Tuple
import uuid
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelProvider
from app.models.models_model import ModelConfig, ModelApiKey, ModelType
from app.schemas.model_schema import (
ModelConfigCreate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
ModelConfigQuery
)
from app.core.logging_config import get_db_logger
@@ -32,7 +32,7 @@ class ModelConfigRepository:
query = query.filter(
or_(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public == True
ModelConfig.is_public
)
)
@@ -60,7 +60,7 @@ class ModelConfigRepository:
query = query.filter(
or_(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public == True
ModelConfig.is_public
)
)
@@ -92,7 +92,7 @@ class ModelConfigRepository:
query = query.filter(
or_(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public == True
ModelConfig.is_public
)
)
@@ -117,13 +117,21 @@ class ModelConfigRepository:
filters.append(
or_(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public == True
ModelConfig.is_public
)
)
# 支持多个 type 值(使用 IN 查询)
# 兼容 chat 和 llm 类型:如果查询包含其中一个,则同时匹配两者
if query.type:
filters.append(ModelConfig.type.in_(query.type))
type_values = list(query.type)
# 如果包含 chat 或 llm则同时包含两者
if ModelType.CHAT in type_values or ModelType.LLM in type_values:
if ModelType.CHAT not in type_values:
type_values.append(ModelType.CHAT)
if ModelType.LLM not in type_values:
type_values.append(ModelType.LLM)
filters.append(ModelConfig.type.in_(type_values))
if query.is_active is not None:
filters.append(ModelConfig.is_active == query.is_active)
@@ -183,12 +191,12 @@ class ModelConfigRepository:
query = query.filter(
or_(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public == True
ModelConfig.is_public
)
)
if is_active:
query = query.filter(ModelConfig.is_active == True)
query = query.filter(ModelConfig.is_active)
models = query.order_by(ModelConfig.name).all()
db_logger.debug(f"根据类型查询模型配置成功: 数量={len(models)}")
@@ -285,7 +293,7 @@ class ModelConfigRepository:
try:
# 总数统计
total_models = db.query(ModelConfig).count()
active_models = db.query(ModelConfig).filter(ModelConfig.is_active == True).count()
active_models = db.query(ModelConfig).filter(ModelConfig.is_active).count()
# 按类型统计
llm_count = db.query(ModelConfig).filter(ModelConfig.type == ModelType.LLM).count()
@@ -344,7 +352,7 @@ class ModelApiKeyRepository:
query = db.query(ModelApiKey).filter(ModelApiKey.model_config_id == model_config_id)
if is_active:
query = query.filter(ModelApiKey.is_active == True)
query = query.filter(ModelApiKey.is_active)
api_keys = query.order_by(ModelApiKey.priority, ModelApiKey.created_at).all()
db_logger.debug(f"API Key列表查询成功: 数量={len(api_keys)}")

View File

@@ -100,7 +100,13 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
# "triplets": [triplet.model_dump() for triplet in statement.triplet_extraction_info.triplets] if statement.triplet_extraction_info else [],
# "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else []
# }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}),
"statement_embedding": statement.statement_embedding if statement.statement_embedding else None
"statement_embedding": statement.statement_embedding if statement.statement_embedding else None,
# 添加情绪字段处理
"emotion_type": statement.emotion_type,
"emotion_intensity": statement.emotion_intensity,
"emotion_keywords": statement.emotion_keywords if statement.emotion_keywords else [],
"emotion_subject": statement.emotion_subject,
"emotion_target": statement.emotion_target
}
flattened_statements.append(flattened_statement)

View File

@@ -20,20 +20,25 @@ UNWIND $statements AS statement
MERGE (s:Statement {id: statement.id})
SET s += {
id: statement.id,
run_id: statement.run_id,
chunk_id: statement.chunk_id,
group_id: statement.group_id,
user_id: statement.user_id,
apply_id: statement.apply_id,
chunk_id: statement.chunk_id,
run_id: statement.run_id,
stmt_type: statement.stmt_type,
statement: statement.statement,
emotion_intensity: statement.emotion_intensity,
emotion_target: statement.emotion_target,
emotion_subject: statement.emotion_subject,
emotion_type: statement.emotion_type,
emotion_keywords: statement.emotion_keywords,
temporal_info: statement.temporal_info,
created_at: statement.created_at,
expired_at: statement.expired_at,
stmt_type: statement.stmt_type,
temporal_info: statement.temporal_info,
relevence_info: statement.relevence_info,
statement: statement.statement,
valid_at: statement.valid_at,
invalid_at: statement.invalid_at,
statement_embedding: statement.statement_embedding
statement_embedding: statement.statement_embedding,
relevence_info: statement.relevence_info
}
RETURN s.id AS uuid
"""
@@ -746,3 +751,57 @@ DETACH DELETE losing
RETURN count(losing) as deleted
"""
neo4j_statement_part = '''
MATCH (n:Statement)
WHERE n.group_id = "{}"
AND datetime(n.created_at) >= datetime() - duration('P3D')
RETURN
n.statement as statement_name,
n.id as statement_id,
n.created_at as statement_created_at
'''
neo4j_statement_all = '''
MATCH (n:Statement)
WHERE n.group_id = "{}"
RETURN
n.statement as statement_name,
n.id as statement_id
'''
neo4j_query_part = """
MATCH (n)-[r]-(m:ExtractedEntity)
WHERE n.group_id = "{}"
AND datetime(n.created_at) >= datetime() - duration('P3D')
WITH DISTINCT m
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
RETURN
m.name as entity1_name,
m.description as description,
m.statement_id as statement_id,
m.created_at as created_at,
m.expired_at as expired_at,
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
rel as relationship,
CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name,
other as entity2
"""
neo4j_query_all = """
MATCH (n)-[r]-(m:ExtractedEntity)
WHERE n.group_id = "{}"
WITH DISTINCT m
OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity)
RETURN
m.name as entity1_name,
m.description as description,
m.statement_id as statement_id,
m.created_at as created_at,
m.expired_at as expired_at,
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
rel as relationship,
CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name,
other as entity2
"""

View File

@@ -0,0 +1,246 @@
# -*- coding: utf-8 -*-
"""情绪数据仓储模块
本模块提供情绪数据的查询功能,用于情绪分析和统计。
Classes:
EmotionRepository: 情绪数据仓储,提供情绪标签、词云、健康指数等查询方法
"""
from typing import List, Dict, Optional, Any
from datetime import datetime, timedelta
import json
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class EmotionRepository:
"""情绪数据仓储
提供情绪数据的查询和统计功能,包括:
- 情绪标签统计
- 情绪词云数据
- 时间范围内的情绪数据查询
Attributes:
connector: Neo4j连接器实例
"""
def __init__(self, connector: Neo4jConnector):
"""初始化情绪数据仓储
Args:
connector: Neo4j连接器实例
"""
self.connector = connector
logger.info("情绪数据仓储初始化完成")
async def get_emotion_tags(
self,
group_id: str,
emotion_type: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
limit: int = 10
) -> List[Dict[str, Any]]:
"""获取情绪标签统计
查询指定用户的情绪类型分布,包括计数、百分比和平均强度。
Args:
group_id: 用户组ID宿主ID
emotion_type: 可选的情绪类型过滤joy/sadness/anger/fear/surprise/neutral
start_date: 可选的开始日期ISO格式字符串
end_date: 可选的结束日期ISO格式字符串
limit: 返回结果的最大数量
Returns:
List[Dict]: 情绪标签列表,每个包含:
- emotion_type: 情绪类型
- count: 该类型的数量
- percentage: 占比百分比
- avg_intensity: 平均强度
"""
# 构建查询条件
where_clauses = ["s.group_id = $group_id", "s.emotion_type IS NOT NULL"]
params = {"group_id": group_id, "limit": limit}
if emotion_type:
where_clauses.append("s.emotion_type = $emotion_type")
params["emotion_type"] = emotion_type
if start_date:
where_clauses.append("s.created_at >= $start_date")
params["start_date"] = start_date
if end_date:
where_clauses.append("s.created_at <= $end_date")
params["end_date"] = end_date
where_str = " AND ".join(where_clauses)
# 优化的 Cypher 查询:使用索引,减少中间结果
query = f"""
MATCH (s:Statement)
WHERE {where_str}
WITH s.emotion_type as emotion_type,
count(*) as count,
avg(s.emotion_intensity) as avg_intensity
WITH collect({{emotion_type: emotion_type, count: count, avg_intensity: avg_intensity}}) as results,
sum(count) as total_count
UNWIND results as result
RETURN result.emotion_type as emotion_type,
result.count as count,
toFloat(result.count) / total_count * 100 as percentage,
result.avg_intensity as avg_intensity
ORDER BY count DESC
LIMIT $limit
"""
try:
results = await self.connector.execute_query(query, **params)
formatted_results = [
{
"emotion_type": record["emotion_type"],
"count": record["count"],
"percentage": round(record["percentage"], 2),
"avg_intensity": round(record["avg_intensity"], 3) if record["avg_intensity"] else 0.0
}
for record in results
]
return formatted_results
except Exception as e:
logger.error(f"查询情绪标签失败: {str(e)}", exc_info=True)
return []
async def get_emotion_wordcloud(
self,
group_id: str,
emotion_type: Optional[str] = None,
limit: int = 50
) -> List[Dict[str, Any]]:
"""获取情绪词云数据
查询情绪关键词及其频率,用于生成词云可视化。
Args:
group_id: 用户组ID宿主ID
emotion_type: 可选的情绪类型过滤
limit: 返回关键词的最大数量
Returns:
List[Dict]: 关键词列表,每个包含:
- keyword: 关键词
- frequency: 出现频率
- emotion_type: 关联的情绪类型
- avg_intensity: 平均强度
"""
# 构建查询条件
where_clauses = ["s.group_id = $group_id", "s.emotion_keywords IS NOT NULL"]
params = {"group_id": group_id, "limit": limit}
if emotion_type:
where_clauses.append("s.emotion_type = $emotion_type")
params["emotion_type"] = emotion_type
where_str = " AND ".join(where_clauses)
# 优化的 Cypher 查询:使用索引,减少不必要的计算
query = f"""
MATCH (s:Statement)
WHERE {where_str}
UNWIND s.emotion_keywords as keyword
WITH keyword,
s.emotion_type as emotion_type,
count(*) as frequency,
avg(s.emotion_intensity) as avg_intensity
WHERE keyword IS NOT NULL AND keyword <> ''
RETURN keyword,
frequency,
emotion_type,
avg_intensity
ORDER BY frequency DESC
LIMIT $limit
"""
try:
results = await self.connector.execute_query(query, **params)
formatted_results = [
{
"keyword": record["keyword"],
"frequency": record["frequency"],
"emotion_type": record["emotion_type"],
"avg_intensity": round(record["avg_intensity"], 3) if record["avg_intensity"] else 0.0
}
for record in results
]
return formatted_results
except Exception as e:
logger.error(f"查询情绪词云失败: {str(e)}", exc_info=True)
return []
async def get_emotions_in_range(
self,
group_id: str,
time_range: str = "30d"
) -> List[Dict[str, Any]]:
"""获取时间范围内的情绪数据
查询指定时间范围内的所有情绪数据,用于健康指数计算。
Args:
group_id: 用户组ID宿主ID
time_range: 时间范围7d/30d/90d
Returns:
List[Dict]: 情绪数据列表,每个包含:
- emotion_type: 情绪类型
- emotion_intensity: 情绪强度
- created_at: 创建时间
- statement_id: 陈述句ID
"""
# 解析时间范围
days_map = {"7d": 7, "30d": 30, "90d": 90}
days = days_map.get(time_range, 30)
# 计算起始日期(使用字符串比较,避免时区问题)
start_date = (datetime.now() - timedelta(days=days)).isoformat()
# 优化的 Cypher 查询:使用字符串比较避免时区问题
query = """
MATCH (s:Statement)
WHERE s.group_id = $group_id
AND s.emotion_type IS NOT NULL
AND s.created_at >= $start_date
RETURN s.id as statement_id,
s.emotion_type as emotion_type,
s.emotion_intensity as emotion_intensity,
s.created_at as created_at
ORDER BY s.created_at ASC
"""
try:
results = await self.connector.execute_query(
query,
group_id=group_id,
start_date=start_date
)
formatted_results = [
{
"statement_id": record["statement_id"],
"emotion_type": record["emotion_type"],
"emotion_intensity": record["emotion_intensity"],
"created_at": record["created_at"].isoformat() if hasattr(record["created_at"], "isoformat") else str(record["created_at"])
}
for record in results
]
return formatted_results
except Exception as e:
logger.error(f"查询时间范围情绪数据失败: {str(e)}", exc_info=True)
return []

View File

@@ -0,0 +1,227 @@
from app.repositories import Neo4jConnector
neo4j_connector = Neo4jConnector()
async def update_neo4j_data(neo4j_dict_data, update_databases):
"""
Update Neo4j data based on query criteria and update parameters
Args:
neo4j_dict_data: find
update_databases: update
"""
try:
# 构建WHERE条件
where_conditions = []
params = {}
for key, value in neo4j_dict_data.items():
if value is not None:
param_name = f"param_{key}"
where_conditions.append(f"e.{key} = ${param_name}")
params[param_name] = value
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
# 构建SET条件
set_conditions = []
for key, value in update_databases.items():
if value is not None:
param_name = f"update_{key}"
set_conditions.append(f"e.{key} = ${param_name}")
params[param_name] = value
set_clause = ", ".join(set_conditions)
if not set_clause:
print("警告: 没有需要更新的字段")
return False
# 构建Cypher查询
cypher_query = f"""
MATCH (e:ExtractedEntity)
WHERE {where_clause}
SET {set_clause}
RETURN count(e) as updated_count, collect(e.name) as updated_names
"""
print(f"\n执行Cypher查询: {cypher_query}")
print(f"参数: {params}")
# 执行更新
result = await neo4j_connector.execute_query(cypher_query, **params)
if result:
updated_count = result[0].get('updated_count', 0)
updated_names = result[0].get('updated_names', [])
print(f"成功更新 {updated_count} 个节点")
if updated_names:
print(f"更新的实体名称: {updated_names}")
return updated_count > 0
else:
return False
except Exception as e:
print(f"更新过程中出现错误: {e}")
import traceback
traceback.print_exc()
return False
def map_field_names(data_dict):
mapped_dict = {}
has_name_field = False
# 第一遍检查是否有name相关字段
for key, value in data_dict.items():
if key in ['name', 'entity2.name', 'entity1.name']:
has_name_field = True
break
print(f"字段检查: has_name_field = {has_name_field}")
# 第二遍:根据规则映射和过滤字段
for key, value in data_dict.items():
if key == 'entity2.name' or key == 'entity2_name':
# 将 entity2.name 映射为 name
mapped_dict['name'] = value
print(f"字段名映射: {key} -> name")
elif key == 'entity1.name' or key == 'entity1_name':
# 将 entity1.name 映射为 name
mapped_dict['name'] = value
print(f"字段名映射: {key} -> name")
elif key == 'entity1.description':
# 将 entity1.description 映射为 description
mapped_dict['description'] = value
print(f"字段名映射: {key} -> description")
elif key == 'entity2.description':
# 将 entity2.description 映射为 description
mapped_dict['description'] = value
print(f"字段名映射: {key} -> description")
elif key == 'relationship_type':
# 跳过relationship_type字段
print(f"字段过滤: 跳过不需要的字段 '{key}'")
continue
elif key == 'entity1_name':
if has_name_field:
# 如果有name字段跳过entity1_name
print(f"字段过滤: 由于存在name字段跳过 '{key}'")
continue
else:
# 如果没有name字段保留entity1_name
mapped_dict[key] = value
print(f"字段保留: {key}")
elif key == 'entity2_name':
if has_name_field:
# 如果有name字段跳过entity2_name
print(f"字段过滤: 由于存在name字段跳过 '{key}'")
continue
else:
# 即使没有name字段也不使用entity2_name根据需求
print(f"字段过滤: 跳过不推荐的字段 '{key}'")
continue
elif '.' not in key:
# 不包含点号的其他字段直接保留
mapped_dict[key] = value
else:
# 其他包含点号的字段跳过并警告
print(f"警告: 跳过不支持的嵌套字段 '{key}'")
print(f"字段映射结果: {mapped_dict}")
return mapped_dict
async def neo4j_data(solved_data):
"""
Process the resolved data and update the Neo4j database
Args:
Solved_data: Solution Data List
Returns:
Int: Number of successfully updated records
"""
success_count = 0
for i in solved_data:
neo4j_dict_data = {}
update_databases = {}
results = i['results']
for data in results:
resolved = data.get('resolved')
if not resolved:
print("跳过resolved为None")
continue
try:
change_list = resolved.get('change', [])
except (AttributeError, TypeError):
change_list = []
if change_list == []:
print("跳过change_list为空")
continue
if change_list and len(change_list) > 0:
change = change_list[0]
print(f"change: {change}")
field_data = change.get('field', [])
print(f"field_data: {field_data}")
print(f"field_data type: {type(field_data)}")
# 字段名映射和过滤函数
# 处理field数据可能是字典或列表
if isinstance(field_data, dict):
# 如果是字典,映射字段名后更新
mapped_data = map_field_names(field_data)
update_databases.update(mapped_data)
elif isinstance(field_data, list):
# 如果是列表,遍历每个字典并更新
for field_item in field_data:
if isinstance(field_item, dict):
mapped_item = map_field_names(field_item)
update_databases.update(mapped_item)
else:
print(f"警告: field_item不是字典: {field_item}")
else:
print(f"警告: field_data类型不支持: {type(field_data)}")
if 'entity1_name' in data:
data['name'] = data.pop('entity1_name')
if 'entity2_name' in data:
data.pop('entity2_name', None)
resolved_memory = resolved.get('resolved_memory', {})
entity2 = None
if isinstance(resolved_memory, dict):
entity2 = resolved_memory.get('entity2')
if entity2 and isinstance(entity2, dict) and len(entity2) >= 5:
stat_id = resolved.get('original_memory_id')
# 安全地获取description
statement_id = None
if isinstance(resolved_memory, dict):
statement_id = resolved_memory.get('statement_id')
# 只有当neo4j_dict_data中还没有statement_id时才使用original_memory_id
if statement_id and 'id' not in neo4j_dict_data:
neo4j_dict_data['id'] = stat_id
neo4j_dict_data['statement_id'] = statement_id
else:
# 处理original_memory_id它可能是字符串或字典
try:
for key, value in resolved_memory.items():
if key == 'statement_id':
neo4j_dict_data['statement_id'] = value
if key == 'description':
neo4j_dict_data['description'] = value
except AttributeError:
neo4j_dict_data=[]
print(neo4j_dict_data)
print(update_databases)
if neo4j_dict_data!=[]:
await update_neo4j_data(neo4j_dict_data, update_databases)
success_count += 1
return success_count

View File

@@ -58,11 +58,22 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]):
n['invalid_at'] = datetime.fromisoformat(n['invalid_at'])
# 处理temporal_info字段
if isinstance(n.get('temporal_info'), dict):
if isinstance(n.get('temporal_info'), str):
# 从字符串转换为枚举值
n['temporal_info'] = TemporalInfo(n['temporal_info'])
elif isinstance(n.get('temporal_info'), dict):
n['temporal_info'] = TemporalInfo(**n['temporal_info'])
elif not n.get('temporal_info'):
# 如果没有temporal_info创建一个默认的
n['temporal_info'] = TemporalInfo()
n['temporal_info'] = TemporalInfo.STATIC
# 处理情绪字段 - 映射 Neo4j 节点属性到 StatementNode 模型
# 处理空值情况,确保字段存在
n['emotion_type'] = n.get('emotion_type')
n['emotion_intensity'] = n.get('emotion_intensity')
n['emotion_keywords'] = n.get('emotion_keywords', [])
n['emotion_subject'] = n.get('emotion_subject')
n['emotion_target'] = n.get('emotion_target')
return StatementNode(**n)

View File

@@ -0,0 +1,124 @@
import uuid
from sqlalchemy.orm import Session
from app.core.logging_config import get_db_logger
from app.models.prompt_optimizer_model import (
PromptOptimizerSession, PromptOptimizerSessionHistory, RoleType
)
db_logger = get_db_logger()
class PromptOptimizerSessionRepository:
"""Repository for managing prompt optimization sessions and session history."""
def __init__(self, db: Session):
self.db = db
def create_session(
self,
tenant_id: uuid.UUID,
user_id: uuid.UUID
) -> PromptOptimizerSession:
"""
Create a new prompt optimization session for a user and app.
Args:
tenant_id (uuid.UUID): The unique identifier of the tenant.
user_id (uuid.UUID): The unique identifier of the user.
Returns:
PromptOptimizerSession: The newly created session object.
"""
db_logger.debug(f"Create prompt optimization session: tenant_id={tenant_id}, user_id={user_id}")
try:
session = PromptOptimizerSession(
tenant_id=tenant_id,
user_id=user_id,
)
self.db.add(session)
self.db.commit()
self.db.refresh(session)
db_logger.debug(f"Prompt optimization session created: ID:{session.id}")
return session
except Exception as e:
db_logger.error(f"Error creating prompt optimization session: user_id={user_id} - {str(e)}")
raise
def get_session_history(
self,
session_id: uuid.UUID,
user_id: uuid.UUID
) -> list[type[PromptOptimizerSessionHistory]]:
"""
Retrieve all message history of a specific prompt optimization session.
Args:
session_id (uuid.UUID): The unique identifier of the session.
user_id (uuid.UUID): The unique identifier of the user.
Returns:
list[PromptOptimizerSessionHistory]: A list of session history records
ordered by creation time ascending.
"""
db_logger.debug(f"Get prompt optimization session history: "
f"user_id={user_id}, session_id={session_id}")
try:
# First get the internal session ID from the session list table
session = self.db.query(PromptOptimizerSession).filter(
PromptOptimizerSession.id == session_id,
PromptOptimizerSession.user_id == user_id
).first()
if not session:
return []
history = self.db.query(PromptOptimizerSessionHistory).filter(
PromptOptimizerSessionHistory.session_id == session.id,
PromptOptimizerSessionHistory.user_id == user_id
).order_by(PromptOptimizerSessionHistory.created_at.asc()).all()
return history
except Exception as e:
db_logger.error(f"Error retrieving prompt optimization session history: session_id={session_id} - {str(e)}")
raise
def create_message(
self,
tenant_id: uuid.UUID,
session_id: uuid.UUID,
user_id: uuid.UUID,
role: RoleType,
content: str,
) -> PromptOptimizerSessionHistory:
"""
Create a new message in the session history.
This method is a placeholder for future implementation.
"""
try:
# Get the session to ensure it exists and belongs to the user
session = self.db.query(PromptOptimizerSession).filter(
PromptOptimizerSession.id == session_id,
PromptOptimizerSession.user_id == user_id,
PromptOptimizerSession.tenant_id == tenant_id
).first()
if not session:
db_logger.error(f"Session {session_id} not found for user {user_id}")
raise ValueError(f"Session {session_id} not found for user {user_id}")
message = PromptOptimizerSessionHistory(
tenant_id=tenant_id,
session_id=session.id,
user_id=user_id,
role=role.value,
content=content,
)
self.db.add(message)
self.db.commit()
return message
except Exception as e:
db_logger.error(f"Error creating prompt optimization session history: session_id={session_id} - {str(e)}")
raise

View File

@@ -0,0 +1,32 @@
"""情绪分析相关的请求和响应模型"""
from typing import Optional
from pydantic import BaseModel, Field
class EmotionTagsRequest(BaseModel):
"""获取情绪标签统计请求"""
group_id: str = Field(..., description="组ID")
emotion_type: Optional[str] = Field(None, description="情绪类型过滤joy/sadness/anger/fear/surprise/neutral")
start_date: Optional[str] = Field(None, description="开始日期ISO格式2024-01-01")
end_date: Optional[str] = Field(None, description="结束日期ISO格式2024-12-31")
limit: int = Field(10, ge=1, le=100, description="返回数量限制")
class EmotionWordcloudRequest(BaseModel):
"""获取情绪词云数据请求"""
group_id: str = Field(..., description="组ID")
emotion_type: Optional[str] = Field(None, description="情绪类型过滤joy/sadness/anger/fear/surprise/neutral")
limit: int = Field(50, ge=1, le=200, description="返回词语数量")
class EmotionHealthRequest(BaseModel):
"""获取情绪健康指数请求"""
group_id: str = Field(..., description="组ID")
time_range: str = Field("30d", description="时间范围7d/30d/90d")
class EmotionSuggestionsRequest(BaseModel):
"""获取个性化情绪建议请求"""
group_id: str = Field(..., description="组ID")
config_id: Optional[int] = Field(None, description="配置ID用于指定LLM模型")

View File

@@ -13,5 +13,6 @@ class EndUser(BaseModel):
other_id: Optional[str] = Field(description="第三方ID", default=None)
other_name: Optional[str] = Field(description="其他名称", default="")
other_address: Optional[str] = Field(description="其他地址", default="")
reflection_time: Optional[datetime.datetime] = Field(description="反思时间", default_factory=datetime.datetime.now)
created_at: datetime.datetime = Field(description="创建时间", default_factory=datetime.datetime.now)
updated_at: datetime.datetime = Field(description="更新时间", default_factory=datetime.datetime.now)

View File

@@ -0,0 +1,52 @@
from pydantic import BaseModel, Field
from typing import Optional
from enum import Enum
class OptimizationStrategy(str, Enum):
"""优化策略枚举"""
SPEED_FIRST = "speed_first"
ACCURACY_FIRST = "accuracy_first"
BALANCED = "balanced"
class Memory_Reflection(BaseModel):
config_id: Optional[int] = None
reflection_enabled: bool
reflection_period_in_hours: str
reflexion_range: str
baseline: str
reflection_model_id: str
memory_verify: bool
quality_assessment: bool
# 新增快速引擎优化参数
optimization_strategy: Optional[OptimizationStrategy] = OptimizationStrategy.BALANCED
use_fast_model: Optional[bool] = True
enable_caching: Optional[bool] = True
enable_streaming: Optional[bool] = True
batch_size: Optional[int] = Field(default=3, ge=1, le=10)
max_concurrent: Optional[int] = Field(default=5, ge=1, le=20)
class Config:
use_enum_values = True
class FastReflectionRequest(BaseModel):
"""快速反思请求模型"""
reflection: Memory_Reflection
host_id: Optional[str] = "88a459f5_text02"
optimization_strategy: Optional[OptimizationStrategy] = OptimizationStrategy.BALANCED
class Config:
use_enum_values = True
class ReflectionBenchmarkRequest(BaseModel):
"""反思基准测试请求模型"""
reflection: Memory_Reflection
host_id: Optional[str] = "88a459f5_text02"
iterations: Optional[int] = Field(default=3, ge=1, le=10)
class Config:
use_enum_values = True

View File

@@ -2,7 +2,7 @@
所有的内容是放错误地方了应该放在models
"""
from typing import Any, Optional, List, Dict, Literal
from typing import Any, Optional, List, Dict, Literal, Union
import time
import uuid
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator
@@ -28,25 +28,48 @@ class Write_UserInput(BaseModel):
# ============================================================================
class BaseDataSchema(BaseModel):
"""Base schema for the data"""
id: str = Field(..., description="The unique identifier for the data entry.")
statement: str = Field(..., description="The statement text.")
group_id: str = Field(..., description="The group identifier.")
chunk_id: str = Field(..., description="The chunk identifier.")
# 保持原有必需字段为可选,以兼容不同数据源
id: Optional[str] = Field(None, description="The unique identifier for the data entry.")
statement: Optional[str] = Field(None, description="The statement text.")
group_id: Optional[str] = Field(None, description="The group identifier.")
chunk_id: Optional[str] = Field(None, description="The chunk identifier.")
created_at: str = Field(..., description="The creation timestamp in ISO 8601 format.")
expired_at: Optional[str] = Field(None, description="The expiration timestamp in ISO 8601 format.")
valid_at: Optional[str] = Field(None, description="The validation timestamp in ISO 8601 format.")
invalid_at: Optional[str] = Field(None, description="The invalidation timestamp in ISO 8601 format.")
entity_ids: List[str] = Field([], description="The list of entity identifiers.")
description: Optional[str] = Field(None, description="The description of the data entry.")
# 新增字段以匹配实际输入数据
entity1_name: str = Field(..., description="The first entity name.")
entity2_name: Optional[str] = Field(None, description="The second entity name.")
statement_id: str = Field(..., description="The statement identifier.")
relationship_type: str = Field(..., description="The relationship type.")
relationship: Optional[Dict[str, Any]] = Field(None, description="The relationship object.")
entity2: Optional[Dict[str, Any]] = Field(None, description="The second entity object.")
class QualityAssessmentSchema(BaseModel):
"""Schema for memory quality assessment results."""
score: int = Field(..., ge=0, le=100, description="Quality score percentage (0-100).")
summary: str = Field(..., description="Brief summary of data quality status, including main issues and strengths.")
class MemoryVerifySchema(BaseModel):
"""Schema for memory privacy verification results."""
has_privacy: bool = Field(..., description="Whether privacy information was detected.")
privacy_types: List[str] = Field([], description="List of detected privacy information types.")
summary: str = Field(..., description="Brief summary of privacy detection results.")
class ConflictResultSchema(BaseModel):
"""Schema for the conflict result data in the reflexion_data.json file."""
data: List[BaseDataSchema] = Field(..., description="The conflict memory data.")
data: List[BaseDataSchema] = Field(..., description="The conflict memory data. Only contains conflicting records when conflict is True.")
conflict: bool = Field(..., description="Whether the memory is in conflict.")
conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.")
quality_assessment: Optional[QualityAssessmentSchema] = Field(None, description="The quality assessment object. Contains score and summary when quality_assessment is enabled, null otherwise.")
memory_verify: Optional[MemoryVerifySchema] = Field(None, description="The memory privacy verification object. Contains privacy detection results when memory_verify is enabled, null otherwise.")
@model_validator(mode="before")
@classmethod
def _normalize_data(cls, v):
if isinstance(v, dict):
d = v.get("data")
@@ -61,7 +84,6 @@ class ConflictSchema(BaseModel):
conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.")
@model_validator(mode="before")
@classmethod
def _normalize_data(cls, v):
if isinstance(v, dict):
d = v.get("data")
@@ -76,21 +98,30 @@ class ReflexionSchema(BaseModel):
solution: str = Field(..., description="The solution for the reflexion.")
class ChangeRecordSchema(BaseModel):
"""Schema for individual change records"""
field: List[Dict[str, str]] = Field(..., description="List of field changes, each containing field name and new value.")
class ResolvedSchema(BaseModel):
"""Schema for the resolved memory data in the reflexion_data"""
original_memory_id: Optional[str] = Field(None, description="The original memory identifier.")
resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data.")
# resolved_memory: Optional[BaseDataSchema] = Field(None, description="The resolved memory data (only contains records that need modification).")
resolved_memory: Optional[Union[BaseDataSchema, List[BaseDataSchema]]] = Field(None, description="The resolved memory data (only contains records that need modification). Can be a single record or list of records.")
change: Optional[List[ChangeRecordSchema]] = Field(None, description="List of detailed change records with IDs and field information.")
class SingleReflexionResultSchema(BaseModel):
"""Schema for a single reflexion result item."""
conflict: ConflictResultSchema = Field(..., description="The conflict result data for this specific conflict type.")
reflexion: ReflexionSchema = Field(..., description="The reflexion data for this conflict.")
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data for this conflict.")
type: str = Field("reflexion_result", description="The type identifier.")
class ReflexionResultSchema(BaseModel):
"""Schema for the reflexion result data in the reflexion_data.json file."""
# 模型输出中 "conflict" 为单个冲突对象(包含 data 与 conflict_memory而非字典映射
conflict: ConflictResultSchema = Field(..., description="The conflict result data.")
reflexion: Optional[ReflexionSchema] = Field(None, description="The reflexion data.")
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data.")
"""Schema for the complete reflexion result data - a list of individual conflict resolutions."""
results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.")
@model_validator(mode="before")
@classmethod
def _normalize_resolved(cls, v):
if isinstance(v, dict):
conflict = v.get("conflict")

View File

@@ -0,0 +1,99 @@
from pydantic import BaseModel, Field
from uuid import UUID
# =========================================
# API Request Schemas
# =========================================
class PromptOptMessage(BaseModel):
model_id: UUID = Field(
...,
description="Model ID"
)
message: str = Field(
...,
min_length=1,
description="User's input message"
)
current_prompt: str = Field(
default="",
description="currently optimized prompt"
)
class PromptOptModelSet(BaseModel):
id: UUID | None = Field(
default=None,
description="Configuration ID"
)
system_prompt: str = Field(
...,
description="System Prompt"
)
# =========================================
# Service Layer Results
# =========================================
class OptimizePromptResult(BaseModel):
prompt: str = Field(
...,
description="Optimized Prompt"
)
desc: str = Field(
...,
description="Description"
)
# =========================================
# API Response Schemas
# =========================================
class CreateSessionResponse(BaseModel):
model_config = {"from_attributes": True}
id: UUID = Field(
...,
description="Session ID"
)
class OptimizePromptResponse(BaseModel):
model_config = {"from_attributes": True}
prompt: str = Field(
...,
description="Optimized Prompt"
)
desc: str = Field(
...,
description="Description"
)
variables: list = Field(
...,
description="Variables"
)
class SessionMessage(BaseModel):
role: str = Field(
...,
description="Message role (user/assistant)"
)
content: str = Field(
...,
description="Message content"
)
class SessionHistoryResponse(BaseModel):
session_id: UUID = Field(
...,
description="Session ID"
)
messages: list[SessionMessage] = Field(
...,
description="List of messages in the session"
)

View File

@@ -14,6 +14,7 @@ from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
from app.repositories import workspace_repository, knowledge_repository
logger = get_business_logger()
@@ -328,4 +329,4 @@ def create_agent_invocation_tool(
)
return f"调用 Agent 失败: {str(e)}"
return invoke_agent
return invoke_agent

Some files were not shown because too many files have changed in this diff Show More