309 lines
7.9 KiB
Python
309 lines
7.9 KiB
Python
"""
|
||
Redis 数据库连接管理模块
|
||
|
||
提供缓存和任务队列功能
|
||
"""
|
||
import json
|
||
import logging
|
||
from datetime import timedelta
|
||
from typing import Any, Dict, Optional
|
||
|
||
import redis.asyncio as redis
|
||
|
||
from app.config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class RedisDB:
|
||
"""Redis 数据库管理类"""
|
||
|
||
def __init__(self):
|
||
self.client: Optional[redis.Redis] = None
|
||
self._connected = False
|
||
|
||
async def connect(self):
|
||
"""建立 Redis 连接"""
|
||
try:
|
||
self.client = redis.from_url(
|
||
settings.REDIS_URL,
|
||
encoding="utf-8",
|
||
decode_responses=True,
|
||
)
|
||
# 验证连接
|
||
await self.client.ping()
|
||
self._connected = True
|
||
logger.info(f"Redis 连接成功: {settings.REDIS_URL}")
|
||
except Exception as e:
|
||
logger.error(f"Redis 连接失败: {e}")
|
||
raise
|
||
|
||
async def close(self):
|
||
"""关闭 Redis 连接"""
|
||
if self.client:
|
||
await self.client.close()
|
||
self._connected = False
|
||
logger.info("Redis 连接已关闭")
|
||
|
||
@property
|
||
def is_connected(self) -> bool:
|
||
"""检查连接状态"""
|
||
return self._connected
|
||
|
||
# ==================== 基础操作 ====================
|
||
|
||
async def get(self, key: str) -> Optional[str]:
|
||
"""获取值"""
|
||
return await self.client.get(key)
|
||
|
||
async def set(
|
||
self,
|
||
key: str,
|
||
value: str,
|
||
expire: Optional[int] = None,
|
||
) -> bool:
|
||
"""
|
||
设置值
|
||
|
||
Args:
|
||
key: 键
|
||
value: 值
|
||
expire: 过期时间(秒)
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
return await self.client.set(key, value, ex=expire)
|
||
|
||
async def delete(self, key: str) -> int:
|
||
"""删除键"""
|
||
return await self.client.delete(key)
|
||
|
||
async def exists(self, key: str) -> bool:
|
||
"""检查键是否存在"""
|
||
return await self.client.exists(key) > 0
|
||
|
||
# ==================== JSON 操作 ====================
|
||
|
||
async def set_json(
|
||
self,
|
||
key: str,
|
||
data: Dict[str, Any],
|
||
expire: Optional[int] = None,
|
||
) -> bool:
|
||
"""
|
||
设置 JSON 数据
|
||
|
||
Args:
|
||
key: 键
|
||
data: 数据字典
|
||
expire: 过期时间(秒)
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
json_str = json.dumps(data, ensure_ascii=False, default=str)
|
||
return await self.set(key, json_str, expire)
|
||
|
||
async def get_json(self, key: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
获取 JSON 数据
|
||
|
||
Args:
|
||
key: 键
|
||
|
||
Returns:
|
||
数据字典,不存在返回 None
|
||
"""
|
||
value = await self.get(key)
|
||
if value:
|
||
try:
|
||
return json.loads(value)
|
||
except json.JSONDecodeError:
|
||
return None
|
||
return None
|
||
|
||
# ==================== 任务状态管理 ====================
|
||
|
||
async def set_task_status(
|
||
self,
|
||
task_id: str,
|
||
status: str,
|
||
meta: Optional[Dict[str, Any]] = None,
|
||
expire: int = 86400, # 默认24小时过期
|
||
) -> bool:
|
||
"""
|
||
设置任务状态
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
status: 状态 (pending/processing/success/failure)
|
||
meta: 附加信息
|
||
expire: 过期时间(秒)
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
if not self._connected or not self.client:
|
||
logger.warning(f"Redis未连接,跳过任务状态更新: {task_id}")
|
||
return False
|
||
try:
|
||
key = f"task:{task_id}"
|
||
data = {
|
||
"status": status,
|
||
"meta": meta or {},
|
||
}
|
||
return await self.set_json(key, data, expire)
|
||
except Exception as e:
|
||
logger.warning(f"设置任务状态失败: {task_id}, error: {e}")
|
||
return False
|
||
|
||
async def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
获取任务状态
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
|
||
Returns:
|
||
状态信息
|
||
"""
|
||
if not self._connected or not self.client:
|
||
logger.warning(f"Redis未连接,无法获取任务状态: {task_id}")
|
||
return None
|
||
try:
|
||
key = f"task:{task_id}"
|
||
return await self.get_json(key)
|
||
except Exception as e:
|
||
logger.warning(f"获取任务状态失败: {task_id}, error: {e}")
|
||
return None
|
||
|
||
async def update_task_progress(
|
||
self,
|
||
task_id: str,
|
||
progress: int,
|
||
message: Optional[str] = None,
|
||
) -> bool:
|
||
"""
|
||
更新任务进度
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
progress: 进度值 (0-100)
|
||
message: 进度消息
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
if not self._connected or not self.client:
|
||
logger.warning(f"Redis未连接,跳过任务进度更新: {task_id}")
|
||
return False
|
||
try:
|
||
data = await self.get_task_status(task_id)
|
||
if data:
|
||
data["meta"]["progress"] = progress
|
||
if message:
|
||
data["meta"]["message"] = message
|
||
key = f"task:{task_id}"
|
||
return await self.set_json(key, data, expire=86400)
|
||
return False
|
||
except Exception as e:
|
||
logger.warning(f"更新任务进度失败: {task_id}, error: {e}")
|
||
return False
|
||
|
||
# ==================== 缓存操作 ====================
|
||
|
||
async def cache_document(
|
||
self,
|
||
doc_id: str,
|
||
data: Dict[str, Any],
|
||
expire: int = 3600, # 默认1小时
|
||
) -> bool:
|
||
"""
|
||
缓存文档数据
|
||
|
||
Args:
|
||
doc_id: 文档ID
|
||
data: 文档数据
|
||
expire: 过期时间(秒)
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
key = f"doc:{doc_id}"
|
||
return await self.set_json(key, data, expire)
|
||
|
||
async def get_cached_document(self, doc_id: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
获取缓存的文档
|
||
|
||
Args:
|
||
doc_id: 文档ID
|
||
|
||
Returns:
|
||
文档数据
|
||
"""
|
||
key = f"doc:{doc_id}"
|
||
return await self.get_json(key)
|
||
|
||
# ==================== 分布式锁 ====================
|
||
|
||
async def acquire_lock(
|
||
self,
|
||
lock_name: str,
|
||
expire: int = 30,
|
||
) -> bool:
|
||
"""
|
||
获取分布式锁
|
||
|
||
Args:
|
||
lock_name: 锁名称
|
||
expire: 过期时间(秒)
|
||
|
||
Returns:
|
||
是否获取成功
|
||
"""
|
||
key = f"lock:{lock_name}"
|
||
# 使用 SET NX EX 原子操作
|
||
result = await self.client.set(key, "1", nx=True, ex=expire)
|
||
return result is not None
|
||
|
||
async def release_lock(self, lock_name: str) -> bool:
|
||
"""
|
||
释放分布式锁
|
||
|
||
Args:
|
||
lock_name: 锁名称
|
||
|
||
Returns:
|
||
是否释放成功
|
||
"""
|
||
key = f"lock:{lock_name}"
|
||
result = await self.client.delete(key)
|
||
return result > 0
|
||
|
||
# ==================== 计数器 ====================
|
||
|
||
async def incr(self, key: str, amount: int = 1) -> int:
|
||
"""递增计数器"""
|
||
return await self.client.incrby(key, amount)
|
||
|
||
async def decr(self, key: str, amount: int = 1) -> int:
|
||
"""递减计数器"""
|
||
return await self.client.decrby(key, amount)
|
||
|
||
# ==================== 过期时间管理 ====================
|
||
|
||
async def expire(self, key: str, seconds: int) -> bool:
|
||
"""设置键的过期时间"""
|
||
return await self.client.expire(key, seconds)
|
||
|
||
async def ttl(self, key: str) -> int:
|
||
"""获取键的剩余生存时间"""
|
||
return await self.client.ttl(key)
|
||
|
||
|
||
# ==================== 全局单例 ====================
|
||
|
||
redis_db = RedisDB()
|