Files
FilesReadSystem/backend/app/services/rag_service.py
dj a9dc0d8b91 优化智能填表功能:提升速度、完善数据提取精度
后端优化 (template_fill_service.py):

1. 速度优化:
   - 使用 asyncio.gather 实现字段并行提取
   - 跳过 AI 审核步骤,减少 LLM 调用次数
   - 新增 _extract_single_field_fast 方法

2. 数据提取优化:
   - 集成 RAG 服务进行智能内容检索
   - 修复 Markdown 表格列匹配跳过空列
   - 修复年份子表头行误识别问题

3. AI 表头生成优化:
   - 精简为 5-7 个代表性字段(原来 8-15 个)
   - 过滤非数据字段(source、备注、说明等)
   - 简化字段名,如"医院数量"而非"医院-公立医院数量"

4. AI 数据提取 prompt 优化:
   - 严格按表头提取,只返回相关数据
   - 每个值必须带标注(年份/地区/分类)
   - 支持多种标注类型:2024年、北京、某省、公立医院、三级医院等
   - 保留原始数值、单位和百分号格式
   - 不返回大段来源说明

5. FillResult 新增 warning 字段:
   - 多值检测提示,如"检测到 2 个值"

前端优化 (TemplateFill.tsx):
- 填写详情显示多值警告(黄色提示框)
- 多值情况下直接显示所有值

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-14 17:14:59 +08:00

490 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAG 服务模块 - 检索增强生成
使用 sentence-transformers + Faiss 实现向量检索
"""
import logging
import os
import pickle
from typing import Any, Dict, List, Optional
import faiss
import numpy as np
from app.config import settings
logger = logging.getLogger(__name__)
# 尝试导入 sentence-transformers
try:
from sentence_transformers import SentenceTransformer
SENTENCE_TRANSFORMERS_AVAILABLE = True
except ImportError as e:
logger.warning(f"sentence-transformers 导入失败: {e}")
SENTENCE_TRANSFORMERS_AVAILABLE = False
SentenceTransformer = None
class SimpleDocument:
"""简化文档对象"""
def __init__(self, page_content: str, metadata: Dict[str, Any]):
self.page_content = page_content
self.metadata = metadata
class RAGService:
"""RAG 检索增强服务"""
# 默认分块参数
DEFAULT_CHUNK_SIZE = 500 # 每个文本块的大小(字符数)
DEFAULT_CHUNK_OVERLAP = 50 # 块之间的重叠(字符数)
def __init__(self):
self.embedding_model = None
self.index: Optional[faiss.Index] = None
self.documents: List[Dict[str, Any]] = []
self.doc_ids: List[str] = []
self._dimension: int = 384 # 默认维度
self._initialized = False
self._persist_dir = settings.FAISS_INDEX_DIR
# 检查是否可用
self._disabled = not SENTENCE_TRANSFORMERS_AVAILABLE
if self._disabled:
logger.warning("RAG 服务已禁用sentence-transformers 不可用),将使用关键词匹配作为后备")
else:
logger.info("RAG 服务已启用")
def _init_embeddings(self):
"""初始化嵌入模型"""
if self._disabled:
logger.debug("RAG 已禁用,跳过嵌入模型初始化")
return
if self.embedding_model is None:
# 使用轻量级本地模型,避免网络问题
model_name = 'all-MiniLM-L6-v2'
try:
self.embedding_model = SentenceTransformer(model_name)
self._dimension = self.embedding_model.get_sentence_embedding_dimension()
logger.info(f"RAG 嵌入模型初始化完成: {model_name}, 维度: {self._dimension}")
except Exception as e:
logger.warning(f"嵌入模型 {model_name} 加载失败: {e}")
# 如果本地模型也失败使用简单hash作为后备
self.embedding_model = None
self._dimension = 384
logger.info("RAG 使用简化模式 (无向量嵌入)")
def _init_vector_store(self):
"""初始化向量存储"""
if self.index is None:
self._init_embeddings()
if self.embedding_model is None:
# 无法加载嵌入模型,使用简化模式
self._dimension = 384
self.index = None
logger.warning("RAG 嵌入模型未加载,使用简化模式")
else:
self.index = faiss.IndexIDMap(faiss.IndexFlatIP(self._dimension))
logger.info("Faiss 向量存储初始化完成")
async def initialize(self):
"""异步初始化"""
try:
self._init_vector_store()
self._initialized = True
logger.info("RAG 服务初始化成功")
except Exception as e:
logger.error(f"RAG 服务初始化失败: {e}")
raise
def _normalize_vectors(self, vectors: np.ndarray) -> np.ndarray:
"""归一化向量"""
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
norms = np.where(norms == 0, 1, norms)
return vectors / norms
def _split_into_chunks(self, text: str, chunk_size: int = None, overlap: int = None) -> List[str]:
"""
将长文本分割成块
Args:
text: 待分割的文本
chunk_size: 每个块的大小(字符数)
overlap: 块之间的重叠字符数
Returns:
文本块列表
"""
if chunk_size is None:
chunk_size = self.DEFAULT_CHUNK_SIZE
if overlap is None:
overlap = self.DEFAULT_CHUNK_OVERLAP
if len(text) <= chunk_size:
return [text] if text.strip() else []
chunks = []
start = 0
text_len = len(text)
while start < text_len:
# 计算当前块的结束位置
end = start + chunk_size
# 如果不是最后一块,尝试在句子边界处切割
if end < text_len:
# 向前查找最后一个句号、逗号、换行或分号
cut_positions = []
for i in range(end, max(start, end - 100), -1):
if text[i] in '。;,,\n':
cut_positions.append(i + 1)
break
if cut_positions:
end = cut_positions[0]
else:
# 如果没找到句子边界,尝试向后查找
for i in range(end, min(text_len, end + 50)):
if text[i] in '。;,,\n':
end = i + 1
break
chunk = text[start:end].strip()
if chunk:
chunks.append(chunk)
# 移动起始位置(考虑重叠)
start = end - overlap
if start <= 0:
start = end
return chunks
def index_field(
self,
table_name: str,
field_name: str,
field_description: str,
sample_values: Optional[List[str]] = None
):
"""将字段信息索引到向量数据库"""
if self._disabled:
logger.info(f"[RAG DISABLED] 字段索引操作已跳过: {table_name}.{field_name}")
return
if not self._initialized:
self._init_vector_store()
# 如果没有嵌入模型,只记录到日志
if self.embedding_model is None:
logger.debug(f"字段跳过索引 (无嵌入模型): {table_name}.{field_name}")
return
text = f"表名: {table_name}, 字段: {field_name}, 描述: {field_description}"
if sample_values:
text += f", 示例值: {', '.join(sample_values)}"
doc_id = f"{table_name}.{field_name}"
doc = SimpleDocument(
page_content=text,
metadata={"table_name": table_name, "field_name": field_name, "doc_id": doc_id}
)
self._add_documents([doc], [doc_id])
logger.debug(f"已索引字段: {doc_id}")
def index_document_content(
self,
doc_id: str,
content: str,
metadata: Optional[Dict[str, Any]] = None,
chunk_size: int = None,
chunk_overlap: int = None
):
"""
将文档内容索引到向量数据库(自动分块)
Args:
doc_id: 文档唯一标识
content: 文档内容
metadata: 文档元数据
chunk_size: 文本块大小字符数默认500
chunk_overlap: 块之间的重叠字符数默认50
"""
if self._disabled:
logger.info(f"[RAG DISABLED] 文档索引操作已跳过: {doc_id}")
return
if not self._initialized:
self._init_vector_store()
# 如果没有嵌入模型,只记录到日志
if self.embedding_model is None:
logger.debug(f"文档跳过索引 (无嵌入模型): {doc_id}")
return
# 分割文档为小块
if chunk_size is None:
chunk_size = self.DEFAULT_CHUNK_SIZE
if chunk_overlap is None:
chunk_overlap = self.DEFAULT_CHUNK_OVERLAP
chunks = self._split_into_chunks(content, chunk_size, chunk_overlap)
if not chunks:
logger.warning(f"文档内容为空,跳过索引: {doc_id}")
return
# 为每个块创建文档对象
documents = []
chunk_ids = []
for i, chunk in enumerate(chunks):
chunk_id = f"{doc_id}_chunk_{i}"
chunk_metadata = metadata.copy() if metadata else {}
chunk_metadata.update({
"chunk_index": i,
"total_chunks": len(chunks),
"doc_id": doc_id
})
documents.append(SimpleDocument(
page_content=chunk,
metadata=chunk_metadata
))
chunk_ids.append(chunk_id)
# 批量添加文档
self._add_documents(documents, chunk_ids)
logger.info(f"已索引文档 {doc_id},共 {len(chunks)} 个块")
def _add_documents(self, documents: List[SimpleDocument], doc_ids: List[str]):
"""批量添加文档到向量索引"""
if not documents:
return
# 总是将文档存储在内存中(用于关键词搜索后备)
for doc, did in zip(documents, doc_ids):
self.documents.append({"id": did, "content": doc.page_content, "metadata": doc.metadata})
self.doc_ids.append(did)
# 如果没有嵌入模型,跳过向量索引
if self.embedding_model is None:
logger.debug(f"文档跳过向量索引 (无嵌入模型): {len(documents)} 个文档")
return
texts = [doc.page_content for doc in documents]
embeddings = self.embedding_model.encode(texts, convert_to_numpy=True)
embeddings = self._normalize_vectors(embeddings).astype('float32')
if self.index is None:
self._init_vector_store()
id_list = [hash(did) for did in doc_ids]
id_array = np.array(id_list, dtype='int64')
self.index.add_with_ids(embeddings, id_array)
def retrieve(self, query: str, top_k: int = 5, min_score: float = 0.3) -> List[Dict[str, Any]]:
"""
根据查询检索相关文档块
Args:
query: 查询文本
top_k: 返回的最大结果数
min_score: 最低相似度分数阈值
Returns:
相关文档块列表,每项包含 content, metadata, score, doc_id, chunk_index
"""
if self._disabled:
logger.info(f"[RAG DISABLED] 检索操作已跳过: query={query}, top_k={top_k}")
return []
if not self._initialized:
self._init_vector_store()
# 优先使用向量检索
if self.index is not None and self.index.ntotal > 0 and self.embedding_model is not None:
try:
query_embedding = self.embedding_model.encode([query], convert_to_numpy=True)
query_embedding = self._normalize_vectors(query_embedding).astype('float32')
scores, indices = self.index.search(query_embedding, min(top_k, self.index.ntotal))
results = []
for score, idx in zip(scores[0], indices[0]):
if idx < 0:
continue
if score < min_score:
continue
doc = self.documents[idx]
results.append({
"content": doc["content"],
"metadata": doc["metadata"],
"score": float(score),
"doc_id": doc["id"],
"chunk_index": doc["metadata"].get("chunk_index", 0)
})
if results:
logger.debug(f"向量检索到 {len(results)} 条相关文档块")
return results
except Exception as e:
logger.warning(f"向量检索失败,使用关键词搜索后备: {e}")
# 后备:使用关键词搜索
logger.debug("使用关键词搜索后备方案")
return self._keyword_search(query, top_k)
def _keyword_search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""
关键词搜索后备方案
Args:
query: 查询文本
top_k: 返回的最大结果数
Returns:
相关文档块列表
"""
if not self.documents:
return []
# 提取查询关键词
keywords = []
for char in query:
if '\u4e00' <= char <= '\u9fff': # 中文字符
keywords.append(char)
# 添加英文单词
import re
english_words = re.findall(r'[a-zA-Z]+', query)
keywords.extend(english_words)
if not keywords:
return []
results = []
for doc in self.documents:
content = doc["content"]
# 计算关键词匹配分数
score = 0
matched_keywords = 0
for kw in keywords:
if kw in content:
score += 1
matched_keywords += 1
if matched_keywords > 0:
# 归一化分数
score = score / max(len(keywords), 1)
results.append({
"content": content,
"metadata": doc["metadata"],
"score": score,
"doc_id": doc["id"],
"chunk_index": doc["metadata"].get("chunk_index", 0)
})
# 按分数排序
results.sort(key=lambda x: x["score"], reverse=True)
logger.debug(f"关键词搜索返回 {len(results[:top_k])} 条结果")
return results[:top_k]
def retrieve_by_doc_id(self, doc_id: str, top_k: int = 10) -> List[Dict[str, Any]]:
"""
获取指定文档的所有块
Args:
doc_id: 文档ID
top_k: 返回的最大结果数
Returns:
该文档的所有块
"""
# 获取属于该文档的所有块
doc_chunks = [d for d in self.documents if d["metadata"].get("doc_id") == doc_id]
# 按 chunk_index 排序
doc_chunks.sort(key=lambda x: x["metadata"].get("chunk_index", 0))
# 返回指定数量
return doc_chunks[:top_k]
def retrieve_by_table(self, table_name: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""检索指定表的字段"""
return self.retrieve(f"表名: {table_name}", top_k)
def get_vector_count(self) -> int:
"""获取向量总数"""
if self._disabled:
logger.info("[RAG DISABLED] get_vector_count 返回 0")
return 0
if self.index is None:
return 0
return self.index.ntotal
def save_index(self, persist_path: str = None):
"""保存向量索引到磁盘"""
if persist_path is None:
persist_path = self._persist_dir
if self.index is not None:
os.makedirs(persist_path, exist_ok=True)
faiss.write_index(self.index, os.path.join(persist_path, "index.faiss"))
with open(os.path.join(persist_path, "documents.pkl"), "wb") as f:
pickle.dump(self.documents, f)
logger.info(f"向量索引已保存到: {persist_path}")
def load_index(self, persist_path: str = None):
"""从磁盘加载向量索引"""
if persist_path is None:
persist_path = self._persist_dir
index_file = os.path.join(persist_path, "index.faiss")
docs_file = os.path.join(persist_path, "documents.pkl")
if not os.path.exists(index_file):
logger.warning(f"向量索引文件不存在: {index_file}")
return
self._init_embeddings()
self.index = faiss.read_index(index_file)
with open(docs_file, "rb") as f:
self.documents = pickle.load(f)
self.doc_ids = [d["id"] for d in self.documents]
self._initialized = True
logger.info(f"向量索引已从 {persist_path} 加载,共 {len(self.documents)}")
def delete_by_doc_id(self, doc_id: str):
"""根据文档ID删除索引"""
if self.index is not None:
remaining = [d for d in self.documents if d["id"] != doc_id]
self.documents = remaining
self.doc_ids = [d["id"] for d in self.documents]
self.index.reset()
if self.documents:
texts = [d["content"] for d in self.documents]
embeddings = self.embedding_model.encode(texts, convert_to_numpy=True)
embeddings = self._normalize_vectors(embeddings).astype('float32')
id_array = np.array([hash(did) for did in self.doc_ids], dtype='int64')
self.index.add_with_ids(embeddings, id_array)
logger.debug(f"已删除索引: {doc_id}")
def clear(self):
"""清空所有索引"""
if self._disabled:
logger.info("[RAG DISABLED] clear 操作已跳过")
return
self._init_vector_store()
if self.index is not None:
self.index.reset()
self.documents = []
self.doc_ids = []
logger.info("已清空所有向量索引")
rag_service = RAGService()