[fix] Fix the code according to the comments

This commit is contained in:
lanceyq
2026-03-25 16:36:20 +08:00
parent 38c6c7f053
commit 1e986c641f
4 changed files with 23 additions and 37 deletions

View File

@@ -61,7 +61,7 @@ __all__ = [
"AppRelease", "AppRelease",
"MemoryIncrement", "MemoryIncrement",
"EndUser", "EndUser",
"UserAlias", "EndUserInfo",
"AppShare", "AppShare",
"ReleaseShare", "ReleaseShare",
"Conversation", "Conversation",

View File

@@ -1,7 +1,7 @@
import datetime import datetime
import uuid import uuid
from sqlalchemy import Column, DateTime, ForeignKey, String, Text from sqlalchemy import Column, DateTime, ForeignKey, String, Text, ARRAY
from sqlalchemy.dialects.postgresql import UUID, JSONB from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
@@ -15,7 +15,7 @@ class EndUserInfo(Base):
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, nullable=False, index=True) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, nullable=False, index=True)
end_user_id = Column(UUID(as_uuid=True), ForeignKey("end_users.id"), nullable=False, index=True, comment="关联的终端用户ID") end_user_id = Column(UUID(as_uuid=True), ForeignKey("end_users.id"), nullable=False, index=True, comment="关联的终端用户ID")
other_name = Column(String, nullable=False, comment="关联的用户名称") other_name = Column(String, nullable=False, comment="关联的用户名称")
aliases = Column(JSONB, nullable=True, comment="用户别名列表(JSON数组)") aliases = Column(ARRAY(String), nullable=True, comment="用户别名列表(字符串数组)")
meta_data = Column(JSONB, nullable=True, comment="用户相关的扩展信息JSON格式") meta_data = Column(JSONB, nullable=True, comment="用户相关的扩展信息JSON格式")
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")

View File

@@ -17,18 +17,18 @@ class EndUserInfoRepository:
def __init__(self, db: Session): def __init__(self, db: Session):
self.db = db self.db = db
def create(self, end_user_id: uuid.UUID, other_name: str, alias: str = None, meta_data: dict = None) -> EndUserInfo: def create(self, end_user_id: uuid.UUID, other_name: str, aliases: List[str] = None, meta_data: dict = None) -> EndUserInfo:
"""创建终端用户信息""" """创建终端用户信息"""
end_user_info = EndUserInfo( end_user_info = EndUserInfo(
end_user_id=end_user_id, end_user_id=end_user_id,
other_name=other_name, other_name=other_name,
alias=alias, aliases=aliases or [],
meta_data=meta_data meta_data=meta_data
) )
self.db.add(end_user_info) self.db.add(end_user_info)
self.db.commit() self.db.commit()
self.db.refresh(end_user_info) self.db.refresh(end_user_info)
logger.info(f"创建终端用户信息: end_user_id={end_user_id}, alias={alias}") logger.info(f"创建终端用户信息: end_user_id={end_user_id}, aliases={aliases}")
return end_user_info return end_user_info
def get_by_id(self, info_id: uuid.UUID) -> Optional[EndUserInfo]: def get_by_id(self, info_id: uuid.UUID) -> Optional[EndUserInfo]:
@@ -39,12 +39,12 @@ class EndUserInfoRepository:
"""获取用户的所有信息记录""" """获取用户的所有信息记录"""
return self.db.query(EndUserInfo).filter(EndUserInfo.end_user_id == end_user_id).all() return self.db.query(EndUserInfo).filter(EndUserInfo.end_user_id == end_user_id).all()
def update(self, info_id: uuid.UUID, alias: str = None, meta_data: dict = None) -> Optional[EndUserInfo]: def update(self, info_id: uuid.UUID, aliases: List[str] = None, meta_data: dict = None) -> Optional[EndUserInfo]:
"""更新用户信息""" """更新用户信息"""
end_user_info = self.get_by_id(info_id) end_user_info = self.get_by_id(info_id)
if end_user_info: if end_user_info:
if alias is not None: if aliases is not None:
end_user_info.alias = alias end_user_info.aliases = aliases
if meta_data is not None: if meta_data is not None:
end_user_info.meta_data = meta_data end_user_info.meta_data = meta_data
self.db.commit() self.db.commit()
@@ -68,23 +68,3 @@ class EndUserInfoRepository:
self.db.commit() self.db.commit()
logger.info(f"删除用户所有信息记录: end_user_id={end_user_id}, count={count}") logger.info(f"删除用户所有信息记录: end_user_id={end_user_id}, count={count}")
return count return count
def batch_create(self, end_user_id: uuid.UUID, other_name: str, aliases: List[str]) -> List[EndUserInfo]:
"""批量创建用户信息"""
end_user_infos = []
for alias in aliases:
if alias and alias.strip():
end_user_info = EndUserInfo(
end_user_id=end_user_id,
other_name=other_name,
alias=alias.strip()
)
self.db.add(end_user_info)
end_user_infos.append(end_user_info)
self.db.commit()
for end_user_info in end_user_infos:
self.db.refresh(end_user_info)
logger.info(f"批量创建终端用户信息: end_user_id={end_user_id}, count={len(end_user_infos)}")
return end_user_infos

View File

@@ -382,14 +382,14 @@ class UserMemoryService:
} }
""" """
try: try:
from app.models.end_user_info_model import EndUserInfo from app.repositories.end_user_info_repository import EndUserInfoRepository
from app.core.api_key_utils import datetime_to_timestamp from app.core.api_key_utils import datetime_to_timestamp
# 转换为UUID并查询 # 转换为UUID并查询
user_uuid = uuid.UUID(end_user_id) user_uuid = uuid.UUID(end_user_id)
end_user_info_record = db.query(EndUserInfo).filter(EndUserInfo.end_user_id == user_uuid).first() end_user_info_records = EndUserInfoRepository(db).get_by_end_user_id(user_uuid)
if not end_user_info_record: if not end_user_info_records:
logger.warning(f"终端用户信息记录不存在: end_user_id={end_user_id}") logger.warning(f"终端用户信息记录不存在: end_user_id={end_user_id}")
return { return {
"success": False, "success": False,
@@ -397,6 +397,9 @@ class UserMemoryService:
"error": "终端用户信息记录不存在" "error": "终端用户信息记录不存在"
} }
# 获取第一条记录
end_user_info_record = end_user_info_records[0]
# 构建响应数据(转换时间为毫秒时间戳) # 构建响应数据(转换时间为毫秒时间戳)
response_data = { response_data = {
"end_user_info_id": str(end_user_info_record.id), "end_user_info_id": str(end_user_info_record.id),
@@ -453,15 +456,15 @@ class UserMemoryService:
} }
""" """
try: try:
from app.models.end_user_info_model import EndUserInfo from app.repositories.end_user_info_repository import EndUserInfoRepository
from app.models.end_user_model import EndUser from app.repositories.end_user_repository import EndUserRepository
from app.core.api_key_utils import datetime_to_timestamp from app.core.api_key_utils import datetime_to_timestamp
# 转换为UUID并查询 # 转换为UUID并查询
user_uuid = uuid.UUID(end_user_id) user_uuid = uuid.UUID(end_user_id)
end_user_info_record = db.query(EndUserInfo).filter(EndUserInfo.end_user_id == user_uuid).first() end_user_info_records = EndUserInfoRepository(db).get_by_end_user_id(user_uuid)
if not end_user_info_record: if not end_user_info_records:
logger.warning(f"终端用户信息记录不存在: end_user_id={end_user_id}") logger.warning(f"终端用户信息记录不存在: end_user_id={end_user_id}")
return { return {
"success": False, "success": False,
@@ -469,6 +472,9 @@ class UserMemoryService:
"error": "终端用户信息记录不存在" "error": "终端用户信息记录不存在"
} }
# 获取第一条记录
end_user_info_record = end_user_info_records[0]
# 定义允许更新的字段白名单 # 定义允许更新的字段白名单
allowed_fields = {'other_name', 'aliases', 'meta_data'} allowed_fields = {'other_name', 'aliases', 'meta_data'}
@@ -488,7 +494,7 @@ class UserMemoryService:
# 如果 other_name 被更新,同步更新 end_user 表 # 如果 other_name 被更新,同步更新 end_user 表
if other_name_updated: if other_name_updated:
end_user_record = db.query(EndUser).filter(EndUser.id == user_uuid).first() end_user_record = EndUserRepository(db).get_by_id(user_uuid)
if end_user_record: if end_user_record:
end_user_record.other_name = update_data['other_name'] end_user_record.other_name = update_data['other_name']
end_user_record.updated_at = datetime.now() end_user_record.updated_at = datetime.now()