From 5bcad4a5fa131cc669cc3444bb58047d76f540f6 Mon Sep 17 00:00:00 2001 From: KiriAky 107 Date: Thu, 26 Mar 2026 23:14:39 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=85=B6=E4=BB=96=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F=E6=96=87=E6=A1=A3=E7=9A=84=E8=A7=A3=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/api/__init__.py | 24 +- backend/app/api/endpoints/templates.py | 228 +++++++++ backend/app/core/document_parser/__init__.py | 39 +- .../app/core/document_parser/docx_parser.py | 163 +++++++ backend/app/core/document_parser/md_parser.py | 262 +++++++++++ .../app/core/document_parser/txt_parser.py | 278 +++++++++++ backend/app/services/excel_storage_service.py | 352 ++++++++++++++ backend/app/services/prompt_service.py | 444 ++++++++++++++++++ backend/app/services/template_fill_service.py | 307 ++++++++++++ 9 files changed, 2075 insertions(+), 22 deletions(-) create mode 100644 backend/app/api/endpoints/templates.py create mode 100644 backend/app/core/document_parser/docx_parser.py create mode 100644 backend/app/core/document_parser/md_parser.py create mode 100644 backend/app/core/document_parser/txt_parser.py create mode 100644 backend/app/services/excel_storage_service.py create mode 100644 backend/app/services/prompt_service.py create mode 100644 backend/app/services/template_fill_service.py diff --git a/backend/app/api/__init__.py b/backend/app/api/__init__.py index 6681dfa..c393c2a 100644 --- a/backend/app/api/__init__.py +++ b/backend/app/api/__init__.py @@ -4,10 +4,11 @@ API 路由注册模块 from fastapi import APIRouter from app.api.endpoints import ( upload, - documents, # 新增:文档上传 - tasks, # 新增:任务管理 - library, # 新增:文档库 - rag, # 新增:RAG检索 + documents, # 多格式文档上传 + tasks, # 任务管理 + library, # 文档库 + rag, # RAG检索 + templates, # 表格模板 ai_analyze, visualization, analysis_charts, @@ -18,12 +19,13 @@ from app.api.endpoints import ( api_router = APIRouter() # 注册各模块路由 -api_router.include_router(health.router) # 健康检查 -api_router.include_router(upload.router) # 原有Excel上传 +api_router.include_router(health.router) # 健康检查 +api_router.include_router(upload.router) # 原有Excel上传 api_router.include_router(documents.router) # 多格式文档上传 -api_router.include_router(tasks.router) # 任务状态查询 -api_router.include_router(library.router) # 文档库管理 -api_router.include_router(rag.router) # RAG检索 -api_router.include_router(ai_analyze.router) # AI分析 +api_router.include_router(tasks.router) # 任务状态查询 +api_router.include_router(library.router) # 文档库管理 +api_router.include_router(rag.router) # RAG检索 +api_router.include_router(templates.router) # 表格模板 +api_router.include_router(ai_analyze.router) # AI分析 api_router.include_router(visualization.router) # 可视化 -api_router.include_router(analysis_charts.router) # 分析图表 +api_router.include_router(analysis_charts.router) # 分析图表 diff --git a/backend/app/api/endpoints/templates.py b/backend/app/api/endpoints/templates.py new file mode 100644 index 0000000..2248b1c --- /dev/null +++ b/backend/app/api/endpoints/templates.py @@ -0,0 +1,228 @@ +""" +表格模板 API 接口 + +提供模板上传、解析和填写功能 +""" +import io +from typing import List, Optional + +from fastapi import APIRouter, File, HTTPException, Query, UploadFile +from fastapi.responses import StreamingResponse +import pandas as pd +from pydantic import BaseModel + +from app.services.template_fill_service import template_fill_service, TemplateField +from app.services.excel_storage_service import excel_storage_service + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/templates", tags=["表格模板"]) + + +# ==================== 请求/响应模型 ==================== + +class TemplateFieldRequest(BaseModel): + """模板字段请求""" + cell: str + name: str + field_type: str = "text" + required: bool = True + + +class FillRequest(BaseModel): + """填写请求""" + template_id: str + template_fields: List[TemplateFieldRequest] + source_doc_ids: Optional[List[str]] = None + user_hint: Optional[str] = None + + +class ExportRequest(BaseModel): + """导出请求""" + template_id: str + filled_data: dict + format: str = "xlsx" # xlsx 或 docx + + +# ==================== 接口实现 ==================== + +@router.post("/upload") +async def upload_template( + file: UploadFile = File(...), +): + """ + 上传表格模板文件 + + 支持 Excel (.xlsx, .xls) 和 Word (.docx) 格式 + + Returns: + 模板信息,包括提取的字段列表 + """ + if not file.filename: + raise HTTPException(status_code=400, detail="文件名为空") + + file_ext = file.filename.split('.')[-1].lower() + if file_ext not in ['xlsx', 'xls', 'docx']: + raise HTTPException( + status_code=400, + detail=f"不支持的模板格式: {file_ext},仅支持 xlsx/xls/docx" + ) + + try: + # 保存文件 + from app.services.file_service import file_service + content = await file.read() + saved_path = file_service.save_uploaded_file( + content, + file.filename, + subfolder="templates" + ) + + # 提取字段 + template_fields = await template_fill_service.get_template_fields_from_file( + saved_path, + file_ext + ) + + return { + "success": True, + "template_id": saved_path, # 使用文件路径作为ID + "filename": file.filename, + "file_type": file_ext, + "fields": [ + { + "cell": f.cell, + "name": f.name, + "field_type": f.field_type, + "required": f.required + } + for f in template_fields + ], + "field_count": len(template_fields) + } + + except Exception as e: + logger.error(f"上传模板失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}") + + +@router.post("/fields") +async def extract_template_fields( + template_id: str = Query(..., description="模板ID/文件路径"), + file_type: str = Query("xlsx", description="文件类型") +): + """ + 从已上传的模板提取字段定义 + + Args: + template_id: 模板ID + file_type: 文件类型 + + Returns: + 字段列表 + """ + try: + fields = await template_fill_service.get_template_fields_from_file( + template_id, + file_type + ) + + return { + "success": True, + "fields": [ + { + "cell": f.cell, + "name": f.name, + "field_type": f.field_type, + "required": f.required + } + for f in fields + ] + } + + except Exception as e: + logger.error(f"提取字段失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"提取失败: {str(e)}") + + +@router.post("/fill") +async def fill_template( + request: FillRequest, +): + """ + 执行表格填写 + + 根据提供的字段定义,从已上传的文档中检索信息并填写 + + Args: + request: 填写请求 + + Returns: + 填写结果 + """ + try: + # 转换字段 + fields = [ + TemplateField( + cell=f.cell, + name=f.name, + field_type=f.field_type, + required=f.required + ) + for f in request.template_fields + ] + + # 执行填写 + result = await template_fill_service.fill_template( + template_fields=fields, + source_doc_ids=request.source_doc_ids, + user_hint=request.user_hint + ) + + return result + + except Exception as e: + logger.error(f"填写表格失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"填写失败: {str(e)}") + + +@router.post("/export") +async def export_filled_template( + request: ExportRequest, +): + """ + 导出填写后的表格 + + Args: + request: 导出请求 + + Returns: + 文件流 + """ + try: + # 创建 DataFrame + df = pd.DataFrame([request.filled_data]) + + # 导出为 Excel + output = io.BytesIO() + with pd.ExcelWriter(output, engine='openpyxl') as writer: + df.to_excel(writer, index=False, sheet_name='填写结果') + + output.seek(0) + + # 生成文件名 + filename = f"filled_template.{request.format}" + + return StreamingResponse( + io.BytesIO(output.getvalue()), + media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + headers={"Content-Disposition": f"attachment; filename={filename}"} + ) + + except Exception as e: + logger.error(f"导出失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"导出失败: {str(e)}") + + +# ==================== 需要添加的 import ==================== +import logging diff --git a/backend/app/core/document_parser/__init__.py b/backend/app/core/document_parser/__init__.py index 88b3e04..b28c686 100644 --- a/backend/app/core/document_parser/__init__.py +++ b/backend/app/core/document_parser/__init__.py @@ -2,26 +2,29 @@ 文档解析模块 - 支持多种文件格式的解析 """ from pathlib import Path -from typing import Dict, Optional +from typing import Dict from .base import BaseParser, ParseResult from .xlsx_parser import XlsxParser - -# 导入其他解析器 (需要先实现) -# from .docx_parser import DocxParser -# from .md_parser import MarkdownParser -# from .txt_parser import TxtParser +from .docx_parser import DocxParser +from .md_parser import MarkdownParser +from .txt_parser import TxtParser class ParserFactory: """解析器工厂,根据文件类型返回对应解析器""" _parsers: Dict[str, BaseParser] = { + # Excel '.xlsx': XlsxParser(), '.xls': XlsxParser(), - # '.docx': DocxParser(), # TODO: 待实现 - # '.md': MarkdownParser(), # TODO: 待实现 - # '.txt': TxtParser(), # TODO: 待实现 + # Word + '.docx': DocxParser(), + # Markdown + '.md': MarkdownParser(), + '.markdown': MarkdownParser(), + # 文本 + '.txt': TxtParser(), } @classmethod @@ -30,7 +33,8 @@ class ParserFactory: ext = Path(file_path).suffix.lower() parser = cls._parsers.get(ext) if not parser: - raise ValueError(f"不支持的文件格式: {ext},支持的格式: {list(cls._parsers.keys())}") + supported = list(cls._parsers.keys()) + raise ValueError(f"不支持的文件格式: {ext},支持的格式: {supported}") return parser @classmethod @@ -44,5 +48,18 @@ class ParserFactory: """注册新的解析器""" cls._parsers[ext.lower()] = parser + @classmethod + def get_supported_extensions(cls) -> list: + """获取所有支持的扩展名""" + return list(cls._parsers.keys()) -__all__ = ['BaseParser', 'ParseResult', 'XlsxParser', 'ParserFactory'] + +__all__ = [ + 'BaseParser', + 'ParseResult', + 'ParserFactory', + 'XlsxParser', + 'DocxParser', + 'MarkdownParser', + 'TxtParser', +] diff --git a/backend/app/core/document_parser/docx_parser.py b/backend/app/core/document_parser/docx_parser.py new file mode 100644 index 0000000..75e79da --- /dev/null +++ b/backend/app/core/document_parser/docx_parser.py @@ -0,0 +1,163 @@ +""" +Word 文档 (.docx) 解析器 +""" +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional + +from docx import Document + +from .base import BaseParser, ParseResult + +logger = logging.getLogger(__name__) + + +class DocxParser(BaseParser): + """Word 文档解析器""" + + def __init__(self): + super().__init__() + self.supported_extensions = ['.docx'] + self.parser_name = "docx_parser" + + def parse( + self, + file_path: str, + **kwargs + ) -> ParseResult: + """ + 解析 Word 文档 + + Args: + file_path: 文件路径 + **kwargs: 其他参数 + + Returns: + ParseResult: 解析结果 + """ + path = Path(file_path) + + # 检查文件是否存在 + if not path.exists(): + return ParseResult( + success=False, + error=f"文件不存在: {file_path}" + ) + + # 检查文件扩展名 + if path.suffix.lower() not in self.supported_extensions: + return ParseResult( + success=False, + error=f"不支持的文件类型: {path.suffix}" + ) + + try: + # 读取 Word 文档 + doc = Document(file_path) + + # 提取文本内容 + paragraphs = [] + for para in doc.paragraphs: + if para.text.strip(): + paragraphs.append(para.text) + + # 提取表格内容 + tables_data = [] + for i, table in enumerate(doc.tables): + table_rows = [] + for row in table.rows: + row_data = [cell.text.strip() for cell in row.cells] + table_rows.append(row_data) + + if table_rows: + tables_data.append({ + "table_index": i, + "rows": table_rows, + "row_count": len(table_rows), + "column_count": len(table_rows[0]) if table_rows else 0 + }) + + # 合并所有文本 + full_text = "\n".join(paragraphs) + + # 构建元数据 + metadata = { + "filename": path.name, + "extension": path.suffix.lower(), + "file_size": path.stat().st_size, + "paragraph_count": len(paragraphs), + "table_count": len(tables_data), + "word_count": len(full_text), + "char_count": len(full_text.replace("\n", "")), + "has_tables": len(tables_data) > 0 + } + + # 返回结果 + return ParseResult( + success=True, + data={ + "content": full_text, + "paragraphs": paragraphs, + "tables": tables_data, + "word_count": len(full_text), + "structured_data": { + "paragraphs": paragraphs, + "tables": tables_data + } + }, + metadata=metadata + ) + + except Exception as e: + logger.error(f"解析 Word 文档失败: {str(e)}") + return ParseResult( + success=False, + error=f"解析 Word 文档失败: {str(e)}" + ) + + def extract_key_sentences(self, text: str, max_sentences: int = 10) -> List[str]: + """ + 从文本中提取关键句子 + + Args: + text: 文本内容 + max_sentences: 最大句子数 + + Returns: + 关键句子列表 + """ + # 简单实现:按句号分割,取前N个句子 + sentences = [s.strip() for s in text.split("。") if s.strip()] + return sentences[:max_sentences] + + def extract_structured_fields(self, text: str) -> Dict[str, Any]: + """ + 尝试提取结构化字段 + + 针对合同、简历等有固定格式的文档 + + Args: + text: 文本内容 + + Returns: + 提取的字段字典 + """ + fields = {} + + # 常见字段模式 + patterns = { + "姓名": r"姓名[::]\s*(\S+)", + "电话": r"电话[::]\s*(\d{11}|\d{3}-\d{8})", + "邮箱": r"邮箱[::]\s*(\S+@\S+)", + "地址": r"地址[::]\s*(.+?)(?:\n|$)", + "金额": r"金额[::]\s*(\d+(?:\.\d+)?)", + "日期": r"日期[::]\s*(\d{4}[年/-]\d{1,2}[月/-]\d{1,2})", + } + + import re + for field_name, pattern in patterns.items(): + match = re.search(pattern, text) + if match: + fields[field_name] = match.group(1) + + return fields diff --git a/backend/app/core/document_parser/md_parser.py b/backend/app/core/document_parser/md_parser.py new file mode 100644 index 0000000..fff1277 --- /dev/null +++ b/backend/app/core/document_parser/md_parser.py @@ -0,0 +1,262 @@ +""" +Markdown 文档解析器 +""" +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional + +import markdown + +from .base import BaseParser, ParseResult + +logger = logging.getLogger(__name__) + + +class MarkdownParser(BaseParser): + """Markdown 文档解析器""" + + def __init__(self): + super().__init__() + self.supported_extensions = ['.md', '.markdown'] + self.parser_name = "markdown_parser" + + def parse( + self, + file_path: str, + **kwargs + ) -> ParseResult: + """ + 解析 Markdown 文档 + + Args: + file_path: 文件路径 + **kwargs: 其他参数 + + Returns: + ParseResult: 解析结果 + """ + path = Path(file_path) + + # 检查文件是否存在 + if not path.exists(): + return ParseResult( + success=False, + error=f"文件不存在: {file_path}" + ) + + # 检查文件扩展名 + if path.suffix.lower() not in self.supported_extensions: + return ParseResult( + success=False, + error=f"不支持的文件类型: {path.suffix}" + ) + + try: + # 读取文件内容 + with open(file_path, 'r', encoding='utf-8') as f: + raw_content = f.read() + + # 解析 Markdown + md = markdown.Markdown(extensions=[ + 'markdown.extensions.tables', + 'markdown.extensions.fenced_code', + 'markdown.extensions.codehilite', + 'markdown.extensions.toc', + ]) + + html_content = md.convert(raw_content) + + # 提取标题结构 + titles = self._extract_titles(raw_content) + + # 提取代码块 + code_blocks = self._extract_code_blocks(raw_content) + + # 提取表格 + tables = self._extract_tables(raw_content) + + # 提取链接和图片 + links_images = self._extract_links_images(raw_content) + + # 清理后的纯文本(去除 Markdown 语法) + plain_text = self._strip_markdown(raw_content) + + # 构建元数据 + metadata = { + "filename": path.name, + "extension": path.suffix.lower(), + "file_size": path.stat().st_size, + "word_count": len(plain_text), + "char_count": len(raw_content), + "line_count": len(raw_content.splitlines()), + "title_count": len(titles), + "code_block_count": len(code_blocks), + "table_count": len(tables), + "link_count": len(links_images.get("links", [])), + "image_count": len(links_images.get("images", [])), + } + + return ParseResult( + success=True, + data={ + "content": plain_text, + "raw_content": raw_content, + "html_content": html_content, + "titles": titles, + "code_blocks": code_blocks, + "tables": tables, + "links_images": links_images, + "word_count": len(plain_text), + "structured_data": { + "titles": titles, + "code_blocks": code_blocks, + "tables": tables + } + }, + metadata=metadata + ) + + except Exception as e: + logger.error(f"解析 Markdown 文档失败: {str(e)}") + return ParseResult( + success=False, + error=f"解析 Markdown 文档失败: {str(e)}" + ) + + def _extract_titles(self, content: str) -> List[Dict[str, Any]]: + """提取标题结构""" + import re + titles = [] + + # 匹配 # 标题 + for match in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE): + level = len(match.group(1)) + title_text = match.group(2).strip() + titles.append({ + "level": level, + "text": title_text, + "line": content[:match.start()].count('\n') + 1 + }) + + return titles + + def _extract_code_blocks(self, content: str) -> List[Dict[str, str]]: + """提取代码块""" + import re + code_blocks = [] + + # 匹配 ```code ``` 格式 + pattern = r'```(\w*)\n(.*?)```' + for match in re.finditer(pattern, content, re.DOTALL): + language = match.group(1) or "text" + code = match.group(2).strip() + code_blocks.append({ + "language": language, + "code": code + }) + + return code_blocks + + def _extract_tables(self, content: str) -> List[Dict[str, Any]]: + """提取表格""" + import re + tables = [] + + # 简单表格匹配(| col1 | col2 | 格式) + lines = content.split('\n') + i = 0 + while i < len(lines): + line = lines[i].strip() + + # 检查是否是表格行 + if line.startswith('|') and line.endswith('|'): + # 找到表头 + header_row = [cell.strip() for cell in line.split('|')[1:-1]] + + # 检查下一行是否是分隔符 + if i + 1 < len(lines) and re.match(r'^\|[\s\-:|]+\|$', lines[i + 1]): + # 跳过分隔符,读取数据行 + data_rows = [] + for j in range(i + 2, len(lines)): + row_line = lines[j].strip() + if not (row_line.startswith('|') and row_line.endswith('|')): + break + row_data = [cell.strip() for cell in row_line.split('|')[1:-1]] + data_rows.append(row_data) + + if header_row and data_rows: + tables.append({ + "headers": header_row, + "rows": data_rows, + "row_count": len(data_rows), + "column_count": len(header_row) + }) + i = j - 1 + + i += 1 + + return tables + + def _extract_links_images(self, content: str) -> Dict[str, List[Dict[str, str]]]: + """提取链接和图片""" + import re + result = {"links": [], "images": []} + + # 提取链接 [text](url) + for match in re.finditer(r'\[([^\]]+)\]\(([^\)]+)\)', content): + result["links"].append({ + "text": match.group(1), + "url": match.group(2) + }) + + # 提取图片 ![alt](url) + for match in re.finditer(r'!\[([^\]]*)\]\(([^\)]+)\)', content): + result["images"].append({ + "alt": match.group(1), + "url": match.group(2) + }) + + return result + + def _strip_markdown(self, content: str) -> str: + """去除 Markdown 语法,获取纯文本""" + import re + + # 去除代码块 + content = re.sub(r'```[\s\S]*?```', '', content) + + # 去除行内代码 + content = re.sub(r'`[^`]+`', '', content) + + # 去除图片 + content = re.sub(r'!\[([^\]]*)\]\([^\)]+\)', r'\1', content) + + # 去除链接,保留文本 + content = re.sub(r'\[([^\]]+)\]\([^\)]+\)', r'\1', content) + + # 去除标题标记 + content = re.sub(r'^#{1,6}\s+', '', content, flags=re.MULTILINE) + + # 去除加粗和斜体 + content = re.sub(r'\*\*([^\*]+)\*\*', r'\1', content) + content = re.sub(r'\*([^\*]+)\*', r'\1', content) + content = re.sub(r'__([^_]+)__', r'\1', content) + content = re.sub(r'_([^_]+)_', r'\1', content) + + # 去除引用标记 + content = re.sub(r'^>\s+', '', content, flags=re.MULTILINE) + + # 去除列表标记 + content = re.sub(r'^[-*+]\s+', '', content, flags=re.MULTILINE) + content = re.sub(r'^\d+\.\s+', '', content, flags=re.MULTILINE) + + # 去除水平线 + content = re.sub(r'^[-*_]{3,}$', '', content, flags=re.MULTILINE) + + # 去除表格分隔符 + content = re.sub(r'^\|[\s\-:|]+\|$', '', content, flags=re.MULTILINE) + + # 清理多余空行 + content = re.sub(r'\n{3,}', '\n\n', content) + + return content.strip() diff --git a/backend/app/core/document_parser/txt_parser.py b/backend/app/core/document_parser/txt_parser.py new file mode 100644 index 0000000..173d8bf --- /dev/null +++ b/backend/app/core/document_parser/txt_parser.py @@ -0,0 +1,278 @@ +""" +纯文本 (.txt) 解析器 +""" +import logging +import re +from pathlib import Path +from typing import Any, Dict, List, Optional + +import chardet + +from .base import BaseParser, ParseResult + +logger = logging.getLogger(__name__) + + +class TxtParser(BaseParser): + """纯文本文档解析器""" + + def __init__(self): + super().__init__() + self.supported_extensions = ['.txt'] + self.parser_name = "txt_parser" + + def parse( + self, + file_path: str, + encoding: Optional[str] = None, + **kwargs + ) -> ParseResult: + """ + 解析文本文件 + + Args: + file_path: 文件路径 + encoding: 指定编码,不指定则自动检测 + **kwargs: 其他参数 + + Returns: + ParseResult: 解析结果 + """ + path = Path(file_path) + + # 检查文件是否存在 + if not path.exists(): + return ParseResult( + success=False, + error=f"文件不存在: {file_path}" + ) + + # 检查文件扩展名 + if path.suffix.lower() not in self.supported_extensions: + return ParseResult( + success=False, + error=f"不支持的文件类型: {path.suffix}" + ) + + try: + # 检测编码 + if not encoding: + encoding = self._detect_encoding(file_path) + + # 读取文件内容 + with open(file_path, 'r', encoding=encoding) as f: + raw_content = f.read() + + # 清理文本 + content = self._clean_text(raw_content) + + # 提取行信息 + lines = content.split('\n') + + # 估算字数 + word_count = len(content.replace('\n', '').replace(' ', '')) + + # 构建元数据 + metadata = { + "filename": path.name, + "extension": path.suffix.lower(), + "file_size": path.stat().st_size, + "encoding": encoding, + "line_count": len(lines), + "word_count": word_count, + "char_count": len(content), + "non_empty_line_count": len([l for l in lines if l.strip()]) + } + + return ParseResult( + success=True, + data={ + "content": content, + "raw_content": raw_content, + "lines": lines, + "word_count": word_count, + "char_count": len(content), + "line_count": len(lines), + "structured_data": { + "line_count": len(lines), + "non_empty_line_count": metadata["non_empty_line_count"] + } + }, + metadata=metadata + ) + + except Exception as e: + logger.error(f"解析文本文件失败: {str(e)}") + return ParseResult( + success=False, + error=f"解析文本文件失败: {str(e)}" + ) + + def _detect_encoding(self, file_path: str) -> str: + """ + 自动检测文件编码 + + Args: + file_path: 文件路径 + + Returns: + 检测到的编码 + """ + try: + with open(file_path, 'rb') as f: + raw_data = f.read() + + result = chardet.detect(raw_data) + encoding = result.get('encoding', 'utf-8') + + # 验证编码是否有效 + if encoding: + try: + raw_data.decode(encoding) + return encoding + except (UnicodeDecodeError, LookupError): + pass + + return 'utf-8' + + except Exception as e: + logger.warning(f"编码检测失败,使用默认编码: {str(e)}") + return 'utf-8' + + def _clean_text(self, text: str) -> str: + """ + 清理文本内容 + + - 去除多余空白字符 + - 规范化换行符 + - 去除特殊控制字符 + + Args: + text: 原始文本 + + Returns: + 清理后的文本 + """ + # 规范化换行符 + text = text.replace('\r\n', '\n').replace('\r', '\n') + + # 去除控制字符(除了换行和tab) + text = re.sub(r'[\x00-\x08\x0b-\x0c\x0e-\x1f\x7f]', '', text) + + # 将多个连续空格合并为一个 + text = re.sub(r'[ \t]+', ' ', text) + + # 将多个连续空行合并为一个 + text = re.sub(r'\n{3,}', '\n\n', text) + + return text.strip() + + def extract_structured_data(self, content: str) -> Dict[str, Any]: + """ + 尝试从文本中提取结构化数据 + + 支持提取: + - 邮箱地址 + - URL + - 电话号码 + - 日期 + - 金额 + + Args: + content: 文本内容 + + Returns: + 结构化数据字典 + """ + data = { + "emails": [], + "urls": [], + "phones": [], + "dates": [], + "amounts": [] + } + + # 提取邮箱 + emails = re.findall(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', content) + data["emails"] = list(set(emails)) + + # 提取 URL + urls = re.findall(r'https?://[^\s<>"{}|\\^`\[\]]+', content) + data["urls"] = list(set(urls)) + + # 提取电话号码 (支持多种格式) + phone_patterns = [ + r'1[3-9]\d{9}', # 手机号 + r'\d{3,4}-\d{7,8}', # 固话 + ] + phones = [] + for pattern in phone_patterns: + phones.extend(re.findall(pattern, content)) + data["phones"] = list(set(phones)) + + # 提取日期 + date_patterns = [ + r'\d{4}[-/年]\d{1,2}[-/月]\d{1,2}[日]?', + r'\d{4}\.\d{1,2}\.\d{1,2}', + ] + dates = [] + for pattern in date_patterns: + dates.extend(re.findall(pattern, content)) + data["dates"] = list(set(dates)) + + # 提取金额 + amount_patterns = [ + r'¥\s*\d+(?:\.\d{1,2})?', + r'\$\s*\d+(?:\.\d{1,2})?', + r'\d+(?:\.\d{1,2})?\s*元', + ] + amounts = [] + for pattern in amount_patterns: + amounts.extend(re.findall(pattern, content)) + data["amounts"] = list(set(amounts)) + + return data + + def split_into_chunks( + self, + content: str, + chunk_size: int = 1000, + overlap: int = 100 + ) -> List[str]: + """ + 将长文本分割成块 + + 用于 RAG 索引或 LLM 处理 + + Args: + content: 文本内容 + chunk_size: 每块字符数 + overlap: 块之间的重叠字符数 + + Returns: + 文本块列表 + """ + if len(content) <= chunk_size: + return [content] + + chunks = [] + start = 0 + + while start < len(content): + end = start + chunk_size + chunk = content[start:end] + + # 尝试在句子边界分割 + if end < len(content): + last_period = chunk.rfind('。') + last_newline = chunk.rfind('\n') + split_pos = max(last_period, last_newline) + + if split_pos > chunk_size // 2: + chunk = chunk[:split_pos + 1] + end = start + split_pos + 1 + + chunks.append(chunk) + start = end - overlap if end < len(content) else end + + return chunks diff --git a/backend/app/services/excel_storage_service.py b/backend/app/services/excel_storage_service.py new file mode 100644 index 0000000..5f348e1 --- /dev/null +++ b/backend/app/services/excel_storage_service.py @@ -0,0 +1,352 @@ +""" +Excel 存储服务 + +将 Excel 数据转换为 MySQL 表结构并存储 +""" +import logging +import re +from datetime import datetime +from typing import Any, Dict, List, Optional + +import pandas as pd +from sqlalchemy import ( + Column, + DateTime, + Float, + Integer, + String, + Text, + inspect, +) +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database.mysql import Base, mysql_db + +logger = logging.getLogger(__name__) + + +class ExcelStorageService: + """Excel 数据存储服务""" + + def __init__(self): + self.mysql_db = mysql_db + + def _sanitize_table_name(self, filename: str) -> str: + """ + 将文件名转换为合法的表名 + + Args: + filename: 原始文件名 + + Returns: + 合法的表名 + """ + # 移除扩展名 + name = filename.rsplit('.', 1)[0] if '.' in filename else filename + + # 只保留字母、数字、下划线 + name = re.sub(r'[^a-zA-Z0-9_]', '_', name) + + # 确保以字母开头 + if name and name[0].isdigit(): + name = 't_' + name + + # 限制长度 + return name[:50] + + def _sanitize_column_name(self, col_name: str) -> str: + """ + 将列名转换为合法的字段名 + + Args: + col_name: 原始列名 + + Returns: + 合法的字段名 + """ + # 只保留字母、数字、下划线 + name = re.sub(r'[^a-zA-Z0-9_]', '_', str(col_name)) + + # 确保以字母开头 + if name and name[0].isdigit(): + name = 'col_' + name + + # 限制长度 + return name[:50] + + def _infer_column_type(self, series: pd.Series) -> str: + """ + 根据数据推断列类型 + + Args: + series: pandas Series + + Returns: + 类型名称 + """ + dtype = series.dtype + + if pd.api.types.is_integer_dtype(dtype): + return "INTEGER" + elif pd.api.types.is_float_dtype(dtype): + return "FLOAT" + elif pd.api.types.is_datetime64_any_dtype(dtype): + return "DATETIME" + elif pd.api.types.is_bool_dtype(dtype): + return "BOOLEAN" + else: + return "TEXT" + + def _create_table_model( + self, + table_name: str, + columns: List[str], + column_types: Dict[str, str] + ) -> type: + """ + 动态创建 SQLAlchemy 模型类 + + Args: + table_name: 表名 + columns: 列名列表 + column_types: 列类型字典 + + Returns: + SQLAlchemy 模型类 + """ + # 创建属性字典 + attrs = { + '__tablename__': table_name, + '__table_args__': {'extend_existing': True}, + } + + # 添加主键列 + attrs['id'] = Column(Integer, primary_key=True, autoincrement=True) + + # 添加数据列 + for col in columns: + col_name = self._sanitize_column_name(col) + col_type = column_types.get(col, "TEXT") + + if col_type == "INTEGER": + attrs[col_name] = Column(Integer, nullable=True) + elif col_type == "FLOAT": + attrs[col_name] = Column(Float, nullable=True) + elif col_type == "DATETIME": + attrs[col_name] = Column(DateTime, nullable=True) + elif col_type == "BOOLEAN": + attrs[col_name] = Column(Integer, nullable=True) # MySQL 没有原生 BOOLEAN + else: + attrs[col_name] = Column(Text, nullable=True) + + # 添加元数据列 + attrs['created_at'] = Column(DateTime, default=datetime.utcnow) + attrs['updated_at'] = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + # 创建类 + return type(table_name, (Base,), attrs) + + async def store_excel( + self, + file_path: str, + filename: str, + sheet_name: Optional[str] = None, + header_row: int = 0 + ) -> Dict[str, Any]: + """ + 将 Excel 文件存储到 MySQL + + Args: + file_path: Excel 文件路径 + filename: 原始文件名 + sheet_name: 工作表名称 + header_row: 表头行号 + + Returns: + 存储结果 + """ + table_name = self._sanitize_table_name(filename) + results = { + "success": True, + "table_name": table_name, + "row_count": 0, + "columns": [] + } + + try: + # 读取 Excel + if sheet_name: + df = pd.read_excel(file_path, sheet_name=sheet_name, header=header_row) + else: + df = pd.read_excel(file_path, header=header_row) + + if df.empty: + return {"success": False, "error": "Excel 文件为空"} + + # 清理列名 + df.columns = [str(c) for c in df.columns] + + # 推断列类型 + column_types = {} + for col in df.columns: + col_name = self._sanitize_column_name(col) + col_type = self._infer_column_type(df[col]) + column_types[col] = col_type + results["columns"].append({ + "original_name": col, + "sanitized_name": col_name, + "type": col_type + }) + + # 创建表 + model_class = self._create_table_model(table_name, df.columns, column_types) + + # 创建表结构 + async with self.mysql_db.get_session() as session: + model_class.__table__.create(session.bind, checkfirst=True) + + # 插入数据 + records = [] + for _, row in df.iterrows(): + record = {} + for col in df.columns: + col_name = self._sanitize_column_name(col) + value = row[col] + + # 处理 NaN 值 + if pd.isna(value): + record[col_name] = None + elif column_types[col] == "INTEGER": + try: + record[col_name] = int(value) + except (ValueError, TypeError): + record[col_name] = None + elif column_types[col] == "FLOAT": + try: + record[col_name] = float(value) + except (ValueError, TypeError): + record[col_name] = None + else: + record[col_name] = str(value) + + records.append(record) + + # 批量插入 + async with self.mysql_db.get_session() as session: + for record in records: + session.add(model_class(**record)) + await session.commit() + + results["row_count"] = len(records) + logger.info(f"Excel 数据已存储到 MySQL 表 {table_name},共 {len(records)} 行") + + return results + + except Exception as e: + logger.error(f"存储 Excel 到 MySQL 失败: {str(e)}") + return {"success": False, "error": str(e)} + + async def query_table( + self, + table_name: str, + columns: Optional[List[str]] = None, + where: Optional[str] = None, + limit: int = 100 + ) -> List[Dict[str, Any]]: + """ + 查询 MySQL 表数据 + + Args: + table_name: 表名 + columns: 要查询的列 + where: WHERE 条件 + limit: 限制返回行数 + + Returns: + 查询结果 + """ + try: + # 构建查询 + sql = f"SELECT * FROM `{table_name}`" + if where: + sql += f" WHERE {where}" + sql += f" LIMIT {limit}" + + results = await self.mysql_db.execute_query(sql) + return results + + except Exception as e: + logger.error(f"查询表失败: {str(e)}") + return [] + + async def get_table_schema(self, table_name: str) -> Optional[Dict[str, Any]]: + """ + 获取表结构信息 + + Args: + table_name: 表名 + + Returns: + 表结构信息 + """ + try: + sql = f""" + SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_KEY, COLUMN_COMMENT + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = '{table_name}' + ORDER BY ORDINAL_POSITION + """ + results = await self.mysql_db.execute_query(sql) + return results + + except Exception as e: + logger.error(f"获取表结构失败: {str(e)}") + return None + + async def delete_table(self, table_name: str) -> bool: + """ + 删除表 + + Args: + table_name: 表名 + + Returns: + 是否成功 + """ + try: + # 安全检查:表名必须包含下划线(避免删除系统表) + if '_' not in table_name and not table_name.startswith('t_'): + raise ValueError("不允许删除此表") + + sql = f"DROP TABLE IF EXISTS `{table_name}`" + await self.mysql_db.execute_raw_sql(sql) + logger.info(f"表 {table_name} 已删除") + return True + + except Exception as e: + logger.error(f"删除表失败: {str(e)}") + return False + + async def list_tables(self) -> List[str]: + """ + 列出所有用户表 + + Returns: + 表名列表 + """ + try: + sql = """ + SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_SCHEMA = DATABASE() AND TABLE_TYPE = 'BASE TABLE' + """ + results = await self.mysql_db.execute_query(sql) + return [r['TABLE_NAME'] for r in results] + + except Exception as e: + logger.error(f"列出表失败: {str(e)}") + return [] + + +# ==================== 全局单例 ==================== + +excel_storage_service = ExcelStorageService() diff --git a/backend/app/services/prompt_service.py b/backend/app/services/prompt_service.py new file mode 100644 index 0000000..8293b91 --- /dev/null +++ b/backend/app/services/prompt_service.py @@ -0,0 +1,444 @@ +""" +提示词工程服务 + +管理和优化与大模型交互的提示词 +""" +import json +import logging +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +class PromptType(Enum): + """提示词类型""" + DOCUMENT_PARSING = "document_parsing" # 文档解析 + FIELD_EXTRACTION = "field_extraction" # 字段提取 + TABLE_FILLING = "table_filling" # 表格填写 + QUERY_GENERATION = "query_generation" # 查询生成 + TEXT_SUMMARY = "text_summary" # 文本摘要 + INTENT_CLASSIFICATION = "intent_classification" # 意图分类 + DATA_CLASSIFICATION = "data_classification" # 数据分类 + + +@dataclass +class PromptTemplate: + """提示词模板""" + name: str + type: PromptType + system_prompt: str + user_template: str + examples: List[Dict[str, str]] = field(default_factory=list) # Few-shot 示例 + rules: List[str] = field(default_factory=list) # 特殊规则 + + def format( + self, + context: Dict[str, Any], + user_input: Optional[str] = None + ) -> List[Dict[str, str]]: + """ + 格式化提示词 + + Args: + context: 上下文数据 + user_input: 用户输入 + + Returns: + 格式化后的消息列表 + """ + messages = [] + + # 系统提示词 + system_content = self.system_prompt + + # 添加规则 + if self.rules: + system_content += "\n\n【输出规则】\n" + "\n".join([f"- {rule}" for rule in self.rules]) + + # 添加示例 + if self.examples: + system_content += "\n\n【示例】\n" + for i, ex in enumerate(self.examples): + system_content += f"\n示例 {i+1}:\n" + system_content += f"输入: {ex.get('input', '')}\n" + system_content += f"输出: {ex.get('output', '')}\n" + + messages.append({"role": "system", "content": system_content}) + + # 用户提示词 + user_content = self._format_user_template(context, user_input) + messages.append({"role": "user", "content": user_content}) + + return messages + + def _format_user_template( + self, + context: Dict[str, Any], + user_input: Optional[str] + ) -> str: + """格式化用户模板""" + content = self.user_template + + # 替换上下文变量 + for key, value in context.items(): + placeholder = f"{{{key}}}" + if placeholder in content: + if isinstance(value, (dict, list)): + content = content.replace(placeholder, json.dumps(value, ensure_ascii=False, indent=2)) + else: + content = content.replace(placeholder, str(value)) + + # 添加用户输入 + if user_input: + content += f"\n\n【用户需求】\n{user_input}" + + return content + + +class PromptEngineeringService: + """提示词工程服务""" + + def __init__(self): + self.templates: Dict[PromptType, PromptTemplate] = {} + self._init_templates() + + def _init_templates(self): + """初始化所有提示词模板""" + + # ==================== 文档解析模板 ==================== + self.templates[PromptType.DOCUMENT_PARSING] = PromptTemplate( + name="文档解析", + type=PromptType.DOCUMENT_PARSING, + system_prompt="""你是一个专业的文档解析专家。你的任务是从各类文档(Word、Excel、Markdown、纯文本)中提取关键信息。 + +请严格按照JSON格式输出解析结果: +{ + "success": true/false, + "document_type": "文档类型", + "key_fields": {"字段名": "字段值", ...}, + "summary": "文档摘要(100字内)", + "structured_data": {...} // 提取的表格或其他结构化数据 +} + +重要规则: +- 只提取明确存在的信息,不要猜测 +- 如果是表格数据,请以数组格式输出 +- 日期请使用 YYYY-MM-DD 格式 +- 金额请使用数字格式 +- 如果无法提取某个字段,设置为 null""", + user_template="""请解析以下文档内容: + +=== 文档开始 === +{content} +=== 文档结束 === + +请提取文档中的关键信息。""", + examples=[ + { + "input": "合同金额:100万元\n签订日期:2024年1月15日\n甲方:张三\n乙方:某某公司", + "output": '{"success": true, "document_type": "合同", "key_fields": {"金额": 1000000, "日期": "2024-01-15", "甲方": "张三", "乙方": "某某公司"}, "summary": "甲乙双方签订的金额为100万元的合同", "structured_data": null}' + } + ], + rules=[ + "只输出JSON,不要添加任何解释", + "使用严格的JSON格式" + ] + ) + + # ==================== 字段提取模板 ==================== + self.templates[PromptType.FIELD_EXTRACTION] = PromptTemplate( + name="字段提取", + type=PromptType.FIELD_EXTRACTION, + system_prompt="""你是一个专业的数据提取专家。你的任务是从文档内容中提取指定字段的信息。 + +请严格按照以下JSON格式输出: +{ + "value": "提取到的值,找不到则为空字符串", + "source": "数据来源描述", + "confidence": 0.0到1.0之间的置信度 +} + +重要规则: +- 严格按字段名称匹配,不要提取无关信息 +- 置信度反映你对提取结果的信心程度 +- 如果字段不存在或无法确定,value设为空字符串,confidence设为0.0 +- value必须是实际值,不能是"未找到"之类的描述""", + user_template="""请从以下文档内容中提取指定字段的信息。 + +【需要提取的字段】 +字段名称:{field_name} +字段类型:{field_type} +是否必填:{required} + +【用户提示】 +{hint} + +【文档内容】 +{context} + +请提取字段值。""", + examples=[ + { + "input": "文档内容:姓名张三,电话13800138000,邮箱zhangsan@example.com", + "output": '{"value": "张三", "source": "文档第1行", "confidence": 1.0}' + } + ], + rules=[ + "只输出JSON,不要添加任何解释" + ] + ) + + # ==================== 表格填写模板 ==================== + self.templates[PromptType.TABLE_FILLING] = PromptTemplate( + name="表格填写", + type=PromptType.TABLE_FILLING, + system_prompt="""你是一个专业的表格填写助手。你的任务是根据提供的文档内容,填写表格模板中的字段。 + +请严格按照以下JSON格式输出: +{ + "filled_data": {{"字段1": "值1", "字段2": "值2", ...}}, + "fill_details": [ + {{"field": "字段1", "value": "值1", "source": "来源", "confidence": 0.95}}, + ... + ] +} + +重要规则: +- 只填写模板中存在的字段 +- 值必须来自提供的文档内容,不要编造 +- 如果某个字段在文档中找不到对应值,设为空字符串 +- fill_details 中记录每个字段的详细信息""", + user_template="""请根据以下文档内容,填写表格模板。 + +【表格模板字段】 +{fields} + +【用户需求】 +{hint} + +【参考文档内容】 +{context} + +请填写表格。""", + examples=[ + { + "input": "字段:姓名、电话\n文档:张三,电话是13800138000", + "output": '{"filled_data": {"姓名": "张三", "电话": "13800138000"}, "fill_details": [{"field": "姓名", "value": "张三", "source": "文档第1行", "confidence": 1.0}, {"field": "电话", "value": "13800138000", "source": "文档第1行", "confidence": 1.0}]}' + } + ], + rules=[ + "只输出JSON,不要添加任何解释" + ] + ) + + # ==================== 查询生成模板 ==================== + self.templates[PromptType.QUERY_GENERATION] = PromptTemplate( + name="查询生成", + type=PromptType.QUERY_GENERATION, + system_prompt="""你是一个SQL查询生成专家。你的任务是根据用户的自然语言需求,生成相应的数据库查询语句。 + +请严格按照以下JSON格式输出: +{ + "sql_query": "生成的SQL查询语句", + "explanation": "查询逻辑说明" +} + +重要规则: +- 只生成 SELECT 查询语句,不要生成 INSERT/UPDATE/DELETE +- 必须包含 WHERE 条件限制查询范围 +- 表名和字段名使用反引号包裹 +- 确保SQL语法正确 +- 如果无法生成有效的查询,sql_query设为空字符串""", + user_template="""根据以下信息生成查询语句。 + +【数据库表结构】 +{table_schema} + +【RAG检索到的上下文】 +{rag_context} + +【用户查询需求】 +{user_intent} + +请生成SQL查询。""", + examples=[ + { + "input": "表:orders(订单号, 金额, 日期, 客户)\n需求:查询2024年1月销售额超过10000的订单", + "output": '{"sql_query": "SELECT * FROM `orders` WHERE `日期` >= \\'2024-01-01\\' AND `日期` < \\'2024-02-01\\' AND `金额` > 10000", "explanation": "筛选2024年1月销售额超过10000的订单"}' + } + ], + rules=[ + "只输出JSON,不要添加任何解释", + "禁止生成 DROP、DELETE、TRUNCATE 等危险操作" + ] + ) + + # ==================== 文本摘要模板 ==================== + self.templates[PromptType.TEXT_SUMMARY] = PromptTemplate( + name="文本摘要", + type=PromptType.TEXT_SUMMARY, + system_prompt="""你是一个专业的文本摘要专家。你的任务是对长文档进行压缩,提取关键信息。 + +请严格按照以下JSON格式输出: +{ + "summary": "摘要内容(不超过200字)", + "key_points": ["要点1", "要点2", "要点3"], + "keywords": ["关键词1", "关键词2", "关键词3"] +}""", + user_template="""请为以下文档生成摘要: + +=== 文档开始 === +{content} +=== 文档结束 === + +生成简明摘要。""", + rules=[ + "只输出JSON,不要添加任何解释" + ] + ) + + # ==================== 意图分类模板 ==================== + self.templates[PromptType.INTENT_CLASSIFICATION] = PromptTemplate( + name="意图分类", + type=PromptType.INTENT_CLASSIFICATION, + system_prompt="""你是一个意图分类专家。你的任务是分析用户的自然语言输入,判断用户的真实意图。 + +支持的意图类型: +- upload: 上传文档 +- parse: 解析文档 +- query: 查询数据 +- fill: 填写表格 +- export: 导出数据 +- analyze: 分析数据 +- other: 其他/未知 + +请严格按照以下JSON格式输出: +{ + "intent": "意图类型", + "confidence": 0.0到1.0之间的置信度, + "entities": {{"实体名": "实体值", ...}}, // 识别出的关键实体 + "suggestion": "建议的下一步操作" +}""", + user_template="""请分析以下用户输入,判断其意图: + +【用户输入】 +{user_input} + +请分类。""", + rules=[ + "只输出JSON,不要添加任何解释" + ] + ) + + # ==================== 数据分类模板 ==================== + self.templates[PromptType.DATA_CLASSIFICATION] = PromptTemplate( + name="数据分类", + type=PromptType.DATA_CLASSIFICATION, + system_prompt="""你是一个数据分类专家。你的任务是判断数据的类型和格式。 + +请严格按照以下JSON格式输出: +{ + "data_type": "text/number/date/email/phone/url/amount/other", + "format": "具体格式描述", + "is_valid": true/false, + "normalized_value": "规范化后的值" +}""", + user_template="""请分析以下数据的类型和格式: + +【数据】 +{value} + +【期望类型(如果有)】 +{expected_type} + +请分类。""", + rules=[ + "只输出JSON,不要添加任何解释" + ] + ) + + def get_prompt( + self, + type: PromptType, + context: Dict[str, Any], + user_input: Optional[str] = None + ) -> List[Dict[str, str]]: + """ + 获取格式化后的提示词 + + Args: + type: 提示词类型 + context: 上下文数据 + user_input: 用户输入 + + Returns: + 消息列表 + """ + template = self.templates.get(type) + if not template: + logger.warning(f"未找到提示词模板: {type}") + return [{"role": "user", "content": str(context)}] + + return template.format(context, user_input) + + def get_template(self, type: PromptType) -> Optional[PromptTemplate]: + """获取提示词模板""" + return self.templates.get(type) + + def add_template(self, template: PromptTemplate): + """添加自定义提示词模板""" + self.templates[template.type] = template + logger.info(f"已添加提示词模板: {template.name}") + + def update_template(self, type: PromptType, **kwargs): + """更新提示词模板""" + template = self.templates.get(type) + if template: + for key, value in kwargs.items(): + if hasattr(template, key): + setattr(template, key, value) + + def optimize_prompt( + self, + type: PromptType, + feedback: str, + iteration: int = 1 + ) -> List[Dict[str, str]]: + """ + 根据反馈优化提示词 + + Args: + type: 提示词类型 + feedback: 优化反馈 + iteration: 迭代次数 + + Returns: + 优化后的提示词 + """ + template = self.templates.get(type) + if not template: + return [] + + # 简单优化策略:根据反馈添加规则 + optimization_rules = { + "准确率低": "提高要求,明确指出必须从原文提取,不要猜测", + "格式错误": "强调JSON格式要求,提供更详细的格式示例", + "遗漏信息": "添加提取更多细节的要求", + } + + new_rules = [] + for keyword, rule in optimization_rules.items(): + if keyword in feedback: + new_rules.append(rule) + + if new_rules: + template.rules.extend(new_rules) + + return template.format({}, None) + + +# ==================== 全局单例 ==================== + +prompt_service = PromptEngineeringService() diff --git a/backend/app/services/template_fill_service.py b/backend/app/services/template_fill_service.py new file mode 100644 index 0000000..2612354 --- /dev/null +++ b/backend/app/services/template_fill_service.py @@ -0,0 +1,307 @@ +""" +表格模板填写服务 + +从非结构化文档中检索信息并填写到表格模板 +""" +import logging +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from app.core.database import mongodb +from app.services.rag_service import rag_service +from app.services.llm_service import llm_service +from app.services.excel_storage_service import excel_storage_service + +logger = logging.getLogger(__name__) + + +@dataclass +class TemplateField: + """模板字段""" + cell: str # 单元格位置,如 "A1" + name: str # 字段名称 + field_type: str = "text" # 字段类型: text/number/date + required: bool = True + + +@dataclass +class FillResult: + """填写结果""" + field: str + value: Any + source: str # 来源文档 + confidence: float = 1.0 # 置信度 + + +class TemplateFillService: + """表格填写服务""" + + def __init__(self): + self.llm = llm_service + self.rag = rag_service + + async def fill_template( + self, + template_fields: List[TemplateField], + source_doc_ids: Optional[List[str]] = None, + user_hint: Optional[str] = None + ) -> Dict[str, Any]: + """ + 填写表格模板 + + Args: + template_fields: 模板字段列表 + source_doc_ids: 源文档ID列表,不指定则从所有文档检索 + user_hint: 用户提示(如"请从合同文档中提取") + + Returns: + 填写结果 + """ + filled_data = {} + fill_details = [] + + for field in template_fields: + try: + # 1. 从 RAG 检索相关上下文 + rag_results = await self._retrieve_context(field.name, user_hint) + + if not rag_results: + # 如果没有检索到结果,尝试直接询问 LLM + result = FillResult( + field=field.name, + value="", + source="未找到相关数据", + confidence=0.0 + ) + else: + # 2. 构建 Prompt 让 LLM 提取信息 + result = await self._extract_field_value( + field=field, + rag_context=rag_results, + user_hint=user_hint + ) + + # 3. 存储结果 + filled_data[field.name] = result.value + fill_details.append({ + "field": field.name, + "cell": field.cell, + "value": result.value, + "source": result.source, + "confidence": result.confidence + }) + + logger.info(f"字段 {field.name} 填写完成: {result.value}") + + except Exception as e: + logger.error(f"填写字段 {field.name} 失败: {str(e)}") + filled_data[field.name] = f"[提取失败: {str(e)}]" + fill_details.append({ + "field": field.name, + "cell": field.cell, + "value": f"[提取失败]", + "source": "error", + "confidence": 0.0 + }) + + return { + "success": True, + "filled_data": filled_data, + "fill_details": fill_details + } + + async def _retrieve_context( + self, + field_name: str, + user_hint: Optional[str] = None + ) -> List[Dict[str, Any]]: + """ + 从 RAG 检索相关上下文 + + Args: + field_name: 字段名称 + user_hint: 用户提示 + + Returns: + 检索结果列表 + """ + # 构建查询文本 + query = field_name + if user_hint: + query = f"{user_hint} {field_name}" + + # 检索相关文档片段 + results = self.rag.retrieve(query=query, top_k=5) + + return results + + async def _extract_field_value( + self, + field: TemplateField, + rag_context: List[Dict[str, Any]], + user_hint: Optional[str] = None + ) -> FillResult: + """ + 使用 LLM 从上下文中提取字段值 + + Args: + field: 字段定义 + rag_context: RAG 检索到的上下文 + user_hint: 用户提示 + + Returns: + 提取结果 + """ + # 构建上下文文本 + context_text = "\n\n".join([ + f"【文档 {i+1}】\n{doc['content']}" + for i, doc in enumerate(rag_context) + ]) + + # 构建 Prompt + prompt = f"""你是一个数据提取专家。请根据以下文档内容,提取指定字段的信息。 + +需要提取的字段: +- 字段名称:{field.name} +- 字段类型:{field.field_type} +- 是否必填:{'是' if field.required else '否'} + +{'用户提示:' + user_hint if user_hint else ''} + +参考文档内容: +{context_text} + +请严格按照以下 JSON 格式输出,不要添加任何解释: +{{ + "value": "提取到的值,如果没有找到则填写空字符串", + "source": "数据来源的文档描述", + "confidence": 0.0到1.0之间的置信度 +}} +""" + + # 调用 LLM + messages = [ + {"role": "system", "content": "你是一个专业的数据提取助手。请严格按JSON格式输出。"}, + {"role": "user", "content": prompt} + ] + + try: + response = await self.llm.chat( + messages=messages, + temperature=0.1, + max_tokens=500 + ) + + content = self.llm.extract_message_content(response) + + # 解析 JSON 响应 + import json + import re + + # 尝试提取 JSON + json_match = re.search(r'\{[\s\S]*\}', content) + if json_match: + result = json.loads(json_match.group()) + return FillResult( + field=field.name, + value=result.get("value", ""), + source=result.get("source", "LLM生成"), + confidence=result.get("confidence", 0.5) + ) + else: + # 如果无法解析,返回原始内容 + return FillResult( + field=field.name, + value=content.strip(), + source="直接提取", + confidence=0.5 + ) + + except Exception as e: + logger.error(f"LLM 提取失败: {str(e)}") + return FillResult( + field=field.name, + value="", + source=f"提取失败: {str(e)}", + confidence=0.0 + ) + + async def get_template_fields_from_file( + self, + file_path: str, + file_type: str = "xlsx" + ) -> List[TemplateField]: + """ + 从模板文件提取字段定义 + + Args: + file_path: 模板文件路径 + file_type: 文件类型 + + Returns: + 字段列表 + """ + fields = [] + + try: + if file_type in ["xlsx", "xls"]: + # 从 Excel 读取表头 + import pandas as pd + df = pd.read_excel(file_path, nrows=5) + + for idx, col in enumerate(df.columns): + # 获取单元格位置 (A, B, C, ...) + cell = self._column_to_cell(idx) + + fields.append(TemplateField( + cell=cell, + name=str(col), + field_type=self._infer_field_type(df[col]), + required=True + )) + + elif file_type == "docx": + # 从 Word 表格读取 + from docx import Document + doc = Document(file_path) + + for table_idx, table in enumerate(doc.tables): + for row_idx, row in enumerate(table.rows): + for col_idx, cell in enumerate(row.cells): + cell_text = cell.text.strip() + if cell_text: + fields.append(TemplateField( + cell=self._column_to_cell(col_idx), + name=cell_text, + field_type="text", + required=True + )) + + except Exception as e: + logger.error(f"提取模板字段失败: {str(e)}") + + return fields + + def _column_to_cell(self, col_idx: int) -> str: + """将列索引转换为单元格列名 (0 -> A, 1 -> B, ...)""" + result = "" + while col_idx >= 0: + result = chr(65 + (col_idx % 26)) + result + col_idx = col_idx // 26 - 1 + return result + + def _infer_field_type(self, series) -> str: + """推断字段类型""" + import pandas as pd + + if pd.api.types.is_numeric_dtype(series): + return "number" + elif pd.api.types.is_datetime64_any_dtype(series): + return "date" + else: + return "text" + + +# ==================== 全局单例 ==================== + +template_fill_service = TemplateFillService()