完成后端数据库连接配置

This commit is contained in:
2026-03-26 19:49:40 +08:00
parent d3bdb17e87
commit 4bdc3f9707
19 changed files with 2843 additions and 302 deletions

View File

@@ -1,16 +1,50 @@
# 基础配置
# ============================================================
# 基于大语言模型的文档理解与多源数据融合系统
# 环境变量配置文件
# ============================================================
# 复制此文件为 .env 并填入实际值
# ==================== 应用基础配置 ====================
APP_NAME="FilesReadSystem"
DEBUG=true
API_V1_STR="/api/v1"
# 数据库
MONGODB_URL="mongodb://username:password@host:port"
MONGODB_DB_NAME=""
# ==================== MongoDB 配置 ====================
# 非结构化数据存储 (原始文档、解析结果)
MONGODB_URL="mongodb://localhost:27017"
MONGODB_DB_NAME="document_system"
# ==================== MySQL 配置 ====================
# 结构化数据存储 (Excel表格、查询结果)
MYSQL_HOST="localhost"
MYSQL_PORT=3306
MYSQL_USER="root"
MYSQL_PASSWORD="your_password_here"
MYSQL_DATABASE="document_system"
MYSQL_CHARSET="utf8mb4"
# ==================== Redis 配置 ====================
# 缓存/任务队列
REDIS_URL="redis://localhost:6379/0"
# 大模型 API
LLM_API_KEY=""
LLM_BASE_URL=""
# ==================== LLM AI 配置 ====================
# 大语言模型 API 配置
LLM_API_KEY="your_api_key_here"
LLM_BASE_URL="https://api.minimax.chat/v1"
LLM_MODEL_NAME="MiniMax-Text-01"
# 文件存储配置
# ==================== 文件路径配置 ====================
# 上传文件存储目录 (相对于项目根目录)
UPLOAD_DIR="./data/uploads"
MAX_UPLOAD_SIZE=104857600 # 100MB
# ChromaDB 向量数据库持久化目录
CHROMADB_PERSIST_DIR="./data/chromadb"
# ==================== RAG 配置 ====================
# Embedding 模型名称
EMBEDDING_MODEL="all-MiniLM-L6-v2"
# ==================== Celery 配置 ====================
# 异步任务队列 Broker
CELERY_BROKER_URL="redis://localhost:6379/1"
CELERY_RESULT_BACKEND="redis://localhost:6379/2"

View File

@@ -2,13 +2,28 @@
API 路由注册模块
"""
from fastapi import APIRouter
from app.api.endpoints import upload, ai_analyze, visualization, analysis_charts
from app.api.endpoints import (
upload,
documents, # 新增:文档上传
tasks, # 新增:任务管理
library, # 新增:文档库
rag, # 新增RAG检索
ai_analyze,
visualization,
analysis_charts,
health,
)
# 创建主路由
api_router = APIRouter()
# 注册各模块路由
api_router.include_router(upload.router)
api_router.include_router(ai_analyze.router)
api_router.include_router(visualization.router)
api_router.include_router(analysis_charts.router)
api_router.include_router(health.router) # 健康检查
api_router.include_router(upload.router) # 原有Excel上传
api_router.include_router(documents.router) # 多格式文档上传
api_router.include_router(tasks.router) # 任务状态查询
api_router.include_router(library.router) # 文档库管理
api_router.include_router(rag.router) # RAG检索
api_router.include_router(ai_analyze.router) # AI分析
api_router.include_router(visualization.router) # 可视化
api_router.include_router(analysis_charts.router) # 分析图表

View File

@@ -0,0 +1,371 @@
"""
文档管理 API 接口
支持多格式文档(docx/xlsx/md/txt)上传、解析、存储和RAG索引
"""
import uuid
from datetime import datetime
from typing import List, Optional
from fastapi import APIRouter, UploadFile, File, HTTPException, Query, BackgroundTasks
from pydantic import BaseModel
from app.services.file_service import file_service
from app.core.database import mongodb, mysql_db
from app.services.rag_service import rag_service
from app.core.document_parser import ParserFactory, ParseResult
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/upload", tags=["文档上传"])
# ==================== 请求/响应模型 ====================
class UploadResponse(BaseModel):
task_id: str
file_count: int
message: str
status_url: str
class TaskStatusResponse(BaseModel):
task_id: str
status: str # pending, processing, success, failure
progress: int = 0
message: Optional[str] = None
result: Optional[dict] = None
error: Optional[str] = None
# ==================== 文档上传接口 ====================
@router.post("/document", response_model=UploadResponse)
async def upload_document(
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
doc_type: Optional[str] = Query(None, description="文档类型: docx/xlsx/md/txt"),
parse_all_sheets: bool = Query(False, description="是否解析所有工作表(仅Excel)"),
sheet_name: Optional[str] = Query(None, description="指定工作表(仅Excel)"),
header_row: int = Query(0, description="表头行号(仅Excel)")
):
"""
上传单个文档并异步处理
文档会:
1. 保存到本地存储
2. 解析内容
3. 存入 MongoDB (原始内容)
4. 如果是 Excel存入 MySQL (结构化数据)
5. 建立 RAG 索引
"""
if not file.filename:
raise HTTPException(status_code=400, detail="文件名为空")
# 根据扩展名确定文档类型
file_ext = file.filename.split('.')[-1].lower()
if file_ext not in ['docx', 'xlsx', 'xls', 'md', 'txt']:
raise HTTPException(
status_code=400,
detail=f"不支持的文件类型: {file_ext},仅支持 docx/xlsx/xls/md/txt"
)
# 生成任务ID
task_id = str(uuid.uuid4())
try:
# 读取文件内容
content = await file.read()
# 保存文件
saved_path = file_service.save_uploaded_file(
content,
file.filename,
subfolder=file_ext
)
# 后台处理文档
background_tasks.add_task(
process_document,
task_id=task_id,
file_path=saved_path,
original_filename=file.filename,
doc_type=file_ext,
parse_options={
"parse_all_sheets": parse_all_sheets,
"sheet_name": sheet_name,
"header_row": header_row
}
)
return UploadResponse(
task_id=task_id,
file_count=1,
message=f"文档 {file.filename} 已提交处理",
status_url=f"/api/v1/tasks/{task_id}"
)
except Exception as e:
logger.error(f"上传文档失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
@router.post("/documents", response_model=UploadResponse)
async def upload_documents(
background_tasks: BackgroundTasks,
files: List[UploadFile] = File(...),
doc_type: Optional[str] = Query(None, description="文档类型")
):
"""
批量上传文档
所有文档会异步处理,处理完成后可通过 task_id 查询状态
"""
if not files:
raise HTTPException(status_code=400, detail="没有上传文件")
task_id = str(uuid.uuid4())
saved_paths = []
try:
for file in files:
if not file.filename:
continue
content = await file.read()
saved_path = file_service.save_uploaded_file(
content,
file.filename,
subfolder="batch"
)
saved_paths.append({
"path": saved_path,
"filename": file.filename,
"ext": file.filename.split('.')[-1].lower()
})
# 后台处理所有文档
background_tasks.add_task(
process_documents_batch,
task_id=task_id,
files=saved_paths
)
return UploadResponse(
task_id=task_id,
file_count=len(saved_paths),
message=f"已提交 {len(saved_paths)} 个文档处理",
status_url=f"/api/v1/tasks/{task_id}"
)
except Exception as e:
logger.error(f"批量上传失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"批量上传失败: {str(e)}")
# ==================== 任务处理函数 ====================
async def process_document(
task_id: str,
file_path: str,
original_filename: str,
doc_type: str,
parse_options: dict
):
"""处理单个文档"""
from app.core.database import redis_db
try:
# 更新状态: 处理中
await redis_db.set_task_status(
task_id,
status="processing",
meta={"progress": 10, "message": "正在解析文档"}
)
# 解析文档
parser = ParserFactory.get_parser(file_path)
result = parser.parse(file_path)
if not result.success:
raise Exception(result.error or "解析失败")
# 更新状态: 存储数据
await redis_db.set_task_status(
task_id,
status="processing",
meta={"progress": 40, "message": "正在存储数据"}
)
# 存储到 MongoDB
doc_id = await mongodb.insert_document(
doc_type=doc_type,
content=result.data.get("content", ""),
metadata={
**result.metadata,
"original_filename": original_filename,
"file_path": file_path
},
structured_data=result.data.get("structured_data")
)
# 如果是 Excel存储到 MySQL
if doc_type in ["xlsx", "xls"]:
await store_excel_to_mysql(file_path, original_filename, result)
# 更新状态: 建立 RAG 索引
await redis_db.set_task_status(
task_id,
status="processing",
meta={"progress": 70, "message": "正在建立索引"}
)
# 建立 RAG 索引
await index_document_to_rag(doc_id, original_filename, result, doc_type)
# 更新状态: 完成
await redis_db.set_task_status(
task_id,
status="success",
meta={
"progress": 100,
"message": "处理完成",
"doc_id": doc_id,
"result": {
"doc_id": doc_id,
"doc_type": doc_type,
"filename": original_filename
}
}
)
logger.info(f"文档处理完成: {original_filename}, doc_id: {doc_id}")
except Exception as e:
logger.error(f"文档处理失败: {str(e)}")
await redis_db.set_task_status(
task_id,
status="failure",
meta={"error": str(e)}
)
async def process_documents_batch(task_id: str, files: List[dict]):
"""批量处理文档"""
from app.core.database import redis_db
try:
await redis_db.set_task_status(
task_id,
status="processing",
meta={"progress": 0, "message": "开始批量处理"}
)
results = []
for i, file_info in enumerate(files):
try:
parser = ParserFactory.get_parser(file_info["path"])
result = parser.parse(file_info["path"])
if result.success:
doc_id = await mongodb.insert_document(
doc_type=file_info["ext"],
content=result.data.get("content", ""),
metadata={
**result.metadata,
"original_filename": file_info["filename"],
"file_path": file_info["path"]
},
structured_data=result.data.get("structured_data")
)
results.append({"filename": file_info["filename"], "doc_id": doc_id, "success": True})
else:
results.append({"filename": file_info["filename"], "success": False, "error": result.error})
except Exception as e:
results.append({"filename": file_info["filename"], "success": False, "error": str(e)})
# 更新进度
progress = int((i + 1) / len(files) * 100)
await redis_db.set_task_status(
task_id,
status="processing",
meta={"progress": progress, "message": f"已处理 {i+1}/{len(files)}"}
)
await redis_db.set_task_status(
task_id,
status="success",
meta={"progress": 100, "message": "批量处理完成", "results": results}
)
except Exception as e:
logger.error(f"批量处理失败: {str(e)}")
await redis_db.set_task_status(
task_id,
status="failure",
meta={"error": str(e)}
)
async def store_excel_to_mysql(file_path: str, filename: str, result: ParseResult):
"""将 Excel 数据存储到 MySQL"""
# TODO: 实现 Excel 数据到 MySQL 的转换和存储
# 需要根据表头动态创建表结构
pass
async def index_document_to_rag(doc_id: str, filename: str, result: ParseResult, doc_type: str):
"""将文档索引到 RAG"""
try:
if doc_type in ["xlsx", "xls"]:
# Excel 文件: 索引字段信息
columns = result.metadata.get("columns", [])
for col in columns:
rag_service.index_field(
table_name=filename,
field_name=col,
field_description=f"Excel表格 {filename} 的列 {col}",
sample_values=None
)
else:
# 其他文档: 索引文档内容
content = result.data.get("content", "")
if content:
rag_service.index_document_content(
doc_id=doc_id,
content=content[:5000], # 限制长度
metadata={
"filename": filename,
"doc_type": doc_type
}
)
except Exception as e:
logger.warning(f"RAG 索引失败: {str(e)}")
# ==================== 文档解析接口 ====================
@router.post("/document/parse")
async def parse_uploaded_document(
file_path: str = Query(..., description="文件路径")
):
"""解析已上传的文档"""
try:
parser = ParserFactory.get_parser(file_path)
result = parser.parse(file_path)
if result.success:
return result.to_dict()
else:
raise HTTPException(status_code=400, detail=result.error)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"解析文档失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"解析失败: {str(e)}")
# 需要添加 import
import logging

View File

@@ -0,0 +1,76 @@
"""
健康检查接口
"""
from datetime import datetime
from typing import Any, Dict
from fastapi import APIRouter
from app.core.database import mysql_db, mongodb, redis_db
router = APIRouter(tags=["健康检查"])
@router.get("/health")
async def health_check() -> Dict[str, Any]:
"""
健康检查接口
返回各数据库连接状态和应用信息
"""
# 检查各数据库连接状态
mysql_status = "connected"
mongodb_status = "connected"
redis_status = "connected"
try:
if mysql_db.async_engine is None:
mysql_status = "disconnected"
except Exception:
mysql_status = "error"
try:
if mongodb.client is None:
mongodb_status = "disconnected"
except Exception:
mongodb_status = "error"
try:
if not redis_db.is_connected:
redis_status = "disconnected"
except Exception:
redis_status = "error"
return {
"status": "healthy" if all([
mysql_status == "connected",
mongodb_status == "connected",
redis_status == "connected"
]) else "degraded",
"timestamp": datetime.utcnow().isoformat(),
"services": {
"mysql": mysql_status,
"mongodb": mongodb_status,
"redis": redis_status,
}
}
@router.get("/health/ready")
async def readiness_check() -> Dict[str, str]:
"""
就绪检查接口
用于 Kubernetes/负载均衡器检查服务是否就绪
"""
return {"status": "ready"}
@router.get("/health/live")
async def liveness_check() -> Dict[str, str]:
"""
存活检查接口
用于 Kubernetes/负载均衡器检查服务是否存活
"""
return {"status": "alive"}

View File

@@ -0,0 +1,139 @@
"""
文档库管理 API 接口
提供文档列表、详情查询和删除功能
"""
from typing import Optional, List
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
from app.core.database import mongodb
router = APIRouter(prefix="/documents", tags=["文档库"])
class DocumentItem(BaseModel):
doc_id: str
filename: str
original_filename: str
doc_type: str
file_size: int
created_at: str
metadata: Optional[dict] = None
@router.get("")
async def get_documents(
doc_type: Optional[str] = Query(None, description="文档类型过滤"),
limit: int = Query(50, ge=1, le=100, description="返回数量")
):
"""
获取文档列表
Returns:
文档列表
"""
try:
# 构建查询条件
query = {}
if doc_type:
query["doc_type"] = doc_type
# 查询文档
cursor = mongodb.documents.find(query).sort("created_at", -1).limit(limit)
documents = []
async for doc in cursor:
documents.append({
"doc_id": str(doc["_id"]),
"filename": doc.get("metadata", {}).get("filename", ""),
"original_filename": doc.get("metadata", {}).get("original_filename", ""),
"doc_type": doc.get("doc_type", ""),
"file_size": doc.get("metadata", {}).get("file_size", 0),
"created_at": doc.get("created_at", "").isoformat() if doc.get("created_at") else "",
"metadata": {
"row_count": doc.get("metadata", {}).get("row_count"),
"column_count": doc.get("metadata", {}).get("column_count"),
"columns": doc.get("metadata", {}).get("columns", [])[:10] # 只返回前10列
}
})
return {
"success": True,
"documents": documents,
"total": len(documents)
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"获取文档列表失败: {str(e)}")
@router.get("/{doc_id}")
async def get_document(doc_id: str):
"""
获取文档详情
Args:
doc_id: 文档ID
Returns:
文档详情
"""
try:
doc = await mongodb.get_document(doc_id)
if not doc:
raise HTTPException(status_code=404, detail="文档不存在")
return {
"success": True,
"document": {
"doc_id": str(doc["_id"]),
"filename": doc.get("metadata", {}).get("filename", ""),
"original_filename": doc.get("metadata", {}).get("original_filename", ""),
"doc_type": doc.get("doc_type", ""),
"file_size": doc.get("metadata", {}).get("file_size", 0),
"created_at": doc.get("created_at", "").isoformat() if doc.get("created_at") else "",
"content": doc.get("content", ""), # 原始文本内容
"structured_data": doc.get("structured_data"), # 结构化数据(如果有)
"metadata": doc.get("metadata", {})
}
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"获取文档详情失败: {str(e)}")
@router.delete("/{doc_id}")
async def delete_document(doc_id: str):
"""
删除文档
Args:
doc_id: 文档ID
Returns:
删除结果
"""
try:
# 从 MongoDB 删除
deleted = await mongodb.delete_document(doc_id)
if not deleted:
raise HTTPException(status_code=404, detail="文档不存在")
# TODO: 从 MySQL 删除相关数据(如果是Excel)
# TODO: 从 RAG 删除相关索引
return {
"success": True,
"message": "文档已删除"
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}")

View File

@@ -0,0 +1,116 @@
"""
RAG 检索 API 接口
提供向量检索功能
"""
from typing import Optional
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
from app.services.rag_service import rag_service
router = APIRouter(prefix="/rag", tags=["RAG检索"])
class SearchRequest(BaseModel):
query: str
top_k: int = 5
class SearchResult(BaseModel):
content: str
metadata: dict
score: float
doc_id: str
@router.post("/search")
async def search_rag(
request: SearchRequest
):
"""
RAG 语义检索
根据查询文本检索相关的文档片段或字段
Args:
request.query: 查询文本
request.top_k: 返回数量
Returns:
相关文档列表
"""
try:
results = rag_service.retrieve(
query=request.query,
top_k=request.top_k
)
return {
"success": True,
"results": results
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"检索失败: {str(e)}")
@router.get("/status")
async def get_rag_status():
"""
获取 RAG 索引状态
Returns:
RAG 索引统计信息
"""
try:
count = rag_service.get_vector_count()
return {
"success": True,
"vector_count": count,
"collections": ["document_fields", "document_content"] # 预留
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}")
@router.post("/rebuild")
async def rebuild_rag_index():
"""
重建 RAG 索引
从 MongoDB 中读取所有文档,重新构建向量索引
"""
from app.core.database import mongodb
try:
# 清空现有索引
rag_service.clear()
# 从 MongoDB 读取所有文档
cursor = mongodb.documents.find({})
count = 0
async for doc in cursor:
content = doc.get("content", "")
if content:
rag_service.index_document_content(
doc_id=str(doc["_id"]),
content=content[:5000],
metadata={
"filename": doc.get("metadata", {}).get("filename"),
"doc_type": doc.get("doc_type")
}
)
count += 1
return {
"success": True,
"message": f"已重建索引,共处理 {count} 个文档"
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"重建索引失败: {str(e)}")

View File

@@ -0,0 +1,38 @@
"""
任务管理 API 接口
提供异步任务状态查询
"""
from typing import Optional
from fastapi import APIRouter, HTTPException
from app.core.database import redis_db
router = APIRouter(prefix="/tasks", tags=["任务管理"])
@router.get("/{task_id}")
async def get_task_status(task_id: str):
"""
查询任务状态
Args:
task_id: 任务ID
Returns:
任务状态信息
"""
status = await redis_db.get_task_status(task_id)
if not status:
raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
return {
"task_id": task_id,
"status": status.get("status", "unknown"),
"progress": status.get("meta", {}).get("progress", 0),
"message": status.get("meta", {}).get("message"),
"result": status.get("meta", {}).get("result"),
"error": status.get("meta", {}).get("error")
}

View File

@@ -7,20 +7,35 @@ class Settings(BaseSettings):
DEBUG: bool = True
API_V1_STR: str = "/api/v1"
# 数据库
MONGODB_URL: str
MONGODB_DB_NAME: str
REDIS_URL: str
# ==================== 数据库配置 ====================
# AI 相关
LLM_API_KEY: str
LLM_BASE_URL: str
LLM_MODEL_NAME: str
# MongoDB 配置 (非结构化数据存储)
MONGODB_URL: str = "mongodb://localhost:27017"
MONGODB_DB_NAME: str = "document_system"
# 文件路径
# MySQL 配置 (结构化数据存储)
MYSQL_HOST: str = "localhost"
MYSQL_PORT: int = 3306
MYSQL_USER: str = "root"
MYSQL_PASSWORD: str = ""
MYSQL_DATABASE: str = "document_system"
MYSQL_CHARSET: str = "utf8mb4"
# Redis 配置 (缓存/任务队列)
REDIS_URL: str = "redis://localhost:6379/0"
# ==================== AI 相关配置 ====================
LLM_API_KEY: str = ""
LLM_BASE_URL: str = "https://api.minimax.chat"
LLM_MODEL_NAME: str = "MiniMax-Text-01"
# ==================== 文件路径配置 ====================
BASE_DIR: Path = Path(__file__).resolve().parent.parent.parent
UPLOAD_DIR: str = "data/uploads"
# ==================== RAG/向量数据库配置 ====================
CHROMADB_PERSIST_DIR: str = "data/chromadb"
# 允许 Pydantic 从 .env 文件读取
model_config = SettingsConfigDict(
env_file=Path(__file__).parent.parent / ".env",
@@ -28,4 +43,13 @@ class Settings(BaseSettings):
extra='ignore'
)
@property
def mysql_url(self) -> str:
"""生成MySQL连接URL"""
return (
f"mysql+pymysql://{self.MYSQL_USER}:{self.MYSQL_PASSWORD}"
f"@{self.MYSQL_HOST}:{self.MYSQL_PORT}/{self.MYSQL_DATABASE}"
f"?charset={self.MYSQL_CHARSET}"
)
settings = Settings()

View File

@@ -0,0 +1,18 @@
"""
数据库连接管理模块
提供 MySQL、MongoDB、Redis 的连接管理
"""
from app.core.database.mysql import MySQLDB, mysql_db, Base
from app.core.database.mongodb import MongoDB, mongodb
from app.core.database.redis_db import RedisDB, redis_db
__all__ = [
"MySQLDB",
"mysql_db",
"MongoDB",
"mongodb",
"RedisDB",
"redis_db",
"Base",
]

View File

@@ -0,0 +1,247 @@
"""
MongoDB 数据库连接管理模块
提供非结构化数据的存储和查询功能
"""
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
from app.config import settings
logger = logging.getLogger(__name__)
class MongoDB:
"""MongoDB 数据库管理类"""
def __init__(self):
self.client: Optional[AsyncIOMotorClient] = None
self.db: Optional[AsyncIOMotorDatabase] = None
async def connect(self):
"""建立 MongoDB 连接"""
try:
self.client = AsyncIOMotorClient(
settings.MONGODB_URL,
serverSelectionTimeoutMS=5000,
)
self.db = self.client[settings.MONGODB_DB_NAME]
# 验证连接
await self.client.admin.command('ping')
logger.info(f"MongoDB 连接成功: {settings.MONGODB_DB_NAME}")
except Exception as e:
logger.error(f"MongoDB 连接失败: {e}")
raise
async def close(self):
"""关闭 MongoDB 连接"""
if self.client:
self.client.close()
logger.info("MongoDB 连接已关闭")
@property
def documents(self):
"""文档集合 - 存储原始文档和解析结果"""
return self.db["documents"]
@property
def embeddings(self):
"""向量嵌入集合 - 存储文本嵌入向量"""
return self.db["embeddings"]
@property
def rag_index(self):
"""RAG索引集合 - 存储字段语义索引"""
return self.db["rag_index"]
# ==================== 文档操作 ====================
async def insert_document(
self,
doc_type: str,
content: str,
metadata: Dict[str, Any],
structured_data: Optional[Dict[str, Any]] = None,
) -> str:
"""
插入文档
Args:
doc_type: 文档类型 (docx/xlsx/md/txt)
content: 原始文本内容
metadata: 元数据
structured_data: 结构化数据 (表格等)
Returns:
插入文档的ID
"""
document = {
"doc_type": doc_type,
"content": content,
"metadata": metadata,
"structured_data": structured_data,
"created_at": datetime.utcnow(),
"updated_at": datetime.utcnow(),
}
result = await self.documents.insert_one(document)
logger.info(f"文档已插入MongoDB: {result.inserted_id}")
return str(result.inserted_id)
async def get_document(self, doc_id: str) -> Optional[Dict[str, Any]]:
"""根据ID获取文档"""
from bson import ObjectId
doc = await self.documents.find_one({"_id": ObjectId(doc_id)})
if doc:
doc["_id"] = str(doc["_id"])
return doc
async def search_documents(
self,
query: str,
doc_type: Optional[str] = None,
limit: int = 10,
) -> List[Dict[str, Any]]:
"""
搜索文档
Args:
query: 搜索关键词
doc_type: 文档类型过滤
limit: 返回数量
Returns:
文档列表
"""
filter_query = {"content": {"$regex": query}}
if doc_type:
filter_query["doc_type"] = doc_type
cursor = self.documents.find(filter_query).limit(limit)
documents = []
async for doc in cursor:
doc["_id"] = str(doc["_id"])
documents.append(doc)
return documents
async def delete_document(self, doc_id: str) -> bool:
"""删除文档"""
from bson import ObjectId
result = await self.documents.delete_one({"_id": ObjectId(doc_id)})
return result.deleted_count > 0
# ==================== RAG 索引操作 ====================
async def insert_rag_entry(
self,
table_name: str,
field_name: str,
field_description: str,
embedding: List[float],
metadata: Optional[Dict[str, Any]] = None,
) -> str:
"""
插入RAG索引条目
Args:
table_name: 表名
field_name: 字段名
field_description: 字段描述
embedding: 向量嵌入
metadata: 其他元数据
Returns:
插入条目的ID
"""
entry = {
"table_name": table_name,
"field_name": field_name,
"field_description": field_description,
"embedding": embedding,
"metadata": metadata or {},
"created_at": datetime.utcnow(),
}
result = await self.rag_index.insert_one(entry)
return str(result.inserted_id)
async def search_rag(
self,
query_embedding: List[float],
top_k: int = 5,
table_name: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
搜索RAG索引 (使用向量相似度)
Args:
query_embedding: 查询向量
top_k: 返回数量
table_name: 可选的表名过滤
Returns:
相关的索引条目
"""
# MongoDB 5.0+ 支持向量搜索
# 较低版本使用欧氏距离替代
pipeline = [
{
"$addFields": {
"distance": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{
"$pow": [
{
"$subtract": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]},
]
},
2,
]
},
]
},
}
}
}
},
{"$sort": {"distance": 1}},
{"$limit": top_k},
]
if table_name:
pipeline.insert(0, {"$match": {"table_name": table_name}})
results = []
async for doc in self.rag_index.aggregate(pipeline):
doc["_id"] = str(doc["_id"])
results.append(doc)
return results
# ==================== 集合管理 ====================
async def create_indexes(self):
"""创建索引以优化查询"""
# 文档集合索引
await self.documents.create_index("doc_type")
await self.documents.create_index("created_at")
await self.documents.create_index([("content", "text")])
# RAG索引集合索引
await self.rag_index.create_index("table_name")
await self.rag_index.create_index("field_name")
await self.rag_index.create_index([("embedding", "hnsw", {"type": "knnVector"})])
logger.info("MongoDB 索引创建完成")
# ==================== 全局单例 ====================
mongodb = MongoDB()

View File

@@ -0,0 +1,193 @@
"""
MySQL 数据库连接管理模块
提供结构化数据的存储和查询功能
"""
import logging
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, List, Optional
from sqlalchemy import (
Column,
DateTime,
Enum as SQLEnum,
Float,
Integer,
String,
Text,
create_engine,
)
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, sessionmaker
from sqlalchemy.sql import select
from app.config import settings
logger = logging.getLogger(__name__)
class Base(DeclarativeBase):
"""SQLAlchemy 声明基类"""
pass
class MySQLDB:
"""MySQL 数据库管理类"""
def __init__(self):
# 异步引擎 (用于 FastAPI 异步操作)
self.async_engine = create_async_engine(
settings.mysql_url,
echo=settings.DEBUG, # SQL 日志
pool_pre_ping=True, # 连接前检测
pool_size=10,
max_overflow=20,
)
# 异步会话工厂
self.async_session_factory = async_sessionmaker(
bind=self.async_engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
# 同步引擎 (用于 Celery 同步任务)
self.sync_engine = create_engine(
settings.mysql_url.replace("mysql+pymysql", "mysql"),
echo=settings.DEBUG,
pool_pre_ping=True,
pool_size=5,
max_overflow=10,
)
# 同步会话工厂
self.sync_session_factory = sessionmaker(
bind=self.sync_engine,
autocommit=False,
autoflush=False,
)
async def init_db(self):
"""初始化数据库,创建所有表"""
try:
async with self.async_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("MySQL 数据库表初始化完成")
except Exception as e:
logger.error(f"MySQL 数据库初始化失败: {e}")
raise
async def close(self):
"""关闭数据库连接"""
await self.async_engine.dispose()
self.sync_engine.dispose()
logger.info("MySQL 数据库连接已关闭")
@asynccontextmanager
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
"""获取异步数据库会话"""
session = self.async_session_factory()
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def execute_query(
self,
query: str,
params: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
执行原始 SQL 查询
Args:
query: SQL 查询语句
params: 查询参数
Returns:
查询结果列表
"""
async with self.get_session() as session:
result = await session.execute(select(text(query)), params or {})
rows = result.fetchall()
return [dict(row._mapping) for row in rows]
async def execute_raw_sql(
self,
sql: str,
params: Optional[Dict[str, Any]] = None
) -> Any:
"""
执行原始 SQL 语句 (INSERT/UPDATE/DELETE)
Args:
sql: SQL 语句
params: 语句参数
Returns:
执行结果
"""
async with self.get_session() as session:
result = await session.execute(text(sql), params or {})
await session.commit()
return result.lastrowid if result.lastrowid else result.rowcount
# ==================== 预定义的数据模型 ====================
class DocumentTable(Base):
"""文档元数据表 - 存储已解析文档的基本信息"""
__tablename__ = "document_tables"
id = Column(Integer, primary_key=True, autoincrement=True)
table_name = Column(String(255), unique=True, nullable=False, comment="表名")
display_name = Column(String(255), comment="显示名称")
description = Column(Text, comment="表描述")
source_file = Column(String(512), comment="来源文件")
column_count = Column(Integer, default=0, comment="列数")
row_count = Column(Integer, default=0, comment="行数")
file_size = Column(Integer, comment="文件大小(字节)")
created_at = Column(DateTime, comment="创建时间")
updated_at = Column(DateTime, comment="更新时间")
class DocumentField(Base):
"""文档字段表 - 存储每个表的字段信息"""
__tablename__ = "document_fields"
id = Column(Integer, primary_key=True, autoincrement=True)
table_id = Column(Integer, nullable=False, comment="所属表ID")
field_name = Column(String(255), nullable=False, comment="字段名")
field_type = Column(String(50), comment="字段类型")
field_description = Column(Text, comment="字段描述/语义")
is_key_field = Column(Integer, default=0, comment="是否主键")
is_nullable = Column(Integer, default=1, comment="是否可空")
sample_values = Column(Text, comment="示例值(逗号分隔)")
created_at = Column(DateTime, comment="创建时间")
class TaskRecord(Base):
"""任务记录表 - 存储异步任务信息"""
__tablename__ = "task_records"
id = Column(Integer, primary_key=True, autoincrement=True)
task_id = Column(String(255), unique=True, nullable=False, comment="Celery任务ID")
task_type = Column(String(50), comment="任务类型")
status = Column(String(50), default="pending", comment="任务状态")
input_params = Column(Text, comment="输入参数JSON")
result_data = Column(Text, comment="结果数据JSON")
error_message = Column(Text, comment="错误信息")
started_at = Column(DateTime, comment="开始时间")
completed_at = Column(DateTime, comment="完成时间")
created_at = Column(DateTime, comment="创建时间")
# ==================== 全局单例 ====================
mysql_db = MySQLDB()

View File

@@ -0,0 +1,287 @@
"""
Redis 数据库连接管理模块
提供缓存和任务队列功能
"""
import json
import logging
from datetime import timedelta
from typing import Any, Optional
import redis.asyncio as redis
from app.config import settings
logger = logging.getLogger(__name__)
class RedisDB:
"""Redis 数据库管理类"""
def __init__(self):
self.client: Optional[redis.Redis] = None
self._connected = False
async def connect(self):
"""建立 Redis 连接"""
try:
self.client = redis.from_url(
settings.REDIS_URL,
encoding="utf-8",
decode_responses=True,
)
# 验证连接
await self.client.ping()
self._connected = True
logger.info(f"Redis 连接成功: {settings.REDIS_URL}")
except Exception as e:
logger.error(f"Redis 连接失败: {e}")
raise
async def close(self):
"""关闭 Redis 连接"""
if self.client:
await self.client.close()
self._connected = False
logger.info("Redis 连接已关闭")
@property
def is_connected(self) -> bool:
"""检查连接状态"""
return self._connected
# ==================== 基础操作 ====================
async def get(self, key: str) -> Optional[str]:
"""获取值"""
return await self.client.get(key)
async def set(
self,
key: str,
value: str,
expire: Optional[int] = None,
) -> bool:
"""
设置值
Args:
key: 键
value: 值
expire: 过期时间(秒)
Returns:
是否成功
"""
return await self.client.set(key, value, ex=expire)
async def delete(self, key: str) -> int:
"""删除键"""
return await self.client.delete(key)
async def exists(self, key: str) -> bool:
"""检查键是否存在"""
return await self.client.exists(key) > 0
# ==================== JSON 操作 ====================
async def set_json(
self,
key: str,
data: Dict[str, Any],
expire: Optional[int] = None,
) -> bool:
"""
设置 JSON 数据
Args:
key: 键
data: 数据字典
expire: 过期时间(秒)
Returns:
是否成功
"""
json_str = json.dumps(data, ensure_ascii=False, default=str)
return await self.set(key, json_str, expire)
async def get_json(self, key: str) -> Optional[Dict[str, Any]]:
"""
获取 JSON 数据
Args:
key: 键
Returns:
数据字典,不存在返回 None
"""
value = await self.get(key)
if value:
try:
return json.loads(value)
except json.JSONDecodeError:
return None
return None
# ==================== 任务状态管理 ====================
async def set_task_status(
self,
task_id: str,
status: str,
meta: Optional[Dict[str, Any]] = None,
expire: int = 86400, # 默认24小时过期
) -> bool:
"""
设置任务状态
Args:
task_id: 任务ID
status: 状态 (pending/processing/success/failure)
meta: 附加信息
expire: 过期时间(秒)
Returns:
是否成功
"""
key = f"task:{task_id}"
data = {
"status": status,
"meta": meta or {},
}
return await self.set_json(key, data, expire)
async def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
"""
获取任务状态
Args:
task_id: 任务ID
Returns:
状态信息
"""
key = f"task:{task_id}"
return await self.get_json(key)
async def update_task_progress(
self,
task_id: str,
progress: int,
message: Optional[str] = None,
) -> bool:
"""
更新任务进度
Args:
task_id: 任务ID
progress: 进度值 (0-100)
message: 进度消息
Returns:
是否成功
"""
data = await self.get_task_status(task_id)
if data:
data["meta"]["progress"] = progress
if message:
data["meta"]["message"] = message
key = f"task:{task_id}"
return await self.set_json(key, data, expire=86400)
return False
# ==================== 缓存操作 ====================
async def cache_document(
self,
doc_id: str,
data: Dict[str, Any],
expire: int = 3600, # 默认1小时
) -> bool:
"""
缓存文档数据
Args:
doc_id: 文档ID
data: 文档数据
expire: 过期时间(秒)
Returns:
是否成功
"""
key = f"doc:{doc_id}"
return await self.set_json(key, data, expire)
async def get_cached_document(self, doc_id: str) -> Optional[Dict[str, Any]]:
"""
获取缓存的文档
Args:
doc_id: 文档ID
Returns:
文档数据
"""
key = f"doc:{doc_id}"
return await self.get_json(key)
# ==================== 分布式锁 ====================
async def acquire_lock(
self,
lock_name: str,
expire: int = 30,
) -> bool:
"""
获取分布式锁
Args:
lock_name: 锁名称
expire: 过期时间(秒)
Returns:
是否获取成功
"""
key = f"lock:{lock_name}"
# 使用 SET NX EX 原子操作
result = await self.client.set(key, "1", nx=True, ex=expire)
return result is not None
async def release_lock(self, lock_name: str) -> bool:
"""
释放分布式锁
Args:
lock_name: 锁名称
Returns:
是否释放成功
"""
key = f"lock:{lock_name}"
result = await self.client.delete(key)
return result > 0
# ==================== 计数器 ====================
async def incr(self, key: str, amount: int = 1) -> int:
"""递增计数器"""
return await self.client.incrby(key, amount)
async def decr(self, key: str, amount: int = 1) -> int:
"""递减计数器"""
return await self.client.decrby(key, amount)
# ==================== 过期时间管理 ====================
async def expire(self, key: str, seconds: int) -> bool:
"""设置键的过期时间"""
return await self.client.expire(key, seconds)
async def ttl(self, key: str) -> int:
"""获取键的剩余生存时间"""
return await self.client.ttl(key)
# ==================== 全局单例 ====================
redis_db = RedisDB()

View File

@@ -1,7 +1,48 @@
"""
文档解析模块 - 支持多种文件格式的解析
"""
from .base import BaseParser
from pathlib import Path
from typing import Dict, Optional
from .base import BaseParser, ParseResult
from .xlsx_parser import XlsxParser
__all__ = ['BaseParser', 'XlsxParser']
# 导入其他解析器 (需要先实现)
# from .docx_parser import DocxParser
# from .md_parser import MarkdownParser
# from .txt_parser import TxtParser
class ParserFactory:
"""解析器工厂,根据文件类型返回对应解析器"""
_parsers: Dict[str, BaseParser] = {
'.xlsx': XlsxParser(),
'.xls': XlsxParser(),
# '.docx': DocxParser(), # TODO: 待实现
# '.md': MarkdownParser(), # TODO: 待实现
# '.txt': TxtParser(), # TODO: 待实现
}
@classmethod
def get_parser(cls, file_path: str) -> BaseParser:
"""根据文件扩展名获取解析器"""
ext = Path(file_path).suffix.lower()
parser = cls._parsers.get(ext)
if not parser:
raise ValueError(f"不支持的文件格式: {ext},支持的格式: {list(cls._parsers.keys())}")
return parser
@classmethod
def parse(cls, file_path: str, **kwargs) -> ParseResult:
"""统一解析接口"""
parser = cls.get_parser(file_path)
return parser.parse(file_path, **kwargs)
@classmethod
def register_parser(cls, ext: str, parser: BaseParser):
"""注册新的解析器"""
cls._parsers[ext.lower()] = parser
__all__ = ['BaseParser', 'ParseResult', 'XlsxParser', 'ParserFactory']

View File

@@ -1,10 +1,67 @@
"""
FastAPI 应用主入口
"""
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.config import settings
from app.api import api_router
from app.core.database import mysql_db, mongodb, redis_db
# 配置日志
logging.basicConfig(
level=logging.INFO if settings.DEBUG else logging.WARNING,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
应用生命周期管理
启动时: 初始化数据库连接
关闭时: 关闭数据库连接
"""
# 启动时
logger.info("正在初始化数据库连接...")
# 初始化 MySQL
try:
await mysql_db.init_db()
logger.info("✓ MySQL 初始化成功")
except Exception as e:
logger.error(f"✗ MySQL 初始化失败: {e}")
# 初始化 MongoDB
try:
await mongodb.connect()
await mongodb.create_indexes()
logger.info("✓ MongoDB 初始化成功")
except Exception as e:
logger.error(f"✗ MongoDB 初始化失败: {e}")
# 初始化 Redis
try:
await redis_db.connect()
logger.info("✓ Redis 初始化成功")
except Exception as e:
logger.error(f"✗ Redis 初始化失败: {e}")
logger.info("数据库初始化完成")
yield
# 关闭时
logger.info("正在关闭数据库连接...")
await mysql_db.close()
await mongodb.close()
await redis_db.close()
logger.info("数据库连接已关闭")
# 创建 FastAPI 应用实例
app = FastAPI(
@@ -13,7 +70,8 @@ app = FastAPI(
version="1.0.0",
openapi_url=f"{settings.API_V1_STR}/openapi.json",
docs_url=f"{settings.API_V1_STR}/docs",
redoc_url=f"{settings.API_V1_STR}/redoc"
redoc_url=f"{settings.API_V1_STR}/redoc",
lifespan=lifespan, # 添加生命周期管理
)
# 配置 CORS 中间件
@@ -43,10 +101,24 @@ async def root():
@app.get("/health")
async def health_check():
"""健康检查接口"""
"""
健康检查接口
返回各数据库连接状态
"""
# 检查各数据库连接状态
mysql_status = "connected" if mysql_db.async_engine else "disconnected"
mongodb_status = "connected" if mongodb.client else "disconnected"
redis_status = "connected" if redis_db.is_connected else "disconnected"
return {
"status": "healthy",
"service": settings.APP_NAME
"service": settings.APP_NAME,
"databases": {
"mysql": mysql_status,
"mongodb": mongodb_status,
"redis": redis_status,
}
}

View File

@@ -0,0 +1,18 @@
"""
数据模型模块
定义数据库表结构和数据模型
"""
from app.core.database.mysql import (
Base,
DocumentField,
DocumentTable,
TaskRecord,
)
__all__ = [
"Base",
"DocumentTable",
"DocumentField",
"TaskRecord",
]

View File

@@ -0,0 +1,172 @@
"""
文档数据模型
定义文档相关的 Pydantic 模型
"""
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
class DocumentType(str, Enum):
"""文档类型枚举"""
DOCX = "docx"
XLSX = "xlsx"
MD = "md"
TXT = "txt"
class TaskStatus(str, Enum):
"""任务状态枚举"""
PENDING = "pending"
PROCESSING = "processing"
SUCCESS = "success"
FAILURE = "failure"
# ==================== 解析结果模型 ====================
class DocumentMetadata(BaseModel):
"""文档元数据"""
filename: str
extension: str
file_size: int = 0
doc_type: Optional[str] = None
sheet_count: Optional[int] = None
sheet_names: Optional[List[str]] = None
current_sheet: Optional[str] = None
row_count: Optional[int] = None
column_count: Optional[int] = None
columns: Optional[List[str]] = None
encoding: Optional[str] = None
class ParseResultData(BaseModel):
"""解析结果数据"""
columns: List[str] = Field(default_factory=list)
rows: List[Dict[str, Any]] = Field(default_factory=list)
row_count: int = 0
column_count: int = 0
class ParseResult(BaseModel):
"""文档解析结果"""
success: bool
data: Optional[ParseResultData] = None
metadata: Optional[DocumentMetadata] = None
error: Optional[str] = None
# ==================== 存储模型 ====================
class DocumentStore(BaseModel):
"""文档存储模型"""
doc_id: str
doc_type: DocumentType
content: str
metadata: DocumentMetadata
structured_data: Optional[Dict[str, Any]] = None
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class RAGEntry(BaseModel):
"""RAG索引条目"""
table_name: str
field_name: str
field_description: str
embedding: List[float]
metadata: Optional[Dict[str, Any]] = None
# ==================== 任务模型 ====================
class TaskCreate(BaseModel):
"""任务创建请求"""
task_type: str
input_params: Dict[str, Any]
class TaskStatusResponse(BaseModel):
"""任务状态响应"""
task_id: str
status: TaskStatus
progress: int = 0
message: Optional[str] = None
result: Optional[Any] = None
error: Optional[str] = None
# ==================== 模板填写模型 ====================
class TemplateField(BaseModel):
"""模板字段"""
cell: str = Field(description="单元格位置, 如 A1")
name: str = Field(description="字段名称")
field_type: str = Field(default="text", description="字段类型: text/number/date")
required: bool = Field(default=True, description="是否必填")
class TemplateSheet(BaseModel):
"""模板工作表"""
name: str
fields: List[TemplateField]
class TemplateInfo(BaseModel):
"""模板信息"""
file_path: str
file_type: str # xlsx/docx
sheets: List[TemplateSheet]
class FillRequest(BaseModel):
"""填写请求"""
template_path: str
template_fields: List[TemplateField]
source_doc_ids: Optional[List[str]] = None
class FillResult(BaseModel):
"""填写结果"""
success: bool
filled_data: Dict[str, Any]
fill_details: List[Dict[str, Any]]
source_documents: List[str] = Field(default_factory=list)
# ==================== API 响应模型 ====================
class UploadResponse(BaseModel):
"""上传响应"""
task_id: str
file_count: int
message: str
status_url: str
class AnalyzeResponse(BaseModel):
"""分析响应"""
success: bool
analysis: Optional[str] = None
structured_data: Optional[Dict[str, Any]] = None
model: Optional[str] = None
error: Optional[str] = None
class QueryRequest(BaseModel):
"""查询请求"""
user_intent: str
table_name: Optional[str] = None
top_k: int = Field(default=5, ge=1, le=20)
class QueryResponse(BaseModel):
"""查询响应"""
success: bool
sql_query: Optional[str] = None
results: Optional[List[Dict[str, Any]]] = None
rag_context: Optional[List[str]] = None
error: Optional[str] = None

View File

@@ -0,0 +1,233 @@
"""
RAG 服务模块 - 检索增强生成
使用 LangChain + Faiss 实现向量检索
"""
import logging
import os
from typing import Any, Dict, List, Optional
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document as LangchainDocument
from langchain.vectorstores import FAISS
from app.config import settings
logger = logging.getLogger(__name__)
class RAGService:
"""RAG 检索增强服务"""
def __init__(self):
self.embeddings: Optional[HuggingFaceEmbeddings] = None
self.vector_store: Optional[FAISS] = None
self._initialized = False
def _init_embeddings(self):
"""初始化嵌入模型"""
if self.embeddings is None:
self.embeddings = HuggingFaceEmbeddings(
model_name=settings.EMBEDDING_MODEL,
model_kwargs={'device': 'cpu'}
)
logger.info(f"RAG 嵌入模型初始化完成: {settings.EMBEDDING_MODEL}")
def _init_vector_store(self):
"""初始化向量存储"""
if self.vector_store is None:
self._init_embeddings()
self.vector_store = FAISS(
embedding_function=self.embeddings,
index=None, # 创建一个空索引
docstore={},
index_to_docstore_id={}
)
logger.info("Faiss 向量存储初始化完成")
async def initialize(self):
"""异步初始化"""
try:
self._init_vector_store()
self._initialized = True
logger.info("RAG 服务初始化成功")
except Exception as e:
logger.error(f"RAG 服务初始化失败: {e}")
raise
def index_field(
self,
table_name: str,
field_name: str,
field_description: str,
sample_values: Optional[List[str]] = None
):
"""
将字段信息索引到向量数据库
Args:
table_name: 表名
field_name: 字段名
field_description: 字段语义描述
sample_values: 示例值
"""
if not self._initialized:
self._init_vector_store()
# 构造完整文本
text = f"表名: {table_name}, 字段: {field_name}, 描述: {field_description}"
if sample_values:
text += f", 示例值: {', '.join(sample_values)}"
# 创建文档
doc_id = f"{table_name}.{field_name}"
doc = LangchainDocument(
page_content=text,
metadata={
"table_name": table_name,
"field_name": field_name,
"doc_id": doc_id
}
)
# 添加到向量存储
if self.vector_store is None:
self._init_vector_store()
self.vector_store.add_documents([doc], ids=[doc_id])
logger.debug(f"已索引字段: {doc_id}")
def index_document_content(
self,
doc_id: str,
content: str,
metadata: Optional[Dict[str, Any]] = None
):
"""
将文档内容索引到向量数据库
Args:
doc_id: 文档ID
content: 文档内容
metadata: 元数据
"""
if not self._initialized:
self._init_vector_store()
doc = LangchainDocument(
page_content=content,
metadata=metadata or {"doc_id": doc_id}
)
if self.vector_store is None:
self._init_vector_store()
self.vector_store.add_documents([doc], ids=[doc_id])
logger.debug(f"已索引文档: {doc_id}")
def retrieve(
self,
query: str,
top_k: int = 5
) -> List[Dict[str, Any]]:
"""
根据查询检索相关文档
Args:
query: 用户查询
top_k: 返回数量
Returns:
相关文档列表
"""
if not self._initialized:
self._init_vector_store()
if self.vector_store is None:
return []
# 执行相似度搜索
docs_and_scores = self.vector_store.similarity_search_with_score(
query,
k=top_k
)
results = []
for doc, score in docs_and_scores:
results.append({
"content": doc.page_content,
"metadata": doc.metadata,
"score": float(score), # 距离分数,越小越相似
"doc_id": doc.metadata.get("doc_id", "")
})
logger.debug(f"检索到 {len(results)} 条相关文档")
return results
def retrieve_by_table(self, table_name: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""
检索指定表的字段
Args:
table_name: 表名
top_k: 返回数量
Returns:
相关字段列表
"""
return self.retrieve(f"表名: {table_name}", top_k)
def get_vector_count(self) -> int:
"""获取向量总数"""
if self.vector_store is None:
return 0
return len(self.vector_store.docstore._dict)
def save_index(self, persist_path: str):
"""
保存向量索引到磁盘
Args:
persist_path: 保存路径
"""
if self.vector_store is not None:
self.vector_store.save_local(persist_path)
logger.info(f"向量索引已保存到: {persist_path}")
def load_index(self, persist_path: str):
"""
从磁盘加载向量索引
Args:
persist_path: 保存路径
"""
if not os.path.exists(persist_path):
logger.warning(f"向量索引文件不存在: {persist_path}")
return
self._init_embeddings()
self.vector_store = FAISS.load_local(
persist_path,
self.embeddings,
allow_dangerous_deserialization=True
)
self._initialized = True
logger.info(f"向量索引已从 {persist_path} 加载")
def delete_by_doc_id(self, doc_id: str):
"""根据文档ID删除索引"""
if self.vector_store is not None:
self.vector_store.delete(ids=[doc_id])
logger.debug(f"已删除索引: {doc_id}")
def clear(self):
"""清空所有索引"""
self._init_vector_store()
if self.vector_store is not None:
self.vector_store.delete(ids=list(self.vector_store.docstore._dict.keys()))
logger.info("已清空所有向量索引")
# ==================== 全局单例 ====================
rag_service = RAGService()

View File

@@ -1,24 +1,55 @@
# ============================================================
# 基于大语言模型的文档理解与多源数据融合系统
# Python 依赖清单
# ============================================================
# ==================== Web 框架 ====================
fastapi[all]==0.104.1
uvicorn[standard]==0.24.0
python-multipart==0.0.6
# ==================== 数据验证与配置 ====================
pydantic==2.5.0
pydantic-settings==2.1.0
python-multipart==0.0.6
python-dotenv==1.0.0
# ==================== 数据库 - MySQL (结构化数据) ====================
pymysql==1.1.0
sqlalchemy==2.0.25
# ==================== 数据库 - MongoDB (非结构化数据) ====================
motor==3.3.2
pymongo==4.5.0
# ==================== 数据库 - Redis (缓存/队列) ====================
redis==5.0.0
# ==================== 异步任务 ====================
celery==5.3.4
# ==================== RAG / 向量数据库 ====================
# chromadb==0.4.22 # Windows 需要 C++ 编译环境,如需安装请使用预编译版本或 WSL
sentence-transformers==2.2.2
faiss-cpu==1.8.0
python-docx==0.8.11
# ==================== 文档解析 ====================
pandas==2.1.4
openpyxl==3.1.2
matplotlib==3.8.2
numpy==1.26.2
markdown==3.5.1
python-docx==0.8.11
markdown-it-py==3.0.0
chardet==5.2.0
# ==================== AI / LLM ====================
langchain==0.1.0
langchain-community==0.0.10
requests==2.31.0
httpx==0.25.2
python-dotenv==1.0.0
# ==================== 数据处理与可视化 ====================
matplotlib==3.8.2
numpy==1.26.2
# ==================== 工具库 ====================
requests==2.31.0
loguru==0.7.2
tqdm==4.66.1
PyYAML==6.0.1

File diff suppressed because it is too large Load Diff