Compare commits
71 Commits
96f83042d8
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 47c89d888f | |||
| 6701df613b | |||
| ecad9ccd82 | |||
| 51350e3002 | |||
| 8e713be1ca | |||
| f2af27245d | |||
| a9dc0d8b91 | |||
| 902c28166b | |||
| 4a53be7eeb | |||
| 8b5b24fa2a | |||
| ed66aa346d | |||
| 5b82d40be0 | |||
| bedf1af9c0 | |||
| 5fca4eb094 | |||
| 0dbf74db9d | |||
| 858b594171 | |||
| ed0f51f2a4 | |||
| ecc0c79475 | |||
| 6befc510d8 | |||
| 8f66c235fa | |||
| 886d5ae0cc | |||
| 6752c5c231 | |||
| 610d475ce0 | |||
| 496b96508d | |||
| 07ebdc09bc | |||
| 7f67fa89de | |||
| c1886fb68f | |||
| 78417c898a | |||
| d5df5b8283 | |||
| 718f864926 | |||
| e5711b3f05 | |||
| df35105d16 | |||
| 2c2ab56d2d | |||
| faff1a5977 | |||
| b2ebd3e12d | |||
| 4eda6cf758 | |||
| 38e41c6eff | |||
| 6f8976cf71 | |||
| 44d389a434 | |||
| c75eb26d60 | |||
| 3b82103e87 | |||
| fd435c7fd3 | |||
| 41e5eaaa2d | |||
| 7c19e49988 | |||
| d189ea9620 | |||
| ddf30078f0 | |||
| 1a54d40e01 | |||
| ec4759512d | |||
| 8e1ddb8aff | |||
| 8b12cb9322 | |||
| b9ca11efe5 | |||
| c122f1d63b | |||
| 332f0f636d | |||
| d494e78f70 | |||
| 091c9db0da | |||
| 4e178477fe | |||
| 7c88da9ab1 | |||
| 6b88e971e8 | |||
| 5bcad4a5fa | |||
| 4bdc3f9707 | |||
| d3bdb17e87 | |||
| eab5f88662 | |||
| 2f630695ff | |||
| c23b93bb70 | |||
| 67e29d5800 | |||
| 0b00e27dbd | |||
| 12053a8fb1 | |||
| b32b1983ce | |||
| d8266e6d05 | |||
| 249cb5f6fd | |||
| b4a32748c5 |
48
.gitignore
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
/.git/
|
||||
/.gitignore
|
||||
/.idea/
|
||||
/.vscode/
|
||||
/backend/venv/
|
||||
/backend/command/
|
||||
/backend/.env
|
||||
/backend/.env.local
|
||||
/backend/.env.*.local
|
||||
/backend/app/__pycache__/*
|
||||
/backend/data/uploads
|
||||
/backend/data/charts
|
||||
/backend/data/logs
|
||||
|
||||
/frontend/node_modules/
|
||||
/frontend/dist/
|
||||
/frontend/build/
|
||||
/frontend/.vscode/
|
||||
/frontend/.idea/
|
||||
/frontend/.env
|
||||
/frontend/*.log
|
||||
|
||||
/frontend/src/api/
|
||||
/frontend/src/api/index.js
|
||||
/frontend/src/api/index.ts
|
||||
/frontend/src/api/index.tsx
|
||||
/frontend/src/api/index.py
|
||||
/frontend/src/api/index.go
|
||||
/frontend/src/api/index.java
|
||||
|
||||
/frontend - 副本/
|
||||
|
||||
/docs/
|
||||
/frontendTest/
|
||||
/supabase.txt
|
||||
|
||||
# 取消跟踪的文件 / Untracked files
|
||||
比赛备赛规划.md
|
||||
Q&A.xlsx
|
||||
package.json
|
||||
技术路线.md
|
||||
开发路径.md
|
||||
开发日志_2026-03-16.md
|
||||
/logs/
|
||||
|
||||
# Python cache
|
||||
**/__pycache__/**
|
||||
**.pyc
|
||||
10
.idea/.gitignore
generated
vendored
@@ -1,10 +0,0 @@
|
||||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
/.idea/
|
||||
/venv/
|
||||
# 基于编辑器的 HTTP 客户端请求
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
12
.idea/FilesReadSysteam.iml
generated
@@ -1,12 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="Python 3.12" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
</component>
|
||||
</module>
|
||||
6
.idea/encodings.xml
generated
@@ -1,6 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Encoding">
|
||||
<file url="file://$PROJECT_DIR$/backend/requirements.txt" charset="UTF-8" />
|
||||
</component>
|
||||
</project>
|
||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
@@ -1,6 +0,0 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
7
.idea/misc.xml
generated
@@ -1,7 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.12" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
8
.idea/modules.xml
generated
@@ -1,8 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/FilesReadSysteam.iml" filepath="$PROJECT_DIR$/.idea/FilesReadSysteam.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
6
.idea/vcs.xml
generated
@@ -1,6 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
238
README.md
Normal file
@@ -0,0 +1,238 @@
|
||||
# FilesReadSystem
|
||||
|
||||
## 项目介绍 / Project Introduction
|
||||
|
||||
基于大语言模型的文档理解与多源数据融合系统,专为第十七届中国大学生服务外包创新创业大赛(A23赛题)开发。本系统利用大语言模型(LLM)解析、分析各类文档格式并提取结构化数据,支持通过自然语言指令自动填写模板表格。
|
||||
|
||||
A document understanding and multi-source data fusion system based on Large Language Models (LLM), developed for the 17th China University Student Service Outsourcing Innovation and Entrepreneurship Competition (Topic A23). This system uses LLMs to parse, analyze, and extract structured data from various document formats, supporting automatic template table filling through natural language instructions.
|
||||
|
||||
---
|
||||
|
||||
## 技术栈 / Technology Stack
|
||||
|
||||
| 层次 / Layer | 组件 / Component | 说明 / Description |
|
||||
|:---|:---|:---|
|
||||
| 后端 / Backend | FastAPI + Uvicorn | RESTful API,异步任务调度 / API & async task scheduling |
|
||||
| 前端 / Frontend | React + TypeScript + Vite | 文件上传、表格配置、聊天界面 / Upload, table config, chat UI |
|
||||
| 异步任务 / Async Tasks | Celery + Redis | 处理耗时的解析与AI提取 / Heavy parsing & AI extraction |
|
||||
| 文档数据库 / Document DB | MongoDB (Motor) | 元数据、提取结果、文档块存储 / Metadata, results, chunk storage |
|
||||
| 关系数据库 / Relational DB | MySQL (SQLAlchemy) | 结构化数据存储 / Structured data storage |
|
||||
| 缓存 / Cache | Redis | 缓存与任务队列 / Caching & task queue |
|
||||
| 向量检索 / Vector Search | FAISS | 高效相似性搜索 / Efficient similarity search |
|
||||
| AI集成 / AI Integration | LangChain-style + MiniMax API | RAG流水线、提示词管理 / RAG pipeline, prompt management |
|
||||
| 文档解析 / Document Parsing | python-docx, pandas, openpyxl, markdown-it | 多格式支持 / Multi-format support |
|
||||
|
||||
---
|
||||
|
||||
## 项目架构 / Project Architecture
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ User Interface │
|
||||
│ (React + TypeScript + shadcn/ui) │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ FastAPI Backend │
|
||||
│ ┌─────────────┐ ┌──────────────┐ ┌─────────────────────────┐ │
|
||||
│ │ Upload API │ │ RAG Search │ │ Natural Language │ │
|
||||
│ │ /documents │ │ /rag/search │ │ /instruction/execute │ │
|
||||
│ └─────────────┘ └──────────────┘ └─────────────────────────┘ │
|
||||
│ ┌─────────────┐ ┌──────────────┐ ┌─────────────────────────┐ │
|
||||
│ │ AI Analyze │ │ Template Fill│ │ Visualization │ │
|
||||
│ │ /ai/analyze │ │ /templates │ │ /visualization │ │
|
||||
│ └─────────────┘ └──────────────┘ └─────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
┌─────────────────────┼─────────────────────┐
|
||||
▼ ▼ ▼
|
||||
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
||||
│ MongoDB │ │ MySQL │ │ Redis │
|
||||
│ (Documents) │ │ (Structured) │ │ (Cache/Queue) │
|
||||
└─────────────────┘ └─────────────────┘ └─────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ FAISS │
|
||||
│ (Vector Index) │
|
||||
└─────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 目录结构 / Directory Structure
|
||||
|
||||
```
|
||||
FilesReadSystem/
|
||||
├── backend/ # 后端服务(Python + FastAPI)
|
||||
│ ├── app/
|
||||
│ │ ├── api/endpoints/ # API路由层 / API endpoints
|
||||
│ │ │ ├── ai_analyze.py # AI分析接口 / AI analysis
|
||||
│ │ │ ├── documents.py # 文档管理 / Document management
|
||||
│ │ │ ├── instruction.py # 自然语言指令 / Natural language instruction
|
||||
│ │ │ ├── rag.py # RAG检索 / RAG retrieval
|
||||
│ │ │ ├── tasks.py # 任务管理 / Task management
|
||||
│ │ │ ├── templates.py # 模板管理 / Template management
|
||||
│ │ │ ├── upload.py # 文件上传 / File upload
|
||||
│ │ │ └── visualization.py # 可视化 / Visualization
|
||||
│ │ ├── core/
|
||||
│ │ │ ├── database/ # 数据库连接 / Database connections
|
||||
│ │ │ └── document_parser/ # 文档解析器 / Document parsers
|
||||
│ │ ├── services/ # 业务逻辑服务 / Business logic services
|
||||
│ │ │ ├── llm_service.py # LLM调用 / LLM service
|
||||
│ │ │ ├── rag_service.py # RAG流水线 / RAG pipeline
|
||||
│ │ │ ├── template_fill_service.py # 模板填充 / Template filling
|
||||
│ │ │ ├── excel_ai_service.py # Excel AI分析 / Excel AI analysis
|
||||
│ │ │ ├── word_ai_service.py # Word AI分析 / Word AI analysis
|
||||
│ │ │ └── table_rag_service.py # 表格RAG / Table RAG
|
||||
│ │ └── instruction/ # 指令解析与执行 / Instruction parsing & execution
|
||||
│ ├── requirements.txt # Python依赖 / Python dependencies
|
||||
│ └── README.md
|
||||
│
|
||||
├── frontend/ # 前端项目(React + TypeScript)
|
||||
│ ├── src/
|
||||
│ │ ├── pages/ # 页面组件 / Page components
|
||||
│ │ │ ├── Dashboard.tsx # 仪表板 / Dashboard
|
||||
│ │ │ ├── Documents.tsx # 文档管理 / Document management
|
||||
│ │ │ ├── TemplateFill.tsx # 模板填充 / Template fill
|
||||
│ │ │ └── InstructionChat.tsx # 指令聊天 / Instruction chat
|
||||
│ │ ├── components/ui/ # shadcn/ui组件库 / shadcn/ui components
|
||||
│ │ ├── contexts/ # React上下文 / React contexts
|
||||
│ │ ├── db/ # API调用封装 / API call wrappers
|
||||
│ │ └── supabase/functions/ # Edge函数 / Edge functions
|
||||
│ ├── package.json
|
||||
│ └── README.md
|
||||
│
|
||||
├── docs/ # 文档与测试数据 / Documentation & test data
|
||||
├── logs/ # 应用日志 / Application logs
|
||||
└── README.md # 本文件 / This file
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 主要功能 / Key Features
|
||||
|
||||
- **多格式文档解析** / Multi-format Document Parsing
|
||||
- Excel (.xlsx)
|
||||
- Word (.docx)
|
||||
- Markdown (.md)
|
||||
- Plain Text (.txt)
|
||||
|
||||
- **AI智能分析** / AI-Powered Analysis
|
||||
- 文档内容理解与摘要
|
||||
- 表格数据自动提取
|
||||
- 多文档联合推理
|
||||
|
||||
- **RAG检索增强** / RAG (Retrieval Augmented Generation)
|
||||
- 语义向量相似度搜索
|
||||
- 上下文感知的答案生成
|
||||
|
||||
- **模板自动填充** / Template Auto-fill
|
||||
- 智能表格模板识别
|
||||
- 自然语言指令驱动填写
|
||||
- 批量数据导入导出
|
||||
|
||||
- **自然语言指令** / Natural Language Instructions
|
||||
- 意图识别与解析
|
||||
- 多步骤任务自动执行
|
||||
|
||||
---
|
||||
|
||||
## API接口 / API Endpoints
|
||||
|
||||
| 方法 / Method | 路径 / Path | 说明 / Description |
|
||||
|:---|:---|:---|
|
||||
| GET | `/health` | 健康检查 / Health check |
|
||||
| POST | `/upload/document` | 单文件上传 / Single file upload |
|
||||
| POST | `/upload/documents` | 批量上传 / Batch upload |
|
||||
| GET | `/documents` | 文档库 / Document library |
|
||||
| GET | `/tasks/{task_id}` | 任务状态 / Task status |
|
||||
| POST | `/rag/search` | RAG语义搜索 / RAG search |
|
||||
| POST | `/templates/upload` | 模板上传 / Template upload |
|
||||
| POST | `/templates/fill` | 执行模板填充 / Execute template fill |
|
||||
| POST | `/ai/analyze/excel` | Excel AI分析 / Excel AI analysis |
|
||||
| POST | `/ai/analyze/word` | Word AI分析 / Word AI analysis |
|
||||
| POST | `/instruction/recognize` | 意图识别 / Intent recognition |
|
||||
| POST | `/instruction/execute` | 执行指令 / Execute instruction |
|
||||
| GET | `/visualization/statistics` | 统计图表 / Statistics charts |
|
||||
|
||||
---
|
||||
|
||||
## 环境配置 / Environment Setup
|
||||
|
||||
### 后端 / Backend
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
|
||||
# 创建虚拟环境 / Create virtual environment
|
||||
python -m venv venv
|
||||
|
||||
# 激活虚拟环境 / Activate virtual environment
|
||||
# Windows PowerShell:
|
||||
.\venv\Scripts\Activate.ps1
|
||||
# Windows CMD:
|
||||
.\venv\Scripts\Activate.bat
|
||||
|
||||
# 安装依赖 / Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 复制环境变量模板 / Copy environment template
|
||||
copy .env.example .env
|
||||
# 编辑 .env 填入API密钥 / Edit .env with your API keys
|
||||
```
|
||||
|
||||
### 前端 / Frontend
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
|
||||
# 安装依赖 / Install dependencies
|
||||
npm install
|
||||
|
||||
# 或使用 pnpm / Or using pnpm
|
||||
pnpm install
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 启动项目 / Starting the Project
|
||||
|
||||
### 后端启动 / Backend Startup
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
./venv/Scripts/python.exe -m uvicorn app.main:app --host 127.0.0.1 --port 8000 --reload
|
||||
```
|
||||
|
||||
### 前端启动 / Frontend Startup
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm run dev
|
||||
# 或 / or
|
||||
pnpm dev
|
||||
```
|
||||
|
||||
前端地址 / Frontend URL: http://localhost:5173
|
||||
|
||||
---
|
||||
|
||||
## 配置说明 / Configuration
|
||||
|
||||
### 环境变量 / Environment Variables
|
||||
|
||||
| 变量 / Variable | 说明 / Description |
|
||||
|:---|:---|
|
||||
| `MONGODB_URL` | MongoDB连接地址 / MongoDB connection URL |
|
||||
| `MYSQL_HOST` | MySQL主机 / MySQL host |
|
||||
| `REDIS_URL` | Redis连接地址 / Redis connection URL |
|
||||
| `MINIMAX_API_KEY` | MiniMax API密钥 / MiniMax API key |
|
||||
| `MINIMAX_API_URL` | MiniMax API地址 / MiniMax API URL |
|
||||
|
||||
---
|
||||
|
||||
## 许可证 / License
|
||||
|
||||
ISC
|
||||
@@ -1,15 +1,61 @@
|
||||
# 基础配置
|
||||
# ============================================================
|
||||
# 基于大语言模型的文档理解与多源数据融合系统
|
||||
# 环境变量配置文件
|
||||
# ============================================================
|
||||
# 复制此文件为 .env 并填入实际值
|
||||
|
||||
# ==================== 应用基础配置 ====================
|
||||
APP_NAME="FilesReadSystem"
|
||||
DEBUG=true
|
||||
API_V1_STR="/api/v1"
|
||||
|
||||
# 数据库
|
||||
MONGODB_URL="mongodb://username:password@host:port"
|
||||
# ==================== MongoDB 配置 ====================
|
||||
# 非结构化数据存储 (原始文档、解析结果)
|
||||
MONGODB_URL="mongodb://localhost:27017"
|
||||
MONGODB_DB_NAME="document_system"
|
||||
|
||||
# ==================== MySQL 配置 ====================
|
||||
# 结构化数据存储 (Excel表格、查询结果)
|
||||
MYSQL_HOST="localhost"
|
||||
MYSQL_PORT=3306
|
||||
MYSQL_USER="root"
|
||||
MYSQL_PASSWORD="your_password_here"
|
||||
MYSQL_DATABASE="document_system"
|
||||
MYSQL_CHARSET="utf8mb4"
|
||||
|
||||
# ==================== Redis 配置 ====================
|
||||
# 缓存/任务队列
|
||||
REDIS_URL="redis://localhost:6379/0"
|
||||
|
||||
# 大模型 API
|
||||
LLM_API_KEY=""
|
||||
LLM_BASE_URL=""
|
||||
# ==================== LLM AI 配置 ====================
|
||||
# 大语言模型 API 配置
|
||||
# 支持 OpenAI 兼容格式 (DeepSeek, 智谱 GLM, 阿里等)
|
||||
# 智谱 AI (Zhipu AI) GLM 系列:
|
||||
# - 模型: glm-4-flash (快速文本模型), glm-4 (标准), glm-4-plus (高性能)
|
||||
# - API: https://open.bigmodel.cn
|
||||
# - API Key: https://open.bigmodel.cn/usercenter/apikeys
|
||||
LLM_API_KEY="ca79ad9f96524cd5afc3e43ca97f347d.cpiLLx2oyitGvTeU"
|
||||
LLM_BASE_URL="https://open.bigmodel.cn/api/paas/v4"
|
||||
LLM_MODEL_NAME="glm-4v-plus"
|
||||
|
||||
# 文件存储配置
|
||||
# ==================== Supabase 配置 ====================
|
||||
# Supabase 项目配置
|
||||
SUPABASE_URL="your_supabase_url_here"
|
||||
SUPABASE_ANON_KEY="your_supabase_anon_key_here"
|
||||
SUPABASE_SERVICE_KEY="your_supabase_service_key_here"
|
||||
|
||||
# ==================== 文件路径配置 ====================
|
||||
# 上传文件存储目录 (相对于项目根目录)
|
||||
UPLOAD_DIR="./data/uploads"
|
||||
MAX_UPLOAD_SIZE=104857600 # 100MB
|
||||
|
||||
# Faiss 向量数据库持久化目录 (LangChain + Faiss 实现)
|
||||
FAISS_INDEX_DIR="./data/faiss"
|
||||
|
||||
# ==================== RAG 配置 ====================
|
||||
# Embedding 模型名称
|
||||
EMBEDDING_MODEL="all-MiniLM-L6-v2"
|
||||
|
||||
# ==================== Celery 配置 ====================
|
||||
# 异步任务队列 Broker
|
||||
CELERY_BROKER_URL="redis://localhost:6379/1"
|
||||
CELERY_RESULT_BACKEND="redis://localhost:6379/2"
|
||||
|
||||
38
backend/=3.0.0
Normal file
@@ -0,0 +1,38 @@
|
||||
Requirement already satisfied: sentence-transformers in c:\python312\lib\site-packages (2.2.2)
|
||||
Requirement already satisfied: transformers<5.0.0,>=4.6.0 in c:\python312\lib\site-packages (from sentence-transformers) (4.57.6)
|
||||
Requirement already satisfied: tqdm in c:\python312\lib\site-packages (from sentence-transformers) (4.66.1)
|
||||
Requirement already satisfied: torch>=1.6.0 in c:\python312\lib\site-packages (from sentence-transformers) (2.10.0)
|
||||
Requirement already satisfied: torchvision in c:\python312\lib\site-packages (from sentence-transformers) (0.25.0)
|
||||
Requirement already satisfied: numpy in c:\python312\lib\site-packages (from sentence-transformers) (1.26.2)
|
||||
Requirement already satisfied: scikit-learn in c:\python312\lib\site-packages (from sentence-transformers) (1.8.0)
|
||||
Requirement already satisfied: scipy in c:\python312\lib\site-packages (from sentence-transformers) (1.16.3)
|
||||
Requirement already satisfied: nltk in c:\python312\lib\site-packages (from sentence-transformers) (3.9.3)
|
||||
Requirement already satisfied: sentencepiece in c:\python312\lib\site-packages (from sentence-transformers) (0.2.1)
|
||||
Requirement already satisfied: huggingface-hub>=0.4.0 in c:\python312\lib\site-packages (from sentence-transformers) (0.36.2)
|
||||
Requirement already satisfied: filelock in c:\python312\lib\site-packages (from huggingface-hub>=0.4.0->sentence-transformers) (3.25.2)
|
||||
Requirement already satisfied: fsspec>=2023.5.0 in c:\python312\lib\site-packages (from huggingface-hub>=0.4.0->sentence-transformers) (2026.2.0)
|
||||
Requirement already satisfied: packaging>=20.9 in c:\python312\lib\site-packages (from huggingface-hub>=0.4.0->sentence-transformers) (23.2)
|
||||
Requirement already satisfied: pyyaml>=5.1 in c:\python312\lib\site-packages (from huggingface-hub>=0.4.0->sentence-transformers) (6.0.1)
|
||||
Requirement already satisfied: requests in c:\python312\lib\site-packages (from huggingface-hub>=0.4.0->sentence-transformers) (2.31.0)
|
||||
Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\python312\lib\site-packages (from huggingface-hub>=0.4.0->sentence-transformers) (4.15.0)
|
||||
Requirement already satisfied: sympy>=1.13.3 in c:\python312\lib\site-packages (from torch>=1.6.0->sentence-transformers) (1.14.0)
|
||||
Requirement already satisfied: networkx>=2.5.1 in c:\python312\lib\site-packages (from torch>=1.6.0->sentence-transformers) (3.6.1)
|
||||
Requirement already satisfied: jinja2 in c:\python312\lib\site-packages (from torch>=1.6.0->sentence-transformers) (3.1.6)
|
||||
Requirement already satisfied: setuptools in c:\python312\lib\site-packages (from torch>=1.6.0->sentence-transformers) (82.0.1)
|
||||
Requirement already satisfied: colorama in c:\python312\lib\site-packages (from tqdm->sentence-transformers) (0.4.6)
|
||||
Requirement already satisfied: regex!=2019.12.17 in c:\python312\lib\site-packages (from transformers<5.0.0,>=4.6.0->sentence-transformers) (2026.2.28)
|
||||
Requirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in c:\python312\lib\site-packages (from transformers<5.0.0,>=4.6.0->sentence-transformers) (0.22.2)
|
||||
Requirement already satisfied: safetensors>=0.4.3 in c:\python312\lib\site-packages (from transformers<5.0.0,>=4.6.0->sentence-transformers) (0.7.0)
|
||||
Requirement already satisfied: click in c:\python312\lib\site-packages (from nltk->sentence-transformers) (8.3.1)
|
||||
Requirement already satisfied: joblib in c:\python312\lib\site-packages (from nltk->sentence-transformers) (1.5.3)
|
||||
Requirement already satisfied: threadpoolctl>=3.2.0 in c:\python312\lib\site-packages (from scikit-learn->sentence-transformers) (3.6.0)
|
||||
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in c:\python312\lib\site-packages (from torchvision->sentence-transformers) (12.1.1)
|
||||
Requirement already satisfied: mpmath<1.4,>=1.1.0 in c:\python312\lib\site-packages (from sympy>=1.13.3->torch>=1.6.0->sentence-transformers) (1.3.0)
|
||||
Requirement already satisfied: MarkupSafe>=2.0 in c:\python312\lib\site-packages (from jinja2->torch>=1.6.0->sentence-transformers) (3.0.3)
|
||||
Requirement already satisfied: charset-normalizer<4,>=2 in c:\python312\lib\site-packages (from requests->huggingface-hub>=0.4.0->sentence-transformers) (3.4.6)
|
||||
Requirement already satisfied: idna<4,>=2.5 in c:\python312\lib\site-packages (from requests->huggingface-hub>=0.4.0->sentence-transformers) (3.11)
|
||||
Requirement already satisfied: urllib3<3,>=1.21.1 in c:\python312\lib\site-packages (from requests->huggingface-hub>=0.4.0->sentence-transformers) (2.6.3)
|
||||
Requirement already satisfied: certifi>=2017.4.17 in c:\python312\lib\site-packages (from requests->huggingface-hub>=0.4.0->sentence-transformers) (2026.2.25)
|
||||
|
||||
[notice] A new release of pip is available: 24.2 -> 26.0.1
|
||||
[notice] To update, run: python.exe -m pip install --upgrade pip
|
||||
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
API 路由注册模块
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
from app.api.endpoints import (
|
||||
upload,
|
||||
documents, # 多格式文档上传
|
||||
tasks, # 任务管理
|
||||
library, # 文档库
|
||||
rag, # RAG检索
|
||||
templates, # 表格模板
|
||||
ai_analyze,
|
||||
visualization,
|
||||
analysis_charts,
|
||||
health,
|
||||
instruction, # 智能指令
|
||||
)
|
||||
|
||||
# 创建主路由
|
||||
api_router = APIRouter()
|
||||
|
||||
# 注册各模块路由
|
||||
api_router.include_router(health.router) # 健康检查
|
||||
api_router.include_router(upload.router) # 原有Excel上传
|
||||
api_router.include_router(documents.router) # 多格式文档上传
|
||||
api_router.include_router(tasks.router) # 任务状态查询
|
||||
api_router.include_router(library.router) # 文档库管理
|
||||
api_router.include_router(rag.router) # RAG检索
|
||||
api_router.include_router(templates.router) # 表格模板
|
||||
api_router.include_router(ai_analyze.router) # AI分析
|
||||
api_router.include_router(visualization.router) # 可视化
|
||||
api_router.include_router(analysis_charts.router) # 分析图表
|
||||
api_router.include_router(instruction.router) # 智能指令
|
||||
|
||||
485
backend/app/api/endpoints/ai_analyze.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""
|
||||
AI 分析 API 接口
|
||||
"""
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException, Query, Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing import Optional
|
||||
import logging
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
from app.services.excel_ai_service import excel_ai_service
|
||||
from app.services.markdown_ai_service import markdown_ai_service
|
||||
from app.services.template_fill_service import template_fill_service
|
||||
from app.services.word_ai_service import word_ai_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/ai", tags=["AI 分析"])
|
||||
|
||||
|
||||
@router.post("/analyze/excel")
|
||||
async def analyze_excel(
|
||||
file: UploadFile = File(...),
|
||||
user_prompt: str = Query("", description="用户自定义提示词"),
|
||||
analysis_type: str = Query("general", description="分析类型: general, summary, statistics, insights"),
|
||||
parse_all_sheets: bool = Query(False, description="是否分析所有工作表")
|
||||
):
|
||||
"""
|
||||
上传并使用 AI 分析 Excel 文件
|
||||
|
||||
Args:
|
||||
file: 上传的 Excel 文件
|
||||
user_prompt: 用户自定义提示词
|
||||
analysis_type: 分析类型
|
||||
parse_all_sheets: 是否分析所有工作表
|
||||
|
||||
Returns:
|
||||
dict: 分析结果,包含 Excel 数据和 AI 分析结果
|
||||
"""
|
||||
# 检查文件类型
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="文件名为空")
|
||||
|
||||
file_ext = file.filename.split('.')[-1].lower()
|
||||
if file_ext not in ['xlsx', 'xls']:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的文件类型: {file_ext},仅支持 .xlsx 和 .xls"
|
||||
)
|
||||
|
||||
# 验证分析类型
|
||||
supported_types = ['general', 'summary', 'statistics', 'insights']
|
||||
if analysis_type not in supported_types:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的分析类型: {analysis_type},支持的类型: {', '.join(supported_types)}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 读取文件内容
|
||||
content = await file.read()
|
||||
|
||||
logger.info(f"开始分析文件: {file.filename}, 分析类型: {analysis_type}")
|
||||
|
||||
# 调用 AI 分析服务
|
||||
if parse_all_sheets:
|
||||
result = await excel_ai_service.batch_analyze_sheets(
|
||||
content,
|
||||
file.filename,
|
||||
user_prompt=user_prompt,
|
||||
analysis_type=analysis_type
|
||||
)
|
||||
else:
|
||||
# 解析选项
|
||||
parse_options = {"header_row": 0}
|
||||
|
||||
result = await excel_ai_service.analyze_excel_file(
|
||||
content,
|
||||
file.filename,
|
||||
user_prompt=user_prompt,
|
||||
analysis_type=analysis_type,
|
||||
parse_options=parse_options
|
||||
)
|
||||
|
||||
logger.info(f"文件分析完成: {file.filename}, 成功: {result['success']}")
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"AI 分析过程中出错: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"分析失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/analysis/types")
|
||||
async def get_analysis_types():
|
||||
"""
|
||||
获取支持的分析类型列表
|
||||
|
||||
Returns:
|
||||
dict: 支持的分析类型(包含 Excel 和 Markdown)
|
||||
"""
|
||||
return {
|
||||
"excel_types": excel_ai_service.get_supported_analysis_types(),
|
||||
"markdown_types": markdown_ai_service.get_supported_analysis_types()
|
||||
}
|
||||
|
||||
|
||||
@router.post("/analyze/text")
|
||||
async def analyze_text(
|
||||
excel_data: dict = Body(..., description="Excel 解析后的数据"),
|
||||
user_prompt: str = Body("", description="用户提示词"),
|
||||
analysis_type: str = Body("general", description="分析类型")
|
||||
):
|
||||
"""
|
||||
对已解析的 Excel 数据进行 AI 分析
|
||||
|
||||
Args:
|
||||
excel_data: Excel 数据
|
||||
user_prompt: 用户提示词
|
||||
analysis_type: 分析类型
|
||||
|
||||
Returns:
|
||||
dict: 分析结果
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始文本分析, 分析类型: {analysis_type}")
|
||||
|
||||
# 调用 LLM 服务
|
||||
from app.services.llm_service import llm_service
|
||||
|
||||
if user_prompt and user_prompt.strip():
|
||||
result = await llm_service.analyze_with_template(
|
||||
excel_data,
|
||||
user_prompt
|
||||
)
|
||||
else:
|
||||
result = await llm_service.analyze_excel_data(
|
||||
excel_data,
|
||||
user_prompt,
|
||||
analysis_type
|
||||
)
|
||||
|
||||
logger.info(f"文本分析完成, 成功: {result['success']}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"文本分析失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"分析失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/analyze/md")
|
||||
async def analyze_markdown(
|
||||
file: UploadFile = File(...),
|
||||
analysis_type: str = Query("summary", description="分析类型: summary, outline, key_points, questions, tags, qa, statistics, section"),
|
||||
user_prompt: str = Query("", description="用户自定义提示词"),
|
||||
section_number: Optional[str] = Query(None, description="指定章节编号,如 '一' 或 '(一)'")
|
||||
):
|
||||
"""
|
||||
上传并使用 AI 分析 Markdown 文件
|
||||
|
||||
Args:
|
||||
file: 上传的 Markdown 文件
|
||||
analysis_type: 分析类型
|
||||
user_prompt: 用户自定义提示词
|
||||
section_number: 指定分析的章节编号
|
||||
|
||||
Returns:
|
||||
dict: 分析结果
|
||||
"""
|
||||
# 检查文件类型
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="文件名为空")
|
||||
|
||||
file_ext = file.filename.split('.')[-1].lower()
|
||||
if file_ext not in ['md', 'markdown']:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的文件类型: {file_ext},仅支持 .md 和 .markdown"
|
||||
)
|
||||
|
||||
# 验证分析类型
|
||||
supported_types = markdown_ai_service.get_supported_analysis_types()
|
||||
if analysis_type not in supported_types:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的分析类型: {analysis_type},支持的类型: {', '.join(supported_types)}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 读取文件内容
|
||||
content = await file.read()
|
||||
|
||||
# 保存到临时文件
|
||||
with tempfile.NamedTemporaryFile(mode='wb', suffix='.md', delete=False) as tmp:
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
logger.info(f"开始分析 Markdown 文件: {file.filename}, 分析类型: {analysis_type}, 章节: {section_number}")
|
||||
|
||||
# 调用 AI 分析服务
|
||||
result = await markdown_ai_service.analyze_markdown(
|
||||
file_path=tmp_path,
|
||||
analysis_type=analysis_type,
|
||||
user_prompt=user_prompt,
|
||||
section_number=section_number
|
||||
)
|
||||
|
||||
logger.info(f"Markdown 分析完成: {file.filename}, 成功: {result['success']}")
|
||||
|
||||
if not result['success']:
|
||||
raise HTTPException(status_code=500, detail=result.get('error', '分析失败'))
|
||||
|
||||
return result
|
||||
|
||||
finally:
|
||||
# 清理临时文件,确保在所有情况下都能清理
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"临时文件清理失败: {tmp_path}, error: {cleanup_error}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Markdown AI 分析过程中出错: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"分析失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/analyze/md/stream")
|
||||
async def analyze_markdown_stream(
|
||||
file: UploadFile = File(...),
|
||||
analysis_type: str = Query("summary", description="分析类型"),
|
||||
user_prompt: str = Query("", description="用户自定义提示词"),
|
||||
section_number: Optional[str] = Query(None, description="指定章节编号")
|
||||
):
|
||||
"""
|
||||
流式分析 Markdown 文件 (SSE)
|
||||
|
||||
Returns:
|
||||
StreamingResponse: SSE 流式响应
|
||||
"""
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="文件名为空")
|
||||
|
||||
file_ext = file.filename.split('.')[-1].lower()
|
||||
if file_ext not in ['md', 'markdown']:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的文件类型: {file_ext},仅支持 .md 和 .markdown"
|
||||
)
|
||||
|
||||
try:
|
||||
content = await file.read()
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='wb', suffix='.md', delete=False) as tmp:
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
logger.info(f"开始流式分析 Markdown 文件: {file.filename}, 分析类型: {analysis_type}")
|
||||
|
||||
async def stream_generator():
|
||||
async for chunk in markdown_ai_service.analyze_markdown_stream(
|
||||
file_path=tmp_path,
|
||||
analysis_type=analysis_type,
|
||||
user_prompt=user_prompt,
|
||||
section_number=section_number
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(
|
||||
stream_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
finally:
|
||||
# 清理临时文件,确保在所有情况下都能清理
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"临时文件清理失败: {tmp_path}, error: {cleanup_error}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Markdown AI 流式分析出错: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"流式分析失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/analyze/md/outline")
|
||||
async def get_markdown_outline(
|
||||
file: UploadFile = File(...)
|
||||
):
|
||||
"""
|
||||
获取 Markdown 文档的大纲结构(分章节信息)
|
||||
|
||||
Args:
|
||||
file: 上传的 Markdown 文件
|
||||
|
||||
Returns:
|
||||
dict: 文档大纲结构
|
||||
"""
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="文件名为空")
|
||||
|
||||
file_ext = file.filename.split('.')[-1].lower()
|
||||
if file_ext not in ['md', 'markdown']:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的文件类型: {file_ext},仅支持 .md 和 .markdown"
|
||||
)
|
||||
|
||||
try:
|
||||
content = await file.read()
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='wb', suffix='.md', delete=False) as tmp:
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
result = await markdown_ai_service.extract_outline(tmp_path)
|
||||
return result
|
||||
finally:
|
||||
# 清理临时文件,确保在所有情况下都能清理
|
||||
try:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"临时文件清理失败: {tmp_path}, error: {cleanup_error}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取 Markdown 大纲失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取大纲失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/analyze/txt")
|
||||
async def analyze_txt(
|
||||
file: UploadFile = File(...),
|
||||
):
|
||||
"""
|
||||
上传并使用 AI 分析 TXT 文本文件,提取结构化数据
|
||||
|
||||
将非结构化文本转换为结构化表格数据,便于后续填表使用
|
||||
|
||||
Args:
|
||||
file: 上传的 TXT 文件
|
||||
|
||||
Returns:
|
||||
dict: 分析结果,包含结构化表格数据
|
||||
"""
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="文件名为空")
|
||||
|
||||
file_ext = file.filename.split('.')[-1].lower()
|
||||
if file_ext not in ['txt', 'text']:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的文件类型: {file_ext},仅支持 .txt"
|
||||
)
|
||||
|
||||
try:
|
||||
# 读取文件内容
|
||||
content = await file.read()
|
||||
|
||||
# 保存到临时文件
|
||||
with tempfile.NamedTemporaryFile(mode='wb', suffix='.txt', delete=False) as tmp:
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
logger.info(f"开始 AI 分析 TXT 文件: {file.filename}")
|
||||
|
||||
# 使用 template_fill_service 的 AI 分析方法
|
||||
result = await template_fill_service.analyze_txt_with_ai(
|
||||
content=content.decode('utf-8', errors='replace'),
|
||||
filename=file.filename
|
||||
)
|
||||
|
||||
if result:
|
||||
logger.info(f"TXT AI 分析成功: {file.filename}")
|
||||
return {
|
||||
"success": True,
|
||||
"filename": file.filename,
|
||||
"structured_data": result
|
||||
}
|
||||
else:
|
||||
logger.warning(f"TXT AI 分析返回空结果: {file.filename}")
|
||||
return {
|
||||
"success": False,
|
||||
"filename": file.filename,
|
||||
"error": "AI 分析未能提取到结构化数据",
|
||||
"structured_data": None
|
||||
}
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"TXT AI 分析过程中出错: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"分析失败: {str(e)}")
|
||||
|
||||
|
||||
# ==================== Word 文档 AI 解析 ====================
|
||||
|
||||
@router.post("/analyze/word")
|
||||
async def analyze_word(
|
||||
file: UploadFile = File(...),
|
||||
user_hint: str = Query("", description="用户提示词,如'请提取表格数据'")
|
||||
):
|
||||
"""
|
||||
使用 AI 解析 Word 文档,提取结构化数据
|
||||
|
||||
适用于从非结构化的 Word 文档中提取表格数据、键值对等信息
|
||||
|
||||
Args:
|
||||
file: 上传的 Word 文件
|
||||
user_hint: 用户提示词
|
||||
|
||||
Returns:
|
||||
dict: 包含结构化数据的解析结果
|
||||
"""
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="文件名为空")
|
||||
|
||||
file_ext = file.filename.split('.')[-1].lower()
|
||||
if file_ext not in ['docx']:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的文件类型: {file_ext},仅支持 .docx"
|
||||
)
|
||||
|
||||
try:
|
||||
# 保存上传的文件
|
||||
content = await file.read()
|
||||
suffix = f".{file_ext}"
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
# 使用 AI 解析 Word 文档
|
||||
result = await word_ai_service.parse_word_with_ai(
|
||||
file_path=tmp_path,
|
||||
user_hint=user_hint or "请提取文档中的所有结构化数据,包括表格、键值对等"
|
||||
)
|
||||
|
||||
if result.get("success"):
|
||||
return {
|
||||
"success": True,
|
||||
"filename": file.filename,
|
||||
"result": result
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"filename": file.filename,
|
||||
"error": result.get("error", "AI 解析失败"),
|
||||
"result": None
|
||||
}
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if os.path.exists(tmp_path):
|
||||
os.unlink(tmp_path)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Word AI 分析过程中出错: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"分析失败: {str(e)}")
|
||||
105
backend/app/api/endpoints/analysis_charts.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
分析结果图表 API - 根据文本分析结果生成图表
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
from app.services.text_analysis_service import text_analysis_service
|
||||
from app.services.chart_generator_service import chart_generator_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/analysis", tags=["分析结果图表"])
|
||||
|
||||
|
||||
class AnalysisChartRequest(BaseModel):
|
||||
"""分析图表生成请求模型"""
|
||||
analysis_text: str
|
||||
original_filename: Optional[str] = ""
|
||||
file_type: Optional[str] = "text"
|
||||
|
||||
|
||||
@router.post("/extract-and-chart")
|
||||
async def extract_and_generate_charts(request: AnalysisChartRequest):
|
||||
"""
|
||||
从 AI 分析结果中提取数据并生成图表
|
||||
|
||||
Args:
|
||||
request: 包含分析文本的请求
|
||||
|
||||
Returns:
|
||||
dict: 包含图表数据的结果
|
||||
"""
|
||||
if not request.analysis_text or not request.analysis_text.strip():
|
||||
raise HTTPException(status_code=400, detail="分析文本不能为空")
|
||||
|
||||
try:
|
||||
logger.info("开始从分析结果中提取结构化数据...")
|
||||
|
||||
# 1. 使用 LLM 提取结构化数据
|
||||
extract_result = await text_analysis_service.extract_structured_data(
|
||||
analysis_text=request.analysis_text,
|
||||
original_filename=request.original_filename or "unknown",
|
||||
file_type=request.file_type or "text"
|
||||
)
|
||||
|
||||
if not extract_result.get("success"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"提取结构化数据失败: {extract_result.get('error', '未知错误')}"
|
||||
)
|
||||
|
||||
logger.info("结构化数据提取成功,开始生成图表...")
|
||||
|
||||
# 2. 根据提取的数据生成图表
|
||||
chart_result = chart_generator_service.generate_charts_from_analysis(extract_result)
|
||||
|
||||
if not chart_result.get("success"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"生成图表失败: {chart_result.get('error', '未知错误')}"
|
||||
)
|
||||
|
||||
logger.info("图表生成成功")
|
||||
|
||||
return chart_result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"分析结果图表生成失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"图表生成失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/analyze-text")
|
||||
async def analyze_text_only(request: AnalysisChartRequest):
|
||||
"""
|
||||
仅提取结构化数据(不生成图表),用于调试
|
||||
|
||||
Args:
|
||||
request: 包含分析文本的请求
|
||||
|
||||
Returns:
|
||||
dict: 提取的结构化数据
|
||||
"""
|
||||
if not request.analysis_text or not request.analysis_text.strip():
|
||||
raise HTTPException(status_code=400, detail="分析文本不能为空")
|
||||
|
||||
try:
|
||||
result = await text_analysis_service.extract_structured_data(
|
||||
analysis_text=request.analysis_text,
|
||||
original_filename=request.original_filename or "unknown",
|
||||
file_type=request.file_type or "text"
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"文本分析失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"文本分析失败: {str(e)}"
|
||||
)
|
||||
447
backend/app/api/endpoints/documents.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""
|
||||
文档管理 API 接口
|
||||
|
||||
支持多格式文档(docx/xlsx/md/txt)上传、解析、存储和RAG索引
|
||||
集成 Excel 存储和 AI 生成字段描述
|
||||
"""
|
||||
import logging
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException, Query, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.services.file_service import file_service
|
||||
from app.core.database import mongodb, redis_db
|
||||
from app.services.rag_service import rag_service
|
||||
from app.services.table_rag_service import table_rag_service
|
||||
from app.services.excel_storage_service import excel_storage_service
|
||||
from app.core.document_parser import ParserFactory, ParseResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/upload", tags=["文档上传"])
|
||||
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
async def update_task_status(
|
||||
task_id: str,
|
||||
status: str,
|
||||
progress: int = 0,
|
||||
message: str = "",
|
||||
result: dict = None,
|
||||
error: str = None
|
||||
):
|
||||
"""
|
||||
更新任务状态,同时写入 Redis 和 MongoDB
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
status: 状态
|
||||
progress: 进度
|
||||
message: 消息
|
||||
result: 结果
|
||||
error: 错误信息
|
||||
"""
|
||||
meta = {"progress": progress, "message": message}
|
||||
if result:
|
||||
meta["result"] = result
|
||||
if error:
|
||||
meta["error"] = error
|
||||
|
||||
# 尝试写入 Redis
|
||||
try:
|
||||
await redis_db.set_task_status(task_id, status, meta)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis 任务状态更新失败: {e}")
|
||||
|
||||
# 尝试写入 MongoDB(作为备用)
|
||||
try:
|
||||
await mongodb.update_task(
|
||||
task_id=task_id,
|
||||
status=status,
|
||||
message=message,
|
||||
result=result,
|
||||
error=error
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"MongoDB 任务状态更新失败: {e}")
|
||||
|
||||
|
||||
# ==================== 请求/响应模型 ====================
|
||||
|
||||
class UploadResponse(BaseModel):
|
||||
task_id: str
|
||||
file_count: int
|
||||
message: str
|
||||
status_url: str
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
task_id: str
|
||||
status: str
|
||||
progress: int = 0
|
||||
message: Optional[str] = None
|
||||
result: Optional[dict] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
# ==================== 文档上传接口 ====================
|
||||
|
||||
@router.post("/document", response_model=UploadResponse)
|
||||
async def upload_document(
|
||||
background_tasks: BackgroundTasks,
|
||||
file: UploadFile = File(...),
|
||||
parse_all_sheets: bool = Query(False, description="是否解析所有工作表(仅Excel)"),
|
||||
sheet_name: Optional[str] = Query(None, description="指定工作表(仅Excel)"),
|
||||
header_row: int = Query(0, description="表头行号(仅Excel)")
|
||||
):
|
||||
"""
|
||||
上传单个文档并异步处理
|
||||
|
||||
文档会:
|
||||
1. 保存到本地存储
|
||||
2. 解析内容
|
||||
3. 存入 MongoDB (原始内容)
|
||||
4. 如果是 Excel:
|
||||
- 存入 MySQL (结构化数据)
|
||||
- AI 生成字段描述
|
||||
- 建立 RAG 索引
|
||||
5. 建立 RAG 索引 (非结构化文档)
|
||||
"""
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="文件名为空")
|
||||
|
||||
file_ext = file.filename.split('.')[-1].lower()
|
||||
if file_ext not in ['docx', 'xlsx', 'xls', 'md', 'txt']:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的文件类型: {file_ext},仅支持 docx/xlsx/xls/md/txt"
|
||||
)
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# 保存任务记录到 MongoDB(如果 Redis 不可用时仍能查询)
|
||||
try:
|
||||
await mongodb.insert_task(
|
||||
task_id=task_id,
|
||||
task_type="document_parse",
|
||||
status="pending",
|
||||
message=f"文档 {file.filename} 已提交处理"
|
||||
)
|
||||
except Exception as mongo_err:
|
||||
logger.warning(f"MongoDB 保存任务记录失败: {mongo_err}")
|
||||
|
||||
content = await file.read()
|
||||
saved_path = file_service.save_uploaded_file(
|
||||
content,
|
||||
file.filename,
|
||||
subfolder=file_ext
|
||||
)
|
||||
|
||||
background_tasks.add_task(
|
||||
process_document,
|
||||
task_id=task_id,
|
||||
file_path=saved_path,
|
||||
original_filename=file.filename,
|
||||
doc_type=file_ext,
|
||||
parse_options={
|
||||
"parse_all_sheets": parse_all_sheets,
|
||||
"sheet_name": sheet_name,
|
||||
"header_row": header_row
|
||||
}
|
||||
)
|
||||
|
||||
return UploadResponse(
|
||||
task_id=task_id,
|
||||
file_count=1,
|
||||
message=f"文档 {file.filename} 已提交处理",
|
||||
status_url=f"/api/v1/tasks/{task_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"上传文档失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/documents", response_model=UploadResponse)
|
||||
async def upload_documents(
|
||||
background_tasks: BackgroundTasks,
|
||||
files: List[UploadFile] = File(...),
|
||||
):
|
||||
"""批量上传文档"""
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="没有上传文件")
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
saved_paths = []
|
||||
|
||||
try:
|
||||
# 保存任务记录到 MongoDB
|
||||
try:
|
||||
await mongodb.insert_task(
|
||||
task_id=task_id,
|
||||
task_type="batch_parse",
|
||||
status="pending",
|
||||
message=f"已提交 {len(files)} 个文档处理"
|
||||
)
|
||||
except Exception as mongo_err:
|
||||
logger.warning(f"MongoDB 保存批量任务记录失败: {mongo_err}")
|
||||
|
||||
for file in files:
|
||||
if not file.filename:
|
||||
continue
|
||||
content = await file.read()
|
||||
saved_path = file_service.save_uploaded_file(content, file.filename, subfolder="batch")
|
||||
saved_paths.append({
|
||||
"path": saved_path,
|
||||
"filename": file.filename,
|
||||
"ext": file.filename.split('.')[-1].lower()
|
||||
})
|
||||
|
||||
background_tasks.add_task(process_documents_batch, task_id=task_id, files=saved_paths)
|
||||
|
||||
return UploadResponse(
|
||||
task_id=task_id,
|
||||
file_count=len(saved_paths),
|
||||
message=f"已提交 {len(saved_paths)} 个文档处理",
|
||||
status_url=f"/api/v1/tasks/{task_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量上传失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"批量上传失败: {str(e)}")
|
||||
|
||||
|
||||
# ==================== 任务处理函数 ====================
|
||||
|
||||
async def process_document(
|
||||
task_id: str,
|
||||
file_path: str,
|
||||
original_filename: str,
|
||||
doc_type: str,
|
||||
parse_options: dict
|
||||
):
|
||||
"""处理单个文档"""
|
||||
try:
|
||||
# 状态: 解析中
|
||||
await update_task_status(
|
||||
task_id, status="processing",
|
||||
progress=10, message="正在解析文档"
|
||||
)
|
||||
|
||||
# 解析文档
|
||||
parser = ParserFactory.get_parser(file_path)
|
||||
result = parser.parse(file_path)
|
||||
|
||||
if not result.success:
|
||||
raise Exception(result.error or "解析失败")
|
||||
|
||||
# 状态: 存储中
|
||||
await update_task_status(
|
||||
task_id, status="processing",
|
||||
progress=30, message="正在存储数据"
|
||||
)
|
||||
|
||||
# 存储到 MongoDB
|
||||
doc_id = await mongodb.insert_document(
|
||||
doc_type=doc_type,
|
||||
content=result.data.get("content", ""),
|
||||
metadata={
|
||||
**result.metadata,
|
||||
"original_filename": original_filename,
|
||||
"file_path": file_path
|
||||
},
|
||||
structured_data=result.data.get("structured_data")
|
||||
)
|
||||
|
||||
# 如果是 Excel,存储到 MySQL + AI生成描述 + RAG索引
|
||||
if doc_type in ["xlsx", "xls"]:
|
||||
await update_task_status(
|
||||
task_id, status="processing",
|
||||
progress=50, message="正在存储到MySQL并生成字段描述"
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用 TableRAG 服务完成建表和RAG索引
|
||||
logger.info(f"开始存储Excel到MySQL: {original_filename}, file_path: {file_path}")
|
||||
rag_result = await table_rag_service.build_table_rag_index(
|
||||
file_path=file_path,
|
||||
filename=original_filename,
|
||||
sheet_name=parse_options.get("sheet_name"),
|
||||
header_row=parse_options.get("header_row", 0)
|
||||
)
|
||||
|
||||
if rag_result.get("success"):
|
||||
logger.info(f"Excel存储到MySQL成功: {original_filename}, table: {rag_result.get('table_name')}")
|
||||
else:
|
||||
logger.error(f"RAG索引构建失败: {rag_result.get('error')}")
|
||||
except Exception as e:
|
||||
logger.error(f"Excel存储到MySQL异常: {str(e)}", exc_info=True)
|
||||
|
||||
else:
|
||||
# 非结构化文档
|
||||
await update_task_status(
|
||||
task_id, status="processing",
|
||||
progress=60, message="正在建立索引"
|
||||
)
|
||||
|
||||
# 如果文档中有表格数据,提取并存储到 MySQL + RAG
|
||||
structured_data = result.data.get("structured_data", {})
|
||||
tables = structured_data.get("tables", [])
|
||||
|
||||
if tables:
|
||||
# 对每个表格建立 MySQL 表和 RAG 索引
|
||||
for table_info in tables:
|
||||
await table_rag_service.index_document_table(
|
||||
doc_id=doc_id,
|
||||
filename=original_filename,
|
||||
table_data=table_info,
|
||||
source_doc_type=doc_type
|
||||
)
|
||||
|
||||
# 同时对文档内容建立 RAG 索引
|
||||
await index_document_to_rag(doc_id, original_filename, result, doc_type)
|
||||
|
||||
# 完成
|
||||
await update_task_status(
|
||||
task_id, status="success",
|
||||
progress=100, message="处理完成",
|
||||
result={
|
||||
"doc_id": doc_id,
|
||||
"doc_type": doc_type,
|
||||
"filename": original_filename
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"文档处理完成: {original_filename}, doc_id: {doc_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"文档处理失败: {str(e)}")
|
||||
await update_task_status(
|
||||
task_id, status="failure",
|
||||
progress=0, message="处理失败",
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
|
||||
async def process_documents_batch(task_id: str, files: List[dict]):
|
||||
"""批量处理文档"""
|
||||
try:
|
||||
await update_task_status(
|
||||
task_id, status="processing",
|
||||
progress=0, message="开始批量处理"
|
||||
)
|
||||
|
||||
results = []
|
||||
for i, file_info in enumerate(files):
|
||||
try:
|
||||
parser = ParserFactory.get_parser(file_info["path"])
|
||||
result = parser.parse(file_info["path"])
|
||||
|
||||
if result.success:
|
||||
doc_id = await mongodb.insert_document(
|
||||
doc_type=file_info["ext"],
|
||||
content=result.data.get("content", ""),
|
||||
metadata={
|
||||
**result.metadata,
|
||||
"original_filename": file_info["filename"],
|
||||
"file_path": file_info["path"]
|
||||
},
|
||||
structured_data=result.data.get("structured_data")
|
||||
)
|
||||
|
||||
# Excel 处理
|
||||
if file_info["ext"] in ["xlsx", "xls"]:
|
||||
await table_rag_service.build_table_rag_index(
|
||||
file_path=file_info["path"],
|
||||
filename=file_info["filename"]
|
||||
)
|
||||
else:
|
||||
# 非结构化文档:处理其中的表格 + 内容索引
|
||||
structured_data = result.data.get("structured_data", {})
|
||||
tables = structured_data.get("tables", [])
|
||||
|
||||
if tables:
|
||||
for table_info in tables:
|
||||
await table_rag_service.index_document_table(
|
||||
doc_id=doc_id,
|
||||
filename=file_info["filename"],
|
||||
table_data=table_info,
|
||||
source_doc_type=file_info["ext"]
|
||||
)
|
||||
|
||||
await index_document_to_rag(doc_id, file_info["filename"], result, file_info["ext"])
|
||||
|
||||
results.append({"filename": file_info["filename"], "doc_id": doc_id, "success": True})
|
||||
else:
|
||||
results.append({"filename": file_info["filename"], "success": False, "error": result.error})
|
||||
|
||||
except Exception as e:
|
||||
results.append({"filename": file_info["filename"], "success": False, "error": str(e)})
|
||||
|
||||
progress = int((i + 1) / len(files) * 100)
|
||||
await update_task_status(
|
||||
task_id, status="processing",
|
||||
progress=progress, message=f"已处理 {i+1}/{len(files)}"
|
||||
)
|
||||
|
||||
await update_task_status(
|
||||
task_id, status="success",
|
||||
progress=100, message="批量处理完成",
|
||||
result={"results": results}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量处理失败: {str(e)}")
|
||||
await update_task_status(
|
||||
task_id, status="failure",
|
||||
progress=0, message="批量处理失败",
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
|
||||
async def index_document_to_rag(doc_id: str, filename: str, result: ParseResult, doc_type: str):
|
||||
"""将非结构化文档索引到 RAG(使用分块索引)"""
|
||||
try:
|
||||
content = result.data.get("content", "")
|
||||
if content:
|
||||
# 将完整内容传递给 RAG 服务自动分块索引
|
||||
rag_service.index_document_content(
|
||||
doc_id=doc_id,
|
||||
content=content, # 传递完整内容,由 RAG 服务自动分块
|
||||
metadata={
|
||||
"filename": filename,
|
||||
"doc_type": doc_type
|
||||
},
|
||||
chunk_size=500, # 每块 500 字符
|
||||
chunk_overlap=50 # 块之间 50 字符重叠
|
||||
)
|
||||
logger.info(f"RAG 索引完成: {filename}, doc_id={doc_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"RAG 索引失败: {str(e)}")
|
||||
|
||||
|
||||
# ==================== 文档解析接口 ====================
|
||||
|
||||
@router.post("/document/parse")
|
||||
async def parse_uploaded_document(
|
||||
file_path: str = Query(..., description="文件路径")
|
||||
):
|
||||
"""解析已上传的文档"""
|
||||
try:
|
||||
parser = ParserFactory.get_parser(file_path)
|
||||
result = parser.parse(file_path)
|
||||
|
||||
if result.success:
|
||||
return result.to_dict()
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=result.error)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"解析文档失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"解析失败: {str(e)}")
|
||||
93
backend/app/api/endpoints/health.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
健康检查接口
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.core.database import mysql_db, mongodb, redis_db
|
||||
|
||||
router = APIRouter(tags=["健康检查"])
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check() -> Dict[str, Any]:
|
||||
"""
|
||||
健康检查接口
|
||||
|
||||
返回各数据库连接状态和应用信息
|
||||
"""
|
||||
# 检查各数据库连接状态
|
||||
mysql_status = "unknown"
|
||||
mongodb_status = "unknown"
|
||||
redis_status = "unknown"
|
||||
|
||||
try:
|
||||
if mysql_db.async_engine is None:
|
||||
mysql_status = "disconnected"
|
||||
else:
|
||||
# 实际执行一次查询验证连接
|
||||
from sqlalchemy import text
|
||||
async with mysql_db.async_engine.connect() as conn:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
mysql_status = "connected"
|
||||
except Exception as e:
|
||||
logger.warning(f"MySQL 健康检查失败: {e}")
|
||||
mysql_status = "error"
|
||||
|
||||
try:
|
||||
if mongodb.client is None:
|
||||
mongodb_status = "disconnected"
|
||||
else:
|
||||
# 实际 ping 验证
|
||||
await mongodb.client.admin.command('ping')
|
||||
mongodb_status = "connected"
|
||||
except Exception as e:
|
||||
logger.warning(f"MongoDB 健康检查失败: {e}")
|
||||
mongodb_status = "error"
|
||||
|
||||
try:
|
||||
if not redis_db.is_connected or redis_db.client is None:
|
||||
redis_status = "disconnected"
|
||||
else:
|
||||
# 实际执行 ping 验证
|
||||
await redis_db.client.ping()
|
||||
redis_status = "connected"
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis 健康检查失败: {e}")
|
||||
redis_status = "error"
|
||||
|
||||
return {
|
||||
"status": "healthy" if all([
|
||||
mysql_status == "connected",
|
||||
mongodb_status == "connected",
|
||||
redis_status == "connected"
|
||||
]) else "degraded",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"services": {
|
||||
"mysql": mysql_status,
|
||||
"mongodb": mongodb_status,
|
||||
"redis": redis_status,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.get("/health/ready")
|
||||
async def readiness_check() -> Dict[str, str]:
|
||||
"""
|
||||
就绪检查接口
|
||||
|
||||
用于 Kubernetes/负载均衡器检查服务是否就绪
|
||||
"""
|
||||
return {"status": "ready"}
|
||||
|
||||
|
||||
@router.get("/health/live")
|
||||
async def liveness_check() -> Dict[str, str]:
|
||||
"""
|
||||
存活检查接口
|
||||
|
||||
用于 Kubernetes/负载均衡器检查服务是否存活
|
||||
"""
|
||||
return {"status": "alive"}
|
||||
439
backend/app/api/endpoints/instruction.py
Normal file
@@ -0,0 +1,439 @@
|
||||
"""
|
||||
智能指令 API 接口
|
||||
|
||||
支持自然语言指令解析和执行
|
||||
"""
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.instruction.intent_parser import intent_parser
|
||||
from app.instruction.executor import instruction_executor
|
||||
from app.core.database import mongodb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/instruction", tags=["智能指令"])
|
||||
|
||||
|
||||
# ==================== 请求/响应模型 ====================
|
||||
|
||||
class InstructionRequest(BaseModel):
|
||||
instruction: str
|
||||
doc_ids: Optional[List[str]] = None # 关联的文档 ID 列表
|
||||
context: Optional[Dict[str, Any]] = None # 额外上下文
|
||||
|
||||
|
||||
class IntentRecognitionResponse(BaseModel):
|
||||
success: bool
|
||||
intent: str
|
||||
params: Dict[str, Any]
|
||||
message: str
|
||||
|
||||
|
||||
class InstructionExecutionResponse(BaseModel):
|
||||
success: bool
|
||||
intent: str
|
||||
result: Dict[str, Any]
|
||||
message: str
|
||||
|
||||
|
||||
# ==================== 接口 ====================
|
||||
|
||||
@router.post("/recognize", response_model=IntentRecognitionResponse)
|
||||
async def recognize_intent(request: InstructionRequest):
|
||||
"""
|
||||
意图识别接口
|
||||
|
||||
将自然语言指令解析为结构化的意图和参数
|
||||
|
||||
示例指令:
|
||||
- "提取文档中的医院数量和床位数"
|
||||
- "根据这些数据填表"
|
||||
- "总结一下这份文档"
|
||||
- "对比这两个文档的差异"
|
||||
"""
|
||||
try:
|
||||
intent, params = await intent_parser.parse(request.instruction)
|
||||
|
||||
# 添加文档关联信息
|
||||
if request.doc_ids:
|
||||
params["document_refs"] = [f"doc_{doc_id}" for doc_id in request.doc_ids]
|
||||
|
||||
intent_names = {
|
||||
"extract": "信息提取",
|
||||
"fill_table": "表格填写",
|
||||
"summarize": "摘要总结",
|
||||
"question": "智能问答",
|
||||
"search": "文档搜索",
|
||||
"compare": "对比分析",
|
||||
"transform": "格式转换",
|
||||
"edit": "文档编辑",
|
||||
"unknown": "未知"
|
||||
}
|
||||
|
||||
return IntentRecognitionResponse(
|
||||
success=True,
|
||||
intent=intent,
|
||||
params=params,
|
||||
message=f"识别到意图: {intent_names.get(intent, intent)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"意图识别失败: {e}")
|
||||
return IntentRecognitionResponse(
|
||||
success=False,
|
||||
intent="error",
|
||||
params={},
|
||||
message=f"意图识别失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/execute")
|
||||
async def execute_instruction(
|
||||
background_tasks: BackgroundTasks,
|
||||
request: InstructionRequest,
|
||||
async_execute: bool = Query(False, description="是否异步执行(仅返回任务ID)")
|
||||
):
|
||||
"""
|
||||
指令执行接口
|
||||
|
||||
解析并执行自然语言指令
|
||||
|
||||
示例:
|
||||
- 指令: "提取文档1中的医院数量"
|
||||
返回: {"extracted_data": {"医院数量": ["38710个"]}}
|
||||
|
||||
- 指令: "填表"
|
||||
返回: {"filled_data": {...}}
|
||||
|
||||
设置 async_execute=true 可异步执行,返回任务ID用于查询进度
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
if async_execute:
|
||||
# 异步模式:立即返回任务ID,后台执行
|
||||
background_tasks.add_task(
|
||||
_execute_instruction_task,
|
||||
task_id=task_id,
|
||||
instruction=request.instruction,
|
||||
doc_ids=request.doc_ids,
|
||||
context=request.context
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"task_id": task_id,
|
||||
"message": "指令已提交执行",
|
||||
"status_url": f"/api/v1/tasks/{task_id}"
|
||||
}
|
||||
|
||||
# 同步模式:等待执行完成
|
||||
return await _execute_instruction_task(task_id, request.instruction, request.doc_ids, request.context)
|
||||
|
||||
|
||||
async def _execute_instruction_task(
|
||||
task_id: str,
|
||||
instruction: str,
|
||||
doc_ids: Optional[List[str]],
|
||||
context: Optional[Dict[str, Any]]
|
||||
) -> InstructionExecutionResponse:
|
||||
"""执行指令的后台任务"""
|
||||
from app.core.database import redis_db, mongodb as mongo_client
|
||||
|
||||
try:
|
||||
# 记录任务
|
||||
try:
|
||||
await mongo_client.insert_task(
|
||||
task_id=task_id,
|
||||
task_type="instruction_execute",
|
||||
status="processing",
|
||||
message="正在执行指令"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 构建执行上下文
|
||||
ctx: Dict[str, Any] = context or {}
|
||||
|
||||
# 如果提供了文档 ID,获取文档内容
|
||||
if doc_ids:
|
||||
docs = []
|
||||
for doc_id in doc_ids:
|
||||
doc = await mongo_client.get_document(doc_id)
|
||||
if doc:
|
||||
docs.append(doc)
|
||||
|
||||
if docs:
|
||||
ctx["source_docs"] = docs
|
||||
logger.info(f"指令执行上下文: 关联了 {len(docs)} 个文档")
|
||||
|
||||
# 执行指令
|
||||
result = await instruction_executor.execute(instruction, ctx)
|
||||
|
||||
# 更新任务状态
|
||||
try:
|
||||
await mongo_client.update_task(
|
||||
task_id=task_id,
|
||||
status="success",
|
||||
message="执行完成",
|
||||
result=result
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return InstructionExecutionResponse(
|
||||
success=result.get("success", False),
|
||||
intent=result.get("intent", "unknown"),
|
||||
result=result,
|
||||
message=result.get("message", "执行完成")
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"指令执行失败: {e}")
|
||||
try:
|
||||
await mongo_client.update_task(
|
||||
task_id=task_id,
|
||||
status="failure",
|
||||
message="执行失败",
|
||||
error=str(e)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return InstructionExecutionResponse(
|
||||
success=False,
|
||||
intent="error",
|
||||
result={"error": str(e)},
|
||||
message=f"指令执行失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
async def instruction_chat(
|
||||
background_tasks: BackgroundTasks,
|
||||
request: InstructionRequest,
|
||||
async_execute: bool = Query(False, description="是否异步执行(仅返回任务ID)")
|
||||
):
|
||||
"""
|
||||
指令对话接口
|
||||
|
||||
支持多轮对话的指令执行
|
||||
|
||||
示例对话流程:
|
||||
1. 用户: "上传一些文档"
|
||||
2. 系统: "请上传文档"
|
||||
3. 用户: "提取其中的医院数量"
|
||||
4. 系统: 返回提取结果
|
||||
|
||||
设置 async_execute=true 可异步执行,返回任务ID用于查询进度
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
if async_execute:
|
||||
# 异步模式:立即返回任务ID,后台执行
|
||||
background_tasks.add_task(
|
||||
_execute_chat_task,
|
||||
task_id=task_id,
|
||||
instruction=request.instruction,
|
||||
doc_ids=request.doc_ids,
|
||||
context=request.context
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"task_id": task_id,
|
||||
"message": "指令已提交执行",
|
||||
"status_url": f"/api/v1/tasks/{task_id}"
|
||||
}
|
||||
|
||||
# 同步模式:等待执行完成
|
||||
return await _execute_chat_task(task_id, request.instruction, request.doc_ids, request.context)
|
||||
|
||||
|
||||
async def _execute_chat_task(
|
||||
task_id: str,
|
||||
instruction: str,
|
||||
doc_ids: Optional[List[str]],
|
||||
context: Optional[Dict[str, Any]]
|
||||
):
|
||||
"""执行指令对话的后台任务"""
|
||||
from app.core.database import mongodb as mongo_client
|
||||
|
||||
try:
|
||||
# 记录任务
|
||||
try:
|
||||
await mongo_client.insert_task(
|
||||
task_id=task_id,
|
||||
task_type="instruction_chat",
|
||||
status="processing",
|
||||
message="正在处理对话"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 构建上下文
|
||||
ctx: Dict[str, Any] = context or {}
|
||||
|
||||
# 获取关联文档
|
||||
if doc_ids:
|
||||
docs = []
|
||||
for doc_id in doc_ids:
|
||||
doc = await mongo_client.get_document(doc_id)
|
||||
if doc:
|
||||
docs.append(doc)
|
||||
if docs:
|
||||
ctx["source_docs"] = docs
|
||||
|
||||
# 执行指令
|
||||
result = await instruction_executor.execute(instruction, ctx)
|
||||
|
||||
# 根据意图类型添加友好的响应消息
|
||||
response_messages = {
|
||||
"extract": f"已提取 {len(result.get('extracted_data', {}))} 个字段的数据",
|
||||
"fill_table": f"填表完成,填写了 {len(result.get('result', {}).get('filled_data', {}))} 个字段",
|
||||
"summarize": "已生成文档摘要",
|
||||
"question": "已找到相关答案",
|
||||
"search": f"找到 {len(result.get('results', []))} 条相关内容",
|
||||
"compare": f"对比了 {len(result.get('comparison', []))} 个文档",
|
||||
"edit": "编辑操作已完成",
|
||||
"transform": "格式转换已完成",
|
||||
"unknown": "无法理解该指令,请尝试更明确的描述"
|
||||
}
|
||||
|
||||
response = {
|
||||
"success": result.get("success", False),
|
||||
"intent": result.get("intent", "unknown"),
|
||||
"result": result,
|
||||
"message": response_messages.get(result.get("intent", ""), result.get("message", "")),
|
||||
"hint": _get_intent_hint(result.get("intent", ""))
|
||||
}
|
||||
|
||||
# 更新任务状态
|
||||
try:
|
||||
await mongo_client.update_task(
|
||||
task_id=task_id,
|
||||
status="success",
|
||||
message="处理完成",
|
||||
result=response
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"指令对话失败: {e}")
|
||||
try:
|
||||
await mongo_client.update_task(
|
||||
task_id=task_id,
|
||||
status="failure",
|
||||
message="处理失败",
|
||||
error=str(e)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"处理失败: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
def _get_intent_hint(intent: str) -> Optional[str]:
|
||||
"""根据意图返回下一步提示"""
|
||||
hints = {
|
||||
"extract": "您可以继续说 '提取更多字段' 或 '将数据填入表格'",
|
||||
"fill_table": "您可以提供表格模板或说 '帮我创建一个表格'",
|
||||
"question": "您可以继续提问或说 '总结一下这些内容'",
|
||||
"search": "您可以查看搜索结果或说 '对比这些内容'",
|
||||
"unknown": "您可以尝试: '提取数据'、'填表'、'总结'、'问答' 等指令"
|
||||
}
|
||||
return hints.get(intent)
|
||||
|
||||
|
||||
@router.get("/intents")
|
||||
async def list_supported_intents():
|
||||
"""
|
||||
获取支持的意图类型列表
|
||||
|
||||
返回所有可用的自然语言指令类型
|
||||
"""
|
||||
return {
|
||||
"intents": [
|
||||
{
|
||||
"intent": "extract",
|
||||
"name": "信息提取",
|
||||
"examples": [
|
||||
"提取文档中的医院数量",
|
||||
"抽取所有机构的名称",
|
||||
"找出表格中的数据"
|
||||
],
|
||||
"params": ["field_refs", "document_refs"]
|
||||
},
|
||||
{
|
||||
"intent": "fill_table",
|
||||
"name": "表格填写",
|
||||
"examples": [
|
||||
"填表",
|
||||
"根据这些数据填写表格",
|
||||
"帮我填到Excel里"
|
||||
],
|
||||
"params": ["template", "document_refs"]
|
||||
},
|
||||
{
|
||||
"intent": "summarize",
|
||||
"name": "摘要总结",
|
||||
"examples": [
|
||||
"总结一下这份文档",
|
||||
"生成摘要",
|
||||
"概括主要内容"
|
||||
],
|
||||
"params": ["document_refs"]
|
||||
},
|
||||
{
|
||||
"intent": "question",
|
||||
"name": "智能问答",
|
||||
"examples": [
|
||||
"这段话说的是什么?",
|
||||
"有多少家医院?",
|
||||
"解释一下这个概念"
|
||||
],
|
||||
"params": ["question", "focus"]
|
||||
},
|
||||
{
|
||||
"intent": "search",
|
||||
"name": "文档搜索",
|
||||
"examples": [
|
||||
"搜索相关内容",
|
||||
"找找看有哪些机构",
|
||||
"查询医院相关的数据"
|
||||
],
|
||||
"params": ["field_refs", "question"]
|
||||
},
|
||||
{
|
||||
"intent": "compare",
|
||||
"name": "对比分析",
|
||||
"examples": [
|
||||
"对比这两个文档",
|
||||
"比较一下差异",
|
||||
"找出不同点"
|
||||
],
|
||||
"params": ["document_refs"]
|
||||
},
|
||||
{
|
||||
"intent": "edit",
|
||||
"name": "文档编辑",
|
||||
"examples": [
|
||||
"润色这段文字",
|
||||
"修改格式",
|
||||
"添加注释"
|
||||
],
|
||||
"params": []
|
||||
}
|
||||
]
|
||||
}
|
||||
170
backend/app/api/endpoints/library.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
文档库管理 API 接口
|
||||
|
||||
提供文档列表、详情查询和删除功能
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.database import mongodb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/documents", tags=["文档库"])
|
||||
|
||||
|
||||
class DocumentItem(BaseModel):
|
||||
doc_id: str
|
||||
filename: str
|
||||
original_filename: str
|
||||
doc_type: str
|
||||
file_size: int
|
||||
created_at: str
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_documents(
|
||||
doc_type: Optional[str] = Query(None, description="文档类型过滤"),
|
||||
limit: int = Query(20, ge=1, le=100, description="返回数量"),
|
||||
skip: int = Query(0, ge=0, description="跳过数量")
|
||||
):
|
||||
"""
|
||||
获取文档列表
|
||||
|
||||
Returns:
|
||||
文档列表
|
||||
"""
|
||||
try:
|
||||
# 构建查询条件
|
||||
query = {}
|
||||
if doc_type:
|
||||
query["doc_type"] = doc_type
|
||||
|
||||
logger.info(f"开始查询文档列表, query: {query}, limit: {limit}")
|
||||
|
||||
# 使用 batch_size 和 max_time_ms 来控制查询
|
||||
cursor = mongodb.documents.find(
|
||||
query,
|
||||
{"content": 0} # 不返回 content 字段,减少数据传输
|
||||
).sort("created_at", -1).skip(skip).limit(limit)
|
||||
|
||||
# 设置 10 秒超时
|
||||
cursor.max_time_ms(10000)
|
||||
|
||||
logger.info("Cursor created with 10s timeout, executing...")
|
||||
|
||||
# 使用 batch_size 逐批获取
|
||||
documents_raw = await cursor.to_list(length=limit)
|
||||
logger.info(f"查询到原始文档数: {len(documents_raw)}")
|
||||
|
||||
documents = []
|
||||
for doc in documents_raw:
|
||||
documents.append({
|
||||
"doc_id": str(doc["_id"]),
|
||||
"filename": doc.get("metadata", {}).get("filename", ""),
|
||||
"original_filename": doc.get("metadata", {}).get("original_filename", ""),
|
||||
"doc_type": doc.get("doc_type", ""),
|
||||
"file_size": doc.get("metadata", {}).get("file_size", 0),
|
||||
"created_at": doc.get("created_at", "").isoformat() if doc.get("created_at") else "",
|
||||
"metadata": {
|
||||
"row_count": doc.get("metadata", {}).get("row_count"),
|
||||
"column_count": doc.get("metadata", {}).get("column_count"),
|
||||
"columns": doc.get("metadata", {}).get("columns", [])[:10]
|
||||
}
|
||||
})
|
||||
|
||||
logger.info(f"文档列表处理完成: {len(documents)} 个文档")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"documents": documents,
|
||||
"total": len(documents)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
err_str = str(e)
|
||||
# 如果是超时错误,返回空列表而不是报错
|
||||
if "timeout" in err_str.lower() or "time" in err_str.lower():
|
||||
logger.warning(f"文档查询超时,返回空列表: {err_str}")
|
||||
return {
|
||||
"success": True,
|
||||
"documents": [],
|
||||
"total": 0,
|
||||
"warning": "查询超时,请稍后重试"
|
||||
}
|
||||
logger.error(f"获取文档列表失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"获取文档列表失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{doc_id}")
|
||||
async def get_document(doc_id: str):
|
||||
"""
|
||||
获取文档详情
|
||||
|
||||
Args:
|
||||
doc_id: 文档ID
|
||||
|
||||
Returns:
|
||||
文档详情
|
||||
"""
|
||||
try:
|
||||
doc = await mongodb.get_document(doc_id)
|
||||
|
||||
if not doc:
|
||||
raise HTTPException(status_code=404, detail="文档不存在")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"document": {
|
||||
"doc_id": str(doc["_id"]),
|
||||
"filename": doc.get("metadata", {}).get("filename", ""),
|
||||
"original_filename": doc.get("metadata", {}).get("original_filename", ""),
|
||||
"doc_type": doc.get("doc_type", ""),
|
||||
"file_size": doc.get("metadata", {}).get("file_size", 0),
|
||||
"created_at": doc.get("created_at", "").isoformat() if doc.get("created_at") else "",
|
||||
"content": doc.get("content", ""), # 原始文本内容
|
||||
"structured_data": doc.get("structured_data"), # 结构化数据(如果有)
|
||||
"metadata": doc.get("metadata", {})
|
||||
}
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"获取文档详情失败: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/{doc_id}")
|
||||
async def delete_document(doc_id: str):
|
||||
"""
|
||||
删除文档
|
||||
|
||||
Args:
|
||||
doc_id: 文档ID
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
# 从 MongoDB 删除
|
||||
deleted = await mongodb.delete_document(doc_id)
|
||||
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="文档不存在")
|
||||
|
||||
# TODO: 从 MySQL 删除相关数据(如果是Excel)
|
||||
# TODO: 从 RAG 删除相关索引
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "文档已删除"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}")
|
||||
116
backend/app/api/endpoints/rag.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
RAG 检索 API 接口
|
||||
|
||||
提供向量检索功能
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.services.rag_service import rag_service
|
||||
|
||||
router = APIRouter(prefix="/rag", tags=["RAG检索"])
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
query: str
|
||||
top_k: int = 5
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
content: str
|
||||
metadata: dict
|
||||
score: float
|
||||
doc_id: str
|
||||
|
||||
|
||||
@router.post("/search")
|
||||
async def search_rag(
|
||||
request: SearchRequest
|
||||
):
|
||||
"""
|
||||
RAG 语义检索
|
||||
|
||||
根据查询文本检索相关的文档片段或字段
|
||||
|
||||
Args:
|
||||
request.query: 查询文本
|
||||
request.top_k: 返回数量
|
||||
|
||||
Returns:
|
||||
相关文档列表
|
||||
"""
|
||||
try:
|
||||
results = rag_service.retrieve(
|
||||
query=request.query,
|
||||
top_k=request.top_k
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"results": results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"检索失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_rag_status():
|
||||
"""
|
||||
获取 RAG 索引状态
|
||||
|
||||
Returns:
|
||||
RAG 索引统计信息
|
||||
"""
|
||||
try:
|
||||
count = rag_service.get_vector_count()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"vector_count": count,
|
||||
"collections": ["document_fields", "document_content"] # 预留
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/rebuild")
|
||||
async def rebuild_rag_index():
|
||||
"""
|
||||
重建 RAG 索引
|
||||
|
||||
从 MongoDB 中读取所有文档,重新构建向量索引
|
||||
"""
|
||||
from app.core.database import mongodb
|
||||
|
||||
try:
|
||||
# 清空现有索引
|
||||
rag_service.clear()
|
||||
|
||||
# 从 MongoDB 读取所有文档
|
||||
cursor = mongodb.documents.find({})
|
||||
count = 0
|
||||
|
||||
async for doc in cursor:
|
||||
content = doc.get("content", "")
|
||||
if content:
|
||||
rag_service.index_document_content(
|
||||
doc_id=str(doc["_id"]),
|
||||
content=content[:5000],
|
||||
metadata={
|
||||
"filename": doc.get("metadata", {}).get("filename"),
|
||||
"doc_type": doc.get("doc_type")
|
||||
}
|
||||
)
|
||||
count += 1
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"已重建索引,共处理 {count} 个文档"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"重建索引失败: {str(e)}")
|
||||
116
backend/app/api/endpoints/tasks.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
任务管理 API 接口
|
||||
|
||||
提供异步任务状态查询和历史记录
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.core.database import redis_db, mongodb
|
||||
|
||||
router = APIRouter(prefix="/tasks", tags=["任务管理"])
|
||||
|
||||
|
||||
@router.get("/{task_id}")
|
||||
async def get_task_status(task_id: str):
|
||||
"""
|
||||
查询任务状态
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
任务状态信息
|
||||
"""
|
||||
# 优先从 Redis 获取
|
||||
status = await redis_db.get_task_status(task_id)
|
||||
|
||||
if status:
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"status": status.get("status", "unknown"),
|
||||
"progress": status.get("meta", {}).get("progress", 0),
|
||||
"message": status.get("meta", {}).get("message"),
|
||||
"result": status.get("meta", {}).get("result"),
|
||||
"error": status.get("meta", {}).get("error")
|
||||
}
|
||||
|
||||
# Redis 不可用时,尝试从 MongoDB 获取
|
||||
mongo_task = await mongodb.get_task(task_id)
|
||||
if mongo_task:
|
||||
return {
|
||||
"task_id": mongo_task.get("task_id"),
|
||||
"status": mongo_task.get("status", "unknown"),
|
||||
"progress": 100 if mongo_task.get("status") == "success" else 0,
|
||||
"message": mongo_task.get("message"),
|
||||
"result": mongo_task.get("result"),
|
||||
"error": mongo_task.get("error")
|
||||
}
|
||||
|
||||
# 任务不存在或状态未知
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"status": "unknown",
|
||||
"progress": 0,
|
||||
"message": "无法获取任务状态(Redis和MongoDB均不可用)",
|
||||
"result": None,
|
||||
"error": None
|
||||
}
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_tasks(limit: int = 50, skip: int = 0):
|
||||
"""
|
||||
获取任务历史列表
|
||||
|
||||
Args:
|
||||
limit: 返回数量限制
|
||||
skip: 跳过数量
|
||||
|
||||
Returns:
|
||||
任务列表
|
||||
"""
|
||||
try:
|
||||
tasks = await mongodb.list_tasks(limit=limit, skip=skip)
|
||||
return {
|
||||
"success": True,
|
||||
"tasks": tasks,
|
||||
"count": len(tasks)
|
||||
}
|
||||
except Exception as e:
|
||||
# MongoDB 不可用时返回空列表
|
||||
return {
|
||||
"success": False,
|
||||
"tasks": [],
|
||||
"count": 0,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/{task_id}")
|
||||
async def delete_task(task_id: str):
|
||||
"""
|
||||
删除任务
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
# 从 Redis 删除
|
||||
if redis_db._connected and redis_db.client:
|
||||
key = f"task:{task_id}"
|
||||
await redis_db.client.delete(key)
|
||||
|
||||
# 从 MongoDB 删除
|
||||
deleted = await mongodb.delete_task(task_id)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"deleted": deleted
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"删除任务失败: {str(e)}")
|
||||
722
backend/app/api/endpoints/templates.py
Normal file
@@ -0,0 +1,722 @@
|
||||
"""
|
||||
表格模板 API 接口
|
||||
|
||||
提供模板上传、解析和填写功能
|
||||
"""
|
||||
import io
|
||||
import logging
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, File, HTTPException, Query, UploadFile, BackgroundTasks
|
||||
from fastapi.responses import StreamingResponse
|
||||
import pandas as pd
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.services.template_fill_service import template_fill_service, TemplateField
|
||||
from app.services.file_service import file_service
|
||||
from app.core.database import mongodb
|
||||
from app.core.document_parser import ParserFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/templates", tags=["表格模板"])
|
||||
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
async def update_task_status(
|
||||
task_id: str,
|
||||
status: str,
|
||||
progress: int = 0,
|
||||
message: str = "",
|
||||
result: dict = None,
|
||||
error: str = None
|
||||
):
|
||||
"""
|
||||
更新任务状态,同时写入 Redis 和 MongoDB
|
||||
"""
|
||||
from app.core.database import redis_db
|
||||
|
||||
meta = {"progress": progress, "message": message}
|
||||
if result:
|
||||
meta["result"] = result
|
||||
if error:
|
||||
meta["error"] = error
|
||||
|
||||
try:
|
||||
await redis_db.set_task_status(task_id, status, meta)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis 任务状态更新失败: {e}")
|
||||
|
||||
try:
|
||||
await mongodb.update_task(
|
||||
task_id=task_id,
|
||||
status=status,
|
||||
message=message,
|
||||
result=result,
|
||||
error=error
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"MongoDB 任务状态更新失败: {e}")
|
||||
|
||||
|
||||
# ==================== 请求/响应模型 ====================
|
||||
|
||||
class TemplateFieldRequest(BaseModel):
|
||||
"""模板字段请求"""
|
||||
cell: str
|
||||
name: str
|
||||
field_type: str = "text"
|
||||
required: bool = True
|
||||
hint: str = ""
|
||||
|
||||
|
||||
class FillRequest(BaseModel):
|
||||
"""填写请求"""
|
||||
template_id: str
|
||||
template_fields: List[TemplateFieldRequest]
|
||||
source_doc_ids: Optional[List[str]] = None # MongoDB 文档 ID 列表
|
||||
source_file_paths: Optional[List[str]] = None # 源文档文件路径列表
|
||||
user_hint: Optional[str] = None
|
||||
task_id: Optional[str] = None # 可选的任务ID,用于任务历史跟踪
|
||||
|
||||
|
||||
class ExportRequest(BaseModel):
|
||||
"""导出请求"""
|
||||
template_id: str
|
||||
filled_data: dict
|
||||
format: str = "xlsx" # xlsx 或 docx
|
||||
|
||||
|
||||
# ==================== 接口实现 ====================
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload_template(
|
||||
file: UploadFile = File(...),
|
||||
):
|
||||
"""
|
||||
上传表格模板文件
|
||||
|
||||
支持 Excel (.xlsx, .xls) 和 Word (.docx) 格式
|
||||
|
||||
Returns:
|
||||
模板信息,包括提取的字段列表
|
||||
"""
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="文件名为空")
|
||||
|
||||
file_ext = file.filename.split('.')[-1].lower()
|
||||
if file_ext not in ['xlsx', 'xls', 'docx']:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的模板格式: {file_ext},仅支持 xlsx/xls/docx"
|
||||
)
|
||||
|
||||
try:
|
||||
# 保存文件
|
||||
content = await file.read()
|
||||
saved_path = file_service.save_uploaded_file(
|
||||
content,
|
||||
file.filename,
|
||||
subfolder="templates"
|
||||
)
|
||||
|
||||
# 提取字段
|
||||
template_fields = await template_fill_service.get_template_fields_from_file(
|
||||
saved_path,
|
||||
file_ext
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"template_id": saved_path,
|
||||
"filename": file.filename,
|
||||
"file_type": file_ext,
|
||||
"fields": [
|
||||
{
|
||||
"cell": f.cell,
|
||||
"name": f.name,
|
||||
"field_type": f.field_type,
|
||||
"required": f.required,
|
||||
"hint": f.hint
|
||||
}
|
||||
for f in template_fields
|
||||
],
|
||||
"field_count": len(template_fields)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"上传模板失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/upload-joint")
|
||||
async def upload_joint_template(
|
||||
background_tasks: BackgroundTasks,
|
||||
template_file: UploadFile = File(..., description="模板文件"),
|
||||
source_files: List[UploadFile] = File(..., description="源文档文件列表"),
|
||||
):
|
||||
"""
|
||||
联合上传模板和源文档,一键完成解析和存储
|
||||
|
||||
1. 保存模板文件并提取字段
|
||||
2. 异步处理源文档(解析+存MongoDB)
|
||||
3. 返回模板信息和源文档ID列表
|
||||
|
||||
Args:
|
||||
template_file: 模板文件 (xlsx/xls/docx)
|
||||
source_files: 源文档列表 (docx/xlsx/md/txt)
|
||||
|
||||
Returns:
|
||||
模板ID、字段列表、源文档ID列表
|
||||
"""
|
||||
if not template_file.filename:
|
||||
raise HTTPException(status_code=400, detail="模板文件名为空")
|
||||
|
||||
# 验证模板格式
|
||||
template_ext = template_file.filename.split('.')[-1].lower()
|
||||
if template_ext not in ['xlsx', 'xls', 'docx']:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的模板格式: {template_ext},仅支持 xlsx/xls/docx"
|
||||
)
|
||||
|
||||
# 验证源文档格式
|
||||
valid_exts = ['docx', 'xlsx', 'xls', 'md', 'txt']
|
||||
for sf in source_files:
|
||||
if sf.filename:
|
||||
sf_ext = sf.filename.split('.')[-1].lower()
|
||||
if sf_ext not in valid_exts:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的源文档格式: {sf_ext},仅支持 docx/xlsx/xls/md/txt"
|
||||
)
|
||||
|
||||
try:
|
||||
# 1. 保存模板文件
|
||||
template_content = await template_file.read()
|
||||
template_path = file_service.save_uploaded_file(
|
||||
template_content,
|
||||
template_file.filename,
|
||||
subfolder="templates"
|
||||
)
|
||||
|
||||
# 2. 保存并解析源文档 - 提取内容用于生成表头
|
||||
source_file_info = []
|
||||
source_contents = []
|
||||
for sf in source_files:
|
||||
if sf.filename:
|
||||
sf_content = await sf.read()
|
||||
sf_ext = sf.filename.split('.')[-1].lower()
|
||||
sf_path = file_service.save_uploaded_file(
|
||||
sf_content,
|
||||
sf.filename,
|
||||
subfolder=sf_ext
|
||||
)
|
||||
source_file_info.append({
|
||||
"path": sf_path,
|
||||
"filename": sf.filename,
|
||||
"ext": sf_ext
|
||||
})
|
||||
# 解析源文档获取内容(用于 AI 生成表头)
|
||||
try:
|
||||
from app.core.document_parser import ParserFactory
|
||||
parser = ParserFactory.get_parser(sf_path)
|
||||
parse_result = parser.parse(sf_path)
|
||||
if parse_result.success and parse_result.data:
|
||||
# 获取原始内容
|
||||
content = parse_result.data.get("content", "")[:5000] if parse_result.data.get("content") else ""
|
||||
|
||||
# 获取标题(可能在顶层或structured_data内)
|
||||
titles = parse_result.data.get("titles", [])
|
||||
if not titles and parse_result.data.get("structured_data"):
|
||||
titles = parse_result.data.get("structured_data", {}).get("titles", [])
|
||||
titles = titles[:10] if titles else []
|
||||
|
||||
# 获取表格数量(可能在顶层或structured_data内)
|
||||
tables = parse_result.data.get("tables", [])
|
||||
if not tables and parse_result.data.get("structured_data"):
|
||||
tables = parse_result.data.get("structured_data", {}).get("tables", [])
|
||||
tables_count = len(tables) if tables else 0
|
||||
|
||||
# 获取表格内容摘要(用于 AI 理解源文档结构)
|
||||
tables_summary = ""
|
||||
if tables:
|
||||
tables_summary = "\n【文档中的表格】:\n"
|
||||
for idx, table in enumerate(tables[:5]): # 最多5个表格
|
||||
if isinstance(table, dict):
|
||||
headers = table.get("headers", [])
|
||||
rows = table.get("rows", [])
|
||||
if headers:
|
||||
tables_summary += f"表格{idx+1}表头: {', '.join(str(h) for h in headers)}\n"
|
||||
if rows:
|
||||
tables_summary += f"表格{idx+1}前3行: "
|
||||
for row_idx, row in enumerate(rows[:3]):
|
||||
if isinstance(row, list):
|
||||
tables_summary += " | ".join(str(c) for c in row) + "; "
|
||||
elif isinstance(row, dict):
|
||||
tables_summary += " | ".join(str(row.get(h, "")) for h in headers if headers) + "; "
|
||||
tables_summary += "\n"
|
||||
|
||||
source_contents.append({
|
||||
"filename": sf.filename,
|
||||
"doc_type": sf_ext,
|
||||
"content": content,
|
||||
"titles": titles,
|
||||
"tables_count": tables_count,
|
||||
"tables_summary": tables_summary
|
||||
})
|
||||
logger.info(f"[DEBUG] source_contents built: filename={sf.filename}, content_len={len(content)}, titles_count={len(titles)}, tables_count={tables_count}")
|
||||
if tables_summary:
|
||||
logger.info(f"[DEBUG] tables_summary preview: {tables_summary[:300]}")
|
||||
except Exception as e:
|
||||
logger.warning(f"解析源文档失败 {sf.filename}: {e}")
|
||||
|
||||
# 3. 根据源文档内容生成表头
|
||||
template_fields = await template_fill_service.get_template_fields_from_file(
|
||||
template_path,
|
||||
template_ext,
|
||||
source_contents=source_contents # 传递源文档内容
|
||||
)
|
||||
|
||||
# 3. 异步处理源文档到MongoDB
|
||||
task_id = str(uuid.uuid4())
|
||||
if source_file_info:
|
||||
# 保存任务记录到 MongoDB
|
||||
try:
|
||||
await mongodb.insert_task(
|
||||
task_id=task_id,
|
||||
task_type="source_process",
|
||||
status="pending",
|
||||
message=f"开始处理 {len(source_file_info)} 个源文档"
|
||||
)
|
||||
except Exception as mongo_err:
|
||||
logger.warning(f"MongoDB 保存任务记录失败: {mongo_err}")
|
||||
|
||||
background_tasks.add_task(
|
||||
process_source_documents,
|
||||
task_id=task_id,
|
||||
files=source_file_info
|
||||
)
|
||||
|
||||
logger.info(f"联合上传完成: 模板={template_file.filename}, 源文档={len(source_file_info)}个")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"template_id": template_path,
|
||||
"filename": template_file.filename,
|
||||
"file_type": template_ext,
|
||||
"fields": [
|
||||
{
|
||||
"cell": f.cell,
|
||||
"name": f.name,
|
||||
"field_type": f.field_type,
|
||||
"required": f.required,
|
||||
"hint": f.hint
|
||||
}
|
||||
for f in template_fields
|
||||
],
|
||||
"field_count": len(template_fields),
|
||||
"source_file_paths": [f["path"] for f in source_file_info],
|
||||
"source_filenames": [f["filename"] for f in source_file_info],
|
||||
"task_id": task_id
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"联合上传失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"联合上传失败: {str(e)}")
|
||||
|
||||
|
||||
async def process_source_documents(task_id: str, files: List[dict]):
|
||||
"""异步处理源文档,存入MongoDB"""
|
||||
try:
|
||||
await update_task_status(
|
||||
task_id, status="processing",
|
||||
progress=0, message="开始处理源文档"
|
||||
)
|
||||
|
||||
doc_ids = []
|
||||
for i, file_info in enumerate(files):
|
||||
try:
|
||||
parser = ParserFactory.get_parser(file_info["path"])
|
||||
result = parser.parse(file_info["path"])
|
||||
|
||||
if result.success:
|
||||
doc_id = await mongodb.insert_document(
|
||||
doc_type=file_info["ext"],
|
||||
content=result.data.get("content", ""),
|
||||
metadata={
|
||||
**result.metadata,
|
||||
"original_filename": file_info["filename"],
|
||||
"file_path": file_info["path"]
|
||||
},
|
||||
structured_data=result.data.get("structured_data")
|
||||
)
|
||||
doc_ids.append(doc_id)
|
||||
logger.info(f"源文档处理成功: {file_info['filename']}, doc_id: {doc_id}")
|
||||
else:
|
||||
logger.error(f"源文档解析失败: {file_info['filename']}, error: {result.error}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"源文档处理异常: {file_info['filename']}, error: {str(e)}")
|
||||
|
||||
progress = int((i + 1) / len(files) * 100)
|
||||
await update_task_status(
|
||||
task_id, status="processing",
|
||||
progress=progress, message=f"已处理 {i+1}/{len(files)}"
|
||||
)
|
||||
|
||||
await update_task_status(
|
||||
task_id, status="success",
|
||||
progress=100, message="源文档处理完成",
|
||||
result={"doc_ids": doc_ids}
|
||||
)
|
||||
logger.info(f"所有源文档处理完成: {len(doc_ids)}个")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"源文档批量处理失败: {str(e)}")
|
||||
await update_task_status(
|
||||
task_id, status="failure",
|
||||
progress=0, message="源文档处理失败",
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
|
||||
@router.post("/fields")
|
||||
async def extract_template_fields(
|
||||
template_id: str = Query(..., description="模板ID/文件路径"),
|
||||
file_type: str = Query("xlsx", description="文件类型")
|
||||
):
|
||||
"""
|
||||
从已上传的模板提取字段定义
|
||||
|
||||
Args:
|
||||
template_id: 模板ID
|
||||
file_type: 文件类型
|
||||
|
||||
Returns:
|
||||
字段列表
|
||||
"""
|
||||
try:
|
||||
fields = await template_fill_service.get_template_fields_from_file(
|
||||
template_id,
|
||||
file_type
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"fields": [
|
||||
{
|
||||
"cell": f.cell,
|
||||
"name": f.name,
|
||||
"field_type": f.field_type,
|
||||
"required": f.required,
|
||||
"hint": f.hint
|
||||
}
|
||||
for f in fields
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"提取字段失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"提取失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/fill")
|
||||
async def fill_template(
|
||||
request: FillRequest,
|
||||
):
|
||||
"""
|
||||
执行表格填写
|
||||
|
||||
根据提供的字段定义,从源文档中检索信息并填写
|
||||
|
||||
Args:
|
||||
request: 填写请求
|
||||
|
||||
Returns:
|
||||
填写结果
|
||||
"""
|
||||
# 生成或使用传入的 task_id
|
||||
task_id = request.task_id or str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# 创建任务记录到 MongoDB
|
||||
try:
|
||||
await mongodb.insert_task(
|
||||
task_id=task_id,
|
||||
task_type="template_fill",
|
||||
status="processing",
|
||||
message=f"开始填表任务: {len(request.template_fields)} 个字段"
|
||||
)
|
||||
except Exception as mongo_err:
|
||||
logger.warning(f"MongoDB 创建任务记录失败: {mongo_err}")
|
||||
|
||||
# 更新进度 - 开始
|
||||
await update_task_status(
|
||||
task_id, "processing",
|
||||
progress=0, message="开始处理..."
|
||||
)
|
||||
|
||||
# 转换字段
|
||||
fields = [
|
||||
TemplateField(
|
||||
cell=f.cell,
|
||||
name=f.name,
|
||||
field_type=f.field_type,
|
||||
required=f.required,
|
||||
hint=f.hint
|
||||
)
|
||||
for f in request.template_fields
|
||||
]
|
||||
|
||||
# 从 template_id 提取文件类型
|
||||
template_file_type = "xlsx" # 默认类型
|
||||
if request.template_id:
|
||||
ext = request.template_id.split('.')[-1].lower()
|
||||
if ext in ["xlsx", "xls"]:
|
||||
template_file_type = "xlsx"
|
||||
elif ext == "docx":
|
||||
template_file_type = "docx"
|
||||
|
||||
# 更新进度 - 准备开始填写
|
||||
await update_task_status(
|
||||
task_id, "processing",
|
||||
progress=10, message=f"准备填写 {len(fields)} 个字段..."
|
||||
)
|
||||
|
||||
# 执行填写
|
||||
result = await template_fill_service.fill_template(
|
||||
template_fields=fields,
|
||||
source_doc_ids=request.source_doc_ids,
|
||||
source_file_paths=request.source_file_paths,
|
||||
user_hint=request.user_hint,
|
||||
template_id=request.template_id,
|
||||
template_file_type=template_file_type,
|
||||
task_id=task_id
|
||||
)
|
||||
|
||||
# 更新为成功
|
||||
await update_task_status(
|
||||
task_id, "success",
|
||||
progress=100, message="填表完成",
|
||||
result={
|
||||
"field_count": len(fields),
|
||||
"max_rows": result.get("max_rows", 0)
|
||||
}
|
||||
)
|
||||
|
||||
return {**result, "task_id": task_id}
|
||||
|
||||
except Exception as e:
|
||||
# 更新为失败
|
||||
await update_task_status(
|
||||
task_id, "failure",
|
||||
progress=0, message="填表失败",
|
||||
error=str(e)
|
||||
)
|
||||
logger.error(f"填写表格失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"填写失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/export")
|
||||
async def export_filled_template(
|
||||
request: ExportRequest,
|
||||
):
|
||||
"""
|
||||
导出填写后的表格
|
||||
|
||||
支持 Excel (.xlsx) 和 Word (.docx) 格式
|
||||
|
||||
Args:
|
||||
request: 导出请求
|
||||
|
||||
Returns:
|
||||
文件流
|
||||
"""
|
||||
try:
|
||||
if request.format == "xlsx":
|
||||
return await _export_to_excel(request.filled_data, request.template_id)
|
||||
elif request.format == "docx":
|
||||
return await _export_to_word(request.filled_data, request.template_id)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的导出格式: {request.format},仅支持 xlsx/docx"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"导出失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"导出失败: {str(e)}")
|
||||
|
||||
|
||||
async def _export_to_excel(filled_data: dict, template_id: str) -> StreamingResponse:
|
||||
"""导出为 Excel 格式(支持多行)"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger.info(f"导出填表数据: {len(filled_data)} 个字段")
|
||||
|
||||
# 计算最大行数
|
||||
max_rows = 1
|
||||
for k, v in filled_data.items():
|
||||
if isinstance(v, list) and len(v) > max_rows:
|
||||
max_rows = len(v)
|
||||
logger.info(f" {k}: {type(v).__name__} = {str(v)[:80]}")
|
||||
|
||||
logger.info(f"最大行数: {max_rows}")
|
||||
|
||||
# 构建多行数据
|
||||
rows_data = []
|
||||
for row_idx in range(max_rows):
|
||||
row = {}
|
||||
for col_name, values in filled_data.items():
|
||||
if isinstance(values, list):
|
||||
# 取对应行的值,不足则填空
|
||||
row[col_name] = values[row_idx] if row_idx < len(values) else ""
|
||||
else:
|
||||
# 非列表,整个值填入第一行
|
||||
row[col_name] = values if row_idx == 0 else ""
|
||||
rows_data.append(row)
|
||||
|
||||
df = pd.DataFrame(rows_data)
|
||||
|
||||
# 确保列顺序
|
||||
if not df.empty:
|
||||
df = df[list(filled_data.keys())]
|
||||
|
||||
logger.info(f"DataFrame 形状: {df.shape}")
|
||||
logger.info(f"DataFrame 列: {list(df.columns)}")
|
||||
|
||||
output = io.BytesIO()
|
||||
with pd.ExcelWriter(output, engine='openpyxl') as writer:
|
||||
df.to_excel(writer, index=False, sheet_name='填写结果')
|
||||
|
||||
output.seek(0)
|
||||
|
||||
filename = f"filled_template.xlsx"
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(output.getvalue()),
|
||||
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
headers={"Content-Disposition": f"attachment; filename={filename}"}
|
||||
)
|
||||
|
||||
|
||||
async def _export_to_word(filled_data: dict, template_id: str) -> StreamingResponse:
|
||||
"""导出为 Word 格式"""
|
||||
import re
|
||||
import tempfile
|
||||
import os
|
||||
from docx import Document
|
||||
from docx.shared import Pt, RGBColor
|
||||
from docx.enum.text import WD_ALIGN_PARAGRAPH
|
||||
|
||||
def clean_text(text: str) -> str:
|
||||
"""清理文本,移除可能导致Word问题的非法字符"""
|
||||
if not text:
|
||||
return ""
|
||||
# 移除控制字符
|
||||
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text)
|
||||
return text.strip()
|
||||
|
||||
try:
|
||||
# 先保存到临时文件,再读取到内存,确保文档完整性
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.docx') as tmp_file:
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
doc = Document()
|
||||
doc.add_heading('填写结果', level=1)
|
||||
|
||||
from datetime import datetime
|
||||
info_para = doc.add_paragraph()
|
||||
template_filename = template_id.split('/')[-1].split('\\')[-1] if template_id else '未知'
|
||||
info_para.add_run(f"模板文件: {clean_text(template_filename)}\n").bold = True
|
||||
info_para.add_run(f"导出时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
doc.add_paragraph()
|
||||
|
||||
table = doc.add_table(rows=1, cols=3)
|
||||
table.style = 'Table Grid'
|
||||
|
||||
header_cells = table.rows[0].cells
|
||||
header_cells[0].text = '字段名'
|
||||
header_cells[1].text = '填写值'
|
||||
header_cells[2].text = '状态'
|
||||
|
||||
for field_name, field_value in filled_data.items():
|
||||
row_cells = table.add_row().cells
|
||||
row_cells[0].text = clean_text(str(field_name))
|
||||
|
||||
if isinstance(field_value, list):
|
||||
clean_values = [clean_text(str(v)) for v in field_value if v]
|
||||
display_value = ', '.join(clean_values) if clean_values else ''
|
||||
else:
|
||||
display_value = clean_text(str(field_value)) if field_value else ''
|
||||
|
||||
row_cells[1].text = display_value
|
||||
row_cells[2].text = '已填写' if display_value else '为空'
|
||||
|
||||
# 保存到临时文件
|
||||
doc.save(tmp_path)
|
||||
|
||||
# 读取文件内容
|
||||
with open(tmp_path, 'rb') as f:
|
||||
file_content = f.read()
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if os.path.exists(tmp_path):
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
output = io.BytesIO(file_content)
|
||||
filename = "filled_template.docx"
|
||||
|
||||
return StreamingResponse(
|
||||
output,
|
||||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{filename}"}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/export/excel")
|
||||
async def export_to_excel(
|
||||
filled_data: dict,
|
||||
template_id: str = Query(..., description="模板ID")
|
||||
):
|
||||
"""
|
||||
专门导出为 Excel 格式
|
||||
|
||||
Args:
|
||||
filled_data: 填写数据
|
||||
template_id: 模板ID
|
||||
|
||||
Returns:
|
||||
Excel 文件流
|
||||
"""
|
||||
return await _export_to_excel(filled_data, template_id)
|
||||
|
||||
|
||||
@router.post("/export/word")
|
||||
async def export_to_word(
|
||||
filled_data: dict,
|
||||
template_id: str = Query(..., description="模板ID")
|
||||
):
|
||||
"""
|
||||
专门导出为 Word 格式
|
||||
|
||||
Args:
|
||||
filled_data: 填写数据
|
||||
template_id: 模板ID
|
||||
|
||||
Returns:
|
||||
Word 文件流
|
||||
"""
|
||||
return await _export_to_word(filled_data, template_id)
|
||||
275
backend/app/api/endpoints/upload.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""
|
||||
文件上传 API 接口
|
||||
"""
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing import Optional
|
||||
import logging
|
||||
import os
|
||||
import pandas as pd
|
||||
import io
|
||||
|
||||
from app.services.file_service import file_service
|
||||
from app.core.document_parser import XlsxParser
|
||||
from app.services.table_rag_service import table_rag_service
|
||||
from app.core.database import mongodb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/upload", tags=["文件上传"])
|
||||
|
||||
# 初始化解析器
|
||||
excel_parser = XlsxParser()
|
||||
|
||||
|
||||
@router.post("/excel")
|
||||
async def upload_excel(
|
||||
file: UploadFile = File(...),
|
||||
parse_all_sheets: bool = Query(False, description="是否解析所有工作表"),
|
||||
sheet_name: Optional[str] = Query(None, description="指定解析的工作表名称"),
|
||||
header_row: int = Query(0, description="表头所在的行索引")
|
||||
):
|
||||
"""
|
||||
上传并解析 Excel 文件,同时存储到 MySQL 数据库
|
||||
|
||||
Args:
|
||||
file: 上传的 Excel 文件
|
||||
parse_all_sheets: 是否解析所有工作表
|
||||
sheet_name: 指定解析的工作表名称
|
||||
header_row: 表头所在的行索引
|
||||
|
||||
Returns:
|
||||
dict: 解析结果
|
||||
"""
|
||||
# 检查文件类型
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="文件名为空")
|
||||
|
||||
file_ext = file.filename.split('.')[-1].lower()
|
||||
if file_ext not in ['xlsx', 'xls']:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的文件类型: {file_ext},仅支持 .xlsx 和 .xls"
|
||||
)
|
||||
|
||||
try:
|
||||
# 读取文件内容
|
||||
content = await file.read()
|
||||
|
||||
# 保存文件
|
||||
saved_path = file_service.save_uploaded_file(
|
||||
content,
|
||||
file.filename,
|
||||
subfolder="excel"
|
||||
)
|
||||
|
||||
logger.info(f"文件已保存: {saved_path}")
|
||||
|
||||
# 解析文件
|
||||
if parse_all_sheets:
|
||||
result = excel_parser.parse_all_sheets(saved_path)
|
||||
else:
|
||||
# 如果指定了 sheet_name,使用指定的,否则使用默认的第一个
|
||||
if sheet_name:
|
||||
result = excel_parser.parse(saved_path, sheet_name=sheet_name, header_row=header_row)
|
||||
else:
|
||||
result = excel_parser.parse(saved_path, header_row=header_row)
|
||||
|
||||
# 添加文件路径到元数据
|
||||
if result.metadata:
|
||||
result.metadata['saved_path'] = saved_path
|
||||
result.metadata['original_filename'] = file.filename
|
||||
|
||||
# 存储到 MySQL 数据库
|
||||
try:
|
||||
store_result = await table_rag_service.build_table_rag_index(
|
||||
file_path=saved_path,
|
||||
filename=file.filename,
|
||||
sheet_name=sheet_name if sheet_name else None,
|
||||
header_row=header_row
|
||||
)
|
||||
if store_result.get("success"):
|
||||
result.metadata['mysql_table'] = store_result.get('table_name')
|
||||
result.metadata['row_count'] = store_result.get('row_count')
|
||||
logger.info(f"Excel已存储到MySQL: {file.filename}, 表: {store_result.get('table_name')}")
|
||||
else:
|
||||
logger.warning(f"Excel存储到MySQL失败: {store_result.get('error')}")
|
||||
except Exception as e:
|
||||
logger.error(f"Excel存储到MySQL异常: {str(e)}", exc_info=True)
|
||||
|
||||
# 存储到 MongoDB(用于文档列表展示)
|
||||
try:
|
||||
content = ""
|
||||
# 构建文本内容用于展示
|
||||
if result.data:
|
||||
if isinstance(result.data, dict):
|
||||
# 单 sheet 格式: {columns, rows, ...}
|
||||
if 'columns' in result.data and 'rows' in result.data:
|
||||
content += f"Sheet: {result.metadata.get('current_sheet', 'Sheet1') if result.metadata else 'Sheet1'}\n"
|
||||
content += ", ".join(str(h) for h in result.data['columns']) + "\n"
|
||||
for row in result.data['rows'][:100]:
|
||||
if isinstance(row, dict):
|
||||
content += ", ".join(str(row.get(col, "")) for col in result.data['columns']) + "\n"
|
||||
elif isinstance(row, list):
|
||||
content += ", ".join(str(cell) for cell in row) + "\n"
|
||||
content += f"... (共 {len(result.data['rows'])} 行)\n\n"
|
||||
# 多 sheet 格式: {sheets: {sheet_name: {columns, rows}}}
|
||||
elif 'sheets' in result.data:
|
||||
for sheet_name_key, sheet_data in result.data['sheets'].items():
|
||||
if isinstance(sheet_data, dict) and 'columns' in sheet_data and 'rows' in sheet_data:
|
||||
content += f"Sheet: {sheet_name_key}\n"
|
||||
content += ", ".join(str(h) for h in sheet_data['columns']) + "\n"
|
||||
for row in sheet_data['rows'][:100]:
|
||||
if isinstance(row, dict):
|
||||
content += ", ".join(str(row.get(col, "")) for col in sheet_data['columns']) + "\n"
|
||||
elif isinstance(row, list):
|
||||
content += ", ".join(str(cell) for cell in row) + "\n"
|
||||
content += f"... (共 {len(sheet_data['rows'])} 行)\n\n"
|
||||
|
||||
doc_metadata = {
|
||||
"filename": os.path.basename(saved_path),
|
||||
"original_filename": file.filename,
|
||||
"saved_path": saved_path,
|
||||
"file_size": len(content),
|
||||
"row_count": result.metadata.get('row_count', 0) if result.metadata else 0,
|
||||
"column_count": result.metadata.get('column_count', 0) if result.metadata else 0,
|
||||
"columns": result.metadata.get('columns', []) if result.metadata else [],
|
||||
"mysql_table": result.metadata.get('mysql_table') if result.metadata else None,
|
||||
"sheet_count": result.metadata.get('sheet_count', 1) if result.metadata else 1,
|
||||
}
|
||||
await mongodb.insert_document(
|
||||
doc_type="xlsx",
|
||||
content=content,
|
||||
metadata=doc_metadata,
|
||||
structured_data=result.data if result.data else None
|
||||
)
|
||||
logger.info(f"Excel文档已存储到MongoDB: {file.filename}, content长度: {len(content)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Excel存储到MongoDB异常: {str(e)}", exc_info=True)
|
||||
|
||||
return result.to_dict()
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Excel 文件时出错: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"解析失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/excel/preview/{file_path:path}")
|
||||
async def get_excel_preview(
|
||||
file_path: str,
|
||||
sheet_name: Optional[str] = Query(None, description="工作表名称"),
|
||||
max_rows: int = Query(10, description="最多返回的行数", ge=1, le=100)
|
||||
):
|
||||
"""
|
||||
获取 Excel 文件的预览数据
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
sheet_name: 工作表名称
|
||||
max_rows: 最多返回的行数
|
||||
|
||||
Returns:
|
||||
dict: 预览数据
|
||||
"""
|
||||
try:
|
||||
# 解析工作表名称参数
|
||||
sheet_param = sheet_name if sheet_name else 0
|
||||
|
||||
result = excel_parser.get_sheet_preview(
|
||||
file_path,
|
||||
sheet_name=sheet_param,
|
||||
max_rows=max_rows
|
||||
)
|
||||
|
||||
return result.to_dict()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取预览数据时出错: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取预览失败: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/file")
|
||||
async def delete_uploaded_file(file_path: str = Query(..., description="要删除的文件路径")):
|
||||
"""
|
||||
删除已上传的文件
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
dict: 删除结果
|
||||
"""
|
||||
try:
|
||||
success = file_service.delete_file(file_path)
|
||||
|
||||
if success:
|
||||
return {"success": True, "message": "文件删除成功"}
|
||||
else:
|
||||
return {"success": False, "message": "文件不存在或删除失败"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除文件时出错: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/excel/export/{file_path:path}")
|
||||
async def export_excel(
|
||||
file_path: str,
|
||||
sheet_name: Optional[str] = Query(None, description="工作表名称"),
|
||||
columns: Optional[str] = Query(None, description="要导出的列,逗号分隔")
|
||||
):
|
||||
"""
|
||||
导出 Excel 文件(可选择工作表和列)
|
||||
|
||||
Args:
|
||||
file_path: 原始文件路径
|
||||
sheet_name: 工作表名称(可选)
|
||||
columns: 要导出的列名,逗号分隔(可选)
|
||||
|
||||
Returns:
|
||||
StreamingResponse: Excel 文件
|
||||
"""
|
||||
try:
|
||||
# 读取 Excel 文件
|
||||
if sheet_name:
|
||||
df = pd.read_excel(file_path, sheet_name=sheet_name)
|
||||
else:
|
||||
df = pd.read_excel(file_path)
|
||||
|
||||
# 如果指定了列,只选择这些列
|
||||
if columns:
|
||||
column_list = [col.strip() for col in columns.split(',')]
|
||||
# 过滤掉不存在的列
|
||||
available_columns = [col for col in column_list if col in df.columns]
|
||||
if available_columns:
|
||||
df = df[available_columns]
|
||||
|
||||
# 创建 Excel 文件
|
||||
output = io.BytesIO()
|
||||
with pd.ExcelWriter(output, engine='openpyxl') as writer:
|
||||
df.to_excel(writer, index=False, sheet_name=sheet_name or 'Sheet1')
|
||||
|
||||
output.seek(0)
|
||||
|
||||
# 生成文件名
|
||||
original_name = os.path.basename(file_path)
|
||||
if columns:
|
||||
export_name = f"export_{sheet_name or 'data'}_{len(column_list) if columns else 'all'}_cols.xlsx"
|
||||
else:
|
||||
export_name = f"export_{original_name}"
|
||||
|
||||
# 返回文件流
|
||||
return StreamingResponse(
|
||||
io.BytesIO(output.getvalue()),
|
||||
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
headers={"Content-Disposition": f"attachment; filename={export_name}"}
|
||||
)
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"文件不存在: {file_path}")
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
except Exception as e:
|
||||
logger.error(f"导出 Excel 文件时出错: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"导出失败: {str(e)}")
|
||||
90
backend/app/api/endpoints/visualization.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
可视化 API 接口 - 生成统计图表
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Body
|
||||
from typing import Dict, Any
|
||||
import logging
|
||||
|
||||
from app.services.visualization_service import visualization_service
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/visualization", tags=["数据可视化"])
|
||||
|
||||
|
||||
class StatisticsRequest(BaseModel):
|
||||
"""统计图表生成请求模型"""
|
||||
excel_data: Dict[str, Any]
|
||||
analysis_type: str = "statistics"
|
||||
|
||||
|
||||
@router.post("/statistics")
|
||||
async def generate_statistics(request: StatisticsRequest):
|
||||
"""
|
||||
生成统计信息和可视化图表
|
||||
|
||||
Args:
|
||||
request: 包含 excel_data 和 analysis_type 的请求体
|
||||
|
||||
Returns:
|
||||
dict: 包含统计信息和图表数据的结果
|
||||
"""
|
||||
excel_data = request.excel_data
|
||||
analysis_type = request.analysis_type
|
||||
|
||||
if not excel_data:
|
||||
raise HTTPException(status_code=400, detail="未提供 Excel 数据")
|
||||
|
||||
try:
|
||||
result = visualization_service.analyze_and_visualize(
|
||||
excel_data,
|
||||
analysis_type
|
||||
)
|
||||
|
||||
if not result.get("success"):
|
||||
raise HTTPException(status_code=500, detail=result.get("error", "分析失败"))
|
||||
|
||||
logger.info("统计图表生成成功")
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"统计图表生成失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"图表生成失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/chart-types")
|
||||
async def get_chart_types():
|
||||
"""
|
||||
获取支持的图表类型
|
||||
|
||||
Returns:
|
||||
dict: 支持的图表类型列表
|
||||
"""
|
||||
return {
|
||||
"chart_types": [
|
||||
{
|
||||
"value": "histogram",
|
||||
"label": "直方图",
|
||||
"description": "显示数值型列的分布情况"
|
||||
},
|
||||
{
|
||||
"value": "bar_chart",
|
||||
"label": "条形图",
|
||||
"description": "显示分类列的频次分布"
|
||||
},
|
||||
{
|
||||
"value": "box_plot",
|
||||
"label": "箱线图",
|
||||
"description": "显示数值列的四分位数和异常值"
|
||||
},
|
||||
{
|
||||
"value": "correlation_heatmap",
|
||||
"label": "相关性热力图",
|
||||
"description": "显示数值列之间的相关性"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -6,26 +6,67 @@ class Settings(BaseSettings):
|
||||
APP_NAME: str = "FilesReadSystem"
|
||||
DEBUG: bool = True
|
||||
API_V1_STR: str = "/api/v1"
|
||||
|
||||
# 数据库
|
||||
MONGODB_URL: str
|
||||
MONGODB_DB_NAME: str
|
||||
REDIS_URL: str
|
||||
|
||||
# AI 相关
|
||||
LLM_API_KEY: str
|
||||
LLM_BASE_URL: str
|
||||
LLM_MODEL_NAME: str
|
||||
|
||||
# 文件路径
|
||||
|
||||
# ==================== 数据库配置 ====================
|
||||
|
||||
# MongoDB 配置 (非结构化数据存储)
|
||||
MONGODB_URL: str = "mongodb://localhost:27017"
|
||||
MONGODB_DB_NAME: str = "document_system"
|
||||
|
||||
# MySQL 配置 (结构化数据存储)
|
||||
MYSQL_HOST: str = "localhost"
|
||||
MYSQL_PORT: int = 3306
|
||||
MYSQL_USER: str = "root"
|
||||
MYSQL_PASSWORD: str = ""
|
||||
MYSQL_DATABASE: str = "document_system"
|
||||
MYSQL_CHARSET: str = "utf8mb4"
|
||||
|
||||
# Redis 配置 (缓存/任务队列)
|
||||
REDIS_URL: str = "redis://localhost:6379/0"
|
||||
|
||||
# ==================== AI 相关配置 ====================
|
||||
LLM_API_KEY: str = ""
|
||||
LLM_BASE_URL: str = "https://api.minimax.chat"
|
||||
LLM_MODEL_NAME: str = "MiniMax-Text-01"
|
||||
|
||||
# ==================== RAG/Embedding 配置 ====================
|
||||
EMBEDDING_MODEL: str = "all-MiniLM-L6-v2"
|
||||
|
||||
# ==================== Supabase 配置 ====================
|
||||
SUPABASE_URL: str = ""
|
||||
SUPABASE_ANON_KEY: str = ""
|
||||
SUPABASE_SERVICE_KEY: str = ""
|
||||
|
||||
# ==================== 文件路径配置 ====================
|
||||
BASE_DIR: Path = Path(__file__).resolve().parent.parent.parent
|
||||
UPLOAD_DIR: str = "data/uploads"
|
||||
|
||||
|
||||
# ==================== RAG/向量数据库配置 ====================
|
||||
FAISS_INDEX_DIR: str = "data/faiss"
|
||||
|
||||
# 允许 Pydantic 从 .env 文件读取
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=Path(__file__).parent.parent / ".env",
|
||||
env_file=Path(__file__).parent.parent / ".env",
|
||||
env_file_encoding='utf-8',
|
||||
extra='ignore'
|
||||
)
|
||||
|
||||
@property
|
||||
def mysql_url(self) -> str:
|
||||
"""生成MySQL连接URL (同步)"""
|
||||
return (
|
||||
f"mysql+pymysql://{self.MYSQL_USER}:{self.MYSQL_PASSWORD}"
|
||||
f"@{self.MYSQL_HOST}:{self.MYSQL_PORT}/{self.MYSQL_DATABASE}"
|
||||
f"?charset={self.MYSQL_CHARSET}"
|
||||
)
|
||||
|
||||
@property
|
||||
def async_mysql_url(self) -> str:
|
||||
"""生成MySQL连接URL (异步)"""
|
||||
return (
|
||||
f"mysql+aiomysql://{self.MYSQL_USER}:{self.MYSQL_PASSWORD}"
|
||||
f"@{self.MYSQL_HOST}:{self.MYSQL_PORT}/{self.MYSQL_DATABASE}"
|
||||
f"?charset={self.MYSQL_CHARSET}"
|
||||
)
|
||||
|
||||
settings = Settings()
|
||||
18
backend/app/core/database/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
数据库连接管理模块
|
||||
|
||||
提供 MySQL、MongoDB、Redis 的连接管理
|
||||
"""
|
||||
from app.core.database.mysql import MySQLDB, mysql_db, Base
|
||||
from app.core.database.mongodb import MongoDB, mongodb
|
||||
from app.core.database.redis_db import RedisDB, redis_db
|
||||
|
||||
__all__ = [
|
||||
"MySQLDB",
|
||||
"mysql_db",
|
||||
"MongoDB",
|
||||
"mongodb",
|
||||
"RedisDB",
|
||||
"redis_db",
|
||||
"Base",
|
||||
]
|
||||
375
backend/app/core/database/mongodb.py
Normal file
@@ -0,0 +1,375 @@
|
||||
"""
|
||||
MongoDB 数据库连接管理模块
|
||||
|
||||
提供非结构化数据的存储和查询功能
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MongoDB:
|
||||
"""MongoDB 数据库管理类"""
|
||||
|
||||
def __init__(self):
|
||||
self.client: Optional[AsyncIOMotorClient] = None
|
||||
self.db: Optional[AsyncIOMotorDatabase] = None
|
||||
|
||||
async def connect(self):
|
||||
"""建立 MongoDB 连接"""
|
||||
try:
|
||||
self.client = AsyncIOMotorClient(
|
||||
settings.MONGODB_URL,
|
||||
serverSelectionTimeoutMS=30000, # 30秒超时,适应远程服务器
|
||||
connectTimeoutMS=30000, # 连接超时
|
||||
socketTimeoutMS=60000, # Socket 超时
|
||||
)
|
||||
self.db = self.client[settings.MONGODB_DB_NAME]
|
||||
# 验证连接
|
||||
await self.client.admin.command('ping')
|
||||
logger.info(f"MongoDB 连接成功: {settings.MONGODB_DB_NAME}")
|
||||
except Exception as e:
|
||||
logger.error(f"MongoDB 连接失败: {e}")
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
"""关闭 MongoDB 连接"""
|
||||
if self.client:
|
||||
self.client.close()
|
||||
logger.info("MongoDB 连接已关闭")
|
||||
|
||||
@property
|
||||
def documents(self):
|
||||
"""文档集合 - 存储原始文档和解析结果"""
|
||||
return self.db["documents"]
|
||||
|
||||
@property
|
||||
def embeddings(self):
|
||||
"""向量嵌入集合 - 存储文本嵌入向量"""
|
||||
return self.db["embeddings"]
|
||||
|
||||
@property
|
||||
def rag_index(self):
|
||||
"""RAG索引集合 - 存储字段语义索引"""
|
||||
return self.db["rag_index"]
|
||||
|
||||
@property
|
||||
def tasks(self):
|
||||
"""任务集合 - 存储任务历史记录"""
|
||||
return self.db["tasks"]
|
||||
|
||||
# ==================== 文档操作 ====================
|
||||
|
||||
async def insert_document(
|
||||
self,
|
||||
doc_type: str,
|
||||
content: str,
|
||||
metadata: Dict[str, Any],
|
||||
structured_data: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
插入文档
|
||||
|
||||
Args:
|
||||
doc_type: 文档类型 (docx/xlsx/md/txt)
|
||||
content: 原始文本内容
|
||||
metadata: 元数据
|
||||
structured_data: 结构化数据 (表格等)
|
||||
|
||||
Returns:
|
||||
插入文档的ID
|
||||
"""
|
||||
document = {
|
||||
"doc_type": doc_type,
|
||||
"content": content,
|
||||
"metadata": metadata,
|
||||
"structured_data": structured_data,
|
||||
"created_at": datetime.utcnow(),
|
||||
"updated_at": datetime.utcnow(),
|
||||
}
|
||||
result = await self.documents.insert_one(document)
|
||||
doc_id = str(result.inserted_id)
|
||||
filename = metadata.get("original_filename", "unknown")
|
||||
logger.info(f"✓ 文档已存入MongoDB: [{doc_type}] {filename} | ID: {doc_id}")
|
||||
return doc_id
|
||||
|
||||
async def get_document(self, doc_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""根据ID获取文档"""
|
||||
from bson import ObjectId
|
||||
doc = await self.documents.find_one({"_id": ObjectId(doc_id)})
|
||||
if doc:
|
||||
doc["_id"] = str(doc["_id"])
|
||||
return doc
|
||||
|
||||
async def search_documents(
|
||||
self,
|
||||
query: str,
|
||||
doc_type: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
搜索文档
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
doc_type: 文档类型过滤
|
||||
limit: 返回数量
|
||||
|
||||
Returns:
|
||||
文档列表
|
||||
"""
|
||||
filter_query = {"content": {"$regex": query}}
|
||||
if doc_type:
|
||||
filter_query["doc_type"] = doc_type
|
||||
|
||||
cursor = self.documents.find(filter_query).limit(limit)
|
||||
documents = []
|
||||
async for doc in cursor:
|
||||
doc["_id"] = str(doc["_id"])
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
async def delete_document(self, doc_id: str) -> bool:
|
||||
"""删除文档"""
|
||||
from bson import ObjectId
|
||||
result = await self.documents.delete_one({"_id": ObjectId(doc_id)})
|
||||
return result.deleted_count > 0
|
||||
|
||||
# ==================== RAG 索引操作 ====================
|
||||
|
||||
async def insert_rag_entry(
|
||||
self,
|
||||
table_name: str,
|
||||
field_name: str,
|
||||
field_description: str,
|
||||
embedding: List[float],
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
插入RAG索引条目
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
field_name: 字段名
|
||||
field_description: 字段描述
|
||||
embedding: 向量嵌入
|
||||
metadata: 其他元数据
|
||||
|
||||
Returns:
|
||||
插入条目的ID
|
||||
"""
|
||||
entry = {
|
||||
"table_name": table_name,
|
||||
"field_name": field_name,
|
||||
"field_description": field_description,
|
||||
"embedding": embedding,
|
||||
"metadata": metadata or {},
|
||||
"created_at": datetime.utcnow(),
|
||||
}
|
||||
result = await self.rag_index.insert_one(entry)
|
||||
return str(result.inserted_id)
|
||||
|
||||
async def search_rag(
|
||||
self,
|
||||
query_embedding: List[float],
|
||||
top_k: int = 5,
|
||||
table_name: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
搜索RAG索引 (使用向量相似度)
|
||||
|
||||
Args:
|
||||
query_embedding: 查询向量
|
||||
top_k: 返回数量
|
||||
table_name: 可选的表名过滤
|
||||
|
||||
Returns:
|
||||
相关的索引条目
|
||||
"""
|
||||
# MongoDB 5.0+ 支持向量搜索
|
||||
# 较低版本使用欧氏距离替代
|
||||
pipeline = [
|
||||
{
|
||||
"$addFields": {
|
||||
"distance": {
|
||||
"$reduce": {
|
||||
"input": {"$range": [0, {"$size": "$embedding"}]},
|
||||
"initialValue": 0,
|
||||
"in": {
|
||||
"$add": [
|
||||
"$$value",
|
||||
{
|
||||
"$pow": [
|
||||
{
|
||||
"$subtract": [
|
||||
{"$arrayElemAt": ["$embedding", "$$this"]},
|
||||
{"$arrayElemAt": [query_embedding, "$$this"]},
|
||||
]
|
||||
},
|
||||
2,
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{"$sort": {"distance": 1}},
|
||||
{"$limit": top_k},
|
||||
]
|
||||
|
||||
if table_name:
|
||||
pipeline.insert(0, {"$match": {"table_name": table_name}})
|
||||
|
||||
results = []
|
||||
async for doc in self.rag_index.aggregate(pipeline):
|
||||
doc["_id"] = str(doc["_id"])
|
||||
results.append(doc)
|
||||
return results
|
||||
|
||||
# ==================== 集合管理 ====================
|
||||
|
||||
async def create_indexes(self):
|
||||
"""创建索引以优化查询"""
|
||||
# 文档集合索引
|
||||
await self.documents.create_index("doc_type")
|
||||
await self.documents.create_index("created_at")
|
||||
await self.documents.create_index([("content", "text")])
|
||||
|
||||
# RAG索引集合索引
|
||||
await self.rag_index.create_index("table_name")
|
||||
await self.rag_index.create_index("field_name")
|
||||
|
||||
# 任务集合索引
|
||||
await self.tasks.create_index("task_id", unique=True)
|
||||
await self.tasks.create_index("created_at")
|
||||
|
||||
logger.info("MongoDB 索引创建完成")
|
||||
|
||||
# ==================== 任务历史操作 ====================
|
||||
|
||||
async def insert_task(
|
||||
self,
|
||||
task_id: str,
|
||||
task_type: str,
|
||||
status: str = "pending",
|
||||
message: str = "",
|
||||
result: Optional[Dict[str, Any]] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
插入任务记录
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
task_type: 任务类型
|
||||
status: 任务状态
|
||||
message: 任务消息
|
||||
result: 任务结果
|
||||
error: 错误信息
|
||||
|
||||
Returns:
|
||||
插入文档的ID
|
||||
"""
|
||||
task = {
|
||||
"task_id": task_id,
|
||||
"task_type": task_type,
|
||||
"status": status,
|
||||
"message": message,
|
||||
"result": result,
|
||||
"error": error,
|
||||
"created_at": datetime.utcnow(),
|
||||
"updated_at": datetime.utcnow(),
|
||||
}
|
||||
result_obj = await self.tasks.insert_one(task)
|
||||
return str(result_obj.inserted_id)
|
||||
|
||||
async def update_task(
|
||||
self,
|
||||
task_id: str,
|
||||
status: Optional[str] = None,
|
||||
message: Optional[str] = None,
|
||||
result: Optional[Dict[str, Any]] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
更新任务状态
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
status: 任务状态
|
||||
message: 任务消息
|
||||
result: 任务结果
|
||||
error: 错误信息
|
||||
|
||||
Returns:
|
||||
是否更新成功
|
||||
"""
|
||||
from bson import ObjectId
|
||||
|
||||
update_data = {"updated_at": datetime.utcnow()}
|
||||
if status is not None:
|
||||
update_data["status"] = status
|
||||
if message is not None:
|
||||
update_data["message"] = message
|
||||
if result is not None:
|
||||
update_data["result"] = result
|
||||
if error is not None:
|
||||
update_data["error"] = error
|
||||
|
||||
update_result = await self.tasks.update_one(
|
||||
{"task_id": task_id},
|
||||
{"$set": update_data}
|
||||
)
|
||||
return update_result.modified_count > 0
|
||||
|
||||
async def get_task(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""根据task_id获取任务"""
|
||||
task = await self.tasks.find_one({"task_id": task_id})
|
||||
if task:
|
||||
task["_id"] = str(task["_id"])
|
||||
return task
|
||||
|
||||
async def list_tasks(
|
||||
self,
|
||||
limit: int = 50,
|
||||
skip: int = 0,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取任务列表
|
||||
|
||||
Args:
|
||||
limit: 返回数量
|
||||
skip: 跳过数量
|
||||
|
||||
Returns:
|
||||
任务列表
|
||||
"""
|
||||
cursor = self.tasks.find().sort("created_at", -1).skip(skip).limit(limit)
|
||||
tasks = []
|
||||
async for task in cursor:
|
||||
task["_id"] = str(task["_id"])
|
||||
# 转换 datetime 为字符串
|
||||
if task.get("created_at"):
|
||||
task["created_at"] = task["created_at"].isoformat()
|
||||
if task.get("updated_at"):
|
||||
task["updated_at"] = task["updated_at"].isoformat()
|
||||
tasks.append(task)
|
||||
return tasks
|
||||
|
||||
async def delete_task(self, task_id: str) -> bool:
|
||||
"""删除任务"""
|
||||
result = await self.tasks.delete_one({"task_id": task_id})
|
||||
return result.deleted_count > 0
|
||||
|
||||
|
||||
# ==================== 全局单例 ====================
|
||||
|
||||
mongodb = MongoDB()
|
||||
214
backend/app/core/database/mysql.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
MySQL 数据库连接管理模块
|
||||
|
||||
提供结构化数据的存储和查询功能
|
||||
"""
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
DateTime,
|
||||
Enum as SQLEnum,
|
||||
Float,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
create_engine,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
||||
from sqlalchemy.sql import select
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""SQLAlchemy 声明基类"""
|
||||
pass
|
||||
|
||||
|
||||
class MySQLDB:
|
||||
"""MySQL 数据库管理类"""
|
||||
|
||||
def __init__(self):
|
||||
# 异步引擎 (用于 FastAPI 异步操作)
|
||||
self.async_engine = create_async_engine(
|
||||
settings.async_mysql_url,
|
||||
echo=settings.DEBUG, # SQL 日志
|
||||
pool_pre_ping=True, # 连接前检测
|
||||
pool_size=10,
|
||||
max_overflow=20,
|
||||
)
|
||||
|
||||
# 异步会话工厂
|
||||
self.async_session_factory = async_sessionmaker(
|
||||
bind=self.async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
# 同步引擎 (用于 Celery 同步任务)
|
||||
self.sync_engine = create_engine(
|
||||
settings.mysql_url,
|
||||
echo=settings.DEBUG,
|
||||
pool_pre_ping=True,
|
||||
pool_size=5,
|
||||
max_overflow=10,
|
||||
)
|
||||
|
||||
# 同步会话工厂
|
||||
self.sync_session_factory = sessionmaker(
|
||||
bind=self.sync_engine,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
async def init_db(self):
|
||||
"""初始化数据库,创建所有表"""
|
||||
try:
|
||||
# 先创建数据库(如果不存在)
|
||||
from sqlalchemy import text
|
||||
db_name = settings.MYSQL_DATABASE
|
||||
# 连接时不指定数据库来创建数据库
|
||||
temp_url = (
|
||||
f"mysql+aiomysql://{settings.MYSQL_USER}:{settings.MYSQL_PASSWORD}"
|
||||
f"@{settings.MYSQL_HOST}:{settings.MYSQL_PORT}/"
|
||||
f"?charset={settings.MYSQL_CHARSET}"
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
temp_engine = create_async_engine(temp_url, echo=False)
|
||||
try:
|
||||
async with temp_engine.connect() as conn:
|
||||
await conn.execute(text(f"CREATE DATABASE IF NOT EXISTS `{db_name}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"))
|
||||
await conn.commit()
|
||||
logger.info(f"MySQL 数据库 {db_name} 创建或已存在")
|
||||
finally:
|
||||
await temp_engine.dispose()
|
||||
|
||||
# 然后创建表
|
||||
async with self.async_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
logger.info("MySQL 数据库表初始化完成")
|
||||
except Exception as e:
|
||||
logger.error(f"MySQL 数据库初始化失败: {e}")
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
"""关闭数据库连接"""
|
||||
await self.async_engine.dispose()
|
||||
self.sync_engine.dispose()
|
||||
logger.info("MySQL 数据库连接已关闭")
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""获取异步数据库会话"""
|
||||
session = self.async_session_factory()
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def execute_query(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
执行原始 SQL 查询
|
||||
|
||||
Args:
|
||||
query: SQL 查询语句
|
||||
params: 查询参数
|
||||
|
||||
Returns:
|
||||
查询结果列表
|
||||
"""
|
||||
async with self.get_session() as session:
|
||||
result = await session.execute(select(text(query)), params or {})
|
||||
rows = result.fetchall()
|
||||
return [dict(row._mapping) for row in rows]
|
||||
|
||||
async def execute_raw_sql(
|
||||
self,
|
||||
sql: str,
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
) -> Any:
|
||||
"""
|
||||
执行原始 SQL 语句 (INSERT/UPDATE/DELETE)
|
||||
|
||||
Args:
|
||||
sql: SQL 语句
|
||||
params: 语句参数
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
async with self.get_session() as session:
|
||||
result = await session.execute(text(sql), params or {})
|
||||
await session.commit()
|
||||
return result.lastrowid if result.lastrowid else result.rowcount
|
||||
|
||||
|
||||
# ==================== 预定义的数据模型 ====================
|
||||
|
||||
class DocumentTable(Base):
|
||||
"""文档元数据表 - 存储已解析文档的基本信息"""
|
||||
__tablename__ = "document_tables"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
table_name = Column(String(255), unique=True, nullable=False, comment="表名")
|
||||
display_name = Column(String(255), comment="显示名称")
|
||||
description = Column(Text, comment="表描述")
|
||||
source_file = Column(String(512), comment="来源文件")
|
||||
column_count = Column(Integer, default=0, comment="列数")
|
||||
row_count = Column(Integer, default=0, comment="行数")
|
||||
file_size = Column(Integer, comment="文件大小(字节)")
|
||||
created_at = Column(DateTime, comment="创建时间")
|
||||
updated_at = Column(DateTime, comment="更新时间")
|
||||
|
||||
|
||||
class DocumentField(Base):
|
||||
"""文档字段表 - 存储每个表的字段信息"""
|
||||
__tablename__ = "document_fields"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
table_id = Column(Integer, nullable=False, comment="所属表ID")
|
||||
field_name = Column(String(255), nullable=False, comment="字段名")
|
||||
field_type = Column(String(50), comment="字段类型")
|
||||
field_description = Column(Text, comment="字段描述/语义")
|
||||
is_key_field = Column(Integer, default=0, comment="是否主键")
|
||||
is_nullable = Column(Integer, default=1, comment="是否可空")
|
||||
sample_values = Column(Text, comment="示例值(逗号分隔)")
|
||||
created_at = Column(DateTime, comment="创建时间")
|
||||
|
||||
|
||||
class TaskRecord(Base):
|
||||
"""任务记录表 - 存储异步任务信息"""
|
||||
__tablename__ = "task_records"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
task_id = Column(String(255), unique=True, nullable=False, comment="Celery任务ID")
|
||||
task_type = Column(String(50), comment="任务类型")
|
||||
status = Column(String(50), default="pending", comment="任务状态")
|
||||
input_params = Column(Text, comment="输入参数JSON")
|
||||
result_data = Column(Text, comment="结果数据JSON")
|
||||
error_message = Column(Text, comment="错误信息")
|
||||
started_at = Column(DateTime, comment="开始时间")
|
||||
completed_at = Column(DateTime, comment="完成时间")
|
||||
created_at = Column(DateTime, comment="创建时间")
|
||||
|
||||
|
||||
# ==================== 全局单例 ====================
|
||||
|
||||
mysql_db = MySQLDB()
|
||||
308
backend/app/core/database/redis_db.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
Redis 数据库连接管理模块
|
||||
|
||||
提供缓存和任务队列功能
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import redis.asyncio as redis
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisDB:
|
||||
"""Redis 数据库管理类"""
|
||||
|
||||
def __init__(self):
|
||||
self.client: Optional[redis.Redis] = None
|
||||
self._connected = False
|
||||
|
||||
async def connect(self):
|
||||
"""建立 Redis 连接"""
|
||||
try:
|
||||
self.client = redis.from_url(
|
||||
settings.REDIS_URL,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
# 验证连接
|
||||
await self.client.ping()
|
||||
self._connected = True
|
||||
logger.info(f"Redis 连接成功: {settings.REDIS_URL}")
|
||||
except Exception as e:
|
||||
logger.error(f"Redis 连接失败: {e}")
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
"""关闭 Redis 连接"""
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
self._connected = False
|
||||
logger.info("Redis 连接已关闭")
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""检查连接状态"""
|
||||
return self._connected
|
||||
|
||||
# ==================== 基础操作 ====================
|
||||
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
"""获取值"""
|
||||
return await self.client.get(key)
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str,
|
||||
expire: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
设置值
|
||||
|
||||
Args:
|
||||
key: 键
|
||||
value: 值
|
||||
expire: 过期时间(秒)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
return await self.client.set(key, value, ex=expire)
|
||||
|
||||
async def delete(self, key: str) -> int:
|
||||
"""删除键"""
|
||||
return await self.client.delete(key)
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""检查键是否存在"""
|
||||
return await self.client.exists(key) > 0
|
||||
|
||||
# ==================== JSON 操作 ====================
|
||||
|
||||
async def set_json(
|
||||
self,
|
||||
key: str,
|
||||
data: Dict[str, Any],
|
||||
expire: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
设置 JSON 数据
|
||||
|
||||
Args:
|
||||
key: 键
|
||||
data: 数据字典
|
||||
expire: 过期时间(秒)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
json_str = json.dumps(data, ensure_ascii=False, default=str)
|
||||
return await self.set(key, json_str, expire)
|
||||
|
||||
async def get_json(self, key: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取 JSON 数据
|
||||
|
||||
Args:
|
||||
key: 键
|
||||
|
||||
Returns:
|
||||
数据字典,不存在返回 None
|
||||
"""
|
||||
value = await self.get(key)
|
||||
if value:
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
return None
|
||||
|
||||
# ==================== 任务状态管理 ====================
|
||||
|
||||
async def set_task_status(
|
||||
self,
|
||||
task_id: str,
|
||||
status: str,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
expire: int = 86400, # 默认24小时过期
|
||||
) -> bool:
|
||||
"""
|
||||
设置任务状态
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
status: 状态 (pending/processing/success/failure)
|
||||
meta: 附加信息
|
||||
expire: 过期时间(秒)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not self._connected or not self.client:
|
||||
logger.warning(f"Redis未连接,跳过任务状态更新: {task_id}")
|
||||
return False
|
||||
try:
|
||||
key = f"task:{task_id}"
|
||||
data = {
|
||||
"status": status,
|
||||
"meta": meta or {},
|
||||
}
|
||||
return await self.set_json(key, data, expire)
|
||||
except Exception as e:
|
||||
logger.warning(f"设置任务状态失败: {task_id}, error: {e}")
|
||||
return False
|
||||
|
||||
async def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取任务状态
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
状态信息
|
||||
"""
|
||||
if not self._connected or not self.client:
|
||||
logger.warning(f"Redis未连接,无法获取任务状态: {task_id}")
|
||||
return None
|
||||
try:
|
||||
key = f"task:{task_id}"
|
||||
return await self.get_json(key)
|
||||
except Exception as e:
|
||||
logger.warning(f"获取任务状态失败: {task_id}, error: {e}")
|
||||
return None
|
||||
|
||||
async def update_task_progress(
|
||||
self,
|
||||
task_id: str,
|
||||
progress: int,
|
||||
message: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
更新任务进度
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
progress: 进度值 (0-100)
|
||||
message: 进度消息
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not self._connected or not self.client:
|
||||
logger.warning(f"Redis未连接,跳过任务进度更新: {task_id}")
|
||||
return False
|
||||
try:
|
||||
data = await self.get_task_status(task_id)
|
||||
if data:
|
||||
data["meta"]["progress"] = progress
|
||||
if message:
|
||||
data["meta"]["message"] = message
|
||||
key = f"task:{task_id}"
|
||||
return await self.set_json(key, data, expire=86400)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"更新任务进度失败: {task_id}, error: {e}")
|
||||
return False
|
||||
|
||||
# ==================== 缓存操作 ====================
|
||||
|
||||
async def cache_document(
|
||||
self,
|
||||
doc_id: str,
|
||||
data: Dict[str, Any],
|
||||
expire: int = 3600, # 默认1小时
|
||||
) -> bool:
|
||||
"""
|
||||
缓存文档数据
|
||||
|
||||
Args:
|
||||
doc_id: 文档ID
|
||||
data: 文档数据
|
||||
expire: 过期时间(秒)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
key = f"doc:{doc_id}"
|
||||
return await self.set_json(key, data, expire)
|
||||
|
||||
async def get_cached_document(self, doc_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取缓存的文档
|
||||
|
||||
Args:
|
||||
doc_id: 文档ID
|
||||
|
||||
Returns:
|
||||
文档数据
|
||||
"""
|
||||
key = f"doc:{doc_id}"
|
||||
return await self.get_json(key)
|
||||
|
||||
# ==================== 分布式锁 ====================
|
||||
|
||||
async def acquire_lock(
|
||||
self,
|
||||
lock_name: str,
|
||||
expire: int = 30,
|
||||
) -> bool:
|
||||
"""
|
||||
获取分布式锁
|
||||
|
||||
Args:
|
||||
lock_name: 锁名称
|
||||
expire: 过期时间(秒)
|
||||
|
||||
Returns:
|
||||
是否获取成功
|
||||
"""
|
||||
key = f"lock:{lock_name}"
|
||||
# 使用 SET NX EX 原子操作
|
||||
result = await self.client.set(key, "1", nx=True, ex=expire)
|
||||
return result is not None
|
||||
|
||||
async def release_lock(self, lock_name: str) -> bool:
|
||||
"""
|
||||
释放分布式锁
|
||||
|
||||
Args:
|
||||
lock_name: 锁名称
|
||||
|
||||
Returns:
|
||||
是否释放成功
|
||||
"""
|
||||
key = f"lock:{lock_name}"
|
||||
result = await self.client.delete(key)
|
||||
return result > 0
|
||||
|
||||
# ==================== 计数器 ====================
|
||||
|
||||
async def incr(self, key: str, amount: int = 1) -> int:
|
||||
"""递增计数器"""
|
||||
return await self.client.incrby(key, amount)
|
||||
|
||||
async def decr(self, key: str, amount: int = 1) -> int:
|
||||
"""递减计数器"""
|
||||
return await self.client.decrby(key, amount)
|
||||
|
||||
# ==================== 过期时间管理 ====================
|
||||
|
||||
async def expire(self, key: str, seconds: int) -> bool:
|
||||
"""设置键的过期时间"""
|
||||
return await self.client.expire(key, seconds)
|
||||
|
||||
async def ttl(self, key: str) -> int:
|
||||
"""获取键的剩余生存时间"""
|
||||
return await self.client.ttl(key)
|
||||
|
||||
|
||||
# ==================== 全局单例 ====================
|
||||
|
||||
redis_db = RedisDB()
|
||||
65
backend/app/core/document_parser/__init__.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
文档解析模块 - 支持多种文件格式的解析
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from .base import BaseParser, ParseResult
|
||||
from .xlsx_parser import XlsxParser
|
||||
from .docx_parser import DocxParser
|
||||
from .md_parser import MarkdownParser
|
||||
from .txt_parser import TxtParser
|
||||
|
||||
|
||||
class ParserFactory:
|
||||
"""解析器工厂,根据文件类型返回对应解析器"""
|
||||
|
||||
_parsers: Dict[str, BaseParser] = {
|
||||
# Excel
|
||||
'.xlsx': XlsxParser(),
|
||||
'.xls': XlsxParser(),
|
||||
# Word
|
||||
'.docx': DocxParser(),
|
||||
# Markdown
|
||||
'.md': MarkdownParser(),
|
||||
'.markdown': MarkdownParser(),
|
||||
# 文本
|
||||
'.txt': TxtParser(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_parser(cls, file_path: str) -> BaseParser:
|
||||
"""根据文件扩展名获取解析器"""
|
||||
ext = Path(file_path).suffix.lower()
|
||||
parser = cls._parsers.get(ext)
|
||||
if not parser:
|
||||
supported = list(cls._parsers.keys())
|
||||
raise ValueError(f"不支持的文件格式: {ext},支持的格式: {supported}")
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def parse(cls, file_path: str, **kwargs) -> ParseResult:
|
||||
"""统一解析接口"""
|
||||
parser = cls.get_parser(file_path)
|
||||
return parser.parse(file_path, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def register_parser(cls, ext: str, parser: BaseParser):
|
||||
"""注册新的解析器"""
|
||||
cls._parsers[ext.lower()] = parser
|
||||
|
||||
@classmethod
|
||||
def get_supported_extensions(cls) -> list:
|
||||
"""获取所有支持的扩展名"""
|
||||
return list(cls._parsers.keys())
|
||||
|
||||
|
||||
__all__ = [
|
||||
'BaseParser',
|
||||
'ParseResult',
|
||||
'ParserFactory',
|
||||
'XlsxParser',
|
||||
'DocxParser',
|
||||
'MarkdownParser',
|
||||
'TxtParser',
|
||||
]
|
||||
87
backend/app/core/document_parser/base.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
解析器基类 - 定义所有解析器的通用接口
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class ParseResult:
|
||||
"""解析结果类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
success: bool,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
error: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
self.success = success
|
||||
self.data = data or {}
|
||||
self.error = error
|
||||
self.metadata = metadata or {}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"success": self.success,
|
||||
"data": self.data,
|
||||
"error": self.error,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
|
||||
class BaseParser(ABC):
|
||||
"""文档解析器基类"""
|
||||
|
||||
def __init__(self):
|
||||
self.supported_extensions: List[str] = []
|
||||
self.parser_name: str = "base_parser"
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, file_path: str, **kwargs) -> ParseResult:
|
||||
"""
|
||||
解析文件
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
**kwargs: 其他解析参数
|
||||
|
||||
Returns:
|
||||
ParseResult: 解析结果
|
||||
"""
|
||||
pass
|
||||
|
||||
def can_parse(self, file_path: str) -> bool:
|
||||
"""
|
||||
检查是否可以解析该文件
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
bool: 是否可以解析
|
||||
"""
|
||||
ext = Path(file_path).suffix.lower()
|
||||
return ext in self.supported_extensions
|
||||
|
||||
def get_file_info(self, file_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取文件基本信息
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 文件信息
|
||||
"""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
return {"error": "File not found"}
|
||||
|
||||
return {
|
||||
"filename": path.name,
|
||||
"extension": path.suffix.lower(),
|
||||
"size": path.stat().st_size,
|
||||
"parser": self.parser_name
|
||||
}
|
||||
429
backend/app/core/document_parser/docx_parser.py
Normal file
@@ -0,0 +1,429 @@
|
||||
"""
|
||||
Word 文档 (.docx) 解析器
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from docx import Document
|
||||
|
||||
from .base import BaseParser, ParseResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocxParser(BaseParser):
|
||||
"""Word 文档解析器"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_extensions = ['.docx']
|
||||
self.parser_name = "docx_parser"
|
||||
|
||||
def parse(
|
||||
self,
|
||||
file_path: str,
|
||||
**kwargs
|
||||
) -> ParseResult:
|
||||
"""
|
||||
解析 Word 文档
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
ParseResult: 解析结果
|
||||
"""
|
||||
path = Path(file_path)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not path.exists():
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"文件不存在: {file_path}"
|
||||
)
|
||||
|
||||
# 检查文件扩展名
|
||||
if path.suffix.lower() not in self.supported_extensions:
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"不支持的文件类型: {path.suffix}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 读取 Word 文档
|
||||
doc = Document(file_path)
|
||||
|
||||
# 提取文本内容
|
||||
paragraphs = []
|
||||
for para in doc.paragraphs:
|
||||
if para.text.strip():
|
||||
paragraphs.append({
|
||||
"text": para.text,
|
||||
"style": str(para.style.name) if para.style else "Normal"
|
||||
})
|
||||
|
||||
# 提取段落纯文本(用于 AI 解析)
|
||||
paragraphs_text = [p["text"] for p in paragraphs if p["text"].strip()]
|
||||
|
||||
# 提取表格内容
|
||||
tables_data = []
|
||||
for i, table in enumerate(doc.tables):
|
||||
table_rows = []
|
||||
for row in table.rows:
|
||||
row_data = [cell.text.strip() for cell in row.cells]
|
||||
table_rows.append(row_data)
|
||||
|
||||
if table_rows:
|
||||
tables_data.append({
|
||||
"table_index": i,
|
||||
"rows": table_rows,
|
||||
"row_count": len(table_rows),
|
||||
"column_count": len(table_rows[0]) if table_rows else 0
|
||||
})
|
||||
|
||||
# 提取图片/嵌入式对象信息
|
||||
images_info = self._extract_images_info(doc, path)
|
||||
|
||||
# 合并所有文本(包括图片描述)
|
||||
full_text_parts = []
|
||||
full_text_parts.append("【文档正文】")
|
||||
full_text_parts.extend(paragraphs_text)
|
||||
|
||||
if tables_data:
|
||||
full_text_parts.append("\n【文档表格】")
|
||||
for idx, table in enumerate(tables_data):
|
||||
full_text_parts.append(f"--- 表格 {idx + 1} ---")
|
||||
for row in table["rows"]:
|
||||
full_text_parts.append(" | ".join(str(cell) for cell in row))
|
||||
|
||||
if images_info.get("image_count", 0) > 0:
|
||||
full_text_parts.append(f"\n【文档图片】文档包含 {images_info['image_count']} 张图片/图表")
|
||||
|
||||
full_text = "\n".join(full_text_parts)
|
||||
|
||||
# 构建元数据
|
||||
metadata = {
|
||||
"filename": path.name,
|
||||
"extension": path.suffix.lower(),
|
||||
"file_size": path.stat().st_size,
|
||||
"paragraph_count": len(paragraphs),
|
||||
"table_count": len(tables_data),
|
||||
"word_count": len(full_text),
|
||||
"char_count": len(full_text.replace("\n", "")),
|
||||
"has_tables": len(tables_data) > 0,
|
||||
"has_images": images_info.get("image_count", 0) > 0,
|
||||
"image_count": images_info.get("image_count", 0)
|
||||
}
|
||||
|
||||
# 返回结果
|
||||
return ParseResult(
|
||||
success=True,
|
||||
data={
|
||||
"content": full_text,
|
||||
"paragraphs": paragraphs_text,
|
||||
"paragraphs_with_style": paragraphs,
|
||||
"tables": tables_data,
|
||||
"images": images_info,
|
||||
"word_count": len(full_text),
|
||||
"structured_data": {
|
||||
"paragraphs": paragraphs,
|
||||
"paragraphs_text": paragraphs_text,
|
||||
"tables": tables_data,
|
||||
"images": images_info
|
||||
}
|
||||
},
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Word 文档失败: {str(e)}")
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"解析 Word 文档失败: {str(e)}"
|
||||
)
|
||||
|
||||
def extract_images_as_base64(self, file_path: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
提取 Word 文档中的所有图片,返回 base64 编码列表
|
||||
|
||||
Args:
|
||||
file_path: Word 文件路径
|
||||
|
||||
Returns:
|
||||
图片列表,每项包含 base64 编码和图片类型
|
||||
"""
|
||||
import zipfile
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
images = []
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(file_path, 'r') as zf:
|
||||
# 查找 word/media 目录下的图片文件
|
||||
for filename in zf.namelist():
|
||||
if filename.startswith('word/media/'):
|
||||
# 获取图片类型
|
||||
ext = filename.split('.')[-1].lower()
|
||||
mime_types = {
|
||||
'png': 'image/png',
|
||||
'jpg': 'image/jpeg',
|
||||
'jpeg': 'image/jpeg',
|
||||
'gif': 'image/gif',
|
||||
'bmp': 'image/bmp'
|
||||
}
|
||||
mime_type = mime_types.get(ext, 'image/png')
|
||||
|
||||
try:
|
||||
# 读取图片数据并转为 base64
|
||||
image_data = zf.read(filename)
|
||||
base64_data = base64.b64encode(image_data).decode('utf-8')
|
||||
|
||||
images.append({
|
||||
"filename": filename,
|
||||
"mime_type": mime_type,
|
||||
"base64": base64_data,
|
||||
"size": len(image_data)
|
||||
})
|
||||
logger.info(f"提取图片: {filename}, 大小: {len(image_data)} bytes")
|
||||
except Exception as e:
|
||||
logger.warning(f"提取图片失败 {filename}: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"打开 Word 文档提取图片失败: {str(e)}")
|
||||
|
||||
logger.info(f"共提取 {len(images)} 张图片")
|
||||
return images
|
||||
|
||||
def extract_key_sentences(self, text: str, max_sentences: int = 10) -> List[str]:
|
||||
"""
|
||||
从文本中提取关键句子
|
||||
|
||||
Args:
|
||||
text: 文本内容
|
||||
max_sentences: 最大句子数
|
||||
|
||||
Returns:
|
||||
关键句子列表
|
||||
"""
|
||||
# 简单实现:按句号分割,取前N个句子
|
||||
sentences = [s.strip() for s in text.split("。") if s.strip()]
|
||||
return sentences[:max_sentences]
|
||||
|
||||
def extract_structured_fields(self, text: str) -> Dict[str, Any]:
|
||||
"""
|
||||
尝试提取结构化字段
|
||||
|
||||
针对合同、简历等有固定格式的文档
|
||||
|
||||
Args:
|
||||
text: 文本内容
|
||||
|
||||
Returns:
|
||||
提取的字段字典
|
||||
"""
|
||||
fields = {}
|
||||
|
||||
# 常见字段模式
|
||||
patterns = {
|
||||
"姓名": r"姓名[::]\s*(\S+)",
|
||||
"电话": r"电话[::]\s*(\d{11}|\d{3}-\d{8})",
|
||||
"邮箱": r"邮箱[::]\s*(\S+@\S+)",
|
||||
"地址": r"地址[::]\s*(.+?)(?:\n|$)",
|
||||
"金额": r"金额[::]\s*(\d+(?:\.\d+)?)",
|
||||
"日期": r"日期[::]\s*(\d{4}[年/-]\d{1,2}[月/-]\d{1,2})",
|
||||
}
|
||||
|
||||
import re
|
||||
for field_name, pattern in patterns.items():
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
fields[field_name] = match.group(1)
|
||||
|
||||
return fields
|
||||
|
||||
def parse_tables_for_template(
|
||||
self,
|
||||
file_path: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
解析 Word 文档中的表格,提取模板字段
|
||||
|
||||
专门用于比赛场景:解析表格模板,识别需要填写的字段
|
||||
|
||||
Args:
|
||||
file_path: Word 文件路径
|
||||
|
||||
Returns:
|
||||
包含表格字段信息的字典
|
||||
"""
|
||||
from docx import Document
|
||||
from docx.table import Table
|
||||
from docx.oxml.ns import qn
|
||||
|
||||
doc = Document(file_path)
|
||||
|
||||
template_info = {
|
||||
"tables": [],
|
||||
"fields": [],
|
||||
"field_count": 0
|
||||
}
|
||||
|
||||
for table_idx, table in enumerate(doc.tables):
|
||||
table_info = {
|
||||
"table_index": table_idx,
|
||||
"rows": [],
|
||||
"headers": [],
|
||||
"data_rows": [],
|
||||
"field_hints": {} # 字段名称 -> 提示词/描述
|
||||
}
|
||||
|
||||
# 提取表头(第一行)
|
||||
if table.rows:
|
||||
header_cells = [cell.text.strip() for cell in table.rows[0].cells]
|
||||
table_info["headers"] = header_cells
|
||||
|
||||
# 提取数据行
|
||||
for row_idx, row in enumerate(table.rows[1:], 1):
|
||||
row_data = [cell.text.strip() for cell in row.cells]
|
||||
table_info["data_rows"].append(row_data)
|
||||
table_info["rows"].append({
|
||||
"row_index": row_idx,
|
||||
"cells": row_data
|
||||
})
|
||||
|
||||
# 尝试从第二列/第三列提取提示词
|
||||
# 比赛模板通常格式为:字段名 | 提示词 | 填写值
|
||||
if len(table.rows[0].cells) >= 2:
|
||||
for row_idx, row in enumerate(table.rows[1:], 1):
|
||||
cells = [cell.text.strip() for cell in row.cells]
|
||||
if len(cells) >= 2 and cells[0]:
|
||||
# 第一列是字段名
|
||||
field_name = cells[0]
|
||||
# 第二列可能是提示词或描述
|
||||
hint = cells[1] if len(cells) > 1 else ""
|
||||
table_info["field_hints"][field_name] = hint
|
||||
|
||||
template_info["fields"].append({
|
||||
"table_index": table_idx,
|
||||
"row_index": row_idx,
|
||||
"field_name": field_name,
|
||||
"hint": hint,
|
||||
"expected_value": cells[2] if len(cells) > 2 else ""
|
||||
})
|
||||
|
||||
template_info["tables"].append(table_info)
|
||||
|
||||
template_info["field_count"] = len(template_info["fields"])
|
||||
return template_info
|
||||
|
||||
def extract_template_fields_from_docx(
|
||||
self,
|
||||
file_path: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从 Word 文档中提取模板字段定义
|
||||
|
||||
适用于比赛评分表格:表格第一列是字段名,第二列是提示词/填写示例
|
||||
|
||||
Args:
|
||||
file_path: Word 文件路径
|
||||
|
||||
Returns:
|
||||
字段定义列表
|
||||
"""
|
||||
template_info = self.parse_tables_for_template(file_path)
|
||||
|
||||
fields = []
|
||||
for field in template_info["fields"]:
|
||||
fields.append({
|
||||
"cell": f"T{field['table_index']}R{field['row_index']}", # TableXRowY 格式
|
||||
"name": field["field_name"],
|
||||
"hint": field["hint"],
|
||||
"table_index": field["table_index"],
|
||||
"row_index": field["row_index"],
|
||||
"field_type": self._infer_field_type_from_hint(field["hint"]),
|
||||
"required": True
|
||||
})
|
||||
|
||||
return fields
|
||||
|
||||
def _extract_images_info(self, doc: Document, path: Path) -> Dict[str, Any]:
|
||||
"""
|
||||
提取 Word 文档中的图片/嵌入式对象信息
|
||||
|
||||
Args:
|
||||
doc: Document 对象
|
||||
path: 文件路径
|
||||
|
||||
Returns:
|
||||
图片信息字典
|
||||
"""
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
|
||||
image_count = 0
|
||||
image_descriptions = []
|
||||
inline_shapes_count = 0
|
||||
|
||||
try:
|
||||
# 方法1: 通过 inline shapes 统计图片
|
||||
try:
|
||||
inline_shapes_count = len(doc.inline_shapes)
|
||||
if inline_shapes_count > 0:
|
||||
image_count = inline_shapes_count
|
||||
image_descriptions.append(f"文档包含 {inline_shapes_count} 个嵌入式图形/图片")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 方法2: 通过 ZIP 分析 document.xml 获取图片引用
|
||||
try:
|
||||
with zipfile.ZipFile(path, 'r') as zf:
|
||||
# 查找 word/media 目录下的图片文件
|
||||
media_files = [f for f in zf.namelist() if f.startswith('word/media/')]
|
||||
if media_files and not inline_shapes_count:
|
||||
image_count = len(media_files)
|
||||
image_descriptions.append(f"文档包含 {image_count} 个嵌入图片")
|
||||
|
||||
# 检查是否有页眉页脚中的图片
|
||||
header_images = [f for f in zf.namelist() if 'header' in f.lower() and f.endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))]
|
||||
if header_images:
|
||||
image_descriptions.append(f"页眉/页脚包含 {len(header_images)} 个图片")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"提取图片信息失败: {str(e)}")
|
||||
|
||||
return {
|
||||
"image_count": image_count,
|
||||
"inline_shapes_count": inline_shapes_count,
|
||||
"descriptions": image_descriptions,
|
||||
"has_images": image_count > 0
|
||||
}
|
||||
|
||||
def _infer_field_type_from_hint(self, hint: str) -> str:
|
||||
"""
|
||||
从提示词推断字段类型
|
||||
|
||||
Args:
|
||||
hint: 字段提示词
|
||||
|
||||
Returns:
|
||||
字段类型 (text/number/date)
|
||||
"""
|
||||
hint_lower = hint.lower()
|
||||
|
||||
# 日期关键词
|
||||
date_keywords = ["年", "月", "日", "日期", "时间", "出生"]
|
||||
if any(kw in hint for kw in date_keywords):
|
||||
return "date"
|
||||
|
||||
# 数字关键词
|
||||
number_keywords = ["数量", "金额", "人数", "面积", "增长", "比率", "%", "率"]
|
||||
if any(kw in hint_lower for kw in number_keywords):
|
||||
return "number"
|
||||
|
||||
return "text"
|
||||
262
backend/app/core/document_parser/md_parser.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""
|
||||
Markdown 文档解析器
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import markdown
|
||||
|
||||
from .base import BaseParser, ParseResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarkdownParser(BaseParser):
|
||||
"""Markdown 文档解析器"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_extensions = ['.md', '.markdown']
|
||||
self.parser_name = "markdown_parser"
|
||||
|
||||
def parse(
|
||||
self,
|
||||
file_path: str,
|
||||
**kwargs
|
||||
) -> ParseResult:
|
||||
"""
|
||||
解析 Markdown 文档
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
ParseResult: 解析结果
|
||||
"""
|
||||
path = Path(file_path)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not path.exists():
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"文件不存在: {file_path}"
|
||||
)
|
||||
|
||||
# 检查文件扩展名
|
||||
if path.suffix.lower() not in self.supported_extensions:
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"不支持的文件类型: {path.suffix}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 读取文件内容
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
raw_content = f.read()
|
||||
|
||||
# 解析 Markdown
|
||||
md = markdown.Markdown(extensions=[
|
||||
'markdown.extensions.tables',
|
||||
'markdown.extensions.fenced_code',
|
||||
'markdown.extensions.codehilite',
|
||||
'markdown.extensions.toc',
|
||||
])
|
||||
|
||||
html_content = md.convert(raw_content)
|
||||
|
||||
# 提取标题结构
|
||||
titles = self._extract_titles(raw_content)
|
||||
|
||||
# 提取代码块
|
||||
code_blocks = self._extract_code_blocks(raw_content)
|
||||
|
||||
# 提取表格
|
||||
tables = self._extract_tables(raw_content)
|
||||
|
||||
# 提取链接和图片
|
||||
links_images = self._extract_links_images(raw_content)
|
||||
|
||||
# 清理后的纯文本(去除 Markdown 语法)
|
||||
plain_text = self._strip_markdown(raw_content)
|
||||
|
||||
# 构建元数据
|
||||
metadata = {
|
||||
"filename": path.name,
|
||||
"extension": path.suffix.lower(),
|
||||
"file_size": path.stat().st_size,
|
||||
"word_count": len(plain_text),
|
||||
"char_count": len(raw_content),
|
||||
"line_count": len(raw_content.splitlines()),
|
||||
"title_count": len(titles),
|
||||
"code_block_count": len(code_blocks),
|
||||
"table_count": len(tables),
|
||||
"link_count": len(links_images.get("links", [])),
|
||||
"image_count": len(links_images.get("images", [])),
|
||||
}
|
||||
|
||||
return ParseResult(
|
||||
success=True,
|
||||
data={
|
||||
"content": plain_text,
|
||||
"raw_content": raw_content,
|
||||
"html_content": html_content,
|
||||
"titles": titles,
|
||||
"code_blocks": code_blocks,
|
||||
"tables": tables,
|
||||
"links_images": links_images,
|
||||
"word_count": len(plain_text),
|
||||
"structured_data": {
|
||||
"titles": titles,
|
||||
"code_blocks": code_blocks,
|
||||
"tables": tables
|
||||
}
|
||||
},
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Markdown 文档失败: {str(e)}")
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"解析 Markdown 文档失败: {str(e)}"
|
||||
)
|
||||
|
||||
def _extract_titles(self, content: str) -> List[Dict[str, Any]]:
|
||||
"""提取标题结构"""
|
||||
import re
|
||||
titles = []
|
||||
|
||||
# 匹配 # 标题
|
||||
for match in re.finditer(r'^(#{1,6})\s+(.+)$', content, re.MULTILINE):
|
||||
level = len(match.group(1))
|
||||
title_text = match.group(2).strip()
|
||||
titles.append({
|
||||
"level": level,
|
||||
"text": title_text,
|
||||
"line": content[:match.start()].count('\n') + 1
|
||||
})
|
||||
|
||||
return titles
|
||||
|
||||
def _extract_code_blocks(self, content: str) -> List[Dict[str, str]]:
|
||||
"""提取代码块"""
|
||||
import re
|
||||
code_blocks = []
|
||||
|
||||
# 匹配 ```code ``` 格式
|
||||
pattern = r'```(\w*)\n(.*?)```'
|
||||
for match in re.finditer(pattern, content, re.DOTALL):
|
||||
language = match.group(1) or "text"
|
||||
code = match.group(2).strip()
|
||||
code_blocks.append({
|
||||
"language": language,
|
||||
"code": code
|
||||
})
|
||||
|
||||
return code_blocks
|
||||
|
||||
def _extract_tables(self, content: str) -> List[Dict[str, Any]]:
|
||||
"""提取表格"""
|
||||
import re
|
||||
tables = []
|
||||
|
||||
# 简单表格匹配(| col1 | col2 | 格式)
|
||||
lines = content.split('\n')
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i].strip()
|
||||
|
||||
# 检查是否是表格行
|
||||
if line.startswith('|') and line.endswith('|'):
|
||||
# 找到表头
|
||||
header_row = [cell.strip() for cell in line.split('|')[1:-1]]
|
||||
|
||||
# 检查下一行是否是分隔符
|
||||
if i + 1 < len(lines) and re.match(r'^\|[\s\-:|]+\|$', lines[i + 1]):
|
||||
# 跳过分隔符,读取数据行
|
||||
data_rows = []
|
||||
for j in range(i + 2, len(lines)):
|
||||
row_line = lines[j].strip()
|
||||
if not (row_line.startswith('|') and row_line.endswith('|')):
|
||||
break
|
||||
row_data = [cell.strip() for cell in row_line.split('|')[1:-1]]
|
||||
data_rows.append(row_data)
|
||||
|
||||
if header_row and data_rows:
|
||||
tables.append({
|
||||
"headers": header_row,
|
||||
"rows": data_rows,
|
||||
"row_count": len(data_rows),
|
||||
"column_count": len(header_row)
|
||||
})
|
||||
i = j - 1
|
||||
|
||||
i += 1
|
||||
|
||||
return tables
|
||||
|
||||
def _extract_links_images(self, content: str) -> Dict[str, List[Dict[str, str]]]:
|
||||
"""提取链接和图片"""
|
||||
import re
|
||||
result = {"links": [], "images": []}
|
||||
|
||||
# 提取链接 [text](url)
|
||||
for match in re.finditer(r'\[([^\]]+)\]\(([^\)]+)\)', content):
|
||||
result["links"].append({
|
||||
"text": match.group(1),
|
||||
"url": match.group(2)
|
||||
})
|
||||
|
||||
# 提取图片 
|
||||
for match in re.finditer(r'!\[([^\]]*)\]\(([^\)]+)\)', content):
|
||||
result["images"].append({
|
||||
"alt": match.group(1),
|
||||
"url": match.group(2)
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def _strip_markdown(self, content: str) -> str:
|
||||
"""去除 Markdown 语法,获取纯文本"""
|
||||
import re
|
||||
|
||||
# 去除代码块
|
||||
content = re.sub(r'```[\s\S]*?```', '', content)
|
||||
|
||||
# 去除行内代码
|
||||
content = re.sub(r'`[^`]+`', '', content)
|
||||
|
||||
# 去除图片
|
||||
content = re.sub(r'!\[([^\]]*)\]\([^\)]+\)', r'\1', content)
|
||||
|
||||
# 去除链接,保留文本
|
||||
content = re.sub(r'\[([^\]]+)\]\([^\)]+\)', r'\1', content)
|
||||
|
||||
# 去除标题标记
|
||||
content = re.sub(r'^#{1,6}\s+', '', content, flags=re.MULTILINE)
|
||||
|
||||
# 去除加粗和斜体
|
||||
content = re.sub(r'\*\*([^\*]+)\*\*', r'\1', content)
|
||||
content = re.sub(r'\*([^\*]+)\*', r'\1', content)
|
||||
content = re.sub(r'__([^_]+)__', r'\1', content)
|
||||
content = re.sub(r'_([^_]+)_', r'\1', content)
|
||||
|
||||
# 去除引用标记
|
||||
content = re.sub(r'^>\s+', '', content, flags=re.MULTILINE)
|
||||
|
||||
# 去除列表标记
|
||||
content = re.sub(r'^[-*+]\s+', '', content, flags=re.MULTILINE)
|
||||
content = re.sub(r'^\d+\.\s+', '', content, flags=re.MULTILINE)
|
||||
|
||||
# 去除水平线
|
||||
content = re.sub(r'^[-*_]{3,}$', '', content, flags=re.MULTILINE)
|
||||
|
||||
# 去除表格分隔符
|
||||
content = re.sub(r'^\|[\s\-:|]+\|$', '', content, flags=re.MULTILINE)
|
||||
|
||||
# 清理多余空行
|
||||
content = re.sub(r'\n{3,}', '\n\n', content)
|
||||
|
||||
return content.strip()
|
||||
278
backend/app/core/document_parser/txt_parser.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""
|
||||
纯文本 (.txt) 解析器
|
||||
"""
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import chardet
|
||||
|
||||
from .base import BaseParser, ParseResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TxtParser(BaseParser):
|
||||
"""纯文本文档解析器"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_extensions = ['.txt']
|
||||
self.parser_name = "txt_parser"
|
||||
|
||||
def parse(
|
||||
self,
|
||||
file_path: str,
|
||||
encoding: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> ParseResult:
|
||||
"""
|
||||
解析文本文件
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
encoding: 指定编码,不指定则自动检测
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
ParseResult: 解析结果
|
||||
"""
|
||||
path = Path(file_path)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not path.exists():
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"文件不存在: {file_path}"
|
||||
)
|
||||
|
||||
# 检查文件扩展名
|
||||
if path.suffix.lower() not in self.supported_extensions:
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"不支持的文件类型: {path.suffix}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 检测编码
|
||||
if not encoding:
|
||||
encoding = self._detect_encoding(file_path)
|
||||
|
||||
# 读取文件内容
|
||||
with open(file_path, 'r', encoding=encoding) as f:
|
||||
raw_content = f.read()
|
||||
|
||||
# 清理文本
|
||||
content = self._clean_text(raw_content)
|
||||
|
||||
# 提取行信息
|
||||
lines = content.split('\n')
|
||||
|
||||
# 估算字数
|
||||
word_count = len(content.replace('\n', '').replace(' ', ''))
|
||||
|
||||
# 构建元数据
|
||||
metadata = {
|
||||
"filename": path.name,
|
||||
"extension": path.suffix.lower(),
|
||||
"file_size": path.stat().st_size,
|
||||
"encoding": encoding,
|
||||
"line_count": len(lines),
|
||||
"word_count": word_count,
|
||||
"char_count": len(content),
|
||||
"non_empty_line_count": len([l for l in lines if l.strip()])
|
||||
}
|
||||
|
||||
return ParseResult(
|
||||
success=True,
|
||||
data={
|
||||
"content": content,
|
||||
"raw_content": raw_content,
|
||||
"lines": lines,
|
||||
"word_count": word_count,
|
||||
"char_count": len(content),
|
||||
"line_count": len(lines),
|
||||
"structured_data": {
|
||||
"line_count": len(lines),
|
||||
"non_empty_line_count": metadata["non_empty_line_count"]
|
||||
}
|
||||
},
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析文本文件失败: {str(e)}")
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"解析文本文件失败: {str(e)}"
|
||||
)
|
||||
|
||||
def _detect_encoding(self, file_path: str) -> str:
|
||||
"""
|
||||
自动检测文件编码
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
检测到的编码
|
||||
"""
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
raw_data = f.read()
|
||||
|
||||
result = chardet.detect(raw_data)
|
||||
encoding = result.get('encoding', 'utf-8')
|
||||
|
||||
# 验证编码是否有效
|
||||
if encoding:
|
||||
try:
|
||||
raw_data.decode(encoding)
|
||||
return encoding
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
pass
|
||||
|
||||
return 'utf-8'
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"编码检测失败,使用默认编码: {str(e)}")
|
||||
return 'utf-8'
|
||||
|
||||
def _clean_text(self, text: str) -> str:
|
||||
"""
|
||||
清理文本内容
|
||||
|
||||
- 去除多余空白字符
|
||||
- 规范化换行符
|
||||
- 去除特殊控制字符
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
|
||||
Returns:
|
||||
清理后的文本
|
||||
"""
|
||||
# 规范化换行符
|
||||
text = text.replace('\r\n', '\n').replace('\r', '\n')
|
||||
|
||||
# 去除控制字符(除了换行和tab)
|
||||
text = re.sub(r'[\x00-\x08\x0b-\x0c\x0e-\x1f\x7f]', '', text)
|
||||
|
||||
# 将多个连续空格合并为一个
|
||||
text = re.sub(r'[ \t]+', ' ', text)
|
||||
|
||||
# 将多个连续空行合并为一个
|
||||
text = re.sub(r'\n{3,}', '\n\n', text)
|
||||
|
||||
return text.strip()
|
||||
|
||||
def extract_structured_data(self, content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
尝试从文本中提取结构化数据
|
||||
|
||||
支持提取:
|
||||
- 邮箱地址
|
||||
- URL
|
||||
- 电话号码
|
||||
- 日期
|
||||
- 金额
|
||||
|
||||
Args:
|
||||
content: 文本内容
|
||||
|
||||
Returns:
|
||||
结构化数据字典
|
||||
"""
|
||||
data = {
|
||||
"emails": [],
|
||||
"urls": [],
|
||||
"phones": [],
|
||||
"dates": [],
|
||||
"amounts": []
|
||||
}
|
||||
|
||||
# 提取邮箱
|
||||
emails = re.findall(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', content)
|
||||
data["emails"] = list(set(emails))
|
||||
|
||||
# 提取 URL
|
||||
urls = re.findall(r'https?://[^\s<>"{}|\\^`\[\]]+', content)
|
||||
data["urls"] = list(set(urls))
|
||||
|
||||
# 提取电话号码 (支持多种格式)
|
||||
phone_patterns = [
|
||||
r'1[3-9]\d{9}', # 手机号
|
||||
r'\d{3,4}-\d{7,8}', # 固话
|
||||
]
|
||||
phones = []
|
||||
for pattern in phone_patterns:
|
||||
phones.extend(re.findall(pattern, content))
|
||||
data["phones"] = list(set(phones))
|
||||
|
||||
# 提取日期
|
||||
date_patterns = [
|
||||
r'\d{4}[-/年]\d{1,2}[-/月]\d{1,2}[日]?',
|
||||
r'\d{4}\.\d{1,2}\.\d{1,2}',
|
||||
]
|
||||
dates = []
|
||||
for pattern in date_patterns:
|
||||
dates.extend(re.findall(pattern, content))
|
||||
data["dates"] = list(set(dates))
|
||||
|
||||
# 提取金额
|
||||
amount_patterns = [
|
||||
r'¥\s*\d+(?:\.\d{1,2})?',
|
||||
r'\$\s*\d+(?:\.\d{1,2})?',
|
||||
r'\d+(?:\.\d{1,2})?\s*元',
|
||||
]
|
||||
amounts = []
|
||||
for pattern in amount_patterns:
|
||||
amounts.extend(re.findall(pattern, content))
|
||||
data["amounts"] = list(set(amounts))
|
||||
|
||||
return data
|
||||
|
||||
def split_into_chunks(
|
||||
self,
|
||||
content: str,
|
||||
chunk_size: int = 1000,
|
||||
overlap: int = 100
|
||||
) -> List[str]:
|
||||
"""
|
||||
将长文本分割成块
|
||||
|
||||
用于 RAG 索引或 LLM 处理
|
||||
|
||||
Args:
|
||||
content: 文本内容
|
||||
chunk_size: 每块字符数
|
||||
overlap: 块之间的重叠字符数
|
||||
|
||||
Returns:
|
||||
文本块列表
|
||||
"""
|
||||
if len(content) <= chunk_size:
|
||||
return [content]
|
||||
|
||||
chunks = []
|
||||
start = 0
|
||||
|
||||
while start < len(content):
|
||||
end = start + chunk_size
|
||||
chunk = content[start:end]
|
||||
|
||||
# 尝试在句子边界分割
|
||||
if end < len(content):
|
||||
last_period = chunk.rfind('。')
|
||||
last_newline = chunk.rfind('\n')
|
||||
split_pos = max(last_period, last_newline)
|
||||
|
||||
if split_pos > chunk_size // 2:
|
||||
chunk = chunk[:split_pos + 1]
|
||||
end = start + split_pos + 1
|
||||
|
||||
chunks.append(chunk)
|
||||
start = end - overlap if end < len(content) else end
|
||||
|
||||
return chunks
|
||||
120
backend/app/core/document_parser/utils.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
文档解析工具函数
|
||||
"""
|
||||
import re
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
|
||||
def clean_text(text: str) -> str:
|
||||
"""
|
||||
清洗文本,去除多余的空白字符和特殊符号
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
|
||||
Returns:
|
||||
str: 清洗后的文本
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# 去除首尾空白
|
||||
text = text.strip()
|
||||
|
||||
# 将多个连续的空白字符替换为单个空格
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
|
||||
# 去除不可打印字符
|
||||
text = ''.join(char for char in text if char.isprintable() or char in '\n\r\t')
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def chunk_text(
|
||||
text: str,
|
||||
chunk_size: int = 1000,
|
||||
overlap: int = 100
|
||||
) -> List[str]:
|
||||
"""
|
||||
将文本分块
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
chunk_size: 每块的大小(字符数)
|
||||
overlap: 重叠区域的大小
|
||||
|
||||
Returns:
|
||||
List[str]: 文本块列表
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
chunks = []
|
||||
start = 0
|
||||
text_length = len(text)
|
||||
|
||||
while start < text_length:
|
||||
end = start + chunk_size
|
||||
chunk = text[start:end]
|
||||
chunks.append(chunk)
|
||||
start = end - overlap
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def normalize_string(s: Any) -> str:
|
||||
"""
|
||||
标准化字符串
|
||||
|
||||
Args:
|
||||
s: 输入值
|
||||
|
||||
Returns:
|
||||
str: 标准化后的字符串
|
||||
"""
|
||||
if s is None:
|
||||
return ""
|
||||
if isinstance(s, (int, float)):
|
||||
return str(s)
|
||||
if isinstance(s, str):
|
||||
return clean_text(s)
|
||||
return str(s)
|
||||
|
||||
|
||||
def detect_encoding(file_path: str) -> Optional[str]:
|
||||
"""
|
||||
检测文件编码(简化版)
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
Optional[str]: 编码格式,无法检测则返回 None
|
||||
"""
|
||||
import chardet
|
||||
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
raw_data = f.read(10000) # 读取前 10000 字节
|
||||
result = chardet.detect(raw_data)
|
||||
return result.get('encoding')
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def safe_get(d: Dict[str, Any], key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
安全地获取字典值
|
||||
|
||||
Args:
|
||||
d: 字典
|
||||
key: 键
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
Any: 字典值或默认值
|
||||
"""
|
||||
try:
|
||||
return d.get(key, default)
|
||||
except Exception:
|
||||
return default
|
||||
541
backend/app/core/document_parser/xlsx_parser.py
Normal file
@@ -0,0 +1,541 @@
|
||||
"""
|
||||
Excel 文件解析器 - 解析 .xlsx 和 .xls 文件
|
||||
"""
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import logging
|
||||
|
||||
from .base import BaseParser, ParseResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class XlsxParser(BaseParser):
|
||||
"""Excel 文件解析器"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_extensions = ['.xlsx', '.xls']
|
||||
self.parser_name = "excel_parser"
|
||||
|
||||
def parse(
|
||||
self,
|
||||
file_path: str,
|
||||
sheet_name: Optional[str | int] = 0,
|
||||
header_row: int = 0,
|
||||
**kwargs
|
||||
) -> ParseResult:
|
||||
"""
|
||||
解析 Excel 文件
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
sheet_name: 工作表名称或索引,默认为第一个工作表
|
||||
header_row: 表头所在的行索引,默认为 0
|
||||
**kwargs: 其他参数传递给 pandas.read_excel
|
||||
|
||||
Returns:
|
||||
ParseResult: 解析结果
|
||||
"""
|
||||
path = Path(file_path)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not path.exists():
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"File not found: {file_path}"
|
||||
)
|
||||
|
||||
# 检查文件扩展名
|
||||
if path.suffix.lower() not in self.supported_extensions:
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"Unsupported file type: {path.suffix}"
|
||||
)
|
||||
|
||||
# 检查文件大小
|
||||
file_size = path.stat().st_size
|
||||
if file_size == 0:
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"File is empty: {file_path}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 尝试读取 Excel 文件,检查是否有工作表
|
||||
xls_file = pd.ExcelFile(file_path)
|
||||
sheet_names = xls_file.sheet_names
|
||||
|
||||
# 如果 pandas 返回空列表,尝试从 XML 提取
|
||||
if not sheet_names:
|
||||
sheet_names = self._extract_sheet_names_from_xml(file_path)
|
||||
if not sheet_names:
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"Excel 文件没有找到任何工作表: {file_path}"
|
||||
)
|
||||
|
||||
# 验证请求的工作表索引/名称
|
||||
target_sheet = None
|
||||
if sheet_name is not None:
|
||||
if isinstance(sheet_name, int) and sheet_name < len(sheet_names):
|
||||
target_sheet = sheet_names[sheet_name]
|
||||
elif isinstance(sheet_name, str) and sheet_name in sheet_names:
|
||||
target_sheet = sheet_name
|
||||
else:
|
||||
# 如果指定的 sheet_name 无效,使用第一个工作表
|
||||
target_sheet = sheet_names[0]
|
||||
else:
|
||||
# 默认使用第一个工作表
|
||||
target_sheet = sheet_names[0]
|
||||
|
||||
# 读取 Excel 文件
|
||||
df = None
|
||||
try:
|
||||
df = pd.read_excel(
|
||||
file_path,
|
||||
sheet_name=target_sheet,
|
||||
header=header_row,
|
||||
**kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"pandas 读取 Excel 失败,尝试 XML 方式: {e}")
|
||||
# pandas 读取失败,尝试 XML 方式
|
||||
df = self._read_excel_sheet_xml(file_path, sheet_name=target_sheet, header_row=header_row)
|
||||
|
||||
# 检查 DataFrame 是否为空(但如果有列名,仍算有效)
|
||||
if df is None:
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"工作表 '{target_sheet}' 读取失败"
|
||||
)
|
||||
|
||||
# 如果 DataFrame 为空但有列名(比如模板文件),仍算有效
|
||||
if df.empty and len(df.columns) == 0:
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"工作表 '{target_sheet}' 为空,请检查 Excel 文件内容"
|
||||
)
|
||||
|
||||
# 转换为可序列化的数据
|
||||
data = self._df_to_dict(df)
|
||||
|
||||
# 构建元数据
|
||||
metadata = {
|
||||
"filename": path.name,
|
||||
"extension": path.suffix.lower(),
|
||||
"sheet_count": len(sheet_names),
|
||||
"sheet_names": sheet_names,
|
||||
"current_sheet": target_sheet,
|
||||
"row_count": len(df),
|
||||
"column_count": len(df.columns) if not df.empty else 0,
|
||||
"columns": df.columns.tolist() if not df.empty else [],
|
||||
"file_size": file_size
|
||||
}
|
||||
|
||||
return ParseResult(
|
||||
success=True,
|
||||
data=data,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
except IndexError as e:
|
||||
logger.error(f"工作表索引错误: {str(e)}")
|
||||
# 工作表索引超出范围时,尝试使用第一个工作表
|
||||
try:
|
||||
xls_file = pd.ExcelFile(file_path)
|
||||
sheet_names = xls_file.sheet_names
|
||||
if sheet_names:
|
||||
df = pd.read_excel(
|
||||
file_path,
|
||||
sheet_name=sheet_names[0],
|
||||
header=header_row,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
data = self._df_to_dict(df)
|
||||
metadata = {
|
||||
"filename": path.name,
|
||||
"extension": path.suffix.lower(),
|
||||
"sheet_count": len(sheet_names),
|
||||
"sheet_names": sheet_names,
|
||||
"current_sheet": sheet_names[0],
|
||||
"row_count": len(df),
|
||||
"column_count": len(df.columns) if not df.empty else 0,
|
||||
"columns": df.columns.tolist() if not df.empty else [],
|
||||
"file_size": path.stat().st_size
|
||||
}
|
||||
|
||||
return ParseResult(
|
||||
success=True,
|
||||
data=data,
|
||||
metadata=metadata
|
||||
)
|
||||
else:
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"Excel 文件没有有效的工作表"
|
||||
)
|
||||
except Exception as e2:
|
||||
logger.error(f"重试解析失败: {str(e2)}")
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"无法解析 Excel 文件: {str(e)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Excel 文件时出错: {str(e)}")
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"Failed to parse Excel file: {str(e)}"
|
||||
)
|
||||
|
||||
def parse_all_sheets(self, file_path: str, **kwargs) -> ParseResult:
|
||||
"""
|
||||
解析 Excel 文件的所有工作表
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
**kwargs: 其他参数传递给 pandas.read_excel
|
||||
|
||||
Returns:
|
||||
ParseResult: 解析结果
|
||||
"""
|
||||
path = Path(file_path)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not path.exists():
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"File not found: {file_path}"
|
||||
)
|
||||
|
||||
if path.suffix.lower() not in self.supported_extensions:
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"Unsupported file type: {path.suffix}"
|
||||
)
|
||||
|
||||
# 检查文件大小
|
||||
file_size = path.stat().st_size
|
||||
if file_size == 0:
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"File is empty: {file_path}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 读取所有工作表
|
||||
all_data = None
|
||||
try:
|
||||
all_data = pd.read_excel(file_path, sheet_name=None, **kwargs)
|
||||
except Exception as e:
|
||||
logger.warning(f"pandas 读取所有工作表失败: {e}")
|
||||
|
||||
# 如果 pandas 失败,尝试 XML 方式
|
||||
if all_data is None or len(all_data) == 0:
|
||||
sheet_names = self._extract_sheet_names_from_xml(file_path)
|
||||
if not sheet_names:
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"无法读取 Excel 文件或文件为空: {file_path}"
|
||||
)
|
||||
# 使用 XML 方式读取每个工作表
|
||||
all_data = {}
|
||||
for sheet_name in sheet_names:
|
||||
df = self._read_excel_sheet_xml(file_path, sheet_name=sheet_name, header_row=0)
|
||||
if df is not None and not df.empty:
|
||||
all_data[sheet_name] = df
|
||||
|
||||
# 检查是否成功读取到数据
|
||||
if not all_data or len(all_data) == 0:
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"无法读取 Excel 文件或文件为空: {file_path}"
|
||||
)
|
||||
|
||||
# 转换为可序列化的数据
|
||||
sheets_data = {}
|
||||
for sheet_name, df in all_data.items():
|
||||
sheets_data[sheet_name] = self._df_to_dict(df)
|
||||
|
||||
# 获取所有工作表名称
|
||||
all_sheets = list(all_data.keys())
|
||||
|
||||
# 构建元数据
|
||||
total_rows = sum(len(df) for df in all_data.values())
|
||||
metadata = {
|
||||
"filename": path.name,
|
||||
"extension": path.suffix.lower(),
|
||||
"sheet_count": len(all_sheets),
|
||||
"sheet_names": all_sheets,
|
||||
"total_rows": total_rows,
|
||||
"file_size": file_size
|
||||
}
|
||||
|
||||
return ParseResult(
|
||||
success=True,
|
||||
data={"sheets": sheets_data},
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse Excel file: {str(e)}")
|
||||
return ParseResult(
|
||||
success=False,
|
||||
error=f"Failed to parse Excel file: {str(e)}"
|
||||
)
|
||||
|
||||
def _get_sheet_names(self, file_path: str) -> List[str]:
|
||||
"""获取 Excel 文件中的所有工作表名称"""
|
||||
try:
|
||||
xls = pd.ExcelFile(file_path)
|
||||
sheet_names = xls.sheet_names
|
||||
if sheet_names:
|
||||
return sheet_names
|
||||
# pandas 返回空列表,尝试从 XML 提取
|
||||
return self._extract_sheet_names_from_xml(file_path)
|
||||
except Exception as e:
|
||||
logger.error(f"获取工作表名称失败: {str(e)}")
|
||||
# 尝试从 XML 提取
|
||||
return self._extract_sheet_names_from_xml(file_path)
|
||||
|
||||
def _extract_sheet_names_from_xml(self, file_path: str) -> List[str]:
|
||||
"""
|
||||
从 Excel 文件的 XML 中提取工作表名称
|
||||
|
||||
某些 Excel 文件由于包含非标准元素(如 mc:AlternateContent),
|
||||
pandas/openpyxl 无法正确解析工作表列表,此时需要直接从 XML 中提取。
|
||||
|
||||
Args:
|
||||
file_path: Excel 文件路径
|
||||
|
||||
Returns:
|
||||
工作表名称列表
|
||||
"""
|
||||
import zipfile
|
||||
from xml.etree import ElementTree as ET
|
||||
|
||||
# 常见的命名空间
|
||||
COMMON_NAMESPACES = [
|
||||
'http://schemas.openxmlformats.org/spreadsheetml/2006/main',
|
||||
'http://schemas.openxmlformats.org/spreadsheetml/2005/main',
|
||||
'http://schemas.openxmlformats.org/spreadsheetml/2004/main',
|
||||
'http://schemas.openxmlformats.org/spreadsheetml/2003/main',
|
||||
]
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(file_path, 'r') as z:
|
||||
# 尝试多种可能的 workbook.xml 路径
|
||||
possible_paths = ['xl/workbook.xml', 'xl\\workbook.xml', 'workbook.xml']
|
||||
content = None
|
||||
for path in possible_paths:
|
||||
if path in z.namelist():
|
||||
content = z.read(path)
|
||||
logger.info(f"找到 workbook.xml at: {path}")
|
||||
break
|
||||
|
||||
if content is None:
|
||||
logger.warning(f"未找到 workbook.xml,文件列表: {z.namelist()[:10]}")
|
||||
return []
|
||||
|
||||
root = ET.fromstring(content)
|
||||
|
||||
sheet_names = []
|
||||
|
||||
# 方法1:尝试带命名空间的查找
|
||||
for ns in COMMON_NAMESPACES:
|
||||
sheet_elements = root.findall(f'.//{{{ns}}}sheet')
|
||||
if sheet_elements:
|
||||
for sheet in sheet_elements:
|
||||
name = sheet.get('name')
|
||||
if name:
|
||||
sheet_names.append(name)
|
||||
if sheet_names:
|
||||
logger.info(f"使用命名空间 {ns} 提取工作表: {sheet_names}")
|
||||
return sheet_names
|
||||
|
||||
# 方法2:不使用命名空间,直接查找所有 sheet 元素
|
||||
if not sheet_names:
|
||||
for elem in root.iter():
|
||||
if elem.tag.endswith('sheet') and elem.tag != 'sheets':
|
||||
name = elem.get('name')
|
||||
if name:
|
||||
sheet_names.append(name)
|
||||
for child in elem:
|
||||
if child.tag.endswith('sheet') or child.tag == 'sheet':
|
||||
name = child.get('name')
|
||||
if name and name not in sheet_names:
|
||||
sheet_names.append(name)
|
||||
|
||||
# 方法3:直接从 XML 文本中正则匹配 sheet name
|
||||
if not sheet_names:
|
||||
import re
|
||||
xml_str = content.decode('utf-8', errors='ignore')
|
||||
matches = re.findall(r'<sheet\s+[^>]*name=["\']([^"\']+)["\']', xml_str, re.IGNORECASE)
|
||||
if matches:
|
||||
sheet_names = matches
|
||||
logger.info(f"使用正则提取工作表: {sheet_names}")
|
||||
|
||||
logger.info(f"从 XML 提取工作表: {sheet_names}")
|
||||
return sheet_names
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从 XML 提取工作表名称失败: {e}")
|
||||
return []
|
||||
|
||||
def _read_excel_sheet_xml(self, file_path: str, sheet_name: str = None, header_row: int = 0) -> pd.DataFrame:
|
||||
"""
|
||||
从 XML 直接读取 Excel 工作表数据
|
||||
|
||||
当 pandas 无法正确解析时使用此方法。
|
||||
|
||||
Args:
|
||||
file_path: Excel 文件路径
|
||||
sheet_name: 工作表名称(如果为 None,读取第一个工作表)
|
||||
header_row: 表头行号(0-indexed)
|
||||
|
||||
Returns:
|
||||
DataFrame
|
||||
"""
|
||||
import zipfile
|
||||
from xml.etree import ElementTree as ET
|
||||
|
||||
# 常见的命名空间
|
||||
COMMON_NAMESPACES = [
|
||||
'http://schemas.openxmlformats.org/spreadsheetml/2006/main',
|
||||
'http://schemas.openxmlformats.org/spreadsheetml/2005/main',
|
||||
'http://schemas.openxmlformats.org/spreadsheetml/2004/main',
|
||||
'http://schemas.openxmlformats.org/spreadsheetml/2003/main',
|
||||
]
|
||||
|
||||
def find_elements_with_ns(root, tag_name):
|
||||
"""灵活查找元素,支持任意命名空间"""
|
||||
results = []
|
||||
# 方法1:用固定命名空间
|
||||
for ns in COMMON_NAMESPACES:
|
||||
try:
|
||||
elems = root.findall(f'.//{{{ns}}}{tag_name}')
|
||||
if elems:
|
||||
results.extend(elems)
|
||||
except:
|
||||
pass
|
||||
# 方法2:不带命名空间查找
|
||||
if not results:
|
||||
for elem in root.iter():
|
||||
if elem.tag.endswith('}' + tag_name):
|
||||
results.append(elem)
|
||||
return results
|
||||
|
||||
with zipfile.ZipFile(file_path, 'r') as z:
|
||||
# 获取工作表名称
|
||||
sheet_names = self._extract_sheet_names_from_xml(file_path)
|
||||
if not sheet_names:
|
||||
raise ValueError("无法从 Excel 文件中找到工作表")
|
||||
|
||||
# 确定要读取的工作表
|
||||
target_sheet = sheet_name if sheet_name and sheet_name in sheet_names else sheet_names[0]
|
||||
sheet_index = sheet_names.index(target_sheet) + 1 # sheet1.xml, sheet2.xml, ...
|
||||
|
||||
# 读取 shared strings - 尝试多种路径
|
||||
shared_strings = []
|
||||
ss_paths = ['xl/sharedStrings.xml', 'xl\\sharedStrings.xml', 'sharedStrings.xml']
|
||||
for ss_path in ss_paths:
|
||||
if ss_path in z.namelist():
|
||||
try:
|
||||
ss_content = z.read(ss_path)
|
||||
ss_root = ET.fromstring(ss_content)
|
||||
for si in find_elements_with_ns(ss_root, 'si'):
|
||||
t_elements = [c for c in si if c.tag.endswith('}t') or c.tag == 't']
|
||||
if t_elements:
|
||||
shared_strings.append(t_elements[0].text or '')
|
||||
else:
|
||||
shared_strings.append('')
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"读取 sharedStrings 失败: {e}")
|
||||
|
||||
# 读取工作表 - 尝试多种可能的路径
|
||||
sheet_content = None
|
||||
sheet_paths = [
|
||||
f'xl/worksheets/sheet{sheet_index}.xml',
|
||||
f'xl\\worksheets\\sheet{sheet_index}.xml',
|
||||
f'worksheets/sheet{sheet_index}.xml',
|
||||
]
|
||||
for sp in sheet_paths:
|
||||
if sp in z.namelist():
|
||||
sheet_content = z.read(sp)
|
||||
break
|
||||
|
||||
if sheet_content is None:
|
||||
raise ValueError(f"工作表文件 sheet{sheet_index}.xml 不存在")
|
||||
|
||||
root = ET.fromstring(sheet_content)
|
||||
|
||||
# 收集所有行数据
|
||||
all_rows = []
|
||||
headers = {}
|
||||
|
||||
for row in find_elements_with_ns(root, 'row'):
|
||||
row_idx = int(row.get('r', 0))
|
||||
row_cells = {}
|
||||
for cell in find_elements_with_ns(row, 'c'):
|
||||
cell_ref = cell.get('r', '')
|
||||
col_letters = ''.join(filter(str.isalpha, cell_ref))
|
||||
cell_type = cell.get('t', 'n')
|
||||
v_elements = find_elements_with_ns(cell, 'v')
|
||||
v = v_elements[0] if v_elements else None
|
||||
|
||||
if v is not None and v.text:
|
||||
if cell_type == 's':
|
||||
try:
|
||||
row_cells[col_letters] = shared_strings[int(v.text)]
|
||||
except (ValueError, IndexError):
|
||||
row_cells[col_letters] = v.text
|
||||
elif cell_type == 'b':
|
||||
row_cells[col_letters] = v.text == '1'
|
||||
else:
|
||||
row_cells[col_letters] = v.text
|
||||
else:
|
||||
row_cells[col_letters] = None
|
||||
|
||||
if row_idx == header_row + 1:
|
||||
headers = {**row_cells}
|
||||
elif row_idx > header_row + 1:
|
||||
all_rows.append(row_cells)
|
||||
|
||||
# 构建 DataFrame
|
||||
if headers:
|
||||
col_order = list(headers.keys())
|
||||
df = pd.DataFrame(all_rows)
|
||||
if not df.empty:
|
||||
df = df[col_order]
|
||||
df.columns = [headers.get(col, col) for col in df.columns]
|
||||
else:
|
||||
df = pd.DataFrame(all_rows)
|
||||
|
||||
return df
|
||||
|
||||
def _df_to_dict(self, df: pd.DataFrame) -> Dict[str, Any]:
|
||||
"""
|
||||
将 DataFrame 转换为字典,处理 NaN 值
|
||||
|
||||
Args:
|
||||
df: pandas DataFrame
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 转换后的字典
|
||||
"""
|
||||
# 将 NaN 替换为 None
|
||||
df = df.replace({pd.NA: None, float('nan'): None})
|
||||
|
||||
# 转换为字典列表(每一行一个字典)
|
||||
rows = df.to_dict(orient='records')
|
||||
|
||||
return {
|
||||
"columns": df.columns.tolist(),
|
||||
"rows": rows,
|
||||
"row_count": len(rows),
|
||||
"column_count": len(df.columns) if not df.empty else 0
|
||||
}
|
||||
14
backend/app/instruction/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
指令执行模块
|
||||
|
||||
支持文档智能操作交互,包括意图解析和指令执行
|
||||
"""
|
||||
from .intent_parser import IntentParser, intent_parser
|
||||
from .executor import InstructionExecutor, instruction_executor
|
||||
|
||||
__all__ = [
|
||||
"IntentParser",
|
||||
"intent_parser",
|
||||
"InstructionExecutor",
|
||||
"instruction_executor",
|
||||
]
|
||||
572
backend/app/instruction/executor.py
Normal file
@@ -0,0 +1,572 @@
|
||||
"""
|
||||
指令执行器模块
|
||||
|
||||
将自然语言指令转换为可执行操作
|
||||
"""
|
||||
import logging
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.services.template_fill_service import template_fill_service
|
||||
from app.services.rag_service import rag_service
|
||||
from app.services.markdown_ai_service import markdown_ai_service
|
||||
from app.core.database import mongodb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InstructionExecutor:
|
||||
"""指令执行器"""
|
||||
|
||||
def __init__(self):
|
||||
self.intent_parser = None # 将通过 set_intent_parser 设置
|
||||
|
||||
def set_intent_parser(self, intent_parser):
|
||||
"""设置意图解析器"""
|
||||
self.intent_parser = intent_parser
|
||||
|
||||
async def execute(self, instruction: str, context: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
执行指令
|
||||
|
||||
Args:
|
||||
instruction: 自然语言指令
|
||||
context: 执行上下文(包含文档信息等)
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
if self.intent_parser is None:
|
||||
from app.instruction.intent_parser import intent_parser
|
||||
self.intent_parser = intent_parser
|
||||
|
||||
context = context or {}
|
||||
|
||||
# 解析意图
|
||||
intent, params = await self.intent_parser.parse(instruction)
|
||||
|
||||
# 根据意图类型执行相应操作
|
||||
if intent == "extract":
|
||||
return await self._execute_extract(params, context)
|
||||
elif intent == "fill_table":
|
||||
return await self._execute_fill_table(params, context)
|
||||
elif intent == "summarize":
|
||||
return await self._execute_summarize(params, context)
|
||||
elif intent == "question":
|
||||
return await self._execute_question(params, context)
|
||||
elif intent == "search":
|
||||
return await self._execute_search(params, context)
|
||||
elif intent == "compare":
|
||||
return await self._execute_compare(params, context)
|
||||
elif intent == "edit":
|
||||
return await self._execute_edit(params, context)
|
||||
elif intent == "transform":
|
||||
return await self._execute_transform(params, context)
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"未知意图类型: {intent}",
|
||||
"message": "无法理解该指令,请尝试更明确的描述"
|
||||
}
|
||||
|
||||
async def _execute_extract(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行信息提取"""
|
||||
try:
|
||||
target_fields = params.get("field_refs", [])
|
||||
doc_ids = params.get("document_refs", [])
|
||||
|
||||
if not target_fields:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "未指定要提取的字段",
|
||||
"message": "请明确说明要提取哪些字段,如:'提取医院数量和床位数'"
|
||||
}
|
||||
|
||||
# 如果指定了文档,验证文档存在
|
||||
if doc_ids and "all_docs" not in doc_ids:
|
||||
valid_docs = []
|
||||
for doc_ref in doc_ids:
|
||||
doc_id = doc_ref.replace("doc_", "")
|
||||
doc = await mongodb.get_document(doc_id)
|
||||
if doc:
|
||||
valid_docs.append(doc)
|
||||
if not valid_docs:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "指定的文档不存在",
|
||||
"message": "请检查文档编号是否正确"
|
||||
}
|
||||
context["source_docs"] = valid_docs
|
||||
|
||||
# 构建字段列表
|
||||
fields = []
|
||||
for i, field_name in enumerate(target_fields):
|
||||
fields.append({
|
||||
"name": field_name,
|
||||
"cell": f"A{i+1}",
|
||||
"field_type": "text",
|
||||
"required": False
|
||||
})
|
||||
|
||||
# 调用填表服务
|
||||
result = await template_fill_service.fill_template(
|
||||
template_fields=fields,
|
||||
source_doc_ids=[doc.get("_id") for doc in context.get("source_docs", [])] if context.get("source_docs") else None,
|
||||
user_hint=f"请提取字段: {', '.join(target_fields)}"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "extract",
|
||||
"extracted_data": result.get("filled_data", {}),
|
||||
"fields": target_fields,
|
||||
"message": f"成功提取 {len(result.get('filled_data', {}))} 个字段"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"提取执行失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"提取失败: {str(e)}"
|
||||
}
|
||||
|
||||
async def _execute_fill_table(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行填表操作"""
|
||||
try:
|
||||
template_file = context.get("template_file")
|
||||
if not template_file:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "未提供表格模板",
|
||||
"message": "请先上传要填写的表格模板"
|
||||
}
|
||||
|
||||
# 获取源文档
|
||||
source_docs = context.get("source_docs", [])
|
||||
source_doc_ids = [doc.get("_id") for doc in source_docs if doc.get("_id")]
|
||||
|
||||
# 获取字段
|
||||
fields = context.get("template_fields", [])
|
||||
|
||||
# 调用填表服务
|
||||
result = await template_fill_service.fill_template(
|
||||
template_fields=fields,
|
||||
source_doc_ids=source_doc_ids if source_doc_ids else None,
|
||||
source_file_paths=context.get("source_file_paths"),
|
||||
user_hint=params.get("user_hint"),
|
||||
template_id=template_file if isinstance(template_file, str) else None,
|
||||
template_file_type=params.get("template", {}).get("type", "xlsx")
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "fill_table",
|
||||
"result": result,
|
||||
"message": f"填表完成,成功填写 {len(result.get('filled_data', {}))} 个字段"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"填表执行失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"填表失败: {str(e)}"
|
||||
}
|
||||
|
||||
async def _execute_summarize(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行摘要总结"""
|
||||
try:
|
||||
docs = context.get("source_docs", [])
|
||||
if not docs:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "没有可用的文档",
|
||||
"message": "请先上传要总结的文档"
|
||||
}
|
||||
|
||||
summaries = []
|
||||
for doc in docs[:5]: # 最多处理5个文档
|
||||
content = doc.get("content", "")[:5000] # 限制内容长度
|
||||
if content:
|
||||
summaries.append({
|
||||
"filename": doc.get("metadata", {}).get("original_filename", "未知"),
|
||||
"content_preview": content[:500] + "..." if len(content) > 500 else content
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "summarize",
|
||||
"summaries": summaries,
|
||||
"message": f"找到 {len(summaries)} 个文档可供参考"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"摘要执行失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"摘要生成失败: {str(e)}"
|
||||
}
|
||||
|
||||
async def _execute_question(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行问答"""
|
||||
try:
|
||||
question = params.get("question", "")
|
||||
if not question:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "未提供问题",
|
||||
"message": "请输入要回答的问题"
|
||||
}
|
||||
|
||||
# 使用 RAG 检索相关文档
|
||||
docs = context.get("source_docs", [])
|
||||
rag_results = []
|
||||
|
||||
for doc in docs:
|
||||
doc_id = doc.get("_id", "")
|
||||
if doc_id:
|
||||
results = rag_service.retrieve_by_doc_id(doc_id, top_k=3)
|
||||
rag_results.extend(results)
|
||||
|
||||
# 构建上下文
|
||||
context_text = "\n\n".join([
|
||||
r.get("content", "") for r in rag_results[:5]
|
||||
]) if rag_results else ""
|
||||
|
||||
# 如果没有 RAG 结果,使用文档内容
|
||||
if not context_text:
|
||||
context_text = "\n\n".join([
|
||||
doc.get("content", "")[:3000] for doc in docs[:3] if doc.get("content")
|
||||
])
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "question",
|
||||
"question": question,
|
||||
"context_preview": context_text[:500] + "..." if len(context_text) > 500 else context_text,
|
||||
"message": "已找到相关上下文,可进行问答"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"问答执行失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"问答处理失败: {str(e)}"
|
||||
}
|
||||
|
||||
async def _execute_search(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行搜索"""
|
||||
try:
|
||||
field_refs = params.get("field_refs", [])
|
||||
query = " ".join(field_refs) if field_refs else params.get("question", "")
|
||||
|
||||
if not query:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "未提供搜索关键词",
|
||||
"message": "请输入要搜索的关键词"
|
||||
}
|
||||
|
||||
# 使用 RAG 检索
|
||||
results = rag_service.retrieve(query, top_k=10, min_score=0.3)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "search",
|
||||
"query": query,
|
||||
"results": [
|
||||
{
|
||||
"content": r.get("content", "")[:200],
|
||||
"score": r.get("score", 0),
|
||||
"doc_id": r.get("doc_id", "")
|
||||
}
|
||||
for r in results[:10]
|
||||
],
|
||||
"message": f"找到 {len(results)} 条相关结果"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"搜索执行失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"搜索失败: {str(e)}"
|
||||
}
|
||||
|
||||
async def _execute_compare(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行对比分析"""
|
||||
try:
|
||||
docs = context.get("source_docs", [])
|
||||
if len(docs) < 2:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "对比需要至少2个文档",
|
||||
"message": "请上传至少2个文档进行对比"
|
||||
}
|
||||
|
||||
# 提取文档基本信息
|
||||
comparison = []
|
||||
for i, doc in enumerate(docs[:5]):
|
||||
comparison.append({
|
||||
"index": i + 1,
|
||||
"filename": doc.get("metadata", {}).get("original_filename", "未知"),
|
||||
"doc_type": doc.get("doc_type", "未知"),
|
||||
"content_length": len(doc.get("content", "")),
|
||||
"has_tables": bool(doc.get("structured_data", {}).get("tables")),
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "compare",
|
||||
"comparison": comparison,
|
||||
"message": f"对比了 {len(comparison)} 个文档的基本信息"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"对比执行失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"对比分析失败: {str(e)}"
|
||||
}
|
||||
|
||||
async def _execute_edit(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行文档编辑操作"""
|
||||
try:
|
||||
docs = context.get("source_docs", [])
|
||||
if not docs:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "没有可用的文档",
|
||||
"message": "请先上传要编辑的文档"
|
||||
}
|
||||
|
||||
doc = docs[0] # 默认编辑第一个文档
|
||||
content = doc.get("content", "")
|
||||
original_filename = doc.get("metadata", {}).get("original_filename", "未知文档")
|
||||
|
||||
if not content:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "文档内容为空",
|
||||
"message": "该文档没有可编辑的内容"
|
||||
}
|
||||
|
||||
# 使用 LLM 进行文本润色/编辑
|
||||
prompt = f"""请对以下文档内容进行编辑处理。
|
||||
|
||||
原文内容:
|
||||
{content[:8000]}
|
||||
|
||||
编辑要求:
|
||||
- 润色表述,使其更加专业流畅
|
||||
- 修正明显的语法错误
|
||||
- 保持原意不变
|
||||
- 只返回编辑后的内容,不要解释
|
||||
|
||||
请直接输出编辑后的内容:"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的文本编辑助手。请直接输出编辑后的内容。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
from app.services.llm_service import llm_service
|
||||
response = await llm_service.chat(messages=messages, temperature=0.3, max_tokens=8000)
|
||||
edited_content = llm_service.extract_message_content(response)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "edit",
|
||||
"edited_content": edited_content,
|
||||
"original_filename": original_filename,
|
||||
"message": "文档编辑完成,内容已返回"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"编辑执行失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"编辑处理失败: {str(e)}"
|
||||
}
|
||||
|
||||
async def _execute_transform(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
执行格式转换操作
|
||||
|
||||
支持:
|
||||
- Word -> Excel
|
||||
- Excel -> Word
|
||||
- Markdown -> Word
|
||||
- Word -> Markdown
|
||||
"""
|
||||
try:
|
||||
docs = context.get("source_docs", [])
|
||||
if not docs:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "没有可用的文档",
|
||||
"message": "请先上传要转换的文档"
|
||||
}
|
||||
|
||||
# 获取目标格式
|
||||
template_info = params.get("template", {})
|
||||
target_type = template_info.get("type", "")
|
||||
|
||||
if not target_type:
|
||||
# 尝试从指令中推断
|
||||
instruction = params.get("instruction", "")
|
||||
if "excel" in instruction.lower() or "xlsx" in instruction.lower():
|
||||
target_type = "xlsx"
|
||||
elif "word" in instruction.lower() or "docx" in instruction.lower():
|
||||
target_type = "docx"
|
||||
elif "markdown" in instruction.lower() or "md" in instruction.lower():
|
||||
target_type = "md"
|
||||
|
||||
if not target_type:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "未指定目标格式",
|
||||
"message": "请说明要转换成什么格式(如:转成Excel、转成Word)"
|
||||
}
|
||||
|
||||
doc = docs[0]
|
||||
content = doc.get("content", "")
|
||||
structured_data = doc.get("structured_data", {})
|
||||
original_filename = doc.get("metadata", {}).get("original_filename", "未知文档")
|
||||
|
||||
# 构建转换内容
|
||||
if structured_data.get("tables"):
|
||||
# 有表格数据,生成表格格式的内容
|
||||
tables = structured_data.get("tables", [])
|
||||
table_content = []
|
||||
for i, table in enumerate(tables[:3]): # 最多处理3个表格
|
||||
headers = table.get("headers", [])
|
||||
rows = table.get("rows", [])[:20] # 最多20行
|
||||
if headers:
|
||||
table_content.append(f"【表格 {i+1}】")
|
||||
table_content.append(" | ".join(str(h) for h in headers))
|
||||
table_content.append(" | ".join(["---"] * len(headers)))
|
||||
for row in rows:
|
||||
if isinstance(row, list):
|
||||
table_content.append(" | ".join(str(c) for c in row))
|
||||
elif isinstance(row, dict):
|
||||
table_content.append(" | ".join(str(row.get(h, "")) for h in headers))
|
||||
table_content.append("")
|
||||
|
||||
if target_type == "xlsx":
|
||||
# 生成 Excel 格式的数据(JSON)
|
||||
excel_data = []
|
||||
for table in tables[:1]: # 只处理第一个表格
|
||||
headers = table.get("headers", [])
|
||||
rows = table.get("rows", [])[:100]
|
||||
for row in rows:
|
||||
if isinstance(row, list):
|
||||
excel_data.append(dict(zip(headers, row)))
|
||||
elif isinstance(row, dict):
|
||||
excel_data.append(row)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "transform",
|
||||
"transform_type": "to_excel",
|
||||
"target_format": "xlsx",
|
||||
"excel_data": excel_data,
|
||||
"headers": headers,
|
||||
"message": f"已转换为 Excel 格式,包含 {len(excel_data)} 行数据"
|
||||
}
|
||||
elif target_type in ["docx", "word"]:
|
||||
# 生成 Word 格式的文本
|
||||
word_content = f"# {original_filename}\n\n"
|
||||
word_content += "\n".join(table_content)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "transform",
|
||||
"transform_type": "to_word",
|
||||
"target_format": "docx",
|
||||
"content": word_content,
|
||||
"message": "已转换为 Word 格式"
|
||||
}
|
||||
elif target_type == "md":
|
||||
# 生成 Markdown 格式
|
||||
md_content = f"# {original_filename}\n\n"
|
||||
md_content += "\n".join(table_content)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "transform",
|
||||
"transform_type": "to_markdown",
|
||||
"target_format": "md",
|
||||
"content": md_content,
|
||||
"message": "已转换为 Markdown 格式"
|
||||
}
|
||||
|
||||
# 无表格数据,使用纯文本内容转换
|
||||
if target_type == "xlsx":
|
||||
# 将文本内容转为 Excel 格式(每行作为一列)
|
||||
lines = [line.strip() for line in content.split("\n") if line.strip()][:100]
|
||||
excel_data = [{"行号": i+1, "内容": line} for i, line in enumerate(lines)]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "transform",
|
||||
"transform_type": "to_excel",
|
||||
"target_format": "xlsx",
|
||||
"excel_data": excel_data,
|
||||
"headers": ["行号", "内容"],
|
||||
"message": f"已将文本内容转换为 Excel,包含 {len(excel_data)} 行"
|
||||
}
|
||||
elif target_type in ["docx", "word"]:
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "transform",
|
||||
"transform_type": "to_word",
|
||||
"target_format": "docx",
|
||||
"content": content,
|
||||
"message": "文档内容已准备好,可下载为 Word 格式"
|
||||
}
|
||||
elif target_type == "md":
|
||||
# 简单的文本转 Markdown
|
||||
md_lines = []
|
||||
for line in content.split("\n"):
|
||||
line = line.strip()
|
||||
if line:
|
||||
# 简单处理:如果行不长且不是列表格式,作为段落
|
||||
if len(line) < 100 and not line.startswith(("-", "*", "1.", "2.", "3.")):
|
||||
md_lines.append(line)
|
||||
else:
|
||||
md_lines.append(line)
|
||||
else:
|
||||
md_lines.append("")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "transform",
|
||||
"transform_type": "to_markdown",
|
||||
"target_format": "md",
|
||||
"content": "\n".join(md_lines),
|
||||
"message": "已转换为 Markdown 格式"
|
||||
}
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": "不支持的目标格式",
|
||||
"message": f"暂不支持转换为 {target_type} 格式"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"格式转换失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"格式转换失败: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
# 全局单例
|
||||
instruction_executor = InstructionExecutor()
|
||||
242
backend/app/instruction/intent_parser.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""
|
||||
意图解析器模块
|
||||
|
||||
解析用户自然语言指令,识别意图和参数
|
||||
"""
|
||||
import re
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IntentParser:
|
||||
"""意图解析器"""
|
||||
|
||||
# 意图类型定义
|
||||
INTENT_EXTRACT = "extract" # 信息提取
|
||||
INTENT_FILL_TABLE = "fill_table" # 填表
|
||||
INTENT_SUMMARIZE = "summarize" # 摘要总结
|
||||
INTENT_QUESTION = "question" # 问答
|
||||
INTENT_SEARCH = "search" # 搜索
|
||||
INTENT_COMPARE = "compare" # 对比分析
|
||||
INTENT_TRANSFORM = "transform" # 格式转换
|
||||
INTENT_EDIT = "edit" # 编辑文档
|
||||
INTENT_UNKNOWN = "unknown" # 未知
|
||||
|
||||
# 意图关键词映射
|
||||
INTENT_KEYWORDS = {
|
||||
INTENT_EXTRACT: ["提取", "抽取", "获取", "找出", "查找", "识别", "找到"],
|
||||
INTENT_FILL_TABLE: ["填表", "填写", "填充", "录入", "导入到表格", "填写到"],
|
||||
INTENT_SUMMARIZE: ["总结", "摘要", "概括", "概述", "归纳", "提炼"],
|
||||
INTENT_QUESTION: ["问答", "回答", "解释", "什么是", "为什么", "如何", "怎样", "多少", "几个"],
|
||||
INTENT_SEARCH: ["搜索", "查找", "检索", "查询", "找"],
|
||||
INTENT_COMPARE: ["对比", "比较", "差异", "区别", "不同"],
|
||||
INTENT_TRANSFORM: ["转换", "转化", "变成", "转为", "导出"],
|
||||
INTENT_EDIT: ["修改", "编辑", "调整", "改写", "润色", "优化"],
|
||||
}
|
||||
|
||||
# 实体模式定义
|
||||
ENTITY_PATTERNS = {
|
||||
"number": [r"\d+", r"[一二三四五六七八九十百千万]+"],
|
||||
"date": [r"\d{4}年", r"\d{1,2}月", r"\d{1,2}日"],
|
||||
"percentage": [r"\d+(\.\d+)?%", r"\d+(\.\d+)?‰"],
|
||||
"currency": [r"\d+(\.\d+)?万元", r"\d+(\.\d+)?亿元", r"\d+(\.\d+)?元"],
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.intent_history: List[Dict[str, Any]] = []
|
||||
|
||||
async def parse(self, text: str) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
解析自然语言指令
|
||||
|
||||
Args:
|
||||
text: 用户输入的自然语言
|
||||
|
||||
Returns:
|
||||
(意图类型, 参数字典)
|
||||
"""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return self.INTENT_UNKNOWN, {}
|
||||
|
||||
# 记录历史
|
||||
self.intent_history.append({"text": text, "intent": None})
|
||||
|
||||
# 识别意图
|
||||
intent = self._recognize_intent(text)
|
||||
|
||||
# 提取参数
|
||||
params = self._extract_params(text, intent)
|
||||
|
||||
# 更新历史
|
||||
if self.intent_history:
|
||||
self.intent_history[-1]["intent"] = intent
|
||||
|
||||
logger.info(f"意图解析: text={text[:50]}..., intent={intent}, params={params}")
|
||||
|
||||
return intent, params
|
||||
|
||||
def _recognize_intent(self, text: str) -> str:
|
||||
"""识别意图类型"""
|
||||
intent_scores: Dict[str, float] = {}
|
||||
|
||||
for intent, keywords in self.INTENT_KEYWORDS.items():
|
||||
score = 0
|
||||
for keyword in keywords:
|
||||
if keyword in text:
|
||||
score += 1
|
||||
if score > 0:
|
||||
intent_scores[intent] = score
|
||||
|
||||
if not intent_scores:
|
||||
return self.INTENT_UNKNOWN
|
||||
|
||||
# 返回得分最高的意图
|
||||
return max(intent_scores, key=intent_scores.get)
|
||||
|
||||
def _extract_params(self, text: str, intent: str) -> Dict[str, Any]:
|
||||
"""提取参数"""
|
||||
params: Dict[str, Any] = {
|
||||
"entities": self._extract_entities(text),
|
||||
"document_refs": self._extract_document_refs(text),
|
||||
"field_refs": self._extract_field_refs(text),
|
||||
"template_refs": self._extract_template_refs(text),
|
||||
}
|
||||
|
||||
# 根据意图类型提取特定参数
|
||||
if intent == self.INTENT_QUESTION:
|
||||
params["question"] = text
|
||||
params["focus"] = self._extract_question_focus(text)
|
||||
elif intent == self.INTENT_FILL_TABLE:
|
||||
params["template"] = self._extract_template_info(text)
|
||||
elif intent == self.INTENT_EXTRACT:
|
||||
params["target_fields"] = self._extract_target_fields(text)
|
||||
|
||||
return params
|
||||
|
||||
def _extract_entities(self, text: str) -> Dict[str, List[str]]:
|
||||
"""提取实体"""
|
||||
entities: Dict[str, List[str]] = {}
|
||||
|
||||
for entity_type, patterns in self.ENTITY_PATTERNS.items():
|
||||
matches = []
|
||||
for pattern in patterns:
|
||||
found = re.findall(pattern, text)
|
||||
matches.extend(found)
|
||||
if matches:
|
||||
entities[entity_type] = list(set(matches))
|
||||
|
||||
return entities
|
||||
|
||||
def _extract_document_refs(self, text: str) -> List[str]:
|
||||
"""提取文档引用"""
|
||||
# 匹配 "文档1"、"doc1"、"第一个文档" 等
|
||||
refs = []
|
||||
|
||||
# 数字索引: 文档1, doc1, 第1个文档
|
||||
num_patterns = [
|
||||
r"[文档doc]+(\d+)",
|
||||
r"第(\d+)个文档",
|
||||
r"第(\d+)份",
|
||||
]
|
||||
for pattern in num_patterns:
|
||||
matches = re.findall(pattern, text.lower())
|
||||
refs.extend([f"doc_{m}" for m in matches])
|
||||
|
||||
# "所有文档"、"全部文档"
|
||||
if any(kw in text for kw in ["所有", "全部", "整个"]):
|
||||
refs.append("all_docs")
|
||||
|
||||
return refs
|
||||
|
||||
def _extract_field_refs(self, text: str) -> List[str]:
|
||||
"""提取字段引用"""
|
||||
fields = []
|
||||
|
||||
# 匹配引号内的字段名
|
||||
quoted = re.findall(r"['\"『「]([^'\"』」]+)['\"』」]", text)
|
||||
fields.extend(quoted)
|
||||
|
||||
# 匹配 "xxx字段"、"xxx列" 等
|
||||
field_patterns = [
|
||||
r"([^\s]+)字段",
|
||||
r"([^\s]+)列",
|
||||
r"([^\s]+)数据",
|
||||
]
|
||||
for pattern in field_patterns:
|
||||
matches = re.findall(pattern, text)
|
||||
fields.extend(matches)
|
||||
|
||||
return list(set(fields))
|
||||
|
||||
def _extract_template_refs(self, text: str) -> List[str]:
|
||||
"""提取模板引用"""
|
||||
templates = []
|
||||
|
||||
# 匹配 "表格模板"、"Excel模板"、"表1" 等
|
||||
template_patterns = [
|
||||
r"([^\s]+模板)",
|
||||
r"表(\d+)",
|
||||
r"([^\s]+表格)",
|
||||
]
|
||||
for pattern in template_patterns:
|
||||
matches = re.findall(pattern, text)
|
||||
templates.extend(matches)
|
||||
|
||||
return list(set(templates))
|
||||
|
||||
def _extract_question_focus(self, text: str) -> Optional[str]:
|
||||
"""提取问题焦点"""
|
||||
# "什么是XXX"、"XXX是什么"
|
||||
match = re.search(r"[什么是]([^?]+)", text)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
# "XXX有多少"
|
||||
match = re.search(r"([^?]+)有多少", text)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
return None
|
||||
|
||||
def _extract_template_info(self, text: str) -> Optional[Dict[str, str]]:
|
||||
"""提取模板信息"""
|
||||
template_info: Dict[str, str] = {}
|
||||
|
||||
# 提取模板类型
|
||||
if "excel" in text.lower() or "xlsx" in text.lower() or "电子表格" in text:
|
||||
template_info["type"] = "xlsx"
|
||||
elif "word" in text.lower() or "docx" in text.lower() or "文档" in text:
|
||||
template_info["type"] = "docx"
|
||||
|
||||
return template_info if template_info else None
|
||||
|
||||
def _extract_target_fields(self, text: str) -> List[str]:
|
||||
"""提取目标字段"""
|
||||
fields = []
|
||||
|
||||
# 匹配 "提取XXX和YYY"、"抽取XXX、YYY"
|
||||
patterns = [
|
||||
r"提取([^(and|,|,)+]+?)(?:和|与|、|,|plus)",
|
||||
r"抽取([^(and|,|,)+]+?)(?:和|与|、|,|plus)",
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
matches = re.findall(pattern, text)
|
||||
fields.extend([m.strip() for m in matches if m.strip()])
|
||||
|
||||
return list(set(fields))
|
||||
|
||||
def get_intent_history(self) -> List[Dict[str, Any]]:
|
||||
"""获取意图历史"""
|
||||
return self.intent_history
|
||||
|
||||
def clear_history(self):
|
||||
"""清空历史"""
|
||||
self.intent_history = []
|
||||
|
||||
|
||||
# 全局单例
|
||||
intent_parser = IntentParser()
|
||||
@@ -1,19 +1,263 @@
|
||||
from fastapi import FastAPI
|
||||
from config import settings
|
||||
"""
|
||||
FastAPI 应用主入口
|
||||
"""
|
||||
# ========== 压制 MongoDB 疯狂刷屏日志 ==========
|
||||
import logging
|
||||
logging.getLogger("pymongo").setLevel(logging.WARNING)
|
||||
logging.getLogger("pymongo.topology").setLevel(logging.WARNING)
|
||||
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
||||
# ==============================================
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
import sys
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Callable
|
||||
from functools import wraps
|
||||
|
||||
from fastapi import FastAPI, Request, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.config import settings
|
||||
from app.api import api_router
|
||||
from app.core.database import mysql_db, mongodb, redis_db
|
||||
|
||||
# ==================== 日志配置 ====================
|
||||
|
||||
def setup_logging():
|
||||
"""配置应用日志系统"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# 根日志配置
|
||||
log_level = logging.DEBUG if settings.DEBUG else logging.INFO
|
||||
|
||||
# 日志目录
|
||||
log_dir = Path("data/logs")
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 日志文件路径
|
||||
log_file = log_dir / "app.log"
|
||||
error_log_file = log_dir / "error.log"
|
||||
|
||||
# 控制台处理器
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(log_level)
|
||||
console_formatter = logging.Formatter(
|
||||
fmt="%(asctime)s | %(levelname)-8s | %(name)s:%(lineno)d | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
console_handler.setFormatter(console_formatter)
|
||||
|
||||
# 文件处理器 (所有日志)
|
||||
file_handler = logging.handlers.RotatingFileHandler(
|
||||
log_file,
|
||||
maxBytes=10 * 1024 * 1024, # 10MB
|
||||
backupCount=5,
|
||||
encoding="utf-8"
|
||||
)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_formatter = logging.Formatter(
|
||||
fmt="%(asctime)s | %(levelname)-8s | %(name)s:%(lineno)d | %(funcName)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
|
||||
# 错误日志处理器 (仅ERROR及以上)
|
||||
error_file_handler = logging.handlers.RotatingFileHandler(
|
||||
error_log_file,
|
||||
maxBytes=10 * 1024 * 1024, # 10MB
|
||||
backupCount=5,
|
||||
encoding="utf-8"
|
||||
)
|
||||
error_file_handler.setLevel(logging.ERROR)
|
||||
error_file_handler.setFormatter(file_formatter)
|
||||
|
||||
# 根日志器
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(logging.DEBUG)
|
||||
root_logger.handlers = []
|
||||
root_logger.addHandler(console_handler)
|
||||
root_logger.addHandler(file_handler)
|
||||
root_logger.addHandler(error_file_handler)
|
||||
|
||||
# 第三方库日志级别
|
||||
for lib in ["uvicorn", "uvicorn.access", "fastapi", "httpx", "sqlalchemy"]:
|
||||
logging.getLogger(lib).setLevel(logging.WARNING)
|
||||
|
||||
root_logger.info(f"日志系统初始化完成 | 日志目录: {log_dir}")
|
||||
root_logger.info(f"主日志文件: {log_file} | 错误日志: {error_log_file}")
|
||||
|
||||
return root_logger
|
||||
|
||||
# 初始化日志
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ==================== 请求日志中间件 ====================
|
||||
|
||||
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
"""请求日志中间件 - 记录每个请求的详细信息"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# 生成请求ID
|
||||
request_id = str(uuid.uuid4())[:8]
|
||||
request.state.request_id = request_id
|
||||
|
||||
# 记录请求
|
||||
logger.info(f"→ [{request_id}] {request.method} {request.url.path}")
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
|
||||
# 记录响应
|
||||
logger.info(
|
||||
f"← [{request_id}] {request.method} {request.url.path} "
|
||||
f"| 状态: {response.status_code} | 耗时: N/A"
|
||||
)
|
||||
|
||||
# 添加请求ID到响应头
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"✗ [{request_id}] {request.method} {request.url.path} | 异常: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# ==================== 请求追踪装饰器 ====================
|
||||
|
||||
def log_async_function(func: Callable) -> Callable:
|
||||
"""异步函数日志装饰器"""
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
func_name = func.__name__
|
||||
logger.debug(f"→ {func_name} 开始执行")
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
logger.debug(f"← {func_name} 执行完成")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"✗ {func_name} 执行失败: {str(e)}")
|
||||
raise
|
||||
return wrapper
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
应用生命周期管理
|
||||
|
||||
启动时: 初始化数据库连接
|
||||
关闭时: 关闭数据库连接
|
||||
"""
|
||||
# 启动时
|
||||
logger.info("正在初始化数据库连接...")
|
||||
|
||||
# 初始化 MySQL
|
||||
try:
|
||||
await mysql_db.init_db()
|
||||
logger.info("✓ MySQL 初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"✗ MySQL 初始化失败: {e}")
|
||||
|
||||
# 初始化 MongoDB
|
||||
try:
|
||||
await mongodb.connect()
|
||||
await mongodb.create_indexes()
|
||||
logger.info("✓ MongoDB 初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"✗ MongoDB 初始化失败: {e}")
|
||||
|
||||
# 初始化 Redis
|
||||
try:
|
||||
await redis_db.connect()
|
||||
logger.info("✓ Redis 初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"✗ Redis 初始化失败: {e}")
|
||||
|
||||
logger.info("数据库初始化完成")
|
||||
yield
|
||||
|
||||
# 关闭时
|
||||
logger.info("正在关闭数据库连接...")
|
||||
await mysql_db.close()
|
||||
await mongodb.close()
|
||||
await redis_db.close()
|
||||
logger.info("数据库连接已关闭")
|
||||
|
||||
|
||||
# 创建 FastAPI 应用实例
|
||||
app = FastAPI(
|
||||
title=settings.APP_NAME,
|
||||
openapi_url=f"{settings.API_V1_STR}/openapi.json"
|
||||
description="基于大语言模型的文档理解与多源数据融合系统",
|
||||
version="1.0.0",
|
||||
openapi_url=f"{settings.API_V1_STR}/openapi.json",
|
||||
docs_url=f"{settings.API_V1_STR}/docs",
|
||||
redoc_url=f"{settings.API_V1_STR}/redoc",
|
||||
lifespan=lifespan, # 添加生命周期管理
|
||||
)
|
||||
|
||||
# 配置 CORS 中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 添加请求日志中间件
|
||||
app.add_middleware(RequestLoggingMiddleware)
|
||||
|
||||
# 注册 API 路由
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""根路径"""
|
||||
return {
|
||||
"message": f"Welcome to {settings.APP_NAME}",
|
||||
"status": "online",
|
||||
"debug_mode": settings.DEBUG
|
||||
"version": "1.0.0",
|
||||
"debug_mode": settings.DEBUG,
|
||||
"api_docs": f"{settings.API_V1_STR}/docs"
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""
|
||||
健康检查接口
|
||||
|
||||
返回各数据库连接状态
|
||||
"""
|
||||
# 检查各数据库连接状态
|
||||
mysql_status = "connected" if mysql_db.async_engine else "disconnected"
|
||||
mongodb_status = "connected" if mongodb.client else "disconnected"
|
||||
redis_status = "connected" if redis_db.is_connected else "disconnected"
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": settings.APP_NAME,
|
||||
"databases": {
|
||||
"mysql": mysql_status,
|
||||
"mongodb": mongodb_status,
|
||||
"redis": redis_status,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=True)
|
||||
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host="127.0.0.1",
|
||||
port=8000,
|
||||
reload=settings.DEBUG
|
||||
)
|
||||
|
||||
18
backend/app/models/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
数据模型模块
|
||||
|
||||
定义数据库表结构和数据模型
|
||||
"""
|
||||
from app.core.database.mysql import (
|
||||
Base,
|
||||
DocumentField,
|
||||
DocumentTable,
|
||||
TaskRecord,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"DocumentTable",
|
||||
"DocumentField",
|
||||
"TaskRecord",
|
||||
]
|
||||
172
backend/app/models/document.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
文档数据模型
|
||||
|
||||
定义文档相关的 Pydantic 模型
|
||||
"""
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DocumentType(str, Enum):
|
||||
"""文档类型枚举"""
|
||||
DOCX = "docx"
|
||||
XLSX = "xlsx"
|
||||
MD = "md"
|
||||
TXT = "txt"
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""任务状态枚举"""
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
|
||||
|
||||
# ==================== 解析结果模型 ====================
|
||||
|
||||
class DocumentMetadata(BaseModel):
|
||||
"""文档元数据"""
|
||||
filename: str
|
||||
extension: str
|
||||
file_size: int = 0
|
||||
doc_type: Optional[str] = None
|
||||
sheet_count: Optional[int] = None
|
||||
sheet_names: Optional[List[str]] = None
|
||||
current_sheet: Optional[str] = None
|
||||
row_count: Optional[int] = None
|
||||
column_count: Optional[int] = None
|
||||
columns: Optional[List[str]] = None
|
||||
encoding: Optional[str] = None
|
||||
|
||||
|
||||
class ParseResultData(BaseModel):
|
||||
"""解析结果数据"""
|
||||
columns: List[str] = Field(default_factory=list)
|
||||
rows: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
row_count: int = 0
|
||||
column_count: int = 0
|
||||
|
||||
|
||||
class ParseResult(BaseModel):
|
||||
"""文档解析结果"""
|
||||
success: bool
|
||||
data: Optional[ParseResultData] = None
|
||||
metadata: Optional[DocumentMetadata] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
# ==================== 存储模型 ====================
|
||||
|
||||
class DocumentStore(BaseModel):
|
||||
"""文档存储模型"""
|
||||
doc_id: str
|
||||
doc_type: DocumentType
|
||||
content: str
|
||||
metadata: DocumentMetadata
|
||||
structured_data: Optional[Dict[str, Any]] = None
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class RAGEntry(BaseModel):
|
||||
"""RAG索引条目"""
|
||||
table_name: str
|
||||
field_name: str
|
||||
field_description: str
|
||||
embedding: List[float]
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
# ==================== 任务模型 ====================
|
||||
|
||||
class TaskCreate(BaseModel):
|
||||
"""任务创建请求"""
|
||||
task_type: str
|
||||
input_params: Dict[str, Any]
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
"""任务状态响应"""
|
||||
task_id: str
|
||||
status: TaskStatus
|
||||
progress: int = 0
|
||||
message: Optional[str] = None
|
||||
result: Optional[Any] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
# ==================== 模板填写模型 ====================
|
||||
|
||||
class TemplateField(BaseModel):
|
||||
"""模板字段"""
|
||||
cell: str = Field(description="单元格位置, 如 A1")
|
||||
name: str = Field(description="字段名称")
|
||||
field_type: str = Field(default="text", description="字段类型: text/number/date")
|
||||
required: bool = Field(default=True, description="是否必填")
|
||||
|
||||
|
||||
class TemplateSheet(BaseModel):
|
||||
"""模板工作表"""
|
||||
name: str
|
||||
fields: List[TemplateField]
|
||||
|
||||
|
||||
class TemplateInfo(BaseModel):
|
||||
"""模板信息"""
|
||||
file_path: str
|
||||
file_type: str # xlsx/docx
|
||||
sheets: List[TemplateSheet]
|
||||
|
||||
|
||||
class FillRequest(BaseModel):
|
||||
"""填写请求"""
|
||||
template_path: str
|
||||
template_fields: List[TemplateField]
|
||||
source_doc_ids: Optional[List[str]] = None
|
||||
|
||||
|
||||
class FillResult(BaseModel):
|
||||
"""填写结果"""
|
||||
success: bool
|
||||
filled_data: Dict[str, Any]
|
||||
fill_details: List[Dict[str, Any]]
|
||||
source_documents: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ==================== API 响应模型 ====================
|
||||
|
||||
class UploadResponse(BaseModel):
|
||||
"""上传响应"""
|
||||
task_id: str
|
||||
file_count: int
|
||||
message: str
|
||||
status_url: str
|
||||
|
||||
|
||||
class AnalyzeResponse(BaseModel):
|
||||
"""分析响应"""
|
||||
success: bool
|
||||
analysis: Optional[str] = None
|
||||
structured_data: Optional[Dict[str, Any]] = None
|
||||
model: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
"""查询请求"""
|
||||
user_intent: str
|
||||
table_name: Optional[str] = None
|
||||
top_k: int = Field(default=5, ge=1, le=20)
|
||||
|
||||
|
||||
class QueryResponse(BaseModel):
|
||||
"""查询响应"""
|
||||
success: bool
|
||||
sql_query: Optional[str] = None
|
||||
results: Optional[List[Dict[str, Any]]] = None
|
||||
rag_context: Optional[List[str]] = None
|
||||
error: Optional[str] = None
|
||||
349
backend/app/services/chart_generator_service.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""
|
||||
图表生成服务 - 根据结构化数据生成图表
|
||||
"""
|
||||
import io
|
||||
import base64
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
|
||||
# 使用字体辅助模块配置中文字体
|
||||
from app.services.font_helper import configure_matplotlib_fonts
|
||||
|
||||
configure_matplotlib_fonts()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChartGeneratorService:
|
||||
"""图表生成服务类"""
|
||||
|
||||
def __init__(self):
|
||||
self.output_dir = Path(__file__).resolve().parent.parent.parent / "data" / "charts"
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def generate_charts_from_analysis(
|
||||
self,
|
||||
structured_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
根据提取的结构化数据生成图表
|
||||
|
||||
Args:
|
||||
structured_data: 从 AI 分析结果中提取的结构化数据
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 包含图表数据的结果
|
||||
"""
|
||||
if not structured_data.get("success"):
|
||||
return {
|
||||
"success": False,
|
||||
"error": structured_data.get("error", "数据提取失败")
|
||||
}
|
||||
|
||||
data = structured_data.get("data", {})
|
||||
charts = {}
|
||||
statistics = {}
|
||||
|
||||
try:
|
||||
# 1. 数值型数据图表
|
||||
numeric_data = data.get("numeric_data", [])
|
||||
if numeric_data:
|
||||
charts["numeric_charts"] = self._create_numeric_charts(numeric_data)
|
||||
statistics["numeric_summary"] = self._create_numeric_summary(numeric_data)
|
||||
|
||||
# 2. 分类数据图表
|
||||
categorical_data = data.get("categorical_data", [])
|
||||
if categorical_data:
|
||||
charts["categorical_charts"] = self._create_categorical_charts(categorical_data)
|
||||
|
||||
# 3. 时间序列图表
|
||||
time_series_data = data.get("time_series_data", [])
|
||||
if time_series_data:
|
||||
charts["time_series_chart"] = self._create_time_series_chart(time_series_data)
|
||||
|
||||
# 4. 对比数据图表
|
||||
comparison_data = data.get("comparison_data", [])
|
||||
if comparison_data:
|
||||
charts["comparison_chart"] = self._create_comparison_chart(comparison_data)
|
||||
|
||||
# 5. 表格数据可视化
|
||||
table_data = data.get("table_data")
|
||||
if table_data:
|
||||
charts["table_preview"] = self._create_table_preview(table_data)
|
||||
|
||||
# 元数据
|
||||
metadata = data.get("metadata", {})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"charts": charts,
|
||||
"statistics": statistics,
|
||||
"metadata": metadata,
|
||||
"data_source": "ai_analysis"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成图表失败: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def _create_numeric_charts(self, numeric_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""创建数值型数据图表"""
|
||||
charts = []
|
||||
|
||||
# 提取数值和标签
|
||||
names = [item.get("name", f"项{i}") for i, item in enumerate(numeric_data)]
|
||||
values = [item.get("value", 0) for item in numeric_data]
|
||||
|
||||
if not values:
|
||||
return charts
|
||||
|
||||
# 1. 柱状图
|
||||
try:
|
||||
fig, ax = plt.subplots(figsize=(12, 7))
|
||||
colors = plt.cm.Set3(np.linspace(0, 1, len(values)))
|
||||
bars = ax.bar(names, values, color=colors, alpha=0.8, edgecolor='black', linewidth=0.5)
|
||||
|
||||
# 添加数值标签
|
||||
for bar, value in zip(bars, values):
|
||||
height = bar.get_height()
|
||||
ax.text(bar.get_x() + bar.get_width() / 2., height,
|
||||
f'{value:,.0f}',
|
||||
ha='center', va='bottom', fontsize=9, fontweight='bold')
|
||||
|
||||
ax.set_xlabel('项目', fontsize=10, labelpad=10, fontweight='bold')
|
||||
ax.set_ylabel('数值', fontsize=10, labelpad=10, fontweight='bold')
|
||||
ax.set_title('数值型数据对比', fontsize=12, fontweight='bold', pad=15)
|
||||
ax.set_xticklabels(names, rotation=30, ha='right', fontsize=9)
|
||||
ax.tick_params(axis='both', which='major', labelsize=9)
|
||||
plt.grid(axis='y', alpha=0.3)
|
||||
plt.tight_layout(pad=1.5)
|
||||
|
||||
img_base64 = self._figure_to_base64(fig)
|
||||
charts.append({
|
||||
"type": "bar",
|
||||
"title": "数值型数据对比",
|
||||
"image": img_base64,
|
||||
"data": [{"name": n, "value": v} for n, v in zip(names, values)]
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"创建柱状图失败: {str(e)}")
|
||||
|
||||
# 2. 饼图
|
||||
if len(values) > 0 and len(values) <= 10:
|
||||
try:
|
||||
fig, ax = plt.subplots(figsize=(10, 10))
|
||||
wedges, texts, autotexts = ax.pie(values, labels=names, autopct='%1.1f%%',
|
||||
startangle=90, colors=plt.cm.Set3.colors[:len(values)])
|
||||
|
||||
for autotext in autotexts:
|
||||
autotext.set_color('white')
|
||||
autotext.set_fontsize(9)
|
||||
autotext.set_fontweight('bold')
|
||||
|
||||
ax.set_title('数值型数据占比', fontsize=12, fontweight='bold', pad=15)
|
||||
|
||||
img_base64 = self._figure_to_base64(fig)
|
||||
charts.append({
|
||||
"type": "pie",
|
||||
"title": "数值型数据占比",
|
||||
"image": img_base64,
|
||||
"data": [{"name": n, "value": v} for n, v in zip(names, values)]
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"创建饼图失败: {str(e)}")
|
||||
|
||||
return charts
|
||||
|
||||
def _create_categorical_charts(self, categorical_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""创建分类数据图表"""
|
||||
charts = []
|
||||
|
||||
# 提取数据
|
||||
names = [item.get("name", f"类{i}") for i, item in enumerate(categorical_data)]
|
||||
counts = [item.get("count", 1) for item in categorical_data]
|
||||
|
||||
if not names or not counts:
|
||||
return charts
|
||||
|
||||
# 水平条形图
|
||||
try:
|
||||
fig, ax = plt.subplots(figsize=(10, max(6, len(names) * 0.8)))
|
||||
y_pos = np.arange(len(names))
|
||||
|
||||
bars = ax.barh(y_pos, counts, align='center', color='#10b981', alpha=0.8, edgecolor='black', linewidth=0.5)
|
||||
|
||||
# 添加数值标签
|
||||
for bar, count in zip(bars, counts):
|
||||
width = bar.get_width()
|
||||
ax.text(width, bar.get_y() + bar.get_height() / 2.,
|
||||
f'{count}',
|
||||
ha='left', va='center', fontsize=10, fontweight='bold')
|
||||
|
||||
ax.set_yticks(y_pos)
|
||||
ax.set_yticklabels(names, fontsize=10)
|
||||
ax.invert_yaxis()
|
||||
ax.set_xlabel('数量', fontsize=10, labelpad=10, fontweight='bold')
|
||||
ax.set_title('分类数据分布', fontsize=12, fontweight='bold', pad=15)
|
||||
ax.tick_params(axis='both', which='major', labelsize=9)
|
||||
ax.grid(axis='x', alpha=0.3)
|
||||
plt.tight_layout(pad=1.5)
|
||||
|
||||
img_base64 = self._figure_to_base64(fig)
|
||||
charts.append({
|
||||
"type": "barh",
|
||||
"title": "分类数据分布",
|
||||
"image": img_base64,
|
||||
"data": [{"name": n, "count": c} for n, c in zip(names, counts)]
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"创建分类图表失败: {str(e)}")
|
||||
|
||||
return charts
|
||||
|
||||
def _create_time_series_chart(self, time_series_data: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
"""创建时间序列图表"""
|
||||
if not time_series_data:
|
||||
return None
|
||||
|
||||
try:
|
||||
names = [item.get("name", f"时间{i}") for i, item in enumerate(time_series_data)]
|
||||
values = [item.get("value", 0) for item in time_series_data]
|
||||
|
||||
if len(values) < 2:
|
||||
return None
|
||||
|
||||
fig, ax = plt.subplots(figsize=(14, 7))
|
||||
|
||||
# 绘制折线图和柱状图
|
||||
x_pos = np.arange(len(names))
|
||||
bars = ax.bar(x_pos, values, width=0.4, label='数值', color='#3b82f6', alpha=0.7)
|
||||
|
||||
# 添加折线
|
||||
line = ax.plot(x_pos, values, 'o-', color='#ef4444', linewidth=2.5, markersize=8, label='趋势')
|
||||
|
||||
ax.set_xticks(x_pos)
|
||||
ax.set_xticklabels(names, rotation=30, ha='right', fontsize=9)
|
||||
ax.set_ylabel('数值', fontsize=10, labelpad=10, fontweight='bold')
|
||||
ax.set_title('时间序列数据', fontsize=12, fontweight='bold', pad=15)
|
||||
ax.legend(loc='best', fontsize=9)
|
||||
ax.tick_params(axis='both', which='major', labelsize=9)
|
||||
ax.grid(True, alpha=0.3)
|
||||
plt.tight_layout(pad=1.5)
|
||||
|
||||
img_base64 = self._figure_to_base64(fig)
|
||||
return {
|
||||
"type": "time_series",
|
||||
"title": "时间序列数据",
|
||||
"image": img_base64,
|
||||
"data": [{"name": n, "value": v} for n, v in zip(names, values)]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"创建时间序列图表失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def _create_comparison_chart(self, comparison_data: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
"""创建对比图表"""
|
||||
if not comparison_data:
|
||||
return None
|
||||
|
||||
try:
|
||||
names = [item.get("name", f"对比{i}") for i, item in enumerate(comparison_data)]
|
||||
values = [item.get("value", 0) for item in comparison_data]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 7))
|
||||
|
||||
# 区分正负值
|
||||
colors = ['#10b981' if v >= 0 else '#ef4444' for v in values]
|
||||
bars = ax.bar(names, values, color=colors, alpha=0.8, edgecolor='black', linewidth=0.8)
|
||||
|
||||
# 添加数值标签
|
||||
for bar, value in zip(bars, values):
|
||||
height = bar.get_height()
|
||||
ax.text(bar.get_x() + bar.get_width() / 2., height,
|
||||
f'{value:,.1f}',
|
||||
ha='center', va='bottom' if value >= 0 else 'top',
|
||||
fontsize=10, fontweight='bold')
|
||||
|
||||
# 添加零线
|
||||
ax.axhline(y=0, color='black', linestyle='-', linewidth=1)
|
||||
|
||||
ax.set_ylabel('值', fontsize=10, labelpad=10, fontweight='bold')
|
||||
ax.set_title('对比数据', fontsize=12, fontweight='bold', pad=15)
|
||||
ax.set_xticklabels(names, rotation=30, ha='right', fontsize=9)
|
||||
ax.tick_params(axis='both', which='major', labelsize=9)
|
||||
plt.grid(axis='y', alpha=0.3)
|
||||
plt.tight_layout(pad=1.5)
|
||||
|
||||
img_base64 = self._figure_to_base64(fig)
|
||||
return {
|
||||
"type": "comparison",
|
||||
"title": "对比数据",
|
||||
"image": img_base64,
|
||||
"data": [{"name": n, "value": v} for n, v in zip(names, values)]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"创建对比图表失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def _create_table_preview(self, table_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""创建表格预览数据"""
|
||||
if not table_data:
|
||||
return {}
|
||||
|
||||
columns = table_data.get("columns", [])
|
||||
rows = table_data.get("rows", [])
|
||||
|
||||
return {
|
||||
"columns": columns,
|
||||
"rows": rows[:50], # 限制显示前50行
|
||||
"total_rows": len(rows),
|
||||
"preview_rows": min(50, len(rows))
|
||||
}
|
||||
|
||||
def _create_numeric_summary(self, numeric_data: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""创建数值型数据摘要"""
|
||||
values = [item.get("value", 0) for item in numeric_data if isinstance(item.get("value"), (int, float))]
|
||||
|
||||
if not values:
|
||||
return {}
|
||||
|
||||
return {
|
||||
"count": len(values),
|
||||
"sum": float(sum(values)),
|
||||
"mean": float(np.mean(values)),
|
||||
"median": float(np.median(values)),
|
||||
"min": float(min(values)),
|
||||
"max": float(max(values)),
|
||||
"std": float(np.std(values)) if len(values) > 1 else 0
|
||||
}
|
||||
|
||||
def _figure_to_base64(self, fig) -> str:
|
||||
"""将 matplotlib 图形转换为 base64 字符串"""
|
||||
buf = io.BytesIO()
|
||||
fig.savefig(
|
||||
buf,
|
||||
format='png',
|
||||
dpi=120,
|
||||
bbox_inches='tight',
|
||||
pad_inches=0.3,
|
||||
facecolor='white',
|
||||
edgecolor='none',
|
||||
transparent=False
|
||||
)
|
||||
plt.close(fig)
|
||||
buf.seek(0)
|
||||
img_base64 = base64.b64encode(buf.read()).decode('utf-8')
|
||||
return f"data:image/png;base64,{img_base64}"
|
||||
|
||||
|
||||
# 全局单例
|
||||
chart_generator_service = ChartGeneratorService()
|
||||
253
backend/app/services/excel_ai_service.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
Excel AI 分析服务 - 集成 Excel 解析和 LLM 分析
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
from app.core.document_parser import XlsxParser
|
||||
from app.services.file_service import file_service
|
||||
from app.services.llm_service import llm_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExcelAIService:
|
||||
"""Excel AI 分析服务"""
|
||||
|
||||
def __init__(self):
|
||||
self.parser = XlsxParser()
|
||||
self.file_service = file_service
|
||||
self.llm_service = llm_service
|
||||
|
||||
async def analyze_excel_file(
|
||||
self,
|
||||
file_content: bytes,
|
||||
filename: str,
|
||||
user_prompt: str = "",
|
||||
analysis_type: str = "general",
|
||||
parse_options: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
分析 Excel 文件
|
||||
|
||||
Args:
|
||||
file_content: 文件内容字节
|
||||
filename: 文件名
|
||||
user_prompt: 用户自定义提示词
|
||||
analysis_type: 分析类型
|
||||
parse_options: 解析选项
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 分析结果
|
||||
"""
|
||||
# 1. 保存文件
|
||||
try:
|
||||
saved_path = self.file_service.save_uploaded_file(
|
||||
file_content,
|
||||
filename,
|
||||
subfolder="excel"
|
||||
)
|
||||
logger.info(f"文件已保存: {saved_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"文件保存失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"文件保存失败: {str(e)}",
|
||||
"analysis": None
|
||||
}
|
||||
|
||||
# 2. 解析 Excel 文件
|
||||
try:
|
||||
parse_options = parse_options or {}
|
||||
parse_result = self.parser.parse(saved_path, **parse_options)
|
||||
|
||||
if not parse_result.success:
|
||||
return {
|
||||
"success": False,
|
||||
"error": parse_result.error,
|
||||
"analysis": None
|
||||
}
|
||||
|
||||
excel_data = parse_result.data
|
||||
logger.info(f"Excel 解析成功: {parse_result.metadata}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Excel 解析失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Excel 解析失败: {str(e)}",
|
||||
"analysis": None
|
||||
}
|
||||
|
||||
# 3. 调用 LLM 进行分析
|
||||
try:
|
||||
# 如果有自定义提示词,使用模板分析
|
||||
if user_prompt and user_prompt.strip():
|
||||
llm_result = await self.llm_service.analyze_with_template(
|
||||
excel_data,
|
||||
user_prompt
|
||||
)
|
||||
else:
|
||||
# 否则使用标准分析
|
||||
llm_result = await self.llm_service.analyze_excel_data(
|
||||
excel_data,
|
||||
user_prompt,
|
||||
analysis_type
|
||||
)
|
||||
|
||||
logger.info(f"AI 分析完成: {llm_result['success']}")
|
||||
|
||||
# 4. 组合结果
|
||||
return {
|
||||
"success": True,
|
||||
"excel": {
|
||||
"data": excel_data,
|
||||
"metadata": parse_result.metadata,
|
||||
"saved_path": saved_path
|
||||
},
|
||||
"analysis": llm_result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AI 分析失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"AI 分析失败: {str(e)}",
|
||||
"excel": {
|
||||
"data": excel_data,
|
||||
"metadata": parse_result.metadata
|
||||
},
|
||||
"analysis": None
|
||||
}
|
||||
|
||||
async def batch_analyze_sheets(
|
||||
self,
|
||||
file_content: bytes,
|
||||
filename: str,
|
||||
user_prompt: str = "",
|
||||
analysis_type: str = "general"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
批量分析 Excel 文件的所有工作表
|
||||
|
||||
Args:
|
||||
file_content: 文件内容字节
|
||||
filename: 文件名
|
||||
user_prompt: 用户自定义提示词
|
||||
analysis_type: 分析类型
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 分析结果
|
||||
"""
|
||||
# 1. 保存文件
|
||||
try:
|
||||
saved_path = self.file_service.save_uploaded_file(
|
||||
file_content,
|
||||
filename,
|
||||
subfolder="excel"
|
||||
)
|
||||
logger.info(f"文件已保存: {saved_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"文件保存失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"文件保存失败: {str(e)}",
|
||||
"analysis": None
|
||||
}
|
||||
|
||||
# 2. 解析所有工作表
|
||||
try:
|
||||
parse_result = self.parser.parse_all_sheets(saved_path)
|
||||
|
||||
if not parse_result.success:
|
||||
return {
|
||||
"success": False,
|
||||
"error": parse_result.error,
|
||||
"analysis": None
|
||||
}
|
||||
|
||||
sheets_data = parse_result.data.get("sheets", {})
|
||||
logger.info(f"Excel 解析成功,共 {len(sheets_data)} 个工作表")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Excel 解析失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Excel 解析失败: {str(e)}",
|
||||
"analysis": None
|
||||
}
|
||||
|
||||
# 3. 批量分析每个工作表
|
||||
sheet_analyses = {}
|
||||
errors = {}
|
||||
|
||||
for sheet_name, sheet_data in sheets_data.items():
|
||||
try:
|
||||
# 调用 LLM 分析
|
||||
if user_prompt and user_prompt.strip():
|
||||
llm_result = await self.llm_service.analyze_with_template(
|
||||
sheet_data,
|
||||
user_prompt
|
||||
)
|
||||
else:
|
||||
llm_result = await self.llm_service.analyze_excel_data(
|
||||
sheet_data,
|
||||
user_prompt,
|
||||
analysis_type
|
||||
)
|
||||
|
||||
sheet_analyses[sheet_name] = llm_result
|
||||
|
||||
if not llm_result["success"]:
|
||||
errors[sheet_name] = llm_result.get("error", "未知错误")
|
||||
|
||||
logger.info(f"工作表 '{sheet_name}' 分析完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作表 '{sheet_name}' 分析失败: {str(e)}")
|
||||
errors[sheet_name] = str(e)
|
||||
|
||||
# 4. 组合结果
|
||||
return {
|
||||
"success": len(errors) == 0,
|
||||
"excel": {
|
||||
"sheets": sheets_data,
|
||||
"metadata": parse_result.metadata,
|
||||
"saved_path": saved_path
|
||||
},
|
||||
"analysis": {
|
||||
"sheets": sheet_analyses,
|
||||
"total_sheets": len(sheets_data),
|
||||
"successful": len(sheet_analyses) - len(errors),
|
||||
"errors": errors
|
||||
}
|
||||
}
|
||||
|
||||
def get_supported_analysis_types(self) -> List[str]:
|
||||
"""获取支持的分析类型"""
|
||||
return [
|
||||
{
|
||||
"value": "general",
|
||||
"label": "综合分析",
|
||||
"description": "提供数据概览、关键发现、质量评估和建议"
|
||||
},
|
||||
{
|
||||
"value": "summary",
|
||||
"label": "数据摘要",
|
||||
"description": "快速了解数据的结构、范围和主要内容"
|
||||
},
|
||||
{
|
||||
"value": "statistics",
|
||||
"label": "统计分析",
|
||||
"description": "数值型列的统计信息和分类列的分布"
|
||||
},
|
||||
{
|
||||
"value": "insights",
|
||||
"label": "深度洞察",
|
||||
"description": "深入挖掘数据,提供异常值和业务建议"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# 全局单例
|
||||
excel_ai_service = ExcelAIService()
|
||||
722
backend/app/services/excel_storage_service.py
Normal file
@@ -0,0 +1,722 @@
|
||||
"""
|
||||
Excel 存储服务
|
||||
|
||||
将 Excel 数据转换为 MySQL 表结构并存储
|
||||
"""
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
DateTime,
|
||||
Float,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
inspect,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database.mysql import Base, mysql_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# 设置该模块的日志级别
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
class ExcelStorageService:
|
||||
"""Excel 数据存储服务"""
|
||||
|
||||
def __init__(self):
|
||||
self.mysql_db = mysql_db
|
||||
|
||||
def _extract_sheet_names_from_xml(self, file_path: str) -> list:
|
||||
"""从 Excel 文件的 XML 中提取工作表名称"""
|
||||
import zipfile
|
||||
from xml.etree import ElementTree as ET
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(file_path, 'r') as z:
|
||||
if 'xl/workbook.xml' not in z.namelist():
|
||||
return []
|
||||
content = z.read('xl/workbook.xml')
|
||||
root = ET.fromstring(content)
|
||||
|
||||
# 尝试多种命名空间
|
||||
namespaces = [
|
||||
'http://schemas.openxmlformats.org/spreadsheetml/2006/main',
|
||||
'http://purl.oclc.org/ooxml/spreadsheetml/main',
|
||||
]
|
||||
|
||||
for ns_uri in namespaces:
|
||||
ns = {'main': ns_uri}
|
||||
sheets = root.findall('.//main:sheet', ns)
|
||||
if sheets:
|
||||
names = [s.get('name') for s in sheets if s.get('name')]
|
||||
if names:
|
||||
return names
|
||||
|
||||
# 尝试通配符
|
||||
sheets = root.findall('.//{*}sheet')
|
||||
if not sheets:
|
||||
sheets = root.findall('.//sheet')
|
||||
return [s.get('name') for s in sheets if s.get('name')]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _read_excel_sheet(self, file_path: str, sheet_name: str = None, header_row: int = 0) -> pd.DataFrame:
|
||||
"""读取 Excel 工作表,支持 pandas 无法解析的特殊 Excel 文件"""
|
||||
import zipfile
|
||||
from xml.etree import ElementTree as ET
|
||||
|
||||
try:
|
||||
df = pd.read_excel(file_path, sheet_name=sheet_name, header=header_row)
|
||||
if df is not None and not df.empty:
|
||||
return df
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# pandas 读取失败,从 XML 直接解析
|
||||
logger.info(f"使用 XML 方式读取 Excel: {file_path}")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(file_path, 'r') as z:
|
||||
sheet_names = self._extract_sheet_names_from_xml(file_path)
|
||||
if not sheet_names:
|
||||
raise ValueError("无法从 Excel 文件中找到工作表")
|
||||
|
||||
target_sheet = sheet_name if sheet_name and sheet_name in sheet_names else sheet_names[0]
|
||||
sheet_index = sheet_names.index(target_sheet) + 1
|
||||
|
||||
shared_strings = []
|
||||
if 'xl/sharedStrings.xml' in z.namelist():
|
||||
ss_content = z.read('xl/sharedStrings.xml')
|
||||
ss_root = ET.fromstring(ss_content)
|
||||
for si in ss_root.iter():
|
||||
if si.tag.endswith('}si') or si.tag == 'si':
|
||||
t = si.find('.//{*}t')
|
||||
shared_strings.append(t.text if t is not None and t.text else '')
|
||||
|
||||
sheet_file = f'xl/worksheets/sheet{sheet_index}.xml'
|
||||
sheet_content = z.read(sheet_file)
|
||||
root = ET.fromstring(sheet_content)
|
||||
|
||||
rows_data = []
|
||||
headers = {}
|
||||
|
||||
for row in root.iter():
|
||||
if row.tag.endswith('}row') or row.tag == 'row':
|
||||
row_idx = int(row.get('r', 0))
|
||||
|
||||
# 收集表头行
|
||||
if row_idx == header_row + 1:
|
||||
for cell in row:
|
||||
if cell.tag.endswith('}c') or cell.tag == 'c':
|
||||
cell_ref = cell.get('r', '')
|
||||
col_letters = ''.join(filter(str.isalpha, cell_ref))
|
||||
cell_type = cell.get('t', 'n')
|
||||
v = cell.find('{*}v')
|
||||
if v is not None and v.text:
|
||||
if cell_type == 's':
|
||||
try:
|
||||
headers[col_letters] = shared_strings[int(v.text)]
|
||||
except (ValueError, IndexError):
|
||||
headers[col_letters] = v.text
|
||||
else:
|
||||
headers[col_letters] = v.text
|
||||
else:
|
||||
headers[col_letters] = col_letters
|
||||
continue
|
||||
|
||||
if row_idx <= header_row + 1:
|
||||
continue
|
||||
|
||||
row_cells = {}
|
||||
for cell in row:
|
||||
if cell.tag.endswith('}c') or cell.tag == 'c':
|
||||
cell_ref = cell.get('r', '')
|
||||
col_letters = ''.join(filter(str.isalpha, cell_ref))
|
||||
cell_type = cell.get('t', 'n')
|
||||
v = cell.find('{*}v')
|
||||
|
||||
if v is not None and v.text:
|
||||
if cell_type == 's':
|
||||
try:
|
||||
val = shared_strings[int(v.text)]
|
||||
except (ValueError, IndexError):
|
||||
val = v.text
|
||||
elif cell_type == 'b':
|
||||
val = v.text == '1'
|
||||
else:
|
||||
val = v.text
|
||||
else:
|
||||
val = None
|
||||
row_cells[col_letters] = val
|
||||
|
||||
if row_cells:
|
||||
rows_data.append(row_cells)
|
||||
|
||||
if not rows_data:
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame(rows_data)
|
||||
|
||||
if headers:
|
||||
df.columns = [headers.get(col, col) for col in df.columns]
|
||||
|
||||
return df
|
||||
except Exception as e:
|
||||
logger.error(f"XML 解析 Excel 失败: {e}")
|
||||
raise
|
||||
|
||||
def _sanitize_table_name(self, filename: str) -> str:
|
||||
"""
|
||||
将文件名转换为合法的表名
|
||||
|
||||
Args:
|
||||
filename: 原始文件名
|
||||
|
||||
Returns:
|
||||
合法的表名
|
||||
"""
|
||||
# 移除扩展名
|
||||
name = filename.rsplit('.', 1)[0] if '.' in filename else filename
|
||||
|
||||
# 只保留字母、数字、下划线
|
||||
name = re.sub(r'[^a-zA-Z0-9_]', '_', name)
|
||||
|
||||
# 确保以字母开头
|
||||
if name and name[0].isdigit():
|
||||
name = 't_' + name
|
||||
|
||||
# 限制长度
|
||||
return name[:50]
|
||||
|
||||
def _sanitize_column_name(self, col_name: str) -> str:
|
||||
"""
|
||||
将列名转换为合法的字段名
|
||||
|
||||
Args:
|
||||
col_name: 原始列名
|
||||
|
||||
Returns:
|
||||
合法的字段名
|
||||
"""
|
||||
# MySQL 支持 UTF8 编码,中文字符可以直接使用
|
||||
# 只处理非法字符(控制字符等)和首字符数字
|
||||
name = str(col_name).strip()
|
||||
# 移除控制字符
|
||||
name = re.sub(r'[\x00-\x1f\x7f]', '', name)
|
||||
# 确保以字母或中文开头
|
||||
if name and name[0].isdigit():
|
||||
name = 'col_' + name
|
||||
# 限制长度 (MySQL 字段名最多64字符)
|
||||
return name[:64]
|
||||
|
||||
def _get_unique_column_name(self, col_name: str, used_names: set) -> str:
|
||||
"""
|
||||
获取唯一的列名,避免重复
|
||||
|
||||
Args:
|
||||
col_name: 原始列名
|
||||
used_names: 已使用的列名集合
|
||||
|
||||
Returns:
|
||||
唯一的列名
|
||||
"""
|
||||
sanitized = self._sanitize_column_name(col_name)
|
||||
# "id" 是 MySQL 保留名,作为主键使用
|
||||
if sanitized.lower() == "id":
|
||||
sanitized = "col_id"
|
||||
if sanitized not in used_names:
|
||||
used_names.add(sanitized)
|
||||
return sanitized
|
||||
|
||||
# 添加数字后缀直到唯一
|
||||
base = sanitized if sanitized else "col"
|
||||
counter = 1
|
||||
while f"{base}_{counter}" in used_names:
|
||||
counter += 1
|
||||
unique_name = f"{base}_{counter}"
|
||||
used_names.add(unique_name)
|
||||
return unique_name
|
||||
|
||||
def _infer_column_type(self, series: pd.Series) -> str:
|
||||
"""
|
||||
根据数据推断列类型
|
||||
|
||||
Args:
|
||||
series: pandas Series
|
||||
|
||||
Returns:
|
||||
类型名称
|
||||
"""
|
||||
# 移除空值进行类型检查
|
||||
non_null = series.dropna()
|
||||
if len(non_null) == 0:
|
||||
return "TEXT"
|
||||
|
||||
dtype = series.dtype
|
||||
|
||||
# 整数类型检查
|
||||
if pd.api.types.is_integer_dtype(dtype):
|
||||
# 检查是否所有值都能放入 INT 范围
|
||||
try:
|
||||
int_values = non_null.astype('int64')
|
||||
if int_values.min() >= -2147483648 and int_values.max() <= 2147483647:
|
||||
return "INTEGER"
|
||||
else:
|
||||
# 超出 INT 范围,使用 TEXT
|
||||
return "TEXT"
|
||||
except (ValueError, OverflowError):
|
||||
return "TEXT"
|
||||
elif pd.api.types.is_float_dtype(dtype):
|
||||
# 检查是否所有值都能放入 FLOAT
|
||||
try:
|
||||
float_values = non_null.astype('float64')
|
||||
if float_values.min() >= -1e308 and float_values.max() <= 1e308:
|
||||
return "FLOAT"
|
||||
else:
|
||||
return "TEXT"
|
||||
except (ValueError, OverflowError):
|
||||
return "TEXT"
|
||||
elif pd.api.types.is_datetime64_any_dtype(dtype):
|
||||
return "DATETIME"
|
||||
elif pd.api.types.is_bool_dtype(dtype):
|
||||
return "BOOLEAN"
|
||||
else:
|
||||
return "TEXT"
|
||||
|
||||
def _create_table_model(
|
||||
self,
|
||||
table_name: str,
|
||||
columns: List[str],
|
||||
column_types: Dict[str, str]
|
||||
) -> type:
|
||||
"""
|
||||
动态创建 SQLAlchemy 模型类
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
columns: 列名列表
|
||||
column_types: 列类型字典
|
||||
|
||||
Returns:
|
||||
SQLAlchemy 模型类
|
||||
"""
|
||||
# 创建属性字典
|
||||
attrs = {
|
||||
'__tablename__': table_name,
|
||||
'__table_args__': {'extend_existing': True},
|
||||
}
|
||||
|
||||
# 添加主键列
|
||||
attrs['id'] = Column(Integer, primary_key=True, autoincrement=True)
|
||||
|
||||
# 添加数据列
|
||||
for col in columns:
|
||||
col_name = self._sanitize_column_name(col)
|
||||
col_type = column_types.get(col, "TEXT")
|
||||
|
||||
if col_type == "INTEGER":
|
||||
attrs[col_name] = Column(Integer, nullable=True)
|
||||
elif col_type == "FLOAT":
|
||||
attrs[col_name] = Column(Float, nullable=True)
|
||||
elif col_type == "DATETIME":
|
||||
attrs[col_name] = Column(DateTime, nullable=True)
|
||||
elif col_type == "BOOLEAN":
|
||||
attrs[col_name] = Column(Integer, nullable=True) # MySQL 没有原生 BOOLEAN
|
||||
else:
|
||||
attrs[col_name] = Column(Text, nullable=True)
|
||||
|
||||
# 添加元数据列
|
||||
attrs['created_at'] = Column(DateTime, default=datetime.utcnow)
|
||||
attrs['updated_at'] = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# 创建类
|
||||
return type(table_name, (Base,), attrs)
|
||||
|
||||
async def store_excel(
|
||||
self,
|
||||
file_path: str,
|
||||
filename: str,
|
||||
sheet_name: Optional[str] = None,
|
||||
header_row: int = 0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
将 Excel 文件存储到 MySQL
|
||||
|
||||
Args:
|
||||
file_path: Excel 文件路径
|
||||
filename: 原始文件名
|
||||
sheet_name: 工作表名称
|
||||
header_row: 表头行号
|
||||
|
||||
Returns:
|
||||
存储结果
|
||||
"""
|
||||
table_name = self._sanitize_table_name(filename)
|
||||
results = {
|
||||
"success": True,
|
||||
"table_name": table_name,
|
||||
"row_count": 0,
|
||||
"columns": []
|
||||
}
|
||||
|
||||
try:
|
||||
logger.info(f"开始读取Excel文件: {file_path}")
|
||||
# 读取 Excel(使用 fallback 方式支持特殊格式文件)
|
||||
df = self._read_excel_sheet(file_path, sheet_name=sheet_name, header_row=header_row)
|
||||
|
||||
logger.info(f"Excel读取完成,行数: {len(df)}, 列数: {len(df.columns)}")
|
||||
|
||||
if df.empty:
|
||||
return {"success": False, "error": "Excel 文件为空"}
|
||||
|
||||
# 清理列名
|
||||
df.columns = [str(c) for c in df.columns]
|
||||
|
||||
# 推断列类型,并生成唯一的列名
|
||||
column_types = {}
|
||||
column_name_map = {} # 原始列名 -> 唯一合法列名
|
||||
used_names = set()
|
||||
for col in df.columns:
|
||||
col_name = self._get_unique_column_name(col, used_names)
|
||||
col_type = self._infer_column_type(df[col])
|
||||
column_types[col] = col_type
|
||||
column_name_map[col] = col_name
|
||||
results["columns"].append({
|
||||
"original_name": col,
|
||||
"sanitized_name": col_name,
|
||||
"type": col_type
|
||||
})
|
||||
|
||||
# 创建表 - 使用原始 SQL 以兼容异步
|
||||
logger.info(f"正在创建MySQL表: {table_name}")
|
||||
sql_columns = ["id INT AUTO_INCREMENT PRIMARY KEY"]
|
||||
for col in df.columns:
|
||||
col_name = column_name_map[col]
|
||||
col_type = column_types.get(col, "TEXT")
|
||||
sql_type = "INT" if col_type == "INTEGER" else "FLOAT" if col_type == "FLOAT" else "DATETIME" if col_type == "DATETIME" else "TEXT"
|
||||
sql_columns.append(f"`{col_name}` {sql_type}")
|
||||
sql_columns.append("created_at DATETIME DEFAULT CURRENT_TIMESTAMP")
|
||||
sql_columns.append("updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")
|
||||
create_sql = text(f"CREATE TABLE IF NOT EXISTS `{table_name}` ({', '.join(sql_columns)})")
|
||||
await self.mysql_db.execute_raw_sql(str(create_sql))
|
||||
logger.info(f"MySQL表创建完成: {table_name}")
|
||||
|
||||
# 插入数据
|
||||
records = []
|
||||
for _, row in df.iterrows():
|
||||
record = {}
|
||||
for col in df.columns:
|
||||
col_name = column_name_map[col]
|
||||
value = row[col]
|
||||
|
||||
# 处理 NaN 值
|
||||
if pd.isna(value):
|
||||
record[col_name] = None
|
||||
elif column_types[col] == "INTEGER":
|
||||
try:
|
||||
record[col_name] = int(value)
|
||||
except (ValueError, TypeError):
|
||||
record[col_name] = None
|
||||
elif column_types[col] == "FLOAT":
|
||||
try:
|
||||
record[col_name] = float(value)
|
||||
except (ValueError, TypeError):
|
||||
record[col_name] = None
|
||||
else:
|
||||
record[col_name] = str(value)
|
||||
|
||||
records.append(record)
|
||||
|
||||
logger.info(f"正在插入 {len(records)} 条数据到 MySQL (使用批量插入)...")
|
||||
# 使用 pymysql 直接插入以避免 SQLAlchemy 异步问题
|
||||
import pymysql
|
||||
from app.config import settings
|
||||
|
||||
connection = pymysql.connect(
|
||||
host=settings.MYSQL_HOST,
|
||||
port=settings.MYSQL_PORT,
|
||||
user=settings.MYSQL_USER,
|
||||
password=settings.MYSQL_PASSWORD,
|
||||
database=settings.MYSQL_DATABASE,
|
||||
charset=settings.MYSQL_CHARSET
|
||||
)
|
||||
try:
|
||||
columns_str = ', '.join(['`' + column_name_map[col] + '`' for col in df.columns])
|
||||
placeholders = ', '.join(['%s' for _ in df.columns])
|
||||
insert_sql = f"INSERT INTO `{table_name}` ({columns_str}) VALUES ({placeholders})"
|
||||
|
||||
# 转换为元组列表 (使用映射后的列名)
|
||||
param_list = [tuple(record.get(column_name_map[col]) for col in df.columns) for record in records]
|
||||
|
||||
with connection.cursor() as cursor:
|
||||
cursor.executemany(insert_sql, param_list)
|
||||
connection.commit()
|
||||
logger.info(f"数据插入完成: {len(records)} 条")
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
results["row_count"] = len(records)
|
||||
logger.info(f"Excel 数据已存储到 MySQL 表 {table_name},共 {len(records)} 行")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"存储 Excel 到 MySQL 失败: {str(e)}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def store_structured_data(
|
||||
self,
|
||||
table_name: str,
|
||||
data: Dict[str, Any],
|
||||
source_doc_id: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
将结构化数据(从非结构化文档提取的表格)存储到 MySQL
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
data: 结构化数据,格式为:
|
||||
{
|
||||
"columns": ["col1", "col2"], # 列名
|
||||
"rows": [["val1", "val2"], ["val3", "val4"]] # 数据行
|
||||
}
|
||||
source_doc_id: 源文档 ID
|
||||
|
||||
Returns:
|
||||
存储结果
|
||||
"""
|
||||
results = {
|
||||
"success": True,
|
||||
"table_name": table_name,
|
||||
"row_count": 0,
|
||||
"columns": []
|
||||
}
|
||||
|
||||
try:
|
||||
columns = data.get("columns", [])
|
||||
rows = data.get("rows", [])
|
||||
|
||||
if not columns or not rows:
|
||||
return {"success": False, "error": "数据为空"}
|
||||
|
||||
# 清理列名
|
||||
sanitized_columns = [self._sanitize_column_name(c) for c in columns]
|
||||
|
||||
# 推断列类型
|
||||
column_types = {}
|
||||
for i, col in enumerate(columns):
|
||||
col_values = [row[i] for row in rows if i < len(row)]
|
||||
# 根据数据推断类型
|
||||
col_type = self._infer_type_from_values(col_values)
|
||||
column_types[col] = col_type
|
||||
results["columns"].append({
|
||||
"original_name": col,
|
||||
"sanitized_name": self._sanitize_column_name(col),
|
||||
"type": col_type
|
||||
})
|
||||
|
||||
# 创建表
|
||||
model_class = self._create_table_model(table_name, columns, column_types)
|
||||
|
||||
# 创建表结构
|
||||
async with self.mysql_db.get_session() as session:
|
||||
model_class.__table__.create(session.bind, checkfirst=True)
|
||||
|
||||
# 插入数据
|
||||
records = []
|
||||
for row in rows:
|
||||
record = {}
|
||||
for i, col in enumerate(columns):
|
||||
if i >= len(row):
|
||||
continue
|
||||
col_name = self._sanitize_column_name(col)
|
||||
value = row[i]
|
||||
col_type = column_types.get(col, "TEXT")
|
||||
|
||||
# 处理空值
|
||||
if value is None or str(value).strip() == '':
|
||||
record[col_name] = None
|
||||
elif col_type == "INTEGER":
|
||||
try:
|
||||
record[col_name] = int(value)
|
||||
except (ValueError, TypeError):
|
||||
record[col_name] = None
|
||||
elif col_type == "FLOAT":
|
||||
try:
|
||||
record[col_name] = float(value)
|
||||
except (ValueError, TypeError):
|
||||
record[col_name] = None
|
||||
else:
|
||||
record[col_name] = str(value)
|
||||
|
||||
records.append(record)
|
||||
|
||||
# 批量插入
|
||||
async with self.mysql_db.get_session() as session:
|
||||
for record in records:
|
||||
session.add(model_class(**record))
|
||||
await session.commit()
|
||||
|
||||
results["row_count"] = len(records)
|
||||
logger.info(f"结构化数据已存储到 MySQL 表 {table_name},共 {len(records)} 行")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"存储结构化数据到 MySQL 失败: {str(e)}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
def _infer_type_from_values(self, values: List[Any]) -> str:
|
||||
"""
|
||||
根据值列表推断列类型
|
||||
|
||||
Args:
|
||||
values: 值列表
|
||||
|
||||
Returns:
|
||||
类型名称
|
||||
"""
|
||||
non_null_values = [v for v in values if v is not None and str(v).strip() != '']
|
||||
if not non_null_values:
|
||||
return "TEXT"
|
||||
|
||||
# 检查是否全是整数
|
||||
is_integer = all(self._is_integer(v) for v in non_null_values)
|
||||
if is_integer:
|
||||
return "INTEGER"
|
||||
|
||||
# 检查是否全是浮点数
|
||||
is_float = all(self._is_float(v) for v in non_null_values)
|
||||
if is_float:
|
||||
return "FLOAT"
|
||||
|
||||
return "TEXT"
|
||||
|
||||
def _is_integer(self, value: Any) -> bool:
|
||||
"""判断值是否可以转为整数"""
|
||||
try:
|
||||
int(value)
|
||||
return True
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
def _is_float(self, value: Any) -> bool:
|
||||
"""判断值是否可以转为浮点数"""
|
||||
try:
|
||||
float(value)
|
||||
return True
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
async def query_table(
|
||||
self,
|
||||
table_name: str,
|
||||
columns: Optional[List[str]] = None,
|
||||
where: Optional[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
查询 MySQL 表数据
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
columns: 要查询的列
|
||||
where: WHERE 条件
|
||||
limit: 限制返回行数
|
||||
|
||||
Returns:
|
||||
查询结果
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
sql = f"SELECT * FROM `{table_name}`"
|
||||
if where:
|
||||
sql += f" WHERE {where}"
|
||||
sql += f" LIMIT {limit}"
|
||||
|
||||
results = await self.mysql_db.execute_query(sql)
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询表失败: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_table_schema(self, table_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取表结构信息
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
|
||||
Returns:
|
||||
表结构信息
|
||||
"""
|
||||
try:
|
||||
sql = f"""
|
||||
SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_KEY, COLUMN_COMMENT
|
||||
FROM INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = '{table_name}'
|
||||
ORDER BY ORDINAL_POSITION
|
||||
"""
|
||||
results = await self.mysql_db.execute_query(sql)
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取表结构失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def delete_table(self, table_name: str) -> bool:
|
||||
"""
|
||||
删除表
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
try:
|
||||
# 安全检查:表名必须包含下划线(避免删除系统表)
|
||||
if '_' not in table_name and not table_name.startswith('t_'):
|
||||
raise ValueError("不允许删除此表")
|
||||
|
||||
sql = f"DROP TABLE IF EXISTS `{table_name}`"
|
||||
await self.mysql_db.execute_raw_sql(sql)
|
||||
logger.info(f"表 {table_name} 已删除")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除表失败: {str(e)}")
|
||||
return False
|
||||
|
||||
async def list_tables(self) -> List[str]:
|
||||
"""
|
||||
列出所有用户表
|
||||
|
||||
Returns:
|
||||
表名列表
|
||||
"""
|
||||
try:
|
||||
sql = """
|
||||
SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES
|
||||
WHERE TABLE_SCHEMA = DATABASE() AND TABLE_TYPE = 'BASE TABLE'
|
||||
"""
|
||||
results = await self.mysql_db.execute_query(sql)
|
||||
return [r['TABLE_NAME'] for r in results]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"列出表失败: {str(e)}")
|
||||
return []
|
||||
|
||||
|
||||
# ==================== 全局单例 ====================
|
||||
|
||||
excel_storage_service = ExcelStorageService()
|
||||
138
backend/app/services/file_service.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
文件服务模块 - 处理文件存储和读取
|
||||
"""
|
||||
import os
|
||||
import shutil
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileService:
|
||||
"""文件服务类,负责文件的存储、读取和管理"""
|
||||
|
||||
def __init__(self):
|
||||
self.upload_dir = Path(settings.UPLOAD_DIR)
|
||||
self._ensure_upload_dir()
|
||||
logger.info(f"FileService 初始化,上传目录: {self.upload_dir}")
|
||||
|
||||
def _ensure_upload_dir(self):
|
||||
"""确保上传目录存在"""
|
||||
self.upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def save_uploaded_file(
|
||||
self,
|
||||
file_content: bytes,
|
||||
filename: str,
|
||||
subfolder: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
保存上传的文件
|
||||
|
||||
Args:
|
||||
file_content: 文件内容字节
|
||||
filename: 原始文件名
|
||||
subfolder: 可选的子文件夹名称
|
||||
|
||||
Returns:
|
||||
str: 保存后的文件路径
|
||||
"""
|
||||
# 生成唯一文件名,避免覆盖
|
||||
file_ext = Path(filename).suffix
|
||||
unique_name = f"{uuid.uuid4().hex}{file_ext}"
|
||||
|
||||
# 确定保存路径
|
||||
if subfolder:
|
||||
save_dir = self.upload_dir / subfolder
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
save_dir = self.upload_dir
|
||||
|
||||
file_path = save_dir / unique_name
|
||||
|
||||
# 写入文件
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(file_content)
|
||||
|
||||
file_size = len(file_content)
|
||||
logger.info(f"文件已保存: {filename} -> {file_path} ({file_size} bytes)")
|
||||
return str(file_path)
|
||||
|
||||
def read_file(self, file_path: str) -> bytes:
|
||||
"""
|
||||
读取文件内容
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
bytes: 文件内容
|
||||
"""
|
||||
with open(file_path, 'rb') as f:
|
||||
return f.read()
|
||||
|
||||
def delete_file(self, file_path: str) -> bool:
|
||||
"""
|
||||
删除文件
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
bool: 是否删除成功
|
||||
"""
|
||||
try:
|
||||
file = Path(file_path)
|
||||
if file.exists():
|
||||
file.unlink()
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_file_info(self, file_path: str) -> dict:
|
||||
"""
|
||||
获取文件信息
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
dict: 文件信息
|
||||
"""
|
||||
file = Path(file_path)
|
||||
if not file.exists():
|
||||
return {}
|
||||
|
||||
stat = file.stat()
|
||||
return {
|
||||
"filename": file.name,
|
||||
"filepath": str(file),
|
||||
"size": stat.st_size,
|
||||
"created": datetime.fromtimestamp(stat.st_ctime).isoformat(),
|
||||
"modified": datetime.fromtimestamp(stat.st_mtime).isoformat(),
|
||||
"extension": file.suffix.lower()
|
||||
}
|
||||
|
||||
def get_file_size(self, file_path: str) -> int:
|
||||
"""
|
||||
获取文件大小(字节)
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
int: 文件大小,文件不存在返回 0
|
||||
"""
|
||||
file = Path(file_path)
|
||||
return file.stat().st_size if file.exists() else 0
|
||||
|
||||
|
||||
# 全局单例
|
||||
file_service = FileService()
|
||||
105
backend/app/services/font_helper.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
字体辅助模块 - 处理中文字体检测和配置
|
||||
"""
|
||||
import matplotlib
|
||||
import matplotlib.font_manager as fm
|
||||
import platform
|
||||
import os
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_chinese_font() -> str:
|
||||
"""
|
||||
获取可用的中文字体
|
||||
|
||||
Returns:
|
||||
str: 可用的中文字体名称
|
||||
"""
|
||||
# 获取系统中所有可用字体
|
||||
available_fonts = set([f.name for f in fm.fontManager.ttflist])
|
||||
|
||||
# 定义字体优先级列表
|
||||
# Windows 优先
|
||||
if platform.system() == 'Windows':
|
||||
font_list = [
|
||||
'Microsoft YaHei', # 微软雅黑
|
||||
'SimHei', # 黑体
|
||||
'SimSun', # 宋体
|
||||
'KaiTi', # 楷体
|
||||
'FangSong', # 仿宋
|
||||
'STXihei', # 华文细黑
|
||||
'STKaiti', # 华文楷体
|
||||
'STSong', # 华文宋体
|
||||
'STFangsong', # 华文仿宋
|
||||
]
|
||||
# macOS 优先
|
||||
elif platform.system() == 'Darwin':
|
||||
font_list = [
|
||||
'PingFang SC', # 苹方-简
|
||||
'PingFang TC', # 苹方-繁
|
||||
'Heiti SC', # 黑体-简
|
||||
'Heiti TC', # 黑体-繁
|
||||
'STHeiti', # 华文黑体
|
||||
'STSong', # 华文宋体
|
||||
'STKaiti', # 华文楷体
|
||||
'Arial Unicode MS', # Arial Unicode MS
|
||||
]
|
||||
# Linux 优先
|
||||
else:
|
||||
font_list = [
|
||||
'Noto Sans CJK SC', # Noto Sans CJK 简体中文
|
||||
'WenQuanYi Micro Hei', # 文泉驿微米黑
|
||||
'AR PL UMing CN', # AR PL UMing
|
||||
'AR PL UKai CN', # AR PL UKai
|
||||
'ZCOOL XiaoWei', # ZCOOL 小薇
|
||||
]
|
||||
|
||||
# 通用备选字体
|
||||
font_list.extend([
|
||||
'SimHei',
|
||||
'Microsoft YaHei',
|
||||
'Arial Unicode MS',
|
||||
'Droid Sans Fallback',
|
||||
])
|
||||
|
||||
# 查找第一个可用的字体
|
||||
for font_name in font_list:
|
||||
if font_name in available_fonts:
|
||||
logger.info(f"找到中文字体: {font_name}")
|
||||
return font_name
|
||||
|
||||
# 如果没找到,尝试获取第一个中文字体
|
||||
for font in fm.fontManager.ttflist:
|
||||
if 'CJK' in font.name or 'SC' in font.name or 'TC' in font.name:
|
||||
logger.info(f"使用找到的中文字体: {font.name}")
|
||||
return font.name
|
||||
|
||||
# 最终备选:使用系统默认字体
|
||||
logger.warning("未找到合适的中文字体,使用默认字体")
|
||||
return 'sans-serif'
|
||||
|
||||
|
||||
def configure_matplotlib_fonts():
|
||||
"""
|
||||
配置 matplotlib 的字体设置
|
||||
"""
|
||||
chinese_font = get_chinese_font()
|
||||
|
||||
# 配置字体
|
||||
matplotlib.rcParams['font.sans-serif'] = [chinese_font]
|
||||
matplotlib.rcParams['axes.unicode_minus'] = False
|
||||
matplotlib.rcParams['figure.dpi'] = 100
|
||||
matplotlib.rcParams['savefig.dpi'] = 120
|
||||
|
||||
# 字体大小设置
|
||||
matplotlib.rcParams['font.size'] = 10
|
||||
matplotlib.rcParams['axes.labelsize'] = 10
|
||||
matplotlib.rcParams['axes.titlesize'] = 11
|
||||
matplotlib.rcParams['xtick.labelsize'] = 9
|
||||
matplotlib.rcParams['ytick.labelsize'] = 9
|
||||
matplotlib.rcParams['legend.fontsize'] = 9
|
||||
|
||||
logger.info(f"配置完成,使用字体: {chinese_font}")
|
||||
return chinese_font
|
||||
491
backend/app/services/llm_service.py
Normal file
@@ -0,0 +1,491 @@
|
||||
"""
|
||||
LLM 服务模块 - 封装大模型 API 调用
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, AsyncGenerator
|
||||
import httpx
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMService:
|
||||
"""大语言模型服务类"""
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = settings.LLM_API_KEY
|
||||
self.base_url = settings.LLM_BASE_URL
|
||||
self.model_name = settings.LLM_MODEL_NAME
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
调用聊天 API
|
||||
|
||||
Args:
|
||||
messages: 消息列表,格式为 [{"role": "user", "content": "..."}]
|
||||
temperature: 温度参数,控制随机性
|
||||
max_tokens: 最大生成 token 数
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: API 响应结果
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"messages": messages,
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
payload["max_tokens"] = max_tokens
|
||||
|
||||
# 添加其他参数
|
||||
payload.update(kwargs)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_detail = e.response.text
|
||||
logger.error(f"LLM API 请求失败: {e.response.status_code} - {error_detail}")
|
||||
# 尝试解析错误信息
|
||||
try:
|
||||
import json
|
||||
err_json = json.loads(error_detail)
|
||||
err_code = err_json.get("error", {}).get("code", "unknown")
|
||||
err_msg = err_json.get("error", {}).get("message", "unknown")
|
||||
logger.error(f"API 错误码: {err_code}, 错误信息: {err_msg}")
|
||||
except:
|
||||
pass
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"LLM API 调用异常: {str(e)}")
|
||||
raise
|
||||
|
||||
def extract_message_content(self, response: Dict[str, Any]) -> str:
|
||||
"""
|
||||
从 API 响应中提取消息内容
|
||||
|
||||
Args:
|
||||
response: API 响应
|
||||
|
||||
Returns:
|
||||
str: 消息内容
|
||||
"""
|
||||
try:
|
||||
return response["choices"][0]["message"]["content"]
|
||||
except (KeyError, IndexError) as e:
|
||||
logger.error(f"解析 API 响应失败: {str(e)}")
|
||||
raise
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
流式调用聊天 API
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大 token 数
|
||||
**kwargs: 其他参数
|
||||
|
||||
Yields:
|
||||
Dict[str, Any]: 包含 delta 内容的块
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"stream": True
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
payload["max_tokens"] = max_tokens
|
||||
|
||||
payload.update(kwargs)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
) as response:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data = line[6:] # Remove "data: " prefix
|
||||
if data == "[DONE]":
|
||||
break
|
||||
try:
|
||||
import json as json_module
|
||||
chunk = json_module.loads(data)
|
||||
delta = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
||||
if delta:
|
||||
yield {"content": delta}
|
||||
except json_module.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"LLM 流式 API 请求失败: {e.response.status_code}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 流式 API 调用异常: {str(e)}")
|
||||
raise
|
||||
|
||||
async def analyze_excel_data(
|
||||
self,
|
||||
excel_data: Dict[str, Any],
|
||||
user_prompt: str,
|
||||
analysis_type: str = "general"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
分析 Excel 数据
|
||||
|
||||
Args:
|
||||
excel_data: Excel 解析后的数据
|
||||
user_prompt: 用户提示词
|
||||
analysis_type: 分析类型 (general, summary, statistics, insights)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 分析结果
|
||||
"""
|
||||
# 构建 Prompt
|
||||
system_prompt = self._get_system_prompt(analysis_type)
|
||||
user_message = self._format_user_message(excel_data, user_prompt)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_message}
|
||||
]
|
||||
|
||||
try:
|
||||
response = await self.chat(
|
||||
messages=messages,
|
||||
temperature=0.3, # 较低的温度以获得更稳定的输出
|
||||
max_tokens=2000
|
||||
)
|
||||
|
||||
content = self.extract_message_content(response)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"analysis": content,
|
||||
"model": self.model_name,
|
||||
"analysis_type": analysis_type
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Excel 数据分析失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"analysis": None
|
||||
}
|
||||
|
||||
def _get_system_prompt(self, analysis_type: str) -> str:
|
||||
"""获取系统提示词"""
|
||||
prompts = {
|
||||
"general": """你是一个专业的数据分析师。请分析用户提供的 Excel 数据,提供有价值的见解和建议。
|
||||
|
||||
请按照以下格式输出:
|
||||
1. 数据概览
|
||||
2. 关键发现
|
||||
3. 数据质量评估
|
||||
4. 建议
|
||||
|
||||
输出语言:中文""",
|
||||
"summary": """你是一个专业的数据分析师。请对用户提供的 Excel 数据进行简洁的总结。
|
||||
|
||||
输出格式:
|
||||
- 数据行数和列数
|
||||
- 主要列的说明
|
||||
- 数据范围概述
|
||||
|
||||
输出语言:中文""",
|
||||
"statistics": """你是一个专业的数据分析师。请对用户提供的 Excel 数据进行统计分析。
|
||||
|
||||
请分析:
|
||||
- 数值型列的统计信息(平均值、中位数、最大值、最小值)
|
||||
- 分类列的分布情况
|
||||
- 数据相关性
|
||||
|
||||
输出语言:中文,使用表格或结构化格式展示""",
|
||||
"insights": """你是一个专业的数据分析师。请深入挖掘用户提供的 Excel 数据,提供有价值的洞察。
|
||||
|
||||
请分析:
|
||||
1. 数据中的异常值或特殊模式
|
||||
2. 数据之间的潜在关联
|
||||
3. 基于数据的业务建议
|
||||
4. 数据趋势分析(如适用)
|
||||
|
||||
输出语言:中文,提供详细且可操作的建议"""
|
||||
}
|
||||
|
||||
return prompts.get(analysis_type, prompts["general"])
|
||||
|
||||
def _format_user_message(self, excel_data: Dict[str, Any], user_prompt: str) -> str:
|
||||
"""格式化用户消息"""
|
||||
columns = excel_data.get("columns", [])
|
||||
rows = excel_data.get("rows", [])
|
||||
row_count = excel_data.get("row_count", 0)
|
||||
column_count = excel_data.get("column_count", 0)
|
||||
|
||||
# 构建数据描述
|
||||
data_info = f"""
|
||||
Excel 数据概览:
|
||||
- 行数: {row_count}
|
||||
- 列数: {column_count}
|
||||
- 列名: {', '.join(columns)}
|
||||
|
||||
数据样例(前 5 行):
|
||||
"""
|
||||
|
||||
# 添加数据样例
|
||||
for i, row in enumerate(rows[:5], 1):
|
||||
row_str = " | ".join([f"{col}: {row.get(col, '')}" for col in columns])
|
||||
data_info += f"第 {i} 行: {row_str}\n"
|
||||
|
||||
if row_count > 5:
|
||||
data_info += f"\n(还有 {row_count - 5} 行数据...)\n"
|
||||
|
||||
# 添加用户自定义提示
|
||||
if user_prompt and user_prompt.strip():
|
||||
data_info += f"\n用户需求:\n{user_prompt}"
|
||||
else:
|
||||
data_info += "\n用户需求: 请对上述数据进行分析"
|
||||
|
||||
return data_info
|
||||
|
||||
async def analyze_with_template(
|
||||
self,
|
||||
excel_data: Dict[str, Any],
|
||||
template_prompt: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
使用自定义模板分析 Excel 数据
|
||||
|
||||
Args:
|
||||
excel_data: Excel 解析后的数据
|
||||
template_prompt: 自定义提示词模板
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 分析结果
|
||||
"""
|
||||
system_prompt = """你是一个专业的数据分析师。请根据用户提供的自定义提示词分析 Excel 数据。
|
||||
|
||||
请严格按照用户的要求进行分析,输出清晰、有条理的结果。
|
||||
|
||||
输出语言:中文"""
|
||||
|
||||
user_message = self._format_user_message(excel_data, template_prompt)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_message}
|
||||
]
|
||||
|
||||
try:
|
||||
response = await self.chat(
|
||||
messages=messages,
|
||||
temperature=0.5,
|
||||
max_tokens=3000
|
||||
)
|
||||
|
||||
content = self.extract_message_content(response)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"analysis": content,
|
||||
"model": self.model_name,
|
||||
"is_template": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"自定义模板分析失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"analysis": None
|
||||
}
|
||||
|
||||
async def chat_with_images(
|
||||
self,
|
||||
text: str,
|
||||
images: List[Dict[str, str]],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
调用视觉模型 API(支持图片输入)
|
||||
|
||||
Args:
|
||||
text: 文本内容
|
||||
images: 图片列表,每项包含 base64 编码和 mime_type
|
||||
格式: [{"base64": "...", "mime_type": "image/png"}, ...]
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大 token 数
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: API 响应结果
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# 构建图片内容
|
||||
image_contents = []
|
||||
for img in images:
|
||||
image_contents.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{img['mime_type']};base64,{img['base64']}"
|
||||
}
|
||||
})
|
||||
|
||||
# 构建消息
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": text
|
||||
},
|
||||
*image_contents
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
"messages": messages,
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
if max_tokens:
|
||||
payload["max_tokens"] = max_tokens
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_detail = e.response.text
|
||||
logger.error(f"视觉模型 API 请求失败: {e.response.status_code} - {error_detail}")
|
||||
# 尝试解析错误信息
|
||||
try:
|
||||
import json
|
||||
err_json = json.loads(error_detail)
|
||||
err_code = err_json.get("error", {}).get("code", "unknown")
|
||||
err_msg = err_json.get("error", {}).get("message", "unknown")
|
||||
logger.error(f"API 错误码: {err_code}, 错误信息: {err_msg}")
|
||||
logger.error(f"请求模型: {self.model_name}, base_url: {self.base_url}")
|
||||
except:
|
||||
pass
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"视觉模型 API 调用异常: {str(e)}")
|
||||
raise
|
||||
|
||||
async def analyze_images(
|
||||
self,
|
||||
images: List[Dict[str, str]],
|
||||
user_prompt: str = ""
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
分析图片内容(使用视觉模型)
|
||||
|
||||
Args:
|
||||
images: 图片列表,每项包含 base64 编码和 mime_type
|
||||
user_prompt: 用户提示词
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 分析结果
|
||||
"""
|
||||
prompt = f"""你是一个专业的视觉分析专家。请分析以下图片内容。
|
||||
|
||||
{user_prompt if user_prompt else "请详细描述图片中的内容,包括文字、数据、图表、流程等所有可见信息。"}
|
||||
|
||||
请按照以下 JSON 格式输出:
|
||||
{{
|
||||
"description": "图片内容的详细描述",
|
||||
"text_content": "图片中的文字内容(如有)",
|
||||
"data_extracted": {{"键": "值"}} // 如果图片中有表格或数据
|
||||
}}
|
||||
|
||||
如果图片不包含有用信息,请返回空的描述。"""
|
||||
|
||||
try:
|
||||
response = await self.chat_with_images(
|
||||
text=prompt,
|
||||
images=images,
|
||||
temperature=0.1,
|
||||
max_tokens=4000
|
||||
)
|
||||
|
||||
content = self.extract_message_content(response)
|
||||
|
||||
# 解析 JSON
|
||||
import json
|
||||
try:
|
||||
result = json.loads(content)
|
||||
return {
|
||||
"success": True,
|
||||
"analysis": result,
|
||||
"model": self.model_name
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
return {
|
||||
"success": True,
|
||||
"analysis": {"description": content},
|
||||
"model": self.model_name
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"图片分析失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"analysis": None
|
||||
}
|
||||
|
||||
|
||||
# 全局单例
|
||||
llm_service = LLMService()
|
||||
707
backend/app/services/markdown_ai_service.py
Normal file
@@ -0,0 +1,707 @@
|
||||
"""
|
||||
Markdown 文档 AI 分析服务
|
||||
|
||||
支持:
|
||||
- 分章节解析(中文章节编号:一、二、三, (一)(二)(三))
|
||||
- 结构化数据提取
|
||||
- 流式输出
|
||||
- 多种分析类型
|
||||
- 可视化图表生成
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.services.llm_service import llm_service
|
||||
from app.core.document_parser import MarkdownParser
|
||||
from app.services.visualization_service import visualization_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarkdownSection:
|
||||
"""文档章节结构"""
|
||||
def __init__(self, number: str, title: str, level: int, content: str, line_start: int, line_end: int):
|
||||
self.number = number # 章节编号,如 "一", "(一)", "1"
|
||||
self.title = title
|
||||
self.level = level # 层级深度
|
||||
self.content = content # 章节内容(不含子章节)
|
||||
self.line_start = line_start
|
||||
self.line_end = line_end
|
||||
self.subsections: List[MarkdownSection] = []
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"number": self.number,
|
||||
"title": self.title,
|
||||
"level": self.level,
|
||||
"content_preview": self.content[:200] + "..." if len(self.content) > 200 else self.content,
|
||||
"line_start": self.line_start,
|
||||
"line_end": self.line_end,
|
||||
"subsections": [s.to_dict() for s in self.subsections]
|
||||
}
|
||||
|
||||
|
||||
class MarkdownAIService:
|
||||
"""Markdown 文档 AI 分析服务"""
|
||||
|
||||
# 中文章节编号模式
|
||||
CHINESE_NUMBERS = ["一", "二", "三", "四", "五", "六", "七", "八", "九", "十"]
|
||||
CHINESE_SUFFIX = "、"
|
||||
PARENTHESIS_PATTERN = re.compile(r'^(([一二三四五六七八九十]+)\s*(.+)$')
|
||||
CHINESE_SECTION_PATTERN = re.compile(r'^([一二三四五六七八九十]+)、\s*(.+)$')
|
||||
ARABIC_SECTION_PATTERN = re.compile(r'^(\d+)\.\s+(.+)$')
|
||||
|
||||
def __init__(self):
|
||||
self.parser = MarkdownParser()
|
||||
|
||||
def get_supported_analysis_types(self) -> list:
|
||||
"""获取支持的分析类型"""
|
||||
return [
|
||||
"summary", # 文档摘要
|
||||
"outline", # 大纲提取
|
||||
"key_points", # 关键点提取
|
||||
"questions", # 生成问题
|
||||
"tags", # 生成标签
|
||||
"qa", # 问答对
|
||||
"statistics", # 统计数据分析(适合政府公报)
|
||||
"section", # 分章节详细分析
|
||||
"charts" # 可视化图表生成
|
||||
]
|
||||
|
||||
def extract_sections(self, content: str, titles: List[Dict]) -> List[MarkdownSection]:
|
||||
"""
|
||||
从文档内容中提取章节结构
|
||||
|
||||
识别以下章节格式:
|
||||
- 一级:一、二、三...
|
||||
- 二级:(一)(二)(三)...
|
||||
- 三级:1. 2. 3. ...
|
||||
"""
|
||||
sections = []
|
||||
lines = content.split('\n')
|
||||
|
||||
# 构建标题行到内容的映射
|
||||
title_lines = {}
|
||||
for t in titles:
|
||||
title_lines[t.get('line', 0)] = t
|
||||
|
||||
current_section = None
|
||||
section_stack = []
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
stripped = line.strip()
|
||||
|
||||
# 检查是否是一级标题(中文数字 + 、)
|
||||
match = self.CHINESE_SECTION_PATTERN.match(stripped)
|
||||
if match:
|
||||
# 结束当前章节
|
||||
if current_section:
|
||||
current_section.content = self._get_section_content(
|
||||
lines, current_section.line_start, i - 1
|
||||
)
|
||||
|
||||
current_section = MarkdownSection(
|
||||
number=match.group(1),
|
||||
title=match.group(2),
|
||||
level=1,
|
||||
content="",
|
||||
line_start=i,
|
||||
line_end=len(lines)
|
||||
)
|
||||
sections.append(current_section)
|
||||
section_stack = [current_section]
|
||||
continue
|
||||
|
||||
# 检查是否是二级标题((一)(二)...)
|
||||
match = self.PARENTHESIS_PATTERN.match(stripped)
|
||||
if match and current_section:
|
||||
# 结束当前子章节
|
||||
if section_stack and len(section_stack) > 1:
|
||||
parent = section_stack[-1]
|
||||
parent.content = self._get_section_content(
|
||||
lines, parent.line_start, i - 1
|
||||
)
|
||||
|
||||
subsection = MarkdownSection(
|
||||
number=match.group(1),
|
||||
title=match.group(2),
|
||||
level=2,
|
||||
content="",
|
||||
line_start=i,
|
||||
line_end=len(lines)
|
||||
)
|
||||
current_section.subsections.append(subsection)
|
||||
section_stack = [current_section, subsection]
|
||||
continue
|
||||
|
||||
# 检查是否是三级标题(1. 2. 3.)
|
||||
match = self.ARABIC_SECTION_PATTERN.match(stripped)
|
||||
if match and len(section_stack) > 1:
|
||||
# 结束当前子章节
|
||||
if len(section_stack) > 2:
|
||||
parent = section_stack[-1]
|
||||
parent.content = self._get_section_content(
|
||||
lines, parent.line_start, i - 1
|
||||
)
|
||||
|
||||
sub_subsection = MarkdownSection(
|
||||
number=match.group(1),
|
||||
title=match.group(2),
|
||||
level=3,
|
||||
content="",
|
||||
line_start=i,
|
||||
line_end=len(lines)
|
||||
)
|
||||
section_stack[-1].subsections.append(sub_subsection)
|
||||
section_stack = section_stack[:-1] + [sub_subsection]
|
||||
continue
|
||||
|
||||
# 处理最后一个章节
|
||||
if current_section:
|
||||
current_section.content = self._get_section_content(
|
||||
lines, current_section.line_start, len(lines)
|
||||
)
|
||||
|
||||
return sections
|
||||
|
||||
def _get_section_content(self, lines: List[str], start: int, end: int) -> str:
|
||||
"""获取指定行范围的内容"""
|
||||
if start > end:
|
||||
return ""
|
||||
content_lines = lines[start-1:end]
|
||||
# 清理:移除标题行和空行
|
||||
cleaned = []
|
||||
for line in content_lines:
|
||||
stripped = line.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
# 跳过章节标题行
|
||||
if self.CHINESE_SECTION_PATTERN.match(stripped):
|
||||
continue
|
||||
if self.PARENTHESIS_PATTERN.match(stripped):
|
||||
continue
|
||||
if self.ARABIC_SECTION_PATTERN.match(stripped):
|
||||
continue
|
||||
cleaned.append(stripped)
|
||||
return '\n'.join(cleaned)
|
||||
|
||||
async def analyze_markdown(
|
||||
self,
|
||||
file_path: str,
|
||||
analysis_type: str = "summary",
|
||||
user_prompt: str = "",
|
||||
section_number: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
使用 AI 分析 Markdown 文档
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
analysis_type: 分析类型
|
||||
user_prompt: 用户自定义提示词
|
||||
section_number: 指定分析的章节编号(如 "一" 或 "(一)")
|
||||
|
||||
Returns:
|
||||
dict: 分析结果
|
||||
"""
|
||||
try:
|
||||
parse_result = self.parser.parse(file_path)
|
||||
|
||||
if not parse_result.success:
|
||||
return {
|
||||
"success": False,
|
||||
"error": parse_result.error
|
||||
}
|
||||
|
||||
data = parse_result.data
|
||||
|
||||
# 提取章节结构
|
||||
sections = self.extract_sections(data.get("content", ""), data.get("titles", []))
|
||||
|
||||
# 如果指定了章节,只分析该章节
|
||||
target_content = data.get("content", "")
|
||||
target_title = parse_result.metadata.get("filename", "")
|
||||
|
||||
if section_number:
|
||||
section = self._find_section(sections, section_number)
|
||||
if section:
|
||||
target_content = section.content
|
||||
target_title = f"{section.number}、{section.title}"
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"未找到章节: {section_number}"
|
||||
}
|
||||
|
||||
# 根据分析类型构建提示词
|
||||
prompt = self._build_prompt(
|
||||
content=target_content,
|
||||
analysis_type=analysis_type,
|
||||
user_prompt=user_prompt,
|
||||
title=target_title
|
||||
)
|
||||
|
||||
# 调用 LLM 分析
|
||||
messages = [
|
||||
{"role": "system", "content": self._get_system_prompt(analysis_type)},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
response = await llm_service.chat(
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
max_tokens=4000
|
||||
)
|
||||
|
||||
analysis = llm_service.extract_message_content(response)
|
||||
|
||||
# 构建基础返回
|
||||
result = {
|
||||
"success": True,
|
||||
"filename": parse_result.metadata.get("filename", ""),
|
||||
"analysis_type": analysis_type,
|
||||
"section": target_title if section_number else None,
|
||||
"word_count": len(target_content),
|
||||
"structure": {
|
||||
"title_count": parse_result.metadata.get("title_count", 0),
|
||||
"code_block_count": parse_result.metadata.get("code_block_count", 0),
|
||||
"table_count": parse_result.metadata.get("table_count", 0),
|
||||
"section_count": len(sections)
|
||||
},
|
||||
"sections": [s.to_dict() for s in sections[:10]], # 最多返回10个一级章节
|
||||
"analysis": analysis
|
||||
}
|
||||
|
||||
# 如果是 charts 类型,额外生成可视化
|
||||
if analysis_type == "charts":
|
||||
try:
|
||||
# 解析 LLM 返回的 JSON 数据
|
||||
chart_data = self._parse_chart_json(analysis)
|
||||
if chart_data and chart_data.get("tables"):
|
||||
# 使用可视化服务生成图表
|
||||
for table_info in chart_data.get("tables", []):
|
||||
columns = table_info.get("columns", [])
|
||||
rows = table_info.get("rows", [])
|
||||
if columns and rows:
|
||||
vis_result = visualization_service.analyze_and_visualize({
|
||||
"columns": columns,
|
||||
"rows": [dict(zip(columns, row)) for row in rows]
|
||||
})
|
||||
if vis_result.get("success"):
|
||||
table_info["visualization"] = {
|
||||
"statistics": vis_result.get("statistics"),
|
||||
"charts": vis_result.get("charts"),
|
||||
"distributions": vis_result.get("distributions")
|
||||
}
|
||||
result["chart_data"] = chart_data
|
||||
except Exception as e:
|
||||
logger.warning(f"生成可视化图表失败: {e}")
|
||||
result["chart_data"] = {"tables": [], "key_statistics": [], "chart_suggestions": []}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Markdown AI 分析失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def analyze_markdown_stream(
|
||||
self,
|
||||
file_path: str,
|
||||
analysis_type: str = "summary",
|
||||
user_prompt: str = "",
|
||||
section_number: Optional[str] = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
流式分析 Markdown 文档 (SSE)
|
||||
|
||||
Yields:
|
||||
str: SSE 格式的数据块
|
||||
"""
|
||||
try:
|
||||
parse_result = self.parser.parse(file_path)
|
||||
|
||||
if not parse_result.success:
|
||||
yield f"data: {json.dumps({'error': parse_result.error}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
|
||||
data = parse_result.data
|
||||
sections = self.extract_sections(data.get("content", ""), data.get("titles", []))
|
||||
|
||||
target_content = data.get("content", "")
|
||||
target_title = parse_result.metadata.get("filename", "")
|
||||
|
||||
if section_number:
|
||||
section = self._find_section(sections, section_number)
|
||||
if section:
|
||||
target_content = section.content
|
||||
target_title = f"{section.number}、{section.title}"
|
||||
else:
|
||||
yield f"data: {json.dumps({'error': f'未找到章节: {section_number}'}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
|
||||
prompt = self._build_prompt(
|
||||
content=target_content,
|
||||
analysis_type=analysis_type,
|
||||
user_prompt=user_prompt,
|
||||
title=target_title
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": self._get_system_prompt(analysis_type)},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
# 发送初始元数据
|
||||
yield f"data: {json.dumps({
|
||||
'type': 'start',
|
||||
'filename': parse_result.metadata.get("filename", ""),
|
||||
'analysis_type': analysis_type,
|
||||
'section': target_title if section_number else None,
|
||||
'word_count': len(target_content)
|
||||
}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 流式调用 LLM
|
||||
full_response = ""
|
||||
async for chunk in llm_service.chat_stream(messages, temperature=0.3, max_tokens=4000):
|
||||
content = chunk.get("content", "")
|
||||
if content:
|
||||
full_response += content
|
||||
yield f"data: {json.dumps({'type': 'content', 'delta': content}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 发送完成消息
|
||||
yield f"data: {json.dumps({'type': 'done', 'full_response': full_response}, ensure_ascii=False)}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Markdown AI 流式分析失败: {str(e)}")
|
||||
yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
def _find_section(self, sections: List[MarkdownSection], number: str) -> Optional[MarkdownSection]:
|
||||
"""查找指定编号的章节"""
|
||||
# 标准化编号
|
||||
num = number.strip()
|
||||
for section in sections:
|
||||
if section.number == num or section.title == num:
|
||||
return section
|
||||
# 在子章节中查找
|
||||
found = self._find_section(section.subsections, number)
|
||||
if found:
|
||||
return found
|
||||
return None
|
||||
|
||||
def _parse_chart_json(self, json_str: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
解析 LLM 返回的 JSON 字符串
|
||||
|
||||
Args:
|
||||
json_str: LLM 返回的 JSON 字符串
|
||||
|
||||
Returns:
|
||||
解析后的字典,如果解析失败返回 None
|
||||
"""
|
||||
if not json_str:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 尝试直接解析
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试提取 JSON 代码块
|
||||
import re
|
||||
# 匹配 ```json ... ``` 格式
|
||||
match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', json_str)
|
||||
if match:
|
||||
try:
|
||||
return json.loads(match.group(1))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试找到 JSON 对象的开始和结束
|
||||
start = json_str.find('{')
|
||||
end = json_str.rfind('}')
|
||||
if start != -1 and end != -1 and end > start:
|
||||
try:
|
||||
return json.loads(json_str[start:end+1])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def _get_system_prompt(self, analysis_type: str) -> str:
|
||||
"""根据分析类型获取系统提示词"""
|
||||
prompts = {
|
||||
"summary": "你是一个专业的文档摘要助手,擅长从长文档中提取核心信息。",
|
||||
"outline": "你是一个专业的文档结构分析助手,擅长提取文档大纲和层级结构。",
|
||||
"key_points": "你是一个专业的知识提取助手,擅长从文档中提取关键信息和要点。",
|
||||
"questions": "你是一个专业的教育助手,擅长生成帮助理解文档的问题。",
|
||||
"tags": "你是一个专业的标签生成助手,擅长提取文档的主题标签。",
|
||||
"qa": "你是一个专业的问答助手,擅长基于文档内容生成问答对。",
|
||||
"statistics": "你是一个专业的统计数据分析助手,擅长分析政府统计公报中的数据。",
|
||||
"section": "你是一个专业的章节分析助手,擅长对文档的特定章节进行深入分析。",
|
||||
"charts": "你是一个专业的数据可视化助手,擅长从文档中提取数据并生成适合制作图表的数据结构。"
|
||||
}
|
||||
return prompts.get(analysis_type, "你是一个专业的文档分析助手。")
|
||||
|
||||
def _build_prompt(
|
||||
self,
|
||||
content: str,
|
||||
analysis_type: str,
|
||||
user_prompt: str,
|
||||
title: str = ""
|
||||
) -> str:
|
||||
"""根据分析类型构建提示词"""
|
||||
|
||||
# 截断内容避免超出 token 限制
|
||||
max_content_len = 6000
|
||||
if len(content) > max_content_len:
|
||||
content = content[:max_content_len] + "\n\n[内容已截断...]"
|
||||
|
||||
base_prompts = {
|
||||
"summary": f"""请对以下文档进行摘要分析:
|
||||
|
||||
文档标题:{title}
|
||||
|
||||
文档内容:
|
||||
{content}
|
||||
|
||||
请提供:
|
||||
1. 文档主要内容摘要(300字以内)
|
||||
2. 文档的目的和用途
|
||||
3. 适合的读者群体
|
||||
|
||||
请用中文回答,结构清晰。""",
|
||||
|
||||
"outline": f"""请提取以下文档的大纲结构:
|
||||
|
||||
文档标题:{title}
|
||||
|
||||
文档内容:
|
||||
{content}
|
||||
|
||||
请按层级列出文档大纲,用缩进表示层级关系。
|
||||
格式:
|
||||
一、一级标题
|
||||
(一)二级标题
|
||||
1. 三级标题
|
||||
|
||||
请用中文回答。""",
|
||||
|
||||
"key_points": f"""请从以下文档中提取关键要点:
|
||||
|
||||
文档标题:{title}
|
||||
|
||||
文档内容:
|
||||
{content}
|
||||
|
||||
请列出文档的关键要点(5-10条),每条用简洁的语言描述,并说明其在文档中的重要性。
|
||||
|
||||
请用中文回答,格式清晰。""",
|
||||
|
||||
"questions": f"""请根据以下文档生成有助于理解内容的问题:
|
||||
|
||||
文档标题:{title}
|
||||
|
||||
文档内容:
|
||||
{content}
|
||||
|
||||
请生成5-10个问题,帮助读者更好地理解文档内容。每个问题应该:
|
||||
1. 涵盖文档的重要信息点
|
||||
2. 易于理解和回答
|
||||
3. 具有思考价值
|
||||
|
||||
请用中文回答。""",
|
||||
|
||||
"tags": f"""请为以下文档生成标签:
|
||||
|
||||
文档标题:{title}
|
||||
|
||||
文档内容:
|
||||
{content[:3000]}
|
||||
|
||||
请生成5-8个标签,用逗号分隔。标签应该反映:
|
||||
- 文档的主题领域
|
||||
- 文档的类型
|
||||
- 文档的关键特征
|
||||
|
||||
请用中文回答,只需输出标签,不要其他内容。""",
|
||||
|
||||
"qa": f"""请根据以下文档生成问答对:
|
||||
|
||||
文档标题:{title}
|
||||
|
||||
文档内容:
|
||||
{content[:4000]}
|
||||
|
||||
请生成3-5个问答对,帮助读者通过问答形式理解文档内容。
|
||||
格式:
|
||||
Q1: 问题
|
||||
A1: 回答
|
||||
Q2: 问题
|
||||
A2: 回答
|
||||
|
||||
请用中文回答,内容准确。""",
|
||||
|
||||
"statistics": f"""请分析以下政府统计公报中的数据和结论:
|
||||
|
||||
文档标题:{title}
|
||||
|
||||
文档内容:
|
||||
{content}
|
||||
|
||||
请提供:
|
||||
1. 文档中涉及的主要统计数据(列出关键数字和指标)
|
||||
2. 数据的变化趋势(增长/下降)
|
||||
3. 重要的百分比和对比
|
||||
4. 数据来源和统计口径说明
|
||||
|
||||
请用中文回答,数据准确。""",
|
||||
|
||||
"section": f"""请详细分析以下文档章节:
|
||||
|
||||
章节标题:{title}
|
||||
|
||||
章节内容:
|
||||
{content}
|
||||
|
||||
请提供:
|
||||
1. 章节主要内容概括
|
||||
2. 关键信息和数据
|
||||
3. 与其他部分的关联(如有)
|
||||
4. 重要结论
|
||||
|
||||
请用中文回答,分析深入。""",
|
||||
|
||||
"charts": f"""请从以下文档中提取可用于可视化的数据,并生成适合制作图表的数据结构:
|
||||
|
||||
文档标题:{title}
|
||||
|
||||
文档内容:
|
||||
{content}
|
||||
|
||||
请完成以下任务:
|
||||
1. 识别文档中的表格数据(Markdown表格格式)
|
||||
2. 识别文档中的关键统计数据(百分比、数量、趋势等)
|
||||
3. 识别可用于比较的分类数据
|
||||
|
||||
请用 JSON 格式返回以下结构的数据(如果没有表格数据,返回空结构):
|
||||
{{
|
||||
"tables": [
|
||||
{{
|
||||
"description": "表格的描述",
|
||||
"columns": ["列名1", "列名2", ...],
|
||||
"rows": [
|
||||
["值1", "值2", ...],
|
||||
["值1", "值2", ...]
|
||||
]
|
||||
}}
|
||||
],
|
||||
"key_statistics": [
|
||||
{{
|
||||
"name": "指标名称",
|
||||
"value": "数值",
|
||||
"trend": "增长/下降/持平",
|
||||
"description": "指标说明"
|
||||
}}
|
||||
],
|
||||
"chart_suggestions": [
|
||||
{{
|
||||
"chart_type": "bar/line/pie",
|
||||
"title": "图表标题",
|
||||
"data_source": "数据来源说明"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
请确保返回的是合法的 JSON 格式。"""
|
||||
}
|
||||
|
||||
prompt = base_prompts.get(analysis_type, base_prompts["summary"])
|
||||
|
||||
if user_prompt and user_prompt.strip():
|
||||
prompt += f"\n\n用户额外需求:{user_prompt}"
|
||||
|
||||
return prompt
|
||||
|
||||
async def extract_outline(self, file_path: str) -> Dict[str, Any]:
|
||||
"""提取文档大纲"""
|
||||
try:
|
||||
parse_result = self.parser.parse(file_path)
|
||||
|
||||
if not parse_result.success:
|
||||
return {"success": False, "error": parse_result.error}
|
||||
|
||||
data = parse_result.data
|
||||
sections = self.extract_sections(data.get("content", ""), data.get("titles", []))
|
||||
|
||||
# 构建结构化大纲
|
||||
outline = []
|
||||
for section in sections:
|
||||
outline.append({
|
||||
"number": section.number,
|
||||
"title": section.title,
|
||||
"level": section.level,
|
||||
"line": section.line_start,
|
||||
"content_preview": section.content[:100] + "..." if len(section.content) > 100 else section.content,
|
||||
"subsections": [{
|
||||
"number": s.number,
|
||||
"title": s.title,
|
||||
"level": s.level,
|
||||
"line": s.line_start
|
||||
} for s in section.subsections]
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"outline": outline
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"大纲提取失败: {str(e)}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def extract_tables_summary(self, file_path: str) -> Dict[str, Any]:
|
||||
"""提取并总结文档中的表格"""
|
||||
try:
|
||||
parse_result = self.parser.parse(file_path)
|
||||
|
||||
if not parse_result.success:
|
||||
return {"success": False, "error": parse_result.error}
|
||||
|
||||
tables = parse_result.data.get("tables", [])
|
||||
|
||||
if not tables:
|
||||
return {"success": True, "tables": [], "message": "文档中没有表格"}
|
||||
|
||||
# 提取每个表格的关键信息
|
||||
table_summaries = []
|
||||
for i, table in enumerate(tables):
|
||||
summary = {
|
||||
"index": i + 1,
|
||||
"headers": table.get("headers", []),
|
||||
"row_count": table.get("row_count", 0),
|
||||
"column_count": table.get("column_count", 0),
|
||||
"preview_rows": table.get("rows", [])[:3], # 只取前3行预览
|
||||
"first_column": [row[0] if row else "" for row in table.get("rows", [])[:5]]
|
||||
}
|
||||
table_summaries.append(summary)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"tables": table_summaries,
|
||||
"table_count": len(tables)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"表格提取失败: {str(e)}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
# 全局单例
|
||||
markdown_ai_service = MarkdownAIService()
|
||||
446
backend/app/services/multi_doc_reasoning_service.py
Normal file
@@ -0,0 +1,446 @@
|
||||
"""
|
||||
多文档关联推理服务
|
||||
|
||||
跨文档信息关联和推理
|
||||
"""
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from collections import defaultdict
|
||||
|
||||
from app.services.llm_service import llm_service
|
||||
from app.services.rag_service import rag_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MultiDocReasoningService:
|
||||
"""
|
||||
多文档关联推理服务
|
||||
|
||||
功能:
|
||||
1. 实体跨文档追踪 - 追踪同一实体在不同文档中的描述
|
||||
2. 关系抽取与推理 - 抽取实体间关系并进行推理
|
||||
3. 信息补全 - 根据多个文档的信息互补填充缺失数据
|
||||
4. 冲突检测 - 检测不同文档间的信息冲突
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.llm = llm_service
|
||||
|
||||
async def analyze_cross_documents(
|
||||
self,
|
||||
documents: List[Dict[str, Any]],
|
||||
query: Optional[str] = None,
|
||||
entity_types: Optional[List[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
跨文档分析
|
||||
|
||||
Args:
|
||||
documents: 文档列表
|
||||
query: 查询意图(可选)
|
||||
entity_types: 要追踪的实体类型列表,如 ["机构", "人物", "地点", "数量"]
|
||||
|
||||
Returns:
|
||||
跨文档分析结果
|
||||
"""
|
||||
if not documents:
|
||||
return {"success": False, "error": "没有可用的文档"}
|
||||
|
||||
entity_types = entity_types or ["机构", "数量", "时间", "地点"]
|
||||
|
||||
try:
|
||||
# 1. 提取各文档中的实体
|
||||
entities_per_doc = await self._extract_entities_from_docs(documents, entity_types)
|
||||
|
||||
# 2. 跨文档实体对齐
|
||||
aligned_entities = self._align_entities_across_docs(entities_per_doc)
|
||||
|
||||
# 3. 关系抽取
|
||||
relations = await self._extract_relations(documents)
|
||||
|
||||
# 4. 构建知识图谱
|
||||
knowledge_graph = self._build_knowledge_graph(aligned_entities, relations)
|
||||
|
||||
# 5. 信息补全
|
||||
completed_info = await self._complete_missing_info(knowledge_graph, documents)
|
||||
|
||||
# 6. 冲突检测
|
||||
conflicts = self._detect_conflicts(aligned_entities)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"entities": aligned_entities,
|
||||
"relations": relations,
|
||||
"knowledge_graph": knowledge_graph,
|
||||
"completed_info": completed_info,
|
||||
"conflicts": conflicts,
|
||||
"summary": self._generate_summary(aligned_entities, conflicts)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"跨文档分析失败: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def _extract_entities_from_docs(
|
||||
self,
|
||||
documents: List[Dict[str, Any]],
|
||||
entity_types: List[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""从各文档中提取实体"""
|
||||
entities_per_doc = []
|
||||
|
||||
for idx, doc in enumerate(documents):
|
||||
doc_id = doc.get("_id", f"doc_{idx}")
|
||||
content = doc.get("content", "")[:8000] # 限制长度
|
||||
|
||||
# 使用 LLM 提取实体
|
||||
prompt = f"""从以下文档中提取指定的实体类型信息。
|
||||
|
||||
实体类型: {', '.join(entity_types)}
|
||||
|
||||
文档内容:
|
||||
{content}
|
||||
|
||||
请按以下 JSON 格式输出(只需输出 JSON):
|
||||
{{
|
||||
"entities": [
|
||||
{{"type": "机构", "name": "实体名称", "value": "相关数值(如有)", "context": "上下文描述"}},
|
||||
...
|
||||
]
|
||||
}}
|
||||
|
||||
只提取在文档中明确提到的实体,不要推测。"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个实体提取专家。请严格按JSON格式输出。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
response = await self.llm.chat(messages=messages, temperature=0.1, max_tokens=3000)
|
||||
content_response = self.llm.extract_message_content(response)
|
||||
|
||||
# 解析 JSON
|
||||
import json
|
||||
import re
|
||||
cleaned = content_response.strip()
|
||||
json_match = re.search(r'\{[\s\S]*\}', cleaned)
|
||||
if json_match:
|
||||
result = json.loads(json_match.group())
|
||||
entities = result.get("entities", [])
|
||||
entities_per_doc.append({
|
||||
"doc_id": doc_id,
|
||||
"doc_name": doc.get("metadata", {}).get("original_filename", f"文档{idx+1}"),
|
||||
"entities": entities
|
||||
})
|
||||
logger.info(f"文档 {doc_id} 提取到 {len(entities)} 个实体")
|
||||
except Exception as e:
|
||||
logger.warning(f"文档 {doc_id} 实体提取失败: {e}")
|
||||
|
||||
return entities_per_doc
|
||||
|
||||
def _align_entities_across_docs(
|
||||
self,
|
||||
entities_per_doc: List[Dict[str, Any]]
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
跨文档实体对齐
|
||||
|
||||
将同一实体在不同文档中的描述进行关联
|
||||
"""
|
||||
aligned: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
|
||||
|
||||
for doc_data in entities_per_doc:
|
||||
doc_id = doc_data["doc_id"]
|
||||
doc_name = doc_data["doc_name"]
|
||||
|
||||
for entity in doc_data.get("entities", []):
|
||||
entity_name = entity.get("name", "")
|
||||
if not entity_name:
|
||||
continue
|
||||
|
||||
# 标准化实体名(去除空格和括号内容)
|
||||
normalized = self._normalize_entity_name(entity_name)
|
||||
|
||||
aligned[normalized].append({
|
||||
"original_name": entity_name,
|
||||
"type": entity.get("type", "未知"),
|
||||
"value": entity.get("value", ""),
|
||||
"context": entity.get("context", ""),
|
||||
"source_doc": doc_name,
|
||||
"source_doc_id": doc_id
|
||||
})
|
||||
|
||||
# 合并相同实体
|
||||
result = {}
|
||||
for normalized, appearances in aligned.items():
|
||||
if len(appearances) > 1:
|
||||
result[normalized] = appearances
|
||||
logger.info(f"实体对齐: {normalized} 在 {len(appearances)} 个文档中出现")
|
||||
|
||||
return result
|
||||
|
||||
def _normalize_entity_name(self, name: str) -> str:
|
||||
"""标准化实体名称"""
|
||||
# 去除空格
|
||||
name = name.strip()
|
||||
# 去除括号内容
|
||||
name = re.sub(r'[((].*?[))]', '', name)
|
||||
# 去除"第X名"等
|
||||
name = re.sub(r'^第\d+[名位个]', '', name)
|
||||
return name.strip()
|
||||
|
||||
async def _extract_relations(
|
||||
self,
|
||||
documents: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, str]]:
|
||||
"""从文档中抽取关系"""
|
||||
relations = []
|
||||
|
||||
# 合并所有文档内容
|
||||
combined_content = "\n\n".join([
|
||||
f"【{doc.get('metadata', {}).get('original_filename', f'文档{i}')}】\n{doc.get('content', '')[:3000]}"
|
||||
for i, doc in enumerate(documents)
|
||||
])
|
||||
|
||||
prompt = f"""从以下文档内容中抽取实体之间的关系。
|
||||
|
||||
文档内容:
|
||||
{combined_content[:8000]}
|
||||
|
||||
请识别以下类型的关系:
|
||||
- 包含关系 (A包含B)
|
||||
- 隶属关系 (A隶属于B)
|
||||
- 合作关系 (A与B合作)
|
||||
- 对比关系 (A vs B)
|
||||
- 时序关系 (A先于B发生)
|
||||
|
||||
请按以下 JSON 格式输出(只需输出 JSON):
|
||||
{{
|
||||
"relations": [
|
||||
{{"entity1": "实体1", "entity2": "实体2", "relation": "关系类型", "description": "关系描述"}},
|
||||
...
|
||||
]
|
||||
}}
|
||||
|
||||
如果没有找到明确的关系,返回空数组。"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个关系抽取专家。请严格按JSON格式输出。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
response = await self.llm.chat(messages=messages, temperature=0.1, max_tokens=3000)
|
||||
content_response = self.llm.extract_message_content(response)
|
||||
|
||||
import json
|
||||
import re
|
||||
cleaned = content_response.strip()
|
||||
json_match = re.search(r'\{{[\s\S]*\}}', cleaned)
|
||||
if json_match:
|
||||
result = json.loads(json_match.group())
|
||||
relations = result.get("relations", [])
|
||||
logger.info(f"抽取到 {len(relations)} 个关系")
|
||||
except Exception as e:
|
||||
logger.warning(f"关系抽取失败: {e}")
|
||||
|
||||
return relations
|
||||
|
||||
def _build_knowledge_graph(
|
||||
self,
|
||||
aligned_entities: Dict[str, List[Dict[str, Any]]],
|
||||
relations: List[Dict[str, str]]
|
||||
) -> Dict[str, Any]:
|
||||
"""构建知识图谱"""
|
||||
nodes = []
|
||||
edges = []
|
||||
node_ids = set()
|
||||
|
||||
# 添加实体节点
|
||||
for entity_name, appearances in aligned_entities.items():
|
||||
if len(appearances) < 1:
|
||||
continue
|
||||
|
||||
first_appearance = appearances[0]
|
||||
node_id = f"entity_{len(nodes)}"
|
||||
|
||||
# 收集该实体在所有文档中的值
|
||||
values = [a.get("value", "") for a in appearances if a.get("value")]
|
||||
primary_value = values[0] if values else ""
|
||||
|
||||
nodes.append({
|
||||
"id": node_id,
|
||||
"name": entity_name,
|
||||
"type": first_appearance.get("type", "未知"),
|
||||
"value": primary_value,
|
||||
"occurrence_count": len(appearances),
|
||||
"sources": [a.get("source_doc", "") for a in appearances]
|
||||
})
|
||||
node_ids.add(entity_name)
|
||||
|
||||
# 添加关系边
|
||||
for relation in relations:
|
||||
entity1 = self._normalize_entity_name(relation.get("entity1", ""))
|
||||
entity2 = self._normalize_entity_name(relation.get("entity2", ""))
|
||||
|
||||
if entity1 in node_ids and entity2 in node_ids:
|
||||
edges.append({
|
||||
"source": entity1,
|
||||
"target": entity2,
|
||||
"relation": relation.get("relation", "相关"),
|
||||
"description": relation.get("description", "")
|
||||
})
|
||||
|
||||
return {
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"stats": {
|
||||
"entity_count": len(nodes),
|
||||
"relation_count": len(edges)
|
||||
}
|
||||
}
|
||||
|
||||
async def _complete_missing_info(
|
||||
self,
|
||||
knowledge_graph: Dict[str, Any],
|
||||
documents: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""根据多个文档补全信息"""
|
||||
completed = []
|
||||
|
||||
for node in knowledge_graph.get("nodes", []):
|
||||
if not node.get("value") and node.get("occurrence_count", 0) > 1:
|
||||
# 实体在多个文档中出现但没有数值,尝试从 RAG 检索补充
|
||||
query = f"{node['name']} 数值 数据"
|
||||
results = rag_service.retrieve(query, top_k=3, min_score=0.3)
|
||||
|
||||
if results:
|
||||
completed.append({
|
||||
"entity": node["name"],
|
||||
"type": node.get("type", "未知"),
|
||||
"source": "rag_inference",
|
||||
"context": results[0].get("content", "")[:200],
|
||||
"confidence": results[0].get("score", 0)
|
||||
})
|
||||
|
||||
return completed
|
||||
|
||||
def _detect_conflicts(
|
||||
self,
|
||||
aligned_entities: Dict[str, List[Dict[str, Any]]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""检测不同文档间的信息冲突"""
|
||||
conflicts = []
|
||||
|
||||
for entity_name, appearances in aligned_entities.items():
|
||||
if len(appearances) < 2:
|
||||
continue
|
||||
|
||||
# 检查数值冲突
|
||||
values = {}
|
||||
for appearance in appearances:
|
||||
val = appearance.get("value", "")
|
||||
if val:
|
||||
source = appearance.get("source_doc", "未知来源")
|
||||
values[source] = val
|
||||
|
||||
if len(values) > 1:
|
||||
unique_values = set(values.values())
|
||||
if len(unique_values) > 1:
|
||||
conflicts.append({
|
||||
"entity": entity_name,
|
||||
"type": "value_conflict",
|
||||
"details": values,
|
||||
"description": f"实体 '{entity_name}' 在不同文档中有不同数值: {values}"
|
||||
})
|
||||
|
||||
return conflicts
|
||||
|
||||
def _generate_summary(
|
||||
self,
|
||||
aligned_entities: Dict[str, List[Dict[str, Any]]],
|
||||
conflicts: List[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""生成摘要"""
|
||||
summary_parts = []
|
||||
|
||||
total_entities = sum(len(appearances) for appearances in aligned_entities.values())
|
||||
multi_doc_entities = sum(1 for appearances in aligned_entities.values() if len(appearances) > 1)
|
||||
|
||||
summary_parts.append(f"跨文档分析完成:发现 {total_entities} 个实体")
|
||||
summary_parts.append(f"其中 {multi_doc_entities} 个实体在多个文档中被提及")
|
||||
|
||||
if conflicts:
|
||||
summary_parts.append(f"检测到 {len(conflicts)} 个潜在冲突")
|
||||
|
||||
return "; ".join(summary_parts)
|
||||
|
||||
async def answer_cross_doc_question(
|
||||
self,
|
||||
question: str,
|
||||
documents: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
跨文档问答
|
||||
|
||||
Args:
|
||||
question: 问题
|
||||
documents: 文档列表
|
||||
|
||||
Returns:
|
||||
答案结果
|
||||
"""
|
||||
# 先进行跨文档分析
|
||||
analysis_result = await self.analyze_cross_documents(documents, query=question)
|
||||
|
||||
# 构建上下文
|
||||
context_parts = []
|
||||
|
||||
# 添加实体信息
|
||||
for entity_name, appearances in analysis_result.get("entities", {}).items():
|
||||
contexts = [f"{a.get('source_doc')}: {a.get('context', '')}" for a in appearances[:2]]
|
||||
if contexts:
|
||||
context_parts.append(f"【{entity_name}】{' | '.join(contexts)}")
|
||||
|
||||
# 添加关系信息
|
||||
for relation in analysis_result.get("relations", [])[:5]:
|
||||
context_parts.append(f"【关系】{relation.get('entity1')} {relation.get('relation')} {relation.get('entity2')}: {relation.get('description', '')}")
|
||||
|
||||
context_text = "\n\n".join(context_parts) if context_parts else "未找到相关实体和关系"
|
||||
|
||||
# 使用 LLM 生成答案
|
||||
prompt = f"""基于以下跨文档分析结果,回答用户问题。
|
||||
|
||||
问题: {question}
|
||||
|
||||
分析结果:
|
||||
{context_text}
|
||||
|
||||
请直接回答问题,如果分析结果中没有相关信息,请说明"根据提供的文档无法回答该问题"。"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个基于文档的问答助手。请根据提供的信息回答问题。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
response = await self.llm.chat(messages=messages, temperature=0.2, max_tokens=2000)
|
||||
answer = self.llm.extract_message_content(response)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"supporting_entities": list(analysis_result.get("entities", {}).keys())[:10],
|
||||
"relations_count": len(analysis_result.get("relations", []))
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"跨文档问答失败: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
# 全局单例
|
||||
multi_doc_reasoning_service = MultiDocReasoningService()
|
||||
444
backend/app/services/prompt_service.py
Normal file
@@ -0,0 +1,444 @@
|
||||
"""
|
||||
提示词工程服务
|
||||
|
||||
管理和优化与大模型交互的提示词
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PromptType(Enum):
|
||||
"""提示词类型"""
|
||||
DOCUMENT_PARSING = "document_parsing" # 文档解析
|
||||
FIELD_EXTRACTION = "field_extraction" # 字段提取
|
||||
TABLE_FILLING = "table_filling" # 表格填写
|
||||
QUERY_GENERATION = "query_generation" # 查询生成
|
||||
TEXT_SUMMARY = "text_summary" # 文本摘要
|
||||
INTENT_CLASSIFICATION = "intent_classification" # 意图分类
|
||||
DATA_CLASSIFICATION = "data_classification" # 数据分类
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptTemplate:
|
||||
"""提示词模板"""
|
||||
name: str
|
||||
type: PromptType
|
||||
system_prompt: str
|
||||
user_template: str
|
||||
examples: List[Dict[str, str]] = field(default_factory=list) # Few-shot 示例
|
||||
rules: List[str] = field(default_factory=list) # 特殊规则
|
||||
|
||||
def format(
|
||||
self,
|
||||
context: Dict[str, Any],
|
||||
user_input: Optional[str] = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
格式化提示词
|
||||
|
||||
Args:
|
||||
context: 上下文数据
|
||||
user_input: 用户输入
|
||||
|
||||
Returns:
|
||||
格式化后的消息列表
|
||||
"""
|
||||
messages = []
|
||||
|
||||
# 系统提示词
|
||||
system_content = self.system_prompt
|
||||
|
||||
# 添加规则
|
||||
if self.rules:
|
||||
system_content += "\n\n【输出规则】\n" + "\n".join([f"- {rule}" for rule in self.rules])
|
||||
|
||||
# 添加示例
|
||||
if self.examples:
|
||||
system_content += "\n\n【示例】\n"
|
||||
for i, ex in enumerate(self.examples):
|
||||
system_content += f"\n示例 {i+1}:\n"
|
||||
system_content += f"输入: {ex.get('input', '')}\n"
|
||||
system_content += f"输出: {ex.get('output', '')}\n"
|
||||
|
||||
messages.append({"role": "system", "content": system_content})
|
||||
|
||||
# 用户提示词
|
||||
user_content = self._format_user_template(context, user_input)
|
||||
messages.append({"role": "user", "content": user_content})
|
||||
|
||||
return messages
|
||||
|
||||
def _format_user_template(
|
||||
self,
|
||||
context: Dict[str, Any],
|
||||
user_input: Optional[str]
|
||||
) -> str:
|
||||
"""格式化用户模板"""
|
||||
content = self.user_template
|
||||
|
||||
# 替换上下文变量
|
||||
for key, value in context.items():
|
||||
placeholder = f"{{{key}}}"
|
||||
if placeholder in content:
|
||||
if isinstance(value, (dict, list)):
|
||||
content = content.replace(placeholder, json.dumps(value, ensure_ascii=False, indent=2))
|
||||
else:
|
||||
content = content.replace(placeholder, str(value))
|
||||
|
||||
# 添加用户输入
|
||||
if user_input:
|
||||
content += f"\n\n【用户需求】\n{user_input}"
|
||||
|
||||
return content
|
||||
|
||||
|
||||
class PromptEngineeringService:
|
||||
"""提示词工程服务"""
|
||||
|
||||
def __init__(self):
|
||||
self.templates: Dict[PromptType, PromptTemplate] = {}
|
||||
self._init_templates()
|
||||
|
||||
def _init_templates(self):
|
||||
"""初始化所有提示词模板"""
|
||||
|
||||
# ==================== 文档解析模板 ====================
|
||||
self.templates[PromptType.DOCUMENT_PARSING] = PromptTemplate(
|
||||
name="文档解析",
|
||||
type=PromptType.DOCUMENT_PARSING,
|
||||
system_prompt="""你是一个专业的文档解析专家。你的任务是从各类文档(Word、Excel、Markdown、纯文本)中提取关键信息。
|
||||
|
||||
请严格按照JSON格式输出解析结果:
|
||||
{
|
||||
"success": true/false,
|
||||
"document_type": "文档类型",
|
||||
"key_fields": {"字段名": "字段值", ...},
|
||||
"summary": "文档摘要(100字内)",
|
||||
"structured_data": {...} // 提取的表格或其他结构化数据
|
||||
}
|
||||
|
||||
重要规则:
|
||||
- 只提取明确存在的信息,不要猜测
|
||||
- 如果是表格数据,请以数组格式输出
|
||||
- 日期请使用 YYYY-MM-DD 格式
|
||||
- 金额请使用数字格式
|
||||
- 如果无法提取某个字段,设置为 null""",
|
||||
user_template="""请解析以下文档内容:
|
||||
|
||||
=== 文档开始 ===
|
||||
{content}
|
||||
=== 文档结束 ===
|
||||
|
||||
请提取文档中的关键信息。""",
|
||||
examples=[
|
||||
{
|
||||
"input": "合同金额:100万元\n签订日期:2024年1月15日\n甲方:张三\n乙方:某某公司",
|
||||
"output": '{"success": true, "document_type": "合同", "key_fields": {"金额": 1000000, "日期": "2024-01-15", "甲方": "张三", "乙方": "某某公司"}, "summary": "甲乙双方签订的金额为100万元的合同", "structured_data": null}'
|
||||
}
|
||||
],
|
||||
rules=[
|
||||
"只输出JSON,不要添加任何解释",
|
||||
"使用严格的JSON格式"
|
||||
]
|
||||
)
|
||||
|
||||
# ==================== 字段提取模板 ====================
|
||||
self.templates[PromptType.FIELD_EXTRACTION] = PromptTemplate(
|
||||
name="字段提取",
|
||||
type=PromptType.FIELD_EXTRACTION,
|
||||
system_prompt="""你是一个专业的数据提取专家。你的任务是从文档内容中提取指定字段的信息。
|
||||
|
||||
请严格按照以下JSON格式输出:
|
||||
{
|
||||
"value": "提取到的值,找不到则为空字符串",
|
||||
"source": "数据来源描述",
|
||||
"confidence": 0.0到1.0之间的置信度
|
||||
}
|
||||
|
||||
重要规则:
|
||||
- 严格按字段名称匹配,不要提取无关信息
|
||||
- 置信度反映你对提取结果的信心程度
|
||||
- 如果字段不存在或无法确定,value设为空字符串,confidence设为0.0
|
||||
- value必须是实际值,不能是"未找到"之类的描述""",
|
||||
user_template="""请从以下文档内容中提取指定字段的信息。
|
||||
|
||||
【需要提取的字段】
|
||||
字段名称:{field_name}
|
||||
字段类型:{field_type}
|
||||
是否必填:{required}
|
||||
|
||||
【用户提示】
|
||||
{hint}
|
||||
|
||||
【文档内容】
|
||||
{context}
|
||||
|
||||
请提取字段值。""",
|
||||
examples=[
|
||||
{
|
||||
"input": "文档内容:姓名张三,电话13800138000,邮箱zhangsan@example.com",
|
||||
"output": '{"value": "张三", "source": "文档第1行", "confidence": 1.0}'
|
||||
}
|
||||
],
|
||||
rules=[
|
||||
"只输出JSON,不要添加任何解释"
|
||||
]
|
||||
)
|
||||
|
||||
# ==================== 表格填写模板 ====================
|
||||
self.templates[PromptType.TABLE_FILLING] = PromptTemplate(
|
||||
name="表格填写",
|
||||
type=PromptType.TABLE_FILLING,
|
||||
system_prompt="""你是一个专业的表格填写助手。你的任务是根据提供的文档内容,填写表格模板中的字段。
|
||||
|
||||
请严格按照以下JSON格式输出:
|
||||
{
|
||||
"filled_data": {{"字段1": "值1", "字段2": "值2", ...}},
|
||||
"fill_details": [
|
||||
{{"field": "字段1", "value": "值1", "source": "来源", "confidence": 0.95}},
|
||||
...
|
||||
]
|
||||
}
|
||||
|
||||
重要规则:
|
||||
- 只填写模板中存在的字段
|
||||
- 值必须来自提供的文档内容,不要编造
|
||||
- 如果某个字段在文档中找不到对应值,设为空字符串
|
||||
- fill_details 中记录每个字段的详细信息""",
|
||||
user_template="""请根据以下文档内容,填写表格模板。
|
||||
|
||||
【表格模板字段】
|
||||
{fields}
|
||||
|
||||
【用户需求】
|
||||
{hint}
|
||||
|
||||
【参考文档内容】
|
||||
{context}
|
||||
|
||||
请填写表格。""",
|
||||
examples=[
|
||||
{
|
||||
"input": "字段:姓名、电话\n文档:张三,电话是13800138000",
|
||||
"output": '{"filled_data": {"姓名": "张三", "电话": "13800138000"}, "fill_details": [{"field": "姓名", "value": "张三", "source": "文档第1行", "confidence": 1.0}, {"field": "电话", "value": "13800138000", "source": "文档第1行", "confidence": 1.0}]}'
|
||||
}
|
||||
],
|
||||
rules=[
|
||||
"只输出JSON,不要添加任何解释"
|
||||
]
|
||||
)
|
||||
|
||||
# ==================== 查询生成模板 ====================
|
||||
self.templates[PromptType.QUERY_GENERATION] = PromptTemplate(
|
||||
name="查询生成",
|
||||
type=PromptType.QUERY_GENERATION,
|
||||
system_prompt="""你是一个SQL查询生成专家。你的任务是根据用户的自然语言需求,生成相应的数据库查询语句。
|
||||
|
||||
请严格按照以下JSON格式输出:
|
||||
{
|
||||
"sql_query": "生成的SQL查询语句",
|
||||
"explanation": "查询逻辑说明"
|
||||
}
|
||||
|
||||
重要规则:
|
||||
- 只生成 SELECT 查询语句,不要生成 INSERT/UPDATE/DELETE
|
||||
- 必须包含 WHERE 条件限制查询范围
|
||||
- 表名和字段名使用反引号包裹
|
||||
- 确保SQL语法正确
|
||||
- 如果无法生成有效的查询,sql_query设为空字符串""",
|
||||
user_template="""根据以下信息生成查询语句。
|
||||
|
||||
【数据库表结构】
|
||||
{table_schema}
|
||||
|
||||
【RAG检索到的上下文】
|
||||
{rag_context}
|
||||
|
||||
【用户查询需求】
|
||||
{user_intent}
|
||||
|
||||
请生成SQL查询。""",
|
||||
examples=[
|
||||
{
|
||||
"input": "表:orders(订单号, 金额, 日期, 客户)\n需求:查询2024年1月销售额超过10000的订单",
|
||||
"output": '{"sql_query": "SELECT * FROM `orders` WHERE `日期` >= \\'2024-01-01\\' AND `日期` < \\'2024-02-01\\' AND `金额` > 10000", "explanation": "筛选2024年1月销售额超过10000的订单"}'
|
||||
}
|
||||
],
|
||||
rules=[
|
||||
"只输出JSON,不要添加任何解释",
|
||||
"禁止生成 DROP、DELETE、TRUNCATE 等危险操作"
|
||||
]
|
||||
)
|
||||
|
||||
# ==================== 文本摘要模板 ====================
|
||||
self.templates[PromptType.TEXT_SUMMARY] = PromptTemplate(
|
||||
name="文本摘要",
|
||||
type=PromptType.TEXT_SUMMARY,
|
||||
system_prompt="""你是一个专业的文本摘要专家。你的任务是对长文档进行压缩,提取关键信息。
|
||||
|
||||
请严格按照以下JSON格式输出:
|
||||
{
|
||||
"summary": "摘要内容(不超过200字)",
|
||||
"key_points": ["要点1", "要点2", "要点3"],
|
||||
"keywords": ["关键词1", "关键词2", "关键词3"]
|
||||
}""",
|
||||
user_template="""请为以下文档生成摘要:
|
||||
|
||||
=== 文档开始 ===
|
||||
{content}
|
||||
=== 文档结束 ===
|
||||
|
||||
生成简明摘要。""",
|
||||
rules=[
|
||||
"只输出JSON,不要添加任何解释"
|
||||
]
|
||||
)
|
||||
|
||||
# ==================== 意图分类模板 ====================
|
||||
self.templates[PromptType.INTENT_CLASSIFICATION] = PromptTemplate(
|
||||
name="意图分类",
|
||||
type=PromptType.INTENT_CLASSIFICATION,
|
||||
system_prompt="""你是一个意图分类专家。你的任务是分析用户的自然语言输入,判断用户的真实意图。
|
||||
|
||||
支持的意图类型:
|
||||
- upload: 上传文档
|
||||
- parse: 解析文档
|
||||
- query: 查询数据
|
||||
- fill: 填写表格
|
||||
- export: 导出数据
|
||||
- analyze: 分析数据
|
||||
- other: 其他/未知
|
||||
|
||||
请严格按照以下JSON格式输出:
|
||||
{
|
||||
"intent": "意图类型",
|
||||
"confidence": 0.0到1.0之间的置信度,
|
||||
"entities": {{"实体名": "实体值", ...}}, // 识别出的关键实体
|
||||
"suggestion": "建议的下一步操作"
|
||||
}""",
|
||||
user_template="""请分析以下用户输入,判断其意图:
|
||||
|
||||
【用户输入】
|
||||
{user_input}
|
||||
|
||||
请分类。""",
|
||||
rules=[
|
||||
"只输出JSON,不要添加任何解释"
|
||||
]
|
||||
)
|
||||
|
||||
# ==================== 数据分类模板 ====================
|
||||
self.templates[PromptType.DATA_CLASSIFICATION] = PromptTemplate(
|
||||
name="数据分类",
|
||||
type=PromptType.DATA_CLASSIFICATION,
|
||||
system_prompt="""你是一个数据分类专家。你的任务是判断数据的类型和格式。
|
||||
|
||||
请严格按照以下JSON格式输出:
|
||||
{
|
||||
"data_type": "text/number/date/email/phone/url/amount/other",
|
||||
"format": "具体格式描述",
|
||||
"is_valid": true/false,
|
||||
"normalized_value": "规范化后的值"
|
||||
}""",
|
||||
user_template="""请分析以下数据的类型和格式:
|
||||
|
||||
【数据】
|
||||
{value}
|
||||
|
||||
【期望类型(如果有)】
|
||||
{expected_type}
|
||||
|
||||
请分类。""",
|
||||
rules=[
|
||||
"只输出JSON,不要添加任何解释"
|
||||
]
|
||||
)
|
||||
|
||||
def get_prompt(
|
||||
self,
|
||||
type: PromptType,
|
||||
context: Dict[str, Any],
|
||||
user_input: Optional[str] = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
获取格式化后的提示词
|
||||
|
||||
Args:
|
||||
type: 提示词类型
|
||||
context: 上下文数据
|
||||
user_input: 用户输入
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
template = self.templates.get(type)
|
||||
if not template:
|
||||
logger.warning(f"未找到提示词模板: {type}")
|
||||
return [{"role": "user", "content": str(context)}]
|
||||
|
||||
return template.format(context, user_input)
|
||||
|
||||
def get_template(self, type: PromptType) -> Optional[PromptTemplate]:
|
||||
"""获取提示词模板"""
|
||||
return self.templates.get(type)
|
||||
|
||||
def add_template(self, template: PromptTemplate):
|
||||
"""添加自定义提示词模板"""
|
||||
self.templates[template.type] = template
|
||||
logger.info(f"已添加提示词模板: {template.name}")
|
||||
|
||||
def update_template(self, type: PromptType, **kwargs):
|
||||
"""更新提示词模板"""
|
||||
template = self.templates.get(type)
|
||||
if template:
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(template, key):
|
||||
setattr(template, key, value)
|
||||
|
||||
def optimize_prompt(
|
||||
self,
|
||||
type: PromptType,
|
||||
feedback: str,
|
||||
iteration: int = 1
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
根据反馈优化提示词
|
||||
|
||||
Args:
|
||||
type: 提示词类型
|
||||
feedback: 优化反馈
|
||||
iteration: 迭代次数
|
||||
|
||||
Returns:
|
||||
优化后的提示词
|
||||
"""
|
||||
template = self.templates.get(type)
|
||||
if not template:
|
||||
return []
|
||||
|
||||
# 简单优化策略:根据反馈添加规则
|
||||
optimization_rules = {
|
||||
"准确率低": "提高要求,明确指出必须从原文提取,不要猜测",
|
||||
"格式错误": "强调JSON格式要求,提供更详细的格式示例",
|
||||
"遗漏信息": "添加提取更多细节的要求",
|
||||
}
|
||||
|
||||
new_rules = []
|
||||
for keyword, rule in optimization_rules.items():
|
||||
if keyword in feedback:
|
||||
new_rules.append(rule)
|
||||
|
||||
if new_rules:
|
||||
template.rules.extend(new_rules)
|
||||
|
||||
return template.format({}, None)
|
||||
|
||||
|
||||
# ==================== 全局单例 ====================
|
||||
|
||||
prompt_service = PromptEngineeringService()
|
||||
764
backend/app/services/rag_service.py
Normal file
@@ -0,0 +1,764 @@
|
||||
"""
|
||||
RAG 服务模块 - 检索增强生成
|
||||
|
||||
使用 sentence-transformers + Faiss 实现向量检索
|
||||
支持 BM25 关键词检索 + 向量检索混合融合
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from collections import Counter, defaultdict
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 尝试导入 sentence-transformers
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
SENTENCE_TRANSFORMERS_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
logger.warning(f"sentence-transformers 导入失败: {e}")
|
||||
SENTENCE_TRANSFORMERS_AVAILABLE = False
|
||||
SentenceTransformer = None
|
||||
|
||||
|
||||
class SimpleDocument:
|
||||
"""简化文档对象"""
|
||||
def __init__(self, page_content: str, metadata: Dict[str, Any]):
|
||||
self.page_content = page_content
|
||||
self.metadata = metadata
|
||||
|
||||
|
||||
class BM25:
|
||||
"""
|
||||
BM25 关键词检索算法
|
||||
|
||||
一种基于词频和文档频率的信息检索算法,比纯向量搜索更适合关键词精确匹配
|
||||
"""
|
||||
|
||||
def __init__(self, k1: float = 1.5, b: float = 0.75):
|
||||
self.k1 = k1 # 词频饱和参数
|
||||
self.b = b # 文档长度归一化参数
|
||||
self.documents: List[str] = []
|
||||
self.doc_ids: List[str] = []
|
||||
self.avg_doc_length = 0
|
||||
self.doc_freqs: Dict[str, int] = {} # 词 -> 包含该词的文档数
|
||||
self.idf: Dict[str, float] = {} # 词 -> IDF 值
|
||||
self.doc_lengths: List[int] = []
|
||||
self.doc_term_freqs: List[Dict[str, int]] = [] # 每个文档的词频
|
||||
|
||||
def _tokenize(self, text: str) -> List[str]:
|
||||
"""分词(简单的中文分词)"""
|
||||
if not text:
|
||||
return []
|
||||
# 简单分词:按标点和空格分割
|
||||
tokens = re.findall(r'[\u4e00-\u9fff]+|[a-zA-Z0-9]+', text.lower())
|
||||
# 过滤单字符
|
||||
return [t for t in tokens if len(t) > 1]
|
||||
|
||||
def fit(self, documents: List[str], doc_ids: List[str]):
|
||||
"""
|
||||
构建 BM25 索引
|
||||
|
||||
Args:
|
||||
documents: 文档内容列表
|
||||
doc_ids: 文档 ID 列表
|
||||
"""
|
||||
self.documents = documents
|
||||
self.doc_ids = doc_ids
|
||||
n = len(documents)
|
||||
|
||||
# 统计文档频率
|
||||
self.doc_freqs = defaultdict(int)
|
||||
self.doc_lengths = []
|
||||
self.doc_term_freqs = []
|
||||
|
||||
for doc in documents:
|
||||
tokens = self._tokenize(doc)
|
||||
self.doc_lengths.append(len(tokens))
|
||||
doc_tf = Counter(tokens)
|
||||
self.doc_term_freqs.append(doc_tf)
|
||||
|
||||
for term in doc_tf:
|
||||
self.doc_freqs[term] += 1
|
||||
|
||||
# 计算平均文档长度
|
||||
self.avg_doc_length = sum(self.doc_lengths) / n if n > 0 else 0
|
||||
|
||||
# 计算 IDF
|
||||
for term, df in self.doc_freqs.items():
|
||||
# IDF = log((n - df + 0.5) / (df + 0.5))
|
||||
self.idf[term] = math.log((n - df + 0.5) / (df + 0.5) + 1)
|
||||
|
||||
logger.info(f"BM25 索引构建完成: {n} 个文档, {len(self.idf)} 个词项")
|
||||
|
||||
def search(self, query: str, top_k: int = 10) -> List[Tuple[int, float]]:
|
||||
"""
|
||||
搜索相关文档
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
top_k: 返回前 k 个结果
|
||||
|
||||
Returns:
|
||||
[(文档索引, BM25分数), ...]
|
||||
"""
|
||||
if not self.documents:
|
||||
return []
|
||||
|
||||
query_tokens = self._tokenize(query)
|
||||
if not query_tokens:
|
||||
return []
|
||||
|
||||
scores = []
|
||||
n = len(self.documents)
|
||||
|
||||
for idx in range(n):
|
||||
score = self._calculate_score(query_tokens, idx)
|
||||
scores.append((idx, score))
|
||||
|
||||
# 按分数降序排序
|
||||
scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
return scores[:top_k]
|
||||
|
||||
def _calculate_score(self, query_tokens: List[str], doc_idx: int) -> float:
|
||||
"""计算单个文档的 BM25 分数"""
|
||||
doc_tf = self.doc_term_freqs[doc_idx]
|
||||
doc_len = self.doc_lengths[doc_idx]
|
||||
score = 0.0
|
||||
|
||||
for term in query_tokens:
|
||||
if term not in self.idf:
|
||||
continue
|
||||
|
||||
tf = doc_tf.get(term, 0)
|
||||
idf = self.idf[term]
|
||||
|
||||
# BM25 公式
|
||||
numerator = tf * (self.k1 + 1)
|
||||
denominator = tf + self.k1 * (1 - self.b + self.b * doc_len / self.avg_doc_length)
|
||||
|
||||
score += idf * numerator / denominator
|
||||
|
||||
return score
|
||||
|
||||
def get_scores(self, query: str) -> List[float]:
|
||||
"""获取所有文档的 BM25 分数"""
|
||||
if not self.documents:
|
||||
return []
|
||||
|
||||
query_tokens = self._tokenize(query)
|
||||
if not query_tokens:
|
||||
return [0.0] * len(self.documents)
|
||||
|
||||
return [self._calculate_score(query_tokens, idx) for idx in range(len(self.documents))]
|
||||
|
||||
|
||||
class RAGService:
|
||||
"""RAG 检索增强服务"""
|
||||
|
||||
# 默认分块参数
|
||||
DEFAULT_CHUNK_SIZE = 500 # 每个文本块的大小(字符数)
|
||||
DEFAULT_CHUNK_OVERLAP = 50 # 块之间的重叠(字符数)
|
||||
|
||||
def __init__(self):
|
||||
self.embedding_model = None
|
||||
self.index: Optional[faiss.Index] = None
|
||||
self.documents: List[Dict[str, Any]] = []
|
||||
self.doc_ids: List[str] = []
|
||||
self._dimension: int = 384 # 默认维度
|
||||
self._initialized = False
|
||||
self._persist_dir = settings.FAISS_INDEX_DIR
|
||||
# BM25 索引
|
||||
self.bm25: Optional[BM25] = None
|
||||
self._bm25_enabled = True # 始终启用 BM25
|
||||
# 检查是否可用
|
||||
self._disabled = not SENTENCE_TRANSFORMERS_AVAILABLE
|
||||
if self._disabled:
|
||||
logger.warning("RAG 服务已禁用(sentence-transformers 不可用),将使用 BM25 关键词检索")
|
||||
else:
|
||||
logger.info("RAG 服务已启用(向量检索 + BM25 混合检索)")
|
||||
|
||||
def _init_embeddings(self):
|
||||
"""初始化嵌入模型"""
|
||||
if self._disabled:
|
||||
logger.debug("RAG 已禁用,跳过嵌入模型初始化")
|
||||
return
|
||||
if self.embedding_model is None:
|
||||
# 使用轻量级本地模型,避免网络问题
|
||||
model_name = 'all-MiniLM-L6-v2'
|
||||
try:
|
||||
self.embedding_model = SentenceTransformer(model_name)
|
||||
self._dimension = self.embedding_model.get_sentence_embedding_dimension()
|
||||
logger.info(f"RAG 嵌入模型初始化完成: {model_name}, 维度: {self._dimension}")
|
||||
except Exception as e:
|
||||
logger.warning(f"嵌入模型 {model_name} 加载失败: {e}")
|
||||
# 如果本地模型也失败,使用简单hash作为后备
|
||||
self.embedding_model = None
|
||||
self._dimension = 384
|
||||
logger.info("RAG 使用简化模式 (无向量嵌入)")
|
||||
|
||||
def _init_vector_store(self):
|
||||
"""初始化向量存储"""
|
||||
if self.index is None:
|
||||
self._init_embeddings()
|
||||
if self.embedding_model is None:
|
||||
# 无法加载嵌入模型,使用简化模式
|
||||
self._dimension = 384
|
||||
self.index = None
|
||||
logger.warning("RAG 嵌入模型未加载,使用简化模式")
|
||||
else:
|
||||
self.index = faiss.IndexIDMap(faiss.IndexFlatIP(self._dimension))
|
||||
logger.info("Faiss 向量存储初始化完成")
|
||||
|
||||
async def initialize(self):
|
||||
"""异步初始化"""
|
||||
try:
|
||||
self._init_vector_store()
|
||||
self._initialized = True
|
||||
logger.info("RAG 服务初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"RAG 服务初始化失败: {e}")
|
||||
raise
|
||||
|
||||
def _normalize_vectors(self, vectors: np.ndarray) -> np.ndarray:
|
||||
"""归一化向量"""
|
||||
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
|
||||
norms = np.where(norms == 0, 1, norms)
|
||||
return vectors / norms
|
||||
|
||||
def _split_into_chunks(self, text: str, chunk_size: int = None, overlap: int = None) -> List[str]:
|
||||
"""
|
||||
将长文本分割成块
|
||||
|
||||
Args:
|
||||
text: 待分割的文本
|
||||
chunk_size: 每个块的大小(字符数)
|
||||
overlap: 块之间的重叠字符数
|
||||
|
||||
Returns:
|
||||
文本块列表
|
||||
"""
|
||||
if chunk_size is None:
|
||||
chunk_size = self.DEFAULT_CHUNK_SIZE
|
||||
if overlap is None:
|
||||
overlap = self.DEFAULT_CHUNK_OVERLAP
|
||||
|
||||
if len(text) <= chunk_size:
|
||||
return [text] if text.strip() else []
|
||||
|
||||
chunks = []
|
||||
start = 0
|
||||
text_len = len(text)
|
||||
|
||||
while start < text_len:
|
||||
# 计算当前块的结束位置
|
||||
end = start + chunk_size
|
||||
|
||||
# 如果不是最后一块,尝试在句子边界处切割
|
||||
if end < text_len:
|
||||
# 向前查找最后一个句号、逗号、换行或分号
|
||||
cut_positions = []
|
||||
for i in range(end, max(start, end - 100), -1):
|
||||
if text[i] in '。;,,\n、':
|
||||
cut_positions.append(i + 1)
|
||||
break
|
||||
|
||||
if cut_positions:
|
||||
end = cut_positions[0]
|
||||
else:
|
||||
# 如果没找到句子边界,尝试向后查找
|
||||
for i in range(end, min(text_len, end + 50)):
|
||||
if text[i] in '。;,,\n、':
|
||||
end = i + 1
|
||||
break
|
||||
|
||||
chunk = text[start:end].strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
|
||||
# 移动起始位置(考虑重叠)
|
||||
start = end - overlap
|
||||
if start <= 0:
|
||||
start = end
|
||||
|
||||
return chunks
|
||||
|
||||
def index_field(
|
||||
self,
|
||||
table_name: str,
|
||||
field_name: str,
|
||||
field_description: str,
|
||||
sample_values: Optional[List[str]] = None
|
||||
):
|
||||
"""将字段信息索引到向量数据库"""
|
||||
if self._disabled:
|
||||
logger.info(f"[RAG DISABLED] 字段索引操作已跳过: {table_name}.{field_name}")
|
||||
return
|
||||
|
||||
if not self._initialized:
|
||||
self._init_vector_store()
|
||||
|
||||
# 如果没有嵌入模型,只记录到日志
|
||||
if self.embedding_model is None:
|
||||
logger.debug(f"字段跳过索引 (无嵌入模型): {table_name}.{field_name}")
|
||||
return
|
||||
|
||||
text = f"表名: {table_name}, 字段: {field_name}, 描述: {field_description}"
|
||||
if sample_values:
|
||||
text += f", 示例值: {', '.join(sample_values)}"
|
||||
|
||||
doc_id = f"{table_name}.{field_name}"
|
||||
doc = SimpleDocument(
|
||||
page_content=text,
|
||||
metadata={"table_name": table_name, "field_name": field_name, "doc_id": doc_id}
|
||||
)
|
||||
self._add_documents([doc], [doc_id])
|
||||
logger.debug(f"已索引字段: {doc_id}")
|
||||
|
||||
def index_document_content(
|
||||
self,
|
||||
doc_id: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
chunk_size: int = None,
|
||||
chunk_overlap: int = None
|
||||
):
|
||||
"""
|
||||
将文档内容索引到向量数据库(自动分块)
|
||||
|
||||
Args:
|
||||
doc_id: 文档唯一标识
|
||||
content: 文档内容
|
||||
metadata: 文档元数据
|
||||
chunk_size: 文本块大小(字符数),默认500
|
||||
chunk_overlap: 块之间的重叠字符数,默认50
|
||||
"""
|
||||
if self._disabled:
|
||||
logger.info(f"[RAG DISABLED] 文档索引操作已跳过: {doc_id}")
|
||||
return
|
||||
|
||||
if not self._initialized:
|
||||
self._init_vector_store()
|
||||
|
||||
# 如果没有嵌入模型,只记录到日志
|
||||
if self.embedding_model is None:
|
||||
logger.debug(f"文档跳过索引 (无嵌入模型): {doc_id}")
|
||||
return
|
||||
|
||||
# 分割文档为小块
|
||||
if chunk_size is None:
|
||||
chunk_size = self.DEFAULT_CHUNK_SIZE
|
||||
if chunk_overlap is None:
|
||||
chunk_overlap = self.DEFAULT_CHUNK_OVERLAP
|
||||
|
||||
chunks = self._split_into_chunks(content, chunk_size, chunk_overlap)
|
||||
|
||||
if not chunks:
|
||||
logger.warning(f"文档内容为空,跳过索引: {doc_id}")
|
||||
return
|
||||
|
||||
# 为每个块创建文档对象
|
||||
documents = []
|
||||
chunk_ids = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk_id = f"{doc_id}_chunk_{i}"
|
||||
chunk_metadata = metadata.copy() if metadata else {}
|
||||
chunk_metadata.update({
|
||||
"chunk_index": i,
|
||||
"total_chunks": len(chunks),
|
||||
"doc_id": doc_id
|
||||
})
|
||||
|
||||
documents.append(SimpleDocument(
|
||||
page_content=chunk,
|
||||
metadata=chunk_metadata
|
||||
))
|
||||
chunk_ids.append(chunk_id)
|
||||
|
||||
# 批量添加文档
|
||||
self._add_documents(documents, chunk_ids)
|
||||
logger.info(f"已索引文档 {doc_id},共 {len(chunks)} 个块")
|
||||
|
||||
def _add_documents(self, documents: List[SimpleDocument], doc_ids: List[str]):
|
||||
"""批量添加文档到向量索引"""
|
||||
if not documents:
|
||||
return
|
||||
|
||||
# 总是将文档存储在内存中(用于 BM25 和关键词搜索)
|
||||
for doc, did in zip(documents, doc_ids):
|
||||
self.documents.append({"id": did, "content": doc.page_content, "metadata": doc.metadata})
|
||||
self.doc_ids.append(did)
|
||||
|
||||
# 构建 BM25 索引
|
||||
if self._bm25_enabled and documents:
|
||||
bm25_texts = [doc.page_content for doc in documents]
|
||||
if self.bm25 is None:
|
||||
self.bm25 = BM25()
|
||||
self.bm25.fit(bm25_texts, doc_ids)
|
||||
else:
|
||||
# 增量添加:重新构建(BM25 不支持增量)
|
||||
all_texts = [d["content"] for d in self.documents]
|
||||
all_ids = self.doc_ids.copy()
|
||||
self.bm25 = BM25()
|
||||
self.bm25.fit(all_texts, all_ids)
|
||||
logger.debug(f"BM25 索引更新: {len(documents)} 个文档")
|
||||
|
||||
# 如果没有嵌入模型,跳过向量索引
|
||||
if self.embedding_model is None:
|
||||
logger.debug(f"文档跳过向量索引 (无嵌入模型): {len(documents)} 个文档")
|
||||
return
|
||||
|
||||
texts = [doc.page_content for doc in documents]
|
||||
embeddings = self.embedding_model.encode(texts, convert_to_numpy=True)
|
||||
embeddings = self._normalize_vectors(embeddings).astype('float32')
|
||||
|
||||
if self.index is None:
|
||||
self._init_vector_store()
|
||||
|
||||
id_list = [hash(did) for did in doc_ids]
|
||||
id_array = np.array(id_list, dtype='int64')
|
||||
self.index.add_with_ids(embeddings, id_array)
|
||||
|
||||
def retrieve(self, query: str, top_k: int = 5, min_score: float = 0.3) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
根据查询检索相关文档块(混合检索:向量 + BM25)
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
top_k: 返回的最大结果数
|
||||
min_score: 最低相似度分数阈值
|
||||
|
||||
Returns:
|
||||
相关文档块列表,每项包含 content, metadata, score, doc_id, chunk_index
|
||||
"""
|
||||
if self._disabled:
|
||||
logger.info(f"[RAG DISABLED] 检索操作已跳过: query={query}, top_k={top_k}")
|
||||
return []
|
||||
|
||||
if not self._initialized:
|
||||
self._init_vector_store()
|
||||
|
||||
# 获取向量检索结果
|
||||
vector_results = self._vector_search(query, top_k * 2, min_score)
|
||||
|
||||
# 获取 BM25 检索结果
|
||||
bm25_results = self._bm25_search(query, top_k * 2)
|
||||
|
||||
# 混合融合
|
||||
hybrid_results = self._hybrid_fusion(vector_results, bm25_results, top_k)
|
||||
|
||||
if hybrid_results:
|
||||
logger.info(f"混合检索到 {len(hybrid_results)} 条相关文档块 (向量:{len(vector_results)}, BM25:{len(bm25_results)})")
|
||||
return hybrid_results
|
||||
|
||||
# 降级:只使用 BM25
|
||||
if bm25_results:
|
||||
logger.info(f"降级到 BM25 检索: {len(bm25_results)} 条")
|
||||
return bm25_results
|
||||
|
||||
# 降级:使用关键词搜索
|
||||
logger.info("降级到关键词搜索")
|
||||
return self._keyword_search(query, top_k)
|
||||
|
||||
def _vector_search(self, query: str, top_k: int, min_score: float) -> List[Dict[str, Any]]:
|
||||
"""向量检索"""
|
||||
if self.index is None or self.index.ntotal == 0 or self.embedding_model is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
query_embedding = self.embedding_model.encode([query], convert_to_numpy=True)
|
||||
query_embedding = self._normalize_vectors(query_embedding).astype('float32')
|
||||
|
||||
scores, indices = self.index.search(query_embedding, min(top_k * 2, self.index.ntotal))
|
||||
|
||||
results = []
|
||||
for score, idx in zip(scores[0], indices[0]):
|
||||
if idx < 0:
|
||||
continue
|
||||
if score < min_score:
|
||||
continue
|
||||
doc = self.documents[idx]
|
||||
results.append({
|
||||
"content": doc["content"],
|
||||
"metadata": doc["metadata"],
|
||||
"score": float(score),
|
||||
"doc_id": doc["id"],
|
||||
"chunk_index": doc["metadata"].get("chunk_index", 0),
|
||||
"search_type": "vector"
|
||||
})
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(f"向量检索失败: {e}")
|
||||
return []
|
||||
|
||||
def _bm25_search(self, query: str, top_k: int) -> List[Dict[str, Any]]:
|
||||
"""BM25 检索"""
|
||||
if not self.bm25 or not self.documents:
|
||||
return []
|
||||
|
||||
try:
|
||||
bm25_scores = self.bm25.get_scores(query)
|
||||
if not bm25_scores:
|
||||
return []
|
||||
|
||||
# 归一化 BM25 分数到 [0, 1]
|
||||
max_score = max(bm25_scores) if bm25_scores else 1
|
||||
min_score_bm = min(bm25_scores) if bm25_scores else 0
|
||||
score_range = max_score - min_score_bm if max_score != min_score_bm else 1
|
||||
|
||||
results = []
|
||||
for idx, score in enumerate(bm25_scores):
|
||||
if score <= 0:
|
||||
continue
|
||||
# 归一化
|
||||
normalized_score = (score - min_score_bm) / score_range if score_range > 0 else 0
|
||||
doc = self.documents[idx]
|
||||
results.append({
|
||||
"content": doc["content"],
|
||||
"metadata": doc["metadata"],
|
||||
"score": float(normalized_score),
|
||||
"doc_id": doc["id"],
|
||||
"chunk_index": doc["metadata"].get("chunk_index", 0),
|
||||
"search_type": "bm25"
|
||||
})
|
||||
|
||||
# 按分数降序
|
||||
results.sort(key=lambda x: x["score"], reverse=True)
|
||||
return results[:top_k]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"BM25 检索失败: {e}")
|
||||
return []
|
||||
|
||||
def _hybrid_fusion(
|
||||
self,
|
||||
vector_results: List[Dict[str, Any]],
|
||||
bm25_results: List[Dict[str, Any]],
|
||||
top_k: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
混合融合向量和 BM25 检索结果
|
||||
|
||||
使用 RRFR (Reciprocal Rank Fusion) 算法:
|
||||
Score = weight_vector * (1 / rank_vector) + weight_bm25 * (1 / rank_bm25)
|
||||
|
||||
Args:
|
||||
vector_results: 向量检索结果
|
||||
bm25_results: BM25 检索结果
|
||||
top_k: 返回数量
|
||||
|
||||
Returns:
|
||||
融合后的结果
|
||||
"""
|
||||
if not vector_results and not bm25_results:
|
||||
return []
|
||||
|
||||
# 融合权重
|
||||
weight_vector = 0.6
|
||||
weight_bm25 = 0.4
|
||||
|
||||
# 构建文档分数映射
|
||||
doc_scores: Dict[str, Dict[str, float]] = {}
|
||||
|
||||
# 添加向量检索结果
|
||||
for rank, result in enumerate(vector_results):
|
||||
doc_id = result["doc_id"]
|
||||
if doc_id not in doc_scores:
|
||||
doc_scores[doc_id] = {"vector": 0, "bm25": 0, "content": result["content"], "metadata": result["metadata"]}
|
||||
# 使用倒数排名 (Reciprocal Rank)
|
||||
doc_scores[doc_id]["vector"] = weight_vector / (rank + 1)
|
||||
|
||||
# 添加 BM25 检索结果
|
||||
for rank, result in enumerate(bm25_results):
|
||||
doc_id = result["doc_id"]
|
||||
if doc_id not in doc_scores:
|
||||
doc_scores[doc_id] = {"vector": 0, "bm25": 0, "content": result["content"], "metadata": result["metadata"]}
|
||||
doc_scores[doc_id]["bm25"] = weight_bm25 / (rank + 1)
|
||||
|
||||
# 计算融合分数
|
||||
fused_results = []
|
||||
for doc_id, scores in doc_scores.items():
|
||||
fused_score = scores["vector"] + scores["bm25"]
|
||||
# 使用向量检索结果的原始分数作为参考
|
||||
vector_score = next((r["score"] for r in vector_results if r["doc_id"] == doc_id), 0.5)
|
||||
fused_results.append({
|
||||
"content": scores["content"],
|
||||
"metadata": scores["metadata"],
|
||||
"score": fused_score,
|
||||
"doc_id": doc_id,
|
||||
"vector_score": vector_score,
|
||||
"bm25_score": scores["bm25"],
|
||||
"search_type": "hybrid"
|
||||
})
|
||||
|
||||
# 按融合分数降序排序
|
||||
fused_results.sort(key=lambda x: x["score"], reverse=True)
|
||||
|
||||
logger.debug(f"混合融合: {len(fused_results)} 个文档, 向量:{len(vector_results)}, BM25:{len(bm25_results)}")
|
||||
|
||||
return fused_results[:top_k]
|
||||
|
||||
def _keyword_search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
关键词搜索后备方案
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
top_k: 返回的最大结果数
|
||||
|
||||
Returns:
|
||||
相关文档块列表
|
||||
"""
|
||||
if not self.documents:
|
||||
return []
|
||||
|
||||
# 提取查询关键词
|
||||
keywords = []
|
||||
for char in query:
|
||||
if '\u4e00' <= char <= '\u9fff': # 中文字符
|
||||
keywords.append(char)
|
||||
# 添加英文单词
|
||||
import re
|
||||
english_words = re.findall(r'[a-zA-Z]+', query)
|
||||
keywords.extend(english_words)
|
||||
|
||||
if not keywords:
|
||||
return []
|
||||
|
||||
results = []
|
||||
for doc in self.documents:
|
||||
content = doc["content"]
|
||||
# 计算关键词匹配分数
|
||||
score = 0
|
||||
matched_keywords = 0
|
||||
for kw in keywords:
|
||||
if kw in content:
|
||||
score += 1
|
||||
matched_keywords += 1
|
||||
|
||||
if matched_keywords > 0:
|
||||
# 归一化分数
|
||||
score = score / max(len(keywords), 1)
|
||||
results.append({
|
||||
"content": content,
|
||||
"metadata": doc["metadata"],
|
||||
"score": score,
|
||||
"doc_id": doc["id"],
|
||||
"chunk_index": doc["metadata"].get("chunk_index", 0)
|
||||
})
|
||||
|
||||
# 按分数排序
|
||||
results.sort(key=lambda x: x["score"], reverse=True)
|
||||
|
||||
logger.debug(f"关键词搜索返回 {len(results[:top_k])} 条结果")
|
||||
return results[:top_k]
|
||||
|
||||
def retrieve_by_doc_id(self, doc_id: str, top_k: int = 10) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定文档的所有块
|
||||
|
||||
Args:
|
||||
doc_id: 文档ID
|
||||
top_k: 返回的最大结果数
|
||||
|
||||
Returns:
|
||||
该文档的所有块
|
||||
"""
|
||||
# 获取属于该文档的所有块
|
||||
doc_chunks = [d for d in self.documents if d["metadata"].get("doc_id") == doc_id]
|
||||
|
||||
# 按 chunk_index 排序
|
||||
doc_chunks.sort(key=lambda x: x["metadata"].get("chunk_index", 0))
|
||||
|
||||
# 返回指定数量
|
||||
return doc_chunks[:top_k]
|
||||
|
||||
def retrieve_by_table(self, table_name: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
"""检索指定表的字段"""
|
||||
return self.retrieve(f"表名: {table_name}", top_k)
|
||||
|
||||
def get_vector_count(self) -> int:
|
||||
"""获取向量总数"""
|
||||
if self._disabled:
|
||||
logger.info("[RAG DISABLED] get_vector_count 返回 0")
|
||||
return 0
|
||||
if self.index is None:
|
||||
return 0
|
||||
return self.index.ntotal
|
||||
|
||||
def save_index(self, persist_path: str = None):
|
||||
"""保存向量索引到磁盘"""
|
||||
if persist_path is None:
|
||||
persist_path = self._persist_dir
|
||||
|
||||
if self.index is not None:
|
||||
os.makedirs(persist_path, exist_ok=True)
|
||||
faiss.write_index(self.index, os.path.join(persist_path, "index.faiss"))
|
||||
with open(os.path.join(persist_path, "documents.pkl"), "wb") as f:
|
||||
pickle.dump(self.documents, f)
|
||||
logger.info(f"向量索引已保存到: {persist_path}")
|
||||
|
||||
def load_index(self, persist_path: str = None):
|
||||
"""从磁盘加载向量索引"""
|
||||
if persist_path is None:
|
||||
persist_path = self._persist_dir
|
||||
|
||||
index_file = os.path.join(persist_path, "index.faiss")
|
||||
docs_file = os.path.join(persist_path, "documents.pkl")
|
||||
|
||||
if not os.path.exists(index_file):
|
||||
logger.warning(f"向量索引文件不存在: {index_file}")
|
||||
return
|
||||
|
||||
self._init_embeddings()
|
||||
self.index = faiss.read_index(index_file)
|
||||
|
||||
with open(docs_file, "rb") as f:
|
||||
self.documents = pickle.load(f)
|
||||
|
||||
self.doc_ids = [d["id"] for d in self.documents]
|
||||
self._initialized = True
|
||||
logger.info(f"向量索引已从 {persist_path} 加载,共 {len(self.documents)} 条")
|
||||
|
||||
def delete_by_doc_id(self, doc_id: str):
|
||||
"""根据文档ID删除索引"""
|
||||
if self.index is not None:
|
||||
remaining = [d for d in self.documents if d["id"] != doc_id]
|
||||
self.documents = remaining
|
||||
self.doc_ids = [d["id"] for d in self.documents]
|
||||
|
||||
self.index.reset()
|
||||
if self.documents:
|
||||
texts = [d["content"] for d in self.documents]
|
||||
embeddings = self.embedding_model.encode(texts, convert_to_numpy=True)
|
||||
embeddings = self._normalize_vectors(embeddings).astype('float32')
|
||||
id_array = np.array([hash(did) for did in self.doc_ids], dtype='int64')
|
||||
self.index.add_with_ids(embeddings, id_array)
|
||||
|
||||
logger.debug(f"已删除索引: {doc_id}")
|
||||
|
||||
def clear(self):
|
||||
"""清空所有索引"""
|
||||
if self._disabled:
|
||||
logger.info("[RAG DISABLED] clear 操作已跳过")
|
||||
return
|
||||
self._init_vector_store()
|
||||
if self.index is not None:
|
||||
self.index.reset()
|
||||
self.documents = []
|
||||
self.doc_ids = []
|
||||
logger.info("已清空所有向量索引")
|
||||
|
||||
|
||||
rag_service = RAGService()
|
||||
724
backend/app/services/table_rag_service.py
Normal file
@@ -0,0 +1,724 @@
|
||||
"""
|
||||
表结构 RAG 索引服务
|
||||
|
||||
AI 自动生成表字段的语义描述,并建立向量索引
|
||||
"""
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from app.services.llm_service import llm_service
|
||||
from app.services.rag_service import rag_service
|
||||
from app.services.excel_storage_service import excel_storage_service
|
||||
from app.core.database.mysql import mysql_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TableRAGService:
|
||||
"""
|
||||
表结构 RAG 索引服务
|
||||
|
||||
核心功能:
|
||||
1. AI 根据表头和数据生成字段语义描述
|
||||
2. 将字段描述存入向量数据库 (RAG)
|
||||
3. 支持自然语言查询表字段
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.llm = llm_service
|
||||
self.rag = rag_service
|
||||
self.excel_storage = excel_storage_service
|
||||
# 临时禁用 RAG 索引构建
|
||||
self._disabled = True
|
||||
logger.info("TableRAG 服务已禁用(_disabled=True),仅记录索引操作日志")
|
||||
|
||||
def _extract_sheet_names_from_xml(self, file_path: str) -> List[str]:
|
||||
"""
|
||||
从 Excel 文件的 XML 中提取工作表名称
|
||||
|
||||
某些 Excel 文件由于包含非标准元素,pandas/openpyxl 无法正确解析工作表列表,
|
||||
此时需要直接从 XML 中提取。
|
||||
|
||||
Args:
|
||||
file_path: Excel 文件路径
|
||||
|
||||
Returns:
|
||||
工作表名称列表
|
||||
"""
|
||||
import zipfile
|
||||
from xml.etree import ElementTree as ET
|
||||
|
||||
# 尝试多种命名空间
|
||||
namespaces = [
|
||||
'http://schemas.openxmlformats.org/spreadsheetml/2006/main',
|
||||
'http://purl.oclc.org/ooxml/spreadsheetml/main',
|
||||
]
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(file_path, 'r') as z:
|
||||
# 读取 workbook.xml
|
||||
if 'xl/workbook.xml' not in z.namelist():
|
||||
return []
|
||||
|
||||
content = z.read('xl/workbook.xml')
|
||||
root = ET.fromstring(content)
|
||||
|
||||
# 尝试多种命名空间
|
||||
for ns_uri in namespaces:
|
||||
ns = {'main': ns_uri}
|
||||
sheets = root.findall('.//main:sheet', ns)
|
||||
if sheets:
|
||||
names = [s.get('name') for s in sheets if s.get('name')]
|
||||
if names:
|
||||
logger.info(f"使用命名空间 {ns_uri} 提取到工作表: {names}")
|
||||
return names
|
||||
|
||||
# 如果都没找到,尝试不带命名空间
|
||||
sheets = root.findall('.//sheet')
|
||||
if not sheets:
|
||||
sheets = root.findall('.//{*}sheet')
|
||||
names = [s.get('name') for s in sheets if s.get('name')]
|
||||
if names:
|
||||
logger.info(f"使用通配符提取到工作表: {names}")
|
||||
return names
|
||||
|
||||
logger.warning(f"无法从 XML 提取工作表,尝试的文件: {file_path}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"从 XML 提取工作表失败: {file_path}, error: {e}")
|
||||
return []
|
||||
|
||||
def _read_excel_sheet(self, file_path: str, sheet_name: str = None, header_row: int = 0) -> pd.DataFrame:
|
||||
"""
|
||||
读取 Excel 工作表,支持 pandas 无法解析的特殊 Excel 文件
|
||||
|
||||
当 pandas 的 ExcelFile 无法正确解析时,直接从 XML 读取数据。
|
||||
|
||||
Args:
|
||||
file_path: Excel 文件路径
|
||||
sheet_name: 工作表名称(如果为 None,读取第一个工作表)
|
||||
header_row: 表头行号
|
||||
|
||||
Returns:
|
||||
DataFrame
|
||||
"""
|
||||
import zipfile
|
||||
from xml.etree import ElementTree as ET
|
||||
|
||||
# 定义命名空间
|
||||
namespaces = [
|
||||
'http://schemas.openxmlformats.org/spreadsheetml/2006/main',
|
||||
'http://purl.oclc.org/ooxml/spreadsheetml/main',
|
||||
]
|
||||
|
||||
try:
|
||||
# 先尝试用 pandas 正常读取
|
||||
df = pd.read_excel(file_path, sheet_name=sheet_name, header=header_row)
|
||||
if df is not None and not df.empty:
|
||||
return df
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# pandas 读取失败,从 XML 直接解析
|
||||
logger.info(f"使用 XML 方式读取 Excel: {file_path}")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(file_path, 'r') as z:
|
||||
# 获取工作表名称
|
||||
sheet_names = self._extract_sheet_names_from_xml(file_path)
|
||||
if not sheet_names:
|
||||
raise ValueError("无法从 Excel 文件中找到工作表")
|
||||
|
||||
# 确定要读取的工作表
|
||||
target_sheet = sheet_name if sheet_name and sheet_name in sheet_names else sheet_names[0]
|
||||
sheet_index = sheet_names.index(target_sheet) + 1 # sheet1.xml, sheet2.xml, ...
|
||||
|
||||
# 读取 shared strings
|
||||
shared_strings = []
|
||||
if 'xl/sharedStrings.xml' in z.namelist():
|
||||
ss_content = z.read('xl/sharedStrings.xml')
|
||||
ss_root = ET.fromstring(ss_content)
|
||||
# 使用通配符查找所有 si 元素
|
||||
for si in ss_root.iter():
|
||||
if si.tag.endswith('}si') or si.tag == 'si':
|
||||
t = si.find('.//{*}t')
|
||||
if t is not None and t.text:
|
||||
shared_strings.append(t.text)
|
||||
else:
|
||||
shared_strings.append('')
|
||||
|
||||
# 读取工作表
|
||||
sheet_file = f'xl/worksheets/sheet{sheet_index}.xml'
|
||||
if sheet_file not in z.namelist():
|
||||
raise ValueError(f"工作表文件 {sheet_file} 不存在")
|
||||
|
||||
sheet_content = z.read(sheet_file)
|
||||
root = ET.fromstring(sheet_content)
|
||||
|
||||
# 解析行 - 使用通配符查找
|
||||
rows_data = []
|
||||
headers = {}
|
||||
|
||||
for row in root.iter():
|
||||
if row.tag.endswith('}row') or row.tag == 'row':
|
||||
row_idx = int(row.get('r', 0))
|
||||
|
||||
# 收集表头行
|
||||
if row_idx == header_row + 1:
|
||||
for cell in row:
|
||||
if cell.tag.endswith('}c') or cell.tag == 'c':
|
||||
cell_ref = cell.get('r', '')
|
||||
col_letters = ''.join(filter(str.isalpha, cell_ref))
|
||||
cell_type = cell.get('t', 'n')
|
||||
v = cell.find('{*}v')
|
||||
if v is not None and v.text:
|
||||
if cell_type == 's':
|
||||
try:
|
||||
headers[col_letters] = shared_strings[int(v.text)]
|
||||
except (ValueError, IndexError):
|
||||
headers[col_letters] = v.text
|
||||
else:
|
||||
headers[col_letters] = v.text
|
||||
else:
|
||||
headers[col_letters] = col_letters
|
||||
continue
|
||||
|
||||
# 跳过表头行之后的数据行
|
||||
if row_idx <= header_row + 1:
|
||||
continue
|
||||
|
||||
row_cells = {}
|
||||
for cell in row:
|
||||
if cell.tag.endswith('}c') or cell.tag == 'c':
|
||||
cell_ref = cell.get('r', '')
|
||||
col_letters = ''.join(filter(str.isalpha, cell_ref))
|
||||
cell_type = cell.get('t', 'n')
|
||||
v = cell.find('{*}v')
|
||||
|
||||
if v is not None and v.text:
|
||||
if cell_type == 's':
|
||||
try:
|
||||
val = shared_strings[int(v.text)]
|
||||
except (ValueError, IndexError):
|
||||
val = v.text
|
||||
elif cell_type == 'b':
|
||||
val = v.text == '1'
|
||||
else:
|
||||
val = v.text
|
||||
else:
|
||||
val = None
|
||||
|
||||
row_cells[col_letters] = val
|
||||
|
||||
if row_cells:
|
||||
rows_data.append(row_cells)
|
||||
|
||||
# 转换为 DataFrame
|
||||
if not rows_data:
|
||||
logger.warning(f"XML 解析结果为空: {file_path}, sheet: {target_sheet}")
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame(rows_data)
|
||||
|
||||
# 应用表头
|
||||
if headers:
|
||||
df.columns = [headers.get(col, col) for col in df.columns]
|
||||
|
||||
logger.info(f"XML 解析完成: {len(df)} 行, {len(df.columns)} 列")
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"XML 解析 Excel 失败: {e}")
|
||||
raise
|
||||
|
||||
async def generate_field_description(
|
||||
self,
|
||||
table_name: str,
|
||||
field_name: str,
|
||||
sample_values: List[Any],
|
||||
all_fields: Dict[str, List[Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
使用 AI 生成字段的语义描述
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
field_name: 字段名
|
||||
sample_values: 字段示例值 (前10个)
|
||||
all_fields: 其他字段的示例值,用于上下文理解
|
||||
|
||||
Returns:
|
||||
字段的语义描述
|
||||
"""
|
||||
# 构建 Prompt
|
||||
context = ""
|
||||
if all_fields:
|
||||
context = "\n其他字段示例:\n"
|
||||
for fname, values in all_fields.items():
|
||||
if fname != field_name and values:
|
||||
context += f"- {fname}: {', '.join([str(v) for v in values[:3]])}\n"
|
||||
|
||||
prompt = f"""你是一个数据语义分析专家。请根据字段名和示例值,推断该字段的语义含义。
|
||||
|
||||
表名:{table_name}
|
||||
字段名:{field_name}
|
||||
示例值:{', '.join([str(v) for v in sample_values[:10] if v is not None])}
|
||||
{context}
|
||||
|
||||
请生成一段简洁的字段语义描述(不超过50字),说明:
|
||||
1. 该字段代表什么含义
|
||||
2. 数据格式或单位(如果有)
|
||||
3. 可能的业务用途
|
||||
|
||||
只输出描述文字,不要其他内容。"""
|
||||
|
||||
try:
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的数据分析师。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
response = await self.llm.chat(
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
max_tokens=200
|
||||
)
|
||||
|
||||
description = self.llm.extract_message_content(response)
|
||||
return description.strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成字段描述失败: {str(e)}")
|
||||
return f"{field_name}: 数据字段"
|
||||
|
||||
async def build_table_rag_index(
|
||||
self,
|
||||
file_path: str,
|
||||
filename: str,
|
||||
sheet_name: Optional[str] = None,
|
||||
header_row: int = 0,
|
||||
sample_size: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
为 Excel 表构建完整的 RAG 索引
|
||||
|
||||
流程:
|
||||
1. 读取 Excel 获取字段信息
|
||||
2. AI 生成每个字段的语义描述
|
||||
3. 将字段描述存入向量数据库
|
||||
|
||||
Args:
|
||||
file_path: Excel 文件路径
|
||||
filename: 原始文件名
|
||||
sheet_name: 工作表名称
|
||||
header_row: 表头行号
|
||||
sample_size: 每个字段采样的数据条数
|
||||
|
||||
Returns:
|
||||
索引构建结果
|
||||
"""
|
||||
results = {
|
||||
"success": True,
|
||||
"table_name": "",
|
||||
"field_count": 0,
|
||||
"indexed_fields": [],
|
||||
"errors": []
|
||||
}
|
||||
|
||||
try:
|
||||
# 1. 先检查 Excel 文件是否有效
|
||||
logger.info(f"正在检查Excel文件: {file_path}")
|
||||
try:
|
||||
xls_file = pd.ExcelFile(file_path)
|
||||
sheet_names = xls_file.sheet_names
|
||||
logger.info(f"Excel文件工作表: {sheet_names}")
|
||||
|
||||
# 如果 sheet_names 为空,尝试从 XML 中手动提取
|
||||
if not sheet_names:
|
||||
sheet_names = self._extract_sheet_names_from_xml(file_path)
|
||||
logger.info(f"从XML提取工作表: {sheet_names}")
|
||||
|
||||
if not sheet_names:
|
||||
return {"success": False, "error": "Excel 文件没有工作表"}
|
||||
except Exception as e:
|
||||
logger.error(f"读取Excel文件失败: {file_path}, error: {e}")
|
||||
return {"success": False, "error": f"无法读取Excel文件: {str(e)}"}
|
||||
|
||||
# 2. 读取 Excel
|
||||
if sheet_name:
|
||||
# 验证指定的sheet_name是否存在
|
||||
if sheet_name not in sheet_names:
|
||||
logger.warning(f"指定的工作表 '{sheet_name}' 不存在,使用第一个工作表: {sheet_names[0]}")
|
||||
sheet_name = sheet_names[0]
|
||||
df = self._read_excel_sheet(file_path, sheet_name=sheet_name, header_row=header_row)
|
||||
|
||||
logger.info(f"读取到数据: {len(df)} 行, {len(df.columns)} 列")
|
||||
|
||||
if df.empty:
|
||||
return {"success": False, "error": "Excel 文件为空"}
|
||||
|
||||
# 清理列名
|
||||
df.columns = [str(c) for c in df.columns]
|
||||
table_name = self.excel_storage._sanitize_table_name(filename)
|
||||
results["table_name"] = table_name
|
||||
results["field_count"] = len(df.columns)
|
||||
logger.info(f"表名: {table_name}, 字段数: {len(df.columns)}")
|
||||
|
||||
# 3. 初始化 RAG (如果需要)
|
||||
if not self.rag._initialized:
|
||||
self.rag._init_vector_store()
|
||||
|
||||
# 4. 为每个字段生成描述并索引
|
||||
all_fields_data = {}
|
||||
for col in df.columns:
|
||||
# 采样示例值
|
||||
sample_values = df[col].dropna().head(sample_size).tolist()
|
||||
all_fields_data[col] = sample_values
|
||||
|
||||
# 批量生成描述(避免过多 API 调用)
|
||||
indexed_count = 0
|
||||
for col in df.columns:
|
||||
try:
|
||||
sample_values = all_fields_data[col]
|
||||
|
||||
# 生成描述
|
||||
description = await self.generate_field_description(
|
||||
table_name=table_name,
|
||||
field_name=col,
|
||||
sample_values=sample_values,
|
||||
all_fields=all_fields_data
|
||||
)
|
||||
|
||||
# 存入 RAG(如果未禁用)
|
||||
if self._disabled:
|
||||
logger.info(f"[RAG DISABLED] 字段索引已跳过: {table_name}.{col}")
|
||||
else:
|
||||
self.rag.index_field(
|
||||
table_name=table_name,
|
||||
field_name=col,
|
||||
field_description=description,
|
||||
sample_values=[str(v) for v in sample_values[:5]]
|
||||
)
|
||||
|
||||
indexed_count += 1
|
||||
results["indexed_fields"].append({
|
||||
"field": col,
|
||||
"description": description
|
||||
})
|
||||
|
||||
logger.info(f"字段已索引: {table_name}.{col}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"字段 {col} 索引失败: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
results["errors"].append(error_msg)
|
||||
|
||||
# 5. 存储到 MySQL
|
||||
logger.info(f"开始存储到MySQL: {filename}")
|
||||
store_result = await self.excel_storage.store_excel(
|
||||
file_path=file_path,
|
||||
filename=filename,
|
||||
sheet_name=sheet_name,
|
||||
header_row=header_row
|
||||
)
|
||||
|
||||
if store_result.get("success"):
|
||||
results["mysql_table"] = store_result.get("table_name")
|
||||
results["row_count"] = store_result.get("row_count")
|
||||
else:
|
||||
results["mysql_warning"] = "MySQL 存储失败: " + str(store_result.get("error"))
|
||||
|
||||
results["indexed_count"] = indexed_count
|
||||
logger.info(f"表 {table_name} RAG 索引构建完成,共 {indexed_count} 个字段")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"构建 RAG 索引失败: {str(e)}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def index_document_table(
|
||||
self,
|
||||
doc_id: str,
|
||||
filename: str,
|
||||
table_data: Dict[str, Any],
|
||||
source_doc_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
为非结构化文档中提取的表格建立 MySQL 存储和 RAG 索引
|
||||
|
||||
Args:
|
||||
doc_id: 源文档 ID
|
||||
filename: 源文件名
|
||||
table_data: 表格数据,支持两种格式:
|
||||
1. docx/txt格式: {"rows": [["col1", "col2"], ["val1", "val2"]], ...}
|
||||
2. md格式: {"headers": [...], "rows": [...], ...}
|
||||
source_doc_type: 源文档类型 (docx/md/txt)
|
||||
|
||||
Returns:
|
||||
索引构建结果
|
||||
"""
|
||||
results = {
|
||||
"success": True,
|
||||
"table_name": "",
|
||||
"field_count": 0,
|
||||
"indexed_fields": [],
|
||||
"errors": []
|
||||
}
|
||||
|
||||
try:
|
||||
# 兼容两种格式
|
||||
if "headers" in table_data:
|
||||
# md 格式:headers 和 rows 分开
|
||||
columns = table_data.get("headers", [])
|
||||
data_rows = table_data.get("rows", [])
|
||||
else:
|
||||
# docx/txt 格式:第一行作为表头
|
||||
rows = table_data.get("rows", [])
|
||||
if not rows or len(rows) < 2:
|
||||
return {"success": False, "error": "表格数据不足"}
|
||||
columns = rows[0]
|
||||
data_rows = rows[1:]
|
||||
|
||||
# 生成表名:源文件 + 表格索引
|
||||
base_name = self.excel_storage._sanitize_table_name(filename)
|
||||
table_name = f"{base_name}_table{table_data.get('table_index', 0)}"
|
||||
|
||||
results["table_name"] = table_name
|
||||
results["field_count"] = len(columns)
|
||||
|
||||
# 1. 初始化 RAG
|
||||
if not self.rag._initialized:
|
||||
self.rag._init_vector_store()
|
||||
|
||||
# 2. 准备结构化数据
|
||||
structured_data = {
|
||||
"columns": columns,
|
||||
"rows": data_rows
|
||||
}
|
||||
|
||||
# 3. 存储到 MySQL
|
||||
store_result = await self.excel_storage.store_structured_data(
|
||||
table_name=table_name,
|
||||
data=structured_data,
|
||||
source_doc_id=doc_id
|
||||
)
|
||||
|
||||
if store_result.get("success"):
|
||||
results["mysql_table"] = store_result.get("table_name")
|
||||
results["row_count"] = store_result.get("row_count")
|
||||
else:
|
||||
results["mysql_warning"] = "MySQL 存储失败: " + str(store_result.get("error"))
|
||||
|
||||
# 4. 为每个字段生成描述并索引
|
||||
all_fields_data = {}
|
||||
for i, col in enumerate(columns):
|
||||
col_values = [row[i] for row in data_rows if i < len(row)]
|
||||
all_fields_data[col] = col_values
|
||||
|
||||
indexed_count = 0
|
||||
for col in columns:
|
||||
try:
|
||||
col_values = all_fields_data.get(col, [])
|
||||
|
||||
# 生成描述
|
||||
description = await self.generate_field_description(
|
||||
table_name=table_name,
|
||||
field_name=col,
|
||||
sample_values=col_values[:10],
|
||||
all_fields=all_fields_data
|
||||
)
|
||||
|
||||
# 存入 RAG(如果未禁用)
|
||||
if self._disabled:
|
||||
logger.info(f"[RAG DISABLED] 文档表格字段索引已跳过: {table_name}.{col}")
|
||||
else:
|
||||
self.rag.index_field(
|
||||
table_name=table_name,
|
||||
field_name=col,
|
||||
field_description=description,
|
||||
sample_values=[str(v) for v in col_values[:5]]
|
||||
)
|
||||
|
||||
indexed_count += 1
|
||||
results["indexed_fields"].append({
|
||||
"field": col,
|
||||
"description": description
|
||||
})
|
||||
|
||||
logger.info(f"文档表格字段已索引: {table_name}.{col}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"字段 {col} 索引失败: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
results["errors"].append(error_msg)
|
||||
|
||||
results["indexed_count"] = indexed_count
|
||||
logger.info(f"文档表格 {table_name} RAG 索引构建完成,共 {indexed_count} 个字段")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"构建文档表格 RAG 索引失败: {str(e)}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def query_table_by_natural_language(
|
||||
self,
|
||||
user_query: str,
|
||||
top_k: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
根据自然语言查询相关表字段
|
||||
|
||||
Args:
|
||||
user_query: 用户查询
|
||||
top_k: 返回数量
|
||||
|
||||
Returns:
|
||||
匹配的字段信息
|
||||
"""
|
||||
try:
|
||||
# 1. RAG 检索
|
||||
rag_results = self.rag.retrieve(user_query, top_k=top_k)
|
||||
|
||||
# 2. 解析检索结果
|
||||
matched_fields = []
|
||||
for result in rag_results:
|
||||
metadata = result.get("metadata", {})
|
||||
matched_fields.append({
|
||||
"table_name": metadata.get("table_name", ""),
|
||||
"field_name": metadata.get("field_name", ""),
|
||||
"description": result.get("content", ""),
|
||||
"score": result.get("score", 0),
|
||||
"sample_values": [] # 可以后续补充
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"query": user_query,
|
||||
"matched_fields": matched_fields,
|
||||
"count": len(matched_fields)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询失败: {str(e)}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def get_table_fields_with_description(
|
||||
self,
|
||||
table_name: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取表的字段及其描述
|
||||
|
||||
Args:
|
||||
table_name: 表名
|
||||
|
||||
Returns:
|
||||
字段列表
|
||||
"""
|
||||
try:
|
||||
# 从 RAG 检索该表的所有字段
|
||||
results = self.rag.retrieve_by_table(table_name, top_k=50)
|
||||
|
||||
fields = []
|
||||
for result in results:
|
||||
metadata = result.get("metadata", {})
|
||||
fields.append({
|
||||
"table_name": metadata.get("table_name", ""),
|
||||
"field_name": metadata.get("field_name", ""),
|
||||
"description": result.get("content", ""),
|
||||
"score": result.get("score", 0)
|
||||
})
|
||||
|
||||
return fields
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取字段失败: {str(e)}")
|
||||
return []
|
||||
|
||||
async def rebuild_all_table_indexes(self) -> Dict[str, Any]:
|
||||
"""
|
||||
重建所有表的 RAG 索引
|
||||
|
||||
从 MySQL 读取所有表结构,重新生成描述并索引
|
||||
"""
|
||||
try:
|
||||
# 清空现有索引
|
||||
self.rag.clear()
|
||||
|
||||
# 获取所有表
|
||||
tables = await self.excel_storage.list_tables()
|
||||
|
||||
results = {
|
||||
"success": True,
|
||||
"tables_processed": 0,
|
||||
"total_fields": 0,
|
||||
"errors": []
|
||||
}
|
||||
|
||||
for table_name in tables:
|
||||
try:
|
||||
# 获取表结构
|
||||
schema = await self.excel_storage.get_table_schema(table_name)
|
||||
|
||||
if not schema:
|
||||
continue
|
||||
|
||||
# 初始化 RAG
|
||||
if not self.rag._initialized:
|
||||
self.rag._init_vector_store()
|
||||
|
||||
# 为每个字段生成描述并索引
|
||||
for col_info in schema:
|
||||
field_name = col_info.get("COLUMN_NAME", "")
|
||||
if field_name in ["id", "created_at", "updated_at"]:
|
||||
continue
|
||||
|
||||
# 采样数据
|
||||
samples = await self.excel_storage.query_table(
|
||||
table_name,
|
||||
columns=[field_name],
|
||||
limit=10
|
||||
)
|
||||
sample_values = [r.get(field_name) for r in samples if r.get(field_name)]
|
||||
|
||||
# 生成描述
|
||||
description = await self.generate_field_description(
|
||||
table_name=table_name,
|
||||
field_name=field_name,
|
||||
sample_values=sample_values
|
||||
)
|
||||
|
||||
# 索引
|
||||
self.rag.index_field(
|
||||
table_name=table_name,
|
||||
field_name=field_name,
|
||||
field_description=description,
|
||||
sample_values=[str(v) for v in sample_values[:5]]
|
||||
)
|
||||
|
||||
results["total_fields"] += 1
|
||||
|
||||
results["tables_processed"] += 1
|
||||
logger.info(f"表 {table_name} 索引重建完成")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"表 {table_name} 索引失败: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
results["errors"].append(error_msg)
|
||||
|
||||
logger.info(f"全部 {results['tables_processed']} 个表索引重建完成")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"重建索引失败: {str(e)}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
# ==================== 全局单例 ====================
|
||||
|
||||
table_rag_service = TableRAGService()
|
||||
2945
backend/app/services/template_fill_service.py
Normal file
218
backend/app/services/text_analysis_service.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""
|
||||
文本分析服务 - 从 AI 分析结果中提取结构化数据用于可视化
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
import re
|
||||
import json
|
||||
|
||||
from app.services.llm_service import llm_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextAnalysisService:
|
||||
"""文本分析服务类"""
|
||||
|
||||
def __init__(self):
|
||||
self.llm_service = llm_service
|
||||
|
||||
async def extract_structured_data(
|
||||
self,
|
||||
analysis_text: str,
|
||||
original_filename: str = "",
|
||||
file_type: str = "text"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从 AI 分析结果文本中提取结构化数据
|
||||
|
||||
Args:
|
||||
analysis_text: AI 分析结果文本
|
||||
original_filename: 原始文件名
|
||||
file_type: 文件类型
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 提取的结构化数据
|
||||
"""
|
||||
# 限制分析的文本长度,避免 token 超限
|
||||
max_text_length = 8000
|
||||
truncated_text = analysis_text[:max_text_length]
|
||||
|
||||
system_prompt = """你是一个专业的数据提取助手。你的任务是从AI分析结果中提取结构化数据,用于生成图表。
|
||||
|
||||
请按照以下要求提取数据:
|
||||
|
||||
1. 数值型数据:
|
||||
- 提取所有的数值、统计信息、百分比等
|
||||
- 为每个数值创建一个条目,包含:名称、值、单位(如果有)
|
||||
- 格式示例:{"name": "销售额", "value": 123456.78, "unit": "元"}
|
||||
|
||||
2. 分类数据:
|
||||
- 提取所有的类别、状态、枚举值等
|
||||
- 为每个类别创建一个条目,包含:名称、值、数量(如果有)
|
||||
- 格式示例:{"name": "产品类别", "value": "电子产品", "count": 25}
|
||||
|
||||
3. 时间序列数据:
|
||||
- 提取所有的时间相关数据(年月、季度、日期等)
|
||||
- 格式示例:{"name": "2025年1月", "value": 12345}
|
||||
|
||||
4. 对比数据:
|
||||
- 提取所有的对比、排名、趋势等数据
|
||||
- 格式示例:{"name": "同比增长", "value": 15.3, "unit": "%"}
|
||||
|
||||
5. 表格数据:
|
||||
- 如果分析结果中包含表格或列表形式的数据,提取出来
|
||||
- 格式:{"columns": ["列1", "列2"], "rows": [{"列1": "值1", "列2": "值2"}]}
|
||||
|
||||
重要规则:
|
||||
- 只提取明确提到的数据和数值
|
||||
- 如果某种类型的数据不存在,返回空数组 []
|
||||
- 确保所有数值都是有效的数字类型
|
||||
- 保持数据的原始精度
|
||||
- 返回的 JSON 必须完整且格式正确
|
||||
- 表格数据最多提取 20 行
|
||||
|
||||
请以 JSON 格式返回,不要添加任何 Markdown 标记或解释文字,只返回纯 JSON:
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"numeric_data": [
|
||||
{"name": string, "value": number, "unit": string|null}
|
||||
],
|
||||
"categorical_data": [
|
||||
{"name": string, "value": string, "count": number|null}
|
||||
],
|
||||
"time_series_data": [
|
||||
{"name": string, "value": number}
|
||||
],
|
||||
"comparison_data": [
|
||||
{"name": string, "value": number, "unit": string|null}
|
||||
],
|
||||
"table_data": {
|
||||
"columns": string[],
|
||||
"rows": object[]
|
||||
} | null
|
||||
},
|
||||
"metadata": {
|
||||
"total_items": number,
|
||||
"data_types": string[]
|
||||
}
|
||||
}"""
|
||||
|
||||
user_message = f"""请从以下 AI 分析结果中提取结构化数据:
|
||||
|
||||
原始文件名:{original_filename}
|
||||
文件类型:{file_type}
|
||||
|
||||
AI 分析结果:
|
||||
{truncated_text}
|
||||
|
||||
请按照系统提示的要求提取数据并返回纯 JSON 格式。"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_message}
|
||||
]
|
||||
|
||||
try:
|
||||
logger.info(f"开始提取结构化数据,文本长度: {len(truncated_text)}")
|
||||
|
||||
response = await self.llm_service.chat(
|
||||
messages=messages,
|
||||
temperature=0.1,
|
||||
max_tokens=4000
|
||||
)
|
||||
|
||||
content = self.llm_service.extract_message_content(response)
|
||||
logger.info(f"LLM 返回内容长度: {len(content)}")
|
||||
|
||||
# 使用简单的方法提取 JSON
|
||||
result = self._extract_json_simple(content)
|
||||
|
||||
if not result:
|
||||
logger.error("无法从 LLM 响应中提取有效的 JSON")
|
||||
return {
|
||||
"success": False,
|
||||
"error": "AI 返回的数据格式不正确或被截断",
|
||||
"raw_content": content[:500]
|
||||
}
|
||||
|
||||
logger.info(f"成功提取结构化数据")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"提取结构化数据失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def _extract_json_simple(self, content: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
简化的 JSON 提取方法
|
||||
|
||||
Args:
|
||||
content: LLM 返回的内容
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 解析后的 JSON,失败返回 None
|
||||
"""
|
||||
try:
|
||||
# 方法 1: 查找 ```json 代码块
|
||||
code_block_match = re.search(r'```json\n{[\s\S]*?}[\s\S]*?}\n```', content, re.DOTALL)
|
||||
if code_block_match:
|
||||
json_str = code_block_match.group(1)
|
||||
logger.info("从代码块中提取 JSON")
|
||||
return json.loads(json_str)
|
||||
|
||||
# 方法 2: 查找第一个完整的 { } 对象
|
||||
brace_count = 0
|
||||
json_start = -1
|
||||
|
||||
for i in range(len(content)):
|
||||
if content[i] == '{':
|
||||
if brace_count == 0:
|
||||
json_start = i
|
||||
brace_count += 1
|
||||
elif content[i] == '}':
|
||||
brace_count -= 1
|
||||
if brace_count == 0:
|
||||
# 找到了完整的 JSON 对象
|
||||
json_end = i + 1
|
||||
json_str = content[json_start:json_end]
|
||||
logger.info(f"从大括号中提取 JSON")
|
||||
return json.loads(json_str)
|
||||
|
||||
# 方法 3: 尝试直接解析
|
||||
logger.info("尝试直接解析整个内容")
|
||||
return json.loads(content)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON 解析失败: {str(e)}")
|
||||
logger.error(f"原始内容(前 500 字符): {content[:500]}...")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"提取 JSON 失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def detect_data_types(self, data: Dict[str, Any]) -> List[str]:
|
||||
"""检测数据中包含的类型"""
|
||||
types = []
|
||||
d = data.get("data", {})
|
||||
|
||||
if d.get("numeric_data") and len(d["numeric_data"]) > 0:
|
||||
types.append("numeric")
|
||||
if d.get("categorical_data") and len(d["categorical_data"]) > 0:
|
||||
types.append("categorical")
|
||||
if d.get("time_series_data") and len(d["time_series_data"]) > 0:
|
||||
types.append("time_series")
|
||||
if d.get("comparison_data") and len(d["comparison_data"]) > 0:
|
||||
types.append("comparison")
|
||||
if d.get("table_data") and d["table_data"]:
|
||||
types.append("table")
|
||||
|
||||
return types
|
||||
|
||||
|
||||
# 全局单例
|
||||
text_analysis_service = TextAnalysisService()
|
||||
388
backend/app/services/visualization_service.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""
|
||||
数据可视化服务 - 使用 matplotlib/plotly 生成统计图表
|
||||
"""
|
||||
import io
|
||||
import base64
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
|
||||
# 使用字体辅助模块配置中文字体
|
||||
from app.services.font_helper import configure_matplotlib_fonts
|
||||
|
||||
configure_matplotlib_fonts()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VisualizationService:
|
||||
"""数据可视化服务类"""
|
||||
|
||||
def __init__(self):
|
||||
self.output_dir = Path(__file__).resolve().parent.parent.parent / "data" / "charts"
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def analyze_and_visualize(
|
||||
self,
|
||||
excel_data: Dict[str, Any],
|
||||
analysis_type: str = "statistics"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
分析数据并生成可视化图表
|
||||
|
||||
Args:
|
||||
excel_data: Excel 解析后的数据
|
||||
analysis_type: 分析类型
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 包含图表数据和统计信息的结果
|
||||
"""
|
||||
try:
|
||||
columns = excel_data.get("columns", [])
|
||||
rows = excel_data.get("rows", [])
|
||||
|
||||
if not columns or not rows:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "没有数据可用于分析"
|
||||
}
|
||||
|
||||
# 转换为 DataFrame
|
||||
df = pd.DataFrame(rows, columns=columns)
|
||||
|
||||
# 根据列类型分类
|
||||
numeric_columns = df.select_dtypes(include=[np.number]).columns.tolist()
|
||||
categorical_columns = df.select_dtypes(exclude=[np.number]).columns.tolist()
|
||||
|
||||
# 生成统计信息
|
||||
statistics = self._generate_statistics(df, numeric_columns, categorical_columns)
|
||||
|
||||
# 生成图表
|
||||
charts = self._generate_charts(df, numeric_columns, categorical_columns)
|
||||
|
||||
# 生成数据分布信息
|
||||
distributions = self._generate_distributions(df, categorical_columns)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"statistics": statistics,
|
||||
"charts": charts,
|
||||
"distributions": distributions,
|
||||
"row_count": len(df),
|
||||
"column_count": len(columns)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"可视化分析失败: {str(e)}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def _generate_statistics(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
numeric_columns: List[str],
|
||||
categorical_columns: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""生成统计信息"""
|
||||
statistics = {
|
||||
"numeric": {},
|
||||
"categorical": {}
|
||||
}
|
||||
|
||||
# 数值型列统计
|
||||
for col in numeric_columns:
|
||||
try:
|
||||
stats = {
|
||||
"count": int(df[col].count()),
|
||||
"mean": float(df[col].mean()),
|
||||
"median": float(df[col].median()),
|
||||
"std": float(df[col].std()) if df[col].count() > 1 else 0,
|
||||
"min": float(df[col].min()),
|
||||
"max": float(df[col].max()),
|
||||
"q25": float(df[col].quantile(0.25)),
|
||||
"q75": float(df[col].quantile(0.75)),
|
||||
"missing": int(df[col].isna().sum())
|
||||
}
|
||||
statistics["numeric"][col] = stats
|
||||
except Exception as e:
|
||||
logger.warning(f"列 {col} 统计失败: {str(e)}")
|
||||
|
||||
# 分类型列统计
|
||||
for col in categorical_columns:
|
||||
try:
|
||||
value_counts = df[col].value_counts()
|
||||
stats = {
|
||||
"unique": int(df[col].nunique()),
|
||||
"most_common": str(value_counts.index[0]) if len(value_counts) > 0 else "",
|
||||
"most_common_count": int(value_counts.iloc[0]) if len(value_counts) > 0 else 0,
|
||||
"missing": int(df[col].isna().sum()),
|
||||
"distribution": {str(k): int(v) for k, v in value_counts.items()}
|
||||
}
|
||||
statistics["categorical"][col] = stats
|
||||
except Exception as e:
|
||||
logger.warning(f"列 {col} 统计失败: {str(e)}")
|
||||
|
||||
return statistics
|
||||
|
||||
def _generate_charts(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
numeric_columns: List[str],
|
||||
categorical_columns: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""生成图表"""
|
||||
charts = {}
|
||||
|
||||
# 1. 数值型列的直方图
|
||||
charts["histograms"] = []
|
||||
for col in numeric_columns[:5]: # 限制最多 5 个数值列
|
||||
chart_data = self._create_histogram(df[col], col)
|
||||
if chart_data:
|
||||
charts["histograms"].append(chart_data)
|
||||
|
||||
# 2. 分类型列的条形图
|
||||
charts["bar_charts"] = []
|
||||
for col in categorical_columns[:5]: # 限制最多 5 个分类型列
|
||||
chart_data = self._create_bar_chart(df[col], col)
|
||||
if chart_data:
|
||||
charts["bar_charts"].append(chart_data)
|
||||
|
||||
# 3. 数值型列的箱线图
|
||||
charts["box_plots"] = []
|
||||
if len(numeric_columns) > 0:
|
||||
chart_data = self._create_box_plot(df[numeric_columns[:5]], numeric_columns[:5])
|
||||
if chart_data:
|
||||
charts["box_plots"].append(chart_data)
|
||||
|
||||
# 4. 相关性热力图
|
||||
if len(numeric_columns) >= 2:
|
||||
chart_data = self._create_correlation_heatmap(df[numeric_columns], numeric_columns)
|
||||
if chart_data:
|
||||
charts["correlation"] = chart_data
|
||||
|
||||
return charts
|
||||
|
||||
def _create_histogram(self, series: pd.Series, column_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""创建直方图"""
|
||||
try:
|
||||
fig, ax = plt.subplots(figsize=(11, 7))
|
||||
ax.hist(series.dropna(), bins=20, edgecolor='black', alpha=0.7, color='#3b82f6')
|
||||
ax.set_xlabel(column_name, fontsize=10, labelpad=10)
|
||||
ax.set_ylabel('频数', fontsize=10, labelpad=10)
|
||||
ax.set_title(f'{column_name} 分布', fontsize=12, fontweight='bold', pad=15)
|
||||
ax.grid(True, alpha=0.3, axis='y')
|
||||
ax.tick_params(axis='both', which='major', labelsize=9)
|
||||
|
||||
# 改进布局
|
||||
plt.tight_layout(pad=1.5, w_pad=1.0, h_pad=1.0)
|
||||
|
||||
# 转换为 base64
|
||||
img_base64 = self._figure_to_base64(fig)
|
||||
|
||||
return {
|
||||
"type": "histogram",
|
||||
"column": column_name,
|
||||
"image": img_base64,
|
||||
"stats": {
|
||||
"mean": float(series.mean()),
|
||||
"median": float(series.median()),
|
||||
"std": float(series.std()) if len(series) > 1 else 0
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"创建直方图失败 ({column_name}): {str(e)}")
|
||||
return None
|
||||
|
||||
def _create_bar_chart(self, series: pd.Series, column_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""创建条形图"""
|
||||
try:
|
||||
value_counts = series.value_counts().head(10) # 只显示前 10 个
|
||||
fig, ax = plt.subplots(figsize=(12, 7))
|
||||
|
||||
# 处理标签显示
|
||||
labels = [str(x)[:15] + '...' if len(str(x)) > 15 else str(x) for x in value_counts.index]
|
||||
x_pos = range(len(value_counts))
|
||||
bars = ax.bar(x_pos, value_counts.values, color='#10b981', alpha=0.8, edgecolor='black', linewidth=0.5)
|
||||
|
||||
ax.set_xticks(x_pos)
|
||||
ax.set_xticklabels(labels, rotation=30, ha='right', fontsize=8)
|
||||
ax.set_xlabel(column_name, fontsize=10, labelpad=10)
|
||||
ax.set_ylabel('数量', fontsize=10, labelpad=10)
|
||||
ax.set_title(f'{column_name} 分布 (Top 10)', fontsize=12, fontweight='bold', pad=15)
|
||||
ax.grid(True, alpha=0.3, axis='y')
|
||||
ax.tick_params(axis='both', which='major', labelsize=9)
|
||||
|
||||
# 添加数值标签(位置稍微上移)
|
||||
max_val = value_counts.values.max()
|
||||
y_offset = max_val * 0.02 if max_val > 0 else 0.5
|
||||
for bar, value in zip(bars, value_counts.values):
|
||||
ax.text(bar.get_x() + bar.get_width() / 2., value + y_offset,
|
||||
f'{int(value)}',
|
||||
ha='center', va='bottom', fontsize=8, fontweight='bold')
|
||||
|
||||
# 改进布局
|
||||
plt.tight_layout(pad=1.5, w_pad=1.0, h_pad=1.0)
|
||||
|
||||
# 转换为 base64
|
||||
img_base64 = self._figure_to_base64(fig)
|
||||
|
||||
return {
|
||||
"type": "bar_chart",
|
||||
"column": column_name,
|
||||
"image": img_base64,
|
||||
"categories": {str(k): int(v) for k, v in value_counts.items()}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"创建条形图失败 ({column_name}): {str(e)}")
|
||||
return None
|
||||
|
||||
def _create_box_plot(self, df: pd.DataFrame, columns: List[str]) -> Optional[Dict[str, Any]]:
|
||||
"""创建箱线图"""
|
||||
try:
|
||||
fig, ax = plt.subplots(figsize=(14, 7))
|
||||
|
||||
# 准备数据
|
||||
box_data = [df[col].dropna() for col in columns]
|
||||
bp = ax.boxplot(box_data, labels=columns, patch_artist=True,
|
||||
notch=True, showcaps=True, showfliers=True)
|
||||
|
||||
# 美化箱线图
|
||||
box_colors = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6']
|
||||
for patch, color in zip(bp['boxes'], box_colors[:len(bp['boxes'])]):
|
||||
patch.set_facecolor(color)
|
||||
patch.set_alpha(0.6)
|
||||
patch.set_linewidth(1.5)
|
||||
|
||||
# 设置其他元素样式
|
||||
for element in ['whiskers', 'fliers', 'means', 'medians', 'caps']:
|
||||
plt.setp(bp[element], linewidth=1.5)
|
||||
|
||||
ax.set_ylabel('值', fontsize=10, labelpad=10)
|
||||
ax.set_title('数值型列分布对比', fontsize=12, fontweight='bold', pad=15)
|
||||
ax.grid(True, alpha=0.3, axis='y')
|
||||
|
||||
# 旋转 x 轴标签以避免重叠
|
||||
plt.setp(ax.get_xticklabels(), rotation=30, ha='right', fontsize=9)
|
||||
ax.tick_params(axis='both', which='major', labelsize=9)
|
||||
|
||||
# 改进布局
|
||||
plt.tight_layout(pad=1.5, w_pad=1.5, h_pad=1.0)
|
||||
|
||||
# 转换为 base64
|
||||
img_base64 = self._figure_to_base64(fig)
|
||||
|
||||
return {
|
||||
"type": "box_plot",
|
||||
"columns": columns,
|
||||
"image": img_base64
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"创建箱线图失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def _create_correlation_heatmap(self, df: pd.DataFrame, columns: List[str]) -> Optional[Dict[str, Any]]:
|
||||
"""创建相关性热力图"""
|
||||
try:
|
||||
# 计算相关系数
|
||||
corr = df.corr()
|
||||
|
||||
fig, ax = plt.subplots(figsize=(11, 9))
|
||||
im = ax.imshow(corr, cmap='RdBu_r', aspect='auto', vmin=-1, vmax=1)
|
||||
|
||||
# 设置刻度
|
||||
n_cols = len(corr)
|
||||
ax.set_xticks(np.arange(n_cols))
|
||||
ax.set_yticks(np.arange(n_cols))
|
||||
|
||||
# 处理过长的列名
|
||||
x_labels = [str(col)[:10] + '...' if len(str(col)) > 10 else str(col) for col in corr.columns]
|
||||
y_labels = [str(col)[:10] + '...' if len(str(col)) > 10 else str(col) for col in corr.columns]
|
||||
|
||||
ax.set_xticklabels(x_labels, rotation=30, ha='right', fontsize=9)
|
||||
ax.set_yticklabels(y_labels, fontsize=9)
|
||||
|
||||
# 添加数值标签,根据相关性值选择颜色
|
||||
for i in range(n_cols):
|
||||
for j in range(n_cols):
|
||||
value = corr.iloc[i, j]
|
||||
# 根据背景色深浅选择文字颜色
|
||||
text_color = 'white' if abs(value) > 0.5 else 'black'
|
||||
ax.text(j, i, f'{value:.2f}',
|
||||
ha="center", va="center", color=text_color,
|
||||
fontsize=8, fontweight='bold' if abs(value) > 0.7 else 'normal')
|
||||
|
||||
ax.set_title('数值型列相关性热力图', fontsize=12, fontweight='bold', pad=15)
|
||||
ax.tick_params(axis='both', which='major', labelsize=9)
|
||||
|
||||
# 添加颜色条
|
||||
cbar = plt.colorbar(im, ax=ax)
|
||||
cbar.set_label('相关系数', rotation=270, labelpad=20, fontsize=10)
|
||||
cbar.ax.tick_params(labelsize=9)
|
||||
|
||||
# 改进布局
|
||||
plt.tight_layout(pad=2.0, w_pad=1.0, h_pad=1.0)
|
||||
|
||||
# 转换为 base64
|
||||
img_base64 = self._figure_to_base64(fig)
|
||||
|
||||
return {
|
||||
"type": "correlation_heatmap",
|
||||
"columns": columns,
|
||||
"image": img_base64,
|
||||
"correlation_matrix": corr.to_dict()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"创建相关性热力图失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def _generate_distributions(
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
categorical_columns: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""生成数据分布信息"""
|
||||
distributions = {}
|
||||
|
||||
for col in categorical_columns[:5]:
|
||||
try:
|
||||
value_counts = df[col].value_counts()
|
||||
total = len(df)
|
||||
|
||||
distributions[col] = {
|
||||
"categories": {str(k): int(v) for k, v in value_counts.items()},
|
||||
"percentages": {str(k): round(v / total * 100, 2) for k, v in value_counts.items()},
|
||||
"unique_count": len(value_counts)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"列 {col} 分布生成失败: {str(e)}")
|
||||
|
||||
return distributions
|
||||
|
||||
def _figure_to_base64(self, fig) -> str:
|
||||
"""将 matplotlib 图形转换为 base64 字符串"""
|
||||
buf = io.BytesIO()
|
||||
fig.savefig(
|
||||
buf,
|
||||
format='png',
|
||||
dpi=120,
|
||||
bbox_inches='tight',
|
||||
pad_inches=0.3,
|
||||
facecolor='white',
|
||||
edgecolor='none',
|
||||
transparent=False
|
||||
)
|
||||
plt.close(fig)
|
||||
buf.seek(0)
|
||||
img_base64 = base64.b64encode(buf.read()).decode('utf-8')
|
||||
return f"data:image/png;base64,{img_base64}"
|
||||
|
||||
|
||||
# 全局单例
|
||||
visualization_service = VisualizationService()
|
||||
639
backend/app/services/word_ai_service.py
Normal file
@@ -0,0 +1,639 @@
|
||||
"""
|
||||
Word 文档 AI 解析服务
|
||||
|
||||
使用 LLM (GLM) 对 Word 文档进行深度理解,提取结构化数据
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
import json
|
||||
|
||||
from app.services.llm_service import llm_service
|
||||
from app.core.document_parser.docx_parser import DocxParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WordAIService:
|
||||
"""Word 文档 AI 解析服务"""
|
||||
|
||||
def __init__(self):
|
||||
self.llm = llm_service
|
||||
self.parser = DocxParser()
|
||||
|
||||
async def parse_word_with_ai(
|
||||
self,
|
||||
file_path: str,
|
||||
user_hint: str = ""
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
使用 AI 解析 Word 文档,提取结构化数据
|
||||
|
||||
适用于从非结构化的 Word 文档中提取表格数据、键值对等信息
|
||||
|
||||
Args:
|
||||
file_path: Word 文件路径
|
||||
user_hint: 用户提示词,指定要提取的内容类型
|
||||
|
||||
Returns:
|
||||
Dict: 包含结构化数据的解析结果
|
||||
"""
|
||||
try:
|
||||
# 1. 先用基础解析器提取原始内容
|
||||
parse_result = self.parser.parse(file_path)
|
||||
|
||||
if not parse_result.success:
|
||||
return {
|
||||
"success": False,
|
||||
"error": parse_result.error,
|
||||
"structured_data": None
|
||||
}
|
||||
|
||||
# 2. 获取原始数据
|
||||
raw_data = parse_result.data
|
||||
paragraphs = raw_data.get("paragraphs", [])
|
||||
paragraphs_with_style = raw_data.get("paragraphs_with_style", [])
|
||||
tables = raw_data.get("tables", [])
|
||||
content = raw_data.get("content", "")
|
||||
images_info = raw_data.get("images", {})
|
||||
metadata = parse_result.metadata or {}
|
||||
|
||||
image_count = images_info.get("image_count", 0)
|
||||
image_descriptions = images_info.get("descriptions", [])
|
||||
|
||||
logger.info(f"Word 基础解析完成: {len(paragraphs)} 个段落, {len(tables)} 个表格, {image_count} 张图片")
|
||||
|
||||
# 3. 提取图片数据(用于视觉分析)
|
||||
images_base64 = []
|
||||
if image_count > 0:
|
||||
try:
|
||||
images_base64 = self.parser.extract_images_as_base64(file_path)
|
||||
logger.info(f"提取到 {len(images_base64)} 张图片的 base64 数据")
|
||||
except Exception as e:
|
||||
logger.warning(f"提取图片 base64 失败: {str(e)}")
|
||||
|
||||
# 4. 根据内容类型选择 AI 解析策略
|
||||
# 如果有图片,先分析图片
|
||||
image_analysis = ""
|
||||
if images_base64:
|
||||
image_analysis = await self._analyze_images_with_ai(images_base64, user_hint)
|
||||
logger.info(f"图片 AI 分析完成: {len(image_analysis)} 字符")
|
||||
|
||||
# 优先处理:表格 > (表格+文本) > 纯文本
|
||||
if tables and len(tables) > 0:
|
||||
structured_data = await self._extract_tables_with_ai(
|
||||
tables, paragraphs, image_count, user_hint, metadata, image_analysis
|
||||
)
|
||||
elif paragraphs and len(paragraphs) > 0:
|
||||
structured_data = await self._extract_from_text_with_ai(
|
||||
paragraphs, content, image_count, image_descriptions, user_hint, image_analysis
|
||||
)
|
||||
else:
|
||||
structured_data = {
|
||||
"success": True,
|
||||
"type": "empty",
|
||||
"message": "文档内容为空"
|
||||
}
|
||||
|
||||
# 添加图片分析结果
|
||||
if image_analysis:
|
||||
structured_data["image_analysis"] = image_analysis
|
||||
|
||||
return structured_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AI 解析 Word 文档失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"structured_data": None
|
||||
}
|
||||
|
||||
async def _extract_tables_with_ai(
|
||||
self,
|
||||
tables: List[Dict],
|
||||
paragraphs: List[str],
|
||||
image_count: int,
|
||||
user_hint: str,
|
||||
metadata: Dict,
|
||||
image_analysis: str = ""
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
使用 AI 从 Word 表格和文本中提取结构化数据
|
||||
|
||||
Args:
|
||||
tables: 表格列表
|
||||
paragraphs: 段落列表
|
||||
image_count: 图片数量
|
||||
user_hint: 用户提示
|
||||
metadata: 文档元数据
|
||||
image_analysis: 图片 AI 分析结果
|
||||
|
||||
Returns:
|
||||
结构化数据
|
||||
"""
|
||||
try:
|
||||
# 构建表格文本描述
|
||||
tables_text = self._build_tables_description(tables)
|
||||
|
||||
# 构建段落描述
|
||||
paragraphs_text = "\n".join(paragraphs[:50]) if paragraphs else "(无正文文本)"
|
||||
if len(paragraphs) > 50:
|
||||
paragraphs_text += f"\n...(共 {len(paragraphs)} 个段落,仅显示前50个)"
|
||||
|
||||
# 图片提示
|
||||
image_hint = f"注意:此文档包含 {image_count} 张图片/图表。" if image_count > 0 else ""
|
||||
|
||||
prompt = f"""你是一个专业的数据提取专家。请从以下 Word 文档的完整内容中提取结构化数据。
|
||||
|
||||
【用户需求】
|
||||
{user_hint if user_hint else "请提取文档中的所有结构化数据,包括表格数据、键值对、列表项等。"}
|
||||
|
||||
【文档正文(段落)】
|
||||
{paragraphs_text}
|
||||
|
||||
【文档表格】
|
||||
{tables_text}
|
||||
|
||||
【文档图片信息】
|
||||
{image_hint}
|
||||
|
||||
请按照以下 JSON 格式输出:
|
||||
{{
|
||||
"type": "table_data",
|
||||
"headers": ["列1", "列2", ...],
|
||||
"rows": [["行1列1", "行1列2", ...], ["行2列1", "行2列2", ...], ...],
|
||||
"key_values": {{"键1": "值1", "键2": "值2", ...}},
|
||||
"list_items": ["项1", "项2", ...],
|
||||
"description": "文档内容描述"
|
||||
}}
|
||||
|
||||
重点:
|
||||
- 优先从表格中提取结构化数据
|
||||
- 如果表格中有表头,headers 是表头,rows 是数据行
|
||||
- 如果文档中有键值对(如 名称: 张三),提取到 key_values 中
|
||||
- 如果文档中有列表项,提取到 list_items 中
|
||||
- 图片内容无法直接提取,但请在 description 中说明图片的大致主题(如"包含流程图"、"包含数据图表"等)
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的数据提取助手。请严格按JSON格式输出。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
response = await self.llm.chat(
|
||||
messages=messages,
|
||||
temperature=0.1,
|
||||
max_tokens=50000
|
||||
)
|
||||
|
||||
content = self.llm.extract_message_content(response)
|
||||
|
||||
# 解析 JSON
|
||||
result = self._parse_json_response(content)
|
||||
|
||||
if result:
|
||||
logger.info(f"AI 表格提取成功: {len(result.get('rows', []))} 行数据, key_values={len(result.get('key_values', {}))}, list_items={len(result.get('list_items', []))}")
|
||||
return {
|
||||
"success": True,
|
||||
"type": "table_data",
|
||||
"headers": result.get("headers", []),
|
||||
"rows": result.get("rows", []),
|
||||
"description": result.get("description", ""),
|
||||
"key_values": result.get("key_values", {}),
|
||||
"list_items": result.get("list_items", [])
|
||||
}
|
||||
else:
|
||||
# 如果 AI 返回格式不对,尝试直接解析表格
|
||||
return self._fallback_table_parse(tables)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AI 表格提取失败: {str(e)}")
|
||||
return self._fallback_table_parse(tables)
|
||||
|
||||
async def _extract_from_text_with_ai(
|
||||
self,
|
||||
paragraphs: List[str],
|
||||
full_text: str,
|
||||
image_count: int,
|
||||
image_descriptions: List[str],
|
||||
user_hint: str,
|
||||
image_analysis: str = ""
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
使用 AI 从 Word 纯文本中提取结构化数据
|
||||
|
||||
Args:
|
||||
paragraphs: 段落列表
|
||||
full_text: 完整文本
|
||||
image_count: 图片数量
|
||||
image_descriptions: 图片描述列表
|
||||
user_hint: 用户提示
|
||||
image_analysis: 图片 AI 分析结果
|
||||
|
||||
Returns:
|
||||
结构化数据
|
||||
"""
|
||||
try:
|
||||
# 限制文本长度
|
||||
text_preview = full_text[:8000] if len(full_text) > 8000 else full_text
|
||||
|
||||
# 图片提示
|
||||
image_hint = f"\n【文档图片】此文档包含 {image_count} 张图片/图表。" if image_count > 0 else ""
|
||||
if image_descriptions:
|
||||
image_hint += "\n" + "\n".join(image_descriptions)
|
||||
|
||||
prompt = f"""你是一个专业的数据提取专家。请从以下 Word 文档的完整内容中提取结构化数据。
|
||||
|
||||
【用户需求】
|
||||
{user_hint if user_hint else "请识别并提取文档中的关键信息,包括:表格数据、键值对、列表项等。"}
|
||||
|
||||
【文档正文】{image_hint}
|
||||
{text_preview}
|
||||
|
||||
请按照以下 JSON 格式输出:
|
||||
{{
|
||||
"type": "structured_text",
|
||||
"tables": [{{"headers": [...], "rows": [...]}}],
|
||||
"key_values": {{"键1": "值1", "键2": "值2", ...}},
|
||||
"list_items": ["项1", "项2", ...],
|
||||
"summary": "文档内容摘要"
|
||||
}}
|
||||
|
||||
重点:
|
||||
- 如果文档包含表格数据,提取到 tables 中
|
||||
- 如果文档包含键值对(如 名称: 张三),提取到 key_values 中
|
||||
- 如果文档包含列表项,提取到 list_items 中
|
||||
- 如果文档包含图片,请根据上下文推断图片内容(如"流程图"、"数据折线图"等)并在 description 中说明
|
||||
- 如果无法提取到结构化数据,至少提供一个详细的摘要
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的数据提取助手。请严格按JSON格式输出。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
response = await self.llm.chat(
|
||||
messages=messages,
|
||||
temperature=0.1,
|
||||
max_tokens=50000
|
||||
)
|
||||
|
||||
content = self.llm.extract_message_content(response)
|
||||
|
||||
result = self._parse_json_response(content)
|
||||
|
||||
if result:
|
||||
logger.info(f"AI 文本提取成功: type={result.get('type')}")
|
||||
return {
|
||||
"success": True,
|
||||
"type": result.get("type", "structured_text"),
|
||||
"tables": result.get("tables", []),
|
||||
"key_values": result.get("key_values", {}),
|
||||
"list_items": result.get("list_items", []),
|
||||
"summary": result.get("summary", ""),
|
||||
"raw_text_preview": text_preview[:500]
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": True,
|
||||
"type": "text",
|
||||
"summary": text_preview[:500],
|
||||
"raw_text_preview": text_preview[:500]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AI 文本提取失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def _analyze_images_with_ai(
|
||||
self,
|
||||
images: List[Dict[str, str]],
|
||||
user_hint: str = ""
|
||||
) -> str:
|
||||
"""
|
||||
使用视觉模型分析 Word 文档中的图片
|
||||
|
||||
Args:
|
||||
images: 图片列表,每项包含 base64 和 mime_type
|
||||
user_hint: 用户提示
|
||||
|
||||
Returns:
|
||||
图片分析结果文本
|
||||
"""
|
||||
try:
|
||||
# 调用 LLM 的视觉分析功能
|
||||
result = await self.llm.analyze_images(
|
||||
images=images,
|
||||
user_prompt=user_hint or "请详细描述图片内容,提取所有文字和数据信息。"
|
||||
)
|
||||
|
||||
if result.get("success"):
|
||||
analysis = result.get("analysis", {})
|
||||
if isinstance(analysis, dict):
|
||||
description = analysis.get("description", "")
|
||||
text_content = analysis.get("text_content", "")
|
||||
data_extracted = analysis.get("data_extracted", {})
|
||||
|
||||
result_text = f"【图片分析结果】\n{description}"
|
||||
if text_content:
|
||||
result_text += f"\n\n【图片中的文字】\n{text_content}"
|
||||
if data_extracted:
|
||||
result_text += f"\n\n【提取的数据】\n{json.dumps(data_extracted, ensure_ascii=False)}"
|
||||
return result_text
|
||||
else:
|
||||
return str(analysis)
|
||||
else:
|
||||
logger.warning(f"图片 AI 分析失败: {result.get('error')}")
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"图片 AI 分析异常: {str(e)}")
|
||||
return ""
|
||||
|
||||
def _build_tables_description(self, tables: List[Dict]) -> str:
|
||||
"""构建表格的文本描述"""
|
||||
result = []
|
||||
|
||||
for idx, table in enumerate(tables):
|
||||
rows = table.get("rows", [])
|
||||
if not rows:
|
||||
continue
|
||||
|
||||
result.append(f"\n--- 表格 {idx + 1} ---")
|
||||
|
||||
for row_idx, row in enumerate(rows[:50]): # 限制每表格最多50行
|
||||
if isinstance(row, list):
|
||||
result.append(" | ".join(str(cell).strip() for cell in row))
|
||||
elif isinstance(row, dict):
|
||||
result.append(str(row))
|
||||
|
||||
if len(rows) > 50:
|
||||
result.append(f"...(共 {len(rows)} 行,仅显示前50行)")
|
||||
|
||||
return "\n".join(result) if result else "(无表格内容)"
|
||||
|
||||
def _parse_json_response(self, content: str) -> Optional[Dict]:
|
||||
"""解析 JSON 响应,处理各种格式问题"""
|
||||
import re
|
||||
|
||||
# 清理 markdown 标记
|
||||
cleaned = content.strip()
|
||||
cleaned = re.sub(r'^```json\s*', '', cleaned, flags=re.MULTILINE)
|
||||
cleaned = re.sub(r'^```\s*', '', cleaned, flags=re.MULTILINE)
|
||||
cleaned = cleaned.strip()
|
||||
|
||||
# 找到 JSON 开始位置
|
||||
json_start = -1
|
||||
for i, c in enumerate(cleaned):
|
||||
if c == '{':
|
||||
json_start = i
|
||||
break
|
||||
|
||||
if json_start == -1:
|
||||
logger.warning("无法找到 JSON 开始位置")
|
||||
return None
|
||||
|
||||
json_text = cleaned[json_start:]
|
||||
|
||||
# 尝试直接解析
|
||||
try:
|
||||
return json.loads(json_text)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试修复并解析
|
||||
try:
|
||||
# 找到闭合括号
|
||||
depth = 0
|
||||
end_pos = -1
|
||||
for i, c in enumerate(json_text):
|
||||
if c == '{':
|
||||
depth += 1
|
||||
elif c == '}':
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
end_pos = i + 1
|
||||
break
|
||||
|
||||
if end_pos > 0:
|
||||
fixed = json_text[:end_pos]
|
||||
# 移除末尾逗号
|
||||
fixed = re.sub(r',\s*([}]])', r'\1', fixed)
|
||||
return json.loads(fixed)
|
||||
except Exception as e:
|
||||
logger.warning(f"JSON 修复失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _fallback_table_parse(self, tables: List[Dict]) -> Dict[str, Any]:
|
||||
"""当 AI 解析失败时,直接解析表格"""
|
||||
if not tables:
|
||||
return {
|
||||
"success": True,
|
||||
"type": "empty",
|
||||
"data": {},
|
||||
"message": "无表格内容"
|
||||
}
|
||||
|
||||
all_rows = []
|
||||
all_headers = None
|
||||
|
||||
for table in tables:
|
||||
rows = table.get("rows", [])
|
||||
if not rows:
|
||||
continue
|
||||
|
||||
# 查找真正的表头行(跳过标题行)
|
||||
header_row_idx = 0
|
||||
for idx, row in enumerate(rows[:5]): # 只检查前5行
|
||||
if not isinstance(row, list):
|
||||
continue
|
||||
# 如果某一行包含"表"字开头且单元格内容很长,这可能是标题行
|
||||
first_cell = str(row[0]) if row else ""
|
||||
if first_cell.startswith("表") and len(first_cell) > 15:
|
||||
header_row_idx = idx + 1
|
||||
continue
|
||||
# 如果某一行有超过3个空单元格,可能是无效行
|
||||
empty_count = sum(1 for cell in row if not str(cell).strip())
|
||||
if empty_count > 3:
|
||||
header_row_idx = idx + 1
|
||||
continue
|
||||
# 找到第一行看起来像表头的行(短单元格,大部分有内容)
|
||||
avg_len = sum(len(str(c)) for c in row) / len(row) if row else 0
|
||||
if avg_len < 20: # 表头通常比数据行短
|
||||
header_row_idx = idx
|
||||
break
|
||||
|
||||
if header_row_idx >= len(rows):
|
||||
continue
|
||||
|
||||
# 使用找到的表头行
|
||||
if rows and isinstance(rows[header_row_idx], list):
|
||||
headers = rows[header_row_idx]
|
||||
if all_headers is None:
|
||||
all_headers = headers
|
||||
|
||||
# 数据行(从表头之后开始)
|
||||
for row in rows[header_row_idx + 1:]:
|
||||
if isinstance(row, list) and len(row) == len(headers):
|
||||
all_rows.append(row)
|
||||
|
||||
if all_headers and all_rows:
|
||||
return {
|
||||
"success": True,
|
||||
"type": "table_data",
|
||||
"headers": all_headers,
|
||||
"rows": all_rows,
|
||||
"description": "直接从 Word 表格提取"
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"type": "raw",
|
||||
"tables": tables,
|
||||
"message": "表格数据(未AI处理)"
|
||||
}
|
||||
|
||||
async def fill_template_with_ai(
|
||||
self,
|
||||
file_path: str,
|
||||
template_fields: List[Dict[str, Any]],
|
||||
user_hint: str = ""
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
使用 AI 解析 Word 文档并填写模板
|
||||
|
||||
这是主要入口函数,前端调用此函数即可完成:
|
||||
1. AI 解析 Word 文档
|
||||
2. 根据模板字段提取数据
|
||||
3. 返回填写结果
|
||||
|
||||
Args:
|
||||
file_path: Word 文件路径
|
||||
template_fields: 模板字段列表 [{"name": "字段名", "hint": "提示词"}, ...]
|
||||
user_hint: 用户提示
|
||||
|
||||
Returns:
|
||||
填写结果
|
||||
"""
|
||||
try:
|
||||
# 1. AI 解析文档
|
||||
parse_result = await self.parse_word_with_ai(file_path, user_hint)
|
||||
|
||||
if not parse_result.get("success"):
|
||||
return {
|
||||
"success": False,
|
||||
"error": parse_result.get("error", "解析失败"),
|
||||
"filled_data": {},
|
||||
"source": "ai_parse_failed"
|
||||
}
|
||||
|
||||
# 2. 根据字段类型提取数据
|
||||
filled_data = {}
|
||||
extract_details = []
|
||||
|
||||
parse_type = parse_result.get("type", "")
|
||||
|
||||
if parse_type == "table_data":
|
||||
# 表格数据:直接匹配列名
|
||||
headers = parse_result.get("headers", [])
|
||||
rows = parse_result.get("rows", [])
|
||||
|
||||
for field in template_fields:
|
||||
field_name = field.get("name", "")
|
||||
values = self._extract_field_from_table(headers, rows, field_name)
|
||||
filled_data[field_name] = values
|
||||
extract_details.append({
|
||||
"field": field_name,
|
||||
"values": values,
|
||||
"source": "ai_table_extraction",
|
||||
"confidence": 0.9 if values else 0.0
|
||||
})
|
||||
|
||||
elif parse_type == "structured_text":
|
||||
# 结构化文本:尝试从 key_values 和 list_items 提取
|
||||
key_values = parse_result.get("key_values", {})
|
||||
list_items = parse_result.get("list_items", [])
|
||||
|
||||
for field in template_fields:
|
||||
field_name = field.get("name", "")
|
||||
value = key_values.get(field_name, "")
|
||||
if not value and list_items:
|
||||
value = list_items[0] if list_items else ""
|
||||
filled_data[field_name] = [value] if value else []
|
||||
extract_details.append({
|
||||
"field": field_name,
|
||||
"values": [value] if value else [],
|
||||
"source": "ai_text_extraction",
|
||||
"confidence": 0.7 if value else 0.0
|
||||
})
|
||||
|
||||
else:
|
||||
# 其他类型:返回原始解析结果供后续处理
|
||||
for field in template_fields:
|
||||
field_name = field.get("name", "")
|
||||
filled_data[field_name] = []
|
||||
extract_details.append({
|
||||
"field": field_name,
|
||||
"values": [],
|
||||
"source": "no_ai_data",
|
||||
"confidence": 0.0
|
||||
})
|
||||
|
||||
# 3. 返回结果
|
||||
max_rows = max(len(v) for v in filled_data.values()) if filled_data else 1
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"filled_data": filled_data,
|
||||
"fill_details": extract_details,
|
||||
"ai_parse_result": {
|
||||
"type": parse_type,
|
||||
"description": parse_result.get("description", "")
|
||||
},
|
||||
"source_doc_count": 1,
|
||||
"max_rows": max_rows
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AI 填表失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"filled_data": {},
|
||||
"fill_details": []
|
||||
}
|
||||
|
||||
def _extract_field_from_table(
|
||||
self,
|
||||
headers: List[str],
|
||||
rows: List[List],
|
||||
field_name: str
|
||||
) -> List[str]:
|
||||
"""从表格中提取指定字段的值"""
|
||||
# 查找匹配的列
|
||||
target_col_idx = None
|
||||
for col_idx, header in enumerate(headers):
|
||||
if field_name.lower() in str(header).lower() or str(header).lower() in field_name.lower():
|
||||
target_col_idx = col_idx
|
||||
break
|
||||
|
||||
if target_col_idx is None:
|
||||
return []
|
||||
|
||||
# 提取该列所有值
|
||||
values = []
|
||||
for row in rows:
|
||||
if isinstance(row, list) and target_col_idx < len(row):
|
||||
val = str(row[target_col_idx]).strip()
|
||||
if val:
|
||||
values.append(val)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
# 全局单例
|
||||
word_ai_service = WordAIService()
|
||||
BIN
backend/image/image-1.png
Normal file
|
After Width: | Height: | Size: 259 KiB |
BIN
backend/image/image-2.png
Normal file
|
After Width: | Height: | Size: 206 KiB |
BIN
backend/image/image.png
Normal file
|
After Width: | Height: | Size: 167 KiB |
@@ -52,6 +52,18 @@ settings.json内容如下:
|
||||
```
|
||||
保存即可
|
||||
|
||||
或者点击python解释器
|
||||

|
||||
|
||||
如果你完成了上述setting.json的配置,可以直接选择第三个使用 xxx 设置中的python xxx
|
||||
否则点击箭头指示的输入解释器路径
|
||||

|
||||
|
||||
找到你项目路径的\venv\Scripts\python.exe
|
||||

|
||||
例如我的:H:\OwnProject\FilesReadSysteam\backend\venv\Scripts\python.exe (记得加上这个.exe)
|
||||
输入进去即可
|
||||
|
||||
## 关于.gitignore
|
||||
为了在上传git仓库时,不把venv中的软件包和其他关于项目的特殊api key暴露,请将.gitignore文件放在项目根目录下,并添加以下内容:
|
||||
```bash
|
||||
@@ -70,7 +82,51 @@ settings.json内容如下:
|
||||
为了数据安全,请不要把api key暴露,请将api key保存在.env文件中,并添加到.gitignore中(正如前文所示),这样git就不会将api key上传到git仓库中。
|
||||
但,可以保留.env.example文件,以示需要调用的api key
|
||||
|
||||
### 预计项目结构:
|
||||
## 关于git账户
|
||||
直接在终端输入以下命令
|
||||
```bash
|
||||
#全局设置
|
||||
git config --global user.name "你的名字"
|
||||
git config --global user.email "你的邮箱@example.com"
|
||||
|
||||
#单个项目设置
|
||||
cd 你的项目路径
|
||||
git config user.name "你的项目专用名字"
|
||||
git config user.email "你的项目专用邮箱@example.com"
|
||||
|
||||
#验证
|
||||
git config --list #查看所有配置
|
||||
|
||||
git config user.name #查看单条
|
||||
git config user.email #同上
|
||||
|
||||
#如果想看全局的,可以加上 --global,例如 git config --global user.name
|
||||
```
|
||||
|
||||
需要更新以下库
|
||||
先进入虚拟机
|
||||
```bash
|
||||
cd backend
|
||||
.\venv\Scripts\Activate.ph1
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 启动后端项目
|
||||
在终端输入以下命令:
|
||||
```bash
|
||||
cd backend #确保启动时在后端跟目录下
|
||||
./venv/Scripts/python.exe -m uvicorn app.main:app --host 127.0.0.1 --port 8000 --reload #启动后端项目
|
||||
```
|
||||
先启动后端项目,再启动前端项目
|
||||
|
||||
记得在你的.gitignore中添加:
|
||||
```
|
||||
/backend/data/uploads
|
||||
/backend/data/charts
|
||||
```
|
||||
|
||||
## 预计项目结构:
|
||||
|
||||
```bash
|
||||
FilesReadSystem/
|
||||
├── backend/ # 后端服务(Python + FastAPI)
|
||||
|
||||
@@ -1,22 +1,54 @@
|
||||
fastapi[all]==0.104.1
|
||||
# ============================================================
|
||||
# 基于大语言模型的文档理解与多源数据融合系统
|
||||
# Python 依赖清单
|
||||
# ============================================================
|
||||
|
||||
# ==================== Web 框架 ====================
|
||||
fastapi[all]==0.104.1
|
||||
uvicorn[standard]==0.24.0
|
||||
pydantic==2.5.0
|
||||
python-multipart==0.0.6
|
||||
|
||||
# ==================== 数据验证与配置 ====================
|
||||
pydantic==2.5.0
|
||||
pydantic-settings==2.1.0
|
||||
python-dotenv==1.0.0
|
||||
|
||||
# ==================== 数据库 - MySQL (结构化数据) ====================
|
||||
pymysql==1.1.0
|
||||
aiomysql==0.2.0
|
||||
sqlalchemy==2.0.25
|
||||
|
||||
# ==================== 数据库 - MongoDB (非结构化数据) ====================
|
||||
motor==3.3.2
|
||||
pymongo==4.5.0
|
||||
|
||||
# ==================== 数据库 - Redis (缓存/队列) ====================
|
||||
redis==5.0.0
|
||||
|
||||
# ==================== 异步任务 ====================
|
||||
celery==5.3.4
|
||||
sentence-transformers==2.2.2
|
||||
|
||||
# ==================== RAG / 向量数据库 ====================
|
||||
# chromadb==0.4.22 # Windows 需要 C++ 编译环境,如需安装请使用预编译版本或 WSL
|
||||
sentence-transformers==2.7.0
|
||||
faiss-cpu==1.8.0
|
||||
python-docx==0.8.11
|
||||
|
||||
# ==================== 文档解析 ====================
|
||||
pandas==2.1.4
|
||||
openpyxl==3.1.2
|
||||
markdown==3.5.1
|
||||
langchain==0.1.0
|
||||
langchain-community==0.0.10
|
||||
requests==2.31.0
|
||||
python-docx==0.8.11
|
||||
markdown-it-py==3.0.0
|
||||
chardet==5.2.0
|
||||
|
||||
# ==================== AI / LLM ====================
|
||||
httpx==0.25.2
|
||||
python-dotenv==1.0.0
|
||||
|
||||
# ==================== 数据处理与可视化 ====================
|
||||
matplotlib==3.8.2
|
||||
numpy==1.26.2
|
||||
|
||||
# ==================== 工具库 ====================
|
||||
requests==2.31.0
|
||||
loguru==0.7.2
|
||||
tqdm==4.66.1
|
||||
numpy==1.26.2
|
||||
PyYAML==6.0.1
|
||||
PyYAML==6.0.1
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
print("Hello,World")
|
||||
46
backend/test_mongodb.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
MongoDB 数据库连接测试
|
||||
"""
|
||||
import asyncio
|
||||
from app.core.database.mongodb import mongodb
|
||||
|
||||
|
||||
async def test_mongodb():
|
||||
print("=" * 50)
|
||||
print("MongoDB 数据库连接测试")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
# 连接
|
||||
await mongodb.connect()
|
||||
print(f"✓ MongoDB 连接成功: {mongodb.client}")
|
||||
|
||||
# 测试插入
|
||||
test_doc = {"test": "hello", "value": 123}
|
||||
doc_id = await mongodb.client.test_database.test_collection.insert_one(test_doc)
|
||||
print(f"✓ 写入测试成功, ID: {doc_id.inserted_id}")
|
||||
|
||||
# 测试查询
|
||||
doc = await mongodb.client.test_database.test_collection.find_one({"test": "hello"})
|
||||
print(f"✓ 读取测试成功: {doc}")
|
||||
|
||||
# 删除测试数据
|
||||
await mongodb.client.test_database.test_collection.delete_one({"test": "hello"})
|
||||
print(f"✓ 删除测试数据成功")
|
||||
|
||||
# 列出数据库
|
||||
dbs = await mongodb.client.list_database_names()
|
||||
print(f"✓ 数据库列表: {dbs}")
|
||||
|
||||
print("\n✓ MongoDB 测试通过!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ MongoDB 测试失败: {e}")
|
||||
return False
|
||||
finally:
|
||||
await mongodb.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_mongodb())
|
||||
37
backend/test_mysql.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
MySQL 数据库连接测试
|
||||
"""
|
||||
import asyncio
|
||||
from sqlalchemy import text
|
||||
from app.core.database.mysql import mysql_db
|
||||
|
||||
|
||||
async def test_mysql():
|
||||
print("=" * 50)
|
||||
print("MySQL 数据库连接测试")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
# 测试连接
|
||||
async with mysql_db.async_session_factory() as session:
|
||||
result = await session.execute(text("SELECT 1"))
|
||||
print(f"✓ MySQL 连接成功: {result.fetchone()}")
|
||||
|
||||
# 测试查询数据库
|
||||
async with mysql_db.async_session_factory() as session:
|
||||
result = await session.execute(text("SHOW DATABASES"))
|
||||
dbs = result.fetchall()
|
||||
print(f"✓ 数据库列表: {[db[0] for db in dbs]}")
|
||||
|
||||
print("\n✓ MySQL 测试通过!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ MySQL 测试失败: {e}")
|
||||
return False
|
||||
finally:
|
||||
await mysql_db.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_mysql())
|
||||
46
backend/test_redis.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
Redis 数据库连接测试
|
||||
"""
|
||||
import asyncio
|
||||
from app.core.database.redis_db import redis_db
|
||||
|
||||
|
||||
async def test_redis():
|
||||
print("=" * 50)
|
||||
print("Redis 数据库连接测试")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
# 连接
|
||||
await redis_db.connect()
|
||||
print(f"✓ Redis 连接成功")
|
||||
|
||||
# 测试写入
|
||||
await redis_db.client.set("test_key", "hello_redis")
|
||||
print(f"✓ 写入测试成功")
|
||||
|
||||
# 测试读取
|
||||
value = await redis_db.client.get("test_key")
|
||||
print(f"✓ 读取测试成功: {value}")
|
||||
|
||||
# 测试删除
|
||||
await redis_db.client.delete("test_key")
|
||||
print(f"✓ 删除测试成功")
|
||||
|
||||
# 测试任务状态
|
||||
await redis_db.set_task_status("test_task", "processing", {"progress": 50})
|
||||
status = await redis_db.get_task_status("test_task")
|
||||
print(f"✓ 任务状态测试成功: {status}")
|
||||
|
||||
print("\n✓ Redis 测试通过!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ Redis 测试失败: {e}")
|
||||
return False
|
||||
finally:
|
||||
await redis_db.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_redis())
|
||||
BIN
docs/test/2025山东省环境空气质量监测数据信息.xlsx
Normal file
BIN
docs/test/2025年国考职位表(节选).xlsx
Normal file
BIN
docs/test/COVID-19全球数据集(节选).xlsx
Normal file
BIN
docs/test/电商销售数据.xlsx
Normal file
BIN
docs/test/糖尿病患者数据.xlsx
Normal file
7
frontend/.env
Normal file
@@ -0,0 +1,7 @@
|
||||
VITE_APP_ID=app-a6ww9j3ja3nl
|
||||
|
||||
VITE_SUPABASE_URL=https://ojtxpvjgqoybhmadimym.supabase.co
|
||||
|
||||
VITE_SUPABASE_ANON_KEY=sb_publishable_VMZMg44D-9bKE6bsbUiSsw_x3rUJbu2
|
||||
|
||||
VITE_BACKEND_API_URL=http://localhost:8000/api/v1
|
||||
7
frontend/.env.example
Normal file
@@ -0,0 +1,7 @@
|
||||
VITE_APP_ID=
|
||||
|
||||
VITE_SUPABASE_URL=
|
||||
|
||||
VITE_SUPABASE_ANON_KEY=
|
||||
|
||||
VITE_BACKEND_API_URL=http://localhost:8000/api/v1
|
||||
29
frontend/.gitignore
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
|
||||
node_modules
|
||||
dist
|
||||
dist-ssr
|
||||
output
|
||||
*.local
|
||||
package-lock.json
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
.DS_Store
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
.sync
|
||||
history/*.json
|
||||
.vite_cache
|
||||
28
frontend/.rules/SelectItem.yml
Normal file
@@ -0,0 +1,28 @@
|
||||
id: selectItemWithEmptyValue
|
||||
language: Tsx
|
||||
files:
|
||||
- src/**/*.tsx
|
||||
rule:
|
||||
kind: jsx_opening_element
|
||||
all:
|
||||
- has:
|
||||
kind: identifier
|
||||
regex: '^SelectItem$'
|
||||
- has:
|
||||
kind: jsx_attribute
|
||||
all:
|
||||
- has:
|
||||
kind: property_identifier
|
||||
regex: '^value$'
|
||||
- any:
|
||||
- has:
|
||||
kind: string
|
||||
regex: '^""$'
|
||||
- has:
|
||||
kind: jsx_expression
|
||||
has:
|
||||
kind: string
|
||||
regex: '^""$'
|
||||
|
||||
message: "检测到 SelectItem 组件使用空字符串 value: $MATCH, 这是错误用法, 运行时会报错, 请修改, 如果想实现全选,建议使用all代替空字符串"
|
||||
severity: error
|
||||
39
frontend/.rules/check.sh
Normal file
@@ -0,0 +1,39 @@
|
||||
#!/bin/bash
|
||||
|
||||
ast-grep scan -r .rules/SelectItem.yml
|
||||
|
||||
ast-grep scan -r .rules/contrast.yml
|
||||
|
||||
ast-grep scan -r .rules/supabase-google-sso.yml
|
||||
|
||||
ast-grep scan -r .rules/toast-hook.yml
|
||||
|
||||
ast-grep scan -r .rules/slot-nesting.yml
|
||||
|
||||
ast-grep scan -r .rules/require-button-interaction.yml
|
||||
|
||||
useauth_output=$(ast-grep scan -r .rules/useAuth.yml 2>/dev/null)
|
||||
|
||||
if [ -z "$useauth_output" ]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
authprovider_output=$(ast-grep scan -r .rules/authProvider.yml 2>/dev/null)
|
||||
|
||||
if [ -n "$authprovider_output" ]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "=== ast-grep scan -r .rules/useAuth.yml output ==="
|
||||
echo "$useauth_output"
|
||||
echo ""
|
||||
echo "=== ast-grep scan -r .rules/authProvider.yml output ==="
|
||||
echo "$authprovider_output"
|
||||
echo ""
|
||||
echo "⚠️ Issue detected:"
|
||||
echo "The code uses useAuth Hook but does not have AuthProvider component wrapping the components."
|
||||
echo "Please ensure that components using useAuth are wrapped with AuthProvider to provide proper authentication context."
|
||||
echo ""
|
||||
echo "Suggested fixes:"
|
||||
echo "1. Add AuthProvider wrapper in app.tsx or corresponding root component"
|
||||
echo "2. Ensure all components using useAuth are within AuthProvider scope"
|
||||
103
frontend/.rules/contrast.yml
Normal file
@@ -0,0 +1,103 @@
|
||||
id: button-outline-text-foreground-contrast
|
||||
language: tsx
|
||||
files:
|
||||
- src/**/*.tsx
|
||||
message: "Outline button with text-foreground class causes invisible text. The outline variant has a transparent background, making text-foreground color blend with the background and become unreadable. Use text-primary or another contrasting color instead."
|
||||
rule:
|
||||
kind: jsx_element
|
||||
has:
|
||||
kind: jsx_opening_element
|
||||
all:
|
||||
- has:
|
||||
field: name
|
||||
regex: "^Button$"
|
||||
- has:
|
||||
kind: jsx_attribute
|
||||
all:
|
||||
- has:
|
||||
kind: property_identifier
|
||||
regex: "^variant$"
|
||||
- has:
|
||||
kind: string
|
||||
has:
|
||||
kind: string_fragment
|
||||
regex: "^outline$"
|
||||
- has:
|
||||
kind: jsx_attribute
|
||||
all:
|
||||
- has:
|
||||
kind: property_identifier
|
||||
regex: "^className$"
|
||||
- has:
|
||||
kind: string
|
||||
has:
|
||||
kind: string_fragment
|
||||
regex: "(^|\\s)text-foreground(\\s|$)"
|
||||
---
|
||||
id: button-default-text-primary-contrast
|
||||
language: tsx
|
||||
files:
|
||||
- src/**/*.tsx
|
||||
message: "Default button with text-primary class causes poor contrast. The default variant has a primary-colored background, making text-primary color blend with the background and become hard to read. Remove the text-primary class or specify a different variant like 'outline' or 'ghost'."
|
||||
rule:
|
||||
kind: jsx_element
|
||||
has:
|
||||
kind: jsx_opening_element
|
||||
all:
|
||||
- has:
|
||||
field: name
|
||||
regex: "^Button$"
|
||||
- has:
|
||||
kind: jsx_attribute
|
||||
all:
|
||||
- has:
|
||||
kind: property_identifier
|
||||
regex: "^className$"
|
||||
- has:
|
||||
kind: string
|
||||
has:
|
||||
kind: string_fragment
|
||||
regex: "(^|\\s)text-primary(\\s|$)"
|
||||
- not:
|
||||
has:
|
||||
kind: jsx_attribute
|
||||
has:
|
||||
kind: property_identifier
|
||||
regex: "^variant$"
|
||||
|
||||
---
|
||||
id: button-outline-white-gray-contrast
|
||||
language: tsx
|
||||
files:
|
||||
- src/**/*.tsx
|
||||
message: "Outline button with white/gray text color has poor contrast. Remove the text color class and use the default button text color."
|
||||
rule:
|
||||
kind: jsx_element
|
||||
has:
|
||||
kind: jsx_opening_element
|
||||
all:
|
||||
- has:
|
||||
field: name
|
||||
regex: "^Button$"
|
||||
- has:
|
||||
kind: jsx_attribute
|
||||
all:
|
||||
- has:
|
||||
kind: property_identifier
|
||||
regex: "^variant$"
|
||||
- has:
|
||||
kind: string
|
||||
has:
|
||||
kind: string_fragment
|
||||
regex: "^outline$"
|
||||
- has:
|
||||
kind: jsx_attribute
|
||||
all:
|
||||
- has:
|
||||
kind: property_identifier
|
||||
regex: "^className$"
|
||||
- has:
|
||||
kind: string
|
||||
has:
|
||||
kind: string_fragment
|
||||
regex: "(^|\\s)text-(white|gray)(-[0-9]+)?(\\s|$)"
|
||||
56
frontend/.rules/require-button-interaction.yml
Normal file
@@ -0,0 +1,56 @@
|
||||
id: require-button-interaction
|
||||
language: Tsx
|
||||
files:
|
||||
- src/**/*.tsx
|
||||
- src/**/*.jsx
|
||||
rule:
|
||||
kind: jsx_opening_element
|
||||
all:
|
||||
# 必须是 <Button> 组件
|
||||
- has:
|
||||
kind: identifier
|
||||
regex: '^Button$'
|
||||
# 没有 onClick
|
||||
- not:
|
||||
has:
|
||||
kind: jsx_attribute
|
||||
has:
|
||||
kind: property_identifier
|
||||
regex: '^onClick$'
|
||||
# 没有 asChild
|
||||
- not:
|
||||
has:
|
||||
kind: jsx_attribute
|
||||
has:
|
||||
kind: property_identifier
|
||||
regex: '^asChild$'
|
||||
# 没有 type="submit" 或 type="reset"
|
||||
- not:
|
||||
has:
|
||||
kind: jsx_attribute
|
||||
all:
|
||||
- has:
|
||||
kind: property_identifier
|
||||
regex: '^type$'
|
||||
- any:
|
||||
- has:
|
||||
kind: string
|
||||
regex: '^"(submit|reset)"$'
|
||||
- has:
|
||||
kind: jsx_expression
|
||||
has:
|
||||
kind: string
|
||||
regex: '^"(submit|reset)"$'
|
||||
# 不在 *Trigger 组件内部(如 DialogTrigger、SheetTrigger)
|
||||
- not:
|
||||
inside:
|
||||
stopBy: end
|
||||
kind: jsx_element
|
||||
has:
|
||||
kind: jsx_opening_element
|
||||
has:
|
||||
kind: identifier
|
||||
regex: 'Trigger$'
|
||||
|
||||
message: '<Button> 必须是可点击的:请添加 onClick、type="submit"、type="reset"、asChild 属性,或将其包裹在 *Trigger 组件中'
|
||||
severity: error
|
||||
52
frontend/.rules/slot-nesting.yml
Normal file
@@ -0,0 +1,52 @@
|
||||
---
|
||||
id: radix-trigger-formcontrol-nesting
|
||||
language: tsx
|
||||
files:
|
||||
- src/**/*.tsx
|
||||
message: |
|
||||
❌ 检测到危险的 Slot 嵌套:Radix UI Trigger (asChild) 内包裹 FormControl
|
||||
|
||||
问题代码:
|
||||
<PopoverTrigger asChild>
|
||||
<FormControl> ← 会导致点击事件失效
|
||||
<Button>...</Button>
|
||||
</FormControl>
|
||||
</PopoverTrigger>
|
||||
|
||||
正确写法:
|
||||
<PopoverTrigger asChild>
|
||||
<Button>...</Button> ← 直接使用 Button
|
||||
</PopoverTrigger>
|
||||
|
||||
原因:FormControl 和 Trigger 都使用 Radix UI 的 Slot 机制,双层嵌套会导致:
|
||||
- ref 传递链断裂
|
||||
- 点击事件丢失
|
||||
- 内部组件无法交互
|
||||
|
||||
FormControl 只应该用于原生表单控件(Input, Textarea, Select),不要用于触发器按钮。
|
||||
severity: error
|
||||
rule:
|
||||
kind: jsx_element
|
||||
all:
|
||||
# 开始标签需要满足的条件
|
||||
- has:
|
||||
kind: jsx_opening_element
|
||||
all:
|
||||
# 匹配所有 Radix UI 的 Trigger 组件
|
||||
- has:
|
||||
field: name
|
||||
regex: "^(Popover|Dialog|DropdownMenu|AlertDialog|HoverCard|Menubar|NavigationMenu|ContextMenu|Tooltip)Trigger$"
|
||||
# 必须有 asChild 属性
|
||||
- has:
|
||||
kind: jsx_attribute
|
||||
has:
|
||||
kind: property_identifier
|
||||
regex: "^asChild$"
|
||||
# 直接子元素包含 FormControl
|
||||
- has:
|
||||
kind: jsx_element
|
||||
has:
|
||||
kind: jsx_opening_element
|
||||
has:
|
||||
field: name
|
||||
regex: "^FormControl$"
|
||||
20
frontend/.rules/supabase-google-sso.yml
Normal file
@@ -0,0 +1,20 @@
|
||||
id: supabase-google-sso
|
||||
language: Tsx
|
||||
files:
|
||||
- src/**/*.tsx
|
||||
rule:
|
||||
pattern: |
|
||||
$AUTH.signInWithOAuth({ provider: 'google', $$$ })
|
||||
message: |
|
||||
Replace `signInWithOAuth` with `signInWithSSO` for Google authentication (Supabase).
|
||||
|
||||
Refactor to:
|
||||
```typescript
|
||||
const { data, error } = await supabase.auth.signInWithSSO({
|
||||
domain: 'miaoda-gg.com',
|
||||
options: { redirectTo: window.location.origin },
|
||||
});
|
||||
if (data?.url) window.open(data.url, '_self');
|
||||
```
|
||||
Ensure `window.open` uses `_self` target.
|
||||
severity: warning
|
||||
10
frontend/.rules/testBuild.sh
Normal file
@@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
OUTPUT=$(npx vite build --minify false --logLevel error --outDir /workspace/.dist 2>&1)
|
||||
EXIT_CODE=$?
|
||||
|
||||
if [ $EXIT_CODE -ne 0 ]; then
|
||||
echo "$OUTPUT"
|
||||
fi
|
||||
|
||||
exit $EXIT_CODE
|
||||
11
frontend/.rules/toast-hook.yml
Normal file
@@ -0,0 +1,11 @@
|
||||
id: use-toast-import
|
||||
message: Use 'import { toast } from "sonner"' instead of "@/hooks/use-toast"
|
||||
severity: error
|
||||
language: Tsx
|
||||
note: |
|
||||
The new shadcn/ui pattern uses sonner for toast notifications.
|
||||
Replace: import { toast } from "@/hooks/use-toast"
|
||||
With: import { toast } from "sonner"
|
||||
|
||||
rule:
|
||||
pattern: import { $$$IMPORTS } from "@/hooks/use-toast"
|
||||
67
frontend/README.md
Normal file
@@ -0,0 +1,67 @@
|
||||
|
||||
|
||||
## 介绍
|
||||
|
||||
项目介绍
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
├── README.md # 说明文档
|
||||
├── components.json # 组件库配置
|
||||
├── index.html # 入口文件
|
||||
├── package.json # 包管理
|
||||
├── postcss.config.js # postcss 配置
|
||||
├── public # 静态资源目录
|
||||
│ ├── favicon.png # 图标
|
||||
│ └── images # 图片资源
|
||||
├── src # 源码目录
|
||||
│ ├── App.tsx # 入口文件
|
||||
│ ├── components # 组件目录
|
||||
│ ├── contexts # 上下文目录
|
||||
│ ├── db # 数据库配置目录
|
||||
│ ├── hooks # 通用钩子函数目录
|
||||
│ ├── index.css # 全局样式
|
||||
│ ├── layout # 布局目录
|
||||
│ ├── lib # 工具库目录
|
||||
│ ├── main.tsx # 入口文件
|
||||
│ ├── routes.tsx # 路由配置
|
||||
│ ├── pages # 页面目录
|
||||
│ ├── services # 数据库交互目录
|
||||
│ ├── types # 类型定义目录
|
||||
├── tsconfig.app.json # ts 前端配置文件
|
||||
├── tsconfig.json # ts 配置文件
|
||||
├── tsconfig.node.json # ts node端配置文件
|
||||
└── vite.config.ts # vite 配置文件
|
||||
```
|
||||
|
||||
## 技术栈
|
||||
|
||||
Vite、TypeScript、React、Supabase
|
||||
|
||||
## 本地开发
|
||||
|
||||
首先进行包安装:
|
||||
```bash
|
||||
cd frontend #进入前端目录
|
||||
npm install #确定目录中有node_modules文件夹后输入命令安装依赖包
|
||||
```
|
||||
|
||||
## 启动项目
|
||||
启动项目:
|
||||
```bash
|
||||
npm run dev #启动项目,需要确保后端已启动,否则前端功能无法使用
|
||||
```
|
||||
启动后在终端ctrl+左键点击项目地址打开浏览器,一般是http://localhost:5173
|
||||
|
||||
|
||||
记得在你根目录下的.gitignore文件中添加:
|
||||
|
||||
```bash
|
||||
/frontend/node_modules/
|
||||
/frontend/dist/
|
||||
/frontend/build/
|
||||
/frontend/.vscode/
|
||||
/frontend/.idea/
|
||||
/frontend/*.log
|
||||
```
|
||||
37
frontend/TODO.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# Task: 基于大语言模型的文档理解与多源数据融合系统
|
||||
|
||||
## Plan
|
||||
- [x] 数据库初始化与权限配置 (Supabase)
|
||||
- [x] 创建 `profiles` 表及触发器 (登录同步)
|
||||
- [x] 创建 `documents` 表 (存储上传的文档信息)
|
||||
- [x] 创建 `extracted_entities` 表 (存储从文档提取的结构化数据)
|
||||
- [x] 创建 `templates` 表 (存储表格模板)
|
||||
- [x] 创建 `fill_tasks` 表 (存储填写任务)
|
||||
- [x] 配置 RLS 策略 (Row Level Security)
|
||||
- [x] 创建 Storage 存储桶 `document_storage` (存储文档和模板)
|
||||
- [x] 基础架构与登录模块
|
||||
- [x] 配置路由 `@/routes.tsx`
|
||||
- [x] 创建登录/注册页面
|
||||
- [x] 实现 `AuthContext` 与 `RouteGuard` (登录状态管理)
|
||||
- [x] 创建系统主布局 `MainLayout` (含侧边栏导航)
|
||||
- [x] 文档上传与智能提取功能
|
||||
- [x] 实现文档上传组件 (支持 docx, md, xlsx, txt)
|
||||
- [x] 部署 Edge Function `process-document` (调用 MiniMax 处理文档提取)
|
||||
- [x] 实现文档列表与详情页 (显示提取的结构化数据)
|
||||
- [x] 表格模板与自动填写模块
|
||||
- [x] 实现模板上传与管理
|
||||
- [x] 部署 Edge Function `fill-template` (基于提取数据填充表格)
|
||||
- [x] 实现任务监控与结果下载
|
||||
- [x] 智能对话交互模块
|
||||
- [x] 实现智能助手聊天界面 (侧边栏或独立页面)
|
||||
- [x] 部署 Edge Function `chat-assistant` (解析自然语言指令执行操作)
|
||||
- [x] 系统优化与美化
|
||||
- [x] 全面应用科技蓝办公风格 (index.css, tailwind.config.js)
|
||||
- [x] 响应式适配 (移动端兼容)
|
||||
- [x] 完善错误处理与加载状态 (Skeleton, Toast)
|
||||
|
||||
## Notes
|
||||
- 所有 Edge Functions 已部署并集成 MiniMax API
|
||||
- 文档解析使用 mammoth (docx), xlsx (excel), 原生 TextDecoder (txt/md)
|
||||
- 系统采用科技蓝主题,支持暗色模式
|
||||
- 所有代码已通过 lint 检查
|
||||
24
frontend/biome.json
Normal file
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"$schema": "./node_modules/@biomejs/biome/configuration_schema.json",
|
||||
"files": {
|
||||
"includes": ["src/**/*.{js,jsx,ts,tsx}"]
|
||||
},
|
||||
"linter": {
|
||||
"enabled": true,
|
||||
"rules": {
|
||||
"recommended": false,
|
||||
"correctness": {
|
||||
"noUndeclaredDependencies": "error"
|
||||
},
|
||||
"suspicious": {
|
||||
"noRedeclare": "error"
|
||||
},
|
||||
"style": {
|
||||
"noCommonJs": "error"
|
||||
}
|
||||
}
|
||||
},
|
||||
"formatter": {
|
||||
"enabled": false
|
||||
}
|
||||
}
|
||||
21
frontend/components.json
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"$schema": "https://ui.shadcn.com/schema.json",
|
||||
"style": "new-york",
|
||||
"rsc": false,
|
||||
"tsx": true,
|
||||
"tailwind": {
|
||||
"config": "tailwind.config.js",
|
||||
"css": "src/index.css",
|
||||
"baseColor": "slate",
|
||||
"cssVariables": true,
|
||||
"prefix": ""
|
||||
},
|
||||
"iconLibrary": "lucide",
|
||||
"aliases": {
|
||||
"components": "@/components",
|
||||
"utils": "@/lib/utils",
|
||||
"ui": "@/components/ui",
|
||||
"lib": "@/lib",
|
||||
"hooks": "@/hooks"
|
||||
}
|
||||
}
|
||||
95
frontend/docs/prd.md
Normal file
@@ -0,0 +1,95 @@
|
||||
# 基于大语言模型的文档理解与多源数据融合系统需求文档
|
||||
|
||||
## 1. 应用概述
|
||||
|
||||
### 1.1 应用名称
|
||||
基于大语言模型的文档理解与多源数据融合系统
|
||||
|
||||
### 1.2 应用描述
|
||||
本系统旨在解决企事业单位在日常办公中面临的文本信息处理效率低下问题,通过引入人工智能技术实现文档的智能理解、信息自动提取、结构化存储以及智能表格填写,帮助用户从繁琐的重复性劳动中解放出来,提升整体工作效率。
|
||||
|
||||
## 2. 核心功能
|
||||
|
||||
### 2.1 文档智能操作交互模块
|
||||
- 支持用户通过自然语言指令对文档进行操作
|
||||
- 自动解析用户指令并执行相应的文档编辑、排版、格式调整、内容提取等操作
|
||||
- 基于自然语言处理与文档结构理解技术实现人机交互
|
||||
|
||||
### 2.2 非结构化文档信息提取模块
|
||||
- 支持用户导入各类非结构化文档(包括但不限于docx、md、xlsx、txt格式)
|
||||
- 自动识别并提取文档中的关键信息、实体数据或用户指定内容
|
||||
- 将提取的信息进行数据库存储
|
||||
- 确保信息提取的准确性和入库的规范性
|
||||
- 支持桌面端、Web网站或第三方平台部署
|
||||
|
||||
### 2.3 表格自定义数据填写模块
|
||||
- 支持用户提供表格模板(word或excel格式)
|
||||
- 从用户提供的非结构化数据中自动搜索相关信息
|
||||
- 将搜索到的信息自动填写到表格中
|
||||
- 生成具备直接业务应用价值的、格式严谨的汇总表格
|
||||
|
||||
## 3. 技术要求
|
||||
|
||||
### 3.1 系统架构
|
||||
- 可基于开源或第三方商业AI平台构建
|
||||
- 也可采用自研创新算法
|
||||
- 系统可运行在H5小程序、原生App、Web网站、PC端软件等平台上
|
||||
|
||||
### 3.2 性能指标
|
||||
- 信息提取准确率需高于80%
|
||||
- 每个文档的响应时间至多为90秒
|
||||
- 支持异步调用的API接口
|
||||
|
||||
### 3.3 数据处理能力
|
||||
- 能够准确识别多种数据类型并在不同数据类型间稳定运行
|
||||
- 支持比赛方提供的测试文档样本集(包括5个docx文档、3个md文档、5个xlsx文档、3个txt文档)
|
||||
|
||||
## 4. 交互流程
|
||||
|
||||
### 4.1 文档上传与处理流程
|
||||
- 用户上传多个文档文件(支持docx、md、xlsx、txt格式)
|
||||
- 系统自动识别并提取文档中的关键信息
|
||||
- 将提取的信息进行数据库存储
|
||||
|
||||
### 4.2 表格填写流程
|
||||
- 用户上传表格模板文件(word或excel格式)
|
||||
- 系统从已存储的非结构化数据中自动搜索相关信息
|
||||
- 将相关信息自动填写到表格中
|
||||
- 完成填写后返回或展示结果表格
|
||||
|
||||
### 4.3 智能交互流程
|
||||
- 用户通过自然语言输入操作指令
|
||||
- 系统解析指令并识别用户需求
|
||||
- 执行相应的文档操作并反馈结果
|
||||
|
||||
## 5. 参考信息
|
||||
|
||||
### 5.1 测试文档样本集
|
||||
- 5个不小于500KB的docx格式文档
|
||||
- 3个不小于15KB的md格式文档
|
||||
- 5个不小于500KB的xlsx格式文档
|
||||
- 3个不小于15KB的txt文档
|
||||
|
||||
### 5.2 评分标准
|
||||
- 信息填写准确率(平均准确率)
|
||||
- 响应时间(平均响应时间)
|
||||
- 准确率差距2%以上时,准确率越高系统越好
|
||||
- 准确率差距小于2%时,结合响应时间综合评价
|
||||
|
||||
## 6. 其他说明
|
||||
|
||||
### 6.1 开发工具
|
||||
- 开发工具及平台不限
|
||||
- 可借助开源工具
|
||||
- 数据与功能API需提供技术说明
|
||||
|
||||
### 6.2 提交材料
|
||||
- 项目概要介绍
|
||||
- 项目简介PPT
|
||||
- 项目详细方案
|
||||
- 项目演示视频
|
||||
- 企业要求提交的材料:
|
||||
- 训练素材详细的素材介绍与来源说明
|
||||
- 关键模块的概要设计和创新要点说明文档
|
||||
- 可运行的Demo实现程序
|
||||
- 团队自愿提交的其他补充材料
|
||||
12
frontend/index.html
Normal file
@@ -0,0 +1,12 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<link rel="icon" type="image/svg+xml" href="/favicon.png" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/main.tsx"></script>
|
||||
</body>
|
||||
</html>
|
||||
96
frontend/package.json
Normal file
@@ -0,0 +1,96 @@
|
||||
{
|
||||
"name": "miaoda-react-admin",
|
||||
"version": "0.0.1",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "vite build",
|
||||
"lint": "tsgo -p tsconfig.check.json; npx biome lint; .rules/check.sh;npx tailwindcss -i ./src/index.css -o /dev/null 2>&1 | grep -E '^(CssSyntaxError|Error):.*' || true;.rules/testBuild.sh"
|
||||
},
|
||||
"dependencies": {
|
||||
"@hookform/resolvers": "^5.2.2",
|
||||
"@radix-ui/react-accordion": "^1.2.12",
|
||||
"@radix-ui/react-alert-dialog": "^1.1.15",
|
||||
"@radix-ui/react-aspect-ratio": "^1.1.7",
|
||||
"@radix-ui/react-avatar": "^1.1.10",
|
||||
"@radix-ui/react-checkbox": "^1.3.3",
|
||||
"@radix-ui/react-collapsible": "^1.1.12",
|
||||
"@radix-ui/react-context-menu": "^2.2.16",
|
||||
"@radix-ui/react-dialog": "^1.1.15",
|
||||
"@radix-ui/react-dropdown-menu": "^2.1.16",
|
||||
"@radix-ui/react-hover-card": "^1.1.15",
|
||||
"@radix-ui/react-icons": "^1.3.2",
|
||||
"@radix-ui/react-label": "^2.1.7",
|
||||
"@radix-ui/react-menubar": "^1.1.16",
|
||||
"@radix-ui/react-navigation-menu": "^1.2.14",
|
||||
"@radix-ui/react-popover": "^1.1.15",
|
||||
"@radix-ui/react-progress": "^1.1.7",
|
||||
"@radix-ui/react-radio-group": "^1.3.8",
|
||||
"@radix-ui/react-scroll-area": "^1.2.10",
|
||||
"@radix-ui/react-select": "^2.2.6",
|
||||
"@radix-ui/react-separator": "^1.1.7",
|
||||
"@radix-ui/react-slider": "^1.3.6",
|
||||
"@radix-ui/react-slot": "^1.2.3",
|
||||
"@radix-ui/react-switch": "^1.2.6",
|
||||
"@radix-ui/react-tabs": "^1.1.13",
|
||||
"@radix-ui/react-toast": "^1.2.15",
|
||||
"@radix-ui/react-toggle": "^1.1.10",
|
||||
"@radix-ui/react-toggle-group": "^1.1.11",
|
||||
"@radix-ui/react-tooltip": "^1.2.8",
|
||||
"@supabase/supabase-js": "^2.98.0",
|
||||
"axios": "^1.13.1",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"cmdk": "^1.1.1",
|
||||
"date-fns": "^3.6.0",
|
||||
"embla-carousel-react": "^8.6.0",
|
||||
"eventsource-parser": "^3.0.6",
|
||||
"framer-motion": "^12.35.2",
|
||||
"input-otp": "^1.4.2",
|
||||
"ky": "^1.13.0",
|
||||
"lucide-react": "^0.576.0",
|
||||
"mammoth": "^1.11.0",
|
||||
"miaoda-auth-react": "2.0.6",
|
||||
"miaoda-sc-plugin": "1.0.56",
|
||||
"motion": "^12.23.25",
|
||||
"next-themes": "^0.4.6",
|
||||
"qrcode": "^1.5.4",
|
||||
"react": "^18.0.0",
|
||||
"react-day-picker": "^9.13.0",
|
||||
"react-dom": "^18.0.0",
|
||||
"react-dropzone": "^14.3.8",
|
||||
"react-helmet-async": "^2.0.5",
|
||||
"react-hook-form": "^7.66.0",
|
||||
"react-markdown": "^10.1.0",
|
||||
"react-resizable-panels": "^2.1.8",
|
||||
"react-router": "^7.9.5",
|
||||
"react-router-dom": "^7.9.5",
|
||||
"recharts": "2.15.4",
|
||||
"remark-gfm": "^4.0.1",
|
||||
"sonner": "^2.0.7",
|
||||
"streamdown": "^1.4.0",
|
||||
"tailwind-merge": "^3.3.1",
|
||||
"tailwindcss-animate": "^1.0.7",
|
||||
"tailwindcss-intersect": "^2.2.0",
|
||||
"vaul": "^1.1.2",
|
||||
"video-react": "^0.16.0",
|
||||
"xlsx": "^0.18.5",
|
||||
"zod": "^3.25.76"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@biomejs/biome": "2.4.5",
|
||||
"@tailwindcss/container-queries": "^0.1.1",
|
||||
"@types/lodash": "^4.17.24",
|
||||
"@types/react": "^19.2.2",
|
||||
"@types/react-dom": "^19.2.2",
|
||||
"@types/video-react": "^0.15.8",
|
||||
"@typescript/native-preview": "7.0.0-dev.20260303.1",
|
||||
"@vitejs/plugin-react": "^5.1.4",
|
||||
"autoprefixer": "^10.4.27",
|
||||
"postcss": "^8.5.6",
|
||||
"tailwindcss": "^3.4.11",
|
||||
"typescript": "~5.9.3",
|
||||
"vite": "npm:rolldown-vite@latest",
|
||||
"vite-plugin-svgr": "^4.5.0"
|
||||
}
|
||||
}
|
||||
7665
frontend/pnpm-lock.yaml
generated
Normal file
6
frontend/pnpm-workspace.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
catalog:
|
||||
'@react-three/drei': 9.122.0
|
||||
'@react-three/fiber': 8.18.0
|
||||
three: 0.180.0
|
||||
|
||||
lockfileIncludeTarballUrl: false
|
||||
6
frontend/postcss.config.js
Normal file
@@ -0,0 +1,6 @@
|
||||
export default {
|
||||
plugins: {
|
||||
tailwindcss: {},
|
||||
autoprefixer: {},
|
||||
},
|
||||
};
|
||||
BIN
frontend/public/favicon.png
Normal file
|
After Width: | Height: | Size: 5.4 KiB |
20
frontend/public/images/error/404-dark.svg
Normal file
@@ -0,0 +1,20 @@
|
||||
<svg width="472" height="158" viewBox="0 0 472 158" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect x="203.103" y="41.7015" width="22.1453" height="20.7141" rx="2.63433" fill="#7592FF" stroke="#7592FF" stroke-width="0.752667"/>
|
||||
<rect x="246.752" y="41.7015" width="22.1453" height="20.7141" rx="2.63433" fill="#7592FF" stroke="#7592FF" stroke-width="0.752667"/>
|
||||
<rect x="258.201" y="98.2308" width="22.1453" height="20.7141" rx="2.63433" fill="#7592FF" stroke="#7592FF" stroke-width="0.752667"/>
|
||||
<rect x="191.654" y="98.2308" width="22.1453" height="20.7141" rx="2.63433" fill="#7592FF" stroke="#7592FF" stroke-width="0.752667"/>
|
||||
<rect x="207.396" y="82.847" width="57.5655" height="20.7141" rx="2.63433" fill="#7592FF" stroke="#7592FF" stroke-width="0.752667"/>
|
||||
<rect x="152.769" y="15.167" width="166.462" height="130.311" rx="28" stroke="#7592FF" stroke-width="24"/>
|
||||
<rect x="0.0405273" y="0.522461" width="32.6255" height="77.5957" rx="6.26271" fill="#7592FF"/>
|
||||
<rect x="0.0405273" y="0.522461" width="32.6255" height="77.5957" rx="6.26271" stroke="#7592FF"/>
|
||||
<rect x="75.8726" y="3.16797" width="32.6255" height="154.31" rx="6.26271" fill="#7592FF"/>
|
||||
<rect x="75.8726" y="3.16797" width="32.6255" height="154.31" rx="6.26271" stroke="#7592FF"/>
|
||||
<rect x="16.7939" y="91.3438" width="32.6255" height="77.5957" rx="6.26271" transform="rotate(-90 16.7939 91.3438)" fill="#7592FF"/>
|
||||
<rect x="16.7939" y="91.3438" width="32.6255" height="77.5957" rx="6.26271" transform="rotate(-90 16.7939 91.3438)" stroke="#7592FF"/>
|
||||
<rect x="363.502" y="0.522461" width="32.6255" height="77.5957" rx="6.26271" fill="#7592FF"/>
|
||||
<rect x="363.502" y="0.522461" width="32.6255" height="77.5957" rx="6.26271" stroke="#7592FF"/>
|
||||
<rect x="439.334" y="3.16797" width="32.6255" height="154.31" rx="6.26271" fill="#7592FF"/>
|
||||
<rect x="439.334" y="3.16797" width="32.6255" height="154.31" rx="6.26271" stroke="#7592FF"/>
|
||||
<rect x="380.255" y="91.3438" width="32.6255" height="77.5957" rx="6.26271" transform="rotate(-90 380.255 91.3438)" fill="#7592FF"/>
|
||||
<rect x="380.255" y="91.3438" width="32.6255" height="77.5957" rx="6.26271" transform="rotate(-90 380.255 91.3438)" stroke="#7592FF"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.1 KiB |
20
frontend/public/images/error/404.svg
Normal file
@@ -0,0 +1,20 @@
|
||||
<svg width="472" height="158" viewBox="0 0 472 158" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect x="203.103" y="41.7015" width="22.1453" height="20.7141" rx="2.63433" fill="#465FFF" stroke="#465FFF" stroke-width="0.752667"/>
|
||||
<rect x="246.752" y="41.7015" width="22.1453" height="20.7141" rx="2.63433" fill="#465FFF" stroke="#465FFF" stroke-width="0.752667"/>
|
||||
<rect x="258.201" y="98.2303" width="22.1453" height="20.7141" rx="2.63433" fill="#465FFF" stroke="#465FFF" stroke-width="0.752667"/>
|
||||
<rect x="191.654" y="98.2303" width="22.1453" height="20.7141" rx="2.63433" fill="#465FFF" stroke="#465FFF" stroke-width="0.752667"/>
|
||||
<rect x="207.396" y="82.847" width="57.5655" height="20.7141" rx="2.63433" fill="#465FFF" stroke="#465FFF" stroke-width="0.752667"/>
|
||||
<rect x="152.769" y="15.167" width="166.462" height="130.311" rx="28" stroke="#465FFF" stroke-width="24"/>
|
||||
<rect x="0.0405273" y="0.522461" width="32.6255" height="77.5957" rx="6.26271" fill="#465FFF"/>
|
||||
<rect x="0.0405273" y="0.522461" width="32.6255" height="77.5957" rx="6.26271" stroke="#465FFF"/>
|
||||
<rect x="75.8726" y="3.16748" width="32.6255" height="154.31" rx="6.26271" fill="#465FFF"/>
|
||||
<rect x="75.8726" y="3.16748" width="32.6255" height="154.31" rx="6.26271" stroke="#465FFF"/>
|
||||
<rect x="16.7939" y="91.3442" width="32.6255" height="77.5957" rx="6.26271" transform="rotate(-90 16.7939 91.3442)" fill="#465FFF"/>
|
||||
<rect x="16.7939" y="91.3442" width="32.6255" height="77.5957" rx="6.26271" transform="rotate(-90 16.7939 91.3442)" stroke="#465FFF"/>
|
||||
<rect x="363.502" y="0.522461" width="32.6255" height="77.5957" rx="6.26271" fill="#465FFF"/>
|
||||
<rect x="363.502" y="0.522461" width="32.6255" height="77.5957" rx="6.26271" stroke="#465FFF"/>
|
||||
<rect x="439.334" y="3.16748" width="32.6255" height="154.31" rx="6.26271" fill="#465FFF"/>
|
||||
<rect x="439.334" y="3.16748" width="32.6255" height="154.31" rx="6.26271" stroke="#465FFF"/>
|
||||
<rect x="380.255" y="91.3442" width="32.6255" height="77.5957" rx="6.26271" transform="rotate(-90 380.255 91.3442)" fill="#465FFF"/>
|
||||
<rect x="380.255" y="91.3442" width="32.6255" height="77.5957" rx="6.26271" transform="rotate(-90 380.255 91.3442)" stroke="#465FFF"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.1 KiB |
24
frontend/public/images/error/500-dark.svg
Normal file
@@ -0,0 +1,24 @@
|
||||
<svg width="562" height="156" viewBox="0 0 562 156" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect x="0.161133" y="13.4297" width="32.6255" height="71" rx="6.26271" fill="#7592FF"/>
|
||||
<rect x="0.161133" y="13.4297" width="32.6255" height="71" rx="6.26271" stroke="#7592FF"/>
|
||||
<rect x="88.2891" y="80.1504" width="32.6255" height="63.5801" rx="6.26271" fill="#7592FF"/>
|
||||
<rect x="88.2891" y="80.1504" width="32.6255" height="63.5801" rx="6.26271" stroke="#7592FF"/>
|
||||
<rect x="15.5254" y="33.4668" width="32.6255" height="105.389" rx="6.26271" transform="rotate(-90 15.5254 33.4668)" fill="#7592FF"/>
|
||||
<rect x="15.5254" y="33.4668" width="32.6255" height="105.389" rx="6.26271" transform="rotate(-90 15.5254 33.4668)" stroke="#7592FF"/>
|
||||
<rect x="0.161133" y="155.16" width="30" height="107.028" rx="6.26271" transform="rotate(-90 0.161133 155.16)" fill="#7592FF"/>
|
||||
<rect x="0.161133" y="155.16" width="30" height="107.028" rx="6.26271" transform="rotate(-90 0.161133 155.16)" stroke="#7592FF"/>
|
||||
<rect x="15.5254" y="96.3398" width="32.6255" height="91.6638" rx="6.26271" transform="rotate(-90 15.5254 96.3398)" fill="#7592FF"/>
|
||||
<rect x="15.5254" y="96.3398" width="32.6255" height="91.6638" rx="6.26271" transform="rotate(-90 15.5254 96.3398)" stroke="#7592FF"/>
|
||||
<rect x="162.915" y="12.8496" width="166.462" height="130.311" rx="28" stroke="#7592FF" stroke-width="24"/>
|
||||
<rect x="213.52" y="42.0287" width="22.1453" height="20.7141" rx="2.63433" fill="#7592FF" stroke="#7592FF" stroke-width="0.752667"/>
|
||||
<rect x="257.168" y="42.0287" width="22.1453" height="20.7141" rx="2.63433" fill="#7592FF" stroke="#7592FF" stroke-width="0.752667"/>
|
||||
<rect x="268.618" y="98.558" width="22.1453" height="20.7141" rx="2.63433" fill="#7592FF" stroke="#7592FF" stroke-width="0.752667"/>
|
||||
<rect x="202.071" y="98.558" width="22.1453" height="20.7141" rx="2.63433" fill="#7592FF" stroke="#7592FF" stroke-width="0.752667"/>
|
||||
<rect x="217.813" y="83.1732" width="57.5655" height="20.7141" rx="2.63433" fill="#7592FF" stroke="#7592FF" stroke-width="0.752667"/>
|
||||
<rect x="383.377" y="12.8496" width="166.462" height="130.311" rx="28" stroke="#7592FF" stroke-width="24"/>
|
||||
<rect x="433.982" y="42.0287" width="22.1453" height="20.7141" rx="2.63433" fill="#7592FF" stroke="#7592FF" stroke-width="0.752667"/>
|
||||
<rect x="477.63" y="42.0287" width="22.1453" height="20.7141" rx="2.63433" fill="#7592FF" stroke="#7592FF" stroke-width="0.752667"/>
|
||||
<rect x="489.079" y="98.558" width="22.1453" height="20.7141" rx="2.63433" fill="#7592FF" stroke="#7592FF" stroke-width="0.752667"/>
|
||||
<rect x="422.533" y="98.558" width="22.1453" height="20.7141" rx="2.63433" fill="#7592FF" stroke="#7592FF" stroke-width="0.752667"/>
|
||||
<rect x="438.275" y="83.1732" width="57.5655" height="20.7141" rx="2.63433" fill="#7592FF" stroke="#7592FF" stroke-width="0.752667"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.8 KiB |
24
frontend/public/images/error/500.svg
Normal file
@@ -0,0 +1,24 @@
|
||||
<svg width="562" height="156" viewBox="0 0 562 156" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect x="0.161133" y="13.4292" width="32.6255" height="71" rx="6.26271" fill="#465FFF"/>
|
||||
<rect x="0.161133" y="13.4292" width="32.6255" height="71" rx="6.26271" stroke="#465FFF"/>
|
||||
<rect x="88.2891" y="80.1499" width="32.6255" height="63.5801" rx="6.26271" fill="#465FFF"/>
|
||||
<rect x="88.2891" y="80.1499" width="32.6255" height="63.5801" rx="6.26271" stroke="#465FFF"/>
|
||||
<rect x="15.5254" y="33.4673" width="32.6255" height="105.389" rx="6.26271" transform="rotate(-90 15.5254 33.4673)" fill="#465FFF"/>
|
||||
<rect x="15.5254" y="33.4673" width="32.6255" height="105.389" rx="6.26271" transform="rotate(-90 15.5254 33.4673)" stroke="#465FFF"/>
|
||||
<rect x="0.161133" y="155.16" width="30" height="107.028" rx="6.26271" transform="rotate(-90 0.161133 155.16)" fill="#465FFF"/>
|
||||
<rect x="0.161133" y="155.16" width="30" height="107.028" rx="6.26271" transform="rotate(-90 0.161133 155.16)" stroke="#465FFF"/>
|
||||
<rect x="15.5254" y="96.3398" width="32.6255" height="91.6638" rx="6.26271" transform="rotate(-90 15.5254 96.3398)" fill="#465FFF"/>
|
||||
<rect x="15.5254" y="96.3398" width="32.6255" height="91.6638" rx="6.26271" transform="rotate(-90 15.5254 96.3398)" stroke="#465FFF"/>
|
||||
<rect x="162.915" y="12.8496" width="166.462" height="130.311" rx="28" stroke="#465FFF" stroke-width="24"/>
|
||||
<rect x="213.52" y="42.0287" width="22.1453" height="20.7141" rx="2.63433" fill="#465FFF" stroke="#465FFF" stroke-width="0.752667"/>
|
||||
<rect x="257.168" y="42.0287" width="22.1453" height="20.7141" rx="2.63433" fill="#465FFF" stroke="#465FFF" stroke-width="0.752667"/>
|
||||
<rect x="268.618" y="98.558" width="22.1453" height="20.7141" rx="2.63433" fill="#465FFF" stroke="#465FFF" stroke-width="0.752667"/>
|
||||
<rect x="202.071" y="98.558" width="22.1453" height="20.7141" rx="2.63433" fill="#465FFF" stroke="#465FFF" stroke-width="0.752667"/>
|
||||
<rect x="217.813" y="83.1732" width="57.5655" height="20.7141" rx="2.63433" fill="#465FFF" stroke="#465FFF" stroke-width="0.752667"/>
|
||||
<rect x="383.377" y="12.8496" width="166.462" height="130.311" rx="28" stroke="#465FFF" stroke-width="24"/>
|
||||
<rect x="433.982" y="42.0287" width="22.1453" height="20.7141" rx="2.63433" fill="#465FFF" stroke="#465FFF" stroke-width="0.752667"/>
|
||||
<rect x="477.63" y="42.0287" width="22.1453" height="20.7141" rx="2.63433" fill="#465FFF" stroke="#465FFF" stroke-width="0.752667"/>
|
||||
<rect x="489.079" y="98.558" width="22.1453" height="20.7141" rx="2.63433" fill="#465FFF" stroke="#465FFF" stroke-width="0.752667"/>
|
||||
<rect x="422.533" y="98.558" width="22.1453" height="20.7141" rx="2.63433" fill="#465FFF" stroke="#465FFF" stroke-width="0.752667"/>
|
||||
<rect x="438.275" y="83.1732" width="57.5655" height="20.7141" rx="2.63433" fill="#465FFF" stroke="#465FFF" stroke-width="0.752667"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.8 KiB |
26
frontend/public/images/error/503-dark.svg
Normal file
@@ -0,0 +1,26 @@
|
||||
<svg width="494" height="156" viewBox="0 0 494 156" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect x="0.515625" y="13.4492" width="32.6255" height="71" rx="6.26271" fill="#7592FF"/>
|
||||
<rect x="0.515625" y="13.4492" width="32.6255" height="71" rx="6.26271" stroke="#7592FF"/>
|
||||
<rect x="88.6436" y="80.1699" width="32.6255" height="63.5801" rx="6.26271" fill="#7592FF"/>
|
||||
<rect x="88.6436" y="80.1699" width="32.6255" height="63.5801" rx="6.26271" stroke="#7592FF"/>
|
||||
<rect x="15.8799" y="33.4863" width="32.6255" height="105.389" rx="6.26271" transform="rotate(-90 15.8799 33.4863)" fill="#7592FF"/>
|
||||
<rect x="15.8799" y="33.4863" width="32.6255" height="105.389" rx="6.26271" transform="rotate(-90 15.8799 33.4863)" stroke="#7592FF"/>
|
||||
<rect x="0.515625" y="155.18" width="30" height="107.028" rx="6.26271" transform="rotate(-90 0.515625 155.18)" fill="#7592FF"/>
|
||||
<rect x="0.515625" y="155.18" width="30" height="107.028" rx="6.26271" transform="rotate(-90 0.515625 155.18)" stroke="#7592FF"/>
|
||||
<rect x="15.8799" y="96.3594" width="32.6255" height="91.6638" rx="6.26271" transform="rotate(-90 15.8799 96.3594)" fill="#7592FF"/>
|
||||
<rect x="15.8799" y="96.3594" width="32.6255" height="91.6638" rx="6.26271" transform="rotate(-90 15.8799 96.3594)" stroke="#7592FF"/>
|
||||
<rect x="163.27" y="12.8691" width="166.462" height="130.311" rx="28" stroke="#7592FF" stroke-width="24"/>
|
||||
<rect x="213.874" y="42.0482" width="22.1453" height="20.7141" rx="2.63433" fill="#7592FF" stroke="#7592FF" stroke-width="0.752667"/>
|
||||
<rect x="257.523" y="42.0482" width="22.1453" height="20.7141" rx="2.63433" fill="#7592FF" stroke="#7592FF" stroke-width="0.752667"/>
|
||||
<rect x="268.972" y="98.5775" width="22.1453" height="20.7141" rx="2.63433" fill="#7592FF" stroke="#7592FF" stroke-width="0.752667"/>
|
||||
<rect x="202.425" y="98.5775" width="22.1453" height="20.7141" rx="2.63433" fill="#7592FF" stroke="#7592FF" stroke-width="0.752667"/>
|
||||
<rect x="218.167" y="83.1927" width="57.5655" height="20.7141" rx="2.63433" fill="#7592FF" stroke="#7592FF" stroke-width="0.752667"/>
|
||||
<rect x="460.859" y="11.1885" width="32.6255" height="132.562" rx="6.26271" fill="#7592FF"/>
|
||||
<rect x="460.859" y="11.1885" width="32.6255" height="132.562" rx="6.26271" stroke="#7592FF"/>
|
||||
<rect x="371.731" y="33.4453" width="32.6255" height="107.028" rx="6.26271" transform="rotate(-90 371.731 33.4453)" fill="#7592FF"/>
|
||||
<rect x="371.731" y="33.4453" width="32.6255" height="107.028" rx="6.26271" transform="rotate(-90 371.731 33.4453)" stroke="#7592FF"/>
|
||||
<rect x="371.731" y="155.18" width="30" height="107.028" rx="6.26271" transform="rotate(-90 371.731 155.18)" fill="#7592FF"/>
|
||||
<rect x="371.731" y="155.18" width="30" height="107.028" rx="6.26271" transform="rotate(-90 371.731 155.18)" stroke="#7592FF"/>
|
||||
<rect x="388.096" y="93.7812" width="32.6255" height="91.6638" rx="6.26271" transform="rotate(-90 388.096 93.7812)" fill="#7592FF"/>
|
||||
<rect x="388.096" y="93.7812" width="32.6255" height="91.6638" rx="6.26271" transform="rotate(-90 388.096 93.7812)" stroke="#7592FF"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 3.0 KiB |
26
frontend/public/images/error/503.svg
Normal file
@@ -0,0 +1,26 @@
|
||||
<svg width="494" height="156" viewBox="0 0 494 156" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect x="0.515625" y="13.4492" width="32.6255" height="71" rx="6.26271" fill="#465FFF"/>
|
||||
<rect x="0.515625" y="13.4492" width="32.6255" height="71" rx="6.26271" stroke="#465FFF"/>
|
||||
<rect x="88.6436" y="80.1699" width="32.6255" height="63.5801" rx="6.26271" fill="#465FFF"/>
|
||||
<rect x="88.6436" y="80.1699" width="32.6255" height="63.5801" rx="6.26271" stroke="#465FFF"/>
|
||||
<rect x="15.8799" y="33.4873" width="32.6255" height="105.389" rx="6.26271" transform="rotate(-90 15.8799 33.4873)" fill="#465FFF"/>
|
||||
<rect x="15.8799" y="33.4873" width="32.6255" height="105.389" rx="6.26271" transform="rotate(-90 15.8799 33.4873)" stroke="#465FFF"/>
|
||||
<rect x="0.515625" y="155.18" width="30" height="107.028" rx="6.26271" transform="rotate(-90 0.515625 155.18)" fill="#465FFF"/>
|
||||
<rect x="0.515625" y="155.18" width="30" height="107.028" rx="6.26271" transform="rotate(-90 0.515625 155.18)" stroke="#465FFF"/>
|
||||
<rect x="15.8799" y="96.3599" width="32.6255" height="91.6638" rx="6.26271" transform="rotate(-90 15.8799 96.3599)" fill="#465FFF"/>
|
||||
<rect x="15.8799" y="96.3599" width="32.6255" height="91.6638" rx="6.26271" transform="rotate(-90 15.8799 96.3599)" stroke="#465FFF"/>
|
||||
<rect x="163.27" y="12.8696" width="166.462" height="130.311" rx="28" stroke="#465FFF" stroke-width="24"/>
|
||||
<rect x="213.874" y="42.0487" width="22.1453" height="20.7141" rx="2.63433" fill="#465FFF" stroke="#465FFF" stroke-width="0.752667"/>
|
||||
<rect x="257.523" y="42.0487" width="22.1453" height="20.7141" rx="2.63433" fill="#465FFF" stroke="#465FFF" stroke-width="0.752667"/>
|
||||
<rect x="268.972" y="98.578" width="22.1453" height="20.7141" rx="2.63433" fill="#465FFF" stroke="#465FFF" stroke-width="0.752667"/>
|
||||
<rect x="202.425" y="98.578" width="22.1453" height="20.7141" rx="2.63433" fill="#465FFF" stroke="#465FFF" stroke-width="0.752667"/>
|
||||
<rect x="218.167" y="83.1932" width="57.5655" height="20.7141" rx="2.63433" fill="#465FFF" stroke="#465FFF" stroke-width="0.752667"/>
|
||||
<rect x="460.859" y="11.188" width="32.6255" height="132.562" rx="6.26271" fill="#465FFF"/>
|
||||
<rect x="460.859" y="11.188" width="32.6255" height="132.562" rx="6.26271" stroke="#465FFF"/>
|
||||
<rect x="371.731" y="33.4458" width="32.6255" height="107.028" rx="6.26271" transform="rotate(-90 371.731 33.4458)" fill="#465FFF"/>
|
||||
<rect x="371.731" y="33.4458" width="32.6255" height="107.028" rx="6.26271" transform="rotate(-90 371.731 33.4458)" stroke="#465FFF"/>
|
||||
<rect x="371.731" y="155.18" width="30" height="107.028" rx="6.26271" transform="rotate(-90 371.731 155.18)" fill="#465FFF"/>
|
||||
<rect x="371.731" y="155.18" width="30" height="107.028" rx="6.26271" transform="rotate(-90 371.731 155.18)" stroke="#465FFF"/>
|
||||
<rect x="388.096" y="93.7812" width="32.6255" height="91.6638" rx="6.26271" transform="rotate(-90 388.096 93.7812)" fill="#465FFF"/>
|
||||
<rect x="388.096" y="93.7812" width="32.6255" height="91.6638" rx="6.26271" transform="rotate(-90 388.096 93.7812)" stroke="#465FFF"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.9 KiB |