feat(cache): Add thread-safe Redis client and enable activity stats cache
- Add get_thread_safe_redis() function with thread-local storage and PID checking to prevent "Future attached to a different loop" errors in Celery thread and prefork pools - Implement health_check_interval=30 to prevent stale connection errors after fork - Uncomment and enable ActivityStatsCache module in cache/memory/__init__.py - Uncomment ActivityStatsCache implementation in activity_stats_cache.py and update to use get_thread_safe_redis() - Update interest_memory.py to use thread-safe Redis client - Update write_tools.py to use thread-safe Redis client - Remove redundant Chinese comments from aioRedis.py for cleaner code - Ensures safe Redis operations across different execution contexts and Celery worker configurations
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
|
|
||||||
import redis.asyncio as redis
|
import redis.asyncio as redis
|
||||||
@@ -21,6 +23,41 @@ pool = ConnectionPool.from_url(
|
|||||||
)
|
)
|
||||||
aio_redis = redis.StrictRedis(connection_pool=pool)
|
aio_redis = redis.StrictRedis(connection_pool=pool)
|
||||||
|
|
||||||
|
_REDIS_URL = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}"
|
||||||
|
|
||||||
|
# Thread-local storage for connection pools.
|
||||||
|
# Each thread (and each forked process) gets its own pool to avoid
|
||||||
|
# "Future attached to a different loop" errors in Celery --pool=threads
|
||||||
|
# and stale connections after fork in --pool=prefork.
|
||||||
|
_thread_local = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
def get_thread_safe_redis() -> redis.StrictRedis:
|
||||||
|
"""Get a Redis client safe for the current execution context.
|
||||||
|
|
||||||
|
Uses thread-local storage with PID checking to ensure:
|
||||||
|
- Each thread gets its own ConnectionPool (Celery --pool=threads)
|
||||||
|
- Pools are recreated after fork (Celery --pool=prefork)
|
||||||
|
- health_check_interval prevents stale connection errors
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
redis.StrictRedis: A Redis client with a thread/process-local pool.
|
||||||
|
"""
|
||||||
|
current_pid = os.getpid()
|
||||||
|
|
||||||
|
if not hasattr(_thread_local, "pool") or getattr(_thread_local, "pid", None) != current_pid:
|
||||||
|
_thread_local.pid = current_pid
|
||||||
|
_thread_local.pool = ConnectionPool.from_url(
|
||||||
|
_REDIS_URL,
|
||||||
|
db=settings.REDIS_DB,
|
||||||
|
password=settings.REDIS_PASSWORD,
|
||||||
|
decode_responses=True,
|
||||||
|
max_connections=5,
|
||||||
|
health_check_interval=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
return redis.StrictRedis(connection_pool=_thread_local.pool)
|
||||||
|
|
||||||
|
|
||||||
async def get_redis_connection():
|
async def get_redis_connection():
|
||||||
"""获取Redis连接"""
|
"""获取Redis连接"""
|
||||||
@@ -44,10 +81,8 @@ async def aio_redis_set(key: str, val: str | dict, expire: int = None):
|
|||||||
val = json.dumps(val, ensure_ascii=False)
|
val = json.dumps(val, ensure_ascii=False)
|
||||||
|
|
||||||
if expire is not None:
|
if expire is not None:
|
||||||
# 设置带过期时间的键值
|
|
||||||
await aio_redis.set(key, val, ex=expire)
|
await aio_redis.set(key, val, ex=expire)
|
||||||
else:
|
else:
|
||||||
# 设置永久键值
|
|
||||||
await aio_redis.set(key, val)
|
await aio_redis.set(key, val)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Redis set错误: {str(e)}")
|
logger.error(f"Redis set错误: {str(e)}")
|
||||||
|
|||||||
4
api/app/cache/memory/__init__.py
vendored
4
api/app/cache/memory/__init__.py
vendored
@@ -4,9 +4,9 @@ Memory 缓存模块
|
|||||||
提供记忆系统相关的缓存功能
|
提供记忆系统相关的缓存功能
|
||||||
"""
|
"""
|
||||||
from .interest_memory import InterestMemoryCache
|
from .interest_memory import InterestMemoryCache
|
||||||
# from .activity_stats_cache import ActivityStatsCache
|
from .activity_stats_cache import ActivityStatsCache
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"InterestMemoryCache",
|
"InterestMemoryCache",
|
||||||
# "ActivityStatsCache",
|
"ActivityStatsCache",
|
||||||
]
|
]
|
||||||
|
|||||||
210
api/app/cache/memory/activity_stats_cache.py
vendored
210
api/app/cache/memory/activity_stats_cache.py
vendored
@@ -1,124 +1,124 @@
|
|||||||
# """
|
"""
|
||||||
# Recent Activity Stats Cache
|
Recent Activity Stats Cache
|
||||||
|
|
||||||
# 记忆提取活动统计缓存模块
|
记忆提取活动统计缓存模块
|
||||||
# 用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储,24小时后释放
|
用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储,24小时后释放
|
||||||
# 查询命令:cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843
|
查询命令:cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843
|
||||||
# """
|
"""
|
||||||
# import json
|
import json
|
||||||
# import logging
|
import logging
|
||||||
# from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
# from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
# from app.aioRedis import aio_redis
|
from app.aioRedis import get_thread_safe_redis
|
||||||
|
|
||||||
# logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# # 缓存过期时间:24小时
|
# 缓存过期时间:24小时
|
||||||
# ACTIVITY_STATS_CACHE_EXPIRE = 86400
|
ACTIVITY_STATS_CACHE_EXPIRE = 86400
|
||||||
|
|
||||||
|
|
||||||
# class ActivityStatsCache:
|
class ActivityStatsCache:
|
||||||
# """记忆提取活动统计缓存类"""
|
"""记忆提取活动统计缓存类"""
|
||||||
|
|
||||||
# PREFIX = "cache:memory:activity_stats"
|
PREFIX = "cache:memory:activity_stats"
|
||||||
|
|
||||||
# @classmethod
|
@classmethod
|
||||||
# def _get_key(cls, workspace_id: str) -> str:
|
def _get_key(cls, workspace_id: str) -> str:
|
||||||
# """生成 Redis key
|
"""生成 Redis key
|
||||||
|
|
||||||
# Args:
|
Args:
|
||||||
# workspace_id: 工作空间ID
|
workspace_id: 工作空间ID
|
||||||
|
|
||||||
# Returns:
|
Returns:
|
||||||
# 完整的 Redis key
|
完整的 Redis key
|
||||||
# """
|
"""
|
||||||
# return f"{cls.PREFIX}:by_workspace:{workspace_id}"
|
return f"{cls.PREFIX}:by_workspace:{workspace_id}"
|
||||||
|
|
||||||
# @classmethod
|
@classmethod
|
||||||
# async def set_activity_stats(
|
async def set_activity_stats(
|
||||||
# cls,
|
cls,
|
||||||
# workspace_id: str,
|
workspace_id: str,
|
||||||
# stats: Dict[str, Any],
|
stats: Dict[str, Any],
|
||||||
# expire: int = ACTIVITY_STATS_CACHE_EXPIRE,
|
expire: int = ACTIVITY_STATS_CACHE_EXPIRE,
|
||||||
# ) -> bool:
|
) -> bool:
|
||||||
# """设置记忆提取活动统计缓存
|
"""设置记忆提取活动统计缓存
|
||||||
|
|
||||||
# Args:
|
Args:
|
||||||
# workspace_id: 工作空间ID
|
workspace_id: 工作空间ID
|
||||||
# stats: 统计数据,格式:
|
stats: 统计数据,格式:
|
||||||
# {
|
{
|
||||||
# "chunk_count": int,
|
"chunk_count": int,
|
||||||
# "statements_count": int,
|
"statements_count": int,
|
||||||
# "triplet_entities_count": int,
|
"triplet_entities_count": int,
|
||||||
# "triplet_relations_count": int,
|
"triplet_relations_count": int,
|
||||||
# "temporal_count": int,
|
"temporal_count": int,
|
||||||
# }
|
}
|
||||||
# expire: 过期时间(秒),默认24小时
|
expire: 过期时间(秒),默认24小时
|
||||||
|
|
||||||
# Returns:
|
Returns:
|
||||||
# 是否设置成功
|
是否设置成功
|
||||||
# """
|
"""
|
||||||
# try:
|
try:
|
||||||
# key = cls._get_key(workspace_id)
|
key = cls._get_key(workspace_id)
|
||||||
# payload = {
|
payload = {
|
||||||
# "stats": stats,
|
"stats": stats,
|
||||||
# "generated_at": datetime.now().isoformat(),
|
"generated_at": datetime.now().isoformat(),
|
||||||
# "workspace_id": workspace_id,
|
"workspace_id": workspace_id,
|
||||||
# "cached": True,
|
"cached": True,
|
||||||
# }
|
}
|
||||||
# value = json.dumps(payload, ensure_ascii=False)
|
value = json.dumps(payload, ensure_ascii=False)
|
||||||
# await aio_redis.set(key, value, ex=expire)
|
await get_thread_safe_redis().set(key, value, ex=expire)
|
||||||
# logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒")
|
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒")
|
||||||
# return True
|
return True
|
||||||
# except Exception as e:
|
except Exception as e:
|
||||||
# logger.error(f"设置活动统计缓存失败: {e}", exc_info=True)
|
logger.error(f"设置活动统计缓存失败: {e}", exc_info=True)
|
||||||
# return False
|
return False
|
||||||
|
|
||||||
# @classmethod
|
@classmethod
|
||||||
# async def get_activity_stats(
|
async def get_activity_stats(
|
||||||
# cls,
|
cls,
|
||||||
# workspace_id: str,
|
workspace_id: str,
|
||||||
# ) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
# """获取记忆提取活动统计缓存
|
"""获取记忆提取活动统计缓存
|
||||||
|
|
||||||
# Args:
|
Args:
|
||||||
# workspace_id: 工作空间ID
|
workspace_id: 工作空间ID
|
||||||
|
|
||||||
# Returns:
|
Returns:
|
||||||
# 统计数据字典,缓存不存在或已过期返回 None
|
统计数据字典,缓存不存在或已过期返回 None
|
||||||
# """
|
"""
|
||||||
# try:
|
try:
|
||||||
# key = cls._get_key(workspace_id)
|
key = cls._get_key(workspace_id)
|
||||||
# value = await aio_redis.get(key)
|
value = await get_thread_safe_redis().get(key)
|
||||||
# if value:
|
if value:
|
||||||
# payload = json.loads(value)
|
payload = json.loads(value)
|
||||||
# logger.info(f"命中活动统计缓存: {key}")
|
logger.info(f"命中活动统计缓存: {key}")
|
||||||
# return payload
|
return payload
|
||||||
# logger.info(f"活动统计缓存不存在或已过期: {key}")
|
logger.info(f"活动统计缓存不存在或已过期: {key}")
|
||||||
# return None
|
return None
|
||||||
# except Exception as e:
|
except Exception as e:
|
||||||
# logger.error(f"获取活动统计缓存失败: {e}", exc_info=True)
|
logger.error(f"获取活动统计缓存失败: {e}", exc_info=True)
|
||||||
# return None
|
return None
|
||||||
|
|
||||||
# @classmethod
|
@classmethod
|
||||||
# async def delete_activity_stats(
|
async def delete_activity_stats(
|
||||||
# cls,
|
cls,
|
||||||
# workspace_id: str,
|
workspace_id: str,
|
||||||
# ) -> bool:
|
) -> bool:
|
||||||
# """删除记忆提取活动统计缓存
|
"""删除记忆提取活动统计缓存
|
||||||
|
|
||||||
# Args:
|
Args:
|
||||||
# workspace_id: 工作空间ID
|
workspace_id: 工作空间ID
|
||||||
|
|
||||||
# Returns:
|
Returns:
|
||||||
# 是否删除成功
|
是否删除成功
|
||||||
# """
|
"""
|
||||||
# try:
|
try:
|
||||||
# key = cls._get_key(workspace_id)
|
key = cls._get_key(workspace_id)
|
||||||
# result = await aio_redis.delete(key)
|
result = await get_thread_safe_redis().delete(key)
|
||||||
# logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
|
logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
|
||||||
# return result > 0
|
return result > 0
|
||||||
# except Exception as e:
|
except Exception as e:
|
||||||
# logger.error(f"删除活动统计缓存失败: {e}", exc_info=True)
|
logger.error(f"删除活动统计缓存失败: {e}", exc_info=True)
|
||||||
# return False
|
return False
|
||||||
|
|||||||
8
api/app/cache/memory/interest_memory.py
vendored
8
api/app/cache/memory/interest_memory.py
vendored
@@ -9,7 +9,7 @@ import logging
|
|||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from app.aioRedis import aio_redis
|
from app.aioRedis import get_thread_safe_redis
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -62,7 +62,7 @@ class InterestMemoryCache:
|
|||||||
"cached": True,
|
"cached": True,
|
||||||
}
|
}
|
||||||
value = json.dumps(payload, ensure_ascii=False)
|
value = json.dumps(payload, ensure_ascii=False)
|
||||||
await aio_redis.set(key, value, ex=expire)
|
await get_thread_safe_redis().set(key, value, ex=expire)
|
||||||
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒")
|
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -86,7 +86,7 @@ class InterestMemoryCache:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
key = cls._get_key(end_user_id, language)
|
key = cls._get_key(end_user_id, language)
|
||||||
value = await aio_redis.get(key)
|
value = await get_thread_safe_redis().get(key)
|
||||||
if value:
|
if value:
|
||||||
payload = json.loads(value)
|
payload = json.loads(value)
|
||||||
logger.info(f"命中兴趣分布缓存: {key}")
|
logger.info(f"命中兴趣分布缓存: {key}")
|
||||||
@@ -114,7 +114,7 @@ class InterestMemoryCache:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
key = cls._get_key(end_user_id, language)
|
key = cls._get_key(end_user_id, language)
|
||||||
result = await aio_redis.delete(key)
|
result = await get_thread_safe_redis().delete(key)
|
||||||
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
|
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
|
||||||
return result > 0
|
return result > 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -260,24 +260,24 @@ async def write(
|
|||||||
with open(log_file, "a", encoding="utf-8") as f:
|
with open(log_file, "a", encoding="utf-8") as f:
|
||||||
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
|
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
|
||||||
|
|
||||||
# # 将提取统计写入 Redis,按 workspace_id 存储
|
# 将提取统计写入 Redis,按 workspace_id 存储
|
||||||
# try:
|
try:
|
||||||
# from app.cache.memory.activity_stats_cache import ActivityStatsCache
|
from app.cache.memory.activity_stats_cache import ActivityStatsCache
|
||||||
|
|
||||||
# stats_to_cache = {
|
stats_to_cache = {
|
||||||
# "chunk_count": len(all_chunk_nodes) if all_chunk_nodes else 0,
|
"chunk_count": len(all_chunk_nodes) if all_chunk_nodes else 0,
|
||||||
# "statements_count": len(all_statement_nodes) if all_statement_nodes else 0,
|
"statements_count": len(all_statement_nodes) if all_statement_nodes else 0,
|
||||||
# "triplet_entities_count": len(all_entity_nodes) if all_entity_nodes else 0,
|
"triplet_entities_count": len(all_entity_nodes) if all_entity_nodes else 0,
|
||||||
# "triplet_relations_count": len(all_entity_entity_edges) if all_entity_entity_edges else 0,
|
"triplet_relations_count": len(all_entity_entity_edges) if all_entity_entity_edges else 0,
|
||||||
# "temporal_count": 0,
|
"temporal_count": 0,
|
||||||
# }
|
}
|
||||||
# await ActivityStatsCache.set_activity_stats(
|
await ActivityStatsCache.set_activity_stats(
|
||||||
# workspace_id=str(memory_config.workspace_id),
|
workspace_id=str(memory_config.workspace_id),
|
||||||
# stats=stats_to_cache,
|
stats=stats_to_cache,
|
||||||
# )
|
)
|
||||||
# logger.info(f"[WRITE] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}")
|
logger.info(f"[WRITE] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}")
|
||||||
# except Exception as cache_err:
|
except Exception as cache_err:
|
||||||
# logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||||
|
|
||||||
# logger.info("=== Pipeline Complete ===")
|
logger.info("=== Pipeline Complete ===")
|
||||||
# logger.info(f"Total execution time: {total_time:.2f} seconds")
|
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||||
|
|||||||
Reference in New Issue
Block a user