添加其他格式文档的解析

This commit is contained in:
2026-03-26 23:14:39 +08:00
parent 4bdc3f9707
commit 5bcad4a5fa
9 changed files with 2075 additions and 22 deletions

View File

@@ -4,10 +4,11 @@ API 路由注册模块
from fastapi import APIRouter
from app.api.endpoints import (
upload,
documents, # 新增:文档上传
tasks, # 新增:任务管理
library, # 新增:文档库
rag, # 新增:RAG检索
documents, # 多格式文档上传
tasks, # 任务管理
library, # 文档库
rag, # RAG检索
templates, # 表格模板
ai_analyze,
visualization,
analysis_charts,
@@ -18,12 +19,13 @@ from app.api.endpoints import (
api_router = APIRouter()
# 注册各模块路由
api_router.include_router(health.router) # 健康检查
api_router.include_router(upload.router) # 原有Excel上传
api_router.include_router(health.router) # 健康检查
api_router.include_router(upload.router) # 原有Excel上传
api_router.include_router(documents.router) # 多格式文档上传
api_router.include_router(tasks.router) # 任务状态查询
api_router.include_router(library.router) # 文档库管理
api_router.include_router(rag.router) # RAG检索
api_router.include_router(ai_analyze.router) # AI分析
api_router.include_router(tasks.router) # 任务状态查询
api_router.include_router(library.router) # 文档库管理
api_router.include_router(rag.router) # RAG检索
api_router.include_router(templates.router) # 表格模板
api_router.include_router(ai_analyze.router) # AI分析
api_router.include_router(visualization.router) # 可视化
api_router.include_router(analysis_charts.router) # 分析图表
api_router.include_router(analysis_charts.router) # 分析图表

View File

@@ -0,0 +1,228 @@
"""
表格模板 API 接口
提供模板上传、解析和填写功能
"""
import io
from typing import List, Optional
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
from fastapi.responses import StreamingResponse
import pandas as pd
from pydantic import BaseModel
from app.services.template_fill_service import template_fill_service, TemplateField
from app.services.excel_storage_service import excel_storage_service
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/templates", tags=["表格模板"])
# ==================== 请求/响应模型 ====================
class TemplateFieldRequest(BaseModel):
"""模板字段请求"""
cell: str
name: str
field_type: str = "text"
required: bool = True
class FillRequest(BaseModel):
"""填写请求"""
template_id: str
template_fields: List[TemplateFieldRequest]
source_doc_ids: Optional[List[str]] = None
user_hint: Optional[str] = None
class ExportRequest(BaseModel):
"""导出请求"""
template_id: str
filled_data: dict
format: str = "xlsx" # xlsx 或 docx
# ==================== 接口实现 ====================
@router.post("/upload")
async def upload_template(
file: UploadFile = File(...),
):
"""
上传表格模板文件
支持 Excel (.xlsx, .xls) 和 Word (.docx) 格式
Returns:
模板信息,包括提取的字段列表
"""
if not file.filename:
raise HTTPException(status_code=400, detail="文件名为空")
file_ext = file.filename.split('.')[-1].lower()
if file_ext not in ['xlsx', 'xls', 'docx']:
raise HTTPException(
status_code=400,
detail=f"不支持的模板格式: {file_ext},仅支持 xlsx/xls/docx"
)
try:
# 保存文件
from app.services.file_service import file_service
content = await file.read()
saved_path = file_service.save_uploaded_file(
content,
file.filename,
subfolder="templates"
)
# 提取字段
template_fields = await template_fill_service.get_template_fields_from_file(
saved_path,
file_ext
)
return {
"success": True,
"template_id": saved_path, # 使用文件路径作为ID
"filename": file.filename,
"file_type": file_ext,
"fields": [
{
"cell": f.cell,
"name": f.name,
"field_type": f.field_type,
"required": f.required
}
for f in template_fields
],
"field_count": len(template_fields)
}
except Exception as e:
logger.error(f"上传模板失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
@router.post("/fields")
async def extract_template_fields(
template_id: str = Query(..., description="模板ID/文件路径"),
file_type: str = Query("xlsx", description="文件类型")
):
"""
从已上传的模板提取字段定义
Args:
template_id: 模板ID
file_type: 文件类型
Returns:
字段列表
"""
try:
fields = await template_fill_service.get_template_fields_from_file(
template_id,
file_type
)
return {
"success": True,
"fields": [
{
"cell": f.cell,
"name": f.name,
"field_type": f.field_type,
"required": f.required
}
for f in fields
]
}
except Exception as e:
logger.error(f"提取字段失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"提取失败: {str(e)}")
@router.post("/fill")
async def fill_template(
request: FillRequest,
):
"""
执行表格填写
根据提供的字段定义,从已上传的文档中检索信息并填写
Args:
request: 填写请求
Returns:
填写结果
"""
try:
# 转换字段
fields = [
TemplateField(
cell=f.cell,
name=f.name,
field_type=f.field_type,
required=f.required
)
for f in request.template_fields
]
# 执行填写
result = await template_fill_service.fill_template(
template_fields=fields,
source_doc_ids=request.source_doc_ids,
user_hint=request.user_hint
)
return result
except Exception as e:
logger.error(f"填写表格失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"填写失败: {str(e)}")
@router.post("/export")
async def export_filled_template(
request: ExportRequest,
):
"""
导出填写后的表格
Args:
request: 导出请求
Returns:
文件流
"""
try:
# 创建 DataFrame
df = pd.DataFrame([request.filled_data])
# 导出为 Excel
output = io.BytesIO()
with pd.ExcelWriter(output, engine='openpyxl') as writer:
df.to_excel(writer, index=False, sheet_name='填写结果')
output.seek(0)
# 生成文件名
filename = f"filled_template.{request.format}"
return StreamingResponse(
io.BytesIO(output.getvalue()),
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
headers={"Content-Disposition": f"attachment; filename={filename}"}
)
except Exception as e:
logger.error(f"导出失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"导出失败: {str(e)}")
# ==================== 需要添加的 import ====================
import logging

View File

@@ -2,26 +2,29 @@
文档解析模块 - 支持多种文件格式的解析
"""
from pathlib import Path
from typing import Dict, Optional
from typing import Dict
from .base import BaseParser, ParseResult
from .xlsx_parser import XlsxParser
# 导入其他解析器 (需要先实现)
# from .docx_parser import DocxParser
# from .md_parser import MarkdownParser
# from .txt_parser import TxtParser
from .docx_parser import DocxParser
from .md_parser import MarkdownParser
from .txt_parser import TxtParser
class ParserFactory:
"""解析器工厂,根据文件类型返回对应解析器"""
_parsers: Dict[str, BaseParser] = {
# Excel
'.xlsx': XlsxParser(),
'.xls': XlsxParser(),
# '.docx': DocxParser(), # TODO: 待实现
# '.md': MarkdownParser(), # TODO: 待实现
# '.txt': TxtParser(), # TODO: 待实现
# Word
'.docx': DocxParser(),
# Markdown
'.md': MarkdownParser(),
'.markdown': MarkdownParser(),
# 文本
'.txt': TxtParser(),
}
@classmethod
@@ -30,7 +33,8 @@ class ParserFactory:
ext = Path(file_path).suffix.lower()
parser = cls._parsers.get(ext)
if not parser:
raise ValueError(f"不支持的文件格式: {ext},支持的格式: {list(cls._parsers.keys())}")
supported = list(cls._parsers.keys())
raise ValueError(f"不支持的文件格式: {ext},支持的格式: {supported}")
return parser
@classmethod
@@ -44,5 +48,18 @@ class ParserFactory:
"""注册新的解析器"""
cls._parsers[ext.lower()] = parser
@classmethod
def get_supported_extensions(cls) -> list:
"""获取所有支持的扩展名"""
return list(cls._parsers.keys())
__all__ = ['BaseParser', 'ParseResult', 'XlsxParser', 'ParserFactory']
__all__ = [
'BaseParser',
'ParseResult',
'ParserFactory',
'XlsxParser',
'DocxParser',
'MarkdownParser',
'TxtParser',
]

View File

@@ -0,0 +1,163 @@
"""
Word 文档 (.docx) 解析器
"""
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
from docx import Document
from .base import BaseParser, ParseResult
logger = logging.getLogger(__name__)
class DocxParser(BaseParser):
"""Word 文档解析器"""
def __init__(self):
super().__init__()
self.supported_extensions = ['.docx']
self.parser_name = "docx_parser"
def parse(
self,
file_path: str,
**kwargs
) -> ParseResult:
"""
解析 Word 文档
Args:
file_path: 文件路径
**kwargs: 其他参数
Returns:
ParseResult: 解析结果
"""
path = Path(file_path)
# 检查文件是否存在
if not path.exists():
return ParseResult(
success=False,
error=f"文件不存在: {file_path}"
)
# 检查文件扩展名
if path.suffix.lower() not in self.supported_extensions:
return ParseResult(
success=False,
error=f"不支持的文件类型: {path.suffix}"
)
try:
# 读取 Word 文档
doc = Document(file_path)
# 提取文本内容
paragraphs = []
for para in doc.paragraphs:
if para.text.strip():
paragraphs.append(para.text)
# 提取表格内容
tables_data = []
for i, table in enumerate(doc.tables):
table_rows = []
for row in table.rows:
row_data = [cell.text.strip() for cell in row.cells]
table_rows.append(row_data)
if table_rows:
tables_data.append({
"table_index": i,
"rows": table_rows,
"row_count": len(table_rows),
"column_count": len(table_rows[0]) if table_rows else 0
})
# 合并所有文本
full_text = "\n".join(paragraphs)
# 构建元数据
metadata = {
"filename": path.name,
"extension": path.suffix.lower(),
"file_size": path.stat().st_size,
"paragraph_count": len(paragraphs),
"table_count": len(tables_data),
"word_count": len(full_text),
"char_count": len(full_text.replace("\n", "")),
"has_tables": len(tables_data) > 0
}
# 返回结果
return ParseResult(
success=True,
data={
"content": full_text,
"paragraphs": paragraphs,
"tables": tables_data,
"word_count": len(full_text),
"structured_data": {
"paragraphs": paragraphs,
"tables": tables_data
}
},
metadata=metadata
)
except Exception as e:
logger.error(f"解析 Word 文档失败: {str(e)}")
return ParseResult(
success=False,
error=f"解析 Word 文档失败: {str(e)}"
)
def extract_key_sentences(self, text: str, max_sentences: int = 10) -> List[str]:
"""
从文本中提取关键句子
Args:
text: 文本内容
max_sentences: 最大句子数
Returns:
关键句子列表
"""
# 简单实现按句号分割取前N个句子
sentences = [s.strip() for s in text.split("") if s.strip()]
return sentences[:max_sentences]
def extract_structured_fields(self, text: str) -> Dict[str, Any]:
"""
尝试提取结构化字段
针对合同、简历等有固定格式的文档
Args:
text: 文本内容
Returns:
提取的字段字典
"""
fields = {}
# 常见字段模式
patterns = {
"姓名": r"姓名[:]\s*(\S+)",
"电话": r"电话[:]\s*(\d{11}|\d{3}-\d{8})",
"邮箱": r"邮箱[:]\s*(\S+@\S+)",
"地址": r"地址[:]\s*(.+?)(?:\n|$)",
"金额": r"金额[:]\s*(\d+(?:\.\d+)?)",
"日期": r"日期[:]\s*(\d{4}[年/-]\d{1,2}[月/-]\d{1,2})",
}
import re
for field_name, pattern in patterns.items():
match = re.search(pattern, text)
if match:
fields[field_name] = match.group(1)
return fields

View File

@@ -0,0 +1,262 @@
"""
Markdown 文档解析器
"""
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
import markdown
from .base import BaseParser, ParseResult
logger = logging.getLogger(__name__)
class MarkdownParser(BaseParser):
"""Markdown 文档解析器"""
def __init__(self):
super().__init__()
self.supported_extensions = ['.md', '.markdown']
self.parser_name = "markdown_parser"
def parse(
self,
file_path: str,
**kwargs
) -> ParseResult:
"""
解析 Markdown 文档
Args:
file_path: 文件路径
**kwargs: 其他参数
Returns:
ParseResult: 解析结果
"""
path = Path(file_path)
# 检查文件是否存在
if not path.exists():
return ParseResult(
success=False,
error=f"文件不存在: {file_path}"
)
# 检查文件扩展名
if path.suffix.lower() not in self.supported_extensions:
return ParseResult(
success=False,
error=f"不支持的文件类型: {path.suffix}"
)
try:
# 读取文件内容
with open(file_path, 'r', encoding='utf-8') as f:
raw_content = f.read()
# 解析 Markdown
md = markdown.Markdown(extensions=[
'markdown.extensions.tables',
'markdown.extensions.fenced_code',
'markdown.extensions.codehilite',
'markdown.extensions.toc',
])
html_content = md.convert(raw_content)
# 提取标题结构
titles = self._extract_titles(raw_content)
# 提取代码块
code_blocks = self._extract_code_blocks(raw_content)
# 提取表格
tables = self._extract_tables(raw_content)
# 提取链接和图片
links_images = self._extract_links_images(raw_content)
# 清理后的纯文本(去除 Markdown 语法)
plain_text = self._strip_markdown(raw_content)
# 构建元数据
metadata = {
"filename": path.name,
"extension": path.suffix.lower(),
"file_size": path.stat().st_size,
"word_count": len(plain_text),
"char_count": len(raw_content),
"line_count": len(raw_content.splitlines()),
"title_count": len(titles),
"code_block_count": len(code_blocks),
"table_count": len(tables),
"link_count": len(links_images.get("links", [])),
"image_count": len(links_images.get("images", [])),
}
return ParseResult(
success=True,
data={
"content": plain_text,
"raw_content": raw_content,
"html_content": html_content,
"titles": titles,
"code_blocks": code_blocks,
"tables": tables,
"links_images": links_images,
"word_count": len(plain_text),
"structured_data": {
"titles": titles,
"code_blocks": code_blocks,
"tables": tables
}
},
metadata=metadata
)
except Exception as e:
logger.error(f"解析 Markdown 文档失败: {str(e)}")
return ParseResult(
success=False,
error=f"解析 Markdown 文档失败: {str(e)}"
)
def _extract_titles(self, content: str) -> List[Dict[str, Any]]:
"""提取标题结构"""
import re
titles = []
# 匹配 # 标题
for match in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE):
level = len(match.group(1))
title_text = match.group(2).strip()
titles.append({
"level": level,
"text": title_text,
"line": content[:match.start()].count('\n') + 1
})
return titles
def _extract_code_blocks(self, content: str) -> List[Dict[str, str]]:
"""提取代码块"""
import re
code_blocks = []
# 匹配 ```code ``` 格式
pattern = r'```(\w*)\n(.*?)```'
for match in re.finditer(pattern, content, re.DOTALL):
language = match.group(1) or "text"
code = match.group(2).strip()
code_blocks.append({
"language": language,
"code": code
})
return code_blocks
def _extract_tables(self, content: str) -> List[Dict[str, Any]]:
"""提取表格"""
import re
tables = []
# 简单表格匹配(| col1 | col2 | 格式)
lines = content.split('\n')
i = 0
while i < len(lines):
line = lines[i].strip()
# 检查是否是表格行
if line.startswith('|') and line.endswith('|'):
# 找到表头
header_row = [cell.strip() for cell in line.split('|')[1:-1]]
# 检查下一行是否是分隔符
if i + 1 < len(lines) and re.match(r'^\|[\s\-:|]+\|$', lines[i + 1]):
# 跳过分隔符,读取数据行
data_rows = []
for j in range(i + 2, len(lines)):
row_line = lines[j].strip()
if not (row_line.startswith('|') and row_line.endswith('|')):
break
row_data = [cell.strip() for cell in row_line.split('|')[1:-1]]
data_rows.append(row_data)
if header_row and data_rows:
tables.append({
"headers": header_row,
"rows": data_rows,
"row_count": len(data_rows),
"column_count": len(header_row)
})
i = j - 1
i += 1
return tables
def _extract_links_images(self, content: str) -> Dict[str, List[Dict[str, str]]]:
"""提取链接和图片"""
import re
result = {"links": [], "images": []}
# 提取链接 [text](url)
for match in re.finditer(r'\[([^\]]+)\]\(([^\)]+)\)', content):
result["links"].append({
"text": match.group(1),
"url": match.group(2)
})
# 提取图片 ![alt](url)
for match in re.finditer(r'!\[([^\]]*)\]\(([^\)]+)\)', content):
result["images"].append({
"alt": match.group(1),
"url": match.group(2)
})
return result
def _strip_markdown(self, content: str) -> str:
"""去除 Markdown 语法,获取纯文本"""
import re
# 去除代码块
content = re.sub(r'```[\s\S]*?```', '', content)
# 去除行内代码
content = re.sub(r'`[^`]+`', '', content)
# 去除图片
content = re.sub(r'!\[([^\]]*)\]\([^\)]+\)', r'\1', content)
# 去除链接,保留文本
content = re.sub(r'\[([^\]]+)\]\([^\)]+\)', r'\1', content)
# 去除标题标记
content = re.sub(r'^#{1,6}\s+', '', content, flags=re.MULTILINE)
# 去除加粗和斜体
content = re.sub(r'\*\*([^\*]+)\*\*', r'\1', content)
content = re.sub(r'\*([^\*]+)\*', r'\1', content)
content = re.sub(r'__([^_]+)__', r'\1', content)
content = re.sub(r'_([^_]+)_', r'\1', content)
# 去除引用标记
content = re.sub(r'^>\s+', '', content, flags=re.MULTILINE)
# 去除列表标记
content = re.sub(r'^[-*+]\s+', '', content, flags=re.MULTILINE)
content = re.sub(r'^\d+\.\s+', '', content, flags=re.MULTILINE)
# 去除水平线
content = re.sub(r'^[-*_]{3,}$', '', content, flags=re.MULTILINE)
# 去除表格分隔符
content = re.sub(r'^\|[\s\-:|]+\|$', '', content, flags=re.MULTILINE)
# 清理多余空行
content = re.sub(r'\n{3,}', '\n\n', content)
return content.strip()

View File

@@ -0,0 +1,278 @@
"""
纯文本 (.txt) 解析器
"""
import logging
import re
from pathlib import Path
from typing import Any, Dict, List, Optional
import chardet
from .base import BaseParser, ParseResult
logger = logging.getLogger(__name__)
class TxtParser(BaseParser):
"""纯文本文档解析器"""
def __init__(self):
super().__init__()
self.supported_extensions = ['.txt']
self.parser_name = "txt_parser"
def parse(
self,
file_path: str,
encoding: Optional[str] = None,
**kwargs
) -> ParseResult:
"""
解析文本文件
Args:
file_path: 文件路径
encoding: 指定编码,不指定则自动检测
**kwargs: 其他参数
Returns:
ParseResult: 解析结果
"""
path = Path(file_path)
# 检查文件是否存在
if not path.exists():
return ParseResult(
success=False,
error=f"文件不存在: {file_path}"
)
# 检查文件扩展名
if path.suffix.lower() not in self.supported_extensions:
return ParseResult(
success=False,
error=f"不支持的文件类型: {path.suffix}"
)
try:
# 检测编码
if not encoding:
encoding = self._detect_encoding(file_path)
# 读取文件内容
with open(file_path, 'r', encoding=encoding) as f:
raw_content = f.read()
# 清理文本
content = self._clean_text(raw_content)
# 提取行信息
lines = content.split('\n')
# 估算字数
word_count = len(content.replace('\n', '').replace(' ', ''))
# 构建元数据
metadata = {
"filename": path.name,
"extension": path.suffix.lower(),
"file_size": path.stat().st_size,
"encoding": encoding,
"line_count": len(lines),
"word_count": word_count,
"char_count": len(content),
"non_empty_line_count": len([l for l in lines if l.strip()])
}
return ParseResult(
success=True,
data={
"content": content,
"raw_content": raw_content,
"lines": lines,
"word_count": word_count,
"char_count": len(content),
"line_count": len(lines),
"structured_data": {
"line_count": len(lines),
"non_empty_line_count": metadata["non_empty_line_count"]
}
},
metadata=metadata
)
except Exception as e:
logger.error(f"解析文本文件失败: {str(e)}")
return ParseResult(
success=False,
error=f"解析文本文件失败: {str(e)}"
)
def _detect_encoding(self, file_path: str) -> str:
"""
自动检测文件编码
Args:
file_path: 文件路径
Returns:
检测到的编码
"""
try:
with open(file_path, 'rb') as f:
raw_data = f.read()
result = chardet.detect(raw_data)
encoding = result.get('encoding', 'utf-8')
# 验证编码是否有效
if encoding:
try:
raw_data.decode(encoding)
return encoding
except (UnicodeDecodeError, LookupError):
pass
return 'utf-8'
except Exception as e:
logger.warning(f"编码检测失败,使用默认编码: {str(e)}")
return 'utf-8'
def _clean_text(self, text: str) -> str:
"""
清理文本内容
- 去除多余空白字符
- 规范化换行符
- 去除特殊控制字符
Args:
text: 原始文本
Returns:
清理后的文本
"""
# 规范化换行符
text = text.replace('\r\n', '\n').replace('\r', '\n')
# 去除控制字符除了换行和tab
text = re.sub(r'[\x00-\x08\x0b-\x0c\x0e-\x1f\x7f]', '', text)
# 将多个连续空格合并为一个
text = re.sub(r'[ \t]+', ' ', text)
# 将多个连续空行合并为一个
text = re.sub(r'\n{3,}', '\n\n', text)
return text.strip()
def extract_structured_data(self, content: str) -> Dict[str, Any]:
"""
尝试从文本中提取结构化数据
支持提取:
- 邮箱地址
- URL
- 电话号码
- 日期
- 金额
Args:
content: 文本内容
Returns:
结构化数据字典
"""
data = {
"emails": [],
"urls": [],
"phones": [],
"dates": [],
"amounts": []
}
# 提取邮箱
emails = re.findall(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', content)
data["emails"] = list(set(emails))
# 提取 URL
urls = re.findall(r'https?://[^\s<>"{}|\\^`\[\]]+', content)
data["urls"] = list(set(urls))
# 提取电话号码 (支持多种格式)
phone_patterns = [
r'1[3-9]\d{9}', # 手机号
r'\d{3,4}-\d{7,8}', # 固话
]
phones = []
for pattern in phone_patterns:
phones.extend(re.findall(pattern, content))
data["phones"] = list(set(phones))
# 提取日期
date_patterns = [
r'\d{4}[-/年]\d{1,2}[-/月]\d{1,2}[日]?',
r'\d{4}\.\d{1,2}\.\d{1,2}',
]
dates = []
for pattern in date_patterns:
dates.extend(re.findall(pattern, content))
data["dates"] = list(set(dates))
# 提取金额
amount_patterns = [
r'¥\s*\d+(?:\.\d{1,2})?',
r'\$\s*\d+(?:\.\d{1,2})?',
r'\d+(?:\.\d{1,2})?\s*元',
]
amounts = []
for pattern in amount_patterns:
amounts.extend(re.findall(pattern, content))
data["amounts"] = list(set(amounts))
return data
def split_into_chunks(
self,
content: str,
chunk_size: int = 1000,
overlap: int = 100
) -> List[str]:
"""
将长文本分割成块
用于 RAG 索引或 LLM 处理
Args:
content: 文本内容
chunk_size: 每块字符数
overlap: 块之间的重叠字符数
Returns:
文本块列表
"""
if len(content) <= chunk_size:
return [content]
chunks = []
start = 0
while start < len(content):
end = start + chunk_size
chunk = content[start:end]
# 尝试在句子边界分割
if end < len(content):
last_period = chunk.rfind('')
last_newline = chunk.rfind('\n')
split_pos = max(last_period, last_newline)
if split_pos > chunk_size // 2:
chunk = chunk[:split_pos + 1]
end = start + split_pos + 1
chunks.append(chunk)
start = end - overlap if end < len(content) else end
return chunks

View File

@@ -0,0 +1,352 @@
"""
Excel 存储服务
将 Excel 数据转换为 MySQL 表结构并存储
"""
import logging
import re
from datetime import datetime
from typing import Any, Dict, List, Optional
import pandas as pd
from sqlalchemy import (
Column,
DateTime,
Float,
Integer,
String,
Text,
inspect,
)
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database.mysql import Base, mysql_db
logger = logging.getLogger(__name__)
class ExcelStorageService:
"""Excel 数据存储服务"""
def __init__(self):
self.mysql_db = mysql_db
def _sanitize_table_name(self, filename: str) -> str:
"""
将文件名转换为合法的表名
Args:
filename: 原始文件名
Returns:
合法的表名
"""
# 移除扩展名
name = filename.rsplit('.', 1)[0] if '.' in filename else filename
# 只保留字母、数字、下划线
name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
# 确保以字母开头
if name and name[0].isdigit():
name = 't_' + name
# 限制长度
return name[:50]
def _sanitize_column_name(self, col_name: str) -> str:
"""
将列名转换为合法的字段名
Args:
col_name: 原始列名
Returns:
合法的字段名
"""
# 只保留字母、数字、下划线
name = re.sub(r'[^a-zA-Z0-9_]', '_', str(col_name))
# 确保以字母开头
if name and name[0].isdigit():
name = 'col_' + name
# 限制长度
return name[:50]
def _infer_column_type(self, series: pd.Series) -> str:
"""
根据数据推断列类型
Args:
series: pandas Series
Returns:
类型名称
"""
dtype = series.dtype
if pd.api.types.is_integer_dtype(dtype):
return "INTEGER"
elif pd.api.types.is_float_dtype(dtype):
return "FLOAT"
elif pd.api.types.is_datetime64_any_dtype(dtype):
return "DATETIME"
elif pd.api.types.is_bool_dtype(dtype):
return "BOOLEAN"
else:
return "TEXT"
def _create_table_model(
self,
table_name: str,
columns: List[str],
column_types: Dict[str, str]
) -> type:
"""
动态创建 SQLAlchemy 模型类
Args:
table_name: 表名
columns: 列名列表
column_types: 列类型字典
Returns:
SQLAlchemy 模型类
"""
# 创建属性字典
attrs = {
'__tablename__': table_name,
'__table_args__': {'extend_existing': True},
}
# 添加主键列
attrs['id'] = Column(Integer, primary_key=True, autoincrement=True)
# 添加数据列
for col in columns:
col_name = self._sanitize_column_name(col)
col_type = column_types.get(col, "TEXT")
if col_type == "INTEGER":
attrs[col_name] = Column(Integer, nullable=True)
elif col_type == "FLOAT":
attrs[col_name] = Column(Float, nullable=True)
elif col_type == "DATETIME":
attrs[col_name] = Column(DateTime, nullable=True)
elif col_type == "BOOLEAN":
attrs[col_name] = Column(Integer, nullable=True) # MySQL 没有原生 BOOLEAN
else:
attrs[col_name] = Column(Text, nullable=True)
# 添加元数据列
attrs['created_at'] = Column(DateTime, default=datetime.utcnow)
attrs['updated_at'] = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
# 创建类
return type(table_name, (Base,), attrs)
async def store_excel(
self,
file_path: str,
filename: str,
sheet_name: Optional[str] = None,
header_row: int = 0
) -> Dict[str, Any]:
"""
将 Excel 文件存储到 MySQL
Args:
file_path: Excel 文件路径
filename: 原始文件名
sheet_name: 工作表名称
header_row: 表头行号
Returns:
存储结果
"""
table_name = self._sanitize_table_name(filename)
results = {
"success": True,
"table_name": table_name,
"row_count": 0,
"columns": []
}
try:
# 读取 Excel
if sheet_name:
df = pd.read_excel(file_path, sheet_name=sheet_name, header=header_row)
else:
df = pd.read_excel(file_path, header=header_row)
if df.empty:
return {"success": False, "error": "Excel 文件为空"}
# 清理列名
df.columns = [str(c) for c in df.columns]
# 推断列类型
column_types = {}
for col in df.columns:
col_name = self._sanitize_column_name(col)
col_type = self._infer_column_type(df[col])
column_types[col] = col_type
results["columns"].append({
"original_name": col,
"sanitized_name": col_name,
"type": col_type
})
# 创建表
model_class = self._create_table_model(table_name, df.columns, column_types)
# 创建表结构
async with self.mysql_db.get_session() as session:
model_class.__table__.create(session.bind, checkfirst=True)
# 插入数据
records = []
for _, row in df.iterrows():
record = {}
for col in df.columns:
col_name = self._sanitize_column_name(col)
value = row[col]
# 处理 NaN 值
if pd.isna(value):
record[col_name] = None
elif column_types[col] == "INTEGER":
try:
record[col_name] = int(value)
except (ValueError, TypeError):
record[col_name] = None
elif column_types[col] == "FLOAT":
try:
record[col_name] = float(value)
except (ValueError, TypeError):
record[col_name] = None
else:
record[col_name] = str(value)
records.append(record)
# 批量插入
async with self.mysql_db.get_session() as session:
for record in records:
session.add(model_class(**record))
await session.commit()
results["row_count"] = len(records)
logger.info(f"Excel 数据已存储到 MySQL 表 {table_name},共 {len(records)}")
return results
except Exception as e:
logger.error(f"存储 Excel 到 MySQL 失败: {str(e)}")
return {"success": False, "error": str(e)}
async def query_table(
self,
table_name: str,
columns: Optional[List[str]] = None,
where: Optional[str] = None,
limit: int = 100
) -> List[Dict[str, Any]]:
"""
查询 MySQL 表数据
Args:
table_name: 表名
columns: 要查询的列
where: WHERE 条件
limit: 限制返回行数
Returns:
查询结果
"""
try:
# 构建查询
sql = f"SELECT * FROM `{table_name}`"
if where:
sql += f" WHERE {where}"
sql += f" LIMIT {limit}"
results = await self.mysql_db.execute_query(sql)
return results
except Exception as e:
logger.error(f"查询表失败: {str(e)}")
return []
async def get_table_schema(self, table_name: str) -> Optional[Dict[str, Any]]:
"""
获取表结构信息
Args:
table_name: 表名
Returns:
表结构信息
"""
try:
sql = f"""
SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_KEY, COLUMN_COMMENT
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = '{table_name}'
ORDER BY ORDINAL_POSITION
"""
results = await self.mysql_db.execute_query(sql)
return results
except Exception as e:
logger.error(f"获取表结构失败: {str(e)}")
return None
async def delete_table(self, table_name: str) -> bool:
"""
删除表
Args:
table_name: 表名
Returns:
是否成功
"""
try:
# 安全检查:表名必须包含下划线(避免删除系统表)
if '_' not in table_name and not table_name.startswith('t_'):
raise ValueError("不允许删除此表")
sql = f"DROP TABLE IF EXISTS `{table_name}`"
await self.mysql_db.execute_raw_sql(sql)
logger.info(f"{table_name} 已删除")
return True
except Exception as e:
logger.error(f"删除表失败: {str(e)}")
return False
async def list_tables(self) -> List[str]:
"""
列出所有用户表
Returns:
表名列表
"""
try:
sql = """
SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_TYPE = 'BASE TABLE'
"""
results = await self.mysql_db.execute_query(sql)
return [r['TABLE_NAME'] for r in results]
except Exception as e:
logger.error(f"列出表失败: {str(e)}")
return []
# ==================== 全局单例 ====================
excel_storage_service = ExcelStorageService()

View File

@@ -0,0 +1,444 @@
"""
提示词工程服务
管理和优化与大模型交互的提示词
"""
import json
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
class PromptType(Enum):
"""提示词类型"""
DOCUMENT_PARSING = "document_parsing" # 文档解析
FIELD_EXTRACTION = "field_extraction" # 字段提取
TABLE_FILLING = "table_filling" # 表格填写
QUERY_GENERATION = "query_generation" # 查询生成
TEXT_SUMMARY = "text_summary" # 文本摘要
INTENT_CLASSIFICATION = "intent_classification" # 意图分类
DATA_CLASSIFICATION = "data_classification" # 数据分类
@dataclass
class PromptTemplate:
"""提示词模板"""
name: str
type: PromptType
system_prompt: str
user_template: str
examples: List[Dict[str, str]] = field(default_factory=list) # Few-shot 示例
rules: List[str] = field(default_factory=list) # 特殊规则
def format(
self,
context: Dict[str, Any],
user_input: Optional[str] = None
) -> List[Dict[str, str]]:
"""
格式化提示词
Args:
context: 上下文数据
user_input: 用户输入
Returns:
格式化后的消息列表
"""
messages = []
# 系统提示词
system_content = self.system_prompt
# 添加规则
if self.rules:
system_content += "\n\n【输出规则】\n" + "\n".join([f"- {rule}" for rule in self.rules])
# 添加示例
if self.examples:
system_content += "\n\n【示例】\n"
for i, ex in enumerate(self.examples):
system_content += f"\n示例 {i+1}:\n"
system_content += f"输入: {ex.get('input', '')}\n"
system_content += f"输出: {ex.get('output', '')}\n"
messages.append({"role": "system", "content": system_content})
# 用户提示词
user_content = self._format_user_template(context, user_input)
messages.append({"role": "user", "content": user_content})
return messages
def _format_user_template(
self,
context: Dict[str, Any],
user_input: Optional[str]
) -> str:
"""格式化用户模板"""
content = self.user_template
# 替换上下文变量
for key, value in context.items():
placeholder = f"{{{key}}}"
if placeholder in content:
if isinstance(value, (dict, list)):
content = content.replace(placeholder, json.dumps(value, ensure_ascii=False, indent=2))
else:
content = content.replace(placeholder, str(value))
# 添加用户输入
if user_input:
content += f"\n\n【用户需求】\n{user_input}"
return content
class PromptEngineeringService:
"""提示词工程服务"""
def __init__(self):
self.templates: Dict[PromptType, PromptTemplate] = {}
self._init_templates()
def _init_templates(self):
"""初始化所有提示词模板"""
# ==================== 文档解析模板 ====================
self.templates[PromptType.DOCUMENT_PARSING] = PromptTemplate(
name="文档解析",
type=PromptType.DOCUMENT_PARSING,
system_prompt="""你是一个专业的文档解析专家。你的任务是从各类文档Word、Excel、Markdown、纯文本中提取关键信息。
请严格按照JSON格式输出解析结果
{
"success": true/false,
"document_type": "文档类型",
"key_fields": {"字段名": "字段值", ...},
"summary": "文档摘要100字内",
"structured_data": {...} // 提取的表格或其他结构化数据
}
重要规则:
- 只提取明确存在的信息,不要猜测
- 如果是表格数据,请以数组格式输出
- 日期请使用 YYYY-MM-DD 格式
- 金额请使用数字格式
- 如果无法提取某个字段,设置为 null""",
user_template="""请解析以下文档内容:
=== 文档开始 ===
{content}
=== 文档结束 ===
请提取文档中的关键信息。""",
examples=[
{
"input": "合同金额100万元\n签订日期2024年1月15日\n甲方:张三\n乙方:某某公司",
"output": '{"success": true, "document_type": "合同", "key_fields": {"金额": 1000000, "日期": "2024-01-15", "甲方": "张三", "乙方": "某某公司"}, "summary": "甲乙双方签订的金额为100万元的合同", "structured_data": null}'
}
],
rules=[
"只输出JSON不要添加任何解释",
"使用严格的JSON格式"
]
)
# ==================== 字段提取模板 ====================
self.templates[PromptType.FIELD_EXTRACTION] = PromptTemplate(
name="字段提取",
type=PromptType.FIELD_EXTRACTION,
system_prompt="""你是一个专业的数据提取专家。你的任务是从文档内容中提取指定字段的信息。
请严格按照以下JSON格式输出
{
"value": "提取到的值,找不到则为空字符串",
"source": "数据来源描述",
"confidence": 0.0到1.0之间的置信度
}
重要规则:
- 严格按字段名称匹配,不要提取无关信息
- 置信度反映你对提取结果的信心程度
- 如果字段不存在或无法确定value设为空字符串confidence设为0.0
- value必须是实际值不能是"未找到"之类的描述""",
user_template="""请从以下文档内容中提取指定字段的信息。
【需要提取的字段】
字段名称:{field_name}
字段类型:{field_type}
是否必填:{required}
【用户提示】
{hint}
【文档内容】
{context}
请提取字段值。""",
examples=[
{
"input": "文档内容姓名张三电话13800138000邮箱zhangsan@example.com",
"output": '{"value": "张三", "source": "文档第1行", "confidence": 1.0}'
}
],
rules=[
"只输出JSON不要添加任何解释"
]
)
# ==================== 表格填写模板 ====================
self.templates[PromptType.TABLE_FILLING] = PromptTemplate(
name="表格填写",
type=PromptType.TABLE_FILLING,
system_prompt="""你是一个专业的表格填写助手。你的任务是根据提供的文档内容,填写表格模板中的字段。
请严格按照以下JSON格式输出
{
"filled_data": {{"字段1": "值1", "字段2": "值2", ...}},
"fill_details": [
{{"field": "字段1", "value": "值1", "source": "来源", "confidence": 0.95}},
...
]
}
重要规则:
- 只填写模板中存在的字段
- 值必须来自提供的文档内容,不要编造
- 如果某个字段在文档中找不到对应值,设为空字符串
- fill_details 中记录每个字段的详细信息""",
user_template="""请根据以下文档内容,填写表格模板。
【表格模板字段】
{fields}
【用户需求】
{hint}
【参考文档内容】
{context}
请填写表格。""",
examples=[
{
"input": "字段:姓名、电话\n文档张三电话是13800138000",
"output": '{"filled_data": {"姓名": "张三", "电话": "13800138000"}, "fill_details": [{"field": "姓名", "value": "张三", "source": "文档第1行", "confidence": 1.0}, {"field": "电话", "value": "13800138000", "source": "文档第1行", "confidence": 1.0}]}'
}
],
rules=[
"只输出JSON不要添加任何解释"
]
)
# ==================== 查询生成模板 ====================
self.templates[PromptType.QUERY_GENERATION] = PromptTemplate(
name="查询生成",
type=PromptType.QUERY_GENERATION,
system_prompt="""你是一个SQL查询生成专家。你的任务是根据用户的自然语言需求生成相应的数据库查询语句。
请严格按照以下JSON格式输出
{
"sql_query": "生成的SQL查询语句",
"explanation": "查询逻辑说明"
}
重要规则:
- 只生成 SELECT 查询语句,不要生成 INSERT/UPDATE/DELETE
- 必须包含 WHERE 条件限制查询范围
- 表名和字段名使用反引号包裹
- 确保SQL语法正确
- 如果无法生成有效的查询sql_query设为空字符串""",
user_template="""根据以下信息生成查询语句。
【数据库表结构】
{table_schema}
【RAG检索到的上下文】
{rag_context}
【用户查询需求】
{user_intent}
请生成SQL查询。""",
examples=[
{
"input": "orders(订单号, 金额, 日期, 客户)\n需求查询2024年1月销售额超过10000的订单",
"output": '{"sql_query": "SELECT * FROM `orders` WHERE `日期` >= \\'2024-01-01\\' AND `日期` < \\'2024-02-01\\' AND `金额` > 10000", "explanation": "筛选2024年1月销售额超过10000的订单"}'
}
],
rules=[
"只输出JSON不要添加任何解释",
"禁止生成 DROP、DELETE、TRUNCATE 等危险操作"
]
)
# ==================== 文本摘要模板 ====================
self.templates[PromptType.TEXT_SUMMARY] = PromptTemplate(
name="文本摘要",
type=PromptType.TEXT_SUMMARY,
system_prompt="""你是一个专业的文本摘要专家。你的任务是对长文档进行压缩,提取关键信息。
请严格按照以下JSON格式输出
{
"summary": "摘要内容不超过200字",
"key_points": ["要点1", "要点2", "要点3"],
"keywords": ["关键词1", "关键词2", "关键词3"]
}""",
user_template="""请为以下文档生成摘要:
=== 文档开始 ===
{content}
=== 文档结束 ===
生成简明摘要。""",
rules=[
"只输出JSON不要添加任何解释"
]
)
# ==================== 意图分类模板 ====================
self.templates[PromptType.INTENT_CLASSIFICATION] = PromptTemplate(
name="意图分类",
type=PromptType.INTENT_CLASSIFICATION,
system_prompt="""你是一个意图分类专家。你的任务是分析用户的自然语言输入,判断用户的真实意图。
支持的意图类型:
- upload: 上传文档
- parse: 解析文档
- query: 查询数据
- fill: 填写表格
- export: 导出数据
- analyze: 分析数据
- other: 其他/未知
请严格按照以下JSON格式输出
{
"intent": "意图类型",
"confidence": 0.0到1.0之间的置信度,
"entities": {{"实体名": "实体值", ...}}, // 识别出的关键实体
"suggestion": "建议的下一步操作"
}""",
user_template="""请分析以下用户输入,判断其意图:
【用户输入】
{user_input}
请分类。""",
rules=[
"只输出JSON不要添加任何解释"
]
)
# ==================== 数据分类模板 ====================
self.templates[PromptType.DATA_CLASSIFICATION] = PromptTemplate(
name="数据分类",
type=PromptType.DATA_CLASSIFICATION,
system_prompt="""你是一个数据分类专家。你的任务是判断数据的类型和格式。
请严格按照以下JSON格式输出
{
"data_type": "text/number/date/email/phone/url/amount/other",
"format": "具体格式描述",
"is_valid": true/false,
"normalized_value": "规范化后的值"
}""",
user_template="""请分析以下数据的类型和格式:
【数据】
{value}
【期望类型(如果有)】
{expected_type}
请分类。""",
rules=[
"只输出JSON不要添加任何解释"
]
)
def get_prompt(
self,
type: PromptType,
context: Dict[str, Any],
user_input: Optional[str] = None
) -> List[Dict[str, str]]:
"""
获取格式化后的提示词
Args:
type: 提示词类型
context: 上下文数据
user_input: 用户输入
Returns:
消息列表
"""
template = self.templates.get(type)
if not template:
logger.warning(f"未找到提示词模板: {type}")
return [{"role": "user", "content": str(context)}]
return template.format(context, user_input)
def get_template(self, type: PromptType) -> Optional[PromptTemplate]:
"""获取提示词模板"""
return self.templates.get(type)
def add_template(self, template: PromptTemplate):
"""添加自定义提示词模板"""
self.templates[template.type] = template
logger.info(f"已添加提示词模板: {template.name}")
def update_template(self, type: PromptType, **kwargs):
"""更新提示词模板"""
template = self.templates.get(type)
if template:
for key, value in kwargs.items():
if hasattr(template, key):
setattr(template, key, value)
def optimize_prompt(
self,
type: PromptType,
feedback: str,
iteration: int = 1
) -> List[Dict[str, str]]:
"""
根据反馈优化提示词
Args:
type: 提示词类型
feedback: 优化反馈
iteration: 迭代次数
Returns:
优化后的提示词
"""
template = self.templates.get(type)
if not template:
return []
# 简单优化策略:根据反馈添加规则
optimization_rules = {
"准确率低": "提高要求,明确指出必须从原文提取,不要猜测",
"格式错误": "强调JSON格式要求提供更详细的格式示例",
"遗漏信息": "添加提取更多细节的要求",
}
new_rules = []
for keyword, rule in optimization_rules.items():
if keyword in feedback:
new_rules.append(rule)
if new_rules:
template.rules.extend(new_rules)
return template.format({}, None)
# ==================== 全局单例 ====================
prompt_service = PromptEngineeringService()

View File

@@ -0,0 +1,307 @@
"""
表格模板填写服务
从非结构化文档中检索信息并填写到表格模板
"""
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from app.core.database import mongodb
from app.services.rag_service import rag_service
from app.services.llm_service import llm_service
from app.services.excel_storage_service import excel_storage_service
logger = logging.getLogger(__name__)
@dataclass
class TemplateField:
"""模板字段"""
cell: str # 单元格位置,如 "A1"
name: str # 字段名称
field_type: str = "text" # 字段类型: text/number/date
required: bool = True
@dataclass
class FillResult:
"""填写结果"""
field: str
value: Any
source: str # 来源文档
confidence: float = 1.0 # 置信度
class TemplateFillService:
"""表格填写服务"""
def __init__(self):
self.llm = llm_service
self.rag = rag_service
async def fill_template(
self,
template_fields: List[TemplateField],
source_doc_ids: Optional[List[str]] = None,
user_hint: Optional[str] = None
) -> Dict[str, Any]:
"""
填写表格模板
Args:
template_fields: 模板字段列表
source_doc_ids: 源文档ID列表不指定则从所有文档检索
user_hint: 用户提示(如"请从合同文档中提取"
Returns:
填写结果
"""
filled_data = {}
fill_details = []
for field in template_fields:
try:
# 1. 从 RAG 检索相关上下文
rag_results = await self._retrieve_context(field.name, user_hint)
if not rag_results:
# 如果没有检索到结果,尝试直接询问 LLM
result = FillResult(
field=field.name,
value="",
source="未找到相关数据",
confidence=0.0
)
else:
# 2. 构建 Prompt 让 LLM 提取信息
result = await self._extract_field_value(
field=field,
rag_context=rag_results,
user_hint=user_hint
)
# 3. 存储结果
filled_data[field.name] = result.value
fill_details.append({
"field": field.name,
"cell": field.cell,
"value": result.value,
"source": result.source,
"confidence": result.confidence
})
logger.info(f"字段 {field.name} 填写完成: {result.value}")
except Exception as e:
logger.error(f"填写字段 {field.name} 失败: {str(e)}")
filled_data[field.name] = f"[提取失败: {str(e)}]"
fill_details.append({
"field": field.name,
"cell": field.cell,
"value": f"[提取失败]",
"source": "error",
"confidence": 0.0
})
return {
"success": True,
"filled_data": filled_data,
"fill_details": fill_details
}
async def _retrieve_context(
self,
field_name: str,
user_hint: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
从 RAG 检索相关上下文
Args:
field_name: 字段名称
user_hint: 用户提示
Returns:
检索结果列表
"""
# 构建查询文本
query = field_name
if user_hint:
query = f"{user_hint} {field_name}"
# 检索相关文档片段
results = self.rag.retrieve(query=query, top_k=5)
return results
async def _extract_field_value(
self,
field: TemplateField,
rag_context: List[Dict[str, Any]],
user_hint: Optional[str] = None
) -> FillResult:
"""
使用 LLM 从上下文中提取字段值
Args:
field: 字段定义
rag_context: RAG 检索到的上下文
user_hint: 用户提示
Returns:
提取结果
"""
# 构建上下文文本
context_text = "\n\n".join([
f"【文档 {i+1}\n{doc['content']}"
for i, doc in enumerate(rag_context)
])
# 构建 Prompt
prompt = f"""你是一个数据提取专家。请根据以下文档内容,提取指定字段的信息。
需要提取的字段:
- 字段名称:{field.name}
- 字段类型:{field.field_type}
- 是否必填:{'' if field.required else ''}
{'用户提示:' + user_hint if user_hint else ''}
参考文档内容:
{context_text}
请严格按照以下 JSON 格式输出,不要添加任何解释:
{{
"value": "提取到的值,如果没有找到则填写空字符串",
"source": "数据来源的文档描述",
"confidence": 0.0到1.0之间的置信度
}}
"""
# 调用 LLM
messages = [
{"role": "system", "content": "你是一个专业的数据提取助手。请严格按JSON格式输出。"},
{"role": "user", "content": prompt}
]
try:
response = await self.llm.chat(
messages=messages,
temperature=0.1,
max_tokens=500
)
content = self.llm.extract_message_content(response)
# 解析 JSON 响应
import json
import re
# 尝试提取 JSON
json_match = re.search(r'\{[\s\S]*\}', content)
if json_match:
result = json.loads(json_match.group())
return FillResult(
field=field.name,
value=result.get("value", ""),
source=result.get("source", "LLM生成"),
confidence=result.get("confidence", 0.5)
)
else:
# 如果无法解析,返回原始内容
return FillResult(
field=field.name,
value=content.strip(),
source="直接提取",
confidence=0.5
)
except Exception as e:
logger.error(f"LLM 提取失败: {str(e)}")
return FillResult(
field=field.name,
value="",
source=f"提取失败: {str(e)}",
confidence=0.0
)
async def get_template_fields_from_file(
self,
file_path: str,
file_type: str = "xlsx"
) -> List[TemplateField]:
"""
从模板文件提取字段定义
Args:
file_path: 模板文件路径
file_type: 文件类型
Returns:
字段列表
"""
fields = []
try:
if file_type in ["xlsx", "xls"]:
# 从 Excel 读取表头
import pandas as pd
df = pd.read_excel(file_path, nrows=5)
for idx, col in enumerate(df.columns):
# 获取单元格位置 (A, B, C, ...)
cell = self._column_to_cell(idx)
fields.append(TemplateField(
cell=cell,
name=str(col),
field_type=self._infer_field_type(df[col]),
required=True
))
elif file_type == "docx":
# 从 Word 表格读取
from docx import Document
doc = Document(file_path)
for table_idx, table in enumerate(doc.tables):
for row_idx, row in enumerate(table.rows):
for col_idx, cell in enumerate(row.cells):
cell_text = cell.text.strip()
if cell_text:
fields.append(TemplateField(
cell=self._column_to_cell(col_idx),
name=cell_text,
field_type="text",
required=True
))
except Exception as e:
logger.error(f"提取模板字段失败: {str(e)}")
return fields
def _column_to_cell(self, col_idx: int) -> str:
"""将列索引转换为单元格列名 (0 -> A, 1 -> B, ...)"""
result = ""
while col_idx >= 0:
result = chr(65 + (col_idx % 26)) + result
col_idx = col_idx // 26 - 1
return result
def _infer_field_type(self, series) -> str:
"""推断字段类型"""
import pandas as pd
if pd.api.types.is_numeric_dtype(series):
return "number"
elif pd.api.types.is_datetime64_any_dtype(series):
return "date"
else:
return "text"
# ==================== 全局单例 ====================
template_fill_service = TemplateFillService()