优化智能填表功能:提升速度、完善数据提取精度

后端优化 (template_fill_service.py):

1. 速度优化:
   - 使用 asyncio.gather 实现字段并行提取
   - 跳过 AI 审核步骤,减少 LLM 调用次数
   - 新增 _extract_single_field_fast 方法

2. 数据提取优化:
   - 集成 RAG 服务进行智能内容检索
   - 修复 Markdown 表格列匹配跳过空列
   - 修复年份子表头行误识别问题

3. AI 表头生成优化:
   - 精简为 5-7 个代表性字段(原来 8-15 个)
   - 过滤非数据字段(source、备注、说明等)
   - 简化字段名,如"医院数量"而非"医院-公立医院数量"

4. AI 数据提取 prompt 优化:
   - 严格按表头提取,只返回相关数据
   - 每个值必须带标注(年份/地区/分类)
   - 支持多种标注类型:2024年、北京、某省、公立医院、三级医院等
   - 保留原始数值、单位和百分号格式
   - 不返回大段来源说明

5. FillResult 新增 warning 字段:
   - 多值检测提示,如"检测到 2 个值"

前端优化 (TemplateFill.tsx):
- 填写详情显示多值警告(黄色提示框)
- 多值情况下直接显示所有值

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
dj
2026-04-14 17:14:59 +08:00
parent 5fca4eb094
commit a9dc0d8b91
5 changed files with 784 additions and 113 deletions

View File

@@ -404,18 +404,22 @@ async def process_documents_batch(task_id: str, files: List[dict]):
async def index_document_to_rag(doc_id: str, filename: str, result: ParseResult, doc_type: str):
"""将非结构化文档索引到 RAG"""
"""将非结构化文档索引到 RAG(使用分块索引)"""
try:
content = result.data.get("content", "")
if content:
# 将完整内容传递给 RAG 服务自动分块索引
rag_service.index_document_content(
doc_id=doc_id,
content=content[:5000],
content=content, # 传递完整内容,由 RAG 服务自动分块
metadata={
"filename": filename,
"doc_type": doc_type
}
},
chunk_size=500, # 每块 500 字符
chunk_overlap=50 # 块之间 50 字符重叠
)
logger.info(f"RAG 索引完成: {filename}, doc_id={doc_id}")
except Exception as e:
logger.warning(f"RAG 索引失败: {str(e)}")

View File

@@ -3,7 +3,6 @@ RAG 服务模块 - 检索增强生成
使用 sentence-transformers + Faiss 实现向量检索
"""
import json
import logging
import os
import pickle
@@ -11,12 +10,20 @@ from typing import Any, Dict, List, Optional
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from app.config import settings
logger = logging.getLogger(__name__)
# 尝试导入 sentence-transformers
try:
from sentence_transformers import SentenceTransformer
SENTENCE_TRANSFORMERS_AVAILABLE = True
except ImportError as e:
logger.warning(f"sentence-transformers 导入失败: {e}")
SENTENCE_TRANSFORMERS_AVAILABLE = False
SentenceTransformer = None
class SimpleDocument:
"""简化文档对象"""
@@ -28,17 +35,24 @@ class SimpleDocument:
class RAGService:
"""RAG 检索增强服务"""
# 默认分块参数
DEFAULT_CHUNK_SIZE = 500 # 每个文本块的大小(字符数)
DEFAULT_CHUNK_OVERLAP = 50 # 块之间的重叠(字符数)
def __init__(self):
self.embedding_model: Optional[SentenceTransformer] = None
self.embedding_model = None
self.index: Optional[faiss.Index] = None
self.documents: List[Dict[str, Any]] = []
self.doc_ids: List[str] = []
self._dimension: int = 0
self._dimension: int = 384 # 默认维度
self._initialized = False
self._persist_dir = settings.FAISS_INDEX_DIR
# 临时禁用 RAG API 调用,仅记录日志
self._disabled = True
logger.info("RAG 服务已禁用_disabled=True仅记录索引操作日志")
# 检查是否可用
self._disabled = not SENTENCE_TRANSFORMERS_AVAILABLE
if self._disabled:
logger.warning("RAG 服务已禁用sentence-transformers 不可用),将使用关键词匹配作为后备")
else:
logger.info("RAG 服务已启用")
def _init_embeddings(self):
"""初始化嵌入模型"""
@@ -88,6 +102,63 @@ class RAGService:
norms = np.where(norms == 0, 1, norms)
return vectors / norms
def _split_into_chunks(self, text: str, chunk_size: int = None, overlap: int = None) -> List[str]:
"""
将长文本分割成块
Args:
text: 待分割的文本
chunk_size: 每个块的大小(字符数)
overlap: 块之间的重叠字符数
Returns:
文本块列表
"""
if chunk_size is None:
chunk_size = self.DEFAULT_CHUNK_SIZE
if overlap is None:
overlap = self.DEFAULT_CHUNK_OVERLAP
if len(text) <= chunk_size:
return [text] if text.strip() else []
chunks = []
start = 0
text_len = len(text)
while start < text_len:
# 计算当前块的结束位置
end = start + chunk_size
# 如果不是最后一块,尝试在句子边界处切割
if end < text_len:
# 向前查找最后一个句号、逗号、换行或分号
cut_positions = []
for i in range(end, max(start, end - 100), -1):
if text[i] in '。;,,\n':
cut_positions.append(i + 1)
break
if cut_positions:
end = cut_positions[0]
else:
# 如果没找到句子边界,尝试向后查找
for i in range(end, min(text_len, end + 50)):
if text[i] in '。;,,\n':
end = i + 1
break
chunk = text[start:end].strip()
if chunk:
chunks.append(chunk)
# 移动起始位置(考虑重叠)
start = end - overlap
if start <= 0:
start = end
return chunks
def index_field(
self,
table_name: str,
@@ -124,9 +195,20 @@ class RAGService:
self,
doc_id: str,
content: str,
metadata: Optional[Dict[str, Any]] = None
metadata: Optional[Dict[str, Any]] = None,
chunk_size: int = None,
chunk_overlap: int = None
):
"""将文档内容索引到向量数据库"""
"""
将文档内容索引到向量数据库(自动分块)
Args:
doc_id: 文档唯一标识
content: 文档内容
metadata: 文档元数据
chunk_size: 文本块大小字符数默认500
chunk_overlap: 块之间的重叠字符数默认50
"""
if self._disabled:
logger.info(f"[RAG DISABLED] 文档索引操作已跳过: {doc_id}")
return
@@ -139,18 +221,56 @@ class RAGService:
logger.debug(f"文档跳过索引 (无嵌入模型): {doc_id}")
return
doc = SimpleDocument(
page_content=content,
metadata=metadata or {"doc_id": doc_id}
)
self._add_documents([doc], [doc_id])
logger.debug(f"已索引文档: {doc_id}")
# 分割文档为小块
if chunk_size is None:
chunk_size = self.DEFAULT_CHUNK_SIZE
if chunk_overlap is None:
chunk_overlap = self.DEFAULT_CHUNK_OVERLAP
chunks = self._split_into_chunks(content, chunk_size, chunk_overlap)
if not chunks:
logger.warning(f"文档内容为空,跳过索引: {doc_id}")
return
# 为每个块创建文档对象
documents = []
chunk_ids = []
for i, chunk in enumerate(chunks):
chunk_id = f"{doc_id}_chunk_{i}"
chunk_metadata = metadata.copy() if metadata else {}
chunk_metadata.update({
"chunk_index": i,
"total_chunks": len(chunks),
"doc_id": doc_id
})
documents.append(SimpleDocument(
page_content=chunk,
metadata=chunk_metadata
))
chunk_ids.append(chunk_id)
# 批量添加文档
self._add_documents(documents, chunk_ids)
logger.info(f"已索引文档 {doc_id},共 {len(chunks)} 个块")
def _add_documents(self, documents: List[SimpleDocument], doc_ids: List[str]):
"""批量添加文档到向量索引"""
if not documents:
return
# 总是将文档存储在内存中(用于关键词搜索后备)
for doc, did in zip(documents, doc_ids):
self.documents.append({"id": did, "content": doc.page_content, "metadata": doc.metadata})
self.doc_ids.append(did)
# 如果没有嵌入模型,跳过向量索引
if self.embedding_model is None:
logger.debug(f"文档跳过向量索引 (无嵌入模型): {len(documents)} 个文档")
return
texts = [doc.page_content for doc in documents]
embeddings = self.embedding_model.encode(texts, convert_to_numpy=True)
embeddings = self._normalize_vectors(embeddings).astype('float32')
@@ -162,12 +282,18 @@ class RAGService:
id_array = np.array(id_list, dtype='int64')
self.index.add_with_ids(embeddings, id_array)
for doc, did in zip(documents, doc_ids):
self.documents.append({"id": did, "content": doc.page_content, "metadata": doc.metadata})
self.doc_ids.append(did)
def retrieve(self, query: str, top_k: int = 5, min_score: float = 0.3) -> List[Dict[str, Any]]:
"""
根据查询检索相关文档块
def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""根据查询检索相关文档"""
Args:
query: 查询文本
top_k: 返回的最大结果数
min_score: 最低相似度分数阈值
Returns:
相关文档块列表,每项包含 content, metadata, score, doc_id, chunk_index
"""
if self._disabled:
logger.info(f"[RAG DISABLED] 检索操作已跳过: query={query}, top_k={top_k}")
return []
@@ -175,28 +301,113 @@ class RAGService:
if not self._initialized:
self._init_vector_store()
if self.index is None or self.index.ntotal == 0:
# 优先使用向量检索
if self.index is not None and self.index.ntotal > 0 and self.embedding_model is not None:
try:
query_embedding = self.embedding_model.encode([query], convert_to_numpy=True)
query_embedding = self._normalize_vectors(query_embedding).astype('float32')
scores, indices = self.index.search(query_embedding, min(top_k, self.index.ntotal))
results = []
for score, idx in zip(scores[0], indices[0]):
if idx < 0:
continue
if score < min_score:
continue
doc = self.documents[idx]
results.append({
"content": doc["content"],
"metadata": doc["metadata"],
"score": float(score),
"doc_id": doc["id"],
"chunk_index": doc["metadata"].get("chunk_index", 0)
})
if results:
logger.debug(f"向量检索到 {len(results)} 条相关文档块")
return results
except Exception as e:
logger.warning(f"向量检索失败,使用关键词搜索后备: {e}")
# 后备:使用关键词搜索
logger.debug("使用关键词搜索后备方案")
return self._keyword_search(query, top_k)
def _keyword_search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""
关键词搜索后备方案
Args:
query: 查询文本
top_k: 返回的最大结果数
Returns:
相关文档块列表
"""
if not self.documents:
return []
query_embedding = self.embedding_model.encode([query], convert_to_numpy=True)
query_embedding = self._normalize_vectors(query_embedding).astype('float32')
# 提取查询关键词
keywords = []
for char in query:
if '\u4e00' <= char <= '\u9fff': # 中文字符
keywords.append(char)
# 添加英文单词
import re
english_words = re.findall(r'[a-zA-Z]+', query)
keywords.extend(english_words)
scores, indices = self.index.search(query_embedding, min(top_k, self.index.ntotal))
if not keywords:
return []
results = []
for score, idx in zip(scores[0], indices[0]):
if idx < 0:
continue
doc = self.documents[idx]
results.append({
"content": doc["content"],
"metadata": doc["metadata"],
"score": float(score),
"doc_id": doc["id"]
})
for doc in self.documents:
content = doc["content"]
# 计算关键词匹配分数
score = 0
matched_keywords = 0
for kw in keywords:
if kw in content:
score += 1
matched_keywords += 1
logger.debug(f"检索到 {len(results)} 条相关文档")
return results
if matched_keywords > 0:
# 归一化分数
score = score / max(len(keywords), 1)
results.append({
"content": content,
"metadata": doc["metadata"],
"score": score,
"doc_id": doc["id"],
"chunk_index": doc["metadata"].get("chunk_index", 0)
})
# 按分数排序
results.sort(key=lambda x: x["score"], reverse=True)
logger.debug(f"关键词搜索返回 {len(results[:top_k])} 条结果")
return results[:top_k]
def retrieve_by_doc_id(self, doc_id: str, top_k: int = 10) -> List[Dict[str, Any]]:
"""
获取指定文档的所有块
Args:
doc_id: 文档ID
top_k: 返回的最大结果数
Returns:
该文档的所有块
"""
# 获取属于该文档的所有块
doc_chunks = [d for d in self.documents if d["metadata"].get("doc_id") == doc_id]
# 按 chunk_index 排序
doc_chunks.sort(key=lambda x: x["metadata"].get("chunk_index", 0))
# 返回指定数量
return doc_chunks[:top_k]
def retrieve_by_table(self, table_name: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""检索指定表的字段"""

View File

@@ -3,6 +3,7 @@
从非结构化文档中检索信息并填写到表格模板
"""
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
@@ -11,6 +12,7 @@ from app.core.database import mongodb
from app.services.llm_service import llm_service
from app.core.document_parser import ParserFactory
from app.services.markdown_ai_service import markdown_ai_service
from app.services.rag_service import rag_service
logger = logging.getLogger(__name__)
@@ -43,6 +45,7 @@ class FillResult:
value: Any = "" # 保留兼容
source: str = "" # 来源文档
confidence: float = 1.0 # 置信度
warning: str = None # 多值提示
def __post_init__(self):
if self.values is None:
@@ -172,49 +175,30 @@ class TemplateFillService:
if source_docs and template_fields:
logger.info(f"表头看起来正常(非自动生成),无需重新生成: {[f.name for f in template_fields[:5]]}")
# 2. 对每个字段进行提取
# 2. 并行提取所有字段跳过AI审核以提升速度
logger.info(f"开始并行提取 {len(template_fields)} 个字段...")
# 并行处理所有字段
tasks = []
for idx, field in enumerate(template_fields):
try:
logger.info(f"提取字段 [{idx+1}/{len(template_fields)}]: {field.name}")
# 从源文档中提取字段值
result = await self._extract_field_value(
field=field,
source_docs=source_docs,
user_hint=user_hint
)
task = self._extract_single_field_fast(
field=field,
source_docs=source_docs,
user_hint=user_hint,
field_idx=idx,
total_fields=len(template_fields)
)
tasks.append(task)
# AI审核验证提取的值是否合理
if result.values and result.values[0]:
logger.info(f"字段 {field.name} 进入AI审核阶段...")
verified_result = await self._verify_field_value(
field=field,
extracted_values=result.values,
source_docs=source_docs,
user_hint=user_hint
)
if verified_result:
# 审核给出了修正结果
result = verified_result
logger.info(f"字段 {field.name} 审核后修正值: {result.values[:3]}")
else:
logger.info(f"字段 {field.name} 审核通过,使用原提取结果")
# 等待所有任务完成
results = await asyncio.gather(*tasks, return_exceptions=True)
# 存储结果 - 使用 values 数组
filled_data[field.name] = result.values if result.values else [""]
fill_details.append({
"field": field.name,
"cell": field.cell,
"values": result.values,
"value": result.value,
"source": result.source,
"confidence": result.confidence
})
logger.info(f"字段 {field.name} 填写完成: {len(result.values)} 个值")
except Exception as e:
logger.error(f"填写字段 {field.name} 失败: {str(e)}", exc_info=True)
filled_data[field.name] = [f"[提取失败: {str(e)}]"]
# 处理结果
for idx, result in enumerate(results):
field = template_fields[idx]
if isinstance(result, Exception):
logger.error(f"提取字段 {field.name} 失败: {str(result)}")
filled_data[field.name] = [f"[提取失败: {str(result)}]"]
fill_details.append({
"field": field.name,
"cell": field.cell,
@@ -223,6 +207,18 @@ class TemplateFillService:
"source": "error",
"confidence": 0.0
})
else:
filled_data[field.name] = result.values if result.values else [""]
fill_details.append({
"field": field.name,
"cell": field.cell,
"values": result.values,
"value": result.value,
"source": result.source,
"confidence": result.confidence,
"warning": result.warning
})
logger.info(f"字段 {field.name} 填写完成: {len(result.values) if result.values else 0} 个值")
# 计算最大行数
max_rows = max(len(v) for v in filled_data.values()) if filled_data else 1
@@ -551,6 +547,222 @@ class TemplateFillService:
confidence=0.0
)
async def _extract_single_field_fast(
self,
field: TemplateField,
source_docs: List[SourceDocument],
user_hint: Optional[str] = None,
field_idx: int = 0,
total_fields: int = 1
) -> FillResult:
"""
快速提取单个字段跳过AI审核减少LLM调用
Args:
field: 字段定义
source_docs: 源文档列表
user_hint: 用户提示
field_idx: 当前字段索引(用于日志)
total_fields: 总字段数(用于日志)
Returns:
提取结果
"""
try:
if not source_docs:
return FillResult(
field=field.name,
value="",
values=[""],
source="无源文档",
confidence=0.0
)
# 1. 优先尝试直接从结构化数据中提取(最快路径)
direct_values = self._extract_values_from_structured_data(source_docs, field.name)
if direct_values:
logger.info(f"✅ [{field_idx+1}/{total_fields}] 字段 {field.name} 直接从结构化数据提取到 {len(direct_values)} 个值")
return FillResult(
field=field.name,
values=direct_values,
value=direct_values[0] if direct_values else "",
source="结构化数据直接提取",
confidence=1.0
)
# 2. 无法直接从结构化数据提取使用简化版AI提取
logger.info(f"🔍 [{field_idx+1}/{total_fields}] 字段 {field.name} 尝试AI提取...")
# 构建提示词 - 简化版
hint_text = field.hint if field.hint else f"请提取{field.name}的信息"
if user_hint:
hint_text = f"{user_hint}{hint_text}"
# 优先使用 RAG 检索内容,否则使用文档开头部分
context_parts = []
for doc in source_docs:
if not doc.content:
logger.info(f" 文档 {doc.filename} 无content内容")
continue
logger.info(f" 处理文档: {doc.filename}, doc_id={doc.doc_id}, content长度={len(doc.content)}")
# 尝试 RAG 检索
rag_results = rag_service.retrieve(
query=f"{field.name} {hint_text}",
top_k=3,
min_score=0.1
)
if rag_results:
logger.info(f" RAG检索到 {len(rag_results)} 条结果")
# 使用 RAG 检索到的内容
for r in rag_results:
rag_doc_id = r.get("doc_id", "")
if rag_doc_id.startswith(doc.doc_id):
context_parts.append(r["content"])
logger.info(f" 匹配成功使用RAG内容长度={len(r['content'])}")
else:
# RAG 没结果,使用文档内容开头
context_parts.append(doc.content[:2500])
logger.info(f" RAG无结果使用文档开头 {min(2500, len(doc.content))} 字符")
context = "\n\n".join(context_parts[:3]) if context_parts else ""
logger.info(f" 最终context长度: {len(context)}, 内容预览: {context[:200] if context else ''}...")
prompt = f"""你是一个专业的数据提取专家。请严格按照表头字段「{field.name}」从文档中提取数据。
提示: {hint_text}
【重要规则 - 必须遵守】
1. **每个值必须有标注**:根据数据来源添加合适的标注前缀!
- ✅ 正确格式:
- "2024年38710个"
- "北京1234万人次"
- "某省5678万人"
- "公立医院11754个"
- "三级医院4111个"
- "图书馆3246个"
- ❌ 错误格式:"38710个"(缺少标注)
2. **标注类型根据数据决定**
- 年份类数据 → "2024年xxx""2023年xxx"
- 地区类数据 → "北京xxx""广东xxx""某县xxx"
- 机构/分类数据 → "公立医院xxx""三级医院xxx""图书馆xxx"
- 其他分类 → 根据实际情况标注
3. **严格按表头提取**:只提取与「{field.name}」直接相关的数据
4. **多值必须全部提取并标注**:如果文档中提到多个相关数据,每个都要有标注
文档内容:
{context if context else "(无文档内容)"}
请严格按格式返回JSON{{"values": ["标注:数值", "标注:数值", ...]}}
注意values数组中每个元素都必须包含标注前缀不能只有数值
"""
messages = [
{"role": "system", "content": "你是一个专业的数据提取助手擅长从政府统计公报中提取数据。严格按JSON格式输出只返回values数组。"},
{"role": "user", "content": prompt}
]
response = await self.llm.chat(
messages=messages,
temperature=0.1,
max_tokens=1000
)
content = self.llm.extract_message_content(response)
logger.info(f" LLM原始返回: {content[:500]}")
# 解析JSON
import json
import re
cleaned = content.strip()
# 查找JSON开始位置
json_start = -1
for i, c in enumerate(cleaned):
if c == '{':
json_start = i
break
values = []
source = "AI提取"
if json_start >= 0:
try:
json_text = cleaned[json_start:]
result = json.loads(json_text)
values = result.get("values", [])
logger.info(f" JSON解析成功values: {values}")
except json.JSONDecodeError as e:
logger.warning(f" JSON解析失败: {e},尝试修复...")
# 尝试修复常见JSON问题
try:
# 尝试找到values数组
values_match = re.search(r'"values"\s*:\s*\[(.*?)\]', cleaned, re.DOTALL)
if values_match:
values_str = values_match.group(1)
# 提取数组中的字符串
values = re.findall(r'"([^"]*)"', values_str)
logger.info(f" 正则提取values: {values}")
except:
pass
# 如果values为空尝试从文本中用正则提取数字+单位
if not values or values == [""]:
logger.info(f" JSON解析未获取到值尝试正则提取...")
# 匹配数字+单位或百分号的模式
patterns = [
r'(\d+\.?\d*[亿万千百十个]?[%‰℃℃万元亿]?)', # 通用数字+单位
r'(\d+\.?\d*%)', # 百分号
r'(\d+\.?\d*[个万人亿元]?)', # 中文单位
]
for pattern in patterns:
matches = re.findall(pattern, context)
if matches:
# 过滤掉纯数字
filtered = [m for m in matches if not m.replace('.', '').isdigit()]
if filtered:
values = filtered[:10] # 最多取10个
logger.info(f" 正则提取到: {values}")
break
if not values or values == [""]:
values = self._extract_values_by_regex(cleaned)
if not values:
values = [""]
# 生成多值提示(基于实际检测到的值数量)
warning = ""
if len(values) > 1:
warning = f"⚠️ 检测到 {len(values)} 个值:{values[:5]}{'...' if len(values) > 5 else ''}"
logger.info(f"✅ [{field_idx+1}/{total_fields}] 字段 {field.name} AI提取完成: {len(values)} 个值")
if warning:
logger.info(f" {warning}")
return FillResult(
field=field.name,
values=values,
value=values[0] if values else "",
source=source,
confidence=0.8,
warning=warning if warning else None
)
except Exception as e:
logger.error(f"❌ [{field_idx+1}/{total_fields}] 字段 {field.name} 提取失败: {str(e)}")
return FillResult(
field=field.name,
values=[""],
value="",
source=f"提取失败: {str(e)}",
confidence=0.0
)
async def _verify_field_value(
self,
field: TemplateField,
@@ -1172,13 +1384,148 @@ class TemplateFillService:
values = []
for row in rows:
if isinstance(row, list) and target_idx < len(row):
val = row[target_idx]
if isinstance(row, list):
# 跳过子表头行(主要包含年份值的行,如 "1985", "1995"
if self._is_year_subheader_row(row):
logger.info(f"跳过子表头行: {row[:5]}...")
continue
# 跳过章节标题行
if self._is_section_header_row(row):
logger.info(f"跳过章节标题行: {row[:5]}...")
continue
if target_idx < len(row):
val = row[target_idx]
else:
val = ""
else:
val = ""
values.append(self._format_value(val))
return values
# 过滤掉无效值(章节标题、省略号等)
valid_values = self._filter_valid_values(values)
if len(valid_values) < len(values):
logger.info(f"过滤无效值: {len(values)} -> {len(valid_values)}")
return valid_values
def _is_year_subheader_row(self, row: List) -> bool:
"""
检测行是否看起来像年份子表头行
年份子表头行通常包含 "1985", "1995", "2020" 等4位数字
Args:
row: 行数据
Returns:
是否是年份子表头行
"""
if not row:
return False
import re
year_pattern = re.compile(r'^(19|20)\d{2}$')
# 计算看起来像年份的单元格数量
year_like_count = 0
for cell in row:
cell_str = str(cell).strip()
if year_pattern.match(cell_str):
year_like_count += 1
# 如果超过50%的单元格是年份格式,认为是子表头行
if len(row) > 0 and year_like_count / len(row) > 0.5:
return True
return False
def _is_section_header_row(self, row: List) -> bool:
"""
检测行是否看起来像章节标题行
章节标题行通常包含 "其中:""全部工业中:""按...计算" 等关键词
Args:
row: 行数据
Returns:
是否是章节标题行
"""
if not row:
return False
import re
# 章节标题通常包含这些模式
section_patterns = [
r'其中[:]',
r'全部\w+中[:]',
r'\w+计算',
r'小计',
r'合计',
r'总计',
r'^其中$',
r'全部$'
]
for cell in row:
cell_str = str(cell).strip()
if not cell_str:
continue
for pattern in section_patterns:
if re.search(pattern, cell_str):
return True
return False
def _is_valid_data_value(self, val: str) -> bool:
"""
检测值是否是有效的数据值(不是章节标题、省略号等)
Args:
val: 值字符串
Returns:
是否是有效数据值
"""
if not val or not str(val).strip():
return False
val_str = str(val).strip()
# 无效模式
invalid_patterns = [
r'^…$', # 省略号
r'^[\.。]+$', # 只有点或句号
r'其中[:]', # 章节标题
r'全部\w+中', # 章节标题
r'\w+计算', # 计算类型
r'^(小计|合计|总计)$', # 汇总行
r'^其中$',
r'^全部$'
]
for pattern in invalid_patterns:
import re
if re.match(pattern, val_str):
return False
return True
def _filter_valid_values(self, values: List[str]) -> List[str]:
"""
过滤出有效的数据值
Args:
values: 值列表
Returns:
只包含有效值的列表
"""
valid_values = []
for val in values:
if self._is_valid_data_value(val):
valid_values.append(val)
return valid_values
def _find_best_matching_column(self, headers: List, field_name: str) -> Optional[int]:
"""
@@ -1206,6 +1553,11 @@ class TemplateFillService:
header_str = str(header).strip()
header_lower = header_str.lower()
# 跳过空表头(第一列为空的情况)
if not header_str:
logger.info(f"跳过空表头列: 索引 {idx}")
continue
# 策略1: 精确匹配(忽略大小写)
if header_lower == field_lower:
return idx
@@ -1262,6 +1614,12 @@ class TemplateFillService:
values = []
for row in rows:
# 跳过子表头行(主要包含年份值的行,如 "1985", "1995"
if isinstance(row, list) and self._is_year_subheader_row(row):
continue
# 跳过章节标题行
if isinstance(row, list) and self._is_section_header_row(row):
continue
if isinstance(row, dict):
val = row.get(target_col, "")
elif isinstance(row, list) and target_idx < len(row):
@@ -1270,7 +1628,12 @@ class TemplateFillService:
val = ""
values.append(self._format_value(val))
return values
# 过滤掉无效值(章节标题、省略号等)
valid_values = self._filter_valid_values(values)
if len(valid_values) < len(values):
logger.info(f"过滤无效值: {len(values)} -> {len(valid_values)}")
return valid_values
def _format_value(self, val: Any) -> str:
"""
@@ -1604,38 +1967,98 @@ class TemplateFillService:
if user_hint:
hint_text = f"{user_hint}{hint_text}"
# 构建针对字段提取的提示词
prompt = f"""你是一个专业的数据提取专家。请从以下文档内容中提取与"{field.name}"完全匹配的数据。
# 构建查询文本
query_text = f"{field.name} {hint_text}"
# 使用 RAG 向量检索获取相关内容块
rag_results = rag_service.retrieve(
query=query_text,
top_k=5,
min_score=0.3
)
# 构建上下文:优先使用 RAG 检索结果,如果检索不到则使用原始内容
if rag_results:
# 使用 RAG 检索到的相关块
context_parts = []
for result in rag_results:
if result.get("doc_id", "").startswith(doc.doc_id) or not result.get("doc_id"):
context_parts.append(result["content"])
if context_parts:
retrieved_context = "\n\n---\n\n".join(context_parts)
logger.info(f"RAG 检索到 {len(context_parts)} 个相关块用于字段 {field.name}")
# 使用检索到的内容(限制长度)
context_to_use = retrieved_context[:6000]
else:
# RAG 检索结果不属于当前文档,使用原始内容
context_to_use = doc.content[:6000] if doc.content else ""
logger.info(f"字段 {field.name} 使用原始内容RAG结果不属于当前文档")
else:
# 没有 RAG 检索结果,使用原始内容
context_to_use = doc.content[:6000] if doc.content else ""
logger.info(f"字段 {field.name} 使用原始内容无RAG检索结果")
# 构建针对字段提取的提示词 - 增强语义匹配能力
prompt = f"""你是一个专业的数据提取专家。请从以下文档内容中进行**语义匹配**提取。
【重要】字段名: "{field.name}"
【重要】字段提示: {hint_text}
请严格按照以下步骤操作:
1. 在文档中搜索与"{field.name}"完全相同或高度相关的关键词
2. 找到后,提取该关键词后的数值(注意:只要数值,不要单位)
3. 如果是表格中的数据,直接提取该单元格的数值
4. 如果是段落描述,在关键词附近找数值
## 分类数据识别
文档中经常包含分类统计数据,格式如下:
### 1. 直接分类(用"其中:""中,"等分隔)
原文示例:
- "全国医疗卫生机构总数1093551个其中医院38710个基层医疗卫生机构1040023个"
→ 字段"医院数量" 应提取: 38710
→ 字段"基层医疗卫生机构数量" 应提取: 1040023
- "医院中公立医院11754个民营医院26956个"
→ 字段"公立医院数量" 应提取: 11754
→ 字段"民营医院数量" 应提取: 26956
### 2. 嵌套分类(用"按...分:""其中:"等结构)
原文示例:
- "医院按等级分三级医院4111个其中三级甲等医院1876个二级医院12294个"
→ 字段"三级医院数量" 应提取: 4111
→ 字段"三级甲等医院数量" 应提取: 1876
→ 字段"二级医院数量" 应提取: 12294
### 3. 匹配技巧
- "医院数量" 可匹配: "医院38710个""医院数量为"
- "公立医院数量" 可匹配: "公立医院11754个""公立医院有"
- 忽略"数量"""""等后缀的差异
- 数值可能紧跟关键词,也可能分开描述
## 提取规则
1. **全文搜索**:在文档的全部内容中搜索,不要只搜索开头部分
2. **分类定位**:找到包含该分类关键词的句子,理解其完整的数值
3. **保留单位**:提取数值时**要包含单位**
【重要】返回值规则:
- 返回数值,不要单位(如 "4.9" 而不是 "4.9万亿元"
- 如原文"4.9万亿元",返回 "4.9"
- 如原文"144000万册",返回 "144000"
- 如果是百分比如"增长7.7%",返回 "7.7"
- 如果没有找到完全匹配的数据,返回空数组
- **返回数值时必须包含单位**
- 如原文"公共图书馆3246个"提取时应返回 "3246个"
- 如原文"国内旅游收入4.9万亿元"提取时应返回 "4.9万亿元"
- 例如原文"注册护士585.5万人"提取时应返回 "585.5万人"
- 如果字段是"指标"类型,返回具体的指标名称文本(不带单位)
- 如果没有找到任何相关数据,返回空数组
文档内容:
{doc.content[:10000] if doc.content else ""}
{context_to_use}
请用严格的 JSON 格式返回:
{{
"values": ["值1", "值2", ...], // 只填数值,不要单位
"source": "数据来源说明",
"values": ["提取到的值1", "值2", ...],
"source": "数据来源说明从文档第X段提取",
"confidence": 0.0到1.0之间的置信度
}}
示例
- 如果字段是"图书馆总藏量(万册)"且文档说"图书总藏量14.4亿册",返回 values: ["144000"]
- 如果字段是"国内旅游收入(亿元)"且文档说"国内旅游收入4.9万亿元",返回 values: ["49000"]"""
【重要】即使是模糊匹配,也要
- 确保提取的内容确实来自文档
- source中准确说明数据来源位置"""
messages = [
{"role": "system", "content": "你是一个专业的数据提取助手擅长从政府统计公报等文档中提取数据。请严格按JSON格式输出。"},
@@ -1790,31 +2213,45 @@ class TemplateFillService:
source_info += f"【包含表格数】: {tables_count}\n"
if tables_summary:
source_info += f"{tables_summary}\n"
elif content:
source_info += f"内容预览】: {content[:1500]}...\n"
if content:
source_info += f"文档内容】前3000字符{content[:3000]}\n"
prompt = f"""你是一个专业的表格设计助手。请根据源文档内容生成合适的表格表头字段。
prompt = f"""你是一个专业的数据分析助手。请分析源文档中的所有数据,生成表格表头字段。
任务:用户有一些源文档(包含表格数据),需要填写到空白表格模板中。源文档中的表格如下:
任务:分析源文档,找出所有具体的数据指标及其分类。
{source_info}
【重要要求】
1. 请仔细阅读上面的源文档表格,找出所有不同的列名(如"产品名称""1995年产量""按资产总额计算(%)"等)
2. 直接使用这些实际的列名作为表头字段名,不要生成新的或同义词
3. 如果一个源文档有多个表格,请为每个表格选择合适的列名
4. 生成3-8个表头字段优先选择数据量大的表格的列
1. **只生成数据字段名**
- ✅ 正确示例:"医院数量""公立医院数量""病床使用率"
- ❌ 错误示例:"source""备注""说明""数据来源"
2. **识别所有数值数据**
- 例如:"医院38710个""病床使用率78.8%"
- 例如:"公立医院11754个""公立医院病床使用率84.8%"
3. **理解分类层级**
- 顶级分类:如"医院""基层医疗卫生机构"
- 二级分类:如"医院"下分为"公立医院""民营医院"
4. **生成字段**
- 字段名要简洁,如:"医院数量""病床使用率"
- 优先选择:总数 + 主要分类
5. **生成数量**
- 生成5-7个最有代表性的字段
请严格按照以下 JSON 格式输出(只需输出 JSON不要其他内容
{{
"fields": [
{{"name": "实际列名1", "hint": "对该列的说明"}},
{{"name": "实际列名2", "hint": "对该列的说明"}}
{{"name": "字段名1"}},
{{"name": "字段名2"}}
]
}}
"""
messages = [
{"role": "system", "content": "你是一个专业的表格设计助手。请严格按JSON格式输出。"},
{"role": "system", "content": "你是一个专业的表格设计助手。请严格按JSON格式输出只返回纯数据字段名不要source、备注、说明等辅助字段"},
{"role": "user", "content": prompt}
]
@@ -1853,14 +2290,22 @@ class TemplateFillService:
if result and "fields" in result:
fields = []
# 过滤非数据字段
skip_keywords = ["source", "来源", "备注", "说明", "备注列", "说明列", "data_source", "remark", "note", "description"]
for idx, f in enumerate(result["fields"]):
field_name = f.get("name", f"字段{idx+1}")
# 跳过非数据字段
if any(kw in field_name.lower() for kw in skip_keywords):
logger.info(f"跳过非数据字段: {field_name}")
continue
fields.append(TemplateField(
cell=self._column_to_cell(idx),
name=f.get("name", f"字段{idx+1}"),
name=field_name,
field_type="text",
required=False,
hint=f.get("hint", "")
))
logger.info(f"AI 生成表头: {[f.name for f in fields]}")
return fields
except Exception as e:

View File

@@ -766,6 +766,7 @@ const Documents: React.FC = () => {
<div
{...getRootProps()}
className="flex items-center justify-center gap-2 p-3 border-2 border-dashed rounded-lg cursor-pointer hover:border-primary/50 hover:bg-primary/5 transition-colors"
onClick={(e) => e.stopPropagation()}
>
<input {...getInputProps()} multiple={true} />
<Plus size={16} className="text-muted-foreground" />

View File

@@ -626,6 +626,16 @@ const TemplateFill: React.FC = () => {
<div className="text-muted-foreground text-xs mt-1">
: {detail.source} | : {detail.confidence ? (detail.confidence * 100).toFixed(0) + '%' : 'N/A'}
</div>
{detail.warning && (
<div className="mt-2 p-2 bg-yellow-50 border border-yellow-200 rounded-lg text-yellow-700 text-xs">
{detail.warning}
</div>
)}
{detail.values && detail.values.length > 1 && !detail.warning && (
<div className="mt-2 text-xs text-muted-foreground">
: {detail.values.join(', ')}
</div>
)}
</div>
</div>
))}