完成后端数据库连接配置
This commit is contained in:
18
backend/app/core/database/__init__.py
Normal file
18
backend/app/core/database/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
数据库连接管理模块
|
||||
|
||||
提供 MySQL、MongoDB、Redis 的连接管理
|
||||
"""
|
||||
from app.core.database.mysql import MySQLDB, mysql_db, Base
|
||||
from app.core.database.mongodb import MongoDB, mongodb
|
||||
from app.core.database.redis_db import RedisDB, redis_db
|
||||
|
||||
__all__ = [
|
||||
"MySQLDB",
|
||||
"mysql_db",
|
||||
"MongoDB",
|
||||
"mongodb",
|
||||
"RedisDB",
|
||||
"redis_db",
|
||||
"Base",
|
||||
]
|
||||
247
backend/app/core/database/mongodb.py
Normal file
247
backend/app/core/database/mongodb.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
MongoDB 数据库连接管理模块
|
||||
|
||||
提供非结构化数据的存储和查询功能
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MongoDB:
|
||||
"""MongoDB 数据库管理类"""
|
||||
|
||||
def __init__(self):
|
||||
self.client: Optional[AsyncIOMotorClient] = None
|
||||
self.db: Optional[AsyncIOMotorDatabase] = None
|
||||
|
||||
async def connect(self):
|
||||
"""建立 MongoDB 连接"""
|
||||
try:
|
||||
self.client = AsyncIOMotorClient(
|
||||
settings.MONGODB_URL,
|
||||
serverSelectionTimeoutMS=5000,
|
||||
)
|
||||
self.db = self.client[settings.MONGODB_DB_NAME]
|
||||
# 验证连接
|
||||
await self.client.admin.command('ping')
|
||||
logger.info(f"MongoDB 连接成功: {settings.MONGODB_DB_NAME}")
|
||||
except Exception as e:
|
||||
logger.error(f"MongoDB 连接失败: {e}")
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
"""关闭 MongoDB 连接"""
|
||||
if self.client:
|
||||
self.client.close()
|
||||
logger.info("MongoDB 连接已关闭")
|
||||
|
||||
@property
|
||||
def documents(self):
|
||||
"""文档集合 - 存储原始文档和解析结果"""
|
||||
return self.db["documents"]
|
||||
|
||||
@property
|
||||
def embeddings(self):
|
||||
"""向量嵌入集合 - 存储文本嵌入向量"""
|
||||
return self.db["embeddings"]
|
||||
|
||||
@property
|
||||
def rag_index(self):
|
||||
"""RAG索引集合 - 存储字段语义索引"""
|
||||
return self.db["rag_index"]
|
||||
|
||||
# ==================== 文档操作 ====================
|
||||
|
||||
async def insert_document(
|
||||
self,
|
||||
doc_type: str,
|
||||
content: str,
|
||||
metadata: Dict[str, Any],
|
||||
structured_data: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
插入文档
|
||||
|
||||
Args:
|
||||
doc_type: 文档类型 (docx/xlsx/md/txt)
|
||||
content: 原始文本内容
|
||||
metadata: 元数据
|
||||
structured_data: 结构化数据 (表格等)
|
||||
|
||||
Returns:
|
||||
插入文档的ID
|
||||
"""
|
||||
document = {
|
||||
"doc_type": doc_type,
|
||||
"content": content,
|
||||
"metadata": metadata,
|
||||
"structured_data": structured_data,
|
||||
"created_at": datetime.utcnow(),
|
||||
"updated_at": datetime.utcnow(),
|
||||
}
|
||||
result = await self.documents.insert_one(document)
|
||||
logger.info(f"文档已插入MongoDB: {result.inserted_id}")
|
||||
return str(result.inserted_id)
|
||||
|
||||
async def get_document(self, doc_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""根据ID获取文档"""
|
||||
from bson import ObjectId
|
||||
doc = await self.documents.find_one({"_id": ObjectId(doc_id)})
|
||||
if doc:
|
||||
doc["_id"] = str(doc["_id"])
|
||||
return doc
|
||||
|
||||
async def search_documents(
|
||||
self,
|
||||
query: str,
|
||||
doc_type: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
搜索文档
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
doc_type: 文档类型过滤
|
||||
limit: 返回数量
|
||||
|
||||
Returns:
|
||||
文档列表
|
||||
"""
|
||||
filter_query = {"content": {"$regex": query}}
|
||||
if doc_type:
|
||||
filter_query["doc_type"] = doc_type
|
||||
|
||||
cursor = self.documents.find(filter_query).limit(limit)
|
||||
documents = []
|
||||
async for doc in cursor:
|
||||
doc["_id"] = str(doc["_id"])
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
async def delete_document(self, doc_id: str) -> bool:
|
||||
"""删除文档"""
|
||||
from bson import ObjectId
|
||||
result = await self.documents.delete_one({"_id": ObjectId(doc_id)})
|
||||
return result.deleted_count > 0
|
||||
|
||||
# ==================== RAG 索引操作 ====================
|
||||
|
||||
async def insert_rag_entry(
|
||||
self,
|
||||
table_name: str,
|
||||
field_name: str,
|
||||
field_description: str,
|
||||
embedding: List[float],
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
插入RAG索引条目
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
field_name: 字段名
|
||||
field_description: 字段描述
|
||||
embedding: 向量嵌入
|
||||
metadata: 其他元数据
|
||||
|
||||
Returns:
|
||||
插入条目的ID
|
||||
"""
|
||||
entry = {
|
||||
"table_name": table_name,
|
||||
"field_name": field_name,
|
||||
"field_description": field_description,
|
||||
"embedding": embedding,
|
||||
"metadata": metadata or {},
|
||||
"created_at": datetime.utcnow(),
|
||||
}
|
||||
result = await self.rag_index.insert_one(entry)
|
||||
return str(result.inserted_id)
|
||||
|
||||
async def search_rag(
|
||||
self,
|
||||
query_embedding: List[float],
|
||||
top_k: int = 5,
|
||||
table_name: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
搜索RAG索引 (使用向量相似度)
|
||||
|
||||
Args:
|
||||
query_embedding: 查询向量
|
||||
top_k: 返回数量
|
||||
table_name: 可选的表名过滤
|
||||
|
||||
Returns:
|
||||
相关的索引条目
|
||||
"""
|
||||
# MongoDB 5.0+ 支持向量搜索
|
||||
# 较低版本使用欧氏距离替代
|
||||
pipeline = [
|
||||
{
|
||||
"$addFields": {
|
||||
"distance": {
|
||||
"$reduce": {
|
||||
"input": {"$range": [0, {"$size": "$embedding"}]},
|
||||
"initialValue": 0,
|
||||
"in": {
|
||||
"$add": [
|
||||
"$$value",
|
||||
{
|
||||
"$pow": [
|
||||
{
|
||||
"$subtract": [
|
||||
{"$arrayElemAt": ["$embedding", "$$this"]},
|
||||
{"$arrayElemAt": [query_embedding, "$$this"]},
|
||||
]
|
||||
},
|
||||
2,
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{"$sort": {"distance": 1}},
|
||||
{"$limit": top_k},
|
||||
]
|
||||
|
||||
if table_name:
|
||||
pipeline.insert(0, {"$match": {"table_name": table_name}})
|
||||
|
||||
results = []
|
||||
async for doc in self.rag_index.aggregate(pipeline):
|
||||
doc["_id"] = str(doc["_id"])
|
||||
results.append(doc)
|
||||
return results
|
||||
|
||||
# ==================== 集合管理 ====================
|
||||
|
||||
async def create_indexes(self):
|
||||
"""创建索引以优化查询"""
|
||||
# 文档集合索引
|
||||
await self.documents.create_index("doc_type")
|
||||
await self.documents.create_index("created_at")
|
||||
await self.documents.create_index([("content", "text")])
|
||||
|
||||
# RAG索引集合索引
|
||||
await self.rag_index.create_index("table_name")
|
||||
await self.rag_index.create_index("field_name")
|
||||
await self.rag_index.create_index([("embedding", "hnsw", {"type": "knnVector"})])
|
||||
|
||||
logger.info("MongoDB 索引创建完成")
|
||||
|
||||
|
||||
# ==================== 全局单例 ====================
|
||||
|
||||
mongodb = MongoDB()
|
||||
193
backend/app/core/database/mysql.py
Normal file
193
backend/app/core/database/mysql.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
MySQL 数据库连接管理模块
|
||||
|
||||
提供结构化数据的存储和查询功能
|
||||
"""
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
DateTime,
|
||||
Enum as SQLEnum,
|
||||
Float,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
create_engine,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
||||
from sqlalchemy.sql import select
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""SQLAlchemy 声明基类"""
|
||||
pass
|
||||
|
||||
|
||||
class MySQLDB:
|
||||
"""MySQL 数据库管理类"""
|
||||
|
||||
def __init__(self):
|
||||
# 异步引擎 (用于 FastAPI 异步操作)
|
||||
self.async_engine = create_async_engine(
|
||||
settings.mysql_url,
|
||||
echo=settings.DEBUG, # SQL 日志
|
||||
pool_pre_ping=True, # 连接前检测
|
||||
pool_size=10,
|
||||
max_overflow=20,
|
||||
)
|
||||
|
||||
# 异步会话工厂
|
||||
self.async_session_factory = async_sessionmaker(
|
||||
bind=self.async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
# 同步引擎 (用于 Celery 同步任务)
|
||||
self.sync_engine = create_engine(
|
||||
settings.mysql_url.replace("mysql+pymysql", "mysql"),
|
||||
echo=settings.DEBUG,
|
||||
pool_pre_ping=True,
|
||||
pool_size=5,
|
||||
max_overflow=10,
|
||||
)
|
||||
|
||||
# 同步会话工厂
|
||||
self.sync_session_factory = sessionmaker(
|
||||
bind=self.sync_engine,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
async def init_db(self):
|
||||
"""初始化数据库,创建所有表"""
|
||||
try:
|
||||
async with self.async_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
logger.info("MySQL 数据库表初始化完成")
|
||||
except Exception as e:
|
||||
logger.error(f"MySQL 数据库初始化失败: {e}")
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
"""关闭数据库连接"""
|
||||
await self.async_engine.dispose()
|
||||
self.sync_engine.dispose()
|
||||
logger.info("MySQL 数据库连接已关闭")
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""获取异步数据库会话"""
|
||||
session = self.async_session_factory()
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def execute_query(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
执行原始 SQL 查询
|
||||
|
||||
Args:
|
||||
query: SQL 查询语句
|
||||
params: 查询参数
|
||||
|
||||
Returns:
|
||||
查询结果列表
|
||||
"""
|
||||
async with self.get_session() as session:
|
||||
result = await session.execute(select(text(query)), params or {})
|
||||
rows = result.fetchall()
|
||||
return [dict(row._mapping) for row in rows]
|
||||
|
||||
async def execute_raw_sql(
|
||||
self,
|
||||
sql: str,
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
) -> Any:
|
||||
"""
|
||||
执行原始 SQL 语句 (INSERT/UPDATE/DELETE)
|
||||
|
||||
Args:
|
||||
sql: SQL 语句
|
||||
params: 语句参数
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
async with self.get_session() as session:
|
||||
result = await session.execute(text(sql), params or {})
|
||||
await session.commit()
|
||||
return result.lastrowid if result.lastrowid else result.rowcount
|
||||
|
||||
|
||||
# ==================== 预定义的数据模型 ====================
|
||||
|
||||
class DocumentTable(Base):
|
||||
"""文档元数据表 - 存储已解析文档的基本信息"""
|
||||
__tablename__ = "document_tables"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
table_name = Column(String(255), unique=True, nullable=False, comment="表名")
|
||||
display_name = Column(String(255), comment="显示名称")
|
||||
description = Column(Text, comment="表描述")
|
||||
source_file = Column(String(512), comment="来源文件")
|
||||
column_count = Column(Integer, default=0, comment="列数")
|
||||
row_count = Column(Integer, default=0, comment="行数")
|
||||
file_size = Column(Integer, comment="文件大小(字节)")
|
||||
created_at = Column(DateTime, comment="创建时间")
|
||||
updated_at = Column(DateTime, comment="更新时间")
|
||||
|
||||
|
||||
class DocumentField(Base):
|
||||
"""文档字段表 - 存储每个表的字段信息"""
|
||||
__tablename__ = "document_fields"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
table_id = Column(Integer, nullable=False, comment="所属表ID")
|
||||
field_name = Column(String(255), nullable=False, comment="字段名")
|
||||
field_type = Column(String(50), comment="字段类型")
|
||||
field_description = Column(Text, comment="字段描述/语义")
|
||||
is_key_field = Column(Integer, default=0, comment="是否主键")
|
||||
is_nullable = Column(Integer, default=1, comment="是否可空")
|
||||
sample_values = Column(Text, comment="示例值(逗号分隔)")
|
||||
created_at = Column(DateTime, comment="创建时间")
|
||||
|
||||
|
||||
class TaskRecord(Base):
|
||||
"""任务记录表 - 存储异步任务信息"""
|
||||
__tablename__ = "task_records"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
task_id = Column(String(255), unique=True, nullable=False, comment="Celery任务ID")
|
||||
task_type = Column(String(50), comment="任务类型")
|
||||
status = Column(String(50), default="pending", comment="任务状态")
|
||||
input_params = Column(Text, comment="输入参数JSON")
|
||||
result_data = Column(Text, comment="结果数据JSON")
|
||||
error_message = Column(Text, comment="错误信息")
|
||||
started_at = Column(DateTime, comment="开始时间")
|
||||
completed_at = Column(DateTime, comment="完成时间")
|
||||
created_at = Column(DateTime, comment="创建时间")
|
||||
|
||||
|
||||
# ==================== 全局单例 ====================
|
||||
|
||||
mysql_db = MySQLDB()
|
||||
287
backend/app/core/database/redis_db.py
Normal file
287
backend/app/core/database/redis_db.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""
|
||||
Redis 数据库连接管理模块
|
||||
|
||||
提供缓存和任务队列功能
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from typing import Any, Optional
|
||||
|
||||
import redis.asyncio as redis
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisDB:
|
||||
"""Redis 数据库管理类"""
|
||||
|
||||
def __init__(self):
|
||||
self.client: Optional[redis.Redis] = None
|
||||
self._connected = False
|
||||
|
||||
async def connect(self):
|
||||
"""建立 Redis 连接"""
|
||||
try:
|
||||
self.client = redis.from_url(
|
||||
settings.REDIS_URL,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
# 验证连接
|
||||
await self.client.ping()
|
||||
self._connected = True
|
||||
logger.info(f"Redis 连接成功: {settings.REDIS_URL}")
|
||||
except Exception as e:
|
||||
logger.error(f"Redis 连接失败: {e}")
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
"""关闭 Redis 连接"""
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
self._connected = False
|
||||
logger.info("Redis 连接已关闭")
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""检查连接状态"""
|
||||
return self._connected
|
||||
|
||||
# ==================== 基础操作 ====================
|
||||
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
"""获取值"""
|
||||
return await self.client.get(key)
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str,
|
||||
expire: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
设置值
|
||||
|
||||
Args:
|
||||
key: 键
|
||||
value: 值
|
||||
expire: 过期时间(秒)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
return await self.client.set(key, value, ex=expire)
|
||||
|
||||
async def delete(self, key: str) -> int:
|
||||
"""删除键"""
|
||||
return await self.client.delete(key)
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""检查键是否存在"""
|
||||
return await self.client.exists(key) > 0
|
||||
|
||||
# ==================== JSON 操作 ====================
|
||||
|
||||
async def set_json(
|
||||
self,
|
||||
key: str,
|
||||
data: Dict[str, Any],
|
||||
expire: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
设置 JSON 数据
|
||||
|
||||
Args:
|
||||
key: 键
|
||||
data: 数据字典
|
||||
expire: 过期时间(秒)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
json_str = json.dumps(data, ensure_ascii=False, default=str)
|
||||
return await self.set(key, json_str, expire)
|
||||
|
||||
async def get_json(self, key: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取 JSON 数据
|
||||
|
||||
Args:
|
||||
key: 键
|
||||
|
||||
Returns:
|
||||
数据字典,不存在返回 None
|
||||
"""
|
||||
value = await self.get(key)
|
||||
if value:
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
return None
|
||||
|
||||
# ==================== 任务状态管理 ====================
|
||||
|
||||
async def set_task_status(
|
||||
self,
|
||||
task_id: str,
|
||||
status: str,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
expire: int = 86400, # 默认24小时过期
|
||||
) -> bool:
|
||||
"""
|
||||
设置任务状态
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
status: 状态 (pending/processing/success/failure)
|
||||
meta: 附加信息
|
||||
expire: 过期时间(秒)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
key = f"task:{task_id}"
|
||||
data = {
|
||||
"status": status,
|
||||
"meta": meta or {},
|
||||
}
|
||||
return await self.set_json(key, data, expire)
|
||||
|
||||
async def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取任务状态
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
状态信息
|
||||
"""
|
||||
key = f"task:{task_id}"
|
||||
return await self.get_json(key)
|
||||
|
||||
async def update_task_progress(
|
||||
self,
|
||||
task_id: str,
|
||||
progress: int,
|
||||
message: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
更新任务进度
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
progress: 进度值 (0-100)
|
||||
message: 进度消息
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
data = await self.get_task_status(task_id)
|
||||
if data:
|
||||
data["meta"]["progress"] = progress
|
||||
if message:
|
||||
data["meta"]["message"] = message
|
||||
key = f"task:{task_id}"
|
||||
return await self.set_json(key, data, expire=86400)
|
||||
return False
|
||||
|
||||
# ==================== 缓存操作 ====================
|
||||
|
||||
async def cache_document(
|
||||
self,
|
||||
doc_id: str,
|
||||
data: Dict[str, Any],
|
||||
expire: int = 3600, # 默认1小时
|
||||
) -> bool:
|
||||
"""
|
||||
缓存文档数据
|
||||
|
||||
Args:
|
||||
doc_id: 文档ID
|
||||
data: 文档数据
|
||||
expire: 过期时间(秒)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
key = f"doc:{doc_id}"
|
||||
return await self.set_json(key, data, expire)
|
||||
|
||||
async def get_cached_document(self, doc_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取缓存的文档
|
||||
|
||||
Args:
|
||||
doc_id: 文档ID
|
||||
|
||||
Returns:
|
||||
文档数据
|
||||
"""
|
||||
key = f"doc:{doc_id}"
|
||||
return await self.get_json(key)
|
||||
|
||||
# ==================== 分布式锁 ====================
|
||||
|
||||
async def acquire_lock(
|
||||
self,
|
||||
lock_name: str,
|
||||
expire: int = 30,
|
||||
) -> bool:
|
||||
"""
|
||||
获取分布式锁
|
||||
|
||||
Args:
|
||||
lock_name: 锁名称
|
||||
expire: 过期时间(秒)
|
||||
|
||||
Returns:
|
||||
是否获取成功
|
||||
"""
|
||||
key = f"lock:{lock_name}"
|
||||
# 使用 SET NX EX 原子操作
|
||||
result = await self.client.set(key, "1", nx=True, ex=expire)
|
||||
return result is not None
|
||||
|
||||
async def release_lock(self, lock_name: str) -> bool:
|
||||
"""
|
||||
释放分布式锁
|
||||
|
||||
Args:
|
||||
lock_name: 锁名称
|
||||
|
||||
Returns:
|
||||
是否释放成功
|
||||
"""
|
||||
key = f"lock:{lock_name}"
|
||||
result = await self.client.delete(key)
|
||||
return result > 0
|
||||
|
||||
# ==================== 计数器 ====================
|
||||
|
||||
async def incr(self, key: str, amount: int = 1) -> int:
|
||||
"""递增计数器"""
|
||||
return await self.client.incrby(key, amount)
|
||||
|
||||
async def decr(self, key: str, amount: int = 1) -> int:
|
||||
"""递减计数器"""
|
||||
return await self.client.decrby(key, amount)
|
||||
|
||||
# ==================== 过期时间管理 ====================
|
||||
|
||||
async def expire(self, key: str, seconds: int) -> bool:
|
||||
"""设置键的过期时间"""
|
||||
return await self.client.expire(key, seconds)
|
||||
|
||||
async def ttl(self, key: str) -> int:
|
||||
"""获取键的剩余生存时间"""
|
||||
return await self.client.ttl(key)
|
||||
|
||||
|
||||
# ==================== 全局单例 ====================
|
||||
|
||||
redis_db = RedisDB()
|
||||
@@ -1,7 +1,48 @@
|
||||
"""
|
||||
文档解析模块 - 支持多种文件格式的解析
|
||||
"""
|
||||
from .base import BaseParser
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from .base import BaseParser, ParseResult
|
||||
from .xlsx_parser import XlsxParser
|
||||
|
||||
__all__ = ['BaseParser', 'XlsxParser']
|
||||
# 导入其他解析器 (需要先实现)
|
||||
# from .docx_parser import DocxParser
|
||||
# from .md_parser import MarkdownParser
|
||||
# from .txt_parser import TxtParser
|
||||
|
||||
|
||||
class ParserFactory:
|
||||
"""解析器工厂,根据文件类型返回对应解析器"""
|
||||
|
||||
_parsers: Dict[str, BaseParser] = {
|
||||
'.xlsx': XlsxParser(),
|
||||
'.xls': XlsxParser(),
|
||||
# '.docx': DocxParser(), # TODO: 待实现
|
||||
# '.md': MarkdownParser(), # TODO: 待实现
|
||||
# '.txt': TxtParser(), # TODO: 待实现
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_parser(cls, file_path: str) -> BaseParser:
|
||||
"""根据文件扩展名获取解析器"""
|
||||
ext = Path(file_path).suffix.lower()
|
||||
parser = cls._parsers.get(ext)
|
||||
if not parser:
|
||||
raise ValueError(f"不支持的文件格式: {ext},支持的格式: {list(cls._parsers.keys())}")
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def parse(cls, file_path: str, **kwargs) -> ParseResult:
|
||||
"""统一解析接口"""
|
||||
parser = cls.get_parser(file_path)
|
||||
return parser.parse(file_path, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def register_parser(cls, ext: str, parser: BaseParser):
|
||||
"""注册新的解析器"""
|
||||
cls._parsers[ext.lower()] = parser
|
||||
|
||||
|
||||
__all__ = ['BaseParser', 'ParseResult', 'XlsxParser', 'ParserFactory']
|
||||
|
||||
Reference in New Issue
Block a user