From 863be50aafef5fd4d29b3ef040913fee9c9e7ae6 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Thu, 26 Mar 2026 15:03:33 +0800 Subject: [PATCH] [changes] Spatial verification, retrieval synchronization --- .../controllers/user_memory_controllers.py | 23 +++++++++++++++++++ .../extraction_orchestrator.py | 5 ++-- .../repositories/end_user_info_repository.py | 7 +++--- api/app/repositories/end_user_repository.py | 4 ++-- api/app/services/user_memory_service.py | 14 ++++------- 5 files changed, 35 insertions(+), 18 deletions(-) diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py index b0dc82a0..10b396a7 100644 --- a/api/app/controllers/user_memory_controllers.py +++ b/api/app/controllers/user_memory_controllers.py @@ -23,6 +23,7 @@ from app.services.memory_entity_relationship_service import MemoryEntityService, from app.schemas.response_schema import ApiResponse from app.schemas.memory_storage_schema import GenerateCacheRequest from app.repositories.workspace_repository import WorkspaceRepository +from app.repositories.end_user_repository import EndUserRepository from app.schemas.end_user_info_schema import ( EndUserInfoResponse, EndUserInfoCreate, @@ -361,6 +362,17 @@ async def get_end_user_info( f"workspace={workspace_id}" ) + # 校验 end_user 是否属于当前工作空间 + end_user_repo = EndUserRepository(db) + end_user = end_user_repo.get_end_user_by_id(end_user_id) + if end_user is None: + return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", "end_user not found") + if str(end_user.workspace_id) != str(workspace_id): + api_logger.warning( + f"用户 {current_user.username} 尝试查询不属于工作空间 {workspace_id} 的终端用户 {end_user_id}" + ) + return fail(BizCode.PERMISSION_DENIED, "该终端用户不属于当前工作空间", "end_user workspace mismatch") + result = user_memory_service.get_end_user_info(db, end_user_id) if result["success"]: @@ -409,6 +421,17 @@ async def update_end_user_info( f"workspace={workspace_id}" ) + # 校验 end_user 是否属于当前工作空间 + end_user_repo = EndUserRepository(db) + end_user = end_user_repo.get_end_user_by_id(end_user_id) + if end_user is None: + return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", "end_user not found") + if str(end_user.workspace_id) != str(workspace_id): + api_logger.warning( + f"用户 {current_user.username} 尝试更新不属于工作空间 {workspace_id} 的终端用户 {end_user_id}" + ) + return fail(BizCode.PERMISSION_DENIED, "该终端用户不属于当前工作空间", "end_user workspace mismatch") + # 获取更新数据(排除 end_user_id) update_data = info_update.model_dump(exclude_unset=True, exclude={'end_user_id'}) diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index d5681da9..58a4c441 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -1389,9 +1389,8 @@ class ExtractionOrchestrator: logger.debug(f"end_user 表 other_name 保持不变: {end_user.other_name}") # 更新或创建 end_user_info 记录 - existing_infos = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid) - if existing_infos: - info = existing_infos[0] + info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid) + if info: new_name_info = self._resolve_other_name(info.other_name, current_aliases, neo4j_aliases) if new_name_info is not None: info.other_name = new_name_info diff --git a/api/app/repositories/end_user_info_repository.py b/api/app/repositories/end_user_info_repository.py index f9f4665c..f627b46f 100644 --- a/api/app/repositories/end_user_info_repository.py +++ b/api/app/repositories/end_user_info_repository.py @@ -35,9 +35,10 @@ class EndUserInfoRepository: """根据ID获取用户信息""" return self.db.query(EndUserInfo).filter(EndUserInfo.id == info_id).first() - def get_by_end_user_id(self, end_user_id: uuid.UUID) -> List[EndUserInfo]: - """获取用户的所有信息记录""" - return self.db.query(EndUserInfo).filter(EndUserInfo.end_user_id == end_user_id).all() + + def get_by_end_user_id(self, end_user_id: uuid.UUID) -> Optional[EndUserInfo]: + """获取用户的信息记录""" + return self.db.query(EndUserInfo).filter(EndUserInfo.end_user_id == end_user_id).first() def update(self, info_id: uuid.UUID, aliases: List[str] = None, meta_data: dict = None) -> Optional[EndUserInfo]: """更新用户信息""" diff --git a/api/app/repositories/end_user_repository.py b/api/app/repositories/end_user_repository.py index d8d30618..3c1dd16f 100644 --- a/api/app/repositories/end_user_repository.py +++ b/api/app/repositories/end_user_repository.py @@ -115,8 +115,8 @@ class EndUserRepository: end_user_info = EndUserInfo( end_user_id=end_user.id, other_name=other_name or "", # 如果没有提供 other_name,使用空字符串 - aliases=[], - meta_data=[] + aliases=[], + meta_data={} ) self.db.add(end_user_info) diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index f6239c76..942e01a0 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -387,9 +387,9 @@ class UserMemoryService: # 转换为UUID并查询 user_uuid = uuid.UUID(end_user_id) - end_user_info_records = EndUserInfoRepository(db).get_by_end_user_id(user_uuid) + end_user_info_record = EndUserInfoRepository(db).get_by_end_user_id(user_uuid) - if not end_user_info_records: + if not end_user_info_record: logger.warning(f"终端用户信息记录不存在: end_user_id={end_user_id}") return { "success": False, @@ -397,9 +397,6 @@ class UserMemoryService: "error": "终端用户信息记录不存在" } - # 获取第一条记录 - end_user_info_record = end_user_info_records[0] - # 构建响应数据(转换时间为毫秒时间戳) response_data = { "end_user_info_id": str(end_user_info_record.id), @@ -462,9 +459,9 @@ class UserMemoryService: # 转换为UUID并查询 user_uuid = uuid.UUID(end_user_id) - end_user_info_records = EndUserInfoRepository(db).get_by_end_user_id(user_uuid) + end_user_info_record = EndUserInfoRepository(db).get_by_end_user_id(user_uuid) - if not end_user_info_records: + if not end_user_info_record: logger.warning(f"终端用户信息记录不存在: end_user_id={end_user_id}") return { "success": False, @@ -472,9 +469,6 @@ class UserMemoryService: "error": "终端用户信息记录不存在" } - # 获取第一条记录 - end_user_info_record = end_user_info_records[0] - # 定义允许更新的字段白名单 allowed_fields = {'other_name', 'aliases', 'meta_data'}