288 lines
7.0 KiB
Python
288 lines
7.0 KiB
Python
"""
|
|
Redis 数据库连接管理模块
|
|
|
|
提供缓存和任务队列功能
|
|
"""
|
|
import json
|
|
import logging
|
|
from datetime import timedelta
|
|
from typing import Any, Optional
|
|
|
|
import redis.asyncio as redis
|
|
|
|
from app.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RedisDB:
|
|
"""Redis 数据库管理类"""
|
|
|
|
def __init__(self):
|
|
self.client: Optional[redis.Redis] = None
|
|
self._connected = False
|
|
|
|
async def connect(self):
|
|
"""建立 Redis 连接"""
|
|
try:
|
|
self.client = redis.from_url(
|
|
settings.REDIS_URL,
|
|
encoding="utf-8",
|
|
decode_responses=True,
|
|
)
|
|
# 验证连接
|
|
await self.client.ping()
|
|
self._connected = True
|
|
logger.info(f"Redis 连接成功: {settings.REDIS_URL}")
|
|
except Exception as e:
|
|
logger.error(f"Redis 连接失败: {e}")
|
|
raise
|
|
|
|
async def close(self):
|
|
"""关闭 Redis 连接"""
|
|
if self.client:
|
|
await self.client.close()
|
|
self._connected = False
|
|
logger.info("Redis 连接已关闭")
|
|
|
|
@property
|
|
def is_connected(self) -> bool:
|
|
"""检查连接状态"""
|
|
return self._connected
|
|
|
|
# ==================== 基础操作 ====================
|
|
|
|
async def get(self, key: str) -> Optional[str]:
|
|
"""获取值"""
|
|
return await self.client.get(key)
|
|
|
|
async def set(
|
|
self,
|
|
key: str,
|
|
value: str,
|
|
expire: Optional[int] = None,
|
|
) -> bool:
|
|
"""
|
|
设置值
|
|
|
|
Args:
|
|
key: 键
|
|
value: 值
|
|
expire: 过期时间(秒)
|
|
|
|
Returns:
|
|
是否成功
|
|
"""
|
|
return await self.client.set(key, value, ex=expire)
|
|
|
|
async def delete(self, key: str) -> int:
|
|
"""删除键"""
|
|
return await self.client.delete(key)
|
|
|
|
async def exists(self, key: str) -> bool:
|
|
"""检查键是否存在"""
|
|
return await self.client.exists(key) > 0
|
|
|
|
# ==================== JSON 操作 ====================
|
|
|
|
async def set_json(
|
|
self,
|
|
key: str,
|
|
data: Dict[str, Any],
|
|
expire: Optional[int] = None,
|
|
) -> bool:
|
|
"""
|
|
设置 JSON 数据
|
|
|
|
Args:
|
|
key: 键
|
|
data: 数据字典
|
|
expire: 过期时间(秒)
|
|
|
|
Returns:
|
|
是否成功
|
|
"""
|
|
json_str = json.dumps(data, ensure_ascii=False, default=str)
|
|
return await self.set(key, json_str, expire)
|
|
|
|
async def get_json(self, key: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
获取 JSON 数据
|
|
|
|
Args:
|
|
key: 键
|
|
|
|
Returns:
|
|
数据字典,不存在返回 None
|
|
"""
|
|
value = await self.get(key)
|
|
if value:
|
|
try:
|
|
return json.loads(value)
|
|
except json.JSONDecodeError:
|
|
return None
|
|
return None
|
|
|
|
# ==================== 任务状态管理 ====================
|
|
|
|
async def set_task_status(
|
|
self,
|
|
task_id: str,
|
|
status: str,
|
|
meta: Optional[Dict[str, Any]] = None,
|
|
expire: int = 86400, # 默认24小时过期
|
|
) -> bool:
|
|
"""
|
|
设置任务状态
|
|
|
|
Args:
|
|
task_id: 任务ID
|
|
status: 状态 (pending/processing/success/failure)
|
|
meta: 附加信息
|
|
expire: 过期时间(秒)
|
|
|
|
Returns:
|
|
是否成功
|
|
"""
|
|
key = f"task:{task_id}"
|
|
data = {
|
|
"status": status,
|
|
"meta": meta or {},
|
|
}
|
|
return await self.set_json(key, data, expire)
|
|
|
|
async def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
获取任务状态
|
|
|
|
Args:
|
|
task_id: 任务ID
|
|
|
|
Returns:
|
|
状态信息
|
|
"""
|
|
key = f"task:{task_id}"
|
|
return await self.get_json(key)
|
|
|
|
async def update_task_progress(
|
|
self,
|
|
task_id: str,
|
|
progress: int,
|
|
message: Optional[str] = None,
|
|
) -> bool:
|
|
"""
|
|
更新任务进度
|
|
|
|
Args:
|
|
task_id: 任务ID
|
|
progress: 进度值 (0-100)
|
|
message: 进度消息
|
|
|
|
Returns:
|
|
是否成功
|
|
"""
|
|
data = await self.get_task_status(task_id)
|
|
if data:
|
|
data["meta"]["progress"] = progress
|
|
if message:
|
|
data["meta"]["message"] = message
|
|
key = f"task:{task_id}"
|
|
return await self.set_json(key, data, expire=86400)
|
|
return False
|
|
|
|
# ==================== 缓存操作 ====================
|
|
|
|
async def cache_document(
|
|
self,
|
|
doc_id: str,
|
|
data: Dict[str, Any],
|
|
expire: int = 3600, # 默认1小时
|
|
) -> bool:
|
|
"""
|
|
缓存文档数据
|
|
|
|
Args:
|
|
doc_id: 文档ID
|
|
data: 文档数据
|
|
expire: 过期时间(秒)
|
|
|
|
Returns:
|
|
是否成功
|
|
"""
|
|
key = f"doc:{doc_id}"
|
|
return await self.set_json(key, data, expire)
|
|
|
|
async def get_cached_document(self, doc_id: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
获取缓存的文档
|
|
|
|
Args:
|
|
doc_id: 文档ID
|
|
|
|
Returns:
|
|
文档数据
|
|
"""
|
|
key = f"doc:{doc_id}"
|
|
return await self.get_json(key)
|
|
|
|
# ==================== 分布式锁 ====================
|
|
|
|
async def acquire_lock(
|
|
self,
|
|
lock_name: str,
|
|
expire: int = 30,
|
|
) -> bool:
|
|
"""
|
|
获取分布式锁
|
|
|
|
Args:
|
|
lock_name: 锁名称
|
|
expire: 过期时间(秒)
|
|
|
|
Returns:
|
|
是否获取成功
|
|
"""
|
|
key = f"lock:{lock_name}"
|
|
# 使用 SET NX EX 原子操作
|
|
result = await self.client.set(key, "1", nx=True, ex=expire)
|
|
return result is not None
|
|
|
|
async def release_lock(self, lock_name: str) -> bool:
|
|
"""
|
|
释放分布式锁
|
|
|
|
Args:
|
|
lock_name: 锁名称
|
|
|
|
Returns:
|
|
是否释放成功
|
|
"""
|
|
key = f"lock:{lock_name}"
|
|
result = await self.client.delete(key)
|
|
return result > 0
|
|
|
|
# ==================== 计数器 ====================
|
|
|
|
async def incr(self, key: str, amount: int = 1) -> int:
|
|
"""递增计数器"""
|
|
return await self.client.incrby(key, amount)
|
|
|
|
async def decr(self, key: str, amount: int = 1) -> int:
|
|
"""递减计数器"""
|
|
return await self.client.decrby(key, amount)
|
|
|
|
# ==================== 过期时间管理 ====================
|
|
|
|
async def expire(self, key: str, seconds: int) -> bool:
|
|
"""设置键的过期时间"""
|
|
return await self.client.expire(key, seconds)
|
|
|
|
async def ttl(self, key: str) -> int:
|
|
"""获取键的剩余生存时间"""
|
|
return await self.client.ttl(key)
|
|
|
|
|
|
# ==================== 全局单例 ====================
|
|
|
|
redis_db = RedisDB()
|