完成后端数据库连接配置
This commit is contained in:
233
backend/app/services/rag_service.py
Normal file
233
backend/app/services/rag_service.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
RAG 服务模块 - 检索增强生成
|
||||
|
||||
使用 LangChain + Faiss 实现向量检索
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain.schema import Document as LangchainDocument
|
||||
from langchain.vectorstores import FAISS
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RAGService:
|
||||
"""RAG 检索增强服务"""
|
||||
|
||||
def __init__(self):
|
||||
self.embeddings: Optional[HuggingFaceEmbeddings] = None
|
||||
self.vector_store: Optional[FAISS] = None
|
||||
self._initialized = False
|
||||
|
||||
def _init_embeddings(self):
|
||||
"""初始化嵌入模型"""
|
||||
if self.embeddings is None:
|
||||
self.embeddings = HuggingFaceEmbeddings(
|
||||
model_name=settings.EMBEDDING_MODEL,
|
||||
model_kwargs={'device': 'cpu'}
|
||||
)
|
||||
logger.info(f"RAG 嵌入模型初始化完成: {settings.EMBEDDING_MODEL}")
|
||||
|
||||
def _init_vector_store(self):
|
||||
"""初始化向量存储"""
|
||||
if self.vector_store is None:
|
||||
self._init_embeddings()
|
||||
self.vector_store = FAISS(
|
||||
embedding_function=self.embeddings,
|
||||
index=None, # 创建一个空索引
|
||||
docstore={},
|
||||
index_to_docstore_id={}
|
||||
)
|
||||
logger.info("Faiss 向量存储初始化完成")
|
||||
|
||||
async def initialize(self):
|
||||
"""异步初始化"""
|
||||
try:
|
||||
self._init_vector_store()
|
||||
self._initialized = True
|
||||
logger.info("RAG 服务初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"RAG 服务初始化失败: {e}")
|
||||
raise
|
||||
|
||||
def index_field(
|
||||
self,
|
||||
table_name: str,
|
||||
field_name: str,
|
||||
field_description: str,
|
||||
sample_values: Optional[List[str]] = None
|
||||
):
|
||||
"""
|
||||
将字段信息索引到向量数据库
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
field_name: 字段名
|
||||
field_description: 字段语义描述
|
||||
sample_values: 示例值
|
||||
"""
|
||||
if not self._initialized:
|
||||
self._init_vector_store()
|
||||
|
||||
# 构造完整文本
|
||||
text = f"表名: {table_name}, 字段: {field_name}, 描述: {field_description}"
|
||||
if sample_values:
|
||||
text += f", 示例值: {', '.join(sample_values)}"
|
||||
|
||||
# 创建文档
|
||||
doc_id = f"{table_name}.{field_name}"
|
||||
doc = LangchainDocument(
|
||||
page_content=text,
|
||||
metadata={
|
||||
"table_name": table_name,
|
||||
"field_name": field_name,
|
||||
"doc_id": doc_id
|
||||
}
|
||||
)
|
||||
|
||||
# 添加到向量存储
|
||||
if self.vector_store is None:
|
||||
self._init_vector_store()
|
||||
|
||||
self.vector_store.add_documents([doc], ids=[doc_id])
|
||||
logger.debug(f"已索引字段: {doc_id}")
|
||||
|
||||
def index_document_content(
|
||||
self,
|
||||
doc_id: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
将文档内容索引到向量数据库
|
||||
|
||||
Args:
|
||||
doc_id: 文档ID
|
||||
content: 文档内容
|
||||
metadata: 元数据
|
||||
"""
|
||||
if not self._initialized:
|
||||
self._init_vector_store()
|
||||
|
||||
doc = LangchainDocument(
|
||||
page_content=content,
|
||||
metadata=metadata or {"doc_id": doc_id}
|
||||
)
|
||||
|
||||
if self.vector_store is None:
|
||||
self._init_vector_store()
|
||||
|
||||
self.vector_store.add_documents([doc], ids=[doc_id])
|
||||
logger.debug(f"已索引文档: {doc_id}")
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 5
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
根据查询检索相关文档
|
||||
|
||||
Args:
|
||||
query: 用户查询
|
||||
top_k: 返回数量
|
||||
|
||||
Returns:
|
||||
相关文档列表
|
||||
"""
|
||||
if not self._initialized:
|
||||
self._init_vector_store()
|
||||
|
||||
if self.vector_store is None:
|
||||
return []
|
||||
|
||||
# 执行相似度搜索
|
||||
docs_and_scores = self.vector_store.similarity_search_with_score(
|
||||
query,
|
||||
k=top_k
|
||||
)
|
||||
|
||||
results = []
|
||||
for doc, score in docs_and_scores:
|
||||
results.append({
|
||||
"content": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
"score": float(score), # 距离分数,越小越相似
|
||||
"doc_id": doc.metadata.get("doc_id", "")
|
||||
})
|
||||
|
||||
logger.debug(f"检索到 {len(results)} 条相关文档")
|
||||
return results
|
||||
|
||||
def retrieve_by_table(self, table_name: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
检索指定表的字段
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
top_k: 返回数量
|
||||
|
||||
Returns:
|
||||
相关字段列表
|
||||
"""
|
||||
return self.retrieve(f"表名: {table_name}", top_k)
|
||||
|
||||
def get_vector_count(self) -> int:
|
||||
"""获取向量总数"""
|
||||
if self.vector_store is None:
|
||||
return 0
|
||||
return len(self.vector_store.docstore._dict)
|
||||
|
||||
def save_index(self, persist_path: str):
|
||||
"""
|
||||
保存向量索引到磁盘
|
||||
|
||||
Args:
|
||||
persist_path: 保存路径
|
||||
"""
|
||||
if self.vector_store is not None:
|
||||
self.vector_store.save_local(persist_path)
|
||||
logger.info(f"向量索引已保存到: {persist_path}")
|
||||
|
||||
def load_index(self, persist_path: str):
|
||||
"""
|
||||
从磁盘加载向量索引
|
||||
|
||||
Args:
|
||||
persist_path: 保存路径
|
||||
"""
|
||||
if not os.path.exists(persist_path):
|
||||
logger.warning(f"向量索引文件不存在: {persist_path}")
|
||||
return
|
||||
|
||||
self._init_embeddings()
|
||||
self.vector_store = FAISS.load_local(
|
||||
persist_path,
|
||||
self.embeddings,
|
||||
allow_dangerous_deserialization=True
|
||||
)
|
||||
self._initialized = True
|
||||
logger.info(f"向量索引已从 {persist_path} 加载")
|
||||
|
||||
def delete_by_doc_id(self, doc_id: str):
|
||||
"""根据文档ID删除索引"""
|
||||
if self.vector_store is not None:
|
||||
self.vector_store.delete(ids=[doc_id])
|
||||
logger.debug(f"已删除索引: {doc_id}")
|
||||
|
||||
def clear(self):
|
||||
"""清空所有索引"""
|
||||
self._init_vector_store()
|
||||
if self.vector_store is not None:
|
||||
self.vector_store.delete(ids=list(self.vector_store.docstore._dict.keys()))
|
||||
logger.info("已清空所有向量索引")
|
||||
|
||||
|
||||
# ==================== 全局单例 ====================
|
||||
|
||||
rag_service = RAGService()
|
||||
Reference in New Issue
Block a user