Files
FilesReadSystem/backend/app/core/database/redis_db.py

288 lines
7.0 KiB
Python

"""
Redis 数据库连接管理模块
提供缓存和任务队列功能
"""
import json
import logging
from datetime import timedelta
from typing import Any, Optional
import redis.asyncio as redis
from app.config import settings
logger = logging.getLogger(__name__)
class RedisDB:
"""Redis 数据库管理类"""
def __init__(self):
self.client: Optional[redis.Redis] = None
self._connected = False
async def connect(self):
"""建立 Redis 连接"""
try:
self.client = redis.from_url(
settings.REDIS_URL,
encoding="utf-8",
decode_responses=True,
)
# 验证连接
await self.client.ping()
self._connected = True
logger.info(f"Redis 连接成功: {settings.REDIS_URL}")
except Exception as e:
logger.error(f"Redis 连接失败: {e}")
raise
async def close(self):
"""关闭 Redis 连接"""
if self.client:
await self.client.close()
self._connected = False
logger.info("Redis 连接已关闭")
@property
def is_connected(self) -> bool:
"""检查连接状态"""
return self._connected
# ==================== 基础操作 ====================
async def get(self, key: str) -> Optional[str]:
"""获取值"""
return await self.client.get(key)
async def set(
self,
key: str,
value: str,
expire: Optional[int] = None,
) -> bool:
"""
设置值
Args:
key: 键
value: 值
expire: 过期时间(秒)
Returns:
是否成功
"""
return await self.client.set(key, value, ex=expire)
async def delete(self, key: str) -> int:
"""删除键"""
return await self.client.delete(key)
async def exists(self, key: str) -> bool:
"""检查键是否存在"""
return await self.client.exists(key) > 0
# ==================== JSON 操作 ====================
async def set_json(
self,
key: str,
data: Dict[str, Any],
expire: Optional[int] = None,
) -> bool:
"""
设置 JSON 数据
Args:
key: 键
data: 数据字典
expire: 过期时间(秒)
Returns:
是否成功
"""
json_str = json.dumps(data, ensure_ascii=False, default=str)
return await self.set(key, json_str, expire)
async def get_json(self, key: str) -> Optional[Dict[str, Any]]:
"""
获取 JSON 数据
Args:
key: 键
Returns:
数据字典,不存在返回 None
"""
value = await self.get(key)
if value:
try:
return json.loads(value)
except json.JSONDecodeError:
return None
return None
# ==================== 任务状态管理 ====================
async def set_task_status(
self,
task_id: str,
status: str,
meta: Optional[Dict[str, Any]] = None,
expire: int = 86400, # 默认24小时过期
) -> bool:
"""
设置任务状态
Args:
task_id: 任务ID
status: 状态 (pending/processing/success/failure)
meta: 附加信息
expire: 过期时间(秒)
Returns:
是否成功
"""
key = f"task:{task_id}"
data = {
"status": status,
"meta": meta or {},
}
return await self.set_json(key, data, expire)
async def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
"""
获取任务状态
Args:
task_id: 任务ID
Returns:
状态信息
"""
key = f"task:{task_id}"
return await self.get_json(key)
async def update_task_progress(
self,
task_id: str,
progress: int,
message: Optional[str] = None,
) -> bool:
"""
更新任务进度
Args:
task_id: 任务ID
progress: 进度值 (0-100)
message: 进度消息
Returns:
是否成功
"""
data = await self.get_task_status(task_id)
if data:
data["meta"]["progress"] = progress
if message:
data["meta"]["message"] = message
key = f"task:{task_id}"
return await self.set_json(key, data, expire=86400)
return False
# ==================== 缓存操作 ====================
async def cache_document(
self,
doc_id: str,
data: Dict[str, Any],
expire: int = 3600, # 默认1小时
) -> bool:
"""
缓存文档数据
Args:
doc_id: 文档ID
data: 文档数据
expire: 过期时间(秒)
Returns:
是否成功
"""
key = f"doc:{doc_id}"
return await self.set_json(key, data, expire)
async def get_cached_document(self, doc_id: str) -> Optional[Dict[str, Any]]:
"""
获取缓存的文档
Args:
doc_id: 文档ID
Returns:
文档数据
"""
key = f"doc:{doc_id}"
return await self.get_json(key)
# ==================== 分布式锁 ====================
async def acquire_lock(
self,
lock_name: str,
expire: int = 30,
) -> bool:
"""
获取分布式锁
Args:
lock_name: 锁名称
expire: 过期时间(秒)
Returns:
是否获取成功
"""
key = f"lock:{lock_name}"
# 使用 SET NX EX 原子操作
result = await self.client.set(key, "1", nx=True, ex=expire)
return result is not None
async def release_lock(self, lock_name: str) -> bool:
"""
释放分布式锁
Args:
lock_name: 锁名称
Returns:
是否释放成功
"""
key = f"lock:{lock_name}"
result = await self.client.delete(key)
return result > 0
# ==================== 计数器 ====================
async def incr(self, key: str, amount: int = 1) -> int:
"""递增计数器"""
return await self.client.incrby(key, amount)
async def decr(self, key: str, amount: int = 1) -> int:
"""递减计数器"""
return await self.client.decrby(key, amount)
# ==================== 过期时间管理 ====================
async def expire(self, key: str, seconds: int) -> bool:
"""设置键的过期时间"""
return await self.client.expire(key, seconds)
async def ttl(self, key: str) -> int:
"""获取键的剩余生存时间"""
return await self.client.ttl(key)
# ==================== 全局单例 ====================
redis_db = RedisDB()