feat: 实现智能指令的格式转换和文档编辑功能
主要更新: - 新增 transform 意图:支持 Word/Excel/Markdown 格式互转 - 新增 edit 意图:使用 LLM 润色编辑文档内容 - 智能指令接口增加异步执行模式(async_execute 参数) - 修复 Word 模板导出文档损坏问题(改用临时文件方式) - 优化 intent_parser 增加 transform/edit 关键词识别 新增文件: - app/api/endpoints/instruction.py: 智能指令 API 端点 - app/services/multi_doc_reasoning_service.py: 多文档推理服务 其他优化: - RAG 服务混合搜索(BM25 + 向量)融合 - 模板填充服务表头匹配增强 - Word AI 解析服务返回结构完善 - 前端 InstructionChat 组件对接真实 API
This commit is contained in:
446
backend/app/services/multi_doc_reasoning_service.py
Normal file
446
backend/app/services/multi_doc_reasoning_service.py
Normal file
@@ -0,0 +1,446 @@
|
||||
"""
|
||||
多文档关联推理服务
|
||||
|
||||
跨文档信息关联和推理
|
||||
"""
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from collections import defaultdict
|
||||
|
||||
from app.services.llm_service import llm_service
|
||||
from app.services.rag_service import rag_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MultiDocReasoningService:
|
||||
"""
|
||||
多文档关联推理服务
|
||||
|
||||
功能:
|
||||
1. 实体跨文档追踪 - 追踪同一实体在不同文档中的描述
|
||||
2. 关系抽取与推理 - 抽取实体间关系并进行推理
|
||||
3. 信息补全 - 根据多个文档的信息互补填充缺失数据
|
||||
4. 冲突检测 - 检测不同文档间的信息冲突
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.llm = llm_service
|
||||
|
||||
async def analyze_cross_documents(
|
||||
self,
|
||||
documents: List[Dict[str, Any]],
|
||||
query: Optional[str] = None,
|
||||
entity_types: Optional[List[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
跨文档分析
|
||||
|
||||
Args:
|
||||
documents: 文档列表
|
||||
query: 查询意图(可选)
|
||||
entity_types: 要追踪的实体类型列表,如 ["机构", "人物", "地点", "数量"]
|
||||
|
||||
Returns:
|
||||
跨文档分析结果
|
||||
"""
|
||||
if not documents:
|
||||
return {"success": False, "error": "没有可用的文档"}
|
||||
|
||||
entity_types = entity_types or ["机构", "数量", "时间", "地点"]
|
||||
|
||||
try:
|
||||
# 1. 提取各文档中的实体
|
||||
entities_per_doc = await self._extract_entities_from_docs(documents, entity_types)
|
||||
|
||||
# 2. 跨文档实体对齐
|
||||
aligned_entities = self._align_entities_across_docs(entities_per_doc)
|
||||
|
||||
# 3. 关系抽取
|
||||
relations = await self._extract_relations(documents)
|
||||
|
||||
# 4. 构建知识图谱
|
||||
knowledge_graph = self._build_knowledge_graph(aligned_entities, relations)
|
||||
|
||||
# 5. 信息补全
|
||||
completed_info = await self._complete_missing_info(knowledge_graph, documents)
|
||||
|
||||
# 6. 冲突检测
|
||||
conflicts = self._detect_conflicts(aligned_entities)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"entities": aligned_entities,
|
||||
"relations": relations,
|
||||
"knowledge_graph": knowledge_graph,
|
||||
"completed_info": completed_info,
|
||||
"conflicts": conflicts,
|
||||
"summary": self._generate_summary(aligned_entities, conflicts)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"跨文档分析失败: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def _extract_entities_from_docs(
|
||||
self,
|
||||
documents: List[Dict[str, Any]],
|
||||
entity_types: List[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""从各文档中提取实体"""
|
||||
entities_per_doc = []
|
||||
|
||||
for idx, doc in enumerate(documents):
|
||||
doc_id = doc.get("_id", f"doc_{idx}")
|
||||
content = doc.get("content", "")[:8000] # 限制长度
|
||||
|
||||
# 使用 LLM 提取实体
|
||||
prompt = f"""从以下文档中提取指定的实体类型信息。
|
||||
|
||||
实体类型: {', '.join(entity_types)}
|
||||
|
||||
文档内容:
|
||||
{content}
|
||||
|
||||
请按以下 JSON 格式输出(只需输出 JSON):
|
||||
{{
|
||||
"entities": [
|
||||
{{"type": "机构", "name": "实体名称", "value": "相关数值(如有)", "context": "上下文描述"}},
|
||||
...
|
||||
]
|
||||
}}
|
||||
|
||||
只提取在文档中明确提到的实体,不要推测。"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个实体提取专家。请严格按JSON格式输出。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
response = await self.llm.chat(messages=messages, temperature=0.1, max_tokens=3000)
|
||||
content_response = self.llm.extract_message_content(response)
|
||||
|
||||
# 解析 JSON
|
||||
import json
|
||||
import re
|
||||
cleaned = content_response.strip()
|
||||
json_match = re.search(r'\{[\s\S]*\}', cleaned)
|
||||
if json_match:
|
||||
result = json.loads(json_match.group())
|
||||
entities = result.get("entities", [])
|
||||
entities_per_doc.append({
|
||||
"doc_id": doc_id,
|
||||
"doc_name": doc.get("metadata", {}).get("original_filename", f"文档{idx+1}"),
|
||||
"entities": entities
|
||||
})
|
||||
logger.info(f"文档 {doc_id} 提取到 {len(entities)} 个实体")
|
||||
except Exception as e:
|
||||
logger.warning(f"文档 {doc_id} 实体提取失败: {e}")
|
||||
|
||||
return entities_per_doc
|
||||
|
||||
def _align_entities_across_docs(
|
||||
self,
|
||||
entities_per_doc: List[Dict[str, Any]]
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
跨文档实体对齐
|
||||
|
||||
将同一实体在不同文档中的描述进行关联
|
||||
"""
|
||||
aligned: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
|
||||
|
||||
for doc_data in entities_per_doc:
|
||||
doc_id = doc_data["doc_id"]
|
||||
doc_name = doc_data["doc_name"]
|
||||
|
||||
for entity in doc_data.get("entities", []):
|
||||
entity_name = entity.get("name", "")
|
||||
if not entity_name:
|
||||
continue
|
||||
|
||||
# 标准化实体名(去除空格和括号内容)
|
||||
normalized = self._normalize_entity_name(entity_name)
|
||||
|
||||
aligned[normalized].append({
|
||||
"original_name": entity_name,
|
||||
"type": entity.get("type", "未知"),
|
||||
"value": entity.get("value", ""),
|
||||
"context": entity.get("context", ""),
|
||||
"source_doc": doc_name,
|
||||
"source_doc_id": doc_id
|
||||
})
|
||||
|
||||
# 合并相同实体
|
||||
result = {}
|
||||
for normalized, appearances in aligned.items():
|
||||
if len(appearances) > 1:
|
||||
result[normalized] = appearances
|
||||
logger.info(f"实体对齐: {normalized} 在 {len(appearances)} 个文档中出现")
|
||||
|
||||
return result
|
||||
|
||||
def _normalize_entity_name(self, name: str) -> str:
|
||||
"""标准化实体名称"""
|
||||
# 去除空格
|
||||
name = name.strip()
|
||||
# 去除括号内容
|
||||
name = re.sub(r'[((].*?[))]', '', name)
|
||||
# 去除"第X名"等
|
||||
name = re.sub(r'^第\d+[名位个]', '', name)
|
||||
return name.strip()
|
||||
|
||||
async def _extract_relations(
|
||||
self,
|
||||
documents: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, str]]:
|
||||
"""从文档中抽取关系"""
|
||||
relations = []
|
||||
|
||||
# 合并所有文档内容
|
||||
combined_content = "\n\n".join([
|
||||
f"【{doc.get('metadata', {}).get('original_filename', f'文档{i}')}】\n{doc.get('content', '')[:3000]}"
|
||||
for i, doc in enumerate(documents)
|
||||
])
|
||||
|
||||
prompt = f"""从以下文档内容中抽取实体之间的关系。
|
||||
|
||||
文档内容:
|
||||
{combined_content[:8000]}
|
||||
|
||||
请识别以下类型的关系:
|
||||
- 包含关系 (A包含B)
|
||||
- 隶属关系 (A隶属于B)
|
||||
- 合作关系 (A与B合作)
|
||||
- 对比关系 (A vs B)
|
||||
- 时序关系 (A先于B发生)
|
||||
|
||||
请按以下 JSON 格式输出(只需输出 JSON):
|
||||
{{
|
||||
"relations": [
|
||||
{{"entity1": "实体1", "entity2": "实体2", "relation": "关系类型", "description": "关系描述"}},
|
||||
...
|
||||
]
|
||||
}}
|
||||
|
||||
如果没有找到明确的关系,返回空数组。"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个关系抽取专家。请严格按JSON格式输出。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
response = await self.llm.chat(messages=messages, temperature=0.1, max_tokens=3000)
|
||||
content_response = self.llm.extract_message_content(response)
|
||||
|
||||
import json
|
||||
import re
|
||||
cleaned = content_response.strip()
|
||||
json_match = re.search(r'\{{[\s\S]*\}}', cleaned)
|
||||
if json_match:
|
||||
result = json.loads(json_match.group())
|
||||
relations = result.get("relations", [])
|
||||
logger.info(f"抽取到 {len(relations)} 个关系")
|
||||
except Exception as e:
|
||||
logger.warning(f"关系抽取失败: {e}")
|
||||
|
||||
return relations
|
||||
|
||||
def _build_knowledge_graph(
|
||||
self,
|
||||
aligned_entities: Dict[str, List[Dict[str, Any]]],
|
||||
relations: List[Dict[str, str]]
|
||||
) -> Dict[str, Any]:
|
||||
"""构建知识图谱"""
|
||||
nodes = []
|
||||
edges = []
|
||||
node_ids = set()
|
||||
|
||||
# 添加实体节点
|
||||
for entity_name, appearances in aligned_entities.items():
|
||||
if len(appearances) < 1:
|
||||
continue
|
||||
|
||||
first_appearance = appearances[0]
|
||||
node_id = f"entity_{len(nodes)}"
|
||||
|
||||
# 收集该实体在所有文档中的值
|
||||
values = [a.get("value", "") for a in appearances if a.get("value")]
|
||||
primary_value = values[0] if values else ""
|
||||
|
||||
nodes.append({
|
||||
"id": node_id,
|
||||
"name": entity_name,
|
||||
"type": first_appearance.get("type", "未知"),
|
||||
"value": primary_value,
|
||||
"occurrence_count": len(appearances),
|
||||
"sources": [a.get("source_doc", "") for a in appearances]
|
||||
})
|
||||
node_ids.add(entity_name)
|
||||
|
||||
# 添加关系边
|
||||
for relation in relations:
|
||||
entity1 = self._normalize_entity_name(relation.get("entity1", ""))
|
||||
entity2 = self._normalize_entity_name(relation.get("entity2", ""))
|
||||
|
||||
if entity1 in node_ids and entity2 in node_ids:
|
||||
edges.append({
|
||||
"source": entity1,
|
||||
"target": entity2,
|
||||
"relation": relation.get("relation", "相关"),
|
||||
"description": relation.get("description", "")
|
||||
})
|
||||
|
||||
return {
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"stats": {
|
||||
"entity_count": len(nodes),
|
||||
"relation_count": len(edges)
|
||||
}
|
||||
}
|
||||
|
||||
async def _complete_missing_info(
|
||||
self,
|
||||
knowledge_graph: Dict[str, Any],
|
||||
documents: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""根据多个文档补全信息"""
|
||||
completed = []
|
||||
|
||||
for node in knowledge_graph.get("nodes", []):
|
||||
if not node.get("value") and node.get("occurrence_count", 0) > 1:
|
||||
# 实体在多个文档中出现但没有数值,尝试从 RAG 检索补充
|
||||
query = f"{node['name']} 数值 数据"
|
||||
results = rag_service.retrieve(query, top_k=3, min_score=0.3)
|
||||
|
||||
if results:
|
||||
completed.append({
|
||||
"entity": node["name"],
|
||||
"type": node.get("type", "未知"),
|
||||
"source": "rag_inference",
|
||||
"context": results[0].get("content", "")[:200],
|
||||
"confidence": results[0].get("score", 0)
|
||||
})
|
||||
|
||||
return completed
|
||||
|
||||
def _detect_conflicts(
|
||||
self,
|
||||
aligned_entities: Dict[str, List[Dict[str, Any]]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""检测不同文档间的信息冲突"""
|
||||
conflicts = []
|
||||
|
||||
for entity_name, appearances in aligned_entities.items():
|
||||
if len(appearances) < 2:
|
||||
continue
|
||||
|
||||
# 检查数值冲突
|
||||
values = {}
|
||||
for appearance in appearances:
|
||||
val = appearance.get("value", "")
|
||||
if val:
|
||||
source = appearance.get("source_doc", "未知来源")
|
||||
values[source] = val
|
||||
|
||||
if len(values) > 1:
|
||||
unique_values = set(values.values())
|
||||
if len(unique_values) > 1:
|
||||
conflicts.append({
|
||||
"entity": entity_name,
|
||||
"type": "value_conflict",
|
||||
"details": values,
|
||||
"description": f"实体 '{entity_name}' 在不同文档中有不同数值: {values}"
|
||||
})
|
||||
|
||||
return conflicts
|
||||
|
||||
def _generate_summary(
|
||||
self,
|
||||
aligned_entities: Dict[str, List[Dict[str, Any]]],
|
||||
conflicts: List[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""生成摘要"""
|
||||
summary_parts = []
|
||||
|
||||
total_entities = sum(len(appearances) for appearances in aligned_entities.values())
|
||||
multi_doc_entities = sum(1 for appearances in aligned_entities.values() if len(appearances) > 1)
|
||||
|
||||
summary_parts.append(f"跨文档分析完成:发现 {total_entities} 个实体")
|
||||
summary_parts.append(f"其中 {multi_doc_entities} 个实体在多个文档中被提及")
|
||||
|
||||
if conflicts:
|
||||
summary_parts.append(f"检测到 {len(conflicts)} 个潜在冲突")
|
||||
|
||||
return "; ".join(summary_parts)
|
||||
|
||||
async def answer_cross_doc_question(
|
||||
self,
|
||||
question: str,
|
||||
documents: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
跨文档问答
|
||||
|
||||
Args:
|
||||
question: 问题
|
||||
documents: 文档列表
|
||||
|
||||
Returns:
|
||||
答案结果
|
||||
"""
|
||||
# 先进行跨文档分析
|
||||
analysis_result = await self.analyze_cross_documents(documents, query=question)
|
||||
|
||||
# 构建上下文
|
||||
context_parts = []
|
||||
|
||||
# 添加实体信息
|
||||
for entity_name, appearances in analysis_result.get("entities", {}).items():
|
||||
contexts = [f"{a.get('source_doc')}: {a.get('context', '')}" for a in appearances[:2]]
|
||||
if contexts:
|
||||
context_parts.append(f"【{entity_name}】{' | '.join(contexts)}")
|
||||
|
||||
# 添加关系信息
|
||||
for relation in analysis_result.get("relations", [])[:5]:
|
||||
context_parts.append(f"【关系】{relation.get('entity1')} {relation.get('relation')} {relation.get('entity2')}: {relation.get('description', '')}")
|
||||
|
||||
context_text = "\n\n".join(context_parts) if context_parts else "未找到相关实体和关系"
|
||||
|
||||
# 使用 LLM 生成答案
|
||||
prompt = f"""基于以下跨文档分析结果,回答用户问题。
|
||||
|
||||
问题: {question}
|
||||
|
||||
分析结果:
|
||||
{context_text}
|
||||
|
||||
请直接回答问题,如果分析结果中没有相关信息,请说明"根据提供的文档无法回答该问题"。"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个基于文档的问答助手。请根据提供的信息回答问题。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
response = await self.llm.chat(messages=messages, temperature=0.2, max_tokens=2000)
|
||||
answer = self.llm.extract_message_content(response)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"supporting_entities": list(analysis_result.get("entities", {}).keys())[:10],
|
||||
"relations_count": len(analysis_result.get("relations", []))
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"跨文档问答失败: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
# 全局单例
|
||||
multi_doc_reasoning_service = MultiDocReasoningService()
|
||||
@@ -2,11 +2,15 @@
|
||||
RAG 服务模块 - 检索增强生成
|
||||
|
||||
使用 sentence-transformers + Faiss 实现向量检索
|
||||
支持 BM25 关键词检索 + 向量检索混合融合
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
from typing import Any, Dict, List, Optional
|
||||
import re
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from collections import Counter, defaultdict
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
@@ -32,6 +36,132 @@ class SimpleDocument:
|
||||
self.metadata = metadata
|
||||
|
||||
|
||||
class BM25:
|
||||
"""
|
||||
BM25 关键词检索算法
|
||||
|
||||
一种基于词频和文档频率的信息检索算法,比纯向量搜索更适合关键词精确匹配
|
||||
"""
|
||||
|
||||
def __init__(self, k1: float = 1.5, b: float = 0.75):
|
||||
self.k1 = k1 # 词频饱和参数
|
||||
self.b = b # 文档长度归一化参数
|
||||
self.documents: List[str] = []
|
||||
self.doc_ids: List[str] = []
|
||||
self.avg_doc_length = 0
|
||||
self.doc_freqs: Dict[str, int] = {} # 词 -> 包含该词的文档数
|
||||
self.idf: Dict[str, float] = {} # 词 -> IDF 值
|
||||
self.doc_lengths: List[int] = []
|
||||
self.doc_term_freqs: List[Dict[str, int]] = [] # 每个文档的词频
|
||||
|
||||
def _tokenize(self, text: str) -> List[str]:
|
||||
"""分词(简单的中文分词)"""
|
||||
if not text:
|
||||
return []
|
||||
# 简单分词:按标点和空格分割
|
||||
tokens = re.findall(r'[\u4e00-\u9fff]+|[a-zA-Z0-9]+', text.lower())
|
||||
# 过滤单字符
|
||||
return [t for t in tokens if len(t) > 1]
|
||||
|
||||
def fit(self, documents: List[str], doc_ids: List[str]):
|
||||
"""
|
||||
构建 BM25 索引
|
||||
|
||||
Args:
|
||||
documents: 文档内容列表
|
||||
doc_ids: 文档 ID 列表
|
||||
"""
|
||||
self.documents = documents
|
||||
self.doc_ids = doc_ids
|
||||
n = len(documents)
|
||||
|
||||
# 统计文档频率
|
||||
self.doc_freqs = defaultdict(int)
|
||||
self.doc_lengths = []
|
||||
self.doc_term_freqs = []
|
||||
|
||||
for doc in documents:
|
||||
tokens = self._tokenize(doc)
|
||||
self.doc_lengths.append(len(tokens))
|
||||
doc_tf = Counter(tokens)
|
||||
self.doc_term_freqs.append(doc_tf)
|
||||
|
||||
for term in doc_tf:
|
||||
self.doc_freqs[term] += 1
|
||||
|
||||
# 计算平均文档长度
|
||||
self.avg_doc_length = sum(self.doc_lengths) / n if n > 0 else 0
|
||||
|
||||
# 计算 IDF
|
||||
for term, df in self.doc_freqs.items():
|
||||
# IDF = log((n - df + 0.5) / (df + 0.5))
|
||||
self.idf[term] = math.log((n - df + 0.5) / (df + 0.5) + 1)
|
||||
|
||||
logger.info(f"BM25 索引构建完成: {n} 个文档, {len(self.idf)} 个词项")
|
||||
|
||||
def search(self, query: str, top_k: int = 10) -> List[Tuple[int, float]]:
|
||||
"""
|
||||
搜索相关文档
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
top_k: 返回前 k 个结果
|
||||
|
||||
Returns:
|
||||
[(文档索引, BM25分数), ...]
|
||||
"""
|
||||
if not self.documents:
|
||||
return []
|
||||
|
||||
query_tokens = self._tokenize(query)
|
||||
if not query_tokens:
|
||||
return []
|
||||
|
||||
scores = []
|
||||
n = len(self.documents)
|
||||
|
||||
for idx in range(n):
|
||||
score = self._calculate_score(query_tokens, idx)
|
||||
scores.append((idx, score))
|
||||
|
||||
# 按分数降序排序
|
||||
scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
return scores[:top_k]
|
||||
|
||||
def _calculate_score(self, query_tokens: List[str], doc_idx: int) -> float:
|
||||
"""计算单个文档的 BM25 分数"""
|
||||
doc_tf = self.doc_term_freqs[doc_idx]
|
||||
doc_len = self.doc_lengths[doc_idx]
|
||||
score = 0.0
|
||||
|
||||
for term in query_tokens:
|
||||
if term not in self.idf:
|
||||
continue
|
||||
|
||||
tf = doc_tf.get(term, 0)
|
||||
idf = self.idf[term]
|
||||
|
||||
# BM25 公式
|
||||
numerator = tf * (self.k1 + 1)
|
||||
denominator = tf + self.k1 * (1 - self.b + self.b * doc_len / self.avg_doc_length)
|
||||
|
||||
score += idf * numerator / denominator
|
||||
|
||||
return score
|
||||
|
||||
def get_scores(self, query: str) -> List[float]:
|
||||
"""获取所有文档的 BM25 分数"""
|
||||
if not self.documents:
|
||||
return []
|
||||
|
||||
query_tokens = self._tokenize(query)
|
||||
if not query_tokens:
|
||||
return [0.0] * len(self.documents)
|
||||
|
||||
return [self._calculate_score(query_tokens, idx) for idx in range(len(self.documents))]
|
||||
|
||||
|
||||
class RAGService:
|
||||
"""RAG 检索增强服务"""
|
||||
|
||||
@@ -47,12 +177,15 @@ class RAGService:
|
||||
self._dimension: int = 384 # 默认维度
|
||||
self._initialized = False
|
||||
self._persist_dir = settings.FAISS_INDEX_DIR
|
||||
# BM25 索引
|
||||
self.bm25: Optional[BM25] = None
|
||||
self._bm25_enabled = True # 始终启用 BM25
|
||||
# 检查是否可用
|
||||
self._disabled = not SENTENCE_TRANSFORMERS_AVAILABLE
|
||||
if self._disabled:
|
||||
logger.warning("RAG 服务已禁用(sentence-transformers 不可用),将使用关键词匹配作为后备")
|
||||
logger.warning("RAG 服务已禁用(sentence-transformers 不可用),将使用 BM25 关键词检索")
|
||||
else:
|
||||
logger.info("RAG 服务已启用")
|
||||
logger.info("RAG 服务已启用(向量检索 + BM25 混合检索)")
|
||||
|
||||
def _init_embeddings(self):
|
||||
"""初始化嵌入模型"""
|
||||
@@ -261,11 +394,25 @@ class RAGService:
|
||||
if not documents:
|
||||
return
|
||||
|
||||
# 总是将文档存储在内存中(用于关键词搜索后备)
|
||||
# 总是将文档存储在内存中(用于 BM25 和关键词搜索)
|
||||
for doc, did in zip(documents, doc_ids):
|
||||
self.documents.append({"id": did, "content": doc.page_content, "metadata": doc.metadata})
|
||||
self.doc_ids.append(did)
|
||||
|
||||
# 构建 BM25 索引
|
||||
if self._bm25_enabled and documents:
|
||||
bm25_texts = [doc.page_content for doc in documents]
|
||||
if self.bm25 is None:
|
||||
self.bm25 = BM25()
|
||||
self.bm25.fit(bm25_texts, doc_ids)
|
||||
else:
|
||||
# 增量添加:重新构建(BM25 不支持增量)
|
||||
all_texts = [d["content"] for d in self.documents]
|
||||
all_ids = self.doc_ids.copy()
|
||||
self.bm25 = BM25()
|
||||
self.bm25.fit(all_texts, all_ids)
|
||||
logger.debug(f"BM25 索引更新: {len(documents)} 个文档")
|
||||
|
||||
# 如果没有嵌入模型,跳过向量索引
|
||||
if self.embedding_model is None:
|
||||
logger.debug(f"文档跳过向量索引 (无嵌入模型): {len(documents)} 个文档")
|
||||
@@ -284,7 +431,7 @@ class RAGService:
|
||||
|
||||
def retrieve(self, query: str, top_k: int = 5, min_score: float = 0.3) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
根据查询检索相关文档块
|
||||
根据查询检索相关文档块(混合检索:向量 + BM25)
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
@@ -301,39 +448,167 @@ class RAGService:
|
||||
if not self._initialized:
|
||||
self._init_vector_store()
|
||||
|
||||
# 优先使用向量检索
|
||||
if self.index is not None and self.index.ntotal > 0 and self.embedding_model is not None:
|
||||
try:
|
||||
query_embedding = self.embedding_model.encode([query], convert_to_numpy=True)
|
||||
query_embedding = self._normalize_vectors(query_embedding).astype('float32')
|
||||
# 获取向量检索结果
|
||||
vector_results = self._vector_search(query, top_k * 2, min_score)
|
||||
|
||||
scores, indices = self.index.search(query_embedding, min(top_k, self.index.ntotal))
|
||||
# 获取 BM25 检索结果
|
||||
bm25_results = self._bm25_search(query, top_k * 2)
|
||||
|
||||
results = []
|
||||
for score, idx in zip(scores[0], indices[0]):
|
||||
if idx < 0:
|
||||
continue
|
||||
if score < min_score:
|
||||
continue
|
||||
doc = self.documents[idx]
|
||||
results.append({
|
||||
"content": doc["content"],
|
||||
"metadata": doc["metadata"],
|
||||
"score": float(score),
|
||||
"doc_id": doc["id"],
|
||||
"chunk_index": doc["metadata"].get("chunk_index", 0)
|
||||
})
|
||||
# 混合融合
|
||||
hybrid_results = self._hybrid_fusion(vector_results, bm25_results, top_k)
|
||||
|
||||
if results:
|
||||
logger.debug(f"向量检索到 {len(results)} 条相关文档块")
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(f"向量检索失败,使用关键词搜索后备: {e}")
|
||||
if hybrid_results:
|
||||
logger.info(f"混合检索到 {len(hybrid_results)} 条相关文档块 (向量:{len(vector_results)}, BM25:{len(bm25_results)})")
|
||||
return hybrid_results
|
||||
|
||||
# 后备:使用关键词搜索
|
||||
logger.debug("使用关键词搜索后备方案")
|
||||
# 降级:只使用 BM25
|
||||
if bm25_results:
|
||||
logger.info(f"降级到 BM25 检索: {len(bm25_results)} 条")
|
||||
return bm25_results
|
||||
|
||||
# 降级:使用关键词搜索
|
||||
logger.info("降级到关键词搜索")
|
||||
return self._keyword_search(query, top_k)
|
||||
|
||||
def _vector_search(self, query: str, top_k: int, min_score: float) -> List[Dict[str, Any]]:
|
||||
"""向量检索"""
|
||||
if self.index is None or self.index.ntotal == 0 or self.embedding_model is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
query_embedding = self.embedding_model.encode([query], convert_to_numpy=True)
|
||||
query_embedding = self._normalize_vectors(query_embedding).astype('float32')
|
||||
|
||||
scores, indices = self.index.search(query_embedding, min(top_k * 2, self.index.ntotal))
|
||||
|
||||
results = []
|
||||
for score, idx in zip(scores[0], indices[0]):
|
||||
if idx < 0:
|
||||
continue
|
||||
if score < min_score:
|
||||
continue
|
||||
doc = self.documents[idx]
|
||||
results.append({
|
||||
"content": doc["content"],
|
||||
"metadata": doc["metadata"],
|
||||
"score": float(score),
|
||||
"doc_id": doc["id"],
|
||||
"chunk_index": doc["metadata"].get("chunk_index", 0),
|
||||
"search_type": "vector"
|
||||
})
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(f"向量检索失败: {e}")
|
||||
return []
|
||||
|
||||
def _bm25_search(self, query: str, top_k: int) -> List[Dict[str, Any]]:
|
||||
"""BM25 检索"""
|
||||
if not self.bm25 or not self.documents:
|
||||
return []
|
||||
|
||||
try:
|
||||
bm25_scores = self.bm25.get_scores(query)
|
||||
if not bm25_scores:
|
||||
return []
|
||||
|
||||
# 归一化 BM25 分数到 [0, 1]
|
||||
max_score = max(bm25_scores) if bm25_scores else 1
|
||||
min_score_bm = min(bm25_scores) if bm25_scores else 0
|
||||
score_range = max_score - min_score_bm if max_score != min_score_bm else 1
|
||||
|
||||
results = []
|
||||
for idx, score in enumerate(bm25_scores):
|
||||
if score <= 0:
|
||||
continue
|
||||
# 归一化
|
||||
normalized_score = (score - min_score_bm) / score_range if score_range > 0 else 0
|
||||
doc = self.documents[idx]
|
||||
results.append({
|
||||
"content": doc["content"],
|
||||
"metadata": doc["metadata"],
|
||||
"score": float(normalized_score),
|
||||
"doc_id": doc["id"],
|
||||
"chunk_index": doc["metadata"].get("chunk_index", 0),
|
||||
"search_type": "bm25"
|
||||
})
|
||||
|
||||
# 按分数降序
|
||||
results.sort(key=lambda x: x["score"], reverse=True)
|
||||
return results[:top_k]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"BM25 检索失败: {e}")
|
||||
return []
|
||||
|
||||
def _hybrid_fusion(
|
||||
self,
|
||||
vector_results: List[Dict[str, Any]],
|
||||
bm25_results: List[Dict[str, Any]],
|
||||
top_k: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
混合融合向量和 BM25 检索结果
|
||||
|
||||
使用 RRFR (Reciprocal Rank Fusion) 算法:
|
||||
Score = weight_vector * (1 / rank_vector) + weight_bm25 * (1 / rank_bm25)
|
||||
|
||||
Args:
|
||||
vector_results: 向量检索结果
|
||||
bm25_results: BM25 检索结果
|
||||
top_k: 返回数量
|
||||
|
||||
Returns:
|
||||
融合后的结果
|
||||
"""
|
||||
if not vector_results and not bm25_results:
|
||||
return []
|
||||
|
||||
# 融合权重
|
||||
weight_vector = 0.6
|
||||
weight_bm25 = 0.4
|
||||
|
||||
# 构建文档分数映射
|
||||
doc_scores: Dict[str, Dict[str, float]] = {}
|
||||
|
||||
# 添加向量检索结果
|
||||
for rank, result in enumerate(vector_results):
|
||||
doc_id = result["doc_id"]
|
||||
if doc_id not in doc_scores:
|
||||
doc_scores[doc_id] = {"vector": 0, "bm25": 0, "content": result["content"], "metadata": result["metadata"]}
|
||||
# 使用倒数排名 (Reciprocal Rank)
|
||||
doc_scores[doc_id]["vector"] = weight_vector / (rank + 1)
|
||||
|
||||
# 添加 BM25 检索结果
|
||||
for rank, result in enumerate(bm25_results):
|
||||
doc_id = result["doc_id"]
|
||||
if doc_id not in doc_scores:
|
||||
doc_scores[doc_id] = {"vector": 0, "bm25": 0, "content": result["content"], "metadata": result["metadata"]}
|
||||
doc_scores[doc_id]["bm25"] = weight_bm25 / (rank + 1)
|
||||
|
||||
# 计算融合分数
|
||||
fused_results = []
|
||||
for doc_id, scores in doc_scores.items():
|
||||
fused_score = scores["vector"] + scores["bm25"]
|
||||
# 使用向量检索结果的原始分数作为参考
|
||||
vector_score = next((r["score"] for r in vector_results if r["doc_id"] == doc_id), 0.5)
|
||||
fused_results.append({
|
||||
"content": scores["content"],
|
||||
"metadata": scores["metadata"],
|
||||
"score": fused_score,
|
||||
"doc_id": doc_id,
|
||||
"vector_score": vector_score,
|
||||
"bm25_score": scores["bm25"],
|
||||
"search_type": "hybrid"
|
||||
})
|
||||
|
||||
# 按融合分数降序排序
|
||||
fused_results.sort(key=lambda x: x["score"], reverse=True)
|
||||
|
||||
logger.debug(f"混合融合: {len(fused_results)} 个文档, 向量:{len(vector_results)}, BM25:{len(bm25_results)}")
|
||||
|
||||
return fused_results[:top_k]
|
||||
|
||||
def _keyword_search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
关键词搜索后备方案
|
||||
|
||||
@@ -13,6 +13,7 @@ from app.services.llm_service import llm_service
|
||||
from app.core.document_parser import ParserFactory
|
||||
from app.services.markdown_ai_service import markdown_ai_service
|
||||
from app.services.rag_service import rag_service
|
||||
from app.services.word_ai_service import word_ai_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -55,6 +56,249 @@ class FillResult:
|
||||
class TemplateFillService:
|
||||
"""表格填写服务"""
|
||||
|
||||
# 通用表头语义扩展字典
|
||||
GENERIC_HEADER_EXPANSION = {
|
||||
"机构": ["医院", "学校", "企业", "机关", "团体", "协会", "基金会", "研究所", "医院数量", "学校数量", "企业数量"],
|
||||
"名称": ["医院名称", "学校名称", "企业名称", "机构名称", "单位名称", "名称"],
|
||||
"类型": ["医院类型", "学校类型", "企业类型", "机构类型", "类型分类"],
|
||||
"数量": ["医院数量", "学校数量", "企业数量", "机构数量", "个数", "总数", "人员数量"],
|
||||
"金额": ["金额", "收入", "支出", "产值", "销售额", "利润", "税收"],
|
||||
"比率": ["增长率", "占比", "比重", "比率", "百分比", "使用率", "就业率"],
|
||||
"面积": ["占地面积", "建筑面积", "用地面积", "耕地面积", "绿化面积"],
|
||||
"人口": ["常住人口", "户籍人口", "流动人口", "城镇人口", "农村人口"],
|
||||
"价格": ["价格", "物价", "CPI", "涨幅", "指数"],
|
||||
"增长": ["增速", "增长率", "增幅", "增长", "上涨", "下降"],
|
||||
}
|
||||
|
||||
# 模板表头到源文档表头的映射缓存
|
||||
_header_mapping_cache: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
def _analyze_source_table_structure(self, source_docs: List["SourceDocument"]) -> Dict[str, Any]:
|
||||
"""
|
||||
分析源文档的表格结构
|
||||
|
||||
Args:
|
||||
source_docs: 源文档列表
|
||||
|
||||
Returns:
|
||||
表格结构分析结果,包含所有表头和样本数据
|
||||
"""
|
||||
table_structures = {}
|
||||
|
||||
for doc_idx, doc in enumerate(source_docs):
|
||||
structured = doc.structured_data if doc.structured_data else {}
|
||||
|
||||
# 处理多 sheet 格式
|
||||
if structured.get("sheets"):
|
||||
for sheet_name, sheet_data in structured.get("sheets", {}).items():
|
||||
if isinstance(sheet_data, dict):
|
||||
columns = sheet_data.get("columns", [])
|
||||
rows = sheet_data.get("rows", [])[:10] # 只取前10行作为样本
|
||||
key = f"doc{doc_idx}_{sheet_name}"
|
||||
table_structures[key] = {
|
||||
"doc_idx": doc_idx,
|
||||
"sheet_name": sheet_name,
|
||||
"columns": columns,
|
||||
"sample_rows": rows,
|
||||
"column_count": len(columns),
|
||||
"row_count": len(sheet_data.get("rows", []))
|
||||
}
|
||||
|
||||
# 处理 tables 格式
|
||||
elif structured.get("tables"):
|
||||
for table_idx, table in enumerate(structured.get("tables", [])[:5]):
|
||||
if isinstance(table, dict):
|
||||
headers = table.get("headers", [])
|
||||
rows = table.get("rows", [])[:10]
|
||||
key = f"doc{doc_idx}_table{table_idx}"
|
||||
table_structures[key] = {
|
||||
"doc_idx": doc_idx,
|
||||
"table_idx": table_idx,
|
||||
"columns": headers,
|
||||
"sample_rows": rows,
|
||||
"column_count": len(headers),
|
||||
"row_count": len(table.get("rows", []))
|
||||
}
|
||||
|
||||
# 处理单 sheet 格式
|
||||
elif structured.get("columns") and structured.get("rows"):
|
||||
columns = structured.get("columns", [])
|
||||
rows = structured.get("rows", [])[:10]
|
||||
key = f"doc{doc_idx}_default"
|
||||
table_structures[key] = {
|
||||
"doc_idx": doc_idx,
|
||||
"columns": columns,
|
||||
"sample_rows": rows,
|
||||
"column_count": len(columns),
|
||||
"row_count": len(structured.get("rows", []))
|
||||
}
|
||||
|
||||
logger.info(f"分析源文档表格结构: {len(table_structures)} 个表格")
|
||||
return table_structures
|
||||
|
||||
def _build_adaptive_header_mapping(
|
||||
self,
|
||||
template_fields: List["TemplateField"],
|
||||
source_table_structures: Dict[str, Any]
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
自适应构建模板表头到源文档表头的映射
|
||||
|
||||
Args:
|
||||
template_fields: 模板字段列表
|
||||
source_table_structures: 源文档表格结构
|
||||
|
||||
Returns:
|
||||
映射字典: {field_name: {source_table_key: {column: idx, match_score: score}}}
|
||||
"""
|
||||
mappings = {}
|
||||
|
||||
for field in template_fields:
|
||||
field_name = field.name
|
||||
field_lower = field_name.lower()
|
||||
field_keywords = set(field_lower.replace(" ", "").split())
|
||||
|
||||
best_matches = {}
|
||||
|
||||
for table_key, table_info in source_table_structures.items():
|
||||
columns = table_info.get("columns", [])
|
||||
if not columns:
|
||||
continue
|
||||
|
||||
best_col_idx = None
|
||||
best_col_name = None
|
||||
best_score = 0
|
||||
|
||||
for col_idx, col in enumerate(columns):
|
||||
col_str = str(col).strip()
|
||||
col_lower = col_str.lower()
|
||||
col_keywords = set(col_lower.replace(" ", "").split())
|
||||
|
||||
score = 0
|
||||
|
||||
# 1. 精确匹配
|
||||
if col_lower == field_lower:
|
||||
score = 1.0
|
||||
# 2. 子字符串匹配
|
||||
elif field_lower in col_lower or col_lower in field_lower:
|
||||
score = 0.8 * max(len(field_lower), len(col_lower)) / min(len(field_lower) + 1, len(col_lower) + 1)
|
||||
# 3. 关键词重叠
|
||||
else:
|
||||
overlap = field_keywords & col_keywords
|
||||
if overlap:
|
||||
score = 0.6 * len(overlap) / max(len(field_keywords), len(col_keywords), 1)
|
||||
|
||||
# 4. 检查通用表头扩展
|
||||
if score < 0.5:
|
||||
for generic, specifics in self.GENERIC_HEADER_EXPANSION.items():
|
||||
if generic in field_lower:
|
||||
for specific in specifics:
|
||||
if specific in col_lower or col_lower in specific:
|
||||
score = 0.7
|
||||
break
|
||||
if score >= 0.5:
|
||||
break
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_col_idx = col_idx
|
||||
best_col_name = col_str
|
||||
|
||||
if best_score >= 0.3 and best_col_idx is not None:
|
||||
best_matches[table_key] = {
|
||||
"column_index": best_col_idx,
|
||||
"column_name": best_col_name,
|
||||
"match_score": best_score,
|
||||
"table_info": table_info
|
||||
}
|
||||
|
||||
if best_matches:
|
||||
mappings[field_name] = best_matches
|
||||
logger.info(f"字段 '{field_name}' 匹配到 {len(best_matches)} 个源表头,最佳匹配: {list(best_matches.values())[0].get('column_name')}")
|
||||
|
||||
return mappings
|
||||
|
||||
def _extract_with_adaptive_mapping(
|
||||
self,
|
||||
source_docs: List["SourceDocument"],
|
||||
field_name: str,
|
||||
mapping: Dict[str, Dict[str, Any]]
|
||||
) -> List[str]:
|
||||
"""
|
||||
使用自适应映射提取字段值
|
||||
|
||||
Args:
|
||||
source_docs: 源文档列表
|
||||
field_name: 字段名
|
||||
mapping: 字段到源表头的映射
|
||||
|
||||
Returns:
|
||||
提取的值列表
|
||||
"""
|
||||
values = []
|
||||
|
||||
if field_name not in mapping:
|
||||
return values
|
||||
|
||||
best_matches = mapping[field_name]
|
||||
|
||||
for table_key, match_info in best_matches.items():
|
||||
table_info = match_info.get("table_info", {})
|
||||
col_idx = match_info.get("column_index", 0)
|
||||
doc_idx = table_info.get("doc_idx", 0)
|
||||
|
||||
if doc_idx >= len(source_docs):
|
||||
continue
|
||||
|
||||
doc = source_docs[doc_idx]
|
||||
structured = doc.structured_data if doc.structured_data else {}
|
||||
|
||||
# 根据表格类型提取值
|
||||
rows = []
|
||||
|
||||
# 多 sheet 格式
|
||||
if structured.get("sheets"):
|
||||
sheet_name = table_info.get("sheet_name")
|
||||
if sheet_name:
|
||||
sheet_data = structured.get("sheets", {}).get(sheet_name, {})
|
||||
rows = sheet_data.get("rows", [])
|
||||
|
||||
# tables 格式
|
||||
elif structured.get("tables"):
|
||||
table_idx = table_info.get("table_idx", 0)
|
||||
tables = structured.get("tables", [])
|
||||
if table_idx < len(tables):
|
||||
rows = tables[table_idx].get("rows", [])
|
||||
|
||||
# 单 sheet 格式
|
||||
elif structured.get("rows"):
|
||||
rows = structured.get("rows", [])
|
||||
|
||||
# 提取指定列的值
|
||||
for row in rows:
|
||||
if isinstance(row, list) and col_idx < len(row):
|
||||
val = self._format_value(row[col_idx])
|
||||
if val and self._is_valid_data_value(val):
|
||||
values.append(val)
|
||||
elif isinstance(row, dict):
|
||||
# 对于 dict 格式的行
|
||||
columns = table_info.get("columns", [])
|
||||
if col_idx < len(columns):
|
||||
col_name = columns[col_idx]
|
||||
val = self._format_value(row.get(col_name, ""))
|
||||
if val and self._is_valid_data_value(val):
|
||||
values.append(val)
|
||||
|
||||
# 过滤和去重
|
||||
seen = set()
|
||||
unique_values = []
|
||||
for v in values:
|
||||
if v not in seen:
|
||||
seen.add(v)
|
||||
unique_values.append(v)
|
||||
|
||||
return unique_values
|
||||
|
||||
def __init__(self):
|
||||
self.llm = llm_service
|
||||
|
||||
@@ -305,6 +549,62 @@ class TemplateFillService:
|
||||
if source_file_paths:
|
||||
for file_path in source_file_paths:
|
||||
try:
|
||||
file_ext = file_path.lower().split('.')[-1]
|
||||
|
||||
# 对于 Word 文档,优先使用 AI 解析
|
||||
if file_ext == 'docx':
|
||||
# 使用 AI 深度解析 Word 文档
|
||||
ai_result = await word_ai_service.parse_word_with_ai(
|
||||
file_path=file_path,
|
||||
user_hint="请提取文档中的所有结构化数据,包括表格、键值对等"
|
||||
)
|
||||
|
||||
if ai_result.get("success"):
|
||||
# AI 解析成功,转换为 SourceDocument 格式
|
||||
parse_type = ai_result.get("type", "unknown")
|
||||
|
||||
# 构建 structured_data
|
||||
doc_structured = {
|
||||
"ai_parsed": True,
|
||||
"parse_type": parse_type,
|
||||
"tables": [],
|
||||
"key_values": ai_result.get("key_values", {}) if "key_values" in ai_result else {},
|
||||
"list_items": ai_result.get("list_items", []) if "list_items" in ai_result else [],
|
||||
"summary": ai_result.get("summary", "") if "summary" in ai_result else ""
|
||||
}
|
||||
|
||||
# 如果 AI 返回了表格数据
|
||||
if parse_type == "table_data":
|
||||
headers = ai_result.get("headers", [])
|
||||
rows = ai_result.get("rows", [])
|
||||
if headers and rows:
|
||||
doc_structured["tables"] = [{
|
||||
"headers": headers,
|
||||
"rows": rows
|
||||
}]
|
||||
doc_structured["columns"] = headers
|
||||
doc_structured["rows"] = rows
|
||||
logger.info(f"AI 表格数据: {len(headers)} 列, {len(rows)} 行")
|
||||
elif parse_type == "structured_text":
|
||||
tables = ai_result.get("tables", [])
|
||||
if tables:
|
||||
doc_structured["tables"] = tables
|
||||
logger.info(f"AI 结构化文本提取到 {len(tables)} 个表格")
|
||||
|
||||
# 获取摘要内容
|
||||
content_text = doc_structured.get("summary", "") or ai_result.get("description", "")
|
||||
|
||||
source_docs.append(SourceDocument(
|
||||
doc_id=file_path,
|
||||
filename=file_path.split("/")[-1] if "/" in file_path else file_path.split("\\")[-1],
|
||||
doc_type="docx",
|
||||
content=content_text,
|
||||
structured_data=doc_structured
|
||||
))
|
||||
logger.info(f"AI 解析 Word 文档: {file_path}, type={parse_type}, tables={len(doc_structured.get('tables', []))}")
|
||||
continue # 跳后续的基础解析
|
||||
|
||||
# 基础解析(Excel 或非 AI 解析的 Word)
|
||||
parser = ParserFactory.get_parser(file_path)
|
||||
result = parser.parse(file_path)
|
||||
if result.success:
|
||||
@@ -1351,6 +1651,36 @@ class TemplateFillService:
|
||||
if all_values:
|
||||
break
|
||||
|
||||
# 处理 AI 解析的 Word 文档键值对格式: {key_values: {"键": "值"}, ...}
|
||||
if structured.get("key_values") and isinstance(structured.get("key_values"), dict):
|
||||
key_values = structured.get("key_values", {})
|
||||
logger.info(f" 检测到 AI 解析键值对格式,共 {len(key_values)} 个键值对")
|
||||
values = self._extract_from_key_values(key_values, field_name)
|
||||
if values:
|
||||
all_values.extend(values)
|
||||
logger.info(f"从 Word AI 键值对提取到 {len(values)} 个值: {values}")
|
||||
break
|
||||
|
||||
# 处理 AI 解析的 list_items 格式
|
||||
if structured.get("list_items") and isinstance(structured.get("list_items"), list):
|
||||
list_items = structured.get("list_items", [])
|
||||
logger.info(f" 检测到 AI 解析列表格式,共 {len(list_items)} 个列表项")
|
||||
values = self._extract_from_list_items(list_items, field_name)
|
||||
if values:
|
||||
all_values.extend(values)
|
||||
logger.info(f"从 Word AI 列表提取到 {len(values)} 个值")
|
||||
break
|
||||
|
||||
# 如果从结构化数据中没有提取到值,且字段是通用表头,搜索文本内容
|
||||
if not all_values and field_name in self.GENERIC_HEADER_EXPANSION:
|
||||
for doc in source_docs:
|
||||
if doc.content:
|
||||
text_values = self._search_generic_header_in_text(doc.content, field_name)
|
||||
if text_values:
|
||||
all_values.extend(text_values)
|
||||
logger.info(f"从文本内容通过通用表头匹配提取到 {len(text_values)} 个值")
|
||||
break
|
||||
|
||||
return all_values
|
||||
|
||||
def _extract_values_from_markdown_table(self, headers: List, rows: List, field_name: str) -> List[str]:
|
||||
@@ -1376,10 +1706,27 @@ class TemplateFillService:
|
||||
# 查找匹配的列索引 - 使用增强的匹配算法
|
||||
target_idx = self._find_best_matching_column(headers, field_name)
|
||||
|
||||
if target_idx is None:
|
||||
# 如果没有找到列匹配,尝试在第一列中搜索字段名(适用于指标在行的文档)
|
||||
matched_row_idx = None
|
||||
if target_idx is None and rows:
|
||||
matched_row_idx = self._search_row_in_first_column(rows, field_name)
|
||||
if matched_row_idx is not None:
|
||||
logger.info(f"在第一列找到匹配: {field_name} -> 行索引 {matched_row_idx} (转置表格结构)")
|
||||
|
||||
if target_idx is None and matched_row_idx is None:
|
||||
logger.warning(f"未找到匹配列: {field_name}, 可用表头: {headers}")
|
||||
return []
|
||||
|
||||
# 如果在第一列找到匹配(转置表格),提取该行的其他列作为值
|
||||
if matched_row_idx is not None:
|
||||
matched_row = rows[matched_row_idx]
|
||||
if isinstance(matched_row, list):
|
||||
# 跳过第一列(指标名),提取后续列的值
|
||||
for val in matched_row[1:]:
|
||||
values.append(self._format_value(val))
|
||||
logger.info(f"转置表格提取到 {len(values)} 个值: {values[:5]}...")
|
||||
return self._filter_valid_values(values)
|
||||
|
||||
logger.info(f"列匹配成功: {field_name} -> {headers[target_idx]} (索引: {target_idx})")
|
||||
|
||||
values = []
|
||||
@@ -1527,6 +1874,149 @@ class TemplateFillService:
|
||||
valid_values.append(val)
|
||||
return valid_values
|
||||
|
||||
def _extract_from_key_values(self, key_values: Dict[str, str], field_name: str) -> List[str]:
|
||||
"""
|
||||
从键值对字典中提取与字段名匹配的值
|
||||
|
||||
Args:
|
||||
key_values: 键值对字典,如 {"医院数量": "38710个", "床位总数": "456789张"}
|
||||
field_name: 要匹配的字段名
|
||||
|
||||
Returns:
|
||||
匹配的值列表
|
||||
"""
|
||||
if not key_values:
|
||||
return []
|
||||
|
||||
field_lower = field_name.lower().strip()
|
||||
field_chars = set(field_lower.replace(" ", ""))
|
||||
field_keywords = set(field_lower.replace(" ", "").split())
|
||||
|
||||
best_match_key = None
|
||||
best_match_score = 0
|
||||
|
||||
for key, value in key_values.items():
|
||||
key_str = str(key).strip()
|
||||
key_lower = key_str.lower()
|
||||
key_chars = set(key_lower.replace(" ", ""))
|
||||
|
||||
if not key_str or not value:
|
||||
continue
|
||||
|
||||
# 策略1: 精确匹配(忽略大小写)
|
||||
if key_lower == field_lower:
|
||||
logger.info(f"键值对精确匹配: {field_name} -> {key_str}: {value}")
|
||||
return [str(value)]
|
||||
|
||||
# 策略2: 子字符串匹配
|
||||
if field_lower in key_lower or key_lower in field_lower:
|
||||
score = max(len(field_lower), len(key_lower)) / min(len(field_lower) + 1, len(key_lower) + 1)
|
||||
if score > best_match_score:
|
||||
best_match_score = score
|
||||
best_match_key = (key_str, value)
|
||||
|
||||
# 策略3: 关键词重叠匹配(适用于中文)
|
||||
key_keywords = set(key_lower.replace(" ", "").split())
|
||||
overlap = field_keywords & key_keywords
|
||||
if overlap and len(overlap) > 0:
|
||||
score = len(overlap) / max(len(field_keywords), len(key_keywords), 1)
|
||||
if score > best_match_score:
|
||||
best_match_score = score
|
||||
best_match_key = (key_str, value)
|
||||
|
||||
# 策略4: 字符级包含匹配(适用于中文短字段)
|
||||
char_overlap = field_chars & key_chars
|
||||
if char_overlap:
|
||||
char_score = len(char_overlap) / max(len(field_chars), len(key_chars), 1)
|
||||
# 对于短字段(<=4字符),降低要求
|
||||
if len(field_chars) <= 4 and char_score >= 0.5:
|
||||
if char_score > best_match_score:
|
||||
best_match_score = char_score
|
||||
best_match_key = (key_str, value)
|
||||
elif char_score > best_match_score and len(char_overlap) >= 2:
|
||||
best_match_score = char_score
|
||||
best_match_key = (key_str, value)
|
||||
|
||||
# 降低阈值到 0.2,允许更多模糊匹配
|
||||
if best_match_score >= 0.2 and best_match_key:
|
||||
logger.info(f"键值对模糊匹配: {field_name} -> {best_match_key[0]}: {best_match_key[1]} (分数: {best_match_score:.2f})")
|
||||
return [str(best_match_key[1])]
|
||||
|
||||
logger.warning(f"键值对未匹配到: {field_name}, 可用键: {list(key_values.keys())}")
|
||||
return []
|
||||
|
||||
def _extract_from_list_items(self, list_items: List[str], field_name: str) -> List[str]:
|
||||
"""
|
||||
从列表项中提取与字段名匹配的值
|
||||
|
||||
Args:
|
||||
list_items: 列表项,如 ["医院数量: 38710个", "床位总数: 456789张", ...]
|
||||
field_name: 要匹配的字段名
|
||||
|
||||
Returns:
|
||||
匹配的值列表
|
||||
"""
|
||||
if not list_items:
|
||||
return []
|
||||
|
||||
field_lower = field_name.lower().strip()
|
||||
field_keywords = set(field_lower.replace(" ", "").split())
|
||||
|
||||
matched_values = []
|
||||
|
||||
for item in list_items:
|
||||
item_str = str(item).strip()
|
||||
if not item_str:
|
||||
continue
|
||||
|
||||
item_lower = item_str.lower()
|
||||
|
||||
# 策略1: 检查列表项是否以字段名开头(格式如 "医院数量: 38710个")
|
||||
if ':' in item_str or ':' in item_str:
|
||||
parts = item_str.replace(':', ':').split(':', 1)
|
||||
if len(parts) == 2:
|
||||
key = parts[0].strip()
|
||||
value = parts[1].strip()
|
||||
key_lower = key.lower()
|
||||
|
||||
# 精确匹配键
|
||||
if key_lower == field_lower:
|
||||
logger.info(f"列表项键值精确匹配: {field_name} -> {value}")
|
||||
return [value]
|
||||
|
||||
# 子字符串匹配
|
||||
if field_lower in key_lower or key_lower in field_lower:
|
||||
score = max(len(field_lower), len(key_lower)) / min(len(field_lower) + 1, len(key_lower) + 1)
|
||||
if score >= 0.2:
|
||||
logger.info(f"列表项键值模糊匹配: {field_name} -> {key}: {value} (分数: {score:.2f})")
|
||||
matched_values.append(value)
|
||||
|
||||
# 关键词重叠
|
||||
key_keywords = set(key_lower.replace(" ", "").split())
|
||||
overlap = field_keywords & key_keywords
|
||||
if overlap:
|
||||
score = len(overlap) / max(len(field_keywords), len(key_keywords), 1)
|
||||
if score >= 0.2:
|
||||
matched_values.append(value)
|
||||
|
||||
# 策略2: 直接匹配整个列表项
|
||||
if field_lower in item_lower or item_lower in field_lower:
|
||||
matched_values.append(item_str)
|
||||
continue
|
||||
|
||||
# 策略3: 关键词重叠
|
||||
item_keywords = set(item_lower.replace(" ", "").split())
|
||||
overlap = field_keywords & item_keywords
|
||||
if overlap and len(overlap) >= 2: # 至少2个关键词重叠
|
||||
score = len(overlap) / max(len(field_keywords), len(item_keywords), 1)
|
||||
if score >= 0.2:
|
||||
matched_values.append(item_str)
|
||||
|
||||
if matched_values:
|
||||
logger.info(f"列表项匹配到 {len(matched_values)} 个: {matched_values[:5]}")
|
||||
|
||||
return matched_values
|
||||
|
||||
def _find_best_matching_column(self, headers: List, field_name: str) -> Optional[int]:
|
||||
"""
|
||||
查找最佳匹配的列索引
|
||||
@@ -1535,6 +2025,7 @@ class TemplateFillService:
|
||||
1. 精确匹配(忽略大小写)
|
||||
2. 子字符串匹配(字段名在表头中,或表头在字段名中)
|
||||
3. 关键词重叠匹配(中文字符串分割后比对)
|
||||
4. 字符级包含匹配(适用于中文短字段)
|
||||
|
||||
Args:
|
||||
headers: 表头列表
|
||||
@@ -1544,6 +2035,8 @@ class TemplateFillService:
|
||||
匹配的列索引,找不到返回 None
|
||||
"""
|
||||
field_lower = field_name.lower().strip()
|
||||
# 对中文进行字符级拆分,增加匹配的灵活性
|
||||
field_chars = set(field_lower.replace(" ", ""))
|
||||
field_keywords = set(field_lower.replace(" ", "").split())
|
||||
|
||||
best_match_idx = None
|
||||
@@ -1560,6 +2053,7 @@ class TemplateFillService:
|
||||
|
||||
# 策略1: 精确匹配(忽略大小写)
|
||||
if header_lower == field_lower:
|
||||
logger.info(f"精确匹配: {field_name} -> {header_str}")
|
||||
return idx
|
||||
|
||||
# 策略2: 子字符串匹配
|
||||
@@ -1580,13 +2074,93 @@ class TemplateFillService:
|
||||
best_match_score = score
|
||||
best_match_idx = idx
|
||||
|
||||
# 只有当匹配分数超过阈值时才返回
|
||||
if best_match_score >= 0.3:
|
||||
# 策略4: 字符级包含匹配(适用于中文短字段,如"医院"匹配"医院数量")
|
||||
header_chars = set(header_lower.replace(" ", ""))
|
||||
char_overlap = field_chars & header_chars
|
||||
if char_overlap:
|
||||
# 计算字符重叠率,但要求至少有一定数量的重叠字符
|
||||
char_score = len(char_overlap) / max(len(field_chars), len(header_chars), 1)
|
||||
# 对于短字段(<=4字符),降低要求,只要有重叠且字符score较高即可
|
||||
if len(field_chars) <= 4 and char_score >= 0.5:
|
||||
if char_score > best_match_score:
|
||||
best_match_score = char_score
|
||||
best_match_idx = idx
|
||||
elif char_score > best_match_score and len(char_overlap) >= 2:
|
||||
# 对于较长字段,要求至少2个字符重叠
|
||||
best_match_score = char_score
|
||||
best_match_idx = idx
|
||||
|
||||
# 降低阈值到 0.2,允许更多模糊匹配
|
||||
if best_match_score >= 0.2:
|
||||
logger.info(f"模糊匹配: {field_name} -> {headers[best_match_idx]} (分数: {best_match_score:.2f})")
|
||||
return best_match_idx
|
||||
|
||||
return None
|
||||
|
||||
def _search_row_in_first_column(self, rows: List, field_name: str) -> Optional[int]:
|
||||
"""
|
||||
在表格第一列中搜索字段名(适用于指标在行的转置表格结构)
|
||||
|
||||
对于某些中文统计文档,表格结构是转置的:
|
||||
- 第一列是指标名称(如"医院数量")
|
||||
- 其他列是年份或数值
|
||||
|
||||
Args:
|
||||
rows: 数据行列表
|
||||
field_name: 要搜索的字段名
|
||||
|
||||
Returns:
|
||||
匹配的列索引(始终返回0,因为是第一列),如果没找到返回None
|
||||
"""
|
||||
if not rows or not field_name:
|
||||
return None
|
||||
|
||||
field_lower = field_name.lower().strip()
|
||||
field_chars = set(field_lower.replace(" ", ""))
|
||||
field_keywords = set(field_lower.replace(" ", "").split())
|
||||
|
||||
for row_idx, row in enumerate(rows):
|
||||
if not isinstance(row, list) or len(row) == 0:
|
||||
continue
|
||||
|
||||
first_cell = str(row[0]).strip()
|
||||
if not first_cell:
|
||||
continue
|
||||
|
||||
first_cell_lower = first_cell.lower()
|
||||
|
||||
# 精确匹配
|
||||
if first_cell_lower == field_lower:
|
||||
logger.info(f"第一列精确匹配字段: {field_name} -> {first_cell} (行{row_idx})")
|
||||
return 0
|
||||
|
||||
# 子字符串匹配
|
||||
if field_lower in first_cell_lower or first_cell_lower in field_lower:
|
||||
score = max(len(field_lower), len(first_cell_lower)) / min(len(field_lower) + 1, len(first_cell_lower) + 1)
|
||||
if score >= 0.5:
|
||||
logger.info(f"第一列模糊匹配字段: {field_name} -> {first_cell} (行{row_idx}, 分数:{score:.2f})")
|
||||
return 0
|
||||
|
||||
# 关键词重叠匹配
|
||||
first_keywords = set(first_cell_lower.replace(" ", "").split())
|
||||
overlap = field_keywords & first_keywords
|
||||
if overlap and len(overlap) >= 2:
|
||||
score = len(overlap) / max(len(field_keywords), len(first_keywords), 1)
|
||||
if score >= 0.3:
|
||||
logger.info(f"第一列关键词匹配: {field_name} -> {first_cell} (行{row_idx}, 分数:{score:.2f})")
|
||||
return 0
|
||||
|
||||
# 字符级匹配(短字段)
|
||||
first_chars = set(first_cell_lower.replace(" ", ""))
|
||||
char_overlap = field_chars & first_chars
|
||||
if char_overlap and len(field_chars) <= 4:
|
||||
char_score = len(char_overlap) / max(len(field_chars), len(first_chars), 1)
|
||||
if char_score >= 0.5:
|
||||
logger.info(f"第一列字符匹配: {field_name} -> {first_cell} (行{row_idx}, 分数:{char_score:.2f})")
|
||||
return 0
|
||||
|
||||
return None
|
||||
|
||||
def _extract_column_values(self, rows: List, columns: List, field_name: str) -> List[str]:
|
||||
"""
|
||||
从 rows 和 columns 中提取指定列的值
|
||||
@@ -1677,6 +2251,55 @@ class TemplateFillService:
|
||||
|
||||
return str(val)
|
||||
|
||||
def _search_generic_header_in_text(self, text: str, field_name: str) -> List[str]:
|
||||
"""
|
||||
从文本中搜索通用表头对应的具体值
|
||||
|
||||
例如:表头"机构" -> 搜索文本中的"医院"、"学校"、"企业"等
|
||||
|
||||
Args:
|
||||
text: 文档文本内容
|
||||
field_name: 字段名称(可能是通用表头)
|
||||
|
||||
Returns:
|
||||
匹配到的值列表
|
||||
"""
|
||||
import re
|
||||
|
||||
# 检查是否是通用表头
|
||||
generic_terms = self.GENERIC_HEADER_EXPANSION.get(field_name, [])
|
||||
if not generic_terms:
|
||||
return []
|
||||
|
||||
matched_values = []
|
||||
|
||||
for term in generic_terms:
|
||||
# 搜索 term + 数字/量词 的模式,如 "医院 100所"
|
||||
patterns = [
|
||||
rf'{re.escape(term)}[\s\d所个家级人万元亿元%‰]+', # 医院100所, 企业50家
|
||||
rf'{re.escape(term)}[::\s]+(\d+[\d。,,]?\d*)', # 医院:100
|
||||
rf'(\d+[\d。,,]?\d*)[^\d]*{re.escape(term)}', # 100家医院
|
||||
]
|
||||
for pattern in patterns:
|
||||
matches = re.findall(pattern, text, re.IGNORECASE)
|
||||
for match in matches:
|
||||
val = match.strip() if isinstance(match, str) else match
|
||||
if val and len(str(val)) < 100:
|
||||
matched_values.append(str(val))
|
||||
|
||||
# 去重并保持顺序
|
||||
seen = set()
|
||||
unique_values = []
|
||||
for v in matched_values:
|
||||
if v not in seen:
|
||||
seen.add(v)
|
||||
unique_values.append(v)
|
||||
|
||||
if unique_values:
|
||||
logger.info(f"通用表头 '{field_name}' 匹配到值: {unique_values[:10]}")
|
||||
|
||||
return unique_values
|
||||
|
||||
def _extract_values_from_json(self, result) -> List[str]:
|
||||
"""
|
||||
从解析后的 JSON 对象/数组中提取值数组
|
||||
@@ -2236,29 +2859,32 @@ class TemplateFillService:
|
||||
- 二级分类:如"医院"下分为"公立医院"、"民营医院"
|
||||
|
||||
4. **生成字段**:
|
||||
- 字段名要简洁,如:"医院数量"、"病床使用率"
|
||||
- 优先选择:总数 + 主要分类
|
||||
- 字段名要详细具体,能区分不同数据,如:"医院数量(个)"、"病床使用率(%)"、"公立医院数量"
|
||||
- 优先选择:总数 + 主要分类 + 重要指标
|
||||
|
||||
5. **生成数量**:
|
||||
- 生成5-7个最有代表性的字段
|
||||
- 生成10-15个最有代表性的字段,确保覆盖主要数据指标
|
||||
|
||||
6. **添加字段说明**:
|
||||
- 每个字段可以添加 hint 说明字段的含义和数据来源
|
||||
|
||||
请严格按照以下 JSON 格式输出(只需输出 JSON,不要其他内容):
|
||||
{{
|
||||
"fields": [
|
||||
{{"name": "字段名1"}},
|
||||
{{"name": "字段名2"}}
|
||||
{{"name": "医院数量", "hint": "从文档中提取医院总数,包括公立和民营医院"}},
|
||||
{{"name": "病床使用率", "hint": "提取病床使用率数据"}}
|
||||
]
|
||||
}}
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的表格设计助手。请严格按JSON格式输出,只返回纯数据字段名,不要source、备注、说明等辅助字段。"},
|
||||
{"role": "system", "content": "你是一个专业的表格设计助手。请严格按JSON格式输出,为每个字段生成详细名称和hint说明。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
response = await self.llm.chat(
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
max_tokens=2000
|
||||
max_tokens=4000
|
||||
)
|
||||
|
||||
content = self.llm.extract_message_content(response)
|
||||
|
||||
@@ -192,13 +192,15 @@ class WordAIService:
|
||||
result = self._parse_json_response(content)
|
||||
|
||||
if result:
|
||||
logger.info(f"AI 表格提取成功: {len(result.get('rows', []))} 行数据")
|
||||
logger.info(f"AI 表格提取成功: {len(result.get('rows', []))} 行数据, key_values={len(result.get('key_values', {}))}, list_items={len(result.get('list_items', []))}")
|
||||
return {
|
||||
"success": True,
|
||||
"type": "table_data",
|
||||
"headers": result.get("headers", []),
|
||||
"rows": result.get("rows", []),
|
||||
"description": result.get("description", "")
|
||||
"description": result.get("description", ""),
|
||||
"key_values": result.get("key_values", {}),
|
||||
"list_items": result.get("list_items", [])
|
||||
}
|
||||
else:
|
||||
# 如果 AI 返回格式不对,尝试直接解析表格
|
||||
|
||||
Reference in New Issue
Block a user