feat(cache): Add thread-safe Redis client and enable activity stats cache

- Add get_thread_safe_redis() function with thread-local storage and PID checking to prevent "Future attached to a different loop" errors in Celery thread and prefork pools
- Implement health_check_interval=30 to prevent stale connection errors after fork
- Uncomment and enable ActivityStatsCache module in cache/memory/__init__.py
- Uncomment ActivityStatsCache implementation in activity_stats_cache.py and update to use get_thread_safe_redis()
- Update interest_memory.py to use thread-safe Redis client
- Update write_tools.py to use thread-safe Redis client
- Remove redundant Chinese comments from aioRedis.py for cleaner code
- Ensures safe Redis operations across different execution contexts and Celery worker configurations
This commit is contained in:
Ke Sun
2026-03-27 15:35:47 +08:00
committed by lanceyq
parent 17ea92357d
commit 4e9b5736b1
5 changed files with 167 additions and 132 deletions

View File

@@ -1,6 +1,8 @@
import asyncio import asyncio
import json import json
import logging import logging
import os
import threading
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
import redis.asyncio as redis import redis.asyncio as redis
@@ -21,6 +23,41 @@ pool = ConnectionPool.from_url(
) )
aio_redis = redis.StrictRedis(connection_pool=pool) aio_redis = redis.StrictRedis(connection_pool=pool)
_REDIS_URL = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}"
# Thread-local storage for connection pools.
# Each thread (and each forked process) gets its own pool to avoid
# "Future attached to a different loop" errors in Celery --pool=threads
# and stale connections after fork in --pool=prefork.
_thread_local = threading.local()
def get_thread_safe_redis() -> redis.StrictRedis:
"""Get a Redis client safe for the current execution context.
Uses thread-local storage with PID checking to ensure:
- Each thread gets its own ConnectionPool (Celery --pool=threads)
- Pools are recreated after fork (Celery --pool=prefork)
- health_check_interval prevents stale connection errors
Returns:
redis.StrictRedis: A Redis client with a thread/process-local pool.
"""
current_pid = os.getpid()
if not hasattr(_thread_local, "pool") or getattr(_thread_local, "pid", None) != current_pid:
_thread_local.pid = current_pid
_thread_local.pool = ConnectionPool.from_url(
_REDIS_URL,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD,
decode_responses=True,
max_connections=5,
health_check_interval=30,
)
return redis.StrictRedis(connection_pool=_thread_local.pool)
async def get_redis_connection(): async def get_redis_connection():
"""获取Redis连接""" """获取Redis连接"""
@@ -44,10 +81,8 @@ async def aio_redis_set(key: str, val: str | dict, expire: int = None):
val = json.dumps(val, ensure_ascii=False) val = json.dumps(val, ensure_ascii=False)
if expire is not None: if expire is not None:
# 设置带过期时间的键值
await aio_redis.set(key, val, ex=expire) await aio_redis.set(key, val, ex=expire)
else: else:
# 设置永久键值
await aio_redis.set(key, val) await aio_redis.set(key, val)
except Exception as e: except Exception as e:
logger.error(f"Redis set错误: {str(e)}") logger.error(f"Redis set错误: {str(e)}")

View File

@@ -4,9 +4,9 @@ Memory 缓存模块
提供记忆系统相关的缓存功能 提供记忆系统相关的缓存功能
""" """
from .interest_memory import InterestMemoryCache from .interest_memory import InterestMemoryCache
# from .activity_stats_cache import ActivityStatsCache from .activity_stats_cache import ActivityStatsCache
__all__ = [ __all__ = [
"InterestMemoryCache", "InterestMemoryCache",
# "ActivityStatsCache", "ActivityStatsCache",
] ]

View File

@@ -1,124 +1,124 @@
# """ """
# Recent Activity Stats Cache Recent Activity Stats Cache
# 记忆提取活动统计缓存模块 记忆提取活动统计缓存模块
# 用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储24小时后释放 用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储24小时后释放
# 查询命令cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843 查询命令cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843
# """ """
# import json import json
# import logging import logging
# from typing import Optional, Dict, Any from typing import Optional, Dict, Any
# from datetime import datetime from datetime import datetime
# from app.aioRedis import aio_redis from app.aioRedis import get_thread_safe_redis
# logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# # 缓存过期时间24小时 # 缓存过期时间24小时
# ACTIVITY_STATS_CACHE_EXPIRE = 86400 ACTIVITY_STATS_CACHE_EXPIRE = 86400
# class ActivityStatsCache: class ActivityStatsCache:
# """记忆提取活动统计缓存类""" """记忆提取活动统计缓存类"""
# PREFIX = "cache:memory:activity_stats" PREFIX = "cache:memory:activity_stats"
# @classmethod @classmethod
# def _get_key(cls, workspace_id: str) -> str: def _get_key(cls, workspace_id: str) -> str:
# """生成 Redis key """生成 Redis key
# Args: Args:
# workspace_id: 工作空间ID workspace_id: 工作空间ID
# Returns: Returns:
# 完整的 Redis key 完整的 Redis key
# """ """
# return f"{cls.PREFIX}:by_workspace:{workspace_id}" return f"{cls.PREFIX}:by_workspace:{workspace_id}"
# @classmethod @classmethod
# async def set_activity_stats( async def set_activity_stats(
# cls, cls,
# workspace_id: str, workspace_id: str,
# stats: Dict[str, Any], stats: Dict[str, Any],
# expire: int = ACTIVITY_STATS_CACHE_EXPIRE, expire: int = ACTIVITY_STATS_CACHE_EXPIRE,
# ) -> bool: ) -> bool:
# """设置记忆提取活动统计缓存 """设置记忆提取活动统计缓存
# Args: Args:
# workspace_id: 工作空间ID workspace_id: 工作空间ID
# stats: 统计数据,格式: stats: 统计数据,格式:
# { {
# "chunk_count": int, "chunk_count": int,
# "statements_count": int, "statements_count": int,
# "triplet_entities_count": int, "triplet_entities_count": int,
# "triplet_relations_count": int, "triplet_relations_count": int,
# "temporal_count": int, "temporal_count": int,
# } }
# expire: 过期时间默认24小时 expire: 过期时间默认24小时
# Returns: Returns:
# 是否设置成功 是否设置成功
# """ """
# try: try:
# key = cls._get_key(workspace_id) key = cls._get_key(workspace_id)
# payload = { payload = {
# "stats": stats, "stats": stats,
# "generated_at": datetime.now().isoformat(), "generated_at": datetime.now().isoformat(),
# "workspace_id": workspace_id, "workspace_id": workspace_id,
# "cached": True, "cached": True,
# } }
# value = json.dumps(payload, ensure_ascii=False) value = json.dumps(payload, ensure_ascii=False)
# await aio_redis.set(key, value, ex=expire) await get_thread_safe_redis().set(key, value, ex=expire)
# logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒") logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}")
# return True return True
# except Exception as e: except Exception as e:
# logger.error(f"设置活动统计缓存失败: {e}", exc_info=True) logger.error(f"设置活动统计缓存失败: {e}", exc_info=True)
# return False return False
# @classmethod @classmethod
# async def get_activity_stats( async def get_activity_stats(
# cls, cls,
# workspace_id: str, workspace_id: str,
# ) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
# """获取记忆提取活动统计缓存 """获取记忆提取活动统计缓存
# Args: Args:
# workspace_id: 工作空间ID workspace_id: 工作空间ID
# Returns: Returns:
# 统计数据字典,缓存不存在或已过期返回 None 统计数据字典,缓存不存在或已过期返回 None
# """ """
# try: try:
# key = cls._get_key(workspace_id) key = cls._get_key(workspace_id)
# value = await aio_redis.get(key) value = await get_thread_safe_redis().get(key)
# if value: if value:
# payload = json.loads(value) payload = json.loads(value)
# logger.info(f"命中活动统计缓存: {key}") logger.info(f"命中活动统计缓存: {key}")
# return payload return payload
# logger.info(f"活动统计缓存不存在或已过期: {key}") logger.info(f"活动统计缓存不存在或已过期: {key}")
# return None return None
# except Exception as e: except Exception as e:
# logger.error(f"获取活动统计缓存失败: {e}", exc_info=True) logger.error(f"获取活动统计缓存失败: {e}", exc_info=True)
# return None return None
# @classmethod @classmethod
# async def delete_activity_stats( async def delete_activity_stats(
# cls, cls,
# workspace_id: str, workspace_id: str,
# ) -> bool: ) -> bool:
# """删除记忆提取活动统计缓存 """删除记忆提取活动统计缓存
# Args: Args:
# workspace_id: 工作空间ID workspace_id: 工作空间ID
# Returns: Returns:
# 是否删除成功 是否删除成功
# """ """
# try: try:
# key = cls._get_key(workspace_id) key = cls._get_key(workspace_id)
# result = await aio_redis.delete(key) result = await get_thread_safe_redis().delete(key)
# logger.info(f"删除活动统计缓存: {key}, 结果: {result}") logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
# return result > 0 return result > 0
# except Exception as e: except Exception as e:
# logger.error(f"删除活动统计缓存失败: {e}", exc_info=True) logger.error(f"删除活动统计缓存失败: {e}", exc_info=True)
# return False return False

View File

@@ -9,7 +9,7 @@ import logging
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
from datetime import datetime from datetime import datetime
from app.aioRedis import aio_redis from app.aioRedis import get_thread_safe_redis
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -62,7 +62,7 @@ class InterestMemoryCache:
"cached": True, "cached": True,
} }
value = json.dumps(payload, ensure_ascii=False) value = json.dumps(payload, ensure_ascii=False)
await aio_redis.set(key, value, ex=expire) await get_thread_safe_redis().set(key, value, ex=expire)
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}") logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}")
return True return True
except Exception as e: except Exception as e:
@@ -86,7 +86,7 @@ class InterestMemoryCache:
""" """
try: try:
key = cls._get_key(end_user_id, language) key = cls._get_key(end_user_id, language)
value = await aio_redis.get(key) value = await get_thread_safe_redis().get(key)
if value: if value:
payload = json.loads(value) payload = json.loads(value)
logger.info(f"命中兴趣分布缓存: {key}") logger.info(f"命中兴趣分布缓存: {key}")
@@ -114,7 +114,7 @@ class InterestMemoryCache:
""" """
try: try:
key = cls._get_key(end_user_id, language) key = cls._get_key(end_user_id, language)
result = await aio_redis.delete(key) result = await get_thread_safe_redis().delete(key)
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}") logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
return result > 0 return result > 0
except Exception as e: except Exception as e:

View File

@@ -260,24 +260,24 @@ async def write(
with open(log_file, "a", encoding="utf-8") as f: with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n") f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
# # 将提取统计写入 Redis按 workspace_id 存储 # 将提取统计写入 Redis按 workspace_id 存储
# try: try:
# from app.cache.memory.activity_stats_cache import ActivityStatsCache from app.cache.memory.activity_stats_cache import ActivityStatsCache
# stats_to_cache = { stats_to_cache = {
# "chunk_count": len(all_chunk_nodes) if all_chunk_nodes else 0, "chunk_count": len(all_chunk_nodes) if all_chunk_nodes else 0,
# "statements_count": len(all_statement_nodes) if all_statement_nodes else 0, "statements_count": len(all_statement_nodes) if all_statement_nodes else 0,
# "triplet_entities_count": len(all_entity_nodes) if all_entity_nodes else 0, "triplet_entities_count": len(all_entity_nodes) if all_entity_nodes else 0,
# "triplet_relations_count": len(all_entity_entity_edges) if all_entity_entity_edges else 0, "triplet_relations_count": len(all_entity_entity_edges) if all_entity_entity_edges else 0,
# "temporal_count": 0, "temporal_count": 0,
# } }
# await ActivityStatsCache.set_activity_stats( await ActivityStatsCache.set_activity_stats(
# workspace_id=str(memory_config.workspace_id), workspace_id=str(memory_config.workspace_id),
# stats=stats_to_cache, stats=stats_to_cache,
# ) )
# logger.info(f"[WRITE] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}") logger.info(f"[WRITE] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}")
# except Exception as cache_err: except Exception as cache_err:
# logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
# logger.info("=== Pipeline Complete ===") logger.info("=== Pipeline Complete ===")
# logger.info(f"Total execution time: {total_time:.2f} seconds") logger.info(f"Total execution time: {total_time:.2f} seconds")