优化智能填表功能:提升速度、完善数据提取精度
后端优化 (template_fill_service.py): 1. 速度优化: - 使用 asyncio.gather 实现字段并行提取 - 跳过 AI 审核步骤,减少 LLM 调用次数 - 新增 _extract_single_field_fast 方法 2. 数据提取优化: - 集成 RAG 服务进行智能内容检索 - 修复 Markdown 表格列匹配跳过空列 - 修复年份子表头行误识别问题 3. AI 表头生成优化: - 精简为 5-7 个代表性字段(原来 8-15 个) - 过滤非数据字段(source、备注、说明等) - 简化字段名,如"医院数量"而非"医院-公立医院数量" 4. AI 数据提取 prompt 优化: - 严格按表头提取,只返回相关数据 - 每个值必须带标注(年份/地区/分类) - 支持多种标注类型:2024年、北京、某省、公立医院、三级医院等 - 保留原始数值、单位和百分号格式 - 不返回大段来源说明 5. FillResult 新增 warning 字段: - 多值检测提示,如"检测到 2 个值" 前端优化 (TemplateFill.tsx): - 填写详情显示多值警告(黄色提示框) - 多值情况下直接显示所有值 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -3,7 +3,6 @@ RAG 服务模块 - 检索增强生成
|
||||
|
||||
使用 sentence-transformers + Faiss 实现向量检索
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
@@ -11,12 +10,20 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 尝试导入 sentence-transformers
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
SENTENCE_TRANSFORMERS_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
logger.warning(f"sentence-transformers 导入失败: {e}")
|
||||
SENTENCE_TRANSFORMERS_AVAILABLE = False
|
||||
SentenceTransformer = None
|
||||
|
||||
|
||||
class SimpleDocument:
|
||||
"""简化文档对象"""
|
||||
@@ -28,17 +35,24 @@ class SimpleDocument:
|
||||
class RAGService:
|
||||
"""RAG 检索增强服务"""
|
||||
|
||||
# 默认分块参数
|
||||
DEFAULT_CHUNK_SIZE = 500 # 每个文本块的大小(字符数)
|
||||
DEFAULT_CHUNK_OVERLAP = 50 # 块之间的重叠(字符数)
|
||||
|
||||
def __init__(self):
|
||||
self.embedding_model: Optional[SentenceTransformer] = None
|
||||
self.embedding_model = None
|
||||
self.index: Optional[faiss.Index] = None
|
||||
self.documents: List[Dict[str, Any]] = []
|
||||
self.doc_ids: List[str] = []
|
||||
self._dimension: int = 0
|
||||
self._dimension: int = 384 # 默认维度
|
||||
self._initialized = False
|
||||
self._persist_dir = settings.FAISS_INDEX_DIR
|
||||
# 临时禁用 RAG API 调用,仅记录日志
|
||||
self._disabled = True
|
||||
logger.info("RAG 服务已禁用(_disabled=True),仅记录索引操作日志")
|
||||
# 检查是否可用
|
||||
self._disabled = not SENTENCE_TRANSFORMERS_AVAILABLE
|
||||
if self._disabled:
|
||||
logger.warning("RAG 服务已禁用(sentence-transformers 不可用),将使用关键词匹配作为后备")
|
||||
else:
|
||||
logger.info("RAG 服务已启用")
|
||||
|
||||
def _init_embeddings(self):
|
||||
"""初始化嵌入模型"""
|
||||
@@ -88,6 +102,63 @@ class RAGService:
|
||||
norms = np.where(norms == 0, 1, norms)
|
||||
return vectors / norms
|
||||
|
||||
def _split_into_chunks(self, text: str, chunk_size: int = None, overlap: int = None) -> List[str]:
|
||||
"""
|
||||
将长文本分割成块
|
||||
|
||||
Args:
|
||||
text: 待分割的文本
|
||||
chunk_size: 每个块的大小(字符数)
|
||||
overlap: 块之间的重叠字符数
|
||||
|
||||
Returns:
|
||||
文本块列表
|
||||
"""
|
||||
if chunk_size is None:
|
||||
chunk_size = self.DEFAULT_CHUNK_SIZE
|
||||
if overlap is None:
|
||||
overlap = self.DEFAULT_CHUNK_OVERLAP
|
||||
|
||||
if len(text) <= chunk_size:
|
||||
return [text] if text.strip() else []
|
||||
|
||||
chunks = []
|
||||
start = 0
|
||||
text_len = len(text)
|
||||
|
||||
while start < text_len:
|
||||
# 计算当前块的结束位置
|
||||
end = start + chunk_size
|
||||
|
||||
# 如果不是最后一块,尝试在句子边界处切割
|
||||
if end < text_len:
|
||||
# 向前查找最后一个句号、逗号、换行或分号
|
||||
cut_positions = []
|
||||
for i in range(end, max(start, end - 100), -1):
|
||||
if text[i] in '。;,,\n、':
|
||||
cut_positions.append(i + 1)
|
||||
break
|
||||
|
||||
if cut_positions:
|
||||
end = cut_positions[0]
|
||||
else:
|
||||
# 如果没找到句子边界,尝试向后查找
|
||||
for i in range(end, min(text_len, end + 50)):
|
||||
if text[i] in '。;,,\n、':
|
||||
end = i + 1
|
||||
break
|
||||
|
||||
chunk = text[start:end].strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
|
||||
# 移动起始位置(考虑重叠)
|
||||
start = end - overlap
|
||||
if start <= 0:
|
||||
start = end
|
||||
|
||||
return chunks
|
||||
|
||||
def index_field(
|
||||
self,
|
||||
table_name: str,
|
||||
@@ -124,9 +195,20 @@ class RAGService:
|
||||
self,
|
||||
doc_id: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
chunk_size: int = None,
|
||||
chunk_overlap: int = None
|
||||
):
|
||||
"""将文档内容索引到向量数据库"""
|
||||
"""
|
||||
将文档内容索引到向量数据库(自动分块)
|
||||
|
||||
Args:
|
||||
doc_id: 文档唯一标识
|
||||
content: 文档内容
|
||||
metadata: 文档元数据
|
||||
chunk_size: 文本块大小(字符数),默认500
|
||||
chunk_overlap: 块之间的重叠字符数,默认50
|
||||
"""
|
||||
if self._disabled:
|
||||
logger.info(f"[RAG DISABLED] 文档索引操作已跳过: {doc_id}")
|
||||
return
|
||||
@@ -139,18 +221,56 @@ class RAGService:
|
||||
logger.debug(f"文档跳过索引 (无嵌入模型): {doc_id}")
|
||||
return
|
||||
|
||||
doc = SimpleDocument(
|
||||
page_content=content,
|
||||
metadata=metadata or {"doc_id": doc_id}
|
||||
)
|
||||
self._add_documents([doc], [doc_id])
|
||||
logger.debug(f"已索引文档: {doc_id}")
|
||||
# 分割文档为小块
|
||||
if chunk_size is None:
|
||||
chunk_size = self.DEFAULT_CHUNK_SIZE
|
||||
if chunk_overlap is None:
|
||||
chunk_overlap = self.DEFAULT_CHUNK_OVERLAP
|
||||
|
||||
chunks = self._split_into_chunks(content, chunk_size, chunk_overlap)
|
||||
|
||||
if not chunks:
|
||||
logger.warning(f"文档内容为空,跳过索引: {doc_id}")
|
||||
return
|
||||
|
||||
# 为每个块创建文档对象
|
||||
documents = []
|
||||
chunk_ids = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk_id = f"{doc_id}_chunk_{i}"
|
||||
chunk_metadata = metadata.copy() if metadata else {}
|
||||
chunk_metadata.update({
|
||||
"chunk_index": i,
|
||||
"total_chunks": len(chunks),
|
||||
"doc_id": doc_id
|
||||
})
|
||||
|
||||
documents.append(SimpleDocument(
|
||||
page_content=chunk,
|
||||
metadata=chunk_metadata
|
||||
))
|
||||
chunk_ids.append(chunk_id)
|
||||
|
||||
# 批量添加文档
|
||||
self._add_documents(documents, chunk_ids)
|
||||
logger.info(f"已索引文档 {doc_id},共 {len(chunks)} 个块")
|
||||
|
||||
def _add_documents(self, documents: List[SimpleDocument], doc_ids: List[str]):
|
||||
"""批量添加文档到向量索引"""
|
||||
if not documents:
|
||||
return
|
||||
|
||||
# 总是将文档存储在内存中(用于关键词搜索后备)
|
||||
for doc, did in zip(documents, doc_ids):
|
||||
self.documents.append({"id": did, "content": doc.page_content, "metadata": doc.metadata})
|
||||
self.doc_ids.append(did)
|
||||
|
||||
# 如果没有嵌入模型,跳过向量索引
|
||||
if self.embedding_model is None:
|
||||
logger.debug(f"文档跳过向量索引 (无嵌入模型): {len(documents)} 个文档")
|
||||
return
|
||||
|
||||
texts = [doc.page_content for doc in documents]
|
||||
embeddings = self.embedding_model.encode(texts, convert_to_numpy=True)
|
||||
embeddings = self._normalize_vectors(embeddings).astype('float32')
|
||||
@@ -162,12 +282,18 @@ class RAGService:
|
||||
id_array = np.array(id_list, dtype='int64')
|
||||
self.index.add_with_ids(embeddings, id_array)
|
||||
|
||||
for doc, did in zip(documents, doc_ids):
|
||||
self.documents.append({"id": did, "content": doc.page_content, "metadata": doc.metadata})
|
||||
self.doc_ids.append(did)
|
||||
def retrieve(self, query: str, top_k: int = 5, min_score: float = 0.3) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
根据查询检索相关文档块
|
||||
|
||||
def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
"""根据查询检索相关文档"""
|
||||
Args:
|
||||
query: 查询文本
|
||||
top_k: 返回的最大结果数
|
||||
min_score: 最低相似度分数阈值
|
||||
|
||||
Returns:
|
||||
相关文档块列表,每项包含 content, metadata, score, doc_id, chunk_index
|
||||
"""
|
||||
if self._disabled:
|
||||
logger.info(f"[RAG DISABLED] 检索操作已跳过: query={query}, top_k={top_k}")
|
||||
return []
|
||||
@@ -175,28 +301,113 @@ class RAGService:
|
||||
if not self._initialized:
|
||||
self._init_vector_store()
|
||||
|
||||
if self.index is None or self.index.ntotal == 0:
|
||||
# 优先使用向量检索
|
||||
if self.index is not None and self.index.ntotal > 0 and self.embedding_model is not None:
|
||||
try:
|
||||
query_embedding = self.embedding_model.encode([query], convert_to_numpy=True)
|
||||
query_embedding = self._normalize_vectors(query_embedding).astype('float32')
|
||||
|
||||
scores, indices = self.index.search(query_embedding, min(top_k, self.index.ntotal))
|
||||
|
||||
results = []
|
||||
for score, idx in zip(scores[0], indices[0]):
|
||||
if idx < 0:
|
||||
continue
|
||||
if score < min_score:
|
||||
continue
|
||||
doc = self.documents[idx]
|
||||
results.append({
|
||||
"content": doc["content"],
|
||||
"metadata": doc["metadata"],
|
||||
"score": float(score),
|
||||
"doc_id": doc["id"],
|
||||
"chunk_index": doc["metadata"].get("chunk_index", 0)
|
||||
})
|
||||
|
||||
if results:
|
||||
logger.debug(f"向量检索到 {len(results)} 条相关文档块")
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(f"向量检索失败,使用关键词搜索后备: {e}")
|
||||
|
||||
# 后备:使用关键词搜索
|
||||
logger.debug("使用关键词搜索后备方案")
|
||||
return self._keyword_search(query, top_k)
|
||||
|
||||
def _keyword_search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
关键词搜索后备方案
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
top_k: 返回的最大结果数
|
||||
|
||||
Returns:
|
||||
相关文档块列表
|
||||
"""
|
||||
if not self.documents:
|
||||
return []
|
||||
|
||||
query_embedding = self.embedding_model.encode([query], convert_to_numpy=True)
|
||||
query_embedding = self._normalize_vectors(query_embedding).astype('float32')
|
||||
# 提取查询关键词
|
||||
keywords = []
|
||||
for char in query:
|
||||
if '\u4e00' <= char <= '\u9fff': # 中文字符
|
||||
keywords.append(char)
|
||||
# 添加英文单词
|
||||
import re
|
||||
english_words = re.findall(r'[a-zA-Z]+', query)
|
||||
keywords.extend(english_words)
|
||||
|
||||
scores, indices = self.index.search(query_embedding, min(top_k, self.index.ntotal))
|
||||
if not keywords:
|
||||
return []
|
||||
|
||||
results = []
|
||||
for score, idx in zip(scores[0], indices[0]):
|
||||
if idx < 0:
|
||||
continue
|
||||
doc = self.documents[idx]
|
||||
results.append({
|
||||
"content": doc["content"],
|
||||
"metadata": doc["metadata"],
|
||||
"score": float(score),
|
||||
"doc_id": doc["id"]
|
||||
})
|
||||
for doc in self.documents:
|
||||
content = doc["content"]
|
||||
# 计算关键词匹配分数
|
||||
score = 0
|
||||
matched_keywords = 0
|
||||
for kw in keywords:
|
||||
if kw in content:
|
||||
score += 1
|
||||
matched_keywords += 1
|
||||
|
||||
logger.debug(f"检索到 {len(results)} 条相关文档")
|
||||
return results
|
||||
if matched_keywords > 0:
|
||||
# 归一化分数
|
||||
score = score / max(len(keywords), 1)
|
||||
results.append({
|
||||
"content": content,
|
||||
"metadata": doc["metadata"],
|
||||
"score": score,
|
||||
"doc_id": doc["id"],
|
||||
"chunk_index": doc["metadata"].get("chunk_index", 0)
|
||||
})
|
||||
|
||||
# 按分数排序
|
||||
results.sort(key=lambda x: x["score"], reverse=True)
|
||||
|
||||
logger.debug(f"关键词搜索返回 {len(results[:top_k])} 条结果")
|
||||
return results[:top_k]
|
||||
|
||||
def retrieve_by_doc_id(self, doc_id: str, top_k: int = 10) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定文档的所有块
|
||||
|
||||
Args:
|
||||
doc_id: 文档ID
|
||||
top_k: 返回的最大结果数
|
||||
|
||||
Returns:
|
||||
该文档的所有块
|
||||
"""
|
||||
# 获取属于该文档的所有块
|
||||
doc_chunks = [d for d in self.documents if d["metadata"].get("doc_id") == doc_id]
|
||||
|
||||
# 按 chunk_index 排序
|
||||
doc_chunks.sort(key=lambda x: x["metadata"].get("chunk_index", 0))
|
||||
|
||||
# 返回指定数量
|
||||
return doc_chunks[:top_k]
|
||||
|
||||
def retrieve_by_table(self, table_name: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
"""检索指定表的字段"""
|
||||
|
||||
Reference in New Issue
Block a user