refactor(memory): replace raw dict responses with Pydantic schema models in user memory controllers

- Add user_memory_schema.py with typed Pydantic models for all user memory
  API responses: MemoryInsightReportData, UserSummaryData, GraphData,
  MemoryTypeStatItem, cache result models, and RelationshipEvolutionData
- Refactor user_memory_controllers.py to construct schema instances and
  return model_dump() instead of raw dicts
- Remove unused imports (datetime, timestamp_to_datetime, EndUserInfoResponse,
  EndUserInfoCreate, EndUser)
This commit is contained in:
lanceyq
2026-04-27 17:57:06 +08:00
parent 4619b40d03
commit 9a5ce7f7c6
2 changed files with 242 additions and 53 deletions

View File

@@ -2,8 +2,8 @@
用户记忆相关的控制器 用户记忆相关的控制器
包含用户摘要、记忆洞察、节点统计、图数据和用户档案等接口 包含用户摘要、记忆洞察、节点统计、图数据和用户档案等接口
""" """
from typing import Optional from typing import Optional, List
import datetime
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from fastapi import APIRouter, Depends, Header from fastapi import APIRouter, Depends, Header
@@ -12,7 +12,6 @@ from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail from app.core.response_utils import success, fail
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.api_key_utils import timestamp_to_datetime
from app.services.user_memory_service import ( from app.services.user_memory_service import (
UserMemoryService, UserMemoryService,
analytics_memory_types, analytics_memory_types,
@@ -22,14 +21,25 @@ from app.services.user_memory_service import (
from app.services.memory_entity_relationship_service import MemoryEntityService, MemoryEmotion, MemoryInteraction from app.services.memory_entity_relationship_service import MemoryEntityService, MemoryEmotion, MemoryInteraction
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.schemas.memory_storage_schema import GenerateCacheRequest from app.schemas.memory_storage_schema import GenerateCacheRequest
from app.schemas.user_memory_schema import (
MemoryInsightReportData,
UserSummaryData,
SingleUserCacheResultData,
GenerateCacheErrorItem,
WorkspaceCacheResultData,
WorkspaceCacheErrorItem,
MemoryTypeStatItem,
GraphData,
GraphNodeData,
GraphEdgeData,
GraphStatistics,
RelationshipEvolutionData,
)
from app.repositories.workspace_repository import WorkspaceRepository from app.repositories.workspace_repository import WorkspaceRepository
from app.repositories.end_user_repository import EndUserRepository from app.repositories.end_user_repository import EndUserRepository
from app.schemas.end_user_info_schema import ( from app.schemas.end_user_info_schema import (
EndUserInfoResponse,
EndUserInfoCreate,
EndUserInfoUpdate, EndUserInfoUpdate,
) )
from app.models.end_user_model import EndUser
from app.dependencies import get_current_user from app.dependencies import get_current_user
from app.models.user_model import User from app.models.user_model import User
@@ -61,13 +71,22 @@ async def get_memory_insight_report_api(
try: try:
# 调用服务层获取缓存数据 # 调用服务层获取缓存数据
result = await user_memory_service.get_cached_memory_insight(db, end_user_id) result = await user_memory_service.get_cached_memory_insight(db, end_user_id)
data = MemoryInsightReportData(
memory_insight=result.get("memory_insight"),
behavior_pattern=result.get("behavior_pattern"),
key_findings=result.get("key_findings"),
growth_trajectory=result.get("growth_trajectory"),
updated_at=result.get("updated_at"),
is_cached=result["is_cached"],
message=result.get("message"),
)
if result["is_cached"]: if data.is_cached:
api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}") api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}")
return success(data=result, msg="查询成功") return success(data=data.model_dump(), msg="查询成功")
else: else:
api_logger.info(f"记忆洞察报告缓存不存在: end_user_id={end_user_id}") api_logger.info(f"记忆洞察报告缓存不存在: end_user_id={end_user_id}")
return success(data=result, msg="数据尚未生成") return success(data=data.model_dump(), msg="数据尚未生成")
except Exception as e: except Exception as e:
api_logger.error(f"记忆洞察报告查询失败: end_user_id={end_user_id}, error={str(e)}") api_logger.error(f"记忆洞察报告查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告查询失败", str(e))
@@ -105,13 +124,22 @@ async def get_user_summary_api(
try: try:
# 调用服务层获取缓存数据 # 调用服务层获取缓存数据
result = await user_memory_service.get_cached_user_summary(db, end_user_id, model_id, language) result = await user_memory_service.get_cached_user_summary(db, end_user_id, model_id, language)
data = UserSummaryData(
user_summary=result.get("user_summary"),
personality=result.get("personality"),
core_values=result.get("core_values"),
one_sentence=result.get("one_sentence"),
updated_at=result.get("updated_at"),
is_cached=result["is_cached"],
message=result.get("message"),
)
if result["is_cached"]: if data.is_cached:
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}") api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
return success(data=result, msg="查询成功") return success(data=data.model_dump(), msg="查询成功")
else: else:
api_logger.info(f"用户摘要缓存不存在: end_user_id={end_user_id}") api_logger.info(f"用户摘要缓存不存在: end_user_id={end_user_id}")
return success(data=result, msg="数据尚未生成") return success(data=data.model_dump(), msg="数据尚未生成")
except Exception as e: except Exception as e:
api_logger.error(f"用户摘要查询失败: end_user_id={end_user_id}, error={str(e)}") api_logger.error(f"用户摘要查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "用户摘要查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "用户摘要查询失败", str(e))
@@ -165,32 +193,32 @@ async def generate_cache_api(
language=language) language=language)
# 构建响应 # 构建响应
result = { errors: List[GenerateCacheErrorItem] = []
"end_user_id": end_user_id,
"insight_success": insight_result["success"],
"summary_success": summary_result["success"],
"errors": []
}
# 收集错误信息
if not insight_result["success"]: if not insight_result["success"]:
result["errors"].append({ errors.append(GenerateCacheErrorItem(
"type": "insight", type="insight",
"error": insight_result.get("error") error=insight_result.get("error"),
}) ))
if not summary_result["success"]: if not summary_result["success"]:
result["errors"].append({ errors.append(GenerateCacheErrorItem(
"type": "summary", type="summary",
"error": summary_result.get("error") error=summary_result.get("error"),
}) ))
data = SingleUserCacheResultData(
end_user_id=end_user_id,
insight_success=insight_result["success"],
summary_success=summary_result["success"],
errors=errors,
)
# 记录结果 # 记录结果
if result["insight_success"] and result["summary_success"]: if data.insight_success and data.summary_success:
api_logger.info(f"成功为用户 {end_user_id} 生成缓存") api_logger.info(f"成功为用户 {end_user_id} 生成缓存")
else: else:
api_logger.warning(f"用户 {end_user_id} 的缓存生成部分失败: {result['errors']}") api_logger.warning(f"用户 {end_user_id} 的缓存生成部分失败: {[e.model_dump() for e in errors]}")
return success(data=result, msg="生成完成") return success(data=data.model_dump(), msg="生成完成")
else: else:
# 为整个工作空间生成 # 为整个工作空间生成
@@ -198,13 +226,29 @@ async def generate_cache_api(
result = await user_memory_service.generate_cache_for_workspace(db, workspace_id, language=language) result = await user_memory_service.generate_cache_for_workspace(db, workspace_id, language=language)
ws_errors = [
WorkspaceCacheErrorItem(
end_user_id=e.get("end_user_id"),
insight_error=e.get("insight_error"),
summary_error=e.get("summary_error"),
error=e.get("error"),
)
for e in result.get("errors", [])
]
data = WorkspaceCacheResultData(
total_users=result["total_users"],
successful=result["successful"],
failed=result["failed"],
errors=ws_errors,
)
# 记录统计信息 # 记录统计信息
api_logger.info( api_logger.info(
f"工作空间 {workspace_id} 批量生成完成: " f"工作空间 {workspace_id} 批量生成完成: "
f"总数={result['total_users']}, 成功={result['successful']}, 失败={result['failed']}" f"总数={data.total_users}, 成功={data.successful}, 失败={data.failed}"
) )
return success(data=result, msg="批量生成完成") return success(data=data.model_dump(), msg="批量生成完成")
except Exception as e: except Exception as e:
api_logger.error(f"缓存生成失败: user={current_user.username}, error={str(e)}") api_logger.error(f"缓存生成失败: user={current_user.username}, error={str(e)}")
@@ -231,11 +275,21 @@ async def get_node_statistics_api(
# 调用新的记忆类型统计函数 # 调用新的记忆类型统计函数
result = await analytics_memory_types(db, end_user_id) result = await analytics_memory_types(db, end_user_id)
# 使用 schema 模型构建响应
stat_items = [
MemoryTypeStatItem(
type=item["type"],
count=item["count"],
percentage=item["percentage"],
)
for item in result
]
# 计算总数用于日志 # 计算总数用于日志
total_count = sum(item["count"] for item in result) total_count = sum(item.count for item in stat_items)
api_logger.info( api_logger.info(
f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}") f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(stat_items)}")
return success(data=result, msg="查询成功") return success(data=[item.model_dump() for item in stat_items], msg="查询成功")
except Exception as e: except Exception as e:
api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}") api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e))
@@ -286,17 +340,26 @@ async def get_graph_data_api(
depth=depth, depth=depth,
center_node_id=center_node_id center_node_id=center_node_id
) )
# 使用 schema 模型构建响应
data = GraphData(
nodes=[GraphNodeData(**n) for n in result.get("nodes", [])],
edges=[GraphEdgeData(**e) for e in result.get("edges", [])],
statistics=GraphStatistics(**result.get("statistics", {})),
message=result.get("message"),
)
# 检查是否有错误消息 # 检查是否有错误消息
if "message" in result and result["statistics"]["total_nodes"] == 0: if data.message and data.statistics.total_nodes == 0:
api_logger.warning(f"图数据查询返回空结果: {result.get('message')}") api_logger.warning(f"图数据查询返回空结果: {data.message}")
return success(data=result, msg=result.get("message", "查询成功")) return success(data=data.model_dump(), msg=data.message)
api_logger.info( api_logger.info(
f"成功获取图数据: end_user_id={end_user_id}, " f"成功获取图数据: end_user_id={end_user_id}, "
f"nodes={result['statistics']['total_nodes']}, " f"nodes={data.statistics.total_nodes}, "
f"edges={result['statistics']['total_edges']}" f"edges={data.statistics.total_edges}"
) )
return success(data=result, msg="查询成功") return success(data=data.model_dump(), msg="查询成功")
except Exception as e: except Exception as e:
api_logger.error(f"图数据查询失败: end_user_id={end_user_id}, error={str(e)}") api_logger.error(f"图数据查询失败: end_user_id={end_user_id}, error={str(e)}")
@@ -323,16 +386,24 @@ async def get_community_graph_data_api(
try: try:
result = await analytics_community_graph_data(db=db, end_user_id=end_user_id) result = await analytics_community_graph_data(db=db, end_user_id=end_user_id)
if "message" in result and result["statistics"]["total_nodes"] == 0: # 使用 schema 模型构建响应
api_logger.warning(f"社区图谱查询返回空结果: {result.get('message')}") data = GraphData(
return success(data=result, msg=result.get("message", "查询成功")) nodes=[GraphNodeData(**n) for n in result.get("nodes", [])],
edges=[GraphEdgeData(**e) for e in result.get("edges", [])],
statistics=GraphStatistics(**result.get("statistics", {})),
message=result.get("message"),
)
if data.message and data.statistics.total_nodes == 0:
api_logger.warning(f"社区图谱查询返回空结果: {data.message}")
return success(data=data.model_dump(), msg=data.message)
api_logger.info( api_logger.info(
f"成功获取社区图谱: end_user_id={end_user_id}, " f"成功获取社区图谱: end_user_id={end_user_id}, "
f"nodes={result['statistics']['total_nodes']}, " f"nodes={data.statistics.total_nodes}, "
f"edges={result['statistics']['total_edges']}" f"edges={data.statistics.total_edges}"
) )
return success(data=result, msg="查询成功") return success(data=data.model_dump(), msg="查询成功")
except Exception as e: except Exception as e:
api_logger.error(f"社区图谱查询失败: end_user_id={end_user_id}, error={str(e)}") api_logger.error(f"社区图谱查询失败: end_user_id={end_user_id}, error={str(e)}")
@@ -495,13 +566,13 @@ async def memory_space_relationship_evolution(id: str, label: str,
await emotion.close() await emotion.close()
await interaction.close() await interaction.close()
result = { data = RelationshipEvolutionData(
"emotion": emotion_result, emotion=emotion_result,
"interaction": interaction_result interaction=interaction_result,
} )
api_logger.info(f"关系演变查询成功: id={id}, table={label}") api_logger.info(f"关系演变查询成功: id={id}, table={label}")
return success(data=result, msg="关系演变") return success(data=data.model_dump(), msg="关系演变")
except Exception as e: except Exception as e:
api_logger.error(f"关系演变查询失败: id={id}, table={label}, error={str(e)}", exc_info=True) api_logger.error(f"关系演变查询失败: id={id}, table={label}, error={str(e)}", exc_info=True)

View File

@@ -0,0 +1,118 @@
"""
用户记忆相关的请求和响应模型
包含用户摘要、记忆洞察、节点统计、图数据和用户档案等接口的 Schema
"""
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
# ==================== 记忆洞察报告 ====================
class MemoryInsightReportData(BaseModel):
"""记忆洞察报告数据"""
memory_insight: Optional[str] = Field(None, description="总体概述")
behavior_pattern: Optional[str] = Field(None, description="行为模式")
key_findings: Optional[List[str]] = Field(None, description="关键发现")
growth_trajectory: Optional[str] = Field(None, description="成长轨迹")
updated_at: Optional[int] = Field(None, description="更新时间戳(毫秒)")
is_cached: bool = Field(..., description="是否有缓存数据")
message: Optional[str] = Field(None, description="附加消息")
# ==================== 用户摘要 ====================
class UserSummaryData(BaseModel):
"""用户摘要数据"""
user_summary: Optional[str] = Field(None, description="用户摘要")
personality: Optional[str] = Field(None, description="性格特征")
core_values: Optional[str] = Field(None, description="核心价值观")
one_sentence: Optional[str] = Field(None, description="一句话总结")
updated_at: Optional[int] = Field(None, description="更新时间戳(毫秒)")
is_cached: bool = Field(..., description="是否有缓存数据")
message: Optional[str] = Field(None, description="附加消息")
# ==================== 缓存生成 ====================
class GenerateCacheErrorItem(BaseModel):
"""缓存生成错误项"""
type: Optional[str] = Field(None, description="错误类型 (insight/summary)")
error: Optional[str] = Field(None, description="错误信息")
class SingleUserCacheResultData(BaseModel):
"""单用户缓存生成结果"""
end_user_id: str = Field(..., description="终端用户ID")
insight_success: bool = Field(..., description="洞察生成是否成功")
summary_success: bool = Field(..., description="摘要生成是否成功")
errors: List[GenerateCacheErrorItem] = Field(default_factory=list, description="错误列表")
class WorkspaceCacheErrorItem(BaseModel):
"""工作空间缓存生成错误项"""
end_user_id: Optional[str] = Field(None, description="终端用户ID")
insight_error: Optional[str] = Field(None, description="洞察生成错误")
summary_error: Optional[str] = Field(None, description="摘要生成错误")
error: Optional[str] = Field(None, description="通用错误信息")
class WorkspaceCacheResultData(BaseModel):
"""工作空间批量缓存生成结果"""
total_users: int = Field(..., description="总用户数")
successful: int = Field(..., description="成功数")
failed: int = Field(..., description="失败数")
errors: List[WorkspaceCacheErrorItem] = Field(default_factory=list, description="错误列表")
# ==================== 节点统计 ====================
class MemoryTypeStatItem(BaseModel):
"""记忆类型统计项"""
type: str = Field(..., description="记忆类型枚举值")
count: int = Field(..., description="该类型的数量")
percentage: float = Field(..., description="该类型在所有记忆中的占比")
# ==================== 图数据 ====================
class GraphNodeData(BaseModel):
"""图节点数据"""
id: str = Field(..., description="节点ID")
label: str = Field(..., description="节点类型标签")
properties: Dict[str, Any] = Field(default_factory=dict, description="节点属性")
caption: Optional[str] = Field(None, description="节点显示名称")
class GraphEdgeData(BaseModel):
"""图边数据"""
id: str = Field(..., description="边ID")
source: str = Field(..., description="源节点ID")
target: str = Field(..., description="目标节点ID")
type: Optional[str] = Field(None, description="关系类型")
properties: Dict[str, Any] = Field(default_factory=dict, description="边属性")
caption: Optional[str] = Field(None, description="边显示名称")
class GraphStatistics(BaseModel):
"""图统计信息"""
total_nodes: int = Field(0, description="节点总数")
total_edges: int = Field(0, description="边总数")
node_types: Dict[str, int] = Field(default_factory=dict, description="各节点类型数量")
edge_types: Optional[Dict[str, int]] = Field(default_factory=dict, description="各边类型数量")
class GraphData(BaseModel):
"""图数据响应"""
nodes: List[GraphNodeData] = Field(..., description="节点列表")
edges: List[GraphEdgeData] = Field(..., description="边列表")
statistics: GraphStatistics = Field(..., description="统计信息")
message: Optional[str] = Field(None, description="附加消息")
# ==================== 关系演变 ====================
class RelationshipEvolutionData(BaseModel):
"""关系演变数据"""
emotion: Any = Field(None, description="情绪数据")
interaction: Any = Field(None, description="交互频率数据")