diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 185d746c..3e7db8cb 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -3,9 +3,10 @@ import platform from datetime import timedelta from urllib.parse import quote -from app.core.config import settings from celery import Celery +from app.core.config import settings + # 创建 Celery 应用实例 # broker: 任务队列(使用 Redis DB 0) # backend: 结果存储(使用 Redis DB 10) @@ -67,11 +68,11 @@ celery_app.conf.update( 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, - # Beat/periodic tasks → document_tasks queue (prefork worker) - 'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'}, - 'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'}, - 'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'}, - 'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'}, + # Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker) + 'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'}, + 'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'}, + 'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'}, + 'app.controllers.memory_storage_controller.search_all': {'queue': 'periodic_tasks'}, }, ) @@ -79,40 +80,40 @@ celery_app.conf.update( celery_app.autodiscover_tasks(['app']) # Celery Beat schedule for periodic tasks -memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS) -memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS) -workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME -forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期 +# memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS) +# memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS) +# workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME +# forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期 # 构建定时任务配置 -beat_schedule_config = { - "run-workspace-reflection": { - "task": "app.tasks.workspace_reflection_task", - "schedule": workspace_reflection_schedule, - "args": (), - }, - "regenerate-memory-cache": { - "task": "app.tasks.regenerate_memory_cache", - "schedule": memory_cache_regeneration_schedule, - "args": (), - }, - "run-forgetting-cycle": { - "task": "app.tasks.run_forgetting_cycle_task", - "schedule": forgetting_cycle_schedule, - "kwargs": { - "config_id": None, # 使用默认配置,可以通过环境变量配置 - }, - }, -} +# beat_schedule_config = { +# "run-workspace-reflection": { +# "task": "app.tasks.workspace_reflection_task", +# "schedule": workspace_reflection_schedule, +# "args": (), +# }, +# "regenerate-memory-cache": { +# "task": "app.tasks.regenerate_memory_cache", +# "schedule": memory_cache_regeneration_schedule, +# "args": (), +# }, +# "run-forgetting-cycle": { +# "task": "app.tasks.run_forgetting_cycle_task", +# "schedule": forgetting_cycle_schedule, +# "kwargs": { +# "config_id": None, # 使用默认配置,可以通过环境变量配置 +# }, +# }, +# } # 如果配置了默认工作空间ID,则添加记忆总量统计任务 -if settings.DEFAULT_WORKSPACE_ID: - beat_schedule_config["write-total-memory"] = { - "task": "app.controllers.memory_storage_controller.search_all", - "schedule": memory_increment_schedule, - "kwargs": { - "workspace_id": settings.DEFAULT_WORKSPACE_ID, - }, - } +# if settings.DEFAULT_WORKSPACE_ID: +# beat_schedule_config["write-total-memory"] = { +# "task": "app.controllers.memory_storage_controller.search_all", +# "schedule": memory_increment_schedule, +# "kwargs": { +# "workspace_id": settings.DEFAULT_WORKSPACE_ID, +# }, +# } -celery_app.conf.beat_schedule = beat_schedule_config +# celery_app.conf.beat_schedule = beat_schedule_config diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index f96d0b7e..c4a2f984 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -43,6 +43,7 @@ from . import ( user_memory_controllers, workflow_controller, workspace_controller, + ontology_controller, ) # 创建管理端 API 路由器 @@ -88,5 +89,6 @@ manager_router.include_router(implicit_memory_controller.router) manager_router.include_router(memory_perceptual_controller.router) manager_router.include_router(memory_working_controller.router) manager_router.include_router(file_storage_controller.router) +manager_router.include_router(ontology_controller.router) __all__ = ["manager_router"] diff --git a/api/app/controllers/memory_reflection_controller.py b/api/app/controllers/memory_reflection_controller.py index 7941be35..8d5408f1 100644 --- a/api/app/controllers/memory_reflection_controller.py +++ b/api/app/controllers/memory_reflection_controller.py @@ -51,7 +51,6 @@ async def save_reflection_config( status_code=status.HTTP_400_BAD_REQUEST, detail="缺少必需参数: config_id" ) - api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}") memory_config = MemoryConfigRepository.update_reflection_config( @@ -102,7 +101,7 @@ async def start_workspace_reflection( current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: - """Activate the reflection function for all matching applications in the workspace""" + """启动工作空间中所有匹配应用的反思功能""" workspace_id = current_user.current_workspace_id reflection_service = MemoryReflectionService(db) @@ -111,42 +110,55 @@ async def start_workspace_reflection( service = WorkspaceAppService(db) result = service.get_workspace_apps_detailed(workspace_id) - reflection_results = [] - for data in result['apps_detailed_info']: - if data['memory_configs'] == []: + # 跳过没有配置的应用 + if not data['memory_configs']: + api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过") continue - + releases = data['releases'] memory_configs = data['memory_configs'] end_users = data['end_users'] - - for base, config, user in zip(releases, memory_configs, end_users): - # 安全地转换为整数,处理空字符串和None的情况 - print(base['config']) - try: - base_config = int(base['config']) if base['config'] else 0 - config_id = int(config['config_id']) if config['config_id'] else 0 - except (ValueError, TypeError): - api_logger.warning(f"无效的配置ID: base['config']={base.get('config')}, config['config_id']={config.get('config_id')}") + + # 为每个配置和用户组合执行反思 + for config in memory_configs: + config_id_str = str(config['config_id']) + + # 找到匹配此配置的所有release + matching_releases = [r for r in releases if str(r['config']) == config_id_str] + + if not matching_releases: + api_logger.debug(f"配置 {config_id_str} 没有匹配的release") continue - - if base_config == config_id and base['app_id'] == user['app_id']: - # 调用反思服务 - api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}") - - reflection_result = await reflection_service.start_text_reflection( - config_data=config, - end_user_id=user['id'] - ) - - reflection_results.append({ - "app_id": base['app_id'], - "config_id": config['config_id'], - "end_user_id": user['id'], - "reflection_result": reflection_result - }) + + # 为每个用户执行反思 + for user in end_users: + api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}") + + try: + reflection_result = await reflection_service.start_text_reflection( + config_data=config, + end_user_id=user['id'] + ) + + reflection_results.append({ + "app_id": data['id'], + "config_id": config_id_str, + "end_user_id": user['id'], + "reflection_result": reflection_result + }) + except Exception as e: + api_logger.error(f"用户 {user['id']} 反思失败: {str(e)}") + reflection_results.append({ + "app_id": data['id'], + "config_id": config_id_str, + "end_user_id": user['id'], + "reflection_result": { + "status": "错误", + "message": f"反思失败: {str(e)}" + } + }) return success(data=reflection_results, msg="反思配置成功") diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py new file mode 100644 index 00000000..1cf8e64e --- /dev/null +++ b/api/app/controllers/ontology_controller.py @@ -0,0 +1,964 @@ +"""本体提取API控制器 + +本模块提供本体提取系统的RESTful API端点。 + +Endpoints: + POST /api/memory/ontology/extract - 提取本体类 + POST /api/memory/ontology/export - 导出OWL文件 + POST /api/memory/ontology/scene - 创建本体场景 + PUT /api/memory/ontology/scene/{scene_id} - 更新本体场景 + DELETE /api/memory/ontology/scene/{scene_id} - 删除本体场景 + GET /api/memory/ontology/scene/{scene_id} - 获取单个场景 + GET /api/memory/ontology/scenes - 获取场景列表 + POST /api/memory/ontology/class - 创建本体类型 + PUT /api/memory/ontology/class/{class_id} - 更新本体类型 + DELETE /api/memory/ontology/class/{class_id} - 删除本体类型 + GET /api/memory/ontology/class/{class_id} - 获取单个类型 + GET /api/memory/ontology/classes - 获取类型列表 +""" + +import logging +import tempfile +from typing import Dict, Optional + +from fastapi import APIRouter, Depends, HTTPException, Header +from sqlalchemy.orm import Session + +from app.core.error_codes import BizCode +from app.core.logging_config import get_api_logger +from app.core.response_utils import fail, success +from app.db import get_db +from app.dependencies import get_current_user +from app.models.user_model import User +from app.services.memory_base_service import Translation_English +from app.core.memory.models.ontology_models import OntologyClass +from typing import List +from app.schemas.ontology_schemas import ( + ExportRequest, + ExportResponse, + ExtractionRequest, + ExtractionResponse, + SceneCreateRequest, + SceneUpdateRequest, + SceneResponse, + SceneListResponse, + ClassCreateRequest, + ClassUpdateRequest, + ClassResponse, + ClassListResponse, +) +from app.schemas.response_schema import ApiResponse +from app.services.ontology_service import OntologyService +from app.core.memory.llm_tools.openai_client import OpenAIClient +from app.core.memory.utils.validation.owl_validator import OWLValidator +from app.services.model_service import ModelConfigService + + +api_logger = get_api_logger() +logger = logging.getLogger(__name__) + +router = APIRouter( + prefix="/memory/ontology", + tags=["Ontology"], +) + + +async def translate_ontology_classes( + classes: List[OntologyClass], + model_id: str +) -> List[OntologyClass]: + """翻译本体类列表 + + 将本体类的中文字段翻译为英文,包括: + - name_chinese: 中文名称 + - description: 描述 + - examples: 示例列表 + + Args: + classes: 本体类列表 + model_id: LLM模型ID,用于翻译 + + Returns: + List[OntologyClass]: 翻译后的本体类列表 + """ + translated_classes = [] + + for ontology_class in classes: + # 创建类的副本,避免修改原对象 + translated_class = ontology_class.model_copy(deep=True) + + # 翻译 name_chinese 字段 + if translated_class.name_chinese: + try: + translated_class.name_chinese = await Translation_English( + model_id, + translated_class.name_chinese + ) + except Exception as e: + logger.warning(f"Failed to translate name_chinese: {e}") + # 保留原文 + + # 翻译 description 字段 + if translated_class.description: + try: + translated_class.description = await Translation_English( + model_id, + translated_class.description + ) + except Exception as e: + logger.warning(f"Failed to translate description: {e}") + # 保留原文 + + # 翻译 examples 列表 + if translated_class.examples: + translated_examples = [] + for example in translated_class.examples: + try: + translated_example = await Translation_English( + model_id, + example + ) + translated_examples.append(translated_example) + except Exception as e: + logger.warning(f"Failed to translate example: {e}") + translated_examples.append(example) # 保留原文 + translated_class.examples = translated_examples + + translated_classes.append(translated_class) + + return translated_classes + + +def _get_ontology_service( + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user), + llm_id: str = None +) -> OntologyService: + """获取OntologyService实例的依赖注入函数 + + 指定的llm_id获取LLM配置,创建OpenAIClient和OntologyService实例。 + + Args: + db: 数据库会话 + current_user: 当前用户 + llm_id: 可选的LLM模型ID,如果提供则使用指定模型,否则使用工作空间默认模型 + + Returns: + OntologyService: 本体提取服务实例 + + Raises: + HTTPException: 如果无法获取LLM配置 + """ + try: + import uuid + + # 必须提供llm_id + if not llm_id: + logger.error(f"llm_id is required but not provided - user: {current_user.id}") + raise HTTPException( + status_code=400, + detail="必须提供llm_id参数" + ) + + logger.info(f"Using specified LLM model: {llm_id}") + + # 验证llm_id格式 + try: + model_id = uuid.UUID(llm_id) + except ValueError: + logger.error(f"Invalid llm_id format: {llm_id}") + raise HTTPException( + status_code=400, + detail="无效的LLM模型ID格式" + ) + + # 获取指定的模型配置 + try: + model_config = ModelConfigService.get_model_by_id(db=db, model_id=model_id) + except Exception as e: + logger.error(f"Model {llm_id} not found: {str(e)}") + raise HTTPException( + status_code=400, + detail=f"找不到指定的LLM模型: {llm_id}" + ) + + # 检查是否为组合模型 + if hasattr(model_config, 'is_composite') and model_config.is_composite: + logger.error(f"Model {llm_id} is a composite model, which is not supported for ontology extraction") + raise HTTPException( + status_code=400, + detail="本体提取不支持使用组合模型,请选择单个模型" + ) + + # 验证模型配置了API密钥 + if not model_config.api_keys: + logger.error(f"Model {llm_id} has no API key configuration") + raise HTTPException( + status_code=400, + detail="指定的LLM模型没有配置API密钥" + ) + + api_key_config = model_config.api_keys[0] + + logger.info( + f"Using specified model - user: {current_user.id}, " + f"model_id: {llm_id}, model_name: {api_key_config.model_name}" + ) + + # 创建模型配置对象 + from app.core.models.base import RedBearModelConfig + + llm_model_config = RedBearModelConfig( + model_name=api_key_config.model_name, + provider=model_config.provider if hasattr(model_config, 'provider') else "openai", + api_key=api_key_config.api_key, + base_url=api_key_config.api_base, + max_retries=3, + timeout=60.0 + ) + + # 创建OpenAI客户端 + llm_client = OpenAIClient(model_config=llm_model_config) + + # 创建OntologyService + service = OntologyService(llm_client=llm_client, db=db) + + logger.debug( + f"OntologyService created successfully - " + f"user: {current_user.id}, model: {api_key_config.model_name}" + ) + + return service + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to create OntologyService: {str(e)}", exc_info=True) + raise HTTPException( + status_code=500, + detail=f"创建本体提取服务失败: {str(e)}" + ) + + +@router.post("/extract", response_model=ApiResponse) +async def extract_ontology( + request: ExtractionRequest, + language_type: str = Header(default="zh", alias="X-Language-Type"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """提取本体类 + + 从场景描述中提取符合OWL规范的本体类。 + 提取结果仅返回给前端,不会自动保存到数据库。 + 前端可以从返回结果中选择需要的类型,然后调用 /class 接口创建类型。 + 支持中英文切换,通过 X-Language-Type Header 指定语言。 + + Args: + request: 提取请求,包含scenario、domain、llm_id和scene_id + language_type: 语言类型,'zh'(中文)或 'en'(英文),默认 'zh' + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含提取结果的响应 + + Response format: + { + "code": 200, + "msg": "本体提取成功", + "data": { + "classes": [ + { + "id": "147d9db50b524a9e909e01a753d3acdd", + "name": "Patient", + "name_chinese": "患者", + "description": "在医疗机构中接受诊疗、护理或健康管理的个体", + "examples": ["糖尿病患者", "术后康复患者", "门诊初诊患者"], + "parent_class": null, + "entity_type": "Person", + "domain": "Healthcare" + }, + ... + ], + "domain": "Healthcare", + "extracted_count": 7 + } + } + """ + api_logger.info( + f"Ontology extraction requested by user {current_user.id}, " + f"scenario_length={len(request.scenario)}, " + f"domain={request.domain}, " + f"llm_id={request.llm_id}, " + f"scene_id={request.scene_id}, " + f"language_type={language_type}" + ) + + try: + # 获取当前工作空间ID + workspace_id = current_user.current_workspace_id + if not workspace_id: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + # 创建OntologyService实例,传入llm_id + service = _get_ontology_service( + db=db, + current_user=current_user, + llm_id=request.llm_id + ) + + # 调用服务层执行提取,传入scene_id和workspace_id + result = await service.extract_ontology( + scenario=request.scenario, + domain=request.domain, + scene_id=request.scene_id, + workspace_id=workspace_id + ) + + # ===== 新增:翻译逻辑 ===== + # 如果需要英文,则翻译数据 + if language_type != 'zh': + api_logger.info(f"Translating extraction result to English") + + # 翻译 classes 列表 + result.classes = await translate_ontology_classes( + result.classes, + request.llm_id + ) + + # 翻译 domain 字段 + if result.domain: + try: + result.domain = await Translation_English( + request.llm_id, + result.domain + ) + except Exception as e: + logger.warning(f"Failed to translate domain: {e}") + # 保留原文 + # ===== 翻译逻辑结束 ===== + + # 构建响应 + response = ExtractionResponse( + classes=result.classes, + domain=result.domain, + extracted_count=len(result.classes) + ) + + api_logger.info( + f"Ontology extraction completed, extracted {len(result.classes)} classes, " + f"saved to scene {request.scene_id}, language={language_type}" + ) + + return success(data=response.model_dump(), msg="本体提取成功") + + except ValueError as e: + # 验证错误 (400) + api_logger.warning(f"Validation error in extraction: {str(e)}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) + + except RuntimeError as e: + # 运行时错误 (500) + api_logger.error(f"Runtime error in extraction: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "本体提取失败", str(e)) + + except Exception as e: + # 未知错误 (500) + api_logger.error(f"Unexpected error in extraction: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "本体提取失败", str(e)) + + +@router.post("/export", response_model=ApiResponse) +async def export_owl( + request: ExportRequest, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """导出OWL文件 + + 将提取的本体类导出为OWL文件,支持多种格式。 + 导出操作不需要LLM,只使用OWL验证器和Owlready2库。 + + Args: + request: 导出请求,包含classes、format和include_metadata + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含OWL文件内容的响应 + + Supported formats: + - rdfxml: 标准OWL RDF/XML格式(完整) + - turtle: Turtle格式(可读性好) + - ntriples: N-Triples格式(简单) + - json: JSON格式(简化,只包含类信息) + + Response format: + { + "code": 200, + "msg": "OWL文件导出成功", + "data": { + "owl_content": "...", + "format": "rdfxml", + "classes_count": 7 + } + } + """ + api_logger.info( + f"OWL export requested by user {current_user.id}, " + f"classes_count={len(request.classes)}, " + f"format={request.format}, " + f"include_metadata={request.include_metadata}" + ) + + try: + # 验证格式 + valid_formats = ["rdfxml", "turtle", "ntriples", "json"] + if request.format not in valid_formats: + api_logger.warning(f"Invalid export format: {request.format}") + return fail( + BizCode.BAD_REQUEST, + "不支持的导出格式", + f"format必须是以下之一: {', '.join(valid_formats)}" + ) + + # JSON格式直接导出,不需要OWL验证 + if request.format == "json": + owl_validator = OWLValidator() + owl_content = owl_validator.export_to_owl( + world=None, + format="json", + classes=request.classes + ) + + response = ExportResponse( + owl_content=owl_content, + format=request.format, + classes_count=len(request.classes) + ) + + api_logger.info( + f"JSON export completed, content_length={len(owl_content)}" + ) + + return success(data=response.model_dump(), msg="OWL文件导出成功") + + # 创建临时文件路径 + with tempfile.NamedTemporaryFile( + mode='w', + suffix='.owl', + delete=False + ) as tmp_file: + output_path = tmp_file.name + + # 导出操作不需要LLM,直接使用OWL验证器 + owl_validator = OWLValidator() + + # 验证本体类 + logger.debug("Validating ontology classes") + is_valid, errors, world = owl_validator.validate_ontology_classes( + classes=request.classes, + ) + + if not is_valid: + logger.warning( + f"OWL validation found {len(errors)} issues during export: {errors}" + ) + # 继续导出,但记录警告 + + if not world: + error_msg = "Failed to create OWL world for export" + logger.error(error_msg) + return fail(BizCode.INTERNAL_ERROR, "创建OWL世界失败", error_msg) + + # 导出OWL文件 + logger.info(f"Exporting to {request.format} format") + owl_content = owl_validator.export_to_owl( + world=world, + output_path=output_path, + format=request.format, + classes=request.classes + ) + + # 构建响应 + response = ExportResponse( + owl_content=owl_content, + format=request.format, + classes_count=len(request.classes) + ) + + api_logger.info( + f"OWL export completed, format={request.format}, " + f"content_length={len(owl_content)}" + ) + + return success(data=response.model_dump(), msg="OWL文件导出成功") + + except ValueError as e: + # 验证错误 (400) + api_logger.warning(f"Validation error in export: {str(e)}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) + + except RuntimeError as e: + # 运行时错误 (500) + api_logger.error(f"Runtime error in export: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "OWL文件导出失败", str(e)) + + except Exception as e: + # 未知错误 (500) + api_logger.error(f"Unexpected error in export: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "OWL文件导出失败", str(e)) + + +# ==================== 本体场景管理接口 ==================== + +@router.post("/scene", response_model=ApiResponse) +async def create_scene( + request: SceneCreateRequest, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """创建本体场景 + + 在当前工作空间下创建新的本体场景。 + + Args: + request: 场景创建请求 + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含创建的场景信息 + """ + api_logger.info( + f"Scene creation requested by user {current_user.id}, " + f"name={request.scene_name}" + ) + + try: + # 获取当前工作空间ID + workspace_id = current_user.current_workspace_id + if not workspace_id: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + # 创建OntologyService实例(不需要LLM) + from app.core.memory.llm_tools.openai_client import OpenAIClient + from app.core.models.base import RedBearModelConfig + + # 创建一个空的LLM配置(场景管理不需要LLM) + dummy_config = RedBearModelConfig( + model_name="dummy", + provider="openai", + api_key="dummy", + base_url="https://api.openai.com/v1" + ) + llm_client = OpenAIClient(model_config=dummy_config) + service = OntologyService(llm_client=llm_client, db=db) + + # 调用服务层创建场景 + scene = service.create_scene( + scene_name=request.scene_name, + scene_description=request.scene_description, + workspace_id=workspace_id + ) + + # 构建响应 + # 动态计算 type_num + type_num = len(scene.classes) if scene.classes else 0 + + response = SceneResponse( + scene_id=scene.scene_id, + scene_name=scene.scene_name, + scene_description=scene.scene_description, + type_num=type_num, + workspace_id=scene.workspace_id, + created_at=scene.created_at, + updated_at=scene.updated_at, + classes_count=type_num + ) + + api_logger.info(f"Scene created successfully: {scene.scene_id}") + + return success(data=response.model_dump(), msg="场景创建成功") + + except ValueError as e: + api_logger.warning(f"Validation error in scene creation: {str(e)}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) + + except RuntimeError as e: + api_logger.error(f"Runtime error in scene creation: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "场景创建失败", str(e)) + + except Exception as e: + api_logger.error(f"Unexpected error in scene creation: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "场景创建失败", str(e)) + + +@router.put("/scene/{scene_id}", response_model=ApiResponse) +async def update_scene( + scene_id: str, + request: SceneUpdateRequest, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """更新本体场景 + + 更新指定场景的信息,只能更新当前工作空间下的场景。 + + Args: + scene_id: 场景ID + request: 场景更新请求 + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含更新后的场景信息 + """ + api_logger.info( + f"Scene update requested by user {current_user.id}, " + f"scene_id={scene_id}" + ) + + try: + from uuid import UUID + + # 验证UUID格式 + try: + scene_uuid = UUID(scene_id) + except ValueError: + api_logger.warning(f"Invalid scene_id format: {scene_id}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的场景ID格式") + + # 获取当前工作空间ID + workspace_id = current_user.current_workspace_id + if not workspace_id: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + # 创建OntologyService实例 + from app.core.memory.llm_tools.openai_client import OpenAIClient + from app.core.models.base import RedBearModelConfig + + dummy_config = RedBearModelConfig( + model_name="dummy", + provider="openai", + api_key="dummy", + base_url="https://api.openai.com/v1" + ) + llm_client = OpenAIClient(model_config=dummy_config) + service = OntologyService(llm_client=llm_client, db=db) + + # 调用服务层更新场景 + scene = service.update_scene( + scene_id=scene_uuid, + scene_name=request.scene_name, + scene_description=request.scene_description, + workspace_id=workspace_id + ) + + # 构建响应 + # 动态计算 type_num + type_num = len(scene.classes) if scene.classes else 0 + + response = SceneResponse( + scene_id=scene.scene_id, + scene_name=scene.scene_name, + scene_description=scene.scene_description, + type_num=type_num, + workspace_id=scene.workspace_id, + created_at=scene.created_at, + updated_at=scene.updated_at, + classes_count=type_num + ) + + api_logger.info(f"Scene updated successfully: {scene_id}") + + return success(data=response.model_dump(), msg="场景更新成功") + + except ValueError as e: + api_logger.warning(f"Validation error in scene update: {str(e)}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) + + except RuntimeError as e: + api_logger.error(f"Runtime error in scene update: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "场景更新失败", str(e)) + + except Exception as e: + api_logger.error(f"Unexpected error in scene update: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "场景更新失败", str(e)) + + +@router.delete("/scene/{scene_id}", response_model=ApiResponse) +async def delete_scene( + scene_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """删除本体场景 + + 删除指定场景及其所有关联类型,只能删除当前工作空间下的场景。 + + Args: + scene_id: 场景ID + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 删除结果 + """ + api_logger.info( + f"Scene deletion requested by user {current_user.id}, " + f"scene_id={scene_id}" + ) + + try: + from uuid import UUID + + # 验证UUID格式 + try: + scene_uuid = UUID(scene_id) + except ValueError: + api_logger.warning(f"Invalid scene_id format: {scene_id}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的场景ID格式") + + # 获取当前工作空间ID + workspace_id = current_user.current_workspace_id + if not workspace_id: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + # 创建OntologyService实例 + from app.core.memory.llm_tools.openai_client import OpenAIClient + from app.core.models.base import RedBearModelConfig + + dummy_config = RedBearModelConfig( + model_name="dummy", + provider="openai", + api_key="dummy", + base_url="https://api.openai.com/v1" + ) + llm_client = OpenAIClient(model_config=dummy_config) + service = OntologyService(llm_client=llm_client, db=db) + + # 调用服务层删除场景 + success_flag = service.delete_scene( + scene_id=scene_uuid, + workspace_id=workspace_id + ) + + api_logger.info(f"Scene deleted successfully: {scene_id}") + + return success(data={"deleted": success_flag}, msg="场景删除成功") + + except ValueError as e: + api_logger.warning(f"Validation error in scene deletion: {str(e)}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) + + except RuntimeError as e: + api_logger.error(f"Runtime error in scene deletion: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "场景删除失败", str(e)) + + except Exception as e: + api_logger.error(f"Unexpected error in scene deletion: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "场景删除失败", str(e)) + + +@router.get("/scenes", response_model=ApiResponse) +async def get_scenes( + workspace_id: Optional[str] = None, + scene_name: Optional[str] = None, + page: Optional[int] = None, + pagesize: Optional[int] = None, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取场景列表(支持模糊搜索和全量查询,全量查询支持分页) + + 根据是否提供 scene_name 参数,执行不同的查询: + - 提供 scene_name:进行模糊搜索,返回匹配的场景列表(支持分页) + - 不提供 scene_name:返回工作空间下的所有场景(支持分页) + + 支持中文和英文的模糊匹配,不区分大小写。 + + Args: + workspace_id: 工作空间ID(可选,默认当前用户工作空间) + scene_name: 场景名称关键词(可选,支持模糊匹配) + page: 页码(可选,从1开始) + pagesize: 每页数量(可选) + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含场景列表和分页信息 + + Examples: + - 模糊搜索(不分页):GET /scenes?workspace_id=xxx&scene_name=医疗 + 输入 "医疗" 可以匹配到 "医疗场景"、"智慧医疗"、"医疗管理系统" 等 + - 模糊搜索(分页):GET /scenes?workspace_id=xxx&scene_name=医疗&page=1&pagesize=10 + 返回匹配 "医疗" 的第1页,每页10条数据 + - 全量查询(不分页):GET /scenes?workspace_id=xxx + 返回工作空间下的所有场景 + - 全量查询(分页):GET /scenes?workspace_id=xxx&page=1&pagesize=10 + 返回第1页,每页10条数据 + + Notes: + - 分页参数 page 和 pagesize 必须同时提供 + - page 从1开始,pagesize 必须大于0 + - 返回格式:{"items": [...], "page": {"page": 1, "pagesize": 10, "total": 100, "hasnext": true}} + - 不分页时,page 字段为 null + """ + from app.controllers.ontology_secondary_routes import scenes_handler + return await scenes_handler(workspace_id, scene_name, page, pagesize, db, current_user) + + +# ==================== 本体类型管理接口 ==================== + +@router.post("/class", response_model=ApiResponse) +async def create_class( + request: ClassCreateRequest, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """创建本体类型 + + 在指定场景下创建新的本体类型。 + + Args: + request: 类型创建请求 + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含创建的类型信息 + """ + from app.controllers.ontology_secondary_routes import create_class_handler + return await create_class_handler(request, db, current_user) + + +@router.put("/class/{class_id}", response_model=ApiResponse) +async def update_class( + class_id: str, + request: ClassUpdateRequest, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """更新本体类型 + + 更新指定类型的信息,只能更新当前工作空间下场景的类型。 + + Args: + class_id: 类型ID + request: 类型更新请求 + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含更新后的类型信息 + """ + from app.controllers.ontology_secondary_routes import update_class_handler + return await update_class_handler(class_id, request, db, current_user) + + +@router.delete("/class/{class_id}", response_model=ApiResponse) +async def delete_class( + class_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """删除本体类型 + + 删除指定类型,只能删除当前工作空间下场景的类型。 + + Args: + class_id: 类型ID + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 删除结果 + """ + from app.controllers.ontology_secondary_routes import delete_class_handler + return await delete_class_handler(class_id, db, current_user) + + +@router.get("/classes", response_model=ApiResponse) +async def get_classes( + scene_id: str, + class_name: Optional[str] = None, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取类型列表(支持模糊搜索和全量查询) + + 根据是否提供 class_name 参数,执行不同的查询: + - 提供 class_name:进行模糊搜索,返回匹配的类型列表 + - 不提供 class_name:返回场景下的所有类型 + + 支持中文和英文的模糊匹配,不区分大小写。 + 返回结果包含场景的基本信息(scene_name 和 scene_description)。 + + Args: + scene_id: 场景ID(必填) + class_name: 类型名称关键词(可选,支持模糊匹配) + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含类型列表和场景信息 + + Examples: + - 模糊搜索:GET /classes?scene_id=xxx&class_name=患者 + 输入 "患者" 可以匹配到 "患者"、"患者信息"、"门诊患者" 等 + - 全量查询:GET /classes?scene_id=xxx + 返回场景下的所有类型 + + Response Format: + { + "total": 3, + "scene_id": "xxx", + "scene_name": "医疗场景", + "scene_description": "用于医疗领域的本体建模", + "items": [...] + } + """ + from app.controllers.ontology_secondary_routes import classes_handler + return await classes_handler(scene_id, class_name, db, current_user) + + +@router.get("/class/{class_id}", response_model=ApiResponse) +async def get_class( + class_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取单个本体类型 + + 根据类型ID获取类型的详细信息,只能查询当前工作空间下场景的类型。 + + Args: + class_id: 类型ID + db: 数据库会话 + current_user: 当前用户 + + Returns: + ApiResponse: 包含类型详细信息 + + Response Format: + { + "code": 0, + "msg": "查询成功", + "data": { + "class_id": "xxx", + "class_name": "患者", + "class_description": "在医疗机构中接受诊疗的个体", + "scene_id": "xxx", + "created_at": "2026-01-29T10:00:00", + "updated_at": "2026-01-29T10:00:00" + } + } + """ + from app.controllers.ontology_secondary_routes import get_class_handler + return await get_class_handler(class_id, db, current_user) diff --git a/api/app/controllers/ontology_secondary_routes.py b/api/app/controllers/ontology_secondary_routes.py new file mode 100644 index 00000000..99017eea --- /dev/null +++ b/api/app/controllers/ontology_secondary_routes.py @@ -0,0 +1,611 @@ +# -*- coding: utf-8 -*- +"""本体场景和类型路由(续) + +由于主Controller文件较大,将剩余路由放在此文件中。 +""" + +from uuid import UUID +from typing import Optional + +from fastapi import Depends +from sqlalchemy.orm import Session + +from app.core.error_codes import BizCode +from app.core.logging_config import get_api_logger +from app.core.response_utils import fail, success +from app.db import get_db +from app.dependencies import get_current_user +from app.models.user_model import User +from app.schemas.ontology_schemas import ( + SceneResponse, + SceneListResponse, + PaginationInfo, + ClassCreateRequest, + ClassUpdateRequest, + ClassResponse, + ClassListResponse, + ClassBatchCreateResponse, +) +from app.schemas.response_schema import ApiResponse +from app.services.ontology_service import OntologyService +from app.core.memory.llm_tools.openai_client import OpenAIClient +from app.core.models.base import RedBearModelConfig + + +api_logger = get_api_logger() + + +def _get_dummy_ontology_service(db: Session) -> OntologyService: + """获取OntologyService实例(不需要LLM) + + 场景和类型管理不需要LLM,创建一个dummy配置。 + """ + dummy_config = RedBearModelConfig( + model_name="dummy", + provider="openai", + api_key="dummy", + base_url="https://api.openai.com/v1" + ) + llm_client = OpenAIClient(model_config=dummy_config) + return OntologyService(llm_client=llm_client, db=db) + + +# 这些函数将被导入到主Controller中 + +async def scenes_handler( + workspace_id: Optional[str] = None, + scene_name: Optional[str] = None, + page: Optional[int] = None, + page_size: Optional[int] = None, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取场景列表(支持模糊搜索和全量查询,全量查询支持分页) + + 当提供 scene_name 参数时,进行模糊搜索(不分页); + 当不提供 scene_name 参数时,返回所有场景(支持分页)。 + + Args: + workspace_id: 工作空间ID(可选,默认当前用户工作空间) + scene_name: 场景名称关键词(可选,支持模糊匹配) + page: 页码(可选,从1开始,仅在全量查询时有效) + page_size: 每页数量(可选,仅在全量查询时有效) + db: 数据库会话 + current_user: 当前用户 + """ + operation = "search" if scene_name else "list" + api_logger.info( + f"Scene {operation} requested by user {current_user.id}, " + f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}" + ) + + try: + # 确定工作空间ID + if workspace_id: + try: + ws_uuid = UUID(workspace_id) + except ValueError: + api_logger.warning(f"Invalid workspace_id format: {workspace_id}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的工作空间ID格式") + else: + ws_uuid = current_user.current_workspace_id + if not ws_uuid: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + # 创建Service + service = _get_dummy_ontology_service(db) + + # 根据是否提供 scene_name 决定查询方式 + if scene_name and scene_name.strip(): + # 验证分页参数(模糊搜索也支持分页) + if page is not None and page < 1: + api_logger.warning(f"Invalid page number: {page}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0") + + if page_size is not None and page_size < 1: + api_logger.warning(f"Invalid page_size: {page_size}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0") + + # 如果只提供了page或page_size中的一个,返回错误 + if (page is not None and page_size is None) or (page is None and page_size is not None): + api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供") + + # 模糊搜索场景(支持分页) + scenes = service.search_scenes_by_name(scene_name.strip(), ws_uuid) + total = len(scenes) + + # 如果提供了分页参数,进行分页处理 + if page is not None and page_size is not None: + start_idx = (page - 1) * page_size + end_idx = start_idx + page_size + scenes = scenes[start_idx:end_idx] + + # 构建响应 + items = [] + for scene in scenes: + # 获取前3个class_name作为entity_type + entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None + # 动态计算 type_num + type_num = len(scene.classes) if scene.classes else 0 + + items.append(SceneResponse( + scene_id=scene.scene_id, + scene_name=scene.scene_name, + scene_description=scene.scene_description, + type_num=type_num, + entity_type=entity_type, + workspace_id=scene.workspace_id, + created_at=scene.created_at, + updated_at=scene.updated_at, + classes_count=type_num + )) + + # 构建响应(包含分页信息) + if page is not None and page_size is not None: + # 计算是否有下一页 + hasnext = (page * page_size) < total + + pagination_info = PaginationInfo( + page=page, + pagesize=page_size, + total=total, + hasnext=hasnext + ) + response = SceneListResponse(items=items, page=pagination_info) + else: + response = SceneListResponse(items=items) + + api_logger.info( + f"Scene search completed: found {len(items)} scenes matching '{scene_name}' " + f"in workspace {ws_uuid}, total={total}" + ) + else: + # 获取所有场景(支持分页) + # 验证分页参数 + if page is not None and page < 1: + api_logger.warning(f"Invalid page number: {page}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0") + + if page_size is not None and page_size < 1: + api_logger.warning(f"Invalid page_size: {page_size}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0") + + # 如果只提供了page或page_size中的一个,返回错误 + if (page is not None and page_size is None) or (page is None and page_size is not None): + api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供") + + scenes, total = service.list_scenes(ws_uuid, page, page_size) + + # 构建响应 + items = [] + for scene in scenes: + # 获取前3个class_name作为entity_type + entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None + # 动态计算 type_num + type_num = len(scene.classes) if scene.classes else 0 + + items.append(SceneResponse( + scene_id=scene.scene_id, + scene_name=scene.scene_name, + scene_description=scene.scene_description, + type_num=type_num, + entity_type=entity_type, + workspace_id=scene.workspace_id, + created_at=scene.created_at, + updated_at=scene.updated_at, + classes_count=type_num + )) + + # 构建响应(包含分页信息) + if page is not None and page_size is not None: + # 计算是否有下一页 + hasnext = (page * page_size) < total + + pagination_info = PaginationInfo( + page=page, + pagesize=page_size, + total=total, + hasnext=hasnext + ) + response = SceneListResponse(items=items, page=pagination_info) + else: + response = SceneListResponse(items=items) + + api_logger.info(f"Scene list retrieved successfully, count={len(items)}, total={total}") + + return success(data=response.model_dump(mode='json'), msg="查询成功") + + except ValueError as e: + api_logger.warning(f"Validation error in scene {operation}: {str(e)}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) + + except RuntimeError as e: + api_logger.error(f"Runtime error in scene {operation}: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e)) + + except Exception as e: + api_logger.error(f"Unexpected error in scene {operation}: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e)) + + +# ==================== 本体类型管理接口 ==================== + +async def create_class_handler( + request: ClassCreateRequest, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """创建本体类型(统一使用列表形式,支持单个或批量)""" + + # 根据列表长度判断是单个还是批量 + count = len(request.classes) + mode = "single" if count == 1 else "batch" + + api_logger.info( + f"Class creation ({mode}) requested by user {current_user.id}, " + f"scene_id={request.scene_id}, count={count}" + ) + + try: + # 获取当前工作空间ID + workspace_id = current_user.current_workspace_id + if not workspace_id: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + # 创建Service + service = _get_dummy_ontology_service(db) + + # 准备类型数据 + classes_data = [ + { + "class_name": item.class_name, + "class_description": item.class_description + } + for item in request.classes + ] + + if count == 1: + # 单个创建 + class_data = classes_data[0] + ontology_class = service.create_class( + scene_id=request.scene_id, + class_name=class_data["class_name"], + class_description=class_data["class_description"], + workspace_id=workspace_id + ) + + # 构建单个响应 + response = ClassResponse( + class_id=ontology_class.class_id, + class_name=ontology_class.class_name, + class_description=ontology_class.class_description, + scene_id=ontology_class.scene_id, + created_at=ontology_class.created_at, + updated_at=ontology_class.updated_at + ) + + api_logger.info(f"Class created successfully: {ontology_class.class_id}") + + return success(data=response.model_dump(mode='json'), msg="类型创建成功") + + else: + # 批量创建 + created_classes, errors = service.create_classes_batch( + scene_id=request.scene_id, + classes=classes_data, + workspace_id=workspace_id + ) + + # 构建批量响应 + items = [] + for ontology_class in created_classes: + items.append(ClassResponse( + class_id=ontology_class.class_id, + class_name=ontology_class.class_name, + class_description=ontology_class.class_description, + scene_id=ontology_class.scene_id, + created_at=ontology_class.created_at, + updated_at=ontology_class.updated_at + )) + + response = ClassBatchCreateResponse( + total=len(classes_data), + success_count=len(created_classes), + failed_count=len(errors), + items=items, + errors=errors if errors else None + ) + + api_logger.info( + f"Batch class creation completed: " + f"success={len(created_classes)}, failed={len(errors)}" + ) + + return success(data=response.model_dump(mode='json'), msg="批量创建完成") + + except ValueError as e: + api_logger.warning(f"Validation error in class creation: {str(e)}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) + + except RuntimeError as e: + api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e)) + + except Exception as e: + api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e)) + + +async def update_class_handler( + class_id: str, + request: ClassUpdateRequest, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """更新本体类型""" + api_logger.info( + f"Class update requested by user {current_user.id}, " + f"class_id={class_id}" + ) + + try: + # 验证UUID格式 + try: + class_uuid = UUID(class_id) + except ValueError: + api_logger.warning(f"Invalid class_id format: {class_id}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式") + + # 获取当前工作空间ID + workspace_id = current_user.current_workspace_id + if not workspace_id: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + # 创建Service + service = _get_dummy_ontology_service(db) + + # 更新类型 + ontology_class = service.update_class( + class_id=class_uuid, + class_name=request.class_name, + class_description=request.class_description, + workspace_id=workspace_id + ) + + # 构建响应 + response = ClassResponse( + class_id=ontology_class.class_id, + class_name=ontology_class.class_name, + class_description=ontology_class.class_description, + scene_id=ontology_class.scene_id, + created_at=ontology_class.created_at, + updated_at=ontology_class.updated_at + ) + + api_logger.info(f"Class updated successfully: {class_id}") + + return success(data=response.model_dump(mode='json'), msg="类型更新成功") + + except ValueError as e: + api_logger.warning(f"Validation error in class update: {str(e)}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) + + except RuntimeError as e: + api_logger.error(f"Runtime error in class update: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e)) + + except Exception as e: + api_logger.error(f"Unexpected error in class update: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e)) + + +async def delete_class_handler( + class_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """删除本体类型""" + api_logger.info( + f"Class deletion requested by user {current_user.id}, " + f"class_id={class_id}" + ) + + try: + # 验证UUID格式 + try: + class_uuid = UUID(class_id) + except ValueError: + api_logger.warning(f"Invalid class_id format: {class_id}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式") + + # 获取当前工作空间ID + workspace_id = current_user.current_workspace_id + if not workspace_id: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + # 创建Service + service = _get_dummy_ontology_service(db) + + # 删除类型 + success_flag = service.delete_class( + class_id=class_uuid, + workspace_id=workspace_id + ) + + api_logger.info(f"Class deleted successfully: {class_id}") + + return success(data={"deleted": success_flag}, msg="类型删除成功") + + except ValueError as e: + api_logger.warning(f"Validation error in class deletion: {str(e)}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) + + except RuntimeError as e: + api_logger.error(f"Runtime error in class deletion: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e)) + + except Exception as e: + api_logger.error(f"Unexpected error in class deletion: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e)) + + +async def get_class_handler( + class_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取单个本体类型""" + api_logger.info( + f"Get class requested by user {current_user.id}, " + f"class_id={class_id}" + ) + + try: + # 验证UUID格式 + try: + class_uuid = UUID(class_id) + except ValueError: + api_logger.warning(f"Invalid class_id format: {class_id}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式") + + # 获取当前工作空间ID + workspace_id = current_user.current_workspace_id + if not workspace_id: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + # 创建Service + service = _get_dummy_ontology_service(db) + + # 获取类型(会抛出ValueError如果不存在) + ontology_class = service.get_class_by_id(class_uuid, workspace_id) + + # 构建响应 + response = ClassResponse( + class_id=ontology_class.class_id, + class_name=ontology_class.class_name, + class_description=ontology_class.class_description, + scene_id=ontology_class.scene_id, + created_at=ontology_class.created_at, + updated_at=ontology_class.updated_at + ) + + api_logger.info(f"Class retrieved successfully: {class_id}") + + return success(data=response.model_dump(mode='json'), msg="查询成功") + + except ValueError as e: + # 类型不存在或无权限访问 + api_logger.warning(f"Validation error in get class: {str(e)}") + return fail(BizCode.NOT_FOUND, "请求参数无效", str(e)) + + except RuntimeError as e: + api_logger.error(f"Runtime error in get class: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e)) + + except Exception as e: + api_logger.error(f"Unexpected error in get class: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e)) + + +async def classes_handler( + scene_id: str, + class_name: Optional[str] = None, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取类型列表(支持模糊搜索和全量查询) + + 当提供 class_name 参数时,进行模糊搜索; + 当不提供 class_name 参数时,返回场景下的所有类型。 + + Args: + scene_id: 场景ID(必填) + class_name: 类型名称关键词(可选,支持模糊匹配) + db: 数据库会话 + current_user: 当前用户 + """ + operation = "search" if class_name else "list" + api_logger.info( + f"Class {operation} requested by user {current_user.id}, " + f"keyword={class_name}, scene_id={scene_id}" + ) + + try: + # 验证UUID格式 + try: + scene_uuid = UUID(scene_id) + except ValueError: + api_logger.warning(f"Invalid scene_id format: {scene_id}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的场景ID格式") + + # 获取当前工作空间ID + workspace_id = current_user.current_workspace_id + if not workspace_id: + api_logger.warning(f"User {current_user.id} has no current workspace") + return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + + # 创建Service + service = _get_dummy_ontology_service(db) + + # 获取场景信息 + scene = service.get_scene_by_id(scene_uuid, workspace_id) + if not scene: + api_logger.warning(f"Scene not found: {scene_id}") + return fail(BizCode.NOT_FOUND, "场景不存在", f"未找到ID为 {scene_id} 的场景") + + # 根据是否提供 class_name 决定查询方式 + if class_name and class_name.strip(): + # 模糊搜索类型 + classes = service.search_classes_by_name(class_name.strip(), scene_uuid, workspace_id) + else: + # 获取所有类型 + classes = service.list_classes_by_scene(scene_uuid, workspace_id) + + # 构建响应 + items = [] + for ontology_class in classes: + items.append(ClassResponse( + class_id=ontology_class.class_id, + class_name=ontology_class.class_name, + class_description=ontology_class.class_description, + scene_id=ontology_class.scene_id, + created_at=ontology_class.created_at, + updated_at=ontology_class.updated_at + )) + + response = ClassListResponse( + total=len(items), + scene_id=scene_uuid, + scene_name=scene.scene_name, + scene_description=scene.scene_description, + items=items + ) + + if class_name: + api_logger.info( + f"Class search completed: found {len(items)} classes matching '{class_name}' " + f"in scene {scene_id}" + ) + else: + api_logger.info(f"Class list retrieved successfully, count={len(items)}") + + return success(data=response.model_dump(mode='json'), msg="查询成功") + + except ValueError as e: + api_logger.warning(f"Validation error in class {operation}: {str(e)}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) + + except RuntimeError as e: + api_logger.error(f"Runtime error in class {operation}: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e)) + + except Exception as e: + api_logger.error(f"Unexpected error in class {operation}: {str(e)}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e)) diff --git a/api/app/controllers/prompt_optimizer_controller.py b/api/app/controllers/prompt_optimizer_controller.py index dba52d0b..61195deb 100644 --- a/api/app/controllers/prompt_optimizer_controller.py +++ b/api/app/controllers/prompt_optimizer_controller.py @@ -1,5 +1,5 @@ -import uuid import json +import uuid from fastapi import APIRouter, Depends, Path from sqlalchemy.orm import Session @@ -8,9 +8,13 @@ from starlette.responses import StreamingResponse from app.core.logging_config import get_api_logger from app.core.response_utils import success from app.dependencies import get_current_user, get_db -from app.models.prompt_optimizer_model import RoleType -from app.schemas.prompt_optimizer_schema import PromptOptMessage, PromptOptModelSet, CreateSessionResponse, \ - OptimizePromptResponse, SessionHistoryResponse, SessionMessage +from app.schemas.prompt_optimizer_schema import ( + PromptOptMessage, + CreateSessionResponse, + SessionHistoryResponse, + SessionMessage, + PromptSaveRequest +) from app.schemas.response_schema import ApiResponse from app.services.prompt_optimizer_service import PromptOptimizerService @@ -135,3 +139,109 @@ async def get_prompt_opt( "X-Accel-Buffering": "no" } ) + + +@router.post( + "/releases", + summary="Get prompt optimization", + response_model=ApiResponse +) +def save_prompt( + data: PromptSaveRequest, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), +): + """ + Save a prompt release for the current tenant. + + Args: + data (PromptSaveRequest): Request body containing session_id, title, and prompt. + db (Session): SQLAlchemy database session, injected via dependency. + current_user: Currently authenticated user object, injected via dependency. + + Returns: + ApiResponse: Standard API response containing the saved prompt release info: + - id: UUID of the prompt release + - session_id: associated session + - title: prompt title + - prompt: prompt content + - created_at: timestamp of creation + + Raises: + Any database or service exceptions are propagated to the global exception handler. + """ + service = PromptOptimizerService(db) + prompt_info = service.save_prompt( + tenant_id=current_user.tenant_id, + session_id=data.session_id, + title=data.title, + prompt=data.prompt + ) + return success(data=prompt_info) + + +@router.delete( + "/releases/{prompt_id}", + summary="Delete prompt (soft delete)", + response_model=ApiResponse +) +def delete_prompt( + prompt_id: uuid.UUID = Path(..., description="Prompt ID"), + db: Session = Depends(get_db), + current_user=Depends(get_current_user), +): + """ + Soft delete a prompt release. + + Args: + prompt_id + db (Session): Database session + current_user: Current logged-in user + + Returns: + ApiResponse: Success message confirming deletion + """ + service = PromptOptimizerService(db) + service.delete_prompt( + tenant_id=current_user.tenant_id, + prompt_id=prompt_id + ) + return success(msg="Prompt deleted successfully") + + +@router.get( + "/releases/list", + summary="Get paginated list of released prompts with optional filter", + response_model=ApiResponse +) +def get_release_list( + page: int = 1, + page_size: int = 20, + keyword: str | None = None, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), +): + """ + Retrieve paginated list of released prompts for the current tenant. + Optionally filter by keyword in title. + + Args: + page (int): Page number (starting from 1) + page_size (int): Number of items per page (max 100) + keyword (str | None): Optional keyword to filter prompt titles + db (Session): Database session + current_user: Current logged-in user + + Returns: + ApiResponse: Contains paginated list of prompt releases with metadata + """ + service = PromptOptimizerService(db) + result = service.get_release_list( + tenant_id=current_user.tenant_id, + page=max(1, page), + page_size=min(max(1, page_size), 100), + filter_keyword=keyword + ) + return success(data=result) + + diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index a34c781f..441609ac 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -11,7 +11,8 @@ import os import time from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence - +from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages, format_parsing, messages_parse +from app.core.memory.agent.langgraph_graph.write_graph import long_term_storage from app.db import get_db from app.core.logging_config import get_business_logger from app.core.memory.agent.utils.redis_tool import store @@ -106,7 +107,7 @@ class LangChainAgent: "streaming": streaming, "tool_count": len(self.tools), "tool_names": [tool.name for tool in self.tools] if self.tools else [], - "tool_count": len(self.tools) + # "tool_count": len(self.tools) } ) @@ -145,38 +146,33 @@ class LangChainAgent: user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}" messages.append(HumanMessage(content=user_content)) - return messages -# TODO 乐力齐 - 累积多组对话批量写入功能已禁用 - # async def term_memory_save(self,messages,end_user_end,aimessages): - # '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j''' - # end_user_end=f"Term_{end_user_end}" - # print(messages) - # print(aimessages) - # session_id = store.save_session( - # userid=end_user_end, - # messages=messages, - # apply_id=end_user_end, - # end_user_id=end_user_end, - # aimessages=aimessages - # ) - # store.delete_duplicate_sessions() - # # logger.info(f'Redis_Agent:{end_user_end};{session_id}') - # return session_id -# TODO 乐力齐 - 累积多组对话批量写入功能已禁用 - # async def term_memory_redis_read(self,end_user_end): - # end_user_end = f"Term_{end_user_end}" - # history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end) - # # logger.info(f'Redis_Agent:{end_user_end};{history}') - # messagss_list=[] - # retrieved_content=[] - # for messages in history: - # query = messages.get("Query") - # aimessages = messages.get("Answer") - # messagss_list.append(f'用户:{query}。AI回复:{aimessages}') - # retrieved_content.append({query: aimessages}) - # return messagss_list,retrieved_content + async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type): + db = next(get_db()) + scope=6 + + try: + repo = LongTermMemoryRepository(db) + await long_term_storage(long_term_type="chunk", langchain_messages=long_term_messages, + memory_config=actual_config_id, end_user_id=end_user_id, scope=scope) + + from app.core.memory.agent.utils.redis_tool import write_store + result = write_store.get_session_by_userid(end_user_id) + if type=="chunk" or type=="aggregate": + data = await format_parsing(result, "dict") + chunk_data = data[:scope] + if len(chunk_data)==scope: + repo.upsert(end_user_id, chunk_data) + logger.info(f'写入短长期:') + else: + long_time_data = write_store.find_user_recent_sessions(end_user_id, 5) + long_messages = await messages_parse(long_time_data) + repo.upsert(end_user_id, long_messages) + logger.info(f'写入短长期:') + finally: + db.close() + async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id): """ 写入记忆(支持结构化消息) @@ -224,14 +220,6 @@ class LangChainAgent: logger.warning(f"No messages to write for user {actual_end_user_id}") return - # 调用 Celery 任务,传递结构化消息列表 - # 数据流: - # 1. structured_messages 传递给 write_message_task - # 2. write_message_task 调用 memory_agent_service.write_memory - # 3. write_memory 调用 write_tools.write,传递 messages 参数 - # 4. write_tools.write 调用 get_chunked_dialogs,传递 messages 参数 - # 5. get_chunked_dialogs 为每条消息创建独立的 Chunk,设置 speaker 字段 - # 6. 每个 Chunk 保存到 Neo4j,包含 speaker 字段 logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") write_id = write_message_task.delay( actual_end_user_id, # end_user_id: 用户ID @@ -288,30 +276,6 @@ class LangChainAgent: actual_end_user_id = end_user_id if end_user_id is not None else "unknown" logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') -# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码 -# history_term_memory_result = await self.term_memory_redis_read(end_user_id) -# history_term_memory = history_term_memory_result[0] -# db_for_memory = next(get_db()) -# if memory_flag: -# if len(history_term_memory)>=4 and storage_type != "rag": -# history_term_memory = ';'.join(history_term_memory) -# retrieved_content = history_term_memory_result[1] -# print(retrieved_content) -# # 为长期记忆操作获取新的数据库连接 -# try: -# repo = LongTermMemoryRepository(db_for_memory) -# repo.upsert(end_user_id, retrieved_content) -# logger.info( -# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') -# except Exception as e: -# logger.error(f"Failed to write to LongTermMemory: {e}") -# raise -# finally: -# db_for_memory.close() - -# # 长期记忆写入( -# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id) -# # 注意:不在这里写入用户消息,等 AI 回复后一起写入 try: # 准备消息列表 messages = self._prepare_messages(message, history, context) @@ -332,17 +296,21 @@ class LangChainAgent: # 获取最后的 AI 消息 output_messages = result.get("messages", []) content = "" + total_tokens = 0 for msg in reversed(output_messages): if isinstance(msg, AIMessage): content = msg.content + response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None + total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0 break elapsed_time = time.time() - start_time if memory_flag: + long_term_messages=await agent_chat_messages(message_chat,content) # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id) - # TODO 乐力齐 - 累积多组对话批量写入功能已禁用 - # await self.term_memory_save(message_chat, end_user_id, content) + '''长期''' + await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk") response = { "content": content, "model": self.model_name, @@ -350,7 +318,7 @@ class LangChainAgent: "usage": { "prompt_tokens": 0, "completion_tokens": 0, - "total_tokens": 0 + "total_tokens": total_tokens } } @@ -410,25 +378,7 @@ class LangChainAgent: db.close() except Exception as e: logger.warning(f"Failed to get db session: {e}") -# # TODO 乐力齐 -# history_term_memory_result = await self.term_memory_redis_read(end_user_id) -# history_term_memory = history_term_memory_result[0] -# if memory_flag: -# if len(history_term_memory) >= 4 and storage_type != "rag": -# history_term_memory = ';'.join(history_term_memory) -# retrieved_content = history_term_memory_result[1] -# db_for_memory = next(get_db()) -# try: -# repo = LongTermMemoryRepository(db_for_memory) -# repo.upsert(end_user_id, retrieved_content) -# logger.info( -# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') -# # 长期记忆写入 -# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id) -# except Exception as e: -# logger.error(f"Failed to write to long term memory: {e}") -# finally: -# db_for_memory.close() + # 注意:不在这里写入用户消息,等 AI 回复后一起写入 try: @@ -444,7 +394,7 @@ class LangChainAgent: # 统一使用 agent 的 astream_events 实现流式输出 logger.debug("使用 Agent astream_events 实现流式输出") - full_content='' + full_content = '' try: async for event in self.agent.astream_events( {"messages": messages}, @@ -481,11 +431,20 @@ class LangChainAgent: logger.debug(f"工具调用结束: {event.get('name')}") logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件") + # 统计token消耗 + output_messages = event.get("data", {}).get("output", {}).get("messages", []) + for msg in reversed(output_messages): + if isinstance(msg, AIMessage): + response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None + total_tokens = response_meta.get("token_usage", {}).get("total_tokens", + 0) if response_meta else 0 + yield total_tokens + break if memory_flag: # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + long_term_messages = await agent_chat_messages(message_chat, full_content) await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id) - # TODO 乐力齐 - 累积多组对话批量写入功能已禁用 - # await self.term_memory_save(message_chat, end_user_id, full_content) + await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk") except Exception as e: logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) diff --git a/api/app/core/config.py b/api/app/core/config.py index a8981054..0de957c7 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -157,6 +157,11 @@ class Settings: if origin.strip() ] + # Language Configuration + # Supported values: "zh" (Chinese), "en" (English) + # This controls the language used for memory summary titles and other generated content + DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh") + # Logging settings LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO") LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s") diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py new file mode 100644 index 00000000..d6fbbb38 --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -0,0 +1,165 @@ +import os + +from app.core.logging_config import get_agent_logger +from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format, format_parsing +from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph + +from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel +from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ +from app.core.memory.agent.utils.redis_tool import write_store +from app.core.memory.agent.utils.redis_tool import count_store +from app.core.memory.agent.utils.template_tools import TemplateService +from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.db import get_db_context +logger = get_agent_logger(__name__) +template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') + + +async def write_messages(end_user_id,langchain_messages,memory_config): + ''' + 写入数据到neo4j: + Args: + end_user_id: 终端用户ID + memory_config: 内存配置对象 + langchain_messages:原始数据LIST + ''' + try: + + async with make_write_graph() as graph: + config = {"configurable": {"thread_id": end_user_id}} + # 初始状态 - 包含所有必要字段 + initial_state = { + "messages": langchain_messages, + "end_user_id": end_user_id, + "memory_config": memory_config + } + + # 获取节点更新信息 + async for update_event in graph.astream( + initial_state, + stream_mode="updates", + config=config + ): + for node_name, node_data in update_event.items(): + if 'save_neo4j' == node_name: + massages = node_data + massagesstatus = massages.get('write_result')['status'] + contents = massages.get('write_result') + print(contents) + except Exception as e: + import traceback + traceback.print_exc() +'''根据窗口''' +async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): + ''' + 根据窗口获取redis数据,写入neo4j: + Args: + end_user_id: 终端用户ID + memory_config: 内存配置对象 + langchain_messages:原始数据LIST + scope:窗口大小 + ''' + scope=scope + is_end_user_id = count_store.get_sessions_count(end_user_id) + if is_end_user_id is not False: + is_end_user_id = count_store.get_sessions_count(end_user_id)[0] + redis_messages = count_store.get_sessions_count(end_user_id)[1] + if is_end_user_id and int(is_end_user_id) != int(scope): + print(is_end_user_id) + is_end_user_id += 1 + langchain_messages += redis_messages + count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) + elif int(is_end_user_id) == int(scope): + print('写入长期记忆,并且设置为0') + print(is_end_user_id) + formatted_messages = await chat_data_format(redis_messages) + print(100*'-') + print(formatted_messages) + print(100*'-') + await write_messages(end_user_id, formatted_messages, memory_config) + count_store.update_sessions_count(end_user_id, 0, '') + else: + count_store.save_sessions_count(end_user_id, 1, langchain_messages) + + +"""根据时间""" +async def memory_long_term_storage(end_user_id,memory_config,time): + ''' + 根据时间获取redis数据,写入neo4j: + Args: + end_user_id: 终端用户ID + memory_config: 内存配置对象 + ''' + long_time_data = write_store.find_user_recent_sessions(end_user_id, time) + format_messages = await chat_data_format(long_time_data) + if format_messages!=[]: + await write_messages(end_user_id, format_messages, memory_config) +'''聚合判断''' +async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict: + """ + 聚合判断函数:判断输入句子和历史消息是否描述同一事件 + + Args: + end_user_id: 终端用户ID + ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] + memory_config: 内存配置对象 + """ + + try: + # 1. 获取历史会话数据(使用新方法) + result = write_store.get_all_sessions_by_end_user_id(end_user_id) + history = await format_parsing(result) + if not result: + history = [] + else: + history = await format_parsing(result) + json_schema = WriteAggregateModel.model_json_schema() + template_service = TemplateService(template_root) + system_prompt = await template_service.render_template( + template_name='write_aggregate_judgment.jinja2', + operation_name='aggregate_judgment', + history=history, + sentence=ori_messages, + json_schema=json_schema + ) + with get_db_context() as db_session: + factory = MemoryClientFactory(db_session) + llm_client = factory.get_llm_client(memory_config.llm_model_id) + messages = [ + { + "role": "user", + "content": system_prompt + } + ] + structured = await llm_client.response_structured( + messages=messages, + response_model=WriteAggregateModel + ) + output_value = structured.output + if isinstance(output_value, list): + output_value = [ + {"role": msg.role, "content": msg.content} + for msg in output_value + ] + + result_dict = { + "is_same_event": structured.is_same_event, + "output": output_value + } + if not structured.is_same_event: + logger.info(result_dict) + await write_messages(end_user_id, output_value, memory_config) + return result_dict + + except Exception as e: + print(f"[aggregate_judgment] 发生错误: {e}") + import traceback + traceback.print_exc() + + return { + "is_same_event": False, + "output": ori_messages, + "messages": ori_messages, + "history": history if 'history' in locals() else [], + "error": str(e) + } \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py new file mode 100644 index 00000000..a1fb8226 --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py @@ -0,0 +1,100 @@ +import json + +from langchain_core.messages import HumanMessage, AIMessage + + +async def format_parsing(messages: list,type:str='string'): + """ + 格式化解析消息列表 + + Args: + messages: 消息列表 + type: 返回类型 ('string' 或 'dict') + + Returns: + 格式化后的消息列表 + """ + result = [] + user=[] + ai=[] + + for message in messages: + hstory_messages = message['messages'] + for history_messag in hstory_messages.strip().splitlines(): + history_messag = json.loads(history_messag) + for content in history_messag: + role = content['role'] + content = content['content'] + if type == "string": + if role == 'human': + content = '用户:' + content + else: + content = 'AI:' + content + result.append(content) + if type == "dict": + if role == 'human': + user.append( content) + else: + ai.append(content) + if type == "dict": + for key,values in zip(user,ai): + result.append({key:values}) + return result + +async def messages_parse(messages: list | dict): + user=[] + ai=[] + database=[] + for message in messages: + Query = message['Query'] + Query = json.loads(Query) + for data in Query: + role = data['role'] + if role == "human": + user.append(data['content']) + if role == "ai": + ai.append(data['content']) + for key, values in zip(user, ai): + database.append({key, values}) + return database +async def chat_data_format(messages: list | dict): + """ + 将消息格式化为 LangChain 消息格式 + + Args: + messages: 消息列表或字典 + + Returns: + LangChain 消息列表 + """ + langchain_messages = [] + if isinstance(messages, list): + for msg in messages: + if 'role' in msg.keys(): + if msg['role'] == 'user': + langchain_messages.append(HumanMessage(content=msg['content'])) + elif msg['role'] == 'assistant': + langchain_messages.append(AIMessage(content=msg['content'])) + if "Query" in msg.keys(): + langchain_messages.append(HumanMessage(content=msg['Query'])) + langchain_messages.append(AIMessage(content=msg['Answer'])) + if isinstance(messages, dict): + if messages['type'] == 'human': + langchain_messages.append(HumanMessage(content=messages['content'])) + elif messages['type'] == 'ai': + langchain_messages.append(AIMessage(content=messages['content'])) + return langchain_messages + +async def agent_chat_messages(user_content,ai_content): + messages = [ + { + "role": "user", + "content": f"{user_content}" + }, + { + "role": "assistant", + "content": f"{ai_content}" + } + + ] + return messages diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index 8b5de444..d0e8a45d 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -1,20 +1,17 @@ import asyncio +import json import sys import warnings from contextlib import asynccontextmanager - - -from langchain_core.messages import HumanMessage from langgraph.constants import END, START from langgraph.graph import StateGraph - +from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, chat_data_format, messages_parse from app.db import get_db from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node -from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write from app.services.memory_config_service import MemoryConfigService warnings.filterwarnings("ignore", category=RuntimeWarning) @@ -34,14 +31,6 @@ async def make_write_graph(): end_user_id: Group identifier memory_config: MemoryConfig object containing all configuration """ - # workflow = StateGraph(WriteState) - # workflow.add_node("content_input", content_input_write) - # workflow.add_node("save_neo4j", write_node) - # workflow.add_edge(START, "content_input") - # workflow.add_edge("content_input", "save_neo4j") - # workflow.add_edge("save_neo4j", END) - # - # graph = workflow.compile() workflow = StateGraph(WriteState) workflow.add_node("save_neo4j", write_node) workflow.add_edge(START, "save_neo4j") @@ -50,44 +39,49 @@ async def make_write_graph(): graph = workflow.compile() yield graph - - -async def main(): - """主函数 - 运行工作流""" - message = "今天周一" - end_user_id = 'new_2025test1103' # 组ID - - +async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6): + from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment + from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format + from app.core.memory.agent.utils.redis_tool import write_store + write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages)) # 获取数据库会话 db_session = next(get_db()) config_service = MemoryConfigService(db_session) memory_config = config_service.load_memory_config( - config_id=17, # 改为整数 + config_id=memory_config, # 改为整数 service_name="MemoryAgentService" ) - try: - async with make_write_graph() as graph: - config = {"configurable": {"thread_id": end_user_id}} - # 初始状态 - 包含所有必要字段 - initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config} + if long_term_type=='chunk': + '''方案一:对话窗口6轮对话''' + await window_dialogue(end_user_id,langchain_messages,memory_config,scope) + if long_term_type=='time': + """时间""" + await memory_long_term_storage(end_user_id, memory_config,5) + if long_term_type=='aggregate': - # 获取节点更新信息 - async for update_event in graph.astream( - initial_state, - stream_mode="updates", - config=config - ): - for node_name, node_data in update_event.items(): - if 'save_neo4j'==node_name: - massages=node_data - massages=massages.get('write_result')['status'] - print(massages) # | 更新数据: {node_data} - - except Exception as e: - import traceback - traceback.print_exc() + """方案三:聚合判断""" + await aggregate_judgment(end_user_id, langchain_messages, memory_config) -if __name__ == "__main__": - import asyncio - asyncio.run(main()) \ No newline at end of file +# async def main(): +# """主函数 - 运行工作流""" +# langchain_messages = [ +# { +# "role": "user", +# "content": "今天周五好开心啊" +# }, +# { +# "role": "assistant", +# "content": "你也这么觉得,我也是耶" +# } +# +# ] +# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID +# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4" +# # await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) +# result=await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) +# +# +# if __name__ == "__main__": +# import asyncio +# asyncio.run(main()) \ No newline at end of file diff --git a/api/app/core/memory/agent/models/write_aggregate_model.py b/api/app/core/memory/agent/models/write_aggregate_model.py new file mode 100644 index 00000000..fd423314 --- /dev/null +++ b/api/app/core/memory/agent/models/write_aggregate_model.py @@ -0,0 +1,28 @@ +"""Pydantic models for write aggregate judgment operations.""" + +from typing import List, Union +from pydantic import BaseModel, Field + + +class MessageItem(BaseModel): + """Individual message item in conversation.""" + + role: str = Field(..., description="角色:user 或 assistant") + content: str = Field(..., description="消息内容") + + +class WriteAggregateResponse(BaseModel): + """Response model for aggregate judgment containing judgment result and output.""" + + is_same_event: bool = Field( + ..., + description="是否是同一事件。True表示是同一事件,False表示不同事件" + ) + output: Union[List[MessageItem], bool] = Field( + ..., + description="如果is_same_event为True,返回False;如果is_same_event为False,返回消息列表" + ) + + +# 为了保持向后兼容,保留旧的类名作为别名 +WriteAggregateModel = WriteAggregateResponse diff --git a/api/app/core/memory/agent/utils/prompt/write_aggregate_judgment.jinja2 b/api/app/core/memory/agent/utils/prompt/write_aggregate_judgment.jinja2 new file mode 100644 index 00000000..fb0247aa --- /dev/null +++ b/api/app/core/memory/agent/utils/prompt/write_aggregate_judgment.jinja2 @@ -0,0 +1,57 @@ +输入句子:{{sentence}} +历史消息:{{history}} + +# 你的角色 +你是一个擅长事件聚合与语义判断的专家。 + +# 你的任务 +结合历史消息和输入句子,判断它们是否在描述**同一件事件或同一事件链**。 + +以下情况视为"同一事件"(需要返回 is_same_event=True, output=False): +- 描述的是同一个具体事件或事实 +- 存在明显的因果关系、前后发展关系 +- 是对同一事件的补充、解释、追问或延展 +- 逻辑上属于同一语境下的连续讨论 + +以下情况视为"不同事件"(需要返回 is_same_event=False, output=消息列表): +- 话题不同,事件主体不同 +- 时间、地点、对象明显不同 +- 只是语义相似,但并非同一具体事件 +- 无直接事件、因果或逻辑关联 + +# 输出规则(非常重要) +你必须按照以下JSON格式输出: + +**如果是同一事件:** +```json +{ + "is_same_event": true, + "output": false +} +``` + +**如果不是同一事件:** +```json +{ + "is_same_event": false, + "output": [ + { + "role": "user", + "content": "输入句子的内容" + }, + { + "role": "assistant", + "content": "对应的回复内容" + } + ] +} +``` + +# JSON Schema +{{json_schema}} + +# 注意事项 +- 必须严格按照上述格式输出 +- output 字段:如果是同一事件返回 false,如果不是同一事件返回完整的消息列表 +- 消息列表必须包含 role 和 content 字段 +- 不要输出任何解释、分析或多余内容 diff --git a/api/app/core/memory/agent/utils/redis_base.py b/api/app/core/memory/agent/utils/redis_base.py new file mode 100644 index 00000000..59bac109 --- /dev/null +++ b/api/app/core/memory/agent/utils/redis_base.py @@ -0,0 +1,186 @@ +import json +from typing import Any, List, Dict, Optional +from datetime import datetime, timedelta + + +def serialize_messages(messages: Any) -> str: + """ + 将消息序列化为 JSON 字符串,支持 LangChain 消息对象 + + Args: + messages: 可以是 list、dict、string 或 LangChain 消息对象列表 + + Returns: + str: JSON 字符串 + """ + if isinstance(messages, str): + return messages + + if isinstance(messages, (list, tuple)): + # 检查是否是 LangChain 消息对象列表 + serialized_list = [] + for msg in messages: + if hasattr(msg, 'type') and hasattr(msg, 'content'): + # LangChain 消息对象 + serialized_list.append({ + 'type': msg.type, + 'content': msg.content, + 'role': getattr(msg, 'role', msg.type) + }) + elif isinstance(msg, dict): + serialized_list.append(msg) + else: + serialized_list.append(str(msg)) + return json.dumps(serialized_list, ensure_ascii=False) + + if isinstance(messages, dict): + return json.dumps(messages, ensure_ascii=False) + + # 其他类型转为字符串 + return str(messages) + + +def deserialize_messages(messages_str: str) -> Any: + """ + 将 JSON 字符串反序列化为原始格式 + + Args: + messages_str: JSON 字符串 + + Returns: + 反序列化后的对象(list、dict 或 string) + """ + if not messages_str: + return [] + + try: + return json.loads(messages_str) + except (json.JSONDecodeError, TypeError): + return messages_str + + +def fix_encoding(text: str) -> str: + """ + 修复错误编码的文本 + + Args: + text: 需要修复的文本 + + Returns: + str: 修复后的文本 + """ + if not text or not isinstance(text, str): + return text + try: + # 尝试修复 Latin-1 误编码为 UTF-8 的情况 + return text.encode('latin-1').decode('utf-8') + except (UnicodeDecodeError, UnicodeEncodeError): + # 如果修复失败,返回原文本 + return text + + +def format_session_data(data: Dict[str, Any], include_time: bool = False) -> Dict[str, Any]: + """ + 格式化会话数据为统一的输出格式 + + Args: + data: 原始会话数据 + include_time: 是否包含时间字段 + + Returns: + Dict: 格式化后的数据 {"Query": "...", "Answer": "...", "starttime": "..."} + """ + result = { + "Query": fix_encoding(data.get('messages', '')), + "Answer": fix_encoding(data.get('aimessages', '')) + } + + if include_time: + result["starttime"] = data.get('starttime', '') + + return result + + +def filter_by_time_range(items: List[Dict], minutes: int) -> List[Dict]: + """ + 根据时间范围过滤数据 + + Args: + items: 包含 starttime 字段的数据列表 + minutes: 时间范围(分钟) + + Returns: + List[Dict]: 过滤后的数据列表 + """ + time_threshold = datetime.now() - timedelta(minutes=minutes) + time_threshold_str = time_threshold.strftime("%Y-%m-%d %H:%M:%S") + + filtered_items = [] + for item in items: + starttime = item.get('starttime', '') + if starttime and starttime >= time_threshold_str: + filtered_items.append(item) + + return filtered_items + + +def sort_and_limit_results(items: List[Dict], limit: int = 6, + remove_time: bool = True) -> List[Dict]: + """ + 对结果进行排序、限制数量并移除时间字段 + + Args: + items: 数据列表 + limit: 最大返回数量 + remove_time: 是否移除 starttime 字段 + + Returns: + List[Dict]: 处理后的数据列表 + """ + # 按时间降序排序(最新的在前) + items.sort(key=lambda x: x.get('starttime', ''), reverse=True) + + # 限制数量 + result_items = items[:limit] + + # 移除 starttime 字段 + if remove_time: + for item in result_items: + item.pop('starttime', None) + + # 如果结果少于1条,返回空列表 + if len(result_items) < 1: + return [] + + return result_items + + +def generate_session_key(session_id: str, key_type: str = "session") -> str: + """ + 生成 Redis key + + Args: + session_id: 会话ID + key_type: key 类型 ("session", "read", "write", "count") + + Returns: + str: Redis key + """ + if key_type == "count": + return f"session:count:{session_id}" + elif key_type == "write": + return f"session:write:{session_id}" + elif key_type == "session" or key_type == "read": + return f"session:{session_id}" + else: + return f"session:{session_id}" + + +def get_current_timestamp() -> str: + """ + 获取当前时间戳字符串 + + Returns: + str: 格式化的时间字符串 "YYYY-MM-DD HH:MM:SS" + """ + return datetime.now().strftime("%Y-%m-%d %H:%M:%S") \ No newline at end of file diff --git a/api/app/core/memory/agent/utils/redis_tool.py b/api/app/core/memory/agent/utils/redis_tool.py index 505545b3..b61319e5 100644 --- a/api/app/core/memory/agent/utils/redis_tool.py +++ b/api/app/core/memory/agent/utils/redis_tool.py @@ -1,11 +1,36 @@ import redis import uuid -from datetime import datetime from app.core.config import settings +from typing import List, Dict, Any, Optional, Union + +from app.core.memory.agent.utils.redis_base import ( + serialize_messages, + deserialize_messages, + fix_encoding, + format_session_data, + filter_by_time_range, + sort_and_limit_results, + generate_session_key, + get_current_timestamp +) -class RedisSessionStore: + + +class RedisWriteStore: + """Redis Write 类型存储类,用于管理 save_session_write 相关的数据""" + def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): + """ + 初始化 Redis 连接 + + Args: + host: Redis 主机地址 + port: Redis 端口 + db: Redis 数据库编号 + password: Redis 密码 + session_id: 会话ID + """ self.r = redis.Redis( host=host, port=port, @@ -16,32 +41,400 @@ class RedisSessionStore: ) self.uudi = session_id - def _fix_encoding(self, text): - """修复错误编码的文本""" - if not text or not isinstance(text, str): - return text - try: - # 尝试修复 Latin-1 误编码为 UTF-8 的情况 - return text.encode('latin-1').decode('utf-8') - except (UnicodeDecodeError, UnicodeEncodeError): - # 如果修复失败,返回原文本 - return text - - # 修改后的 save_session 方法 - def save_session(self, userid, messages, aimessages, apply_id, end_user_id): + def save_session_write(self, userid: str, messages: str) -> str: """ 写入一条会话数据,返回 session_id - 优化版本:确保写入时间不超过1秒 + + Args: + userid: 用户ID + messages: 用户消息 + + Returns: + str: 新生成的 session_id """ try: - session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID - starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - key = f"session:{session_id}" # 使用新生成的 session_id 作为 key + messages = serialize_messages(messages) + session_id = str(uuid.uuid4()) + key = generate_session_key(session_id, key_type="write") - # 使用 pipeline 批量写入,减少网络往返 pipe = self.r.pipeline() + pipe.hset(key, mapping={ + "id": self.uudi, + "sessionid": userid, + "messages": messages, + "starttime": get_current_timestamp() + }) + result = pipe.execute() - # 直接写入数据,decode_responses=True 已经处理了编码 + print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}") + return session_id + except Exception as e: + print(f"[save_session_write] 保存会话失败: {e}") + raise e + + def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]: + """ + 通过 save_session_write 的 userid 获取 sessionid 和 messages + + Args: + userid: 用户ID (对应 sessionid 字段) + + Returns: + List[Dict] 或 False: 如果找到数据返回 [{"sessionid": "...", "messages": "..."}, ...],否则返回 False + """ + try: + # 只查询 write 类型的 key + keys = self.r.keys('session:write:*') + if not keys: + return False + + # 批量获取数据 + pipe = self.r.pipeline() + for key in keys: + pipe.hgetall(key) + all_data = pipe.execute() + + # 筛选符合 userid 的数据 + results = [] + for key, data in zip(keys, all_data): + if not data: + continue + + # 从 write 类型读取,匹配 sessionid 字段 + if data.get('sessionid') == userid: + # 从 key 中提取 session_id: session:write:{session_id} + session_id = key.split(':')[-1] + results.append({ + "sessionid": session_id, + "messages": fix_encoding(data.get('messages', '')) + }) + + if not results: + return False + + print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据") + return results + except Exception as e: + print(f"[get_session_by_userid] 查询失败: {e}") + return False + + def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]: + """ + 通过 end_user_id 获取所有 write 类型的会话数据 + + Args: + end_user_id: 终端用户ID (对应 sessionid 字段) + + Returns: + List[Dict] 或 False: 如果找到数据返回完整的会话信息列表,否则返回 False + + 返回格式: + [ + { + "session_id": "uuid", + "id": "...", + "sessionid": "end_user_id", + "messages": "...", + "starttime": "timestamp" + }, + ... + ] + """ + try: + # 只查询 write 类型的 key + keys = self.r.keys('session:write:*') + if not keys: + print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话") + return False + + # 批量获取数据 + pipe = self.r.pipeline() + for key in keys: + pipe.hgetall(key) + all_data = pipe.execute() + + # 筛选符合 end_user_id 的数据 + results = [] + for key, data in zip(keys, all_data): + if not data: + continue + + # 从 write 类型读取,匹配 sessionid 字段 + if data.get('sessionid') == end_user_id: + # 从 key 中提取 session_id: session:write:{session_id} + session_id = key.split(':')[-1] + + # 构建完整的会话信息 + session_info = { + "session_id": session_id, + "id": data.get('id', ''), + "sessionid": data.get('sessionid', ''), + "messages": fix_encoding(data.get('messages', '')), + "starttime": data.get('starttime', '') + } + results.append(session_info) + + if not results: + print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据") + return False + + # 按时间排序(最新的在前) + results.sort(key=lambda x: x.get('starttime', ''), reverse=True) + + print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据") + return results + except Exception as e: + print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}") + import traceback + traceback.print_exc() + return False + + def find_user_recent_sessions(self, userid: str, + minutes: int = 5) -> List[Dict[str, str]]: + """ + 根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据 + + Args: + userid: 用户ID (对应 sessionid 字段) + minutes: 查询最近几分钟的数据,默认5分钟 + + Returns: + List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...] + """ + import time + start_time = time.time() + + # 只查询 write 类型的 key + keys = self.r.keys('session:write:*') + if not keys: + print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") + return [] + + # 批量获取数据 + pipe = self.r.pipeline() + for key in keys: + pipe.hgetall(key) + all_data = pipe.execute() + + # 筛选符合 userid 的数据 + matched_items = [] + for data in all_data: + if not data: + continue + + # 从 write 类型读取,匹配 sessionid 字段 + if data.get('sessionid') == userid and data.get('starttime'): + # write 类型没有 aimessages,所以 Answer 为空 + matched_items.append({ + "Query": fix_encoding(data.get('messages', '')), + "Answer": "", + "starttime": data.get('starttime', '') + }) + + # 根据时间范围过滤 + filtered_items = filter_by_time_range(matched_items, minutes) + # 排序并移除时间字段 + result_items = sort_and_limit_results(filtered_items, limit=None) + print(result_items) + + elapsed_time = time.time() - start_time + print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, " + f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") + + return result_items + + def delete_all_write_sessions(self) -> int: + """ + 删除所有 write 类型的会话 + + Returns: + int: 删除的数量 + """ + keys = self.r.keys('session:write:*') + if keys: + return self.r.delete(*keys) + return 0 + + +class RedisCountStore: + """Redis Count 类型存储类,用于管理访问次数统计相关的数据""" + + def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): + """ + 初始化 Redis 连接 + + Args: + host: Redis 主机地址 + port: Redis 端口 + db: Redis 数据库编号 + password: Redis 密码 + session_id: 会话ID + """ + self.r = redis.Redis( + host=host, + port=port, + db=db, + password=password, + decode_responses=True, + encoding='utf-8' + ) + self.uudi = session_id + + def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str: + """ + 保存用户访问次数统计 + + Args: + end_user_id: 终端用户ID + count: 访问次数 + messages: 消息内容 + + Returns: + str: 新生成的 session_id + """ + session_id = str(uuid.uuid4()) + key = generate_session_key(session_id, key_type="count") + + pipe = self.r.pipeline() + pipe.hset(key, mapping={ + "id": self.uudi, + "end_user_id": end_user_id, + "count": int(count), + "messages": serialize_messages(messages), + "starttime": get_current_timestamp() + }) + pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期 + result = pipe.execute() + + print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}") + return session_id + + def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]: + """ + 通过 end_user_id 查询访问次数统计 + + Args: + end_user_id: 终端用户ID + + Returns: + list 或 False: 如果找到返回 [count, messages],否则返回 False + """ + try: + search_pattern = 'session:count:*' + + for key in self.r.keys(search_pattern): + data = self.r.hgetall(key) + + if not data: + continue + + if data.get('end_user_id') == end_user_id: + count = data.get('count') + messages_str = data.get('messages') + + if count is not None: + messages = deserialize_messages(messages_str) + return [int(count), messages] + + return False + except Exception as e: + print(f"[get_sessions_count] 查询失败: {e}") + return False + + def update_sessions_count(self, end_user_id: str, new_count: int, + messages: Any) -> bool: + """ + 通过 end_user_id 修改访问次数统计 + + Args: + end_user_id: 终端用户ID + new_count: 新的 count 值 + messages: 消息内容 + + Returns: + bool: 更新成功返回 True,未找到记录返回 False + """ + try: + messages_str = serialize_messages(messages) + search_pattern = 'session:count:*' + + for key in self.r.keys(search_pattern): + data = self.r.hgetall(key) + + if not data: + continue + + if data.get('end_user_id') == end_user_id: + self.r.hset(key, 'count', int(new_count)) + self.r.hset(key, 'messages', messages_str) + print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}") + return True + + print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") + return False + except Exception as e: + print(f"[update_sessions_count] 更新失败: {e}") + return False + + def delete_all_count_sessions(self) -> int: + """ + 删除所有 count 类型的会话 + + Returns: + int: 删除的数量 + """ + keys = self.r.keys('session:count:*') + if keys: + return self.r.delete(*keys) + return 0 + + +class RedisSessionStore: + """Redis 会话存储类,用于管理会话数据""" + + def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): + """ + 初始化 Redis 连接 + + Args: + host: Redis 主机地址 + port: Redis 端口 + db: Redis 数据库编号 + password: Redis 密码 + session_id: 会话ID + """ + self.r = redis.Redis( + host=host, + port=port, + db=db, + password=password, + decode_responses=True, + encoding='utf-8' + ) + self.uudi = session_id + + # ==================== 写入操作 ==================== + + def save_session(self, userid: str, messages: str, aimessages: str, + apply_id: str, end_user_id: str) -> str: + """ + 写入一条会话数据,返回 session_id + + Args: + userid: 用户ID + messages: 用户消息 + aimessages: AI回复消息 + apply_id: 应用ID + end_user_id: 终端用户ID + + Returns: + str: 新生成的 session_id + """ + try: + session_id = str(uuid.uuid4()) + key = generate_session_key(session_id, key_type="read") + + pipe = self.r.pipeline() pipe.hset(key, mapping={ "id": self.uudi, "sessionid": userid, @@ -49,177 +442,195 @@ class RedisSessionStore: "end_user_id": end_user_id, "messages": messages, "aimessages": aimessages, - "starttime": starttime + "starttime": get_current_timestamp() }) - - # 可选:设置过期时间(例如30天),避免数据无限增长 - # pipe.expire(key, 30 * 24 * 60 * 60) - - # 执行批量操作 result = pipe.execute() - print(f"保存结果: {result[0]}, session_id: {session_id}") - return session_id # 返回新生成的 session_id + print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}") + return session_id except Exception as e: - print(f"保存会话失败: {e}") + print(f"[save_session] 保存会话失败: {e}") raise e - def save_sessions_batch(self, sessions_data): - """ - 批量写入多条会话数据,返回 session_id 列表 - sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id - 优化版本:批量操作,大幅提升性能 - """ - try: - session_ids = [] - pipe = self.r.pipeline() - - for session in sessions_data: - session_id = str(uuid.uuid4()) - starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - key = f"session:{session_id}" - - pipe.hset(key, mapping={ - "id": self.uudi, - "sessionid": session.get('userid'), - "apply_id": session.get('apply_id'), - "end_user_id": session.get('end_user_id'), - "messages": session.get('messages'), - "aimessages": session.get('aimessages'), - "starttime": starttime - }) - - session_ids.append(session_id) - - # 一次性执行所有写入操作 - results = pipe.execute() - print(f"批量保存完成: {len(session_ids)} 条记录") - return session_ids - except Exception as e: - print(f"批量保存会话失败: {e}") - raise e - - # ---------------- 读取 ---------------- - def get_session(self, session_id): + # ==================== 读取操作 ==================== + + def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: """ 读取一条会话数据 + + Args: + session_id: 会话ID + + Returns: + Dict 或 None: 会话数据 """ - key = f"session:{session_id}" + key = generate_session_key(session_id) data = self.r.hgetall(key) return data if data else None - def get_session_apply_group(self, sessionid, apply_id, end_user_id): + def get_all_sessions(self) -> Dict[str, Dict[str, Any]]: """ - 根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据 - """ - result_items = [] - - # 遍历所有会话数据 - for key in self.r.keys('session:*'): - data = self.r.hgetall(key) - - if not data: - continue - - # 检查三个条件是否都匹配 - if (data.get('sessionid') == sessionid and - data.get('apply_id') == apply_id and - data.get('end_user_id') == end_user_id): - result_items.append(data) - - return result_items - - def get_all_sessions(self): - """ - 获取所有会话数据 + 获取所有会话数据(不包括 count 和 write 类型) + + Returns: + Dict: 所有会话数据,key 为 session_id """ sessions = {} for key in self.r.keys('session:*'): - sid = key.split(':')[1] - sessions[sid] = self.get_session(sid) + # 排除 count 和 write 类型的 key + if ':count:' not in key and ':write:' not in key: + sid = key.split(':')[1] + sessions[sid] = self.get_session(sid) return sessions - # ---------------- 更新 ---------------- - def update_session(self, session_id, field, value): + def find_user_apply_group(self, sessionid: str, apply_id: str, + end_user_id: str) -> List[Dict[str, str]]: + """ + 根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条 + + Args: + sessionid: 会话ID(支持模糊匹配) + apply_id: 应用ID + end_user_id: 终端用户ID + + Returns: + List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...] + """ + import time + start_time = time.time() + + keys = self.r.keys('session:*') + if not keys: + print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") + return [] + + # 批量获取数据 + pipe = self.r.pipeline() + for key in keys: + # 排除 count 和 write 类型 + if ':count:' not in key and ':write:' not in key: + pipe.hgetall(key) + all_data = pipe.execute() + + # 筛选符合条件的数据 + matched_items = [] + for data in all_data: + if not data: + continue + + if (data.get('apply_id') == apply_id and + data.get('end_user_id') == end_user_id): + # 支持模糊匹配或完全匹配 sessionid + if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid: + matched_items.append(format_session_data(data, include_time=True)) + + # 排序、限制数量并移除时间字段 + result_items = sort_and_limit_results(matched_items, limit=6) + + elapsed_time = time.time() - start_time + print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") + + return result_items + + # ==================== 更新操作 ==================== + + def update_session(self, session_id: str, field: str, value: Any) -> bool: """ 更新单个字段 - 优化版本:使用 pipeline 减少网络往返 + + Args: + session_id: 会话ID + field: 字段名 + value: 字段值 + + Returns: + bool: 是否更新成功 """ - key = f"session:{session_id}" + key = generate_session_key(session_id) pipe = self.r.pipeline() pipe.exists(key) pipe.hset(key, field, value) results = pipe.execute() - return bool(results[0]) # 返回 key 是否存在 + return bool(results[0]) - # ---------------- 删除 ---------------- - def delete_session(self, session_id): + # ==================== 删除操作 ==================== + + def delete_session(self, session_id: str) -> int: """ 删除单条会话 + + Args: + session_id: 会话ID + + Returns: + int: 删除的数量 """ - key = f"session:{session_id}" + key = generate_session_key(session_id) return self.r.delete(key) - def delete_all_sessions(self): + def delete_all_sessions(self) -> int: """ - 删除所有会话 + 删除所有会话(不包括 count 和 write 类型) + + Returns: + int: 删除的数量 """ keys = self.r.keys('session:*') - if keys: - return self.r.delete(*keys) + # 过滤掉 count 和 write 类型 + keys_to_delete = [k for k in keys if ':count:' not in k and ':write:' not in k] + if keys_to_delete: + return self.r.delete(*keys_to_delete) return 0 - def delete_duplicate_sessions(self): + def delete_duplicate_sessions(self) -> int: """ - 删除重复会话数据,条件: - "sessionid"、"user_id"、"end_user_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除 - 优化版本:使用 pipeline 批量操作,确保在1秒内完成 + 删除重复会话数据(不包括 count 和 write 类型) + 条件:sessionid、user_id、end_user_id、messages、aimessages 五个字段都相同的只保留一个 + + Returns: + int: 删除的数量 """ import time start_time = time.time() - # 第一步:使用 pipeline 批量获取所有 key keys = self.r.keys('session:*') - if not keys: print("[delete_duplicate_sessions] 没有会话数据") return 0 - # 第二步:使用 pipeline 批量获取所有数据 + # 批量获取所有数据 pipe = self.r.pipeline() for key in keys: - pipe.hgetall(key) + # 排除 count 和 write 类型 + if ':count:' not in key and ':write:' not in key: + pipe.hgetall(key) all_data = pipe.execute() - # 第三步:在内存中识别重复数据 - seen = {} # 用字典记录:identifier -> key(保留第一个出现的 key) - keys_to_delete = [] # 需要删除的 key 列表 + # 识别重复数据 + seen = {} + keys_to_delete = [] - for key, data in zip(keys, all_data, strict=False): + for key, data in zip([k for k in keys if ':count:' not in k and ':write:' not in k], all_data, strict=False): if not data: continue - # 获取五个字段的值 - sessionid = data.get('sessionid', '') - user_id = data.get('id', '') - end_user_id = data.get('end_user_id', '') - messages = data.get('messages', '') - aimessages = data.get('aimessages', '') - # 用五元组作为唯一标识 - identifier = (sessionid, user_id, end_user_id, messages, aimessages) + identifier = ( + data.get('sessionid', ''), + data.get('id', ''), + data.get('end_user_id', ''), + data.get('messages', ''), + data.get('aimessages', '') + ) if identifier in seen: - # 重复,标记为待删除 keys_to_delete.append(key) else: - # 第一次出现,记录 seen[identifier] = key - # 第四步:使用 pipeline 批量删除重复的 key + # 批量删除重复的 key deleted_count = 0 if keys_to_delete: - # 分批删除,避免单次操作过大 batch_size = 1000 for i in range(0, len(keys_to_delete), batch_size): batch = keys_to_delete[i:i + batch_size] @@ -233,79 +644,28 @@ class RedisSessionStore: print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒") return deleted_count - def find_user_session(self, sessionid): - user_id = sessionid - - result_items = [] - for key, values in store.get_all_sessions().items(): - history = {} - if user_id == str(values['sessionid']): - history["Query"] = values['messages'] - history["Answer"] = values['aimessages'] - result_items.append(history) - - if len(result_items) <= 1: - result_items = [] - return (result_items) - - def find_user_apply_group(self, sessionid, apply_id, end_user_id): - """ - 根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据,返回最新的6条 - """ - import time - start_time = time.time() - # 使用 pipeline 批量获取数据,提高性能 - keys = self.r.keys('session:*') - - if not keys: - print(f"查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") - return [] - - # 使用 pipeline 批量获取所有 hash 数据 - pipe = self.r.pipeline() - for key in keys: - pipe.hgetall(key) - all_data = pipe.execute() - - # 解析并筛选符合条件的数据 - matched_items = [] - for data in all_data: - if not data: - continue - - # 检查是否符合三个条件 - - if (data.get('apply_id') == apply_id and - data.get('end_user_id') == end_user_id): - # 支持模糊匹配 sessionid 或者完全匹配 - if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid: - matched_items.append({ - "Query": self._fix_encoding(data.get('messages')), - "Answer": self._fix_encoding(data.get('aimessages')), - "starttime": data.get('starttime', '') - }) - # 按时间降序排序(最新的在前) - matched_items.sort(key=lambda x: x.get('starttime', ''), reverse=True) - # 只保留最新的6条 - result_items = matched_items[:6] - # # 移除 starttime 字段 - for item in result_items: - item.pop('starttime', None) - - # 如果结果少于等于1条,返回空列表 - if len(result_items) <= 1: - result_items = [] - - elapsed_time = time.time() - start_time - print(f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") - - return result_items - +# 全局实例 store = RedisSessionStore( host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=settings.REDIS_DB, password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None, session_id=str(uuid.uuid4()) -) \ No newline at end of file +) + +write_store = RedisWriteStore( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + db=settings.REDIS_DB, + password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None, + session_id=str(uuid.uuid4()) +) + +count_store = RedisCountStore( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + db=settings.REDIS_DB, + password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None, + session_id=str(uuid.uuid4()) +) diff --git a/api/app/core/memory/models/__init__.py b/api/app/core/memory/models/__init__.py index 1de3424a..8c573b7a 100644 --- a/api/app/core/memory/models/__init__.py +++ b/api/app/core/memory/models/__init__.py @@ -58,6 +58,12 @@ from app.core.memory.models.triplet_models import ( TripletExtractionResponse, ) +# Ontology models +from app.core.memory.models.ontology_models import ( + OntologyClass, + OntologyExtractionResponse, +) + # Variable configuration models from app.core.memory.models.variate_config import ( StatementExtractionConfig, @@ -105,6 +111,9 @@ __all__ = [ "Entity", "Triplet", "TripletExtractionResponse", + # Ontology models + "OntologyClass", + "OntologyExtractionResponse", # Variable configuration "StatementExtractionConfig", "ForgettingEngineConfig", diff --git a/api/app/core/memory/models/ontology_models.py b/api/app/core/memory/models/ontology_models.py new file mode 100644 index 00000000..24a61f5f --- /dev/null +++ b/api/app/core/memory/models/ontology_models.py @@ -0,0 +1,135 @@ +"""Models for ontology classes and extraction responses. + +This module contains Pydantic models for representing extracted ontology classes +from scenario descriptions, following OWL ontology engineering standards. + +Classes: + OntologyClass: Represents an extracted ontology class + OntologyExtractionResponse: Response model containing extracted ontology classes +""" + +from typing import List, Optional +from uuid import uuid4 + +from pydantic import BaseModel, ConfigDict, Field, field_validator + + +class OntologyClass(BaseModel): + """Represents an extracted ontology class from scenario description. + + An ontology class represents an abstract category or concept in a domain, + following OWL ontology engineering standards and naming conventions. + + Attributes: + id: Unique string identifier for the ontology class + name: Name of the class in PascalCase format (e.g., 'MedicalProcedure') + name_chinese: Chinese translation of the class name (e.g., '医疗程序') + description: Textual description of the class + examples: List of concrete instance examples of this class + parent_class: Optional name of the parent class in the hierarchy + entity_type: Type/category of the entity (e.g., 'Person', 'Organization', 'Concept') + domain: Domain this class belongs to (e.g., 'Healthcare', 'Education') + + Config: + extra: Ignore extra fields from LLM output + """ + model_config = ConfigDict(extra='ignore') + + id: str = Field( + default_factory=lambda: uuid4().hex, + description="Unique identifier for the ontology class" + ) + name: str = Field( + ..., + description="Name of the class in PascalCase format" + ) + name_chinese: Optional[str] = Field( + None, + description="Chinese translation of the class name" + ) + description: str = Field( + ..., + description="Description of the class" + ) + examples: List[str] = Field( + default_factory=list, + description="List of concrete instance examples" + ) + parent_class: Optional[str] = Field( + None, + description="Name of the parent class in the hierarchy" + ) + entity_type: str = Field( + ..., + description="Type/category of the entity" + ) + domain: str = Field( + ..., + description="Domain this class belongs to" + ) + + @field_validator('name') + @classmethod + def validate_pascal_case(cls, v: str) -> str: + """Validate that the class name follows PascalCase convention. + + PascalCase rules: + - Must start with an uppercase letter + - Cannot contain spaces + - Should not contain special characters except underscores + + Args: + v: The class name to validate + + Returns: + The validated class name + + Raises: + ValueError: If the name doesn't follow PascalCase convention + """ + if not v: + raise ValueError("Class name cannot be empty") + + if not v[0].isupper(): + raise ValueError( + f"Class name '{v}' must start with an uppercase letter (PascalCase)" + ) + + if ' ' in v: + raise ValueError( + f"Class name '{v}' cannot contain spaces (PascalCase)" + ) + + # Check for invalid characters (allow alphanumeric and underscore only) + if not all(c.isalnum() or c == '_' for c in v): + raise ValueError( + f"Class name '{v}' contains invalid characters. " + "Only alphanumeric characters and underscores are allowed" + ) + + return v + + +class OntologyExtractionResponse(BaseModel): + """Response model for ontology extraction from LLM. + + This model represents the structured output from the LLM when + extracting ontology classes from scenario descriptions. + + Attributes: + classes: List of extracted ontology classes + domain: Domain/field the scenario belongs to + + Config: + extra: Ignore extra fields from LLM output + """ + model_config = ConfigDict(extra='ignore') + + classes: List[OntologyClass] = Field( + default_factory=list, + description="List of extracted ontology classes" + ) + domain: str = Field( + ..., + description="Domain/field the scenario belongs to" + ) diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/__init__.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/__init__.py index 53815124..0bc09622 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/__init__.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/__init__.py @@ -8,4 +8,5 @@ - TemporalExtractor: 时间信息提取 - EmbeddingGenerator: 嵌入向量生成 - MemorySummaryGenerator: 记忆摘要生成 +- OntologyExtractor: 本体类提取 """ diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py index f39313a8..58633363 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py @@ -14,6 +14,34 @@ from pydantic import Field logger = get_memory_logger(__name__) +# 支持的语言列表和默认回退值 +SUPPORTED_LANGUAGES = {"zh", "en"} +FALLBACK_LANGUAGE = "en" + + +def validate_language(language: Optional[str]) -> str: + """ + 校验语言参数,确保其为有效值。 + + Args: + language: 待校验的语言代码 + + Returns: + 有效的语言代码("zh" 或 "en") + """ + if language is None: + return FALLBACK_LANGUAGE + + lang = str(language).lower().strip() + if lang in SUPPORTED_LANGUAGES: + return lang + + logger.warning( + f"无效的语言参数 '{language}',已回退到默认值 '{FALLBACK_LANGUAGE}'。" + f"支持的语言: {SUPPORTED_LANGUAGES}" + ) + return FALLBACK_LANGUAGE + class MemorySummaryResponse(RobustLLMResponse): """Structured response for summary generation per chunk. @@ -31,7 +59,8 @@ class MemorySummaryResponse(RobustLLMResponse): async def generate_title_and_type_for_summary( content: str, - llm_client + llm_client, + language: str = None ) -> Tuple[str, str]: """ 为MemorySummary生成标题和类型 @@ -41,11 +70,18 @@ async def generate_title_and_type_for_summary( Args: content: Summary的内容文本 llm_client: LLM客户端实例 + language: 生成标题使用的语言 ("zh" 中文, "en" 英文),如果为None则从配置读取 Returns: (标题, 类型)元组 """ from app.core.memory.utils.prompt.prompt_utils import render_episodic_title_and_type_prompt + from app.core.config import settings + + # 如果没有指定语言,从配置中读取,并校验有效性 + if language is None: + language = settings.DEFAULT_LANGUAGE + language = validate_language(language) # 定义有效的类型集合 VALID_TYPES = { @@ -57,13 +93,19 @@ async def generate_title_and_type_for_summary( } DEFAULT_TYPE = "conversation" # 默认类型 + # 根据语言设置默认标题 + DEFAULT_TITLE = "空内容" if language == "zh" else "Empty Content" + PARSE_ERROR_TITLE = "解析失败" if language == "zh" else "Parse Failed" + ERROR_TITLE = "错误" if language == "zh" else "Error" + UNKNOWN_TITLE = "未知标题" if language == "zh" else "Unknown Title" + try: if not content: - logger.warning("content为空,无法生成标题和类型") - return ("空内容", DEFAULT_TYPE) + logger.warning(f"content为空,无法生成标题和类型 (language={language})") + return (DEFAULT_TITLE, DEFAULT_TYPE) - # 1. 渲染Jinja2提示词模板 - prompt = await render_episodic_title_and_type_prompt(content) + # 1. 渲染Jinja2提示词模板,传递语言参数 + prompt = await render_episodic_title_and_type_prompt(content, language=language) # 2. 调用LLM生成标题和类型 messages = [ @@ -102,7 +144,7 @@ async def generate_title_and_type_for_summary( json_str = json_str.strip() result_data = json.loads(json_str) - title = result_data.get("title", "未知标题") + title = result_data.get("title", UNKNOWN_TITLE) episodic_type_raw = result_data.get("type", DEFAULT_TYPE) # 5. 校验和归一化类型 @@ -130,16 +172,16 @@ async def generate_title_and_type_for_summary( f"已归一化为 '{episodic_type}'" ) - logger.info(f"成功生成标题和类型: title={title}, type={episodic_type}") + logger.info(f"成功生成标题和类型 (language={language}): title={title}, type={episodic_type}") return (title, episodic_type) except json.JSONDecodeError: - logger.error(f"无法解析LLM响应为JSON: {full_response}") - return ("解析失败", DEFAULT_TYPE) + logger.error(f"无法解析LLM响应为JSON (language={language}): {full_response}") + return (PARSE_ERROR_TITLE, DEFAULT_TYPE) except Exception as e: - logger.error(f"生成标题和类型时出错: {str(e)}", exc_info=True) - return ("错误", DEFAULT_TYPE) + logger.error(f"生成标题和类型时出错 (language={language}): {str(e)}", exc_info=True) + return (ERROR_TITLE, DEFAULT_TYPE) async def _process_chunk_summary( dialog: DialogData, @@ -153,11 +195,16 @@ async def _process_chunk_summary( return None try: + # 从配置中获取语言设置(只获取一次,复用),并校验有效性 + from app.core.config import settings + language = validate_language(settings.DEFAULT_LANGUAGE) + # Render prompt via Jinja2 for a single chunk prompt_content = await render_memory_summary_prompt( chunk_texts=chunk.content, json_schema=MemorySummaryResponse.model_json_schema(), max_words=200, + language=language, ) messages = [ @@ -178,9 +225,10 @@ async def _process_chunk_summary( try: title, episodic_type = await generate_title_and_type_for_summary( content=summary_text, - llm_client=llm_client + llm_client=llm_client, + language=language ) - logger.info(f"Generated title and type for MemorySummary: title={title}, type={episodic_type}") + logger.info(f"Generated title and type for MemorySummary (language={language}): title={title}, type={episodic_type}") except Exception as e: logger.warning(f"Failed to generate title and type for chunk {chunk.id}: {e}") # Continue without title and type diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/ontology_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/ontology_extraction.py new file mode 100644 index 00000000..d1b79bd1 --- /dev/null +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/ontology_extraction.py @@ -0,0 +1,482 @@ +"""Ontology class extraction from scenario descriptions using LLM. + +This module provides the OntologyExtractor class for extracting ontology classes +from natural language scenario descriptions. It uses LLM-driven extraction combined +with two-layer validation (string validation + OWL semantic validation). + +Classes: + OntologyExtractor: Extracts ontology classes from scenario descriptions +""" + +import asyncio +import logging +import time +from typing import List, Optional + +from app.core.memory.llm_tools.openai_client import OpenAIClient +from app.core.memory.models.ontology_models import ( + OntologyClass, + OntologyExtractionResponse, +) +from app.core.memory.utils.validation.ontology_validator import OntologyValidator +from app.core.memory.utils.validation.owl_validator import OWLValidator +from app.core.memory.utils.prompt.prompt_utils import render_ontology_extraction_prompt + + +logger = logging.getLogger(__name__) + + +class OntologyExtractor: + """Extractor for ontology classes from scenario descriptions. + + This extractor uses LLM to identify abstract classes and concepts from + natural language scenario descriptions, following OWL ontology engineering + standards. It performs two-layer validation: + 1. String validation (naming conventions, reserved words, duplicates) + 2. OWL semantic validation (consistency checking, circular inheritance) + + Attributes: + llm_client: OpenAI client for LLM calls + validator: String validator for class names and descriptions + owl_validator: OWL validator for semantic validation + """ + + def __init__(self, llm_client: OpenAIClient): + """Initialize the OntologyExtractor. + + Args: + llm_client: OpenAIClient instance for LLM processing + """ + self.llm_client = llm_client + self.validator = OntologyValidator() + self.owl_validator = OWLValidator() + + logger.info("OntologyExtractor initialized") + + async def extract_ontology_classes( + self, + scenario: str, + domain: Optional[str] = None, + max_classes: int = 15, + min_classes: int = 5, + enable_owl_validation: bool = True, + llm_temperature: float = 0.3, + llm_max_tokens: int = 2000, + max_description_length: int = 500, + timeout: Optional[float] = None, + ) -> OntologyExtractionResponse: + """Extract ontology classes from a scenario description. + + This is the main extraction method that orchestrates the entire process: + 1. Call LLM to extract ontology classes + 2. Perform first-layer validation (string validation and cleaning) + 3. Perform second-layer validation (OWL semantic validation) + 4. Filter invalid classes based on validation errors + 5. Return validated ontology classes + + Args: + scenario: Natural language scenario description + domain: Optional domain hint (e.g., "Healthcare", "Education") + max_classes: Maximum number of classes to extract (default: 15) + min_classes: Minimum number of classes to extract (default: 5) + enable_owl_validation: Whether to enable OWL validation (default: True) + llm_temperature: LLM temperature parameter (default: 0.3) + llm_max_tokens: LLM max tokens parameter (default: 2000) + max_description_length: Maximum description length (default: 500) + timeout: Optional timeout in seconds for LLM call (default: None, no timeout) + + Returns: + OntologyExtractionResponse containing validated ontology classes + + Raises: + ValueError: If scenario is empty or invalid + asyncio.TimeoutError: If extraction times out + + Examples: + >>> extractor = OntologyExtractor(llm_client) + >>> response = await extractor.extract_ontology_classes( + ... scenario="A hospital manages patient records...", + ... domain="Healthcare", + ... max_classes=10, + ... timeout=30.0 + ... ) + >>> len(response.classes) + 7 + """ + # Start timing + start_time = time.time() + + # Validate input + if not scenario or not scenario.strip(): + logger.error("Scenario description is empty") + raise ValueError("Scenario description cannot be empty") + + scenario = scenario.strip() + + logger.info( + f"Starting ontology extraction - scenario_length={len(scenario)}, " + f"domain={domain}, max_classes={max_classes}, min_classes={min_classes}, " + f"timeout={timeout}" + ) + + try: + # Step 1: Call LLM for extraction with timeout + logger.info("Step 1: Calling LLM for ontology extraction") + llm_start_time = time.time() + + if timeout is not None: + # Wrap LLM call with timeout + try: + response = await asyncio.wait_for( + self._call_llm_for_extraction( + scenario=scenario, + domain=domain, + max_classes=max_classes, + llm_temperature=llm_temperature, + llm_max_tokens=llm_max_tokens, + ), + timeout=timeout + ) + except asyncio.TimeoutError: + llm_duration = time.time() - llm_start_time + logger.error( + f"LLM extraction timed out after {timeout} seconds " + f"(actual duration: {llm_duration:.2f}s)" + ) + # Return empty response on timeout + return OntologyExtractionResponse( + classes=[], + domain=domain or "Unknown", + ) + else: + # No timeout specified, call directly + response = await self._call_llm_for_extraction( + scenario=scenario, + domain=domain, + max_classes=max_classes, + llm_temperature=llm_temperature, + llm_max_tokens=llm_max_tokens, + ) + + llm_duration = time.time() - llm_start_time + logger.info( + f"LLM returned {len(response.classes)} classes in {llm_duration:.2f}s" + ) + + # Step 2: First-layer validation (string validation and cleaning) + logger.info("Step 2: Performing first-layer validation (string validation)") + validation_start_time = time.time() + + response = self._validate_and_clean( + response=response, + max_description_length=max_description_length, + ) + + validation_duration = time.time() - validation_start_time + logger.info( + f"After first-layer validation: {len(response.classes)} classes remain " + f"(validation took {validation_duration:.2f}s)" + ) + + # Check if we have enough classes after first-layer validation + if len(response.classes) < min_classes: + logger.warning( + f"Only {len(response.classes)} classes remain after validation, " + f"which is below minimum of {min_classes}" + ) + + # Step 3: Second-layer validation (OWL semantic validation) + if enable_owl_validation and response.classes: + logger.info("Step 3: Performing second-layer validation (OWL validation)") + owl_start_time = time.time() + + is_valid, errors, world = self.owl_validator.validate_ontology_classes( + classes=response.classes, + ) + + owl_duration = time.time() - owl_start_time + + if not is_valid: + logger.warning( + f"OWL validation found {len(errors)} issues in {owl_duration:.2f}s: {errors}" + ) + + # Filter invalid classes based on errors + response = self._filter_invalid_classes( + response=response, + errors=errors, + ) + + logger.info( + f"After second-layer validation: {len(response.classes)} classes remain" + ) + else: + logger.info(f"OWL validation passed successfully in {owl_duration:.2f}s") + else: + if not enable_owl_validation: + logger.info("Step 3: OWL validation disabled, skipping") + else: + logger.info("Step 3: No classes to validate, skipping OWL validation") + + # Calculate total duration + total_duration = time.time() - start_time + + # Log extraction statistics + logger.info( + f"Ontology extraction completed - " + f"final_class_count={len(response.classes)}, " + f"domain={response.domain}, " + f"total_duration={total_duration:.2f}s, " + f"llm_duration={llm_duration:.2f}s" + ) + + return response + + except asyncio.TimeoutError: + # Re-raise timeout errors + total_duration = time.time() - start_time + logger.error( + f"Ontology extraction timed out after {timeout} seconds " + f"(total duration: {total_duration:.2f}s)", + exc_info=True + ) + raise + except Exception as e: + total_duration = time.time() - start_time + logger.error( + f"Ontology extraction failed after {total_duration:.2f}s: {str(e)}", + exc_info=True + ) + # Return empty response on failure + return OntologyExtractionResponse( + classes=[], + domain=domain or "Unknown", + ) + + async def _call_llm_for_extraction( + self, + scenario: str, + domain: Optional[str], + max_classes: int, + llm_temperature: float, + llm_max_tokens: int, + ) -> OntologyExtractionResponse: + """Call LLM to extract ontology classes from scenario. + + This method renders the extraction prompt using the Jinja2 template + and calls the LLM with structured output to get ontology classes. + + Args: + scenario: Scenario description text + domain: Optional domain hint + max_classes: Maximum number of classes to extract + llm_temperature: LLM temperature parameter + llm_max_tokens: LLM max tokens parameter + + Returns: + OntologyExtractionResponse from LLM + + Raises: + Exception: If LLM call fails + """ + try: + # Render prompt using template + prompt_content = await render_ontology_extraction_prompt( + scenario=scenario, + domain=domain, + max_classes=max_classes, + json_schema=OntologyExtractionResponse.model_json_schema(), + ) + + logger.debug(f"Rendered prompt length: {len(prompt_content)}") + + # Create messages for LLM + messages = [ + { + "role": "system", + "content": ( + "You are an expert ontology engineer specializing in knowledge " + "representation and OWL standards. Extract ontology classes from " + "scenario descriptions following the provided instructions. " + "Return valid JSON conforming to the schema." + ), + }, + { + "role": "user", + "content": prompt_content, + }, + ] + + # Call LLM with structured output + logger.debug( + f"Calling LLM with temperature={llm_temperature}, " + f"max_tokens={llm_max_tokens}" + ) + + response = await self.llm_client.response_structured( + messages=messages, + response_model=OntologyExtractionResponse, + ) + + logger.info( + f"LLM extraction successful - extracted {len(response.classes)} classes" + ) + + return response + + except Exception as e: + logger.error( + f"LLM extraction failed: {str(e)}", + exc_info=True + ) + raise + + def _validate_and_clean( + self, + response: OntologyExtractionResponse, + max_description_length: int, + ) -> OntologyExtractionResponse: + """Perform first-layer validation: string validation and cleaning. + + This method validates and cleans the extracted ontology classes: + 1. Validate class names (PascalCase, no reserved words) + 2. Sanitize invalid class names + 3. Truncate long descriptions + 4. Remove duplicate classes + + Args: + response: OntologyExtractionResponse from LLM + max_description_length: Maximum description length + + Returns: + Cleaned OntologyExtractionResponse + """ + if not response.classes: + logger.debug("No classes to validate") + return response + + logger.debug(f"Validating {len(response.classes)} classes") + + validated_classes = [] + + for ontology_class in response.classes: + # Validate class name + is_valid, error_msg = self.validator.validate_class_name( + ontology_class.name + ) + + if not is_valid: + logger.warning( + f"Invalid class name '{ontology_class.name}': {error_msg}" + ) + + # Attempt to sanitize + sanitized_name = self.validator.sanitize_class_name( + ontology_class.name + ) + + logger.info( + f"Sanitized class name: '{ontology_class.name}' -> '{sanitized_name}'" + ) + + # Update class name + ontology_class.name = sanitized_name + + # Re-validate sanitized name + is_valid, error_msg = self.validator.validate_class_name( + sanitized_name + ) + + if not is_valid: + logger.error( + f"Failed to sanitize class name '{ontology_class.name}': {error_msg}. " + "Skipping this class." + ) + continue + + # Truncate description if too long + if ontology_class.description: + original_length = len(ontology_class.description) + ontology_class.description = self.validator.truncate_description( + ontology_class.description, + max_length=max_description_length, + ) + + if len(ontology_class.description) < original_length: + logger.debug( + f"Truncated description for '{ontology_class.name}': " + f"{original_length} -> {len(ontology_class.description)} chars" + ) + + validated_classes.append(ontology_class) + + # Remove duplicates (case-insensitive) + original_count = len(validated_classes) + validated_classes = self.validator.remove_duplicates(validated_classes) + + if len(validated_classes) < original_count: + logger.info( + f"Removed {original_count - len(validated_classes)} duplicate classes" + ) + + # Return cleaned response + return OntologyExtractionResponse( + classes=validated_classes, + domain=response.domain, + ) + + def _filter_invalid_classes( + self, + response: OntologyExtractionResponse, + errors: List[str], + ) -> OntologyExtractionResponse: + """Filter invalid classes based on OWL validation errors. + + This method analyzes OWL validation errors and removes classes + that caused validation failures (e.g., circular inheritance, + inconsistencies). + + Args: + response: OntologyExtractionResponse to filter + errors: List of error messages from OWL validation + + Returns: + Filtered OntologyExtractionResponse + """ + if not errors: + return response + + logger.debug(f"Filtering classes based on {len(errors)} OWL validation errors") + + # Extract class names mentioned in errors + invalid_class_names = set() + + for error in errors: + # Look for class names in error messages + for ontology_class in response.classes: + if ontology_class.name in error: + invalid_class_names.add(ontology_class.name) + logger.debug( + f"Class '{ontology_class.name}' marked as invalid due to error: {error}" + ) + + # Filter out invalid classes + if invalid_class_names: + original_count = len(response.classes) + + filtered_classes = [ + c for c in response.classes + if c.name not in invalid_class_names + ] + + logger.info( + f"Filtered out {original_count - len(filtered_classes)} invalid classes: " + f"{invalid_class_names}" + ) + + return OntologyExtractionResponse( + classes=filtered_classes, + domain=response.domain, + ) + + return response diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py index bfc0bc88..8c3e31b4 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py @@ -25,6 +25,15 @@ class TripletExtractor: """ self.llm_client = llm_client + def _get_language(self) -> str: + """Get the configured language for entity descriptions + + Returns: + Language code ("zh" or "en") + """ + from app.core.config import settings + return settings.DEFAULT_LANGUAGE + async def _extract_triplets(self, statement: Statement, chunk_content: str) -> TripletExtractionResponse: """Process a single statement and return extracted triplets and entities""" # Render the prompt using helper function @@ -40,7 +49,8 @@ class TripletExtractor: statement=statement.statement, chunk_content=chunk_content, json_schema=TripletExtractionResponse.model_json_schema(), - predicate_instructions=PREDICATE_DEFINITIONS + predicate_instructions=PREDICATE_DEFINITIONS, + language=self._get_language() ) # Create messages for LLM diff --git a/api/app/core/memory/utils/prompt/prompt_utils.py b/api/app/core/memory/utils/prompt/prompt_utils.py index 50593e49..a4d2af95 100644 --- a/api/app/core/memory/utils/prompt/prompt_utils.py +++ b/api/app/core/memory/utils/prompt/prompt_utils.py @@ -177,7 +177,7 @@ def render_entity_dedup_prompt( # Args: # entity_a: Dict of entity A attributes -async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None) -> str: +async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None, language: str = "zh") -> str: """ Renders the triplet extraction prompt using the extract_triplet.jinja2 template. @@ -186,6 +186,7 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j chunk_content: The content of the chunk to process json_schema: JSON schema for the expected output format predicate_instructions: Optional predicate instructions + language: The language to use for entity descriptions ("zh" for Chinese, "en" for English) Returns: Rendered prompt content as string @@ -195,7 +196,8 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j statement=statement, chunk_content=chunk_content, json_schema=json_schema, - predicate_instructions=predicate_instructions + predicate_instructions=predicate_instructions, + language=language ) # 记录渲染结果到提示日志(与示例日志结构一致) log_prompt_rendering('triplet extraction', rendered_prompt) @@ -204,7 +206,8 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j 'statement': 'str', 'chunk_content': 'str', 'json_schema': 'TripletExtractionResponse.schema', - 'predicate_instructions': 'PREDICATE_DEFINITIONS' + 'predicate_instructions': 'PREDICATE_DEFINITIONS', + 'language': language }) return rendered_prompt @@ -213,6 +216,7 @@ async def render_memory_summary_prompt( chunk_texts: str, json_schema: dict, max_words: int = 200, + language: str = "zh", ) -> str: """ Renders the memory summary prompt using the memory_summary.jinja2 template. @@ -221,6 +225,7 @@ async def render_memory_summary_prompt( chunk_texts: Concatenated text of conversation chunks json_schema: JSON schema for the expected output format max_words: Maximum words for the summary + language: The language to use for summary generation ("zh" for Chinese, "en" for English) Returns: Rendered prompt content as string. @@ -230,12 +235,14 @@ async def render_memory_summary_prompt( chunk_texts=chunk_texts, json_schema=json_schema, max_words=max_words, + language=language, ) log_prompt_rendering('memory summary', rendered_prompt) log_template_rendering('memory_summary.jinja2', { 'chunk_texts_len': len(chunk_texts or ""), 'max_words': max_words, - 'json_schema': 'MemorySummaryResponse.schema' + 'json_schema': 'MemorySummaryResponse.schema', + 'language': language }) return rendered_prompt @@ -388,24 +395,65 @@ async def render_memory_insight_prompt( return rendered_prompt -async def render_episodic_title_and_type_prompt(content: str) -> str: +async def render_episodic_title_and_type_prompt(content: str, language: str = "zh") -> str: """ Renders the episodic title and type classification prompt using the episodic_type_classification.jinja2 template. Args: content: The content of the episodic memory summary to analyze + language: The language to use for title generation ("zh" for Chinese, "en" for English) Returns: Rendered prompt content as string """ template = prompt_env.get_template("episodic_type_classification.jinja2") - rendered_prompt = template.render(content=content) + rendered_prompt = template.render(content=content, language=language) # 记录渲染结果到提示日志 log_prompt_rendering('episodic title and type classification', rendered_prompt) # 可选:记录模板渲染信息 log_template_rendering('episodic_type_classification.jinja2', { - 'content_len': len(content) if content else 0 + 'content_len': len(content) if content else 0, + 'language': language + }) + + return rendered_prompt + + +async def render_ontology_extraction_prompt( + scenario: str, + domain: str | None = None, + max_classes: int = 15, + json_schema: dict | None = None +) -> str: + """ + Renders the ontology extraction prompt using the extract_ontology.jinja2 template. + + Args: + scenario: The scenario description text to extract ontology classes from + domain: Optional domain hint for the scenario (e.g., "Healthcare", "Education") + max_classes: Maximum number of classes to extract (default: 15) + json_schema: JSON schema for the expected output format + + Returns: + Rendered prompt content as string + """ + template = prompt_env.get_template("extract_ontology.jinja2") + rendered_prompt = template.render( + scenario=scenario, + domain=domain, + max_classes=max_classes, + json_schema=json_schema + ) + + # 记录渲染结果到提示日志 + log_prompt_rendering('ontology extraction', rendered_prompt) + # 可选:记录模板渲染信息 + log_template_rendering('extract_ontology.jinja2', { + 'scenario_len': len(scenario) if scenario else 0, + 'domain': domain, + 'max_classes': max_classes, + 'json_schema': 'OntologyExtractionResponse.schema' }) return rendered_prompt diff --git a/api/app/core/memory/utils/prompt/prompts/episodic_type_classification.jinja2 b/api/app/core/memory/utils/prompt/prompts/episodic_type_classification.jinja2 index fa382ec7..d778890b 100644 --- a/api/app/core/memory/utils/prompt/prompts/episodic_type_classification.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/episodic_type_classification.jinja2 @@ -1,8 +1,19 @@ === Task === Generate a concise title and classify the episodic memory into the most appropriate category. +{% if language == "zh" %} +**重要:请使用中文生成标题和分类。** +{% else %} +**Important: Please generate the title and classification in English.** +{% endif %} + === Requirements === - Extract a clear, concise title (10-20 characters) that captures the core content +{% if language == "zh" %} +- 标题必须使用中文 +{% else %} +- Title must be in English +{% endif %} - Classify into exactly one category based on the primary theme - Be specific and avoid ambiguity - Output must be valid JSON conforming to the schema below diff --git a/api/app/core/memory/utils/prompt/prompts/extract_ontology.jinja2 b/api/app/core/memory/utils/prompt/prompts/extract_ontology.jinja2 new file mode 100644 index 00000000..80594ad9 --- /dev/null +++ b/api/app/core/memory/utils/prompt/prompts/extract_ontology.jinja2 @@ -0,0 +1,210 @@ +===Task=== +Extract ontology classes from the given scenario description following ontology engineering standards. + +===Role=== +You are a professional ontology engineer with expertise in knowledge representation and OWL (Web Ontology Language) standards. Your task is to identify abstract classes and concepts from scenario descriptions, not concrete instances. + +===Scenario Description=== +{{ scenario }} + +{% if domain -%} +===Domain Hint=== +This scenario belongs to the **{{ domain }}** domain. Consider domain-specific concepts and terminology when extracting classes. +{%- endif %} + +===Extraction Rules=== + +**1. Abstract Classes, Not Instances:** +- Extract abstract categories and concepts (e.g., "MedicalProcedure", "Patient", "Diagnosis") +- Do NOT extract concrete instances (e.g., "John Smith", "Room 301", "2024-01-15") +- Think in terms of "types of things" rather than "specific things" + +**2. Naming Convention (PascalCase):** +- Use PascalCase format for the "name" field: start with uppercase letter, capitalize each word, no spaces +- Examples: "MedicalProcedure", "HealthcareProvider", "DiagnosticTest" +- Avoid: "medical procedure", "healthcare_provider", "diagnostic-test" +- Use clear, descriptive names in English +- Avoid abbreviations unless they are standard in the domain (e.g., "API", "DNA") +- Provide Chinese translation in the "name_chinese" field (e.g., "医疗程序", "医疗服务提供者", "诊断测试") + +**3. Domain Relevance:** +- Focus on classes that are central to the scenario's domain +- Prioritize classes that represent key concepts, entities, or relationships +- Avoid overly generic classes (e.g., "Thing", "Object") unless they have specific domain meaning + +**4. Class Quantity:** +- Extract between 5 and {{ max_classes }} classes +- Aim for a balanced set covering the main concepts in the scenario +- Quality over quantity: prefer well-defined classes over exhaustive lists + +**5. Clear Descriptions:** +- Provide concise, informative descriptions in Chinese (max 500 characters) +- Describe what the class represents, not specific instances +- Use clear, natural Chinese language that explains the class's role in the domain + +**6. Concrete Examples:** +- Provide 2-5 concrete instance examples in Chinese for each class +- Examples should be specific, realistic instances of the class +- Examples help clarify the class's scope and meaning +- Use natural Chinese language for examples +- Example format: ["示例1", "示例2", "示例3"] + +**7. Class Hierarchy:** +- Identify parent-child relationships where applicable +- Use the parent_class field to specify inheritance +- Parent class must be one of the extracted classes or a standard OWL class +- Leave parent_class as null for top-level classes + +**8. Entity Types:** +- Classify each class with an appropriate entity_type +- Common types: "Person", "Organization", "Location", "Event", "Concept", "Process", "Object", "Role" +- Choose the most specific type that applies + +**9. OWL Reserved Words:** +- Do NOT use OWL reserved words as class names +- Reserved words include: "Thing", "Nothing", "Class", "Property", "ObjectProperty", "DatatypeProperty", "AnnotationProperty", "Ontology", "Individual", "Literal" +- If a reserved word is needed, add a domain-specific prefix (e.g., "MedicalClass" instead of "Class") + +**10. Language Consistency:** +- Extract all class names in English (PascalCase format) for the "name" field +- Provide Chinese translation for class names in the "name_chinese" field +- Descriptions MUST be in Chinese (中文) +- Examples MUST be in Chinese (中文) +- Use clear, natural Chinese language for descriptions and examples + +===Examples=== + +**Example 1 (Healthcare Domain):** +Scenario: "A hospital manages patient records, schedules appointments, and coordinates medical procedures. Doctors diagnose conditions and prescribe treatments." + +Output: +{ + "classes": [ + { + "name": "Patient", + "name_chinese": "患者", + "description": "在医疗机构接受医疗护理或治疗的人", + "examples": ["张三", "李四", "患有糖尿病的老年患者"], + "parent_class": null, + "entity_type": "Person", + "domain": "Healthcare" + }, + { + "name": "MedicalProcedure", + "name_chinese": "医疗程序", + "description": "为医疗诊断或治疗而执行的系统性操作流程", + "examples": ["手术", "血液检查", "X光检查", "疫苗接种"], + "parent_class": null, + "entity_type": "Process", + "domain": "Healthcare" + }, + { + "name": "Diagnosis", + "name_chinese": "诊断", + "description": "基于症状和检查结果对疾病或状况的识别", + "examples": ["糖尿病诊断", "癌症诊断", "流感诊断"], + "parent_class": null, + "entity_type": "Concept", + "domain": "Healthcare" + }, + { + "name": "Doctor", + "name_chinese": "医生", + "description": "诊断和治疗患者的持证医疗专业人员", + "examples": ["全科医生", "外科医生", "心脏病专家"], + "parent_class": null, + "entity_type": "Role", + "domain": "Healthcare" + }, + { + "name": "Treatment", + "name_chinese": "治疗", + "description": "为治愈或管理疾病状况而提供的医疗护理或疗法", + "examples": ["药物治疗", "物理治疗", "化疗", "手术治疗"], + "parent_class": null, + "entity_type": "Process", + "domain": "Healthcare" + } + ], + "domain": "Healthcare", + "namespace": "http://example.org/healthcare#" +} + +**Example 2 (Education Domain):** +Scenario: "A university offers courses taught by professors. Students enroll in programs, attend lectures, and complete assignments to earn degrees." + +Output: +{ + "classes": [ + { + "name": "Student", + "name_chinese": "学生", + "description": "在教育机构注册学习的人", + "examples": ["本科生", "研究生", "在职学生"], + "parent_class": null, + "entity_type": "Role", + "domain": "Education" + }, + { + "name": "Course", + "name_chinese": "课程", + "description": "涵盖特定学科或主题的结构化教育课程", + "examples": ["计算机科学导论", "微积分I", "世界历史"], + "parent_class": null, + "entity_type": "Concept", + "domain": "Education" + }, + { + "name": "Professor", + "name_chinese": "教授", + "description": "教授课程并进行研究的学术教师", + "examples": ["助理教授", "副教授", "正教授"], + "parent_class": null, + "entity_type": "Role", + "domain": "Education" + }, + { + "name": "AcademicProgram", + "name_chinese": "学术项目", + "description": "通向学位或证书的结构化课程体系", + "examples": ["理学学士", "文学硕士", "博士项目"], + "parent_class": null, + "entity_type": "Concept", + "domain": "Education" + }, + { + "name": "Assignment", + "name_chinese": "作业", + "description": "分配给学生以评估学习成果的任务或项目", + "examples": ["论文", "习题集", "研究报告", "实验报告"], + "parent_class": null, + "entity_type": "Object", + "domain": "Education" + }, + { + "name": "Lecture", + "name_chinese": "讲座", + "description": "由教师进行的教育性演讲或讲座", + "examples": ["入门讲座", "客座讲座", "在线讲座"], + "parent_class": null, + "entity_type": "Event", + "domain": "Education" + } + ], + "domain": "Education", + "namespace": "http://example.org/education#" +} + +===Output Format=== + +**JSON Requirements:** +- Use only ASCII double quotes (") for JSON structure +- Never use Chinese quotation marks ("") or Unicode quotes +- Escape quotation marks in text with backslashes (\") +- Ensure proper string closure and comma separation +- No line breaks within JSON string values +- All class names must be in PascalCase format +- All class names must be unique (case-insensitive) +- Extract between 5 and {{ max_classes }} classes + +{{ json_schema }} diff --git a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 index 03691a04..67df162a 100644 --- a/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/extract_triplet.jinja2 @@ -5,6 +5,12 @@ ===Task=== Extract entities and knowledge triplets from the given statement. +{% if language == "zh" %} +**重要:请使用中文生成实体描述(description)和示例(example)。** +{% else %} +**Important: Please generate entity descriptions and examples in English.** +{% endif %} + ===Inputs=== **Chunk Content:** "{{ chunk_content }}" **Statement:** "{{ statement }}" @@ -13,6 +19,13 @@ Extract entities and knowledge triplets from the given statement. **Entity Extraction:** - Extract entities with their types, context-independent descriptions, **concise examples**, aliases, and semantic memory classification +{% if language == "zh" %} +- **实体描述(description)必须使用中文** +- **示例(example)必须使用中文** +{% else %} +- **Entity descriptions must be in English** +- **Examples must be in English** +{% endif %} - **Semantic Memory Classification (is_explicit_memory):** * Set to `true` if the entity represents **explicit/semantic memory**: - **Concepts:** "Machine Learning", "Photosynthesis", "Democracy", "人工智能", "光合作用", "民主" @@ -334,9 +347,11 @@ Output: - Escape quotation marks in text with backslashes (\") - Ensure proper string closure and comma separation - No line breaks within JSON string values -- The output language should ALWAYS match the input language -- If input is in English, extract statements in English -- If input is in Chinese, extract statements in Chinese +{% if language == "zh" %} +- **语言要求:实体描述(description)和示例(example)必须使用中文** +{% else %} +- **Language Requirement: Entity descriptions and examples must be in English** +{% endif %} - Preserve the original language and do not translate {{ json_schema }} \ No newline at end of file diff --git a/api/app/core/memory/utils/prompt/prompts/memory_summary.jinja2 b/api/app/core/memory/utils/prompt/prompts/memory_summary.jinja2 index 1dd86ca3..82f91cc4 100644 --- a/api/app/core/memory/utils/prompt/prompts/memory_summary.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/memory_summary.jinja2 @@ -5,10 +5,21 @@ === Task === Summarize the provided conversation chunks into a concise Memory summary. +{% if language == "zh" %} +**重要:请使用中文生成摘要内容。** +{% else %} +**Important: Please generate the summary content in English.** +{% endif %} + === Requirements === - Focus on factual statements, user preferences, relationships, and salient temporal context. - Avoid repetition and filler; be specific. - Keep it under {{ max_words or 200 }} words. +{% if language == "zh" %} +- 摘要内容必须使用中文 +{% else %} +- Summary content must be in English +{% endif %} - Output must be valid JSON conforming to the schema below. === Input === @@ -24,6 +35,11 @@ Summarize the provided conversation chunks into a concise Memory summary. 4. Do not include line breaks within JSON string values 5. Example of proper escaping: "statement": "张曼婷说:\"我很喜欢这本书。\"" -The output language should always be the same as the input language. +{% if language == "zh" %} +**语言要求:输出内容必须使用中文。** +{% else %} +**Language Requirement: The output content must be in English.** +{% endif %} + Return only a list of extracted labelled statements in the JSON ARRAY of objects that match the schema below: {{ json_schema }} \ No newline at end of file diff --git a/api/app/core/memory/utils/validation/__init__.py b/api/app/core/memory/utils/validation/__init__.py new file mode 100644 index 00000000..d5dd41e7 --- /dev/null +++ b/api/app/core/memory/utils/validation/__init__.py @@ -0,0 +1,10 @@ +"""Validation utilities for ontology extraction. + +This module provides validation classes for ontology class names, +descriptions, and OWL compliance checking. +""" + +from .ontology_validator import OntologyValidator +from .owl_validator import OWLValidator + +__all__ = ['OntologyValidator', 'OWLValidator'] diff --git a/api/app/core/memory/utils/validation/ontology_validator.py b/api/app/core/memory/utils/validation/ontology_validator.py new file mode 100644 index 00000000..eb7492ad --- /dev/null +++ b/api/app/core/memory/utils/validation/ontology_validator.py @@ -0,0 +1,268 @@ +"""String validation for ontology class names and descriptions. + +This module provides the OntologyValidator class for validating and sanitizing +ontology class names according to OWL standards and naming conventions. + +Classes: + OntologyValidator: Validates class names, removes duplicates, and truncates descriptions +""" + +import logging +import re +from typing import List, Tuple + +from app.core.memory.models.ontology_models import OntologyClass + + +logger = logging.getLogger(__name__) + + +class OntologyValidator: + """Validator for ontology class names and descriptions. + + This validator performs string-level validation including: + - PascalCase naming convention validation + - OWL reserved word checking + - Duplicate class name removal + - Description length truncation + + Attributes: + OWL_RESERVED_WORDS: Set of OWL reserved words that cannot be used as class names + """ + + # OWL reserved words that cannot be used as class names + OWL_RESERVED_WORDS = { + 'Thing', 'Nothing', 'Class', 'Property', + 'ObjectProperty', 'DatatypeProperty', 'FunctionalProperty', + 'InverseFunctionalProperty', 'TransitiveProperty', 'SymmetricProperty', + 'AsymmetricProperty', 'ReflexiveProperty', 'IrreflexiveProperty', + 'Restriction', 'Ontology', 'Individual', 'NamedIndividual', + 'Annotation', 'AnnotationProperty', 'Axiom', + 'AllDifferent', 'AllDisjointClasses', 'AllDisjointProperties', + 'Datatype', 'DataRange', 'Literal', + 'DeprecatedClass', 'DeprecatedProperty', + 'Imports', 'IncompatibleWith', 'PriorVersion', 'VersionInfo', + 'BackwardCompatibleWith', 'OntologyProperty', + } + + def validate_class_name(self, name: str) -> Tuple[bool, str]: + """Validate that a class name follows OWL naming conventions. + + Validation rules: + 1. Must not be empty + 2. Must start with an uppercase letter (PascalCase) + 3. Cannot contain spaces + 4. Can only contain alphanumeric characters and underscores + 5. Cannot be an OWL reserved word + + Args: + name: The class name to validate + + Returns: + Tuple of (is_valid, error_message) + - is_valid: True if the name is valid, False otherwise + - error_message: Empty string if valid, error description if invalid + + Examples: + >>> validator = OntologyValidator() + >>> validator.validate_class_name("MedicalProcedure") + (True, "") + >>> validator.validate_class_name("medical procedure") + (False, "Class name 'medical procedure' cannot contain spaces") + >>> validator.validate_class_name("Thing") + (False, "Class name 'Thing' is an OWL reserved word") + """ + logger.debug(f"Validating class name: '{name}'") + + # Check if empty + if not name or not name.strip(): + error_msg = "Class name cannot be empty" + logger.warning(f"Validation failed: {error_msg}") + return False, error_msg + + name = name.strip() + + # Check if it's an OWL reserved word + if name in self.OWL_RESERVED_WORDS: + error_msg = f"Class name '{name}' is an OWL reserved word" + logger.warning(f"Validation failed: {error_msg}") + return False, error_msg + + # Check if starts with uppercase letter + if not name[0].isupper(): + error_msg = f"Class name '{name}' must start with an uppercase letter (PascalCase)" + logger.warning(f"Validation failed: {error_msg}") + return False, error_msg + + # Check for spaces + if ' ' in name: + error_msg = f"Class name '{name}' cannot contain spaces" + logger.warning(f"Validation failed: {error_msg}") + return False, error_msg + + # Check for invalid characters (only alphanumeric and underscore allowed) + if not re.match(r'^[A-Za-z0-9_]+$', name): + error_msg = f"Class name '{name}' contains invalid characters. Only alphanumeric characters and underscores are allowed" + logger.warning(f"Validation failed: {error_msg}") + return False, error_msg + + logger.debug(f"Class name '{name}' is valid") + return True, "" + + def sanitize_class_name(self, name: str) -> str: + """Attempt to sanitize an invalid class name into a valid format. + + Sanitization steps: + 1. Strip whitespace + 2. Remove invalid characters + 3. Replace spaces with empty string (PascalCase) + 4. Capitalize first letter of each word + 5. If result is empty or starts with number, prefix with 'Class' + + Args: + name: The class name to sanitize + + Returns: + Sanitized class name that should pass validation + + Examples: + >>> validator = OntologyValidator() + >>> validator.sanitize_class_name("medical procedure") + 'MedicalProcedure' + >>> validator.sanitize_class_name("patient-record") + 'PatientRecord' + >>> validator.sanitize_class_name("123invalid") + 'Class123Invalid' + """ + logger.debug(f"Sanitizing class name: '{name}'") + + if not name or not name.strip(): + logger.warning("Empty class name provided for sanitization, returning 'UnnamedClass'") + return "UnnamedClass" + + # Strip whitespace + name = name.strip() + original_name = name + + # Split on spaces, hyphens, and underscores, then capitalize each word + words = re.split(r'[\s\-_]+', name) + + # Capitalize first letter of each word and keep rest as is + sanitized_words = [] + for word in words: + if word: + # Remove non-alphanumeric characters except underscore + clean_word = re.sub(r'[^A-Za-z0-9_]', '', word) + if clean_word: + # Capitalize first letter + sanitized_words.append(clean_word[0].upper() + clean_word[1:]) + + # Join words + sanitized = ''.join(sanitized_words) + + # If empty or starts with number, prefix with 'Class' + if not sanitized or sanitized[0].isdigit(): + sanitized = 'Class' + sanitized + logger.info(f"Prefixed class name with 'Class': '{original_name}' -> '{sanitized}'") + + # If it's a reserved word, append 'Class' suffix + if sanitized in self.OWL_RESERVED_WORDS: + sanitized = sanitized + 'Class' + logger.info(f"Appended 'Class' suffix to reserved word: '{original_name}' -> '{sanitized}'") + + logger.info(f"Sanitized class name: '{original_name}' -> '{sanitized}'") + return sanitized + + def remove_duplicates(self, classes: List[OntologyClass]) -> List[OntologyClass]: + """Remove duplicate ontology classes based on case-insensitive name comparison. + + When duplicates are found, keeps the first occurrence and discards subsequent ones. + Comparison is case-insensitive to catch variations like 'Patient' and 'patient'. + + Args: + classes: List of OntologyClass objects + + Returns: + List of OntologyClass objects with duplicates removed + + Examples: + >>> validator = OntologyValidator() + >>> classes = [ + ... OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare"), + ... OntologyClass(name="patient", description="Another patient", entity_type="Person", domain="Healthcare"), + ... OntologyClass(name="Doctor", description="A doctor", entity_type="Person", domain="Healthcare"), + ... ] + >>> unique = validator.remove_duplicates(classes) + >>> len(unique) + 2 + >>> [c.name for c in unique] + ['Patient', 'Doctor'] + """ + if not classes: + logger.debug("No classes to check for duplicates") + return classes + + logger.debug(f"Checking {len(classes)} classes for duplicates") + + seen_names = set() + unique_classes = [] + duplicates_found = [] + + for ontology_class in classes: + # Use lowercase for comparison + name_lower = ontology_class.name.lower() + + if name_lower not in seen_names: + seen_names.add(name_lower) + unique_classes.append(ontology_class) + else: + duplicates_found.append(ontology_class.name) + logger.debug(f"Duplicate class found and removed: '{ontology_class.name}'") + + if duplicates_found: + logger.info( + f"Removed {len(duplicates_found)} duplicate classes: {duplicates_found}" + ) + else: + logger.debug("No duplicate classes found") + + return unique_classes + + def truncate_description(self, description: str, max_length: int = 500) -> str: + """Truncate a description to a maximum length. + + If the description exceeds max_length, it will be truncated and + an ellipsis (...) will be appended to indicate truncation. + + Args: + description: The description text to truncate + max_length: Maximum allowed length (default: 500) + + Returns: + Truncated description string + + Examples: + >>> validator = OntologyValidator() + >>> long_desc = "A" * 600 + >>> truncated = validator.truncate_description(long_desc, max_length=500) + >>> len(truncated) + 500 + >>> truncated.endswith("...") + True + """ + if not description: + return "" + + if len(description) <= max_length: + return description + + # Truncate and add ellipsis + # Reserve 3 characters for "..." + truncate_at = max_length - 3 + truncated = description[:truncate_at] + "..." + + logger.debug( + f"Truncated description from {len(description)} to {len(truncated)} characters" + ) + + return truncated diff --git a/api/app/core/memory/utils/validation/owl_validator.py b/api/app/core/memory/utils/validation/owl_validator.py new file mode 100644 index 00000000..2398d528 --- /dev/null +++ b/api/app/core/memory/utils/validation/owl_validator.py @@ -0,0 +1,585 @@ +"""OWL semantic validation for ontology classes using Owlready2. + +This module provides the OWLValidator class for validating ontology classes +against OWL standards using the Owlready2 library. It performs semantic +validation including consistency checking, circular inheritance detection, +and OWL file export. + +Classes: + OWLValidator: Validates ontology classes using OWL reasoning and exports to OWL formats +""" + +import logging +from typing import List, Optional, Tuple + +from owlready2 import ( + World, + Thing, + get_ontology, + sync_reasoner_pellet, + OwlReadyInconsistentOntologyError, +) + +from app.core.memory.models.ontology_models import OntologyClass +logger = logging.getLogger(__name__) + + +class OWLValidator: + """Validator for OWL semantic validation of ontology classes. + + This validator performs semantic-level validation using Owlready2 including: + - Creating OWL classes from ontology class definitions + - Running consistency checking with Pellet reasoner + - Detecting circular inheritance + - Validating Protégé compatibility + - Exporting ontologies to various OWL formats (RDF/XML, Turtle, N-Triples) + + Attributes: + base_namespace: Base URI for the ontology namespace + """ + + def __init__(self, base_namespace: str = "http://example.org/ontology#"): + """Initialize the OWL validator. + + Args: + base_namespace: Base URI for the ontology namespace (default: http://example.org/ontology#) + """ + self.base_namespace = base_namespace + + def validate_ontology_classes( + self, + classes: List[OntologyClass], + ) -> Tuple[bool, List[str], Optional[World]]: + """Validate extracted ontology classes against OWL standards. + + This method creates an OWL ontology from the provided classes using Owlready2, + runs consistency checking with the Pellet reasoner, and detects common issues + like circular inheritance. + + Args: + classes: List of OntologyClass objects to validate + + Returns: + Tuple of (is_valid, error_messages, world): + - is_valid: True if ontology is valid and consistent, False otherwise + - error_messages: List of error/warning messages + - world: Owlready2 World object containing the ontology (None if validation failed) + + Examples: + >>> validator = OWLValidator() + >>> classes = [ + ... OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare"), + ... OntologyClass(name="Doctor", description="A doctor", entity_type="Person", domain="Healthcare"), + ... ] + >>> is_valid, errors, world = validator.validate_ontology_classes(classes) + >>> is_valid + True + >>> len(errors) + 0 + """ + if not classes: + return False, ["No classes provided for validation"], None + + errors = [] + + try: + # Create a new world (isolated ontology environment) + world = World() + + # Use a proper ontology IRI + # Owlready2 expects the IRI to end with .owl or similar + onto_iri = self.base_namespace.rstrip('#/') + if not onto_iri.endswith('.owl'): + onto_iri = onto_iri + '.owl' + + # Create ontology + onto = world.get_ontology(onto_iri) + + with onto: + # Dictionary to store created OWL classes for parent reference + owl_classes = {} + + # First pass: Create all classes without parent relationships + for ontology_class in classes: + try: + # Create OWL class dynamically using type() with Thing as base + # The key is to NOT set namespace in the dict, let Owlready2 handle it + owl_class = type( + ontology_class.name, # Class name + (Thing,), # Base classes + {} # Class dict (empty, let Owlready2 manage) + ) + + # Add label (rdfs:label) - include both English and Chinese names + labels = [ontology_class.name] + if ontology_class.name_chinese: + labels.append(ontology_class.name_chinese) + owl_class.label = labels + + # Add comment (rdfs:comment) with description + if ontology_class.description: + owl_class.comment = [ontology_class.description] + + # Store for parent relationship setup + owl_classes[ontology_class.name] = owl_class + + logger.debug( + f"Created OWL class: {ontology_class.name} " + f"(Chinese: {ontology_class.name_chinese}) " + f"IRI: {owl_class.iri if hasattr(owl_class, 'iri') else 'N/A'}" + ) + + except Exception as e: + error_msg = f"Failed to create OWL class '{ontology_class.name}': {str(e)}" + errors.append(error_msg) + logger.error(error_msg, exc_info=True) + + # Second pass: Set up parent relationships + for ontology_class in classes: + if ontology_class.parent_class and ontology_class.name in owl_classes: + parent_name = ontology_class.parent_class + + # Check if parent exists + if parent_name in owl_classes: + try: + child_class = owl_classes[ontology_class.name] + parent_class = owl_classes[parent_name] + + # Set parent by modifying is_a + child_class.is_a = [parent_class] + + logger.debug( + f"Set parent relationship: {ontology_class.name} -> {parent_name}" + ) + + except Exception as e: + error_msg = ( + f"Failed to set parent relationship " + f"'{ontology_class.name}' -> '{parent_name}': {str(e)}" + ) + errors.append(error_msg) + logger.warning(error_msg) + else: + warning_msg = ( + f"Parent class '{parent_name}' not found for '{ontology_class.name}'" + ) + errors.append(warning_msg) + logger.warning(warning_msg) + + # Check for circular inheritance + for class_name, owl_class in owl_classes.items(): + if self._has_circular_inheritance(owl_class): + error_msg = f"Circular inheritance detected for class '{class_name}'" + errors.append(error_msg) + logger.error(error_msg) + + # Run consistency checking with Pellet reasoner + try: + logger.info("Running Pellet reasoner for consistency checking...") + sync_reasoner_pellet(world, infer_property_values=True, infer_data_property_values=True) + logger.info("Consistency check passed") + + except OwlReadyInconsistentOntologyError as e: + error_msg = f"Ontology is inconsistent: {str(e)}" + errors.append(error_msg) + logger.error(error_msg) + return False, errors, world + + except Exception as e: + # Reasoner errors are often due to Java not being installed or configured + # Log as warning but don't fail validation - ontology structure is still valid + warning_msg = f"Reasoner check skipped: {str(e)}" + if str(e).strip(): # Only log if there's an actual error message + logger.warning(warning_msg) + else: + logger.warning("Reasoner check skipped: Java may not be installed or configured") + # Continue - ontology structure is valid even without reasoner check + + # If we have errors (excluding warnings), validation failed + is_valid = len(errors) == 0 + + return is_valid, errors, world + + except Exception as e: + error_msg = f"OWL validation failed: {str(e)}" + errors.append(error_msg) + logger.error(error_msg, exc_info=True) + return False, errors, None + + def _has_circular_inheritance(self, owl_class) -> bool: + """Check if an OWL class has circular inheritance. + + Circular inheritance occurs when a class inherits from itself through + a chain of parent relationships (e.g., A -> B -> C -> A). + + Args: + owl_class: Owlready2 class object to check + + Returns: + True if circular inheritance is detected, False otherwise + """ + visited = set() + current = owl_class + + while current: + # Get class IRI or name as identifier + class_id = str(current.iri) if hasattr(current, 'iri') else str(current) + + if class_id in visited: + # Found a cycle + return True + + visited.add(class_id) + + # Get parent classes (is_a relationship) + parents = getattr(current, 'is_a', []) + + # Filter out Thing and other base classes + parent_classes = [p for p in parents if p != Thing and hasattr(p, 'is_a')] + + if not parent_classes: + # No more parents, no cycle + break + + # Check first parent (in single inheritance) + current = parent_classes[0] if parent_classes else None + + return False + + def export_to_owl( + self, + world: World, + output_path: Optional[str] = None, + format: str = "rdfxml", + classes: Optional[List] = None + ) -> str: + """Export ontology to OWL file in specified format. + + Supported formats: + - rdfxml: RDF/XML format (default, most compatible) + - turtle: Turtle format (more readable) + - ntriples: N-Triples format (simplest) + - json: JSON format (simplified, human-readable) + + Args: + world: Owlready2 World object containing the ontology + output_path: Optional file path to save the ontology (if None, returns string) + format: Export format - "rdfxml", "turtle", "ntriples", or "json" (default: "rdfxml") + classes: Optional list of OntologyClass objects (required for json format) + + Returns: + String representation of the exported ontology + + Raises: + ValueError: If format is not supported + RuntimeError: If export fails + + Examples: + >>> validator = OWLValidator() + >>> is_valid, errors, world = validator.validate_ontology_classes(classes) + >>> owl_content = validator.export_to_owl(world, "ontology.owl", format="rdfxml") + """ + # Validate format + valid_formats = ["rdfxml", "turtle", "ntriples", "json"] + if format not in valid_formats: + raise ValueError( + f"Unsupported format '{format}'. Must be one of: {', '.join(valid_formats)}" + ) + + # JSON format doesn't need OWL processing + if format == "json": + if not classes: + raise ValueError("Classes list is required for JSON format export") + return self._export_to_json(classes) + + # For OWL formats, world is required + if not world: + raise ValueError("World object is None. Cannot export ontology.") + + # Note: Owlready2 has issues with turtle format export + # We'll handle it specially by converting from rdfxml + use_conversion = (format == "turtle") + + try: + # Get all ontologies in the world + ontologies = list(world.ontologies.values()) + + if not ontologies: + raise RuntimeError("No ontologies found in world") + + # Find the ontology with classes (skip anonymous/empty ontologies) + onto = None + for ont in ontologies: + classes_count = len(list(ont.classes())) + logger.debug(f"Checking ontology {ont.base_iri}: {classes_count} classes") + if classes_count > 0: + onto = ont + break + + # If no ontology with classes found, use the last non-anonymous one + if onto is None: + for ont in reversed(ontologies): + if ont.base_iri != "http://anonymous/": + onto = ont + break + + # If still no ontology, use the first one + if onto is None: + onto = ontologies[0] + + # Log ontology contents for debugging + logger.info(f"Ontology IRI: {onto.base_iri}") + logger.info(f"Ontology contains {len(list(onto.classes()))} classes") + + # List all classes in the ontology + all_classes = list(onto.classes()) + for cls in all_classes: + logger.info(f"Class in ontology: {cls.name} (IRI: {cls.iri})") + if hasattr(cls, 'label'): + logger.debug(f" Labels: {cls.label}") + if hasattr(cls, 'comment'): + logger.debug(f" Comments: {cls.comment}") + + if len(all_classes) == 0: + logger.warning("No classes found in ontology! This may indicate a problem with class creation.") + + if output_path: + # Save to file + export_format = "rdfxml" if use_conversion else format + logger.info(f"Exporting ontology to {output_path} in {export_format} format") + onto.save(file=output_path, format=export_format) + + # Read back the file content to return + with open(output_path, 'r', encoding='utf-8') as f: + content = f.read() + + # Convert to turtle if needed + if use_conversion: + content = self._convert_to_turtle(content) + + logger.info(f"Successfully exported ontology to {output_path}") + + # Format the content for better readability + content = self._format_owl_content(content, format) + + return content + else: + # Export to string (save to temporary location and read) + import tempfile + import os + + with tempfile.NamedTemporaryFile(mode='w', suffix='.owl', delete=False) as tmp: + tmp_path = tmp.name + + try: + export_format = "rdfxml" if use_conversion else format + onto.save(file=tmp_path, format=export_format) + + with open(tmp_path, 'r', encoding='utf-8') as f: + content = f.read() + + # Convert to turtle if needed + if use_conversion: + content = self._convert_to_turtle(content) + + # Format the content for better readability + content = self._format_owl_content(content, format) + + return content + + finally: + # Clean up temporary file + if os.path.exists(tmp_path): + os.remove(tmp_path) + + except Exception as e: + error_msg = f"Failed to export ontology: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def _export_to_json(self, classes: List) -> str: + """Export ontology classes to simplified JSON format. + + This format is more compact and easier to parse than OWL XML. + + Args: + classes: List of OntologyClass objects + + Returns: + JSON string representation (compact format) + """ + import json + + result = { + "ontology": { + "namespace": self.base_namespace, + "classes": [] + } + } + + for cls in classes: + class_data = { + "name": cls.name, + "name_chinese": cls.name_chinese, + "description": cls.description, + "entity_type": cls.entity_type, + "domain": cls.domain, + "parent_class": cls.parent_class, + "examples": cls.examples if hasattr(cls, 'examples') else [] + } + result["ontology"]["classes"].append(class_data) + + # 使用紧凑格式:无缩进,使用分隔符减少空格 + return json.dumps(result, ensure_ascii=False, separators=(',', ':')) + + def _convert_to_turtle(self, rdfxml_content: str) -> str: + """Convert RDF/XML content to Turtle format using rdflib. + + Args: + rdfxml_content: RDF/XML format content + + Returns: + Turtle format content + """ + try: + from rdflib import Graph + + # Parse RDF/XML + g = Graph() + g.parse(data=rdfxml_content, format="xml") + + # Serialize to Turtle + turtle_content = g.serialize(format="turtle") + + # Handle bytes vs string + if isinstance(turtle_content, bytes): + turtle_content = turtle_content.decode('utf-8') + + return turtle_content + + except ImportError: + logger.warning( + "rdflib is not installed. Cannot convert to Turtle format. " + "Install with: pip install rdflib" + ) + return rdfxml_content + except Exception as e: + logger.error(f"Failed to convert to Turtle format: {e}") + return rdfxml_content + + def _format_owl_content(self, content: str, format: str) -> str: + """Format OWL content for better readability. + + Args: + content: Raw OWL content string + format: Format type (rdfxml, turtle, ntriples) + + Returns: + Formatted OWL content string + """ + if format == "rdfxml": + # Format XML with proper indentation + try: + import xml.dom.minidom as minidom + dom = minidom.parseString(content) + # Pretty print with 2-space indentation + formatted = dom.toprettyxml(indent=" ", encoding="utf-8").decode("utf-8") + + # Remove extra blank lines + lines = [] + prev_blank = False + for line in formatted.split('\n'): + is_blank = not line.strip() + if not (is_blank and prev_blank): # Skip consecutive blank lines + lines.append(line) + prev_blank = is_blank + + formatted = '\n'.join(lines) + + return formatted + except Exception as e: + logger.warning(f"Failed to format XML content: {e}") + return content + + elif format == "turtle": + # Turtle format is already relatively readable + # Just ensure consistent line endings and not empty + if not content or content.strip() == "": + logger.warning("Turtle content is empty, this may indicate an export issue") + return content.strip() + '\n' if content.strip() else content + + elif format == "ntriples": + # N-Triples format is line-based, ensure proper line endings + return content.strip() + '\n' if content.strip() else content + + return content + + def validate_with_protege_compatibility( + self, + classes: List[OntologyClass] + ) -> Tuple[bool, List[str]]: + """Validate that ontology classes are compatible with Protégé editor. + + Protégé compatibility checks: + - Class names are valid OWL identifiers + - No special characters that Protégé cannot handle + - Namespace is properly formatted + - Labels and comments are properly encoded + + Args: + classes: List of OntologyClass objects to validate + + Returns: + Tuple of (is_compatible, warnings): + - is_compatible: True if compatible with Protégé, False otherwise + - warnings: List of compatibility warning messages + + Examples: + >>> validator = OWLValidator() + >>> classes = [OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare")] + >>> is_compatible, warnings = validator.validate_with_protege_compatibility(classes) + >>> is_compatible + True + """ + warnings = [] + + # Check namespace format + if not self.base_namespace.startswith(('http://', 'https://')): + warnings.append( + f"Namespace '{self.base_namespace}' should start with http:// or https:// " + "for Protégé compatibility" + ) + + if not self.base_namespace.endswith(('#', '/')): + warnings.append( + f"Namespace '{self.base_namespace}' should end with # or / " + "for Protégé compatibility" + ) + + # Check each class + for ontology_class in classes: + # Check for special characters that might cause issues + if any(char in ontology_class.name for char in ['<', '>', '"', '{', '}', '|', '^', '`']): + warnings.append( + f"Class name '{ontology_class.name}' contains special characters " + "that may cause issues in Protégé" + ) + + # Check description length (Protégé can handle long descriptions but may display poorly) + if ontology_class.description and len(ontology_class.description) > 1000: + warnings.append( + f"Class '{ontology_class.name}' has a very long description ({len(ontology_class.description)} chars) " + "which may display poorly in Protégé" + ) + + # Check for non-ASCII characters (Protégé supports them but encoding issues may occur) + if not ontology_class.name.isascii(): + warnings.append( + f"Class name '{ontology_class.name}' contains non-ASCII characters " + "which may cause encoding issues in some Protégé versions" + ) + + # If no warnings, it's compatible + is_compatible = len(warnings) == 0 + + return is_compatible, warnings diff --git a/api/app/core/workflow/nodes/code/node.py b/api/app/core/workflow/nodes/code/node.py index b2a4da32..892708f2 100644 --- a/api/app/core/workflow/nodes/code/node.py +++ b/api/app/core/workflow/nodes/code/node.py @@ -14,7 +14,7 @@ from app.core.workflow.nodes.code.config import CodeNodeConfig logger = logging.getLogger(__name__) -SCRIPT_TEMPLATE = Template(dedent(""" +PYTHON_SCRIPT_TEMPLATE = Template(dedent(""" $code import json @@ -32,6 +32,20 @@ result = "<>" + output_json + "<>" print(result) """)) +NODEJS_SCRIPT_TEMPLATE = Template(dedent(""" +$code +// decode and prepare input object +var inputs_obj = JSON.parse(Buffer.from('$inputs_variable', 'base64').toString('utf-8')) + +// execute main function +var output_obj = main(inputs_obj) + +// convert output to json and print +var output_json = JSON.stringify(output_obj) +var result = `<>$${output_json}<>` +console.log(result) +""")) + class CodeNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): @@ -83,6 +97,7 @@ class CodeNode(BaseNode): input_variable_dict = {} for input_variable in self.typed_config.input_variables: input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, state) + code = base64.b64decode( self.typed_config.code ).decode("utf-8") @@ -90,11 +105,18 @@ class CodeNode(BaseNode): input_variable_dict = base64.b64encode( json.dumps(input_variable_dict).encode("utf-8") ).decode("utf-8") - - final_script = SCRIPT_TEMPLATE.substitute( - code=code, - inputs_variable=input_variable_dict, - ) + if self.typed_config.language == "python3": + final_script = PYTHON_SCRIPT_TEMPLATE.substitute( + code=code, + inputs_variable=input_variable_dict, + ) + elif self.typed_config.language == 'nodejs': + final_script = NODEJS_SCRIPT_TEMPLATE.substitute( + code=code, + inputs_variable=input_variable_dict, + ) + else: + raise ValueError(f"Unsupported language: {self.typed_config.language}") async with httpx.AsyncClient() as client: response = await client.post( diff --git a/api/app/models/__init__.py b/api/app/models/__init__.py index a429dd8e..984212de 100644 --- a/api/app/models/__init__.py +++ b/api/app/models/__init__.py @@ -28,6 +28,10 @@ from .tool_model import ( ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus ) from .memory_perceptual_model import MemoryPerceptualModel +from .ontology_scene import OntologyScene +from .ontology_class import OntologyClass +from .ontology_scene import OntologyScene +from .ontology_class import OntologyClass __all__ = [ "Tenants", diff --git a/api/app/models/memory_config_model.py b/api/app/models/memory_config_model.py index 816ece79..1095a386 100644 --- a/api/app/models/memory_config_model.py +++ b/api/app/models/memory_config_model.py @@ -22,6 +22,9 @@ class MemoryConfig(Base): end_user_id = Column(String, nullable=True, comment="组ID") user_id = Column(String, nullable=True, comment="用户ID") apply_id = Column(String, nullable=True, comment="应用ID") + + # 本体场景关联 + scene_id = Column(UUID(as_uuid=True), nullable=True, comment="本体场景ID,关联ontology_scene表") # 模型选择(从workspace继承) llm_id = Column(String, nullable=True, comment="LLM模型配置ID") diff --git a/api/app/models/ontology_class.py b/api/app/models/ontology_class.py new file mode 100644 index 00000000..528d934e --- /dev/null +++ b/api/app/models/ontology_class.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +"""本体类型模型 + +本模块定义本体类型的数据模型。 + +Classes: + OntologyClass: 本体类型表模型 +""" + +import datetime +import uuid +from sqlalchemy import Column, String, DateTime, Text, ForeignKey +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import relationship +from app.db import Base + + +class OntologyClass(Base): + """本体类型表 - 用于存储某个场景提取出来的本体类型信息""" + __tablename__ = "ontology_class" + + # 主键 + class_id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="类型ID") + + # 类型信息 + class_name = Column(String(200), nullable=False, comment="类型名称") + class_description = Column(Text, nullable=True, comment="类型描述") + + # 外键:关联到本体场景 + scene_id = Column(UUID(as_uuid=True), ForeignKey("ontology_scene.scene_id", ondelete="CASCADE"), nullable=False, index=True, comment="所属场景ID") + + # 时间戳 + created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, comment="创建时间") + updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, nullable=False, comment="更新时间") + + # 关系:类型属于某个场景 + scene = relationship("OntologyScene", back_populates="classes") + + def __repr__(self): + return f"" diff --git a/api/app/models/ontology_scene.py b/api/app/models/ontology_scene.py new file mode 100644 index 00000000..350bfdd6 --- /dev/null +++ b/api/app/models/ontology_scene.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +"""本体场景模型 + +本模块定义本体场景的数据模型。 + +Classes: + OntologyScene: 本体场景表模型 +""" + +import datetime +import uuid +from sqlalchemy import Column, String, DateTime, Integer, Text, ForeignKey, UniqueConstraint +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import relationship +from app.db import Base + + +class OntologyScene(Base): + """本体场景表 - 用于存储本体场景下不同的类型信息""" + __tablename__ = "ontology_scene" + __table_args__ = ( + UniqueConstraint('workspace_id', 'scene_name', name='uq_workspace_scene_name'), + ) + + # 主键 + scene_id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="场景ID") + + # 场景信息 + scene_name = Column(String(200), nullable=False, comment="场景名称") + scene_description = Column(Text, nullable=True, comment="场景描述") + + # 外键:关联到工作空间 + workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="CASCADE"), nullable=False, index=True, comment="所属工作空间ID") + + # 时间戳 + created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, comment="创建时间") + updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, nullable=False, comment="更新时间") + + # 关系:一个场景可以有多个类型 + classes = relationship("OntologyClass", back_populates="scene", cascade="all, delete-orphan") + + def __repr__(self): + return f"" diff --git a/api/app/models/prompt_optimizer_model.py b/api/app/models/prompt_optimizer_model.py index 39845ee7..f96b0a66 100644 --- a/api/app/models/prompt_optimizer_model.py +++ b/api/app/models/prompt_optimizer_model.py @@ -2,7 +2,7 @@ import datetime import uuid from enum import StrEnum -from sqlalchemy import Column, ForeignKey, Text, DateTime, String, Index +from sqlalchemy import Column, ForeignKey, Text, DateTime, String, Index, Boolean from sqlalchemy.dialects.postgresql import UUID from app.db import Base @@ -121,10 +121,33 @@ class PromptOptimizerSessionHistory(Base): id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID") # app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False, comment="Application ID") - session_id = Column(UUID(as_uuid=True), ForeignKey("prompt_opt_session_list.id"),nullable=False, comment="Session ID") + session_id = Column( + UUID(as_uuid=True), + ForeignKey("prompt_opt_session_list.id"), + nullable=False, + comment="Session ID" + ) user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, comment="User ID") role = Column(String, nullable=False, comment="Message Role") content = Column(Text, nullable=False, comment="Message Content") # prompt = Column(Text, nullable=False, comment="Prompt") created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time", index=True) + + +class PromptHistory(Base): + __tablename__ = "prompt_history" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) + tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, comment="Tenant ID") + + session_id = Column( + UUID(as_uuid=True), + ForeignKey("prompt_opt_session_list.id"), + nullable=False, + comment="Session ID" + ) + title = Column(String, nullable=False, comment="Title") + prompt = Column(Text, nullable=False, comment="Prompt") + created_at = Column(DateTime, default=datetime.datetime.now, comment="Creation Time", index=True) + is_delete = Column(Boolean, default=False, comment="Delete") diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index fbc04f2e..568c262f 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -32,6 +32,8 @@ db_logger = get_db_logger() config_logger = get_config_logger() TABLE_NAME = "memory_config" + + class MemoryConfigRepository: """记忆配置Repository @@ -154,7 +156,7 @@ class MemoryConfigRepository: return memory_config_obj @staticmethod - def query_reflection_config_by_id(db: Session, config_id: uuid.UUID) -> MemoryConfig: + def query_reflection_config_by_id(db: Session, config_id: uuid.UUID|int|str) -> MemoryConfig: """构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数) Args: @@ -170,6 +172,7 @@ class MemoryConfigRepository: if not memory_config: raise RuntimeError("reflection config not found") return memory_config + @staticmethod def query_reflection_config_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> MemoryConfig: """构建查询所有配置的语句(SQLAlchemy text() 命名参数) @@ -189,7 +192,6 @@ class MemoryConfigRepository: raise RuntimeError("reflection config not found") return memory_config - @staticmethod def build_select_all(workspace_id: uuid.UUID) -> Tuple[str, Dict]: """构建查询所有配置的语句(SQLAlchemy text() 命名参数) @@ -229,6 +231,7 @@ class MemoryConfigRepository: config_name=params.config_name, config_desc=params.config_desc, workspace_id=params.workspace_id, + scene_id=params.scene_id, llm_id=params.llm_id, embedding_id=params.embedding_id, rerank_id=params.rerank_id, @@ -289,7 +292,6 @@ class MemoryConfigRepository: db_logger.error(f"更新记忆配置失败: config_id={update.config_id} - {str(e)}") raise - @staticmethod def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[MemoryConfig]: """更新记忆萃取引擎配置 @@ -412,7 +414,7 @@ class MemoryConfigRepository: raise @staticmethod - def get_extracted_config(db: Session, config_id: UUID |int) -> Optional[Dict]: + def get_extracted_config(db: Session, config_id: UUID | int) -> Optional[Dict]: """获取萃取配置,通过主键查询某条配置 Args: @@ -422,7 +424,7 @@ class MemoryConfigRepository: Returns: Optional[Dict]: 萃取配置字典,不存在则返回None """ - config_id=resolve_config_id(config_id,db) + config_id = resolve_config_id(config_id, db) db_logger.debug(f"查询萃取配置: config_id={config_id}") try: db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first() @@ -516,26 +518,28 @@ class MemoryConfigRepository: except Exception as e: db_logger.error(f"根据ID查询记忆配置失败: config_id={config_id} - {str(e)}") raise + @staticmethod - def get_config_with_workspace(db: Session, config_id: uuid.UUID) -> Optional[tuple]: + def get_config_with_workspace(db: Session, config_id: uuid.UUID | int | str) -> Optional[tuple]: """Get memory config and its associated workspace information - + Args: db: Database session config_id: Configuration ID - + Returns: Optional[tuple]: (MemoryConfig, Workspace) tuple, None if not found - + Raises: ValueError: Raised when config exists but workspace doesn't """ import time from app.models.workspace_model import Workspace - + start_time = time.time() - + config_id = resolve_config_id(config_id, db) + # Log configuration loading start config_logger.info( "Loading configuration with workspace", @@ -544,17 +548,17 @@ class MemoryConfigRepository: "config_id": config_id } ) - + db_logger.debug(f"Querying memory config and workspace: config_id={config_id}") - + try: # Use join query to get both config and workspace result = db.query(MemoryConfig, Workspace).join( Workspace, MemoryConfig.workspace_id == Workspace.id ).filter(MemoryConfig.config_id == config_id).first() - + elapsed_ms = (time.time() - start_time) * 1000 - + if not result: # Check if config exists but workspace is missing config_only = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first() @@ -583,9 +587,11 @@ class MemoryConfigRepository: "elapsed_ms": elapsed_ms } ) - db_logger.error(f"Memory config {config_id} references non-existent workspace {config_only.workspace_id}") - raise ValueError(f"Workspace {config_only.workspace_id} not found for configuration {config_id}") - + db_logger.error( + f"Memory config {config_id} references non-existent workspace {config_only.workspace_id}") + raise ValueError( + f"Workspace {config_only.workspace_id} not found for configuration {config_id}") + config_logger.debug( "Configuration not found", extra={ @@ -597,9 +603,9 @@ class MemoryConfigRepository: ) db_logger.debug(f"Memory config not found: config_id={config_id}") return None - + config, workspace = result - + # Log successful configuration loading config_logger.info( "Configuration with workspace loaded successfully", @@ -614,16 +620,17 @@ class MemoryConfigRepository: "elapsed_ms": elapsed_ms } ) - - db_logger.debug(f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}") + + db_logger.debug( + f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}") return (config, workspace) - + except ValueError: # Re-raise known business exceptions raise except Exception as e: elapsed_ms = (time.time() - start_time) * 1000 - + config_logger.error( "Failed to load configuration with workspace", extra={ @@ -636,9 +643,10 @@ class MemoryConfigRepository: }, exc_info=True ) - + db_logger.error(f"Failed to query memory config and workspace: config_id={config_id} - {str(e)}") raise + @staticmethod def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[MemoryConfig]: """获取所有配置参数 diff --git a/api/app/repositories/model_repository.py b/api/app/repositories/model_repository.py index 36f7062f..3d66964a 100644 --- a/api/app/repositories/model_repository.py +++ b/api/app/repositories/model_repository.py @@ -630,6 +630,13 @@ class ModelBaseRepository: db.add(model_base) return model_base + @staticmethod + def get_by_name_and_provider(db: Session, name: str, provider: str) -> Optional['ModelBase']: + return db.query(ModelBase).filter( + ModelBase.name == name, + ModelBase.provider == provider + ).first() + @staticmethod def update(db: Session, model_base_id: uuid.UUID, data: dict) -> Optional['ModelBase']: model_base = db.query(ModelBase).filter(ModelBase.id == model_base_id).first() diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index c93e75b3..cf1732fd 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -877,7 +877,8 @@ RETURN CASE WHEN ms:ExtractedEntity THEN { text: ms.name, - created_at: ms.created_at + created_at: ms.created_at, + type: "情景记忆" } END ) AS ExtractedEntity, @@ -887,7 +888,8 @@ RETURN CASE WHEN n:MemorySummary THEN { text: n.content, - created_at: n.created_at + created_at: n.created_at, + type: "长期沉淀" } END ) AS MemorySummary, @@ -895,7 +897,8 @@ RETURN collect( DISTINCT { text: e.statement, - created_at: e.created_at + created_at: e.created_at, + type: "情绪记忆" } ) AS statement; """ diff --git a/api/app/repositories/ontology_class_repository.py b/api/app/repositories/ontology_class_repository.py new file mode 100644 index 00000000..68f261ff --- /dev/null +++ b/api/app/repositories/ontology_class_repository.py @@ -0,0 +1,404 @@ +# -*- coding: utf-8 -*- +"""本体类型Repository层 + +本模块提供本体类型的数据访问层实现。 + +Classes: + OntologyClassRepository: 本体类型数据访问类 +""" + +import logging +from typing import List, Optional +from uuid import UUID + +from sqlalchemy.orm import Session, joinedload + +from app.core.logging_config import get_db_logger +from app.models.ontology_class import OntologyClass +from app.models.ontology_scene import OntologyScene + + +logger = get_db_logger() + + +class OntologyClassRepository: + """本体类型Repository + + 提供本体类型的CRUD操作和权限检查。 + + Attributes: + db: SQLAlchemy数据库会话 + """ + + def __init__(self, db: Session): + """初始化Repository + + Args: + db: SQLAlchemy数据库会话 + """ + self.db = db + + def create(self, class_data: dict, scene_id: UUID) -> OntologyClass: + """创建本体类型 + + Args: + class_data: 类型数据字典,包含class_name和class_description + scene_id: 所属场景ID + + Returns: + OntologyClass: 创建的类型对象 + + Raises: + Exception: 数据库操作失败 + + Examples: + >>> repo = OntologyClassRepository(db) + >>> ontology_class = repo.create( + ... {"class_name": "患者", "class_description": "描述"}, + ... scene_id + ... ) + """ + try: + logger.info( + f"Creating ontology class - " + f"name={class_data.get('class_name')}, " + f"scene_id={scene_id}" + ) + + ontology_class = OntologyClass( + class_name=class_data.get("class_name"), + class_description=class_data.get("class_description"), + scene_id=scene_id + ) + + self.db.add(ontology_class) + self.db.flush() # 获取ID但不提交 + + logger.info( + f"Ontology class created successfully - " + f"class_id={ontology_class.class_id}" + ) + + return ontology_class + + except Exception as e: + logger.error( + f"Failed to create ontology class: {str(e)}", + exc_info=True + ) + raise + + def get_by_id(self, class_id: UUID) -> Optional[OntologyClass]: + """根据ID获取类型 + + Args: + class_id: 类型ID + + Returns: + Optional[OntologyClass]: 类型对象,不存在则返回None + + Examples: + >>> repo = OntologyClassRepository(db) + >>> ontology_class = repo.get_by_id(class_id) + """ + try: + logger.debug(f"Getting ontology class by ID: {class_id}") + + ontology_class = self.db.query(OntologyClass).filter( + OntologyClass.class_id == class_id + ).first() + + if ontology_class: + logger.debug(f"Ontology class found: {class_id}") + else: + logger.debug(f"Ontology class not found: {class_id}") + + return ontology_class + + except Exception as e: + logger.error( + f"Failed to get ontology class by ID: {str(e)}", + exc_info=True + ) + raise + + def get_by_name(self, class_name: str, scene_id: UUID) -> Optional[OntologyClass]: + """根据类型名称和场景ID获取类型(精确匹配) + + Args: + class_name: 类型名称 + scene_id: 场景ID + + Returns: + Optional[OntologyClass]: 类型对象,不存在则返回None + + Examples: + >>> repo = OntologyClassRepository(db) + >>> ontology_class = repo.get_by_name("患者", scene_id) + """ + try: + logger.debug(f"Getting ontology class by name: {class_name}, scene_id: {scene_id}") + + ontology_class = self.db.query(OntologyClass).filter( + OntologyClass.class_name == class_name, + OntologyClass.scene_id == scene_id + ).first() + + if ontology_class: + logger.debug(f"Ontology class found: {class_name}") + else: + logger.debug(f"Ontology class not found: {class_name}") + + return ontology_class + + except Exception as e: + logger.error( + f"Failed to get ontology class by name: {str(e)}", + exc_info=True + ) + raise + + def search_by_name(self, keyword: str, scene_id: UUID) -> List[OntologyClass]: + """根据关键词模糊搜索类型 + + 使用 LIKE 进行模糊匹配,支持中文和英文。 + + Args: + keyword: 搜索关键词 + scene_id: 场景ID + + Returns: + List[OntologyClass]: 匹配的类型列表 + + Examples: + >>> repo = OntologyClassRepository(db) + >>> classes = repo.search_by_name("患者", scene_id) + """ + try: + logger.debug( + f"Searching ontology classes by keyword - " + f"keyword={keyword}, scene_id={scene_id}" + ) + + # 使用 ilike 进行不区分大小写的模糊匹配 + classes = self.db.query(OntologyClass).filter( + OntologyClass.class_name.ilike(f"%{keyword}%"), + OntologyClass.scene_id == scene_id + ).order_by( + OntologyClass.created_at.desc() + ).all() + + logger.info( + f"Found {len(classes)} ontology classes matching keyword '{keyword}' " + f"in scene {scene_id}" + ) + + return classes + + except Exception as e: + logger.error( + f"Failed to search ontology classes by keyword: {str(e)}", + exc_info=True + ) + raise + + def get_by_scene(self, scene_id: UUID) -> List[OntologyClass]: + """获取场景下的所有类型 + + 按创建时间倒序排列。 + + Args: + scene_id: 场景ID + + Returns: + List[OntologyClass]: 类型列表 + + Examples: + >>> repo = OntologyClassRepository(db) + >>> classes = repo.get_by_scene(scene_id) + """ + try: + logger.debug(f"Getting ontology classes by scene: {scene_id}") + + classes = self.db.query(OntologyClass).filter( + OntologyClass.scene_id == scene_id + ).order_by( + OntologyClass.created_at.desc() + ).all() + + logger.info( + f"Found {len(classes)} ontology classes in scene {scene_id}" + ) + + return classes + + except Exception as e: + logger.error( + f"Failed to get ontology classes by scene: {str(e)}", + exc_info=True + ) + raise + + def update(self, class_id: UUID, update_data: dict) -> Optional[OntologyClass]: + """更新类型信息 + + Args: + class_id: 类型ID + update_data: 更新数据字典 + + Returns: + Optional[OntologyClass]: 更新后的类型对象,不存在则返回None + + Raises: + Exception: 数据库操作失败 + + Examples: + >>> repo = OntologyClassRepository(db) + >>> ontology_class = repo.update( + ... class_id, + ... {"class_name": "新名称"} + ... ) + """ + try: + logger.info(f"Updating ontology class: {class_id}") + + ontology_class = self.get_by_id(class_id) + if not ontology_class: + logger.warning(f"Ontology class not found for update: {class_id}") + return None + + # 更新字段 + if "class_name" in update_data and update_data["class_name"] is not None: + ontology_class.class_name = update_data["class_name"] + + if "class_description" in update_data: + ontology_class.class_description = update_data["class_description"] + + self.db.flush() + + logger.info(f"Ontology class updated successfully: {class_id}") + + return ontology_class + + except Exception as e: + logger.error( + f"Failed to update ontology class: {str(e)}", + exc_info=True + ) + raise + + def delete(self, class_id: UUID) -> bool: + """删除类型 + + Args: + class_id: 类型ID + + Returns: + bool: 删除成功返回True,类型不存在返回False + + Raises: + Exception: 数据库操作失败 + + Examples: + >>> repo = OntologyClassRepository(db) + >>> success = repo.delete(class_id) + """ + try: + logger.info(f"Deleting ontology class: {class_id}") + + ontology_class = self.get_by_id(class_id) + if not ontology_class: + logger.warning(f"Ontology class not found for delete: {class_id}") + return False + + self.db.delete(ontology_class) + self.db.flush() + + logger.info(f"Ontology class deleted successfully: {class_id}") + + return True + + except Exception as e: + logger.error( + f"Failed to delete ontology class: {str(e)}", + exc_info=True + ) + raise + + def check_ownership(self, class_id: UUID, workspace_id: UUID) -> bool: + """检查类型是否属于指定工作空间(通过场景关联) + + Args: + class_id: 类型ID + workspace_id: 工作空间ID + + Returns: + bool: 属于返回True,否则返回False + + Examples: + >>> repo = OntologyClassRepository(db) + >>> is_owner = repo.check_ownership(class_id, workspace_id) + """ + try: + logger.debug( + f"Checking class ownership - " + f"class_id={class_id}, workspace_id={workspace_id}" + ) + + count = self.db.query(OntologyClass).join( + OntologyScene, + OntologyClass.scene_id == OntologyScene.scene_id + ).filter( + OntologyClass.class_id == class_id, + OntologyScene.workspace_id == workspace_id + ).count() + + is_owner = count > 0 + + logger.debug( + f"Class ownership check result: {is_owner} - " + f"class_id={class_id}" + ) + + return is_owner + + except Exception as e: + logger.error( + f"Failed to check class ownership: {str(e)}", + exc_info=True + ) + raise + + def get_scene_id_by_class(self, class_id: UUID) -> Optional[UUID]: + """根据类型ID获取所属场景ID + + Args: + class_id: 类型ID + + Returns: + Optional[UUID]: 场景ID,类型不存在则返回None + + Examples: + >>> repo = OntologyClassRepository(db) + >>> scene_id = repo.get_scene_id_by_class(class_id) + """ + try: + logger.debug(f"Getting scene ID by class: {class_id}") + + ontology_class = self.get_by_id(class_id) + if not ontology_class: + logger.debug(f"Class not found: {class_id}") + return None + + logger.debug( + f"Found scene ID: {ontology_class.scene_id} for class: {class_id}" + ) + + return ontology_class.scene_id + + except Exception as e: + logger.error( + f"Failed to get scene ID by class: {str(e)}", + exc_info=True + ) + raise diff --git a/api/app/repositories/ontology_scene_repository.py b/api/app/repositories/ontology_scene_repository.py new file mode 100644 index 00000000..322e111c --- /dev/null +++ b/api/app/repositories/ontology_scene_repository.py @@ -0,0 +1,394 @@ +# -*- coding: utf-8 -*- +"""本体场景Repository层 + +本模块提供本体场景的数据访问层实现。 + +Classes: + OntologySceneRepository: 本体场景数据访问类 +""" + +import logging +from typing import List, Optional +from uuid import UUID + +from sqlalchemy.orm import Session, joinedload + +from app.core.logging_config import get_db_logger +from app.models.ontology_scene import OntologyScene + + +logger = get_db_logger() + + +class OntologySceneRepository: + """本体场景Repository + + 提供本体场景的CRUD操作和权限检查。 + + Attributes: + db: SQLAlchemy数据库会话 + """ + + def __init__(self, db: Session): + """初始化Repository + + Args: + db: SQLAlchemy数据库会话 + """ + self.db = db + + def create(self, scene_data: dict, workspace_id: UUID) -> OntologyScene: + """创建本体场景 + + Args: + scene_data: 场景数据字典,包含scene_name和scene_description + workspace_id: 所属工作空间ID + + Returns: + OntologyScene: 创建的场景对象 + + Raises: + Exception: 数据库操作失败 + + Examples: + >>> repo = OntologySceneRepository(db) + >>> scene = repo.create( + ... {"scene_name": "医疗场景", "scene_description": "描述"}, + ... workspace_id + ... ) + """ + try: + logger.info( + f"Creating ontology scene - " + f"name={scene_data.get('scene_name')}, " + f"workspace_id={workspace_id}" + ) + + scene = OntologyScene( + scene_name=scene_data.get("scene_name"), + scene_description=scene_data.get("scene_description"), + workspace_id=workspace_id + ) + + self.db.add(scene) + self.db.flush() # 获取ID但不提交 + + logger.info( + f"Ontology scene created successfully - " + f"scene_id={scene.scene_id}" + ) + + return scene + + except Exception as e: + logger.error( + f"Failed to create ontology scene: {str(e)}", + exc_info=True + ) + raise + + def get_by_id(self, scene_id: UUID) -> Optional[OntologyScene]: + """根据ID获取场景 + + Args: + scene_id: 场景ID + + Returns: + Optional[OntologyScene]: 场景对象,不存在则返回None + + Examples: + >>> repo = OntologySceneRepository(db) + >>> scene = repo.get_by_id(scene_id) + """ + try: + logger.debug(f"Getting ontology scene by ID: {scene_id}") + + scene = self.db.query(OntologyScene).filter( + OntologyScene.scene_id == scene_id + ).first() + + if scene: + logger.debug(f"Ontology scene found: {scene_id}") + else: + logger.debug(f"Ontology scene not found: {scene_id}") + + return scene + + except Exception as e: + logger.error( + f"Failed to get ontology scene by ID: {str(e)}", + exc_info=True + ) + raise + + def get_by_name(self, scene_name: str, workspace_id: UUID) -> Optional[OntologyScene]: + """根据场景名称和工作空间ID获取场景(精确匹配) + + Args: + scene_name: 场景名称 + workspace_id: 工作空间ID + + Returns: + Optional[OntologyScene]: 场景对象,不存在则返回None + + Examples: + >>> repo = OntologySceneRepository(db) + >>> scene = repo.get_by_name("医疗场景", workspace_id) + """ + try: + logger.debug( + f"Getting ontology scene by name - " + f"scene_name={scene_name}, workspace_id={workspace_id}" + ) + + scene = self.db.query(OntologyScene).options( + joinedload(OntologyScene.classes) + ).filter( + OntologyScene.scene_name == scene_name, + OntologyScene.workspace_id == workspace_id + ).first() + + if scene: + logger.debug(f"Ontology scene found: {scene_name}") + else: + logger.debug(f"Ontology scene not found: {scene_name}") + + return scene + + except Exception as e: + logger.error( + f"Failed to get ontology scene by name: {str(e)}", + exc_info=True + ) + raise + + def search_by_name(self, keyword: str, workspace_id: UUID) -> List[OntologyScene]: + """根据关键词模糊搜索场景 + + 使用 LIKE 进行模糊匹配,支持中文和英文。 + + Args: + keyword: 搜索关键词 + workspace_id: 工作空间ID + + Returns: + List[OntologyScene]: 匹配的场景列表 + + Examples: + >>> repo = OntologySceneRepository(db) + >>> scenes = repo.search_by_name("医疗", workspace_id) + """ + try: + logger.debug( + f"Searching ontology scenes by keyword - " + f"keyword={keyword}, workspace_id={workspace_id}" + ) + + # 使用 ilike 进行不区分大小写的模糊匹配 + scenes = self.db.query(OntologyScene).options( + joinedload(OntologyScene.classes) + ).filter( + OntologyScene.scene_name.ilike(f"%{keyword}%"), + OntologyScene.workspace_id == workspace_id + ).order_by( + OntologyScene.updated_at.desc() + ).all() + + logger.info( + f"Found {len(scenes)} ontology scenes matching keyword '{keyword}' " + f"in workspace {workspace_id}" + ) + + return scenes + + except Exception as e: + logger.error( + f"Failed to search ontology scenes by keyword: {str(e)}", + exc_info=True + ) + raise + + def get_by_workspace(self, workspace_id: UUID, page: Optional[int] = None, page_size: Optional[int] = None) -> tuple: + """获取工作空间下的所有场景(支持分页) + + 使用joinedload预加载classes关系以统计数量。 + + Args: + workspace_id: 工作空间ID + page: 页码(可选,从1开始) + page_size: 每页数量(可选) + + Returns: + tuple: (场景列表, 总数量) + + Examples: + >>> repo = OntologySceneRepository(db) + >>> scenes, total = repo.get_by_workspace(workspace_id) + >>> scenes, total = repo.get_by_workspace(workspace_id, page=1, page_size=10) + """ + try: + logger.debug(f"Getting ontology scenes by workspace: {workspace_id}, page={page}, page_size={page_size}") + + # 构建基础查询 + query = self.db.query(OntologyScene).options( + joinedload(OntologyScene.classes) + ).filter( + OntologyScene.workspace_id == workspace_id + ).order_by( + OntologyScene.updated_at.desc() + ) + + # 获取总数 + total = query.count() + + # 如果提供了分页参数,应用分页 + if page is not None and page_size is not None: + offset = (page - 1) * page_size + query = query.offset(offset).limit(page_size) + logger.debug(f"Applying pagination: offset={offset}, limit={page_size}") + + scenes = query.all() + + logger.info( + f"Found {len(scenes)} ontology scenes (total: {total}) in workspace {workspace_id}" + ) + + return scenes, total + + except Exception as e: + logger.error( + f"Failed to get ontology scenes by workspace: {str(e)}", + exc_info=True + ) + raise + + def update(self, scene_id: UUID, update_data: dict) -> Optional[OntologyScene]: + """更新场景信息 + + Args: + scene_id: 场景ID + update_data: 更新数据字典 + + Returns: + Optional[OntologyScene]: 更新后的场景对象,不存在则返回None + + Raises: + Exception: 数据库操作失败 + + Examples: + >>> repo = OntologySceneRepository(db) + >>> scene = repo.update( + ... scene_id, + ... {"scene_name": "新名称"} + ... ) + """ + try: + logger.info(f"Updating ontology scene: {scene_id}") + + scene = self.get_by_id(scene_id) + if not scene: + logger.warning(f"Ontology scene not found for update: {scene_id}") + return None + + # 更新字段 + if "scene_name" in update_data and update_data["scene_name"] is not None: + scene.scene_name = update_data["scene_name"] + + if "scene_description" in update_data: + scene.scene_description = update_data["scene_description"] + + self.db.flush() + + logger.info(f"Ontology scene updated successfully: {scene_id}") + + return scene + + except Exception as e: + logger.error( + f"Failed to update ontology scene: {str(e)}", + exc_info=True + ) + raise + + def delete(self, scene_id: UUID) -> bool: + """删除场景(级联删除类型) + + 依赖数据库级联删除配置(ondelete="CASCADE")。 + + Args: + scene_id: 场景ID + + Returns: + bool: 删除成功返回True,场景不存在返回False + + Raises: + Exception: 数据库操作失败 + + Examples: + >>> repo = OntologySceneRepository(db) + >>> success = repo.delete(scene_id) + """ + try: + logger.info(f"Deleting ontology scene: {scene_id}") + + scene = self.get_by_id(scene_id) + if not scene: + logger.warning(f"Ontology scene not found for delete: {scene_id}") + return False + + self.db.delete(scene) + self.db.flush() + + logger.info( + f"Ontology scene deleted successfully (cascade): {scene_id}" + ) + + return True + + except Exception as e: + logger.error( + f"Failed to delete ontology scene: {str(e)}", + exc_info=True + ) + raise + + def check_ownership(self, scene_id: UUID, workspace_id: UUID) -> bool: + """检查场景是否属于指定工作空间 + + Args: + scene_id: 场景ID + workspace_id: 工作空间ID + + Returns: + bool: 属于返回True,否则返回False + + Examples: + >>> repo = OntologySceneRepository(db) + >>> is_owner = repo.check_ownership(scene_id, workspace_id) + """ + try: + logger.debug( + f"Checking scene ownership - " + f"scene_id={scene_id}, workspace_id={workspace_id}" + ) + + count = self.db.query(OntologyScene).filter( + OntologyScene.scene_id == scene_id, + OntologyScene.workspace_id == workspace_id + ).count() + + is_owner = count > 0 + + logger.debug( + f"Scene ownership check result: {is_owner} - " + f"scene_id={scene_id}" + ) + + return is_owner + + except Exception as e: + logger.error( + f"Failed to check scene ownership: {str(e)}", + exc_info=True + ) + raise diff --git a/api/app/repositories/prompt_optimizer_repository.py b/api/app/repositories/prompt_optimizer_repository.py index ba65257a..e73ab513 100644 --- a/api/app/repositories/prompt_optimizer_repository.py +++ b/api/app/repositories/prompt_optimizer_repository.py @@ -4,7 +4,10 @@ from sqlalchemy.orm import Session from app.core.logging_config import get_db_logger from app.models.prompt_optimizer_model import ( - PromptOptimizerSession, PromptOptimizerSessionHistory, RoleType + PromptOptimizerSession, + PromptOptimizerSessionHistory, + RoleType, + PromptHistory ) db_logger = get_db_logger() @@ -16,6 +19,12 @@ class PromptOptimizerSessionRepository: def __init__(self, db: Session): self.db = db + def get_session_by_id(self, session_id: uuid.UUID) -> PromptOptimizerSession | None: + session = self.db.query(PromptOptimizerSession).filter( + PromptOptimizerSession.id == session_id, + ).first() + return session + def create_session( self, tenant_id: uuid.UUID, @@ -38,12 +47,9 @@ class PromptOptimizerSessionRepository: user_id=user_id, ) self.db.add(session) - self.db.commit() - self.db.refresh(session) - db_logger.debug(f"Prompt optimization session created: ID:{session.id}") return session except Exception as e: - db_logger.error(f"Error creating prompt optimization session: user_id={user_id} - {str(e)}") + db_logger.error(f"Error creating prompt optimization session: - {str(e)}") raise def get_session_history( @@ -71,10 +77,10 @@ class PromptOptimizerSessionRepository: PromptOptimizerSession.id == session_id, PromptOptimizerSession.user_id == user_id ).first() - + if not session: return [] - + history = self.db.query(PromptOptimizerSessionHistory).filter( PromptOptimizerSessionHistory.session_id == session.id, PromptOptimizerSessionHistory.user_id == user_id @@ -104,11 +110,11 @@ class PromptOptimizerSessionRepository: PromptOptimizerSession.user_id == user_id, PromptOptimizerSession.tenant_id == tenant_id ).first() - + if not session: db_logger.error(f"Session {session_id} not found for user {user_id}") raise ValueError(f"Session {session_id} not found for user {user_id}") - + message = PromptOptimizerSessionHistory( tenant_id=tenant_id, session_id=session.id, @@ -117,8 +123,199 @@ class PromptOptimizerSessionRepository: content=content, ) self.db.add(message) - self.db.commit() + return message except Exception as e: db_logger.error(f"Error creating prompt optimization session history: session_id={session_id} - {str(e)}") raise + + def get_first_user_message(self, session_id: uuid.UUID) -> str | None: + """ + Get the first user message from a session. + + Args: + session_id (uuid.UUID): The session ID. + + Returns: + str | None: The content of the first user message, or None if not found. + """ + try: + message = self.db.query(PromptOptimizerSessionHistory).filter( + PromptOptimizerSessionHistory.session_id == session_id, + PromptOptimizerSessionHistory.role == RoleType.USER.value + ).order_by( + PromptOptimizerSessionHistory.created_at.asc() + ).first() + + return message.content if message else None + except Exception as e: + db_logger.error(f"Error getting first user message: session_id={session_id} - {str(e)}") + raise + + +class PromptReleaseRepository: + def __init__(self, db: Session): + self.db = db + + def get_prompt_by_session_id(self, session_id: uuid.UUID) -> PromptHistory | None: + prompt_obj = self.db.query(PromptHistory).filter( + PromptHistory.session_id == session_id, + PromptHistory.is_delete.is_(False) + ).first() + return prompt_obj + + def create_prompt_release( + self, + tenant_id: uuid.UUID, + title: str, + session_id: uuid.UUID, + prompt: str, + ) -> PromptHistory: + try: + prompt_obj = PromptHistory( + tenant_id=tenant_id, + title=title, + session_id=session_id, + prompt=prompt, + ) + self.db.add(prompt_obj) + return prompt_obj + except Exception as e: + db_logger.error(f"Error creating prompt release: session_id={session_id} - {str(e)}") + raise + + def soft_delete_prompt(self, prompt_obj: PromptHistory) -> None: + """ + Soft delete a prompt release by setting is_delete flag to True. + + Args: + prompt_obj (PromptHistory): The prompt release object to delete. + """ + try: + prompt_obj.is_delete = True + db_logger.debug(f"Soft deleted prompt release: id={prompt_obj.id}, session_id={prompt_obj.session_id}") + except Exception as e: + db_logger.error(f"Error soft deleting prompt release: id={prompt_obj.id} - {str(e)}") + raise + + def get_prompt_by_id(self, prompt_id: uuid.UUID) -> PromptHistory | None: + """ + Get a prompt release by its ID. + + Args: + prompt_id (uuid.UUID): The prompt release ID. + + Returns: + PromptHistory | None: The prompt release object or None if not found. + """ + try: + prompt_obj = self.db.query(PromptHistory).filter( + PromptHistory.id == prompt_id + ).first() + return prompt_obj + except Exception as e: + db_logger.error(f"Error getting prompt release by id: id={prompt_id} - {str(e)}") + raise + + def count_prompts(self, tenant_id: uuid.UUID) -> int: + """ + Count total number of non-deleted prompts for a tenant. + + Args: + tenant_id (uuid.UUID): The tenant ID. + + Returns: + int: Total count of prompts. + """ + try: + count = self.db.query(PromptHistory).filter( + PromptHistory.tenant_id == tenant_id, + PromptHistory.is_delete.is_(False) + ).count() + return count + except Exception as e: + db_logger.error(f"Error counting prompts: tenant_id={tenant_id} - {str(e)}") + raise + + def get_prompts_paginated( + self, + tenant_id: uuid.UUID, + offset: int, + limit: int + ) -> list[PromptHistory]: + """ + Get paginated list of prompt releases for a tenant. + + Args: + tenant_id (uuid.UUID): The tenant ID. + offset (int): Number of records to skip. + limit (int): Maximum number of records to return. + + Returns: + list[PromptHistory]: List of prompt releases. + """ + try: + prompts = self.db.query(PromptHistory).filter( + PromptHistory.tenant_id == tenant_id, + PromptHistory.is_delete.is_(False) + ).order_by( + PromptHistory.created_at.desc() + ).offset(offset).limit(limit).all() + return prompts + except Exception as e: + db_logger.error(f"Error getting paginated prompts: tenant_id={tenant_id} - {str(e)}") + raise + + def count_prompts_by_keyword(self, tenant_id: uuid.UUID, keyword: str) -> int: + """ + Count total number of non-deleted prompts matching keyword for a tenant. + + Args: + tenant_id (uuid.UUID): The tenant ID. + keyword (str): Search keyword for title. + + Returns: + int: Total count of matching prompts. + """ + try: + count = self.db.query(PromptHistory).filter( + PromptHistory.tenant_id == tenant_id, + PromptHistory.is_delete.is_(False), + PromptHistory.title.ilike(f"%{keyword}%") + ).count() + return count + except Exception as e: + db_logger.error(f"Error counting prompts by keyword: tenant_id={tenant_id}, keyword={keyword} - {str(e)}") + raise + + def search_prompts_paginated( + self, + tenant_id: uuid.UUID, + keyword: str, + offset: int, + limit: int + ) -> list[PromptHistory]: + """ + Search prompt releases by keyword in title with pagination. + + Args: + tenant_id (uuid.UUID): The tenant ID. + keyword (str): Search keyword for title. + offset (int): Number of records to skip. + limit (int): Maximum number of records to return. + + Returns: + list[PromptHistory]: List of matching prompt releases. + """ + try: + prompts = self.db.query(PromptHistory).filter( + PromptHistory.tenant_id == tenant_id, + PromptHistory.is_delete.is_(False), + PromptHistory.title.ilike(f"%{keyword}%") + ).order_by( + PromptHistory.created_at.desc() + ).offset(offset).limit(limit).all() + return prompts + except Exception as e: + db_logger.error(f"Error searching prompts: tenant_id={tenant_id}, keyword={keyword} - {str(e)}") + raise diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 09410091..ddaed685 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -12,8 +12,8 @@ class KnowledgeBaseConfig(BaseModel): kb_id: str = Field(..., description="知识库ID") top_k: int = Field(default=3, ge=1, le=20, description="检索返回的文档数量") similarity_threshold: float = Field(default=0.7, ge=0.0, le=1.0, description="相似度阈值") - strategy: str = Field(default="hybrid", description="检索策略: hybrid | bm25 | dense") - weight: float = Field(default=1.0, ge=0.0, le=1.0, description="知识库权重(用于多知识库融合)") + # strategy: str = Field(default="hybrid", description="检索策略: hybrid | bm25 | dense") + # weight: float = Field(default=1.0, ge=0.0, le=1.0, description="知识库权重(用于多知识库融合)") vector_similarity_weight: float = Field(default=0.5, ge=0.0, le=1.0, description="向量相似度权重") retrieve_type: str = Field(default="hybrid", description="检索方式participle| semantic|hybrid") diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index 5fda0a1d..5e22d70f 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -229,6 +229,9 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body, config_desc: str = Field("配置描述", description="配置描述(字符串)") workspace_id: Optional[uuid.UUID] = Field(None, description="工作空间ID(UUID)") + # 本体场景关联(可选) + scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID(UUID),关联ontology_scene表") + # 模型配置字段(可选,用于手动指定或自动填充) llm_id: Optional[str] = Field(None, description="LLM模型配置ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") diff --git a/api/app/schemas/ontology_schemas.py b/api/app/schemas/ontology_schemas.py new file mode 100644 index 00000000..5a88f84d --- /dev/null +++ b/api/app/schemas/ontology_schemas.py @@ -0,0 +1,461 @@ +"""本体提取API的请求和响应模型 + +本模块定义了本体提取系统的所有API请求和响应的Pydantic模型。 + +Classes: + ExtractionRequest: 本体提取请求模型 + ExtractionResponse: 本体提取响应模型 + ExportRequest: OWL文件导出请求模型 + ExportResponse: OWL文件导出响应模型 + OntologyResultResponse: 本体提取结果响应模型(带毫秒时间戳) + SceneCreateRequest: 场景创建请求模型 + SceneUpdateRequest: 场景更新请求模型 + SceneResponse: 场景响应模型 + SceneListResponse: 场景列表响应模型 + ClassCreateRequest: 类型创建请求模型 + ClassUpdateRequest: 类型更新请求模型 + ClassResponse: 类型响应模型 + ClassListResponse: 类型列表响应模型 +""" + +from typing import List, Optional +import datetime +from uuid import UUID + +from pydantic import BaseModel, Field, field_serializer, ConfigDict + +from app.core.memory.models.ontology_models import OntologyClass + + +class ExtractionRequest(BaseModel): + """本体提取请求模型 + + 用于POST /api/ontology/extract端点的请求体。 + + Attributes: + scenario: 场景描述文本,不能为空 + domain: 可选的领域提示(如Healthcare, Education等) + llm_id: LLM模型ID,必须提供 + scene_id: 场景ID,必须提供,用于将提取的类保存到指定场景 + + Examples: + >>> request = ExtractionRequest( + ... scenario="医院管理患者记录...", + ... domain="Healthcare", + ... llm_id="550e8400-e29b-41d4-a716-446655440000", + ... scene_id="660e8400-e29b-41d4-a716-446655440000" + ... ) + """ + scenario: str = Field(..., description="场景描述文本", min_length=1) + domain: Optional[str] = Field(None, description="可选的领域提示") + llm_id: str = Field(..., description="LLM模型ID") + scene_id: UUID = Field(..., description="场景ID,用于将提取的类保存到指定场景") + + +class ExtractionResponse(BaseModel): + """本体提取响应模型 + + 用于POST /api/ontology/extract端点的响应体。 + + Attributes: + classes: 提取的本体类列表 + domain: 识别的领域 + extracted_count: 提取的类数量 + + Examples: + >>> response = ExtractionResponse( + ... classes=[...], + ... domain="Healthcare", + ... extracted_count=7 + ... ) + """ + classes: List[OntologyClass] = Field(default_factory=list, description="提取的本体类列表") + domain: str = Field(..., description="识别的领域") + extracted_count: int = Field(..., description="提取的类数量") + + +class ExportRequest(BaseModel): + """OWL文件导出请求模型 + + 用于POST /api/ontology/export端点的请求体。 + + Attributes: + classes: 要导出的本体类列表 + format: 导出格式,可选值: rdfxml, turtle, ntriples, json + include_metadata: 是否包含完整的OWL元数据(命名空间等),默认True + + Examples: + >>> request = ExportRequest( + ... classes=[...], + ... format="rdfxml", + ... include_metadata=True + ... ) + """ + classes: List[OntologyClass] = Field(..., description="要导出的本体类列表", min_length=1) + format: str = Field("rdfxml", description="导出格式: rdfxml, turtle, ntriples, json") + include_metadata: bool = Field(True, description="是否包含完整的OWL元数据") + + +class ExportResponse(BaseModel): + """OWL文件导出响应模型 + + 用于POST /api/ontology/export端点的响应体。 + + Attributes: + owl_content: OWL文件内容 + format: 导出格式 + classes_count: 导出的类数量 + + Examples: + >>> response = ExportResponse( + ... owl_content="...", + ... format="rdfxml", + ... classes_count=7 + ... ) + """ + owl_content: str = Field(..., description="OWL文件内容") + format: str = Field(..., description="导出格式") + classes_count: int = Field(..., description="导出的类数量") + + +class OntologyResultResponse(BaseModel): + """本体提取结果响应模型 + + 用于返回数据库中存储的提取结果,时间戳为毫秒级。 + + Attributes: + id: 结果ID (UUID) + scenario: 场景描述文本 + domain: 领域 + classes_json: 提取的本体类数据(JSON格式) + extracted_count: 提取的类数量 + user_id: 用户ID + created_at: 创建时间(毫秒时间戳) + + Examples: + >>> response = OntologyResultResponse( + ... id=uuid.uuid4(), + ... scenario="医院管理患者记录...", + ... domain="Healthcare", + ... classes_json={"classes": [...]}, + ... extracted_count=7, + ... user_id=123, + ... created_at=datetime.now() + ... ) + """ + id: UUID = Field(..., description="结果ID") + scenario: str = Field(..., description="场景描述文本") + domain: Optional[str] = Field(None, description="领域") + classes_json: dict = Field(..., description="提取的本体类数据(JSON格式)") + extracted_count: int = Field(..., description="提取的类数量") + user_id: Optional[int] = Field(None, description="用户ID") + created_at: datetime.datetime = Field(..., description="创建时间") + + @field_serializer("created_at", when_used="json") + def _serialize_created_at(self, dt: datetime.datetime): + """将创建时间序列化为毫秒时间戳""" + return int(dt.timestamp() * 1000) if dt else None + + class Config: + from_attributes = True + + + +# ==================== 本体场景相关 Schema ==================== + +class SceneCreateRequest(BaseModel): + """场景创建请求模型 + + 用于创建新的本体场景。 + + Attributes: + scene_name: 场景名称,必填,1-200字符 + scene_description: 场景描述,可选 + + Examples: + >>> request = SceneCreateRequest( + ... scene_name="医疗场景", + ... scene_description="用于医疗领域的本体建模" + ... ) + """ + scene_name: str = Field(..., min_length=1, max_length=200, description="场景名称") + scene_description: Optional[str] = Field(None, description="场景描述") + + +class SceneUpdateRequest(BaseModel): + """场景更新请求模型 + + 用于更新已有本体场景信息。 + + Attributes: + scene_name: 场景名称,可选,1-200字符 + scene_description: 场景描述,可选 + + Examples: + >>> request = SceneUpdateRequest( + ... scene_name="更新后的场景名称", + ... scene_description="更新后的描述" + ... ) + """ + scene_name: Optional[str] = Field(None, min_length=1, max_length=200, description="场景名称") + scene_description: Optional[str] = Field(None, description="场景描述") + + +class SceneResponse(BaseModel): + """场景响应模型 + + 用于返回本体场景信息。 + + Attributes: + scene_id: 场景ID + scene_name: 场景名称 + scene_description: 场景描述 + type_num: 类型数量 + workspace_id: 所属工作空间ID + created_at: 创建时间(毫秒时间戳) + updated_at: 更新时间(毫秒时间戳) + classes_count: 类型数量 + + Examples: + >>> response = SceneResponse( + ... scene_id=uuid.uuid4(), + ... scene_name="医疗场景", + ... scene_description="用于医疗领域的本体建模", + ... type_num=0, + ... workspace_id=uuid.uuid4(), + ... created_at=datetime.now(), + ... updated_at=datetime.now(), + ... classes_count=5 + ... ) + """ + scene_id: UUID = Field(..., description="场景ID") + scene_name: str = Field(..., description="场景名称") + scene_description: Optional[str] = Field(None, description="场景描述") + type_num: int = Field(..., description="类型数量") + entity_type: Optional[List[str]] = Field(None, description="实体类型列表(最多3个class_name)") + workspace_id: UUID = Field(..., description="所属工作空间ID") + created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)") + updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)") + classes_count: int = Field(0, description="类型数量") + + @field_serializer("created_at", when_used="json") + def _serialize_created_at(self, dt: datetime.datetime): + """将创建时间序列化为毫秒时间戳""" + return int(dt.timestamp() * 1000) if dt else None + + @field_serializer("updated_at", when_used="json") + def _serialize_updated_at(self, dt: datetime.datetime): + """将更新时间序列化为毫秒时间戳""" + return int(dt.timestamp() * 1000) if dt else None + + model_config = ConfigDict(from_attributes=True) + + +class PaginationInfo(BaseModel): + """分页信息模型 + + Attributes: + page: 当前页码 + pagesize: 每页数量 + total: 总数量 + hasnext: 是否有下一页 + """ + page: int = Field(..., description="当前页码") + pagesize: int = Field(..., description="每页数量") + total: int = Field(..., description="总数量") + hasnext: bool = Field(..., description="是否有下一页") + + +class SceneListResponse(BaseModel): + """场景列表响应模型(支持分页) + + 用于返回本体场景列表。 + + Attributes: + items: 场景列表 + page: 分页信息(可选,分页时返回) + + Examples: + >>> # 不分页 + >>> response = SceneListResponse( + ... items=[scene1, scene2] + ... ) + >>> # 分页 + >>> response = SceneListResponse( + ... items=[scene1, scene2, ...], + ... page=PaginationInfo(page=1, pagesize=100, total=150, hasnext=True) + ... ) + """ + items: List[SceneResponse] = Field(..., description="场景列表") + page: Optional[PaginationInfo] = Field(None, description="分页信息") + + +# ==================== 本体类型相关 Schema ==================== + +class ClassItem(BaseModel): + """单个类型信息模型 + + Attributes: + class_name: 类型名称,必填,1-200字符 + class_description: 类型描述,可选 + + Examples: + >>> item = ClassItem( + ... class_name="患者", + ... class_description="医院患者信息" + ... ) + """ + class_name: str = Field(..., min_length=1, max_length=200, description="类型名称") + class_description: Optional[str] = Field(None, description="类型描述") + + +class ClassCreateRequest(BaseModel): + """类型创建请求模型(统一使用列表形式) + + 通过列表中元素数量决定创建模式: + - 列表包含 1 个元素:单个创建 + - 列表包含多个元素:批量创建 + + Attributes: + scene_id: 所属场景ID,必填 + classes: 类型列表,必填,至少包含 1 个元素 + + Examples: + # 单个创建(列表中 1 个元素) + >>> request = ClassCreateRequest( + ... scene_id=uuid.uuid4(), + ... classes=[ + ... ClassItem(class_name="患者", class_description="医院患者信息") + ... ] + ... ) + + # 批量创建(列表中多个元素) + >>> request = ClassCreateRequest( + ... scene_id=uuid.uuid4(), + ... classes=[ + ... ClassItem(class_name="患者", class_description="医院患者信息"), + ... ClassItem(class_name="医生", class_description="医院医生信息"), + ... ClassItem(class_name="药品", class_description="医院药品信息") + ... ] + ... ) + """ + scene_id: UUID = Field(..., description="所属场景ID") + classes: List[ClassItem] = Field(..., min_length=1, description="类型列表,至少包含 1 个元素") + + +class ClassUpdateRequest(BaseModel): + """类型更新请求模型 + + 用于更新已有本体类型信息。 + + Attributes: + class_name: 类型名称,可选,1-200字符 + class_description: 类型描述,可选 + + Examples: + >>> request = ClassUpdateRequest( + ... class_name="更新后的类型名称", + ... class_description="更新后的描述" + ... ) + """ + class_name: Optional[str] = Field(None, min_length=1, max_length=200, description="类型名称") + class_description: Optional[str] = Field(None, description="类型描述") + + +class ClassResponse(BaseModel): + """类型响应模型 + + 用于返回本体类型信息。 + + Attributes: + class_id: 类型ID + class_name: 类型名称 + class_description: 类型描述 + scene_id: 所属场景ID + created_at: 创建时间(毫秒时间戳) + updated_at: 更新时间(毫秒时间戳) + + Examples: + >>> response = ClassResponse( + ... class_id=uuid.uuid4(), + ... class_name="患者", + ... class_description="医院患者信息", + ... scene_id=uuid.uuid4(), + ... created_at=datetime.now(), + ... updated_at=datetime.now() + ... ) + """ + class_id: UUID = Field(..., description="类型ID") + class_name: str = Field(..., description="类型名称") + class_description: Optional[str] = Field(None, description="类型描述") + scene_id: UUID = Field(..., description="所属场景ID") + created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)") + updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)") + + @field_serializer("created_at", when_used="json") + def _serialize_created_at(self, dt: datetime.datetime): + """将创建时间序列化为毫秒时间戳""" + return int(dt.timestamp() * 1000) if dt else None + + @field_serializer("updated_at", when_used="json") + def _serialize_updated_at(self, dt: datetime.datetime): + """将更新时间序列化为毫秒时间戳""" + return int(dt.timestamp() * 1000) if dt else None + + model_config = ConfigDict(from_attributes=True) + + +class ClassBatchCreateResponse(BaseModel): + """批量创建类型响应模型 + + 用于返回批量创建的结果统计和详情。 + + Attributes: + total: 总共尝试创建的数量 + success_count: 成功创建的数量 + failed_count: 失败的数量 + items: 成功创建的类型列表 + errors: 失败的错误信息列表(可选) + + Examples: + >>> response = ClassBatchCreateResponse( + ... total=3, + ... success_count=2, + ... failed_count=1, + ... items=[class1, class2], + ... errors=["创建类型 '药品' 失败: 类型名称已存在"] + ... ) + """ + total: int = Field(..., description="总共尝试创建的数量") + success_count: int = Field(..., description="成功创建的数量") + failed_count: int = Field(0, description="失败的数量") + items: List[ClassResponse] = Field(..., description="成功创建的类型列表") + errors: Optional[List[str]] = Field(None, description="失败的错误信息列表") + + +class ClassListResponse(BaseModel): + """类型列表响应模型 + + 用于返回本体类型列表。 + + Attributes: + total: 总数量 + scene_id: 所属场景ID + scene_name: 场景名称 + scene_description: 场景描述 + items: 类型列表 + + Examples: + >>> response = ClassListResponse( + ... total=3, + ... scene_id=uuid.uuid4(), + ... scene_name="医疗场景", + ... scene_description="用于医疗领域的本体建模", + ... items=[class1, class2, class3] + ... ) + """ + total: int = Field(..., description="总数量") + scene_id: UUID = Field(..., description="所属场景ID") + scene_name: str = Field(..., description="场景名称") + scene_description: Optional[str] = Field(None, description="场景描述") + items: List[ClassResponse] = Field(..., description="类型列表") diff --git a/api/app/schemas/prompt_optimizer_schema.py b/api/app/schemas/prompt_optimizer_schema.py index e1f27be0..08a11317 100644 --- a/api/app/schemas/prompt_optimizer_schema.py +++ b/api/app/schemas/prompt_optimizer_schema.py @@ -22,6 +22,23 @@ class PromptOptMessage(BaseModel): ) +class PromptSaveRequest(BaseModel): + session_id: UUID = Field( + ..., + description="Session ID" + ) + + title: str = Field( + ..., + description="Prompt Title" + ) + + prompt: str = Field( + ..., + description="Optimized prompt content" + ) + + class PromptOptModelSet(BaseModel): id: UUID | None = Field( default=None, diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index c0a66e03..bd9106e5 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -171,7 +171,14 @@ class AppChatService: self.conversation_service.save_conversation_messages( conversation_id=conversation_id, user_message=message, - assistant_message=result["content"] + assistant_message=result["content"], + meta_data={ + "usage": result.get("usage", { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + }) + } ) elapsed_time = time.time() - start_time @@ -310,6 +317,7 @@ class AppChatService: # 流式调用 Agent full_content = "" + total_tokens = 0 async for chunk in agent.chat_stream( message=message, history=history, @@ -320,9 +328,12 @@ class AppChatService: config_id=config_id, memory_flag=memory_flag ): - full_content += chunk - # 发送消息块事件 - yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n" + if isinstance(chunk, int): + total_tokens = chunk + else: + full_content += chunk + # 发送消息块事件 + yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n" elapsed_time = time.time() - start_time @@ -339,7 +350,7 @@ class AppChatService: content=full_content, meta_data={ "model": api_key_obj.model_name, - "usage": {} + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens} } ) @@ -416,7 +427,11 @@ class AppChatService: meta_data={ "mode": result.get("mode"), "elapsed_time": result.get("elapsed_time"), - "sub_results": result.get("sub_results") + "usage": result.get("usage", { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + }) } ) @@ -458,6 +473,7 @@ class AppChatService: yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n" full_content = "" + total_tokens = 0 # 2. 创建编排器 orchestrator = MultiAgentOrchestrator(self.db, config) @@ -474,16 +490,26 @@ class AppChatService: storage_type=storage_type, user_rag_memory_id=user_rag_memory_id ): - yield event - # 尝试提取内容(用于保存) - if "data:" in event: - try: - data_line = event.split("data: ", 1)[1].strip() - data = json.loads(data_line) - if "content" in data: - full_content += data["content"] - except: - pass + if "sub_usage" in event: + if "data:" in event: + try: + data_line = event.split("data: ", 1)[1].strip() + data = json.loads(data_line) + if "total_tokens" in data: + total_tokens += data["total_tokens"] + except: + pass + else: + yield event + # 尝试提取内容(用于保存) + if "data:" in event: + try: + data_line = event.split("data: ", 1)[1].strip() + data = json.loads(data_line) + if "content" in data: + full_content += data["content"] + except: + pass elapsed_time = time.time() - start_time @@ -499,7 +525,12 @@ class AppChatService: role="assistant", content=full_content, meta_data={ - "elapsed_time": elapsed_time + "elapsed_time": elapsed_time, + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": total_tokens + } } ) diff --git a/api/app/services/conversation_service.py b/api/app/services/conversation_service.py index 275d6413..553aefc4 100644 --- a/api/app/services/conversation_service.py +++ b/api/app/services/conversation_service.py @@ -1,4 +1,5 @@ """会话服务""" +import os import uuid from datetime import datetime, timedelta from typing import Annotated @@ -298,7 +299,8 @@ class ConversationService: self, conversation_id: uuid.UUID, user_message: str, - assistant_message: str + assistant_message: str, + meta_data: Optional[dict] = None ): """ Save a pair of user and assistant messages to the conversation. @@ -307,6 +309,7 @@ class ConversationService: conversation_id (uuid.UUID): Conversation UUID. user_message (str): User's message content. assistant_message (str): Assistant's response content. + meta_data (Optional[dict]): Optional metadata for the messages. """ self.add_message( conversation_id=conversation_id, @@ -317,7 +320,8 @@ class ConversationService: self.add_message( conversation_id=conversation_id, role="assistant", - content=assistant_message + content=assistant_message, + meta_data=meta_data ) logger.debug( @@ -526,12 +530,12 @@ class ConversationService: takeaways=[], info_score=0, ) - - with open('app/services/prompt/conversation_summary_system.jinja2', 'r', encoding='utf-8') as f: + prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') + with open(os.path.join(prompt_path, 'conversation_summary_system.jinja2'), 'r', encoding='utf-8') as f: system_prompt = f.read() rendered_system_message = Template(system_prompt).render() - with open('app/services/prompt/conversation_summary_user.jinja2', 'r', encoding='utf-8') as f: + with open(os.path.join(prompt_path, 'conversation_summary_user.jinja2'), 'r', encoding='utf-8') as f: user_prompt = f.read() rendered_user_message = Template(user_prompt).render( language=language, diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 524c9ff6..9a3e1d37 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -442,7 +442,14 @@ class DraftRunService: user_message=message, assistant_message=result["content"], app_id=agent_config.app_id, - user_id=user_id + user_id=user_id, + meta_data={ + "usage": result.get("usage", { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + }) + } ) response = { @@ -649,6 +656,7 @@ class DraftRunService: # 9. 流式调用 Agent full_content = "" + total_tokens = 0 async for chunk in agent.chat_stream( message=message, history=history, @@ -659,14 +667,22 @@ class DraftRunService: user_rag_memory_id=user_rag_memory_id, memory_flag=memory_flag ): - full_content += chunk - # 发送消息块事件 - yield self._format_sse_event("message", { - "content": chunk - }) + if isinstance(chunk, int): + total_tokens = chunk + else: + full_content += chunk + # 发送消息块事件 + yield self._format_sse_event("message", { + "content": chunk + }) elapsed_time = time.time() - start_time + if sub_agent: + yield self._format_sse_event("sub_usage", { + "total_tokens": total_tokens + }) + # 10. 保存会话消息 if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"): await self._save_conversation_message( @@ -674,7 +690,10 @@ class DraftRunService: user_message=message, assistant_message=full_content, app_id=agent_config.app_id, - user_id=user_id + user_id=user_id, + meta_data={ + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens} + } ) # 11. 发送结束事件 @@ -898,6 +917,7 @@ class DraftRunService: conversation_id: str, user_message: str, assistant_message: str, + meta_data: dict, app_id: Optional[uuid.UUID] = None, user_id: Optional[str] = None ) -> None: @@ -909,6 +929,7 @@ class DraftRunService: assistant_message: AI 回复消息 app_id: 应用ID(未使用,保留用于兼容性) user_id: 用户ID(未使用,保留用于兼容性) + meta_data: token消耗 """ try: from app.services.conversation_service import ConversationService @@ -927,7 +948,8 @@ class DraftRunService: conversation_service.add_message( conversation_id=conv_uuid, role="assistant", - content=assistant_message + content=assistant_message, + meta_data=meta_data ) logger.debug( diff --git a/api/app/services/emotion_analytics_service.py b/api/app/services/emotion_analytics_service.py index af98fb52..7bc776ed 100644 --- a/api/app/services/emotion_analytics_service.py +++ b/api/app/services/emotion_analytics_service.py @@ -17,12 +17,15 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector from pydantic import BaseModel, Field from sqlalchemy.orm import Session +from app.utils.config_utils import resolve_config_id + logger = get_business_logger() class EmotionSuggestion(BaseModel): """情绪建议模型""" - type: str = Field(..., description="建议类型:emotion_balance/activity_recommendation/social_connection/stress_management") + type: str = Field(..., + description="建议类型:emotion_balance/activity_recommendation/social_connection/stress_management") title: str = Field(..., description="建议标题") content: str = Field(..., description="建议内容") priority: str = Field(..., description="优先级:high/medium/low") @@ -37,33 +40,33 @@ class EmotionSuggestionsResponse(BaseModel): class EmotionAnalyticsService: """情绪分析服务 - + 提供情绪数据的分析和统计功能,包括: - 情绪标签统计 - 情绪词云数据 - 情绪健康指数计算 - 个性化情绪建议生成 - + Attributes: emotion_repo: 情绪数据仓储实例 """ - + def __init__(self): """初始化情绪分析服务""" connector = Neo4jConnector() self.emotion_repo = EmotionRepository(connector) logger.info("情绪分析服务初始化完成") - + async def get_emotion_tags( - self, - end_user_id: str, - emotion_type: Optional[str] = None, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - limit: int = 10 + self, + end_user_id: str, + emotion_type: Optional[str] = None, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + limit: int = 10 ) -> Dict[str, Any]: """获取情绪标签统计 - + 查询指定用户的情绪类型分布,包括计数、百分比和平均强度。 确保返回所有6个情绪维度(joy、sadness、anger、fear、surprise、neutral), 即使某些维度没有数据也会返回count=0的记录。 @@ -71,8 +74,8 @@ class EmotionAnalyticsService: """ try: logger.info(f"获取情绪标签统计: user={end_user_id}, type={emotion_type}, " - f"start={start_date}, end={end_date}, limit={limit}") - + f"start={start_date}, end={end_date}, limit={limit}") + # 调用仓储层查询 tags = await self.emotion_repo.get_emotion_tags( end_user_id=end_user_id, @@ -81,13 +84,13 @@ class EmotionAnalyticsService: end_date=end_date, limit=limit ) - + # 定义所有6个情绪维度 all_emotion_types = ['joy', 'sadness', 'anger', 'fear', 'surprise', 'neutral'] - + # 将查询结果转换为字典,方便查找 tags_dict = {tag["emotion_type"]: tag for tag in tags} - + # 补全缺失的情绪维度 complete_tags = [] for emotion in all_emotion_types: @@ -101,52 +104,52 @@ class EmotionAnalyticsService: "percentage": 0.0, "avg_intensity": 0.0 }) - + # 计算总数 total_count = sum(tag["count"] for tag in complete_tags) - + # 如果有数据,重新计算百分比(因为补全了0值项) if total_count > 0: for tag in complete_tags: if tag["count"] > 0: tag["percentage"] = round((tag["count"] / total_count) * 100, 2) - + # 构建时间范围信息 time_range = {} if start_date: time_range["start_date"] = start_date if end_date: time_range["end_date"] = end_date - + # 格式化响应 response = { "tags": complete_tags, "total_count": total_count, "time_range": time_range if time_range else None } - + logger.info(f"情绪标签统计完成: total_count={total_count}, tags_count={len(complete_tags)}") return response - + except Exception as e: logger.error(f"获取情绪标签统计失败: {str(e)}", exc_info=True) raise - + async def get_emotion_wordcloud( - self, - end_user_id: str, - emotion_type: Optional[str] = None, - limit: int = 50 + self, + end_user_id: str, + emotion_type: Optional[str] = None, + limit: int = 50 ) -> Dict[str, Any]: """获取情绪词云数据 - + 查询情绪关键词及其频率,用于生成词云可视化。 - + Args: end_user_id: 宿主ID(用户组ID) emotion_type: 可选的情绪类型过滤 limit: 返回关键词的最大数量 - + Returns: Dict: 包含情绪词云数据的响应: - keywords: 关键词列表 @@ -154,39 +157,39 @@ class EmotionAnalyticsService: """ try: logger.info(f"获取情绪词云数据: user={end_user_id}, type={emotion_type}, limit={limit}") - + # 调用仓储层查询 keywords = await self.emotion_repo.get_emotion_wordcloud( end_user_id=end_user_id, emotion_type=emotion_type, limit=limit ) - + # 计算总关键词数量 total_keywords = len(keywords) - + # 格式化响应 response = { "keywords": keywords, "total_keywords": total_keywords } - + logger.info(f"情绪词云数据获取完成: total_keywords={total_keywords}") return response - + except Exception as e: logger.error(f"获取情绪词云数据失败: {str(e)}", exc_info=True) raise - + def _calculate_positivity_rate(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]: """计算积极率 - + 根据情绪类型分类正面、负面和中性情绪,计算积极率。 公式:(正面数 / (正面数 + 负面数)) * 100 - + Args: emotions: 情绪数据列表,每个包含 emotion_type 字段 - + Returns: Dict: 包含积极率计算结果: - score: 积极率分数(0-100) @@ -197,38 +200,38 @@ class EmotionAnalyticsService: # 定义情绪分类 positive_emotions = {'joy', 'surprise'} negative_emotions = {'sadness', 'anger', 'fear'} - + # 统计各类情绪数量 positive_count = sum(1 for e in emotions if e.get('emotion_type') in positive_emotions) negative_count = sum(1 for e in emotions if e.get('emotion_type') in negative_emotions) neutral_count = sum(1 for e in emotions if e.get('emotion_type') == 'neutral') - + # 计算积极率 total_non_neutral = positive_count + negative_count if total_non_neutral > 0: score = (positive_count / total_non_neutral) * 100 else: score = 50.0 # 如果没有非中性情绪,默认为50 - + logger.debug(f"积极率计算: positive={positive_count}, negative={negative_count}, " - f"neutral={neutral_count}, score={score:.2f}") - + f"neutral={neutral_count}, score={score:.2f}") + return { "score": round(score, 2), "positive_count": positive_count, "negative_count": negative_count, "neutral_count": neutral_count } - + def _calculate_stability(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]: """计算稳定性 - + 基于情绪强度的标准差计算情绪稳定性。 公式:(1 - min(std_deviation, 1.0)) * 100 - + Args: emotions: 情绪数据列表,每个包含 emotion_intensity 字段 - + Returns: Dict: 包含稳定性计算结果: - score: 稳定性分数(0-100) @@ -236,7 +239,7 @@ class EmotionAnalyticsService: """ # 提取所有情绪强度 intensities = [e.get('emotion_intensity', 0.0) for e in emotions if e.get('emotion_intensity') is not None] - + # 计算标准差 if len(intensities) >= 2: std_deviation = statistics.stdev(intensities) @@ -244,29 +247,29 @@ class EmotionAnalyticsService: std_deviation = 0.0 # 只有一个数据点,标准差为0 else: std_deviation = 0.0 # 没有数据,标准差为0 - + # 计算稳定性分数 # 标准差越小,稳定性越高 score = (1 - min(std_deviation, 1.0)) * 100 - + logger.debug(f"稳定性计算: intensities_count={len(intensities)}, " - f"std_deviation={std_deviation:.3f}, score={score:.2f}") - + f"std_deviation={std_deviation:.3f}, score={score:.2f}") + return { "score": round(score, 2), "std_deviation": round(std_deviation, 3) } - + def _calculate_resilience(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]: """计算恢复力 - + 分析情绪转换模式,统计从负面情绪恢复到正面情绪的能力。 公式:(负面到正面转换次数 / 总负面情绪数) * 100 - + Args: emotions: 情绪数据列表,每个包含 emotion_type 和 created_at 字段 应该按时间顺序排列 - + Returns: Dict: 包含恢复力计算结果: - score: 恢复力分数(0-100) @@ -275,24 +278,24 @@ class EmotionAnalyticsService: # 定义情绪分类 positive_emotions = {'joy', 'surprise'} negative_emotions = {'sadness', 'anger', 'fear'} - + # 统计负面到正面的转换次数 recovery_count = 0 negative_count = 0 - + for i in range(len(emotions)): current_emotion = emotions[i].get('emotion_type') - + # 统计负面情绪总数 if current_emotion in negative_emotions: negative_count += 1 - + # 检查下一个情绪是否为正面 if i + 1 < len(emotions): next_emotion = emotions[i + 1].get('emotion_type') if next_emotion in positive_emotions: recovery_count += 1 - + # 计算恢复力分数 if negative_count > 0: recovery_rate = recovery_count / negative_count @@ -301,28 +304,28 @@ class EmotionAnalyticsService: # 如果没有负面情绪,恢复力设为100(最佳状态) recovery_rate = 1.0 score = 100.0 - + logger.debug(f"恢复力计算: negative_count={negative_count}, " - f"recovery_count={recovery_count}, score={score:.2f}") - + f"recovery_count={recovery_count}, score={score:.2f}") + return { "score": round(score, 2), "recovery_rate": round(recovery_rate, 3) } - + async def calculate_emotion_health_index( - self, - end_user_id: str, - time_range: str = "30d" + self, + end_user_id: str, + time_range: str = "30d" ) -> Dict[str, Any]: """计算情绪健康指数 - + 综合积极率、稳定性和恢复力计算情绪健康指数。 - + Args: end_user_id: 宿主ID(用户组ID) time_range: 时间范围(7d/30d/90d) - + Returns: Dict: 包含情绪健康指数的完整响应: - health_score: 综合健康分数(0-100) @@ -336,13 +339,13 @@ class EmotionAnalyticsService: """ try: logger.info(f"计算情绪健康指数: user={end_user_id}, time_range={time_range}") - + # 获取时间范围内的情绪数据 emotions = await self.emotion_repo.get_emotions_in_range( end_user_id=end_user_id, time_range=time_range ) - + # 如果没有数据,返回默认值 if not emotions: logger.warning(f"用户 {end_user_id} 在时间范围 {time_range} 内没有情绪数据") @@ -357,20 +360,20 @@ class EmotionAnalyticsService: "emotion_distribution": {}, "time_range": time_range } - + # 计算各维度指标 positivity_rate = self._calculate_positivity_rate(emotions) stability = self._calculate_stability(emotions) resilience = self._calculate_resilience(emotions) - + # 计算综合健康分数 # 公式:positivity_rate * 0.4 + stability * 0.3 + resilience * 0.3 health_score = ( - positivity_rate["score"] * 0.4 + - stability["score"] * 0.3 + - resilience["score"] * 0.3 + positivity_rate["score"] * 0.4 + + stability["score"] * 0.3 + + resilience["score"] * 0.3 ) - + # 确定健康等级 if health_score >= 80: level = "优秀" @@ -380,13 +383,13 @@ class EmotionAnalyticsService: level = "一般" else: level = "较差" - + # 统计情绪分布 emotion_distribution = {} for emotion_type in ['joy', 'sadness', 'anger', 'fear', 'surprise', 'neutral']: count = sum(1 for e in emotions if e.get('emotion_type') == emotion_type) emotion_distribution[emotion_type] = count - + # 格式化响应 response = { "health_score": round(health_score, 2), @@ -399,22 +402,22 @@ class EmotionAnalyticsService: "emotion_distribution": emotion_distribution, "time_range": time_range } - + logger.info(f"情绪健康指数计算完成: score={health_score:.2f}, level={level}") return response - + except Exception as e: logger.error(f"计算情绪健康指数失败: {str(e)}", exc_info=True) raise - + def _analyze_emotion_patterns(self, emotions: List[Dict[str, Any]]) -> Dict[str, Any]: """分析情绪模式 - + 识别主要负面情绪、情绪触发因素和波动时段。 - + Args: emotions: 情绪数据列表,每个包含 emotion_type、emotion_intensity、created_at 字段 - + Returns: Dict: 包含情绪模式分析结果: - dominant_negative_emotion: 主要负面情绪类型 @@ -422,19 +425,19 @@ class EmotionAnalyticsService: - emotion_volatility: 情绪波动性(高/中/低) """ negative_emotions = {'sadness', 'anger', 'fear'} - + # 统计负面情绪分布 negative_emotion_counts = {} for emotion in emotions: emotion_type = emotion.get('emotion_type') if emotion_type in negative_emotions: negative_emotion_counts[emotion_type] = negative_emotion_counts.get(emotion_type, 0) + 1 - + # 识别主要负面情绪 dominant_negative_emotion = None if negative_emotion_counts: dominant_negative_emotion = max(negative_emotion_counts, key=negative_emotion_counts.get) - + # 识别高强度情绪(强度 >= 0.7) high_intensity_emotions = [ { @@ -445,7 +448,7 @@ class EmotionAnalyticsService: for e in emotions if e.get('emotion_intensity', 0) >= 0.7 ] - + # 评估情绪波动性 intensities = [e.get('emotion_intensity', 0.0) for e in emotions if e.get('emotion_intensity') is not None] if len(intensities) >= 2: @@ -458,29 +461,29 @@ class EmotionAnalyticsService: volatility = "低" else: volatility = "未知" - + logger.debug(f"情绪模式分析: dominant_negative={dominant_negative_emotion}, " - f"high_intensity_count={len(high_intensity_emotions)}, volatility={volatility}") - + f"high_intensity_count={len(high_intensity_emotions)}, volatility={volatility}") + return { "dominant_negative_emotion": dominant_negative_emotion, "high_intensity_emotions": high_intensity_emotions[:5], # 最多返回5个 "emotion_volatility": volatility } - + async def generate_emotion_suggestions( - self, - end_user_id: str, - db: Session, + self, + end_user_id: str, + db: Session, ) -> Dict[str, Any]: """生成个性化情绪建议 - + 基于情绪健康数据和用户画像生成个性化建议。 - + Args: end_user_id: 宿主ID(用户组ID) db: 数据库会话 - + Returns: Dict: 包含个性化建议的响应: - health_summary: 健康状态摘要 @@ -488,17 +491,17 @@ class EmotionAnalyticsService: """ try: logger.info(f"生成个性化情绪建议: user={end_user_id}") - + # 1. 从 end_user_id 获取关联的 memory_config_id llm_client = None try: from app.services.memory_agent_service import ( get_end_user_connected_config, ) - + connected_config = get_end_user_connected_config(end_user_id, db) config_id = connected_config.get("memory_config_id") - + config_id = resolve_config_id(config_id, db) if config_id is not None: from app.services.memory_config_service import ( MemoryConfigService, @@ -513,35 +516,35 @@ class EmotionAnalyticsService: llm_client = factory.get_llm_client(str(memory_config.llm_model_id)) except Exception as e: logger.warning(f"无法获取 end_user {end_user_id} 的配置,将使用默认配置: {e}") - + # 2. 获取情绪健康数据 health_data = await self.calculate_emotion_health_index(end_user_id, time_range="30d") - + # 3. 获取情绪数据用于模式分析 emotions = await self.emotion_repo.get_emotions_in_range( end_user_id=end_user_id, time_range="30d" ) - + # 4. 分析情绪模式 patterns = self._analyze_emotion_patterns(emotions) - + # 5. 获取用户画像数据(简化版,直接从Neo4j获取) user_profile = await self._get_simple_user_profile(end_user_id) - + # 6. 构建LLM prompt prompt = await self._build_suggestion_prompt(health_data, patterns, user_profile) - + # 7. 调用LLM生成建议(使用配置中的LLM) if llm_client is None: # 无法获取配置时,抛出错误而不是使用默认配置 raise ValueError("无法获取LLM配置,请确保end_user关联了有效的memory_config") - + # 将 prompt 转换为 messages 格式 messages = [ {"role": "user", "content": prompt} ] - + # 8. 使用结构化输出直接获取 Pydantic 模型 try: suggestions_response = await llm_client.response_structured( @@ -552,7 +555,7 @@ class EmotionAnalyticsService: logger.error(f"LLM 结构化输出失败: {str(e)}") # 返回默认建议 suggestions_response = self._get_default_suggestions(health_data) - + # 8. 验证建议数量(3-5条) if len(suggestions_response.suggestions) < 3: logger.warning(f"建议数量不足: {len(suggestions_response.suggestions)}") @@ -560,7 +563,7 @@ class EmotionAnalyticsService: elif len(suggestions_response.suggestions) > 5: logger.warning(f"建议数量过多: {len(suggestions_response.suggestions)}") suggestions_response.suggestions = suggestions_response.suggestions[:5] - + # 9. 格式化响应 response = { "health_summary": suggestions_response.health_summary, @@ -575,26 +578,26 @@ class EmotionAnalyticsService: for s in suggestions_response.suggestions ] } - + logger.info(f"个性化建议生成完成: suggestions_count={len(response['suggestions'])}") return response - + except Exception as e: logger.error(f"生成个性化建议失败: {str(e)}", exc_info=True) raise - + async def _get_simple_user_profile(self, end_user_id: str) -> Dict[str, Any]: """获取简化的用户画像数据 - + Args: end_user_id: 用户ID - + Returns: Dict: 用户画像数据 """ try: connector = Neo4jConnector() - + # 查询用户的实体和标签 query = """ MATCH (e:Entity) @@ -603,59 +606,59 @@ class EmotionAnalyticsService: ORDER BY e.created_at DESC LIMIT 20 """ - + entities = await connector.execute_query(query, end_user_id=end_user_id) - + # 提取兴趣标签 interests = [e["name"] for e in entities if e.get("type") in ["INTEREST", "HOBBY"]][:5] # 后期会引入用户的习惯。。 return { "interests": interests if interests else ["未知"] } - + except Exception as e: logger.error(f"获取用户画像失败: {str(e)}") return {"interests": ["未知"]} - + async def _build_suggestion_prompt( - self, - health_data: Dict[str, Any], - patterns: Dict[str, Any], - user_profile: Dict[str, Any] + self, + health_data: Dict[str, Any], + patterns: Dict[str, Any], + user_profile: Dict[str, Any] ) -> str: """构建情绪建议生成的prompt - + Args: health_data: 情绪健康数据 patterns: 情绪模式分析结果 user_profile: 用户画像数据 - + Returns: str: LLM prompt """ from app.core.memory.utils.prompt.prompt_utils import ( render_emotion_suggestions_prompt, ) - + prompt = await render_emotion_suggestions_prompt( health_data=health_data, patterns=patterns, user_profile=user_profile ) - + return prompt - + def _get_default_suggestions(self, health_data: Dict[str, Any]) -> EmotionSuggestionsResponse: """获取默认建议(当LLM调用失败时使用) - + Args: health_data: 情绪健康数据 - + Returns: EmotionSuggestionsResponse: 默认建议 """ health_score = health_data.get('health_score', 0) - + if health_score >= 80: summary = "您的情绪健康状况优秀,请继续保持积极的生活态度。" elif health_score >= 60: @@ -664,7 +667,7 @@ class EmotionAnalyticsService: summary = "您的情绪健康需要关注,建议采取一些改善措施。" else: summary = "您的情绪健康需要重点关注,建议寻求专业帮助。" - + suggestions = [ EmotionSuggestion( type="emotion_balance", @@ -700,54 +703,54 @@ class EmotionAnalyticsService: ] ) ] - + return EmotionSuggestionsResponse( health_summary=summary, suggestions=suggestions ) - + async def get_cached_suggestions( - self, - end_user_id: str, - db: Session, + self, + end_user_id: str, + db: Session, ) -> Optional[Dict[str, Any]]: """从 Redis 缓存获取个性化情绪建议 - + Args: end_user_id: 宿主ID(用户组ID) db: 数据库会话(保留参数以保持接口兼容性) - + Returns: Dict: 缓存的建议数据,如果不存在或已过期返回 None """ try: from app.cache.memory.emotion_memory import EmotionMemoryCache - + logger.info(f"尝试从 Redis 缓存获取情绪建议: user={end_user_id}") - + # 从 Redis 获取缓存 cached_data = await EmotionMemoryCache.get_emotion_suggestions(end_user_id) - + if cached_data is None: logger.info(f"用户 {end_user_id} 的建议缓存不存在或已过期") return None - + logger.info(f"成功从 Redis 缓存获取建议: user={end_user_id}") return cached_data - + except Exception as e: logger.error(f"从 Redis 缓存获取建议失败: {str(e)}", exc_info=True) return None - + async def save_suggestions_cache( - self, - end_user_id: str, - suggestions_data: Dict[str, Any], - db: Session, - expires_hours: int = 24 + self, + end_user_id: str, + suggestions_data: Dict[str, Any], + db: Session, + expires_hours: int = 24 ) -> None: """保存建议到 Redis 缓存 - + Args: end_user_id: 宿主ID(用户组ID) suggestions_data: 建议数据 @@ -756,24 +759,24 @@ class EmotionAnalyticsService: """ try: from app.cache.memory.emotion_memory import EmotionMemoryCache - + logger.info(f"保存建议到 Redis 缓存: user={end_user_id}, expires={expires_hours}小时") - + # 计算过期时间(秒) expire_seconds = expires_hours * 3600 - + # 保存到 Redis success = await EmotionMemoryCache.set_emotion_suggestions( user_id=end_user_id, suggestions_data=suggestions_data, expire=expire_seconds ) - + if success: logger.info(f"建议缓存保存成功: user={end_user_id}") else: logger.warning(f"建议缓存保存失败: user={end_user_id}") - + except Exception as e: logger.error(f"保存建议缓存失败: {str(e)}", exc_info=True) # 不抛出异常,缓存失败不应影响主流程 \ No newline at end of file diff --git a/api/app/services/handoffs_service.py b/api/app/services/handoffs_service.py index 114e9945..10e4d646 100644 --- a/api/app/services/handoffs_service.py +++ b/api/app/services/handoffs_service.py @@ -4,7 +4,7 @@ import uuid from typing import List, Dict, Any, Optional, AsyncGenerator, Annotated from typing_extensions import TypedDict -from langchain_core.messages import HumanMessage, AIMessage, BaseMessage +from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, AIMessageChunk from langgraph.graph import StateGraph, START, END from langgraph.types import Command from langgraph.checkpoint.memory import MemorySaver @@ -727,9 +727,12 @@ class HandoffsService: # 提取响应 response_content = "" + total_tokens = 0 for msg in result.get("messages", []): if isinstance(msg, AIMessage): response_content = msg.content + response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None + total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0 break return { @@ -737,7 +740,12 @@ class HandoffsService: "active_agent": result.get("active_agent"), "response": response_content, "message_count": len(result.get("messages", [])), - "handoff_count": result.get("handoff_count", 0) + "handoff_count": result.get("handoff_count", 0), + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": total_tokens + } } async def chat_stream( @@ -830,6 +838,12 @@ class HandoffsService: # 捕获 LLM 结束事件,输出收集到的工具调用 elif kind == "on_chat_model_end": + output_message = event.get("data", {}).get("output", {}) + if isinstance(output_message, AIMessageChunk): + response_meta = output_message.response_metadata if hasattr(output_message, 'response_metadata') else None + total_tokens = response_meta.get("token_usage", {}).get("total_tokens", + 0) if response_meta else 0 + yield f"event: sub_usage\ndata: {json.dumps({"total_tokens": total_tokens}, ensure_ascii=False)}\n\n" if collected_tool_calls: # 找到参数最完整的 transfer 工具调用 best_tc = None diff --git a/api/app/services/memory_reflection_service.py b/api/app/services/memory_reflection_service.py index b92a5d06..e025c1b3 100644 --- a/api/app/services/memory_reflection_service.py +++ b/api/app/services/memory_reflection_service.py @@ -89,7 +89,6 @@ class WorkspaceAppService: for release in app_releases: memory_content = self._extract_memory_content(release.config) - memory_content=resolve_config_id(memory_content, self.db) if memory_content and memory_content in processed_configs: continue @@ -122,16 +121,12 @@ class WorkspaceAppService: def _get_memory_config(self, memory_content: str) -> Dict[str, Any]: """Retrieve memory_config information based on memory_content""" try: - memory_config_result = MemoryConfigRepository.query_reflection_config_by_id(self.db, int(memory_content)) - - # memory_config_query, memory_config_params = MemoryConfigRepository.build_select_reflection(memory_content) - # memory_config_result = self.db.execute(text(memory_config_query), memory_config_params).fetchone() - # if memory_config_result is None: - # return None + memory_content = resolve_config_id(memory_content, self.db) + memory_config_result = MemoryConfigRepository.query_reflection_config_by_id(self.db, (memory_content)) if memory_config_result: return { - "config_id": memory_config_result.config_id, + "config_id": memory_content, "enable_self_reflexion": memory_config_result.enable_self_reflexion, "iteration_period": memory_config_result.iteration_period, "reflexion_range": memory_config_result.reflexion_range, @@ -291,7 +286,7 @@ class MemoryReflectionService: # 检查是否需要执行反思 should_execute = False hours_diff = 0 - + if current_reflection_time is None: # 首次执行反思 should_execute = True @@ -303,11 +298,11 @@ class MemoryReflectionService: reflection_time = datetime.fromisoformat(current_reflection_time) else: reflection_time = current_reflection_time - + current_time = datetime.now() time_diff = current_time - reflection_time hours_diff = int(time_diff.total_seconds() / 3600) - + # 检查是否达到反思周期 if hours_diff >= iteration_period: should_execute = True @@ -317,7 +312,7 @@ class MemoryReflectionService: except (ValueError, TypeError) as e: api_logger.warning(f"解析反思时间失败: {e},将执行反思") should_execute = True - + if should_execute: api_logger.info(f"与上次的反思时间间隔为: {hours_diff} 小时") # 3. 执行反思引擎 @@ -350,7 +345,7 @@ class MemoryReflectionService: "next_reflection_in_hours": iteration_period - hours_diff } - + except Exception as e: config_id = config_data.get("config_id", "unknown") api_logger.error(f"启动反思失败,config_id: {config_id}, end_user_id: {end_user_id}, 错误: {str(e)}") @@ -361,7 +356,7 @@ class MemoryReflectionService: "end_user_id": end_user_id, "config_data": config_data } - + def _create_reflection_config_from_data(self, config_data: Dict[str, Any]) -> ReflectionConfig: """Create reflective configuration objects from configuration data""" @@ -369,12 +364,12 @@ class MemoryReflectionService: if reflexion_range_value is None or reflexion_range_value == "": reflexion_range_value = "partial" reflexion_range = ReflectionRange(reflexion_range_value) - + baseline_value = config_data.get("baseline") if baseline_value is None or baseline_value == "": baseline_value = "TIME" baseline = ReflectionBaseline(baseline_value) - + # iteration_period = iteration_period = config_data.get("iteration_period", 24) if isinstance(iteration_period, str): @@ -382,7 +377,6 @@ class MemoryReflectionService: iteration_period = int(iteration_period) except (ValueError, TypeError): iteration_period = 24 # 默认24小时 - return ReflectionConfig( enabled=config_data.get("enable_self_reflexion", False), iteration_period=str(iteration_period), # ReflectionConfig期望字符串 diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index 904821c1..dee6cd1d 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -508,10 +508,7 @@ class ModelApiKeyService: ) if not validation_result["valid"]: # 记录验证失败的模型,但不抛出异常 - failed_models.append({ - "model_name": model_name, - "error": validation_result["error"] - }) + failed_models.append(model_name) continue # 创建API Key @@ -692,6 +689,9 @@ class ModelBaseService: @staticmethod def create_model_base(db: Session, data: model_schema.ModelBaseCreate): + existing = ModelBaseRepository.get_by_name_and_provider(db, data.name, data.provider) + if existing: + raise BusinessException("模型已存在", BizCode.DUPLICATE_NAME) model_base = ModelBaseRepository.create(db, data.model_dump()) db.commit() db.refresh(model_base) diff --git a/api/app/services/multi_agent_orchestrator.py b/api/app/services/multi_agent_orchestrator.py index d9062eaf..b28bafbf 100644 --- a/api/app/services/multi_agent_orchestrator.py +++ b/api/app/services/multi_agent_orchestrator.py @@ -280,14 +280,22 @@ class MultiAgentOrchestrator: # 4. 提取子 Agent 的 conversation_id(用于多轮对话) sub_conversation_id = None + total_tokens = 0 + if isinstance(results, dict): sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id") + # 提取 token 信息 + usage = results.get("usage", {}) or results.get("result", {}).get("usage", {}) + total_tokens += usage.get("total_tokens", 0) elif isinstance(results, list) and results: for item in results: if "result" in item: sub_conversation_id = item["result"].get("conversation_id") if sub_conversation_id: break + # 累加每个子 Agent 的 token + usage = item.get("usage", {}) or item.get("result", {}).get("usage", {}) + total_tokens += usage.get("total_tokens", 0) logger.info( "多 Agent 任务完成", @@ -301,9 +309,15 @@ class MultiAgentOrchestrator: return { "message": final_result, "conversation_id": sub_conversation_id, + "mode": OrchestrationMode.SUPERVISOR, "elapsed_time": elapsed_time, "strategy": routing_decision.get("collaboration_strategy", "single"), - "sub_results": results + "sub_results": results, + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": total_tokens + } } except Exception as e: @@ -1552,10 +1566,12 @@ class MultiAgentOrchestrator: return { "message": result.get("response", ""), "conversation_id": result.get("conversation_id"), + "mode": OrchestrationMode.COLLABORATION, "elapsed_time": elapsed_time, "strategy": "collaboration", "active_agent": result.get("active_agent"), - "sub_results": result + "sub_results": result, + "usage": result.get("usage") } except Exception as e: diff --git a/api/app/services/multi_agent_service.py b/api/app/services/multi_agent_service.py index da984d16..c52814ed 100644 --- a/api/app/services/multi_agent_service.py +++ b/api/app/services/multi_agent_service.py @@ -1,5 +1,6 @@ """多 Agent 配置管理服务""" import uuid +import json from typing import Optional, List, Tuple, Any, Annotated from fastapi import Depends @@ -427,6 +428,23 @@ class MultiAgentService: memory=getattr(request, 'memory', True) # 记忆功能参数 ) + await self._save_conversation_message( + conversation_id=request.conversation_id, + user_message=request.message, + assistant_message=result.get("message", ""), + app_id=app_id, + user_id=request.user_id, + meta_data={ + "mode": result.get("mode"), + "elapsed_time": result.get("elapsed_time"), + "usage": result.get("usage", { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + }) + } + ) + return result async def run_stream( @@ -451,11 +469,14 @@ class MultiAgentService: raise ResourceNotFoundException("多 Agent 配置", str(app_id)) if not config.is_active: - raise BusinessException("多 Agent 配置已禁用", BizCode.RESOURCE_DISABLED) + raise BusinessException("多 Agent 配置已禁用", BizCode.NOT_FOUND) # 2. 创建编排器 orchestrator = MultiAgentOrchestrator(self.db, config) + full_content = "" + total_tokens = 0 + # 3. 流式执行任务 async for event in orchestrator.execute_stream( message=request.message, @@ -468,7 +489,88 @@ class MultiAgentService: storage_type=storage_type, user_rag_memory_id=user_rag_memory_id ): - yield event + if "sub_usage" in event: + if "data:" in event: + try: + data_line = event.split("data: ", 1)[1].strip() + data = json.loads(data_line) + if "total_tokens" in data: + total_tokens += data["total_tokens"] + except: + pass + else: + yield event + if "data:" in event: + try: + data_line = event.split("data: ", 1)[1].strip() + data = json.loads(data_line) + if "content" in data: + full_content += data["content"] + except: + pass + + await self._save_conversation_message( + conversation_id=request.conversation_id, + user_message=request.message, + assistant_message=full_content, + app_id=app_id, + user_id=request.user_id, + meta_data={ + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": total_tokens + } + } + ) + + async def _save_conversation_message( + self, + conversation_id: uuid.UUID, + user_message: str, + assistant_message: str, + meta_data: dict, + app_id: Optional[uuid.UUID] = None, + user_id: Optional[str] = None + ) -> None: + """保存会话消息 + + Args: + conversation_id: 会话ID + user_message: 用户消息 + assistant_message: AI 回复消息 + meta_data: 元数据(包括 token 消耗) + app_id: 应用ID + user_id: 用户ID + """ + try: + from app.services.conversation_service import ConversationService + + conversation_service = ConversationService(self.db) + + conversation_service.add_message( + conversation_id=conversation_id, + role="user", + content=user_message + ) + conversation_service.add_message( + conversation_id=conversation_id, + role="assistant", + content=assistant_message, + meta_data=meta_data + ) + + logger.debug( + "保存多 Agent 会话消息", + extra={ + "conversation_id": conversation_id, + "user_message_length": len(user_message), + "assistant_message_length": len(assistant_message) + } + ) + + except Exception as e: + logger.warning("保存会话消息失败", extra={"error": str(e)}) # def add_sub_agent( # self, diff --git a/api/app/services/ontology_service.py b/api/app/services/ontology_service.py new file mode 100644 index 00000000..c832b0cc --- /dev/null +++ b/api/app/services/ontology_service.py @@ -0,0 +1,1162 @@ +"""本体提取服务层 + +本模块提供本体提取的业务逻辑封装,协调OntologyExtractor和OWLValidator。 +包括本体提取、OWL文件导出等功能。 + +Classes: + OntologyService: 本体提取服务类,封装业务逻辑 +""" + +import logging +import time +from typing import Any, Dict, List, Optional + +from sqlalchemy.orm import Session + +from app.core.memory.llm_tools.openai_client import OpenAIClient +from app.core.memory.models.ontology_models import ( + OntologyClass, + OntologyExtractionResponse, +) +from app.core.memory.storage_services.extraction_engine.knowledge_extraction.ontology_extraction import ( + OntologyExtractor, +) +from app.core.memory.utils.validation.owl_validator import OWLValidator + + +logger = logging.getLogger(__name__) + + +class OntologyService: + """本体提取服务层 + + 封装本体提取的业务逻辑,协调各个组件: + - OntologyExtractor: 执行LLM驱动的本体提取 + - OWLValidator: OWL语义验证 + + Attributes: + extractor: 本体提取器实例 + owl_validator: OWL验证器实例 + db: 数据库会话 + """ + + # 默认配置参数 + DEFAULT_MAX_CLASSES = 15 + DEFAULT_MIN_CLASSES = 5 + DEFAULT_MAX_DESCRIPTION_LENGTH = 500 + DEFAULT_LLM_TEMPERATURE = 0.3 + DEFAULT_LLM_MAX_TOKENS = 2000 + DEFAULT_LLM_TIMEOUT = 30.0 + DEFAULT_ENABLE_OWL_VALIDATION = True + + def __init__( + self, + llm_client: OpenAIClient, + db: Session + ): + """初始化本体提取服务 + + Args: + llm_client: OpenAI客户端实例 + db: SQLAlchemy数据库会话 + """ + self.extractor = OntologyExtractor(llm_client) + self.owl_validator = OWLValidator() + self.db = db + + # 初始化Repository + from app.repositories.ontology_scene_repository import OntologySceneRepository + from app.repositories.ontology_class_repository import OntologyClassRepository + + self.scene_repo = OntologySceneRepository(db) + self.class_repo = OntologyClassRepository(db) + + logger.info("OntologyService initialized") + + async def extract_ontology( + self, + scenario: str, + domain: Optional[str] = None, + scene_id: Optional[Any] = None, + workspace_id: Optional[Any] = None + ) -> OntologyExtractionResponse: + """执行本体提取 + + 使用默认配置参数调用OntologyExtractor执行提取。 + 提取结果仅返回给前端,不会自动保存到数据库。 + 前端需要调用 /class 接口来保存选中的类型。 + + Args: + scenario: 场景描述文本 + domain: 可选的领域提示 + scene_id: 可选的场景ID,用于权限验证(不再用于自动保存) + workspace_id: 可选的工作空间ID,用于权限验证 + + Returns: + OntologyExtractionResponse: 提取结果 + + Raises: + ValueError: 场景描述为空、场景不存在或无权限 + RuntimeError: 提取过程失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> response = await service.extract_ontology( + ... scenario="医院管理患者记录...", + ... domain="Healthcare", + ... scene_id=scene_uuid, + ... workspace_id=workspace_uuid + ... ) + >>> len(response.classes) + 7 + """ + # 开始计时 + start_time = time.time() + + # 验证输入 + if not scenario or not scenario.strip(): + logger.error("Scenario description is empty") + raise ValueError("Scenario description cannot be empty") + + # 如果提供了scene_id,验证场景是否存在且有权限 + if scene_id and workspace_id: + logger.info(f"Validating scene access - scene_id={scene_id}, workspace_id={workspace_id}") + scene = self.scene_repo.get_by_id(scene_id) + if not scene: + logger.warning(f"Scene not found: {scene_id}") + raise ValueError("场景不存在") + + if not self.scene_repo.check_ownership(scene_id, workspace_id): + logger.warning( + f"Permission denied - scene_id={scene_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限在该场景下创建类型") + + logger.info( + f"Starting ontology extraction service - " + f"scenario_length={len(scenario)}, " + f"domain={domain}, " + f"scene_id={scene_id}" + ) + + try: + # 调用提取器执行提取(使用默认配置) + logger.info("Calling OntologyExtractor with default config") + extraction_start_time = time.time() + + response = await self.extractor.extract_ontology_classes( + scenario=scenario, + domain=domain, + max_classes=self.DEFAULT_MAX_CLASSES, + min_classes=self.DEFAULT_MIN_CLASSES, + enable_owl_validation=self.DEFAULT_ENABLE_OWL_VALIDATION, + llm_temperature=self.DEFAULT_LLM_TEMPERATURE, + llm_max_tokens=self.DEFAULT_LLM_MAX_TOKENS, + max_description_length=self.DEFAULT_MAX_DESCRIPTION_LENGTH, + timeout=self.DEFAULT_LLM_TIMEOUT, + ) + + extraction_duration = time.time() - extraction_start_time + + # 检查是否成功提取到类 + if not response.classes: + logger.error("Ontology extraction failed: No classes extracted (structured output may have failed)") + raise RuntimeError("本体提取失败:结构化输出失败,未能提取到任何本体类") + + # 注释:提取结果仅返回给前端,不保存到数据库 + # 前端将从返回结果中选择需要的类型,然后调用 /class 接口创建 + logger.info( + f"Extraction completed. Classes will be saved to ontology_class " + f"via /class endpoint based on user selection" + ) + + total_duration = time.time() - start_time + + # 记录提取统计 + logger.info( + f"Ontology extraction service completed - " + f"extracted_classes={len(response.classes)}, " + f"domain={response.domain}, " + f"extraction_duration={extraction_duration:.2f}s, " + f"total_duration={total_duration:.2f}s" + ) + + return response + + except ValueError: + # 重新抛出验证错误 + total_duration = time.time() - start_time + logger.error( + f"Validation error after {total_duration:.2f}s", + exc_info=True + ) + raise + except Exception as e: + total_duration = time.time() - start_time + error_msg = f"Ontology extraction failed after {total_duration:.2f}s: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + async def export_owl_file( + self, + classes: List[OntologyClass], + output_path: str, + format: str = "rdfxml", + ) -> str: + """导出OWL文件 + + 将提取的本体类导出为OWL文件,支持多种格式。 + + Args: + classes: 本体类列表 + output_path: 输出文件路径 + format: 导出格式,可选值: "rdfxml", "turtle", "ntriples" (默认: "rdfxml") + + Returns: + str: 导出的OWL文件内容 + + Raises: + ValueError: 类列表为空或格式不支持 + RuntimeError: 导出失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> owl_content = await service.export_owl_file( + ... classes=response.classes, + ... output_path="ontology.owl", + ... format="rdfxml" + ... ) + """ + # 验证输入 + if not classes: + logger.error("Classes list is empty") + raise ValueError("Classes list cannot be empty") + + valid_formats = ["rdfxml", "turtle", "ntriples"] + if format not in valid_formats: + error_msg = f"Unsupported format '{format}'. Must be one of: {', '.join(valid_formats)}" + logger.error(error_msg) + raise ValueError(error_msg) + + logger.info( + f"Starting OWL export - " + f"classes_count={len(classes)}, " + f"output_path={output_path}, " + f"format={format}" + ) + + try: + # 步骤1: 验证本体类 + logger.debug("Validating ontology classes") + is_valid, errors, world = self.owl_validator.validate_ontology_classes( + classes=classes, + ) + + if not is_valid: + logger.warning( + f"OWL validation found {len(errors)} issues during export: {errors}" + ) + # 继续导出,但记录警告 + + if not world: + error_msg = "Failed to create OWL world for export" + logger.error(error_msg) + raise RuntimeError(error_msg) + + # 步骤2: 导出OWL文件 + logger.info(f"Exporting to {format} format") + owl_content = self.owl_validator.export_to_owl( + world=world, + output_path=output_path, + format=format + ) + + logger.info( + f"OWL export completed - " + f"output_path={output_path}, " + f"content_length={len(owl_content)}" + ) + + return owl_content + + except Exception as e: + error_msg = f"OWL export failed: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + + # ==================== 本体场景管理方法 ==================== + + def create_scene( + self, + scene_name: str, + scene_description: Optional[str], + workspace_id: Any + ): + """创建本体场景 + + Args: + scene_name: 场景名称 + scene_description: 场景描述 + workspace_id: 所属工作空间ID + + Returns: + OntologyScene: 创建的场景对象 + + Raises: + ValueError: 场景名称为空 + RuntimeError: 创建失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> scene = service.create_scene( + ... "医疗场景", + ... "用于医疗领域的本体建模", + ... workspace_id + ... ) + """ + # 验证输入 + if not scene_name or not scene_name.strip(): + logger.error("Scene name is empty") + raise ValueError("场景名称不能为空") + + logger.info( + f"Creating scene - " + f"name={scene_name}, workspace_id={workspace_id}" + ) + + try: + scene_data = { + "scene_name": scene_name.strip(), + "scene_description": scene_description + } + + scene = self.scene_repo.create(scene_data, workspace_id) + self.db.commit() + + logger.info(f"Scene created successfully: {scene.scene_id}") + + return scene + + except ValueError: + raise + except Exception as e: + self.db.rollback() + error_msg = f"Failed to create scene: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def update_scene( + self, + scene_id: Any, + scene_name: Optional[str], + scene_description: Optional[str], + workspace_id: Any + ): + """更新本体场景 + + Args: + scene_id: 场景ID + scene_name: 场景名称(可选) + scene_description: 场景描述(可选) + workspace_id: 工作空间ID(用于权限验证) + + Returns: + OntologyScene: 更新后的场景对象 + + Raises: + ValueError: 场景不存在或无权限 + RuntimeError: 更新失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> scene = service.update_scene( + ... scene_id, + ... "新名称", + ... "新描述", + ... workspace_id + ... ) + """ + logger.info(f"Updating scene: {scene_id}") + + try: + # 检查场景是否存在 + scene = self.scene_repo.get_by_id(scene_id) + if not scene: + logger.warning(f"Scene not found: {scene_id}") + raise ValueError("场景不存在") + + # 检查权限 + if not self.scene_repo.check_ownership(scene_id, workspace_id): + logger.warning( + f"Permission denied - scene_id={scene_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限操作该场景") + + # 准备更新数据 + update_data = {} + if scene_name is not None: + if not scene_name.strip(): + raise ValueError("场景名称不能为空") + update_data["scene_name"] = scene_name.strip() + + if scene_description is not None: + update_data["scene_description"] = scene_description + + # 如果没有更新数据,直接返回 + if not update_data: + logger.info("No update data provided, returning existing scene") + return scene + + # 执行更新 + updated_scene = self.scene_repo.update(scene_id, update_data) + self.db.commit() + + logger.info(f"Scene updated successfully: {scene_id}") + + return updated_scene + + except ValueError: + raise + except Exception as e: + self.db.rollback() + error_msg = f"Failed to update scene: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def delete_scene( + self, + scene_id: Any, + workspace_id: Any + ) -> bool: + """删除本体场景 + + Args: + scene_id: 场景ID + workspace_id: 工作空间ID(用于权限验证) + + Returns: + bool: 删除成功返回True + + Raises: + ValueError: 场景不存在或无权限 + RuntimeError: 删除失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> success = service.delete_scene(scene_id, workspace_id) + """ + logger.info(f"Deleting scene: {scene_id}") + + try: + # 检查场景是否存在 + scene = self.scene_repo.get_by_id(scene_id) + if not scene: + logger.warning(f"Scene not found: {scene_id}") + raise ValueError("场景不存在") + + # 检查权限 + if not self.scene_repo.check_ownership(scene_id, workspace_id): + logger.warning( + f"Permission denied - scene_id={scene_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限操作该场景") + + # 执行删除 + success = self.scene_repo.delete(scene_id) + self.db.commit() + + logger.info(f"Scene deleted successfully: {scene_id}") + + return success + + except ValueError: + raise + except Exception as e: + self.db.rollback() + error_msg = f"Failed to delete scene: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def get_scene_by_id( + self, + scene_id: Any, + workspace_id: Any + ): + """获取单个场景 + + Args: + scene_id: 场景ID + workspace_id: 工作空间ID(用于权限验证) + + Returns: + Optional[OntologyScene]: 场景对象 + + Raises: + ValueError: 场景不存在或无权限 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> scene = service.get_scene_by_id(scene_id, workspace_id) + """ + logger.debug(f"Getting scene by ID: {scene_id}") + + try: + # 获取场景 + scene = self.scene_repo.get_by_id(scene_id) + if not scene: + logger.warning(f"Scene not found: {scene_id}") + raise ValueError("场景不存在") + + # 检查权限 + if not self.scene_repo.check_ownership(scene_id, workspace_id): + logger.warning( + f"Permission denied - scene_id={scene_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限访问该场景") + + return scene + + except ValueError: + raise + except Exception as e: + error_msg = f"Failed to get scene: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def get_scene_by_name( + self, + scene_name: str, + workspace_id: Any + ): + """根据场景名称获取场景(精确匹配) + + Args: + scene_name: 场景名称 + workspace_id: 工作空间ID + + Returns: + Optional[OntologyScene]: 场景对象 + + Raises: + ValueError: 场景不存在 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> scene = service.get_scene_by_name("医疗场景", workspace_id) + """ + logger.debug(f"Getting scene by name: {scene_name}, workspace_id: {workspace_id}") + + try: + # 获取场景 + scene = self.scene_repo.get_by_name(scene_name, workspace_id) + if not scene: + logger.warning(f"Scene not found: {scene_name} in workspace {workspace_id}") + raise ValueError("场景不存在") + + return scene + + except ValueError: + raise + except Exception as e: + error_msg = f"Failed to get scene by name: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def search_scenes_by_name( + self, + keyword: str, + workspace_id: Any + ) -> List: + """根据关键词模糊搜索场景 + + Args: + keyword: 搜索关键词 + workspace_id: 工作空间ID + + Returns: + List[OntologyScene]: 匹配的场景列表 + + Raises: + RuntimeError: 搜索失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> scenes = service.search_scenes_by_name("医疗", workspace_id) + """ + logger.debug(f"Searching scenes by keyword: {keyword}, workspace_id: {workspace_id}") + + try: + scenes = self.scene_repo.search_by_name(keyword, workspace_id) + + logger.info( + f"Found {len(scenes)} scenes matching keyword '{keyword}' " + f"in workspace {workspace_id}" + ) + + return scenes + + except Exception as e: + error_msg = f"Failed to search scenes by keyword: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def list_scenes( + self, + workspace_id: Any, + page: Optional[int] = None, + page_size: Optional[int] = None + ) -> tuple: + """获取工作空间下的所有场景(支持分页) + + Args: + workspace_id: 工作空间ID + page: 页码(可选,从1开始) + page_size: 每页数量(可选) + + Returns: + tuple: (场景列表, 总数量) + + Raises: + RuntimeError: 查询失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> scenes, total = service.list_scenes(workspace_id) + >>> scenes, total = service.list_scenes(workspace_id, page=1, page_size=10) + """ + logger.debug(f"Listing scenes for workspace: {workspace_id}, page={page}, page_size={page_size}") + + try: + scenes, total = self.scene_repo.get_by_workspace(workspace_id, page, page_size) + + logger.info(f"Found {len(scenes)} scenes (total: {total}) in workspace {workspace_id}") + + return scenes, total + + except Exception as e: + error_msg = f"Failed to list scenes: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + # ==================== 本体类型管理方法 ==================== + + def create_class( + self, + scene_id: Any, + class_name: str, + class_description: Optional[str], + workspace_id: Any + ): + """创建本体类型 + + Args: + scene_id: 所属场景ID + class_name: 类型名称 + class_description: 类型描述 + workspace_id: 工作空间ID(用于权限验证) + + Returns: + OntologyClass: 创建的类型对象 + + Raises: + ValueError: 类型名称为空、场景不存在或无权限 + RuntimeError: 创建失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> ontology_class = service.create_class( + ... scene_id, + ... "患者", + ... "医院患者信息", + ... workspace_id + ... ) + """ + # 验证输入 + if not class_name or not class_name.strip(): + logger.error("Class name is empty") + raise ValueError("类型名称不能为空") + + logger.info( + f"Creating class - " + f"name={class_name}, scene_id={scene_id}" + ) + + try: + # 检查场景是否存在且属于当前工作空间 + scene = self.scene_repo.get_by_id(scene_id) + if not scene: + logger.warning(f"Scene not found: {scene_id}") + raise ValueError("所属场景不存在") + + if not self.scene_repo.check_ownership(scene_id, workspace_id): + logger.warning( + f"Permission denied - scene_id={scene_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限在该场景下创建类型") + + # 创建类型 + class_data = { + "class_name": class_name.strip(), + "class_description": class_description + } + + ontology_class = self.class_repo.create(class_data, scene_id) + self.db.commit() + + logger.info(f"Class created successfully: {ontology_class.class_id}") + + return ontology_class + + except ValueError: + raise + except Exception as e: + self.db.rollback() + error_msg = f"Failed to create class: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def create_classes_batch( + self, + scene_id: Any, + classes: List[Dict[str, Optional[str]]], + workspace_id: Any + ): + """批量创建本体类型 + + Args: + scene_id: 所属场景ID + classes: 类型列表,每个元素包含 class_name 和 class_description + workspace_id: 工作空间ID(用于权限验证) + + Returns: + Tuple[List, List[str]]: (成功创建的类型列表, 错误信息列表) + + Raises: + ValueError: 场景不存在或无权限 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> classes_data = [ + ... {"class_name": "患者", "class_description": "医院患者信息"}, + ... {"class_name": "医生", "class_description": "医院医生信息"} + ... ] + >>> created_classes, errors = service.create_classes_batch( + ... scene_id, + ... classes_data, + ... workspace_id + ... ) + """ + logger.info( + f"Batch creating classes - " + f"count={len(classes)}, scene_id={scene_id}" + ) + + # 检查场景是否存在且属于当前工作空间(只检查一次) + scene = self.scene_repo.get_by_id(scene_id) + if not scene: + logger.warning(f"Scene not found: {scene_id}") + raise ValueError("所属场景不存在") + + if not self.scene_repo.check_ownership(scene_id, workspace_id): + logger.warning( + f"Permission denied - scene_id={scene_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限在该场景下创建类型") + + created_classes = [] + errors = [] + + for idx, class_data in enumerate(classes): + class_name = class_data.get("class_name", "").strip() + class_description = class_data.get("class_description") + + if not class_name: + error_msg = f"第 {idx + 1} 个类型名称为空,已跳过" + logger.warning(error_msg) + errors.append(error_msg) + continue + + try: + # 创建类型(不需要再次检查权限) + create_data = { + "class_name": class_name, + "class_description": class_description + } + + ontology_class = self.class_repo.create(create_data, scene_id) + created_classes.append(ontology_class) + logger.info(f"Class created successfully: {class_name}") + + except Exception as e: + error_msg = f"创建类型 '{class_name}' 失败: {str(e)}" + logger.error(error_msg) + errors.append(error_msg) + + # 统一提交所有成功的创建 + try: + self.db.commit() + logger.info( + f"Batch creation completed - " + f"success={len(created_classes)}, failed={len(errors)}" + ) + except Exception as e: + self.db.rollback() + error_msg = f"批量创建提交失败: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + return created_classes, errors + + def update_class( + self, + class_id: Any, + class_name: Optional[str], + class_description: Optional[str], + workspace_id: Any + ): + """更新本体类型 + + Args: + class_id: 类型ID + class_name: 类型名称(可选) + class_description: 类型描述(可选) + workspace_id: 工作空间ID(用于权限验证) + + Returns: + OntologyClass: 更新后的类型对象 + + Raises: + ValueError: 类型不存在或无权限 + RuntimeError: 更新失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> ontology_class = service.update_class( + ... class_id, + ... "新名称", + ... "新描述", + ... workspace_id + ... ) + """ + logger.info(f"Updating class: {class_id}") + + try: + # 检查类型是否存在 + ontology_class = self.class_repo.get_by_id(class_id) + if not ontology_class: + logger.warning(f"Class not found: {class_id}") + raise ValueError("类型不存在") + + # 检查权限(通过场景关联) + if not self.class_repo.check_ownership(class_id, workspace_id): + logger.warning( + f"Permission denied - class_id={class_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限操作该类型") + + # 准备更新数据 + update_data = {} + if class_name is not None: + if not class_name.strip(): + raise ValueError("类型名称不能为空") + update_data["class_name"] = class_name.strip() + + if class_description is not None: + update_data["class_description"] = class_description + + # 如果没有更新数据,直接返回 + if not update_data: + logger.info("No update data provided, returning existing class") + return ontology_class + + # 执行更新 + updated_class = self.class_repo.update(class_id, update_data) + self.db.commit() + + logger.info(f"Class updated successfully: {class_id}") + + return updated_class + + except ValueError: + raise + except Exception as e: + self.db.rollback() + error_msg = f"Failed to update class: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def delete_class( + self, + class_id: Any, + workspace_id: Any + ) -> bool: + """删除本体类型 + + Args: + class_id: 类型ID + workspace_id: 工作空间ID(用于权限验证) + + Returns: + bool: 删除成功返回True + + Raises: + ValueError: 类型不存在或无权限 + RuntimeError: 删除失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> success = service.delete_class(class_id, workspace_id) + """ + logger.info(f"Deleting class: {class_id}") + + try: + # 检查类型是否存在 + ontology_class = self.class_repo.get_by_id(class_id) + if not ontology_class: + logger.warning(f"Class not found: {class_id}") + raise ValueError("类型不存在") + + # 检查权限(通过场景关联) + if not self.class_repo.check_ownership(class_id, workspace_id): + logger.warning( + f"Permission denied - class_id={class_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限操作该类型") + + # 执行删除 + success = self.class_repo.delete(class_id) + self.db.commit() + + logger.info(f"Class deleted successfully: {class_id}") + + return success + + except ValueError: + raise + except Exception as e: + self.db.rollback() + error_msg = f"Failed to delete class: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def get_class_by_id( + self, + class_id: Any, + workspace_id: Any + ): + """获取单个类型 + + Args: + class_id: 类型ID + workspace_id: 工作空间ID(用于权限验证) + + Returns: + Optional[OntologyClass]: 类型对象 + + Raises: + ValueError: 类型不存在或无权限 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> ontology_class = service.get_class_by_id(class_id, workspace_id) + """ + logger.debug(f"Getting class by ID: {class_id}") + + try: + # 获取类型 + ontology_class = self.class_repo.get_by_id(class_id) + if not ontology_class: + logger.warning(f"Class not found: {class_id}") + raise ValueError("类型不存在") + + # 检查权限(通过场景关联) + if not self.class_repo.check_ownership(class_id, workspace_id): + logger.warning( + f"Permission denied - class_id={class_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限访问该类型") + + return ontology_class + + except ValueError: + raise + except Exception as e: + error_msg = f"Failed to get class: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def get_class_by_name( + self, + class_name: str, + scene_id: Any, + workspace_id: Any + ): + """根据类型名称获取类型(精确匹配) + + Args: + class_name: 类型名称 + scene_id: 场景ID + workspace_id: 工作空间ID(用于权限验证) + + Returns: + Optional[OntologyClass]: 类型对象 + + Raises: + ValueError: 类型不存在或无权限 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> ontology_class = service.get_class_by_name("患者", scene_id, workspace_id) + """ + logger.debug(f"Getting class by name: {class_name}, scene_id: {scene_id}") + + try: + # 检查场景是否存在且属于当前工作空间 + scene = self.scene_repo.get_by_id(scene_id) + if not scene: + logger.warning(f"Scene not found: {scene_id}") + raise ValueError("场景不存在") + + if not self.scene_repo.check_ownership(scene_id, workspace_id): + logger.warning( + f"Permission denied - scene_id={scene_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限访问该场景") + + # 获取类型 + ontology_class = self.class_repo.get_by_name(class_name, scene_id) + if not ontology_class: + logger.warning(f"Class not found: {class_name} in scene {scene_id}") + raise ValueError("类型不存在") + + return ontology_class + + except ValueError: + raise + except Exception as e: + error_msg = f"Failed to get class by name: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def search_classes_by_name( + self, + keyword: str, + scene_id: Any, + workspace_id: Any + ) -> List: + """根据关键词模糊搜索类型 + + Args: + keyword: 搜索关键词 + scene_id: 场景ID + workspace_id: 工作空间ID(用于权限验证) + + Returns: + List[OntologyClass]: 匹配的类型列表 + + Raises: + ValueError: 场景不存在或无权限 + RuntimeError: 搜索失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> classes = service.search_classes_by_name("患者", scene_id, workspace_id) + """ + logger.debug( + f"Searching classes by keyword: {keyword}, " + f"scene_id: {scene_id}, workspace_id: {workspace_id}" + ) + + try: + # 检查场景是否存在且属于当前工作空间 + scene = self.scene_repo.get_by_id(scene_id) + if not scene: + logger.warning(f"Scene not found: {scene_id}") + raise ValueError("场景不存在") + + if not self.scene_repo.check_ownership(scene_id, workspace_id): + logger.warning( + f"Permission denied - scene_id={scene_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限访问该场景") + + # 搜索类型 + classes = self.class_repo.search_by_name(keyword, scene_id) + + logger.info( + f"Found {len(classes)} classes matching keyword '{keyword}' " + f"in scene {scene_id}" + ) + + return classes + + except ValueError: + raise + except Exception as e: + error_msg = f"Failed to search classes by keyword: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e + + def list_classes_by_scene( + self, + scene_id: Any, + workspace_id: Any + ) -> List: + """获取场景下的所有类型 + + Args: + scene_id: 场景ID + workspace_id: 工作空间ID(用于权限验证) + + Returns: + List[OntologyClass]: 类型列表 + + Raises: + ValueError: 场景不存在或无权限 + RuntimeError: 查询失败 + + Examples: + >>> service = OntologyService(llm_client, db) + >>> classes = service.list_classes_by_scene(scene_id, workspace_id) + """ + logger.debug(f"Listing classes for scene: {scene_id}") + + try: + # 检查场景是否存在且属于当前工作空间 + scene = self.scene_repo.get_by_id(scene_id) + if not scene: + logger.warning(f"Scene not found: {scene_id}") + raise ValueError("场景不存在") + + if not self.scene_repo.check_ownership(scene_id, workspace_id): + logger.warning( + f"Permission denied - scene_id={scene_id}, " + f"workspace_id={workspace_id}" + ) + raise ValueError("无权限访问该场景的类型") + + # 获取类型列表 + classes = self.class_repo.get_by_scene(scene_id) + + logger.info(f"Found {len(classes)} classes in scene {scene_id}") + + return classes + + except ValueError: + raise + except Exception as e: + error_msg = f"Failed to list classes: {str(e)}" + logger.error(error_msg, exc_info=True) + raise RuntimeError(error_msg) from e diff --git a/api/app/services/prompt_optimizer_service.py b/api/app/services/prompt_optimizer_service.py index 9e447214..2c0b57ac 100644 --- a/api/app/services/prompt_optimizer_service.py +++ b/api/app/services/prompt_optimizer_service.py @@ -1,3 +1,4 @@ +import os import re import uuid from typing import Any, AsyncGenerator @@ -18,7 +19,8 @@ from app.models.prompt_optimizer_model import ( ) from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository from app.repositories.prompt_optimizer_repository import ( - PromptOptimizerSessionRepository + PromptOptimizerSessionRepository, + PromptReleaseRepository ) from app.schemas.prompt_optimizer_schema import OptimizePromptResult @@ -28,6 +30,8 @@ logger = get_business_logger() class PromptOptimizerService: def __init__(self, db: Session): self.db = db + self.optim_repo = PromptOptimizerSessionRepository(self.db) + self.release_repo = PromptReleaseRepository(self.db) def get_model_config( self, @@ -78,10 +82,12 @@ class PromptOptimizerService: Returns: PromptOptimzerSession: The newly created prompt optimization session. """ - session = PromptOptimizerSessionRepository(self.db).create_session( + session = self.optim_repo.create_session( tenant_id=tenant_id, user_id=user_id ) + self.db.commit() + self.db.refresh(session) return session def get_session_message_history( @@ -106,7 +112,7 @@ class PromptOptimizerService: - role (str): The role of the message sender, e.g., 'system', 'user', or 'assistant'. - content (str): The content of the message. """ - history = PromptOptimizerSessionRepository(self.db).get_session_history( + history = self.optim_repo.get_session_history( session_id=session_id, user_id=user_id ) @@ -177,11 +183,12 @@ class PromptOptimizerService: base_url=api_config.api_base ), type=ModelType(model_config.type)) try: - with open('app/services/prompt/prompt_optimizer_system.jinja2', 'r', encoding='utf-8') as f: + prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') + with open(os.path.join(prompt_path, 'prompt_optimizer_system.jinja2'), 'r', encoding='utf-8') as f: opt_system_prompt = f.read() rendered_system_message = Template(opt_system_prompt).render() - with open('app/services/prompt/prompt_optimizer_user.jinja2', 'r', encoding='utf-8') as f: + with open(os.path.join(prompt_path, 'prompt_optimizer_user.jinja2'), 'r', encoding='utf-8') as f: opt_user_prompt = f.read() except FileNotFoundError: raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND) @@ -296,4 +303,165 @@ class PromptOptimizerService: role=role, content=content ) + self.db.commit() + self.db.refresh(message) return message + + def save_prompt( + self, + tenant_id: uuid.UUID, + session_id: uuid.UUID, + title: str, + prompt: str + ) -> dict: + """ + Create and save a new prompt release for a given session. + + Args: + tenant_id (uuid.UUID): The ID of the tenant owning the prompt. + session_id (uuid.UUID): The ID of the session to associate with this prompt. + title (str): The title of the prompt release. + prompt (str): The content of the prompt. + + Returns: + dict: A dictionary containing: + - id (UUID): The unique ID of the created prompt release. + - session_id (UUID): The session ID linked to the release. + - title (str): The title of the prompt. + - prompt (str): The prompt content. + - created_at (int): Timestamp (in milliseconds) of when the prompt was created. + + Raises: + BusinessException: If a prompt release already exists for the given session. + """ + session = self.optim_repo.get_session_by_id(session_id) + if session is None or session.tenant_id != tenant_id: + raise BusinessException( + "Session does not exist or the current user has no access", + BizCode.BAD_REQUEST + ) + + if self.release_repo.get_prompt_by_session_id(session_id): + raise BusinessException( + "A release already exists for the current session", + BizCode.BAD_REQUEST + ) + + prompt_obj = self.release_repo.create_prompt_release( + tenant_id=tenant_id, + title=title, + session_id=session_id, + prompt=prompt + ) + self.db.commit() + self.db.refresh(prompt_obj) + return { + "id": prompt_obj.id, + "session_id": prompt_obj.session_id, + "title": prompt_obj.title, + "prompt": prompt_obj.prompt, + "created_at": int(prompt_obj.created_at.timestamp() * 1000) + } + + def delete_prompt( + self, + tenant_id: uuid.UUID, + prompt_id: uuid.UUID + ) -> None: + """ + Soft delete a prompt release by prompt_id. + + Args: + tenant_id (uuid.UUID): Tenant identifier. + prompt_id (uuid.UUID): Prompt identifier. + + Raises: + BusinessException: If the prompt does not exist or already deleted. + """ + prompt_obj = self.release_repo.get_prompt_by_id(prompt_id) + if not prompt_obj or prompt_obj.is_delete: + raise BusinessException( + "Prompt does not exist or has already been deleted", + BizCode.NOT_FOUND + ) + + if prompt_obj.tenant_id != tenant_id: + raise BusinessException( + "No permission to delete this prompt", + BizCode.FORBIDDEN + ) + + self.release_repo.soft_delete_prompt(prompt_obj) + self.db.commit() + logger.info(f"Prompt soft deleted, prompt_id={prompt_id}, tenant_id={tenant_id}") + + def get_release_list( + self, + tenant_id: uuid.UUID, + page: int, + page_size: int, + filter_keyword: str | None = None + ) -> dict[str, int | list[Any]]: + """ + Get paginated list of prompt releases with optional filter. + + Args: + tenant_id (uuid.UUID): Tenant identifier. + page (int): Page number (starting from 1). + page_size (int): Number of items per page. + filter_keyword (str | None): Optional keyword to filter by title. + + Returns: + dict: Contains total count, pagination info, and list of releases. + """ + offset = (page - 1) * page_size + + # Get total count and releases based on filter + if filter_keyword: + total = self.release_repo.count_prompts_by_keyword(tenant_id, filter_keyword) + releases = self.release_repo.search_prompts_paginated( + tenant_id=tenant_id, + keyword=filter_keyword, + offset=offset, + limit=page_size + ) + else: + total = self.release_repo.count_prompts(tenant_id) + releases = self.release_repo.get_prompts_paginated( + tenant_id=tenant_id, + offset=offset, + limit=page_size + ) + + items = [] + for release in releases: + # Get first user message from session + first_message = self.optim_repo.get_first_user_message( + session_id=release.session_id + ) + + items.append({ + "id": release.id, + "title": release.title, + "prompt": release.prompt, + "created_at": int(release.created_at.timestamp() * 1000), + "first_message": first_message + }) + + log_msg = f"Retrieved {len(items)} prompt releases, page={page}, tenant_id={tenant_id}" + if filter_keyword: + log_msg += f", filter='{filter_keyword}'" + logger.info(log_msg) + + result = { + "page": { + "total": total, + "page": page, + "page_size": page_size, + "hasnext": page * page_size < total + }, + "keyword": filter_keyword, + "items": items + } + + return result diff --git a/api/app/services/shared_chat_service.py b/api/app/services/shared_chat_service.py index 1d012088..a92c2649 100644 --- a/api/app/services/shared_chat_service.py +++ b/api/app/services/shared_chat_service.py @@ -282,7 +282,14 @@ class SharedChatService: self.conversation_service.save_conversation_messages( conversation_id=conversation.id, user_message=message, - assistant_message=result["content"] + assistant_message=result["content"], + meta_data={ + "usage": result.get("usage", { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + }) + } ) # self.conversation_service.add_message( # conversation_id=conversation.id, @@ -469,6 +476,7 @@ class SharedChatService: # 流式调用 Agent full_content = "" + total_tokens = 0 async for chunk in agent.chat_stream( message=message, history=history, @@ -479,9 +487,12 @@ class SharedChatService: config_id=config_id, memory_flag=memory_flag ): - full_content += chunk - # 发送消息块事件 - yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n" + if isinstance(chunk, int): + total_tokens = chunk + else: + full_content += chunk + # 发送消息块事件 + yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n" elapsed_time = time.time() - start_time @@ -498,7 +509,7 @@ class SharedChatService: content=full_content, meta_data={ "model": api_key_obj.model_name, - "usage": {} + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens} } ) diff --git a/api/app/tasks.py b/api/app/tasks.py index cdd7945e..48b41e4f 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -774,7 +774,15 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: } -@celery_app.task(name="app.tasks.regenerate_memory_cache", bind=True) +@celery_app.task( + name="app.tasks.regenerate_memory_cache", + bind=True, + ignore_result=True, + max_retries=0, + acks_late=False, + time_limit=3600, + soft_time_limit=3300, +) def regenerate_memory_cache(self) -> Dict[str, Any]: """定时任务:为所有用户重新生成记忆洞察和用户摘要缓存 @@ -966,7 +974,16 @@ def regenerate_memory_cache(self) -> Dict[str, Any]: } -@celery_app.task(name="app.tasks.workspace_reflection_task", bind=True) + +@celery_app.task( + name="app.tasks.workspace_reflection_task", + bind=True, + ignore_result=True, + max_retries=0, + acks_late=False, + time_limit=300, + soft_time_limit=240, +) def workspace_reflection_task(self) -> Dict[str, Any]: """定时任务:每30秒运行工作空间反思功能 @@ -1111,7 +1128,16 @@ def workspace_reflection_task(self) -> Dict[str, Any]: -@celery_app.task(name="app.tasks.run_forgetting_cycle_task", bind=True) + +@celery_app.task( + name="app.tasks.run_forgetting_cycle_task", + bind=True, + ignore_result=True, + max_retries=0, + acks_late=False, + time_limit=7200, + soft_time_limit=7000, +) def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Dict[str, Any]: """定时任务:运行遗忘周期 diff --git a/api/app/utils/config_utils.py b/api/app/utils/config_utils.py index 8863ea78..cc67afd2 100644 --- a/api/app/utils/config_utils.py +++ b/api/app/utils/config_utils.py @@ -7,30 +7,31 @@ from uuid import UUID from sqlalchemy.orm import Session -def resolve_config_id(config_id: UUID | int, db: Session) -> UUID: +def resolve_config_id(config_id: UUID | int|str, db: Session) -> UUID: """ 解析 config_id,如果是整数则通过 config_id_old 查找对应的 UUID - + Args: config_id: 配置ID(UUID 或整数) db: 数据库会话 - + Returns: UUID: 解析后的配置ID - + Raises: ValueError: 当找不到对应的配置时 """ + from app.models.memory_config_model import MemoryConfig if isinstance(config_id, UUID): return config_id if isinstance(config_id, str) and len(config_id)<=6: memory_config = db.query(MemoryConfig).filter( - MemoryConfig.config_id_old == config_id + MemoryConfig.config_id_old == int(config_id) ).first() - + print(memory_config) if not memory_config: - raise ValueError(f"未找到 config_id_old={config_id} 对应的配置") + raise ValueError(f"STR 未找到 config_id_old={config_id} 对应的配置") return memory_config.config_id if isinstance(config_id, int): memory_config = db.query(MemoryConfig).filter( @@ -38,7 +39,7 @@ def resolve_config_id(config_id: UUID | int, db: Session) -> UUID: ).first() if not memory_config: - raise ValueError(f"未找到 config_id_old={config_id} 对应的配置") + raise ValueError(f"INT 未找到 config_id_old={config_id} 对应的配置") return memory_config.config_id diff --git a/api/app/version_info.json b/api/app/version_info.json index 86a5e33e..e82243a4 100644 --- a/api/app/version_info.json +++ b/api/app/version_info.json @@ -1,4 +1,32 @@ { + "v0.2.2": { + "introduction": { + "codeName": "淬锋(Temper)", + "releaseDate": "2026-1-31", + "upgradePosition": "本次发布聚焦平台稳定性和性能优化。正如\"淬锋\"之名——千锤百炼,淬火成锋,我们通过严格测试和修复打磨系统品质。引入 Agent 工作流的代码执行能力、改进模型并发管理,并修复了记忆系统的多个关键问题。", + "coreUpgrades": [ + "1. Agent平台增强
* 模型并发管理:优化模型广场的并发请求处理和资源分配能力。", + "2. 记忆系统优化
* Celery 队列修复:解决任务队列问题,提升异步记忆处理的可靠性
* 记忆 Agent 优化:提升记忆 Agent 的性能和效率
* 接口响应速度优化:优化记忆接口响应时间,加快操作速度。", + "3. 情绪记忆与识别升级
* 情绪记忆角色识别修复:解决情绪记忆上下文中的角色/人物识别问题
* 角色识别增强:提升对话记忆中的角色/人物识别准确性。", + "
", + "MemoryBear 持续致力于为 AI 应用提供类人记忆能力。本次以稳定性为核心的发布,进一步夯实了「感知→精炼→关联→遗忘」范式的基础。", + "未来版本将在此坚实基础上,扩展 Agent 能力并深化记忆智能特性。" + ] + }, + "introduction_en": { + "codeName": "Temper (淬锋)", + "releaseDate": "2026-1-31", + "upgradePosition": "This release focuses on platform stability and performance optimization — true to its codename \"淬锋\" (tempered blade), we've refined the system through rigorous testing and fixes. Introducing Python code execution for Agent workflows, improved model concurrency management, and critical fixes across the memory system.", + "coreUpgrades": [ + "1. Agent Platform Enhancements
* Model Concurrency Management: Enhanced Model Plaza with improved concurrent model request handling and resource allocation.", + "2. Memory System Improvements
* Celery Queue Fix: Resolved task queue issues for more reliable asynchronous memory processing
* Memory Agent Optimization: Improved memory Agent performance and efficiency
* API Response Speed: Optimized memory interface response times for faster operations.", + "3. Emotional Memory & Recognition Upgrades
* Emotion Memory Role Recognition Fix: Resolved issues with role/character identification in emotional memory contexts
* Role Recognition Enhancement: Improved character/role identification accuracy in conversation memory.", + "
", + "MemoryBear continues advancing toward human-like memory capabilities for AI applications. This stability-focused release strengthens the foundation for our Perception → Refinement → Association → Forgetting paradigm.", + "Future releases will build on this solid base with expanded Agent capabilities and deeper memory intelligence features." + ] + } + }, "v0.2.1": { "introduction": { "codeName": "启知", diff --git a/api/docker-compose.yml b/api/docker-compose.yml index f30220cb..69763de2 100644 --- a/api/docker-compose.yml +++ b/api/docker-compose.yml @@ -19,6 +19,7 @@ services: depends_on: - worker-memory - worker-document + - worker-periodic # Memory worker - Memory read/write tasks (threads pool for asyncio) worker-memory: @@ -48,6 +49,20 @@ services: networks: - celery + # Periodic worker - Scheduled/beat tasks (prefork, low concurrency) + worker-periodic: + image: redbear-mem-open:latest + container_name: worker-periodic + env_file: + - .env + volumes: + - ./files:/files + - /etc/localtime:/etc/localtime:ro + command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=prefork --concurrency=2 --queues=periodic_tasks --max-tasks-per-child=50 -n periodic_worker@%h + restart: unless-stopped + networks: + - celery + # Celery Beat - scheduler beat: image: redbear-mem-open:latest @@ -69,7 +84,7 @@ services: container_name: sandbox ports: - "8194" - command: /code/.venv/bin/python main.py + command: /code/.venv/bin/uvicorn main:app --host 0.0.0.0 --port 8194 --log-level debug restart: unless-stopped networks: - sandbox diff --git a/api/env.example b/api/env.example index 274049b9..98c96edc 100644 --- a/api/env.example +++ b/api/env.example @@ -1,4 +1,9 @@ +# Language Configuration +# Supported values: "zh" (Chinese), "en" (English) +# This controls the language used for memory summary titles and other generated content +DEFAULT_LANGUAGE=zh + # Neo4j Configuration (记忆系统数据库) NEO4J_URI= NEO4J_USERNAME= diff --git a/api/migrations/versions/325b759cd66b_2026011240.py b/api/migrations/versions/325b759cd66b_2026011240.py index 3d7443a8..048b109b 100644 --- a/api/migrations/versions/325b759cd66b_2026011240.py +++ b/api/migrations/versions/325b759cd66b_2026011240.py @@ -28,7 +28,15 @@ def upgrade() -> None: op.drop_constraint('data_config_pkey', 'memory_config', type_='primary') op.alter_column('memory_config', 'config_id', new_column_name='config_id_old', nullable=True) op.add_column('memory_config', sa.Column('config_id', sa.UUID(), nullable=True)) - op.execute("UPDATE memory_config SET config_id = apply_id::uuid") + # Handle rows where apply_id might be NULL or invalid - generate new UUIDs for those + op.execute(""" + UPDATE memory_config + SET config_id = CASE + WHEN apply_id IS NOT NULL AND apply_id ~ '^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$' + THEN apply_id::uuid + ELSE gen_random_uuid() + END + """) op.alter_column('memory_config', 'config_id', nullable=False) op.create_primary_key('memory_config_pkey', 'memory_config', ['config_id']) op.execute("ALTER TABLE memory_config ALTER COLUMN config_id_old DROP DEFAULT") diff --git a/api/migrations/versions/550c10595967_202601301521.py b/api/migrations/versions/550c10595967_202601301521.py new file mode 100644 index 00000000..b2f531db --- /dev/null +++ b/api/migrations/versions/550c10595967_202601301521.py @@ -0,0 +1,78 @@ +"""202601301521 + +Revision ID: 550c10595967 +Revises: 5de9b1e28509 +Create Date: 2026-01-30 15:24:34.647440 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '550c10595967' +down_revision: Union[str, None] = '5de9b1e28509' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('ontology_scene', + sa.Column('scene_id', sa.UUID(), nullable=False, comment='场景ID'), + sa.Column('scene_name', sa.String(length=200), nullable=False, comment='场景名称'), + sa.Column('scene_description', sa.Text(), nullable=True, comment='场景描述'), + sa.Column('workspace_id', sa.UUID(), nullable=False, comment='所属工作空间ID'), + sa.Column('created_at', sa.DateTime(), nullable=False, comment='创建时间'), + sa.Column('updated_at', sa.DateTime(), nullable=False, comment='更新时间'), + sa.ForeignKeyConstraint(['workspace_id'], ['workspaces.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('scene_id'), + sa.UniqueConstraint('workspace_id', 'scene_name', name='uq_workspace_scene_name') + ) + op.create_index(op.f('ix_ontology_scene_scene_id'), 'ontology_scene', ['scene_id'], unique=False) + op.create_index(op.f('ix_ontology_scene_workspace_id'), 'ontology_scene', ['workspace_id'], unique=False) + op.create_table('ontology_class', + sa.Column('class_id', sa.UUID(), nullable=False, comment='类型ID'), + sa.Column('class_name', sa.String(length=200), nullable=False, comment='类型名称'), + sa.Column('class_description', sa.Text(), nullable=True, comment='类型描述'), + sa.Column('scene_id', sa.UUID(), nullable=False, comment='所属场景ID'), + sa.Column('created_at', sa.DateTime(), nullable=False, comment='创建时间'), + sa.Column('updated_at', sa.DateTime(), nullable=False, comment='更新时间'), + sa.ForeignKeyConstraint(['scene_id'], ['ontology_scene.scene_id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('class_id') + ) + op.create_index(op.f('ix_ontology_class_class_id'), 'ontology_class', ['class_id'], unique=False) + op.create_index(op.f('ix_ontology_class_scene_id'), 'ontology_class', ['scene_id'], unique=False) + op.create_table('prompt_history', + sa.Column('id', sa.UUID(), nullable=False), + sa.Column('tenant_id', sa.UUID(), nullable=False, comment='Tenant ID'), + sa.Column('session_id', sa.UUID(), nullable=False, comment='Session ID'), + sa.Column('title', sa.String(), nullable=False, comment='Title'), + sa.Column('prompt', sa.Text(), nullable=False, comment='Prompt'), + sa.Column('created_at', sa.DateTime(), nullable=True, comment='Creation Time'), + sa.Column('is_delete', sa.Boolean(), nullable=True, comment='Delete'), + sa.ForeignKeyConstraint(['session_id'], ['prompt_opt_session_list.id'], ), + sa.ForeignKeyConstraint(['tenant_id'], ['tenants.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_prompt_history_created_at'), 'prompt_history', ['created_at'], unique=False) + op.create_index(op.f('ix_prompt_history_id'), 'prompt_history', ['id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + + op.drop_index(op.f('ix_prompt_history_id'), table_name='prompt_history') + op.drop_index(op.f('ix_prompt_history_created_at'), table_name='prompt_history') + op.drop_table('prompt_history') + op.drop_index(op.f('ix_ontology_class_scene_id'), table_name='ontology_class') + op.drop_index(op.f('ix_ontology_class_class_id'), table_name='ontology_class') + op.drop_table('ontology_class') + op.drop_index(op.f('ix_ontology_scene_workspace_id'), table_name='ontology_scene') + op.drop_index(op.f('ix_ontology_scene_scene_id'), table_name='ontology_scene') + op.drop_table('ontology_scene') + # ### end Alembic commands ### diff --git a/api/migrations/versions/5de9b1e28509_20260129212722.py b/api/migrations/versions/5de9b1e28509_20260129212722.py new file mode 100644 index 00000000..cbffad68 --- /dev/null +++ b/api/migrations/versions/5de9b1e28509_20260129212722.py @@ -0,0 +1,80 @@ +"""20260129212722 + +Revision ID: 5de9b1e28509 +Revises: 5ca246ee7dd4 +Create Date: 2026-01-29 21:34:30.978031 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '5de9b1e28509' +down_revision: Union[str, None] = '5ca246ee7dd4' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Neo4j migration: rename group_id to end_user_id + import asyncio + + from app.repositories.neo4j.neo4j_connector import Neo4jConnector + + async def run_neo4j_upgrade(): + connector = Neo4jConnector() + try: + async def transaction_func(tx): + result = await tx.run(""" + MATCH (n) + WHERE n.group_id IS NOT NULL + SET n.end_user_id = n.group_id + REMOVE n.group_id + WITH count(n) AS node_count + MATCH ()-[r]->() + WHERE r.group_id IS NOT NULL + SET r.end_user_id = r.group_id + REMOVE r.group_id + RETURN node_count, count(r) AS rel_count + """) + return await result.data() + + await connector.execute_write_transaction(transaction_func) + finally: + await connector.close() + + asyncio.run(run_neo4j_upgrade()) + + +def downgrade() -> None: + # Neo4j migration: rename end_user_id back to group_id + import asyncio + + from app.repositories.neo4j.neo4j_connector import Neo4jConnector + + async def run_neo4j_downgrade(): + connector = Neo4jConnector() + try: + async def transaction_func(tx): + result = await tx.run(""" + MATCH (n) + WHERE n.end_user_id IS NOT NULL + SET n.group_id = n.end_user_id + REMOVE n.end_user_id + WITH count(n) AS node_count + MATCH ()-[r]->() + WHERE r.end_user_id IS NOT NULL + SET r.group_id = r.end_user_id + REMOVE r.end_user_id + RETURN node_count, count(r) AS rel_count + """) + return await result.data() + + await connector.execute_write_transaction(transaction_func) + finally: + await connector.close() + + asyncio.run(run_neo4j_downgrade()) \ No newline at end of file diff --git a/api/migrations/versions/9def72f79398_202601301850.py b/api/migrations/versions/9def72f79398_202601301850.py new file mode 100644 index 00000000..303a1578 --- /dev/null +++ b/api/migrations/versions/9def72f79398_202601301850.py @@ -0,0 +1,30 @@ +"""202601301850 + +Revision ID: 9def72f79398 +Revises: 550c10595967 +Create Date: 2026-01-30 18:51:18.290796 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '9def72f79398' +down_revision: Union[str, None] = '550c10595967' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('memory_config', sa.Column('scene_id', sa.UUID(), nullable=True, comment='本体场景ID,关联ontology_scene表')) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('memory_config', 'scene_id') + # ### end Alembic commands ### diff --git a/api/pyproject.toml b/api/pyproject.toml index 29597409..6d23a3b9 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -140,6 +140,7 @@ dependencies = [ "oss2>=2.19.1", "flower>=2.0.1", "aiofiles>=23.0.0", + "owlready2>=0.46", ] [tool.pytest.ini_options] diff --git a/sandbox/Dockerfile b/sandbox/Dockerfile index 677b991c..e34b88dd 100644 --- a/sandbox/Dockerfile +++ b/sandbox/Dockerfile @@ -1,9 +1,10 @@ FROM python:3.12-slim USER root WORKDIR /code -LABEL authors="Eterntiy" -ARG NEED_MIRROR=0 +ARG NEED_MIRROR=1 +ENV DEBIAN_FRONTEND=noninteractive + RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \ if [ "$NEED_MIRROR" == "1" ]; then \ @@ -17,11 +18,14 @@ RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \ apt --no-install-recommends install -y ca-certificates && \ apt update && \ apt install -y python3-pip pipx nginx unzip curl wget git vim less && \ + apt install -y nodejs npm && \ apt-get install -y --no-install-recommends tzdata libseccomp2 libseccomp-dev && \ ln -snf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \ echo "Asia/Shanghai" > /etc/timezone && \ apt install -y cargo +ENV PYTHONDONTWRITEBYTECODE=1 + COPY ./app /code/app COPY ./dependencies /code/dependencies COPY ./lib /code/lib @@ -33,10 +37,15 @@ COPY ./requirements.txt /code/requirements.txt RUN python -m venv .venv RUN .venv/bin/python3 -m pip install -r requirements.txt -RUN cargo build --release --manifest-path lib/seccomp_python/Cargo.toml +RUN npm install --prefix=/code/dependencies/nodejs koffi -HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \ +RUN cargo build --release --manifest-path lib/seccomp_redbear/Cargo.toml --features python3 +RUN mv lib/seccomp_redbear/target/release/libsandbox.so lib/seccomp_redbear/target/release/libpython.so +RUN cargo build --release --manifest-path lib/seccomp_redbear/Cargo.toml --features nodejs +RUN mv lib/seccomp_redbear/target/release/libsandbox.so lib/seccomp_redbear/target/release/libnodejs.so + +HEALTHCHECK --interval=30s --timeout=5s --start-period=60s --retries=3 \ CMD curl 127.0.0.1:8194/health -CMD [".venv/bin/python3", "main.py"] \ No newline at end of file +CMD [".venv/bin/uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8194", "--log-level", "debug"] \ No newline at end of file diff --git a/sandbox/app/__init__.py b/sandbox/app/__init__.py new file mode 100644 index 00000000..1b201ce5 --- /dev/null +++ b/sandbox/app/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/1/29 14:33 diff --git a/sandbox/app/config.py b/sandbox/app/config.py index 3fa4cab5..e4930465 100644 --- a/sandbox/app/config.py +++ b/sandbox/app/config.py @@ -4,9 +4,6 @@ from typing import List, Optional from pydantic import BaseModel, Field import yaml -SANDBOX_USER_ID = 1000 -SANDBOX_GROUP_ID = 1000 - DEFAULT_PYTHON_LIB_REQUIREMENTS_AMD = [ "/usr/local/lib/python3.12", "/usr/lib/python3", @@ -15,13 +12,18 @@ DEFAULT_PYTHON_LIB_REQUIREMENTS_AMD = [ "/etc/nsswitch.conf", "/etc/hosts", "/etc/resolv.conf", - "/run/systemd/resolve/stub-resolv.conf", - "/run/resolvconf/resolv.conf", "/etc/localtime", "/usr/share/zoneinfo", "/etc/timezone", ] +DEFAULT_NODEJS_LIB_REQUIREMENTS = [ + "/etc/ssl/certs/ca-certificates.crt", + "/etc/nsswitch.conf", + "/etc/resolv.conf", + "/etc/hosts", +] + class AppConfig(BaseModel): """Application configuration""" @@ -43,83 +45,77 @@ class Config(BaseModel): max_workers: int = 4 max_requests: int = 50 worker_timeout: int = 30 - nodejs_path: str = "node" + enable_network: bool = True enable_preload: bool = False python_path: str = "" python_lib_paths: list = Field(default=DEFAULT_PYTHON_LIB_REQUIREMENTS_AMD) python_deps_update_interval: str = "30m" + + nodejs_path: str = "" + nodejs_lib_paths: list = Field(default=DEFAULT_NODEJS_LIB_REQUIREMENTS) + allowed_syscalls: List[int] = Field(default_factory=list) proxy: ProxyConfig = Field(default_factory=ProxyConfig) + sandbox_user: str = "sandbox" + sandbox_uid: int = 65537 + sandbox_gid: int = 0 + + def set_sandbox_gid(self, gid: int): + """Update sandbox GID dynamically""" + self.sandbox_gid = gid + + def override_with_env(self): + """Override configuration with environment variables""" + env_map = { + "DEBUG": ("app.debug", lambda v: v.lower() in ("true", "1", "yes")), + "MAX_WORKERS": ("max_workers", int), + "MAX_REQUESTS": ("max_requests", int), + "SANDBOX_PORT": ("app.port", int), + "WORKER_TIMEOUT": ("worker_timeout", int), + "API_KEY": ("app.key", str), + "NODEJS_PATH": ("nodejs_path", str), + "ENABLE_NETWORK": ("enable_network", lambda v: v.lower() in ("true", "1", "yes")), + "ENABLE_PRELOAD": ("enable_preload", lambda v: v.lower() in ("true", "1", "yes")), + "ALLOWED_SYSCALLS": ("allowed_syscalls", lambda v: [int(x) for x in v.split(",")]), + "SOCKS5_PROXY": ("proxy.socks5", str), + "HTTP_PROXY": ("proxy.http", str), + "HTTPS_PROXY": ("proxy.https", str), + "PYTHON_PATH": ("python_path", str), + "PYTHON_LIB_PATH": ("python_lib_paths", lambda v: v.split(",")), + "PYTHON_DEPS_UPDATE_INTERVAL": ("python_deps_update_interval", str), + "NODEJS_LIB_PATH": ("nodejs_lib_paths", lambda v: v.split(",")), + } + + for env_var, (attr_path, cast) in env_map.items(): + value = os.getenv(env_var) + if value is not None: + # Support nested attributes like 'app.debug' + parts = attr_path.split(".") + obj = self + for part in parts[:-1]: + obj = getattr(obj, part) + setattr(obj, parts[-1], cast(value)) + # Global configuration instance _config: Optional[Config] = None -def load_config(config_path: str) -> Config: - """Load configuration from YAML file""" +def load_config(config_path: str = "config.yaml") -> Config: + """Load configuration from YAML file and override with env variables""" global _config - - # Load from file if os.path.exists(config_path): with open(config_path, 'r') as f: - data = yaml.safe_load(f) + data = yaml.safe_load(f) or {} _config = Config(**data) else: _config = Config() - # Override with environment variables - if os.getenv("DEBUG"): - _config.app.debug = os.getenv("DEBUG").lower() in ("true", "1", "yes") - - if os.getenv("MAX_WORKERS"): - _config.max_workers = int(os.getenv("MAX_WORKERS")) - - if os.getenv("MAX_REQUESTS"): - _config.max_requests = int(os.getenv("MAX_REQUESTS")) - - if os.getenv("SANDBOX_PORT"): - _config.app.port = int(os.getenv("SANDBOX_PORT")) - - if os.getenv("WORKER_TIMEOUT"): - _config.worker_timeout = int(os.getenv("WORKER_TIMEOUT")) - - if os.getenv("API_KEY"): - _config.app.key = os.getenv("API_KEY") - - if os.getenv("NODEJS_PATH"): - _config.nodejs_path = os.getenv("NODEJS_PATH") - - if os.getenv("ENABLE_NETWORK"): - _config.enable_network = os.getenv("ENABLE_NETWORK").lower() in ("true", "1", "yes") - - if os.getenv("ENABLE_PRELOAD"): - _config.enable_preload = os.getenv("ENABLE_PRELOAD").lower() in ("true", "1", "yes") - - if os.getenv("ALLOWED_SYSCALLS"): - _config.allowed_syscalls = [int(x) for x in os.getenv("ALLOWED_SYSCALLS").split(",")] - - if os.getenv("SOCKS5_PROXY"): - _config.proxy.socks5 = os.getenv("SOCKS5_PROXY") - - if os.getenv("HTTP_PROXY"): - _config.proxy.http = os.getenv("HTTP_PROXY") - - if os.getenv("HTTPS_PROXY"): - _config.proxy.https = os.getenv("HTTPS_PROXY") - - # python - if os.getenv("PYTHON_PATH"): - _config.python_path = os.getenv("PYTHON_PATH") - - if os.getenv("PYTHON_LIB_PATH"): - _config.python_lib_paths = os.getenv("PYTHON_LIB_PATH").split(',') - - if os.getenv("PYTHON_DEPS_UPDATE_INTERVAL"): - _config.python_deps_update_interval = os.getenv("PYTHON_DEPS_UPDATE_INTERVAL") - + # Override from environment + _config.override_with_env() return _config diff --git a/sandbox/app/controllers/health_controller.py b/sandbox/app/controllers/health_controller.py index 4d872e58..882578ec 100644 --- a/sandbox/app/controllers/health_controller.py +++ b/sandbox/app/controllers/health_controller.py @@ -9,4 +9,4 @@ router = APIRouter() @router.get("/health", response_model=HealthResponse) async def health_check(): """Health check endpoint""" - return HealthResponse(status="healthy", version="2.0.0") + return HealthResponse(status="healthy", version="0.1.0") diff --git a/sandbox/app/controllers/sandbox_controller.py b/sandbox/app/controllers/sandbox_controller.py index 1a713f52..c5cce40c 100644 --- a/sandbox/app/controllers/sandbox_controller.py +++ b/sandbox/app/controllers/sandbox_controller.py @@ -2,13 +2,15 @@ from fastapi import APIRouter, Depends from app.middleware.auth import verify_api_key -from app.middleware.concurrency import check_max_requests, acquire_worker +from app.middleware.concurrency import concurrency_guard + from app.models import ( RunCodeRequest, ApiResponse, UpdateDependencyRequest, error_response ) +from app.services.nodejs_service import run_nodejs_code from app.services.python_service import ( run_python_code, list_python_dependencies, @@ -25,16 +27,14 @@ router = APIRouter( @router.post( "/run", response_model=ApiResponse, - dependencies=[Depends(check_max_requests), - Depends(acquire_worker)] + dependencies=[Depends(concurrency_guard)] ) async def run_code(request: RunCodeRequest): """Execute code in sandbox""" if request.language == "python3": return await run_python_code(request.code, request.preload, request.options) elif request.language == "nodejs": - # TODO - return error_response(-400, "TODO") + return await run_nodejs_code(request.code, request.preload, request.options) else: return error_response(-400, "unsupported language") @@ -55,5 +55,3 @@ async def update_dependencies(request: UpdateDependencyRequest): return await update_python_dependencies() else: return error_response(-400, "unsupported language") - - diff --git a/sandbox/app/core/runners/__init__.py b/sandbox/app/core/runners/__init__.py index 96c5e380..b8021009 100644 --- a/sandbox/app/core/runners/__init__.py +++ b/sandbox/app/core/runners/__init__.py @@ -1 +1,40 @@ """Code runners package""" +import pwd +import subprocess + +from app.config import get_config +from app.logger import get_logger + +logger = get_logger() + + +def init_sandbox_user(): + config = get_config() + sandbox_user = config.sandbox_user + sandbox_uid = config.sandbox_uid + try: + pwd.getpwnam(sandbox_user) + logger.info(f"User '{sandbox_user}' already exists") + except KeyError: + try: + subprocess.run( + ["useradd", "-u", str(sandbox_uid), sandbox_user], + check=True, + capture_output=True, + text=True + ) + logger.info(f"Created user '{sandbox_user}' with UID {sandbox_uid}") + except subprocess.CalledProcessError as e: + logger.error(f"Failed to create user: {e.stderr}") + raise RuntimeError(f"Failed to create user '{sandbox_user}': {e.stderr}") from e + + try: + user_info = pwd.getpwnam(sandbox_user) + config.set_sandbox_gid(user_info.pw_gid) + logger.info(f"Sandbox user GID: {config.sandbox_gid}") + except KeyError as e: + logger.error(f"Failed to get GID for user '{sandbox_user}'") + raise RuntimeError(f"Failed to get GID for user '{sandbox_user}'") from e + + + diff --git a/sandbox/app/core/runners/nodejs/__init__.py b/sandbox/app/core/runners/nodejs/__init__.py new file mode 100644 index 00000000..fa5243b7 --- /dev/null +++ b/sandbox/app/core/runners/nodejs/__init__.py @@ -0,0 +1,3 @@ +from app.core.runners.nodejs.env import release_lib_binary + +release_lib_binary(True) diff --git a/sandbox/app/core/runners/nodejs/env.py b/sandbox/app/core/runners/nodejs/env.py new file mode 100644 index 00000000..8c6a55aa --- /dev/null +++ b/sandbox/app/core/runners/nodejs/env.py @@ -0,0 +1,124 @@ +import asyncio +import ctypes +import os +import shutil +import stat +import tempfile +from pathlib import Path + +from app.logger import get_logger +from app.config import get_config + +logger = get_logger() + +RELEASE_LIB_PATH = "./lib/seccomp_redbear/target/release/libnodejs.so" +LIB_PATH = "/var/sandbox/sandbox-nodejs" +LIB_NAME = "libnodejs.so" + +lib = ctypes.CDLL(RELEASE_LIB_PATH) +lib.get_lib_version_static.restype = ctypes.c_char_p +lib.get_lib_feature_static.restype = ctypes.c_char_p +logger.info(f"Seccomp Env: nodejs, " + f"Seccomp Feature: {lib.get_lib_feature_static().decode('utf-8')}, " + f"Seccomp Version: {lib.get_lib_version_static().decode('utf-8')}") + +try: + with open(RELEASE_LIB_PATH, "rb") as f: + _NODEJS_LIB = f.read() +except: + logger.critical("failed to load nodejs lib") + raise + + +def check_lib_avaiable(): + return os.path.exists(os.path.join(LIB_PATH, LIB_NAME)) + + +def release_lib_binary(force_remove: bool): + logger.info("init runtime enviroment") + + lib_file = os.path.join(LIB_PATH, LIB_NAME) + if os.path.exists(lib_file): + if force_remove: + try: + os.remove(lib_file) + except OSError: + logger.critical(f"failed to remove {os.path.join(LIB_PATH, LIB_NAME)}") + raise + + try: + os.makedirs(LIB_PATH, mode=0o755, exist_ok=True) + except OSError: + logger.critical(f"failed to create {LIB_PATH}") + raise + + try: + with open(lib_file, "wb") as f: + f.write(_NODEJS_LIB) + os.chmod(lib_file, 0o755) + except OSError: + logger.critical(f"failed to write {lib_file}") + raise + else: + try: + os.makedirs(LIB_PATH, mode=0o755, exist_ok=True) + except OSError: + logger.critical(f"failed to create {LIB_PATH}") + raise + + try: + with open(lib_file, "wb") as f: + f.write(_NODEJS_LIB) + os.chmod(lib_file, 0o755) + except OSError: + logger.critical(f"failed to write {lib_file}") + raise + + logger.info("nodejs runner environment initialized") + + +async def prepare_nodejs_dependencies_env(): + config = get_config() + + with tempfile.TemporaryDirectory(dir="/") as root_path: + root = Path(root_path) + + env_sh = root / "env.sh" + with open("script/env.sh") as f: + env_sh.write_text(f.read()) + env_sh.chmod(env_sh.stat().st_mode | stat.S_IXUSR) + + shutil.copytree("dependencies/nodejs", os.path.join(LIB_PATH, "node_temp"), dirs_exist_ok=True) + for root, dirs, files in os.walk(os.path.join(LIB_PATH, "node_temp")): + for d in dirs: + os.chmod(os.path.join(root, d), 0o755) + for f in files: + os.chmod(os.path.join(root, f), 0o444) + + for lib_path in config.nodejs_lib_paths: + lib_path = Path(lib_path) + + if not lib_path.exists(): + logger.warning("nodejs lib path %s is not available", lib_path) + continue + + cmd = [ + "bash", + str(env_sh), + str(lib_path), + str(LIB_PATH), + ] + + process = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + + stdout, stderr = await process.communicate() + retcode = process.returncode + + if retcode != 0: + logger.error( + f"create env error for file {lib_path}: retcode={retcode}, stderr={stderr.decode()}" + ) diff --git a/sandbox/app/core/runners/nodejs/nodejs_runner.py b/sandbox/app/core/runners/nodejs/nodejs_runner.py new file mode 100644 index 00000000..59560eee --- /dev/null +++ b/sandbox/app/core/runners/nodejs/nodejs_runner.py @@ -0,0 +1,138 @@ +"""Nodejs code runner""" +import asyncio +import os +import uuid +from typing import Optional + +from app.core.executor import CodeExecutor, ExecutionResult +from app.core.runners.nodejs.env import check_lib_avaiable, release_lib_binary, LIB_PATH +from app.logger import get_logger +from app.models import RunnerOptions + +# Nodejs sandbox prescript template +with open("app/core/runners/nodejs/prescript.js") as f: + NODEJS_PRESCRIPT = f.read() + +logger = get_logger() + + +class NodejsRunner(CodeExecutor): + """Node.js code runner with security isolation""" + + def __init__(self): + super().__init__() + + @staticmethod + def init_environment(code: str, preload: str) -> str: + if not check_lib_avaiable(): + release_lib_binary(False) + code_file_name = uuid.uuid4().hex.replace("-", "_") + + script = NODEJS_PRESCRIPT.replace("{{preload}}", preload, 1) + + eval_code = f"eval(Buffer.from('{code}', 'base64').toString('utf-8'))" + script = script.replace("{{code}}", eval_code, 1) + + code_path = f"{LIB_PATH}/node_temp/tmp/{code_file_name}.js" + try: + os.makedirs(os.path.dirname(code_path), mode=0o755, exist_ok=True) + with open(code_path, "w", encoding="utf-8") as f: + f.write(script) + os.chmod(code_path, 0o755) + + except OSError as e: + raise RuntimeError(f"Failed to write {code_path}") from e + + return code_path + + async def run( + self, + code: str, + options: RunnerOptions, + preload: str = "", + timeout: Optional[int] = None + ) -> ExecutionResult: + """Run Python code in sandbox + + Args: + options: + code: Base64 encoded encrypted code + preload: Preload code to execute before main code + timeout: Execution timeout in seconds + + Returns: + ExecutionResult with stdout, stderr, and exit code + """ + config = self.config + + if timeout is None: + timeout = config.worker_timeout + + # Check if preload is allowed + if not preload or not config.enable_preload: + preload = "" + script_path = self.init_environment(code, preload) + + try: + # Setup environment + env = { + "UV_USE_IO_URING": "0" + } + + # Add proxy settings if configured + if config.proxy.socks5: + env["HTTPS_PROXY"] = config.proxy.socks5 + env["HTTP_PROXY"] = config.proxy.socks5 + elif config.proxy.https or config.proxy.http: + if config.proxy.https: + env["HTTPS_PROXY"] = config.proxy.https + if config.proxy.http: + env["HTTP_PROXY"] = config.proxy.http + + # Add allowed syscalls if configured + if config.allowed_syscalls: + env["ALLOWED_SYSCALLS"] = ",".join(map(str, config.allowed_syscalls)) + + process = await asyncio.create_subprocess_exec( + config.nodejs_path, + script_path, + LIB_PATH, + str(config.sandbox_uid), + str(config.sandbox_gid), + options.model_dump_json(), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + cwd=LIB_PATH + ) + + # Wait for completion with timeout + try: + stdout, stderr = await asyncio.wait_for( + process.communicate(), + timeout=timeout + ) + + return ExecutionResult( + stdout=stdout.decode('utf-8', errors='replace'), + stderr=stderr.decode('utf-8', errors='replace'), + exit_code=process.returncode + ) + + except asyncio.TimeoutError: + # Kill process on timeout + try: + process.kill() + await process.wait() + except: + pass + + return ExecutionResult( + stdout="", + stderr="Execution timeout", + exit_code=-1, + ) + + finally: + # Cleanup temporary file + self.cleanup_temp_file(script_path) diff --git a/sandbox/app/core/runners/nodejs/prescript.js b/sandbox/app/core/runners/nodejs/prescript.js new file mode 100644 index 00000000..460aa108 --- /dev/null +++ b/sandbox/app/core/runners/nodejs/prescript.js @@ -0,0 +1,31 @@ +let argv = process.argv + +let koffi = require('koffi') + +process.chdir(argv[2]) + +let lib = koffi.load("./libnodejs.so") +/** @type {(uid: number, gid: number, enableNetwork: boolean) => number} */ +let initSeccomp = lib.func('int init_seccomp(int, int, bool)') + +let uid = parseInt(argv[3]) +let gid = parseInt(argv[4]) + +let options = JSON.parse(argv[5]) + +let seccomp_init = initSeccomp(uid, gid, options['enable_network']) +if (seccomp_init !== 0) { + throw `code executor err - ${seccomp_init}` +} + +delete process.argv +argv = undefined +koffi = undefined +lib = undefined +initSeccomp = undefined +uid = undefined +gid = undefined +options = undefined +seccomp_init = undefined + +{{code}} diff --git a/sandbox/app/core/runners/python/__init__.py b/sandbox/app/core/runners/python/__init__.py index 99a56ef7..e1a34906 100644 --- a/sandbox/app/core/runners/python/__init__.py +++ b/sandbox/app/core/runners/python/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: UTF-8 -*- -# Author: Eternity -# @Email: 1533512157@qq.com -# @Time : 2026/1/23 11:27 +from app.core.runners.python.env import release_lib_binary + +release_lib_binary(True) diff --git a/sandbox/app/core/runners/python/env.py b/sandbox/app/core/runners/python/env.py index d82b0522..541acc73 100644 --- a/sandbox/app/core/runners/python/env.py +++ b/sandbox/app/core/runners/python/env.py @@ -1,14 +1,80 @@ import asyncio -import tempfile +import ctypes +import os import stat +import tempfile from pathlib import Path from app.config import get_config -from app.core.runners.python.settings import LIB_PATH from app.logger import get_logger logger = get_logger() +RELEASE_LIB_PATH = "./lib/seccomp_redbear/target/release/libpython.so" +LIB_PATH = "/var/sandbox/sandbox-python" +LIB_NAME = "libpython.so" + +lib = ctypes.CDLL(RELEASE_LIB_PATH) +lib.get_lib_version_static.restype = ctypes.c_char_p +lib.get_lib_feature_static.restype = ctypes.c_char_p +logger.info(f"Seccomp Env: python3, " + f"Seccomp Feature: {lib.get_lib_feature_static().decode('utf-8')}, " + f"Seccomp Version: {lib.get_lib_version_static().decode('utf-8')}") + +try: + with open(RELEASE_LIB_PATH, "rb") as f: + _PYTHON_LIB = f.read() +except: + logger.critical("failed to load python lib") + raise + + +def check_lib_avaiable(): + return os.path.exists(os.path.join(LIB_PATH, LIB_NAME)) + + +def release_lib_binary(force_remove: bool): + logger.info("init runtime enviroment") + + lib_file = os.path.join(LIB_PATH, LIB_NAME) + if os.path.exists(lib_file): + if force_remove: + try: + os.remove(lib_file) + except OSError: + logger.critical(f"failed to remove {os.path.join(LIB_PATH, LIB_NAME)}") + raise + + try: + os.makedirs(LIB_PATH, mode=0o755, exist_ok=True) + except OSError: + logger.critical(f"failed to create {LIB_PATH}") + raise + + try: + with open(lib_file, "wb") as f: + f.write(_PYTHON_LIB) + os.chmod(lib_file, 0o755) + except OSError: + logger.critical(f"failed to write {lib_file}") + raise + else: + try: + os.makedirs(LIB_PATH, mode=0o755, exist_ok=True) + except OSError: + logger.critical(f"failed to create {LIB_PATH}") + raise + + try: + with open(lib_file, "wb") as f: + f.write(_PYTHON_LIB) + os.chmod(lib_file, 0o755) + except OSError: + logger.critical(f"failed to write {lib_file}") + raise + + logger.info("python runner environment initialized") + async def prepare_python_dependencies_env(): config = get_config() diff --git a/sandbox/app/core/runners/python/prescript.py b/sandbox/app/core/runners/python/prescript.py index 950710ea..b694fe9b 100644 --- a/sandbox/app/core/runners/python/prescript.py +++ b/sandbox/app/core/runners/python/prescript.py @@ -17,7 +17,7 @@ sys.excepthook = excepthook # Load security library if available lib = ctypes.CDLL("./libpython.so") lib.init_seccomp.argtypes = [ctypes.c_uint32, ctypes.c_uint32, ctypes.c_bool] -lib.init_seccomp.restype = None # TODO: raise error info +lib.init_seccomp.restype = ctypes.c_int # Get running path running_path = sys.argv[1] @@ -37,7 +37,10 @@ os.chdir(running_path) {{preload}} # Apply security if library is available -lib.init_seccomp({{uid}}, {{gid}}, {{enable_network}}) +init_status = lib.init_seccomp({{uid}}, {{gid}}, {{enable_network}}) +if init_status != 0: + raise Exception(f"code executor err - {str(init_status)}") +del lib # Decrypt and execute code code = b64decode("{{code}}") diff --git a/sandbox/app/core/runners/python/python_runner.py b/sandbox/app/core/runners/python/python_runner.py index 30792b91..eccd16e0 100644 --- a/sandbox/app/core/runners/python/python_runner.py +++ b/sandbox/app/core/runners/python/python_runner.py @@ -5,10 +5,10 @@ import os import uuid from typing import Optional -from app.config import SANDBOX_USER_ID, SANDBOX_GROUP_ID, get_config +from app.config import get_config from app.core.encryption import generate_key, encrypt_code from app.core.executor import CodeExecutor, ExecutionResult -from app.core.runners.python.settings import check_lib_avaiable, release_lib_binary, LIB_PATH +from app.core.runners.python.env import check_lib_avaiable, release_lib_binary, LIB_PATH from app.logger import get_logger from app.models import RunnerOptions @@ -32,8 +32,8 @@ class PythonRunner(CodeExecutor): config = get_config() code_file_name = uuid.uuid4().hex.replace("-", "_") - script = PYTHON_PRESCRIPT.replace("{{uid}}", str(SANDBOX_USER_ID), 1) - script = script.replace("{{gid}}", str(SANDBOX_GROUP_ID), 1) + script = PYTHON_PRESCRIPT.replace("{{uid}}", str(config.sandbox_uid), 1) + script = script.replace("{{gid}}", str(config.sandbox_gid), 1) script = script.replace( "{{enable_network}}", str(int(options.enable_network and config.enable_network) diff --git a/sandbox/app/core/runners/python/settings.py b/sandbox/app/core/runners/python/settings.py deleted file mode 100644 index aee8827b..00000000 --- a/sandbox/app/core/runners/python/settings.py +++ /dev/null @@ -1,62 +0,0 @@ -import os - -from app.logger import get_logger - -logger = get_logger() - -RELEASE_LIB_PATH = "./lib/seccomp_python/target/release/libpython.so" -LIB_PATH = "/var/sandbox/sandbox-python" -LIB_NAME = "libpython.so" - -try: - with open(RELEASE_LIB_PATH, "rb") as f: - _PYTHON_LIB = f.read() -except: - logger.critical("failed to load python lib") - raise - - -def check_lib_avaiable(): - return os.path.exists(os.path.join(LIB_PATH, LIB_NAME)) - - -def release_lib_binary(force_remove: bool): - logger.info("init runtime enviroment") - lib_file = os.path.join(LIB_PATH, LIB_NAME) - if os.path.exists(lib_file): - if force_remove: - try: - os.remove(lib_file) - except OSError: - logger.critical(f"failed to remove {os.path.join(LIB_PATH, LIB_NAME)}") - raise - - try: - os.makedirs(LIB_PATH, mode=0o755, exist_ok=True) - except OSError: - logger.critical(f"failed to create {LIB_PATH}") - raise - - try: - with open(lib_file, "wb") as f: - f.write(_PYTHON_LIB) - os.chmod(lib_file, 0o755) - except OSError: - logger.critical(f"failed to write {lib_file}") - raise - else: - try: - os.makedirs(LIB_PATH, mode=0o755, exist_ok=True) - except OSError: - logger.critical(f"failed to create {LIB_PATH}") - raise - - try: - with open(lib_file, "wb") as f: - f.write(_PYTHON_LIB) - os.chmod(lib_file, 0o755) - except OSError: - logger.critical(f"failed to write {lib_file}") - raise - - logger.info("python runner environment initialized") diff --git a/sandbox/app/dependencies.py b/sandbox/app/dependencies.py index 6e88aaf2..6fe05ee4 100644 --- a/sandbox/app/dependencies.py +++ b/sandbox/app/dependencies.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import List, Dict from app.config import get_config +from app.core.runners.nodejs.env import prepare_nodejs_dependencies_env from app.core.runners.python.env import prepare_python_dependencies_env from app.logger import get_logger @@ -19,7 +20,10 @@ async def setup_dependencies(): logger.info("Preparing Python dependencies environment...") await prepare_python_dependencies_env() - logger.info("Python dependencies environment ready") + logger.info("Python Environment Ready ....") + logger.info("Preparing Nodejs dependencies environment...") + await prepare_nodejs_dependencies_env() + logger.info("Nodejs Environment Ready ...") except Exception as e: logger.error(f"Failed to setup dependencies: {e}") @@ -36,7 +40,7 @@ async def install_python_dependencies(): config = get_config() # Check if requirements file exists - req_file = Path("dependencies/python-requirements.txt") + req_file = Path("dependencies/python/python-requirements.txt") if not req_file.exists(): logger.warning("Python requirements file not found, skipping installation") return diff --git a/sandbox/app/logger.py b/sandbox/app/logger.py index de2ccc9e..9e63c8e5 100644 --- a/sandbox/app/logger.py +++ b/sandbox/app/logger.py @@ -12,25 +12,27 @@ def setup_logger() -> logging.Logger: """Setup application logger""" global _logger + if _logger is not None: + return _logger + config = get_config() # Create logger _logger = logging.getLogger("sandbox") _logger.setLevel(logging.DEBUG if config.app.debug else logging.INFO) - # Create console handler - handler = logging.StreamHandler(sys.stdout) - handler.setLevel(logging.DEBUG if config.app.debug else logging.INFO) + # 只在 logger 没有 handler 时才添加 + if not _logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setLevel(logging.DEBUG if config.app.debug else logging.INFO) - # Create formatter - formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) - handler.setFormatter(formatter) + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + handler.setFormatter(formatter) - # Add handler to logger - _logger.addHandler(handler) + _logger.addHandler(handler) return _logger diff --git a/sandbox/app/middleware/concurrency.py b/sandbox/app/middleware/concurrency.py index 8d8325a4..e931f846 100644 --- a/sandbox/app/middleware/concurrency.py +++ b/sandbox/app/middleware/concurrency.py @@ -1,48 +1,66 @@ -"""Concurrency control middleware""" +""" +Concurrency control middleware +""" import asyncio +from contextlib import asynccontextmanager + from fastapi import HTTPException, status from app.config import get_config -from app.models import error_response +from app.logger import get_logger + +logger = get_logger() -# Global semaphores -_worker_semaphore: None | asyncio.Semaphore = None -_request_counter = 0 -_request_lock = asyncio.Lock() +class ConcurrencyController: + def __init__(self): + self._worker_semaphore: asyncio.Semaphore | None = None + self._request_counter = 0 + self._lock = asyncio.Lock() + + config = get_config() + self.max_requests = config.max_requests + + def init(self): + config = get_config() + self._worker_semaphore = asyncio.Semaphore(config.max_workers) + + async def _acquire_worker(self): + if self._worker_semaphore is None: + self.init() + async with self._worker_semaphore: + yield + + async def _limit_requests(self): + async with self._lock: + logger.info(f"Current requests: {self._request_counter}/{self.max_requests}") + if self._request_counter >= self.max_requests: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail={ + "code": 503, + "message": "Too many requests", + "data": None, + } + ) + self._request_counter += 1 + try: + yield + finally: + async with self._lock: + self._request_counter -= 1 + + def acquire_worker(self): + return asynccontextmanager(self._acquire_worker)() + + def limit_requests(self): + return asynccontextmanager(self._limit_requests)() -def init_concurrency_control(): - """Initialize concurrency control""" - global _worker_semaphore - config = get_config() - _worker_semaphore = asyncio.Semaphore(config.max_workers) +concurrency = ConcurrencyController() -async def check_max_requests(): - """Check if max requests limit is reached""" - global _request_counter - config = get_config() - - async with _request_lock: - if _request_counter >= config.max_requests: - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail=error_response(-503, "Too many requests") - ) - _request_counter += 1 - - try: - yield - finally: - async with _request_lock: - _request_counter -= 1 - - -async def acquire_worker(): - """Acquire a worker slot""" - if _worker_semaphore is None: - init_concurrency_control() - - async with _worker_semaphore: - yield +async def concurrency_guard(): + async with concurrency.limit_requests(): + async with concurrency.acquire_worker(): + yield diff --git a/sandbox/app/services/nodejs_service.py b/sandbox/app/services/nodejs_service.py new file mode 100644 index 00000000..ffd6127b --- /dev/null +++ b/sandbox/app/services/nodejs_service.py @@ -0,0 +1,43 @@ +"""Nodejs execution service""" +import signal + +from app.core.runners.nodejs.nodejs_runner import NodejsRunner +from app.logger import get_logger +from app.models import ( + success_response, + error_response, + RunCodeResponse, + RunnerOptions +) + + +async def run_nodejs_code(code: str, preload: str, options: RunnerOptions): + """Execute Node.js code in sandbox + + Args: + options: + code: Base64 encoded encrypted code + preload: Preload code + + Returns: + API response with execution result + """ + logger = get_logger() + + try: + runner = NodejsRunner() + result = await runner.run(code, options, preload) + if result.exit_code == signal.SIGSYS + 0x80: + return error_response(31, "sandbox security policy violation") + + if result.exit_code != 0: + return error_response(500, result.stderr) + + return success_response(RunCodeResponse( + stdout=result.stdout, + stderr=result.stderr + )) + + except Exception as e: + logger.error(f"Python execution failed: {e}", exc_info=True) + return error_response(-500, str(e)) diff --git a/sandbox/config.yaml b/sandbox/config.yaml index d9581b34..26fb9af3 100644 --- a/sandbox/config.yaml +++ b/sandbox/config.yaml @@ -1,13 +1,11 @@ app: - port: 8194 - debug: true key: redbear-sandbox -max_workers: 4 -max_requests: 50 -worker_timeout: 30 +max_workers: 10 +max_requests: 300 +worker_timeout: 15 python_path: /usr/local/bin/python -nodejs_path: /usr/local/bin/node +nodejs_path: /usr/bin/node enable_network: true enable_preload: false python_deps_update_interval: 30m diff --git a/sandbox/dependencies/nodejs/node_modules/.package-lock.json b/sandbox/dependencies/nodejs/node_modules/.package-lock.json new file mode 100644 index 00000000..28b290ef --- /dev/null +++ b/sandbox/dependencies/nodejs/node_modules/.package-lock.json @@ -0,0 +1,6 @@ +{ + "name": "nodejs", + "lockfileVersion": 3, + "requires": true, + "packages": {} +} diff --git a/sandbox/dependencies/nodejs/package-lock.json b/sandbox/dependencies/nodejs/package-lock.json new file mode 100644 index 00000000..28b290ef --- /dev/null +++ b/sandbox/dependencies/nodejs/package-lock.json @@ -0,0 +1,6 @@ +{ + "name": "nodejs", + "lockfileVersion": 3, + "requires": true, + "packages": {} +} diff --git a/sandbox/dependencies/nodejs/package.json b/sandbox/dependencies/nodejs/package.json new file mode 100644 index 00000000..0967ef42 --- /dev/null +++ b/sandbox/dependencies/nodejs/package.json @@ -0,0 +1 @@ +{} diff --git a/sandbox/dependencies/python-requirements.txt b/sandbox/dependencies/python/python-requirements.txt similarity index 100% rename from sandbox/dependencies/python-requirements.txt rename to sandbox/dependencies/python/python-requirements.txt diff --git a/sandbox/lib/seccomp_nodejs/Cargo.lock b/sandbox/lib/seccomp_nodejs/Cargo.lock deleted file mode 100644 index b37698ee..00000000 --- a/sandbox/lib/seccomp_nodejs/Cargo.lock +++ /dev/null @@ -1,7 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 4 - -[[package]] -name = "seccomp_nodejs" -version = "0.1.0" diff --git a/sandbox/lib/seccomp_nodejs/Cargo.toml b/sandbox/lib/seccomp_nodejs/Cargo.toml deleted file mode 100644 index a8bd8932..00000000 --- a/sandbox/lib/seccomp_nodejs/Cargo.toml +++ /dev/null @@ -1,6 +0,0 @@ -[package] -name = "seccomp_nodejs" -version = "0.1.0" -edition = "2024" - -[dependencies] \ No newline at end of file diff --git a/sandbox/lib/seccomp_nodejs/src/lib.rs b/sandbox/lib/seccomp_nodejs/src/lib.rs deleted file mode 100644 index e69de29b..00000000 diff --git a/sandbox/lib/seccomp_python/Cargo.lock b/sandbox/lib/seccomp_redbear/Cargo.lock similarity index 92% rename from sandbox/lib/seccomp_python/Cargo.lock rename to sandbox/lib/seccomp_redbear/Cargo.lock index 881ad177..f81d17c0 100644 --- a/sandbox/lib/seccomp_python/Cargo.lock +++ b/sandbox/lib/seccomp_redbear/Cargo.lock @@ -15,8 +15,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60276e2d41bbb68b323e566047a1bfbf952050b157d8b5cdc74c07c1bf4ca3b6" [[package]] -name = "seccomp_python" -version = "0.1.0" +name = "seccomp_redbear" +version = "0.1.1" dependencies = [ "libc", "libseccomp-sys", diff --git a/sandbox/lib/seccomp_python/Cargo.toml b/sandbox/lib/seccomp_redbear/Cargo.toml similarity index 51% rename from sandbox/lib/seccomp_python/Cargo.toml rename to sandbox/lib/seccomp_redbear/Cargo.toml index 07037172..d6535987 100644 --- a/sandbox/lib/seccomp_python/Cargo.toml +++ b/sandbox/lib/seccomp_redbear/Cargo.toml @@ -1,12 +1,17 @@ [package] -name = "seccomp_python" -version = "0.1.0" +name = "seccomp_redbear" +version = "0.1.1" edition = "2024" [lib] -name = "python" +name = "sandbox" crate-type = ["cdylib"] [dependencies] libc = "0.2.180" libseccomp-sys = "0.3.0" + +[features] +default = [] +python3 = [] +nodejs = [] diff --git a/sandbox/lib/seccomp_python/src/lib.rs b/sandbox/lib/seccomp_redbear/src/lib.rs similarity index 82% rename from sandbox/lib/seccomp_python/src/lib.rs rename to sandbox/lib/seccomp_redbear/src/lib.rs index 08b46c54..9de38a56 100644 --- a/sandbox/lib/seccomp_python/src/lib.rs +++ b/sandbox/lib/seccomp_redbear/src/lib.rs @@ -1,13 +1,25 @@ -mod syscalls; +#[cfg(all(feature = "python3", feature = "nodejs"))] +compile_error!("Only one feature can be enabled: either python3 or nodejs, not both!"); -use crate::syscalls::*; -use libc::{chdir, chroot, gid_t, uid_t, c_int}; +#[cfg(not(any(feature = "python3", feature = "nodejs")))] +compile_error!("You must enable one feature: either python3 or nodejs"); + +#[cfg(feature = "python3")] +mod python_syscalls; +#[cfg(feature = "python3")] +use crate::python_syscalls::*; + +#[cfg(feature = "nodejs")] +mod nodejs_syscalls; +#[cfg(feature = "nodejs")] +use crate::nodejs_syscalls::*; + +use libc::{c_char, c_int, chdir, chroot, gid_t, uid_t}; use libseccomp_sys::*; use std::env; use std::ffi::CString; use std::str::FromStr; - /* * get_allowed_syscalls - retrieve allowed syscalls for the sandbox * @enable_network: enable network-related syscalls if non-zero @@ -193,3 +205,20 @@ pub unsafe extern "C" fn init_seccomp(uid: uid_t, gid: gid_t, enable_network: i3 Err(code) => code, } } + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn get_lib_version_static() -> *const c_char { + concat!(env!("CARGO_PKG_VERSION"), "\0").as_ptr() as *const c_char +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn get_lib_feature_static() -> *const c_char { + #[cfg(feature = "python3")] + let s = b"python3\0"; + #[cfg(feature = "nodejs")] + let s = b"nodejs\0"; + #[cfg(not(any(feature = "python3", feature = "nodejs")))] + let s = b"none\0"; + + s.as_ptr() as *const c_char +} diff --git a/sandbox/lib/seccomp_redbear/src/nodejs_syscalls.rs b/sandbox/lib/seccomp_redbear/src/nodejs_syscalls.rs new file mode 100644 index 00000000..7cf36664 --- /dev/null +++ b/sandbox/lib/seccomp_redbear/src/nodejs_syscalls.rs @@ -0,0 +1,74 @@ +// src/nodejs_syscalls.rs + +pub static ALLOW_SYSCALLS: &[i32] = &[ + // File IO + libc::SYS_open as i32, + libc::SYS_write as i32, + libc::SYS_close as i32, + libc::SYS_read as i32, + libc::SYS_openat as i32, + libc::SYS_newfstatat as i32, + libc::SYS_ioctl as i32, + libc::SYS_lseek as i32, + libc::SYS_fstat as i32, + libc::SYS_readlink as i32, + libc::SYS_dup3 as i32, + libc::SYS_fcntl as i32, + libc::SYS_fsync as i32, + // Memory + libc::SYS_mprotect as i32, + libc::SYS_mmap as i32, + libc::SYS_munmap as i32, + libc::SYS_mremap as i32, + libc::SYS_brk as i32, + libc::SYS_madvise as i32, + // Signal + libc::SYS_rt_sigaction as i32, + libc::SYS_rt_sigprocmask as i32, + libc::SYS_sigaltstack as i32, + libc::SYS_rt_sigreturn as i32, + libc::SYS_tgkill as i32, + // Thread + libc::SYS_futex as i32, + libc::SYS_sched_yield as i32, + libc::SYS_set_robust_list as i32, + libc::SYS_rseq as i32, + // User / Group + libc::SYS_getuid as i32, + // Process + libc::SYS_getpid as i32, + libc::SYS_gettid as i32, + libc::SYS_exit as i32, + libc::SYS_exit_group as i32, + libc::SYS_sched_getaffinity as i32, + // Time + libc::SYS_clock_gettime as i32, + libc::SYS_gettimeofday as i32, + libc::SYS_nanosleep as i32, + libc::SYS_time as i32, + // Epoll / Event (I/O multiplexing) + libc::SYS_epoll_ctl as i32, + libc::SYS_epoll_pwait as i32, +]; + +pub static ALLOW_ERROR_SYSCALLS: &[i32] = &[libc::SYS_clone as i32, libc::SYS_clone3 as i32]; + +pub static ALLOW_NETWORK_SYSCALLS: &[i32] = &[ + libc::SYS_socket as i32, + libc::SYS_connect as i32, + libc::SYS_bind as i32, + libc::SYS_listen as i32, + libc::SYS_accept as i32, + libc::SYS_sendto as i32, + libc::SYS_recvfrom as i32, + libc::SYS_getsockname as i32, + libc::SYS_recvmsg as i32, + libc::SYS_getpeername as i32, + libc::SYS_setsockopt as i32, + libc::SYS_ppoll as i32, + libc::SYS_uname as i32, + libc::SYS_sendmsg as i32, + libc::SYS_getsockopt as i32, + libc::SYS_fcntl as i32, + libc::SYS_fstatfs as i32, +]; diff --git a/sandbox/lib/seccomp_python/src/syscalls.rs b/sandbox/lib/seccomp_redbear/src/python_syscalls.rs similarity index 90% rename from sandbox/lib/seccomp_python/src/syscalls.rs rename to sandbox/lib/seccomp_redbear/src/python_syscalls.rs index 961fffac..998ae390 100644 --- a/sandbox/lib/seccomp_python/src/syscalls.rs +++ b/sandbox/lib/seccomp_redbear/src/python_syscalls.rs @@ -1,7 +1,7 @@ -// src/syscalls.rs +// src/python_syscalls.rs pub static ALLOW_SYSCALLS: &[i32] = &[ - // file io + // File IO libc::SYS_read as i32, libc::SYS_write as i32, libc::SYS_openat as i32, @@ -11,48 +11,44 @@ pub static ALLOW_SYSCALLS: &[i32] = &[ libc::SYS_lseek as i32, libc::SYS_getdents64 as i32, libc::SYS_fstat as i32, - - // thread + // Signal + libc::SYS_rt_sigreturn as i32, + libc::SYS_rt_sigaction as i32, + libc::SYS_rt_sigprocmask as i32, + libc::SYS_sigaltstack as i32, + libc::SYS_tgkill as i32, + // Thread libc::SYS_futex as i32, - - // memory + // Memory libc::SYS_mmap as i32, libc::SYS_brk as i32, libc::SYS_mprotect as i32, libc::SYS_munmap as i32, - libc::SYS_rt_sigreturn as i32, libc::SYS_mremap as i32, - - // user / group - libc::SYS_setuid as i32, - libc::SYS_setgid as i32, + // User / Group libc::SYS_getuid as i32, - - // process + // Process libc::SYS_getpid as i32, libc::SYS_getppid as i32, libc::SYS_gettid as i32, libc::SYS_exit as i32, libc::SYS_exit_group as i32, - libc::SYS_tgkill as i32, - libc::SYS_rt_sigaction as i32, libc::SYS_sched_yield as i32, libc::SYS_set_robust_list as i32, libc::SYS_get_robust_list as i32, libc::SYS_rseq as i32, - - // time + // Time libc::SYS_clock_gettime as i32, libc::SYS_gettimeofday as i32, + libc::SYS_time as i32, libc::SYS_nanosleep as i32, + libc::SYS_clock_nanosleep as i32, + // Epoll / Event (I/O multiplexing) libc::SYS_epoll_create1 as i32, libc::SYS_epoll_ctl as i32, - libc::SYS_clock_nanosleep as i32, libc::SYS_pselect6 as i32, - libc::SYS_rt_sigprocmask as i32, - libc::SYS_sigaltstack as i32, + // Randomness libc::SYS_getrandom as i32, - ]; pub static ALLOW_ERROR_SYSCALLS: &[i32] = &[ diff --git a/sandbox/main.py b/sandbox/main.py index fc417563..99b7b0a6 100644 --- a/sandbox/main.py +++ b/sandbox/main.py @@ -11,51 +11,15 @@ from fastapi import FastAPI from app.config import get_config from app.controllers import manager_router +from app.core.runners import init_sandbox_user from app.dependencies import setup_dependencies, update_dependencies_periodically from app.logger import setup_logger, get_logger +setup_logger() +config = get_config() logger = get_logger() -@asynccontextmanager -async def lifespan(app: FastAPI): - """Application lifespan manager""" - logger = get_logger() - - # Startup - logger.info("Starting RedBear Sandbox...") - - # Setup dependencies in background - asyncio.create_task(setup_dependencies()) - - # Start periodic dependency updates - config = get_config() - if config.python_deps_update_interval: - asyncio.create_task(update_dependencies_periodically()) - - yield - - # Shutdown - logger.info("Shutting down Redbear Sandbox...") - - -def create_app() -> FastAPI: - """Create FastAPI application""" - config = get_config() - - app = FastAPI( - title="Sandbox", - description="Secure code execution sandbox", - version="2.0.0", - lifespan=lifespan, - debug=config.app.debug - ) - - app.include_router(manager_router) - - return app - - def check_root_privileges(): """Check if running with root privileges""" if os.geteuid() != 0: @@ -63,35 +27,38 @@ def check_root_privileges(): sys.exit(1) -def main(): - """Main entry point""" - # Check root privileges - check_root_privileges() +check_root_privileges() - # Setup logging - setup_logger() - config = get_config() +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan manager""" logger = get_logger() - + config = get_config() + # Startup + logger.info("Starting RedBear Sandbox...") logger.info(f"Starting server on port {config.app.port}") logger.info(f"Debug mode: {config.app.debug}") logger.info(f"Max workers: {config.max_workers}") logger.info(f"Max requests: {config.max_requests}") logger.info(f"Network enabled: {config.enable_network}") + init_sandbox_user() + await setup_dependencies() - # Create app - app = create_app() + if config.python_deps_update_interval: + asyncio.create_task(update_dependencies_periodically()) - # Run server - uvicorn.run( - app, - host="0.0.0.0", - port=config.app.port, - log_level="debug" if config.app.debug else "info", - access_log=config.app.debug - ) + yield + # Shutdown + logger.info("Shutting down Redbear Sandbox...") -if __name__ == "__main__": - main() +app = FastAPI( + title="Sandbox", + description="Secure code execution sandbox", + version="0.1.0", + lifespan=lifespan, + debug=config.app.debug +) + +app.include_router(manager_router) diff --git a/web/src/api/ontology.ts b/web/src/api/ontology.ts new file mode 100644 index 00000000..4213d362 --- /dev/null +++ b/web/src/api/ontology.ts @@ -0,0 +1,39 @@ +import { request } from '@/utils/request' +import type { Query, OntologyModalData, OntologyClassModalData, OntologyClassExtractModalData } from '@/views/Ontology/types' + +// Scene list +export const getOntologyScenesUrl = '/memory/ontology/scenes' +export const getOntologyScenesList = (data: Query) => { + return request.get(getOntologyScenesUrl, data) +} + +// Create scene +export const createOntologyScene = (data: OntologyModalData) => { + return request.post('/memory/ontology/scene', data) +} +// Update scene +export const updateOntologyScene = (scene_id: string, data: OntologyModalData) => { + return request.put(`/memory/ontology/scene/${scene_id}`, data) +} +// Delete scene +export const deleteOntologyScene = (scene_id: string) => { + return request.delete(`/memory/ontology/scene/${scene_id}`) +} + +// Get class list +export const getOntologyclassesUrl = '/memory/ontology/classes' +export const getOntologyClassList = (data: { scene_id: string; class_name?: string; }) => { + return request.get(getOntologyclassesUrl, data) +} +// Extract ontology types +export const extractOntologyTypes = (data: OntologyClassExtractModalData) => { + return request.post('/memory/ontology/extract', data) +} +// Create ontology class +export const createOntologyClass = (data: OntologyClassModalData) => { + return request.post('/memory/ontology/class', data) +} +// Delete ontology class +export const deleteOntologyClass = (class_id: string) => { + return request.delete(`/memory/ontology/class/${class_id}`) +} diff --git a/web/src/api/prompt.ts b/web/src/api/prompt.ts index 526f50ac..79ea374c 100644 --- a/web/src/api/prompt.ts +++ b/web/src/api/prompt.ts @@ -1,13 +1,26 @@ import { request } from '@/utils/request' import type { AiPromptForm } from '@/views/ApplicationConfig/types' +import type { PromptReleaseData } from '@/views/Prompt/types' import { handleSSE, type SSEMessage } from '@/utils/stream' +// Create session export const createPromptSessions = () => { return request.post(`/prompt/sessions`) } -export const getPrompt = (session_id: string) => { - return request.get(`/prompt/sessions/${session_id}`) -} +// Get prompt optimization export const updatePromptMessages = (session_id: string, data: AiPromptForm, onMessage?: (data: SSEMessage[]) => void) => { return handleSSE(`/prompt/sessions/${session_id}/messages`, data, onMessage) +} +// Prompt release list +export const getPromptReleaseListUrl = '/prompt/releases/list' +export const getPromptReleaseList = () => { + return request.get(getPromptReleaseListUrl) +} +// Save prompt +export const savePrompt = (data: PromptReleaseData) => { + return request.post('/prompt/releases', data) +} +// Delete prompt +export const deletePrompt = (prompt_id: string) => { + return request.delete(`/prompt/releases/${prompt_id}`) } \ No newline at end of file diff --git a/web/src/assets/images/menu/ontology.svg b/web/src/assets/images/menu/ontology.svg new file mode 100644 index 00000000..9bfda42b --- /dev/null +++ b/web/src/assets/images/menu/ontology.svg @@ -0,0 +1,11 @@ + + + 本体管理备份 + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menu/ontology_active.svg b/web/src/assets/images/menu/ontology_active.svg new file mode 100644 index 00000000..1271c2c3 --- /dev/null +++ b/web/src/assets/images/menu/ontology_active.svg @@ -0,0 +1,11 @@ + + + 本体管理 + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menu/prompt.svg b/web/src/assets/images/menu/prompt.svg new file mode 100644 index 00000000..ffef9a34 --- /dev/null +++ b/web/src/assets/images/menu/prompt.svg @@ -0,0 +1,15 @@ + + + 提示词备份 + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menu/prompt_active.svg b/web/src/assets/images/menu/prompt_active.svg new file mode 100644 index 00000000..ac45e13c --- /dev/null +++ b/web/src/assets/images/menu/prompt_active.svg @@ -0,0 +1,15 @@ + + + 提示词 + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/space/neo4j.png b/web/src/assets/images/space/neo4j.png new file mode 100644 index 00000000..74fc7a86 Binary files /dev/null and b/web/src/assets/images/space/neo4j.png differ diff --git a/web/src/assets/images/space/rag.png b/web/src/assets/images/space/rag.png new file mode 100644 index 00000000..4506efda Binary files /dev/null and b/web/src/assets/images/space/rag.png differ diff --git a/web/src/components/CustomSelect/index.tsx b/web/src/components/CustomSelect/index.tsx index 1887d635..f93014c9 100644 --- a/web/src/components/CustomSelect/index.tsx +++ b/web/src/components/CustomSelect/index.tsx @@ -1,4 +1,4 @@ -import { useEffect, useState, type FC, type Key } from 'react'; +import { useEffect, useState, useMemo, type FC, type Key } from 'react'; import { Select } from 'antd'; import type { SelectProps, DefaultOptionType } from 'antd/es/select'; import { useTranslation } from 'react-i18next'; @@ -47,13 +47,14 @@ const CustomSelect: FC = ({ }) => { const { t } = useTranslation(); const [options, setOptions] = useState([]); + const memoizedParams = useMemo(() => params, [JSON.stringify(params)]); useEffect(() => { - request.get>(url, params).then((res) => { + request.get>(url, memoizedParams).then((res) => { const data = Array.isArray(res) ? res : res?.items || []; setOptions(data); }); - }, [url, params]); + }, [url, memoizedParams]); const displayOptions = format ? format(options) : options; diff --git a/web/src/components/Empty/BodyWrapper.tsx b/web/src/components/Empty/BodyWrapper.tsx index f9978184..9cdeb0e8 100644 --- a/web/src/components/Empty/BodyWrapper.tsx +++ b/web/src/components/Empty/BodyWrapper.tsx @@ -1,6 +1,6 @@ import type { FC, ReactNode } from 'react' -import { Skeleton } from 'antd' -import Empty from './index' +import PageEmpty from './PageEmpty' +import PageLoading from './PageLoading' interface BodyWrapperProps { children: ReactNode @@ -9,10 +9,10 @@ interface BodyWrapperProps { } const BodyWrapper: FC = ({ children, loading = false, empty }) => { if (loading) { - return + return } if (!loading && empty) { - return + return } return children } diff --git a/web/src/components/Markdown/index.tsx b/web/src/components/Markdown/index.tsx index 58650207..6737f15a 100644 --- a/web/src/components/Markdown/index.tsx +++ b/web/src/components/Markdown/index.tsx @@ -19,6 +19,7 @@ interface RbMarkdownProps { showHtmlComments?: boolean; // 是否显示 HTML 注释,默认为 false(隐藏) editable?: boolean; // 是否可编辑,默认为 false onContentChange?: (content: string) => void; // 内容变化回调 + className?: string; } const components = { @@ -98,6 +99,7 @@ const RbMarkdown: FC = ({ showHtmlComments = false, editable = false, onContentChange, + className }) => { const [editContent, setEditContent] = useState(content) const textareaRef = useRef(null) @@ -162,7 +164,7 @@ const RbMarkdown: FC = ({ // 预览模式 return ( -
+