feat(database): 为MySQL服务添加text函数导入支持 添加了SQLAlchemy的text函数导入,用于支持原始SQL查询操作, 增强数据库交互的灵活性和兼容性。 --- feat(excel): 改进Excel存储服务的列名处理机制 优化了列名清理逻辑,支持UTF8编码包括中文字符,实现唯一列名 生成机制,防止列名冲突。同时切换到pymysql直接插入方式, 提升批量数据插入性能并解决SQLAlchemy异步问题。 --- fix(rag): 改进RAG服务嵌入模型加载策略 当嵌入模型加载失败时,采用更稳健的降级策略,使用简化模式 运行RAG服务而非完全失败,确保系统核心功能可用性。 ```
215 lines
7.1 KiB
Python
215 lines
7.1 KiB
Python
"""
|
|
MySQL 数据库连接管理模块
|
|
|
|
提供结构化数据的存储和查询功能
|
|
"""
|
|
import logging
|
|
from contextlib import asynccontextmanager
|
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
|
|
|
from sqlalchemy import (
|
|
Column,
|
|
DateTime,
|
|
Enum as SQLEnum,
|
|
Float,
|
|
Integer,
|
|
String,
|
|
Text,
|
|
create_engine,
|
|
text,
|
|
)
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
|
from sqlalchemy.sql import select
|
|
|
|
from app.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Base(DeclarativeBase):
|
|
"""SQLAlchemy 声明基类"""
|
|
pass
|
|
|
|
|
|
class MySQLDB:
|
|
"""MySQL 数据库管理类"""
|
|
|
|
def __init__(self):
|
|
# 异步引擎 (用于 FastAPI 异步操作)
|
|
self.async_engine = create_async_engine(
|
|
settings.async_mysql_url,
|
|
echo=settings.DEBUG, # SQL 日志
|
|
pool_pre_ping=True, # 连接前检测
|
|
pool_size=10,
|
|
max_overflow=20,
|
|
)
|
|
|
|
# 异步会话工厂
|
|
self.async_session_factory = async_sessionmaker(
|
|
bind=self.async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autocommit=False,
|
|
autoflush=False,
|
|
)
|
|
|
|
# 同步引擎 (用于 Celery 同步任务)
|
|
self.sync_engine = create_engine(
|
|
settings.mysql_url,
|
|
echo=settings.DEBUG,
|
|
pool_pre_ping=True,
|
|
pool_size=5,
|
|
max_overflow=10,
|
|
)
|
|
|
|
# 同步会话工厂
|
|
self.sync_session_factory = sessionmaker(
|
|
bind=self.sync_engine,
|
|
autocommit=False,
|
|
autoflush=False,
|
|
)
|
|
|
|
async def init_db(self):
|
|
"""初始化数据库,创建所有表"""
|
|
try:
|
|
# 先创建数据库(如果不存在)
|
|
from sqlalchemy import text
|
|
db_name = settings.MYSQL_DATABASE
|
|
# 连接时不指定数据库来创建数据库
|
|
temp_url = (
|
|
f"mysql+aiomysql://{settings.MYSQL_USER}:{settings.MYSQL_PASSWORD}"
|
|
f"@{settings.MYSQL_HOST}:{settings.MYSQL_PORT}/"
|
|
f"?charset={settings.MYSQL_CHARSET}"
|
|
)
|
|
from sqlalchemy.ext.asyncio import create_async_engine
|
|
temp_engine = create_async_engine(temp_url, echo=False)
|
|
try:
|
|
async with temp_engine.connect() as conn:
|
|
await conn.execute(text(f"CREATE DATABASE IF NOT EXISTS `{db_name}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"))
|
|
await conn.commit()
|
|
logger.info(f"MySQL 数据库 {db_name} 创建或已存在")
|
|
finally:
|
|
await temp_engine.dispose()
|
|
|
|
# 然后创建表
|
|
async with self.async_engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
logger.info("MySQL 数据库表初始化完成")
|
|
except Exception as e:
|
|
logger.error(f"MySQL 数据库初始化失败: {e}")
|
|
raise
|
|
|
|
async def close(self):
|
|
"""关闭数据库连接"""
|
|
await self.async_engine.dispose()
|
|
self.sync_engine.dispose()
|
|
logger.info("MySQL 数据库连接已关闭")
|
|
|
|
@asynccontextmanager
|
|
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
|
|
"""获取异步数据库会话"""
|
|
session = self.async_session_factory()
|
|
try:
|
|
yield session
|
|
await session.commit()
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
finally:
|
|
await session.close()
|
|
|
|
async def execute_query(
|
|
self,
|
|
query: str,
|
|
params: Optional[Dict[str, Any]] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
执行原始 SQL 查询
|
|
|
|
Args:
|
|
query: SQL 查询语句
|
|
params: 查询参数
|
|
|
|
Returns:
|
|
查询结果列表
|
|
"""
|
|
async with self.get_session() as session:
|
|
result = await session.execute(select(text(query)), params or {})
|
|
rows = result.fetchall()
|
|
return [dict(row._mapping) for row in rows]
|
|
|
|
async def execute_raw_sql(
|
|
self,
|
|
sql: str,
|
|
params: Optional[Dict[str, Any]] = None
|
|
) -> Any:
|
|
"""
|
|
执行原始 SQL 语句 (INSERT/UPDATE/DELETE)
|
|
|
|
Args:
|
|
sql: SQL 语句
|
|
params: 语句参数
|
|
|
|
Returns:
|
|
执行结果
|
|
"""
|
|
async with self.get_session() as session:
|
|
result = await session.execute(text(sql), params or {})
|
|
await session.commit()
|
|
return result.lastrowid if result.lastrowid else result.rowcount
|
|
|
|
|
|
# ==================== 预定义的数据模型 ====================
|
|
|
|
class DocumentTable(Base):
|
|
"""文档元数据表 - 存储已解析文档的基本信息"""
|
|
__tablename__ = "document_tables"
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
table_name = Column(String(255), unique=True, nullable=False, comment="表名")
|
|
display_name = Column(String(255), comment="显示名称")
|
|
description = Column(Text, comment="表描述")
|
|
source_file = Column(String(512), comment="来源文件")
|
|
column_count = Column(Integer, default=0, comment="列数")
|
|
row_count = Column(Integer, default=0, comment="行数")
|
|
file_size = Column(Integer, comment="文件大小(字节)")
|
|
created_at = Column(DateTime, comment="创建时间")
|
|
updated_at = Column(DateTime, comment="更新时间")
|
|
|
|
|
|
class DocumentField(Base):
|
|
"""文档字段表 - 存储每个表的字段信息"""
|
|
__tablename__ = "document_fields"
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
table_id = Column(Integer, nullable=False, comment="所属表ID")
|
|
field_name = Column(String(255), nullable=False, comment="字段名")
|
|
field_type = Column(String(50), comment="字段类型")
|
|
field_description = Column(Text, comment="字段描述/语义")
|
|
is_key_field = Column(Integer, default=0, comment="是否主键")
|
|
is_nullable = Column(Integer, default=1, comment="是否可空")
|
|
sample_values = Column(Text, comment="示例值(逗号分隔)")
|
|
created_at = Column(DateTime, comment="创建时间")
|
|
|
|
|
|
class TaskRecord(Base):
|
|
"""任务记录表 - 存储异步任务信息"""
|
|
__tablename__ = "task_records"
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
task_id = Column(String(255), unique=True, nullable=False, comment="Celery任务ID")
|
|
task_type = Column(String(50), comment="任务类型")
|
|
status = Column(String(50), default="pending", comment="任务状态")
|
|
input_params = Column(Text, comment="输入参数JSON")
|
|
result_data = Column(Text, comment="结果数据JSON")
|
|
error_message = Column(Text, comment="错误信息")
|
|
started_at = Column(DateTime, comment="开始时间")
|
|
completed_at = Column(DateTime, comment="完成时间")
|
|
created_at = Column(DateTime, comment="创建时间")
|
|
|
|
|
|
# ==================== 全局单例 ====================
|
|
|
|
mysql_db = MySQLDB()
|