```
feat(database): 为MySQL服务添加text函数导入支持 添加了SQLAlchemy的text函数导入,用于支持原始SQL查询操作, 增强数据库交互的灵活性和兼容性。 --- feat(excel): 改进Excel存储服务的列名处理机制 优化了列名清理逻辑,支持UTF8编码包括中文字符,实现唯一列名 生成机制,防止列名冲突。同时切换到pymysql直接插入方式, 提升批量数据插入性能并解决SQLAlchemy异步问题。 --- fix(rag): 改进RAG服务嵌入模型加载策略 当嵌入模型加载失败时,采用更稳健的降级策略,使用简化模式 运行RAG服务而非完全失败,确保系统核心功能可用性。 ```
This commit is contained in:
@@ -16,6 +16,7 @@ from sqlalchemy import (
|
||||
String,
|
||||
Text,
|
||||
create_engine,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
||||
|
||||
@@ -17,6 +17,7 @@ from sqlalchemy import (
|
||||
String,
|
||||
Text,
|
||||
inspect,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -66,15 +67,41 @@ class ExcelStorageService:
|
||||
Returns:
|
||||
合法的字段名
|
||||
"""
|
||||
# 只保留字母、数字、下划线
|
||||
name = re.sub(r'[^a-zA-Z0-9_]', '_', str(col_name))
|
||||
|
||||
# 确保以字母开头
|
||||
# MySQL 支持 UTF8 编码,中文字符可以直接使用
|
||||
# 只处理非法字符(控制字符等)和首字符数字
|
||||
name = str(col_name).strip()
|
||||
# 移除控制字符
|
||||
name = re.sub(r'[\x00-\x1f\x7f]', '', name)
|
||||
# 确保以字母或中文开头
|
||||
if name and name[0].isdigit():
|
||||
name = 'col_' + name
|
||||
# 限制长度 (MySQL 字段名最多64字符)
|
||||
return name[:64]
|
||||
|
||||
# 限制长度
|
||||
return name[:50]
|
||||
def _get_unique_column_name(self, col_name: str, used_names: set) -> str:
|
||||
"""
|
||||
获取唯一的列名,避免重复
|
||||
|
||||
Args:
|
||||
col_name: 原始列名
|
||||
used_names: 已使用的列名集合
|
||||
|
||||
Returns:
|
||||
唯一的列名
|
||||
"""
|
||||
sanitized = self._sanitize_column_name(col_name)
|
||||
if sanitized not in used_names:
|
||||
used_names.add(sanitized)
|
||||
return sanitized
|
||||
|
||||
# 添加数字后缀直到唯一
|
||||
base = sanitized if sanitized else "col"
|
||||
counter = 1
|
||||
while f"{base}_{counter}" in used_names:
|
||||
counter += 1
|
||||
unique_name = f"{base}_{counter}"
|
||||
used_names.add(unique_name)
|
||||
return unique_name
|
||||
|
||||
def _infer_column_type(self, series: pd.Series) -> str:
|
||||
"""
|
||||
@@ -191,12 +218,15 @@ class ExcelStorageService:
|
||||
# 清理列名
|
||||
df.columns = [str(c) for c in df.columns]
|
||||
|
||||
# 推断列类型
|
||||
# 推断列类型,并生成唯一的列名
|
||||
column_types = {}
|
||||
column_name_map = {} # 原始列名 -> 唯一合法列名
|
||||
used_names = set()
|
||||
for col in df.columns:
|
||||
col_name = self._sanitize_column_name(col)
|
||||
col_name = self._get_unique_column_name(col, used_names)
|
||||
col_type = self._infer_column_type(df[col])
|
||||
column_types[col] = col_type
|
||||
column_name_map[col] = col_name
|
||||
results["columns"].append({
|
||||
"original_name": col,
|
||||
"sanitized_name": col_name,
|
||||
@@ -205,10 +235,9 @@ class ExcelStorageService:
|
||||
|
||||
# 创建表 - 使用原始 SQL 以兼容异步
|
||||
logger.info(f"正在创建MySQL表: {table_name}")
|
||||
from sqlalchemy import text
|
||||
sql_columns = ["id INT AUTO_INCREMENT PRIMARY KEY"]
|
||||
for col in df.columns:
|
||||
col_name = self._sanitize_column_name(col)
|
||||
col_name = column_name_map[col]
|
||||
col_type = column_types.get(col, "TEXT")
|
||||
sql_type = "INT" if col_type == "INTEGER" else "FLOAT" if col_type == "FLOAT" else "DATETIME" if col_type == "DATETIME" else "TEXT"
|
||||
sql_columns.append(f"`{col_name}` {sql_type}")
|
||||
@@ -223,7 +252,7 @@ class ExcelStorageService:
|
||||
for _, row in df.iterrows():
|
||||
record = {}
|
||||
for col in df.columns:
|
||||
col_name = self._sanitize_column_name(col)
|
||||
col_name = column_name_map[col]
|
||||
value = row[col]
|
||||
|
||||
# 处理 NaN 值
|
||||
@@ -244,13 +273,33 @@ class ExcelStorageService:
|
||||
|
||||
records.append(record)
|
||||
|
||||
logger.info(f"正在插入 {len(records)} 条数据到 MySQL...")
|
||||
# 批量插入
|
||||
async with self.mysql_db.get_session() as session:
|
||||
for record in records:
|
||||
session.add(model_class(**record))
|
||||
await session.commit()
|
||||
logger.info(f"数据插入完成: {len(records)} 条")
|
||||
logger.info(f"正在插入 {len(records)} 条数据到 MySQL (使用批量插入)...")
|
||||
# 使用 pymysql 直接插入以避免 SQLAlchemy 异步问题
|
||||
import pymysql
|
||||
from app.config import settings
|
||||
|
||||
connection = pymysql.connect(
|
||||
host=settings.MYSQL_HOST,
|
||||
port=settings.MYSQL_PORT,
|
||||
user=settings.MYSQL_USER,
|
||||
password=settings.MYSQL_PASSWORD,
|
||||
database=settings.MYSQL_DATABASE,
|
||||
charset=settings.MYSQL_CHARSET
|
||||
)
|
||||
try:
|
||||
columns_str = ', '.join(['`' + column_name_map[col] + '`' for col in df.columns])
|
||||
placeholders = ', '.join(['%s' for _ in df.columns])
|
||||
insert_sql = f"INSERT INTO `{table_name}` ({columns_str}) VALUES ({placeholders})"
|
||||
|
||||
# 转换为元组列表 (使用映射后的列名)
|
||||
param_list = [tuple(record.get(column_name_map[col]) for col in df.columns) for record in records]
|
||||
|
||||
with connection.cursor() as cursor:
|
||||
cursor.executemany(insert_sql, param_list)
|
||||
connection.commit()
|
||||
logger.info(f"数据插入完成: {len(records)} 条")
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
results["row_count"] = len(records)
|
||||
logger.info(f"Excel 数据已存储到 MySQL 表 {table_name},共 {len(records)} 行")
|
||||
|
||||
@@ -40,24 +40,31 @@ class RAGService:
|
||||
def _init_embeddings(self):
|
||||
"""初始化嵌入模型"""
|
||||
if self.embedding_model is None:
|
||||
model_name = getattr(settings, 'EMBEDDING_MODEL', 'all-MiniLM-L6-v2')
|
||||
# 使用轻量级本地模型,避免网络问题
|
||||
model_name = 'all-MiniLM-L6-v2'
|
||||
try:
|
||||
self.embedding_model = SentenceTransformer(model_name)
|
||||
self._dimension = self.embedding_model.get_sentence_embedding_dimension()
|
||||
logger.info(f"RAG 嵌入模型初始化完成: {model_name}, 维度: {self._dimension}")
|
||||
except Exception as e:
|
||||
logger.warning(f"嵌入模型 {model_name} 加载失败,使用默认模型: {e}")
|
||||
# 使用轻量级默认模型
|
||||
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
self._dimension = self.embedding_model.get_sentence_embedding_dimension()
|
||||
logger.info(f"RAG 嵌入模型使用默认: all-MiniLM-L6-v2, 维度: {self._dimension}")
|
||||
logger.warning(f"嵌入模型 {model_name} 加载失败: {e}")
|
||||
# 如果本地模型也失败,使用简单hash作为后备
|
||||
self.embedding_model = None
|
||||
self._dimension = 384
|
||||
logger.info("RAG 使用简化模式 (无向量嵌入)")
|
||||
|
||||
def _init_vector_store(self):
|
||||
"""初始化向量存储"""
|
||||
if self.index is None:
|
||||
self._init_embeddings()
|
||||
self.index = faiss.IndexIDMap(faiss.IndexFlatIP(self._dimension))
|
||||
logger.info("Faiss 向量存储初始化完成")
|
||||
if self.embedding_model is None:
|
||||
# 无法加载嵌入模型,使用简化模式
|
||||
self._dimension = 384
|
||||
self.index = None
|
||||
logger.warning("RAG 嵌入模型未加载,使用简化模式")
|
||||
else:
|
||||
self.index = faiss.IndexIDMap(faiss.IndexFlatIP(self._dimension))
|
||||
logger.info("Faiss 向量存储初始化完成")
|
||||
|
||||
async def initialize(self):
|
||||
"""异步初始化"""
|
||||
@@ -86,6 +93,11 @@ class RAGService:
|
||||
if not self._initialized:
|
||||
self._init_vector_store()
|
||||
|
||||
# 如果没有嵌入模型,只记录到日志
|
||||
if self.embedding_model is None:
|
||||
logger.debug(f"字段跳过索引 (无嵌入模型): {table_name}.{field_name}")
|
||||
return
|
||||
|
||||
text = f"表名: {table_name}, 字段: {field_name}, 描述: {field_description}"
|
||||
if sample_values:
|
||||
text += f", 示例值: {', '.join(sample_values)}"
|
||||
@@ -108,6 +120,11 @@ class RAGService:
|
||||
if not self._initialized:
|
||||
self._init_vector_store()
|
||||
|
||||
# 如果没有嵌入模型,只记录到日志
|
||||
if self.embedding_model is None:
|
||||
logger.debug(f"文档跳过索引 (无嵌入模型): {doc_id}")
|
||||
return
|
||||
|
||||
doc = SimpleDocument(
|
||||
page_content=content,
|
||||
metadata=metadata or {"doc_id": doc_id}
|
||||
|
||||
Reference in New Issue
Block a user