""" MySQL 数据库连接管理模块 提供结构化数据的存储和查询功能 """ import logging from contextlib import asynccontextmanager from typing import Any, AsyncGenerator, Dict, List, Optional from sqlalchemy import ( Column, DateTime, Enum as SQLEnum, Float, Integer, String, Text, create_engine, ) from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase, sessionmaker from sqlalchemy.sql import select from app.config import settings logger = logging.getLogger(__name__) class Base(DeclarativeBase): """SQLAlchemy 声明基类""" pass class MySQLDB: """MySQL 数据库管理类""" def __init__(self): # 异步引擎 (用于 FastAPI 异步操作) self.async_engine = create_async_engine( settings.async_mysql_url, echo=settings.DEBUG, # SQL 日志 pool_pre_ping=True, # 连接前检测 pool_size=10, max_overflow=20, ) # 异步会话工厂 self.async_session_factory = async_sessionmaker( bind=self.async_engine, class_=AsyncSession, expire_on_commit=False, autocommit=False, autoflush=False, ) # 同步引擎 (用于 Celery 同步任务) self.sync_engine = create_engine( settings.mysql_url, echo=settings.DEBUG, pool_pre_ping=True, pool_size=5, max_overflow=10, ) # 同步会话工厂 self.sync_session_factory = sessionmaker( bind=self.sync_engine, autocommit=False, autoflush=False, ) async def init_db(self): """初始化数据库,创建所有表""" try: # 先创建数据库(如果不存在) from sqlalchemy import text db_name = settings.MYSQL_DATABASE # 连接时不指定数据库来创建数据库 temp_url = ( f"mysql+aiomysql://{settings.MYSQL_USER}:{settings.MYSQL_PASSWORD}" f"@{settings.MYSQL_HOST}:{settings.MYSQL_PORT}/" f"?charset={settings.MYSQL_CHARSET}" ) from sqlalchemy.ext.asyncio import create_async_engine temp_engine = create_async_engine(temp_url, echo=False) try: async with temp_engine.connect() as conn: await conn.execute(text(f"CREATE DATABASE IF NOT EXISTS `{db_name}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci")) await conn.commit() logger.info(f"MySQL 数据库 {db_name} 创建或已存在") finally: await temp_engine.dispose() # 然后创建表 async with self.async_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) logger.info("MySQL 数据库表初始化完成") except Exception as e: logger.error(f"MySQL 数据库初始化失败: {e}") raise async def close(self): """关闭数据库连接""" await self.async_engine.dispose() self.sync_engine.dispose() logger.info("MySQL 数据库连接已关闭") @asynccontextmanager async def get_session(self) -> AsyncGenerator[AsyncSession, None]: """获取异步数据库会话""" session = self.async_session_factory() try: yield session await session.commit() except Exception: await session.rollback() raise finally: await session.close() async def execute_query( self, query: str, params: Optional[Dict[str, Any]] = None ) -> List[Dict[str, Any]]: """ 执行原始 SQL 查询 Args: query: SQL 查询语句 params: 查询参数 Returns: 查询结果列表 """ async with self.get_session() as session: result = await session.execute(select(text(query)), params or {}) rows = result.fetchall() return [dict(row._mapping) for row in rows] async def execute_raw_sql( self, sql: str, params: Optional[Dict[str, Any]] = None ) -> Any: """ 执行原始 SQL 语句 (INSERT/UPDATE/DELETE) Args: sql: SQL 语句 params: 语句参数 Returns: 执行结果 """ async with self.get_session() as session: result = await session.execute(text(sql), params or {}) await session.commit() return result.lastrowid if result.lastrowid else result.rowcount # ==================== 预定义的数据模型 ==================== class DocumentTable(Base): """文档元数据表 - 存储已解析文档的基本信息""" __tablename__ = "document_tables" id = Column(Integer, primary_key=True, autoincrement=True) table_name = Column(String(255), unique=True, nullable=False, comment="表名") display_name = Column(String(255), comment="显示名称") description = Column(Text, comment="表描述") source_file = Column(String(512), comment="来源文件") column_count = Column(Integer, default=0, comment="列数") row_count = Column(Integer, default=0, comment="行数") file_size = Column(Integer, comment="文件大小(字节)") created_at = Column(DateTime, comment="创建时间") updated_at = Column(DateTime, comment="更新时间") class DocumentField(Base): """文档字段表 - 存储每个表的字段信息""" __tablename__ = "document_fields" id = Column(Integer, primary_key=True, autoincrement=True) table_id = Column(Integer, nullable=False, comment="所属表ID") field_name = Column(String(255), nullable=False, comment="字段名") field_type = Column(String(50), comment="字段类型") field_description = Column(Text, comment="字段描述/语义") is_key_field = Column(Integer, default=0, comment="是否主键") is_nullable = Column(Integer, default=1, comment="是否可空") sample_values = Column(Text, comment="示例值(逗号分隔)") created_at = Column(DateTime, comment="创建时间") class TaskRecord(Base): """任务记录表 - 存储异步任务信息""" __tablename__ = "task_records" id = Column(Integer, primary_key=True, autoincrement=True) task_id = Column(String(255), unique=True, nullable=False, comment="Celery任务ID") task_type = Column(String(50), comment="任务类型") status = Column(String(50), default="pending", comment="任务状态") input_params = Column(Text, comment="输入参数JSON") result_data = Column(Text, comment="结果数据JSON") error_message = Column(Text, comment="错误信息") started_at = Column(DateTime, comment="开始时间") completed_at = Column(DateTime, comment="完成时间") created_at = Column(DateTime, comment="创建时间") # ==================== 全局单例 ==================== mysql_db = MySQLDB()