234 lines
6.5 KiB
Python
234 lines
6.5 KiB
Python
"""
|
|
RAG 服务模块 - 检索增强生成
|
|
|
|
使用 LangChain + Faiss 实现向量检索
|
|
"""
|
|
import logging
|
|
import os
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from langchain.embeddings import HuggingFaceEmbeddings
|
|
from langchain.schema import Document as LangchainDocument
|
|
from langchain.vectorstores import FAISS
|
|
|
|
from app.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RAGService:
|
|
"""RAG 检索增强服务"""
|
|
|
|
def __init__(self):
|
|
self.embeddings: Optional[HuggingFaceEmbeddings] = None
|
|
self.vector_store: Optional[FAISS] = None
|
|
self._initialized = False
|
|
|
|
def _init_embeddings(self):
|
|
"""初始化嵌入模型"""
|
|
if self.embeddings is None:
|
|
self.embeddings = HuggingFaceEmbeddings(
|
|
model_name=settings.EMBEDDING_MODEL,
|
|
model_kwargs={'device': 'cpu'}
|
|
)
|
|
logger.info(f"RAG 嵌入模型初始化完成: {settings.EMBEDDING_MODEL}")
|
|
|
|
def _init_vector_store(self):
|
|
"""初始化向量存储"""
|
|
if self.vector_store is None:
|
|
self._init_embeddings()
|
|
self.vector_store = FAISS(
|
|
embedding_function=self.embeddings,
|
|
index=None, # 创建一个空索引
|
|
docstore={},
|
|
index_to_docstore_id={}
|
|
)
|
|
logger.info("Faiss 向量存储初始化完成")
|
|
|
|
async def initialize(self):
|
|
"""异步初始化"""
|
|
try:
|
|
self._init_vector_store()
|
|
self._initialized = True
|
|
logger.info("RAG 服务初始化成功")
|
|
except Exception as e:
|
|
logger.error(f"RAG 服务初始化失败: {e}")
|
|
raise
|
|
|
|
def index_field(
|
|
self,
|
|
table_name: str,
|
|
field_name: str,
|
|
field_description: str,
|
|
sample_values: Optional[List[str]] = None
|
|
):
|
|
"""
|
|
将字段信息索引到向量数据库
|
|
|
|
Args:
|
|
table_name: 表名
|
|
field_name: 字段名
|
|
field_description: 字段语义描述
|
|
sample_values: 示例值
|
|
"""
|
|
if not self._initialized:
|
|
self._init_vector_store()
|
|
|
|
# 构造完整文本
|
|
text = f"表名: {table_name}, 字段: {field_name}, 描述: {field_description}"
|
|
if sample_values:
|
|
text += f", 示例值: {', '.join(sample_values)}"
|
|
|
|
# 创建文档
|
|
doc_id = f"{table_name}.{field_name}"
|
|
doc = LangchainDocument(
|
|
page_content=text,
|
|
metadata={
|
|
"table_name": table_name,
|
|
"field_name": field_name,
|
|
"doc_id": doc_id
|
|
}
|
|
)
|
|
|
|
# 添加到向量存储
|
|
if self.vector_store is None:
|
|
self._init_vector_store()
|
|
|
|
self.vector_store.add_documents([doc], ids=[doc_id])
|
|
logger.debug(f"已索引字段: {doc_id}")
|
|
|
|
def index_document_content(
|
|
self,
|
|
doc_id: str,
|
|
content: str,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
):
|
|
"""
|
|
将文档内容索引到向量数据库
|
|
|
|
Args:
|
|
doc_id: 文档ID
|
|
content: 文档内容
|
|
metadata: 元数据
|
|
"""
|
|
if not self._initialized:
|
|
self._init_vector_store()
|
|
|
|
doc = LangchainDocument(
|
|
page_content=content,
|
|
metadata=metadata or {"doc_id": doc_id}
|
|
)
|
|
|
|
if self.vector_store is None:
|
|
self._init_vector_store()
|
|
|
|
self.vector_store.add_documents([doc], ids=[doc_id])
|
|
logger.debug(f"已索引文档: {doc_id}")
|
|
|
|
def retrieve(
|
|
self,
|
|
query: str,
|
|
top_k: int = 5
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
根据查询检索相关文档
|
|
|
|
Args:
|
|
query: 用户查询
|
|
top_k: 返回数量
|
|
|
|
Returns:
|
|
相关文档列表
|
|
"""
|
|
if not self._initialized:
|
|
self._init_vector_store()
|
|
|
|
if self.vector_store is None:
|
|
return []
|
|
|
|
# 执行相似度搜索
|
|
docs_and_scores = self.vector_store.similarity_search_with_score(
|
|
query,
|
|
k=top_k
|
|
)
|
|
|
|
results = []
|
|
for doc, score in docs_and_scores:
|
|
results.append({
|
|
"content": doc.page_content,
|
|
"metadata": doc.metadata,
|
|
"score": float(score), # 距离分数,越小越相似
|
|
"doc_id": doc.metadata.get("doc_id", "")
|
|
})
|
|
|
|
logger.debug(f"检索到 {len(results)} 条相关文档")
|
|
return results
|
|
|
|
def retrieve_by_table(self, table_name: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
"""
|
|
检索指定表的字段
|
|
|
|
Args:
|
|
table_name: 表名
|
|
top_k: 返回数量
|
|
|
|
Returns:
|
|
相关字段列表
|
|
"""
|
|
return self.retrieve(f"表名: {table_name}", top_k)
|
|
|
|
def get_vector_count(self) -> int:
|
|
"""获取向量总数"""
|
|
if self.vector_store is None:
|
|
return 0
|
|
return len(self.vector_store.docstore._dict)
|
|
|
|
def save_index(self, persist_path: str):
|
|
"""
|
|
保存向量索引到磁盘
|
|
|
|
Args:
|
|
persist_path: 保存路径
|
|
"""
|
|
if self.vector_store is not None:
|
|
self.vector_store.save_local(persist_path)
|
|
logger.info(f"向量索引已保存到: {persist_path}")
|
|
|
|
def load_index(self, persist_path: str):
|
|
"""
|
|
从磁盘加载向量索引
|
|
|
|
Args:
|
|
persist_path: 保存路径
|
|
"""
|
|
if not os.path.exists(persist_path):
|
|
logger.warning(f"向量索引文件不存在: {persist_path}")
|
|
return
|
|
|
|
self._init_embeddings()
|
|
self.vector_store = FAISS.load_local(
|
|
persist_path,
|
|
self.embeddings,
|
|
allow_dangerous_deserialization=True
|
|
)
|
|
self._initialized = True
|
|
logger.info(f"向量索引已从 {persist_path} 加载")
|
|
|
|
def delete_by_doc_id(self, doc_id: str):
|
|
"""根据文档ID删除索引"""
|
|
if self.vector_store is not None:
|
|
self.vector_store.delete(ids=[doc_id])
|
|
logger.debug(f"已删除索引: {doc_id}")
|
|
|
|
def clear(self):
|
|
"""清空所有索引"""
|
|
self._init_vector_store()
|
|
if self.vector_store is not None:
|
|
self.vector_store.delete(ids=list(self.vector_store.docstore._dict.keys()))
|
|
logger.info("已清空所有向量索引")
|
|
|
|
|
|
# ==================== 全局单例 ====================
|
|
|
|
rag_service = RAGService()
|