diff --git a/backend/app/api/endpoints/documents.py b/backend/app/api/endpoints/documents.py index 4260ec6..e8e206a 100644 --- a/backend/app/api/endpoints/documents.py +++ b/backend/app/api/endpoints/documents.py @@ -404,18 +404,22 @@ async def process_documents_batch(task_id: str, files: List[dict]): async def index_document_to_rag(doc_id: str, filename: str, result: ParseResult, doc_type: str): - """将非结构化文档索引到 RAG""" + """将非结构化文档索引到 RAG(使用分块索引)""" try: content = result.data.get("content", "") if content: + # 将完整内容传递给 RAG 服务自动分块索引 rag_service.index_document_content( doc_id=doc_id, - content=content[:5000], + content=content, # 传递完整内容,由 RAG 服务自动分块 metadata={ "filename": filename, "doc_type": doc_type - } + }, + chunk_size=500, # 每块 500 字符 + chunk_overlap=50 # 块之间 50 字符重叠 ) + logger.info(f"RAG 索引完成: {filename}, doc_id={doc_id}") except Exception as e: logger.warning(f"RAG 索引失败: {str(e)}") diff --git a/backend/app/services/rag_service.py b/backend/app/services/rag_service.py index 3855180..b6e905b 100644 --- a/backend/app/services/rag_service.py +++ b/backend/app/services/rag_service.py @@ -3,7 +3,6 @@ RAG 服务模块 - 检索增强生成 使用 sentence-transformers + Faiss 实现向量检索 """ -import json import logging import os import pickle @@ -11,12 +10,20 @@ from typing import Any, Dict, List, Optional import faiss import numpy as np -from sentence_transformers import SentenceTransformer from app.config import settings logger = logging.getLogger(__name__) +# 尝试导入 sentence-transformers +try: + from sentence_transformers import SentenceTransformer + SENTENCE_TRANSFORMERS_AVAILABLE = True +except ImportError as e: + logger.warning(f"sentence-transformers 导入失败: {e}") + SENTENCE_TRANSFORMERS_AVAILABLE = False + SentenceTransformer = None + class SimpleDocument: """简化文档对象""" @@ -28,17 +35,24 @@ class SimpleDocument: class RAGService: """RAG 检索增强服务""" + # 默认分块参数 + DEFAULT_CHUNK_SIZE = 500 # 每个文本块的大小(字符数) + DEFAULT_CHUNK_OVERLAP = 50 # 块之间的重叠(字符数) + def __init__(self): - self.embedding_model: Optional[SentenceTransformer] = None + self.embedding_model = None self.index: Optional[faiss.Index] = None self.documents: List[Dict[str, Any]] = [] self.doc_ids: List[str] = [] - self._dimension: int = 0 + self._dimension: int = 384 # 默认维度 self._initialized = False self._persist_dir = settings.FAISS_INDEX_DIR - # 临时禁用 RAG API 调用,仅记录日志 - self._disabled = True - logger.info("RAG 服务已禁用(_disabled=True),仅记录索引操作日志") + # 检查是否可用 + self._disabled = not SENTENCE_TRANSFORMERS_AVAILABLE + if self._disabled: + logger.warning("RAG 服务已禁用(sentence-transformers 不可用),将使用关键词匹配作为后备") + else: + logger.info("RAG 服务已启用") def _init_embeddings(self): """初始化嵌入模型""" @@ -88,6 +102,63 @@ class RAGService: norms = np.where(norms == 0, 1, norms) return vectors / norms + def _split_into_chunks(self, text: str, chunk_size: int = None, overlap: int = None) -> List[str]: + """ + 将长文本分割成块 + + Args: + text: 待分割的文本 + chunk_size: 每个块的大小(字符数) + overlap: 块之间的重叠字符数 + + Returns: + 文本块列表 + """ + if chunk_size is None: + chunk_size = self.DEFAULT_CHUNK_SIZE + if overlap is None: + overlap = self.DEFAULT_CHUNK_OVERLAP + + if len(text) <= chunk_size: + return [text] if text.strip() else [] + + chunks = [] + start = 0 + text_len = len(text) + + while start < text_len: + # 计算当前块的结束位置 + end = start + chunk_size + + # 如果不是最后一块,尝试在句子边界处切割 + if end < text_len: + # 向前查找最后一个句号、逗号、换行或分号 + cut_positions = [] + for i in range(end, max(start, end - 100), -1): + if text[i] in '。;,,\n、': + cut_positions.append(i + 1) + break + + if cut_positions: + end = cut_positions[0] + else: + # 如果没找到句子边界,尝试向后查找 + for i in range(end, min(text_len, end + 50)): + if text[i] in '。;,,\n、': + end = i + 1 + break + + chunk = text[start:end].strip() + if chunk: + chunks.append(chunk) + + # 移动起始位置(考虑重叠) + start = end - overlap + if start <= 0: + start = end + + return chunks + def index_field( self, table_name: str, @@ -124,9 +195,20 @@ class RAGService: self, doc_id: str, content: str, - metadata: Optional[Dict[str, Any]] = None + metadata: Optional[Dict[str, Any]] = None, + chunk_size: int = None, + chunk_overlap: int = None ): - """将文档内容索引到向量数据库""" + """ + 将文档内容索引到向量数据库(自动分块) + + Args: + doc_id: 文档唯一标识 + content: 文档内容 + metadata: 文档元数据 + chunk_size: 文本块大小(字符数),默认500 + chunk_overlap: 块之间的重叠字符数,默认50 + """ if self._disabled: logger.info(f"[RAG DISABLED] 文档索引操作已跳过: {doc_id}") return @@ -139,18 +221,56 @@ class RAGService: logger.debug(f"文档跳过索引 (无嵌入模型): {doc_id}") return - doc = SimpleDocument( - page_content=content, - metadata=metadata or {"doc_id": doc_id} - ) - self._add_documents([doc], [doc_id]) - logger.debug(f"已索引文档: {doc_id}") + # 分割文档为小块 + if chunk_size is None: + chunk_size = self.DEFAULT_CHUNK_SIZE + if chunk_overlap is None: + chunk_overlap = self.DEFAULT_CHUNK_OVERLAP + + chunks = self._split_into_chunks(content, chunk_size, chunk_overlap) + + if not chunks: + logger.warning(f"文档内容为空,跳过索引: {doc_id}") + return + + # 为每个块创建文档对象 + documents = [] + chunk_ids = [] + + for i, chunk in enumerate(chunks): + chunk_id = f"{doc_id}_chunk_{i}" + chunk_metadata = metadata.copy() if metadata else {} + chunk_metadata.update({ + "chunk_index": i, + "total_chunks": len(chunks), + "doc_id": doc_id + }) + + documents.append(SimpleDocument( + page_content=chunk, + metadata=chunk_metadata + )) + chunk_ids.append(chunk_id) + + # 批量添加文档 + self._add_documents(documents, chunk_ids) + logger.info(f"已索引文档 {doc_id},共 {len(chunks)} 个块") def _add_documents(self, documents: List[SimpleDocument], doc_ids: List[str]): """批量添加文档到向量索引""" if not documents: return + # 总是将文档存储在内存中(用于关键词搜索后备) + for doc, did in zip(documents, doc_ids): + self.documents.append({"id": did, "content": doc.page_content, "metadata": doc.metadata}) + self.doc_ids.append(did) + + # 如果没有嵌入模型,跳过向量索引 + if self.embedding_model is None: + logger.debug(f"文档跳过向量索引 (无嵌入模型): {len(documents)} 个文档") + return + texts = [doc.page_content for doc in documents] embeddings = self.embedding_model.encode(texts, convert_to_numpy=True) embeddings = self._normalize_vectors(embeddings).astype('float32') @@ -162,12 +282,18 @@ class RAGService: id_array = np.array(id_list, dtype='int64') self.index.add_with_ids(embeddings, id_array) - for doc, did in zip(documents, doc_ids): - self.documents.append({"id": did, "content": doc.page_content, "metadata": doc.metadata}) - self.doc_ids.append(did) + def retrieve(self, query: str, top_k: int = 5, min_score: float = 0.3) -> List[Dict[str, Any]]: + """ + 根据查询检索相关文档块 - def retrieve(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: - """根据查询检索相关文档""" + Args: + query: 查询文本 + top_k: 返回的最大结果数 + min_score: 最低相似度分数阈值 + + Returns: + 相关文档块列表,每项包含 content, metadata, score, doc_id, chunk_index + """ if self._disabled: logger.info(f"[RAG DISABLED] 检索操作已跳过: query={query}, top_k={top_k}") return [] @@ -175,28 +301,113 @@ class RAGService: if not self._initialized: self._init_vector_store() - if self.index is None or self.index.ntotal == 0: + # 优先使用向量检索 + if self.index is not None and self.index.ntotal > 0 and self.embedding_model is not None: + try: + query_embedding = self.embedding_model.encode([query], convert_to_numpy=True) + query_embedding = self._normalize_vectors(query_embedding).astype('float32') + + scores, indices = self.index.search(query_embedding, min(top_k, self.index.ntotal)) + + results = [] + for score, idx in zip(scores[0], indices[0]): + if idx < 0: + continue + if score < min_score: + continue + doc = self.documents[idx] + results.append({ + "content": doc["content"], + "metadata": doc["metadata"], + "score": float(score), + "doc_id": doc["id"], + "chunk_index": doc["metadata"].get("chunk_index", 0) + }) + + if results: + logger.debug(f"向量检索到 {len(results)} 条相关文档块") + return results + except Exception as e: + logger.warning(f"向量检索失败,使用关键词搜索后备: {e}") + + # 后备:使用关键词搜索 + logger.debug("使用关键词搜索后备方案") + return self._keyword_search(query, top_k) + + def _keyword_search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: + """ + 关键词搜索后备方案 + + Args: + query: 查询文本 + top_k: 返回的最大结果数 + + Returns: + 相关文档块列表 + """ + if not self.documents: return [] - query_embedding = self.embedding_model.encode([query], convert_to_numpy=True) - query_embedding = self._normalize_vectors(query_embedding).astype('float32') + # 提取查询关键词 + keywords = [] + for char in query: + if '\u4e00' <= char <= '\u9fff': # 中文字符 + keywords.append(char) + # 添加英文单词 + import re + english_words = re.findall(r'[a-zA-Z]+', query) + keywords.extend(english_words) - scores, indices = self.index.search(query_embedding, min(top_k, self.index.ntotal)) + if not keywords: + return [] results = [] - for score, idx in zip(scores[0], indices[0]): - if idx < 0: - continue - doc = self.documents[idx] - results.append({ - "content": doc["content"], - "metadata": doc["metadata"], - "score": float(score), - "doc_id": doc["id"] - }) + for doc in self.documents: + content = doc["content"] + # 计算关键词匹配分数 + score = 0 + matched_keywords = 0 + for kw in keywords: + if kw in content: + score += 1 + matched_keywords += 1 - logger.debug(f"检索到 {len(results)} 条相关文档") - return results + if matched_keywords > 0: + # 归一化分数 + score = score / max(len(keywords), 1) + results.append({ + "content": content, + "metadata": doc["metadata"], + "score": score, + "doc_id": doc["id"], + "chunk_index": doc["metadata"].get("chunk_index", 0) + }) + + # 按分数排序 + results.sort(key=lambda x: x["score"], reverse=True) + + logger.debug(f"关键词搜索返回 {len(results[:top_k])} 条结果") + return results[:top_k] + + def retrieve_by_doc_id(self, doc_id: str, top_k: int = 10) -> List[Dict[str, Any]]: + """ + 获取指定文档的所有块 + + Args: + doc_id: 文档ID + top_k: 返回的最大结果数 + + Returns: + 该文档的所有块 + """ + # 获取属于该文档的所有块 + doc_chunks = [d for d in self.documents if d["metadata"].get("doc_id") == doc_id] + + # 按 chunk_index 排序 + doc_chunks.sort(key=lambda x: x["metadata"].get("chunk_index", 0)) + + # 返回指定数量 + return doc_chunks[:top_k] def retrieve_by_table(self, table_name: str, top_k: int = 5) -> List[Dict[str, Any]]: """检索指定表的字段""" diff --git a/backend/app/services/template_fill_service.py b/backend/app/services/template_fill_service.py index dfa5b20..9465d35 100644 --- a/backend/app/services/template_fill_service.py +++ b/backend/app/services/template_fill_service.py @@ -3,6 +3,7 @@ 从非结构化文档中检索信息并填写到表格模板 """ +import asyncio import logging from dataclasses import dataclass, field from typing import Any, Dict, List, Optional @@ -11,6 +12,7 @@ from app.core.database import mongodb from app.services.llm_service import llm_service from app.core.document_parser import ParserFactory from app.services.markdown_ai_service import markdown_ai_service +from app.services.rag_service import rag_service logger = logging.getLogger(__name__) @@ -43,6 +45,7 @@ class FillResult: value: Any = "" # 保留兼容 source: str = "" # 来源文档 confidence: float = 1.0 # 置信度 + warning: str = None # 多值提示 def __post_init__(self): if self.values is None: @@ -172,49 +175,30 @@ class TemplateFillService: if source_docs and template_fields: logger.info(f"表头看起来正常(非自动生成),无需重新生成: {[f.name for f in template_fields[:5]]}") - # 2. 对每个字段进行提取 + # 2. 并行提取所有字段(跳过AI审核以提升速度) + logger.info(f"开始并行提取 {len(template_fields)} 个字段...") + + # 并行处理所有字段 + tasks = [] for idx, field in enumerate(template_fields): - try: - logger.info(f"提取字段 [{idx+1}/{len(template_fields)}]: {field.name}") - # 从源文档中提取字段值 - result = await self._extract_field_value( - field=field, - source_docs=source_docs, - user_hint=user_hint - ) + task = self._extract_single_field_fast( + field=field, + source_docs=source_docs, + user_hint=user_hint, + field_idx=idx, + total_fields=len(template_fields) + ) + tasks.append(task) - # AI审核:验证提取的值是否合理 - if result.values and result.values[0]: - logger.info(f"字段 {field.name} 进入AI审核阶段...") - verified_result = await self._verify_field_value( - field=field, - extracted_values=result.values, - source_docs=source_docs, - user_hint=user_hint - ) - if verified_result: - # 审核给出了修正结果 - result = verified_result - logger.info(f"字段 {field.name} 审核后修正值: {result.values[:3]}") - else: - logger.info(f"字段 {field.name} 审核通过,使用原提取结果") + # 等待所有任务完成 + results = await asyncio.gather(*tasks, return_exceptions=True) - # 存储结果 - 使用 values 数组 - filled_data[field.name] = result.values if result.values else [""] - fill_details.append({ - "field": field.name, - "cell": field.cell, - "values": result.values, - "value": result.value, - "source": result.source, - "confidence": result.confidence - }) - - logger.info(f"字段 {field.name} 填写完成: {len(result.values)} 个值") - - except Exception as e: - logger.error(f"填写字段 {field.name} 失败: {str(e)}", exc_info=True) - filled_data[field.name] = [f"[提取失败: {str(e)}]"] + # 处理结果 + for idx, result in enumerate(results): + field = template_fields[idx] + if isinstance(result, Exception): + logger.error(f"提取字段 {field.name} 失败: {str(result)}") + filled_data[field.name] = [f"[提取失败: {str(result)}]"] fill_details.append({ "field": field.name, "cell": field.cell, @@ -223,6 +207,18 @@ class TemplateFillService: "source": "error", "confidence": 0.0 }) + else: + filled_data[field.name] = result.values if result.values else [""] + fill_details.append({ + "field": field.name, + "cell": field.cell, + "values": result.values, + "value": result.value, + "source": result.source, + "confidence": result.confidence, + "warning": result.warning + }) + logger.info(f"字段 {field.name} 填写完成: {len(result.values) if result.values else 0} 个值") # 计算最大行数 max_rows = max(len(v) for v in filled_data.values()) if filled_data else 1 @@ -551,6 +547,222 @@ class TemplateFillService: confidence=0.0 ) + async def _extract_single_field_fast( + self, + field: TemplateField, + source_docs: List[SourceDocument], + user_hint: Optional[str] = None, + field_idx: int = 0, + total_fields: int = 1 + ) -> FillResult: + """ + 快速提取单个字段(跳过AI审核,减少LLM调用) + + Args: + field: 字段定义 + source_docs: 源文档列表 + user_hint: 用户提示 + field_idx: 当前字段索引(用于日志) + total_fields: 总字段数(用于日志) + + Returns: + 提取结果 + """ + try: + if not source_docs: + return FillResult( + field=field.name, + value="", + values=[""], + source="无源文档", + confidence=0.0 + ) + + # 1. 优先尝试直接从结构化数据中提取(最快路径) + direct_values = self._extract_values_from_structured_data(source_docs, field.name) + if direct_values: + logger.info(f"✅ [{field_idx+1}/{total_fields}] 字段 {field.name} 直接从结构化数据提取到 {len(direct_values)} 个值") + return FillResult( + field=field.name, + values=direct_values, + value=direct_values[0] if direct_values else "", + source="结构化数据直接提取", + confidence=1.0 + ) + + # 2. 无法直接从结构化数据提取,使用简化版AI提取 + logger.info(f"🔍 [{field_idx+1}/{total_fields}] 字段 {field.name} 尝试AI提取...") + + # 构建提示词 - 简化版 + hint_text = field.hint if field.hint else f"请提取{field.name}的信息" + if user_hint: + hint_text = f"{user_hint}。{hint_text}" + + # 优先使用 RAG 检索内容,否则使用文档开头部分 + context_parts = [] + for doc in source_docs: + if not doc.content: + logger.info(f" 文档 {doc.filename} 无content内容") + continue + + logger.info(f" 处理文档: {doc.filename}, doc_id={doc.doc_id}, content长度={len(doc.content)}") + + # 尝试 RAG 检索 + rag_results = rag_service.retrieve( + query=f"{field.name} {hint_text}", + top_k=3, + min_score=0.1 + ) + + if rag_results: + logger.info(f" RAG检索到 {len(rag_results)} 条结果") + # 使用 RAG 检索到的内容 + for r in rag_results: + rag_doc_id = r.get("doc_id", "") + if rag_doc_id.startswith(doc.doc_id): + context_parts.append(r["content"]) + logger.info(f" 匹配成功,使用RAG内容,长度={len(r['content'])}") + else: + # RAG 没结果,使用文档内容开头 + context_parts.append(doc.content[:2500]) + logger.info(f" RAG无结果,使用文档开头 {min(2500, len(doc.content))} 字符") + + context = "\n\n".join(context_parts[:3]) if context_parts else "" + logger.info(f" 最终context长度: {len(context)}, 内容预览: {context[:200] if context else '无'}...") + + prompt = f"""你是一个专业的数据提取专家。请严格按照表头字段「{field.name}」从文档中提取数据。 + +提示: {hint_text} + +【重要规则 - 必须遵守】 +1. **每个值必须有标注**:根据数据来源添加合适的标注前缀! + - ✅ 正确格式: + - "2024年:38710个" + - "北京:1234万人次" + - "某省:5678万人" + - "公立医院:11754个" + - "三级医院:4111个" + - "图书馆:3246个" + - ❌ 错误格式:"38710个"(缺少标注) + +2. **标注类型根据数据决定**: + - 年份类数据 → "2024年:xxx"、"2023年:xxx" + - 地区类数据 → "北京:xxx"、"广东:xxx"、"某县:xxx" + - 机构/分类数据 → "公立医院:xxx"、"三级医院:xxx"、"图书馆:xxx" + - 其他分类 → 根据实际情况标注 + +3. **严格按表头提取**:只提取与「{field.name}」直接相关的数据 + +4. **多值必须全部提取并标注**:如果文档中提到多个相关数据,每个都要有标注 + +文档内容: +{context if context else "(无文档内容)"} + +请严格按格式返回JSON:{{"values": ["标注:数值", "标注:数值", ...]}} +注意:values数组中每个元素都必须包含标注前缀,不能只有数值! +""" + + messages = [ + {"role": "system", "content": "你是一个专业的数据提取助手,擅长从政府统计公报中提取数据。严格按JSON格式输出,只返回values数组。"}, + {"role": "user", "content": prompt} + ] + + response = await self.llm.chat( + messages=messages, + temperature=0.1, + max_tokens=1000 + ) + + content = self.llm.extract_message_content(response) + logger.info(f" LLM原始返回: {content[:500]}") + + # 解析JSON + import json + import re + cleaned = content.strip() + + # 查找JSON开始位置 + json_start = -1 + for i, c in enumerate(cleaned): + if c == '{': + json_start = i + break + + values = [] + source = "AI提取" + if json_start >= 0: + try: + json_text = cleaned[json_start:] + result = json.loads(json_text) + values = result.get("values", []) + logger.info(f" JSON解析成功,values: {values}") + except json.JSONDecodeError as e: + logger.warning(f" JSON解析失败: {e},尝试修复...") + # 尝试修复常见JSON问题 + try: + # 尝试找到values数组 + values_match = re.search(r'"values"\s*:\s*\[(.*?)\]', cleaned, re.DOTALL) + if values_match: + values_str = values_match.group(1) + # 提取数组中的字符串 + values = re.findall(r'"([^"]*)"', values_str) + logger.info(f" 正则提取values: {values}") + except: + pass + + # 如果values为空,尝试从文本中用正则提取数字+单位 + if not values or values == [""]: + logger.info(f" JSON解析未获取到值,尝试正则提取...") + # 匹配数字+单位或百分号的模式 + patterns = [ + r'(\d+\.?\d*[亿万千百十个]?[%‰℃℃万元亿]?)', # 通用数字+单位 + r'(\d+\.?\d*%)', # 百分号 + r'(\d+\.?\d*[个万人亿元]?)', # 中文单位 + ] + for pattern in patterns: + matches = re.findall(pattern, context) + if matches: + # 过滤掉纯数字 + filtered = [m for m in matches if not m.replace('.', '').isdigit()] + if filtered: + values = filtered[:10] # 最多取10个 + logger.info(f" 正则提取到: {values}") + break + + if not values or values == [""]: + values = self._extract_values_by_regex(cleaned) + + if not values: + values = [""] + + # 生成多值提示(基于实际检测到的值数量) + warning = "" + if len(values) > 1: + warning = f"⚠️ 检测到 {len(values)} 个值:{values[:5]}{'...' if len(values) > 5 else ''}" + + logger.info(f"✅ [{field_idx+1}/{total_fields}] 字段 {field.name} AI提取完成: {len(values)} 个值") + if warning: + logger.info(f" {warning}") + + return FillResult( + field=field.name, + values=values, + value=values[0] if values else "", + source=source, + confidence=0.8, + warning=warning if warning else None + ) + + except Exception as e: + logger.error(f"❌ [{field_idx+1}/{total_fields}] 字段 {field.name} 提取失败: {str(e)}") + return FillResult( + field=field.name, + values=[""], + value="", + source=f"提取失败: {str(e)}", + confidence=0.0 + ) + async def _verify_field_value( self, field: TemplateField, @@ -1172,13 +1384,148 @@ class TemplateFillService: values = [] for row in rows: - if isinstance(row, list) and target_idx < len(row): - val = row[target_idx] + if isinstance(row, list): + # 跳过子表头行(主要包含年份值的行,如 "1985", "1995") + if self._is_year_subheader_row(row): + logger.info(f"跳过子表头行: {row[:5]}...") + continue + # 跳过章节标题行 + if self._is_section_header_row(row): + logger.info(f"跳过章节标题行: {row[:5]}...") + continue + if target_idx < len(row): + val = row[target_idx] + else: + val = "" else: val = "" values.append(self._format_value(val)) - return values + # 过滤掉无效值(章节标题、省略号等) + valid_values = self._filter_valid_values(values) + if len(valid_values) < len(values): + logger.info(f"过滤无效值: {len(values)} -> {len(valid_values)}") + + return valid_values + + def _is_year_subheader_row(self, row: List) -> bool: + """ + 检测行是否看起来像年份子表头行 + + 年份子表头行通常包含 "1985", "1995", "2020" 等4位数字 + + Args: + row: 行数据 + + Returns: + 是否是年份子表头行 + """ + if not row: + return False + + import re + year_pattern = re.compile(r'^(19|20)\d{2}$') + + # 计算看起来像年份的单元格数量 + year_like_count = 0 + for cell in row: + cell_str = str(cell).strip() + if year_pattern.match(cell_str): + year_like_count += 1 + + # 如果超过50%的单元格是年份格式,认为是子表头行 + if len(row) > 0 and year_like_count / len(row) > 0.5: + return True + + return False + + def _is_section_header_row(self, row: List) -> bool: + """ + 检测行是否看起来像章节标题行 + + 章节标题行通常包含 "其中:"、"全部工业中:"、"按...计算" 等关键词 + + Args: + row: 行数据 + + Returns: + 是否是章节标题行 + """ + if not row: + return False + + import re + # 章节标题通常包含这些模式 + section_patterns = [ + r'其中[::]', + r'全部\w+中[::]', + r'按\w+计算', + r'小计', + r'合计', + r'总计', + r'^其中$', + r'全部$' + ] + + for cell in row: + cell_str = str(cell).strip() + if not cell_str: + continue + for pattern in section_patterns: + if re.search(pattern, cell_str): + return True + + return False + + def _is_valid_data_value(self, val: str) -> bool: + """ + 检测值是否是有效的数据值(不是章节标题、省略号等) + + Args: + val: 值字符串 + + Returns: + 是否是有效数据值 + """ + if not val or not str(val).strip(): + return False + + val_str = str(val).strip() + + # 无效模式 + invalid_patterns = [ + r'^…$', # 省略号 + r'^[\.。]+$', # 只有点或句号 + r'其中[::]', # 章节标题 + r'全部\w+中', # 章节标题 + r'按\w+计算', # 计算类型 + r'^(小计|合计|总计)$', # 汇总行 + r'^其中$', + r'^全部$' + ] + + for pattern in invalid_patterns: + import re + if re.match(pattern, val_str): + return False + + return True + + def _filter_valid_values(self, values: List[str]) -> List[str]: + """ + 过滤出有效的数据值 + + Args: + values: 值列表 + + Returns: + 只包含有效值的列表 + """ + valid_values = [] + for val in values: + if self._is_valid_data_value(val): + valid_values.append(val) + return valid_values def _find_best_matching_column(self, headers: List, field_name: str) -> Optional[int]: """ @@ -1206,6 +1553,11 @@ class TemplateFillService: header_str = str(header).strip() header_lower = header_str.lower() + # 跳过空表头(第一列为空的情况) + if not header_str: + logger.info(f"跳过空表头列: 索引 {idx}") + continue + # 策略1: 精确匹配(忽略大小写) if header_lower == field_lower: return idx @@ -1262,6 +1614,12 @@ class TemplateFillService: values = [] for row in rows: + # 跳过子表头行(主要包含年份值的行,如 "1985", "1995") + if isinstance(row, list) and self._is_year_subheader_row(row): + continue + # 跳过章节标题行 + if isinstance(row, list) and self._is_section_header_row(row): + continue if isinstance(row, dict): val = row.get(target_col, "") elif isinstance(row, list) and target_idx < len(row): @@ -1270,7 +1628,12 @@ class TemplateFillService: val = "" values.append(self._format_value(val)) - return values + # 过滤掉无效值(章节标题、省略号等) + valid_values = self._filter_valid_values(values) + if len(valid_values) < len(values): + logger.info(f"过滤无效值: {len(values)} -> {len(valid_values)}") + + return valid_values def _format_value(self, val: Any) -> str: """ @@ -1604,38 +1967,98 @@ class TemplateFillService: if user_hint: hint_text = f"{user_hint}。{hint_text}" - # 构建针对字段提取的提示词 - prompt = f"""你是一个专业的数据提取专家。请从以下文档内容中提取与"{field.name}"完全匹配的数据。 + # 构建查询文本 + query_text = f"{field.name} {hint_text}" + + # 使用 RAG 向量检索获取相关内容块 + rag_results = rag_service.retrieve( + query=query_text, + top_k=5, + min_score=0.3 + ) + + # 构建上下文:优先使用 RAG 检索结果,如果检索不到则使用原始内容 + if rag_results: + # 使用 RAG 检索到的相关块 + context_parts = [] + for result in rag_results: + if result.get("doc_id", "").startswith(doc.doc_id) or not result.get("doc_id"): + context_parts.append(result["content"]) + + if context_parts: + retrieved_context = "\n\n---\n\n".join(context_parts) + logger.info(f"RAG 检索到 {len(context_parts)} 个相关块用于字段 {field.name}") + # 使用检索到的内容(限制长度) + context_to_use = retrieved_context[:6000] + else: + # RAG 检索结果不属于当前文档,使用原始内容 + context_to_use = doc.content[:6000] if doc.content else "" + logger.info(f"字段 {field.name} 使用原始内容(RAG结果不属于当前文档)") + else: + # 没有 RAG 检索结果,使用原始内容 + context_to_use = doc.content[:6000] if doc.content else "" + logger.info(f"字段 {field.name} 使用原始内容(无RAG检索结果)") + + # 构建针对字段提取的提示词 - 增强语义匹配能力 + prompt = f"""你是一个专业的数据提取专家。请从以下文档内容中进行**语义匹配**提取。 【重要】字段名: "{field.name}" 【重要】字段提示: {hint_text} -请严格按照以下步骤操作: -1. 在文档中搜索与"{field.name}"完全相同或高度相关的关键词 -2. 找到后,提取该关键词后的数值(注意:只要数值,不要单位) -3. 如果是表格中的数据,直接提取该单元格的数值 -4. 如果是段落描述,在关键词附近找数值 +## 分类数据识别 + +文档中经常包含分类统计数据,格式如下: + +### 1. 直接分类(用"其中:"、"中,"等分隔) +原文示例: +- "全国医疗卫生机构总数1093551个,其中:医院38710个,基层医疗卫生机构1040023个" + → 字段"医院数量" 应提取: 38710 + → 字段"基层医疗卫生机构数量" 应提取: 1040023 + +- "医院中,公立医院11754个,民营医院26956个" + → 字段"公立医院数量" 应提取: 11754 + → 字段"民营医院数量" 应提取: 26956 + +### 2. 嵌套分类(用"按...分:"、"其中:"等结构) +原文示例: +- "医院按等级分:三级医院4111个(其中:三级甲等医院1876个),二级医院12294个" + → 字段"三级医院数量" 应提取: 4111 + → 字段"三级甲等医院数量" 应提取: 1876 + → 字段"二级医院数量" 应提取: 12294 + +### 3. 匹配技巧 +- "医院数量" 可匹配: "医院38710个"、"医院数量为" +- "公立医院数量" 可匹配: "公立医院11754个"、"公立医院有" +- 忽略"数量"、"数"、"个"等后缀的差异 +- 数值可能紧跟关键词,也可能分开描述 + +## 提取规则 + +1. **全文搜索**:在文档的全部内容中搜索,不要只搜索开头部分 +2. **分类定位**:找到包含该分类关键词的句子,理解其完整的数值 +3. **保留单位**:提取数值时**要包含单位** 【重要】返回值规则: -- 只返回纯数值,不要单位(如 "4.9" 而不是 "4.9万亿元") -- 如果原文是"4.9万亿元",返回 "4.9" -- 如果原文是"144000万册",返回 "144000" -- 如果是百分比如"增长7.7%",返回 "7.7" -- 如果没有找到完全匹配的数据,返回空数组 +- **返回数值时必须包含单位** +- 例如原文"公共图书馆3246个",提取时应返回 "3246个" +- 例如原文"国内旅游收入4.9万亿元",提取时应返回 "4.9万亿元" +- 例如原文"注册护士585.5万人",提取时应返回 "585.5万人" +- 如果字段是"指标"类型,返回具体的指标名称文本(不带单位) +- 如果没有找到任何相关数据,返回空数组 文档内容: -{doc.content[:10000] if doc.content else ""} +{context_to_use} 请用严格的 JSON 格式返回: {{ - "values": ["值1", "值2", ...], // 只填数值,不要单位 - "source": "数据来源说明", + "values": ["提取到的值1", "值2", ...], + "source": "数据来源说明(如:从文档第X段提取)", "confidence": 0.0到1.0之间的置信度 }} -示例: -- 如果字段是"图书馆总藏量(万册)"且文档说"图书总藏量14.4亿册",返回 values: ["144000"] -- 如果字段是"国内旅游收入(亿元)"且文档说"国内旅游收入4.9万亿元",返回 values: ["49000"]""" +【重要】即使是模糊匹配,也要: +- 确保提取的内容确实来自文档 +- source中准确说明数据来源位置""" messages = [ {"role": "system", "content": "你是一个专业的数据提取助手,擅长从政府统计公报等文档中提取数据。请严格按JSON格式输出。"}, @@ -1790,31 +2213,45 @@ class TemplateFillService: source_info += f"【包含表格数】: {tables_count}\n" if tables_summary: source_info += f"{tables_summary}\n" - elif content: - source_info += f"【内容预览】: {content[:1500]}...\n" + if content: + source_info += f"【文档内容】(前3000字符):{content[:3000]}\n" - prompt = f"""你是一个专业的表格设计助手。请根据源文档内容生成合适的表格表头字段。 + prompt = f"""你是一个专业的数据分析助手。请分析源文档中的所有数据,生成表格表头字段。 -任务:用户有一些源文档(包含表格数据),需要填写到空白表格模板中。源文档中的表格如下: +任务:分析源文档,找出所有具体的数据指标及其分类。 {source_info} 【重要要求】 -1. 请仔细阅读上面的源文档表格,找出所有不同的列名(如"产品名称"、"1995年产量"、"按资产总额计算(%)"等) -2. 直接使用这些实际的列名作为表头字段名,不要生成新的或同义词 -3. 如果一个源文档有多个表格,请为每个表格选择合适的列名 -4. 生成3-8个表头字段,优先选择数据量大的表格的列 +1. **只生成数据字段名**: + - ✅ 正确示例:"医院数量"、"公立医院数量"、"病床使用率" + - ❌ 错误示例:"source"、"备注"、"说明"、"数据来源" + +2. **识别所有数值数据**: + - 例如:"医院38710个"、"病床使用率78.8%" + - 例如:"公立医院11754个"、"公立医院病床使用率84.8%" + +3. **理解分类层级**: + - 顶级分类:如"医院"、"基层医疗卫生机构" + - 二级分类:如"医院"下分为"公立医院"、"民营医院" + +4. **生成字段**: + - 字段名要简洁,如:"医院数量"、"病床使用率" + - 优先选择:总数 + 主要分类 + +5. **生成数量**: + - 生成5-7个最有代表性的字段 请严格按照以下 JSON 格式输出(只需输出 JSON,不要其他内容): {{ "fields": [ - {{"name": "实际列名1", "hint": "对该列的说明"}}, - {{"name": "实际列名2", "hint": "对该列的说明"}} + {{"name": "字段名1"}}, + {{"name": "字段名2"}} ] }} """ messages = [ - {"role": "system", "content": "你是一个专业的表格设计助手。请严格按JSON格式输出。"}, + {"role": "system", "content": "你是一个专业的表格设计助手。请严格按JSON格式输出,只返回纯数据字段名,不要source、备注、说明等辅助字段。"}, {"role": "user", "content": prompt} ] @@ -1853,14 +2290,22 @@ class TemplateFillService: if result and "fields" in result: fields = [] + # 过滤非数据字段 + skip_keywords = ["source", "来源", "备注", "说明", "备注列", "说明列", "data_source", "remark", "note", "description"] for idx, f in enumerate(result["fields"]): + field_name = f.get("name", f"字段{idx+1}") + # 跳过非数据字段 + if any(kw in field_name.lower() for kw in skip_keywords): + logger.info(f"跳过非数据字段: {field_name}") + continue fields.append(TemplateField( cell=self._column_to_cell(idx), - name=f.get("name", f"字段{idx+1}"), + name=field_name, field_type="text", required=False, hint=f.get("hint", "") )) + logger.info(f"AI 生成表头: {[f.name for f in fields]}") return fields except Exception as e: diff --git a/frontend/src/pages/Documents.tsx b/frontend/src/pages/Documents.tsx index afeb54d..aa666d9 100644 --- a/frontend/src/pages/Documents.tsx +++ b/frontend/src/pages/Documents.tsx @@ -766,6 +766,7 @@ const Documents: React.FC = () => {
e.stopPropagation()} > diff --git a/frontend/src/pages/TemplateFill.tsx b/frontend/src/pages/TemplateFill.tsx index 0f7fe88..633604c 100644 --- a/frontend/src/pages/TemplateFill.tsx +++ b/frontend/src/pages/TemplateFill.tsx @@ -626,6 +626,16 @@ const TemplateFill: React.FC = () => {
来源: {detail.source} | 置信度: {detail.confidence ? (detail.confidence * 100).toFixed(0) + '%' : 'N/A'}
+ {detail.warning && ( +
+ ⚠️ {detail.warning} +
+ )} + {detail.values && detail.values.length > 1 && !detail.warning && ( +
+ 多值: {detail.values.join(', ')} +
+ )}
))}