From 42e569b8e59f84be60f73fbf5f33ad23b0ae5677 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= Date: Tue, 23 Dec 2025 08:05:06 +0000 Subject: [PATCH] Merge #31 into develop from memory-summary MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [feature]开发用户记忆详情的接口 * memory-summary: (69 commits squashed) - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - Merge branch 'memory-summary' of codeup.aliyun.com:redbearai/python/redbear-mem-open into memory-summary - [feature]Develop the relationship graph interface - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - [feature]Develop the relationship graph interface - Merge branch 'memory-summary' of codeup.aliyun.com:redbearai/python/redbear-mem-open into memory-summary - [feature]Develop the end_user/profile interface - [feature]Memory Insights and User Summary Cache Storage Ingestion - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - [feature]Develop the relationship graph interface - [feature]Develop the end_user/profile interface - Merge branch 'memory-summary' of codeup.aliyun.com:redbearai/python/redbear-mem-open into memory-summary - [updated]Base change operation - [refactor]1.Convert timestamp;2.Remove unnecessary code - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - [feature]Develop the relationship graph interface - [feature]Memory Insights and User Summary Cache Storage Ingestion - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - [feature]Develop the relationship graph interface - [feature]Develop the end_user/profile interface - [feature]Memory Insights and User Summary Cache Storage Ingestion - [feature]Develop the relationship graph interface - [feature]Develop the end_user/profile interface - [updated]Base change operation - [refactor]1.Convert timestamp;2.Remove unnecessary code - Merge branch 'memory-summary' of codeup.aliyun.com:redbearai/python/redbear-mem-open into memory-summary - [check]check_code.py checks the quality of the code - [fix]Fix insecure database connections - [refactor]refactor memory_storage_controller and memory_storage_service - [add]The /total_memory_count interface returns the "name" field. - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - [feature]Develop the relationship graph interface - [feature]Memory Insights and User Summary Cache Storage Ingestion - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - [feature]Develop the relationship graph interface - [feature]Develop the end_user/profile interface - [feature]Memory Insights and User Summary Cache Storage Ingestion - [feature]Develop the relationship graph interface - [feature]Develop the end_user/profile interface - [updated]Base change operation - [refactor]1.Convert timestamp;2.Remove unnecessary code - [feature]Memory Insights and User Summary Cache Storage Ingestion - [featrue]Develop a memory classification interface - [feature]Develop the relationship graph interface - [feature]Develop the end_user/profile interface - [updated]Base change operation - [refactor]1.Convert timestamp;2.Remove unnecessary code - [check]check_code.py checks the quality of the code - [fix]Fix insecure database connections - [refactor]refactor memory_storage_controller and memory_storage_service - [add]The /total_memory_count interface returns the "name" field. - Merge branch 'memory-summary' of codeup.aliyun.com:redbearai/python/redbear-mem-open into memory-summary - [refactor]Reconstruct the user's memory location - add uv.lock Signed-off-by: 乐力齐 Reviewed-by: aliyun6762716068 Merged-by: aliyun6762716068 CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/31 --- api/app/celery_app.py | 6 + api/app/controllers/__init__.py | 2 + .../memory_dashboard_controller.py | 2 +- .../controllers/memory_storage_controller.py | 45 +- .../controllers/user_memory_controllers.py | 382 ++++++++ api/app/core/config.py | 3 + api/app/models/end_user_model.py | 17 +- api/app/repositories/end_user_repository.py | 182 +++- api/app/schemas/end_user_schema.py | 34 + api/app/schemas/memory_storage_schema.py | 9 + api/app/services/memory_dashboard_service.py | 20 +- api/app/services/memory_storage_service.py | 20 +- api/app/services/user_memory_service.py | 831 ++++++++++++++++++ api/app/tasks.py | 708 +++++++++------ api/env.example | 5 + 15 files changed, 1948 insertions(+), 318 deletions(-) create mode 100644 api/app/controllers/user_memory_controllers.py create mode 100644 api/app/services/user_memory_service.py diff --git a/api/app/celery_app.py b/api/app/celery_app.py index ce7e9300..44ae9ab2 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -83,6 +83,7 @@ celery_app.autodiscover_tasks(['app']) reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS) health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS) memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS) +memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS) workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME # 构建定时任务配置 beat_schedule_config = { @@ -97,6 +98,11 @@ beat_schedule_config = { "schedule": workspace_reflection_schedule, "args": (), }, + "regenerate-memory-cache": { + "task": "app.tasks.regenerate_memory_cache", + "schedule": memory_cache_regeneration_schedule, + "args": (), + }, } # 如果配置了默认工作空间ID,则添加记忆总量统计任务 diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index 5cfbe536..c72072eb 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -35,6 +35,7 @@ from . import ( tool_controller, tool_execution_controller, ) +from . import user_memory_controllers # 创建管理端 API 路由器 manager_router = APIRouter() @@ -58,6 +59,7 @@ manager_router.include_router(upload_controller.router) manager_router.include_router(memory_agent_controller.router) manager_router.include_router(memory_dashboard_controller.router) manager_router.include_router(memory_storage_controller.router) +manager_router.include_router(user_memory_controllers.router) manager_router.include_router(api_key_controller.router) manager_router.include_router(release_share_controller.router) manager_router.include_router(public_share_controller.router) # 公开路由(无需认证) diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index 4a01c575..5166d012 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -287,7 +287,7 @@ async def get_workspace_total_memory_count( "total_memory_count": int, "host_count": int, "details": [ - {"host_id": "uuid", "count": 100}, + {"end_user_id": "uuid", "count": 100, "name": "用户名称"}, ... ] } diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index 89daf9ce..0fae66fb 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -1,8 +1,9 @@ -from typing import Optional, Union +from typing import Optional import os import uuid +import datetime from sqlalchemy.orm import Session -from fastapi import APIRouter, Depends, UploadFile +from fastapi import APIRouter, Depends from fastapi.responses import StreamingResponse @@ -10,6 +11,7 @@ from app.db import get_db from app.core.logging_config import get_api_logger from app.core.response_utils import success, fail from app.core.error_codes import BizCode +from app.core.memory.utils.self_reflexion_utils import self_reflexion from app.services.memory_storage_service import ( MemoryStorageService, DataConfigService, @@ -23,9 +25,7 @@ from app.services.memory_storage_service import ( search_edges, search_entity_graph, analytics_hot_memory_tags, - analytics_memory_insight_report, analytics_recent_activity_stats, - analytics_user_summary, ) from app.schemas.response_schema import ApiResponse from app.schemas.memory_storage_schema import ( @@ -36,10 +36,16 @@ from app.schemas.memory_storage_schema import ( ConfigUpdateForget, ConfigKey, ConfigPilotRun, + GenerateCacheRequest, ) -from app.core.memory.utils.config.definitions import reload_configuration_from_database +from app.schemas.end_user_schema import ( + EndUserProfileResponse, + EndUserProfileUpdate, +) +from app.models.end_user_model import EndUser from app.dependencies import get_current_user from app.models.user_model import User + # Get API logger api_logger = get_api_logger() @@ -489,20 +495,6 @@ async def get_hot_memory_tags_api( return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e)) -@router.get("/analytics/memory_insight/report", response_model=ApiResponse) -async def get_memory_insight_report_api( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: - api_logger.info(f"Memory insight report requested for end_user_id: {end_user_id}") - try: - result = await analytics_memory_insight_report(end_user_id) - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"Memory insight report failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告生成失败", str(e)) - - @router.get("/analytics/recent_activity_stats", response_model=ApiResponse) async def get_recent_activity_stats_api( current_user: User = Depends(get_current_user), @@ -516,20 +508,6 @@ async def get_recent_activity_stats_api( return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e)) -@router.get("/analytics/user_summary", response_model=ApiResponse) -async def get_user_summary_api( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: - api_logger.info(f"User summary requested for end_user_id: {end_user_id}") - try: - result = await analytics_user_summary(end_user_id) - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"User summary failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "用户摘要生成失败", str(e)) - -from app.core.memory.utils.self_reflexion_utils import self_reflexion @router.get("/self_reflexion") async def self_reflexion_endpoint(host_id: uuid.UUID) -> str: """ @@ -541,3 +519,4 @@ async def self_reflexion_endpoint(host_id: uuid.UUID) -> str: 自我反思结果。 """ return await self_reflexion(host_id) + diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py new file mode 100644 index 00000000..5ff34d21 --- /dev/null +++ b/api/app/controllers/user_memory_controllers.py @@ -0,0 +1,382 @@ +""" +用户记忆相关的控制器 +包含用户摘要、记忆洞察、节点统计、图数据和用户档案等接口 +""" +from typing import Optional +import datetime +from sqlalchemy.orm import Session +from fastapi import APIRouter, Depends + +from app.db import get_db +from app.core.logging_config import get_api_logger +from app.core.response_utils import success, fail +from app.core.error_codes import BizCode +from app.services.user_memory_service import ( + UserMemoryService, + analytics_node_statistics, + analytics_graph_data, +) +from app.schemas.response_schema import ApiResponse +from app.schemas.memory_storage_schema import GenerateCacheRequest +from app.schemas.end_user_schema import ( + EndUserProfileResponse, + EndUserProfileUpdate, +) +from app.models.end_user_model import EndUser +from app.dependencies import get_current_user +from app.models.user_model import User + +# Get API logger +api_logger = get_api_logger() + +# Initialize service +user_memory_service = UserMemoryService() + +router = APIRouter( + prefix="/memory-storage", + tags=["User Memory"], +) + + +@router.get("/analytics/memory_insight/report", response_model=ApiResponse) +async def get_memory_insight_report_api( + end_user_id: str, # 使用 end_user_id + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), + ) -> dict: + """获取缓存的记忆洞察报告""" + api_logger.info(f"记忆洞察报告请求: end_user_id={end_user_id}, user={current_user.username}") + try: + # 调用服务层获取缓存数据 + result = await user_memory_service.get_cached_memory_insight(db, end_user_id) + + if result["is_cached"]: + # 缓存存在,返回缓存数据 + api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}") + return success(data=result, msg="查询成功") + else: + # 缓存不存在,返回提示消息 + api_logger.info(f"记忆洞察报告缓存不存在: end_user_id={end_user_id}") + return success(data=result, msg="查询成功") + except Exception as e: + api_logger.error(f"记忆洞察报告查询失败: end_user_id={end_user_id}, error={str(e)}") + return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告查询失败", str(e)) + + +@router.get("/analytics/user_summary", response_model=ApiResponse) +async def get_user_summary_api( + end_user_id: str, # 使用 end_user_id + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), + ) -> dict: + """获取缓存的用户摘要""" + api_logger.info(f"用户摘要请求: end_user_id={end_user_id}, user={current_user.username}") + try: + # 调用服务层获取缓存数据 + result = await user_memory_service.get_cached_user_summary(db, end_user_id) + + if result["is_cached"]: + # 缓存存在,返回缓存数据 + api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}") + return success(data=result, msg="查询成功") + else: + # 缓存不存在,返回提示消息 + api_logger.info(f"用户摘要缓存不存在: end_user_id={end_user_id}") + return success(data=result, msg="查询成功") + except Exception as e: + api_logger.error(f"用户摘要查询失败: end_user_id={end_user_id}, error={str(e)}") + return fail(BizCode.INTERNAL_ERROR, "用户摘要查询失败", str(e)) + + +@router.post("/analytics/generate_cache", response_model=ApiResponse) +async def generate_cache_api( + request: GenerateCacheRequest, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + """ + 手动触发缓存生成 + + - 如果提供 end_user_id,只为该用户生成 + - 如果不提供,为当前工作空间的所有用户生成 + """ + workspace_id = current_user.current_workspace_id + + # 检查用户是否已选择工作空间 + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + group_id = request.end_user_id + + api_logger.info( + f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, " + f"end_user_id={group_id if group_id else '全部用户'}" + ) + + try: + if group_id: + # 为单个用户生成 + api_logger.info(f"开始为单个用户生成缓存: end_user_id={group_id}") + + # 生成记忆洞察 + insight_result = await user_memory_service.generate_and_cache_insight(db, group_id, workspace_id) + + # 生成用户摘要 + summary_result = await user_memory_service.generate_and_cache_summary(db, group_id, workspace_id) + + # 构建响应 + result = { + "end_user_id": group_id, + "insight_success": insight_result["success"], + "summary_success": summary_result["success"], + "errors": [] + } + + # 收集错误信息 + if not insight_result["success"]: + result["errors"].append({ + "type": "insight", + "error": insight_result.get("error") + }) + if not summary_result["success"]: + result["errors"].append({ + "type": "summary", + "error": summary_result.get("error") + }) + + # 记录结果 + if result["insight_success"] and result["summary_success"]: + api_logger.info(f"成功为用户 {group_id} 生成缓存") + else: + api_logger.warning(f"用户 {group_id} 的缓存生成部分失败: {result['errors']}") + + return success(data=result, msg="生成完成") + + else: + # 为整个工作空间生成 + api_logger.info(f"开始为工作空间 {workspace_id} 批量生成缓存") + + result = await user_memory_service.generate_cache_for_workspace(db, workspace_id) + + # 记录统计信息 + api_logger.info( + f"工作空间 {workspace_id} 批量生成完成: " + f"总数={result['total_users']}, 成功={result['successful']}, 失败={result['failed']}" + ) + + return success(data=result, msg="批量生成完成") + + except Exception as e: + api_logger.error(f"缓存生成失败: user={current_user.username}, error={str(e)}") + return fail(BizCode.INTERNAL_ERROR, "缓存生成失败", str(e)) + + +@router.get("/analytics/node_statistics", response_model=ApiResponse) +async def get_node_statistics_api( + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + workspace_id = current_user.current_workspace_id + + # 检查用户是否已选择工作空间 + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + api_logger.info(f"节点统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}") + + try: + result = await analytics_node_statistics(db, end_user_id) + + # 检查是否有错误消息 + if "message" in result and result["total"] == 0: + api_logger.warning(f"节点统计查询返回空结果: {result.get('message')}") + return success(data=result, msg=result.get("message", "查询成功")) + + api_logger.info(f"成功获取节点统计: end_user_id={end_user_id}, total={result['total']}") + return success(data=result, msg="查询成功") + except Exception as e: + api_logger.error(f"用户摘要查询失败: end_user_id={end_user_id}, error={str(e)}") + return fail(BizCode.INTERNAL_ERROR, "用户摘要查询失败", str(e)) + +@router.get("/analytics/graph_data", response_model=ApiResponse) +async def get_graph_data_api( + end_user_id: str, + node_types: Optional[str] = None, + limit: int = 100, + depth: int = 1, + center_node_id: Optional[str] = None, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + workspace_id = current_user.current_workspace_id + + # 检查用户是否已选择工作空间 + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试查询图数据但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + # 参数验证 + if limit > 1000: + limit = 1000 + api_logger.warning("limit 参数超过最大值,已调整为 1000") + + if depth > 3: + depth = 3 + api_logger.warning("depth 参数超过最大值,已调整为 3") + + # 解析 node_types 参数 + node_types_list = None + if node_types: + node_types_list = [t.strip() for t in node_types.split(",") if t.strip()] + + api_logger.info( + f"图数据查询请求: end_user_id={end_user_id}, user={current_user.username}, " + f"workspace={workspace_id}, node_types={node_types_list}, limit={limit}, depth={depth}" + ) + + try: + result = await analytics_graph_data( + db=db, + end_user_id=end_user_id, + node_types=node_types_list, + limit=limit, + depth=depth, + center_node_id=center_node_id + ) + + # 检查是否有错误消息 + if "message" in result and result["statistics"]["total_nodes"] == 0: + api_logger.warning(f"图数据查询返回空结果: {result.get('message')}") + return success(data=result, msg=result.get("message", "查询成功")) + + api_logger.info( + f"成功获取图数据: end_user_id={end_user_id}, " + f"nodes={result['statistics']['total_nodes']}, " + f"edges={result['statistics']['total_edges']}" + ) + return success(data=result, msg="查询成功") + + except Exception as e: + api_logger.error(f"图数据查询失败: end_user_id={end_user_id}, error={str(e)}") + return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e)) + + +@router.get("/read_end_user/profile", response_model=ApiResponse) +async def get_end_user_profile( + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + workspace_id = current_user.current_workspace_id + + # 检查用户是否已选择工作空间 + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + api_logger.info( + f"用户信息查询请求: end_user_id={end_user_id}, user={current_user.username}, " + f"workspace={workspace_id}" + ) + + try: + # 查询终端用户 + end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first() + + if not end_user: + api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}") + return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}") + + # 构建响应数据 + profile_data = EndUserProfileResponse( + id=end_user.id, + name=end_user.name, + position=end_user.position, + department=end_user.department, + contact=end_user.contact, + phone=end_user.phone, + hire_date=end_user.hire_date, + updatetime_profile=end_user.updatetime_profile + ) + + api_logger.info(f"成功获取用户信息: end_user_id={end_user_id}") + return success(data=profile_data.model_dump(), msg="查询成功") + + except Exception as e: + api_logger.error(f"用户信息查询失败: end_user_id={end_user_id}, error={str(e)}") + return fail(BizCode.INTERNAL_ERROR, "用户信息查询失败", str(e)) + + +@router.post("/updated_end_user/profile", response_model=ApiResponse) +async def update_end_user_profile( + profile_update: EndUserProfileUpdate, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + """ + 更新终端用户的基本信息 + + 该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。 + 所有字段都是可选的,只更新提供的字段。 + + """ + workspace_id = current_user.current_workspace_id + end_user_id = profile_update.end_user_id + + # 检查用户是否已选择工作空间 + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + api_logger.info( + f"用户信息更新请求: end_user_id={end_user_id}, user={current_user.username}, " + f"workspace={workspace_id}" + ) + + try: + # 查询终端用户 + end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first() + + if not end_user: + api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}") + return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}") + + # 更新字段(只更新提供的非 None 字段,排除 end_user_id) + update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'}) + for field, value in update_data.items(): + if value is not None: + setattr(end_user, field, value) + + # 更新 updated_at 时间戳 + end_user.updated_at = datetime.datetime.now() + + # 更新 updatetime_profile 为当前时间戳(毫秒) + current_timestamp = int(datetime.datetime.now().timestamp() * 1000) + end_user.updatetime_profile = current_timestamp + + # 提交更改 + db.commit() + db.refresh(end_user) + + # 构建响应数据 + profile_data = EndUserProfileResponse( + id=end_user.id, + name=end_user.name, + position=end_user.position, + department=end_user.department, + contact=end_user.contact, + phone=end_user.phone, + hire_date=end_user.hire_date, + updatetime_profile=end_user.updatetime_profile + ) + + api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}, updatetime_profile={current_timestamp}") + return success(data=profile_data.model_dump(), msg="更新成功") + + except Exception as e: + db.rollback() + api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}") + return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", str(e)) diff --git a/api/app/core/config.py b/api/app/core/config.py index bf5ff45a..7f4a99ba 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -149,6 +149,9 @@ class Settings: MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24")) DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None) REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30)) + + # Memory Cache Regeneration Configuration + MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24")) # Memory Module Configuration (internal) MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output") diff --git a/api/app/models/end_user_model.py b/api/app/models/end_user_model.py index 2a9ed8da..0ef11ffa 100644 --- a/api/app/models/end_user_model.py +++ b/api/app/models/end_user_model.py @@ -1,6 +1,6 @@ import datetime import uuid -from sqlalchemy import Column, String, DateTime, ForeignKey +from sqlalchemy import Column, String, DateTime, ForeignKey, Text, BigInteger from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import relationship from app.db import Base @@ -17,6 +17,21 @@ class EndUser(Base): reflection_time = Column(DateTime, nullable=True) created_at = Column(DateTime, default=datetime.datetime.now) updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) + + # 用户基本信息字段 + name = Column(String, nullable=True, comment="姓名") + position = Column(String, nullable=True, comment="职位") + department = Column(String, nullable=True, comment="部门") + contact = Column(String, nullable=True, comment="联系方式") + phone = Column(String, nullable=True, comment="电话") + hire_date = Column(BigInteger, nullable=True, comment="入职日期(时间戳,毫秒)") + updatetime_profile = Column(BigInteger, nullable=True, comment="核心档案信息最后更新时间(时间戳,毫秒)") + + # 缓存字段 - Cache fields for pre-computed analytics + memory_insight = Column(Text, nullable=True, comment="缓存的记忆洞察报告") + user_summary = Column(Text, nullable=True, comment="缓存的用户摘要") + memory_insight_updated_at = Column(DateTime, nullable=True, comment="洞察报告最后更新时间") + user_summary_updated_at = Column(DateTime, nullable=True, comment="用户摘要最后更新时间") # 与 App 的反向关系 app = relationship( diff --git a/api/app/repositories/end_user_repository.py b/api/app/repositories/end_user_repository.py index 07e45a48..69932101 100644 --- a/api/app/repositories/end_user_repository.py +++ b/api/app/repositories/end_user_repository.py @@ -1,8 +1,11 @@ from sqlalchemy.orm import Session from typing import List, Optional import uuid +import datetime from app.models.end_user_model import EndUser +from app.models.app_model import App +from app.models.workspace_model import Workspace from app.core.logging_config import get_db_logger @@ -92,6 +95,157 @@ class EndUserRepository: db_logger.error(f"获取或创建终端用户时出错: {str(e)}") raise + def get_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]: + """根据ID获取终端用户(用于缓存操作) + + Args: + end_user_id: 终端用户ID + + Returns: + Optional[EndUser]: 终端用户对象,如果不存在则返回None + """ + try: + end_user = ( + self.db.query(EndUser) + .filter(EndUser.id == end_user_id) + .first() + ) + if end_user: + db_logger.debug(f"成功查询到终端用户 {end_user_id}") + else: + db_logger.debug(f"未找到终端用户 {end_user_id}") + return end_user + except Exception as e: + self.db.rollback() + db_logger.error(f"查询终端用户 {end_user_id} 时出错: {str(e)}") + raise + + def update_memory_insight( + self, + end_user_id: uuid.UUID, + insight: str + ) -> bool: + """更新记忆洞察缓存 + + Args: + end_user_id: 终端用户ID + insight: 记忆洞察内容 + + Returns: + bool: 更新成功返回True,否则返回False + """ + try: + updated_count = ( + self.db.query(EndUser) + .filter(EndUser.id == end_user_id) + .update( + { + EndUser.memory_insight: insight, + EndUser.memory_insight_updated_at: datetime.datetime.now() + }, + synchronize_session=False + ) + ) + + self.db.commit() + + if updated_count > 0: + db_logger.info(f"成功更新终端用户 {end_user_id} 的记忆洞察缓存") + return True + else: + db_logger.warning(f"未找到终端用户 {end_user_id},无法更新记忆洞察缓存") + return False + + except Exception as e: + self.db.rollback() + db_logger.error(f"更新终端用户 {end_user_id} 的记忆洞察缓存时出错: {str(e)}") + raise + + def update_user_summary( + self, + end_user_id: uuid.UUID, + summary: str + ) -> bool: + """更新用户摘要缓存 + + Args: + end_user_id: 终端用户ID + summary: 用户摘要内容 + + Returns: + bool: 更新成功返回True,否则返回False + """ + try: + updated_count = ( + self.db.query(EndUser) + .filter(EndUser.id == end_user_id) + .update( + { + EndUser.user_summary: summary, + EndUser.user_summary_updated_at: datetime.datetime.now() + }, + synchronize_session=False + ) + ) + + self.db.commit() + + if updated_count > 0: + db_logger.info(f"成功更新终端用户 {end_user_id} 的用户摘要缓存") + return True + else: + db_logger.warning(f"未找到终端用户 {end_user_id},无法更新用户摘要缓存") + return False + + except Exception as e: + self.db.rollback() + db_logger.error(f"更新终端用户 {end_user_id} 的用户摘要缓存时出错: {str(e)}") + raise + + def get_all_by_workspace(self, workspace_id: uuid.UUID) -> List[EndUser]: + """获取工作空间的所有终端用户 + + Args: + workspace_id: 工作空间ID + + Returns: + List[EndUser]: 终端用户列表 + """ + try: + end_users = ( + self.db.query(EndUser) + .join(App, EndUser.app_id == App.id) + .filter(App.workspace_id == workspace_id) + .all() + ) + db_logger.info(f"成功查询工作空间 {workspace_id} 下的 {len(end_users)} 个终端用户") + return end_users + except Exception as e: + self.db.rollback() + db_logger.error(f"查询工作空间 {workspace_id} 下的终端用户时出错: {str(e)}") + raise + + def get_all_active_workspaces(self) -> List[uuid.UUID]: + """获取所有活动工作空间的ID + + Returns: + List[uuid.UUID]: 活动工作空间ID列表 + """ + try: + workspace_ids = ( + self.db.query(Workspace.id) + .filter(Workspace.is_active) + .all() + ) + # 提取ID(查询返回的是元组列表) + workspace_id_list = [workspace_id[0] for workspace_id in workspace_ids] + db_logger.info(f"成功查询到 {len(workspace_id_list)} 个活动工作空间") + return workspace_id_list + except Exception as e: + self.db.rollback() + db_logger.error(f"查询活动工作空间时出错: {str(e)}") + raise + def get_end_users_by_app_id(db: Session, app_id: uuid.UUID) -> List[EndUser]: """根据应用ID查询宿主(返回 EndUser ORM 列表)""" repo = EndUserRepository(db) @@ -138,4 +292,30 @@ def update_end_user_other_name( except Exception as e: db.rollback() db_logger.error(f"更新宿主 {end_user_id} 的 other_name 时出错: {str(e)}") - raise \ No newline at end of file + raise + +# 新增的缓存操作函数(保持与类方法一致的接口) +def get_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]: + """根据ID获取终端用户(用于缓存操作)""" + repo = EndUserRepository(db) + return repo.get_by_id(end_user_id) + +def update_memory_insight(db: Session, end_user_id: uuid.UUID, insight: str) -> bool: + """更新记忆洞察缓存""" + repo = EndUserRepository(db) + return repo.update_memory_insight(end_user_id, insight) + +def update_user_summary(db: Session, end_user_id: uuid.UUID, summary: str) -> bool: + """更新用户摘要缓存""" + repo = EndUserRepository(db) + return repo.update_user_summary(end_user_id, summary) + +def get_all_by_workspace(db: Session, workspace_id: uuid.UUID) -> List[EndUser]: + """获取工作空间的所有终端用户""" + repo = EndUserRepository(db) + return repo.get_all_by_workspace(workspace_id) + +def get_all_active_workspaces(db: Session) -> List[uuid.UUID]: + """获取所有活动工作空间的ID""" + repo = EndUserRepository(db) + return repo.get_all_active_workspaces() diff --git a/api/app/schemas/end_user_schema.py b/api/app/schemas/end_user_schema.py index 74fc4a14..939d2d3e 100644 --- a/api/app/schemas/end_user_schema.py +++ b/api/app/schemas/end_user_schema.py @@ -16,3 +16,37 @@ class EndUser(BaseModel): reflection_time: Optional[datetime.datetime] = Field(description="反思时间", default_factory=datetime.datetime.now) created_at: datetime.datetime = Field(description="创建时间", default_factory=datetime.datetime.now) updated_at: datetime.datetime = Field(description="更新时间", default_factory=datetime.datetime.now) + + # 用户基本信息字段 + name: Optional[str] = Field(description="姓名", default=None) + position: Optional[str] = Field(description="职位", default=None) + department: Optional[str] = Field(description="部门", default=None) + contact: Optional[str] = Field(description="联系方式", default=None) + phone: Optional[str] = Field(description="电话", default=None) + hire_date: Optional[int] = Field(description="入职日期(时间戳,毫秒)", default=None) + updatetime_profile: Optional[int] = Field(description="核心档案信息最后更新时间(时间戳,毫秒)", default=None) + + +class EndUserProfileResponse(BaseModel): + """终端用户基本信息响应模型""" + model_config = ConfigDict(from_attributes=True) + + id: uuid.UUID = Field(description="终端用户ID") + name: Optional[str] = Field(description="姓名", default=None) + position: Optional[str] = Field(description="职位", default=None) + department: Optional[str] = Field(description="部门", default=None) + contact: Optional[str] = Field(description="联系方式", default=None) + phone: Optional[str] = Field(description="电话", default=None) + hire_date: Optional[int] = Field(description="入职日期(时间戳,毫秒)", default=None) + updatetime_profile: Optional[int] = Field(description="核心档案信息最后更新时间(时间戳,毫秒)", default=None) + + +class EndUserProfileUpdate(BaseModel): + """终端用户基本信息更新请求模型""" + end_user_id: str = Field(description="终端用户ID") + name: Optional[str] = Field(description="姓名", default=None) + position: Optional[str] = Field(description="职位", default=None) + department: Optional[str] = Field(description="部门", default=None) + contact: Optional[str] = Field(description="联系方式", default=None) + phone: Optional[str] = Field(description="电话", default=None) + hire_date: Optional[int] = Field(description="入职日期(时间戳,毫秒)", default=None) \ No newline at end of file diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index be249b5e..df70ec77 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -382,3 +382,12 @@ def fail( error=error_code, time=time or _now_ms(), ) + +class GenerateCacheRequest(BaseModel): + """缓存生成请求模型""" + model_config = ConfigDict(populate_by_name=True, extra="forbid") + + end_user_id: Optional[str] = Field( + None, + description="终端用户ID(UUID格式)。如果提供,只为该用户生成;如果不提供,为当前工作空间的所有用户生成" + ) diff --git a/api/app/services/memory_dashboard_service.py b/api/app/services/memory_dashboard_service.py index be4ec12f..6acc699a 100644 --- a/api/app/services/memory_dashboard_service.py +++ b/api/app/services/memory_dashboard_service.py @@ -268,10 +268,20 @@ async def get_workspace_total_memory_count( # 如果提供了 end_user_id,只查询该用户 if end_user_id: search_result = await memory_storage_service.search_all(end_user_id=end_user_id) + # 查询用户名称 + from app.repositories.end_user_repository import EndUserRepository + repo = EndUserRepository(db) + end_user = repo.get_by_id(uuid.UUID(end_user_id)) + user_name = end_user.name if end_user else None + return { "total_memory_count": search_result.get("total", 0), "host_count": 1, - "details": [{"end_user_id": end_user_id, "count": search_result.get("total", 0)}] + "details": [{ + "end_user_id": end_user_id, + "count": search_result.get("total", 0), + "name": user_name + }] } for host in hosts: @@ -287,17 +297,19 @@ async def get_workspace_total_memory_count( details.append({ "end_user_id": end_user_id_str, - "count": host_total + "count": host_total, + "name": host.name # 添加 name 字段 }) - business_logger.debug(f"EndUser {end_user_id_str} 记忆数: {host_total}") + business_logger.debug(f"EndUser {end_user_id_str} ({host.name}) 记忆数: {host_total}") except Exception as e: business_logger.warning(f"获取 end_user {host.id} 记忆数失败: {str(e)}") # 失败的 host 记为 0 details.append({ "end_user_id": str(host.id), - "count": 0 + "count": 0, + "name": host.name # 添加 name 字段 }) result = { diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 0548b704..2644cd8d 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -15,11 +15,9 @@ from sqlalchemy.orm import Session from dotenv import load_dotenv from app.models.user_model import User -from app.models.end_user_model import EndUser from app.core.logging_config import get_logger from app.utils.sse_utils import format_sse_message from app.schemas.memory_storage_schema import ( - ConfigFilter, ConfigPilotRun, ConfigParamsCreate, ConfigParamsDelete, @@ -34,7 +32,8 @@ from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags from app.core.memory.analytics.memory_insight import MemoryInsight from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats from app.core.memory.analytics.user_summary import generate_user_summary -from app.repositories.data_config_repository import DataConfigRepository +from app.repositories.end_user_repository import EndUserRepository +import uuid logger = get_logger(__name__) @@ -67,6 +66,7 @@ class MemoryStorageService: } return result + class DataConfigService: # 数据配置服务类(PostgreSQL) """Service layer for config params CRUD. @@ -85,7 +85,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) @staticmethod def _convert_timestamps_to_format(data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """将 created_at 和 updated_at 字段从 datetime 对象转换为 YYYYMMDDHHmmss 格式""" - from datetime import datetime for item in data_list: for field in ['created_at', 'updated_at']: @@ -576,14 +575,6 @@ async def analytics_hot_memory_tags( return [{"name": t, "frequency": f} for t, f in top_tags] -async def analytics_memory_insight_report(end_user_id: Optional[str] = None) -> Dict[str, Any]: - insight = MemoryInsight(end_user_id) - report = await insight.generate_insight_report() - await insight.close() - data = {"report": report} - return data - - async def analytics_recent_activity_stats() -> Dict[str, Any]: stats, _msg = get_recent_activity_stats() total = ( @@ -617,8 +608,3 @@ async def analytics_recent_activity_stats() -> Dict[str, Any]: data = {"total": total, "stats": stats, "latest_relative": latest_relative} return data - -async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str, Any]: - summary = await generate_user_summary(end_user_id) - data = {"summary": summary} - return data \ No newline at end of file diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py new file mode 100644 index 00000000..a69c776e --- /dev/null +++ b/api/app/services/user_memory_service.py @@ -0,0 +1,831 @@ +""" +User Memory Service + +处理用户记忆相关的业务逻辑,包括记忆洞察、用户摘要、节点统计和图数据等。 +""" + +from typing import Dict, List, Optional, Any +import uuid +from sqlalchemy.orm import Session + +from app.core.logging_config import get_logger +from app.repositories.end_user_repository import EndUserRepository +from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.core.memory.analytics.memory_insight import MemoryInsight +from app.core.memory.analytics.user_summary import generate_user_summary + +logger = get_logger(__name__) + +# Neo4j connector instance +_neo4j_connector = Neo4jConnector() + + +class UserMemoryService: + """用户记忆服务类""" + + def __init__(self): + logger.info("UserMemoryService initialized") + + async def get_cached_memory_insight( + self, + db: Session, + end_user_id: str + ) -> Dict[str, Any]: + """ + 从数据库获取缓存的记忆洞察 + + Args: + db: 数据库会话 + end_user_id: 终端用户ID (UUID) + + Returns: + { + "report": str, + "updated_at": datetime, + "is_cached": bool + } + """ + try: + # 转换为UUID并查询用户 + user_uuid = uuid.UUID(end_user_id) + repo = EndUserRepository(db) + end_user = repo.get_by_id(user_uuid) + + if not end_user: + logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户") + return { + "report": None, + "updated_at": None, + "is_cached": False, + "message": "用户不存在" + } + + # 检查是否有缓存数据 + if end_user.memory_insight: + logger.info(f"成功获取 end_user_id {end_user_id} 的缓存记忆洞察") + return { + "report": end_user.memory_insight, + "updated_at": end_user.memory_insight_updated_at, + "is_cached": True + } + else: + logger.info(f"end_user_id {end_user_id} 的记忆洞察缓存为空") + return { + "report": None, + "updated_at": None, + "is_cached": False, + "message": "数据尚未生成,请稍后重试或联系管理员" + } + + except ValueError: + logger.error(f"无效的 end_user_id 格式: {end_user_id}") + return { + "report": None, + "updated_at": None, + "is_cached": False, + "message": "无效的用户ID格式" + } + except Exception as e: + logger.error(f"获取缓存记忆洞察时出错: {str(e)}") + raise + + async def get_cached_user_summary( + self, + db: Session, + end_user_id: str + ) -> Dict[str, Any]: + """ + 从数据库获取缓存的用户摘要 + + Args: + db: 数据库会话 + end_user_id: 终端用户ID (UUID) + + Returns: + { + "summary": str, + "updated_at": datetime, + "is_cached": bool + } + """ + try: + # 转换为UUID并查询用户 + user_uuid = uuid.UUID(end_user_id) + repo = EndUserRepository(db) + end_user = repo.get_by_id(user_uuid) + + if not end_user: + logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户") + return { + "summary": None, + "updated_at": None, + "is_cached": False, + "message": "用户不存在" + } + + # 检查是否有缓存数据 + if end_user.user_summary: + logger.info(f"成功获取 end_user_id {end_user_id} 的缓存用户摘要") + return { + "summary": end_user.user_summary, + "updated_at": end_user.user_summary_updated_at, + "is_cached": True + } + else: + logger.info(f"end_user_id {end_user_id} 的用户摘要缓存为空") + return { + "summary": None, + "updated_at": None, + "is_cached": False, + "message": "数据尚未生成,请稍后重试或联系管理员" + } + + except ValueError: + logger.error(f"无效的 end_user_id 格式: {end_user_id}") + return { + "summary": None, + "updated_at": None, + "is_cached": False, + "message": "无效的用户ID格式" + } + except Exception as e: + logger.error(f"获取缓存用户摘要时出错: {str(e)}") + raise + + async def generate_and_cache_insight( + self, + db: Session, + end_user_id: str, + workspace_id: Optional[uuid.UUID] = None + ) -> Dict[str, Any]: + """ + 生成并缓存记忆洞察 + + Args: + db: 数据库会话 + end_user_id: 终端用户ID (UUID) + workspace_id: 工作空间ID (可选) + + Returns: + { + "success": bool, + "report": str, + "error": Optional[str] + } + """ + try: + logger.info(f"开始为 end_user_id {end_user_id} 生成记忆洞察") + + # 转换为UUID并查询用户 + user_uuid = uuid.UUID(end_user_id) + repo = EndUserRepository(db) + end_user = repo.get_by_id(user_uuid) + + if not end_user: + logger.error(f"end_user_id {end_user_id} 不存在") + return { + "success": False, + "report": None, + "error": "用户不存在" + } + + # 使用 end_user_id 调用分析函数 + try: + logger.info(f"使用 end_user_id={end_user_id} 生成记忆洞察") + result = await analytics_memory_insight_report(end_user_id) + report = result.get("report", "") + + if not report: + logger.warning(f"end_user_id {end_user_id} 的记忆洞察生成结果为空") + return { + "success": False, + "report": None, + "error": "生成的洞察报告为空,可能Neo4j中没有该用户的数据" + } + + # 更新数据库缓存 + success = repo.update_memory_insight(user_uuid, report) + + if success: + logger.info(f"成功为 end_user_id {end_user_id} 生成并缓存记忆洞察") + return { + "success": True, + "report": report, + "error": None + } + else: + logger.error(f"更新 end_user_id {end_user_id} 的记忆洞察缓存失败") + return { + "success": False, + "report": report, + "error": "数据库更新失败" + } + + except Exception as e: + logger.error(f"调用分析函数生成记忆洞察时出错: {str(e)}") + return { + "success": False, + "report": None, + "error": f"Neo4j或LLM服务不可用: {str(e)}" + } + + except ValueError: + logger.error(f"无效的 end_user_id 格式: {end_user_id}") + return { + "success": False, + "report": None, + "error": "无效的用户ID格式" + } + except Exception as e: + logger.error(f"生成并缓存记忆洞察时出错: {str(e)}") + return { + "success": False, + "report": None, + "error": str(e) + } + + async def generate_and_cache_summary( + self, + db: Session, + end_user_id: str, + workspace_id: Optional[uuid.UUID] = None + ) -> Dict[str, Any]: + """ + 生成并缓存用户摘要 + + Args: + db: 数据库会话 + end_user_id: 终端用户ID (UUID) + workspace_id: 工作空间ID (可选) + + Returns: + { + "success": bool, + "summary": str, + "error": Optional[str] + } + """ + try: + logger.info(f"开始为 end_user_id {end_user_id} 生成用户摘要") + + # 转换为UUID并查询用户 + user_uuid = uuid.UUID(end_user_id) + repo = EndUserRepository(db) + end_user = repo.get_by_id(user_uuid) + + if not end_user: + logger.error(f"end_user_id {end_user_id} 不存在") + return { + "success": False, + "summary": None, + "error": "用户不存在" + } + + # 使用 end_user_id 调用分析函数 + try: + logger.info(f"使用 end_user_id={end_user_id} 生成用户摘要") + result = await analytics_user_summary(end_user_id) + summary = result.get("summary", "") + + if not summary: + logger.warning(f"end_user_id {end_user_id} 的用户摘要生成结果为空") + return { + "success": False, + "summary": None, + "error": "生成的用户摘要为空,可能Neo4j中没有该用户的数据" + } + + # 更新数据库缓存 + success = repo.update_user_summary(user_uuid, summary) + + if success: + logger.info(f"成功为 end_user_id {end_user_id} 生成并缓存用户摘要") + return { + "success": True, + "summary": summary, + "error": None + } + else: + logger.error(f"更新 end_user_id {end_user_id} 的用户摘要缓存失败") + return { + "success": False, + "summary": summary, + "error": "数据库更新失败" + } + + except Exception as e: + logger.error(f"调用分析函数生成用户摘要时出错: {str(e)}") + return { + "success": False, + "summary": None, + "error": f"Neo4j或LLM服务不可用: {str(e)}" + } + + except ValueError: + logger.error(f"无效的 end_user_id 格式: {end_user_id}") + return { + "success": False, + "summary": None, + "error": "无效的用户ID格式" + } + except Exception as e: + logger.error(f"生成并缓存用户摘要时出错: {str(e)}") + return { + "success": False, + "summary": None, + "error": str(e) + } + + async def generate_cache_for_workspace( + self, + db: Session, + workspace_id: uuid.UUID + ) -> Dict[str, Any]: + """ + 为整个工作空间生成缓存 + + Args: + db: 数据库会话 + workspace_id: 工作空间ID + + Returns: + { + "total_users": int, + "successful": int, + "failed": int, + "errors": List[Dict] + } + """ + logger.info(f"开始为工作空间 {workspace_id} 批量生成缓存") + + total_users = 0 + successful = 0 + failed = 0 + errors = [] + + try: + # 获取工作空间的所有终端用户 + repo = EndUserRepository(db) + end_users = repo.get_all_by_workspace(workspace_id) + total_users = len(end_users) + + logger.info(f"工作空间 {workspace_id} 共有 {total_users} 个终端用户") + + # 遍历每个用户并生成缓存 + for end_user in end_users: + end_user_id = str(end_user.id) + + try: + # 生成记忆洞察 + insight_result = await self.generate_and_cache_insight(db, end_user_id) + + # 生成用户摘要 + summary_result = await self.generate_and_cache_summary(db, end_user_id) + + # 检查是否都成功 + if insight_result["success"] and summary_result["success"]: + successful += 1 + logger.info(f"成功为终端用户 {end_user_id} 生成缓存") + else: + failed += 1 + error_info = { + "end_user_id": end_user_id, + "insight_error": insight_result.get("error"), + "summary_error": summary_result.get("error") + } + errors.append(error_info) + logger.warning(f"终端用户 {end_user_id} 的缓存生成部分失败: {error_info}") + + except Exception as e: + # 单个用户失败不影响其他用户 + failed += 1 + error_info = { + "end_user_id": end_user_id, + "error": str(e) + } + errors.append(error_info) + logger.error(f"为终端用户 {end_user_id} 生成缓存时出错: {str(e)}") + + # 记录统计信息 + logger.info( + f"工作空间 {workspace_id} 批量生成完成: " + f"总数={total_users}, 成功={successful}, 失败={failed}" + ) + + return { + "total_users": total_users, + "successful": successful, + "failed": failed, + "errors": errors + } + + except Exception as e: + logger.error(f"为工作空间 {workspace_id} 批量生成缓存时出错: {str(e)}") + return { + "total_users": total_users, + "successful": successful, + "failed": failed, + "errors": errors + [{"error": f"批量处理失败: {str(e)}"}] + } + + +# 独立的分析函数 + +async def analytics_memory_insight_report(end_user_id: Optional[str] = None) -> Dict[str, Any]: + """ + 生成记忆洞察报告 + + Args: + end_user_id: 可选的终端用户ID + + Returns: + 包含报告的字典 + """ + insight = MemoryInsight(end_user_id) + report = await insight.generate_insight_report() + await insight.close() + data = {"report": report} + return data + + +async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str, Any]: + """ + 生成用户摘要 + + Args: + end_user_id: 可选的终端用户ID + + Returns: + 包含摘要的字典 + """ + summary = await generate_user_summary(end_user_id) + data = {"summary": summary} + return data + + +async def analytics_node_statistics( + db: Session, + end_user_id: Optional[str] = None +) -> Dict[str, Any]: + """ + 统计 Neo4j 中四种节点类型的数量和百分比 + + Args: + db: 数据库会话 + end_user_id: 可选的终端用户ID (UUID),用于过滤特定用户的节点 + + Returns: + { + "total": int, # 总节点数 + "nodes": [ + { + "type": str, # 节点类型 + "count": int, # 节点数量 + "percentage": float # 百分比 + } + ] + } + """ + # 定义四种节点类型的查询 + node_types = ["Chunk", "MemorySummary", "Statement", "ExtractedEntity"] + + # 存储每种节点类型的计数 + node_counts = {} + + # 查询每种节点类型的数量 + for node_type in node_types: + # 构建查询语句 + if end_user_id: + query = f""" + MATCH (n:{node_type}) + WHERE n.group_id = $group_id + RETURN count(n) as count + """ + result = await _neo4j_connector.execute_query(query, group_id=end_user_id) + else: + query = f""" + MATCH (n:{node_type}) + RETURN count(n) as count + """ + result = await _neo4j_connector.execute_query(query) + + # 提取计数结果 + count = result[0]["count"] if result and len(result) > 0 else 0 + node_counts[node_type] = count + + # 计算总数 + total = sum(node_counts.values()) + + # 构建返回数据,包含百分比 + nodes = [] + for node_type in node_types: + count = node_counts[node_type] + percentage = round((count / total * 100), 2) if total > 0 else 0.0 + nodes.append({ + "type": node_type, + "count": count, + "percentage": percentage + }) + + data = { + "total": total, + "nodes": nodes + } + + return data + + +async def analytics_graph_data( + db: Session, + end_user_id: str, + node_types: Optional[List[str]] = None, + limit: int = 100, + depth: int = 1, + center_node_id: Optional[str] = None +) -> Dict[str, Any]: + """ + 获取 Neo4j 图数据,用于前端可视化 + + Args: + db: 数据库会话 + end_user_id: 终端用户ID + node_types: 可选的节点类型列表 + limit: 返回节点数量限制 + depth: 图遍历深度 + center_node_id: 可选的中心节点ID + + Returns: + 包含节点、边和统计信息的字典 + """ + try: + # 1. 获取 group_id + user_uuid = uuid.UUID(end_user_id) + repo = EndUserRepository(db) + end_user = repo.get_by_id(user_uuid) + + if not end_user: + logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户") + return { + "nodes": [], + "edges": [], + "statistics": { + "total_nodes": 0, + "total_edges": 0, + "node_types": {}, + "edge_types": {} + }, + "message": "用户不存在" + } + + # 2. 构建节点查询 + if center_node_id: + # 基于中心节点的扩展查询 + node_query = f""" + MATCH path = (center)-[*1..{depth}]-(connected) + WHERE center.group_id = $group_id + AND elementId(center) = $center_node_id + WITH collect(DISTINCT center) + collect(DISTINCT connected) as all_nodes + UNWIND all_nodes as n + RETURN DISTINCT + elementId(n) as id, + labels(n)[0] as label, + properties(n) as properties + LIMIT $limit + """ + node_params = { + "group_id": end_user_id, + "center_node_id": center_node_id, + "limit": limit + } + elif node_types: + # 按节点类型过滤查询 + node_query = """ + MATCH (n) + WHERE n.group_id = $group_id + AND labels(n)[0] IN $node_types + RETURN + elementId(n) as id, + labels(n)[0] as label, + properties(n) as properties + LIMIT $limit + """ + node_params = { + "group_id": end_user_id, + "node_types": node_types, + "limit": limit + } + else: + # 查询所有节点 + node_query = """ + MATCH (n) + WHERE n.group_id = $group_id + RETURN + elementId(n) as id, + labels(n)[0] as label, + properties(n) as properties + LIMIT $limit + """ + node_params = { + "group_id": end_user_id, + "limit": limit + } + + # 执行节点查询 + node_results = await _neo4j_connector.execute_query(node_query, **node_params) + + # 3. 格式化节点数据 + nodes = [] + node_ids = [] + node_type_counts = {} + + for record in node_results: + node_id = record["id"] + node_label = record["label"] + node_props = record["properties"] + + # 根据节点类型提取需要的属性字段 + filtered_props = _extract_node_properties(node_label, node_props) + + # 直接使用数据库中的 caption,如果没有则使用节点类型作为默认值 + caption = filtered_props.get("caption", node_label) + + nodes.append({ + "id": node_id, + "label": node_label, + "properties": filtered_props, + "caption": caption + }) + + node_ids.append(node_id) + node_type_counts[node_label] = node_type_counts.get(node_label, 0) + 1 + + # 4. 查询节点之间的关系 + if len(node_ids) > 0: + edge_query = """ + MATCH (n)-[r]->(m) + WHERE elementId(n) IN $node_ids + AND elementId(m) IN $node_ids + RETURN + elementId(r) as id, + elementId(n) as source, + elementId(m) as target, + type(r) as rel_type, + properties(r) as properties + """ + edge_results = await _neo4j_connector.execute_query( + edge_query, + node_ids=node_ids + ) + else: + edge_results = [] + + # 5. 格式化边数据 + edges = [] + edge_type_counts = {} + + for record in edge_results: + edge_id = record["id"] + source = record["source"] + target = record["target"] + rel_type = record["rel_type"] + edge_props = record["properties"] + + # 清理边属性中的 Neo4j 特殊类型 + # 对于边,我们保留所有属性,但清理特殊类型 + cleaned_edge_props = {} + if edge_props: + for key, value in edge_props.items(): + cleaned_edge_props[key] = _clean_neo4j_value(value) + + # 直接使用关系类型作为 caption,如果 properties 中有 caption 则使用它 + caption = cleaned_edge_props.get("caption", rel_type) + + edges.append({ + "id": edge_id, + "source": source, + "target": target, + "type": rel_type, + "properties": cleaned_edge_props, + "caption": caption + }) + + edge_type_counts[rel_type] = edge_type_counts.get(rel_type, 0) + 1 + + # 6. 构建统计信息 + statistics = { + "total_nodes": len(nodes), + "total_edges": len(edges), + "node_types": node_type_counts, + "edge_types": edge_type_counts + } + + logger.info( + f"成功获取图数据: end_user_id={end_user_id}, " + f"nodes={len(nodes)}, edges={len(edges)}" + ) + + return { + "nodes": nodes, + "edges": edges, + "statistics": statistics + } + + except ValueError: + logger.error(f"无效的 end_user_id 格式: {end_user_id}") + return { + "nodes": [], + "edges": [], + "statistics": { + "total_nodes": 0, + "total_edges": 0, + "node_types": {}, + "edge_types": {} + }, + "message": "无效的用户ID格式" + } + except Exception as e: + logger.error(f"获取图数据失败: {str(e)}", exc_info=True) + raise + + +# 辅助函数 + +def _extract_node_properties(label: str, properties: Dict[str, Any]) -> Dict[str, Any]: + """ + 根据节点类型提取需要的属性字段 + + Args: + label: 节点类型标签 + properties: 节点的所有属性 + + Returns: + 过滤后的属性字典 + """ + # 定义每种节点类型需要的字段(白名单) + field_whitelist = { + "Dialogue": ["content", "created_at"], + "Chunk": ["content", "created_at"], + "Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption"], + "ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption"], + "MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段 + } + + # 获取该节点类型的白名单字段 + allowed_fields = field_whitelist.get(label, []) + + # 如果没有定义白名单,返回空字典(或者可以返回所有字段) + if not allowed_fields: + # 对于未定义的节点类型,只返回基本字段 + allowed_fields = ["name", "created_at", "caption"] + + # 提取白名单中的字段 + filtered_props = {} + for field in allowed_fields: + if field in properties: + value = properties[field] + # 清理 Neo4j 特殊类型 + filtered_props[field] = _clean_neo4j_value(value) + + return filtered_props + + +def _clean_neo4j_value(value: Any) -> Any: + """ + 清理单个值的 Neo4j 特殊类型 + + Args: + value: 需要清理的值 + + Returns: + 清理后的值 + """ + if value is None: + return None + + # 处理列表 + if isinstance(value, list): + return [_clean_neo4j_value(item) for item in value] + + # 处理字典 + if isinstance(value, dict): + return {k: _clean_neo4j_value(v) for k, v in value.items()} + + # 处理 Neo4j DateTime 类型 + if hasattr(value, '__class__') and 'neo4j.time' in str(type(value)): + try: + if hasattr(value, 'to_native'): + native_dt = value.to_native() + return native_dt.isoformat() + return str(value) + except Exception: + return str(value) + + # 处理其他 Neo4j 特殊类型 + if hasattr(value, '__class__') and 'neo4j' in str(type(value)): + try: + return str(value) + except Exception: + return None + + # 返回原始值 + return value diff --git a/api/app/tasks.py b/api/app/tasks.py index 39758275..55d6680c 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1,15 +1,13 @@ -import os import asyncio -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List import requests from datetime import datetime, timezone import time import uuid from math import ceil import redis -import json -from app.db import get_db +from app.db import get_db_context from app.models.document_model import Document from app.models.knowledge_model import Knowledge from app.core.rag.llm.cv_model import QWenCV @@ -48,124 +46,122 @@ def parse_document(file_path: str, document_id: uuid.UUID): """ Document parsing, vectorization, and storage """ - db = next(get_db()) # Manually call the generator - db_document = None - db_knowledge = None - progress_msg = f"{datetime.now().strftime('%H:%M:%S')} Task has been received.\n" - try: - db_document = db.query(Document).filter(Document.id == document_id).first() - db_knowledge = db.query(Knowledge).filter(Knowledge.id == db_document.kb_id).first() - # 1. Document parsing & segmentation - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Start to parse.\n" - start_time = time.time() - db_document.progress = 0.0 - db_document.progress_msg = progress_msg - db_document.process_begin_at = datetime.now(tz=timezone.utc) - db_document.process_duration = 0.0 - db_document.run = 1 - db.commit() - db.refresh(db_document) - - def progress_callback(prog=None, msg=None): - nonlocal progress_msg # Declare the use of an external progress_msg variable - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.\n" - # Prepare to configure chat_mdl、vision_model information - chat_model = Base( - key=db_knowledge.llm.api_keys[0].api_key, - model_name=db_knowledge.llm.api_keys[0].model_name, - base_url=db_knowledge.llm.api_keys[0].api_base - ) - vision_model = QWenCV( - key=db_knowledge.image2text.api_keys[0].api_key, - model_name=db_knowledge.image2text.api_keys[0].model_name, - lang="Chinese", - base_url=db_knowledge.image2text.api_keys[0].api_base - ) - from app.core.rag.app.naive import chunk - res = chunk(filename=file_path, - from_page=0, - to_page=100000, - callback=progress_callback, - vision_model=vision_model, - parser_config=db_document.parser_config, - is_root=False) - - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Finish parsing.\n" - db_document.progress = 0.8 - db_document.progress_msg = progress_msg - db.commit() - db.refresh(db_document) - - # 2. Document vectorization and storage - total_chunks = len(res) - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Generate {total_chunks} chunks.\n" - batch_size = 100 - total_batches = ceil(total_chunks / batch_size) - progress_per_batch = 0.2 / total_batches # Progress of each batch - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - # 2.1 Delete document vector index - vector_service.delete_by_metadata_field(key="document_id", value=str(document_id)) - # 2.2 Vectorize and import batch documents - for batch_start in range(0, total_chunks, batch_size): - batch_end = min(batch_start + batch_size, total_chunks) # prevent out-of-bounds - batch = res[batch_start: batch_end] # Retrieve the current batch - chunks = [] - - # Process the current batch - for idx_in_batch, item in enumerate(batch): - global_idx = batch_start + idx_in_batch # Calculate global index - metadata = { - "doc_id": uuid.uuid4().hex, - "file_id": str(db_document.file_id), - "file_name": db_document.file_name, - "file_created_at": int(db_document.created_at.timestamp() * 1000), - "document_id": str(db_document.id), - "knowledge_id": str(db_document.kb_id), - "sort_id": global_idx, - "status": 1, - } - if db_document.parser_config.get("auto_questions", 0): - topn = db_document.parser_config["auto_questions"] - cached = get_llm_cache(chat_model.model_name, item["content_with_weight"], "question", {"topn": topn}) - if not cached: - cached = question_proposal(chat_model, item["content_with_weight"], topn) - set_llm_cache(chat_model.model_name, item["content_with_weight"], cached, "question", {"topn": topn}) - chunks.append(DocumentChunk(page_content=f"question: {cached} answer: {item['content_with_weight']}", metadata=metadata)) - else: - chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata)) - - # Bulk segmented vector import - vector_service.add_chunks(chunks) - - # Update progress - db_document.progress += progress_per_batch - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Embedding progress ({db_document.progress}).\n" + with get_db_context() as db: + db_document = None + db_knowledge = None + progress_msg = f"{datetime.now().strftime('%H:%M:%S')} Task has been received.\n" + try: + db_document = db.query(Document).filter(Document.id == document_id).first() + db_knowledge = db.query(Knowledge).filter(Knowledge.id == db_document.kb_id).first() + # 1. Document parsing & segmentation + progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Start to parse.\n" + start_time = time.time() + db_document.progress = 0.0 db_document.progress_msg = progress_msg - db_document.process_duration = time.time() - start_time - db_document.run = 0 + db_document.process_begin_at = datetime.now(tz=timezone.utc) + db_document.process_duration = 0.0 + db_document.run = 1 db.commit() db.refresh(db_document) - # Vectorization and data entry completed - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Indexing done.\n" - db_document.chunk_num = total_chunks - db_document.progress = 1.0 - db_document.process_duration = time.time() - start_time - progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Task done ({db_document.process_duration}s).\n" - db_document.progress_msg = progress_msg - db_document.run = 0 - db.commit() - result = f"parse document '{db_document.file_name}' processed successfully." - return result - except Exception as e: - if 'db_document' in locals(): - db_document.progress_msg += f"Failed to vectorize and import the parsed document:{str(e)}\n" + def progress_callback(prog=None, msg=None): + nonlocal progress_msg # Declare the use of an external progress_msg variable + progress_msg += f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.\n" + # Prepare to configure chat_mdl、vision_model information + chat_model = Base( + key=db_knowledge.llm.api_keys[0].api_key, + model_name=db_knowledge.llm.api_keys[0].model_name, + base_url=db_knowledge.llm.api_keys[0].api_base + ) + vision_model = QWenCV( + key=db_knowledge.image2text.api_keys[0].api_key, + model_name=db_knowledge.image2text.api_keys[0].model_name, + lang="Chinese", + base_url=db_knowledge.image2text.api_keys[0].api_base + ) + from app.core.rag.app.naive import chunk + res = chunk(filename=file_path, + from_page=0, + to_page=100000, + callback=progress_callback, + vision_model=vision_model, + parser_config=db_document.parser_config, + is_root=False) + + progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Finish parsing.\n" + db_document.progress = 0.8 + db_document.progress_msg = progress_msg + db.commit() + db.refresh(db_document) + + # 2. Document vectorization and storage + total_chunks = len(res) + progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Generate {total_chunks} chunks.\n" + batch_size = 100 + total_batches = ceil(total_chunks / batch_size) + progress_per_batch = 0.2 / total_batches # Progress of each batch + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + # 2.1 Delete document vector index + vector_service.delete_by_metadata_field(key="document_id", value=str(document_id)) + # 2.2 Vectorize and import batch documents + for batch_start in range(0, total_chunks, batch_size): + batch_end = min(batch_start + batch_size, total_chunks) # prevent out-of-bounds + batch = res[batch_start: batch_end] # Retrieve the current batch + chunks = [] + + # Process the current batch + for idx_in_batch, item in enumerate(batch): + global_idx = batch_start + idx_in_batch # Calculate global index + metadata = { + "doc_id": uuid.uuid4().hex, + "file_id": str(db_document.file_id), + "file_name": db_document.file_name, + "file_created_at": int(db_document.created_at.timestamp() * 1000), + "document_id": str(db_document.id), + "knowledge_id": str(db_document.kb_id), + "sort_id": global_idx, + "status": 1, + } + if db_document.parser_config.get("auto_questions", 0): + topn = db_document.parser_config["auto_questions"] + cached = get_llm_cache(chat_model.model_name, item["content_with_weight"], "question", {"topn": topn}) + if not cached: + cached = question_proposal(chat_model, item["content_with_weight"], topn) + set_llm_cache(chat_model.model_name, item["content_with_weight"], cached, "question", {"topn": topn}) + chunks.append(DocumentChunk(page_content=f"question: {cached} answer: {item['content_with_weight']}", metadata=metadata)) + else: + chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata)) + + # Bulk segmented vector import + vector_service.add_chunks(chunks) + + # Update progress + db_document.progress += progress_per_batch + progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Embedding progress ({db_document.progress}).\n" + db_document.progress_msg = progress_msg + db_document.process_duration = time.time() - start_time + db_document.run = 0 + db.commit() + db.refresh(db_document) + + # Vectorization and data entry completed + progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Indexing done.\n" + db_document.chunk_num = total_chunks + db_document.progress = 1.0 + db_document.process_duration = time.time() - start_time + progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Task done ({db_document.process_duration}s).\n" + db_document.progress_msg = progress_msg db_document.run = 0 db.commit() - result = f"parse document '{db_document.file_name}' failed." - return result - finally: - db.close() + result = f"parse document '{db_document.file_name}' processed successfully." + return result + except Exception as e: + if 'db_document' in locals(): + db_document.progress_msg += f"Failed to vectorize and import the parsed document:{str(e)}\n" + db_document.run = 0 + db.commit() + result = f"parse document '{db_document.file_name}' failed." + return result @celery_app.task(name="app.core.memory.agent.read_message", bind=True) @@ -362,75 +358,75 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: from app.models.end_user_model import EndUser from app.models.app_model import App - db = next(get_db()) - try: - workspace_uuid = uuid.UUID(workspace_id) - - # 1. 查询当前workspace下的所有app - apps = db.query(App).filter(App.workspace_id == workspace_uuid).all() - - if not apps: - # 如果没有app,总量为0 + with get_db_context() as db: + try: + workspace_uuid = uuid.UUID(workspace_id) + + # 1. 查询当前workspace下的所有app + apps = db.query(App).filter(App.workspace_id == workspace_uuid).all() + + if not apps: + # 如果没有app,总量为0 + memory_increment = write_memory_increment( + db=db, + workspace_id=workspace_uuid, + total_num=0 + ) + return { + "status": "SUCCESS", + "workspace_id": workspace_id, + "total_num": 0, + "end_user_count": 0, + "memory_increment_id": str(memory_increment.id), + "created_at": memory_increment.created_at.isoformat(), + } + + # 2. 查询所有app下的end_user_id(去重) + app_ids = [app.id for app in apps] + end_users = db.query(EndUser.id).filter( + EndUser.app_id.in_(app_ids) + ).distinct().all() + + # 3. 遍历所有end_user,查询每个宿主的记忆总量并累加 + total_num = 0 + end_user_details = [] + + for (end_user_id,) in end_users: + try: + # 调用 search_all 接口查询该宿主的总量 + result = await search_all(str(end_user_id)) + user_total = result.get("total", 0) + total_num += user_total + end_user_details.append({ + "end_user_id": str(end_user_id), + "total": user_total + }) + except Exception as e: + # 记录单个用户查询失败,但继续处理其他用户 + end_user_details.append({ + "end_user_id": str(end_user_id), + "total": 0, + "error": str(e) + }) + + # 4. 写入数据库 memory_increment = write_memory_increment( db=db, workspace_id=workspace_uuid, - total_num=0 + total_num=total_num ) + return { "status": "SUCCESS", "workspace_id": workspace_id, - "total_num": 0, - "end_user_count": 0, + "total_num": total_num, + "end_user_count": len(end_users), + "end_user_details": end_user_details, "memory_increment_id": str(memory_increment.id), "created_at": memory_increment.created_at.isoformat(), } - - # 2. 查询所有app下的end_user_id(去重) - app_ids = [app.id for app in apps] - end_users = db.query(EndUser.id).filter( - EndUser.app_id.in_(app_ids) - ).distinct().all() - - # 3. 遍历所有end_user,查询每个宿主的记忆总量并累加 - total_num = 0 - end_user_details = [] - - for (end_user_id,) in end_users: - try: - # 调用 search_all 接口查询该宿主的总量 - result = await search_all(str(end_user_id)) - user_total = result.get("total", 0) - total_num += user_total - end_user_details.append({ - "end_user_id": str(end_user_id), - "total": user_total - }) - except Exception as e: - # 记录单个用户查询失败,但继续处理其他用户 - end_user_details.append({ - "end_user_id": str(end_user_id), - "total": 0, - "error": str(e) - }) - - # 4. 写入数据库 - memory_increment = write_memory_increment( - db=db, - workspace_id=workspace_uuid, - total_num=total_num - ) - - return { - "status": "SUCCESS", - "workspace_id": workspace_id, - "total_num": total_num, - "end_user_count": len(end_users), - "end_user_details": end_user_details, - "memory_increment_id": str(memory_increment.id), - "created_at": memory_increment.created_at.isoformat(), - } - finally: - db.close() + except Exception as e: + raise e try: result = asyncio.run(_run()) @@ -447,6 +443,198 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: } +@celery_app.task(name="app.tasks.regenerate_memory_cache", bind=True) +def regenerate_memory_cache(self) -> Dict[str, Any]: + """定时任务:为所有用户重新生成记忆洞察和用户摘要缓存 + + 遍历所有活动工作空间的所有终端用户,为每个用户重新生成记忆洞察和用户摘要。 + 实现错误隔离,单个用户失败不影响其他用户的处理。 + + Returns: + 包含任务执行结果的字典,包括: + - status: 任务状态 (SUCCESS/FAILURE) + - message: 执行消息 + - workspace_count: 处理的工作空间数量 + - total_users: 总用户数 + - successful: 成功生成的用户数 + - failed: 失败的用户数 + - workspace_results: 每个工作空间的详细结果 + - elapsed_time: 执行耗时(秒) + - task_id: 任务ID + """ + start_time = time.time() + + async def _run() -> Dict[str, Any]: + from app.services.user_memory_service import UserMemoryService + from app.repositories.end_user_repository import EndUserRepository + from app.core.logging_config import get_logger + + logger = get_logger(__name__) + logger.info("开始执行记忆缓存重新生成定时任务") + + service = UserMemoryService() + + total_users = 0 + successful = 0 + failed = 0 + workspace_results = [] + + with get_db_context() as db: + try: + # 获取所有活动工作空间 + repo = EndUserRepository(db) + workspaces = repo.get_all_active_workspaces() + logger.info(f"找到 {len(workspaces)} 个活动工作空间") + + # 遍历每个工作空间 + for workspace_id in workspaces: + logger.info(f"开始处理工作空间: {workspace_id}") + workspace_start_time = time.time() + + try: + # 获取工作空间的所有终端用户 + end_users = repo.get_all_by_workspace(workspace_id) + workspace_user_count = len(end_users) + total_users += workspace_user_count + + logger.info(f"工作空间 {workspace_id} 有 {workspace_user_count} 个终端用户") + + workspace_successful = 0 + workspace_failed = 0 + workspace_errors = [] + + # 遍历每个用户并生成缓存 + for end_user in end_users: + end_user_id = str(end_user.id) + + try: + # 生成记忆洞察 + insight_result = await service.generate_and_cache_insight(db, end_user_id) + + # 生成用户摘要 + summary_result = await service.generate_and_cache_summary(db, end_user_id) + + # 检查是否都成功 + if insight_result["success"] and summary_result["success"]: + workspace_successful += 1 + successful += 1 + logger.info(f"成功为终端用户 {end_user_id} 重新生成缓存") + else: + workspace_failed += 1 + failed += 1 + error_info = { + "end_user_id": end_user_id, + "insight_error": insight_result.get("error"), + "summary_error": summary_result.get("error") + } + workspace_errors.append(error_info) + logger.warning(f"终端用户 {end_user_id} 的缓存重新生成部分失败: {error_info}") + + except Exception as e: + # 单个用户失败不影响其他用户(错误隔离) + workspace_failed += 1 + failed += 1 + error_info = { + "end_user_id": end_user_id, + "error": str(e) + } + workspace_errors.append(error_info) + logger.error(f"为终端用户 {end_user_id} 重新生成缓存时出错: {str(e)}") + + workspace_elapsed = time.time() - workspace_start_time + + # 记录工作空间处理结果 + workspace_result = { + "workspace_id": str(workspace_id), + "total_users": workspace_user_count, + "successful": workspace_successful, + "failed": workspace_failed, + "errors": workspace_errors[:10], # 只保留前10个错误 + "elapsed_time": workspace_elapsed + } + workspace_results.append(workspace_result) + + logger.info( + f"工作空间 {workspace_id} 处理完成: " + f"总数={workspace_user_count}, 成功={workspace_successful}, " + f"失败={workspace_failed}, 耗时={workspace_elapsed:.2f}秒" + ) + + except Exception as e: + # 工作空间处理失败,记录错误并继续处理下一个 + logger.error(f"处理工作空间 {workspace_id} 时出错: {str(e)}") + workspace_results.append({ + "workspace_id": str(workspace_id), + "error": str(e), + "total_users": 0, + "successful": 0, + "failed": 0, + "errors": [] + }) + + # 记录总体统计信息 + logger.info( + f"记忆缓存重新生成定时任务完成: " + f"工作空间数={len(workspaces)}, 总用户数={total_users}, " + f"成功={successful}, 失败={failed}" + ) + + return { + "status": "SUCCESS", + "message": f"成功处理 {len(workspaces)} 个工作空间,总共 {successful}/{total_users} 个用户缓存重新生成成功", + "workspace_count": len(workspaces), + "total_users": total_users, + "successful": successful, + "failed": failed, + "workspace_results": workspace_results + } + + except Exception as e: + logger.error(f"记忆缓存重新生成定时任务执行失败: {str(e)}") + return { + "status": "FAILURE", + "error": str(e), + "workspace_count": len(workspace_results), + "total_users": total_users, + "successful": successful, + "failed": failed, + "workspace_results": workspace_results + } + + try: + # 使用 nest_asyncio 来避免事件循环冲突 + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + # 尝试获取现有事件循环,如果不存在则创建新的 + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + result = loop.run_until_complete(_run()) + elapsed_time = time.time() - start_time + result["elapsed_time"] = elapsed_time + result["task_id"] = self.request.id + + return result + except Exception as e: + elapsed_time = time.time() - start_time + return { + "status": "FAILURE", + "error": str(e), + "elapsed_time": elapsed_time, + "task_id": self.request.id + } + + @celery_app.task(name="app.tasks.workspace_reflection_task", bind=True) def workspace_reflection_task(self) -> Dict[str, Any]: """定时任务:每30秒运行工作空间反思功能 @@ -462,100 +650,98 @@ def workspace_reflection_task(self) -> Dict[str, Any]: from app.core.logging_config import get_api_logger api_logger = get_api_logger() - db = next(get_db()) + + with get_db_context() as db: + try: + # 获取所有工作空间 + workspaces = db.query(Workspace).all() - try: - # 获取所有工作空间 - workspaces = db.query(Workspace).all() + if not workspaces: + return { + "status": "SUCCESS", + "message": "没有找到工作空间", + "workspace_count": 0, + "reflection_results": [] + } + + all_reflection_results = [] + + # 遍历每个工作空间 + for workspace in workspaces: + workspace_id = workspace.id + api_logger.info(f"开始处理工作空间反思,workspace_id: {workspace_id}") + + try: + reflection_service = MemoryReflectionService(db) + + # 使用服务类处理复杂查询逻辑 + service = WorkspaceAppService(db) + result = service.get_workspace_apps_detailed(str(workspace_id)) + + workspace_reflection_results = [] + + for data in result['apps_detailed_info']: + if data['data_configs'] == []: + continue + + releases = data['releases'] + data_configs = data['data_configs'] + end_users = data['end_users'] + + for base, config, user in zip(releases, data_configs, end_users): + if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']: + # 调用反思服务 + api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}") + + reflection_result = await reflection_service.start_reflection_from_data( + config_data=config, + end_user_id=user['id'] + ) + + workspace_reflection_results.append({ + "app_id": base['app_id'], + "config_id": config['config_id'], + "end_user_id": user['id'], + "reflection_result": reflection_result + }) + + all_reflection_results.append({ + "workspace_id": str(workspace_id), + "reflection_count": len(workspace_reflection_results), + "reflection_results": workspace_reflection_results + }) + + api_logger.info( + f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务") + + except Exception as e: + api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}") + all_reflection_results.append({ + "workspace_id": str(workspace_id), + "error": str(e), + "reflection_count": 0, + "reflection_results": [] + }) + + total_reflections = sum(r.get("reflection_count", 0) for r in all_reflection_results) - if not workspaces: return { "status": "SUCCESS", - "message": "没有找到工作空间", + "message": f"成功处理 {len(workspaces)} 个工作空间,总共 {total_reflections} 个反思任务", + "workspace_count": len(workspaces), + "total_reflections": total_reflections, + "workspace_results": all_reflection_results + } + + except Exception as e: + api_logger.error(f"工作空间反思任务执行失败: {str(e)}") + return { + "status": "FAILURE", + "error": str(e), "workspace_count": 0, "reflection_results": [] } - all_reflection_results = [] - - # 遍历每个工作空间 - for workspace in workspaces: - workspace_id = workspace.id - api_logger.info(f"开始处理工作空间反思,workspace_id: {workspace_id}") - - try: - reflection_service = MemoryReflectionService(db) - - # 使用服务类处理复杂查询逻辑 - service = WorkspaceAppService(db) - result = service.get_workspace_apps_detailed(str(workspace_id)) - - workspace_reflection_results = [] - - for data in result['apps_detailed_info']: - if data['data_configs'] == []: - continue - - releases = data['releases'] - data_configs = data['data_configs'] - end_users = data['end_users'] - - for base, config, user in zip(releases, data_configs, end_users): - if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']: - # 调用反思服务 - api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}") - - reflection_result = await reflection_service.start_reflection_from_data( - config_data=config, - end_user_id=user['id'] - ) - - workspace_reflection_results.append({ - "app_id": base['app_id'], - "config_id": config['config_id'], - "end_user_id": user['id'], - "reflection_result": reflection_result - }) - - all_reflection_results.append({ - "workspace_id": str(workspace_id), - "reflection_count": len(workspace_reflection_results), - "reflection_results": workspace_reflection_results - }) - - api_logger.info( - f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务") - - except Exception as e: - api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}") - all_reflection_results.append({ - "workspace_id": str(workspace_id), - "error": str(e), - "reflection_count": 0, - "reflection_results": [] - }) - - total_reflections = sum(r.get("reflection_count", 0) for r in all_reflection_results) - - return { - "status": "SUCCESS", - "message": f"成功处理 {len(workspaces)} 个工作空间,总共 {total_reflections} 个反思任务", - "workspace_count": len(workspaces), - "total_reflections": total_reflections, - "workspace_results": all_reflection_results - } - - except Exception as e: - api_logger.error(f"工作空间反思任务执行失败: {str(e)}") - return { - "status": "FAILURE", - "error": str(e), - "workspace_count": 0, - "reflection_results": [] - } - finally: - db.close() - try: # 使用 nest_asyncio 来避免事件循环冲突 try: diff --git a/api/env.example b/api/env.example index c4e0c1eb..1354233d 100644 --- a/api/env.example +++ b/api/env.example @@ -30,6 +30,11 @@ RESULT_BACKEND= CELERY_BROKER= CELERY_BACKEND= +# Memory Cache Regeneration Configuration +# Interval in hours for regenerating memory insight and user summary cache +# Default: 24 hours +MEMORY_CACHE_REGENERATION_HOURS=24 + # ElasticSearch configuration ELASTICSEARCH_HOST= ELASTICSEARCH_PORT=