Compare commits
2 Commits
f2af27245d
...
8e713be1ca
| Author | SHA1 | Date | |
|---|---|---|---|
| 8e713be1ca | |||
| a9dc0d8b91 |
@@ -448,18 +448,22 @@ async def process_documents_batch(task_id: str, files: List[dict]):
|
|||||||
|
|
||||||
|
|
||||||
async def index_document_to_rag(doc_id: str, filename: str, result: ParseResult, doc_type: str):
|
async def index_document_to_rag(doc_id: str, filename: str, result: ParseResult, doc_type: str):
|
||||||
"""将非结构化文档索引到 RAG"""
|
"""将非结构化文档索引到 RAG(使用分块索引)"""
|
||||||
try:
|
try:
|
||||||
content = result.data.get("content", "")
|
content = result.data.get("content", "")
|
||||||
if content:
|
if content:
|
||||||
|
# 将完整内容传递给 RAG 服务自动分块索引
|
||||||
rag_service.index_document_content(
|
rag_service.index_document_content(
|
||||||
doc_id=doc_id,
|
doc_id=doc_id,
|
||||||
content=content[:5000],
|
content=content, # 传递完整内容,由 RAG 服务自动分块
|
||||||
metadata={
|
metadata={
|
||||||
"filename": filename,
|
"filename": filename,
|
||||||
"doc_type": doc_type
|
"doc_type": doc_type
|
||||||
}
|
},
|
||||||
|
chunk_size=500, # 每块 500 字符
|
||||||
|
chunk_overlap=50 # 块之间 50 字符重叠
|
||||||
)
|
)
|
||||||
|
logger.info(f"RAG 索引完成: {filename}, doc_id={doc_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"RAG 索引失败: {str(e)}")
|
logger.warning(f"RAG 索引失败: {str(e)}")
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ RAG 服务模块 - 检索增强生成
|
|||||||
|
|
||||||
使用 sentence-transformers + Faiss 实现向量检索
|
使用 sentence-transformers + Faiss 实现向量检索
|
||||||
"""
|
"""
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
@@ -11,12 +10,20 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
import faiss
|
import faiss
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 尝试导入 sentence-transformers
|
||||||
|
try:
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
SENTENCE_TRANSFORMERS_AVAILABLE = True
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(f"sentence-transformers 导入失败: {e}")
|
||||||
|
SENTENCE_TRANSFORMERS_AVAILABLE = False
|
||||||
|
SentenceTransformer = None
|
||||||
|
|
||||||
|
|
||||||
class SimpleDocument:
|
class SimpleDocument:
|
||||||
"""简化文档对象"""
|
"""简化文档对象"""
|
||||||
@@ -28,17 +35,24 @@ class SimpleDocument:
|
|||||||
class RAGService:
|
class RAGService:
|
||||||
"""RAG 检索增强服务"""
|
"""RAG 检索增强服务"""
|
||||||
|
|
||||||
|
# 默认分块参数
|
||||||
|
DEFAULT_CHUNK_SIZE = 500 # 每个文本块的大小(字符数)
|
||||||
|
DEFAULT_CHUNK_OVERLAP = 50 # 块之间的重叠(字符数)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.embedding_model: Optional[SentenceTransformer] = None
|
self.embedding_model = None
|
||||||
self.index: Optional[faiss.Index] = None
|
self.index: Optional[faiss.Index] = None
|
||||||
self.documents: List[Dict[str, Any]] = []
|
self.documents: List[Dict[str, Any]] = []
|
||||||
self.doc_ids: List[str] = []
|
self.doc_ids: List[str] = []
|
||||||
self._dimension: int = 0
|
self._dimension: int = 384 # 默认维度
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
self._persist_dir = settings.FAISS_INDEX_DIR
|
self._persist_dir = settings.FAISS_INDEX_DIR
|
||||||
# 临时禁用 RAG API 调用,仅记录日志
|
# 检查是否可用
|
||||||
self._disabled = True
|
self._disabled = not SENTENCE_TRANSFORMERS_AVAILABLE
|
||||||
logger.info("RAG 服务已禁用(_disabled=True),仅记录索引操作日志")
|
if self._disabled:
|
||||||
|
logger.warning("RAG 服务已禁用(sentence-transformers 不可用),将使用关键词匹配作为后备")
|
||||||
|
else:
|
||||||
|
logger.info("RAG 服务已启用")
|
||||||
|
|
||||||
def _init_embeddings(self):
|
def _init_embeddings(self):
|
||||||
"""初始化嵌入模型"""
|
"""初始化嵌入模型"""
|
||||||
@@ -88,6 +102,63 @@ class RAGService:
|
|||||||
norms = np.where(norms == 0, 1, norms)
|
norms = np.where(norms == 0, 1, norms)
|
||||||
return vectors / norms
|
return vectors / norms
|
||||||
|
|
||||||
|
def _split_into_chunks(self, text: str, chunk_size: int = None, overlap: int = None) -> List[str]:
|
||||||
|
"""
|
||||||
|
将长文本分割成块
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 待分割的文本
|
||||||
|
chunk_size: 每个块的大小(字符数)
|
||||||
|
overlap: 块之间的重叠字符数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
文本块列表
|
||||||
|
"""
|
||||||
|
if chunk_size is None:
|
||||||
|
chunk_size = self.DEFAULT_CHUNK_SIZE
|
||||||
|
if overlap is None:
|
||||||
|
overlap = self.DEFAULT_CHUNK_OVERLAP
|
||||||
|
|
||||||
|
if len(text) <= chunk_size:
|
||||||
|
return [text] if text.strip() else []
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
start = 0
|
||||||
|
text_len = len(text)
|
||||||
|
|
||||||
|
while start < text_len:
|
||||||
|
# 计算当前块的结束位置
|
||||||
|
end = start + chunk_size
|
||||||
|
|
||||||
|
# 如果不是最后一块,尝试在句子边界处切割
|
||||||
|
if end < text_len:
|
||||||
|
# 向前查找最后一个句号、逗号、换行或分号
|
||||||
|
cut_positions = []
|
||||||
|
for i in range(end, max(start, end - 100), -1):
|
||||||
|
if text[i] in '。;,,\n、':
|
||||||
|
cut_positions.append(i + 1)
|
||||||
|
break
|
||||||
|
|
||||||
|
if cut_positions:
|
||||||
|
end = cut_positions[0]
|
||||||
|
else:
|
||||||
|
# 如果没找到句子边界,尝试向后查找
|
||||||
|
for i in range(end, min(text_len, end + 50)):
|
||||||
|
if text[i] in '。;,,\n、':
|
||||||
|
end = i + 1
|
||||||
|
break
|
||||||
|
|
||||||
|
chunk = text[start:end].strip()
|
||||||
|
if chunk:
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# 移动起始位置(考虑重叠)
|
||||||
|
start = end - overlap
|
||||||
|
if start <= 0:
|
||||||
|
start = end
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
def index_field(
|
def index_field(
|
||||||
self,
|
self,
|
||||||
table_name: str,
|
table_name: str,
|
||||||
@@ -124,9 +195,20 @@ class RAGService:
|
|||||||
self,
|
self,
|
||||||
doc_id: str,
|
doc_id: str,
|
||||||
content: str,
|
content: str,
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
chunk_size: int = None,
|
||||||
|
chunk_overlap: int = None
|
||||||
):
|
):
|
||||||
"""将文档内容索引到向量数据库"""
|
"""
|
||||||
|
将文档内容索引到向量数据库(自动分块)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_id: 文档唯一标识
|
||||||
|
content: 文档内容
|
||||||
|
metadata: 文档元数据
|
||||||
|
chunk_size: 文本块大小(字符数),默认500
|
||||||
|
chunk_overlap: 块之间的重叠字符数,默认50
|
||||||
|
"""
|
||||||
if self._disabled:
|
if self._disabled:
|
||||||
logger.info(f"[RAG DISABLED] 文档索引操作已跳过: {doc_id}")
|
logger.info(f"[RAG DISABLED] 文档索引操作已跳过: {doc_id}")
|
||||||
return
|
return
|
||||||
@@ -139,18 +221,56 @@ class RAGService:
|
|||||||
logger.debug(f"文档跳过索引 (无嵌入模型): {doc_id}")
|
logger.debug(f"文档跳过索引 (无嵌入模型): {doc_id}")
|
||||||
return
|
return
|
||||||
|
|
||||||
doc = SimpleDocument(
|
# 分割文档为小块
|
||||||
page_content=content,
|
if chunk_size is None:
|
||||||
metadata=metadata or {"doc_id": doc_id}
|
chunk_size = self.DEFAULT_CHUNK_SIZE
|
||||||
)
|
if chunk_overlap is None:
|
||||||
self._add_documents([doc], [doc_id])
|
chunk_overlap = self.DEFAULT_CHUNK_OVERLAP
|
||||||
logger.debug(f"已索引文档: {doc_id}")
|
|
||||||
|
chunks = self._split_into_chunks(content, chunk_size, chunk_overlap)
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
logger.warning(f"文档内容为空,跳过索引: {doc_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 为每个块创建文档对象
|
||||||
|
documents = []
|
||||||
|
chunk_ids = []
|
||||||
|
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
chunk_id = f"{doc_id}_chunk_{i}"
|
||||||
|
chunk_metadata = metadata.copy() if metadata else {}
|
||||||
|
chunk_metadata.update({
|
||||||
|
"chunk_index": i,
|
||||||
|
"total_chunks": len(chunks),
|
||||||
|
"doc_id": doc_id
|
||||||
|
})
|
||||||
|
|
||||||
|
documents.append(SimpleDocument(
|
||||||
|
page_content=chunk,
|
||||||
|
metadata=chunk_metadata
|
||||||
|
))
|
||||||
|
chunk_ids.append(chunk_id)
|
||||||
|
|
||||||
|
# 批量添加文档
|
||||||
|
self._add_documents(documents, chunk_ids)
|
||||||
|
logger.info(f"已索引文档 {doc_id},共 {len(chunks)} 个块")
|
||||||
|
|
||||||
def _add_documents(self, documents: List[SimpleDocument], doc_ids: List[str]):
|
def _add_documents(self, documents: List[SimpleDocument], doc_ids: List[str]):
|
||||||
"""批量添加文档到向量索引"""
|
"""批量添加文档到向量索引"""
|
||||||
if not documents:
|
if not documents:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 总是将文档存储在内存中(用于关键词搜索后备)
|
||||||
|
for doc, did in zip(documents, doc_ids):
|
||||||
|
self.documents.append({"id": did, "content": doc.page_content, "metadata": doc.metadata})
|
||||||
|
self.doc_ids.append(did)
|
||||||
|
|
||||||
|
# 如果没有嵌入模型,跳过向量索引
|
||||||
|
if self.embedding_model is None:
|
||||||
|
logger.debug(f"文档跳过向量索引 (无嵌入模型): {len(documents)} 个文档")
|
||||||
|
return
|
||||||
|
|
||||||
texts = [doc.page_content for doc in documents]
|
texts = [doc.page_content for doc in documents]
|
||||||
embeddings = self.embedding_model.encode(texts, convert_to_numpy=True)
|
embeddings = self.embedding_model.encode(texts, convert_to_numpy=True)
|
||||||
embeddings = self._normalize_vectors(embeddings).astype('float32')
|
embeddings = self._normalize_vectors(embeddings).astype('float32')
|
||||||
@@ -162,12 +282,18 @@ class RAGService:
|
|||||||
id_array = np.array(id_list, dtype='int64')
|
id_array = np.array(id_list, dtype='int64')
|
||||||
self.index.add_with_ids(embeddings, id_array)
|
self.index.add_with_ids(embeddings, id_array)
|
||||||
|
|
||||||
for doc, did in zip(documents, doc_ids):
|
def retrieve(self, query: str, top_k: int = 5, min_score: float = 0.3) -> List[Dict[str, Any]]:
|
||||||
self.documents.append({"id": did, "content": doc.page_content, "metadata": doc.metadata})
|
"""
|
||||||
self.doc_ids.append(did)
|
根据查询检索相关文档块
|
||||||
|
|
||||||
def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
Args:
|
||||||
"""根据查询检索相关文档"""
|
query: 查询文本
|
||||||
|
top_k: 返回的最大结果数
|
||||||
|
min_score: 最低相似度分数阈值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
相关文档块列表,每项包含 content, metadata, score, doc_id, chunk_index
|
||||||
|
"""
|
||||||
if self._disabled:
|
if self._disabled:
|
||||||
logger.info(f"[RAG DISABLED] 检索操作已跳过: query={query}, top_k={top_k}")
|
logger.info(f"[RAG DISABLED] 检索操作已跳过: query={query}, top_k={top_k}")
|
||||||
return []
|
return []
|
||||||
@@ -175,28 +301,113 @@ class RAGService:
|
|||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
self._init_vector_store()
|
self._init_vector_store()
|
||||||
|
|
||||||
if self.index is None or self.index.ntotal == 0:
|
# 优先使用向量检索
|
||||||
|
if self.index is not None and self.index.ntotal > 0 and self.embedding_model is not None:
|
||||||
|
try:
|
||||||
|
query_embedding = self.embedding_model.encode([query], convert_to_numpy=True)
|
||||||
|
query_embedding = self._normalize_vectors(query_embedding).astype('float32')
|
||||||
|
|
||||||
|
scores, indices = self.index.search(query_embedding, min(top_k, self.index.ntotal))
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for score, idx in zip(scores[0], indices[0]):
|
||||||
|
if idx < 0:
|
||||||
|
continue
|
||||||
|
if score < min_score:
|
||||||
|
continue
|
||||||
|
doc = self.documents[idx]
|
||||||
|
results.append({
|
||||||
|
"content": doc["content"],
|
||||||
|
"metadata": doc["metadata"],
|
||||||
|
"score": float(score),
|
||||||
|
"doc_id": doc["id"],
|
||||||
|
"chunk_index": doc["metadata"].get("chunk_index", 0)
|
||||||
|
})
|
||||||
|
|
||||||
|
if results:
|
||||||
|
logger.debug(f"向量检索到 {len(results)} 条相关文档块")
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"向量检索失败,使用关键词搜索后备: {e}")
|
||||||
|
|
||||||
|
# 后备:使用关键词搜索
|
||||||
|
logger.debug("使用关键词搜索后备方案")
|
||||||
|
return self._keyword_search(query, top_k)
|
||||||
|
|
||||||
|
def _keyword_search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
关键词搜索后备方案
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 查询文本
|
||||||
|
top_k: 返回的最大结果数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
相关文档块列表
|
||||||
|
"""
|
||||||
|
if not self.documents:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
query_embedding = self.embedding_model.encode([query], convert_to_numpy=True)
|
# 提取查询关键词
|
||||||
query_embedding = self._normalize_vectors(query_embedding).astype('float32')
|
keywords = []
|
||||||
|
for char in query:
|
||||||
|
if '\u4e00' <= char <= '\u9fff': # 中文字符
|
||||||
|
keywords.append(char)
|
||||||
|
# 添加英文单词
|
||||||
|
import re
|
||||||
|
english_words = re.findall(r'[a-zA-Z]+', query)
|
||||||
|
keywords.extend(english_words)
|
||||||
|
|
||||||
scores, indices = self.index.search(query_embedding, min(top_k, self.index.ntotal))
|
if not keywords:
|
||||||
|
return []
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for score, idx in zip(scores[0], indices[0]):
|
for doc in self.documents:
|
||||||
if idx < 0:
|
content = doc["content"]
|
||||||
continue
|
# 计算关键词匹配分数
|
||||||
doc = self.documents[idx]
|
score = 0
|
||||||
results.append({
|
matched_keywords = 0
|
||||||
"content": doc["content"],
|
for kw in keywords:
|
||||||
"metadata": doc["metadata"],
|
if kw in content:
|
||||||
"score": float(score),
|
score += 1
|
||||||
"doc_id": doc["id"]
|
matched_keywords += 1
|
||||||
})
|
|
||||||
|
|
||||||
logger.debug(f"检索到 {len(results)} 条相关文档")
|
if matched_keywords > 0:
|
||||||
return results
|
# 归一化分数
|
||||||
|
score = score / max(len(keywords), 1)
|
||||||
|
results.append({
|
||||||
|
"content": content,
|
||||||
|
"metadata": doc["metadata"],
|
||||||
|
"score": score,
|
||||||
|
"doc_id": doc["id"],
|
||||||
|
"chunk_index": doc["metadata"].get("chunk_index", 0)
|
||||||
|
})
|
||||||
|
|
||||||
|
# 按分数排序
|
||||||
|
results.sort(key=lambda x: x["score"], reverse=True)
|
||||||
|
|
||||||
|
logger.debug(f"关键词搜索返回 {len(results[:top_k])} 条结果")
|
||||||
|
return results[:top_k]
|
||||||
|
|
||||||
|
def retrieve_by_doc_id(self, doc_id: str, top_k: int = 10) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取指定文档的所有块
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_id: 文档ID
|
||||||
|
top_k: 返回的最大结果数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
该文档的所有块
|
||||||
|
"""
|
||||||
|
# 获取属于该文档的所有块
|
||||||
|
doc_chunks = [d for d in self.documents if d["metadata"].get("doc_id") == doc_id]
|
||||||
|
|
||||||
|
# 按 chunk_index 排序
|
||||||
|
doc_chunks.sort(key=lambda x: x["metadata"].get("chunk_index", 0))
|
||||||
|
|
||||||
|
# 返回指定数量
|
||||||
|
return doc_chunks[:top_k]
|
||||||
|
|
||||||
def retrieve_by_table(self, table_name: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
def retrieve_by_table(self, table_name: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||||
"""检索指定表的字段"""
|
"""检索指定表的字段"""
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -766,6 +766,7 @@ const Documents: React.FC = () => {
|
|||||||
<div
|
<div
|
||||||
{...getRootProps()}
|
{...getRootProps()}
|
||||||
className="flex items-center justify-center gap-2 p-3 border-2 border-dashed rounded-lg cursor-pointer hover:border-primary/50 hover:bg-primary/5 transition-colors"
|
className="flex items-center justify-center gap-2 p-3 border-2 border-dashed rounded-lg cursor-pointer hover:border-primary/50 hover:bg-primary/5 transition-colors"
|
||||||
|
onClick={(e) => e.stopPropagation()}
|
||||||
>
|
>
|
||||||
<input {...getInputProps()} multiple={true} />
|
<input {...getInputProps()} multiple={true} />
|
||||||
<Plus size={16} className="text-muted-foreground" />
|
<Plus size={16} className="text-muted-foreground" />
|
||||||
|
|||||||
@@ -641,6 +641,16 @@ const TemplateFill: React.FC = () => {
|
|||||||
<div className="text-muted-foreground text-xs mt-1">
|
<div className="text-muted-foreground text-xs mt-1">
|
||||||
来源: {detail.source} | 置信度: {detail.confidence ? (detail.confidence * 100).toFixed(0) + '%' : 'N/A'}
|
来源: {detail.source} | 置信度: {detail.confidence ? (detail.confidence * 100).toFixed(0) + '%' : 'N/A'}
|
||||||
</div>
|
</div>
|
||||||
|
{detail.warning && (
|
||||||
|
<div className="mt-2 p-2 bg-yellow-50 border border-yellow-200 rounded-lg text-yellow-700 text-xs">
|
||||||
|
⚠️ {detail.warning}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{detail.values && detail.values.length > 1 && !detail.warning && (
|
||||||
|
<div className="mt-2 text-xs text-muted-foreground">
|
||||||
|
多值: {detail.values.join(', ')}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
))}
|
))}
|
||||||
|
|||||||
Reference in New Issue
Block a user