Files
FilesReadSystem/backend/app/services/template_fill_service.py
2026-04-09 17:40:10 +08:00

676 lines
27 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
表格模板填写服务
从非结构化文档中检索信息并填写到表格模板
"""
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from app.core.database import mongodb
from app.services.llm_service import llm_service
from app.core.document_parser import ParserFactory
logger = logging.getLogger(__name__)
@dataclass
class TemplateField:
"""模板字段"""
cell: str # 单元格位置,如 "A1"
name: str # 字段名称
field_type: str = "text" # 字段类型: text/number/date
required: bool = True
hint: str = "" # 字段提示词
@dataclass
class SourceDocument:
"""源文档"""
doc_id: str
filename: str
doc_type: str
content: str = ""
structured_data: Dict[str, Any] = field(default_factory=dict)
@dataclass
class FillResult:
"""填写结果"""
field: str
values: List[Any] = None # 支持多个值
value: Any = "" # 保留兼容
source: str = "" # 来源文档
confidence: float = 1.0 # 置信度
def __post_init__(self):
if self.values is None:
self.values = []
class TemplateFillService:
"""表格填写服务"""
def __init__(self):
self.llm = llm_service
async def fill_template(
self,
template_fields: List[TemplateField],
source_doc_ids: Optional[List[str]] = None,
source_file_paths: Optional[List[str]] = None,
user_hint: Optional[str] = None
) -> Dict[str, Any]:
"""
填写表格模板
Args:
template_fields: 模板字段列表
source_doc_ids: 源文档 MongoDB ID 列表
source_file_paths: 源文档文件路径列表
user_hint: 用户提示(如"请从合同文档中提取"
Returns:
填写结果
"""
filled_data = {}
fill_details = []
logger.info(f"开始填表: {len(template_fields)} 个字段, {len(source_doc_ids or [])} 个源文档")
# 1. 加载源文档内容
source_docs = await self._load_source_documents(source_doc_ids, source_file_paths)
logger.info(f"加载了 {len(source_docs)} 个源文档")
if not source_docs:
logger.warning("没有找到源文档,填表结果将全部为空")
# 2. 对每个字段进行提取
for idx, field in enumerate(template_fields):
try:
logger.info(f"提取字段 [{idx+1}/{len(template_fields)}]: {field.name}")
# 从源文档中提取字段值
result = await self._extract_field_value(
field=field,
source_docs=source_docs,
user_hint=user_hint
)
# 存储结果 - 使用 values 数组
filled_data[field.name] = result.values if result.values else [""]
fill_details.append({
"field": field.name,
"cell": field.cell,
"values": result.values,
"value": result.value,
"source": result.source,
"confidence": result.confidence
})
logger.info(f"字段 {field.name} 填写完成: {len(result.values)} 个值")
except Exception as e:
logger.error(f"填写字段 {field.name} 失败: {str(e)}", exc_info=True)
filled_data[field.name] = [f"[提取失败: {str(e)}]"]
fill_details.append({
"field": field.name,
"cell": field.cell,
"values": [f"[提取失败]"],
"value": f"[提取失败]",
"source": "error",
"confidence": 0.0
})
# 计算最大行数
max_rows = max(len(v) for v in filled_data.values()) if filled_data else 1
logger.info(f"填表完成: {len(filled_data)} 个字段, 最大行数: {max_rows}")
return {
"success": True,
"filled_data": filled_data,
"fill_details": fill_details,
"source_doc_count": len(source_docs),
"max_rows": max_rows
}
async def _load_source_documents(
self,
source_doc_ids: Optional[List[str]] = None,
source_file_paths: Optional[List[str]] = None
) -> List[SourceDocument]:
"""
加载源文档内容
Args:
source_doc_ids: MongoDB 文档 ID 列表
source_file_paths: 源文档文件路径列表
Returns:
源文档列表
"""
source_docs = []
# 1. 从 MongoDB 加载文档
if source_doc_ids:
for doc_id in source_doc_ids:
try:
doc = await mongodb.get_document(doc_id)
if doc:
source_docs.append(SourceDocument(
doc_id=doc_id,
filename=doc.get("metadata", {}).get("original_filename", "unknown"),
doc_type=doc.get("doc_type", "unknown"),
content=doc.get("content", ""),
structured_data=doc.get("structured_data", {})
))
logger.info(f"从MongoDB加载文档: {doc_id}")
except Exception as e:
logger.error(f"从MongoDB加载文档失败 {doc_id}: {str(e)}")
# 2. 从文件路径加载文档
if source_file_paths:
for file_path in source_file_paths:
try:
parser = ParserFactory.get_parser(file_path)
result = parser.parse(file_path)
if result.success:
# result.data 的结构取决于解析器类型:
# - Excel 单 sheet: {columns: [...], rows: [...], row_count, column_count}
# - Excel 多 sheet: {sheets: {sheet_name: {columns, rows, ...}}}
# - Word/TXT: {content: "...", structured_data: {...}}
doc_data = result.data if result.data else {}
doc_content = doc_data.get("content", "") if isinstance(doc_data, dict) else ""
doc_structured = doc_data if isinstance(doc_data, dict) and "rows" in doc_data or isinstance(doc_data, dict) and "sheets" in doc_data else {}
source_docs.append(SourceDocument(
doc_id=file_path,
filename=result.metadata.get("filename", file_path.split("/")[-1]),
doc_type=result.metadata.get("extension", "unknown").replace(".", ""),
content=doc_content,
structured_data=doc_structured
))
logger.info(f"从文件加载文档: {file_path}, content长度: {len(doc_content)}, structured数据: {bool(doc_structured)}")
except Exception as e:
logger.error(f"从文件加载文档失败 {file_path}: {str(e)}")
return source_docs
async def _extract_field_value(
self,
field: TemplateField,
source_docs: List[SourceDocument],
user_hint: Optional[str] = None
) -> FillResult:
"""
使用 LLM 从源文档中提取字段值
Args:
field: 字段定义
source_docs: 源文档列表
user_hint: 用户提示
Returns:
提取结果
"""
if not source_docs:
return FillResult(
field=field.name,
value="",
source="无源文档",
confidence=0.0
)
# 构建上下文文本 - 传入字段名,只提取该列数据
context_text = self._build_context_text(source_docs, field_name=field.name, max_length=8000)
# 构建提示词
hint_text = field.hint if field.hint else f"请提取{field.name}的信息"
if user_hint:
hint_text = f"{user_hint}{hint_text}"
prompt = f"""你是一个专业的数据提取专家。请从以下文档内容中提取"{field.name}"字段的所有行数据。
参考文档内容(已提取" {field.name}"列的数据):
{context_text}
请提取上述所有行的" {field.name}"值,存入数组。每一行对应数组中的一个元素。
如果某行该字段为空,请用空字符串""占位。
请严格按照以下 JSON 格式输出,不要添加任何解释:
{{
"values": ["第1行的值", "第2行的值", "第3行的值", ...],
"source": "数据来源的文档描述",
"confidence": 0.0到1.0之间的置信度
}}
"""
# 调用 LLM
messages = [
{"role": "system", "content": "你是一个专业的数据提取助手。请严格按JSON格式输出。"},
{"role": "user", "content": prompt}
]
try:
response = await self.llm.chat(
messages=messages,
temperature=0.1,
max_tokens=500
)
content = self.llm.extract_message_content(response)
# 解析 JSON 响应
import json
import re
# 尝试提取 JSON使用更严格的匹配
extracted_values = []
extracted_value = ""
extracted_source = "LLM生成"
confidence = 0.5
try:
# 方法1: 尝试直接解析整个 content
result = json.loads(content)
if isinstance(result, dict):
# 优先使用 values 数组格式
if "values" in result and isinstance(result["values"], list):
extracted_values = [str(v) for v in result["values"]]
logger.info(f"字段 {field.name} 使用 values 数组格式: {len(extracted_values)} 个值")
elif "value" in result:
extracted_value = str(result.get("value", ""))
extracted_values = [extracted_value] if extracted_value else []
extracted_source = result.get("source", "LLM生成")
confidence = float(result.get("confidence", 0.5))
logger.info(f"字段 {field.name} 直接 JSON 解析成功")
except json.JSONDecodeError:
# 方法2: 尝试提取 JSON 对象
json_match = re.search(r'\{[\s\S]*\}', content)
if json_match:
try:
result = json.loads(json_match.group())
if isinstance(result, dict):
# 优先使用 values 数组格式
if "values" in result and isinstance(result["values"], list):
extracted_values = [str(v) for v in result["values"]]
logger.info(f"字段 {field.name} 使用 values 数组格式: {len(extracted_values)} 个值")
elif "value" in result:
extracted_value = str(result.get("value", ""))
extracted_values = [extracted_value] if extracted_value else []
extracted_source = result.get("source", "LLM生成")
confidence = float(result.get("confidence", 0.5))
logger.info(f"字段 {field.name} 正则 JSON 解析成功")
else:
logger.warning(f"字段 {field.name} JSON 不是字典格式")
except json.JSONDecodeError as e:
logger.error(f"字段 {field.name} JSON 解析失败: {str(e)}")
# 如果 JSON 解析失败,尝试从文本中提取
extracted_values = self._extract_values_from_text(content, field.name)
extracted_source = "文本提取"
confidence = 0.3
else:
logger.warning(f"字段 {field.name} 未找到 JSON: {content[:200]}")
extracted_values = self._extract_values_from_text(content, field.name)
extracted_source = "文本提取"
confidence = 0.3
# 如果没有提取到值,返回空
if not extracted_values:
extracted_values = [""]
return FillResult(
field=field.name,
values=extracted_values,
value=extracted_values[0] if extracted_values else "",
source=extracted_source,
confidence=confidence
)
except Exception as e:
logger.error(f"LLM 提取失败: {str(e)}")
return FillResult(
field=field.name,
values=[""],
value="",
source=f"提取失败: {str(e)}",
confidence=0.0
)
def _build_context_text(self, source_docs: List[SourceDocument], field_name: str = None, max_length: int = 8000) -> str:
"""
构建上下文文本
Args:
source_docs: 源文档列表
field_name: 需要提取的字段名(可选,用于只提取特定列)
max_length: 最大字符数
Returns:
上下文文本
"""
contexts = []
total_length = 0
for doc in source_docs:
# 优先使用结构化数据(表格),其次使用文本内容
doc_content = ""
row_count = 0
if doc.structured_data and doc.structured_data.get("sheets"):
# parse_all_sheets 格式: {sheets: {sheet_name: {columns, rows}}}
sheets = doc.structured_data.get("sheets", {})
for sheet_name, sheet_data in sheets.items():
if isinstance(sheet_data, dict):
columns = sheet_data.get("columns", [])
rows = sheet_data.get("rows", [])
if rows and columns:
doc_content += f"\n【文档: {doc.filename} - {sheet_name},共 {len(rows)} 行】\n"
# 如果指定了字段名,只提取该列数据
if field_name:
# 查找匹配的列(模糊匹配)
target_col = None
for col in columns:
if field_name.lower() in str(col).lower() or str(col).lower() in field_name.lower():
target_col = col
break
if target_col:
doc_content += f"列名: {target_col}\n"
for row_idx, row in enumerate(rows):
if isinstance(row, dict):
val = row.get(target_col, "")
elif isinstance(row, list) and target_col in columns:
val = row[columns.index(target_col)]
else:
val = ""
doc_content += f"{row_idx+1}: {val}\n"
row_count += 1
else:
# 列名不匹配,输出所有列(但只输出关键列)
doc_content += " | ".join(str(col) for col in columns) + "\n"
for row in rows:
if isinstance(row, dict):
doc_content += " | ".join(str(row.get(col, "")) for col in columns) + "\n"
elif isinstance(row, list):
doc_content += " | ".join(str(cell) for cell in row) + "\n"
row_count += 1
else:
# 输出所有列和行
doc_content += " | ".join(str(col) for col in columns) + "\n"
for row in rows:
if isinstance(row, dict):
doc_content += " | ".join(str(row.get(col, "")) for col in columns) + "\n"
elif isinstance(row, list):
doc_content += " | ".join(str(cell) for cell in row) + "\n"
row_count += 1
elif doc.structured_data and doc.structured_data.get("rows"):
# Excel 单 sheet 格式: {columns: [...], rows: [...], ...}
columns = doc.structured_data.get("columns", [])
rows = doc.structured_data.get("rows", [])
if rows and columns:
doc_content += f"\n【文档: {doc.filename},共 {len(rows)} 行】\n"
if field_name:
target_col = None
for col in columns:
if field_name.lower() in str(col).lower() or str(col).lower() in field_name.lower():
target_col = col
break
if target_col:
doc_content += f"列名: {target_col}\n"
for row_idx, row in enumerate(rows):
if isinstance(row, dict):
val = row.get(target_col, "")
elif isinstance(row, list) and target_col in columns:
val = row[columns.index(target_col)]
else:
val = ""
doc_content += f"{row_idx+1}: {val}\n"
row_count += 1
else:
doc_content += " | ".join(str(col) for col in columns) + "\n"
for row in rows:
if isinstance(row, dict):
doc_content += " | ".join(str(row.get(col, "")) for col in columns) + "\n"
elif isinstance(row, list):
doc_content += " | ".join(str(cell) for cell in row) + "\n"
row_count += 1
else:
doc_content += " | ".join(str(col) for col in columns) + "\n"
for row in rows:
if isinstance(row, dict):
doc_content += " | ".join(str(row.get(col, "")) for col in columns) + "\n"
elif isinstance(row, list):
doc_content += " | ".join(str(cell) for cell in row) + "\n"
row_count += 1
elif doc.content:
doc_content = doc.content[:5000]
if doc_content:
doc_context = f"【文档: {doc.filename} ({doc.doc_type})】\n{doc_content}"
logger.info(f"文档 {doc.filename} 上下文长度: {len(doc_context)}, 行数: {row_count}")
if total_length + len(doc_context) <= max_length:
contexts.append(doc_context)
total_length += len(doc_context)
else:
remaining = max_length - total_length
if remaining > 100:
doc_context = doc_context[:remaining] + f"\n...(内容被截断)"
contexts.append(doc_context)
logger.warning(f"上下文被截断: {doc.filename}, 总长度: {total_length + len(doc_context)}")
break
result = "\n\n".join(contexts) if contexts else "(源文档内容为空)"
logger.info(f"最终上下文长度: {len(result)}")
return result
async def get_template_fields_from_file(
self,
file_path: str,
file_type: str = "xlsx"
) -> List[TemplateField]:
"""
从模板文件提取字段定义
Args:
file_path: 模板文件路径
file_type: 文件类型 (xlsx/xls/docx)
Returns:
字段列表
"""
fields = []
try:
if file_type in ["xlsx", "xls"]:
fields = await self._get_template_fields_from_excel(file_path)
elif file_type == "docx":
fields = await self._get_template_fields_from_docx(file_path)
except Exception as e:
logger.error(f"提取模板字段失败: {str(e)}")
return fields
async def _get_template_fields_from_excel(self, file_path: str) -> List[TemplateField]:
"""从 Excel 模板提取字段"""
fields = []
try:
import pandas as pd
df = pd.read_excel(file_path, nrows=5)
for idx, col in enumerate(df.columns):
cell = self._column_to_cell(idx)
col_str = str(col)
fields.append(TemplateField(
cell=cell,
name=col_str,
field_type=self._infer_field_type_from_value(df[col].iloc[0] if len(df) > 0 else ""),
required=True,
hint=""
))
except Exception as e:
logger.error(f"从Excel提取字段失败: {str(e)}")
return fields
async def _get_template_fields_from_docx(self, file_path: str) -> List[TemplateField]:
"""从 Word 模板提取字段"""
fields = []
try:
from docx import Document
doc = Document(file_path)
for table_idx, table in enumerate(doc.tables):
for row_idx, row in enumerate(table.rows):
cells = [cell.text.strip() for cell in row.cells]
# 假设第一列是字段名
if cells and cells[0]:
field_name = cells[0]
hint = cells[1] if len(cells) > 1 else ""
# 跳过空行或标题行
if field_name and field_name not in ["", "字段名", "名称", "项目"]:
fields.append(TemplateField(
cell=f"T{table_idx}R{row_idx}",
name=field_name,
field_type=self._infer_field_type_from_hint(hint),
required=True,
hint=hint
))
except Exception as e:
logger.error(f"从Word提取字段失败: {str(e)}")
return fields
def _infer_field_type_from_hint(self, hint: str) -> str:
"""从提示词推断字段类型"""
hint_lower = hint.lower()
date_keywords = ["", "", "", "日期", "时间", "出生"]
if any(kw in hint for kw in date_keywords):
return "date"
number_keywords = ["数量", "金额", "人数", "面积", "增长", "比率", "%", "", "总计", "合计"]
if any(kw in hint_lower for kw in number_keywords):
return "number"
return "text"
def _infer_field_type_from_value(self, value: Any) -> str:
"""从示例值推断字段类型"""
if value is None or value == "":
return "text"
value_str = str(value)
# 检查日期模式
import re
if re.search(r'\d{4}[年/-]\d{1,2}[月/-]\d{1,2}', value_str):
return "date"
# 检查数值
try:
float(value_str.replace(',', '').replace('%', ''))
return "number"
except ValueError:
pass
return "text"
def _column_to_cell(self, col_idx: int) -> str:
"""将列索引转换为单元格列名 (0 -> A, 1 -> B, ...)"""
result = ""
while col_idx >= 0:
result = chr(65 + (col_idx % 26)) + result
col_idx = col_idx // 26 - 1
return result
def _extract_value_from_text(self, text: str, field_name: str) -> str:
"""
从非 JSON 文本中提取字段值(单值版本)
Args:
text: 原始文本
field_name: 字段名称
Returns:
提取的值
"""
values = self._extract_values_from_text(text, field_name)
return values[0] if values else ""
def _extract_values_from_text(self, text: str, field_name: str) -> List[str]:
"""
从非 JSON 文本中提取多个字段值
Args:
text: 原始文本
field_name: 字段名称
Returns:
提取的值列表
"""
import re
# 尝试匹配 JSON 数组格式
array_match = re.search(r'\[[\s\S]*\]', text)
if array_match:
try:
arr = json.loads(array_match.group())
if isinstance(arr, list):
return [str(v) for v in arr if v]
except:
pass
# 尝试用分号分割(如果文本中有分号分隔的多个值)
if '' in text or ';' in text:
separator = '' if '' in text else ';'
parts = text.split(separator)
values = []
for part in parts:
part = part.strip()
if part and len(part) < 500:
# 清理 Markdown 格式
part = re.sub(r'^\*\*|\*\*$', '', part)
part = re.sub(r'^\*|\*$', '', part)
values.append(part.strip())
if values:
return values
# 尝试多种模式匹配
patterns = [
# "字段名: 值" 或 "字段名:值" 格式
rf'{re.escape(field_name)}[:]\s*(.+?)(?:\n|$)',
# "值" 在引号中
rf'"value"\s*:\s*"([^"]+)"',
# "值" 在单引号中
rf"['\"]?value['\"]?\s*:\s*['\"]([^'\"]+)['\"]",
]
for pattern in patterns:
match = re.search(pattern, text, re.DOTALL)
if match:
value = match.group(1).strip()
# 清理 Markdown 格式
value = re.sub(r'^\*\*|\*\*$', '', value)
value = re.sub(r'^\*|\*$', '', value)
value = value.strip()
if value and len(value) < 1000:
return [value]
# 如果无法匹配,返回原始内容
content = text.strip()[:500] if text.strip() else ""
return [content] if content else []
# ==================== 全局单例 ====================
template_fill_service = TemplateFillService()