Compare commits
1 Commits
release/v0
...
revert-218
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
524aed19d4 |
@@ -3,13 +3,8 @@ import platform
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
from celery import Celery
|
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from celery import Celery
|
||||||
# macOS fork() safety - must be set before any Celery initialization
|
|
||||||
if platform.system() == 'Darwin':
|
|
||||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
|
||||||
|
|
||||||
# 创建 Celery 应用实例
|
# 创建 Celery 应用实例
|
||||||
# broker: 任务队列(使用 Redis DB 0)
|
# broker: 任务队列(使用 Redis DB 0)
|
||||||
@@ -68,20 +63,15 @@ celery_app.conf.update(
|
|||||||
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||||
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||||
|
|
||||||
# Long-term storage tasks → memory_tasks queue (batched write strategies)
|
|
||||||
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
|
|
||||||
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
|
||||||
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
|
||||||
|
|
||||||
# Document tasks → document_tasks queue (prefork worker)
|
# Document tasks → document_tasks queue (prefork worker)
|
||||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||||
|
|
||||||
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
# Beat/periodic tasks → document_tasks queue (prefork worker)
|
||||||
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'},
|
||||||
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'},
|
||||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
|
'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'},
|
||||||
'app.controllers.memory_storage_controller.search_all': {'queue': 'periodic_tasks'},
|
'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -89,40 +79,40 @@ celery_app.conf.update(
|
|||||||
celery_app.autodiscover_tasks(['app'])
|
celery_app.autodiscover_tasks(['app'])
|
||||||
|
|
||||||
# Celery Beat schedule for periodic tasks
|
# Celery Beat schedule for periodic tasks
|
||||||
# memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||||
# memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||||
# workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||||
# forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
|
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
|
||||||
|
|
||||||
# 构建定时任务配置
|
# 构建定时任务配置
|
||||||
# beat_schedule_config = {
|
beat_schedule_config = {
|
||||||
# "run-workspace-reflection": {
|
"run-workspace-reflection": {
|
||||||
# "task": "app.tasks.workspace_reflection_task",
|
"task": "app.tasks.workspace_reflection_task",
|
||||||
# "schedule": workspace_reflection_schedule,
|
"schedule": workspace_reflection_schedule,
|
||||||
# "args": (),
|
"args": (),
|
||||||
# },
|
},
|
||||||
# "regenerate-memory-cache": {
|
"regenerate-memory-cache": {
|
||||||
# "task": "app.tasks.regenerate_memory_cache",
|
"task": "app.tasks.regenerate_memory_cache",
|
||||||
# "schedule": memory_cache_regeneration_schedule,
|
"schedule": memory_cache_regeneration_schedule,
|
||||||
# "args": (),
|
"args": (),
|
||||||
# },
|
},
|
||||||
# "run-forgetting-cycle": {
|
"run-forgetting-cycle": {
|
||||||
# "task": "app.tasks.run_forgetting_cycle_task",
|
"task": "app.tasks.run_forgetting_cycle_task",
|
||||||
# "schedule": forgetting_cycle_schedule,
|
"schedule": forgetting_cycle_schedule,
|
||||||
# "kwargs": {
|
"kwargs": {
|
||||||
# "config_id": None, # 使用默认配置,可以通过环境变量配置
|
"config_id": None, # 使用默认配置,可以通过环境变量配置
|
||||||
# },
|
},
|
||||||
# },
|
},
|
||||||
# }
|
}
|
||||||
|
|
||||||
# 如果配置了默认工作空间ID,则添加记忆总量统计任务
|
# 如果配置了默认工作空间ID,则添加记忆总量统计任务
|
||||||
# if settings.DEFAULT_WORKSPACE_ID:
|
if settings.DEFAULT_WORKSPACE_ID:
|
||||||
# beat_schedule_config["write-total-memory"] = {
|
beat_schedule_config["write-total-memory"] = {
|
||||||
# "task": "app.controllers.memory_storage_controller.search_all",
|
"task": "app.controllers.memory_storage_controller.search_all",
|
||||||
# "schedule": memory_increment_schedule,
|
"schedule": memory_increment_schedule,
|
||||||
# "kwargs": {
|
"kwargs": {
|
||||||
# "workspace_id": settings.DEFAULT_WORKSPACE_ID,
|
"workspace_id": settings.DEFAULT_WORKSPACE_ID,
|
||||||
# },
|
},
|
||||||
# }
|
}
|
||||||
|
|
||||||
# celery_app.conf.beat_schedule = beat_schedule_config
|
celery_app.conf.beat_schedule = beat_schedule_config
|
||||||
|
|||||||
@@ -45,7 +45,6 @@ from . import (
|
|||||||
home_page_controller,
|
home_page_controller,
|
||||||
memory_perceptual_controller,
|
memory_perceptual_controller,
|
||||||
memory_working_controller,
|
memory_working_controller,
|
||||||
ontology_controller,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建管理端 API 路由器
|
# 创建管理端 API 路由器
|
||||||
@@ -91,6 +90,5 @@ manager_router.include_router(implicit_memory_controller.router)
|
|||||||
manager_router.include_router(memory_perceptual_controller.router)
|
manager_router.include_router(memory_perceptual_controller.router)
|
||||||
manager_router.include_router(memory_working_controller.router)
|
manager_router.include_router(memory_working_controller.router)
|
||||||
manager_router.include_router(file_storage_controller.router)
|
manager_router.include_router(file_storage_controller.router)
|
||||||
manager_router.include_router(ontology_controller.router)
|
|
||||||
|
|
||||||
__all__ = ["manager_router"]
|
__all__ = ["manager_router"]
|
||||||
|
|||||||
@@ -872,44 +872,3 @@ async def update_workflow_config(
|
|||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{app_id}/statistics", summary="应用统计数据")
|
|
||||||
@cur_workspace_access_guard()
|
|
||||||
def get_app_statistics(
|
|
||||||
app_id: uuid.UUID,
|
|
||||||
start_date: int,
|
|
||||||
end_date: int,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user=Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""获取应用统计数据
|
|
||||||
|
|
||||||
Args:
|
|
||||||
app_id: 应用ID
|
|
||||||
start_date: 开始时间戳(毫秒)
|
|
||||||
end_date: 结束时间戳(毫秒)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- daily_conversations: 每日会话数统计
|
|
||||||
- total_conversations: 总会话数
|
|
||||||
- daily_new_users: 每日新增用户数
|
|
||||||
- total_new_users: 总新增用户数
|
|
||||||
- daily_api_calls: 每日API调用次数
|
|
||||||
- total_api_calls: 总API调用次数
|
|
||||||
- daily_tokens: 每日token消耗
|
|
||||||
- total_tokens: 总token消耗
|
|
||||||
"""
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
|
|
||||||
from app.services.app_statistics_service import AppStatisticsService
|
|
||||||
stats_service = AppStatisticsService(db)
|
|
||||||
|
|
||||||
result = stats_service.get_app_statistics(
|
|
||||||
app_id=app_id,
|
|
||||||
workspace_id=workspace_id,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date
|
|
||||||
)
|
|
||||||
|
|
||||||
return success(data=result)
|
|
||||||
|
|||||||
@@ -7,11 +7,10 @@ Routes:
|
|||||||
GET /memory/config/emotion - 获取情绪引擎配置
|
GET /memory/config/emotion - 获取情绪引擎配置
|
||||||
POST /memory/config/emotion - 更新情绪引擎配置
|
POST /memory/config/emotion - 更新情绪引擎配置
|
||||||
"""
|
"""
|
||||||
import uuid
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query, HTTPException, status
|
from fastapi import APIRouter, Depends, Query, HTTPException, status
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Optional, Union
|
from typing import Optional
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
@@ -22,7 +21,6 @@ from app.schemas.response_schema import ApiResponse
|
|||||||
from app.services.emotion_config_service import EmotionConfigService
|
from app.services.emotion_config_service import EmotionConfigService
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.utils.config_utils import resolve_config_id
|
|
||||||
|
|
||||||
# 获取API专用日志器
|
# 获取API专用日志器
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -39,7 +37,7 @@ class EmotionConfigQuery(BaseModel):
|
|||||||
|
|
||||||
class EmotionConfigUpdate(BaseModel):
|
class EmotionConfigUpdate(BaseModel):
|
||||||
"""情绪配置更新请求模型"""
|
"""情绪配置更新请求模型"""
|
||||||
config_id: Union[uuid.UUID, int, str]= Field(..., description="配置ID")
|
config_id: UUID = Field(..., description="配置ID")
|
||||||
emotion_enabled: bool = Field(..., description="是否启用情绪提取")
|
emotion_enabled: bool = Field(..., description="是否启用情绪提取")
|
||||||
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
|
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
|
||||||
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
|
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
|
||||||
@@ -48,7 +46,7 @@ class EmotionConfigUpdate(BaseModel):
|
|||||||
|
|
||||||
@router.get("/read_config", response_model=ApiResponse)
|
@router.get("/read_config", response_model=ApiResponse)
|
||||||
def get_emotion_config(
|
def get_emotion_config(
|
||||||
config_id: UUID|int = Query(..., description="配置ID"),
|
config_id: UUID = Query(..., description="配置ID"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
@@ -81,7 +79,7 @@ def get_emotion_config(
|
|||||||
f"用户 {current_user.username} 请求获取情绪配置",
|
f"用户 {current_user.username} 请求获取情绪配置",
|
||||||
extra={"config_id": config_id}
|
extra={"config_id": config_id}
|
||||||
)
|
)
|
||||||
config_id=resolve_config_id(config_id, db)
|
|
||||||
# 初始化服务
|
# 初始化服务
|
||||||
config_service = EmotionConfigService(db)
|
config_service = EmotionConfigService(db)
|
||||||
|
|
||||||
@@ -160,7 +158,6 @@ def update_emotion_config(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
config.config_id=resolve_config_id(config.config_id, db)
|
|
||||||
try:
|
try:
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"用户 {current_user.username} 请求更新情绪配置",
|
f"用户 {current_user.username} 请求更新情绪配置",
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ from app.schemas.memory_storage_schema import (
|
|||||||
)
|
)
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services.memory_forget_service import MemoryForgetService
|
from app.services.memory_forget_service import MemoryForgetService
|
||||||
from app.utils.config_utils import resolve_config_id
|
|
||||||
|
|
||||||
# 获取API专用日志器
|
# 获取API专用日志器
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -84,7 +84,6 @@ async def trigger_forgetting_cycle(
|
|||||||
|
|
||||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
config_id = connected_config.get("memory_config_id")
|
config_id = connected_config.get("memory_config_id")
|
||||||
config_id = resolve_config_id((config_id), db)
|
|
||||||
|
|
||||||
if config_id is None:
|
if config_id is None:
|
||||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||||
@@ -130,7 +129,7 @@ async def trigger_forgetting_cycle(
|
|||||||
|
|
||||||
@router.get("/read_config", response_model=ApiResponse)
|
@router.get("/read_config", response_model=ApiResponse)
|
||||||
async def read_forgetting_config(
|
async def read_forgetting_config(
|
||||||
config_id: UUID|int,
|
config_id: UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
@@ -159,7 +158,6 @@ async def read_forgetting_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config_id=resolve_config_id(config_id, db)
|
|
||||||
# 调用服务层读取配置
|
# 调用服务层读取配置
|
||||||
config = forget_service.read_forgetting_config(db=db, config_id=config_id)
|
config = forget_service.read_forgetting_config(db=db, config_id=config_id)
|
||||||
|
|
||||||
@@ -197,8 +195,6 @@ async def update_forgetting_config(
|
|||||||
ApiResponse: 包含更新结果的响应
|
ApiResponse: 包含更新结果的响应
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
payload.config_id=resolve_config_id((payload.config_id), db)
|
|
||||||
|
|
||||||
|
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
@@ -259,10 +255,12 @@ async def get_forgetting_stats(
|
|||||||
ApiResponse: 包含统计信息的响应
|
ApiResponse: 包含统计信息的响应
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
# 如果提供了 end_user_id,通过它获取 config_id
|
# 如果提供了 end_user_id,通过它获取 config_id
|
||||||
config_id = None
|
config_id = None
|
||||||
if end_user_id:
|
if end_user_id:
|
||||||
@@ -271,7 +269,6 @@ async def get_forgetting_stats(
|
|||||||
|
|
||||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
config_id = connected_config.get("memory_config_id")
|
config_id = connected_config.get("memory_config_id")
|
||||||
config_id = resolve_config_id(config_id, db)
|
|
||||||
|
|
||||||
if config_id is None:
|
if config_id is None:
|
||||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||||
@@ -328,7 +325,7 @@ async def get_forgetting_curve(
|
|||||||
ApiResponse: 包含遗忘曲线数据的响应
|
ApiResponse: 包含遗忘曲线数据的响应
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
request.config_id = resolve_config_id((request.config_id), db)
|
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘曲线但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘曲线但未选择工作空间")
|
||||||
|
|||||||
@@ -25,8 +25,6 @@ from fastapi import APIRouter, Depends, HTTPException, status,Header
|
|||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.utils.config_utils import resolve_config_id
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
@@ -45,12 +43,12 @@ async def save_reflection_config(
|
|||||||
"""Save reflection configuration to data_comfig table"""
|
"""Save reflection configuration to data_comfig table"""
|
||||||
try:
|
try:
|
||||||
config_id = request.config_id
|
config_id = request.config_id
|
||||||
config_id = resolve_config_id(config_id, db)
|
|
||||||
if not config_id:
|
if not config_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail="缺少必需参数: config_id"
|
detail="缺少必需参数: config_id"
|
||||||
)
|
)
|
||||||
|
|
||||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||||
|
|
||||||
memory_config = MemoryConfigRepository.update_reflection_config(
|
memory_config = MemoryConfigRepository.update_reflection_config(
|
||||||
@@ -101,7 +99,7 @@ async def start_workspace_reflection(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""启动工作空间中所有匹配应用的反思功能"""
|
"""Activate the reflection function for all matching applications in the workspace"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
reflection_service = MemoryReflectionService(db)
|
reflection_service = MemoryReflectionService(db)
|
||||||
|
|
||||||
@@ -110,55 +108,42 @@ async def start_workspace_reflection(
|
|||||||
|
|
||||||
service = WorkspaceAppService(db)
|
service = WorkspaceAppService(db)
|
||||||
result = service.get_workspace_apps_detailed(workspace_id)
|
result = service.get_workspace_apps_detailed(workspace_id)
|
||||||
|
|
||||||
reflection_results = []
|
reflection_results = []
|
||||||
|
|
||||||
for data in result['apps_detailed_info']:
|
for data in result['apps_detailed_info']:
|
||||||
# 跳过没有配置的应用
|
if data['memory_configs'] == []:
|
||||||
if not data['memory_configs']:
|
|
||||||
api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过")
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
releases = data['releases']
|
releases = data['releases']
|
||||||
memory_configs = data['memory_configs']
|
memory_configs = data['memory_configs']
|
||||||
end_users = data['end_users']
|
end_users = data['end_users']
|
||||||
|
|
||||||
# 为每个配置和用户组合执行反思
|
for base, config, user in zip(releases, memory_configs, end_users):
|
||||||
for config in memory_configs:
|
# 安全地转换为整数,处理空字符串和None的情况
|
||||||
config_id_str = str(config['config_id'])
|
print(base['config'])
|
||||||
|
try:
|
||||||
# 找到匹配此配置的所有release
|
base_config = int(base['config']) if base['config'] else 0
|
||||||
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
|
config_id = int(config['config_id']) if config['config_id'] else 0
|
||||||
|
except (ValueError, TypeError):
|
||||||
if not matching_releases:
|
api_logger.warning(f"无效的配置ID: base['config']={base.get('config')}, config['config_id']={config.get('config_id')}")
|
||||||
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 为每个用户执行反思
|
if base_config == config_id and base['app_id'] == user['app_id']:
|
||||||
for user in end_users:
|
# 调用反思服务
|
||||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}")
|
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
||||||
|
|
||||||
try:
|
reflection_result = await reflection_service.start_text_reflection(
|
||||||
reflection_result = await reflection_service.start_text_reflection(
|
config_data=config,
|
||||||
config_data=config,
|
end_user_id=user['id']
|
||||||
end_user_id=user['id']
|
)
|
||||||
)
|
|
||||||
|
|
||||||
reflection_results.append({
|
reflection_results.append({
|
||||||
"app_id": data['id'],
|
"app_id": base['app_id'],
|
||||||
"config_id": config_id_str,
|
"config_id": config['config_id'],
|
||||||
"end_user_id": user['id'],
|
"end_user_id": user['id'],
|
||||||
"reflection_result": reflection_result
|
"reflection_result": reflection_result
|
||||||
})
|
})
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"用户 {user['id']} 反思失败: {str(e)}")
|
|
||||||
reflection_results.append({
|
|
||||||
"app_id": data['id'],
|
|
||||||
"config_id": config_id_str,
|
|
||||||
"end_user_id": user['id'],
|
|
||||||
"reflection_result": {
|
|
||||||
"status": "错误",
|
|
||||||
"message": f"反思失败: {str(e)}"
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
return success(data=reflection_results, msg="反思配置成功")
|
return success(data=reflection_results, msg="反思配置成功")
|
||||||
|
|
||||||
@@ -172,20 +157,17 @@ async def start_workspace_reflection(
|
|||||||
|
|
||||||
@router.get("/reflection/configs")
|
@router.get("/reflection/configs")
|
||||||
async def start_reflection_configs(
|
async def start_reflection_configs(
|
||||||
config_id: uuid.UUID|int,
|
config_id: uuid.UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""通过config_id查询memory_config表中的反思配置信息"""
|
"""通过config_id查询memory_config表中的反思配置信息"""
|
||||||
config_id = resolve_config_id(config_id, db)
|
|
||||||
try:
|
try:
|
||||||
config_id=resolve_config_id(config_id,db)
|
|
||||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||||
memory_config_id = resolve_config_id(result.config_id, db)
|
|
||||||
# 构建返回数据
|
# 构建返回数据
|
||||||
reflection_config = {
|
reflection_config = {
|
||||||
"config_id": memory_config_id,
|
"config_id": result.config_id,
|
||||||
"reflection_enabled": result.enable_self_reflexion,
|
"reflection_enabled": result.enable_self_reflexion,
|
||||||
"reflection_period_in_hours": result.iteration_period,
|
"reflection_period_in_hours": result.iteration_period,
|
||||||
"reflexion_range": result.reflexion_range,
|
"reflexion_range": result.reflexion_range,
|
||||||
@@ -210,7 +192,7 @@ async def start_reflection_configs(
|
|||||||
|
|
||||||
@router.get("/reflection/run")
|
@router.get("/reflection/run")
|
||||||
async def reflection_run(
|
async def reflection_run(
|
||||||
config_id: UUID|int,
|
config_id: UUID,
|
||||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
@@ -218,7 +200,7 @@ async def reflection_run(
|
|||||||
"""Activate the reflection function for all matching applications in the workspace"""
|
"""Activate the reflection function for all matching applications in the workspace"""
|
||||||
|
|
||||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||||
config_id = resolve_config_id(config_id, db)
|
|
||||||
# 使用MemoryConfigRepository查询反思配置
|
# 使用MemoryConfigRepository查询反思配置
|
||||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||||
if not result:
|
if not result:
|
||||||
|
|||||||
@@ -35,8 +35,6 @@ from fastapi import APIRouter, Depends
|
|||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.utils.config_utils import resolve_config_id
|
|
||||||
|
|
||||||
# Get API logger
|
# Get API logger
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
@@ -143,6 +141,7 @@ def create_config(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
|
||||||
@@ -162,12 +161,12 @@ def create_config(
|
|||||||
|
|
||||||
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
||||||
def delete_config(
|
def delete_config(
|
||||||
config_id: UUID|int,
|
config_id: UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
config_id=resolve_config_id(config_id, db)
|
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
||||||
@@ -189,17 +188,12 @@ def update_config(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
# 校验至少有一个字段需要更新
|
|
||||||
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
|
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
|
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
|
|
||||||
|
|
||||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
||||||
try:
|
try:
|
||||||
svc = DataConfigService(db)
|
svc = DataConfigService(db)
|
||||||
@@ -217,7 +211,7 @@ def update_config_extracted(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
|
||||||
@@ -239,12 +233,12 @@ def update_config_extracted(
|
|||||||
|
|
||||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||||
def read_config_extracted(
|
def read_config_extracted(
|
||||||
config_id: UUID | int,
|
config_id: UUID,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
config_id = resolve_config_id(config_id, db)
|
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
|
||||||
@@ -292,7 +286,6 @@ async def pilot_run(
|
|||||||
f"Pilot run requested: config_id={payload.config_id}, "
|
f"Pilot run requested: config_id={payload.config_id}, "
|
||||||
f"dialogue_text_length={len(payload.dialogue_text)}"
|
f"dialogue_text_length={len(payload.dialogue_text)}"
|
||||||
)
|
)
|
||||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
|
||||||
svc = DataConfigService(db)
|
svc = DataConfigService(db)
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
svc.pilot_run_stream(payload),
|
svc.pilot_run_stream(payload),
|
||||||
|
|||||||
@@ -3,17 +3,15 @@ from sqlalchemy.orm import Session
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
|
||||||
from app.core.exceptions import BusinessException
|
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.models.models_model import ModelProvider, ModelType, LoadBalanceStrategy
|
from app.models.models_model import ModelProvider, ModelType
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.repositories.model_repository import ModelConfigRepository
|
|
||||||
from app.schemas import model_schema
|
from app.schemas import model_schema
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.schemas.response_schema import ApiResponse, PageData
|
from app.schemas.response_schema import ApiResponse, PageData
|
||||||
from app.services.model_service import ModelConfigService, ModelApiKeyService, ModelBaseService
|
from app.services.model_service import ModelConfigService, ModelApiKeyService
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
|
|
||||||
# 获取API专用日志器
|
# 获取API专用日志器
|
||||||
@@ -26,30 +24,26 @@ router = APIRouter(
|
|||||||
|
|
||||||
@router.get("/type", response_model=ApiResponse)
|
@router.get("/type", response_model=ApiResponse)
|
||||||
def get_model_types():
|
def get_model_types():
|
||||||
|
|
||||||
return success(msg="获取模型类型成功", data=list(ModelType))
|
return success(msg="获取模型类型成功", data=list(ModelType))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/provider", response_model=ApiResponse)
|
@router.get("/provider", response_model=ApiResponse)
|
||||||
def get_model_providers():
|
def get_model_providers():
|
||||||
providers = [p for p in ModelProvider if p != ModelProvider.COMPOSITE]
|
return success(msg="获取模型提供商成功", data=list(ModelProvider))
|
||||||
return success(msg="获取模型提供商成功", data=providers)
|
|
||||||
|
|
||||||
@router.get("/strategy", response_model=ApiResponse)
|
|
||||||
def get_model_strategies():
|
|
||||||
return success(msg="获取模型策略成功", data=list(LoadBalanceStrategy))
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=ApiResponse)
|
@router.get("", response_model=ApiResponse)
|
||||||
def get_model_list(
|
def get_model_list(
|
||||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
type: Optional[str] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||||
page: int = Query(1, ge=1, description="页码"),
|
page: int = Query(1, ge=1, description="页码"),
|
||||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取模型配置列表
|
获取模型配置列表
|
||||||
@@ -59,20 +53,14 @@ def get_model_list(
|
|||||||
- 多个(逗号分隔):?type=LLM,EMBEDDING
|
- 多个(逗号分隔):?type=LLM,EMBEDDING
|
||||||
- 多个(重复参数):?type=LLM&type=EMBEDDING
|
- 多个(重复参数):?type=LLM&type=EMBEDDING
|
||||||
"""
|
"""
|
||||||
api_logger.info(
|
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
|
||||||
f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 解析 type 参数(支持逗号分隔)
|
# 解析 type 参数(支持逗号分隔)
|
||||||
type_list = []
|
type_list = None
|
||||||
if type is not None:
|
if type:
|
||||||
flat_type = []
|
type_values = [t.strip() for t in type.split(',')]
|
||||||
for item in type:
|
type_list = [model_schema.ModelType(t.lower()) for t in type_values if t]
|
||||||
split_items = [t.strip() for t in item.split(',') if t.strip()]
|
|
||||||
flat_type.extend(split_items)
|
|
||||||
|
|
||||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
|
||||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
|
||||||
|
|
||||||
api_logger.error(f"获取模型type_list: {type_list}")
|
api_logger.error(f"获取模型type_list: {type_list}")
|
||||||
query = model_schema.ModelConfigQuery(
|
query = model_schema.ModelConfigQuery(
|
||||||
@@ -95,146 +83,6 @@ def get_model_list(
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@router.get("/new", response_model=ApiResponse)
|
|
||||||
def get_model_list_new(
|
|
||||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
|
||||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于ModelConfig)"),
|
|
||||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
|
||||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
|
||||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
|
||||||
is_composite: Optional[bool] = Query(None, description="组合模型筛选"),
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
获取模型配置列表
|
|
||||||
|
|
||||||
支持多个 type 参数:
|
|
||||||
- 单个:?type=LLM
|
|
||||||
- 多个(逗号分隔):?type=LLM,EMBEDDING
|
|
||||||
- 多个(重复参数):?type=LLM&type=EMBEDDING
|
|
||||||
"""
|
|
||||||
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, tenant_id={current_user.tenant_id}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 解析 type 参数(支持逗号分隔)
|
|
||||||
type_list = []
|
|
||||||
if type is not None:
|
|
||||||
flat_type = []
|
|
||||||
for item in type:
|
|
||||||
split_items = [t.strip() for t in item.split(',') if t.strip()]
|
|
||||||
flat_type.extend(split_items)
|
|
||||||
|
|
||||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
|
||||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
|
||||||
|
|
||||||
api_logger.info(f"获取模型type_list: {type_list}")
|
|
||||||
query = model_schema.ModelConfigQueryNew(
|
|
||||||
type=type_list,
|
|
||||||
provider=provider,
|
|
||||||
is_active=is_active,
|
|
||||||
is_public=is_public,
|
|
||||||
is_composite=is_composite,
|
|
||||||
search=search
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.debug(f"开始获取模型配置列表: {query.model_dump()}")
|
|
||||||
result = ModelConfigService.get_model_list_new(db=db, query=query, tenant_id=current_user.tenant_id)
|
|
||||||
api_logger.info(f"模型配置列表获取成功: 分组数={len(result)}, 总模型数={sum(len(item['models']) for item in result)}")
|
|
||||||
return success(data=result, msg="模型配置列表获取成功")
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"获取模型配置列表失败: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/model_plaza", response_model=ApiResponse)
|
|
||||||
def get_model_plaza_list(
|
|
||||||
type: Optional[ModelType] = Query(None, description="模型类型"),
|
|
||||||
provider: Optional[ModelProvider] = Query(None, description="供应商"),
|
|
||||||
is_official: Optional[bool] = Query(None, description="是否官方模型"),
|
|
||||||
is_deprecated: Optional[bool] = Query(None, description="是否弃用"),
|
|
||||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""模型广场查询接口(按供应商分组)"""
|
|
||||||
|
|
||||||
query = model_schema.ModelBaseQuery(
|
|
||||||
type=type,
|
|
||||||
provider=provider,
|
|
||||||
is_official=is_official,
|
|
||||||
is_deprecated=is_deprecated,
|
|
||||||
search=search
|
|
||||||
)
|
|
||||||
result = ModelBaseService.get_model_base_list(db=db, query=query, tenant_id=current_user.tenant_id)
|
|
||||||
return success(data=result, msg="模型广场列表获取成功")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/model_plaza/{model_base_id}", response_model=ApiResponse)
|
|
||||||
def get_model_base_by_id(
|
|
||||||
model_base_id: uuid.UUID,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""获取基础模型详情"""
|
|
||||||
|
|
||||||
result = ModelBaseService.get_model_base_by_id(db=db, model_base_id=model_base_id)
|
|
||||||
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型获取成功")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/model_plaza", response_model=ApiResponse)
|
|
||||||
def create_model_base(
|
|
||||||
data: model_schema.ModelBaseCreate,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""创建基础模型"""
|
|
||||||
|
|
||||||
result = ModelBaseService.create_model_base(db=db, data=data)
|
|
||||||
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型创建成功")
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/model_plaza/{model_base_id}", response_model=ApiResponse)
|
|
||||||
def update_model_base(
|
|
||||||
model_base_id: uuid.UUID,
|
|
||||||
data: model_schema.ModelBaseUpdate,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""更新基础模型"""
|
|
||||||
|
|
||||||
# 不允许更改type类型
|
|
||||||
if data.type is not None or data.provider is not None:
|
|
||||||
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
|
||||||
|
|
||||||
result = ModelBaseService.update_model_base(db=db, model_base_id=model_base_id, data=data)
|
|
||||||
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型更新成功")
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/model_plaza/{model_base_id}", response_model=ApiResponse)
|
|
||||||
def delete_model_base(
|
|
||||||
model_base_id: uuid.UUID,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""删除基础模型"""
|
|
||||||
|
|
||||||
ModelBaseService.delete_model_base(db=db, model_base_id=model_base_id)
|
|
||||||
return success(msg="基础模型删除成功")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/model_plaza/{model_base_id}/add", response_model=ApiResponse)
|
|
||||||
def add_model_from_plaza(
|
|
||||||
model_base_id: uuid.UUID,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""从模型广场添加模型到模型列表"""
|
|
||||||
|
|
||||||
result = ModelBaseService.add_model_from_plaza(db=db, model_base_id=model_base_id, tenant_id=current_user.tenant_id)
|
|
||||||
return success(data=model_schema.ModelConfig.model_validate(result), msg="模型添加成功")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{model_id}", response_model=ApiResponse)
|
@router.get("/{model_id}", response_model=ApiResponse)
|
||||||
def get_model_by_id(
|
def get_model_by_id(
|
||||||
model_id: uuid.UUID,
|
model_id: uuid.UUID,
|
||||||
@@ -290,73 +138,6 @@ async def create_model(
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@router.post("/composite", response_model=ApiResponse)
|
|
||||||
async def create_composite_model(
|
|
||||||
model_data: model_schema.CompositeModelCreate,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
创建组合模型
|
|
||||||
|
|
||||||
- 绑定一个或多个现有的 API Key
|
|
||||||
- 所有 API Key 必须来自非组合模型
|
|
||||||
- 所有 API Key 关联的模型类型必须与组合模型类型一致
|
|
||||||
"""
|
|
||||||
api_logger.info(f"创建组合模型请求: {model_data.name}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
result_orm = await ModelConfigService.create_composite_model(db=db, model_data=model_data, tenant_id=current_user.tenant_id)
|
|
||||||
api_logger.info(f"组合模型创建成功: {result_orm.name} (ID: {result_orm.id})")
|
|
||||||
|
|
||||||
result = model_schema.ModelConfig.model_validate(result_orm)
|
|
||||||
return success(data=result, msg="组合模型创建成功")
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"创建组合模型失败: {model_data.name} - {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/composite/{model_id}", response_model=ApiResponse)
|
|
||||||
async def update_composite_model(
|
|
||||||
model_id: uuid.UUID,
|
|
||||||
model_data: model_schema.CompositeModelCreate,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""更新组合模型"""
|
|
||||||
api_logger.info(f"更新组合模型请求: model_id={model_id}, 用户: {current_user.username}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
if model_data.type is not None:
|
|
||||||
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
|
||||||
result_orm = await ModelConfigService.update_composite_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
|
|
||||||
api_logger.info(f"组合模型更新成功: {result_orm.name} (ID: {model_id})")
|
|
||||||
|
|
||||||
result = model_schema.ModelConfig.model_validate(result_orm)
|
|
||||||
return success(data=result, msg="组合模型更新成功")
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"更新组合模型失败: model_id={model_id} - {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/composite/{model_id}", response_model=ApiResponse)
|
|
||||||
def delete_composite_model(
|
|
||||||
model_id: uuid.UUID,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""删除组合模型"""
|
|
||||||
api_logger.info(f"删除组合模型请求: model_id={model_id}, 用户: {current_user.username}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
ModelConfigService.delete_model(db=db, model_id=model_id, tenant_id=current_user.tenant_id)
|
|
||||||
api_logger.info(f"组合模型删除成功: model_id={model_id}")
|
|
||||||
return success(msg="组合模型删除成功")
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"删除组合模型失败: model_id={model_id} - {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{model_id}", response_model=ApiResponse)
|
@router.put("/{model_id}", response_model=ApiResponse)
|
||||||
def update_model(
|
def update_model(
|
||||||
model_id: uuid.UUID,
|
model_id: uuid.UUID,
|
||||||
@@ -433,53 +214,6 @@ def get_model_api_keys(
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@router.post("/provider/apikeys", response_model=ApiResponse)
|
|
||||||
async def create_model_api_key_by_provider(
|
|
||||||
api_key_data: model_schema.ModelApiKeyCreateByProvider,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
根据供应商为所有匹配的模型创建API Key
|
|
||||||
"""
|
|
||||||
api_logger.info(f"创建API Key请求: provider={api_key_data.provider}, 用户: {current_user.username}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 根据tenant_id和provider筛选model_config_id列表
|
|
||||||
model_config_ids = api_key_data.model_config_ids
|
|
||||||
if not model_config_ids:
|
|
||||||
model_config_ids = ModelConfigRepository.get_model_config_ids_by_provider(
|
|
||||||
db=db,
|
|
||||||
tenant_id=current_user.tenant_id,
|
|
||||||
provider=api_key_data.provider
|
|
||||||
)
|
|
||||||
|
|
||||||
if not model_config_ids:
|
|
||||||
raise BusinessException(f"未找到供应商 {api_key_data.provider} 的模型配置", BizCode.MODEL_NOT_FOUND)
|
|
||||||
|
|
||||||
# 构造schema并调用service
|
|
||||||
create_data = model_schema.ModelApiKeyCreateByProvider(
|
|
||||||
provider=api_key_data.provider,
|
|
||||||
api_key=api_key_data.api_key,
|
|
||||||
api_base=api_key_data.api_base,
|
|
||||||
description=api_key_data.description,
|
|
||||||
config=api_key_data.config,
|
|
||||||
is_active=api_key_data.is_active,
|
|
||||||
priority=api_key_data.priority,
|
|
||||||
model_config_ids=model_config_ids
|
|
||||||
)
|
|
||||||
created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)
|
|
||||||
|
|
||||||
api_logger.info(f"API Key创建成功: 关联{len(created_keys)}个模型")
|
|
||||||
# result_list = [model_schema.ModelApiKey.model_validate(key) for key in created_keys]
|
|
||||||
result = "API Key已存在" if len(created_keys) == 0 and len(failed_models) == 0 else \
|
|
||||||
f"成功为 {len(created_keys)} 个模型创建API Key, 失败模型列表{failed_models}"
|
|
||||||
return success(data=result, msg=f"成功为 {len(created_keys)} 个模型创建API Key")
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"创建API Key失败: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{model_id}/apikeys", response_model=ApiResponse, status_code=status.HTTP_201_CREATED)
|
@router.post("/{model_id}/apikeys", response_model=ApiResponse, status_code=status.HTTP_201_CREATED)
|
||||||
async def create_model_api_key(
|
async def create_model_api_key(
|
||||||
model_id: uuid.UUID,
|
model_id: uuid.UUID,
|
||||||
@@ -494,12 +228,11 @@ async def create_model_api_key(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 设置模型配置ID
|
# 设置模型配置ID
|
||||||
api_key_data.model_config_ids = [model_id]
|
api_key_data.model_config_id = model_id
|
||||||
|
|
||||||
api_logger.debug(f"开始创建模型API Key: {api_key_data.model_name}")
|
api_logger.debug(f"开始创建模型API Key: {api_key_data.model_name}")
|
||||||
result_orm = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data)
|
result = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data)
|
||||||
api_logger.info(f"模型API Key创建成功: {result_orm.model_name} (ID: {result_orm.id})")
|
api_logger.info(f"模型API Key创建成功: {result.model_name} (ID: {result.id})")
|
||||||
result = model_schema.ModelApiKey.model_validate(result_orm)
|
|
||||||
return success(data=result, msg="模型API Key创建成功")
|
return success(data=result, msg="模型API Key创建成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"创建模型API Key失败: {api_key_data.model_name} - {str(e)}")
|
api_logger.error(f"创建模型API Key失败: {api_key_data.model_name} - {str(e)}")
|
||||||
@@ -601,3 +334,5 @@ async def validate_model_config(
|
|||||||
return success(data=model_schema.ModelValidateResponse(**result), msg="验证完成")
|
return success(data=model_schema.ModelValidateResponse(**result), msg="验证完成")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,611 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""本体场景和类型路由(续)
|
|
||||||
|
|
||||||
由于主Controller文件较大,将剩余路由放在此文件中。
|
|
||||||
"""
|
|
||||||
|
|
||||||
from uuid import UUID
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from fastapi import Depends
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
|
||||||
from app.core.logging_config import get_api_logger
|
|
||||||
from app.core.response_utils import fail, success
|
|
||||||
from app.db import get_db
|
|
||||||
from app.dependencies import get_current_user
|
|
||||||
from app.models.user_model import User
|
|
||||||
from app.schemas.ontology_schemas import (
|
|
||||||
SceneResponse,
|
|
||||||
SceneListResponse,
|
|
||||||
PaginationInfo,
|
|
||||||
ClassCreateRequest,
|
|
||||||
ClassUpdateRequest,
|
|
||||||
ClassResponse,
|
|
||||||
ClassListResponse,
|
|
||||||
ClassBatchCreateResponse,
|
|
||||||
)
|
|
||||||
from app.schemas.response_schema import ApiResponse
|
|
||||||
from app.services.ontology_service import OntologyService
|
|
||||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
|
||||||
from app.core.models.base import RedBearModelConfig
|
|
||||||
|
|
||||||
|
|
||||||
api_logger = get_api_logger()
|
|
||||||
|
|
||||||
|
|
||||||
def _get_dummy_ontology_service(db: Session) -> OntologyService:
|
|
||||||
"""获取OntologyService实例(不需要LLM)
|
|
||||||
|
|
||||||
场景和类型管理不需要LLM,创建一个dummy配置。
|
|
||||||
"""
|
|
||||||
dummy_config = RedBearModelConfig(
|
|
||||||
model_name="dummy",
|
|
||||||
provider="openai",
|
|
||||||
api_key="dummy",
|
|
||||||
base_url="https://api.openai.com/v1"
|
|
||||||
)
|
|
||||||
llm_client = OpenAIClient(model_config=dummy_config)
|
|
||||||
return OntologyService(llm_client=llm_client, db=db)
|
|
||||||
|
|
||||||
|
|
||||||
# 这些函数将被导入到主Controller中
|
|
||||||
|
|
||||||
async def scenes_handler(
|
|
||||||
workspace_id: Optional[str] = None,
|
|
||||||
scene_name: Optional[str] = None,
|
|
||||||
page: Optional[int] = None,
|
|
||||||
page_size: Optional[int] = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""获取场景列表(支持模糊搜索和全量查询,全量查询支持分页)
|
|
||||||
|
|
||||||
当提供 scene_name 参数时,进行模糊搜索(不分页);
|
|
||||||
当不提供 scene_name 参数时,返回所有场景(支持分页)。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
|
||||||
scene_name: 场景名称关键词(可选,支持模糊匹配)
|
|
||||||
page: 页码(可选,从1开始,仅在全量查询时有效)
|
|
||||||
page_size: 每页数量(可选,仅在全量查询时有效)
|
|
||||||
db: 数据库会话
|
|
||||||
current_user: 当前用户
|
|
||||||
"""
|
|
||||||
operation = "search" if scene_name else "list"
|
|
||||||
api_logger.info(
|
|
||||||
f"Scene {operation} requested by user {current_user.id}, "
|
|
||||||
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 确定工作空间ID
|
|
||||||
if workspace_id:
|
|
||||||
try:
|
|
||||||
ws_uuid = UUID(workspace_id)
|
|
||||||
except ValueError:
|
|
||||||
api_logger.warning(f"Invalid workspace_id format: {workspace_id}")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的工作空间ID格式")
|
|
||||||
else:
|
|
||||||
ws_uuid = current_user.current_workspace_id
|
|
||||||
if not ws_uuid:
|
|
||||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
|
||||||
|
|
||||||
# 创建Service
|
|
||||||
service = _get_dummy_ontology_service(db)
|
|
||||||
|
|
||||||
# 根据是否提供 scene_name 决定查询方式
|
|
||||||
if scene_name and scene_name.strip():
|
|
||||||
# 验证分页参数(模糊搜索也支持分页)
|
|
||||||
if page is not None and page < 1:
|
|
||||||
api_logger.warning(f"Invalid page number: {page}")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
|
||||||
|
|
||||||
if page_size is not None and page_size < 1:
|
|
||||||
api_logger.warning(f"Invalid page_size: {page_size}")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
|
||||||
|
|
||||||
# 如果只提供了page或page_size中的一个,返回错误
|
|
||||||
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
|
||||||
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
|
||||||
|
|
||||||
# 模糊搜索场景(支持分页)
|
|
||||||
scenes = service.search_scenes_by_name(scene_name.strip(), ws_uuid)
|
|
||||||
total = len(scenes)
|
|
||||||
|
|
||||||
# 如果提供了分页参数,进行分页处理
|
|
||||||
if page is not None and page_size is not None:
|
|
||||||
start_idx = (page - 1) * page_size
|
|
||||||
end_idx = start_idx + page_size
|
|
||||||
scenes = scenes[start_idx:end_idx]
|
|
||||||
|
|
||||||
# 构建响应
|
|
||||||
items = []
|
|
||||||
for scene in scenes:
|
|
||||||
# 获取前3个class_name作为entity_type
|
|
||||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
|
||||||
# 动态计算 type_num
|
|
||||||
type_num = len(scene.classes) if scene.classes else 0
|
|
||||||
|
|
||||||
items.append(SceneResponse(
|
|
||||||
scene_id=scene.scene_id,
|
|
||||||
scene_name=scene.scene_name,
|
|
||||||
scene_description=scene.scene_description,
|
|
||||||
type_num=type_num,
|
|
||||||
entity_type=entity_type,
|
|
||||||
workspace_id=scene.workspace_id,
|
|
||||||
created_at=scene.created_at,
|
|
||||||
updated_at=scene.updated_at,
|
|
||||||
classes_count=type_num
|
|
||||||
))
|
|
||||||
|
|
||||||
# 构建响应(包含分页信息)
|
|
||||||
if page is not None and page_size is not None:
|
|
||||||
# 计算是否有下一页
|
|
||||||
hasnext = (page * page_size) < total
|
|
||||||
|
|
||||||
pagination_info = PaginationInfo(
|
|
||||||
page=page,
|
|
||||||
pagesize=page_size,
|
|
||||||
total=total,
|
|
||||||
hasnext=hasnext
|
|
||||||
)
|
|
||||||
response = SceneListResponse(items=items, page=pagination_info)
|
|
||||||
else:
|
|
||||||
response = SceneListResponse(items=items)
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"Scene search completed: found {len(items)} scenes matching '{scene_name}' "
|
|
||||||
f"in workspace {ws_uuid}, total={total}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 获取所有场景(支持分页)
|
|
||||||
# 验证分页参数
|
|
||||||
if page is not None and page < 1:
|
|
||||||
api_logger.warning(f"Invalid page number: {page}")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
|
||||||
|
|
||||||
if page_size is not None and page_size < 1:
|
|
||||||
api_logger.warning(f"Invalid page_size: {page_size}")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
|
||||||
|
|
||||||
# 如果只提供了page或page_size中的一个,返回错误
|
|
||||||
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
|
||||||
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
|
||||||
|
|
||||||
scenes, total = service.list_scenes(ws_uuid, page, page_size)
|
|
||||||
|
|
||||||
# 构建响应
|
|
||||||
items = []
|
|
||||||
for scene in scenes:
|
|
||||||
# 获取前3个class_name作为entity_type
|
|
||||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
|
||||||
# 动态计算 type_num
|
|
||||||
type_num = len(scene.classes) if scene.classes else 0
|
|
||||||
|
|
||||||
items.append(SceneResponse(
|
|
||||||
scene_id=scene.scene_id,
|
|
||||||
scene_name=scene.scene_name,
|
|
||||||
scene_description=scene.scene_description,
|
|
||||||
type_num=type_num,
|
|
||||||
entity_type=entity_type,
|
|
||||||
workspace_id=scene.workspace_id,
|
|
||||||
created_at=scene.created_at,
|
|
||||||
updated_at=scene.updated_at,
|
|
||||||
classes_count=type_num
|
|
||||||
))
|
|
||||||
|
|
||||||
# 构建响应(包含分页信息)
|
|
||||||
if page is not None and page_size is not None:
|
|
||||||
# 计算是否有下一页
|
|
||||||
hasnext = (page * page_size) < total
|
|
||||||
|
|
||||||
pagination_info = PaginationInfo(
|
|
||||||
page=page,
|
|
||||||
pagesize=page_size,
|
|
||||||
total=total,
|
|
||||||
hasnext=hasnext
|
|
||||||
)
|
|
||||||
response = SceneListResponse(items=items, page=pagination_info)
|
|
||||||
else:
|
|
||||||
response = SceneListResponse(items=items)
|
|
||||||
|
|
||||||
api_logger.info(f"Scene list retrieved successfully, count={len(items)}, total={total}")
|
|
||||||
|
|
||||||
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
api_logger.warning(f"Validation error in scene {operation}: {str(e)}")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
api_logger.error(f"Runtime error in scene {operation}: {str(e)}", exc_info=True)
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"Unexpected error in scene {operation}: {str(e)}", exc_info=True)
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 本体类型管理接口 ====================
|
|
||||||
|
|
||||||
async def create_class_handler(
|
|
||||||
request: ClassCreateRequest,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
|
|
||||||
|
|
||||||
# 根据列表长度判断是单个还是批量
|
|
||||||
count = len(request.classes)
|
|
||||||
mode = "single" if count == 1 else "batch"
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"Class creation ({mode}) requested by user {current_user.id}, "
|
|
||||||
f"scene_id={request.scene_id}, count={count}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 获取当前工作空间ID
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
if not workspace_id:
|
|
||||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
|
||||||
|
|
||||||
# 创建Service
|
|
||||||
service = _get_dummy_ontology_service(db)
|
|
||||||
|
|
||||||
# 准备类型数据
|
|
||||||
classes_data = [
|
|
||||||
{
|
|
||||||
"class_name": item.class_name,
|
|
||||||
"class_description": item.class_description
|
|
||||||
}
|
|
||||||
for item in request.classes
|
|
||||||
]
|
|
||||||
|
|
||||||
if count == 1:
|
|
||||||
# 单个创建
|
|
||||||
class_data = classes_data[0]
|
|
||||||
ontology_class = service.create_class(
|
|
||||||
scene_id=request.scene_id,
|
|
||||||
class_name=class_data["class_name"],
|
|
||||||
class_description=class_data["class_description"],
|
|
||||||
workspace_id=workspace_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构建单个响应
|
|
||||||
response = ClassResponse(
|
|
||||||
class_id=ontology_class.class_id,
|
|
||||||
class_name=ontology_class.class_name,
|
|
||||||
class_description=ontology_class.class_description,
|
|
||||||
scene_id=ontology_class.scene_id,
|
|
||||||
created_at=ontology_class.created_at,
|
|
||||||
updated_at=ontology_class.updated_at
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(f"Class created successfully: {ontology_class.class_id}")
|
|
||||||
|
|
||||||
return success(data=response.model_dump(mode='json'), msg="类型创建成功")
|
|
||||||
|
|
||||||
else:
|
|
||||||
# 批量创建
|
|
||||||
created_classes, errors = service.create_classes_batch(
|
|
||||||
scene_id=request.scene_id,
|
|
||||||
classes=classes_data,
|
|
||||||
workspace_id=workspace_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构建批量响应
|
|
||||||
items = []
|
|
||||||
for ontology_class in created_classes:
|
|
||||||
items.append(ClassResponse(
|
|
||||||
class_id=ontology_class.class_id,
|
|
||||||
class_name=ontology_class.class_name,
|
|
||||||
class_description=ontology_class.class_description,
|
|
||||||
scene_id=ontology_class.scene_id,
|
|
||||||
created_at=ontology_class.created_at,
|
|
||||||
updated_at=ontology_class.updated_at
|
|
||||||
))
|
|
||||||
|
|
||||||
response = ClassBatchCreateResponse(
|
|
||||||
total=len(classes_data),
|
|
||||||
success_count=len(created_classes),
|
|
||||||
failed_count=len(errors),
|
|
||||||
items=items,
|
|
||||||
errors=errors if errors else None
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(
|
|
||||||
f"Batch class creation completed: "
|
|
||||||
f"success={len(created_classes)}, failed={len(errors)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
api_logger.warning(f"Validation error in class creation: {str(e)}")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True)
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
|
|
||||||
|
|
||||||
|
|
||||||
async def update_class_handler(
|
|
||||||
class_id: str,
|
|
||||||
request: ClassUpdateRequest,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""更新本体类型"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Class update requested by user {current_user.id}, "
|
|
||||||
f"class_id={class_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 验证UUID格式
|
|
||||||
try:
|
|
||||||
class_uuid = UUID(class_id)
|
|
||||||
except ValueError:
|
|
||||||
api_logger.warning(f"Invalid class_id format: {class_id}")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
|
||||||
|
|
||||||
# 获取当前工作空间ID
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
if not workspace_id:
|
|
||||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
|
||||||
|
|
||||||
# 创建Service
|
|
||||||
service = _get_dummy_ontology_service(db)
|
|
||||||
|
|
||||||
# 更新类型
|
|
||||||
ontology_class = service.update_class(
|
|
||||||
class_id=class_uuid,
|
|
||||||
class_name=request.class_name,
|
|
||||||
class_description=request.class_description,
|
|
||||||
workspace_id=workspace_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构建响应
|
|
||||||
response = ClassResponse(
|
|
||||||
class_id=ontology_class.class_id,
|
|
||||||
class_name=ontology_class.class_name,
|
|
||||||
class_description=ontology_class.class_description,
|
|
||||||
scene_id=ontology_class.scene_id,
|
|
||||||
created_at=ontology_class.created_at,
|
|
||||||
updated_at=ontology_class.updated_at
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(f"Class updated successfully: {class_id}")
|
|
||||||
|
|
||||||
return success(data=response.model_dump(mode='json'), msg="类型更新成功")
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
api_logger.warning(f"Validation error in class update: {str(e)}")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
api_logger.error(f"Runtime error in class update: {str(e)}", exc_info=True)
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"Unexpected error in class update: {str(e)}", exc_info=True)
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_class_handler(
|
|
||||||
class_id: str,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""删除本体类型"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Class deletion requested by user {current_user.id}, "
|
|
||||||
f"class_id={class_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 验证UUID格式
|
|
||||||
try:
|
|
||||||
class_uuid = UUID(class_id)
|
|
||||||
except ValueError:
|
|
||||||
api_logger.warning(f"Invalid class_id format: {class_id}")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
|
||||||
|
|
||||||
# 获取当前工作空间ID
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
if not workspace_id:
|
|
||||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
|
||||||
|
|
||||||
# 创建Service
|
|
||||||
service = _get_dummy_ontology_service(db)
|
|
||||||
|
|
||||||
# 删除类型
|
|
||||||
success_flag = service.delete_class(
|
|
||||||
class_id=class_uuid,
|
|
||||||
workspace_id=workspace_id
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(f"Class deleted successfully: {class_id}")
|
|
||||||
|
|
||||||
return success(data={"deleted": success_flag}, msg="类型删除成功")
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
api_logger.warning(f"Validation error in class deletion: {str(e)}")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
api_logger.error(f"Runtime error in class deletion: {str(e)}", exc_info=True)
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"Unexpected error in class deletion: {str(e)}", exc_info=True)
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
|
|
||||||
|
|
||||||
|
|
||||||
async def get_class_handler(
|
|
||||||
class_id: str,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""获取单个本体类型"""
|
|
||||||
api_logger.info(
|
|
||||||
f"Get class requested by user {current_user.id}, "
|
|
||||||
f"class_id={class_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 验证UUID格式
|
|
||||||
try:
|
|
||||||
class_uuid = UUID(class_id)
|
|
||||||
except ValueError:
|
|
||||||
api_logger.warning(f"Invalid class_id format: {class_id}")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
|
||||||
|
|
||||||
# 获取当前工作空间ID
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
if not workspace_id:
|
|
||||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
|
||||||
|
|
||||||
# 创建Service
|
|
||||||
service = _get_dummy_ontology_service(db)
|
|
||||||
|
|
||||||
# 获取类型(会抛出ValueError如果不存在)
|
|
||||||
ontology_class = service.get_class_by_id(class_uuid, workspace_id)
|
|
||||||
|
|
||||||
# 构建响应
|
|
||||||
response = ClassResponse(
|
|
||||||
class_id=ontology_class.class_id,
|
|
||||||
class_name=ontology_class.class_name,
|
|
||||||
class_description=ontology_class.class_description,
|
|
||||||
scene_id=ontology_class.scene_id,
|
|
||||||
created_at=ontology_class.created_at,
|
|
||||||
updated_at=ontology_class.updated_at
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(f"Class retrieved successfully: {class_id}")
|
|
||||||
|
|
||||||
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
# 类型不存在或无权限访问
|
|
||||||
api_logger.warning(f"Validation error in get class: {str(e)}")
|
|
||||||
return fail(BizCode.NOT_FOUND, "请求参数无效", str(e))
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
api_logger.error(f"Runtime error in get class: {str(e)}", exc_info=True)
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"Unexpected error in get class: {str(e)}", exc_info=True)
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
|
||||||
|
|
||||||
|
|
||||||
async def classes_handler(
|
|
||||||
scene_id: str,
|
|
||||||
class_name: Optional[str] = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""获取类型列表(支持模糊搜索和全量查询)
|
|
||||||
|
|
||||||
当提供 class_name 参数时,进行模糊搜索;
|
|
||||||
当不提供 class_name 参数时,返回场景下的所有类型。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scene_id: 场景ID(必填)
|
|
||||||
class_name: 类型名称关键词(可选,支持模糊匹配)
|
|
||||||
db: 数据库会话
|
|
||||||
current_user: 当前用户
|
|
||||||
"""
|
|
||||||
operation = "search" if class_name else "list"
|
|
||||||
api_logger.info(
|
|
||||||
f"Class {operation} requested by user {current_user.id}, "
|
|
||||||
f"keyword={class_name}, scene_id={scene_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 验证UUID格式
|
|
||||||
try:
|
|
||||||
scene_uuid = UUID(scene_id)
|
|
||||||
except ValueError:
|
|
||||||
api_logger.warning(f"Invalid scene_id format: {scene_id}")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的场景ID格式")
|
|
||||||
|
|
||||||
# 获取当前工作空间ID
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
if not workspace_id:
|
|
||||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
|
||||||
|
|
||||||
# 创建Service
|
|
||||||
service = _get_dummy_ontology_service(db)
|
|
||||||
|
|
||||||
# 获取场景信息
|
|
||||||
scene = service.get_scene_by_id(scene_uuid, workspace_id)
|
|
||||||
if not scene:
|
|
||||||
api_logger.warning(f"Scene not found: {scene_id}")
|
|
||||||
return fail(BizCode.NOT_FOUND, "场景不存在", f"未找到ID为 {scene_id} 的场景")
|
|
||||||
|
|
||||||
# 根据是否提供 class_name 决定查询方式
|
|
||||||
if class_name and class_name.strip():
|
|
||||||
# 模糊搜索类型
|
|
||||||
classes = service.search_classes_by_name(class_name.strip(), scene_uuid, workspace_id)
|
|
||||||
else:
|
|
||||||
# 获取所有类型
|
|
||||||
classes = service.list_classes_by_scene(scene_uuid, workspace_id)
|
|
||||||
|
|
||||||
# 构建响应
|
|
||||||
items = []
|
|
||||||
for ontology_class in classes:
|
|
||||||
items.append(ClassResponse(
|
|
||||||
class_id=ontology_class.class_id,
|
|
||||||
class_name=ontology_class.class_name,
|
|
||||||
class_description=ontology_class.class_description,
|
|
||||||
scene_id=ontology_class.scene_id,
|
|
||||||
created_at=ontology_class.created_at,
|
|
||||||
updated_at=ontology_class.updated_at
|
|
||||||
))
|
|
||||||
|
|
||||||
response = ClassListResponse(
|
|
||||||
total=len(items),
|
|
||||||
scene_id=scene_uuid,
|
|
||||||
scene_name=scene.scene_name,
|
|
||||||
scene_description=scene.scene_description,
|
|
||||||
items=items
|
|
||||||
)
|
|
||||||
|
|
||||||
if class_name:
|
|
||||||
api_logger.info(
|
|
||||||
f"Class search completed: found {len(items)} classes matching '{class_name}' "
|
|
||||||
f"in scene {scene_id}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
api_logger.info(f"Class list retrieved successfully, count={len(items)}")
|
|
||||||
|
|
||||||
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
api_logger.warning(f"Validation error in class {operation}: {str(e)}")
|
|
||||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
api_logger.error(f"Runtime error in class {operation}: {str(e)}", exc_info=True)
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"Unexpected error in class {operation}: {str(e)}", exc_info=True)
|
|
||||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
import json
|
|
||||||
import uuid
|
import uuid
|
||||||
|
import json
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Path
|
from fastapi import APIRouter, Depends, Path
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -8,13 +8,9 @@ from starlette.responses import StreamingResponse
|
|||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.dependencies import get_current_user, get_db
|
from app.dependencies import get_current_user, get_db
|
||||||
from app.schemas.prompt_optimizer_schema import (
|
from app.models.prompt_optimizer_model import RoleType
|
||||||
PromptOptMessage,
|
from app.schemas.prompt_optimizer_schema import PromptOptMessage, PromptOptModelSet, CreateSessionResponse, \
|
||||||
CreateSessionResponse,
|
OptimizePromptResponse, SessionHistoryResponse, SessionMessage
|
||||||
SessionHistoryResponse,
|
|
||||||
SessionMessage,
|
|
||||||
PromptSaveRequest
|
|
||||||
)
|
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services.prompt_optimizer_service import PromptOptimizerService
|
from app.services.prompt_optimizer_service import PromptOptimizerService
|
||||||
|
|
||||||
@@ -139,109 +135,3 @@ async def get_prompt_opt(
|
|||||||
"X-Accel-Buffering": "no"
|
"X-Accel-Buffering": "no"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/releases",
|
|
||||||
summary="Get prompt optimization",
|
|
||||||
response_model=ApiResponse
|
|
||||||
)
|
|
||||||
def save_prompt(
|
|
||||||
data: PromptSaveRequest,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user=Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Save a prompt release for the current tenant.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data (PromptSaveRequest): Request body containing session_id, title, and prompt.
|
|
||||||
db (Session): SQLAlchemy database session, injected via dependency.
|
|
||||||
current_user: Currently authenticated user object, injected via dependency.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ApiResponse: Standard API response containing the saved prompt release info:
|
|
||||||
- id: UUID of the prompt release
|
|
||||||
- session_id: associated session
|
|
||||||
- title: prompt title
|
|
||||||
- prompt: prompt content
|
|
||||||
- created_at: timestamp of creation
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Any database or service exceptions are propagated to the global exception handler.
|
|
||||||
"""
|
|
||||||
service = PromptOptimizerService(db)
|
|
||||||
prompt_info = service.save_prompt(
|
|
||||||
tenant_id=current_user.tenant_id,
|
|
||||||
session_id=data.session_id,
|
|
||||||
title=data.title,
|
|
||||||
prompt=data.prompt
|
|
||||||
)
|
|
||||||
return success(data=prompt_info)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
|
||||||
"/releases/{prompt_id}",
|
|
||||||
summary="Delete prompt (soft delete)",
|
|
||||||
response_model=ApiResponse
|
|
||||||
)
|
|
||||||
def delete_prompt(
|
|
||||||
prompt_id: uuid.UUID = Path(..., description="Prompt ID"),
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user=Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Soft delete a prompt release.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt_id
|
|
||||||
db (Session): Database session
|
|
||||||
current_user: Current logged-in user
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ApiResponse: Success message confirming deletion
|
|
||||||
"""
|
|
||||||
service = PromptOptimizerService(db)
|
|
||||||
service.delete_prompt(
|
|
||||||
tenant_id=current_user.tenant_id,
|
|
||||||
prompt_id=prompt_id
|
|
||||||
)
|
|
||||||
return success(msg="Prompt deleted successfully")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/releases/list",
|
|
||||||
summary="Get paginated list of released prompts with optional filter",
|
|
||||||
response_model=ApiResponse
|
|
||||||
)
|
|
||||||
def get_release_list(
|
|
||||||
page: int = 1,
|
|
||||||
page_size: int = 20,
|
|
||||||
keyword: str | None = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user=Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Retrieve paginated list of released prompts for the current tenant.
|
|
||||||
Optionally filter by keyword in title.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
page (int): Page number (starting from 1)
|
|
||||||
page_size (int): Number of items per page (max 100)
|
|
||||||
keyword (str | None): Optional keyword to filter prompt titles
|
|
||||||
db (Session): Database session
|
|
||||||
current_user: Current logged-in user
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ApiResponse: Contains paginated list of prompt releases with metadata
|
|
||||||
"""
|
|
||||||
service = PromptOptimizerService(db)
|
|
||||||
result = service.get_release_list(
|
|
||||||
tenant_id=current_user.tenant_id,
|
|
||||||
page=max(1, page),
|
|
||||||
page_size=min(max(1, page_size), 100),
|
|
||||||
filter_keyword=keyword
|
|
||||||
)
|
|
||||||
return success(data=result)
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -235,11 +235,11 @@ async def chat(
|
|||||||
|
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
user_id=end_user_id, # 转换为字符串
|
user_id=new_end_user.id, # 转换为字符串
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
config=config,
|
config=config,
|
||||||
web_search=web_search,
|
web_search=payload.web_search,
|
||||||
memory=memory,
|
memory=payload.memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
app_id=app.id,
|
app_id=app.id,
|
||||||
@@ -268,11 +268,11 @@ async def chat(
|
|||||||
|
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
user_id=end_user_id, # 转换为字符串
|
user_id=new_end_user.id, # 转换为字符串
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
config=config,
|
config=config,
|
||||||
web_search=web_search,
|
web_search=payload.web_search,
|
||||||
memory=memory,
|
memory=payload.memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
app_id=app.id,
|
app_id=app.id,
|
||||||
|
|||||||
@@ -7,21 +7,27 @@ LangChain Agent 封装
|
|||||||
- 支持流式输出
|
- 支持流式输出
|
||||||
- 使用 RedBearLLM 支持多提供商
|
- 使用 RedBearLLM 支持多提供商
|
||||||
"""
|
"""
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
from app.models.models_model import ModelType
|
from app.models.models_model import ModelType
|
||||||
|
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||||
from app.services.memory_agent_service import (
|
from app.services.memory_agent_service import (
|
||||||
get_end_user_connected_config,
|
get_end_user_connected_config,
|
||||||
)
|
)
|
||||||
|
from app.services.memory_konwledges_server import write_rag
|
||||||
|
from app.services.task_service import get_task_memory_write_result
|
||||||
|
from app.tasks import write_message_task
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
@@ -98,7 +104,7 @@ class LangChainAgent:
|
|||||||
"streaming": streaming,
|
"streaming": streaming,
|
||||||
"tool_count": len(self.tools),
|
"tool_count": len(self.tools),
|
||||||
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
|
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
|
||||||
# "tool_count": len(self.tools)
|
"tool_count": len(self.tools)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -137,7 +143,100 @@ class LangChainAgent:
|
|||||||
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
||||||
|
|
||||||
messages.append(HumanMessage(content=user_content))
|
messages.append(HumanMessage(content=user_content))
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||||
|
# async def term_memory_save(self,messages,end_user_end,aimessages):
|
||||||
|
# '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j'''
|
||||||
|
# end_user_end=f"Term_{end_user_end}"
|
||||||
|
# print(messages)
|
||||||
|
# print(aimessages)
|
||||||
|
# session_id = store.save_session(
|
||||||
|
# userid=end_user_end,
|
||||||
|
# messages=messages,
|
||||||
|
# apply_id=end_user_end,
|
||||||
|
# end_user_id=end_user_end,
|
||||||
|
# aimessages=aimessages
|
||||||
|
# )
|
||||||
|
# store.delete_duplicate_sessions()
|
||||||
|
# # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
||||||
|
# return session_id
|
||||||
|
|
||||||
|
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||||
|
# async def term_memory_redis_read(self,end_user_end):
|
||||||
|
# end_user_end = f"Term_{end_user_end}"
|
||||||
|
# history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
|
||||||
|
# # logger.info(f'Redis_Agent:{end_user_end};{history}')
|
||||||
|
# messagss_list=[]
|
||||||
|
# retrieved_content=[]
|
||||||
|
# for messages in history:
|
||||||
|
# query = messages.get("Query")
|
||||||
|
# aimessages = messages.get("Answer")
|
||||||
|
# messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
||||||
|
# retrieved_content.append({query: aimessages})
|
||||||
|
# return messagss_list,retrieved_content
|
||||||
|
|
||||||
|
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
|
||||||
|
"""
|
||||||
|
写入记忆(支持结构化消息)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_type: 存储类型 (neo4j/rag)
|
||||||
|
end_user_id: 终端用户ID
|
||||||
|
user_message: 用户消息内容
|
||||||
|
ai_message: AI 回复内容
|
||||||
|
user_rag_memory_id: RAG 记忆ID
|
||||||
|
actual_end_user_id: 实际用户ID
|
||||||
|
actual_config_id: 配置ID
|
||||||
|
|
||||||
|
逻辑说明:
|
||||||
|
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||||
|
- Neo4j 模式:使用结构化消息列表
|
||||||
|
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||||
|
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||||
|
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||||
|
"""
|
||||||
|
if storage_type == "rag":
|
||||||
|
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||||
|
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||||
|
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||||
|
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||||
|
else:
|
||||||
|
# Neo4j 模式:使用结构化消息列表
|
||||||
|
structured_messages = []
|
||||||
|
|
||||||
|
# 始终添加用户消息(如果不为空)
|
||||||
|
if user_message:
|
||||||
|
structured_messages.append({"role": "user", "content": user_message})
|
||||||
|
|
||||||
|
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||||
|
if ai_message:
|
||||||
|
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||||
|
|
||||||
|
# 如果没有消息,直接返回
|
||||||
|
if not structured_messages:
|
||||||
|
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 调用 Celery 任务,传递结构化消息列表
|
||||||
|
# 数据流:
|
||||||
|
# 1. structured_messages 传递给 write_message_task
|
||||||
|
# 2. write_message_task 调用 memory_agent_service.write_memory
|
||||||
|
# 3. write_memory 调用 write_tools.write,传递 messages 参数
|
||||||
|
# 4. write_tools.write 调用 get_chunked_dialogs,传递 messages 参数
|
||||||
|
# 5. get_chunked_dialogs 为每条消息创建独立的 Chunk,设置 speaker 字段
|
||||||
|
# 6. 每个 Chunk 保存到 Neo4j,包含 speaker 字段
|
||||||
|
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||||
|
write_id = write_message_task.delay(
|
||||||
|
actual_end_user_id, # end_user_id: 用户ID
|
||||||
|
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||||
|
actual_config_id, # config_id: 配置ID
|
||||||
|
storage_type, # storage_type: "neo4j"
|
||||||
|
user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||||
|
)
|
||||||
|
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||||
|
write_status = get_task_memory_write_result(str(write_id))
|
||||||
|
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
@@ -182,6 +281,30 @@ class LangChainAgent:
|
|||||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||||
|
# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码
|
||||||
|
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||||
|
# history_term_memory = history_term_memory_result[0]
|
||||||
|
# db_for_memory = next(get_db())
|
||||||
|
# if memory_flag:
|
||||||
|
# if len(history_term_memory)>=4 and storage_type != "rag":
|
||||||
|
# history_term_memory = ';'.join(history_term_memory)
|
||||||
|
# retrieved_content = history_term_memory_result[1]
|
||||||
|
# print(retrieved_content)
|
||||||
|
# # 为长期记忆操作获取新的数据库连接
|
||||||
|
# try:
|
||||||
|
# repo = LongTermMemoryRepository(db_for_memory)
|
||||||
|
# repo.upsert(end_user_id, retrieved_content)
|
||||||
|
# logger.info(
|
||||||
|
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"Failed to write to LongTermMemory: {e}")
|
||||||
|
# raise
|
||||||
|
# finally:
|
||||||
|
# db_for_memory.close()
|
||||||
|
|
||||||
|
# # 长期记忆写入(
|
||||||
|
# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||||
|
# # 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||||
try:
|
try:
|
||||||
# 准备消息列表
|
# 准备消息列表
|
||||||
messages = self._prepare_messages(message, history, context)
|
messages = self._prepare_messages(message, history, context)
|
||||||
@@ -202,17 +325,17 @@ class LangChainAgent:
|
|||||||
# 获取最后的 AI 消息
|
# 获取最后的 AI 消息
|
||||||
output_messages = result.get("messages", [])
|
output_messages = result.get("messages", [])
|
||||||
content = ""
|
content = ""
|
||||||
total_tokens = 0
|
|
||||||
for msg in reversed(output_messages):
|
for msg in reversed(output_messages):
|
||||||
if isinstance(msg, AIMessage):
|
if isinstance(msg, AIMessage):
|
||||||
content = msg.content
|
content = msg.content
|
||||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
|
||||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
|
||||||
break
|
break
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
if memory_flag:
|
if memory_flag:
|
||||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id)
|
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||||
|
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||||
|
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||||
|
# await self.term_memory_save(message_chat, end_user_id, content)
|
||||||
response = {
|
response = {
|
||||||
"content": content,
|
"content": content,
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
@@ -220,7 +343,7 @@ class LangChainAgent:
|
|||||||
"usage": {
|
"usage": {
|
||||||
"prompt_tokens": 0,
|
"prompt_tokens": 0,
|
||||||
"completion_tokens": 0,
|
"completion_tokens": 0,
|
||||||
"total_tokens": total_tokens
|
"total_tokens": 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -280,7 +403,25 @@ class LangChainAgent:
|
|||||||
db.close()
|
db.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get db session: {e}")
|
logger.warning(f"Failed to get db session: {e}")
|
||||||
|
# # TODO 乐力齐
|
||||||
|
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||||
|
# history_term_memory = history_term_memory_result[0]
|
||||||
|
# if memory_flag:
|
||||||
|
# if len(history_term_memory) >= 4 and storage_type != "rag":
|
||||||
|
# history_term_memory = ';'.join(history_term_memory)
|
||||||
|
# retrieved_content = history_term_memory_result[1]
|
||||||
|
# db_for_memory = next(get_db())
|
||||||
|
# try:
|
||||||
|
# repo = LongTermMemoryRepository(db_for_memory)
|
||||||
|
# repo.upsert(end_user_id, retrieved_content)
|
||||||
|
# logger.info(
|
||||||
|
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||||
|
# # 长期记忆写入
|
||||||
|
# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id)
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"Failed to write to long term memory: {e}")
|
||||||
|
# finally:
|
||||||
|
# db_for_memory.close()
|
||||||
|
|
||||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||||
try:
|
try:
|
||||||
@@ -296,7 +437,7 @@ class LangChainAgent:
|
|||||||
|
|
||||||
# 统一使用 agent 的 astream_events 实现流式输出
|
# 统一使用 agent 的 astream_events 实现流式输出
|
||||||
logger.debug("使用 Agent astream_events 实现流式输出")
|
logger.debug("使用 Agent astream_events 实现流式输出")
|
||||||
full_content = ''
|
full_content=''
|
||||||
try:
|
try:
|
||||||
async for event in self.agent.astream_events(
|
async for event in self.agent.astream_events(
|
||||||
{"messages": messages},
|
{"messages": messages},
|
||||||
@@ -333,17 +474,12 @@ class LangChainAgent:
|
|||||||
logger.debug(f"工具调用结束: {event.get('name')}")
|
logger.debug(f"工具调用结束: {event.get('name')}")
|
||||||
|
|
||||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||||
# 统计token消耗
|
|
||||||
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
|
||||||
for msg in reversed(output_messages):
|
|
||||||
if isinstance(msg, AIMessage):
|
|
||||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
|
||||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
|
|
||||||
0) if response_meta else 0
|
|
||||||
yield total_tokens
|
|
||||||
break
|
|
||||||
if memory_flag:
|
if memory_flag:
|
||||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id)
|
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||||
|
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
|
||||||
|
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||||
|
# await self.term_memory_save(message_chat, end_user_id, full_content)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -157,11 +157,6 @@ class Settings:
|
|||||||
if origin.strip()
|
if origin.strip()
|
||||||
]
|
]
|
||||||
|
|
||||||
# Language Configuration
|
|
||||||
# Supported values: "zh" (Chinese), "en" (English)
|
|
||||||
# This controls the language used for memory summary titles and other generated content
|
|
||||||
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
|
|
||||||
|
|
||||||
# Logging settings
|
# Logging settings
|
||||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||||
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
|||||||
@@ -1,238 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger
|
|
||||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
|
||||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
|
|
||||||
|
|
||||||
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
|
||||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
|
||||||
from app.core.memory.agent.utils.redis_tool import write_store
|
|
||||||
from app.core.memory.agent.utils.redis_tool import count_store
|
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
|
||||||
from app.db import get_db_context, get_db
|
|
||||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
|
||||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
|
||||||
from app.services.memory_konwledges_server import write_rag
|
|
||||||
from app.services.task_service import get_task_memory_write_result
|
|
||||||
from app.tasks import write_message_task
|
|
||||||
from app.utils.config_utils import resolve_config_id
|
|
||||||
logger = get_agent_logger(__name__)
|
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
|
||||||
|
|
||||||
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
|
||||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
|
||||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
|
||||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
|
||||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
|
||||||
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
|
|
||||||
actual_config_id, long_term_messages=[]):
|
|
||||||
"""
|
|
||||||
写入记忆(支持结构化消息)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
storage_type: 存储类型 (neo4j/rag)
|
|
||||||
end_user_id: 终端用户ID
|
|
||||||
user_message: 用户消息内容
|
|
||||||
ai_message: AI 回复内容
|
|
||||||
user_rag_memory_id: RAG 记忆ID
|
|
||||||
actual_end_user_id: 实际用户ID
|
|
||||||
actual_config_id: 配置ID
|
|
||||||
|
|
||||||
逻辑说明:
|
|
||||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
|
||||||
- Neo4j 模式:使用结构化消息列表
|
|
||||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
|
||||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
|
||||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
|
||||||
"""
|
|
||||||
|
|
||||||
db = next(get_db())
|
|
||||||
try:
|
|
||||||
actual_config_id = resolve_config_id(actual_config_id, db)
|
|
||||||
# Neo4j 模式:使用结构化消息列表
|
|
||||||
structured_messages = []
|
|
||||||
|
|
||||||
# 始终添加用户消息(如果不为空)
|
|
||||||
if isinstance(user_message, str) and user_message.strip() != "":
|
|
||||||
structured_messages.append({"role": "user", "content": user_message})
|
|
||||||
|
|
||||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
|
||||||
if isinstance(ai_message, str) and ai_message.strip() != "":
|
|
||||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
|
||||||
|
|
||||||
# 如果提供了 long_term_messages,使用它替代 structured_messages
|
|
||||||
if long_term_messages and isinstance(long_term_messages, list):
|
|
||||||
structured_messages = long_term_messages
|
|
||||||
elif long_term_messages and isinstance(long_term_messages, str):
|
|
||||||
# 如果是 JSON 字符串,先解析
|
|
||||||
try:
|
|
||||||
structured_messages = json.loads(long_term_messages)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
|
||||||
|
|
||||||
# 如果没有消息,直接返回
|
|
||||||
if not structured_messages:
|
|
||||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
|
||||||
write_id = write_message_task.delay(
|
|
||||||
actual_end_user_id, # end_user_id: 用户ID
|
|
||||||
structured_messages, # message: JSON 字符串格式的消息列表
|
|
||||||
str(actual_config_id), # config_id: 配置ID字符串
|
|
||||||
storage_type, # storage_type: "neo4j"
|
|
||||||
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
|
||||||
)
|
|
||||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
|
||||||
write_status = get_task_memory_write_result(str(write_id))
|
|
||||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
|
|
||||||
with get_db_context() as db_session:
|
|
||||||
repo = LongTermMemoryRepository(db_session)
|
|
||||||
|
|
||||||
|
|
||||||
from app.core.memory.agent.utils.redis_tool import write_store
|
|
||||||
result = write_store.get_session_by_userid(end_user_id)
|
|
||||||
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
|
||||||
data = await format_parsing(result, "dict")
|
|
||||||
chunk_data = data[:scope]
|
|
||||||
if len(chunk_data)==scope:
|
|
||||||
repo.upsert(end_user_id, chunk_data)
|
|
||||||
logger.info(f'---------写入短长期-----------')
|
|
||||||
else:
|
|
||||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
|
||||||
long_messages = await messages_parse(long_time_data)
|
|
||||||
repo.upsert(end_user_id, long_messages)
|
|
||||||
logger.info(f'写入短长期:')
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
'''根据窗口'''
|
|
||||||
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
|
||||||
'''
|
|
||||||
根据窗口获取redis数据,写入neo4j:
|
|
||||||
Args:
|
|
||||||
end_user_id: 终端用户ID
|
|
||||||
memory_config: 内存配置对象
|
|
||||||
langchain_messages:原始数据LIST
|
|
||||||
scope:窗口大小
|
|
||||||
'''
|
|
||||||
scope=scope
|
|
||||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
|
||||||
if is_end_user_id is not False:
|
|
||||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
|
||||||
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
|
||||||
if is_end_user_id and int(is_end_user_id) != int(scope):
|
|
||||||
is_end_user_id += 1
|
|
||||||
langchain_messages += redis_messages
|
|
||||||
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
|
||||||
elif int(is_end_user_id) == int(scope):
|
|
||||||
logger.info('写入长期记忆NEO4J')
|
|
||||||
formatted_messages = (redis_messages)
|
|
||||||
# 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用)
|
|
||||||
if hasattr(memory_config, 'config_id'):
|
|
||||||
config_id = memory_config.config_id
|
|
||||||
else:
|
|
||||||
config_id = memory_config
|
|
||||||
|
|
||||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
|
||||||
config_id, formatted_messages)
|
|
||||||
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
|
||||||
else:
|
|
||||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
|
||||||
|
|
||||||
|
|
||||||
"""根据时间"""
|
|
||||||
async def memory_long_term_storage(end_user_id,memory_config,time):
|
|
||||||
'''
|
|
||||||
根据时间获取redis数据,写入neo4j:
|
|
||||||
Args:
|
|
||||||
end_user_id: 终端用户ID
|
|
||||||
memory_config: 内存配置对象
|
|
||||||
'''
|
|
||||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
|
||||||
format_messages = (long_time_data)
|
|
||||||
messages=[]
|
|
||||||
memory_config=memory_config.config_id
|
|
||||||
for i in format_messages:
|
|
||||||
message=json.loads(i['Query'])
|
|
||||||
messages+= message
|
|
||||||
if format_messages!=[]:
|
|
||||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
|
||||||
memory_config, messages)
|
|
||||||
'''聚合判断'''
|
|
||||||
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
|
||||||
"""
|
|
||||||
聚合判断函数:判断输入句子和历史消息是否描述同一事件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: 终端用户ID
|
|
||||||
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
|
||||||
memory_config: 内存配置对象
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 1. 获取历史会话数据(使用新方法)
|
|
||||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
|
||||||
history = await format_parsing(result)
|
|
||||||
if not result:
|
|
||||||
history = []
|
|
||||||
else:
|
|
||||||
history = await format_parsing(result)
|
|
||||||
json_schema = WriteAggregateModel.model_json_schema()
|
|
||||||
template_service = TemplateService(template_root)
|
|
||||||
system_prompt = await template_service.render_template(
|
|
||||||
template_name='write_aggregate_judgment.jinja2',
|
|
||||||
operation_name='aggregate_judgment',
|
|
||||||
history=history,
|
|
||||||
sentence=ori_messages,
|
|
||||||
json_schema=json_schema
|
|
||||||
)
|
|
||||||
with get_db_context() as db_session:
|
|
||||||
factory = MemoryClientFactory(db_session)
|
|
||||||
llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": system_prompt
|
|
||||||
}
|
|
||||||
]
|
|
||||||
structured = await llm_client.response_structured(
|
|
||||||
messages=messages,
|
|
||||||
response_model=WriteAggregateModel
|
|
||||||
)
|
|
||||||
output_value = structured.output
|
|
||||||
if isinstance(output_value, list):
|
|
||||||
output_value = [
|
|
||||||
{"role": msg.role, "content": msg.content}
|
|
||||||
for msg in output_value
|
|
||||||
]
|
|
||||||
|
|
||||||
result_dict = {
|
|
||||||
"is_same_event": structured.is_same_event,
|
|
||||||
"output": output_value
|
|
||||||
}
|
|
||||||
if not structured.is_same_event:
|
|
||||||
logger.info(result_dict)
|
|
||||||
await write("neo4j", end_user_id, "", "", None, end_user_id,
|
|
||||||
memory_config.config_id, output_value)
|
|
||||||
return result_dict
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[aggregate_judgment] 发生错误: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"is_same_event": False,
|
|
||||||
"output": ori_messages,
|
|
||||||
"messages": ori_messages,
|
|
||||||
"history": history if 'history' in locals() else [],
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
@@ -186,11 +186,10 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
|||||||
清理后的数据
|
清理后的数据
|
||||||
"""
|
"""
|
||||||
# 需要过滤的字段列表
|
# 需要过滤的字段列表
|
||||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
|
||||||
fields_to_remove = {
|
fields_to_remove = {
|
||||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
|
'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary"
|
||||||
}
|
}
|
||||||
|
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
|
|||||||
@@ -1,72 +0,0 @@
|
|||||||
import json
|
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, AIMessage
|
|
||||||
async def format_parsing(messages: list,type:str='string'):
|
|
||||||
"""
|
|
||||||
格式化解析消息列表
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: 消息列表
|
|
||||||
type: 返回类型 ('string' 或 'dict')
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
格式化后的消息列表
|
|
||||||
"""
|
|
||||||
result = []
|
|
||||||
user=[]
|
|
||||||
ai=[]
|
|
||||||
|
|
||||||
for message in messages:
|
|
||||||
hstory_messages = message['messages']
|
|
||||||
for history_messag in hstory_messages.strip().splitlines():
|
|
||||||
history_messag = json.loads(history_messag)
|
|
||||||
for content in history_messag:
|
|
||||||
role = content['role']
|
|
||||||
content = content['content']
|
|
||||||
if type == "string":
|
|
||||||
if role == 'human' or role=="user":
|
|
||||||
content = '用户:' + content
|
|
||||||
else:
|
|
||||||
content = 'AI:' + content
|
|
||||||
result.append(content)
|
|
||||||
if type == "dict" :
|
|
||||||
if role == 'human' or role=="user":
|
|
||||||
user.append( content)
|
|
||||||
else:
|
|
||||||
ai.append(content)
|
|
||||||
if type == "dict":
|
|
||||||
for key,values in zip(user,ai):
|
|
||||||
result.append({key:values})
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def messages_parse(messages: list | dict):
|
|
||||||
user=[]
|
|
||||||
ai=[]
|
|
||||||
database=[]
|
|
||||||
for message in messages:
|
|
||||||
Query = message['Query']
|
|
||||||
Query = json.loads(Query)
|
|
||||||
for data in Query:
|
|
||||||
role = data['role']
|
|
||||||
if role == "human":
|
|
||||||
user.append(data['content'])
|
|
||||||
if role == "ai":
|
|
||||||
ai.append(data['content'])
|
|
||||||
for key, values in zip(user, ai):
|
|
||||||
database.append({key, values})
|
|
||||||
return database
|
|
||||||
|
|
||||||
|
|
||||||
async def agent_chat_messages(user_content,ai_content):
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": f"{user_content}"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": f"{ai_content}"
|
|
||||||
}
|
|
||||||
|
|
||||||
]
|
|
||||||
return messages
|
|
||||||
@@ -1,20 +1,22 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
from langgraph.constants import END, START
|
from langgraph.constants import END, START
|
||||||
from langgraph.graph import StateGraph
|
from langgraph.graph import StateGraph
|
||||||
|
|
||||||
from app.db import get_db, get_db_context
|
|
||||||
|
from app.db import get_db
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
@@ -32,6 +34,14 @@ async def make_write_graph():
|
|||||||
end_user_id: Group identifier
|
end_user_id: Group identifier
|
||||||
memory_config: MemoryConfig object containing all configuration
|
memory_config: MemoryConfig object containing all configuration
|
||||||
"""
|
"""
|
||||||
|
# workflow = StateGraph(WriteState)
|
||||||
|
# workflow.add_node("content_input", content_input_write)
|
||||||
|
# workflow.add_node("save_neo4j", write_node)
|
||||||
|
# workflow.add_edge(START, "content_input")
|
||||||
|
# workflow.add_edge("content_input", "save_neo4j")
|
||||||
|
# workflow.add_edge("save_neo4j", END)
|
||||||
|
#
|
||||||
|
# graph = workflow.compile()
|
||||||
workflow = StateGraph(WriteState)
|
workflow = StateGraph(WriteState)
|
||||||
workflow.add_node("save_neo4j", write_node)
|
workflow.add_node("save_neo4j", write_node)
|
||||||
workflow.add_edge(START, "save_neo4j")
|
workflow.add_edge(START, "save_neo4j")
|
||||||
@@ -41,63 +51,43 @@ async def make_write_graph():
|
|||||||
|
|
||||||
yield graph
|
yield graph
|
||||||
|
|
||||||
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
|
||||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
async def main():
|
||||||
from app.core.memory.agent.utils.redis_tool import write_store
|
"""主函数 - 运行工作流"""
|
||||||
write_store.save_session_write(end_user_id, (langchain_messages))
|
message = "今天周一"
|
||||||
|
end_user_id = 'new_2025test1103' # 组ID
|
||||||
|
|
||||||
|
|
||||||
# 获取数据库会话
|
# 获取数据库会话
|
||||||
with get_db_context() as db_session:
|
db_session = next(get_db())
|
||||||
config_service = MemoryConfigService(db_session)
|
config_service = MemoryConfigService(db_session)
|
||||||
memory_config = config_service.load_memory_config(
|
memory_config = config_service.load_memory_config(
|
||||||
config_id=memory_config, # 改为整数
|
config_id=17, # 改为整数
|
||||||
service_name="MemoryAgentService"
|
service_name="MemoryAgentService"
|
||||||
)
|
)
|
||||||
if long_term_type=='chunk':
|
try:
|
||||||
'''方案一:对话窗口6轮对话'''
|
async with make_write_graph() as graph:
|
||||||
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
config = {"configurable": {"thread_id": end_user_id}}
|
||||||
if long_term_type=='time':
|
# 初始状态 - 包含所有必要字段
|
||||||
"""时间"""
|
initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config}
|
||||||
await memory_long_term_storage(end_user_id, memory_config,5)
|
|
||||||
if long_term_type=='aggregate':
|
# 获取节点更新信息
|
||||||
"""方案三:聚合判断"""
|
async for update_event in graph.astream(
|
||||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
initial_state,
|
||||||
|
stream_mode="updates",
|
||||||
|
config=config
|
||||||
|
):
|
||||||
|
for node_name, node_data in update_event.items():
|
||||||
|
if 'save_neo4j'==node_name:
|
||||||
|
massages=node_data
|
||||||
|
massages=massages.get('write_result')['status']
|
||||||
|
print(massages) # | 更新数据: {node_data}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
|
import asyncio
|
||||||
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
asyncio.run(main())
|
||||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
|
||||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
|
||||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
|
||||||
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
|
||||||
else:
|
|
||||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
|
||||||
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
|
||||||
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
|
||||||
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
|
||||||
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
|
|
||||||
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
|
|
||||||
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
|
|
||||||
|
|
||||||
# async def main():
|
|
||||||
# """主函数 - 运行工作流"""
|
|
||||||
# langchain_messages = [
|
|
||||||
# {
|
|
||||||
# "role": "user",
|
|
||||||
# "content": "今天周五去爬山"
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "role": "assistant",
|
|
||||||
# "content": "好耶"
|
|
||||||
# }
|
|
||||||
#
|
|
||||||
# ]
|
|
||||||
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
|
||||||
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
|
||||||
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
|
||||||
#
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# import asyncio
|
|
||||||
# asyncio.run(main())
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
"""Pydantic models for write aggregate judgment operations."""
|
|
||||||
|
|
||||||
from typing import List, Union
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class MessageItem(BaseModel):
|
|
||||||
"""Individual message item in conversation."""
|
|
||||||
|
|
||||||
role: str = Field(..., description="角色:user 或 assistant")
|
|
||||||
content: str = Field(..., description="消息内容")
|
|
||||||
|
|
||||||
|
|
||||||
class WriteAggregateResponse(BaseModel):
|
|
||||||
"""Response model for aggregate judgment containing judgment result and output."""
|
|
||||||
|
|
||||||
is_same_event: bool = Field(
|
|
||||||
...,
|
|
||||||
description="是否是同一事件。True表示是同一事件,False表示不同事件"
|
|
||||||
)
|
|
||||||
output: Union[List[MessageItem], bool] = Field(
|
|
||||||
...,
|
|
||||||
description="如果is_same_event为True,返回False;如果is_same_event为False,返回消息列表"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# 为了保持向后兼容,保留旧的类名作为别名
|
|
||||||
WriteAggregateModel = WriteAggregateResponse
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
输入句子:{{sentence}}
|
|
||||||
历史消息:{{history}}
|
|
||||||
|
|
||||||
# 你的角色
|
|
||||||
你是一个擅长事件聚合与语义判断的专家。
|
|
||||||
|
|
||||||
# 你的任务
|
|
||||||
结合历史消息和输入句子,判断它们是否在描述**同一件事件或同一事件链**。
|
|
||||||
|
|
||||||
以下情况视为"同一事件"(需要返回 is_same_event=True, output=False):
|
|
||||||
- 描述的是同一个具体事件或事实
|
|
||||||
- 存在明显的因果关系、前后发展关系
|
|
||||||
- 是对同一事件的补充、解释、追问或延展
|
|
||||||
- 逻辑上属于同一语境下的连续讨论
|
|
||||||
|
|
||||||
以下情况视为"不同事件"(需要返回 is_same_event=False, output=消息列表):
|
|
||||||
- 话题不同,事件主体不同
|
|
||||||
- 时间、地点、对象明显不同
|
|
||||||
- 只是语义相似,但并非同一具体事件
|
|
||||||
- 无直接事件、因果或逻辑关联
|
|
||||||
|
|
||||||
# 输出规则(非常重要)
|
|
||||||
你必须按照以下JSON格式输出:
|
|
||||||
|
|
||||||
**如果是同一事件:**
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"is_same_event": true,
|
|
||||||
"output": false
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**如果不是同一事件:**
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"is_same_event": false,
|
|
||||||
"output": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "输入句子的内容"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "对应的回复内容"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
# JSON Schema
|
|
||||||
{{json_schema}}
|
|
||||||
|
|
||||||
# 注意事项
|
|
||||||
- 必须严格按照上述格式输出
|
|
||||||
- output 字段:如果是同一事件返回 false,如果不是同一事件返回完整的消息列表
|
|
||||||
- 消息列表必须包含 role 和 content 字段
|
|
||||||
- 不要输出任何解释、分析或多余内容
|
|
||||||
@@ -1,186 +0,0 @@
|
|||||||
import json
|
|
||||||
from typing import Any, List, Dict, Optional
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_messages(messages: Any) -> str:
|
|
||||||
"""
|
|
||||||
将消息序列化为 JSON 字符串,支持 LangChain 消息对象
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: 可以是 list、dict、string 或 LangChain 消息对象列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: JSON 字符串
|
|
||||||
"""
|
|
||||||
if isinstance(messages, str):
|
|
||||||
return messages
|
|
||||||
|
|
||||||
if isinstance(messages, (list, tuple)):
|
|
||||||
# 检查是否是 LangChain 消息对象列表
|
|
||||||
serialized_list = []
|
|
||||||
for msg in messages:
|
|
||||||
if hasattr(msg, 'type') and hasattr(msg, 'content'):
|
|
||||||
# LangChain 消息对象
|
|
||||||
serialized_list.append({
|
|
||||||
'type': msg.type,
|
|
||||||
'content': msg.content,
|
|
||||||
'role': getattr(msg, 'role', msg.type)
|
|
||||||
})
|
|
||||||
elif isinstance(msg, dict):
|
|
||||||
serialized_list.append(msg)
|
|
||||||
else:
|
|
||||||
serialized_list.append(str(msg))
|
|
||||||
return json.dumps(serialized_list, ensure_ascii=False)
|
|
||||||
|
|
||||||
if isinstance(messages, dict):
|
|
||||||
return json.dumps(messages, ensure_ascii=False)
|
|
||||||
|
|
||||||
# 其他类型转为字符串
|
|
||||||
return str(messages)
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize_messages(messages_str: str) -> Any:
|
|
||||||
"""
|
|
||||||
将 JSON 字符串反序列化为原始格式
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages_str: JSON 字符串
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
反序列化后的对象(list、dict 或 string)
|
|
||||||
"""
|
|
||||||
if not messages_str:
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
|
||||||
return json.loads(messages_str)
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
return messages_str
|
|
||||||
|
|
||||||
|
|
||||||
def fix_encoding(text: str) -> str:
|
|
||||||
"""
|
|
||||||
修复错误编码的文本
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: 需要修复的文本
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 修复后的文本
|
|
||||||
"""
|
|
||||||
if not text or not isinstance(text, str):
|
|
||||||
return text
|
|
||||||
try:
|
|
||||||
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
|
|
||||||
return text.encode('latin-1').decode('utf-8')
|
|
||||||
except (UnicodeDecodeError, UnicodeEncodeError):
|
|
||||||
# 如果修复失败,返回原文本
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def format_session_data(data: Dict[str, Any], include_time: bool = False) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
格式化会话数据为统一的输出格式
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: 原始会话数据
|
|
||||||
include_time: 是否包含时间字段
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: 格式化后的数据 {"Query": "...", "Answer": "...", "starttime": "..."}
|
|
||||||
"""
|
|
||||||
result = {
|
|
||||||
"Query": fix_encoding(data.get('messages', '')),
|
|
||||||
"Answer": fix_encoding(data.get('aimessages', ''))
|
|
||||||
}
|
|
||||||
|
|
||||||
if include_time:
|
|
||||||
result["starttime"] = data.get('starttime', '')
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def filter_by_time_range(items: List[Dict], minutes: int) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
根据时间范围过滤数据
|
|
||||||
|
|
||||||
Args:
|
|
||||||
items: 包含 starttime 字段的数据列表
|
|
||||||
minutes: 时间范围(分钟)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict]: 过滤后的数据列表
|
|
||||||
"""
|
|
||||||
time_threshold = datetime.now() - timedelta(minutes=minutes)
|
|
||||||
time_threshold_str = time_threshold.strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
|
|
||||||
filtered_items = []
|
|
||||||
for item in items:
|
|
||||||
starttime = item.get('starttime', '')
|
|
||||||
if starttime and starttime >= time_threshold_str:
|
|
||||||
filtered_items.append(item)
|
|
||||||
|
|
||||||
return filtered_items
|
|
||||||
|
|
||||||
|
|
||||||
def sort_and_limit_results(items: List[Dict], limit: int = 6,
|
|
||||||
remove_time: bool = True) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
对结果进行排序、限制数量并移除时间字段
|
|
||||||
|
|
||||||
Args:
|
|
||||||
items: 数据列表
|
|
||||||
limit: 最大返回数量
|
|
||||||
remove_time: 是否移除 starttime 字段
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict]: 处理后的数据列表
|
|
||||||
"""
|
|
||||||
# 按时间降序排序(最新的在前)
|
|
||||||
items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
|
||||||
|
|
||||||
# 限制数量
|
|
||||||
result_items = items[:limit]
|
|
||||||
|
|
||||||
# 移除 starttime 字段
|
|
||||||
if remove_time:
|
|
||||||
for item in result_items:
|
|
||||||
item.pop('starttime', None)
|
|
||||||
|
|
||||||
# 如果结果少于1条,返回空列表
|
|
||||||
if len(result_items) < 1:
|
|
||||||
return []
|
|
||||||
|
|
||||||
return result_items
|
|
||||||
|
|
||||||
|
|
||||||
def generate_session_key(session_id: str, key_type: str = "session") -> str:
|
|
||||||
"""
|
|
||||||
生成 Redis key
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_id: 会话ID
|
|
||||||
key_type: key 类型 ("session", "read", "write", "count")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Redis key
|
|
||||||
"""
|
|
||||||
if key_type == "count":
|
|
||||||
return f"session:count:{session_id}"
|
|
||||||
elif key_type == "write":
|
|
||||||
return f"session:write:{session_id}"
|
|
||||||
elif key_type == "session" or key_type == "read":
|
|
||||||
return f"session:{session_id}"
|
|
||||||
else:
|
|
||||||
return f"session:{session_id}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_timestamp() -> str:
|
|
||||||
"""
|
|
||||||
获取当前时间戳字符串
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 格式化的时间字符串 "YYYY-MM-DD HH:MM:SS"
|
|
||||||
"""
|
|
||||||
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
@@ -1,445 +1,11 @@
|
|||||||
import redis
|
import redis
|
||||||
import uuid
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from typing import List, Dict, Any, Optional, Union
|
|
||||||
|
|
||||||
from app.core.memory.agent.utils.redis_base import (
|
|
||||||
serialize_messages,
|
|
||||||
deserialize_messages,
|
|
||||||
fix_encoding,
|
|
||||||
format_session_data,
|
|
||||||
filter_by_time_range,
|
|
||||||
sort_and_limit_results,
|
|
||||||
generate_session_key,
|
|
||||||
get_current_timestamp
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RedisWriteStore:
|
|
||||||
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
|
|
||||||
|
|
||||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
|
||||||
"""
|
|
||||||
初始化 Redis 连接
|
|
||||||
|
|
||||||
Args:
|
|
||||||
host: Redis 主机地址
|
|
||||||
port: Redis 端口
|
|
||||||
db: Redis 数据库编号
|
|
||||||
password: Redis 密码
|
|
||||||
session_id: 会话ID
|
|
||||||
"""
|
|
||||||
self.r = redis.Redis(
|
|
||||||
host=host,
|
|
||||||
port=port,
|
|
||||||
db=db,
|
|
||||||
password=password,
|
|
||||||
decode_responses=True,
|
|
||||||
encoding='utf-8'
|
|
||||||
)
|
|
||||||
self.uudi = session_id
|
|
||||||
|
|
||||||
def save_session_write(self, userid: str, messages: str) -> str:
|
|
||||||
"""
|
|
||||||
写入一条会话数据,返回 session_id
|
|
||||||
|
|
||||||
Args:
|
|
||||||
userid: 用户ID
|
|
||||||
messages: 用户消息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 新生成的 session_id
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
messages = serialize_messages(messages)
|
|
||||||
session_id = str(uuid.uuid4())
|
|
||||||
key = generate_session_key(session_id, key_type="write")
|
|
||||||
|
|
||||||
pipe = self.r.pipeline()
|
|
||||||
pipe.hset(key, mapping={
|
|
||||||
"id": self.uudi,
|
|
||||||
"sessionid": userid,
|
|
||||||
"messages": messages,
|
|
||||||
"starttime": get_current_timestamp()
|
|
||||||
})
|
|
||||||
result = pipe.execute()
|
|
||||||
|
|
||||||
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
|
||||||
return session_id
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[save_session_write] 保存会话失败: {e}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
|
|
||||||
"""
|
|
||||||
通过 save_session_write 的 userid 获取 sessionid 和 messages
|
|
||||||
|
|
||||||
Args:
|
|
||||||
userid: 用户ID (对应 sessionid 字段)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict] 或 False: 如果找到数据返回 [{"sessionid": "...", "messages": "..."}, ...],否则返回 False
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 只查询 write 类型的 key
|
|
||||||
keys = self.r.keys('session:write:*')
|
|
||||||
if not keys:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 批量获取数据
|
|
||||||
pipe = self.r.pipeline()
|
|
||||||
for key in keys:
|
|
||||||
pipe.hgetall(key)
|
|
||||||
all_data = pipe.execute()
|
|
||||||
|
|
||||||
# 筛选符合 userid 的数据
|
|
||||||
results = []
|
|
||||||
for key, data in zip(keys, all_data):
|
|
||||||
if not data:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 从 write 类型读取,匹配 sessionid 字段
|
|
||||||
if data.get('sessionid') == userid:
|
|
||||||
# 从 key 中提取 session_id: session:write:{session_id}
|
|
||||||
session_id = key.split(':')[-1]
|
|
||||||
results.append({
|
|
||||||
"sessionid": session_id,
|
|
||||||
"messages": fix_encoding(data.get('messages', ''))
|
|
||||||
})
|
|
||||||
|
|
||||||
if not results:
|
|
||||||
return False
|
|
||||||
|
|
||||||
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
|
||||||
return results
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[get_session_by_userid] 查询失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
|
|
||||||
"""
|
|
||||||
通过 end_user_id 获取所有 write 类型的会话数据
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: 终端用户ID (对应 sessionid 字段)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict] 或 False: 如果找到数据返回完整的会话信息列表,否则返回 False
|
|
||||||
|
|
||||||
返回格式:
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"session_id": "uuid",
|
|
||||||
"id": "...",
|
|
||||||
"sessionid": "end_user_id",
|
|
||||||
"messages": "...",
|
|
||||||
"starttime": "timestamp"
|
|
||||||
},
|
|
||||||
...
|
|
||||||
]
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 只查询 write 类型的 key
|
|
||||||
keys = self.r.keys('session:write:*')
|
|
||||||
if not keys:
|
|
||||||
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 批量获取数据
|
|
||||||
pipe = self.r.pipeline()
|
|
||||||
for key in keys:
|
|
||||||
pipe.hgetall(key)
|
|
||||||
all_data = pipe.execute()
|
|
||||||
|
|
||||||
# 筛选符合 end_user_id 的数据
|
|
||||||
results = []
|
|
||||||
for key, data in zip(keys, all_data):
|
|
||||||
if not data:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 从 write 类型读取,匹配 sessionid 字段
|
|
||||||
if data.get('sessionid') == end_user_id:
|
|
||||||
# 从 key 中提取 session_id: session:write:{session_id}
|
|
||||||
session_id = key.split(':')[-1]
|
|
||||||
|
|
||||||
# 构建完整的会话信息
|
|
||||||
session_info = {
|
|
||||||
"session_id": session_id,
|
|
||||||
"id": data.get('id', ''),
|
|
||||||
"sessionid": data.get('sessionid', ''),
|
|
||||||
"messages": fix_encoding(data.get('messages', '')),
|
|
||||||
"starttime": data.get('starttime', '')
|
|
||||||
}
|
|
||||||
results.append(session_info)
|
|
||||||
|
|
||||||
if not results:
|
|
||||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 按时间排序(最新的在前)
|
|
||||||
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
|
||||||
|
|
||||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
|
||||||
return results
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
def find_user_recent_sessions(self, userid: str,
|
|
||||||
minutes: int = 5) -> List[Dict[str, str]]:
|
|
||||||
"""
|
|
||||||
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
|
|
||||||
|
|
||||||
Args:
|
|
||||||
userid: 用户ID (对应 sessionid 字段)
|
|
||||||
minutes: 查询最近几分钟的数据,默认5分钟
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
|
|
||||||
"""
|
|
||||||
import time
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
# 只查询 write 类型的 key
|
|
||||||
keys = self.r.keys('session:write:*')
|
|
||||||
if not keys:
|
|
||||||
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 批量获取数据
|
|
||||||
pipe = self.r.pipeline()
|
|
||||||
for key in keys:
|
|
||||||
pipe.hgetall(key)
|
|
||||||
all_data = pipe.execute()
|
|
||||||
|
|
||||||
# 筛选符合 userid 的数据
|
|
||||||
matched_items = []
|
|
||||||
for data in all_data:
|
|
||||||
if not data:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 从 write 类型读取,匹配 sessionid 字段
|
|
||||||
if data.get('sessionid') == userid and data.get('starttime'):
|
|
||||||
# write 类型没有 aimessages,所以 Answer 为空
|
|
||||||
matched_items.append({
|
|
||||||
"Query": fix_encoding(data.get('messages', '')),
|
|
||||||
"Answer": "",
|
|
||||||
"starttime": data.get('starttime', '')
|
|
||||||
})
|
|
||||||
|
|
||||||
# 根据时间范围过滤
|
|
||||||
filtered_items = filter_by_time_range(matched_items, minutes)
|
|
||||||
# 排序并移除时间字段
|
|
||||||
result_items = sort_and_limit_results(filtered_items, limit=None)
|
|
||||||
print(result_items)
|
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
|
||||||
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
|
||||||
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
|
||||||
|
|
||||||
return result_items
|
|
||||||
|
|
||||||
def delete_all_write_sessions(self) -> int:
|
|
||||||
"""
|
|
||||||
删除所有 write 类型的会话
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: 删除的数量
|
|
||||||
"""
|
|
||||||
keys = self.r.keys('session:write:*')
|
|
||||||
if keys:
|
|
||||||
return self.r.delete(*keys)
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
class RedisCountStore:
|
|
||||||
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
|
|
||||||
|
|
||||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
|
||||||
"""
|
|
||||||
初始化 Redis 连接
|
|
||||||
|
|
||||||
Args:
|
|
||||||
host: Redis 主机地址
|
|
||||||
port: Redis 端口
|
|
||||||
db: Redis 数据库编号
|
|
||||||
password: Redis 密码
|
|
||||||
session_id: 会话ID
|
|
||||||
"""
|
|
||||||
self.r = redis.Redis(
|
|
||||||
host=host,
|
|
||||||
port=port,
|
|
||||||
db=db,
|
|
||||||
password=password,
|
|
||||||
decode_responses=True,
|
|
||||||
encoding='utf-8'
|
|
||||||
)
|
|
||||||
self.uudi = session_id
|
|
||||||
|
|
||||||
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
|
|
||||||
"""
|
|
||||||
保存用户访问次数统计
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: 终端用户ID
|
|
||||||
count: 访问次数
|
|
||||||
messages: 消息内容
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 新生成的 session_id
|
|
||||||
"""
|
|
||||||
session_id = str(uuid.uuid4())
|
|
||||||
key = generate_session_key(session_id, key_type="count")
|
|
||||||
index_key = f'session:count:index:{end_user_id}' # 索引键
|
|
||||||
|
|
||||||
pipe = self.r.pipeline()
|
|
||||||
pipe.hset(key, mapping={
|
|
||||||
"id": self.uudi,
|
|
||||||
"end_user_id": end_user_id,
|
|
||||||
"count": int(count),
|
|
||||||
"messages": serialize_messages(messages),
|
|
||||||
"starttime": get_current_timestamp()
|
|
||||||
})
|
|
||||||
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
|
|
||||||
|
|
||||||
# 创建索引:end_user_id -> session_id 映射
|
|
||||||
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
|
|
||||||
|
|
||||||
result = pipe.execute()
|
|
||||||
|
|
||||||
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
|
||||||
return session_id
|
|
||||||
|
|
||||||
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
|
|
||||||
"""
|
|
||||||
通过 end_user_id 查询访问次数统计
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: 终端用户ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list 或 False: 如果找到返回 [count, messages],否则返回 False
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 使用索引键快速查找
|
|
||||||
index_key = f'session:count:index:{end_user_id}'
|
|
||||||
|
|
||||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
|
||||||
try:
|
|
||||||
key_type = self.r.type(index_key)
|
|
||||||
if key_type != 'string' and key_type != 'none':
|
|
||||||
self.r.delete(index_key)
|
|
||||||
return False
|
|
||||||
except Exception as type_error:
|
|
||||||
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
|
||||||
|
|
||||||
session_id = self.r.get(index_key)
|
|
||||||
|
|
||||||
if not session_id:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 直接获取数据
|
|
||||||
key = generate_session_key(session_id, key_type="count")
|
|
||||||
data = self.r.hgetall(key)
|
|
||||||
|
|
||||||
if not data:
|
|
||||||
# 索引存在但数据不存在,清理索引
|
|
||||||
self.r.delete(index_key)
|
|
||||||
return False
|
|
||||||
|
|
||||||
count = data.get('count')
|
|
||||||
messages_str = data.get('messages')
|
|
||||||
|
|
||||||
if count is not None:
|
|
||||||
messages = deserialize_messages(messages_str)
|
|
||||||
return [int(count), messages]
|
|
||||||
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[get_sessions_count] 查询失败: {e}")
|
|
||||||
return False
|
|
||||||
def update_sessions_count(self, end_user_id: str, new_count: int,
|
|
||||||
messages: Any) -> bool:
|
|
||||||
"""
|
|
||||||
通过 end_user_id 修改访问次数统计(优化版:使用索引)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: 终端用户ID
|
|
||||||
new_count: 新的 count 值
|
|
||||||
messages: 消息内容
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 更新成功返回 True,未找到记录返回 False
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 使用索引键快速查找
|
|
||||||
index_key = f'session:count:index:{end_user_id}'
|
|
||||||
|
|
||||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
|
||||||
try:
|
|
||||||
key_type = self.r.type(index_key)
|
|
||||||
if key_type != 'string' and key_type != 'none':
|
|
||||||
# 索引键类型错误,删除并返回 False
|
|
||||||
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
|
||||||
self.r.delete(index_key)
|
|
||||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
|
||||||
return False
|
|
||||||
except Exception as type_error:
|
|
||||||
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
|
||||||
|
|
||||||
session_id = self.r.get(index_key)
|
|
||||||
|
|
||||||
if not session_id:
|
|
||||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 直接更新数据
|
|
||||||
key = generate_session_key(session_id, key_type="count")
|
|
||||||
messages_str = serialize_messages(messages)
|
|
||||||
|
|
||||||
pipe = self.r.pipeline()
|
|
||||||
pipe.hset(key, 'count', int(new_count))
|
|
||||||
pipe.hset(key, 'messages', messages_str)
|
|
||||||
result = pipe.execute()
|
|
||||||
|
|
||||||
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[update_sessions_count] 更新失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def delete_all_count_sessions(self) -> int:
|
|
||||||
"""
|
|
||||||
删除所有 count 类型的会话
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: 删除的数量
|
|
||||||
"""
|
|
||||||
keys = self.r.keys('session:count:*')
|
|
||||||
if keys:
|
|
||||||
return self.r.delete(*keys)
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
class RedisSessionStore:
|
class RedisSessionStore:
|
||||||
"""Redis 会话存储类,用于管理会话数据"""
|
|
||||||
|
|
||||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||||
"""
|
|
||||||
初始化 Redis 连接
|
|
||||||
|
|
||||||
Args:
|
|
||||||
host: Redis 主机地址
|
|
||||||
port: Redis 端口
|
|
||||||
db: Redis 数据库编号
|
|
||||||
password: Redis 密码
|
|
||||||
session_id: 会话ID
|
|
||||||
"""
|
|
||||||
self.r = redis.Redis(
|
self.r = redis.Redis(
|
||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
@@ -450,28 +16,32 @@ class RedisSessionStore:
|
|||||||
)
|
)
|
||||||
self.uudi = session_id
|
self.uudi = session_id
|
||||||
|
|
||||||
# ==================== 写入操作 ====================
|
def _fix_encoding(self, text):
|
||||||
|
"""修复错误编码的文本"""
|
||||||
|
if not text or not isinstance(text, str):
|
||||||
|
return text
|
||||||
|
try:
|
||||||
|
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
|
||||||
|
return text.encode('latin-1').decode('utf-8')
|
||||||
|
except (UnicodeDecodeError, UnicodeEncodeError):
|
||||||
|
# 如果修复失败,返回原文本
|
||||||
|
return text
|
||||||
|
|
||||||
def save_session(self, userid: str, messages: str, aimessages: str,
|
# 修改后的 save_session 方法
|
||||||
apply_id: str, end_user_id: str) -> str:
|
def save_session(self, userid, messages, aimessages, apply_id, end_user_id):
|
||||||
"""
|
"""
|
||||||
写入一条会话数据,返回 session_id
|
写入一条会话数据,返回 session_id
|
||||||
|
优化版本:确保写入时间不超过1秒
|
||||||
Args:
|
|
||||||
userid: 用户ID
|
|
||||||
messages: 用户消息
|
|
||||||
aimessages: AI回复消息
|
|
||||||
apply_id: 应用ID
|
|
||||||
end_user_id: 终端用户ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 新生成的 session_id
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
session_id = str(uuid.uuid4())
|
session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID
|
||||||
key = generate_session_key(session_id, key_type="read")
|
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
key = f"session:{session_id}" # 使用新生成的 session_id 作为 key
|
||||||
|
|
||||||
|
# 使用 pipeline 批量写入,减少网络往返
|
||||||
pipe = self.r.pipeline()
|
pipe = self.r.pipeline()
|
||||||
|
|
||||||
|
# 直接写入数据,decode_responses=True 已经处理了编码
|
||||||
pipe.hset(key, mapping={
|
pipe.hset(key, mapping={
|
||||||
"id": self.uudi,
|
"id": self.uudi,
|
||||||
"sessionid": userid,
|
"sessionid": userid,
|
||||||
@@ -479,195 +49,177 @@ class RedisSessionStore:
|
|||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"aimessages": aimessages,
|
"aimessages": aimessages,
|
||||||
"starttime": get_current_timestamp()
|
"starttime": starttime
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# 可选:设置过期时间(例如30天),避免数据无限增长
|
||||||
|
# pipe.expire(key, 30 * 24 * 60 * 60)
|
||||||
|
|
||||||
|
# 执行批量操作
|
||||||
result = pipe.execute()
|
result = pipe.execute()
|
||||||
|
|
||||||
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
print(f"保存结果: {result[0]}, session_id: {session_id}")
|
||||||
return session_id
|
return session_id # 返回新生成的 session_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[save_session] 保存会话失败: {e}")
|
print(f"保存会话失败: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
# ==================== 读取操作 ====================
|
def save_sessions_batch(self, sessions_data):
|
||||||
|
"""
|
||||||
|
批量写入多条会话数据,返回 session_id 列表
|
||||||
|
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id
|
||||||
|
优化版本:批量操作,大幅提升性能
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
session_ids = []
|
||||||
|
pipe = self.r.pipeline()
|
||||||
|
|
||||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
for session in sessions_data:
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
key = f"session:{session_id}"
|
||||||
|
|
||||||
|
pipe.hset(key, mapping={
|
||||||
|
"id": self.uudi,
|
||||||
|
"sessionid": session.get('userid'),
|
||||||
|
"apply_id": session.get('apply_id'),
|
||||||
|
"end_user_id": session.get('end_user_id'),
|
||||||
|
"messages": session.get('messages'),
|
||||||
|
"aimessages": session.get('aimessages'),
|
||||||
|
"starttime": starttime
|
||||||
|
})
|
||||||
|
|
||||||
|
session_ids.append(session_id)
|
||||||
|
|
||||||
|
# 一次性执行所有写入操作
|
||||||
|
results = pipe.execute()
|
||||||
|
print(f"批量保存完成: {len(session_ids)} 条记录")
|
||||||
|
return session_ids
|
||||||
|
except Exception as e:
|
||||||
|
print(f"批量保存会话失败: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# ---------------- 读取 ----------------
|
||||||
|
def get_session(self, session_id):
|
||||||
"""
|
"""
|
||||||
读取一条会话数据
|
读取一条会话数据
|
||||||
|
|
||||||
Args:
|
|
||||||
session_id: 会话ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict 或 None: 会话数据
|
|
||||||
"""
|
"""
|
||||||
key = generate_session_key(session_id)
|
key = f"session:{session_id}"
|
||||||
data = self.r.hgetall(key)
|
data = self.r.hgetall(key)
|
||||||
return data if data else None
|
return data if data else None
|
||||||
|
|
||||||
def get_all_sessions(self) -> Dict[str, Dict[str, Any]]:
|
def get_session_apply_group(self, sessionid, apply_id, end_user_id):
|
||||||
"""
|
"""
|
||||||
获取所有会话数据(不包括 count 和 write 类型)
|
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据
|
||||||
|
"""
|
||||||
|
result_items = []
|
||||||
|
|
||||||
Returns:
|
# 遍历所有会话数据
|
||||||
Dict: 所有会话数据,key 为 session_id
|
|
||||||
"""
|
|
||||||
sessions = {}
|
|
||||||
for key in self.r.keys('session:*'):
|
for key in self.r.keys('session:*'):
|
||||||
# 排除 count 和 write 类型的 key
|
data = self.r.hgetall(key)
|
||||||
if ':count:' not in key and ':write:' not in key:
|
|
||||||
sid = key.split(':')[1]
|
|
||||||
sessions[sid] = self.get_session(sid)
|
|
||||||
return sessions
|
|
||||||
|
|
||||||
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
|
||||||
end_user_id: str) -> List[Dict[str, str]]:
|
|
||||||
"""
|
|
||||||
根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sessionid: 会话ID(支持模糊匹配)
|
|
||||||
apply_id: 应用ID
|
|
||||||
end_user_id: 终端用户ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
|
|
||||||
"""
|
|
||||||
import time
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
keys = self.r.keys('session:*')
|
|
||||||
if not keys:
|
|
||||||
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 批量获取数据
|
|
||||||
pipe = self.r.pipeline()
|
|
||||||
for key in keys:
|
|
||||||
# 排除 count 和 write 类型
|
|
||||||
if ':count:' not in key and ':write:' not in key:
|
|
||||||
pipe.hgetall(key)
|
|
||||||
all_data = pipe.execute()
|
|
||||||
|
|
||||||
# 筛选符合条件的数据
|
|
||||||
matched_items = []
|
|
||||||
for data in all_data:
|
|
||||||
if not data:
|
if not data:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (data.get('apply_id') == apply_id and
|
# 检查三个条件是否都匹配
|
||||||
data.get('end_user_id') == end_user_id):
|
if (data.get('sessionid') == sessionid and
|
||||||
# 支持模糊匹配或完全匹配 sessionid
|
data.get('apply_id') == apply_id and
|
||||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
data.get('end_user_id') == end_user_id):
|
||||||
matched_items.append(format_session_data(data, include_time=True))
|
result_items.append(data)
|
||||||
|
|
||||||
# 排序、限制数量并移除时间字段
|
|
||||||
result_items = sort_and_limit_results(matched_items, limit=6)
|
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
|
||||||
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
|
||||||
|
|
||||||
return result_items
|
return result_items
|
||||||
|
|
||||||
# ==================== 更新操作 ====================
|
def get_all_sessions(self):
|
||||||
|
"""
|
||||||
|
获取所有会话数据
|
||||||
|
"""
|
||||||
|
sessions = {}
|
||||||
|
for key in self.r.keys('session:*'):
|
||||||
|
sid = key.split(':')[1]
|
||||||
|
sessions[sid] = self.get_session(sid)
|
||||||
|
return sessions
|
||||||
|
|
||||||
def update_session(self, session_id: str, field: str, value: Any) -> bool:
|
# ---------------- 更新 ----------------
|
||||||
|
def update_session(self, session_id, field, value):
|
||||||
"""
|
"""
|
||||||
更新单个字段
|
更新单个字段
|
||||||
|
优化版本:使用 pipeline 减少网络往返
|
||||||
Args:
|
|
||||||
session_id: 会话ID
|
|
||||||
field: 字段名
|
|
||||||
value: 字段值
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否更新成功
|
|
||||||
"""
|
"""
|
||||||
key = generate_session_key(session_id)
|
key = f"session:{session_id}"
|
||||||
pipe = self.r.pipeline()
|
pipe = self.r.pipeline()
|
||||||
pipe.exists(key)
|
pipe.exists(key)
|
||||||
pipe.hset(key, field, value)
|
pipe.hset(key, field, value)
|
||||||
results = pipe.execute()
|
results = pipe.execute()
|
||||||
return bool(results[0])
|
return bool(results[0]) # 返回 key 是否存在
|
||||||
|
|
||||||
# ==================== 删除操作 ====================
|
# ---------------- 删除 ----------------
|
||||||
|
def delete_session(self, session_id):
|
||||||
def delete_session(self, session_id: str) -> int:
|
|
||||||
"""
|
"""
|
||||||
删除单条会话
|
删除单条会话
|
||||||
|
|
||||||
Args:
|
|
||||||
session_id: 会话ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: 删除的数量
|
|
||||||
"""
|
"""
|
||||||
key = generate_session_key(session_id)
|
key = f"session:{session_id}"
|
||||||
return self.r.delete(key)
|
return self.r.delete(key)
|
||||||
|
|
||||||
def delete_all_sessions(self) -> int:
|
def delete_all_sessions(self):
|
||||||
"""
|
"""
|
||||||
删除所有会话(不包括 count 和 write 类型)
|
删除所有会话
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: 删除的数量
|
|
||||||
"""
|
"""
|
||||||
keys = self.r.keys('session:*')
|
keys = self.r.keys('session:*')
|
||||||
# 过滤掉 count 和 write 类型
|
if keys:
|
||||||
keys_to_delete = [k for k in keys if ':count:' not in k and ':write:' not in k]
|
return self.r.delete(*keys)
|
||||||
if keys_to_delete:
|
|
||||||
return self.r.delete(*keys_to_delete)
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def delete_duplicate_sessions(self) -> int:
|
def delete_duplicate_sessions(self):
|
||||||
"""
|
"""
|
||||||
删除重复会话数据(不包括 count 和 write 类型)
|
删除重复会话数据,条件:
|
||||||
条件:sessionid、user_id、end_user_id、messages、aimessages 五个字段都相同的只保留一个
|
"sessionid"、"user_id"、"end_user_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除
|
||||||
|
优化版本:使用 pipeline 批量操作,确保在1秒内完成
|
||||||
Returns:
|
|
||||||
int: 删除的数量
|
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 第一步:使用 pipeline 批量获取所有 key
|
||||||
keys = self.r.keys('session:*')
|
keys = self.r.keys('session:*')
|
||||||
|
|
||||||
if not keys:
|
if not keys:
|
||||||
print("[delete_duplicate_sessions] 没有会话数据")
|
print("[delete_duplicate_sessions] 没有会话数据")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# 批量获取所有数据
|
# 第二步:使用 pipeline 批量获取所有数据
|
||||||
pipe = self.r.pipeline()
|
pipe = self.r.pipeline()
|
||||||
for key in keys:
|
for key in keys:
|
||||||
# 排除 count 和 write 类型
|
pipe.hgetall(key)
|
||||||
if ':count:' not in key and ':write:' not in key:
|
|
||||||
pipe.hgetall(key)
|
|
||||||
all_data = pipe.execute()
|
all_data = pipe.execute()
|
||||||
|
|
||||||
# 识别重复数据
|
# 第三步:在内存中识别重复数据
|
||||||
seen = {}
|
seen = {} # 用字典记录:identifier -> key(保留第一个出现的 key)
|
||||||
keys_to_delete = []
|
keys_to_delete = [] # 需要删除的 key 列表
|
||||||
|
|
||||||
for key, data in zip([k for k in keys if ':count:' not in k and ':write:' not in k], all_data, strict=False):
|
for key, data in zip(keys, all_data, strict=False):
|
||||||
if not data:
|
if not data:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# 获取五个字段的值
|
||||||
|
sessionid = data.get('sessionid', '')
|
||||||
|
user_id = data.get('id', '')
|
||||||
|
end_user_id = data.get('end_user_id', '')
|
||||||
|
messages = data.get('messages', '')
|
||||||
|
aimessages = data.get('aimessages', '')
|
||||||
|
|
||||||
# 用五元组作为唯一标识
|
# 用五元组作为唯一标识
|
||||||
identifier = (
|
identifier = (sessionid, user_id, end_user_id, messages, aimessages)
|
||||||
data.get('sessionid', ''),
|
|
||||||
data.get('id', ''),
|
|
||||||
data.get('end_user_id', ''),
|
|
||||||
data.get('messages', ''),
|
|
||||||
data.get('aimessages', '')
|
|
||||||
)
|
|
||||||
|
|
||||||
if identifier in seen:
|
if identifier in seen:
|
||||||
|
# 重复,标记为待删除
|
||||||
keys_to_delete.append(key)
|
keys_to_delete.append(key)
|
||||||
else:
|
else:
|
||||||
|
# 第一次出现,记录
|
||||||
seen[identifier] = key
|
seen[identifier] = key
|
||||||
|
|
||||||
# 批量删除重复的 key
|
# 第四步:使用 pipeline 批量删除重复的 key
|
||||||
deleted_count = 0
|
deleted_count = 0
|
||||||
if keys_to_delete:
|
if keys_to_delete:
|
||||||
|
# 分批删除,避免单次操作过大
|
||||||
batch_size = 1000
|
batch_size = 1000
|
||||||
for i in range(0, len(keys_to_delete), batch_size):
|
for i in range(0, len(keys_to_delete), batch_size):
|
||||||
batch = keys_to_delete[i:i + batch_size]
|
batch = keys_to_delete[i:i + batch_size]
|
||||||
@@ -681,8 +233,75 @@ class RedisSessionStore:
|
|||||||
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||||
return deleted_count
|
return deleted_count
|
||||||
|
|
||||||
|
def find_user_session(self, sessionid):
|
||||||
|
user_id = sessionid
|
||||||
|
|
||||||
|
result_items = []
|
||||||
|
for key, values in store.get_all_sessions().items():
|
||||||
|
history = {}
|
||||||
|
if user_id == str(values['sessionid']):
|
||||||
|
history["Query"] = values['messages']
|
||||||
|
history["Answer"] = values['aimessages']
|
||||||
|
result_items.append(history)
|
||||||
|
|
||||||
|
if len(result_items) <= 1:
|
||||||
|
result_items = []
|
||||||
|
return (result_items)
|
||||||
|
|
||||||
|
def find_user_apply_group(self, sessionid, apply_id, end_user_id):
|
||||||
|
"""
|
||||||
|
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据,返回最新的6条
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
# 使用 pipeline 批量获取数据,提高性能
|
||||||
|
keys = self.r.keys('session:*')
|
||||||
|
|
||||||
|
if not keys:
|
||||||
|
print(f"查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 使用 pipeline 批量获取所有 hash 数据
|
||||||
|
pipe = self.r.pipeline()
|
||||||
|
for key in keys:
|
||||||
|
pipe.hgetall(key)
|
||||||
|
all_data = pipe.execute()
|
||||||
|
|
||||||
|
# 解析并筛选符合条件的数据
|
||||||
|
matched_items = []
|
||||||
|
for data in all_data:
|
||||||
|
if not data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查是否符合三个条件
|
||||||
|
|
||||||
|
if (data.get('apply_id') == apply_id and
|
||||||
|
data.get('end_user_id') == end_user_id):
|
||||||
|
# 支持模糊匹配 sessionid 或者完全匹配
|
||||||
|
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||||
|
matched_items.append({
|
||||||
|
"Query": self._fix_encoding(data.get('messages')),
|
||||||
|
"Answer": self._fix_encoding(data.get('aimessages')),
|
||||||
|
"starttime": data.get('starttime', '')
|
||||||
|
})
|
||||||
|
# 按时间降序排序(最新的在前)
|
||||||
|
matched_items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||||
|
# 只保留最新的6条
|
||||||
|
result_items = matched_items[:6]
|
||||||
|
# # 移除 starttime 字段
|
||||||
|
for item in result_items:
|
||||||
|
item.pop('starttime', None)
|
||||||
|
|
||||||
|
# 如果结果少于等于1条,返回空列表
|
||||||
|
if len(result_items) <= 1:
|
||||||
|
result_items = []
|
||||||
|
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
print(f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||||
|
|
||||||
|
return result_items
|
||||||
|
|
||||||
|
|
||||||
# 全局实例
|
|
||||||
store = RedisSessionStore(
|
store = RedisSessionStore(
|
||||||
host=settings.REDIS_HOST,
|
host=settings.REDIS_HOST,
|
||||||
port=settings.REDIS_PORT,
|
port=settings.REDIS_PORT,
|
||||||
@@ -690,19 +309,3 @@ store = RedisSessionStore(
|
|||||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||||
session_id=str(uuid.uuid4())
|
session_id=str(uuid.uuid4())
|
||||||
)
|
)
|
||||||
|
|
||||||
write_store = RedisWriteStore(
|
|
||||||
host=settings.REDIS_HOST,
|
|
||||||
port=settings.REDIS_PORT,
|
|
||||||
db=settings.REDIS_DB,
|
|
||||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
|
||||||
session_id=str(uuid.uuid4())
|
|
||||||
)
|
|
||||||
|
|
||||||
count_store = RedisCountStore(
|
|
||||||
host=settings.REDIS_HOST,
|
|
||||||
port=settings.REDIS_PORT,
|
|
||||||
db=settings.REDIS_DB,
|
|
||||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
|
||||||
session_id=str(uuid.uuid4())
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ Write Tools for Memory Knowledge Extraction Pipeline
|
|||||||
This module provides the main write function for executing the knowledge extraction
|
This module provides the main write function for executing the knowledge extraction
|
||||||
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||||
"""
|
"""
|
||||||
import asyncio
|
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
@@ -124,48 +123,23 @@ async def write(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
||||||
|
|
||||||
# 添加死锁重试机制
|
|
||||||
max_retries = 3
|
|
||||||
retry_delay = 1 # 秒
|
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
success = await save_dialog_and_statements_to_neo4j(
|
|
||||||
dialogue_nodes=all_dialogue_nodes,
|
|
||||||
chunk_nodes=all_chunk_nodes,
|
|
||||||
statement_nodes=all_statement_nodes,
|
|
||||||
entity_nodes=all_entity_nodes,
|
|
||||||
statement_chunk_edges=all_statement_chunk_edges,
|
|
||||||
statement_entity_edges=all_statement_entity_edges,
|
|
||||||
entity_edges=all_entity_entity_edges,
|
|
||||||
connector=neo4j_connector
|
|
||||||
)
|
|
||||||
if success:
|
|
||||||
logger.info("Successfully saved all data to Neo4j")
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
logger.warning("Failed to save some data to Neo4j")
|
|
||||||
if attempt < max_retries - 1:
|
|
||||||
logger.info(f"Retrying... (attempt {attempt + 2}/{max_retries})")
|
|
||||||
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = str(e)
|
|
||||||
# 检查是否是死锁错误
|
|
||||||
if "DeadlockDetected" in error_msg or "deadlock" in error_msg.lower():
|
|
||||||
if attempt < max_retries - 1:
|
|
||||||
logger.warning(f"Deadlock detected, retrying... (attempt {attempt + 2}/{max_retries})")
|
|
||||||
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
|
||||||
else:
|
|
||||||
logger.error(f"Failed after {max_retries} attempts due to deadlock: {e}")
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
# 非死锁错误,直接抛出
|
|
||||||
raise
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
success = await save_dialog_and_statements_to_neo4j(
|
||||||
|
dialogue_nodes=all_dialogue_nodes,
|
||||||
|
chunk_nodes=all_chunk_nodes,
|
||||||
|
statement_nodes=all_statement_nodes,
|
||||||
|
entity_nodes=all_entity_nodes,
|
||||||
|
statement_chunk_edges=all_statement_chunk_edges,
|
||||||
|
statement_entity_edges=all_statement_entity_edges,
|
||||||
|
entity_edges=all_entity_entity_edges,
|
||||||
|
connector=neo4j_connector
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
logger.info("Successfully saved all data to Neo4j")
|
||||||
|
else:
|
||||||
|
logger.warning("Failed to save some data to Neo4j")
|
||||||
|
finally:
|
||||||
await neo4j_connector.close()
|
await neo4j_connector.close()
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error closing Neo4j connector: {e}")
|
|
||||||
|
|
||||||
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
||||||
|
|
||||||
|
|||||||
@@ -58,12 +58,6 @@ from app.core.memory.models.triplet_models import (
|
|||||||
TripletExtractionResponse,
|
TripletExtractionResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ontology models
|
|
||||||
from app.core.memory.models.ontology_models import (
|
|
||||||
OntologyClass,
|
|
||||||
OntologyExtractionResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Variable configuration models
|
# Variable configuration models
|
||||||
from app.core.memory.models.variate_config import (
|
from app.core.memory.models.variate_config import (
|
||||||
StatementExtractionConfig,
|
StatementExtractionConfig,
|
||||||
@@ -111,9 +105,6 @@ __all__ = [
|
|||||||
"Entity",
|
"Entity",
|
||||||
"Triplet",
|
"Triplet",
|
||||||
"TripletExtractionResponse",
|
"TripletExtractionResponse",
|
||||||
# Ontology models
|
|
||||||
"OntologyClass",
|
|
||||||
"OntologyExtractionResponse",
|
|
||||||
# Variable configuration
|
# Variable configuration
|
||||||
"StatementExtractionConfig",
|
"StatementExtractionConfig",
|
||||||
"ForgettingEngineConfig",
|
"ForgettingEngineConfig",
|
||||||
|
|||||||
@@ -413,8 +413,7 @@ class ExtractedEntityNode(Node):
|
|||||||
description="Entity aliases - alternative names for this entity"
|
description="Entity aliases - alternative names for this entity"
|
||||||
)
|
)
|
||||||
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
|
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
|
||||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||||
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
|
||||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
||||||
|
|
||||||
|
|||||||
@@ -1,135 +0,0 @@
|
|||||||
"""Models for ontology classes and extraction responses.
|
|
||||||
|
|
||||||
This module contains Pydantic models for representing extracted ontology classes
|
|
||||||
from scenario descriptions, following OWL ontology engineering standards.
|
|
||||||
|
|
||||||
Classes:
|
|
||||||
OntologyClass: Represents an extracted ontology class
|
|
||||||
OntologyExtractionResponse: Response model containing extracted ontology classes
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
||||||
|
|
||||||
|
|
||||||
class OntologyClass(BaseModel):
|
|
||||||
"""Represents an extracted ontology class from scenario description.
|
|
||||||
|
|
||||||
An ontology class represents an abstract category or concept in a domain,
|
|
||||||
following OWL ontology engineering standards and naming conventions.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
id: Unique string identifier for the ontology class
|
|
||||||
name: Name of the class in PascalCase format (e.g., 'MedicalProcedure')
|
|
||||||
name_chinese: Chinese translation of the class name (e.g., '医疗程序')
|
|
||||||
description: Textual description of the class
|
|
||||||
examples: List of concrete instance examples of this class
|
|
||||||
parent_class: Optional name of the parent class in the hierarchy
|
|
||||||
entity_type: Type/category of the entity (e.g., 'Person', 'Organization', 'Concept')
|
|
||||||
domain: Domain this class belongs to (e.g., 'Healthcare', 'Education')
|
|
||||||
|
|
||||||
Config:
|
|
||||||
extra: Ignore extra fields from LLM output
|
|
||||||
"""
|
|
||||||
model_config = ConfigDict(extra='ignore')
|
|
||||||
|
|
||||||
id: str = Field(
|
|
||||||
default_factory=lambda: uuid4().hex,
|
|
||||||
description="Unique identifier for the ontology class"
|
|
||||||
)
|
|
||||||
name: str = Field(
|
|
||||||
...,
|
|
||||||
description="Name of the class in PascalCase format"
|
|
||||||
)
|
|
||||||
name_chinese: Optional[str] = Field(
|
|
||||||
None,
|
|
||||||
description="Chinese translation of the class name"
|
|
||||||
)
|
|
||||||
description: str = Field(
|
|
||||||
...,
|
|
||||||
description="Description of the class"
|
|
||||||
)
|
|
||||||
examples: List[str] = Field(
|
|
||||||
default_factory=list,
|
|
||||||
description="List of concrete instance examples"
|
|
||||||
)
|
|
||||||
parent_class: Optional[str] = Field(
|
|
||||||
None,
|
|
||||||
description="Name of the parent class in the hierarchy"
|
|
||||||
)
|
|
||||||
entity_type: str = Field(
|
|
||||||
...,
|
|
||||||
description="Type/category of the entity"
|
|
||||||
)
|
|
||||||
domain: str = Field(
|
|
||||||
...,
|
|
||||||
description="Domain this class belongs to"
|
|
||||||
)
|
|
||||||
|
|
||||||
@field_validator('name')
|
|
||||||
@classmethod
|
|
||||||
def validate_pascal_case(cls, v: str) -> str:
|
|
||||||
"""Validate that the class name follows PascalCase convention.
|
|
||||||
|
|
||||||
PascalCase rules:
|
|
||||||
- Must start with an uppercase letter
|
|
||||||
- Cannot contain spaces
|
|
||||||
- Should not contain special characters except underscores
|
|
||||||
|
|
||||||
Args:
|
|
||||||
v: The class name to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The validated class name
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the name doesn't follow PascalCase convention
|
|
||||||
"""
|
|
||||||
if not v:
|
|
||||||
raise ValueError("Class name cannot be empty")
|
|
||||||
|
|
||||||
if not v[0].isupper():
|
|
||||||
raise ValueError(
|
|
||||||
f"Class name '{v}' must start with an uppercase letter (PascalCase)"
|
|
||||||
)
|
|
||||||
|
|
||||||
if ' ' in v:
|
|
||||||
raise ValueError(
|
|
||||||
f"Class name '{v}' cannot contain spaces (PascalCase)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check for invalid characters (allow alphanumeric and underscore only)
|
|
||||||
if not all(c.isalnum() or c == '_' for c in v):
|
|
||||||
raise ValueError(
|
|
||||||
f"Class name '{v}' contains invalid characters. "
|
|
||||||
"Only alphanumeric characters and underscores are allowed"
|
|
||||||
)
|
|
||||||
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
class OntologyExtractionResponse(BaseModel):
|
|
||||||
"""Response model for ontology extraction from LLM.
|
|
||||||
|
|
||||||
This model represents the structured output from the LLM when
|
|
||||||
extracting ontology classes from scenario descriptions.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
classes: List of extracted ontology classes
|
|
||||||
domain: Domain/field the scenario belongs to
|
|
||||||
|
|
||||||
Config:
|
|
||||||
extra: Ignore extra fields from LLM output
|
|
||||||
"""
|
|
||||||
model_config = ConfigDict(extra='ignore')
|
|
||||||
|
|
||||||
classes: List[OntologyClass] = Field(
|
|
||||||
default_factory=list,
|
|
||||||
description="List of extracted ontology classes"
|
|
||||||
)
|
|
||||||
domain: str = Field(
|
|
||||||
...,
|
|
||||||
description="Domain/field the scenario belongs to"
|
|
||||||
)
|
|
||||||
@@ -134,45 +134,42 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
|
|||||||
if len(desc_b) > len(desc_a):
|
if len(desc_b) > len(desc_a):
|
||||||
canonical.description = desc_b
|
canonical.description = desc_b
|
||||||
# 合并事实摘要:统一保留一个“实体: name”行,来源行去重保序
|
# 合并事实摘要:统一保留一个“实体: name”行,来源行去重保序
|
||||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
fact_a = getattr(canonical, "fact_summary", "") or ""
|
||||||
# fact_a = getattr(canonical, "fact_summary", "") or ""
|
fact_b = getattr(ent, "fact_summary", "") or ""
|
||||||
# fact_b = getattr(ent, "fact_summary", "") or ""
|
def _extract_sources(txt: str) -> List[str]:
|
||||||
# def _extract_sources(txt: str) -> List[str]:
|
sources: List[str] = []
|
||||||
# sources: List[str] = []
|
if not txt:
|
||||||
# if not txt:
|
return sources
|
||||||
# return sources
|
for line in str(txt).splitlines():
|
||||||
# for line in str(txt).splitlines():
|
ln = line.strip()
|
||||||
# ln = line.strip()
|
|
||||||
# 支持“来源:”或“来源:”前缀
|
# 支持“来源:”或“来源:”前缀
|
||||||
# m = re.match(r"^来源[::]\s*(.+)$", ln)
|
m = re.match(r"^来源[::]\s*(.+)$", ln)
|
||||||
# if m:
|
if m:
|
||||||
# content = m.group(1).strip()
|
content = m.group(1).strip()
|
||||||
# if content:
|
if content:
|
||||||
# sources.append(content)
|
sources.append(content)
|
||||||
# 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失
|
# 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失
|
||||||
# if not sources and txt.strip():
|
if not sources and txt.strip():
|
||||||
# sources.append(txt.strip())
|
sources.append(txt.strip())
|
||||||
# return sources
|
return sources
|
||||||
try:
|
try:
|
||||||
# src_a = _extract_sources(fact_a)
|
src_a = _extract_sources(fact_a)
|
||||||
# src_b = _extract_sources(fact_b)
|
src_b = _extract_sources(fact_b)
|
||||||
# seen = set()
|
seen = set()
|
||||||
# merged_sources: List[str] = []
|
merged_sources: List[str] = []
|
||||||
# for s in src_a + src_b:
|
for s in src_a + src_b:
|
||||||
# if s and s not in seen:
|
if s and s not in seen:
|
||||||
# seen.add(s)
|
seen.add(s)
|
||||||
# merged_sources.append(s)
|
merged_sources.append(s)
|
||||||
# if merged_sources:
|
if merged_sources:
|
||||||
# name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
|
name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
|
||||||
# canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
|
canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
|
||||||
# elif fact_b and not fact_a:
|
elif fact_b and not fact_a:
|
||||||
# canonical.fact_summary = fact_b
|
canonical.fact_summary = fact_b
|
||||||
pass
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# 兜底:若解析失败,保留较长文本
|
# 兜底:若解析失败,保留较长文本
|
||||||
# if len(fact_b) > len(fact_a):
|
if len(fact_b) > len(fact_a):
|
||||||
# canonical.fact_summary = fact_b
|
canonical.fact_summary = fact_b
|
||||||
pass
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -145,13 +145,10 @@ def _choose_canonical(a: ExtractedEntityNode, b: ExtractedEntityNode) -> int: #
|
|||||||
# 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整)
|
# 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整)
|
||||||
desc_a = (getattr(a, "description", "") or "")
|
desc_a = (getattr(a, "description", "") or "")
|
||||||
desc_b = (getattr(b, "description", "") or "")
|
desc_b = (getattr(b, "description", "") or "")
|
||||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
fact_a = (getattr(a, "fact_summary", "") or "")
|
||||||
# fact_a = (getattr(a, "fact_summary", "") or "")
|
fact_b = (getattr(b, "fact_summary", "") or "")
|
||||||
# fact_b = (getattr(b, "fact_summary", "") or "")
|
score_a = len(desc_a) + len(fact_a)
|
||||||
# score_a = len(desc_a) + len(fact_a)
|
score_b = len(desc_b) + len(fact_b)
|
||||||
# score_b = len(desc_b) + len(fact_b)
|
|
||||||
score_a = len(desc_a)
|
|
||||||
score_b = len(desc_b)
|
|
||||||
if score_a != score_b:
|
if score_a != score_b:
|
||||||
return 0 if score_a >= score_b else 1
|
return 0 if score_a >= score_b else 1
|
||||||
return 0
|
return 0
|
||||||
@@ -192,8 +189,7 @@ async def _judge_pair(
|
|||||||
"entity_type": getattr(a, "entity_type", None),
|
"entity_type": getattr(a, "entity_type", None),
|
||||||
"description": getattr(a, "description", None),
|
"description": getattr(a, "description", None),
|
||||||
"aliases": getattr(a, "aliases", None) or [],
|
"aliases": getattr(a, "aliases", None) or [],
|
||||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
"fact_summary": getattr(a, "fact_summary", None),
|
||||||
# "fact_summary": getattr(a, "fact_summary", None),
|
|
||||||
"connect_strength": getattr(a, "connect_strength", None),
|
"connect_strength": getattr(a, "connect_strength", None),
|
||||||
}
|
}
|
||||||
entity_b = {
|
entity_b = {
|
||||||
@@ -201,8 +197,7 @@ async def _judge_pair(
|
|||||||
"entity_type": getattr(b, "entity_type", None),
|
"entity_type": getattr(b, "entity_type", None),
|
||||||
"description": getattr(b, "description", None),
|
"description": getattr(b, "description", None),
|
||||||
"aliases": getattr(b, "aliases", None) or [],
|
"aliases": getattr(b, "aliases", None) or [],
|
||||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
"fact_summary": getattr(b, "fact_summary", None),
|
||||||
# "fact_summary": getattr(b, "fact_summary", None),
|
|
||||||
"connect_strength": getattr(b, "connect_strength", None),
|
"connect_strength": getattr(b, "connect_strength", None),
|
||||||
}
|
}
|
||||||
# 5. 渲染LLM提示词(用工具函数填充模板,包含实体信息、上下文、输出格式)
|
# 5. 渲染LLM提示词(用工具函数填充模板,包含实体信息、上下文、输出格式)
|
||||||
@@ -253,8 +248,7 @@ async def _judge_pair_disamb(
|
|||||||
"entity_type": getattr(a, "entity_type", None),
|
"entity_type": getattr(a, "entity_type", None),
|
||||||
"description": getattr(a, "description", None),
|
"description": getattr(a, "description", None),
|
||||||
"aliases": getattr(a, "aliases", None) or [],
|
"aliases": getattr(a, "aliases", None) or [],
|
||||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
"fact_summary": getattr(a, "fact_summary", None),
|
||||||
# "fact_summary": getattr(a, "fact_summary", None),
|
|
||||||
"connect_strength": getattr(a, "connect_strength", None),
|
"connect_strength": getattr(a, "connect_strength", None),
|
||||||
}
|
}
|
||||||
entity_b = {
|
entity_b = {
|
||||||
@@ -262,8 +256,7 @@ async def _judge_pair_disamb(
|
|||||||
"entity_type": getattr(b, "entity_type", None),
|
"entity_type": getattr(b, "entity_type", None),
|
||||||
"description": getattr(b, "description", None),
|
"description": getattr(b, "description", None),
|
||||||
"aliases": getattr(b, "aliases", None) or [],
|
"aliases": getattr(b, "aliases", None) or [],
|
||||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
"fact_summary": getattr(b, "fact_summary", None),
|
||||||
# "fact_summary": getattr(b, "fact_summary", None),
|
|
||||||
"connect_strength": getattr(b, "connect_strength", None),
|
"connect_strength": getattr(b, "connect_strength", None),
|
||||||
}
|
}
|
||||||
prompt = render_entity_dedup_prompt(
|
prompt = render_entity_dedup_prompt(
|
||||||
|
|||||||
@@ -72,8 +72,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
|
|||||||
description=row.get("description") or "",
|
description=row.get("description") or "",
|
||||||
aliases=row.get("aliases") or [],
|
aliases=row.get("aliases") or [],
|
||||||
name_embedding=row.get("name_embedding") or [],
|
name_embedding=row.get("name_embedding") or [],
|
||||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
fact_summary=row.get("fact_summary") or "",
|
||||||
# fact_summary=row.get("fact_summary") or "",
|
|
||||||
connect_strength=row.get("connect_strength") or "",
|
connect_strength=row.get("connect_strength") or "",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1085,8 +1085,7 @@ class ExtractionOrchestrator:
|
|||||||
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
|
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
|
||||||
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
|
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
|
||||||
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
||||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
||||||
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
|
||||||
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
||||||
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
||||||
name_embedding=getattr(entity, 'name_embedding', None),
|
name_embedding=getattr(entity, 'name_embedding', None),
|
||||||
|
|||||||
@@ -8,5 +8,4 @@
|
|||||||
- TemporalExtractor: 时间信息提取
|
- TemporalExtractor: 时间信息提取
|
||||||
- EmbeddingGenerator: 嵌入向量生成
|
- EmbeddingGenerator: 嵌入向量生成
|
||||||
- MemorySummaryGenerator: 记忆摘要生成
|
- MemorySummaryGenerator: 记忆摘要生成
|
||||||
- OntologyExtractor: 本体类提取
|
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -14,34 +14,6 @@ from pydantic import Field
|
|||||||
|
|
||||||
logger = get_memory_logger(__name__)
|
logger = get_memory_logger(__name__)
|
||||||
|
|
||||||
# 支持的语言列表和默认回退值
|
|
||||||
SUPPORTED_LANGUAGES = {"zh", "en"}
|
|
||||||
FALLBACK_LANGUAGE = "en"
|
|
||||||
|
|
||||||
|
|
||||||
def validate_language(language: Optional[str]) -> str:
|
|
||||||
"""
|
|
||||||
校验语言参数,确保其为有效值。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
language: 待校验的语言代码
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
有效的语言代码("zh" 或 "en")
|
|
||||||
"""
|
|
||||||
if language is None:
|
|
||||||
return FALLBACK_LANGUAGE
|
|
||||||
|
|
||||||
lang = str(language).lower().strip()
|
|
||||||
if lang in SUPPORTED_LANGUAGES:
|
|
||||||
return lang
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"无效的语言参数 '{language}',已回退到默认值 '{FALLBACK_LANGUAGE}'。"
|
|
||||||
f"支持的语言: {SUPPORTED_LANGUAGES}"
|
|
||||||
)
|
|
||||||
return FALLBACK_LANGUAGE
|
|
||||||
|
|
||||||
|
|
||||||
class MemorySummaryResponse(RobustLLMResponse):
|
class MemorySummaryResponse(RobustLLMResponse):
|
||||||
"""Structured response for summary generation per chunk.
|
"""Structured response for summary generation per chunk.
|
||||||
@@ -59,8 +31,7 @@ class MemorySummaryResponse(RobustLLMResponse):
|
|||||||
|
|
||||||
async def generate_title_and_type_for_summary(
|
async def generate_title_and_type_for_summary(
|
||||||
content: str,
|
content: str,
|
||||||
llm_client,
|
llm_client
|
||||||
language: str = None
|
|
||||||
) -> Tuple[str, str]:
|
) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
为MemorySummary生成标题和类型
|
为MemorySummary生成标题和类型
|
||||||
@@ -70,18 +41,11 @@ async def generate_title_and_type_for_summary(
|
|||||||
Args:
|
Args:
|
||||||
content: Summary的内容文本
|
content: Summary的内容文本
|
||||||
llm_client: LLM客户端实例
|
llm_client: LLM客户端实例
|
||||||
language: 生成标题使用的语言 ("zh" 中文, "en" 英文),如果为None则从配置读取
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(标题, 类型)元组
|
(标题, 类型)元组
|
||||||
"""
|
"""
|
||||||
from app.core.memory.utils.prompt.prompt_utils import render_episodic_title_and_type_prompt
|
from app.core.memory.utils.prompt.prompt_utils import render_episodic_title_and_type_prompt
|
||||||
from app.core.config import settings
|
|
||||||
|
|
||||||
# 如果没有指定语言,从配置中读取,并校验有效性
|
|
||||||
if language is None:
|
|
||||||
language = settings.DEFAULT_LANGUAGE
|
|
||||||
language = validate_language(language)
|
|
||||||
|
|
||||||
# 定义有效的类型集合
|
# 定义有效的类型集合
|
||||||
VALID_TYPES = {
|
VALID_TYPES = {
|
||||||
@@ -93,19 +57,13 @@ async def generate_title_and_type_for_summary(
|
|||||||
}
|
}
|
||||||
DEFAULT_TYPE = "conversation" # 默认类型
|
DEFAULT_TYPE = "conversation" # 默认类型
|
||||||
|
|
||||||
# 根据语言设置默认标题
|
|
||||||
DEFAULT_TITLE = "空内容" if language == "zh" else "Empty Content"
|
|
||||||
PARSE_ERROR_TITLE = "解析失败" if language == "zh" else "Parse Failed"
|
|
||||||
ERROR_TITLE = "错误" if language == "zh" else "Error"
|
|
||||||
UNKNOWN_TITLE = "未知标题" if language == "zh" else "Unknown Title"
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not content:
|
if not content:
|
||||||
logger.warning(f"content为空,无法生成标题和类型 (language={language})")
|
logger.warning("content为空,无法生成标题和类型")
|
||||||
return (DEFAULT_TITLE, DEFAULT_TYPE)
|
return ("空内容", DEFAULT_TYPE)
|
||||||
|
|
||||||
# 1. 渲染Jinja2提示词模板,传递语言参数
|
# 1. 渲染Jinja2提示词模板
|
||||||
prompt = await render_episodic_title_and_type_prompt(content, language=language)
|
prompt = await render_episodic_title_and_type_prompt(content)
|
||||||
|
|
||||||
# 2. 调用LLM生成标题和类型
|
# 2. 调用LLM生成标题和类型
|
||||||
messages = [
|
messages = [
|
||||||
@@ -144,7 +102,7 @@ async def generate_title_and_type_for_summary(
|
|||||||
json_str = json_str.strip()
|
json_str = json_str.strip()
|
||||||
|
|
||||||
result_data = json.loads(json_str)
|
result_data = json.loads(json_str)
|
||||||
title = result_data.get("title", UNKNOWN_TITLE)
|
title = result_data.get("title", "未知标题")
|
||||||
episodic_type_raw = result_data.get("type", DEFAULT_TYPE)
|
episodic_type_raw = result_data.get("type", DEFAULT_TYPE)
|
||||||
|
|
||||||
# 5. 校验和归一化类型
|
# 5. 校验和归一化类型
|
||||||
@@ -172,16 +130,16 @@ async def generate_title_and_type_for_summary(
|
|||||||
f"已归一化为 '{episodic_type}'"
|
f"已归一化为 '{episodic_type}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"成功生成标题和类型 (language={language}): title={title}, type={episodic_type}")
|
logger.info(f"成功生成标题和类型: title={title}, type={episodic_type}")
|
||||||
return (title, episodic_type)
|
return (title, episodic_type)
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.error(f"无法解析LLM响应为JSON (language={language}): {full_response}")
|
logger.error(f"无法解析LLM响应为JSON: {full_response}")
|
||||||
return (PARSE_ERROR_TITLE, DEFAULT_TYPE)
|
return ("解析失败", DEFAULT_TYPE)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"生成标题和类型时出错 (language={language}): {str(e)}", exc_info=True)
|
logger.error(f"生成标题和类型时出错: {str(e)}", exc_info=True)
|
||||||
return (ERROR_TITLE, DEFAULT_TYPE)
|
return ("错误", DEFAULT_TYPE)
|
||||||
|
|
||||||
async def _process_chunk_summary(
|
async def _process_chunk_summary(
|
||||||
dialog: DialogData,
|
dialog: DialogData,
|
||||||
@@ -195,16 +153,11 @@ async def _process_chunk_summary(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 从配置中获取语言设置(只获取一次,复用),并校验有效性
|
|
||||||
from app.core.config import settings
|
|
||||||
language = validate_language(settings.DEFAULT_LANGUAGE)
|
|
||||||
|
|
||||||
# Render prompt via Jinja2 for a single chunk
|
# Render prompt via Jinja2 for a single chunk
|
||||||
prompt_content = await render_memory_summary_prompt(
|
prompt_content = await render_memory_summary_prompt(
|
||||||
chunk_texts=chunk.content,
|
chunk_texts=chunk.content,
|
||||||
json_schema=MemorySummaryResponse.model_json_schema(),
|
json_schema=MemorySummaryResponse.model_json_schema(),
|
||||||
max_words=200,
|
max_words=200,
|
||||||
language=language,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
@@ -225,10 +178,9 @@ async def _process_chunk_summary(
|
|||||||
try:
|
try:
|
||||||
title, episodic_type = await generate_title_and_type_for_summary(
|
title, episodic_type = await generate_title_and_type_for_summary(
|
||||||
content=summary_text,
|
content=summary_text,
|
||||||
llm_client=llm_client,
|
llm_client=llm_client
|
||||||
language=language
|
|
||||||
)
|
)
|
||||||
logger.info(f"Generated title and type for MemorySummary (language={language}): title={title}, type={episodic_type}")
|
logger.info(f"Generated title and type for MemorySummary: title={title}, type={episodic_type}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to generate title and type for chunk {chunk.id}: {e}")
|
logger.warning(f"Failed to generate title and type for chunk {chunk.id}: {e}")
|
||||||
# Continue without title and type
|
# Continue without title and type
|
||||||
|
|||||||
@@ -1,482 +0,0 @@
|
|||||||
"""Ontology class extraction from scenario descriptions using LLM.
|
|
||||||
|
|
||||||
This module provides the OntologyExtractor class for extracting ontology classes
|
|
||||||
from natural language scenario descriptions. It uses LLM-driven extraction combined
|
|
||||||
with two-layer validation (string validation + OWL semantic validation).
|
|
||||||
|
|
||||||
Classes:
|
|
||||||
OntologyExtractor: Extracts ontology classes from scenario descriptions
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
|
||||||
from app.core.memory.models.ontology_models import (
|
|
||||||
OntologyClass,
|
|
||||||
OntologyExtractionResponse,
|
|
||||||
)
|
|
||||||
from app.core.memory.utils.validation.ontology_validator import OntologyValidator
|
|
||||||
from app.core.memory.utils.validation.owl_validator import OWLValidator
|
|
||||||
from app.core.memory.utils.prompt.prompt_utils import render_ontology_extraction_prompt
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class OntologyExtractor:
|
|
||||||
"""Extractor for ontology classes from scenario descriptions.
|
|
||||||
|
|
||||||
This extractor uses LLM to identify abstract classes and concepts from
|
|
||||||
natural language scenario descriptions, following OWL ontology engineering
|
|
||||||
standards. It performs two-layer validation:
|
|
||||||
1. String validation (naming conventions, reserved words, duplicates)
|
|
||||||
2. OWL semantic validation (consistency checking, circular inheritance)
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
llm_client: OpenAI client for LLM calls
|
|
||||||
validator: String validator for class names and descriptions
|
|
||||||
owl_validator: OWL validator for semantic validation
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, llm_client: OpenAIClient):
|
|
||||||
"""Initialize the OntologyExtractor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
llm_client: OpenAIClient instance for LLM processing
|
|
||||||
"""
|
|
||||||
self.llm_client = llm_client
|
|
||||||
self.validator = OntologyValidator()
|
|
||||||
self.owl_validator = OWLValidator()
|
|
||||||
|
|
||||||
logger.info("OntologyExtractor initialized")
|
|
||||||
|
|
||||||
async def extract_ontology_classes(
|
|
||||||
self,
|
|
||||||
scenario: str,
|
|
||||||
domain: Optional[str] = None,
|
|
||||||
max_classes: int = 15,
|
|
||||||
min_classes: int = 5,
|
|
||||||
enable_owl_validation: bool = True,
|
|
||||||
llm_temperature: float = 0.3,
|
|
||||||
llm_max_tokens: int = 2000,
|
|
||||||
max_description_length: int = 500,
|
|
||||||
timeout: Optional[float] = None,
|
|
||||||
) -> OntologyExtractionResponse:
|
|
||||||
"""Extract ontology classes from a scenario description.
|
|
||||||
|
|
||||||
This is the main extraction method that orchestrates the entire process:
|
|
||||||
1. Call LLM to extract ontology classes
|
|
||||||
2. Perform first-layer validation (string validation and cleaning)
|
|
||||||
3. Perform second-layer validation (OWL semantic validation)
|
|
||||||
4. Filter invalid classes based on validation errors
|
|
||||||
5. Return validated ontology classes
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scenario: Natural language scenario description
|
|
||||||
domain: Optional domain hint (e.g., "Healthcare", "Education")
|
|
||||||
max_classes: Maximum number of classes to extract (default: 15)
|
|
||||||
min_classes: Minimum number of classes to extract (default: 5)
|
|
||||||
enable_owl_validation: Whether to enable OWL validation (default: True)
|
|
||||||
llm_temperature: LLM temperature parameter (default: 0.3)
|
|
||||||
llm_max_tokens: LLM max tokens parameter (default: 2000)
|
|
||||||
max_description_length: Maximum description length (default: 500)
|
|
||||||
timeout: Optional timeout in seconds for LLM call (default: None, no timeout)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OntologyExtractionResponse containing validated ontology classes
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If scenario is empty or invalid
|
|
||||||
asyncio.TimeoutError: If extraction times out
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> extractor = OntologyExtractor(llm_client)
|
|
||||||
>>> response = await extractor.extract_ontology_classes(
|
|
||||||
... scenario="A hospital manages patient records...",
|
|
||||||
... domain="Healthcare",
|
|
||||||
... max_classes=10,
|
|
||||||
... timeout=30.0
|
|
||||||
... )
|
|
||||||
>>> len(response.classes)
|
|
||||||
7
|
|
||||||
"""
|
|
||||||
# Start timing
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
# Validate input
|
|
||||||
if not scenario or not scenario.strip():
|
|
||||||
logger.error("Scenario description is empty")
|
|
||||||
raise ValueError("Scenario description cannot be empty")
|
|
||||||
|
|
||||||
scenario = scenario.strip()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Starting ontology extraction - scenario_length={len(scenario)}, "
|
|
||||||
f"domain={domain}, max_classes={max_classes}, min_classes={min_classes}, "
|
|
||||||
f"timeout={timeout}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Step 1: Call LLM for extraction with timeout
|
|
||||||
logger.info("Step 1: Calling LLM for ontology extraction")
|
|
||||||
llm_start_time = time.time()
|
|
||||||
|
|
||||||
if timeout is not None:
|
|
||||||
# Wrap LLM call with timeout
|
|
||||||
try:
|
|
||||||
response = await asyncio.wait_for(
|
|
||||||
self._call_llm_for_extraction(
|
|
||||||
scenario=scenario,
|
|
||||||
domain=domain,
|
|
||||||
max_classes=max_classes,
|
|
||||||
llm_temperature=llm_temperature,
|
|
||||||
llm_max_tokens=llm_max_tokens,
|
|
||||||
),
|
|
||||||
timeout=timeout
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
llm_duration = time.time() - llm_start_time
|
|
||||||
logger.error(
|
|
||||||
f"LLM extraction timed out after {timeout} seconds "
|
|
||||||
f"(actual duration: {llm_duration:.2f}s)"
|
|
||||||
)
|
|
||||||
# Return empty response on timeout
|
|
||||||
return OntologyExtractionResponse(
|
|
||||||
classes=[],
|
|
||||||
domain=domain or "Unknown",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# No timeout specified, call directly
|
|
||||||
response = await self._call_llm_for_extraction(
|
|
||||||
scenario=scenario,
|
|
||||||
domain=domain,
|
|
||||||
max_classes=max_classes,
|
|
||||||
llm_temperature=llm_temperature,
|
|
||||||
llm_max_tokens=llm_max_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
llm_duration = time.time() - llm_start_time
|
|
||||||
logger.info(
|
|
||||||
f"LLM returned {len(response.classes)} classes in {llm_duration:.2f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 2: First-layer validation (string validation and cleaning)
|
|
||||||
logger.info("Step 2: Performing first-layer validation (string validation)")
|
|
||||||
validation_start_time = time.time()
|
|
||||||
|
|
||||||
response = self._validate_and_clean(
|
|
||||||
response=response,
|
|
||||||
max_description_length=max_description_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
validation_duration = time.time() - validation_start_time
|
|
||||||
logger.info(
|
|
||||||
f"After first-layer validation: {len(response.classes)} classes remain "
|
|
||||||
f"(validation took {validation_duration:.2f}s)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if we have enough classes after first-layer validation
|
|
||||||
if len(response.classes) < min_classes:
|
|
||||||
logger.warning(
|
|
||||||
f"Only {len(response.classes)} classes remain after validation, "
|
|
||||||
f"which is below minimum of {min_classes}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 3: Second-layer validation (OWL semantic validation)
|
|
||||||
if enable_owl_validation and response.classes:
|
|
||||||
logger.info("Step 3: Performing second-layer validation (OWL validation)")
|
|
||||||
owl_start_time = time.time()
|
|
||||||
|
|
||||||
is_valid, errors, world = self.owl_validator.validate_ontology_classes(
|
|
||||||
classes=response.classes,
|
|
||||||
)
|
|
||||||
|
|
||||||
owl_duration = time.time() - owl_start_time
|
|
||||||
|
|
||||||
if not is_valid:
|
|
||||||
logger.warning(
|
|
||||||
f"OWL validation found {len(errors)} issues in {owl_duration:.2f}s: {errors}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filter invalid classes based on errors
|
|
||||||
response = self._filter_invalid_classes(
|
|
||||||
response=response,
|
|
||||||
errors=errors,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"After second-layer validation: {len(response.classes)} classes remain"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info(f"OWL validation passed successfully in {owl_duration:.2f}s")
|
|
||||||
else:
|
|
||||||
if not enable_owl_validation:
|
|
||||||
logger.info("Step 3: OWL validation disabled, skipping")
|
|
||||||
else:
|
|
||||||
logger.info("Step 3: No classes to validate, skipping OWL validation")
|
|
||||||
|
|
||||||
# Calculate total duration
|
|
||||||
total_duration = time.time() - start_time
|
|
||||||
|
|
||||||
# Log extraction statistics
|
|
||||||
logger.info(
|
|
||||||
f"Ontology extraction completed - "
|
|
||||||
f"final_class_count={len(response.classes)}, "
|
|
||||||
f"domain={response.domain}, "
|
|
||||||
f"total_duration={total_duration:.2f}s, "
|
|
||||||
f"llm_duration={llm_duration:.2f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
# Re-raise timeout errors
|
|
||||||
total_duration = time.time() - start_time
|
|
||||||
logger.error(
|
|
||||||
f"Ontology extraction timed out after {timeout} seconds "
|
|
||||||
f"(total duration: {total_duration:.2f}s)",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
total_duration = time.time() - start_time
|
|
||||||
logger.error(
|
|
||||||
f"Ontology extraction failed after {total_duration:.2f}s: {str(e)}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
# Return empty response on failure
|
|
||||||
return OntologyExtractionResponse(
|
|
||||||
classes=[],
|
|
||||||
domain=domain or "Unknown",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _call_llm_for_extraction(
|
|
||||||
self,
|
|
||||||
scenario: str,
|
|
||||||
domain: Optional[str],
|
|
||||||
max_classes: int,
|
|
||||||
llm_temperature: float,
|
|
||||||
llm_max_tokens: int,
|
|
||||||
) -> OntologyExtractionResponse:
|
|
||||||
"""Call LLM to extract ontology classes from scenario.
|
|
||||||
|
|
||||||
This method renders the extraction prompt using the Jinja2 template
|
|
||||||
and calls the LLM with structured output to get ontology classes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scenario: Scenario description text
|
|
||||||
domain: Optional domain hint
|
|
||||||
max_classes: Maximum number of classes to extract
|
|
||||||
llm_temperature: LLM temperature parameter
|
|
||||||
llm_max_tokens: LLM max tokens parameter
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OntologyExtractionResponse from LLM
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: If LLM call fails
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Render prompt using template
|
|
||||||
prompt_content = await render_ontology_extraction_prompt(
|
|
||||||
scenario=scenario,
|
|
||||||
domain=domain,
|
|
||||||
max_classes=max_classes,
|
|
||||||
json_schema=OntologyExtractionResponse.model_json_schema(),
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"Rendered prompt length: {len(prompt_content)}")
|
|
||||||
|
|
||||||
# Create messages for LLM
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": (
|
|
||||||
"You are an expert ontology engineer specializing in knowledge "
|
|
||||||
"representation and OWL standards. Extract ontology classes from "
|
|
||||||
"scenario descriptions following the provided instructions. "
|
|
||||||
"Return valid JSON conforming to the schema."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt_content,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
# Call LLM with structured output
|
|
||||||
logger.debug(
|
|
||||||
f"Calling LLM with temperature={llm_temperature}, "
|
|
||||||
f"max_tokens={llm_max_tokens}"
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await self.llm_client.response_structured(
|
|
||||||
messages=messages,
|
|
||||||
response_model=OntologyExtractionResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"LLM extraction successful - extracted {len(response.classes)} classes"
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"LLM extraction failed: {str(e)}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _validate_and_clean(
|
|
||||||
self,
|
|
||||||
response: OntologyExtractionResponse,
|
|
||||||
max_description_length: int,
|
|
||||||
) -> OntologyExtractionResponse:
|
|
||||||
"""Perform first-layer validation: string validation and cleaning.
|
|
||||||
|
|
||||||
This method validates and cleans the extracted ontology classes:
|
|
||||||
1. Validate class names (PascalCase, no reserved words)
|
|
||||||
2. Sanitize invalid class names
|
|
||||||
3. Truncate long descriptions
|
|
||||||
4. Remove duplicate classes
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response: OntologyExtractionResponse from LLM
|
|
||||||
max_description_length: Maximum description length
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Cleaned OntologyExtractionResponse
|
|
||||||
"""
|
|
||||||
if not response.classes:
|
|
||||||
logger.debug("No classes to validate")
|
|
||||||
return response
|
|
||||||
|
|
||||||
logger.debug(f"Validating {len(response.classes)} classes")
|
|
||||||
|
|
||||||
validated_classes = []
|
|
||||||
|
|
||||||
for ontology_class in response.classes:
|
|
||||||
# Validate class name
|
|
||||||
is_valid, error_msg = self.validator.validate_class_name(
|
|
||||||
ontology_class.name
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_valid:
|
|
||||||
logger.warning(
|
|
||||||
f"Invalid class name '{ontology_class.name}': {error_msg}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Attempt to sanitize
|
|
||||||
sanitized_name = self.validator.sanitize_class_name(
|
|
||||||
ontology_class.name
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Sanitized class name: '{ontology_class.name}' -> '{sanitized_name}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update class name
|
|
||||||
ontology_class.name = sanitized_name
|
|
||||||
|
|
||||||
# Re-validate sanitized name
|
|
||||||
is_valid, error_msg = self.validator.validate_class_name(
|
|
||||||
sanitized_name
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_valid:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to sanitize class name '{ontology_class.name}': {error_msg}. "
|
|
||||||
"Skipping this class."
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Truncate description if too long
|
|
||||||
if ontology_class.description:
|
|
||||||
original_length = len(ontology_class.description)
|
|
||||||
ontology_class.description = self.validator.truncate_description(
|
|
||||||
ontology_class.description,
|
|
||||||
max_length=max_description_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(ontology_class.description) < original_length:
|
|
||||||
logger.debug(
|
|
||||||
f"Truncated description for '{ontology_class.name}': "
|
|
||||||
f"{original_length} -> {len(ontology_class.description)} chars"
|
|
||||||
)
|
|
||||||
|
|
||||||
validated_classes.append(ontology_class)
|
|
||||||
|
|
||||||
# Remove duplicates (case-insensitive)
|
|
||||||
original_count = len(validated_classes)
|
|
||||||
validated_classes = self.validator.remove_duplicates(validated_classes)
|
|
||||||
|
|
||||||
if len(validated_classes) < original_count:
|
|
||||||
logger.info(
|
|
||||||
f"Removed {original_count - len(validated_classes)} duplicate classes"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Return cleaned response
|
|
||||||
return OntologyExtractionResponse(
|
|
||||||
classes=validated_classes,
|
|
||||||
domain=response.domain,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _filter_invalid_classes(
|
|
||||||
self,
|
|
||||||
response: OntologyExtractionResponse,
|
|
||||||
errors: List[str],
|
|
||||||
) -> OntologyExtractionResponse:
|
|
||||||
"""Filter invalid classes based on OWL validation errors.
|
|
||||||
|
|
||||||
This method analyzes OWL validation errors and removes classes
|
|
||||||
that caused validation failures (e.g., circular inheritance,
|
|
||||||
inconsistencies).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response: OntologyExtractionResponse to filter
|
|
||||||
errors: List of error messages from OWL validation
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Filtered OntologyExtractionResponse
|
|
||||||
"""
|
|
||||||
if not errors:
|
|
||||||
return response
|
|
||||||
|
|
||||||
logger.debug(f"Filtering classes based on {len(errors)} OWL validation errors")
|
|
||||||
|
|
||||||
# Extract class names mentioned in errors
|
|
||||||
invalid_class_names = set()
|
|
||||||
|
|
||||||
for error in errors:
|
|
||||||
# Look for class names in error messages
|
|
||||||
for ontology_class in response.classes:
|
|
||||||
if ontology_class.name in error:
|
|
||||||
invalid_class_names.add(ontology_class.name)
|
|
||||||
logger.debug(
|
|
||||||
f"Class '{ontology_class.name}' marked as invalid due to error: {error}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filter out invalid classes
|
|
||||||
if invalid_class_names:
|
|
||||||
original_count = len(response.classes)
|
|
||||||
|
|
||||||
filtered_classes = [
|
|
||||||
c for c in response.classes
|
|
||||||
if c.name not in invalid_class_names
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Filtered out {original_count - len(filtered_classes)} invalid classes: "
|
|
||||||
f"{invalid_class_names}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return OntologyExtractionResponse(
|
|
||||||
classes=filtered_classes,
|
|
||||||
domain=response.domain,
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
@@ -25,15 +25,6 @@ class TripletExtractor:
|
|||||||
"""
|
"""
|
||||||
self.llm_client = llm_client
|
self.llm_client = llm_client
|
||||||
|
|
||||||
def _get_language(self) -> str:
|
|
||||||
"""Get the configured language for entity descriptions
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Language code ("zh" or "en")
|
|
||||||
"""
|
|
||||||
from app.core.config import settings
|
|
||||||
return settings.DEFAULT_LANGUAGE
|
|
||||||
|
|
||||||
async def _extract_triplets(self, statement: Statement, chunk_content: str) -> TripletExtractionResponse:
|
async def _extract_triplets(self, statement: Statement, chunk_content: str) -> TripletExtractionResponse:
|
||||||
"""Process a single statement and return extracted triplets and entities"""
|
"""Process a single statement and return extracted triplets and entities"""
|
||||||
# Render the prompt using helper function
|
# Render the prompt using helper function
|
||||||
@@ -49,8 +40,7 @@ class TripletExtractor:
|
|||||||
statement=statement.statement,
|
statement=statement.statement,
|
||||||
chunk_content=chunk_content,
|
chunk_content=chunk_content,
|
||||||
json_schema=TripletExtractionResponse.model_json_schema(),
|
json_schema=TripletExtractionResponse.model_json_schema(),
|
||||||
predicate_instructions=PREDICATE_DEFINITIONS,
|
predicate_instructions=PREDICATE_DEFINITIONS
|
||||||
language=self._get_language()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create messages for LLM
|
# Create messages for LLM
|
||||||
|
|||||||
@@ -296,9 +296,7 @@ def resolve_alias_cycles(entities: List[Any], cycles: Dict[str, Set[str]]) -> Li
|
|||||||
key=lambda eid: (
|
key=lambda eid: (
|
||||||
_strength_rank(eid),
|
_strength_rank(eid),
|
||||||
len(getattr(entity_by_id.get(eid), 'description', '') or ''),
|
len(getattr(entity_by_id.get(eid), 'description', '') or ''),
|
||||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
|
||||||
# len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
|
|
||||||
0 # 临时占位
|
|
||||||
),
|
),
|
||||||
reverse=True
|
reverse=True
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -177,7 +177,7 @@ def render_entity_dedup_prompt(
|
|||||||
|
|
||||||
# Args:
|
# Args:
|
||||||
# entity_a: Dict of entity A attributes
|
# entity_a: Dict of entity A attributes
|
||||||
async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None, language: str = "zh") -> str:
|
async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None) -> str:
|
||||||
"""
|
"""
|
||||||
Renders the triplet extraction prompt using the extract_triplet.jinja2 template.
|
Renders the triplet extraction prompt using the extract_triplet.jinja2 template.
|
||||||
|
|
||||||
@@ -186,7 +186,6 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
|
|||||||
chunk_content: The content of the chunk to process
|
chunk_content: The content of the chunk to process
|
||||||
json_schema: JSON schema for the expected output format
|
json_schema: JSON schema for the expected output format
|
||||||
predicate_instructions: Optional predicate instructions
|
predicate_instructions: Optional predicate instructions
|
||||||
language: The language to use for entity descriptions ("zh" for Chinese, "en" for English)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Rendered prompt content as string
|
Rendered prompt content as string
|
||||||
@@ -196,8 +195,7 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
|
|||||||
statement=statement,
|
statement=statement,
|
||||||
chunk_content=chunk_content,
|
chunk_content=chunk_content,
|
||||||
json_schema=json_schema,
|
json_schema=json_schema,
|
||||||
predicate_instructions=predicate_instructions,
|
predicate_instructions=predicate_instructions
|
||||||
language=language
|
|
||||||
)
|
)
|
||||||
# 记录渲染结果到提示日志(与示例日志结构一致)
|
# 记录渲染结果到提示日志(与示例日志结构一致)
|
||||||
log_prompt_rendering('triplet extraction', rendered_prompt)
|
log_prompt_rendering('triplet extraction', rendered_prompt)
|
||||||
@@ -206,8 +204,7 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
|
|||||||
'statement': 'str',
|
'statement': 'str',
|
||||||
'chunk_content': 'str',
|
'chunk_content': 'str',
|
||||||
'json_schema': 'TripletExtractionResponse.schema',
|
'json_schema': 'TripletExtractionResponse.schema',
|
||||||
'predicate_instructions': 'PREDICATE_DEFINITIONS',
|
'predicate_instructions': 'PREDICATE_DEFINITIONS'
|
||||||
'language': language
|
|
||||||
})
|
})
|
||||||
|
|
||||||
return rendered_prompt
|
return rendered_prompt
|
||||||
@@ -216,7 +213,6 @@ async def render_memory_summary_prompt(
|
|||||||
chunk_texts: str,
|
chunk_texts: str,
|
||||||
json_schema: dict,
|
json_schema: dict,
|
||||||
max_words: int = 200,
|
max_words: int = 200,
|
||||||
language: str = "zh",
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Renders the memory summary prompt using the memory_summary.jinja2 template.
|
Renders the memory summary prompt using the memory_summary.jinja2 template.
|
||||||
@@ -225,7 +221,6 @@ async def render_memory_summary_prompt(
|
|||||||
chunk_texts: Concatenated text of conversation chunks
|
chunk_texts: Concatenated text of conversation chunks
|
||||||
json_schema: JSON schema for the expected output format
|
json_schema: JSON schema for the expected output format
|
||||||
max_words: Maximum words for the summary
|
max_words: Maximum words for the summary
|
||||||
language: The language to use for summary generation ("zh" for Chinese, "en" for English)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Rendered prompt content as string.
|
Rendered prompt content as string.
|
||||||
@@ -235,14 +230,12 @@ async def render_memory_summary_prompt(
|
|||||||
chunk_texts=chunk_texts,
|
chunk_texts=chunk_texts,
|
||||||
json_schema=json_schema,
|
json_schema=json_schema,
|
||||||
max_words=max_words,
|
max_words=max_words,
|
||||||
language=language,
|
|
||||||
)
|
)
|
||||||
log_prompt_rendering('memory summary', rendered_prompt)
|
log_prompt_rendering('memory summary', rendered_prompt)
|
||||||
log_template_rendering('memory_summary.jinja2', {
|
log_template_rendering('memory_summary.jinja2', {
|
||||||
'chunk_texts_len': len(chunk_texts or ""),
|
'chunk_texts_len': len(chunk_texts or ""),
|
||||||
'max_words': max_words,
|
'max_words': max_words,
|
||||||
'json_schema': 'MemorySummaryResponse.schema',
|
'json_schema': 'MemorySummaryResponse.schema'
|
||||||
'language': language
|
|
||||||
})
|
})
|
||||||
return rendered_prompt
|
return rendered_prompt
|
||||||
|
|
||||||
@@ -395,65 +388,24 @@ async def render_memory_insight_prompt(
|
|||||||
return rendered_prompt
|
return rendered_prompt
|
||||||
|
|
||||||
|
|
||||||
async def render_episodic_title_and_type_prompt(content: str, language: str = "zh") -> str:
|
async def render_episodic_title_and_type_prompt(content: str) -> str:
|
||||||
"""
|
"""
|
||||||
Renders the episodic title and type classification prompt using the episodic_type_classification.jinja2 template.
|
Renders the episodic title and type classification prompt using the episodic_type_classification.jinja2 template.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: The content of the episodic memory summary to analyze
|
content: The content of the episodic memory summary to analyze
|
||||||
language: The language to use for title generation ("zh" for Chinese, "en" for English)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Rendered prompt content as string
|
Rendered prompt content as string
|
||||||
"""
|
"""
|
||||||
template = prompt_env.get_template("episodic_type_classification.jinja2")
|
template = prompt_env.get_template("episodic_type_classification.jinja2")
|
||||||
rendered_prompt = template.render(content=content, language=language)
|
rendered_prompt = template.render(content=content)
|
||||||
|
|
||||||
# 记录渲染结果到提示日志
|
# 记录渲染结果到提示日志
|
||||||
log_prompt_rendering('episodic title and type classification', rendered_prompt)
|
log_prompt_rendering('episodic title and type classification', rendered_prompt)
|
||||||
# 可选:记录模板渲染信息
|
# 可选:记录模板渲染信息
|
||||||
log_template_rendering('episodic_type_classification.jinja2', {
|
log_template_rendering('episodic_type_classification.jinja2', {
|
||||||
'content_len': len(content) if content else 0,
|
'content_len': len(content) if content else 0
|
||||||
'language': language
|
|
||||||
})
|
|
||||||
|
|
||||||
return rendered_prompt
|
|
||||||
|
|
||||||
|
|
||||||
async def render_ontology_extraction_prompt(
|
|
||||||
scenario: str,
|
|
||||||
domain: str | None = None,
|
|
||||||
max_classes: int = 15,
|
|
||||||
json_schema: dict | None = None
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Renders the ontology extraction prompt using the extract_ontology.jinja2 template.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scenario: The scenario description text to extract ontology classes from
|
|
||||||
domain: Optional domain hint for the scenario (e.g., "Healthcare", "Education")
|
|
||||||
max_classes: Maximum number of classes to extract (default: 15)
|
|
||||||
json_schema: JSON schema for the expected output format
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Rendered prompt content as string
|
|
||||||
"""
|
|
||||||
template = prompt_env.get_template("extract_ontology.jinja2")
|
|
||||||
rendered_prompt = template.render(
|
|
||||||
scenario=scenario,
|
|
||||||
domain=domain,
|
|
||||||
max_classes=max_classes,
|
|
||||||
json_schema=json_schema
|
|
||||||
)
|
|
||||||
|
|
||||||
# 记录渲染结果到提示日志
|
|
||||||
log_prompt_rendering('ontology extraction', rendered_prompt)
|
|
||||||
# 可选:记录模板渲染信息
|
|
||||||
log_template_rendering('extract_ontology.jinja2', {
|
|
||||||
'scenario_len': len(scenario) if scenario else 0,
|
|
||||||
'domain': domain,
|
|
||||||
'max_classes': max_classes,
|
|
||||||
'json_schema': 'OntologyExtractionResponse.schema'
|
|
||||||
})
|
})
|
||||||
|
|
||||||
return rendered_prompt
|
return rendered_prompt
|
||||||
|
|||||||
@@ -9,8 +9,7 @@
|
|||||||
- 类型: "{{ entity_a.entity_type | default('') }}"
|
- 类型: "{{ entity_a.entity_type | default('') }}"
|
||||||
- 描述: "{{ entity_a.description | default('') }}"
|
- 描述: "{{ entity_a.description | default('') }}"
|
||||||
- 别名: {{ entity_a.aliases | default([]) }}
|
- 别名: {{ entity_a.aliases | default([]) }}
|
||||||
{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #}
|
- 摘要: "{{ entity_a.fact_summary | default('') }}"
|
||||||
{# - 摘要: "{{ entity_a.fact_summary | default('') }}" #}
|
|
||||||
- 连接强弱: "{{ entity_a.connect_strength | default('') }}"
|
- 连接强弱: "{{ entity_a.connect_strength | default('') }}"
|
||||||
|
|
||||||
实体B:
|
实体B:
|
||||||
@@ -18,8 +17,7 @@
|
|||||||
- 类型: "{{ entity_b.entity_type | default('') }}"
|
- 类型: "{{ entity_b.entity_type | default('') }}"
|
||||||
- 描述: "{{ entity_b.description | default('') }}"
|
- 描述: "{{ entity_b.description | default('') }}"
|
||||||
- 别名: {{ entity_b.aliases | default([]) }}
|
- 别名: {{ entity_b.aliases | default([]) }}
|
||||||
{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #}
|
- 摘要: "{{ entity_b.fact_summary | default('') }}"
|
||||||
{# - 摘要: "{{ entity_b.fact_summary | default('') }}" #}
|
|
||||||
- 连接强弱: "{{ entity_b.connect_strength | default('') }}"
|
- 连接强弱: "{{ entity_b.connect_strength | default('') }}"
|
||||||
|
|
||||||
上下文:
|
上下文:
|
||||||
|
|||||||
@@ -1,19 +1,8 @@
|
|||||||
=== Task ===
|
=== Task ===
|
||||||
Generate a concise title and classify the episodic memory into the most appropriate category.
|
Generate a concise title and classify the episodic memory into the most appropriate category.
|
||||||
|
|
||||||
{% if language == "zh" %}
|
|
||||||
**重要:请使用中文生成标题和分类。**
|
|
||||||
{% else %}
|
|
||||||
**Important: Please generate the title and classification in English.**
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
=== Requirements ===
|
=== Requirements ===
|
||||||
- Extract a clear, concise title (10-20 characters) that captures the core content
|
- Extract a clear, concise title (10-20 characters) that captures the core content
|
||||||
{% if language == "zh" %}
|
|
||||||
- 标题必须使用中文
|
|
||||||
{% else %}
|
|
||||||
- Title must be in English
|
|
||||||
{% endif %}
|
|
||||||
- Classify into exactly one category based on the primary theme
|
- Classify into exactly one category based on the primary theme
|
||||||
- Be specific and avoid ambiguity
|
- Be specific and avoid ambiguity
|
||||||
- Output must be valid JSON conforming to the schema below
|
- Output must be valid JSON conforming to the schema below
|
||||||
|
|||||||
@@ -1,210 +0,0 @@
|
|||||||
===Task===
|
|
||||||
Extract ontology classes from the given scenario description following ontology engineering standards.
|
|
||||||
|
|
||||||
===Role===
|
|
||||||
You are a professional ontology engineer with expertise in knowledge representation and OWL (Web Ontology Language) standards. Your task is to identify abstract classes and concepts from scenario descriptions, not concrete instances.
|
|
||||||
|
|
||||||
===Scenario Description===
|
|
||||||
{{ scenario }}
|
|
||||||
|
|
||||||
{% if domain -%}
|
|
||||||
===Domain Hint===
|
|
||||||
This scenario belongs to the **{{ domain }}** domain. Consider domain-specific concepts and terminology when extracting classes.
|
|
||||||
{%- endif %}
|
|
||||||
|
|
||||||
===Extraction Rules===
|
|
||||||
|
|
||||||
**1. Abstract Classes, Not Instances:**
|
|
||||||
- Extract abstract categories and concepts (e.g., "MedicalProcedure", "Patient", "Diagnosis")
|
|
||||||
- Do NOT extract concrete instances (e.g., "John Smith", "Room 301", "2024-01-15")
|
|
||||||
- Think in terms of "types of things" rather than "specific things"
|
|
||||||
|
|
||||||
**2. Naming Convention (PascalCase):**
|
|
||||||
- Use PascalCase format for the "name" field: start with uppercase letter, capitalize each word, no spaces
|
|
||||||
- Examples: "MedicalProcedure", "HealthcareProvider", "DiagnosticTest"
|
|
||||||
- Avoid: "medical procedure", "healthcare_provider", "diagnostic-test"
|
|
||||||
- Use clear, descriptive names in English
|
|
||||||
- Avoid abbreviations unless they are standard in the domain (e.g., "API", "DNA")
|
|
||||||
- Provide Chinese translation in the "name_chinese" field (e.g., "医疗程序", "医疗服务提供者", "诊断测试")
|
|
||||||
|
|
||||||
**3. Domain Relevance:**
|
|
||||||
- Focus on classes that are central to the scenario's domain
|
|
||||||
- Prioritize classes that represent key concepts, entities, or relationships
|
|
||||||
- Avoid overly generic classes (e.g., "Thing", "Object") unless they have specific domain meaning
|
|
||||||
|
|
||||||
**4. Class Quantity:**
|
|
||||||
- Extract between 5 and {{ max_classes }} classes
|
|
||||||
- Aim for a balanced set covering the main concepts in the scenario
|
|
||||||
- Quality over quantity: prefer well-defined classes over exhaustive lists
|
|
||||||
|
|
||||||
**5. Clear Descriptions:**
|
|
||||||
- Provide concise, informative descriptions in Chinese (max 500 characters)
|
|
||||||
- Describe what the class represents, not specific instances
|
|
||||||
- Use clear, natural Chinese language that explains the class's role in the domain
|
|
||||||
|
|
||||||
**6. Concrete Examples:**
|
|
||||||
- Provide 2-5 concrete instance examples in Chinese for each class
|
|
||||||
- Examples should be specific, realistic instances of the class
|
|
||||||
- Examples help clarify the class's scope and meaning
|
|
||||||
- Use natural Chinese language for examples
|
|
||||||
- Example format: ["示例1", "示例2", "示例3"]
|
|
||||||
|
|
||||||
**7. Class Hierarchy:**
|
|
||||||
- Identify parent-child relationships where applicable
|
|
||||||
- Use the parent_class field to specify inheritance
|
|
||||||
- Parent class must be one of the extracted classes or a standard OWL class
|
|
||||||
- Leave parent_class as null for top-level classes
|
|
||||||
|
|
||||||
**8. Entity Types:**
|
|
||||||
- Classify each class with an appropriate entity_type
|
|
||||||
- Common types: "Person", "Organization", "Location", "Event", "Concept", "Process", "Object", "Role"
|
|
||||||
- Choose the most specific type that applies
|
|
||||||
|
|
||||||
**9. OWL Reserved Words:**
|
|
||||||
- Do NOT use OWL reserved words as class names
|
|
||||||
- Reserved words include: "Thing", "Nothing", "Class", "Property", "ObjectProperty", "DatatypeProperty", "AnnotationProperty", "Ontology", "Individual", "Literal"
|
|
||||||
- If a reserved word is needed, add a domain-specific prefix (e.g., "MedicalClass" instead of "Class")
|
|
||||||
|
|
||||||
**10. Language Consistency:**
|
|
||||||
- Extract all class names in English (PascalCase format) for the "name" field
|
|
||||||
- Provide Chinese translation for class names in the "name_chinese" field
|
|
||||||
- Descriptions MUST be in Chinese (中文)
|
|
||||||
- Examples MUST be in Chinese (中文)
|
|
||||||
- Use clear, natural Chinese language for descriptions and examples
|
|
||||||
|
|
||||||
===Examples===
|
|
||||||
|
|
||||||
**Example 1 (Healthcare Domain):**
|
|
||||||
Scenario: "A hospital manages patient records, schedules appointments, and coordinates medical procedures. Doctors diagnose conditions and prescribe treatments."
|
|
||||||
|
|
||||||
Output:
|
|
||||||
{
|
|
||||||
"classes": [
|
|
||||||
{
|
|
||||||
"name": "Patient",
|
|
||||||
"name_chinese": "患者",
|
|
||||||
"description": "在医疗机构接受医疗护理或治疗的人",
|
|
||||||
"examples": ["张三", "李四", "患有糖尿病的老年患者"],
|
|
||||||
"parent_class": null,
|
|
||||||
"entity_type": "Person",
|
|
||||||
"domain": "Healthcare"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "MedicalProcedure",
|
|
||||||
"name_chinese": "医疗程序",
|
|
||||||
"description": "为医疗诊断或治疗而执行的系统性操作流程",
|
|
||||||
"examples": ["手术", "血液检查", "X光检查", "疫苗接种"],
|
|
||||||
"parent_class": null,
|
|
||||||
"entity_type": "Process",
|
|
||||||
"domain": "Healthcare"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Diagnosis",
|
|
||||||
"name_chinese": "诊断",
|
|
||||||
"description": "基于症状和检查结果对疾病或状况的识别",
|
|
||||||
"examples": ["糖尿病诊断", "癌症诊断", "流感诊断"],
|
|
||||||
"parent_class": null,
|
|
||||||
"entity_type": "Concept",
|
|
||||||
"domain": "Healthcare"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Doctor",
|
|
||||||
"name_chinese": "医生",
|
|
||||||
"description": "诊断和治疗患者的持证医疗专业人员",
|
|
||||||
"examples": ["全科医生", "外科医生", "心脏病专家"],
|
|
||||||
"parent_class": null,
|
|
||||||
"entity_type": "Role",
|
|
||||||
"domain": "Healthcare"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Treatment",
|
|
||||||
"name_chinese": "治疗",
|
|
||||||
"description": "为治愈或管理疾病状况而提供的医疗护理或疗法",
|
|
||||||
"examples": ["药物治疗", "物理治疗", "化疗", "手术治疗"],
|
|
||||||
"parent_class": null,
|
|
||||||
"entity_type": "Process",
|
|
||||||
"domain": "Healthcare"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"domain": "Healthcare",
|
|
||||||
"namespace": "http://example.org/healthcare#"
|
|
||||||
}
|
|
||||||
|
|
||||||
**Example 2 (Education Domain):**
|
|
||||||
Scenario: "A university offers courses taught by professors. Students enroll in programs, attend lectures, and complete assignments to earn degrees."
|
|
||||||
|
|
||||||
Output:
|
|
||||||
{
|
|
||||||
"classes": [
|
|
||||||
{
|
|
||||||
"name": "Student",
|
|
||||||
"name_chinese": "学生",
|
|
||||||
"description": "在教育机构注册学习的人",
|
|
||||||
"examples": ["本科生", "研究生", "在职学生"],
|
|
||||||
"parent_class": null,
|
|
||||||
"entity_type": "Role",
|
|
||||||
"domain": "Education"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Course",
|
|
||||||
"name_chinese": "课程",
|
|
||||||
"description": "涵盖特定学科或主题的结构化教育课程",
|
|
||||||
"examples": ["计算机科学导论", "微积分I", "世界历史"],
|
|
||||||
"parent_class": null,
|
|
||||||
"entity_type": "Concept",
|
|
||||||
"domain": "Education"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Professor",
|
|
||||||
"name_chinese": "教授",
|
|
||||||
"description": "教授课程并进行研究的学术教师",
|
|
||||||
"examples": ["助理教授", "副教授", "正教授"],
|
|
||||||
"parent_class": null,
|
|
||||||
"entity_type": "Role",
|
|
||||||
"domain": "Education"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "AcademicProgram",
|
|
||||||
"name_chinese": "学术项目",
|
|
||||||
"description": "通向学位或证书的结构化课程体系",
|
|
||||||
"examples": ["理学学士", "文学硕士", "博士项目"],
|
|
||||||
"parent_class": null,
|
|
||||||
"entity_type": "Concept",
|
|
||||||
"domain": "Education"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Assignment",
|
|
||||||
"name_chinese": "作业",
|
|
||||||
"description": "分配给学生以评估学习成果的任务或项目",
|
|
||||||
"examples": ["论文", "习题集", "研究报告", "实验报告"],
|
|
||||||
"parent_class": null,
|
|
||||||
"entity_type": "Object",
|
|
||||||
"domain": "Education"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Lecture",
|
|
||||||
"name_chinese": "讲座",
|
|
||||||
"description": "由教师进行的教育性演讲或讲座",
|
|
||||||
"examples": ["入门讲座", "客座讲座", "在线讲座"],
|
|
||||||
"parent_class": null,
|
|
||||||
"entity_type": "Event",
|
|
||||||
"domain": "Education"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"domain": "Education",
|
|
||||||
"namespace": "http://example.org/education#"
|
|
||||||
}
|
|
||||||
|
|
||||||
===Output Format===
|
|
||||||
|
|
||||||
**JSON Requirements:**
|
|
||||||
- Use only ASCII double quotes (") for JSON structure
|
|
||||||
- Never use Chinese quotation marks ("") or Unicode quotes
|
|
||||||
- Escape quotation marks in text with backslashes (\")
|
|
||||||
- Ensure proper string closure and comma separation
|
|
||||||
- No line breaks within JSON string values
|
|
||||||
- All class names must be in PascalCase format
|
|
||||||
- All class names must be unique (case-insensitive)
|
|
||||||
- Extract between 5 and {{ max_classes }} classes
|
|
||||||
|
|
||||||
{{ json_schema }}
|
|
||||||
@@ -5,12 +5,6 @@
|
|||||||
===Task===
|
===Task===
|
||||||
Extract entities and knowledge triplets from the given statement.
|
Extract entities and knowledge triplets from the given statement.
|
||||||
|
|
||||||
{% if language == "zh" %}
|
|
||||||
**重要:请使用中文生成实体描述(description)和示例(example)。**
|
|
||||||
{% else %}
|
|
||||||
**Important: Please generate entity descriptions and examples in English.**
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
===Inputs===
|
===Inputs===
|
||||||
**Chunk Content:** "{{ chunk_content }}"
|
**Chunk Content:** "{{ chunk_content }}"
|
||||||
**Statement:** "{{ statement }}"
|
**Statement:** "{{ statement }}"
|
||||||
@@ -19,13 +13,6 @@ Extract entities and knowledge triplets from the given statement.
|
|||||||
|
|
||||||
**Entity Extraction:**
|
**Entity Extraction:**
|
||||||
- Extract entities with their types, context-independent descriptions, **concise examples**, aliases, and semantic memory classification
|
- Extract entities with their types, context-independent descriptions, **concise examples**, aliases, and semantic memory classification
|
||||||
{% if language == "zh" %}
|
|
||||||
- **实体描述(description)必须使用中文**
|
|
||||||
- **示例(example)必须使用中文**
|
|
||||||
{% else %}
|
|
||||||
- **Entity descriptions must be in English**
|
|
||||||
- **Examples must be in English**
|
|
||||||
{% endif %}
|
|
||||||
- **Semantic Memory Classification (is_explicit_memory):**
|
- **Semantic Memory Classification (is_explicit_memory):**
|
||||||
* Set to `true` if the entity represents **explicit/semantic memory**:
|
* Set to `true` if the entity represents **explicit/semantic memory**:
|
||||||
- **Concepts:** "Machine Learning", "Photosynthesis", "Democracy", "人工智能", "光合作用", "民主"
|
- **Concepts:** "Machine Learning", "Photosynthesis", "Democracy", "人工智能", "光合作用", "民主"
|
||||||
@@ -347,11 +334,9 @@ Output:
|
|||||||
- Escape quotation marks in text with backslashes (\")
|
- Escape quotation marks in text with backslashes (\")
|
||||||
- Ensure proper string closure and comma separation
|
- Ensure proper string closure and comma separation
|
||||||
- No line breaks within JSON string values
|
- No line breaks within JSON string values
|
||||||
{% if language == "zh" %}
|
- The output language should ALWAYS match the input language
|
||||||
- **语言要求:实体描述(description)和示例(example)必须使用中文**
|
- If input is in English, extract statements in English
|
||||||
{% else %}
|
- If input is in Chinese, extract statements in Chinese
|
||||||
- **Language Requirement: Entity descriptions and examples must be in English**
|
|
||||||
{% endif %}
|
|
||||||
- Preserve the original language and do not translate
|
- Preserve the original language and do not translate
|
||||||
|
|
||||||
{{ json_schema }}
|
{{ json_schema }}
|
||||||
@@ -5,21 +5,10 @@
|
|||||||
=== Task ===
|
=== Task ===
|
||||||
Summarize the provided conversation chunks into a concise Memory summary.
|
Summarize the provided conversation chunks into a concise Memory summary.
|
||||||
|
|
||||||
{% if language == "zh" %}
|
|
||||||
**重要:请使用中文生成摘要内容。**
|
|
||||||
{% else %}
|
|
||||||
**Important: Please generate the summary content in English.**
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
=== Requirements ===
|
=== Requirements ===
|
||||||
- Focus on factual statements, user preferences, relationships, and salient temporal context.
|
- Focus on factual statements, user preferences, relationships, and salient temporal context.
|
||||||
- Avoid repetition and filler; be specific.
|
- Avoid repetition and filler; be specific.
|
||||||
- Keep it under {{ max_words or 200 }} words.
|
- Keep it under {{ max_words or 200 }} words.
|
||||||
{% if language == "zh" %}
|
|
||||||
- 摘要内容必须使用中文
|
|
||||||
{% else %}
|
|
||||||
- Summary content must be in English
|
|
||||||
{% endif %}
|
|
||||||
- Output must be valid JSON conforming to the schema below.
|
- Output must be valid JSON conforming to the schema below.
|
||||||
|
|
||||||
=== Input ===
|
=== Input ===
|
||||||
@@ -35,11 +24,6 @@ Summarize the provided conversation chunks into a concise Memory summary.
|
|||||||
4. Do not include line breaks within JSON string values
|
4. Do not include line breaks within JSON string values
|
||||||
5. Example of proper escaping: "statement": "张曼婷说:\"我很喜欢这本书。\""
|
5. Example of proper escaping: "statement": "张曼婷说:\"我很喜欢这本书。\""
|
||||||
|
|
||||||
{% if language == "zh" %}
|
The output language should always be the same as the input language.
|
||||||
**语言要求:输出内容必须使用中文。**
|
|
||||||
{% else %}
|
|
||||||
**Language Requirement: The output content must be in English.**
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
Return only a list of extracted labelled statements in the JSON ARRAY of objects that match the schema below:
|
Return only a list of extracted labelled statements in the JSON ARRAY of objects that match the schema below:
|
||||||
{{ json_schema }}
|
{{ json_schema }}
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
"""Validation utilities for ontology extraction.
|
|
||||||
|
|
||||||
This module provides validation classes for ontology class names,
|
|
||||||
descriptions, and OWL compliance checking.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .ontology_validator import OntologyValidator
|
|
||||||
from .owl_validator import OWLValidator
|
|
||||||
|
|
||||||
__all__ = ['OntologyValidator', 'OWLValidator']
|
|
||||||
@@ -1,268 +0,0 @@
|
|||||||
"""String validation for ontology class names and descriptions.
|
|
||||||
|
|
||||||
This module provides the OntologyValidator class for validating and sanitizing
|
|
||||||
ontology class names according to OWL standards and naming conventions.
|
|
||||||
|
|
||||||
Classes:
|
|
||||||
OntologyValidator: Validates class names, removes duplicates, and truncates descriptions
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
from app.core.memory.models.ontology_models import OntologyClass
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class OntologyValidator:
|
|
||||||
"""Validator for ontology class names and descriptions.
|
|
||||||
|
|
||||||
This validator performs string-level validation including:
|
|
||||||
- PascalCase naming convention validation
|
|
||||||
- OWL reserved word checking
|
|
||||||
- Duplicate class name removal
|
|
||||||
- Description length truncation
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
OWL_RESERVED_WORDS: Set of OWL reserved words that cannot be used as class names
|
|
||||||
"""
|
|
||||||
|
|
||||||
# OWL reserved words that cannot be used as class names
|
|
||||||
OWL_RESERVED_WORDS = {
|
|
||||||
'Thing', 'Nothing', 'Class', 'Property',
|
|
||||||
'ObjectProperty', 'DatatypeProperty', 'FunctionalProperty',
|
|
||||||
'InverseFunctionalProperty', 'TransitiveProperty', 'SymmetricProperty',
|
|
||||||
'AsymmetricProperty', 'ReflexiveProperty', 'IrreflexiveProperty',
|
|
||||||
'Restriction', 'Ontology', 'Individual', 'NamedIndividual',
|
|
||||||
'Annotation', 'AnnotationProperty', 'Axiom',
|
|
||||||
'AllDifferent', 'AllDisjointClasses', 'AllDisjointProperties',
|
|
||||||
'Datatype', 'DataRange', 'Literal',
|
|
||||||
'DeprecatedClass', 'DeprecatedProperty',
|
|
||||||
'Imports', 'IncompatibleWith', 'PriorVersion', 'VersionInfo',
|
|
||||||
'BackwardCompatibleWith', 'OntologyProperty',
|
|
||||||
}
|
|
||||||
|
|
||||||
def validate_class_name(self, name: str) -> Tuple[bool, str]:
|
|
||||||
"""Validate that a class name follows OWL naming conventions.
|
|
||||||
|
|
||||||
Validation rules:
|
|
||||||
1. Must not be empty
|
|
||||||
2. Must start with an uppercase letter (PascalCase)
|
|
||||||
3. Cannot contain spaces
|
|
||||||
4. Can only contain alphanumeric characters and underscores
|
|
||||||
5. Cannot be an OWL reserved word
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: The class name to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (is_valid, error_message)
|
|
||||||
- is_valid: True if the name is valid, False otherwise
|
|
||||||
- error_message: Empty string if valid, error description if invalid
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> validator = OntologyValidator()
|
|
||||||
>>> validator.validate_class_name("MedicalProcedure")
|
|
||||||
(True, "")
|
|
||||||
>>> validator.validate_class_name("medical procedure")
|
|
||||||
(False, "Class name 'medical procedure' cannot contain spaces")
|
|
||||||
>>> validator.validate_class_name("Thing")
|
|
||||||
(False, "Class name 'Thing' is an OWL reserved word")
|
|
||||||
"""
|
|
||||||
logger.debug(f"Validating class name: '{name}'")
|
|
||||||
|
|
||||||
# Check if empty
|
|
||||||
if not name or not name.strip():
|
|
||||||
error_msg = "Class name cannot be empty"
|
|
||||||
logger.warning(f"Validation failed: {error_msg}")
|
|
||||||
return False, error_msg
|
|
||||||
|
|
||||||
name = name.strip()
|
|
||||||
|
|
||||||
# Check if it's an OWL reserved word
|
|
||||||
if name in self.OWL_RESERVED_WORDS:
|
|
||||||
error_msg = f"Class name '{name}' is an OWL reserved word"
|
|
||||||
logger.warning(f"Validation failed: {error_msg}")
|
|
||||||
return False, error_msg
|
|
||||||
|
|
||||||
# Check if starts with uppercase letter
|
|
||||||
if not name[0].isupper():
|
|
||||||
error_msg = f"Class name '{name}' must start with an uppercase letter (PascalCase)"
|
|
||||||
logger.warning(f"Validation failed: {error_msg}")
|
|
||||||
return False, error_msg
|
|
||||||
|
|
||||||
# Check for spaces
|
|
||||||
if ' ' in name:
|
|
||||||
error_msg = f"Class name '{name}' cannot contain spaces"
|
|
||||||
logger.warning(f"Validation failed: {error_msg}")
|
|
||||||
return False, error_msg
|
|
||||||
|
|
||||||
# Check for invalid characters (only alphanumeric and underscore allowed)
|
|
||||||
if not re.match(r'^[A-Za-z0-9_]+$', name):
|
|
||||||
error_msg = f"Class name '{name}' contains invalid characters. Only alphanumeric characters and underscores are allowed"
|
|
||||||
logger.warning(f"Validation failed: {error_msg}")
|
|
||||||
return False, error_msg
|
|
||||||
|
|
||||||
logger.debug(f"Class name '{name}' is valid")
|
|
||||||
return True, ""
|
|
||||||
|
|
||||||
def sanitize_class_name(self, name: str) -> str:
|
|
||||||
"""Attempt to sanitize an invalid class name into a valid format.
|
|
||||||
|
|
||||||
Sanitization steps:
|
|
||||||
1. Strip whitespace
|
|
||||||
2. Remove invalid characters
|
|
||||||
3. Replace spaces with empty string (PascalCase)
|
|
||||||
4. Capitalize first letter of each word
|
|
||||||
5. If result is empty or starts with number, prefix with 'Class'
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: The class name to sanitize
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Sanitized class name that should pass validation
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> validator = OntologyValidator()
|
|
||||||
>>> validator.sanitize_class_name("medical procedure")
|
|
||||||
'MedicalProcedure'
|
|
||||||
>>> validator.sanitize_class_name("patient-record")
|
|
||||||
'PatientRecord'
|
|
||||||
>>> validator.sanitize_class_name("123invalid")
|
|
||||||
'Class123Invalid'
|
|
||||||
"""
|
|
||||||
logger.debug(f"Sanitizing class name: '{name}'")
|
|
||||||
|
|
||||||
if not name or not name.strip():
|
|
||||||
logger.warning("Empty class name provided for sanitization, returning 'UnnamedClass'")
|
|
||||||
return "UnnamedClass"
|
|
||||||
|
|
||||||
# Strip whitespace
|
|
||||||
name = name.strip()
|
|
||||||
original_name = name
|
|
||||||
|
|
||||||
# Split on spaces, hyphens, and underscores, then capitalize each word
|
|
||||||
words = re.split(r'[\s\-_]+', name)
|
|
||||||
|
|
||||||
# Capitalize first letter of each word and keep rest as is
|
|
||||||
sanitized_words = []
|
|
||||||
for word in words:
|
|
||||||
if word:
|
|
||||||
# Remove non-alphanumeric characters except underscore
|
|
||||||
clean_word = re.sub(r'[^A-Za-z0-9_]', '', word)
|
|
||||||
if clean_word:
|
|
||||||
# Capitalize first letter
|
|
||||||
sanitized_words.append(clean_word[0].upper() + clean_word[1:])
|
|
||||||
|
|
||||||
# Join words
|
|
||||||
sanitized = ''.join(sanitized_words)
|
|
||||||
|
|
||||||
# If empty or starts with number, prefix with 'Class'
|
|
||||||
if not sanitized or sanitized[0].isdigit():
|
|
||||||
sanitized = 'Class' + sanitized
|
|
||||||
logger.info(f"Prefixed class name with 'Class': '{original_name}' -> '{sanitized}'")
|
|
||||||
|
|
||||||
# If it's a reserved word, append 'Class' suffix
|
|
||||||
if sanitized in self.OWL_RESERVED_WORDS:
|
|
||||||
sanitized = sanitized + 'Class'
|
|
||||||
logger.info(f"Appended 'Class' suffix to reserved word: '{original_name}' -> '{sanitized}'")
|
|
||||||
|
|
||||||
logger.info(f"Sanitized class name: '{original_name}' -> '{sanitized}'")
|
|
||||||
return sanitized
|
|
||||||
|
|
||||||
def remove_duplicates(self, classes: List[OntologyClass]) -> List[OntologyClass]:
|
|
||||||
"""Remove duplicate ontology classes based on case-insensitive name comparison.
|
|
||||||
|
|
||||||
When duplicates are found, keeps the first occurrence and discards subsequent ones.
|
|
||||||
Comparison is case-insensitive to catch variations like 'Patient' and 'patient'.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
classes: List of OntologyClass objects
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of OntologyClass objects with duplicates removed
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> validator = OntologyValidator()
|
|
||||||
>>> classes = [
|
|
||||||
... OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare"),
|
|
||||||
... OntologyClass(name="patient", description="Another patient", entity_type="Person", domain="Healthcare"),
|
|
||||||
... OntologyClass(name="Doctor", description="A doctor", entity_type="Person", domain="Healthcare"),
|
|
||||||
... ]
|
|
||||||
>>> unique = validator.remove_duplicates(classes)
|
|
||||||
>>> len(unique)
|
|
||||||
2
|
|
||||||
>>> [c.name for c in unique]
|
|
||||||
['Patient', 'Doctor']
|
|
||||||
"""
|
|
||||||
if not classes:
|
|
||||||
logger.debug("No classes to check for duplicates")
|
|
||||||
return classes
|
|
||||||
|
|
||||||
logger.debug(f"Checking {len(classes)} classes for duplicates")
|
|
||||||
|
|
||||||
seen_names = set()
|
|
||||||
unique_classes = []
|
|
||||||
duplicates_found = []
|
|
||||||
|
|
||||||
for ontology_class in classes:
|
|
||||||
# Use lowercase for comparison
|
|
||||||
name_lower = ontology_class.name.lower()
|
|
||||||
|
|
||||||
if name_lower not in seen_names:
|
|
||||||
seen_names.add(name_lower)
|
|
||||||
unique_classes.append(ontology_class)
|
|
||||||
else:
|
|
||||||
duplicates_found.append(ontology_class.name)
|
|
||||||
logger.debug(f"Duplicate class found and removed: '{ontology_class.name}'")
|
|
||||||
|
|
||||||
if duplicates_found:
|
|
||||||
logger.info(
|
|
||||||
f"Removed {len(duplicates_found)} duplicate classes: {duplicates_found}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug("No duplicate classes found")
|
|
||||||
|
|
||||||
return unique_classes
|
|
||||||
|
|
||||||
def truncate_description(self, description: str, max_length: int = 500) -> str:
|
|
||||||
"""Truncate a description to a maximum length.
|
|
||||||
|
|
||||||
If the description exceeds max_length, it will be truncated and
|
|
||||||
an ellipsis (...) will be appended to indicate truncation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
description: The description text to truncate
|
|
||||||
max_length: Maximum allowed length (default: 500)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Truncated description string
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> validator = OntologyValidator()
|
|
||||||
>>> long_desc = "A" * 600
|
|
||||||
>>> truncated = validator.truncate_description(long_desc, max_length=500)
|
|
||||||
>>> len(truncated)
|
|
||||||
500
|
|
||||||
>>> truncated.endswith("...")
|
|
||||||
True
|
|
||||||
"""
|
|
||||||
if not description:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
if len(description) <= max_length:
|
|
||||||
return description
|
|
||||||
|
|
||||||
# Truncate and add ellipsis
|
|
||||||
# Reserve 3 characters for "..."
|
|
||||||
truncate_at = max_length - 3
|
|
||||||
truncated = description[:truncate_at] + "..."
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Truncated description from {len(description)} to {len(truncated)} characters"
|
|
||||||
)
|
|
||||||
|
|
||||||
return truncated
|
|
||||||
@@ -1,585 +0,0 @@
|
|||||||
"""OWL semantic validation for ontology classes using Owlready2.
|
|
||||||
|
|
||||||
This module provides the OWLValidator class for validating ontology classes
|
|
||||||
against OWL standards using the Owlready2 library. It performs semantic
|
|
||||||
validation including consistency checking, circular inheritance detection,
|
|
||||||
and OWL file export.
|
|
||||||
|
|
||||||
Classes:
|
|
||||||
OWLValidator: Validates ontology classes using OWL reasoning and exports to OWL formats
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
from owlready2 import (
|
|
||||||
World,
|
|
||||||
Thing,
|
|
||||||
get_ontology,
|
|
||||||
sync_reasoner_pellet,
|
|
||||||
OwlReadyInconsistentOntologyError,
|
|
||||||
)
|
|
||||||
|
|
||||||
from app.core.memory.models.ontology_models import OntologyClass
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class OWLValidator:
|
|
||||||
"""Validator for OWL semantic validation of ontology classes.
|
|
||||||
|
|
||||||
This validator performs semantic-level validation using Owlready2 including:
|
|
||||||
- Creating OWL classes from ontology class definitions
|
|
||||||
- Running consistency checking with Pellet reasoner
|
|
||||||
- Detecting circular inheritance
|
|
||||||
- Validating Protégé compatibility
|
|
||||||
- Exporting ontologies to various OWL formats (RDF/XML, Turtle, N-Triples)
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
base_namespace: Base URI for the ontology namespace
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, base_namespace: str = "http://example.org/ontology#"):
|
|
||||||
"""Initialize the OWL validator.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
base_namespace: Base URI for the ontology namespace (default: http://example.org/ontology#)
|
|
||||||
"""
|
|
||||||
self.base_namespace = base_namespace
|
|
||||||
|
|
||||||
def validate_ontology_classes(
|
|
||||||
self,
|
|
||||||
classes: List[OntologyClass],
|
|
||||||
) -> Tuple[bool, List[str], Optional[World]]:
|
|
||||||
"""Validate extracted ontology classes against OWL standards.
|
|
||||||
|
|
||||||
This method creates an OWL ontology from the provided classes using Owlready2,
|
|
||||||
runs consistency checking with the Pellet reasoner, and detects common issues
|
|
||||||
like circular inheritance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
classes: List of OntologyClass objects to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (is_valid, error_messages, world):
|
|
||||||
- is_valid: True if ontology is valid and consistent, False otherwise
|
|
||||||
- error_messages: List of error/warning messages
|
|
||||||
- world: Owlready2 World object containing the ontology (None if validation failed)
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> validator = OWLValidator()
|
|
||||||
>>> classes = [
|
|
||||||
... OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare"),
|
|
||||||
... OntologyClass(name="Doctor", description="A doctor", entity_type="Person", domain="Healthcare"),
|
|
||||||
... ]
|
|
||||||
>>> is_valid, errors, world = validator.validate_ontology_classes(classes)
|
|
||||||
>>> is_valid
|
|
||||||
True
|
|
||||||
>>> len(errors)
|
|
||||||
0
|
|
||||||
"""
|
|
||||||
if not classes:
|
|
||||||
return False, ["No classes provided for validation"], None
|
|
||||||
|
|
||||||
errors = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create a new world (isolated ontology environment)
|
|
||||||
world = World()
|
|
||||||
|
|
||||||
# Use a proper ontology IRI
|
|
||||||
# Owlready2 expects the IRI to end with .owl or similar
|
|
||||||
onto_iri = self.base_namespace.rstrip('#/')
|
|
||||||
if not onto_iri.endswith('.owl'):
|
|
||||||
onto_iri = onto_iri + '.owl'
|
|
||||||
|
|
||||||
# Create ontology
|
|
||||||
onto = world.get_ontology(onto_iri)
|
|
||||||
|
|
||||||
with onto:
|
|
||||||
# Dictionary to store created OWL classes for parent reference
|
|
||||||
owl_classes = {}
|
|
||||||
|
|
||||||
# First pass: Create all classes without parent relationships
|
|
||||||
for ontology_class in classes:
|
|
||||||
try:
|
|
||||||
# Create OWL class dynamically using type() with Thing as base
|
|
||||||
# The key is to NOT set namespace in the dict, let Owlready2 handle it
|
|
||||||
owl_class = type(
|
|
||||||
ontology_class.name, # Class name
|
|
||||||
(Thing,), # Base classes
|
|
||||||
{} # Class dict (empty, let Owlready2 manage)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add label (rdfs:label) - include both English and Chinese names
|
|
||||||
labels = [ontology_class.name]
|
|
||||||
if ontology_class.name_chinese:
|
|
||||||
labels.append(ontology_class.name_chinese)
|
|
||||||
owl_class.label = labels
|
|
||||||
|
|
||||||
# Add comment (rdfs:comment) with description
|
|
||||||
if ontology_class.description:
|
|
||||||
owl_class.comment = [ontology_class.description]
|
|
||||||
|
|
||||||
# Store for parent relationship setup
|
|
||||||
owl_classes[ontology_class.name] = owl_class
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Created OWL class: {ontology_class.name} "
|
|
||||||
f"(Chinese: {ontology_class.name_chinese}) "
|
|
||||||
f"IRI: {owl_class.iri if hasattr(owl_class, 'iri') else 'N/A'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Failed to create OWL class '{ontology_class.name}': {str(e)}"
|
|
||||||
errors.append(error_msg)
|
|
||||||
logger.error(error_msg, exc_info=True)
|
|
||||||
|
|
||||||
# Second pass: Set up parent relationships
|
|
||||||
for ontology_class in classes:
|
|
||||||
if ontology_class.parent_class and ontology_class.name in owl_classes:
|
|
||||||
parent_name = ontology_class.parent_class
|
|
||||||
|
|
||||||
# Check if parent exists
|
|
||||||
if parent_name in owl_classes:
|
|
||||||
try:
|
|
||||||
child_class = owl_classes[ontology_class.name]
|
|
||||||
parent_class = owl_classes[parent_name]
|
|
||||||
|
|
||||||
# Set parent by modifying is_a
|
|
||||||
child_class.is_a = [parent_class]
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Set parent relationship: {ontology_class.name} -> {parent_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = (
|
|
||||||
f"Failed to set parent relationship "
|
|
||||||
f"'{ontology_class.name}' -> '{parent_name}': {str(e)}"
|
|
||||||
)
|
|
||||||
errors.append(error_msg)
|
|
||||||
logger.warning(error_msg)
|
|
||||||
else:
|
|
||||||
warning_msg = (
|
|
||||||
f"Parent class '{parent_name}' not found for '{ontology_class.name}'"
|
|
||||||
)
|
|
||||||
errors.append(warning_msg)
|
|
||||||
logger.warning(warning_msg)
|
|
||||||
|
|
||||||
# Check for circular inheritance
|
|
||||||
for class_name, owl_class in owl_classes.items():
|
|
||||||
if self._has_circular_inheritance(owl_class):
|
|
||||||
error_msg = f"Circular inheritance detected for class '{class_name}'"
|
|
||||||
errors.append(error_msg)
|
|
||||||
logger.error(error_msg)
|
|
||||||
|
|
||||||
# Run consistency checking with Pellet reasoner
|
|
||||||
try:
|
|
||||||
logger.info("Running Pellet reasoner for consistency checking...")
|
|
||||||
sync_reasoner_pellet(world, infer_property_values=True, infer_data_property_values=True)
|
|
||||||
logger.info("Consistency check passed")
|
|
||||||
|
|
||||||
except OwlReadyInconsistentOntologyError as e:
|
|
||||||
error_msg = f"Ontology is inconsistent: {str(e)}"
|
|
||||||
errors.append(error_msg)
|
|
||||||
logger.error(error_msg)
|
|
||||||
return False, errors, world
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Reasoner errors are often due to Java not being installed or configured
|
|
||||||
# Log as warning but don't fail validation - ontology structure is still valid
|
|
||||||
warning_msg = f"Reasoner check skipped: {str(e)}"
|
|
||||||
if str(e).strip(): # Only log if there's an actual error message
|
|
||||||
logger.warning(warning_msg)
|
|
||||||
else:
|
|
||||||
logger.warning("Reasoner check skipped: Java may not be installed or configured")
|
|
||||||
# Continue - ontology structure is valid even without reasoner check
|
|
||||||
|
|
||||||
# If we have errors (excluding warnings), validation failed
|
|
||||||
is_valid = len(errors) == 0
|
|
||||||
|
|
||||||
return is_valid, errors, world
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"OWL validation failed: {str(e)}"
|
|
||||||
errors.append(error_msg)
|
|
||||||
logger.error(error_msg, exc_info=True)
|
|
||||||
return False, errors, None
|
|
||||||
|
|
||||||
def _has_circular_inheritance(self, owl_class) -> bool:
|
|
||||||
"""Check if an OWL class has circular inheritance.
|
|
||||||
|
|
||||||
Circular inheritance occurs when a class inherits from itself through
|
|
||||||
a chain of parent relationships (e.g., A -> B -> C -> A).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
owl_class: Owlready2 class object to check
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if circular inheritance is detected, False otherwise
|
|
||||||
"""
|
|
||||||
visited = set()
|
|
||||||
current = owl_class
|
|
||||||
|
|
||||||
while current:
|
|
||||||
# Get class IRI or name as identifier
|
|
||||||
class_id = str(current.iri) if hasattr(current, 'iri') else str(current)
|
|
||||||
|
|
||||||
if class_id in visited:
|
|
||||||
# Found a cycle
|
|
||||||
return True
|
|
||||||
|
|
||||||
visited.add(class_id)
|
|
||||||
|
|
||||||
# Get parent classes (is_a relationship)
|
|
||||||
parents = getattr(current, 'is_a', [])
|
|
||||||
|
|
||||||
# Filter out Thing and other base classes
|
|
||||||
parent_classes = [p for p in parents if p != Thing and hasattr(p, 'is_a')]
|
|
||||||
|
|
||||||
if not parent_classes:
|
|
||||||
# No more parents, no cycle
|
|
||||||
break
|
|
||||||
|
|
||||||
# Check first parent (in single inheritance)
|
|
||||||
current = parent_classes[0] if parent_classes else None
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def export_to_owl(
|
|
||||||
self,
|
|
||||||
world: World,
|
|
||||||
output_path: Optional[str] = None,
|
|
||||||
format: str = "rdfxml",
|
|
||||||
classes: Optional[List] = None
|
|
||||||
) -> str:
|
|
||||||
"""Export ontology to OWL file in specified format.
|
|
||||||
|
|
||||||
Supported formats:
|
|
||||||
- rdfxml: RDF/XML format (default, most compatible)
|
|
||||||
- turtle: Turtle format (more readable)
|
|
||||||
- ntriples: N-Triples format (simplest)
|
|
||||||
- json: JSON format (simplified, human-readable)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
world: Owlready2 World object containing the ontology
|
|
||||||
output_path: Optional file path to save the ontology (if None, returns string)
|
|
||||||
format: Export format - "rdfxml", "turtle", "ntriples", or "json" (default: "rdfxml")
|
|
||||||
classes: Optional list of OntologyClass objects (required for json format)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
String representation of the exported ontology
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If format is not supported
|
|
||||||
RuntimeError: If export fails
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> validator = OWLValidator()
|
|
||||||
>>> is_valid, errors, world = validator.validate_ontology_classes(classes)
|
|
||||||
>>> owl_content = validator.export_to_owl(world, "ontology.owl", format="rdfxml")
|
|
||||||
"""
|
|
||||||
# Validate format
|
|
||||||
valid_formats = ["rdfxml", "turtle", "ntriples", "json"]
|
|
||||||
if format not in valid_formats:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported format '{format}'. Must be one of: {', '.join(valid_formats)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# JSON format doesn't need OWL processing
|
|
||||||
if format == "json":
|
|
||||||
if not classes:
|
|
||||||
raise ValueError("Classes list is required for JSON format export")
|
|
||||||
return self._export_to_json(classes)
|
|
||||||
|
|
||||||
# For OWL formats, world is required
|
|
||||||
if not world:
|
|
||||||
raise ValueError("World object is None. Cannot export ontology.")
|
|
||||||
|
|
||||||
# Note: Owlready2 has issues with turtle format export
|
|
||||||
# We'll handle it specially by converting from rdfxml
|
|
||||||
use_conversion = (format == "turtle")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get all ontologies in the world
|
|
||||||
ontologies = list(world.ontologies.values())
|
|
||||||
|
|
||||||
if not ontologies:
|
|
||||||
raise RuntimeError("No ontologies found in world")
|
|
||||||
|
|
||||||
# Find the ontology with classes (skip anonymous/empty ontologies)
|
|
||||||
onto = None
|
|
||||||
for ont in ontologies:
|
|
||||||
classes_count = len(list(ont.classes()))
|
|
||||||
logger.debug(f"Checking ontology {ont.base_iri}: {classes_count} classes")
|
|
||||||
if classes_count > 0:
|
|
||||||
onto = ont
|
|
||||||
break
|
|
||||||
|
|
||||||
# If no ontology with classes found, use the last non-anonymous one
|
|
||||||
if onto is None:
|
|
||||||
for ont in reversed(ontologies):
|
|
||||||
if ont.base_iri != "http://anonymous/":
|
|
||||||
onto = ont
|
|
||||||
break
|
|
||||||
|
|
||||||
# If still no ontology, use the first one
|
|
||||||
if onto is None:
|
|
||||||
onto = ontologies[0]
|
|
||||||
|
|
||||||
# Log ontology contents for debugging
|
|
||||||
logger.info(f"Ontology IRI: {onto.base_iri}")
|
|
||||||
logger.info(f"Ontology contains {len(list(onto.classes()))} classes")
|
|
||||||
|
|
||||||
# List all classes in the ontology
|
|
||||||
all_classes = list(onto.classes())
|
|
||||||
for cls in all_classes:
|
|
||||||
logger.info(f"Class in ontology: {cls.name} (IRI: {cls.iri})")
|
|
||||||
if hasattr(cls, 'label'):
|
|
||||||
logger.debug(f" Labels: {cls.label}")
|
|
||||||
if hasattr(cls, 'comment'):
|
|
||||||
logger.debug(f" Comments: {cls.comment}")
|
|
||||||
|
|
||||||
if len(all_classes) == 0:
|
|
||||||
logger.warning("No classes found in ontology! This may indicate a problem with class creation.")
|
|
||||||
|
|
||||||
if output_path:
|
|
||||||
# Save to file
|
|
||||||
export_format = "rdfxml" if use_conversion else format
|
|
||||||
logger.info(f"Exporting ontology to {output_path} in {export_format} format")
|
|
||||||
onto.save(file=output_path, format=export_format)
|
|
||||||
|
|
||||||
# Read back the file content to return
|
|
||||||
with open(output_path, 'r', encoding='utf-8') as f:
|
|
||||||
content = f.read()
|
|
||||||
|
|
||||||
# Convert to turtle if needed
|
|
||||||
if use_conversion:
|
|
||||||
content = self._convert_to_turtle(content)
|
|
||||||
|
|
||||||
logger.info(f"Successfully exported ontology to {output_path}")
|
|
||||||
|
|
||||||
# Format the content for better readability
|
|
||||||
content = self._format_owl_content(content, format)
|
|
||||||
|
|
||||||
return content
|
|
||||||
else:
|
|
||||||
# Export to string (save to temporary location and read)
|
|
||||||
import tempfile
|
|
||||||
import os
|
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.owl', delete=False) as tmp:
|
|
||||||
tmp_path = tmp.name
|
|
||||||
|
|
||||||
try:
|
|
||||||
export_format = "rdfxml" if use_conversion else format
|
|
||||||
onto.save(file=tmp_path, format=export_format)
|
|
||||||
|
|
||||||
with open(tmp_path, 'r', encoding='utf-8') as f:
|
|
||||||
content = f.read()
|
|
||||||
|
|
||||||
# Convert to turtle if needed
|
|
||||||
if use_conversion:
|
|
||||||
content = self._convert_to_turtle(content)
|
|
||||||
|
|
||||||
# Format the content for better readability
|
|
||||||
content = self._format_owl_content(content, format)
|
|
||||||
|
|
||||||
return content
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Clean up temporary file
|
|
||||||
if os.path.exists(tmp_path):
|
|
||||||
os.remove(tmp_path)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Failed to export ontology: {str(e)}"
|
|
||||||
logger.error(error_msg, exc_info=True)
|
|
||||||
raise RuntimeError(error_msg) from e
|
|
||||||
|
|
||||||
def _export_to_json(self, classes: List) -> str:
|
|
||||||
"""Export ontology classes to simplified JSON format.
|
|
||||||
|
|
||||||
This format is more compact and easier to parse than OWL XML.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
classes: List of OntologyClass objects
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
JSON string representation (compact format)
|
|
||||||
"""
|
|
||||||
import json
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"ontology": {
|
|
||||||
"namespace": self.base_namespace,
|
|
||||||
"classes": []
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for cls in classes:
|
|
||||||
class_data = {
|
|
||||||
"name": cls.name,
|
|
||||||
"name_chinese": cls.name_chinese,
|
|
||||||
"description": cls.description,
|
|
||||||
"entity_type": cls.entity_type,
|
|
||||||
"domain": cls.domain,
|
|
||||||
"parent_class": cls.parent_class,
|
|
||||||
"examples": cls.examples if hasattr(cls, 'examples') else []
|
|
||||||
}
|
|
||||||
result["ontology"]["classes"].append(class_data)
|
|
||||||
|
|
||||||
# 使用紧凑格式:无缩进,使用分隔符减少空格
|
|
||||||
return json.dumps(result, ensure_ascii=False, separators=(',', ':'))
|
|
||||||
|
|
||||||
def _convert_to_turtle(self, rdfxml_content: str) -> str:
|
|
||||||
"""Convert RDF/XML content to Turtle format using rdflib.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rdfxml_content: RDF/XML format content
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Turtle format content
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from rdflib import Graph
|
|
||||||
|
|
||||||
# Parse RDF/XML
|
|
||||||
g = Graph()
|
|
||||||
g.parse(data=rdfxml_content, format="xml")
|
|
||||||
|
|
||||||
# Serialize to Turtle
|
|
||||||
turtle_content = g.serialize(format="turtle")
|
|
||||||
|
|
||||||
# Handle bytes vs string
|
|
||||||
if isinstance(turtle_content, bytes):
|
|
||||||
turtle_content = turtle_content.decode('utf-8')
|
|
||||||
|
|
||||||
return turtle_content
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
logger.warning(
|
|
||||||
"rdflib is not installed. Cannot convert to Turtle format. "
|
|
||||||
"Install with: pip install rdflib"
|
|
||||||
)
|
|
||||||
return rdfxml_content
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to convert to Turtle format: {e}")
|
|
||||||
return rdfxml_content
|
|
||||||
|
|
||||||
def _format_owl_content(self, content: str, format: str) -> str:
|
|
||||||
"""Format OWL content for better readability.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: Raw OWL content string
|
|
||||||
format: Format type (rdfxml, turtle, ntriples)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted OWL content string
|
|
||||||
"""
|
|
||||||
if format == "rdfxml":
|
|
||||||
# Format XML with proper indentation
|
|
||||||
try:
|
|
||||||
import xml.dom.minidom as minidom
|
|
||||||
dom = minidom.parseString(content)
|
|
||||||
# Pretty print with 2-space indentation
|
|
||||||
formatted = dom.toprettyxml(indent=" ", encoding="utf-8").decode("utf-8")
|
|
||||||
|
|
||||||
# Remove extra blank lines
|
|
||||||
lines = []
|
|
||||||
prev_blank = False
|
|
||||||
for line in formatted.split('\n'):
|
|
||||||
is_blank = not line.strip()
|
|
||||||
if not (is_blank and prev_blank): # Skip consecutive blank lines
|
|
||||||
lines.append(line)
|
|
||||||
prev_blank = is_blank
|
|
||||||
|
|
||||||
formatted = '\n'.join(lines)
|
|
||||||
|
|
||||||
return formatted
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to format XML content: {e}")
|
|
||||||
return content
|
|
||||||
|
|
||||||
elif format == "turtle":
|
|
||||||
# Turtle format is already relatively readable
|
|
||||||
# Just ensure consistent line endings and not empty
|
|
||||||
if not content or content.strip() == "":
|
|
||||||
logger.warning("Turtle content is empty, this may indicate an export issue")
|
|
||||||
return content.strip() + '\n' if content.strip() else content
|
|
||||||
|
|
||||||
elif format == "ntriples":
|
|
||||||
# N-Triples format is line-based, ensure proper line endings
|
|
||||||
return content.strip() + '\n' if content.strip() else content
|
|
||||||
|
|
||||||
return content
|
|
||||||
|
|
||||||
def validate_with_protege_compatibility(
|
|
||||||
self,
|
|
||||||
classes: List[OntologyClass]
|
|
||||||
) -> Tuple[bool, List[str]]:
|
|
||||||
"""Validate that ontology classes are compatible with Protégé editor.
|
|
||||||
|
|
||||||
Protégé compatibility checks:
|
|
||||||
- Class names are valid OWL identifiers
|
|
||||||
- No special characters that Protégé cannot handle
|
|
||||||
- Namespace is properly formatted
|
|
||||||
- Labels and comments are properly encoded
|
|
||||||
|
|
||||||
Args:
|
|
||||||
classes: List of OntologyClass objects to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (is_compatible, warnings):
|
|
||||||
- is_compatible: True if compatible with Protégé, False otherwise
|
|
||||||
- warnings: List of compatibility warning messages
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> validator = OWLValidator()
|
|
||||||
>>> classes = [OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare")]
|
|
||||||
>>> is_compatible, warnings = validator.validate_with_protege_compatibility(classes)
|
|
||||||
>>> is_compatible
|
|
||||||
True
|
|
||||||
"""
|
|
||||||
warnings = []
|
|
||||||
|
|
||||||
# Check namespace format
|
|
||||||
if not self.base_namespace.startswith(('http://', 'https://')):
|
|
||||||
warnings.append(
|
|
||||||
f"Namespace '{self.base_namespace}' should start with http:// or https:// "
|
|
||||||
"for Protégé compatibility"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.base_namespace.endswith(('#', '/')):
|
|
||||||
warnings.append(
|
|
||||||
f"Namespace '{self.base_namespace}' should end with # or / "
|
|
||||||
"for Protégé compatibility"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check each class
|
|
||||||
for ontology_class in classes:
|
|
||||||
# Check for special characters that might cause issues
|
|
||||||
if any(char in ontology_class.name for char in ['<', '>', '"', '{', '}', '|', '^', '`']):
|
|
||||||
warnings.append(
|
|
||||||
f"Class name '{ontology_class.name}' contains special characters "
|
|
||||||
"that may cause issues in Protégé"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check description length (Protégé can handle long descriptions but may display poorly)
|
|
||||||
if ontology_class.description and len(ontology_class.description) > 1000:
|
|
||||||
warnings.append(
|
|
||||||
f"Class '{ontology_class.name}' has a very long description ({len(ontology_class.description)} chars) "
|
|
||||||
"which may display poorly in Protégé"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check for non-ASCII characters (Protégé supports them but encoding issues may occur)
|
|
||||||
if not ontology_class.name.isascii():
|
|
||||||
warnings.append(
|
|
||||||
f"Class name '{ontology_class.name}' contains non-ASCII characters "
|
|
||||||
"which may cause encoding issues in some Protégé versions"
|
|
||||||
)
|
|
||||||
|
|
||||||
# If no warnings, it's compatible
|
|
||||||
is_compatible = len(warnings) == 0
|
|
||||||
|
|
||||||
return is_compatible, warnings
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
"""模型配置脚本模块"""
|
|
||||||
@@ -1,174 +0,0 @@
|
|||||||
provider: bedrock
|
|
||||||
enabled: true
|
|
||||||
models:
|
|
||||||
- name: ai21
|
|
||||||
type: llm
|
|
||||||
provider: bedrock
|
|
||||||
description: AI21 Labs大语言模型,completion生成模式,256000上下文窗口
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
logo: bedrock
|
|
||||||
- name: amazon nova
|
|
||||||
type: llm
|
|
||||||
provider: bedrock
|
|
||||||
description: Amazon Nova大语言模型,支持智能体思考、工具调用、流式工具调用、视觉能力,300000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
- tool-call
|
|
||||||
- stream-tool-call
|
|
||||||
- vision
|
|
||||||
logo: bedrock
|
|
||||||
- name: anthropic claude
|
|
||||||
type: llm
|
|
||||||
provider: bedrock
|
|
||||||
description: Anthropic Claude大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用、文档处理,200000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
- vision
|
|
||||||
- tool-call
|
|
||||||
- stream-tool-call
|
|
||||||
- document
|
|
||||||
logo: bedrock
|
|
||||||
- name: cohere
|
|
||||||
type: llm
|
|
||||||
provider: bedrock
|
|
||||||
description: Cohere大语言模型,支持智能体思考、工具调用、流式工具调用,128000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
- tool-call
|
|
||||||
- stream-tool-call
|
|
||||||
logo: bedrock
|
|
||||||
- name: deepseek
|
|
||||||
type: llm
|
|
||||||
provider: bedrock
|
|
||||||
description: DeepSeek大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用,32768上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
- vision
|
|
||||||
- tool-call
|
|
||||||
- stream-tool-call
|
|
||||||
logo: bedrock
|
|
||||||
- name: meta
|
|
||||||
type: llm
|
|
||||||
provider: bedrock
|
|
||||||
description: Meta Llama大语言模型,支持智能体思考、工具调用,128000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
- tool-call
|
|
||||||
logo: bedrock
|
|
||||||
- name: mistral
|
|
||||||
type: llm
|
|
||||||
provider: bedrock
|
|
||||||
description: Mistral AI大语言模型,支持智能体思考、工具调用,32000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
- tool-call
|
|
||||||
logo: bedrock
|
|
||||||
- name: openai
|
|
||||||
type: llm
|
|
||||||
provider: bedrock
|
|
||||||
description: OpenAI大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
- tool-call
|
|
||||||
- stream-tool-call
|
|
||||||
logo: bedrock
|
|
||||||
- name: qwen
|
|
||||||
type: llm
|
|
||||||
provider: bedrock
|
|
||||||
description: Qwen大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
- tool-call
|
|
||||||
- stream-tool-call
|
|
||||||
logo: bedrock
|
|
||||||
- name: amazon.rerank-v1:0
|
|
||||||
type: rerank
|
|
||||||
provider: bedrock
|
|
||||||
description: amazon.rerank-v1:0重排序模型,5120上下文窗口
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 重排序模型
|
|
||||||
logo: bedrock
|
|
||||||
- name: cohere.rerank-v3-5:0
|
|
||||||
type: rerank
|
|
||||||
provider: bedrock
|
|
||||||
description: cohere.rerank-v3-5:0重排序模型,5120上下文窗口
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 重排序模型
|
|
||||||
logo: bedrock
|
|
||||||
- name: amazon.nova-2-multimodal-embeddings-v1:0
|
|
||||||
type: embedding
|
|
||||||
provider: bedrock
|
|
||||||
description: amazon.nova-2-multimodal-embeddings-v1:0文本嵌入模型,支持视觉能力,8192上下文窗口
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 文本嵌入模型
|
|
||||||
- vision
|
|
||||||
logo: bedrock
|
|
||||||
- name: amazon.titan-embed-text-v1
|
|
||||||
type: embedding
|
|
||||||
provider: bedrock
|
|
||||||
description: amazon.titan-embed-text-v1文本嵌入模型,8192上下文窗口
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 文本嵌入模型
|
|
||||||
logo: bedrock
|
|
||||||
- name: amazon.titan-embed-text-v2:0
|
|
||||||
type: embedding
|
|
||||||
provider: bedrock
|
|
||||||
description: amazon.titan-embed-text-v2:0文本嵌入模型,8192上下文窗口
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 文本嵌入模型
|
|
||||||
logo: bedrock
|
|
||||||
- name: cohere.embed-english-v3
|
|
||||||
type: embedding
|
|
||||||
provider: bedrock
|
|
||||||
description: Cohere Embed 3 English文本嵌入模型,512上下文窗口
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 文本嵌入模型
|
|
||||||
logo: bedrock
|
|
||||||
- name: cohere.embed-multilingual-v3
|
|
||||||
type: embedding
|
|
||||||
provider: bedrock
|
|
||||||
description: Cohere Embed 3 Multilingual文本嵌入模型,512上下文窗口
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 文本嵌入模型
|
|
||||||
logo: bedrock
|
|
||||||
@@ -1,820 +0,0 @@
|
|||||||
provider: dashscope
|
|
||||||
enabled: true
|
|
||||||
models:
|
|
||||||
- name: deepseek-r1-distill-qwen-14b
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: DeepSeek-R1-Distill-Qwen-14B大语言模型,支持智能体思考,32000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
logo: dashscope
|
|
||||||
- name: deepseek-r1-distill-qwen-32b
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: DeepSeek-R1-Distill-Qwen-32B大语言模型,支持智能体思考,32000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
logo: dashscope
|
|
||||||
- name: deepseek-r1
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: DeepSeek-R1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
logo: dashscope
|
|
||||||
- name: deepseek-v3.1
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: DeepSeek-V3.1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
logo: dashscope
|
|
||||||
- name: deepseek-v3.2-exp
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: DeepSeek-V3.2-exp实验版大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
logo: dashscope
|
|
||||||
- name: deepseek-v3.2
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: DeepSeek-V3.2大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
logo: dashscope
|
|
||||||
- name: deepseek-v3
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: DeepSeek-V3大语言模型,支持智能体思考,64000上下文窗口,对话模式,支持文本与JSON格式输出
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
logo: dashscope
|
|
||||||
- name: farui-plus
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: farui-plus大语言模型,支持多工具调用、智能体思考、流式工具调用,12288上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: glm-4.7
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: GLM-4.7大语言模型,支持多工具调用、智能体思考、流式工具调用,202752超大上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qvq-max-latest
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qvq-max-latest大语言模型,支持视觉、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- vision
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qvq-max
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qvq-max大语言模型,支持视觉、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- vision
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen-coder-turbo-0919
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen-coder-turbo-0919代码专用大语言模型,支持智能体思考,131072上下文窗口,对话模式,已废弃
|
|
||||||
is_deprecated: true
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- 代码模型
|
|
||||||
- agent-thought
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen-max-latest
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen-max-latest大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen-max-longcontext
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen-max-longcontext长上下文大语言模型,支持多工具调用、智能体思考、流式工具调用,32000上下文窗口,对话模式,已废弃
|
|
||||||
is_deprecated: true
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen-max
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen-max大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,支持联网搜索
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen-mt-plus
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen-mt-plus多语言翻译大语言模型,支持智能体思考,16384上下文窗口,对话模式,支持多语种互译与领域翻译适配
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- 翻译模型
|
|
||||||
- agent-thought
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen-mt-turbo
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen-mt-turbo轻量化多语言翻译大语言模型,支持智能体思考,16384上下文窗口,对话模式,支持多语种互译与领域翻译适配
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- 翻译模型
|
|
||||||
- agent-thought
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen-plus-0112
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen-plus-0112大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
|
||||||
is_deprecated: true
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen-plus-0125
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen-plus-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
|
||||||
is_deprecated: true
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen-plus-0723
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen-plus-0723大语言模型,支持多工具调用、智能体思考、流式工具调用,32000上下文窗口,对话模式,支持联网搜索,已废弃
|
|
||||||
is_deprecated: true
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen-plus-0806
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen-plus-0806大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
|
||||||
is_deprecated: true
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen-plus-0919
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen-plus-0919大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
|
||||||
is_deprecated: true
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen-plus-1125
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen-plus-1125大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
|
||||||
is_deprecated: true
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen-plus-1127
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen-plus-1127大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
|
||||||
is_deprecated: true
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen-plus-1220
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen-plus-1220大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,已废弃
|
|
||||||
is_deprecated: true
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen-vl-max
|
|
||||||
type: chat
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen-vl-max多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- 多模态模型
|
|
||||||
- vision
|
|
||||||
- agent-thought
|
|
||||||
- video
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen-vl-plus-0809
|
|
||||||
type: chat
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen-vl-plus-0809多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,已废弃
|
|
||||||
is_deprecated: true
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- 多模态模型
|
|
||||||
- vision
|
|
||||||
- agent-thought
|
|
||||||
- video
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen-vl-plus-2025-01-02
|
|
||||||
type: chat
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen-vl-plus-2025-01-02多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,未废弃
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- 多模态模型
|
|
||||||
- vision
|
|
||||||
- agent-thought
|
|
||||||
- video
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen-vl-plus-2025-01-25
|
|
||||||
type: chat
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen-vl-plus-2025-01-25多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- 多模态模型
|
|
||||||
- vision
|
|
||||||
- agent-thought
|
|
||||||
- video
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen-vl-plus-latest
|
|
||||||
type: chat
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen-vl-plus-latest多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- 多模态模型
|
|
||||||
- vision
|
|
||||||
- agent-thought
|
|
||||||
- video
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen-vl-plus
|
|
||||||
type: chat
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen-vl-plus多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- 多模态模型
|
|
||||||
- vision
|
|
||||||
- agent-thought
|
|
||||||
- video
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen2.5-0.5b-instruct
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen2.5-0.5b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,未废弃
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-14b
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-14b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-235b-a22b-instruct-2507
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-235b-a22b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-235b-a22b-thinking-2507
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-235b-a22b-thinking-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-235b-a22b
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-235b-a22b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-30b-a3b-instruct-2507
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-30b-a3b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-30b-a3b
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-30b-a3b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-32b
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-32b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-4b
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-4b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-8b
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-8b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-coder-30b-a3b-instruct
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-coder-30b-a3b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- 代码模型
|
|
||||||
- agent-thought
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-coder-480b-a35b-instruct
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-coder-480b-a35b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- 代码模型
|
|
||||||
- agent-thought
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-coder-plus-2025-09-23
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-coder-plus-2025-09-23大语言模型,支持智能体思考,1000000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- 代码模型
|
|
||||||
- agent-thought
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-coder-plus
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-coder-plus大语言模型,支持智能体思考,1000000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- 代码模型
|
|
||||||
- agent-thought
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-max-2025-09-23
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-max-2025-09-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
- 联网搜索
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-max-2026-01-23
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-max-2026-01-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
- 联网搜索
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-max-preview
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-max-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-max
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-max大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
- 联网搜索
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-next-80b-a3b-instruct
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-next-80b-a3b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-next-80b-a3b-thinking
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-next-80b-a3b-thinking大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-omni-flash-2025-12-01
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-omni-flash-2025-12-01多模态大语言模型,支持视觉、智能体思考、视频、音频能力,65536上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- 多模态模型
|
|
||||||
- vision
|
|
||||||
- agent-thought
|
|
||||||
- video
|
|
||||||
- audio
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-vl-235b-a22b-instruct
|
|
||||||
type: chat
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-vl-235b-a22b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- 多模态模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
- vision
|
|
||||||
- video
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-vl-235b-a22b-thinking
|
|
||||||
type: chat
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-vl-235b-a22b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- 多模态模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
- vision
|
|
||||||
- video
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-vl-30b-a3b-instruct
|
|
||||||
type: chat
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-vl-30b-a3b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- 多模态模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
- vision
|
|
||||||
- video
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-vl-30b-a3b-thinking
|
|
||||||
type: chat
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-vl-30b-a3b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- 多模态模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
- vision
|
|
||||||
- video
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-vl-flash
|
|
||||||
type: chat
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-vl-flash多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- 多模态模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
- vision
|
|
||||||
- video
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-vl-plus-2025-09-23
|
|
||||||
type: chat
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-vl-plus-2025-09-23多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- 多模态模型
|
|
||||||
- vision
|
|
||||||
- agent-thought
|
|
||||||
- video
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwen3-vl-plus
|
|
||||||
type: chat
|
|
||||||
provider: dashscope
|
|
||||||
description: qwen3-vl-plus多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- 多模态模型
|
|
||||||
- vision
|
|
||||||
- agent-thought
|
|
||||||
- video
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwq-32b
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwq-32b大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwq-plus-0305
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwq-plus-0305大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: qwq-plus
|
|
||||||
type: llm
|
|
||||||
provider: dashscope
|
|
||||||
description: qwq-plus大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: dashscope
|
|
||||||
- name: gte-rerank-v2
|
|
||||||
type: rerank
|
|
||||||
provider: dashscope
|
|
||||||
description: gte-rerank-v2重排序模型,4000上下文窗口
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 重排序模型
|
|
||||||
logo: dashscope
|
|
||||||
- name: gte-rerank
|
|
||||||
type: rerank
|
|
||||||
provider: dashscope
|
|
||||||
description: gte-rerank重排序模型,4000上下文窗口
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 重排序模型
|
|
||||||
logo: dashscope
|
|
||||||
- name: multimodal-embedding-v1
|
|
||||||
type: embedding
|
|
||||||
provider: dashscope
|
|
||||||
description: multimodal-embedding-v1多模态嵌入模型,支持视觉能力,8192上下文窗口,最大分块数10
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 嵌入模型
|
|
||||||
- 多模态模型
|
|
||||||
- vision
|
|
||||||
logo: dashscope
|
|
||||||
- name: text-embedding-v1
|
|
||||||
type: embedding
|
|
||||||
provider: dashscope
|
|
||||||
description: text-embedding-v1文本嵌入模型,2048上下文窗口,最大分块数25
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 嵌入模型
|
|
||||||
- 文本嵌入
|
|
||||||
logo: dashscope
|
|
||||||
- name: text-embedding-v2
|
|
||||||
type: embedding
|
|
||||||
provider: dashscope
|
|
||||||
description: text-embedding-v2文本嵌入模型,2048上下文窗口,最大分块数25
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 嵌入模型
|
|
||||||
- 文本嵌入
|
|
||||||
logo: dashscope
|
|
||||||
- name: text-embedding-v3
|
|
||||||
type: embedding
|
|
||||||
provider: dashscope
|
|
||||||
description: text-embedding-v3文本嵌入模型,8192上下文窗口,最大分块数10
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 嵌入模型
|
|
||||||
- 文本嵌入
|
|
||||||
logo: dashscope
|
|
||||||
- name: text-embedding-v4
|
|
||||||
type: embedding
|
|
||||||
provider: dashscope
|
|
||||||
description: text-embedding-v4文本嵌入模型,8192上下文窗口,最大分块数10
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 嵌入模型
|
|
||||||
- 文本嵌入
|
|
||||||
logo: dashscope
|
|
||||||
@@ -1,143 +0,0 @@
|
|||||||
"""模型配置加载器 - 用于将预定义模型批量导入到数据库"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from app.models.models_model import ModelBase, ModelProvider
|
|
||||||
|
|
||||||
|
|
||||||
def _load_yaml_config(provider: ModelProvider) -> list[dict]:
|
|
||||||
"""从YAML文件加载指定供应商的模型配置"""
|
|
||||||
config_dir = Path(__file__).parent
|
|
||||||
config_file = config_dir / f"{provider.value}_models.yaml"
|
|
||||||
|
|
||||||
if not config_file.exists():
|
|
||||||
return []
|
|
||||||
|
|
||||||
with open(config_file, 'r', encoding='utf-8') as f:
|
|
||||||
data = yaml.safe_load(f)
|
|
||||||
|
|
||||||
# 检查是否需要加载(默认为 true)
|
|
||||||
if not data.get('enabled', True):
|
|
||||||
return []
|
|
||||||
|
|
||||||
return data.get('models', [])
|
|
||||||
|
|
||||||
|
|
||||||
def _disable_yaml_config(provider: ModelProvider) -> None:
|
|
||||||
"""将YAML文件的enabled标志设置为false"""
|
|
||||||
config_dir = Path(__file__).parent
|
|
||||||
config_file = config_dir / f"{provider.value}_models.yaml"
|
|
||||||
|
|
||||||
if not config_file.exists():
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(config_file, 'r', encoding='utf-8') as f:
|
|
||||||
data = yaml.safe_load(f)
|
|
||||||
|
|
||||||
data['enabled'] = False
|
|
||||||
|
|
||||||
with open(config_file, 'w', encoding='utf-8') as f:
|
|
||||||
yaml.dump(data, f, allow_unicode=True, sort_keys=False)
|
|
||||||
|
|
||||||
|
|
||||||
def load_models(db: Session, providers: list[str] = None, silent: bool = False) -> dict:
|
|
||||||
"""
|
|
||||||
加载模型配置到数据库
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: 数据库会话
|
|
||||||
providers: 要加载的供应商列表,None表示加载所有
|
|
||||||
silent: 是否静默模式(不输出详细日志)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 加载结果统计 {"success": int, "skipped": int, "failed": int}
|
|
||||||
"""
|
|
||||||
result = {"success": 0, "skipped": 0, "failed": 0}
|
|
||||||
|
|
||||||
# 确定要加载的供应商
|
|
||||||
if providers:
|
|
||||||
target_providers = [ModelProvider(p) if isinstance(p, str) else p for p in providers]
|
|
||||||
else:
|
|
||||||
target_providers = [p for p in ModelProvider if p != ModelProvider.COMPOSITE]
|
|
||||||
|
|
||||||
for provider in target_providers:
|
|
||||||
# 从YAML文件加载模型配置
|
|
||||||
models = _load_yaml_config(provider)
|
|
||||||
|
|
||||||
if not models:
|
|
||||||
if not silent:
|
|
||||||
print(f"警告: 供应商 '{provider.value}' 暂无预定义模型")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not silent:
|
|
||||||
print(f"\n正在加载 {provider.value} 的 {len(models)} 个模型...")
|
|
||||||
|
|
||||||
# provider_success = 0
|
|
||||||
for model_data in models:
|
|
||||||
try:
|
|
||||||
# 检查模型是否已存在
|
|
||||||
existing = db.query(ModelBase).filter(
|
|
||||||
ModelBase.name == model_data["name"],
|
|
||||||
ModelBase.provider == model_data["provider"]
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if existing:
|
|
||||||
# 更新现有模型配置
|
|
||||||
for key, value in model_data.items():
|
|
||||||
setattr(existing, key, value)
|
|
||||||
db.commit()
|
|
||||||
if not silent:
|
|
||||||
print(f"更新成功: {model_data['name']}")
|
|
||||||
result["success"] += 1
|
|
||||||
# provider_success += 1
|
|
||||||
else:
|
|
||||||
# 创建新模型
|
|
||||||
model = ModelBase(**model_data)
|
|
||||||
db.add(model)
|
|
||||||
db.commit()
|
|
||||||
if not silent:
|
|
||||||
print(f"添加成功: {model_data['name']}")
|
|
||||||
result["success"] += 1
|
|
||||||
# provider_success += 1
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
db.rollback()
|
|
||||||
if not silent:
|
|
||||||
print(f"添加失败: {model_data['name']} - {str(e)}")
|
|
||||||
result["failed"] += 1
|
|
||||||
|
|
||||||
# 如果该供应商的模型全部加载成功,将enabled设置为false
|
|
||||||
# if provider_success == len(models):
|
|
||||||
_disable_yaml_config(provider)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def load_models_by_provider(db: Session, provider: str) -> dict:
|
|
||||||
"""
|
|
||||||
加载指定供应商的模型配置
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: 数据库会话
|
|
||||||
provider: 供应商名称(字符串或ModelProvider枚举)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 加载结果统计
|
|
||||||
"""
|
|
||||||
provider_enum = ModelProvider(provider) if isinstance(provider, str) else provider
|
|
||||||
return load_models(db, providers=[provider_enum])
|
|
||||||
|
|
||||||
|
|
||||||
def get_available_providers() -> list[Callable[[], str]]:
|
|
||||||
"""获取所有可用的供应商列表(从ModelProvider枚举获取,排除COMPOSITE)"""
|
|
||||||
return [p.value for p in ModelProvider if p != ModelProvider.COMPOSITE]
|
|
||||||
|
|
||||||
|
|
||||||
def get_models_by_provider(provider: str) -> list[dict]:
|
|
||||||
"""获取指定供应商的模型配置列表"""
|
|
||||||
provider_enum = ModelProvider(provider) if isinstance(provider, str) else provider
|
|
||||||
return _load_yaml_config(provider_enum)
|
|
||||||
@@ -1,294 +0,0 @@
|
|||||||
provider: openai
|
|
||||||
enabled: true
|
|
||||||
models:
|
|
||||||
- name: chatgpt-4o-latest
|
|
||||||
type: llm
|
|
||||||
provider: openai
|
|
||||||
description: chatgpt-4o-latest大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
- vision
|
|
||||||
logo: openai
|
|
||||||
- name: gpt-3.5-turbo-0125
|
|
||||||
type: llm
|
|
||||||
provider: openai
|
|
||||||
description: gpt-3.5-turbo-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: openai
|
|
||||||
- name: gpt-3.5-turbo-1106
|
|
||||||
type: llm
|
|
||||||
provider: openai
|
|
||||||
description: gpt-3.5-turbo-1106大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: openai
|
|
||||||
- name: gpt-3.5-turbo-16k
|
|
||||||
type: llm
|
|
||||||
provider: openai
|
|
||||||
description: gpt-3.5-turbo-16k大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: openai
|
|
||||||
- name: gpt-3.5-turbo-instruct
|
|
||||||
type: llm
|
|
||||||
provider: openai
|
|
||||||
description: gpt-3.5-turbo-instruct大语言模型,4096上下文窗口,文本补全模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
logo: openai
|
|
||||||
- name: gpt-3.5-turbo
|
|
||||||
type: llm
|
|
||||||
provider: openai
|
|
||||||
description: gpt-3.5-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: openai
|
|
||||||
- name: gpt-4-0125-preview
|
|
||||||
type: llm
|
|
||||||
provider: openai
|
|
||||||
description: gpt-4-0125-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: openai
|
|
||||||
- name: gpt-4-1106-preview
|
|
||||||
type: llm
|
|
||||||
provider: openai
|
|
||||||
description: gpt-4-1106-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: openai
|
|
||||||
- name: gpt-4-turbo-2024-04-09
|
|
||||||
type: llm
|
|
||||||
provider: openai
|
|
||||||
description: gpt-4-turbo-2024-04-09大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
- vision
|
|
||||||
logo: openai
|
|
||||||
- name: gpt-4-turbo-preview
|
|
||||||
type: llm
|
|
||||||
provider: openai
|
|
||||||
description: gpt-4-turbo-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
logo: openai
|
|
||||||
- name: gpt-4-turbo
|
|
||||||
type: llm
|
|
||||||
provider: openai
|
|
||||||
description: gpt-4-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
- vision
|
|
||||||
logo: openai
|
|
||||||
- name: o1-preview
|
|
||||||
type: llm
|
|
||||||
provider: openai
|
|
||||||
description: o1-preview大语言模型,支持智能体思考,128000上下文窗口,对话模式,已废弃
|
|
||||||
is_deprecated: true
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
logo: openai
|
|
||||||
- name: o1
|
|
||||||
type: llm
|
|
||||||
provider: openai
|
|
||||||
description: o1大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- multi-tool-call
|
|
||||||
- agent-thought
|
|
||||||
- stream-tool-call
|
|
||||||
- vision
|
|
||||||
- structured-output
|
|
||||||
logo: openai
|
|
||||||
- name: o3-2025-04-16
|
|
||||||
type: llm
|
|
||||||
provider: openai
|
|
||||||
description: o3-2025-04-16大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
- tool-call
|
|
||||||
- vision
|
|
||||||
- stream-tool-call
|
|
||||||
- structured-output
|
|
||||||
logo: openai
|
|
||||||
- name: o3-mini-2025-01-31
|
|
||||||
type: llm
|
|
||||||
provider: openai
|
|
||||||
description: o3-mini-2025-01-31大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
- tool-call
|
|
||||||
- stream-tool-call
|
|
||||||
- structured-output
|
|
||||||
logo: openai
|
|
||||||
- name: o3-mini
|
|
||||||
type: llm
|
|
||||||
provider: openai
|
|
||||||
description: o3-mini大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
- tool-call
|
|
||||||
- stream-tool-call
|
|
||||||
- structured-output
|
|
||||||
logo: openai
|
|
||||||
- name: o3-pro-2025-06-10
|
|
||||||
type: llm
|
|
||||||
provider: openai
|
|
||||||
description: o3-pro-2025-06-10大语言模型,支持智能体思考、工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
- tool-call
|
|
||||||
- vision
|
|
||||||
- structured-output
|
|
||||||
logo: openai
|
|
||||||
- name: o3-pro
|
|
||||||
type: llm
|
|
||||||
provider: openai
|
|
||||||
description: o3-pro大语言模型,支持智能体思考、工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
- tool-call
|
|
||||||
- vision
|
|
||||||
- structured-output
|
|
||||||
logo: openai
|
|
||||||
- name: o3
|
|
||||||
type: llm
|
|
||||||
provider: openai
|
|
||||||
description: o3大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
- vision
|
|
||||||
- tool-call
|
|
||||||
- stream-tool-call
|
|
||||||
- structured-output
|
|
||||||
logo: openai
|
|
||||||
- name: o4-mini-2025-04-16
|
|
||||||
type: llm
|
|
||||||
provider: openai
|
|
||||||
description: o4-mini-2025-04-16大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
- tool-call
|
|
||||||
- vision
|
|
||||||
- stream-tool-call
|
|
||||||
- structured-output
|
|
||||||
logo: openai
|
|
||||||
- name: o4-mini
|
|
||||||
type: llm
|
|
||||||
provider: openai
|
|
||||||
description: o4-mini大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 大语言模型
|
|
||||||
- agent-thought
|
|
||||||
- tool-call
|
|
||||||
- vision
|
|
||||||
- stream-tool-call
|
|
||||||
- structured-output
|
|
||||||
logo: openai
|
|
||||||
- name: text-embedding-3-large
|
|
||||||
type: embedding
|
|
||||||
provider: openai
|
|
||||||
description: text-embedding-3-large文本向量模型,8191上下文窗口,最大分块数32
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 文本向量模型
|
|
||||||
logo: openai
|
|
||||||
- name: text-embedding-3-small
|
|
||||||
type: embedding
|
|
||||||
provider: openai
|
|
||||||
description: text-embedding-3-small文本向量模型,8191上下文窗口,最大分块数32
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 文本向量模型
|
|
||||||
logo: openai
|
|
||||||
- name: text-embedding-ada-002
|
|
||||||
type: embedding
|
|
||||||
provider: openai
|
|
||||||
description: text-embedding-ada-002文本向量模型,8097上下文窗口,最大分块数32
|
|
||||||
is_deprecated: false
|
|
||||||
is_official: true
|
|
||||||
tags:
|
|
||||||
- 文本向量模型
|
|
||||||
logo: openai
|
|
||||||
@@ -28,9 +28,7 @@ from app.core.rag.common.float_utils import get_float
|
|||||||
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
|
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
|
||||||
from app.core.rag.llm.chat_model import Base
|
from app.core.rag.llm.chat_model import Base
|
||||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def knowledge_retrieval(
|
def knowledge_retrieval(
|
||||||
query: str,
|
query: str,
|
||||||
@@ -64,15 +62,7 @@ def knowledge_retrieval(
|
|||||||
merge_strategy = config.get("merge_strategy", "weight")
|
merge_strategy = config.get("merge_strategy", "weight")
|
||||||
reranker_id = config.get("reranker_id")
|
reranker_id = config.get("reranker_id")
|
||||||
reranker_top_k = config.get("reranker_top_k", 1024)
|
reranker_top_k = config.get("reranker_top_k", 1024)
|
||||||
# use_graph = config.get("use_graph", "false").lower() == "true"
|
use_graph = config.get("use_graph", "false").lower() == "true"
|
||||||
|
|
||||||
use_graph_value = config.get("use_graph", False)
|
|
||||||
if isinstance(use_graph_value, bool):
|
|
||||||
use_graph = use_graph_value
|
|
||||||
elif isinstance(use_graph_value, str):
|
|
||||||
use_graph = use_graph_value.lower() in ("true", "1", "yes")
|
|
||||||
else:
|
|
||||||
use_graph = False
|
|
||||||
|
|
||||||
file_names_filter = []
|
file_names_filter = []
|
||||||
if user_ids:
|
if user_ids:
|
||||||
@@ -169,29 +159,13 @@ def knowledge_retrieval(
|
|||||||
|
|
||||||
# Use the specified reranker for re-ranking
|
# Use the specified reranker for re-ranking
|
||||||
if reranker_id:
|
if reranker_id:
|
||||||
try:
|
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
|
||||||
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
|
# use graph
|
||||||
except Exception as rerank_error:
|
|
||||||
# If reranker fails, log warning and continue with original results
|
|
||||||
logger.warning(
|
|
||||||
"Reranker failed, falling back to original results",
|
|
||||||
extra={
|
|
||||||
"reranker_id": reranker_id,
|
|
||||||
"query": query,
|
|
||||||
"doc_count": len(all_results),
|
|
||||||
"error": str(rerank_error),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_graph:
|
if use_graph:
|
||||||
try:
|
from app.core.rag.common.settings import kg_retriever
|
||||||
from app.core.rag.common.settings import kg_retriever
|
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||||
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
if doc:
|
||||||
if doc:
|
all_results.insert(0, doc)
|
||||||
all_results.insert(0, doc)
|
|
||||||
except Exception as graph_error:
|
|
||||||
print(f"Failed to retrieve from knowledge graph: {str(graph_error)}")
|
|
||||||
|
|
||||||
return all_results
|
return all_results
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from app.core.workflow.graph_builder import GraphBuilder, StreamOutputConfig
|
|||||||
from app.core.workflow.nodes import WorkflowState
|
from app.core.workflow.nodes import WorkflowState
|
||||||
from app.core.workflow.nodes.base_config import VariableType
|
from app.core.workflow.nodes.base_config import VariableType
|
||||||
from app.core.workflow.nodes.enums import NodeType
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
|
from app.core.workflow.template_renderer import render_template
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -156,137 +157,12 @@ class WorkflowExecutor:
|
|||||||
"error": result.get("error"),
|
"error": result.get("error"),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _update_scope_activate(self, scope, status=None):
|
def _update_end_activate(self, node_id):
|
||||||
"""
|
|
||||||
Update the activation state of all End nodes based on a completed scope (node or variable).
|
|
||||||
|
|
||||||
Iterates over all End nodes in `self.end_outputs` and calls
|
|
||||||
`update_activate` on each, which may:
|
|
||||||
- Activate variable segments that depend on the completed node/scope.
|
|
||||||
- Activate the entire End node output if all control conditions are met.
|
|
||||||
|
|
||||||
If any End node becomes active and `self.activate_end` is not yet set,
|
|
||||||
this node will be marked as the currently active End node.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scope (str): The node ID or scope that has completed execution.
|
|
||||||
status (str | None): Optional status of the node (used for branch/control nodes).
|
|
||||||
"""
|
|
||||||
for node in self.end_outputs.keys():
|
for node in self.end_outputs.keys():
|
||||||
self.end_outputs[node].update_activate(scope, status)
|
self.end_outputs[node].update_activate(node_id)
|
||||||
if self.end_outputs[node].activate and self.activate_end is None:
|
if self.end_outputs[node].activate and self.activate_end is None:
|
||||||
self.activate_end = node
|
self.activate_end = node
|
||||||
|
|
||||||
def _update_stream_output_status(self, activate, data):
|
|
||||||
"""
|
|
||||||
Update the stream output state of End nodes based on workflow state updates.
|
|
||||||
|
|
||||||
This method checks which nodes/scopes are activated and propagates
|
|
||||||
activation to End nodes accordingly.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
activate (dict): Mapping of node_id -> bool indicating which nodes/scopes are activated.
|
|
||||||
data (dict): Mapping of node_id -> node runtime data, including outputs.
|
|
||||||
|
|
||||||
Behavior:
|
|
||||||
For each node in `data`:
|
|
||||||
1. If the node is activated (`activate[node_id]` is True),
|
|
||||||
retrieve its output status from `runtime_vars`.
|
|
||||||
2. Call `_update_scope_activate` to propagate the activation
|
|
||||||
to all relevant End nodes and update `self.activate_end`.
|
|
||||||
"""
|
|
||||||
for node_id in data.keys():
|
|
||||||
if activate.get(node_id):
|
|
||||||
node_output_status = (
|
|
||||||
data[node_id]
|
|
||||||
.get('runtime_vars', {})
|
|
||||||
.get(node_id)
|
|
||||||
.get("output")
|
|
||||||
)
|
|
||||||
self._update_scope_activate(node_id, status=node_output_status)
|
|
||||||
|
|
||||||
async def _emit_active_chunks(
|
|
||||||
self,
|
|
||||||
node_outputs: dict,
|
|
||||||
variables: dict,
|
|
||||||
force=False
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Process and yield all currently active output segments for the currently active End node.
|
|
||||||
|
|
||||||
This method handles stream-mode output for an End node by iterating through its output segments
|
|
||||||
(`OutputContent`). Only segments marked as active (`activate=True`) are processed, unless
|
|
||||||
`force=True`, which allows all segments to be processed regardless of their activation state.
|
|
||||||
|
|
||||||
Behavior:
|
|
||||||
1. Iterates from the current `cursor` position to the end of the outputs list.
|
|
||||||
2. For each segment:
|
|
||||||
- If the segment is literal text (`is_variable=False`), append it directly.
|
|
||||||
- If the segment is a variable (`is_variable=True`), evaluate it using
|
|
||||||
`evaluate_expression` with the given `node_outputs` and `variables`,
|
|
||||||
then transform the result with `_trans_output_string`.
|
|
||||||
3. Yield a stream event of type "message" containing the processed chunk.
|
|
||||||
4. Move the `cursor` forward after processing each segment.
|
|
||||||
5. When all segments have been processed, remove this End node from `end_outputs`
|
|
||||||
and reset `activate_end` to None.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_outputs (dict): Current runtime node outputs, used for variable evaluation.
|
|
||||||
variables (dict): Current runtime variables, used for variable evaluation.
|
|
||||||
force (bool, default=False): If True, process segments even if `activate=False`.
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
dict: A stream event of type "message" containing the processed chunk.
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
- Segments that fail evaluation (ValueError) are skipped with a warning logged.
|
|
||||||
- This method only processes the currently active End node (`self.activate_end`).
|
|
||||||
- Use `force=True` for final emission regardless of activation state.
|
|
||||||
"""
|
|
||||||
|
|
||||||
end_info = self.end_outputs[self.activate_end]
|
|
||||||
|
|
||||||
while end_info.cursor < len(end_info.outputs):
|
|
||||||
final_chunk = ''
|
|
||||||
current_segment = end_info.outputs[end_info.cursor]
|
|
||||||
|
|
||||||
if not current_segment.activate and not force:
|
|
||||||
# Stop processing until this segment becomes active
|
|
||||||
break
|
|
||||||
|
|
||||||
# Literal segment
|
|
||||||
if not current_segment.is_variable:
|
|
||||||
final_chunk += current_segment.literal
|
|
||||||
else:
|
|
||||||
# Variable segment: evaluate and transform
|
|
||||||
try:
|
|
||||||
chunk = evaluate_expression(
|
|
||||||
current_segment.literal,
|
|
||||||
variables=variables,
|
|
||||||
node_outputs=node_outputs
|
|
||||||
)
|
|
||||||
chunk = self._trans_output_string(chunk)
|
|
||||||
final_chunk += chunk
|
|
||||||
except ValueError:
|
|
||||||
# Log failed evaluation but continue streaming
|
|
||||||
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}")
|
|
||||||
|
|
||||||
if final_chunk:
|
|
||||||
yield {
|
|
||||||
"event": "message",
|
|
||||||
"data": {
|
|
||||||
"chunk": final_chunk
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Advance cursor after processing
|
|
||||||
end_info.cursor += 1
|
|
||||||
|
|
||||||
# Remove End node from active tracking if all segments have been processed
|
|
||||||
if end_info.cursor >= len(end_info.outputs):
|
|
||||||
self.end_outputs.pop(self.activate_end)
|
|
||||||
self.activate_end = None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _trans_output_string(content):
|
def _trans_output_string(content):
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
@@ -342,8 +218,14 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
|
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
|
||||||
full_content = ''
|
full_content = ''
|
||||||
for end_id in self.end_outputs.keys():
|
for end_info in self.end_outputs.values():
|
||||||
full_content += result.get('runtime_vars', {}).get(end_id, {}).get('output', '')
|
output_template = "".join([output.literal for output in end_info.outputs])
|
||||||
|
full_content += render_template(
|
||||||
|
output_template,
|
||||||
|
result.get("variables", {}),
|
||||||
|
result.get("runtime_vars", {}),
|
||||||
|
strict=False
|
||||||
|
)
|
||||||
result["messages"].extend(
|
result["messages"].extend(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
@@ -424,7 +306,7 @@ class WorkflowExecutor:
|
|||||||
try:
|
try:
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
full_content = ''
|
full_content = ''
|
||||||
self._update_scope_activate("sys")
|
|
||||||
async for event in graph.astream(
|
async for event in graph.astream(
|
||||||
initial_state,
|
initial_state,
|
||||||
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
|
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
|
||||||
@@ -451,12 +333,9 @@ class WorkflowExecutor:
|
|||||||
if not end_info or end_info.cursor >= len(end_info.outputs):
|
if not end_info or end_info.cursor >= len(end_info.outputs):
|
||||||
continue
|
continue
|
||||||
current_output = end_info.outputs[end_info.cursor]
|
current_output = end_info.outputs[end_info.cursor]
|
||||||
if current_output.is_variable and current_output.depends_on_scope(node_id):
|
if current_output.is_variable and current_output.depends_on_node(node_id):
|
||||||
if data.get("done"):
|
if data.get("done"):
|
||||||
end_info.cursor += 1
|
end_info.cursor += 1
|
||||||
if end_info.cursor >= len(end_info.outputs):
|
|
||||||
self.end_outputs.pop(self.activate_end)
|
|
||||||
self.activate_end = None
|
|
||||||
else:
|
else:
|
||||||
full_content += data.get("chunk")
|
full_content += data.get("chunk")
|
||||||
yield {
|
yield {
|
||||||
@@ -536,53 +415,91 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
elif mode == "updates":
|
elif mode == "updates":
|
||||||
# Handle state updates - store final state
|
# Handle state updates - store final state
|
||||||
state = graph.get_state(config=self.checkpoint_config).values
|
for node_id in data.keys():
|
||||||
node_outputs = state.get("runtime_vars", {})
|
self._update_end_activate(node_id)
|
||||||
variables = state.get("variables", {})
|
|
||||||
activate = state.get("activate", {})
|
|
||||||
for _, node_data in data.items():
|
|
||||||
node_outputs |= node_data.get("runtime_vars", {})
|
|
||||||
variables |= node_data.get("variables", {})
|
|
||||||
|
|
||||||
self._update_stream_output_status(activate, data)
|
|
||||||
wait = False
|
wait = False
|
||||||
while self.activate_end and not wait:
|
state = graph.get_state(config=self.checkpoint_config)
|
||||||
async for msg_event in self._emit_active_chunks(
|
node_outputs = state.values.get("runtime_vars", {})
|
||||||
node_outputs=node_outputs,
|
for _ in data.keys():
|
||||||
variables=variables
|
node_outputs = node_outputs | data.get(_).get("runtime_vars", {})
|
||||||
):
|
|
||||||
full_content += msg_event["data"]['chunk']
|
|
||||||
yield msg_event
|
|
||||||
|
|
||||||
if self.activate_end:
|
while self.activate_end and not wait:
|
||||||
|
message = ''
|
||||||
|
logger.info(self.activate_end)
|
||||||
|
end_info = self.end_outputs[self.activate_end]
|
||||||
|
content = end_info.outputs[end_info.cursor]
|
||||||
|
while content.activate:
|
||||||
|
if not content.is_variable:
|
||||||
|
full_content += content.literal
|
||||||
|
message += content.literal
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
chunk = evaluate_expression(
|
||||||
|
content.literal,
|
||||||
|
variables={},
|
||||||
|
node_outputs=node_outputs
|
||||||
|
)
|
||||||
|
chunk = self._trans_output_string(chunk)
|
||||||
|
message += chunk
|
||||||
|
full_content += chunk
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
end_info.cursor += 1
|
||||||
|
if end_info.cursor == len(end_info.outputs):
|
||||||
|
break
|
||||||
|
content = end_info.outputs[end_info.cursor]
|
||||||
|
if end_info.cursor != len(end_info.outputs):
|
||||||
wait = True
|
wait = True
|
||||||
else:
|
else:
|
||||||
self._update_stream_output_status(activate, data)
|
self.end_outputs.pop(self.activate_end)
|
||||||
|
self.activate_end = None
|
||||||
|
for node_id in data.keys():
|
||||||
|
self._update_end_activate(node_id)
|
||||||
|
if message:
|
||||||
|
yield {
|
||||||
|
"event": "message",
|
||||||
|
"data": {
|
||||||
|
"chunk": message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
|
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
|
||||||
f"- execution_id: {self.execution_id}")
|
f"- execution_id: {self.execution_id}")
|
||||||
|
|
||||||
result = graph.get_state(self.checkpoint_config).values
|
result = graph.get_state(self.checkpoint_config).values
|
||||||
node_outputs = result.get("runtime_vars", {})
|
while self.activate_end:
|
||||||
variables = result.get("variables", {})
|
message = ''
|
||||||
self.end_outputs = {
|
end_info = self.end_outputs[self.activate_end]
|
||||||
node_id: node_info
|
content = end_info.outputs[end_info.cursor]
|
||||||
for node_id, node_info in self.end_outputs.items()
|
if not content.is_variable:
|
||||||
if node_info.activate
|
message += content.literal
|
||||||
}
|
else:
|
||||||
|
node_outputs = result.get("runtime_vars", {})
|
||||||
if self.end_outputs or self.activate_end:
|
variables = result.get("variables", {})
|
||||||
while self.activate_end:
|
try:
|
||||||
async for msg_event in self._emit_active_chunks(
|
chunk = evaluate_expression(
|
||||||
node_outputs=node_outputs,
|
content.literal,
|
||||||
variables=variables,
|
variables=variables,
|
||||||
force=True
|
node_outputs=node_outputs
|
||||||
):
|
)
|
||||||
full_content += msg_event["data"]['chunk']
|
chunk = self._trans_output_string(chunk)
|
||||||
yield msg_event
|
message += chunk
|
||||||
|
full_content += chunk
|
||||||
if not self.activate_end and self.end_outputs:
|
except ValueError:
|
||||||
|
pass
|
||||||
|
end_info.cursor += 1
|
||||||
|
if end_info.cursor == len(end_info.outputs):
|
||||||
|
self.end_outputs.pop(self.activate_end)
|
||||||
|
self.activate_end = None
|
||||||
|
if self.end_outputs:
|
||||||
self.activate_end = list(self.end_outputs.keys())[0]
|
self.activate_end = list(self.end_outputs.keys())[0]
|
||||||
|
if message:
|
||||||
|
yield {
|
||||||
|
"event": "message",
|
||||||
|
"data": {
|
||||||
|
"chunk": message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
# 计算耗时
|
# 计算耗时
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
|
|||||||
@@ -53,110 +53,114 @@ class OutputContent(BaseModel):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def depends_on_scope(self, scope: str) -> bool:
|
def depends_on_node(self, node_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if this segment depends on a given scope.
|
Check if this output segment depends on a specific node's variable.
|
||||||
|
|
||||||
|
This method examines the `literal` of the output segment to see if it
|
||||||
|
contains a variable placeholder referencing the given node in the form:
|
||||||
|
|
||||||
|
{{ node_id.field_name }}
|
||||||
|
|
||||||
|
It uses a regular expression to match the exact node ID, avoiding
|
||||||
|
false positives from substring matches (e.g., 'node1' should not match 'node10').
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
scope (str): Node ID or special variable prefix (e.g., "sys").
|
node_id (str): The ID of the node to check for in this segment's variable placeholders.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if this segment references the given scope.
|
bool:
|
||||||
|
- True if the segment contains a variable referencing the given node.
|
||||||
|
- False otherwise.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
literal = "{{node1.name}}"
|
||||||
|
|
||||||
|
depends_on_node("node1") -> True
|
||||||
|
depends_on_node("node2") -> False
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
This method is primarily used in stream mode to determine whether
|
||||||
|
a particular variable output segment should be activated when a
|
||||||
|
specific upstream node completes execution.
|
||||||
"""
|
"""
|
||||||
pattern = rf"\{{\{{\s*{re.escape(scope)}\.[a-zA-Z0-9_]+\s*\}}\}}"
|
variable_pattern = rf"\{{\{{\s*{re.escape(node_id)}\.[a-zA-Z0-9_]+\s*\}}\}}"
|
||||||
return bool(re.search(pattern, self.literal))
|
pattern = re.compile(variable_pattern)
|
||||||
|
match = pattern.search(self.literal)
|
||||||
|
if match:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class StreamOutputConfig(BaseModel):
|
class StreamOutputConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
Streaming output configuration for an End node.
|
Streaming output configuration for an End node.
|
||||||
|
|
||||||
This configuration describes how the End node output behaves in streaming mode,
|
This structure controls:
|
||||||
including:
|
- whether the End node output is globally active
|
||||||
- whether output emission is globally activated
|
- which upstream branch nodes are responsible for activation
|
||||||
- which upstream branch/control nodes gate the activation
|
- how each output segment behaves in streaming mode
|
||||||
- how each parsed output segment is streamed and activated
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
activate: bool = Field(
|
activate: bool = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Global activation flag for the End node output.\n"
|
"Global activation state of the End node output.\n"
|
||||||
"When False, output segments should not be emitted even if available.\n"
|
"If False, no output should be emitted until all control nodes are resolved."
|
||||||
"This flag typically becomes True once required control branch conditions "
|
|
||||||
"are satisfied."
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
control_nodes: dict[str, str] = Field(
|
control_nodes: list[str] = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Control branch conditions for this End node output.\n"
|
"List of upstream branch node IDs that control this End node.\n"
|
||||||
"Mapping of `branch_node_id -> expected_branch_label`.\n"
|
"Each node must signal completion before output becomes active."
|
||||||
"The End node output becomes globally active when a controlling branch node "
|
|
||||||
"reports a matching completion status."
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs: list[OutputContent] = Field(
|
outputs: list[OutputContent] = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description="Ordered list of output segments parsed from the output template."
|
||||||
"Ordered list of output segments parsed from the output template.\n"
|
|
||||||
"Each segment represents either a literal text block or a variable placeholder "
|
|
||||||
"that may be activated independently."
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cursor: int = Field(
|
cursor: int = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Streaming cursor index.\n"
|
"Streaming cursor index.\n"
|
||||||
"Indicates the next output segment index to be emitted.\n"
|
"Indicates how many output segments have already been emitted."
|
||||||
"Segments with index < cursor are considered already streamed."
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_activate(self, scope: str, status=None):
|
def update_activate(self, node_id):
|
||||||
"""
|
"""
|
||||||
Update streaming activation state based on an upstream node or special variable.
|
Update activation state based on an upstream node completion.
|
||||||
|
|
||||||
Args:
|
This method is typically called when a branch/control node finishes execution.
|
||||||
scope (str):
|
|
||||||
Identifier of the completed upstream entity.
|
|
||||||
- If a control branch node, it should match a key in `control_nodes`.
|
|
||||||
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
|
|
||||||
status (optional):
|
|
||||||
Completion status of the control branch node.
|
|
||||||
Required when `scope` refers to a control node.
|
|
||||||
|
|
||||||
Behavior:
|
Behavior:
|
||||||
1. Control branch nodes:
|
1. If the node is a control node:
|
||||||
- If `scope` matches a key in `control_nodes` and `status` matches the expected
|
- Remove it from `control_nodes`
|
||||||
branch label, the End node output becomes globally active (`activate = True`).
|
- If all control nodes are resolved, activate the entire output
|
||||||
|
|
||||||
2. Variable output segments:
|
2. Activate variable output segments that depend on this node:
|
||||||
- For each segment that is a variable (`is_variable=True`):
|
- If an output segment is a variable
|
||||||
- If the segment literal references `scope`, mark the segment as active.
|
- And its literal references the completed node_id
|
||||||
- This applies both to regular node variables (e.g., "node_id.field")
|
- Mark that segment as active
|
||||||
and special system variables (e.g., "sys.xxx").
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
- This method does not emit output or advance the streaming cursor.
|
|
||||||
- It only updates activation flags based on upstream events or special variables.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Case 1: resolve control branch dependency
|
# Case 1: resolve control branch dependency
|
||||||
if scope in self.control_nodes.keys():
|
if node_id in self.control_nodes:
|
||||||
if status is None:
|
self.control_nodes.remove(node_id)
|
||||||
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
|
||||||
if status == self.control_nodes[scope]:
|
# All branch constraints resolved → enable output
|
||||||
|
if not self.control_nodes:
|
||||||
self.activate = True
|
self.activate = True
|
||||||
|
|
||||||
# Case 2: activate variable segments related to this node
|
# Case 2: activate variable segments related to this node
|
||||||
for i in range(len(self.outputs)):
|
for i in range(len(self.outputs)):
|
||||||
if (
|
if (
|
||||||
self.outputs[i].is_variable
|
self.outputs[i].is_variable
|
||||||
and self.outputs[i].depends_on_scope(scope)
|
and self.outputs[i].depends_on_node(node_id)
|
||||||
):
|
):
|
||||||
self.outputs[i].activate = True
|
self.outputs[i].activate = True
|
||||||
|
|
||||||
@@ -180,11 +184,11 @@ class GraphBuilder:
|
|||||||
self._find_upstream_branch_node = lru_cache(
|
self._find_upstream_branch_node = lru_cache(
|
||||||
maxsize=len(self.nodes) * 2
|
maxsize=len(self.nodes) * 2
|
||||||
)(self._find_upstream_branch_node)
|
)(self._find_upstream_branch_node)
|
||||||
|
self._analyze_end_node_output()
|
||||||
|
|
||||||
self.graph = StateGraph(WorkflowState)
|
self.graph = StateGraph(WorkflowState)
|
||||||
self.add_nodes()
|
self.add_nodes()
|
||||||
self.add_edges()
|
self.add_edges()
|
||||||
self._analyze_end_node_output()
|
|
||||||
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -212,53 +216,30 @@ class GraphBuilder:
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
raise RuntimeError(f"Node not found: Id={node_id}")
|
raise RuntimeError(f"Node not found: Id={node_id}")
|
||||||
|
|
||||||
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[tuple[str, str]]]:
|
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[str]]:
|
||||||
"""
|
"""Find upstream branch nodes for a given target node in the workflow graph.
|
||||||
Recursively find all upstream branch (control) nodes that influence the execution
|
|
||||||
of the given target node.
|
|
||||||
|
|
||||||
This method walks upstream along the workflow graph starting from `target_node`.
|
This method identifies all upstream control (branch) nodes that can affect
|
||||||
It distinguishes between:
|
the execution of `target_node`. If `target_node` is reachable from a start
|
||||||
- branch nodes (node types listed in `BRANCH_NODES`)
|
node (i.e., a node with no upstream nodes), the method returns an empty tuple.
|
||||||
- non-branch nodes (ordinary processing nodes)
|
|
||||||
|
|
||||||
Traversal rules:
|
The function distinguishes between branch nodes (defined in `BRANCH_NODES`)
|
||||||
1. For each immediate upstream node:
|
and non-branch nodes, recursively traversing upstream through non-branch
|
||||||
- If it is a branch node, it is recorded as an affecting control node.
|
nodes. If any non-branch upstream path does not lead to a branch node,
|
||||||
- If it is a non-branch node, the traversal continues recursively upstream.
|
the result will indicate that no valid upstream branch node exists.
|
||||||
2. If ANY upstream path reaches a START / CYCLE_START node without encountering
|
|
||||||
a branch node, the traversal is considered invalid:
|
|
||||||
- `has_branch` will be False
|
|
||||||
- no branch nodes are returned.
|
|
||||||
3. Only when ALL upstream non-branch paths eventually lead to at least one
|
|
||||||
branch node will `has_branch` be True.
|
|
||||||
|
|
||||||
Special case:
|
|
||||||
- If `target_node` has no upstream nodes AND its type is START or CYCLE_START,
|
|
||||||
it is considered directly reachable from the workflow entry, and therefore
|
|
||||||
has no controlling branch nodes.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_node (str):
|
target_node (str): The identifier of the target node.
|
||||||
The identifier of the node whose upstream control branches
|
|
||||||
are to be resolved.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[bool, tuple[tuple[str, str]]]:
|
tuple[bool, tuple[str]]:
|
||||||
- has_branch (bool):
|
- has_branch (bool): True if all upstream non-branch paths lead to at least
|
||||||
True if every upstream path from `target_node` encounters
|
one branch node; False if any path reaches a start node without a branch.
|
||||||
at least one branch node.
|
- branch_nodes (tuple[str]): A deduplicated tuple of upstream branch node IDs
|
||||||
False if any path reaches a start node without a branch.
|
affecting `target_node`. Returns an empty tuple if `has_branch` is False.
|
||||||
- branch_nodes (tuple[tuple[str, str]]):
|
|
||||||
A deduplicated tuple of `(branch_node_id, branch_label)` pairs
|
|
||||||
representing all branch nodes that can influence `target_node`.
|
|
||||||
Returns an empty tuple if `has_branch` is False.
|
|
||||||
"""
|
"""
|
||||||
source_nodes = [
|
source_nodes = [
|
||||||
{
|
edge.get("source")
|
||||||
"id": edge.get("source"),
|
|
||||||
"branch": edge.get("label")
|
|
||||||
}
|
|
||||||
for edge in self.edges
|
for edge in self.edges
|
||||||
if edge.get("target") == target_node
|
if edge.get("target") == target_node
|
||||||
]
|
]
|
||||||
@@ -268,13 +249,11 @@ class GraphBuilder:
|
|||||||
branch_nodes = []
|
branch_nodes = []
|
||||||
non_branch_nodes = []
|
non_branch_nodes = []
|
||||||
|
|
||||||
for node_info in source_nodes:
|
for node_id in source_nodes:
|
||||||
if self.get_node_type(node_info["id"]) in BRANCH_NODES:
|
if self.get_node_type(node_id) in BRANCH_NODES:
|
||||||
branch_nodes.append(
|
branch_nodes.append(node_id)
|
||||||
(node_info["id"], node_info["branch"])
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
non_branch_nodes.append(node_info["id"])
|
non_branch_nodes.append(node_id)
|
||||||
|
|
||||||
has_branch = True
|
has_branch = True
|
||||||
for node_id in non_branch_nodes:
|
for node_id in non_branch_nodes:
|
||||||
@@ -355,7 +334,7 @@ class GraphBuilder:
|
|||||||
activate=not has_branch,
|
activate=not has_branch,
|
||||||
|
|
||||||
# Branch nodes that control activation of this End node
|
# Branch nodes that control activation of this End node
|
||||||
control_nodes=dict(control_nodes),
|
control_nodes=list(control_nodes),
|
||||||
|
|
||||||
# Convert output segments into OutputContent objects
|
# Convert output segments into OutputContent objects
|
||||||
outputs=list(
|
outputs=list(
|
||||||
@@ -383,7 +362,7 @@ class GraphBuilder:
|
|||||||
else:
|
else:
|
||||||
self.end_node_map[end_node_id] = StreamOutputConfig(
|
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||||
activate=True,
|
activate=True,
|
||||||
control_nodes={},
|
control_nodes=[],
|
||||||
outputs=list(
|
outputs=list(
|
||||||
[
|
[
|
||||||
OutputContent(
|
OutputContent(
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class CodeNodeConfig(BaseNodeConfig):
|
|||||||
description="code content"
|
description="code content"
|
||||||
)
|
)
|
||||||
|
|
||||||
language: Literal['python3', 'javascript'] = Field(
|
language: Literal['python3', 'nodejs'] = Field(
|
||||||
...,
|
...,
|
||||||
description="language"
|
description="language"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import urllib.parse
|
|
||||||
from string import Template
|
from string import Template
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -15,7 +14,7 @@ from app.core.workflow.nodes.code.config import CodeNodeConfig
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
PYTHON_SCRIPT_TEMPLATE = Template(dedent("""
|
SCRIPT_TEMPLATE = Template(dedent("""
|
||||||
$code
|
$code
|
||||||
|
|
||||||
import json
|
import json
|
||||||
@@ -33,20 +32,6 @@ result = "<<RESULT>>" + output_json + "<<RESULT>>"
|
|||||||
print(result)
|
print(result)
|
||||||
"""))
|
"""))
|
||||||
|
|
||||||
NODEJS_SCRIPT_TEMPLATE = Template(dedent("""
|
|
||||||
$code
|
|
||||||
// decode and prepare input object
|
|
||||||
var inputs_obj = JSON.parse(Buffer.from('$inputs_variable', 'base64').toString('utf-8'))
|
|
||||||
|
|
||||||
// execute main function
|
|
||||||
var output_obj = main(inputs_obj)
|
|
||||||
|
|
||||||
// convert output to json and print
|
|
||||||
var output_json = JSON.stringify(output_obj)
|
|
||||||
var result = `<<RESULT>>$${output_json}<<RESULT>>`
|
|
||||||
console.log(result)
|
|
||||||
"""))
|
|
||||||
|
|
||||||
|
|
||||||
class CodeNode(BaseNode):
|
class CodeNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
@@ -98,27 +83,18 @@ class CodeNode(BaseNode):
|
|||||||
input_variable_dict = {}
|
input_variable_dict = {}
|
||||||
for input_variable in self.typed_config.input_variables:
|
for input_variable in self.typed_config.input_variables:
|
||||||
input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, state)
|
input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, state)
|
||||||
|
|
||||||
code = base64.b64decode(
|
code = base64.b64decode(
|
||||||
self.typed_config.code
|
self.typed_config.code
|
||||||
).decode("utf-8")
|
).decode("utf-8")
|
||||||
code = urllib.parse.unquote(code, encoding='utf-8')
|
|
||||||
|
|
||||||
input_variable_dict = base64.b64encode(
|
input_variable_dict = base64.b64encode(
|
||||||
json.dumps(input_variable_dict).encode("utf-8")
|
json.dumps(input_variable_dict).encode("utf-8")
|
||||||
).decode("utf-8")
|
).decode("utf-8")
|
||||||
if self.typed_config.language == "python3":
|
|
||||||
final_script = PYTHON_SCRIPT_TEMPLATE.substitute(
|
final_script = SCRIPT_TEMPLATE.substitute(
|
||||||
code=code,
|
code=code,
|
||||||
inputs_variable=input_variable_dict,
|
inputs_variable=input_variable_dict,
|
||||||
)
|
)
|
||||||
elif self.typed_config.language == 'javascript':
|
|
||||||
final_script = NODEJS_SCRIPT_TEMPLATE.substitute(
|
|
||||||
code=code,
|
|
||||||
inputs_variable=input_variable_dict,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported language: {self.typed_config.language}")
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
|
|||||||
@@ -25,6 +25,6 @@ class MemoryWriteNodeConfig(BaseNodeConfig):
|
|||||||
...
|
...
|
||||||
)
|
)
|
||||||
|
|
||||||
config_id: UUID | int = Field(
|
config_id: UUID = Field(
|
||||||
...
|
...
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -36,10 +36,9 @@ class MemoryReadNode(BaseNode):
|
|||||||
class MemoryWriteNode(BaseNode):
|
class MemoryWriteNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config: MemoryWriteNodeConfig | None = None
|
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> Any:
|
async def execute(self, state: WorkflowState) -> Any:
|
||||||
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
|
||||||
end_user_id = self.get_variable("sys.user_id", state)
|
end_user_id = self.get_variable("sys.user_id", state)
|
||||||
|
|
||||||
if not end_user_id:
|
if not end_user_id:
|
||||||
|
|||||||
@@ -23,18 +23,6 @@ class ParameterExtractorNode(BaseNode):
|
|||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config: ParameterExtractorNodeConfig | None = None
|
self.typed_config: ParameterExtractorNodeConfig | None = None
|
||||||
self.response_metadata = {}
|
|
||||||
|
|
||||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
|
||||||
if self.response_metadata:
|
|
||||||
usage = self.response_metadata.get('token_usage')
|
|
||||||
if usage:
|
|
||||||
return {
|
|
||||||
"prompt_tokens": usage.get('prompt_tokens', 0),
|
|
||||||
"completion_tokens": usage.get('completion_tokens', 0),
|
|
||||||
"total_tokens": usage.get('total_tokens', 0)
|
|
||||||
}
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_prompt():
|
def _get_prompt():
|
||||||
@@ -183,7 +171,6 @@ class ParameterExtractorNode(BaseNode):
|
|||||||
])
|
])
|
||||||
|
|
||||||
model_resp = await llm.ainvoke(messages)
|
model_resp = await llm.ainvoke(messages)
|
||||||
self.response_metadata = model_resp.response_metadata
|
|
||||||
result = json_repair.repair_json(model_resp.content, return_objects=True)
|
result = json_repair.repair_json(model_resp.content, return_objects=True)
|
||||||
logger.info(f"node: {self.node_id} get params:{result}")
|
logger.info(f"node: {self.node_id} get params:{result}")
|
||||||
|
|
||||||
|
|||||||
@@ -23,18 +23,6 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config: QuestionClassifierNodeConfig | None = None
|
self.typed_config: QuestionClassifierNodeConfig | None = None
|
||||||
self.category_to_case_map = {}
|
self.category_to_case_map = {}
|
||||||
self.response_metadata = {}
|
|
||||||
|
|
||||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
|
||||||
if self.response_metadata:
|
|
||||||
usage = self.response_metadata.get('token_usage')
|
|
||||||
if usage:
|
|
||||||
return {
|
|
||||||
"prompt_tokens": usage.get('prompt_tokens', 0),
|
|
||||||
"completion_tokens": usage.get('completion_tokens', 0),
|
|
||||||
"total_tokens": usage.get('total_tokens', 0)
|
|
||||||
}
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _get_llm_instance(self) -> RedBearLLM:
|
def _get_llm_instance(self) -> RedBearLLM:
|
||||||
"""获取LLM实例"""
|
"""获取LLM实例"""
|
||||||
@@ -124,7 +112,6 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
|
|
||||||
response = await llm.ainvoke(messages)
|
response = await llm.ainvoke(messages)
|
||||||
result = response.content.strip()
|
result = response.content.strip()
|
||||||
self.response_metadata = response.response_metadata
|
|
||||||
|
|
||||||
if result in category_names:
|
if result in category_names:
|
||||||
category = result
|
category = result
|
||||||
|
|||||||
@@ -4,19 +4,16 @@
|
|||||||
从文件系统加载预定义的工作流模板
|
从文件系统加载预定义的工作流模板
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
TEMPLATE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'templates')
|
|
||||||
|
|
||||||
|
|
||||||
class TemplateLoader:
|
class TemplateLoader:
|
||||||
"""工作流模板加载器"""
|
"""工作流模板加载器"""
|
||||||
|
|
||||||
def __init__(self, templates_dir: str = TEMPLATE_DIR):
|
def __init__(self, templates_dir: str = "app/templates/workflows"):
|
||||||
"""初始化模板加载器
|
"""初始化模板加载器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -16,8 +16,6 @@ from app.core.error_codes import BizCode, HTTP_MAPPING
|
|||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.logging_config import LoggingConfig, get_logger
|
from app.core.logging_config import LoggingConfig, get_logger
|
||||||
from app.core.response_utils import fail
|
from app.core.response_utils import fail
|
||||||
from app.core.models.scripts.loader import load_models
|
|
||||||
from app.db import get_db_context
|
|
||||||
|
|
||||||
# Initialize logging system
|
# Initialize logging system
|
||||||
LoggingConfig.setup_logging()
|
LoggingConfig.setup_logging()
|
||||||
@@ -49,15 +47,6 @@ async def lifespan(app: FastAPI):
|
|||||||
else:
|
else:
|
||||||
logger.info("自动数据库升级已禁用 (DB_AUTO_UPGRADE=false)")
|
logger.info("自动数据库升级已禁用 (DB_AUTO_UPGRADE=false)")
|
||||||
|
|
||||||
# 加载预定义模型
|
|
||||||
logger.info("开始加载预定义模型...")
|
|
||||||
try:
|
|
||||||
with get_db_context() as db:
|
|
||||||
result = load_models(db, silent=True)
|
|
||||||
logger.info(f"预定义模型加载完成: 成功{result['success']}个, 跳过{result['skipped']}个, 失败{result['failed']}个")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"加载预定义模型时出错: {str(e)}")
|
|
||||||
|
|
||||||
logger.info("应用程序启动完成")
|
logger.info("应用程序启动完成")
|
||||||
yield
|
yield
|
||||||
# 应用关闭事件
|
# 应用关闭事件
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from .document_model import Document
|
|||||||
from .file_model import File
|
from .file_model import File
|
||||||
from .file_metadata_model import FileMetadata
|
from .file_metadata_model import FileMetadata
|
||||||
from .generic_file_model import GenericFile
|
from .generic_file_model import GenericFile
|
||||||
from .models_model import ModelConfig, ModelProvider, ModelType, ModelApiKey, ModelBase, LoadBalanceStrategy
|
from .models_model import ModelConfig, ModelProvider, ModelType, ModelApiKey
|
||||||
from .memory_short_model import ShortTermMemory, LongTermMemory
|
from .memory_short_model import ShortTermMemory, LongTermMemory
|
||||||
from .knowledgeshare_model import KnowledgeShare
|
from .knowledgeshare_model import KnowledgeShare
|
||||||
from .app_model import App
|
from .app_model import App
|
||||||
@@ -28,10 +28,6 @@ from .tool_model import (
|
|||||||
ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus
|
ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus
|
||||||
)
|
)
|
||||||
from .memory_perceptual_model import MemoryPerceptualModel
|
from .memory_perceptual_model import MemoryPerceptualModel
|
||||||
from .ontology_scene import OntologyScene
|
|
||||||
from .ontology_class import OntologyClass
|
|
||||||
from .ontology_scene import OntologyScene
|
|
||||||
from .ontology_class import OntologyClass
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Tenants",
|
"Tenants",
|
||||||
@@ -83,6 +79,4 @@ __all__ = [
|
|||||||
"AuthType",
|
"AuthType",
|
||||||
"ExecutionStatus",
|
"ExecutionStatus",
|
||||||
"MemoryPerceptualModel",
|
"MemoryPerceptualModel",
|
||||||
"ModelBase",
|
|
||||||
"LoadBalanceStrategy"
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -21,9 +21,6 @@ class MemoryConfig(Base):
|
|||||||
user_id = Column(String, nullable=True, comment="用户ID")
|
user_id = Column(String, nullable=True, comment="用户ID")
|
||||||
apply_id = Column(String, nullable=True, comment="应用ID")
|
apply_id = Column(String, nullable=True, comment="应用ID")
|
||||||
|
|
||||||
# 本体场景关联
|
|
||||||
scene_id = Column(UUID(as_uuid=True), nullable=True, comment="本体场景ID,关联ontology_scene表")
|
|
||||||
|
|
||||||
# 模型选择(从workspace继承)
|
# 模型选择(从workspace继承)
|
||||||
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
|
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
|
||||||
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
|
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
|
||||||
|
|||||||
@@ -1,34 +1,19 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import uuid
|
import uuid
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
from typing import Optional, List
|
||||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table
|
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum
|
||||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
from sqlalchemy.dialects.postgresql import UUID, JSON
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
from sqlalchemy.sql import func
|
|
||||||
from app.db import Base
|
from app.db import Base
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(Base):
|
|
||||||
"""基础模型(抽象类,提取公共字段)"""
|
|
||||||
__abstract__ = True # 标记为抽象类,不生成表
|
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
|
||||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
|
|
||||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
|
|
||||||
is_active = Column(Boolean, default=True, nullable=False, comment="是否激活")
|
|
||||||
|
|
||||||
|
|
||||||
class ModelType(StrEnum):
|
class ModelType(StrEnum):
|
||||||
"""模型类型枚举"""
|
"""模型类型枚举"""
|
||||||
LLM = "llm"
|
LLM = "llm"
|
||||||
CHAT = "chat"
|
CHAT = "chat"
|
||||||
EMBEDDING = "embedding"
|
EMBEDDING = "embedding"
|
||||||
RERANK = "rerank"
|
RERANK = "rerank"
|
||||||
# TTS = "tts"
|
|
||||||
# SPEECH2TEXT = "speech2text"
|
|
||||||
# IMAGE = "image"
|
|
||||||
# AUDIO = "audio"
|
|
||||||
# VISION = "vision"
|
|
||||||
|
|
||||||
|
|
||||||
class ModelProvider(StrEnum):
|
class ModelProvider(StrEnum):
|
||||||
@@ -45,36 +30,16 @@ class ModelProvider(StrEnum):
|
|||||||
XINFERENCE = "xinference"
|
XINFERENCE = "xinference"
|
||||||
GPUSTACK = "gpustack"
|
GPUSTACK = "gpustack"
|
||||||
BEDROCK = "bedrock"
|
BEDROCK = "bedrock"
|
||||||
COMPOSITE = "composite"
|
|
||||||
|
|
||||||
|
|
||||||
class LoadBalanceStrategy(StrEnum):
|
class ModelConfig(Base):
|
||||||
"""API Key负载均衡策略枚举"""
|
|
||||||
ROUND_ROBIN = "round_robin" # 轮询
|
|
||||||
NONE = "none" # 无
|
|
||||||
|
|
||||||
|
|
||||||
# 多对多关联表
|
|
||||||
model_config_api_key_association = Table(
|
|
||||||
'model_config_api_key_association',
|
|
||||||
Base.metadata,
|
|
||||||
Column('model_config_id', UUID(as_uuid=True), ForeignKey('model_configs.id'), primary_key=True),
|
|
||||||
Column('api_key_id', UUID(as_uuid=True), ForeignKey('model_api_keys.id'), primary_key=True),
|
|
||||||
Column('created_at', DateTime, default=datetime.datetime.now)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
|
||||||
"""模型配置表"""
|
"""模型配置表"""
|
||||||
__tablename__ = "model_configs"
|
__tablename__ = "model_configs"
|
||||||
|
|
||||||
model_id = Column(UUID(as_uuid=True), ForeignKey("model_bases.id"), nullable=True, index=True, comment="基础模型ID")
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, index=True, comment="租户ID")
|
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, index=True, comment="租户ID")
|
||||||
logo = Column(String(255), nullable=True, comment="模型logo图片URL")
|
|
||||||
name = Column(String, nullable=False, comment="模型显示名称")
|
name = Column(String, nullable=False, comment="模型显示名称")
|
||||||
provider = Column(String, nullable=False, comment="供应商", server_default=ModelProvider.COMPOSITE)
|
|
||||||
type = Column(String, nullable=False, index=True, comment="模型类型")
|
type = Column(String, nullable=False, index=True, comment="模型类型")
|
||||||
is_composite = Column(Boolean, default=False, server_default="true", nullable=False, comment="是否为组合模型")
|
|
||||||
description = Column(String, comment="模型描述")
|
description = Column(String, comment="模型描述")
|
||||||
|
|
||||||
# 模型配置参数
|
# 模型配置参数
|
||||||
@@ -91,29 +56,29 @@ class ModelConfig(BaseModel):
|
|||||||
# context_length = Column(String, comment="上下文长度")
|
# context_length = Column(String, comment="上下文长度")
|
||||||
|
|
||||||
# 状态管理
|
# 状态管理
|
||||||
|
is_active = Column(Boolean, default=True, nullable=False, comment="是否激活")
|
||||||
is_public = Column(Boolean, default=False, nullable=False, comment="是否公开")
|
is_public = Column(Boolean, default=False, nullable=False, comment="是否公开")
|
||||||
load_balance_strategy = Column(String, nullable=True, comment="负载均衡策略", default=LoadBalanceStrategy.NONE,
|
|
||||||
server_default=LoadBalanceStrategy.NONE)
|
# 时间戳
|
||||||
|
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
|
||||||
|
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
|
||||||
|
|
||||||
# 关联关系
|
# 关联关系
|
||||||
model_base = relationship("ModelBase", back_populates="configs")
|
api_keys = relationship("ModelApiKey", back_populates="model_config", cascade="all, delete-orphan")
|
||||||
api_keys = relationship(
|
|
||||||
"ModelApiKey",
|
|
||||||
secondary=model_config_api_key_association,
|
|
||||||
back_populates="model_configs"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<ModelConfig(id={self.id}, name={self.name}, type={self.type})>"
|
return f"<ModelConfig(id={self.id}, name={self.name}, type={self.type})>"
|
||||||
|
|
||||||
|
|
||||||
class ModelApiKey(BaseModel):
|
class ModelApiKey(Base):
|
||||||
"""模型API密钥表"""
|
"""模型API密钥表"""
|
||||||
__tablename__ = "model_api_keys"
|
__tablename__ = "model_api_keys"
|
||||||
|
|
||||||
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||||
|
model_config_id = Column(UUID(as_uuid=True), ForeignKey("model_configs.id"), nullable=False, comment="模型配置ID")
|
||||||
|
|
||||||
# API Key 信息
|
# API Key 信息
|
||||||
model_name = Column(String, nullable=False, comment="模型实际名称")
|
model_name = Column(String, nullable=False, comment="模型实际名称")
|
||||||
description = Column(String, comment="备注")
|
|
||||||
provider = Column(String, nullable=False, comment="API Key提供商")
|
provider = Column(String, nullable=False, comment="API Key提供商")
|
||||||
api_key = Column(String, nullable=False, comment="API密钥")
|
api_key = Column(String, nullable=False, comment="API密钥")
|
||||||
api_base = Column(String, comment="API基础URL")
|
api_base = Column(String, comment="API基础URL")
|
||||||
@@ -126,42 +91,15 @@ class ModelApiKey(BaseModel):
|
|||||||
last_used_at = Column(DateTime, comment="最后使用时间")
|
last_used_at = Column(DateTime, comment="最后使用时间")
|
||||||
|
|
||||||
# 状态管理
|
# 状态管理
|
||||||
|
is_active = Column(Boolean, default=True, nullable=False, comment="是否激活")
|
||||||
priority = Column(String, default="1", comment="优先级")
|
priority = Column(String, default="1", comment="优先级")
|
||||||
|
|
||||||
# 关联关系
|
# 时间戳
|
||||||
model_configs = relationship(
|
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
|
||||||
"ModelConfig",
|
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
|
||||||
secondary=model_config_api_key_association,
|
|
||||||
back_populates="api_keys"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"<ModelApiKey(id={self.id}, model_name={self.model_name}, provider={self.provider})>"
|
|
||||||
|
|
||||||
|
|
||||||
class ModelBase(Base):
|
|
||||||
"""基础模型信息表(模型广场)"""
|
|
||||||
__tablename__ = "model_bases"
|
|
||||||
|
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
|
||||||
logo = Column(String(255), nullable=True, comment="模型logo图片URL")
|
|
||||||
name = Column(String, nullable=False, comment="模型唯一标识(如gpt-3.5-turbo)")
|
|
||||||
type = Column(String, nullable=False, index=True, comment="模型类型")
|
|
||||||
provider = Column(String, nullable=False, index=True)
|
|
||||||
description = Column(Text, comment="模型描述")
|
|
||||||
is_deprecated = Column(Boolean, default=False, nullable=False, comment="是否弃用")
|
|
||||||
is_official = Column(Boolean, default=True, comment="是否供应商官方模型(区分自定义)")
|
|
||||||
tags = Column(ARRAY(String), default=list, nullable=False, comment="模型标签(如['聊天', '创作'])")
|
|
||||||
add_count = Column(Integer, default=0, nullable=False, comment="模型被用户添加的次数")
|
|
||||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间", server_default=func.now())
|
|
||||||
|
|
||||||
# 关联关系
|
# 关联关系
|
||||||
configs = relationship("ModelConfig", back_populates="model_base", cascade="all, delete-orphan")
|
model_config = relationship("ModelConfig", back_populates="api_keys")
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
UniqueConstraint("name", "provider", name="uk_model_name_provider"),
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<ModelBase(name={self.name}, provider={self.provider}, type={self.type})>"
|
return f"<ModelApiKey(id={self.id}, model_name={self.model_name}, provider={self.provider}, model_config_id={self.model_config_id})>"
|
||||||
|
|||||||
@@ -1,40 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""本体类型模型
|
|
||||||
|
|
||||||
本模块定义本体类型的数据模型。
|
|
||||||
|
|
||||||
Classes:
|
|
||||||
OntologyClass: 本体类型表模型
|
|
||||||
"""
|
|
||||||
|
|
||||||
import datetime
|
|
||||||
import uuid
|
|
||||||
from sqlalchemy import Column, String, DateTime, Text, ForeignKey
|
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
|
||||||
from sqlalchemy.orm import relationship
|
|
||||||
from app.db import Base
|
|
||||||
|
|
||||||
|
|
||||||
class OntologyClass(Base):
|
|
||||||
"""本体类型表 - 用于存储某个场景提取出来的本体类型信息"""
|
|
||||||
__tablename__ = "ontology_class"
|
|
||||||
|
|
||||||
# 主键
|
|
||||||
class_id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="类型ID")
|
|
||||||
|
|
||||||
# 类型信息
|
|
||||||
class_name = Column(String(200), nullable=False, comment="类型名称")
|
|
||||||
class_description = Column(Text, nullable=True, comment="类型描述")
|
|
||||||
|
|
||||||
# 外键:关联到本体场景
|
|
||||||
scene_id = Column(UUID(as_uuid=True), ForeignKey("ontology_scene.scene_id", ondelete="CASCADE"), nullable=False, index=True, comment="所属场景ID")
|
|
||||||
|
|
||||||
# 时间戳
|
|
||||||
created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, comment="创建时间")
|
|
||||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, nullable=False, comment="更新时间")
|
|
||||||
|
|
||||||
# 关系:类型属于某个场景
|
|
||||||
scene = relationship("OntologyScene", back_populates="classes")
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"<OntologyClass(id={self.class_id}, name={self.class_name}, scene_id={self.scene_id})>"
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""本体场景模型
|
|
||||||
|
|
||||||
本模块定义本体场景的数据模型。
|
|
||||||
|
|
||||||
Classes:
|
|
||||||
OntologyScene: 本体场景表模型
|
|
||||||
"""
|
|
||||||
|
|
||||||
import datetime
|
|
||||||
import uuid
|
|
||||||
from sqlalchemy import Column, String, DateTime, Integer, Text, ForeignKey, UniqueConstraint
|
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
|
||||||
from sqlalchemy.orm import relationship
|
|
||||||
from app.db import Base
|
|
||||||
|
|
||||||
|
|
||||||
class OntologyScene(Base):
|
|
||||||
"""本体场景表 - 用于存储本体场景下不同的类型信息"""
|
|
||||||
__tablename__ = "ontology_scene"
|
|
||||||
__table_args__ = (
|
|
||||||
UniqueConstraint('workspace_id', 'scene_name', name='uq_workspace_scene_name'),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 主键
|
|
||||||
scene_id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="场景ID")
|
|
||||||
|
|
||||||
# 场景信息
|
|
||||||
scene_name = Column(String(200), nullable=False, comment="场景名称")
|
|
||||||
scene_description = Column(Text, nullable=True, comment="场景描述")
|
|
||||||
|
|
||||||
# 外键:关联到工作空间
|
|
||||||
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="CASCADE"), nullable=False, index=True, comment="所属工作空间ID")
|
|
||||||
|
|
||||||
# 时间戳
|
|
||||||
created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, comment="创建时间")
|
|
||||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, nullable=False, comment="更新时间")
|
|
||||||
|
|
||||||
# 关系:一个场景可以有多个类型
|
|
||||||
classes = relationship("OntologyClass", back_populates="scene", cascade="all, delete-orphan")
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"<OntologyScene(id={self.scene_id}, name={self.scene_name})>"
|
|
||||||
@@ -2,7 +2,7 @@ import datetime
|
|||||||
import uuid
|
import uuid
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
|
||||||
from sqlalchemy import Column, ForeignKey, Text, DateTime, String, Index, Boolean
|
from sqlalchemy import Column, ForeignKey, Text, DateTime, String, Index
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
from app.db import Base
|
from app.db import Base
|
||||||
@@ -121,33 +121,10 @@ class PromptOptimizerSessionHistory(Base):
|
|||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID")
|
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID")
|
||||||
# app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False, comment="Application ID")
|
# app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False, comment="Application ID")
|
||||||
session_id = Column(
|
session_id = Column(UUID(as_uuid=True), ForeignKey("prompt_opt_session_list.id"),nullable=False, comment="Session ID")
|
||||||
UUID(as_uuid=True),
|
|
||||||
ForeignKey("prompt_opt_session_list.id"),
|
|
||||||
nullable=False,
|
|
||||||
comment="Session ID"
|
|
||||||
)
|
|
||||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, comment="User ID")
|
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, comment="User ID")
|
||||||
role = Column(String, nullable=False, comment="Message Role")
|
role = Column(String, nullable=False, comment="Message Role")
|
||||||
content = Column(Text, nullable=False, comment="Message Content")
|
content = Column(Text, nullable=False, comment="Message Content")
|
||||||
# prompt = Column(Text, nullable=False, comment="Prompt")
|
# prompt = Column(Text, nullable=False, comment="Prompt")
|
||||||
|
|
||||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time", index=True)
|
created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time", index=True)
|
||||||
|
|
||||||
|
|
||||||
class PromptHistory(Base):
|
|
||||||
__tablename__ = "prompt_history"
|
|
||||||
|
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
|
||||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID")
|
|
||||||
|
|
||||||
session_id = Column(
|
|
||||||
UUID(as_uuid=True),
|
|
||||||
ForeignKey("prompt_opt_session_list.id"),
|
|
||||||
nullable=False,
|
|
||||||
comment="Session ID"
|
|
||||||
)
|
|
||||||
title = Column(String, nullable=False, comment="Title")
|
|
||||||
prompt = Column(Text, nullable=False, comment="Prompt")
|
|
||||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time", index=True)
|
|
||||||
is_delete = Column(Boolean, default=False, comment="Delete")
|
|
||||||
|
|||||||
@@ -24,16 +24,12 @@ from app.schemas.memory_storage_schema import (
|
|||||||
from sqlalchemy import desc, select
|
from sqlalchemy import desc, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.utils.config_utils import resolve_config_id
|
|
||||||
|
|
||||||
# 获取数据库专用日志器
|
# 获取数据库专用日志器
|
||||||
db_logger = get_db_logger()
|
db_logger = get_db_logger()
|
||||||
# 获取配置专用日志器
|
# 获取配置专用日志器
|
||||||
config_logger = get_config_logger()
|
config_logger = get_config_logger()
|
||||||
|
|
||||||
TABLE_NAME = "memory_config"
|
TABLE_NAME = "memory_config"
|
||||||
|
|
||||||
|
|
||||||
class MemoryConfigRepository:
|
class MemoryConfigRepository:
|
||||||
"""记忆配置Repository
|
"""记忆配置Repository
|
||||||
|
|
||||||
@@ -86,8 +82,7 @@ class MemoryConfigRepository:
|
|||||||
n.description AS description,
|
n.description AS description,
|
||||||
n.entity_type AS entity_type,
|
n.entity_type AS entity_type,
|
||||||
n.name AS name,
|
n.name AS name,
|
||||||
// TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
COALESCE(n.fact_summary, '') AS fact_summary,
|
||||||
// COALESCE(n.fact_summary, '') AS fact_summary,
|
|
||||||
n.end_user_id AS end_user_id,
|
n.end_user_id AS end_user_id,
|
||||||
n.apply_id AS apply_id,
|
n.apply_id AS apply_id,
|
||||||
n.user_id AS user_id,
|
n.user_id AS user_id,
|
||||||
@@ -157,7 +152,7 @@ class MemoryConfigRepository:
|
|||||||
return memory_config_obj
|
return memory_config_obj
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID|int|str) -> MemoryConfig:
|
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID) -> MemoryConfig:
|
||||||
"""构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数)
|
"""构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -173,7 +168,6 @@ class MemoryConfigRepository:
|
|||||||
if not memory_config:
|
if not memory_config:
|
||||||
raise RuntimeError("reflection config not found")
|
raise RuntimeError("reflection config not found")
|
||||||
return memory_config
|
return memory_config
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def query_reflection_config_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> MemoryConfig:
|
def query_reflection_config_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> MemoryConfig:
|
||||||
"""构建查询所有配置的语句(SQLAlchemy text() 命名参数)
|
"""构建查询所有配置的语句(SQLAlchemy text() 命名参数)
|
||||||
@@ -193,6 +187,7 @@ class MemoryConfigRepository:
|
|||||||
raise RuntimeError("reflection config not found")
|
raise RuntimeError("reflection config not found")
|
||||||
return memory_config
|
return memory_config
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build_select_all(workspace_id: uuid.UUID) -> Tuple[str, Dict]:
|
def build_select_all(workspace_id: uuid.UUID) -> Tuple[str, Dict]:
|
||||||
"""构建查询所有配置的语句(SQLAlchemy text() 命名参数)
|
"""构建查询所有配置的语句(SQLAlchemy text() 命名参数)
|
||||||
@@ -232,12 +227,9 @@ class MemoryConfigRepository:
|
|||||||
config_name=params.config_name,
|
config_name=params.config_name,
|
||||||
config_desc=params.config_desc,
|
config_desc=params.config_desc,
|
||||||
workspace_id=params.workspace_id,
|
workspace_id=params.workspace_id,
|
||||||
scene_id=params.scene_id,
|
|
||||||
llm_id=params.llm_id,
|
llm_id=params.llm_id,
|
||||||
embedding_id=params.embedding_id,
|
embedding_id=params.embedding_id,
|
||||||
rerank_id=params.rerank_id,
|
rerank_id=params.rerank_id,
|
||||||
reflection_model_id=params.reflection_model_id,
|
|
||||||
emotion_model_id=params.emotion_model_id,
|
|
||||||
)
|
)
|
||||||
db.add(db_config)
|
db.add(db_config)
|
||||||
db.flush() # 获取自增ID但不提交事务
|
db.flush() # 获取自增ID但不提交事务
|
||||||
@@ -280,9 +272,6 @@ class MemoryConfigRepository:
|
|||||||
if update.config_desc is not None:
|
if update.config_desc is not None:
|
||||||
db_config.config_desc = update.config_desc
|
db_config.config_desc = update.config_desc
|
||||||
has_update = True
|
has_update = True
|
||||||
if update.scene_id is not None:
|
|
||||||
db_config.scene_id = update.scene_id
|
|
||||||
has_update = True
|
|
||||||
|
|
||||||
if not has_update:
|
if not has_update:
|
||||||
raise ValueError("No fields to update")
|
raise ValueError("No fields to update")
|
||||||
@@ -298,6 +287,7 @@ class MemoryConfigRepository:
|
|||||||
db_logger.error(f"更新记忆配置失败: config_id={update.config_id} - {str(e)}")
|
db_logger.error(f"更新记忆配置失败: config_id={update.config_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[MemoryConfig]:
|
def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[MemoryConfig]:
|
||||||
"""更新记忆萃取引擎配置
|
"""更新记忆萃取引擎配置
|
||||||
@@ -420,7 +410,7 @@ class MemoryConfigRepository:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_extracted_config(db: Session, config_id: UUID | int) -> Optional[Dict]:
|
def get_extracted_config(db: Session, config_id: UUID) -> Optional[Dict]:
|
||||||
"""获取萃取配置,通过主键查询某条配置
|
"""获取萃取配置,通过主键查询某条配置
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -430,8 +420,8 @@ class MemoryConfigRepository:
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[Dict]: 萃取配置字典,不存在则返回None
|
Optional[Dict]: 萃取配置字典,不存在则返回None
|
||||||
"""
|
"""
|
||||||
config_id = resolve_config_id(config_id, db)
|
|
||||||
db_logger.debug(f"查询萃取配置: config_id={config_id}")
|
db_logger.debug(f"查询萃取配置: config_id={config_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
|
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
|
||||||
if not db_config:
|
if not db_config:
|
||||||
@@ -524,9 +514,8 @@ class MemoryConfigRepository:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
db_logger.error(f"根据ID查询记忆配置失败: config_id={config_id} - {str(e)}")
|
db_logger.error(f"根据ID查询记忆配置失败: config_id={config_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_config_with_workspace(db: Session, config_id: uuid.UUID | int | str) -> Optional[tuple]:
|
def get_config_with_workspace(db: Session, config_id: uuid.UUID) -> Optional[tuple]:
|
||||||
"""Get memory config and its associated workspace information
|
"""Get memory config and its associated workspace information
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -544,7 +533,6 @@ class MemoryConfigRepository:
|
|||||||
from app.models.workspace_model import Workspace
|
from app.models.workspace_model import Workspace
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
config_id = resolve_config_id(config_id, db)
|
|
||||||
|
|
||||||
# Log configuration loading start
|
# Log configuration loading start
|
||||||
config_logger.info(
|
config_logger.info(
|
||||||
@@ -593,10 +581,8 @@ class MemoryConfigRepository:
|
|||||||
"elapsed_ms": elapsed_ms
|
"elapsed_ms": elapsed_ms
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
db_logger.error(
|
db_logger.error(f"Memory config {config_id} references non-existent workspace {config_only.workspace_id}")
|
||||||
f"Memory config {config_id} references non-existent workspace {config_only.workspace_id}")
|
raise ValueError(f"Workspace {config_only.workspace_id} not found for configuration {config_id}")
|
||||||
raise ValueError(
|
|
||||||
f"Workspace {config_only.workspace_id} not found for configuration {config_id}")
|
|
||||||
|
|
||||||
config_logger.debug(
|
config_logger.debug(
|
||||||
"Configuration not found",
|
"Configuration not found",
|
||||||
@@ -627,8 +613,7 @@ class MemoryConfigRepository:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
db_logger.debug(
|
db_logger.debug(f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
|
||||||
f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
|
|
||||||
return (config, workspace)
|
return (config, workspace)
|
||||||
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@@ -652,34 +637,29 @@ class MemoryConfigRepository:
|
|||||||
|
|
||||||
db_logger.error(f"Failed to query memory config and workspace: config_id={config_id} - {str(e)}")
|
db_logger.error(f"Failed to query memory config and workspace: config_id={config_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[Tuple[MemoryConfig, Optional[str]]]:
|
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[MemoryConfig]:
|
||||||
"""获取所有配置参数,包含关联的场景名称
|
"""获取所有配置参数
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
workspace_id: 工作空间ID,用于过滤查询结果
|
workspace_id: 工作空间ID,用于过滤查询结果
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Tuple[MemoryConfig, Optional[str]]]: 配置列表,每项为 (配置对象, 场景名称)
|
List[MemoryConfig]: 配置列表
|
||||||
"""
|
"""
|
||||||
from app.models.ontology_scene import OntologyScene
|
|
||||||
|
|
||||||
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
|
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
query = db.query(MemoryConfig, OntologyScene.scene_name).outerjoin(
|
query = db.query(MemoryConfig)
|
||||||
OntologyScene, MemoryConfig.scene_id == OntologyScene.scene_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if workspace_id:
|
if workspace_id:
|
||||||
query = query.filter(MemoryConfig.workspace_id == workspace_id)
|
query = query.filter(MemoryConfig.workspace_id == workspace_id)
|
||||||
|
|
||||||
results = query.order_by(desc(MemoryConfig.updated_at)).all()
|
configs = query.order_by(desc(MemoryConfig.updated_at)).all()
|
||||||
|
|
||||||
db_logger.debug(f"配置列表查询成功: 数量={len(results)}")
|
db_logger.debug(f"配置列表查询成功: 数量={len(configs)}")
|
||||||
return results
|
return configs
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}")
|
db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}")
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
from sqlalchemy.orm import Session, joinedload, selectinload
|
from sqlalchemy.orm import Session, joinedload
|
||||||
from sqlalchemy import and_, or_, func, desc, select
|
from sqlalchemy import and_, or_, func, desc
|
||||||
from typing import List, Optional, Dict, Any, Tuple
|
from typing import List, Optional, Dict, Any, Tuple
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelBase, model_config_api_key_association
|
from app.models.models_model import ModelConfig, ModelApiKey, ModelType
|
||||||
from app.schemas.model_schema import (
|
from app.schemas.model_schema import (
|
||||||
ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
||||||
ModelConfigQuery, ModelConfigQueryNew
|
ModelConfigQuery
|
||||||
)
|
)
|
||||||
from app.core.logging_config import get_db_logger
|
from app.core.logging_config import get_db_logger
|
||||||
|
|
||||||
@@ -165,7 +165,7 @@ class ModelConfigRepository:
|
|||||||
total = base_query.count()
|
total = base_query.count()
|
||||||
|
|
||||||
# 分页查询
|
# 分页查询
|
||||||
models = base_query.order_by(desc(ModelConfig.created_at)).offset(
|
models = base_query.order_by(desc(ModelConfig.updated_at)).offset(
|
||||||
(query.page - 1) * query.pagesize
|
(query.page - 1) * query.pagesize
|
||||||
).limit(query.pagesize).all()
|
).limit(query.pagesize).all()
|
||||||
|
|
||||||
@@ -176,84 +176,6 @@ class ModelConfigRepository:
|
|||||||
db_logger.error(f"查询模型配置列表失败: {str(e)}")
|
db_logger.error(f"查询模型配置列表失败: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_list_new(db: Session, query: ModelConfigQueryNew, tenant_id: uuid.UUID | None = None) -> tuple[
|
|
||||||
dict[str, list[ModelConfig]], Any]:
|
|
||||||
"""获取模型配置列表"""
|
|
||||||
db_logger.debug(f"查询模型配置列表: {query.model_dump()}, tenant_id={tenant_id}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 构建查询条件
|
|
||||||
filters = []
|
|
||||||
|
|
||||||
# 添加租户过滤(查询本租户的模型或公开模型)
|
|
||||||
if tenant_id:
|
|
||||||
filters.append(
|
|
||||||
or_(
|
|
||||||
ModelConfig.tenant_id == tenant_id,
|
|
||||||
ModelConfig.is_public
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 支持多个 type 值(使用 IN 查询)
|
|
||||||
# 兼容 chat 和 llm 类型:如果查询包含其中一个,则同时匹配两者
|
|
||||||
if query.type:
|
|
||||||
type_values = list(query.type)
|
|
||||||
# 如果包含 chat 或 llm,则同时包含两者
|
|
||||||
if ModelType.CHAT in type_values or ModelType.LLM in type_values:
|
|
||||||
if ModelType.CHAT not in type_values:
|
|
||||||
type_values.append(ModelType.CHAT)
|
|
||||||
if ModelType.LLM not in type_values:
|
|
||||||
type_values.append(ModelType.LLM)
|
|
||||||
filters.append(ModelConfig.type.in_(type_values))
|
|
||||||
|
|
||||||
if query.is_active is not None:
|
|
||||||
filters.append(ModelConfig.is_active == query.is_active)
|
|
||||||
|
|
||||||
if query.is_public is not None:
|
|
||||||
filters.append(ModelConfig.is_public == query.is_public)
|
|
||||||
|
|
||||||
if query.is_composite is not None:
|
|
||||||
filters.append(ModelConfig.is_composite == query.is_composite)
|
|
||||||
|
|
||||||
if query.provider:
|
|
||||||
filters.append(ModelConfig.provider == query.provider)
|
|
||||||
|
|
||||||
if query.search:
|
|
||||||
search_filter = ModelConfig.name.ilike(f"%{query.search}%")
|
|
||||||
filters.append(search_filter)
|
|
||||||
|
|
||||||
# 构建基础查询
|
|
||||||
base_query = db.query(ModelConfig).options(
|
|
||||||
joinedload(ModelConfig.api_keys)
|
|
||||||
)
|
|
||||||
|
|
||||||
if filters:
|
|
||||||
base_query = base_query.filter(and_(*filters))
|
|
||||||
|
|
||||||
# 获取总数
|
|
||||||
total = base_query.count()
|
|
||||||
|
|
||||||
query_results = base_query.order_by(desc(ModelConfig.created_at)).all()
|
|
||||||
|
|
||||||
provider_groups: Dict[str, List[ModelConfig]] = {}
|
|
||||||
for model_config in query_results:
|
|
||||||
provider = model_config.provider
|
|
||||||
if provider not in provider_groups:
|
|
||||||
provider_groups[provider] = []
|
|
||||||
provider_groups[provider].append(model_config)
|
|
||||||
|
|
||||||
db_logger.debug(
|
|
||||||
f"模型配置列表查询成功: 总数={total}, "
|
|
||||||
f"分组数={len(provider_groups)}, "
|
|
||||||
f"各分组模型数={[len(v) for v in provider_groups.values()]}, "
|
|
||||||
f"type筛选={query.type}")
|
|
||||||
return provider_groups, total
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
db_logger.error(f"查询模型配置列表失败(按provider分组/无分页): {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_by_type(db: Session, model_type: ModelType, tenant_id: uuid.UUID | None = None, is_active: bool = True) -> List[ModelConfig]:
|
def get_by_type(db: Session, model_type: ModelType, tenant_id: uuid.UUID | None = None, is_active: bool = True) -> List[ModelConfig]:
|
||||||
"""根据类型获取模型配置"""
|
"""根据类型获取模型配置"""
|
||||||
@@ -319,7 +241,7 @@ class ModelConfigRepository:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# 更新字段
|
# 更新字段
|
||||||
update_data = model_data.model_dump(exclude_unset=True)
|
update_data = model_data.dict(exclude_unset=True)
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
setattr(db_model, field, value)
|
setattr(db_model, field, value)
|
||||||
|
|
||||||
@@ -381,18 +303,8 @@ class ModelConfigRepository:
|
|||||||
# 按提供商统计 - 现在从ModelApiKey表获取
|
# 按提供商统计 - 现在从ModelApiKey表获取
|
||||||
provider_stats = {}
|
provider_stats = {}
|
||||||
provider_results = db.query(
|
provider_results = db.query(
|
||||||
# 保留 provider 字段
|
ModelApiKey.provider, func.count(func.distinct(ModelApiKey.model_config_id))
|
||||||
ModelApiKey.provider,
|
).group_by(ModelApiKey.provider).all()
|
||||||
# 统计中间表中 唯一的 model_config_id 数量(替换原 ModelApiKey.model_config_id)
|
|
||||||
func.count(func.distinct(model_config_api_key_association.c.model_config_id))
|
|
||||||
).join(
|
|
||||||
# 联表:ModelApiKey <-> 中间表(多对多关联)
|
|
||||||
model_config_api_key_association,
|
|
||||||
ModelApiKey.id == model_config_api_key_association.c.api_key_id
|
|
||||||
).group_by(
|
|
||||||
# 按 provider 分组(保留原有逻辑)
|
|
||||||
ModelApiKey.provider
|
|
||||||
).all()
|
|
||||||
|
|
||||||
for provider, count in provider_results:
|
for provider, count in provider_results:
|
||||||
provider_stats[provider.value] = count
|
provider_stats[provider.value] = count
|
||||||
@@ -413,38 +325,6 @@ class ModelConfigRepository:
|
|||||||
db_logger.error(f"获取模型统计信息失败: {str(e)}")
|
db_logger.error(f"获取模型统计信息失败: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_model_config_ids_by_provider(
|
|
||||||
db: Session,
|
|
||||||
tenant_id: uuid.UUID,
|
|
||||||
provider: Any
|
|
||||||
) -> List[uuid.UUID]:
|
|
||||||
"""根据tenant_id和provider获取model_config_id列表"""
|
|
||||||
db_logger.debug(f"查询model_config_id列表: tenant_id={tenant_id}, provider={provider}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 查询ModelConfig关联的ModelApiKey,筛选出匹配的model_config_id
|
|
||||||
model_config_ids = db.query(ModelConfig.id).join(
|
|
||||||
ModelBase, ModelConfig.model_id == ModelBase.id
|
|
||||||
).filter(
|
|
||||||
and_(
|
|
||||||
or_(
|
|
||||||
ModelConfig.tenant_id == tenant_id,
|
|
||||||
ModelConfig.is_public
|
|
||||||
),
|
|
||||||
ModelBase.provider == provider,
|
|
||||||
ModelConfig.is_active,
|
|
||||||
~ModelConfig.is_composite
|
|
||||||
)
|
|
||||||
).distinct().all()
|
|
||||||
|
|
||||||
db_logger.debug(f"查询成功: 数量={len(model_config_ids)}")
|
|
||||||
return [row[0] for row in model_config_ids]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
db_logger.error(f"查询model_config_id列表失败: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
class ModelApiKeyRepository:
|
class ModelApiKeyRepository:
|
||||||
"""模型API Key Repository"""
|
"""模型API Key Repository"""
|
||||||
@@ -469,14 +349,7 @@ class ModelApiKeyRepository:
|
|||||||
db_logger.debug(f"根据模型配置ID查询API Key: model_config_id={model_config_id}")
|
db_logger.debug(f"根据模型配置ID查询API Key: model_config_id={model_config_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from app.models.models_model import ModelConfig, model_config_api_key_association
|
query = db.query(ModelApiKey).filter(ModelApiKey.model_config_id == model_config_id)
|
||||||
|
|
||||||
query = db.query(ModelApiKey).join(
|
|
||||||
model_config_api_key_association,
|
|
||||||
ModelApiKey.id == model_config_api_key_association.c.api_key_id
|
|
||||||
).filter(
|
|
||||||
model_config_api_key_association.c.model_config_id == model_config_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_active:
|
if is_active:
|
||||||
query = query.filter(ModelApiKey.is_active)
|
query = query.filter(ModelApiKey.is_active)
|
||||||
@@ -495,20 +368,8 @@ class ModelApiKeyRepository:
|
|||||||
db_logger.debug(f"创建API Key: {api_key_data.provider}")
|
db_logger.debug(f"创建API Key: {api_key_data.provider}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from app.models.models_model import ModelConfig
|
db_api_key = ModelApiKey(**api_key_data.dict())
|
||||||
|
|
||||||
# 创建API Key,不包含model_config_ids
|
|
||||||
api_key_dict = api_key_data.model_dump(exclude={"model_config_ids"})
|
|
||||||
db_api_key = ModelApiKey(**api_key_dict)
|
|
||||||
db.add(db_api_key)
|
db.add(db_api_key)
|
||||||
db.flush() # 获取生成的ID
|
|
||||||
|
|
||||||
# 关联ModelConfig
|
|
||||||
if api_key_data.model_config_ids:
|
|
||||||
for model_config_id in api_key_data.model_config_ids:
|
|
||||||
model_config = db.query(ModelConfig).filter(ModelConfig.id == model_config_id).first()
|
|
||||||
if model_config:
|
|
||||||
db_api_key.model_configs.append(model_config)
|
|
||||||
|
|
||||||
db_logger.info(f"API Key已添加到会话: {db_api_key.provider}")
|
db_logger.info(f"API Key已添加到会话: {db_api_key.provider}")
|
||||||
return db_api_key
|
return db_api_key
|
||||||
@@ -530,7 +391,7 @@ class ModelApiKeyRepository:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# 更新字段
|
# 更新字段
|
||||||
update_data = api_key_data.model_dump(exclude_unset=True)
|
update_data = api_key_data.dict(exclude_unset=True)
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
setattr(db_api_key, field, value)
|
setattr(db_api_key, field, value)
|
||||||
|
|
||||||
@@ -591,91 +452,3 @@ class ModelApiKeyRepository:
|
|||||||
db.rollback()
|
db.rollback()
|
||||||
db_logger.error(f"更新API Key使用统计失败: api_key_id={api_key_id} - {str(e)}")
|
db_logger.error(f"更新API Key使用统计失败: api_key_id={api_key_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
class ModelBaseRepository:
|
|
||||||
"""基础模型Repository"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_by_id(db: Session, model_base_id: uuid.UUID) -> Optional['ModelBase']:
|
|
||||||
return db.query(ModelBase).filter(ModelBase.id == model_base_id).first()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_list(db: Session, query: 'ModelBaseQuery') -> List['ModelBase']:
|
|
||||||
|
|
||||||
filters = []
|
|
||||||
if query.type:
|
|
||||||
filters.append(ModelBase.type == query.type)
|
|
||||||
if query.provider:
|
|
||||||
filters.append(ModelBase.provider == query.provider)
|
|
||||||
if query.is_official is not None:
|
|
||||||
filters.append(ModelBase.is_official == query.is_official)
|
|
||||||
if query.is_deprecated is not None:
|
|
||||||
filters.append(ModelBase.is_deprecated == query.is_deprecated)
|
|
||||||
if query.search:
|
|
||||||
filters.append(or_(
|
|
||||||
ModelBase.name.ilike(f"%{query.search}%"),
|
|
||||||
# ModelBase.description.ilike(f"%{query.search}%")
|
|
||||||
))
|
|
||||||
|
|
||||||
q = db.query(ModelBase)
|
|
||||||
if filters:
|
|
||||||
q = q.filter(and_(*filters))
|
|
||||||
|
|
||||||
return q.order_by(ModelBase.add_count.desc(), ModelBase.created_at.desc()).all()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create(db: Session, data: dict) -> 'ModelBase':
|
|
||||||
model_base = ModelBase(**data)
|
|
||||||
db.add(model_base)
|
|
||||||
return model_base
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_by_name_and_provider(db: Session, name: str, provider: str) -> Optional['ModelBase']:
|
|
||||||
return db.query(ModelBase).filter(
|
|
||||||
ModelBase.name == name,
|
|
||||||
ModelBase.provider == provider
|
|
||||||
).first()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def update(db: Session, model_base_id: uuid.UUID, data: dict) -> Optional['ModelBase']:
|
|
||||||
model_base = db.query(ModelBase).filter(ModelBase.id == model_base_id).first()
|
|
||||||
if not model_base:
|
|
||||||
return None
|
|
||||||
for key, value in data.items():
|
|
||||||
setattr(model_base, key, value)
|
|
||||||
|
|
||||||
# 同步更新绑定的非组合模型配置
|
|
||||||
if any(k in data for k in ['name', 'description', 'logo']):
|
|
||||||
db.query(ModelConfig).filter(
|
|
||||||
ModelConfig.model_id == model_base_id,
|
|
||||||
ModelConfig.is_composite == False
|
|
||||||
).update({
|
|
||||||
k: v for k, v in data.items()
|
|
||||||
if k in ['name', 'description', 'logo']
|
|
||||||
}, synchronize_session=False)
|
|
||||||
|
|
||||||
return model_base
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def delete(db: Session, model_base_id: uuid.UUID) -> bool:
|
|
||||||
model_base = db.query(ModelBase).filter(ModelBase.id == model_base_id).first()
|
|
||||||
if not model_base:
|
|
||||||
return False
|
|
||||||
db.delete(model_base)
|
|
||||||
return True
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def increment_add_count(db: Session, model_base_id: uuid.UUID) -> bool:
|
|
||||||
model_base = db.query(ModelBase).filter(ModelBase.id == model_base_id).first()
|
|
||||||
if not model_base:
|
|
||||||
return False
|
|
||||||
model_base.add_count += 1
|
|
||||||
return True
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_added_by_tenant(db: Session, model_base_id: uuid.UUID, tenant_id: uuid.UUID) -> bool:
|
|
||||||
return db.query(ModelConfig).filter(
|
|
||||||
ModelConfig.model_id == model_base_id,
|
|
||||||
ModelConfig.tenant_id == tenant_id
|
|
||||||
).first() is not None
|
|
||||||
|
|||||||
@@ -79,8 +79,7 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
|
|||||||
try:
|
try:
|
||||||
edges: List[dict] = []
|
edges: List[dict] = []
|
||||||
for s in summaries:
|
for s in summaries:
|
||||||
chunk_ids = getattr(s, "chunk_ids", []) or []
|
for chunk_id in getattr(s, "chunk_ids", []) or []:
|
||||||
for chunk_id in chunk_ids:
|
|
||||||
edges.append({
|
edges.append({
|
||||||
"summary_id": s.id,
|
"summary_id": s.id,
|
||||||
"chunk_id": chunk_id,
|
"chunk_id": chunk_id,
|
||||||
@@ -92,11 +91,12 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
|
|||||||
|
|
||||||
if not edges:
|
if not edges:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
result = await connector.execute_query(
|
result = await connector.execute_query(
|
||||||
MEMORY_SUMMARY_STATEMENT_EDGE_SAVE,
|
MEMORY_SUMMARY_STATEMENT_EDGE_SAVE,
|
||||||
edges=edges
|
edges=edges
|
||||||
)
|
)
|
||||||
created = [record.get("uuid") for record in result] if result else []
|
created = [record.get("uuid") for record in result] if result else []
|
||||||
return created
|
return created
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -217,10 +217,8 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
|||||||
summaries=flattened
|
summaries=flattened
|
||||||
)
|
)
|
||||||
created_ids = [record.get("uuid") for record in result]
|
created_ids = [record.get("uuid") for record in result]
|
||||||
print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
|
|
||||||
return created_ids
|
return created_ids
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print(f"Failed to save MemorySummary nodes to Neo4j: {e}")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -101,11 +101,10 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
|
|||||||
e.name_embedding = CASE
|
e.name_embedding = CASE
|
||||||
WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding
|
WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding
|
||||||
ELSE e.name_embedding END,
|
ELSE e.name_embedding END,
|
||||||
// TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
e.fact_summary = CASE
|
||||||
// e.fact_summary = CASE
|
WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> ''
|
||||||
// WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> ''
|
AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary))
|
||||||
// AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary))
|
THEN entity.fact_summary ELSE e.fact_summary END,
|
||||||
// THEN entity.fact_summary ELSE e.fact_summary END,
|
|
||||||
e.connect_strength = CASE
|
e.connect_strength = CASE
|
||||||
WHEN entity.connect_strength IS NULL OR entity.connect_strength = '' THEN e.connect_strength
|
WHEN entity.connect_strength IS NULL OR entity.connect_strength = '' THEN e.connect_strength
|
||||||
ELSE CASE
|
ELSE CASE
|
||||||
@@ -322,8 +321,7 @@ RETURN e.id AS id,
|
|||||||
e.description AS description,
|
e.description AS description,
|
||||||
e.aliases AS aliases,
|
e.aliases AS aliases,
|
||||||
e.name_embedding AS name_embedding,
|
e.name_embedding AS name_embedding,
|
||||||
// TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
COALESCE(e.fact_summary, '') AS fact_summary,
|
||||||
// COALESCE(e.fact_summary, '') AS fact_summary,
|
|
||||||
e.connect_strength AS connect_strength,
|
e.connect_strength AS connect_strength,
|
||||||
collect(DISTINCT s.id) AS statement_ids,
|
collect(DISTINCT s.id) AS statement_ids,
|
||||||
collect(DISTINCT c.id) AS chunk_ids,
|
collect(DISTINCT c.id) AS chunk_ids,
|
||||||
@@ -879,8 +877,7 @@ RETURN
|
|||||||
CASE
|
CASE
|
||||||
WHEN ms:ExtractedEntity THEN {
|
WHEN ms:ExtractedEntity THEN {
|
||||||
text: ms.name,
|
text: ms.name,
|
||||||
created_at: ms.created_at,
|
created_at: ms.created_at
|
||||||
type: "情景记忆"
|
|
||||||
}
|
}
|
||||||
END
|
END
|
||||||
) AS ExtractedEntity,
|
) AS ExtractedEntity,
|
||||||
@@ -890,8 +887,7 @@ RETURN
|
|||||||
CASE
|
CASE
|
||||||
WHEN n:MemorySummary THEN {
|
WHEN n:MemorySummary THEN {
|
||||||
text: n.content,
|
text: n.content,
|
||||||
created_at: n.created_at,
|
created_at: n.created_at
|
||||||
type: "长期沉淀"
|
|
||||||
}
|
}
|
||||||
END
|
END
|
||||||
) AS MemorySummary,
|
) AS MemorySummary,
|
||||||
@@ -899,8 +895,7 @@ RETURN
|
|||||||
collect(
|
collect(
|
||||||
DISTINCT {
|
DISTINCT {
|
||||||
text: e.statement,
|
text: e.statement,
|
||||||
created_at: e.created_at,
|
created_at: e.created_at
|
||||||
type: "情绪记忆"
|
|
||||||
}
|
}
|
||||||
) AS statement;
|
) AS statement;
|
||||||
"""
|
"""
|
||||||
@@ -1004,58 +999,3 @@ RETURN DISTINCT
|
|||||||
x.statement as statement,x.created_at as created_at
|
x.statement as statement,x.created_at as created_at
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Graph_Node_query = """
|
|
||||||
MATCH (n:MemorySummary)
|
|
||||||
WHERE n.end_user_id = $end_user_id
|
|
||||||
RETURN
|
|
||||||
elementId(n) AS id,
|
|
||||||
labels(n) AS labels,
|
|
||||||
properties(n) AS properties,
|
|
||||||
0 AS priority
|
|
||||||
LIMIT $limit
|
|
||||||
|
|
||||||
UNION ALL
|
|
||||||
|
|
||||||
MATCH (n:Dialogue)
|
|
||||||
WHERE n.end_user_id = $end_user_id
|
|
||||||
RETURN
|
|
||||||
elementId(n) AS id,
|
|
||||||
labels(n) AS labels,
|
|
||||||
properties(n) AS properties,
|
|
||||||
1 AS priority
|
|
||||||
LIMIT 1
|
|
||||||
|
|
||||||
UNION ALL
|
|
||||||
|
|
||||||
MATCH (n:Statement)
|
|
||||||
WHERE n.end_user_id = $end_user_id
|
|
||||||
RETURN
|
|
||||||
elementId(n) AS id,
|
|
||||||
labels(n) AS labels,
|
|
||||||
properties(n) AS properties,
|
|
||||||
1 AS priority
|
|
||||||
LIMIT $limit
|
|
||||||
|
|
||||||
UNION ALL
|
|
||||||
|
|
||||||
MATCH (n:ExtractedEntity)
|
|
||||||
WHERE n.end_user_id = $end_user_id
|
|
||||||
RETURN
|
|
||||||
elementId(n) AS id,
|
|
||||||
labels(n) AS labels,
|
|
||||||
properties(n) AS properties,
|
|
||||||
2 AS priority
|
|
||||||
LIMIT $limit
|
|
||||||
|
|
||||||
UNION ALL
|
|
||||||
|
|
||||||
MATCH (n:Chunk)
|
|
||||||
WHERE n.end_user_id = $end_user_id
|
|
||||||
RETURN
|
|
||||||
elementId(n) AS id,
|
|
||||||
labels(n) AS labels,
|
|
||||||
properties(n) AS properties,
|
|
||||||
3 AS priority
|
|
||||||
LIMIT $limit
|
|
||||||
|
|
||||||
"""
|
|
||||||
@@ -21,8 +21,7 @@ from app.core.memory.models.graph_models import (
|
|||||||
ExtractedEntityNode,
|
ExtractedEntityNode,
|
||||||
EntityEntityEdge,
|
EntityEntityEdge,
|
||||||
)
|
)
|
||||||
import logging
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
async def save_entities_and_relationships(
|
async def save_entities_and_relationships(
|
||||||
entity_nodes: List[ExtractedEntityNode],
|
entity_nodes: List[ExtractedEntityNode],
|
||||||
entity_entity_edges: List[EntityEntityEdge],
|
entity_entity_edges: List[EntityEntityEdge],
|
||||||
@@ -42,8 +41,8 @@ async def save_entities_and_relationships(
|
|||||||
'statement': edge.statement,
|
'statement': edge.statement,
|
||||||
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
|
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
|
||||||
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
|
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
|
||||||
'created_at': edge.created_at.isoformat() if edge.created_at else None,
|
'created_at': edge.created_at.isoformat(),
|
||||||
'expired_at': edge.expired_at.isoformat() if edge.expired_at else None,
|
'expired_at': edge.expired_at.isoformat(),
|
||||||
'run_id': edge.run_id,
|
'run_id': edge.run_id,
|
||||||
'end_user_id': edge.end_user_id,
|
'end_user_id': edge.end_user_id,
|
||||||
}
|
}
|
||||||
@@ -148,14 +147,14 @@ async def save_statement_entity_edges(
|
|||||||
|
|
||||||
|
|
||||||
async def save_dialog_and_statements_to_neo4j(
|
async def save_dialog_and_statements_to_neo4j(
|
||||||
dialogue_nodes: List[DialogueNode],
|
dialogue_nodes: List[DialogueNode],
|
||||||
chunk_nodes: List[ChunkNode],
|
chunk_nodes: List[ChunkNode],
|
||||||
statement_nodes: List[StatementNode],
|
statement_nodes: List[StatementNode],
|
||||||
entity_nodes: List[ExtractedEntityNode],
|
entity_nodes: List[ExtractedEntityNode],
|
||||||
entity_edges: List[EntityEntityEdge],
|
entity_edges: List[EntityEntityEdge],
|
||||||
statement_chunk_edges: List[StatementChunkEdge],
|
statement_chunk_edges: List[StatementChunkEdge],
|
||||||
statement_entity_edges: List[StatementEntityEdge],
|
statement_entity_edges: List[StatementEntityEdge],
|
||||||
connector: Neo4jConnector
|
connector: Neo4jConnector
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
||||||
|
|
||||||
@@ -172,127 +171,40 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if successful, False otherwise
|
bool: True if successful, False otherwise
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 定义事务函数,将所有写操作放在一个事务中
|
|
||||||
async def _save_all_in_transaction(tx):
|
|
||||||
"""在单个事务中执行所有保存操作,避免死锁"""
|
|
||||||
results = {}
|
|
||||||
|
|
||||||
# 1. Save all dialogue nodes in batch
|
|
||||||
if dialogue_nodes:
|
|
||||||
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE
|
|
||||||
dialogue_data = [node.model_dump() for node in dialogue_nodes]
|
|
||||||
result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data)
|
|
||||||
dialogue_uuids = [record["uuid"] async for record in result]
|
|
||||||
results['dialogues'] = dialogue_uuids
|
|
||||||
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
|
|
||||||
|
|
||||||
# 2. Save all chunk nodes in batch
|
|
||||||
if chunk_nodes:
|
|
||||||
from app.repositories.neo4j.cypher_queries import CHUNK_NODE_SAVE
|
|
||||||
chunk_data = [node.model_dump() for node in chunk_nodes]
|
|
||||||
result = await tx.run(CHUNK_NODE_SAVE, chunks=chunk_data)
|
|
||||||
chunk_uuids = [record["uuid"] async for record in result]
|
|
||||||
results['chunks'] = chunk_uuids
|
|
||||||
logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
|
|
||||||
|
|
||||||
# 3. Save all statement nodes in batch
|
|
||||||
if statement_nodes:
|
|
||||||
from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE
|
|
||||||
statement_data = [node.model_dump() for node in statement_nodes]
|
|
||||||
result = await tx.run(STATEMENT_NODE_SAVE, statements=statement_data)
|
|
||||||
statement_uuids = [record["uuid"] async for record in result]
|
|
||||||
results['statements'] = statement_uuids
|
|
||||||
logger.info(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j")
|
|
||||||
|
|
||||||
# 4. Save entities
|
|
||||||
if entity_nodes:
|
|
||||||
from app.repositories.neo4j.cypher_queries import EXTRACTED_ENTITY_NODE_SAVE
|
|
||||||
entity_data = [entity.model_dump() for entity in entity_nodes]
|
|
||||||
result = await tx.run(EXTRACTED_ENTITY_NODE_SAVE, entities=entity_data)
|
|
||||||
entity_uuids = [record["uuid"] async for record in result]
|
|
||||||
results['entities'] = entity_uuids
|
|
||||||
logger.info(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j")
|
|
||||||
|
|
||||||
# 5. Create entity relationships
|
|
||||||
if entity_edges:
|
|
||||||
from app.repositories.neo4j.cypher_queries import ENTITY_RELATIONSHIP_SAVE
|
|
||||||
relationship_data = []
|
|
||||||
for edge in entity_edges:
|
|
||||||
relationship_data.append({
|
|
||||||
'source_id': edge.source,
|
|
||||||
'target_id': edge.target,
|
|
||||||
'predicate': edge.relation_type,
|
|
||||||
'statement_id': edge.source_statement_id,
|
|
||||||
'value': edge.relation_value,
|
|
||||||
'statement': edge.statement,
|
|
||||||
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
|
|
||||||
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
|
|
||||||
'created_at': edge.created_at.isoformat() if edge.created_at else None,
|
|
||||||
'expired_at': edge.expired_at.isoformat() if edge.expired_at else None,
|
|
||||||
'run_id': edge.run_id,
|
|
||||||
'end_user_id': edge.end_user_id,
|
|
||||||
})
|
|
||||||
result = await tx.run(ENTITY_RELATIONSHIP_SAVE, relationships=relationship_data)
|
|
||||||
rel_uuids = [record["uuid"] async for record in result]
|
|
||||||
results['entity_relationships'] = rel_uuids
|
|
||||||
logger.info(f"Successfully saved {len(rel_uuids)} entity relationships to Neo4j")
|
|
||||||
|
|
||||||
# 6. Save statement-chunk edges
|
|
||||||
if statement_chunk_edges:
|
|
||||||
from app.repositories.neo4j.cypher_queries import CHUNK_STATEMENT_EDGE_SAVE
|
|
||||||
sc_edge_data = []
|
|
||||||
for edge in statement_chunk_edges:
|
|
||||||
sc_edge_data.append({
|
|
||||||
"id": edge.id,
|
|
||||||
"source": edge.source,
|
|
||||||
"target": edge.target,
|
|
||||||
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
|
||||||
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
|
|
||||||
"run_id": edge.run_id,
|
|
||||||
"end_user_id": edge.end_user_id,
|
|
||||||
})
|
|
||||||
result = await tx.run(CHUNK_STATEMENT_EDGE_SAVE, chunk_statement_edges=sc_edge_data)
|
|
||||||
sc_uuids = [record["uuid"] async for record in result]
|
|
||||||
results['statement_chunk_edges'] = sc_uuids
|
|
||||||
logger.info(f"Successfully saved {len(sc_uuids)} statement-chunk edges to Neo4j")
|
|
||||||
|
|
||||||
# 7. Save statement-entity edges
|
|
||||||
if statement_entity_edges:
|
|
||||||
from app.repositories.neo4j.cypher_queries import STATEMENT_ENTITY_EDGE_SAVE
|
|
||||||
se_edge_data = []
|
|
||||||
for edge in statement_entity_edges:
|
|
||||||
se_edge_data.append({
|
|
||||||
"source": edge.source,
|
|
||||||
"target": edge.target,
|
|
||||||
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
|
||||||
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
|
|
||||||
"run_id": edge.run_id,
|
|
||||||
"end_user_id": edge.end_user_id,
|
|
||||||
"connect_strength": getattr(edge, "connect_strength", "strong"),
|
|
||||||
})
|
|
||||||
result = await tx.run(STATEMENT_ENTITY_EDGE_SAVE, relationships=se_edge_data)
|
|
||||||
se_uuids = [record["uuid"] async for record in result]
|
|
||||||
results['statement_entity_edges'] = se_uuids
|
|
||||||
logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用显式写事务执行所有操作,避免死锁
|
# Save all dialogue nodes in batch
|
||||||
results = await connector.execute_write_transaction(_save_all_in_transaction)
|
dialogue_uuids = await add_dialogue_nodes(dialogue_nodes, connector)
|
||||||
summary = {
|
if dialogue_uuids:
|
||||||
key: len(value)
|
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
|
||||||
for key, value in results.items()
|
else:
|
||||||
if isinstance(value, (list, tuple, set))
|
print("Failed to save dialogues to Neo4j")
|
||||||
}
|
return False
|
||||||
logger.info("Transaction completed. Summary: %s", summary)
|
|
||||||
logger.debug("Full transaction results: %r", results)
|
# Save all chunk nodes in batch
|
||||||
|
await save_chunk_nodes(chunk_nodes, connector)
|
||||||
|
|
||||||
|
# Save all statement nodes in batch
|
||||||
|
if statement_nodes:
|
||||||
|
statement_uuids = await add_statement_nodes(statement_nodes, connector)
|
||||||
|
if statement_uuids:
|
||||||
|
print(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j")
|
||||||
|
else:
|
||||||
|
print("Failed to save statement nodes to Neo4j")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
print("No statement nodes to save")
|
||||||
|
|
||||||
|
# Save entities and relationships
|
||||||
|
await save_entities_and_relationships(entity_nodes, entity_edges, connector)
|
||||||
|
print("Successfully saved entities and relationships to Neo4j")
|
||||||
|
|
||||||
|
# Save new edges
|
||||||
|
await save_statement_chunk_edges(statement_chunk_edges, connector)
|
||||||
|
await save_statement_entity_edges(statement_entity_edges, connector)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Neo4j integration error: {e}", exc_info=True)
|
|
||||||
print(f"Neo4j integration error: {e}")
|
print(f"Neo4j integration error: {e}")
|
||||||
print("Continuing without database storage...")
|
print("Continuing without database storage...")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -1,404 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""本体类型Repository层
|
|
||||||
|
|
||||||
本模块提供本体类型的数据访问层实现。
|
|
||||||
|
|
||||||
Classes:
|
|
||||||
OntologyClassRepository: 本体类型数据访问类
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import List, Optional
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from sqlalchemy.orm import Session, joinedload
|
|
||||||
|
|
||||||
from app.core.logging_config import get_db_logger
|
|
||||||
from app.models.ontology_class import OntologyClass
|
|
||||||
from app.models.ontology_scene import OntologyScene
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_db_logger()
|
|
||||||
|
|
||||||
|
|
||||||
class OntologyClassRepository:
|
|
||||||
"""本体类型Repository
|
|
||||||
|
|
||||||
提供本体类型的CRUD操作和权限检查。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
db: SQLAlchemy数据库会话
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, db: Session):
|
|
||||||
"""初始化Repository
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: SQLAlchemy数据库会话
|
|
||||||
"""
|
|
||||||
self.db = db
|
|
||||||
|
|
||||||
def create(self, class_data: dict, scene_id: UUID) -> OntologyClass:
|
|
||||||
"""创建本体类型
|
|
||||||
|
|
||||||
Args:
|
|
||||||
class_data: 类型数据字典,包含class_name和class_description
|
|
||||||
scene_id: 所属场景ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OntologyClass: 创建的类型对象
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: 数据库操作失败
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> repo = OntologyClassRepository(db)
|
|
||||||
>>> ontology_class = repo.create(
|
|
||||||
... {"class_name": "患者", "class_description": "描述"},
|
|
||||||
... scene_id
|
|
||||||
... )
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info(
|
|
||||||
f"Creating ontology class - "
|
|
||||||
f"name={class_data.get('class_name')}, "
|
|
||||||
f"scene_id={scene_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
ontology_class = OntologyClass(
|
|
||||||
class_name=class_data.get("class_name"),
|
|
||||||
class_description=class_data.get("class_description"),
|
|
||||||
scene_id=scene_id
|
|
||||||
)
|
|
||||||
|
|
||||||
self.db.add(ontology_class)
|
|
||||||
self.db.flush() # 获取ID但不提交
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Ontology class created successfully - "
|
|
||||||
f"class_id={ontology_class.class_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return ontology_class
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to create ontology class: {str(e)}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_by_id(self, class_id: UUID) -> Optional[OntologyClass]:
|
|
||||||
"""根据ID获取类型
|
|
||||||
|
|
||||||
Args:
|
|
||||||
class_id: 类型ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[OntologyClass]: 类型对象,不存在则返回None
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> repo = OntologyClassRepository(db)
|
|
||||||
>>> ontology_class = repo.get_by_id(class_id)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.debug(f"Getting ontology class by ID: {class_id}")
|
|
||||||
|
|
||||||
ontology_class = self.db.query(OntologyClass).filter(
|
|
||||||
OntologyClass.class_id == class_id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if ontology_class:
|
|
||||||
logger.debug(f"Ontology class found: {class_id}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"Ontology class not found: {class_id}")
|
|
||||||
|
|
||||||
return ontology_class
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to get ontology class by ID: {str(e)}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_by_name(self, class_name: str, scene_id: UUID) -> Optional[OntologyClass]:
|
|
||||||
"""根据类型名称和场景ID获取类型(精确匹配)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
class_name: 类型名称
|
|
||||||
scene_id: 场景ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[OntologyClass]: 类型对象,不存在则返回None
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> repo = OntologyClassRepository(db)
|
|
||||||
>>> ontology_class = repo.get_by_name("患者", scene_id)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.debug(f"Getting ontology class by name: {class_name}, scene_id: {scene_id}")
|
|
||||||
|
|
||||||
ontology_class = self.db.query(OntologyClass).filter(
|
|
||||||
OntologyClass.class_name == class_name,
|
|
||||||
OntologyClass.scene_id == scene_id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if ontology_class:
|
|
||||||
logger.debug(f"Ontology class found: {class_name}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"Ontology class not found: {class_name}")
|
|
||||||
|
|
||||||
return ontology_class
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to get ontology class by name: {str(e)}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def search_by_name(self, keyword: str, scene_id: UUID) -> List[OntologyClass]:
|
|
||||||
"""根据关键词模糊搜索类型
|
|
||||||
|
|
||||||
使用 LIKE 进行模糊匹配,支持中文和英文。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
keyword: 搜索关键词
|
|
||||||
scene_id: 场景ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[OntologyClass]: 匹配的类型列表
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> repo = OntologyClassRepository(db)
|
|
||||||
>>> classes = repo.search_by_name("患者", scene_id)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.debug(
|
|
||||||
f"Searching ontology classes by keyword - "
|
|
||||||
f"keyword={keyword}, scene_id={scene_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 使用 ilike 进行不区分大小写的模糊匹配
|
|
||||||
classes = self.db.query(OntologyClass).filter(
|
|
||||||
OntologyClass.class_name.ilike(f"%{keyword}%"),
|
|
||||||
OntologyClass.scene_id == scene_id
|
|
||||||
).order_by(
|
|
||||||
OntologyClass.created_at.desc()
|
|
||||||
).all()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Found {len(classes)} ontology classes matching keyword '{keyword}' "
|
|
||||||
f"in scene {scene_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return classes
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to search ontology classes by keyword: {str(e)}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_by_scene(self, scene_id: UUID) -> List[OntologyClass]:
|
|
||||||
"""获取场景下的所有类型
|
|
||||||
|
|
||||||
按创建时间倒序排列。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scene_id: 场景ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[OntologyClass]: 类型列表
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> repo = OntologyClassRepository(db)
|
|
||||||
>>> classes = repo.get_by_scene(scene_id)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.debug(f"Getting ontology classes by scene: {scene_id}")
|
|
||||||
|
|
||||||
classes = self.db.query(OntologyClass).filter(
|
|
||||||
OntologyClass.scene_id == scene_id
|
|
||||||
).order_by(
|
|
||||||
OntologyClass.created_at.desc()
|
|
||||||
).all()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Found {len(classes)} ontology classes in scene {scene_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return classes
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to get ontology classes by scene: {str(e)}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def update(self, class_id: UUID, update_data: dict) -> Optional[OntologyClass]:
|
|
||||||
"""更新类型信息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
class_id: 类型ID
|
|
||||||
update_data: 更新数据字典
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[OntologyClass]: 更新后的类型对象,不存在则返回None
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: 数据库操作失败
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> repo = OntologyClassRepository(db)
|
|
||||||
>>> ontology_class = repo.update(
|
|
||||||
... class_id,
|
|
||||||
... {"class_name": "新名称"}
|
|
||||||
... )
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info(f"Updating ontology class: {class_id}")
|
|
||||||
|
|
||||||
ontology_class = self.get_by_id(class_id)
|
|
||||||
if not ontology_class:
|
|
||||||
logger.warning(f"Ontology class not found for update: {class_id}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 更新字段
|
|
||||||
if "class_name" in update_data and update_data["class_name"] is not None:
|
|
||||||
ontology_class.class_name = update_data["class_name"]
|
|
||||||
|
|
||||||
if "class_description" in update_data:
|
|
||||||
ontology_class.class_description = update_data["class_description"]
|
|
||||||
|
|
||||||
self.db.flush()
|
|
||||||
|
|
||||||
logger.info(f"Ontology class updated successfully: {class_id}")
|
|
||||||
|
|
||||||
return ontology_class
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to update ontology class: {str(e)}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def delete(self, class_id: UUID) -> bool:
|
|
||||||
"""删除类型
|
|
||||||
|
|
||||||
Args:
|
|
||||||
class_id: 类型ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 删除成功返回True,类型不存在返回False
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: 数据库操作失败
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> repo = OntologyClassRepository(db)
|
|
||||||
>>> success = repo.delete(class_id)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info(f"Deleting ontology class: {class_id}")
|
|
||||||
|
|
||||||
ontology_class = self.get_by_id(class_id)
|
|
||||||
if not ontology_class:
|
|
||||||
logger.warning(f"Ontology class not found for delete: {class_id}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
self.db.delete(ontology_class)
|
|
||||||
self.db.flush()
|
|
||||||
|
|
||||||
logger.info(f"Ontology class deleted successfully: {class_id}")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to delete ontology class: {str(e)}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def check_ownership(self, class_id: UUID, workspace_id: UUID) -> bool:
|
|
||||||
"""检查类型是否属于指定工作空间(通过场景关联)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
class_id: 类型ID
|
|
||||||
workspace_id: 工作空间ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 属于返回True,否则返回False
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> repo = OntologyClassRepository(db)
|
|
||||||
>>> is_owner = repo.check_ownership(class_id, workspace_id)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.debug(
|
|
||||||
f"Checking class ownership - "
|
|
||||||
f"class_id={class_id}, workspace_id={workspace_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
count = self.db.query(OntologyClass).join(
|
|
||||||
OntologyScene,
|
|
||||||
OntologyClass.scene_id == OntologyScene.scene_id
|
|
||||||
).filter(
|
|
||||||
OntologyClass.class_id == class_id,
|
|
||||||
OntologyScene.workspace_id == workspace_id
|
|
||||||
).count()
|
|
||||||
|
|
||||||
is_owner = count > 0
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Class ownership check result: {is_owner} - "
|
|
||||||
f"class_id={class_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return is_owner
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to check class ownership: {str(e)}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_scene_id_by_class(self, class_id: UUID) -> Optional[UUID]:
|
|
||||||
"""根据类型ID获取所属场景ID
|
|
||||||
|
|
||||||
Args:
|
|
||||||
class_id: 类型ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[UUID]: 场景ID,类型不存在则返回None
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> repo = OntologyClassRepository(db)
|
|
||||||
>>> scene_id = repo.get_scene_id_by_class(class_id)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.debug(f"Getting scene ID by class: {class_id}")
|
|
||||||
|
|
||||||
ontology_class = self.get_by_id(class_id)
|
|
||||||
if not ontology_class:
|
|
||||||
logger.debug(f"Class not found: {class_id}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Found scene ID: {ontology_class.scene_id} for class: {class_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return ontology_class.scene_id
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to get scene ID by class: {str(e)}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
@@ -1,439 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""本体场景Repository层
|
|
||||||
|
|
||||||
本模块提供本体场景的数据访问层实现。
|
|
||||||
|
|
||||||
Classes:
|
|
||||||
OntologySceneRepository: 本体场景数据访问类
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import List, Optional
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from sqlalchemy.orm import Session, joinedload
|
|
||||||
|
|
||||||
from app.core.logging_config import get_db_logger
|
|
||||||
from app.models.ontology_scene import OntologyScene
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_db_logger()
|
|
||||||
|
|
||||||
|
|
||||||
class OntologySceneRepository:
|
|
||||||
"""本体场景Repository
|
|
||||||
|
|
||||||
提供本体场景的CRUD操作和权限检查。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
db: SQLAlchemy数据库会话
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, db: Session):
|
|
||||||
"""初始化Repository
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: SQLAlchemy数据库会话
|
|
||||||
"""
|
|
||||||
self.db = db
|
|
||||||
|
|
||||||
def create(self, scene_data: dict, workspace_id: UUID) -> OntologyScene:
|
|
||||||
"""创建本体场景
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scene_data: 场景数据字典,包含scene_name和scene_description
|
|
||||||
workspace_id: 所属工作空间ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OntologyScene: 创建的场景对象
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: 数据库操作失败
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> repo = OntologySceneRepository(db)
|
|
||||||
>>> scene = repo.create(
|
|
||||||
... {"scene_name": "医疗场景", "scene_description": "描述"},
|
|
||||||
... workspace_id
|
|
||||||
... )
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info(
|
|
||||||
f"Creating ontology scene - "
|
|
||||||
f"name={scene_data.get('scene_name')}, "
|
|
||||||
f"workspace_id={workspace_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
scene = OntologyScene(
|
|
||||||
scene_name=scene_data.get("scene_name"),
|
|
||||||
scene_description=scene_data.get("scene_description"),
|
|
||||||
workspace_id=workspace_id
|
|
||||||
)
|
|
||||||
|
|
||||||
self.db.add(scene)
|
|
||||||
self.db.flush() # 获取ID但不提交
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Ontology scene created successfully - "
|
|
||||||
f"scene_id={scene.scene_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return scene
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to create ontology scene: {str(e)}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_by_id(self, scene_id: UUID) -> Optional[OntologyScene]:
|
|
||||||
"""根据ID获取场景
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scene_id: 场景ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[OntologyScene]: 场景对象,不存在则返回None
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> repo = OntologySceneRepository(db)
|
|
||||||
>>> scene = repo.get_by_id(scene_id)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.debug(f"Getting ontology scene by ID: {scene_id}")
|
|
||||||
|
|
||||||
scene = self.db.query(OntologyScene).filter(
|
|
||||||
OntologyScene.scene_id == scene_id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if scene:
|
|
||||||
logger.debug(f"Ontology scene found: {scene_id}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"Ontology scene not found: {scene_id}")
|
|
||||||
|
|
||||||
return scene
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to get ontology scene by ID: {str(e)}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_by_name(self, scene_name: str, workspace_id: UUID) -> Optional[OntologyScene]:
|
|
||||||
"""根据场景名称和工作空间ID获取场景(精确匹配)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scene_name: 场景名称
|
|
||||||
workspace_id: 工作空间ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[OntologyScene]: 场景对象,不存在则返回None
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> repo = OntologySceneRepository(db)
|
|
||||||
>>> scene = repo.get_by_name("医疗场景", workspace_id)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.debug(
|
|
||||||
f"Getting ontology scene by name - "
|
|
||||||
f"scene_name={scene_name}, workspace_id={workspace_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
scene = self.db.query(OntologyScene).options(
|
|
||||||
joinedload(OntologyScene.classes)
|
|
||||||
).filter(
|
|
||||||
OntologyScene.scene_name == scene_name,
|
|
||||||
OntologyScene.workspace_id == workspace_id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if scene:
|
|
||||||
logger.debug(f"Ontology scene found: {scene_name}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"Ontology scene not found: {scene_name}")
|
|
||||||
|
|
||||||
return scene
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to get ontology scene by name: {str(e)}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def search_by_name(self, keyword: str, workspace_id: UUID) -> List[OntologyScene]:
|
|
||||||
"""根据关键词模糊搜索场景
|
|
||||||
|
|
||||||
使用 LIKE 进行模糊匹配,支持中文和英文。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
keyword: 搜索关键词
|
|
||||||
workspace_id: 工作空间ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[OntologyScene]: 匹配的场景列表
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> repo = OntologySceneRepository(db)
|
|
||||||
>>> scenes = repo.search_by_name("医疗", workspace_id)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.debug(
|
|
||||||
f"Searching ontology scenes by keyword - "
|
|
||||||
f"keyword={keyword}, workspace_id={workspace_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 使用 ilike 进行不区分大小写的模糊匹配
|
|
||||||
scenes = self.db.query(OntologyScene).options(
|
|
||||||
joinedload(OntologyScene.classes)
|
|
||||||
).filter(
|
|
||||||
OntologyScene.scene_name.ilike(f"%{keyword}%"),
|
|
||||||
OntologyScene.workspace_id == workspace_id
|
|
||||||
).order_by(
|
|
||||||
OntologyScene.updated_at.desc()
|
|
||||||
).all()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Found {len(scenes)} ontology scenes matching keyword '{keyword}' "
|
|
||||||
f"in workspace {workspace_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return scenes
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to search ontology scenes by keyword: {str(e)}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_by_workspace(self, workspace_id: UUID, page: Optional[int] = None, page_size: Optional[int] = None) -> tuple:
|
|
||||||
"""获取工作空间下的所有场景(支持分页)
|
|
||||||
|
|
||||||
使用joinedload预加载classes关系以统计数量。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
workspace_id: 工作空间ID
|
|
||||||
page: 页码(可选,从1开始)
|
|
||||||
page_size: 每页数量(可选)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (场景列表, 总数量)
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> repo = OntologySceneRepository(db)
|
|
||||||
>>> scenes, total = repo.get_by_workspace(workspace_id)
|
|
||||||
>>> scenes, total = repo.get_by_workspace(workspace_id, page=1, page_size=10)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.debug(f"Getting ontology scenes by workspace: {workspace_id}, page={page}, page_size={page_size}")
|
|
||||||
|
|
||||||
# 构建基础查询
|
|
||||||
query = self.db.query(OntologyScene).options(
|
|
||||||
joinedload(OntologyScene.classes)
|
|
||||||
).filter(
|
|
||||||
OntologyScene.workspace_id == workspace_id
|
|
||||||
).order_by(
|
|
||||||
OntologyScene.updated_at.desc()
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取总数
|
|
||||||
total = query.count()
|
|
||||||
|
|
||||||
# 如果提供了分页参数,应用分页
|
|
||||||
if page is not None and page_size is not None:
|
|
||||||
offset = (page - 1) * page_size
|
|
||||||
query = query.offset(offset).limit(page_size)
|
|
||||||
logger.debug(f"Applying pagination: offset={offset}, limit={page_size}")
|
|
||||||
|
|
||||||
scenes = query.all()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Found {len(scenes)} ontology scenes (total: {total}) in workspace {workspace_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return scenes, total
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to get ontology scenes by workspace: {str(e)}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def update(self, scene_id: UUID, update_data: dict) -> Optional[OntologyScene]:
|
|
||||||
"""更新场景信息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scene_id: 场景ID
|
|
||||||
update_data: 更新数据字典
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[OntologyScene]: 更新后的场景对象,不存在则返回None
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: 数据库操作失败
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> repo = OntologySceneRepository(db)
|
|
||||||
>>> scene = repo.update(
|
|
||||||
... scene_id,
|
|
||||||
... {"scene_name": "新名称"}
|
|
||||||
... )
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info(f"Updating ontology scene: {scene_id}")
|
|
||||||
|
|
||||||
scene = self.get_by_id(scene_id)
|
|
||||||
if not scene:
|
|
||||||
logger.warning(f"Ontology scene not found for update: {scene_id}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 更新字段
|
|
||||||
if "scene_name" in update_data and update_data["scene_name"] is not None:
|
|
||||||
scene.scene_name = update_data["scene_name"]
|
|
||||||
|
|
||||||
if "scene_description" in update_data:
|
|
||||||
scene.scene_description = update_data["scene_description"]
|
|
||||||
|
|
||||||
self.db.flush()
|
|
||||||
|
|
||||||
logger.info(f"Ontology scene updated successfully: {scene_id}")
|
|
||||||
|
|
||||||
return scene
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to update ontology scene: {str(e)}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def delete(self, scene_id: UUID) -> bool:
|
|
||||||
"""删除场景(级联删除类型)
|
|
||||||
|
|
||||||
依赖数据库级联删除配置(ondelete="CASCADE")。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scene_id: 场景ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 删除成功返回True,场景不存在返回False
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: 数据库操作失败
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> repo = OntologySceneRepository(db)
|
|
||||||
>>> success = repo.delete(scene_id)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info(f"Deleting ontology scene: {scene_id}")
|
|
||||||
|
|
||||||
scene = self.get_by_id(scene_id)
|
|
||||||
if not scene:
|
|
||||||
logger.warning(f"Ontology scene not found for delete: {scene_id}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
self.db.delete(scene)
|
|
||||||
self.db.flush()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Ontology scene deleted successfully (cascade): {scene_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to delete ontology scene: {str(e)}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def check_ownership(self, scene_id: UUID, workspace_id: UUID) -> bool:
|
|
||||||
"""检查场景是否属于指定工作空间
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scene_id: 场景ID
|
|
||||||
workspace_id: 工作空间ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 属于返回True,否则返回False
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> repo = OntologySceneRepository(db)
|
|
||||||
>>> is_owner = repo.check_ownership(scene_id, workspace_id)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.debug(
|
|
||||||
f"Checking scene ownership - "
|
|
||||||
f"scene_id={scene_id}, workspace_id={workspace_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
count = self.db.query(OntologyScene).filter(
|
|
||||||
OntologyScene.scene_id == scene_id,
|
|
||||||
OntologyScene.workspace_id == workspace_id
|
|
||||||
).count()
|
|
||||||
|
|
||||||
is_owner = count > 0
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Scene ownership check result: {is_owner} - "
|
|
||||||
f"scene_id={scene_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return is_owner
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to check scene ownership: {str(e)}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_simple_list(self, workspace_id: UUID) -> List[dict]:
|
|
||||||
"""获取场景简单列表(仅包含scene_id和scene_name,用于下拉选择)
|
|
||||||
|
|
||||||
这是一个轻量级查询,不加载关联的classes,响应速度快。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
workspace_id: 工作空间ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[dict]: 场景简单列表,每项包含scene_id和scene_name
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> repo = OntologySceneRepository(db)
|
|
||||||
>>> scenes = repo.get_simple_list(workspace_id)
|
|
||||||
>>> # [{"scene_id": "xxx", "scene_name": "场景1"}, ...]
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.debug(f"Getting simple scene list for workspace: {workspace_id}")
|
|
||||||
|
|
||||||
# 只查询需要的字段,不加载关联数据
|
|
||||||
results = self.db.query(
|
|
||||||
OntologyScene.scene_id,
|
|
||||||
OntologyScene.scene_name
|
|
||||||
).filter(
|
|
||||||
OntologyScene.workspace_id == workspace_id
|
|
||||||
).order_by(
|
|
||||||
OntologyScene.updated_at.desc()
|
|
||||||
).all()
|
|
||||||
|
|
||||||
scenes = [
|
|
||||||
{"scene_id": str(r.scene_id), "scene_name": r.scene_name}
|
|
||||||
for r in results
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.info(f"Found {len(scenes)} scenes (simple list) in workspace {workspace_id}")
|
|
||||||
|
|
||||||
return scenes
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to get simple scene list: {str(e)}",
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
@@ -4,10 +4,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from app.core.logging_config import get_db_logger
|
from app.core.logging_config import get_db_logger
|
||||||
from app.models.prompt_optimizer_model import (
|
from app.models.prompt_optimizer_model import (
|
||||||
PromptOptimizerSession,
|
PromptOptimizerSession, PromptOptimizerSessionHistory, RoleType
|
||||||
PromptOptimizerSessionHistory,
|
|
||||||
RoleType,
|
|
||||||
PromptHistory
|
|
||||||
)
|
)
|
||||||
|
|
||||||
db_logger = get_db_logger()
|
db_logger = get_db_logger()
|
||||||
@@ -19,12 +16,6 @@ class PromptOptimizerSessionRepository:
|
|||||||
def __init__(self, db: Session):
|
def __init__(self, db: Session):
|
||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
def get_session_by_id(self, session_id: uuid.UUID) -> PromptOptimizerSession | None:
|
|
||||||
session = self.db.query(PromptOptimizerSession).filter(
|
|
||||||
PromptOptimizerSession.id == session_id,
|
|
||||||
).first()
|
|
||||||
return session
|
|
||||||
|
|
||||||
def create_session(
|
def create_session(
|
||||||
self,
|
self,
|
||||||
tenant_id: uuid.UUID,
|
tenant_id: uuid.UUID,
|
||||||
@@ -47,9 +38,12 @@ class PromptOptimizerSessionRepository:
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
self.db.add(session)
|
self.db.add(session)
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(session)
|
||||||
|
db_logger.debug(f"Prompt optimization session created: ID:{session.id}")
|
||||||
return session
|
return session
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db_logger.error(f"Error creating prompt optimization session: - {str(e)}")
|
db_logger.error(f"Error creating prompt optimization session: user_id={user_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def get_session_history(
|
def get_session_history(
|
||||||
@@ -123,199 +117,8 @@ class PromptOptimizerSessionRepository:
|
|||||||
content=content,
|
content=content,
|
||||||
)
|
)
|
||||||
self.db.add(message)
|
self.db.add(message)
|
||||||
|
self.db.commit()
|
||||||
return message
|
return message
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db_logger.error(f"Error creating prompt optimization session history: session_id={session_id} - {str(e)}")
|
db_logger.error(f"Error creating prompt optimization session history: session_id={session_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def get_first_user_message(self, session_id: uuid.UUID) -> str | None:
|
|
||||||
"""
|
|
||||||
Get the first user message from a session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_id (uuid.UUID): The session ID.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str | None: The content of the first user message, or None if not found.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
message = self.db.query(PromptOptimizerSessionHistory).filter(
|
|
||||||
PromptOptimizerSessionHistory.session_id == session_id,
|
|
||||||
PromptOptimizerSessionHistory.role == RoleType.USER.value
|
|
||||||
).order_by(
|
|
||||||
PromptOptimizerSessionHistory.created_at.asc()
|
|
||||||
).first()
|
|
||||||
|
|
||||||
return message.content if message else None
|
|
||||||
except Exception as e:
|
|
||||||
db_logger.error(f"Error getting first user message: session_id={session_id} - {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
class PromptReleaseRepository:
|
|
||||||
def __init__(self, db: Session):
|
|
||||||
self.db = db
|
|
||||||
|
|
||||||
def get_prompt_by_session_id(self, session_id: uuid.UUID) -> PromptHistory | None:
|
|
||||||
prompt_obj = self.db.query(PromptHistory).filter(
|
|
||||||
PromptHistory.session_id == session_id,
|
|
||||||
PromptHistory.is_delete.is_(False)
|
|
||||||
).first()
|
|
||||||
return prompt_obj
|
|
||||||
|
|
||||||
def create_prompt_release(
|
|
||||||
self,
|
|
||||||
tenant_id: uuid.UUID,
|
|
||||||
title: str,
|
|
||||||
session_id: uuid.UUID,
|
|
||||||
prompt: str,
|
|
||||||
) -> PromptHistory:
|
|
||||||
try:
|
|
||||||
prompt_obj = PromptHistory(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
title=title,
|
|
||||||
session_id=session_id,
|
|
||||||
prompt=prompt,
|
|
||||||
)
|
|
||||||
self.db.add(prompt_obj)
|
|
||||||
return prompt_obj
|
|
||||||
except Exception as e:
|
|
||||||
db_logger.error(f"Error creating prompt release: session_id={session_id} - {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def soft_delete_prompt(self, prompt_obj: PromptHistory) -> None:
|
|
||||||
"""
|
|
||||||
Soft delete a prompt release by setting is_delete flag to True.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt_obj (PromptHistory): The prompt release object to delete.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
prompt_obj.is_delete = True
|
|
||||||
db_logger.debug(f"Soft deleted prompt release: id={prompt_obj.id}, session_id={prompt_obj.session_id}")
|
|
||||||
except Exception as e:
|
|
||||||
db_logger.error(f"Error soft deleting prompt release: id={prompt_obj.id} - {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_prompt_by_id(self, prompt_id: uuid.UUID) -> PromptHistory | None:
|
|
||||||
"""
|
|
||||||
Get a prompt release by its ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt_id (uuid.UUID): The prompt release ID.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
PromptHistory | None: The prompt release object or None if not found.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
prompt_obj = self.db.query(PromptHistory).filter(
|
|
||||||
PromptHistory.id == prompt_id
|
|
||||||
).first()
|
|
||||||
return prompt_obj
|
|
||||||
except Exception as e:
|
|
||||||
db_logger.error(f"Error getting prompt release by id: id={prompt_id} - {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def count_prompts(self, tenant_id: uuid.UUID) -> int:
|
|
||||||
"""
|
|
||||||
Count total number of non-deleted prompts for a tenant.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tenant_id (uuid.UUID): The tenant ID.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: Total count of prompts.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
count = self.db.query(PromptHistory).filter(
|
|
||||||
PromptHistory.tenant_id == tenant_id,
|
|
||||||
PromptHistory.is_delete.is_(False)
|
|
||||||
).count()
|
|
||||||
return count
|
|
||||||
except Exception as e:
|
|
||||||
db_logger.error(f"Error counting prompts: tenant_id={tenant_id} - {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_prompts_paginated(
|
|
||||||
self,
|
|
||||||
tenant_id: uuid.UUID,
|
|
||||||
offset: int,
|
|
||||||
limit: int
|
|
||||||
) -> list[PromptHistory]:
|
|
||||||
"""
|
|
||||||
Get paginated list of prompt releases for a tenant.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tenant_id (uuid.UUID): The tenant ID.
|
|
||||||
offset (int): Number of records to skip.
|
|
||||||
limit (int): Maximum number of records to return.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[PromptHistory]: List of prompt releases.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
prompts = self.db.query(PromptHistory).filter(
|
|
||||||
PromptHistory.tenant_id == tenant_id,
|
|
||||||
PromptHistory.is_delete.is_(False)
|
|
||||||
).order_by(
|
|
||||||
PromptHistory.created_at.desc()
|
|
||||||
).offset(offset).limit(limit).all()
|
|
||||||
return prompts
|
|
||||||
except Exception as e:
|
|
||||||
db_logger.error(f"Error getting paginated prompts: tenant_id={tenant_id} - {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def count_prompts_by_keyword(self, tenant_id: uuid.UUID, keyword: str) -> int:
|
|
||||||
"""
|
|
||||||
Count total number of non-deleted prompts matching keyword for a tenant.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tenant_id (uuid.UUID): The tenant ID.
|
|
||||||
keyword (str): Search keyword for title.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: Total count of matching prompts.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
count = self.db.query(PromptHistory).filter(
|
|
||||||
PromptHistory.tenant_id == tenant_id,
|
|
||||||
PromptHistory.is_delete.is_(False),
|
|
||||||
PromptHistory.title.ilike(f"%{keyword}%")
|
|
||||||
).count()
|
|
||||||
return count
|
|
||||||
except Exception as e:
|
|
||||||
db_logger.error(f"Error counting prompts by keyword: tenant_id={tenant_id}, keyword={keyword} - {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def search_prompts_paginated(
|
|
||||||
self,
|
|
||||||
tenant_id: uuid.UUID,
|
|
||||||
keyword: str,
|
|
||||||
offset: int,
|
|
||||||
limit: int
|
|
||||||
) -> list[PromptHistory]:
|
|
||||||
"""
|
|
||||||
Search prompt releases by keyword in title with pagination.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tenant_id (uuid.UUID): The tenant ID.
|
|
||||||
keyword (str): Search keyword for title.
|
|
||||||
offset (int): Number of records to skip.
|
|
||||||
limit (int): Maximum number of records to return.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[PromptHistory]: List of matching prompt releases.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
prompts = self.db.query(PromptHistory).filter(
|
|
||||||
PromptHistory.tenant_id == tenant_id,
|
|
||||||
PromptHistory.is_delete.is_(False),
|
|
||||||
PromptHistory.title.ilike(f"%{keyword}%")
|
|
||||||
).order_by(
|
|
||||||
PromptHistory.created_at.desc()
|
|
||||||
).offset(offset).limit(limit).all()
|
|
||||||
return prompts
|
|
||||||
except Exception as e:
|
|
||||||
db_logger.error(f"Error searching prompts: tenant_id={tenant_id}, keyword={keyword} - {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ class KnowledgeBaseConfig(BaseModel):
|
|||||||
kb_id: str = Field(..., description="知识库ID")
|
kb_id: str = Field(..., description="知识库ID")
|
||||||
top_k: int = Field(default=3, ge=1, le=20, description="检索返回的文档数量")
|
top_k: int = Field(default=3, ge=1, le=20, description="检索返回的文档数量")
|
||||||
similarity_threshold: float = Field(default=0.7, ge=0.0, le=1.0, description="相似度阈值")
|
similarity_threshold: float = Field(default=0.7, ge=0.0, le=1.0, description="相似度阈值")
|
||||||
# strategy: str = Field(default="hybrid", description="检索策略: hybrid | bm25 | dense")
|
strategy: str = Field(default="hybrid", description="检索策略: hybrid | bm25 | dense")
|
||||||
# weight: float = Field(default=1.0, ge=0.0, le=1.0, description="知识库权重(用于多知识库融合)")
|
weight: float = Field(default=1.0, ge=0.0, le=1.0, description="知识库权重(用于多知识库融合)")
|
||||||
vector_similarity_weight: float = Field(default=0.5, ge=0.0, le=1.0, description="向量相似度权重")
|
vector_similarity_weight: float = Field(default=0.5, ge=0.0, le=1.0, description="向量相似度权重")
|
||||||
retrieve_type: str = Field(default="hybrid", description="检索方式participle| semantic|hybrid")
|
retrieve_type: str = Field(default="hybrid", description="检索方式participle| semantic|hybrid")
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from abc import ABC
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -16,14 +15,3 @@ class Write_UserInput(BaseModel):
|
|||||||
messages: list[dict]
|
messages: list[dict]
|
||||||
end_user_id: str
|
end_user_id: str
|
||||||
config_id: Optional[str] = None
|
config_id: Optional[str] = None
|
||||||
|
|
||||||
class AgentMemory_Long_Term(ABC):
|
|
||||||
"""长期记忆配置常量"""
|
|
||||||
STORAGE_NEO4J = "neo4j"
|
|
||||||
STORAGE_RAG = "rag"
|
|
||||||
STRATEGY_AGGREGATE = "aggregate"
|
|
||||||
STRATEGY_CHUNK = "chunk"
|
|
||||||
STRATEGY_TIME = "time"
|
|
||||||
DEFAULT_SCOPE = 6
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
import uuid
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Optional, Union
|
from typing import Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
@@ -12,7 +10,7 @@ class OptimizationStrategy(str, Enum):
|
|||||||
ACCURACY_FIRST = "accuracy_first"
|
ACCURACY_FIRST = "accuracy_first"
|
||||||
BALANCED = "balanced"
|
BALANCED = "balanced"
|
||||||
class Memory_Reflection(BaseModel):
|
class Memory_Reflection(BaseModel):
|
||||||
config_id: Union[uuid.UUID, int, str] = None
|
config_id: Optional[UUID] = None
|
||||||
reflection_enabled: bool
|
reflection_enabled: bool
|
||||||
reflection_period_in_hours: str
|
reflection_period_in_hours: str
|
||||||
reflexion_range: Optional[str] = "partial"
|
reflexion_range: Optional[str] = "partial"
|
||||||
|
|||||||
@@ -147,7 +147,7 @@ class ReflexionResultSchema(BaseModel):
|
|||||||
# Composite key identifying a config row
|
# Composite key identifying a config row
|
||||||
class ConfigKey(BaseModel): # 配置参数键模型
|
class ConfigKey(BaseModel): # 配置参数键模型
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)")
|
config_id: uuid.UUID = Field("config_id", description="配置唯一标识(UUID)")
|
||||||
user_id: str = Field("user_id", description="用户标识(字符串)")
|
user_id: str = Field("user_id", description="用户标识(字符串)")
|
||||||
apply_id: str = Field("apply_id", description="应用或场景标识(字符串)")
|
apply_id: str = Field("apply_id", description="应用或场景标识(字符串)")
|
||||||
|
|
||||||
@@ -229,32 +229,26 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body,
|
|||||||
config_desc: str = Field("配置描述", description="配置描述(字符串)")
|
config_desc: str = Field("配置描述", description="配置描述(字符串)")
|
||||||
workspace_id: Optional[uuid.UUID] = Field(None, description="工作空间ID(UUID)")
|
workspace_id: Optional[uuid.UUID] = Field(None, description="工作空间ID(UUID)")
|
||||||
|
|
||||||
# 本体场景关联(可选)
|
|
||||||
scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID(UUID),关联ontology_scene表")
|
|
||||||
|
|
||||||
# 模型配置字段(可选,用于手动指定或自动填充)
|
# 模型配置字段(可选,用于手动指定或自动填充)
|
||||||
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
||||||
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
||||||
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
||||||
reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致")
|
|
||||||
emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致")
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
|
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
# config_name: str = Field("配置名称", description="配置名称(字符串)")
|
# config_name: str = Field("配置名称", description="配置名称(字符串)")
|
||||||
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)")
|
config_id: uuid.UUID = Field("配置ID", description="配置ID(UUID)")
|
||||||
|
|
||||||
|
|
||||||
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
||||||
config_id: Union[uuid.UUID, int, str] = None
|
config_id: Optional[uuid.UUID] = None
|
||||||
config_name: Optional[str] = Field(None, description="配置名称(字符串)")
|
config_name: str = Field("配置名称", description="配置名称(字符串)")
|
||||||
config_desc: Optional[str] = Field(None, description="配置描述(字符串)")
|
config_desc: str = Field("配置描述", description="配置描述(字符串)")
|
||||||
scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID")
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
||||||
config_id:Union[uuid.UUID, int, str] = None
|
config_id: Optional[uuid.UUID] = None
|
||||||
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
||||||
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
||||||
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
||||||
@@ -321,14 +315,14 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数
|
|||||||
|
|
||||||
class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型
|
class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型
|
||||||
# 遗忘引擎配置参数更新模型
|
# 遗忘引擎配置参数更新模型
|
||||||
config_id:Union[uuid.UUID, int, str] = None
|
config_id: Optional[uuid.UUID] = None
|
||||||
lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度,0-1 小数;默认 0.5")
|
lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度,0-1 小数;默认 0.5")
|
||||||
lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率,0-1 小数;默认 0.5")
|
lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率,0-1 小数;默认 0.5")
|
||||||
offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度,0-1 小数;默认 0.0")
|
offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度,0-1 小数;默认 0.0")
|
||||||
|
|
||||||
|
|
||||||
class ConfigPilotRun(BaseModel): # 试运行触发请求模型
|
class ConfigPilotRun(BaseModel): # 试运行触发请求模型
|
||||||
config_id:Union[uuid.UUID, int, str] = Field(..., description="配置ID(唯一,支持UUID、整数或字符串)")
|
config_id: uuid.UUID = Field(..., description="配置ID(唯一)")
|
||||||
dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填")
|
dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填")
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
|
||||||
@@ -336,7 +330,7 @@ class ConfigPilotRun(BaseModel): # 试运行触发请求模型
|
|||||||
class ConfigFilter(BaseModel): # 查询配置参数时使用的模型
|
class ConfigFilter(BaseModel): # 查询配置参数时使用的模型
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
|
||||||
config_id: Union[uuid.UUID, int, str] = None
|
config_id: Optional[uuid.UUID] = None
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
apply_id: Optional[str] = None
|
apply_id: Optional[str] = None
|
||||||
|
|
||||||
@@ -412,7 +406,7 @@ class ForgettingConfigResponse(BaseModel):
|
|||||||
"""遗忘引擎配置响应模型"""
|
"""遗忘引擎配置响应模型"""
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
|
||||||
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置ID(支持UUID、整数或字符串)")
|
config_id: uuid.UUID = Field(..., description="配置ID")
|
||||||
decay_constant: float = Field(..., description="衰减常数 d")
|
decay_constant: float = Field(..., description="衰减常数 d")
|
||||||
lambda_time: float = Field(..., description="时间衰减参数")
|
lambda_time: float = Field(..., description="时间衰减参数")
|
||||||
lambda_mem: float = Field(..., description="记忆衰减参数")
|
lambda_mem: float = Field(..., description="记忆衰减参数")
|
||||||
@@ -430,7 +424,7 @@ class ForgettingConfigUpdateRequest(BaseModel):
|
|||||||
"""遗忘引擎配置更新请求模型"""
|
"""遗忘引擎配置更新请求模型"""
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
|
||||||
config_id: Union[uuid.UUID, int,str] = Field(..., description="配置唯一标识(UUID或int)")
|
config_id: uuid.UUID = Field(..., description="配置ID")
|
||||||
decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d")
|
decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d")
|
||||||
lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数")
|
lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数")
|
||||||
lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数")
|
lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数")
|
||||||
@@ -505,7 +499,7 @@ class ForgettingCurveRequest(BaseModel):
|
|||||||
|
|
||||||
importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数(0-1)")
|
importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数(0-1)")
|
||||||
days: int = Field(60, ge=1, le=365, description="模拟天数(默认60天)")
|
days: int = Field(60, ge=1, le=365, description="模拟天数(默认60天)")
|
||||||
config_id: Union[uuid.UUID, int, str] = Field(..., description="配置唯一标识(UUID或int)")
|
config_id: Optional[uuid.UUID] = Field(None, description="配置ID(可选,如果为None则使用默认配置)")
|
||||||
|
|
||||||
|
|
||||||
class ForgettingCurveResponse(BaseModel):
|
class ForgettingCurveResponse(BaseModel):
|
||||||
|
|||||||
@@ -3,10 +3,8 @@ from typing import Optional, List, Dict, Any
|
|||||||
import datetime
|
import datetime
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from app.models.models_model import ModelProvider, ModelType, LoadBalanceStrategy
|
from app.models.models_model import ModelProvider, ModelType
|
||||||
from app.core.logging_config import get_business_logger
|
|
||||||
|
|
||||||
schema_logger = get_business_logger()
|
|
||||||
|
|
||||||
|
|
||||||
# ModelConfig Schemas
|
# ModelConfig Schemas
|
||||||
@@ -14,19 +12,15 @@ class ModelConfigBase(BaseModel):
|
|||||||
"""模型配置基础Schema"""
|
"""模型配置基础Schema"""
|
||||||
name: str = Field(..., description="模型显示名称", max_length=255)
|
name: str = Field(..., description="模型显示名称", max_length=255)
|
||||||
type: ModelType = Field(..., description="模型类型")
|
type: ModelType = Field(..., description="模型类型")
|
||||||
logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255)
|
|
||||||
description: Optional[str] = Field(None, description="模型描述")
|
description: Optional[str] = Field(None, description="模型描述")
|
||||||
provider: str = Field(..., description="供应商")
|
|
||||||
config: Optional[Dict[str, Any]] = Field({}, description="模型配置参数")
|
config: Optional[Dict[str, Any]] = Field({}, description="模型配置参数")
|
||||||
is_active: bool = Field(True, description="是否激活")
|
is_active: bool = Field(True, description="是否激活")
|
||||||
is_public: bool = Field(False, description="是否公开")
|
is_public: bool = Field(False, description="是否公开")
|
||||||
load_balance_strategy: Optional[str] = Field(LoadBalanceStrategy.NONE.value, description="负载均衡策略")
|
|
||||||
|
|
||||||
|
|
||||||
class ApiKeyCreateNested(BaseModel):
|
class ApiKeyCreateNested(BaseModel):
|
||||||
"""用于在创建模型时内嵌创建API Key的Schema"""
|
"""用于在创建模型时内嵌创建API Key的Schema"""
|
||||||
model_name: str = Field(..., description="模型实际名称", max_length=255)
|
model_name: str = Field(..., description="模型实际名称", max_length=255)
|
||||||
description: Optional[str] = Field(None, description="备注")
|
|
||||||
provider: ModelProvider = Field(..., description="API Key提供商")
|
provider: ModelProvider = Field(..., description="API Key提供商")
|
||||||
api_key: str = Field(..., description="API密钥", max_length=500)
|
api_key: str = Field(..., description="API密钥", max_length=500)
|
||||||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||||||
@@ -36,23 +30,10 @@ class ApiKeyCreateNested(BaseModel):
|
|||||||
|
|
||||||
class ModelConfigCreate(ModelConfigBase):
|
class ModelConfigCreate(ModelConfigBase):
|
||||||
"""创建模型配置Schema"""
|
"""创建模型配置Schema"""
|
||||||
api_keys: Optional[List[ApiKeyCreateNested]] = Field(None, description="同时创建的API Key配置")
|
api_keys: Optional[ApiKeyCreateNested] = Field(None, description="同时创建的API Key配置")
|
||||||
skip_validation: Optional[bool] = Field(False, description="是否跳过配置验证")
|
skip_validation: Optional[bool] = Field(False, description="是否跳过配置验证")
|
||||||
|
|
||||||
|
|
||||||
class CompositeModelCreate(BaseModel):
|
|
||||||
"""创建组合模型Schema"""
|
|
||||||
name: str = Field(..., description="组合模型名称", max_length=255)
|
|
||||||
type: Optional[ModelType] = Field(None, description="模型类型")
|
|
||||||
logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255)
|
|
||||||
description: Optional[str] = Field(None, description="模型描述")
|
|
||||||
config: Optional[Dict[str, Any]] = Field({}, description="模型配置参数")
|
|
||||||
is_active: bool = Field(True, description="是否激活")
|
|
||||||
is_public: bool = Field(False, description="是否公开")
|
|
||||||
api_key_ids: List[uuid.UUID] = Field(..., description="绑定的API Key ID列表")
|
|
||||||
load_balance_strategy: Optional[str] = Field(default=LoadBalanceStrategy.NONE.value, description="负载均衡策略")
|
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigUpdate(BaseModel):
|
class ModelConfigUpdate(BaseModel):
|
||||||
"""更新模型配置Schema"""
|
"""更新模型配置Schema"""
|
||||||
name: Optional[str] = Field(None, description="模型显示名称", max_length=255)
|
name: Optional[str] = Field(None, description="模型显示名称", max_length=255)
|
||||||
@@ -72,48 +53,22 @@ class ModelConfig(ModelConfigBase):
|
|||||||
updated_at: datetime.datetime
|
updated_at: datetime.datetime
|
||||||
api_keys: List["ModelApiKey"] = []
|
api_keys: List["ModelApiKey"] = []
|
||||||
|
|
||||||
@field_validator("api_keys", mode="after")
|
|
||||||
@classmethod
|
|
||||||
def filter_active_api_keys(cls, api_keys: List["ModelApiKey"]) -> List["ModelApiKey"]:
|
|
||||||
return [key for key in api_keys if key.is_active]
|
|
||||||
|
|
||||||
@field_serializer("created_at", when_used="json")
|
|
||||||
def _serialize_created_at(self, dt: datetime.datetime | None):
|
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
|
||||||
|
|
||||||
@field_serializer("updated_at", when_used="json")
|
|
||||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
|
||||||
|
|
||||||
|
|
||||||
# ModelApiKey Schemas
|
# ModelApiKey Schemas
|
||||||
class ModelApiKeyCreateByProvider(BaseModel):
|
|
||||||
"""基于供应商创建API Key Schema"""
|
|
||||||
provider: ModelProvider = Field(..., description="API Key提供商")
|
|
||||||
api_key: str = Field(..., description="API密钥", max_length=500)
|
|
||||||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
|
||||||
description: Optional[str] = Field(None, description="备注")
|
|
||||||
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
|
|
||||||
is_active: bool = Field(True, description="是否激活")
|
|
||||||
priority: str = Field("1", description="优先级", max_length=10)
|
|
||||||
model_config_ids: Optional[List[uuid.UUID]] = Field(None, description="关联的模型配置ID列表")
|
|
||||||
|
|
||||||
|
|
||||||
class ModelApiKeyBase(BaseModel):
|
class ModelApiKeyBase(BaseModel):
|
||||||
"""API Key基础Schema"""
|
"""API Key基础Schema"""
|
||||||
model_name: str = Field(..., description="模型实际名称", max_length=255)
|
model_name: str = Field(..., description="模型实际名称", max_length=255)
|
||||||
description: Optional[str] = Field(None, description="备注")
|
|
||||||
provider: ModelProvider = Field(..., description="API Key提供商")
|
provider: ModelProvider = Field(..., description="API Key提供商")
|
||||||
api_key: str = Field(..., description="API密钥", max_length=500)
|
api_key: str = Field(..., description="API密钥", max_length=500)
|
||||||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||||||
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
|
config: Optional[Dict[str, Any]] = Field(None, description="API Key特定配置")
|
||||||
is_active: bool = Field(True, description="是否激活")
|
is_active: bool = Field(True, description="是否激活")
|
||||||
priority: str = Field("1", description="优先级", max_length=10)
|
priority: str = Field("1", description="优先级", max_length=10)
|
||||||
|
|
||||||
|
|
||||||
class ModelApiKeyCreate(ModelApiKeyBase):
|
class ModelApiKeyCreate(ModelApiKeyBase):
|
||||||
"""创建API Key Schema"""
|
"""创建API Key Schema"""
|
||||||
model_config_ids: Optional[List[uuid.UUID]] = Field(None, description="关联的模型配置ID列表")
|
model_config_id: uuid.UUID = Field(..., description="模型配置ID")
|
||||||
|
|
||||||
|
|
||||||
class ModelApiKeyUpdate(BaseModel):
|
class ModelApiKeyUpdate(BaseModel):
|
||||||
@@ -130,54 +85,23 @@ class ModelApiKeyUpdate(BaseModel):
|
|||||||
class ModelApiKey(ModelApiKeyBase):
|
class ModelApiKey(ModelApiKeyBase):
|
||||||
"""API Key Schema"""
|
"""API Key Schema"""
|
||||||
id: uuid.UUID
|
id: uuid.UUID
|
||||||
|
model_config_id: uuid.UUID
|
||||||
usage_count: str
|
usage_count: str
|
||||||
last_used_at: Optional[datetime.datetime]
|
last_used_at: Optional[datetime.datetime]
|
||||||
created_at: datetime.datetime
|
created_at: datetime.datetime
|
||||||
updated_at: datetime.datetime
|
updated_at: datetime.datetime
|
||||||
model_configs: Any = Field(default=None, exclude=True)
|
|
||||||
model_config_ids: List[uuid.UUID] = Field(default_factory=list, description="关联的模型配置ID列表")
|
|
||||||
|
|
||||||
def model_post_init(self, __context: Any) -> None:
|
@field_validator("config", mode="before")
|
||||||
"""实例化后强制提取 model_configs 的ID到 model_config_ids"""
|
@classmethod
|
||||||
# 如果手动传入了 model_config_ids,不覆盖
|
def parse_config(cls, v):
|
||||||
if self.model_config_ids and len(self.model_config_ids) > 0:
|
"""处理 config 字段,如果是字符串则解析为字典"""
|
||||||
return
|
if isinstance(v, str):
|
||||||
|
import json
|
||||||
# 从 model_configs 提取ID(只提取与 model_name 相同的非组合模型)
|
|
||||||
if self.model_configs is not None:
|
|
||||||
try:
|
try:
|
||||||
# 情况1:ORM 对象列表(SQLAlchemy 关联)
|
return json.loads(v)
|
||||||
if hasattr(self.model_configs, '__iter__') and not isinstance(self.model_configs, dict):
|
except json.JSONDecodeError:
|
||||||
self.model_config_ids = [
|
return {}
|
||||||
mc.id for mc in self.model_configs
|
return v
|
||||||
if hasattr(mc, 'id')
|
|
||||||
and not getattr(mc, 'is_composite', False)
|
|
||||||
and getattr(mc, 'name', None) == self.model_name
|
|
||||||
]
|
|
||||||
# 情况2:字典列表
|
|
||||||
elif isinstance(self.model_configs, list):
|
|
||||||
self.model_config_ids = [
|
|
||||||
mc['id'] if isinstance(mc, dict) else mc.id
|
|
||||||
for mc in self.model_configs
|
|
||||||
if ((isinstance(mc, dict)
|
|
||||||
and 'id' in mc
|
|
||||||
and not mc.get('is_composite', False)
|
|
||||||
and mc.get('name') == self.model_name) or
|
|
||||||
(hasattr(mc, 'id')
|
|
||||||
and not getattr(mc, 'is_composite', False)
|
|
||||||
and getattr(mc, 'name', None) == self.model_name))
|
|
||||||
]
|
|
||||||
except Exception as e:
|
|
||||||
schema_logger.warning(f"提取 model_config_ids 失败:{e}")
|
|
||||||
self.model_config_ids = []
|
|
||||||
|
|
||||||
model_config = ConfigDict(
|
|
||||||
from_attributes=True, # 支持从 ORM 解析
|
|
||||||
arbitrary_types_allowed=True, # 允许任意类型(ORM 对象)
|
|
||||||
populate_by_name=True, # 按属性名匹配字段
|
|
||||||
validate_assignment=True # 确保赋值触发校验
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@field_serializer("created_at", when_used="json")
|
@field_serializer("created_at", when_used="json")
|
||||||
def _serialize_created_at(self, dt: datetime.datetime):
|
def _serialize_created_at(self, dt: datetime.datetime):
|
||||||
@@ -187,11 +111,14 @@ class ModelApiKey(ModelApiKeyBase):
|
|||||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
@field_serializer("last_used_at", when_used="json")
|
@field_serializer("last_used_at", when_used="json")
|
||||||
def _serialize_last_used_at(self, dt: datetime.datetime):
|
def _serialize_last_used_at(self, dt: datetime.datetime):
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
|
||||||
|
# 查询和响应Schemas
|
||||||
class ModelConfigQuery(BaseModel):
|
class ModelConfigQuery(BaseModel):
|
||||||
"""模型配置查询Schema"""
|
"""模型配置查询Schema"""
|
||||||
type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)")
|
type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)")
|
||||||
@@ -202,17 +129,6 @@ class ModelConfigQuery(BaseModel):
|
|||||||
page: int = Field(1, description="页码", ge=1)
|
page: int = Field(1, description="页码", ge=1)
|
||||||
pagesize: int = Field(10, description="每页数量", ge=1, le=100)
|
pagesize: int = Field(10, description="每页数量", ge=1, le=100)
|
||||||
|
|
||||||
|
|
||||||
# 查询和响应Schemas
|
|
||||||
class ModelConfigQueryNew(BaseModel):
|
|
||||||
"""模型配置查询Schema"""
|
|
||||||
type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)")
|
|
||||||
provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)")
|
|
||||||
is_active: Optional[bool] = Field(None, description="激活状态筛选")
|
|
||||||
is_public: Optional[bool] = Field(None, description="公开状态筛选")
|
|
||||||
is_composite: Optional[bool] = Field(None, description="组合模型筛选")
|
|
||||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
|
||||||
|
|
||||||
class ModelMarketplace(BaseModel):
|
class ModelMarketplace(BaseModel):
|
||||||
"""模型广场响应Schema"""
|
"""模型广场响应Schema"""
|
||||||
llm_models: List[ModelConfig] = []
|
llm_models: List[ModelConfig] = []
|
||||||
@@ -256,52 +172,3 @@ class ModelValidateResponse(BaseModel):
|
|||||||
|
|
||||||
# 更新前向引用
|
# 更新前向引用
|
||||||
ModelConfig.model_rebuild()
|
ModelConfig.model_rebuild()
|
||||||
|
|
||||||
|
|
||||||
# ModelBase Schemas
|
|
||||||
class ModelBaseCreate(BaseModel):
|
|
||||||
"""创建基础模型Schema"""
|
|
||||||
name: str = Field(..., description="模型唯一标识", max_length=255)
|
|
||||||
type: ModelType = Field(..., description="模型类型")
|
|
||||||
provider: ModelProvider = Field(..., description="提供商")
|
|
||||||
logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255)
|
|
||||||
description: Optional[str] = Field(None, description="模型描述")
|
|
||||||
is_official: bool = Field(True, description="是否供应商官方模型")
|
|
||||||
tags: List[str] = Field(default_factory=list, description="模型标签")
|
|
||||||
|
|
||||||
|
|
||||||
class ModelBaseUpdate(BaseModel):
|
|
||||||
"""更新基础模型Schema"""
|
|
||||||
name: Optional[str] = Field(None, description="模型唯一标识", max_length=255)
|
|
||||||
type: Optional[ModelType] = Field(None, description="模型类型")
|
|
||||||
provider: Optional[ModelProvider] = Field(None, description="提供商")
|
|
||||||
logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255)
|
|
||||||
description: Optional[str] = Field(None, description="模型描述")
|
|
||||||
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
|
|
||||||
is_official: Optional[bool] = Field(None, description="是否供应商官方模型")
|
|
||||||
tags: Optional[List[str]] = Field(None, description="模型标签")
|
|
||||||
|
|
||||||
|
|
||||||
class ModelBase(BaseModel):
|
|
||||||
"""基础模型Schema"""
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: uuid.UUID
|
|
||||||
name: str
|
|
||||||
type: str
|
|
||||||
provider: str
|
|
||||||
logo: Optional[str]
|
|
||||||
description: Optional[str]
|
|
||||||
is_deprecated: bool
|
|
||||||
is_official: bool
|
|
||||||
tags: List[str]
|
|
||||||
add_count: int
|
|
||||||
|
|
||||||
|
|
||||||
class ModelBaseQuery(BaseModel):
|
|
||||||
"""基础模型查询Schema"""
|
|
||||||
type: Optional[ModelType] = Field(None, description="模型类型")
|
|
||||||
provider: Optional[ModelProvider] = Field(None, description="提供商")
|
|
||||||
is_official: Optional[bool] = Field(None, description="是否官方模型")
|
|
||||||
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
|
|
||||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
|
||||||
|
|||||||
@@ -1,461 +0,0 @@
|
|||||||
"""本体提取API的请求和响应模型
|
|
||||||
|
|
||||||
本模块定义了本体提取系统的所有API请求和响应的Pydantic模型。
|
|
||||||
|
|
||||||
Classes:
|
|
||||||
ExtractionRequest: 本体提取请求模型
|
|
||||||
ExtractionResponse: 本体提取响应模型
|
|
||||||
ExportRequest: OWL文件导出请求模型
|
|
||||||
ExportResponse: OWL文件导出响应模型
|
|
||||||
OntologyResultResponse: 本体提取结果响应模型(带毫秒时间戳)
|
|
||||||
SceneCreateRequest: 场景创建请求模型
|
|
||||||
SceneUpdateRequest: 场景更新请求模型
|
|
||||||
SceneResponse: 场景响应模型
|
|
||||||
SceneListResponse: 场景列表响应模型
|
|
||||||
ClassCreateRequest: 类型创建请求模型
|
|
||||||
ClassUpdateRequest: 类型更新请求模型
|
|
||||||
ClassResponse: 类型响应模型
|
|
||||||
ClassListResponse: 类型列表响应模型
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
import datetime
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_serializer, ConfigDict
|
|
||||||
|
|
||||||
from app.core.memory.models.ontology_models import OntologyClass
|
|
||||||
|
|
||||||
|
|
||||||
class ExtractionRequest(BaseModel):
|
|
||||||
"""本体提取请求模型
|
|
||||||
|
|
||||||
用于POST /api/ontology/extract端点的请求体。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
scenario: 场景描述文本,不能为空
|
|
||||||
domain: 可选的领域提示(如Healthcare, Education等)
|
|
||||||
llm_id: LLM模型ID,必须提供
|
|
||||||
scene_id: 场景ID,必须提供,用于将提取的类保存到指定场景
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> request = ExtractionRequest(
|
|
||||||
... scenario="医院管理患者记录...",
|
|
||||||
... domain="Healthcare",
|
|
||||||
... llm_id="550e8400-e29b-41d4-a716-446655440000",
|
|
||||||
... scene_id="660e8400-e29b-41d4-a716-446655440000"
|
|
||||||
... )
|
|
||||||
"""
|
|
||||||
scenario: str = Field(..., description="场景描述文本", min_length=1)
|
|
||||||
domain: Optional[str] = Field(None, description="可选的领域提示")
|
|
||||||
llm_id: str = Field(..., description="LLM模型ID")
|
|
||||||
scene_id: UUID = Field(..., description="场景ID,用于将提取的类保存到指定场景")
|
|
||||||
|
|
||||||
|
|
||||||
class ExtractionResponse(BaseModel):
|
|
||||||
"""本体提取响应模型
|
|
||||||
|
|
||||||
用于POST /api/ontology/extract端点的响应体。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
classes: 提取的本体类列表
|
|
||||||
domain: 识别的领域
|
|
||||||
extracted_count: 提取的类数量
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> response = ExtractionResponse(
|
|
||||||
... classes=[...],
|
|
||||||
... domain="Healthcare",
|
|
||||||
... extracted_count=7
|
|
||||||
... )
|
|
||||||
"""
|
|
||||||
classes: List[OntologyClass] = Field(default_factory=list, description="提取的本体类列表")
|
|
||||||
domain: str = Field(..., description="识别的领域")
|
|
||||||
extracted_count: int = Field(..., description="提取的类数量")
|
|
||||||
|
|
||||||
|
|
||||||
class ExportRequest(BaseModel):
|
|
||||||
"""OWL文件导出请求模型
|
|
||||||
|
|
||||||
用于POST /api/ontology/export端点的请求体。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
classes: 要导出的本体类列表
|
|
||||||
format: 导出格式,可选值: rdfxml, turtle, ntriples, json
|
|
||||||
include_metadata: 是否包含完整的OWL元数据(命名空间等),默认True
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> request = ExportRequest(
|
|
||||||
... classes=[...],
|
|
||||||
... format="rdfxml",
|
|
||||||
... include_metadata=True
|
|
||||||
... )
|
|
||||||
"""
|
|
||||||
classes: List[OntologyClass] = Field(..., description="要导出的本体类列表", min_length=1)
|
|
||||||
format: str = Field("rdfxml", description="导出格式: rdfxml, turtle, ntriples, json")
|
|
||||||
include_metadata: bool = Field(True, description="是否包含完整的OWL元数据")
|
|
||||||
|
|
||||||
|
|
||||||
class ExportResponse(BaseModel):
|
|
||||||
"""OWL文件导出响应模型
|
|
||||||
|
|
||||||
用于POST /api/ontology/export端点的响应体。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
owl_content: OWL文件内容
|
|
||||||
format: 导出格式
|
|
||||||
classes_count: 导出的类数量
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> response = ExportResponse(
|
|
||||||
... owl_content="<?xml version='1.0'?>...",
|
|
||||||
... format="rdfxml",
|
|
||||||
... classes_count=7
|
|
||||||
... )
|
|
||||||
"""
|
|
||||||
owl_content: str = Field(..., description="OWL文件内容")
|
|
||||||
format: str = Field(..., description="导出格式")
|
|
||||||
classes_count: int = Field(..., description="导出的类数量")
|
|
||||||
|
|
||||||
|
|
||||||
class OntologyResultResponse(BaseModel):
|
|
||||||
"""本体提取结果响应模型
|
|
||||||
|
|
||||||
用于返回数据库中存储的提取结果,时间戳为毫秒级。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
id: 结果ID (UUID)
|
|
||||||
scenario: 场景描述文本
|
|
||||||
domain: 领域
|
|
||||||
classes_json: 提取的本体类数据(JSON格式)
|
|
||||||
extracted_count: 提取的类数量
|
|
||||||
user_id: 用户ID
|
|
||||||
created_at: 创建时间(毫秒时间戳)
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> response = OntologyResultResponse(
|
|
||||||
... id=uuid.uuid4(),
|
|
||||||
... scenario="医院管理患者记录...",
|
|
||||||
... domain="Healthcare",
|
|
||||||
... classes_json={"classes": [...]},
|
|
||||||
... extracted_count=7,
|
|
||||||
... user_id=123,
|
|
||||||
... created_at=datetime.now()
|
|
||||||
... )
|
|
||||||
"""
|
|
||||||
id: UUID = Field(..., description="结果ID")
|
|
||||||
scenario: str = Field(..., description="场景描述文本")
|
|
||||||
domain: Optional[str] = Field(None, description="领域")
|
|
||||||
classes_json: dict = Field(..., description="提取的本体类数据(JSON格式)")
|
|
||||||
extracted_count: int = Field(..., description="提取的类数量")
|
|
||||||
user_id: Optional[int] = Field(None, description="用户ID")
|
|
||||||
created_at: datetime.datetime = Field(..., description="创建时间")
|
|
||||||
|
|
||||||
@field_serializer("created_at", when_used="json")
|
|
||||||
def _serialize_created_at(self, dt: datetime.datetime):
|
|
||||||
"""将创建时间序列化为毫秒时间戳"""
|
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 本体场景相关 Schema ====================
|
|
||||||
|
|
||||||
class SceneCreateRequest(BaseModel):
|
|
||||||
"""场景创建请求模型
|
|
||||||
|
|
||||||
用于创建新的本体场景。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
scene_name: 场景名称,必填,1-200字符
|
|
||||||
scene_description: 场景描述,可选
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> request = SceneCreateRequest(
|
|
||||||
... scene_name="医疗场景",
|
|
||||||
... scene_description="用于医疗领域的本体建模"
|
|
||||||
... )
|
|
||||||
"""
|
|
||||||
scene_name: str = Field(..., min_length=1, max_length=200, description="场景名称")
|
|
||||||
scene_description: Optional[str] = Field(None, description="场景描述")
|
|
||||||
|
|
||||||
|
|
||||||
class SceneUpdateRequest(BaseModel):
|
|
||||||
"""场景更新请求模型
|
|
||||||
|
|
||||||
用于更新已有本体场景信息。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
scene_name: 场景名称,可选,1-200字符
|
|
||||||
scene_description: 场景描述,可选
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> request = SceneUpdateRequest(
|
|
||||||
... scene_name="更新后的场景名称",
|
|
||||||
... scene_description="更新后的描述"
|
|
||||||
... )
|
|
||||||
"""
|
|
||||||
scene_name: Optional[str] = Field(None, min_length=1, max_length=200, description="场景名称")
|
|
||||||
scene_description: Optional[str] = Field(None, description="场景描述")
|
|
||||||
|
|
||||||
|
|
||||||
class SceneResponse(BaseModel):
|
|
||||||
"""场景响应模型
|
|
||||||
|
|
||||||
用于返回本体场景信息。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
scene_id: 场景ID
|
|
||||||
scene_name: 场景名称
|
|
||||||
scene_description: 场景描述
|
|
||||||
type_num: 类型数量
|
|
||||||
workspace_id: 所属工作空间ID
|
|
||||||
created_at: 创建时间(毫秒时间戳)
|
|
||||||
updated_at: 更新时间(毫秒时间戳)
|
|
||||||
classes_count: 类型数量
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> response = SceneResponse(
|
|
||||||
... scene_id=uuid.uuid4(),
|
|
||||||
... scene_name="医疗场景",
|
|
||||||
... scene_description="用于医疗领域的本体建模",
|
|
||||||
... type_num=0,
|
|
||||||
... workspace_id=uuid.uuid4(),
|
|
||||||
... created_at=datetime.now(),
|
|
||||||
... updated_at=datetime.now(),
|
|
||||||
... classes_count=5
|
|
||||||
... )
|
|
||||||
"""
|
|
||||||
scene_id: UUID = Field(..., description="场景ID")
|
|
||||||
scene_name: str = Field(..., description="场景名称")
|
|
||||||
scene_description: Optional[str] = Field(None, description="场景描述")
|
|
||||||
type_num: int = Field(..., description="类型数量")
|
|
||||||
entity_type: Optional[List[str]] = Field(None, description="实体类型列表(最多3个class_name)")
|
|
||||||
workspace_id: UUID = Field(..., description="所属工作空间ID")
|
|
||||||
created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)")
|
|
||||||
updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)")
|
|
||||||
classes_count: int = Field(0, description="类型数量")
|
|
||||||
|
|
||||||
@field_serializer("created_at", when_used="json")
|
|
||||||
def _serialize_created_at(self, dt: datetime.datetime):
|
|
||||||
"""将创建时间序列化为毫秒时间戳"""
|
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
|
||||||
|
|
||||||
@field_serializer("updated_at", when_used="json")
|
|
||||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
|
||||||
"""将更新时间序列化为毫秒时间戳"""
|
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
|
|
||||||
class PaginationInfo(BaseModel):
|
|
||||||
"""分页信息模型
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
page: 当前页码
|
|
||||||
pagesize: 每页数量
|
|
||||||
total: 总数量
|
|
||||||
hasnext: 是否有下一页
|
|
||||||
"""
|
|
||||||
page: int = Field(..., description="当前页码")
|
|
||||||
pagesize: int = Field(..., description="每页数量")
|
|
||||||
total: int = Field(..., description="总数量")
|
|
||||||
hasnext: bool = Field(..., description="是否有下一页")
|
|
||||||
|
|
||||||
|
|
||||||
class SceneListResponse(BaseModel):
|
|
||||||
"""场景列表响应模型(支持分页)
|
|
||||||
|
|
||||||
用于返回本体场景列表。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
items: 场景列表
|
|
||||||
page: 分页信息(可选,分页时返回)
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> # 不分页
|
|
||||||
>>> response = SceneListResponse(
|
|
||||||
... items=[scene1, scene2]
|
|
||||||
... )
|
|
||||||
>>> # 分页
|
|
||||||
>>> response = SceneListResponse(
|
|
||||||
... items=[scene1, scene2, ...],
|
|
||||||
... page=PaginationInfo(page=1, pagesize=100, total=150, hasnext=True)
|
|
||||||
... )
|
|
||||||
"""
|
|
||||||
items: List[SceneResponse] = Field(..., description="场景列表")
|
|
||||||
page: Optional[PaginationInfo] = Field(None, description="分页信息")
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 本体类型相关 Schema ====================
|
|
||||||
|
|
||||||
class ClassItem(BaseModel):
|
|
||||||
"""单个类型信息模型
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
class_name: 类型名称,必填,1-200字符
|
|
||||||
class_description: 类型描述,可选
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> item = ClassItem(
|
|
||||||
... class_name="患者",
|
|
||||||
... class_description="医院患者信息"
|
|
||||||
... )
|
|
||||||
"""
|
|
||||||
class_name: str = Field(..., min_length=1, max_length=200, description="类型名称")
|
|
||||||
class_description: Optional[str] = Field(None, description="类型描述")
|
|
||||||
|
|
||||||
|
|
||||||
class ClassCreateRequest(BaseModel):
|
|
||||||
"""类型创建请求模型(统一使用列表形式)
|
|
||||||
|
|
||||||
通过列表中元素数量决定创建模式:
|
|
||||||
- 列表包含 1 个元素:单个创建
|
|
||||||
- 列表包含多个元素:批量创建
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
scene_id: 所属场景ID,必填
|
|
||||||
classes: 类型列表,必填,至少包含 1 个元素
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
# 单个创建(列表中 1 个元素)
|
|
||||||
>>> request = ClassCreateRequest(
|
|
||||||
... scene_id=uuid.uuid4(),
|
|
||||||
... classes=[
|
|
||||||
... ClassItem(class_name="患者", class_description="医院患者信息")
|
|
||||||
... ]
|
|
||||||
... )
|
|
||||||
|
|
||||||
# 批量创建(列表中多个元素)
|
|
||||||
>>> request = ClassCreateRequest(
|
|
||||||
... scene_id=uuid.uuid4(),
|
|
||||||
... classes=[
|
|
||||||
... ClassItem(class_name="患者", class_description="医院患者信息"),
|
|
||||||
... ClassItem(class_name="医生", class_description="医院医生信息"),
|
|
||||||
... ClassItem(class_name="药品", class_description="医院药品信息")
|
|
||||||
... ]
|
|
||||||
... )
|
|
||||||
"""
|
|
||||||
scene_id: UUID = Field(..., description="所属场景ID")
|
|
||||||
classes: List[ClassItem] = Field(..., min_length=1, description="类型列表,至少包含 1 个元素")
|
|
||||||
|
|
||||||
|
|
||||||
class ClassUpdateRequest(BaseModel):
|
|
||||||
"""类型更新请求模型
|
|
||||||
|
|
||||||
用于更新已有本体类型信息。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
class_name: 类型名称,可选,1-200字符
|
|
||||||
class_description: 类型描述,可选
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> request = ClassUpdateRequest(
|
|
||||||
... class_name="更新后的类型名称",
|
|
||||||
... class_description="更新后的描述"
|
|
||||||
... )
|
|
||||||
"""
|
|
||||||
class_name: Optional[str] = Field(None, min_length=1, max_length=200, description="类型名称")
|
|
||||||
class_description: Optional[str] = Field(None, description="类型描述")
|
|
||||||
|
|
||||||
|
|
||||||
class ClassResponse(BaseModel):
|
|
||||||
"""类型响应模型
|
|
||||||
|
|
||||||
用于返回本体类型信息。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
class_id: 类型ID
|
|
||||||
class_name: 类型名称
|
|
||||||
class_description: 类型描述
|
|
||||||
scene_id: 所属场景ID
|
|
||||||
created_at: 创建时间(毫秒时间戳)
|
|
||||||
updated_at: 更新时间(毫秒时间戳)
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> response = ClassResponse(
|
|
||||||
... class_id=uuid.uuid4(),
|
|
||||||
... class_name="患者",
|
|
||||||
... class_description="医院患者信息",
|
|
||||||
... scene_id=uuid.uuid4(),
|
|
||||||
... created_at=datetime.now(),
|
|
||||||
... updated_at=datetime.now()
|
|
||||||
... )
|
|
||||||
"""
|
|
||||||
class_id: UUID = Field(..., description="类型ID")
|
|
||||||
class_name: str = Field(..., description="类型名称")
|
|
||||||
class_description: Optional[str] = Field(None, description="类型描述")
|
|
||||||
scene_id: UUID = Field(..., description="所属场景ID")
|
|
||||||
created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)")
|
|
||||||
updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)")
|
|
||||||
|
|
||||||
@field_serializer("created_at", when_used="json")
|
|
||||||
def _serialize_created_at(self, dt: datetime.datetime):
|
|
||||||
"""将创建时间序列化为毫秒时间戳"""
|
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
|
||||||
|
|
||||||
@field_serializer("updated_at", when_used="json")
|
|
||||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
|
||||||
"""将更新时间序列化为毫秒时间戳"""
|
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassBatchCreateResponse(BaseModel):
|
|
||||||
"""批量创建类型响应模型
|
|
||||||
|
|
||||||
用于返回批量创建的结果统计和详情。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
total: 总共尝试创建的数量
|
|
||||||
success_count: 成功创建的数量
|
|
||||||
failed_count: 失败的数量
|
|
||||||
items: 成功创建的类型列表
|
|
||||||
errors: 失败的错误信息列表(可选)
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> response = ClassBatchCreateResponse(
|
|
||||||
... total=3,
|
|
||||||
... success_count=2,
|
|
||||||
... failed_count=1,
|
|
||||||
... items=[class1, class2],
|
|
||||||
... errors=["创建类型 '药品' 失败: 类型名称已存在"]
|
|
||||||
... )
|
|
||||||
"""
|
|
||||||
total: int = Field(..., description="总共尝试创建的数量")
|
|
||||||
success_count: int = Field(..., description="成功创建的数量")
|
|
||||||
failed_count: int = Field(0, description="失败的数量")
|
|
||||||
items: List[ClassResponse] = Field(..., description="成功创建的类型列表")
|
|
||||||
errors: Optional[List[str]] = Field(None, description="失败的错误信息列表")
|
|
||||||
|
|
||||||
|
|
||||||
class ClassListResponse(BaseModel):
|
|
||||||
"""类型列表响应模型
|
|
||||||
|
|
||||||
用于返回本体类型列表。
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
total: 总数量
|
|
||||||
scene_id: 所属场景ID
|
|
||||||
scene_name: 场景名称
|
|
||||||
scene_description: 场景描述
|
|
||||||
items: 类型列表
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> response = ClassListResponse(
|
|
||||||
... total=3,
|
|
||||||
... scene_id=uuid.uuid4(),
|
|
||||||
... scene_name="医疗场景",
|
|
||||||
... scene_description="用于医疗领域的本体建模",
|
|
||||||
... items=[class1, class2, class3]
|
|
||||||
... )
|
|
||||||
"""
|
|
||||||
total: int = Field(..., description="总数量")
|
|
||||||
scene_id: UUID = Field(..., description="所属场景ID")
|
|
||||||
scene_name: str = Field(..., description="场景名称")
|
|
||||||
scene_description: Optional[str] = Field(None, description="场景描述")
|
|
||||||
items: List[ClassResponse] = Field(..., description="类型列表")
|
|
||||||
@@ -22,23 +22,6 @@ class PromptOptMessage(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PromptSaveRequest(BaseModel):
|
|
||||||
session_id: UUID = Field(
|
|
||||||
...,
|
|
||||||
description="Session ID"
|
|
||||||
)
|
|
||||||
|
|
||||||
title: str = Field(
|
|
||||||
...,
|
|
||||||
description="Prompt Title"
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt: str = Field(
|
|
||||||
...,
|
|
||||||
description="Optimized prompt content"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PromptOptModelSet(BaseModel):
|
class PromptOptModelSet(BaseModel):
|
||||||
id: UUID | None = Field(
|
id: UUID | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|||||||
@@ -171,14 +171,7 @@ class AppChatService:
|
|||||||
self.conversation_service.save_conversation_messages(
|
self.conversation_service.save_conversation_messages(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
user_message=message,
|
user_message=message,
|
||||||
assistant_message=result["content"],
|
assistant_message=result["content"]
|
||||||
meta_data={
|
|
||||||
"usage": result.get("usage", {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"total_tokens": 0
|
|
||||||
})
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
@@ -317,7 +310,6 @@ class AppChatService:
|
|||||||
|
|
||||||
# 流式调用 Agent
|
# 流式调用 Agent
|
||||||
full_content = ""
|
full_content = ""
|
||||||
total_tokens = 0
|
|
||||||
async for chunk in agent.chat_stream(
|
async for chunk in agent.chat_stream(
|
||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
@@ -328,12 +320,9 @@ class AppChatService:
|
|||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
memory_flag=memory_flag
|
memory_flag=memory_flag
|
||||||
):
|
):
|
||||||
if isinstance(chunk, int):
|
full_content += chunk
|
||||||
total_tokens = chunk
|
# 发送消息块事件
|
||||||
else:
|
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
|
||||||
full_content += chunk
|
|
||||||
# 发送消息块事件
|
|
||||||
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
|
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
@@ -350,7 +339,7 @@ class AppChatService:
|
|||||||
content=full_content,
|
content=full_content,
|
||||||
meta_data={
|
meta_data={
|
||||||
"model": api_key_obj.model_name,
|
"model": api_key_obj.model_name,
|
||||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
|
"usage": {}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -427,11 +416,7 @@ class AppChatService:
|
|||||||
meta_data={
|
meta_data={
|
||||||
"mode": result.get("mode"),
|
"mode": result.get("mode"),
|
||||||
"elapsed_time": result.get("elapsed_time"),
|
"elapsed_time": result.get("elapsed_time"),
|
||||||
"usage": result.get("usage", {
|
"sub_results": result.get("sub_results")
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"total_tokens": 0
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -473,7 +458,6 @@ class AppChatService:
|
|||||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
full_content = ""
|
full_content = ""
|
||||||
total_tokens = 0
|
|
||||||
|
|
||||||
# 2. 创建编排器
|
# 2. 创建编排器
|
||||||
orchestrator = MultiAgentOrchestrator(self.db, config)
|
orchestrator = MultiAgentOrchestrator(self.db, config)
|
||||||
@@ -490,26 +474,16 @@ class AppChatService:
|
|||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id
|
||||||
):
|
):
|
||||||
if "sub_usage" in event:
|
yield event
|
||||||
if "data:" in event:
|
# 尝试提取内容(用于保存)
|
||||||
try:
|
if "data:" in event:
|
||||||
data_line = event.split("data: ", 1)[1].strip()
|
try:
|
||||||
data = json.loads(data_line)
|
data_line = event.split("data: ", 1)[1].strip()
|
||||||
if "total_tokens" in data:
|
data = json.loads(data_line)
|
||||||
total_tokens += data["total_tokens"]
|
if "content" in data:
|
||||||
except:
|
full_content += data["content"]
|
||||||
pass
|
except:
|
||||||
else:
|
pass
|
||||||
yield event
|
|
||||||
# 尝试提取内容(用于保存)
|
|
||||||
if "data:" in event:
|
|
||||||
try:
|
|
||||||
data_line = event.split("data: ", 1)[1].strip()
|
|
||||||
data = json.loads(data_line)
|
|
||||||
if "content" in data:
|
|
||||||
full_content += data["content"]
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
@@ -525,12 +499,7 @@ class AppChatService:
|
|||||||
role="assistant",
|
role="assistant",
|
||||||
content=full_content,
|
content=full_content,
|
||||||
meta_data={
|
meta_data={
|
||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"total_tokens": total_tokens
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,193 +0,0 @@
|
|||||||
"""应用统计服务"""
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Dict, Any, List
|
|
||||||
import uuid
|
|
||||||
from sqlalchemy import func, and_, cast, Date
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from app.models.conversation_model import Conversation, Message
|
|
||||||
from app.models.end_user_model import EndUser
|
|
||||||
from app.models.api_key_model import ApiKey, ApiKeyLog
|
|
||||||
from app.core.exceptions import BusinessException
|
|
||||||
from app.core.error_codes import BizCode
|
|
||||||
|
|
||||||
|
|
||||||
class AppStatisticsService:
|
|
||||||
"""应用统计服务"""
|
|
||||||
|
|
||||||
def __init__(self, db: Session):
|
|
||||||
self.db = db
|
|
||||||
|
|
||||||
def get_app_statistics(
|
|
||||||
self,
|
|
||||||
app_id: uuid.UUID,
|
|
||||||
workspace_id: uuid.UUID,
|
|
||||||
start_date: int,
|
|
||||||
end_date: int
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""获取应用统计数据
|
|
||||||
|
|
||||||
Args:
|
|
||||||
app_id: 应用ID
|
|
||||||
workspace_id: 工作空间ID
|
|
||||||
start_date: 开始时间戳(毫秒)
|
|
||||||
end_date: 结束时间戳(毫秒)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
统计数据字典
|
|
||||||
"""
|
|
||||||
# 将毫秒时间戳转换为 datetime
|
|
||||||
start_dt = datetime.fromtimestamp(start_date / 1000)
|
|
||||||
end_dt = datetime.fromtimestamp(end_date / 1000) + timedelta(days=1)
|
|
||||||
|
|
||||||
# 1. 会话统计
|
|
||||||
conversations_stats = self._get_conversations_statistics(app_id, workspace_id, start_dt, end_dt)
|
|
||||||
|
|
||||||
# 2. 新增用户统计
|
|
||||||
users_stats = self._get_new_users_statistics(app_id, start_dt, end_dt)
|
|
||||||
|
|
||||||
# 3. API调用统计
|
|
||||||
api_stats = self._get_api_calls_statistics(app_id, start_dt, end_dt)
|
|
||||||
|
|
||||||
# 4. Token消耗统计
|
|
||||||
token_stats = self._get_token_statistics(app_id, start_dt, end_dt)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"daily_conversations": conversations_stats["daily"],
|
|
||||||
"total_conversations": conversations_stats["total"],
|
|
||||||
"daily_new_users": users_stats["daily"],
|
|
||||||
"total_new_users": users_stats["total"],
|
|
||||||
"daily_api_calls": api_stats["daily"],
|
|
||||||
"total_api_calls": api_stats["total"],
|
|
||||||
"daily_tokens": token_stats["daily"],
|
|
||||||
"total_tokens": token_stats["total"]
|
|
||||||
}
|
|
||||||
|
|
||||||
def _get_conversations_statistics(
|
|
||||||
self,
|
|
||||||
app_id: uuid.UUID,
|
|
||||||
workspace_id: uuid.UUID,
|
|
||||||
start_dt: datetime,
|
|
||||||
end_dt: datetime
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""获取会话统计"""
|
|
||||||
# 每日会话数
|
|
||||||
daily_query = self.db.query(
|
|
||||||
cast(Conversation.created_at, Date).label('date'),
|
|
||||||
func.count(Conversation.id).label('count')
|
|
||||||
).filter(
|
|
||||||
and_(
|
|
||||||
Conversation.app_id == app_id,
|
|
||||||
Conversation.workspace_id == workspace_id,
|
|
||||||
Conversation.created_at >= start_dt,
|
|
||||||
Conversation.created_at < end_dt
|
|
||||||
)
|
|
||||||
).group_by(cast(Conversation.created_at, Date)).all()
|
|
||||||
|
|
||||||
daily_data = [{"date": str(row.date), "count": row.count} for row in daily_query]
|
|
||||||
total = sum(row["count"] for row in daily_data)
|
|
||||||
|
|
||||||
return {"daily": daily_data, "total": total}
|
|
||||||
|
|
||||||
def _get_new_users_statistics(
|
|
||||||
self,
|
|
||||||
app_id: uuid.UUID,
|
|
||||||
start_dt: datetime,
|
|
||||||
end_dt: datetime
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""获取新增用户统计"""
|
|
||||||
# 每日新增用户数
|
|
||||||
daily_query = self.db.query(
|
|
||||||
cast(EndUser.created_at, Date).label('date'),
|
|
||||||
func.count(EndUser.id).label('count')
|
|
||||||
).filter(
|
|
||||||
and_(
|
|
||||||
EndUser.app_id == app_id,
|
|
||||||
EndUser.created_at >= start_dt,
|
|
||||||
EndUser.created_at < end_dt
|
|
||||||
)
|
|
||||||
).group_by(cast(EndUser.created_at, Date)).all()
|
|
||||||
|
|
||||||
daily_data = [{"date": str(row.date), "count": row.count} for row in daily_query]
|
|
||||||
total = sum(row["count"] for row in daily_data)
|
|
||||||
|
|
||||||
return {"daily": daily_data, "total": total}
|
|
||||||
|
|
||||||
def _get_api_calls_statistics(
|
|
||||||
self,
|
|
||||||
app_id: uuid.UUID,
|
|
||||||
start_dt: datetime,
|
|
||||||
end_dt: datetime
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""获取API调用统计"""
|
|
||||||
# 每日API调用次数
|
|
||||||
daily_query = self.db.query(
|
|
||||||
cast(ApiKeyLog.created_at, Date).label('date'),
|
|
||||||
func.count(ApiKeyLog.id).label('count')
|
|
||||||
).join(
|
|
||||||
ApiKey, ApiKeyLog.api_key_id == ApiKey.id
|
|
||||||
).filter(
|
|
||||||
and_(
|
|
||||||
ApiKey.resource_id == app_id,
|
|
||||||
ApiKeyLog.created_at >= start_dt,
|
|
||||||
ApiKeyLog.created_at < end_dt
|
|
||||||
)
|
|
||||||
).group_by(cast(ApiKeyLog.created_at, Date)).all()
|
|
||||||
|
|
||||||
daily_data = [{"date": str(row.date), "count": row.count} for row in daily_query]
|
|
||||||
total = sum(row["count"] for row in daily_data)
|
|
||||||
|
|
||||||
return {"daily": daily_data, "total": total}
|
|
||||||
|
|
||||||
def _get_token_statistics(
|
|
||||||
self,
|
|
||||||
app_id: uuid.UUID,
|
|
||||||
start_dt: datetime,
|
|
||||||
end_dt: datetime
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""获取Token消耗统计(从Message的meta_data中提取)"""
|
|
||||||
from sqlalchemy import text
|
|
||||||
|
|
||||||
# 查询所有相关消息的token使用情况
|
|
||||||
# meta_data中可能包含: {"usage": {"total_tokens": 100}} 或 {"tokens": 100}
|
|
||||||
daily_query = self.db.query(
|
|
||||||
cast(Message.created_at, Date).label('date'),
|
|
||||||
Message.meta_data
|
|
||||||
).join(
|
|
||||||
Conversation, Message.conversation_id == Conversation.id
|
|
||||||
).filter(
|
|
||||||
and_(
|
|
||||||
Conversation.app_id == app_id,
|
|
||||||
Message.created_at >= start_dt,
|
|
||||||
Message.created_at < end_dt,
|
|
||||||
Message.meta_data.isnot(None)
|
|
||||||
)
|
|
||||||
).all()
|
|
||||||
|
|
||||||
# 按日期聚合token
|
|
||||||
daily_tokens = {}
|
|
||||||
for row in daily_query:
|
|
||||||
date_str = str(row.date)
|
|
||||||
meta = row.meta_data or {}
|
|
||||||
|
|
||||||
# 提取token数量(支持多种格式)
|
|
||||||
tokens = 0
|
|
||||||
if isinstance(meta, dict):
|
|
||||||
# 格式1: {"usage": {"total_tokens": 100}}
|
|
||||||
if "usage" in meta and isinstance(meta["usage"], dict):
|
|
||||||
tokens = meta["usage"].get("total_tokens", 0)
|
|
||||||
# 格式2: {"tokens": 100}
|
|
||||||
elif "tokens" in meta:
|
|
||||||
tokens = meta.get("tokens", 0)
|
|
||||||
# 格式3: {"total_tokens": 100}
|
|
||||||
elif "total_tokens" in meta:
|
|
||||||
tokens = meta.get("total_tokens", 0)
|
|
||||||
|
|
||||||
if date_str not in daily_tokens:
|
|
||||||
daily_tokens[date_str] = 0
|
|
||||||
daily_tokens[date_str] += int(tokens)
|
|
||||||
|
|
||||||
daily_data = [{"date": date, "count": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
|
|
||||||
total = sum(row["count"] for row in daily_data)
|
|
||||||
|
|
||||||
return {"daily": daily_data, "total": total}
|
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
"""会话服务"""
|
"""会话服务"""
|
||||||
import os
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
@@ -299,8 +298,7 @@ class ConversationService:
|
|||||||
self,
|
self,
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
user_message: str,
|
user_message: str,
|
||||||
assistant_message: str,
|
assistant_message: str
|
||||||
meta_data: Optional[dict] = None
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Save a pair of user and assistant messages to the conversation.
|
Save a pair of user and assistant messages to the conversation.
|
||||||
@@ -309,7 +307,6 @@ class ConversationService:
|
|||||||
conversation_id (uuid.UUID): Conversation UUID.
|
conversation_id (uuid.UUID): Conversation UUID.
|
||||||
user_message (str): User's message content.
|
user_message (str): User's message content.
|
||||||
assistant_message (str): Assistant's response content.
|
assistant_message (str): Assistant's response content.
|
||||||
meta_data (Optional[dict]): Optional metadata for the messages.
|
|
||||||
"""
|
"""
|
||||||
self.add_message(
|
self.add_message(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
@@ -320,8 +317,7 @@ class ConversationService:
|
|||||||
self.add_message(
|
self.add_message(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=assistant_message,
|
content=assistant_message
|
||||||
meta_data=meta_data
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -530,12 +526,12 @@ class ConversationService:
|
|||||||
takeaways=[],
|
takeaways=[],
|
||||||
info_score=0,
|
info_score=0,
|
||||||
)
|
)
|
||||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
|
||||||
with open(os.path.join(prompt_path, 'conversation_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
with open('app/services/prompt/conversation_summary_system.jinja2', 'r', encoding='utf-8') as f:
|
||||||
system_prompt = f.read()
|
system_prompt = f.read()
|
||||||
rendered_system_message = Template(system_prompt).render()
|
rendered_system_message = Template(system_prompt).render()
|
||||||
|
|
||||||
with open(os.path.join(prompt_path, 'conversation_summary_user.jinja2'), 'r', encoding='utf-8') as f:
|
with open('app/services/prompt/conversation_summary_user.jinja2', 'r', encoding='utf-8') as f:
|
||||||
user_prompt = f.read()
|
user_prompt = f.read()
|
||||||
rendered_user_message = Template(user_prompt).render(
|
rendered_user_message = Template(user_prompt).render(
|
||||||
language=language,
|
language=language,
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from app.core.exceptions import BusinessException
|
|||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.rag.nlp.search import knowledge_retrieval
|
from app.core.rag.nlp.search import knowledge_retrieval
|
||||||
from app.models import AgentConfig, ModelApiKey, ModelConfig
|
from app.models import AgentConfig, ModelApiKey, ModelConfig
|
||||||
from app.repositories.model_repository import ModelApiKeyRepository
|
|
||||||
from app.repositories.tool_repository import ToolRepository
|
from app.repositories.tool_repository import ToolRepository
|
||||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||||
from app.services import task_service
|
from app.services import task_service
|
||||||
@@ -110,8 +109,6 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
|||||||
result = task_service.get_task_memory_read_result(task.id)
|
result = task_service.get_task_memory_read_result(task.id)
|
||||||
status = result.get("status")
|
status = result.get("status")
|
||||||
logger.info(f"读取任务状态:{status}")
|
logger.info(f"读取任务状态:{status}")
|
||||||
if memory_content:
|
|
||||||
memory_content = memory_content['answer']
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
@@ -125,6 +122,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
|||||||
"content_length": len(str(memory_content))
|
"content_length": len(str(memory_content))
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return f"检索到以下历史记忆:\n\n{memory_content}"
|
return f"检索到以下历史记忆:\n\n{memory_content}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
||||||
@@ -443,14 +441,7 @@ class DraftRunService:
|
|||||||
user_message=message,
|
user_message=message,
|
||||||
assistant_message=result["content"],
|
assistant_message=result["content"],
|
||||||
app_id=agent_config.app_id,
|
app_id=agent_config.app_id,
|
||||||
user_id=user_id,
|
user_id=user_id
|
||||||
meta_data={
|
|
||||||
"usage": result.get("usage", {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"total_tokens": 0
|
|
||||||
})
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
@@ -657,7 +648,6 @@ class DraftRunService:
|
|||||||
|
|
||||||
# 9. 流式调用 Agent
|
# 9. 流式调用 Agent
|
||||||
full_content = ""
|
full_content = ""
|
||||||
total_tokens = 0
|
|
||||||
async for chunk in agent.chat_stream(
|
async for chunk in agent.chat_stream(
|
||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
@@ -668,22 +658,14 @@ class DraftRunService:
|
|||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
memory_flag=memory_flag
|
memory_flag=memory_flag
|
||||||
):
|
):
|
||||||
if isinstance(chunk, int):
|
full_content += chunk
|
||||||
total_tokens = chunk
|
# 发送消息块事件
|
||||||
else:
|
yield self._format_sse_event("message", {
|
||||||
full_content += chunk
|
"content": chunk
|
||||||
# 发送消息块事件
|
})
|
||||||
yield self._format_sse_event("message", {
|
|
||||||
"content": chunk
|
|
||||||
})
|
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
if sub_agent:
|
|
||||||
yield self._format_sse_event("sub_usage", {
|
|
||||||
"total_tokens": total_tokens
|
|
||||||
})
|
|
||||||
|
|
||||||
# 10. 保存会话消息
|
# 10. 保存会话消息
|
||||||
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
||||||
await self._save_conversation_message(
|
await self._save_conversation_message(
|
||||||
@@ -691,10 +673,7 @@ class DraftRunService:
|
|||||||
user_message=message,
|
user_message=message,
|
||||||
assistant_message=full_content,
|
assistant_message=full_content,
|
||||||
app_id=agent_config.app_id,
|
app_id=agent_config.app_id,
|
||||||
user_id=user_id,
|
user_id=user_id
|
||||||
meta_data={
|
|
||||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 11. 发送结束事件
|
# 11. 发送结束事件
|
||||||
@@ -745,21 +724,17 @@ class DraftRunService:
|
|||||||
Raises:
|
Raises:
|
||||||
BusinessException: 当没有可用的 API Key 时
|
BusinessException: 当没有可用的 API Key 时
|
||||||
"""
|
"""
|
||||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
stmt = (
|
||||||
# stmt = (
|
select(ModelApiKey)
|
||||||
# select(ModelApiKey).join(
|
.where(
|
||||||
# ModelConfig, ModelApiKey.model_configs
|
ModelApiKey.model_config_id == model_config_id,
|
||||||
# )
|
ModelApiKey.is_active.is_(True)
|
||||||
# .where(
|
)
|
||||||
# ModelConfig.id == model_config_id,
|
.order_by(ModelApiKey.priority.desc())
|
||||||
# ModelApiKey.is_active.is_(True)
|
.limit(1)
|
||||||
# )
|
)
|
||||||
# .order_by(ModelApiKey.priority.desc())
|
|
||||||
# .limit(1)
|
api_key = self.db.scalars(stmt).first()
|
||||||
# )
|
|
||||||
#
|
|
||||||
# api_key = self.db.scalars(stmt).first()
|
|
||||||
api_key = api_keys[0] if api_keys else None
|
|
||||||
|
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||||
@@ -918,7 +893,6 @@ class DraftRunService:
|
|||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
user_message: str,
|
user_message: str,
|
||||||
assistant_message: str,
|
assistant_message: str,
|
||||||
meta_data: dict,
|
|
||||||
app_id: Optional[uuid.UUID] = None,
|
app_id: Optional[uuid.UUID] = None,
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -930,7 +904,6 @@ class DraftRunService:
|
|||||||
assistant_message: AI 回复消息
|
assistant_message: AI 回复消息
|
||||||
app_id: 应用ID(未使用,保留用于兼容性)
|
app_id: 应用ID(未使用,保留用于兼容性)
|
||||||
user_id: 用户ID(未使用,保留用于兼容性)
|
user_id: 用户ID(未使用,保留用于兼容性)
|
||||||
meta_data: token消耗
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
@@ -949,8 +922,7 @@ class DraftRunService:
|
|||||||
conversation_service.add_message(
|
conversation_service.add_message(
|
||||||
conversation_id=conv_uuid,
|
conversation_id=conv_uuid,
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=assistant_message,
|
content=assistant_message
|
||||||
meta_data=meta_data
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|||||||
@@ -17,15 +17,12 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.utils.config_utils import resolve_config_id
|
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
class EmotionSuggestion(BaseModel):
|
class EmotionSuggestion(BaseModel):
|
||||||
"""情绪建议模型"""
|
"""情绪建议模型"""
|
||||||
type: str = Field(...,
|
type: str = Field(..., description="建议类型:emotion_balance/activity_recommendation/social_connection/stress_management")
|
||||||
description="建议类型:emotion_balance/activity_recommendation/social_connection/stress_management")
|
|
||||||
title: str = Field(..., description="建议标题")
|
title: str = Field(..., description="建议标题")
|
||||||
content: str = Field(..., description="建议内容")
|
content: str = Field(..., description="建议内容")
|
||||||
priority: str = Field(..., description="优先级:high/medium/low")
|
priority: str = Field(..., description="优先级:high/medium/low")
|
||||||
@@ -58,12 +55,12 @@ class EmotionAnalyticsService:
|
|||||||
logger.info("情绪分析服务初始化完成")
|
logger.info("情绪分析服务初始化完成")
|
||||||
|
|
||||||
async def get_emotion_tags(
|
async def get_emotion_tags(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
emotion_type: Optional[str] = None,
|
emotion_type: Optional[str] = None,
|
||||||
start_date: Optional[str] = None,
|
start_date: Optional[str] = None,
|
||||||
end_date: Optional[str] = None,
|
end_date: Optional[str] = None,
|
||||||
limit: int = 10
|
limit: int = 10
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""获取情绪标签统计
|
"""获取情绪标签统计
|
||||||
|
|
||||||
@@ -74,7 +71,7 @@ class EmotionAnalyticsService:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"获取情绪标签统计: user={end_user_id}, type={emotion_type}, "
|
logger.info(f"获取情绪标签统计: user={end_user_id}, type={emotion_type}, "
|
||||||
f"start={start_date}, end={end_date}, limit={limit}")
|
f"start={start_date}, end={end_date}, limit={limit}")
|
||||||
|
|
||||||
# 调用仓储层查询
|
# 调用仓储层查询
|
||||||
tags = await self.emotion_repo.get_emotion_tags(
|
tags = await self.emotion_repo.get_emotion_tags(
|
||||||
@@ -136,10 +133,10 @@ class EmotionAnalyticsService:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_emotion_wordcloud(
|
async def get_emotion_wordcloud(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
emotion_type: Optional[str] = None,
|
emotion_type: Optional[str] = None,
|
||||||
limit: int = 50
|
limit: int = 50
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""获取情绪词云数据
|
"""获取情绪词云数据
|
||||||
|
|
||||||
@@ -214,7 +211,7 @@ class EmotionAnalyticsService:
|
|||||||
score = 50.0 # 如果没有非中性情绪,默认为50
|
score = 50.0 # 如果没有非中性情绪,默认为50
|
||||||
|
|
||||||
logger.debug(f"积极率计算: positive={positive_count}, negative={negative_count}, "
|
logger.debug(f"积极率计算: positive={positive_count}, negative={negative_count}, "
|
||||||
f"neutral={neutral_count}, score={score:.2f}")
|
f"neutral={neutral_count}, score={score:.2f}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"score": round(score, 2),
|
"score": round(score, 2),
|
||||||
@@ -253,7 +250,7 @@ class EmotionAnalyticsService:
|
|||||||
score = (1 - min(std_deviation, 1.0)) * 100
|
score = (1 - min(std_deviation, 1.0)) * 100
|
||||||
|
|
||||||
logger.debug(f"稳定性计算: intensities_count={len(intensities)}, "
|
logger.debug(f"稳定性计算: intensities_count={len(intensities)}, "
|
||||||
f"std_deviation={std_deviation:.3f}, score={score:.2f}")
|
f"std_deviation={std_deviation:.3f}, score={score:.2f}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"score": round(score, 2),
|
"score": round(score, 2),
|
||||||
@@ -306,7 +303,7 @@ class EmotionAnalyticsService:
|
|||||||
score = 100.0
|
score = 100.0
|
||||||
|
|
||||||
logger.debug(f"恢复力计算: negative_count={negative_count}, "
|
logger.debug(f"恢复力计算: negative_count={negative_count}, "
|
||||||
f"recovery_count={recovery_count}, score={score:.2f}")
|
f"recovery_count={recovery_count}, score={score:.2f}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"score": round(score, 2),
|
"score": round(score, 2),
|
||||||
@@ -314,9 +311,9 @@ class EmotionAnalyticsService:
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def calculate_emotion_health_index(
|
async def calculate_emotion_health_index(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
time_range: str = "30d"
|
time_range: str = "30d"
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""计算情绪健康指数
|
"""计算情绪健康指数
|
||||||
|
|
||||||
@@ -369,9 +366,9 @@ class EmotionAnalyticsService:
|
|||||||
# 计算综合健康分数
|
# 计算综合健康分数
|
||||||
# 公式:positivity_rate * 0.4 + stability * 0.3 + resilience * 0.3
|
# 公式:positivity_rate * 0.4 + stability * 0.3 + resilience * 0.3
|
||||||
health_score = (
|
health_score = (
|
||||||
positivity_rate["score"] * 0.4 +
|
positivity_rate["score"] * 0.4 +
|
||||||
stability["score"] * 0.3 +
|
stability["score"] * 0.3 +
|
||||||
resilience["score"] * 0.3
|
resilience["score"] * 0.3
|
||||||
)
|
)
|
||||||
|
|
||||||
# 确定健康等级
|
# 确定健康等级
|
||||||
@@ -463,7 +460,7 @@ class EmotionAnalyticsService:
|
|||||||
volatility = "未知"
|
volatility = "未知"
|
||||||
|
|
||||||
logger.debug(f"情绪模式分析: dominant_negative={dominant_negative_emotion}, "
|
logger.debug(f"情绪模式分析: dominant_negative={dominant_negative_emotion}, "
|
||||||
f"high_intensity_count={len(high_intensity_emotions)}, volatility={volatility}")
|
f"high_intensity_count={len(high_intensity_emotions)}, volatility={volatility}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"dominant_negative_emotion": dominant_negative_emotion,
|
"dominant_negative_emotion": dominant_negative_emotion,
|
||||||
@@ -472,9 +469,9 @@ class EmotionAnalyticsService:
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def generate_emotion_suggestions(
|
async def generate_emotion_suggestions(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
db: Session,
|
db: Session,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""生成个性化情绪建议
|
"""生成个性化情绪建议
|
||||||
|
|
||||||
@@ -501,7 +498,7 @@ class EmotionAnalyticsService:
|
|||||||
|
|
||||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
config_id = connected_config.get("memory_config_id")
|
config_id = connected_config.get("memory_config_id")
|
||||||
config_id = resolve_config_id(config_id, db)
|
|
||||||
if config_id is not None:
|
if config_id is not None:
|
||||||
from app.services.memory_config_service import (
|
from app.services.memory_config_service import (
|
||||||
MemoryConfigService,
|
MemoryConfigService,
|
||||||
@@ -621,10 +618,10 @@ class EmotionAnalyticsService:
|
|||||||
return {"interests": ["未知"]}
|
return {"interests": ["未知"]}
|
||||||
|
|
||||||
async def _build_suggestion_prompt(
|
async def _build_suggestion_prompt(
|
||||||
self,
|
self,
|
||||||
health_data: Dict[str, Any],
|
health_data: Dict[str, Any],
|
||||||
patterns: Dict[str, Any],
|
patterns: Dict[str, Any],
|
||||||
user_profile: Dict[str, Any]
|
user_profile: Dict[str, Any]
|
||||||
) -> str:
|
) -> str:
|
||||||
"""构建情绪建议生成的prompt
|
"""构建情绪建议生成的prompt
|
||||||
|
|
||||||
@@ -710,9 +707,9 @@ class EmotionAnalyticsService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def get_cached_suggestions(
|
async def get_cached_suggestions(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
db: Session,
|
db: Session,
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""从 Redis 缓存获取个性化情绪建议
|
"""从 Redis 缓存获取个性化情绪建议
|
||||||
|
|
||||||
@@ -743,11 +740,11 @@ class EmotionAnalyticsService:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def save_suggestions_cache(
|
async def save_suggestions_cache(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
suggestions_data: Dict[str, Any],
|
suggestions_data: Dict[str, Any],
|
||||||
db: Session,
|
db: Session,
|
||||||
expires_hours: int = 24
|
expires_hours: int = 24
|
||||||
) -> None:
|
) -> None:
|
||||||
"""保存建议到 Redis 缓存
|
"""保存建议到 Redis 缓存
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import uuid
|
|||||||
from typing import List, Dict, Any, Optional, AsyncGenerator, Annotated
|
from typing import List, Dict, Any, Optional, AsyncGenerator, Annotated
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, AIMessageChunk
|
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
|
||||||
from langgraph.graph import StateGraph, START, END
|
from langgraph.graph import StateGraph, START, END
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
@@ -727,12 +727,9 @@ class HandoffsService:
|
|||||||
|
|
||||||
# 提取响应
|
# 提取响应
|
||||||
response_content = ""
|
response_content = ""
|
||||||
total_tokens = 0
|
|
||||||
for msg in result.get("messages", []):
|
for msg in result.get("messages", []):
|
||||||
if isinstance(msg, AIMessage):
|
if isinstance(msg, AIMessage):
|
||||||
response_content = msg.content
|
response_content = msg.content
|
||||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
|
||||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
|
||||||
break
|
break
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -740,12 +737,7 @@ class HandoffsService:
|
|||||||
"active_agent": result.get("active_agent"),
|
"active_agent": result.get("active_agent"),
|
||||||
"response": response_content,
|
"response": response_content,
|
||||||
"message_count": len(result.get("messages", [])),
|
"message_count": len(result.get("messages", [])),
|
||||||
"handoff_count": result.get("handoff_count", 0),
|
"handoff_count": result.get("handoff_count", 0)
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"total_tokens": total_tokens
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async def chat_stream(
|
async def chat_stream(
|
||||||
@@ -838,12 +830,6 @@ class HandoffsService:
|
|||||||
|
|
||||||
# 捕获 LLM 结束事件,输出收集到的工具调用
|
# 捕获 LLM 结束事件,输出收集到的工具调用
|
||||||
elif kind == "on_chat_model_end":
|
elif kind == "on_chat_model_end":
|
||||||
output_message = event.get("data", {}).get("output", {})
|
|
||||||
if isinstance(output_message, AIMessageChunk):
|
|
||||||
response_meta = output_message.response_metadata if hasattr(output_message, 'response_metadata') else None
|
|
||||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
|
|
||||||
0) if response_meta else 0
|
|
||||||
yield f"event: sub_usage\ndata: {json.dumps({"total_tokens": total_tokens}, ensure_ascii=False)}\n\n"
|
|
||||||
if collected_tool_calls:
|
if collected_tool_calls:
|
||||||
# 找到参数最完整的 transfer 工具调用
|
# 找到参数最完整的 transfer 工具调用
|
||||||
best_tc = None
|
best_tc = None
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import uuid
|
|||||||
from typing import Dict, Any, List, Optional, Tuple
|
from typing import Dict, Any, List, Optional, Tuple
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.repositories.model_repository import ModelApiKeyRepository
|
|
||||||
from app.services.conversation_state_manager import ConversationStateManager
|
from app.services.conversation_state_manager import ConversationStateManager
|
||||||
from app.models import ModelConfig, AgentConfig
|
from app.models import ModelConfig, AgentConfig
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
@@ -383,14 +382,11 @@ class LLMRouter:
|
|||||||
from app.core.models.base import RedBearModelConfig
|
from app.core.models.base import RedBearModelConfig
|
||||||
from app.models import ModelApiKey, ModelType
|
from app.models import ModelApiKey, ModelType
|
||||||
|
|
||||||
# 获取 API Key 配置(通过关联关系)
|
# 获取 API Key 配置
|
||||||
# api_key_config = self.db.query(ModelApiKey).join(
|
api_key_config = self.db.query(ModelApiKey).filter(
|
||||||
# ModelConfig, ModelApiKey.model_configs
|
ModelApiKey.model_config_id == self.routing_model_config.id,
|
||||||
# ).filter(ModelConfig.id == self.routing_model_config.id,
|
ModelApiKey.is_active
|
||||||
# ModelApiKey.is_active == True
|
).first()
|
||||||
# ).first()
|
|
||||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, self.routing_model_config.id)
|
|
||||||
api_key_config = api_keys[0] if api_keys else None
|
|
||||||
|
|
||||||
if not api_key_config:
|
if not api_key_config:
|
||||||
raise Exception("路由模型没有可用的 API Key")
|
raise Exception("路由模型没有可用的 API Key")
|
||||||
@@ -424,9 +420,6 @@ class LLMRouter:
|
|||||||
# 调用模型
|
# 调用模型
|
||||||
response = await llm.ainvoke(prompt)
|
response = await llm.ainvoke(prompt)
|
||||||
|
|
||||||
from app.services.model_service import ModelApiKeyService
|
|
||||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
|
|
||||||
|
|
||||||
# 提取响应内容
|
# 提取响应内容
|
||||||
if hasattr(response, 'content'):
|
if hasattr(response, 'content'):
|
||||||
return response.content
|
return response.content
|
||||||
|
|||||||
@@ -334,9 +334,7 @@ class MemoryAgentService:
|
|||||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||||
elif msg['role'] == 'assistant':
|
elif msg['role'] == 'assistant':
|
||||||
langchain_messages.append(AIMessage(content=msg['content']))
|
langchain_messages.append(AIMessage(content=msg['content']))
|
||||||
print(100*'-')
|
|
||||||
print(langchain_messages)
|
|
||||||
print(100*'-')
|
|
||||||
# 初始状态 - 包含所有必要字段
|
# 初始状态 - 包含所有必要字段
|
||||||
initial_state = {
|
initial_state = {
|
||||||
"messages": langchain_messages,
|
"messages": langchain_messages,
|
||||||
|
|||||||
@@ -338,7 +338,7 @@ class MemoryConfigService:
|
|||||||
"provider": api_config.provider,
|
"provider": api_config.provider,
|
||||||
"api_key": api_config.api_key,
|
"api_key": api_config.api_key,
|
||||||
"base_url": api_config.api_base,
|
"base_url": api_config.api_base,
|
||||||
"model_config_id": str(config.id),
|
"model_config_id": api_config.model_config_id,
|
||||||
"type": config.type,
|
"type": config.type,
|
||||||
"timeout": settings.LLM_TIMEOUT,
|
"timeout": settings.LLM_TIMEOUT,
|
||||||
"max_retries": settings.LLM_MAX_RETRIES,
|
"max_retries": settings.LLM_MAX_RETRIES,
|
||||||
@@ -370,7 +370,7 @@ class MemoryConfigService:
|
|||||||
"provider": api_config.provider,
|
"provider": api_config.provider,
|
||||||
"api_key": api_config.api_key,
|
"api_key": api_config.api_key,
|
||||||
"base_url": api_config.api_base,
|
"base_url": api_config.api_base,
|
||||||
"model_config_id": str(config.id),
|
"model_config_id": api_config.model_config_id,
|
||||||
"type": config.type,
|
"type": config.type,
|
||||||
"timeout": 120.0,
|
"timeout": 120.0,
|
||||||
"max_retries": 5,
|
"max_retries": 5,
|
||||||
|
|||||||
@@ -53,10 +53,7 @@ def get_workspace_end_users(
|
|||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
current_user: User
|
current_user: User
|
||||||
) -> List[EndUser]:
|
) -> List[EndUser]:
|
||||||
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)
|
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)"""
|
||||||
|
|
||||||
返回结果按 updated_at 从新到旧排序(NULL 值排在最后)
|
|
||||||
"""
|
|
||||||
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -71,14 +68,9 @@ def get_workspace_end_users(
|
|||||||
app_ids = [app.id for app in apps_orm]
|
app_ids = [app.id for app in apps_orm]
|
||||||
|
|
||||||
# 批量查询所有 end_users(一次查询而非循环查询)
|
# 批量查询所有 end_users(一次查询而非循环查询)
|
||||||
# 按 updated_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性
|
|
||||||
from app.models.end_user_model import EndUser as EndUserModel
|
from app.models.end_user_model import EndUser as EndUserModel
|
||||||
from sqlalchemy import desc, nullslast
|
|
||||||
end_users_orm = db.query(EndUserModel).filter(
|
end_users_orm = db.query(EndUserModel).filter(
|
||||||
EndUserModel.app_id.in_(app_ids)
|
EndUserModel.app_id.in_(app_ids)
|
||||||
).order_by(
|
|
||||||
nullslast(desc(EndUserModel.updated_at)),
|
|
||||||
desc(EndUserModel.id)
|
|
||||||
).all()
|
).all()
|
||||||
|
|
||||||
# 转换为 Pydantic 模型(只在需要时转换)
|
# 转换为 Pydantic 模型(只在需要时转换)
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|||||||
from app.models.app_model import App
|
from app.models.app_model import App
|
||||||
from app.models.app_release_model import AppRelease
|
from app.models.app_release_model import AppRelease
|
||||||
from app.models.end_user_model import EndUser
|
from app.models.end_user_model import EndUser
|
||||||
from app.utils.config_utils import resolve_config_id
|
|
||||||
|
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
@@ -89,6 +88,8 @@ class WorkspaceAppService:
|
|||||||
|
|
||||||
for release in app_releases:
|
for release in app_releases:
|
||||||
memory_content = self._extract_memory_content(release.config)
|
memory_content = self._extract_memory_content(release.config)
|
||||||
|
|
||||||
|
|
||||||
if memory_content and memory_content in processed_configs:
|
if memory_content and memory_content in processed_configs:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -101,6 +102,7 @@ class WorkspaceAppService:
|
|||||||
if memory_content:
|
if memory_content:
|
||||||
processed_configs.add(memory_content)
|
processed_configs.add(memory_content)
|
||||||
memory_config_info = self._get_memory_config(memory_content)
|
memory_config_info = self._get_memory_config(memory_content)
|
||||||
|
|
||||||
if memory_config_info:
|
if memory_config_info:
|
||||||
if not any(dc["config_id"] == memory_config_info["config_id"] for dc in app_info["memory_configs"]):
|
if not any(dc["config_id"] == memory_config_info["config_id"] for dc in app_info["memory_configs"]):
|
||||||
app_info["memory_configs"].append(memory_config_info)
|
app_info["memory_configs"].append(memory_config_info)
|
||||||
@@ -121,12 +123,16 @@ class WorkspaceAppService:
|
|||||||
def _get_memory_config(self, memory_content: str) -> Dict[str, Any]:
|
def _get_memory_config(self, memory_content: str) -> Dict[str, Any]:
|
||||||
"""Retrieve memory_config information based on memory_content"""
|
"""Retrieve memory_config information based on memory_content"""
|
||||||
try:
|
try:
|
||||||
memory_content = resolve_config_id(memory_content, self.db)
|
memory_config_result = MemoryConfigRepository.query_reflection_config_by_id(self.db, int(memory_content))
|
||||||
memory_config_result = MemoryConfigRepository.query_reflection_config_by_id(self.db, (memory_content))
|
|
||||||
|
# memory_config_query, memory_config_params = MemoryConfigRepository.build_select_reflection(memory_content)
|
||||||
|
# memory_config_result = self.db.execute(text(memory_config_query), memory_config_params).fetchone()
|
||||||
|
# if memory_config_result is None:
|
||||||
|
# return None
|
||||||
|
|
||||||
if memory_config_result:
|
if memory_config_result:
|
||||||
return {
|
return {
|
||||||
"config_id": memory_content,
|
"config_id": memory_config_result.config_id,
|
||||||
"enable_self_reflexion": memory_config_result.enable_self_reflexion,
|
"enable_self_reflexion": memory_config_result.enable_self_reflexion,
|
||||||
"iteration_period": memory_config_result.iteration_period,
|
"iteration_period": memory_config_result.iteration_period,
|
||||||
"reflexion_range": memory_config_result.reflexion_range,
|
"reflexion_range": memory_config_result.reflexion_range,
|
||||||
@@ -151,8 +157,6 @@ class WorkspaceAppService:
|
|||||||
"app_id": str(end_user.app_id)
|
"app_id": str(end_user.app_id)
|
||||||
}
|
}
|
||||||
app_info["end_users"].append(end_user_info)
|
app_info["end_users"].append(end_user_info)
|
||||||
print(100*'-')
|
|
||||||
print(app_info)
|
|
||||||
|
|
||||||
def get_end_user_reflection_time(self, end_user_id: str) -> Optional[Any]:
|
def get_end_user_reflection_time(self, end_user_id: str) -> Optional[Any]:
|
||||||
"""
|
"""
|
||||||
@@ -377,6 +381,7 @@ class MemoryReflectionService:
|
|||||||
iteration_period = int(iteration_period)
|
iteration_period = int(iteration_period)
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
iteration_period = 24 # 默认24小时
|
iteration_period = 24 # 默认24小时
|
||||||
|
|
||||||
return ReflectionConfig(
|
return ReflectionConfig(
|
||||||
enabled=config_data.get("enable_self_reflexion", False),
|
enabled=config_data.get("enable_self_reflexion", False),
|
||||||
iteration_period=str(iteration_period), # ReflectionConfig期望字符串
|
iteration_period=str(iteration_period), # ReflectionConfig期望字符串
|
||||||
|
|||||||
@@ -129,12 +129,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
if not params.rerank_id:
|
if not params.rerank_id:
|
||||||
params.rerank_id = configs.get('rerank')
|
params.rerank_id = configs.get('rerank')
|
||||||
|
|
||||||
# reflection_model_id 和 emotion_model_id 默认与 llm_id 一致
|
|
||||||
if not params.reflection_model_id:
|
|
||||||
params.reflection_model_id = params.llm_id
|
|
||||||
if not params.emotion_model_id:
|
|
||||||
params.emotion_model_id = params.llm_id
|
|
||||||
|
|
||||||
config = MemoryConfigRepository.create(self.db, params)
|
config = MemoryConfigRepository.create(self.db, params)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
return {"affected": 1, "config_id": config.config_id}
|
return {"affected": 1, "config_id": config.config_id}
|
||||||
@@ -183,11 +177,11 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
|
|
||||||
# --- Read All ---
|
# --- Read All ---
|
||||||
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
|
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
|
||||||
results = MemoryConfigRepository.get_all(self.db, workspace_id)
|
configs = MemoryConfigRepository.get_all(self.db, workspace_id)
|
||||||
|
|
||||||
# 将 ORM 对象转换为字典列表
|
# 将 ORM 对象转换为字典列表
|
||||||
data_list = []
|
data_list = []
|
||||||
for config, scene_name in results:
|
for config in configs:
|
||||||
# 安全地转换 user_id 为 int
|
# 安全地转换 user_id 为 int
|
||||||
config_id_old = None
|
config_id_old = None
|
||||||
if config.config_id_old:
|
if config.config_id_old:
|
||||||
@@ -209,8 +203,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
|||||||
"end_user_id": config.end_user_id,
|
"end_user_id": config.end_user_id,
|
||||||
"config_id_old": config_id_old,
|
"config_id_old": config_id_old,
|
||||||
"apply_id": config.apply_id,
|
"apply_id": config.apply_id,
|
||||||
"scene_id": str(config.scene_id) if config.scene_id else None,
|
|
||||||
"scene_name": scene_name, # 新增:场景名称
|
|
||||||
"llm_id": config.llm_id,
|
"llm_id": config.llm_id,
|
||||||
"embedding_id": config.embedding_id,
|
"embedding_id": config.embedding_id,
|
||||||
"rerank_id": config.rerank_id,
|
"rerank_id": config.rerank_id,
|
||||||
@@ -636,9 +628,10 @@ async def analytics_recent_activity_stats() -> Dict[str, Any]:
|
|||||||
if m < 1:
|
if m < 1:
|
||||||
latest_relative = "刚刚"
|
latest_relative = "刚刚"
|
||||||
elif m < 60:
|
elif m < 60:
|
||||||
latest_relative = "一会前"
|
latest_relative = f"{m}分钟前"
|
||||||
else:
|
else:
|
||||||
latest_relative = "较早前"
|
h = int(m // 60)
|
||||||
|
latest_relative = f"{h}小时前" if h < 24 else f"{int(h // 24)}天前"
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from datetime import datetime
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from typing import List, Optional, Dict, Any
|
from typing import List, Optional, Dict, Any
|
||||||
import uuid
|
import uuid
|
||||||
@@ -7,11 +6,11 @@ import time
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType
|
from app.models.models_model import ModelConfig, ModelApiKey, ModelType
|
||||||
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository, ModelBaseRepository
|
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository
|
||||||
from app.schemas import model_schema
|
from app.schemas import model_schema
|
||||||
from app.schemas.model_schema import (
|
from app.schemas.model_schema import (
|
||||||
ModelConfigCreate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
ModelConfigCreate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
||||||
ModelConfigQuery, ModelStats, ModelConfigQueryNew
|
ModelConfigQuery, ModelStats
|
||||||
)
|
)
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.schemas.response_schema import PageData, PageMeta
|
from app.schemas.response_schema import PageData, PageMeta
|
||||||
@@ -48,26 +47,6 @@ class ModelConfigService:
|
|||||||
items=[model_schema.ModelConfig.model_validate(model) for model in models]
|
items=[model_schema.ModelConfig.model_validate(model) for model in models]
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_model_list_new(db: Session, query: ModelConfigQueryNew, tenant_id: uuid.UUID | None = None) -> List[dict]:
|
|
||||||
"""获取模型配置列表"""
|
|
||||||
provider_groups, total = ModelConfigRepository.get_list_new(db, query, tenant_id=tenant_id)
|
|
||||||
|
|
||||||
items = []
|
|
||||||
for provider, models in provider_groups.items():
|
|
||||||
# 验证每个模型并封装分组信息
|
|
||||||
validated_models = [model_schema.ModelConfig.model_validate(model) for model in models]
|
|
||||||
tags = list({model.type for model in validated_models})
|
|
||||||
group_item = {
|
|
||||||
"provider": provider, # 服务商名称
|
|
||||||
"logo": validated_models[0].logo,
|
|
||||||
"tags": tags,
|
|
||||||
"models": validated_models # 该服务商下的所有模型
|
|
||||||
}
|
|
||||||
items.append(group_item)
|
|
||||||
|
|
||||||
return items
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_model_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
def get_model_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||||
"""根据名称获取模型配置"""
|
"""根据名称获取模型配置"""
|
||||||
@@ -249,39 +228,37 @@ class ModelConfigService:
|
|||||||
|
|
||||||
# 验证配置
|
# 验证配置
|
||||||
if not model_data.skip_validation and model_data.api_keys:
|
if not model_data.skip_validation and model_data.api_keys:
|
||||||
api_key_data_list = model_data.api_keys
|
api_key_data = model_data.api_keys
|
||||||
for api_key_data in api_key_data_list:
|
validation_result = await ModelConfigService.validate_model_config(
|
||||||
validation_result = await ModelConfigService.validate_model_config(
|
db=db,
|
||||||
db=db,
|
model_name=api_key_data.model_name,
|
||||||
model_name=api_key_data.model_name,
|
provider=api_key_data.provider,
|
||||||
provider=api_key_data.provider,
|
api_key=api_key_data.api_key,
|
||||||
api_key=api_key_data.api_key,
|
api_base=api_key_data.api_base,
|
||||||
api_base=api_key_data.api_base,
|
model_type=model_data.type, # 传递模型类型
|
||||||
model_type=model_data.type, # 传递模型类型
|
test_message="Hello"
|
||||||
test_message="Hello"
|
)
|
||||||
|
if not validation_result["valid"]:
|
||||||
|
raise BusinessException(
|
||||||
|
f"模型配置验证失败: {validation_result['error']}",
|
||||||
|
BizCode.INVALID_PARAMETER
|
||||||
)
|
)
|
||||||
if not validation_result["valid"]:
|
|
||||||
raise BusinessException(
|
|
||||||
f"模型配置验证失败: {validation_result['error']}",
|
|
||||||
BizCode.INVALID_PARAMETER
|
|
||||||
)
|
|
||||||
|
|
||||||
# 事务处理
|
# 事务处理
|
||||||
api_key_datas = model_data.api_keys
|
api_key_data = model_data.api_keys
|
||||||
model_config_data = model_data.model_dump(exclude={"api_keys", "skip_validation"})
|
model_config_data = model_data.dict(exclude={"api_keys", "skip_validation"})
|
||||||
# 添加租户ID
|
# 添加租户ID
|
||||||
model_config_data["tenant_id"] = tenant_id
|
model_config_data["tenant_id"] = tenant_id
|
||||||
|
|
||||||
model = ModelConfigRepository.create(db, model_config_data)
|
model = ModelConfigRepository.create(db, model_config_data)
|
||||||
db.flush() # 获取生成的 ID
|
db.flush() # 获取生成的 ID
|
||||||
|
|
||||||
if api_key_datas:
|
if api_key_data:
|
||||||
for api_key_data in api_key_datas:
|
api_key_create_schema = ModelApiKeyCreate(
|
||||||
api_key_create_schema = ModelApiKeyCreate(
|
model_config_id=model.id,
|
||||||
model_config_ids=[model.id],
|
**api_key_data.dict()
|
||||||
**api_key_data.model_dump()
|
)
|
||||||
)
|
ModelApiKeyRepository.create(db, api_key_create_schema)
|
||||||
ModelApiKeyRepository.create(db, api_key_create_schema)
|
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(model)
|
db.refresh(model)
|
||||||
@@ -303,116 +280,6 @@ class ModelConfigService:
|
|||||||
db.refresh(model)
|
db.refresh(model)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
|
||||||
"""创建组合模型"""
|
|
||||||
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
|
|
||||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
|
||||||
|
|
||||||
# 验证所有 API Key 存在且类型匹配
|
|
||||||
for api_key_id in model_data.api_key_ids:
|
|
||||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
|
||||||
if not api_key:
|
|
||||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
|
|
||||||
|
|
||||||
# 检查 API Key 关联的模型配置类型
|
|
||||||
for model_config in api_key.model_configs:
|
|
||||||
# chat 和 llm 类型可以兼容
|
|
||||||
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
|
||||||
config_type = model_config.type
|
|
||||||
request_type = model_data.type
|
|
||||||
|
|
||||||
if not (config_type == request_type or
|
|
||||||
(config_type in compatible_types and request_type in compatible_types)):
|
|
||||||
raise BusinessException(
|
|
||||||
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
|
|
||||||
BizCode.INVALID_PARAMETER
|
|
||||||
)
|
|
||||||
# if model_config.is_composite:
|
|
||||||
# raise BusinessException(
|
|
||||||
# f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型",
|
|
||||||
# BizCode.INVALID_PARAMETER
|
|
||||||
# )
|
|
||||||
|
|
||||||
# 创建组合模型
|
|
||||||
model_config_data = {
|
|
||||||
"tenant_id": tenant_id,
|
|
||||||
"name": model_data.name,
|
|
||||||
"type": model_data.type,
|
|
||||||
"logo": model_data.logo,
|
|
||||||
"description": model_data.description,
|
|
||||||
"provider": "composite",
|
|
||||||
"config": model_data.config,
|
|
||||||
"is_active": model_data.is_active,
|
|
||||||
"is_public": model_data.is_public,
|
|
||||||
"is_composite": True
|
|
||||||
}
|
|
||||||
if "load_balance_strategy" in model_data.model_fields_set:
|
|
||||||
model_config_data["load_balance_strategy"] = model_data.load_balance_strategy
|
|
||||||
|
|
||||||
model = ModelConfigRepository.create(db, model_config_data)
|
|
||||||
db.flush()
|
|
||||||
|
|
||||||
# 关联 API Keys
|
|
||||||
for api_key_id in model_data.api_key_ids:
|
|
||||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
|
||||||
if api_key:
|
|
||||||
model.api_keys.append(api_key)
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
db.refresh(model)
|
|
||||||
return model
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
|
||||||
"""更新组合模型"""
|
|
||||||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
|
||||||
if not existing_model:
|
|
||||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
|
||||||
|
|
||||||
if not existing_model.is_composite:
|
|
||||||
raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER)
|
|
||||||
|
|
||||||
# 验证所有 API Key 存在且类型匹配
|
|
||||||
for api_key_id in model_data.api_key_ids:
|
|
||||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
|
||||||
if not api_key:
|
|
||||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
|
|
||||||
|
|
||||||
for model_config in api_key.model_configs:
|
|
||||||
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
|
||||||
config_type = model_config.type
|
|
||||||
request_type = existing_model.type
|
|
||||||
|
|
||||||
if not (config_type == request_type or
|
|
||||||
(config_type in compatible_types and request_type in compatible_types)):
|
|
||||||
raise BusinessException(
|
|
||||||
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
|
|
||||||
BizCode.INVALID_PARAMETER
|
|
||||||
)
|
|
||||||
|
|
||||||
# 更新基本信息
|
|
||||||
existing_model.name = model_data.name
|
|
||||||
# existing_model.type = model_data.type
|
|
||||||
existing_model.logo = model_data.logo
|
|
||||||
existing_model.description = model_data.description
|
|
||||||
existing_model.config = model_data.config
|
|
||||||
existing_model.is_active = model_data.is_active
|
|
||||||
existing_model.is_public = model_data.is_public
|
|
||||||
if "load_balance_strategy" in model_data.model_fields_set:
|
|
||||||
existing_model.load_balance_strategy = model_data.load_balance_strategy
|
|
||||||
|
|
||||||
# 更新 API Keys 关联
|
|
||||||
existing_model.api_keys.clear()
|
|
||||||
for api_key_id in model_data.api_key_ids:
|
|
||||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
|
||||||
if api_key:
|
|
||||||
existing_model.api_keys.append(api_key)
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
db.refresh(existing_model)
|
|
||||||
return existing_model
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_model(db: Session, model_id: uuid.UUID, tenant_id: uuid.UUID | None = None) -> bool:
|
def delete_model(db: Session, model_id: uuid.UUID, tenant_id: uuid.UUID | None = None) -> bool:
|
||||||
"""删除模型配置"""
|
"""删除模型配置"""
|
||||||
@@ -457,133 +324,27 @@ class ModelApiKeyService:
|
|||||||
return ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active)
|
return ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def create_api_key_by_provider(db: Session, data: model_schema.ModelApiKeyCreateByProvider) -> tuple[
|
async def create_api_key(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey:
|
||||||
list[Any], list[Any]]:
|
"""创建API Key"""
|
||||||
"""根据provider为多个ModelConfig创建API Key"""
|
model_config = ModelConfigRepository.get_by_id(db, api_key_data.model_config_id)
|
||||||
created_keys = []
|
if not model_config:
|
||||||
failed_models = [] # 记录验证失败的模型
|
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||||
|
|
||||||
for model_config_id in data.model_config_ids:
|
validation_result = await ModelConfigService.validate_model_config(
|
||||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
|
||||||
if not model_config:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 从ModelBase获取model_name
|
|
||||||
model_name = model_config.model_base.name if model_config.model_base else model_config.name
|
|
||||||
|
|
||||||
# 检查是否存在API Key(包括软删除)
|
|
||||||
existing_key = db.query(ModelApiKey).filter(
|
|
||||||
ModelApiKey.api_key == data.api_key,
|
|
||||||
ModelApiKey.provider == data.provider,
|
|
||||||
ModelApiKey.model_name == model_name
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if existing_key:
|
|
||||||
# 如果已存在,重新激活并更新
|
|
||||||
if existing_key.is_active:
|
|
||||||
continue
|
|
||||||
existing_key.is_active = True
|
|
||||||
existing_key.api_base = data.api_base
|
|
||||||
existing_key.description = data.description
|
|
||||||
existing_key.config = data.config
|
|
||||||
existing_key.priority = data.priority
|
|
||||||
existing_key.model_name = model_name
|
|
||||||
|
|
||||||
# 检查是否已关联该模型配置
|
|
||||||
if model_config not in existing_key.model_configs:
|
|
||||||
existing_key.model_configs.append(model_config)
|
|
||||||
|
|
||||||
created_keys.append(existing_key)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 验证配置
|
|
||||||
validation_result = await ModelConfigService.validate_model_config(
|
|
||||||
db=db,
|
db=db,
|
||||||
model_name=model_name,
|
model_name=api_key_data.model_name,
|
||||||
provider=data.provider,
|
provider=api_key_data.provider,
|
||||||
api_key=data.api_key,
|
api_key=api_key_data.api_key,
|
||||||
api_base=data.api_base,
|
api_base=api_key_data.api_base,
|
||||||
model_type=model_config.type,
|
model_type=model_config.type, # 传递模型类型
|
||||||
test_message="Hello"
|
test_message="Hello"
|
||||||
)
|
)
|
||||||
if not validation_result["valid"]:
|
print(validation_result)
|
||||||
# 记录验证失败的模型,但不抛出异常
|
if not validation_result["valid"]:
|
||||||
failed_models.append(model_name)
|
raise BusinessException(
|
||||||
continue
|
f"模型配置验证失败: {validation_result['error']}",
|
||||||
|
BizCode.INVALID_PARAMETER
|
||||||
# 创建API Key
|
|
||||||
api_key_data = ModelApiKeyCreate(
|
|
||||||
model_config_ids=[model_config_id],
|
|
||||||
model_name=model_name,
|
|
||||||
description=data.description,
|
|
||||||
provider=data.provider,
|
|
||||||
api_key=data.api_key,
|
|
||||||
api_base=data.api_base,
|
|
||||||
config=data.config,
|
|
||||||
is_active=data.is_active,
|
|
||||||
priority=data.priority
|
|
||||||
)
|
|
||||||
api_key_obj = ModelApiKeyRepository.create(db, api_key_data)
|
|
||||||
created_keys.append(api_key_obj)
|
|
||||||
|
|
||||||
if created_keys:
|
|
||||||
db.commit()
|
|
||||||
for key in created_keys:
|
|
||||||
db.refresh(key)
|
|
||||||
|
|
||||||
return created_keys, failed_models
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def create_api_key(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey:
|
|
||||||
# 验证所有关联的模型配置是否存在
|
|
||||||
if api_key_data.model_config_ids:
|
|
||||||
for model_config_id in api_key_data.model_config_ids:
|
|
||||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
|
||||||
if not model_config:
|
|
||||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
|
||||||
|
|
||||||
# 检查API Key是否已存在(包括软删除)
|
|
||||||
existing_key = db.query(ModelApiKey).filter(
|
|
||||||
ModelApiKey.api_key == api_key_data.api_key,
|
|
||||||
ModelApiKey.provider == api_key_data.provider,
|
|
||||||
ModelApiKey.model_name == api_key_data.model_name
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if existing_key:
|
|
||||||
if existing_key.is_active:
|
|
||||||
# 如果已激活,跳过
|
|
||||||
raise BusinessException("该API Key已存在", BizCode.DUPLICATE_NAME)
|
|
||||||
# 如果已存在,重新激活并更新
|
|
||||||
existing_key.is_active = True
|
|
||||||
existing_key.api_base = api_key_data.api_base
|
|
||||||
existing_key.description = api_key_data.description
|
|
||||||
existing_key.config = api_key_data.config
|
|
||||||
existing_key.priority = api_key_data.priority
|
|
||||||
existing_key.model_name = api_key_data.model_name
|
|
||||||
|
|
||||||
# 检查是否已关联该模型配置
|
|
||||||
if model_config not in existing_key.model_configs:
|
|
||||||
existing_key.model_configs.append(model_config)
|
|
||||||
|
|
||||||
db.commit()
|
|
||||||
db.refresh(existing_key)
|
|
||||||
return existing_key
|
|
||||||
|
|
||||||
# 验证配置
|
|
||||||
validation_result = await ModelConfigService.validate_model_config(
|
|
||||||
db=db,
|
|
||||||
model_name=api_key_data.model_name,
|
|
||||||
provider=api_key_data.provider,
|
|
||||||
api_key=api_key_data.api_key,
|
|
||||||
api_base=api_key_data.api_base,
|
|
||||||
model_type=model_config.type,
|
|
||||||
test_message="Hello"
|
|
||||||
)
|
)
|
||||||
if not validation_result["valid"]:
|
|
||||||
raise BusinessException(
|
|
||||||
f"模型配置验证失败: {validation_result['error']}",
|
|
||||||
BizCode.INVALID_PARAMETER
|
|
||||||
)
|
|
||||||
|
|
||||||
api_key = ModelApiKeyRepository.create(db, api_key_data)
|
api_key = ModelApiKeyRepository.create(db, api_key_data)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -598,19 +359,21 @@ class ModelApiKeyService:
|
|||||||
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
|
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
|
||||||
|
|
||||||
# 获取关联的模型配置以获取模型类型
|
# 获取关联的模型配置以获取模型类型
|
||||||
if existing_api_key.model_configs:
|
model_config = ModelConfigRepository.get_by_id(db, existing_api_key.model_config_id)
|
||||||
model_config = existing_api_key.model_configs[0]
|
if not model_config:
|
||||||
|
raise BusinessException("关联的模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||||
|
|
||||||
validation_result = await ModelConfigService.validate_model_config(
|
validation_result = await ModelConfigService.validate_model_config(
|
||||||
db=db,
|
db=db,
|
||||||
model_name=api_key_data.model_name or existing_api_key.model_name,
|
model_name=api_key_data.model_name,
|
||||||
provider=api_key_data.provider or existing_api_key.provider,
|
provider=api_key_data.provider,
|
||||||
api_key=api_key_data.api_key or existing_api_key.api_key,
|
api_key=api_key_data.api_key,
|
||||||
api_base=api_key_data.api_base or existing_api_key.api_base,
|
api_base=api_key_data.api_base,
|
||||||
model_type=model_config.type,
|
model_type=model_config.type, # 传递模型类型
|
||||||
test_message="Hello"
|
test_message="Hello"
|
||||||
)
|
)
|
||||||
if not validation_result["valid"]:
|
print(validation_result)
|
||||||
|
if not validation_result["valid"]:
|
||||||
raise BusinessException(
|
raise BusinessException(
|
||||||
f"模型配置验证失败: {validation_result['error']}",
|
f"模型配置验证失败: {validation_result['error']}",
|
||||||
BizCode.INVALID_PARAMETER
|
BizCode.INVALID_PARAMETER
|
||||||
@@ -654,87 +417,3 @@ class ModelApiKeyService:
|
|||||||
if api_kes and len(api_kes) > 0:
|
if api_kes and len(api_kes) > 0:
|
||||||
return api_kes[0]
|
return api_kes[0]
|
||||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelBaseService:
|
|
||||||
"""基础模型服务"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List:
|
|
||||||
models = ModelBaseRepository.get_list(db, query)
|
|
||||||
|
|
||||||
provider_groups = {}
|
|
||||||
for m in models:
|
|
||||||
model_dict = model_schema.ModelBase.model_validate(m).model_dump()
|
|
||||||
if tenant_id:
|
|
||||||
model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id)
|
|
||||||
|
|
||||||
provider = m.provider
|
|
||||||
if provider not in provider_groups:
|
|
||||||
provider_groups[provider] = {
|
|
||||||
"provider": provider,
|
|
||||||
"models": []
|
|
||||||
}
|
|
||||||
provider_groups[provider]["models"].append(model_dict)
|
|
||||||
|
|
||||||
return list(provider_groups.values())
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_model_base_by_id(db: Session, model_base_id: uuid.UUID):
|
|
||||||
model = ModelBaseRepository.get_by_id(db, model_base_id)
|
|
||||||
if not model:
|
|
||||||
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
|
||||||
return model
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_model_base(db: Session, data: model_schema.ModelBaseCreate):
|
|
||||||
existing = ModelBaseRepository.get_by_name_and_provider(db, data.name, data.provider)
|
|
||||||
if existing:
|
|
||||||
raise BusinessException("模型已存在", BizCode.DUPLICATE_NAME)
|
|
||||||
model_base = ModelBaseRepository.create(db, data.model_dump())
|
|
||||||
db.commit()
|
|
||||||
db.refresh(model_base)
|
|
||||||
return model_base
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def update_model_base(db: Session, model_base_id: uuid.UUID, data: model_schema.ModelBaseUpdate):
|
|
||||||
model_base = ModelBaseRepository.update(db, model_base_id, data.model_dump(exclude_unset=True))
|
|
||||||
if not model_base:
|
|
||||||
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(model_base)
|
|
||||||
return model_base
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def delete_model_base(db: Session, model_base_id: uuid.UUID) -> bool:
|
|
||||||
success = ModelBaseRepository.delete(db, model_base_id)
|
|
||||||
if not success:
|
|
||||||
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
|
||||||
db.commit()
|
|
||||||
return success
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def add_model_from_plaza(db: Session, model_base_id: uuid.UUID, tenant_id: uuid.UUID) -> ModelConfig:
|
|
||||||
model_base = ModelBaseRepository.get_by_id(db, model_base_id)
|
|
||||||
if not model_base:
|
|
||||||
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
|
||||||
|
|
||||||
if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id):
|
|
||||||
raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME)
|
|
||||||
|
|
||||||
model_config_data = {
|
|
||||||
"model_id": model_base_id,
|
|
||||||
"tenant_id": tenant_id,
|
|
||||||
"name": model_base.name,
|
|
||||||
"provider": model_base.provider,
|
|
||||||
"type": model_base.type,
|
|
||||||
"logo": model_base.logo,
|
|
||||||
"description": model_base.description,
|
|
||||||
"is_composite": False
|
|
||||||
}
|
|
||||||
model_config = ModelConfigRepository.create(db, model_config_data)
|
|
||||||
ModelBaseRepository.increment_add_count(db, model_base_id)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(model_config)
|
|
||||||
return model_config
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from app.models import MultiAgentConfig, AgentConfig, ModelConfig
|
from app.models import MultiAgentConfig, AgentConfig, ModelConfig
|
||||||
from app.models.multi_agent_model import AggregationStrategy, OrchestrationMode
|
from app.models.multi_agent_model import AggregationStrategy, OrchestrationMode
|
||||||
from app.repositories.model_repository import ModelApiKeyRepository
|
|
||||||
from app.services.agent_registry import AgentRegistry
|
from app.services.agent_registry import AgentRegistry
|
||||||
from app.services.master_agent_router import MasterAgentRouter
|
from app.services.master_agent_router import MasterAgentRouter
|
||||||
from app.services.conversation_state_manager import ConversationStateManager
|
from app.services.conversation_state_manager import ConversationStateManager
|
||||||
@@ -280,22 +279,14 @@ class MultiAgentOrchestrator:
|
|||||||
|
|
||||||
# 4. 提取子 Agent 的 conversation_id(用于多轮对话)
|
# 4. 提取子 Agent 的 conversation_id(用于多轮对话)
|
||||||
sub_conversation_id = None
|
sub_conversation_id = None
|
||||||
total_tokens = 0
|
|
||||||
|
|
||||||
if isinstance(results, dict):
|
if isinstance(results, dict):
|
||||||
sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id")
|
sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id")
|
||||||
# 提取 token 信息
|
|
||||||
usage = results.get("usage", {}) or results.get("result", {}).get("usage", {})
|
|
||||||
total_tokens += usage.get("total_tokens", 0)
|
|
||||||
elif isinstance(results, list) and results:
|
elif isinstance(results, list) and results:
|
||||||
for item in results:
|
for item in results:
|
||||||
if "result" in item:
|
if "result" in item:
|
||||||
sub_conversation_id = item["result"].get("conversation_id")
|
sub_conversation_id = item["result"].get("conversation_id")
|
||||||
if sub_conversation_id:
|
if sub_conversation_id:
|
||||||
break
|
break
|
||||||
# 累加每个子 Agent 的 token
|
|
||||||
usage = item.get("usage", {}) or item.get("result", {}).get("usage", {})
|
|
||||||
total_tokens += usage.get("total_tokens", 0)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"多 Agent 任务完成",
|
"多 Agent 任务完成",
|
||||||
@@ -309,15 +300,9 @@ class MultiAgentOrchestrator:
|
|||||||
return {
|
return {
|
||||||
"message": final_result,
|
"message": final_result,
|
||||||
"conversation_id": sub_conversation_id,
|
"conversation_id": sub_conversation_id,
|
||||||
"mode": OrchestrationMode.SUPERVISOR,
|
|
||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time,
|
||||||
"strategy": routing_decision.get("collaboration_strategy", "single"),
|
"strategy": routing_decision.get("collaboration_strategy", "single"),
|
||||||
"sub_results": results,
|
"sub_results": results
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"total_tokens": total_tokens
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1566,12 +1551,10 @@ class MultiAgentOrchestrator:
|
|||||||
return {
|
return {
|
||||||
"message": result.get("response", ""),
|
"message": result.get("response", ""),
|
||||||
"conversation_id": result.get("conversation_id"),
|
"conversation_id": result.get("conversation_id"),
|
||||||
"mode": OrchestrationMode.COLLABORATION,
|
|
||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time,
|
||||||
"strategy": "collaboration",
|
"strategy": "collaboration",
|
||||||
"active_agent": result.get("active_agent"),
|
"active_agent": result.get("active_agent"),
|
||||||
"sub_results": result,
|
"sub_results": result
|
||||||
"usage": result.get("usage")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -2563,14 +2546,10 @@ class MultiAgentOrchestrator:
|
|||||||
return self._smart_merge_results(results, strategy)
|
return self._smart_merge_results(results, strategy)
|
||||||
|
|
||||||
# 获取 API Key 配置
|
# 获取 API Key 配置
|
||||||
# api_key_config = self.db.query(ModelApiKey).join(
|
api_key_config = self.db.query(ModelApiKey).filter(
|
||||||
# ModelConfig, ModelApiKey.model_configs
|
ModelApiKey.model_config_id == default_model_config_id,
|
||||||
# ).filter(
|
ModelApiKey.is_active.is_(True)
|
||||||
# ModelConfig.id == default_model_config_id,
|
).first()
|
||||||
# ModelApiKey.is_active.is_(True)
|
|
||||||
# ).first()
|
|
||||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id)
|
|
||||||
api_key_config = api_keys[0] if api_keys else None
|
|
||||||
|
|
||||||
if not api_key_config:
|
if not api_key_config:
|
||||||
logger.warning("Master Agent 没有可用的 API Key,使用简单整合")
|
logger.warning("Master Agent 没有可用的 API Key,使用简单整合")
|
||||||
@@ -2724,14 +2703,10 @@ class MultiAgentOrchestrator:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 获取 API Key 配置
|
# 获取 API Key 配置
|
||||||
# api_key_config = self.db.query(ModelApiKey).join(
|
api_key_config = self.db.query(ModelApiKey).filter(
|
||||||
# ModelConfig, ModelApiKey.model_configs
|
ModelApiKey.model_config_id == default_model_config_id,
|
||||||
# ).filter(
|
ModelApiKey.is_active.is_(True)
|
||||||
# ModelConfig.id == default_model_config_id,
|
).first()
|
||||||
# ModelApiKey.is_active.is_(True)
|
|
||||||
# ).first()
|
|
||||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id)
|
|
||||||
api_key_config = api_keys[0] if api_keys else None
|
|
||||||
|
|
||||||
if not api_key_config:
|
if not api_key_config:
|
||||||
logger.warning("Master Agent 没有可用的 API Key,使用简单整合")
|
logger.warning("Master Agent 没有可用的 API Key,使用简单整合")
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""多 Agent 配置管理服务"""
|
"""多 Agent 配置管理服务"""
|
||||||
import uuid
|
import uuid
|
||||||
import json
|
|
||||||
from typing import Optional, List, Tuple, Any, Annotated
|
from typing import Optional, List, Tuple, Any, Annotated
|
||||||
|
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
@@ -428,23 +427,6 @@ class MultiAgentService:
|
|||||||
memory=getattr(request, 'memory', True) # 记忆功能参数
|
memory=getattr(request, 'memory', True) # 记忆功能参数
|
||||||
)
|
)
|
||||||
|
|
||||||
await self._save_conversation_message(
|
|
||||||
conversation_id=request.conversation_id,
|
|
||||||
user_message=request.message,
|
|
||||||
assistant_message=result.get("message", ""),
|
|
||||||
app_id=app_id,
|
|
||||||
user_id=request.user_id,
|
|
||||||
meta_data={
|
|
||||||
"mode": result.get("mode"),
|
|
||||||
"elapsed_time": result.get("elapsed_time"),
|
|
||||||
"usage": result.get("usage", {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"total_tokens": 0
|
|
||||||
})
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def run_stream(
|
async def run_stream(
|
||||||
@@ -469,14 +451,11 @@ class MultiAgentService:
|
|||||||
raise ResourceNotFoundException("多 Agent 配置", str(app_id))
|
raise ResourceNotFoundException("多 Agent 配置", str(app_id))
|
||||||
|
|
||||||
if not config.is_active:
|
if not config.is_active:
|
||||||
raise BusinessException("多 Agent 配置已禁用", BizCode.NOT_FOUND)
|
raise BusinessException("多 Agent 配置已禁用", BizCode.RESOURCE_DISABLED)
|
||||||
|
|
||||||
# 2. 创建编排器
|
# 2. 创建编排器
|
||||||
orchestrator = MultiAgentOrchestrator(self.db, config)
|
orchestrator = MultiAgentOrchestrator(self.db, config)
|
||||||
|
|
||||||
full_content = ""
|
|
||||||
total_tokens = 0
|
|
||||||
|
|
||||||
# 3. 流式执行任务
|
# 3. 流式执行任务
|
||||||
async for event in orchestrator.execute_stream(
|
async for event in orchestrator.execute_stream(
|
||||||
message=request.message,
|
message=request.message,
|
||||||
@@ -489,88 +468,7 @@ class MultiAgentService:
|
|||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id
|
||||||
):
|
):
|
||||||
if "sub_usage" in event:
|
yield event
|
||||||
if "data:" in event:
|
|
||||||
try:
|
|
||||||
data_line = event.split("data: ", 1)[1].strip()
|
|
||||||
data = json.loads(data_line)
|
|
||||||
if "total_tokens" in data:
|
|
||||||
total_tokens += data["total_tokens"]
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
yield event
|
|
||||||
if "data:" in event:
|
|
||||||
try:
|
|
||||||
data_line = event.split("data: ", 1)[1].strip()
|
|
||||||
data = json.loads(data_line)
|
|
||||||
if "content" in data:
|
|
||||||
full_content += data["content"]
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
await self._save_conversation_message(
|
|
||||||
conversation_id=request.conversation_id,
|
|
||||||
user_message=request.message,
|
|
||||||
assistant_message=full_content,
|
|
||||||
app_id=app_id,
|
|
||||||
user_id=request.user_id,
|
|
||||||
meta_data={
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"total_tokens": total_tokens
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _save_conversation_message(
|
|
||||||
self,
|
|
||||||
conversation_id: uuid.UUID,
|
|
||||||
user_message: str,
|
|
||||||
assistant_message: str,
|
|
||||||
meta_data: dict,
|
|
||||||
app_id: Optional[uuid.UUID] = None,
|
|
||||||
user_id: Optional[str] = None
|
|
||||||
) -> None:
|
|
||||||
"""保存会话消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
conversation_id: 会话ID
|
|
||||||
user_message: 用户消息
|
|
||||||
assistant_message: AI 回复消息
|
|
||||||
meta_data: 元数据(包括 token 消耗)
|
|
||||||
app_id: 应用ID
|
|
||||||
user_id: 用户ID
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from app.services.conversation_service import ConversationService
|
|
||||||
|
|
||||||
conversation_service = ConversationService(self.db)
|
|
||||||
|
|
||||||
conversation_service.add_message(
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
role="user",
|
|
||||||
content=user_message
|
|
||||||
)
|
|
||||||
conversation_service.add_message(
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
role="assistant",
|
|
||||||
content=assistant_message,
|
|
||||||
meta_data=meta_data
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"保存多 Agent 会话消息",
|
|
||||||
extra={
|
|
||||||
"conversation_id": conversation_id,
|
|
||||||
"user_message_length": len(user_message),
|
|
||||||
"assistant_message_length": len(assistant_message)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("保存会话消息失败", extra={"error": str(e)})
|
|
||||||
|
|
||||||
# def add_sub_agent(
|
# def add_sub_agent(
|
||||||
# self,
|
# self,
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, AsyncGenerator
|
from typing import Any, AsyncGenerator
|
||||||
@@ -17,10 +16,9 @@ from app.models.prompt_optimizer_model import (
|
|||||||
PromptOptimizerSession,
|
PromptOptimizerSession,
|
||||||
RoleType
|
RoleType
|
||||||
)
|
)
|
||||||
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository
|
from app.repositories.model_repository import ModelConfigRepository
|
||||||
from app.repositories.prompt_optimizer_repository import (
|
from app.repositories.prompt_optimizer_repository import (
|
||||||
PromptOptimizerSessionRepository,
|
PromptOptimizerSessionRepository
|
||||||
PromptReleaseRepository
|
|
||||||
)
|
)
|
||||||
from app.schemas.prompt_optimizer_schema import OptimizePromptResult
|
from app.schemas.prompt_optimizer_schema import OptimizePromptResult
|
||||||
|
|
||||||
@@ -30,8 +28,6 @@ logger = get_business_logger()
|
|||||||
class PromptOptimizerService:
|
class PromptOptimizerService:
|
||||||
def __init__(self, db: Session):
|
def __init__(self, db: Session):
|
||||||
self.db = db
|
self.db = db
|
||||||
self.optim_repo = PromptOptimizerSessionRepository(self.db)
|
|
||||||
self.release_repo = PromptReleaseRepository(self.db)
|
|
||||||
|
|
||||||
def get_model_config(
|
def get_model_config(
|
||||||
self,
|
self,
|
||||||
@@ -82,12 +78,10 @@ class PromptOptimizerService:
|
|||||||
Returns:
|
Returns:
|
||||||
PromptOptimzerSession: The newly created prompt optimization session.
|
PromptOptimzerSession: The newly created prompt optimization session.
|
||||||
"""
|
"""
|
||||||
session = self.optim_repo.create_session(
|
session = PromptOptimizerSessionRepository(self.db).create_session(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
user_id=user_id
|
user_id=user_id
|
||||||
)
|
)
|
||||||
self.db.commit()
|
|
||||||
self.db.refresh(session)
|
|
||||||
return session
|
return session
|
||||||
|
|
||||||
def get_session_message_history(
|
def get_session_message_history(
|
||||||
@@ -112,7 +106,7 @@ class PromptOptimizerService:
|
|||||||
- role (str): The role of the message sender, e.g., 'system', 'user', or 'assistant'.
|
- role (str): The role of the message sender, e.g., 'system', 'user', or 'assistant'.
|
||||||
- content (str): The content of the message.
|
- content (str): The content of the message.
|
||||||
"""
|
"""
|
||||||
history = self.optim_repo.get_session_history(
|
history = PromptOptimizerSessionRepository(self.db).get_session_history(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
user_id=user_id
|
user_id=user_id
|
||||||
)
|
)
|
||||||
@@ -174,8 +168,7 @@ class PromptOptimizerService:
|
|||||||
logger.info(f"Prompt optimization started, user_id={user_id}, session_id={session_id}")
|
logger.info(f"Prompt optimization started, user_id={user_id}, session_id={session_id}")
|
||||||
|
|
||||||
# Create LLM instance
|
# Create LLM instance
|
||||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config.id)
|
api_config: ModelApiKey = model_config.api_keys[0]
|
||||||
api_config: ModelApiKey = api_keys[0] if api_keys else None
|
|
||||||
llm = RedBearLLM(RedBearModelConfig(
|
llm = RedBearLLM(RedBearModelConfig(
|
||||||
model_name=api_config.model_name,
|
model_name=api_config.model_name,
|
||||||
provider=api_config.provider,
|
provider=api_config.provider,
|
||||||
@@ -183,12 +176,11 @@ class PromptOptimizerService:
|
|||||||
base_url=api_config.api_base
|
base_url=api_config.api_base
|
||||||
), type=ModelType(model_config.type))
|
), type=ModelType(model_config.type))
|
||||||
try:
|
try:
|
||||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
with open('app/services/prompt/prompt_optimizer_system.jinja2', 'r', encoding='utf-8') as f:
|
||||||
with open(os.path.join(prompt_path, 'prompt_optimizer_system.jinja2'), 'r', encoding='utf-8') as f:
|
|
||||||
opt_system_prompt = f.read()
|
opt_system_prompt = f.read()
|
||||||
rendered_system_message = Template(opt_system_prompt).render()
|
rendered_system_message = Template(opt_system_prompt).render()
|
||||||
|
|
||||||
with open(os.path.join(prompt_path, 'prompt_optimizer_user.jinja2'), 'r', encoding='utf-8') as f:
|
with open('app/services/prompt/prompt_optimizer_user.jinja2', 'r', encoding='utf-8') as f:
|
||||||
opt_user_prompt = f.read()
|
opt_user_prompt = f.read()
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
||||||
@@ -303,165 +295,4 @@ class PromptOptimizerService:
|
|||||||
role=role,
|
role=role,
|
||||||
content=content
|
content=content
|
||||||
)
|
)
|
||||||
self.db.commit()
|
|
||||||
self.db.refresh(message)
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
def save_prompt(
|
|
||||||
self,
|
|
||||||
tenant_id: uuid.UUID,
|
|
||||||
session_id: uuid.UUID,
|
|
||||||
title: str,
|
|
||||||
prompt: str
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Create and save a new prompt release for a given session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tenant_id (uuid.UUID): The ID of the tenant owning the prompt.
|
|
||||||
session_id (uuid.UUID): The ID of the session to associate with this prompt.
|
|
||||||
title (str): The title of the prompt release.
|
|
||||||
prompt (str): The content of the prompt.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: A dictionary containing:
|
|
||||||
- id (UUID): The unique ID of the created prompt release.
|
|
||||||
- session_id (UUID): The session ID linked to the release.
|
|
||||||
- title (str): The title of the prompt.
|
|
||||||
- prompt (str): The prompt content.
|
|
||||||
- created_at (int): Timestamp (in milliseconds) of when the prompt was created.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
BusinessException: If a prompt release already exists for the given session.
|
|
||||||
"""
|
|
||||||
session = self.optim_repo.get_session_by_id(session_id)
|
|
||||||
if session is None or session.tenant_id != tenant_id:
|
|
||||||
raise BusinessException(
|
|
||||||
"Session does not exist or the current user has no access",
|
|
||||||
BizCode.BAD_REQUEST
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.release_repo.get_prompt_by_session_id(session_id):
|
|
||||||
raise BusinessException(
|
|
||||||
"A release already exists for the current session",
|
|
||||||
BizCode.BAD_REQUEST
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_obj = self.release_repo.create_prompt_release(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
title=title,
|
|
||||||
session_id=session_id,
|
|
||||||
prompt=prompt
|
|
||||||
)
|
|
||||||
self.db.commit()
|
|
||||||
self.db.refresh(prompt_obj)
|
|
||||||
return {
|
|
||||||
"id": prompt_obj.id,
|
|
||||||
"session_id": prompt_obj.session_id,
|
|
||||||
"title": prompt_obj.title,
|
|
||||||
"prompt": prompt_obj.prompt,
|
|
||||||
"created_at": int(prompt_obj.created_at.timestamp() * 1000)
|
|
||||||
}
|
|
||||||
|
|
||||||
def delete_prompt(
|
|
||||||
self,
|
|
||||||
tenant_id: uuid.UUID,
|
|
||||||
prompt_id: uuid.UUID
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Soft delete a prompt release by prompt_id.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tenant_id (uuid.UUID): Tenant identifier.
|
|
||||||
prompt_id (uuid.UUID): Prompt identifier.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
BusinessException: If the prompt does not exist or already deleted.
|
|
||||||
"""
|
|
||||||
prompt_obj = self.release_repo.get_prompt_by_id(prompt_id)
|
|
||||||
if not prompt_obj or prompt_obj.is_delete:
|
|
||||||
raise BusinessException(
|
|
||||||
"Prompt does not exist or has already been deleted",
|
|
||||||
BizCode.NOT_FOUND
|
|
||||||
)
|
|
||||||
|
|
||||||
if prompt_obj.tenant_id != tenant_id:
|
|
||||||
raise BusinessException(
|
|
||||||
"No permission to delete this prompt",
|
|
||||||
BizCode.FORBIDDEN
|
|
||||||
)
|
|
||||||
|
|
||||||
self.release_repo.soft_delete_prompt(prompt_obj)
|
|
||||||
self.db.commit()
|
|
||||||
logger.info(f"Prompt soft deleted, prompt_id={prompt_id}, tenant_id={tenant_id}")
|
|
||||||
|
|
||||||
def get_release_list(
|
|
||||||
self,
|
|
||||||
tenant_id: uuid.UUID,
|
|
||||||
page: int,
|
|
||||||
page_size: int,
|
|
||||||
filter_keyword: str | None = None
|
|
||||||
) -> dict[str, int | list[Any]]:
|
|
||||||
"""
|
|
||||||
Get paginated list of prompt releases with optional filter.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tenant_id (uuid.UUID): Tenant identifier.
|
|
||||||
page (int): Page number (starting from 1).
|
|
||||||
page_size (int): Number of items per page.
|
|
||||||
filter_keyword (str | None): Optional keyword to filter by title.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Contains total count, pagination info, and list of releases.
|
|
||||||
"""
|
|
||||||
offset = (page - 1) * page_size
|
|
||||||
|
|
||||||
# Get total count and releases based on filter
|
|
||||||
if filter_keyword:
|
|
||||||
total = self.release_repo.count_prompts_by_keyword(tenant_id, filter_keyword)
|
|
||||||
releases = self.release_repo.search_prompts_paginated(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
keyword=filter_keyword,
|
|
||||||
offset=offset,
|
|
||||||
limit=page_size
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
total = self.release_repo.count_prompts(tenant_id)
|
|
||||||
releases = self.release_repo.get_prompts_paginated(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
offset=offset,
|
|
||||||
limit=page_size
|
|
||||||
)
|
|
||||||
|
|
||||||
items = []
|
|
||||||
for release in releases:
|
|
||||||
# Get first user message from session
|
|
||||||
first_message = self.optim_repo.get_first_user_message(
|
|
||||||
session_id=release.session_id
|
|
||||||
)
|
|
||||||
|
|
||||||
items.append({
|
|
||||||
"id": release.id,
|
|
||||||
"title": release.title,
|
|
||||||
"prompt": release.prompt,
|
|
||||||
"created_at": int(release.created_at.timestamp() * 1000),
|
|
||||||
"first_message": first_message
|
|
||||||
})
|
|
||||||
|
|
||||||
log_msg = f"Retrieved {len(items)} prompt releases, page={page}, tenant_id={tenant_id}"
|
|
||||||
if filter_keyword:
|
|
||||||
log_msg += f", filter='{filter_keyword}'"
|
|
||||||
logger.info(log_msg)
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"page": {
|
|
||||||
"total": total,
|
|
||||||
"page": page,
|
|
||||||
"page_size": page_size,
|
|
||||||
"hasnext": page * page_size < total
|
|
||||||
},
|
|
||||||
"keyword": filter_keyword,
|
|
||||||
"items": items
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ import time
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Optional, Dict, Any, AsyncGenerator
|
from typing import Optional, Dict, Any, AsyncGenerator
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.repositories.model_repository import ModelApiKeyRepository
|
|
||||||
from app.services.memory_konwledges_server import write_rag
|
from app.services.memory_konwledges_server import write_rag
|
||||||
from app.models import ReleaseShare, AppRelease, Conversation
|
from app.models import ReleaseShare, AppRelease, Conversation
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
@@ -166,20 +164,16 @@ class SharedChatService:
|
|||||||
raise ResourceNotFoundException("模型配置", str(model_config_id))
|
raise ResourceNotFoundException("模型配置", str(model_config_id))
|
||||||
|
|
||||||
# 获取 API Key
|
# 获取 API Key
|
||||||
# stmt = (
|
stmt = (
|
||||||
# select(ModelApiKey).join(
|
select(ModelApiKey)
|
||||||
# ModelConfig, ModelApiKey.model_configs
|
.where(
|
||||||
# )
|
ModelApiKey.model_config_id == model_config_id,
|
||||||
# .where(
|
ModelApiKey.is_active.is_(True)
|
||||||
# ModelConfig.id == model_config_id,
|
)
|
||||||
# ModelApiKey.is_active.is_(True)
|
.order_by(ModelApiKey.priority.desc())
|
||||||
# )
|
.limit(1)
|
||||||
# .order_by(ModelApiKey.priority.desc())
|
)
|
||||||
# .limit(1)
|
api_key_obj = self.db.scalars(stmt).first()
|
||||||
# )
|
|
||||||
# api_key_obj = self.db.scalars(stmt).first()
|
|
||||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
|
||||||
api_key_obj = api_keys[0] if api_keys else None
|
|
||||||
if not api_key_obj:
|
if not api_key_obj:
|
||||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
|
||||||
@@ -282,14 +276,7 @@ class SharedChatService:
|
|||||||
self.conversation_service.save_conversation_messages(
|
self.conversation_service.save_conversation_messages(
|
||||||
conversation_id=conversation.id,
|
conversation_id=conversation.id,
|
||||||
user_message=message,
|
user_message=message,
|
||||||
assistant_message=result["content"],
|
assistant_message=result["content"]
|
||||||
meta_data={
|
|
||||||
"usage": result.get("usage", {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"total_tokens": 0
|
|
||||||
})
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
# self.conversation_service.add_message(
|
# self.conversation_service.add_message(
|
||||||
# conversation_id=conversation.id,
|
# conversation_id=conversation.id,
|
||||||
@@ -371,20 +358,16 @@ class SharedChatService:
|
|||||||
raise ResourceNotFoundException("模型配置", str(model_config_id))
|
raise ResourceNotFoundException("模型配置", str(model_config_id))
|
||||||
|
|
||||||
# 获取 API Key
|
# 获取 API Key
|
||||||
# stmt = (
|
stmt = (
|
||||||
# select(ModelApiKey).join(
|
select(ModelApiKey)
|
||||||
# ModelConfig, ModelApiKey.model_configs
|
.where(
|
||||||
# )
|
ModelApiKey.model_config_id == model_config_id,
|
||||||
# .where(
|
ModelApiKey.is_active.is_(True)
|
||||||
# ModelConfig.id == model_config_id,
|
)
|
||||||
# ModelApiKey.is_active.is_(True)
|
.order_by(ModelApiKey.priority.desc())
|
||||||
# )
|
.limit(1)
|
||||||
# .order_by(ModelApiKey.priority.desc())
|
)
|
||||||
# .limit(1)
|
api_key_obj = self.db.scalars(stmt).first()
|
||||||
# )
|
|
||||||
# api_key_obj = self.db.scalars(stmt).first()
|
|
||||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
|
||||||
api_key_obj = api_keys[0] if api_keys else None
|
|
||||||
if not api_key_obj:
|
if not api_key_obj:
|
||||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
|
||||||
@@ -476,7 +459,6 @@ class SharedChatService:
|
|||||||
|
|
||||||
# 流式调用 Agent
|
# 流式调用 Agent
|
||||||
full_content = ""
|
full_content = ""
|
||||||
total_tokens = 0
|
|
||||||
async for chunk in agent.chat_stream(
|
async for chunk in agent.chat_stream(
|
||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
@@ -487,12 +469,9 @@ class SharedChatService:
|
|||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
memory_flag=memory_flag
|
memory_flag=memory_flag
|
||||||
):
|
):
|
||||||
if isinstance(chunk, int):
|
full_content += chunk
|
||||||
total_tokens = chunk
|
# 发送消息块事件
|
||||||
else:
|
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
|
||||||
full_content += chunk
|
|
||||||
# 发送消息块事件
|
|
||||||
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
|
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
@@ -509,7 +488,7 @@ class SharedChatService:
|
|||||||
content=full_content,
|
content=full_content,
|
||||||
meta_data={
|
meta_data={
|
||||||
"model": api_key_obj.model_name,
|
"model": api_key_obj.model_name,
|
||||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
|
"usage": {}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user