This commit is contained in:
dj
2026-04-09 21:42:14 +08:00
5 changed files with 797 additions and 288 deletions

View File

@@ -5,15 +5,18 @@
"""
import io
import logging
import uuid
from typing import List, Optional
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
from fastapi import APIRouter, File, HTTPException, Query, UploadFile, BackgroundTasks
from fastapi.responses import StreamingResponse
import pandas as pd
from pydantic import BaseModel
from app.services.template_fill_service import template_fill_service, TemplateField
from app.services.file_service import file_service
from app.core.database import mongodb
from app.core.document_parser import ParserFactory
logger = logging.getLogger(__name__)
@@ -109,6 +112,172 @@ async def upload_template(
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
@router.post("/upload-joint")
async def upload_joint_template(
background_tasks: BackgroundTasks,
template_file: UploadFile = File(..., description="模板文件"),
source_files: List[UploadFile] = File(..., description="源文档文件列表"),
):
"""
联合上传模板和源文档,一键完成解析和存储
1. 保存模板文件并提取字段
2. 异步处理源文档(解析+存MongoDB
3. 返回模板信息和源文档ID列表
Args:
template_file: 模板文件 (xlsx/xls/docx)
source_files: 源文档列表 (docx/xlsx/md/txt)
Returns:
模板ID、字段列表、源文档ID列表
"""
if not template_file.filename:
raise HTTPException(status_code=400, detail="模板文件名为空")
# 验证模板格式
template_ext = template_file.filename.split('.')[-1].lower()
if template_ext not in ['xlsx', 'xls', 'docx']:
raise HTTPException(
status_code=400,
detail=f"不支持的模板格式: {template_ext},仅支持 xlsx/xls/docx"
)
# 验证源文档格式
valid_exts = ['docx', 'xlsx', 'xls', 'md', 'txt']
for sf in source_files:
if sf.filename:
sf_ext = sf.filename.split('.')[-1].lower()
if sf_ext not in valid_exts:
raise HTTPException(
status_code=400,
detail=f"不支持的源文档格式: {sf_ext},仅支持 docx/xlsx/xls/md/txt"
)
try:
# 1. 保存模板文件并提取字段
template_content = await template_file.read()
template_path = file_service.save_uploaded_file(
template_content,
template_file.filename,
subfolder="templates"
)
template_fields = await template_fill_service.get_template_fields_from_file(
template_path,
template_ext
)
# 2. 处理源文档 - 保存文件
source_file_info = []
for sf in source_files:
if sf.filename:
sf_content = await sf.read()
sf_ext = sf.filename.split('.')[-1].lower()
sf_path = file_service.save_uploaded_file(
sf_content,
sf.filename,
subfolder=sf_ext
)
source_file_info.append({
"path": sf_path,
"filename": sf.filename,
"ext": sf_ext
})
# 3. 异步处理源文档到MongoDB
task_id = str(uuid.uuid4())
if source_file_info:
background_tasks.add_task(
process_source_documents,
task_id=task_id,
files=source_file_info
)
logger.info(f"联合上传完成: 模板={template_file.filename}, 源文档={len(source_file_info)}")
return {
"success": True,
"template_id": template_path,
"filename": template_file.filename,
"file_type": template_ext,
"fields": [
{
"cell": f.cell,
"name": f.name,
"field_type": f.field_type,
"required": f.required,
"hint": f.hint
}
for f in template_fields
],
"field_count": len(template_fields),
"source_file_paths": [f["path"] for f in source_file_info],
"source_filenames": [f["filename"] for f in source_file_info],
"task_id": task_id
}
except HTTPException:
raise
except Exception as e:
logger.error(f"联合上传失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"联合上传失败: {str(e)}")
async def process_source_documents(task_id: str, files: List[dict]):
"""异步处理源文档存入MongoDB"""
from app.core.database import redis_db
try:
await redis_db.set_task_status(
task_id, status="processing",
meta={"progress": 0, "message": "开始处理源文档"}
)
doc_ids = []
for i, file_info in enumerate(files):
try:
parser = ParserFactory.get_parser(file_info["path"])
result = parser.parse(file_info["path"])
if result.success:
doc_id = await mongodb.insert_document(
doc_type=file_info["ext"],
content=result.data.get("content", ""),
metadata={
**result.metadata,
"original_filename": file_info["filename"],
"file_path": file_info["path"]
},
structured_data=result.data.get("structured_data")
)
doc_ids.append(doc_id)
logger.info(f"源文档处理成功: {file_info['filename']}, doc_id: {doc_id}")
else:
logger.error(f"源文档解析失败: {file_info['filename']}, error: {result.error}")
except Exception as e:
logger.error(f"源文档处理异常: {file_info['filename']}, error: {str(e)}")
progress = int((i + 1) / len(files) * 100)
await redis_db.set_task_status(
task_id, status="processing",
meta={"progress": progress, "message": f"已处理 {i+1}/{len(files)}"}
)
await redis_db.set_task_status(
task_id, status="success",
meta={"progress": 100, "message": "源文档处理完成", "doc_ids": doc_ids}
)
logger.info(f"所有源文档处理完成: {len(doc_ids)}")
except Exception as e:
logger.error(f"源文档批量处理失败: {str(e)}")
await redis_db.set_task_status(
task_id, status="failure",
meta={"error": str(e)}
)
@router.post("/fields")
async def extract_template_fields(
template_id: str = Query(..., description="模板ID/文件路径"),

View File

@@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional
from app.core.database import mongodb
from app.services.llm_service import llm_service
from app.core.document_parser import ParserFactory
from app.services.markdown_ai_service import markdown_ai_service
logger = logging.getLogger(__name__)
@@ -285,6 +286,12 @@ class TemplateFillService:
confidence=1.0
)
# 无法直接从结构化数据提取,尝试 AI 分析非结构化文档
ai_structured = await self._analyze_unstructured_docs_for_fields(source_docs, field, user_hint)
if ai_structured:
logger.info(f"✅ 字段 {field.name} 通过 AI 分析结构化提取到数据")
return ai_structured
# 无法从结构化数据提取,使用 LLM
logger.info(f"字段 {field.name} 无法直接从结构化数据提取,使用 LLM...")
@@ -296,18 +303,20 @@ class TemplateFillService:
if user_hint:
hint_text = f"{user_hint}{hint_text}"
prompt = f"""你是一个专业的数据提取专家。请从以下文档内容中提取"{field.name}"字段的所有行数据
prompt = f"""你是一个专业的数据提取专家。请从以下文档内容中提取"{field.name}"相关的所有信息
参考文档内容(已提取" {field.name}"列的数据):
提示词: {hint_text}
文档内容:
{context_text}
提取上述所有行的" {field.name}"值,存入数组。每一行对应数组中的一个元素
如果某行该字段为空,请用空字符串""占位
分析文档结构(可能包含表格、标题段落等),找出所有与"{field.name}"相关的数据
如果找到表格数据,返回多行值;如果是非表格段落,提取关键信息
请严格按照以下 JSON 格式输出,不要添加任何解释
请严格按照以下 JSON 格式输出:
{{
"values": ["第1行的值", "第2行的值", "第3行的值", ...],
"source": "数据来源的文档描述",
"values": ["第1行的值", "第2行的值", ...],
"source": "数据来源描述",
"confidence": 0.0到1.0之间的置信度
}}
"""
@@ -525,6 +534,29 @@ class TemplateFillService:
elif isinstance(row, list):
doc_content += " | ".join(str(cell) for cell in row) + "\n"
row_count += 1
elif doc.structured_data and doc.structured_data.get("tables"):
# Markdown 表格格式: {tables: [{headers: [...], rows: [...]}]}
tables = doc.structured_data.get("tables", [])
for table in tables:
if isinstance(table, dict):
headers = table.get("headers", [])
rows = table.get("rows", [])
if rows and headers:
doc_content += f"\n【文档: {doc.filename} - 表格】\n"
doc_content += " | ".join(str(h) for h in headers) + "\n"
for row in rows:
if isinstance(row, list):
doc_content += " | ".join(str(cell) for cell in row) + "\n"
row_count += 1
# 如果有标题结构,也添加上下文
if doc.structured_data.get("titles"):
titles = doc.structured_data.get("titles", [])
doc_content += f"\n【文档章节结构】\n"
for title in titles[:20]: # 限制前20个标题
doc_content += f"{'#' * title.get('level', 1)} {title.get('text', '')}\n"
# 如果没有提取到表格内容,使用纯文本
if not doc_content.strip():
doc_content = doc.content[:5000] if doc.content else ""
elif doc.content:
doc_content = doc.content[:5000]
@@ -804,6 +836,21 @@ class TemplateFillService:
logger.info(f"从文档 {doc.filename} 提取到 {len(values)} 个值")
break
# 处理 Markdown 表格格式: {tables: [{headers: [...], rows: [...]}]}
elif structured.get("tables"):
tables = structured.get("tables", [])
for table in tables:
if isinstance(table, dict):
headers = table.get("headers", [])
rows = table.get("rows", [])
values = self._extract_column_values(rows, headers, field_name)
if values:
all_values.extend(values)
logger.info(f"从 Markdown 表格提取到 {len(values)} 个值")
break
if all_values:
break
return all_values
def _extract_values_from_markdown_table(self, headers: List, rows: List, field_name: str) -> List[str]:
@@ -1223,6 +1270,145 @@ class TemplateFillService:
content = text.strip()[:500] if text.strip() else ""
return [content] if content else []
async def _analyze_unstructured_docs_for_fields(
self,
source_docs: List[SourceDocument],
field: TemplateField,
user_hint: Optional[str] = None
) -> Optional[FillResult]:
"""
对非结构化文档进行 AI 分析,尝试提取结构化数据
适用于 Markdown 等没有表格格式的文档,通过 AI 分析提取结构化信息
Args:
source_docs: 源文档列表
field: 字段定义
user_hint: 用户提示
Returns:
FillResult 如果提取成功,否则返回 None
"""
# 找出非结构化的 Markdown/TXT 文档(没有表格的)
unstructured_docs = []
for doc in source_docs:
if doc.doc_type in ["md", "txt", "markdown"]:
# 检查是否有表格
has_tables = (
doc.structured_data and
doc.structured_data.get("tables") and
len(doc.structured_data.get("tables", [])) > 0
)
if not has_tables:
unstructured_docs.append(doc)
if not unstructured_docs:
return None
logger.info(f"发现 {len(unstructured_docs)} 个非结构化文档,尝试 AI 分析...")
# 对每个非结构化文档进行 AI 分析
for doc in unstructured_docs:
try:
# 使用 markdown_ai_service 的 statistics 分析类型
# 这种类型专门用于政府统计公报等包含数据的文档
hint_text = field.hint if field.hint else f"请提取{field.name}的信息"
if user_hint:
hint_text = f"{user_hint}{hint_text}"
# 构建针对字段提取的提示词
prompt = f"""你是一个专业的数据提取专家。请从以下文档内容中提取与"{field.name}"相关的所有数据。
字段提示: {hint_text}
文档内容:
{doc.content[:8000] if doc.content else ""}
请完成以下任务:
1. 仔细阅读文档,找出所有与"{field.name}"相关的数据
2. 如果文档中有表格数据,提取表格中的对应列值
3. 如果文档中是段落描述,提取其中的关键数值或结论
4. 返回提取的所有值(可能多个,用数组存储)
请用严格的 JSON 格式返回:
{{
"values": ["值1", "值2", ...],
"source": "数据来源说明",
"confidence": 0.0到1.0之间的置信度
}}
如果没有找到相关数据,返回空数组 values: []"""
messages = [
{"role": "system", "content": "你是一个专业的数据提取助手擅长从政府统计公报等文档中提取数据。请严格按JSON格式输出。"},
{"role": "user", "content": prompt}
]
response = await self.llm.chat(
messages=messages,
temperature=0.1,
max_tokens=5000
)
content = self.llm.extract_message_content(response)
logger.info(f"AI 分析返回: {content[:500]}")
# 解析 JSON
import json
import re
# 清理 markdown 格式
cleaned = content.strip()
cleaned = re.sub(r'^```json\s*', '', cleaned, flags=re.MULTILINE)
cleaned = re.sub(r'^```\s*', '', cleaned, flags=re.MULTILINE)
cleaned = cleaned.strip()
# 查找 JSON
json_start = -1
for i, c in enumerate(cleaned):
if c == '{' or c == '[':
json_start = i
break
if json_start == -1:
continue
json_text = cleaned[json_start:]
try:
result = json.loads(json_text)
values = self._extract_values_from_json(result)
if values:
return FillResult(
field=field.name,
values=values,
value=values[0] if values else "",
source=f"AI分析: {doc.filename}",
confidence=result.get("confidence", 0.8)
)
except json.JSONDecodeError:
# 尝试修复 JSON
fixed = self._fix_json(json_text)
if fixed:
try:
result = json.loads(fixed)
values = self._extract_values_from_json(result)
if values:
return FillResult(
field=field.name,
values=values,
value=values[0] if values else "",
source=f"AI分析: {doc.filename}",
confidence=result.get("confidence", 0.8)
)
except json.JSONDecodeError:
pass
except Exception as e:
logger.warning(f"AI 分析文档 {doc.filename} 失败: {str(e)}")
continue
return None
# ==================== 全局单例 ====================