完成后端数据库连接配置

This commit is contained in:
2026-03-26 19:49:40 +08:00
parent d3bdb17e87
commit 4bdc3f9707
19 changed files with 2843 additions and 302 deletions

View File

@@ -0,0 +1,18 @@
"""
数据库连接管理模块
提供 MySQL、MongoDB、Redis 的连接管理
"""
from app.core.database.mysql import MySQLDB, mysql_db, Base
from app.core.database.mongodb import MongoDB, mongodb
from app.core.database.redis_db import RedisDB, redis_db
__all__ = [
"MySQLDB",
"mysql_db",
"MongoDB",
"mongodb",
"RedisDB",
"redis_db",
"Base",
]

View File

@@ -0,0 +1,247 @@
"""
MongoDB 数据库连接管理模块
提供非结构化数据的存储和查询功能
"""
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
from app.config import settings
logger = logging.getLogger(__name__)
class MongoDB:
"""MongoDB 数据库管理类"""
def __init__(self):
self.client: Optional[AsyncIOMotorClient] = None
self.db: Optional[AsyncIOMotorDatabase] = None
async def connect(self):
"""建立 MongoDB 连接"""
try:
self.client = AsyncIOMotorClient(
settings.MONGODB_URL,
serverSelectionTimeoutMS=5000,
)
self.db = self.client[settings.MONGODB_DB_NAME]
# 验证连接
await self.client.admin.command('ping')
logger.info(f"MongoDB 连接成功: {settings.MONGODB_DB_NAME}")
except Exception as e:
logger.error(f"MongoDB 连接失败: {e}")
raise
async def close(self):
"""关闭 MongoDB 连接"""
if self.client:
self.client.close()
logger.info("MongoDB 连接已关闭")
@property
def documents(self):
"""文档集合 - 存储原始文档和解析结果"""
return self.db["documents"]
@property
def embeddings(self):
"""向量嵌入集合 - 存储文本嵌入向量"""
return self.db["embeddings"]
@property
def rag_index(self):
"""RAG索引集合 - 存储字段语义索引"""
return self.db["rag_index"]
# ==================== 文档操作 ====================
async def insert_document(
self,
doc_type: str,
content: str,
metadata: Dict[str, Any],
structured_data: Optional[Dict[str, Any]] = None,
) -> str:
"""
插入文档
Args:
doc_type: 文档类型 (docx/xlsx/md/txt)
content: 原始文本内容
metadata: 元数据
structured_data: 结构化数据 (表格等)
Returns:
插入文档的ID
"""
document = {
"doc_type": doc_type,
"content": content,
"metadata": metadata,
"structured_data": structured_data,
"created_at": datetime.utcnow(),
"updated_at": datetime.utcnow(),
}
result = await self.documents.insert_one(document)
logger.info(f"文档已插入MongoDB: {result.inserted_id}")
return str(result.inserted_id)
async def get_document(self, doc_id: str) -> Optional[Dict[str, Any]]:
"""根据ID获取文档"""
from bson import ObjectId
doc = await self.documents.find_one({"_id": ObjectId(doc_id)})
if doc:
doc["_id"] = str(doc["_id"])
return doc
async def search_documents(
self,
query: str,
doc_type: Optional[str] = None,
limit: int = 10,
) -> List[Dict[str, Any]]:
"""
搜索文档
Args:
query: 搜索关键词
doc_type: 文档类型过滤
limit: 返回数量
Returns:
文档列表
"""
filter_query = {"content": {"$regex": query}}
if doc_type:
filter_query["doc_type"] = doc_type
cursor = self.documents.find(filter_query).limit(limit)
documents = []
async for doc in cursor:
doc["_id"] = str(doc["_id"])
documents.append(doc)
return documents
async def delete_document(self, doc_id: str) -> bool:
"""删除文档"""
from bson import ObjectId
result = await self.documents.delete_one({"_id": ObjectId(doc_id)})
return result.deleted_count > 0
# ==================== RAG 索引操作 ====================
async def insert_rag_entry(
self,
table_name: str,
field_name: str,
field_description: str,
embedding: List[float],
metadata: Optional[Dict[str, Any]] = None,
) -> str:
"""
插入RAG索引条目
Args:
table_name: 表名
field_name: 字段名
field_description: 字段描述
embedding: 向量嵌入
metadata: 其他元数据
Returns:
插入条目的ID
"""
entry = {
"table_name": table_name,
"field_name": field_name,
"field_description": field_description,
"embedding": embedding,
"metadata": metadata or {},
"created_at": datetime.utcnow(),
}
result = await self.rag_index.insert_one(entry)
return str(result.inserted_id)
async def search_rag(
self,
query_embedding: List[float],
top_k: int = 5,
table_name: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
搜索RAG索引 (使用向量相似度)
Args:
query_embedding: 查询向量
top_k: 返回数量
table_name: 可选的表名过滤
Returns:
相关的索引条目
"""
# MongoDB 5.0+ 支持向量搜索
# 较低版本使用欧氏距离替代
pipeline = [
{
"$addFields": {
"distance": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{
"$pow": [
{
"$subtract": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]},
]
},
2,
]
},
]
},
}
}
}
},
{"$sort": {"distance": 1}},
{"$limit": top_k},
]
if table_name:
pipeline.insert(0, {"$match": {"table_name": table_name}})
results = []
async for doc in self.rag_index.aggregate(pipeline):
doc["_id"] = str(doc["_id"])
results.append(doc)
return results
# ==================== 集合管理 ====================
async def create_indexes(self):
"""创建索引以优化查询"""
# 文档集合索引
await self.documents.create_index("doc_type")
await self.documents.create_index("created_at")
await self.documents.create_index([("content", "text")])
# RAG索引集合索引
await self.rag_index.create_index("table_name")
await self.rag_index.create_index("field_name")
await self.rag_index.create_index([("embedding", "hnsw", {"type": "knnVector"})])
logger.info("MongoDB 索引创建完成")
# ==================== 全局单例 ====================
mongodb = MongoDB()

View File

@@ -0,0 +1,193 @@
"""
MySQL 数据库连接管理模块
提供结构化数据的存储和查询功能
"""
import logging
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, List, Optional
from sqlalchemy import (
Column,
DateTime,
Enum as SQLEnum,
Float,
Integer,
String,
Text,
create_engine,
)
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, sessionmaker
from sqlalchemy.sql import select
from app.config import settings
logger = logging.getLogger(__name__)
class Base(DeclarativeBase):
"""SQLAlchemy 声明基类"""
pass
class MySQLDB:
"""MySQL 数据库管理类"""
def __init__(self):
# 异步引擎 (用于 FastAPI 异步操作)
self.async_engine = create_async_engine(
settings.mysql_url,
echo=settings.DEBUG, # SQL 日志
pool_pre_ping=True, # 连接前检测
pool_size=10,
max_overflow=20,
)
# 异步会话工厂
self.async_session_factory = async_sessionmaker(
bind=self.async_engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
# 同步引擎 (用于 Celery 同步任务)
self.sync_engine = create_engine(
settings.mysql_url.replace("mysql+pymysql", "mysql"),
echo=settings.DEBUG,
pool_pre_ping=True,
pool_size=5,
max_overflow=10,
)
# 同步会话工厂
self.sync_session_factory = sessionmaker(
bind=self.sync_engine,
autocommit=False,
autoflush=False,
)
async def init_db(self):
"""初始化数据库,创建所有表"""
try:
async with self.async_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("MySQL 数据库表初始化完成")
except Exception as e:
logger.error(f"MySQL 数据库初始化失败: {e}")
raise
async def close(self):
"""关闭数据库连接"""
await self.async_engine.dispose()
self.sync_engine.dispose()
logger.info("MySQL 数据库连接已关闭")
@asynccontextmanager
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
"""获取异步数据库会话"""
session = self.async_session_factory()
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def execute_query(
self,
query: str,
params: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
执行原始 SQL 查询
Args:
query: SQL 查询语句
params: 查询参数
Returns:
查询结果列表
"""
async with self.get_session() as session:
result = await session.execute(select(text(query)), params or {})
rows = result.fetchall()
return [dict(row._mapping) for row in rows]
async def execute_raw_sql(
self,
sql: str,
params: Optional[Dict[str, Any]] = None
) -> Any:
"""
执行原始 SQL 语句 (INSERT/UPDATE/DELETE)
Args:
sql: SQL 语句
params: 语句参数
Returns:
执行结果
"""
async with self.get_session() as session:
result = await session.execute(text(sql), params or {})
await session.commit()
return result.lastrowid if result.lastrowid else result.rowcount
# ==================== 预定义的数据模型 ====================
class DocumentTable(Base):
"""文档元数据表 - 存储已解析文档的基本信息"""
__tablename__ = "document_tables"
id = Column(Integer, primary_key=True, autoincrement=True)
table_name = Column(String(255), unique=True, nullable=False, comment="表名")
display_name = Column(String(255), comment="显示名称")
description = Column(Text, comment="表描述")
source_file = Column(String(512), comment="来源文件")
column_count = Column(Integer, default=0, comment="列数")
row_count = Column(Integer, default=0, comment="行数")
file_size = Column(Integer, comment="文件大小(字节)")
created_at = Column(DateTime, comment="创建时间")
updated_at = Column(DateTime, comment="更新时间")
class DocumentField(Base):
"""文档字段表 - 存储每个表的字段信息"""
__tablename__ = "document_fields"
id = Column(Integer, primary_key=True, autoincrement=True)
table_id = Column(Integer, nullable=False, comment="所属表ID")
field_name = Column(String(255), nullable=False, comment="字段名")
field_type = Column(String(50), comment="字段类型")
field_description = Column(Text, comment="字段描述/语义")
is_key_field = Column(Integer, default=0, comment="是否主键")
is_nullable = Column(Integer, default=1, comment="是否可空")
sample_values = Column(Text, comment="示例值(逗号分隔)")
created_at = Column(DateTime, comment="创建时间")
class TaskRecord(Base):
"""任务记录表 - 存储异步任务信息"""
__tablename__ = "task_records"
id = Column(Integer, primary_key=True, autoincrement=True)
task_id = Column(String(255), unique=True, nullable=False, comment="Celery任务ID")
task_type = Column(String(50), comment="任务类型")
status = Column(String(50), default="pending", comment="任务状态")
input_params = Column(Text, comment="输入参数JSON")
result_data = Column(Text, comment="结果数据JSON")
error_message = Column(Text, comment="错误信息")
started_at = Column(DateTime, comment="开始时间")
completed_at = Column(DateTime, comment="完成时间")
created_at = Column(DateTime, comment="创建时间")
# ==================== 全局单例 ====================
mysql_db = MySQLDB()

View File

@@ -0,0 +1,287 @@
"""
Redis 数据库连接管理模块
提供缓存和任务队列功能
"""
import json
import logging
from datetime import timedelta
from typing import Any, Optional
import redis.asyncio as redis
from app.config import settings
logger = logging.getLogger(__name__)
class RedisDB:
"""Redis 数据库管理类"""
def __init__(self):
self.client: Optional[redis.Redis] = None
self._connected = False
async def connect(self):
"""建立 Redis 连接"""
try:
self.client = redis.from_url(
settings.REDIS_URL,
encoding="utf-8",
decode_responses=True,
)
# 验证连接
await self.client.ping()
self._connected = True
logger.info(f"Redis 连接成功: {settings.REDIS_URL}")
except Exception as e:
logger.error(f"Redis 连接失败: {e}")
raise
async def close(self):
"""关闭 Redis 连接"""
if self.client:
await self.client.close()
self._connected = False
logger.info("Redis 连接已关闭")
@property
def is_connected(self) -> bool:
"""检查连接状态"""
return self._connected
# ==================== 基础操作 ====================
async def get(self, key: str) -> Optional[str]:
"""获取值"""
return await self.client.get(key)
async def set(
self,
key: str,
value: str,
expire: Optional[int] = None,
) -> bool:
"""
设置值
Args:
key: 键
value: 值
expire: 过期时间(秒)
Returns:
是否成功
"""
return await self.client.set(key, value, ex=expire)
async def delete(self, key: str) -> int:
"""删除键"""
return await self.client.delete(key)
async def exists(self, key: str) -> bool:
"""检查键是否存在"""
return await self.client.exists(key) > 0
# ==================== JSON 操作 ====================
async def set_json(
self,
key: str,
data: Dict[str, Any],
expire: Optional[int] = None,
) -> bool:
"""
设置 JSON 数据
Args:
key: 键
data: 数据字典
expire: 过期时间(秒)
Returns:
是否成功
"""
json_str = json.dumps(data, ensure_ascii=False, default=str)
return await self.set(key, json_str, expire)
async def get_json(self, key: str) -> Optional[Dict[str, Any]]:
"""
获取 JSON 数据
Args:
key: 键
Returns:
数据字典,不存在返回 None
"""
value = await self.get(key)
if value:
try:
return json.loads(value)
except json.JSONDecodeError:
return None
return None
# ==================== 任务状态管理 ====================
async def set_task_status(
self,
task_id: str,
status: str,
meta: Optional[Dict[str, Any]] = None,
expire: int = 86400, # 默认24小时过期
) -> bool:
"""
设置任务状态
Args:
task_id: 任务ID
status: 状态 (pending/processing/success/failure)
meta: 附加信息
expire: 过期时间(秒)
Returns:
是否成功
"""
key = f"task:{task_id}"
data = {
"status": status,
"meta": meta or {},
}
return await self.set_json(key, data, expire)
async def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
"""
获取任务状态
Args:
task_id: 任务ID
Returns:
状态信息
"""
key = f"task:{task_id}"
return await self.get_json(key)
async def update_task_progress(
self,
task_id: str,
progress: int,
message: Optional[str] = None,
) -> bool:
"""
更新任务进度
Args:
task_id: 任务ID
progress: 进度值 (0-100)
message: 进度消息
Returns:
是否成功
"""
data = await self.get_task_status(task_id)
if data:
data["meta"]["progress"] = progress
if message:
data["meta"]["message"] = message
key = f"task:{task_id}"
return await self.set_json(key, data, expire=86400)
return False
# ==================== 缓存操作 ====================
async def cache_document(
self,
doc_id: str,
data: Dict[str, Any],
expire: int = 3600, # 默认1小时
) -> bool:
"""
缓存文档数据
Args:
doc_id: 文档ID
data: 文档数据
expire: 过期时间(秒)
Returns:
是否成功
"""
key = f"doc:{doc_id}"
return await self.set_json(key, data, expire)
async def get_cached_document(self, doc_id: str) -> Optional[Dict[str, Any]]:
"""
获取缓存的文档
Args:
doc_id: 文档ID
Returns:
文档数据
"""
key = f"doc:{doc_id}"
return await self.get_json(key)
# ==================== 分布式锁 ====================
async def acquire_lock(
self,
lock_name: str,
expire: int = 30,
) -> bool:
"""
获取分布式锁
Args:
lock_name: 锁名称
expire: 过期时间(秒)
Returns:
是否获取成功
"""
key = f"lock:{lock_name}"
# 使用 SET NX EX 原子操作
result = await self.client.set(key, "1", nx=True, ex=expire)
return result is not None
async def release_lock(self, lock_name: str) -> bool:
"""
释放分布式锁
Args:
lock_name: 锁名称
Returns:
是否释放成功
"""
key = f"lock:{lock_name}"
result = await self.client.delete(key)
return result > 0
# ==================== 计数器 ====================
async def incr(self, key: str, amount: int = 1) -> int:
"""递增计数器"""
return await self.client.incrby(key, amount)
async def decr(self, key: str, amount: int = 1) -> int:
"""递减计数器"""
return await self.client.decrby(key, amount)
# ==================== 过期时间管理 ====================
async def expire(self, key: str, seconds: int) -> bool:
"""设置键的过期时间"""
return await self.client.expire(key, seconds)
async def ttl(self, key: str) -> int:
"""获取键的剩余生存时间"""
return await self.client.ttl(key)
# ==================== 全局单例 ====================
redis_db = RedisDB()