完善数据库调用

This commit is contained in:
2026-03-27 00:06:17 +08:00
parent 6b88e971e8
commit 7c88da9ab1
18 changed files with 133 additions and 129 deletions

View File

@@ -45,11 +45,20 @@ class Settings(BaseSettings):
@property
def mysql_url(self) -> str:
"""生成MySQL连接URL"""
"""生成MySQL连接URL (同步)"""
return (
f"mysql+pymysql://{self.MYSQL_USER}:{self.MYSQL_PASSWORD}"
f"@{self.MYSQL_HOST}:{self.MYSQL_PORT}/{self.MYSQL_DATABASE}"
f"?charset={self.MYSQL_CHARSET}"
)
@property
def async_mysql_url(self) -> str:
"""生成MySQL连接URL (异步)"""
return (
f"mysql+aiomysql://{self.MYSQL_USER}:{self.MYSQL_PASSWORD}"
f"@{self.MYSQL_HOST}:{self.MYSQL_PORT}/{self.MYSQL_DATABASE}"
f"?charset={self.MYSQL_CHARSET}"
)
settings = Settings()

View File

@@ -37,7 +37,7 @@ class MySQLDB:
def __init__(self):
# 异步引擎 (用于 FastAPI 异步操作)
self.async_engine = create_async_engine(
settings.mysql_url,
settings.async_mysql_url,
echo=settings.DEBUG, # SQL 日志
pool_pre_ping=True, # 连接前检测
pool_size=10,
@@ -55,7 +55,7 @@ class MySQLDB:
# 同步引擎 (用于 Celery 同步任务)
self.sync_engine = create_engine(
settings.mysql_url.replace("mysql+pymysql", "mysql"),
settings.mysql_url,
echo=settings.DEBUG,
pool_pre_ping=True,
pool_size=5,

View File

@@ -6,7 +6,7 @@ Redis 数据库连接管理模块
import json
import logging
from datetime import timedelta
from typing import Any, Optional
from typing import Any, Dict, Optional
import redis.asyncio as redis

View File

@@ -1,48 +1,54 @@
"""
RAG 服务模块 - 检索增强生成
使用 LangChain + Faiss 实现向量检索
使用 sentence-transformers + Faiss 实现向量检索
"""
import json
import logging
import os
import pickle
from typing import Any, Dict, List, Optional
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document as LangchainDocument
from langchain.vectorstores import FAISS
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from app.config import settings
logger = logging.getLogger(__name__)
class SimpleDocument:
"""简化文档对象"""
def __init__(self, page_content: str, metadata: Dict[str, Any]):
self.page_content = page_content
self.metadata = metadata
class RAGService:
"""RAG 检索增强服务"""
def __init__(self):
self.embeddings: Optional[HuggingFaceEmbeddings] = None
self.vector_store: Optional[FAISS] = None
self.embedding_model: Optional[SentenceTransformer] = None
self.index: Optional[faiss.Index] = None
self.documents: List[Dict[str, Any]] = []
self.doc_ids: List[str] = []
self._dimension: int = 0
self._initialized = False
self._persist_dir = settings.FAISS_INDEX_DIR
def _init_embeddings(self):
"""初始化嵌入模型"""
if self.embeddings is None:
self.embeddings = HuggingFaceEmbeddings(
model_name=settings.EMBEDDING_MODEL,
model_kwargs={'device': 'cpu'}
)
logger.info(f"RAG 嵌入模型初始化完成: {settings.EMBEDDING_MODEL}")
if self.embedding_model is None:
self.embedding_model = SentenceTransformer(settings.EMBEDDING_MODEL)
self._dimension = self.embedding_model.get_sentence_embedding_dimension()
logger.info(f"RAG 嵌入模型初始化完成: {settings.EMBEDDING_MODEL}, 维度: {self._dimension}")
def _init_vector_store(self):
"""初始化向量存储"""
if self.vector_store is None:
if self.index is None:
self._init_embeddings()
self.vector_store = FAISS(
embedding_function=self.embeddings,
index=None, # 创建一个空索引
docstore={},
index_to_docstore_id={}
)
self.index = faiss.IndexIDMap(faiss.IndexFlatIP(self._dimension))
logger.info("Faiss 向量存储初始化完成")
async def initialize(self):
@@ -55,6 +61,12 @@ class RAGService:
logger.error(f"RAG 服务初始化失败: {e}")
raise
def _normalize_vectors(self, vectors: np.ndarray) -> np.ndarray:
"""归一化向量"""
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
norms = np.where(norms == 0, 1, norms)
return vectors / norms
def index_field(
self,
table_name: str,
@@ -62,39 +74,20 @@ class RAGService:
field_description: str,
sample_values: Optional[List[str]] = None
):
"""
将字段信息索引到向量数据库
Args:
table_name: 表名
field_name: 字段名
field_description: 字段语义描述
sample_values: 示例值
"""
"""将字段信息索引到向量数据库"""
if not self._initialized:
self._init_vector_store()
# 构造完整文本
text = f"表名: {table_name}, 字段: {field_name}, 描述: {field_description}"
if sample_values:
text += f", 示例值: {', '.join(sample_values)}"
# 创建文档
doc_id = f"{table_name}.{field_name}"
doc = LangchainDocument(
doc = SimpleDocument(
page_content=text,
metadata={
"table_name": table_name,
"field_name": field_name,
"doc_id": doc_id
}
metadata={"table_name": table_name, "field_name": field_name, "doc_id": doc_id}
)
# 添加到向量存储
if self.vector_store is None:
self._init_vector_store()
self.vector_store.add_documents([doc], ids=[doc_id])
self._add_documents([doc], [doc_id])
logger.debug(f"已索引字段: {doc_id}")
def index_document_content(
@@ -103,131 +96,134 @@ class RAGService:
content: str,
metadata: Optional[Dict[str, Any]] = None
):
"""
将文档内容索引到向量数据库
Args:
doc_id: 文档ID
content: 文档内容
metadata: 元数据
"""
"""将文档内容索引到向量数据库"""
if not self._initialized:
self._init_vector_store()
doc = LangchainDocument(
doc = SimpleDocument(
page_content=content,
metadata=metadata or {"doc_id": doc_id}
)
if self.vector_store is None:
self._init_vector_store()
self.vector_store.add_documents([doc], ids=[doc_id])
self._add_documents([doc], [doc_id])
logger.debug(f"已索引文档: {doc_id}")
def retrieve(
self,
query: str,
top_k: int = 5
) -> List[Dict[str, Any]]:
"""
根据查询检索相关文档
def _add_documents(self, documents: List[SimpleDocument], doc_ids: List[str]):
"""批量添加文档到向量索引"""
if not documents:
return
Args:
query: 用户查询
top_k: 返回数量
texts = [doc.page_content for doc in documents]
embeddings = self.embedding_model.encode(texts, convert_to_numpy=True)
embeddings = self._normalize_vectors(embeddings).astype('float32')
Returns:
相关文档列表
"""
if self.index is None:
self._init_vector_store()
id_list = [hash(did) for did in doc_ids]
id_array = np.array(id_list, dtype='int64')
self.index.add_with_ids(embeddings, id_array)
for doc, did in zip(documents, doc_ids):
self.documents.append({"id": did, "content": doc.page_content, "metadata": doc.metadata})
self.doc_ids.append(did)
def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""根据查询检索相关文档"""
if not self._initialized:
self._init_vector_store()
if self.vector_store is None:
if self.index is None or self.index.ntotal == 0:
return []
# 执行相似度搜索
docs_and_scores = self.vector_store.similarity_search_with_score(
query,
k=top_k
)
query_embedding = self.embedding_model.encode([query], convert_to_numpy=True)
query_embedding = self._normalize_vectors(query_embedding).astype('float32')
scores, indices = self.index.search(query_embedding, min(top_k, self.index.ntotal))
results = []
for doc, score in docs_and_scores:
for score, idx in zip(scores[0], indices[0]):
if idx < 0:
continue
doc = self.documents[idx]
results.append({
"content": doc.page_content,
"metadata": doc.metadata,
"score": float(score), # 距离分数,越小越相似
"doc_id": doc.metadata.get("doc_id", "")
"content": doc["content"],
"metadata": doc["metadata"],
"score": float(score),
"doc_id": doc["id"]
})
logger.debug(f"检索到 {len(results)} 条相关文档")
return results
def retrieve_by_table(self, table_name: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""
检索指定表的字段
Args:
table_name: 表名
top_k: 返回数量
Returns:
相关字段列表
"""
"""检索指定表的字段"""
return self.retrieve(f"表名: {table_name}", top_k)
def get_vector_count(self) -> int:
"""获取向量总数"""
if self.vector_store is None:
if self.index is None:
return 0
return len(self.vector_store.docstore._dict)
return self.index.ntotal
def save_index(self, persist_path: str):
"""
保存向量索引到磁盘
def save_index(self, persist_path: str = None):
"""保存向量索引到磁盘"""
if persist_path is None:
persist_path = self._persist_dir
Args:
persist_path: 保存路径
"""
if self.vector_store is not None:
self.vector_store.save_local(persist_path)
if self.index is not None:
os.makedirs(persist_path, exist_ok=True)
faiss.write_index(self.index, os.path.join(persist_path, "index.faiss"))
with open(os.path.join(persist_path, "documents.pkl"), "wb") as f:
pickle.dump(self.documents, f)
logger.info(f"向量索引已保存到: {persist_path}")
def load_index(self, persist_path: str):
"""
从磁盘加载向量索引
def load_index(self, persist_path: str = None):
"""从磁盘加载向量索引"""
if persist_path is None:
persist_path = self._persist_dir
Args:
persist_path: 保存路径
"""
if not os.path.exists(persist_path):
logger.warning(f"向量索引文件不存在: {persist_path}")
index_file = os.path.join(persist_path, "index.faiss")
docs_file = os.path.join(persist_path, "documents.pkl")
if not os.path.exists(index_file):
logger.warning(f"向量索引文件不存在: {index_file}")
return
self._init_embeddings()
self.vector_store = FAISS.load_local(
persist_path,
self.embeddings,
allow_dangerous_deserialization=True
)
self.index = faiss.read_index(index_file)
with open(docs_file, "rb") as f:
self.documents = pickle.load(f)
self.doc_ids = [d["id"] for d in self.documents]
self._initialized = True
logger.info(f"向量索引已从 {persist_path} 加载")
logger.info(f"向量索引已从 {persist_path} 加载,共 {len(self.documents)}")
def delete_by_doc_id(self, doc_id: str):
"""根据文档ID删除索引"""
if self.vector_store is not None:
self.vector_store.delete(ids=[doc_id])
if self.index is not None:
remaining = [d for d in self.documents if d["id"] != doc_id]
self.documents = remaining
self.doc_ids = [d["id"] for d in self.documents]
self.index.reset()
if self.documents:
texts = [d["content"] for d in self.documents]
embeddings = self.embedding_model.encode(texts, convert_to_numpy=True)
embeddings = self._normalize_vectors(embeddings).astype('float32')
id_array = np.array([hash(did) for did in self.doc_ids], dtype='int64')
self.index.add_with_ids(embeddings, id_array)
logger.debug(f"已删除索引: {doc_id}")
def clear(self):
"""清空所有索引"""
self._init_vector_store()
if self.vector_store is not None:
self.vector_store.delete(ids=list(self.vector_store.docstore._dict.keys()))
if self.index is not None:
self.index.reset()
self.documents = []
self.doc_ids = []
logger.info("已清空所有向量索引")
# ==================== 全局单例 ====================
rag_service = RAGService()

View File

@@ -15,6 +15,7 @@ python-dotenv==1.0.0
# ==================== 数据库 - MySQL (结构化数据) ====================
pymysql==1.1.0
aiomysql==0.2.0
sqlalchemy==2.0.25
# ==================== 数据库 - MongoDB (非结构化数据) ====================
@@ -29,7 +30,7 @@ celery==5.3.4
# ==================== RAG / 向量数据库 ====================
# chromadb==0.4.22 # Windows 需要 C++ 编译环境,如需安装请使用预编译版本或 WSL
sentence-transformers==2.2.2
sentence-transformers==2.7.0
faiss-cpu==1.8.0
# ==================== 文档解析 ====================
@@ -40,8 +41,6 @@ markdown-it-py==3.0.0
chardet==5.2.0
# ==================== AI / LLM ====================
langchain==0.1.0
langchain-community==0.0.10
httpx==0.25.2
# ==================== 数据处理与可视化 ====================

View File

@@ -30,7 +30,7 @@ import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@
import { Checkbox } from '@/components/ui/checkbox';
import { toast } from 'sonner';
import { cn } from '@/lib/utils';
import { backendApi, type ExcelParseResult, type ExcelUploadOptions, aiApi, analysisChartsApi } from '@/db/backend-api';
import { backendApi, type ExcelParseResult, type ExcelUploadOptions, aiApi } from '@/db/backend-api';
import {
Table as TableComponent,
TableBody,
@@ -179,7 +179,7 @@ const ExcelParse: React.FC = () => {
setAnalysisCharts(null);
try {
const result = await analysisChartsApi.extractAndGenerateCharts({
const result = await aiApi.extractAndGenerateCharts({
analysis_text: analysisText,
original_filename: uploadedFile?.name || 'unknown',
file_type: 'excel'