feat: 实现智能指令的格式转换和文档编辑功能
主要更新: - 新增 transform 意图:支持 Word/Excel/Markdown 格式互转 - 新增 edit 意图:使用 LLM 润色编辑文档内容 - 智能指令接口增加异步执行模式(async_execute 参数) - 修复 Word 模板导出文档损坏问题(改用临时文件方式) - 优化 intent_parser 增加 transform/edit 关键词识别 新增文件: - app/api/endpoints/instruction.py: 智能指令 API 端点 - app/services/multi_doc_reasoning_service.py: 多文档推理服务 其他优化: - RAG 服务混合搜索(BM25 + 向量)融合 - 模板填充服务表头匹配增强 - Word AI 解析服务返回结构完善 - 前端 InstructionChat 组件对接真实 API
This commit is contained in:
@@ -13,6 +13,7 @@ from app.api.endpoints import (
|
||||
visualization,
|
||||
analysis_charts,
|
||||
health,
|
||||
instruction, # 智能指令
|
||||
)
|
||||
|
||||
# 创建主路由
|
||||
@@ -29,3 +30,4 @@ api_router.include_router(templates.router) # 表格模板
|
||||
api_router.include_router(ai_analyze.router) # AI分析
|
||||
api_router.include_router(visualization.router) # 可视化
|
||||
api_router.include_router(analysis_charts.router) # 分析图表
|
||||
api_router.include_router(instruction.router) # 智能指令
|
||||
|
||||
439
backend/app/api/endpoints/instruction.py
Normal file
439
backend/app/api/endpoints/instruction.py
Normal file
@@ -0,0 +1,439 @@
|
||||
"""
|
||||
智能指令 API 接口
|
||||
|
||||
支持自然语言指令解析和执行
|
||||
"""
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.instruction.intent_parser import intent_parser
|
||||
from app.instruction.executor import instruction_executor
|
||||
from app.core.database import mongodb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/instruction", tags=["智能指令"])
|
||||
|
||||
|
||||
# ==================== 请求/响应模型 ====================
|
||||
|
||||
class InstructionRequest(BaseModel):
|
||||
instruction: str
|
||||
doc_ids: Optional[List[str]] = None # 关联的文档 ID 列表
|
||||
context: Optional[Dict[str, Any]] = None # 额外上下文
|
||||
|
||||
|
||||
class IntentRecognitionResponse(BaseModel):
|
||||
success: bool
|
||||
intent: str
|
||||
params: Dict[str, Any]
|
||||
message: str
|
||||
|
||||
|
||||
class InstructionExecutionResponse(BaseModel):
|
||||
success: bool
|
||||
intent: str
|
||||
result: Dict[str, Any]
|
||||
message: str
|
||||
|
||||
|
||||
# ==================== 接口 ====================
|
||||
|
||||
@router.post("/recognize", response_model=IntentRecognitionResponse)
|
||||
async def recognize_intent(request: InstructionRequest):
|
||||
"""
|
||||
意图识别接口
|
||||
|
||||
将自然语言指令解析为结构化的意图和参数
|
||||
|
||||
示例指令:
|
||||
- "提取文档中的医院数量和床位数"
|
||||
- "根据这些数据填表"
|
||||
- "总结一下这份文档"
|
||||
- "对比这两个文档的差异"
|
||||
"""
|
||||
try:
|
||||
intent, params = await intent_parser.parse(request.instruction)
|
||||
|
||||
# 添加文档关联信息
|
||||
if request.doc_ids:
|
||||
params["document_refs"] = [f"doc_{doc_id}" for doc_id in request.doc_ids]
|
||||
|
||||
intent_names = {
|
||||
"extract": "信息提取",
|
||||
"fill_table": "表格填写",
|
||||
"summarize": "摘要总结",
|
||||
"question": "智能问答",
|
||||
"search": "文档搜索",
|
||||
"compare": "对比分析",
|
||||
"transform": "格式转换",
|
||||
"edit": "文档编辑",
|
||||
"unknown": "未知"
|
||||
}
|
||||
|
||||
return IntentRecognitionResponse(
|
||||
success=True,
|
||||
intent=intent,
|
||||
params=params,
|
||||
message=f"识别到意图: {intent_names.get(intent, intent)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"意图识别失败: {e}")
|
||||
return IntentRecognitionResponse(
|
||||
success=False,
|
||||
intent="error",
|
||||
params={},
|
||||
message=f"意图识别失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/execute")
|
||||
async def execute_instruction(
|
||||
background_tasks: BackgroundTasks,
|
||||
request: InstructionRequest,
|
||||
async_execute: bool = Query(False, description="是否异步执行(仅返回任务ID)")
|
||||
):
|
||||
"""
|
||||
指令执行接口
|
||||
|
||||
解析并执行自然语言指令
|
||||
|
||||
示例:
|
||||
- 指令: "提取文档1中的医院数量"
|
||||
返回: {"extracted_data": {"医院数量": ["38710个"]}}
|
||||
|
||||
- 指令: "填表"
|
||||
返回: {"filled_data": {...}}
|
||||
|
||||
设置 async_execute=true 可异步执行,返回任务ID用于查询进度
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
if async_execute:
|
||||
# 异步模式:立即返回任务ID,后台执行
|
||||
background_tasks.add_task(
|
||||
_execute_instruction_task,
|
||||
task_id=task_id,
|
||||
instruction=request.instruction,
|
||||
doc_ids=request.doc_ids,
|
||||
context=request.context
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"task_id": task_id,
|
||||
"message": "指令已提交执行",
|
||||
"status_url": f"/api/v1/tasks/{task_id}"
|
||||
}
|
||||
|
||||
# 同步模式:等待执行完成
|
||||
return await _execute_instruction_task(task_id, request.instruction, request.doc_ids, request.context)
|
||||
|
||||
|
||||
async def _execute_instruction_task(
|
||||
task_id: str,
|
||||
instruction: str,
|
||||
doc_ids: Optional[List[str]],
|
||||
context: Optional[Dict[str, Any]]
|
||||
) -> InstructionExecutionResponse:
|
||||
"""执行指令的后台任务"""
|
||||
from app.core.database import redis_db, mongodb as mongo_client
|
||||
|
||||
try:
|
||||
# 记录任务
|
||||
try:
|
||||
await mongo_client.insert_task(
|
||||
task_id=task_id,
|
||||
task_type="instruction_execute",
|
||||
status="processing",
|
||||
message="正在执行指令"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 构建执行上下文
|
||||
ctx: Dict[str, Any] = context or {}
|
||||
|
||||
# 如果提供了文档 ID,获取文档内容
|
||||
if doc_ids:
|
||||
docs = []
|
||||
for doc_id in doc_ids:
|
||||
doc = await mongo_client.get_document(doc_id)
|
||||
if doc:
|
||||
docs.append(doc)
|
||||
|
||||
if docs:
|
||||
ctx["source_docs"] = docs
|
||||
logger.info(f"指令执行上下文: 关联了 {len(docs)} 个文档")
|
||||
|
||||
# 执行指令
|
||||
result = await instruction_executor.execute(instruction, ctx)
|
||||
|
||||
# 更新任务状态
|
||||
try:
|
||||
await mongo_client.update_task(
|
||||
task_id=task_id,
|
||||
status="success",
|
||||
message="执行完成",
|
||||
result=result
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return InstructionExecutionResponse(
|
||||
success=result.get("success", False),
|
||||
intent=result.get("intent", "unknown"),
|
||||
result=result,
|
||||
message=result.get("message", "执行完成")
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"指令执行失败: {e}")
|
||||
try:
|
||||
await mongo_client.update_task(
|
||||
task_id=task_id,
|
||||
status="failure",
|
||||
message="执行失败",
|
||||
error=str(e)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return InstructionExecutionResponse(
|
||||
success=False,
|
||||
intent="error",
|
||||
result={"error": str(e)},
|
||||
message=f"指令执行失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
async def instruction_chat(
|
||||
background_tasks: BackgroundTasks,
|
||||
request: InstructionRequest,
|
||||
async_execute: bool = Query(False, description="是否异步执行(仅返回任务ID)")
|
||||
):
|
||||
"""
|
||||
指令对话接口
|
||||
|
||||
支持多轮对话的指令执行
|
||||
|
||||
示例对话流程:
|
||||
1. 用户: "上传一些文档"
|
||||
2. 系统: "请上传文档"
|
||||
3. 用户: "提取其中的医院数量"
|
||||
4. 系统: 返回提取结果
|
||||
|
||||
设置 async_execute=true 可异步执行,返回任务ID用于查询进度
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
if async_execute:
|
||||
# 异步模式:立即返回任务ID,后台执行
|
||||
background_tasks.add_task(
|
||||
_execute_chat_task,
|
||||
task_id=task_id,
|
||||
instruction=request.instruction,
|
||||
doc_ids=request.doc_ids,
|
||||
context=request.context
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"task_id": task_id,
|
||||
"message": "指令已提交执行",
|
||||
"status_url": f"/api/v1/tasks/{task_id}"
|
||||
}
|
||||
|
||||
# 同步模式:等待执行完成
|
||||
return await _execute_chat_task(task_id, request.instruction, request.doc_ids, request.context)
|
||||
|
||||
|
||||
async def _execute_chat_task(
|
||||
task_id: str,
|
||||
instruction: str,
|
||||
doc_ids: Optional[List[str]],
|
||||
context: Optional[Dict[str, Any]]
|
||||
):
|
||||
"""执行指令对话的后台任务"""
|
||||
from app.core.database import mongodb as mongo_client
|
||||
|
||||
try:
|
||||
# 记录任务
|
||||
try:
|
||||
await mongo_client.insert_task(
|
||||
task_id=task_id,
|
||||
task_type="instruction_chat",
|
||||
status="processing",
|
||||
message="正在处理对话"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 构建上下文
|
||||
ctx: Dict[str, Any] = context or {}
|
||||
|
||||
# 获取关联文档
|
||||
if doc_ids:
|
||||
docs = []
|
||||
for doc_id in doc_ids:
|
||||
doc = await mongo_client.get_document(doc_id)
|
||||
if doc:
|
||||
docs.append(doc)
|
||||
if docs:
|
||||
ctx["source_docs"] = docs
|
||||
|
||||
# 执行指令
|
||||
result = await instruction_executor.execute(instruction, ctx)
|
||||
|
||||
# 根据意图类型添加友好的响应消息
|
||||
response_messages = {
|
||||
"extract": f"已提取 {len(result.get('extracted_data', {}))} 个字段的数据",
|
||||
"fill_table": f"填表完成,填写了 {len(result.get('result', {}).get('filled_data', {}))} 个字段",
|
||||
"summarize": "已生成文档摘要",
|
||||
"question": "已找到相关答案",
|
||||
"search": f"找到 {len(result.get('results', []))} 条相关内容",
|
||||
"compare": f"对比了 {len(result.get('comparison', []))} 个文档",
|
||||
"edit": "编辑操作已完成",
|
||||
"transform": "格式转换已完成",
|
||||
"unknown": "无法理解该指令,请尝试更明确的描述"
|
||||
}
|
||||
|
||||
response = {
|
||||
"success": result.get("success", False),
|
||||
"intent": result.get("intent", "unknown"),
|
||||
"result": result,
|
||||
"message": response_messages.get(result.get("intent", ""), result.get("message", "")),
|
||||
"hint": _get_intent_hint(result.get("intent", ""))
|
||||
}
|
||||
|
||||
# 更新任务状态
|
||||
try:
|
||||
await mongo_client.update_task(
|
||||
task_id=task_id,
|
||||
status="success",
|
||||
message="处理完成",
|
||||
result=response
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"指令对话失败: {e}")
|
||||
try:
|
||||
await mongo_client.update_task(
|
||||
task_id=task_id,
|
||||
status="failure",
|
||||
message="处理失败",
|
||||
error=str(e)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"处理失败: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
def _get_intent_hint(intent: str) -> Optional[str]:
|
||||
"""根据意图返回下一步提示"""
|
||||
hints = {
|
||||
"extract": "您可以继续说 '提取更多字段' 或 '将数据填入表格'",
|
||||
"fill_table": "您可以提供表格模板或说 '帮我创建一个表格'",
|
||||
"question": "您可以继续提问或说 '总结一下这些内容'",
|
||||
"search": "您可以查看搜索结果或说 '对比这些内容'",
|
||||
"unknown": "您可以尝试: '提取数据'、'填表'、'总结'、'问答' 等指令"
|
||||
}
|
||||
return hints.get(intent)
|
||||
|
||||
|
||||
@router.get("/intents")
|
||||
async def list_supported_intents():
|
||||
"""
|
||||
获取支持的意图类型列表
|
||||
|
||||
返回所有可用的自然语言指令类型
|
||||
"""
|
||||
return {
|
||||
"intents": [
|
||||
{
|
||||
"intent": "extract",
|
||||
"name": "信息提取",
|
||||
"examples": [
|
||||
"提取文档中的医院数量",
|
||||
"抽取所有机构的名称",
|
||||
"找出表格中的数据"
|
||||
],
|
||||
"params": ["field_refs", "document_refs"]
|
||||
},
|
||||
{
|
||||
"intent": "fill_table",
|
||||
"name": "表格填写",
|
||||
"examples": [
|
||||
"填表",
|
||||
"根据这些数据填写表格",
|
||||
"帮我填到Excel里"
|
||||
],
|
||||
"params": ["template", "document_refs"]
|
||||
},
|
||||
{
|
||||
"intent": "summarize",
|
||||
"name": "摘要总结",
|
||||
"examples": [
|
||||
"总结一下这份文档",
|
||||
"生成摘要",
|
||||
"概括主要内容"
|
||||
],
|
||||
"params": ["document_refs"]
|
||||
},
|
||||
{
|
||||
"intent": "question",
|
||||
"name": "智能问答",
|
||||
"examples": [
|
||||
"这段话说的是什么?",
|
||||
"有多少家医院?",
|
||||
"解释一下这个概念"
|
||||
],
|
||||
"params": ["question", "focus"]
|
||||
},
|
||||
{
|
||||
"intent": "search",
|
||||
"name": "文档搜索",
|
||||
"examples": [
|
||||
"搜索相关内容",
|
||||
"找找看有哪些机构",
|
||||
"查询医院相关的数据"
|
||||
],
|
||||
"params": ["field_refs", "question"]
|
||||
},
|
||||
{
|
||||
"intent": "compare",
|
||||
"name": "对比分析",
|
||||
"examples": [
|
||||
"对比这两个文档",
|
||||
"比较一下差异",
|
||||
"找出不同点"
|
||||
],
|
||||
"params": ["document_refs"]
|
||||
},
|
||||
{
|
||||
"intent": "edit",
|
||||
"name": "文档编辑",
|
||||
"examples": [
|
||||
"润色这段文字",
|
||||
"修改格式",
|
||||
"添加注释"
|
||||
],
|
||||
"params": []
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -610,29 +610,39 @@ async def _export_to_excel(filled_data: dict, template_id: str) -> StreamingResp
|
||||
|
||||
async def _export_to_word(filled_data: dict, template_id: str) -> StreamingResponse:
|
||||
"""导出为 Word 格式"""
|
||||
import re
|
||||
import tempfile
|
||||
import os
|
||||
from docx import Document
|
||||
from docx.shared import Pt, RGBColor
|
||||
from docx.enum.text import WD_ALIGN_PARAGRAPH
|
||||
|
||||
def clean_text(text: str) -> str:
|
||||
"""清理文本,移除可能导致Word问题的非法字符"""
|
||||
if not text:
|
||||
return ""
|
||||
# 移除控制字符
|
||||
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text)
|
||||
return text.strip()
|
||||
|
||||
try:
|
||||
# 先保存到临时文件,再读取到内存,确保文档完整性
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.docx') as tmp_file:
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
doc = Document()
|
||||
doc.add_heading('填写结果', level=1)
|
||||
|
||||
# 添加标题
|
||||
title = doc.add_heading('填写结果', level=1)
|
||||
title.alignment = WD_ALIGN_PARAGRAPH.CENTER
|
||||
|
||||
# 添加填写时间和模板信息
|
||||
from datetime import datetime
|
||||
info_para = doc.add_paragraph()
|
||||
info_para.add_run(f"模板ID: {template_id}\n").bold = True
|
||||
template_filename = template_id.split('/')[-1].split('\\')[-1] if template_id else '未知'
|
||||
info_para.add_run(f"模板文件: {clean_text(template_filename)}\n").bold = True
|
||||
info_para.add_run(f"导出时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
doc.add_paragraph()
|
||||
|
||||
doc.add_paragraph() # 空行
|
||||
|
||||
# 添加字段表格
|
||||
table = doc.add_table(rows=1, cols=3)
|
||||
table.style = 'Light Grid Accent 1'
|
||||
table.style = 'Table Grid'
|
||||
|
||||
# 表头
|
||||
header_cells = table.rows[0].cells
|
||||
header_cells[0].text = '字段名'
|
||||
header_cells[1].text = '填写值'
|
||||
@@ -640,21 +650,39 @@ async def _export_to_word(filled_data: dict, template_id: str) -> StreamingRespo
|
||||
|
||||
for field_name, field_value in filled_data.items():
|
||||
row_cells = table.add_row().cells
|
||||
row_cells[0].text = field_name
|
||||
row_cells[1].text = str(field_value) if field_value else ''
|
||||
row_cells[2].text = '已填写' if field_value else '为空'
|
||||
row_cells[0].text = clean_text(str(field_name))
|
||||
|
||||
# 保存到 BytesIO
|
||||
output = io.BytesIO()
|
||||
doc.save(output)
|
||||
output.seek(0)
|
||||
if isinstance(field_value, list):
|
||||
clean_values = [clean_text(str(v)) for v in field_value if v]
|
||||
display_value = ', '.join(clean_values) if clean_values else ''
|
||||
else:
|
||||
display_value = clean_text(str(field_value)) if field_value else ''
|
||||
|
||||
filename = f"filled_template.docx"
|
||||
row_cells[1].text = display_value
|
||||
row_cells[2].text = '已填写' if display_value else '为空'
|
||||
|
||||
# 保存到临时文件
|
||||
doc.save(tmp_path)
|
||||
|
||||
# 读取文件内容
|
||||
with open(tmp_path, 'rb') as f:
|
||||
file_content = f.read()
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if os.path.exists(tmp_path):
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
output = io.BytesIO(file_content)
|
||||
filename = "filled_template.docx"
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(output.getvalue()),
|
||||
output,
|
||||
media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
headers={"Content-Disposition": f"attachment; filename={filename}"}
|
||||
headers={"Content-Disposition": f"attachment; filename*=UTF-8''{filename}"}
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
"""
|
||||
指令执行模块
|
||||
|
||||
注意: 此模块为可选功能,当前尚未实现。
|
||||
如需启用,请实现 intent_parser.py 和 executor.py
|
||||
支持文档智能操作交互,包括意图解析和指令执行
|
||||
"""
|
||||
from .intent_parser import IntentParser, DefaultIntentParser
|
||||
from .executor import InstructionExecutor, DefaultInstructionExecutor
|
||||
from .intent_parser import IntentParser, intent_parser
|
||||
from .executor import InstructionExecutor, instruction_executor
|
||||
|
||||
__all__ = [
|
||||
"IntentParser",
|
||||
"DefaultIntentParser",
|
||||
"intent_parser",
|
||||
"InstructionExecutor",
|
||||
"DefaultInstructionExecutor",
|
||||
"instruction_executor",
|
||||
]
|
||||
|
||||
@@ -2,34 +2,571 @@
|
||||
指令执行器模块
|
||||
|
||||
将自然语言指令转换为可执行操作
|
||||
|
||||
注意: 此模块为可选功能,当前尚未实现。
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
import logging
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.services.template_fill_service import template_fill_service
|
||||
from app.services.rag_service import rag_service
|
||||
from app.services.markdown_ai_service import markdown_ai_service
|
||||
from app.core.database import mongodb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InstructionExecutor(ABC):
|
||||
"""指令执行器抽象基类"""
|
||||
class InstructionExecutor:
|
||||
"""指令执行器"""
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, instruction: str, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def __init__(self):
|
||||
self.intent_parser = None # 将通过 set_intent_parser 设置
|
||||
|
||||
def set_intent_parser(self, intent_parser):
|
||||
"""设置意图解析器"""
|
||||
self.intent_parser = intent_parser
|
||||
|
||||
async def execute(self, instruction: str, context: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
执行指令
|
||||
|
||||
Args:
|
||||
instruction: 解析后的指令
|
||||
context: 执行上下文
|
||||
instruction: 自然语言指令
|
||||
context: 执行上下文(包含文档信息等)
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
pass
|
||||
if self.intent_parser is None:
|
||||
from app.instruction.intent_parser import intent_parser
|
||||
self.intent_parser = intent_parser
|
||||
|
||||
context = context or {}
|
||||
|
||||
# 解析意图
|
||||
intent, params = await self.intent_parser.parse(instruction)
|
||||
|
||||
# 根据意图类型执行相应操作
|
||||
if intent == "extract":
|
||||
return await self._execute_extract(params, context)
|
||||
elif intent == "fill_table":
|
||||
return await self._execute_fill_table(params, context)
|
||||
elif intent == "summarize":
|
||||
return await self._execute_summarize(params, context)
|
||||
elif intent == "question":
|
||||
return await self._execute_question(params, context)
|
||||
elif intent == "search":
|
||||
return await self._execute_search(params, context)
|
||||
elif intent == "compare":
|
||||
return await self._execute_compare(params, context)
|
||||
elif intent == "edit":
|
||||
return await self._execute_edit(params, context)
|
||||
elif intent == "transform":
|
||||
return await self._execute_transform(params, context)
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"未知意图类型: {intent}",
|
||||
"message": "无法理解该指令,请尝试更明确的描述"
|
||||
}
|
||||
|
||||
async def _execute_extract(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行信息提取"""
|
||||
try:
|
||||
target_fields = params.get("field_refs", [])
|
||||
doc_ids = params.get("document_refs", [])
|
||||
|
||||
if not target_fields:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "未指定要提取的字段",
|
||||
"message": "请明确说明要提取哪些字段,如:'提取医院数量和床位数'"
|
||||
}
|
||||
|
||||
# 如果指定了文档,验证文档存在
|
||||
if doc_ids and "all_docs" not in doc_ids:
|
||||
valid_docs = []
|
||||
for doc_ref in doc_ids:
|
||||
doc_id = doc_ref.replace("doc_", "")
|
||||
doc = await mongodb.get_document(doc_id)
|
||||
if doc:
|
||||
valid_docs.append(doc)
|
||||
if not valid_docs:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "指定的文档不存在",
|
||||
"message": "请检查文档编号是否正确"
|
||||
}
|
||||
context["source_docs"] = valid_docs
|
||||
|
||||
# 构建字段列表
|
||||
fields = []
|
||||
for i, field_name in enumerate(target_fields):
|
||||
fields.append({
|
||||
"name": field_name,
|
||||
"cell": f"A{i+1}",
|
||||
"field_type": "text",
|
||||
"required": False
|
||||
})
|
||||
|
||||
# 调用填表服务
|
||||
result = await template_fill_service.fill_template(
|
||||
template_fields=fields,
|
||||
source_doc_ids=[doc.get("_id") for doc in context.get("source_docs", [])] if context.get("source_docs") else None,
|
||||
user_hint=f"请提取字段: {', '.join(target_fields)}"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "extract",
|
||||
"extracted_data": result.get("filled_data", {}),
|
||||
"fields": target_fields,
|
||||
"message": f"成功提取 {len(result.get('filled_data', {}))} 个字段"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"提取执行失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"提取失败: {str(e)}"
|
||||
}
|
||||
|
||||
async def _execute_fill_table(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行填表操作"""
|
||||
try:
|
||||
template_file = context.get("template_file")
|
||||
if not template_file:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "未提供表格模板",
|
||||
"message": "请先上传要填写的表格模板"
|
||||
}
|
||||
|
||||
# 获取源文档
|
||||
source_docs = context.get("source_docs", [])
|
||||
source_doc_ids = [doc.get("_id") for doc in source_docs if doc.get("_id")]
|
||||
|
||||
# 获取字段
|
||||
fields = context.get("template_fields", [])
|
||||
|
||||
# 调用填表服务
|
||||
result = await template_fill_service.fill_template(
|
||||
template_fields=fields,
|
||||
source_doc_ids=source_doc_ids if source_doc_ids else None,
|
||||
source_file_paths=context.get("source_file_paths"),
|
||||
user_hint=params.get("user_hint"),
|
||||
template_id=template_file if isinstance(template_file, str) else None,
|
||||
template_file_type=params.get("template", {}).get("type", "xlsx")
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "fill_table",
|
||||
"result": result,
|
||||
"message": f"填表完成,成功填写 {len(result.get('filled_data', {}))} 个字段"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"填表执行失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"填表失败: {str(e)}"
|
||||
}
|
||||
|
||||
async def _execute_summarize(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行摘要总结"""
|
||||
try:
|
||||
docs = context.get("source_docs", [])
|
||||
if not docs:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "没有可用的文档",
|
||||
"message": "请先上传要总结的文档"
|
||||
}
|
||||
|
||||
summaries = []
|
||||
for doc in docs[:5]: # 最多处理5个文档
|
||||
content = doc.get("content", "")[:5000] # 限制内容长度
|
||||
if content:
|
||||
summaries.append({
|
||||
"filename": doc.get("metadata", {}).get("original_filename", "未知"),
|
||||
"content_preview": content[:500] + "..." if len(content) > 500 else content
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "summarize",
|
||||
"summaries": summaries,
|
||||
"message": f"找到 {len(summaries)} 个文档可供参考"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"摘要执行失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"摘要生成失败: {str(e)}"
|
||||
}
|
||||
|
||||
async def _execute_question(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行问答"""
|
||||
try:
|
||||
question = params.get("question", "")
|
||||
if not question:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "未提供问题",
|
||||
"message": "请输入要回答的问题"
|
||||
}
|
||||
|
||||
# 使用 RAG 检索相关文档
|
||||
docs = context.get("source_docs", [])
|
||||
rag_results = []
|
||||
|
||||
for doc in docs:
|
||||
doc_id = doc.get("_id", "")
|
||||
if doc_id:
|
||||
results = rag_service.retrieve_by_doc_id(doc_id, top_k=3)
|
||||
rag_results.extend(results)
|
||||
|
||||
# 构建上下文
|
||||
context_text = "\n\n".join([
|
||||
r.get("content", "") for r in rag_results[:5]
|
||||
]) if rag_results else ""
|
||||
|
||||
# 如果没有 RAG 结果,使用文档内容
|
||||
if not context_text:
|
||||
context_text = "\n\n".join([
|
||||
doc.get("content", "")[:3000] for doc in docs[:3] if doc.get("content")
|
||||
])
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "question",
|
||||
"question": question,
|
||||
"context_preview": context_text[:500] + "..." if len(context_text) > 500 else context_text,
|
||||
"message": "已找到相关上下文,可进行问答"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"问答执行失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"问答处理失败: {str(e)}"
|
||||
}
|
||||
|
||||
async def _execute_search(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行搜索"""
|
||||
try:
|
||||
field_refs = params.get("field_refs", [])
|
||||
query = " ".join(field_refs) if field_refs else params.get("question", "")
|
||||
|
||||
if not query:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "未提供搜索关键词",
|
||||
"message": "请输入要搜索的关键词"
|
||||
}
|
||||
|
||||
# 使用 RAG 检索
|
||||
results = rag_service.retrieve(query, top_k=10, min_score=0.3)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "search",
|
||||
"query": query,
|
||||
"results": [
|
||||
{
|
||||
"content": r.get("content", "")[:200],
|
||||
"score": r.get("score", 0),
|
||||
"doc_id": r.get("doc_id", "")
|
||||
}
|
||||
for r in results[:10]
|
||||
],
|
||||
"message": f"找到 {len(results)} 条相关结果"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"搜索执行失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"搜索失败: {str(e)}"
|
||||
}
|
||||
|
||||
async def _execute_compare(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行对比分析"""
|
||||
try:
|
||||
docs = context.get("source_docs", [])
|
||||
if len(docs) < 2:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "对比需要至少2个文档",
|
||||
"message": "请上传至少2个文档进行对比"
|
||||
}
|
||||
|
||||
# 提取文档基本信息
|
||||
comparison = []
|
||||
for i, doc in enumerate(docs[:5]):
|
||||
comparison.append({
|
||||
"index": i + 1,
|
||||
"filename": doc.get("metadata", {}).get("original_filename", "未知"),
|
||||
"doc_type": doc.get("doc_type", "未知"),
|
||||
"content_length": len(doc.get("content", "")),
|
||||
"has_tables": bool(doc.get("structured_data", {}).get("tables")),
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "compare",
|
||||
"comparison": comparison,
|
||||
"message": f"对比了 {len(comparison)} 个文档的基本信息"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"对比执行失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"对比分析失败: {str(e)}"
|
||||
}
|
||||
|
||||
async def _execute_edit(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行文档编辑操作"""
|
||||
try:
|
||||
docs = context.get("source_docs", [])
|
||||
if not docs:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "没有可用的文档",
|
||||
"message": "请先上传要编辑的文档"
|
||||
}
|
||||
|
||||
doc = docs[0] # 默认编辑第一个文档
|
||||
content = doc.get("content", "")
|
||||
original_filename = doc.get("metadata", {}).get("original_filename", "未知文档")
|
||||
|
||||
if not content:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "文档内容为空",
|
||||
"message": "该文档没有可编辑的内容"
|
||||
}
|
||||
|
||||
# 使用 LLM 进行文本润色/编辑
|
||||
prompt = f"""请对以下文档内容进行编辑处理。
|
||||
|
||||
原文内容:
|
||||
{content[:8000]}
|
||||
|
||||
编辑要求:
|
||||
- 润色表述,使其更加专业流畅
|
||||
- 修正明显的语法错误
|
||||
- 保持原意不变
|
||||
- 只返回编辑后的内容,不要解释
|
||||
|
||||
请直接输出编辑后的内容:"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的文本编辑助手。请直接输出编辑后的内容。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
from app.services.llm_service import llm_service
|
||||
response = await llm_service.chat(messages=messages, temperature=0.3, max_tokens=8000)
|
||||
edited_content = llm_service.extract_message_content(response)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "edit",
|
||||
"edited_content": edited_content,
|
||||
"original_filename": original_filename,
|
||||
"message": "文档编辑完成,内容已返回"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"编辑执行失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"编辑处理失败: {str(e)}"
|
||||
}
|
||||
|
||||
async def _execute_transform(self, params: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
执行格式转换操作
|
||||
|
||||
支持:
|
||||
- Word -> Excel
|
||||
- Excel -> Word
|
||||
- Markdown -> Word
|
||||
- Word -> Markdown
|
||||
"""
|
||||
try:
|
||||
docs = context.get("source_docs", [])
|
||||
if not docs:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "没有可用的文档",
|
||||
"message": "请先上传要转换的文档"
|
||||
}
|
||||
|
||||
# 获取目标格式
|
||||
template_info = params.get("template", {})
|
||||
target_type = template_info.get("type", "")
|
||||
|
||||
if not target_type:
|
||||
# 尝试从指令中推断
|
||||
instruction = params.get("instruction", "")
|
||||
if "excel" in instruction.lower() or "xlsx" in instruction.lower():
|
||||
target_type = "xlsx"
|
||||
elif "word" in instruction.lower() or "docx" in instruction.lower():
|
||||
target_type = "docx"
|
||||
elif "markdown" in instruction.lower() or "md" in instruction.lower():
|
||||
target_type = "md"
|
||||
|
||||
if not target_type:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "未指定目标格式",
|
||||
"message": "请说明要转换成什么格式(如:转成Excel、转成Word)"
|
||||
}
|
||||
|
||||
doc = docs[0]
|
||||
content = doc.get("content", "")
|
||||
structured_data = doc.get("structured_data", {})
|
||||
original_filename = doc.get("metadata", {}).get("original_filename", "未知文档")
|
||||
|
||||
# 构建转换内容
|
||||
if structured_data.get("tables"):
|
||||
# 有表格数据,生成表格格式的内容
|
||||
tables = structured_data.get("tables", [])
|
||||
table_content = []
|
||||
for i, table in enumerate(tables[:3]): # 最多处理3个表格
|
||||
headers = table.get("headers", [])
|
||||
rows = table.get("rows", [])[:20] # 最多20行
|
||||
if headers:
|
||||
table_content.append(f"【表格 {i+1}】")
|
||||
table_content.append(" | ".join(str(h) for h in headers))
|
||||
table_content.append(" | ".join(["---"] * len(headers)))
|
||||
for row in rows:
|
||||
if isinstance(row, list):
|
||||
table_content.append(" | ".join(str(c) for c in row))
|
||||
elif isinstance(row, dict):
|
||||
table_content.append(" | ".join(str(row.get(h, "")) for h in headers))
|
||||
table_content.append("")
|
||||
|
||||
if target_type == "xlsx":
|
||||
# 生成 Excel 格式的数据(JSON)
|
||||
excel_data = []
|
||||
for table in tables[:1]: # 只处理第一个表格
|
||||
headers = table.get("headers", [])
|
||||
rows = table.get("rows", [])[:100]
|
||||
for row in rows:
|
||||
if isinstance(row, list):
|
||||
excel_data.append(dict(zip(headers, row)))
|
||||
elif isinstance(row, dict):
|
||||
excel_data.append(row)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "transform",
|
||||
"transform_type": "to_excel",
|
||||
"target_format": "xlsx",
|
||||
"excel_data": excel_data,
|
||||
"headers": headers,
|
||||
"message": f"已转换为 Excel 格式,包含 {len(excel_data)} 行数据"
|
||||
}
|
||||
elif target_type in ["docx", "word"]:
|
||||
# 生成 Word 格式的文本
|
||||
word_content = f"# {original_filename}\n\n"
|
||||
word_content += "\n".join(table_content)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "transform",
|
||||
"transform_type": "to_word",
|
||||
"target_format": "docx",
|
||||
"content": word_content,
|
||||
"message": "已转换为 Word 格式"
|
||||
}
|
||||
elif target_type == "md":
|
||||
# 生成 Markdown 格式
|
||||
md_content = f"# {original_filename}\n\n"
|
||||
md_content += "\n".join(table_content)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "transform",
|
||||
"transform_type": "to_markdown",
|
||||
"target_format": "md",
|
||||
"content": md_content,
|
||||
"message": "已转换为 Markdown 格式"
|
||||
}
|
||||
|
||||
# 无表格数据,使用纯文本内容转换
|
||||
if target_type == "xlsx":
|
||||
# 将文本内容转为 Excel 格式(每行作为一列)
|
||||
lines = [line.strip() for line in content.split("\n") if line.strip()][:100]
|
||||
excel_data = [{"行号": i+1, "内容": line} for i, line in enumerate(lines)]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "transform",
|
||||
"transform_type": "to_excel",
|
||||
"target_format": "xlsx",
|
||||
"excel_data": excel_data,
|
||||
"headers": ["行号", "内容"],
|
||||
"message": f"已将文本内容转换为 Excel,包含 {len(excel_data)} 行"
|
||||
}
|
||||
elif target_type in ["docx", "word"]:
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "transform",
|
||||
"transform_type": "to_word",
|
||||
"target_format": "docx",
|
||||
"content": content,
|
||||
"message": "文档内容已准备好,可下载为 Word 格式"
|
||||
}
|
||||
elif target_type == "md":
|
||||
# 简单的文本转 Markdown
|
||||
md_lines = []
|
||||
for line in content.split("\n"):
|
||||
line = line.strip()
|
||||
if line:
|
||||
# 简单处理:如果行不长且不是列表格式,作为段落
|
||||
if len(line) < 100 and not line.startswith(("-", "*", "1.", "2.", "3.")):
|
||||
md_lines.append(line)
|
||||
else:
|
||||
md_lines.append(line)
|
||||
else:
|
||||
md_lines.append("")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"intent": "transform",
|
||||
"transform_type": "to_markdown",
|
||||
"target_format": "md",
|
||||
"content": "\n".join(md_lines),
|
||||
"message": "已转换为 Markdown 格式"
|
||||
}
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": "不支持的目标格式",
|
||||
"message": f"暂不支持转换为 {target_type} 格式"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"格式转换失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"格式转换失败: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
class DefaultInstructionExecutor(InstructionExecutor):
|
||||
"""默认指令执行器"""
|
||||
|
||||
async def execute(self, instruction: str, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""暂未实现"""
|
||||
raise NotImplementedError("指令执行功能暂未实现")
|
||||
# 全局单例
|
||||
instruction_executor = InstructionExecutor()
|
||||
|
||||
@@ -2,17 +2,51 @@
|
||||
意图解析器模块
|
||||
|
||||
解析用户自然语言指令,识别意图和参数
|
||||
|
||||
注意: 此模块为可选功能,当前尚未实现。
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Tuple
|
||||
import re
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IntentParser(ABC):
|
||||
"""意图解析器抽象基类"""
|
||||
class IntentParser:
|
||||
"""意图解析器"""
|
||||
|
||||
# 意图类型定义
|
||||
INTENT_EXTRACT = "extract" # 信息提取
|
||||
INTENT_FILL_TABLE = "fill_table" # 填表
|
||||
INTENT_SUMMARIZE = "summarize" # 摘要总结
|
||||
INTENT_QUESTION = "question" # 问答
|
||||
INTENT_SEARCH = "search" # 搜索
|
||||
INTENT_COMPARE = "compare" # 对比分析
|
||||
INTENT_TRANSFORM = "transform" # 格式转换
|
||||
INTENT_EDIT = "edit" # 编辑文档
|
||||
INTENT_UNKNOWN = "unknown" # 未知
|
||||
|
||||
# 意图关键词映射
|
||||
INTENT_KEYWORDS = {
|
||||
INTENT_EXTRACT: ["提取", "抽取", "获取", "找出", "查找", "识别", "找到"],
|
||||
INTENT_FILL_TABLE: ["填表", "填写", "填充", "录入", "导入到表格", "填写到"],
|
||||
INTENT_SUMMARIZE: ["总结", "摘要", "概括", "概述", "归纳", "提炼"],
|
||||
INTENT_QUESTION: ["问答", "回答", "解释", "什么是", "为什么", "如何", "怎样", "多少", "几个"],
|
||||
INTENT_SEARCH: ["搜索", "查找", "检索", "查询", "找"],
|
||||
INTENT_COMPARE: ["对比", "比较", "差异", "区别", "不同"],
|
||||
INTENT_TRANSFORM: ["转换", "转化", "变成", "转为", "导出"],
|
||||
INTENT_EDIT: ["修改", "编辑", "调整", "改写", "润色", "优化"],
|
||||
}
|
||||
|
||||
# 实体模式定义
|
||||
ENTITY_PATTERNS = {
|
||||
"number": [r"\d+", r"[一二三四五六七八九十百千万]+"],
|
||||
"date": [r"\d{4}年", r"\d{1,2}月", r"\d{1,2}日"],
|
||||
"percentage": [r"\d+(\.\d+)?%", r"\d+(\.\d+)?‰"],
|
||||
"currency": [r"\d+(\.\d+)?万元", r"\d+(\.\d+)?亿元", r"\d+(\.\d+)?元"],
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.intent_history: List[Dict[str, Any]] = []
|
||||
|
||||
@abstractmethod
|
||||
async def parse(self, text: str) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
解析自然语言指令
|
||||
@@ -23,12 +57,186 @@ class IntentParser(ABC):
|
||||
Returns:
|
||||
(意图类型, 参数字典)
|
||||
"""
|
||||
pass
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return self.INTENT_UNKNOWN, {}
|
||||
|
||||
# 记录历史
|
||||
self.intent_history.append({"text": text, "intent": None})
|
||||
|
||||
# 识别意图
|
||||
intent = self._recognize_intent(text)
|
||||
|
||||
# 提取参数
|
||||
params = self._extract_params(text, intent)
|
||||
|
||||
# 更新历史
|
||||
if self.intent_history:
|
||||
self.intent_history[-1]["intent"] = intent
|
||||
|
||||
logger.info(f"意图解析: text={text[:50]}..., intent={intent}, params={params}")
|
||||
|
||||
return intent, params
|
||||
|
||||
def _recognize_intent(self, text: str) -> str:
|
||||
"""识别意图类型"""
|
||||
intent_scores: Dict[str, float] = {}
|
||||
|
||||
for intent, keywords in self.INTENT_KEYWORDS.items():
|
||||
score = 0
|
||||
for keyword in keywords:
|
||||
if keyword in text:
|
||||
score += 1
|
||||
if score > 0:
|
||||
intent_scores[intent] = score
|
||||
|
||||
if not intent_scores:
|
||||
return self.INTENT_UNKNOWN
|
||||
|
||||
# 返回得分最高的意图
|
||||
return max(intent_scores, key=intent_scores.get)
|
||||
|
||||
def _extract_params(self, text: str, intent: str) -> Dict[str, Any]:
|
||||
"""提取参数"""
|
||||
params: Dict[str, Any] = {
|
||||
"entities": self._extract_entities(text),
|
||||
"document_refs": self._extract_document_refs(text),
|
||||
"field_refs": self._extract_field_refs(text),
|
||||
"template_refs": self._extract_template_refs(text),
|
||||
}
|
||||
|
||||
# 根据意图类型提取特定参数
|
||||
if intent == self.INTENT_QUESTION:
|
||||
params["question"] = text
|
||||
params["focus"] = self._extract_question_focus(text)
|
||||
elif intent == self.INTENT_FILL_TABLE:
|
||||
params["template"] = self._extract_template_info(text)
|
||||
elif intent == self.INTENT_EXTRACT:
|
||||
params["target_fields"] = self._extract_target_fields(text)
|
||||
|
||||
return params
|
||||
|
||||
def _extract_entities(self, text: str) -> Dict[str, List[str]]:
|
||||
"""提取实体"""
|
||||
entities: Dict[str, List[str]] = {}
|
||||
|
||||
for entity_type, patterns in self.ENTITY_PATTERNS.items():
|
||||
matches = []
|
||||
for pattern in patterns:
|
||||
found = re.findall(pattern, text)
|
||||
matches.extend(found)
|
||||
if matches:
|
||||
entities[entity_type] = list(set(matches))
|
||||
|
||||
return entities
|
||||
|
||||
def _extract_document_refs(self, text: str) -> List[str]:
|
||||
"""提取文档引用"""
|
||||
# 匹配 "文档1"、"doc1"、"第一个文档" 等
|
||||
refs = []
|
||||
|
||||
# 数字索引: 文档1, doc1, 第1个文档
|
||||
num_patterns = [
|
||||
r"[文档doc]+(\d+)",
|
||||
r"第(\d+)个文档",
|
||||
r"第(\d+)份",
|
||||
]
|
||||
for pattern in num_patterns:
|
||||
matches = re.findall(pattern, text.lower())
|
||||
refs.extend([f"doc_{m}" for m in matches])
|
||||
|
||||
# "所有文档"、"全部文档"
|
||||
if any(kw in text for kw in ["所有", "全部", "整个"]):
|
||||
refs.append("all_docs")
|
||||
|
||||
return refs
|
||||
|
||||
def _extract_field_refs(self, text: str) -> List[str]:
|
||||
"""提取字段引用"""
|
||||
fields = []
|
||||
|
||||
# 匹配引号内的字段名
|
||||
quoted = re.findall(r"['\"『「]([^'\"』」]+)['\"』」]", text)
|
||||
fields.extend(quoted)
|
||||
|
||||
# 匹配 "xxx字段"、"xxx列" 等
|
||||
field_patterns = [
|
||||
r"([^\s]+)字段",
|
||||
r"([^\s]+)列",
|
||||
r"([^\s]+)数据",
|
||||
]
|
||||
for pattern in field_patterns:
|
||||
matches = re.findall(pattern, text)
|
||||
fields.extend(matches)
|
||||
|
||||
return list(set(fields))
|
||||
|
||||
def _extract_template_refs(self, text: str) -> List[str]:
|
||||
"""提取模板引用"""
|
||||
templates = []
|
||||
|
||||
# 匹配 "表格模板"、"Excel模板"、"表1" 等
|
||||
template_patterns = [
|
||||
r"([^\s]+模板)",
|
||||
r"表(\d+)",
|
||||
r"([^\s]+表格)",
|
||||
]
|
||||
for pattern in template_patterns:
|
||||
matches = re.findall(pattern, text)
|
||||
templates.extend(matches)
|
||||
|
||||
return list(set(templates))
|
||||
|
||||
def _extract_question_focus(self, text: str) -> Optional[str]:
|
||||
"""提取问题焦点"""
|
||||
# "什么是XXX"、"XXX是什么"
|
||||
match = re.search(r"[什么是]([^?]+)", text)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
# "XXX有多少"
|
||||
match = re.search(r"([^?]+)有多少", text)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
return None
|
||||
|
||||
def _extract_template_info(self, text: str) -> Optional[Dict[str, str]]:
|
||||
"""提取模板信息"""
|
||||
template_info: Dict[str, str] = {}
|
||||
|
||||
# 提取模板类型
|
||||
if "excel" in text.lower() or "xlsx" in text.lower() or "电子表格" in text:
|
||||
template_info["type"] = "xlsx"
|
||||
elif "word" in text.lower() or "docx" in text.lower() or "文档" in text:
|
||||
template_info["type"] = "docx"
|
||||
|
||||
return template_info if template_info else None
|
||||
|
||||
def _extract_target_fields(self, text: str) -> List[str]:
|
||||
"""提取目标字段"""
|
||||
fields = []
|
||||
|
||||
# 匹配 "提取XXX和YYY"、"抽取XXX、YYY"
|
||||
patterns = [
|
||||
r"提取([^(and|,|,)+]+?)(?:和|与|、|,|plus)",
|
||||
r"抽取([^(and|,|,)+]+?)(?:和|与|、|,|plus)",
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
matches = re.findall(pattern, text)
|
||||
fields.extend([m.strip() for m in matches if m.strip()])
|
||||
|
||||
return list(set(fields))
|
||||
|
||||
def get_intent_history(self) -> List[Dict[str, Any]]:
|
||||
"""获取意图历史"""
|
||||
return self.intent_history
|
||||
|
||||
def clear_history(self):
|
||||
"""清空历史"""
|
||||
self.intent_history = []
|
||||
|
||||
|
||||
class DefaultIntentParser(IntentParser):
|
||||
"""默认意图解析器"""
|
||||
|
||||
async def parse(self, text: str) -> Tuple[str, Dict[str, Any]]:
|
||||
"""暂未实现"""
|
||||
raise NotImplementedError("意图解析功能暂未实现")
|
||||
# 全局单例
|
||||
intent_parser = IntentParser()
|
||||
|
||||
446
backend/app/services/multi_doc_reasoning_service.py
Normal file
446
backend/app/services/multi_doc_reasoning_service.py
Normal file
@@ -0,0 +1,446 @@
|
||||
"""
|
||||
多文档关联推理服务
|
||||
|
||||
跨文档信息关联和推理
|
||||
"""
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from collections import defaultdict
|
||||
|
||||
from app.services.llm_service import llm_service
|
||||
from app.services.rag_service import rag_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MultiDocReasoningService:
|
||||
"""
|
||||
多文档关联推理服务
|
||||
|
||||
功能:
|
||||
1. 实体跨文档追踪 - 追踪同一实体在不同文档中的描述
|
||||
2. 关系抽取与推理 - 抽取实体间关系并进行推理
|
||||
3. 信息补全 - 根据多个文档的信息互补填充缺失数据
|
||||
4. 冲突检测 - 检测不同文档间的信息冲突
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.llm = llm_service
|
||||
|
||||
async def analyze_cross_documents(
|
||||
self,
|
||||
documents: List[Dict[str, Any]],
|
||||
query: Optional[str] = None,
|
||||
entity_types: Optional[List[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
跨文档分析
|
||||
|
||||
Args:
|
||||
documents: 文档列表
|
||||
query: 查询意图(可选)
|
||||
entity_types: 要追踪的实体类型列表,如 ["机构", "人物", "地点", "数量"]
|
||||
|
||||
Returns:
|
||||
跨文档分析结果
|
||||
"""
|
||||
if not documents:
|
||||
return {"success": False, "error": "没有可用的文档"}
|
||||
|
||||
entity_types = entity_types or ["机构", "数量", "时间", "地点"]
|
||||
|
||||
try:
|
||||
# 1. 提取各文档中的实体
|
||||
entities_per_doc = await self._extract_entities_from_docs(documents, entity_types)
|
||||
|
||||
# 2. 跨文档实体对齐
|
||||
aligned_entities = self._align_entities_across_docs(entities_per_doc)
|
||||
|
||||
# 3. 关系抽取
|
||||
relations = await self._extract_relations(documents)
|
||||
|
||||
# 4. 构建知识图谱
|
||||
knowledge_graph = self._build_knowledge_graph(aligned_entities, relations)
|
||||
|
||||
# 5. 信息补全
|
||||
completed_info = await self._complete_missing_info(knowledge_graph, documents)
|
||||
|
||||
# 6. 冲突检测
|
||||
conflicts = self._detect_conflicts(aligned_entities)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"entities": aligned_entities,
|
||||
"relations": relations,
|
||||
"knowledge_graph": knowledge_graph,
|
||||
"completed_info": completed_info,
|
||||
"conflicts": conflicts,
|
||||
"summary": self._generate_summary(aligned_entities, conflicts)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"跨文档分析失败: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def _extract_entities_from_docs(
|
||||
self,
|
||||
documents: List[Dict[str, Any]],
|
||||
entity_types: List[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""从各文档中提取实体"""
|
||||
entities_per_doc = []
|
||||
|
||||
for idx, doc in enumerate(documents):
|
||||
doc_id = doc.get("_id", f"doc_{idx}")
|
||||
content = doc.get("content", "")[:8000] # 限制长度
|
||||
|
||||
# 使用 LLM 提取实体
|
||||
prompt = f"""从以下文档中提取指定的实体类型信息。
|
||||
|
||||
实体类型: {', '.join(entity_types)}
|
||||
|
||||
文档内容:
|
||||
{content}
|
||||
|
||||
请按以下 JSON 格式输出(只需输出 JSON):
|
||||
{{
|
||||
"entities": [
|
||||
{{"type": "机构", "name": "实体名称", "value": "相关数值(如有)", "context": "上下文描述"}},
|
||||
...
|
||||
]
|
||||
}}
|
||||
|
||||
只提取在文档中明确提到的实体,不要推测。"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个实体提取专家。请严格按JSON格式输出。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
response = await self.llm.chat(messages=messages, temperature=0.1, max_tokens=3000)
|
||||
content_response = self.llm.extract_message_content(response)
|
||||
|
||||
# 解析 JSON
|
||||
import json
|
||||
import re
|
||||
cleaned = content_response.strip()
|
||||
json_match = re.search(r'\{[\s\S]*\}', cleaned)
|
||||
if json_match:
|
||||
result = json.loads(json_match.group())
|
||||
entities = result.get("entities", [])
|
||||
entities_per_doc.append({
|
||||
"doc_id": doc_id,
|
||||
"doc_name": doc.get("metadata", {}).get("original_filename", f"文档{idx+1}"),
|
||||
"entities": entities
|
||||
})
|
||||
logger.info(f"文档 {doc_id} 提取到 {len(entities)} 个实体")
|
||||
except Exception as e:
|
||||
logger.warning(f"文档 {doc_id} 实体提取失败: {e}")
|
||||
|
||||
return entities_per_doc
|
||||
|
||||
def _align_entities_across_docs(
|
||||
self,
|
||||
entities_per_doc: List[Dict[str, Any]]
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
跨文档实体对齐
|
||||
|
||||
将同一实体在不同文档中的描述进行关联
|
||||
"""
|
||||
aligned: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
|
||||
|
||||
for doc_data in entities_per_doc:
|
||||
doc_id = doc_data["doc_id"]
|
||||
doc_name = doc_data["doc_name"]
|
||||
|
||||
for entity in doc_data.get("entities", []):
|
||||
entity_name = entity.get("name", "")
|
||||
if not entity_name:
|
||||
continue
|
||||
|
||||
# 标准化实体名(去除空格和括号内容)
|
||||
normalized = self._normalize_entity_name(entity_name)
|
||||
|
||||
aligned[normalized].append({
|
||||
"original_name": entity_name,
|
||||
"type": entity.get("type", "未知"),
|
||||
"value": entity.get("value", ""),
|
||||
"context": entity.get("context", ""),
|
||||
"source_doc": doc_name,
|
||||
"source_doc_id": doc_id
|
||||
})
|
||||
|
||||
# 合并相同实体
|
||||
result = {}
|
||||
for normalized, appearances in aligned.items():
|
||||
if len(appearances) > 1:
|
||||
result[normalized] = appearances
|
||||
logger.info(f"实体对齐: {normalized} 在 {len(appearances)} 个文档中出现")
|
||||
|
||||
return result
|
||||
|
||||
def _normalize_entity_name(self, name: str) -> str:
|
||||
"""标准化实体名称"""
|
||||
# 去除空格
|
||||
name = name.strip()
|
||||
# 去除括号内容
|
||||
name = re.sub(r'[((].*?[))]', '', name)
|
||||
# 去除"第X名"等
|
||||
name = re.sub(r'^第\d+[名位个]', '', name)
|
||||
return name.strip()
|
||||
|
||||
async def _extract_relations(
|
||||
self,
|
||||
documents: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, str]]:
|
||||
"""从文档中抽取关系"""
|
||||
relations = []
|
||||
|
||||
# 合并所有文档内容
|
||||
combined_content = "\n\n".join([
|
||||
f"【{doc.get('metadata', {}).get('original_filename', f'文档{i}')}】\n{doc.get('content', '')[:3000]}"
|
||||
for i, doc in enumerate(documents)
|
||||
])
|
||||
|
||||
prompt = f"""从以下文档内容中抽取实体之间的关系。
|
||||
|
||||
文档内容:
|
||||
{combined_content[:8000]}
|
||||
|
||||
请识别以下类型的关系:
|
||||
- 包含关系 (A包含B)
|
||||
- 隶属关系 (A隶属于B)
|
||||
- 合作关系 (A与B合作)
|
||||
- 对比关系 (A vs B)
|
||||
- 时序关系 (A先于B发生)
|
||||
|
||||
请按以下 JSON 格式输出(只需输出 JSON):
|
||||
{{
|
||||
"relations": [
|
||||
{{"entity1": "实体1", "entity2": "实体2", "relation": "关系类型", "description": "关系描述"}},
|
||||
...
|
||||
]
|
||||
}}
|
||||
|
||||
如果没有找到明确的关系,返回空数组。"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个关系抽取专家。请严格按JSON格式输出。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
response = await self.llm.chat(messages=messages, temperature=0.1, max_tokens=3000)
|
||||
content_response = self.llm.extract_message_content(response)
|
||||
|
||||
import json
|
||||
import re
|
||||
cleaned = content_response.strip()
|
||||
json_match = re.search(r'\{{[\s\S]*\}}', cleaned)
|
||||
if json_match:
|
||||
result = json.loads(json_match.group())
|
||||
relations = result.get("relations", [])
|
||||
logger.info(f"抽取到 {len(relations)} 个关系")
|
||||
except Exception as e:
|
||||
logger.warning(f"关系抽取失败: {e}")
|
||||
|
||||
return relations
|
||||
|
||||
def _build_knowledge_graph(
|
||||
self,
|
||||
aligned_entities: Dict[str, List[Dict[str, Any]]],
|
||||
relations: List[Dict[str, str]]
|
||||
) -> Dict[str, Any]:
|
||||
"""构建知识图谱"""
|
||||
nodes = []
|
||||
edges = []
|
||||
node_ids = set()
|
||||
|
||||
# 添加实体节点
|
||||
for entity_name, appearances in aligned_entities.items():
|
||||
if len(appearances) < 1:
|
||||
continue
|
||||
|
||||
first_appearance = appearances[0]
|
||||
node_id = f"entity_{len(nodes)}"
|
||||
|
||||
# 收集该实体在所有文档中的值
|
||||
values = [a.get("value", "") for a in appearances if a.get("value")]
|
||||
primary_value = values[0] if values else ""
|
||||
|
||||
nodes.append({
|
||||
"id": node_id,
|
||||
"name": entity_name,
|
||||
"type": first_appearance.get("type", "未知"),
|
||||
"value": primary_value,
|
||||
"occurrence_count": len(appearances),
|
||||
"sources": [a.get("source_doc", "") for a in appearances]
|
||||
})
|
||||
node_ids.add(entity_name)
|
||||
|
||||
# 添加关系边
|
||||
for relation in relations:
|
||||
entity1 = self._normalize_entity_name(relation.get("entity1", ""))
|
||||
entity2 = self._normalize_entity_name(relation.get("entity2", ""))
|
||||
|
||||
if entity1 in node_ids and entity2 in node_ids:
|
||||
edges.append({
|
||||
"source": entity1,
|
||||
"target": entity2,
|
||||
"relation": relation.get("relation", "相关"),
|
||||
"description": relation.get("description", "")
|
||||
})
|
||||
|
||||
return {
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"stats": {
|
||||
"entity_count": len(nodes),
|
||||
"relation_count": len(edges)
|
||||
}
|
||||
}
|
||||
|
||||
async def _complete_missing_info(
|
||||
self,
|
||||
knowledge_graph: Dict[str, Any],
|
||||
documents: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""根据多个文档补全信息"""
|
||||
completed = []
|
||||
|
||||
for node in knowledge_graph.get("nodes", []):
|
||||
if not node.get("value") and node.get("occurrence_count", 0) > 1:
|
||||
# 实体在多个文档中出现但没有数值,尝试从 RAG 检索补充
|
||||
query = f"{node['name']} 数值 数据"
|
||||
results = rag_service.retrieve(query, top_k=3, min_score=0.3)
|
||||
|
||||
if results:
|
||||
completed.append({
|
||||
"entity": node["name"],
|
||||
"type": node.get("type", "未知"),
|
||||
"source": "rag_inference",
|
||||
"context": results[0].get("content", "")[:200],
|
||||
"confidence": results[0].get("score", 0)
|
||||
})
|
||||
|
||||
return completed
|
||||
|
||||
def _detect_conflicts(
|
||||
self,
|
||||
aligned_entities: Dict[str, List[Dict[str, Any]]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""检测不同文档间的信息冲突"""
|
||||
conflicts = []
|
||||
|
||||
for entity_name, appearances in aligned_entities.items():
|
||||
if len(appearances) < 2:
|
||||
continue
|
||||
|
||||
# 检查数值冲突
|
||||
values = {}
|
||||
for appearance in appearances:
|
||||
val = appearance.get("value", "")
|
||||
if val:
|
||||
source = appearance.get("source_doc", "未知来源")
|
||||
values[source] = val
|
||||
|
||||
if len(values) > 1:
|
||||
unique_values = set(values.values())
|
||||
if len(unique_values) > 1:
|
||||
conflicts.append({
|
||||
"entity": entity_name,
|
||||
"type": "value_conflict",
|
||||
"details": values,
|
||||
"description": f"实体 '{entity_name}' 在不同文档中有不同数值: {values}"
|
||||
})
|
||||
|
||||
return conflicts
|
||||
|
||||
def _generate_summary(
|
||||
self,
|
||||
aligned_entities: Dict[str, List[Dict[str, Any]]],
|
||||
conflicts: List[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""生成摘要"""
|
||||
summary_parts = []
|
||||
|
||||
total_entities = sum(len(appearances) for appearances in aligned_entities.values())
|
||||
multi_doc_entities = sum(1 for appearances in aligned_entities.values() if len(appearances) > 1)
|
||||
|
||||
summary_parts.append(f"跨文档分析完成:发现 {total_entities} 个实体")
|
||||
summary_parts.append(f"其中 {multi_doc_entities} 个实体在多个文档中被提及")
|
||||
|
||||
if conflicts:
|
||||
summary_parts.append(f"检测到 {len(conflicts)} 个潜在冲突")
|
||||
|
||||
return "; ".join(summary_parts)
|
||||
|
||||
async def answer_cross_doc_question(
|
||||
self,
|
||||
question: str,
|
||||
documents: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
跨文档问答
|
||||
|
||||
Args:
|
||||
question: 问题
|
||||
documents: 文档列表
|
||||
|
||||
Returns:
|
||||
答案结果
|
||||
"""
|
||||
# 先进行跨文档分析
|
||||
analysis_result = await self.analyze_cross_documents(documents, query=question)
|
||||
|
||||
# 构建上下文
|
||||
context_parts = []
|
||||
|
||||
# 添加实体信息
|
||||
for entity_name, appearances in analysis_result.get("entities", {}).items():
|
||||
contexts = [f"{a.get('source_doc')}: {a.get('context', '')}" for a in appearances[:2]]
|
||||
if contexts:
|
||||
context_parts.append(f"【{entity_name}】{' | '.join(contexts)}")
|
||||
|
||||
# 添加关系信息
|
||||
for relation in analysis_result.get("relations", [])[:5]:
|
||||
context_parts.append(f"【关系】{relation.get('entity1')} {relation.get('relation')} {relation.get('entity2')}: {relation.get('description', '')}")
|
||||
|
||||
context_text = "\n\n".join(context_parts) if context_parts else "未找到相关实体和关系"
|
||||
|
||||
# 使用 LLM 生成答案
|
||||
prompt = f"""基于以下跨文档分析结果,回答用户问题。
|
||||
|
||||
问题: {question}
|
||||
|
||||
分析结果:
|
||||
{context_text}
|
||||
|
||||
请直接回答问题,如果分析结果中没有相关信息,请说明"根据提供的文档无法回答该问题"。"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个基于文档的问答助手。请根据提供的信息回答问题。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
response = await self.llm.chat(messages=messages, temperature=0.2, max_tokens=2000)
|
||||
answer = self.llm.extract_message_content(response)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"supporting_entities": list(analysis_result.get("entities", {}).keys())[:10],
|
||||
"relations_count": len(analysis_result.get("relations", []))
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"跨文档问答失败: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
# 全局单例
|
||||
multi_doc_reasoning_service = MultiDocReasoningService()
|
||||
@@ -2,11 +2,15 @@
|
||||
RAG 服务模块 - 检索增强生成
|
||||
|
||||
使用 sentence-transformers + Faiss 实现向量检索
|
||||
支持 BM25 关键词检索 + 向量检索混合融合
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
from typing import Any, Dict, List, Optional
|
||||
import re
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from collections import Counter, defaultdict
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
@@ -32,6 +36,132 @@ class SimpleDocument:
|
||||
self.metadata = metadata
|
||||
|
||||
|
||||
class BM25:
|
||||
"""
|
||||
BM25 关键词检索算法
|
||||
|
||||
一种基于词频和文档频率的信息检索算法,比纯向量搜索更适合关键词精确匹配
|
||||
"""
|
||||
|
||||
def __init__(self, k1: float = 1.5, b: float = 0.75):
|
||||
self.k1 = k1 # 词频饱和参数
|
||||
self.b = b # 文档长度归一化参数
|
||||
self.documents: List[str] = []
|
||||
self.doc_ids: List[str] = []
|
||||
self.avg_doc_length = 0
|
||||
self.doc_freqs: Dict[str, int] = {} # 词 -> 包含该词的文档数
|
||||
self.idf: Dict[str, float] = {} # 词 -> IDF 值
|
||||
self.doc_lengths: List[int] = []
|
||||
self.doc_term_freqs: List[Dict[str, int]] = [] # 每个文档的词频
|
||||
|
||||
def _tokenize(self, text: str) -> List[str]:
|
||||
"""分词(简单的中文分词)"""
|
||||
if not text:
|
||||
return []
|
||||
# 简单分词:按标点和空格分割
|
||||
tokens = re.findall(r'[\u4e00-\u9fff]+|[a-zA-Z0-9]+', text.lower())
|
||||
# 过滤单字符
|
||||
return [t for t in tokens if len(t) > 1]
|
||||
|
||||
def fit(self, documents: List[str], doc_ids: List[str]):
|
||||
"""
|
||||
构建 BM25 索引
|
||||
|
||||
Args:
|
||||
documents: 文档内容列表
|
||||
doc_ids: 文档 ID 列表
|
||||
"""
|
||||
self.documents = documents
|
||||
self.doc_ids = doc_ids
|
||||
n = len(documents)
|
||||
|
||||
# 统计文档频率
|
||||
self.doc_freqs = defaultdict(int)
|
||||
self.doc_lengths = []
|
||||
self.doc_term_freqs = []
|
||||
|
||||
for doc in documents:
|
||||
tokens = self._tokenize(doc)
|
||||
self.doc_lengths.append(len(tokens))
|
||||
doc_tf = Counter(tokens)
|
||||
self.doc_term_freqs.append(doc_tf)
|
||||
|
||||
for term in doc_tf:
|
||||
self.doc_freqs[term] += 1
|
||||
|
||||
# 计算平均文档长度
|
||||
self.avg_doc_length = sum(self.doc_lengths) / n if n > 0 else 0
|
||||
|
||||
# 计算 IDF
|
||||
for term, df in self.doc_freqs.items():
|
||||
# IDF = log((n - df + 0.5) / (df + 0.5))
|
||||
self.idf[term] = math.log((n - df + 0.5) / (df + 0.5) + 1)
|
||||
|
||||
logger.info(f"BM25 索引构建完成: {n} 个文档, {len(self.idf)} 个词项")
|
||||
|
||||
def search(self, query: str, top_k: int = 10) -> List[Tuple[int, float]]:
|
||||
"""
|
||||
搜索相关文档
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
top_k: 返回前 k 个结果
|
||||
|
||||
Returns:
|
||||
[(文档索引, BM25分数), ...]
|
||||
"""
|
||||
if not self.documents:
|
||||
return []
|
||||
|
||||
query_tokens = self._tokenize(query)
|
||||
if not query_tokens:
|
||||
return []
|
||||
|
||||
scores = []
|
||||
n = len(self.documents)
|
||||
|
||||
for idx in range(n):
|
||||
score = self._calculate_score(query_tokens, idx)
|
||||
scores.append((idx, score))
|
||||
|
||||
# 按分数降序排序
|
||||
scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
return scores[:top_k]
|
||||
|
||||
def _calculate_score(self, query_tokens: List[str], doc_idx: int) -> float:
|
||||
"""计算单个文档的 BM25 分数"""
|
||||
doc_tf = self.doc_term_freqs[doc_idx]
|
||||
doc_len = self.doc_lengths[doc_idx]
|
||||
score = 0.0
|
||||
|
||||
for term in query_tokens:
|
||||
if term not in self.idf:
|
||||
continue
|
||||
|
||||
tf = doc_tf.get(term, 0)
|
||||
idf = self.idf[term]
|
||||
|
||||
# BM25 公式
|
||||
numerator = tf * (self.k1 + 1)
|
||||
denominator = tf + self.k1 * (1 - self.b + self.b * doc_len / self.avg_doc_length)
|
||||
|
||||
score += idf * numerator / denominator
|
||||
|
||||
return score
|
||||
|
||||
def get_scores(self, query: str) -> List[float]:
|
||||
"""获取所有文档的 BM25 分数"""
|
||||
if not self.documents:
|
||||
return []
|
||||
|
||||
query_tokens = self._tokenize(query)
|
||||
if not query_tokens:
|
||||
return [0.0] * len(self.documents)
|
||||
|
||||
return [self._calculate_score(query_tokens, idx) for idx in range(len(self.documents))]
|
||||
|
||||
|
||||
class RAGService:
|
||||
"""RAG 检索增强服务"""
|
||||
|
||||
@@ -47,12 +177,15 @@ class RAGService:
|
||||
self._dimension: int = 384 # 默认维度
|
||||
self._initialized = False
|
||||
self._persist_dir = settings.FAISS_INDEX_DIR
|
||||
# BM25 索引
|
||||
self.bm25: Optional[BM25] = None
|
||||
self._bm25_enabled = True # 始终启用 BM25
|
||||
# 检查是否可用
|
||||
self._disabled = not SENTENCE_TRANSFORMERS_AVAILABLE
|
||||
if self._disabled:
|
||||
logger.warning("RAG 服务已禁用(sentence-transformers 不可用),将使用关键词匹配作为后备")
|
||||
logger.warning("RAG 服务已禁用(sentence-transformers 不可用),将使用 BM25 关键词检索")
|
||||
else:
|
||||
logger.info("RAG 服务已启用")
|
||||
logger.info("RAG 服务已启用(向量检索 + BM25 混合检索)")
|
||||
|
||||
def _init_embeddings(self):
|
||||
"""初始化嵌入模型"""
|
||||
@@ -261,11 +394,25 @@ class RAGService:
|
||||
if not documents:
|
||||
return
|
||||
|
||||
# 总是将文档存储在内存中(用于关键词搜索后备)
|
||||
# 总是将文档存储在内存中(用于 BM25 和关键词搜索)
|
||||
for doc, did in zip(documents, doc_ids):
|
||||
self.documents.append({"id": did, "content": doc.page_content, "metadata": doc.metadata})
|
||||
self.doc_ids.append(did)
|
||||
|
||||
# 构建 BM25 索引
|
||||
if self._bm25_enabled and documents:
|
||||
bm25_texts = [doc.page_content for doc in documents]
|
||||
if self.bm25 is None:
|
||||
self.bm25 = BM25()
|
||||
self.bm25.fit(bm25_texts, doc_ids)
|
||||
else:
|
||||
# 增量添加:重新构建(BM25 不支持增量)
|
||||
all_texts = [d["content"] for d in self.documents]
|
||||
all_ids = self.doc_ids.copy()
|
||||
self.bm25 = BM25()
|
||||
self.bm25.fit(all_texts, all_ids)
|
||||
logger.debug(f"BM25 索引更新: {len(documents)} 个文档")
|
||||
|
||||
# 如果没有嵌入模型,跳过向量索引
|
||||
if self.embedding_model is None:
|
||||
logger.debug(f"文档跳过向量索引 (无嵌入模型): {len(documents)} 个文档")
|
||||
@@ -284,7 +431,7 @@ class RAGService:
|
||||
|
||||
def retrieve(self, query: str, top_k: int = 5, min_score: float = 0.3) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
根据查询检索相关文档块
|
||||
根据查询检索相关文档块(混合检索:向量 + BM25)
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
@@ -301,13 +448,38 @@ class RAGService:
|
||||
if not self._initialized:
|
||||
self._init_vector_store()
|
||||
|
||||
# 优先使用向量检索
|
||||
if self.index is not None and self.index.ntotal > 0 and self.embedding_model is not None:
|
||||
# 获取向量检索结果
|
||||
vector_results = self._vector_search(query, top_k * 2, min_score)
|
||||
|
||||
# 获取 BM25 检索结果
|
||||
bm25_results = self._bm25_search(query, top_k * 2)
|
||||
|
||||
# 混合融合
|
||||
hybrid_results = self._hybrid_fusion(vector_results, bm25_results, top_k)
|
||||
|
||||
if hybrid_results:
|
||||
logger.info(f"混合检索到 {len(hybrid_results)} 条相关文档块 (向量:{len(vector_results)}, BM25:{len(bm25_results)})")
|
||||
return hybrid_results
|
||||
|
||||
# 降级:只使用 BM25
|
||||
if bm25_results:
|
||||
logger.info(f"降级到 BM25 检索: {len(bm25_results)} 条")
|
||||
return bm25_results
|
||||
|
||||
# 降级:使用关键词搜索
|
||||
logger.info("降级到关键词搜索")
|
||||
return self._keyword_search(query, top_k)
|
||||
|
||||
def _vector_search(self, query: str, top_k: int, min_score: float) -> List[Dict[str, Any]]:
|
||||
"""向量检索"""
|
||||
if self.index is None or self.index.ntotal == 0 or self.embedding_model is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
query_embedding = self.embedding_model.encode([query], convert_to_numpy=True)
|
||||
query_embedding = self._normalize_vectors(query_embedding).astype('float32')
|
||||
|
||||
scores, indices = self.index.search(query_embedding, min(top_k, self.index.ntotal))
|
||||
scores, indices = self.index.search(query_embedding, min(top_k * 2, self.index.ntotal))
|
||||
|
||||
results = []
|
||||
for score, idx in zip(scores[0], indices[0]):
|
||||
@@ -321,18 +493,121 @@ class RAGService:
|
||||
"metadata": doc["metadata"],
|
||||
"score": float(score),
|
||||
"doc_id": doc["id"],
|
||||
"chunk_index": doc["metadata"].get("chunk_index", 0)
|
||||
"chunk_index": doc["metadata"].get("chunk_index", 0),
|
||||
"search_type": "vector"
|
||||
})
|
||||
|
||||
if results:
|
||||
logger.debug(f"向量检索到 {len(results)} 条相关文档块")
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(f"向量检索失败,使用关键词搜索后备: {e}")
|
||||
logger.warning(f"向量检索失败: {e}")
|
||||
return []
|
||||
|
||||
# 后备:使用关键词搜索
|
||||
logger.debug("使用关键词搜索后备方案")
|
||||
return self._keyword_search(query, top_k)
|
||||
def _bm25_search(self, query: str, top_k: int) -> List[Dict[str, Any]]:
|
||||
"""BM25 检索"""
|
||||
if not self.bm25 or not self.documents:
|
||||
return []
|
||||
|
||||
try:
|
||||
bm25_scores = self.bm25.get_scores(query)
|
||||
if not bm25_scores:
|
||||
return []
|
||||
|
||||
# 归一化 BM25 分数到 [0, 1]
|
||||
max_score = max(bm25_scores) if bm25_scores else 1
|
||||
min_score_bm = min(bm25_scores) if bm25_scores else 0
|
||||
score_range = max_score - min_score_bm if max_score != min_score_bm else 1
|
||||
|
||||
results = []
|
||||
for idx, score in enumerate(bm25_scores):
|
||||
if score <= 0:
|
||||
continue
|
||||
# 归一化
|
||||
normalized_score = (score - min_score_bm) / score_range if score_range > 0 else 0
|
||||
doc = self.documents[idx]
|
||||
results.append({
|
||||
"content": doc["content"],
|
||||
"metadata": doc["metadata"],
|
||||
"score": float(normalized_score),
|
||||
"doc_id": doc["id"],
|
||||
"chunk_index": doc["metadata"].get("chunk_index", 0),
|
||||
"search_type": "bm25"
|
||||
})
|
||||
|
||||
# 按分数降序
|
||||
results.sort(key=lambda x: x["score"], reverse=True)
|
||||
return results[:top_k]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"BM25 检索失败: {e}")
|
||||
return []
|
||||
|
||||
def _hybrid_fusion(
|
||||
self,
|
||||
vector_results: List[Dict[str, Any]],
|
||||
bm25_results: List[Dict[str, Any]],
|
||||
top_k: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
混合融合向量和 BM25 检索结果
|
||||
|
||||
使用 RRFR (Reciprocal Rank Fusion) 算法:
|
||||
Score = weight_vector * (1 / rank_vector) + weight_bm25 * (1 / rank_bm25)
|
||||
|
||||
Args:
|
||||
vector_results: 向量检索结果
|
||||
bm25_results: BM25 检索结果
|
||||
top_k: 返回数量
|
||||
|
||||
Returns:
|
||||
融合后的结果
|
||||
"""
|
||||
if not vector_results and not bm25_results:
|
||||
return []
|
||||
|
||||
# 融合权重
|
||||
weight_vector = 0.6
|
||||
weight_bm25 = 0.4
|
||||
|
||||
# 构建文档分数映射
|
||||
doc_scores: Dict[str, Dict[str, float]] = {}
|
||||
|
||||
# 添加向量检索结果
|
||||
for rank, result in enumerate(vector_results):
|
||||
doc_id = result["doc_id"]
|
||||
if doc_id not in doc_scores:
|
||||
doc_scores[doc_id] = {"vector": 0, "bm25": 0, "content": result["content"], "metadata": result["metadata"]}
|
||||
# 使用倒数排名 (Reciprocal Rank)
|
||||
doc_scores[doc_id]["vector"] = weight_vector / (rank + 1)
|
||||
|
||||
# 添加 BM25 检索结果
|
||||
for rank, result in enumerate(bm25_results):
|
||||
doc_id = result["doc_id"]
|
||||
if doc_id not in doc_scores:
|
||||
doc_scores[doc_id] = {"vector": 0, "bm25": 0, "content": result["content"], "metadata": result["metadata"]}
|
||||
doc_scores[doc_id]["bm25"] = weight_bm25 / (rank + 1)
|
||||
|
||||
# 计算融合分数
|
||||
fused_results = []
|
||||
for doc_id, scores in doc_scores.items():
|
||||
fused_score = scores["vector"] + scores["bm25"]
|
||||
# 使用向量检索结果的原始分数作为参考
|
||||
vector_score = next((r["score"] for r in vector_results if r["doc_id"] == doc_id), 0.5)
|
||||
fused_results.append({
|
||||
"content": scores["content"],
|
||||
"metadata": scores["metadata"],
|
||||
"score": fused_score,
|
||||
"doc_id": doc_id,
|
||||
"vector_score": vector_score,
|
||||
"bm25_score": scores["bm25"],
|
||||
"search_type": "hybrid"
|
||||
})
|
||||
|
||||
# 按融合分数降序排序
|
||||
fused_results.sort(key=lambda x: x["score"], reverse=True)
|
||||
|
||||
logger.debug(f"混合融合: {len(fused_results)} 个文档, 向量:{len(vector_results)}, BM25:{len(bm25_results)}")
|
||||
|
||||
return fused_results[:top_k]
|
||||
|
||||
def _keyword_search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
|
||||
@@ -13,6 +13,7 @@ from app.services.llm_service import llm_service
|
||||
from app.core.document_parser import ParserFactory
|
||||
from app.services.markdown_ai_service import markdown_ai_service
|
||||
from app.services.rag_service import rag_service
|
||||
from app.services.word_ai_service import word_ai_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -55,6 +56,249 @@ class FillResult:
|
||||
class TemplateFillService:
|
||||
"""表格填写服务"""
|
||||
|
||||
# 通用表头语义扩展字典
|
||||
GENERIC_HEADER_EXPANSION = {
|
||||
"机构": ["医院", "学校", "企业", "机关", "团体", "协会", "基金会", "研究所", "医院数量", "学校数量", "企业数量"],
|
||||
"名称": ["医院名称", "学校名称", "企业名称", "机构名称", "单位名称", "名称"],
|
||||
"类型": ["医院类型", "学校类型", "企业类型", "机构类型", "类型分类"],
|
||||
"数量": ["医院数量", "学校数量", "企业数量", "机构数量", "个数", "总数", "人员数量"],
|
||||
"金额": ["金额", "收入", "支出", "产值", "销售额", "利润", "税收"],
|
||||
"比率": ["增长率", "占比", "比重", "比率", "百分比", "使用率", "就业率"],
|
||||
"面积": ["占地面积", "建筑面积", "用地面积", "耕地面积", "绿化面积"],
|
||||
"人口": ["常住人口", "户籍人口", "流动人口", "城镇人口", "农村人口"],
|
||||
"价格": ["价格", "物价", "CPI", "涨幅", "指数"],
|
||||
"增长": ["增速", "增长率", "增幅", "增长", "上涨", "下降"],
|
||||
}
|
||||
|
||||
# 模板表头到源文档表头的映射缓存
|
||||
_header_mapping_cache: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
def _analyze_source_table_structure(self, source_docs: List["SourceDocument"]) -> Dict[str, Any]:
|
||||
"""
|
||||
分析源文档的表格结构
|
||||
|
||||
Args:
|
||||
source_docs: 源文档列表
|
||||
|
||||
Returns:
|
||||
表格结构分析结果,包含所有表头和样本数据
|
||||
"""
|
||||
table_structures = {}
|
||||
|
||||
for doc_idx, doc in enumerate(source_docs):
|
||||
structured = doc.structured_data if doc.structured_data else {}
|
||||
|
||||
# 处理多 sheet 格式
|
||||
if structured.get("sheets"):
|
||||
for sheet_name, sheet_data in structured.get("sheets", {}).items():
|
||||
if isinstance(sheet_data, dict):
|
||||
columns = sheet_data.get("columns", [])
|
||||
rows = sheet_data.get("rows", [])[:10] # 只取前10行作为样本
|
||||
key = f"doc{doc_idx}_{sheet_name}"
|
||||
table_structures[key] = {
|
||||
"doc_idx": doc_idx,
|
||||
"sheet_name": sheet_name,
|
||||
"columns": columns,
|
||||
"sample_rows": rows,
|
||||
"column_count": len(columns),
|
||||
"row_count": len(sheet_data.get("rows", []))
|
||||
}
|
||||
|
||||
# 处理 tables 格式
|
||||
elif structured.get("tables"):
|
||||
for table_idx, table in enumerate(structured.get("tables", [])[:5]):
|
||||
if isinstance(table, dict):
|
||||
headers = table.get("headers", [])
|
||||
rows = table.get("rows", [])[:10]
|
||||
key = f"doc{doc_idx}_table{table_idx}"
|
||||
table_structures[key] = {
|
||||
"doc_idx": doc_idx,
|
||||
"table_idx": table_idx,
|
||||
"columns": headers,
|
||||
"sample_rows": rows,
|
||||
"column_count": len(headers),
|
||||
"row_count": len(table.get("rows", []))
|
||||
}
|
||||
|
||||
# 处理单 sheet 格式
|
||||
elif structured.get("columns") and structured.get("rows"):
|
||||
columns = structured.get("columns", [])
|
||||
rows = structured.get("rows", [])[:10]
|
||||
key = f"doc{doc_idx}_default"
|
||||
table_structures[key] = {
|
||||
"doc_idx": doc_idx,
|
||||
"columns": columns,
|
||||
"sample_rows": rows,
|
||||
"column_count": len(columns),
|
||||
"row_count": len(structured.get("rows", []))
|
||||
}
|
||||
|
||||
logger.info(f"分析源文档表格结构: {len(table_structures)} 个表格")
|
||||
return table_structures
|
||||
|
||||
def _build_adaptive_header_mapping(
|
||||
self,
|
||||
template_fields: List["TemplateField"],
|
||||
source_table_structures: Dict[str, Any]
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
自适应构建模板表头到源文档表头的映射
|
||||
|
||||
Args:
|
||||
template_fields: 模板字段列表
|
||||
source_table_structures: 源文档表格结构
|
||||
|
||||
Returns:
|
||||
映射字典: {field_name: {source_table_key: {column: idx, match_score: score}}}
|
||||
"""
|
||||
mappings = {}
|
||||
|
||||
for field in template_fields:
|
||||
field_name = field.name
|
||||
field_lower = field_name.lower()
|
||||
field_keywords = set(field_lower.replace(" ", "").split())
|
||||
|
||||
best_matches = {}
|
||||
|
||||
for table_key, table_info in source_table_structures.items():
|
||||
columns = table_info.get("columns", [])
|
||||
if not columns:
|
||||
continue
|
||||
|
||||
best_col_idx = None
|
||||
best_col_name = None
|
||||
best_score = 0
|
||||
|
||||
for col_idx, col in enumerate(columns):
|
||||
col_str = str(col).strip()
|
||||
col_lower = col_str.lower()
|
||||
col_keywords = set(col_lower.replace(" ", "").split())
|
||||
|
||||
score = 0
|
||||
|
||||
# 1. 精确匹配
|
||||
if col_lower == field_lower:
|
||||
score = 1.0
|
||||
# 2. 子字符串匹配
|
||||
elif field_lower in col_lower or col_lower in field_lower:
|
||||
score = 0.8 * max(len(field_lower), len(col_lower)) / min(len(field_lower) + 1, len(col_lower) + 1)
|
||||
# 3. 关键词重叠
|
||||
else:
|
||||
overlap = field_keywords & col_keywords
|
||||
if overlap:
|
||||
score = 0.6 * len(overlap) / max(len(field_keywords), len(col_keywords), 1)
|
||||
|
||||
# 4. 检查通用表头扩展
|
||||
if score < 0.5:
|
||||
for generic, specifics in self.GENERIC_HEADER_EXPANSION.items():
|
||||
if generic in field_lower:
|
||||
for specific in specifics:
|
||||
if specific in col_lower or col_lower in specific:
|
||||
score = 0.7
|
||||
break
|
||||
if score >= 0.5:
|
||||
break
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_col_idx = col_idx
|
||||
best_col_name = col_str
|
||||
|
||||
if best_score >= 0.3 and best_col_idx is not None:
|
||||
best_matches[table_key] = {
|
||||
"column_index": best_col_idx,
|
||||
"column_name": best_col_name,
|
||||
"match_score": best_score,
|
||||
"table_info": table_info
|
||||
}
|
||||
|
||||
if best_matches:
|
||||
mappings[field_name] = best_matches
|
||||
logger.info(f"字段 '{field_name}' 匹配到 {len(best_matches)} 个源表头,最佳匹配: {list(best_matches.values())[0].get('column_name')}")
|
||||
|
||||
return mappings
|
||||
|
||||
def _extract_with_adaptive_mapping(
|
||||
self,
|
||||
source_docs: List["SourceDocument"],
|
||||
field_name: str,
|
||||
mapping: Dict[str, Dict[str, Any]]
|
||||
) -> List[str]:
|
||||
"""
|
||||
使用自适应映射提取字段值
|
||||
|
||||
Args:
|
||||
source_docs: 源文档列表
|
||||
field_name: 字段名
|
||||
mapping: 字段到源表头的映射
|
||||
|
||||
Returns:
|
||||
提取的值列表
|
||||
"""
|
||||
values = []
|
||||
|
||||
if field_name not in mapping:
|
||||
return values
|
||||
|
||||
best_matches = mapping[field_name]
|
||||
|
||||
for table_key, match_info in best_matches.items():
|
||||
table_info = match_info.get("table_info", {})
|
||||
col_idx = match_info.get("column_index", 0)
|
||||
doc_idx = table_info.get("doc_idx", 0)
|
||||
|
||||
if doc_idx >= len(source_docs):
|
||||
continue
|
||||
|
||||
doc = source_docs[doc_idx]
|
||||
structured = doc.structured_data if doc.structured_data else {}
|
||||
|
||||
# 根据表格类型提取值
|
||||
rows = []
|
||||
|
||||
# 多 sheet 格式
|
||||
if structured.get("sheets"):
|
||||
sheet_name = table_info.get("sheet_name")
|
||||
if sheet_name:
|
||||
sheet_data = structured.get("sheets", {}).get(sheet_name, {})
|
||||
rows = sheet_data.get("rows", [])
|
||||
|
||||
# tables 格式
|
||||
elif structured.get("tables"):
|
||||
table_idx = table_info.get("table_idx", 0)
|
||||
tables = structured.get("tables", [])
|
||||
if table_idx < len(tables):
|
||||
rows = tables[table_idx].get("rows", [])
|
||||
|
||||
# 单 sheet 格式
|
||||
elif structured.get("rows"):
|
||||
rows = structured.get("rows", [])
|
||||
|
||||
# 提取指定列的值
|
||||
for row in rows:
|
||||
if isinstance(row, list) and col_idx < len(row):
|
||||
val = self._format_value(row[col_idx])
|
||||
if val and self._is_valid_data_value(val):
|
||||
values.append(val)
|
||||
elif isinstance(row, dict):
|
||||
# 对于 dict 格式的行
|
||||
columns = table_info.get("columns", [])
|
||||
if col_idx < len(columns):
|
||||
col_name = columns[col_idx]
|
||||
val = self._format_value(row.get(col_name, ""))
|
||||
if val and self._is_valid_data_value(val):
|
||||
values.append(val)
|
||||
|
||||
# 过滤和去重
|
||||
seen = set()
|
||||
unique_values = []
|
||||
for v in values:
|
||||
if v not in seen:
|
||||
seen.add(v)
|
||||
unique_values.append(v)
|
||||
|
||||
return unique_values
|
||||
|
||||
def __init__(self):
|
||||
self.llm = llm_service
|
||||
|
||||
@@ -305,6 +549,62 @@ class TemplateFillService:
|
||||
if source_file_paths:
|
||||
for file_path in source_file_paths:
|
||||
try:
|
||||
file_ext = file_path.lower().split('.')[-1]
|
||||
|
||||
# 对于 Word 文档,优先使用 AI 解析
|
||||
if file_ext == 'docx':
|
||||
# 使用 AI 深度解析 Word 文档
|
||||
ai_result = await word_ai_service.parse_word_with_ai(
|
||||
file_path=file_path,
|
||||
user_hint="请提取文档中的所有结构化数据,包括表格、键值对等"
|
||||
)
|
||||
|
||||
if ai_result.get("success"):
|
||||
# AI 解析成功,转换为 SourceDocument 格式
|
||||
parse_type = ai_result.get("type", "unknown")
|
||||
|
||||
# 构建 structured_data
|
||||
doc_structured = {
|
||||
"ai_parsed": True,
|
||||
"parse_type": parse_type,
|
||||
"tables": [],
|
||||
"key_values": ai_result.get("key_values", {}) if "key_values" in ai_result else {},
|
||||
"list_items": ai_result.get("list_items", []) if "list_items" in ai_result else [],
|
||||
"summary": ai_result.get("summary", "") if "summary" in ai_result else ""
|
||||
}
|
||||
|
||||
# 如果 AI 返回了表格数据
|
||||
if parse_type == "table_data":
|
||||
headers = ai_result.get("headers", [])
|
||||
rows = ai_result.get("rows", [])
|
||||
if headers and rows:
|
||||
doc_structured["tables"] = [{
|
||||
"headers": headers,
|
||||
"rows": rows
|
||||
}]
|
||||
doc_structured["columns"] = headers
|
||||
doc_structured["rows"] = rows
|
||||
logger.info(f"AI 表格数据: {len(headers)} 列, {len(rows)} 行")
|
||||
elif parse_type == "structured_text":
|
||||
tables = ai_result.get("tables", [])
|
||||
if tables:
|
||||
doc_structured["tables"] = tables
|
||||
logger.info(f"AI 结构化文本提取到 {len(tables)} 个表格")
|
||||
|
||||
# 获取摘要内容
|
||||
content_text = doc_structured.get("summary", "") or ai_result.get("description", "")
|
||||
|
||||
source_docs.append(SourceDocument(
|
||||
doc_id=file_path,
|
||||
filename=file_path.split("/")[-1] if "/" in file_path else file_path.split("\\")[-1],
|
||||
doc_type="docx",
|
||||
content=content_text,
|
||||
structured_data=doc_structured
|
||||
))
|
||||
logger.info(f"AI 解析 Word 文档: {file_path}, type={parse_type}, tables={len(doc_structured.get('tables', []))}")
|
||||
continue # 跳后续的基础解析
|
||||
|
||||
# 基础解析(Excel 或非 AI 解析的 Word)
|
||||
parser = ParserFactory.get_parser(file_path)
|
||||
result = parser.parse(file_path)
|
||||
if result.success:
|
||||
@@ -1351,6 +1651,36 @@ class TemplateFillService:
|
||||
if all_values:
|
||||
break
|
||||
|
||||
# 处理 AI 解析的 Word 文档键值对格式: {key_values: {"键": "值"}, ...}
|
||||
if structured.get("key_values") and isinstance(structured.get("key_values"), dict):
|
||||
key_values = structured.get("key_values", {})
|
||||
logger.info(f" 检测到 AI 解析键值对格式,共 {len(key_values)} 个键值对")
|
||||
values = self._extract_from_key_values(key_values, field_name)
|
||||
if values:
|
||||
all_values.extend(values)
|
||||
logger.info(f"从 Word AI 键值对提取到 {len(values)} 个值: {values}")
|
||||
break
|
||||
|
||||
# 处理 AI 解析的 list_items 格式
|
||||
if structured.get("list_items") and isinstance(structured.get("list_items"), list):
|
||||
list_items = structured.get("list_items", [])
|
||||
logger.info(f" 检测到 AI 解析列表格式,共 {len(list_items)} 个列表项")
|
||||
values = self._extract_from_list_items(list_items, field_name)
|
||||
if values:
|
||||
all_values.extend(values)
|
||||
logger.info(f"从 Word AI 列表提取到 {len(values)} 个值")
|
||||
break
|
||||
|
||||
# 如果从结构化数据中没有提取到值,且字段是通用表头,搜索文本内容
|
||||
if not all_values and field_name in self.GENERIC_HEADER_EXPANSION:
|
||||
for doc in source_docs:
|
||||
if doc.content:
|
||||
text_values = self._search_generic_header_in_text(doc.content, field_name)
|
||||
if text_values:
|
||||
all_values.extend(text_values)
|
||||
logger.info(f"从文本内容通过通用表头匹配提取到 {len(text_values)} 个值")
|
||||
break
|
||||
|
||||
return all_values
|
||||
|
||||
def _extract_values_from_markdown_table(self, headers: List, rows: List, field_name: str) -> List[str]:
|
||||
@@ -1376,10 +1706,27 @@ class TemplateFillService:
|
||||
# 查找匹配的列索引 - 使用增强的匹配算法
|
||||
target_idx = self._find_best_matching_column(headers, field_name)
|
||||
|
||||
if target_idx is None:
|
||||
# 如果没有找到列匹配,尝试在第一列中搜索字段名(适用于指标在行的文档)
|
||||
matched_row_idx = None
|
||||
if target_idx is None and rows:
|
||||
matched_row_idx = self._search_row_in_first_column(rows, field_name)
|
||||
if matched_row_idx is not None:
|
||||
logger.info(f"在第一列找到匹配: {field_name} -> 行索引 {matched_row_idx} (转置表格结构)")
|
||||
|
||||
if target_idx is None and matched_row_idx is None:
|
||||
logger.warning(f"未找到匹配列: {field_name}, 可用表头: {headers}")
|
||||
return []
|
||||
|
||||
# 如果在第一列找到匹配(转置表格),提取该行的其他列作为值
|
||||
if matched_row_idx is not None:
|
||||
matched_row = rows[matched_row_idx]
|
||||
if isinstance(matched_row, list):
|
||||
# 跳过第一列(指标名),提取后续列的值
|
||||
for val in matched_row[1:]:
|
||||
values.append(self._format_value(val))
|
||||
logger.info(f"转置表格提取到 {len(values)} 个值: {values[:5]}...")
|
||||
return self._filter_valid_values(values)
|
||||
|
||||
logger.info(f"列匹配成功: {field_name} -> {headers[target_idx]} (索引: {target_idx})")
|
||||
|
||||
values = []
|
||||
@@ -1527,6 +1874,149 @@ class TemplateFillService:
|
||||
valid_values.append(val)
|
||||
return valid_values
|
||||
|
||||
def _extract_from_key_values(self, key_values: Dict[str, str], field_name: str) -> List[str]:
|
||||
"""
|
||||
从键值对字典中提取与字段名匹配的值
|
||||
|
||||
Args:
|
||||
key_values: 键值对字典,如 {"医院数量": "38710个", "床位总数": "456789张"}
|
||||
field_name: 要匹配的字段名
|
||||
|
||||
Returns:
|
||||
匹配的值列表
|
||||
"""
|
||||
if not key_values:
|
||||
return []
|
||||
|
||||
field_lower = field_name.lower().strip()
|
||||
field_chars = set(field_lower.replace(" ", ""))
|
||||
field_keywords = set(field_lower.replace(" ", "").split())
|
||||
|
||||
best_match_key = None
|
||||
best_match_score = 0
|
||||
|
||||
for key, value in key_values.items():
|
||||
key_str = str(key).strip()
|
||||
key_lower = key_str.lower()
|
||||
key_chars = set(key_lower.replace(" ", ""))
|
||||
|
||||
if not key_str or not value:
|
||||
continue
|
||||
|
||||
# 策略1: 精确匹配(忽略大小写)
|
||||
if key_lower == field_lower:
|
||||
logger.info(f"键值对精确匹配: {field_name} -> {key_str}: {value}")
|
||||
return [str(value)]
|
||||
|
||||
# 策略2: 子字符串匹配
|
||||
if field_lower in key_lower or key_lower in field_lower:
|
||||
score = max(len(field_lower), len(key_lower)) / min(len(field_lower) + 1, len(key_lower) + 1)
|
||||
if score > best_match_score:
|
||||
best_match_score = score
|
||||
best_match_key = (key_str, value)
|
||||
|
||||
# 策略3: 关键词重叠匹配(适用于中文)
|
||||
key_keywords = set(key_lower.replace(" ", "").split())
|
||||
overlap = field_keywords & key_keywords
|
||||
if overlap and len(overlap) > 0:
|
||||
score = len(overlap) / max(len(field_keywords), len(key_keywords), 1)
|
||||
if score > best_match_score:
|
||||
best_match_score = score
|
||||
best_match_key = (key_str, value)
|
||||
|
||||
# 策略4: 字符级包含匹配(适用于中文短字段)
|
||||
char_overlap = field_chars & key_chars
|
||||
if char_overlap:
|
||||
char_score = len(char_overlap) / max(len(field_chars), len(key_chars), 1)
|
||||
# 对于短字段(<=4字符),降低要求
|
||||
if len(field_chars) <= 4 and char_score >= 0.5:
|
||||
if char_score > best_match_score:
|
||||
best_match_score = char_score
|
||||
best_match_key = (key_str, value)
|
||||
elif char_score > best_match_score and len(char_overlap) >= 2:
|
||||
best_match_score = char_score
|
||||
best_match_key = (key_str, value)
|
||||
|
||||
# 降低阈值到 0.2,允许更多模糊匹配
|
||||
if best_match_score >= 0.2 and best_match_key:
|
||||
logger.info(f"键值对模糊匹配: {field_name} -> {best_match_key[0]}: {best_match_key[1]} (分数: {best_match_score:.2f})")
|
||||
return [str(best_match_key[1])]
|
||||
|
||||
logger.warning(f"键值对未匹配到: {field_name}, 可用键: {list(key_values.keys())}")
|
||||
return []
|
||||
|
||||
def _extract_from_list_items(self, list_items: List[str], field_name: str) -> List[str]:
|
||||
"""
|
||||
从列表项中提取与字段名匹配的值
|
||||
|
||||
Args:
|
||||
list_items: 列表项,如 ["医院数量: 38710个", "床位总数: 456789张", ...]
|
||||
field_name: 要匹配的字段名
|
||||
|
||||
Returns:
|
||||
匹配的值列表
|
||||
"""
|
||||
if not list_items:
|
||||
return []
|
||||
|
||||
field_lower = field_name.lower().strip()
|
||||
field_keywords = set(field_lower.replace(" ", "").split())
|
||||
|
||||
matched_values = []
|
||||
|
||||
for item in list_items:
|
||||
item_str = str(item).strip()
|
||||
if not item_str:
|
||||
continue
|
||||
|
||||
item_lower = item_str.lower()
|
||||
|
||||
# 策略1: 检查列表项是否以字段名开头(格式如 "医院数量: 38710个")
|
||||
if ':' in item_str or ':' in item_str:
|
||||
parts = item_str.replace(':', ':').split(':', 1)
|
||||
if len(parts) == 2:
|
||||
key = parts[0].strip()
|
||||
value = parts[1].strip()
|
||||
key_lower = key.lower()
|
||||
|
||||
# 精确匹配键
|
||||
if key_lower == field_lower:
|
||||
logger.info(f"列表项键值精确匹配: {field_name} -> {value}")
|
||||
return [value]
|
||||
|
||||
# 子字符串匹配
|
||||
if field_lower in key_lower or key_lower in field_lower:
|
||||
score = max(len(field_lower), len(key_lower)) / min(len(field_lower) + 1, len(key_lower) + 1)
|
||||
if score >= 0.2:
|
||||
logger.info(f"列表项键值模糊匹配: {field_name} -> {key}: {value} (分数: {score:.2f})")
|
||||
matched_values.append(value)
|
||||
|
||||
# 关键词重叠
|
||||
key_keywords = set(key_lower.replace(" ", "").split())
|
||||
overlap = field_keywords & key_keywords
|
||||
if overlap:
|
||||
score = len(overlap) / max(len(field_keywords), len(key_keywords), 1)
|
||||
if score >= 0.2:
|
||||
matched_values.append(value)
|
||||
|
||||
# 策略2: 直接匹配整个列表项
|
||||
if field_lower in item_lower or item_lower in field_lower:
|
||||
matched_values.append(item_str)
|
||||
continue
|
||||
|
||||
# 策略3: 关键词重叠
|
||||
item_keywords = set(item_lower.replace(" ", "").split())
|
||||
overlap = field_keywords & item_keywords
|
||||
if overlap and len(overlap) >= 2: # 至少2个关键词重叠
|
||||
score = len(overlap) / max(len(field_keywords), len(item_keywords), 1)
|
||||
if score >= 0.2:
|
||||
matched_values.append(item_str)
|
||||
|
||||
if matched_values:
|
||||
logger.info(f"列表项匹配到 {len(matched_values)} 个: {matched_values[:5]}")
|
||||
|
||||
return matched_values
|
||||
|
||||
def _find_best_matching_column(self, headers: List, field_name: str) -> Optional[int]:
|
||||
"""
|
||||
查找最佳匹配的列索引
|
||||
@@ -1535,6 +2025,7 @@ class TemplateFillService:
|
||||
1. 精确匹配(忽略大小写)
|
||||
2. 子字符串匹配(字段名在表头中,或表头在字段名中)
|
||||
3. 关键词重叠匹配(中文字符串分割后比对)
|
||||
4. 字符级包含匹配(适用于中文短字段)
|
||||
|
||||
Args:
|
||||
headers: 表头列表
|
||||
@@ -1544,6 +2035,8 @@ class TemplateFillService:
|
||||
匹配的列索引,找不到返回 None
|
||||
"""
|
||||
field_lower = field_name.lower().strip()
|
||||
# 对中文进行字符级拆分,增加匹配的灵活性
|
||||
field_chars = set(field_lower.replace(" ", ""))
|
||||
field_keywords = set(field_lower.replace(" ", "").split())
|
||||
|
||||
best_match_idx = None
|
||||
@@ -1560,6 +2053,7 @@ class TemplateFillService:
|
||||
|
||||
# 策略1: 精确匹配(忽略大小写)
|
||||
if header_lower == field_lower:
|
||||
logger.info(f"精确匹配: {field_name} -> {header_str}")
|
||||
return idx
|
||||
|
||||
# 策略2: 子字符串匹配
|
||||
@@ -1580,13 +2074,93 @@ class TemplateFillService:
|
||||
best_match_score = score
|
||||
best_match_idx = idx
|
||||
|
||||
# 只有当匹配分数超过阈值时才返回
|
||||
if best_match_score >= 0.3:
|
||||
# 策略4: 字符级包含匹配(适用于中文短字段,如"医院"匹配"医院数量")
|
||||
header_chars = set(header_lower.replace(" ", ""))
|
||||
char_overlap = field_chars & header_chars
|
||||
if char_overlap:
|
||||
# 计算字符重叠率,但要求至少有一定数量的重叠字符
|
||||
char_score = len(char_overlap) / max(len(field_chars), len(header_chars), 1)
|
||||
# 对于短字段(<=4字符),降低要求,只要有重叠且字符score较高即可
|
||||
if len(field_chars) <= 4 and char_score >= 0.5:
|
||||
if char_score > best_match_score:
|
||||
best_match_score = char_score
|
||||
best_match_idx = idx
|
||||
elif char_score > best_match_score and len(char_overlap) >= 2:
|
||||
# 对于较长字段,要求至少2个字符重叠
|
||||
best_match_score = char_score
|
||||
best_match_idx = idx
|
||||
|
||||
# 降低阈值到 0.2,允许更多模糊匹配
|
||||
if best_match_score >= 0.2:
|
||||
logger.info(f"模糊匹配: {field_name} -> {headers[best_match_idx]} (分数: {best_match_score:.2f})")
|
||||
return best_match_idx
|
||||
|
||||
return None
|
||||
|
||||
def _search_row_in_first_column(self, rows: List, field_name: str) -> Optional[int]:
|
||||
"""
|
||||
在表格第一列中搜索字段名(适用于指标在行的转置表格结构)
|
||||
|
||||
对于某些中文统计文档,表格结构是转置的:
|
||||
- 第一列是指标名称(如"医院数量")
|
||||
- 其他列是年份或数值
|
||||
|
||||
Args:
|
||||
rows: 数据行列表
|
||||
field_name: 要搜索的字段名
|
||||
|
||||
Returns:
|
||||
匹配的列索引(始终返回0,因为是第一列),如果没找到返回None
|
||||
"""
|
||||
if not rows or not field_name:
|
||||
return None
|
||||
|
||||
field_lower = field_name.lower().strip()
|
||||
field_chars = set(field_lower.replace(" ", ""))
|
||||
field_keywords = set(field_lower.replace(" ", "").split())
|
||||
|
||||
for row_idx, row in enumerate(rows):
|
||||
if not isinstance(row, list) or len(row) == 0:
|
||||
continue
|
||||
|
||||
first_cell = str(row[0]).strip()
|
||||
if not first_cell:
|
||||
continue
|
||||
|
||||
first_cell_lower = first_cell.lower()
|
||||
|
||||
# 精确匹配
|
||||
if first_cell_lower == field_lower:
|
||||
logger.info(f"第一列精确匹配字段: {field_name} -> {first_cell} (行{row_idx})")
|
||||
return 0
|
||||
|
||||
# 子字符串匹配
|
||||
if field_lower in first_cell_lower or first_cell_lower in field_lower:
|
||||
score = max(len(field_lower), len(first_cell_lower)) / min(len(field_lower) + 1, len(first_cell_lower) + 1)
|
||||
if score >= 0.5:
|
||||
logger.info(f"第一列模糊匹配字段: {field_name} -> {first_cell} (行{row_idx}, 分数:{score:.2f})")
|
||||
return 0
|
||||
|
||||
# 关键词重叠匹配
|
||||
first_keywords = set(first_cell_lower.replace(" ", "").split())
|
||||
overlap = field_keywords & first_keywords
|
||||
if overlap and len(overlap) >= 2:
|
||||
score = len(overlap) / max(len(field_keywords), len(first_keywords), 1)
|
||||
if score >= 0.3:
|
||||
logger.info(f"第一列关键词匹配: {field_name} -> {first_cell} (行{row_idx}, 分数:{score:.2f})")
|
||||
return 0
|
||||
|
||||
# 字符级匹配(短字段)
|
||||
first_chars = set(first_cell_lower.replace(" ", ""))
|
||||
char_overlap = field_chars & first_chars
|
||||
if char_overlap and len(field_chars) <= 4:
|
||||
char_score = len(char_overlap) / max(len(field_chars), len(first_chars), 1)
|
||||
if char_score >= 0.5:
|
||||
logger.info(f"第一列字符匹配: {field_name} -> {first_cell} (行{row_idx}, 分数:{char_score:.2f})")
|
||||
return 0
|
||||
|
||||
return None
|
||||
|
||||
def _extract_column_values(self, rows: List, columns: List, field_name: str) -> List[str]:
|
||||
"""
|
||||
从 rows 和 columns 中提取指定列的值
|
||||
@@ -1677,6 +2251,55 @@ class TemplateFillService:
|
||||
|
||||
return str(val)
|
||||
|
||||
def _search_generic_header_in_text(self, text: str, field_name: str) -> List[str]:
|
||||
"""
|
||||
从文本中搜索通用表头对应的具体值
|
||||
|
||||
例如:表头"机构" -> 搜索文本中的"医院"、"学校"、"企业"等
|
||||
|
||||
Args:
|
||||
text: 文档文本内容
|
||||
field_name: 字段名称(可能是通用表头)
|
||||
|
||||
Returns:
|
||||
匹配到的值列表
|
||||
"""
|
||||
import re
|
||||
|
||||
# 检查是否是通用表头
|
||||
generic_terms = self.GENERIC_HEADER_EXPANSION.get(field_name, [])
|
||||
if not generic_terms:
|
||||
return []
|
||||
|
||||
matched_values = []
|
||||
|
||||
for term in generic_terms:
|
||||
# 搜索 term + 数字/量词 的模式,如 "医院 100所"
|
||||
patterns = [
|
||||
rf'{re.escape(term)}[\s\d所个家级人万元亿元%‰]+', # 医院100所, 企业50家
|
||||
rf'{re.escape(term)}[::\s]+(\d+[\d。,,]?\d*)', # 医院:100
|
||||
rf'(\d+[\d。,,]?\d*)[^\d]*{re.escape(term)}', # 100家医院
|
||||
]
|
||||
for pattern in patterns:
|
||||
matches = re.findall(pattern, text, re.IGNORECASE)
|
||||
for match in matches:
|
||||
val = match.strip() if isinstance(match, str) else match
|
||||
if val and len(str(val)) < 100:
|
||||
matched_values.append(str(val))
|
||||
|
||||
# 去重并保持顺序
|
||||
seen = set()
|
||||
unique_values = []
|
||||
for v in matched_values:
|
||||
if v not in seen:
|
||||
seen.add(v)
|
||||
unique_values.append(v)
|
||||
|
||||
if unique_values:
|
||||
logger.info(f"通用表头 '{field_name}' 匹配到值: {unique_values[:10]}")
|
||||
|
||||
return unique_values
|
||||
|
||||
def _extract_values_from_json(self, result) -> List[str]:
|
||||
"""
|
||||
从解析后的 JSON 对象/数组中提取值数组
|
||||
@@ -2236,29 +2859,32 @@ class TemplateFillService:
|
||||
- 二级分类:如"医院"下分为"公立医院"、"民营医院"
|
||||
|
||||
4. **生成字段**:
|
||||
- 字段名要简洁,如:"医院数量"、"病床使用率"
|
||||
- 优先选择:总数 + 主要分类
|
||||
- 字段名要详细具体,能区分不同数据,如:"医院数量(个)"、"病床使用率(%)"、"公立医院数量"
|
||||
- 优先选择:总数 + 主要分类 + 重要指标
|
||||
|
||||
5. **生成数量**:
|
||||
- 生成5-7个最有代表性的字段
|
||||
- 生成10-15个最有代表性的字段,确保覆盖主要数据指标
|
||||
|
||||
6. **添加字段说明**:
|
||||
- 每个字段可以添加 hint 说明字段的含义和数据来源
|
||||
|
||||
请严格按照以下 JSON 格式输出(只需输出 JSON,不要其他内容):
|
||||
{{
|
||||
"fields": [
|
||||
{{"name": "字段名1"}},
|
||||
{{"name": "字段名2"}}
|
||||
{{"name": "医院数量", "hint": "从文档中提取医院总数,包括公立和民营医院"}},
|
||||
{{"name": "病床使用率", "hint": "提取病床使用率数据"}}
|
||||
]
|
||||
}}
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的表格设计助手。请严格按JSON格式输出,只返回纯数据字段名,不要source、备注、说明等辅助字段。"},
|
||||
{"role": "system", "content": "你是一个专业的表格设计助手。请严格按JSON格式输出,为每个字段生成详细名称和hint说明。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
response = await self.llm.chat(
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
max_tokens=2000
|
||||
max_tokens=4000
|
||||
)
|
||||
|
||||
content = self.llm.extract_message_content(response)
|
||||
|
||||
@@ -192,13 +192,15 @@ class WordAIService:
|
||||
result = self._parse_json_response(content)
|
||||
|
||||
if result:
|
||||
logger.info(f"AI 表格提取成功: {len(result.get('rows', []))} 行数据")
|
||||
logger.info(f"AI 表格提取成功: {len(result.get('rows', []))} 行数据, key_values={len(result.get('key_values', {}))}, list_items={len(result.get('list_items', []))}")
|
||||
return {
|
||||
"success": True,
|
||||
"type": "table_data",
|
||||
"headers": result.get("headers", []),
|
||||
"rows": result.get("rows", []),
|
||||
"description": result.get("description", "")
|
||||
"description": result.get("description", ""),
|
||||
"key_values": result.get("key_values", {}),
|
||||
"list_items": result.get("list_items", [])
|
||||
}
|
||||
else:
|
||||
# 如果 AI 返回格式不对,尝试直接解析表格
|
||||
|
||||
@@ -1459,4 +1459,131 @@ export const aiApi = {
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
|
||||
// ==================== 智能指令 ====================
|
||||
|
||||
/**
|
||||
* 识别自然语言指令的意图
|
||||
*/
|
||||
async recognizeIntent(
|
||||
instruction: string,
|
||||
docIds?: string[]
|
||||
): Promise<{
|
||||
success: boolean;
|
||||
intent: string;
|
||||
params: Record<string, any>;
|
||||
message: string;
|
||||
}> {
|
||||
const url = `${BACKEND_BASE_URL}/instruction/recognize`;
|
||||
|
||||
try {
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ instruction, doc_ids: docIds }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.json();
|
||||
throw new Error(error.detail || '意图识别失败');
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
} catch (error) {
|
||||
console.error('意图识别失败:', error);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* 执行自然语言指令
|
||||
*/
|
||||
async executeInstruction(
|
||||
instruction: string,
|
||||
docIds?: string[],
|
||||
context?: Record<string, any>
|
||||
): Promise<{
|
||||
success: boolean;
|
||||
intent: string;
|
||||
result: Record<string, any>;
|
||||
message: string;
|
||||
}> {
|
||||
const url = `${BACKEND_BASE_URL}/instruction/execute`;
|
||||
|
||||
try {
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ instruction, doc_ids: docIds, context }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.json();
|
||||
throw new Error(error.detail || '指令执行失败');
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
} catch (error) {
|
||||
console.error('指令执行失败:', error);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* 智能对话(支持多轮对话的指令执行)
|
||||
*/
|
||||
async instructionChat(
|
||||
instruction: string,
|
||||
docIds?: string[],
|
||||
context?: Record<string, any>
|
||||
): Promise<{
|
||||
success: boolean;
|
||||
intent: string;
|
||||
result: Record<string, any>;
|
||||
message: string;
|
||||
hint?: string;
|
||||
}> {
|
||||
const url = `${BACKEND_BASE_URL}/instruction/chat`;
|
||||
|
||||
try {
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ instruction, doc_ids: docIds, context }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.json();
|
||||
throw new Error(error.detail || '对话处理失败');
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
} catch (error) {
|
||||
console.error('对话处理失败:', error);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* 获取支持的指令类型列表
|
||||
*/
|
||||
async getSupportedIntents(): Promise<{
|
||||
intents: Array<{
|
||||
intent: string;
|
||||
name: string;
|
||||
examples: string[];
|
||||
params: string[];
|
||||
}>;
|
||||
}> {
|
||||
const url = `${BACKEND_BASE_URL}/instruction/intents`;
|
||||
|
||||
try {
|
||||
const response = await fetch(url);
|
||||
if (!response.ok) throw new Error('获取指令列表失败');
|
||||
return await response.json();
|
||||
} catch (error) {
|
||||
console.error('获取指令列表失败:', error);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
@@ -10,7 +10,11 @@ import {
|
||||
TableProperties,
|
||||
ChevronRight,
|
||||
ArrowRight,
|
||||
Loader2
|
||||
Loader2,
|
||||
Download,
|
||||
Search,
|
||||
MessageSquare,
|
||||
CheckCircle
|
||||
} from 'lucide-react';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Input } from '@/components/ui/input';
|
||||
@@ -26,12 +30,15 @@ type ChatMessage = {
|
||||
role: 'user' | 'assistant';
|
||||
content: string;
|
||||
created_at: string;
|
||||
intent?: string;
|
||||
result?: any;
|
||||
};
|
||||
|
||||
const InstructionChat: React.FC = () => {
|
||||
const [messages, setMessages] = useState<ChatMessage[]>([]);
|
||||
const [input, setInput] = useState('');
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [currentDocIds, setCurrentDocIds] = useState<string[]>([]);
|
||||
const scrollAreaRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -43,27 +50,47 @@ const InstructionChat: React.FC = () => {
|
||||
role: 'assistant',
|
||||
content: `您好!我是智联文档 AI 助手。
|
||||
|
||||
我可以帮您完成以下操作:
|
||||
**📄 文档智能操作**
|
||||
- "提取文档中的医院数量和床位数"
|
||||
- "帮我找出所有机构的名称"
|
||||
|
||||
📄 **文档管理**
|
||||
- "帮我列出最近上传的所有文档"
|
||||
- "删除三天前的 docx 文档"
|
||||
**📊 数据填表**
|
||||
- "根据这些数据填表"
|
||||
- "将提取的信息填写到Excel模板"
|
||||
|
||||
📊 **Excel 分析**
|
||||
- "分析一下最近上传的 Excel 文件"
|
||||
- "帮我统计销售报表中的数据"
|
||||
**📝 内容处理**
|
||||
- "总结一下这份文档"
|
||||
- "对比这两个文档的差异"
|
||||
|
||||
📝 **智能填表**
|
||||
- "根据员工信息表创建一个考勤汇总表"
|
||||
- "用财务文档填充报销模板"
|
||||
**🔍 智能问答**
|
||||
- "文档里说了些什么?"
|
||||
- "有多少家医院?"
|
||||
|
||||
请告诉我您想做什么?`,
|
||||
created_at: new Date().toISOString()
|
||||
}
|
||||
]);
|
||||
|
||||
// 获取已上传的文档ID列表
|
||||
loadDocuments();
|
||||
}
|
||||
}, []);
|
||||
|
||||
const loadDocuments = async () => {
|
||||
try {
|
||||
const result = await backendApi.getDocuments(undefined, 50);
|
||||
if (result.success && result.documents) {
|
||||
const docIds = result.documents.map((d: any) => d.doc_id);
|
||||
setCurrentDocIds(docIds);
|
||||
if (docIds.length > 0) {
|
||||
console.log(`已加载 ${docIds.length} 个文档`);
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('获取文档列表失败:', err);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
// Scroll to bottom
|
||||
if (scrollAreaRef.current) {
|
||||
@@ -89,95 +116,126 @@ const InstructionChat: React.FC = () => {
|
||||
setLoading(true);
|
||||
|
||||
try {
|
||||
// TODO: 后端对话接口,暂用模拟响应
|
||||
await new Promise(resolve => setTimeout(resolve, 1500));
|
||||
// 使用真实的智能指令 API
|
||||
const response = await backendApi.instructionChat(
|
||||
input.trim(),
|
||||
currentDocIds.length > 0 ? currentDocIds : undefined
|
||||
);
|
||||
|
||||
// 简单的命令解析演示
|
||||
const userInput = userMessage.content.toLowerCase();
|
||||
let response = '';
|
||||
// 根据意图类型生成友好响应
|
||||
let responseContent = '';
|
||||
const resultData = response.result;
|
||||
|
||||
if (userInput.includes('列出') || userInput.includes('列表')) {
|
||||
const result = await backendApi.getDocuments(undefined, 10);
|
||||
if (result.success && result.documents && result.documents.length > 0) {
|
||||
response = `已为您找到 ${result.documents.length} 个文档:\n\n`;
|
||||
result.documents.slice(0, 5).forEach((doc: any, idx: number) => {
|
||||
response += `${idx + 1}. **${doc.original_filename}** (${doc.doc_type.toUpperCase()})\n`;
|
||||
response += ` - 大小: ${(doc.file_size / 1024).toFixed(1)} KB\n`;
|
||||
response += ` - 时间: ${new Date(doc.created_at).toLocaleDateString()}\n\n`;
|
||||
switch (response.intent) {
|
||||
case 'extract':
|
||||
// 信息提取结果
|
||||
const extracted = resultData?.extracted_data || {};
|
||||
const keys = Object.keys(extracted);
|
||||
if (keys.length > 0) {
|
||||
responseContent = `✅ 已提取到 ${keys.length} 个字段的数据:\n\n`;
|
||||
for (const [key, value] of Object.entries(extracted)) {
|
||||
const values = Array.isArray(value) ? value : [value];
|
||||
responseContent += `**${key}**: ${values.slice(0, 3).join(', ')}${values.length > 3 ? '...' : ''}\n`;
|
||||
}
|
||||
responseContent += `\n💡 您可以将这些数据填入表格。`;
|
||||
} else {
|
||||
responseContent = '未能从文档中提取到相关数据。请尝试更明确的字段名称。';
|
||||
}
|
||||
break;
|
||||
|
||||
case 'fill_table':
|
||||
// 填表结果
|
||||
const filled = resultData?.result?.filled_data || {};
|
||||
const filledKeys = Object.keys(filled);
|
||||
if (filledKeys.length > 0) {
|
||||
responseContent = `✅ 填表完成!成功填写 ${filledKeys.length} 个字段:\n\n`;
|
||||
for (const [key, value] of Object.entries(filled)) {
|
||||
const values = Array.isArray(value) ? value : [value];
|
||||
responseContent += `**${key}**: ${values.slice(0, 3).join(', ')}\n`;
|
||||
}
|
||||
responseContent += `\n📋 请到【智能填表】页面查看或导出结果。`;
|
||||
} else {
|
||||
responseContent = '填表未能提取到数据。请检查模板表头和数据源内容。';
|
||||
}
|
||||
break;
|
||||
|
||||
case 'summarize':
|
||||
// 摘要结果
|
||||
const summaries = resultData?.summaries || [];
|
||||
if (summaries.length > 0) {
|
||||
responseContent = `📄 找到 ${summaries.length} 个文档的摘要:\n\n`;
|
||||
summaries.forEach((s: any, idx: number) => {
|
||||
responseContent += `**${idx + 1}. ${s.filename}**\n${s.content_preview}\n\n`;
|
||||
});
|
||||
if (result.documents.length > 5) {
|
||||
response += `...还有 ${result.documents.length - 5} 个文档`;
|
||||
}
|
||||
} else {
|
||||
response = '暂未找到已上传的文档,您可以先上传一些文档试试。';
|
||||
responseContent = '未能生成摘要。请确保已上传文档。';
|
||||
}
|
||||
} else if (userInput.includes('分析') || userInput.includes('excel') || userInput.includes('报表')) {
|
||||
response = `好的,我可以帮您分析 Excel 文件。
|
||||
break;
|
||||
|
||||
请告诉我:
|
||||
1. 您想分析哪个 Excel 文件?
|
||||
2. 需要什么样的分析?(数据摘要/统计分析/图表生成)
|
||||
|
||||
或者您可以直接告诉我您想从数据中了解什么,我来为您生成分析。`;
|
||||
} else if (userInput.includes('填表') || userInput.includes('模板')) {
|
||||
response = `好的,要进行智能填表,我需要:
|
||||
|
||||
1. **上传表格模板** - 您要填写的表格模板文件(Excel 或 Word 格式)
|
||||
2. **选择数据源** - 包含要填写内容的源文档
|
||||
|
||||
您可以去【智能填表】页面完成这些操作,或者告诉我您具体想填什么类型的表格,我来指导您操作。`;
|
||||
} else if (userInput.includes('删除')) {
|
||||
response = `要删除文档,请告诉我:
|
||||
|
||||
- 要删除的文件名是什么?
|
||||
- 或者您可以到【文档中心】页面手动选择并删除文档
|
||||
|
||||
⚠️ 删除操作不可恢复,请确认后再操作。`;
|
||||
} else if (userInput.includes('帮助') || userInput.includes('help')) {
|
||||
response = `**我可以帮您完成以下操作:**
|
||||
|
||||
📄 **文档管理**
|
||||
- 列出/搜索已上传的文档
|
||||
- 查看文档详情和元数据
|
||||
- 删除不需要的文档
|
||||
|
||||
📊 **Excel 处理**
|
||||
- 分析 Excel 文件内容
|
||||
- 生成数据统计和图表
|
||||
- 导出处理后的数据
|
||||
|
||||
📝 **智能填表**
|
||||
- 上传表格模板
|
||||
- 从文档中提取信息填入模板
|
||||
- 导出填写完成的表格
|
||||
|
||||
📋 **任务历史**
|
||||
- 查看历史处理任务
|
||||
- 重新执行或导出结果
|
||||
|
||||
请直接告诉我您想做什么!`;
|
||||
case 'question':
|
||||
// 问答结果
|
||||
if (resultData?.answer) {
|
||||
responseContent = `**问题**: ${resultData.question}\n\n**答案**: ${resultData.answer}`;
|
||||
} else {
|
||||
response = `我理解您想要: "${input.trim()}"
|
||||
responseContent = resultData?.message || '我找到了相关信息,请查看上文。';
|
||||
}
|
||||
break;
|
||||
|
||||
目前我还在学习如何更好地理解您的需求。您可以尝试:
|
||||
case 'search':
|
||||
// 搜索结果
|
||||
const searchResults = resultData?.results || [];
|
||||
if (searchResults.length > 0) {
|
||||
responseContent = `🔍 找到 ${searchResults.length} 条相关内容:\n\n`;
|
||||
searchResults.slice(0, 5).forEach((r: any, idx: number) => {
|
||||
responseContent += `**${idx + 1}.** ${r.content?.substring(0, 100)}...\n\n`;
|
||||
});
|
||||
} else {
|
||||
responseContent = '未找到相关内容。请尝试其他关键词。';
|
||||
}
|
||||
break;
|
||||
|
||||
1. **上传文档** - 去【文档中心】上传 docx/md/txt 文件
|
||||
2. **分析 Excel** - 去【Excel解析】上传并分析 Excel 文件
|
||||
3. **智能填表** - 去【智能填表】创建填表任务
|
||||
case 'compare':
|
||||
// 对比结果
|
||||
const comparison = resultData?.comparison || [];
|
||||
if (comparison.length > 0) {
|
||||
responseContent = `📊 对比了 ${comparison.length} 个文档:\n\n`;
|
||||
comparison.forEach((c: any) => {
|
||||
responseContent += `- **${c.filename}**: ${c.doc_type}, ${c.content_length} 字\n`;
|
||||
});
|
||||
} else {
|
||||
responseContent = '需要至少2个文档才能进行对比。';
|
||||
}
|
||||
break;
|
||||
|
||||
或者您可以更具体地描述您想做的事情,我会尽力帮助您!`;
|
||||
case 'unknown':
|
||||
responseContent = `我理解您想要: "${input.trim()}"\n\n但我目前无法完成此操作。您可以尝试:\n\n1. **提取数据**: "提取医院数量和床位数"\n2. **填表**: "根据这些数据填表"\n3. **总结**: "总结这份文档"\n4. **问答**: "文档里说了什么?"\n5. **搜索**: "搜索相关内容"`;
|
||||
break;
|
||||
|
||||
default:
|
||||
responseContent = response.message || resultData?.message || '已完成您的请求。';
|
||||
}
|
||||
|
||||
const assistantMessage: ChatMessage = {
|
||||
id: Math.random().toString(36).substring(7),
|
||||
role: 'assistant',
|
||||
content: response,
|
||||
created_at: new Date().toISOString()
|
||||
content: responseContent,
|
||||
created_at: new Date().toISOString(),
|
||||
intent: response.intent,
|
||||
result: resultData
|
||||
};
|
||||
|
||||
setMessages(prev => [...prev, assistantMessage]);
|
||||
} catch (err: any) {
|
||||
toast.error('请求失败,请重试');
|
||||
console.error('指令执行失败:', err);
|
||||
toast.error(err.message || '请求失败,请重试');
|
||||
|
||||
const errorMessage: ChatMessage = {
|
||||
id: Math.random().toString(36).substring(7),
|
||||
role: 'assistant',
|
||||
content: `抱歉,处理您的请求时遇到了问题:${err.message}\n\n请稍后重试,或尝试更简单的指令。`,
|
||||
created_at: new Date().toISOString()
|
||||
};
|
||||
setMessages(prev => [...prev, errorMessage]);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
@@ -189,10 +247,10 @@ const InstructionChat: React.FC = () => {
|
||||
};
|
||||
|
||||
const quickActions = [
|
||||
{ label: '列出所有文档', icon: FileText, action: () => setInput('列出所有已上传的文档') },
|
||||
{ label: '分析 Excel 数据', icon: TableProperties, action: () => setInput('分析一下 Excel 文件') },
|
||||
{ label: '智能填表', icon: Sparkles, action: () => setInput('我想进行智能填表') },
|
||||
{ label: '帮助', icon: Sparkles, action: () => setInput('帮助') }
|
||||
{ label: '提取医院数量', icon: Search, action: () => setInput('提取文档中的医院数量和床位数') },
|
||||
{ label: '智能填表', icon: TableProperties, action: () => setInput('根据这些数据填表') },
|
||||
{ label: '总结文档', icon: MessageSquare, action: () => setInput('总结一下这份文档') },
|
||||
{ label: '智能问答', icon: Bot, action: () => setInput('文档里说了些什么?') }
|
||||
];
|
||||
|
||||
return (
|
||||
|
||||
Reference in New Issue
Block a user