Files
MemoryBear/api/app/aioRedis.py

209 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import json
import logging
from typing import Dict, Any, Optional
import redis.asyncio as redis
from redis.asyncio import ConnectionPool
from app.core.config import settings
# 设置日志记录器
logger = logging.getLogger(__name__)
# 创建连接池
pool = ConnectionPool.from_url(
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}",
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD,
decode_responses=True,
max_connections=30
)
aio_redis = redis.StrictRedis(connection_pool=pool)
async def get_redis_connection():
"""获取Redis连接"""
try:
return redis.StrictRedis(connection_pool=pool)
except Exception as e:
logger.error(f"Redis连接失败: {str(e)}")
return None
async def aio_redis_set(key: str, val: str | dict, expire: int = None):
"""设置Redis键值
Args:
key: Redis键
val: 要存储的值(字符串或字典)
expire: 过期时间(秒)None表示永不过期
"""
try:
if isinstance(val, dict):
val = json.dumps(val, ensure_ascii=False)
if expire is not None:
# 设置带过期时间的键值
await aio_redis.set(key, val, ex=expire)
else:
# 设置永久键值
await aio_redis.set(key, val)
except Exception as e:
logger.error(f"Redis set错误: {str(e)}")
async def aio_redis_get(key: str):
"""获取Redis键值"""
try:
return await aio_redis.get(key)
except Exception as e:
logger.error(f"Redis get错误: {str(e)}")
return None
async def aio_redis_delete(key: str):
"""删除Redis键"""
try:
return await aio_redis.delete(key)
except Exception as e:
logger.error(f"Redis delete错误: {str(e)}")
return None
async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
"""发布消息到Redis频道"""
try:
conn = await get_redis_connection()
if not conn:
return False
await conn.publish(channel, json.dumps(message, ensure_ascii=False))
return True
except Exception as e:
logger.error(f"Redis发布错误: {str(e)}")
return False
class RedisSubscriber:
"""Redis订阅器"""
def __init__(self, channel: str):
self.channel = channel
self.conn = None
self.pubsub = None
self.is_closed = False
self._queue = asyncio.Queue()
self._task = None
async def start(self):
"""开始订阅"""
if self.is_closed or self._task:
return
self._task = asyncio.create_task(self._receive_messages())
logger.info(f"开始订阅: {self.channel}")
async def _receive_messages(self):
"""接收消息"""
try:
self.conn = await get_redis_connection()
if not self.conn:
return
self.pubsub = self.conn.pubsub()
await self.pubsub.subscribe(self.channel)
while not self.is_closed:
try:
message = await self.pubsub.get_message(ignore_subscribe_messages=True, timeout=0.01)
if message and isinstance(message.get("data"), str):
try:
await self._queue.put(json.loads(message["data"]))
except json.JSONDecodeError:
logger.warning(f"消息解析失败: {message['data']}")
await asyncio.sleep(0.01)
except Exception as e:
if "closed" in str(e).lower():
break
logger.warning(f"接收消息错误: {str(e)}")
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"订阅错误: {str(e)}")
await self._queue.put({"type": "error", "data": {"message": str(e), "status": "error"}})
finally:
await self._queue.put(None)
await self._cleanup()
async def _cleanup(self):
"""清理资源"""
if self.pubsub:
try:
await self.pubsub.unsubscribe(self.channel)
await self.pubsub.close()
except Exception:
pass
if self.conn:
try:
await self.conn.close()
except Exception:
pass
async def get_message(self) -> Optional[Dict[str, Any]]:
"""获取消息"""
if self.is_closed:
return None
if not self._task:
await self.start()
try:
return await self._queue.get()
except Exception as e:
logger.error(f"获取消息错误: {str(e)}")
return None
async def close(self):
"""关闭订阅器"""
if self.is_closed:
return
self.is_closed = True
if self._task:
self._task.cancel()
await self._cleanup()
class RedisPubSubManager:
"""Redis发布订阅管理器"""
def __init__(self):
self.subscribers = {}
async def publish(self, channel: str, message: Dict[str, Any]) -> bool:
return await aio_redis_publish(channel, message)
def get_subscriber(self, channel: str) -> RedisSubscriber:
if channel in self.subscribers:
subscriber = self.subscribers[channel]
if not subscriber.is_closed:
return subscriber
subscriber = RedisSubscriber(channel)
self.subscribers[channel] = subscriber
return subscriber
def cancel_subscription(self, channel: str) -> bool:
if channel in self.subscribers:
asyncio.create_task(self.subscribers[channel].close())
del self.subscribers[channel]
return True
return False
def cancel_all_subscriptions(self) -> int:
count = len(self.subscribers)
for subscriber in self.subscribers.values():
asyncio.create_task(subscriber.close())
self.subscribers.clear()
return count
# 全局实例
pubsub_manager = RedisPubSubManager()