Compare commits
3 Commits
hotfix/v0.
...
v0.2.3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79ab929fb0 | ||
|
|
eab7225d83 | ||
|
|
0159fdf149 |
@@ -3,9 +3,14 @@ import platform
|
||||
from datetime import timedelta
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.core.config import settings
|
||||
from celery import Celery
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# macOS fork() safety - must be set before any Celery initialization
|
||||
if platform.system() == 'Darwin':
|
||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||
|
||||
# 创建 Celery 应用实例
|
||||
# broker: 任务队列(使用 Redis DB 0)
|
||||
# backend: 结果存储(使用 Redis DB 10)
|
||||
@@ -63,15 +68,20 @@ celery_app.conf.update(
|
||||
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||
|
||||
# Long-term storage tasks → memory_tasks queue (batched write strategies)
|
||||
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
||||
|
||||
# Document tasks → document_tasks queue (prefork worker)
|
||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||
|
||||
# Beat/periodic tasks → document_tasks queue (prefork worker)
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'},
|
||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'},
|
||||
'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'},
|
||||
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
|
||||
'app.controllers.memory_storage_controller.search_all': {'queue': 'periodic_tasks'},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -79,40 +89,40 @@ celery_app.conf.update(
|
||||
celery_app.autodiscover_tasks(['app'])
|
||||
|
||||
# Celery Beat schedule for periodic tasks
|
||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
|
||||
# memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||
# memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||
# workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||
# forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
|
||||
|
||||
# 构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
"run-workspace-reflection": {
|
||||
"task": "app.tasks.workspace_reflection_task",
|
||||
"schedule": workspace_reflection_schedule,
|
||||
"args": (),
|
||||
},
|
||||
"regenerate-memory-cache": {
|
||||
"task": "app.tasks.regenerate_memory_cache",
|
||||
"schedule": memory_cache_regeneration_schedule,
|
||||
"args": (),
|
||||
},
|
||||
"run-forgetting-cycle": {
|
||||
"task": "app.tasks.run_forgetting_cycle_task",
|
||||
"schedule": forgetting_cycle_schedule,
|
||||
"kwargs": {
|
||||
"config_id": None, # 使用默认配置,可以通过环境变量配置
|
||||
},
|
||||
},
|
||||
}
|
||||
# beat_schedule_config = {
|
||||
# "run-workspace-reflection": {
|
||||
# "task": "app.tasks.workspace_reflection_task",
|
||||
# "schedule": workspace_reflection_schedule,
|
||||
# "args": (),
|
||||
# },
|
||||
# "regenerate-memory-cache": {
|
||||
# "task": "app.tasks.regenerate_memory_cache",
|
||||
# "schedule": memory_cache_regeneration_schedule,
|
||||
# "args": (),
|
||||
# },
|
||||
# "run-forgetting-cycle": {
|
||||
# "task": "app.tasks.run_forgetting_cycle_task",
|
||||
# "schedule": forgetting_cycle_schedule,
|
||||
# "kwargs": {
|
||||
# "config_id": None, # 使用默认配置,可以通过环境变量配置
|
||||
# },
|
||||
# },
|
||||
# }
|
||||
|
||||
# 如果配置了默认工作空间ID,则添加记忆总量统计任务
|
||||
if settings.DEFAULT_WORKSPACE_ID:
|
||||
beat_schedule_config["write-total-memory"] = {
|
||||
"task": "app.controllers.memory_storage_controller.search_all",
|
||||
"schedule": memory_increment_schedule,
|
||||
"kwargs": {
|
||||
"workspace_id": settings.DEFAULT_WORKSPACE_ID,
|
||||
},
|
||||
}
|
||||
# if settings.DEFAULT_WORKSPACE_ID:
|
||||
# beat_schedule_config["write-total-memory"] = {
|
||||
# "task": "app.controllers.memory_storage_controller.search_all",
|
||||
# "schedule": memory_increment_schedule,
|
||||
# "kwargs": {
|
||||
# "workspace_id": settings.DEFAULT_WORKSPACE_ID,
|
||||
# },
|
||||
# }
|
||||
|
||||
celery_app.conf.beat_schedule = beat_schedule_config
|
||||
# celery_app.conf.beat_schedule = beat_schedule_config
|
||||
|
||||
@@ -45,6 +45,7 @@ from . import (
|
||||
home_page_controller,
|
||||
memory_perceptual_controller,
|
||||
memory_working_controller,
|
||||
ontology_controller,
|
||||
)
|
||||
|
||||
# 创建管理端 API 路由器
|
||||
@@ -90,5 +91,6 @@ manager_router.include_router(implicit_memory_controller.router)
|
||||
manager_router.include_router(memory_perceptual_controller.router)
|
||||
manager_router.include_router(memory_working_controller.router)
|
||||
manager_router.include_router(file_storage_controller.router)
|
||||
manager_router.include_router(ontology_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
@@ -51,7 +51,6 @@ async def save_reflection_config(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="缺少必需参数: config_id"
|
||||
)
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||
|
||||
memory_config = MemoryConfigRepository.update_reflection_config(
|
||||
@@ -102,7 +101,7 @@ async def start_workspace_reflection(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
"""启动工作空间中所有匹配应用的反思功能"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
reflection_service = MemoryReflectionService(db)
|
||||
|
||||
@@ -111,42 +110,55 @@ async def start_workspace_reflection(
|
||||
|
||||
service = WorkspaceAppService(db)
|
||||
result = service.get_workspace_apps_detailed(workspace_id)
|
||||
|
||||
reflection_results = []
|
||||
|
||||
for data in result['apps_detailed_info']:
|
||||
if data['memory_configs'] == []:
|
||||
# 跳过没有配置的应用
|
||||
if not data['memory_configs']:
|
||||
api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过")
|
||||
continue
|
||||
|
||||
|
||||
releases = data['releases']
|
||||
memory_configs = data['memory_configs']
|
||||
end_users = data['end_users']
|
||||
|
||||
for base, config, user in zip(releases, memory_configs, end_users):
|
||||
# 安全地转换为整数,处理空字符串和None的情况
|
||||
print(base['config'])
|
||||
try:
|
||||
base_config = int(base['config']) if base['config'] else 0
|
||||
config_id = int(config['config_id']) if config['config_id'] else 0
|
||||
except (ValueError, TypeError):
|
||||
api_logger.warning(f"无效的配置ID: base['config']={base.get('config')}, config['config_id']={config.get('config_id')}")
|
||||
|
||||
# 为每个配置和用户组合执行反思
|
||||
for config in memory_configs:
|
||||
config_id_str = str(config['config_id'])
|
||||
|
||||
# 找到匹配此配置的所有release
|
||||
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
|
||||
|
||||
if not matching_releases:
|
||||
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
|
||||
continue
|
||||
|
||||
if base_config == config_id and base['app_id'] == user['app_id']:
|
||||
# 调用反思服务
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
||||
|
||||
reflection_result = await reflection_service.start_text_reflection(
|
||||
config_data=config,
|
||||
end_user_id=user['id']
|
||||
)
|
||||
|
||||
reflection_results.append({
|
||||
"app_id": base['app_id'],
|
||||
"config_id": config['config_id'],
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": reflection_result
|
||||
})
|
||||
|
||||
# 为每个用户执行反思
|
||||
for user in end_users:
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}")
|
||||
|
||||
try:
|
||||
reflection_result = await reflection_service.start_text_reflection(
|
||||
config_data=config,
|
||||
end_user_id=user['id']
|
||||
)
|
||||
|
||||
reflection_results.append({
|
||||
"app_id": data['id'],
|
||||
"config_id": config_id_str,
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": reflection_result
|
||||
})
|
||||
except Exception as e:
|
||||
api_logger.error(f"用户 {user['id']} 反思失败: {str(e)}")
|
||||
reflection_results.append({
|
||||
"app_id": data['id'],
|
||||
"config_id": config_id_str,
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": {
|
||||
"status": "错误",
|
||||
"message": f"反思失败: {str(e)}"
|
||||
}
|
||||
})
|
||||
|
||||
return success(data=reflection_results, msg="反思配置成功")
|
||||
|
||||
|
||||
@@ -195,6 +195,11 @@ def update_config(
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 校验至少有一个字段需要更新
|
||||
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
|
||||
1005
api/app/controllers/ontology_controller.py
Normal file
1005
api/app/controllers/ontology_controller.py
Normal file
File diff suppressed because it is too large
Load Diff
611
api/app/controllers/ontology_secondary_routes.py
Normal file
611
api/app/controllers/ontology_secondary_routes.py
Normal file
@@ -0,0 +1,611 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""本体场景和类型路由(续)
|
||||
|
||||
由于主Controller文件较大,将剩余路由放在此文件中。
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.ontology_schemas import (
|
||||
SceneResponse,
|
||||
SceneListResponse,
|
||||
PaginationInfo,
|
||||
ClassCreateRequest,
|
||||
ClassUpdateRequest,
|
||||
ClassResponse,
|
||||
ClassListResponse,
|
||||
ClassBatchCreateResponse,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.ontology_service import OntologyService
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
|
||||
def _get_dummy_ontology_service(db: Session) -> OntologyService:
|
||||
"""获取OntologyService实例(不需要LLM)
|
||||
|
||||
场景和类型管理不需要LLM,创建一个dummy配置。
|
||||
"""
|
||||
dummy_config = RedBearModelConfig(
|
||||
model_name="dummy",
|
||||
provider="openai",
|
||||
api_key="dummy",
|
||||
base_url="https://api.openai.com/v1"
|
||||
)
|
||||
llm_client = OpenAIClient(model_config=dummy_config)
|
||||
return OntologyService(llm_client=llm_client, db=db)
|
||||
|
||||
|
||||
# 这些函数将被导入到主Controller中
|
||||
|
||||
async def scenes_handler(
|
||||
workspace_id: Optional[str] = None,
|
||||
scene_name: Optional[str] = None,
|
||||
page: Optional[int] = None,
|
||||
page_size: Optional[int] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取场景列表(支持模糊搜索和全量查询,全量查询支持分页)
|
||||
|
||||
当提供 scene_name 参数时,进行模糊搜索(不分页);
|
||||
当不提供 scene_name 参数时,返回所有场景(支持分页)。
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
||||
scene_name: 场景名称关键词(可选,支持模糊匹配)
|
||||
page: 页码(可选,从1开始,仅在全量查询时有效)
|
||||
page_size: 每页数量(可选,仅在全量查询时有效)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
"""
|
||||
operation = "search" if scene_name else "list"
|
||||
api_logger.info(
|
||||
f"Scene {operation} requested by user {current_user.id}, "
|
||||
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 确定工作空间ID
|
||||
if workspace_id:
|
||||
try:
|
||||
ws_uuid = UUID(workspace_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid workspace_id format: {workspace_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的工作空间ID格式")
|
||||
else:
|
||||
ws_uuid = current_user.current_workspace_id
|
||||
if not ws_uuid:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 根据是否提供 scene_name 决定查询方式
|
||||
if scene_name and scene_name.strip():
|
||||
# 验证分页参数(模糊搜索也支持分页)
|
||||
if page is not None and page < 1:
|
||||
api_logger.warning(f"Invalid page number: {page}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||
|
||||
if page_size is not None and page_size < 1:
|
||||
api_logger.warning(f"Invalid page_size: {page_size}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||
|
||||
# 如果只提供了page或page_size中的一个,返回错误
|
||||
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||
|
||||
# 模糊搜索场景(支持分页)
|
||||
scenes = service.search_scenes_by_name(scene_name.strip(), ws_uuid)
|
||||
total = len(scenes)
|
||||
|
||||
# 如果提供了分页参数,进行分页处理
|
||||
if page is not None and page_size is not None:
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
scenes = scenes[start_idx:end_idx]
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for scene in scenes:
|
||||
# 获取前3个class_name作为entity_type
|
||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||
# 动态计算 type_num
|
||||
type_num = len(scene.classes) if scene.classes else 0
|
||||
|
||||
items.append(SceneResponse(
|
||||
scene_id=scene.scene_id,
|
||||
scene_name=scene.scene_name,
|
||||
scene_description=scene.scene_description,
|
||||
type_num=type_num,
|
||||
entity_type=entity_type,
|
||||
workspace_id=scene.workspace_id,
|
||||
created_at=scene.created_at,
|
||||
updated_at=scene.updated_at,
|
||||
classes_count=type_num
|
||||
))
|
||||
|
||||
# 构建响应(包含分页信息)
|
||||
if page is not None and page_size is not None:
|
||||
# 计算是否有下一页
|
||||
hasnext = (page * page_size) < total
|
||||
|
||||
pagination_info = PaginationInfo(
|
||||
page=page,
|
||||
pagesize=page_size,
|
||||
total=total,
|
||||
hasnext=hasnext
|
||||
)
|
||||
response = SceneListResponse(items=items, page=pagination_info)
|
||||
else:
|
||||
response = SceneListResponse(items=items)
|
||||
|
||||
api_logger.info(
|
||||
f"Scene search completed: found {len(items)} scenes matching '{scene_name}' "
|
||||
f"in workspace {ws_uuid}, total={total}"
|
||||
)
|
||||
else:
|
||||
# 获取所有场景(支持分页)
|
||||
# 验证分页参数
|
||||
if page is not None and page < 1:
|
||||
api_logger.warning(f"Invalid page number: {page}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||
|
||||
if page_size is not None and page_size < 1:
|
||||
api_logger.warning(f"Invalid page_size: {page_size}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||
|
||||
# 如果只提供了page或page_size中的一个,返回错误
|
||||
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||
|
||||
scenes, total = service.list_scenes(ws_uuid, page, page_size)
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for scene in scenes:
|
||||
# 获取前3个class_name作为entity_type
|
||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||
# 动态计算 type_num
|
||||
type_num = len(scene.classes) if scene.classes else 0
|
||||
|
||||
items.append(SceneResponse(
|
||||
scene_id=scene.scene_id,
|
||||
scene_name=scene.scene_name,
|
||||
scene_description=scene.scene_description,
|
||||
type_num=type_num,
|
||||
entity_type=entity_type,
|
||||
workspace_id=scene.workspace_id,
|
||||
created_at=scene.created_at,
|
||||
updated_at=scene.updated_at,
|
||||
classes_count=type_num
|
||||
))
|
||||
|
||||
# 构建响应(包含分页信息)
|
||||
if page is not None and page_size is not None:
|
||||
# 计算是否有下一页
|
||||
hasnext = (page * page_size) < total
|
||||
|
||||
pagination_info = PaginationInfo(
|
||||
page=page,
|
||||
pagesize=page_size,
|
||||
total=total,
|
||||
hasnext=hasnext
|
||||
)
|
||||
response = SceneListResponse(items=items, page=pagination_info)
|
||||
else:
|
||||
response = SceneListResponse(items=items)
|
||||
|
||||
api_logger.info(f"Scene list retrieved successfully, count={len(items)}, total={total}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in scene {operation}: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in scene {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in scene {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
|
||||
# ==================== 本体类型管理接口 ====================
|
||||
|
||||
async def create_class_handler(
|
||||
request: ClassCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
|
||||
|
||||
# 根据列表长度判断是单个还是批量
|
||||
count = len(request.classes)
|
||||
mode = "single" if count == 1 else "batch"
|
||||
|
||||
api_logger.info(
|
||||
f"Class creation ({mode}) requested by user {current_user.id}, "
|
||||
f"scene_id={request.scene_id}, count={count}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 准备类型数据
|
||||
classes_data = [
|
||||
{
|
||||
"class_name": item.class_name,
|
||||
"class_description": item.class_description
|
||||
}
|
||||
for item in request.classes
|
||||
]
|
||||
|
||||
if count == 1:
|
||||
# 单个创建
|
||||
class_data = classes_data[0]
|
||||
ontology_class = service.create_class(
|
||||
scene_id=request.scene_id,
|
||||
class_name=class_data["class_name"],
|
||||
class_description=class_data["class_description"],
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 构建单个响应
|
||||
response = ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
)
|
||||
|
||||
api_logger.info(f"Class created successfully: {ontology_class.class_id}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="类型创建成功")
|
||||
|
||||
else:
|
||||
# 批量创建
|
||||
created_classes, errors = service.create_classes_batch(
|
||||
scene_id=request.scene_id,
|
||||
classes=classes_data,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 构建批量响应
|
||||
items = []
|
||||
for ontology_class in created_classes:
|
||||
items.append(ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
))
|
||||
|
||||
response = ClassBatchCreateResponse(
|
||||
total=len(classes_data),
|
||||
success_count=len(created_classes),
|
||||
failed_count=len(errors),
|
||||
items=items,
|
||||
errors=errors if errors else None
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"Batch class creation completed: "
|
||||
f"success={len(created_classes)}, failed={len(errors)}"
|
||||
)
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class creation: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
|
||||
|
||||
|
||||
async def update_class_handler(
|
||||
class_id: str,
|
||||
request: ClassUpdateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新本体类型"""
|
||||
api_logger.info(
|
||||
f"Class update requested by user {current_user.id}, "
|
||||
f"class_id={class_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
class_uuid = UUID(class_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid class_id format: {class_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 更新类型
|
||||
ontology_class = service.update_class(
|
||||
class_id=class_uuid,
|
||||
class_name=request.class_name,
|
||||
class_description=request.class_description,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response = ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
)
|
||||
|
||||
api_logger.info(f"Class updated successfully: {class_id}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="类型更新成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class update: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class update: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class update: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
|
||||
|
||||
|
||||
async def delete_class_handler(
|
||||
class_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除本体类型"""
|
||||
api_logger.info(
|
||||
f"Class deletion requested by user {current_user.id}, "
|
||||
f"class_id={class_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
class_uuid = UUID(class_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid class_id format: {class_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 删除类型
|
||||
success_flag = service.delete_class(
|
||||
class_id=class_uuid,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
api_logger.info(f"Class deleted successfully: {class_id}")
|
||||
|
||||
return success(data={"deleted": success_flag}, msg="类型删除成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class deletion: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class deletion: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class deletion: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
|
||||
|
||||
|
||||
async def get_class_handler(
|
||||
class_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取单个本体类型"""
|
||||
api_logger.info(
|
||||
f"Get class requested by user {current_user.id}, "
|
||||
f"class_id={class_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
class_uuid = UUID(class_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid class_id format: {class_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 获取类型(会抛出ValueError如果不存在)
|
||||
ontology_class = service.get_class_by_id(class_uuid, workspace_id)
|
||||
|
||||
# 构建响应
|
||||
response = ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
)
|
||||
|
||||
api_logger.info(f"Class retrieved successfully: {class_id}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
# 类型不存在或无权限访问
|
||||
api_logger.warning(f"Validation error in get class: {str(e)}")
|
||||
return fail(BizCode.NOT_FOUND, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in get class: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in get class: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
|
||||
async def classes_handler(
|
||||
scene_id: str,
|
||||
class_name: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取类型列表(支持模糊搜索和全量查询)
|
||||
|
||||
当提供 class_name 参数时,进行模糊搜索;
|
||||
当不提供 class_name 参数时,返回场景下的所有类型。
|
||||
|
||||
Args:
|
||||
scene_id: 场景ID(必填)
|
||||
class_name: 类型名称关键词(可选,支持模糊匹配)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
"""
|
||||
operation = "search" if class_name else "list"
|
||||
api_logger.info(
|
||||
f"Class {operation} requested by user {current_user.id}, "
|
||||
f"keyword={class_name}, scene_id={scene_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
scene_uuid = UUID(scene_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid scene_id format: {scene_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的场景ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 获取场景信息
|
||||
scene = service.get_scene_by_id(scene_uuid, workspace_id)
|
||||
if not scene:
|
||||
api_logger.warning(f"Scene not found: {scene_id}")
|
||||
return fail(BizCode.NOT_FOUND, "场景不存在", f"未找到ID为 {scene_id} 的场景")
|
||||
|
||||
# 根据是否提供 class_name 决定查询方式
|
||||
if class_name and class_name.strip():
|
||||
# 模糊搜索类型
|
||||
classes = service.search_classes_by_name(class_name.strip(), scene_uuid, workspace_id)
|
||||
else:
|
||||
# 获取所有类型
|
||||
classes = service.list_classes_by_scene(scene_uuid, workspace_id)
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for ontology_class in classes:
|
||||
items.append(ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
))
|
||||
|
||||
response = ClassListResponse(
|
||||
total=len(items),
|
||||
scene_id=scene_uuid,
|
||||
scene_name=scene.scene_name,
|
||||
scene_description=scene.scene_description,
|
||||
items=items
|
||||
)
|
||||
|
||||
if class_name:
|
||||
api_logger.info(
|
||||
f"Class search completed: found {len(items)} classes matching '{class_name}' "
|
||||
f"in scene {scene_id}"
|
||||
)
|
||||
else:
|
||||
api_logger.info(f"Class list retrieved successfully, count={len(items)}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class {operation}: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
@@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -8,9 +8,13 @@ from starlette.responses import StreamingResponse
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.models.prompt_optimizer_model import RoleType
|
||||
from app.schemas.prompt_optimizer_schema import PromptOptMessage, PromptOptModelSet, CreateSessionResponse, \
|
||||
OptimizePromptResponse, SessionHistoryResponse, SessionMessage
|
||||
from app.schemas.prompt_optimizer_schema import (
|
||||
PromptOptMessage,
|
||||
CreateSessionResponse,
|
||||
SessionHistoryResponse,
|
||||
SessionMessage,
|
||||
PromptSaveRequest
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.prompt_optimizer_service import PromptOptimizerService
|
||||
|
||||
@@ -135,3 +139,109 @@ async def get_prompt_opt(
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/releases",
|
||||
summary="Get prompt optimization",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def save_prompt(
|
||||
data: PromptSaveRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Save a prompt release for the current tenant.
|
||||
|
||||
Args:
|
||||
data (PromptSaveRequest): Request body containing session_id, title, and prompt.
|
||||
db (Session): SQLAlchemy database session, injected via dependency.
|
||||
current_user: Currently authenticated user object, injected via dependency.
|
||||
|
||||
Returns:
|
||||
ApiResponse: Standard API response containing the saved prompt release info:
|
||||
- id: UUID of the prompt release
|
||||
- session_id: associated session
|
||||
- title: prompt title
|
||||
- prompt: prompt content
|
||||
- created_at: timestamp of creation
|
||||
|
||||
Raises:
|
||||
Any database or service exceptions are propagated to the global exception handler.
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
prompt_info = service.save_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
session_id=data.session_id,
|
||||
title=data.title,
|
||||
prompt=data.prompt
|
||||
)
|
||||
return success(data=prompt_info)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/releases/{prompt_id}",
|
||||
summary="Delete prompt (soft delete)",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def delete_prompt(
|
||||
prompt_id: uuid.UUID = Path(..., description="Prompt ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Soft delete a prompt release.
|
||||
|
||||
Args:
|
||||
prompt_id
|
||||
db (Session): Database session
|
||||
current_user: Current logged-in user
|
||||
|
||||
Returns:
|
||||
ApiResponse: Success message confirming deletion
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
service.delete_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
prompt_id=prompt_id
|
||||
)
|
||||
return success(msg="Prompt deleted successfully")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/releases/list",
|
||||
summary="Get paginated list of released prompts with optional filter",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def get_release_list(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
keyword: str | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Retrieve paginated list of released prompts for the current tenant.
|
||||
Optionally filter by keyword in title.
|
||||
|
||||
Args:
|
||||
page (int): Page number (starting from 1)
|
||||
page_size (int): Number of items per page (max 100)
|
||||
keyword (str | None): Optional keyword to filter prompt titles
|
||||
db (Session): Database session
|
||||
current_user: Current logged-in user
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains paginated list of prompt releases with metadata
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
result = service.get_release_list(
|
||||
tenant_id=current_user.tenant_id,
|
||||
page=max(1, page),
|
||||
page_size=min(max(1, page_size), 100),
|
||||
filter_keyword=keyword
|
||||
)
|
||||
return success(data=result)
|
||||
|
||||
|
||||
|
||||
@@ -7,29 +7,21 @@ LangChain Agent 封装
|
||||
- 支持流式输出
|
||||
- 使用 RedBearLLM 支持多提供商
|
||||
"""
|
||||
import os
|
||||
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@@ -106,7 +98,7 @@ class LangChainAgent:
|
||||
"streaming": streaming,
|
||||
"tool_count": len(self.tools),
|
||||
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
|
||||
"tool_count": len(self.tools)
|
||||
# "tool_count": len(self.tools)
|
||||
}
|
||||
)
|
||||
|
||||
@@ -145,106 +137,8 @@ class LangChainAgent:
|
||||
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
||||
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
|
||||
return messages
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# async def term_memory_save(self,messages,end_user_end,aimessages):
|
||||
# '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j'''
|
||||
# end_user_end=f"Term_{end_user_end}"
|
||||
# print(messages)
|
||||
# print(aimessages)
|
||||
# session_id = store.save_session(
|
||||
# userid=end_user_end,
|
||||
# messages=messages,
|
||||
# apply_id=end_user_end,
|
||||
# end_user_id=end_user_end,
|
||||
# aimessages=aimessages
|
||||
# )
|
||||
# store.delete_duplicate_sessions()
|
||||
# # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
||||
# return session_id
|
||||
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# async def term_memory_redis_read(self,end_user_end):
|
||||
# end_user_end = f"Term_{end_user_end}"
|
||||
# history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
|
||||
# # logger.info(f'Redis_Agent:{end_user_end};{history}')
|
||||
# messagss_list=[]
|
||||
# retrieved_content=[]
|
||||
# for messages in history:
|
||||
# query = messages.get("Query")
|
||||
# aimessages = messages.get("Answer")
|
||||
# messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
||||
# retrieved_content.append({query: aimessages})
|
||||
# return messagss_list,retrieved_content
|
||||
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
|
||||
"""
|
||||
写入记忆(支持结构化消息)
|
||||
|
||||
Args:
|
||||
storage_type: 存储类型 (neo4j/rag)
|
||||
end_user_id: 终端用户ID
|
||||
user_message: 用户消息内容
|
||||
ai_message: AI 回复内容
|
||||
user_rag_memory_id: RAG 记忆ID
|
||||
actual_end_user_id: 实际用户ID
|
||||
actual_config_id: 配置ID
|
||||
|
||||
逻辑说明:
|
||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||
- Neo4j 模式:使用结构化消息列表
|
||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||
"""
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
actual_config_id=resolve_config_id(actual_config_id, db)
|
||||
|
||||
if storage_type == "rag":
|
||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
else:
|
||||
# Neo4j 模式:使用结构化消息列表
|
||||
structured_messages = []
|
||||
|
||||
# 始终添加用户消息(如果不为空)
|
||||
if user_message:
|
||||
structured_messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||
if ai_message:
|
||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||
|
||||
# 如果没有消息,直接返回
|
||||
if not structured_messages:
|
||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||
return
|
||||
|
||||
# 调用 Celery 任务,传递结构化消息列表
|
||||
# 数据流:
|
||||
# 1. structured_messages 传递给 write_message_task
|
||||
# 2. write_message_task 调用 memory_agent_service.write_memory
|
||||
# 3. write_memory 调用 write_tools.write,传递 messages 参数
|
||||
# 4. write_tools.write 调用 get_chunked_dialogs,传递 messages 参数
|
||||
# 5. get_chunked_dialogs 为每条消息创建独立的 Chunk,设置 speaker 字段
|
||||
# 6. 每个 Chunk 保存到 Neo4j,包含 speaker 字段
|
||||
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||
write_id = write_message_task.delay(
|
||||
actual_end_user_id, # end_user_id: 用户ID
|
||||
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
actual_config_id, # config_id: 配置ID
|
||||
storage_type, # storage_type: "neo4j"
|
||||
user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||
)
|
||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
finally:
|
||||
db.close()
|
||||
async def chat(
|
||||
self,
|
||||
message: str,
|
||||
@@ -288,30 +182,6 @@ class LangChainAgent:
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码
|
||||
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||
# history_term_memory = history_term_memory_result[0]
|
||||
# db_for_memory = next(get_db())
|
||||
# if memory_flag:
|
||||
# if len(history_term_memory)>=4 and storage_type != "rag":
|
||||
# history_term_memory = ';'.join(history_term_memory)
|
||||
# retrieved_content = history_term_memory_result[1]
|
||||
# print(retrieved_content)
|
||||
# # 为长期记忆操作获取新的数据库连接
|
||||
# try:
|
||||
# repo = LongTermMemoryRepository(db_for_memory)
|
||||
# repo.upsert(end_user_id, retrieved_content)
|
||||
# logger.info(
|
||||
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to write to LongTermMemory: {e}")
|
||||
# raise
|
||||
# finally:
|
||||
# db_for_memory.close()
|
||||
|
||||
# # 长期记忆写入(
|
||||
# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||
# # 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表
|
||||
messages = self._prepare_messages(message, history, context)
|
||||
@@ -332,17 +202,17 @@ class LangChainAgent:
|
||||
# 获取最后的 AI 消息
|
||||
output_messages = result.get("messages", [])
|
||||
content = ""
|
||||
total_tokens = 0
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
content = msg.content
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
||||
break
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
if memory_flag:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# await self.term_memory_save(message_chat, end_user_id, content)
|
||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id)
|
||||
response = {
|
||||
"content": content,
|
||||
"model": self.model_name,
|
||||
@@ -350,7 +220,7 @@ class LangChainAgent:
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
}
|
||||
|
||||
@@ -410,25 +280,7 @@ class LangChainAgent:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
# # TODO 乐力齐
|
||||
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||
# history_term_memory = history_term_memory_result[0]
|
||||
# if memory_flag:
|
||||
# if len(history_term_memory) >= 4 and storage_type != "rag":
|
||||
# history_term_memory = ';'.join(history_term_memory)
|
||||
# retrieved_content = history_term_memory_result[1]
|
||||
# db_for_memory = next(get_db())
|
||||
# try:
|
||||
# repo = LongTermMemoryRepository(db_for_memory)
|
||||
# repo.upsert(end_user_id, retrieved_content)
|
||||
# logger.info(
|
||||
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
# # 长期记忆写入
|
||||
# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id)
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to write to long term memory: {e}")
|
||||
# finally:
|
||||
# db_for_memory.close()
|
||||
|
||||
|
||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
@@ -444,7 +296,7 @@ class LangChainAgent:
|
||||
|
||||
# 统一使用 agent 的 astream_events 实现流式输出
|
||||
logger.debug("使用 Agent astream_events 实现流式输出")
|
||||
full_content=''
|
||||
full_content = ''
|
||||
try:
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
@@ -481,12 +333,17 @@ class LangChainAgent:
|
||||
logger.debug(f"工具调用结束: {event.get('name')}")
|
||||
|
||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||
# 统计token消耗
|
||||
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
|
||||
0) if response_meta else 0
|
||||
yield total_tokens
|
||||
break
|
||||
if memory_flag:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# await self.term_memory_save(message_chat, end_user_id, full_content)
|
||||
|
||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
@@ -157,6 +157,11 @@ class Settings:
|
||||
if origin.strip()
|
||||
]
|
||||
|
||||
# Language Configuration
|
||||
# Supported values: "zh" (Chinese), "en" (English)
|
||||
# This controls the language used for memory summary titles and other generated content
|
||||
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
|
||||
|
||||
# Logging settings
|
||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
|
||||
@@ -0,0 +1,238 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
|
||||
|
||||
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.core.memory.agent.utils.redis_tool import count_store
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context, get_db
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
logger = get_agent_logger(__name__)
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
|
||||
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
|
||||
actual_config_id, long_term_messages=[]):
|
||||
"""
|
||||
写入记忆(支持结构化消息)
|
||||
|
||||
Args:
|
||||
storage_type: 存储类型 (neo4j/rag)
|
||||
end_user_id: 终端用户ID
|
||||
user_message: 用户消息内容
|
||||
ai_message: AI 回复内容
|
||||
user_rag_memory_id: RAG 记忆ID
|
||||
actual_end_user_id: 实际用户ID
|
||||
actual_config_id: 配置ID
|
||||
|
||||
逻辑说明:
|
||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||
- Neo4j 模式:使用结构化消息列表
|
||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||
"""
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
actual_config_id = resolve_config_id(actual_config_id, db)
|
||||
# Neo4j 模式:使用结构化消息列表
|
||||
structured_messages = []
|
||||
|
||||
# 始终添加用户消息(如果不为空)
|
||||
if isinstance(user_message, str) and user_message.strip() != "":
|
||||
structured_messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||
if isinstance(ai_message, str) and ai_message.strip() != "":
|
||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||
|
||||
# 如果提供了 long_term_messages,使用它替代 structured_messages
|
||||
if long_term_messages and isinstance(long_term_messages, list):
|
||||
structured_messages = long_term_messages
|
||||
elif long_term_messages and isinstance(long_term_messages, str):
|
||||
# 如果是 JSON 字符串,先解析
|
||||
try:
|
||||
structured_messages = json.loads(long_term_messages)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
||||
|
||||
# 如果没有消息,直接返回
|
||||
if not structured_messages:
|
||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||
write_id = write_message_task.delay(
|
||||
actual_end_user_id, # end_user_id: 用户ID
|
||||
structured_messages, # message: JSON 字符串格式的消息列表
|
||||
str(actual_config_id), # config_id: 配置ID字符串
|
||||
storage_type, # storage_type: "neo4j"
|
||||
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||
)
|
||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
|
||||
with get_db_context() as db_session:
|
||||
repo = LongTermMemoryRepository(db_session)
|
||||
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
result = write_store.get_session_by_userid(end_user_id)
|
||||
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
data = await format_parsing(result, "dict")
|
||||
chunk_data = data[:scope]
|
||||
if len(chunk_data)==scope:
|
||||
repo.upsert(end_user_id, chunk_data)
|
||||
logger.info(f'---------写入短长期-----------')
|
||||
else:
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
||||
long_messages = await messages_parse(long_time_data)
|
||||
repo.upsert(end_user_id, long_messages)
|
||||
logger.info(f'写入短长期:')
|
||||
|
||||
|
||||
|
||||
'''根据窗口'''
|
||||
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||
'''
|
||||
根据窗口获取redis数据,写入neo4j:
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config: 内存配置对象
|
||||
langchain_messages:原始数据LIST
|
||||
scope:窗口大小
|
||||
'''
|
||||
scope=scope
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_id is not False:
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
||||
if is_end_user_id and int(is_end_user_id) != int(scope):
|
||||
is_end_user_id += 1
|
||||
langchain_messages += redis_messages
|
||||
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
||||
elif int(is_end_user_id) == int(scope):
|
||||
logger.info('写入长期记忆NEO4J')
|
||||
formatted_messages = (redis_messages)
|
||||
# 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用)
|
||||
if hasattr(memory_config, 'config_id'):
|
||||
config_id = memory_config.config_id
|
||||
else:
|
||||
config_id = memory_config
|
||||
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
config_id, formatted_messages)
|
||||
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
||||
else:
|
||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||
|
||||
|
||||
"""根据时间"""
|
||||
async def memory_long_term_storage(end_user_id,memory_config,time):
|
||||
'''
|
||||
根据时间获取redis数据,写入neo4j:
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config: 内存配置对象
|
||||
'''
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||
format_messages = (long_time_data)
|
||||
messages=[]
|
||||
memory_config=memory_config.config_id
|
||||
for i in format_messages:
|
||||
message=json.loads(i['Query'])
|
||||
messages+= message
|
||||
if format_messages!=[]:
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
memory_config, messages)
|
||||
'''聚合判断'''
|
||||
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
||||
"""
|
||||
聚合判断函数:判断输入句子和历史消息是否描述同一事件
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
memory_config: 内存配置对象
|
||||
"""
|
||||
|
||||
try:
|
||||
# 1. 获取历史会话数据(使用新方法)
|
||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||
history = await format_parsing(result)
|
||||
if not result:
|
||||
history = []
|
||||
else:
|
||||
history = await format_parsing(result)
|
||||
json_schema = WriteAggregateModel.model_json_schema()
|
||||
template_service = TemplateService(template_root)
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='write_aggregate_judgment.jinja2',
|
||||
operation_name='aggregate_judgment',
|
||||
history=history,
|
||||
sentence=ori_messages,
|
||||
json_schema=json_schema
|
||||
)
|
||||
with get_db_context() as db_session:
|
||||
factory = MemoryClientFactory(db_session)
|
||||
llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": system_prompt
|
||||
}
|
||||
]
|
||||
structured = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=WriteAggregateModel
|
||||
)
|
||||
output_value = structured.output
|
||||
if isinstance(output_value, list):
|
||||
output_value = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in output_value
|
||||
]
|
||||
|
||||
result_dict = {
|
||||
"is_same_event": structured.is_same_event,
|
||||
"output": output_value
|
||||
}
|
||||
if not structured.is_same_event:
|
||||
logger.info(result_dict)
|
||||
await write("neo4j", end_user_id, "", "", None, end_user_id,
|
||||
memory_config.config_id, output_value)
|
||||
return result_dict
|
||||
|
||||
except Exception as e:
|
||||
print(f"[aggregate_judgment] 发生错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
return {
|
||||
"is_same_event": False,
|
||||
"output": ori_messages,
|
||||
"messages": ori_messages,
|
||||
"history": history if 'history' in locals() else [],
|
||||
"error": str(e)
|
||||
}
|
||||
@@ -186,10 +186,11 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
清理后的数据
|
||||
"""
|
||||
# 需要过滤的字段列表
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
fields_to_remove = {
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary"
|
||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
|
||||
}
|
||||
|
||||
if isinstance(data, dict):
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
async def format_parsing(messages: list,type:str='string'):
|
||||
"""
|
||||
格式化解析消息列表
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
type: 返回类型 ('string' 或 'dict')
|
||||
|
||||
Returns:
|
||||
格式化后的消息列表
|
||||
"""
|
||||
result = []
|
||||
user=[]
|
||||
ai=[]
|
||||
|
||||
for message in messages:
|
||||
hstory_messages = message['messages']
|
||||
for history_messag in hstory_messages.strip().splitlines():
|
||||
history_messag = json.loads(history_messag)
|
||||
for content in history_messag:
|
||||
role = content['role']
|
||||
content = content['content']
|
||||
if type == "string":
|
||||
if role == 'human' or role=="user":
|
||||
content = '用户:' + content
|
||||
else:
|
||||
content = 'AI:' + content
|
||||
result.append(content)
|
||||
if type == "dict" :
|
||||
if role == 'human' or role=="user":
|
||||
user.append( content)
|
||||
else:
|
||||
ai.append(content)
|
||||
if type == "dict":
|
||||
for key,values in zip(user,ai):
|
||||
result.append({key:values})
|
||||
return result
|
||||
|
||||
async def messages_parse(messages: list | dict):
|
||||
user=[]
|
||||
ai=[]
|
||||
database=[]
|
||||
for message in messages:
|
||||
Query = message['Query']
|
||||
Query = json.loads(Query)
|
||||
for data in Query:
|
||||
role = data['role']
|
||||
if role == "human":
|
||||
user.append(data['content'])
|
||||
if role == "ai":
|
||||
ai.append(data['content'])
|
||||
for key, values in zip(user, ai):
|
||||
database.append({key, values})
|
||||
return database
|
||||
|
||||
|
||||
async def agent_chat_messages(user_content,ai_content):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{user_content}"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"{ai_content}"
|
||||
}
|
||||
|
||||
]
|
||||
return messages
|
||||
@@ -1,22 +1,20 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
|
||||
from app.db import get_db
|
||||
from app.db import get_db, get_db_context
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -34,14 +32,6 @@ async def make_write_graph():
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
"""
|
||||
# workflow = StateGraph(WriteState)
|
||||
# workflow.add_node("content_input", content_input_write)
|
||||
# workflow.add_node("save_neo4j", write_node)
|
||||
# workflow.add_edge(START, "content_input")
|
||||
# workflow.add_edge("content_input", "save_neo4j")
|
||||
# workflow.add_edge("save_neo4j", END)
|
||||
#
|
||||
# graph = workflow.compile()
|
||||
workflow = StateGraph(WriteState)
|
||||
workflow.add_node("save_neo4j", write_node)
|
||||
workflow.add_edge(START, "save_neo4j")
|
||||
@@ -51,43 +41,63 @@ async def make_write_graph():
|
||||
|
||||
yield graph
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "今天周一"
|
||||
end_user_id = 'new_2025test1103' # 组ID
|
||||
|
||||
|
||||
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
write_store.save_session_write(end_user_id, (langchain_messages))
|
||||
# 获取数据库会话
|
||||
db_session = next(get_db())
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=17, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
try:
|
||||
async with make_write_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config}
|
||||
|
||||
# 获取节点更新信息
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
if 'save_neo4j'==node_name:
|
||||
massages=node_data
|
||||
massages=massages.get('write_result')['status']
|
||||
print(massages) # | 更新数据: {node_data}
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
with get_db_context() as db_session:
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=memory_config, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
if long_term_type=='chunk':
|
||||
'''方案一:对话窗口6轮对话'''
|
||||
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
||||
if long_term_type=='time':
|
||||
"""时间"""
|
||||
await memory_long_term_storage(end_user_id, memory_config,5)
|
||||
if long_term_type=='aggregate':
|
||||
"""方案三:聚合判断"""
|
||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
|
||||
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
||||
else:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
||||
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
|
||||
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
|
||||
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
|
||||
|
||||
# async def main():
|
||||
# """主函数 - 运行工作流"""
|
||||
# langchain_messages = [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "今天周五去爬山"
|
||||
# },
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": "好耶"
|
||||
# }
|
||||
#
|
||||
# ]
|
||||
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
||||
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
||||
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
||||
#
|
||||
#
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# import asyncio
|
||||
# asyncio.run(main())
|
||||
28
api/app/core/memory/agent/models/write_aggregate_model.py
Normal file
28
api/app/core/memory/agent/models/write_aggregate_model.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Pydantic models for write aggregate judgment operations."""
|
||||
|
||||
from typing import List, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MessageItem(BaseModel):
|
||||
"""Individual message item in conversation."""
|
||||
|
||||
role: str = Field(..., description="角色:user 或 assistant")
|
||||
content: str = Field(..., description="消息内容")
|
||||
|
||||
|
||||
class WriteAggregateResponse(BaseModel):
|
||||
"""Response model for aggregate judgment containing judgment result and output."""
|
||||
|
||||
is_same_event: bool = Field(
|
||||
...,
|
||||
description="是否是同一事件。True表示是同一事件,False表示不同事件"
|
||||
)
|
||||
output: Union[List[MessageItem], bool] = Field(
|
||||
...,
|
||||
description="如果is_same_event为True,返回False;如果is_same_event为False,返回消息列表"
|
||||
)
|
||||
|
||||
|
||||
# 为了保持向后兼容,保留旧的类名作为别名
|
||||
WriteAggregateModel = WriteAggregateResponse
|
||||
@@ -0,0 +1,57 @@
|
||||
输入句子:{{sentence}}
|
||||
历史消息:{{history}}
|
||||
|
||||
# 你的角色
|
||||
你是一个擅长事件聚合与语义判断的专家。
|
||||
|
||||
# 你的任务
|
||||
结合历史消息和输入句子,判断它们是否在描述**同一件事件或同一事件链**。
|
||||
|
||||
以下情况视为"同一事件"(需要返回 is_same_event=True, output=False):
|
||||
- 描述的是同一个具体事件或事实
|
||||
- 存在明显的因果关系、前后发展关系
|
||||
- 是对同一事件的补充、解释、追问或延展
|
||||
- 逻辑上属于同一语境下的连续讨论
|
||||
|
||||
以下情况视为"不同事件"(需要返回 is_same_event=False, output=消息列表):
|
||||
- 话题不同,事件主体不同
|
||||
- 时间、地点、对象明显不同
|
||||
- 只是语义相似,但并非同一具体事件
|
||||
- 无直接事件、因果或逻辑关联
|
||||
|
||||
# 输出规则(非常重要)
|
||||
你必须按照以下JSON格式输出:
|
||||
|
||||
**如果是同一事件:**
|
||||
```json
|
||||
{
|
||||
"is_same_event": true,
|
||||
"output": false
|
||||
}
|
||||
```
|
||||
|
||||
**如果不是同一事件:**
|
||||
```json
|
||||
{
|
||||
"is_same_event": false,
|
||||
"output": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "输入句子的内容"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "对应的回复内容"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
# JSON Schema
|
||||
{{json_schema}}
|
||||
|
||||
# 注意事项
|
||||
- 必须严格按照上述格式输出
|
||||
- output 字段:如果是同一事件返回 false,如果不是同一事件返回完整的消息列表
|
||||
- 消息列表必须包含 role 和 content 字段
|
||||
- 不要输出任何解释、分析或多余内容
|
||||
186
api/app/core/memory/agent/utils/redis_base.py
Normal file
186
api/app/core/memory/agent/utils/redis_base.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import json
|
||||
from typing import Any, List, Dict, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
def serialize_messages(messages: Any) -> str:
|
||||
"""
|
||||
将消息序列化为 JSON 字符串,支持 LangChain 消息对象
|
||||
|
||||
Args:
|
||||
messages: 可以是 list、dict、string 或 LangChain 消息对象列表
|
||||
|
||||
Returns:
|
||||
str: JSON 字符串
|
||||
"""
|
||||
if isinstance(messages, str):
|
||||
return messages
|
||||
|
||||
if isinstance(messages, (list, tuple)):
|
||||
# 检查是否是 LangChain 消息对象列表
|
||||
serialized_list = []
|
||||
for msg in messages:
|
||||
if hasattr(msg, 'type') and hasattr(msg, 'content'):
|
||||
# LangChain 消息对象
|
||||
serialized_list.append({
|
||||
'type': msg.type,
|
||||
'content': msg.content,
|
||||
'role': getattr(msg, 'role', msg.type)
|
||||
})
|
||||
elif isinstance(msg, dict):
|
||||
serialized_list.append(msg)
|
||||
else:
|
||||
serialized_list.append(str(msg))
|
||||
return json.dumps(serialized_list, ensure_ascii=False)
|
||||
|
||||
if isinstance(messages, dict):
|
||||
return json.dumps(messages, ensure_ascii=False)
|
||||
|
||||
# 其他类型转为字符串
|
||||
return str(messages)
|
||||
|
||||
|
||||
def deserialize_messages(messages_str: str) -> Any:
|
||||
"""
|
||||
将 JSON 字符串反序列化为原始格式
|
||||
|
||||
Args:
|
||||
messages_str: JSON 字符串
|
||||
|
||||
Returns:
|
||||
反序列化后的对象(list、dict 或 string)
|
||||
"""
|
||||
if not messages_str:
|
||||
return []
|
||||
|
||||
try:
|
||||
return json.loads(messages_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return messages_str
|
||||
|
||||
|
||||
def fix_encoding(text: str) -> str:
|
||||
"""
|
||||
修复错误编码的文本
|
||||
|
||||
Args:
|
||||
text: 需要修复的文本
|
||||
|
||||
Returns:
|
||||
str: 修复后的文本
|
||||
"""
|
||||
if not text or not isinstance(text, str):
|
||||
return text
|
||||
try:
|
||||
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
|
||||
return text.encode('latin-1').decode('utf-8')
|
||||
except (UnicodeDecodeError, UnicodeEncodeError):
|
||||
# 如果修复失败,返回原文本
|
||||
return text
|
||||
|
||||
|
||||
def format_session_data(data: Dict[str, Any], include_time: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
格式化会话数据为统一的输出格式
|
||||
|
||||
Args:
|
||||
data: 原始会话数据
|
||||
include_time: 是否包含时间字段
|
||||
|
||||
Returns:
|
||||
Dict: 格式化后的数据 {"Query": "...", "Answer": "...", "starttime": "..."}
|
||||
"""
|
||||
result = {
|
||||
"Query": fix_encoding(data.get('messages', '')),
|
||||
"Answer": fix_encoding(data.get('aimessages', ''))
|
||||
}
|
||||
|
||||
if include_time:
|
||||
result["starttime"] = data.get('starttime', '')
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def filter_by_time_range(items: List[Dict], minutes: int) -> List[Dict]:
|
||||
"""
|
||||
根据时间范围过滤数据
|
||||
|
||||
Args:
|
||||
items: 包含 starttime 字段的数据列表
|
||||
minutes: 时间范围(分钟)
|
||||
|
||||
Returns:
|
||||
List[Dict]: 过滤后的数据列表
|
||||
"""
|
||||
time_threshold = datetime.now() - timedelta(minutes=minutes)
|
||||
time_threshold_str = time_threshold.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
filtered_items = []
|
||||
for item in items:
|
||||
starttime = item.get('starttime', '')
|
||||
if starttime and starttime >= time_threshold_str:
|
||||
filtered_items.append(item)
|
||||
|
||||
return filtered_items
|
||||
|
||||
|
||||
def sort_and_limit_results(items: List[Dict], limit: int = 6,
|
||||
remove_time: bool = True) -> List[Dict]:
|
||||
"""
|
||||
对结果进行排序、限制数量并移除时间字段
|
||||
|
||||
Args:
|
||||
items: 数据列表
|
||||
limit: 最大返回数量
|
||||
remove_time: 是否移除 starttime 字段
|
||||
|
||||
Returns:
|
||||
List[Dict]: 处理后的数据列表
|
||||
"""
|
||||
# 按时间降序排序(最新的在前)
|
||||
items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
|
||||
# 限制数量
|
||||
result_items = items[:limit]
|
||||
|
||||
# 移除 starttime 字段
|
||||
if remove_time:
|
||||
for item in result_items:
|
||||
item.pop('starttime', None)
|
||||
|
||||
# 如果结果少于1条,返回空列表
|
||||
if len(result_items) < 1:
|
||||
return []
|
||||
|
||||
return result_items
|
||||
|
||||
|
||||
def generate_session_key(session_id: str, key_type: str = "session") -> str:
|
||||
"""
|
||||
生成 Redis key
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
key_type: key 类型 ("session", "read", "write", "count")
|
||||
|
||||
Returns:
|
||||
str: Redis key
|
||||
"""
|
||||
if key_type == "count":
|
||||
return f"session:count:{session_id}"
|
||||
elif key_type == "write":
|
||||
return f"session:write:{session_id}"
|
||||
elif key_type == "session" or key_type == "read":
|
||||
return f"session:{session_id}"
|
||||
else:
|
||||
return f"session:{session_id}"
|
||||
|
||||
|
||||
def get_current_timestamp() -> str:
|
||||
"""
|
||||
获取当前时间戳字符串
|
||||
|
||||
Returns:
|
||||
str: 格式化的时间字符串 "YYYY-MM-DD HH:MM:SS"
|
||||
"""
|
||||
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
@@ -1,11 +1,36 @@
|
||||
import redis
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from app.core.config import settings
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
|
||||
from app.core.memory.agent.utils.redis_base import (
|
||||
serialize_messages,
|
||||
deserialize_messages,
|
||||
fix_encoding,
|
||||
format_session_data,
|
||||
filter_by_time_range,
|
||||
sort_and_limit_results,
|
||||
generate_session_key,
|
||||
get_current_timestamp
|
||||
)
|
||||
|
||||
|
||||
class RedisSessionStore:
|
||||
|
||||
|
||||
class RedisWriteStore:
|
||||
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
|
||||
Args:
|
||||
host: Redis 主机地址
|
||||
port: Redis 端口
|
||||
db: Redis 数据库编号
|
||||
password: Redis 密码
|
||||
session_id: 会话ID
|
||||
"""
|
||||
self.r = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
@@ -16,32 +41,437 @@ class RedisSessionStore:
|
||||
)
|
||||
self.uudi = session_id
|
||||
|
||||
def _fix_encoding(self, text):
|
||||
"""修复错误编码的文本"""
|
||||
if not text or not isinstance(text, str):
|
||||
return text
|
||||
try:
|
||||
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
|
||||
return text.encode('latin-1').decode('utf-8')
|
||||
except (UnicodeDecodeError, UnicodeEncodeError):
|
||||
# 如果修复失败,返回原文本
|
||||
return text
|
||||
|
||||
# 修改后的 save_session 方法
|
||||
def save_session(self, userid, messages, aimessages, apply_id, end_user_id):
|
||||
def save_session_write(self, userid: str, messages: str) -> str:
|
||||
"""
|
||||
写入一条会话数据,返回 session_id
|
||||
优化版本:确保写入时间不超过1秒
|
||||
|
||||
Args:
|
||||
userid: 用户ID
|
||||
messages: 用户消息
|
||||
|
||||
Returns:
|
||||
str: 新生成的 session_id
|
||||
"""
|
||||
try:
|
||||
session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID
|
||||
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
key = f"session:{session_id}" # 使用新生成的 session_id 作为 key
|
||||
messages = serialize_messages(messages)
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="write")
|
||||
|
||||
# 使用 pipeline 批量写入,减少网络往返
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"sessionid": userid,
|
||||
"messages": messages,
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
result = pipe.execute()
|
||||
|
||||
# 直接写入数据,decode_responses=True 已经处理了编码
|
||||
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
print(f"[save_session_write] 保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
|
||||
"""
|
||||
通过 save_session_write 的 userid 获取 sessionid 和 messages
|
||||
|
||||
Args:
|
||||
userid: 用户ID (对应 sessionid 字段)
|
||||
|
||||
Returns:
|
||||
List[Dict] 或 False: 如果找到数据返回 [{"sessionid": "...", "messages": "..."}, ...],否则返回 False
|
||||
"""
|
||||
try:
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
return False
|
||||
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 筛选符合 userid 的数据
|
||||
results = []
|
||||
for key, data in zip(keys, all_data):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == userid:
|
||||
# 从 key 中提取 session_id: session:write:{session_id}
|
||||
session_id = key.split(':')[-1]
|
||||
results.append({
|
||||
"sessionid": session_id,
|
||||
"messages": fix_encoding(data.get('messages', ''))
|
||||
})
|
||||
|
||||
if not results:
|
||||
return False
|
||||
|
||||
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"[get_session_by_userid] 查询失败: {e}")
|
||||
return False
|
||||
|
||||
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
|
||||
"""
|
||||
通过 end_user_id 获取所有 write 类型的会话数据
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID (对应 sessionid 字段)
|
||||
|
||||
Returns:
|
||||
List[Dict] 或 False: 如果找到数据返回完整的会话信息列表,否则返回 False
|
||||
|
||||
返回格式:
|
||||
[
|
||||
{
|
||||
"session_id": "uuid",
|
||||
"id": "...",
|
||||
"sessionid": "end_user_id",
|
||||
"messages": "...",
|
||||
"starttime": "timestamp"
|
||||
},
|
||||
...
|
||||
]
|
||||
"""
|
||||
try:
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
||||
return False
|
||||
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 筛选符合 end_user_id 的数据
|
||||
results = []
|
||||
for key, data in zip(keys, all_data):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == end_user_id:
|
||||
# 从 key 中提取 session_id: session:write:{session_id}
|
||||
session_id = key.split(':')[-1]
|
||||
|
||||
# 构建完整的会话信息
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
"id": data.get('id', ''),
|
||||
"sessionid": data.get('sessionid', ''),
|
||||
"messages": fix_encoding(data.get('messages', '')),
|
||||
"starttime": data.get('starttime', '')
|
||||
}
|
||||
results.append(session_info)
|
||||
|
||||
if not results:
|
||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
||||
return False
|
||||
|
||||
# 按时间排序(最新的在前)
|
||||
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
|
||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def find_user_recent_sessions(self, userid: str,
|
||||
minutes: int = 5) -> List[Dict[str, str]]:
|
||||
"""
|
||||
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
|
||||
|
||||
Args:
|
||||
userid: 用户ID (对应 sessionid 字段)
|
||||
minutes: 查询最近几分钟的数据,默认5分钟
|
||||
|
||||
Returns:
|
||||
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 筛选符合 userid 的数据
|
||||
matched_items = []
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == userid and data.get('starttime'):
|
||||
# write 类型没有 aimessages,所以 Answer 为空
|
||||
matched_items.append({
|
||||
"Query": fix_encoding(data.get('messages', '')),
|
||||
"Answer": "",
|
||||
"starttime": data.get('starttime', '')
|
||||
})
|
||||
|
||||
# 根据时间范围过滤
|
||||
filtered_items = filter_by_time_range(matched_items, minutes)
|
||||
# 排序并移除时间字段
|
||||
result_items = sort_and_limit_results(filtered_items, limit=None)
|
||||
print(result_items)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
||||
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
def delete_all_write_sessions(self) -> int:
|
||||
"""
|
||||
删除所有 write 类型的会话
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
keys = self.r.keys('session:write:*')
|
||||
if keys:
|
||||
return self.r.delete(*keys)
|
||||
return 0
|
||||
|
||||
|
||||
class RedisCountStore:
|
||||
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
|
||||
Args:
|
||||
host: Redis 主机地址
|
||||
port: Redis 端口
|
||||
db: Redis 数据库编号
|
||||
password: Redis 密码
|
||||
session_id: 会话ID
|
||||
"""
|
||||
self.r = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
decode_responses=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.uudi = session_id
|
||||
|
||||
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
|
||||
"""
|
||||
保存用户访问次数统计
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
count: 访问次数
|
||||
messages: 消息内容
|
||||
|
||||
Returns:
|
||||
str: 新生成的 session_id
|
||||
"""
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
index_key = f'session:count:index:{end_user_id}' # 索引键
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"end_user_id": end_user_id,
|
||||
"count": int(count),
|
||||
"messages": serialize_messages(messages),
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
|
||||
|
||||
# 创建索引:end_user_id -> session_id 映射
|
||||
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
|
||||
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
||||
return session_id
|
||||
|
||||
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
|
||||
"""
|
||||
通过 end_user_id 查询访问次数统计
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
list 或 False: 如果找到返回 [count, messages],否则返回 False
|
||||
"""
|
||||
try:
|
||||
# 使用索引键快速查找
|
||||
index_key = f'session:count:index:{end_user_id}'
|
||||
|
||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||
try:
|
||||
key_type = self.r.type(index_key)
|
||||
if key_type != 'string' and key_type != 'none':
|
||||
self.r.delete(index_key)
|
||||
return False
|
||||
except Exception as type_error:
|
||||
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
session_id = self.r.get(index_key)
|
||||
|
||||
if not session_id:
|
||||
return False
|
||||
|
||||
# 直接获取数据
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
data = self.r.hgetall(key)
|
||||
|
||||
if not data:
|
||||
# 索引存在但数据不存在,清理索引
|
||||
self.r.delete(index_key)
|
||||
return False
|
||||
|
||||
count = data.get('count')
|
||||
messages_str = data.get('messages')
|
||||
|
||||
if count is not None:
|
||||
messages = deserialize_messages(messages_str)
|
||||
return [int(count), messages]
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"[get_sessions_count] 查询失败: {e}")
|
||||
return False
|
||||
def update_sessions_count(self, end_user_id: str, new_count: int,
|
||||
messages: Any) -> bool:
|
||||
"""
|
||||
通过 end_user_id 修改访问次数统计(优化版:使用索引)
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
new_count: 新的 count 值
|
||||
messages: 消息内容
|
||||
|
||||
Returns:
|
||||
bool: 更新成功返回 True,未找到记录返回 False
|
||||
"""
|
||||
try:
|
||||
# 使用索引键快速查找
|
||||
index_key = f'session:count:index:{end_user_id}'
|
||||
|
||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||
try:
|
||||
key_type = self.r.type(index_key)
|
||||
if key_type != 'string' and key_type != 'none':
|
||||
# 索引键类型错误,删除并返回 False
|
||||
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
||||
self.r.delete(index_key)
|
||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
return False
|
||||
except Exception as type_error:
|
||||
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
session_id = self.r.get(index_key)
|
||||
|
||||
if not session_id:
|
||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
return False
|
||||
|
||||
# 直接更新数据
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
messages_str = serialize_messages(messages)
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, 'count', int(new_count))
|
||||
pipe.hset(key, 'messages', messages_str)
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"[update_sessions_count] 更新失败: {e}")
|
||||
return False
|
||||
|
||||
def delete_all_count_sessions(self) -> int:
|
||||
"""
|
||||
删除所有 count 类型的会话
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
keys = self.r.keys('session:count:*')
|
||||
if keys:
|
||||
return self.r.delete(*keys)
|
||||
return 0
|
||||
|
||||
|
||||
class RedisSessionStore:
|
||||
"""Redis 会话存储类,用于管理会话数据"""
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
|
||||
Args:
|
||||
host: Redis 主机地址
|
||||
port: Redis 端口
|
||||
db: Redis 数据库编号
|
||||
password: Redis 密码
|
||||
session_id: 会话ID
|
||||
"""
|
||||
self.r = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
decode_responses=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.uudi = session_id
|
||||
|
||||
# ==================== 写入操作 ====================
|
||||
|
||||
def save_session(self, userid: str, messages: str, aimessages: str,
|
||||
apply_id: str, end_user_id: str) -> str:
|
||||
"""
|
||||
写入一条会话数据,返回 session_id
|
||||
|
||||
Args:
|
||||
userid: 用户ID
|
||||
messages: 用户消息
|
||||
aimessages: AI回复消息
|
||||
apply_id: 应用ID
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
str: 新生成的 session_id
|
||||
"""
|
||||
try:
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="read")
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"sessionid": userid,
|
||||
@@ -49,177 +479,195 @@ class RedisSessionStore:
|
||||
"end_user_id": end_user_id,
|
||||
"messages": messages,
|
||||
"aimessages": aimessages,
|
||||
"starttime": starttime
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
|
||||
# 可选:设置过期时间(例如30天),避免数据无限增长
|
||||
# pipe.expire(key, 30 * 24 * 60 * 60)
|
||||
|
||||
# 执行批量操作
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id # 返回新生成的 session_id
|
||||
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
print(f"保存会话失败: {e}")
|
||||
print(f"[save_session] 保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
def save_sessions_batch(self, sessions_data):
|
||||
"""
|
||||
批量写入多条会话数据,返回 session_id 列表
|
||||
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id
|
||||
优化版本:批量操作,大幅提升性能
|
||||
"""
|
||||
try:
|
||||
session_ids = []
|
||||
pipe = self.r.pipeline()
|
||||
|
||||
for session in sessions_data:
|
||||
session_id = str(uuid.uuid4())
|
||||
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
key = f"session:{session_id}"
|
||||
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"sessionid": session.get('userid'),
|
||||
"apply_id": session.get('apply_id'),
|
||||
"end_user_id": session.get('end_user_id'),
|
||||
"messages": session.get('messages'),
|
||||
"aimessages": session.get('aimessages'),
|
||||
"starttime": starttime
|
||||
})
|
||||
|
||||
session_ids.append(session_id)
|
||||
|
||||
# 一次性执行所有写入操作
|
||||
results = pipe.execute()
|
||||
print(f"批量保存完成: {len(session_ids)} 条记录")
|
||||
return session_ids
|
||||
except Exception as e:
|
||||
print(f"批量保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
# ---------------- 读取 ----------------
|
||||
def get_session(self, session_id):
|
||||
# ==================== 读取操作 ====================
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
读取一条会话数据
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
Dict 或 None: 会话数据
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
key = generate_session_key(session_id)
|
||||
data = self.r.hgetall(key)
|
||||
return data if data else None
|
||||
|
||||
def get_session_apply_group(self, sessionid, apply_id, end_user_id):
|
||||
def get_all_sessions(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据
|
||||
"""
|
||||
result_items = []
|
||||
|
||||
# 遍历所有会话数据
|
||||
for key in self.r.keys('session:*'):
|
||||
data = self.r.hgetall(key)
|
||||
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 检查三个条件是否都匹配
|
||||
if (data.get('sessionid') == sessionid and
|
||||
data.get('apply_id') == apply_id and
|
||||
data.get('end_user_id') == end_user_id):
|
||||
result_items.append(data)
|
||||
|
||||
return result_items
|
||||
|
||||
def get_all_sessions(self):
|
||||
"""
|
||||
获取所有会话数据
|
||||
获取所有会话数据(不包括 count 和 write 类型)
|
||||
|
||||
Returns:
|
||||
Dict: 所有会话数据,key 为 session_id
|
||||
"""
|
||||
sessions = {}
|
||||
for key in self.r.keys('session:*'):
|
||||
sid = key.split(':')[1]
|
||||
sessions[sid] = self.get_session(sid)
|
||||
# 排除 count 和 write 类型的 key
|
||||
if ':count:' not in key and ':write:' not in key:
|
||||
sid = key.split(':')[1]
|
||||
sessions[sid] = self.get_session(sid)
|
||||
return sessions
|
||||
|
||||
# ---------------- 更新 ----------------
|
||||
def update_session(self, session_id, field, value):
|
||||
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
||||
end_user_id: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条
|
||||
|
||||
Args:
|
||||
sessionid: 会话ID(支持模糊匹配)
|
||||
apply_id: 应用ID
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
keys = self.r.keys('session:*')
|
||||
if not keys:
|
||||
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
# 排除 count 和 write 类型
|
||||
if ':count:' not in key and ':write:' not in key:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 筛选符合条件的数据
|
||||
matched_items = []
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
if (data.get('apply_id') == apply_id and
|
||||
data.get('end_user_id') == end_user_id):
|
||||
# 支持模糊匹配或完全匹配 sessionid
|
||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||
matched_items.append(format_session_data(data, include_time=True))
|
||||
|
||||
# 排序、限制数量并移除时间字段
|
||||
result_items = sort_and_limit_results(matched_items, limit=6)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
# ==================== 更新操作 ====================
|
||||
|
||||
def update_session(self, session_id: str, field: str, value: Any) -> bool:
|
||||
"""
|
||||
更新单个字段
|
||||
优化版本:使用 pipeline 减少网络往返
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
field: 字段名
|
||||
value: 字段值
|
||||
|
||||
Returns:
|
||||
bool: 是否更新成功
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
key = generate_session_key(session_id)
|
||||
pipe = self.r.pipeline()
|
||||
pipe.exists(key)
|
||||
pipe.hset(key, field, value)
|
||||
results = pipe.execute()
|
||||
return bool(results[0]) # 返回 key 是否存在
|
||||
return bool(results[0])
|
||||
|
||||
# ---------------- 删除 ----------------
|
||||
def delete_session(self, session_id):
|
||||
# ==================== 删除操作 ====================
|
||||
|
||||
def delete_session(self, session_id: str) -> int:
|
||||
"""
|
||||
删除单条会话
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
key = generate_session_key(session_id)
|
||||
return self.r.delete(key)
|
||||
|
||||
def delete_all_sessions(self):
|
||||
def delete_all_sessions(self) -> int:
|
||||
"""
|
||||
删除所有会话
|
||||
删除所有会话(不包括 count 和 write 类型)
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
keys = self.r.keys('session:*')
|
||||
if keys:
|
||||
return self.r.delete(*keys)
|
||||
# 过滤掉 count 和 write 类型
|
||||
keys_to_delete = [k for k in keys if ':count:' not in k and ':write:' not in k]
|
||||
if keys_to_delete:
|
||||
return self.r.delete(*keys_to_delete)
|
||||
return 0
|
||||
|
||||
def delete_duplicate_sessions(self):
|
||||
def delete_duplicate_sessions(self) -> int:
|
||||
"""
|
||||
删除重复会话数据,条件:
|
||||
"sessionid"、"user_id"、"end_user_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除
|
||||
优化版本:使用 pipeline 批量操作,确保在1秒内完成
|
||||
删除重复会话数据(不包括 count 和 write 类型)
|
||||
条件:sessionid、user_id、end_user_id、messages、aimessages 五个字段都相同的只保留一个
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 第一步:使用 pipeline 批量获取所有 key
|
||||
keys = self.r.keys('session:*')
|
||||
|
||||
if not keys:
|
||||
print("[delete_duplicate_sessions] 没有会话数据")
|
||||
return 0
|
||||
|
||||
# 第二步:使用 pipeline 批量获取所有数据
|
||||
# 批量获取所有数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
# 排除 count 和 write 类型
|
||||
if ':count:' not in key and ':write:' not in key:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 第三步:在内存中识别重复数据
|
||||
seen = {} # 用字典记录:identifier -> key(保留第一个出现的 key)
|
||||
keys_to_delete = [] # 需要删除的 key 列表
|
||||
# 识别重复数据
|
||||
seen = {}
|
||||
keys_to_delete = []
|
||||
|
||||
for key, data in zip(keys, all_data, strict=False):
|
||||
for key, data in zip([k for k in keys if ':count:' not in k and ':write:' not in k], all_data, strict=False):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 获取五个字段的值
|
||||
sessionid = data.get('sessionid', '')
|
||||
user_id = data.get('id', '')
|
||||
end_user_id = data.get('end_user_id', '')
|
||||
messages = data.get('messages', '')
|
||||
aimessages = data.get('aimessages', '')
|
||||
|
||||
# 用五元组作为唯一标识
|
||||
identifier = (sessionid, user_id, end_user_id, messages, aimessages)
|
||||
identifier = (
|
||||
data.get('sessionid', ''),
|
||||
data.get('id', ''),
|
||||
data.get('end_user_id', ''),
|
||||
data.get('messages', ''),
|
||||
data.get('aimessages', '')
|
||||
)
|
||||
|
||||
if identifier in seen:
|
||||
# 重复,标记为待删除
|
||||
keys_to_delete.append(key)
|
||||
else:
|
||||
# 第一次出现,记录
|
||||
seen[identifier] = key
|
||||
|
||||
# 第四步:使用 pipeline 批量删除重复的 key
|
||||
# 批量删除重复的 key
|
||||
deleted_count = 0
|
||||
if keys_to_delete:
|
||||
# 分批删除,避免单次操作过大
|
||||
batch_size = 1000
|
||||
for i in range(0, len(keys_to_delete), batch_size):
|
||||
batch = keys_to_delete[i:i + batch_size]
|
||||
@@ -233,79 +681,28 @@ class RedisSessionStore:
|
||||
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||
return deleted_count
|
||||
|
||||
def find_user_session(self, sessionid):
|
||||
user_id = sessionid
|
||||
|
||||
result_items = []
|
||||
for key, values in store.get_all_sessions().items():
|
||||
history = {}
|
||||
if user_id == str(values['sessionid']):
|
||||
history["Query"] = values['messages']
|
||||
history["Answer"] = values['aimessages']
|
||||
result_items.append(history)
|
||||
|
||||
if len(result_items) <= 1:
|
||||
result_items = []
|
||||
return (result_items)
|
||||
|
||||
def find_user_apply_group(self, sessionid, apply_id, end_user_id):
|
||||
"""
|
||||
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据,返回最新的6条
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
# 使用 pipeline 批量获取数据,提高性能
|
||||
keys = self.r.keys('session:*')
|
||||
|
||||
if not keys:
|
||||
print(f"查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 使用 pipeline 批量获取所有 hash 数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 解析并筛选符合条件的数据
|
||||
matched_items = []
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 检查是否符合三个条件
|
||||
|
||||
if (data.get('apply_id') == apply_id and
|
||||
data.get('end_user_id') == end_user_id):
|
||||
# 支持模糊匹配 sessionid 或者完全匹配
|
||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||
matched_items.append({
|
||||
"Query": self._fix_encoding(data.get('messages')),
|
||||
"Answer": self._fix_encoding(data.get('aimessages')),
|
||||
"starttime": data.get('starttime', '')
|
||||
})
|
||||
# 按时间降序排序(最新的在前)
|
||||
matched_items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
# 只保留最新的6条
|
||||
result_items = matched_items[:6]
|
||||
# # 移除 starttime 字段
|
||||
for item in result_items:
|
||||
item.pop('starttime', None)
|
||||
|
||||
# 如果结果少于等于1条,返回空列表
|
||||
if len(result_items) <= 1:
|
||||
result_items = []
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
|
||||
# 全局实例
|
||||
store = RedisSessionStore(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||
session_id=str(uuid.uuid4())
|
||||
)
|
||||
)
|
||||
|
||||
write_store = RedisWriteStore(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||
session_id=str(uuid.uuid4())
|
||||
)
|
||||
|
||||
count_store = RedisCountStore(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||
session_id=str(uuid.uuid4())
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ Write Tools for Memory Knowledge Extraction Pipeline
|
||||
This module provides the main write function for executing the knowledge extraction
|
||||
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
@@ -123,23 +124,48 @@ async def write(
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
||||
|
||||
# 添加死锁重试机制
|
||||
max_retries = 3
|
||||
retry_delay = 1 # 秒
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=all_dialogue_nodes,
|
||||
chunk_nodes=all_chunk_nodes,
|
||||
statement_nodes=all_statement_nodes,
|
||||
entity_nodes=all_entity_nodes,
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
connector=neo4j_connector
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
break
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
if attempt < max_retries - 1:
|
||||
logger.info(f"Retrying... (attempt {attempt + 2}/{max_retries})")
|
||||
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
# 检查是否是死锁错误
|
||||
if "DeadlockDetected" in error_msg or "deadlock" in error_msg.lower():
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(f"Deadlock detected, retrying... (attempt {attempt + 2}/{max_retries})")
|
||||
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||
else:
|
||||
logger.error(f"Failed after {max_retries} attempts due to deadlock: {e}")
|
||||
raise
|
||||
else:
|
||||
# 非死锁错误,直接抛出
|
||||
raise
|
||||
|
||||
try:
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=all_dialogue_nodes,
|
||||
chunk_nodes=all_chunk_nodes,
|
||||
statement_nodes=all_statement_nodes,
|
||||
entity_nodes=all_entity_nodes,
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
connector=neo4j_connector
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
finally:
|
||||
await neo4j_connector.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing Neo4j connector: {e}")
|
||||
|
||||
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
||||
|
||||
|
||||
@@ -58,6 +58,12 @@ from app.core.memory.models.triplet_models import (
|
||||
TripletExtractionResponse,
|
||||
)
|
||||
|
||||
# Ontology models
|
||||
from app.core.memory.models.ontology_models import (
|
||||
OntologyClass,
|
||||
OntologyExtractionResponse,
|
||||
)
|
||||
|
||||
# Variable configuration models
|
||||
from app.core.memory.models.variate_config import (
|
||||
StatementExtractionConfig,
|
||||
@@ -105,6 +111,9 @@ __all__ = [
|
||||
"Entity",
|
||||
"Triplet",
|
||||
"TripletExtractionResponse",
|
||||
# Ontology models
|
||||
"OntologyClass",
|
||||
"OntologyExtractionResponse",
|
||||
# Variable configuration
|
||||
"StatementExtractionConfig",
|
||||
"ForgettingEngineConfig",
|
||||
|
||||
@@ -413,7 +413,8 @@ class ExtractedEntityNode(Node):
|
||||
description="Entity aliases - alternative names for this entity"
|
||||
)
|
||||
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
|
||||
fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
||||
|
||||
|
||||
135
api/app/core/memory/models/ontology_models.py
Normal file
135
api/app/core/memory/models/ontology_models.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Models for ontology classes and extraction responses.
|
||||
|
||||
This module contains Pydantic models for representing extracted ontology classes
|
||||
from scenario descriptions, following OWL ontology engineering standards.
|
||||
|
||||
Classes:
|
||||
OntologyClass: Represents an extracted ontology class
|
||||
OntologyExtractionResponse: Response model containing extracted ontology classes
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class OntologyClass(BaseModel):
|
||||
"""Represents an extracted ontology class from scenario description.
|
||||
|
||||
An ontology class represents an abstract category or concept in a domain,
|
||||
following OWL ontology engineering standards and naming conventions.
|
||||
|
||||
Attributes:
|
||||
id: Unique string identifier for the ontology class
|
||||
name: Name of the class in PascalCase format (e.g., 'MedicalProcedure')
|
||||
name_chinese: Chinese translation of the class name (e.g., '医疗程序')
|
||||
description: Textual description of the class
|
||||
examples: List of concrete instance examples of this class
|
||||
parent_class: Optional name of the parent class in the hierarchy
|
||||
entity_type: Type/category of the entity (e.g., 'Person', 'Organization', 'Concept')
|
||||
domain: Domain this class belongs to (e.g., 'Healthcare', 'Education')
|
||||
|
||||
Config:
|
||||
extra: Ignore extra fields from LLM output
|
||||
"""
|
||||
model_config = ConfigDict(extra='ignore')
|
||||
|
||||
id: str = Field(
|
||||
default_factory=lambda: uuid4().hex,
|
||||
description="Unique identifier for the ontology class"
|
||||
)
|
||||
name: str = Field(
|
||||
...,
|
||||
description="Name of the class in PascalCase format"
|
||||
)
|
||||
name_chinese: Optional[str] = Field(
|
||||
None,
|
||||
description="Chinese translation of the class name"
|
||||
)
|
||||
description: str = Field(
|
||||
...,
|
||||
description="Description of the class"
|
||||
)
|
||||
examples: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="List of concrete instance examples"
|
||||
)
|
||||
parent_class: Optional[str] = Field(
|
||||
None,
|
||||
description="Name of the parent class in the hierarchy"
|
||||
)
|
||||
entity_type: str = Field(
|
||||
...,
|
||||
description="Type/category of the entity"
|
||||
)
|
||||
domain: str = Field(
|
||||
...,
|
||||
description="Domain this class belongs to"
|
||||
)
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
def validate_pascal_case(cls, v: str) -> str:
|
||||
"""Validate that the class name follows PascalCase convention.
|
||||
|
||||
PascalCase rules:
|
||||
- Must start with an uppercase letter
|
||||
- Cannot contain spaces
|
||||
- Should not contain special characters except underscores
|
||||
|
||||
Args:
|
||||
v: The class name to validate
|
||||
|
||||
Returns:
|
||||
The validated class name
|
||||
|
||||
Raises:
|
||||
ValueError: If the name doesn't follow PascalCase convention
|
||||
"""
|
||||
if not v:
|
||||
raise ValueError("Class name cannot be empty")
|
||||
|
||||
if not v[0].isupper():
|
||||
raise ValueError(
|
||||
f"Class name '{v}' must start with an uppercase letter (PascalCase)"
|
||||
)
|
||||
|
||||
if ' ' in v:
|
||||
raise ValueError(
|
||||
f"Class name '{v}' cannot contain spaces (PascalCase)"
|
||||
)
|
||||
|
||||
# Check for invalid characters (allow alphanumeric and underscore only)
|
||||
if not all(c.isalnum() or c == '_' for c in v):
|
||||
raise ValueError(
|
||||
f"Class name '{v}' contains invalid characters. "
|
||||
"Only alphanumeric characters and underscores are allowed"
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class OntologyExtractionResponse(BaseModel):
|
||||
"""Response model for ontology extraction from LLM.
|
||||
|
||||
This model represents the structured output from the LLM when
|
||||
extracting ontology classes from scenario descriptions.
|
||||
|
||||
Attributes:
|
||||
classes: List of extracted ontology classes
|
||||
domain: Domain/field the scenario belongs to
|
||||
|
||||
Config:
|
||||
extra: Ignore extra fields from LLM output
|
||||
"""
|
||||
model_config = ConfigDict(extra='ignore')
|
||||
|
||||
classes: List[OntologyClass] = Field(
|
||||
default_factory=list,
|
||||
description="List of extracted ontology classes"
|
||||
)
|
||||
domain: str = Field(
|
||||
...,
|
||||
description="Domain/field the scenario belongs to"
|
||||
)
|
||||
@@ -134,42 +134,45 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
|
||||
if len(desc_b) > len(desc_a):
|
||||
canonical.description = desc_b
|
||||
# 合并事实摘要:统一保留一个“实体: name”行,来源行去重保序
|
||||
fact_a = getattr(canonical, "fact_summary", "") or ""
|
||||
fact_b = getattr(ent, "fact_summary", "") or ""
|
||||
def _extract_sources(txt: str) -> List[str]:
|
||||
sources: List[str] = []
|
||||
if not txt:
|
||||
return sources
|
||||
for line in str(txt).splitlines():
|
||||
ln = line.strip()
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_a = getattr(canonical, "fact_summary", "") or ""
|
||||
# fact_b = getattr(ent, "fact_summary", "") or ""
|
||||
# def _extract_sources(txt: str) -> List[str]:
|
||||
# sources: List[str] = []
|
||||
# if not txt:
|
||||
# return sources
|
||||
# for line in str(txt).splitlines():
|
||||
# ln = line.strip()
|
||||
# 支持“来源:”或“来源:”前缀
|
||||
m = re.match(r"^来源[::]\s*(.+)$", ln)
|
||||
if m:
|
||||
content = m.group(1).strip()
|
||||
if content:
|
||||
sources.append(content)
|
||||
# m = re.match(r"^来源[::]\s*(.+)$", ln)
|
||||
# if m:
|
||||
# content = m.group(1).strip()
|
||||
# if content:
|
||||
# sources.append(content)
|
||||
# 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失
|
||||
if not sources and txt.strip():
|
||||
sources.append(txt.strip())
|
||||
return sources
|
||||
# if not sources and txt.strip():
|
||||
# sources.append(txt.strip())
|
||||
# return sources
|
||||
try:
|
||||
src_a = _extract_sources(fact_a)
|
||||
src_b = _extract_sources(fact_b)
|
||||
seen = set()
|
||||
merged_sources: List[str] = []
|
||||
for s in src_a + src_b:
|
||||
if s and s not in seen:
|
||||
seen.add(s)
|
||||
merged_sources.append(s)
|
||||
if merged_sources:
|
||||
name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
|
||||
canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
|
||||
elif fact_b and not fact_a:
|
||||
canonical.fact_summary = fact_b
|
||||
# src_a = _extract_sources(fact_a)
|
||||
# src_b = _extract_sources(fact_b)
|
||||
# seen = set()
|
||||
# merged_sources: List[str] = []
|
||||
# for s in src_a + src_b:
|
||||
# if s and s not in seen:
|
||||
# seen.add(s)
|
||||
# merged_sources.append(s)
|
||||
# if merged_sources:
|
||||
# name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
|
||||
# canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
|
||||
# elif fact_b and not fact_a:
|
||||
# canonical.fact_summary = fact_b
|
||||
pass
|
||||
except Exception:
|
||||
# 兜底:若解析失败,保留较长文本
|
||||
if len(fact_b) > len(fact_a):
|
||||
canonical.fact_summary = fact_b
|
||||
# if len(fact_b) > len(fact_a):
|
||||
# canonical.fact_summary = fact_b
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -145,10 +145,13 @@ def _choose_canonical(a: ExtractedEntityNode, b: ExtractedEntityNode) -> int: #
|
||||
# 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整)
|
||||
desc_a = (getattr(a, "description", "") or "")
|
||||
desc_b = (getattr(b, "description", "") or "")
|
||||
fact_a = (getattr(a, "fact_summary", "") or "")
|
||||
fact_b = (getattr(b, "fact_summary", "") or "")
|
||||
score_a = len(desc_a) + len(fact_a)
|
||||
score_b = len(desc_b) + len(fact_b)
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_a = (getattr(a, "fact_summary", "") or "")
|
||||
# fact_b = (getattr(b, "fact_summary", "") or "")
|
||||
# score_a = len(desc_a) + len(fact_a)
|
||||
# score_b = len(desc_b) + len(fact_b)
|
||||
score_a = len(desc_a)
|
||||
score_b = len(desc_b)
|
||||
if score_a != score_b:
|
||||
return 0 if score_a >= score_b else 1
|
||||
return 0
|
||||
@@ -189,7 +192,8 @@ async def _judge_pair(
|
||||
"entity_type": getattr(a, "entity_type", None),
|
||||
"description": getattr(a, "description", None),
|
||||
"aliases": getattr(a, "aliases", None) or [],
|
||||
"fact_summary": getattr(a, "fact_summary", None),
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# "fact_summary": getattr(a, "fact_summary", None),
|
||||
"connect_strength": getattr(a, "connect_strength", None),
|
||||
}
|
||||
entity_b = {
|
||||
@@ -197,7 +201,8 @@ async def _judge_pair(
|
||||
"entity_type": getattr(b, "entity_type", None),
|
||||
"description": getattr(b, "description", None),
|
||||
"aliases": getattr(b, "aliases", None) or [],
|
||||
"fact_summary": getattr(b, "fact_summary", None),
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# "fact_summary": getattr(b, "fact_summary", None),
|
||||
"connect_strength": getattr(b, "connect_strength", None),
|
||||
}
|
||||
# 5. 渲染LLM提示词(用工具函数填充模板,包含实体信息、上下文、输出格式)
|
||||
@@ -248,7 +253,8 @@ async def _judge_pair_disamb(
|
||||
"entity_type": getattr(a, "entity_type", None),
|
||||
"description": getattr(a, "description", None),
|
||||
"aliases": getattr(a, "aliases", None) or [],
|
||||
"fact_summary": getattr(a, "fact_summary", None),
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# "fact_summary": getattr(a, "fact_summary", None),
|
||||
"connect_strength": getattr(a, "connect_strength", None),
|
||||
}
|
||||
entity_b = {
|
||||
@@ -256,7 +262,8 @@ async def _judge_pair_disamb(
|
||||
"entity_type": getattr(b, "entity_type", None),
|
||||
"description": getattr(b, "description", None),
|
||||
"aliases": getattr(b, "aliases", None) or [],
|
||||
"fact_summary": getattr(b, "fact_summary", None),
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# "fact_summary": getattr(b, "fact_summary", None),
|
||||
"connect_strength": getattr(b, "connect_strength", None),
|
||||
}
|
||||
prompt = render_entity_dedup_prompt(
|
||||
|
||||
@@ -72,7 +72,8 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
|
||||
description=row.get("description") or "",
|
||||
aliases=row.get("aliases") or [],
|
||||
name_embedding=row.get("name_embedding") or [],
|
||||
fact_summary=row.get("fact_summary") or "",
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_summary=row.get("fact_summary") or "",
|
||||
connect_strength=row.get("connect_strength") or "",
|
||||
)
|
||||
|
||||
|
||||
@@ -1085,7 +1085,8 @@ class ExtractionOrchestrator:
|
||||
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
|
||||
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
|
||||
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
||||
fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
||||
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
||||
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
||||
name_embedding=getattr(entity, 'name_embedding', None),
|
||||
|
||||
@@ -8,4 +8,5 @@
|
||||
- TemporalExtractor: 时间信息提取
|
||||
- EmbeddingGenerator: 嵌入向量生成
|
||||
- MemorySummaryGenerator: 记忆摘要生成
|
||||
- OntologyExtractor: 本体类提取
|
||||
"""
|
||||
|
||||
@@ -14,6 +14,34 @@ from pydantic import Field
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
# 支持的语言列表和默认回退值
|
||||
SUPPORTED_LANGUAGES = {"zh", "en"}
|
||||
FALLBACK_LANGUAGE = "en"
|
||||
|
||||
|
||||
def validate_language(language: Optional[str]) -> str:
|
||||
"""
|
||||
校验语言参数,确保其为有效值。
|
||||
|
||||
Args:
|
||||
language: 待校验的语言代码
|
||||
|
||||
Returns:
|
||||
有效的语言代码("zh" 或 "en")
|
||||
"""
|
||||
if language is None:
|
||||
return FALLBACK_LANGUAGE
|
||||
|
||||
lang = str(language).lower().strip()
|
||||
if lang in SUPPORTED_LANGUAGES:
|
||||
return lang
|
||||
|
||||
logger.warning(
|
||||
f"无效的语言参数 '{language}',已回退到默认值 '{FALLBACK_LANGUAGE}'。"
|
||||
f"支持的语言: {SUPPORTED_LANGUAGES}"
|
||||
)
|
||||
return FALLBACK_LANGUAGE
|
||||
|
||||
|
||||
class MemorySummaryResponse(RobustLLMResponse):
|
||||
"""Structured response for summary generation per chunk.
|
||||
@@ -31,7 +59,8 @@ class MemorySummaryResponse(RobustLLMResponse):
|
||||
|
||||
async def generate_title_and_type_for_summary(
|
||||
content: str,
|
||||
llm_client
|
||||
llm_client,
|
||||
language: str = None
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
为MemorySummary生成标题和类型
|
||||
@@ -41,11 +70,18 @@ async def generate_title_and_type_for_summary(
|
||||
Args:
|
||||
content: Summary的内容文本
|
||||
llm_client: LLM客户端实例
|
||||
language: 生成标题使用的语言 ("zh" 中文, "en" 英文),如果为None则从配置读取
|
||||
|
||||
Returns:
|
||||
(标题, 类型)元组
|
||||
"""
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_episodic_title_and_type_prompt
|
||||
from app.core.config import settings
|
||||
|
||||
# 如果没有指定语言,从配置中读取,并校验有效性
|
||||
if language is None:
|
||||
language = settings.DEFAULT_LANGUAGE
|
||||
language = validate_language(language)
|
||||
|
||||
# 定义有效的类型集合
|
||||
VALID_TYPES = {
|
||||
@@ -57,13 +93,19 @@ async def generate_title_and_type_for_summary(
|
||||
}
|
||||
DEFAULT_TYPE = "conversation" # 默认类型
|
||||
|
||||
# 根据语言设置默认标题
|
||||
DEFAULT_TITLE = "空内容" if language == "zh" else "Empty Content"
|
||||
PARSE_ERROR_TITLE = "解析失败" if language == "zh" else "Parse Failed"
|
||||
ERROR_TITLE = "错误" if language == "zh" else "Error"
|
||||
UNKNOWN_TITLE = "未知标题" if language == "zh" else "Unknown Title"
|
||||
|
||||
try:
|
||||
if not content:
|
||||
logger.warning("content为空,无法生成标题和类型")
|
||||
return ("空内容", DEFAULT_TYPE)
|
||||
logger.warning(f"content为空,无法生成标题和类型 (language={language})")
|
||||
return (DEFAULT_TITLE, DEFAULT_TYPE)
|
||||
|
||||
# 1. 渲染Jinja2提示词模板
|
||||
prompt = await render_episodic_title_and_type_prompt(content)
|
||||
# 1. 渲染Jinja2提示词模板,传递语言参数
|
||||
prompt = await render_episodic_title_and_type_prompt(content, language=language)
|
||||
|
||||
# 2. 调用LLM生成标题和类型
|
||||
messages = [
|
||||
@@ -102,7 +144,7 @@ async def generate_title_and_type_for_summary(
|
||||
json_str = json_str.strip()
|
||||
|
||||
result_data = json.loads(json_str)
|
||||
title = result_data.get("title", "未知标题")
|
||||
title = result_data.get("title", UNKNOWN_TITLE)
|
||||
episodic_type_raw = result_data.get("type", DEFAULT_TYPE)
|
||||
|
||||
# 5. 校验和归一化类型
|
||||
@@ -130,16 +172,16 @@ async def generate_title_and_type_for_summary(
|
||||
f"已归一化为 '{episodic_type}'"
|
||||
)
|
||||
|
||||
logger.info(f"成功生成标题和类型: title={title}, type={episodic_type}")
|
||||
logger.info(f"成功生成标题和类型 (language={language}): title={title}, type={episodic_type}")
|
||||
return (title, episodic_type)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"无法解析LLM响应为JSON: {full_response}")
|
||||
return ("解析失败", DEFAULT_TYPE)
|
||||
logger.error(f"无法解析LLM响应为JSON (language={language}): {full_response}")
|
||||
return (PARSE_ERROR_TITLE, DEFAULT_TYPE)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成标题和类型时出错: {str(e)}", exc_info=True)
|
||||
return ("错误", DEFAULT_TYPE)
|
||||
logger.error(f"生成标题和类型时出错 (language={language}): {str(e)}", exc_info=True)
|
||||
return (ERROR_TITLE, DEFAULT_TYPE)
|
||||
|
||||
async def _process_chunk_summary(
|
||||
dialog: DialogData,
|
||||
@@ -153,11 +195,16 @@ async def _process_chunk_summary(
|
||||
return None
|
||||
|
||||
try:
|
||||
# 从配置中获取语言设置(只获取一次,复用),并校验有效性
|
||||
from app.core.config import settings
|
||||
language = validate_language(settings.DEFAULT_LANGUAGE)
|
||||
|
||||
# Render prompt via Jinja2 for a single chunk
|
||||
prompt_content = await render_memory_summary_prompt(
|
||||
chunk_texts=chunk.content,
|
||||
json_schema=MemorySummaryResponse.model_json_schema(),
|
||||
max_words=200,
|
||||
language=language,
|
||||
)
|
||||
|
||||
messages = [
|
||||
@@ -178,9 +225,10 @@ async def _process_chunk_summary(
|
||||
try:
|
||||
title, episodic_type = await generate_title_and_type_for_summary(
|
||||
content=summary_text,
|
||||
llm_client=llm_client
|
||||
llm_client=llm_client,
|
||||
language=language
|
||||
)
|
||||
logger.info(f"Generated title and type for MemorySummary: title={title}, type={episodic_type}")
|
||||
logger.info(f"Generated title and type for MemorySummary (language={language}): title={title}, type={episodic_type}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate title and type for chunk {chunk.id}: {e}")
|
||||
# Continue without title and type
|
||||
|
||||
@@ -0,0 +1,482 @@
|
||||
"""Ontology class extraction from scenario descriptions using LLM.
|
||||
|
||||
This module provides the OntologyExtractor class for extracting ontology classes
|
||||
from natural language scenario descriptions. It uses LLM-driven extraction combined
|
||||
with two-layer validation (string validation + OWL semantic validation).
|
||||
|
||||
Classes:
|
||||
OntologyExtractor: Extracts ontology classes from scenario descriptions
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.models.ontology_models import (
|
||||
OntologyClass,
|
||||
OntologyExtractionResponse,
|
||||
)
|
||||
from app.core.memory.utils.validation.ontology_validator import OntologyValidator
|
||||
from app.core.memory.utils.validation.owl_validator import OWLValidator
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_ontology_extraction_prompt
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OntologyExtractor:
|
||||
"""Extractor for ontology classes from scenario descriptions.
|
||||
|
||||
This extractor uses LLM to identify abstract classes and concepts from
|
||||
natural language scenario descriptions, following OWL ontology engineering
|
||||
standards. It performs two-layer validation:
|
||||
1. String validation (naming conventions, reserved words, duplicates)
|
||||
2. OWL semantic validation (consistency checking, circular inheritance)
|
||||
|
||||
Attributes:
|
||||
llm_client: OpenAI client for LLM calls
|
||||
validator: String validator for class names and descriptions
|
||||
owl_validator: OWL validator for semantic validation
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: OpenAIClient):
|
||||
"""Initialize the OntologyExtractor.
|
||||
|
||||
Args:
|
||||
llm_client: OpenAIClient instance for LLM processing
|
||||
"""
|
||||
self.llm_client = llm_client
|
||||
self.validator = OntologyValidator()
|
||||
self.owl_validator = OWLValidator()
|
||||
|
||||
logger.info("OntologyExtractor initialized")
|
||||
|
||||
async def extract_ontology_classes(
|
||||
self,
|
||||
scenario: str,
|
||||
domain: Optional[str] = None,
|
||||
max_classes: int = 15,
|
||||
min_classes: int = 5,
|
||||
enable_owl_validation: bool = True,
|
||||
llm_temperature: float = 0.3,
|
||||
llm_max_tokens: int = 2000,
|
||||
max_description_length: int = 500,
|
||||
timeout: Optional[float] = None,
|
||||
) -> OntologyExtractionResponse:
|
||||
"""Extract ontology classes from a scenario description.
|
||||
|
||||
This is the main extraction method that orchestrates the entire process:
|
||||
1. Call LLM to extract ontology classes
|
||||
2. Perform first-layer validation (string validation and cleaning)
|
||||
3. Perform second-layer validation (OWL semantic validation)
|
||||
4. Filter invalid classes based on validation errors
|
||||
5. Return validated ontology classes
|
||||
|
||||
Args:
|
||||
scenario: Natural language scenario description
|
||||
domain: Optional domain hint (e.g., "Healthcare", "Education")
|
||||
max_classes: Maximum number of classes to extract (default: 15)
|
||||
min_classes: Minimum number of classes to extract (default: 5)
|
||||
enable_owl_validation: Whether to enable OWL validation (default: True)
|
||||
llm_temperature: LLM temperature parameter (default: 0.3)
|
||||
llm_max_tokens: LLM max tokens parameter (default: 2000)
|
||||
max_description_length: Maximum description length (default: 500)
|
||||
timeout: Optional timeout in seconds for LLM call (default: None, no timeout)
|
||||
|
||||
Returns:
|
||||
OntologyExtractionResponse containing validated ontology classes
|
||||
|
||||
Raises:
|
||||
ValueError: If scenario is empty or invalid
|
||||
asyncio.TimeoutError: If extraction times out
|
||||
|
||||
Examples:
|
||||
>>> extractor = OntologyExtractor(llm_client)
|
||||
>>> response = await extractor.extract_ontology_classes(
|
||||
... scenario="A hospital manages patient records...",
|
||||
... domain="Healthcare",
|
||||
... max_classes=10,
|
||||
... timeout=30.0
|
||||
... )
|
||||
>>> len(response.classes)
|
||||
7
|
||||
"""
|
||||
# Start timing
|
||||
start_time = time.time()
|
||||
|
||||
# Validate input
|
||||
if not scenario or not scenario.strip():
|
||||
logger.error("Scenario description is empty")
|
||||
raise ValueError("Scenario description cannot be empty")
|
||||
|
||||
scenario = scenario.strip()
|
||||
|
||||
logger.info(
|
||||
f"Starting ontology extraction - scenario_length={len(scenario)}, "
|
||||
f"domain={domain}, max_classes={max_classes}, min_classes={min_classes}, "
|
||||
f"timeout={timeout}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 1: Call LLM for extraction with timeout
|
||||
logger.info("Step 1: Calling LLM for ontology extraction")
|
||||
llm_start_time = time.time()
|
||||
|
||||
if timeout is not None:
|
||||
# Wrap LLM call with timeout
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
self._call_llm_for_extraction(
|
||||
scenario=scenario,
|
||||
domain=domain,
|
||||
max_classes=max_classes,
|
||||
llm_temperature=llm_temperature,
|
||||
llm_max_tokens=llm_max_tokens,
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
llm_duration = time.time() - llm_start_time
|
||||
logger.error(
|
||||
f"LLM extraction timed out after {timeout} seconds "
|
||||
f"(actual duration: {llm_duration:.2f}s)"
|
||||
)
|
||||
# Return empty response on timeout
|
||||
return OntologyExtractionResponse(
|
||||
classes=[],
|
||||
domain=domain or "Unknown",
|
||||
)
|
||||
else:
|
||||
# No timeout specified, call directly
|
||||
response = await self._call_llm_for_extraction(
|
||||
scenario=scenario,
|
||||
domain=domain,
|
||||
max_classes=max_classes,
|
||||
llm_temperature=llm_temperature,
|
||||
llm_max_tokens=llm_max_tokens,
|
||||
)
|
||||
|
||||
llm_duration = time.time() - llm_start_time
|
||||
logger.info(
|
||||
f"LLM returned {len(response.classes)} classes in {llm_duration:.2f}s"
|
||||
)
|
||||
|
||||
# Step 2: First-layer validation (string validation and cleaning)
|
||||
logger.info("Step 2: Performing first-layer validation (string validation)")
|
||||
validation_start_time = time.time()
|
||||
|
||||
response = self._validate_and_clean(
|
||||
response=response,
|
||||
max_description_length=max_description_length,
|
||||
)
|
||||
|
||||
validation_duration = time.time() - validation_start_time
|
||||
logger.info(
|
||||
f"After first-layer validation: {len(response.classes)} classes remain "
|
||||
f"(validation took {validation_duration:.2f}s)"
|
||||
)
|
||||
|
||||
# Check if we have enough classes after first-layer validation
|
||||
if len(response.classes) < min_classes:
|
||||
logger.warning(
|
||||
f"Only {len(response.classes)} classes remain after validation, "
|
||||
f"which is below minimum of {min_classes}"
|
||||
)
|
||||
|
||||
# Step 3: Second-layer validation (OWL semantic validation)
|
||||
if enable_owl_validation and response.classes:
|
||||
logger.info("Step 3: Performing second-layer validation (OWL validation)")
|
||||
owl_start_time = time.time()
|
||||
|
||||
is_valid, errors, world = self.owl_validator.validate_ontology_classes(
|
||||
classes=response.classes,
|
||||
)
|
||||
|
||||
owl_duration = time.time() - owl_start_time
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
f"OWL validation found {len(errors)} issues in {owl_duration:.2f}s: {errors}"
|
||||
)
|
||||
|
||||
# Filter invalid classes based on errors
|
||||
response = self._filter_invalid_classes(
|
||||
response=response,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"After second-layer validation: {len(response.classes)} classes remain"
|
||||
)
|
||||
else:
|
||||
logger.info(f"OWL validation passed successfully in {owl_duration:.2f}s")
|
||||
else:
|
||||
if not enable_owl_validation:
|
||||
logger.info("Step 3: OWL validation disabled, skipping")
|
||||
else:
|
||||
logger.info("Step 3: No classes to validate, skipping OWL validation")
|
||||
|
||||
# Calculate total duration
|
||||
total_duration = time.time() - start_time
|
||||
|
||||
# Log extraction statistics
|
||||
logger.info(
|
||||
f"Ontology extraction completed - "
|
||||
f"final_class_count={len(response.classes)}, "
|
||||
f"domain={response.domain}, "
|
||||
f"total_duration={total_duration:.2f}s, "
|
||||
f"llm_duration={llm_duration:.2f}s"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Re-raise timeout errors
|
||||
total_duration = time.time() - start_time
|
||||
logger.error(
|
||||
f"Ontology extraction timed out after {timeout} seconds "
|
||||
f"(total duration: {total_duration:.2f}s)",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
total_duration = time.time() - start_time
|
||||
logger.error(
|
||||
f"Ontology extraction failed after {total_duration:.2f}s: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty response on failure
|
||||
return OntologyExtractionResponse(
|
||||
classes=[],
|
||||
domain=domain or "Unknown",
|
||||
)
|
||||
|
||||
async def _call_llm_for_extraction(
|
||||
self,
|
||||
scenario: str,
|
||||
domain: Optional[str],
|
||||
max_classes: int,
|
||||
llm_temperature: float,
|
||||
llm_max_tokens: int,
|
||||
) -> OntologyExtractionResponse:
|
||||
"""Call LLM to extract ontology classes from scenario.
|
||||
|
||||
This method renders the extraction prompt using the Jinja2 template
|
||||
and calls the LLM with structured output to get ontology classes.
|
||||
|
||||
Args:
|
||||
scenario: Scenario description text
|
||||
domain: Optional domain hint
|
||||
max_classes: Maximum number of classes to extract
|
||||
llm_temperature: LLM temperature parameter
|
||||
llm_max_tokens: LLM max tokens parameter
|
||||
|
||||
Returns:
|
||||
OntologyExtractionResponse from LLM
|
||||
|
||||
Raises:
|
||||
Exception: If LLM call fails
|
||||
"""
|
||||
try:
|
||||
# Render prompt using template
|
||||
prompt_content = await render_ontology_extraction_prompt(
|
||||
scenario=scenario,
|
||||
domain=domain,
|
||||
max_classes=max_classes,
|
||||
json_schema=OntologyExtractionResponse.model_json_schema(),
|
||||
)
|
||||
|
||||
logger.debug(f"Rendered prompt length: {len(prompt_content)}")
|
||||
|
||||
# Create messages for LLM
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are an expert ontology engineer specializing in knowledge "
|
||||
"representation and OWL standards. Extract ontology classes from "
|
||||
"scenario descriptions following the provided instructions. "
|
||||
"Return valid JSON conforming to the schema."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt_content,
|
||||
},
|
||||
]
|
||||
|
||||
# Call LLM with structured output
|
||||
logger.debug(
|
||||
f"Calling LLM with temperature={llm_temperature}, "
|
||||
f"max_tokens={llm_max_tokens}"
|
||||
)
|
||||
|
||||
response = await self.llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=OntologyExtractionResponse,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"LLM extraction successful - extracted {len(response.classes)} classes"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM extraction failed: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def _validate_and_clean(
|
||||
self,
|
||||
response: OntologyExtractionResponse,
|
||||
max_description_length: int,
|
||||
) -> OntologyExtractionResponse:
|
||||
"""Perform first-layer validation: string validation and cleaning.
|
||||
|
||||
This method validates and cleans the extracted ontology classes:
|
||||
1. Validate class names (PascalCase, no reserved words)
|
||||
2. Sanitize invalid class names
|
||||
3. Truncate long descriptions
|
||||
4. Remove duplicate classes
|
||||
|
||||
Args:
|
||||
response: OntologyExtractionResponse from LLM
|
||||
max_description_length: Maximum description length
|
||||
|
||||
Returns:
|
||||
Cleaned OntologyExtractionResponse
|
||||
"""
|
||||
if not response.classes:
|
||||
logger.debug("No classes to validate")
|
||||
return response
|
||||
|
||||
logger.debug(f"Validating {len(response.classes)} classes")
|
||||
|
||||
validated_classes = []
|
||||
|
||||
for ontology_class in response.classes:
|
||||
# Validate class name
|
||||
is_valid, error_msg = self.validator.validate_class_name(
|
||||
ontology_class.name
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
f"Invalid class name '{ontology_class.name}': {error_msg}"
|
||||
)
|
||||
|
||||
# Attempt to sanitize
|
||||
sanitized_name = self.validator.sanitize_class_name(
|
||||
ontology_class.name
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Sanitized class name: '{ontology_class.name}' -> '{sanitized_name}'"
|
||||
)
|
||||
|
||||
# Update class name
|
||||
ontology_class.name = sanitized_name
|
||||
|
||||
# Re-validate sanitized name
|
||||
is_valid, error_msg = self.validator.validate_class_name(
|
||||
sanitized_name
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.error(
|
||||
f"Failed to sanitize class name '{ontology_class.name}': {error_msg}. "
|
||||
"Skipping this class."
|
||||
)
|
||||
continue
|
||||
|
||||
# Truncate description if too long
|
||||
if ontology_class.description:
|
||||
original_length = len(ontology_class.description)
|
||||
ontology_class.description = self.validator.truncate_description(
|
||||
ontology_class.description,
|
||||
max_length=max_description_length,
|
||||
)
|
||||
|
||||
if len(ontology_class.description) < original_length:
|
||||
logger.debug(
|
||||
f"Truncated description for '{ontology_class.name}': "
|
||||
f"{original_length} -> {len(ontology_class.description)} chars"
|
||||
)
|
||||
|
||||
validated_classes.append(ontology_class)
|
||||
|
||||
# Remove duplicates (case-insensitive)
|
||||
original_count = len(validated_classes)
|
||||
validated_classes = self.validator.remove_duplicates(validated_classes)
|
||||
|
||||
if len(validated_classes) < original_count:
|
||||
logger.info(
|
||||
f"Removed {original_count - len(validated_classes)} duplicate classes"
|
||||
)
|
||||
|
||||
# Return cleaned response
|
||||
return OntologyExtractionResponse(
|
||||
classes=validated_classes,
|
||||
domain=response.domain,
|
||||
)
|
||||
|
||||
def _filter_invalid_classes(
|
||||
self,
|
||||
response: OntologyExtractionResponse,
|
||||
errors: List[str],
|
||||
) -> OntologyExtractionResponse:
|
||||
"""Filter invalid classes based on OWL validation errors.
|
||||
|
||||
This method analyzes OWL validation errors and removes classes
|
||||
that caused validation failures (e.g., circular inheritance,
|
||||
inconsistencies).
|
||||
|
||||
Args:
|
||||
response: OntologyExtractionResponse to filter
|
||||
errors: List of error messages from OWL validation
|
||||
|
||||
Returns:
|
||||
Filtered OntologyExtractionResponse
|
||||
"""
|
||||
if not errors:
|
||||
return response
|
||||
|
||||
logger.debug(f"Filtering classes based on {len(errors)} OWL validation errors")
|
||||
|
||||
# Extract class names mentioned in errors
|
||||
invalid_class_names = set()
|
||||
|
||||
for error in errors:
|
||||
# Look for class names in error messages
|
||||
for ontology_class in response.classes:
|
||||
if ontology_class.name in error:
|
||||
invalid_class_names.add(ontology_class.name)
|
||||
logger.debug(
|
||||
f"Class '{ontology_class.name}' marked as invalid due to error: {error}"
|
||||
)
|
||||
|
||||
# Filter out invalid classes
|
||||
if invalid_class_names:
|
||||
original_count = len(response.classes)
|
||||
|
||||
filtered_classes = [
|
||||
c for c in response.classes
|
||||
if c.name not in invalid_class_names
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Filtered out {original_count - len(filtered_classes)} invalid classes: "
|
||||
f"{invalid_class_names}"
|
||||
)
|
||||
|
||||
return OntologyExtractionResponse(
|
||||
classes=filtered_classes,
|
||||
domain=response.domain,
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -25,6 +25,15 @@ class TripletExtractor:
|
||||
"""
|
||||
self.llm_client = llm_client
|
||||
|
||||
def _get_language(self) -> str:
|
||||
"""Get the configured language for entity descriptions
|
||||
|
||||
Returns:
|
||||
Language code ("zh" or "en")
|
||||
"""
|
||||
from app.core.config import settings
|
||||
return settings.DEFAULT_LANGUAGE
|
||||
|
||||
async def _extract_triplets(self, statement: Statement, chunk_content: str) -> TripletExtractionResponse:
|
||||
"""Process a single statement and return extracted triplets and entities"""
|
||||
# Render the prompt using helper function
|
||||
@@ -40,7 +49,8 @@ class TripletExtractor:
|
||||
statement=statement.statement,
|
||||
chunk_content=chunk_content,
|
||||
json_schema=TripletExtractionResponse.model_json_schema(),
|
||||
predicate_instructions=PREDICATE_DEFINITIONS
|
||||
predicate_instructions=PREDICATE_DEFINITIONS,
|
||||
language=self._get_language()
|
||||
)
|
||||
|
||||
# Create messages for LLM
|
||||
|
||||
@@ -296,7 +296,9 @@ def resolve_alias_cycles(entities: List[Any], cycles: Dict[str, Set[str]]) -> Li
|
||||
key=lambda eid: (
|
||||
_strength_rank(eid),
|
||||
len(getattr(entity_by_id.get(eid), 'description', '') or ''),
|
||||
len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
|
||||
0 # 临时占位
|
||||
),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
@@ -177,7 +177,7 @@ def render_entity_dedup_prompt(
|
||||
|
||||
# Args:
|
||||
# entity_a: Dict of entity A attributes
|
||||
async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None) -> str:
|
||||
async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None, language: str = "zh") -> str:
|
||||
"""
|
||||
Renders the triplet extraction prompt using the extract_triplet.jinja2 template.
|
||||
|
||||
@@ -186,6 +186,7 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
|
||||
chunk_content: The content of the chunk to process
|
||||
json_schema: JSON schema for the expected output format
|
||||
predicate_instructions: Optional predicate instructions
|
||||
language: The language to use for entity descriptions ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
@@ -195,7 +196,8 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
|
||||
statement=statement,
|
||||
chunk_content=chunk_content,
|
||||
json_schema=json_schema,
|
||||
predicate_instructions=predicate_instructions
|
||||
predicate_instructions=predicate_instructions,
|
||||
language=language
|
||||
)
|
||||
# 记录渲染结果到提示日志(与示例日志结构一致)
|
||||
log_prompt_rendering('triplet extraction', rendered_prompt)
|
||||
@@ -204,7 +206,8 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
|
||||
'statement': 'str',
|
||||
'chunk_content': 'str',
|
||||
'json_schema': 'TripletExtractionResponse.schema',
|
||||
'predicate_instructions': 'PREDICATE_DEFINITIONS'
|
||||
'predicate_instructions': 'PREDICATE_DEFINITIONS',
|
||||
'language': language
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
@@ -213,6 +216,7 @@ async def render_memory_summary_prompt(
|
||||
chunk_texts: str,
|
||||
json_schema: dict,
|
||||
max_words: int = 200,
|
||||
language: str = "zh",
|
||||
) -> str:
|
||||
"""
|
||||
Renders the memory summary prompt using the memory_summary.jinja2 template.
|
||||
@@ -221,6 +225,7 @@ async def render_memory_summary_prompt(
|
||||
chunk_texts: Concatenated text of conversation chunks
|
||||
json_schema: JSON schema for the expected output format
|
||||
max_words: Maximum words for the summary
|
||||
language: The language to use for summary generation ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string.
|
||||
@@ -230,12 +235,14 @@ async def render_memory_summary_prompt(
|
||||
chunk_texts=chunk_texts,
|
||||
json_schema=json_schema,
|
||||
max_words=max_words,
|
||||
language=language,
|
||||
)
|
||||
log_prompt_rendering('memory summary', rendered_prompt)
|
||||
log_template_rendering('memory_summary.jinja2', {
|
||||
'chunk_texts_len': len(chunk_texts or ""),
|
||||
'max_words': max_words,
|
||||
'json_schema': 'MemorySummaryResponse.schema'
|
||||
'json_schema': 'MemorySummaryResponse.schema',
|
||||
'language': language
|
||||
})
|
||||
return rendered_prompt
|
||||
|
||||
@@ -388,24 +395,65 @@ async def render_memory_insight_prompt(
|
||||
return rendered_prompt
|
||||
|
||||
|
||||
async def render_episodic_title_and_type_prompt(content: str) -> str:
|
||||
async def render_episodic_title_and_type_prompt(content: str, language: str = "zh") -> str:
|
||||
"""
|
||||
Renders the episodic title and type classification prompt using the episodic_type_classification.jinja2 template.
|
||||
|
||||
Args:
|
||||
content: The content of the episodic memory summary to analyze
|
||||
language: The language to use for title generation ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("episodic_type_classification.jinja2")
|
||||
rendered_prompt = template.render(content=content)
|
||||
rendered_prompt = template.render(content=content, language=language)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
log_prompt_rendering('episodic title and type classification', rendered_prompt)
|
||||
# 可选:记录模板渲染信息
|
||||
log_template_rendering('episodic_type_classification.jinja2', {
|
||||
'content_len': len(content) if content else 0
|
||||
'content_len': len(content) if content else 0,
|
||||
'language': language
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
|
||||
async def render_ontology_extraction_prompt(
|
||||
scenario: str,
|
||||
domain: str | None = None,
|
||||
max_classes: int = 15,
|
||||
json_schema: dict | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Renders the ontology extraction prompt using the extract_ontology.jinja2 template.
|
||||
|
||||
Args:
|
||||
scenario: The scenario description text to extract ontology classes from
|
||||
domain: Optional domain hint for the scenario (e.g., "Healthcare", "Education")
|
||||
max_classes: Maximum number of classes to extract (default: 15)
|
||||
json_schema: JSON schema for the expected output format
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("extract_ontology.jinja2")
|
||||
rendered_prompt = template.render(
|
||||
scenario=scenario,
|
||||
domain=domain,
|
||||
max_classes=max_classes,
|
||||
json_schema=json_schema
|
||||
)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
log_prompt_rendering('ontology extraction', rendered_prompt)
|
||||
# 可选:记录模板渲染信息
|
||||
log_template_rendering('extract_ontology.jinja2', {
|
||||
'scenario_len': len(scenario) if scenario else 0,
|
||||
'domain': domain,
|
||||
'max_classes': max_classes,
|
||||
'json_schema': 'OntologyExtractionResponse.schema'
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
@@ -9,7 +9,8 @@
|
||||
- 类型: "{{ entity_a.entity_type | default('') }}"
|
||||
- 描述: "{{ entity_a.description | default('') }}"
|
||||
- 别名: {{ entity_a.aliases | default([]) }}
|
||||
- 摘要: "{{ entity_a.fact_summary | default('') }}"
|
||||
{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #}
|
||||
{# - 摘要: "{{ entity_a.fact_summary | default('') }}" #}
|
||||
- 连接强弱: "{{ entity_a.connect_strength | default('') }}"
|
||||
|
||||
实体B:
|
||||
@@ -17,7 +18,8 @@
|
||||
- 类型: "{{ entity_b.entity_type | default('') }}"
|
||||
- 描述: "{{ entity_b.description | default('') }}"
|
||||
- 别名: {{ entity_b.aliases | default([]) }}
|
||||
- 摘要: "{{ entity_b.fact_summary | default('') }}"
|
||||
{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #}
|
||||
{# - 摘要: "{{ entity_b.fact_summary | default('') }}" #}
|
||||
- 连接强弱: "{{ entity_b.connect_strength | default('') }}"
|
||||
|
||||
上下文:
|
||||
|
||||
@@ -1,8 +1,19 @@
|
||||
=== Task ===
|
||||
Generate a concise title and classify the episodic memory into the most appropriate category.
|
||||
|
||||
{% if language == "zh" %}
|
||||
**重要:请使用中文生成标题和分类。**
|
||||
{% else %}
|
||||
**Important: Please generate the title and classification in English.**
|
||||
{% endif %}
|
||||
|
||||
=== Requirements ===
|
||||
- Extract a clear, concise title (10-20 characters) that captures the core content
|
||||
{% if language == "zh" %}
|
||||
- 标题必须使用中文
|
||||
{% else %}
|
||||
- Title must be in English
|
||||
{% endif %}
|
||||
- Classify into exactly one category based on the primary theme
|
||||
- Be specific and avoid ambiguity
|
||||
- Output must be valid JSON conforming to the schema below
|
||||
|
||||
210
api/app/core/memory/utils/prompt/prompts/extract_ontology.jinja2
Normal file
210
api/app/core/memory/utils/prompt/prompts/extract_ontology.jinja2
Normal file
@@ -0,0 +1,210 @@
|
||||
===Task===
|
||||
Extract ontology classes from the given scenario description following ontology engineering standards.
|
||||
|
||||
===Role===
|
||||
You are a professional ontology engineer with expertise in knowledge representation and OWL (Web Ontology Language) standards. Your task is to identify abstract classes and concepts from scenario descriptions, not concrete instances.
|
||||
|
||||
===Scenario Description===
|
||||
{{ scenario }}
|
||||
|
||||
{% if domain -%}
|
||||
===Domain Hint===
|
||||
This scenario belongs to the **{{ domain }}** domain. Consider domain-specific concepts and terminology when extracting classes.
|
||||
{%- endif %}
|
||||
|
||||
===Extraction Rules===
|
||||
|
||||
**1. Abstract Classes, Not Instances:**
|
||||
- Extract abstract categories and concepts (e.g., "MedicalProcedure", "Patient", "Diagnosis")
|
||||
- Do NOT extract concrete instances (e.g., "John Smith", "Room 301", "2024-01-15")
|
||||
- Think in terms of "types of things" rather than "specific things"
|
||||
|
||||
**2. Naming Convention (PascalCase):**
|
||||
- Use PascalCase format for the "name" field: start with uppercase letter, capitalize each word, no spaces
|
||||
- Examples: "MedicalProcedure", "HealthcareProvider", "DiagnosticTest"
|
||||
- Avoid: "medical procedure", "healthcare_provider", "diagnostic-test"
|
||||
- Use clear, descriptive names in English
|
||||
- Avoid abbreviations unless they are standard in the domain (e.g., "API", "DNA")
|
||||
- Provide Chinese translation in the "name_chinese" field (e.g., "医疗程序", "医疗服务提供者", "诊断测试")
|
||||
|
||||
**3. Domain Relevance:**
|
||||
- Focus on classes that are central to the scenario's domain
|
||||
- Prioritize classes that represent key concepts, entities, or relationships
|
||||
- Avoid overly generic classes (e.g., "Thing", "Object") unless they have specific domain meaning
|
||||
|
||||
**4. Class Quantity:**
|
||||
- Extract between 5 and {{ max_classes }} classes
|
||||
- Aim for a balanced set covering the main concepts in the scenario
|
||||
- Quality over quantity: prefer well-defined classes over exhaustive lists
|
||||
|
||||
**5. Clear Descriptions:**
|
||||
- Provide concise, informative descriptions in Chinese (max 500 characters)
|
||||
- Describe what the class represents, not specific instances
|
||||
- Use clear, natural Chinese language that explains the class's role in the domain
|
||||
|
||||
**6. Concrete Examples:**
|
||||
- Provide 2-5 concrete instance examples in Chinese for each class
|
||||
- Examples should be specific, realistic instances of the class
|
||||
- Examples help clarify the class's scope and meaning
|
||||
- Use natural Chinese language for examples
|
||||
- Example format: ["示例1", "示例2", "示例3"]
|
||||
|
||||
**7. Class Hierarchy:**
|
||||
- Identify parent-child relationships where applicable
|
||||
- Use the parent_class field to specify inheritance
|
||||
- Parent class must be one of the extracted classes or a standard OWL class
|
||||
- Leave parent_class as null for top-level classes
|
||||
|
||||
**8. Entity Types:**
|
||||
- Classify each class with an appropriate entity_type
|
||||
- Common types: "Person", "Organization", "Location", "Event", "Concept", "Process", "Object", "Role"
|
||||
- Choose the most specific type that applies
|
||||
|
||||
**9. OWL Reserved Words:**
|
||||
- Do NOT use OWL reserved words as class names
|
||||
- Reserved words include: "Thing", "Nothing", "Class", "Property", "ObjectProperty", "DatatypeProperty", "AnnotationProperty", "Ontology", "Individual", "Literal"
|
||||
- If a reserved word is needed, add a domain-specific prefix (e.g., "MedicalClass" instead of "Class")
|
||||
|
||||
**10. Language Consistency:**
|
||||
- Extract all class names in English (PascalCase format) for the "name" field
|
||||
- Provide Chinese translation for class names in the "name_chinese" field
|
||||
- Descriptions MUST be in Chinese (中文)
|
||||
- Examples MUST be in Chinese (中文)
|
||||
- Use clear, natural Chinese language for descriptions and examples
|
||||
|
||||
===Examples===
|
||||
|
||||
**Example 1 (Healthcare Domain):**
|
||||
Scenario: "A hospital manages patient records, schedules appointments, and coordinates medical procedures. Doctors diagnose conditions and prescribe treatments."
|
||||
|
||||
Output:
|
||||
{
|
||||
"classes": [
|
||||
{
|
||||
"name": "Patient",
|
||||
"name_chinese": "患者",
|
||||
"description": "在医疗机构接受医疗护理或治疗的人",
|
||||
"examples": ["张三", "李四", "患有糖尿病的老年患者"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Person",
|
||||
"domain": "Healthcare"
|
||||
},
|
||||
{
|
||||
"name": "MedicalProcedure",
|
||||
"name_chinese": "医疗程序",
|
||||
"description": "为医疗诊断或治疗而执行的系统性操作流程",
|
||||
"examples": ["手术", "血液检查", "X光检查", "疫苗接种"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Process",
|
||||
"domain": "Healthcare"
|
||||
},
|
||||
{
|
||||
"name": "Diagnosis",
|
||||
"name_chinese": "诊断",
|
||||
"description": "基于症状和检查结果对疾病或状况的识别",
|
||||
"examples": ["糖尿病诊断", "癌症诊断", "流感诊断"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Concept",
|
||||
"domain": "Healthcare"
|
||||
},
|
||||
{
|
||||
"name": "Doctor",
|
||||
"name_chinese": "医生",
|
||||
"description": "诊断和治疗患者的持证医疗专业人员",
|
||||
"examples": ["全科医生", "外科医生", "心脏病专家"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Role",
|
||||
"domain": "Healthcare"
|
||||
},
|
||||
{
|
||||
"name": "Treatment",
|
||||
"name_chinese": "治疗",
|
||||
"description": "为治愈或管理疾病状况而提供的医疗护理或疗法",
|
||||
"examples": ["药物治疗", "物理治疗", "化疗", "手术治疗"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Process",
|
||||
"domain": "Healthcare"
|
||||
}
|
||||
],
|
||||
"domain": "Healthcare",
|
||||
"namespace": "http://example.org/healthcare#"
|
||||
}
|
||||
|
||||
**Example 2 (Education Domain):**
|
||||
Scenario: "A university offers courses taught by professors. Students enroll in programs, attend lectures, and complete assignments to earn degrees."
|
||||
|
||||
Output:
|
||||
{
|
||||
"classes": [
|
||||
{
|
||||
"name": "Student",
|
||||
"name_chinese": "学生",
|
||||
"description": "在教育机构注册学习的人",
|
||||
"examples": ["本科生", "研究生", "在职学生"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Role",
|
||||
"domain": "Education"
|
||||
},
|
||||
{
|
||||
"name": "Course",
|
||||
"name_chinese": "课程",
|
||||
"description": "涵盖特定学科或主题的结构化教育课程",
|
||||
"examples": ["计算机科学导论", "微积分I", "世界历史"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Concept",
|
||||
"domain": "Education"
|
||||
},
|
||||
{
|
||||
"name": "Professor",
|
||||
"name_chinese": "教授",
|
||||
"description": "教授课程并进行研究的学术教师",
|
||||
"examples": ["助理教授", "副教授", "正教授"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Role",
|
||||
"domain": "Education"
|
||||
},
|
||||
{
|
||||
"name": "AcademicProgram",
|
||||
"name_chinese": "学术项目",
|
||||
"description": "通向学位或证书的结构化课程体系",
|
||||
"examples": ["理学学士", "文学硕士", "博士项目"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Concept",
|
||||
"domain": "Education"
|
||||
},
|
||||
{
|
||||
"name": "Assignment",
|
||||
"name_chinese": "作业",
|
||||
"description": "分配给学生以评估学习成果的任务或项目",
|
||||
"examples": ["论文", "习题集", "研究报告", "实验报告"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Object",
|
||||
"domain": "Education"
|
||||
},
|
||||
{
|
||||
"name": "Lecture",
|
||||
"name_chinese": "讲座",
|
||||
"description": "由教师进行的教育性演讲或讲座",
|
||||
"examples": ["入门讲座", "客座讲座", "在线讲座"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Event",
|
||||
"domain": "Education"
|
||||
}
|
||||
],
|
||||
"domain": "Education",
|
||||
"namespace": "http://example.org/education#"
|
||||
}
|
||||
|
||||
===Output Format===
|
||||
|
||||
**JSON Requirements:**
|
||||
- Use only ASCII double quotes (") for JSON structure
|
||||
- Never use Chinese quotation marks ("") or Unicode quotes
|
||||
- Escape quotation marks in text with backslashes (\")
|
||||
- Ensure proper string closure and comma separation
|
||||
- No line breaks within JSON string values
|
||||
- All class names must be in PascalCase format
|
||||
- All class names must be unique (case-insensitive)
|
||||
- Extract between 5 and {{ max_classes }} classes
|
||||
|
||||
{{ json_schema }}
|
||||
@@ -5,6 +5,12 @@
|
||||
===Task===
|
||||
Extract entities and knowledge triplets from the given statement.
|
||||
|
||||
{% if language == "zh" %}
|
||||
**重要:请使用中文生成实体描述(description)和示例(example)。**
|
||||
{% else %}
|
||||
**Important: Please generate entity descriptions and examples in English.**
|
||||
{% endif %}
|
||||
|
||||
===Inputs===
|
||||
**Chunk Content:** "{{ chunk_content }}"
|
||||
**Statement:** "{{ statement }}"
|
||||
@@ -13,6 +19,13 @@ Extract entities and knowledge triplets from the given statement.
|
||||
|
||||
**Entity Extraction:**
|
||||
- Extract entities with their types, context-independent descriptions, **concise examples**, aliases, and semantic memory classification
|
||||
{% if language == "zh" %}
|
||||
- **实体描述(description)必须使用中文**
|
||||
- **示例(example)必须使用中文**
|
||||
{% else %}
|
||||
- **Entity descriptions must be in English**
|
||||
- **Examples must be in English**
|
||||
{% endif %}
|
||||
- **Semantic Memory Classification (is_explicit_memory):**
|
||||
* Set to `true` if the entity represents **explicit/semantic memory**:
|
||||
- **Concepts:** "Machine Learning", "Photosynthesis", "Democracy", "人工智能", "光合作用", "民主"
|
||||
@@ -334,9 +347,11 @@ Output:
|
||||
- Escape quotation marks in text with backslashes (\")
|
||||
- Ensure proper string closure and comma separation
|
||||
- No line breaks within JSON string values
|
||||
- The output language should ALWAYS match the input language
|
||||
- If input is in English, extract statements in English
|
||||
- If input is in Chinese, extract statements in Chinese
|
||||
{% if language == "zh" %}
|
||||
- **语言要求:实体描述(description)和示例(example)必须使用中文**
|
||||
{% else %}
|
||||
- **Language Requirement: Entity descriptions and examples must be in English**
|
||||
{% endif %}
|
||||
- Preserve the original language and do not translate
|
||||
|
||||
{{ json_schema }}
|
||||
@@ -5,10 +5,21 @@
|
||||
=== Task ===
|
||||
Summarize the provided conversation chunks into a concise Memory summary.
|
||||
|
||||
{% if language == "zh" %}
|
||||
**重要:请使用中文生成摘要内容。**
|
||||
{% else %}
|
||||
**Important: Please generate the summary content in English.**
|
||||
{% endif %}
|
||||
|
||||
=== Requirements ===
|
||||
- Focus on factual statements, user preferences, relationships, and salient temporal context.
|
||||
- Avoid repetition and filler; be specific.
|
||||
- Keep it under {{ max_words or 200 }} words.
|
||||
{% if language == "zh" %}
|
||||
- 摘要内容必须使用中文
|
||||
{% else %}
|
||||
- Summary content must be in English
|
||||
{% endif %}
|
||||
- Output must be valid JSON conforming to the schema below.
|
||||
|
||||
=== Input ===
|
||||
@@ -24,6 +35,11 @@ Summarize the provided conversation chunks into a concise Memory summary.
|
||||
4. Do not include line breaks within JSON string values
|
||||
5. Example of proper escaping: "statement": "张曼婷说:\"我很喜欢这本书。\""
|
||||
|
||||
The output language should always be the same as the input language.
|
||||
{% if language == "zh" %}
|
||||
**语言要求:输出内容必须使用中文。**
|
||||
{% else %}
|
||||
**Language Requirement: The output content must be in English.**
|
||||
{% endif %}
|
||||
|
||||
Return only a list of extracted labelled statements in the JSON ARRAY of objects that match the schema below:
|
||||
{{ json_schema }}
|
||||
10
api/app/core/memory/utils/validation/__init__.py
Normal file
10
api/app/core/memory/utils/validation/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Validation utilities for ontology extraction.
|
||||
|
||||
This module provides validation classes for ontology class names,
|
||||
descriptions, and OWL compliance checking.
|
||||
"""
|
||||
|
||||
from .ontology_validator import OntologyValidator
|
||||
from .owl_validator import OWLValidator
|
||||
|
||||
__all__ = ['OntologyValidator', 'OWLValidator']
|
||||
268
api/app/core/memory/utils/validation/ontology_validator.py
Normal file
268
api/app/core/memory/utils/validation/ontology_validator.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""String validation for ontology class names and descriptions.
|
||||
|
||||
This module provides the OntologyValidator class for validating and sanitizing
|
||||
ontology class names according to OWL standards and naming conventions.
|
||||
|
||||
Classes:
|
||||
OntologyValidator: Validates class names, removes duplicates, and truncates descriptions
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
|
||||
from app.core.memory.models.ontology_models import OntologyClass
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OntologyValidator:
|
||||
"""Validator for ontology class names and descriptions.
|
||||
|
||||
This validator performs string-level validation including:
|
||||
- PascalCase naming convention validation
|
||||
- OWL reserved word checking
|
||||
- Duplicate class name removal
|
||||
- Description length truncation
|
||||
|
||||
Attributes:
|
||||
OWL_RESERVED_WORDS: Set of OWL reserved words that cannot be used as class names
|
||||
"""
|
||||
|
||||
# OWL reserved words that cannot be used as class names
|
||||
OWL_RESERVED_WORDS = {
|
||||
'Thing', 'Nothing', 'Class', 'Property',
|
||||
'ObjectProperty', 'DatatypeProperty', 'FunctionalProperty',
|
||||
'InverseFunctionalProperty', 'TransitiveProperty', 'SymmetricProperty',
|
||||
'AsymmetricProperty', 'ReflexiveProperty', 'IrreflexiveProperty',
|
||||
'Restriction', 'Ontology', 'Individual', 'NamedIndividual',
|
||||
'Annotation', 'AnnotationProperty', 'Axiom',
|
||||
'AllDifferent', 'AllDisjointClasses', 'AllDisjointProperties',
|
||||
'Datatype', 'DataRange', 'Literal',
|
||||
'DeprecatedClass', 'DeprecatedProperty',
|
||||
'Imports', 'IncompatibleWith', 'PriorVersion', 'VersionInfo',
|
||||
'BackwardCompatibleWith', 'OntologyProperty',
|
||||
}
|
||||
|
||||
def validate_class_name(self, name: str) -> Tuple[bool, str]:
|
||||
"""Validate that a class name follows OWL naming conventions.
|
||||
|
||||
Validation rules:
|
||||
1. Must not be empty
|
||||
2. Must start with an uppercase letter (PascalCase)
|
||||
3. Cannot contain spaces
|
||||
4. Can only contain alphanumeric characters and underscores
|
||||
5. Cannot be an OWL reserved word
|
||||
|
||||
Args:
|
||||
name: The class name to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
- is_valid: True if the name is valid, False otherwise
|
||||
- error_message: Empty string if valid, error description if invalid
|
||||
|
||||
Examples:
|
||||
>>> validator = OntologyValidator()
|
||||
>>> validator.validate_class_name("MedicalProcedure")
|
||||
(True, "")
|
||||
>>> validator.validate_class_name("medical procedure")
|
||||
(False, "Class name 'medical procedure' cannot contain spaces")
|
||||
>>> validator.validate_class_name("Thing")
|
||||
(False, "Class name 'Thing' is an OWL reserved word")
|
||||
"""
|
||||
logger.debug(f"Validating class name: '{name}'")
|
||||
|
||||
# Check if empty
|
||||
if not name or not name.strip():
|
||||
error_msg = "Class name cannot be empty"
|
||||
logger.warning(f"Validation failed: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
name = name.strip()
|
||||
|
||||
# Check if it's an OWL reserved word
|
||||
if name in self.OWL_RESERVED_WORDS:
|
||||
error_msg = f"Class name '{name}' is an OWL reserved word"
|
||||
logger.warning(f"Validation failed: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
# Check if starts with uppercase letter
|
||||
if not name[0].isupper():
|
||||
error_msg = f"Class name '{name}' must start with an uppercase letter (PascalCase)"
|
||||
logger.warning(f"Validation failed: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
# Check for spaces
|
||||
if ' ' in name:
|
||||
error_msg = f"Class name '{name}' cannot contain spaces"
|
||||
logger.warning(f"Validation failed: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
# Check for invalid characters (only alphanumeric and underscore allowed)
|
||||
if not re.match(r'^[A-Za-z0-9_]+$', name):
|
||||
error_msg = f"Class name '{name}' contains invalid characters. Only alphanumeric characters and underscores are allowed"
|
||||
logger.warning(f"Validation failed: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
logger.debug(f"Class name '{name}' is valid")
|
||||
return True, ""
|
||||
|
||||
def sanitize_class_name(self, name: str) -> str:
|
||||
"""Attempt to sanitize an invalid class name into a valid format.
|
||||
|
||||
Sanitization steps:
|
||||
1. Strip whitespace
|
||||
2. Remove invalid characters
|
||||
3. Replace spaces with empty string (PascalCase)
|
||||
4. Capitalize first letter of each word
|
||||
5. If result is empty or starts with number, prefix with 'Class'
|
||||
|
||||
Args:
|
||||
name: The class name to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized class name that should pass validation
|
||||
|
||||
Examples:
|
||||
>>> validator = OntologyValidator()
|
||||
>>> validator.sanitize_class_name("medical procedure")
|
||||
'MedicalProcedure'
|
||||
>>> validator.sanitize_class_name("patient-record")
|
||||
'PatientRecord'
|
||||
>>> validator.sanitize_class_name("123invalid")
|
||||
'Class123Invalid'
|
||||
"""
|
||||
logger.debug(f"Sanitizing class name: '{name}'")
|
||||
|
||||
if not name or not name.strip():
|
||||
logger.warning("Empty class name provided for sanitization, returning 'UnnamedClass'")
|
||||
return "UnnamedClass"
|
||||
|
||||
# Strip whitespace
|
||||
name = name.strip()
|
||||
original_name = name
|
||||
|
||||
# Split on spaces, hyphens, and underscores, then capitalize each word
|
||||
words = re.split(r'[\s\-_]+', name)
|
||||
|
||||
# Capitalize first letter of each word and keep rest as is
|
||||
sanitized_words = []
|
||||
for word in words:
|
||||
if word:
|
||||
# Remove non-alphanumeric characters except underscore
|
||||
clean_word = re.sub(r'[^A-Za-z0-9_]', '', word)
|
||||
if clean_word:
|
||||
# Capitalize first letter
|
||||
sanitized_words.append(clean_word[0].upper() + clean_word[1:])
|
||||
|
||||
# Join words
|
||||
sanitized = ''.join(sanitized_words)
|
||||
|
||||
# If empty or starts with number, prefix with 'Class'
|
||||
if not sanitized or sanitized[0].isdigit():
|
||||
sanitized = 'Class' + sanitized
|
||||
logger.info(f"Prefixed class name with 'Class': '{original_name}' -> '{sanitized}'")
|
||||
|
||||
# If it's a reserved word, append 'Class' suffix
|
||||
if sanitized in self.OWL_RESERVED_WORDS:
|
||||
sanitized = sanitized + 'Class'
|
||||
logger.info(f"Appended 'Class' suffix to reserved word: '{original_name}' -> '{sanitized}'")
|
||||
|
||||
logger.info(f"Sanitized class name: '{original_name}' -> '{sanitized}'")
|
||||
return sanitized
|
||||
|
||||
def remove_duplicates(self, classes: List[OntologyClass]) -> List[OntologyClass]:
|
||||
"""Remove duplicate ontology classes based on case-insensitive name comparison.
|
||||
|
||||
When duplicates are found, keeps the first occurrence and discards subsequent ones.
|
||||
Comparison is case-insensitive to catch variations like 'Patient' and 'patient'.
|
||||
|
||||
Args:
|
||||
classes: List of OntologyClass objects
|
||||
|
||||
Returns:
|
||||
List of OntologyClass objects with duplicates removed
|
||||
|
||||
Examples:
|
||||
>>> validator = OntologyValidator()
|
||||
>>> classes = [
|
||||
... OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare"),
|
||||
... OntologyClass(name="patient", description="Another patient", entity_type="Person", domain="Healthcare"),
|
||||
... OntologyClass(name="Doctor", description="A doctor", entity_type="Person", domain="Healthcare"),
|
||||
... ]
|
||||
>>> unique = validator.remove_duplicates(classes)
|
||||
>>> len(unique)
|
||||
2
|
||||
>>> [c.name for c in unique]
|
||||
['Patient', 'Doctor']
|
||||
"""
|
||||
if not classes:
|
||||
logger.debug("No classes to check for duplicates")
|
||||
return classes
|
||||
|
||||
logger.debug(f"Checking {len(classes)} classes for duplicates")
|
||||
|
||||
seen_names = set()
|
||||
unique_classes = []
|
||||
duplicates_found = []
|
||||
|
||||
for ontology_class in classes:
|
||||
# Use lowercase for comparison
|
||||
name_lower = ontology_class.name.lower()
|
||||
|
||||
if name_lower not in seen_names:
|
||||
seen_names.add(name_lower)
|
||||
unique_classes.append(ontology_class)
|
||||
else:
|
||||
duplicates_found.append(ontology_class.name)
|
||||
logger.debug(f"Duplicate class found and removed: '{ontology_class.name}'")
|
||||
|
||||
if duplicates_found:
|
||||
logger.info(
|
||||
f"Removed {len(duplicates_found)} duplicate classes: {duplicates_found}"
|
||||
)
|
||||
else:
|
||||
logger.debug("No duplicate classes found")
|
||||
|
||||
return unique_classes
|
||||
|
||||
def truncate_description(self, description: str, max_length: int = 500) -> str:
|
||||
"""Truncate a description to a maximum length.
|
||||
|
||||
If the description exceeds max_length, it will be truncated and
|
||||
an ellipsis (...) will be appended to indicate truncation.
|
||||
|
||||
Args:
|
||||
description: The description text to truncate
|
||||
max_length: Maximum allowed length (default: 500)
|
||||
|
||||
Returns:
|
||||
Truncated description string
|
||||
|
||||
Examples:
|
||||
>>> validator = OntologyValidator()
|
||||
>>> long_desc = "A" * 600
|
||||
>>> truncated = validator.truncate_description(long_desc, max_length=500)
|
||||
>>> len(truncated)
|
||||
500
|
||||
>>> truncated.endswith("...")
|
||||
True
|
||||
"""
|
||||
if not description:
|
||||
return ""
|
||||
|
||||
if len(description) <= max_length:
|
||||
return description
|
||||
|
||||
# Truncate and add ellipsis
|
||||
# Reserve 3 characters for "..."
|
||||
truncate_at = max_length - 3
|
||||
truncated = description[:truncate_at] + "..."
|
||||
|
||||
logger.debug(
|
||||
f"Truncated description from {len(description)} to {len(truncated)} characters"
|
||||
)
|
||||
|
||||
return truncated
|
||||
585
api/app/core/memory/utils/validation/owl_validator.py
Normal file
585
api/app/core/memory/utils/validation/owl_validator.py
Normal file
@@ -0,0 +1,585 @@
|
||||
"""OWL semantic validation for ontology classes using Owlready2.
|
||||
|
||||
This module provides the OWLValidator class for validating ontology classes
|
||||
against OWL standards using the Owlready2 library. It performs semantic
|
||||
validation including consistency checking, circular inheritance detection,
|
||||
and OWL file export.
|
||||
|
||||
Classes:
|
||||
OWLValidator: Validates ontology classes using OWL reasoning and exports to OWL formats
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from owlready2 import (
|
||||
World,
|
||||
Thing,
|
||||
get_ontology,
|
||||
sync_reasoner_pellet,
|
||||
OwlReadyInconsistentOntologyError,
|
||||
)
|
||||
|
||||
from app.core.memory.models.ontology_models import OntologyClass
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OWLValidator:
|
||||
"""Validator for OWL semantic validation of ontology classes.
|
||||
|
||||
This validator performs semantic-level validation using Owlready2 including:
|
||||
- Creating OWL classes from ontology class definitions
|
||||
- Running consistency checking with Pellet reasoner
|
||||
- Detecting circular inheritance
|
||||
- Validating Protégé compatibility
|
||||
- Exporting ontologies to various OWL formats (RDF/XML, Turtle, N-Triples)
|
||||
|
||||
Attributes:
|
||||
base_namespace: Base URI for the ontology namespace
|
||||
"""
|
||||
|
||||
def __init__(self, base_namespace: str = "http://example.org/ontology#"):
|
||||
"""Initialize the OWL validator.
|
||||
|
||||
Args:
|
||||
base_namespace: Base URI for the ontology namespace (default: http://example.org/ontology#)
|
||||
"""
|
||||
self.base_namespace = base_namespace
|
||||
|
||||
def validate_ontology_classes(
|
||||
self,
|
||||
classes: List[OntologyClass],
|
||||
) -> Tuple[bool, List[str], Optional[World]]:
|
||||
"""Validate extracted ontology classes against OWL standards.
|
||||
|
||||
This method creates an OWL ontology from the provided classes using Owlready2,
|
||||
runs consistency checking with the Pellet reasoner, and detects common issues
|
||||
like circular inheritance.
|
||||
|
||||
Args:
|
||||
classes: List of OntologyClass objects to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_messages, world):
|
||||
- is_valid: True if ontology is valid and consistent, False otherwise
|
||||
- error_messages: List of error/warning messages
|
||||
- world: Owlready2 World object containing the ontology (None if validation failed)
|
||||
|
||||
Examples:
|
||||
>>> validator = OWLValidator()
|
||||
>>> classes = [
|
||||
... OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare"),
|
||||
... OntologyClass(name="Doctor", description="A doctor", entity_type="Person", domain="Healthcare"),
|
||||
... ]
|
||||
>>> is_valid, errors, world = validator.validate_ontology_classes(classes)
|
||||
>>> is_valid
|
||||
True
|
||||
>>> len(errors)
|
||||
0
|
||||
"""
|
||||
if not classes:
|
||||
return False, ["No classes provided for validation"], None
|
||||
|
||||
errors = []
|
||||
|
||||
try:
|
||||
# Create a new world (isolated ontology environment)
|
||||
world = World()
|
||||
|
||||
# Use a proper ontology IRI
|
||||
# Owlready2 expects the IRI to end with .owl or similar
|
||||
onto_iri = self.base_namespace.rstrip('#/')
|
||||
if not onto_iri.endswith('.owl'):
|
||||
onto_iri = onto_iri + '.owl'
|
||||
|
||||
# Create ontology
|
||||
onto = world.get_ontology(onto_iri)
|
||||
|
||||
with onto:
|
||||
# Dictionary to store created OWL classes for parent reference
|
||||
owl_classes = {}
|
||||
|
||||
# First pass: Create all classes without parent relationships
|
||||
for ontology_class in classes:
|
||||
try:
|
||||
# Create OWL class dynamically using type() with Thing as base
|
||||
# The key is to NOT set namespace in the dict, let Owlready2 handle it
|
||||
owl_class = type(
|
||||
ontology_class.name, # Class name
|
||||
(Thing,), # Base classes
|
||||
{} # Class dict (empty, let Owlready2 manage)
|
||||
)
|
||||
|
||||
# Add label (rdfs:label) - include both English and Chinese names
|
||||
labels = [ontology_class.name]
|
||||
if ontology_class.name_chinese:
|
||||
labels.append(ontology_class.name_chinese)
|
||||
owl_class.label = labels
|
||||
|
||||
# Add comment (rdfs:comment) with description
|
||||
if ontology_class.description:
|
||||
owl_class.comment = [ontology_class.description]
|
||||
|
||||
# Store for parent relationship setup
|
||||
owl_classes[ontology_class.name] = owl_class
|
||||
|
||||
logger.debug(
|
||||
f"Created OWL class: {ontology_class.name} "
|
||||
f"(Chinese: {ontology_class.name_chinese}) "
|
||||
f"IRI: {owl_class.iri if hasattr(owl_class, 'iri') else 'N/A'}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to create OWL class '{ontology_class.name}': {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(error_msg, exc_info=True)
|
||||
|
||||
# Second pass: Set up parent relationships
|
||||
for ontology_class in classes:
|
||||
if ontology_class.parent_class and ontology_class.name in owl_classes:
|
||||
parent_name = ontology_class.parent_class
|
||||
|
||||
# Check if parent exists
|
||||
if parent_name in owl_classes:
|
||||
try:
|
||||
child_class = owl_classes[ontology_class.name]
|
||||
parent_class = owl_classes[parent_name]
|
||||
|
||||
# Set parent by modifying is_a
|
||||
child_class.is_a = [parent_class]
|
||||
|
||||
logger.debug(
|
||||
f"Set parent relationship: {ontology_class.name} -> {parent_name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = (
|
||||
f"Failed to set parent relationship "
|
||||
f"'{ontology_class.name}' -> '{parent_name}': {str(e)}"
|
||||
)
|
||||
errors.append(error_msg)
|
||||
logger.warning(error_msg)
|
||||
else:
|
||||
warning_msg = (
|
||||
f"Parent class '{parent_name}' not found for '{ontology_class.name}'"
|
||||
)
|
||||
errors.append(warning_msg)
|
||||
logger.warning(warning_msg)
|
||||
|
||||
# Check for circular inheritance
|
||||
for class_name, owl_class in owl_classes.items():
|
||||
if self._has_circular_inheritance(owl_class):
|
||||
error_msg = f"Circular inheritance detected for class '{class_name}'"
|
||||
errors.append(error_msg)
|
||||
logger.error(error_msg)
|
||||
|
||||
# Run consistency checking with Pellet reasoner
|
||||
try:
|
||||
logger.info("Running Pellet reasoner for consistency checking...")
|
||||
sync_reasoner_pellet(world, infer_property_values=True, infer_data_property_values=True)
|
||||
logger.info("Consistency check passed")
|
||||
|
||||
except OwlReadyInconsistentOntologyError as e:
|
||||
error_msg = f"Ontology is inconsistent: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(error_msg)
|
||||
return False, errors, world
|
||||
|
||||
except Exception as e:
|
||||
# Reasoner errors are often due to Java not being installed or configured
|
||||
# Log as warning but don't fail validation - ontology structure is still valid
|
||||
warning_msg = f"Reasoner check skipped: {str(e)}"
|
||||
if str(e).strip(): # Only log if there's an actual error message
|
||||
logger.warning(warning_msg)
|
||||
else:
|
||||
logger.warning("Reasoner check skipped: Java may not be installed or configured")
|
||||
# Continue - ontology structure is valid even without reasoner check
|
||||
|
||||
# If we have errors (excluding warnings), validation failed
|
||||
is_valid = len(errors) == 0
|
||||
|
||||
return is_valid, errors, world
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"OWL validation failed: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return False, errors, None
|
||||
|
||||
def _has_circular_inheritance(self, owl_class) -> bool:
|
||||
"""Check if an OWL class has circular inheritance.
|
||||
|
||||
Circular inheritance occurs when a class inherits from itself through
|
||||
a chain of parent relationships (e.g., A -> B -> C -> A).
|
||||
|
||||
Args:
|
||||
owl_class: Owlready2 class object to check
|
||||
|
||||
Returns:
|
||||
True if circular inheritance is detected, False otherwise
|
||||
"""
|
||||
visited = set()
|
||||
current = owl_class
|
||||
|
||||
while current:
|
||||
# Get class IRI or name as identifier
|
||||
class_id = str(current.iri) if hasattr(current, 'iri') else str(current)
|
||||
|
||||
if class_id in visited:
|
||||
# Found a cycle
|
||||
return True
|
||||
|
||||
visited.add(class_id)
|
||||
|
||||
# Get parent classes (is_a relationship)
|
||||
parents = getattr(current, 'is_a', [])
|
||||
|
||||
# Filter out Thing and other base classes
|
||||
parent_classes = [p for p in parents if p != Thing and hasattr(p, 'is_a')]
|
||||
|
||||
if not parent_classes:
|
||||
# No more parents, no cycle
|
||||
break
|
||||
|
||||
# Check first parent (in single inheritance)
|
||||
current = parent_classes[0] if parent_classes else None
|
||||
|
||||
return False
|
||||
|
||||
def export_to_owl(
|
||||
self,
|
||||
world: World,
|
||||
output_path: Optional[str] = None,
|
||||
format: str = "rdfxml",
|
||||
classes: Optional[List] = None
|
||||
) -> str:
|
||||
"""Export ontology to OWL file in specified format.
|
||||
|
||||
Supported formats:
|
||||
- rdfxml: RDF/XML format (default, most compatible)
|
||||
- turtle: Turtle format (more readable)
|
||||
- ntriples: N-Triples format (simplest)
|
||||
- json: JSON format (simplified, human-readable)
|
||||
|
||||
Args:
|
||||
world: Owlready2 World object containing the ontology
|
||||
output_path: Optional file path to save the ontology (if None, returns string)
|
||||
format: Export format - "rdfxml", "turtle", "ntriples", or "json" (default: "rdfxml")
|
||||
classes: Optional list of OntologyClass objects (required for json format)
|
||||
|
||||
Returns:
|
||||
String representation of the exported ontology
|
||||
|
||||
Raises:
|
||||
ValueError: If format is not supported
|
||||
RuntimeError: If export fails
|
||||
|
||||
Examples:
|
||||
>>> validator = OWLValidator()
|
||||
>>> is_valid, errors, world = validator.validate_ontology_classes(classes)
|
||||
>>> owl_content = validator.export_to_owl(world, "ontology.owl", format="rdfxml")
|
||||
"""
|
||||
# Validate format
|
||||
valid_formats = ["rdfxml", "turtle", "ntriples", "json"]
|
||||
if format not in valid_formats:
|
||||
raise ValueError(
|
||||
f"Unsupported format '{format}'. Must be one of: {', '.join(valid_formats)}"
|
||||
)
|
||||
|
||||
# JSON format doesn't need OWL processing
|
||||
if format == "json":
|
||||
if not classes:
|
||||
raise ValueError("Classes list is required for JSON format export")
|
||||
return self._export_to_json(classes)
|
||||
|
||||
# For OWL formats, world is required
|
||||
if not world:
|
||||
raise ValueError("World object is None. Cannot export ontology.")
|
||||
|
||||
# Note: Owlready2 has issues with turtle format export
|
||||
# We'll handle it specially by converting from rdfxml
|
||||
use_conversion = (format == "turtle")
|
||||
|
||||
try:
|
||||
# Get all ontologies in the world
|
||||
ontologies = list(world.ontologies.values())
|
||||
|
||||
if not ontologies:
|
||||
raise RuntimeError("No ontologies found in world")
|
||||
|
||||
# Find the ontology with classes (skip anonymous/empty ontologies)
|
||||
onto = None
|
||||
for ont in ontologies:
|
||||
classes_count = len(list(ont.classes()))
|
||||
logger.debug(f"Checking ontology {ont.base_iri}: {classes_count} classes")
|
||||
if classes_count > 0:
|
||||
onto = ont
|
||||
break
|
||||
|
||||
# If no ontology with classes found, use the last non-anonymous one
|
||||
if onto is None:
|
||||
for ont in reversed(ontologies):
|
||||
if ont.base_iri != "http://anonymous/":
|
||||
onto = ont
|
||||
break
|
||||
|
||||
# If still no ontology, use the first one
|
||||
if onto is None:
|
||||
onto = ontologies[0]
|
||||
|
||||
# Log ontology contents for debugging
|
||||
logger.info(f"Ontology IRI: {onto.base_iri}")
|
||||
logger.info(f"Ontology contains {len(list(onto.classes()))} classes")
|
||||
|
||||
# List all classes in the ontology
|
||||
all_classes = list(onto.classes())
|
||||
for cls in all_classes:
|
||||
logger.info(f"Class in ontology: {cls.name} (IRI: {cls.iri})")
|
||||
if hasattr(cls, 'label'):
|
||||
logger.debug(f" Labels: {cls.label}")
|
||||
if hasattr(cls, 'comment'):
|
||||
logger.debug(f" Comments: {cls.comment}")
|
||||
|
||||
if len(all_classes) == 0:
|
||||
logger.warning("No classes found in ontology! This may indicate a problem with class creation.")
|
||||
|
||||
if output_path:
|
||||
# Save to file
|
||||
export_format = "rdfxml" if use_conversion else format
|
||||
logger.info(f"Exporting ontology to {output_path} in {export_format} format")
|
||||
onto.save(file=output_path, format=export_format)
|
||||
|
||||
# Read back the file content to return
|
||||
with open(output_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# Convert to turtle if needed
|
||||
if use_conversion:
|
||||
content = self._convert_to_turtle(content)
|
||||
|
||||
logger.info(f"Successfully exported ontology to {output_path}")
|
||||
|
||||
# Format the content for better readability
|
||||
content = self._format_owl_content(content, format)
|
||||
|
||||
return content
|
||||
else:
|
||||
# Export to string (save to temporary location and read)
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.owl', delete=False) as tmp:
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
export_format = "rdfxml" if use_conversion else format
|
||||
onto.save(file=tmp_path, format=export_format)
|
||||
|
||||
with open(tmp_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# Convert to turtle if needed
|
||||
if use_conversion:
|
||||
content = self._convert_to_turtle(content)
|
||||
|
||||
# Format the content for better readability
|
||||
content = self._format_owl_content(content, format)
|
||||
|
||||
return content
|
||||
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to export ontology: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
def _export_to_json(self, classes: List) -> str:
|
||||
"""Export ontology classes to simplified JSON format.
|
||||
|
||||
This format is more compact and easier to parse than OWL XML.
|
||||
|
||||
Args:
|
||||
classes: List of OntologyClass objects
|
||||
|
||||
Returns:
|
||||
JSON string representation (compact format)
|
||||
"""
|
||||
import json
|
||||
|
||||
result = {
|
||||
"ontology": {
|
||||
"namespace": self.base_namespace,
|
||||
"classes": []
|
||||
}
|
||||
}
|
||||
|
||||
for cls in classes:
|
||||
class_data = {
|
||||
"name": cls.name,
|
||||
"name_chinese": cls.name_chinese,
|
||||
"description": cls.description,
|
||||
"entity_type": cls.entity_type,
|
||||
"domain": cls.domain,
|
||||
"parent_class": cls.parent_class,
|
||||
"examples": cls.examples if hasattr(cls, 'examples') else []
|
||||
}
|
||||
result["ontology"]["classes"].append(class_data)
|
||||
|
||||
# 使用紧凑格式:无缩进,使用分隔符减少空格
|
||||
return json.dumps(result, ensure_ascii=False, separators=(',', ':'))
|
||||
|
||||
def _convert_to_turtle(self, rdfxml_content: str) -> str:
|
||||
"""Convert RDF/XML content to Turtle format using rdflib.
|
||||
|
||||
Args:
|
||||
rdfxml_content: RDF/XML format content
|
||||
|
||||
Returns:
|
||||
Turtle format content
|
||||
"""
|
||||
try:
|
||||
from rdflib import Graph
|
||||
|
||||
# Parse RDF/XML
|
||||
g = Graph()
|
||||
g.parse(data=rdfxml_content, format="xml")
|
||||
|
||||
# Serialize to Turtle
|
||||
turtle_content = g.serialize(format="turtle")
|
||||
|
||||
# Handle bytes vs string
|
||||
if isinstance(turtle_content, bytes):
|
||||
turtle_content = turtle_content.decode('utf-8')
|
||||
|
||||
return turtle_content
|
||||
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"rdflib is not installed. Cannot convert to Turtle format. "
|
||||
"Install with: pip install rdflib"
|
||||
)
|
||||
return rdfxml_content
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert to Turtle format: {e}")
|
||||
return rdfxml_content
|
||||
|
||||
def _format_owl_content(self, content: str, format: str) -> str:
|
||||
"""Format OWL content for better readability.
|
||||
|
||||
Args:
|
||||
content: Raw OWL content string
|
||||
format: Format type (rdfxml, turtle, ntriples)
|
||||
|
||||
Returns:
|
||||
Formatted OWL content string
|
||||
"""
|
||||
if format == "rdfxml":
|
||||
# Format XML with proper indentation
|
||||
try:
|
||||
import xml.dom.minidom as minidom
|
||||
dom = minidom.parseString(content)
|
||||
# Pretty print with 2-space indentation
|
||||
formatted = dom.toprettyxml(indent=" ", encoding="utf-8").decode("utf-8")
|
||||
|
||||
# Remove extra blank lines
|
||||
lines = []
|
||||
prev_blank = False
|
||||
for line in formatted.split('\n'):
|
||||
is_blank = not line.strip()
|
||||
if not (is_blank and prev_blank): # Skip consecutive blank lines
|
||||
lines.append(line)
|
||||
prev_blank = is_blank
|
||||
|
||||
formatted = '\n'.join(lines)
|
||||
|
||||
return formatted
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to format XML content: {e}")
|
||||
return content
|
||||
|
||||
elif format == "turtle":
|
||||
# Turtle format is already relatively readable
|
||||
# Just ensure consistent line endings and not empty
|
||||
if not content or content.strip() == "":
|
||||
logger.warning("Turtle content is empty, this may indicate an export issue")
|
||||
return content.strip() + '\n' if content.strip() else content
|
||||
|
||||
elif format == "ntriples":
|
||||
# N-Triples format is line-based, ensure proper line endings
|
||||
return content.strip() + '\n' if content.strip() else content
|
||||
|
||||
return content
|
||||
|
||||
def validate_with_protege_compatibility(
|
||||
self,
|
||||
classes: List[OntologyClass]
|
||||
) -> Tuple[bool, List[str]]:
|
||||
"""Validate that ontology classes are compatible with Protégé editor.
|
||||
|
||||
Protégé compatibility checks:
|
||||
- Class names are valid OWL identifiers
|
||||
- No special characters that Protégé cannot handle
|
||||
- Namespace is properly formatted
|
||||
- Labels and comments are properly encoded
|
||||
|
||||
Args:
|
||||
classes: List of OntologyClass objects to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_compatible, warnings):
|
||||
- is_compatible: True if compatible with Protégé, False otherwise
|
||||
- warnings: List of compatibility warning messages
|
||||
|
||||
Examples:
|
||||
>>> validator = OWLValidator()
|
||||
>>> classes = [OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare")]
|
||||
>>> is_compatible, warnings = validator.validate_with_protege_compatibility(classes)
|
||||
>>> is_compatible
|
||||
True
|
||||
"""
|
||||
warnings = []
|
||||
|
||||
# Check namespace format
|
||||
if not self.base_namespace.startswith(('http://', 'https://')):
|
||||
warnings.append(
|
||||
f"Namespace '{self.base_namespace}' should start with http:// or https:// "
|
||||
"for Protégé compatibility"
|
||||
)
|
||||
|
||||
if not self.base_namespace.endswith(('#', '/')):
|
||||
warnings.append(
|
||||
f"Namespace '{self.base_namespace}' should end with # or / "
|
||||
"for Protégé compatibility"
|
||||
)
|
||||
|
||||
# Check each class
|
||||
for ontology_class in classes:
|
||||
# Check for special characters that might cause issues
|
||||
if any(char in ontology_class.name for char in ['<', '>', '"', '{', '}', '|', '^', '`']):
|
||||
warnings.append(
|
||||
f"Class name '{ontology_class.name}' contains special characters "
|
||||
"that may cause issues in Protégé"
|
||||
)
|
||||
|
||||
# Check description length (Protégé can handle long descriptions but may display poorly)
|
||||
if ontology_class.description and len(ontology_class.description) > 1000:
|
||||
warnings.append(
|
||||
f"Class '{ontology_class.name}' has a very long description ({len(ontology_class.description)} chars) "
|
||||
"which may display poorly in Protégé"
|
||||
)
|
||||
|
||||
# Check for non-ASCII characters (Protégé supports them but encoding issues may occur)
|
||||
if not ontology_class.name.isascii():
|
||||
warnings.append(
|
||||
f"Class name '{ontology_class.name}' contains non-ASCII characters "
|
||||
"which may cause encoding issues in some Protégé versions"
|
||||
)
|
||||
|
||||
# If no warnings, it's compatible
|
||||
is_compatible = len(warnings) == 0
|
||||
|
||||
return is_compatible, warnings
|
||||
@@ -285,7 +285,7 @@ models:
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen-vl-max
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen-vl-max多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
@@ -298,7 +298,7 @@ models:
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen-vl-plus-0809
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen-vl-plus-0809多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,已废弃
|
||||
is_deprecated: true
|
||||
@@ -311,7 +311,7 @@ models:
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen-vl-plus-2025-01-02
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen-vl-plus-2025-01-02多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
@@ -324,7 +324,7 @@ models:
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen-vl-plus-2025-01-25
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen-vl-plus-2025-01-25多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
@@ -337,7 +337,7 @@ models:
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen-vl-plus-latest
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen-vl-plus-latest多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
@@ -350,7 +350,7 @@ models:
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen-vl-plus
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen-vl-plus多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
@@ -616,7 +616,7 @@ models:
|
||||
- audio
|
||||
logo: dashscope
|
||||
- name: qwen3-vl-235b-a22b-instruct
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen3-vl-235b-a22b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
@@ -631,7 +631,7 @@ models:
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen3-vl-235b-a22b-thinking
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen3-vl-235b-a22b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
@@ -646,7 +646,7 @@ models:
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen3-vl-30b-a3b-instruct
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen3-vl-30b-a3b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
@@ -661,7 +661,7 @@ models:
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen3-vl-30b-a3b-thinking
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen3-vl-30b-a3b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
@@ -676,7 +676,7 @@ models:
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen3-vl-flash
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen3-vl-flash多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
@@ -691,7 +691,7 @@ models:
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen3-vl-plus-2025-09-23
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen3-vl-plus-2025-09-23多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
@@ -704,7 +704,7 @@ models:
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen3-vl-plus
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen3-vl-plus多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
|
||||
@@ -28,7 +28,9 @@ from app.core.rag.common.float_utils import get_float
|
||||
from app.core.rag.common.constants import PAGERANK_FLD, TAG_FLD
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def knowledge_retrieval(
|
||||
query: str,
|
||||
@@ -62,7 +64,15 @@ def knowledge_retrieval(
|
||||
merge_strategy = config.get("merge_strategy", "weight")
|
||||
reranker_id = config.get("reranker_id")
|
||||
reranker_top_k = config.get("reranker_top_k", 1024)
|
||||
use_graph = config.get("use_graph", "false").lower() == "true"
|
||||
# use_graph = config.get("use_graph", "false").lower() == "true"
|
||||
|
||||
use_graph_value = config.get("use_graph", False)
|
||||
if isinstance(use_graph_value, bool):
|
||||
use_graph = use_graph_value
|
||||
elif isinstance(use_graph_value, str):
|
||||
use_graph = use_graph_value.lower() in ("true", "1", "yes")
|
||||
else:
|
||||
use_graph = False
|
||||
|
||||
file_names_filter = []
|
||||
if user_ids:
|
||||
@@ -159,13 +169,29 @@ def knowledge_retrieval(
|
||||
|
||||
# Use the specified reranker for re-ranking
|
||||
if reranker_id:
|
||||
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
|
||||
# use graph
|
||||
try:
|
||||
return rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
|
||||
except Exception as rerank_error:
|
||||
# If reranker fails, log warning and continue with original results
|
||||
logger.warning(
|
||||
"Reranker failed, falling back to original results",
|
||||
extra={
|
||||
"reranker_id": reranker_id,
|
||||
"query": query,
|
||||
"doc_count": len(all_results),
|
||||
"error": str(rerank_error),
|
||||
},
|
||||
)
|
||||
|
||||
if use_graph:
|
||||
from app.core.rag.common.settings import kg_retriever
|
||||
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||
if doc:
|
||||
all_results.insert(0, doc)
|
||||
try:
|
||||
from app.core.rag.common.settings import kg_retriever
|
||||
doc = kg_retriever.retrieval(question=query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||
if doc:
|
||||
all_results.insert(0, doc)
|
||||
except Exception as graph_error:
|
||||
print(f"Failed to retrieve from knowledge graph: {str(graph_error)}")
|
||||
|
||||
return all_results
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -44,7 +44,7 @@ class CodeNodeConfig(BaseNodeConfig):
|
||||
description="code content"
|
||||
)
|
||||
|
||||
language: Literal['python3', 'nodejs'] = Field(
|
||||
language: Literal['python3', 'javascript'] = Field(
|
||||
...,
|
||||
description="language"
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import urllib.parse
|
||||
from string import Template
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
@@ -14,7 +15,7 @@ from app.core.workflow.nodes.code.config import CodeNodeConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SCRIPT_TEMPLATE = Template(dedent("""
|
||||
PYTHON_SCRIPT_TEMPLATE = Template(dedent("""
|
||||
$code
|
||||
|
||||
import json
|
||||
@@ -32,6 +33,20 @@ result = "<<RESULT>>" + output_json + "<<RESULT>>"
|
||||
print(result)
|
||||
"""))
|
||||
|
||||
NODEJS_SCRIPT_TEMPLATE = Template(dedent("""
|
||||
$code
|
||||
// decode and prepare input object
|
||||
var inputs_obj = JSON.parse(Buffer.from('$inputs_variable', 'base64').toString('utf-8'))
|
||||
|
||||
// execute main function
|
||||
var output_obj = main(inputs_obj)
|
||||
|
||||
// convert output to json and print
|
||||
var output_json = JSON.stringify(output_obj)
|
||||
var result = `<<RESULT>>$${output_json}<<RESULT>>`
|
||||
console.log(result)
|
||||
"""))
|
||||
|
||||
|
||||
class CodeNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
@@ -83,18 +98,27 @@ class CodeNode(BaseNode):
|
||||
input_variable_dict = {}
|
||||
for input_variable in self.typed_config.input_variables:
|
||||
input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, state)
|
||||
|
||||
code = base64.b64decode(
|
||||
self.typed_config.code
|
||||
).decode("utf-8")
|
||||
code = urllib.parse.unquote(code, encoding='utf-8')
|
||||
|
||||
input_variable_dict = base64.b64encode(
|
||||
json.dumps(input_variable_dict).encode("utf-8")
|
||||
).decode("utf-8")
|
||||
|
||||
final_script = SCRIPT_TEMPLATE.substitute(
|
||||
code=code,
|
||||
inputs_variable=input_variable_dict,
|
||||
)
|
||||
if self.typed_config.language == "python3":
|
||||
final_script = PYTHON_SCRIPT_TEMPLATE.substitute(
|
||||
code=code,
|
||||
inputs_variable=input_variable_dict,
|
||||
)
|
||||
elif self.typed_config.language == 'javascript':
|
||||
final_script = NODEJS_SCRIPT_TEMPLATE.substitute(
|
||||
code=code,
|
||||
inputs_variable=input_variable_dict,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {self.typed_config.language}")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
|
||||
@@ -23,6 +23,18 @@ class ParameterExtractorNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: ParameterExtractorNodeConfig | None = None
|
||||
self.response_metadata = {}
|
||||
|
||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||
if self.response_metadata:
|
||||
usage = self.response_metadata.get('token_usage')
|
||||
if usage:
|
||||
return {
|
||||
"prompt_tokens": usage.get('prompt_tokens', 0),
|
||||
"completion_tokens": usage.get('completion_tokens', 0),
|
||||
"total_tokens": usage.get('total_tokens', 0)
|
||||
}
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_prompt():
|
||||
@@ -171,6 +183,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
])
|
||||
|
||||
model_resp = await llm.ainvoke(messages)
|
||||
self.response_metadata = model_resp.response_metadata
|
||||
result = json_repair.repair_json(model_resp.content, return_objects=True)
|
||||
logger.info(f"node: {self.node_id} get params:{result}")
|
||||
|
||||
|
||||
@@ -23,6 +23,18 @@ class QuestionClassifierNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: QuestionClassifierNodeConfig | None = None
|
||||
self.category_to_case_map = {}
|
||||
self.response_metadata = {}
|
||||
|
||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||
if self.response_metadata:
|
||||
usage = self.response_metadata.get('token_usage')
|
||||
if usage:
|
||||
return {
|
||||
"prompt_tokens": usage.get('prompt_tokens', 0),
|
||||
"completion_tokens": usage.get('completion_tokens', 0),
|
||||
"total_tokens": usage.get('total_tokens', 0)
|
||||
}
|
||||
return None
|
||||
|
||||
def _get_llm_instance(self) -> RedBearLLM:
|
||||
"""获取LLM实例"""
|
||||
@@ -112,6 +124,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
|
||||
response = await llm.ainvoke(messages)
|
||||
result = response.content.strip()
|
||||
self.response_metadata = response.response_metadata
|
||||
|
||||
if result in category_names:
|
||||
category = result
|
||||
|
||||
@@ -4,16 +4,19 @@
|
||||
从文件系统加载预定义的工作流模板
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
TEMPLATE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'templates')
|
||||
|
||||
|
||||
class TemplateLoader:
|
||||
"""工作流模板加载器"""
|
||||
|
||||
def __init__(self, templates_dir: str = "app/templates/workflows"):
|
||||
|
||||
def __init__(self, templates_dir: str = TEMPLATE_DIR):
|
||||
"""初始化模板加载器
|
||||
|
||||
Args:
|
||||
@@ -22,7 +25,7 @@ class TemplateLoader:
|
||||
self.templates_dir = Path(templates_dir)
|
||||
if not self.templates_dir.exists():
|
||||
raise ValueError(f"模板目录不存在: {templates_dir}")
|
||||
|
||||
|
||||
def list_templates(self) -> list[dict]:
|
||||
"""列出所有可用的模板
|
||||
|
||||
@@ -30,22 +33,22 @@ class TemplateLoader:
|
||||
模板列表,每个模板包含 id, name, description 等信息
|
||||
"""
|
||||
templates = []
|
||||
|
||||
|
||||
# 遍历模板目录
|
||||
for template_dir in self.templates_dir.iterdir():
|
||||
if not template_dir.is_dir():
|
||||
continue
|
||||
|
||||
|
||||
# 检查是否有 template.yml 文件
|
||||
template_file = template_dir / "template.yml"
|
||||
if not template_file.exists():
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
# 读取模板配置
|
||||
with open(template_file, 'r', encoding='utf-8') as f:
|
||||
template_data = yaml.safe_load(f)
|
||||
|
||||
|
||||
# 提取模板信息
|
||||
templates.append({
|
||||
"id": template_dir.name,
|
||||
@@ -59,9 +62,9 @@ class TemplateLoader:
|
||||
except Exception as e:
|
||||
print(f"加载模板 {template_dir.name} 失败: {e}")
|
||||
continue
|
||||
|
||||
|
||||
return templates
|
||||
|
||||
|
||||
def load_template(self, template_id: str) -> Optional[dict]:
|
||||
"""加载指定的模板
|
||||
|
||||
@@ -73,14 +76,14 @@ class TemplateLoader:
|
||||
"""
|
||||
template_dir = self.templates_dir / template_id
|
||||
template_file = template_dir / "template.yml"
|
||||
|
||||
|
||||
if not template_file.exists():
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
with open(template_file, 'r', encoding='utf-8') as f:
|
||||
template_data = yaml.safe_load(f)
|
||||
|
||||
|
||||
# 返回工作流配置部分
|
||||
return {
|
||||
"name": template_data.get("name", template_id),
|
||||
@@ -94,7 +97,7 @@ class TemplateLoader:
|
||||
except Exception as e:
|
||||
print(f"加载模板 {template_id} 失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_template_readme(self, template_id: str) -> Optional[str]:
|
||||
"""获取模板的 README 文档
|
||||
|
||||
@@ -106,10 +109,10 @@ class TemplateLoader:
|
||||
"""
|
||||
template_dir = self.templates_dir / template_id
|
||||
readme_file = template_dir / "README.md"
|
||||
|
||||
|
||||
if not readme_file.exists():
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
with open(readme_file, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
|
||||
@@ -28,6 +28,10 @@ from .tool_model import (
|
||||
ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus
|
||||
)
|
||||
from .memory_perceptual_model import MemoryPerceptualModel
|
||||
from .ontology_scene import OntologyScene
|
||||
from .ontology_class import OntologyClass
|
||||
from .ontology_scene import OntologyScene
|
||||
from .ontology_class import OntologyClass
|
||||
|
||||
__all__ = [
|
||||
"Tenants",
|
||||
|
||||
@@ -20,6 +20,9 @@ class MemoryConfig(Base):
|
||||
end_user_id = Column(String, nullable=True, comment="组ID")
|
||||
user_id = Column(String, nullable=True, comment="用户ID")
|
||||
apply_id = Column(String, nullable=True, comment="应用ID")
|
||||
|
||||
# 本体场景关联
|
||||
scene_id = Column(UUID(as_uuid=True), nullable=True, comment="本体场景ID,关联ontology_scene表")
|
||||
|
||||
# 模型选择(从workspace继承)
|
||||
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
|
||||
|
||||
40
api/app/models/ontology_class.py
Normal file
40
api/app/models/ontology_class.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""本体类型模型
|
||||
|
||||
本模块定义本体类型的数据模型。
|
||||
|
||||
Classes:
|
||||
OntologyClass: 本体类型表模型
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, DateTime, Text, ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db import Base
|
||||
|
||||
|
||||
class OntologyClass(Base):
|
||||
"""本体类型表 - 用于存储某个场景提取出来的本体类型信息"""
|
||||
__tablename__ = "ontology_class"
|
||||
|
||||
# 主键
|
||||
class_id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="类型ID")
|
||||
|
||||
# 类型信息
|
||||
class_name = Column(String(200), nullable=False, comment="类型名称")
|
||||
class_description = Column(Text, nullable=True, comment="类型描述")
|
||||
|
||||
# 外键:关联到本体场景
|
||||
scene_id = Column(UUID(as_uuid=True), ForeignKey("ontology_scene.scene_id", ondelete="CASCADE"), nullable=False, index=True, comment="所属场景ID")
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, comment="创建时间")
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, nullable=False, comment="更新时间")
|
||||
|
||||
# 关系:类型属于某个场景
|
||||
scene = relationship("OntologyScene", back_populates="classes")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OntologyClass(id={self.class_id}, name={self.class_name}, scene_id={self.scene_id})>"
|
||||
43
api/app/models/ontology_scene.py
Normal file
43
api/app/models/ontology_scene.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""本体场景模型
|
||||
|
||||
本模块定义本体场景的数据模型。
|
||||
|
||||
Classes:
|
||||
OntologyScene: 本体场景表模型
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, DateTime, Integer, Text, ForeignKey, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db import Base
|
||||
|
||||
|
||||
class OntologyScene(Base):
|
||||
"""本体场景表 - 用于存储本体场景下不同的类型信息"""
|
||||
__tablename__ = "ontology_scene"
|
||||
__table_args__ = (
|
||||
UniqueConstraint('workspace_id', 'scene_name', name='uq_workspace_scene_name'),
|
||||
)
|
||||
|
||||
# 主键
|
||||
scene_id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="场景ID")
|
||||
|
||||
# 场景信息
|
||||
scene_name = Column(String(200), nullable=False, comment="场景名称")
|
||||
scene_description = Column(Text, nullable=True, comment="场景描述")
|
||||
|
||||
# 外键:关联到工作空间
|
||||
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="CASCADE"), nullable=False, index=True, comment="所属工作空间ID")
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, comment="创建时间")
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, nullable=False, comment="更新时间")
|
||||
|
||||
# 关系:一个场景可以有多个类型
|
||||
classes = relationship("OntologyClass", back_populates="scene", cascade="all, delete-orphan")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OntologyScene(id={self.scene_id}, name={self.scene_name})>"
|
||||
@@ -2,7 +2,7 @@ import datetime
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlalchemy import Column, ForeignKey, Text, DateTime, String, Index
|
||||
from sqlalchemy import Column, ForeignKey, Text, DateTime, String, Index, Boolean
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from app.db import Base
|
||||
@@ -121,10 +121,33 @@ class PromptOptimizerSessionHistory(Base):
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID")
|
||||
# app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False, comment="Application ID")
|
||||
session_id = Column(UUID(as_uuid=True), ForeignKey("prompt_opt_session_list.id"),nullable=False, comment="Session ID")
|
||||
session_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("prompt_opt_session_list.id"),
|
||||
nullable=False,
|
||||
comment="Session ID"
|
||||
)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, comment="User ID")
|
||||
role = Column(String, nullable=False, comment="Message Role")
|
||||
content = Column(Text, nullable=False, comment="Message Content")
|
||||
# prompt = Column(Text, nullable=False, comment="Prompt")
|
||||
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time", index=True)
|
||||
|
||||
|
||||
class PromptHistory(Base):
|
||||
__tablename__ = "prompt_history"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID")
|
||||
|
||||
session_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("prompt_opt_session_list.id"),
|
||||
nullable=False,
|
||||
comment="Session ID"
|
||||
)
|
||||
title = Column(String, nullable=False, comment="Title")
|
||||
prompt = Column(Text, nullable=False, comment="Prompt")
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time", index=True)
|
||||
is_delete = Column(Boolean, default=False, comment="Delete")
|
||||
|
||||
@@ -86,7 +86,8 @@ class MemoryConfigRepository:
|
||||
n.description AS description,
|
||||
n.entity_type AS entity_type,
|
||||
n.name AS name,
|
||||
COALESCE(n.fact_summary, '') AS fact_summary,
|
||||
// TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
// COALESCE(n.fact_summary, '') AS fact_summary,
|
||||
n.end_user_id AS end_user_id,
|
||||
n.apply_id AS apply_id,
|
||||
n.user_id AS user_id,
|
||||
@@ -156,7 +157,7 @@ class MemoryConfigRepository:
|
||||
return memory_config_obj
|
||||
|
||||
@staticmethod
|
||||
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID) -> MemoryConfig:
|
||||
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID|int|str) -> MemoryConfig:
|
||||
"""构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数)
|
||||
|
||||
Args:
|
||||
@@ -172,7 +173,6 @@ class MemoryConfigRepository:
|
||||
if not memory_config:
|
||||
raise RuntimeError("reflection config not found")
|
||||
return memory_config
|
||||
|
||||
@staticmethod
|
||||
def query_reflection_config_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> MemoryConfig:
|
||||
"""构建查询所有配置的语句(SQLAlchemy text() 命名参数)
|
||||
@@ -231,9 +231,12 @@ class MemoryConfigRepository:
|
||||
config_name=params.config_name,
|
||||
config_desc=params.config_desc,
|
||||
workspace_id=params.workspace_id,
|
||||
scene_id=params.scene_id,
|
||||
llm_id=params.llm_id,
|
||||
embedding_id=params.embedding_id,
|
||||
rerank_id=params.rerank_id,
|
||||
reflection_model_id=params.reflection_model_id,
|
||||
emotion_model_id=params.emotion_model_id,
|
||||
)
|
||||
db.add(db_config)
|
||||
db.flush() # 获取自增ID但不提交事务
|
||||
@@ -276,6 +279,9 @@ class MemoryConfigRepository:
|
||||
if update.config_desc is not None:
|
||||
db_config.config_desc = update.config_desc
|
||||
has_update = True
|
||||
if update.scene_id is not None:
|
||||
db_config.scene_id = update.scene_id
|
||||
has_update = True
|
||||
|
||||
if not has_update:
|
||||
raise ValueError("No fields to update")
|
||||
@@ -528,7 +534,6 @@ class MemoryConfigRepository:
|
||||
|
||||
Returns:
|
||||
Optional[tuple]: (MemoryConfig, Workspace) tuple, None if not found
|
||||
|
||||
Raises:
|
||||
ValueError: Raised when config exists but workspace doesn't
|
||||
"""
|
||||
@@ -555,7 +560,6 @@ class MemoryConfigRepository:
|
||||
result = db.query(MemoryConfig, Workspace).join(
|
||||
Workspace, MemoryConfig.workspace_id == Workspace.id
|
||||
).filter(MemoryConfig.config_id == config_id).first()
|
||||
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
if not result:
|
||||
@@ -642,33 +646,36 @@ class MemoryConfigRepository:
|
||||
},
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
db_logger.error(f"Failed to query memory config and workspace: config_id={config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[MemoryConfig]:
|
||||
"""获取所有配置参数
|
||||
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[Tuple[MemoryConfig, Optional[str]]]:
|
||||
"""获取所有配置参数,包含关联的场景名称
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
workspace_id: 工作空间ID,用于过滤查询结果
|
||||
|
||||
Returns:
|
||||
List[MemoryConfig]: 配置列表
|
||||
List[Tuple[MemoryConfig, Optional[str]]]: 配置列表,每项为 (配置对象, 场景名称)
|
||||
"""
|
||||
from app.models.ontology_scene import OntologyScene
|
||||
|
||||
db_logger.debug(f"查询所有配置: workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
query = db.query(MemoryConfig)
|
||||
query = db.query(MemoryConfig, OntologyScene.scene_name).outerjoin(
|
||||
OntologyScene, MemoryConfig.scene_id == OntologyScene.scene_id
|
||||
)
|
||||
|
||||
if workspace_id:
|
||||
query = query.filter(MemoryConfig.workspace_id == workspace_id)
|
||||
|
||||
configs = query.order_by(desc(MemoryConfig.updated_at)).all()
|
||||
results = query.order_by(desc(MemoryConfig.updated_at)).all()
|
||||
|
||||
db_logger.debug(f"配置列表查询成功: 数量={len(configs)}")
|
||||
return configs
|
||||
db_logger.debug(f"配置列表查询成功: 数量={len(results)}")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询所有配置失败: workspace_id={workspace_id} - {str(e)}")
|
||||
|
||||
@@ -79,7 +79,8 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
|
||||
try:
|
||||
edges: List[dict] = []
|
||||
for s in summaries:
|
||||
for chunk_id in getattr(s, "chunk_ids", []) or []:
|
||||
chunk_ids = getattr(s, "chunk_ids", []) or []
|
||||
for chunk_id in chunk_ids:
|
||||
edges.append({
|
||||
"summary_id": s.id,
|
||||
"chunk_id": chunk_id,
|
||||
@@ -91,12 +92,11 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode],
|
||||
|
||||
if not edges:
|
||||
return []
|
||||
|
||||
result = await connector.execute_query(
|
||||
MEMORY_SUMMARY_STATEMENT_EDGE_SAVE,
|
||||
edges=edges
|
||||
)
|
||||
created = [record.get("uuid") for record in result] if result else []
|
||||
return created
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
@@ -217,8 +217,10 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
||||
summaries=flattened
|
||||
)
|
||||
created_ids = [record.get("uuid") for record in result]
|
||||
print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
|
||||
return created_ids
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
print(f"Failed to save MemorySummary nodes to Neo4j: {e}")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -101,10 +101,11 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
|
||||
e.name_embedding = CASE
|
||||
WHEN entity.name_embedding IS NOT NULL AND size(entity.name_embedding) > 0 THEN entity.name_embedding
|
||||
ELSE e.name_embedding END,
|
||||
e.fact_summary = CASE
|
||||
WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> ''
|
||||
AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary))
|
||||
THEN entity.fact_summary ELSE e.fact_summary END,
|
||||
// TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
// e.fact_summary = CASE
|
||||
// WHEN entity.fact_summary IS NOT NULL AND entity.fact_summary <> ''
|
||||
// AND (e.fact_summary IS NULL OR size(e.fact_summary) = 0 OR size(entity.fact_summary) > size(e.fact_summary))
|
||||
// THEN entity.fact_summary ELSE e.fact_summary END,
|
||||
e.connect_strength = CASE
|
||||
WHEN entity.connect_strength IS NULL OR entity.connect_strength = '' THEN e.connect_strength
|
||||
ELSE CASE
|
||||
@@ -321,7 +322,8 @@ RETURN e.id AS id,
|
||||
e.description AS description,
|
||||
e.aliases AS aliases,
|
||||
e.name_embedding AS name_embedding,
|
||||
COALESCE(e.fact_summary, '') AS fact_summary,
|
||||
// TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
// COALESCE(e.fact_summary, '') AS fact_summary,
|
||||
e.connect_strength AS connect_strength,
|
||||
collect(DISTINCT s.id) AS statement_ids,
|
||||
collect(DISTINCT c.id) AS chunk_ids,
|
||||
@@ -877,7 +879,8 @@ RETURN
|
||||
CASE
|
||||
WHEN ms:ExtractedEntity THEN {
|
||||
text: ms.name,
|
||||
created_at: ms.created_at
|
||||
created_at: ms.created_at,
|
||||
type: "情景记忆"
|
||||
}
|
||||
END
|
||||
) AS ExtractedEntity,
|
||||
@@ -887,7 +890,8 @@ RETURN
|
||||
CASE
|
||||
WHEN n:MemorySummary THEN {
|
||||
text: n.content,
|
||||
created_at: n.created_at
|
||||
created_at: n.created_at,
|
||||
type: "长期沉淀"
|
||||
}
|
||||
END
|
||||
) AS MemorySummary,
|
||||
@@ -895,7 +899,8 @@ RETURN
|
||||
collect(
|
||||
DISTINCT {
|
||||
text: e.statement,
|
||||
created_at: e.created_at
|
||||
created_at: e.created_at,
|
||||
type: "情绪记忆"
|
||||
}
|
||||
) AS statement;
|
||||
"""
|
||||
@@ -999,3 +1004,58 @@ RETURN DISTINCT
|
||||
x.statement as statement,x.created_at as created_at
|
||||
"""
|
||||
|
||||
Graph_Node_query = """
|
||||
MATCH (n:MemorySummary)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
0 AS priority
|
||||
LIMIT $limit
|
||||
|
||||
UNION ALL
|
||||
|
||||
MATCH (n:Dialogue)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
1 AS priority
|
||||
LIMIT 1
|
||||
|
||||
UNION ALL
|
||||
|
||||
MATCH (n:Statement)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
1 AS priority
|
||||
LIMIT $limit
|
||||
|
||||
UNION ALL
|
||||
|
||||
MATCH (n:ExtractedEntity)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
2 AS priority
|
||||
LIMIT $limit
|
||||
|
||||
UNION ALL
|
||||
|
||||
MATCH (n:Chunk)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) AS id,
|
||||
labels(n) AS labels,
|
||||
properties(n) AS properties,
|
||||
3 AS priority
|
||||
LIMIT $limit
|
||||
|
||||
"""
|
||||
@@ -21,7 +21,8 @@ from app.core.memory.models.graph_models import (
|
||||
ExtractedEntityNode,
|
||||
EntityEntityEdge,
|
||||
)
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
async def save_entities_and_relationships(
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
@@ -41,8 +42,8 @@ async def save_entities_and_relationships(
|
||||
'statement': edge.statement,
|
||||
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
|
||||
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
|
||||
'created_at': edge.created_at.isoformat(),
|
||||
'expired_at': edge.expired_at.isoformat(),
|
||||
'created_at': edge.created_at.isoformat() if edge.created_at else None,
|
||||
'expired_at': edge.expired_at.isoformat() if edge.expired_at else None,
|
||||
'run_id': edge.run_id,
|
||||
'end_user_id': edge.end_user_id,
|
||||
}
|
||||
@@ -147,14 +148,14 @@ async def save_statement_entity_edges(
|
||||
|
||||
|
||||
async def save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes: List[DialogueNode],
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
entity_edges: List[EntityEntityEdge],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
dialogue_nodes: List[DialogueNode],
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
entity_edges: List[EntityEntityEdge],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
connector: Neo4jConnector
|
||||
) -> bool:
|
||||
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
||||
|
||||
@@ -171,40 +172,127 @@ async def save_dialog_and_statements_to_neo4j(
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Save all dialogue nodes in batch
|
||||
dialogue_uuids = await add_dialogue_nodes(dialogue_nodes, connector)
|
||||
if dialogue_uuids:
|
||||
|
||||
# 定义事务函数,将所有写操作放在一个事务中
|
||||
async def _save_all_in_transaction(tx):
|
||||
"""在单个事务中执行所有保存操作,避免死锁"""
|
||||
results = {}
|
||||
|
||||
# 1. Save all dialogue nodes in batch
|
||||
if dialogue_nodes:
|
||||
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE
|
||||
dialogue_data = [node.model_dump() for node in dialogue_nodes]
|
||||
result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data)
|
||||
dialogue_uuids = [record["uuid"] async for record in result]
|
||||
results['dialogues'] = dialogue_uuids
|
||||
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
|
||||
else:
|
||||
print("Failed to save dialogues to Neo4j")
|
||||
return False
|
||||
|
||||
# Save all chunk nodes in batch
|
||||
await save_chunk_nodes(chunk_nodes, connector)
|
||||
# 2. Save all chunk nodes in batch
|
||||
if chunk_nodes:
|
||||
from app.repositories.neo4j.cypher_queries import CHUNK_NODE_SAVE
|
||||
chunk_data = [node.model_dump() for node in chunk_nodes]
|
||||
result = await tx.run(CHUNK_NODE_SAVE, chunks=chunk_data)
|
||||
chunk_uuids = [record["uuid"] async for record in result]
|
||||
results['chunks'] = chunk_uuids
|
||||
logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
|
||||
|
||||
# Save all statement nodes in batch
|
||||
# 3. Save all statement nodes in batch
|
||||
if statement_nodes:
|
||||
statement_uuids = await add_statement_nodes(statement_nodes, connector)
|
||||
if statement_uuids:
|
||||
print(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j")
|
||||
else:
|
||||
print("Failed to save statement nodes to Neo4j")
|
||||
return False
|
||||
else:
|
||||
print("No statement nodes to save")
|
||||
from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE
|
||||
statement_data = [node.model_dump() for node in statement_nodes]
|
||||
result = await tx.run(STATEMENT_NODE_SAVE, statements=statement_data)
|
||||
statement_uuids = [record["uuid"] async for record in result]
|
||||
results['statements'] = statement_uuids
|
||||
logger.info(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j")
|
||||
|
||||
# Save entities and relationships
|
||||
await save_entities_and_relationships(entity_nodes, entity_edges, connector)
|
||||
print("Successfully saved entities and relationships to Neo4j")
|
||||
# 4. Save entities
|
||||
if entity_nodes:
|
||||
from app.repositories.neo4j.cypher_queries import EXTRACTED_ENTITY_NODE_SAVE
|
||||
entity_data = [entity.model_dump() for entity in entity_nodes]
|
||||
result = await tx.run(EXTRACTED_ENTITY_NODE_SAVE, entities=entity_data)
|
||||
entity_uuids = [record["uuid"] async for record in result]
|
||||
results['entities'] = entity_uuids
|
||||
logger.info(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j")
|
||||
|
||||
# Save new edges
|
||||
await save_statement_chunk_edges(statement_chunk_edges, connector)
|
||||
await save_statement_entity_edges(statement_entity_edges, connector)
|
||||
# 5. Create entity relationships
|
||||
if entity_edges:
|
||||
from app.repositories.neo4j.cypher_queries import ENTITY_RELATIONSHIP_SAVE
|
||||
relationship_data = []
|
||||
for edge in entity_edges:
|
||||
relationship_data.append({
|
||||
'source_id': edge.source,
|
||||
'target_id': edge.target,
|
||||
'predicate': edge.relation_type,
|
||||
'statement_id': edge.source_statement_id,
|
||||
'value': edge.relation_value,
|
||||
'statement': edge.statement,
|
||||
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
|
||||
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
|
||||
'created_at': edge.created_at.isoformat() if edge.created_at else None,
|
||||
'expired_at': edge.expired_at.isoformat() if edge.expired_at else None,
|
||||
'run_id': edge.run_id,
|
||||
'end_user_id': edge.end_user_id,
|
||||
})
|
||||
result = await tx.run(ENTITY_RELATIONSHIP_SAVE, relationships=relationship_data)
|
||||
rel_uuids = [record["uuid"] async for record in result]
|
||||
results['entity_relationships'] = rel_uuids
|
||||
logger.info(f"Successfully saved {len(rel_uuids)} entity relationships to Neo4j")
|
||||
|
||||
# 6. Save statement-chunk edges
|
||||
if statement_chunk_edges:
|
||||
from app.repositories.neo4j.cypher_queries import CHUNK_STATEMENT_EDGE_SAVE
|
||||
sc_edge_data = []
|
||||
for edge in statement_chunk_edges:
|
||||
sc_edge_data.append({
|
||||
"id": edge.id,
|
||||
"source": edge.source,
|
||||
"target": edge.target,
|
||||
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
|
||||
"run_id": edge.run_id,
|
||||
"end_user_id": edge.end_user_id,
|
||||
})
|
||||
result = await tx.run(CHUNK_STATEMENT_EDGE_SAVE, chunk_statement_edges=sc_edge_data)
|
||||
sc_uuids = [record["uuid"] async for record in result]
|
||||
results['statement_chunk_edges'] = sc_uuids
|
||||
logger.info(f"Successfully saved {len(sc_uuids)} statement-chunk edges to Neo4j")
|
||||
|
||||
# 7. Save statement-entity edges
|
||||
if statement_entity_edges:
|
||||
from app.repositories.neo4j.cypher_queries import STATEMENT_ENTITY_EDGE_SAVE
|
||||
se_edge_data = []
|
||||
for edge in statement_entity_edges:
|
||||
se_edge_data.append({
|
||||
"source": edge.source,
|
||||
"target": edge.target,
|
||||
"created_at": edge.created_at.isoformat() if edge.created_at else None,
|
||||
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
|
||||
"run_id": edge.run_id,
|
||||
"end_user_id": edge.end_user_id,
|
||||
"connect_strength": getattr(edge, "connect_strength", "strong"),
|
||||
})
|
||||
result = await tx.run(STATEMENT_ENTITY_EDGE_SAVE, relationships=se_edge_data)
|
||||
se_uuids = [record["uuid"] async for record in result]
|
||||
results['statement_entity_edges'] = se_uuids
|
||||
logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j")
|
||||
|
||||
return results
|
||||
|
||||
try:
|
||||
# 使用显式写事务执行所有操作,避免死锁
|
||||
results = await connector.execute_write_transaction(_save_all_in_transaction)
|
||||
summary = {
|
||||
key: len(value)
|
||||
for key, value in results.items()
|
||||
if isinstance(value, (list, tuple, set))
|
||||
}
|
||||
logger.info("Transaction completed. Summary: %s", summary)
|
||||
logger.debug("Full transaction results: %r", results)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Neo4j integration error: {e}", exc_info=True)
|
||||
print(f"Neo4j integration error: {e}")
|
||||
print("Continuing without database storage...")
|
||||
return False
|
||||
|
||||
|
||||
404
api/app/repositories/ontology_class_repository.py
Normal file
404
api/app/repositories/ontology_class_repository.py
Normal file
@@ -0,0 +1,404 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""本体类型Repository层
|
||||
|
||||
本模块提供本体类型的数据访问层实现。
|
||||
|
||||
Classes:
|
||||
OntologyClassRepository: 本体类型数据访问类
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
from app.models.ontology_class import OntologyClass
|
||||
from app.models.ontology_scene import OntologyScene
|
||||
|
||||
|
||||
logger = get_db_logger()
|
||||
|
||||
|
||||
class OntologyClassRepository:
|
||||
"""本体类型Repository
|
||||
|
||||
提供本体类型的CRUD操作和权限检查。
|
||||
|
||||
Attributes:
|
||||
db: SQLAlchemy数据库会话
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""初始化Repository
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy数据库会话
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
def create(self, class_data: dict, scene_id: UUID) -> OntologyClass:
|
||||
"""创建本体类型
|
||||
|
||||
Args:
|
||||
class_data: 类型数据字典,包含class_name和class_description
|
||||
scene_id: 所属场景ID
|
||||
|
||||
Returns:
|
||||
OntologyClass: 创建的类型对象
|
||||
|
||||
Raises:
|
||||
Exception: 数据库操作失败
|
||||
|
||||
Examples:
|
||||
>>> repo = OntologyClassRepository(db)
|
||||
>>> ontology_class = repo.create(
|
||||
... {"class_name": "患者", "class_description": "描述"},
|
||||
... scene_id
|
||||
... )
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
f"Creating ontology class - "
|
||||
f"name={class_data.get('class_name')}, "
|
||||
f"scene_id={scene_id}"
|
||||
)
|
||||
|
||||
ontology_class = OntologyClass(
|
||||
class_name=class_data.get("class_name"),
|
||||
class_description=class_data.get("class_description"),
|
||||
scene_id=scene_id
|
||||
)
|
||||
|
||||
self.db.add(ontology_class)
|
||||
self.db.flush() # 获取ID但不提交
|
||||
|
||||
logger.info(
|
||||
f"Ontology class created successfully - "
|
||||
f"class_id={ontology_class.class_id}"
|
||||
)
|
||||
|
||||
return ontology_class
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create ontology class: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_id(self, class_id: UUID) -> Optional[OntologyClass]:
|
||||
"""根据ID获取类型
|
||||
|
||||
Args:
|
||||
class_id: 类型ID
|
||||
|
||||
Returns:
|
||||
Optional[OntologyClass]: 类型对象,不存在则返回None
|
||||
|
||||
Examples:
|
||||
>>> repo = OntologyClassRepository(db)
|
||||
>>> ontology_class = repo.get_by_id(class_id)
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Getting ontology class by ID: {class_id}")
|
||||
|
||||
ontology_class = self.db.query(OntologyClass).filter(
|
||||
OntologyClass.class_id == class_id
|
||||
).first()
|
||||
|
||||
if ontology_class:
|
||||
logger.debug(f"Ontology class found: {class_id}")
|
||||
else:
|
||||
logger.debug(f"Ontology class not found: {class_id}")
|
||||
|
||||
return ontology_class
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get ontology class by ID: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_name(self, class_name: str, scene_id: UUID) -> Optional[OntologyClass]:
|
||||
"""根据类型名称和场景ID获取类型(精确匹配)
|
||||
|
||||
Args:
|
||||
class_name: 类型名称
|
||||
scene_id: 场景ID
|
||||
|
||||
Returns:
|
||||
Optional[OntologyClass]: 类型对象,不存在则返回None
|
||||
|
||||
Examples:
|
||||
>>> repo = OntologyClassRepository(db)
|
||||
>>> ontology_class = repo.get_by_name("患者", scene_id)
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Getting ontology class by name: {class_name}, scene_id: {scene_id}")
|
||||
|
||||
ontology_class = self.db.query(OntologyClass).filter(
|
||||
OntologyClass.class_name == class_name,
|
||||
OntologyClass.scene_id == scene_id
|
||||
).first()
|
||||
|
||||
if ontology_class:
|
||||
logger.debug(f"Ontology class found: {class_name}")
|
||||
else:
|
||||
logger.debug(f"Ontology class not found: {class_name}")
|
||||
|
||||
return ontology_class
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get ontology class by name: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def search_by_name(self, keyword: str, scene_id: UUID) -> List[OntologyClass]:
|
||||
"""根据关键词模糊搜索类型
|
||||
|
||||
使用 LIKE 进行模糊匹配,支持中文和英文。
|
||||
|
||||
Args:
|
||||
keyword: 搜索关键词
|
||||
scene_id: 场景ID
|
||||
|
||||
Returns:
|
||||
List[OntologyClass]: 匹配的类型列表
|
||||
|
||||
Examples:
|
||||
>>> repo = OntologyClassRepository(db)
|
||||
>>> classes = repo.search_by_name("患者", scene_id)
|
||||
"""
|
||||
try:
|
||||
logger.debug(
|
||||
f"Searching ontology classes by keyword - "
|
||||
f"keyword={keyword}, scene_id={scene_id}"
|
||||
)
|
||||
|
||||
# 使用 ilike 进行不区分大小写的模糊匹配
|
||||
classes = self.db.query(OntologyClass).filter(
|
||||
OntologyClass.class_name.ilike(f"%{keyword}%"),
|
||||
OntologyClass.scene_id == scene_id
|
||||
).order_by(
|
||||
OntologyClass.created_at.desc()
|
||||
).all()
|
||||
|
||||
logger.info(
|
||||
f"Found {len(classes)} ontology classes matching keyword '{keyword}' "
|
||||
f"in scene {scene_id}"
|
||||
)
|
||||
|
||||
return classes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to search ontology classes by keyword: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_scene(self, scene_id: UUID) -> List[OntologyClass]:
|
||||
"""获取场景下的所有类型
|
||||
|
||||
按创建时间倒序排列。
|
||||
|
||||
Args:
|
||||
scene_id: 场景ID
|
||||
|
||||
Returns:
|
||||
List[OntologyClass]: 类型列表
|
||||
|
||||
Examples:
|
||||
>>> repo = OntologyClassRepository(db)
|
||||
>>> classes = repo.get_by_scene(scene_id)
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Getting ontology classes by scene: {scene_id}")
|
||||
|
||||
classes = self.db.query(OntologyClass).filter(
|
||||
OntologyClass.scene_id == scene_id
|
||||
).order_by(
|
||||
OntologyClass.created_at.desc()
|
||||
).all()
|
||||
|
||||
logger.info(
|
||||
f"Found {len(classes)} ontology classes in scene {scene_id}"
|
||||
)
|
||||
|
||||
return classes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get ontology classes by scene: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def update(self, class_id: UUID, update_data: dict) -> Optional[OntologyClass]:
|
||||
"""更新类型信息
|
||||
|
||||
Args:
|
||||
class_id: 类型ID
|
||||
update_data: 更新数据字典
|
||||
|
||||
Returns:
|
||||
Optional[OntologyClass]: 更新后的类型对象,不存在则返回None
|
||||
|
||||
Raises:
|
||||
Exception: 数据库操作失败
|
||||
|
||||
Examples:
|
||||
>>> repo = OntologyClassRepository(db)
|
||||
>>> ontology_class = repo.update(
|
||||
... class_id,
|
||||
... {"class_name": "新名称"}
|
||||
... )
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Updating ontology class: {class_id}")
|
||||
|
||||
ontology_class = self.get_by_id(class_id)
|
||||
if not ontology_class:
|
||||
logger.warning(f"Ontology class not found for update: {class_id}")
|
||||
return None
|
||||
|
||||
# 更新字段
|
||||
if "class_name" in update_data and update_data["class_name"] is not None:
|
||||
ontology_class.class_name = update_data["class_name"]
|
||||
|
||||
if "class_description" in update_data:
|
||||
ontology_class.class_description = update_data["class_description"]
|
||||
|
||||
self.db.flush()
|
||||
|
||||
logger.info(f"Ontology class updated successfully: {class_id}")
|
||||
|
||||
return ontology_class
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update ontology class: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def delete(self, class_id: UUID) -> bool:
|
||||
"""删除类型
|
||||
|
||||
Args:
|
||||
class_id: 类型ID
|
||||
|
||||
Returns:
|
||||
bool: 删除成功返回True,类型不存在返回False
|
||||
|
||||
Raises:
|
||||
Exception: 数据库操作失败
|
||||
|
||||
Examples:
|
||||
>>> repo = OntologyClassRepository(db)
|
||||
>>> success = repo.delete(class_id)
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Deleting ontology class: {class_id}")
|
||||
|
||||
ontology_class = self.get_by_id(class_id)
|
||||
if not ontology_class:
|
||||
logger.warning(f"Ontology class not found for delete: {class_id}")
|
||||
return False
|
||||
|
||||
self.db.delete(ontology_class)
|
||||
self.db.flush()
|
||||
|
||||
logger.info(f"Ontology class deleted successfully: {class_id}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete ontology class: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def check_ownership(self, class_id: UUID, workspace_id: UUID) -> bool:
|
||||
"""检查类型是否属于指定工作空间(通过场景关联)
|
||||
|
||||
Args:
|
||||
class_id: 类型ID
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
bool: 属于返回True,否则返回False
|
||||
|
||||
Examples:
|
||||
>>> repo = OntologyClassRepository(db)
|
||||
>>> is_owner = repo.check_ownership(class_id, workspace_id)
|
||||
"""
|
||||
try:
|
||||
logger.debug(
|
||||
f"Checking class ownership - "
|
||||
f"class_id={class_id}, workspace_id={workspace_id}"
|
||||
)
|
||||
|
||||
count = self.db.query(OntologyClass).join(
|
||||
OntologyScene,
|
||||
OntologyClass.scene_id == OntologyScene.scene_id
|
||||
).filter(
|
||||
OntologyClass.class_id == class_id,
|
||||
OntologyScene.workspace_id == workspace_id
|
||||
).count()
|
||||
|
||||
is_owner = count > 0
|
||||
|
||||
logger.debug(
|
||||
f"Class ownership check result: {is_owner} - "
|
||||
f"class_id={class_id}"
|
||||
)
|
||||
|
||||
return is_owner
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to check class ownership: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_scene_id_by_class(self, class_id: UUID) -> Optional[UUID]:
|
||||
"""根据类型ID获取所属场景ID
|
||||
|
||||
Args:
|
||||
class_id: 类型ID
|
||||
|
||||
Returns:
|
||||
Optional[UUID]: 场景ID,类型不存在则返回None
|
||||
|
||||
Examples:
|
||||
>>> repo = OntologyClassRepository(db)
|
||||
>>> scene_id = repo.get_scene_id_by_class(class_id)
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Getting scene ID by class: {class_id}")
|
||||
|
||||
ontology_class = self.get_by_id(class_id)
|
||||
if not ontology_class:
|
||||
logger.debug(f"Class not found: {class_id}")
|
||||
return None
|
||||
|
||||
logger.debug(
|
||||
f"Found scene ID: {ontology_class.scene_id} for class: {class_id}"
|
||||
)
|
||||
|
||||
return ontology_class.scene_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get scene ID by class: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
439
api/app/repositories/ontology_scene_repository.py
Normal file
439
api/app/repositories/ontology_scene_repository.py
Normal file
@@ -0,0 +1,439 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""本体场景Repository层
|
||||
|
||||
本模块提供本体场景的数据访问层实现。
|
||||
|
||||
Classes:
|
||||
OntologySceneRepository: 本体场景数据访问类
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
from app.models.ontology_scene import OntologyScene
|
||||
|
||||
|
||||
logger = get_db_logger()
|
||||
|
||||
|
||||
class OntologySceneRepository:
|
||||
"""本体场景Repository
|
||||
|
||||
提供本体场景的CRUD操作和权限检查。
|
||||
|
||||
Attributes:
|
||||
db: SQLAlchemy数据库会话
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""初始化Repository
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy数据库会话
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
def create(self, scene_data: dict, workspace_id: UUID) -> OntologyScene:
|
||||
"""创建本体场景
|
||||
|
||||
Args:
|
||||
scene_data: 场景数据字典,包含scene_name和scene_description
|
||||
workspace_id: 所属工作空间ID
|
||||
|
||||
Returns:
|
||||
OntologyScene: 创建的场景对象
|
||||
|
||||
Raises:
|
||||
Exception: 数据库操作失败
|
||||
|
||||
Examples:
|
||||
>>> repo = OntologySceneRepository(db)
|
||||
>>> scene = repo.create(
|
||||
... {"scene_name": "医疗场景", "scene_description": "描述"},
|
||||
... workspace_id
|
||||
... )
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
f"Creating ontology scene - "
|
||||
f"name={scene_data.get('scene_name')}, "
|
||||
f"workspace_id={workspace_id}"
|
||||
)
|
||||
|
||||
scene = OntologyScene(
|
||||
scene_name=scene_data.get("scene_name"),
|
||||
scene_description=scene_data.get("scene_description"),
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
self.db.add(scene)
|
||||
self.db.flush() # 获取ID但不提交
|
||||
|
||||
logger.info(
|
||||
f"Ontology scene created successfully - "
|
||||
f"scene_id={scene.scene_id}"
|
||||
)
|
||||
|
||||
return scene
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create ontology scene: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_id(self, scene_id: UUID) -> Optional[OntologyScene]:
|
||||
"""根据ID获取场景
|
||||
|
||||
Args:
|
||||
scene_id: 场景ID
|
||||
|
||||
Returns:
|
||||
Optional[OntologyScene]: 场景对象,不存在则返回None
|
||||
|
||||
Examples:
|
||||
>>> repo = OntologySceneRepository(db)
|
||||
>>> scene = repo.get_by_id(scene_id)
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Getting ontology scene by ID: {scene_id}")
|
||||
|
||||
scene = self.db.query(OntologyScene).filter(
|
||||
OntologyScene.scene_id == scene_id
|
||||
).first()
|
||||
|
||||
if scene:
|
||||
logger.debug(f"Ontology scene found: {scene_id}")
|
||||
else:
|
||||
logger.debug(f"Ontology scene not found: {scene_id}")
|
||||
|
||||
return scene
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get ontology scene by ID: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_name(self, scene_name: str, workspace_id: UUID) -> Optional[OntologyScene]:
|
||||
"""根据场景名称和工作空间ID获取场景(精确匹配)
|
||||
|
||||
Args:
|
||||
scene_name: 场景名称
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
Optional[OntologyScene]: 场景对象,不存在则返回None
|
||||
|
||||
Examples:
|
||||
>>> repo = OntologySceneRepository(db)
|
||||
>>> scene = repo.get_by_name("医疗场景", workspace_id)
|
||||
"""
|
||||
try:
|
||||
logger.debug(
|
||||
f"Getting ontology scene by name - "
|
||||
f"scene_name={scene_name}, workspace_id={workspace_id}"
|
||||
)
|
||||
|
||||
scene = self.db.query(OntologyScene).options(
|
||||
joinedload(OntologyScene.classes)
|
||||
).filter(
|
||||
OntologyScene.scene_name == scene_name,
|
||||
OntologyScene.workspace_id == workspace_id
|
||||
).first()
|
||||
|
||||
if scene:
|
||||
logger.debug(f"Ontology scene found: {scene_name}")
|
||||
else:
|
||||
logger.debug(f"Ontology scene not found: {scene_name}")
|
||||
|
||||
return scene
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get ontology scene by name: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def search_by_name(self, keyword: str, workspace_id: UUID) -> List[OntologyScene]:
|
||||
"""根据关键词模糊搜索场景
|
||||
|
||||
使用 LIKE 进行模糊匹配,支持中文和英文。
|
||||
|
||||
Args:
|
||||
keyword: 搜索关键词
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
List[OntologyScene]: 匹配的场景列表
|
||||
|
||||
Examples:
|
||||
>>> repo = OntologySceneRepository(db)
|
||||
>>> scenes = repo.search_by_name("医疗", workspace_id)
|
||||
"""
|
||||
try:
|
||||
logger.debug(
|
||||
f"Searching ontology scenes by keyword - "
|
||||
f"keyword={keyword}, workspace_id={workspace_id}"
|
||||
)
|
||||
|
||||
# 使用 ilike 进行不区分大小写的模糊匹配
|
||||
scenes = self.db.query(OntologyScene).options(
|
||||
joinedload(OntologyScene.classes)
|
||||
).filter(
|
||||
OntologyScene.scene_name.ilike(f"%{keyword}%"),
|
||||
OntologyScene.workspace_id == workspace_id
|
||||
).order_by(
|
||||
OntologyScene.updated_at.desc()
|
||||
).all()
|
||||
|
||||
logger.info(
|
||||
f"Found {len(scenes)} ontology scenes matching keyword '{keyword}' "
|
||||
f"in workspace {workspace_id}"
|
||||
)
|
||||
|
||||
return scenes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to search ontology scenes by keyword: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_by_workspace(self, workspace_id: UUID, page: Optional[int] = None, page_size: Optional[int] = None) -> tuple:
|
||||
"""获取工作空间下的所有场景(支持分页)
|
||||
|
||||
使用joinedload预加载classes关系以统计数量。
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
page: 页码(可选,从1开始)
|
||||
page_size: 每页数量(可选)
|
||||
|
||||
Returns:
|
||||
tuple: (场景列表, 总数量)
|
||||
|
||||
Examples:
|
||||
>>> repo = OntologySceneRepository(db)
|
||||
>>> scenes, total = repo.get_by_workspace(workspace_id)
|
||||
>>> scenes, total = repo.get_by_workspace(workspace_id, page=1, page_size=10)
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Getting ontology scenes by workspace: {workspace_id}, page={page}, page_size={page_size}")
|
||||
|
||||
# 构建基础查询
|
||||
query = self.db.query(OntologyScene).options(
|
||||
joinedload(OntologyScene.classes)
|
||||
).filter(
|
||||
OntologyScene.workspace_id == workspace_id
|
||||
).order_by(
|
||||
OntologyScene.updated_at.desc()
|
||||
)
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
|
||||
# 如果提供了分页参数,应用分页
|
||||
if page is not None and page_size is not None:
|
||||
offset = (page - 1) * page_size
|
||||
query = query.offset(offset).limit(page_size)
|
||||
logger.debug(f"Applying pagination: offset={offset}, limit={page_size}")
|
||||
|
||||
scenes = query.all()
|
||||
|
||||
logger.info(
|
||||
f"Found {len(scenes)} ontology scenes (total: {total}) in workspace {workspace_id}"
|
||||
)
|
||||
|
||||
return scenes, total
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get ontology scenes by workspace: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def update(self, scene_id: UUID, update_data: dict) -> Optional[OntologyScene]:
|
||||
"""更新场景信息
|
||||
|
||||
Args:
|
||||
scene_id: 场景ID
|
||||
update_data: 更新数据字典
|
||||
|
||||
Returns:
|
||||
Optional[OntologyScene]: 更新后的场景对象,不存在则返回None
|
||||
|
||||
Raises:
|
||||
Exception: 数据库操作失败
|
||||
|
||||
Examples:
|
||||
>>> repo = OntologySceneRepository(db)
|
||||
>>> scene = repo.update(
|
||||
... scene_id,
|
||||
... {"scene_name": "新名称"}
|
||||
... )
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Updating ontology scene: {scene_id}")
|
||||
|
||||
scene = self.get_by_id(scene_id)
|
||||
if not scene:
|
||||
logger.warning(f"Ontology scene not found for update: {scene_id}")
|
||||
return None
|
||||
|
||||
# 更新字段
|
||||
if "scene_name" in update_data and update_data["scene_name"] is not None:
|
||||
scene.scene_name = update_data["scene_name"]
|
||||
|
||||
if "scene_description" in update_data:
|
||||
scene.scene_description = update_data["scene_description"]
|
||||
|
||||
self.db.flush()
|
||||
|
||||
logger.info(f"Ontology scene updated successfully: {scene_id}")
|
||||
|
||||
return scene
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to update ontology scene: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def delete(self, scene_id: UUID) -> bool:
|
||||
"""删除场景(级联删除类型)
|
||||
|
||||
依赖数据库级联删除配置(ondelete="CASCADE")。
|
||||
|
||||
Args:
|
||||
scene_id: 场景ID
|
||||
|
||||
Returns:
|
||||
bool: 删除成功返回True,场景不存在返回False
|
||||
|
||||
Raises:
|
||||
Exception: 数据库操作失败
|
||||
|
||||
Examples:
|
||||
>>> repo = OntologySceneRepository(db)
|
||||
>>> success = repo.delete(scene_id)
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Deleting ontology scene: {scene_id}")
|
||||
|
||||
scene = self.get_by_id(scene_id)
|
||||
if not scene:
|
||||
logger.warning(f"Ontology scene not found for delete: {scene_id}")
|
||||
return False
|
||||
|
||||
self.db.delete(scene)
|
||||
self.db.flush()
|
||||
|
||||
logger.info(
|
||||
f"Ontology scene deleted successfully (cascade): {scene_id}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete ontology scene: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def check_ownership(self, scene_id: UUID, workspace_id: UUID) -> bool:
|
||||
"""检查场景是否属于指定工作空间
|
||||
|
||||
Args:
|
||||
scene_id: 场景ID
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
bool: 属于返回True,否则返回False
|
||||
|
||||
Examples:
|
||||
>>> repo = OntologySceneRepository(db)
|
||||
>>> is_owner = repo.check_ownership(scene_id, workspace_id)
|
||||
"""
|
||||
try:
|
||||
logger.debug(
|
||||
f"Checking scene ownership - "
|
||||
f"scene_id={scene_id}, workspace_id={workspace_id}"
|
||||
)
|
||||
|
||||
count = self.db.query(OntologyScene).filter(
|
||||
OntologyScene.scene_id == scene_id,
|
||||
OntologyScene.workspace_id == workspace_id
|
||||
).count()
|
||||
|
||||
is_owner = count > 0
|
||||
|
||||
logger.debug(
|
||||
f"Scene ownership check result: {is_owner} - "
|
||||
f"scene_id={scene_id}"
|
||||
)
|
||||
|
||||
return is_owner
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to check scene ownership: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_simple_list(self, workspace_id: UUID) -> List[dict]:
|
||||
"""获取场景简单列表(仅包含scene_id和scene_name,用于下拉选择)
|
||||
|
||||
这是一个轻量级查询,不加载关联的classes,响应速度快。
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
List[dict]: 场景简单列表,每项包含scene_id和scene_name
|
||||
|
||||
Examples:
|
||||
>>> repo = OntologySceneRepository(db)
|
||||
>>> scenes = repo.get_simple_list(workspace_id)
|
||||
>>> # [{"scene_id": "xxx", "scene_name": "场景1"}, ...]
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Getting simple scene list for workspace: {workspace_id}")
|
||||
|
||||
# 只查询需要的字段,不加载关联数据
|
||||
results = self.db.query(
|
||||
OntologyScene.scene_id,
|
||||
OntologyScene.scene_name
|
||||
).filter(
|
||||
OntologyScene.workspace_id == workspace_id
|
||||
).order_by(
|
||||
OntologyScene.updated_at.desc()
|
||||
).all()
|
||||
|
||||
scenes = [
|
||||
{"scene_id": str(r.scene_id), "scene_name": r.scene_name}
|
||||
for r in results
|
||||
]
|
||||
|
||||
logger.info(f"Found {len(scenes)} scenes (simple list) in workspace {workspace_id}")
|
||||
|
||||
return scenes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to get simple scene list: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
@@ -4,7 +4,10 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
from app.models.prompt_optimizer_model import (
|
||||
PromptOptimizerSession, PromptOptimizerSessionHistory, RoleType
|
||||
PromptOptimizerSession,
|
||||
PromptOptimizerSessionHistory,
|
||||
RoleType,
|
||||
PromptHistory
|
||||
)
|
||||
|
||||
db_logger = get_db_logger()
|
||||
@@ -16,6 +19,12 @@ class PromptOptimizerSessionRepository:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_session_by_id(self, session_id: uuid.UUID) -> PromptOptimizerSession | None:
|
||||
session = self.db.query(PromptOptimizerSession).filter(
|
||||
PromptOptimizerSession.id == session_id,
|
||||
).first()
|
||||
return session
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
@@ -38,12 +47,9 @@ class PromptOptimizerSessionRepository:
|
||||
user_id=user_id,
|
||||
)
|
||||
self.db.add(session)
|
||||
self.db.commit()
|
||||
self.db.refresh(session)
|
||||
db_logger.debug(f"Prompt optimization session created: ID:{session.id}")
|
||||
return session
|
||||
except Exception as e:
|
||||
db_logger.error(f"Error creating prompt optimization session: user_id={user_id} - {str(e)}")
|
||||
db_logger.error(f"Error creating prompt optimization session: - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_session_history(
|
||||
@@ -71,10 +77,10 @@ class PromptOptimizerSessionRepository:
|
||||
PromptOptimizerSession.id == session_id,
|
||||
PromptOptimizerSession.user_id == user_id
|
||||
).first()
|
||||
|
||||
|
||||
if not session:
|
||||
return []
|
||||
|
||||
|
||||
history = self.db.query(PromptOptimizerSessionHistory).filter(
|
||||
PromptOptimizerSessionHistory.session_id == session.id,
|
||||
PromptOptimizerSessionHistory.user_id == user_id
|
||||
@@ -104,11 +110,11 @@ class PromptOptimizerSessionRepository:
|
||||
PromptOptimizerSession.user_id == user_id,
|
||||
PromptOptimizerSession.tenant_id == tenant_id
|
||||
).first()
|
||||
|
||||
|
||||
if not session:
|
||||
db_logger.error(f"Session {session_id} not found for user {user_id}")
|
||||
raise ValueError(f"Session {session_id} not found for user {user_id}")
|
||||
|
||||
|
||||
message = PromptOptimizerSessionHistory(
|
||||
tenant_id=tenant_id,
|
||||
session_id=session.id,
|
||||
@@ -117,8 +123,199 @@ class PromptOptimizerSessionRepository:
|
||||
content=content,
|
||||
)
|
||||
self.db.add(message)
|
||||
self.db.commit()
|
||||
|
||||
return message
|
||||
except Exception as e:
|
||||
db_logger.error(f"Error creating prompt optimization session history: session_id={session_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_first_user_message(self, session_id: uuid.UUID) -> str | None:
|
||||
"""
|
||||
Get the first user message from a session.
|
||||
|
||||
Args:
|
||||
session_id (uuid.UUID): The session ID.
|
||||
|
||||
Returns:
|
||||
str | None: The content of the first user message, or None if not found.
|
||||
"""
|
||||
try:
|
||||
message = self.db.query(PromptOptimizerSessionHistory).filter(
|
||||
PromptOptimizerSessionHistory.session_id == session_id,
|
||||
PromptOptimizerSessionHistory.role == RoleType.USER.value
|
||||
).order_by(
|
||||
PromptOptimizerSessionHistory.created_at.asc()
|
||||
).first()
|
||||
|
||||
return message.content if message else None
|
||||
except Exception as e:
|
||||
db_logger.error(f"Error getting first user message: session_id={session_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
class PromptReleaseRepository:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_prompt_by_session_id(self, session_id: uuid.UUID) -> PromptHistory | None:
|
||||
prompt_obj = self.db.query(PromptHistory).filter(
|
||||
PromptHistory.session_id == session_id,
|
||||
PromptHistory.is_delete.is_(False)
|
||||
).first()
|
||||
return prompt_obj
|
||||
|
||||
def create_prompt_release(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
title: str,
|
||||
session_id: uuid.UUID,
|
||||
prompt: str,
|
||||
) -> PromptHistory:
|
||||
try:
|
||||
prompt_obj = PromptHistory(
|
||||
tenant_id=tenant_id,
|
||||
title=title,
|
||||
session_id=session_id,
|
||||
prompt=prompt,
|
||||
)
|
||||
self.db.add(prompt_obj)
|
||||
return prompt_obj
|
||||
except Exception as e:
|
||||
db_logger.error(f"Error creating prompt release: session_id={session_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def soft_delete_prompt(self, prompt_obj: PromptHistory) -> None:
|
||||
"""
|
||||
Soft delete a prompt release by setting is_delete flag to True.
|
||||
|
||||
Args:
|
||||
prompt_obj (PromptHistory): The prompt release object to delete.
|
||||
"""
|
||||
try:
|
||||
prompt_obj.is_delete = True
|
||||
db_logger.debug(f"Soft deleted prompt release: id={prompt_obj.id}, session_id={prompt_obj.session_id}")
|
||||
except Exception as e:
|
||||
db_logger.error(f"Error soft deleting prompt release: id={prompt_obj.id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_prompt_by_id(self, prompt_id: uuid.UUID) -> PromptHistory | None:
|
||||
"""
|
||||
Get a prompt release by its ID.
|
||||
|
||||
Args:
|
||||
prompt_id (uuid.UUID): The prompt release ID.
|
||||
|
||||
Returns:
|
||||
PromptHistory | None: The prompt release object or None if not found.
|
||||
"""
|
||||
try:
|
||||
prompt_obj = self.db.query(PromptHistory).filter(
|
||||
PromptHistory.id == prompt_id
|
||||
).first()
|
||||
return prompt_obj
|
||||
except Exception as e:
|
||||
db_logger.error(f"Error getting prompt release by id: id={prompt_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def count_prompts(self, tenant_id: uuid.UUID) -> int:
|
||||
"""
|
||||
Count total number of non-deleted prompts for a tenant.
|
||||
|
||||
Args:
|
||||
tenant_id (uuid.UUID): The tenant ID.
|
||||
|
||||
Returns:
|
||||
int: Total count of prompts.
|
||||
"""
|
||||
try:
|
||||
count = self.db.query(PromptHistory).filter(
|
||||
PromptHistory.tenant_id == tenant_id,
|
||||
PromptHistory.is_delete.is_(False)
|
||||
).count()
|
||||
return count
|
||||
except Exception as e:
|
||||
db_logger.error(f"Error counting prompts: tenant_id={tenant_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_prompts_paginated(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
offset: int,
|
||||
limit: int
|
||||
) -> list[PromptHistory]:
|
||||
"""
|
||||
Get paginated list of prompt releases for a tenant.
|
||||
|
||||
Args:
|
||||
tenant_id (uuid.UUID): The tenant ID.
|
||||
offset (int): Number of records to skip.
|
||||
limit (int): Maximum number of records to return.
|
||||
|
||||
Returns:
|
||||
list[PromptHistory]: List of prompt releases.
|
||||
"""
|
||||
try:
|
||||
prompts = self.db.query(PromptHistory).filter(
|
||||
PromptHistory.tenant_id == tenant_id,
|
||||
PromptHistory.is_delete.is_(False)
|
||||
).order_by(
|
||||
PromptHistory.created_at.desc()
|
||||
).offset(offset).limit(limit).all()
|
||||
return prompts
|
||||
except Exception as e:
|
||||
db_logger.error(f"Error getting paginated prompts: tenant_id={tenant_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def count_prompts_by_keyword(self, tenant_id: uuid.UUID, keyword: str) -> int:
|
||||
"""
|
||||
Count total number of non-deleted prompts matching keyword for a tenant.
|
||||
|
||||
Args:
|
||||
tenant_id (uuid.UUID): The tenant ID.
|
||||
keyword (str): Search keyword for title.
|
||||
|
||||
Returns:
|
||||
int: Total count of matching prompts.
|
||||
"""
|
||||
try:
|
||||
count = self.db.query(PromptHistory).filter(
|
||||
PromptHistory.tenant_id == tenant_id,
|
||||
PromptHistory.is_delete.is_(False),
|
||||
PromptHistory.title.ilike(f"%{keyword}%")
|
||||
).count()
|
||||
return count
|
||||
except Exception as e:
|
||||
db_logger.error(f"Error counting prompts by keyword: tenant_id={tenant_id}, keyword={keyword} - {str(e)}")
|
||||
raise
|
||||
|
||||
def search_prompts_paginated(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
keyword: str,
|
||||
offset: int,
|
||||
limit: int
|
||||
) -> list[PromptHistory]:
|
||||
"""
|
||||
Search prompt releases by keyword in title with pagination.
|
||||
|
||||
Args:
|
||||
tenant_id (uuid.UUID): The tenant ID.
|
||||
keyword (str): Search keyword for title.
|
||||
offset (int): Number of records to skip.
|
||||
limit (int): Maximum number of records to return.
|
||||
|
||||
Returns:
|
||||
list[PromptHistory]: List of matching prompt releases.
|
||||
"""
|
||||
try:
|
||||
prompts = self.db.query(PromptHistory).filter(
|
||||
PromptHistory.tenant_id == tenant_id,
|
||||
PromptHistory.is_delete.is_(False),
|
||||
PromptHistory.title.ilike(f"%{keyword}%")
|
||||
).order_by(
|
||||
PromptHistory.created_at.desc()
|
||||
).offset(offset).limit(limit).all()
|
||||
return prompts
|
||||
except Exception as e:
|
||||
db_logger.error(f"Error searching prompts: tenant_id={tenant_id}, keyword={keyword} - {str(e)}")
|
||||
raise
|
||||
|
||||
@@ -12,8 +12,8 @@ class KnowledgeBaseConfig(BaseModel):
|
||||
kb_id: str = Field(..., description="知识库ID")
|
||||
top_k: int = Field(default=3, ge=1, le=20, description="检索返回的文档数量")
|
||||
similarity_threshold: float = Field(default=0.7, ge=0.0, le=1.0, description="相似度阈值")
|
||||
strategy: str = Field(default="hybrid", description="检索策略: hybrid | bm25 | dense")
|
||||
weight: float = Field(default=1.0, ge=0.0, le=1.0, description="知识库权重(用于多知识库融合)")
|
||||
# strategy: str = Field(default="hybrid", description="检索策略: hybrid | bm25 | dense")
|
||||
# weight: float = Field(default=1.0, ge=0.0, le=1.0, description="知识库权重(用于多知识库融合)")
|
||||
vector_similarity_weight: float = Field(default=0.5, ge=0.0, le=1.0, description="向量相似度权重")
|
||||
retrieve_type: str = Field(default="hybrid", description="检索方式participle| semantic|hybrid")
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from abc import ABC
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -14,4 +15,15 @@ class UserInput(BaseModel):
|
||||
class Write_UserInput(BaseModel):
|
||||
messages: list[dict]
|
||||
end_user_id: str
|
||||
config_id: Optional[str] = None
|
||||
config_id: Optional[str] = None
|
||||
|
||||
class AgentMemory_Long_Term(ABC):
|
||||
"""长期记忆配置常量"""
|
||||
STORAGE_NEO4J = "neo4j"
|
||||
STORAGE_RAG = "rag"
|
||||
STRATEGY_AGGREGATE = "aggregate"
|
||||
STRATEGY_CHUNK = "chunk"
|
||||
STRATEGY_TIME = "time"
|
||||
DEFAULT_SCOPE = 6
|
||||
|
||||
|
||||
|
||||
@@ -229,10 +229,15 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body,
|
||||
config_desc: str = Field("配置描述", description="配置描述(字符串)")
|
||||
workspace_id: Optional[uuid.UUID] = Field(None, description="工作空间ID(UUID)")
|
||||
|
||||
# 本体场景关联(可选)
|
||||
scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID(UUID),关联ontology_scene表")
|
||||
|
||||
# 模型配置字段(可选,用于手动指定或自动填充)
|
||||
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
|
||||
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")
|
||||
rerank_id: Optional[str] = Field(None, description="重排序模型配置ID")
|
||||
reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致")
|
||||
emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致")
|
||||
|
||||
|
||||
class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
|
||||
@@ -243,8 +248,9 @@ class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体)
|
||||
|
||||
class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
||||
config_id: Union[uuid.UUID, int, str] = None
|
||||
config_name: str = Field("配置名称", description="配置名称(字符串)")
|
||||
config_desc: str = Field("配置描述", description="配置描述(字符串)")
|
||||
config_name: Optional[str] = Field(None, description="配置名称(字符串)")
|
||||
config_desc: Optional[str] = Field(None, description="配置描述(字符串)")
|
||||
scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID")
|
||||
|
||||
|
||||
class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型
|
||||
|
||||
461
api/app/schemas/ontology_schemas.py
Normal file
461
api/app/schemas/ontology_schemas.py
Normal file
@@ -0,0 +1,461 @@
|
||||
"""本体提取API的请求和响应模型
|
||||
|
||||
本模块定义了本体提取系统的所有API请求和响应的Pydantic模型。
|
||||
|
||||
Classes:
|
||||
ExtractionRequest: 本体提取请求模型
|
||||
ExtractionResponse: 本体提取响应模型
|
||||
ExportRequest: OWL文件导出请求模型
|
||||
ExportResponse: OWL文件导出响应模型
|
||||
OntologyResultResponse: 本体提取结果响应模型(带毫秒时间戳)
|
||||
SceneCreateRequest: 场景创建请求模型
|
||||
SceneUpdateRequest: 场景更新请求模型
|
||||
SceneResponse: 场景响应模型
|
||||
SceneListResponse: 场景列表响应模型
|
||||
ClassCreateRequest: 类型创建请求模型
|
||||
ClassUpdateRequest: 类型更新请求模型
|
||||
ClassResponse: 类型响应模型
|
||||
ClassListResponse: 类型列表响应模型
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer, ConfigDict
|
||||
|
||||
from app.core.memory.models.ontology_models import OntologyClass
|
||||
|
||||
|
||||
class ExtractionRequest(BaseModel):
|
||||
"""本体提取请求模型
|
||||
|
||||
用于POST /api/ontology/extract端点的请求体。
|
||||
|
||||
Attributes:
|
||||
scenario: 场景描述文本,不能为空
|
||||
domain: 可选的领域提示(如Healthcare, Education等)
|
||||
llm_id: LLM模型ID,必须提供
|
||||
scene_id: 场景ID,必须提供,用于将提取的类保存到指定场景
|
||||
|
||||
Examples:
|
||||
>>> request = ExtractionRequest(
|
||||
... scenario="医院管理患者记录...",
|
||||
... domain="Healthcare",
|
||||
... llm_id="550e8400-e29b-41d4-a716-446655440000",
|
||||
... scene_id="660e8400-e29b-41d4-a716-446655440000"
|
||||
... )
|
||||
"""
|
||||
scenario: str = Field(..., description="场景描述文本", min_length=1)
|
||||
domain: Optional[str] = Field(None, description="可选的领域提示")
|
||||
llm_id: str = Field(..., description="LLM模型ID")
|
||||
scene_id: UUID = Field(..., description="场景ID,用于将提取的类保存到指定场景")
|
||||
|
||||
|
||||
class ExtractionResponse(BaseModel):
|
||||
"""本体提取响应模型
|
||||
|
||||
用于POST /api/ontology/extract端点的响应体。
|
||||
|
||||
Attributes:
|
||||
classes: 提取的本体类列表
|
||||
domain: 识别的领域
|
||||
extracted_count: 提取的类数量
|
||||
|
||||
Examples:
|
||||
>>> response = ExtractionResponse(
|
||||
... classes=[...],
|
||||
... domain="Healthcare",
|
||||
... extracted_count=7
|
||||
... )
|
||||
"""
|
||||
classes: List[OntologyClass] = Field(default_factory=list, description="提取的本体类列表")
|
||||
domain: str = Field(..., description="识别的领域")
|
||||
extracted_count: int = Field(..., description="提取的类数量")
|
||||
|
||||
|
||||
class ExportRequest(BaseModel):
|
||||
"""OWL文件导出请求模型
|
||||
|
||||
用于POST /api/ontology/export端点的请求体。
|
||||
|
||||
Attributes:
|
||||
classes: 要导出的本体类列表
|
||||
format: 导出格式,可选值: rdfxml, turtle, ntriples, json
|
||||
include_metadata: 是否包含完整的OWL元数据(命名空间等),默认True
|
||||
|
||||
Examples:
|
||||
>>> request = ExportRequest(
|
||||
... classes=[...],
|
||||
... format="rdfxml",
|
||||
... include_metadata=True
|
||||
... )
|
||||
"""
|
||||
classes: List[OntologyClass] = Field(..., description="要导出的本体类列表", min_length=1)
|
||||
format: str = Field("rdfxml", description="导出格式: rdfxml, turtle, ntriples, json")
|
||||
include_metadata: bool = Field(True, description="是否包含完整的OWL元数据")
|
||||
|
||||
|
||||
class ExportResponse(BaseModel):
|
||||
"""OWL文件导出响应模型
|
||||
|
||||
用于POST /api/ontology/export端点的响应体。
|
||||
|
||||
Attributes:
|
||||
owl_content: OWL文件内容
|
||||
format: 导出格式
|
||||
classes_count: 导出的类数量
|
||||
|
||||
Examples:
|
||||
>>> response = ExportResponse(
|
||||
... owl_content="<?xml version='1.0'?>...",
|
||||
... format="rdfxml",
|
||||
... classes_count=7
|
||||
... )
|
||||
"""
|
||||
owl_content: str = Field(..., description="OWL文件内容")
|
||||
format: str = Field(..., description="导出格式")
|
||||
classes_count: int = Field(..., description="导出的类数量")
|
||||
|
||||
|
||||
class OntologyResultResponse(BaseModel):
|
||||
"""本体提取结果响应模型
|
||||
|
||||
用于返回数据库中存储的提取结果,时间戳为毫秒级。
|
||||
|
||||
Attributes:
|
||||
id: 结果ID (UUID)
|
||||
scenario: 场景描述文本
|
||||
domain: 领域
|
||||
classes_json: 提取的本体类数据(JSON格式)
|
||||
extracted_count: 提取的类数量
|
||||
user_id: 用户ID
|
||||
created_at: 创建时间(毫秒时间戳)
|
||||
|
||||
Examples:
|
||||
>>> response = OntologyResultResponse(
|
||||
... id=uuid.uuid4(),
|
||||
... scenario="医院管理患者记录...",
|
||||
... domain="Healthcare",
|
||||
... classes_json={"classes": [...]},
|
||||
... extracted_count=7,
|
||||
... user_id=123,
|
||||
... created_at=datetime.now()
|
||||
... )
|
||||
"""
|
||||
id: UUID = Field(..., description="结果ID")
|
||||
scenario: str = Field(..., description="场景描述文本")
|
||||
domain: Optional[str] = Field(None, description="领域")
|
||||
classes_json: dict = Field(..., description="提取的本体类数据(JSON格式)")
|
||||
extracted_count: int = Field(..., description="提取的类数量")
|
||||
user_id: Optional[int] = Field(None, description="用户ID")
|
||||
created_at: datetime.datetime = Field(..., description="创建时间")
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
"""将创建时间序列化为毫秒时间戳"""
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
|
||||
# ==================== 本体场景相关 Schema ====================
|
||||
|
||||
class SceneCreateRequest(BaseModel):
|
||||
"""场景创建请求模型
|
||||
|
||||
用于创建新的本体场景。
|
||||
|
||||
Attributes:
|
||||
scene_name: 场景名称,必填,1-200字符
|
||||
scene_description: 场景描述,可选
|
||||
|
||||
Examples:
|
||||
>>> request = SceneCreateRequest(
|
||||
... scene_name="医疗场景",
|
||||
... scene_description="用于医疗领域的本体建模"
|
||||
... )
|
||||
"""
|
||||
scene_name: str = Field(..., min_length=1, max_length=200, description="场景名称")
|
||||
scene_description: Optional[str] = Field(None, description="场景描述")
|
||||
|
||||
|
||||
class SceneUpdateRequest(BaseModel):
|
||||
"""场景更新请求模型
|
||||
|
||||
用于更新已有本体场景信息。
|
||||
|
||||
Attributes:
|
||||
scene_name: 场景名称,可选,1-200字符
|
||||
scene_description: 场景描述,可选
|
||||
|
||||
Examples:
|
||||
>>> request = SceneUpdateRequest(
|
||||
... scene_name="更新后的场景名称",
|
||||
... scene_description="更新后的描述"
|
||||
... )
|
||||
"""
|
||||
scene_name: Optional[str] = Field(None, min_length=1, max_length=200, description="场景名称")
|
||||
scene_description: Optional[str] = Field(None, description="场景描述")
|
||||
|
||||
|
||||
class SceneResponse(BaseModel):
|
||||
"""场景响应模型
|
||||
|
||||
用于返回本体场景信息。
|
||||
|
||||
Attributes:
|
||||
scene_id: 场景ID
|
||||
scene_name: 场景名称
|
||||
scene_description: 场景描述
|
||||
type_num: 类型数量
|
||||
workspace_id: 所属工作空间ID
|
||||
created_at: 创建时间(毫秒时间戳)
|
||||
updated_at: 更新时间(毫秒时间戳)
|
||||
classes_count: 类型数量
|
||||
|
||||
Examples:
|
||||
>>> response = SceneResponse(
|
||||
... scene_id=uuid.uuid4(),
|
||||
... scene_name="医疗场景",
|
||||
... scene_description="用于医疗领域的本体建模",
|
||||
... type_num=0,
|
||||
... workspace_id=uuid.uuid4(),
|
||||
... created_at=datetime.now(),
|
||||
... updated_at=datetime.now(),
|
||||
... classes_count=5
|
||||
... )
|
||||
"""
|
||||
scene_id: UUID = Field(..., description="场景ID")
|
||||
scene_name: str = Field(..., description="场景名称")
|
||||
scene_description: Optional[str] = Field(None, description="场景描述")
|
||||
type_num: int = Field(..., description="类型数量")
|
||||
entity_type: Optional[List[str]] = Field(None, description="实体类型列表(最多3个class_name)")
|
||||
workspace_id: UUID = Field(..., description="所属工作空间ID")
|
||||
created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)")
|
||||
updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)")
|
||||
classes_count: int = Field(0, description="类型数量")
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
"""将创建时间序列化为毫秒时间戳"""
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
"""将更新时间序列化为毫秒时间戳"""
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class PaginationInfo(BaseModel):
|
||||
"""分页信息模型
|
||||
|
||||
Attributes:
|
||||
page: 当前页码
|
||||
pagesize: 每页数量
|
||||
total: 总数量
|
||||
hasnext: 是否有下一页
|
||||
"""
|
||||
page: int = Field(..., description="当前页码")
|
||||
pagesize: int = Field(..., description="每页数量")
|
||||
total: int = Field(..., description="总数量")
|
||||
hasnext: bool = Field(..., description="是否有下一页")
|
||||
|
||||
|
||||
class SceneListResponse(BaseModel):
|
||||
"""场景列表响应模型(支持分页)
|
||||
|
||||
用于返回本体场景列表。
|
||||
|
||||
Attributes:
|
||||
items: 场景列表
|
||||
page: 分页信息(可选,分页时返回)
|
||||
|
||||
Examples:
|
||||
>>> # 不分页
|
||||
>>> response = SceneListResponse(
|
||||
... items=[scene1, scene2]
|
||||
... )
|
||||
>>> # 分页
|
||||
>>> response = SceneListResponse(
|
||||
... items=[scene1, scene2, ...],
|
||||
... page=PaginationInfo(page=1, pagesize=100, total=150, hasnext=True)
|
||||
... )
|
||||
"""
|
||||
items: List[SceneResponse] = Field(..., description="场景列表")
|
||||
page: Optional[PaginationInfo] = Field(None, description="分页信息")
|
||||
|
||||
|
||||
# ==================== 本体类型相关 Schema ====================
|
||||
|
||||
class ClassItem(BaseModel):
|
||||
"""单个类型信息模型
|
||||
|
||||
Attributes:
|
||||
class_name: 类型名称,必填,1-200字符
|
||||
class_description: 类型描述,可选
|
||||
|
||||
Examples:
|
||||
>>> item = ClassItem(
|
||||
... class_name="患者",
|
||||
... class_description="医院患者信息"
|
||||
... )
|
||||
"""
|
||||
class_name: str = Field(..., min_length=1, max_length=200, description="类型名称")
|
||||
class_description: Optional[str] = Field(None, description="类型描述")
|
||||
|
||||
|
||||
class ClassCreateRequest(BaseModel):
|
||||
"""类型创建请求模型(统一使用列表形式)
|
||||
|
||||
通过列表中元素数量决定创建模式:
|
||||
- 列表包含 1 个元素:单个创建
|
||||
- 列表包含多个元素:批量创建
|
||||
|
||||
Attributes:
|
||||
scene_id: 所属场景ID,必填
|
||||
classes: 类型列表,必填,至少包含 1 个元素
|
||||
|
||||
Examples:
|
||||
# 单个创建(列表中 1 个元素)
|
||||
>>> request = ClassCreateRequest(
|
||||
... scene_id=uuid.uuid4(),
|
||||
... classes=[
|
||||
... ClassItem(class_name="患者", class_description="医院患者信息")
|
||||
... ]
|
||||
... )
|
||||
|
||||
# 批量创建(列表中多个元素)
|
||||
>>> request = ClassCreateRequest(
|
||||
... scene_id=uuid.uuid4(),
|
||||
... classes=[
|
||||
... ClassItem(class_name="患者", class_description="医院患者信息"),
|
||||
... ClassItem(class_name="医生", class_description="医院医生信息"),
|
||||
... ClassItem(class_name="药品", class_description="医院药品信息")
|
||||
... ]
|
||||
... )
|
||||
"""
|
||||
scene_id: UUID = Field(..., description="所属场景ID")
|
||||
classes: List[ClassItem] = Field(..., min_length=1, description="类型列表,至少包含 1 个元素")
|
||||
|
||||
|
||||
class ClassUpdateRequest(BaseModel):
|
||||
"""类型更新请求模型
|
||||
|
||||
用于更新已有本体类型信息。
|
||||
|
||||
Attributes:
|
||||
class_name: 类型名称,可选,1-200字符
|
||||
class_description: 类型描述,可选
|
||||
|
||||
Examples:
|
||||
>>> request = ClassUpdateRequest(
|
||||
... class_name="更新后的类型名称",
|
||||
... class_description="更新后的描述"
|
||||
... )
|
||||
"""
|
||||
class_name: Optional[str] = Field(None, min_length=1, max_length=200, description="类型名称")
|
||||
class_description: Optional[str] = Field(None, description="类型描述")
|
||||
|
||||
|
||||
class ClassResponse(BaseModel):
|
||||
"""类型响应模型
|
||||
|
||||
用于返回本体类型信息。
|
||||
|
||||
Attributes:
|
||||
class_id: 类型ID
|
||||
class_name: 类型名称
|
||||
class_description: 类型描述
|
||||
scene_id: 所属场景ID
|
||||
created_at: 创建时间(毫秒时间戳)
|
||||
updated_at: 更新时间(毫秒时间戳)
|
||||
|
||||
Examples:
|
||||
>>> response = ClassResponse(
|
||||
... class_id=uuid.uuid4(),
|
||||
... class_name="患者",
|
||||
... class_description="医院患者信息",
|
||||
... scene_id=uuid.uuid4(),
|
||||
... created_at=datetime.now(),
|
||||
... updated_at=datetime.now()
|
||||
... )
|
||||
"""
|
||||
class_id: UUID = Field(..., description="类型ID")
|
||||
class_name: str = Field(..., description="类型名称")
|
||||
class_description: Optional[str] = Field(None, description="类型描述")
|
||||
scene_id: UUID = Field(..., description="所属场景ID")
|
||||
created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)")
|
||||
updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)")
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
"""将创建时间序列化为毫秒时间戳"""
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
"""将更新时间序列化为毫秒时间戳"""
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ClassBatchCreateResponse(BaseModel):
|
||||
"""批量创建类型响应模型
|
||||
|
||||
用于返回批量创建的结果统计和详情。
|
||||
|
||||
Attributes:
|
||||
total: 总共尝试创建的数量
|
||||
success_count: 成功创建的数量
|
||||
failed_count: 失败的数量
|
||||
items: 成功创建的类型列表
|
||||
errors: 失败的错误信息列表(可选)
|
||||
|
||||
Examples:
|
||||
>>> response = ClassBatchCreateResponse(
|
||||
... total=3,
|
||||
... success_count=2,
|
||||
... failed_count=1,
|
||||
... items=[class1, class2],
|
||||
... errors=["创建类型 '药品' 失败: 类型名称已存在"]
|
||||
... )
|
||||
"""
|
||||
total: int = Field(..., description="总共尝试创建的数量")
|
||||
success_count: int = Field(..., description="成功创建的数量")
|
||||
failed_count: int = Field(0, description="失败的数量")
|
||||
items: List[ClassResponse] = Field(..., description="成功创建的类型列表")
|
||||
errors: Optional[List[str]] = Field(None, description="失败的错误信息列表")
|
||||
|
||||
|
||||
class ClassListResponse(BaseModel):
|
||||
"""类型列表响应模型
|
||||
|
||||
用于返回本体类型列表。
|
||||
|
||||
Attributes:
|
||||
total: 总数量
|
||||
scene_id: 所属场景ID
|
||||
scene_name: 场景名称
|
||||
scene_description: 场景描述
|
||||
items: 类型列表
|
||||
|
||||
Examples:
|
||||
>>> response = ClassListResponse(
|
||||
... total=3,
|
||||
... scene_id=uuid.uuid4(),
|
||||
... scene_name="医疗场景",
|
||||
... scene_description="用于医疗领域的本体建模",
|
||||
... items=[class1, class2, class3]
|
||||
... )
|
||||
"""
|
||||
total: int = Field(..., description="总数量")
|
||||
scene_id: UUID = Field(..., description="所属场景ID")
|
||||
scene_name: str = Field(..., description="场景名称")
|
||||
scene_description: Optional[str] = Field(None, description="场景描述")
|
||||
items: List[ClassResponse] = Field(..., description="类型列表")
|
||||
@@ -22,6 +22,23 @@ class PromptOptMessage(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class PromptSaveRequest(BaseModel):
|
||||
session_id: UUID = Field(
|
||||
...,
|
||||
description="Session ID"
|
||||
)
|
||||
|
||||
title: str = Field(
|
||||
...,
|
||||
description="Prompt Title"
|
||||
)
|
||||
|
||||
prompt: str = Field(
|
||||
...,
|
||||
description="Optimized prompt content"
|
||||
)
|
||||
|
||||
|
||||
class PromptOptModelSet(BaseModel):
|
||||
id: UUID | None = Field(
|
||||
default=None,
|
||||
|
||||
@@ -171,7 +171,14 @@ class AppChatService:
|
||||
self.conversation_service.save_conversation_messages(
|
||||
conversation_id=conversation_id,
|
||||
user_message=message,
|
||||
assistant_message=result["content"]
|
||||
assistant_message=result["content"],
|
||||
meta_data={
|
||||
"usage": result.get("usage", {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
@@ -310,6 +317,7 @@ class AppChatService:
|
||||
|
||||
# 流式调用 Agent
|
||||
full_content = ""
|
||||
total_tokens = 0
|
||||
async for chunk in agent.chat_stream(
|
||||
message=message,
|
||||
history=history,
|
||||
@@ -320,9 +328,12 @@ class AppChatService:
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
):
|
||||
full_content += chunk
|
||||
# 发送消息块事件
|
||||
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
if isinstance(chunk, int):
|
||||
total_tokens = chunk
|
||||
else:
|
||||
full_content += chunk
|
||||
# 发送消息块事件
|
||||
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
@@ -339,7 +350,7 @@ class AppChatService:
|
||||
content=full_content,
|
||||
meta_data={
|
||||
"model": api_key_obj.model_name,
|
||||
"usage": {}
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
|
||||
}
|
||||
)
|
||||
|
||||
@@ -416,7 +427,11 @@ class AppChatService:
|
||||
meta_data={
|
||||
"mode": result.get("mode"),
|
||||
"elapsed_time": result.get("elapsed_time"),
|
||||
"sub_results": result.get("sub_results")
|
||||
"usage": result.get("usage", {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
@@ -458,6 +473,7 @@ class AppChatService:
|
||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
full_content = ""
|
||||
total_tokens = 0
|
||||
|
||||
# 2. 创建编排器
|
||||
orchestrator = MultiAgentOrchestrator(self.db, config)
|
||||
@@ -474,16 +490,26 @@ class AppChatService:
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
# 尝试提取内容(用于保存)
|
||||
if "data:" in event:
|
||||
try:
|
||||
data_line = event.split("data: ", 1)[1].strip()
|
||||
data = json.loads(data_line)
|
||||
if "content" in data:
|
||||
full_content += data["content"]
|
||||
except:
|
||||
pass
|
||||
if "sub_usage" in event:
|
||||
if "data:" in event:
|
||||
try:
|
||||
data_line = event.split("data: ", 1)[1].strip()
|
||||
data = json.loads(data_line)
|
||||
if "total_tokens" in data:
|
||||
total_tokens += data["total_tokens"]
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
yield event
|
||||
# 尝试提取内容(用于保存)
|
||||
if "data:" in event:
|
||||
try:
|
||||
data_line = event.split("data: ", 1)[1].strip()
|
||||
data = json.loads(data_line)
|
||||
if "content" in data:
|
||||
full_content += data["content"]
|
||||
except:
|
||||
pass
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
@@ -499,7 +525,12 @@ class AppChatService:
|
||||
role="assistant",
|
||||
content=full_content,
|
||||
meta_data={
|
||||
"elapsed_time": elapsed_time
|
||||
"elapsed_time": elapsed_time,
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class AppStatisticsService:
|
||||
daily_tokens[date_str] = 0
|
||||
daily_tokens[date_str] += int(tokens)
|
||||
|
||||
daily_data = [{"date": date, "tokens": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
|
||||
total = sum(row["tokens"] for row in daily_data)
|
||||
daily_data = [{"date": date, "count": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
|
||||
total = sum(row["count"] for row in daily_data)
|
||||
|
||||
return {"daily": daily_data, "total": total}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""会话服务"""
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Annotated
|
||||
@@ -298,7 +299,8 @@ class ConversationService:
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
user_message: str,
|
||||
assistant_message: str
|
||||
assistant_message: str,
|
||||
meta_data: Optional[dict] = None
|
||||
):
|
||||
"""
|
||||
Save a pair of user and assistant messages to the conversation.
|
||||
@@ -307,6 +309,7 @@ class ConversationService:
|
||||
conversation_id (uuid.UUID): Conversation UUID.
|
||||
user_message (str): User's message content.
|
||||
assistant_message (str): Assistant's response content.
|
||||
meta_data (Optional[dict]): Optional metadata for the messages.
|
||||
"""
|
||||
self.add_message(
|
||||
conversation_id=conversation_id,
|
||||
@@ -317,7 +320,8 @@ class ConversationService:
|
||||
self.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=assistant_message
|
||||
content=assistant_message,
|
||||
meta_data=meta_data
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
@@ -526,12 +530,12 @@ class ConversationService:
|
||||
takeaways=[],
|
||||
info_score=0,
|
||||
)
|
||||
|
||||
with open('app/services/prompt/conversation_summary_system.jinja2', 'r', encoding='utf-8') as f:
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||
with open(os.path.join(prompt_path, 'conversation_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||
system_prompt = f.read()
|
||||
rendered_system_message = Template(system_prompt).render()
|
||||
|
||||
with open('app/services/prompt/conversation_summary_user.jinja2', 'r', encoding='utf-8') as f:
|
||||
with open(os.path.join(prompt_path, 'conversation_summary_user.jinja2'), 'r', encoding='utf-8') as f:
|
||||
user_prompt = f.read()
|
||||
rendered_user_message = Template(user_prompt).render(
|
||||
language=language,
|
||||
|
||||
@@ -110,6 +110,8 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
result = task_service.get_task_memory_read_result(task.id)
|
||||
status = result.get("status")
|
||||
logger.info(f"读取任务状态:{status}")
|
||||
if memory_content:
|
||||
memory_content = memory_content['answer']
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
@@ -123,7 +125,6 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
"content_length": len(str(memory_content))
|
||||
}
|
||||
)
|
||||
|
||||
return f"检索到以下历史记忆:\n\n{memory_content}"
|
||||
except Exception as e:
|
||||
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
||||
@@ -442,7 +443,14 @@ class DraftRunService:
|
||||
user_message=message,
|
||||
assistant_message=result["content"],
|
||||
app_id=agent_config.app_id,
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
meta_data={
|
||||
"usage": result.get("usage", {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
response = {
|
||||
@@ -649,6 +657,7 @@ class DraftRunService:
|
||||
|
||||
# 9. 流式调用 Agent
|
||||
full_content = ""
|
||||
total_tokens = 0
|
||||
async for chunk in agent.chat_stream(
|
||||
message=message,
|
||||
history=history,
|
||||
@@ -659,14 +668,22 @@ class DraftRunService:
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_flag=memory_flag
|
||||
):
|
||||
full_content += chunk
|
||||
# 发送消息块事件
|
||||
yield self._format_sse_event("message", {
|
||||
"content": chunk
|
||||
})
|
||||
if isinstance(chunk, int):
|
||||
total_tokens = chunk
|
||||
else:
|
||||
full_content += chunk
|
||||
# 发送消息块事件
|
||||
yield self._format_sse_event("message", {
|
||||
"content": chunk
|
||||
})
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
if sub_agent:
|
||||
yield self._format_sse_event("sub_usage", {
|
||||
"total_tokens": total_tokens
|
||||
})
|
||||
|
||||
# 10. 保存会话消息
|
||||
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
||||
await self._save_conversation_message(
|
||||
@@ -674,7 +691,10 @@ class DraftRunService:
|
||||
user_message=message,
|
||||
assistant_message=full_content,
|
||||
app_id=agent_config.app_id,
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
meta_data={
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
|
||||
}
|
||||
)
|
||||
|
||||
# 11. 发送结束事件
|
||||
@@ -898,6 +918,7 @@ class DraftRunService:
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
assistant_message: str,
|
||||
meta_data: dict,
|
||||
app_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> None:
|
||||
@@ -909,6 +930,7 @@ class DraftRunService:
|
||||
assistant_message: AI 回复消息
|
||||
app_id: 应用ID(未使用,保留用于兼容性)
|
||||
user_id: 用户ID(未使用,保留用于兼容性)
|
||||
meta_data: token消耗
|
||||
"""
|
||||
try:
|
||||
from app.services.conversation_service import ConversationService
|
||||
@@ -927,7 +949,8 @@ class DraftRunService:
|
||||
conversation_service.add_message(
|
||||
conversation_id=conv_uuid,
|
||||
role="assistant",
|
||||
content=assistant_message
|
||||
content=assistant_message,
|
||||
meta_data=meta_data
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
|
||||
@@ -4,7 +4,7 @@ import uuid
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator, Annotated
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
|
||||
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, AIMessageChunk
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.types import Command
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
@@ -727,9 +727,12 @@ class HandoffsService:
|
||||
|
||||
# 提取响应
|
||||
response_content = ""
|
||||
total_tokens = 0
|
||||
for msg in result.get("messages", []):
|
||||
if isinstance(msg, AIMessage):
|
||||
response_content = msg.content
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
||||
break
|
||||
|
||||
return {
|
||||
@@ -737,7 +740,12 @@ class HandoffsService:
|
||||
"active_agent": result.get("active_agent"),
|
||||
"response": response_content,
|
||||
"message_count": len(result.get("messages", [])),
|
||||
"handoff_count": result.get("handoff_count", 0)
|
||||
"handoff_count": result.get("handoff_count", 0),
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
}
|
||||
|
||||
async def chat_stream(
|
||||
@@ -830,6 +838,12 @@ class HandoffsService:
|
||||
|
||||
# 捕获 LLM 结束事件,输出收集到的工具调用
|
||||
elif kind == "on_chat_model_end":
|
||||
output_message = event.get("data", {}).get("output", {})
|
||||
if isinstance(output_message, AIMessageChunk):
|
||||
response_meta = output_message.response_metadata if hasattr(output_message, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
|
||||
0) if response_meta else 0
|
||||
yield f"event: sub_usage\ndata: {json.dumps({"total_tokens": total_tokens}, ensure_ascii=False)}\n\n"
|
||||
if collected_tool_calls:
|
||||
# 找到参数最完整的 transfer 工具调用
|
||||
best_tc = None
|
||||
|
||||
@@ -53,7 +53,10 @@ def get_workspace_end_users(
|
||||
workspace_id: uuid.UUID,
|
||||
current_user: User
|
||||
) -> List[EndUser]:
|
||||
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)"""
|
||||
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)
|
||||
|
||||
返回结果按 updated_at 从新到旧排序(NULL 值排在最后)
|
||||
"""
|
||||
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
@@ -68,9 +71,14 @@ def get_workspace_end_users(
|
||||
app_ids = [app.id for app in apps_orm]
|
||||
|
||||
# 批量查询所有 end_users(一次查询而非循环查询)
|
||||
# 按 updated_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性
|
||||
from app.models.end_user_model import EndUser as EndUserModel
|
||||
from sqlalchemy import desc, nullslast
|
||||
end_users_orm = db.query(EndUserModel).filter(
|
||||
EndUserModel.app_id.in_(app_ids)
|
||||
).order_by(
|
||||
nullslast(desc(EndUserModel.updated_at)),
|
||||
desc(EndUserModel.id)
|
||||
).all()
|
||||
|
||||
# 转换为 Pydantic 模型(只在需要时转换)
|
||||
|
||||
@@ -89,7 +89,6 @@ class WorkspaceAppService:
|
||||
|
||||
for release in app_releases:
|
||||
memory_content = self._extract_memory_content(release.config)
|
||||
memory_content=resolve_config_id(memory_content, self.db)
|
||||
if memory_content and memory_content in processed_configs:
|
||||
continue
|
||||
|
||||
@@ -122,16 +121,12 @@ class WorkspaceAppService:
|
||||
def _get_memory_config(self, memory_content: str) -> Dict[str, Any]:
|
||||
"""Retrieve memory_config information based on memory_content"""
|
||||
try:
|
||||
memory_config_result = MemoryConfigRepository.query_reflection_config_by_id(self.db, int(memory_content))
|
||||
|
||||
# memory_config_query, memory_config_params = MemoryConfigRepository.build_select_reflection(memory_content)
|
||||
# memory_config_result = self.db.execute(text(memory_config_query), memory_config_params).fetchone()
|
||||
# if memory_config_result is None:
|
||||
# return None
|
||||
memory_content = resolve_config_id(memory_content, self.db)
|
||||
memory_config_result = MemoryConfigRepository.query_reflection_config_by_id(self.db, (memory_content))
|
||||
|
||||
if memory_config_result:
|
||||
return {
|
||||
"config_id": memory_config_result.config_id,
|
||||
"config_id": memory_content,
|
||||
"enable_self_reflexion": memory_config_result.enable_self_reflexion,
|
||||
"iteration_period": memory_config_result.iteration_period,
|
||||
"reflexion_range": memory_config_result.reflexion_range,
|
||||
@@ -291,7 +286,7 @@ class MemoryReflectionService:
|
||||
# 检查是否需要执行反思
|
||||
should_execute = False
|
||||
hours_diff = 0
|
||||
|
||||
|
||||
if current_reflection_time is None:
|
||||
# 首次执行反思
|
||||
should_execute = True
|
||||
@@ -303,11 +298,11 @@ class MemoryReflectionService:
|
||||
reflection_time = datetime.fromisoformat(current_reflection_time)
|
||||
else:
|
||||
reflection_time = current_reflection_time
|
||||
|
||||
|
||||
current_time = datetime.now()
|
||||
time_diff = current_time - reflection_time
|
||||
hours_diff = int(time_diff.total_seconds() / 3600)
|
||||
|
||||
|
||||
# 检查是否达到反思周期
|
||||
if hours_diff >= iteration_period:
|
||||
should_execute = True
|
||||
@@ -317,7 +312,7 @@ class MemoryReflectionService:
|
||||
except (ValueError, TypeError) as e:
|
||||
api_logger.warning(f"解析反思时间失败: {e},将执行反思")
|
||||
should_execute = True
|
||||
|
||||
|
||||
if should_execute:
|
||||
api_logger.info(f"与上次的反思时间间隔为: {hours_diff} 小时")
|
||||
# 3. 执行反思引擎
|
||||
@@ -350,7 +345,7 @@ class MemoryReflectionService:
|
||||
"next_reflection_in_hours": iteration_period - hours_diff
|
||||
}
|
||||
|
||||
|
||||
|
||||
except Exception as e:
|
||||
config_id = config_data.get("config_id", "unknown")
|
||||
api_logger.error(f"启动反思失败,config_id: {config_id}, end_user_id: {end_user_id}, 错误: {str(e)}")
|
||||
@@ -361,7 +356,7 @@ class MemoryReflectionService:
|
||||
"end_user_id": end_user_id,
|
||||
"config_data": config_data
|
||||
}
|
||||
|
||||
|
||||
def _create_reflection_config_from_data(self, config_data: Dict[str, Any]) -> ReflectionConfig:
|
||||
"""Create reflective configuration objects from configuration data"""
|
||||
|
||||
@@ -369,12 +364,12 @@ class MemoryReflectionService:
|
||||
if reflexion_range_value is None or reflexion_range_value == "":
|
||||
reflexion_range_value = "partial"
|
||||
reflexion_range = ReflectionRange(reflexion_range_value)
|
||||
|
||||
|
||||
baseline_value = config_data.get("baseline")
|
||||
if baseline_value is None or baseline_value == "":
|
||||
baseline_value = "TIME"
|
||||
baseline = ReflectionBaseline(baseline_value)
|
||||
|
||||
|
||||
# iteration_period =
|
||||
iteration_period = config_data.get("iteration_period", 24)
|
||||
if isinstance(iteration_period, str):
|
||||
@@ -382,7 +377,6 @@ class MemoryReflectionService:
|
||||
iteration_period = int(iteration_period)
|
||||
except (ValueError, TypeError):
|
||||
iteration_period = 24 # 默认24小时
|
||||
|
||||
return ReflectionConfig(
|
||||
enabled=config_data.get("enable_self_reflexion", False),
|
||||
iteration_period=str(iteration_period), # ReflectionConfig期望字符串
|
||||
|
||||
@@ -129,6 +129,12 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
if not params.rerank_id:
|
||||
params.rerank_id = configs.get('rerank')
|
||||
|
||||
# reflection_model_id 和 emotion_model_id 默认与 llm_id 一致
|
||||
if not params.reflection_model_id:
|
||||
params.reflection_model_id = params.llm_id
|
||||
if not params.emotion_model_id:
|
||||
params.emotion_model_id = params.llm_id
|
||||
|
||||
config = MemoryConfigRepository.create(self.db, params)
|
||||
self.db.commit()
|
||||
return {"affected": 1, "config_id": config.config_id}
|
||||
@@ -177,11 +183,11 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
|
||||
# --- Read All ---
|
||||
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
|
||||
configs = MemoryConfigRepository.get_all(self.db, workspace_id)
|
||||
results = MemoryConfigRepository.get_all(self.db, workspace_id)
|
||||
|
||||
# 将 ORM 对象转换为字典列表
|
||||
data_list = []
|
||||
for config in configs:
|
||||
for config, scene_name in results:
|
||||
# 安全地转换 user_id 为 int
|
||||
config_id_old = None
|
||||
if config.config_id_old:
|
||||
@@ -203,6 +209,8 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
"end_user_id": config.end_user_id,
|
||||
"config_id_old": config_id_old,
|
||||
"apply_id": config.apply_id,
|
||||
"scene_id": str(config.scene_id) if config.scene_id else None,
|
||||
"scene_name": scene_name, # 新增:场景名称
|
||||
"llm_id": config.llm_id,
|
||||
"embedding_id": config.embedding_id,
|
||||
"rerank_id": config.rerank_id,
|
||||
@@ -628,10 +636,9 @@ async def analytics_recent_activity_stats() -> Dict[str, Any]:
|
||||
if m < 1:
|
||||
latest_relative = "刚刚"
|
||||
elif m < 60:
|
||||
latest_relative = f"{m}分钟前"
|
||||
latest_relative = "一会前"
|
||||
else:
|
||||
h = int(m // 60)
|
||||
latest_relative = f"{h}小时前" if h < 24 else f"{int(h // 24)}天前"
|
||||
latest_relative = "较早前"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -280,14 +280,22 @@ class MultiAgentOrchestrator:
|
||||
|
||||
# 4. 提取子 Agent 的 conversation_id(用于多轮对话)
|
||||
sub_conversation_id = None
|
||||
total_tokens = 0
|
||||
|
||||
if isinstance(results, dict):
|
||||
sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id")
|
||||
# 提取 token 信息
|
||||
usage = results.get("usage", {}) or results.get("result", {}).get("usage", {})
|
||||
total_tokens += usage.get("total_tokens", 0)
|
||||
elif isinstance(results, list) and results:
|
||||
for item in results:
|
||||
if "result" in item:
|
||||
sub_conversation_id = item["result"].get("conversation_id")
|
||||
if sub_conversation_id:
|
||||
break
|
||||
# 累加每个子 Agent 的 token
|
||||
usage = item.get("usage", {}) or item.get("result", {}).get("usage", {})
|
||||
total_tokens += usage.get("total_tokens", 0)
|
||||
|
||||
logger.info(
|
||||
"多 Agent 任务完成",
|
||||
@@ -301,9 +309,15 @@ class MultiAgentOrchestrator:
|
||||
return {
|
||||
"message": final_result,
|
||||
"conversation_id": sub_conversation_id,
|
||||
"mode": OrchestrationMode.SUPERVISOR,
|
||||
"elapsed_time": elapsed_time,
|
||||
"strategy": routing_decision.get("collaboration_strategy", "single"),
|
||||
"sub_results": results
|
||||
"sub_results": results,
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -1552,10 +1566,12 @@ class MultiAgentOrchestrator:
|
||||
return {
|
||||
"message": result.get("response", ""),
|
||||
"conversation_id": result.get("conversation_id"),
|
||||
"mode": OrchestrationMode.COLLABORATION,
|
||||
"elapsed_time": elapsed_time,
|
||||
"strategy": "collaboration",
|
||||
"active_agent": result.get("active_agent"),
|
||||
"sub_results": result
|
||||
"sub_results": result,
|
||||
"usage": result.get("usage")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""多 Agent 配置管理服务"""
|
||||
import uuid
|
||||
import json
|
||||
from typing import Optional, List, Tuple, Any, Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
@@ -427,6 +428,23 @@ class MultiAgentService:
|
||||
memory=getattr(request, 'memory', True) # 记忆功能参数
|
||||
)
|
||||
|
||||
await self._save_conversation_message(
|
||||
conversation_id=request.conversation_id,
|
||||
user_message=request.message,
|
||||
assistant_message=result.get("message", ""),
|
||||
app_id=app_id,
|
||||
user_id=request.user_id,
|
||||
meta_data={
|
||||
"mode": result.get("mode"),
|
||||
"elapsed_time": result.get("elapsed_time"),
|
||||
"usage": result.get("usage", {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def run_stream(
|
||||
@@ -451,11 +469,14 @@ class MultiAgentService:
|
||||
raise ResourceNotFoundException("多 Agent 配置", str(app_id))
|
||||
|
||||
if not config.is_active:
|
||||
raise BusinessException("多 Agent 配置已禁用", BizCode.RESOURCE_DISABLED)
|
||||
raise BusinessException("多 Agent 配置已禁用", BizCode.NOT_FOUND)
|
||||
|
||||
# 2. 创建编排器
|
||||
orchestrator = MultiAgentOrchestrator(self.db, config)
|
||||
|
||||
full_content = ""
|
||||
total_tokens = 0
|
||||
|
||||
# 3. 流式执行任务
|
||||
async for event in orchestrator.execute_stream(
|
||||
message=request.message,
|
||||
@@ -468,7 +489,88 @@ class MultiAgentService:
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
if "sub_usage" in event:
|
||||
if "data:" in event:
|
||||
try:
|
||||
data_line = event.split("data: ", 1)[1].strip()
|
||||
data = json.loads(data_line)
|
||||
if "total_tokens" in data:
|
||||
total_tokens += data["total_tokens"]
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
yield event
|
||||
if "data:" in event:
|
||||
try:
|
||||
data_line = event.split("data: ", 1)[1].strip()
|
||||
data = json.loads(data_line)
|
||||
if "content" in data:
|
||||
full_content += data["content"]
|
||||
except:
|
||||
pass
|
||||
|
||||
await self._save_conversation_message(
|
||||
conversation_id=request.conversation_id,
|
||||
user_message=request.message,
|
||||
assistant_message=full_content,
|
||||
app_id=app_id,
|
||||
user_id=request.user_id,
|
||||
meta_data={
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
async def _save_conversation_message(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
user_message: str,
|
||||
assistant_message: str,
|
||||
meta_data: dict,
|
||||
app_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""保存会话消息
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID
|
||||
user_message: 用户消息
|
||||
assistant_message: AI 回复消息
|
||||
meta_data: 元数据(包括 token 消耗)
|
||||
app_id: 应用ID
|
||||
user_id: 用户ID
|
||||
"""
|
||||
try:
|
||||
from app.services.conversation_service import ConversationService
|
||||
|
||||
conversation_service = ConversationService(self.db)
|
||||
|
||||
conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
content=user_message
|
||||
)
|
||||
conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=assistant_message,
|
||||
meta_data=meta_data
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"保存多 Agent 会话消息",
|
||||
extra={
|
||||
"conversation_id": conversation_id,
|
||||
"user_message_length": len(user_message),
|
||||
"assistant_message_length": len(assistant_message)
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("保存会话消息失败", extra={"error": str(e)})
|
||||
|
||||
# def add_sub_agent(
|
||||
# self,
|
||||
|
||||
1162
api/app/services/ontology_service.py
Normal file
1162
api/app/services/ontology_service.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator
|
||||
@@ -18,7 +19,8 @@ from app.models.prompt_optimizer_model import (
|
||||
)
|
||||
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository
|
||||
from app.repositories.prompt_optimizer_repository import (
|
||||
PromptOptimizerSessionRepository
|
||||
PromptOptimizerSessionRepository,
|
||||
PromptReleaseRepository
|
||||
)
|
||||
from app.schemas.prompt_optimizer_schema import OptimizePromptResult
|
||||
|
||||
@@ -28,6 +30,8 @@ logger = get_business_logger()
|
||||
class PromptOptimizerService:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.optim_repo = PromptOptimizerSessionRepository(self.db)
|
||||
self.release_repo = PromptReleaseRepository(self.db)
|
||||
|
||||
def get_model_config(
|
||||
self,
|
||||
@@ -78,10 +82,12 @@ class PromptOptimizerService:
|
||||
Returns:
|
||||
PromptOptimzerSession: The newly created prompt optimization session.
|
||||
"""
|
||||
session = PromptOptimizerSessionRepository(self.db).create_session(
|
||||
session = self.optim_repo.create_session(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id
|
||||
)
|
||||
self.db.commit()
|
||||
self.db.refresh(session)
|
||||
return session
|
||||
|
||||
def get_session_message_history(
|
||||
@@ -106,7 +112,7 @@ class PromptOptimizerService:
|
||||
- role (str): The role of the message sender, e.g., 'system', 'user', or 'assistant'.
|
||||
- content (str): The content of the message.
|
||||
"""
|
||||
history = PromptOptimizerSessionRepository(self.db).get_session_history(
|
||||
history = self.optim_repo.get_session_history(
|
||||
session_id=session_id,
|
||||
user_id=user_id
|
||||
)
|
||||
@@ -177,11 +183,12 @@ class PromptOptimizerService:
|
||||
base_url=api_config.api_base
|
||||
), type=ModelType(model_config.type))
|
||||
try:
|
||||
with open('app/services/prompt/prompt_optimizer_system.jinja2', 'r', encoding='utf-8') as f:
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||
with open(os.path.join(prompt_path, 'prompt_optimizer_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||
opt_system_prompt = f.read()
|
||||
rendered_system_message = Template(opt_system_prompt).render()
|
||||
|
||||
with open('app/services/prompt/prompt_optimizer_user.jinja2', 'r', encoding='utf-8') as f:
|
||||
with open(os.path.join(prompt_path, 'prompt_optimizer_user.jinja2'), 'r', encoding='utf-8') as f:
|
||||
opt_user_prompt = f.read()
|
||||
except FileNotFoundError:
|
||||
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
||||
@@ -296,4 +303,165 @@ class PromptOptimizerService:
|
||||
role=role,
|
||||
content=content
|
||||
)
|
||||
self.db.commit()
|
||||
self.db.refresh(message)
|
||||
return message
|
||||
|
||||
def save_prompt(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
session_id: uuid.UUID,
|
||||
title: str,
|
||||
prompt: str
|
||||
) -> dict:
|
||||
"""
|
||||
Create and save a new prompt release for a given session.
|
||||
|
||||
Args:
|
||||
tenant_id (uuid.UUID): The ID of the tenant owning the prompt.
|
||||
session_id (uuid.UUID): The ID of the session to associate with this prompt.
|
||||
title (str): The title of the prompt release.
|
||||
prompt (str): The content of the prompt.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing:
|
||||
- id (UUID): The unique ID of the created prompt release.
|
||||
- session_id (UUID): The session ID linked to the release.
|
||||
- title (str): The title of the prompt.
|
||||
- prompt (str): The prompt content.
|
||||
- created_at (int): Timestamp (in milliseconds) of when the prompt was created.
|
||||
|
||||
Raises:
|
||||
BusinessException: If a prompt release already exists for the given session.
|
||||
"""
|
||||
session = self.optim_repo.get_session_by_id(session_id)
|
||||
if session is None or session.tenant_id != tenant_id:
|
||||
raise BusinessException(
|
||||
"Session does not exist or the current user has no access",
|
||||
BizCode.BAD_REQUEST
|
||||
)
|
||||
|
||||
if self.release_repo.get_prompt_by_session_id(session_id):
|
||||
raise BusinessException(
|
||||
"A release already exists for the current session",
|
||||
BizCode.BAD_REQUEST
|
||||
)
|
||||
|
||||
prompt_obj = self.release_repo.create_prompt_release(
|
||||
tenant_id=tenant_id,
|
||||
title=title,
|
||||
session_id=session_id,
|
||||
prompt=prompt
|
||||
)
|
||||
self.db.commit()
|
||||
self.db.refresh(prompt_obj)
|
||||
return {
|
||||
"id": prompt_obj.id,
|
||||
"session_id": prompt_obj.session_id,
|
||||
"title": prompt_obj.title,
|
||||
"prompt": prompt_obj.prompt,
|
||||
"created_at": int(prompt_obj.created_at.timestamp() * 1000)
|
||||
}
|
||||
|
||||
def delete_prompt(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
prompt_id: uuid.UUID
|
||||
) -> None:
|
||||
"""
|
||||
Soft delete a prompt release by prompt_id.
|
||||
|
||||
Args:
|
||||
tenant_id (uuid.UUID): Tenant identifier.
|
||||
prompt_id (uuid.UUID): Prompt identifier.
|
||||
|
||||
Raises:
|
||||
BusinessException: If the prompt does not exist or already deleted.
|
||||
"""
|
||||
prompt_obj = self.release_repo.get_prompt_by_id(prompt_id)
|
||||
if not prompt_obj or prompt_obj.is_delete:
|
||||
raise BusinessException(
|
||||
"Prompt does not exist or has already been deleted",
|
||||
BizCode.NOT_FOUND
|
||||
)
|
||||
|
||||
if prompt_obj.tenant_id != tenant_id:
|
||||
raise BusinessException(
|
||||
"No permission to delete this prompt",
|
||||
BizCode.FORBIDDEN
|
||||
)
|
||||
|
||||
self.release_repo.soft_delete_prompt(prompt_obj)
|
||||
self.db.commit()
|
||||
logger.info(f"Prompt soft deleted, prompt_id={prompt_id}, tenant_id={tenant_id}")
|
||||
|
||||
def get_release_list(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
page: int,
|
||||
page_size: int,
|
||||
filter_keyword: str | None = None
|
||||
) -> dict[str, int | list[Any]]:
|
||||
"""
|
||||
Get paginated list of prompt releases with optional filter.
|
||||
|
||||
Args:
|
||||
tenant_id (uuid.UUID): Tenant identifier.
|
||||
page (int): Page number (starting from 1).
|
||||
page_size (int): Number of items per page.
|
||||
filter_keyword (str | None): Optional keyword to filter by title.
|
||||
|
||||
Returns:
|
||||
dict: Contains total count, pagination info, and list of releases.
|
||||
"""
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Get total count and releases based on filter
|
||||
if filter_keyword:
|
||||
total = self.release_repo.count_prompts_by_keyword(tenant_id, filter_keyword)
|
||||
releases = self.release_repo.search_prompts_paginated(
|
||||
tenant_id=tenant_id,
|
||||
keyword=filter_keyword,
|
||||
offset=offset,
|
||||
limit=page_size
|
||||
)
|
||||
else:
|
||||
total = self.release_repo.count_prompts(tenant_id)
|
||||
releases = self.release_repo.get_prompts_paginated(
|
||||
tenant_id=tenant_id,
|
||||
offset=offset,
|
||||
limit=page_size
|
||||
)
|
||||
|
||||
items = []
|
||||
for release in releases:
|
||||
# Get first user message from session
|
||||
first_message = self.optim_repo.get_first_user_message(
|
||||
session_id=release.session_id
|
||||
)
|
||||
|
||||
items.append({
|
||||
"id": release.id,
|
||||
"title": release.title,
|
||||
"prompt": release.prompt,
|
||||
"created_at": int(release.created_at.timestamp() * 1000),
|
||||
"first_message": first_message
|
||||
})
|
||||
|
||||
log_msg = f"Retrieved {len(items)} prompt releases, page={page}, tenant_id={tenant_id}"
|
||||
if filter_keyword:
|
||||
log_msg += f", filter='{filter_keyword}'"
|
||||
logger.info(log_msg)
|
||||
|
||||
result = {
|
||||
"page": {
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"hasnext": page * page_size < total
|
||||
},
|
||||
"keyword": filter_keyword,
|
||||
"items": items
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
@@ -282,7 +282,14 @@ class SharedChatService:
|
||||
self.conversation_service.save_conversation_messages(
|
||||
conversation_id=conversation.id,
|
||||
user_message=message,
|
||||
assistant_message=result["content"]
|
||||
assistant_message=result["content"],
|
||||
meta_data={
|
||||
"usage": result.get("usage", {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
})
|
||||
}
|
||||
)
|
||||
# self.conversation_service.add_message(
|
||||
# conversation_id=conversation.id,
|
||||
@@ -469,6 +476,7 @@ class SharedChatService:
|
||||
|
||||
# 流式调用 Agent
|
||||
full_content = ""
|
||||
total_tokens = 0
|
||||
async for chunk in agent.chat_stream(
|
||||
message=message,
|
||||
history=history,
|
||||
@@ -479,9 +487,12 @@ class SharedChatService:
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
):
|
||||
full_content += chunk
|
||||
# 发送消息块事件
|
||||
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
if isinstance(chunk, int):
|
||||
total_tokens = chunk
|
||||
else:
|
||||
full_content += chunk
|
||||
# 发送消息块事件
|
||||
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
@@ -498,7 +509,7 @@ class SharedChatService:
|
||||
content=full_content,
|
||||
meta_data={
|
||||
"model": api_key_obj.model_name,
|
||||
"usage": {}
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.conversation_repository import ConversationRepository
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.repositories.neo4j.cypher_queries import Graph_Node_query
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
@@ -1508,7 +1509,6 @@ async def analytics_graph_data(
|
||||
user_uuid = uuid.UUID(end_user_id)
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(user_uuid)
|
||||
|
||||
if not end_user:
|
||||
logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户")
|
||||
return {
|
||||
@@ -1562,21 +1562,11 @@ async def analytics_graph_data(
|
||||
}
|
||||
else:
|
||||
# 查询所有节点
|
||||
node_query = """
|
||||
MATCH (n)
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) as id,
|
||||
labels(n)[0] as label,
|
||||
properties(n) as properties
|
||||
LIMIT $limit
|
||||
"""
|
||||
node_query=Graph_Node_query
|
||||
node_params = {
|
||||
"end_user_id": end_user_id,
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
|
||||
# 执行节点查询
|
||||
node_results = await _neo4j_connector.execute_query(node_query, **node_params)
|
||||
|
||||
@@ -1587,9 +1577,9 @@ async def analytics_graph_data(
|
||||
|
||||
for record in node_results:
|
||||
node_id = record["id"]
|
||||
node_label = record["label"]
|
||||
node_labels = record.get("labels", [])
|
||||
node_label = node_labels[0] if node_labels else "Unknown"
|
||||
node_props = record["properties"]
|
||||
|
||||
# 根据节点类型提取需要的属性字段
|
||||
filtered_props = await _extract_node_properties(node_label, node_props,node_id)
|
||||
|
||||
|
||||
320
api/app/tasks.py
320
api/app/tasks.py
@@ -774,7 +774,15 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(name="app.tasks.regenerate_memory_cache", bind=True)
|
||||
@celery_app.task(
|
||||
name="app.tasks.regenerate_memory_cache",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=3600,
|
||||
soft_time_limit=3300,
|
||||
)
|
||||
def regenerate_memory_cache(self) -> Dict[str, Any]:
|
||||
"""定时任务:为所有用户重新生成记忆洞察和用户摘要缓存
|
||||
|
||||
@@ -966,7 +974,16 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(name="app.tasks.workspace_reflection_task", bind=True)
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.workspace_reflection_task",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=300,
|
||||
soft_time_limit=240,
|
||||
)
|
||||
def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
"""定时任务:每30秒运行工作空间反思功能
|
||||
|
||||
@@ -1049,6 +1066,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务")
|
||||
|
||||
except Exception as e:
|
||||
db.rollback() # Rollback failed transaction to allow next query
|
||||
api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
|
||||
all_reflection_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
@@ -1111,7 +1129,16 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
|
||||
|
||||
|
||||
@celery_app.task(name="app.tasks.run_forgetting_cycle_task", bind=True)
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.run_forgetting_cycle_task",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=7200,
|
||||
soft_time_limit=7000,
|
||||
)
|
||||
def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Dict[str, Any]:
|
||||
"""定时任务:运行遗忘周期
|
||||
|
||||
@@ -1178,3 +1205,290 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
||||
return result
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Long-term Memory Storage Tasks (Batched Write Strategies)
|
||||
# =============================================================================
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.long_term_storage.window", bind=True)
|
||||
def long_term_storage_window_task(
|
||||
self,
|
||||
end_user_id: str,
|
||||
langchain_messages: List[Dict[str, Any]],
|
||||
config_id: str,
|
||||
scope: int = 6
|
||||
) -> Dict[str, Any]:
|
||||
"""Celery task for window-based long-term memory storage.
|
||||
|
||||
Accumulates messages in Redis buffer until window size (scope) is reached,
|
||||
then writes batched messages to Neo4j.
|
||||
|
||||
Args:
|
||||
end_user_id: End user identifier
|
||||
langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}]
|
||||
config_id: Memory configuration ID
|
||||
scope: Window size (number of messages before triggering write)
|
||||
|
||||
Returns:
|
||||
Dict containing task status and metadata
|
||||
"""
|
||||
from app.core.logging_config import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger.info(f"[LONG_TERM_WINDOW] Starting task - end_user_id={end_user_id}, scope={scope}")
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import window_dialogue
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
# Save to Redis buffer first
|
||||
write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
|
||||
|
||||
# Load memory config
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="LongTermStorageTask"
|
||||
)
|
||||
|
||||
# Execute window-based dialogue storage
|
||||
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||
|
||||
return {"status": "SUCCESS", "strategy": "window", "scope": scope}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
logger.info(f"[LONG_TERM_WINDOW] Task completed - elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
return {
|
||||
**result,
|
||||
"end_user_id": end_user_id,
|
||||
"config_id": config_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"[LONG_TERM_WINDOW] Task failed - error={str(e)}", exc_info=True)
|
||||
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"strategy": "window",
|
||||
"error": str(e),
|
||||
"end_user_id": end_user_id,
|
||||
"config_id": config_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
|
||||
|
||||
# @celery_app.task(name="app.core.memory.agent.long_term_storage.time", bind=True)
|
||||
# def long_term_storage_time_task(
|
||||
# self,
|
||||
# end_user_id: str,
|
||||
# config_id: str,
|
||||
# time_window: int = 5
|
||||
# ) -> Dict[str, Any]:
|
||||
# """Celery task for time-based long-term memory storage.
|
||||
|
||||
# Retrieves recent sessions from Redis within time window and writes to Neo4j.
|
||||
|
||||
# Args:
|
||||
# end_user_id: End user identifier
|
||||
# config_id: Memory configuration ID
|
||||
# time_window: Time window in minutes for retrieving recent sessions
|
||||
|
||||
# Returns:
|
||||
# Dict containing task status and metadata
|
||||
# """
|
||||
# from app.core.logging_config import get_logger
|
||||
# logger = get_logger(__name__)
|
||||
|
||||
# logger.info(f"[LONG_TERM_TIME] Starting task - end_user_id={end_user_id}, time_window={time_window}")
|
||||
# start_time = time.time()
|
||||
|
||||
# async def _run() -> Dict[str, Any]:
|
||||
# from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage
|
||||
# from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# db = next(get_db())
|
||||
# try:
|
||||
# # Load memory config
|
||||
# config_service = MemoryConfigService(db)
|
||||
# memory_config = config_service.load_memory_config(
|
||||
# config_id=config_id,
|
||||
# service_name="LongTermStorageTask"
|
||||
# )
|
||||
|
||||
# # Execute time-based storage
|
||||
# await memory_long_term_storage(end_user_id, memory_config, time_window)
|
||||
|
||||
# return {"status": "SUCCESS", "strategy": "time", "time_window": time_window}
|
||||
# finally:
|
||||
# db.close()
|
||||
|
||||
# try:
|
||||
# import nest_asyncio
|
||||
# nest_asyncio.apply()
|
||||
# except ImportError:
|
||||
# pass
|
||||
|
||||
# try:
|
||||
# loop = asyncio.get_event_loop()
|
||||
# if loop.is_closed():
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
# except RuntimeError:
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
|
||||
# try:
|
||||
# result = loop.run_until_complete(_run())
|
||||
# elapsed_time = time.time() - start_time
|
||||
|
||||
# logger.info(f"[LONG_TERM_TIME] Task completed - elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
# return {
|
||||
# **result,
|
||||
# "end_user_id": end_user_id,
|
||||
# "config_id": config_id,
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "task_id": self.request.id
|
||||
# }
|
||||
# except Exception as e:
|
||||
# elapsed_time = time.time() - start_time
|
||||
# logger.error(f"[LONG_TERM_TIME] Task failed - error={str(e)}", exc_info=True)
|
||||
|
||||
# return {
|
||||
# "status": "FAILURE",
|
||||
# "strategy": "time",
|
||||
# "error": str(e),
|
||||
# "end_user_id": end_user_id,
|
||||
# "config_id": config_id,
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "task_id": self.request.id
|
||||
# }
|
||||
|
||||
|
||||
# @celery_app.task(name="app.core.memory.agent.long_term_storage.aggregate", bind=True)
|
||||
# def long_term_storage_aggregate_task(
|
||||
# self,
|
||||
# end_user_id: str,
|
||||
# langchain_messages: List[Dict[str, Any]],
|
||||
# config_id: str
|
||||
# ) -> Dict[str, Any]:
|
||||
# """Celery task for aggregate-based long-term memory storage.
|
||||
|
||||
# Uses LLM to determine if new messages describe the same event as history.
|
||||
# Only writes to Neo4j if messages represent new information (not duplicates).
|
||||
|
||||
# Args:
|
||||
# end_user_id: End user identifier
|
||||
# langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}]
|
||||
# config_id: Memory configuration ID
|
||||
|
||||
# Returns:
|
||||
# Dict containing task status, is_same_event flag, and metadata
|
||||
# """
|
||||
# from app.core.logging_config import get_logger
|
||||
# logger = get_logger(__name__)
|
||||
|
||||
# logger.info(f"[LONG_TERM_AGGREGATE] Starting task - end_user_id={end_user_id}")
|
||||
# start_time = time.time()
|
||||
|
||||
# async def _run() -> Dict[str, Any]:
|
||||
# from app.core.memory.agent.langgraph_graph.routing.write_router import aggregate_judgment
|
||||
# from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
|
||||
# from app.core.memory.agent.utils.redis_tool import write_store
|
||||
# from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# db = next(get_db())
|
||||
# try:
|
||||
# # Save to Redis buffer first
|
||||
# write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
|
||||
|
||||
# # Load memory config
|
||||
# config_service = MemoryConfigService(db)
|
||||
# memory_config = config_service.load_memory_config(
|
||||
# config_id=config_id,
|
||||
# service_name="LongTermStorageTask"
|
||||
# )
|
||||
|
||||
# # Execute aggregate judgment
|
||||
# result = await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
|
||||
# return {
|
||||
# "status": "SUCCESS",
|
||||
# "strategy": "aggregate",
|
||||
# "is_same_event": result.get("is_same_event", False),
|
||||
# "wrote_to_neo4j": not result.get("is_same_event", False)
|
||||
# }
|
||||
# finally:
|
||||
# db.close()
|
||||
|
||||
# try:
|
||||
# import nest_asyncio
|
||||
# nest_asyncio.apply()
|
||||
# except ImportError:
|
||||
# pass
|
||||
|
||||
# try:
|
||||
# loop = asyncio.get_event_loop()
|
||||
# if loop.is_closed():
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
# except RuntimeError:
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
|
||||
# try:
|
||||
# result = loop.run_until_complete(_run())
|
||||
# elapsed_time = time.time() - start_time
|
||||
|
||||
# logger.info(f"[LONG_TERM_AGGREGATE] Task completed - is_same_event={result.get('is_same_event')}, elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
# return {
|
||||
# **result,
|
||||
# "end_user_id": end_user_id,
|
||||
# "config_id": config_id,
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "task_id": self.request.id
|
||||
# }
|
||||
# except Exception as e:
|
||||
# elapsed_time = time.time() - start_time
|
||||
# logger.error(f"[LONG_TERM_AGGREGATE] Task failed - error={str(e)}", exc_info=True)
|
||||
|
||||
# return {
|
||||
# "status": "FAILURE",
|
||||
# "strategy": "aggregate",
|
||||
# "error": str(e),
|
||||
# "end_user_id": end_user_id,
|
||||
# "config_id": config_id,
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "task_id": self.request.id
|
||||
# }
|
||||
|
||||
@@ -5,42 +5,68 @@ Shared utilities for configuration handling to avoid circular imports.
|
||||
"""
|
||||
from uuid import UUID
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid as uuid_module
|
||||
|
||||
|
||||
def resolve_config_id(config_id: UUID | int|str, db: Session) -> UUID:
|
||||
def resolve_config_id(config_id: UUID | int | str, db: Session) -> UUID:
|
||||
"""
|
||||
解析 config_id,如果是整数则通过 config_id_old 查找对应的 UUID
|
||||
解析 config_id,支持 UUID、UUID字符串、整数等多种格式
|
||||
|
||||
Args:
|
||||
config_id: 配置ID(UUID 或整数)
|
||||
config_id: 配置ID(UUID、UUID字符串 或 整数)
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
UUID: 解析后的配置ID
|
||||
|
||||
Raises:
|
||||
ValueError: 当找不到对应的配置时
|
||||
ValueError: 当找不到对应的配置时或格式无效时
|
||||
"""
|
||||
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
if isinstance(config_id, UUID):
|
||||
|
||||
# 1. 如果已经是 UUID 类型,直接返回
|
||||
if isinstance(config_id, UUID):
|
||||
return config_id
|
||||
if isinstance(config_id, str) and len(config_id)<=6:
|
||||
memory_config = db.query(MemoryConfig).filter(
|
||||
MemoryConfig.config_id_old == int(config_id)
|
||||
).first()
|
||||
print(memory_config)
|
||||
if not memory_config:
|
||||
raise ValueError(f"STR 未找到 config_id_old={config_id} 对应的配置")
|
||||
return memory_config.config_id
|
||||
|
||||
# 2. 如果是字符串类型
|
||||
if isinstance(config_id, str):
|
||||
config_id_stripped = config_id.strip()
|
||||
|
||||
# 2.1 尝试解析为 UUID(标准 UUID 字符串长度为 36)
|
||||
try:
|
||||
return uuid_module.UUID(config_id_stripped)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# 2.2 尝试解析为整数(用于查询 config_id_old)
|
||||
try:
|
||||
old_id = int(config_id_stripped)
|
||||
if old_id > 0:
|
||||
memory_config = db.query(MemoryConfig).filter(
|
||||
MemoryConfig.config_id_old == old_id
|
||||
).first()
|
||||
if not memory_config:
|
||||
raise ValueError(f"未找到 config_id_old={old_id} 对应的配置")
|
||||
return memory_config.config_id
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# 2.3 无法解析的字符串格式
|
||||
raise ValueError(f"无效的 config_id 格式: '{config_id}'(必须是 UUID 或正整数)")
|
||||
|
||||
# 3. 如果是整数类型,通过 config_id_old 查找
|
||||
if isinstance(config_id, int):
|
||||
if config_id <= 0:
|
||||
raise ValueError(f"config_id 必须是正整数: {config_id}")
|
||||
|
||||
memory_config = db.query(MemoryConfig).filter(
|
||||
MemoryConfig.config_id_old == config_id
|
||||
).first()
|
||||
|
||||
if not memory_config:
|
||||
raise ValueError(f"INT 未找到 config_id_old={config_id} 对应的配置")
|
||||
raise ValueError(f"未找到 config_id_old={config_id} 对应的配置")
|
||||
|
||||
return memory_config.config_id
|
||||
|
||||
return config_id
|
||||
# 4. 不支持的类型
|
||||
raise ValueError(f"不支持的 config_id 类型: {type(config_id).__name__}")
|
||||
|
||||
@@ -1,4 +1,32 @@
|
||||
{
|
||||
"v0.2.2": {
|
||||
"introduction": {
|
||||
"codeName": "淬锋(Temper)",
|
||||
"releaseDate": "2026-1-31",
|
||||
"upgradePosition": "本次发布聚焦平台稳定性和性能优化。正如\"淬锋\"之名——千锤百炼,淬火成锋,我们通过严格测试和修复打磨系统品质。引入 Agent 工作流的代码执行能力、改进模型并发管理,并修复了记忆系统的多个关键问题。",
|
||||
"coreUpgrades": [
|
||||
"1. Agent平台增强<br>* 模型并发管理:优化模型广场的并发请求处理和资源分配能力。",
|
||||
"2. 记忆系统优化<br>* Celery 队列修复:解决任务队列问题,提升异步记忆处理的可靠性<br>* 记忆 Agent 优化:提升记忆 Agent 的性能和效率<br>* 接口响应速度优化:优化记忆接口响应时间,加快操作速度。",
|
||||
"3. 情绪记忆与识别升级<br>* 情绪记忆角色识别修复:解决情绪记忆上下文中的角色/人物识别问题<br>* 角色识别增强:提升对话记忆中的角色/人物识别准确性。",
|
||||
"<br>",
|
||||
"MemoryBear 持续致力于为 AI 应用提供类人记忆能力。本次以稳定性为核心的发布,进一步夯实了「感知→精炼→关联→遗忘」范式的基础。",
|
||||
"未来版本将在此坚实基础上,扩展 Agent 能力并深化记忆智能特性。"
|
||||
]
|
||||
},
|
||||
"introduction_en": {
|
||||
"codeName": "Temper (淬锋)",
|
||||
"releaseDate": "2026-1-31",
|
||||
"upgradePosition": "This release focuses on platform stability and performance optimization — true to its codename \"淬锋\" (tempered blade), we've refined the system through rigorous testing and fixes. Introducing Python code execution for Agent workflows, improved model concurrency management, and critical fixes across the memory system.",
|
||||
"coreUpgrades": [
|
||||
"1. Agent Platform Enhancements<br>* Model Concurrency Management: Enhanced Model Plaza with improved concurrent model request handling and resource allocation.",
|
||||
"2. Memory System Improvements<br>* Celery Queue Fix: Resolved task queue issues for more reliable asynchronous memory processing<br>* Memory Agent Optimization: Improved memory Agent performance and efficiency<br>* API Response Speed: Optimized memory interface response times for faster operations.",
|
||||
"3. Emotional Memory & Recognition Upgrades<br>* Emotion Memory Role Recognition Fix: Resolved issues with role/character identification in emotional memory contexts<br>* Role Recognition Enhancement: Improved character/role identification accuracy in conversation memory.",
|
||||
"<br>",
|
||||
"MemoryBear continues advancing toward human-like memory capabilities for AI applications. This stability-focused release strengthens the foundation for our Perception → Refinement → Association → Forgetting paradigm.",
|
||||
"Future releases will build on this solid base with expanded Agent capabilities and deeper memory intelligence features."
|
||||
]
|
||||
}
|
||||
},
|
||||
"v0.2.1": {
|
||||
"introduction": {
|
||||
"codeName": "启知",
|
||||
|
||||
@@ -19,6 +19,7 @@ services:
|
||||
depends_on:
|
||||
- worker-memory
|
||||
- worker-document
|
||||
- worker-periodic
|
||||
|
||||
# Memory worker - Memory read/write tasks (threads pool for asyncio)
|
||||
worker-memory:
|
||||
@@ -48,6 +49,20 @@ services:
|
||||
networks:
|
||||
- celery
|
||||
|
||||
# Periodic worker - Scheduled/beat tasks (prefork, low concurrency)
|
||||
worker-periodic:
|
||||
image: redbear-mem-open:latest
|
||||
container_name: worker-periodic
|
||||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
- ./files:/files
|
||||
- /etc/localtime:/etc/localtime:ro
|
||||
command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=prefork --concurrency=2 --queues=periodic_tasks --max-tasks-per-child=50 -n periodic_worker@%h
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- celery
|
||||
|
||||
# Celery Beat - scheduler
|
||||
beat:
|
||||
image: redbear-mem-open:latest
|
||||
@@ -69,7 +84,7 @@ services:
|
||||
container_name: sandbox
|
||||
ports:
|
||||
- "8194"
|
||||
command: /code/.venv/bin/python main.py
|
||||
command: /code/.venv/bin/uvicorn main:app --host 0.0.0.0 --port 8194 --log-level debug
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- sandbox
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
|
||||
# Language Configuration
|
||||
# Supported values: "zh" (Chinese), "en" (English)
|
||||
# This controls the language used for memory summary titles and other generated content
|
||||
DEFAULT_LANGUAGE=zh
|
||||
|
||||
# Neo4j Configuration (记忆系统数据库)
|
||||
NEO4J_URI=
|
||||
NEO4J_USERNAME=
|
||||
|
||||
78
api/migrations/versions/550c10595967_202601301521.py
Normal file
78
api/migrations/versions/550c10595967_202601301521.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""202601301521
|
||||
|
||||
Revision ID: 550c10595967
|
||||
Revises: 5de9b1e28509
|
||||
Create Date: 2026-01-30 15:24:34.647440
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '550c10595967'
|
||||
down_revision: Union[str, None] = '5de9b1e28509'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('ontology_scene',
|
||||
sa.Column('scene_id', sa.UUID(), nullable=False, comment='场景ID'),
|
||||
sa.Column('scene_name', sa.String(length=200), nullable=False, comment='场景名称'),
|
||||
sa.Column('scene_description', sa.Text(), nullable=True, comment='场景描述'),
|
||||
sa.Column('workspace_id', sa.UUID(), nullable=False, comment='所属工作空间ID'),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['workspace_id'], ['workspaces.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('scene_id'),
|
||||
sa.UniqueConstraint('workspace_id', 'scene_name', name='uq_workspace_scene_name')
|
||||
)
|
||||
op.create_index(op.f('ix_ontology_scene_scene_id'), 'ontology_scene', ['scene_id'], unique=False)
|
||||
op.create_index(op.f('ix_ontology_scene_workspace_id'), 'ontology_scene', ['workspace_id'], unique=False)
|
||||
op.create_table('ontology_class',
|
||||
sa.Column('class_id', sa.UUID(), nullable=False, comment='类型ID'),
|
||||
sa.Column('class_name', sa.String(length=200), nullable=False, comment='类型名称'),
|
||||
sa.Column('class_description', sa.Text(), nullable=True, comment='类型描述'),
|
||||
sa.Column('scene_id', sa.UUID(), nullable=False, comment='所属场景ID'),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False, comment='更新时间'),
|
||||
sa.ForeignKeyConstraint(['scene_id'], ['ontology_scene.scene_id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('class_id')
|
||||
)
|
||||
op.create_index(op.f('ix_ontology_class_class_id'), 'ontology_class', ['class_id'], unique=False)
|
||||
op.create_index(op.f('ix_ontology_class_scene_id'), 'ontology_class', ['scene_id'], unique=False)
|
||||
op.create_table('prompt_history',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('tenant_id', sa.UUID(), nullable=False, comment='Tenant ID'),
|
||||
sa.Column('session_id', sa.UUID(), nullable=False, comment='Session ID'),
|
||||
sa.Column('title', sa.String(), nullable=False, comment='Title'),
|
||||
sa.Column('prompt', sa.Text(), nullable=False, comment='Prompt'),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True, comment='Creation Time'),
|
||||
sa.Column('is_delete', sa.Boolean(), nullable=True, comment='Delete'),
|
||||
sa.ForeignKeyConstraint(['session_id'], ['prompt_opt_session_list.id'], ),
|
||||
sa.ForeignKeyConstraint(['tenant_id'], ['tenants.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_prompt_history_created_at'), 'prompt_history', ['created_at'], unique=False)
|
||||
op.create_index(op.f('ix_prompt_history_id'), 'prompt_history', ['id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
op.drop_index(op.f('ix_prompt_history_id'), table_name='prompt_history')
|
||||
op.drop_index(op.f('ix_prompt_history_created_at'), table_name='prompt_history')
|
||||
op.drop_table('prompt_history')
|
||||
op.drop_index(op.f('ix_ontology_class_scene_id'), table_name='ontology_class')
|
||||
op.drop_index(op.f('ix_ontology_class_class_id'), table_name='ontology_class')
|
||||
op.drop_table('ontology_class')
|
||||
op.drop_index(op.f('ix_ontology_scene_workspace_id'), table_name='ontology_scene')
|
||||
op.drop_index(op.f('ix_ontology_scene_scene_id'), table_name='ontology_scene')
|
||||
op.drop_table('ontology_scene')
|
||||
# ### end Alembic commands ###
|
||||
30
api/migrations/versions/9def72f79398_202601301850.py
Normal file
30
api/migrations/versions/9def72f79398_202601301850.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""202601301850
|
||||
|
||||
Revision ID: 9def72f79398
|
||||
Revises: 550c10595967
|
||||
Create Date: 2026-01-30 18:51:18.290796
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '9def72f79398'
|
||||
down_revision: Union[str, None] = '550c10595967'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('memory_config', sa.Column('scene_id', sa.UUID(), nullable=True, comment='本体场景ID,关联ontology_scene表'))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('memory_config', 'scene_id')
|
||||
# ### end Alembic commands ###
|
||||
@@ -140,6 +140,7 @@ dependencies = [
|
||||
"oss2>=2.19.1",
|
||||
"flower>=2.0.1",
|
||||
"aiofiles>=23.0.0",
|
||||
"owlready2>=0.46",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
FROM python:3.12-slim
|
||||
USER root
|
||||
WORKDIR /code
|
||||
LABEL authors="Eterntiy"
|
||||
|
||||
ARG NEED_MIRROR=0
|
||||
ARG NEED_MIRROR=1
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
|
||||
RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \
|
||||
if [ "$NEED_MIRROR" == "1" ]; then \
|
||||
@@ -17,11 +18,14 @@ RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \
|
||||
apt --no-install-recommends install -y ca-certificates && \
|
||||
apt update && \
|
||||
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
|
||||
apt install -y nodejs npm && \
|
||||
apt-get install -y --no-install-recommends tzdata libseccomp2 libseccomp-dev && \
|
||||
ln -snf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \
|
||||
echo "Asia/Shanghai" > /etc/timezone && \
|
||||
apt install -y cargo
|
||||
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
|
||||
COPY ./app /code/app
|
||||
COPY ./dependencies /code/dependencies
|
||||
COPY ./lib /code/lib
|
||||
@@ -33,10 +37,15 @@ COPY ./requirements.txt /code/requirements.txt
|
||||
RUN python -m venv .venv
|
||||
RUN .venv/bin/python3 -m pip install -r requirements.txt
|
||||
|
||||
RUN cargo build --release --manifest-path lib/seccomp_python/Cargo.toml
|
||||
RUN npm install --prefix=/code/dependencies/nodejs koffi
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \
|
||||
RUN cargo build --release --manifest-path lib/seccomp_redbear/Cargo.toml --features python3
|
||||
RUN mv lib/seccomp_redbear/target/release/libsandbox.so lib/seccomp_redbear/target/release/libpython.so
|
||||
RUN cargo build --release --manifest-path lib/seccomp_redbear/Cargo.toml --features nodejs
|
||||
RUN mv lib/seccomp_redbear/target/release/libsandbox.so lib/seccomp_redbear/target/release/libnodejs.so
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=5s --start-period=60s --retries=3 \
|
||||
CMD curl 127.0.0.1:8194/health
|
||||
|
||||
|
||||
CMD [".venv/bin/python3", "main.py"]
|
||||
CMD [".venv/bin/uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8194", "--log-level", "debug"]
|
||||
4
sandbox/app/__init__.py
Normal file
4
sandbox/app/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/1/29 14:33
|
||||
@@ -4,9 +4,6 @@ from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
import yaml
|
||||
|
||||
SANDBOX_USER_ID = 1000
|
||||
SANDBOX_GROUP_ID = 1000
|
||||
|
||||
DEFAULT_PYTHON_LIB_REQUIREMENTS_AMD = [
|
||||
"/usr/local/lib/python3.12",
|
||||
"/usr/lib/python3",
|
||||
@@ -15,13 +12,18 @@ DEFAULT_PYTHON_LIB_REQUIREMENTS_AMD = [
|
||||
"/etc/nsswitch.conf",
|
||||
"/etc/hosts",
|
||||
"/etc/resolv.conf",
|
||||
"/run/systemd/resolve/stub-resolv.conf",
|
||||
"/run/resolvconf/resolv.conf",
|
||||
"/etc/localtime",
|
||||
"/usr/share/zoneinfo",
|
||||
"/etc/timezone",
|
||||
]
|
||||
|
||||
DEFAULT_NODEJS_LIB_REQUIREMENTS = [
|
||||
"/etc/ssl/certs/ca-certificates.crt",
|
||||
"/etc/nsswitch.conf",
|
||||
"/etc/resolv.conf",
|
||||
"/etc/hosts",
|
||||
]
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""Application configuration"""
|
||||
@@ -43,83 +45,77 @@ class Config(BaseModel):
|
||||
max_workers: int = 4
|
||||
max_requests: int = 50
|
||||
worker_timeout: int = 30
|
||||
nodejs_path: str = "node"
|
||||
|
||||
enable_network: bool = True
|
||||
enable_preload: bool = False
|
||||
|
||||
python_path: str = ""
|
||||
python_lib_paths: list = Field(default=DEFAULT_PYTHON_LIB_REQUIREMENTS_AMD)
|
||||
python_deps_update_interval: str = "30m"
|
||||
|
||||
nodejs_path: str = ""
|
||||
nodejs_lib_paths: list = Field(default=DEFAULT_NODEJS_LIB_REQUIREMENTS)
|
||||
|
||||
allowed_syscalls: List[int] = Field(default_factory=list)
|
||||
proxy: ProxyConfig = Field(default_factory=ProxyConfig)
|
||||
|
||||
sandbox_user: str = "sandbox"
|
||||
sandbox_uid: int = 65537
|
||||
sandbox_gid: int = 0
|
||||
|
||||
def set_sandbox_gid(self, gid: int):
|
||||
"""Update sandbox GID dynamically"""
|
||||
self.sandbox_gid = gid
|
||||
|
||||
def override_with_env(self):
|
||||
"""Override configuration with environment variables"""
|
||||
env_map = {
|
||||
"DEBUG": ("app.debug", lambda v: v.lower() in ("true", "1", "yes")),
|
||||
"MAX_WORKERS": ("max_workers", int),
|
||||
"MAX_REQUESTS": ("max_requests", int),
|
||||
"SANDBOX_PORT": ("app.port", int),
|
||||
"WORKER_TIMEOUT": ("worker_timeout", int),
|
||||
"API_KEY": ("app.key", str),
|
||||
"NODEJS_PATH": ("nodejs_path", str),
|
||||
"ENABLE_NETWORK": ("enable_network", lambda v: v.lower() in ("true", "1", "yes")),
|
||||
"ENABLE_PRELOAD": ("enable_preload", lambda v: v.lower() in ("true", "1", "yes")),
|
||||
"ALLOWED_SYSCALLS": ("allowed_syscalls", lambda v: [int(x) for x in v.split(",")]),
|
||||
"SOCKS5_PROXY": ("proxy.socks5", str),
|
||||
"HTTP_PROXY": ("proxy.http", str),
|
||||
"HTTPS_PROXY": ("proxy.https", str),
|
||||
"PYTHON_PATH": ("python_path", str),
|
||||
"PYTHON_LIB_PATH": ("python_lib_paths", lambda v: v.split(",")),
|
||||
"PYTHON_DEPS_UPDATE_INTERVAL": ("python_deps_update_interval", str),
|
||||
"NODEJS_LIB_PATH": ("nodejs_lib_paths", lambda v: v.split(",")),
|
||||
}
|
||||
|
||||
for env_var, (attr_path, cast) in env_map.items():
|
||||
value = os.getenv(env_var)
|
||||
if value is not None:
|
||||
# Support nested attributes like 'app.debug'
|
||||
parts = attr_path.split(".")
|
||||
obj = self
|
||||
for part in parts[:-1]:
|
||||
obj = getattr(obj, part)
|
||||
setattr(obj, parts[-1], cast(value))
|
||||
|
||||
|
||||
# Global configuration instance
|
||||
_config: Optional[Config] = None
|
||||
|
||||
|
||||
def load_config(config_path: str) -> Config:
|
||||
"""Load configuration from YAML file"""
|
||||
def load_config(config_path: str = "config.yaml") -> Config:
|
||||
"""Load configuration from YAML file and override with env variables"""
|
||||
global _config
|
||||
|
||||
# Load from file
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
data = yaml.safe_load(f)
|
||||
data = yaml.safe_load(f) or {}
|
||||
_config = Config(**data)
|
||||
else:
|
||||
_config = Config()
|
||||
|
||||
# Override with environment variables
|
||||
if os.getenv("DEBUG"):
|
||||
_config.app.debug = os.getenv("DEBUG").lower() in ("true", "1", "yes")
|
||||
|
||||
if os.getenv("MAX_WORKERS"):
|
||||
_config.max_workers = int(os.getenv("MAX_WORKERS"))
|
||||
|
||||
if os.getenv("MAX_REQUESTS"):
|
||||
_config.max_requests = int(os.getenv("MAX_REQUESTS"))
|
||||
|
||||
if os.getenv("SANDBOX_PORT"):
|
||||
_config.app.port = int(os.getenv("SANDBOX_PORT"))
|
||||
|
||||
if os.getenv("WORKER_TIMEOUT"):
|
||||
_config.worker_timeout = int(os.getenv("WORKER_TIMEOUT"))
|
||||
|
||||
if os.getenv("API_KEY"):
|
||||
_config.app.key = os.getenv("API_KEY")
|
||||
|
||||
if os.getenv("NODEJS_PATH"):
|
||||
_config.nodejs_path = os.getenv("NODEJS_PATH")
|
||||
|
||||
if os.getenv("ENABLE_NETWORK"):
|
||||
_config.enable_network = os.getenv("ENABLE_NETWORK").lower() in ("true", "1", "yes")
|
||||
|
||||
if os.getenv("ENABLE_PRELOAD"):
|
||||
_config.enable_preload = os.getenv("ENABLE_PRELOAD").lower() in ("true", "1", "yes")
|
||||
|
||||
if os.getenv("ALLOWED_SYSCALLS"):
|
||||
_config.allowed_syscalls = [int(x) for x in os.getenv("ALLOWED_SYSCALLS").split(",")]
|
||||
|
||||
if os.getenv("SOCKS5_PROXY"):
|
||||
_config.proxy.socks5 = os.getenv("SOCKS5_PROXY")
|
||||
|
||||
if os.getenv("HTTP_PROXY"):
|
||||
_config.proxy.http = os.getenv("HTTP_PROXY")
|
||||
|
||||
if os.getenv("HTTPS_PROXY"):
|
||||
_config.proxy.https = os.getenv("HTTPS_PROXY")
|
||||
|
||||
# python
|
||||
if os.getenv("PYTHON_PATH"):
|
||||
_config.python_path = os.getenv("PYTHON_PATH")
|
||||
|
||||
if os.getenv("PYTHON_LIB_PATH"):
|
||||
_config.python_lib_paths = os.getenv("PYTHON_LIB_PATH").split(',')
|
||||
|
||||
if os.getenv("PYTHON_DEPS_UPDATE_INTERVAL"):
|
||||
_config.python_deps_update_interval = os.getenv("PYTHON_DEPS_UPDATE_INTERVAL")
|
||||
|
||||
# Override from environment
|
||||
_config.override_with_env()
|
||||
return _config
|
||||
|
||||
|
||||
|
||||
@@ -9,4 +9,4 @@ router = APIRouter()
|
||||
@router.get("/health", response_model=HealthResponse)
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return HealthResponse(status="healthy", version="2.0.0")
|
||||
return HealthResponse(status="healthy", version="0.1.0")
|
||||
|
||||
@@ -2,13 +2,15 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.middleware.auth import verify_api_key
|
||||
from app.middleware.concurrency import check_max_requests, acquire_worker
|
||||
from app.middleware.concurrency import concurrency_guard
|
||||
|
||||
from app.models import (
|
||||
RunCodeRequest,
|
||||
ApiResponse,
|
||||
UpdateDependencyRequest,
|
||||
error_response
|
||||
)
|
||||
from app.services.nodejs_service import run_nodejs_code
|
||||
from app.services.python_service import (
|
||||
run_python_code,
|
||||
list_python_dependencies,
|
||||
@@ -25,16 +27,14 @@ router = APIRouter(
|
||||
@router.post(
|
||||
"/run",
|
||||
response_model=ApiResponse,
|
||||
dependencies=[Depends(check_max_requests),
|
||||
Depends(acquire_worker)]
|
||||
dependencies=[Depends(concurrency_guard)]
|
||||
)
|
||||
async def run_code(request: RunCodeRequest):
|
||||
"""Execute code in sandbox"""
|
||||
if request.language == "python3":
|
||||
return await run_python_code(request.code, request.preload, request.options)
|
||||
elif request.language == "nodejs":
|
||||
# TODO
|
||||
return error_response(-400, "TODO")
|
||||
elif request.language == "javascript":
|
||||
return await run_nodejs_code(request.code, request.preload, request.options)
|
||||
else:
|
||||
return error_response(-400, "unsupported language")
|
||||
|
||||
@@ -55,5 +55,3 @@ async def update_dependencies(request: UpdateDependencyRequest):
|
||||
return await update_python_dependencies()
|
||||
else:
|
||||
return error_response(-400, "unsupported language")
|
||||
|
||||
|
||||
|
||||
@@ -1 +1,40 @@
|
||||
"""Code runners package"""
|
||||
import pwd
|
||||
import subprocess
|
||||
|
||||
from app.config import get_config
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def init_sandbox_user():
|
||||
config = get_config()
|
||||
sandbox_user = config.sandbox_user
|
||||
sandbox_uid = config.sandbox_uid
|
||||
try:
|
||||
pwd.getpwnam(sandbox_user)
|
||||
logger.info(f"User '{sandbox_user}' already exists")
|
||||
except KeyError:
|
||||
try:
|
||||
subprocess.run(
|
||||
["useradd", "-u", str(sandbox_uid), sandbox_user],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
logger.info(f"Created user '{sandbox_user}' with UID {sandbox_uid}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Failed to create user: {e.stderr}")
|
||||
raise RuntimeError(f"Failed to create user '{sandbox_user}': {e.stderr}") from e
|
||||
|
||||
try:
|
||||
user_info = pwd.getpwnam(sandbox_user)
|
||||
config.set_sandbox_gid(user_info.pw_gid)
|
||||
logger.info(f"Sandbox user GID: {config.sandbox_gid}")
|
||||
except KeyError as e:
|
||||
logger.error(f"Failed to get GID for user '{sandbox_user}'")
|
||||
raise RuntimeError(f"Failed to get GID for user '{sandbox_user}'") from e
|
||||
|
||||
|
||||
|
||||
|
||||
3
sandbox/app/core/runners/nodejs/__init__.py
Normal file
3
sandbox/app/core/runners/nodejs/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from app.core.runners.nodejs.env import release_lib_binary
|
||||
|
||||
release_lib_binary(True)
|
||||
124
sandbox/app/core/runners/nodejs/env.py
Normal file
124
sandbox/app/core/runners/nodejs/env.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import asyncio
|
||||
import ctypes
|
||||
import os
|
||||
import shutil
|
||||
import stat
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.config import get_config
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
RELEASE_LIB_PATH = "./lib/seccomp_redbear/target/release/libnodejs.so"
|
||||
LIB_PATH = "/var/sandbox/sandbox-nodejs"
|
||||
LIB_NAME = "libnodejs.so"
|
||||
|
||||
lib = ctypes.CDLL(RELEASE_LIB_PATH)
|
||||
lib.get_lib_version_static.restype = ctypes.c_char_p
|
||||
lib.get_lib_feature_static.restype = ctypes.c_char_p
|
||||
logger.info(f"Seccomp Env: nodejs, "
|
||||
f"Seccomp Feature: {lib.get_lib_feature_static().decode('utf-8')}, "
|
||||
f"Seccomp Version: {lib.get_lib_version_static().decode('utf-8')}")
|
||||
|
||||
try:
|
||||
with open(RELEASE_LIB_PATH, "rb") as f:
|
||||
_NODEJS_LIB = f.read()
|
||||
except:
|
||||
logger.critical("failed to load nodejs lib")
|
||||
raise
|
||||
|
||||
|
||||
def check_lib_avaiable():
|
||||
return os.path.exists(os.path.join(LIB_PATH, LIB_NAME))
|
||||
|
||||
|
||||
def release_lib_binary(force_remove: bool):
|
||||
logger.info("init runtime enviroment")
|
||||
|
||||
lib_file = os.path.join(LIB_PATH, LIB_NAME)
|
||||
if os.path.exists(lib_file):
|
||||
if force_remove:
|
||||
try:
|
||||
os.remove(lib_file)
|
||||
except OSError:
|
||||
logger.critical(f"failed to remove {os.path.join(LIB_PATH, LIB_NAME)}")
|
||||
raise
|
||||
|
||||
try:
|
||||
os.makedirs(LIB_PATH, mode=0o755, exist_ok=True)
|
||||
except OSError:
|
||||
logger.critical(f"failed to create {LIB_PATH}")
|
||||
raise
|
||||
|
||||
try:
|
||||
with open(lib_file, "wb") as f:
|
||||
f.write(_NODEJS_LIB)
|
||||
os.chmod(lib_file, 0o755)
|
||||
except OSError:
|
||||
logger.critical(f"failed to write {lib_file}")
|
||||
raise
|
||||
else:
|
||||
try:
|
||||
os.makedirs(LIB_PATH, mode=0o755, exist_ok=True)
|
||||
except OSError:
|
||||
logger.critical(f"failed to create {LIB_PATH}")
|
||||
raise
|
||||
|
||||
try:
|
||||
with open(lib_file, "wb") as f:
|
||||
f.write(_NODEJS_LIB)
|
||||
os.chmod(lib_file, 0o755)
|
||||
except OSError:
|
||||
logger.critical(f"failed to write {lib_file}")
|
||||
raise
|
||||
|
||||
logger.info("nodejs runner environment initialized")
|
||||
|
||||
|
||||
async def prepare_nodejs_dependencies_env():
|
||||
config = get_config()
|
||||
|
||||
with tempfile.TemporaryDirectory(dir="/") as root_path:
|
||||
root = Path(root_path)
|
||||
|
||||
env_sh = root / "env.sh"
|
||||
with open("script/env.sh") as f:
|
||||
env_sh.write_text(f.read())
|
||||
env_sh.chmod(env_sh.stat().st_mode | stat.S_IXUSR)
|
||||
|
||||
shutil.copytree("dependencies/nodejs", os.path.join(LIB_PATH, "node_temp"), dirs_exist_ok=True)
|
||||
for root, dirs, files in os.walk(os.path.join(LIB_PATH, "node_temp")):
|
||||
for d in dirs:
|
||||
os.chmod(os.path.join(root, d), 0o755)
|
||||
for f in files:
|
||||
os.chmod(os.path.join(root, f), 0o444)
|
||||
|
||||
for lib_path in config.nodejs_lib_paths:
|
||||
lib_path = Path(lib_path)
|
||||
|
||||
if not lib_path.exists():
|
||||
logger.warning("nodejs lib path %s is not available", lib_path)
|
||||
continue
|
||||
|
||||
cmd = [
|
||||
"bash",
|
||||
str(env_sh),
|
||||
str(lib_path),
|
||||
str(LIB_PATH),
|
||||
]
|
||||
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
retcode = process.returncode
|
||||
|
||||
if retcode != 0:
|
||||
logger.error(
|
||||
f"create env error for file {lib_path}: retcode={retcode}, stderr={stderr.decode()}"
|
||||
)
|
||||
138
sandbox/app/core/runners/nodejs/nodejs_runner.py
Normal file
138
sandbox/app/core/runners/nodejs/nodejs_runner.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Nodejs code runner"""
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from app.core.executor import CodeExecutor, ExecutionResult
|
||||
from app.core.runners.nodejs.env import check_lib_avaiable, release_lib_binary, LIB_PATH
|
||||
from app.logger import get_logger
|
||||
from app.models import RunnerOptions
|
||||
|
||||
# Nodejs sandbox prescript template
|
||||
with open("app/core/runners/nodejs/prescript.js") as f:
|
||||
NODEJS_PRESCRIPT = f.read()
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class NodejsRunner(CodeExecutor):
|
||||
"""Node.js code runner with security isolation"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def init_environment(code: str, preload: str) -> str:
|
||||
if not check_lib_avaiable():
|
||||
release_lib_binary(False)
|
||||
code_file_name = uuid.uuid4().hex.replace("-", "_")
|
||||
|
||||
script = NODEJS_PRESCRIPT.replace("{{preload}}", preload, 1)
|
||||
|
||||
eval_code = f"eval(Buffer.from('{code}', 'base64').toString('utf-8'))"
|
||||
script = script.replace("{{code}}", eval_code, 1)
|
||||
|
||||
code_path = f"{LIB_PATH}/node_temp/tmp/{code_file_name}.js"
|
||||
try:
|
||||
os.makedirs(os.path.dirname(code_path), mode=0o755, exist_ok=True)
|
||||
with open(code_path, "w", encoding="utf-8") as f:
|
||||
f.write(script)
|
||||
os.chmod(code_path, 0o755)
|
||||
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"Failed to write {code_path}") from e
|
||||
|
||||
return code_path
|
||||
|
||||
async def run(
|
||||
self,
|
||||
code: str,
|
||||
options: RunnerOptions,
|
||||
preload: str = "",
|
||||
timeout: Optional[int] = None
|
||||
) -> ExecutionResult:
|
||||
"""Run Python code in sandbox
|
||||
|
||||
Args:
|
||||
options:
|
||||
code: Base64 encoded encrypted code
|
||||
preload: Preload code to execute before main code
|
||||
timeout: Execution timeout in seconds
|
||||
|
||||
Returns:
|
||||
ExecutionResult with stdout, stderr, and exit code
|
||||
"""
|
||||
config = self.config
|
||||
|
||||
if timeout is None:
|
||||
timeout = config.worker_timeout
|
||||
|
||||
# Check if preload is allowed
|
||||
if not preload or not config.enable_preload:
|
||||
preload = ""
|
||||
script_path = self.init_environment(code, preload)
|
||||
|
||||
try:
|
||||
# Setup environment
|
||||
env = {
|
||||
"UV_USE_IO_URING": "0"
|
||||
}
|
||||
|
||||
# Add proxy settings if configured
|
||||
if config.proxy.socks5:
|
||||
env["HTTPS_PROXY"] = config.proxy.socks5
|
||||
env["HTTP_PROXY"] = config.proxy.socks5
|
||||
elif config.proxy.https or config.proxy.http:
|
||||
if config.proxy.https:
|
||||
env["HTTPS_PROXY"] = config.proxy.https
|
||||
if config.proxy.http:
|
||||
env["HTTP_PROXY"] = config.proxy.http
|
||||
|
||||
# Add allowed syscalls if configured
|
||||
if config.allowed_syscalls:
|
||||
env["ALLOWED_SYSCALLS"] = ",".join(map(str, config.allowed_syscalls))
|
||||
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
config.nodejs_path,
|
||||
script_path,
|
||||
LIB_PATH,
|
||||
str(config.sandbox_uid),
|
||||
str(config.sandbox_gid),
|
||||
options.model_dump_json(),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=env,
|
||||
cwd=LIB_PATH
|
||||
)
|
||||
|
||||
# Wait for completion with timeout
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
return ExecutionResult(
|
||||
stdout=stdout.decode('utf-8', errors='replace'),
|
||||
stderr=stderr.decode('utf-8', errors='replace'),
|
||||
exit_code=process.returncode
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Kill process on timeout
|
||||
try:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
except:
|
||||
pass
|
||||
|
||||
return ExecutionResult(
|
||||
stdout="",
|
||||
stderr="Execution timeout",
|
||||
exit_code=-1,
|
||||
)
|
||||
|
||||
finally:
|
||||
# Cleanup temporary file
|
||||
self.cleanup_temp_file(script_path)
|
||||
31
sandbox/app/core/runners/nodejs/prescript.js
Normal file
31
sandbox/app/core/runners/nodejs/prescript.js
Normal file
@@ -0,0 +1,31 @@
|
||||
let argv = process.argv
|
||||
|
||||
let koffi = require('koffi')
|
||||
|
||||
process.chdir(argv[2])
|
||||
|
||||
let lib = koffi.load("./libnodejs.so")
|
||||
/** @type {(uid: number, gid: number, enableNetwork: boolean) => number} */
|
||||
let initSeccomp = lib.func('int init_seccomp(int, int, bool)')
|
||||
|
||||
let uid = parseInt(argv[3])
|
||||
let gid = parseInt(argv[4])
|
||||
|
||||
let options = JSON.parse(argv[5])
|
||||
|
||||
let seccomp_init = initSeccomp(uid, gid, options['enable_network'])
|
||||
if (seccomp_init !== 0) {
|
||||
throw `code executor err - ${seccomp_init}`
|
||||
}
|
||||
|
||||
delete process.argv
|
||||
argv = undefined
|
||||
koffi = undefined
|
||||
lib = undefined
|
||||
initSeccomp = undefined
|
||||
uid = undefined
|
||||
gid = undefined
|
||||
options = undefined
|
||||
seccomp_init = undefined
|
||||
|
||||
{{code}}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user