diff --git a/api/app/controllers/service/__init__.py b/api/app/controllers/service/__init__.py index 96da0949..52d4b732 100644 --- a/api/app/controllers/service/__init__.py +++ b/api/app/controllers/service/__init__.py @@ -4,7 +4,17 @@ 认证方式: API Key """ from fastapi import APIRouter -from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller, end_user_api_controller + +from . import ( + app_api_controller, + end_user_api_controller, + memory_api_controller, + memory_config_api_controller, + rag_api_chunk_controller, + rag_api_document_controller, + rag_api_file_controller, + rag_api_knowledge_controller, +) # 创建 V1 API 路由器 service_router = APIRouter() @@ -17,5 +27,6 @@ service_router.include_router(rag_api_file_controller.router) service_router.include_router(rag_api_chunk_controller.router) service_router.include_router(memory_api_controller.router) service_router.include_router(end_user_api_controller.router) +service_router.include_router(memory_config_api_controller.router) __all__ = ["service_router"] diff --git a/api/app/controllers/service/end_user_api_controller.py b/api/app/controllers/service/end_user_api_controller.py index 92a9d7c8..1faea6ef 100644 --- a/api/app/controllers/service/end_user_api_controller.py +++ b/api/app/controllers/service/end_user_api_controller.py @@ -5,6 +5,7 @@ import uuid from fastapi import APIRouter, Body, Depends, Request from sqlalchemy.orm import Session +from app.controllers import user_memory_controllers from app.core.api_key_auth import require_api_key from app.core.error_codes import BizCode from app.core.exceptions import BusinessException @@ -14,13 +15,31 @@ from app.core.response_utils import success from app.db import get_db from app.repositories.end_user_repository import EndUserRepository from app.schemas.api_key_schema import ApiKeyAuth +from app.schemas.end_user_info_schema import EndUserInfoUpdate from app.schemas.memory_api_schema import CreateEndUserRequest, CreateEndUserResponse +from app.services import api_key_service from app.services.memory_config_service import MemoryConfigService router = APIRouter(prefix="/end_user", tags=["V1 - End User API"]) logger = get_business_logger() +def _get_current_user(api_key_auth: ApiKeyAuth, db: Session): + """Build a current_user object from API key auth + + Args: + api_key_auth: Validated API key auth info + db: Database session + + Returns: + User object with current_workspace_id set + """ + api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id) + current_user = api_key.creator + current_user.current_workspace_id = api_key_auth.workspace_id + return current_user + + @router.post("/create") @require_api_key(scopes=["memory"]) @check_end_user_quota @@ -39,6 +58,7 @@ async def create_end_user( Optionally accepts a memory_config_id to connect the end user to a specific memory configuration. If not provided, falls back to the workspace default config. + Optionally accepts an app_id to bind the end user to a specific app. """ body = await request.json() payload = CreateEndUserRequest(**body) @@ -73,14 +93,26 @@ async def create_end_user( else: logger.warning(f"No default memory config found for workspace: {workspace_id}") + # Resolve app_id: explicit from payload, otherwise None + app_id = None + if payload.app_id: + try: + app_id = uuid.UUID(payload.app_id) + except ValueError: + raise BusinessException( + f"Invalid app_id format: {payload.app_id}", + BizCode.INVALID_PARAMETER + ) + end_user_repo = EndUserRepository(db) end_user = end_user_repo.get_or_create_end_user_with_config( - app_id=api_key_auth.resource_id, + app_id=app_id, workspace_id=workspace_id, other_id=payload.other_id, memory_config_id=memory_config_id, + other_name=payload.other_name, ) - + end_user.other_name = payload.other_name logger.info(f"End user ready: {end_user.id}") result = { @@ -92,3 +124,50 @@ async def create_end_user( } return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully") + + +@router.get("/info") +@require_api_key(scopes=["memory"]) +async def get_end_user_info( + request: Request, + end_user_id: str, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), +): + """ + Get end user info. + + Retrieves the info record (aliases, meta_data, etc.) for the specified end user. + Delegates to the manager-side controller for shared logic. + """ + current_user = _get_current_user(api_key_auth, db) + return await user_memory_controllers.get_end_user_info( + end_user_id=end_user_id, + current_user=current_user, + db=db, + ) + + +@router.post("/info/update") +@require_api_key(scopes=["memory"]) +async def update_end_user_info( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), + message: str = Body(None, description="Request body"), +): + """ + Update end user info. + + Updates the info record (other_name, aliases, meta_data) for the specified end user. + Delegates to the manager-side controller for shared logic. + """ + body = await request.json() + payload = EndUserInfoUpdate(**body) + + current_user = _get_current_user(api_key_auth, db) + return await user_memory_controllers.update_end_user_info( + info_update=payload, + current_user=current_user, + db=db, + ) diff --git a/api/app/controllers/service/memory_api_controller.py b/api/app/controllers/service/memory_api_controller.py index 16f1e223..313781d2 100644 --- a/api/app/controllers/service/memory_api_controller.py +++ b/api/app/controllers/service/memory_api_controller.py @@ -1,5 +1,8 @@ """Memory 服务接口 - 基于 API Key 认证""" +from fastapi import APIRouter, Body, Depends, Query, Request +from sqlalchemy.orm import Session + from app.core.api_key_auth import require_api_key from app.core.logging_config import get_business_logger from app.core.quota_stub import check_end_user_quota @@ -7,48 +10,74 @@ from app.core.response_utils import success from app.db import get_db from app.schemas.api_key_schema import ApiKeyAuth from app.schemas.memory_api_schema import ( - CreateEndUserRequest, - CreateEndUserResponse, - ListConfigsResponse, MemoryReadRequest, MemoryReadResponse, + MemoryReadSyncResponse, MemoryWriteRequest, MemoryWriteResponse, + MemoryWriteSyncResponse, ) from app.services.memory_api_service import MemoryAPIService -from fastapi import APIRouter, Body, Depends, Request -from sqlalchemy.orm import Session router = APIRouter(prefix="/memory", tags=["V1 - Memory API"]) logger = get_business_logger() +def _sanitize_task_result(result: dict) -> dict: + """Make Celery task result JSON-serializable. + + Converts UUID and other non-serializable values to strings. + + Args: + result: Raw task result dict from task_service + + Returns: + JSON-safe dict + """ + import uuid as _uuid + from datetime import datetime + + def _convert(obj): + if isinstance(obj, dict): + return {k: _convert(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_convert(i) for i in obj] + if isinstance(obj, _uuid.UUID): + return str(obj) + if isinstance(obj, datetime): + return obj.isoformat() + return obj + + return _convert(result) + + @router.get("") async def get_memory_info(): """获取记忆服务信息(占位)""" return success(data={}, msg="Memory API - Coming Soon") -@router.post("/write_api_service") +@router.post("/write") @require_api_key(scopes=["memory"]) -async def write_memory_api_service( +async def write_memory( request: Request, api_key_auth: ApiKeyAuth = None, db: Session = Depends(get_db), message: str = Body(..., description="Message content"), ): """ - Write memory to storage. - - Stores memory content for the specified end user using the Memory API Service. + Submit a memory write task. + + Validates the end user, then dispatches the write to a Celery background task + with per-user fair locking. Returns a task_id for status polling. """ body = await request.json() payload = MemoryWriteRequest(**body) logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}") - + memory_api_service = MemoryAPIService(db) - - result = await memory_api_service.write_memory( + + result = memory_api_service.write_memory( workspace_id=api_key_auth.workspace_id, end_user_id=payload.end_user_id, message=payload.message, @@ -56,31 +85,53 @@ async def write_memory_api_service( storage_type=payload.storage_type, user_rag_memory_id=payload.user_rag_memory_id, ) - - logger.info(f"Memory write successful for end_user: {payload.end_user_id}") - return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory written successfully") + + logger.info(f"Memory write task submitted: task_id={result['task_id']}, end_user_id: {payload.end_user_id}") + return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted") -@router.post("/read_api_service") +@router.get("/write/status") @require_api_key(scopes=["memory"]) -async def read_memory_api_service( +async def get_write_task_status( + request: Request, + task_id: str = Query(..., description="Celery task ID"), + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), +): + """ + Check the status of a memory write task. + + Returns the current status and result (if completed) of a previously submitted write task. + """ + logger.info(f"Write task status check - task_id: {task_id}") + + from app.services.task_service import get_task_memory_write_result + result = get_task_memory_write_result(task_id) + + return success(data=_sanitize_task_result(result), msg="Task status retrieved") + + +@router.post("/read") +@require_api_key(scopes=["memory"]) +async def read_memory( request: Request, api_key_auth: ApiKeyAuth = None, db: Session = Depends(get_db), message: str = Body(..., description="Query message"), ): """ - Read memory from storage. - - Queries and retrieves memories for the specified end user with context-aware responses. + Submit a memory read task. + + Validates the end user, then dispatches the read to a Celery background task. + Returns a task_id for status polling. """ body = await request.json() payload = MemoryReadRequest(**body) logger.info(f"Memory read request - end_user_id: {payload.end_user_id}") - + memory_api_service = MemoryAPIService(db) - - result = await memory_api_service.read_memory( + + result = memory_api_service.read_memory( workspace_id=api_key_auth.workspace_id, end_user_id=payload.end_user_id, message=payload.message, @@ -89,59 +140,95 @@ async def read_memory_api_service( storage_type=payload.storage_type, user_rag_memory_id=payload.user_rag_memory_id, ) - - logger.info(f"Memory read successful for end_user: {payload.end_user_id}") - return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read successfully") + + logger.info(f"Memory read task submitted: task_id={result['task_id']}, end_user_id: {payload.end_user_id}") + return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read task submitted") -@router.get("/configs") +@router.get("/read/status") @require_api_key(scopes=["memory"]) -async def list_memory_configs( +async def get_read_task_status( request: Request, + task_id: str = Query(..., description="Celery task ID"), api_key_auth: ApiKeyAuth = None, db: Session = Depends(get_db), ): """ - List all memory configs for the workspace. - - Returns all available memory configurations associated with the authorized workspace. + Check the status of a memory read task. + + Returns the current status and result (if completed) of a previously submitted read task. """ - logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}") + logger.info(f"Read task status check - task_id: {task_id}") - memory_api_service = MemoryAPIService(db) + from app.services.task_service import get_task_memory_read_result + result = get_task_memory_read_result(task_id) - result = memory_api_service.list_memory_configs( - workspace_id=api_key_auth.workspace_id, - ) - - logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}") - return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully") + return success(data=_sanitize_task_result(result), msg="Task status retrieved") -@router.post("/end_users") +@router.post("/write/sync") @require_api_key(scopes=["memory"]) @check_end_user_quota -async def create_end_user( +async def write_memory_sync( request: Request, api_key_auth: ApiKeyAuth = None, db: Session = Depends(get_db), + message: str = Body(..., description="Message content"), ): """ - Create an end user. - - Creates a new end user for the authorized workspace. - If an end user with the same other_id already exists, returns the existing one. + Write memory synchronously. + + Blocks until the write completes and returns the result directly. + For async processing with task polling, use /write instead. """ body = await request.json() - payload = CreateEndUserRequest(**body) - logger.info(f"Create end user request - other_id: {payload.other_id}, workspace_id: {api_key_auth.workspace_id}") + payload = MemoryWriteRequest(**body) + logger.info(f"Memory write (sync) request - end_user_id: {payload.end_user_id}") memory_api_service = MemoryAPIService(db) - result = memory_api_service.create_end_user( + result = await memory_api_service.write_memory_sync( workspace_id=api_key_auth.workspace_id, - other_id=payload.other_id, + end_user_id=payload.end_user_id, + message=payload.message, + config_id=payload.config_id, + storage_type=payload.storage_type, + user_rag_memory_id=payload.user_rag_memory_id, ) - logger.info(f"End user ready: {result['id']}") - return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully") + logger.info(f"Memory write (sync) successful for end_user: {payload.end_user_id}") + return success(data=MemoryWriteSyncResponse(**result).model_dump(), msg="Memory written successfully") + + +@router.post("/read/sync") +@require_api_key(scopes=["memory"]) +async def read_memory_sync( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), + message: str = Body(..., description="Query message"), +): + """ + Read memory synchronously. + + Blocks until the read completes and returns the answer directly. + For async processing with task polling, use /read instead. + """ + body = await request.json() + payload = MemoryReadRequest(**body) + logger.info(f"Memory read (sync) request - end_user_id: {payload.end_user_id}") + + memory_api_service = MemoryAPIService(db) + + result = await memory_api_service.read_memory_sync( + workspace_id=api_key_auth.workspace_id, + end_user_id=payload.end_user_id, + message=payload.message, + search_switch=payload.search_switch, + config_id=payload.config_id, + storage_type=payload.storage_type, + user_rag_memory_id=payload.user_rag_memory_id, + ) + + logger.info(f"Memory read (sync) successful for end_user: {payload.end_user_id}") + return success(data=MemoryReadSyncResponse(**result).model_dump(), msg="Memory read successfully") diff --git a/api/app/controllers/service/memory_config_api_controller.py b/api/app/controllers/service/memory_config_api_controller.py new file mode 100644 index 00000000..1e61e0af --- /dev/null +++ b/api/app/controllers/service/memory_config_api_controller.py @@ -0,0 +1,491 @@ +"""Memory Config 服务接口 - 基于 API Key 认证""" + +from typing import Optional +import uuid + +from fastapi import APIRouter, Body, Depends, Header, Query, Request +from fastapi.encoders import jsonable_encoder +from sqlalchemy.orm import Session + +from app.controllers import memory_storage_controller +from app.controllers import memory_forget_controller +from app.controllers import ontology_controller +from app.controllers import emotion_config_controller +from app.controllers import memory_reflection_controller +from app.schemas.memory_storage_schema import ForgettingConfigUpdateRequest +from app.controllers.emotion_config_controller import EmotionConfigUpdate +from app.schemas.memory_reflection_schemas import Memory_Reflection +from app.core.api_key_auth import require_api_key +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException +from app.core.logging_config import get_business_logger +from app.core.response_utils import success +from app.db import get_db +from app.repositories.memory_config_repository import MemoryConfigRepository +from app.schemas.api_key_schema import ApiKeyAuth +from app.schemas.memory_api_schema import ( + ConfigUpdateExtractedRequest, + ConfigUpdateRequest, + ListConfigsResponse, + ConfigCreateRequest, + ConfigUpdateForgettingRequest, + EmotionConfigUpdateRequest, + ReflectionConfigUpdateRequest, +) +from app.schemas.memory_storage_schema import ( + ConfigUpdate, + ConfigUpdateExtracted, + ConfigParamsCreate, +) +from app.services import api_key_service +from app.services.memory_api_service import MemoryAPIService +from app.utils.config_utils import resolve_config_id + +router = APIRouter(prefix="/memory_config", tags=["V1 - Memory Config API"]) +logger = get_business_logger() + + +def _get_current_user(api_key_auth: ApiKeyAuth, db: Session): + """Build a current_user object from API key auth + + Args: + api_key_auth: Validated API key auth info + db: Database session + + Returns: + User object with current_workspace_id set + """ + api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id) + current_user = api_key.creator + current_user.current_workspace_id = api_key_auth.workspace_id + return current_user + + +def _verify_config_ownership(config_id:str, workspace_id:uuid.UUID, db:Session): + """Verify that the config belongs to the workspace. + + Args: + config_id: The ID of the config to verify + workspace_id: The workspace ID tocheck against + db: Database session for querying + Raises: + BusinessException: If the config does not exist or does not belong to the workspace + """ + try: + resolved_id = resolve_config_id(config_id, db) + except ValueError as e: + raise BusinessException( + message=f"Invalid config_id: {e}", + code=BizCode.INVALID_PARAMETER, + ) + config = MemoryConfigRepository.get_by_id(db, resolved_id) + if not config or config.workspace_id != workspace_id: + raise BusinessException( + message="Config not found or access denied", + code=BizCode.MEMORY_CONFIG_NOT_FOUND, + ) + +# @router.get("/configs") +# @require_api_key(scopes=["memory"]) +# async def list_memory_configs( +# request: Request, +# api_key_auth: ApiKeyAuth = None, +# db: Session = Depends(get_db), +# ): +# """ +# List all memory configs for the workspace. + +# Returns all available memory configurations associated with the authorized workspace. +# """ +# logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}") + +# memory_api_service = MemoryAPIService(db) + +# result = memory_api_service.list_memory_configs( +# workspace_id=api_key_auth.workspace_id, +# ) + +# logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}") +# return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully") + +@router.get("/read_all_config") +@require_api_key(scopes=["memory"]) +async def read_all_config( + request:Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), +): + """ + List all memory configs with full details (enhanced version). + + Returns complete config fields for the authorized workspace. + No config_id ownership check needed — results are filtered by workspace. + """ + logger.info(f"V1 get all configs (full) - workspace: {api_key_auth.workspace_id}") + + current_user = _get_current_user(api_key_auth, db) + + return memory_storage_controller.read_all_config( + current_user=current_user, + db=db, + ) + +@router.get("/scenes/simple") +@require_api_key(scopes=["memory"]) +async def get_ontology_scenes( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), +): + """ + Get available ontology scenes for the workspace. + + Returns a simple list of scene_id and scene_name for dropdown selection. + Used before creating a memory config to choose which ontology scene to associate. + """ + logger.info(f"V1 get scenes - workspace: {api_key_auth.workspace_id}") + + current_user = _get_current_user(api_key_auth, db) + + return await ontology_controller.get_scenes_simple( + db=db, + current_user=current_user, + ) + +@router.get("/read_config_extracted") +@require_api_key(scopes=["memory"]) +async def read_config_extracted( + request: Request, + config_id: str = Query(..., description="config_id"), + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), +): + """ + Get extraction engine config details for a specific config. + + Only configs belonging to the authorized workspace can be queried. + """ + logger.info(f"V1 read extracted config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}") + + _verify_config_ownership(config_id, api_key_auth.workspace_id, db) + + current_user = _get_current_user(api_key_auth, db) + + return memory_storage_controller.read_config_extracted( + config_id = config_id, + current_user = current_user, + db = db, + ) + +@router.get("/read_config_forgetting") +@require_api_key(scopes=["memory"]) +async def read_config_forgetting( + request: Request, + config_id: str = Query(..., description="config_id"), + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), +): + """ + Get forgetting settings for a specific memory config. + + Only configs belonging to the authorized workspace can be queried. + """ + logger.info(f"V1 read forgetting config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}") + + _verify_config_ownership(config_id, api_key_auth.workspace_id, db) + + current_user = _get_current_user(api_key_auth, db) + + result = await memory_forget_controller.read_forgetting_config( + config_id = config_id, + current_user = current_user, + db = db, + ) + return jsonable_encoder(result) + + + +@router.get("/read_config_emotion") +@require_api_key(scopes=["memory"]) +async def read_config_emotion( + request: Request, + config_id: str = Query(..., description="config_id"), + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), +): + """ + Get emotion engine config details for a specific config. + + Only configs belonging to the authorized workspace can be queried. + """ + logger.info(f"V1 read emotion config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}") + + _verify_config_ownership(config_id, api_key_auth.workspace_id, db) + + current_user = _get_current_user(api_key_auth, db) + + return jsonable_encoder(emotion_config_controller.get_emotion_config( + config_id=config_id, + db=db, + current_user=current_user, + )) + +@router.get("/read_config_reflection") +@require_api_key(scopes=["memory"]) +async def read_config_reflection( + request: Request, + config_id: str = Query(..., description="config_id"), + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), +): + """ + Get reflection engine config details for a specific config. + + Only configs belonging to the authorized workspace can be queried. + """ + logger.info(f"V1 read reflection config - config_id: {config_id}, workspace: {api_key_auth.workspace_id}") + + _verify_config_ownership(config_id, api_key_auth.workspace_id, db) + + current_user = _get_current_user(api_key_auth, db) + + return jsonable_encoder(await memory_reflection_controller.start_reflection_configs( + config_id=config_id, + current_user=current_user, + db=db, + )) + + +@router.post("/create_config") +@require_api_key(scopes=["memory"]) +async def create_memory_config( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), + message: str = Body(None, description="Request body"), + x_language_type: Optional[str] = Header(None, alias="X-Language-Type"), +): + """ + Create a new memory config for the workspace. + + The config will be associated with the workspace of the API Key. + config_name is required, other fields are optional. + """ + body = await request.json() + payload = ConfigCreateRequest(**body) + + logger.info(f"V1 create config - workspace: {api_key_auth.workspace_id}, config_name: {payload.config_name}") + + # 构造管理端 Schema,workspace_id 从 API Key 注入 + current_user = _get_current_user(api_key_auth, db) + mgmt_payload = ConfigParamsCreate( + config_name=payload.config_name, + config_desc=payload.config_desc or "", + scene_id=payload.scene_id, + llm_id=payload.llm_id, + embedding_id=payload.embedding_id, + rerank_id=payload.rerank_id, + reflection_model_id=payload.reflection_model_id, + emotion_model_id=payload.emotion_model_id, + ) + #将返回数据中UUID序列化处理 + result =memory_storage_controller.create_config( + payload=mgmt_payload, + current_user=current_user, + db=db, + x_language_type=x_language_type, + ) + return jsonable_encoder(result) + +@router.put("/update_config") +@require_api_key(scopes=["memory"]) +async def update_memory_config( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), + message: str = Body(None, description="Request body"), +): + """ + Update memory config basic info (name, description, scene). + + Requires API Key with 'memory' scope + Only configs belonging to the authorized workspace can be updated. + """ + body = await request.json() + payload = ConfigUpdateRequest(**body) + + logger.info(f"V1 update config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}") + + _verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db) + + current_user = _get_current_user(api_key_auth, db) + mgmt_payload = ConfigUpdate( + config_id = payload.config_id, + config_name = payload.config_name, + config_desc = payload.config_desc, + scene_id = payload.scene_id, + ) + + return memory_storage_controller.update_config( + payload = mgmt_payload, + current_user = current_user, + db = db, + ) + +@router.put("/update_config_extracted") +@require_api_key(scopes=["memory"]) +async def update_memory_config_extracted( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), + message: str = Body(None, description="Request body"), +): + """ + update memory config extraction engine config (models, thresholds, chunking, pruning, etc.). + + Requires API Key with 'memory' scope. + Only configs belonging to the authorized workspace can be updated. + """ + body = await request.json() + payload = ConfigUpdateExtractedRequest(**body) + + logger.info(f"V1 update extracted config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}") + + #校验权限 + _verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db) + + current_user = _get_current_user(api_key_auth, db) + update_fields = payload.model_dump(exclude_unset=True) + mgmt_payload = ConfigUpdateExtracted(**update_fields) + + return memory_storage_controller.update_config_extracted( + payload = mgmt_payload, + current_user = current_user, + db = db, + ) + +@router.put("/update_config_forgetting") +@require_api_key(scopes=["memory"]) +async def update_memory_config_forgetting( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), + message: str = Body(None, description="Request body"), +): + """ + update memory config forgetting settings (forgetting strategy, parameters, etc.). + + Requires API Key with 'memory' scope. + Only configs belonging to the authorized workspace can be updated. + """ + body = await request.json() + payload = ConfigUpdateForgettingRequest(**body) + + logger.info(f"V1 update forgetting config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}") + + #校验权限 + _verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db) + + current_user = _get_current_user(api_key_auth, db) + update_fields = payload.model_dump(exclude_unset=True) + mgmt_payload = ForgettingConfigUpdateRequest(**update_fields) + + #将返回数据中UUID序列化处理 + result = await memory_forget_controller.update_forgetting_config( + payload = mgmt_payload, + current_user = current_user, + db = db, + ) + return jsonable_encoder(result) + +@router.put("/update_config_emotion") +@require_api_key(scopes=["memory"]) +async def update_config_emotion( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), + message: str = Body(None, description="Request body"), +): + """ + Update emotion engine config (full update). + + All fields except emotion_model_id are required. + Only configs belonging to the authorized workspace can be updated. + """ + body = await request.json() + payload = EmotionConfigUpdateRequest(**body) + + logger.info(f"V1 update emotion config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}") + + _verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db) + + current_user = _get_current_user(api_key_auth, db) + update_fields = payload.model_dump(exclude_unset=True) + mgmt_payload = EmotionConfigUpdate(**update_fields) + return jsonable_encoder(emotion_config_controller.update_emotion_config( + config=mgmt_payload, + db=db, + current_user=current_user, + )) + +@router.put("/update_config_reflection") +@require_api_key(scopes=["memory"]) +async def update_config_reflection( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), + message: str = Body(None, description="Request body"), +): + """ + Update reflection engine config (full update). + + All fields are required. + Only configs belonging to the authorized workspace can be updated. + """ + body = await request.json() + payload = ReflectionConfigUpdateRequest(**body) + + logger.info(f"V1 update reflection config - config_id: {payload.config_id}, workspace: {api_key_auth.workspace_id}") + + _verify_config_ownership(payload.config_id, api_key_auth.workspace_id, db) + + current_user = _get_current_user(api_key_auth, db) + update_fields = payload.model_dump(exclude_unset=True) + mgmt_payload = Memory_Reflection(**update_fields) + + return jsonable_encoder(await memory_reflection_controller.save_reflection_config( + request=mgmt_payload, + current_user=current_user, + db=db, + )) + +@router.delete("/delete_config") +@require_api_key(scopes=["memory"]) +async def delete_memory_config( + config_id: str, + request: Request, + force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"), + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), +): + """ + Delete a memory config. + + - Default configs cannot be deleted. + - If end users are connected and force=False, returns a warning. + - If force=True, clears end user references and deletes the config. + + Only configs belonging to the authorized workspace can be deleted. + """ + logger.info(f"V1 delete config - config_id: {config_id}, force: {force}, workspace: {api_key_auth.workspace_id}") + + _verify_config_ownership(config_id, api_key_auth.workspace_id, db) + + current_user = _get_current_user(api_key_auth, db) + + return memory_storage_controller.delete_config( + config_id=config_id, + force=force, + current_user=current_user, + db=db, + ) diff --git a/api/app/core/memory/storage_services/search/__init__.py b/api/app/core/memory/storage_services/search/__init__.py index c12c39b0..49154e19 100644 --- a/api/app/core/memory/storage_services/search/__init__.py +++ b/api/app/core/memory/storage_services/search/__init__.py @@ -4,11 +4,6 @@ 本模块提供统一的搜索服务接口,支持关键词搜索、语义搜索和混合搜索。 """ -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from app.schemas.memory_config_schema import MemoryConfig - from app.core.memory.storage_services.search.hybrid_search import HybridSearchStrategy from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy from app.core.memory.storage_services.search.search_strategy import ( @@ -29,115 +24,87 @@ __all__ = [ # ============================================================================ -# 向后兼容的函数式API +# 向后兼容的函数式API (DEPRECATED - 未被使用) # ============================================================================ -# 为了兼容旧代码,提供与 src/search.py 相同的函数式接口 +# 所有调用方均直接使用 app.core.memory.src.search.run_hybrid_search +# 保留注释以备参考 - -async def run_hybrid_search( - query_text: str, - search_type: str = "hybrid", - end_user_id: str | None = None, - apply_id: str | None = None, - user_id: str | None = None, - limit: int = 50, - include: list[str] | None = None, - alpha: float = 0.6, - use_forgetting_curve: bool = False, - memory_config: "MemoryConfig" = None, - **kwargs -) -> dict: - """运行混合搜索(向后兼容的函数式API) - - 这是一个向后兼容的包装函数,将旧的函数式API转换为新的基于类的API。 - - Args: - query_text: 查询文本 - search_type: 搜索类型("hybrid", "keyword", "semantic") - end_user_id: 组ID过滤 - apply_id: 应用ID过滤 - user_id: 用户ID过滤 - limit: 每个类别的最大结果数 - include: 要包含的搜索类别列表 - alpha: BM25分数权重(0.0-1.0) - use_forgetting_curve: 是否使用遗忘曲线 - memory_config: MemoryConfig object containing embedding_model_id - **kwargs: 其他参数 - - Returns: - dict: 搜索结果字典,格式与旧API兼容 - """ - from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient - from app.core.models.base import RedBearModelConfig - from app.db import get_db_context - from app.repositories.neo4j.neo4j_connector import Neo4jConnector - from app.services.memory_config_service import MemoryConfigService - - if not memory_config: - raise ValueError("memory_config is required for search") - - # 初始化客户端 - connector = Neo4jConnector() - with get_db_context() as db: - config_service = MemoryConfigService(db) - embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id)) - embedder_config = RedBearModelConfig(**embedder_config_dict) - embedder_client = OpenAIEmbedderClient(embedder_config) - - try: - # 根据搜索类型选择策略 - if search_type == "keyword": - strategy = KeywordSearchStrategy(connector=connector) - elif search_type == "semantic": - strategy = SemanticSearchStrategy( - connector=connector, - embedder_client=embedder_client - ) - else: # hybrid - strategy = HybridSearchStrategy( - connector=connector, - embedder_client=embedder_client, - alpha=alpha, - use_forgetting_curve=use_forgetting_curve - ) - - # 执行搜索 - result = await strategy.search( - query_text=query_text, - end_user_id=end_user_id, - limit=limit, - include=include, - alpha=alpha, - use_forgetting_curve=use_forgetting_curve, - **kwargs - ) - - # 转换为旧格式 - result_dict = result.to_dict() - - # 保存到文件(如果指定了output_path) - output_path = kwargs.get('output_path', 'search_results.json') - if output_path: - import json - import os - from datetime import datetime - - try: - # 确保目录存在 - out_dir = os.path.dirname(output_path) - if out_dir: - os.makedirs(out_dir, exist_ok=True) - - # 保存结果 - with open(output_path, "w", encoding="utf-8") as f: - json.dump(result_dict, f, ensure_ascii=False, indent=2, default=str) - print(f"Search results saved to {output_path}") - except Exception as e: - print(f"Error saving search results: {e}") - return result_dict - - finally: - await connector.close() - - -__all__.append("run_hybrid_search") +# async def run_hybrid_search( +# query_text: str, +# search_type: str = "hybrid", +# end_user_id: str | None = None, +# apply_id: str | None = None, +# user_id: str | None = None, +# limit: int = 50, +# include: list[str] | None = None, +# alpha: float = 0.6, +# use_forgetting_curve: bool = False, +# memory_config: "MemoryConfig" = None, +# **kwargs +# ) -> dict: +# """运行混合搜索(向后兼容的函数式API)""" +# from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient +# from app.core.models.base import RedBearModelConfig +# from app.db import get_db_context +# from app.repositories.neo4j.neo4j_connector import Neo4jConnector +# from app.services.memory_config_service import MemoryConfigService +# +# if not memory_config: +# raise ValueError("memory_config is required for search") +# +# connector = Neo4jConnector() +# with get_db_context() as db: +# config_service = MemoryConfigService(db) +# embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id)) +# embedder_config = RedBearModelConfig(**embedder_config_dict) +# embedder_client = OpenAIEmbedderClient(embedder_config) +# +# try: +# if search_type == "keyword": +# strategy = KeywordSearchStrategy(connector=connector) +# elif search_type == "semantic": +# strategy = SemanticSearchStrategy( +# connector=connector, +# embedder_client=embedder_client +# ) +# else: +# strategy = HybridSearchStrategy( +# connector=connector, +# embedder_client=embedder_client, +# alpha=alpha, +# use_forgetting_curve=use_forgetting_curve +# ) +# +# result = await strategy.search( +# query_text=query_text, +# end_user_id=end_user_id, +# limit=limit, +# include=include, +# alpha=alpha, +# use_forgetting_curve=use_forgetting_curve, +# **kwargs +# ) +# +# result_dict = result.to_dict() +# +# output_path = kwargs.get('output_path', 'search_results.json') +# if output_path: +# import json +# import os +# from datetime import datetime +# +# try: +# out_dir = os.path.dirname(output_path) +# if out_dir: +# os.makedirs(out_dir, exist_ok=True) +# with open(output_path, "w", encoding="utf-8") as f: +# json.dump(result_dict, f, ensure_ascii=False, indent=2, default=str) +# print(f"Search results saved to {output_path}") +# except Exception as e: +# print(f"Error saving search results: {e}") +# return result_dict +# +# finally: +# await connector.close() +# +# __all__.append("run_hybrid_search") diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index 3139b851..072be1e2 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -328,7 +328,7 @@ class MemoryConfigRepository: if not db_config: db_logger.warning(f"记忆配置不存在: config_id={update.config_id}") return None - + #TODO:部分更新没有用patch请求,是在Repository层中用先查再部分更新的方式实现的,后续可以考虑改成patch请求更符合RESTful设计原则 update_data = update.model_dump(exclude_unset=True) update_data.pop("config_id", None) diff --git a/api/app/schemas/memory_api_schema.py b/api/app/schemas/memory_api_schema.py index ff62355f..4cc548f3 100644 --- a/api/app/schemas/memory_api_schema.py +++ b/api/app/schemas/memory_api_schema.py @@ -4,9 +4,10 @@ This module defines Pydantic schemas for the Memory API Service endpoints, including request validation and response structures for read and write operations. """ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Literal, Optional +import uuid -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator class MemoryWriteRequest(BaseModel): @@ -110,6 +111,30 @@ class MemoryReadRequest(BaseModel): class MemoryWriteResponse(BaseModel): """Response schema for memory write operation. + Attributes: + task_id: Celery task ID for status polling + status: Initial task status (PENDING) + end_user_id: End user ID the write was submitted for + """ + task_id: str = Field(..., description="Celery task ID for polling") + status: str = Field(..., description="Task status: PENDING") + end_user_id: str = Field(..., description="End user ID") + + +class TaskStatusResponse(BaseModel): + """Response schema for task status check. + + Attributes: + status: Task status (PENDING, STARTED, SUCCESS, FAILURE, SKIPPED) + result: Task result data (available when status is SUCCESS or FAILURE) + """ + status: str = Field(..., description="Task status") + result: Optional[Dict[str, Any]] = Field(None, description="Task result when completed") + + +class MemoryWriteSyncResponse(BaseModel): + """Response schema for synchronous memory write. + Attributes: status: Operation status (success or failed) end_user_id: End user ID that was written to @@ -118,8 +143,8 @@ class MemoryWriteResponse(BaseModel): end_user_id: str = Field(..., description="End user ID") -class MemoryReadResponse(BaseModel): - """Response schema for memory read operation. +class MemoryReadSyncResponse(BaseModel): + """Response schema for synchronous memory read. Attributes: answer: Generated answer from memory retrieval @@ -128,12 +153,25 @@ class MemoryReadResponse(BaseModel): """ answer: str = Field(..., description="Generated answer") intermediate_outputs: List[Dict[str, Any]] = Field( - default_factory=list, + default_factory=list, description="Intermediate retrieval outputs" ) end_user_id: str = Field(..., description="End user ID") +class MemoryReadResponse(BaseModel): + """Response schema for memory read operation. + + Attributes: + task_id: Celery task ID for status polling + status: Initial task status (PENDING) + end_user_id: End user ID the read was submitted for + """ + task_id: str = Field(..., description="Celery task ID for polling") + status: str = Field(..., description="Task status: PENDING") + end_user_id: str = Field(..., description="End user ID") + + class CreateEndUserRequest(BaseModel): """Request schema for creating an end user. @@ -141,10 +179,12 @@ class CreateEndUserRequest(BaseModel): other_id: External user identifier (required) other_name: Display name for the end user memory_config_id: Optional memory config ID. If not provided, uses workspace default. + app_id: Optional app ID to bind the end user to. """ other_id: str = Field(..., description="External user identifier (required)") other_name: Optional[str] = Field("", description="Display name") memory_config_id: Optional[str] = Field(None, description="Memory config ID. Falls back to workspace default if not provided.") + app_id: Optional[str] = Field(None, description="App ID to bind the end user to") @field_validator("other_id") @classmethod @@ -192,6 +232,7 @@ class MemoryConfigItem(BaseModel): created_at: Optional[str] = Field(None, description="Creation timestamp") updated_at: Optional[str] = Field(None, description="Last update timestamp") +# ========== V1 记忆配置管理接口 Schema ========== class ListConfigsResponse(BaseModel): """Response schema for listing memory configs. @@ -202,3 +243,203 @@ class ListConfigsResponse(BaseModel): """ configs: List[MemoryConfigItem] = Field(default_factory=list, description="List of configs") total: int = Field(0, description="Total number of configs") + +class ConfigCreateRequest(BaseModel): + """Request schema for creating a new memory config.""" + config_name: str = Field(..., description="Configuration name") + config_desc: Optional[str] = Field("", description="Configuration description") + scene_id: uuid.UUID = Field(..., description="Associated ontology scene ID (UUID, required)") + + llm_id: Optional[str] = Field(None, description="LLM model configuration ID") + embedding_id: Optional[str] = Field(None, description="Embedding model configuration ID") + rerank_id: Optional[str] = Field(None, description="Reranking model configuration ID") + reflection_model_id: Optional[str] = Field(None, description="Reflection model ID") + emotion_model_id: Optional[str] = Field(None, description="Emotion analysis model ID") + + @field_validator("config_name") + @classmethod + def validate_config_name(cls, v: str) -> str: + if not v or not v.strip(): + raise ValueError("config_name is required and cannot be empty") + return v.strip() + +class ConfigUpdateRequest(BaseModel): + """Request schema for updating memory config basic info. + + Attributes: + config_id: Configuration UUID to update (required) + config_name: New configuration name + config_desc: New configuration description + scene_id: New associated ontology scene ID + """ + config_id: str = Field(..., description="Configuration ID to update") + config_name: Optional[str] = Field(None, description="Configuration name") + config_desc: Optional[str] = Field(None, description="Configuration description") + scene_id: Optional[uuid.UUID] = Field(None, description="Associated ontology scene ID") + + @field_validator("config_id") + @classmethod + def validate_config_id(cls, v: str) -> str: + """Validate that config_id is not empty.""" + if not v or not v.strip(): + raise ValueError("config_id is required and cannot be empty") + return v.strip() + +class ConfigUpdateExtractedRequest(BaseModel): + """Request schema for updating memory config extracted parameters. + + Attributes: + config_id: Configuration UUID to update (required) + llm_id: Optional LLM model configuration ID + audio_id: Optional audio model configuration ID + vision_id: Optional vision model configuration ID + video_id: Optional video model configuration ID + embedding_id: Optional embedding model configuration ID + rerank_id: Optional reranking model configuration ID + enable_llm_dedup_blockwise: Optional toggle for LLM decision deduplication + enable_llm_disambiguation: Optional toggle for LLM decision disambiguation + deep_retrieval: Optional toggle for deep retrieval + + t_type_strict: Optional float (0-1) for type strictness threshold + t_name_strict: Optional float (0-1) for name strictness threshold + t_overall: Optional float (0-1) for overall strictness threshold + state: Optional boolean for config active state + chunker_strategy: Optional string for memory chunking strategy + statement_granularity: Optional int (1-3) for statement extraction granularity + include_dialogue_context: Optional boolean for including dialogue context in retrieval + max_context: Optional int for maximum dialogue context length in characters + pruning_enabled: Optional boolean to enable intelligent semantic pruning + pruning_scene: Optional string for semantic pruning scene + pruning_threshold: Optional float (0-0.9) for semantic pruning threshold + enable_self_reflexion: Optional boolean to enable self-reflexion + iteration_period: Optional string for reflexion iteration period in hours (1, 3, 6, 12, 24) + reflexion_range: Optional string for reflexion range (partial or all) + baseline: Optional string for baseline (TIME/FACT/TIME-FACT) + + """ + config_id: str = Field(..., description="Configuration ID (UUID)") + llm_id: Optional[str] = Field(None, description="LLM model configuration ID") + audio_id: Optional[str] = Field(None, description="Audio model ID") + vision_id: Optional[str] = Field(None, description="Vision model ID") + video_id: Optional[str] = Field(None, description="Video model ID") + embedding_id: Optional[str] = Field(None, description="Embedding model configuration ID") + rerank_id: Optional[str] = Field(None, description="Reranking model configuration ID") + enable_llm_dedup_blockwise: Optional[bool] = Field(None, description="Enable LLM decision deduplication") + enable_llm_disambiguation: Optional[bool] = Field(None, description="Enable LLM decision disambiguation") + deep_retrieval: Optional[bool] = Field(None, description="Deep retrieval toggle") + + t_type_strict: Optional[float] = Field(None, ge=0.0, le=1.0, description="type strictness threshold") + t_name_strict: Optional[float] = Field(None, ge=0.0, le=1.0, description="name strictness threshold") + t_overall: Optional[float] = Field(None, ge=0.0, le=1.0, description="overall strictness threshold") + state: Optional[bool] = Field(None, description="config active state") + # 句子提取 + chunker_strategy: Optional[str] = Field(None, description="memory chunking strategy") + statement_granularity: Optional[int] = Field(None, ge=1, le=3, description="statement extraction granularity") + include_dialogue_context: Optional[bool] = Field(None, description="whether to include dialogue context in retrieval") + max_context: Optional[int] = Field(None, gt=100, description="maximum dialogue context length in characters") + # 剪枝配置:与 runtime.json 中 pruning 段对应 + pruning_enabled: Optional[bool] = Field(None, description="whether to enable intelligent semantic pruning") + pruning_scene: Optional[str] = Field(None, description="semantic pruning scene") + pruning_threshold: Optional[float] = Field(None, ge=0.0, le=0.9, description="semantic pruning threshold (0-0.9)") + enable_self_reflexion: Optional[bool] = Field(None, description="whether to enable self-reflexion") + iteration_period: Optional[Literal["1", "3", "6", "12", "24"]] = Field(None, description="reflexion iteration period in hours (1, 3, 6, 12, 24)") + reflexion_range: Optional[Literal["partial", "all"]] = Field(None, description="reflexion range: partial/all") + baseline: Optional[Literal["TIME", "FACT", "TIME-FACT"]] = Field(None, description="baseline: TIME/FACT/TIME-FACT") + + @field_validator("config_id") + @classmethod + def validate_config_id(cls, v: str) -> str: + if not v or not v.strip(): + raise ValueError("config_id is required and cannot be empty") + return v.strip() + +class ConfigUpdateForgettingRequest(BaseModel): + """Request schema for updating memory config forgetting parameters. + + Attributes: + config_id: Configuration UUID to update (required) + decay_constant: Decay constant for forgetting + lambda_time: Time decay parameter + lambda_mem: Memory decay parameter + offset: Offset for forgetting curve + max_history_length: Maximum history length to consider for forgetting + forgetting_threshold: Threshold for forgetting + min_days_since_access: Minimum days since last access to trigger forgetting + enable_llm_summary: Whether to use LLM-generated summaries for forgetting + max_merge_batch_size: Maximum batch size for merging nodes during forgetting + forgetting_interval_hours: Interval in hours for periodic forgetting + + """ + model_config = ConfigDict(populate_by_name=True, extra="forbid") + config_id: str = Field(..., description="Configuration ID (UUID)") + decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="Decay constant for forgetting") + lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="Time decay parameter") + lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="Memory decay parameter") + offset: Optional[float] = Field(None, ge=0.0, le=1.0, description="Offset for forgetting curve") + max_history_length: Optional[int] = Field(None, ge=10, le=1000, description="Maximum history length to consider for forgetting") + forgetting_threshold: Optional[float] = Field(None, ge=0.0, le=1.0, description="Forgetting threshold") + min_days_since_access: Optional[int] = Field(None, ge=1, le=365, description="Minimum days since last access to trigger forgetting") + enable_llm_summary: Optional[bool] = Field(None, description="Whether to use LLM-generated summaries for forgetting") + max_merge_batch_size: Optional[int] = Field(None, ge=1, le=1000, description="Maximum batch size for merging nodes during forgetting") + forgetting_interval_hours: Optional[int] = Field(None, ge=1, le=168, description="Interval in hours for periodic forgetting") + + @field_validator("config_id") + @classmethod + def validate_config_id(cls, v: str) -> str: + if not v or not v.strip(): + raise ValueError("config_id is required and cannot be empty") + return v.strip() + +class EmotionConfigUpdateRequest(BaseModel): + """Request schema for updating memory config emotion parameters. + + Attributes: + config_id: Configuration UUID to update (required) + emotion_enabled: Whether to enable emotion extraction + emotion_model_id: Emotion analysis model ID + emotion_extract_keywords: Whether to extract emotion keywords + emotion_min_intensity: Minimum emotion intensity threshold (0.0-1.0) + emotion_enable_subject: Whether to enable subject classification for emotions + """ + config_id: str = Field(..., description="Configuration ID (UUID)") + emotion_enabled: bool = Field(..., description="Whether to enable emotion extraction") + emotion_model_id: Optional[str] = Field(None, description="Emotion analysis model ID") + emotion_extract_keywords: bool = Field(..., description="Whether to extract emotion keywords") + emotion_min_intensity: float = Field(..., ge=0.0, le=1.0, description="Minimum emotion intensity threshold") + emotion_enable_subject: bool = Field(..., description="Whether to enable subject classification for emotions") + + @field_validator("config_id") + @classmethod + def validate_config_id(cls, v: str) -> str: + if not v or not v.strip(): + raise ValueError("config_id is required and cannot be empty") + return v.strip() + +class ReflectionConfigUpdateRequest(BaseModel): + """Request schema for updating memory config reflection parameters. + + Attributes: + config_id: Configuration UUID to update (required) + reflection_enabled: Whether to enable self-reflection + reflection_period_in_hours: Reflection iteration period in hours + reflexion_range: Reflection range (partial or all) + baseline: Baseline for reflection (TIME/FACT/TIME-FACT) + reflection_model_id: Reflection model ID + memory_verify: Whether to enable memory verification + quality_assessment: Whether to enable quality assessment + """ + config_id: str = Field(..., description="Configuration ID (UUID)") + reflection_enabled: bool = Field(..., description="Whether to enable self-reflection") + reflection_period_in_hours: str = Field(..., description="Reflection iteration period in hours") + reflexion_range: Literal["partial", "all"] = Field(..., description="Reflection range: partial/all") + baseline: Literal["TIME", "FACT", "TIME-FACT"] = Field(..., description="Baseline: TIME/FACT/TIME-FACT") + reflection_model_id: str = Field(..., description="Reflection model ID") + memory_verify: bool = Field(..., description="Whether to enable memory verification") + quality_assessment: bool = Field(..., description="Whether to enable quality assessment") + + @field_validator("config_id") + @classmethod + def validate_config_id(cls, v: str) -> str: + if not v or not v.strip(): + raise ValueError("config_id is required and cannot be empty") + return v.strip() diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index bfcf6337..24dddd80 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -291,7 +291,7 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数 pruning_threshold: Optional[float] = Field( None, ge=0.0, le=0.9, description="智能语义剪枝阈值(0-0.9)" ) - + #TODO:萃取引擎的更新的更新会带有反思引擎的参数,需判断业务是否需要,不需要可以重构 # 反思配置 enable_self_reflexion: Optional[bool] = Field(None, description="是否启用自我反思") iteration_period: Optional[Literal["1", "3", "6", "12", "24"]] = Field( diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index f62f526c..330b84ad 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -8,6 +8,8 @@ This service validates inputs and delegates to MemoryAgentService for core memor import uuid from typing import Any, Dict, Optional +from sqlalchemy.orm import Session + from app.core.error_codes import BizCode from app.core.exceptions import BusinessException, ResourceNotFoundException from app.core.logging_config import get_logger @@ -15,7 +17,6 @@ from app.models.app_model import App from app.models.end_user_model import EndUser from app.schemas.memory_config_schema import ConfigurationError from app.services.memory_agent_service import MemoryAgentService -from sqlalchemy.orm import Session logger = get_logger(__name__) @@ -124,7 +125,7 @@ class MemoryAPIService: except Exception as e: logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}") - async def write_memory( + def write_memory( self, workspace_id: uuid.UUID, end_user_id: str, @@ -133,27 +134,28 @@ class MemoryAPIService: storage_type: str = "neo4j", user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: - """Write memory with validation. - + """Submit a memory write task via Celery. + Validates end_user exists and belongs to workspace, updates the end user's - memory_config_id, then delegates to MemoryAgentService.write_memory. - + memory_config_id, then dispatches write_message_task to Celery for async + processing with per-user fair locking. + Args: workspace_id: Workspace ID for resource validation - end_user_id: End user identifier (used as end_user_id) + end_user_id: End user identifier message: Message content to store config_id: Memory configuration ID (required) storage_type: Storage backend (neo4j or rag) user_rag_memory_id: Optional RAG memory ID - + Returns: - Dict with status and end_user_id - + Dict with task_id, status, and end_user_id + Raises: ResourceNotFoundException: If end_user not found - BusinessException: If end_user not in authorized workspace or write fails + BusinessException: If validation fails """ - logger.info(f"Writing memory for end_user: {end_user_id}, workspace: {workspace_id}") + logger.info(f"Submitting memory write for end_user: {end_user_id}, workspace: {workspace_id}") # Validate end_user exists and belongs to workspace self.validate_end_user(end_user_id, workspace_id) @@ -161,9 +163,120 @@ class MemoryAPIService: # Update end user's memory_config_id self._update_end_user_config(end_user_id, config_id) + # Convert to message list format expected by write_message_task + messages = message if isinstance(message, list) else [{"role": "user", "content": message}] + + from app.tasks import write_message_task + task = write_message_task.delay( + end_user_id, + messages, + config_id, + storage_type, + user_rag_memory_id or "", + ) + + logger.info(f"Memory write task submitted: task_id={task.id}, end_user_id={end_user_id}") + + return { + "task_id": task.id, + "status": "PENDING", + "end_user_id": end_user_id, + } + + def read_memory( + self, + workspace_id: uuid.UUID, + end_user_id: str, + message: str, + search_switch: str = "0", + config_id: str = "", + storage_type: str = "neo4j", + user_rag_memory_id: Optional[str] = None, + ) -> Dict[str, Any]: + """Submit a memory read task via Celery. + + Validates end_user exists and belongs to workspace, updates the end user's + memory_config_id, then dispatches read_message_task to Celery for async processing. + + Args: + workspace_id: Workspace ID for resource validation + end_user_id: End user identifier + message: Query message + search_switch: Search mode (0=deep search with verification, 1=deep search, 2=fast search) + config_id: Memory configuration ID (required) + storage_type: Storage backend (neo4j or rag) + user_rag_memory_id: Optional RAG memory ID + + Returns: + Dict with task_id, status, and end_user_id + + Raises: + ResourceNotFoundException: If end_user not found + BusinessException: If validation fails + """ + logger.info(f"Submitting memory read for end_user: {end_user_id}, workspace: {workspace_id}") + + # Validate end_user exists and belongs to workspace + self.validate_end_user(end_user_id, workspace_id) + + # Update end user's memory_config_id + self._update_end_user_config(end_user_id, config_id) + + from app.tasks import read_message_task + task = read_message_task.delay( + end_user_id, + message, + [], # history + search_switch, + config_id, + storage_type, + user_rag_memory_id or "", + ) + + logger.info(f"Memory read task submitted: task_id={task.id}, end_user_id={end_user_id}") + + return { + "task_id": task.id, + "status": "PENDING", + "end_user_id": end_user_id, + } + + async def write_memory_sync( + self, + workspace_id: uuid.UUID, + end_user_id: str, + message: str, + config_id: str, + storage_type: str = "neo4j", + user_rag_memory_id: Optional[str] = None, + ) -> Dict[str, Any]: + """Write memory synchronously (inline, no Celery). + + Validates end_user, then calls MemoryAgentService.write_memory directly. + Blocks until the write completes. Use for cases where the caller needs + immediate confirmation. + + Args: + workspace_id: Workspace ID for resource validation + end_user_id: End user identifier + message: Message content to store + config_id: Memory configuration ID (required) + storage_type: Storage backend (neo4j or rag) + user_rag_memory_id: Optional RAG memory ID + + Returns: + Dict with status and end_user_id + + Raises: + ResourceNotFoundException: If end_user not found + BusinessException: If write fails + """ + logger.info(f"Writing memory (sync) for end_user: {end_user_id}, workspace: {workspace_id}") + + self.validate_end_user(end_user_id, workspace_id) + self._update_end_user_config(end_user_id, config_id) + try: - # Delegate to MemoryAgentService - # Convert string message to list[dict] format expected by MemoryAgentService messages = message if isinstance(message, list) else [{"role": "user", "content": message}] result = await MemoryAgentService().write_memory( end_user_id=end_user_id, @@ -174,11 +287,8 @@ class MemoryAPIService: user_rag_memory_id=user_rag_memory_id or "", ) - logger.info(f"Memory write successful for end_user: {end_user_id}") + logger.info(f"Memory write (sync) successful for end_user: {end_user_id}") - # result may be a string "success" or a dict with a "status" key - # Preserve the full dict so callers don't silently lose extra fields - # (e.g. error codes, metadata) returned by MemoryAgentService. if isinstance(result, dict): return { **result, @@ -192,20 +302,17 @@ class MemoryAPIService: except ConfigurationError as e: logger.error(f"Memory configuration error for end_user {end_user_id}: {e}") - raise BusinessException( - message=str(e), - code=BizCode.MEMORY_CONFIG_NOT_FOUND - ) + raise BusinessException(message=str(e), code=BizCode.MEMORY_CONFIG_NOT_FOUND) except BusinessException: raise except Exception as e: - logger.error(f"Memory write failed for end_user {end_user_id}: {e}") + logger.error(f"Memory write (sync) failed for end_user {end_user_id}: {e}") raise BusinessException( message=f"Memory write failed: {str(e)}", code=BizCode.MEMORY_WRITE_FAILED ) - async def read_memory( + async def read_memory_sync( self, workspace_id: uuid.UUID, end_user_id: str, @@ -215,37 +322,34 @@ class MemoryAPIService: storage_type: str = "neo4j", user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: - """Read memory with validation. - - Validates end_user exists and belongs to workspace, updates the end user's - memory_config_id, then delegates to MemoryAgentService.read_memory. - + """Read memory synchronously (inline, no Celery). + + Validates end_user, then calls MemoryAgentService.read_memory directly. + Blocks until the read completes. Use for cases where the caller needs + the answer immediately. + Args: workspace_id: Workspace ID for resource validation - end_user_id: End user identifier (used as end_user_id) + end_user_id: End user identifier message: Query message search_switch: Search mode (0=deep search with verification, 1=deep search, 2=fast search) config_id: Memory configuration ID (required) storage_type: Storage backend (neo4j or rag) user_rag_memory_id: Optional RAG memory ID - + Returns: Dict with answer, intermediate_outputs, and end_user_id - + Raises: ResourceNotFoundException: If end_user not found - BusinessException: If end_user not in authorized workspace or read fails + BusinessException: If read fails """ - logger.info(f"Reading memory for end_user: {end_user_id}, workspace: {workspace_id}") + logger.info(f"Reading memory (sync) for end_user: {end_user_id}, workspace: {workspace_id}") - # Validate end_user exists and belongs to workspace self.validate_end_user(end_user_id, workspace_id) - - # Update end user's memory_config_id self._update_end_user_config(end_user_id, config_id) try: - # Delegate to MemoryAgentService result = await MemoryAgentService().read_memory( end_user_id=end_user_id, message=message, @@ -257,7 +361,7 @@ class MemoryAPIService: user_rag_memory_id=user_rag_memory_id or "" ) - logger.info(f"Memory read successful for end_user: {end_user_id}") + logger.info(f"Memory read (sync) successful for end_user: {end_user_id}") return { "answer": result.get("answer", ""), @@ -267,14 +371,11 @@ class MemoryAPIService: except ConfigurationError as e: logger.error(f"Memory configuration error for end_user {end_user_id}: {e}") - raise BusinessException( - message=str(e), - code=BizCode.MEMORY_CONFIG_NOT_FOUND - ) + raise BusinessException(message=str(e), code=BizCode.MEMORY_CONFIG_NOT_FOUND) except BusinessException: raise except Exception as e: - logger.error(f"Memory read failed for end_user {end_user_id}: {e}") + logger.error(f"Memory read (sync) failed for end_user {end_user_id}: {e}") raise BusinessException( message=f"Memory read failed: {str(e)}", code=BizCode.MEMORY_READ_FAILED