完成后端数据库连接配置
This commit is contained in:
287
backend/app/core/database/redis_db.py
Normal file
287
backend/app/core/database/redis_db.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""
|
||||
Redis 数据库连接管理模块
|
||||
|
||||
提供缓存和任务队列功能
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from typing import Any, Optional
|
||||
|
||||
import redis.asyncio as redis
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisDB:
|
||||
"""Redis 数据库管理类"""
|
||||
|
||||
def __init__(self):
|
||||
self.client: Optional[redis.Redis] = None
|
||||
self._connected = False
|
||||
|
||||
async def connect(self):
|
||||
"""建立 Redis 连接"""
|
||||
try:
|
||||
self.client = redis.from_url(
|
||||
settings.REDIS_URL,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
# 验证连接
|
||||
await self.client.ping()
|
||||
self._connected = True
|
||||
logger.info(f"Redis 连接成功: {settings.REDIS_URL}")
|
||||
except Exception as e:
|
||||
logger.error(f"Redis 连接失败: {e}")
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
"""关闭 Redis 连接"""
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
self._connected = False
|
||||
logger.info("Redis 连接已关闭")
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""检查连接状态"""
|
||||
return self._connected
|
||||
|
||||
# ==================== 基础操作 ====================
|
||||
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
"""获取值"""
|
||||
return await self.client.get(key)
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str,
|
||||
expire: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
设置值
|
||||
|
||||
Args:
|
||||
key: 键
|
||||
value: 值
|
||||
expire: 过期时间(秒)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
return await self.client.set(key, value, ex=expire)
|
||||
|
||||
async def delete(self, key: str) -> int:
|
||||
"""删除键"""
|
||||
return await self.client.delete(key)
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""检查键是否存在"""
|
||||
return await self.client.exists(key) > 0
|
||||
|
||||
# ==================== JSON 操作 ====================
|
||||
|
||||
async def set_json(
|
||||
self,
|
||||
key: str,
|
||||
data: Dict[str, Any],
|
||||
expire: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
设置 JSON 数据
|
||||
|
||||
Args:
|
||||
key: 键
|
||||
data: 数据字典
|
||||
expire: 过期时间(秒)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
json_str = json.dumps(data, ensure_ascii=False, default=str)
|
||||
return await self.set(key, json_str, expire)
|
||||
|
||||
async def get_json(self, key: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取 JSON 数据
|
||||
|
||||
Args:
|
||||
key: 键
|
||||
|
||||
Returns:
|
||||
数据字典,不存在返回 None
|
||||
"""
|
||||
value = await self.get(key)
|
||||
if value:
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
return None
|
||||
|
||||
# ==================== 任务状态管理 ====================
|
||||
|
||||
async def set_task_status(
|
||||
self,
|
||||
task_id: str,
|
||||
status: str,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
expire: int = 86400, # 默认24小时过期
|
||||
) -> bool:
|
||||
"""
|
||||
设置任务状态
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
status: 状态 (pending/processing/success/failure)
|
||||
meta: 附加信息
|
||||
expire: 过期时间(秒)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
key = f"task:{task_id}"
|
||||
data = {
|
||||
"status": status,
|
||||
"meta": meta or {},
|
||||
}
|
||||
return await self.set_json(key, data, expire)
|
||||
|
||||
async def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取任务状态
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
状态信息
|
||||
"""
|
||||
key = f"task:{task_id}"
|
||||
return await self.get_json(key)
|
||||
|
||||
async def update_task_progress(
|
||||
self,
|
||||
task_id: str,
|
||||
progress: int,
|
||||
message: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
更新任务进度
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
progress: 进度值 (0-100)
|
||||
message: 进度消息
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
data = await self.get_task_status(task_id)
|
||||
if data:
|
||||
data["meta"]["progress"] = progress
|
||||
if message:
|
||||
data["meta"]["message"] = message
|
||||
key = f"task:{task_id}"
|
||||
return await self.set_json(key, data, expire=86400)
|
||||
return False
|
||||
|
||||
# ==================== 缓存操作 ====================
|
||||
|
||||
async def cache_document(
|
||||
self,
|
||||
doc_id: str,
|
||||
data: Dict[str, Any],
|
||||
expire: int = 3600, # 默认1小时
|
||||
) -> bool:
|
||||
"""
|
||||
缓存文档数据
|
||||
|
||||
Args:
|
||||
doc_id: 文档ID
|
||||
data: 文档数据
|
||||
expire: 过期时间(秒)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
key = f"doc:{doc_id}"
|
||||
return await self.set_json(key, data, expire)
|
||||
|
||||
async def get_cached_document(self, doc_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取缓存的文档
|
||||
|
||||
Args:
|
||||
doc_id: 文档ID
|
||||
|
||||
Returns:
|
||||
文档数据
|
||||
"""
|
||||
key = f"doc:{doc_id}"
|
||||
return await self.get_json(key)
|
||||
|
||||
# ==================== 分布式锁 ====================
|
||||
|
||||
async def acquire_lock(
|
||||
self,
|
||||
lock_name: str,
|
||||
expire: int = 30,
|
||||
) -> bool:
|
||||
"""
|
||||
获取分布式锁
|
||||
|
||||
Args:
|
||||
lock_name: 锁名称
|
||||
expire: 过期时间(秒)
|
||||
|
||||
Returns:
|
||||
是否获取成功
|
||||
"""
|
||||
key = f"lock:{lock_name}"
|
||||
# 使用 SET NX EX 原子操作
|
||||
result = await self.client.set(key, "1", nx=True, ex=expire)
|
||||
return result is not None
|
||||
|
||||
async def release_lock(self, lock_name: str) -> bool:
|
||||
"""
|
||||
释放分布式锁
|
||||
|
||||
Args:
|
||||
lock_name: 锁名称
|
||||
|
||||
Returns:
|
||||
是否释放成功
|
||||
"""
|
||||
key = f"lock:{lock_name}"
|
||||
result = await self.client.delete(key)
|
||||
return result > 0
|
||||
|
||||
# ==================== 计数器 ====================
|
||||
|
||||
async def incr(self, key: str, amount: int = 1) -> int:
|
||||
"""递增计数器"""
|
||||
return await self.client.incrby(key, amount)
|
||||
|
||||
async def decr(self, key: str, amount: int = 1) -> int:
|
||||
"""递减计数器"""
|
||||
return await self.client.decrby(key, amount)
|
||||
|
||||
# ==================== 过期时间管理 ====================
|
||||
|
||||
async def expire(self, key: str, seconds: int) -> bool:
|
||||
"""设置键的过期时间"""
|
||||
return await self.client.expire(key, seconds)
|
||||
|
||||
async def ttl(self, key: str) -> int:
|
||||
"""获取键的剩余生存时间"""
|
||||
return await self.client.ttl(key)
|
||||
|
||||
|
||||
# ==================== 全局单例 ====================
|
||||
|
||||
redis_db = RedisDB()
|
||||
Reference in New Issue
Block a user