添加其他格式文档的解析
This commit is contained in:
@@ -4,10 +4,11 @@ API 路由注册模块
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from app.api.endpoints import (
|
from app.api.endpoints import (
|
||||||
upload,
|
upload,
|
||||||
documents, # 新增:文档上传
|
documents, # 多格式文档上传
|
||||||
tasks, # 新增:任务管理
|
tasks, # 任务管理
|
||||||
library, # 新增:文档库
|
library, # 文档库
|
||||||
rag, # 新增:RAG检索
|
rag, # RAG检索
|
||||||
|
templates, # 表格模板
|
||||||
ai_analyze,
|
ai_analyze,
|
||||||
visualization,
|
visualization,
|
||||||
analysis_charts,
|
analysis_charts,
|
||||||
@@ -18,12 +19,13 @@ from app.api.endpoints import (
|
|||||||
api_router = APIRouter()
|
api_router = APIRouter()
|
||||||
|
|
||||||
# 注册各模块路由
|
# 注册各模块路由
|
||||||
api_router.include_router(health.router) # 健康检查
|
api_router.include_router(health.router) # 健康检查
|
||||||
api_router.include_router(upload.router) # 原有Excel上传
|
api_router.include_router(upload.router) # 原有Excel上传
|
||||||
api_router.include_router(documents.router) # 多格式文档上传
|
api_router.include_router(documents.router) # 多格式文档上传
|
||||||
api_router.include_router(tasks.router) # 任务状态查询
|
api_router.include_router(tasks.router) # 任务状态查询
|
||||||
api_router.include_router(library.router) # 文档库管理
|
api_router.include_router(library.router) # 文档库管理
|
||||||
api_router.include_router(rag.router) # RAG检索
|
api_router.include_router(rag.router) # RAG检索
|
||||||
api_router.include_router(ai_analyze.router) # AI分析
|
api_router.include_router(templates.router) # 表格模板
|
||||||
|
api_router.include_router(ai_analyze.router) # AI分析
|
||||||
api_router.include_router(visualization.router) # 可视化
|
api_router.include_router(visualization.router) # 可视化
|
||||||
api_router.include_router(analysis_charts.router) # 分析图表
|
api_router.include_router(analysis_charts.router) # 分析图表
|
||||||
|
|||||||
228
backend/app/api/endpoints/templates.py
Normal file
228
backend/app/api/endpoints/templates.py
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
"""
|
||||||
|
表格模板 API 接口
|
||||||
|
|
||||||
|
提供模板上传、解析和填写功能
|
||||||
|
"""
|
||||||
|
import io
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, File, HTTPException, Query, UploadFile
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
import pandas as pd
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.services.template_fill_service import template_fill_service, TemplateField
|
||||||
|
from app.services.excel_storage_service import excel_storage_service
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/templates", tags=["表格模板"])
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 请求/响应模型 ====================
|
||||||
|
|
||||||
|
class TemplateFieldRequest(BaseModel):
|
||||||
|
"""模板字段请求"""
|
||||||
|
cell: str
|
||||||
|
name: str
|
||||||
|
field_type: str = "text"
|
||||||
|
required: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class FillRequest(BaseModel):
|
||||||
|
"""填写请求"""
|
||||||
|
template_id: str
|
||||||
|
template_fields: List[TemplateFieldRequest]
|
||||||
|
source_doc_ids: Optional[List[str]] = None
|
||||||
|
user_hint: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ExportRequest(BaseModel):
|
||||||
|
"""导出请求"""
|
||||||
|
template_id: str
|
||||||
|
filled_data: dict
|
||||||
|
format: str = "xlsx" # xlsx 或 docx
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 接口实现 ====================
|
||||||
|
|
||||||
|
@router.post("/upload")
|
||||||
|
async def upload_template(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
上传表格模板文件
|
||||||
|
|
||||||
|
支持 Excel (.xlsx, .xls) 和 Word (.docx) 格式
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
模板信息,包括提取的字段列表
|
||||||
|
"""
|
||||||
|
if not file.filename:
|
||||||
|
raise HTTPException(status_code=400, detail="文件名为空")
|
||||||
|
|
||||||
|
file_ext = file.filename.split('.')[-1].lower()
|
||||||
|
if file_ext not in ['xlsx', 'xls', 'docx']:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"不支持的模板格式: {file_ext},仅支持 xlsx/xls/docx"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 保存文件
|
||||||
|
from app.services.file_service import file_service
|
||||||
|
content = await file.read()
|
||||||
|
saved_path = file_service.save_uploaded_file(
|
||||||
|
content,
|
||||||
|
file.filename,
|
||||||
|
subfolder="templates"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提取字段
|
||||||
|
template_fields = await template_fill_service.get_template_fields_from_file(
|
||||||
|
saved_path,
|
||||||
|
file_ext
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"template_id": saved_path, # 使用文件路径作为ID
|
||||||
|
"filename": file.filename,
|
||||||
|
"file_type": file_ext,
|
||||||
|
"fields": [
|
||||||
|
{
|
||||||
|
"cell": f.cell,
|
||||||
|
"name": f.name,
|
||||||
|
"field_type": f.field_type,
|
||||||
|
"required": f.required
|
||||||
|
}
|
||||||
|
for f in template_fields
|
||||||
|
],
|
||||||
|
"field_count": len(template_fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"上传模板失败: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/fields")
|
||||||
|
async def extract_template_fields(
|
||||||
|
template_id: str = Query(..., description="模板ID/文件路径"),
|
||||||
|
file_type: str = Query("xlsx", description="文件类型")
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
从已上传的模板提取字段定义
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_id: 模板ID
|
||||||
|
file_type: 文件类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
字段列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
fields = await template_fill_service.get_template_fields_from_file(
|
||||||
|
template_id,
|
||||||
|
file_type
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"fields": [
|
||||||
|
{
|
||||||
|
"cell": f.cell,
|
||||||
|
"name": f.name,
|
||||||
|
"field_type": f.field_type,
|
||||||
|
"required": f.required
|
||||||
|
}
|
||||||
|
for f in fields
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取字段失败: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"提取失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/fill")
|
||||||
|
async def fill_template(
|
||||||
|
request: FillRequest,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
执行表格填写
|
||||||
|
|
||||||
|
根据提供的字段定义,从已上传的文档中检索信息并填写
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 填写请求
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
填写结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 转换字段
|
||||||
|
fields = [
|
||||||
|
TemplateField(
|
||||||
|
cell=f.cell,
|
||||||
|
name=f.name,
|
||||||
|
field_type=f.field_type,
|
||||||
|
required=f.required
|
||||||
|
)
|
||||||
|
for f in request.template_fields
|
||||||
|
]
|
||||||
|
|
||||||
|
# 执行填写
|
||||||
|
result = await template_fill_service.fill_template(
|
||||||
|
template_fields=fields,
|
||||||
|
source_doc_ids=request.source_doc_ids,
|
||||||
|
user_hint=request.user_hint
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"填写表格失败: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"填写失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/export")
|
||||||
|
async def export_filled_template(
|
||||||
|
request: ExportRequest,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
导出填写后的表格
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 导出请求
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
文件流
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 创建 DataFrame
|
||||||
|
df = pd.DataFrame([request.filled_data])
|
||||||
|
|
||||||
|
# 导出为 Excel
|
||||||
|
output = io.BytesIO()
|
||||||
|
with pd.ExcelWriter(output, engine='openpyxl') as writer:
|
||||||
|
df.to_excel(writer, index=False, sheet_name='填写结果')
|
||||||
|
|
||||||
|
output.seek(0)
|
||||||
|
|
||||||
|
# 生成文件名
|
||||||
|
filename = f"filled_template.{request.format}"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
io.BytesIO(output.getvalue()),
|
||||||
|
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||||
|
headers={"Content-Disposition": f"attachment; filename={filename}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"导出失败: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"导出失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 需要添加的 import ====================
|
||||||
|
import logging
|
||||||
@@ -2,26 +2,29 @@
|
|||||||
文档解析模块 - 支持多种文件格式的解析
|
文档解析模块 - 支持多种文件格式的解析
|
||||||
"""
|
"""
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional
|
from typing import Dict
|
||||||
|
|
||||||
from .base import BaseParser, ParseResult
|
from .base import BaseParser, ParseResult
|
||||||
from .xlsx_parser import XlsxParser
|
from .xlsx_parser import XlsxParser
|
||||||
|
from .docx_parser import DocxParser
|
||||||
# 导入其他解析器 (需要先实现)
|
from .md_parser import MarkdownParser
|
||||||
# from .docx_parser import DocxParser
|
from .txt_parser import TxtParser
|
||||||
# from .md_parser import MarkdownParser
|
|
||||||
# from .txt_parser import TxtParser
|
|
||||||
|
|
||||||
|
|
||||||
class ParserFactory:
|
class ParserFactory:
|
||||||
"""解析器工厂,根据文件类型返回对应解析器"""
|
"""解析器工厂,根据文件类型返回对应解析器"""
|
||||||
|
|
||||||
_parsers: Dict[str, BaseParser] = {
|
_parsers: Dict[str, BaseParser] = {
|
||||||
|
# Excel
|
||||||
'.xlsx': XlsxParser(),
|
'.xlsx': XlsxParser(),
|
||||||
'.xls': XlsxParser(),
|
'.xls': XlsxParser(),
|
||||||
# '.docx': DocxParser(), # TODO: 待实现
|
# Word
|
||||||
# '.md': MarkdownParser(), # TODO: 待实现
|
'.docx': DocxParser(),
|
||||||
# '.txt': TxtParser(), # TODO: 待实现
|
# Markdown
|
||||||
|
'.md': MarkdownParser(),
|
||||||
|
'.markdown': MarkdownParser(),
|
||||||
|
# 文本
|
||||||
|
'.txt': TxtParser(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -30,7 +33,8 @@ class ParserFactory:
|
|||||||
ext = Path(file_path).suffix.lower()
|
ext = Path(file_path).suffix.lower()
|
||||||
parser = cls._parsers.get(ext)
|
parser = cls._parsers.get(ext)
|
||||||
if not parser:
|
if not parser:
|
||||||
raise ValueError(f"不支持的文件格式: {ext},支持的格式: {list(cls._parsers.keys())}")
|
supported = list(cls._parsers.keys())
|
||||||
|
raise ValueError(f"不支持的文件格式: {ext},支持的格式: {supported}")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -44,5 +48,18 @@ class ParserFactory:
|
|||||||
"""注册新的解析器"""
|
"""注册新的解析器"""
|
||||||
cls._parsers[ext.lower()] = parser
|
cls._parsers[ext.lower()] = parser
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_extensions(cls) -> list:
|
||||||
|
"""获取所有支持的扩展名"""
|
||||||
|
return list(cls._parsers.keys())
|
||||||
|
|
||||||
__all__ = ['BaseParser', 'ParseResult', 'XlsxParser', 'ParserFactory']
|
|
||||||
|
__all__ = [
|
||||||
|
'BaseParser',
|
||||||
|
'ParseResult',
|
||||||
|
'ParserFactory',
|
||||||
|
'XlsxParser',
|
||||||
|
'DocxParser',
|
||||||
|
'MarkdownParser',
|
||||||
|
'TxtParser',
|
||||||
|
]
|
||||||
|
|||||||
163
backend/app/core/document_parser/docx_parser.py
Normal file
163
backend/app/core/document_parser/docx_parser.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
"""
|
||||||
|
Word 文档 (.docx) 解析器
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from docx import Document
|
||||||
|
|
||||||
|
from .base import BaseParser, ParseResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DocxParser(BaseParser):
|
||||||
|
"""Word 文档解析器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.supported_extensions = ['.docx']
|
||||||
|
self.parser_name = "docx_parser"
|
||||||
|
|
||||||
|
def parse(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
**kwargs
|
||||||
|
) -> ParseResult:
|
||||||
|
"""
|
||||||
|
解析 Word 文档
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: 文件路径
|
||||||
|
**kwargs: 其他参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ParseResult: 解析结果
|
||||||
|
"""
|
||||||
|
path = Path(file_path)
|
||||||
|
|
||||||
|
# 检查文件是否存在
|
||||||
|
if not path.exists():
|
||||||
|
return ParseResult(
|
||||||
|
success=False,
|
||||||
|
error=f"文件不存在: {file_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查文件扩展名
|
||||||
|
if path.suffix.lower() not in self.supported_extensions:
|
||||||
|
return ParseResult(
|
||||||
|
success=False,
|
||||||
|
error=f"不支持的文件类型: {path.suffix}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 读取 Word 文档
|
||||||
|
doc = Document(file_path)
|
||||||
|
|
||||||
|
# 提取文本内容
|
||||||
|
paragraphs = []
|
||||||
|
for para in doc.paragraphs:
|
||||||
|
if para.text.strip():
|
||||||
|
paragraphs.append(para.text)
|
||||||
|
|
||||||
|
# 提取表格内容
|
||||||
|
tables_data = []
|
||||||
|
for i, table in enumerate(doc.tables):
|
||||||
|
table_rows = []
|
||||||
|
for row in table.rows:
|
||||||
|
row_data = [cell.text.strip() for cell in row.cells]
|
||||||
|
table_rows.append(row_data)
|
||||||
|
|
||||||
|
if table_rows:
|
||||||
|
tables_data.append({
|
||||||
|
"table_index": i,
|
||||||
|
"rows": table_rows,
|
||||||
|
"row_count": len(table_rows),
|
||||||
|
"column_count": len(table_rows[0]) if table_rows else 0
|
||||||
|
})
|
||||||
|
|
||||||
|
# 合并所有文本
|
||||||
|
full_text = "\n".join(paragraphs)
|
||||||
|
|
||||||
|
# 构建元数据
|
||||||
|
metadata = {
|
||||||
|
"filename": path.name,
|
||||||
|
"extension": path.suffix.lower(),
|
||||||
|
"file_size": path.stat().st_size,
|
||||||
|
"paragraph_count": len(paragraphs),
|
||||||
|
"table_count": len(tables_data),
|
||||||
|
"word_count": len(full_text),
|
||||||
|
"char_count": len(full_text.replace("\n", "")),
|
||||||
|
"has_tables": len(tables_data) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# 返回结果
|
||||||
|
return ParseResult(
|
||||||
|
success=True,
|
||||||
|
data={
|
||||||
|
"content": full_text,
|
||||||
|
"paragraphs": paragraphs,
|
||||||
|
"tables": tables_data,
|
||||||
|
"word_count": len(full_text),
|
||||||
|
"structured_data": {
|
||||||
|
"paragraphs": paragraphs,
|
||||||
|
"tables": tables_data
|
||||||
|
}
|
||||||
|
},
|
||||||
|
metadata=metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解析 Word 文档失败: {str(e)}")
|
||||||
|
return ParseResult(
|
||||||
|
success=False,
|
||||||
|
error=f"解析 Word 文档失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def extract_key_sentences(self, text: str, max_sentences: int = 10) -> List[str]:
|
||||||
|
"""
|
||||||
|
从文本中提取关键句子
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 文本内容
|
||||||
|
max_sentences: 最大句子数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
关键句子列表
|
||||||
|
"""
|
||||||
|
# 简单实现:按句号分割,取前N个句子
|
||||||
|
sentences = [s.strip() for s in text.split("。") if s.strip()]
|
||||||
|
return sentences[:max_sentences]
|
||||||
|
|
||||||
|
def extract_structured_fields(self, text: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
尝试提取结构化字段
|
||||||
|
|
||||||
|
针对合同、简历等有固定格式的文档
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 文本内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
提取的字段字典
|
||||||
|
"""
|
||||||
|
fields = {}
|
||||||
|
|
||||||
|
# 常见字段模式
|
||||||
|
patterns = {
|
||||||
|
"姓名": r"姓名[::]\s*(\S+)",
|
||||||
|
"电话": r"电话[::]\s*(\d{11}|\d{3}-\d{8})",
|
||||||
|
"邮箱": r"邮箱[::]\s*(\S+@\S+)",
|
||||||
|
"地址": r"地址[::]\s*(.+?)(?:\n|$)",
|
||||||
|
"金额": r"金额[::]\s*(\d+(?:\.\d+)?)",
|
||||||
|
"日期": r"日期[::]\s*(\d{4}[年/-]\d{1,2}[月/-]\d{1,2})",
|
||||||
|
}
|
||||||
|
|
||||||
|
import re
|
||||||
|
for field_name, pattern in patterns.items():
|
||||||
|
match = re.search(pattern, text)
|
||||||
|
if match:
|
||||||
|
fields[field_name] = match.group(1)
|
||||||
|
|
||||||
|
return fields
|
||||||
262
backend/app/core/document_parser/md_parser.py
Normal file
262
backend/app/core/document_parser/md_parser.py
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
"""
|
||||||
|
Markdown 文档解析器
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import markdown
|
||||||
|
|
||||||
|
from .base import BaseParser, ParseResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MarkdownParser(BaseParser):
|
||||||
|
"""Markdown 文档解析器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.supported_extensions = ['.md', '.markdown']
|
||||||
|
self.parser_name = "markdown_parser"
|
||||||
|
|
||||||
|
def parse(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
**kwargs
|
||||||
|
) -> ParseResult:
|
||||||
|
"""
|
||||||
|
解析 Markdown 文档
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: 文件路径
|
||||||
|
**kwargs: 其他参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ParseResult: 解析结果
|
||||||
|
"""
|
||||||
|
path = Path(file_path)
|
||||||
|
|
||||||
|
# 检查文件是否存在
|
||||||
|
if not path.exists():
|
||||||
|
return ParseResult(
|
||||||
|
success=False,
|
||||||
|
error=f"文件不存在: {file_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查文件扩展名
|
||||||
|
if path.suffix.lower() not in self.supported_extensions:
|
||||||
|
return ParseResult(
|
||||||
|
success=False,
|
||||||
|
error=f"不支持的文件类型: {path.suffix}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 读取文件内容
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
raw_content = f.read()
|
||||||
|
|
||||||
|
# 解析 Markdown
|
||||||
|
md = markdown.Markdown(extensions=[
|
||||||
|
'markdown.extensions.tables',
|
||||||
|
'markdown.extensions.fenced_code',
|
||||||
|
'markdown.extensions.codehilite',
|
||||||
|
'markdown.extensions.toc',
|
||||||
|
])
|
||||||
|
|
||||||
|
html_content = md.convert(raw_content)
|
||||||
|
|
||||||
|
# 提取标题结构
|
||||||
|
titles = self._extract_titles(raw_content)
|
||||||
|
|
||||||
|
# 提取代码块
|
||||||
|
code_blocks = self._extract_code_blocks(raw_content)
|
||||||
|
|
||||||
|
# 提取表格
|
||||||
|
tables = self._extract_tables(raw_content)
|
||||||
|
|
||||||
|
# 提取链接和图片
|
||||||
|
links_images = self._extract_links_images(raw_content)
|
||||||
|
|
||||||
|
# 清理后的纯文本(去除 Markdown 语法)
|
||||||
|
plain_text = self._strip_markdown(raw_content)
|
||||||
|
|
||||||
|
# 构建元数据
|
||||||
|
metadata = {
|
||||||
|
"filename": path.name,
|
||||||
|
"extension": path.suffix.lower(),
|
||||||
|
"file_size": path.stat().st_size,
|
||||||
|
"word_count": len(plain_text),
|
||||||
|
"char_count": len(raw_content),
|
||||||
|
"line_count": len(raw_content.splitlines()),
|
||||||
|
"title_count": len(titles),
|
||||||
|
"code_block_count": len(code_blocks),
|
||||||
|
"table_count": len(tables),
|
||||||
|
"link_count": len(links_images.get("links", [])),
|
||||||
|
"image_count": len(links_images.get("images", [])),
|
||||||
|
}
|
||||||
|
|
||||||
|
return ParseResult(
|
||||||
|
success=True,
|
||||||
|
data={
|
||||||
|
"content": plain_text,
|
||||||
|
"raw_content": raw_content,
|
||||||
|
"html_content": html_content,
|
||||||
|
"titles": titles,
|
||||||
|
"code_blocks": code_blocks,
|
||||||
|
"tables": tables,
|
||||||
|
"links_images": links_images,
|
||||||
|
"word_count": len(plain_text),
|
||||||
|
"structured_data": {
|
||||||
|
"titles": titles,
|
||||||
|
"code_blocks": code_blocks,
|
||||||
|
"tables": tables
|
||||||
|
}
|
||||||
|
},
|
||||||
|
metadata=metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解析 Markdown 文档失败: {str(e)}")
|
||||||
|
return ParseResult(
|
||||||
|
success=False,
|
||||||
|
error=f"解析 Markdown 文档失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extract_titles(self, content: str) -> List[Dict[str, Any]]:
|
||||||
|
"""提取标题结构"""
|
||||||
|
import re
|
||||||
|
titles = []
|
||||||
|
|
||||||
|
# 匹配 # 标题
|
||||||
|
for match in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE):
|
||||||
|
level = len(match.group(1))
|
||||||
|
title_text = match.group(2).strip()
|
||||||
|
titles.append({
|
||||||
|
"level": level,
|
||||||
|
"text": title_text,
|
||||||
|
"line": content[:match.start()].count('\n') + 1
|
||||||
|
})
|
||||||
|
|
||||||
|
return titles
|
||||||
|
|
||||||
|
def _extract_code_blocks(self, content: str) -> List[Dict[str, str]]:
|
||||||
|
"""提取代码块"""
|
||||||
|
import re
|
||||||
|
code_blocks = []
|
||||||
|
|
||||||
|
# 匹配 ```code ``` 格式
|
||||||
|
pattern = r'```(\w*)\n(.*?)```'
|
||||||
|
for match in re.finditer(pattern, content, re.DOTALL):
|
||||||
|
language = match.group(1) or "text"
|
||||||
|
code = match.group(2).strip()
|
||||||
|
code_blocks.append({
|
||||||
|
"language": language,
|
||||||
|
"code": code
|
||||||
|
})
|
||||||
|
|
||||||
|
return code_blocks
|
||||||
|
|
||||||
|
def _extract_tables(self, content: str) -> List[Dict[str, Any]]:
|
||||||
|
"""提取表格"""
|
||||||
|
import re
|
||||||
|
tables = []
|
||||||
|
|
||||||
|
# 简单表格匹配(| col1 | col2 | 格式)
|
||||||
|
lines = content.split('\n')
|
||||||
|
i = 0
|
||||||
|
while i < len(lines):
|
||||||
|
line = lines[i].strip()
|
||||||
|
|
||||||
|
# 检查是否是表格行
|
||||||
|
if line.startswith('|') and line.endswith('|'):
|
||||||
|
# 找到表头
|
||||||
|
header_row = [cell.strip() for cell in line.split('|')[1:-1]]
|
||||||
|
|
||||||
|
# 检查下一行是否是分隔符
|
||||||
|
if i + 1 < len(lines) and re.match(r'^\|[\s\-:|]+\|$', lines[i + 1]):
|
||||||
|
# 跳过分隔符,读取数据行
|
||||||
|
data_rows = []
|
||||||
|
for j in range(i + 2, len(lines)):
|
||||||
|
row_line = lines[j].strip()
|
||||||
|
if not (row_line.startswith('|') and row_line.endswith('|')):
|
||||||
|
break
|
||||||
|
row_data = [cell.strip() for cell in row_line.split('|')[1:-1]]
|
||||||
|
data_rows.append(row_data)
|
||||||
|
|
||||||
|
if header_row and data_rows:
|
||||||
|
tables.append({
|
||||||
|
"headers": header_row,
|
||||||
|
"rows": data_rows,
|
||||||
|
"row_count": len(data_rows),
|
||||||
|
"column_count": len(header_row)
|
||||||
|
})
|
||||||
|
i = j - 1
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return tables
|
||||||
|
|
||||||
|
def _extract_links_images(self, content: str) -> Dict[str, List[Dict[str, str]]]:
|
||||||
|
"""提取链接和图片"""
|
||||||
|
import re
|
||||||
|
result = {"links": [], "images": []}
|
||||||
|
|
||||||
|
# 提取链接 [text](url)
|
||||||
|
for match in re.finditer(r'\[([^\]]+)\]\(([^\)]+)\)', content):
|
||||||
|
result["links"].append({
|
||||||
|
"text": match.group(1),
|
||||||
|
"url": match.group(2)
|
||||||
|
})
|
||||||
|
|
||||||
|
# 提取图片 
|
||||||
|
for match in re.finditer(r'!\[([^\]]*)\]\(([^\)]+)\)', content):
|
||||||
|
result["images"].append({
|
||||||
|
"alt": match.group(1),
|
||||||
|
"url": match.group(2)
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _strip_markdown(self, content: str) -> str:
|
||||||
|
"""去除 Markdown 语法,获取纯文本"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
# 去除代码块
|
||||||
|
content = re.sub(r'```[\s\S]*?```', '', content)
|
||||||
|
|
||||||
|
# 去除行内代码
|
||||||
|
content = re.sub(r'`[^`]+`', '', content)
|
||||||
|
|
||||||
|
# 去除图片
|
||||||
|
content = re.sub(r'!\[([^\]]*)\]\([^\)]+\)', r'\1', content)
|
||||||
|
|
||||||
|
# 去除链接,保留文本
|
||||||
|
content = re.sub(r'\[([^\]]+)\]\([^\)]+\)', r'\1', content)
|
||||||
|
|
||||||
|
# 去除标题标记
|
||||||
|
content = re.sub(r'^#{1,6}\s+', '', content, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
# 去除加粗和斜体
|
||||||
|
content = re.sub(r'\*\*([^\*]+)\*\*', r'\1', content)
|
||||||
|
content = re.sub(r'\*([^\*]+)\*', r'\1', content)
|
||||||
|
content = re.sub(r'__([^_]+)__', r'\1', content)
|
||||||
|
content = re.sub(r'_([^_]+)_', r'\1', content)
|
||||||
|
|
||||||
|
# 去除引用标记
|
||||||
|
content = re.sub(r'^>\s+', '', content, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
# 去除列表标记
|
||||||
|
content = re.sub(r'^[-*+]\s+', '', content, flags=re.MULTILINE)
|
||||||
|
content = re.sub(r'^\d+\.\s+', '', content, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
# 去除水平线
|
||||||
|
content = re.sub(r'^[-*_]{3,}$', '', content, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
# 去除表格分隔符
|
||||||
|
content = re.sub(r'^\|[\s\-:|]+\|$', '', content, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
# 清理多余空行
|
||||||
|
content = re.sub(r'\n{3,}', '\n\n', content)
|
||||||
|
|
||||||
|
return content.strip()
|
||||||
278
backend/app/core/document_parser/txt_parser.py
Normal file
278
backend/app/core/document_parser/txt_parser.py
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
"""
|
||||||
|
纯文本 (.txt) 解析器
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import chardet
|
||||||
|
|
||||||
|
from .base import BaseParser, ParseResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TxtParser(BaseParser):
|
||||||
|
"""纯文本文档解析器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.supported_extensions = ['.txt']
|
||||||
|
self.parser_name = "txt_parser"
|
||||||
|
|
||||||
|
def parse(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
encoding: Optional[str] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> ParseResult:
|
||||||
|
"""
|
||||||
|
解析文本文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: 文件路径
|
||||||
|
encoding: 指定编码,不指定则自动检测
|
||||||
|
**kwargs: 其他参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ParseResult: 解析结果
|
||||||
|
"""
|
||||||
|
path = Path(file_path)
|
||||||
|
|
||||||
|
# 检查文件是否存在
|
||||||
|
if not path.exists():
|
||||||
|
return ParseResult(
|
||||||
|
success=False,
|
||||||
|
error=f"文件不存在: {file_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查文件扩展名
|
||||||
|
if path.suffix.lower() not in self.supported_extensions:
|
||||||
|
return ParseResult(
|
||||||
|
success=False,
|
||||||
|
error=f"不支持的文件类型: {path.suffix}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 检测编码
|
||||||
|
if not encoding:
|
||||||
|
encoding = self._detect_encoding(file_path)
|
||||||
|
|
||||||
|
# 读取文件内容
|
||||||
|
with open(file_path, 'r', encoding=encoding) as f:
|
||||||
|
raw_content = f.read()
|
||||||
|
|
||||||
|
# 清理文本
|
||||||
|
content = self._clean_text(raw_content)
|
||||||
|
|
||||||
|
# 提取行信息
|
||||||
|
lines = content.split('\n')
|
||||||
|
|
||||||
|
# 估算字数
|
||||||
|
word_count = len(content.replace('\n', '').replace(' ', ''))
|
||||||
|
|
||||||
|
# 构建元数据
|
||||||
|
metadata = {
|
||||||
|
"filename": path.name,
|
||||||
|
"extension": path.suffix.lower(),
|
||||||
|
"file_size": path.stat().st_size,
|
||||||
|
"encoding": encoding,
|
||||||
|
"line_count": len(lines),
|
||||||
|
"word_count": word_count,
|
||||||
|
"char_count": len(content),
|
||||||
|
"non_empty_line_count": len([l for l in lines if l.strip()])
|
||||||
|
}
|
||||||
|
|
||||||
|
return ParseResult(
|
||||||
|
success=True,
|
||||||
|
data={
|
||||||
|
"content": content,
|
||||||
|
"raw_content": raw_content,
|
||||||
|
"lines": lines,
|
||||||
|
"word_count": word_count,
|
||||||
|
"char_count": len(content),
|
||||||
|
"line_count": len(lines),
|
||||||
|
"structured_data": {
|
||||||
|
"line_count": len(lines),
|
||||||
|
"non_empty_line_count": metadata["non_empty_line_count"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
metadata=metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解析文本文件失败: {str(e)}")
|
||||||
|
return ParseResult(
|
||||||
|
success=False,
|
||||||
|
error=f"解析文本文件失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _detect_encoding(self, file_path: str) -> str:
|
||||||
|
"""
|
||||||
|
自动检测文件编码
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: 文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
检测到的编码
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
raw_data = f.read()
|
||||||
|
|
||||||
|
result = chardet.detect(raw_data)
|
||||||
|
encoding = result.get('encoding', 'utf-8')
|
||||||
|
|
||||||
|
# 验证编码是否有效
|
||||||
|
if encoding:
|
||||||
|
try:
|
||||||
|
raw_data.decode(encoding)
|
||||||
|
return encoding
|
||||||
|
except (UnicodeDecodeError, LookupError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return 'utf-8'
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"编码检测失败,使用默认编码: {str(e)}")
|
||||||
|
return 'utf-8'
|
||||||
|
|
||||||
|
def _clean_text(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
清理文本内容
|
||||||
|
|
||||||
|
- 去除多余空白字符
|
||||||
|
- 规范化换行符
|
||||||
|
- 去除特殊控制字符
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 原始文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
清理后的文本
|
||||||
|
"""
|
||||||
|
# 规范化换行符
|
||||||
|
text = text.replace('\r\n', '\n').replace('\r', '\n')
|
||||||
|
|
||||||
|
# 去除控制字符(除了换行和tab)
|
||||||
|
text = re.sub(r'[\x00-\x08\x0b-\x0c\x0e-\x1f\x7f]', '', text)
|
||||||
|
|
||||||
|
# 将多个连续空格合并为一个
|
||||||
|
text = re.sub(r'[ \t]+', ' ', text)
|
||||||
|
|
||||||
|
# 将多个连续空行合并为一个
|
||||||
|
text = re.sub(r'\n{3,}', '\n\n', text)
|
||||||
|
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
def extract_structured_data(self, content: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
尝试从文本中提取结构化数据
|
||||||
|
|
||||||
|
支持提取:
|
||||||
|
- 邮箱地址
|
||||||
|
- URL
|
||||||
|
- 电话号码
|
||||||
|
- 日期
|
||||||
|
- 金额
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 文本内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
结构化数据字典
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"emails": [],
|
||||||
|
"urls": [],
|
||||||
|
"phones": [],
|
||||||
|
"dates": [],
|
||||||
|
"amounts": []
|
||||||
|
}
|
||||||
|
|
||||||
|
# 提取邮箱
|
||||||
|
emails = re.findall(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', content)
|
||||||
|
data["emails"] = list(set(emails))
|
||||||
|
|
||||||
|
# 提取 URL
|
||||||
|
urls = re.findall(r'https?://[^\s<>"{}|\\^`\[\]]+', content)
|
||||||
|
data["urls"] = list(set(urls))
|
||||||
|
|
||||||
|
# 提取电话号码 (支持多种格式)
|
||||||
|
phone_patterns = [
|
||||||
|
r'1[3-9]\d{9}', # 手机号
|
||||||
|
r'\d{3,4}-\d{7,8}', # 固话
|
||||||
|
]
|
||||||
|
phones = []
|
||||||
|
for pattern in phone_patterns:
|
||||||
|
phones.extend(re.findall(pattern, content))
|
||||||
|
data["phones"] = list(set(phones))
|
||||||
|
|
||||||
|
# 提取日期
|
||||||
|
date_patterns = [
|
||||||
|
r'\d{4}[-/年]\d{1,2}[-/月]\d{1,2}[日]?',
|
||||||
|
r'\d{4}\.\d{1,2}\.\d{1,2}',
|
||||||
|
]
|
||||||
|
dates = []
|
||||||
|
for pattern in date_patterns:
|
||||||
|
dates.extend(re.findall(pattern, content))
|
||||||
|
data["dates"] = list(set(dates))
|
||||||
|
|
||||||
|
# 提取金额
|
||||||
|
amount_patterns = [
|
||||||
|
r'¥\s*\d+(?:\.\d{1,2})?',
|
||||||
|
r'\$\s*\d+(?:\.\d{1,2})?',
|
||||||
|
r'\d+(?:\.\d{1,2})?\s*元',
|
||||||
|
]
|
||||||
|
amounts = []
|
||||||
|
for pattern in amount_patterns:
|
||||||
|
amounts.extend(re.findall(pattern, content))
|
||||||
|
data["amounts"] = list(set(amounts))
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def split_into_chunks(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
chunk_size: int = 1000,
|
||||||
|
overlap: int = 100
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
将长文本分割成块
|
||||||
|
|
||||||
|
用于 RAG 索引或 LLM 处理
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 文本内容
|
||||||
|
chunk_size: 每块字符数
|
||||||
|
overlap: 块之间的重叠字符数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
文本块列表
|
||||||
|
"""
|
||||||
|
if len(content) <= chunk_size:
|
||||||
|
return [content]
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
start = 0
|
||||||
|
|
||||||
|
while start < len(content):
|
||||||
|
end = start + chunk_size
|
||||||
|
chunk = content[start:end]
|
||||||
|
|
||||||
|
# 尝试在句子边界分割
|
||||||
|
if end < len(content):
|
||||||
|
last_period = chunk.rfind('。')
|
||||||
|
last_newline = chunk.rfind('\n')
|
||||||
|
split_pos = max(last_period, last_newline)
|
||||||
|
|
||||||
|
if split_pos > chunk_size // 2:
|
||||||
|
chunk = chunk[:split_pos + 1]
|
||||||
|
end = start + split_pos + 1
|
||||||
|
|
||||||
|
chunks.append(chunk)
|
||||||
|
start = end - overlap if end < len(content) else end
|
||||||
|
|
||||||
|
return chunks
|
||||||
352
backend/app/services/excel_storage_service.py
Normal file
352
backend/app/services/excel_storage_service.py
Normal file
@@ -0,0 +1,352 @@
|
|||||||
|
"""
|
||||||
|
Excel 存储服务
|
||||||
|
|
||||||
|
将 Excel 数据转换为 MySQL 表结构并存储
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from sqlalchemy import (
|
||||||
|
Column,
|
||||||
|
DateTime,
|
||||||
|
Float,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
inspect,
|
||||||
|
)
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.core.database.mysql import Base, mysql_db
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ExcelStorageService:
|
||||||
|
"""Excel 数据存储服务"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.mysql_db = mysql_db
|
||||||
|
|
||||||
|
def _sanitize_table_name(self, filename: str) -> str:
|
||||||
|
"""
|
||||||
|
将文件名转换为合法的表名
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: 原始文件名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
合法的表名
|
||||||
|
"""
|
||||||
|
# 移除扩展名
|
||||||
|
name = filename.rsplit('.', 1)[0] if '.' in filename else filename
|
||||||
|
|
||||||
|
# 只保留字母、数字、下划线
|
||||||
|
name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||||
|
|
||||||
|
# 确保以字母开头
|
||||||
|
if name and name[0].isdigit():
|
||||||
|
name = 't_' + name
|
||||||
|
|
||||||
|
# 限制长度
|
||||||
|
return name[:50]
|
||||||
|
|
||||||
|
def _sanitize_column_name(self, col_name: str) -> str:
|
||||||
|
"""
|
||||||
|
将列名转换为合法的字段名
|
||||||
|
|
||||||
|
Args:
|
||||||
|
col_name: 原始列名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
合法的字段名
|
||||||
|
"""
|
||||||
|
# 只保留字母、数字、下划线
|
||||||
|
name = re.sub(r'[^a-zA-Z0-9_]', '_', str(col_name))
|
||||||
|
|
||||||
|
# 确保以字母开头
|
||||||
|
if name and name[0].isdigit():
|
||||||
|
name = 'col_' + name
|
||||||
|
|
||||||
|
# 限制长度
|
||||||
|
return name[:50]
|
||||||
|
|
||||||
|
def _infer_column_type(self, series: pd.Series) -> str:
|
||||||
|
"""
|
||||||
|
根据数据推断列类型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
series: pandas Series
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
类型名称
|
||||||
|
"""
|
||||||
|
dtype = series.dtype
|
||||||
|
|
||||||
|
if pd.api.types.is_integer_dtype(dtype):
|
||||||
|
return "INTEGER"
|
||||||
|
elif pd.api.types.is_float_dtype(dtype):
|
||||||
|
return "FLOAT"
|
||||||
|
elif pd.api.types.is_datetime64_any_dtype(dtype):
|
||||||
|
return "DATETIME"
|
||||||
|
elif pd.api.types.is_bool_dtype(dtype):
|
||||||
|
return "BOOLEAN"
|
||||||
|
else:
|
||||||
|
return "TEXT"
|
||||||
|
|
||||||
|
def _create_table_model(
|
||||||
|
self,
|
||||||
|
table_name: str,
|
||||||
|
columns: List[str],
|
||||||
|
column_types: Dict[str, str]
|
||||||
|
) -> type:
|
||||||
|
"""
|
||||||
|
动态创建 SQLAlchemy 模型类
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: 表名
|
||||||
|
columns: 列名列表
|
||||||
|
column_types: 列类型字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SQLAlchemy 模型类
|
||||||
|
"""
|
||||||
|
# 创建属性字典
|
||||||
|
attrs = {
|
||||||
|
'__tablename__': table_name,
|
||||||
|
'__table_args__': {'extend_existing': True},
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加主键列
|
||||||
|
attrs['id'] = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
|
||||||
|
# 添加数据列
|
||||||
|
for col in columns:
|
||||||
|
col_name = self._sanitize_column_name(col)
|
||||||
|
col_type = column_types.get(col, "TEXT")
|
||||||
|
|
||||||
|
if col_type == "INTEGER":
|
||||||
|
attrs[col_name] = Column(Integer, nullable=True)
|
||||||
|
elif col_type == "FLOAT":
|
||||||
|
attrs[col_name] = Column(Float, nullable=True)
|
||||||
|
elif col_type == "DATETIME":
|
||||||
|
attrs[col_name] = Column(DateTime, nullable=True)
|
||||||
|
elif col_type == "BOOLEAN":
|
||||||
|
attrs[col_name] = Column(Integer, nullable=True) # MySQL 没有原生 BOOLEAN
|
||||||
|
else:
|
||||||
|
attrs[col_name] = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
# 添加元数据列
|
||||||
|
attrs['created_at'] = Column(DateTime, default=datetime.utcnow)
|
||||||
|
attrs['updated_at'] = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||||
|
|
||||||
|
# 创建类
|
||||||
|
return type(table_name, (Base,), attrs)
|
||||||
|
|
||||||
|
async def store_excel(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
filename: str,
|
||||||
|
sheet_name: Optional[str] = None,
|
||||||
|
header_row: int = 0
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
将 Excel 文件存储到 MySQL
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Excel 文件路径
|
||||||
|
filename: 原始文件名
|
||||||
|
sheet_name: 工作表名称
|
||||||
|
header_row: 表头行号
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
存储结果
|
||||||
|
"""
|
||||||
|
table_name = self._sanitize_table_name(filename)
|
||||||
|
results = {
|
||||||
|
"success": True,
|
||||||
|
"table_name": table_name,
|
||||||
|
"row_count": 0,
|
||||||
|
"columns": []
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 读取 Excel
|
||||||
|
if sheet_name:
|
||||||
|
df = pd.read_excel(file_path, sheet_name=sheet_name, header=header_row)
|
||||||
|
else:
|
||||||
|
df = pd.read_excel(file_path, header=header_row)
|
||||||
|
|
||||||
|
if df.empty:
|
||||||
|
return {"success": False, "error": "Excel 文件为空"}
|
||||||
|
|
||||||
|
# 清理列名
|
||||||
|
df.columns = [str(c) for c in df.columns]
|
||||||
|
|
||||||
|
# 推断列类型
|
||||||
|
column_types = {}
|
||||||
|
for col in df.columns:
|
||||||
|
col_name = self._sanitize_column_name(col)
|
||||||
|
col_type = self._infer_column_type(df[col])
|
||||||
|
column_types[col] = col_type
|
||||||
|
results["columns"].append({
|
||||||
|
"original_name": col,
|
||||||
|
"sanitized_name": col_name,
|
||||||
|
"type": col_type
|
||||||
|
})
|
||||||
|
|
||||||
|
# 创建表
|
||||||
|
model_class = self._create_table_model(table_name, df.columns, column_types)
|
||||||
|
|
||||||
|
# 创建表结构
|
||||||
|
async with self.mysql_db.get_session() as session:
|
||||||
|
model_class.__table__.create(session.bind, checkfirst=True)
|
||||||
|
|
||||||
|
# 插入数据
|
||||||
|
records = []
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
record = {}
|
||||||
|
for col in df.columns:
|
||||||
|
col_name = self._sanitize_column_name(col)
|
||||||
|
value = row[col]
|
||||||
|
|
||||||
|
# 处理 NaN 值
|
||||||
|
if pd.isna(value):
|
||||||
|
record[col_name] = None
|
||||||
|
elif column_types[col] == "INTEGER":
|
||||||
|
try:
|
||||||
|
record[col_name] = int(value)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
record[col_name] = None
|
||||||
|
elif column_types[col] == "FLOAT":
|
||||||
|
try:
|
||||||
|
record[col_name] = float(value)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
record[col_name] = None
|
||||||
|
else:
|
||||||
|
record[col_name] = str(value)
|
||||||
|
|
||||||
|
records.append(record)
|
||||||
|
|
||||||
|
# 批量插入
|
||||||
|
async with self.mysql_db.get_session() as session:
|
||||||
|
for record in records:
|
||||||
|
session.add(model_class(**record))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
results["row_count"] = len(records)
|
||||||
|
logger.info(f"Excel 数据已存储到 MySQL 表 {table_name},共 {len(records)} 行")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"存储 Excel 到 MySQL 失败: {str(e)}")
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
|
async def query_table(
|
||||||
|
self,
|
||||||
|
table_name: str,
|
||||||
|
columns: Optional[List[str]] = None,
|
||||||
|
where: Optional[str] = None,
|
||||||
|
limit: int = 100
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
查询 MySQL 表数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: 表名
|
||||||
|
columns: 要查询的列
|
||||||
|
where: WHERE 条件
|
||||||
|
limit: 限制返回行数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
查询结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 构建查询
|
||||||
|
sql = f"SELECT * FROM `{table_name}`"
|
||||||
|
if where:
|
||||||
|
sql += f" WHERE {where}"
|
||||||
|
sql += f" LIMIT {limit}"
|
||||||
|
|
||||||
|
results = await self.mysql_db.execute_query(sql)
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查询表失败: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_table_schema(self, table_name: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取表结构信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: 表名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
表结构信息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
sql = f"""
|
||||||
|
SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_KEY, COLUMN_COMMENT
|
||||||
|
FROM INFORMATION_SCHEMA.COLUMNS
|
||||||
|
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = '{table_name}'
|
||||||
|
ORDER BY ORDINAL_POSITION
|
||||||
|
"""
|
||||||
|
results = await self.mysql_db.execute_query(sql)
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取表结构失败: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def delete_table(self, table_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
删除表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: 表名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 安全检查:表名必须包含下划线(避免删除系统表)
|
||||||
|
if '_' not in table_name and not table_name.startswith('t_'):
|
||||||
|
raise ValueError("不允许删除此表")
|
||||||
|
|
||||||
|
sql = f"DROP TABLE IF EXISTS `{table_name}`"
|
||||||
|
await self.mysql_db.execute_raw_sql(sql)
|
||||||
|
logger.info(f"表 {table_name} 已删除")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除表失败: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def list_tables(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
列出所有用户表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
表名列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
sql = """
|
||||||
|
SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES
|
||||||
|
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_TYPE = 'BASE TABLE'
|
||||||
|
"""
|
||||||
|
results = await self.mysql_db.execute_query(sql)
|
||||||
|
return [r['TABLE_NAME'] for r in results]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"列出表失败: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 全局单例 ====================
|
||||||
|
|
||||||
|
excel_storage_service = ExcelStorageService()
|
||||||
444
backend/app/services/prompt_service.py
Normal file
444
backend/app/services/prompt_service.py
Normal file
@@ -0,0 +1,444 @@
|
|||||||
|
"""
|
||||||
|
提示词工程服务
|
||||||
|
|
||||||
|
管理和优化与大模型交互的提示词
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptType(Enum):
|
||||||
|
"""提示词类型"""
|
||||||
|
DOCUMENT_PARSING = "document_parsing" # 文档解析
|
||||||
|
FIELD_EXTRACTION = "field_extraction" # 字段提取
|
||||||
|
TABLE_FILLING = "table_filling" # 表格填写
|
||||||
|
QUERY_GENERATION = "query_generation" # 查询生成
|
||||||
|
TEXT_SUMMARY = "text_summary" # 文本摘要
|
||||||
|
INTENT_CLASSIFICATION = "intent_classification" # 意图分类
|
||||||
|
DATA_CLASSIFICATION = "data_classification" # 数据分类
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PromptTemplate:
|
||||||
|
"""提示词模板"""
|
||||||
|
name: str
|
||||||
|
type: PromptType
|
||||||
|
system_prompt: str
|
||||||
|
user_template: str
|
||||||
|
examples: List[Dict[str, str]] = field(default_factory=list) # Few-shot 示例
|
||||||
|
rules: List[str] = field(default_factory=list) # 特殊规则
|
||||||
|
|
||||||
|
def format(
|
||||||
|
self,
|
||||||
|
context: Dict[str, Any],
|
||||||
|
user_input: Optional[str] = None
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
格式化提示词
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: 上下文数据
|
||||||
|
user_input: 用户输入
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
格式化后的消息列表
|
||||||
|
"""
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
# 系统提示词
|
||||||
|
system_content = self.system_prompt
|
||||||
|
|
||||||
|
# 添加规则
|
||||||
|
if self.rules:
|
||||||
|
system_content += "\n\n【输出规则】\n" + "\n".join([f"- {rule}" for rule in self.rules])
|
||||||
|
|
||||||
|
# 添加示例
|
||||||
|
if self.examples:
|
||||||
|
system_content += "\n\n【示例】\n"
|
||||||
|
for i, ex in enumerate(self.examples):
|
||||||
|
system_content += f"\n示例 {i+1}:\n"
|
||||||
|
system_content += f"输入: {ex.get('input', '')}\n"
|
||||||
|
system_content += f"输出: {ex.get('output', '')}\n"
|
||||||
|
|
||||||
|
messages.append({"role": "system", "content": system_content})
|
||||||
|
|
||||||
|
# 用户提示词
|
||||||
|
user_content = self._format_user_template(context, user_input)
|
||||||
|
messages.append({"role": "user", "content": user_content})
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def _format_user_template(
|
||||||
|
self,
|
||||||
|
context: Dict[str, Any],
|
||||||
|
user_input: Optional[str]
|
||||||
|
) -> str:
|
||||||
|
"""格式化用户模板"""
|
||||||
|
content = self.user_template
|
||||||
|
|
||||||
|
# 替换上下文变量
|
||||||
|
for key, value in context.items():
|
||||||
|
placeholder = f"{{{key}}}"
|
||||||
|
if placeholder in content:
|
||||||
|
if isinstance(value, (dict, list)):
|
||||||
|
content = content.replace(placeholder, json.dumps(value, ensure_ascii=False, indent=2))
|
||||||
|
else:
|
||||||
|
content = content.replace(placeholder, str(value))
|
||||||
|
|
||||||
|
# 添加用户输入
|
||||||
|
if user_input:
|
||||||
|
content += f"\n\n【用户需求】\n{user_input}"
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
class PromptEngineeringService:
|
||||||
|
"""提示词工程服务"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.templates: Dict[PromptType, PromptTemplate] = {}
|
||||||
|
self._init_templates()
|
||||||
|
|
||||||
|
def _init_templates(self):
|
||||||
|
"""初始化所有提示词模板"""
|
||||||
|
|
||||||
|
# ==================== 文档解析模板 ====================
|
||||||
|
self.templates[PromptType.DOCUMENT_PARSING] = PromptTemplate(
|
||||||
|
name="文档解析",
|
||||||
|
type=PromptType.DOCUMENT_PARSING,
|
||||||
|
system_prompt="""你是一个专业的文档解析专家。你的任务是从各类文档(Word、Excel、Markdown、纯文本)中提取关键信息。
|
||||||
|
|
||||||
|
请严格按照JSON格式输出解析结果:
|
||||||
|
{
|
||||||
|
"success": true/false,
|
||||||
|
"document_type": "文档类型",
|
||||||
|
"key_fields": {"字段名": "字段值", ...},
|
||||||
|
"summary": "文档摘要(100字内)",
|
||||||
|
"structured_data": {...} // 提取的表格或其他结构化数据
|
||||||
|
}
|
||||||
|
|
||||||
|
重要规则:
|
||||||
|
- 只提取明确存在的信息,不要猜测
|
||||||
|
- 如果是表格数据,请以数组格式输出
|
||||||
|
- 日期请使用 YYYY-MM-DD 格式
|
||||||
|
- 金额请使用数字格式
|
||||||
|
- 如果无法提取某个字段,设置为 null""",
|
||||||
|
user_template="""请解析以下文档内容:
|
||||||
|
|
||||||
|
=== 文档开始 ===
|
||||||
|
{content}
|
||||||
|
=== 文档结束 ===
|
||||||
|
|
||||||
|
请提取文档中的关键信息。""",
|
||||||
|
examples=[
|
||||||
|
{
|
||||||
|
"input": "合同金额:100万元\n签订日期:2024年1月15日\n甲方:张三\n乙方:某某公司",
|
||||||
|
"output": '{"success": true, "document_type": "合同", "key_fields": {"金额": 1000000, "日期": "2024-01-15", "甲方": "张三", "乙方": "某某公司"}, "summary": "甲乙双方签订的金额为100万元的合同", "structured_data": null}'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
rules=[
|
||||||
|
"只输出JSON,不要添加任何解释",
|
||||||
|
"使用严格的JSON格式"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==================== 字段提取模板 ====================
|
||||||
|
self.templates[PromptType.FIELD_EXTRACTION] = PromptTemplate(
|
||||||
|
name="字段提取",
|
||||||
|
type=PromptType.FIELD_EXTRACTION,
|
||||||
|
system_prompt="""你是一个专业的数据提取专家。你的任务是从文档内容中提取指定字段的信息。
|
||||||
|
|
||||||
|
请严格按照以下JSON格式输出:
|
||||||
|
{
|
||||||
|
"value": "提取到的值,找不到则为空字符串",
|
||||||
|
"source": "数据来源描述",
|
||||||
|
"confidence": 0.0到1.0之间的置信度
|
||||||
|
}
|
||||||
|
|
||||||
|
重要规则:
|
||||||
|
- 严格按字段名称匹配,不要提取无关信息
|
||||||
|
- 置信度反映你对提取结果的信心程度
|
||||||
|
- 如果字段不存在或无法确定,value设为空字符串,confidence设为0.0
|
||||||
|
- value必须是实际值,不能是"未找到"之类的描述""",
|
||||||
|
user_template="""请从以下文档内容中提取指定字段的信息。
|
||||||
|
|
||||||
|
【需要提取的字段】
|
||||||
|
字段名称:{field_name}
|
||||||
|
字段类型:{field_type}
|
||||||
|
是否必填:{required}
|
||||||
|
|
||||||
|
【用户提示】
|
||||||
|
{hint}
|
||||||
|
|
||||||
|
【文档内容】
|
||||||
|
{context}
|
||||||
|
|
||||||
|
请提取字段值。""",
|
||||||
|
examples=[
|
||||||
|
{
|
||||||
|
"input": "文档内容:姓名张三,电话13800138000,邮箱zhangsan@example.com",
|
||||||
|
"output": '{"value": "张三", "source": "文档第1行", "confidence": 1.0}'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
rules=[
|
||||||
|
"只输出JSON,不要添加任何解释"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==================== 表格填写模板 ====================
|
||||||
|
self.templates[PromptType.TABLE_FILLING] = PromptTemplate(
|
||||||
|
name="表格填写",
|
||||||
|
type=PromptType.TABLE_FILLING,
|
||||||
|
system_prompt="""你是一个专业的表格填写助手。你的任务是根据提供的文档内容,填写表格模板中的字段。
|
||||||
|
|
||||||
|
请严格按照以下JSON格式输出:
|
||||||
|
{
|
||||||
|
"filled_data": {{"字段1": "值1", "字段2": "值2", ...}},
|
||||||
|
"fill_details": [
|
||||||
|
{{"field": "字段1", "value": "值1", "source": "来源", "confidence": 0.95}},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
重要规则:
|
||||||
|
- 只填写模板中存在的字段
|
||||||
|
- 值必须来自提供的文档内容,不要编造
|
||||||
|
- 如果某个字段在文档中找不到对应值,设为空字符串
|
||||||
|
- fill_details 中记录每个字段的详细信息""",
|
||||||
|
user_template="""请根据以下文档内容,填写表格模板。
|
||||||
|
|
||||||
|
【表格模板字段】
|
||||||
|
{fields}
|
||||||
|
|
||||||
|
【用户需求】
|
||||||
|
{hint}
|
||||||
|
|
||||||
|
【参考文档内容】
|
||||||
|
{context}
|
||||||
|
|
||||||
|
请填写表格。""",
|
||||||
|
examples=[
|
||||||
|
{
|
||||||
|
"input": "字段:姓名、电话\n文档:张三,电话是13800138000",
|
||||||
|
"output": '{"filled_data": {"姓名": "张三", "电话": "13800138000"}, "fill_details": [{"field": "姓名", "value": "张三", "source": "文档第1行", "confidence": 1.0}, {"field": "电话", "value": "13800138000", "source": "文档第1行", "confidence": 1.0}]}'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
rules=[
|
||||||
|
"只输出JSON,不要添加任何解释"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==================== 查询生成模板 ====================
|
||||||
|
self.templates[PromptType.QUERY_GENERATION] = PromptTemplate(
|
||||||
|
name="查询生成",
|
||||||
|
type=PromptType.QUERY_GENERATION,
|
||||||
|
system_prompt="""你是一个SQL查询生成专家。你的任务是根据用户的自然语言需求,生成相应的数据库查询语句。
|
||||||
|
|
||||||
|
请严格按照以下JSON格式输出:
|
||||||
|
{
|
||||||
|
"sql_query": "生成的SQL查询语句",
|
||||||
|
"explanation": "查询逻辑说明"
|
||||||
|
}
|
||||||
|
|
||||||
|
重要规则:
|
||||||
|
- 只生成 SELECT 查询语句,不要生成 INSERT/UPDATE/DELETE
|
||||||
|
- 必须包含 WHERE 条件限制查询范围
|
||||||
|
- 表名和字段名使用反引号包裹
|
||||||
|
- 确保SQL语法正确
|
||||||
|
- 如果无法生成有效的查询,sql_query设为空字符串""",
|
||||||
|
user_template="""根据以下信息生成查询语句。
|
||||||
|
|
||||||
|
【数据库表结构】
|
||||||
|
{table_schema}
|
||||||
|
|
||||||
|
【RAG检索到的上下文】
|
||||||
|
{rag_context}
|
||||||
|
|
||||||
|
【用户查询需求】
|
||||||
|
{user_intent}
|
||||||
|
|
||||||
|
请生成SQL查询。""",
|
||||||
|
examples=[
|
||||||
|
{
|
||||||
|
"input": "表:orders(订单号, 金额, 日期, 客户)\n需求:查询2024年1月销售额超过10000的订单",
|
||||||
|
"output": '{"sql_query": "SELECT * FROM `orders` WHERE `日期` >= \\'2024-01-01\\' AND `日期` < \\'2024-02-01\\' AND `金额` > 10000", "explanation": "筛选2024年1月销售额超过10000的订单"}'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
rules=[
|
||||||
|
"只输出JSON,不要添加任何解释",
|
||||||
|
"禁止生成 DROP、DELETE、TRUNCATE 等危险操作"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==================== 文本摘要模板 ====================
|
||||||
|
self.templates[PromptType.TEXT_SUMMARY] = PromptTemplate(
|
||||||
|
name="文本摘要",
|
||||||
|
type=PromptType.TEXT_SUMMARY,
|
||||||
|
system_prompt="""你是一个专业的文本摘要专家。你的任务是对长文档进行压缩,提取关键信息。
|
||||||
|
|
||||||
|
请严格按照以下JSON格式输出:
|
||||||
|
{
|
||||||
|
"summary": "摘要内容(不超过200字)",
|
||||||
|
"key_points": ["要点1", "要点2", "要点3"],
|
||||||
|
"keywords": ["关键词1", "关键词2", "关键词3"]
|
||||||
|
}""",
|
||||||
|
user_template="""请为以下文档生成摘要:
|
||||||
|
|
||||||
|
=== 文档开始 ===
|
||||||
|
{content}
|
||||||
|
=== 文档结束 ===
|
||||||
|
|
||||||
|
生成简明摘要。""",
|
||||||
|
rules=[
|
||||||
|
"只输出JSON,不要添加任何解释"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==================== 意图分类模板 ====================
|
||||||
|
self.templates[PromptType.INTENT_CLASSIFICATION] = PromptTemplate(
|
||||||
|
name="意图分类",
|
||||||
|
type=PromptType.INTENT_CLASSIFICATION,
|
||||||
|
system_prompt="""你是一个意图分类专家。你的任务是分析用户的自然语言输入,判断用户的真实意图。
|
||||||
|
|
||||||
|
支持的意图类型:
|
||||||
|
- upload: 上传文档
|
||||||
|
- parse: 解析文档
|
||||||
|
- query: 查询数据
|
||||||
|
- fill: 填写表格
|
||||||
|
- export: 导出数据
|
||||||
|
- analyze: 分析数据
|
||||||
|
- other: 其他/未知
|
||||||
|
|
||||||
|
请严格按照以下JSON格式输出:
|
||||||
|
{
|
||||||
|
"intent": "意图类型",
|
||||||
|
"confidence": 0.0到1.0之间的置信度,
|
||||||
|
"entities": {{"实体名": "实体值", ...}}, // 识别出的关键实体
|
||||||
|
"suggestion": "建议的下一步操作"
|
||||||
|
}""",
|
||||||
|
user_template="""请分析以下用户输入,判断其意图:
|
||||||
|
|
||||||
|
【用户输入】
|
||||||
|
{user_input}
|
||||||
|
|
||||||
|
请分类。""",
|
||||||
|
rules=[
|
||||||
|
"只输出JSON,不要添加任何解释"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==================== 数据分类模板 ====================
|
||||||
|
self.templates[PromptType.DATA_CLASSIFICATION] = PromptTemplate(
|
||||||
|
name="数据分类",
|
||||||
|
type=PromptType.DATA_CLASSIFICATION,
|
||||||
|
system_prompt="""你是一个数据分类专家。你的任务是判断数据的类型和格式。
|
||||||
|
|
||||||
|
请严格按照以下JSON格式输出:
|
||||||
|
{
|
||||||
|
"data_type": "text/number/date/email/phone/url/amount/other",
|
||||||
|
"format": "具体格式描述",
|
||||||
|
"is_valid": true/false,
|
||||||
|
"normalized_value": "规范化后的值"
|
||||||
|
}""",
|
||||||
|
user_template="""请分析以下数据的类型和格式:
|
||||||
|
|
||||||
|
【数据】
|
||||||
|
{value}
|
||||||
|
|
||||||
|
【期望类型(如果有)】
|
||||||
|
{expected_type}
|
||||||
|
|
||||||
|
请分类。""",
|
||||||
|
rules=[
|
||||||
|
"只输出JSON,不要添加任何解释"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_prompt(
|
||||||
|
self,
|
||||||
|
type: PromptType,
|
||||||
|
context: Dict[str, Any],
|
||||||
|
user_input: Optional[str] = None
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
获取格式化后的提示词
|
||||||
|
|
||||||
|
Args:
|
||||||
|
type: 提示词类型
|
||||||
|
context: 上下文数据
|
||||||
|
user_input: 用户输入
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
消息列表
|
||||||
|
"""
|
||||||
|
template = self.templates.get(type)
|
||||||
|
if not template:
|
||||||
|
logger.warning(f"未找到提示词模板: {type}")
|
||||||
|
return [{"role": "user", "content": str(context)}]
|
||||||
|
|
||||||
|
return template.format(context, user_input)
|
||||||
|
|
||||||
|
def get_template(self, type: PromptType) -> Optional[PromptTemplate]:
|
||||||
|
"""获取提示词模板"""
|
||||||
|
return self.templates.get(type)
|
||||||
|
|
||||||
|
def add_template(self, template: PromptTemplate):
|
||||||
|
"""添加自定义提示词模板"""
|
||||||
|
self.templates[template.type] = template
|
||||||
|
logger.info(f"已添加提示词模板: {template.name}")
|
||||||
|
|
||||||
|
def update_template(self, type: PromptType, **kwargs):
|
||||||
|
"""更新提示词模板"""
|
||||||
|
template = self.templates.get(type)
|
||||||
|
if template:
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if hasattr(template, key):
|
||||||
|
setattr(template, key, value)
|
||||||
|
|
||||||
|
def optimize_prompt(
|
||||||
|
self,
|
||||||
|
type: PromptType,
|
||||||
|
feedback: str,
|
||||||
|
iteration: int = 1
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
根据反馈优化提示词
|
||||||
|
|
||||||
|
Args:
|
||||||
|
type: 提示词类型
|
||||||
|
feedback: 优化反馈
|
||||||
|
iteration: 迭代次数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
优化后的提示词
|
||||||
|
"""
|
||||||
|
template = self.templates.get(type)
|
||||||
|
if not template:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 简单优化策略:根据反馈添加规则
|
||||||
|
optimization_rules = {
|
||||||
|
"准确率低": "提高要求,明确指出必须从原文提取,不要猜测",
|
||||||
|
"格式错误": "强调JSON格式要求,提供更详细的格式示例",
|
||||||
|
"遗漏信息": "添加提取更多细节的要求",
|
||||||
|
}
|
||||||
|
|
||||||
|
new_rules = []
|
||||||
|
for keyword, rule in optimization_rules.items():
|
||||||
|
if keyword in feedback:
|
||||||
|
new_rules.append(rule)
|
||||||
|
|
||||||
|
if new_rules:
|
||||||
|
template.rules.extend(new_rules)
|
||||||
|
|
||||||
|
return template.format({}, None)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 全局单例 ====================
|
||||||
|
|
||||||
|
prompt_service = PromptEngineeringService()
|
||||||
307
backend/app/services/template_fill_service.py
Normal file
307
backend/app/services/template_fill_service.py
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
"""
|
||||||
|
表格模板填写服务
|
||||||
|
|
||||||
|
从非结构化文档中检索信息并填写到表格模板
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from app.core.database import mongodb
|
||||||
|
from app.services.rag_service import rag_service
|
||||||
|
from app.services.llm_service import llm_service
|
||||||
|
from app.services.excel_storage_service import excel_storage_service
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TemplateField:
|
||||||
|
"""模板字段"""
|
||||||
|
cell: str # 单元格位置,如 "A1"
|
||||||
|
name: str # 字段名称
|
||||||
|
field_type: str = "text" # 字段类型: text/number/date
|
||||||
|
required: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FillResult:
|
||||||
|
"""填写结果"""
|
||||||
|
field: str
|
||||||
|
value: Any
|
||||||
|
source: str # 来源文档
|
||||||
|
confidence: float = 1.0 # 置信度
|
||||||
|
|
||||||
|
|
||||||
|
class TemplateFillService:
|
||||||
|
"""表格填写服务"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.llm = llm_service
|
||||||
|
self.rag = rag_service
|
||||||
|
|
||||||
|
async def fill_template(
|
||||||
|
self,
|
||||||
|
template_fields: List[TemplateField],
|
||||||
|
source_doc_ids: Optional[List[str]] = None,
|
||||||
|
user_hint: Optional[str] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
填写表格模板
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_fields: 模板字段列表
|
||||||
|
source_doc_ids: 源文档ID列表,不指定则从所有文档检索
|
||||||
|
user_hint: 用户提示(如"请从合同文档中提取")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
填写结果
|
||||||
|
"""
|
||||||
|
filled_data = {}
|
||||||
|
fill_details = []
|
||||||
|
|
||||||
|
for field in template_fields:
|
||||||
|
try:
|
||||||
|
# 1. 从 RAG 检索相关上下文
|
||||||
|
rag_results = await self._retrieve_context(field.name, user_hint)
|
||||||
|
|
||||||
|
if not rag_results:
|
||||||
|
# 如果没有检索到结果,尝试直接询问 LLM
|
||||||
|
result = FillResult(
|
||||||
|
field=field.name,
|
||||||
|
value="",
|
||||||
|
source="未找到相关数据",
|
||||||
|
confidence=0.0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 2. 构建 Prompt 让 LLM 提取信息
|
||||||
|
result = await self._extract_field_value(
|
||||||
|
field=field,
|
||||||
|
rag_context=rag_results,
|
||||||
|
user_hint=user_hint
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 存储结果
|
||||||
|
filled_data[field.name] = result.value
|
||||||
|
fill_details.append({
|
||||||
|
"field": field.name,
|
||||||
|
"cell": field.cell,
|
||||||
|
"value": result.value,
|
||||||
|
"source": result.source,
|
||||||
|
"confidence": result.confidence
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"字段 {field.name} 填写完成: {result.value}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"填写字段 {field.name} 失败: {str(e)}")
|
||||||
|
filled_data[field.name] = f"[提取失败: {str(e)}]"
|
||||||
|
fill_details.append({
|
||||||
|
"field": field.name,
|
||||||
|
"cell": field.cell,
|
||||||
|
"value": f"[提取失败]",
|
||||||
|
"source": "error",
|
||||||
|
"confidence": 0.0
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"filled_data": filled_data,
|
||||||
|
"fill_details": fill_details
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _retrieve_context(
|
||||||
|
self,
|
||||||
|
field_name: str,
|
||||||
|
user_hint: Optional[str] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
从 RAG 检索相关上下文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_name: 字段名称
|
||||||
|
user_hint: 用户提示
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
检索结果列表
|
||||||
|
"""
|
||||||
|
# 构建查询文本
|
||||||
|
query = field_name
|
||||||
|
if user_hint:
|
||||||
|
query = f"{user_hint} {field_name}"
|
||||||
|
|
||||||
|
# 检索相关文档片段
|
||||||
|
results = self.rag.retrieve(query=query, top_k=5)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def _extract_field_value(
|
||||||
|
self,
|
||||||
|
field: TemplateField,
|
||||||
|
rag_context: List[Dict[str, Any]],
|
||||||
|
user_hint: Optional[str] = None
|
||||||
|
) -> FillResult:
|
||||||
|
"""
|
||||||
|
使用 LLM 从上下文中提取字段值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field: 字段定义
|
||||||
|
rag_context: RAG 检索到的上下文
|
||||||
|
user_hint: 用户提示
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
提取结果
|
||||||
|
"""
|
||||||
|
# 构建上下文文本
|
||||||
|
context_text = "\n\n".join([
|
||||||
|
f"【文档 {i+1}】\n{doc['content']}"
|
||||||
|
for i, doc in enumerate(rag_context)
|
||||||
|
])
|
||||||
|
|
||||||
|
# 构建 Prompt
|
||||||
|
prompt = f"""你是一个数据提取专家。请根据以下文档内容,提取指定字段的信息。
|
||||||
|
|
||||||
|
需要提取的字段:
|
||||||
|
- 字段名称:{field.name}
|
||||||
|
- 字段类型:{field.field_type}
|
||||||
|
- 是否必填:{'是' if field.required else '否'}
|
||||||
|
|
||||||
|
{'用户提示:' + user_hint if user_hint else ''}
|
||||||
|
|
||||||
|
参考文档内容:
|
||||||
|
{context_text}
|
||||||
|
|
||||||
|
请严格按照以下 JSON 格式输出,不要添加任何解释:
|
||||||
|
{{
|
||||||
|
"value": "提取到的值,如果没有找到则填写空字符串",
|
||||||
|
"source": "数据来源的文档描述",
|
||||||
|
"confidence": 0.0到1.0之间的置信度
|
||||||
|
}}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 调用 LLM
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "你是一个专业的数据提取助手。请严格按JSON格式输出。"},
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.llm.chat(
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.1,
|
||||||
|
max_tokens=500
|
||||||
|
)
|
||||||
|
|
||||||
|
content = self.llm.extract_message_content(response)
|
||||||
|
|
||||||
|
# 解析 JSON 响应
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
# 尝试提取 JSON
|
||||||
|
json_match = re.search(r'\{[\s\S]*\}', content)
|
||||||
|
if json_match:
|
||||||
|
result = json.loads(json_match.group())
|
||||||
|
return FillResult(
|
||||||
|
field=field.name,
|
||||||
|
value=result.get("value", ""),
|
||||||
|
source=result.get("source", "LLM生成"),
|
||||||
|
confidence=result.get("confidence", 0.5)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 如果无法解析,返回原始内容
|
||||||
|
return FillResult(
|
||||||
|
field=field.name,
|
||||||
|
value=content.strip(),
|
||||||
|
source="直接提取",
|
||||||
|
confidence=0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM 提取失败: {str(e)}")
|
||||||
|
return FillResult(
|
||||||
|
field=field.name,
|
||||||
|
value="",
|
||||||
|
source=f"提取失败: {str(e)}",
|
||||||
|
confidence=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_template_fields_from_file(
|
||||||
|
self,
|
||||||
|
file_path: str,
|
||||||
|
file_type: str = "xlsx"
|
||||||
|
) -> List[TemplateField]:
|
||||||
|
"""
|
||||||
|
从模板文件提取字段定义
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: 模板文件路径
|
||||||
|
file_type: 文件类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
字段列表
|
||||||
|
"""
|
||||||
|
fields = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
if file_type in ["xlsx", "xls"]:
|
||||||
|
# 从 Excel 读取表头
|
||||||
|
import pandas as pd
|
||||||
|
df = pd.read_excel(file_path, nrows=5)
|
||||||
|
|
||||||
|
for idx, col in enumerate(df.columns):
|
||||||
|
# 获取单元格位置 (A, B, C, ...)
|
||||||
|
cell = self._column_to_cell(idx)
|
||||||
|
|
||||||
|
fields.append(TemplateField(
|
||||||
|
cell=cell,
|
||||||
|
name=str(col),
|
||||||
|
field_type=self._infer_field_type(df[col]),
|
||||||
|
required=True
|
||||||
|
))
|
||||||
|
|
||||||
|
elif file_type == "docx":
|
||||||
|
# 从 Word 表格读取
|
||||||
|
from docx import Document
|
||||||
|
doc = Document(file_path)
|
||||||
|
|
||||||
|
for table_idx, table in enumerate(doc.tables):
|
||||||
|
for row_idx, row in enumerate(table.rows):
|
||||||
|
for col_idx, cell in enumerate(row.cells):
|
||||||
|
cell_text = cell.text.strip()
|
||||||
|
if cell_text:
|
||||||
|
fields.append(TemplateField(
|
||||||
|
cell=self._column_to_cell(col_idx),
|
||||||
|
name=cell_text,
|
||||||
|
field_type="text",
|
||||||
|
required=True
|
||||||
|
))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取模板字段失败: {str(e)}")
|
||||||
|
|
||||||
|
return fields
|
||||||
|
|
||||||
|
def _column_to_cell(self, col_idx: int) -> str:
|
||||||
|
"""将列索引转换为单元格列名 (0 -> A, 1 -> B, ...)"""
|
||||||
|
result = ""
|
||||||
|
while col_idx >= 0:
|
||||||
|
result = chr(65 + (col_idx % 26)) + result
|
||||||
|
col_idx = col_idx // 26 - 1
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _infer_field_type(self, series) -> str:
|
||||||
|
"""推断字段类型"""
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
if pd.api.types.is_numeric_dtype(series):
|
||||||
|
return "number"
|
||||||
|
elif pd.api.types.is_datetime64_any_dtype(series):
|
||||||
|
return "date"
|
||||||
|
else:
|
||||||
|
return "text"
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 全局单例 ====================
|
||||||
|
|
||||||
|
template_fill_service = TemplateFillService()
|
||||||
Reference in New Issue
Block a user