zyh
This commit is contained in:
@@ -4,13 +4,12 @@
|
||||
从非结构化文档中检索信息并填写到表格模板
|
||||
"""
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.database import mongodb
|
||||
from app.services.rag_service import rag_service
|
||||
from app.services.llm_service import llm_service
|
||||
from app.services.excel_storage_service import excel_storage_service
|
||||
from app.core.document_parser import ParserFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -22,6 +21,17 @@ class TemplateField:
|
||||
name: str # 字段名称
|
||||
field_type: str = "text" # 字段类型: text/number/date
|
||||
required: bool = True
|
||||
hint: str = "" # 字段提示词
|
||||
|
||||
|
||||
@dataclass
|
||||
class SourceDocument:
|
||||
"""源文档"""
|
||||
doc_id: str
|
||||
filename: str
|
||||
doc_type: str
|
||||
content: str = ""
|
||||
structured_data: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -38,12 +48,12 @@ class TemplateFillService:
|
||||
|
||||
def __init__(self):
|
||||
self.llm = llm_service
|
||||
self.rag = rag_service
|
||||
|
||||
async def fill_template(
|
||||
self,
|
||||
template_fields: List[TemplateField],
|
||||
source_doc_ids: Optional[List[str]] = None,
|
||||
source_file_paths: Optional[List[str]] = None,
|
||||
user_hint: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -51,7 +61,8 @@ class TemplateFillService:
|
||||
|
||||
Args:
|
||||
template_fields: 模板字段列表
|
||||
source_doc_ids: 源文档ID列表,不指定则从所有文档检索
|
||||
source_doc_ids: 源文档 MongoDB ID 列表
|
||||
source_file_paths: 源文档文件路径列表
|
||||
user_hint: 用户提示(如"请从合同文档中提取")
|
||||
|
||||
Returns:
|
||||
@@ -60,28 +71,23 @@ class TemplateFillService:
|
||||
filled_data = {}
|
||||
fill_details = []
|
||||
|
||||
# 1. 加载源文档内容
|
||||
source_docs = await self._load_source_documents(source_doc_ids, source_file_paths)
|
||||
|
||||
if not source_docs:
|
||||
logger.warning("没有找到源文档,填表结果将全部为空")
|
||||
|
||||
# 2. 对每个字段进行提取
|
||||
for field in template_fields:
|
||||
try:
|
||||
# 1. 从 RAG 检索相关上下文
|
||||
rag_results = await self._retrieve_context(field.name, user_hint)
|
||||
# 从源文档中提取字段值
|
||||
result = await self._extract_field_value(
|
||||
field=field,
|
||||
source_docs=source_docs,
|
||||
user_hint=user_hint
|
||||
)
|
||||
|
||||
if not rag_results:
|
||||
# 如果没有检索到结果,尝试直接询问 LLM
|
||||
result = FillResult(
|
||||
field=field.name,
|
||||
value="",
|
||||
source="未找到相关数据",
|
||||
confidence=0.0
|
||||
)
|
||||
else:
|
||||
# 2. 构建 Prompt 让 LLM 提取信息
|
||||
result = await self._extract_field_value(
|
||||
field=field,
|
||||
rag_context=rag_results,
|
||||
user_hint=user_hint
|
||||
)
|
||||
|
||||
# 3. 存储结果
|
||||
# 存储结果
|
||||
filled_data[field.name] = result.value
|
||||
fill_details.append({
|
||||
"field": field.name,
|
||||
@@ -107,75 +113,113 @@ class TemplateFillService:
|
||||
return {
|
||||
"success": True,
|
||||
"filled_data": filled_data,
|
||||
"fill_details": fill_details
|
||||
"fill_details": fill_details,
|
||||
"source_doc_count": len(source_docs)
|
||||
}
|
||||
|
||||
async def _retrieve_context(
|
||||
async def _load_source_documents(
|
||||
self,
|
||||
field_name: str,
|
||||
user_hint: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
source_doc_ids: Optional[List[str]] = None,
|
||||
source_file_paths: Optional[List[str]] = None
|
||||
) -> List[SourceDocument]:
|
||||
"""
|
||||
从 RAG 检索相关上下文
|
||||
加载源文档内容
|
||||
|
||||
Args:
|
||||
field_name: 字段名称
|
||||
user_hint: 用户提示
|
||||
source_doc_ids: MongoDB 文档 ID 列表
|
||||
source_file_paths: 源文档文件路径列表
|
||||
|
||||
Returns:
|
||||
检索结果列表
|
||||
源文档列表
|
||||
"""
|
||||
# 构建查询文本
|
||||
query = field_name
|
||||
if user_hint:
|
||||
query = f"{user_hint} {field_name}"
|
||||
source_docs = []
|
||||
|
||||
# 检索相关文档片段
|
||||
results = self.rag.retrieve(query=query, top_k=5)
|
||||
# 1. 从 MongoDB 加载文档
|
||||
if source_doc_ids:
|
||||
for doc_id in source_doc_ids:
|
||||
try:
|
||||
doc = await mongodb.get_document(doc_id)
|
||||
if doc:
|
||||
source_docs.append(SourceDocument(
|
||||
doc_id=doc_id,
|
||||
filename=doc.get("metadata", {}).get("original_filename", "unknown"),
|
||||
doc_type=doc.get("doc_type", "unknown"),
|
||||
content=doc.get("content", ""),
|
||||
structured_data=doc.get("structured_data", {})
|
||||
))
|
||||
logger.info(f"从MongoDB加载文档: {doc_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"从MongoDB加载文档失败 {doc_id}: {str(e)}")
|
||||
|
||||
return results
|
||||
# 2. 从文件路径加载文档
|
||||
if source_file_paths:
|
||||
for file_path in source_file_paths:
|
||||
try:
|
||||
parser = ParserFactory.get_parser(file_path)
|
||||
result = parser.parse(file_path)
|
||||
if result.success:
|
||||
source_docs.append(SourceDocument(
|
||||
doc_id=file_path,
|
||||
filename=result.metadata.get("filename", file_path.split("/")[-1]),
|
||||
doc_type=result.metadata.get("extension", "unknown").replace(".", ""),
|
||||
content=result.data.get("content", ""),
|
||||
structured_data=result.data.get("structured_data", {})
|
||||
))
|
||||
logger.info(f"从文件加载文档: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"从文件加载文档失败 {file_path}: {str(e)}")
|
||||
|
||||
return source_docs
|
||||
|
||||
async def _extract_field_value(
|
||||
self,
|
||||
field: TemplateField,
|
||||
rag_context: List[Dict[str, Any]],
|
||||
source_docs: List[SourceDocument],
|
||||
user_hint: Optional[str] = None
|
||||
) -> FillResult:
|
||||
"""
|
||||
使用 LLM 从上下文中提取字段值
|
||||
使用 LLM 从源文档中提取字段值
|
||||
|
||||
Args:
|
||||
field: 字段定义
|
||||
rag_context: RAG 检索到的上下文
|
||||
source_docs: 源文档列表
|
||||
user_hint: 用户提示
|
||||
|
||||
Returns:
|
||||
提取结果
|
||||
"""
|
||||
# 构建上下文文本
|
||||
context_text = "\n\n".join([
|
||||
f"【文档 {i+1}】\n{doc['content']}"
|
||||
for i, doc in enumerate(rag_context)
|
||||
])
|
||||
if not source_docs:
|
||||
return FillResult(
|
||||
field=field.name,
|
||||
value="",
|
||||
source="无源文档",
|
||||
confidence=0.0
|
||||
)
|
||||
|
||||
# 构建 Prompt
|
||||
prompt = f"""你是一个数据提取专家。请根据以下文档内容,提取指定字段的信息。
|
||||
# 构建上下文文本
|
||||
context_text = self._build_context_text(source_docs, max_length=8000)
|
||||
|
||||
# 构建提示词
|
||||
hint_text = field.hint if field.hint else f"请提取{field.name}的信息"
|
||||
if user_hint:
|
||||
hint_text = f"{user_hint}。{hint_text}"
|
||||
|
||||
prompt = f"""你是一个专业的数据提取专家。请根据以下文档内容,提取指定字段的信息。
|
||||
|
||||
需要提取的字段:
|
||||
- 字段名称:{field.name}
|
||||
- 字段类型:{field.field_type}
|
||||
- 填写提示:{hint_text}
|
||||
- 是否必填:{'是' if field.required else '否'}
|
||||
|
||||
{'用户提示:' + user_hint if user_hint else ''}
|
||||
|
||||
参考文档内容:
|
||||
{context_text}
|
||||
|
||||
请严格按照以下 JSON 格式输出,不要添加任何解释:
|
||||
{{
|
||||
"value": "提取到的值,如果没有找到则填写空字符串",
|
||||
"source": "数据来源的文档描述",
|
||||
"confidence": 0.0到1.0之间的置信度
|
||||
"source": "数据来源的文档描述(如:来自xxx文档)",
|
||||
"confidence": 0.0到1.0之间的置信度,表示对提取结果的信心程度"
|
||||
}}
|
||||
"""
|
||||
|
||||
@@ -226,6 +270,54 @@ class TemplateFillService:
|
||||
confidence=0.0
|
||||
)
|
||||
|
||||
def _build_context_text(self, source_docs: List[SourceDocument], max_length: int = 8000) -> str:
|
||||
"""
|
||||
构建上下文文本
|
||||
|
||||
Args:
|
||||
source_docs: 源文档列表
|
||||
max_length: 最大字符数
|
||||
|
||||
Returns:
|
||||
上下文文本
|
||||
"""
|
||||
contexts = []
|
||||
total_length = 0
|
||||
|
||||
for doc in source_docs:
|
||||
# 优先使用结构化数据(表格),其次使用文本内容
|
||||
doc_content = ""
|
||||
|
||||
if doc.structured_data and doc.structured_data.get("tables"):
|
||||
# 如果有表格数据,优先使用
|
||||
tables = doc.structured_data.get("tables", [])
|
||||
for table in tables:
|
||||
if isinstance(table, dict):
|
||||
rows = table.get("rows", [])
|
||||
if rows:
|
||||
doc_content += f"\n【文档: {doc.filename} 表格数据】\n"
|
||||
for row in rows[:20]: # 限制每表最多20行
|
||||
if isinstance(row, list):
|
||||
doc_content += " | ".join(str(cell) for cell in row) + "\n"
|
||||
elif isinstance(row, dict):
|
||||
doc_content += " | ".join(str(v) for v in row.values()) + "\n"
|
||||
elif doc.content:
|
||||
doc_content = doc.content[:5000] # 限制文本长度
|
||||
|
||||
if doc_content:
|
||||
doc_context = f"【文档: {doc.filename} ({doc.doc_type})】\n{doc_content}"
|
||||
if total_length + len(doc_context) <= max_length:
|
||||
contexts.append(doc_context)
|
||||
total_length += len(doc_context)
|
||||
else:
|
||||
# 如果超出长度,截断
|
||||
remaining = max_length - total_length
|
||||
if remaining > 100:
|
||||
contexts.append(doc_context[:remaining])
|
||||
break
|
||||
|
||||
return "\n\n".join(contexts) if contexts else "(源文档内容为空)"
|
||||
|
||||
async def get_template_fields_from_file(
|
||||
self,
|
||||
file_path: str,
|
||||
@@ -236,7 +328,7 @@ class TemplateFillService:
|
||||
|
||||
Args:
|
||||
file_path: 模板文件路径
|
||||
file_type: 文件类型
|
||||
file_type: 文件类型 (xlsx/xls/docx)
|
||||
|
||||
Returns:
|
||||
字段列表
|
||||
@@ -245,43 +337,108 @@ class TemplateFillService:
|
||||
|
||||
try:
|
||||
if file_type in ["xlsx", "xls"]:
|
||||
# 从 Excel 读取表头
|
||||
import pandas as pd
|
||||
df = pd.read_excel(file_path, nrows=5)
|
||||
|
||||
for idx, col in enumerate(df.columns):
|
||||
# 获取单元格位置 (A, B, C, ...)
|
||||
cell = self._column_to_cell(idx)
|
||||
|
||||
fields.append(TemplateField(
|
||||
cell=cell,
|
||||
name=str(col),
|
||||
field_type=self._infer_field_type(df[col]),
|
||||
required=True
|
||||
))
|
||||
|
||||
fields = await self._get_template_fields_from_excel(file_path)
|
||||
elif file_type == "docx":
|
||||
# 从 Word 表格读取
|
||||
from docx import Document
|
||||
doc = Document(file_path)
|
||||
|
||||
for table_idx, table in enumerate(doc.tables):
|
||||
for row_idx, row in enumerate(table.rows):
|
||||
for col_idx, cell in enumerate(row.cells):
|
||||
cell_text = cell.text.strip()
|
||||
if cell_text:
|
||||
fields.append(TemplateField(
|
||||
cell=self._column_to_cell(col_idx),
|
||||
name=cell_text,
|
||||
field_type="text",
|
||||
required=True
|
||||
))
|
||||
fields = await self._get_template_fields_from_docx(file_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"提取模板字段失败: {str(e)}")
|
||||
|
||||
return fields
|
||||
|
||||
async def _get_template_fields_from_excel(self, file_path: str) -> List[TemplateField]:
|
||||
"""从 Excel 模板提取字段"""
|
||||
fields = []
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
df = pd.read_excel(file_path, nrows=5)
|
||||
|
||||
for idx, col in enumerate(df.columns):
|
||||
cell = self._column_to_cell(idx)
|
||||
col_str = str(col)
|
||||
|
||||
fields.append(TemplateField(
|
||||
cell=cell,
|
||||
name=col_str,
|
||||
field_type=self._infer_field_type_from_value(df[col].iloc[0] if len(df) > 0 else ""),
|
||||
required=True,
|
||||
hint=""
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从Excel提取字段失败: {str(e)}")
|
||||
|
||||
return fields
|
||||
|
||||
async def _get_template_fields_from_docx(self, file_path: str) -> List[TemplateField]:
|
||||
"""从 Word 模板提取字段"""
|
||||
fields = []
|
||||
|
||||
try:
|
||||
from docx import Document
|
||||
|
||||
doc = Document(file_path)
|
||||
|
||||
for table_idx, table in enumerate(doc.tables):
|
||||
for row_idx, row in enumerate(table.rows):
|
||||
cells = [cell.text.strip() for cell in row.cells]
|
||||
|
||||
# 假设第一列是字段名
|
||||
if cells and cells[0]:
|
||||
field_name = cells[0]
|
||||
hint = cells[1] if len(cells) > 1 else ""
|
||||
|
||||
# 跳过空行或标题行
|
||||
if field_name and field_name not in ["", "字段名", "名称", "项目"]:
|
||||
fields.append(TemplateField(
|
||||
cell=f"T{table_idx}R{row_idx}",
|
||||
name=field_name,
|
||||
field_type=self._infer_field_type_from_hint(hint),
|
||||
required=True,
|
||||
hint=hint
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从Word提取字段失败: {str(e)}")
|
||||
|
||||
return fields
|
||||
|
||||
def _infer_field_type_from_hint(self, hint: str) -> str:
|
||||
"""从提示词推断字段类型"""
|
||||
hint_lower = hint.lower()
|
||||
|
||||
date_keywords = ["年", "月", "日", "日期", "时间", "出生"]
|
||||
if any(kw in hint for kw in date_keywords):
|
||||
return "date"
|
||||
|
||||
number_keywords = ["数量", "金额", "人数", "面积", "增长", "比率", "%", "率", "总计", "合计"]
|
||||
if any(kw in hint_lower for kw in number_keywords):
|
||||
return "number"
|
||||
|
||||
return "text"
|
||||
|
||||
def _infer_field_type_from_value(self, value: Any) -> str:
|
||||
"""从示例值推断字段类型"""
|
||||
if value is None or value == "":
|
||||
return "text"
|
||||
|
||||
value_str = str(value)
|
||||
|
||||
# 检查日期模式
|
||||
import re
|
||||
if re.search(r'\d{4}[年/-]\d{1,2}[月/-]\d{1,2}', value_str):
|
||||
return "date"
|
||||
|
||||
# 检查数值
|
||||
try:
|
||||
float(value_str.replace(',', '').replace('%', ''))
|
||||
return "number"
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return "text"
|
||||
|
||||
def _column_to_cell(self, col_idx: int) -> str:
|
||||
"""将列索引转换为单元格列名 (0 -> A, 1 -> B, ...)"""
|
||||
result = ""
|
||||
@@ -290,17 +447,6 @@ class TemplateFillService:
|
||||
col_idx = col_idx // 26 - 1
|
||||
return result
|
||||
|
||||
def _infer_field_type(self, series) -> str:
|
||||
"""推断字段类型"""
|
||||
import pandas as pd
|
||||
|
||||
if pd.api.types.is_numeric_dtype(series):
|
||||
return "number"
|
||||
elif pd.api.types.is_datetime64_any_dtype(series):
|
||||
return "date"
|
||||
else:
|
||||
return "text"
|
||||
|
||||
|
||||
# ==================== 全局单例 ====================
|
||||
|
||||
|
||||
Reference in New Issue
Block a user