[changes] Spatial verification, retrieval synchronization
This commit is contained in:
@@ -23,6 +23,7 @@ from app.services.memory_entity_relationship_service import MemoryEntityService,
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
||||
from app.repositories.workspace_repository import WorkspaceRepository
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas.end_user_info_schema import (
|
||||
EndUserInfoResponse,
|
||||
EndUserInfoCreate,
|
||||
@@ -361,6 +362,17 @@ async def get_end_user_info(
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
# 校验 end_user 是否属于当前工作空间
|
||||
end_user_repo = EndUserRepository(db)
|
||||
end_user = end_user_repo.get_end_user_by_id(end_user_id)
|
||||
if end_user is None:
|
||||
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", "end_user not found")
|
||||
if str(end_user.workspace_id) != str(workspace_id):
|
||||
api_logger.warning(
|
||||
f"用户 {current_user.username} 尝试查询不属于工作空间 {workspace_id} 的终端用户 {end_user_id}"
|
||||
)
|
||||
return fail(BizCode.PERMISSION_DENIED, "该终端用户不属于当前工作空间", "end_user workspace mismatch")
|
||||
|
||||
result = user_memory_service.get_end_user_info(db, end_user_id)
|
||||
|
||||
if result["success"]:
|
||||
@@ -409,6 +421,17 @@ async def update_end_user_info(
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
# 校验 end_user 是否属于当前工作空间
|
||||
end_user_repo = EndUserRepository(db)
|
||||
end_user = end_user_repo.get_end_user_by_id(end_user_id)
|
||||
if end_user is None:
|
||||
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", "end_user not found")
|
||||
if str(end_user.workspace_id) != str(workspace_id):
|
||||
api_logger.warning(
|
||||
f"用户 {current_user.username} 尝试更新不属于工作空间 {workspace_id} 的终端用户 {end_user_id}"
|
||||
)
|
||||
return fail(BizCode.PERMISSION_DENIED, "该终端用户不属于当前工作空间", "end_user workspace mismatch")
|
||||
|
||||
# 获取更新数据(排除 end_user_id)
|
||||
update_data = info_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
|
||||
|
||||
|
||||
@@ -1389,9 +1389,8 @@ class ExtractionOrchestrator:
|
||||
logger.debug(f"end_user 表 other_name 保持不变: {end_user.other_name}")
|
||||
|
||||
# 更新或创建 end_user_info 记录
|
||||
existing_infos = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
|
||||
if existing_infos:
|
||||
info = existing_infos[0]
|
||||
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
|
||||
if info:
|
||||
new_name_info = self._resolve_other_name(info.other_name, current_aliases, neo4j_aliases)
|
||||
if new_name_info is not None:
|
||||
info.other_name = new_name_info
|
||||
|
||||
@@ -35,9 +35,10 @@ class EndUserInfoRepository:
|
||||
"""根据ID获取用户信息"""
|
||||
return self.db.query(EndUserInfo).filter(EndUserInfo.id == info_id).first()
|
||||
|
||||
def get_by_end_user_id(self, end_user_id: uuid.UUID) -> List[EndUserInfo]:
|
||||
"""获取用户的所有信息记录"""
|
||||
return self.db.query(EndUserInfo).filter(EndUserInfo.end_user_id == end_user_id).all()
|
||||
|
||||
def get_by_end_user_id(self, end_user_id: uuid.UUID) -> Optional[EndUserInfo]:
|
||||
"""获取用户的信息记录"""
|
||||
return self.db.query(EndUserInfo).filter(EndUserInfo.end_user_id == end_user_id).first()
|
||||
|
||||
def update(self, info_id: uuid.UUID, aliases: List[str] = None, meta_data: dict = None) -> Optional[EndUserInfo]:
|
||||
"""更新用户信息"""
|
||||
|
||||
@@ -115,8 +115,8 @@ class EndUserRepository:
|
||||
end_user_info = EndUserInfo(
|
||||
end_user_id=end_user.id,
|
||||
other_name=other_name or "", # 如果没有提供 other_name,使用空字符串
|
||||
aliases=[],
|
||||
meta_data=[]
|
||||
aliases=[],
|
||||
meta_data={}
|
||||
)
|
||||
self.db.add(end_user_info)
|
||||
|
||||
|
||||
@@ -387,9 +387,9 @@ class UserMemoryService:
|
||||
|
||||
# 转换为UUID并查询
|
||||
user_uuid = uuid.UUID(end_user_id)
|
||||
end_user_info_records = EndUserInfoRepository(db).get_by_end_user_id(user_uuid)
|
||||
end_user_info_record = EndUserInfoRepository(db).get_by_end_user_id(user_uuid)
|
||||
|
||||
if not end_user_info_records:
|
||||
if not end_user_info_record:
|
||||
logger.warning(f"终端用户信息记录不存在: end_user_id={end_user_id}")
|
||||
return {
|
||||
"success": False,
|
||||
@@ -397,9 +397,6 @@ class UserMemoryService:
|
||||
"error": "终端用户信息记录不存在"
|
||||
}
|
||||
|
||||
# 获取第一条记录
|
||||
end_user_info_record = end_user_info_records[0]
|
||||
|
||||
# 构建响应数据(转换时间为毫秒时间戳)
|
||||
response_data = {
|
||||
"end_user_info_id": str(end_user_info_record.id),
|
||||
@@ -462,9 +459,9 @@ class UserMemoryService:
|
||||
|
||||
# 转换为UUID并查询
|
||||
user_uuid = uuid.UUID(end_user_id)
|
||||
end_user_info_records = EndUserInfoRepository(db).get_by_end_user_id(user_uuid)
|
||||
end_user_info_record = EndUserInfoRepository(db).get_by_end_user_id(user_uuid)
|
||||
|
||||
if not end_user_info_records:
|
||||
if not end_user_info_record:
|
||||
logger.warning(f"终端用户信息记录不存在: end_user_id={end_user_id}")
|
||||
return {
|
||||
"success": False,
|
||||
@@ -472,9 +469,6 @@ class UserMemoryService:
|
||||
"error": "终端用户信息记录不存在"
|
||||
}
|
||||
|
||||
# 获取第一条记录
|
||||
end_user_info_record = end_user_info_records[0]
|
||||
|
||||
# 定义允许更新的字段白名单
|
||||
allowed_fields = {'other_name', 'aliases', 'meta_data'}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user