diff --git a/backend/app/api/__init__.py b/backend/app/api/__init__.py index c393c2a..1a7ced4 100644 --- a/backend/app/api/__init__.py +++ b/backend/app/api/__init__.py @@ -13,6 +13,7 @@ from app.api.endpoints import ( visualization, analysis_charts, health, + instruction, # 智能指令 ) # 创建主路由 @@ -29,3 +30,4 @@ api_router.include_router(templates.router) # 表格模板 api_router.include_router(ai_analyze.router) # AI分析 api_router.include_router(visualization.router) # 可视化 api_router.include_router(analysis_charts.router) # 分析图表 +api_router.include_router(instruction.router) # 智能指令 diff --git a/backend/app/api/endpoints/instruction.py b/backend/app/api/endpoints/instruction.py new file mode 100644 index 0000000..751e518 --- /dev/null +++ b/backend/app/api/endpoints/instruction.py @@ -0,0 +1,439 @@ +""" +智能指令 API 接口 + +支持自然语言指令解析和执行 +""" +import logging +import uuid +from typing import Any, Dict, List, Optional + +from fastapi import APIRouter, HTTPException, Query, BackgroundTasks +from pydantic import BaseModel + +from app.instruction.intent_parser import intent_parser +from app.instruction.executor import instruction_executor +from app.core.database import mongodb + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/instruction", tags=["智能指令"]) + + +# ==================== 请求/响应模型 ==================== + +class InstructionRequest(BaseModel): + instruction: str + doc_ids: Optional[List[str]] = None # 关联的文档 ID 列表 + context: Optional[Dict[str, Any]] = None # 额外上下文 + + +class IntentRecognitionResponse(BaseModel): + success: bool + intent: str + params: Dict[str, Any] + message: str + + +class InstructionExecutionResponse(BaseModel): + success: bool + intent: str + result: Dict[str, Any] + message: str + + +# ==================== 接口 ==================== + +@router.post("/recognize", response_model=IntentRecognitionResponse) +async def recognize_intent(request: InstructionRequest): + """ + 意图识别接口 + + 将自然语言指令解析为结构化的意图和参数 + + 示例指令: + - "提取文档中的医院数量和床位数" + - "根据这些数据填表" + - "总结一下这份文档" + - "对比这两个文档的差异" + """ + try: + intent, params = await intent_parser.parse(request.instruction) + + # 添加文档关联信息 + if request.doc_ids: + params["document_refs"] = [f"doc_{doc_id}" for doc_id in request.doc_ids] + + intent_names = { + "extract": "信息提取", + "fill_table": "表格填写", + "summarize": "摘要总结", + "question": "智能问答", + "search": "文档搜索", + "compare": "对比分析", + "transform": "格式转换", + "edit": "文档编辑", + "unknown": "未知" + } + + return IntentRecognitionResponse( + success=True, + intent=intent, + params=params, + message=f"识别到意图: {intent_names.get(intent, intent)}" + ) + + except Exception as e: + logger.error(f"意图识别失败: {e}") + return IntentRecognitionResponse( + success=False, + intent="error", + params={}, + message=f"意图识别失败: {str(e)}" + ) + + +@router.post("/execute") +async def execute_instruction( + background_tasks: BackgroundTasks, + request: InstructionRequest, + async_execute: bool = Query(False, description="是否异步执行(仅返回任务ID)") +): + """ + 指令执行接口 + + 解析并执行自然语言指令 + + 示例: + - 指令: "提取文档1中的医院数量" + 返回: {"extracted_data": {"医院数量": ["38710个"]}} + + - 指令: "填表" + 返回: {"filled_data": {...}} + + 设置 async_execute=true 可异步执行,返回任务ID用于查询进度 + """ + task_id = str(uuid.uuid4()) + + if async_execute: + # 异步模式:立即返回任务ID,后台执行 + background_tasks.add_task( + _execute_instruction_task, + task_id=task_id, + instruction=request.instruction, + doc_ids=request.doc_ids, + context=request.context + ) + + return { + "success": True, + "task_id": task_id, + "message": "指令已提交执行", + "status_url": f"/api/v1/tasks/{task_id}" + } + + # 同步模式:等待执行完成 + return await _execute_instruction_task(task_id, request.instruction, request.doc_ids, request.context) + + +async def _execute_instruction_task( + task_id: str, + instruction: str, + doc_ids: Optional[List[str]], + context: Optional[Dict[str, Any]] +) -> InstructionExecutionResponse: + """执行指令的后台任务""" + from app.core.database import redis_db, mongodb as mongo_client + + try: + # 记录任务 + try: + await mongo_client.insert_task( + task_id=task_id, + task_type="instruction_execute", + status="processing", + message="正在执行指令" + ) + except Exception: + pass + + # 构建执行上下文 + ctx: Dict[str, Any] = context or {} + + # 如果提供了文档 ID,获取文档内容 + if doc_ids: + docs = [] + for doc_id in doc_ids: + doc = await mongo_client.get_document(doc_id) + if doc: + docs.append(doc) + + if docs: + ctx["source_docs"] = docs + logger.info(f"指令执行上下文: 关联了 {len(docs)} 个文档") + + # 执行指令 + result = await instruction_executor.execute(instruction, ctx) + + # 更新任务状态 + try: + await mongo_client.update_task( + task_id=task_id, + status="success", + message="执行完成", + result=result + ) + except Exception: + pass + + return InstructionExecutionResponse( + success=result.get("success", False), + intent=result.get("intent", "unknown"), + result=result, + message=result.get("message", "执行完成") + ) + + except Exception as e: + logger.error(f"指令执行失败: {e}") + try: + await mongo_client.update_task( + task_id=task_id, + status="failure", + message="执行失败", + error=str(e) + ) + except Exception: + pass + + return InstructionExecutionResponse( + success=False, + intent="error", + result={"error": str(e)}, + message=f"指令执行失败: {str(e)}" + ) + + +@router.post("/chat") +async def instruction_chat( + background_tasks: BackgroundTasks, + request: InstructionRequest, + async_execute: bool = Query(False, description="是否异步执行(仅返回任务ID)") +): + """ + 指令对话接口 + + 支持多轮对话的指令执行 + + 示例对话流程: + 1. 用户: "上传一些文档" + 2. 系统: "请上传文档" + 3. 用户: "提取其中的医院数量" + 4. 系统: 返回提取结果 + + 设置 async_execute=true 可异步执行,返回任务ID用于查询进度 + """ + task_id = str(uuid.uuid4()) + + if async_execute: + # 异步模式:立即返回任务ID,后台执行 + background_tasks.add_task( + _execute_chat_task, + task_id=task_id, + instruction=request.instruction, + doc_ids=request.doc_ids, + context=request.context + ) + + return { + "success": True, + "task_id": task_id, + "message": "指令已提交执行", + "status_url": f"/api/v1/tasks/{task_id}" + } + + # 同步模式:等待执行完成 + return await _execute_chat_task(task_id, request.instruction, request.doc_ids, request.context) + + +async def _execute_chat_task( + task_id: str, + instruction: str, + doc_ids: Optional[List[str]], + context: Optional[Dict[str, Any]] +): + """执行指令对话的后台任务""" + from app.core.database import mongodb as mongo_client + + try: + # 记录任务 + try: + await mongo_client.insert_task( + task_id=task_id, + task_type="instruction_chat", + status="processing", + message="正在处理对话" + ) + except Exception: + pass + + # 构建上下文 + ctx: Dict[str, Any] = context or {} + + # 获取关联文档 + if doc_ids: + docs = [] + for doc_id in doc_ids: + doc = await mongo_client.get_document(doc_id) + if doc: + docs.append(doc) + if docs: + ctx["source_docs"] = docs + + # 执行指令 + result = await instruction_executor.execute(instruction, ctx) + + # 根据意图类型添加友好的响应消息 + response_messages = { + "extract": f"已提取 {len(result.get('extracted_data', {}))} 个字段的数据", + "fill_table": f"填表完成,填写了 {len(result.get('result', {}).get('filled_data', {}))} 个字段", + "summarize": "已生成文档摘要", + "question": "已找到相关答案", + "search": f"找到 {len(result.get('results', []))} 条相关内容", + "compare": f"对比了 {len(result.get('comparison', []))} 个文档", + "edit": "编辑操作已完成", + "transform": "格式转换已完成", + "unknown": "无法理解该指令,请尝试更明确的描述" + } + + response = { + "success": result.get("success", False), + "intent": result.get("intent", "unknown"), + "result": result, + "message": response_messages.get(result.get("intent", ""), result.get("message", "")), + "hint": _get_intent_hint(result.get("intent", "")) + } + + # 更新任务状态 + try: + await mongo_client.update_task( + task_id=task_id, + status="success", + message="处理完成", + result=response + ) + except Exception: + pass + + return response + + except Exception as e: + logger.error(f"指令对话失败: {e}") + try: + await mongo_client.update_task( + task_id=task_id, + status="failure", + message="处理失败", + error=str(e) + ) + except Exception: + pass + + return { + "success": False, + "error": str(e), + "message": f"处理失败: {str(e)}" + } + + +def _get_intent_hint(intent: str) -> Optional[str]: + """根据意图返回下一步提示""" + hints = { + "extract": "您可以继续说 '提取更多字段' 或 '将数据填入表格'", + "fill_table": "您可以提供表格模板或说 '帮我创建一个表格'", + "question": "您可以继续提问或说 '总结一下这些内容'", + "search": "您可以查看搜索结果或说 '对比这些内容'", + "unknown": "您可以尝试: '提取数据'、'填表'、'总结'、'问答' 等指令" + } + return hints.get(intent) + + +@router.get("/intents") +async def list_supported_intents(): + """ + 获取支持的意图类型列表 + + 返回所有可用的自然语言指令类型 + """ + return { + "intents": [ + { + "intent": "extract", + "name": "信息提取", + "examples": [ + "提取文档中的医院数量", + "抽取所有机构的名称", + "找出表格中的数据" + ], + "params": ["field_refs", "document_refs"] + }, + { + "intent": "fill_table", + "name": "表格填写", + "examples": [ + "填表", + "根据这些数据填写表格", + "帮我填到Excel里" + ], + "params": ["template", "document_refs"] + }, + { + "intent": "summarize", + "name": "摘要总结", + "examples": [ + "总结一下这份文档", + "生成摘要", + "概括主要内容" + ], + "params": ["document_refs"] + }, + { + "intent": "question", + "name": "智能问答", + "examples": [ + "这段话说的是什么?", + "有多少家医院?", + "解释一下这个概念" + ], + "params": ["question", "focus"] + }, + { + "intent": "search", + "name": "文档搜索", + "examples": [ + "搜索相关内容", + "找找看有哪些机构", + "查询医院相关的数据" + ], + "params": ["field_refs", "question"] + }, + { + "intent": "compare", + "name": "对比分析", + "examples": [ + "对比这两个文档", + "比较一下差异", + "找出不同点" + ], + "params": ["document_refs"] + }, + { + "intent": "edit", + "name": "文档编辑", + "examples": [ + "润色这段文字", + "修改格式", + "添加注释" + ], + "params": [] + } + ] + } diff --git a/backend/app/api/endpoints/templates.py b/backend/app/api/endpoints/templates.py index 625b274..54b3a73 100644 --- a/backend/app/api/endpoints/templates.py +++ b/backend/app/api/endpoints/templates.py @@ -610,51 +610,79 @@ async def _export_to_excel(filled_data: dict, template_id: str) -> StreamingResp async def _export_to_word(filled_data: dict, template_id: str) -> StreamingResponse: """导出为 Word 格式""" + import re + import tempfile + import os from docx import Document from docx.shared import Pt, RGBColor from docx.enum.text import WD_ALIGN_PARAGRAPH - doc = Document() + def clean_text(text: str) -> str: + """清理文本,移除可能导致Word问题的非法字符""" + if not text: + return "" + # 移除控制字符 + text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text) + return text.strip() - # 添加标题 - title = doc.add_heading('填写结果', level=1) - title.alignment = WD_ALIGN_PARAGRAPH.CENTER + try: + # 先保存到临时文件,再读取到内存,确保文档完整性 + with tempfile.NamedTemporaryFile(delete=False, suffix='.docx') as tmp_file: + tmp_path = tmp_file.name - # 添加填写时间和模板信息 - from datetime import datetime - info_para = doc.add_paragraph() - info_para.add_run(f"模板ID: {template_id}\n").bold = True - info_para.add_run(f"导出时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + doc = Document() + doc.add_heading('填写结果', level=1) - doc.add_paragraph() # 空行 + from datetime import datetime + info_para = doc.add_paragraph() + template_filename = template_id.split('/')[-1].split('\\')[-1] if template_id else '未知' + info_para.add_run(f"模板文件: {clean_text(template_filename)}\n").bold = True + info_para.add_run(f"导出时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + doc.add_paragraph() - # 添加字段表格 - table = doc.add_table(rows=1, cols=3) - table.style = 'Light Grid Accent 1' + table = doc.add_table(rows=1, cols=3) + table.style = 'Table Grid' - # 表头 - header_cells = table.rows[0].cells - header_cells[0].text = '字段名' - header_cells[1].text = '填写值' - header_cells[2].text = '状态' + header_cells = table.rows[0].cells + header_cells[0].text = '字段名' + header_cells[1].text = '填写值' + header_cells[2].text = '状态' - for field_name, field_value in filled_data.items(): - row_cells = table.add_row().cells - row_cells[0].text = field_name - row_cells[1].text = str(field_value) if field_value else '' - row_cells[2].text = '已填写' if field_value else '为空' + for field_name, field_value in filled_data.items(): + row_cells = table.add_row().cells + row_cells[0].text = clean_text(str(field_name)) - # 保存到 BytesIO - output = io.BytesIO() - doc.save(output) - output.seek(0) + if isinstance(field_value, list): + clean_values = [clean_text(str(v)) for v in field_value if v] + display_value = ', '.join(clean_values) if clean_values else '' + else: + display_value = clean_text(str(field_value)) if field_value else '' - filename = f"filled_template.docx" + row_cells[1].text = display_value + row_cells[2].text = '已填写' if display_value else '为空' + + # 保存到临时文件 + doc.save(tmp_path) + + # 读取文件内容 + with open(tmp_path, 'rb') as f: + file_content = f.read() + + finally: + # 清理临时文件 + if os.path.exists(tmp_path): + try: + os.unlink(tmp_path) + except: + pass + + output = io.BytesIO(file_content) + filename = "filled_template.docx" return StreamingResponse( - io.BytesIO(output.getvalue()), + output, media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", - headers={"Content-Disposition": f"attachment; filename={filename}"} + headers={"Content-Disposition": f"attachment; filename*=UTF-8''{filename}"} ) diff --git a/backend/app/instruction/__init__.py b/backend/app/instruction/__init__.py index 1386f3d..981f8b4 100644 --- a/backend/app/instruction/__init__.py +++ b/backend/app/instruction/__init__.py @@ -1,15 +1,14 @@ """ 指令执行模块 -注意: 此模块为可选功能,当前尚未实现。 -如需启用,请实现 intent_parser.py 和 executor.py +支持文档智能操作交互,包括意图解析和指令执行 """ -from .intent_parser import IntentParser, DefaultIntentParser -from .executor import InstructionExecutor, DefaultInstructionExecutor +from .intent_parser import IntentParser, intent_parser +from .executor import InstructionExecutor, instruction_executor __all__ = [ "IntentParser", - "DefaultIntentParser", + "intent_parser", "InstructionExecutor", - "DefaultInstructionExecutor", + "instruction_executor", ] diff --git a/backend/app/instruction/executor.py b/backend/app/instruction/executor.py index 36292ce..c7a05c7 100644 --- a/backend/app/instruction/executor.py +++ b/backend/app/instruction/executor.py @@ -2,34 +2,571 @@ 指令执行器模块 将自然语言指令转换为可执行操作 - -注意: 此模块为可选功能,当前尚未实现。 """ -from abc import ABC, abstractmethod -from typing import Any, Dict +import logging +import json +from typing import Any, Dict, List, Optional + +from app.services.template_fill_service import template_fill_service +from app.services.rag_service import rag_service +from app.services.markdown_ai_service import markdown_ai_service +from app.core.database import mongodb + +logger = logging.getLogger(__name__) -class InstructionExecutor(ABC): - """指令执行器抽象基类""" +class InstructionExecutor: + """指令执行器""" - @abstractmethod - async def execute(self, instruction: str, context: Dict[str, Any]) -> Dict[str, Any]: + def __init__(self): + self.intent_parser = None # 将通过 set_intent_parser 设置 + + def set_intent_parser(self, intent_parser): + """设置意图解析器""" + self.intent_parser = intent_parser + + async def execute(self, instruction: str, context: Dict[str, Any] = None) -> Dict[str, Any]: """ 执行指令 Args: - instruction: 解析后的指令 - context: 执行上下文 + instruction: 自然语言指令 + context: 执行上下文(包含文档信息等) Returns: 执行结果 """ - pass + if self.intent_parser is None: + from app.instruction.intent_parser import intent_parser + self.intent_parser = intent_parser + + context = context or {} + + # 解析意图 + intent, params = await self.intent_parser.parse(instruction) + + # 根据意图类型执行相应操作 + if intent == "extract": + return await self._execute_extract(params, context) + elif intent == "fill_table": + return await self._execute_fill_table(params, context) + elif intent == "summarize": + return await self._execute_summarize(params, context) + elif intent == "question": + return await self._execute_question(params, context) + elif intent == "search": + return await self._execute_search(params, context) + elif intent == "compare": + return await self._execute_compare(params, context) + elif intent == "edit": + return await self._execute_edit(params, context) + elif intent == "transform": + return await self._execute_transform(params, context) + else: + return { + "success": False, + "error": f"未知意图类型: {intent}", + "message": "无法理解该指令,请尝试更明确的描述" + } + + async def _execute_extract(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]: + """执行信息提取""" + try: + target_fields = params.get("field_refs", []) + doc_ids = params.get("document_refs", []) + + if not target_fields: + return { + "success": False, + "error": "未指定要提取的字段", + "message": "请明确说明要提取哪些字段,如:'提取医院数量和床位数'" + } + + # 如果指定了文档,验证文档存在 + if doc_ids and "all_docs" not in doc_ids: + valid_docs = [] + for doc_ref in doc_ids: + doc_id = doc_ref.replace("doc_", "") + doc = await mongodb.get_document(doc_id) + if doc: + valid_docs.append(doc) + if not valid_docs: + return { + "success": False, + "error": "指定的文档不存在", + "message": "请检查文档编号是否正确" + } + context["source_docs"] = valid_docs + + # 构建字段列表 + fields = [] + for i, field_name in enumerate(target_fields): + fields.append({ + "name": field_name, + "cell": f"A{i+1}", + "field_type": "text", + "required": False + }) + + # 调用填表服务 + result = await template_fill_service.fill_template( + template_fields=fields, + source_doc_ids=[doc.get("_id") for doc in context.get("source_docs", [])] if context.get("source_docs") else None, + user_hint=f"请提取字段: {', '.join(target_fields)}" + ) + + return { + "success": True, + "intent": "extract", + "extracted_data": result.get("filled_data", {}), + "fields": target_fields, + "message": f"成功提取 {len(result.get('filled_data', {}))} 个字段" + } + + except Exception as e: + logger.error(f"提取执行失败: {e}") + return { + "success": False, + "error": str(e), + "message": f"提取失败: {str(e)}" + } + + async def _execute_fill_table(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]: + """执行填表操作""" + try: + template_file = context.get("template_file") + if not template_file: + return { + "success": False, + "error": "未提供表格模板", + "message": "请先上传要填写的表格模板" + } + + # 获取源文档 + source_docs = context.get("source_docs", []) + source_doc_ids = [doc.get("_id") for doc in source_docs if doc.get("_id")] + + # 获取字段 + fields = context.get("template_fields", []) + + # 调用填表服务 + result = await template_fill_service.fill_template( + template_fields=fields, + source_doc_ids=source_doc_ids if source_doc_ids else None, + source_file_paths=context.get("source_file_paths"), + user_hint=params.get("user_hint"), + template_id=template_file if isinstance(template_file, str) else None, + template_file_type=params.get("template", {}).get("type", "xlsx") + ) + + return { + "success": True, + "intent": "fill_table", + "result": result, + "message": f"填表完成,成功填写 {len(result.get('filled_data', {}))} 个字段" + } + + except Exception as e: + logger.error(f"填表执行失败: {e}") + return { + "success": False, + "error": str(e), + "message": f"填表失败: {str(e)}" + } + + async def _execute_summarize(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]: + """执行摘要总结""" + try: + docs = context.get("source_docs", []) + if not docs: + return { + "success": False, + "error": "没有可用的文档", + "message": "请先上传要总结的文档" + } + + summaries = [] + for doc in docs[:5]: # 最多处理5个文档 + content = doc.get("content", "")[:5000] # 限制内容长度 + if content: + summaries.append({ + "filename": doc.get("metadata", {}).get("original_filename", "未知"), + "content_preview": content[:500] + "..." if len(content) > 500 else content + }) + + return { + "success": True, + "intent": "summarize", + "summaries": summaries, + "message": f"找到 {len(summaries)} 个文档可供参考" + } + + except Exception as e: + logger.error(f"摘要执行失败: {e}") + return { + "success": False, + "error": str(e), + "message": f"摘要生成失败: {str(e)}" + } + + async def _execute_question(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]: + """执行问答""" + try: + question = params.get("question", "") + if not question: + return { + "success": False, + "error": "未提供问题", + "message": "请输入要回答的问题" + } + + # 使用 RAG 检索相关文档 + docs = context.get("source_docs", []) + rag_results = [] + + for doc in docs: + doc_id = doc.get("_id", "") + if doc_id: + results = rag_service.retrieve_by_doc_id(doc_id, top_k=3) + rag_results.extend(results) + + # 构建上下文 + context_text = "\n\n".join([ + r.get("content", "") for r in rag_results[:5] + ]) if rag_results else "" + + # 如果没有 RAG 结果,使用文档内容 + if not context_text: + context_text = "\n\n".join([ + doc.get("content", "")[:3000] for doc in docs[:3] if doc.get("content") + ]) + + return { + "success": True, + "intent": "question", + "question": question, + "context_preview": context_text[:500] + "..." if len(context_text) > 500 else context_text, + "message": "已找到相关上下文,可进行问答" + } + + except Exception as e: + logger.error(f"问答执行失败: {e}") + return { + "success": False, + "error": str(e), + "message": f"问答处理失败: {str(e)}" + } + + async def _execute_search(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]: + """执行搜索""" + try: + field_refs = params.get("field_refs", []) + query = " ".join(field_refs) if field_refs else params.get("question", "") + + if not query: + return { + "success": False, + "error": "未提供搜索关键词", + "message": "请输入要搜索的关键词" + } + + # 使用 RAG 检索 + results = rag_service.retrieve(query, top_k=10, min_score=0.3) + + return { + "success": True, + "intent": "search", + "query": query, + "results": [ + { + "content": r.get("content", "")[:200], + "score": r.get("score", 0), + "doc_id": r.get("doc_id", "") + } + for r in results[:10] + ], + "message": f"找到 {len(results)} 条相关结果" + } + + except Exception as e: + logger.error(f"搜索执行失败: {e}") + return { + "success": False, + "error": str(e), + "message": f"搜索失败: {str(e)}" + } + + async def _execute_compare(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]: + """执行对比分析""" + try: + docs = context.get("source_docs", []) + if len(docs) < 2: + return { + "success": False, + "error": "对比需要至少2个文档", + "message": "请上传至少2个文档进行对比" + } + + # 提取文档基本信息 + comparison = [] + for i, doc in enumerate(docs[:5]): + comparison.append({ + "index": i + 1, + "filename": doc.get("metadata", {}).get("original_filename", "未知"), + "doc_type": doc.get("doc_type", "未知"), + "content_length": len(doc.get("content", "")), + "has_tables": bool(doc.get("structured_data", {}).get("tables")), + }) + + return { + "success": True, + "intent": "compare", + "comparison": comparison, + "message": f"对比了 {len(comparison)} 个文档的基本信息" + } + + except Exception as e: + logger.error(f"对比执行失败: {e}") + return { + "success": False, + "error": str(e), + "message": f"对比分析失败: {str(e)}" + } + + async def _execute_edit(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]: + """执行文档编辑操作""" + try: + docs = context.get("source_docs", []) + if not docs: + return { + "success": False, + "error": "没有可用的文档", + "message": "请先上传要编辑的文档" + } + + doc = docs[0] # 默认编辑第一个文档 + content = doc.get("content", "") + original_filename = doc.get("metadata", {}).get("original_filename", "未知文档") + + if not content: + return { + "success": False, + "error": "文档内容为空", + "message": "该文档没有可编辑的内容" + } + + # 使用 LLM 进行文本润色/编辑 + prompt = f"""请对以下文档内容进行编辑处理。 + +原文内容: +{content[:8000]} + +编辑要求: +- 润色表述,使其更加专业流畅 +- 修正明显的语法错误 +- 保持原意不变 +- 只返回编辑后的内容,不要解释 + +请直接输出编辑后的内容:""" + + messages = [ + {"role": "system", "content": "你是一个专业的文本编辑助手。请直接输出编辑后的内容。"}, + {"role": "user", "content": prompt} + ] + + from app.services.llm_service import llm_service + response = await llm_service.chat(messages=messages, temperature=0.3, max_tokens=8000) + edited_content = llm_service.extract_message_content(response) + + return { + "success": True, + "intent": "edit", + "edited_content": edited_content, + "original_filename": original_filename, + "message": "文档编辑完成,内容已返回" + } + + except Exception as e: + logger.error(f"编辑执行失败: {e}") + return { + "success": False, + "error": str(e), + "message": f"编辑处理失败: {str(e)}" + } + + async def _execute_transform(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]: + """ + 执行格式转换操作 + + 支持: + - Word -> Excel + - Excel -> Word + - Markdown -> Word + - Word -> Markdown + """ + try: + docs = context.get("source_docs", []) + if not docs: + return { + "success": False, + "error": "没有可用的文档", + "message": "请先上传要转换的文档" + } + + # 获取目标格式 + template_info = params.get("template", {}) + target_type = template_info.get("type", "") + + if not target_type: + # 尝试从指令中推断 + instruction = params.get("instruction", "") + if "excel" in instruction.lower() or "xlsx" in instruction.lower(): + target_type = "xlsx" + elif "word" in instruction.lower() or "docx" in instruction.lower(): + target_type = "docx" + elif "markdown" in instruction.lower() or "md" in instruction.lower(): + target_type = "md" + + if not target_type: + return { + "success": False, + "error": "未指定目标格式", + "message": "请说明要转换成什么格式(如:转成Excel、转成Word)" + } + + doc = docs[0] + content = doc.get("content", "") + structured_data = doc.get("structured_data", {}) + original_filename = doc.get("metadata", {}).get("original_filename", "未知文档") + + # 构建转换内容 + if structured_data.get("tables"): + # 有表格数据,生成表格格式的内容 + tables = structured_data.get("tables", []) + table_content = [] + for i, table in enumerate(tables[:3]): # 最多处理3个表格 + headers = table.get("headers", []) + rows = table.get("rows", [])[:20] # 最多20行 + if headers: + table_content.append(f"【表格 {i+1}】") + table_content.append(" | ".join(str(h) for h in headers)) + table_content.append(" | ".join(["---"] * len(headers))) + for row in rows: + if isinstance(row, list): + table_content.append(" | ".join(str(c) for c in row)) + elif isinstance(row, dict): + table_content.append(" | ".join(str(row.get(h, "")) for h in headers)) + table_content.append("") + + if target_type == "xlsx": + # 生成 Excel 格式的数据(JSON) + excel_data = [] + for table in tables[:1]: # 只处理第一个表格 + headers = table.get("headers", []) + rows = table.get("rows", [])[:100] + for row in rows: + if isinstance(row, list): + excel_data.append(dict(zip(headers, row))) + elif isinstance(row, dict): + excel_data.append(row) + + return { + "success": True, + "intent": "transform", + "transform_type": "to_excel", + "target_format": "xlsx", + "excel_data": excel_data, + "headers": headers, + "message": f"已转换为 Excel 格式,包含 {len(excel_data)} 行数据" + } + elif target_type in ["docx", "word"]: + # 生成 Word 格式的文本 + word_content = f"# {original_filename}\n\n" + word_content += "\n".join(table_content) + + return { + "success": True, + "intent": "transform", + "transform_type": "to_word", + "target_format": "docx", + "content": word_content, + "message": "已转换为 Word 格式" + } + elif target_type == "md": + # 生成 Markdown 格式 + md_content = f"# {original_filename}\n\n" + md_content += "\n".join(table_content) + + return { + "success": True, + "intent": "transform", + "transform_type": "to_markdown", + "target_format": "md", + "content": md_content, + "message": "已转换为 Markdown 格式" + } + + # 无表格数据,使用纯文本内容转换 + if target_type == "xlsx": + # 将文本内容转为 Excel 格式(每行作为一列) + lines = [line.strip() for line in content.split("\n") if line.strip()][:100] + excel_data = [{"行号": i+1, "内容": line} for i, line in enumerate(lines)] + + return { + "success": True, + "intent": "transform", + "transform_type": "to_excel", + "target_format": "xlsx", + "excel_data": excel_data, + "headers": ["行号", "内容"], + "message": f"已将文本内容转换为 Excel,包含 {len(excel_data)} 行" + } + elif target_type in ["docx", "word"]: + return { + "success": True, + "intent": "transform", + "transform_type": "to_word", + "target_format": "docx", + "content": content, + "message": "文档内容已准备好,可下载为 Word 格式" + } + elif target_type == "md": + # 简单的文本转 Markdown + md_lines = [] + for line in content.split("\n"): + line = line.strip() + if line: + # 简单处理:如果行不长且不是列表格式,作为段落 + if len(line) < 100 and not line.startswith(("-", "*", "1.", "2.", "3.")): + md_lines.append(line) + else: + md_lines.append(line) + else: + md_lines.append("") + + return { + "success": True, + "intent": "transform", + "transform_type": "to_markdown", + "target_format": "md", + "content": "\n".join(md_lines), + "message": "已转换为 Markdown 格式" + } + + return { + "success": False, + "error": "不支持的目标格式", + "message": f"暂不支持转换为 {target_type} 格式" + } + + except Exception as e: + logger.error(f"格式转换失败: {e}") + return { + "success": False, + "error": str(e), + "message": f"格式转换失败: {str(e)}" + } -class DefaultInstructionExecutor(InstructionExecutor): - """默认指令执行器""" - - async def execute(self, instruction: str, context: Dict[str, Any]) -> Dict[str, Any]: - """暂未实现""" - raise NotImplementedError("指令执行功能暂未实现") +# 全局单例 +instruction_executor = InstructionExecutor() diff --git a/backend/app/instruction/intent_parser.py b/backend/app/instruction/intent_parser.py index 49df250..b53c034 100644 --- a/backend/app/instruction/intent_parser.py +++ b/backend/app/instruction/intent_parser.py @@ -2,17 +2,51 @@ 意图解析器模块 解析用户自然语言指令,识别意图和参数 - -注意: 此模块为可选功能,当前尚未实现。 """ -from abc import ABC, abstractmethod -from typing import Any, Dict, Tuple +import re +import logging +from typing import Any, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) -class IntentParser(ABC): - """意图解析器抽象基类""" +class IntentParser: + """意图解析器""" + + # 意图类型定义 + INTENT_EXTRACT = "extract" # 信息提取 + INTENT_FILL_TABLE = "fill_table" # 填表 + INTENT_SUMMARIZE = "summarize" # 摘要总结 + INTENT_QUESTION = "question" # 问答 + INTENT_SEARCH = "search" # 搜索 + INTENT_COMPARE = "compare" # 对比分析 + INTENT_TRANSFORM = "transform" # 格式转换 + INTENT_EDIT = "edit" # 编辑文档 + INTENT_UNKNOWN = "unknown" # 未知 + + # 意图关键词映射 + INTENT_KEYWORDS = { + INTENT_EXTRACT: ["提取", "抽取", "获取", "找出", "查找", "识别", "找到"], + INTENT_FILL_TABLE: ["填表", "填写", "填充", "录入", "导入到表格", "填写到"], + INTENT_SUMMARIZE: ["总结", "摘要", "概括", "概述", "归纳", "提炼"], + INTENT_QUESTION: ["问答", "回答", "解释", "什么是", "为什么", "如何", "怎样", "多少", "几个"], + INTENT_SEARCH: ["搜索", "查找", "检索", "查询", "找"], + INTENT_COMPARE: ["对比", "比较", "差异", "区别", "不同"], + INTENT_TRANSFORM: ["转换", "转化", "变成", "转为", "导出"], + INTENT_EDIT: ["修改", "编辑", "调整", "改写", "润色", "优化"], + } + + # 实体模式定义 + ENTITY_PATTERNS = { + "number": [r"\d+", r"[一二三四五六七八九十百千万]+"], + "date": [r"\d{4}年", r"\d{1,2}月", r"\d{1,2}日"], + "percentage": [r"\d+(\.\d+)?%", r"\d+(\.\d+)?‰"], + "currency": [r"\d+(\.\d+)?万元", r"\d+(\.\d+)?亿元", r"\d+(\.\d+)?元"], + } + + def __init__(self): + self.intent_history: List[Dict[str, Any]] = [] - @abstractmethod async def parse(self, text: str) -> Tuple[str, Dict[str, Any]]: """ 解析自然语言指令 @@ -23,12 +57,186 @@ class IntentParser(ABC): Returns: (意图类型, 参数字典) """ - pass + text = text.strip() + if not text: + return self.INTENT_UNKNOWN, {} + + # 记录历史 + self.intent_history.append({"text": text, "intent": None}) + + # 识别意图 + intent = self._recognize_intent(text) + + # 提取参数 + params = self._extract_params(text, intent) + + # 更新历史 + if self.intent_history: + self.intent_history[-1]["intent"] = intent + + logger.info(f"意图解析: text={text[:50]}..., intent={intent}, params={params}") + + return intent, params + + def _recognize_intent(self, text: str) -> str: + """识别意图类型""" + intent_scores: Dict[str, float] = {} + + for intent, keywords in self.INTENT_KEYWORDS.items(): + score = 0 + for keyword in keywords: + if keyword in text: + score += 1 + if score > 0: + intent_scores[intent] = score + + if not intent_scores: + return self.INTENT_UNKNOWN + + # 返回得分最高的意图 + return max(intent_scores, key=intent_scores.get) + + def _extract_params(self, text: str, intent: str) -> Dict[str, Any]: + """提取参数""" + params: Dict[str, Any] = { + "entities": self._extract_entities(text), + "document_refs": self._extract_document_refs(text), + "field_refs": self._extract_field_refs(text), + "template_refs": self._extract_template_refs(text), + } + + # 根据意图类型提取特定参数 + if intent == self.INTENT_QUESTION: + params["question"] = text + params["focus"] = self._extract_question_focus(text) + elif intent == self.INTENT_FILL_TABLE: + params["template"] = self._extract_template_info(text) + elif intent == self.INTENT_EXTRACT: + params["target_fields"] = self._extract_target_fields(text) + + return params + + def _extract_entities(self, text: str) -> Dict[str, List[str]]: + """提取实体""" + entities: Dict[str, List[str]] = {} + + for entity_type, patterns in self.ENTITY_PATTERNS.items(): + matches = [] + for pattern in patterns: + found = re.findall(pattern, text) + matches.extend(found) + if matches: + entities[entity_type] = list(set(matches)) + + return entities + + def _extract_document_refs(self, text: str) -> List[str]: + """提取文档引用""" + # 匹配 "文档1"、"doc1"、"第一个文档" 等 + refs = [] + + # 数字索引: 文档1, doc1, 第1个文档 + num_patterns = [ + r"[文档doc]+(\d+)", + r"第(\d+)个文档", + r"第(\d+)份", + ] + for pattern in num_patterns: + matches = re.findall(pattern, text.lower()) + refs.extend([f"doc_{m}" for m in matches]) + + # "所有文档"、"全部文档" + if any(kw in text for kw in ["所有", "全部", "整个"]): + refs.append("all_docs") + + return refs + + def _extract_field_refs(self, text: str) -> List[str]: + """提取字段引用""" + fields = [] + + # 匹配引号内的字段名 + quoted = re.findall(r"['\"『「]([^'\"』」]+)['\"』」]", text) + fields.extend(quoted) + + # 匹配 "xxx字段"、"xxx列" 等 + field_patterns = [ + r"([^\s]+)字段", + r"([^\s]+)列", + r"([^\s]+)数据", + ] + for pattern in field_patterns: + matches = re.findall(pattern, text) + fields.extend(matches) + + return list(set(fields)) + + def _extract_template_refs(self, text: str) -> List[str]: + """提取模板引用""" + templates = [] + + # 匹配 "表格模板"、"Excel模板"、"表1" 等 + template_patterns = [ + r"([^\s]+模板)", + r"表(\d+)", + r"([^\s]+表格)", + ] + for pattern in template_patterns: + matches = re.findall(pattern, text) + templates.extend(matches) + + return list(set(templates)) + + def _extract_question_focus(self, text: str) -> Optional[str]: + """提取问题焦点""" + # "什么是XXX"、"XXX是什么" + match = re.search(r"[什么是]([^?]+)", text) + if match: + return match.group(1).strip() + + # "XXX有多少" + match = re.search(r"([^?]+)有多少", text) + if match: + return match.group(1).strip() + + return None + + def _extract_template_info(self, text: str) -> Optional[Dict[str, str]]: + """提取模板信息""" + template_info: Dict[str, str] = {} + + # 提取模板类型 + if "excel" in text.lower() or "xlsx" in text.lower() or "电子表格" in text: + template_info["type"] = "xlsx" + elif "word" in text.lower() or "docx" in text.lower() or "文档" in text: + template_info["type"] = "docx" + + return template_info if template_info else None + + def _extract_target_fields(self, text: str) -> List[str]: + """提取目标字段""" + fields = [] + + # 匹配 "提取XXX和YYY"、"抽取XXX、YYY" + patterns = [ + r"提取([^(and|,|,)+]+?)(?:和|与|、|,|plus)", + r"抽取([^(and|,|,)+]+?)(?:和|与|、|,|plus)", + ] + + for pattern in patterns: + matches = re.findall(pattern, text) + fields.extend([m.strip() for m in matches if m.strip()]) + + return list(set(fields)) + + def get_intent_history(self) -> List[Dict[str, Any]]: + """获取意图历史""" + return self.intent_history + + def clear_history(self): + """清空历史""" + self.intent_history = [] -class DefaultIntentParser(IntentParser): - """默认意图解析器""" - - async def parse(self, text: str) -> Tuple[str, Dict[str, Any]]: - """暂未实现""" - raise NotImplementedError("意图解析功能暂未实现") +# 全局单例 +intent_parser = IntentParser() diff --git a/backend/app/services/multi_doc_reasoning_service.py b/backend/app/services/multi_doc_reasoning_service.py new file mode 100644 index 0000000..e4021f1 --- /dev/null +++ b/backend/app/services/multi_doc_reasoning_service.py @@ -0,0 +1,446 @@ +""" +多文档关联推理服务 + +跨文档信息关联和推理 +""" +import logging +import re +from typing import Any, Dict, List, Optional, Set, Tuple +from collections import defaultdict + +from app.services.llm_service import llm_service +from app.services.rag_service import rag_service + +logger = logging.getLogger(__name__) + + +class MultiDocReasoningService: + """ + 多文档关联推理服务 + + 功能: + 1. 实体跨文档追踪 - 追踪同一实体在不同文档中的描述 + 2. 关系抽取与推理 - 抽取实体间关系并进行推理 + 3. 信息补全 - 根据多个文档的信息互补填充缺失数据 + 4. 冲突检测 - 检测不同文档间的信息冲突 + """ + + def __init__(self): + self.llm = llm_service + + async def analyze_cross_documents( + self, + documents: List[Dict[str, Any]], + query: Optional[str] = None, + entity_types: Optional[List[str]] = None + ) -> Dict[str, Any]: + """ + 跨文档分析 + + Args: + documents: 文档列表 + query: 查询意图(可选) + entity_types: 要追踪的实体类型列表,如 ["机构", "人物", "地点", "数量"] + + Returns: + 跨文档分析结果 + """ + if not documents: + return {"success": False, "error": "没有可用的文档"} + + entity_types = entity_types or ["机构", "数量", "时间", "地点"] + + try: + # 1. 提取各文档中的实体 + entities_per_doc = await self._extract_entities_from_docs(documents, entity_types) + + # 2. 跨文档实体对齐 + aligned_entities = self._align_entities_across_docs(entities_per_doc) + + # 3. 关系抽取 + relations = await self._extract_relations(documents) + + # 4. 构建知识图谱 + knowledge_graph = self._build_knowledge_graph(aligned_entities, relations) + + # 5. 信息补全 + completed_info = await self._complete_missing_info(knowledge_graph, documents) + + # 6. 冲突检测 + conflicts = self._detect_conflicts(aligned_entities) + + return { + "success": True, + "entities": aligned_entities, + "relations": relations, + "knowledge_graph": knowledge_graph, + "completed_info": completed_info, + "conflicts": conflicts, + "summary": self._generate_summary(aligned_entities, conflicts) + } + + except Exception as e: + logger.error(f"跨文档分析失败: {e}") + return {"success": False, "error": str(e)} + + async def _extract_entities_from_docs( + self, + documents: List[Dict[str, Any]], + entity_types: List[str] + ) -> List[Dict[str, Any]]: + """从各文档中提取实体""" + entities_per_doc = [] + + for idx, doc in enumerate(documents): + doc_id = doc.get("_id", f"doc_{idx}") + content = doc.get("content", "")[:8000] # 限制长度 + + # 使用 LLM 提取实体 + prompt = f"""从以下文档中提取指定的实体类型信息。 + +实体类型: {', '.join(entity_types)} + +文档内容: +{content} + +请按以下 JSON 格式输出(只需输出 JSON): +{{ + "entities": [ + {{"type": "机构", "name": "实体名称", "value": "相关数值(如有)", "context": "上下文描述"}}, + ... + ] +}} + +只提取在文档中明确提到的实体,不要推测。""" + + messages = [ + {"role": "system", "content": "你是一个实体提取专家。请严格按JSON格式输出。"}, + {"role": "user", "content": prompt} + ] + + try: + response = await self.llm.chat(messages=messages, temperature=0.1, max_tokens=3000) + content_response = self.llm.extract_message_content(response) + + # 解析 JSON + import json + import re + cleaned = content_response.strip() + json_match = re.search(r'\{[\s\S]*\}', cleaned) + if json_match: + result = json.loads(json_match.group()) + entities = result.get("entities", []) + entities_per_doc.append({ + "doc_id": doc_id, + "doc_name": doc.get("metadata", {}).get("original_filename", f"文档{idx+1}"), + "entities": entities + }) + logger.info(f"文档 {doc_id} 提取到 {len(entities)} 个实体") + except Exception as e: + logger.warning(f"文档 {doc_id} 实体提取失败: {e}") + + return entities_per_doc + + def _align_entities_across_docs( + self, + entities_per_doc: List[Dict[str, Any]] + ) -> Dict[str, List[Dict[str, Any]]]: + """ + 跨文档实体对齐 + + 将同一实体在不同文档中的描述进行关联 + """ + aligned: Dict[str, List[Dict[str, Any]]] = defaultdict(list) + + for doc_data in entities_per_doc: + doc_id = doc_data["doc_id"] + doc_name = doc_data["doc_name"] + + for entity in doc_data.get("entities", []): + entity_name = entity.get("name", "") + if not entity_name: + continue + + # 标准化实体名(去除空格和括号内容) + normalized = self._normalize_entity_name(entity_name) + + aligned[normalized].append({ + "original_name": entity_name, + "type": entity.get("type", "未知"), + "value": entity.get("value", ""), + "context": entity.get("context", ""), + "source_doc": doc_name, + "source_doc_id": doc_id + }) + + # 合并相同实体 + result = {} + for normalized, appearances in aligned.items(): + if len(appearances) > 1: + result[normalized] = appearances + logger.info(f"实体对齐: {normalized} 在 {len(appearances)} 个文档中出现") + + return result + + def _normalize_entity_name(self, name: str) -> str: + """标准化实体名称""" + # 去除空格 + name = name.strip() + # 去除括号内容 + name = re.sub(r'[((].*?[))]', '', name) + # 去除"第X名"等 + name = re.sub(r'^第\d+[名位个]', '', name) + return name.strip() + + async def _extract_relations( + self, + documents: List[Dict[str, Any]] + ) -> List[Dict[str, str]]: + """从文档中抽取关系""" + relations = [] + + # 合并所有文档内容 + combined_content = "\n\n".join([ + f"【{doc.get('metadata', {}).get('original_filename', f'文档{i}')}】\n{doc.get('content', '')[:3000]}" + for i, doc in enumerate(documents) + ]) + + prompt = f"""从以下文档内容中抽取实体之间的关系。 + +文档内容: +{combined_content[:8000]} + +请识别以下类型的关系: +- 包含关系 (A包含B) +- 隶属关系 (A隶属于B) +- 合作关系 (A与B合作) +- 对比关系 (A vs B) +- 时序关系 (A先于B发生) + +请按以下 JSON 格式输出(只需输出 JSON): +{{ + "relations": [ + {{"entity1": "实体1", "entity2": "实体2", "relation": "关系类型", "description": "关系描述"}}, + ... + ] +}} + +如果没有找到明确的关系,返回空数组。""" + + messages = [ + {"role": "system", "content": "你是一个关系抽取专家。请严格按JSON格式输出。"}, + {"role": "user", "content": prompt} + ] + + try: + response = await self.llm.chat(messages=messages, temperature=0.1, max_tokens=3000) + content_response = self.llm.extract_message_content(response) + + import json + import re + cleaned = content_response.strip() + json_match = re.search(r'\{{[\s\S]*\}}', cleaned) + if json_match: + result = json.loads(json_match.group()) + relations = result.get("relations", []) + logger.info(f"抽取到 {len(relations)} 个关系") + except Exception as e: + logger.warning(f"关系抽取失败: {e}") + + return relations + + def _build_knowledge_graph( + self, + aligned_entities: Dict[str, List[Dict[str, Any]]], + relations: List[Dict[str, str]] + ) -> Dict[str, Any]: + """构建知识图谱""" + nodes = [] + edges = [] + node_ids = set() + + # 添加实体节点 + for entity_name, appearances in aligned_entities.items(): + if len(appearances) < 1: + continue + + first_appearance = appearances[0] + node_id = f"entity_{len(nodes)}" + + # 收集该实体在所有文档中的值 + values = [a.get("value", "") for a in appearances if a.get("value")] + primary_value = values[0] if values else "" + + nodes.append({ + "id": node_id, + "name": entity_name, + "type": first_appearance.get("type", "未知"), + "value": primary_value, + "occurrence_count": len(appearances), + "sources": [a.get("source_doc", "") for a in appearances] + }) + node_ids.add(entity_name) + + # 添加关系边 + for relation in relations: + entity1 = self._normalize_entity_name(relation.get("entity1", "")) + entity2 = self._normalize_entity_name(relation.get("entity2", "")) + + if entity1 in node_ids and entity2 in node_ids: + edges.append({ + "source": entity1, + "target": entity2, + "relation": relation.get("relation", "相关"), + "description": relation.get("description", "") + }) + + return { + "nodes": nodes, + "edges": edges, + "stats": { + "entity_count": len(nodes), + "relation_count": len(edges) + } + } + + async def _complete_missing_info( + self, + knowledge_graph: Dict[str, Any], + documents: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """根据多个文档补全信息""" + completed = [] + + for node in knowledge_graph.get("nodes", []): + if not node.get("value") and node.get("occurrence_count", 0) > 1: + # 实体在多个文档中出现但没有数值,尝试从 RAG 检索补充 + query = f"{node['name']} 数值 数据" + results = rag_service.retrieve(query, top_k=3, min_score=0.3) + + if results: + completed.append({ + "entity": node["name"], + "type": node.get("type", "未知"), + "source": "rag_inference", + "context": results[0].get("content", "")[:200], + "confidence": results[0].get("score", 0) + }) + + return completed + + def _detect_conflicts( + self, + aligned_entities: Dict[str, List[Dict[str, Any]]] + ) -> List[Dict[str, Any]]: + """检测不同文档间的信息冲突""" + conflicts = [] + + for entity_name, appearances in aligned_entities.items(): + if len(appearances) < 2: + continue + + # 检查数值冲突 + values = {} + for appearance in appearances: + val = appearance.get("value", "") + if val: + source = appearance.get("source_doc", "未知来源") + values[source] = val + + if len(values) > 1: + unique_values = set(values.values()) + if len(unique_values) > 1: + conflicts.append({ + "entity": entity_name, + "type": "value_conflict", + "details": values, + "description": f"实体 '{entity_name}' 在不同文档中有不同数值: {values}" + }) + + return conflicts + + def _generate_summary( + self, + aligned_entities: Dict[str, List[Dict[str, Any]]], + conflicts: List[Dict[str, Any]] + ) -> str: + """生成摘要""" + summary_parts = [] + + total_entities = sum(len(appearances) for appearances in aligned_entities.values()) + multi_doc_entities = sum(1 for appearances in aligned_entities.values() if len(appearances) > 1) + + summary_parts.append(f"跨文档分析完成:发现 {total_entities} 个实体") + summary_parts.append(f"其中 {multi_doc_entities} 个实体在多个文档中被提及") + + if conflicts: + summary_parts.append(f"检测到 {len(conflicts)} 个潜在冲突") + + return "; ".join(summary_parts) + + async def answer_cross_doc_question( + self, + question: str, + documents: List[Dict[str, Any]] + ) -> Dict[str, Any]: + """ + 跨文档问答 + + Args: + question: 问题 + documents: 文档列表 + + Returns: + 答案结果 + """ + # 先进行跨文档分析 + analysis_result = await self.analyze_cross_documents(documents, query=question) + + # 构建上下文 + context_parts = [] + + # 添加实体信息 + for entity_name, appearances in analysis_result.get("entities", {}).items(): + contexts = [f"{a.get('source_doc')}: {a.get('context', '')}" for a in appearances[:2]] + if contexts: + context_parts.append(f"【{entity_name}】{' | '.join(contexts)}") + + # 添加关系信息 + for relation in analysis_result.get("relations", [])[:5]: + context_parts.append(f"【关系】{relation.get('entity1')} {relation.get('relation')} {relation.get('entity2')}: {relation.get('description', '')}") + + context_text = "\n\n".join(context_parts) if context_parts else "未找到相关实体和关系" + + # 使用 LLM 生成答案 + prompt = f"""基于以下跨文档分析结果,回答用户问题。 + +问题: {question} + +分析结果: +{context_text} + +请直接回答问题,如果分析结果中没有相关信息,请说明"根据提供的文档无法回答该问题"。""" + + messages = [ + {"role": "system", "content": "你是一个基于文档的问答助手。请根据提供的信息回答问题。"}, + {"role": "user", "content": prompt} + ] + + try: + response = await self.llm.chat(messages=messages, temperature=0.2, max_tokens=2000) + answer = self.llm.extract_message_content(response) + + return { + "success": True, + "question": question, + "answer": answer, + "supporting_entities": list(analysis_result.get("entities", {}).keys())[:10], + "relations_count": len(analysis_result.get("relations", [])) + } + except Exception as e: + logger.error(f"跨文档问答失败: {e}") + return {"success": False, "error": str(e)} + + +# 全局单例 +multi_doc_reasoning_service = MultiDocReasoningService() diff --git a/backend/app/services/rag_service.py b/backend/app/services/rag_service.py index b6e905b..50c2607 100644 --- a/backend/app/services/rag_service.py +++ b/backend/app/services/rag_service.py @@ -2,11 +2,15 @@ RAG 服务模块 - 检索增强生成 使用 sentence-transformers + Faiss 实现向量检索 +支持 BM25 关键词检索 + 向量检索混合融合 """ import logging import os import pickle -from typing import Any, Dict, List, Optional +import re +import math +from typing import Any, Dict, List, Optional, Tuple +from collections import Counter, defaultdict import faiss import numpy as np @@ -32,6 +36,132 @@ class SimpleDocument: self.metadata = metadata +class BM25: + """ + BM25 关键词检索算法 + + 一种基于词频和文档频率的信息检索算法,比纯向量搜索更适合关键词精确匹配 + """ + + def __init__(self, k1: float = 1.5, b: float = 0.75): + self.k1 = k1 # 词频饱和参数 + self.b = b # 文档长度归一化参数 + self.documents: List[str] = [] + self.doc_ids: List[str] = [] + self.avg_doc_length = 0 + self.doc_freqs: Dict[str, int] = {} # 词 -> 包含该词的文档数 + self.idf: Dict[str, float] = {} # 词 -> IDF 值 + self.doc_lengths: List[int] = [] + self.doc_term_freqs: List[Dict[str, int]] = [] # 每个文档的词频 + + def _tokenize(self, text: str) -> List[str]: + """分词(简单的中文分词)""" + if not text: + return [] + # 简单分词:按标点和空格分割 + tokens = re.findall(r'[\u4e00-\u9fff]+|[a-zA-Z0-9]+', text.lower()) + # 过滤单字符 + return [t for t in tokens if len(t) > 1] + + def fit(self, documents: List[str], doc_ids: List[str]): + """ + 构建 BM25 索引 + + Args: + documents: 文档内容列表 + doc_ids: 文档 ID 列表 + """ + self.documents = documents + self.doc_ids = doc_ids + n = len(documents) + + # 统计文档频率 + self.doc_freqs = defaultdict(int) + self.doc_lengths = [] + self.doc_term_freqs = [] + + for doc in documents: + tokens = self._tokenize(doc) + self.doc_lengths.append(len(tokens)) + doc_tf = Counter(tokens) + self.doc_term_freqs.append(doc_tf) + + for term in doc_tf: + self.doc_freqs[term] += 1 + + # 计算平均文档长度 + self.avg_doc_length = sum(self.doc_lengths) / n if n > 0 else 0 + + # 计算 IDF + for term, df in self.doc_freqs.items(): + # IDF = log((n - df + 0.5) / (df + 0.5)) + self.idf[term] = math.log((n - df + 0.5) / (df + 0.5) + 1) + + logger.info(f"BM25 索引构建完成: {n} 个文档, {len(self.idf)} 个词项") + + def search(self, query: str, top_k: int = 10) -> List[Tuple[int, float]]: + """ + 搜索相关文档 + + Args: + query: 查询文本 + top_k: 返回前 k 个结果 + + Returns: + [(文档索引, BM25分数), ...] + """ + if not self.documents: + return [] + + query_tokens = self._tokenize(query) + if not query_tokens: + return [] + + scores = [] + n = len(self.documents) + + for idx in range(n): + score = self._calculate_score(query_tokens, idx) + scores.append((idx, score)) + + # 按分数降序排序 + scores.sort(key=lambda x: x[1], reverse=True) + + return scores[:top_k] + + def _calculate_score(self, query_tokens: List[str], doc_idx: int) -> float: + """计算单个文档的 BM25 分数""" + doc_tf = self.doc_term_freqs[doc_idx] + doc_len = self.doc_lengths[doc_idx] + score = 0.0 + + for term in query_tokens: + if term not in self.idf: + continue + + tf = doc_tf.get(term, 0) + idf = self.idf[term] + + # BM25 公式 + numerator = tf * (self.k1 + 1) + denominator = tf + self.k1 * (1 - self.b + self.b * doc_len / self.avg_doc_length) + + score += idf * numerator / denominator + + return score + + def get_scores(self, query: str) -> List[float]: + """获取所有文档的 BM25 分数""" + if not self.documents: + return [] + + query_tokens = self._tokenize(query) + if not query_tokens: + return [0.0] * len(self.documents) + + return [self._calculate_score(query_tokens, idx) for idx in range(len(self.documents))] + + class RAGService: """RAG 检索增强服务""" @@ -47,12 +177,15 @@ class RAGService: self._dimension: int = 384 # 默认维度 self._initialized = False self._persist_dir = settings.FAISS_INDEX_DIR + # BM25 索引 + self.bm25: Optional[BM25] = None + self._bm25_enabled = True # 始终启用 BM25 # 检查是否可用 self._disabled = not SENTENCE_TRANSFORMERS_AVAILABLE if self._disabled: - logger.warning("RAG 服务已禁用(sentence-transformers 不可用),将使用关键词匹配作为后备") + logger.warning("RAG 服务已禁用(sentence-transformers 不可用),将使用 BM25 关键词检索") else: - logger.info("RAG 服务已启用") + logger.info("RAG 服务已启用(向量检索 + BM25 混合检索)") def _init_embeddings(self): """初始化嵌入模型""" @@ -261,11 +394,25 @@ class RAGService: if not documents: return - # 总是将文档存储在内存中(用于关键词搜索后备) + # 总是将文档存储在内存中(用于 BM25 和关键词搜索) for doc, did in zip(documents, doc_ids): self.documents.append({"id": did, "content": doc.page_content, "metadata": doc.metadata}) self.doc_ids.append(did) + # 构建 BM25 索引 + if self._bm25_enabled and documents: + bm25_texts = [doc.page_content for doc in documents] + if self.bm25 is None: + self.bm25 = BM25() + self.bm25.fit(bm25_texts, doc_ids) + else: + # 增量添加:重新构建(BM25 不支持增量) + all_texts = [d["content"] for d in self.documents] + all_ids = self.doc_ids.copy() + self.bm25 = BM25() + self.bm25.fit(all_texts, all_ids) + logger.debug(f"BM25 索引更新: {len(documents)} 个文档") + # 如果没有嵌入模型,跳过向量索引 if self.embedding_model is None: logger.debug(f"文档跳过向量索引 (无嵌入模型): {len(documents)} 个文档") @@ -284,7 +431,7 @@ class RAGService: def retrieve(self, query: str, top_k: int = 5, min_score: float = 0.3) -> List[Dict[str, Any]]: """ - 根据查询检索相关文档块 + 根据查询检索相关文档块(混合检索:向量 + BM25) Args: query: 查询文本 @@ -301,39 +448,167 @@ class RAGService: if not self._initialized: self._init_vector_store() - # 优先使用向量检索 - if self.index is not None and self.index.ntotal > 0 and self.embedding_model is not None: - try: - query_embedding = self.embedding_model.encode([query], convert_to_numpy=True) - query_embedding = self._normalize_vectors(query_embedding).astype('float32') + # 获取向量检索结果 + vector_results = self._vector_search(query, top_k * 2, min_score) - scores, indices = self.index.search(query_embedding, min(top_k, self.index.ntotal)) + # 获取 BM25 检索结果 + bm25_results = self._bm25_search(query, top_k * 2) - results = [] - for score, idx in zip(scores[0], indices[0]): - if idx < 0: - continue - if score < min_score: - continue - doc = self.documents[idx] - results.append({ - "content": doc["content"], - "metadata": doc["metadata"], - "score": float(score), - "doc_id": doc["id"], - "chunk_index": doc["metadata"].get("chunk_index", 0) - }) + # 混合融合 + hybrid_results = self._hybrid_fusion(vector_results, bm25_results, top_k) - if results: - logger.debug(f"向量检索到 {len(results)} 条相关文档块") - return results - except Exception as e: - logger.warning(f"向量检索失败,使用关键词搜索后备: {e}") + if hybrid_results: + logger.info(f"混合检索到 {len(hybrid_results)} 条相关文档块 (向量:{len(vector_results)}, BM25:{len(bm25_results)})") + return hybrid_results - # 后备:使用关键词搜索 - logger.debug("使用关键词搜索后备方案") + # 降级:只使用 BM25 + if bm25_results: + logger.info(f"降级到 BM25 检索: {len(bm25_results)} 条") + return bm25_results + + # 降级:使用关键词搜索 + logger.info("降级到关键词搜索") return self._keyword_search(query, top_k) + def _vector_search(self, query: str, top_k: int, min_score: float) -> List[Dict[str, Any]]: + """向量检索""" + if self.index is None or self.index.ntotal == 0 or self.embedding_model is None: + return [] + + try: + query_embedding = self.embedding_model.encode([query], convert_to_numpy=True) + query_embedding = self._normalize_vectors(query_embedding).astype('float32') + + scores, indices = self.index.search(query_embedding, min(top_k * 2, self.index.ntotal)) + + results = [] + for score, idx in zip(scores[0], indices[0]): + if idx < 0: + continue + if score < min_score: + continue + doc = self.documents[idx] + results.append({ + "content": doc["content"], + "metadata": doc["metadata"], + "score": float(score), + "doc_id": doc["id"], + "chunk_index": doc["metadata"].get("chunk_index", 0), + "search_type": "vector" + }) + + return results + except Exception as e: + logger.warning(f"向量检索失败: {e}") + return [] + + def _bm25_search(self, query: str, top_k: int) -> List[Dict[str, Any]]: + """BM25 检索""" + if not self.bm25 or not self.documents: + return [] + + try: + bm25_scores = self.bm25.get_scores(query) + if not bm25_scores: + return [] + + # 归一化 BM25 分数到 [0, 1] + max_score = max(bm25_scores) if bm25_scores else 1 + min_score_bm = min(bm25_scores) if bm25_scores else 0 + score_range = max_score - min_score_bm if max_score != min_score_bm else 1 + + results = [] + for idx, score in enumerate(bm25_scores): + if score <= 0: + continue + # 归一化 + normalized_score = (score - min_score_bm) / score_range if score_range > 0 else 0 + doc = self.documents[idx] + results.append({ + "content": doc["content"], + "metadata": doc["metadata"], + "score": float(normalized_score), + "doc_id": doc["id"], + "chunk_index": doc["metadata"].get("chunk_index", 0), + "search_type": "bm25" + }) + + # 按分数降序 + results.sort(key=lambda x: x["score"], reverse=True) + return results[:top_k] + + except Exception as e: + logger.warning(f"BM25 检索失败: {e}") + return [] + + def _hybrid_fusion( + self, + vector_results: List[Dict[str, Any]], + bm25_results: List[Dict[str, Any]], + top_k: int + ) -> List[Dict[str, Any]]: + """ + 混合融合向量和 BM25 检索结果 + + 使用 RRFR (Reciprocal Rank Fusion) 算法: + Score = weight_vector * (1 / rank_vector) + weight_bm25 * (1 / rank_bm25) + + Args: + vector_results: 向量检索结果 + bm25_results: BM25 检索结果 + top_k: 返回数量 + + Returns: + 融合后的结果 + """ + if not vector_results and not bm25_results: + return [] + + # 融合权重 + weight_vector = 0.6 + weight_bm25 = 0.4 + + # 构建文档分数映射 + doc_scores: Dict[str, Dict[str, float]] = {} + + # 添加向量检索结果 + for rank, result in enumerate(vector_results): + doc_id = result["doc_id"] + if doc_id not in doc_scores: + doc_scores[doc_id] = {"vector": 0, "bm25": 0, "content": result["content"], "metadata": result["metadata"]} + # 使用倒数排名 (Reciprocal Rank) + doc_scores[doc_id]["vector"] = weight_vector / (rank + 1) + + # 添加 BM25 检索结果 + for rank, result in enumerate(bm25_results): + doc_id = result["doc_id"] + if doc_id not in doc_scores: + doc_scores[doc_id] = {"vector": 0, "bm25": 0, "content": result["content"], "metadata": result["metadata"]} + doc_scores[doc_id]["bm25"] = weight_bm25 / (rank + 1) + + # 计算融合分数 + fused_results = [] + for doc_id, scores in doc_scores.items(): + fused_score = scores["vector"] + scores["bm25"] + # 使用向量检索结果的原始分数作为参考 + vector_score = next((r["score"] for r in vector_results if r["doc_id"] == doc_id), 0.5) + fused_results.append({ + "content": scores["content"], + "metadata": scores["metadata"], + "score": fused_score, + "doc_id": doc_id, + "vector_score": vector_score, + "bm25_score": scores["bm25"], + "search_type": "hybrid" + }) + + # 按融合分数降序排序 + fused_results.sort(key=lambda x: x["score"], reverse=True) + + logger.debug(f"混合融合: {len(fused_results)} 个文档, 向量:{len(vector_results)}, BM25:{len(bm25_results)}") + + return fused_results[:top_k] + def _keyword_search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: """ 关键词搜索后备方案 diff --git a/backend/app/services/template_fill_service.py b/backend/app/services/template_fill_service.py index 9465d35..1486b8b 100644 --- a/backend/app/services/template_fill_service.py +++ b/backend/app/services/template_fill_service.py @@ -13,6 +13,7 @@ from app.services.llm_service import llm_service from app.core.document_parser import ParserFactory from app.services.markdown_ai_service import markdown_ai_service from app.services.rag_service import rag_service +from app.services.word_ai_service import word_ai_service logger = logging.getLogger(__name__) @@ -55,6 +56,249 @@ class FillResult: class TemplateFillService: """表格填写服务""" + # 通用表头语义扩展字典 + GENERIC_HEADER_EXPANSION = { + "机构": ["医院", "学校", "企业", "机关", "团体", "协会", "基金会", "研究所", "医院数量", "学校数量", "企业数量"], + "名称": ["医院名称", "学校名称", "企业名称", "机构名称", "单位名称", "名称"], + "类型": ["医院类型", "学校类型", "企业类型", "机构类型", "类型分类"], + "数量": ["医院数量", "学校数量", "企业数量", "机构数量", "个数", "总数", "人员数量"], + "金额": ["金额", "收入", "支出", "产值", "销售额", "利润", "税收"], + "比率": ["增长率", "占比", "比重", "比率", "百分比", "使用率", "就业率"], + "面积": ["占地面积", "建筑面积", "用地面积", "耕地面积", "绿化面积"], + "人口": ["常住人口", "户籍人口", "流动人口", "城镇人口", "农村人口"], + "价格": ["价格", "物价", "CPI", "涨幅", "指数"], + "增长": ["增速", "增长率", "增幅", "增长", "上涨", "下降"], + } + + # 模板表头到源文档表头的映射缓存 + _header_mapping_cache: Dict[str, Dict[str, str]] = {} + + def _analyze_source_table_structure(self, source_docs: List["SourceDocument"]) -> Dict[str, Any]: + """ + 分析源文档的表格结构 + + Args: + source_docs: 源文档列表 + + Returns: + 表格结构分析结果,包含所有表头和样本数据 + """ + table_structures = {} + + for doc_idx, doc in enumerate(source_docs): + structured = doc.structured_data if doc.structured_data else {} + + # 处理多 sheet 格式 + if structured.get("sheets"): + for sheet_name, sheet_data in structured.get("sheets", {}).items(): + if isinstance(sheet_data, dict): + columns = sheet_data.get("columns", []) + rows = sheet_data.get("rows", [])[:10] # 只取前10行作为样本 + key = f"doc{doc_idx}_{sheet_name}" + table_structures[key] = { + "doc_idx": doc_idx, + "sheet_name": sheet_name, + "columns": columns, + "sample_rows": rows, + "column_count": len(columns), + "row_count": len(sheet_data.get("rows", [])) + } + + # 处理 tables 格式 + elif structured.get("tables"): + for table_idx, table in enumerate(structured.get("tables", [])[:5]): + if isinstance(table, dict): + headers = table.get("headers", []) + rows = table.get("rows", [])[:10] + key = f"doc{doc_idx}_table{table_idx}" + table_structures[key] = { + "doc_idx": doc_idx, + "table_idx": table_idx, + "columns": headers, + "sample_rows": rows, + "column_count": len(headers), + "row_count": len(table.get("rows", [])) + } + + # 处理单 sheet 格式 + elif structured.get("columns") and structured.get("rows"): + columns = structured.get("columns", []) + rows = structured.get("rows", [])[:10] + key = f"doc{doc_idx}_default" + table_structures[key] = { + "doc_idx": doc_idx, + "columns": columns, + "sample_rows": rows, + "column_count": len(columns), + "row_count": len(structured.get("rows", [])) + } + + logger.info(f"分析源文档表格结构: {len(table_structures)} 个表格") + return table_structures + + def _build_adaptive_header_mapping( + self, + template_fields: List["TemplateField"], + source_table_structures: Dict[str, Any] + ) -> Dict[str, Dict[str, Any]]: + """ + 自适应构建模板表头到源文档表头的映射 + + Args: + template_fields: 模板字段列表 + source_table_structures: 源文档表格结构 + + Returns: + 映射字典: {field_name: {source_table_key: {column: idx, match_score: score}}} + """ + mappings = {} + + for field in template_fields: + field_name = field.name + field_lower = field_name.lower() + field_keywords = set(field_lower.replace(" ", "").split()) + + best_matches = {} + + for table_key, table_info in source_table_structures.items(): + columns = table_info.get("columns", []) + if not columns: + continue + + best_col_idx = None + best_col_name = None + best_score = 0 + + for col_idx, col in enumerate(columns): + col_str = str(col).strip() + col_lower = col_str.lower() + col_keywords = set(col_lower.replace(" ", "").split()) + + score = 0 + + # 1. 精确匹配 + if col_lower == field_lower: + score = 1.0 + # 2. 子字符串匹配 + elif field_lower in col_lower or col_lower in field_lower: + score = 0.8 * max(len(field_lower), len(col_lower)) / min(len(field_lower) + 1, len(col_lower) + 1) + # 3. 关键词重叠 + else: + overlap = field_keywords & col_keywords + if overlap: + score = 0.6 * len(overlap) / max(len(field_keywords), len(col_keywords), 1) + + # 4. 检查通用表头扩展 + if score < 0.5: + for generic, specifics in self.GENERIC_HEADER_EXPANSION.items(): + if generic in field_lower: + for specific in specifics: + if specific in col_lower or col_lower in specific: + score = 0.7 + break + if score >= 0.5: + break + + if score > best_score: + best_score = score + best_col_idx = col_idx + best_col_name = col_str + + if best_score >= 0.3 and best_col_idx is not None: + best_matches[table_key] = { + "column_index": best_col_idx, + "column_name": best_col_name, + "match_score": best_score, + "table_info": table_info + } + + if best_matches: + mappings[field_name] = best_matches + logger.info(f"字段 '{field_name}' 匹配到 {len(best_matches)} 个源表头,最佳匹配: {list(best_matches.values())[0].get('column_name')}") + + return mappings + + def _extract_with_adaptive_mapping( + self, + source_docs: List["SourceDocument"], + field_name: str, + mapping: Dict[str, Dict[str, Any]] + ) -> List[str]: + """ + 使用自适应映射提取字段值 + + Args: + source_docs: 源文档列表 + field_name: 字段名 + mapping: 字段到源表头的映射 + + Returns: + 提取的值列表 + """ + values = [] + + if field_name not in mapping: + return values + + best_matches = mapping[field_name] + + for table_key, match_info in best_matches.items(): + table_info = match_info.get("table_info", {}) + col_idx = match_info.get("column_index", 0) + doc_idx = table_info.get("doc_idx", 0) + + if doc_idx >= len(source_docs): + continue + + doc = source_docs[doc_idx] + structured = doc.structured_data if doc.structured_data else {} + + # 根据表格类型提取值 + rows = [] + + # 多 sheet 格式 + if structured.get("sheets"): + sheet_name = table_info.get("sheet_name") + if sheet_name: + sheet_data = structured.get("sheets", {}).get(sheet_name, {}) + rows = sheet_data.get("rows", []) + + # tables 格式 + elif structured.get("tables"): + table_idx = table_info.get("table_idx", 0) + tables = structured.get("tables", []) + if table_idx < len(tables): + rows = tables[table_idx].get("rows", []) + + # 单 sheet 格式 + elif structured.get("rows"): + rows = structured.get("rows", []) + + # 提取指定列的值 + for row in rows: + if isinstance(row, list) and col_idx < len(row): + val = self._format_value(row[col_idx]) + if val and self._is_valid_data_value(val): + values.append(val) + elif isinstance(row, dict): + # 对于 dict 格式的行 + columns = table_info.get("columns", []) + if col_idx < len(columns): + col_name = columns[col_idx] + val = self._format_value(row.get(col_name, "")) + if val and self._is_valid_data_value(val): + values.append(val) + + # 过滤和去重 + seen = set() + unique_values = [] + for v in values: + if v not in seen: + seen.add(v) + unique_values.append(v) + + return unique_values + def __init__(self): self.llm = llm_service @@ -305,6 +549,62 @@ class TemplateFillService: if source_file_paths: for file_path in source_file_paths: try: + file_ext = file_path.lower().split('.')[-1] + + # 对于 Word 文档,优先使用 AI 解析 + if file_ext == 'docx': + # 使用 AI 深度解析 Word 文档 + ai_result = await word_ai_service.parse_word_with_ai( + file_path=file_path, + user_hint="请提取文档中的所有结构化数据,包括表格、键值对等" + ) + + if ai_result.get("success"): + # AI 解析成功,转换为 SourceDocument 格式 + parse_type = ai_result.get("type", "unknown") + + # 构建 structured_data + doc_structured = { + "ai_parsed": True, + "parse_type": parse_type, + "tables": [], + "key_values": ai_result.get("key_values", {}) if "key_values" in ai_result else {}, + "list_items": ai_result.get("list_items", []) if "list_items" in ai_result else [], + "summary": ai_result.get("summary", "") if "summary" in ai_result else "" + } + + # 如果 AI 返回了表格数据 + if parse_type == "table_data": + headers = ai_result.get("headers", []) + rows = ai_result.get("rows", []) + if headers and rows: + doc_structured["tables"] = [{ + "headers": headers, + "rows": rows + }] + doc_structured["columns"] = headers + doc_structured["rows"] = rows + logger.info(f"AI 表格数据: {len(headers)} 列, {len(rows)} 行") + elif parse_type == "structured_text": + tables = ai_result.get("tables", []) + if tables: + doc_structured["tables"] = tables + logger.info(f"AI 结构化文本提取到 {len(tables)} 个表格") + + # 获取摘要内容 + content_text = doc_structured.get("summary", "") or ai_result.get("description", "") + + source_docs.append(SourceDocument( + doc_id=file_path, + filename=file_path.split("/")[-1] if "/" in file_path else file_path.split("\\")[-1], + doc_type="docx", + content=content_text, + structured_data=doc_structured + )) + logger.info(f"AI 解析 Word 文档: {file_path}, type={parse_type}, tables={len(doc_structured.get('tables', []))}") + continue # 跳后续的基础解析 + + # 基础解析(Excel 或非 AI 解析的 Word) parser = ParserFactory.get_parser(file_path) result = parser.parse(file_path) if result.success: @@ -1351,6 +1651,36 @@ class TemplateFillService: if all_values: break + # 处理 AI 解析的 Word 文档键值对格式: {key_values: {"键": "值"}, ...} + if structured.get("key_values") and isinstance(structured.get("key_values"), dict): + key_values = structured.get("key_values", {}) + logger.info(f" 检测到 AI 解析键值对格式,共 {len(key_values)} 个键值对") + values = self._extract_from_key_values(key_values, field_name) + if values: + all_values.extend(values) + logger.info(f"从 Word AI 键值对提取到 {len(values)} 个值: {values}") + break + + # 处理 AI 解析的 list_items 格式 + if structured.get("list_items") and isinstance(structured.get("list_items"), list): + list_items = structured.get("list_items", []) + logger.info(f" 检测到 AI 解析列表格式,共 {len(list_items)} 个列表项") + values = self._extract_from_list_items(list_items, field_name) + if values: + all_values.extend(values) + logger.info(f"从 Word AI 列表提取到 {len(values)} 个值") + break + + # 如果从结构化数据中没有提取到值,且字段是通用表头,搜索文本内容 + if not all_values and field_name in self.GENERIC_HEADER_EXPANSION: + for doc in source_docs: + if doc.content: + text_values = self._search_generic_header_in_text(doc.content, field_name) + if text_values: + all_values.extend(text_values) + logger.info(f"从文本内容通过通用表头匹配提取到 {len(text_values)} 个值") + break + return all_values def _extract_values_from_markdown_table(self, headers: List, rows: List, field_name: str) -> List[str]: @@ -1376,10 +1706,27 @@ class TemplateFillService: # 查找匹配的列索引 - 使用增强的匹配算法 target_idx = self._find_best_matching_column(headers, field_name) - if target_idx is None: + # 如果没有找到列匹配,尝试在第一列中搜索字段名(适用于指标在行的文档) + matched_row_idx = None + if target_idx is None and rows: + matched_row_idx = self._search_row_in_first_column(rows, field_name) + if matched_row_idx is not None: + logger.info(f"在第一列找到匹配: {field_name} -> 行索引 {matched_row_idx} (转置表格结构)") + + if target_idx is None and matched_row_idx is None: logger.warning(f"未找到匹配列: {field_name}, 可用表头: {headers}") return [] + # 如果在第一列找到匹配(转置表格),提取该行的其他列作为值 + if matched_row_idx is not None: + matched_row = rows[matched_row_idx] + if isinstance(matched_row, list): + # 跳过第一列(指标名),提取后续列的值 + for val in matched_row[1:]: + values.append(self._format_value(val)) + logger.info(f"转置表格提取到 {len(values)} 个值: {values[:5]}...") + return self._filter_valid_values(values) + logger.info(f"列匹配成功: {field_name} -> {headers[target_idx]} (索引: {target_idx})") values = [] @@ -1527,6 +1874,149 @@ class TemplateFillService: valid_values.append(val) return valid_values + def _extract_from_key_values(self, key_values: Dict[str, str], field_name: str) -> List[str]: + """ + 从键值对字典中提取与字段名匹配的值 + + Args: + key_values: 键值对字典,如 {"医院数量": "38710个", "床位总数": "456789张"} + field_name: 要匹配的字段名 + + Returns: + 匹配的值列表 + """ + if not key_values: + return [] + + field_lower = field_name.lower().strip() + field_chars = set(field_lower.replace(" ", "")) + field_keywords = set(field_lower.replace(" ", "").split()) + + best_match_key = None + best_match_score = 0 + + for key, value in key_values.items(): + key_str = str(key).strip() + key_lower = key_str.lower() + key_chars = set(key_lower.replace(" ", "")) + + if not key_str or not value: + continue + + # 策略1: 精确匹配(忽略大小写) + if key_lower == field_lower: + logger.info(f"键值对精确匹配: {field_name} -> {key_str}: {value}") + return [str(value)] + + # 策略2: 子字符串匹配 + if field_lower in key_lower or key_lower in field_lower: + score = max(len(field_lower), len(key_lower)) / min(len(field_lower) + 1, len(key_lower) + 1) + if score > best_match_score: + best_match_score = score + best_match_key = (key_str, value) + + # 策略3: 关键词重叠匹配(适用于中文) + key_keywords = set(key_lower.replace(" ", "").split()) + overlap = field_keywords & key_keywords + if overlap and len(overlap) > 0: + score = len(overlap) / max(len(field_keywords), len(key_keywords), 1) + if score > best_match_score: + best_match_score = score + best_match_key = (key_str, value) + + # 策略4: 字符级包含匹配(适用于中文短字段) + char_overlap = field_chars & key_chars + if char_overlap: + char_score = len(char_overlap) / max(len(field_chars), len(key_chars), 1) + # 对于短字段(<=4字符),降低要求 + if len(field_chars) <= 4 and char_score >= 0.5: + if char_score > best_match_score: + best_match_score = char_score + best_match_key = (key_str, value) + elif char_score > best_match_score and len(char_overlap) >= 2: + best_match_score = char_score + best_match_key = (key_str, value) + + # 降低阈值到 0.2,允许更多模糊匹配 + if best_match_score >= 0.2 and best_match_key: + logger.info(f"键值对模糊匹配: {field_name} -> {best_match_key[0]}: {best_match_key[1]} (分数: {best_match_score:.2f})") + return [str(best_match_key[1])] + + logger.warning(f"键值对未匹配到: {field_name}, 可用键: {list(key_values.keys())}") + return [] + + def _extract_from_list_items(self, list_items: List[str], field_name: str) -> List[str]: + """ + 从列表项中提取与字段名匹配的值 + + Args: + list_items: 列表项,如 ["医院数量: 38710个", "床位总数: 456789张", ...] + field_name: 要匹配的字段名 + + Returns: + 匹配的值列表 + """ + if not list_items: + return [] + + field_lower = field_name.lower().strip() + field_keywords = set(field_lower.replace(" ", "").split()) + + matched_values = [] + + for item in list_items: + item_str = str(item).strip() + if not item_str: + continue + + item_lower = item_str.lower() + + # 策略1: 检查列表项是否以字段名开头(格式如 "医院数量: 38710个") + if ':' in item_str or ':' in item_str: + parts = item_str.replace(':', ':').split(':', 1) + if len(parts) == 2: + key = parts[0].strip() + value = parts[1].strip() + key_lower = key.lower() + + # 精确匹配键 + if key_lower == field_lower: + logger.info(f"列表项键值精确匹配: {field_name} -> {value}") + return [value] + + # 子字符串匹配 + if field_lower in key_lower or key_lower in field_lower: + score = max(len(field_lower), len(key_lower)) / min(len(field_lower) + 1, len(key_lower) + 1) + if score >= 0.2: + logger.info(f"列表项键值模糊匹配: {field_name} -> {key}: {value} (分数: {score:.2f})") + matched_values.append(value) + + # 关键词重叠 + key_keywords = set(key_lower.replace(" ", "").split()) + overlap = field_keywords & key_keywords + if overlap: + score = len(overlap) / max(len(field_keywords), len(key_keywords), 1) + if score >= 0.2: + matched_values.append(value) + + # 策略2: 直接匹配整个列表项 + if field_lower in item_lower or item_lower in field_lower: + matched_values.append(item_str) + continue + + # 策略3: 关键词重叠 + item_keywords = set(item_lower.replace(" ", "").split()) + overlap = field_keywords & item_keywords + if overlap and len(overlap) >= 2: # 至少2个关键词重叠 + score = len(overlap) / max(len(field_keywords), len(item_keywords), 1) + if score >= 0.2: + matched_values.append(item_str) + + if matched_values: + logger.info(f"列表项匹配到 {len(matched_values)} 个: {matched_values[:5]}") + + return matched_values + def _find_best_matching_column(self, headers: List, field_name: str) -> Optional[int]: """ 查找最佳匹配的列索引 @@ -1535,6 +2025,7 @@ class TemplateFillService: 1. 精确匹配(忽略大小写) 2. 子字符串匹配(字段名在表头中,或表头在字段名中) 3. 关键词重叠匹配(中文字符串分割后比对) + 4. 字符级包含匹配(适用于中文短字段) Args: headers: 表头列表 @@ -1544,6 +2035,8 @@ class TemplateFillService: 匹配的列索引,找不到返回 None """ field_lower = field_name.lower().strip() + # 对中文进行字符级拆分,增加匹配的灵活性 + field_chars = set(field_lower.replace(" ", "")) field_keywords = set(field_lower.replace(" ", "").split()) best_match_idx = None @@ -1560,6 +2053,7 @@ class TemplateFillService: # 策略1: 精确匹配(忽略大小写) if header_lower == field_lower: + logger.info(f"精确匹配: {field_name} -> {header_str}") return idx # 策略2: 子字符串匹配 @@ -1580,13 +2074,93 @@ class TemplateFillService: best_match_score = score best_match_idx = idx - # 只有当匹配分数超过阈值时才返回 - if best_match_score >= 0.3: + # 策略4: 字符级包含匹配(适用于中文短字段,如"医院"匹配"医院数量") + header_chars = set(header_lower.replace(" ", "")) + char_overlap = field_chars & header_chars + if char_overlap: + # 计算字符重叠率,但要求至少有一定数量的重叠字符 + char_score = len(char_overlap) / max(len(field_chars), len(header_chars), 1) + # 对于短字段(<=4字符),降低要求,只要有重叠且字符score较高即可 + if len(field_chars) <= 4 and char_score >= 0.5: + if char_score > best_match_score: + best_match_score = char_score + best_match_idx = idx + elif char_score > best_match_score and len(char_overlap) >= 2: + # 对于较长字段,要求至少2个字符重叠 + best_match_score = char_score + best_match_idx = idx + + # 降低阈值到 0.2,允许更多模糊匹配 + if best_match_score >= 0.2: logger.info(f"模糊匹配: {field_name} -> {headers[best_match_idx]} (分数: {best_match_score:.2f})") return best_match_idx return None + def _search_row_in_first_column(self, rows: List, field_name: str) -> Optional[int]: + """ + 在表格第一列中搜索字段名(适用于指标在行的转置表格结构) + + 对于某些中文统计文档,表格结构是转置的: + - 第一列是指标名称(如"医院数量") + - 其他列是年份或数值 + + Args: + rows: 数据行列表 + field_name: 要搜索的字段名 + + Returns: + 匹配的列索引(始终返回0,因为是第一列),如果没找到返回None + """ + if not rows or not field_name: + return None + + field_lower = field_name.lower().strip() + field_chars = set(field_lower.replace(" ", "")) + field_keywords = set(field_lower.replace(" ", "").split()) + + for row_idx, row in enumerate(rows): + if not isinstance(row, list) or len(row) == 0: + continue + + first_cell = str(row[0]).strip() + if not first_cell: + continue + + first_cell_lower = first_cell.lower() + + # 精确匹配 + if first_cell_lower == field_lower: + logger.info(f"第一列精确匹配字段: {field_name} -> {first_cell} (行{row_idx})") + return 0 + + # 子字符串匹配 + if field_lower in first_cell_lower or first_cell_lower in field_lower: + score = max(len(field_lower), len(first_cell_lower)) / min(len(field_lower) + 1, len(first_cell_lower) + 1) + if score >= 0.5: + logger.info(f"第一列模糊匹配字段: {field_name} -> {first_cell} (行{row_idx}, 分数:{score:.2f})") + return 0 + + # 关键词重叠匹配 + first_keywords = set(first_cell_lower.replace(" ", "").split()) + overlap = field_keywords & first_keywords + if overlap and len(overlap) >= 2: + score = len(overlap) / max(len(field_keywords), len(first_keywords), 1) + if score >= 0.3: + logger.info(f"第一列关键词匹配: {field_name} -> {first_cell} (行{row_idx}, 分数:{score:.2f})") + return 0 + + # 字符级匹配(短字段) + first_chars = set(first_cell_lower.replace(" ", "")) + char_overlap = field_chars & first_chars + if char_overlap and len(field_chars) <= 4: + char_score = len(char_overlap) / max(len(field_chars), len(first_chars), 1) + if char_score >= 0.5: + logger.info(f"第一列字符匹配: {field_name} -> {first_cell} (行{row_idx}, 分数:{char_score:.2f})") + return 0 + + return None + def _extract_column_values(self, rows: List, columns: List, field_name: str) -> List[str]: """ 从 rows 和 columns 中提取指定列的值 @@ -1677,6 +2251,55 @@ class TemplateFillService: return str(val) + def _search_generic_header_in_text(self, text: str, field_name: str) -> List[str]: + """ + 从文本中搜索通用表头对应的具体值 + + 例如:表头"机构" -> 搜索文本中的"医院"、"学校"、"企业"等 + + Args: + text: 文档文本内容 + field_name: 字段名称(可能是通用表头) + + Returns: + 匹配到的值列表 + """ + import re + + # 检查是否是通用表头 + generic_terms = self.GENERIC_HEADER_EXPANSION.get(field_name, []) + if not generic_terms: + return [] + + matched_values = [] + + for term in generic_terms: + # 搜索 term + 数字/量词 的模式,如 "医院 100所" + patterns = [ + rf'{re.escape(term)}[\s\d所个家级人万元亿元%‰]+', # 医院100所, 企业50家 + rf'{re.escape(term)}[::\s]+(\d+[\d。,,]?\d*)', # 医院:100 + rf'(\d+[\d。,,]?\d*)[^\d]*{re.escape(term)}', # 100家医院 + ] + for pattern in patterns: + matches = re.findall(pattern, text, re.IGNORECASE) + for match in matches: + val = match.strip() if isinstance(match, str) else match + if val and len(str(val)) < 100: + matched_values.append(str(val)) + + # 去重并保持顺序 + seen = set() + unique_values = [] + for v in matched_values: + if v not in seen: + seen.add(v) + unique_values.append(v) + + if unique_values: + logger.info(f"通用表头 '{field_name}' 匹配到值: {unique_values[:10]}") + + return unique_values + def _extract_values_from_json(self, result) -> List[str]: """ 从解析后的 JSON 对象/数组中提取值数组 @@ -2236,29 +2859,32 @@ class TemplateFillService: - 二级分类:如"医院"下分为"公立医院"、"民营医院" 4. **生成字段**: - - 字段名要简洁,如:"医院数量"、"病床使用率" - - 优先选择:总数 + 主要分类 + - 字段名要详细具体,能区分不同数据,如:"医院数量(个)"、"病床使用率(%)"、"公立医院数量" + - 优先选择:总数 + 主要分类 + 重要指标 5. **生成数量**: - - 生成5-7个最有代表性的字段 + - 生成10-15个最有代表性的字段,确保覆盖主要数据指标 + +6. **添加字段说明**: + - 每个字段可以添加 hint 说明字段的含义和数据来源 请严格按照以下 JSON 格式输出(只需输出 JSON,不要其他内容): {{ "fields": [ - {{"name": "字段名1"}}, - {{"name": "字段名2"}} + {{"name": "医院数量", "hint": "从文档中提取医院总数,包括公立和民营医院"}}, + {{"name": "病床使用率", "hint": "提取病床使用率数据"}} ] }} """ messages = [ - {"role": "system", "content": "你是一个专业的表格设计助手。请严格按JSON格式输出,只返回纯数据字段名,不要source、备注、说明等辅助字段。"}, + {"role": "system", "content": "你是一个专业的表格设计助手。请严格按JSON格式输出,为每个字段生成详细名称和hint说明。"}, {"role": "user", "content": prompt} ] response = await self.llm.chat( messages=messages, temperature=0.3, - max_tokens=2000 + max_tokens=4000 ) content = self.llm.extract_message_content(response) diff --git a/backend/app/services/word_ai_service.py b/backend/app/services/word_ai_service.py index 197b256..3a0ab16 100644 --- a/backend/app/services/word_ai_service.py +++ b/backend/app/services/word_ai_service.py @@ -192,13 +192,15 @@ class WordAIService: result = self._parse_json_response(content) if result: - logger.info(f"AI 表格提取成功: {len(result.get('rows', []))} 行数据") + logger.info(f"AI 表格提取成功: {len(result.get('rows', []))} 行数据, key_values={len(result.get('key_values', {}))}, list_items={len(result.get('list_items', []))}") return { "success": True, "type": "table_data", "headers": result.get("headers", []), "rows": result.get("rows", []), - "description": result.get("description", "") + "description": result.get("description", ""), + "key_values": result.get("key_values", {}), + "list_items": result.get("list_items", []) } else: # 如果 AI 返回格式不对,尝试直接解析表格 diff --git a/frontend/src/db/backend-api.ts b/frontend/src/db/backend-api.ts index 41286a8..24973be 100644 --- a/frontend/src/db/backend-api.ts +++ b/frontend/src/db/backend-api.ts @@ -1459,4 +1459,131 @@ export const aiApi = { throw error; } }, + + // ==================== 智能指令 ==================== + + /** + * 识别自然语言指令的意图 + */ + async recognizeIntent( + instruction: string, + docIds?: string[] + ): Promise<{ + success: boolean; + intent: string; + params: Record; + message: string; + }> { + const url = `${BACKEND_BASE_URL}/instruction/recognize`; + + try { + const response = await fetch(url, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ instruction, doc_ids: docIds }), + }); + + if (!response.ok) { + const error = await response.json(); + throw new Error(error.detail || '意图识别失败'); + } + + return await response.json(); + } catch (error) { + console.error('意图识别失败:', error); + throw error; + } + }, + + /** + * 执行自然语言指令 + */ + async executeInstruction( + instruction: string, + docIds?: string[], + context?: Record + ): Promise<{ + success: boolean; + intent: string; + result: Record; + message: string; + }> { + const url = `${BACKEND_BASE_URL}/instruction/execute`; + + try { + const response = await fetch(url, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ instruction, doc_ids: docIds, context }), + }); + + if (!response.ok) { + const error = await response.json(); + throw new Error(error.detail || '指令执行失败'); + } + + return await response.json(); + } catch (error) { + console.error('指令执行失败:', error); + throw error; + } + }, + + /** + * 智能对话(支持多轮对话的指令执行) + */ + async instructionChat( + instruction: string, + docIds?: string[], + context?: Record + ): Promise<{ + success: boolean; + intent: string; + result: Record; + message: string; + hint?: string; + }> { + const url = `${BACKEND_BASE_URL}/instruction/chat`; + + try { + const response = await fetch(url, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ instruction, doc_ids: docIds, context }), + }); + + if (!response.ok) { + const error = await response.json(); + throw new Error(error.detail || '对话处理失败'); + } + + return await response.json(); + } catch (error) { + console.error('对话处理失败:', error); + throw error; + } + }, + + /** + * 获取支持的指令类型列表 + */ + async getSupportedIntents(): Promise<{ + intents: Array<{ + intent: string; + name: string; + examples: string[]; + params: string[]; + }>; + }> { + const url = `${BACKEND_BASE_URL}/instruction/intents`; + + try { + const response = await fetch(url); + if (!response.ok) throw new Error('获取指令列表失败'); + return await response.json(); + } catch (error) { + console.error('获取指令列表失败:', error); + throw error; + } + }, }; diff --git a/frontend/src/pages/InstructionChat.tsx b/frontend/src/pages/InstructionChat.tsx index cd6bdab..d1a14b6 100644 --- a/frontend/src/pages/InstructionChat.tsx +++ b/frontend/src/pages/InstructionChat.tsx @@ -10,7 +10,11 @@ import { TableProperties, ChevronRight, ArrowRight, - Loader2 + Loader2, + Download, + Search, + MessageSquare, + CheckCircle } from 'lucide-react'; import { Button } from '@/components/ui/button'; import { Input } from '@/components/ui/input'; @@ -26,12 +30,15 @@ type ChatMessage = { role: 'user' | 'assistant'; content: string; created_at: string; + intent?: string; + result?: any; }; const InstructionChat: React.FC = () => { const [messages, setMessages] = useState([]); const [input, setInput] = useState(''); const [loading, setLoading] = useState(false); + const [currentDocIds, setCurrentDocIds] = useState([]); const scrollAreaRef = useRef(null); useEffect(() => { @@ -43,27 +50,47 @@ const InstructionChat: React.FC = () => { role: 'assistant', content: `您好!我是智联文档 AI 助手。 -我可以帮您完成以下操作: +**📄 文档智能操作** +- "提取文档中的医院数量和床位数" +- "帮我找出所有机构的名称" -📄 **文档管理** -- "帮我列出最近上传的所有文档" -- "删除三天前的 docx 文档" +**📊 数据填表** +- "根据这些数据填表" +- "将提取的信息填写到Excel模板" -📊 **Excel 分析** -- "分析一下最近上传的 Excel 文件" -- "帮我统计销售报表中的数据" +**📝 内容处理** +- "总结一下这份文档" +- "对比这两个文档的差异" -📝 **智能填表** -- "根据员工信息表创建一个考勤汇总表" -- "用财务文档填充报销模板" +**🔍 智能问答** +- "文档里说了些什么?" +- "有多少家医院?" 请告诉我您想做什么?`, created_at: new Date().toISOString() } ]); + + // 获取已上传的文档ID列表 + loadDocuments(); } }, []); + const loadDocuments = async () => { + try { + const result = await backendApi.getDocuments(undefined, 50); + if (result.success && result.documents) { + const docIds = result.documents.map((d: any) => d.doc_id); + setCurrentDocIds(docIds); + if (docIds.length > 0) { + console.log(`已加载 ${docIds.length} 个文档`); + } + } + } catch (err) { + console.error('获取文档列表失败:', err); + } + }; + useEffect(() => { // Scroll to bottom if (scrollAreaRef.current) { @@ -89,95 +116,126 @@ const InstructionChat: React.FC = () => { setLoading(true); try { - // TODO: 后端对话接口,暂用模拟响应 - await new Promise(resolve => setTimeout(resolve, 1500)); + // 使用真实的智能指令 API + const response = await backendApi.instructionChat( + input.trim(), + currentDocIds.length > 0 ? currentDocIds : undefined + ); - // 简单的命令解析演示 - const userInput = userMessage.content.toLowerCase(); - let response = ''; + // 根据意图类型生成友好响应 + let responseContent = ''; + const resultData = response.result; - if (userInput.includes('列出') || userInput.includes('列表')) { - const result = await backendApi.getDocuments(undefined, 10); - if (result.success && result.documents && result.documents.length > 0) { - response = `已为您找到 ${result.documents.length} 个文档:\n\n`; - result.documents.slice(0, 5).forEach((doc: any, idx: number) => { - response += `${idx + 1}. **${doc.original_filename}** (${doc.doc_type.toUpperCase()})\n`; - response += ` - 大小: ${(doc.file_size / 1024).toFixed(1)} KB\n`; - response += ` - 时间: ${new Date(doc.created_at).toLocaleDateString()}\n\n`; - }); - if (result.documents.length > 5) { - response += `...还有 ${result.documents.length - 5} 个文档`; + switch (response.intent) { + case 'extract': + // 信息提取结果 + const extracted = resultData?.extracted_data || {}; + const keys = Object.keys(extracted); + if (keys.length > 0) { + responseContent = `✅ 已提取到 ${keys.length} 个字段的数据:\n\n`; + for (const [key, value] of Object.entries(extracted)) { + const values = Array.isArray(value) ? value : [value]; + responseContent += `**${key}**: ${values.slice(0, 3).join(', ')}${values.length > 3 ? '...' : ''}\n`; + } + responseContent += `\n💡 您可以将这些数据填入表格。`; + } else { + responseContent = '未能从文档中提取到相关数据。请尝试更明确的字段名称。'; } - } else { - response = '暂未找到已上传的文档,您可以先上传一些文档试试。'; - } - } else if (userInput.includes('分析') || userInput.includes('excel') || userInput.includes('报表')) { - response = `好的,我可以帮您分析 Excel 文件。 + break; -请告诉我: -1. 您想分析哪个 Excel 文件? -2. 需要什么样的分析?(数据摘要/统计分析/图表生成) + case 'fill_table': + // 填表结果 + const filled = resultData?.result?.filled_data || {}; + const filledKeys = Object.keys(filled); + if (filledKeys.length > 0) { + responseContent = `✅ 填表完成!成功填写 ${filledKeys.length} 个字段:\n\n`; + for (const [key, value] of Object.entries(filled)) { + const values = Array.isArray(value) ? value : [value]; + responseContent += `**${key}**: ${values.slice(0, 3).join(', ')}\n`; + } + responseContent += `\n📋 请到【智能填表】页面查看或导出结果。`; + } else { + responseContent = '填表未能提取到数据。请检查模板表头和数据源内容。'; + } + break; -或者您可以直接告诉我您想从数据中了解什么,我来为您生成分析。`; - } else if (userInput.includes('填表') || userInput.includes('模板')) { - response = `好的,要进行智能填表,我需要: + case 'summarize': + // 摘要结果 + const summaries = resultData?.summaries || []; + if (summaries.length > 0) { + responseContent = `📄 找到 ${summaries.length} 个文档的摘要:\n\n`; + summaries.forEach((s: any, idx: number) => { + responseContent += `**${idx + 1}. ${s.filename}**\n${s.content_preview}\n\n`; + }); + } else { + responseContent = '未能生成摘要。请确保已上传文档。'; + } + break; -1. **上传表格模板** - 您要填写的表格模板文件(Excel 或 Word 格式) -2. **选择数据源** - 包含要填写内容的源文档 + case 'question': + // 问答结果 + if (resultData?.answer) { + responseContent = `**问题**: ${resultData.question}\n\n**答案**: ${resultData.answer}`; + } else { + responseContent = resultData?.message || '我找到了相关信息,请查看上文。'; + } + break; -您可以去【智能填表】页面完成这些操作,或者告诉我您具体想填什么类型的表格,我来指导您操作。`; - } else if (userInput.includes('删除')) { - response = `要删除文档,请告诉我: + case 'search': + // 搜索结果 + const searchResults = resultData?.results || []; + if (searchResults.length > 0) { + responseContent = `🔍 找到 ${searchResults.length} 条相关内容:\n\n`; + searchResults.slice(0, 5).forEach((r: any, idx: number) => { + responseContent += `**${idx + 1}.** ${r.content?.substring(0, 100)}...\n\n`; + }); + } else { + responseContent = '未找到相关内容。请尝试其他关键词。'; + } + break; -- 要删除的文件名是什么? -- 或者您可以到【文档中心】页面手动选择并删除文档 + case 'compare': + // 对比结果 + const comparison = resultData?.comparison || []; + if (comparison.length > 0) { + responseContent = `📊 对比了 ${comparison.length} 个文档:\n\n`; + comparison.forEach((c: any) => { + responseContent += `- **${c.filename}**: ${c.doc_type}, ${c.content_length} 字\n`; + }); + } else { + responseContent = '需要至少2个文档才能进行对比。'; + } + break; -⚠️ 删除操作不可恢复,请确认后再操作。`; - } else if (userInput.includes('帮助') || userInput.includes('help')) { - response = `**我可以帮您完成以下操作:** + case 'unknown': + responseContent = `我理解您想要: "${input.trim()}"\n\n但我目前无法完成此操作。您可以尝试:\n\n1. **提取数据**: "提取医院数量和床位数"\n2. **填表**: "根据这些数据填表"\n3. **总结**: "总结这份文档"\n4. **问答**: "文档里说了什么?"\n5. **搜索**: "搜索相关内容"`; + break; -📄 **文档管理** -- 列出/搜索已上传的文档 -- 查看文档详情和元数据 -- 删除不需要的文档 - -📊 **Excel 处理** -- 分析 Excel 文件内容 -- 生成数据统计和图表 -- 导出处理后的数据 - -📝 **智能填表** -- 上传表格模板 -- 从文档中提取信息填入模板 -- 导出填写完成的表格 - -📋 **任务历史** -- 查看历史处理任务 -- 重新执行或导出结果 - -请直接告诉我您想做什么!`; - } else { - response = `我理解您想要: "${input.trim()}" - -目前我还在学习如何更好地理解您的需求。您可以尝试: - -1. **上传文档** - 去【文档中心】上传 docx/md/txt 文件 -2. **分析 Excel** - 去【Excel解析】上传并分析 Excel 文件 -3. **智能填表** - 去【智能填表】创建填表任务 - -或者您可以更具体地描述您想做的事情,我会尽力帮助您!`; + default: + responseContent = response.message || resultData?.message || '已完成您的请求。'; } const assistantMessage: ChatMessage = { id: Math.random().toString(36).substring(7), role: 'assistant', - content: response, - created_at: new Date().toISOString() + content: responseContent, + created_at: new Date().toISOString(), + intent: response.intent, + result: resultData }; setMessages(prev => [...prev, assistantMessage]); } catch (err: any) { - toast.error('请求失败,请重试'); + console.error('指令执行失败:', err); + toast.error(err.message || '请求失败,请重试'); + + const errorMessage: ChatMessage = { + id: Math.random().toString(36).substring(7), + role: 'assistant', + content: `抱歉,处理您的请求时遇到了问题:${err.message}\n\n请稍后重试,或尝试更简单的指令。`, + created_at: new Date().toISOString() + }; + setMessages(prev => [...prev, errorMessage]); } finally { setLoading(false); } @@ -189,10 +247,10 @@ const InstructionChat: React.FC = () => { }; const quickActions = [ - { label: '列出所有文档', icon: FileText, action: () => setInput('列出所有已上传的文档') }, - { label: '分析 Excel 数据', icon: TableProperties, action: () => setInput('分析一下 Excel 文件') }, - { label: '智能填表', icon: Sparkles, action: () => setInput('我想进行智能填表') }, - { label: '帮助', icon: Sparkles, action: () => setInput('帮助') } + { label: '提取医院数量', icon: Search, action: () => setInput('提取文档中的医院数量和床位数') }, + { label: '智能填表', icon: TableProperties, action: () => setInput('根据这些数据填表') }, + { label: '总结文档', icon: MessageSquare, action: () => setInput('总结一下这份文档') }, + { label: '智能问答', icon: Bot, action: () => setInput('文档里说了些什么?') } ]; return (