diff --git a/backend/app/api/endpoints/ai_analyze.py b/backend/app/api/endpoints/ai_analyze.py index 7cbc83d..a8f49e1 100644 --- a/backend/app/api/endpoints/ai_analyze.py +++ b/backend/app/api/endpoints/ai_analyze.py @@ -216,9 +216,12 @@ async def analyze_markdown( return result finally: - # 清理临时文件 - if os.path.exists(tmp_path): - os.unlink(tmp_path) + # 清理临时文件,确保在所有情况下都能清理 + try: + if tmp_path and os.path.exists(tmp_path): + os.unlink(tmp_path) + except Exception as cleanup_error: + logger.warning(f"临时文件清理失败: {tmp_path}, error: {cleanup_error}") except HTTPException: raise @@ -280,8 +283,12 @@ async def analyze_markdown_stream( ) finally: - if os.path.exists(tmp_path): - os.unlink(tmp_path) + # 清理临时文件,确保在所有情况下都能清理 + try: + if tmp_path and os.path.exists(tmp_path): + os.unlink(tmp_path) + except Exception as cleanup_error: + logger.warning(f"临时文件清理失败: {tmp_path}, error: {cleanup_error}") except HTTPException: raise @@ -290,7 +297,7 @@ async def analyze_markdown_stream( raise HTTPException(status_code=500, detail=f"流式分析失败: {str(e)}") -@router.get("/analyze/md/outline") +@router.post("/analyze/md/outline") async def get_markdown_outline( file: UploadFile = File(...) ): @@ -324,8 +331,12 @@ async def get_markdown_outline( result = await markdown_ai_service.extract_outline(tmp_path) return result finally: - if os.path.exists(tmp_path): - os.unlink(tmp_path) + # 清理临时文件,确保在所有情况下都能清理 + try: + if tmp_path and os.path.exists(tmp_path): + os.unlink(tmp_path) + except Exception as cleanup_error: + logger.warning(f"临时文件清理失败: {tmp_path}, error: {cleanup_error}") except Exception as e: logger.error(f"获取 Markdown 大纲失败: {str(e)}") diff --git a/backend/app/api/endpoints/documents.py b/backend/app/api/endpoints/documents.py index 82d8551..3c682c2 100644 --- a/backend/app/api/endpoints/documents.py +++ b/backend/app/api/endpoints/documents.py @@ -23,6 +23,52 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/upload", tags=["文档上传"]) +# ==================== 辅助函数 ==================== + +async def update_task_status( + task_id: str, + status: str, + progress: int = 0, + message: str = "", + result: dict = None, + error: str = None +): + """ + 更新任务状态,同时写入 Redis 和 MongoDB + + Args: + task_id: 任务ID + status: 状态 + progress: 进度 + message: 消息 + result: 结果 + error: 错误信息 + """ + meta = {"progress": progress, "message": message} + if result: + meta["result"] = result + if error: + meta["error"] = error + + # 尝试写入 Redis + try: + await redis_db.set_task_status(task_id, status, meta) + except Exception as e: + logger.warning(f"Redis 任务状态更新失败: {e}") + + # 尝试写入 MongoDB(作为备用) + try: + await mongodb.update_task( + task_id=task_id, + status=status, + message=message, + result=result, + error=error + ) + except Exception as e: + logger.warning(f"MongoDB 任务状态更新失败: {e}") + + # ==================== 请求/响应模型 ==================== class UploadResponse(BaseModel): @@ -77,6 +123,17 @@ async def upload_document( task_id = str(uuid.uuid4()) try: + # 保存任务记录到 MongoDB(如果 Redis 不可用时仍能查询) + try: + await mongodb.insert_task( + task_id=task_id, + task_type="document_parse", + status="pending", + message=f"文档 {file.filename} 已提交处理" + ) + except Exception as mongo_err: + logger.warning(f"MongoDB 保存任务记录失败: {mongo_err}") + content = await file.read() saved_path = file_service.save_uploaded_file( content, @@ -122,6 +179,17 @@ async def upload_documents( saved_paths = [] try: + # 保存任务记录到 MongoDB + try: + await mongodb.insert_task( + task_id=task_id, + task_type="batch_parse", + status="pending", + message=f"已提交 {len(files)} 个文档处理" + ) + except Exception as mongo_err: + logger.warning(f"MongoDB 保存批量任务记录失败: {mongo_err}") + for file in files: if not file.filename: continue @@ -159,9 +227,9 @@ async def process_document( """处理单个文档""" try: # 状态: 解析中 - await redis_db.set_task_status( + await update_task_status( task_id, status="processing", - meta={"progress": 10, "message": "正在解析文档"} + progress=10, message="正在解析文档" ) # 解析文档 @@ -172,9 +240,9 @@ async def process_document( raise Exception(result.error or "解析失败") # 状态: 存储中 - await redis_db.set_task_status( + await update_task_status( task_id, status="processing", - meta={"progress": 30, "message": "正在存储数据"} + progress=30, message="正在存储数据" ) # 存储到 MongoDB @@ -235,9 +303,9 @@ async def process_document( # 如果是 Excel,存储到 MySQL + AI生成描述 + RAG索引 if doc_type in ["xlsx", "xls"]: - await redis_db.set_task_status( + await update_task_status( task_id, status="processing", - meta={"progress": 50, "message": "正在存储到MySQL并生成字段描述"} + progress=50, message="正在存储到MySQL并生成字段描述" ) try: @@ -259,9 +327,9 @@ async def process_document( else: # 非结构化文档 - await redis_db.set_task_status( + await update_task_status( task_id, status="processing", - meta={"progress": 60, "message": "正在建立索引"} + progress=60, message="正在建立索引" ) # 如果文档中有表格数据,提取并存储到 MySQL + RAG @@ -282,17 +350,13 @@ async def process_document( await index_document_to_rag(doc_id, original_filename, result, doc_type) # 完成 - await redis_db.set_task_status( + await update_task_status( task_id, status="success", - meta={ - "progress": 100, - "message": "处理完成", + progress=100, message="处理完成", + result={ "doc_id": doc_id, - "result": { - "doc_id": doc_id, - "doc_type": doc_type, - "filename": original_filename - } + "doc_type": doc_type, + "filename": original_filename } ) @@ -300,18 +364,19 @@ async def process_document( except Exception as e: logger.error(f"文档处理失败: {str(e)}") - await redis_db.set_task_status( + await update_task_status( task_id, status="failure", - meta={"error": str(e)} + progress=0, message="处理失败", + error=str(e) ) async def process_documents_batch(task_id: str, files: List[dict]): """批量处理文档""" try: - await redis_db.set_task_status( + await update_task_status( task_id, status="processing", - meta={"progress": 0, "message": "开始批量处理"} + progress=0, message="开始批量处理" ) results = [] @@ -362,21 +427,23 @@ async def process_documents_batch(task_id: str, files: List[dict]): results.append({"filename": file_info["filename"], "success": False, "error": str(e)}) progress = int((i + 1) / len(files) * 100) - await redis_db.set_task_status( + await update_task_status( task_id, status="processing", - meta={"progress": progress, "message": f"已处理 {i+1}/{len(files)}"} + progress=progress, message=f"已处理 {i+1}/{len(files)}" ) - await redis_db.set_task_status( + await update_task_status( task_id, status="success", - meta={"progress": 100, "message": "批量处理完成", "results": results} + progress=100, message="批量处理完成", + result={"results": results} ) except Exception as e: logger.error(f"批量处理失败: {str(e)}") - await redis_db.set_task_status( + await update_task_status( task_id, status="failure", - meta={"error": str(e)} + progress=0, message="批量处理失败", + error=str(e) ) diff --git a/backend/app/api/endpoints/health.py b/backend/app/api/endpoints/health.py index 2f239be..00f2049 100644 --- a/backend/app/api/endpoints/health.py +++ b/backend/app/api/endpoints/health.py @@ -19,26 +19,43 @@ async def health_check() -> Dict[str, Any]: 返回各数据库连接状态和应用信息 """ # 检查各数据库连接状态 - mysql_status = "connected" - mongodb_status = "connected" - redis_status = "connected" + mysql_status = "unknown" + mongodb_status = "unknown" + redis_status = "unknown" try: if mysql_db.async_engine is None: mysql_status = "disconnected" - except Exception: + else: + # 实际执行一次查询验证连接 + from sqlalchemy import text + async with mysql_db.async_engine.connect() as conn: + await conn.execute(text("SELECT 1")) + mysql_status = "connected" + except Exception as e: + logger.warning(f"MySQL 健康检查失败: {e}") mysql_status = "error" try: if mongodb.client is None: mongodb_status = "disconnected" - except Exception: + else: + # 实际 ping 验证 + await mongodb.client.admin.command('ping') + mongodb_status = "connected" + except Exception as e: + logger.warning(f"MongoDB 健康检查失败: {e}") mongodb_status = "error" try: - if not redis_db.is_connected: + if not redis_db.is_connected or redis_db.client is None: redis_status = "disconnected" - except Exception: + else: + # 实际执行 ping 验证 + await redis_db.client.ping() + redis_status = "connected" + except Exception as e: + logger.warning(f"Redis 健康检查失败: {e}") redis_status = "error" return { diff --git a/backend/app/api/endpoints/tasks.py b/backend/app/api/endpoints/tasks.py index aeea884..1df7a44 100644 --- a/backend/app/api/endpoints/tasks.py +++ b/backend/app/api/endpoints/tasks.py @@ -1,13 +1,13 @@ """ 任务管理 API 接口 -提供异步任务状态查询 +提供异步任务状态查询和历史记录 """ from typing import Optional from fastapi import APIRouter, HTTPException -from app.core.database import redis_db +from app.core.database import redis_db, mongodb router = APIRouter(prefix="/tasks", tags=["任务管理"]) @@ -23,25 +23,94 @@ async def get_task_status(task_id: str): Returns: 任务状态信息 """ + # 优先从 Redis 获取 status = await redis_db.get_task_status(task_id) - if not status: - # Redis不可用时,假设任务已完成(文档已成功处理) - # 前端轮询时会得到这个响应 + if status: return { "task_id": task_id, - "status": "success", - "progress": 100, - "message": "任务处理完成", - "result": None, - "error": None + "status": status.get("status", "unknown"), + "progress": status.get("meta", {}).get("progress", 0), + "message": status.get("meta", {}).get("message"), + "result": status.get("meta", {}).get("result"), + "error": status.get("meta", {}).get("error") } + # Redis 不可用时,尝试从 MongoDB 获取 + mongo_task = await mongodb.get_task(task_id) + if mongo_task: + return { + "task_id": mongo_task.get("task_id"), + "status": mongo_task.get("status", "unknown"), + "progress": 100 if mongo_task.get("status") == "success" else 0, + "message": mongo_task.get("message"), + "result": mongo_task.get("result"), + "error": mongo_task.get("error") + } + + # 任务不存在或状态未知 return { "task_id": task_id, - "status": status.get("status", "unknown"), - "progress": status.get("meta", {}).get("progress", 0), - "message": status.get("meta", {}).get("message"), - "result": status.get("meta", {}).get("result"), - "error": status.get("meta", {}).get("error") + "status": "unknown", + "progress": 0, + "message": "无法获取任务状态(Redis和MongoDB均不可用)", + "result": None, + "error": None } + + +@router.get("/") +async def list_tasks(limit: int = 50, skip: int = 0): + """ + 获取任务历史列表 + + Args: + limit: 返回数量限制 + skip: 跳过数量 + + Returns: + 任务列表 + """ + try: + tasks = await mongodb.list_tasks(limit=limit, skip=skip) + return { + "success": True, + "tasks": tasks, + "count": len(tasks) + } + except Exception as e: + # MongoDB 不可用时返回空列表 + return { + "success": False, + "tasks": [], + "count": 0, + "error": str(e) + } + + +@router.delete("/{task_id}") +async def delete_task(task_id: str): + """ + 删除任务 + + Args: + task_id: 任务ID + + Returns: + 是否删除成功 + """ + try: + # 从 Redis 删除 + if redis_db._connected and redis_db.client: + key = f"task:{task_id}" + await redis_db.client.delete(key) + + # 从 MongoDB 删除 + deleted = await mongodb.delete_task(task_id) + + return { + "success": True, + "deleted": deleted + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"删除任务失败: {str(e)}") diff --git a/backend/app/api/endpoints/templates.py b/backend/app/api/endpoints/templates.py index 3803196..bdbb3f0 100644 --- a/backend/app/api/endpoints/templates.py +++ b/backend/app/api/endpoints/templates.py @@ -23,6 +23,44 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/templates", tags=["表格模板"]) +# ==================== 辅助函数 ==================== + +async def update_task_status( + task_id: str, + status: str, + progress: int = 0, + message: str = "", + result: dict = None, + error: str = None +): + """ + 更新任务状态,同时写入 Redis 和 MongoDB + """ + from app.core.database import redis_db + + meta = {"progress": progress, "message": message} + if result: + meta["result"] = result + if error: + meta["error"] = error + + try: + await redis_db.set_task_status(task_id, status, meta) + except Exception as e: + logger.warning(f"Redis 任务状态更新失败: {e}") + + try: + await mongodb.update_task( + task_id=task_id, + status=status, + message=message, + result=result, + error=error + ) + except Exception as e: + logger.warning(f"MongoDB 任务状态更新失败: {e}") + + # ==================== 请求/响应模型 ==================== class TemplateFieldRequest(BaseModel): @@ -41,6 +79,7 @@ class FillRequest(BaseModel): source_doc_ids: Optional[List[str]] = None # MongoDB 文档 ID 列表 source_file_paths: Optional[List[str]] = None # 源文档文件路径列表 user_hint: Optional[str] = None + task_id: Optional[str] = None # 可选的任务ID,用于任务历史跟踪 class ExportRequest(BaseModel): @@ -162,20 +201,17 @@ async def upload_joint_template( ) try: - # 1. 保存模板文件并提取字段 + # 1. 保存模板文件 template_content = await template_file.read() template_path = file_service.save_uploaded_file( template_content, template_file.filename, subfolder="templates" ) - template_fields = await template_fill_service.get_template_fields_from_file( - template_path, - template_ext - ) - # 2. 处理源文档 - 保存文件 + # 2. 保存并解析源文档 - 提取内容用于生成表头 source_file_info = [] + source_contents = [] for sf in source_files: if sf.filename: sf_content = await sf.read() @@ -190,10 +226,81 @@ async def upload_joint_template( "filename": sf.filename, "ext": sf_ext }) + # 解析源文档获取内容(用于 AI 生成表头) + try: + from app.core.document_parser import ParserFactory + parser = ParserFactory.get_parser(sf_path) + parse_result = parser.parse(sf_path) + if parse_result.success and parse_result.data: + # 获取原始内容 + content = parse_result.data.get("content", "")[:5000] if parse_result.data.get("content") else "" + + # 获取标题(可能在顶层或structured_data内) + titles = parse_result.data.get("titles", []) + if not titles and parse_result.data.get("structured_data"): + titles = parse_result.data.get("structured_data", {}).get("titles", []) + titles = titles[:10] if titles else [] + + # 获取表格数量(可能在顶层或structured_data内) + tables = parse_result.data.get("tables", []) + if not tables and parse_result.data.get("structured_data"): + tables = parse_result.data.get("structured_data", {}).get("tables", []) + tables_count = len(tables) if tables else 0 + + # 获取表格内容摘要(用于 AI 理解源文档结构) + tables_summary = "" + if tables: + tables_summary = "\n【文档中的表格】:\n" + for idx, table in enumerate(tables[:5]): # 最多5个表格 + if isinstance(table, dict): + headers = table.get("headers", []) + rows = table.get("rows", []) + if headers: + tables_summary += f"表格{idx+1}表头: {', '.join(str(h) for h in headers)}\n" + if rows: + tables_summary += f"表格{idx+1}前3行: " + for row_idx, row in enumerate(rows[:3]): + if isinstance(row, list): + tables_summary += " | ".join(str(c) for c in row) + "; " + elif isinstance(row, dict): + tables_summary += " | ".join(str(row.get(h, "")) for h in headers if headers) + "; " + tables_summary += "\n" + + source_contents.append({ + "filename": sf.filename, + "doc_type": sf_ext, + "content": content, + "titles": titles, + "tables_count": tables_count, + "tables_summary": tables_summary + }) + logger.info(f"[DEBUG] source_contents built: filename={sf.filename}, content_len={len(content)}, titles_count={len(titles)}, tables_count={tables_count}") + if tables_summary: + logger.info(f"[DEBUG] tables_summary preview: {tables_summary[:300]}") + except Exception as e: + logger.warning(f"解析源文档失败 {sf.filename}: {e}") + + # 3. 根据源文档内容生成表头 + template_fields = await template_fill_service.get_template_fields_from_file( + template_path, + template_ext, + source_contents=source_contents # 传递源文档内容 + ) # 3. 异步处理源文档到MongoDB task_id = str(uuid.uuid4()) if source_file_info: + # 保存任务记录到 MongoDB + try: + await mongodb.insert_task( + task_id=task_id, + task_type="source_process", + status="pending", + message=f"开始处理 {len(source_file_info)} 个源文档" + ) + except Exception as mongo_err: + logger.warning(f"MongoDB 保存任务记录失败: {mongo_err}") + background_tasks.add_task( process_source_documents, task_id=task_id, @@ -232,12 +339,10 @@ async def upload_joint_template( async def process_source_documents(task_id: str, files: List[dict]): """异步处理源文档,存入MongoDB""" - from app.core.database import redis_db - try: - await redis_db.set_task_status( + await update_task_status( task_id, status="processing", - meta={"progress": 0, "message": "开始处理源文档"} + progress=0, message="开始处理源文档" ) doc_ids = [] @@ -266,22 +371,24 @@ async def process_source_documents(task_id: str, files: List[dict]): logger.error(f"源文档处理异常: {file_info['filename']}, error: {str(e)}") progress = int((i + 1) / len(files) * 100) - await redis_db.set_task_status( + await update_task_status( task_id, status="processing", - meta={"progress": progress, "message": f"已处理 {i+1}/{len(files)}"} + progress=progress, message=f"已处理 {i+1}/{len(files)}" ) - await redis_db.set_task_status( + await update_task_status( task_id, status="success", - meta={"progress": 100, "message": "源文档处理完成", "doc_ids": doc_ids} + progress=100, message="源文档处理完成", + result={"doc_ids": doc_ids} ) logger.info(f"所有源文档处理完成: {len(doc_ids)}个") except Exception as e: logger.error(f"源文档批量处理失败: {str(e)}") - await redis_db.set_task_status( + await update_task_status( task_id, status="failure", - meta={"error": str(e)} + progress=0, message="源文档处理失败", + error=str(e) ) @@ -340,7 +447,27 @@ async def fill_template( Returns: 填写结果 """ + # 生成或使用传入的 task_id + task_id = request.task_id or str(uuid.uuid4()) + try: + # 创建任务记录到 MongoDB + try: + await mongodb.insert_task( + task_id=task_id, + task_type="template_fill", + status="processing", + message=f"开始填表任务: {len(request.template_fields)} 个字段" + ) + except Exception as mongo_err: + logger.warning(f"MongoDB 创建任务记录失败: {mongo_err}") + + # 更新进度 - 开始 + await update_task_status( + task_id, "processing", + progress=0, message="开始处理..." + ) + # 转换字段 fields = [ TemplateField( @@ -353,17 +480,51 @@ async def fill_template( for f in request.template_fields ] + # 从 template_id 提取文件类型 + template_file_type = "xlsx" # 默认类型 + if request.template_id: + ext = request.template_id.split('.')[-1].lower() + if ext in ["xlsx", "xls"]: + template_file_type = "xlsx" + elif ext == "docx": + template_file_type = "docx" + + # 更新进度 - 准备开始填写 + await update_task_status( + task_id, "processing", + progress=10, message=f"准备填写 {len(fields)} 个字段..." + ) + # 执行填写 result = await template_fill_service.fill_template( template_fields=fields, source_doc_ids=request.source_doc_ids, source_file_paths=request.source_file_paths, - user_hint=request.user_hint + user_hint=request.user_hint, + template_id=request.template_id, + template_file_type=template_file_type, + task_id=task_id ) - return result + # 更新为成功 + await update_task_status( + task_id, "success", + progress=100, message="填表完成", + result={ + "field_count": len(fields), + "max_rows": result.get("max_rows", 0) + } + ) + + return {**result, "task_id": task_id} except Exception as e: + # 更新为失败 + await update_task_status( + task_id, "failure", + progress=0, message="填表失败", + error=str(e) + ) logger.error(f"填写表格失败: {str(e)}") raise HTTPException(status_code=500, detail=f"填写失败: {str(e)}") diff --git a/backend/app/api/endpoints/upload.py b/backend/app/api/endpoints/upload.py index d9d9ada..ca9c8df 100644 --- a/backend/app/api/endpoints/upload.py +++ b/backend/app/api/endpoints/upload.py @@ -5,6 +5,7 @@ from fastapi import APIRouter, UploadFile, File, HTTPException, Query from fastapi.responses import StreamingResponse from typing import Optional import logging +import os import pandas as pd import io @@ -126,7 +127,7 @@ async def upload_excel( content += f"... (共 {len(sheet_data['rows'])} 行)\n\n" doc_metadata = { - "filename": saved_path.split("/")[-1] if "/" in saved_path else saved_path.split("\\")[-1], + "filename": os.path.basename(saved_path), "original_filename": file.filename, "saved_path": saved_path, "file_size": len(content), @@ -253,7 +254,7 @@ async def export_excel( output.seek(0) # 生成文件名 - original_name = file_path.split('/')[-1] if '/' in file_path else file_path + original_name = os.path.basename(file_path) if columns: export_name = f"export_{sheet_name or 'data'}_{len(column_list) if columns else 'all'}_cols.xlsx" else: diff --git a/backend/app/core/database/mongodb.py b/backend/app/core/database/mongodb.py index 01626bc..0a20cd2 100644 --- a/backend/app/core/database/mongodb.py +++ b/backend/app/core/database/mongodb.py @@ -59,6 +59,11 @@ class MongoDB: """RAG索引集合 - 存储字段语义索引""" return self.db["rag_index"] + @property + def tasks(self): + """任务集合 - 存储任务历史记录""" + return self.db["tasks"] + # ==================== 文档操作 ==================== async def insert_document( @@ -264,8 +269,128 @@ class MongoDB: await self.rag_index.create_index("table_name") await self.rag_index.create_index("field_name") + # 任务集合索引 + await self.tasks.create_index("task_id", unique=True) + await self.tasks.create_index("created_at") + logger.info("MongoDB 索引创建完成") + # ==================== 任务历史操作 ==================== + + async def insert_task( + self, + task_id: str, + task_type: str, + status: str = "pending", + message: str = "", + result: Optional[Dict[str, Any]] = None, + error: Optional[str] = None, + ) -> str: + """ + 插入任务记录 + + Args: + task_id: 任务ID + task_type: 任务类型 + status: 任务状态 + message: 任务消息 + result: 任务结果 + error: 错误信息 + + Returns: + 插入文档的ID + """ + task = { + "task_id": task_id, + "task_type": task_type, + "status": status, + "message": message, + "result": result, + "error": error, + "created_at": datetime.utcnow(), + "updated_at": datetime.utcnow(), + } + result_obj = await self.tasks.insert_one(task) + return str(result_obj.inserted_id) + + async def update_task( + self, + task_id: str, + status: Optional[str] = None, + message: Optional[str] = None, + result: Optional[Dict[str, Any]] = None, + error: Optional[str] = None, + ) -> bool: + """ + 更新任务状态 + + Args: + task_id: 任务ID + status: 任务状态 + message: 任务消息 + result: 任务结果 + error: 错误信息 + + Returns: + 是否更新成功 + """ + from bson import ObjectId + + update_data = {"updated_at": datetime.utcnow()} + if status is not None: + update_data["status"] = status + if message is not None: + update_data["message"] = message + if result is not None: + update_data["result"] = result + if error is not None: + update_data["error"] = error + + update_result = await self.tasks.update_one( + {"task_id": task_id}, + {"$set": update_data} + ) + return update_result.modified_count > 0 + + async def get_task(self, task_id: str) -> Optional[Dict[str, Any]]: + """根据task_id获取任务""" + task = await self.tasks.find_one({"task_id": task_id}) + if task: + task["_id"] = str(task["_id"]) + return task + + async def list_tasks( + self, + limit: int = 50, + skip: int = 0, + ) -> List[Dict[str, Any]]: + """ + 获取任务列表 + + Args: + limit: 返回数量 + skip: 跳过数量 + + Returns: + 任务列表 + """ + cursor = self.tasks.find().sort("created_at", -1).skip(skip).limit(limit) + tasks = [] + async for task in cursor: + task["_id"] = str(task["_id"]) + # 转换 datetime 为字符串 + if task.get("created_at"): + task["created_at"] = task["created_at"].isoformat() + if task.get("updated_at"): + task["updated_at"] = task["updated_at"].isoformat() + tasks.append(task) + return tasks + + async def delete_task(self, task_id: str) -> bool: + """删除任务""" + result = await self.tasks.delete_one({"task_id": task_id}) + return result.deleted_count > 0 + # ==================== 全局单例 ==================== diff --git a/backend/app/core/document_parser/xlsx_parser.py b/backend/app/core/document_parser/xlsx_parser.py index 47cd232..a0216a1 100644 --- a/backend/app/core/document_parser/xlsx_parser.py +++ b/backend/app/core/document_parser/xlsx_parser.py @@ -317,24 +317,70 @@ class XlsxParser(BaseParser): import zipfile from xml.etree import ElementTree as ET + # 常见的命名空间 + COMMON_NAMESPACES = [ + 'http://schemas.openxmlformats.org/spreadsheetml/2006/main', + 'http://schemas.openxmlformats.org/spreadsheetml/2005/main', + 'http://schemas.openxmlformats.org/spreadsheetml/2004/main', + 'http://schemas.openxmlformats.org/spreadsheetml/2003/main', + ] + try: with zipfile.ZipFile(file_path, 'r') as z: - if 'xl/workbook.xml' not in z.namelist(): + # 尝试多种可能的 workbook.xml 路径 + possible_paths = ['xl/workbook.xml', 'xl\\workbook.xml', 'workbook.xml'] + content = None + for path in possible_paths: + if path in z.namelist(): + content = z.read(path) + logger.info(f"找到 workbook.xml at: {path}") + break + + if content is None: + logger.warning(f"未找到 workbook.xml,文件列表: {z.namelist()[:10]}") return [] - content = z.read('xl/workbook.xml') + root = ET.fromstring(content) - # 命名空间 - ns = {'main': 'http://schemas.openxmlformats.org/spreadsheetml/2006/main'} - sheet_names = [] - for sheet in root.findall('.//main:sheet', ns): - name = sheet.get('name') - if name: - sheet_names.append(name) + + # 方法1:尝试带命名空间的查找 + for ns in COMMON_NAMESPACES: + sheet_elements = root.findall(f'.//{{{ns}}}sheet') + if sheet_elements: + for sheet in sheet_elements: + name = sheet.get('name') + if name: + sheet_names.append(name) + if sheet_names: + logger.info(f"使用命名空间 {ns} 提取工作表: {sheet_names}") + return sheet_names + + # 方法2:不使用命名空间,直接查找所有 sheet 元素 + if not sheet_names: + for elem in root.iter(): + if elem.tag.endswith('sheet') and elem.tag != 'sheets': + name = elem.get('name') + if name: + sheet_names.append(name) + for child in elem: + if child.tag.endswith('sheet') or child.tag == 'sheet': + name = child.get('name') + if name and name not in sheet_names: + sheet_names.append(name) + + # 方法3:直接从 XML 文本中正则匹配 sheet name + if not sheet_names: + import re + xml_str = content.decode('utf-8', errors='ignore') + matches = re.findall(r']*name=["\']([^"\']+)["\']', xml_str, re.IGNORECASE) + if matches: + sheet_names = matches + logger.info(f"使用正则提取工作表: {sheet_names}") logger.info(f"从 XML 提取工作表: {sheet_names}") return sheet_names + except Exception as e: logger.error(f"从 XML 提取工作表名称失败: {e}") return [] @@ -356,6 +402,32 @@ class XlsxParser(BaseParser): import zipfile from xml.etree import ElementTree as ET + # 常见的命名空间 + COMMON_NAMESPACES = [ + 'http://schemas.openxmlformats.org/spreadsheetml/2006/main', + 'http://schemas.openxmlformats.org/spreadsheetml/2005/main', + 'http://schemas.openxmlformats.org/spreadsheetml/2004/main', + 'http://schemas.openxmlformats.org/spreadsheetml/2003/main', + ] + + def find_elements_with_ns(root, tag_name): + """灵活查找元素,支持任意命名空间""" + results = [] + # 方法1:用固定命名空间 + for ns in COMMON_NAMESPACES: + try: + elems = root.findall(f'.//{{{ns}}}{tag_name}') + if elems: + results.extend(elems) + except: + pass + # 方法2:不带命名空间查找 + if not results: + for elem in root.iter(): + if elem.tag.endswith('}' + tag_name): + results.append(elem) + return results + with zipfile.ZipFile(file_path, 'r') as z: # 获取工作表名称 sheet_names = self._extract_sheet_names_from_xml(file_path) @@ -366,57 +438,68 @@ class XlsxParser(BaseParser): target_sheet = sheet_name if sheet_name and sheet_name in sheet_names else sheet_names[0] sheet_index = sheet_names.index(target_sheet) + 1 # sheet1.xml, sheet2.xml, ... - # 读取 shared strings + # 读取 shared strings - 尝试多种路径 shared_strings = [] - if 'xl/sharedStrings.xml' in z.namelist(): - ss_content = z.read('xl/sharedStrings.xml') - ss_root = ET.fromstring(ss_content) - ns = {'main': 'http://schemas.openxmlformats.org/spreadsheetml/2006/main'} - for si in ss_root.findall('.//main:si', ns): - t = si.find('.//main:t', ns) - if t is not None: - shared_strings.append(t.text or '') - else: - shared_strings.append('') + ss_paths = ['xl/sharedStrings.xml', 'xl\\sharedStrings.xml', 'sharedStrings.xml'] + for ss_path in ss_paths: + if ss_path in z.namelist(): + try: + ss_content = z.read(ss_path) + ss_root = ET.fromstring(ss_content) + for si in find_elements_with_ns(ss_root, 'si'): + t_elements = [c for c in si if c.tag.endswith('}t') or c.tag == 't'] + if t_elements: + shared_strings.append(t_elements[0].text or '') + else: + shared_strings.append('') + break + except Exception as e: + logger.warning(f"读取 sharedStrings 失败: {e}") - # 读取工作表 - sheet_file = f'xl/worksheets/sheet{sheet_index}.xml' - if sheet_file not in z.namelist(): - raise ValueError(f"工作表文件 {sheet_file} 不存在") + # 读取工作表 - 尝试多种可能的路径 + sheet_content = None + sheet_paths = [ + f'xl/worksheets/sheet{sheet_index}.xml', + f'xl\\worksheets\\sheet{sheet_index}.xml', + f'worksheets/sheet{sheet_index}.xml', + ] + for sp in sheet_paths: + if sp in z.namelist(): + sheet_content = z.read(sp) + break + + if sheet_content is None: + raise ValueError(f"工作表文件 sheet{sheet_index}.xml 不存在") - sheet_content = z.read(sheet_file) root = ET.fromstring(sheet_content) - ns = {'main': 'http://schemas.openxmlformats.org/spreadsheetml/2006/main'} # 收集所有行数据 all_rows = [] headers = {} - for row in root.findall('.//main:row', ns): + for row in find_elements_with_ns(root, 'row'): row_idx = int(row.get('r', 0)) row_cells = {} - for cell in row.findall('main:c', ns): + for cell in find_elements_with_ns(row, 'c'): cell_ref = cell.get('r', '') col_letters = ''.join(filter(str.isalpha, cell_ref)) cell_type = cell.get('t', 'n') - v = cell.find('main:v', ns) + v_elements = find_elements_with_ns(cell, 'v') + v = v_elements[0] if v_elements else None if v is not None and v.text: if cell_type == 's': - # shared string try: row_cells[col_letters] = shared_strings[int(v.text)] except (ValueError, IndexError): row_cells[col_letters] = v.text elif cell_type == 'b': - # boolean row_cells[col_letters] = v.text == '1' else: row_cells[col_letters] = v.text else: row_cells[col_letters] = None - # 处理表头行 if row_idx == header_row + 1: headers = {**row_cells} elif row_idx > header_row + 1: @@ -424,7 +507,6 @@ class XlsxParser(BaseParser): # 构建 DataFrame if headers: - # 按原始列顺序排列 col_order = list(headers.keys()) df = pd.DataFrame(all_rows) if not df.empty: diff --git a/backend/app/instruction/__init__.py b/backend/app/instruction/__init__.py index e69de29..1386f3d 100644 --- a/backend/app/instruction/__init__.py +++ b/backend/app/instruction/__init__.py @@ -0,0 +1,15 @@ +""" +指令执行模块 + +注意: 此模块为可选功能,当前尚未实现。 +如需启用,请实现 intent_parser.py 和 executor.py +""" +from .intent_parser import IntentParser, DefaultIntentParser +from .executor import InstructionExecutor, DefaultInstructionExecutor + +__all__ = [ + "IntentParser", + "DefaultIntentParser", + "InstructionExecutor", + "DefaultInstructionExecutor", +] diff --git a/backend/app/instruction/executor.py b/backend/app/instruction/executor.py index e69de29..36292ce 100644 --- a/backend/app/instruction/executor.py +++ b/backend/app/instruction/executor.py @@ -0,0 +1,35 @@ +""" +指令执行器模块 + +将自然语言指令转换为可执行操作 + +注意: 此模块为可选功能,当前尚未实现。 +""" +from abc import ABC, abstractmethod +from typing import Any, Dict + + +class InstructionExecutor(ABC): + """指令执行器抽象基类""" + + @abstractmethod + async def execute(self, instruction: str, context: Dict[str, Any]) -> Dict[str, Any]: + """ + 执行指令 + + Args: + instruction: 解析后的指令 + context: 执行上下文 + + Returns: + 执行结果 + """ + pass + + +class DefaultInstructionExecutor(InstructionExecutor): + """默认指令执行器""" + + async def execute(self, instruction: str, context: Dict[str, Any]) -> Dict[str, Any]: + """暂未实现""" + raise NotImplementedError("指令执行功能暂未实现") diff --git a/backend/app/instruction/intent_parser.py b/backend/app/instruction/intent_parser.py index e69de29..49df250 100644 --- a/backend/app/instruction/intent_parser.py +++ b/backend/app/instruction/intent_parser.py @@ -0,0 +1,34 @@ +""" +意图解析器模块 + +解析用户自然语言指令,识别意图和参数 + +注意: 此模块为可选功能,当前尚未实现。 +""" +from abc import ABC, abstractmethod +from typing import Any, Dict, Tuple + + +class IntentParser(ABC): + """意图解析器抽象基类""" + + @abstractmethod + async def parse(self, text: str) -> Tuple[str, Dict[str, Any]]: + """ + 解析自然语言指令 + + Args: + text: 用户输入的自然语言 + + Returns: + (意图类型, 参数字典) + """ + pass + + +class DefaultIntentParser(IntentParser): + """默认意图解析器""" + + async def parse(self, text: str) -> Tuple[str, Dict[str, Any]]: + """暂未实现""" + raise NotImplementedError("意图解析功能暂未实现") diff --git a/backend/app/services/template_fill_service.py b/backend/app/services/template_fill_service.py index cbaaea1..76c254f 100644 --- a/backend/app/services/template_fill_service.py +++ b/backend/app/services/template_fill_service.py @@ -61,7 +61,10 @@ class TemplateFillService: template_fields: List[TemplateField], source_doc_ids: Optional[List[str]] = None, source_file_paths: Optional[List[str]] = None, - user_hint: Optional[str] = None + user_hint: Optional[str] = None, + template_id: Optional[str] = None, + template_file_type: Optional[str] = "xlsx", + task_id: Optional[str] = None ) -> Dict[str, Any]: """ 填写表格模板 @@ -71,6 +74,9 @@ class TemplateFillService: source_doc_ids: 源文档 MongoDB ID 列表 source_file_paths: 源文档文件路径列表 user_hint: 用户提示(如"请从合同文档中提取") + template_id: 模板文件路径(用于重新生成表头) + template_file_type: 模板文件类型 + task_id: 可选的任务ID,用于任务进度跟踪 Returns: 填写结果 @@ -79,15 +85,94 @@ class TemplateFillService: fill_details = [] logger.info(f"开始填表: {len(template_fields)} 个字段, {len(source_doc_ids or [])} 个源文档") + logger.info(f"source_doc_ids: {source_doc_ids}") + logger.info(f"source_file_paths: {source_file_paths}") # 1. 加载源文档内容 source_docs = await self._load_source_documents(source_doc_ids, source_file_paths) logger.info(f"加载了 {len(source_docs)} 个源文档") + # 打印每个加载的文档的详细信息 + for i, doc in enumerate(source_docs): + logger.info(f" 文档[{i}]: id={doc.doc_id}, filename={doc.filename}, doc_type={doc.doc_type}") + logger.info(f" content长度: {len(doc.content)}, structured_data keys: {list(doc.structured_data.keys()) if doc.structured_data else 'None'}") + if not source_docs: logger.warning("没有找到源文档,填表结果将全部为空") + # 3. 检查是否需要使用源文档重新生成表头 + # 条件:源文档已加载 AND 现有字段看起来是自动生成的(如"字段1"、"字段2") + needs_regenerate_headers = ( + len(source_docs) > 0 and + len(template_fields) > 0 and + all(self._is_auto_generated_field(f.name) for f in template_fields) + ) + + if needs_regenerate_headers: + logger.info(f"检测到自动生成表头,尝试使用源文档重新生成... (当前字段: {[f.name for f in template_fields]})") + + # 将 SourceDocument 转换为 source_contents 格式 + source_contents = [] + for doc in source_docs: + structured = doc.structured_data if doc.structured_data else {} + + # 获取标题 + titles = structured.get("titles", []) + if not titles: + titles = [] + + # 获取表格 + tables = structured.get("tables", []) + tables_count = len(tables) if tables else 0 + + # 生成表格摘要 + tables_summary = "" + if tables: + tables_summary = "\n【文档中的表格】:\n" + for idx, table in enumerate(tables[:5]): + if isinstance(table, dict): + headers = table.get("headers", []) + rows = table.get("rows", []) + if headers: + tables_summary += f"表格{idx+1}表头: {', '.join(str(h) for h in headers)}\n" + if rows: + tables_summary += f"表格{idx+1}前3行: " + for row_idx, row in enumerate(rows[:3]): + if isinstance(row, list): + tables_summary += " | ".join(str(c) for c in row) + "; " + elif isinstance(row, dict): + tables_summary += " | ".join(str(row.get(h, "")) for h in headers if headers) + "; " + tables_summary += "\n" + + source_contents.append({ + "filename": doc.filename, + "doc_type": doc.doc_type, + "content": doc.content[:5000] if doc.content else "", + "titles": titles[:10] if titles else [], + "tables_count": tables_count, + "tables_summary": tables_summary + }) + + # 使用源文档内容重新生成表头 + if template_id and template_file_type: + logger.info(f"使用源文档重新生成表头: template_id={template_id}, template_file_type={template_file_type}") + new_fields = await self.get_template_fields_from_file( + template_id, + template_file_type, + source_contents=source_contents + ) + if new_fields and len(new_fields) > 0: + logger.info(f"成功重新生成表头: {[f.name for f in new_fields]}") + template_fields = new_fields + else: + logger.warning("重新生成表头返回空结果,使用原始字段") + else: + logger.warning("无法重新生成表头:缺少 template_id 或 template_file_type") + else: + if source_docs and template_fields: + logger.info(f"表头看起来正常(非自动生成),无需重新生成: {[f.name for f in template_fields[:5]]}") + # 2. 对每个字段进行提取 for idx, field in enumerate(template_fields): try: @@ -99,6 +184,22 @@ class TemplateFillService: user_hint=user_hint ) + # AI审核:验证提取的值是否合理 + if result.values and result.values[0]: + logger.info(f"字段 {field.name} 进入AI审核阶段...") + verified_result = await self._verify_field_value( + field=field, + extracted_values=result.values, + source_docs=source_docs, + user_hint=user_hint + ) + if verified_result: + # 审核给出了修正结果 + result = verified_result + logger.info(f"字段 {field.name} 审核后修正值: {result.values[:3]}") + else: + logger.info(f"字段 {field.name} 审核通过,使用原提取结果") + # 存储结果 - 使用 values 数组 filled_data[field.name] = result.values if result.values else [""] fill_details.append({ @@ -159,14 +260,49 @@ class TemplateFillService: try: doc = await mongodb.get_document(doc_id) if doc: + sd = doc.get("structured_data", {}) + sd_keys = list(sd.keys()) if sd else [] + logger.info(f"从MongoDB加载文档: {doc_id}, doc_type={doc.get('doc_type')}, structured_data keys={sd_keys}") + + # 如果 structured_data 为空,但有 file_path,尝试重新解析文件 + doc_content = doc.get("content", "") + if not sd or (not sd.get("tables") and not sd.get("headers") and not sd.get("rows")): + file_path = doc.get("metadata", {}).get("file_path") + if file_path: + logger.info(f" structured_data 为空,尝试重新解析文件: {file_path}") + try: + parser = ParserFactory.get_parser(file_path) + result = parser.parse(file_path) + if result.success and result.data: + if result.data.get("structured_data"): + sd = result.data.get("structured_data") + logger.info(f" 重新解析成功,structured_data keys: {list(sd.keys())}") + elif result.data.get("tables"): + sd = {"tables": result.data.get("tables", [])} + logger.info(f" 使用 data.tables,tables数量: {len(sd.get('tables', []))}") + elif result.data.get("rows"): + sd = result.data + logger.info(f" 使用 data.rows 格式") + if result.data.get("content"): + doc_content = result.data.get("content", "") + else: + logger.warning(f" 重新解析失败: {result.error if result else 'unknown'}") + except Exception as parse_err: + logger.error(f" 重新解析文件异常: {str(parse_err)}") + + if sd.get("tables"): + logger.info(f" tables数量: {len(sd.get('tables', []))}") + if sd["tables"]: + first_table = sd["tables"][0] + logger.info(f" 第一表格: headers={first_table.get('headers', [])[:3]}..., rows数量={len(first_table.get('rows', []))}") + source_docs.append(SourceDocument( doc_id=doc_id, filename=doc.get("metadata", {}).get("original_filename", "unknown"), doc_type=doc.get("doc_type", "unknown"), - content=doc.get("content", ""), - structured_data=doc.get("structured_data", {}) + content=doc_content, + structured_data=sd )) - logger.info(f"从MongoDB加载文档: {doc_id}") except Exception as e: logger.error(f"从MongoDB加载文档失败 {doc_id}: {str(e)}") @@ -370,7 +506,7 @@ class TemplateFillService: response = await self.llm.chat( messages=messages, temperature=0.1, - max_tokens=50000 + max_tokens=4000 ) content = self.llm.extract_message_content(response) @@ -476,6 +612,137 @@ class TemplateFillService: confidence=0.0 ) + async def _verify_field_value( + self, + field: TemplateField, + extracted_values: List[str], + source_docs: List[SourceDocument], + user_hint: Optional[str] = None + ) -> Optional[FillResult]: + """ + 验证并修正提取的字段值 + + Args: + field: 字段定义 + extracted_values: 已提取的值 + source_docs: 源文档列表 + user_hint: 用户提示 + + Returns: + 验证后的结果,如果验证通过返回None(使用原结果) + """ + if not extracted_values or not extracted_values[0]: + return None + + if not source_docs: + return None + + try: + # 构建验证上下文 + context_text = self._build_context_text(source_docs, field_name=field.name, max_length=15000) + + hint_text = field.hint if field.hint else f"请理解{field.name}字段的含义" + if user_hint: + hint_text = f"{user_hint}。{hint_text}" + + prompt = f"""你是一个数据质量审核专家。请审核以下提取的数据是否合理。 + +【待审核字段】 +字段名:{field.name} +字段说明:{hint_text} + +【已提取的值】 +{extracted_values[:10]} # 最多审核前10个值 + +【源文档上下文】 +{context_text[:8000]} + +【审核要求】 +1. 这些值是否符合字段的含义? +2. 值在原文中的原始含义是什么?检查是否有误解或误提取 +3. 是否存在明显错误、空值或不合理的数据? +4. 如果表格有多个列,请确认提取的是正确的列 + +请严格按照以下 JSON 格式输出(只需输出 JSON,不要其他内容): +{{ + "is_valid": true或false, + "corrected_values": ["修正后的值列表"] 或 null(如果无需修正), + "reason": "审核说明,解释判断理由", + "original_meaning": "值在原文中的原始含义描述" +}} +""" + + messages = [ + {"role": "system", "content": "你是一个严格的数据质量审核专家。请仔细核对原文和提取的值是否匹配。"}, + {"role": "user", "content": prompt} + ] + + response = await self.llm.chat( + messages=messages, + temperature=0.2, + max_tokens=3000 + ) + + content = self.llm.extract_message_content(response) + logger.info(f"字段 {field.name} 审核返回: {content[:300]}") + + # 解析 JSON + import json + import re + + cleaned = content.strip() + cleaned = re.sub(r'^```json\s*', '', cleaned, flags=re.MULTILINE) + cleaned = re.sub(r'^```\s*', '', cleaned, flags=re.MULTILINE) + cleaned = cleaned.strip() + + json_start = -1 + for i, c in enumerate(cleaned): + if c == '{': + json_start = i + break + + if json_start == -1: + logger.warning(f"字段 {field.name} 审核:无法找到 JSON") + return None + + json_text = cleaned[json_start:] + result = json.loads(json_text) + + is_valid = result.get("is_valid", True) + corrected_values = result.get("corrected_values") + reason = result.get("reason", "") + original_meaning = result.get("original_meaning", "") + + logger.info(f"字段 {field.name} 审核结果: is_valid={is_valid}, reason={reason[:100]}") + + if not is_valid and corrected_values: + # 值有问题且有修正建议,使用修正后的值 + logger.info(f"字段 {field.name} 使用修正后的值: {corrected_values[:5]}") + return FillResult( + field=field.name, + values=corrected_values, + value=corrected_values[0] if corrected_values else "", + source=f"AI审核修正: {reason[:100]}", + confidence=0.7 + ) + elif not is_valid and original_meaning: + # 值有问题但无修正,记录原始含义供用户参考 + logger.info(f"字段 {field.name} 审核发现问题: {original_meaning}") + return FillResult( + field=field.name, + values=extracted_values, + value=extracted_values[0] if extracted_values else "", + source=f"AI审核疑问: {original_meaning[:100]}", + confidence=0.5 + ) + + # 验证通过,返回 None 表示使用原结果 + return None + + except Exception as e: + logger.error(f"字段 {field.name} 审核失败: {str(e)}") + return None + def _build_context_text(self, source_docs: List[SourceDocument], field_name: str = None, max_length: int = 8000) -> str: """ 构建上下文文本 @@ -625,7 +892,8 @@ class TemplateFillService: async def get_template_fields_from_file( self, file_path: str, - file_type: str = "xlsx" + file_type: str = "xlsx", + source_contents: List[dict] = None ) -> List[TemplateField]: """ 从模板文件提取字段定义 @@ -633,11 +901,14 @@ class TemplateFillService: Args: file_path: 模板文件路径 file_type: 文件类型 (xlsx/xls/docx) + source_contents: 源文档内容列表(用于 AI 生成表头) Returns: 字段列表 """ fields = [] + if source_contents is None: + source_contents = [] try: if file_type in ["xlsx", "xls"]: @@ -653,8 +924,8 @@ class TemplateFillService: ) if needs_ai_generation: - logger.info(f"模板表头为空或自动生成,尝试 AI 生成表头... (fields={len(fields)})") - ai_fields = await self._generate_fields_with_ai(file_path, file_type) + logger.info(f"模板表头为空或自动生成,尝试 AI 生成表头... (fields={len(fields)}, source_docs={len(source_contents)})") + ai_fields = await self._generate_fields_with_ai(file_path, file_type, source_contents) if ai_fields: fields = ai_fields logger.info(f"AI 生成表头成功: {len(fields)} 个字段") @@ -857,7 +1128,7 @@ class TemplateFillService: def _extract_values_from_structured_data(self, source_docs: List[SourceDocument], field_name: str) -> List[str]: """ - 从结构化数据(Excel rows)中直接提取指定列的值 + 从结构化数据(Excel rows 或 Markdown tables)中直接提取指定列的值 适用于有 rows 结构的文档数据,无需 LLM 即可提取 @@ -869,10 +1140,15 @@ class TemplateFillService: 值列表,如果无法提取则返回空列表 """ all_values = [] + logger.info(f"[_extract_values_from_structured_data] 开始提取字段: {field_name}") + logger.info(f" source_docs 数量: {len(source_docs)}") - for doc in source_docs: + for doc_idx, doc in enumerate(source_docs): # 尝试从 structured_data 中提取 structured = doc.structured_data + logger.info(f" 文档[{doc_idx}]: {doc.filename}, structured类型: {type(structured)}, 是否为空: {not bool(structured)}") + if structured: + logger.info(f" structured_data keys: {list(structured.keys())}") if not structured: continue @@ -892,6 +1168,33 @@ class TemplateFillService: if all_values: break + # 处理 Markdown 表格格式: {headers: [...], rows: [...], ...} + elif structured.get("headers") and structured.get("rows"): + headers = structured.get("headers", []) + rows = structured.get("rows", []) + values = self._extract_values_from_markdown_table(headers, rows, field_name) + if values: + all_values.extend(values) + logger.info(f"从 Markdown 文档 {doc.filename} 提取到 {len(values)} 个值") + break + + # 处理 MongoDB 存储的 tables 格式: {tables: [{headers, rows, ...}, ...]} + elif structured.get("tables") and isinstance(structured.get("tables"), list): + tables = structured.get("tables", []) + logger.info(f" 检测到 tables 格式,共 {len(tables)} 个表") + for table_idx, table in enumerate(tables): + if isinstance(table, dict): + headers = table.get("headers", []) + rows = table.get("rows", []) + logger.info(f" 表格[{table_idx}]: headers={headers[:3]}..., rows数量={len(rows)}") + values = self._extract_values_from_markdown_table(headers, rows, field_name) + if values: + all_values.extend(values) + logger.info(f"从表格[{table_idx}] 提取到 {len(values)} 个值") + break + if all_values: + break + # 处理单 sheet 格式: {columns: [...], rows: [...]} elif structured.get("rows"): columns = structured.get("columns", []) @@ -945,16 +1248,18 @@ class TemplateFillService: if not table_rows or len(table_rows) < 2: return [] - # 第一步:尝试在 header(第一行)中查找匹配列 - target_col_idx = None - for col_idx, col_name in enumerate(header): - col_str = str(col_name).strip() - if field_name.lower() in col_str.lower() or col_str.lower() in field_name.lower(): - target_col_idx = col_idx - break + # 使用增强的匹配算法查找最佳匹配的列索引 + target_col_idx = self._find_best_matching_column(header, field_name) + + # 如果增强匹配没找到,尝试在 header(第一行)中查找 + if target_col_idx is None: + for col_idx, col_name in enumerate(header): + col_str = str(col_name).strip() + if field_name.lower() in col_str.lower() or col_str.lower() in field_name.lower(): + target_col_idx = col_idx + break # 如果 header 中没找到,尝试在 table_rows[1](第二行)中查找 - # 这是因为有时第一行是数据而不是表头 if target_col_idx is None and len(table_rows) > 1: second_row = table_rows[1] if isinstance(second_row, list): @@ -970,33 +1275,112 @@ class TemplateFillService: return [] # 确定从哪一行开始提取数据 - # 如果 header 是表头(包含 field_name),则从 table_rows[1] 开始提取 - # 如果 header 是数据(不包含 field_name),则从 table_rows[2] 开始提取 header_contains_field = any( field_name.lower() in str(col).strip().lower() or str(col).strip().lower() in field_name.lower() for col in header ) if header_contains_field: - # header 是表头,从第二行开始提取 data_start_idx = 1 else: - # header 是数据,从第三行开始提取(跳过表头和第一行数据) data_start_idx = 2 # 提取值 values = [] for row_idx, row in enumerate(table_rows[data_start_idx:], start=data_start_idx): if isinstance(row, list) and target_col_idx < len(row): - val = str(row[target_col_idx]).strip() if row[target_col_idx] else "" - values.append(val) + val = row[target_col_idx] + values.append(self._format_value(val)) elif isinstance(row, dict): - val = str(row.get(target_col_idx, "")).strip() - values.append(val) + val = row.get(target_col_idx, "") + values.append(self._format_value(val)) logger.info(f"从 Word 表格列 {target_col_idx} 提取到 {len(values)} 个值: {values[:3]}") return values + def _format_value(self, val: Any) -> str: + """ + 格式化值为字符串,保持原始格式 + + Args: + val: 原始值 + + Returns: + 格式化后的字符串 + """ + if val is None: + return "" + + if isinstance(val, str): + return val.strip() + + if isinstance(val, bool): + return "true" if val else "false" + + if isinstance(val, (int, float)): + if isinstance(val, float): + if val == int(val): + return str(int(val)) + else: + formatted = f"{val:.10f}".rstrip('0').rstrip('.') + return formatted + else: + return str(val) + + return str(val) + + def _find_best_matching_column(self, headers: List, field_name: str) -> Optional[int]: + """ + 查找最佳匹配的列索引 + + 使用多层匹配策略: + 1. 精确匹配(忽略大小写) + 2. 子字符串匹配(字段名在表头中,或表头在字段名中) + 3. 关键词重叠匹配(中文字符串分割后比对) + + Args: + headers: 表头列表 + field_name: 要匹配的字段名 + + Returns: + 匹配的列索引,找不到返回 None + """ + field_lower = field_name.lower().strip() + field_keywords = set(field_lower.replace(" ", "").split()) + + best_match_idx = None + best_match_score = 0 + + for idx, header in enumerate(headers): + header_str = str(header).strip() + header_lower = header_str.lower() + + # 策略1: 精确匹配(忽略大小写) + if header_lower == field_lower: + return idx + + # 策略2: 子字符串匹配 + if field_lower in header_lower or header_lower in field_lower: + score = max(len(field_lower), len(header_lower)) / min(len(field_lower) + 1, len(header_lower) + 1) + if score > best_match_score: + best_match_score = score + best_match_idx = idx + continue + + # 策略3: 关键词重叠匹配(适用于中文) + header_keywords = set(header_lower.replace(" ", "").split()) + overlap = field_keywords & header_keywords + if overlap and len(overlap) > 0: + score = len(overlap) / max(len(field_keywords), len(header_keywords), 1) + if score > best_match_score: + best_match_score = score + best_match_idx = idx + + if best_match_score >= 0.3: + logger.info(f"模糊匹配: {field_name} -> {headers[best_match_idx]} (分数: {best_match_score:.2f})") + return best_match_idx + + return None def _extract_column_values(self, rows: List, columns: List, field_name: str) -> List[str]: """ 从 rows 和 columns 中提取指定列的值 @@ -1012,27 +1396,25 @@ class TemplateFillService: if not rows or not columns: return [] - # 查找匹配的列(模糊匹配) - target_col = None - for col in columns: - col_str = str(col) - if field_name.lower() in col_str.lower() or col_str.lower() in field_name.lower(): - target_col = col - break + # 使用增强的匹配算法查找最佳匹配的列索引 + target_idx = self._find_best_matching_column(columns, field_name) - if not target_col: + if target_idx is None: logger.warning(f"未找到匹配列: {field_name}, 可用列: {columns}") return [] + target_col = columns[target_idx] + logger.info(f"列匹配成功: {field_name} -> {target_col} (索引: {target_idx})") + values = [] for row in rows: if isinstance(row, dict): val = row.get(target_col, "") - elif isinstance(row, list) and target_col in columns: - val = row[columns.index(target_col)] + elif isinstance(row, list) and target_idx < len(row): + val = row[target_idx] else: val = "" - values.append(str(val) if val is not None else "") + values.append(self._format_value(val)) return values @@ -1046,7 +1428,6 @@ class TemplateFillService: Returns: (值列表, 置信度) 元组 """ - # 提取置信度 confidence = 0.5 if isinstance(result, dict) and "confidence" in result: try: @@ -1057,28 +1438,25 @@ class TemplateFillService: pass if isinstance(result, dict): - # 优先找 values 数组 if "values" in result and isinstance(result["values"], list): - vals = [str(v).strip() for v in result["values"] if v and str(v).strip()] + vals = [self._format_value(v).strip() for v in result["values"] if self._format_value(v).strip()] if vals: return vals, confidence - # 尝试找 value 字段 if "value" in result: - val = str(result["value"]).strip() + val = self._format_value(result["value"]).strip() if val: return [val], confidence - # 尝试找任何数组类型的键 for key in result.keys(): val = result[key] if isinstance(val, list) and len(val) > 0: if all(isinstance(v, (str, int, float, bool)) or v is None for v in val): - vals = [str(v).strip() for v in val if v is not None and str(v).strip()] + vals = [self._format_value(v).strip() for v in val if v is not None and self._format_value(v).strip()] if vals: return vals, confidence elif isinstance(val, (str, int, float, bool)): - return [str(val).strip()], confidence + return [self._format_value(val).strip()], confidence elif isinstance(result, list): - vals = [str(v).strip() for v in result if v is not None and str(v).strip()] + vals = [self._format_value(v).strip() for v in result if v is not None and self._format_value(v).strip()] if vals: return vals, confidence return [], confidence @@ -1215,15 +1593,15 @@ class TemplateFillService: if isinstance(parsed, dict): # 如果是 {"values": [...]} 格式,提取 values if "values" in parsed and isinstance(parsed["values"], list): - return [str(v).strip() for v in parsed["values"] if v and str(v).strip()] + return [self._format_value(v).strip() for v in parsed["values"] if self._format_value(v).strip()] # 如果是其他 dict 格式,尝试找 values 键 for key in ["values", "value", "data", "result"]: if key in parsed and isinstance(parsed[key], list): - return [str(v).strip() for v in parsed[key] if v and str(v).strip()] + return [self._format_value(v).strip() for v in parsed[key] if self._format_value(v).strip()] elif key in parsed: - return [str(parsed[key]).strip()] + return [self._format_value(parsed[key]).strip()] elif isinstance(parsed, list): - return [str(v).strip() for v in parsed if v and str(v).strip()] + return [self._format_value(v).strip() for v in parsed if self._format_value(v).strip()] except (json.JSONDecodeError, TypeError): pass @@ -1239,14 +1617,14 @@ class TemplateFillService: result = [] for item in arr: if isinstance(item, dict) and "values" in item and isinstance(item["values"], list): - result.extend([str(v).strip() for v in item["values"] if v and str(v).strip()]) + result.extend([self._format_value(v).strip() for v in item["values"] if self._format_value(v).strip()]) elif isinstance(item, dict): result.append(str(item)) else: - result.append(str(item)) + result.append(self._format_value(item)) if result: return result - return [str(v).strip() for v in arr if v and str(v).strip()] + return [self._format_value(v).strip() for v in arr if self._format_value(v).strip()] except: pass @@ -1337,27 +1715,37 @@ class TemplateFillService: hint_text = f"{user_hint}。{hint_text}" # 构建针对字段提取的提示词 - prompt = f"""你是一个专业的数据提取专家。请从以下文档内容中提取与"{field.name}"相关的所有数据。 + prompt = f"""你是一个专业的数据提取专家。请从以下文档内容中提取与"{field.name}"完全匹配的数据。 -字段提示: {hint_text} +【重要】字段名: "{field.name}" +【重要】字段提示: {hint_text} + +请严格按照以下步骤操作: +1. 在文档中搜索与"{field.name}"完全相同或高度相关的关键词 +2. 找到后,提取该关键词后的数值(注意:只要数值,不要单位) +3. 如果是表格中的数据,直接提取该单元格的数值 +4. 如果是段落描述,在关键词附近找数值 + +【重要】返回值规则: +- 只返回纯数值,不要单位(如 "4.9" 而不是 "4.9万亿元") +- 如果原文是"4.9万亿元",返回 "4.9" +- 如果原文是"144000万册",返回 "144000" +- 如果是百分比如"增长7.7%",返回 "7.7" +- 如果没有找到完全匹配的数据,返回空数组 文档内容: -{doc.content[:8000] if doc.content else ""} - -请完成以下任务: -1. 仔细阅读文档,找出所有与"{field.name}"相关的数据 -2. 如果文档中有表格数据,提取表格中的对应列值 -3. 如果文档中是段落描述,提取其中的关键数值或结论 -4. 返回提取的所有值(可能多个,用数组存储) +{doc.content[:10000] if doc.content else ""} 请用严格的 JSON 格式返回: {{ - "values": ["值1", "值2", ...], + "values": ["值1", "值2", ...], // 只填数值,不要单位 "source": "数据来源说明", "confidence": 0.0到1.0之间的置信度 }} -如果没有找到相关数据,返回空数组 values: []""" +示例: +- 如果字段是"图书馆总藏量(万册)"且文档说"图书总藏量14.4亿册",返回 values: ["144000"] +- 如果字段是"国内旅游收入(亿元)"且文档说"国内旅游收入4.9万亿元",返回 values: ["49000"]""" messages = [ {"role": "system", "content": "你是一个专业的数据提取助手,擅长从政府统计公报等文档中提取数据。请严格按JSON格式输出。"}, @@ -1367,7 +1755,7 @@ class TemplateFillService: response = await self.llm.chat( messages=messages, temperature=0.1, - max_tokens=5000 + max_tokens=4000 ) content = self.llm.extract_message_content(response) @@ -1434,7 +1822,8 @@ class TemplateFillService: async def _generate_fields_with_ai( self, file_path: str, - file_type: str + file_type: str, + source_contents: List[dict] = None ) -> Optional[List[TemplateField]]: """ 使用 AI 为空表生成表头字段 @@ -1454,28 +1843,35 @@ class TemplateFillService: content_sample = "" # 读取 Excel 内容检查是否为空 + content_sample = "" if file_type in ["xlsx", "xls"]: df = pd.read_excel(file_path, header=None) if df.shape[0] == 0 or df.shape[1] == 0: logger.info("Excel 表格为空") - # 生成默认字段 - return [TemplateField( - cell=self._column_to_cell(i), - name=f"字段{i+1}", - field_type="text", - required=False, - hint="请填写此字段" - ) for i in range(5)] - - # 表格有数据但没有表头 - if df.shape[1] > 0: - # 读取第一行作为参考,看是否为空 - first_row = df.iloc[0].tolist() if len(df) > 0 else [] - if not any(pd.notna(v) and str(v).strip() != '' for v in first_row): - # 第一行为空,AI 生成表头 - content_sample = df.iloc[:10].to_string() if len(df) >= 10 else df.to_string() + # 即使 Excel 为空,如果有源文档,仍然尝试使用 AI 生成表头 + if not source_contents: + logger.info("Excel 为空且没有源文档,使用默认字段名") + return [TemplateField( + cell=self._column_to_cell(i), + name=f"字段{i+1}", + field_type="text", + required=False, + hint="请填写此字段" + ) for i in range(5)] + # 有源文档,继续调用 AI 生成表头 + logger.info("Excel 为空但有源文档,使用源文档内容生成表头...") + else: + # 表格有数据但没有表头 + if df.shape[1] > 0: + # 读取第一行作为参考,看是否为空 + first_row = df.iloc[0].tolist() if len(df) > 0 else [] + if not any(pd.notna(v) and str(v).strip() != '' for v in first_row): + # 第一行为空,AI 生成表头 + content_sample = df.iloc[:10].to_string() if len(df) >= 10 else df.to_string() + else: + content_sample = df.to_string() else: - content_sample = df.to_string() + content_sample = "" elif file_type == "docx": # Word 文档:尝试使用 docx_parser 提取内容 @@ -1506,21 +1902,56 @@ class TemplateFillService: return None # 调用 AI 生成表头 - prompt = f"""你是一个专业的表格设计助手。请为以下空白表格生成合适的表头字段。 + # 根据源文档内容生成表头 + source_info = "" + logger.info(f"[DEBUG] _generate_fields_with_ai received source_contents: {len(source_contents) if source_contents else 0} items") + if source_contents: + for sc in source_contents: + logger.info(f"[DEBUG] source doc: filename={sc.get('filename')}, content_len={len(sc.get('content', ''))}, titles={len(sc.get('titles', []))}, tables_count={sc.get('tables_count', 0)}, has_tables_summary={bool(sc.get('tables_summary'))}") + source_info = "\n\n【源文档内容摘要】(根据以下文档内容生成表头):\n" + for idx, src in enumerate(source_contents[:5]): # 最多5个源文档 + filename = src.get("filename", f"文档{idx+1}") + doc_type = src.get("doc_type", "unknown") + content = src.get("content", "")[:3000] # 限制内容长度 + titles = src.get("titles", [])[:10] # 最多10个标题 + tables_count = src.get("tables_count", 0) + tables_summary = src.get("tables_summary", "") -表格内容预览: -{content_sample[:2000] if content_sample else "空白表格"} + source_info += f"\n--- 文档 {idx+1}: {filename} ({doc_type}) ---\n" + # 处理 titles(可能是字符串列表或字典列表) + if titles: + title_texts = [] + for t in titles[:5]: + if isinstance(t, dict): + title_texts.append(t.get('text', '')) + else: + title_texts.append(str(t)) + if title_texts: + source_info += f"【章节标题】: {', '.join(title_texts)}\n" + if tables_count > 0: + source_info += f"【包含表格数】: {tables_count}\n" + if tables_summary: + source_info += f"{tables_summary}\n" + elif content: + source_info += f"【内容预览】: {content[:1500]}...\n" -请生成5-10个简洁的表头字段名,这些字段应该: -1. 简洁明了,易于理解 -2. 适合作为表格列标题 -3. 之间有明显的区分度 + prompt = f"""你是一个专业的表格设计助手。请根据源文档内容生成合适的表格表头字段。 + +任务:用户有一些源文档(包含表格数据),需要填写到空白表格模板中。源文档中的表格如下: + +{source_info} + +【重要要求】 +1. 请仔细阅读上面的源文档表格,找出所有不同的列名(如"产品名称"、"1995年产量"、"按资产总额计算(%)"等) +2. 直接使用这些实际的列名作为表头字段名,不要生成新的或同义词 +3. 如果一个源文档有多个表格,请为每个表格选择合适的列名 +4. 生成3-8个表头字段,优先选择数据量大的表格的列 请严格按照以下 JSON 格式输出(只需输出 JSON,不要其他内容): {{ "fields": [ - {{"name": "字段名1", "hint": "字段说明提示1"}}, - {{"name": "字段名2", "hint": "字段说明提示2"}} + {{"name": "实际列名1", "hint": "对该列的说明"}}, + {{"name": "实际列名2", "hint": "对该列的说明"}} ] }} """ diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index e764335..44ccbb5 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,5 +1,5 @@ import { RouterProvider } from 'react-router-dom'; -import { AuthProvider } from '@/context/AuthContext'; +import { AuthProvider } from '@/contexts/AuthContext'; import { TemplateFillProvider } from '@/context/TemplateFillContext'; import { router } from '@/routes'; import { Toaster } from 'sonner'; diff --git a/frontend/src/components/common/RouteGuard.tsx b/frontend/src/components/common/RouteGuard.tsx index 0b691e0..8a4288b 100644 --- a/frontend/src/components/common/RouteGuard.tsx +++ b/frontend/src/components/common/RouteGuard.tsx @@ -1,6 +1,6 @@ import React from 'react'; import { Navigate, useLocation } from 'react-router-dom'; -import { useAuth } from '@/context/AuthContext'; +import { useAuth } from '@/contexts/AuthContext'; export const RouteGuard: React.FC<{ children: React.ReactNode }> = ({ children }) => { const { user, loading } = useAuth(); diff --git a/frontend/src/context/AuthContext.tsx b/frontend/src/context/AuthContext.tsx deleted file mode 100644 index 524dc8d..0000000 --- a/frontend/src/context/AuthContext.tsx +++ /dev/null @@ -1,85 +0,0 @@ -import React, { createContext, useContext, useEffect, useState } from 'react'; -import { supabase } from '@/db/supabase'; -import { User } from '@supabase/supabase-js'; -import { Profile } from '@/types/types'; - -interface AuthContextType { - user: User | null; - profile: Profile | null; - signIn: (email: string, password: string) => Promise<{ error: any }>; - signUp: (email: string, password: string) => Promise<{ error: any }>; - signOut: () => Promise<{ error: any }>; - loading: boolean; -} - -const AuthContext = createContext(undefined); - -export const AuthProvider: React.FC<{ children: React.ReactNode }> = ({ children }) => { - const [user, setUser] = useState(null); - const [profile, setProfile] = useState(null); - const [loading, setLoading] = useState(true); - - useEffect(() => { - // Check active sessions and sets the user - supabase.auth.getSession().then(({ data: { session } }) => { - setUser(session?.user ?? null); - if (session?.user) fetchProfile(session.user.id); - else setLoading(false); - }); - - // Listen for changes on auth state (sign in, sign out, etc.) - const { data: { subscription } } = supabase.auth.onAuthStateChange((_event, session) => { - setUser(session?.user ?? null); - if (session?.user) fetchProfile(session.user.id); - else { - setProfile(null); - setLoading(false); - } - }); - - return () => subscription.unsubscribe(); - }, []); - - const fetchProfile = async (uid: string) => { - try { - const { data, error } = await supabase - .from('profiles') - .select('*') - .eq('id', uid) - .maybeSingle(); - - if (error) throw error; - setProfile(data); - } catch (err) { - console.error('Error fetching profile:', err); - } finally { - setLoading(false); - } - }; - - const signIn = async (email: string, password: string) => { - return await supabase.auth.signInWithPassword({ email, password }); - }; - - const signUp = async (email: string, password: string) => { - return await supabase.auth.signUp({ email, password }); - }; - - const signOut = async () => { - return await supabase.auth.signOut(); - }; - - return ( - - {children} - - ); -}; - -export const useAuth = () => { - const context = useContext(AuthContext); - if (context === undefined) { - throw new Error('useAuth must be used within an AuthProvider'); - } - return context; -}; diff --git a/frontend/src/context/TemplateFillContext.tsx b/frontend/src/context/TemplateFillContext.tsx index 76ba073..61ef55d 100644 --- a/frontend/src/context/TemplateFillContext.tsx +++ b/frontend/src/context/TemplateFillContext.tsx @@ -21,6 +21,7 @@ interface TemplateFillState { templateFields: TemplateField[]; sourceFiles: SourceFile[]; sourceFilePaths: string[]; + sourceDocIds: string[]; templateId: string; filledResult: any; setStep: (step: Step) => void; @@ -30,6 +31,9 @@ interface TemplateFillState { addSourceFiles: (files: SourceFile[]) => void; removeSourceFile: (index: number) => void; setSourceFilePaths: (paths: string[]) => void; + setSourceDocIds: (ids: string[]) => void; + addSourceDocId: (id: string) => void; + removeSourceDocId: (id: string) => void; setTemplateId: (id: string) => void; setFilledResult: (result: any) => void; reset: () => void; @@ -41,6 +45,7 @@ const initialState = { templateFields: [], sourceFiles: [], sourceFilePaths: [], + sourceDocIds: [], templateId: '', filledResult: null, setStep: () => {}, @@ -50,6 +55,9 @@ const initialState = { addSourceFiles: () => {}, removeSourceFile: () => {}, setSourceFilePaths: () => {}, + setSourceDocIds: () => {}, + addSourceDocId: () => {}, + removeSourceDocId: () => {}, setTemplateId: () => {}, setFilledResult: () => {}, reset: () => {}, @@ -63,6 +71,7 @@ export const TemplateFillProvider: React.FC<{ children: ReactNode }> = ({ childr const [templateFields, setTemplateFields] = useState([]); const [sourceFiles, setSourceFiles] = useState([]); const [sourceFilePaths, setSourceFilePaths] = useState([]); + const [sourceDocIds, setSourceDocIds] = useState([]); const [templateId, setTemplateId] = useState(''); const [filledResult, setFilledResult] = useState(null); @@ -74,12 +83,21 @@ export const TemplateFillProvider: React.FC<{ children: ReactNode }> = ({ childr setSourceFiles(prev => prev.filter((_, i) => i !== index)); }; + const addSourceDocId = (id: string) => { + setSourceDocIds(prev => prev.includes(id) ? prev : [...prev, id]); + }; + + const removeSourceDocId = (id: string) => { + setSourceDocIds(prev => prev.filter(docId => docId !== id)); + }; + const reset = () => { setStep('upload'); setTemplateFile(null); setTemplateFields([]); setSourceFiles([]); setSourceFilePaths([]); + setSourceDocIds([]); setTemplateId(''); setFilledResult(null); }; @@ -92,6 +110,7 @@ export const TemplateFillProvider: React.FC<{ children: ReactNode }> = ({ childr templateFields, sourceFiles, sourceFilePaths, + sourceDocIds, templateId, filledResult, setStep, @@ -101,6 +120,9 @@ export const TemplateFillProvider: React.FC<{ children: ReactNode }> = ({ childr addSourceFiles, removeSourceFile, setSourceFilePaths, + setSourceDocIds, + addSourceDocId, + removeSourceDocId, setTemplateId, setFilledResult, reset, diff --git a/frontend/src/db/backend-api.ts b/frontend/src/db/backend-api.ts index 6f218dd..75f9c68 100644 --- a/frontend/src/db/backend-api.ts +++ b/frontend/src/db/backend-api.ts @@ -400,6 +400,49 @@ export const backendApi = { } }, + /** + * 获取任务历史列表 + */ + async getTasks( + limit: number = 50, + skip: number = 0 + ): Promise<{ success: boolean; tasks: any[]; count: number }> { + const url = `${BACKEND_BASE_URL}/tasks?limit=${limit}&skip=${skip}`; + + try { + const response = await fetch(url); + if (!response.ok) { + const error = await response.json(); + throw new Error(error.detail || '获取任务列表失败'); + } + return await response.json(); + } catch (error) { + console.error('获取任务列表失败:', error); + throw error; + } + }, + + /** + * 删除任务 + */ + async deleteTask(taskId: string): Promise<{ success: boolean; deleted: boolean }> { + const url = `${BACKEND_BASE_URL}/tasks/${taskId}`; + + try { + const response = await fetch(url, { + method: 'DELETE' + }); + if (!response.ok) { + const error = await response.json(); + throw new Error(error.detail || '删除任务失败'); + } + return await response.json(); + } catch (error) { + console.error('删除任务失败:', error); + throw error; + } + }, + /** * 轮询任务状态直到完成 */ @@ -1180,7 +1223,7 @@ export const aiApi = { try { const response = await fetch(url, { - method: 'GET', + method: 'POST', body: formData, }); diff --git a/frontend/src/pages/Documents.tsx b/frontend/src/pages/Documents.tsx index d0d9c2e..afeb54d 100644 --- a/frontend/src/pages/Documents.tsx +++ b/frontend/src/pages/Documents.tsx @@ -1,4 +1,4 @@ -import React, { useState, useEffect, useCallback } from 'react'; +import React, { useState, useEffect, useCallback, useRef } from 'react'; import { useDropzone } from 'react-dropzone'; import { FileText, @@ -23,7 +23,8 @@ import { List, MessageSquareCode, Tag, - HelpCircle + HelpCircle, + Plus } from 'lucide-react'; import { Button } from '@/components/ui/button'; import { Input } from '@/components/ui/input'; @@ -72,8 +73,10 @@ const Documents: React.FC = () => { // 上传相关状态 const [uploading, setUploading] = useState(false); const [uploadedFile, setUploadedFile] = useState(null); + const [uploadedFiles, setUploadedFiles] = useState([]); const [parseResult, setParseResult] = useState(null); const [expandedSheet, setExpandedSheet] = useState(null); + const [uploadExpanded, setUploadExpanded] = useState(false); // AI 分析相关状态 const [analyzing, setAnalyzing] = useState(false); @@ -210,75 +213,119 @@ const Documents: React.FC = () => { // 文件上传处理 const onDrop = async (acceptedFiles: File[]) => { - const file = acceptedFiles[0]; - if (!file) return; + if (acceptedFiles.length === 0) return; - setUploadedFile(file); setUploading(true); - setParseResult(null); - setAiAnalysis(null); - setAnalysisCharts(null); - setExpandedSheet(null); - setMdAnalysis(null); - setMdSections([]); - setMdStreamingContent(''); + let successCount = 0; + let failCount = 0; + const successfulFiles: File[] = []; - const ext = file.name.split('.').pop()?.toLowerCase(); + // 逐个上传文件 + for (const file of acceptedFiles) { + const ext = file.name.split('.').pop()?.toLowerCase(); - try { - // Excel 文件使用专门的上传接口 - if (ext === 'xlsx' || ext === 'xls') { - const result = await backendApi.uploadExcel(file, { - parseAllSheets: parseOptions.parseAllSheets, - headerRow: parseOptions.headerRow - }); - if (result.success) { - toast.success(`解析成功: ${file.name}`); - setParseResult(result); - loadDocuments(); // 刷新文档列表 - if (result.metadata?.sheet_count === 1) { - setExpandedSheet(Object.keys(result.data?.sheets || {})[0] || null); + try { + if (ext === 'xlsx' || ext === 'xls') { + const result = await backendApi.uploadExcel(file, { + parseAllSheets: parseOptions.parseAllSheets, + headerRow: parseOptions.headerRow + }); + if (result.success) { + successCount++; + successfulFiles.push(file); + // 第一个Excel文件设置解析结果供预览 + if (successCount === 1) { + setUploadedFile(file); + setParseResult(result); + if (result.metadata?.sheet_count === 1) { + setExpandedSheet(Object.keys(result.data?.sheets || {})[0] || null); + } + } + loadDocuments(); + } else { + failCount++; + toast.error(`${file.name}: ${result.error || '解析失败'}`); + } + } else if (ext === 'md' || ext === 'markdown') { + const result = await backendApi.uploadDocument(file); + if (result.task_id) { + successCount++; + successfulFiles.push(file); + if (successCount === 1) { + setUploadedFile(file); + } + // 轮询任务状态 + let attempts = 0; + const checkStatus = async () => { + while (attempts < 30) { + try { + const status = await backendApi.getTaskStatus(result.task_id); + if (status.status === 'success') { + loadDocuments(); + return; + } else if (status.status === 'failure') { + return; + } + } catch (e) { + console.error('检查状态失败', e); + } + await new Promise(resolve => setTimeout(resolve, 2000)); + attempts++; + } + }; + checkStatus(); + } else { + failCount++; } } else { - toast.error(result.error || '解析失败'); - } - } else if (ext === 'md' || ext === 'markdown') { - // Markdown 文件:获取大纲 - await fetchMdOutline(); - } else { - // 其他文档使用通用上传接口 - const result = await backendApi.uploadDocument(file); - if (result.task_id) { - toast.success(`文件 ${file.name} 已提交处理`); - // 轮询任务状态 - let attempts = 0; - const checkStatus = async () => { - while (attempts < 30) { - try { - const status = await backendApi.getTaskStatus(result.task_id); - if (status.status === 'success') { - toast.success(`文件 ${file.name} 处理完成`); - loadDocuments(); - return; - } else if (status.status === 'failure') { - toast.error(`文件 ${file.name} 处理失败`); - return; - } - } catch (e) { - console.error('检查状态失败', e); - } - await new Promise(resolve => setTimeout(resolve, 2000)); - attempts++; + // 其他文档使用通用上传接口 + const result = await backendApi.uploadDocument(file); + if (result.task_id) { + successCount++; + successfulFiles.push(file); + if (successCount === 1) { + setUploadedFile(file); } - toast.error(`文件 ${file.name} 处理超时`); - }; - checkStatus(); + // 轮询任务状态 + let attempts = 0; + const checkStatus = async () => { + while (attempts < 30) { + try { + const status = await backendApi.getTaskStatus(result.task_id); + if (status.status === 'success') { + loadDocuments(); + return; + } else if (status.status === 'failure') { + return; + } + } catch (e) { + console.error('检查状态失败', e); + } + await new Promise(resolve => setTimeout(resolve, 2000)); + attempts++; + } + }; + checkStatus(); + } else { + failCount++; + } } + } catch (error: any) { + failCount++; + toast.error(`${file.name}: ${error.message || '上传失败'}`); } - } catch (error: any) { - toast.error(error.message || '上传失败'); - } finally { - setUploading(false); + } + + setUploading(false); + loadDocuments(); + + if (successCount > 0) { + toast.success(`成功上传 ${successCount} 个文件`); + setUploadedFiles(prev => [...prev, ...successfulFiles]); + setUploadExpanded(true); + } + if (failCount > 0) { + toast.error(`${failCount} 个文件上传失败`); } }; @@ -291,7 +338,7 @@ const Documents: React.FC = () => { 'text/markdown': ['.md'], 'text/plain': ['.txt'] }, - maxFiles: 1 + multiple: true }); // AI 分析处理 @@ -449,6 +496,7 @@ const Documents: React.FC = () => { const handleDeleteFile = () => { setUploadedFile(null); + setUploadedFiles([]); setParseResult(null); setAiAnalysis(null); setAnalysisCharts(null); @@ -456,6 +504,17 @@ const Documents: React.FC = () => { toast.success('文件已清除'); }; + const handleRemoveUploadedFile = (index: number) => { + setUploadedFiles(prev => { + const newFiles = prev.filter((_, i) => i !== index); + if (newFiles.length === 0) { + setUploadedFile(null); + } + return newFiles; + }); + toast.success('文件已从列表移除'); + }; + const handleDelete = async (docId: string) => { try { const result = await backendApi.deleteDocument(docId); @@ -615,7 +674,7 @@ const Documents: React.FC = () => {

文档中心

上传文档,自动解析并使用 AI 进行深度分析

- @@ -640,7 +699,82 @@ const Documents: React.FC = () => { {uploadPanelOpen && ( - {!uploadedFile ? ( + {uploadedFiles.length > 0 || uploadedFile ? ( +
+ {/* 文件列表头部 */} +
setUploadExpanded(!uploadExpanded)} + > +
+
+ +
+
+

+ 已上传 {(uploadedFiles.length > 0 ? uploadedFiles : [uploadedFile]).length} 个文件 +

+

+ {uploadExpanded ? '点击收起' : '点击展开查看'} +

+
+
+
+ + {uploadExpanded ? : } +
+
+ + {/* 展开的文件列表 */} + {uploadExpanded && ( +
+ {(uploadedFiles.length > 0 ? uploadedFiles : [uploadedFile]).filter(Boolean).map((file, index) => ( +
+
+ {isExcelFile(file?.name || '') ? : } +
+
+

{file?.name}

+

{formatFileSize(file?.size || 0)}

+
+ +
+ ))} + + {/* 继续添加按钮 */} +
+ + + 继续添加更多文件 +
+
+ )} +
+ ) : (
{ uploading && "opacity-50 pointer-events-none" )} > - +
{uploading ? : }
@@ -671,30 +805,6 @@ const Documents: React.FC = () => {
- ) : ( -
-
-
- {isExcelFile(uploadedFile.name) ? : } -
-
-

{uploadedFile.name}

-

{formatFileSize(uploadedFile.size)}

-
- -
- - {isExcelFile(uploadedFile.name) && ( - - )} -
)}
)} diff --git a/frontend/src/pages/ExcelParse.tsx b/frontend/src/pages/ExcelParse.tsx deleted file mode 100644 index 8556025..0000000 --- a/frontend/src/pages/ExcelParse.tsx +++ /dev/null @@ -1,1015 +0,0 @@ -import React, { useState, useEffect } from 'react'; -import { useDropzone } from 'react-dropzone'; -import { - FileSpreadsheet, - Upload, - Trash2, - ChevronDown, - ChevronUp, - Table, - Info, - CheckCircle, - AlertCircle, - Loader2, - Sparkles, - FileText, - TrendingUp, - Download, - Brain, - Check, - X -} from 'lucide-react'; -import { Button } from '@/components/ui/button'; -import { Card, CardContent, CardHeader, CardTitle, CardDescription } from '@/components/ui/card'; -import { Badge } from '@/components/ui/badge'; -import { Input } from '@/components/ui/input'; -import { Label } from '@/components/ui/label'; -import { Switch } from '@/components/ui/switch'; -import { Textarea } from '@/components/ui/textarea'; -import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/select'; -import { Checkbox } from '@/components/ui/checkbox'; -import { toast } from 'sonner'; -import { cn } from '@/lib/utils'; -import { backendApi, type ExcelParseResult, type ExcelUploadOptions, aiApi } from '@/db/backend-api'; -import { - Table as TableComponent, - TableBody, - TableCell, - TableHead, - TableHeader, - TableRow, -} from '@/components/ui/table'; -import { Markdown } from '@/components/ui/markdown'; -import { AIChartDisplay } from '@/components/ui/ai-chart-display'; -import { Dialog, DialogContent, DialogHeader, DialogTitle, DialogFooter } from '@/components/ui/dialog'; - -const ExcelParse: React.FC = () => { - const [loading, setLoading] = useState(false); - const [analyzing, setAnalyzing] = useState(false); - const [analyzingForCharts, setAnalyzingForCharts] = useState(false); - const [exporting, setExporting] = useState(false); - const [parseResult, setParseResult] = useState(null); - const [aiAnalysis, setAiAnalysis] = useState(null); - const [analysisCharts, setAnalysisCharts] = useState(null); - const [uploadedFile, setUploadedFile] = useState(null); - const [expandedSheet, setExpandedSheet] = useState(null); - const [parseOptions, setParseOptions] = useState({ - parseAllSheets: false, - headerRow: 0 - }); - const [aiOptions, setAiOptions] = useState({ - userPrompt: '', - analysisType: 'general' as 'general' | 'summary' | 'statistics' | 'insights', - parseAllSheetsForAI: false - }); - const [analysisTypes, setAnalysisTypes] = useState>([]); - - // 导出相关状态 - const [exportDialogOpen, setExportDialogOpen] = useState(false); - const [selectedSheet, setSelectedSheet] = useState(''); - const [selectedColumns, setSelectedColumns] = useState>(new Set()); - const [selectAll, setSelectAll] = useState(false); - - // 获取支持的分析类型 - useEffect(() => { - aiApi.getAnalysisTypes() - .then(data => setAnalysisTypes(data.types)) - .catch(() => { - setAnalysisTypes([ - { value: 'general', label: '综合分析', description: '提供数据概览、关键发现、质量评估和建议' }, - { value: 'summary', label: '数据摘要', description: '快速了解数据的结构、范围和主要内容' }, - { value: 'statistics', label: '统计分析', description: '数值型列的统计信息和分类列的分布' }, - { value: 'insights', label: '深度洞察', description: '深入挖掘数据,提供异常值和业务建议' } - ]); - }); - }, []); - - const onDrop = async (acceptedFiles: File[]) => { - const file = acceptedFiles[0]; - if (!file) return; - - if (!file.name.match(/\.(xlsx|xls)$/i)) { - toast.error('仅支持 .xlsx 和 .xls 格式的 Excel 文件'); - return; - } - - setUploadedFile(file); - setLoading(true); - setParseResult(null); - setAiAnalysis(null); - setAnalysisCharts(null); - setExpandedSheet(null); - - try { - const result = await backendApi.uploadExcel(file, parseOptions); - - if (result.success) { - toast.success(`解析成功: ${file.name}`); - setParseResult(result); - // 自动展开第一个工作表 - if (result.metadata?.sheet_count === 1) { - setExpandedSheet(null); - } - } else { - toast.error(result.error || '解析失败'); - } - } catch (error: any) { - toast.error(error.message || '上传失败'); - } finally { - setLoading(false); - } - }; - - const handleAnalyze = async () => { - if (!uploadedFile || !parseResult?.success) { - toast.error('请先上传并解析 Excel 文件'); - return; - } - - setAnalyzing(true); - setAiAnalysis(null); - setAnalysisCharts(null); - - try { - const result = await aiApi.analyzeExcel(uploadedFile, { - userPrompt: aiOptions.userPrompt, - analysisType: aiOptions.analysisType, - parseAllSheets: aiOptions.parseAllSheetsForAI - }); - - if (result.success) { - toast.success('AI 分析完成'); - setAiAnalysis(result); - } else { - toast.error(result.error || 'AI 分析失败'); - } - } catch (error: any) { - toast.error(error.message || 'AI 分析失败'); - } finally { - setAnalyzing(false); - } - }; - - const handleGenerateChartsFromAnalysis = async () => { - if (!aiAnalysis || !aiAnalysis.success) { - toast.error('请先进行 AI 分析'); - return; - } - - // 提取 AI 分析文本 - let analysisText = ''; - - if (aiAnalysis.analysis?.analysis) { - analysisText = aiAnalysis.analysis.analysis; - } else if (aiAnalysis.analysis?.sheets) { - // 多工作表模式,合并所有工作表的分析结果 - const sheetAnalyses = aiAnalysis.analysis.sheets; - if (sheetAnalyses && Object.keys(sheetAnalyses).length > 0) { - const firstSheet = Object.keys(sheetAnalyses)[0]; - analysisText = sheetAnalyses[firstSheet]?.analysis || ''; - } - } - - if (!analysisText || !analysisText.trim()) { - toast.error('无法获取 AI 分析结果'); - return; - } - - setAnalyzingForCharts(true); - setAnalysisCharts(null); - - try { - const result = await aiApi.extractAndGenerateCharts({ - analysis_text: analysisText, - original_filename: uploadedFile?.name || 'unknown', - file_type: 'excel' - }); - - if (result.success) { - toast.success('基于 AI 分析的图表生成完成'); - setAnalysisCharts(result); - } else { - toast.error(result.error || '图表生成失败'); - } - } catch (error: any) { - toast.error(error.message || '图表生成失败'); - } finally { - setAnalyzingForCharts(false); - } - }; - - // 获取工作表数据 - const getSheetData = (sheetName: string) => { - if (!parseResult?.success || !parseResult.data) return null; - - const data = parseResult.data; - - // 多工作表模式 - if (data.sheets && data.sheets[sheetName]) { - return data.sheets[sheetName]; - } - - // 单工作表模式 - if (!data.sheets && data.columns && data.rows) { - return data; - } - - return null; - }; - - // 打开导出对话框 - const openExportDialog = () => { - if (!parseResult?.success || !parseResult.data) { - toast.error('请先上传并解析 Excel 文件'); - return; - } - - const data = parseResult.data; - - // 获取所有工作表 - let sheets: string[] = []; - if (data.sheets) { - sheets = Object.keys(data.sheets); - } else { - sheets = ['默认工作表']; - } - - setSelectedSheet(sheets[0]); - const sheetColumns = getSheetData(sheets[0])?.columns || []; - setSelectedColumns(new Set(sheetColumns)); - setSelectAll(true); - setExportDialogOpen(true); - }; - - // 处理列选择 - const toggleColumn = (column: string) => { - const newSelected = new Set(selectedColumns); - if (newSelected.has(column)) { - newSelected.delete(column); - } else { - newSelected.add(column); - } - setSelectedColumns(newSelected); - setSelectAll(newSelected.size === (getSheetData(selectedSheet)?.columns || []).length); - }; - - // 全选/取消全选 - const toggleSelectAll = () => { - const sheetColumns = getSheetData(selectedSheet)?.columns || []; - if (selectAll) { - setSelectedColumns(new Set()); - } else { - setSelectedColumns(new Set(sheetColumns)); - } - setSelectAll(!selectAll); - }; - - // 执行导出 - const handleExport = async () => { - if (selectedColumns.size === 0) { - toast.error('请至少选择一列'); - return; - } - - if (!parseResult?.metadata?.saved_path) { - toast.error('无法获取文件路径'); - return; - } - - setExporting(true); - - try { - const blob = await backendApi.exportExcel( - parseResult.metadata.saved_path, - { - columns: Array.from(selectedColumns), - sheetName: selectedSheet === '默认工作表' ? undefined : selectedSheet - } - ); - - const url = URL.createObjectURL(blob); - const link = document.createElement('a'); - link.href = url; - link.download = `export_${selectedSheet}_${uploadedFile?.name || 'data.xlsx'}`; - document.body.appendChild(link); - link.click(); - document.body.removeChild(link); - URL.revokeObjectURL(url); - - toast.success('导出成功'); - setExportDialogOpen(false); - } catch (error: any) { - toast.error(error.message || '导出失败'); - } finally { - setExporting(false); - } - }; - - const { getRootProps, getInputProps, isDragActive } = useDropzone({ - onDrop, - accept: { - 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': ['.xlsx'], - 'application/vnd.ms-excel': ['.xls'] - }, - maxFiles: 1 - }); - - const handleDeleteFile = () => { - setUploadedFile(null); - setParseResult(null); - setAiAnalysis(null); - setAnalysisCharts(null); - setExpandedSheet(null); - toast.success('文件已清除'); - }; - - const formatFileSize = (bytes: number): string => { - if (bytes === 0) return '0 B'; - const k = 1024; - const sizes = ['B', 'KB', 'MB', 'GB']; - const i = Math.floor(Math.log(bytes) / Math.log(k)); - return `${(bytes / Math.pow(k, i)).toFixed(2)} ${sizes[i]}`; - }; - - const getAnalysisIcon = (type: string) => { - switch (type) { - case 'general': - return ; - case 'summary': - return ; - case 'statistics': - return ; - case 'insights': - return ; - default: - return ; - } - }; - - const downloadAnalysis = () => { - if (!aiAnalysis?.analysis?.analysis) return; - - const content = aiAnalysis.analysis.analysis; - const blob = new Blob([content], { type: 'text/plain;charset=utf-8' }); - const url = URL.createObjectURL(blob); - const link = document.createElement('a'); - link.href = url; - link.download = `AI分析结果_${uploadedFile?.name || 'excel'}.txt`; - link.click(); - URL.revokeObjectURL(url); - toast.success('分析结果已下载'); - }; - - return ( -
-
-
-

- - Excel 智能分析工具 -

-

上传 Excel 文件,使用 AI 进行深度数据分析。

-
-
- -
- {/* 左侧:上传区域 */} -
- {/* 上传卡片 */} - - - - - 文件上传 - - - 拖拽或点击上传 Excel 文件 - - - - {!uploadedFile ? ( -
- -
- {loading ? : } -
-

- {isDragActive ? '释放以开始上传' : '点击或拖拽文件到这里'} -

-

支持 .xlsx 和 .xls 格式

-
- ) : ( -
-
-
- -
-
-

{uploadedFile.name}

-

{formatFileSize(uploadedFile.size)}

-
- -
- -
- )} -
-
- - {/* 解析选项卡片 */} - - - - - 解析选项 - - - 配置 Excel 文件的解析方式 - - - -
- - setParseOptions({ ...parseOptions, parseAllSheets: checked })} - /> -
-
- - setParseOptions({ ...parseOptions, headerRow: parseInt(e.target.value) || 0 })} - className="bg-background" - /> -

- 从 0 开始,0 表示第一行 -

-
-
-
- - {/* AI 分析选项卡片 */} - - - - - AI 分析选项 - - - 配置 AI 分析的方式 - - - -
- - -
-
- -