Merge branch 'develop' into fix/memory-enduser-config

This commit is contained in:
Ke Sun
2026-02-03 19:38:21 +08:00
151 changed files with 11318 additions and 1208 deletions

View File

@@ -3,9 +3,10 @@ import platform
from datetime import timedelta
from urllib.parse import quote
from app.core.config import settings
from celery import Celery
from app.core.config import settings
# 创建 Celery 应用实例
# broker: 任务队列(使用 Redis DB 0
# backend: 结果存储(使用 Redis DB 10
@@ -67,11 +68,11 @@ celery_app.conf.update(
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
# Beat/periodic tasks → document_tasks queue (prefork worker)
'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'},
'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'},
'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'},
'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'},
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
'app.controllers.memory_storage_controller.search_all': {'queue': 'periodic_tasks'},
},
)
@@ -79,40 +80,40 @@ celery_app.conf.update(
celery_app.autodiscover_tasks(['app'])
# Celery Beat schedule for periodic tasks
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
# memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
# memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
# workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
# forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
# 构建定时任务配置
beat_schedule_config = {
"run-workspace-reflection": {
"task": "app.tasks.workspace_reflection_task",
"schedule": workspace_reflection_schedule,
"args": (),
},
"regenerate-memory-cache": {
"task": "app.tasks.regenerate_memory_cache",
"schedule": memory_cache_regeneration_schedule,
"args": (),
},
"run-forgetting-cycle": {
"task": "app.tasks.run_forgetting_cycle_task",
"schedule": forgetting_cycle_schedule,
"kwargs": {
"config_id": None, # 使用默认配置,可以通过环境变量配置
},
},
}
# beat_schedule_config = {
# "run-workspace-reflection": {
# "task": "app.tasks.workspace_reflection_task",
# "schedule": workspace_reflection_schedule,
# "args": (),
# },
# "regenerate-memory-cache": {
# "task": "app.tasks.regenerate_memory_cache",
# "schedule": memory_cache_regeneration_schedule,
# "args": (),
# },
# "run-forgetting-cycle": {
# "task": "app.tasks.run_forgetting_cycle_task",
# "schedule": forgetting_cycle_schedule,
# "kwargs": {
# "config_id": None, # 使用默认配置,可以通过环境变量配置
# },
# },
# }
# 如果配置了默认工作空间ID则添加记忆总量统计任务
if settings.DEFAULT_WORKSPACE_ID:
beat_schedule_config["write-total-memory"] = {
"task": "app.controllers.memory_storage_controller.search_all",
"schedule": memory_increment_schedule,
"kwargs": {
"workspace_id": settings.DEFAULT_WORKSPACE_ID,
},
}
# if settings.DEFAULT_WORKSPACE_ID:
# beat_schedule_config["write-total-memory"] = {
# "task": "app.controllers.memory_storage_controller.search_all",
# "schedule": memory_increment_schedule,
# "kwargs": {
# "workspace_id": settings.DEFAULT_WORKSPACE_ID,
# },
# }
celery_app.conf.beat_schedule = beat_schedule_config
# celery_app.conf.beat_schedule = beat_schedule_config

View File

@@ -43,6 +43,7 @@ from . import (
user_memory_controllers,
workflow_controller,
workspace_controller,
ontology_controller,
)
# 创建管理端 API 路由器
@@ -88,5 +89,6 @@ manager_router.include_router(implicit_memory_controller.router)
manager_router.include_router(memory_perceptual_controller.router)
manager_router.include_router(memory_working_controller.router)
manager_router.include_router(file_storage_controller.router)
manager_router.include_router(ontology_controller.router)
__all__ = ["manager_router"]

View File

@@ -51,7 +51,6 @@ async def save_reflection_config(
status_code=status.HTTP_400_BAD_REQUEST,
detail="缺少必需参数: config_id"
)
api_logger.info(f"用户 {current_user.username} 保存反思配置config_id: {config_id}")
memory_config = MemoryConfigRepository.update_reflection_config(
@@ -102,7 +101,7 @@ async def start_workspace_reflection(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""Activate the reflection function for all matching applications in the workspace"""
"""启动工作空间中所有匹配应用的反思功能"""
workspace_id = current_user.current_workspace_id
reflection_service = MemoryReflectionService(db)
@@ -111,42 +110,55 @@ async def start_workspace_reflection(
service = WorkspaceAppService(db)
result = service.get_workspace_apps_detailed(workspace_id)
reflection_results = []
for data in result['apps_detailed_info']:
if data['memory_configs'] == []:
# 跳过没有配置的应用
if not data['memory_configs']:
api_logger.debug(f"应用 {data['id']} 没有memory_configs跳过")
continue
releases = data['releases']
memory_configs = data['memory_configs']
end_users = data['end_users']
for base, config, user in zip(releases, memory_configs, end_users):
# 安全地转换为整数处理空字符串和None的情况
print(base['config'])
try:
base_config = int(base['config']) if base['config'] else 0
config_id = int(config['config_id']) if config['config_id'] else 0
except (ValueError, TypeError):
api_logger.warning(f"无效的配置ID: base['config']={base.get('config')}, config['config_id']={config.get('config_id')}")
# 为每个配置和用户组合执行反思
for config in memory_configs:
config_id_str = str(config['config_id'])
# 找到匹配此配置的所有release
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
if not matching_releases:
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
continue
if base_config == config_id and base['app_id'] == user['app_id']:
# 调用反思服务
api_logger.info(f"为用户 {user['id']} 启动反思config_id: {config['config_id']}")
reflection_result = await reflection_service.start_text_reflection(
config_data=config,
end_user_id=user['id']
)
reflection_results.append({
"app_id": base['app_id'],
"config_id": config['config_id'],
"end_user_id": user['id'],
"reflection_result": reflection_result
})
# 为每个用户执行反思
for user in end_users:
api_logger.info(f"为用户 {user['id']} 启动反思config_id: {config_id_str}")
try:
reflection_result = await reflection_service.start_text_reflection(
config_data=config,
end_user_id=user['id']
)
reflection_results.append({
"app_id": data['id'],
"config_id": config_id_str,
"end_user_id": user['id'],
"reflection_result": reflection_result
})
except Exception as e:
api_logger.error(f"用户 {user['id']} 反思失败: {str(e)}")
reflection_results.append({
"app_id": data['id'],
"config_id": config_id_str,
"end_user_id": user['id'],
"reflection_result": {
"status": "错误",
"message": f"反思失败: {str(e)}"
}
})
return success(data=reflection_results, msg="反思配置成功")

View File

@@ -0,0 +1,964 @@
"""本体提取API控制器
本模块提供本体提取系统的RESTful API端点。
Endpoints:
POST /api/memory/ontology/extract - 提取本体类
POST /api/memory/ontology/export - 导出OWL文件
POST /api/memory/ontology/scene - 创建本体场景
PUT /api/memory/ontology/scene/{scene_id} - 更新本体场景
DELETE /api/memory/ontology/scene/{scene_id} - 删除本体场景
GET /api/memory/ontology/scene/{scene_id} - 获取单个场景
GET /api/memory/ontology/scenes - 获取场景列表
POST /api/memory/ontology/class - 创建本体类型
PUT /api/memory/ontology/class/{class_id} - 更新本体类型
DELETE /api/memory/ontology/class/{class_id} - 删除本体类型
GET /api/memory/ontology/class/{class_id} - 获取单个类型
GET /api/memory/ontology/classes - 获取类型列表
"""
import logging
import tempfile
from typing import Dict, Optional
from fastapi import APIRouter, Depends, HTTPException, Header
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.services.memory_base_service import Translation_English
from app.core.memory.models.ontology_models import OntologyClass
from typing import List
from app.schemas.ontology_schemas import (
ExportRequest,
ExportResponse,
ExtractionRequest,
ExtractionResponse,
SceneCreateRequest,
SceneUpdateRequest,
SceneResponse,
SceneListResponse,
ClassCreateRequest,
ClassUpdateRequest,
ClassResponse,
ClassListResponse,
)
from app.schemas.response_schema import ApiResponse
from app.services.ontology_service import OntologyService
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.utils.validation.owl_validator import OWLValidator
from app.services.model_service import ModelConfigService
api_logger = get_api_logger()
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/memory/ontology",
tags=["Ontology"],
)
async def translate_ontology_classes(
classes: List[OntologyClass],
model_id: str
) -> List[OntologyClass]:
"""翻译本体类列表
将本体类的中文字段翻译为英文,包括:
- name_chinese: 中文名称
- description: 描述
- examples: 示例列表
Args:
classes: 本体类列表
model_id: LLM模型ID用于翻译
Returns:
List[OntologyClass]: 翻译后的本体类列表
"""
translated_classes = []
for ontology_class in classes:
# 创建类的副本,避免修改原对象
translated_class = ontology_class.model_copy(deep=True)
# 翻译 name_chinese 字段
if translated_class.name_chinese:
try:
translated_class.name_chinese = await Translation_English(
model_id,
translated_class.name_chinese
)
except Exception as e:
logger.warning(f"Failed to translate name_chinese: {e}")
# 保留原文
# 翻译 description 字段
if translated_class.description:
try:
translated_class.description = await Translation_English(
model_id,
translated_class.description
)
except Exception as e:
logger.warning(f"Failed to translate description: {e}")
# 保留原文
# 翻译 examples 列表
if translated_class.examples:
translated_examples = []
for example in translated_class.examples:
try:
translated_example = await Translation_English(
model_id,
example
)
translated_examples.append(translated_example)
except Exception as e:
logger.warning(f"Failed to translate example: {e}")
translated_examples.append(example) # 保留原文
translated_class.examples = translated_examples
translated_classes.append(translated_class)
return translated_classes
def _get_ontology_service(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
llm_id: str = None
) -> OntologyService:
"""获取OntologyService实例的依赖注入函数
指定的llm_id获取LLM配置,创建OpenAIClient和OntologyService实例。
Args:
db: 数据库会话
current_user: 当前用户
llm_id: 可选的LLM模型ID,如果提供则使用指定模型,否则使用工作空间默认模型
Returns:
OntologyService: 本体提取服务实例
Raises:
HTTPException: 如果无法获取LLM配置
"""
try:
import uuid
# 必须提供llm_id
if not llm_id:
logger.error(f"llm_id is required but not provided - user: {current_user.id}")
raise HTTPException(
status_code=400,
detail="必须提供llm_id参数"
)
logger.info(f"Using specified LLM model: {llm_id}")
# 验证llm_id格式
try:
model_id = uuid.UUID(llm_id)
except ValueError:
logger.error(f"Invalid llm_id format: {llm_id}")
raise HTTPException(
status_code=400,
detail="无效的LLM模型ID格式"
)
# 获取指定的模型配置
try:
model_config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
except Exception as e:
logger.error(f"Model {llm_id} not found: {str(e)}")
raise HTTPException(
status_code=400,
detail=f"找不到指定的LLM模型: {llm_id}"
)
# 检查是否为组合模型
if hasattr(model_config, 'is_composite') and model_config.is_composite:
logger.error(f"Model {llm_id} is a composite model, which is not supported for ontology extraction")
raise HTTPException(
status_code=400,
detail="本体提取不支持使用组合模型,请选择单个模型"
)
# 验证模型配置了API密钥
if not model_config.api_keys:
logger.error(f"Model {llm_id} has no API key configuration")
raise HTTPException(
status_code=400,
detail="指定的LLM模型没有配置API密钥"
)
api_key_config = model_config.api_keys[0]
logger.info(
f"Using specified model - user: {current_user.id}, "
f"model_id: {llm_id}, model_name: {api_key_config.model_name}"
)
# 创建模型配置对象
from app.core.models.base import RedBearModelConfig
llm_model_config = RedBearModelConfig(
model_name=api_key_config.model_name,
provider=model_config.provider if hasattr(model_config, 'provider') else "openai",
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
max_retries=3,
timeout=60.0
)
# 创建OpenAI客户端
llm_client = OpenAIClient(model_config=llm_model_config)
# 创建OntologyService
service = OntologyService(llm_client=llm_client, db=db)
logger.debug(
f"OntologyService created successfully - "
f"user: {current_user.id}, model: {api_key_config.model_name}"
)
return service
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to create OntologyService: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"创建本体提取服务失败: {str(e)}"
)
@router.post("/extract", response_model=ApiResponse)
async def extract_ontology(
request: ExtractionRequest,
language_type: str = Header(default="zh", alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""提取本体类
从场景描述中提取符合OWL规范的本体类。
提取结果仅返回给前端,不会自动保存到数据库。
前端可以从返回结果中选择需要的类型,然后调用 /class 接口创建类型。
支持中英文切换,通过 X-Language-Type Header 指定语言。
Args:
request: 提取请求,包含scenario、domain、llm_id和scene_id
language_type: 语言类型,'zh'(中文)或 'en'(英文),默认 'zh'
db: 数据库会话
current_user: 当前用户
Returns:
ApiResponse: 包含提取结果的响应
Response format:
{
"code": 200,
"msg": "本体提取成功",
"data": {
"classes": [
{
"id": "147d9db50b524a9e909e01a753d3acdd",
"name": "Patient",
"name_chinese": "患者",
"description": "在医疗机构中接受诊疗、护理或健康管理的个体",
"examples": ["糖尿病患者", "术后康复患者", "门诊初诊患者"],
"parent_class": null,
"entity_type": "Person",
"domain": "Healthcare"
},
...
],
"domain": "Healthcare",
"extracted_count": 7
}
}
"""
api_logger.info(
f"Ontology extraction requested by user {current_user.id}, "
f"scenario_length={len(request.scenario)}, "
f"domain={request.domain}, "
f"llm_id={request.llm_id}, "
f"scene_id={request.scene_id}, "
f"language_type={language_type}"
)
try:
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建OntologyService实例,传入llm_id
service = _get_ontology_service(
db=db,
current_user=current_user,
llm_id=request.llm_id
)
# 调用服务层执行提取传入scene_id和workspace_id
result = await service.extract_ontology(
scenario=request.scenario,
domain=request.domain,
scene_id=request.scene_id,
workspace_id=workspace_id
)
# ===== 新增:翻译逻辑 =====
# 如果需要英文,则翻译数据
if language_type != 'zh':
api_logger.info(f"Translating extraction result to English")
# 翻译 classes 列表
result.classes = await translate_ontology_classes(
result.classes,
request.llm_id
)
# 翻译 domain 字段
if result.domain:
try:
result.domain = await Translation_English(
request.llm_id,
result.domain
)
except Exception as e:
logger.warning(f"Failed to translate domain: {e}")
# 保留原文
# ===== 翻译逻辑结束 =====
# 构建响应
response = ExtractionResponse(
classes=result.classes,
domain=result.domain,
extracted_count=len(result.classes)
)
api_logger.info(
f"Ontology extraction completed, extracted {len(result.classes)} classes, "
f"saved to scene {request.scene_id}, language={language_type}"
)
return success(data=response.model_dump(), msg="本体提取成功")
except ValueError as e:
# 验证错误 (400)
api_logger.warning(f"Validation error in extraction: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
# 运行时错误 (500)
api_logger.error(f"Runtime error in extraction: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "本体提取失败", str(e))
except Exception as e:
# 未知错误 (500)
api_logger.error(f"Unexpected error in extraction: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "本体提取失败", str(e))
@router.post("/export", response_model=ApiResponse)
async def export_owl(
request: ExportRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""导出OWL文件
将提取的本体类导出为OWL文件,支持多种格式。
导出操作不需要LLM,只使用OWL验证器和Owlready2库。
Args:
request: 导出请求,包含classes、format和include_metadata
db: 数据库会话
current_user: 当前用户
Returns:
ApiResponse: 包含OWL文件内容的响应
Supported formats:
- rdfxml: 标准OWL RDF/XML格式(完整)
- turtle: Turtle格式(可读性好)
- ntriples: N-Triples格式(简单)
- json: JSON格式(简化,只包含类信息)
Response format:
{
"code": 200,
"msg": "OWL文件导出成功",
"data": {
"owl_content": "...",
"format": "rdfxml",
"classes_count": 7
}
}
"""
api_logger.info(
f"OWL export requested by user {current_user.id}, "
f"classes_count={len(request.classes)}, "
f"format={request.format}, "
f"include_metadata={request.include_metadata}"
)
try:
# 验证格式
valid_formats = ["rdfxml", "turtle", "ntriples", "json"]
if request.format not in valid_formats:
api_logger.warning(f"Invalid export format: {request.format}")
return fail(
BizCode.BAD_REQUEST,
"不支持的导出格式",
f"format必须是以下之一: {', '.join(valid_formats)}"
)
# JSON格式直接导出,不需要OWL验证
if request.format == "json":
owl_validator = OWLValidator()
owl_content = owl_validator.export_to_owl(
world=None,
format="json",
classes=request.classes
)
response = ExportResponse(
owl_content=owl_content,
format=request.format,
classes_count=len(request.classes)
)
api_logger.info(
f"JSON export completed, content_length={len(owl_content)}"
)
return success(data=response.model_dump(), msg="OWL文件导出成功")
# 创建临时文件路径
with tempfile.NamedTemporaryFile(
mode='w',
suffix='.owl',
delete=False
) as tmp_file:
output_path = tmp_file.name
# 导出操作不需要LLM,直接使用OWL验证器
owl_validator = OWLValidator()
# 验证本体类
logger.debug("Validating ontology classes")
is_valid, errors, world = owl_validator.validate_ontology_classes(
classes=request.classes,
)
if not is_valid:
logger.warning(
f"OWL validation found {len(errors)} issues during export: {errors}"
)
# 继续导出,但记录警告
if not world:
error_msg = "Failed to create OWL world for export"
logger.error(error_msg)
return fail(BizCode.INTERNAL_ERROR, "创建OWL世界失败", error_msg)
# 导出OWL文件
logger.info(f"Exporting to {request.format} format")
owl_content = owl_validator.export_to_owl(
world=world,
output_path=output_path,
format=request.format,
classes=request.classes
)
# 构建响应
response = ExportResponse(
owl_content=owl_content,
format=request.format,
classes_count=len(request.classes)
)
api_logger.info(
f"OWL export completed, format={request.format}, "
f"content_length={len(owl_content)}"
)
return success(data=response.model_dump(), msg="OWL文件导出成功")
except ValueError as e:
# 验证错误 (400)
api_logger.warning(f"Validation error in export: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
# 运行时错误 (500)
api_logger.error(f"Runtime error in export: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "OWL文件导出失败", str(e))
except Exception as e:
# 未知错误 (500)
api_logger.error(f"Unexpected error in export: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "OWL文件导出失败", str(e))
# ==================== 本体场景管理接口 ====================
@router.post("/scene", response_model=ApiResponse)
async def create_scene(
request: SceneCreateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""创建本体场景
在当前工作空间下创建新的本体场景。
Args:
request: 场景创建请求
db: 数据库会话
current_user: 当前用户
Returns:
ApiResponse: 包含创建的场景信息
"""
api_logger.info(
f"Scene creation requested by user {current_user.id}, "
f"name={request.scene_name}"
)
try:
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建OntologyService实例不需要LLM
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.models.base import RedBearModelConfig
# 创建一个空的LLM配置场景管理不需要LLM
dummy_config = RedBearModelConfig(
model_name="dummy",
provider="openai",
api_key="dummy",
base_url="https://api.openai.com/v1"
)
llm_client = OpenAIClient(model_config=dummy_config)
service = OntologyService(llm_client=llm_client, db=db)
# 调用服务层创建场景
scene = service.create_scene(
scene_name=request.scene_name,
scene_description=request.scene_description,
workspace_id=workspace_id
)
# 构建响应
# 动态计算 type_num
type_num = len(scene.classes) if scene.classes else 0
response = SceneResponse(
scene_id=scene.scene_id,
scene_name=scene.scene_name,
scene_description=scene.scene_description,
type_num=type_num,
workspace_id=scene.workspace_id,
created_at=scene.created_at,
updated_at=scene.updated_at,
classes_count=type_num
)
api_logger.info(f"Scene created successfully: {scene.scene_id}")
return success(data=response.model_dump(), msg="场景创建成功")
except ValueError as e:
api_logger.warning(f"Validation error in scene creation: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in scene creation: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in scene creation: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", str(e))
@router.put("/scene/{scene_id}", response_model=ApiResponse)
async def update_scene(
scene_id: str,
request: SceneUpdateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""更新本体场景
更新指定场景的信息,只能更新当前工作空间下的场景。
Args:
scene_id: 场景ID
request: 场景更新请求
db: 数据库会话
current_user: 当前用户
Returns:
ApiResponse: 包含更新后的场景信息
"""
api_logger.info(
f"Scene update requested by user {current_user.id}, "
f"scene_id={scene_id}"
)
try:
from uuid import UUID
# 验证UUID格式
try:
scene_uuid = UUID(scene_id)
except ValueError:
api_logger.warning(f"Invalid scene_id format: {scene_id}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的场景ID格式")
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建OntologyService实例
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.models.base import RedBearModelConfig
dummy_config = RedBearModelConfig(
model_name="dummy",
provider="openai",
api_key="dummy",
base_url="https://api.openai.com/v1"
)
llm_client = OpenAIClient(model_config=dummy_config)
service = OntologyService(llm_client=llm_client, db=db)
# 调用服务层更新场景
scene = service.update_scene(
scene_id=scene_uuid,
scene_name=request.scene_name,
scene_description=request.scene_description,
workspace_id=workspace_id
)
# 构建响应
# 动态计算 type_num
type_num = len(scene.classes) if scene.classes else 0
response = SceneResponse(
scene_id=scene.scene_id,
scene_name=scene.scene_name,
scene_description=scene.scene_description,
type_num=type_num,
workspace_id=scene.workspace_id,
created_at=scene.created_at,
updated_at=scene.updated_at,
classes_count=type_num
)
api_logger.info(f"Scene updated successfully: {scene_id}")
return success(data=response.model_dump(), msg="场景更新成功")
except ValueError as e:
api_logger.warning(f"Validation error in scene update: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in scene update: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "场景更新失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in scene update: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "场景更新失败", str(e))
@router.delete("/scene/{scene_id}", response_model=ApiResponse)
async def delete_scene(
scene_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""删除本体场景
删除指定场景及其所有关联类型,只能删除当前工作空间下的场景。
Args:
scene_id: 场景ID
db: 数据库会话
current_user: 当前用户
Returns:
ApiResponse: 删除结果
"""
api_logger.info(
f"Scene deletion requested by user {current_user.id}, "
f"scene_id={scene_id}"
)
try:
from uuid import UUID
# 验证UUID格式
try:
scene_uuid = UUID(scene_id)
except ValueError:
api_logger.warning(f"Invalid scene_id format: {scene_id}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的场景ID格式")
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建OntologyService实例
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.models.base import RedBearModelConfig
dummy_config = RedBearModelConfig(
model_name="dummy",
provider="openai",
api_key="dummy",
base_url="https://api.openai.com/v1"
)
llm_client = OpenAIClient(model_config=dummy_config)
service = OntologyService(llm_client=llm_client, db=db)
# 调用服务层删除场景
success_flag = service.delete_scene(
scene_id=scene_uuid,
workspace_id=workspace_id
)
api_logger.info(f"Scene deleted successfully: {scene_id}")
return success(data={"deleted": success_flag}, msg="场景删除成功")
except ValueError as e:
api_logger.warning(f"Validation error in scene deletion: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in scene deletion: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "场景删除失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in scene deletion: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "场景删除失败", str(e))
@router.get("/scenes", response_model=ApiResponse)
async def get_scenes(
workspace_id: Optional[str] = None,
scene_name: Optional[str] = None,
page: Optional[int] = None,
pagesize: Optional[int] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取场景列表(支持模糊搜索和全量查询,全量查询支持分页)
根据是否提供 scene_name 参数,执行不同的查询:
- 提供 scene_name进行模糊搜索返回匹配的场景列表支持分页
- 不提供 scene_name返回工作空间下的所有场景支持分页
支持中文和英文的模糊匹配,不区分大小写。
Args:
workspace_id: 工作空间ID可选默认当前用户工作空间
scene_name: 场景名称关键词(可选,支持模糊匹配)
page: 页码可选从1开始
pagesize: 每页数量(可选)
db: 数据库会话
current_user: 当前用户
Returns:
ApiResponse: 包含场景列表和分页信息
Examples:
- 模糊搜索不分页GET /scenes?workspace_id=xxx&scene_name=医疗
输入 "医疗" 可以匹配到 "医疗场景""智慧医疗""医疗管理系统"
- 模糊搜索分页GET /scenes?workspace_id=xxx&scene_name=医疗&page=1&pagesize=10
返回匹配 "医疗" 的第1页每页10条数据
- 全量查询不分页GET /scenes?workspace_id=xxx
返回工作空间下的所有场景
- 全量查询分页GET /scenes?workspace_id=xxx&page=1&pagesize=10
返回第1页每页10条数据
Notes:
- 分页参数 page 和 pagesize 必须同时提供
- page 从1开始pagesize 必须大于0
- 返回格式:{"items": [...], "page": {"page": 1, "pagesize": 10, "total": 100, "hasnext": true}}
- 不分页时page 字段为 null
"""
from app.controllers.ontology_secondary_routes import scenes_handler
return await scenes_handler(workspace_id, scene_name, page, pagesize, db, current_user)
# ==================== 本体类型管理接口 ====================
@router.post("/class", response_model=ApiResponse)
async def create_class(
request: ClassCreateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""创建本体类型
在指定场景下创建新的本体类型。
Args:
request: 类型创建请求
db: 数据库会话
current_user: 当前用户
Returns:
ApiResponse: 包含创建的类型信息
"""
from app.controllers.ontology_secondary_routes import create_class_handler
return await create_class_handler(request, db, current_user)
@router.put("/class/{class_id}", response_model=ApiResponse)
async def update_class(
class_id: str,
request: ClassUpdateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""更新本体类型
更新指定类型的信息,只能更新当前工作空间下场景的类型。
Args:
class_id: 类型ID
request: 类型更新请求
db: 数据库会话
current_user: 当前用户
Returns:
ApiResponse: 包含更新后的类型信息
"""
from app.controllers.ontology_secondary_routes import update_class_handler
return await update_class_handler(class_id, request, db, current_user)
@router.delete("/class/{class_id}", response_model=ApiResponse)
async def delete_class(
class_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""删除本体类型
删除指定类型,只能删除当前工作空间下场景的类型。
Args:
class_id: 类型ID
db: 数据库会话
current_user: 当前用户
Returns:
ApiResponse: 删除结果
"""
from app.controllers.ontology_secondary_routes import delete_class_handler
return await delete_class_handler(class_id, db, current_user)
@router.get("/classes", response_model=ApiResponse)
async def get_classes(
scene_id: str,
class_name: Optional[str] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取类型列表(支持模糊搜索和全量查询)
根据是否提供 class_name 参数,执行不同的查询:
- 提供 class_name进行模糊搜索返回匹配的类型列表
- 不提供 class_name返回场景下的所有类型
支持中文和英文的模糊匹配,不区分大小写。
返回结果包含场景的基本信息scene_name 和 scene_description
Args:
scene_id: 场景ID必填
class_name: 类型名称关键词(可选,支持模糊匹配)
db: 数据库会话
current_user: 当前用户
Returns:
ApiResponse: 包含类型列表和场景信息
Examples:
- 模糊搜索GET /classes?scene_id=xxx&class_name=患者
输入 "患者" 可以匹配到 "患者""患者信息""门诊患者"
- 全量查询GET /classes?scene_id=xxx
返回场景下的所有类型
Response Format:
{
"total": 3,
"scene_id": "xxx",
"scene_name": "医疗场景",
"scene_description": "用于医疗领域的本体建模",
"items": [...]
}
"""
from app.controllers.ontology_secondary_routes import classes_handler
return await classes_handler(scene_id, class_name, db, current_user)
@router.get("/class/{class_id}", response_model=ApiResponse)
async def get_class(
class_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取单个本体类型
根据类型ID获取类型的详细信息只能查询当前工作空间下场景的类型。
Args:
class_id: 类型ID
db: 数据库会话
current_user: 当前用户
Returns:
ApiResponse: 包含类型详细信息
Response Format:
{
"code": 0,
"msg": "查询成功",
"data": {
"class_id": "xxx",
"class_name": "患者",
"class_description": "在医疗机构中接受诊疗的个体",
"scene_id": "xxx",
"created_at": "2026-01-29T10:00:00",
"updated_at": "2026-01-29T10:00:00"
}
}
"""
from app.controllers.ontology_secondary_routes import get_class_handler
return await get_class_handler(class_id, db, current_user)

View File

@@ -0,0 +1,611 @@
# -*- coding: utf-8 -*-
"""本体场景和类型路由(续)
由于主Controller文件较大将剩余路由放在此文件中。
"""
from uuid import UUID
from typing import Optional
from fastapi import Depends
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.ontology_schemas import (
SceneResponse,
SceneListResponse,
PaginationInfo,
ClassCreateRequest,
ClassUpdateRequest,
ClassResponse,
ClassListResponse,
ClassBatchCreateResponse,
)
from app.schemas.response_schema import ApiResponse
from app.services.ontology_service import OntologyService
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.models.base import RedBearModelConfig
api_logger = get_api_logger()
def _get_dummy_ontology_service(db: Session) -> OntologyService:
"""获取OntologyService实例不需要LLM
场景和类型管理不需要LLM创建一个dummy配置。
"""
dummy_config = RedBearModelConfig(
model_name="dummy",
provider="openai",
api_key="dummy",
base_url="https://api.openai.com/v1"
)
llm_client = OpenAIClient(model_config=dummy_config)
return OntologyService(llm_client=llm_client, db=db)
# 这些函数将被导入到主Controller中
async def scenes_handler(
workspace_id: Optional[str] = None,
scene_name: Optional[str] = None,
page: Optional[int] = None,
page_size: Optional[int] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取场景列表(支持模糊搜索和全量查询,全量查询支持分页)
当提供 scene_name 参数时,进行模糊搜索(不分页);
当不提供 scene_name 参数时,返回所有场景(支持分页)。
Args:
workspace_id: 工作空间ID可选默认当前用户工作空间
scene_name: 场景名称关键词(可选,支持模糊匹配)
page: 页码可选从1开始仅在全量查询时有效
page_size: 每页数量(可选,仅在全量查询时有效)
db: 数据库会话
current_user: 当前用户
"""
operation = "search" if scene_name else "list"
api_logger.info(
f"Scene {operation} requested by user {current_user.id}, "
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}"
)
try:
# 确定工作空间ID
if workspace_id:
try:
ws_uuid = UUID(workspace_id)
except ValueError:
api_logger.warning(f"Invalid workspace_id format: {workspace_id}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的工作空间ID格式")
else:
ws_uuid = current_user.current_workspace_id
if not ws_uuid:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建Service
service = _get_dummy_ontology_service(db)
# 根据是否提供 scene_name 决定查询方式
if scene_name and scene_name.strip():
# 验证分页参数(模糊搜索也支持分页)
if page is not None and page < 1:
api_logger.warning(f"Invalid page number: {page}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
if page_size is not None and page_size < 1:
api_logger.warning(f"Invalid page_size: {page_size}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
# 如果只提供了page或page_size中的一个返回错误
if (page is not None and page_size is None) or (page is None and page_size is not None):
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
# 模糊搜索场景(支持分页)
scenes = service.search_scenes_by_name(scene_name.strip(), ws_uuid)
total = len(scenes)
# 如果提供了分页参数,进行分页处理
if page is not None and page_size is not None:
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
scenes = scenes[start_idx:end_idx]
# 构建响应
items = []
for scene in scenes:
# 获取前3个class_name作为entity_type
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
# 动态计算 type_num
type_num = len(scene.classes) if scene.classes else 0
items.append(SceneResponse(
scene_id=scene.scene_id,
scene_name=scene.scene_name,
scene_description=scene.scene_description,
type_num=type_num,
entity_type=entity_type,
workspace_id=scene.workspace_id,
created_at=scene.created_at,
updated_at=scene.updated_at,
classes_count=type_num
))
# 构建响应(包含分页信息)
if page is not None and page_size is not None:
# 计算是否有下一页
hasnext = (page * page_size) < total
pagination_info = PaginationInfo(
page=page,
pagesize=page_size,
total=total,
hasnext=hasnext
)
response = SceneListResponse(items=items, page=pagination_info)
else:
response = SceneListResponse(items=items)
api_logger.info(
f"Scene search completed: found {len(items)} scenes matching '{scene_name}' "
f"in workspace {ws_uuid}, total={total}"
)
else:
# 获取所有场景(支持分页)
# 验证分页参数
if page is not None and page < 1:
api_logger.warning(f"Invalid page number: {page}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
if page_size is not None and page_size < 1:
api_logger.warning(f"Invalid page_size: {page_size}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
# 如果只提供了page或page_size中的一个返回错误
if (page is not None and page_size is None) or (page is None and page_size is not None):
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
scenes, total = service.list_scenes(ws_uuid, page, page_size)
# 构建响应
items = []
for scene in scenes:
# 获取前3个class_name作为entity_type
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
# 动态计算 type_num
type_num = len(scene.classes) if scene.classes else 0
items.append(SceneResponse(
scene_id=scene.scene_id,
scene_name=scene.scene_name,
scene_description=scene.scene_description,
type_num=type_num,
entity_type=entity_type,
workspace_id=scene.workspace_id,
created_at=scene.created_at,
updated_at=scene.updated_at,
classes_count=type_num
))
# 构建响应(包含分页信息)
if page is not None and page_size is not None:
# 计算是否有下一页
hasnext = (page * page_size) < total
pagination_info = PaginationInfo(
page=page,
pagesize=page_size,
total=total,
hasnext=hasnext
)
response = SceneListResponse(items=items, page=pagination_info)
else:
response = SceneListResponse(items=items)
api_logger.info(f"Scene list retrieved successfully, count={len(items)}, total={total}")
return success(data=response.model_dump(mode='json'), msg="查询成功")
except ValueError as e:
api_logger.warning(f"Validation error in scene {operation}: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in scene {operation}: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in scene {operation}: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
# ==================== 本体类型管理接口 ====================
async def create_class_handler(
request: ClassCreateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
# 根据列表长度判断是单个还是批量
count = len(request.classes)
mode = "single" if count == 1 else "batch"
api_logger.info(
f"Class creation ({mode}) requested by user {current_user.id}, "
f"scene_id={request.scene_id}, count={count}"
)
try:
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建Service
service = _get_dummy_ontology_service(db)
# 准备类型数据
classes_data = [
{
"class_name": item.class_name,
"class_description": item.class_description
}
for item in request.classes
]
if count == 1:
# 单个创建
class_data = classes_data[0]
ontology_class = service.create_class(
scene_id=request.scene_id,
class_name=class_data["class_name"],
class_description=class_data["class_description"],
workspace_id=workspace_id
)
# 构建单个响应
response = ClassResponse(
class_id=ontology_class.class_id,
class_name=ontology_class.class_name,
class_description=ontology_class.class_description,
scene_id=ontology_class.scene_id,
created_at=ontology_class.created_at,
updated_at=ontology_class.updated_at
)
api_logger.info(f"Class created successfully: {ontology_class.class_id}")
return success(data=response.model_dump(mode='json'), msg="类型创建成功")
else:
# 批量创建
created_classes, errors = service.create_classes_batch(
scene_id=request.scene_id,
classes=classes_data,
workspace_id=workspace_id
)
# 构建批量响应
items = []
for ontology_class in created_classes:
items.append(ClassResponse(
class_id=ontology_class.class_id,
class_name=ontology_class.class_name,
class_description=ontology_class.class_description,
scene_id=ontology_class.scene_id,
created_at=ontology_class.created_at,
updated_at=ontology_class.updated_at
))
response = ClassBatchCreateResponse(
total=len(classes_data),
success_count=len(created_classes),
failed_count=len(errors),
items=items,
errors=errors if errors else None
)
api_logger.info(
f"Batch class creation completed: "
f"success={len(created_classes)}, failed={len(errors)}"
)
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
except ValueError as e:
api_logger.warning(f"Validation error in class creation: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
async def update_class_handler(
class_id: str,
request: ClassUpdateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""更新本体类型"""
api_logger.info(
f"Class update requested by user {current_user.id}, "
f"class_id={class_id}"
)
try:
# 验证UUID格式
try:
class_uuid = UUID(class_id)
except ValueError:
api_logger.warning(f"Invalid class_id format: {class_id}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建Service
service = _get_dummy_ontology_service(db)
# 更新类型
ontology_class = service.update_class(
class_id=class_uuid,
class_name=request.class_name,
class_description=request.class_description,
workspace_id=workspace_id
)
# 构建响应
response = ClassResponse(
class_id=ontology_class.class_id,
class_name=ontology_class.class_name,
class_description=ontology_class.class_description,
scene_id=ontology_class.scene_id,
created_at=ontology_class.created_at,
updated_at=ontology_class.updated_at
)
api_logger.info(f"Class updated successfully: {class_id}")
return success(data=response.model_dump(mode='json'), msg="类型更新成功")
except ValueError as e:
api_logger.warning(f"Validation error in class update: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in class update: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in class update: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
async def delete_class_handler(
class_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""删除本体类型"""
api_logger.info(
f"Class deletion requested by user {current_user.id}, "
f"class_id={class_id}"
)
try:
# 验证UUID格式
try:
class_uuid = UUID(class_id)
except ValueError:
api_logger.warning(f"Invalid class_id format: {class_id}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建Service
service = _get_dummy_ontology_service(db)
# 删除类型
success_flag = service.delete_class(
class_id=class_uuid,
workspace_id=workspace_id
)
api_logger.info(f"Class deleted successfully: {class_id}")
return success(data={"deleted": success_flag}, msg="类型删除成功")
except ValueError as e:
api_logger.warning(f"Validation error in class deletion: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in class deletion: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in class deletion: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
async def get_class_handler(
class_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取单个本体类型"""
api_logger.info(
f"Get class requested by user {current_user.id}, "
f"class_id={class_id}"
)
try:
# 验证UUID格式
try:
class_uuid = UUID(class_id)
except ValueError:
api_logger.warning(f"Invalid class_id format: {class_id}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建Service
service = _get_dummy_ontology_service(db)
# 获取类型会抛出ValueError如果不存在
ontology_class = service.get_class_by_id(class_uuid, workspace_id)
# 构建响应
response = ClassResponse(
class_id=ontology_class.class_id,
class_name=ontology_class.class_name,
class_description=ontology_class.class_description,
scene_id=ontology_class.scene_id,
created_at=ontology_class.created_at,
updated_at=ontology_class.updated_at
)
api_logger.info(f"Class retrieved successfully: {class_id}")
return success(data=response.model_dump(mode='json'), msg="查询成功")
except ValueError as e:
# 类型不存在或无权限访问
api_logger.warning(f"Validation error in get class: {str(e)}")
return fail(BizCode.NOT_FOUND, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in get class: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in get class: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
async def classes_handler(
scene_id: str,
class_name: Optional[str] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取类型列表(支持模糊搜索和全量查询)
当提供 class_name 参数时,进行模糊搜索;
当不提供 class_name 参数时,返回场景下的所有类型。
Args:
scene_id: 场景ID必填
class_name: 类型名称关键词(可选,支持模糊匹配)
db: 数据库会话
current_user: 当前用户
"""
operation = "search" if class_name else "list"
api_logger.info(
f"Class {operation} requested by user {current_user.id}, "
f"keyword={class_name}, scene_id={scene_id}"
)
try:
# 验证UUID格式
try:
scene_uuid = UUID(scene_id)
except ValueError:
api_logger.warning(f"Invalid scene_id format: {scene_id}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的场景ID格式")
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建Service
service = _get_dummy_ontology_service(db)
# 获取场景信息
scene = service.get_scene_by_id(scene_uuid, workspace_id)
if not scene:
api_logger.warning(f"Scene not found: {scene_id}")
return fail(BizCode.NOT_FOUND, "场景不存在", f"未找到ID为 {scene_id} 的场景")
# 根据是否提供 class_name 决定查询方式
if class_name and class_name.strip():
# 模糊搜索类型
classes = service.search_classes_by_name(class_name.strip(), scene_uuid, workspace_id)
else:
# 获取所有类型
classes = service.list_classes_by_scene(scene_uuid, workspace_id)
# 构建响应
items = []
for ontology_class in classes:
items.append(ClassResponse(
class_id=ontology_class.class_id,
class_name=ontology_class.class_name,
class_description=ontology_class.class_description,
scene_id=ontology_class.scene_id,
created_at=ontology_class.created_at,
updated_at=ontology_class.updated_at
))
response = ClassListResponse(
total=len(items),
scene_id=scene_uuid,
scene_name=scene.scene_name,
scene_description=scene.scene_description,
items=items
)
if class_name:
api_logger.info(
f"Class search completed: found {len(items)} classes matching '{class_name}' "
f"in scene {scene_id}"
)
else:
api_logger.info(f"Class list retrieved successfully, count={len(items)}")
return success(data=response.model_dump(mode='json'), msg="查询成功")
except ValueError as e:
api_logger.warning(f"Validation error in class {operation}: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in class {operation}: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in class {operation}: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))

View File

@@ -1,5 +1,5 @@
import uuid
import json
import uuid
from fastapi import APIRouter, Depends, Path
from sqlalchemy.orm import Session
@@ -8,9 +8,13 @@ from starlette.responses import StreamingResponse
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.dependencies import get_current_user, get_db
from app.models.prompt_optimizer_model import RoleType
from app.schemas.prompt_optimizer_schema import PromptOptMessage, PromptOptModelSet, CreateSessionResponse, \
OptimizePromptResponse, SessionHistoryResponse, SessionMessage
from app.schemas.prompt_optimizer_schema import (
PromptOptMessage,
CreateSessionResponse,
SessionHistoryResponse,
SessionMessage,
PromptSaveRequest
)
from app.schemas.response_schema import ApiResponse
from app.services.prompt_optimizer_service import PromptOptimizerService
@@ -135,3 +139,109 @@ async def get_prompt_opt(
"X-Accel-Buffering": "no"
}
)
@router.post(
"/releases",
summary="Get prompt optimization",
response_model=ApiResponse
)
def save_prompt(
data: PromptSaveRequest,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
Save a prompt release for the current tenant.
Args:
data (PromptSaveRequest): Request body containing session_id, title, and prompt.
db (Session): SQLAlchemy database session, injected via dependency.
current_user: Currently authenticated user object, injected via dependency.
Returns:
ApiResponse: Standard API response containing the saved prompt release info:
- id: UUID of the prompt release
- session_id: associated session
- title: prompt title
- prompt: prompt content
- created_at: timestamp of creation
Raises:
Any database or service exceptions are propagated to the global exception handler.
"""
service = PromptOptimizerService(db)
prompt_info = service.save_prompt(
tenant_id=current_user.tenant_id,
session_id=data.session_id,
title=data.title,
prompt=data.prompt
)
return success(data=prompt_info)
@router.delete(
"/releases/{prompt_id}",
summary="Delete prompt (soft delete)",
response_model=ApiResponse
)
def delete_prompt(
prompt_id: uuid.UUID = Path(..., description="Prompt ID"),
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
Soft delete a prompt release.
Args:
prompt_id
db (Session): Database session
current_user: Current logged-in user
Returns:
ApiResponse: Success message confirming deletion
"""
service = PromptOptimizerService(db)
service.delete_prompt(
tenant_id=current_user.tenant_id,
prompt_id=prompt_id
)
return success(msg="Prompt deleted successfully")
@router.get(
"/releases/list",
summary="Get paginated list of released prompts with optional filter",
response_model=ApiResponse
)
def get_release_list(
page: int = 1,
page_size: int = 20,
keyword: str | None = None,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
Retrieve paginated list of released prompts for the current tenant.
Optionally filter by keyword in title.
Args:
page (int): Page number (starting from 1)
page_size (int): Number of items per page (max 100)
keyword (str | None): Optional keyword to filter prompt titles
db (Session): Database session
current_user: Current logged-in user
Returns:
ApiResponse: Contains paginated list of prompt releases with metadata
"""
service = PromptOptimizerService(db)
result = service.get_release_list(
tenant_id=current_user.tenant_id,
page=max(1, page),
page_size=min(max(1, page_size), 100),
filter_keyword=keyword
)
return success(data=result)

View File

@@ -11,7 +11,8 @@ import os
import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages, format_parsing, messages_parse
from app.core.memory.agent.langgraph_graph.write_graph import long_term_storage
from app.db import get_db
from app.core.logging_config import get_business_logger
from app.core.memory.agent.utils.redis_tool import store
@@ -106,7 +107,7 @@ class LangChainAgent:
"streaming": streaming,
"tool_count": len(self.tools),
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
"tool_count": len(self.tools)
# "tool_count": len(self.tools)
}
)
@@ -145,38 +146,33 @@ class LangChainAgent:
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
messages.append(HumanMessage(content=user_content))
return messages
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# async def term_memory_save(self,messages,end_user_end,aimessages):
# '''短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j'''
# end_user_end=f"Term_{end_user_end}"
# print(messages)
# print(aimessages)
# session_id = store.save_session(
# userid=end_user_end,
# messages=messages,
# apply_id=end_user_end,
# end_user_id=end_user_end,
# aimessages=aimessages
# )
# store.delete_duplicate_sessions()
# # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
# return session_id
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# async def term_memory_redis_read(self,end_user_end):
# end_user_end = f"Term_{end_user_end}"
# history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
# # logger.info(f'Redis_Agent:{end_user_end};{history}')
# messagss_list=[]
# retrieved_content=[]
# for messages in history:
# query = messages.get("Query")
# aimessages = messages.get("Answer")
# messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
# retrieved_content.append({query: aimessages})
# return messagss_list,retrieved_content
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
db = next(get_db())
scope=6
try:
repo = LongTermMemoryRepository(db)
await long_term_storage(long_term_type="chunk", langchain_messages=long_term_messages,
memory_config=actual_config_id, end_user_id=end_user_id, scope=scope)
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
if type=="chunk" or type=="aggregate":
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
if len(chunk_data)==scope:
repo.upsert(end_user_id, chunk_data)
logger.info(f'写入短长期:')
else:
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
long_messages = await messages_parse(long_time_data)
repo.upsert(end_user_id, long_messages)
logger.info(f'写入短长期:')
finally:
db.close()
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
"""
写入记忆(支持结构化消息)
@@ -224,14 +220,6 @@ class LangChainAgent:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
# 调用 Celery 任务,传递结构化消息列表
# 数据流:
# 1. structured_messages 传递给 write_message_task
# 2. write_message_task 调用 memory_agent_service.write_memory
# 3. write_memory 调用 write_tools.write传递 messages 参数
# 4. write_tools.write 调用 get_chunked_dialogs传递 messages 参数
# 5. get_chunked_dialogs 为每条消息创建独立的 Chunk设置 speaker 字段
# 6. 每个 Chunk 保存到 Neo4j包含 speaker 字段
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # end_user_id: 用户ID
@@ -288,30 +276,6 @@ class LangChainAgent:
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
# history_term_memory = history_term_memory_result[0]
# db_for_memory = next(get_db())
# if memory_flag:
# if len(history_term_memory)>=4 and storage_type != "rag":
# history_term_memory = ';'.join(history_term_memory)
# retrieved_content = history_term_memory_result[1]
# print(retrieved_content)
# # 为长期记忆操作获取新的数据库连接
# try:
# repo = LongTermMemoryRepository(db_for_memory)
# repo.upsert(end_user_id, retrieved_content)
# logger.info(
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
# except Exception as e:
# logger.error(f"Failed to write to LongTermMemory: {e}")
# raise
# finally:
# db_for_memory.close()
# # 长期记忆写入(
# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id)
# # 注意:不在这里写入用户消息,等 AI 回复后一起写入
try:
# 准备消息列表
messages = self._prepare_messages(message, history, context)
@@ -332,17 +296,21 @@ class LangChainAgent:
# 获取最后的 AI 消息
output_messages = result.get("messages", [])
content = ""
total_tokens = 0
for msg in reversed(output_messages):
if isinstance(msg, AIMessage):
content = msg.content
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
break
elapsed_time = time.time() - start_time
if memory_flag:
long_term_messages=await agent_chat_messages(message_chat,content)
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# await self.term_memory_save(message_chat, end_user_id, content)
'''长期'''
await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
response = {
"content": content,
"model": self.model_name,
@@ -350,7 +318,7 @@ class LangChainAgent:
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
"total_tokens": total_tokens
}
}
@@ -410,25 +378,7 @@ class LangChainAgent:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
# # TODO 乐力齐
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
# history_term_memory = history_term_memory_result[0]
# if memory_flag:
# if len(history_term_memory) >= 4 and storage_type != "rag":
# history_term_memory = ';'.join(history_term_memory)
# retrieved_content = history_term_memory_result[1]
# db_for_memory = next(get_db())
# try:
# repo = LongTermMemoryRepository(db_for_memory)
# repo.upsert(end_user_id, retrieved_content)
# logger.info(
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
# # 长期记忆写入
# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id)
# except Exception as e:
# logger.error(f"Failed to write to long term memory: {e}")
# finally:
# db_for_memory.close()
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
try:
@@ -444,7 +394,7 @@ class LangChainAgent:
# 统一使用 agent 的 astream_events 实现流式输出
logger.debug("使用 Agent astream_events 实现流式输出")
full_content=''
full_content = ''
try:
async for event in self.agent.astream_events(
{"messages": messages},
@@ -481,11 +431,20 @@ class LangChainAgent:
logger.debug(f"工具调用结束: {event.get('name')}")
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
# 统计token消耗
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
for msg in reversed(output_messages):
if isinstance(msg, AIMessage):
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
0) if response_meta else 0
yield total_tokens
break
if memory_flag:
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
long_term_messages = await agent_chat_messages(message_chat, full_content)
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# await self.term_memory_save(message_chat, end_user_id, full_content)
await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk")
except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)

View File

@@ -157,6 +157,11 @@ class Settings:
if origin.strip()
]
# Language Configuration
# Supported values: "zh" (Chinese), "en" (English)
# This controls the language used for memory summary titles and other generated content
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
# Logging settings
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")

View File

@@ -0,0 +1,165 @@
import os
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format, format_parsing
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.redis_tool import write_store
from app.core.memory.agent.utils.redis_tool import count_store
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
logger = get_agent_logger(__name__)
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
async def write_messages(end_user_id,langchain_messages,memory_config):
'''
写入数据到neo4j
Args:
end_user_id: 终端用户ID
memory_config: 内存配置对象
langchain_messages原始数据LIST
'''
try:
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {
"messages": langchain_messages,
"end_user_id": end_user_id,
"memory_config": memory_config
}
# 获取节点更新信息
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
if 'save_neo4j' == node_name:
massages = node_data
massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
print(contents)
except Exception as e:
import traceback
traceback.print_exc()
'''根据窗口'''
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
'''
根据窗口获取redis数据,写入neo4j
Args:
end_user_id: 终端用户ID
memory_config: 内存配置对象
langchain_messages原始数据LIST
scope窗口大小
'''
scope=scope
is_end_user_id = count_store.get_sessions_count(end_user_id)
if is_end_user_id is not False:
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
redis_messages = count_store.get_sessions_count(end_user_id)[1]
if is_end_user_id and int(is_end_user_id) != int(scope):
print(is_end_user_id)
is_end_user_id += 1
langchain_messages += redis_messages
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
elif int(is_end_user_id) == int(scope):
print('写入长期记忆并且设置为0')
print(is_end_user_id)
formatted_messages = await chat_data_format(redis_messages)
print(100*'-')
print(formatted_messages)
print(100*'-')
await write_messages(end_user_id, formatted_messages, memory_config)
count_store.update_sessions_count(end_user_id, 0, '')
else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
"""根据时间"""
async def memory_long_term_storage(end_user_id,memory_config,time):
'''
根据时间获取redis数据,写入neo4j
Args:
end_user_id: 终端用户ID
memory_config: 内存配置对象
'''
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
format_messages = await chat_data_format(long_time_data)
if format_messages!=[]:
await write_messages(end_user_id, format_messages, memory_config)
'''聚合判断'''
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
"""
聚合判断函数:判断输入句子和历史消息是否描述同一事件
Args:
end_user_id: 终端用户ID
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
memory_config: 内存配置对象
"""
try:
# 1. 获取历史会话数据(使用新方法)
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
history = await format_parsing(result)
if not result:
history = []
else:
history = await format_parsing(result)
json_schema = WriteAggregateModel.model_json_schema()
template_service = TemplateService(template_root)
system_prompt = await template_service.render_template(
template_name='write_aggregate_judgment.jinja2',
operation_name='aggregate_judgment',
history=history,
sentence=ori_messages,
json_schema=json_schema
)
with get_db_context() as db_session:
factory = MemoryClientFactory(db_session)
llm_client = factory.get_llm_client(memory_config.llm_model_id)
messages = [
{
"role": "user",
"content": system_prompt
}
]
structured = await llm_client.response_structured(
messages=messages,
response_model=WriteAggregateModel
)
output_value = structured.output
if isinstance(output_value, list):
output_value = [
{"role": msg.role, "content": msg.content}
for msg in output_value
]
result_dict = {
"is_same_event": structured.is_same_event,
"output": output_value
}
if not structured.is_same_event:
logger.info(result_dict)
await write_messages(end_user_id, output_value, memory_config)
return result_dict
except Exception as e:
print(f"[aggregate_judgment] 发生错误: {e}")
import traceback
traceback.print_exc()
return {
"is_same_event": False,
"output": ori_messages,
"messages": ori_messages,
"history": history if 'history' in locals() else [],
"error": str(e)
}

View File

@@ -0,0 +1,100 @@
import json
from langchain_core.messages import HumanMessage, AIMessage
async def format_parsing(messages: list,type:str='string'):
"""
格式化解析消息列表
Args:
messages: 消息列表
type: 返回类型 ('string''dict')
Returns:
格式化后的消息列表
"""
result = []
user=[]
ai=[]
for message in messages:
hstory_messages = message['messages']
for history_messag in hstory_messages.strip().splitlines():
history_messag = json.loads(history_messag)
for content in history_messag:
role = content['role']
content = content['content']
if type == "string":
if role == 'human':
content = '用户:' + content
else:
content = 'AI:' + content
result.append(content)
if type == "dict":
if role == 'human':
user.append( content)
else:
ai.append(content)
if type == "dict":
for key,values in zip(user,ai):
result.append({key:values})
return result
async def messages_parse(messages: list | dict):
user=[]
ai=[]
database=[]
for message in messages:
Query = message['Query']
Query = json.loads(Query)
for data in Query:
role = data['role']
if role == "human":
user.append(data['content'])
if role == "ai":
ai.append(data['content'])
for key, values in zip(user, ai):
database.append({key, values})
return database
async def chat_data_format(messages: list | dict):
"""
将消息格式化为 LangChain 消息格式
Args:
messages: 消息列表或字典
Returns:
LangChain 消息列表
"""
langchain_messages = []
if isinstance(messages, list):
for msg in messages:
if 'role' in msg.keys():
if msg['role'] == 'user':
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content']))
if "Query" in msg.keys():
langchain_messages.append(HumanMessage(content=msg['Query']))
langchain_messages.append(AIMessage(content=msg['Answer']))
if isinstance(messages, dict):
if messages['type'] == 'human':
langchain_messages.append(HumanMessage(content=messages['content']))
elif messages['type'] == 'ai':
langchain_messages.append(AIMessage(content=messages['content']))
return langchain_messages
async def agent_chat_messages(user_content,ai_content):
messages = [
{
"role": "user",
"content": f"{user_content}"
},
{
"role": "assistant",
"content": f"{ai_content}"
}
]
return messages

View File

@@ -1,20 +1,17 @@
import asyncio
import json
import sys
import warnings
from contextlib import asynccontextmanager
from langchain_core.messages import HumanMessage
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, chat_data_format, messages_parse
from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning)
@@ -34,14 +31,6 @@ async def make_write_graph():
end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration
"""
# workflow = StateGraph(WriteState)
# workflow.add_node("content_input", content_input_write)
# workflow.add_node("save_neo4j", write_node)
# workflow.add_edge(START, "content_input")
# workflow.add_edge("content_input", "save_neo4j")
# workflow.add_edge("save_neo4j", END)
#
# graph = workflow.compile()
workflow = StateGraph(WriteState)
workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "save_neo4j")
@@ -50,44 +39,49 @@ async def make_write_graph():
graph = workflow.compile()
yield graph
async def main():
"""主函数 - 运行工作流"""
message = "今天周一"
end_user_id = 'new_2025test1103' # 组ID
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
from app.core.memory.agent.utils.redis_tool import write_store
write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
# 获取数据库会话
db_session = next(get_db())
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=17, # 改为整数
config_id=memory_config, # 改为整数
service_name="MemoryAgentService"
)
try:
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config}
if long_term_type=='chunk':
'''方案一:对话窗口6轮对话'''
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
if long_term_type=='time':
"""时间"""
await memory_long_term_storage(end_user_id, memory_config,5)
if long_term_type=='aggregate':
# 获取节点更新信息
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
if 'save_neo4j'==node_name:
massages=node_data
massages=massages.get('write_result')['status']
print(massages) # | 更新数据: {node_data}
except Exception as e:
import traceback
traceback.print_exc()
"""方案三:聚合判断"""
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
if __name__ == "__main__":
import asyncio
asyncio.run(main())
# async def main():
# """主函数 - 运行工作流"""
# langchain_messages = [
# {
# "role": "user",
# "content": "今天周五好开心啊"
# },
# {
# "role": "assistant",
# "content": "你也这么觉得,我也是耶"
# }
#
# ]
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
# # await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
# result=await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
#
#
# if __name__ == "__main__":
# import asyncio
# asyncio.run(main())

View File

@@ -0,0 +1,28 @@
"""Pydantic models for write aggregate judgment operations."""
from typing import List, Union
from pydantic import BaseModel, Field
class MessageItem(BaseModel):
"""Individual message item in conversation."""
role: str = Field(..., description="角色user 或 assistant")
content: str = Field(..., description="消息内容")
class WriteAggregateResponse(BaseModel):
"""Response model for aggregate judgment containing judgment result and output."""
is_same_event: bool = Field(
...,
description="是否是同一事件。True表示是同一事件False表示不同事件"
)
output: Union[List[MessageItem], bool] = Field(
...,
description="如果is_same_event为True返回False如果is_same_event为False返回消息列表"
)
# 为了保持向后兼容,保留旧的类名作为别名
WriteAggregateModel = WriteAggregateResponse

View File

@@ -0,0 +1,57 @@
输入句子:{{sentence}}
历史消息:{{history}}
# 你的角色
你是一个擅长事件聚合与语义判断的专家。
# 你的任务
结合历史消息和输入句子,判断它们是否在描述**同一件事件或同一事件链**。
以下情况视为"同一事件"(需要返回 is_same_event=True, output=False
- 描述的是同一个具体事件或事实
- 存在明显的因果关系、前后发展关系
- 是对同一事件的补充、解释、追问或延展
- 逻辑上属于同一语境下的连续讨论
以下情况视为"不同事件"(需要返回 is_same_event=False, output=消息列表):
- 话题不同,事件主体不同
- 时间、地点、对象明显不同
- 只是语义相似,但并非同一具体事件
- 无直接事件、因果或逻辑关联
# 输出规则(非常重要)
你必须按照以下JSON格式输出
**如果是同一事件:**
```json
{
"is_same_event": true,
"output": false
}
```
**如果不是同一事件:**
```json
{
"is_same_event": false,
"output": [
{
"role": "user",
"content": "输入句子的内容"
},
{
"role": "assistant",
"content": "对应的回复内容"
}
]
}
```
# JSON Schema
{{json_schema}}
# 注意事项
- 必须严格按照上述格式输出
- output 字段:如果是同一事件返回 false如果不是同一事件返回完整的消息列表
- 消息列表必须包含 role 和 content 字段
- 不要输出任何解释、分析或多余内容

View File

@@ -0,0 +1,186 @@
import json
from typing import Any, List, Dict, Optional
from datetime import datetime, timedelta
def serialize_messages(messages: Any) -> str:
"""
将消息序列化为 JSON 字符串,支持 LangChain 消息对象
Args:
messages: 可以是 list、dict、string 或 LangChain 消息对象列表
Returns:
str: JSON 字符串
"""
if isinstance(messages, str):
return messages
if isinstance(messages, (list, tuple)):
# 检查是否是 LangChain 消息对象列表
serialized_list = []
for msg in messages:
if hasattr(msg, 'type') and hasattr(msg, 'content'):
# LangChain 消息对象
serialized_list.append({
'type': msg.type,
'content': msg.content,
'role': getattr(msg, 'role', msg.type)
})
elif isinstance(msg, dict):
serialized_list.append(msg)
else:
serialized_list.append(str(msg))
return json.dumps(serialized_list, ensure_ascii=False)
if isinstance(messages, dict):
return json.dumps(messages, ensure_ascii=False)
# 其他类型转为字符串
return str(messages)
def deserialize_messages(messages_str: str) -> Any:
"""
将 JSON 字符串反序列化为原始格式
Args:
messages_str: JSON 字符串
Returns:
反序列化后的对象list、dict 或 string
"""
if not messages_str:
return []
try:
return json.loads(messages_str)
except (json.JSONDecodeError, TypeError):
return messages_str
def fix_encoding(text: str) -> str:
"""
修复错误编码的文本
Args:
text: 需要修复的文本
Returns:
str: 修复后的文本
"""
if not text or not isinstance(text, str):
return text
try:
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
return text.encode('latin-1').decode('utf-8')
except (UnicodeDecodeError, UnicodeEncodeError):
# 如果修复失败,返回原文本
return text
def format_session_data(data: Dict[str, Any], include_time: bool = False) -> Dict[str, Any]:
"""
格式化会话数据为统一的输出格式
Args:
data: 原始会话数据
include_time: 是否包含时间字段
Returns:
Dict: 格式化后的数据 {"Query": "...", "Answer": "...", "starttime": "..."}
"""
result = {
"Query": fix_encoding(data.get('messages', '')),
"Answer": fix_encoding(data.get('aimessages', ''))
}
if include_time:
result["starttime"] = data.get('starttime', '')
return result
def filter_by_time_range(items: List[Dict], minutes: int) -> List[Dict]:
"""
根据时间范围过滤数据
Args:
items: 包含 starttime 字段的数据列表
minutes: 时间范围(分钟)
Returns:
List[Dict]: 过滤后的数据列表
"""
time_threshold = datetime.now() - timedelta(minutes=minutes)
time_threshold_str = time_threshold.strftime("%Y-%m-%d %H:%M:%S")
filtered_items = []
for item in items:
starttime = item.get('starttime', '')
if starttime and starttime >= time_threshold_str:
filtered_items.append(item)
return filtered_items
def sort_and_limit_results(items: List[Dict], limit: int = 6,
remove_time: bool = True) -> List[Dict]:
"""
对结果进行排序、限制数量并移除时间字段
Args:
items: 数据列表
limit: 最大返回数量
remove_time: 是否移除 starttime 字段
Returns:
List[Dict]: 处理后的数据列表
"""
# 按时间降序排序(最新的在前)
items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
# 限制数量
result_items = items[:limit]
# 移除 starttime 字段
if remove_time:
for item in result_items:
item.pop('starttime', None)
# 如果结果少于1条返回空列表
if len(result_items) < 1:
return []
return result_items
def generate_session_key(session_id: str, key_type: str = "session") -> str:
"""
生成 Redis key
Args:
session_id: 会话ID
key_type: key 类型 ("session", "read", "write", "count")
Returns:
str: Redis key
"""
if key_type == "count":
return f"session:count:{session_id}"
elif key_type == "write":
return f"session:write:{session_id}"
elif key_type == "session" or key_type == "read":
return f"session:{session_id}"
else:
return f"session:{session_id}"
def get_current_timestamp() -> str:
"""
获取当前时间戳字符串
Returns:
str: 格式化的时间字符串 "YYYY-MM-DD HH:MM:SS"
"""
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")

View File

@@ -1,11 +1,36 @@
import redis
import uuid
from datetime import datetime
from app.core.config import settings
from typing import List, Dict, Any, Optional, Union
from app.core.memory.agent.utils.redis_base import (
serialize_messages,
deserialize_messages,
fix_encoding,
format_session_data,
filter_by_time_range,
sort_and_limit_results,
generate_session_key,
get_current_timestamp
)
class RedisSessionStore:
class RedisWriteStore:
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
"""
初始化 Redis 连接
Args:
host: Redis 主机地址
port: Redis 端口
db: Redis 数据库编号
password: Redis 密码
session_id: 会话ID
"""
self.r = redis.Redis(
host=host,
port=port,
@@ -16,32 +41,400 @@ class RedisSessionStore:
)
self.uudi = session_id
def _fix_encoding(self, text):
"""修复错误编码的文本"""
if not text or not isinstance(text, str):
return text
try:
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
return text.encode('latin-1').decode('utf-8')
except (UnicodeDecodeError, UnicodeEncodeError):
# 如果修复失败,返回原文本
return text
# 修改后的 save_session 方法
def save_session(self, userid, messages, aimessages, apply_id, end_user_id):
def save_session_write(self, userid: str, messages: str) -> str:
"""
写入一条会话数据,返回 session_id
优化版本确保写入时间不超过1秒
Args:
userid: 用户ID
messages: 用户消息
Returns:
str: 新生成的 session_id
"""
try:
session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
key = f"session:{session_id}" # 使用新生成的 session_id 作为 key
messages = serialize_messages(messages)
session_id = str(uuid.uuid4())
key = generate_session_key(session_id, key_type="write")
# 使用 pipeline 批量写入,减少网络往返
pipe = self.r.pipeline()
pipe.hset(key, mapping={
"id": self.uudi,
"sessionid": userid,
"messages": messages,
"starttime": get_current_timestamp()
})
result = pipe.execute()
# 直接写入数据decode_responses=True 已经处理了编码
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
return session_id
except Exception as e:
print(f"[save_session_write] 保存会话失败: {e}")
raise e
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
"""
通过 save_session_write 的 userid 获取 sessionid 和 messages
Args:
userid: 用户ID (对应 sessionid 字段)
Returns:
List[Dict] 或 False: 如果找到数据返回 [{"sessionid": "...", "messages": "..."}, ...],否则返回 False
"""
try:
# 只查询 write 类型的 key
keys = self.r.keys('session:write:*')
if not keys:
return False
# 批量获取数据
pipe = self.r.pipeline()
for key in keys:
pipe.hgetall(key)
all_data = pipe.execute()
# 筛选符合 userid 的数据
results = []
for key, data in zip(keys, all_data):
if not data:
continue
# 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == userid:
# 从 key 中提取 session_id: session:write:{session_id}
session_id = key.split(':')[-1]
results.append({
"sessionid": session_id,
"messages": fix_encoding(data.get('messages', ''))
})
if not results:
return False
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
return results
except Exception as e:
print(f"[get_session_by_userid] 查询失败: {e}")
return False
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
"""
通过 end_user_id 获取所有 write 类型的会话数据
Args:
end_user_id: 终端用户ID (对应 sessionid 字段)
Returns:
List[Dict] 或 False: 如果找到数据返回完整的会话信息列表,否则返回 False
返回格式:
[
{
"session_id": "uuid",
"id": "...",
"sessionid": "end_user_id",
"messages": "...",
"starttime": "timestamp"
},
...
]
"""
try:
# 只查询 write 类型的 key
keys = self.r.keys('session:write:*')
if not keys:
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
return False
# 批量获取数据
pipe = self.r.pipeline()
for key in keys:
pipe.hgetall(key)
all_data = pipe.execute()
# 筛选符合 end_user_id 的数据
results = []
for key, data in zip(keys, all_data):
if not data:
continue
# 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == end_user_id:
# 从 key 中提取 session_id: session:write:{session_id}
session_id = key.split(':')[-1]
# 构建完整的会话信息
session_info = {
"session_id": session_id,
"id": data.get('id', ''),
"sessionid": data.get('sessionid', ''),
"messages": fix_encoding(data.get('messages', '')),
"starttime": data.get('starttime', '')
}
results.append(session_info)
if not results:
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
return False
# 按时间排序(最新的在前)
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
return results
except Exception as e:
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
import traceback
traceback.print_exc()
return False
def find_user_recent_sessions(self, userid: str,
minutes: int = 5) -> List[Dict[str, str]]:
"""
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
Args:
userid: 用户ID (对应 sessionid 字段)
minutes: 查询最近几分钟的数据默认5分钟
Returns:
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
"""
import time
start_time = time.time()
# 只查询 write 类型的 key
keys = self.r.keys('session:write:*')
if not keys:
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return []
# 批量获取数据
pipe = self.r.pipeline()
for key in keys:
pipe.hgetall(key)
all_data = pipe.execute()
# 筛选符合 userid 的数据
matched_items = []
for data in all_data:
if not data:
continue
# 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == userid and data.get('starttime'):
# write 类型没有 aimessages所以 Answer 为空
matched_items.append({
"Query": fix_encoding(data.get('messages', '')),
"Answer": "",
"starttime": data.get('starttime', '')
})
# 根据时间范围过滤
filtered_items = filter_by_time_range(matched_items, minutes)
# 排序并移除时间字段
result_items = sort_and_limit_results(filtered_items, limit=None)
print(result_items)
elapsed_time = time.time() - start_time
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items
def delete_all_write_sessions(self) -> int:
"""
删除所有 write 类型的会话
Returns:
int: 删除的数量
"""
keys = self.r.keys('session:write:*')
if keys:
return self.r.delete(*keys)
return 0
class RedisCountStore:
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
"""
初始化 Redis 连接
Args:
host: Redis 主机地址
port: Redis 端口
db: Redis 数据库编号
password: Redis 密码
session_id: 会话ID
"""
self.r = redis.Redis(
host=host,
port=port,
db=db,
password=password,
decode_responses=True,
encoding='utf-8'
)
self.uudi = session_id
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
"""
保存用户访问次数统计
Args:
end_user_id: 终端用户ID
count: 访问次数
messages: 消息内容
Returns:
str: 新生成的 session_id
"""
session_id = str(uuid.uuid4())
key = generate_session_key(session_id, key_type="count")
pipe = self.r.pipeline()
pipe.hset(key, mapping={
"id": self.uudi,
"end_user_id": end_user_id,
"count": int(count),
"messages": serialize_messages(messages),
"starttime": get_current_timestamp()
})
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
result = pipe.execute()
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
return session_id
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
"""
通过 end_user_id 查询访问次数统计
Args:
end_user_id: 终端用户ID
Returns:
list 或 False: 如果找到返回 [count, messages],否则返回 False
"""
try:
search_pattern = 'session:count:*'
for key in self.r.keys(search_pattern):
data = self.r.hgetall(key)
if not data:
continue
if data.get('end_user_id') == end_user_id:
count = data.get('count')
messages_str = data.get('messages')
if count is not None:
messages = deserialize_messages(messages_str)
return [int(count), messages]
return False
except Exception as e:
print(f"[get_sessions_count] 查询失败: {e}")
return False
def update_sessions_count(self, end_user_id: str, new_count: int,
messages: Any) -> bool:
"""
通过 end_user_id 修改访问次数统计
Args:
end_user_id: 终端用户ID
new_count: 新的 count 值
messages: 消息内容
Returns:
bool: 更新成功返回 True未找到记录返回 False
"""
try:
messages_str = serialize_messages(messages)
search_pattern = 'session:count:*'
for key in self.r.keys(search_pattern):
data = self.r.hgetall(key)
if not data:
continue
if data.get('end_user_id') == end_user_id:
self.r.hset(key, 'count', int(new_count))
self.r.hset(key, 'messages', messages_str)
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
return True
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False
except Exception as e:
print(f"[update_sessions_count] 更新失败: {e}")
return False
def delete_all_count_sessions(self) -> int:
"""
删除所有 count 类型的会话
Returns:
int: 删除的数量
"""
keys = self.r.keys('session:count:*')
if keys:
return self.r.delete(*keys)
return 0
class RedisSessionStore:
"""Redis 会话存储类,用于管理会话数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
"""
初始化 Redis 连接
Args:
host: Redis 主机地址
port: Redis 端口
db: Redis 数据库编号
password: Redis 密码
session_id: 会话ID
"""
self.r = redis.Redis(
host=host,
port=port,
db=db,
password=password,
decode_responses=True,
encoding='utf-8'
)
self.uudi = session_id
# ==================== 写入操作 ====================
def save_session(self, userid: str, messages: str, aimessages: str,
apply_id: str, end_user_id: str) -> str:
"""
写入一条会话数据,返回 session_id
Args:
userid: 用户ID
messages: 用户消息
aimessages: AI回复消息
apply_id: 应用ID
end_user_id: 终端用户ID
Returns:
str: 新生成的 session_id
"""
try:
session_id = str(uuid.uuid4())
key = generate_session_key(session_id, key_type="read")
pipe = self.r.pipeline()
pipe.hset(key, mapping={
"id": self.uudi,
"sessionid": userid,
@@ -49,177 +442,195 @@ class RedisSessionStore:
"end_user_id": end_user_id,
"messages": messages,
"aimessages": aimessages,
"starttime": starttime
"starttime": get_current_timestamp()
})
# 可选设置过期时间例如30天避免数据无限增长
# pipe.expire(key, 30 * 24 * 60 * 60)
# 执行批量操作
result = pipe.execute()
print(f"保存结果: {result[0]}, session_id: {session_id}")
return session_id # 返回新生成的 session_id
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
return session_id
except Exception as e:
print(f"保存会话失败: {e}")
print(f"[save_session] 保存会话失败: {e}")
raise e
def save_sessions_batch(self, sessions_data):
"""
批量写入多条会话数据,返回 session_id 列表
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id
优化版本:批量操作,大幅提升性能
"""
try:
session_ids = []
pipe = self.r.pipeline()
for session in sessions_data:
session_id = str(uuid.uuid4())
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
key = f"session:{session_id}"
pipe.hset(key, mapping={
"id": self.uudi,
"sessionid": session.get('userid'),
"apply_id": session.get('apply_id'),
"end_user_id": session.get('end_user_id'),
"messages": session.get('messages'),
"aimessages": session.get('aimessages'),
"starttime": starttime
})
session_ids.append(session_id)
# 一次性执行所有写入操作
results = pipe.execute()
print(f"批量保存完成: {len(session_ids)} 条记录")
return session_ids
except Exception as e:
print(f"批量保存会话失败: {e}")
raise e
# ---------------- 读取 ----------------
def get_session(self, session_id):
# ==================== 读取操作 ====================
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""
读取一条会话数据
Args:
session_id: 会话ID
Returns:
Dict 或 None: 会话数据
"""
key = f"session:{session_id}"
key = generate_session_key(session_id)
data = self.r.hgetall(key)
return data if data else None
def get_session_apply_group(self, sessionid, apply_id, end_user_id):
def get_all_sessions(self) -> Dict[str, Dict[str, Any]]:
"""
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据
"""
result_items = []
# 遍历所有会话数据
for key in self.r.keys('session:*'):
data = self.r.hgetall(key)
if not data:
continue
# 检查三个条件是否都匹配
if (data.get('sessionid') == sessionid and
data.get('apply_id') == apply_id and
data.get('end_user_id') == end_user_id):
result_items.append(data)
return result_items
def get_all_sessions(self):
"""
获取所有会话数据
获取所有会话数据(不包括 count 和 write 类型)
Returns:
Dict: 所有会话数据key 为 session_id
"""
sessions = {}
for key in self.r.keys('session:*'):
sid = key.split(':')[1]
sessions[sid] = self.get_session(sid)
# 排除 count 和 write 类型的 key
if ':count:' not in key and ':write:' not in key:
sid = key.split(':')[1]
sessions[sid] = self.get_session(sid)
return sessions
# ---------------- 更新 ----------------
def update_session(self, session_id, field, value):
def find_user_apply_group(self, sessionid: str, apply_id: str,
end_user_id: str) -> List[Dict[str, str]]:
"""
根据 sessionid、apply_id 和 end_user_id 查询会话数据返回最新的6条
Args:
sessionid: 会话ID支持模糊匹配
apply_id: 应用ID
end_user_id: 终端用户ID
Returns:
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
"""
import time
start_time = time.time()
keys = self.r.keys('session:*')
if not keys:
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return []
# 批量获取数据
pipe = self.r.pipeline()
for key in keys:
# 排除 count 和 write 类型
if ':count:' not in key and ':write:' not in key:
pipe.hgetall(key)
all_data = pipe.execute()
# 筛选符合条件的数据
matched_items = []
for data in all_data:
if not data:
continue
if (data.get('apply_id') == apply_id and
data.get('end_user_id') == end_user_id):
# 支持模糊匹配或完全匹配 sessionid
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
matched_items.append(format_session_data(data, include_time=True))
# 排序、限制数量并移除时间字段
result_items = sort_and_limit_results(matched_items, limit=6)
elapsed_time = time.time() - start_time
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items
# ==================== 更新操作 ====================
def update_session(self, session_id: str, field: str, value: Any) -> bool:
"""
更新单个字段
优化版本:使用 pipeline 减少网络往返
Args:
session_id: 会话ID
field: 字段名
value: 字段值
Returns:
bool: 是否更新成功
"""
key = f"session:{session_id}"
key = generate_session_key(session_id)
pipe = self.r.pipeline()
pipe.exists(key)
pipe.hset(key, field, value)
results = pipe.execute()
return bool(results[0]) # 返回 key 是否存在
return bool(results[0])
# ---------------- 删除 ----------------
def delete_session(self, session_id):
# ==================== 删除操作 ====================
def delete_session(self, session_id: str) -> int:
"""
删除单条会话
Args:
session_id: 会话ID
Returns:
int: 删除的数量
"""
key = f"session:{session_id}"
key = generate_session_key(session_id)
return self.r.delete(key)
def delete_all_sessions(self):
def delete_all_sessions(self) -> int:
"""
删除所有会话
删除所有会话(不包括 count 和 write 类型)
Returns:
int: 删除的数量
"""
keys = self.r.keys('session:*')
if keys:
return self.r.delete(*keys)
# 过滤掉 count 和 write 类型
keys_to_delete = [k for k in keys if ':count:' not in k and ':write:' not in k]
if keys_to_delete:
return self.r.delete(*keys_to_delete)
return 0
def delete_duplicate_sessions(self):
def delete_duplicate_sessions(self) -> int:
"""
删除重复会话数据,条件:
"sessionid""user_id""end_user_id""messages""aimessages" 五个字段都相同的只保留一个,其他删除
优化版本:使用 pipeline 批量操作确保在1秒内完成
删除重复会话数据(不包括 count 和 write 类型)
条件:sessionid、user_id、end_user_id、messages、aimessages 五个字段都相同的只保留一个
Returns:
int: 删除的数量
"""
import time
start_time = time.time()
# 第一步:使用 pipeline 批量获取所有 key
keys = self.r.keys('session:*')
if not keys:
print("[delete_duplicate_sessions] 没有会话数据")
return 0
# 第二步:使用 pipeline 批量获取所有数据
# 批量获取所有数据
pipe = self.r.pipeline()
for key in keys:
pipe.hgetall(key)
# 排除 count 和 write 类型
if ':count:' not in key and ':write:' not in key:
pipe.hgetall(key)
all_data = pipe.execute()
# 第三步:在内存中识别重复数据
seen = {} # 用字典记录identifier -> key保留第一个出现的 key
keys_to_delete = [] # 需要删除的 key 列表
# 识别重复数据
seen = {}
keys_to_delete = []
for key, data in zip(keys, all_data, strict=False):
for key, data in zip([k for k in keys if ':count:' not in k and ':write:' not in k], all_data, strict=False):
if not data:
continue
# 获取五个字段的值
sessionid = data.get('sessionid', '')
user_id = data.get('id', '')
end_user_id = data.get('end_user_id', '')
messages = data.get('messages', '')
aimessages = data.get('aimessages', '')
# 用五元组作为唯一标识
identifier = (sessionid, user_id, end_user_id, messages, aimessages)
identifier = (
data.get('sessionid', ''),
data.get('id', ''),
data.get('end_user_id', ''),
data.get('messages', ''),
data.get('aimessages', '')
)
if identifier in seen:
# 重复,标记为待删除
keys_to_delete.append(key)
else:
# 第一次出现,记录
seen[identifier] = key
# 第四步:使用 pipeline 批量删除重复的 key
# 批量删除重复的 key
deleted_count = 0
if keys_to_delete:
# 分批删除,避免单次操作过大
batch_size = 1000
for i in range(0, len(keys_to_delete), batch_size):
batch = keys_to_delete[i:i + batch_size]
@@ -233,79 +644,28 @@ class RedisSessionStore:
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}")
return deleted_count
def find_user_session(self, sessionid):
user_id = sessionid
result_items = []
for key, values in store.get_all_sessions().items():
history = {}
if user_id == str(values['sessionid']):
history["Query"] = values['messages']
history["Answer"] = values['aimessages']
result_items.append(history)
if len(result_items) <= 1:
result_items = []
return (result_items)
def find_user_apply_group(self, sessionid, apply_id, end_user_id):
"""
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据返回最新的6条
"""
import time
start_time = time.time()
# 使用 pipeline 批量获取数据,提高性能
keys = self.r.keys('session:*')
if not keys:
print(f"查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return []
# 使用 pipeline 批量获取所有 hash 数据
pipe = self.r.pipeline()
for key in keys:
pipe.hgetall(key)
all_data = pipe.execute()
# 解析并筛选符合条件的数据
matched_items = []
for data in all_data:
if not data:
continue
# 检查是否符合三个条件
if (data.get('apply_id') == apply_id and
data.get('end_user_id') == end_user_id):
# 支持模糊匹配 sessionid 或者完全匹配
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
matched_items.append({
"Query": self._fix_encoding(data.get('messages')),
"Answer": self._fix_encoding(data.get('aimessages')),
"starttime": data.get('starttime', '')
})
# 按时间降序排序(最新的在前)
matched_items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
# 只保留最新的6条
result_items = matched_items[:6]
# # 移除 starttime 字段
for item in result_items:
item.pop('starttime', None)
# 如果结果少于等于1条返回空列表
if len(result_items) <= 1:
result_items = []
elapsed_time = time.time() - start_time
print(f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items
# 全局实例
store = RedisSessionStore(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
session_id=str(uuid.uuid4())
)
)
write_store = RedisWriteStore(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
session_id=str(uuid.uuid4())
)
count_store = RedisCountStore(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
session_id=str(uuid.uuid4())
)

View File

@@ -58,6 +58,12 @@ from app.core.memory.models.triplet_models import (
TripletExtractionResponse,
)
# Ontology models
from app.core.memory.models.ontology_models import (
OntologyClass,
OntologyExtractionResponse,
)
# Variable configuration models
from app.core.memory.models.variate_config import (
StatementExtractionConfig,
@@ -105,6 +111,9 @@ __all__ = [
"Entity",
"Triplet",
"TripletExtractionResponse",
# Ontology models
"OntologyClass",
"OntologyExtractionResponse",
# Variable configuration
"StatementExtractionConfig",
"ForgettingEngineConfig",

View File

@@ -0,0 +1,135 @@
"""Models for ontology classes and extraction responses.
This module contains Pydantic models for representing extracted ontology classes
from scenario descriptions, following OWL ontology engineering standards.
Classes:
OntologyClass: Represents an extracted ontology class
OntologyExtractionResponse: Response model containing extracted ontology classes
"""
from typing import List, Optional
from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field, field_validator
class OntologyClass(BaseModel):
"""Represents an extracted ontology class from scenario description.
An ontology class represents an abstract category or concept in a domain,
following OWL ontology engineering standards and naming conventions.
Attributes:
id: Unique string identifier for the ontology class
name: Name of the class in PascalCase format (e.g., 'MedicalProcedure')
name_chinese: Chinese translation of the class name (e.g., '医疗程序')
description: Textual description of the class
examples: List of concrete instance examples of this class
parent_class: Optional name of the parent class in the hierarchy
entity_type: Type/category of the entity (e.g., 'Person', 'Organization', 'Concept')
domain: Domain this class belongs to (e.g., 'Healthcare', 'Education')
Config:
extra: Ignore extra fields from LLM output
"""
model_config = ConfigDict(extra='ignore')
id: str = Field(
default_factory=lambda: uuid4().hex,
description="Unique identifier for the ontology class"
)
name: str = Field(
...,
description="Name of the class in PascalCase format"
)
name_chinese: Optional[str] = Field(
None,
description="Chinese translation of the class name"
)
description: str = Field(
...,
description="Description of the class"
)
examples: List[str] = Field(
default_factory=list,
description="List of concrete instance examples"
)
parent_class: Optional[str] = Field(
None,
description="Name of the parent class in the hierarchy"
)
entity_type: str = Field(
...,
description="Type/category of the entity"
)
domain: str = Field(
...,
description="Domain this class belongs to"
)
@field_validator('name')
@classmethod
def validate_pascal_case(cls, v: str) -> str:
"""Validate that the class name follows PascalCase convention.
PascalCase rules:
- Must start with an uppercase letter
- Cannot contain spaces
- Should not contain special characters except underscores
Args:
v: The class name to validate
Returns:
The validated class name
Raises:
ValueError: If the name doesn't follow PascalCase convention
"""
if not v:
raise ValueError("Class name cannot be empty")
if not v[0].isupper():
raise ValueError(
f"Class name '{v}' must start with an uppercase letter (PascalCase)"
)
if ' ' in v:
raise ValueError(
f"Class name '{v}' cannot contain spaces (PascalCase)"
)
# Check for invalid characters (allow alphanumeric and underscore only)
if not all(c.isalnum() or c == '_' for c in v):
raise ValueError(
f"Class name '{v}' contains invalid characters. "
"Only alphanumeric characters and underscores are allowed"
)
return v
class OntologyExtractionResponse(BaseModel):
"""Response model for ontology extraction from LLM.
This model represents the structured output from the LLM when
extracting ontology classes from scenario descriptions.
Attributes:
classes: List of extracted ontology classes
domain: Domain/field the scenario belongs to
Config:
extra: Ignore extra fields from LLM output
"""
model_config = ConfigDict(extra='ignore')
classes: List[OntologyClass] = Field(
default_factory=list,
description="List of extracted ontology classes"
)
domain: str = Field(
...,
description="Domain/field the scenario belongs to"
)

View File

@@ -8,4 +8,5 @@
- TemporalExtractor: 时间信息提取
- EmbeddingGenerator: 嵌入向量生成
- MemorySummaryGenerator: 记忆摘要生成
- OntologyExtractor: 本体类提取
"""

View File

@@ -14,6 +14,34 @@ from pydantic import Field
logger = get_memory_logger(__name__)
# 支持的语言列表和默认回退值
SUPPORTED_LANGUAGES = {"zh", "en"}
FALLBACK_LANGUAGE = "en"
def validate_language(language: Optional[str]) -> str:
"""
校验语言参数,确保其为有效值。
Args:
language: 待校验的语言代码
Returns:
有效的语言代码("zh""en"
"""
if language is None:
return FALLBACK_LANGUAGE
lang = str(language).lower().strip()
if lang in SUPPORTED_LANGUAGES:
return lang
logger.warning(
f"无效的语言参数 '{language}',已回退到默认值 '{FALLBACK_LANGUAGE}'"
f"支持的语言: {SUPPORTED_LANGUAGES}"
)
return FALLBACK_LANGUAGE
class MemorySummaryResponse(RobustLLMResponse):
"""Structured response for summary generation per chunk.
@@ -31,7 +59,8 @@ class MemorySummaryResponse(RobustLLMResponse):
async def generate_title_and_type_for_summary(
content: str,
llm_client
llm_client,
language: str = None
) -> Tuple[str, str]:
"""
为MemorySummary生成标题和类型
@@ -41,11 +70,18 @@ async def generate_title_and_type_for_summary(
Args:
content: Summary的内容文本
llm_client: LLM客户端实例
language: 生成标题使用的语言 ("zh" 中文, "en" 英文)如果为None则从配置读取
Returns:
(标题, 类型)元组
"""
from app.core.memory.utils.prompt.prompt_utils import render_episodic_title_and_type_prompt
from app.core.config import settings
# 如果没有指定语言,从配置中读取,并校验有效性
if language is None:
language = settings.DEFAULT_LANGUAGE
language = validate_language(language)
# 定义有效的类型集合
VALID_TYPES = {
@@ -57,13 +93,19 @@ async def generate_title_and_type_for_summary(
}
DEFAULT_TYPE = "conversation" # 默认类型
# 根据语言设置默认标题
DEFAULT_TITLE = "空内容" if language == "zh" else "Empty Content"
PARSE_ERROR_TITLE = "解析失败" if language == "zh" else "Parse Failed"
ERROR_TITLE = "错误" if language == "zh" else "Error"
UNKNOWN_TITLE = "未知标题" if language == "zh" else "Unknown Title"
try:
if not content:
logger.warning("content为空无法生成标题和类型")
return ("空内容", DEFAULT_TYPE)
logger.warning(f"content为空无法生成标题和类型 (language={language})")
return (DEFAULT_TITLE, DEFAULT_TYPE)
# 1. 渲染Jinja2提示词模板
prompt = await render_episodic_title_and_type_prompt(content)
# 1. 渲染Jinja2提示词模板,传递语言参数
prompt = await render_episodic_title_and_type_prompt(content, language=language)
# 2. 调用LLM生成标题和类型
messages = [
@@ -102,7 +144,7 @@ async def generate_title_and_type_for_summary(
json_str = json_str.strip()
result_data = json.loads(json_str)
title = result_data.get("title", "未知标题")
title = result_data.get("title", UNKNOWN_TITLE)
episodic_type_raw = result_data.get("type", DEFAULT_TYPE)
# 5. 校验和归一化类型
@@ -130,16 +172,16 @@ async def generate_title_and_type_for_summary(
f"已归一化为 '{episodic_type}'"
)
logger.info(f"成功生成标题和类型: title={title}, type={episodic_type}")
logger.info(f"成功生成标题和类型 (language={language}): title={title}, type={episodic_type}")
return (title, episodic_type)
except json.JSONDecodeError:
logger.error(f"无法解析LLM响应为JSON: {full_response}")
return ("解析失败", DEFAULT_TYPE)
logger.error(f"无法解析LLM响应为JSON (language={language}): {full_response}")
return (PARSE_ERROR_TITLE, DEFAULT_TYPE)
except Exception as e:
logger.error(f"生成标题和类型时出错: {str(e)}", exc_info=True)
return ("错误", DEFAULT_TYPE)
logger.error(f"生成标题和类型时出错 (language={language}): {str(e)}", exc_info=True)
return (ERROR_TITLE, DEFAULT_TYPE)
async def _process_chunk_summary(
dialog: DialogData,
@@ -153,11 +195,16 @@ async def _process_chunk_summary(
return None
try:
# 从配置中获取语言设置(只获取一次,复用),并校验有效性
from app.core.config import settings
language = validate_language(settings.DEFAULT_LANGUAGE)
# Render prompt via Jinja2 for a single chunk
prompt_content = await render_memory_summary_prompt(
chunk_texts=chunk.content,
json_schema=MemorySummaryResponse.model_json_schema(),
max_words=200,
language=language,
)
messages = [
@@ -178,9 +225,10 @@ async def _process_chunk_summary(
try:
title, episodic_type = await generate_title_and_type_for_summary(
content=summary_text,
llm_client=llm_client
llm_client=llm_client,
language=language
)
logger.info(f"Generated title and type for MemorySummary: title={title}, type={episodic_type}")
logger.info(f"Generated title and type for MemorySummary (language={language}): title={title}, type={episodic_type}")
except Exception as e:
logger.warning(f"Failed to generate title and type for chunk {chunk.id}: {e}")
# Continue without title and type

View File

@@ -0,0 +1,482 @@
"""Ontology class extraction from scenario descriptions using LLM.
This module provides the OntologyExtractor class for extracting ontology classes
from natural language scenario descriptions. It uses LLM-driven extraction combined
with two-layer validation (string validation + OWL semantic validation).
Classes:
OntologyExtractor: Extracts ontology classes from scenario descriptions
"""
import asyncio
import logging
import time
from typing import List, Optional
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.models.ontology_models import (
OntologyClass,
OntologyExtractionResponse,
)
from app.core.memory.utils.validation.ontology_validator import OntologyValidator
from app.core.memory.utils.validation.owl_validator import OWLValidator
from app.core.memory.utils.prompt.prompt_utils import render_ontology_extraction_prompt
logger = logging.getLogger(__name__)
class OntologyExtractor:
"""Extractor for ontology classes from scenario descriptions.
This extractor uses LLM to identify abstract classes and concepts from
natural language scenario descriptions, following OWL ontology engineering
standards. It performs two-layer validation:
1. String validation (naming conventions, reserved words, duplicates)
2. OWL semantic validation (consistency checking, circular inheritance)
Attributes:
llm_client: OpenAI client for LLM calls
validator: String validator for class names and descriptions
owl_validator: OWL validator for semantic validation
"""
def __init__(self, llm_client: OpenAIClient):
"""Initialize the OntologyExtractor.
Args:
llm_client: OpenAIClient instance for LLM processing
"""
self.llm_client = llm_client
self.validator = OntologyValidator()
self.owl_validator = OWLValidator()
logger.info("OntologyExtractor initialized")
async def extract_ontology_classes(
self,
scenario: str,
domain: Optional[str] = None,
max_classes: int = 15,
min_classes: int = 5,
enable_owl_validation: bool = True,
llm_temperature: float = 0.3,
llm_max_tokens: int = 2000,
max_description_length: int = 500,
timeout: Optional[float] = None,
) -> OntologyExtractionResponse:
"""Extract ontology classes from a scenario description.
This is the main extraction method that orchestrates the entire process:
1. Call LLM to extract ontology classes
2. Perform first-layer validation (string validation and cleaning)
3. Perform second-layer validation (OWL semantic validation)
4. Filter invalid classes based on validation errors
5. Return validated ontology classes
Args:
scenario: Natural language scenario description
domain: Optional domain hint (e.g., "Healthcare", "Education")
max_classes: Maximum number of classes to extract (default: 15)
min_classes: Minimum number of classes to extract (default: 5)
enable_owl_validation: Whether to enable OWL validation (default: True)
llm_temperature: LLM temperature parameter (default: 0.3)
llm_max_tokens: LLM max tokens parameter (default: 2000)
max_description_length: Maximum description length (default: 500)
timeout: Optional timeout in seconds for LLM call (default: None, no timeout)
Returns:
OntologyExtractionResponse containing validated ontology classes
Raises:
ValueError: If scenario is empty or invalid
asyncio.TimeoutError: If extraction times out
Examples:
>>> extractor = OntologyExtractor(llm_client)
>>> response = await extractor.extract_ontology_classes(
... scenario="A hospital manages patient records...",
... domain="Healthcare",
... max_classes=10,
... timeout=30.0
... )
>>> len(response.classes)
7
"""
# Start timing
start_time = time.time()
# Validate input
if not scenario or not scenario.strip():
logger.error("Scenario description is empty")
raise ValueError("Scenario description cannot be empty")
scenario = scenario.strip()
logger.info(
f"Starting ontology extraction - scenario_length={len(scenario)}, "
f"domain={domain}, max_classes={max_classes}, min_classes={min_classes}, "
f"timeout={timeout}"
)
try:
# Step 1: Call LLM for extraction with timeout
logger.info("Step 1: Calling LLM for ontology extraction")
llm_start_time = time.time()
if timeout is not None:
# Wrap LLM call with timeout
try:
response = await asyncio.wait_for(
self._call_llm_for_extraction(
scenario=scenario,
domain=domain,
max_classes=max_classes,
llm_temperature=llm_temperature,
llm_max_tokens=llm_max_tokens,
),
timeout=timeout
)
except asyncio.TimeoutError:
llm_duration = time.time() - llm_start_time
logger.error(
f"LLM extraction timed out after {timeout} seconds "
f"(actual duration: {llm_duration:.2f}s)"
)
# Return empty response on timeout
return OntologyExtractionResponse(
classes=[],
domain=domain or "Unknown",
)
else:
# No timeout specified, call directly
response = await self._call_llm_for_extraction(
scenario=scenario,
domain=domain,
max_classes=max_classes,
llm_temperature=llm_temperature,
llm_max_tokens=llm_max_tokens,
)
llm_duration = time.time() - llm_start_time
logger.info(
f"LLM returned {len(response.classes)} classes in {llm_duration:.2f}s"
)
# Step 2: First-layer validation (string validation and cleaning)
logger.info("Step 2: Performing first-layer validation (string validation)")
validation_start_time = time.time()
response = self._validate_and_clean(
response=response,
max_description_length=max_description_length,
)
validation_duration = time.time() - validation_start_time
logger.info(
f"After first-layer validation: {len(response.classes)} classes remain "
f"(validation took {validation_duration:.2f}s)"
)
# Check if we have enough classes after first-layer validation
if len(response.classes) < min_classes:
logger.warning(
f"Only {len(response.classes)} classes remain after validation, "
f"which is below minimum of {min_classes}"
)
# Step 3: Second-layer validation (OWL semantic validation)
if enable_owl_validation and response.classes:
logger.info("Step 3: Performing second-layer validation (OWL validation)")
owl_start_time = time.time()
is_valid, errors, world = self.owl_validator.validate_ontology_classes(
classes=response.classes,
)
owl_duration = time.time() - owl_start_time
if not is_valid:
logger.warning(
f"OWL validation found {len(errors)} issues in {owl_duration:.2f}s: {errors}"
)
# Filter invalid classes based on errors
response = self._filter_invalid_classes(
response=response,
errors=errors,
)
logger.info(
f"After second-layer validation: {len(response.classes)} classes remain"
)
else:
logger.info(f"OWL validation passed successfully in {owl_duration:.2f}s")
else:
if not enable_owl_validation:
logger.info("Step 3: OWL validation disabled, skipping")
else:
logger.info("Step 3: No classes to validate, skipping OWL validation")
# Calculate total duration
total_duration = time.time() - start_time
# Log extraction statistics
logger.info(
f"Ontology extraction completed - "
f"final_class_count={len(response.classes)}, "
f"domain={response.domain}, "
f"total_duration={total_duration:.2f}s, "
f"llm_duration={llm_duration:.2f}s"
)
return response
except asyncio.TimeoutError:
# Re-raise timeout errors
total_duration = time.time() - start_time
logger.error(
f"Ontology extraction timed out after {timeout} seconds "
f"(total duration: {total_duration:.2f}s)",
exc_info=True
)
raise
except Exception as e:
total_duration = time.time() - start_time
logger.error(
f"Ontology extraction failed after {total_duration:.2f}s: {str(e)}",
exc_info=True
)
# Return empty response on failure
return OntologyExtractionResponse(
classes=[],
domain=domain or "Unknown",
)
async def _call_llm_for_extraction(
self,
scenario: str,
domain: Optional[str],
max_classes: int,
llm_temperature: float,
llm_max_tokens: int,
) -> OntologyExtractionResponse:
"""Call LLM to extract ontology classes from scenario.
This method renders the extraction prompt using the Jinja2 template
and calls the LLM with structured output to get ontology classes.
Args:
scenario: Scenario description text
domain: Optional domain hint
max_classes: Maximum number of classes to extract
llm_temperature: LLM temperature parameter
llm_max_tokens: LLM max tokens parameter
Returns:
OntologyExtractionResponse from LLM
Raises:
Exception: If LLM call fails
"""
try:
# Render prompt using template
prompt_content = await render_ontology_extraction_prompt(
scenario=scenario,
domain=domain,
max_classes=max_classes,
json_schema=OntologyExtractionResponse.model_json_schema(),
)
logger.debug(f"Rendered prompt length: {len(prompt_content)}")
# Create messages for LLM
messages = [
{
"role": "system",
"content": (
"You are an expert ontology engineer specializing in knowledge "
"representation and OWL standards. Extract ontology classes from "
"scenario descriptions following the provided instructions. "
"Return valid JSON conforming to the schema."
),
},
{
"role": "user",
"content": prompt_content,
},
]
# Call LLM with structured output
logger.debug(
f"Calling LLM with temperature={llm_temperature}, "
f"max_tokens={llm_max_tokens}"
)
response = await self.llm_client.response_structured(
messages=messages,
response_model=OntologyExtractionResponse,
)
logger.info(
f"LLM extraction successful - extracted {len(response.classes)} classes"
)
return response
except Exception as e:
logger.error(
f"LLM extraction failed: {str(e)}",
exc_info=True
)
raise
def _validate_and_clean(
self,
response: OntologyExtractionResponse,
max_description_length: int,
) -> OntologyExtractionResponse:
"""Perform first-layer validation: string validation and cleaning.
This method validates and cleans the extracted ontology classes:
1. Validate class names (PascalCase, no reserved words)
2. Sanitize invalid class names
3. Truncate long descriptions
4. Remove duplicate classes
Args:
response: OntologyExtractionResponse from LLM
max_description_length: Maximum description length
Returns:
Cleaned OntologyExtractionResponse
"""
if not response.classes:
logger.debug("No classes to validate")
return response
logger.debug(f"Validating {len(response.classes)} classes")
validated_classes = []
for ontology_class in response.classes:
# Validate class name
is_valid, error_msg = self.validator.validate_class_name(
ontology_class.name
)
if not is_valid:
logger.warning(
f"Invalid class name '{ontology_class.name}': {error_msg}"
)
# Attempt to sanitize
sanitized_name = self.validator.sanitize_class_name(
ontology_class.name
)
logger.info(
f"Sanitized class name: '{ontology_class.name}' -> '{sanitized_name}'"
)
# Update class name
ontology_class.name = sanitized_name
# Re-validate sanitized name
is_valid, error_msg = self.validator.validate_class_name(
sanitized_name
)
if not is_valid:
logger.error(
f"Failed to sanitize class name '{ontology_class.name}': {error_msg}. "
"Skipping this class."
)
continue
# Truncate description if too long
if ontology_class.description:
original_length = len(ontology_class.description)
ontology_class.description = self.validator.truncate_description(
ontology_class.description,
max_length=max_description_length,
)
if len(ontology_class.description) < original_length:
logger.debug(
f"Truncated description for '{ontology_class.name}': "
f"{original_length} -> {len(ontology_class.description)} chars"
)
validated_classes.append(ontology_class)
# Remove duplicates (case-insensitive)
original_count = len(validated_classes)
validated_classes = self.validator.remove_duplicates(validated_classes)
if len(validated_classes) < original_count:
logger.info(
f"Removed {original_count - len(validated_classes)} duplicate classes"
)
# Return cleaned response
return OntologyExtractionResponse(
classes=validated_classes,
domain=response.domain,
)
def _filter_invalid_classes(
self,
response: OntologyExtractionResponse,
errors: List[str],
) -> OntologyExtractionResponse:
"""Filter invalid classes based on OWL validation errors.
This method analyzes OWL validation errors and removes classes
that caused validation failures (e.g., circular inheritance,
inconsistencies).
Args:
response: OntologyExtractionResponse to filter
errors: List of error messages from OWL validation
Returns:
Filtered OntologyExtractionResponse
"""
if not errors:
return response
logger.debug(f"Filtering classes based on {len(errors)} OWL validation errors")
# Extract class names mentioned in errors
invalid_class_names = set()
for error in errors:
# Look for class names in error messages
for ontology_class in response.classes:
if ontology_class.name in error:
invalid_class_names.add(ontology_class.name)
logger.debug(
f"Class '{ontology_class.name}' marked as invalid due to error: {error}"
)
# Filter out invalid classes
if invalid_class_names:
original_count = len(response.classes)
filtered_classes = [
c for c in response.classes
if c.name not in invalid_class_names
]
logger.info(
f"Filtered out {original_count - len(filtered_classes)} invalid classes: "
f"{invalid_class_names}"
)
return OntologyExtractionResponse(
classes=filtered_classes,
domain=response.domain,
)
return response

View File

@@ -25,6 +25,15 @@ class TripletExtractor:
"""
self.llm_client = llm_client
def _get_language(self) -> str:
"""Get the configured language for entity descriptions
Returns:
Language code ("zh" or "en")
"""
from app.core.config import settings
return settings.DEFAULT_LANGUAGE
async def _extract_triplets(self, statement: Statement, chunk_content: str) -> TripletExtractionResponse:
"""Process a single statement and return extracted triplets and entities"""
# Render the prompt using helper function
@@ -40,7 +49,8 @@ class TripletExtractor:
statement=statement.statement,
chunk_content=chunk_content,
json_schema=TripletExtractionResponse.model_json_schema(),
predicate_instructions=PREDICATE_DEFINITIONS
predicate_instructions=PREDICATE_DEFINITIONS,
language=self._get_language()
)
# Create messages for LLM

View File

@@ -177,7 +177,7 @@ def render_entity_dedup_prompt(
# Args:
# entity_a: Dict of entity A attributes
async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None) -> str:
async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None, language: str = "zh") -> str:
"""
Renders the triplet extraction prompt using the extract_triplet.jinja2 template.
@@ -186,6 +186,7 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
chunk_content: The content of the chunk to process
json_schema: JSON schema for the expected output format
predicate_instructions: Optional predicate instructions
language: The language to use for entity descriptions ("zh" for Chinese, "en" for English)
Returns:
Rendered prompt content as string
@@ -195,7 +196,8 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
statement=statement,
chunk_content=chunk_content,
json_schema=json_schema,
predicate_instructions=predicate_instructions
predicate_instructions=predicate_instructions,
language=language
)
# 记录渲染结果到提示日志(与示例日志结构一致)
log_prompt_rendering('triplet extraction', rendered_prompt)
@@ -204,7 +206,8 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
'statement': 'str',
'chunk_content': 'str',
'json_schema': 'TripletExtractionResponse.schema',
'predicate_instructions': 'PREDICATE_DEFINITIONS'
'predicate_instructions': 'PREDICATE_DEFINITIONS',
'language': language
})
return rendered_prompt
@@ -213,6 +216,7 @@ async def render_memory_summary_prompt(
chunk_texts: str,
json_schema: dict,
max_words: int = 200,
language: str = "zh",
) -> str:
"""
Renders the memory summary prompt using the memory_summary.jinja2 template.
@@ -221,6 +225,7 @@ async def render_memory_summary_prompt(
chunk_texts: Concatenated text of conversation chunks
json_schema: JSON schema for the expected output format
max_words: Maximum words for the summary
language: The language to use for summary generation ("zh" for Chinese, "en" for English)
Returns:
Rendered prompt content as string.
@@ -230,12 +235,14 @@ async def render_memory_summary_prompt(
chunk_texts=chunk_texts,
json_schema=json_schema,
max_words=max_words,
language=language,
)
log_prompt_rendering('memory summary', rendered_prompt)
log_template_rendering('memory_summary.jinja2', {
'chunk_texts_len': len(chunk_texts or ""),
'max_words': max_words,
'json_schema': 'MemorySummaryResponse.schema'
'json_schema': 'MemorySummaryResponse.schema',
'language': language
})
return rendered_prompt
@@ -388,24 +395,65 @@ async def render_memory_insight_prompt(
return rendered_prompt
async def render_episodic_title_and_type_prompt(content: str) -> str:
async def render_episodic_title_and_type_prompt(content: str, language: str = "zh") -> str:
"""
Renders the episodic title and type classification prompt using the episodic_type_classification.jinja2 template.
Args:
content: The content of the episodic memory summary to analyze
language: The language to use for title generation ("zh" for Chinese, "en" for English)
Returns:
Rendered prompt content as string
"""
template = prompt_env.get_template("episodic_type_classification.jinja2")
rendered_prompt = template.render(content=content)
rendered_prompt = template.render(content=content, language=language)
# 记录渲染结果到提示日志
log_prompt_rendering('episodic title and type classification', rendered_prompt)
# 可选:记录模板渲染信息
log_template_rendering('episodic_type_classification.jinja2', {
'content_len': len(content) if content else 0
'content_len': len(content) if content else 0,
'language': language
})
return rendered_prompt
async def render_ontology_extraction_prompt(
scenario: str,
domain: str | None = None,
max_classes: int = 15,
json_schema: dict | None = None
) -> str:
"""
Renders the ontology extraction prompt using the extract_ontology.jinja2 template.
Args:
scenario: The scenario description text to extract ontology classes from
domain: Optional domain hint for the scenario (e.g., "Healthcare", "Education")
max_classes: Maximum number of classes to extract (default: 15)
json_schema: JSON schema for the expected output format
Returns:
Rendered prompt content as string
"""
template = prompt_env.get_template("extract_ontology.jinja2")
rendered_prompt = template.render(
scenario=scenario,
domain=domain,
max_classes=max_classes,
json_schema=json_schema
)
# 记录渲染结果到提示日志
log_prompt_rendering('ontology extraction', rendered_prompt)
# 可选:记录模板渲染信息
log_template_rendering('extract_ontology.jinja2', {
'scenario_len': len(scenario) if scenario else 0,
'domain': domain,
'max_classes': max_classes,
'json_schema': 'OntologyExtractionResponse.schema'
})
return rendered_prompt

View File

@@ -1,8 +1,19 @@
=== Task ===
Generate a concise title and classify the episodic memory into the most appropriate category.
{% if language == "zh" %}
**重要:请使用中文生成标题和分类。**
{% else %}
**Important: Please generate the title and classification in English.**
{% endif %}
=== Requirements ===
- Extract a clear, concise title (10-20 characters) that captures the core content
{% if language == "zh" %}
- 标题必须使用中文
{% else %}
- Title must be in English
{% endif %}
- Classify into exactly one category based on the primary theme
- Be specific and avoid ambiguity
- Output must be valid JSON conforming to the schema below

View File

@@ -0,0 +1,210 @@
===Task===
Extract ontology classes from the given scenario description following ontology engineering standards.
===Role===
You are a professional ontology engineer with expertise in knowledge representation and OWL (Web Ontology Language) standards. Your task is to identify abstract classes and concepts from scenario descriptions, not concrete instances.
===Scenario Description===
{{ scenario }}
{% if domain -%}
===Domain Hint===
This scenario belongs to the **{{ domain }}** domain. Consider domain-specific concepts and terminology when extracting classes.
{%- endif %}
===Extraction Rules===
**1. Abstract Classes, Not Instances:**
- Extract abstract categories and concepts (e.g., "MedicalProcedure", "Patient", "Diagnosis")
- Do NOT extract concrete instances (e.g., "John Smith", "Room 301", "2024-01-15")
- Think in terms of "types of things" rather than "specific things"
**2. Naming Convention (PascalCase):**
- Use PascalCase format for the "name" field: start with uppercase letter, capitalize each word, no spaces
- Examples: "MedicalProcedure", "HealthcareProvider", "DiagnosticTest"
- Avoid: "medical procedure", "healthcare_provider", "diagnostic-test"
- Use clear, descriptive names in English
- Avoid abbreviations unless they are standard in the domain (e.g., "API", "DNA")
- Provide Chinese translation in the "name_chinese" field (e.g., "医疗程序", "医疗服务提供者", "诊断测试")
**3. Domain Relevance:**
- Focus on classes that are central to the scenario's domain
- Prioritize classes that represent key concepts, entities, or relationships
- Avoid overly generic classes (e.g., "Thing", "Object") unless they have specific domain meaning
**4. Class Quantity:**
- Extract between 5 and {{ max_classes }} classes
- Aim for a balanced set covering the main concepts in the scenario
- Quality over quantity: prefer well-defined classes over exhaustive lists
**5. Clear Descriptions:**
- Provide concise, informative descriptions in Chinese (max 500 characters)
- Describe what the class represents, not specific instances
- Use clear, natural Chinese language that explains the class's role in the domain
**6. Concrete Examples:**
- Provide 2-5 concrete instance examples in Chinese for each class
- Examples should be specific, realistic instances of the class
- Examples help clarify the class's scope and meaning
- Use natural Chinese language for examples
- Example format: ["示例1", "示例2", "示例3"]
**7. Class Hierarchy:**
- Identify parent-child relationships where applicable
- Use the parent_class field to specify inheritance
- Parent class must be one of the extracted classes or a standard OWL class
- Leave parent_class as null for top-level classes
**8. Entity Types:**
- Classify each class with an appropriate entity_type
- Common types: "Person", "Organization", "Location", "Event", "Concept", "Process", "Object", "Role"
- Choose the most specific type that applies
**9. OWL Reserved Words:**
- Do NOT use OWL reserved words as class names
- Reserved words include: "Thing", "Nothing", "Class", "Property", "ObjectProperty", "DatatypeProperty", "AnnotationProperty", "Ontology", "Individual", "Literal"
- If a reserved word is needed, add a domain-specific prefix (e.g., "MedicalClass" instead of "Class")
**10. Language Consistency:**
- Extract all class names in English (PascalCase format) for the "name" field
- Provide Chinese translation for class names in the "name_chinese" field
- Descriptions MUST be in Chinese (中文)
- Examples MUST be in Chinese (中文)
- Use clear, natural Chinese language for descriptions and examples
===Examples===
**Example 1 (Healthcare Domain):**
Scenario: "A hospital manages patient records, schedules appointments, and coordinates medical procedures. Doctors diagnose conditions and prescribe treatments."
Output:
{
"classes": [
{
"name": "Patient",
"name_chinese": "患者",
"description": "在医疗机构接受医疗护理或治疗的人",
"examples": ["张三", "李四", "患有糖尿病的老年患者"],
"parent_class": null,
"entity_type": "Person",
"domain": "Healthcare"
},
{
"name": "MedicalProcedure",
"name_chinese": "医疗程序",
"description": "为医疗诊断或治疗而执行的系统性操作流程",
"examples": ["手术", "血液检查", "X光检查", "疫苗接种"],
"parent_class": null,
"entity_type": "Process",
"domain": "Healthcare"
},
{
"name": "Diagnosis",
"name_chinese": "诊断",
"description": "基于症状和检查结果对疾病或状况的识别",
"examples": ["糖尿病诊断", "癌症诊断", "流感诊断"],
"parent_class": null,
"entity_type": "Concept",
"domain": "Healthcare"
},
{
"name": "Doctor",
"name_chinese": "医生",
"description": "诊断和治疗患者的持证医疗专业人员",
"examples": ["全科医生", "外科医生", "心脏病专家"],
"parent_class": null,
"entity_type": "Role",
"domain": "Healthcare"
},
{
"name": "Treatment",
"name_chinese": "治疗",
"description": "为治愈或管理疾病状况而提供的医疗护理或疗法",
"examples": ["药物治疗", "物理治疗", "化疗", "手术治疗"],
"parent_class": null,
"entity_type": "Process",
"domain": "Healthcare"
}
],
"domain": "Healthcare",
"namespace": "http://example.org/healthcare#"
}
**Example 2 (Education Domain):**
Scenario: "A university offers courses taught by professors. Students enroll in programs, attend lectures, and complete assignments to earn degrees."
Output:
{
"classes": [
{
"name": "Student",
"name_chinese": "学生",
"description": "在教育机构注册学习的人",
"examples": ["本科生", "研究生", "在职学生"],
"parent_class": null,
"entity_type": "Role",
"domain": "Education"
},
{
"name": "Course",
"name_chinese": "课程",
"description": "涵盖特定学科或主题的结构化教育课程",
"examples": ["计算机科学导论", "微积分I", "世界历史"],
"parent_class": null,
"entity_type": "Concept",
"domain": "Education"
},
{
"name": "Professor",
"name_chinese": "教授",
"description": "教授课程并进行研究的学术教师",
"examples": ["助理教授", "副教授", "正教授"],
"parent_class": null,
"entity_type": "Role",
"domain": "Education"
},
{
"name": "AcademicProgram",
"name_chinese": "学术项目",
"description": "通向学位或证书的结构化课程体系",
"examples": ["理学学士", "文学硕士", "博士项目"],
"parent_class": null,
"entity_type": "Concept",
"domain": "Education"
},
{
"name": "Assignment",
"name_chinese": "作业",
"description": "分配给学生以评估学习成果的任务或项目",
"examples": ["论文", "习题集", "研究报告", "实验报告"],
"parent_class": null,
"entity_type": "Object",
"domain": "Education"
},
{
"name": "Lecture",
"name_chinese": "讲座",
"description": "由教师进行的教育性演讲或讲座",
"examples": ["入门讲座", "客座讲座", "在线讲座"],
"parent_class": null,
"entity_type": "Event",
"domain": "Education"
}
],
"domain": "Education",
"namespace": "http://example.org/education#"
}
===Output Format===
**JSON Requirements:**
- Use only ASCII double quotes (") for JSON structure
- Never use Chinese quotation marks ("") or Unicode quotes
- Escape quotation marks in text with backslashes (\")
- Ensure proper string closure and comma separation
- No line breaks within JSON string values
- All class names must be in PascalCase format
- All class names must be unique (case-insensitive)
- Extract between 5 and {{ max_classes }} classes
{{ json_schema }}

View File

@@ -5,6 +5,12 @@
===Task===
Extract entities and knowledge triplets from the given statement.
{% if language == "zh" %}
**重要请使用中文生成实体描述description和示例example。**
{% else %}
**Important: Please generate entity descriptions and examples in English.**
{% endif %}
===Inputs===
**Chunk Content:** "{{ chunk_content }}"
**Statement:** "{{ statement }}"
@@ -13,6 +19,13 @@ Extract entities and knowledge triplets from the given statement.
**Entity Extraction:**
- Extract entities with their types, context-independent descriptions, **concise examples**, aliases, and semantic memory classification
{% if language == "zh" %}
- **实体描述description必须使用中文**
- **示例example必须使用中文**
{% else %}
- **Entity descriptions must be in English**
- **Examples must be in English**
{% endif %}
- **Semantic Memory Classification (is_explicit_memory):**
* Set to `true` if the entity represents **explicit/semantic memory**:
- **Concepts:** "Machine Learning", "Photosynthesis", "Democracy", "人工智能", "光合作用", "民主"
@@ -334,9 +347,11 @@ Output:
- Escape quotation marks in text with backslashes (\")
- Ensure proper string closure and comma separation
- No line breaks within JSON string values
- The output language should ALWAYS match the input language
- If input is in English, extract statements in English
- If input is in Chinese, extract statements in Chinese
{% if language == "zh" %}
- **语言要求实体描述description和示例example必须使用中文**
{% else %}
- **Language Requirement: Entity descriptions and examples must be in English**
{% endif %}
- Preserve the original language and do not translate
{{ json_schema }}

View File

@@ -5,10 +5,21 @@
=== Task ===
Summarize the provided conversation chunks into a concise Memory summary.
{% if language == "zh" %}
**重要:请使用中文生成摘要内容。**
{% else %}
**Important: Please generate the summary content in English.**
{% endif %}
=== Requirements ===
- Focus on factual statements, user preferences, relationships, and salient temporal context.
- Avoid repetition and filler; be specific.
- Keep it under {{ max_words or 200 }} words.
{% if language == "zh" %}
- 摘要内容必须使用中文
{% else %}
- Summary content must be in English
{% endif %}
- Output must be valid JSON conforming to the schema below.
=== Input ===
@@ -24,6 +35,11 @@ Summarize the provided conversation chunks into a concise Memory summary.
4. Do not include line breaks within JSON string values
5. Example of proper escaping: "statement": "张曼婷说:\"我很喜欢这本书。\""
The output language should always be the same as the input language.
{% if language == "zh" %}
**语言要求:输出内容必须使用中文。**
{% else %}
**Language Requirement: The output content must be in English.**
{% endif %}
Return only a list of extracted labelled statements in the JSON ARRAY of objects that match the schema below:
{{ json_schema }}

View File

@@ -0,0 +1,10 @@
"""Validation utilities for ontology extraction.
This module provides validation classes for ontology class names,
descriptions, and OWL compliance checking.
"""
from .ontology_validator import OntologyValidator
from .owl_validator import OWLValidator
__all__ = ['OntologyValidator', 'OWLValidator']

View File

@@ -0,0 +1,268 @@
"""String validation for ontology class names and descriptions.
This module provides the OntologyValidator class for validating and sanitizing
ontology class names according to OWL standards and naming conventions.
Classes:
OntologyValidator: Validates class names, removes duplicates, and truncates descriptions
"""
import logging
import re
from typing import List, Tuple
from app.core.memory.models.ontology_models import OntologyClass
logger = logging.getLogger(__name__)
class OntologyValidator:
"""Validator for ontology class names and descriptions.
This validator performs string-level validation including:
- PascalCase naming convention validation
- OWL reserved word checking
- Duplicate class name removal
- Description length truncation
Attributes:
OWL_RESERVED_WORDS: Set of OWL reserved words that cannot be used as class names
"""
# OWL reserved words that cannot be used as class names
OWL_RESERVED_WORDS = {
'Thing', 'Nothing', 'Class', 'Property',
'ObjectProperty', 'DatatypeProperty', 'FunctionalProperty',
'InverseFunctionalProperty', 'TransitiveProperty', 'SymmetricProperty',
'AsymmetricProperty', 'ReflexiveProperty', 'IrreflexiveProperty',
'Restriction', 'Ontology', 'Individual', 'NamedIndividual',
'Annotation', 'AnnotationProperty', 'Axiom',
'AllDifferent', 'AllDisjointClasses', 'AllDisjointProperties',
'Datatype', 'DataRange', 'Literal',
'DeprecatedClass', 'DeprecatedProperty',
'Imports', 'IncompatibleWith', 'PriorVersion', 'VersionInfo',
'BackwardCompatibleWith', 'OntologyProperty',
}
def validate_class_name(self, name: str) -> Tuple[bool, str]:
"""Validate that a class name follows OWL naming conventions.
Validation rules:
1. Must not be empty
2. Must start with an uppercase letter (PascalCase)
3. Cannot contain spaces
4. Can only contain alphanumeric characters and underscores
5. Cannot be an OWL reserved word
Args:
name: The class name to validate
Returns:
Tuple of (is_valid, error_message)
- is_valid: True if the name is valid, False otherwise
- error_message: Empty string if valid, error description if invalid
Examples:
>>> validator = OntologyValidator()
>>> validator.validate_class_name("MedicalProcedure")
(True, "")
>>> validator.validate_class_name("medical procedure")
(False, "Class name 'medical procedure' cannot contain spaces")
>>> validator.validate_class_name("Thing")
(False, "Class name 'Thing' is an OWL reserved word")
"""
logger.debug(f"Validating class name: '{name}'")
# Check if empty
if not name or not name.strip():
error_msg = "Class name cannot be empty"
logger.warning(f"Validation failed: {error_msg}")
return False, error_msg
name = name.strip()
# Check if it's an OWL reserved word
if name in self.OWL_RESERVED_WORDS:
error_msg = f"Class name '{name}' is an OWL reserved word"
logger.warning(f"Validation failed: {error_msg}")
return False, error_msg
# Check if starts with uppercase letter
if not name[0].isupper():
error_msg = f"Class name '{name}' must start with an uppercase letter (PascalCase)"
logger.warning(f"Validation failed: {error_msg}")
return False, error_msg
# Check for spaces
if ' ' in name:
error_msg = f"Class name '{name}' cannot contain spaces"
logger.warning(f"Validation failed: {error_msg}")
return False, error_msg
# Check for invalid characters (only alphanumeric and underscore allowed)
if not re.match(r'^[A-Za-z0-9_]+$', name):
error_msg = f"Class name '{name}' contains invalid characters. Only alphanumeric characters and underscores are allowed"
logger.warning(f"Validation failed: {error_msg}")
return False, error_msg
logger.debug(f"Class name '{name}' is valid")
return True, ""
def sanitize_class_name(self, name: str) -> str:
"""Attempt to sanitize an invalid class name into a valid format.
Sanitization steps:
1. Strip whitespace
2. Remove invalid characters
3. Replace spaces with empty string (PascalCase)
4. Capitalize first letter of each word
5. If result is empty or starts with number, prefix with 'Class'
Args:
name: The class name to sanitize
Returns:
Sanitized class name that should pass validation
Examples:
>>> validator = OntologyValidator()
>>> validator.sanitize_class_name("medical procedure")
'MedicalProcedure'
>>> validator.sanitize_class_name("patient-record")
'PatientRecord'
>>> validator.sanitize_class_name("123invalid")
'Class123Invalid'
"""
logger.debug(f"Sanitizing class name: '{name}'")
if not name or not name.strip():
logger.warning("Empty class name provided for sanitization, returning 'UnnamedClass'")
return "UnnamedClass"
# Strip whitespace
name = name.strip()
original_name = name
# Split on spaces, hyphens, and underscores, then capitalize each word
words = re.split(r'[\s\-_]+', name)
# Capitalize first letter of each word and keep rest as is
sanitized_words = []
for word in words:
if word:
# Remove non-alphanumeric characters except underscore
clean_word = re.sub(r'[^A-Za-z0-9_]', '', word)
if clean_word:
# Capitalize first letter
sanitized_words.append(clean_word[0].upper() + clean_word[1:])
# Join words
sanitized = ''.join(sanitized_words)
# If empty or starts with number, prefix with 'Class'
if not sanitized or sanitized[0].isdigit():
sanitized = 'Class' + sanitized
logger.info(f"Prefixed class name with 'Class': '{original_name}' -> '{sanitized}'")
# If it's a reserved word, append 'Class' suffix
if sanitized in self.OWL_RESERVED_WORDS:
sanitized = sanitized + 'Class'
logger.info(f"Appended 'Class' suffix to reserved word: '{original_name}' -> '{sanitized}'")
logger.info(f"Sanitized class name: '{original_name}' -> '{sanitized}'")
return sanitized
def remove_duplicates(self, classes: List[OntologyClass]) -> List[OntologyClass]:
"""Remove duplicate ontology classes based on case-insensitive name comparison.
When duplicates are found, keeps the first occurrence and discards subsequent ones.
Comparison is case-insensitive to catch variations like 'Patient' and 'patient'.
Args:
classes: List of OntologyClass objects
Returns:
List of OntologyClass objects with duplicates removed
Examples:
>>> validator = OntologyValidator()
>>> classes = [
... OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare"),
... OntologyClass(name="patient", description="Another patient", entity_type="Person", domain="Healthcare"),
... OntologyClass(name="Doctor", description="A doctor", entity_type="Person", domain="Healthcare"),
... ]
>>> unique = validator.remove_duplicates(classes)
>>> len(unique)
2
>>> [c.name for c in unique]
['Patient', 'Doctor']
"""
if not classes:
logger.debug("No classes to check for duplicates")
return classes
logger.debug(f"Checking {len(classes)} classes for duplicates")
seen_names = set()
unique_classes = []
duplicates_found = []
for ontology_class in classes:
# Use lowercase for comparison
name_lower = ontology_class.name.lower()
if name_lower not in seen_names:
seen_names.add(name_lower)
unique_classes.append(ontology_class)
else:
duplicates_found.append(ontology_class.name)
logger.debug(f"Duplicate class found and removed: '{ontology_class.name}'")
if duplicates_found:
logger.info(
f"Removed {len(duplicates_found)} duplicate classes: {duplicates_found}"
)
else:
logger.debug("No duplicate classes found")
return unique_classes
def truncate_description(self, description: str, max_length: int = 500) -> str:
"""Truncate a description to a maximum length.
If the description exceeds max_length, it will be truncated and
an ellipsis (...) will be appended to indicate truncation.
Args:
description: The description text to truncate
max_length: Maximum allowed length (default: 500)
Returns:
Truncated description string
Examples:
>>> validator = OntologyValidator()
>>> long_desc = "A" * 600
>>> truncated = validator.truncate_description(long_desc, max_length=500)
>>> len(truncated)
500
>>> truncated.endswith("...")
True
"""
if not description:
return ""
if len(description) <= max_length:
return description
# Truncate and add ellipsis
# Reserve 3 characters for "..."
truncate_at = max_length - 3
truncated = description[:truncate_at] + "..."
logger.debug(
f"Truncated description from {len(description)} to {len(truncated)} characters"
)
return truncated

View File

@@ -0,0 +1,585 @@
"""OWL semantic validation for ontology classes using Owlready2.
This module provides the OWLValidator class for validating ontology classes
against OWL standards using the Owlready2 library. It performs semantic
validation including consistency checking, circular inheritance detection,
and OWL file export.
Classes:
OWLValidator: Validates ontology classes using OWL reasoning and exports to OWL formats
"""
import logging
from typing import List, Optional, Tuple
from owlready2 import (
World,
Thing,
get_ontology,
sync_reasoner_pellet,
OwlReadyInconsistentOntologyError,
)
from app.core.memory.models.ontology_models import OntologyClass
logger = logging.getLogger(__name__)
class OWLValidator:
"""Validator for OWL semantic validation of ontology classes.
This validator performs semantic-level validation using Owlready2 including:
- Creating OWL classes from ontology class definitions
- Running consistency checking with Pellet reasoner
- Detecting circular inheritance
- Validating Protégé compatibility
- Exporting ontologies to various OWL formats (RDF/XML, Turtle, N-Triples)
Attributes:
base_namespace: Base URI for the ontology namespace
"""
def __init__(self, base_namespace: str = "http://example.org/ontology#"):
"""Initialize the OWL validator.
Args:
base_namespace: Base URI for the ontology namespace (default: http://example.org/ontology#)
"""
self.base_namespace = base_namespace
def validate_ontology_classes(
self,
classes: List[OntologyClass],
) -> Tuple[bool, List[str], Optional[World]]:
"""Validate extracted ontology classes against OWL standards.
This method creates an OWL ontology from the provided classes using Owlready2,
runs consistency checking with the Pellet reasoner, and detects common issues
like circular inheritance.
Args:
classes: List of OntologyClass objects to validate
Returns:
Tuple of (is_valid, error_messages, world):
- is_valid: True if ontology is valid and consistent, False otherwise
- error_messages: List of error/warning messages
- world: Owlready2 World object containing the ontology (None if validation failed)
Examples:
>>> validator = OWLValidator()
>>> classes = [
... OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare"),
... OntologyClass(name="Doctor", description="A doctor", entity_type="Person", domain="Healthcare"),
... ]
>>> is_valid, errors, world = validator.validate_ontology_classes(classes)
>>> is_valid
True
>>> len(errors)
0
"""
if not classes:
return False, ["No classes provided for validation"], None
errors = []
try:
# Create a new world (isolated ontology environment)
world = World()
# Use a proper ontology IRI
# Owlready2 expects the IRI to end with .owl or similar
onto_iri = self.base_namespace.rstrip('#/')
if not onto_iri.endswith('.owl'):
onto_iri = onto_iri + '.owl'
# Create ontology
onto = world.get_ontology(onto_iri)
with onto:
# Dictionary to store created OWL classes for parent reference
owl_classes = {}
# First pass: Create all classes without parent relationships
for ontology_class in classes:
try:
# Create OWL class dynamically using type() with Thing as base
# The key is to NOT set namespace in the dict, let Owlready2 handle it
owl_class = type(
ontology_class.name, # Class name
(Thing,), # Base classes
{} # Class dict (empty, let Owlready2 manage)
)
# Add label (rdfs:label) - include both English and Chinese names
labels = [ontology_class.name]
if ontology_class.name_chinese:
labels.append(ontology_class.name_chinese)
owl_class.label = labels
# Add comment (rdfs:comment) with description
if ontology_class.description:
owl_class.comment = [ontology_class.description]
# Store for parent relationship setup
owl_classes[ontology_class.name] = owl_class
logger.debug(
f"Created OWL class: {ontology_class.name} "
f"(Chinese: {ontology_class.name_chinese}) "
f"IRI: {owl_class.iri if hasattr(owl_class, 'iri') else 'N/A'}"
)
except Exception as e:
error_msg = f"Failed to create OWL class '{ontology_class.name}': {str(e)}"
errors.append(error_msg)
logger.error(error_msg, exc_info=True)
# Second pass: Set up parent relationships
for ontology_class in classes:
if ontology_class.parent_class and ontology_class.name in owl_classes:
parent_name = ontology_class.parent_class
# Check if parent exists
if parent_name in owl_classes:
try:
child_class = owl_classes[ontology_class.name]
parent_class = owl_classes[parent_name]
# Set parent by modifying is_a
child_class.is_a = [parent_class]
logger.debug(
f"Set parent relationship: {ontology_class.name} -> {parent_name}"
)
except Exception as e:
error_msg = (
f"Failed to set parent relationship "
f"'{ontology_class.name}' -> '{parent_name}': {str(e)}"
)
errors.append(error_msg)
logger.warning(error_msg)
else:
warning_msg = (
f"Parent class '{parent_name}' not found for '{ontology_class.name}'"
)
errors.append(warning_msg)
logger.warning(warning_msg)
# Check for circular inheritance
for class_name, owl_class in owl_classes.items():
if self._has_circular_inheritance(owl_class):
error_msg = f"Circular inheritance detected for class '{class_name}'"
errors.append(error_msg)
logger.error(error_msg)
# Run consistency checking with Pellet reasoner
try:
logger.info("Running Pellet reasoner for consistency checking...")
sync_reasoner_pellet(world, infer_property_values=True, infer_data_property_values=True)
logger.info("Consistency check passed")
except OwlReadyInconsistentOntologyError as e:
error_msg = f"Ontology is inconsistent: {str(e)}"
errors.append(error_msg)
logger.error(error_msg)
return False, errors, world
except Exception as e:
# Reasoner errors are often due to Java not being installed or configured
# Log as warning but don't fail validation - ontology structure is still valid
warning_msg = f"Reasoner check skipped: {str(e)}"
if str(e).strip(): # Only log if there's an actual error message
logger.warning(warning_msg)
else:
logger.warning("Reasoner check skipped: Java may not be installed or configured")
# Continue - ontology structure is valid even without reasoner check
# If we have errors (excluding warnings), validation failed
is_valid = len(errors) == 0
return is_valid, errors, world
except Exception as e:
error_msg = f"OWL validation failed: {str(e)}"
errors.append(error_msg)
logger.error(error_msg, exc_info=True)
return False, errors, None
def _has_circular_inheritance(self, owl_class) -> bool:
"""Check if an OWL class has circular inheritance.
Circular inheritance occurs when a class inherits from itself through
a chain of parent relationships (e.g., A -> B -> C -> A).
Args:
owl_class: Owlready2 class object to check
Returns:
True if circular inheritance is detected, False otherwise
"""
visited = set()
current = owl_class
while current:
# Get class IRI or name as identifier
class_id = str(current.iri) if hasattr(current, 'iri') else str(current)
if class_id in visited:
# Found a cycle
return True
visited.add(class_id)
# Get parent classes (is_a relationship)
parents = getattr(current, 'is_a', [])
# Filter out Thing and other base classes
parent_classes = [p for p in parents if p != Thing and hasattr(p, 'is_a')]
if not parent_classes:
# No more parents, no cycle
break
# Check first parent (in single inheritance)
current = parent_classes[0] if parent_classes else None
return False
def export_to_owl(
self,
world: World,
output_path: Optional[str] = None,
format: str = "rdfxml",
classes: Optional[List] = None
) -> str:
"""Export ontology to OWL file in specified format.
Supported formats:
- rdfxml: RDF/XML format (default, most compatible)
- turtle: Turtle format (more readable)
- ntriples: N-Triples format (simplest)
- json: JSON format (simplified, human-readable)
Args:
world: Owlready2 World object containing the ontology
output_path: Optional file path to save the ontology (if None, returns string)
format: Export format - "rdfxml", "turtle", "ntriples", or "json" (default: "rdfxml")
classes: Optional list of OntologyClass objects (required for json format)
Returns:
String representation of the exported ontology
Raises:
ValueError: If format is not supported
RuntimeError: If export fails
Examples:
>>> validator = OWLValidator()
>>> is_valid, errors, world = validator.validate_ontology_classes(classes)
>>> owl_content = validator.export_to_owl(world, "ontology.owl", format="rdfxml")
"""
# Validate format
valid_formats = ["rdfxml", "turtle", "ntriples", "json"]
if format not in valid_formats:
raise ValueError(
f"Unsupported format '{format}'. Must be one of: {', '.join(valid_formats)}"
)
# JSON format doesn't need OWL processing
if format == "json":
if not classes:
raise ValueError("Classes list is required for JSON format export")
return self._export_to_json(classes)
# For OWL formats, world is required
if not world:
raise ValueError("World object is None. Cannot export ontology.")
# Note: Owlready2 has issues with turtle format export
# We'll handle it specially by converting from rdfxml
use_conversion = (format == "turtle")
try:
# Get all ontologies in the world
ontologies = list(world.ontologies.values())
if not ontologies:
raise RuntimeError("No ontologies found in world")
# Find the ontology with classes (skip anonymous/empty ontologies)
onto = None
for ont in ontologies:
classes_count = len(list(ont.classes()))
logger.debug(f"Checking ontology {ont.base_iri}: {classes_count} classes")
if classes_count > 0:
onto = ont
break
# If no ontology with classes found, use the last non-anonymous one
if onto is None:
for ont in reversed(ontologies):
if ont.base_iri != "http://anonymous/":
onto = ont
break
# If still no ontology, use the first one
if onto is None:
onto = ontologies[0]
# Log ontology contents for debugging
logger.info(f"Ontology IRI: {onto.base_iri}")
logger.info(f"Ontology contains {len(list(onto.classes()))} classes")
# List all classes in the ontology
all_classes = list(onto.classes())
for cls in all_classes:
logger.info(f"Class in ontology: {cls.name} (IRI: {cls.iri})")
if hasattr(cls, 'label'):
logger.debug(f" Labels: {cls.label}")
if hasattr(cls, 'comment'):
logger.debug(f" Comments: {cls.comment}")
if len(all_classes) == 0:
logger.warning("No classes found in ontology! This may indicate a problem with class creation.")
if output_path:
# Save to file
export_format = "rdfxml" if use_conversion else format
logger.info(f"Exporting ontology to {output_path} in {export_format} format")
onto.save(file=output_path, format=export_format)
# Read back the file content to return
with open(output_path, 'r', encoding='utf-8') as f:
content = f.read()
# Convert to turtle if needed
if use_conversion:
content = self._convert_to_turtle(content)
logger.info(f"Successfully exported ontology to {output_path}")
# Format the content for better readability
content = self._format_owl_content(content, format)
return content
else:
# Export to string (save to temporary location and read)
import tempfile
import os
with tempfile.NamedTemporaryFile(mode='w', suffix='.owl', delete=False) as tmp:
tmp_path = tmp.name
try:
export_format = "rdfxml" if use_conversion else format
onto.save(file=tmp_path, format=export_format)
with open(tmp_path, 'r', encoding='utf-8') as f:
content = f.read()
# Convert to turtle if needed
if use_conversion:
content = self._convert_to_turtle(content)
# Format the content for better readability
content = self._format_owl_content(content, format)
return content
finally:
# Clean up temporary file
if os.path.exists(tmp_path):
os.remove(tmp_path)
except Exception as e:
error_msg = f"Failed to export ontology: {str(e)}"
logger.error(error_msg, exc_info=True)
raise RuntimeError(error_msg) from e
def _export_to_json(self, classes: List) -> str:
"""Export ontology classes to simplified JSON format.
This format is more compact and easier to parse than OWL XML.
Args:
classes: List of OntologyClass objects
Returns:
JSON string representation (compact format)
"""
import json
result = {
"ontology": {
"namespace": self.base_namespace,
"classes": []
}
}
for cls in classes:
class_data = {
"name": cls.name,
"name_chinese": cls.name_chinese,
"description": cls.description,
"entity_type": cls.entity_type,
"domain": cls.domain,
"parent_class": cls.parent_class,
"examples": cls.examples if hasattr(cls, 'examples') else []
}
result["ontology"]["classes"].append(class_data)
# 使用紧凑格式:无缩进,使用分隔符减少空格
return json.dumps(result, ensure_ascii=False, separators=(',', ':'))
def _convert_to_turtle(self, rdfxml_content: str) -> str:
"""Convert RDF/XML content to Turtle format using rdflib.
Args:
rdfxml_content: RDF/XML format content
Returns:
Turtle format content
"""
try:
from rdflib import Graph
# Parse RDF/XML
g = Graph()
g.parse(data=rdfxml_content, format="xml")
# Serialize to Turtle
turtle_content = g.serialize(format="turtle")
# Handle bytes vs string
if isinstance(turtle_content, bytes):
turtle_content = turtle_content.decode('utf-8')
return turtle_content
except ImportError:
logger.warning(
"rdflib is not installed. Cannot convert to Turtle format. "
"Install with: pip install rdflib"
)
return rdfxml_content
except Exception as e:
logger.error(f"Failed to convert to Turtle format: {e}")
return rdfxml_content
def _format_owl_content(self, content: str, format: str) -> str:
"""Format OWL content for better readability.
Args:
content: Raw OWL content string
format: Format type (rdfxml, turtle, ntriples)
Returns:
Formatted OWL content string
"""
if format == "rdfxml":
# Format XML with proper indentation
try:
import xml.dom.minidom as minidom
dom = minidom.parseString(content)
# Pretty print with 2-space indentation
formatted = dom.toprettyxml(indent=" ", encoding="utf-8").decode("utf-8")
# Remove extra blank lines
lines = []
prev_blank = False
for line in formatted.split('\n'):
is_blank = not line.strip()
if not (is_blank and prev_blank): # Skip consecutive blank lines
lines.append(line)
prev_blank = is_blank
formatted = '\n'.join(lines)
return formatted
except Exception as e:
logger.warning(f"Failed to format XML content: {e}")
return content
elif format == "turtle":
# Turtle format is already relatively readable
# Just ensure consistent line endings and not empty
if not content or content.strip() == "":
logger.warning("Turtle content is empty, this may indicate an export issue")
return content.strip() + '\n' if content.strip() else content
elif format == "ntriples":
# N-Triples format is line-based, ensure proper line endings
return content.strip() + '\n' if content.strip() else content
return content
def validate_with_protege_compatibility(
self,
classes: List[OntologyClass]
) -> Tuple[bool, List[str]]:
"""Validate that ontology classes are compatible with Protégé editor.
Protégé compatibility checks:
- Class names are valid OWL identifiers
- No special characters that Protégé cannot handle
- Namespace is properly formatted
- Labels and comments are properly encoded
Args:
classes: List of OntologyClass objects to validate
Returns:
Tuple of (is_compatible, warnings):
- is_compatible: True if compatible with Protégé, False otherwise
- warnings: List of compatibility warning messages
Examples:
>>> validator = OWLValidator()
>>> classes = [OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare")]
>>> is_compatible, warnings = validator.validate_with_protege_compatibility(classes)
>>> is_compatible
True
"""
warnings = []
# Check namespace format
if not self.base_namespace.startswith(('http://', 'https://')):
warnings.append(
f"Namespace '{self.base_namespace}' should start with http:// or https:// "
"for Protégé compatibility"
)
if not self.base_namespace.endswith(('#', '/')):
warnings.append(
f"Namespace '{self.base_namespace}' should end with # or / "
"for Protégé compatibility"
)
# Check each class
for ontology_class in classes:
# Check for special characters that might cause issues
if any(char in ontology_class.name for char in ['<', '>', '"', '{', '}', '|', '^', '`']):
warnings.append(
f"Class name '{ontology_class.name}' contains special characters "
"that may cause issues in Protégé"
)
# Check description length (Protégé can handle long descriptions but may display poorly)
if ontology_class.description and len(ontology_class.description) > 1000:
warnings.append(
f"Class '{ontology_class.name}' has a very long description ({len(ontology_class.description)} chars) "
"which may display poorly in Protégé"
)
# Check for non-ASCII characters (Protégé supports them but encoding issues may occur)
if not ontology_class.name.isascii():
warnings.append(
f"Class name '{ontology_class.name}' contains non-ASCII characters "
"which may cause encoding issues in some Protégé versions"
)
# If no warnings, it's compatible
is_compatible = len(warnings) == 0
return is_compatible, warnings

View File

@@ -14,7 +14,7 @@ from app.core.workflow.nodes.code.config import CodeNodeConfig
logger = logging.getLogger(__name__)
SCRIPT_TEMPLATE = Template(dedent("""
PYTHON_SCRIPT_TEMPLATE = Template(dedent("""
$code
import json
@@ -32,6 +32,20 @@ result = "<<RESULT>>" + output_json + "<<RESULT>>"
print(result)
"""))
NODEJS_SCRIPT_TEMPLATE = Template(dedent("""
$code
// decode and prepare input object
var inputs_obj = JSON.parse(Buffer.from('$inputs_variable', 'base64').toString('utf-8'))
// execute main function
var output_obj = main(inputs_obj)
// convert output to json and print
var output_json = JSON.stringify(output_obj)
var result = `<<RESULT>>$${output_json}<<RESULT>>`
console.log(result)
"""))
class CodeNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
@@ -83,6 +97,7 @@ class CodeNode(BaseNode):
input_variable_dict = {}
for input_variable in self.typed_config.input_variables:
input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, state)
code = base64.b64decode(
self.typed_config.code
).decode("utf-8")
@@ -90,11 +105,18 @@ class CodeNode(BaseNode):
input_variable_dict = base64.b64encode(
json.dumps(input_variable_dict).encode("utf-8")
).decode("utf-8")
final_script = SCRIPT_TEMPLATE.substitute(
code=code,
inputs_variable=input_variable_dict,
)
if self.typed_config.language == "python3":
final_script = PYTHON_SCRIPT_TEMPLATE.substitute(
code=code,
inputs_variable=input_variable_dict,
)
elif self.typed_config.language == 'nodejs':
final_script = NODEJS_SCRIPT_TEMPLATE.substitute(
code=code,
inputs_variable=input_variable_dict,
)
else:
raise ValueError(f"Unsupported language: {self.typed_config.language}")
async with httpx.AsyncClient() as client:
response = await client.post(

View File

@@ -28,6 +28,10 @@ from .tool_model import (
ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus
)
from .memory_perceptual_model import MemoryPerceptualModel
from .ontology_scene import OntologyScene
from .ontology_class import OntologyClass
from .ontology_scene import OntologyScene
from .ontology_class import OntologyClass
__all__ = [
"Tenants",

View File

@@ -22,6 +22,9 @@ class MemoryConfig(Base):
end_user_id = Column(String, nullable=True, comment="组ID")
user_id = Column(String, nullable=True, comment="用户ID")
apply_id = Column(String, nullable=True, comment="应用ID")
# 本体场景关联
scene_id = Column(UUID(as_uuid=True), nullable=True, comment="本体场景ID关联ontology_scene表")
# 模型选择从workspace继承
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")

View File

@@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-
"""本体类型模型
本模块定义本体类型的数据模型。
Classes:
OntologyClass: 本体类型表模型
"""
import datetime
import uuid
from sqlalchemy import Column, String, DateTime, Text, ForeignKey
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from app.db import Base
class OntologyClass(Base):
"""本体类型表 - 用于存储某个场景提取出来的本体类型信息"""
__tablename__ = "ontology_class"
# 主键
class_id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="类型ID")
# 类型信息
class_name = Column(String(200), nullable=False, comment="类型名称")
class_description = Column(Text, nullable=True, comment="类型描述")
# 外键:关联到本体场景
scene_id = Column(UUID(as_uuid=True), ForeignKey("ontology_scene.scene_id", ondelete="CASCADE"), nullable=False, index=True, comment="所属场景ID")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, comment="创建时间")
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, nullable=False, comment="更新时间")
# 关系:类型属于某个场景
scene = relationship("OntologyScene", back_populates="classes")
def __repr__(self):
return f"<OntologyClass(id={self.class_id}, name={self.class_name}, scene_id={self.scene_id})>"

View File

@@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
"""本体场景模型
本模块定义本体场景的数据模型。
Classes:
OntologyScene: 本体场景表模型
"""
import datetime
import uuid
from sqlalchemy import Column, String, DateTime, Integer, Text, ForeignKey, UniqueConstraint
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from app.db import Base
class OntologyScene(Base):
"""本体场景表 - 用于存储本体场景下不同的类型信息"""
__tablename__ = "ontology_scene"
__table_args__ = (
UniqueConstraint('workspace_id', 'scene_name', name='uq_workspace_scene_name'),
)
# 主键
scene_id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="场景ID")
# 场景信息
scene_name = Column(String(200), nullable=False, comment="场景名称")
scene_description = Column(Text, nullable=True, comment="场景描述")
# 外键:关联到工作空间
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="CASCADE"), nullable=False, index=True, comment="所属工作空间ID")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, comment="创建时间")
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, nullable=False, comment="更新时间")
# 关系:一个场景可以有多个类型
classes = relationship("OntologyClass", back_populates="scene", cascade="all, delete-orphan")
def __repr__(self):
return f"<OntologyScene(id={self.scene_id}, name={self.scene_name})>"

View File

@@ -2,7 +2,7 @@ import datetime
import uuid
from enum import StrEnum
from sqlalchemy import Column, ForeignKey, Text, DateTime, String, Index
from sqlalchemy import Column, ForeignKey, Text, DateTime, String, Index, Boolean
from sqlalchemy.dialects.postgresql import UUID
from app.db import Base
@@ -121,10 +121,33 @@ class PromptOptimizerSessionHistory(Base):
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID")
# app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False, comment="Application ID")
session_id = Column(UUID(as_uuid=True), ForeignKey("prompt_opt_session_list.id"),nullable=False, comment="Session ID")
session_id = Column(
UUID(as_uuid=True),
ForeignKey("prompt_opt_session_list.id"),
nullable=False,
comment="Session ID"
)
user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, comment="User ID")
role = Column(String, nullable=False, comment="Message Role")
content = Column(Text, nullable=False, comment="Message Content")
# prompt = Column(Text, nullable=False, comment="Prompt")
created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time", index=True)
class PromptHistory(Base):
__tablename__ = "prompt_history"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID")
session_id = Column(
UUID(as_uuid=True),
ForeignKey("prompt_opt_session_list.id"),
nullable=False,
comment="Session ID"
)
title = Column(String, nullable=False, comment="Title")
prompt = Column(Text, nullable=False, comment="Prompt")
created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time", index=True)
is_delete = Column(Boolean, default=False, comment="Delete")

View File

@@ -32,6 +32,8 @@ db_logger = get_db_logger()
config_logger = get_config_logger()
TABLE_NAME = "memory_config"
class MemoryConfigRepository:
"""记忆配置Repository
@@ -154,7 +156,7 @@ class MemoryConfigRepository:
return memory_config_obj
@staticmethod
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID) -> MemoryConfig:
def query_reflection_config_by_id(db: Session, config_id: uuid.UUID|int|str) -> MemoryConfig:
"""构建反思配置查询语句通过config_id查询反思配置SQLAlchemy text() 命名参数)
Args:
@@ -170,6 +172,7 @@ class MemoryConfigRepository:
if not memory_config:
raise RuntimeError("reflection config not found")
return memory_config
@staticmethod
def query_reflection_config_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> MemoryConfig:
"""构建查询所有配置的语句SQLAlchemy text() 命名参数)
@@ -189,7 +192,6 @@ class MemoryConfigRepository:
raise RuntimeError("reflection config not found")
return memory_config
@staticmethod
def build_select_all(workspace_id: uuid.UUID) -> Tuple[str, Dict]:
"""构建查询所有配置的语句SQLAlchemy text() 命名参数)
@@ -229,6 +231,7 @@ class MemoryConfigRepository:
config_name=params.config_name,
config_desc=params.config_desc,
workspace_id=params.workspace_id,
scene_id=params.scene_id,
llm_id=params.llm_id,
embedding_id=params.embedding_id,
rerank_id=params.rerank_id,
@@ -289,7 +292,6 @@ class MemoryConfigRepository:
db_logger.error(f"更新记忆配置失败: config_id={update.config_id} - {str(e)}")
raise
@staticmethod
def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[MemoryConfig]:
"""更新记忆萃取引擎配置
@@ -412,7 +414,7 @@ class MemoryConfigRepository:
raise
@staticmethod
def get_extracted_config(db: Session, config_id: UUID |int) -> Optional[Dict]:
def get_extracted_config(db: Session, config_id: UUID | int) -> Optional[Dict]:
"""获取萃取配置,通过主键查询某条配置
Args:
@@ -422,7 +424,7 @@ class MemoryConfigRepository:
Returns:
Optional[Dict]: 萃取配置字典不存在则返回None
"""
config_id=resolve_config_id(config_id,db)
config_id = resolve_config_id(config_id, db)
db_logger.debug(f"查询萃取配置: config_id={config_id}")
try:
db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
@@ -516,26 +518,28 @@ class MemoryConfigRepository:
except Exception as e:
db_logger.error(f"根据ID查询记忆配置失败: config_id={config_id} - {str(e)}")
raise
@staticmethod
def get_config_with_workspace(db: Session, config_id: uuid.UUID) -> Optional[tuple]:
def get_config_with_workspace(db: Session, config_id: uuid.UUID | int | str) -> Optional[tuple]:
"""Get memory config and its associated workspace information
Args:
db: Database session
config_id: Configuration ID
Returns:
Optional[tuple]: (MemoryConfig, Workspace) tuple, None if not found
Raises:
ValueError: Raised when config exists but workspace doesn't
"""
import time
from app.models.workspace_model import Workspace
start_time = time.time()
config_id = resolve_config_id(config_id, db)
# Log configuration loading start
config_logger.info(
"Loading configuration with workspace",
@@ -544,17 +548,17 @@ class MemoryConfigRepository:
"config_id": config_id
}
)
db_logger.debug(f"Querying memory config and workspace: config_id={config_id}")
try:
# Use join query to get both config and workspace
result = db.query(MemoryConfig, Workspace).join(
Workspace, MemoryConfig.workspace_id == Workspace.id
).filter(MemoryConfig.config_id == config_id).first()
elapsed_ms = (time.time() - start_time) * 1000
if not result:
# Check if config exists but workspace is missing
config_only = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first()
@@ -583,9 +587,11 @@ class MemoryConfigRepository:
"elapsed_ms": elapsed_ms
}
)
db_logger.error(f"Memory config {config_id} references non-existent workspace {config_only.workspace_id}")
raise ValueError(f"Workspace {config_only.workspace_id} not found for configuration {config_id}")
db_logger.error(
f"Memory config {config_id} references non-existent workspace {config_only.workspace_id}")
raise ValueError(
f"Workspace {config_only.workspace_id} not found for configuration {config_id}")
config_logger.debug(
"Configuration not found",
extra={
@@ -597,9 +603,9 @@ class MemoryConfigRepository:
)
db_logger.debug(f"Memory config not found: config_id={config_id}")
return None
config, workspace = result
# Log successful configuration loading
config_logger.info(
"Configuration with workspace loaded successfully",
@@ -614,16 +620,17 @@ class MemoryConfigRepository:
"elapsed_ms": elapsed_ms
}
)
db_logger.debug(f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
db_logger.debug(
f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}")
return (config, workspace)
except ValueError:
# Re-raise known business exceptions
raise
except Exception as e:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
"Failed to load configuration with workspace",
extra={
@@ -636,9 +643,10 @@ class MemoryConfigRepository:
},
exc_info=True
)
db_logger.error(f"Failed to query memory config and workspace: config_id={config_id} - {str(e)}")
raise
@staticmethod
def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[MemoryConfig]:
"""获取所有配置参数

View File

@@ -630,6 +630,13 @@ class ModelBaseRepository:
db.add(model_base)
return model_base
@staticmethod
def get_by_name_and_provider(db: Session, name: str, provider: str) -> Optional['ModelBase']:
return db.query(ModelBase).filter(
ModelBase.name == name,
ModelBase.provider == provider
).first()
@staticmethod
def update(db: Session, model_base_id: uuid.UUID, data: dict) -> Optional['ModelBase']:
model_base = db.query(ModelBase).filter(ModelBase.id == model_base_id).first()

View File

@@ -877,7 +877,8 @@ RETURN
CASE
WHEN ms:ExtractedEntity THEN {
text: ms.name,
created_at: ms.created_at
created_at: ms.created_at,
type: "情景记忆"
}
END
) AS ExtractedEntity,
@@ -887,7 +888,8 @@ RETURN
CASE
WHEN n:MemorySummary THEN {
text: n.content,
created_at: n.created_at
created_at: n.created_at,
type: "长期沉淀"
}
END
) AS MemorySummary,
@@ -895,7 +897,8 @@ RETURN
collect(
DISTINCT {
text: e.statement,
created_at: e.created_at
created_at: e.created_at,
type: "情绪记忆"
}
) AS statement;
"""

View File

@@ -0,0 +1,404 @@
# -*- coding: utf-8 -*-
"""本体类型Repository层
本模块提供本体类型的数据访问层实现。
Classes:
OntologyClassRepository: 本体类型数据访问类
"""
import logging
from typing import List, Optional
from uuid import UUID
from sqlalchemy.orm import Session, joinedload
from app.core.logging_config import get_db_logger
from app.models.ontology_class import OntologyClass
from app.models.ontology_scene import OntologyScene
logger = get_db_logger()
class OntologyClassRepository:
"""本体类型Repository
提供本体类型的CRUD操作和权限检查。
Attributes:
db: SQLAlchemy数据库会话
"""
def __init__(self, db: Session):
"""初始化Repository
Args:
db: SQLAlchemy数据库会话
"""
self.db = db
def create(self, class_data: dict, scene_id: UUID) -> OntologyClass:
"""创建本体类型
Args:
class_data: 类型数据字典包含class_name和class_description
scene_id: 所属场景ID
Returns:
OntologyClass: 创建的类型对象
Raises:
Exception: 数据库操作失败
Examples:
>>> repo = OntologyClassRepository(db)
>>> ontology_class = repo.create(
... {"class_name": "患者", "class_description": "描述"},
... scene_id
... )
"""
try:
logger.info(
f"Creating ontology class - "
f"name={class_data.get('class_name')}, "
f"scene_id={scene_id}"
)
ontology_class = OntologyClass(
class_name=class_data.get("class_name"),
class_description=class_data.get("class_description"),
scene_id=scene_id
)
self.db.add(ontology_class)
self.db.flush() # 获取ID但不提交
logger.info(
f"Ontology class created successfully - "
f"class_id={ontology_class.class_id}"
)
return ontology_class
except Exception as e:
logger.error(
f"Failed to create ontology class: {str(e)}",
exc_info=True
)
raise
def get_by_id(self, class_id: UUID) -> Optional[OntologyClass]:
"""根据ID获取类型
Args:
class_id: 类型ID
Returns:
Optional[OntologyClass]: 类型对象不存在则返回None
Examples:
>>> repo = OntologyClassRepository(db)
>>> ontology_class = repo.get_by_id(class_id)
"""
try:
logger.debug(f"Getting ontology class by ID: {class_id}")
ontology_class = self.db.query(OntologyClass).filter(
OntologyClass.class_id == class_id
).first()
if ontology_class:
logger.debug(f"Ontology class found: {class_id}")
else:
logger.debug(f"Ontology class not found: {class_id}")
return ontology_class
except Exception as e:
logger.error(
f"Failed to get ontology class by ID: {str(e)}",
exc_info=True
)
raise
def get_by_name(self, class_name: str, scene_id: UUID) -> Optional[OntologyClass]:
"""根据类型名称和场景ID获取类型精确匹配
Args:
class_name: 类型名称
scene_id: 场景ID
Returns:
Optional[OntologyClass]: 类型对象不存在则返回None
Examples:
>>> repo = OntologyClassRepository(db)
>>> ontology_class = repo.get_by_name("患者", scene_id)
"""
try:
logger.debug(f"Getting ontology class by name: {class_name}, scene_id: {scene_id}")
ontology_class = self.db.query(OntologyClass).filter(
OntologyClass.class_name == class_name,
OntologyClass.scene_id == scene_id
).first()
if ontology_class:
logger.debug(f"Ontology class found: {class_name}")
else:
logger.debug(f"Ontology class not found: {class_name}")
return ontology_class
except Exception as e:
logger.error(
f"Failed to get ontology class by name: {str(e)}",
exc_info=True
)
raise
def search_by_name(self, keyword: str, scene_id: UUID) -> List[OntologyClass]:
"""根据关键词模糊搜索类型
使用 LIKE 进行模糊匹配,支持中文和英文。
Args:
keyword: 搜索关键词
scene_id: 场景ID
Returns:
List[OntologyClass]: 匹配的类型列表
Examples:
>>> repo = OntologyClassRepository(db)
>>> classes = repo.search_by_name("患者", scene_id)
"""
try:
logger.debug(
f"Searching ontology classes by keyword - "
f"keyword={keyword}, scene_id={scene_id}"
)
# 使用 ilike 进行不区分大小写的模糊匹配
classes = self.db.query(OntologyClass).filter(
OntologyClass.class_name.ilike(f"%{keyword}%"),
OntologyClass.scene_id == scene_id
).order_by(
OntologyClass.created_at.desc()
).all()
logger.info(
f"Found {len(classes)} ontology classes matching keyword '{keyword}' "
f"in scene {scene_id}"
)
return classes
except Exception as e:
logger.error(
f"Failed to search ontology classes by keyword: {str(e)}",
exc_info=True
)
raise
def get_by_scene(self, scene_id: UUID) -> List[OntologyClass]:
"""获取场景下的所有类型
按创建时间倒序排列。
Args:
scene_id: 场景ID
Returns:
List[OntologyClass]: 类型列表
Examples:
>>> repo = OntologyClassRepository(db)
>>> classes = repo.get_by_scene(scene_id)
"""
try:
logger.debug(f"Getting ontology classes by scene: {scene_id}")
classes = self.db.query(OntologyClass).filter(
OntologyClass.scene_id == scene_id
).order_by(
OntologyClass.created_at.desc()
).all()
logger.info(
f"Found {len(classes)} ontology classes in scene {scene_id}"
)
return classes
except Exception as e:
logger.error(
f"Failed to get ontology classes by scene: {str(e)}",
exc_info=True
)
raise
def update(self, class_id: UUID, update_data: dict) -> Optional[OntologyClass]:
"""更新类型信息
Args:
class_id: 类型ID
update_data: 更新数据字典
Returns:
Optional[OntologyClass]: 更新后的类型对象不存在则返回None
Raises:
Exception: 数据库操作失败
Examples:
>>> repo = OntologyClassRepository(db)
>>> ontology_class = repo.update(
... class_id,
... {"class_name": "新名称"}
... )
"""
try:
logger.info(f"Updating ontology class: {class_id}")
ontology_class = self.get_by_id(class_id)
if not ontology_class:
logger.warning(f"Ontology class not found for update: {class_id}")
return None
# 更新字段
if "class_name" in update_data and update_data["class_name"] is not None:
ontology_class.class_name = update_data["class_name"]
if "class_description" in update_data:
ontology_class.class_description = update_data["class_description"]
self.db.flush()
logger.info(f"Ontology class updated successfully: {class_id}")
return ontology_class
except Exception as e:
logger.error(
f"Failed to update ontology class: {str(e)}",
exc_info=True
)
raise
def delete(self, class_id: UUID) -> bool:
"""删除类型
Args:
class_id: 类型ID
Returns:
bool: 删除成功返回True类型不存在返回False
Raises:
Exception: 数据库操作失败
Examples:
>>> repo = OntologyClassRepository(db)
>>> success = repo.delete(class_id)
"""
try:
logger.info(f"Deleting ontology class: {class_id}")
ontology_class = self.get_by_id(class_id)
if not ontology_class:
logger.warning(f"Ontology class not found for delete: {class_id}")
return False
self.db.delete(ontology_class)
self.db.flush()
logger.info(f"Ontology class deleted successfully: {class_id}")
return True
except Exception as e:
logger.error(
f"Failed to delete ontology class: {str(e)}",
exc_info=True
)
raise
def check_ownership(self, class_id: UUID, workspace_id: UUID) -> bool:
"""检查类型是否属于指定工作空间(通过场景关联)
Args:
class_id: 类型ID
workspace_id: 工作空间ID
Returns:
bool: 属于返回True否则返回False
Examples:
>>> repo = OntologyClassRepository(db)
>>> is_owner = repo.check_ownership(class_id, workspace_id)
"""
try:
logger.debug(
f"Checking class ownership - "
f"class_id={class_id}, workspace_id={workspace_id}"
)
count = self.db.query(OntologyClass).join(
OntologyScene,
OntologyClass.scene_id == OntologyScene.scene_id
).filter(
OntologyClass.class_id == class_id,
OntologyScene.workspace_id == workspace_id
).count()
is_owner = count > 0
logger.debug(
f"Class ownership check result: {is_owner} - "
f"class_id={class_id}"
)
return is_owner
except Exception as e:
logger.error(
f"Failed to check class ownership: {str(e)}",
exc_info=True
)
raise
def get_scene_id_by_class(self, class_id: UUID) -> Optional[UUID]:
"""根据类型ID获取所属场景ID
Args:
class_id: 类型ID
Returns:
Optional[UUID]: 场景ID类型不存在则返回None
Examples:
>>> repo = OntologyClassRepository(db)
>>> scene_id = repo.get_scene_id_by_class(class_id)
"""
try:
logger.debug(f"Getting scene ID by class: {class_id}")
ontology_class = self.get_by_id(class_id)
if not ontology_class:
logger.debug(f"Class not found: {class_id}")
return None
logger.debug(
f"Found scene ID: {ontology_class.scene_id} for class: {class_id}"
)
return ontology_class.scene_id
except Exception as e:
logger.error(
f"Failed to get scene ID by class: {str(e)}",
exc_info=True
)
raise

View File

@@ -0,0 +1,394 @@
# -*- coding: utf-8 -*-
"""本体场景Repository层
本模块提供本体场景的数据访问层实现。
Classes:
OntologySceneRepository: 本体场景数据访问类
"""
import logging
from typing import List, Optional
from uuid import UUID
from sqlalchemy.orm import Session, joinedload
from app.core.logging_config import get_db_logger
from app.models.ontology_scene import OntologyScene
logger = get_db_logger()
class OntologySceneRepository:
"""本体场景Repository
提供本体场景的CRUD操作和权限检查。
Attributes:
db: SQLAlchemy数据库会话
"""
def __init__(self, db: Session):
"""初始化Repository
Args:
db: SQLAlchemy数据库会话
"""
self.db = db
def create(self, scene_data: dict, workspace_id: UUID) -> OntologyScene:
"""创建本体场景
Args:
scene_data: 场景数据字典包含scene_name和scene_description
workspace_id: 所属工作空间ID
Returns:
OntologyScene: 创建的场景对象
Raises:
Exception: 数据库操作失败
Examples:
>>> repo = OntologySceneRepository(db)
>>> scene = repo.create(
... {"scene_name": "医疗场景", "scene_description": "描述"},
... workspace_id
... )
"""
try:
logger.info(
f"Creating ontology scene - "
f"name={scene_data.get('scene_name')}, "
f"workspace_id={workspace_id}"
)
scene = OntologyScene(
scene_name=scene_data.get("scene_name"),
scene_description=scene_data.get("scene_description"),
workspace_id=workspace_id
)
self.db.add(scene)
self.db.flush() # 获取ID但不提交
logger.info(
f"Ontology scene created successfully - "
f"scene_id={scene.scene_id}"
)
return scene
except Exception as e:
logger.error(
f"Failed to create ontology scene: {str(e)}",
exc_info=True
)
raise
def get_by_id(self, scene_id: UUID) -> Optional[OntologyScene]:
"""根据ID获取场景
Args:
scene_id: 场景ID
Returns:
Optional[OntologyScene]: 场景对象不存在则返回None
Examples:
>>> repo = OntologySceneRepository(db)
>>> scene = repo.get_by_id(scene_id)
"""
try:
logger.debug(f"Getting ontology scene by ID: {scene_id}")
scene = self.db.query(OntologyScene).filter(
OntologyScene.scene_id == scene_id
).first()
if scene:
logger.debug(f"Ontology scene found: {scene_id}")
else:
logger.debug(f"Ontology scene not found: {scene_id}")
return scene
except Exception as e:
logger.error(
f"Failed to get ontology scene by ID: {str(e)}",
exc_info=True
)
raise
def get_by_name(self, scene_name: str, workspace_id: UUID) -> Optional[OntologyScene]:
"""根据场景名称和工作空间ID获取场景精确匹配
Args:
scene_name: 场景名称
workspace_id: 工作空间ID
Returns:
Optional[OntologyScene]: 场景对象不存在则返回None
Examples:
>>> repo = OntologySceneRepository(db)
>>> scene = repo.get_by_name("医疗场景", workspace_id)
"""
try:
logger.debug(
f"Getting ontology scene by name - "
f"scene_name={scene_name}, workspace_id={workspace_id}"
)
scene = self.db.query(OntologyScene).options(
joinedload(OntologyScene.classes)
).filter(
OntologyScene.scene_name == scene_name,
OntologyScene.workspace_id == workspace_id
).first()
if scene:
logger.debug(f"Ontology scene found: {scene_name}")
else:
logger.debug(f"Ontology scene not found: {scene_name}")
return scene
except Exception as e:
logger.error(
f"Failed to get ontology scene by name: {str(e)}",
exc_info=True
)
raise
def search_by_name(self, keyword: str, workspace_id: UUID) -> List[OntologyScene]:
"""根据关键词模糊搜索场景
使用 LIKE 进行模糊匹配,支持中文和英文。
Args:
keyword: 搜索关键词
workspace_id: 工作空间ID
Returns:
List[OntologyScene]: 匹配的场景列表
Examples:
>>> repo = OntologySceneRepository(db)
>>> scenes = repo.search_by_name("医疗", workspace_id)
"""
try:
logger.debug(
f"Searching ontology scenes by keyword - "
f"keyword={keyword}, workspace_id={workspace_id}"
)
# 使用 ilike 进行不区分大小写的模糊匹配
scenes = self.db.query(OntologyScene).options(
joinedload(OntologyScene.classes)
).filter(
OntologyScene.scene_name.ilike(f"%{keyword}%"),
OntologyScene.workspace_id == workspace_id
).order_by(
OntologyScene.updated_at.desc()
).all()
logger.info(
f"Found {len(scenes)} ontology scenes matching keyword '{keyword}' "
f"in workspace {workspace_id}"
)
return scenes
except Exception as e:
logger.error(
f"Failed to search ontology scenes by keyword: {str(e)}",
exc_info=True
)
raise
def get_by_workspace(self, workspace_id: UUID, page: Optional[int] = None, page_size: Optional[int] = None) -> tuple:
"""获取工作空间下的所有场景(支持分页)
使用joinedload预加载classes关系以统计数量。
Args:
workspace_id: 工作空间ID
page: 页码可选从1开始
page_size: 每页数量(可选)
Returns:
tuple: (场景列表, 总数量)
Examples:
>>> repo = OntologySceneRepository(db)
>>> scenes, total = repo.get_by_workspace(workspace_id)
>>> scenes, total = repo.get_by_workspace(workspace_id, page=1, page_size=10)
"""
try:
logger.debug(f"Getting ontology scenes by workspace: {workspace_id}, page={page}, page_size={page_size}")
# 构建基础查询
query = self.db.query(OntologyScene).options(
joinedload(OntologyScene.classes)
).filter(
OntologyScene.workspace_id == workspace_id
).order_by(
OntologyScene.updated_at.desc()
)
# 获取总数
total = query.count()
# 如果提供了分页参数,应用分页
if page is not None and page_size is not None:
offset = (page - 1) * page_size
query = query.offset(offset).limit(page_size)
logger.debug(f"Applying pagination: offset={offset}, limit={page_size}")
scenes = query.all()
logger.info(
f"Found {len(scenes)} ontology scenes (total: {total}) in workspace {workspace_id}"
)
return scenes, total
except Exception as e:
logger.error(
f"Failed to get ontology scenes by workspace: {str(e)}",
exc_info=True
)
raise
def update(self, scene_id: UUID, update_data: dict) -> Optional[OntologyScene]:
"""更新场景信息
Args:
scene_id: 场景ID
update_data: 更新数据字典
Returns:
Optional[OntologyScene]: 更新后的场景对象不存在则返回None
Raises:
Exception: 数据库操作失败
Examples:
>>> repo = OntologySceneRepository(db)
>>> scene = repo.update(
... scene_id,
... {"scene_name": "新名称"}
... )
"""
try:
logger.info(f"Updating ontology scene: {scene_id}")
scene = self.get_by_id(scene_id)
if not scene:
logger.warning(f"Ontology scene not found for update: {scene_id}")
return None
# 更新字段
if "scene_name" in update_data and update_data["scene_name"] is not None:
scene.scene_name = update_data["scene_name"]
if "scene_description" in update_data:
scene.scene_description = update_data["scene_description"]
self.db.flush()
logger.info(f"Ontology scene updated successfully: {scene_id}")
return scene
except Exception as e:
logger.error(
f"Failed to update ontology scene: {str(e)}",
exc_info=True
)
raise
def delete(self, scene_id: UUID) -> bool:
"""删除场景(级联删除类型)
依赖数据库级联删除配置ondelete="CASCADE")。
Args:
scene_id: 场景ID
Returns:
bool: 删除成功返回True场景不存在返回False
Raises:
Exception: 数据库操作失败
Examples:
>>> repo = OntologySceneRepository(db)
>>> success = repo.delete(scene_id)
"""
try:
logger.info(f"Deleting ontology scene: {scene_id}")
scene = self.get_by_id(scene_id)
if not scene:
logger.warning(f"Ontology scene not found for delete: {scene_id}")
return False
self.db.delete(scene)
self.db.flush()
logger.info(
f"Ontology scene deleted successfully (cascade): {scene_id}"
)
return True
except Exception as e:
logger.error(
f"Failed to delete ontology scene: {str(e)}",
exc_info=True
)
raise
def check_ownership(self, scene_id: UUID, workspace_id: UUID) -> bool:
"""检查场景是否属于指定工作空间
Args:
scene_id: 场景ID
workspace_id: 工作空间ID
Returns:
bool: 属于返回True否则返回False
Examples:
>>> repo = OntologySceneRepository(db)
>>> is_owner = repo.check_ownership(scene_id, workspace_id)
"""
try:
logger.debug(
f"Checking scene ownership - "
f"scene_id={scene_id}, workspace_id={workspace_id}"
)
count = self.db.query(OntologyScene).filter(
OntologyScene.scene_id == scene_id,
OntologyScene.workspace_id == workspace_id
).count()
is_owner = count > 0
logger.debug(
f"Scene ownership check result: {is_owner} - "
f"scene_id={scene_id}"
)
return is_owner
except Exception as e:
logger.error(
f"Failed to check scene ownership: {str(e)}",
exc_info=True
)
raise

View File

@@ -4,7 +4,10 @@ from sqlalchemy.orm import Session
from app.core.logging_config import get_db_logger
from app.models.prompt_optimizer_model import (
PromptOptimizerSession, PromptOptimizerSessionHistory, RoleType
PromptOptimizerSession,
PromptOptimizerSessionHistory,
RoleType,
PromptHistory
)
db_logger = get_db_logger()
@@ -16,6 +19,12 @@ class PromptOptimizerSessionRepository:
def __init__(self, db: Session):
self.db = db
def get_session_by_id(self, session_id: uuid.UUID) -> PromptOptimizerSession | None:
session = self.db.query(PromptOptimizerSession).filter(
PromptOptimizerSession.id == session_id,
).first()
return session
def create_session(
self,
tenant_id: uuid.UUID,
@@ -38,12 +47,9 @@ class PromptOptimizerSessionRepository:
user_id=user_id,
)
self.db.add(session)
self.db.commit()
self.db.refresh(session)
db_logger.debug(f"Prompt optimization session created: ID:{session.id}")
return session
except Exception as e:
db_logger.error(f"Error creating prompt optimization session: user_id={user_id} - {str(e)}")
db_logger.error(f"Error creating prompt optimization session: - {str(e)}")
raise
def get_session_history(
@@ -71,10 +77,10 @@ class PromptOptimizerSessionRepository:
PromptOptimizerSession.id == session_id,
PromptOptimizerSession.user_id == user_id
).first()
if not session:
return []
history = self.db.query(PromptOptimizerSessionHistory).filter(
PromptOptimizerSessionHistory.session_id == session.id,
PromptOptimizerSessionHistory.user_id == user_id
@@ -104,11 +110,11 @@ class PromptOptimizerSessionRepository:
PromptOptimizerSession.user_id == user_id,
PromptOptimizerSession.tenant_id == tenant_id
).first()
if not session:
db_logger.error(f"Session {session_id} not found for user {user_id}")
raise ValueError(f"Session {session_id} not found for user {user_id}")
message = PromptOptimizerSessionHistory(
tenant_id=tenant_id,
session_id=session.id,
@@ -117,8 +123,199 @@ class PromptOptimizerSessionRepository:
content=content,
)
self.db.add(message)
self.db.commit()
return message
except Exception as e:
db_logger.error(f"Error creating prompt optimization session history: session_id={session_id} - {str(e)}")
raise
def get_first_user_message(self, session_id: uuid.UUID) -> str | None:
"""
Get the first user message from a session.
Args:
session_id (uuid.UUID): The session ID.
Returns:
str | None: The content of the first user message, or None if not found.
"""
try:
message = self.db.query(PromptOptimizerSessionHistory).filter(
PromptOptimizerSessionHistory.session_id == session_id,
PromptOptimizerSessionHistory.role == RoleType.USER.value
).order_by(
PromptOptimizerSessionHistory.created_at.asc()
).first()
return message.content if message else None
except Exception as e:
db_logger.error(f"Error getting first user message: session_id={session_id} - {str(e)}")
raise
class PromptReleaseRepository:
def __init__(self, db: Session):
self.db = db
def get_prompt_by_session_id(self, session_id: uuid.UUID) -> PromptHistory | None:
prompt_obj = self.db.query(PromptHistory).filter(
PromptHistory.session_id == session_id,
PromptHistory.is_delete.is_(False)
).first()
return prompt_obj
def create_prompt_release(
self,
tenant_id: uuid.UUID,
title: str,
session_id: uuid.UUID,
prompt: str,
) -> PromptHistory:
try:
prompt_obj = PromptHistory(
tenant_id=tenant_id,
title=title,
session_id=session_id,
prompt=prompt,
)
self.db.add(prompt_obj)
return prompt_obj
except Exception as e:
db_logger.error(f"Error creating prompt release: session_id={session_id} - {str(e)}")
raise
def soft_delete_prompt(self, prompt_obj: PromptHistory) -> None:
"""
Soft delete a prompt release by setting is_delete flag to True.
Args:
prompt_obj (PromptHistory): The prompt release object to delete.
"""
try:
prompt_obj.is_delete = True
db_logger.debug(f"Soft deleted prompt release: id={prompt_obj.id}, session_id={prompt_obj.session_id}")
except Exception as e:
db_logger.error(f"Error soft deleting prompt release: id={prompt_obj.id} - {str(e)}")
raise
def get_prompt_by_id(self, prompt_id: uuid.UUID) -> PromptHistory | None:
"""
Get a prompt release by its ID.
Args:
prompt_id (uuid.UUID): The prompt release ID.
Returns:
PromptHistory | None: The prompt release object or None if not found.
"""
try:
prompt_obj = self.db.query(PromptHistory).filter(
PromptHistory.id == prompt_id
).first()
return prompt_obj
except Exception as e:
db_logger.error(f"Error getting prompt release by id: id={prompt_id} - {str(e)}")
raise
def count_prompts(self, tenant_id: uuid.UUID) -> int:
"""
Count total number of non-deleted prompts for a tenant.
Args:
tenant_id (uuid.UUID): The tenant ID.
Returns:
int: Total count of prompts.
"""
try:
count = self.db.query(PromptHistory).filter(
PromptHistory.tenant_id == tenant_id,
PromptHistory.is_delete.is_(False)
).count()
return count
except Exception as e:
db_logger.error(f"Error counting prompts: tenant_id={tenant_id} - {str(e)}")
raise
def get_prompts_paginated(
self,
tenant_id: uuid.UUID,
offset: int,
limit: int
) -> list[PromptHistory]:
"""
Get paginated list of prompt releases for a tenant.
Args:
tenant_id (uuid.UUID): The tenant ID.
offset (int): Number of records to skip.
limit (int): Maximum number of records to return.
Returns:
list[PromptHistory]: List of prompt releases.
"""
try:
prompts = self.db.query(PromptHistory).filter(
PromptHistory.tenant_id == tenant_id,
PromptHistory.is_delete.is_(False)
).order_by(
PromptHistory.created_at.desc()
).offset(offset).limit(limit).all()
return prompts
except Exception as e:
db_logger.error(f"Error getting paginated prompts: tenant_id={tenant_id} - {str(e)}")
raise
def count_prompts_by_keyword(self, tenant_id: uuid.UUID, keyword: str) -> int:
"""
Count total number of non-deleted prompts matching keyword for a tenant.
Args:
tenant_id (uuid.UUID): The tenant ID.
keyword (str): Search keyword for title.
Returns:
int: Total count of matching prompts.
"""
try:
count = self.db.query(PromptHistory).filter(
PromptHistory.tenant_id == tenant_id,
PromptHistory.is_delete.is_(False),
PromptHistory.title.ilike(f"%{keyword}%")
).count()
return count
except Exception as e:
db_logger.error(f"Error counting prompts by keyword: tenant_id={tenant_id}, keyword={keyword} - {str(e)}")
raise
def search_prompts_paginated(
self,
tenant_id: uuid.UUID,
keyword: str,
offset: int,
limit: int
) -> list[PromptHistory]:
"""
Search prompt releases by keyword in title with pagination.
Args:
tenant_id (uuid.UUID): The tenant ID.
keyword (str): Search keyword for title.
offset (int): Number of records to skip.
limit (int): Maximum number of records to return.
Returns:
list[PromptHistory]: List of matching prompt releases.
"""
try:
prompts = self.db.query(PromptHistory).filter(
PromptHistory.tenant_id == tenant_id,
PromptHistory.is_delete.is_(False),
PromptHistory.title.ilike(f"%{keyword}%")
).order_by(
PromptHistory.created_at.desc()
).offset(offset).limit(limit).all()
return prompts
except Exception as e:
db_logger.error(f"Error searching prompts: tenant_id={tenant_id}, keyword={keyword} - {str(e)}")
raise

View File

@@ -12,8 +12,8 @@ class KnowledgeBaseConfig(BaseModel):
kb_id: str = Field(..., description="知识库ID")
top_k: int = Field(default=3, ge=1, le=20, description="检索返回的文档数量")
similarity_threshold: float = Field(default=0.7, ge=0.0, le=1.0, description="相似度阈值")
strategy: str = Field(default="hybrid", description="检索策略: hybrid | bm25 | dense")
weight: float = Field(default=1.0, ge=0.0, le=1.0, description="知识库权重(用于多知识库融合)")
# strategy: str = Field(default="hybrid", description="检索策略: hybrid | bm25 | dense")
# weight: float = Field(default=1.0, ge=0.0, le=1.0, description="知识库权重(用于多知识库融合)")
vector_similarity_weight: float = Field(default=0.5, ge=0.0, le=1.0, description="向量相似度权重")
retrieve_type: str = Field(default="hybrid", description="检索方式participle semantichybrid")

View File

@@ -229,6 +229,9 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body
config_desc: str = Field("配置描述", description="配置描述(字符串)")
workspace_id: Optional[uuid.UUID] = Field(None, description="工作空间IDUUID")
# 本体场景关联(可选)
scene_id: Optional[uuid.UUID] = Field(None, description="本体场景IDUUID关联ontology_scene表")
# 模型配置字段(可选,用于手动指定或自动填充)
llm_id: Optional[str] = Field(None, description="LLM模型配置ID")
embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID")

View File

@@ -0,0 +1,461 @@
"""本体提取API的请求和响应模型
本模块定义了本体提取系统的所有API请求和响应的Pydantic模型。
Classes:
ExtractionRequest: 本体提取请求模型
ExtractionResponse: 本体提取响应模型
ExportRequest: OWL文件导出请求模型
ExportResponse: OWL文件导出响应模型
OntologyResultResponse: 本体提取结果响应模型(带毫秒时间戳)
SceneCreateRequest: 场景创建请求模型
SceneUpdateRequest: 场景更新请求模型
SceneResponse: 场景响应模型
SceneListResponse: 场景列表响应模型
ClassCreateRequest: 类型创建请求模型
ClassUpdateRequest: 类型更新请求模型
ClassResponse: 类型响应模型
ClassListResponse: 类型列表响应模型
"""
from typing import List, Optional
import datetime
from uuid import UUID
from pydantic import BaseModel, Field, field_serializer, ConfigDict
from app.core.memory.models.ontology_models import OntologyClass
class ExtractionRequest(BaseModel):
"""本体提取请求模型
用于POST /api/ontology/extract端点的请求体。
Attributes:
scenario: 场景描述文本,不能为空
domain: 可选的领域提示(如Healthcare, Education等)
llm_id: LLM模型ID,必须提供
scene_id: 场景ID,必须提供,用于将提取的类保存到指定场景
Examples:
>>> request = ExtractionRequest(
... scenario="医院管理患者记录...",
... domain="Healthcare",
... llm_id="550e8400-e29b-41d4-a716-446655440000",
... scene_id="660e8400-e29b-41d4-a716-446655440000"
... )
"""
scenario: str = Field(..., description="场景描述文本", min_length=1)
domain: Optional[str] = Field(None, description="可选的领域提示")
llm_id: str = Field(..., description="LLM模型ID")
scene_id: UUID = Field(..., description="场景ID,用于将提取的类保存到指定场景")
class ExtractionResponse(BaseModel):
"""本体提取响应模型
用于POST /api/ontology/extract端点的响应体。
Attributes:
classes: 提取的本体类列表
domain: 识别的领域
extracted_count: 提取的类数量
Examples:
>>> response = ExtractionResponse(
... classes=[...],
... domain="Healthcare",
... extracted_count=7
... )
"""
classes: List[OntologyClass] = Field(default_factory=list, description="提取的本体类列表")
domain: str = Field(..., description="识别的领域")
extracted_count: int = Field(..., description="提取的类数量")
class ExportRequest(BaseModel):
"""OWL文件导出请求模型
用于POST /api/ontology/export端点的请求体。
Attributes:
classes: 要导出的本体类列表
format: 导出格式,可选值: rdfxml, turtle, ntriples, json
include_metadata: 是否包含完整的OWL元数据(命名空间等),默认True
Examples:
>>> request = ExportRequest(
... classes=[...],
... format="rdfxml",
... include_metadata=True
... )
"""
classes: List[OntologyClass] = Field(..., description="要导出的本体类列表", min_length=1)
format: str = Field("rdfxml", description="导出格式: rdfxml, turtle, ntriples, json")
include_metadata: bool = Field(True, description="是否包含完整的OWL元数据")
class ExportResponse(BaseModel):
"""OWL文件导出响应模型
用于POST /api/ontology/export端点的响应体。
Attributes:
owl_content: OWL文件内容
format: 导出格式
classes_count: 导出的类数量
Examples:
>>> response = ExportResponse(
... owl_content="<?xml version='1.0'?>...",
... format="rdfxml",
... classes_count=7
... )
"""
owl_content: str = Field(..., description="OWL文件内容")
format: str = Field(..., description="导出格式")
classes_count: int = Field(..., description="导出的类数量")
class OntologyResultResponse(BaseModel):
"""本体提取结果响应模型
用于返回数据库中存储的提取结果,时间戳为毫秒级。
Attributes:
id: 结果ID (UUID)
scenario: 场景描述文本
domain: 领域
classes_json: 提取的本体类数据(JSON格式)
extracted_count: 提取的类数量
user_id: 用户ID
created_at: 创建时间(毫秒时间戳)
Examples:
>>> response = OntologyResultResponse(
... id=uuid.uuid4(),
... scenario="医院管理患者记录...",
... domain="Healthcare",
... classes_json={"classes": [...]},
... extracted_count=7,
... user_id=123,
... created_at=datetime.now()
... )
"""
id: UUID = Field(..., description="结果ID")
scenario: str = Field(..., description="场景描述文本")
domain: Optional[str] = Field(None, description="领域")
classes_json: dict = Field(..., description="提取的本体类数据(JSON格式)")
extracted_count: int = Field(..., description="提取的类数量")
user_id: Optional[int] = Field(None, description="用户ID")
created_at: datetime.datetime = Field(..., description="创建时间")
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
"""将创建时间序列化为毫秒时间戳"""
return int(dt.timestamp() * 1000) if dt else None
class Config:
from_attributes = True
# ==================== 本体场景相关 Schema ====================
class SceneCreateRequest(BaseModel):
"""场景创建请求模型
用于创建新的本体场景。
Attributes:
scene_name: 场景名称必填1-200字符
scene_description: 场景描述,可选
Examples:
>>> request = SceneCreateRequest(
... scene_name="医疗场景",
... scene_description="用于医疗领域的本体建模"
... )
"""
scene_name: str = Field(..., min_length=1, max_length=200, description="场景名称")
scene_description: Optional[str] = Field(None, description="场景描述")
class SceneUpdateRequest(BaseModel):
"""场景更新请求模型
用于更新已有本体场景信息。
Attributes:
scene_name: 场景名称可选1-200字符
scene_description: 场景描述,可选
Examples:
>>> request = SceneUpdateRequest(
... scene_name="更新后的场景名称",
... scene_description="更新后的描述"
... )
"""
scene_name: Optional[str] = Field(None, min_length=1, max_length=200, description="场景名称")
scene_description: Optional[str] = Field(None, description="场景描述")
class SceneResponse(BaseModel):
"""场景响应模型
用于返回本体场景信息。
Attributes:
scene_id: 场景ID
scene_name: 场景名称
scene_description: 场景描述
type_num: 类型数量
workspace_id: 所属工作空间ID
created_at: 创建时间(毫秒时间戳)
updated_at: 更新时间(毫秒时间戳)
classes_count: 类型数量
Examples:
>>> response = SceneResponse(
... scene_id=uuid.uuid4(),
... scene_name="医疗场景",
... scene_description="用于医疗领域的本体建模",
... type_num=0,
... workspace_id=uuid.uuid4(),
... created_at=datetime.now(),
... updated_at=datetime.now(),
... classes_count=5
... )
"""
scene_id: UUID = Field(..., description="场景ID")
scene_name: str = Field(..., description="场景名称")
scene_description: Optional[str] = Field(None, description="场景描述")
type_num: int = Field(..., description="类型数量")
entity_type: Optional[List[str]] = Field(None, description="实体类型列表最多3个class_name")
workspace_id: UUID = Field(..., description="所属工作空间ID")
created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)")
updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)")
classes_count: int = Field(0, description="类型数量")
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
"""将创建时间序列化为毫秒时间戳"""
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("updated_at", when_used="json")
def _serialize_updated_at(self, dt: datetime.datetime):
"""将更新时间序列化为毫秒时间戳"""
return int(dt.timestamp() * 1000) if dt else None
model_config = ConfigDict(from_attributes=True)
class PaginationInfo(BaseModel):
"""分页信息模型
Attributes:
page: 当前页码
pagesize: 每页数量
total: 总数量
hasnext: 是否有下一页
"""
page: int = Field(..., description="当前页码")
pagesize: int = Field(..., description="每页数量")
total: int = Field(..., description="总数量")
hasnext: bool = Field(..., description="是否有下一页")
class SceneListResponse(BaseModel):
"""场景列表响应模型(支持分页)
用于返回本体场景列表。
Attributes:
items: 场景列表
page: 分页信息(可选,分页时返回)
Examples:
>>> # 不分页
>>> response = SceneListResponse(
... items=[scene1, scene2]
... )
>>> # 分页
>>> response = SceneListResponse(
... items=[scene1, scene2, ...],
... page=PaginationInfo(page=1, pagesize=100, total=150, hasnext=True)
... )
"""
items: List[SceneResponse] = Field(..., description="场景列表")
page: Optional[PaginationInfo] = Field(None, description="分页信息")
# ==================== 本体类型相关 Schema ====================
class ClassItem(BaseModel):
"""单个类型信息模型
Attributes:
class_name: 类型名称必填1-200字符
class_description: 类型描述,可选
Examples:
>>> item = ClassItem(
... class_name="患者",
... class_description="医院患者信息"
... )
"""
class_name: str = Field(..., min_length=1, max_length=200, description="类型名称")
class_description: Optional[str] = Field(None, description="类型描述")
class ClassCreateRequest(BaseModel):
"""类型创建请求模型(统一使用列表形式)
通过列表中元素数量决定创建模式:
- 列表包含 1 个元素:单个创建
- 列表包含多个元素:批量创建
Attributes:
scene_id: 所属场景ID必填
classes: 类型列表,必填,至少包含 1 个元素
Examples:
# 单个创建(列表中 1 个元素)
>>> request = ClassCreateRequest(
... scene_id=uuid.uuid4(),
... classes=[
... ClassItem(class_name="患者", class_description="医院患者信息")
... ]
... )
# 批量创建(列表中多个元素)
>>> request = ClassCreateRequest(
... scene_id=uuid.uuid4(),
... classes=[
... ClassItem(class_name="患者", class_description="医院患者信息"),
... ClassItem(class_name="医生", class_description="医院医生信息"),
... ClassItem(class_name="药品", class_description="医院药品信息")
... ]
... )
"""
scene_id: UUID = Field(..., description="所属场景ID")
classes: List[ClassItem] = Field(..., min_length=1, description="类型列表,至少包含 1 个元素")
class ClassUpdateRequest(BaseModel):
"""类型更新请求模型
用于更新已有本体类型信息。
Attributes:
class_name: 类型名称可选1-200字符
class_description: 类型描述,可选
Examples:
>>> request = ClassUpdateRequest(
... class_name="更新后的类型名称",
... class_description="更新后的描述"
... )
"""
class_name: Optional[str] = Field(None, min_length=1, max_length=200, description="类型名称")
class_description: Optional[str] = Field(None, description="类型描述")
class ClassResponse(BaseModel):
"""类型响应模型
用于返回本体类型信息。
Attributes:
class_id: 类型ID
class_name: 类型名称
class_description: 类型描述
scene_id: 所属场景ID
created_at: 创建时间(毫秒时间戳)
updated_at: 更新时间(毫秒时间戳)
Examples:
>>> response = ClassResponse(
... class_id=uuid.uuid4(),
... class_name="患者",
... class_description="医院患者信息",
... scene_id=uuid.uuid4(),
... created_at=datetime.now(),
... updated_at=datetime.now()
... )
"""
class_id: UUID = Field(..., description="类型ID")
class_name: str = Field(..., description="类型名称")
class_description: Optional[str] = Field(None, description="类型描述")
scene_id: UUID = Field(..., description="所属场景ID")
created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)")
updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)")
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
"""将创建时间序列化为毫秒时间戳"""
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("updated_at", when_used="json")
def _serialize_updated_at(self, dt: datetime.datetime):
"""将更新时间序列化为毫秒时间戳"""
return int(dt.timestamp() * 1000) if dt else None
model_config = ConfigDict(from_attributes=True)
class ClassBatchCreateResponse(BaseModel):
"""批量创建类型响应模型
用于返回批量创建的结果统计和详情。
Attributes:
total: 总共尝试创建的数量
success_count: 成功创建的数量
failed_count: 失败的数量
items: 成功创建的类型列表
errors: 失败的错误信息列表(可选)
Examples:
>>> response = ClassBatchCreateResponse(
... total=3,
... success_count=2,
... failed_count=1,
... items=[class1, class2],
... errors=["创建类型 '药品' 失败: 类型名称已存在"]
... )
"""
total: int = Field(..., description="总共尝试创建的数量")
success_count: int = Field(..., description="成功创建的数量")
failed_count: int = Field(0, description="失败的数量")
items: List[ClassResponse] = Field(..., description="成功创建的类型列表")
errors: Optional[List[str]] = Field(None, description="失败的错误信息列表")
class ClassListResponse(BaseModel):
"""类型列表响应模型
用于返回本体类型列表。
Attributes:
total: 总数量
scene_id: 所属场景ID
scene_name: 场景名称
scene_description: 场景描述
items: 类型列表
Examples:
>>> response = ClassListResponse(
... total=3,
... scene_id=uuid.uuid4(),
... scene_name="医疗场景",
... scene_description="用于医疗领域的本体建模",
... items=[class1, class2, class3]
... )
"""
total: int = Field(..., description="总数量")
scene_id: UUID = Field(..., description="所属场景ID")
scene_name: str = Field(..., description="场景名称")
scene_description: Optional[str] = Field(None, description="场景描述")
items: List[ClassResponse] = Field(..., description="类型列表")

View File

@@ -22,6 +22,23 @@ class PromptOptMessage(BaseModel):
)
class PromptSaveRequest(BaseModel):
session_id: UUID = Field(
...,
description="Session ID"
)
title: str = Field(
...,
description="Prompt Title"
)
prompt: str = Field(
...,
description="Optimized prompt content"
)
class PromptOptModelSet(BaseModel):
id: UUID | None = Field(
default=None,

View File

@@ -171,7 +171,14 @@ class AppChatService:
self.conversation_service.save_conversation_messages(
conversation_id=conversation_id,
user_message=message,
assistant_message=result["content"]
assistant_message=result["content"],
meta_data={
"usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
})
}
)
elapsed_time = time.time() - start_time
@@ -310,6 +317,7 @@ class AppChatService:
# 流式调用 Agent
full_content = ""
total_tokens = 0
async for chunk in agent.chat_stream(
message=message,
history=history,
@@ -320,9 +328,12 @@ class AppChatService:
config_id=config_id,
memory_flag=memory_flag
):
full_content += chunk
# 发送消息块事件
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
if isinstance(chunk, int):
total_tokens = chunk
else:
full_content += chunk
# 发送消息块事件
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
elapsed_time = time.time() - start_time
@@ -339,7 +350,7 @@ class AppChatService:
content=full_content,
meta_data={
"model": api_key_obj.model_name,
"usage": {}
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
}
)
@@ -416,7 +427,11 @@ class AppChatService:
meta_data={
"mode": result.get("mode"),
"elapsed_time": result.get("elapsed_time"),
"sub_results": result.get("sub_results")
"usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
})
}
)
@@ -458,6 +473,7 @@ class AppChatService:
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
full_content = ""
total_tokens = 0
# 2. 创建编排器
orchestrator = MultiAgentOrchestrator(self.db, config)
@@ -474,16 +490,26 @@ class AppChatService:
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
# 尝试提取内容(用于保存)
if "data:" in event:
try:
data_line = event.split("data: ", 1)[1].strip()
data = json.loads(data_line)
if "content" in data:
full_content += data["content"]
except:
pass
if "sub_usage" in event:
if "data:" in event:
try:
data_line = event.split("data: ", 1)[1].strip()
data = json.loads(data_line)
if "total_tokens" in data:
total_tokens += data["total_tokens"]
except:
pass
else:
yield event
# 尝试提取内容(用于保存)
if "data:" in event:
try:
data_line = event.split("data: ", 1)[1].strip()
data = json.loads(data_line)
if "content" in data:
full_content += data["content"]
except:
pass
elapsed_time = time.time() - start_time
@@ -499,7 +525,12 @@ class AppChatService:
role="assistant",
content=full_content,
meta_data={
"elapsed_time": elapsed_time
"elapsed_time": elapsed_time,
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": total_tokens
}
}
)

View File

@@ -1,4 +1,5 @@
"""会话服务"""
import os
import uuid
from datetime import datetime, timedelta
from typing import Annotated
@@ -298,7 +299,8 @@ class ConversationService:
self,
conversation_id: uuid.UUID,
user_message: str,
assistant_message: str
assistant_message: str,
meta_data: Optional[dict] = None
):
"""
Save a pair of user and assistant messages to the conversation.
@@ -307,6 +309,7 @@ class ConversationService:
conversation_id (uuid.UUID): Conversation UUID.
user_message (str): User's message content.
assistant_message (str): Assistant's response content.
meta_data (Optional[dict]): Optional metadata for the messages.
"""
self.add_message(
conversation_id=conversation_id,
@@ -317,7 +320,8 @@ class ConversationService:
self.add_message(
conversation_id=conversation_id,
role="assistant",
content=assistant_message
content=assistant_message,
meta_data=meta_data
)
logger.debug(
@@ -526,12 +530,12 @@ class ConversationService:
takeaways=[],
info_score=0,
)
with open('app/services/prompt/conversation_summary_system.jinja2', 'r', encoding='utf-8') as f:
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
with open(os.path.join(prompt_path, 'conversation_summary_system.jinja2'), 'r', encoding='utf-8') as f:
system_prompt = f.read()
rendered_system_message = Template(system_prompt).render()
with open('app/services/prompt/conversation_summary_user.jinja2', 'r', encoding='utf-8') as f:
with open(os.path.join(prompt_path, 'conversation_summary_user.jinja2'), 'r', encoding='utf-8') as f:
user_prompt = f.read()
rendered_user_message = Template(user_prompt).render(
language=language,

View File

@@ -442,7 +442,14 @@ class DraftRunService:
user_message=message,
assistant_message=result["content"],
app_id=agent_config.app_id,
user_id=user_id
user_id=user_id,
meta_data={
"usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
})
}
)
response = {
@@ -649,6 +656,7 @@ class DraftRunService:
# 9. 流式调用 Agent
full_content = ""
total_tokens = 0
async for chunk in agent.chat_stream(
message=message,
history=history,
@@ -659,14 +667,22 @@ class DraftRunService:
user_rag_memory_id=user_rag_memory_id,
memory_flag=memory_flag
):
full_content += chunk
# 发送消息块事件
yield self._format_sse_event("message", {
"content": chunk
})
if isinstance(chunk, int):
total_tokens = chunk
else:
full_content += chunk
# 发送消息块事件
yield self._format_sse_event("message", {
"content": chunk
})
elapsed_time = time.time() - start_time
if sub_agent:
yield self._format_sse_event("sub_usage", {
"total_tokens": total_tokens
})
# 10. 保存会话消息
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
await self._save_conversation_message(
@@ -674,7 +690,10 @@ class DraftRunService:
user_message=message,
assistant_message=full_content,
app_id=agent_config.app_id,
user_id=user_id
user_id=user_id,
meta_data={
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
}
)
# 11. 发送结束事件
@@ -898,6 +917,7 @@ class DraftRunService:
conversation_id: str,
user_message: str,
assistant_message: str,
meta_data: dict,
app_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None
) -> None:
@@ -909,6 +929,7 @@ class DraftRunService:
assistant_message: AI 回复消息
app_id: 应用ID未使用保留用于兼容性
user_id: 用户ID未使用保留用于兼容性
meta_data: token消耗
"""
try:
from app.services.conversation_service import ConversationService
@@ -927,7 +948,8 @@ class DraftRunService:
conversation_service.add_message(
conversation_id=conv_uuid,
role="assistant",
content=assistant_message
content=assistant_message,
meta_data=meta_data
)
logger.debug(

View File

@@ -17,12 +17,15 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.utils.config_utils import resolve_config_id
logger = get_business_logger()
class EmotionSuggestion(BaseModel):
"""情绪建议模型"""
type: str = Field(..., description="建议类型emotion_balance/activity_recommendation/social_connection/stress_management")
type: str = Field(...,
description="建议类型emotion_balance/activity_recommendation/social_connection/stress_management")
title: str = Field(..., description="建议标题")
content: str = Field(..., description="建议内容")
priority: str = Field(..., description="优先级high/medium/low")
@@ -37,33 +40,33 @@ class EmotionSuggestionsResponse(BaseModel):
class EmotionAnalyticsService:
"""情绪分析服务
提供情绪数据的分析和统计功能,包括:
- 情绪标签统计
- 情绪词云数据
- 情绪健康指数计算
- 个性化情绪建议生成
Attributes:
emotion_repo: 情绪数据仓储实例
"""
def __init__(self):
"""初始化情绪分析服务"""
connector = Neo4jConnector()
self.emotion_repo = EmotionRepository(connector)
logger.info("情绪分析服务初始化完成")
async def get_emotion_tags(
self,
end_user_id: str,
emotion_type: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
limit: int = 10
self,
end_user_id: str,
emotion_type: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
limit: int = 10
) -> Dict[str, Any]:
"""获取情绪标签统计
查询指定用户的情绪类型分布,包括计数、百分比和平均强度。
确保返回所有6个情绪维度joy、sadness、anger、fear、surprise、neutral
即使某些维度没有数据也会返回count=0的记录。
@@ -71,8 +74,8 @@ class EmotionAnalyticsService:
"""
try:
logger.info(f"获取情绪标签统计: user={end_user_id}, type={emotion_type}, "
f"start={start_date}, end={end_date}, limit={limit}")
f"start={start_date}, end={end_date}, limit={limit}")
# 调用仓储层查询
tags = await self.emotion_repo.get_emotion_tags(
end_user_id=end_user_id,
@@ -81,13 +84,13 @@ class EmotionAnalyticsService:
end_date=end_date,
limit=limit
)
# 定义所有6个情绪维度
all_emotion_types = ['joy', 'sadness', 'anger', 'fear', 'surprise', 'neutral']
# 将查询结果转换为字典,方便查找
tags_dict = {tag["emotion_type"]: tag for tag in tags}
# 补全缺失的情绪维度
complete_tags = []
for emotion in all_emotion_types:
@@ -101,52 +104,52 @@ class EmotionAnalyticsService:
"percentage": 0.0,
"avg_intensity": 0.0
})
# 计算总数
total_count = sum(tag["count"] for tag in complete_tags)
# 如果有数据重新计算百分比因为补全了0值项
if total_count > 0:
for tag in complete_tags:
if tag["count"] > 0:
tag["percentage"] = round((tag["count"] / total_count) * 100, 2)
# 构建时间范围信息
time_range = {}
if start_date:
time_range["start_date"] = start_date
if end_date:
time_range["end_date"] = end_date
# 格式化响应
response = {
"tags": complete_tags,
"total_count": total_count,
"time_range": time_range if time_range else None
}
logger.info(f"情绪标签统计完成: total_count={total_count}, tags_count={len(complete_tags)}")
return response
except Exception as e:
logger.error(f"获取情绪标签统计失败: {str(e)}", exc_info=True)
raise
async def get_emotion_wordcloud(
self,
end_user_id: str,
emotion_type: Optional[str] = None,
limit: int = 50
self,
end_user_id: str,
emotion_type: Optional[str] = None,
limit: int = 50
) -> Dict[str, Any]:
"""获取情绪词云数据
查询情绪关键词及其频率,用于生成词云可视化。
Args:
end_user_id: 宿主ID用户组ID
emotion_type: 可选的情绪类型过滤
limit: 返回关键词的最大数量
Returns:
Dict: 包含情绪词云数据的响应:
- keywords: 关键词列表
@@ -154,39 +157,39 @@ class EmotionAnalyticsService:
"""
try:
logger.info(f"获取情绪词云数据: user={end_user_id}, type={emotion_type}, limit={limit}")
# 调用仓储层查询
keywords = await self.emotion_repo.get_emotion_wordcloud(
end_user_id=end_user_id,
emotion_type=emotion_type,
limit=limit
)
# 计算总关键词数量
total_keywords = len(keywords)
# 格式化响应
response = {
"keywords": keywords,
"total_keywords": total_keywords
}
logger.info(f"情绪词云数据获取完成: total_keywords={total_keywords}")
return response
except Exception as e:
logger.error(f"获取情绪词云数据失败: {str(e)}", exc_info=True)
raise
def _calculate_positivity_rate(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]:
"""计算积极率
根据情绪类型分类正面、负面和中性情绪,计算积极率。
公式:(正面数 / (正面数 + 负面数)) * 100
Args:
emotions: 情绪数据列表,每个包含 emotion_type 字段
Returns:
Dict: 包含积极率计算结果:
- score: 积极率分数0-100
@@ -197,38 +200,38 @@ class EmotionAnalyticsService:
# 定义情绪分类
positive_emotions = {'joy', 'surprise'}
negative_emotions = {'sadness', 'anger', 'fear'}
# 统计各类情绪数量
positive_count = sum(1 for e in emotions if e.get('emotion_type') in positive_emotions)
negative_count = sum(1 for e in emotions if e.get('emotion_type') in negative_emotions)
neutral_count = sum(1 for e in emotions if e.get('emotion_type') == 'neutral')
# 计算积极率
total_non_neutral = positive_count + negative_count
if total_non_neutral > 0:
score = (positive_count / total_non_neutral) * 100
else:
score = 50.0 # 如果没有非中性情绪默认为50
logger.debug(f"积极率计算: positive={positive_count}, negative={negative_count}, "
f"neutral={neutral_count}, score={score:.2f}")
f"neutral={neutral_count}, score={score:.2f}")
return {
"score": round(score, 2),
"positive_count": positive_count,
"negative_count": negative_count,
"neutral_count": neutral_count
}
def _calculate_stability(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]:
"""计算稳定性
基于情绪强度的标准差计算情绪稳定性。
公式:(1 - min(std_deviation, 1.0)) * 100
Args:
emotions: 情绪数据列表,每个包含 emotion_intensity 字段
Returns:
Dict: 包含稳定性计算结果:
- score: 稳定性分数0-100
@@ -236,7 +239,7 @@ class EmotionAnalyticsService:
"""
# 提取所有情绪强度
intensities = [e.get('emotion_intensity', 0.0) for e in emotions if e.get('emotion_intensity') is not None]
# 计算标准差
if len(intensities) >= 2:
std_deviation = statistics.stdev(intensities)
@@ -244,29 +247,29 @@ class EmotionAnalyticsService:
std_deviation = 0.0 # 只有一个数据点标准差为0
else:
std_deviation = 0.0 # 没有数据标准差为0
# 计算稳定性分数
# 标准差越小,稳定性越高
score = (1 - min(std_deviation, 1.0)) * 100
logger.debug(f"稳定性计算: intensities_count={len(intensities)}, "
f"std_deviation={std_deviation:.3f}, score={score:.2f}")
f"std_deviation={std_deviation:.3f}, score={score:.2f}")
return {
"score": round(score, 2),
"std_deviation": round(std_deviation, 3)
}
def _calculate_resilience(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]:
"""计算恢复力
分析情绪转换模式,统计从负面情绪恢复到正面情绪的能力。
公式:(负面到正面转换次数 / 总负面情绪数) * 100
Args:
emotions: 情绪数据列表,每个包含 emotion_type 和 created_at 字段
应该按时间顺序排列
Returns:
Dict: 包含恢复力计算结果:
- score: 恢复力分数0-100
@@ -275,24 +278,24 @@ class EmotionAnalyticsService:
# 定义情绪分类
positive_emotions = {'joy', 'surprise'}
negative_emotions = {'sadness', 'anger', 'fear'}
# 统计负面到正面的转换次数
recovery_count = 0
negative_count = 0
for i in range(len(emotions)):
current_emotion = emotions[i].get('emotion_type')
# 统计负面情绪总数
if current_emotion in negative_emotions:
negative_count += 1
# 检查下一个情绪是否为正面
if i + 1 < len(emotions):
next_emotion = emotions[i + 1].get('emotion_type')
if next_emotion in positive_emotions:
recovery_count += 1
# 计算恢复力分数
if negative_count > 0:
recovery_rate = recovery_count / negative_count
@@ -301,28 +304,28 @@ class EmotionAnalyticsService:
# 如果没有负面情绪恢复力设为100最佳状态
recovery_rate = 1.0
score = 100.0
logger.debug(f"恢复力计算: negative_count={negative_count}, "
f"recovery_count={recovery_count}, score={score:.2f}")
f"recovery_count={recovery_count}, score={score:.2f}")
return {
"score": round(score, 2),
"recovery_rate": round(recovery_rate, 3)
}
async def calculate_emotion_health_index(
self,
end_user_id: str,
time_range: str = "30d"
self,
end_user_id: str,
time_range: str = "30d"
) -> Dict[str, Any]:
"""计算情绪健康指数
综合积极率、稳定性和恢复力计算情绪健康指数。
Args:
end_user_id: 宿主ID用户组ID
time_range: 时间范围7d/30d/90d
Returns:
Dict: 包含情绪健康指数的完整响应:
- health_score: 综合健康分数0-100
@@ -336,13 +339,13 @@ class EmotionAnalyticsService:
"""
try:
logger.info(f"计算情绪健康指数: user={end_user_id}, time_range={time_range}")
# 获取时间范围内的情绪数据
emotions = await self.emotion_repo.get_emotions_in_range(
end_user_id=end_user_id,
time_range=time_range
)
# 如果没有数据,返回默认值
if not emotions:
logger.warning(f"用户 {end_user_id} 在时间范围 {time_range} 内没有情绪数据")
@@ -357,20 +360,20 @@ class EmotionAnalyticsService:
"emotion_distribution": {},
"time_range": time_range
}
# 计算各维度指标
positivity_rate = self._calculate_positivity_rate(emotions)
stability = self._calculate_stability(emotions)
resilience = self._calculate_resilience(emotions)
# 计算综合健康分数
# 公式positivity_rate * 0.4 + stability * 0.3 + resilience * 0.3
health_score = (
positivity_rate["score"] * 0.4 +
stability["score"] * 0.3 +
resilience["score"] * 0.3
positivity_rate["score"] * 0.4 +
stability["score"] * 0.3 +
resilience["score"] * 0.3
)
# 确定健康等级
if health_score >= 80:
level = "优秀"
@@ -380,13 +383,13 @@ class EmotionAnalyticsService:
level = "一般"
else:
level = "较差"
# 统计情绪分布
emotion_distribution = {}
for emotion_type in ['joy', 'sadness', 'anger', 'fear', 'surprise', 'neutral']:
count = sum(1 for e in emotions if e.get('emotion_type') == emotion_type)
emotion_distribution[emotion_type] = count
# 格式化响应
response = {
"health_score": round(health_score, 2),
@@ -399,22 +402,22 @@ class EmotionAnalyticsService:
"emotion_distribution": emotion_distribution,
"time_range": time_range
}
logger.info(f"情绪健康指数计算完成: score={health_score:.2f}, level={level}")
return response
except Exception as e:
logger.error(f"计算情绪健康指数失败: {str(e)}", exc_info=True)
raise
def _analyze_emotion_patterns(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]:
"""分析情绪模式
识别主要负面情绪、情绪触发因素和波动时段。
Args:
emotions: 情绪数据列表,每个包含 emotion_type、emotion_intensity、created_at 字段
Returns:
Dict: 包含情绪模式分析结果:
- dominant_negative_emotion: 主要负面情绪类型
@@ -422,19 +425,19 @@ class EmotionAnalyticsService:
- emotion_volatility: 情绪波动性(高/中/低)
"""
negative_emotions = {'sadness', 'anger', 'fear'}
# 统计负面情绪分布
negative_emotion_counts = {}
for emotion in emotions:
emotion_type = emotion.get('emotion_type')
if emotion_type in negative_emotions:
negative_emotion_counts[emotion_type] = negative_emotion_counts.get(emotion_type, 0) + 1
# 识别主要负面情绪
dominant_negative_emotion = None
if negative_emotion_counts:
dominant_negative_emotion = max(negative_emotion_counts, key=negative_emotion_counts.get)
# 识别高强度情绪(强度 >= 0.7
high_intensity_emotions = [
{
@@ -445,7 +448,7 @@ class EmotionAnalyticsService:
for e in emotions
if e.get('emotion_intensity', 0) >= 0.7
]
# 评估情绪波动性
intensities = [e.get('emotion_intensity', 0.0) for e in emotions if e.get('emotion_intensity') is not None]
if len(intensities) >= 2:
@@ -458,29 +461,29 @@ class EmotionAnalyticsService:
volatility = ""
else:
volatility = "未知"
logger.debug(f"情绪模式分析: dominant_negative={dominant_negative_emotion}, "
f"high_intensity_count={len(high_intensity_emotions)}, volatility={volatility}")
f"high_intensity_count={len(high_intensity_emotions)}, volatility={volatility}")
return {
"dominant_negative_emotion": dominant_negative_emotion,
"high_intensity_emotions": high_intensity_emotions[:5], # 最多返回5个
"emotion_volatility": volatility
}
async def generate_emotion_suggestions(
self,
end_user_id: str,
db: Session,
self,
end_user_id: str,
db: Session,
) -> Dict[str, Any]:
"""生成个性化情绪建议
基于情绪健康数据和用户画像生成个性化建议。
Args:
end_user_id: 宿主ID用户组ID
db: 数据库会话
Returns:
Dict: 包含个性化建议的响应:
- health_summary: 健康状态摘要
@@ -488,17 +491,17 @@ class EmotionAnalyticsService:
"""
try:
logger.info(f"生成个性化情绪建议: user={end_user_id}")
# 1. 从 end_user_id 获取关联的 memory_config_id
llm_client = None
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
config_id = resolve_config_id(config_id, db)
if config_id is not None:
from app.services.memory_config_service import (
MemoryConfigService,
@@ -513,35 +516,35 @@ class EmotionAnalyticsService:
llm_client = factory.get_llm_client(str(memory_config.llm_model_id))
except Exception as e:
logger.warning(f"无法获取 end_user {end_user_id} 的配置,将使用默认配置: {e}")
# 2. 获取情绪健康数据
health_data = await self.calculate_emotion_health_index(end_user_id, time_range="30d")
# 3. 获取情绪数据用于模式分析
emotions = await self.emotion_repo.get_emotions_in_range(
end_user_id=end_user_id,
time_range="30d"
)
# 4. 分析情绪模式
patterns = self._analyze_emotion_patterns(emotions)
# 5. 获取用户画像数据简化版直接从Neo4j获取
user_profile = await self._get_simple_user_profile(end_user_id)
# 6. 构建LLM prompt
prompt = await self._build_suggestion_prompt(health_data, patterns, user_profile)
# 7. 调用LLM生成建议使用配置中的LLM
if llm_client is None:
# 无法获取配置时,抛出错误而不是使用默认配置
raise ValueError("无法获取LLM配置请确保end_user关联了有效的memory_config")
# 将 prompt 转换为 messages 格式
messages = [
{"role": "user", "content": prompt}
]
# 8. 使用结构化输出直接获取 Pydantic 模型
try:
suggestions_response = await llm_client.response_structured(
@@ -552,7 +555,7 @@ class EmotionAnalyticsService:
logger.error(f"LLM 结构化输出失败: {str(e)}")
# 返回默认建议
suggestions_response = self._get_default_suggestions(health_data)
# 8. 验证建议数量3-5条
if len(suggestions_response.suggestions) < 3:
logger.warning(f"建议数量不足: {len(suggestions_response.suggestions)}")
@@ -560,7 +563,7 @@ class EmotionAnalyticsService:
elif len(suggestions_response.suggestions) > 5:
logger.warning(f"建议数量过多: {len(suggestions_response.suggestions)}")
suggestions_response.suggestions = suggestions_response.suggestions[:5]
# 9. 格式化响应
response = {
"health_summary": suggestions_response.health_summary,
@@ -575,26 +578,26 @@ class EmotionAnalyticsService:
for s in suggestions_response.suggestions
]
}
logger.info(f"个性化建议生成完成: suggestions_count={len(response['suggestions'])}")
return response
except Exception as e:
logger.error(f"生成个性化建议失败: {str(e)}", exc_info=True)
raise
async def _get_simple_user_profile(self, end_user_id: str) -> Dict[str, Any]:
"""获取简化的用户画像数据
Args:
end_user_id: 用户ID
Returns:
Dict: 用户画像数据
"""
try:
connector = Neo4jConnector()
# 查询用户的实体和标签
query = """
MATCH (e:Entity)
@@ -603,59 +606,59 @@ class EmotionAnalyticsService:
ORDER BY e.created_at DESC
LIMIT 20
"""
entities = await connector.execute_query(query, end_user_id=end_user_id)
# 提取兴趣标签
interests = [e["name"] for e in entities if e.get("type") in ["INTEREST", "HOBBY"]][:5]
# 后期会引入用户的习惯。。
return {
"interests": interests if interests else ["未知"]
}
except Exception as e:
logger.error(f"获取用户画像失败: {str(e)}")
return {"interests": ["未知"]}
async def _build_suggestion_prompt(
self,
health_data: Dict[str, Any],
patterns: Dict[str, Any],
user_profile: Dict[str, Any]
self,
health_data: Dict[str, Any],
patterns: Dict[str, Any],
user_profile: Dict[str, Any]
) -> str:
"""构建情绪建议生成的prompt
Args:
health_data: 情绪健康数据
patterns: 情绪模式分析结果
user_profile: 用户画像数据
Returns:
str: LLM prompt
"""
from app.core.memory.utils.prompt.prompt_utils import (
render_emotion_suggestions_prompt,
)
prompt = await render_emotion_suggestions_prompt(
health_data=health_data,
patterns=patterns,
user_profile=user_profile
)
return prompt
def _get_default_suggestions(self, health_data: Dict[str, Any]) -> EmotionSuggestionsResponse:
"""获取默认建议当LLM调用失败时使用
Args:
health_data: 情绪健康数据
Returns:
EmotionSuggestionsResponse: 默认建议
"""
health_score = health_data.get('health_score', 0)
if health_score >= 80:
summary = "您的情绪健康状况优秀,请继续保持积极的生活态度。"
elif health_score >= 60:
@@ -664,7 +667,7 @@ class EmotionAnalyticsService:
summary = "您的情绪健康需要关注,建议采取一些改善措施。"
else:
summary = "您的情绪健康需要重点关注,建议寻求专业帮助。"
suggestions = [
EmotionSuggestion(
type="emotion_balance",
@@ -700,54 +703,54 @@ class EmotionAnalyticsService:
]
)
]
return EmotionSuggestionsResponse(
health_summary=summary,
suggestions=suggestions
)
async def get_cached_suggestions(
self,
end_user_id: str,
db: Session,
self,
end_user_id: str,
db: Session,
) -> Optional[Dict[str, Any]]:
"""从 Redis 缓存获取个性化情绪建议
Args:
end_user_id: 宿主ID用户组ID
db: 数据库会话(保留参数以保持接口兼容性)
Returns:
Dict: 缓存的建议数据,如果不存在或已过期返回 None
"""
try:
from app.cache.memory.emotion_memory import EmotionMemoryCache
logger.info(f"尝试从 Redis 缓存获取情绪建议: user={end_user_id}")
# 从 Redis 获取缓存
cached_data = await EmotionMemoryCache.get_emotion_suggestions(end_user_id)
if cached_data is None:
logger.info(f"用户 {end_user_id} 的建议缓存不存在或已过期")
return None
logger.info(f"成功从 Redis 缓存获取建议: user={end_user_id}")
return cached_data
except Exception as e:
logger.error(f"从 Redis 缓存获取建议失败: {str(e)}", exc_info=True)
return None
async def save_suggestions_cache(
self,
end_user_id: str,
suggestions_data: Dict[str, Any],
db: Session,
expires_hours: int = 24
self,
end_user_id: str,
suggestions_data: Dict[str, Any],
db: Session,
expires_hours: int = 24
) -> None:
"""保存建议到 Redis 缓存
Args:
end_user_id: 宿主ID用户组ID
suggestions_data: 建议数据
@@ -756,24 +759,24 @@ class EmotionAnalyticsService:
"""
try:
from app.cache.memory.emotion_memory import EmotionMemoryCache
logger.info(f"保存建议到 Redis 缓存: user={end_user_id}, expires={expires_hours}小时")
# 计算过期时间(秒)
expire_seconds = expires_hours * 3600
# 保存到 Redis
success = await EmotionMemoryCache.set_emotion_suggestions(
user_id=end_user_id,
suggestions_data=suggestions_data,
expire=expire_seconds
)
if success:
logger.info(f"建议缓存保存成功: user={end_user_id}")
else:
logger.warning(f"建议缓存保存失败: user={end_user_id}")
except Exception as e:
logger.error(f"保存建议缓存失败: {str(e)}", exc_info=True)
# 不抛出异常,缓存失败不应影响主流程

View File

@@ -4,7 +4,7 @@ import uuid
from typing import List, Dict, Any, Optional, AsyncGenerator, Annotated
from typing_extensions import TypedDict
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, AIMessageChunk
from langgraph.graph import StateGraph, START, END
from langgraph.types import Command
from langgraph.checkpoint.memory import MemorySaver
@@ -727,9 +727,12 @@ class HandoffsService:
# 提取响应
response_content = ""
total_tokens = 0
for msg in result.get("messages", []):
if isinstance(msg, AIMessage):
response_content = msg.content
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
break
return {
@@ -737,7 +740,12 @@ class HandoffsService:
"active_agent": result.get("active_agent"),
"response": response_content,
"message_count": len(result.get("messages", [])),
"handoff_count": result.get("handoff_count", 0)
"handoff_count": result.get("handoff_count", 0),
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": total_tokens
}
}
async def chat_stream(
@@ -830,6 +838,12 @@ class HandoffsService:
# 捕获 LLM 结束事件,输出收集到的工具调用
elif kind == "on_chat_model_end":
output_message = event.get("data", {}).get("output", {})
if isinstance(output_message, AIMessageChunk):
response_meta = output_message.response_metadata if hasattr(output_message, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
0) if response_meta else 0
yield f"event: sub_usage\ndata: {json.dumps({"total_tokens": total_tokens}, ensure_ascii=False)}\n\n"
if collected_tool_calls:
# 找到参数最完整的 transfer 工具调用
best_tc = None

View File

@@ -89,7 +89,6 @@ class WorkspaceAppService:
for release in app_releases:
memory_content = self._extract_memory_content(release.config)
memory_content=resolve_config_id(memory_content, self.db)
if memory_content and memory_content in processed_configs:
continue
@@ -122,16 +121,12 @@ class WorkspaceAppService:
def _get_memory_config(self, memory_content: str) -> Dict[str, Any]:
"""Retrieve memory_config information based on memory_content"""
try:
memory_config_result = MemoryConfigRepository.query_reflection_config_by_id(self.db, int(memory_content))
# memory_config_query, memory_config_params = MemoryConfigRepository.build_select_reflection(memory_content)
# memory_config_result = self.db.execute(text(memory_config_query), memory_config_params).fetchone()
# if memory_config_result is None:
# return None
memory_content = resolve_config_id(memory_content, self.db)
memory_config_result = MemoryConfigRepository.query_reflection_config_by_id(self.db, (memory_content))
if memory_config_result:
return {
"config_id": memory_config_result.config_id,
"config_id": memory_content,
"enable_self_reflexion": memory_config_result.enable_self_reflexion,
"iteration_period": memory_config_result.iteration_period,
"reflexion_range": memory_config_result.reflexion_range,
@@ -291,7 +286,7 @@ class MemoryReflectionService:
# 检查是否需要执行反思
should_execute = False
hours_diff = 0
if current_reflection_time is None:
# 首次执行反思
should_execute = True
@@ -303,11 +298,11 @@ class MemoryReflectionService:
reflection_time = datetime.fromisoformat(current_reflection_time)
else:
reflection_time = current_reflection_time
current_time = datetime.now()
time_diff = current_time - reflection_time
hours_diff = int(time_diff.total_seconds() / 3600)
# 检查是否达到反思周期
if hours_diff >= iteration_period:
should_execute = True
@@ -317,7 +312,7 @@ class MemoryReflectionService:
except (ValueError, TypeError) as e:
api_logger.warning(f"解析反思时间失败: {e},将执行反思")
should_execute = True
if should_execute:
api_logger.info(f"与上次的反思时间间隔为: {hours_diff} 小时")
# 3. 执行反思引擎
@@ -350,7 +345,7 @@ class MemoryReflectionService:
"next_reflection_in_hours": iteration_period - hours_diff
}
except Exception as e:
config_id = config_data.get("config_id", "unknown")
api_logger.error(f"启动反思失败config_id: {config_id}, end_user_id: {end_user_id}, 错误: {str(e)}")
@@ -361,7 +356,7 @@ class MemoryReflectionService:
"end_user_id": end_user_id,
"config_data": config_data
}
def _create_reflection_config_from_data(self, config_data: Dict[str, Any]) -> ReflectionConfig:
"""Create reflective configuration objects from configuration data"""
@@ -369,12 +364,12 @@ class MemoryReflectionService:
if reflexion_range_value is None or reflexion_range_value == "":
reflexion_range_value = "partial"
reflexion_range = ReflectionRange(reflexion_range_value)
baseline_value = config_data.get("baseline")
if baseline_value is None or baseline_value == "":
baseline_value = "TIME"
baseline = ReflectionBaseline(baseline_value)
# iteration_period =
iteration_period = config_data.get("iteration_period", 24)
if isinstance(iteration_period, str):
@@ -382,7 +377,6 @@ class MemoryReflectionService:
iteration_period = int(iteration_period)
except (ValueError, TypeError):
iteration_period = 24 # 默认24小时
return ReflectionConfig(
enabled=config_data.get("enable_self_reflexion", False),
iteration_period=str(iteration_period), # ReflectionConfig期望字符串

View File

@@ -508,10 +508,7 @@ class ModelApiKeyService:
)
if not validation_result["valid"]:
# 记录验证失败的模型,但不抛出异常
failed_models.append({
"model_name": model_name,
"error": validation_result["error"]
})
failed_models.append(model_name)
continue
# 创建API Key
@@ -692,6 +689,9 @@ class ModelBaseService:
@staticmethod
def create_model_base(db: Session, data: model_schema.ModelBaseCreate):
existing = ModelBaseRepository.get_by_name_and_provider(db, data.name, data.provider)
if existing:
raise BusinessException("模型已存在", BizCode.DUPLICATE_NAME)
model_base = ModelBaseRepository.create(db, data.model_dump())
db.commit()
db.refresh(model_base)

View File

@@ -280,14 +280,22 @@ class MultiAgentOrchestrator:
# 4. 提取子 Agent 的 conversation_id用于多轮对话
sub_conversation_id = None
total_tokens = 0
if isinstance(results, dict):
sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id")
# 提取 token 信息
usage = results.get("usage", {}) or results.get("result", {}).get("usage", {})
total_tokens += usage.get("total_tokens", 0)
elif isinstance(results, list) and results:
for item in results:
if "result" in item:
sub_conversation_id = item["result"].get("conversation_id")
if sub_conversation_id:
break
# 累加每个子 Agent 的 token
usage = item.get("usage", {}) or item.get("result", {}).get("usage", {})
total_tokens += usage.get("total_tokens", 0)
logger.info(
"多 Agent 任务完成",
@@ -301,9 +309,15 @@ class MultiAgentOrchestrator:
return {
"message": final_result,
"conversation_id": sub_conversation_id,
"mode": OrchestrationMode.SUPERVISOR,
"elapsed_time": elapsed_time,
"strategy": routing_decision.get("collaboration_strategy", "single"),
"sub_results": results
"sub_results": results,
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": total_tokens
}
}
except Exception as e:
@@ -1552,10 +1566,12 @@ class MultiAgentOrchestrator:
return {
"message": result.get("response", ""),
"conversation_id": result.get("conversation_id"),
"mode": OrchestrationMode.COLLABORATION,
"elapsed_time": elapsed_time,
"strategy": "collaboration",
"active_agent": result.get("active_agent"),
"sub_results": result
"sub_results": result,
"usage": result.get("usage")
}
except Exception as e:

View File

@@ -1,5 +1,6 @@
"""多 Agent 配置管理服务"""
import uuid
import json
from typing import Optional, List, Tuple, Any, Annotated
from fastapi import Depends
@@ -427,6 +428,23 @@ class MultiAgentService:
memory=getattr(request, 'memory', True) # 记忆功能参数
)
await self._save_conversation_message(
conversation_id=request.conversation_id,
user_message=request.message,
assistant_message=result.get("message", ""),
app_id=app_id,
user_id=request.user_id,
meta_data={
"mode": result.get("mode"),
"elapsed_time": result.get("elapsed_time"),
"usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
})
}
)
return result
async def run_stream(
@@ -451,11 +469,14 @@ class MultiAgentService:
raise ResourceNotFoundException("多 Agent 配置", str(app_id))
if not config.is_active:
raise BusinessException("多 Agent 配置已禁用", BizCode.RESOURCE_DISABLED)
raise BusinessException("多 Agent 配置已禁用", BizCode.NOT_FOUND)
# 2. 创建编排器
orchestrator = MultiAgentOrchestrator(self.db, config)
full_content = ""
total_tokens = 0
# 3. 流式执行任务
async for event in orchestrator.execute_stream(
message=request.message,
@@ -468,7 +489,88 @@ class MultiAgentService:
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
if "sub_usage" in event:
if "data:" in event:
try:
data_line = event.split("data: ", 1)[1].strip()
data = json.loads(data_line)
if "total_tokens" in data:
total_tokens += data["total_tokens"]
except:
pass
else:
yield event
if "data:" in event:
try:
data_line = event.split("data: ", 1)[1].strip()
data = json.loads(data_line)
if "content" in data:
full_content += data["content"]
except:
pass
await self._save_conversation_message(
conversation_id=request.conversation_id,
user_message=request.message,
assistant_message=full_content,
app_id=app_id,
user_id=request.user_id,
meta_data={
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": total_tokens
}
}
)
async def _save_conversation_message(
self,
conversation_id: uuid.UUID,
user_message: str,
assistant_message: str,
meta_data: dict,
app_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None
) -> None:
"""保存会话消息
Args:
conversation_id: 会话ID
user_message: 用户消息
assistant_message: AI 回复消息
meta_data: 元数据(包括 token 消耗)
app_id: 应用ID
user_id: 用户ID
"""
try:
from app.services.conversation_service import ConversationService
conversation_service = ConversationService(self.db)
conversation_service.add_message(
conversation_id=conversation_id,
role="user",
content=user_message
)
conversation_service.add_message(
conversation_id=conversation_id,
role="assistant",
content=assistant_message,
meta_data=meta_data
)
logger.debug(
"保存多 Agent 会话消息",
extra={
"conversation_id": conversation_id,
"user_message_length": len(user_message),
"assistant_message_length": len(assistant_message)
}
)
except Exception as e:
logger.warning("保存会话消息失败", extra={"error": str(e)})
# def add_sub_agent(
# self,

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,4 @@
import os
import re
import uuid
from typing import Any, AsyncGenerator
@@ -18,7 +19,8 @@ from app.models.prompt_optimizer_model import (
)
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository
from app.repositories.prompt_optimizer_repository import (
PromptOptimizerSessionRepository
PromptOptimizerSessionRepository,
PromptReleaseRepository
)
from app.schemas.prompt_optimizer_schema import OptimizePromptResult
@@ -28,6 +30,8 @@ logger = get_business_logger()
class PromptOptimizerService:
def __init__(self, db: Session):
self.db = db
self.optim_repo = PromptOptimizerSessionRepository(self.db)
self.release_repo = PromptReleaseRepository(self.db)
def get_model_config(
self,
@@ -78,10 +82,12 @@ class PromptOptimizerService:
Returns:
PromptOptimzerSession: The newly created prompt optimization session.
"""
session = PromptOptimizerSessionRepository(self.db).create_session(
session = self.optim_repo.create_session(
tenant_id=tenant_id,
user_id=user_id
)
self.db.commit()
self.db.refresh(session)
return session
def get_session_message_history(
@@ -106,7 +112,7 @@ class PromptOptimizerService:
- role (str): The role of the message sender, e.g., 'system', 'user', or 'assistant'.
- content (str): The content of the message.
"""
history = PromptOptimizerSessionRepository(self.db).get_session_history(
history = self.optim_repo.get_session_history(
session_id=session_id,
user_id=user_id
)
@@ -177,11 +183,12 @@ class PromptOptimizerService:
base_url=api_config.api_base
), type=ModelType(model_config.type))
try:
with open('app/services/prompt/prompt_optimizer_system.jinja2', 'r', encoding='utf-8') as f:
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
with open(os.path.join(prompt_path, 'prompt_optimizer_system.jinja2'), 'r', encoding='utf-8') as f:
opt_system_prompt = f.read()
rendered_system_message = Template(opt_system_prompt).render()
with open('app/services/prompt/prompt_optimizer_user.jinja2', 'r', encoding='utf-8') as f:
with open(os.path.join(prompt_path, 'prompt_optimizer_user.jinja2'), 'r', encoding='utf-8') as f:
opt_user_prompt = f.read()
except FileNotFoundError:
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
@@ -296,4 +303,165 @@ class PromptOptimizerService:
role=role,
content=content
)
self.db.commit()
self.db.refresh(message)
return message
def save_prompt(
self,
tenant_id: uuid.UUID,
session_id: uuid.UUID,
title: str,
prompt: str
) -> dict:
"""
Create and save a new prompt release for a given session.
Args:
tenant_id (uuid.UUID): The ID of the tenant owning the prompt.
session_id (uuid.UUID): The ID of the session to associate with this prompt.
title (str): The title of the prompt release.
prompt (str): The content of the prompt.
Returns:
dict: A dictionary containing:
- id (UUID): The unique ID of the created prompt release.
- session_id (UUID): The session ID linked to the release.
- title (str): The title of the prompt.
- prompt (str): The prompt content.
- created_at (int): Timestamp (in milliseconds) of when the prompt was created.
Raises:
BusinessException: If a prompt release already exists for the given session.
"""
session = self.optim_repo.get_session_by_id(session_id)
if session is None or session.tenant_id != tenant_id:
raise BusinessException(
"Session does not exist or the current user has no access",
BizCode.BAD_REQUEST
)
if self.release_repo.get_prompt_by_session_id(session_id):
raise BusinessException(
"A release already exists for the current session",
BizCode.BAD_REQUEST
)
prompt_obj = self.release_repo.create_prompt_release(
tenant_id=tenant_id,
title=title,
session_id=session_id,
prompt=prompt
)
self.db.commit()
self.db.refresh(prompt_obj)
return {
"id": prompt_obj.id,
"session_id": prompt_obj.session_id,
"title": prompt_obj.title,
"prompt": prompt_obj.prompt,
"created_at": int(prompt_obj.created_at.timestamp() * 1000)
}
def delete_prompt(
self,
tenant_id: uuid.UUID,
prompt_id: uuid.UUID
) -> None:
"""
Soft delete a prompt release by prompt_id.
Args:
tenant_id (uuid.UUID): Tenant identifier.
prompt_id (uuid.UUID): Prompt identifier.
Raises:
BusinessException: If the prompt does not exist or already deleted.
"""
prompt_obj = self.release_repo.get_prompt_by_id(prompt_id)
if not prompt_obj or prompt_obj.is_delete:
raise BusinessException(
"Prompt does not exist or has already been deleted",
BizCode.NOT_FOUND
)
if prompt_obj.tenant_id != tenant_id:
raise BusinessException(
"No permission to delete this prompt",
BizCode.FORBIDDEN
)
self.release_repo.soft_delete_prompt(prompt_obj)
self.db.commit()
logger.info(f"Prompt soft deleted, prompt_id={prompt_id}, tenant_id={tenant_id}")
def get_release_list(
self,
tenant_id: uuid.UUID,
page: int,
page_size: int,
filter_keyword: str | None = None
) -> dict[str, int | list[Any]]:
"""
Get paginated list of prompt releases with optional filter.
Args:
tenant_id (uuid.UUID): Tenant identifier.
page (int): Page number (starting from 1).
page_size (int): Number of items per page.
filter_keyword (str | None): Optional keyword to filter by title.
Returns:
dict: Contains total count, pagination info, and list of releases.
"""
offset = (page - 1) * page_size
# Get total count and releases based on filter
if filter_keyword:
total = self.release_repo.count_prompts_by_keyword(tenant_id, filter_keyword)
releases = self.release_repo.search_prompts_paginated(
tenant_id=tenant_id,
keyword=filter_keyword,
offset=offset,
limit=page_size
)
else:
total = self.release_repo.count_prompts(tenant_id)
releases = self.release_repo.get_prompts_paginated(
tenant_id=tenant_id,
offset=offset,
limit=page_size
)
items = []
for release in releases:
# Get first user message from session
first_message = self.optim_repo.get_first_user_message(
session_id=release.session_id
)
items.append({
"id": release.id,
"title": release.title,
"prompt": release.prompt,
"created_at": int(release.created_at.timestamp() * 1000),
"first_message": first_message
})
log_msg = f"Retrieved {len(items)} prompt releases, page={page}, tenant_id={tenant_id}"
if filter_keyword:
log_msg += f", filter='{filter_keyword}'"
logger.info(log_msg)
result = {
"page": {
"total": total,
"page": page,
"page_size": page_size,
"hasnext": page * page_size < total
},
"keyword": filter_keyword,
"items": items
}
return result

View File

@@ -282,7 +282,14 @@ class SharedChatService:
self.conversation_service.save_conversation_messages(
conversation_id=conversation.id,
user_message=message,
assistant_message=result["content"]
assistant_message=result["content"],
meta_data={
"usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
})
}
)
# self.conversation_service.add_message(
# conversation_id=conversation.id,
@@ -469,6 +476,7 @@ class SharedChatService:
# 流式调用 Agent
full_content = ""
total_tokens = 0
async for chunk in agent.chat_stream(
message=message,
history=history,
@@ -479,9 +487,12 @@ class SharedChatService:
config_id=config_id,
memory_flag=memory_flag
):
full_content += chunk
# 发送消息块事件
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
if isinstance(chunk, int):
total_tokens = chunk
else:
full_content += chunk
# 发送消息块事件
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
elapsed_time = time.time() - start_time
@@ -498,7 +509,7 @@ class SharedChatService:
content=full_content,
meta_data={
"model": api_key_obj.model_name,
"usage": {}
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
}
)

View File

@@ -774,7 +774,15 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
}
@celery_app.task(name="app.tasks.regenerate_memory_cache", bind=True)
@celery_app.task(
name="app.tasks.regenerate_memory_cache",
bind=True,
ignore_result=True,
max_retries=0,
acks_late=False,
time_limit=3600,
soft_time_limit=3300,
)
def regenerate_memory_cache(self) -> Dict[str, Any]:
"""定时任务:为所有用户重新生成记忆洞察和用户摘要缓存
@@ -966,7 +974,16 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
}
@celery_app.task(name="app.tasks.workspace_reflection_task", bind=True)
@celery_app.task(
name="app.tasks.workspace_reflection_task",
bind=True,
ignore_result=True,
max_retries=0,
acks_late=False,
time_limit=300,
soft_time_limit=240,
)
def workspace_reflection_task(self) -> Dict[str, Any]:
"""定时任务每30秒运行工作空间反思功能
@@ -1111,7 +1128,16 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
@celery_app.task(name="app.tasks.run_forgetting_cycle_task", bind=True)
@celery_app.task(
name="app.tasks.run_forgetting_cycle_task",
bind=True,
ignore_result=True,
max_retries=0,
acks_late=False,
time_limit=7200,
soft_time_limit=7000,
)
def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Dict[str, Any]:
"""定时任务:运行遗忘周期

View File

@@ -7,30 +7,31 @@ from uuid import UUID
from sqlalchemy.orm import Session
def resolve_config_id(config_id: UUID | int, db: Session) -> UUID:
def resolve_config_id(config_id: UUID | int|str, db: Session) -> UUID:
"""
解析 config_id如果是整数则通过 config_id_old 查找对应的 UUID
Args:
config_id: 配置IDUUID 或整数)
db: 数据库会话
Returns:
UUID: 解析后的配置ID
Raises:
ValueError: 当找不到对应的配置时
"""
from app.models.memory_config_model import MemoryConfig
if isinstance(config_id, UUID):
return config_id
if isinstance(config_id, str) and len(config_id)<=6:
memory_config = db.query(MemoryConfig).filter(
MemoryConfig.config_id_old == config_id
MemoryConfig.config_id_old == int(config_id)
).first()
print(memory_config)
if not memory_config:
raise ValueError(f"未找到 config_id_old={config_id} 对应的配置")
raise ValueError(f"STR 未找到 config_id_old={config_id} 对应的配置")
return memory_config.config_id
if isinstance(config_id, int):
memory_config = db.query(MemoryConfig).filter(
@@ -38,7 +39,7 @@ def resolve_config_id(config_id: UUID | int, db: Session) -> UUID:
).first()
if not memory_config:
raise ValueError(f"未找到 config_id_old={config_id} 对应的配置")
raise ValueError(f"INT 未找到 config_id_old={config_id} 对应的配置")
return memory_config.config_id

View File

@@ -1,4 +1,32 @@
{
"v0.2.2": {
"introduction": {
"codeName": "淬锋Temper",
"releaseDate": "2026-1-31",
"upgradePosition": "本次发布聚焦平台稳定性和性能优化。正如\"淬锋\"之名——千锤百炼,淬火成锋,我们通过严格测试和修复打磨系统品质。引入 Agent 工作流的代码执行能力、改进模型并发管理,并修复了记忆系统的多个关键问题。",
"coreUpgrades": [
"1. Agent平台增强<br>* 模型并发管理:优化模型广场的并发请求处理和资源分配能力。",
"2. 记忆系统优化<br>* Celery 队列修复:解决任务队列问题,提升异步记忆处理的可靠性<br>* 记忆 Agent 优化:提升记忆 Agent 的性能和效率<br>* 接口响应速度优化:优化记忆接口响应时间,加快操作速度。",
"3. 情绪记忆与识别升级<br>* 情绪记忆角色识别修复:解决情绪记忆上下文中的角色/人物识别问题<br>* 角色识别增强:提升对话记忆中的角色/人物识别准确性。",
"<br>",
"MemoryBear 持续致力于为 AI 应用提供类人记忆能力。本次以稳定性为核心的发布,进一步夯实了「感知→精炼→关联→遗忘」范式的基础。",
"未来版本将在此坚实基础上,扩展 Agent 能力并深化记忆智能特性。"
]
},
"introduction_en": {
"codeName": "Temper (淬锋)",
"releaseDate": "2026-1-31",
"upgradePosition": "This release focuses on platform stability and performance optimization — true to its codename \"淬锋\" (tempered blade), we've refined the system through rigorous testing and fixes. Introducing Python code execution for Agent workflows, improved model concurrency management, and critical fixes across the memory system.",
"coreUpgrades": [
"1. Agent Platform Enhancements<br>* Model Concurrency Management: Enhanced Model Plaza with improved concurrent model request handling and resource allocation.",
"2. Memory System Improvements<br>* Celery Queue Fix: Resolved task queue issues for more reliable asynchronous memory processing<br>* Memory Agent Optimization: Improved memory Agent performance and efficiency<br>* API Response Speed: Optimized memory interface response times for faster operations.",
"3. Emotional Memory & Recognition Upgrades<br>* Emotion Memory Role Recognition Fix: Resolved issues with role/character identification in emotional memory contexts<br>* Role Recognition Enhancement: Improved character/role identification accuracy in conversation memory.",
"<br>",
"MemoryBear continues advancing toward human-like memory capabilities for AI applications. This stability-focused release strengthens the foundation for our Perception → Refinement → Association → Forgetting paradigm.",
"Future releases will build on this solid base with expanded Agent capabilities and deeper memory intelligence features."
]
}
},
"v0.2.1": {
"introduction": {
"codeName": "启知",

View File

@@ -19,6 +19,7 @@ services:
depends_on:
- worker-memory
- worker-document
- worker-periodic
# Memory worker - Memory read/write tasks (threads pool for asyncio)
worker-memory:
@@ -48,6 +49,20 @@ services:
networks:
- celery
# Periodic worker - Scheduled/beat tasks (prefork, low concurrency)
worker-periodic:
image: redbear-mem-open:latest
container_name: worker-periodic
env_file:
- .env
volumes:
- ./files:/files
- /etc/localtime:/etc/localtime:ro
command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=prefork --concurrency=2 --queues=periodic_tasks --max-tasks-per-child=50 -n periodic_worker@%h
restart: unless-stopped
networks:
- celery
# Celery Beat - scheduler
beat:
image: redbear-mem-open:latest
@@ -69,7 +84,7 @@ services:
container_name: sandbox
ports:
- "8194"
command: /code/.venv/bin/python main.py
command: /code/.venv/bin/uvicorn main:app --host 0.0.0.0 --port 8194 --log-level debug
restart: unless-stopped
networks:
- sandbox

View File

@@ -1,4 +1,9 @@
# Language Configuration
# Supported values: "zh" (Chinese), "en" (English)
# This controls the language used for memory summary titles and other generated content
DEFAULT_LANGUAGE=zh
# Neo4j Configuration (记忆系统数据库)
NEO4J_URI=
NEO4J_USERNAME=

View File

@@ -28,7 +28,15 @@ def upgrade() -> None:
op.drop_constraint('data_config_pkey', 'memory_config', type_='primary')
op.alter_column('memory_config', 'config_id', new_column_name='config_id_old', nullable=True)
op.add_column('memory_config', sa.Column('config_id', sa.UUID(), nullable=True))
op.execute("UPDATE memory_config SET config_id = apply_id::uuid")
# Handle rows where apply_id might be NULL or invalid - generate new UUIDs for those
op.execute("""
UPDATE memory_config
SET config_id = CASE
WHEN apply_id IS NOT NULL AND apply_id ~ '^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$'
THEN apply_id::uuid
ELSE gen_random_uuid()
END
""")
op.alter_column('memory_config', 'config_id', nullable=False)
op.create_primary_key('memory_config_pkey', 'memory_config', ['config_id'])
op.execute("ALTER TABLE memory_config ALTER COLUMN config_id_old DROP DEFAULT")

View File

@@ -0,0 +1,78 @@
"""202601301521
Revision ID: 550c10595967
Revises: 5de9b1e28509
Create Date: 2026-01-30 15:24:34.647440
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '550c10595967'
down_revision: Union[str, None] = '5de9b1e28509'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('ontology_scene',
sa.Column('scene_id', sa.UUID(), nullable=False, comment='场景ID'),
sa.Column('scene_name', sa.String(length=200), nullable=False, comment='场景名称'),
sa.Column('scene_description', sa.Text(), nullable=True, comment='场景描述'),
sa.Column('workspace_id', sa.UUID(), nullable=False, comment='所属工作空间ID'),
sa.Column('created_at', sa.DateTime(), nullable=False, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), nullable=False, comment='更新时间'),
sa.ForeignKeyConstraint(['workspace_id'], ['workspaces.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('scene_id'),
sa.UniqueConstraint('workspace_id', 'scene_name', name='uq_workspace_scene_name')
)
op.create_index(op.f('ix_ontology_scene_scene_id'), 'ontology_scene', ['scene_id'], unique=False)
op.create_index(op.f('ix_ontology_scene_workspace_id'), 'ontology_scene', ['workspace_id'], unique=False)
op.create_table('ontology_class',
sa.Column('class_id', sa.UUID(), nullable=False, comment='类型ID'),
sa.Column('class_name', sa.String(length=200), nullable=False, comment='类型名称'),
sa.Column('class_description', sa.Text(), nullable=True, comment='类型描述'),
sa.Column('scene_id', sa.UUID(), nullable=False, comment='所属场景ID'),
sa.Column('created_at', sa.DateTime(), nullable=False, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), nullable=False, comment='更新时间'),
sa.ForeignKeyConstraint(['scene_id'], ['ontology_scene.scene_id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('class_id')
)
op.create_index(op.f('ix_ontology_class_class_id'), 'ontology_class', ['class_id'], unique=False)
op.create_index(op.f('ix_ontology_class_scene_id'), 'ontology_class', ['scene_id'], unique=False)
op.create_table('prompt_history',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('tenant_id', sa.UUID(), nullable=False, comment='Tenant ID'),
sa.Column('session_id', sa.UUID(), nullable=False, comment='Session ID'),
sa.Column('title', sa.String(), nullable=False, comment='Title'),
sa.Column('prompt', sa.Text(), nullable=False, comment='Prompt'),
sa.Column('created_at', sa.DateTime(), nullable=True, comment='Creation Time'),
sa.Column('is_delete', sa.Boolean(), nullable=True, comment='Delete'),
sa.ForeignKeyConstraint(['session_id'], ['prompt_opt_session_list.id'], ),
sa.ForeignKeyConstraint(['tenant_id'], ['tenants.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_prompt_history_created_at'), 'prompt_history', ['created_at'], unique=False)
op.create_index(op.f('ix_prompt_history_id'), 'prompt_history', ['id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_prompt_history_id'), table_name='prompt_history')
op.drop_index(op.f('ix_prompt_history_created_at'), table_name='prompt_history')
op.drop_table('prompt_history')
op.drop_index(op.f('ix_ontology_class_scene_id'), table_name='ontology_class')
op.drop_index(op.f('ix_ontology_class_class_id'), table_name='ontology_class')
op.drop_table('ontology_class')
op.drop_index(op.f('ix_ontology_scene_workspace_id'), table_name='ontology_scene')
op.drop_index(op.f('ix_ontology_scene_scene_id'), table_name='ontology_scene')
op.drop_table('ontology_scene')
# ### end Alembic commands ###

View File

@@ -0,0 +1,80 @@
"""20260129212722
Revision ID: 5de9b1e28509
Revises: 5ca246ee7dd4
Create Date: 2026-01-29 21:34:30.978031
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '5de9b1e28509'
down_revision: Union[str, None] = '5ca246ee7dd4'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Neo4j migration: rename group_id to end_user_id
import asyncio
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def run_neo4j_upgrade():
connector = Neo4jConnector()
try:
async def transaction_func(tx):
result = await tx.run("""
MATCH (n)
WHERE n.group_id IS NOT NULL
SET n.end_user_id = n.group_id
REMOVE n.group_id
WITH count(n) AS node_count
MATCH ()-[r]->()
WHERE r.group_id IS NOT NULL
SET r.end_user_id = r.group_id
REMOVE r.group_id
RETURN node_count, count(r) AS rel_count
""")
return await result.data()
await connector.execute_write_transaction(transaction_func)
finally:
await connector.close()
asyncio.run(run_neo4j_upgrade())
def downgrade() -> None:
# Neo4j migration: rename end_user_id back to group_id
import asyncio
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def run_neo4j_downgrade():
connector = Neo4jConnector()
try:
async def transaction_func(tx):
result = await tx.run("""
MATCH (n)
WHERE n.end_user_id IS NOT NULL
SET n.group_id = n.end_user_id
REMOVE n.end_user_id
WITH count(n) AS node_count
MATCH ()-[r]->()
WHERE r.end_user_id IS NOT NULL
SET r.group_id = r.end_user_id
REMOVE r.end_user_id
RETURN node_count, count(r) AS rel_count
""")
return await result.data()
await connector.execute_write_transaction(transaction_func)
finally:
await connector.close()
asyncio.run(run_neo4j_downgrade())

View File

@@ -0,0 +1,30 @@
"""202601301850
Revision ID: 9def72f79398
Revises: 550c10595967
Create Date: 2026-01-30 18:51:18.290796
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '9def72f79398'
down_revision: Union[str, None] = '550c10595967'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('memory_config', sa.Column('scene_id', sa.UUID(), nullable=True, comment='本体场景ID关联ontology_scene表'))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('memory_config', 'scene_id')
# ### end Alembic commands ###

View File

@@ -140,6 +140,7 @@ dependencies = [
"oss2>=2.19.1",
"flower>=2.0.1",
"aiofiles>=23.0.0",
"owlready2>=0.46",
]
[tool.pytest.ini_options]

View File

@@ -1,9 +1,10 @@
FROM python:3.12-slim
USER root
WORKDIR /code
LABEL authors="Eterntiy"
ARG NEED_MIRROR=0
ARG NEED_MIRROR=1
ENV DEBIAN_FRONTEND=noninteractive
RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \
if [ "$NEED_MIRROR" == "1" ]; then \
@@ -17,11 +18,14 @@ RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \
apt --no-install-recommends install -y ca-certificates && \
apt update && \
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
apt install -y nodejs npm && \
apt-get install -y --no-install-recommends tzdata libseccomp2 libseccomp-dev && \
ln -snf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \
echo "Asia/Shanghai" > /etc/timezone && \
apt install -y cargo
ENV PYTHONDONTWRITEBYTECODE=1
COPY ./app /code/app
COPY ./dependencies /code/dependencies
COPY ./lib /code/lib
@@ -33,10 +37,15 @@ COPY ./requirements.txt /code/requirements.txt
RUN python -m venv .venv
RUN .venv/bin/python3 -m pip install -r requirements.txt
RUN cargo build --release --manifest-path lib/seccomp_python/Cargo.toml
RUN npm install --prefix=/code/dependencies/nodejs koffi
HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \
RUN cargo build --release --manifest-path lib/seccomp_redbear/Cargo.toml --features python3
RUN mv lib/seccomp_redbear/target/release/libsandbox.so lib/seccomp_redbear/target/release/libpython.so
RUN cargo build --release --manifest-path lib/seccomp_redbear/Cargo.toml --features nodejs
RUN mv lib/seccomp_redbear/target/release/libsandbox.so lib/seccomp_redbear/target/release/libnodejs.so
HEALTHCHECK --interval=30s --timeout=5s --start-period=60s --retries=3 \
CMD curl 127.0.0.1:8194/health
CMD [".venv/bin/python3", "main.py"]
CMD [".venv/bin/uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8194", "--log-level", "debug"]

4
sandbox/app/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/1/29 14:33

View File

@@ -4,9 +4,6 @@ from typing import List, Optional
from pydantic import BaseModel, Field
import yaml
SANDBOX_USER_ID = 1000
SANDBOX_GROUP_ID = 1000
DEFAULT_PYTHON_LIB_REQUIREMENTS_AMD = [
"/usr/local/lib/python3.12",
"/usr/lib/python3",
@@ -15,13 +12,18 @@ DEFAULT_PYTHON_LIB_REQUIREMENTS_AMD = [
"/etc/nsswitch.conf",
"/etc/hosts",
"/etc/resolv.conf",
"/run/systemd/resolve/stub-resolv.conf",
"/run/resolvconf/resolv.conf",
"/etc/localtime",
"/usr/share/zoneinfo",
"/etc/timezone",
]
DEFAULT_NODEJS_LIB_REQUIREMENTS = [
"/etc/ssl/certs/ca-certificates.crt",
"/etc/nsswitch.conf",
"/etc/resolv.conf",
"/etc/hosts",
]
class AppConfig(BaseModel):
"""Application configuration"""
@@ -43,83 +45,77 @@ class Config(BaseModel):
max_workers: int = 4
max_requests: int = 50
worker_timeout: int = 30
nodejs_path: str = "node"
enable_network: bool = True
enable_preload: bool = False
python_path: str = ""
python_lib_paths: list = Field(default=DEFAULT_PYTHON_LIB_REQUIREMENTS_AMD)
python_deps_update_interval: str = "30m"
nodejs_path: str = ""
nodejs_lib_paths: list = Field(default=DEFAULT_NODEJS_LIB_REQUIREMENTS)
allowed_syscalls: List[int] = Field(default_factory=list)
proxy: ProxyConfig = Field(default_factory=ProxyConfig)
sandbox_user: str = "sandbox"
sandbox_uid: int = 65537
sandbox_gid: int = 0
def set_sandbox_gid(self, gid: int):
"""Update sandbox GID dynamically"""
self.sandbox_gid = gid
def override_with_env(self):
"""Override configuration with environment variables"""
env_map = {
"DEBUG": ("app.debug", lambda v: v.lower() in ("true", "1", "yes")),
"MAX_WORKERS": ("max_workers", int),
"MAX_REQUESTS": ("max_requests", int),
"SANDBOX_PORT": ("app.port", int),
"WORKER_TIMEOUT": ("worker_timeout", int),
"API_KEY": ("app.key", str),
"NODEJS_PATH": ("nodejs_path", str),
"ENABLE_NETWORK": ("enable_network", lambda v: v.lower() in ("true", "1", "yes")),
"ENABLE_PRELOAD": ("enable_preload", lambda v: v.lower() in ("true", "1", "yes")),
"ALLOWED_SYSCALLS": ("allowed_syscalls", lambda v: [int(x) for x in v.split(",")]),
"SOCKS5_PROXY": ("proxy.socks5", str),
"HTTP_PROXY": ("proxy.http", str),
"HTTPS_PROXY": ("proxy.https", str),
"PYTHON_PATH": ("python_path", str),
"PYTHON_LIB_PATH": ("python_lib_paths", lambda v: v.split(",")),
"PYTHON_DEPS_UPDATE_INTERVAL": ("python_deps_update_interval", str),
"NODEJS_LIB_PATH": ("nodejs_lib_paths", lambda v: v.split(",")),
}
for env_var, (attr_path, cast) in env_map.items():
value = os.getenv(env_var)
if value is not None:
# Support nested attributes like 'app.debug'
parts = attr_path.split(".")
obj = self
for part in parts[:-1]:
obj = getattr(obj, part)
setattr(obj, parts[-1], cast(value))
# Global configuration instance
_config: Optional[Config] = None
def load_config(config_path: str) -> Config:
"""Load configuration from YAML file"""
def load_config(config_path: str = "config.yaml") -> Config:
"""Load configuration from YAML file and override with env variables"""
global _config
# Load from file
if os.path.exists(config_path):
with open(config_path, 'r') as f:
data = yaml.safe_load(f)
data = yaml.safe_load(f) or {}
_config = Config(**data)
else:
_config = Config()
# Override with environment variables
if os.getenv("DEBUG"):
_config.app.debug = os.getenv("DEBUG").lower() in ("true", "1", "yes")
if os.getenv("MAX_WORKERS"):
_config.max_workers = int(os.getenv("MAX_WORKERS"))
if os.getenv("MAX_REQUESTS"):
_config.max_requests = int(os.getenv("MAX_REQUESTS"))
if os.getenv("SANDBOX_PORT"):
_config.app.port = int(os.getenv("SANDBOX_PORT"))
if os.getenv("WORKER_TIMEOUT"):
_config.worker_timeout = int(os.getenv("WORKER_TIMEOUT"))
if os.getenv("API_KEY"):
_config.app.key = os.getenv("API_KEY")
if os.getenv("NODEJS_PATH"):
_config.nodejs_path = os.getenv("NODEJS_PATH")
if os.getenv("ENABLE_NETWORK"):
_config.enable_network = os.getenv("ENABLE_NETWORK").lower() in ("true", "1", "yes")
if os.getenv("ENABLE_PRELOAD"):
_config.enable_preload = os.getenv("ENABLE_PRELOAD").lower() in ("true", "1", "yes")
if os.getenv("ALLOWED_SYSCALLS"):
_config.allowed_syscalls = [int(x) for x in os.getenv("ALLOWED_SYSCALLS").split(",")]
if os.getenv("SOCKS5_PROXY"):
_config.proxy.socks5 = os.getenv("SOCKS5_PROXY")
if os.getenv("HTTP_PROXY"):
_config.proxy.http = os.getenv("HTTP_PROXY")
if os.getenv("HTTPS_PROXY"):
_config.proxy.https = os.getenv("HTTPS_PROXY")
# python
if os.getenv("PYTHON_PATH"):
_config.python_path = os.getenv("PYTHON_PATH")
if os.getenv("PYTHON_LIB_PATH"):
_config.python_lib_paths = os.getenv("PYTHON_LIB_PATH").split(',')
if os.getenv("PYTHON_DEPS_UPDATE_INTERVAL"):
_config.python_deps_update_interval = os.getenv("PYTHON_DEPS_UPDATE_INTERVAL")
# Override from environment
_config.override_with_env()
return _config

View File

@@ -9,4 +9,4 @@ router = APIRouter()
@router.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint"""
return HealthResponse(status="healthy", version="2.0.0")
return HealthResponse(status="healthy", version="0.1.0")

View File

@@ -2,13 +2,15 @@
from fastapi import APIRouter, Depends
from app.middleware.auth import verify_api_key
from app.middleware.concurrency import check_max_requests, acquire_worker
from app.middleware.concurrency import concurrency_guard
from app.models import (
RunCodeRequest,
ApiResponse,
UpdateDependencyRequest,
error_response
)
from app.services.nodejs_service import run_nodejs_code
from app.services.python_service import (
run_python_code,
list_python_dependencies,
@@ -25,16 +27,14 @@ router = APIRouter(
@router.post(
"/run",
response_model=ApiResponse,
dependencies=[Depends(check_max_requests),
Depends(acquire_worker)]
dependencies=[Depends(concurrency_guard)]
)
async def run_code(request: RunCodeRequest):
"""Execute code in sandbox"""
if request.language == "python3":
return await run_python_code(request.code, request.preload, request.options)
elif request.language == "nodejs":
# TODO
return error_response(-400, "TODO")
return await run_nodejs_code(request.code, request.preload, request.options)
else:
return error_response(-400, "unsupported language")
@@ -55,5 +55,3 @@ async def update_dependencies(request: UpdateDependencyRequest):
return await update_python_dependencies()
else:
return error_response(-400, "unsupported language")

View File

@@ -1 +1,40 @@
"""Code runners package"""
import pwd
import subprocess
from app.config import get_config
from app.logger import get_logger
logger = get_logger()
def init_sandbox_user():
config = get_config()
sandbox_user = config.sandbox_user
sandbox_uid = config.sandbox_uid
try:
pwd.getpwnam(sandbox_user)
logger.info(f"User '{sandbox_user}' already exists")
except KeyError:
try:
subprocess.run(
["useradd", "-u", str(sandbox_uid), sandbox_user],
check=True,
capture_output=True,
text=True
)
logger.info(f"Created user '{sandbox_user}' with UID {sandbox_uid}")
except subprocess.CalledProcessError as e:
logger.error(f"Failed to create user: {e.stderr}")
raise RuntimeError(f"Failed to create user '{sandbox_user}': {e.stderr}") from e
try:
user_info = pwd.getpwnam(sandbox_user)
config.set_sandbox_gid(user_info.pw_gid)
logger.info(f"Sandbox user GID: {config.sandbox_gid}")
except KeyError as e:
logger.error(f"Failed to get GID for user '{sandbox_user}'")
raise RuntimeError(f"Failed to get GID for user '{sandbox_user}'") from e

View File

@@ -0,0 +1,3 @@
from app.core.runners.nodejs.env import release_lib_binary
release_lib_binary(True)

View File

@@ -0,0 +1,124 @@
import asyncio
import ctypes
import os
import shutil
import stat
import tempfile
from pathlib import Path
from app.logger import get_logger
from app.config import get_config
logger = get_logger()
RELEASE_LIB_PATH = "./lib/seccomp_redbear/target/release/libnodejs.so"
LIB_PATH = "/var/sandbox/sandbox-nodejs"
LIB_NAME = "libnodejs.so"
lib = ctypes.CDLL(RELEASE_LIB_PATH)
lib.get_lib_version_static.restype = ctypes.c_char_p
lib.get_lib_feature_static.restype = ctypes.c_char_p
logger.info(f"Seccomp Env: nodejs, "
f"Seccomp Feature: {lib.get_lib_feature_static().decode('utf-8')}, "
f"Seccomp Version: {lib.get_lib_version_static().decode('utf-8')}")
try:
with open(RELEASE_LIB_PATH, "rb") as f:
_NODEJS_LIB = f.read()
except:
logger.critical("failed to load nodejs lib")
raise
def check_lib_avaiable():
return os.path.exists(os.path.join(LIB_PATH, LIB_NAME))
def release_lib_binary(force_remove: bool):
logger.info("init runtime enviroment")
lib_file = os.path.join(LIB_PATH, LIB_NAME)
if os.path.exists(lib_file):
if force_remove:
try:
os.remove(lib_file)
except OSError:
logger.critical(f"failed to remove {os.path.join(LIB_PATH, LIB_NAME)}")
raise
try:
os.makedirs(LIB_PATH, mode=0o755, exist_ok=True)
except OSError:
logger.critical(f"failed to create {LIB_PATH}")
raise
try:
with open(lib_file, "wb") as f:
f.write(_NODEJS_LIB)
os.chmod(lib_file, 0o755)
except OSError:
logger.critical(f"failed to write {lib_file}")
raise
else:
try:
os.makedirs(LIB_PATH, mode=0o755, exist_ok=True)
except OSError:
logger.critical(f"failed to create {LIB_PATH}")
raise
try:
with open(lib_file, "wb") as f:
f.write(_NODEJS_LIB)
os.chmod(lib_file, 0o755)
except OSError:
logger.critical(f"failed to write {lib_file}")
raise
logger.info("nodejs runner environment initialized")
async def prepare_nodejs_dependencies_env():
config = get_config()
with tempfile.TemporaryDirectory(dir="/") as root_path:
root = Path(root_path)
env_sh = root / "env.sh"
with open("script/env.sh") as f:
env_sh.write_text(f.read())
env_sh.chmod(env_sh.stat().st_mode | stat.S_IXUSR)
shutil.copytree("dependencies/nodejs", os.path.join(LIB_PATH, "node_temp"), dirs_exist_ok=True)
for root, dirs, files in os.walk(os.path.join(LIB_PATH, "node_temp")):
for d in dirs:
os.chmod(os.path.join(root, d), 0o755)
for f in files:
os.chmod(os.path.join(root, f), 0o444)
for lib_path in config.nodejs_lib_paths:
lib_path = Path(lib_path)
if not lib_path.exists():
logger.warning("nodejs lib path %s is not available", lib_path)
continue
cmd = [
"bash",
str(env_sh),
str(lib_path),
str(LIB_PATH),
]
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
retcode = process.returncode
if retcode != 0:
logger.error(
f"create env error for file {lib_path}: retcode={retcode}, stderr={stderr.decode()}"
)

View File

@@ -0,0 +1,138 @@
"""Nodejs code runner"""
import asyncio
import os
import uuid
from typing import Optional
from app.core.executor import CodeExecutor, ExecutionResult
from app.core.runners.nodejs.env import check_lib_avaiable, release_lib_binary, LIB_PATH
from app.logger import get_logger
from app.models import RunnerOptions
# Nodejs sandbox prescript template
with open("app/core/runners/nodejs/prescript.js") as f:
NODEJS_PRESCRIPT = f.read()
logger = get_logger()
class NodejsRunner(CodeExecutor):
"""Node.js code runner with security isolation"""
def __init__(self):
super().__init__()
@staticmethod
def init_environment(code: str, preload: str) -> str:
if not check_lib_avaiable():
release_lib_binary(False)
code_file_name = uuid.uuid4().hex.replace("-", "_")
script = NODEJS_PRESCRIPT.replace("{{preload}}", preload, 1)
eval_code = f"eval(Buffer.from('{code}', 'base64').toString('utf-8'))"
script = script.replace("{{code}}", eval_code, 1)
code_path = f"{LIB_PATH}/node_temp/tmp/{code_file_name}.js"
try:
os.makedirs(os.path.dirname(code_path), mode=0o755, exist_ok=True)
with open(code_path, "w", encoding="utf-8") as f:
f.write(script)
os.chmod(code_path, 0o755)
except OSError as e:
raise RuntimeError(f"Failed to write {code_path}") from e
return code_path
async def run(
self,
code: str,
options: RunnerOptions,
preload: str = "",
timeout: Optional[int] = None
) -> ExecutionResult:
"""Run Python code in sandbox
Args:
options:
code: Base64 encoded encrypted code
preload: Preload code to execute before main code
timeout: Execution timeout in seconds
Returns:
ExecutionResult with stdout, stderr, and exit code
"""
config = self.config
if timeout is None:
timeout = config.worker_timeout
# Check if preload is allowed
if not preload or not config.enable_preload:
preload = ""
script_path = self.init_environment(code, preload)
try:
# Setup environment
env = {
"UV_USE_IO_URING": "0"
}
# Add proxy settings if configured
if config.proxy.socks5:
env["HTTPS_PROXY"] = config.proxy.socks5
env["HTTP_PROXY"] = config.proxy.socks5
elif config.proxy.https or config.proxy.http:
if config.proxy.https:
env["HTTPS_PROXY"] = config.proxy.https
if config.proxy.http:
env["HTTP_PROXY"] = config.proxy.http
# Add allowed syscalls if configured
if config.allowed_syscalls:
env["ALLOWED_SYSCALLS"] = ",".join(map(str, config.allowed_syscalls))
process = await asyncio.create_subprocess_exec(
config.nodejs_path,
script_path,
LIB_PATH,
str(config.sandbox_uid),
str(config.sandbox_gid),
options.model_dump_json(),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=env,
cwd=LIB_PATH
)
# Wait for completion with timeout
try:
stdout, stderr = await asyncio.wait_for(
process.communicate(),
timeout=timeout
)
return ExecutionResult(
stdout=stdout.decode('utf-8', errors='replace'),
stderr=stderr.decode('utf-8', errors='replace'),
exit_code=process.returncode
)
except asyncio.TimeoutError:
# Kill process on timeout
try:
process.kill()
await process.wait()
except:
pass
return ExecutionResult(
stdout="",
stderr="Execution timeout",
exit_code=-1,
)
finally:
# Cleanup temporary file
self.cleanup_temp_file(script_path)

View File

@@ -0,0 +1,31 @@
let argv = process.argv
let koffi = require('koffi')
process.chdir(argv[2])
let lib = koffi.load("./libnodejs.so")
/** @type {(uid: number, gid: number, enableNetwork: boolean) => number} */
let initSeccomp = lib.func('int init_seccomp(int, int, bool)')
let uid = parseInt(argv[3])
let gid = parseInt(argv[4])
let options = JSON.parse(argv[5])
let seccomp_init = initSeccomp(uid, gid, options['enable_network'])
if (seccomp_init !== 0) {
throw `code executor err - ${seccomp_init}`
}
delete process.argv
argv = undefined
koffi = undefined
lib = undefined
initSeccomp = undefined
uid = undefined
gid = undefined
options = undefined
seccomp_init = undefined
{{code}}

View File

@@ -1,4 +1,3 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/1/23 11:27
from app.core.runners.python.env import release_lib_binary
release_lib_binary(True)

View File

@@ -1,14 +1,80 @@
import asyncio
import tempfile
import ctypes
import os
import stat
import tempfile
from pathlib import Path
from app.config import get_config
from app.core.runners.python.settings import LIB_PATH
from app.logger import get_logger
logger = get_logger()
RELEASE_LIB_PATH = "./lib/seccomp_redbear/target/release/libpython.so"
LIB_PATH = "/var/sandbox/sandbox-python"
LIB_NAME = "libpython.so"
lib = ctypes.CDLL(RELEASE_LIB_PATH)
lib.get_lib_version_static.restype = ctypes.c_char_p
lib.get_lib_feature_static.restype = ctypes.c_char_p
logger.info(f"Seccomp Env: python3, "
f"Seccomp Feature: {lib.get_lib_feature_static().decode('utf-8')}, "
f"Seccomp Version: {lib.get_lib_version_static().decode('utf-8')}")
try:
with open(RELEASE_LIB_PATH, "rb") as f:
_PYTHON_LIB = f.read()
except:
logger.critical("failed to load python lib")
raise
def check_lib_avaiable():
return os.path.exists(os.path.join(LIB_PATH, LIB_NAME))
def release_lib_binary(force_remove: bool):
logger.info("init runtime enviroment")
lib_file = os.path.join(LIB_PATH, LIB_NAME)
if os.path.exists(lib_file):
if force_remove:
try:
os.remove(lib_file)
except OSError:
logger.critical(f"failed to remove {os.path.join(LIB_PATH, LIB_NAME)}")
raise
try:
os.makedirs(LIB_PATH, mode=0o755, exist_ok=True)
except OSError:
logger.critical(f"failed to create {LIB_PATH}")
raise
try:
with open(lib_file, "wb") as f:
f.write(_PYTHON_LIB)
os.chmod(lib_file, 0o755)
except OSError:
logger.critical(f"failed to write {lib_file}")
raise
else:
try:
os.makedirs(LIB_PATH, mode=0o755, exist_ok=True)
except OSError:
logger.critical(f"failed to create {LIB_PATH}")
raise
try:
with open(lib_file, "wb") as f:
f.write(_PYTHON_LIB)
os.chmod(lib_file, 0o755)
except OSError:
logger.critical(f"failed to write {lib_file}")
raise
logger.info("python runner environment initialized")
async def prepare_python_dependencies_env():
config = get_config()

View File

@@ -17,7 +17,7 @@ sys.excepthook = excepthook
# Load security library if available
lib = ctypes.CDLL("./libpython.so")
lib.init_seccomp.argtypes = [ctypes.c_uint32, ctypes.c_uint32, ctypes.c_bool]
lib.init_seccomp.restype = None # TODO: raise error info
lib.init_seccomp.restype = ctypes.c_int
# Get running path
running_path = sys.argv[1]
@@ -37,7 +37,10 @@ os.chdir(running_path)
{{preload}}
# Apply security if library is available
lib.init_seccomp({{uid}}, {{gid}}, {{enable_network}})
init_status = lib.init_seccomp({{uid}}, {{gid}}, {{enable_network}})
if init_status != 0:
raise Exception(f"code executor err - {str(init_status)}")
del lib
# Decrypt and execute code
code = b64decode("{{code}}")

View File

@@ -5,10 +5,10 @@ import os
import uuid
from typing import Optional
from app.config import SANDBOX_USER_ID, SANDBOX_GROUP_ID, get_config
from app.config import get_config
from app.core.encryption import generate_key, encrypt_code
from app.core.executor import CodeExecutor, ExecutionResult
from app.core.runners.python.settings import check_lib_avaiable, release_lib_binary, LIB_PATH
from app.core.runners.python.env import check_lib_avaiable, release_lib_binary, LIB_PATH
from app.logger import get_logger
from app.models import RunnerOptions
@@ -32,8 +32,8 @@ class PythonRunner(CodeExecutor):
config = get_config()
code_file_name = uuid.uuid4().hex.replace("-", "_")
script = PYTHON_PRESCRIPT.replace("{{uid}}", str(SANDBOX_USER_ID), 1)
script = script.replace("{{gid}}", str(SANDBOX_GROUP_ID), 1)
script = PYTHON_PRESCRIPT.replace("{{uid}}", str(config.sandbox_uid), 1)
script = script.replace("{{gid}}", str(config.sandbox_gid), 1)
script = script.replace(
"{{enable_network}}",
str(int(options.enable_network and config.enable_network)

View File

@@ -1,62 +0,0 @@
import os
from app.logger import get_logger
logger = get_logger()
RELEASE_LIB_PATH = "./lib/seccomp_python/target/release/libpython.so"
LIB_PATH = "/var/sandbox/sandbox-python"
LIB_NAME = "libpython.so"
try:
with open(RELEASE_LIB_PATH, "rb") as f:
_PYTHON_LIB = f.read()
except:
logger.critical("failed to load python lib")
raise
def check_lib_avaiable():
return os.path.exists(os.path.join(LIB_PATH, LIB_NAME))
def release_lib_binary(force_remove: bool):
logger.info("init runtime enviroment")
lib_file = os.path.join(LIB_PATH, LIB_NAME)
if os.path.exists(lib_file):
if force_remove:
try:
os.remove(lib_file)
except OSError:
logger.critical(f"failed to remove {os.path.join(LIB_PATH, LIB_NAME)}")
raise
try:
os.makedirs(LIB_PATH, mode=0o755, exist_ok=True)
except OSError:
logger.critical(f"failed to create {LIB_PATH}")
raise
try:
with open(lib_file, "wb") as f:
f.write(_PYTHON_LIB)
os.chmod(lib_file, 0o755)
except OSError:
logger.critical(f"failed to write {lib_file}")
raise
else:
try:
os.makedirs(LIB_PATH, mode=0o755, exist_ok=True)
except OSError:
logger.critical(f"failed to create {LIB_PATH}")
raise
try:
with open(lib_file, "wb") as f:
f.write(_PYTHON_LIB)
os.chmod(lib_file, 0o755)
except OSError:
logger.critical(f"failed to write {lib_file}")
raise
logger.info("python runner environment initialized")

View File

@@ -4,6 +4,7 @@ from pathlib import Path
from typing import List, Dict
from app.config import get_config
from app.core.runners.nodejs.env import prepare_nodejs_dependencies_env
from app.core.runners.python.env import prepare_python_dependencies_env
from app.logger import get_logger
@@ -19,7 +20,10 @@ async def setup_dependencies():
logger.info("Preparing Python dependencies environment...")
await prepare_python_dependencies_env()
logger.info("Python dependencies environment ready")
logger.info("Python Environment Ready ....")
logger.info("Preparing Nodejs dependencies environment...")
await prepare_nodejs_dependencies_env()
logger.info("Nodejs Environment Ready ...")
except Exception as e:
logger.error(f"Failed to setup dependencies: {e}")
@@ -36,7 +40,7 @@ async def install_python_dependencies():
config = get_config()
# Check if requirements file exists
req_file = Path("dependencies/python-requirements.txt")
req_file = Path("dependencies/python/python-requirements.txt")
if not req_file.exists():
logger.warning("Python requirements file not found, skipping installation")
return

View File

@@ -12,25 +12,27 @@ def setup_logger() -> logging.Logger:
"""Setup application logger"""
global _logger
if _logger is not None:
return _logger
config = get_config()
# Create logger
_logger = logging.getLogger("sandbox")
_logger.setLevel(logging.DEBUG if config.app.debug else logging.INFO)
# Create console handler
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG if config.app.debug else logging.INFO)
# 只在 logger 没有 handler 时才添加
if not _logger.handlers:
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG if config.app.debug else logging.INFO)
# Create formatter
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
handler.setFormatter(formatter)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
handler.setFormatter(formatter)
# Add handler to logger
_logger.addHandler(handler)
_logger.addHandler(handler)
return _logger

View File

@@ -1,48 +1,66 @@
"""Concurrency control middleware"""
"""
Concurrency control middleware
"""
import asyncio
from contextlib import asynccontextmanager
from fastapi import HTTPException, status
from app.config import get_config
from app.models import error_response
from app.logger import get_logger
logger = get_logger()
# Global semaphores
_worker_semaphore: None | asyncio.Semaphore = None
_request_counter = 0
_request_lock = asyncio.Lock()
class ConcurrencyController:
def __init__(self):
self._worker_semaphore: asyncio.Semaphore | None = None
self._request_counter = 0
self._lock = asyncio.Lock()
config = get_config()
self.max_requests = config.max_requests
def init(self):
config = get_config()
self._worker_semaphore = asyncio.Semaphore(config.max_workers)
async def _acquire_worker(self):
if self._worker_semaphore is None:
self.init()
async with self._worker_semaphore:
yield
async def _limit_requests(self):
async with self._lock:
logger.info(f"Current requests: {self._request_counter}/{self.max_requests}")
if self._request_counter >= self.max_requests:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail={
"code": 503,
"message": "Too many requests",
"data": None,
}
)
self._request_counter += 1
try:
yield
finally:
async with self._lock:
self._request_counter -= 1
def acquire_worker(self):
return asynccontextmanager(self._acquire_worker)()
def limit_requests(self):
return asynccontextmanager(self._limit_requests)()
def init_concurrency_control():
"""Initialize concurrency control"""
global _worker_semaphore
config = get_config()
_worker_semaphore = asyncio.Semaphore(config.max_workers)
concurrency = ConcurrencyController()
async def check_max_requests():
"""Check if max requests limit is reached"""
global _request_counter
config = get_config()
async with _request_lock:
if _request_counter >= config.max_requests:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=error_response(-503, "Too many requests")
)
_request_counter += 1
try:
yield
finally:
async with _request_lock:
_request_counter -= 1
async def acquire_worker():
"""Acquire a worker slot"""
if _worker_semaphore is None:
init_concurrency_control()
async with _worker_semaphore:
yield
async def concurrency_guard():
async with concurrency.limit_requests():
async with concurrency.acquire_worker():
yield

View File

@@ -0,0 +1,43 @@
"""Nodejs execution service"""
import signal
from app.core.runners.nodejs.nodejs_runner import NodejsRunner
from app.logger import get_logger
from app.models import (
success_response,
error_response,
RunCodeResponse,
RunnerOptions
)
async def run_nodejs_code(code: str, preload: str, options: RunnerOptions):
"""Execute Node.js code in sandbox
Args:
options:
code: Base64 encoded encrypted code
preload: Preload code
Returns:
API response with execution result
"""
logger = get_logger()
try:
runner = NodejsRunner()
result = await runner.run(code, options, preload)
if result.exit_code == signal.SIGSYS + 0x80:
return error_response(31, "sandbox security policy violation")
if result.exit_code != 0:
return error_response(500, result.stderr)
return success_response(RunCodeResponse(
stdout=result.stdout,
stderr=result.stderr
))
except Exception as e:
logger.error(f"Python execution failed: {e}", exc_info=True)
return error_response(-500, str(e))

View File

@@ -1,13 +1,11 @@
app:
port: 8194
debug: true
key: redbear-sandbox
max_workers: 4
max_requests: 50
worker_timeout: 30
max_workers: 10
max_requests: 300
worker_timeout: 15
python_path: /usr/local/bin/python
nodejs_path: /usr/local/bin/node
nodejs_path: /usr/bin/node
enable_network: true
enable_preload: false
python_deps_update_interval: 30m

View File

@@ -0,0 +1,6 @@
{
"name": "nodejs",
"lockfileVersion": 3,
"requires": true,
"packages": {}
}

View File

@@ -0,0 +1,6 @@
{
"name": "nodejs",
"lockfileVersion": 3,
"requires": true,
"packages": {}
}

View File

@@ -0,0 +1 @@
{}

View File

@@ -1,7 +0,0 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
[[package]]
name = "seccomp_nodejs"
version = "0.1.0"

View File

@@ -1,6 +0,0 @@
[package]
name = "seccomp_nodejs"
version = "0.1.0"
edition = "2024"
[dependencies]

View File

@@ -15,8 +15,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60276e2d41bbb68b323e566047a1bfbf952050b157d8b5cdc74c07c1bf4ca3b6"
[[package]]
name = "seccomp_python"
version = "0.1.0"
name = "seccomp_redbear"
version = "0.1.1"
dependencies = [
"libc",
"libseccomp-sys",

View File

@@ -1,12 +1,17 @@
[package]
name = "seccomp_python"
version = "0.1.0"
name = "seccomp_redbear"
version = "0.1.1"
edition = "2024"
[lib]
name = "python"
name = "sandbox"
crate-type = ["cdylib"]
[dependencies]
libc = "0.2.180"
libseccomp-sys = "0.3.0"
[features]
default = []
python3 = []
nodejs = []

View File

@@ -1,13 +1,25 @@
mod syscalls;
#[cfg(all(feature = "python3", feature = "nodejs"))]
compile_error!("Only one feature can be enabled: either python3 or nodejs, not both!");
use crate::syscalls::*;
use libc::{chdir, chroot, gid_t, uid_t, c_int};
#[cfg(not(any(feature = "python3", feature = "nodejs")))]
compile_error!("You must enable one feature: either python3 or nodejs");
#[cfg(feature = "python3")]
mod python_syscalls;
#[cfg(feature = "python3")]
use crate::python_syscalls::*;
#[cfg(feature = "nodejs")]
mod nodejs_syscalls;
#[cfg(feature = "nodejs")]
use crate::nodejs_syscalls::*;
use libc::{c_char, c_int, chdir, chroot, gid_t, uid_t};
use libseccomp_sys::*;
use std::env;
use std::ffi::CString;
use std::str::FromStr;
/*
* get_allowed_syscalls - retrieve allowed syscalls for the sandbox
* @enable_network: enable network-related syscalls if non-zero
@@ -193,3 +205,20 @@ pub unsafe extern "C" fn init_seccomp(uid: uid_t, gid: gid_t, enable_network: i3
Err(code) => code,
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn get_lib_version_static() -> *const c_char {
concat!(env!("CARGO_PKG_VERSION"), "\0").as_ptr() as *const c_char
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn get_lib_feature_static() -> *const c_char {
#[cfg(feature = "python3")]
let s = b"python3\0";
#[cfg(feature = "nodejs")]
let s = b"nodejs\0";
#[cfg(not(any(feature = "python3", feature = "nodejs")))]
let s = b"none\0";
s.as_ptr() as *const c_char
}

View File

@@ -0,0 +1,74 @@
// src/nodejs_syscalls.rs
pub static ALLOW_SYSCALLS: &[i32] = &[
// File IO
libc::SYS_open as i32,
libc::SYS_write as i32,
libc::SYS_close as i32,
libc::SYS_read as i32,
libc::SYS_openat as i32,
libc::SYS_newfstatat as i32,
libc::SYS_ioctl as i32,
libc::SYS_lseek as i32,
libc::SYS_fstat as i32,
libc::SYS_readlink as i32,
libc::SYS_dup3 as i32,
libc::SYS_fcntl as i32,
libc::SYS_fsync as i32,
// Memory
libc::SYS_mprotect as i32,
libc::SYS_mmap as i32,
libc::SYS_munmap as i32,
libc::SYS_mremap as i32,
libc::SYS_brk as i32,
libc::SYS_madvise as i32,
// Signal
libc::SYS_rt_sigaction as i32,
libc::SYS_rt_sigprocmask as i32,
libc::SYS_sigaltstack as i32,
libc::SYS_rt_sigreturn as i32,
libc::SYS_tgkill as i32,
// Thread
libc::SYS_futex as i32,
libc::SYS_sched_yield as i32,
libc::SYS_set_robust_list as i32,
libc::SYS_rseq as i32,
// User / Group
libc::SYS_getuid as i32,
// Process
libc::SYS_getpid as i32,
libc::SYS_gettid as i32,
libc::SYS_exit as i32,
libc::SYS_exit_group as i32,
libc::SYS_sched_getaffinity as i32,
// Time
libc::SYS_clock_gettime as i32,
libc::SYS_gettimeofday as i32,
libc::SYS_nanosleep as i32,
libc::SYS_time as i32,
// Epoll / Event (I/O multiplexing)
libc::SYS_epoll_ctl as i32,
libc::SYS_epoll_pwait as i32,
];
pub static ALLOW_ERROR_SYSCALLS: &[i32] = &[libc::SYS_clone as i32, libc::SYS_clone3 as i32];
pub static ALLOW_NETWORK_SYSCALLS: &[i32] = &[
libc::SYS_socket as i32,
libc::SYS_connect as i32,
libc::SYS_bind as i32,
libc::SYS_listen as i32,
libc::SYS_accept as i32,
libc::SYS_sendto as i32,
libc::SYS_recvfrom as i32,
libc::SYS_getsockname as i32,
libc::SYS_recvmsg as i32,
libc::SYS_getpeername as i32,
libc::SYS_setsockopt as i32,
libc::SYS_ppoll as i32,
libc::SYS_uname as i32,
libc::SYS_sendmsg as i32,
libc::SYS_getsockopt as i32,
libc::SYS_fcntl as i32,
libc::SYS_fstatfs as i32,
];

View File

@@ -1,7 +1,7 @@
// src/syscalls.rs
// src/python_syscalls.rs
pub static ALLOW_SYSCALLS: &[i32] = &[
// file io
// File IO
libc::SYS_read as i32,
libc::SYS_write as i32,
libc::SYS_openat as i32,
@@ -11,48 +11,44 @@ pub static ALLOW_SYSCALLS: &[i32] = &[
libc::SYS_lseek as i32,
libc::SYS_getdents64 as i32,
libc::SYS_fstat as i32,
// thread
// Signal
libc::SYS_rt_sigreturn as i32,
libc::SYS_rt_sigaction as i32,
libc::SYS_rt_sigprocmask as i32,
libc::SYS_sigaltstack as i32,
libc::SYS_tgkill as i32,
// Thread
libc::SYS_futex as i32,
// memory
// Memory
libc::SYS_mmap as i32,
libc::SYS_brk as i32,
libc::SYS_mprotect as i32,
libc::SYS_munmap as i32,
libc::SYS_rt_sigreturn as i32,
libc::SYS_mremap as i32,
// user / group
libc::SYS_setuid as i32,
libc::SYS_setgid as i32,
// User / Group
libc::SYS_getuid as i32,
// process
// Process
libc::SYS_getpid as i32,
libc::SYS_getppid as i32,
libc::SYS_gettid as i32,
libc::SYS_exit as i32,
libc::SYS_exit_group as i32,
libc::SYS_tgkill as i32,
libc::SYS_rt_sigaction as i32,
libc::SYS_sched_yield as i32,
libc::SYS_set_robust_list as i32,
libc::SYS_get_robust_list as i32,
libc::SYS_rseq as i32,
// time
// Time
libc::SYS_clock_gettime as i32,
libc::SYS_gettimeofday as i32,
libc::SYS_time as i32,
libc::SYS_nanosleep as i32,
libc::SYS_clock_nanosleep as i32,
// Epoll / Event (I/O multiplexing)
libc::SYS_epoll_create1 as i32,
libc::SYS_epoll_ctl as i32,
libc::SYS_clock_nanosleep as i32,
libc::SYS_pselect6 as i32,
libc::SYS_rt_sigprocmask as i32,
libc::SYS_sigaltstack as i32,
// Randomness
libc::SYS_getrandom as i32,
];
pub static ALLOW_ERROR_SYSCALLS: &[i32] = &[

View File

@@ -11,51 +11,15 @@ from fastapi import FastAPI
from app.config import get_config
from app.controllers import manager_router
from app.core.runners import init_sandbox_user
from app.dependencies import setup_dependencies, update_dependencies_periodically
from app.logger import setup_logger, get_logger
setup_logger()
config = get_config()
logger = get_logger()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager"""
logger = get_logger()
# Startup
logger.info("Starting RedBear Sandbox...")
# Setup dependencies in background
asyncio.create_task(setup_dependencies())
# Start periodic dependency updates
config = get_config()
if config.python_deps_update_interval:
asyncio.create_task(update_dependencies_periodically())
yield
# Shutdown
logger.info("Shutting down Redbear Sandbox...")
def create_app() -> FastAPI:
"""Create FastAPI application"""
config = get_config()
app = FastAPI(
title="Sandbox",
description="Secure code execution sandbox",
version="2.0.0",
lifespan=lifespan,
debug=config.app.debug
)
app.include_router(manager_router)
return app
def check_root_privileges():
"""Check if running with root privileges"""
if os.geteuid() != 0:
@@ -63,35 +27,38 @@ def check_root_privileges():
sys.exit(1)
def main():
"""Main entry point"""
# Check root privileges
check_root_privileges()
check_root_privileges()
# Setup logging
setup_logger()
config = get_config()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager"""
logger = get_logger()
config = get_config()
# Startup
logger.info("Starting RedBear Sandbox...")
logger.info(f"Starting server on port {config.app.port}")
logger.info(f"Debug mode: {config.app.debug}")
logger.info(f"Max workers: {config.max_workers}")
logger.info(f"Max requests: {config.max_requests}")
logger.info(f"Network enabled: {config.enable_network}")
init_sandbox_user()
await setup_dependencies()
# Create app
app = create_app()
if config.python_deps_update_interval:
asyncio.create_task(update_dependencies_periodically())
# Run server
uvicorn.run(
app,
host="0.0.0.0",
port=config.app.port,
log_level="debug" if config.app.debug else "info",
access_log=config.app.debug
)
yield
# Shutdown
logger.info("Shutting down Redbear Sandbox...")
if __name__ == "__main__":
main()
app = FastAPI(
title="Sandbox",
description="Secure code execution sandbox",
version="0.1.0",
lifespan=lifespan,
debug=config.app.debug
)
app.include_router(manager_router)

Some files were not shown because too many files have changed in this diff Show More