From 9a5ce7f7c65477372486b82d813a7efda12b970b Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 27 Apr 2026 17:57:06 +0800 Subject: [PATCH 1/2] refactor(memory): replace raw dict responses with Pydantic schema models in user memory controllers - Add user_memory_schema.py with typed Pydantic models for all user memory API responses: MemoryInsightReportData, UserSummaryData, GraphData, MemoryTypeStatItem, cache result models, and RelationshipEvolutionData - Refactor user_memory_controllers.py to construct schema instances and return model_dump() instead of raw dicts - Remove unused imports (datetime, timestamp_to_datetime, EndUserInfoResponse, EndUserInfoCreate, EndUser) --- .../controllers/user_memory_controllers.py | 177 ++++++++++++------ api/app/schemas/user_memory_schema.py | 118 ++++++++++++ 2 files changed, 242 insertions(+), 53 deletions(-) create mode 100644 api/app/schemas/user_memory_schema.py diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py index 10b396a7..e7f5db4d 100644 --- a/api/app/controllers/user_memory_controllers.py +++ b/api/app/controllers/user_memory_controllers.py @@ -2,8 +2,8 @@ 用户记忆相关的控制器 包含用户摘要、记忆洞察、节点统计、图数据和用户档案等接口 """ -from typing import Optional -import datetime +from typing import Optional, List + from sqlalchemy.orm import Session from fastapi import APIRouter, Depends, Header @@ -12,7 +12,6 @@ from app.core.language_utils import get_language_from_header from app.core.logging_config import get_api_logger from app.core.response_utils import success, fail from app.core.error_codes import BizCode -from app.core.api_key_utils import timestamp_to_datetime from app.services.user_memory_service import ( UserMemoryService, analytics_memory_types, @@ -22,14 +21,25 @@ from app.services.user_memory_service import ( from app.services.memory_entity_relationship_service import MemoryEntityService, MemoryEmotion, MemoryInteraction from app.schemas.response_schema import ApiResponse from app.schemas.memory_storage_schema import GenerateCacheRequest +from app.schemas.user_memory_schema import ( + MemoryInsightReportData, + UserSummaryData, + SingleUserCacheResultData, + GenerateCacheErrorItem, + WorkspaceCacheResultData, + WorkspaceCacheErrorItem, + MemoryTypeStatItem, + GraphData, + GraphNodeData, + GraphEdgeData, + GraphStatistics, + RelationshipEvolutionData, +) from app.repositories.workspace_repository import WorkspaceRepository from app.repositories.end_user_repository import EndUserRepository from app.schemas.end_user_info_schema import ( - EndUserInfoResponse, - EndUserInfoCreate, EndUserInfoUpdate, ) -from app.models.end_user_model import EndUser from app.dependencies import get_current_user from app.models.user_model import User @@ -61,13 +71,22 @@ async def get_memory_insight_report_api( try: # 调用服务层获取缓存数据 result = await user_memory_service.get_cached_memory_insight(db, end_user_id) + data = MemoryInsightReportData( + memory_insight=result.get("memory_insight"), + behavior_pattern=result.get("behavior_pattern"), + key_findings=result.get("key_findings"), + growth_trajectory=result.get("growth_trajectory"), + updated_at=result.get("updated_at"), + is_cached=result["is_cached"], + message=result.get("message"), + ) - if result["is_cached"]: + if data.is_cached: api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}") - return success(data=result, msg="查询成功") + return success(data=data.model_dump(), msg="查询成功") else: api_logger.info(f"记忆洞察报告缓存不存在: end_user_id={end_user_id}") - return success(data=result, msg="数据尚未生成") + return success(data=data.model_dump(), msg="数据尚未生成") except Exception as e: api_logger.error(f"记忆洞察报告查询失败: end_user_id={end_user_id}, error={str(e)}") return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告查询失败", str(e)) @@ -105,13 +124,22 @@ async def get_user_summary_api( try: # 调用服务层获取缓存数据 result = await user_memory_service.get_cached_user_summary(db, end_user_id, model_id, language) + data = UserSummaryData( + user_summary=result.get("user_summary"), + personality=result.get("personality"), + core_values=result.get("core_values"), + one_sentence=result.get("one_sentence"), + updated_at=result.get("updated_at"), + is_cached=result["is_cached"], + message=result.get("message"), + ) - if result["is_cached"]: + if data.is_cached: api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}") - return success(data=result, msg="查询成功") + return success(data=data.model_dump(), msg="查询成功") else: api_logger.info(f"用户摘要缓存不存在: end_user_id={end_user_id}") - return success(data=result, msg="数据尚未生成") + return success(data=data.model_dump(), msg="数据尚未生成") except Exception as e: api_logger.error(f"用户摘要查询失败: end_user_id={end_user_id}, error={str(e)}") return fail(BizCode.INTERNAL_ERROR, "用户摘要查询失败", str(e)) @@ -165,32 +193,32 @@ async def generate_cache_api( language=language) # 构建响应 - result = { - "end_user_id": end_user_id, - "insight_success": insight_result["success"], - "summary_success": summary_result["success"], - "errors": [] - } - - # 收集错误信息 + errors: List[GenerateCacheErrorItem] = [] if not insight_result["success"]: - result["errors"].append({ - "type": "insight", - "error": insight_result.get("error") - }) + errors.append(GenerateCacheErrorItem( + type="insight", + error=insight_result.get("error"), + )) if not summary_result["success"]: - result["errors"].append({ - "type": "summary", - "error": summary_result.get("error") - }) + errors.append(GenerateCacheErrorItem( + type="summary", + error=summary_result.get("error"), + )) + + data = SingleUserCacheResultData( + end_user_id=end_user_id, + insight_success=insight_result["success"], + summary_success=summary_result["success"], + errors=errors, + ) # 记录结果 - if result["insight_success"] and result["summary_success"]: + if data.insight_success and data.summary_success: api_logger.info(f"成功为用户 {end_user_id} 生成缓存") else: - api_logger.warning(f"用户 {end_user_id} 的缓存生成部分失败: {result['errors']}") + api_logger.warning(f"用户 {end_user_id} 的缓存生成部分失败: {[e.model_dump() for e in errors]}") - return success(data=result, msg="生成完成") + return success(data=data.model_dump(), msg="生成完成") else: # 为整个工作空间生成 @@ -198,13 +226,29 @@ async def generate_cache_api( result = await user_memory_service.generate_cache_for_workspace(db, workspace_id, language=language) + ws_errors = [ + WorkspaceCacheErrorItem( + end_user_id=e.get("end_user_id"), + insight_error=e.get("insight_error"), + summary_error=e.get("summary_error"), + error=e.get("error"), + ) + for e in result.get("errors", []) + ] + data = WorkspaceCacheResultData( + total_users=result["total_users"], + successful=result["successful"], + failed=result["failed"], + errors=ws_errors, + ) + # 记录统计信息 api_logger.info( f"工作空间 {workspace_id} 批量生成完成: " - f"总数={result['total_users']}, 成功={result['successful']}, 失败={result['failed']}" + f"总数={data.total_users}, 成功={data.successful}, 失败={data.failed}" ) - return success(data=result, msg="批量生成完成") + return success(data=data.model_dump(), msg="批量生成完成") except Exception as e: api_logger.error(f"缓存生成失败: user={current_user.username}, error={str(e)}") @@ -231,11 +275,21 @@ async def get_node_statistics_api( # 调用新的记忆类型统计函数 result = await analytics_memory_types(db, end_user_id) + # 使用 schema 模型构建响应 + stat_items = [ + MemoryTypeStatItem( + type=item["type"], + count=item["count"], + percentage=item["percentage"], + ) + for item in result + ] + # 计算总数用于日志 - total_count = sum(item["count"] for item in result) + total_count = sum(item.count for item in stat_items) api_logger.info( - f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}") - return success(data=result, msg="查询成功") + f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(stat_items)}") + return success(data=[item.model_dump() for item in stat_items], msg="查询成功") except Exception as e: api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}") return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e)) @@ -286,17 +340,26 @@ async def get_graph_data_api( depth=depth, center_node_id=center_node_id ) + + # 使用 schema 模型构建响应 + data = GraphData( + nodes=[GraphNodeData(**n) for n in result.get("nodes", [])], + edges=[GraphEdgeData(**e) for e in result.get("edges", [])], + statistics=GraphStatistics(**result.get("statistics", {})), + message=result.get("message"), + ) + # 检查是否有错误消息 - if "message" in result and result["statistics"]["total_nodes"] == 0: - api_logger.warning(f"图数据查询返回空结果: {result.get('message')}") - return success(data=result, msg=result.get("message", "查询成功")) + if data.message and data.statistics.total_nodes == 0: + api_logger.warning(f"图数据查询返回空结果: {data.message}") + return success(data=data.model_dump(), msg=data.message) api_logger.info( f"成功获取图数据: end_user_id={end_user_id}, " - f"nodes={result['statistics']['total_nodes']}, " - f"edges={result['statistics']['total_edges']}" + f"nodes={data.statistics.total_nodes}, " + f"edges={data.statistics.total_edges}" ) - return success(data=result, msg="查询成功") + return success(data=data.model_dump(), msg="查询成功") except Exception as e: api_logger.error(f"图数据查询失败: end_user_id={end_user_id}, error={str(e)}") @@ -323,16 +386,24 @@ async def get_community_graph_data_api( try: result = await analytics_community_graph_data(db=db, end_user_id=end_user_id) - if "message" in result and result["statistics"]["total_nodes"] == 0: - api_logger.warning(f"社区图谱查询返回空结果: {result.get('message')}") - return success(data=result, msg=result.get("message", "查询成功")) + # 使用 schema 模型构建响应 + data = GraphData( + nodes=[GraphNodeData(**n) for n in result.get("nodes", [])], + edges=[GraphEdgeData(**e) for e in result.get("edges", [])], + statistics=GraphStatistics(**result.get("statistics", {})), + message=result.get("message"), + ) + + if data.message and data.statistics.total_nodes == 0: + api_logger.warning(f"社区图谱查询返回空结果: {data.message}") + return success(data=data.model_dump(), msg=data.message) api_logger.info( f"成功获取社区图谱: end_user_id={end_user_id}, " - f"nodes={result['statistics']['total_nodes']}, " - f"edges={result['statistics']['total_edges']}" + f"nodes={data.statistics.total_nodes}, " + f"edges={data.statistics.total_edges}" ) - return success(data=result, msg="查询成功") + return success(data=data.model_dump(), msg="查询成功") except Exception as e: api_logger.error(f"社区图谱查询失败: end_user_id={end_user_id}, error={str(e)}") @@ -495,13 +566,13 @@ async def memory_space_relationship_evolution(id: str, label: str, await emotion.close() await interaction.close() - result = { - "emotion": emotion_result, - "interaction": interaction_result - } + data = RelationshipEvolutionData( + emotion=emotion_result, + interaction=interaction_result, + ) api_logger.info(f"关系演变查询成功: id={id}, table={label}") - return success(data=result, msg="关系演变") + return success(data=data.model_dump(), msg="关系演变") except Exception as e: api_logger.error(f"关系演变查询失败: id={id}, table={label}, error={str(e)}", exc_info=True) diff --git a/api/app/schemas/user_memory_schema.py b/api/app/schemas/user_memory_schema.py new file mode 100644 index 00000000..ea6570b3 --- /dev/null +++ b/api/app/schemas/user_memory_schema.py @@ -0,0 +1,118 @@ +""" +用户记忆相关的请求和响应模型 +包含用户摘要、记忆洞察、节点统计、图数据和用户档案等接口的 Schema +""" +from typing import Optional, List, Dict, Any + +from pydantic import BaseModel, Field + + +# ==================== 记忆洞察报告 ==================== + +class MemoryInsightReportData(BaseModel): + """记忆洞察报告数据""" + memory_insight: Optional[str] = Field(None, description="总体概述") + behavior_pattern: Optional[str] = Field(None, description="行为模式") + key_findings: Optional[List[str]] = Field(None, description="关键发现") + growth_trajectory: Optional[str] = Field(None, description="成长轨迹") + updated_at: Optional[int] = Field(None, description="更新时间戳(毫秒)") + is_cached: bool = Field(..., description="是否有缓存数据") + message: Optional[str] = Field(None, description="附加消息") + + +# ==================== 用户摘要 ==================== + +class UserSummaryData(BaseModel): + """用户摘要数据""" + user_summary: Optional[str] = Field(None, description="用户摘要") + personality: Optional[str] = Field(None, description="性格特征") + core_values: Optional[str] = Field(None, description="核心价值观") + one_sentence: Optional[str] = Field(None, description="一句话总结") + updated_at: Optional[int] = Field(None, description="更新时间戳(毫秒)") + is_cached: bool = Field(..., description="是否有缓存数据") + message: Optional[str] = Field(None, description="附加消息") + + +# ==================== 缓存生成 ==================== + +class GenerateCacheErrorItem(BaseModel): + """缓存生成错误项""" + type: Optional[str] = Field(None, description="错误类型 (insight/summary)") + error: Optional[str] = Field(None, description="错误信息") + + +class SingleUserCacheResultData(BaseModel): + """单用户缓存生成结果""" + end_user_id: str = Field(..., description="终端用户ID") + insight_success: bool = Field(..., description="洞察生成是否成功") + summary_success: bool = Field(..., description="摘要生成是否成功") + errors: List[GenerateCacheErrorItem] = Field(default_factory=list, description="错误列表") + + +class WorkspaceCacheErrorItem(BaseModel): + """工作空间缓存生成错误项""" + end_user_id: Optional[str] = Field(None, description="终端用户ID") + insight_error: Optional[str] = Field(None, description="洞察生成错误") + summary_error: Optional[str] = Field(None, description="摘要生成错误") + error: Optional[str] = Field(None, description="通用错误信息") + + +class WorkspaceCacheResultData(BaseModel): + """工作空间批量缓存生成结果""" + total_users: int = Field(..., description="总用户数") + successful: int = Field(..., description="成功数") + failed: int = Field(..., description="失败数") + errors: List[WorkspaceCacheErrorItem] = Field(default_factory=list, description="错误列表") + + +# ==================== 节点统计 ==================== + +class MemoryTypeStatItem(BaseModel): + """记忆类型统计项""" + type: str = Field(..., description="记忆类型枚举值") + count: int = Field(..., description="该类型的数量") + percentage: float = Field(..., description="该类型在所有记忆中的占比") + + +# ==================== 图数据 ==================== + +class GraphNodeData(BaseModel): + """图节点数据""" + id: str = Field(..., description="节点ID") + label: str = Field(..., description="节点类型标签") + properties: Dict[str, Any] = Field(default_factory=dict, description="节点属性") + caption: Optional[str] = Field(None, description="节点显示名称") + + +class GraphEdgeData(BaseModel): + """图边数据""" + id: str = Field(..., description="边ID") + source: str = Field(..., description="源节点ID") + target: str = Field(..., description="目标节点ID") + type: Optional[str] = Field(None, description="关系类型") + properties: Dict[str, Any] = Field(default_factory=dict, description="边属性") + caption: Optional[str] = Field(None, description="边显示名称") + + +class GraphStatistics(BaseModel): + """图统计信息""" + total_nodes: int = Field(0, description="节点总数") + total_edges: int = Field(0, description="边总数") + node_types: Dict[str, int] = Field(default_factory=dict, description="各节点类型数量") + edge_types: Optional[Dict[str, int]] = Field(default_factory=dict, description="各边类型数量") + + +class GraphData(BaseModel): + """图数据响应""" + nodes: List[GraphNodeData] = Field(..., description="节点列表") + edges: List[GraphEdgeData] = Field(..., description="边列表") + statistics: GraphStatistics = Field(..., description="统计信息") + message: Optional[str] = Field(None, description="附加消息") + + +# ==================== 关系演变 ==================== + +class RelationshipEvolutionData(BaseModel): + """关系演变数据""" + emotion: Any = Field(None, description="情绪数据") + interaction: Any = Field(None, description="交互频率数据") From 2fa4d295487b61c9f23d05f83e438e6d070ab79c Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 27 Apr 2026 18:39:33 +0800 Subject: [PATCH 2/2] fix(memory): use explicit None checks and remove unnecessary Optional type - Replace truthiness checks with 'is not None' for data.message in graph_data and community_graph endpoints to handle empty string correctly - Remove Optional wrapper from GraphStatistics.edge_types since it already has a default_factory --- api/app/controllers/user_memory_controllers.py | 4 ++-- api/app/schemas/user_memory_schema.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py index e7f5db4d..c8d24d92 100644 --- a/api/app/controllers/user_memory_controllers.py +++ b/api/app/controllers/user_memory_controllers.py @@ -350,7 +350,7 @@ async def get_graph_data_api( ) # 检查是否有错误消息 - if data.message and data.statistics.total_nodes == 0: + if data.message is not None and data.statistics.total_nodes == 0: api_logger.warning(f"图数据查询返回空结果: {data.message}") return success(data=data.model_dump(), msg=data.message) @@ -394,7 +394,7 @@ async def get_community_graph_data_api( message=result.get("message"), ) - if data.message and data.statistics.total_nodes == 0: + if data.message is not None and data.statistics.total_nodes == 0: api_logger.warning(f"社区图谱查询返回空结果: {data.message}") return success(data=data.model_dump(), msg=data.message) diff --git a/api/app/schemas/user_memory_schema.py b/api/app/schemas/user_memory_schema.py index ea6570b3..e0149ceb 100644 --- a/api/app/schemas/user_memory_schema.py +++ b/api/app/schemas/user_memory_schema.py @@ -99,7 +99,7 @@ class GraphStatistics(BaseModel): total_nodes: int = Field(0, description="节点总数") total_edges: int = Field(0, description="边总数") node_types: Dict[str, int] = Field(default_factory=dict, description="各节点类型数量") - edge_types: Optional[Dict[str, int]] = Field(default_factory=dict, description="各边类型数量") + edge_types: Dict[str, int] = Field(default_factory=dict, description="各边类型数量") class GraphData(BaseModel):