From 4e9b5736b151cd2947d1fe1a10f424315247fa2d Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Fri, 27 Mar 2026 15:35:47 +0800 Subject: [PATCH] feat(cache): Add thread-safe Redis client and enable activity stats cache - Add get_thread_safe_redis() function with thread-local storage and PID checking to prevent "Future attached to a different loop" errors in Celery thread and prefork pools - Implement health_check_interval=30 to prevent stale connection errors after fork - Uncomment and enable ActivityStatsCache module in cache/memory/__init__.py - Uncomment ActivityStatsCache implementation in activity_stats_cache.py and update to use get_thread_safe_redis() - Update interest_memory.py to use thread-safe Redis client - Update write_tools.py to use thread-safe Redis client - Remove redundant Chinese comments from aioRedis.py for cleaner code - Ensures safe Redis operations across different execution contexts and Celery worker configurations --- api/app/aioRedis.py | 39 +++- api/app/cache/memory/__init__.py | 4 +- api/app/cache/memory/activity_stats_cache.py | 210 +++++++++--------- api/app/cache/memory/interest_memory.py | 8 +- .../core/memory/agent/utils/write_tools.py | 38 ++-- 5 files changed, 167 insertions(+), 132 deletions(-) diff --git a/api/app/aioRedis.py b/api/app/aioRedis.py index aac2aa84..f79ef0e1 100644 --- a/api/app/aioRedis.py +++ b/api/app/aioRedis.py @@ -1,6 +1,8 @@ import asyncio import json import logging +import os +import threading from typing import Dict, Any, Optional import redis.asyncio as redis @@ -21,6 +23,41 @@ pool = ConnectionPool.from_url( ) aio_redis = redis.StrictRedis(connection_pool=pool) +_REDIS_URL = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}" + +# Thread-local storage for connection pools. +# Each thread (and each forked process) gets its own pool to avoid +# "Future attached to a different loop" errors in Celery --pool=threads +# and stale connections after fork in --pool=prefork. +_thread_local = threading.local() + + +def get_thread_safe_redis() -> redis.StrictRedis: + """Get a Redis client safe for the current execution context. + + Uses thread-local storage with PID checking to ensure: + - Each thread gets its own ConnectionPool (Celery --pool=threads) + - Pools are recreated after fork (Celery --pool=prefork) + - health_check_interval prevents stale connection errors + + Returns: + redis.StrictRedis: A Redis client with a thread/process-local pool. + """ + current_pid = os.getpid() + + if not hasattr(_thread_local, "pool") or getattr(_thread_local, "pid", None) != current_pid: + _thread_local.pid = current_pid + _thread_local.pool = ConnectionPool.from_url( + _REDIS_URL, + db=settings.REDIS_DB, + password=settings.REDIS_PASSWORD, + decode_responses=True, + max_connections=5, + health_check_interval=30, + ) + + return redis.StrictRedis(connection_pool=_thread_local.pool) + async def get_redis_connection(): """获取Redis连接""" @@ -44,10 +81,8 @@ async def aio_redis_set(key: str, val: str | dict, expire: int = None): val = json.dumps(val, ensure_ascii=False) if expire is not None: - # 设置带过期时间的键值 await aio_redis.set(key, val, ex=expire) else: - # 设置永久键值 await aio_redis.set(key, val) except Exception as e: logger.error(f"Redis set错误: {str(e)}") diff --git a/api/app/cache/memory/__init__.py b/api/app/cache/memory/__init__.py index fa9ad1b1..551062ac 100644 --- a/api/app/cache/memory/__init__.py +++ b/api/app/cache/memory/__init__.py @@ -4,9 +4,9 @@ Memory 缓存模块 提供记忆系统相关的缓存功能 """ from .interest_memory import InterestMemoryCache -# from .activity_stats_cache import ActivityStatsCache +from .activity_stats_cache import ActivityStatsCache __all__ = [ "InterestMemoryCache", - # "ActivityStatsCache", + "ActivityStatsCache", ] diff --git a/api/app/cache/memory/activity_stats_cache.py b/api/app/cache/memory/activity_stats_cache.py index 35c702b1..e0008353 100644 --- a/api/app/cache/memory/activity_stats_cache.py +++ b/api/app/cache/memory/activity_stats_cache.py @@ -1,124 +1,124 @@ -# """ -# Recent Activity Stats Cache +""" +Recent Activity Stats Cache -# 记忆提取活动统计缓存模块 -# 用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储,24小时后释放 -# 查询命令:cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843 -# """ -# import json -# import logging -# from typing import Optional, Dict, Any -# from datetime import datetime +记忆提取活动统计缓存模块 +用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储,24小时后释放 +查询命令:cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843 +""" +import json +import logging +from typing import Optional, Dict, Any +from datetime import datetime -# from app.aioRedis import aio_redis +from app.aioRedis import get_thread_safe_redis -# logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) -# # 缓存过期时间:24小时 -# ACTIVITY_STATS_CACHE_EXPIRE = 86400 +# 缓存过期时间:24小时 +ACTIVITY_STATS_CACHE_EXPIRE = 86400 -# class ActivityStatsCache: -# """记忆提取活动统计缓存类""" +class ActivityStatsCache: + """记忆提取活动统计缓存类""" -# PREFIX = "cache:memory:activity_stats" + PREFIX = "cache:memory:activity_stats" -# @classmethod -# def _get_key(cls, workspace_id: str) -> str: -# """生成 Redis key + @classmethod + def _get_key(cls, workspace_id: str) -> str: + """生成 Redis key -# Args: -# workspace_id: 工作空间ID + Args: + workspace_id: 工作空间ID -# Returns: -# 完整的 Redis key -# """ -# return f"{cls.PREFIX}:by_workspace:{workspace_id}" + Returns: + 完整的 Redis key + """ + return f"{cls.PREFIX}:by_workspace:{workspace_id}" -# @classmethod -# async def set_activity_stats( -# cls, -# workspace_id: str, -# stats: Dict[str, Any], -# expire: int = ACTIVITY_STATS_CACHE_EXPIRE, -# ) -> bool: -# """设置记忆提取活动统计缓存 + @classmethod + async def set_activity_stats( + cls, + workspace_id: str, + stats: Dict[str, Any], + expire: int = ACTIVITY_STATS_CACHE_EXPIRE, + ) -> bool: + """设置记忆提取活动统计缓存 -# Args: -# workspace_id: 工作空间ID -# stats: 统计数据,格式: -# { -# "chunk_count": int, -# "statements_count": int, -# "triplet_entities_count": int, -# "triplet_relations_count": int, -# "temporal_count": int, -# } -# expire: 过期时间(秒),默认24小时 + Args: + workspace_id: 工作空间ID + stats: 统计数据,格式: + { + "chunk_count": int, + "statements_count": int, + "triplet_entities_count": int, + "triplet_relations_count": int, + "temporal_count": int, + } + expire: 过期时间(秒),默认24小时 -# Returns: -# 是否设置成功 -# """ -# try: -# key = cls._get_key(workspace_id) -# payload = { -# "stats": stats, -# "generated_at": datetime.now().isoformat(), -# "workspace_id": workspace_id, -# "cached": True, -# } -# value = json.dumps(payload, ensure_ascii=False) -# await aio_redis.set(key, value, ex=expire) -# logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒") -# return True -# except Exception as e: -# logger.error(f"设置活动统计缓存失败: {e}", exc_info=True) -# return False + Returns: + 是否设置成功 + """ + try: + key = cls._get_key(workspace_id) + payload = { + "stats": stats, + "generated_at": datetime.now().isoformat(), + "workspace_id": workspace_id, + "cached": True, + } + value = json.dumps(payload, ensure_ascii=False) + await get_thread_safe_redis().set(key, value, ex=expire) + logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒") + return True + except Exception as e: + logger.error(f"设置活动统计缓存失败: {e}", exc_info=True) + return False -# @classmethod -# async def get_activity_stats( -# cls, -# workspace_id: str, -# ) -> Optional[Dict[str, Any]]: -# """获取记忆提取活动统计缓存 + @classmethod + async def get_activity_stats( + cls, + workspace_id: str, + ) -> Optional[Dict[str, Any]]: + """获取记忆提取活动统计缓存 -# Args: -# workspace_id: 工作空间ID + Args: + workspace_id: 工作空间ID -# Returns: -# 统计数据字典,缓存不存在或已过期返回 None -# """ -# try: -# key = cls._get_key(workspace_id) -# value = await aio_redis.get(key) -# if value: -# payload = json.loads(value) -# logger.info(f"命中活动统计缓存: {key}") -# return payload -# logger.info(f"活动统计缓存不存在或已过期: {key}") -# return None -# except Exception as e: -# logger.error(f"获取活动统计缓存失败: {e}", exc_info=True) -# return None + Returns: + 统计数据字典,缓存不存在或已过期返回 None + """ + try: + key = cls._get_key(workspace_id) + value = await get_thread_safe_redis().get(key) + if value: + payload = json.loads(value) + logger.info(f"命中活动统计缓存: {key}") + return payload + logger.info(f"活动统计缓存不存在或已过期: {key}") + return None + except Exception as e: + logger.error(f"获取活动统计缓存失败: {e}", exc_info=True) + return None -# @classmethod -# async def delete_activity_stats( -# cls, -# workspace_id: str, -# ) -> bool: -# """删除记忆提取活动统计缓存 + @classmethod + async def delete_activity_stats( + cls, + workspace_id: str, + ) -> bool: + """删除记忆提取活动统计缓存 -# Args: -# workspace_id: 工作空间ID + Args: + workspace_id: 工作空间ID -# Returns: -# 是否删除成功 -# """ -# try: -# key = cls._get_key(workspace_id) -# result = await aio_redis.delete(key) -# logger.info(f"删除活动统计缓存: {key}, 结果: {result}") -# return result > 0 -# except Exception as e: -# logger.error(f"删除活动统计缓存失败: {e}", exc_info=True) -# return False + Returns: + 是否删除成功 + """ + try: + key = cls._get_key(workspace_id) + result = await get_thread_safe_redis().delete(key) + logger.info(f"删除活动统计缓存: {key}, 结果: {result}") + return result > 0 + except Exception as e: + logger.error(f"删除活动统计缓存失败: {e}", exc_info=True) + return False diff --git a/api/app/cache/memory/interest_memory.py b/api/app/cache/memory/interest_memory.py index 108e2a37..2881f06c 100644 --- a/api/app/cache/memory/interest_memory.py +++ b/api/app/cache/memory/interest_memory.py @@ -9,7 +9,7 @@ import logging from typing import Optional, List, Dict, Any from datetime import datetime -from app.aioRedis import aio_redis +from app.aioRedis import get_thread_safe_redis logger = logging.getLogger(__name__) @@ -62,7 +62,7 @@ class InterestMemoryCache: "cached": True, } value = json.dumps(payload, ensure_ascii=False) - await aio_redis.set(key, value, ex=expire) + await get_thread_safe_redis().set(key, value, ex=expire) logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒") return True except Exception as e: @@ -86,7 +86,7 @@ class InterestMemoryCache: """ try: key = cls._get_key(end_user_id, language) - value = await aio_redis.get(key) + value = await get_thread_safe_redis().get(key) if value: payload = json.loads(value) logger.info(f"命中兴趣分布缓存: {key}") @@ -114,7 +114,7 @@ class InterestMemoryCache: """ try: key = cls._get_key(end_user_id, language) - result = await aio_redis.delete(key) + result = await get_thread_safe_redis().delete(key) logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}") return result > 0 except Exception as e: diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index c01a36d1..55bcb8ba 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -260,24 +260,24 @@ async def write( with open(log_file, "a", encoding="utf-8") as f: f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n") - # # 将提取统计写入 Redis,按 workspace_id 存储 - # try: - # from app.cache.memory.activity_stats_cache import ActivityStatsCache + # 将提取统计写入 Redis,按 workspace_id 存储 + try: + from app.cache.memory.activity_stats_cache import ActivityStatsCache - # stats_to_cache = { - # "chunk_count": len(all_chunk_nodes) if all_chunk_nodes else 0, - # "statements_count": len(all_statement_nodes) if all_statement_nodes else 0, - # "triplet_entities_count": len(all_entity_nodes) if all_entity_nodes else 0, - # "triplet_relations_count": len(all_entity_entity_edges) if all_entity_entity_edges else 0, - # "temporal_count": 0, - # } - # await ActivityStatsCache.set_activity_stats( - # workspace_id=str(memory_config.workspace_id), - # stats=stats_to_cache, - # ) - # logger.info(f"[WRITE] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}") - # except Exception as cache_err: - # logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) + stats_to_cache = { + "chunk_count": len(all_chunk_nodes) if all_chunk_nodes else 0, + "statements_count": len(all_statement_nodes) if all_statement_nodes else 0, + "triplet_entities_count": len(all_entity_nodes) if all_entity_nodes else 0, + "triplet_relations_count": len(all_entity_entity_edges) if all_entity_entity_edges else 0, + "temporal_count": 0, + } + await ActivityStatsCache.set_activity_stats( + workspace_id=str(memory_config.workspace_id), + stats=stats_to_cache, + ) + logger.info(f"[WRITE] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}") + except Exception as cache_err: + logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) - # logger.info("=== Pipeline Complete ===") - # logger.info(f"Total execution time: {total_time:.2f} seconds") + logger.info("=== Pipeline Complete ===") + logger.info(f"Total execution time: {total_time:.2f} seconds")