Files
FilesReadSystem/backend/app/services/llm_service.py
2026-04-14 14:58:14 +08:00

401 lines
13 KiB
Python

"""
LLM 服务模块 - 封装大模型 API 调用
"""
import logging
from typing import Dict, Any, List, Optional, AsyncGenerator
import httpx
from app.config import settings
logger = logging.getLogger(__name__)
class LLMService:
"""大语言模型服务类"""
def __init__(self):
self.api_key = settings.LLM_API_KEY
self.base_url = settings.LLM_BASE_URL
self.model_name = settings.LLM_MODEL_NAME
async def chat(
self,
messages: List[Dict[str, str]],
temperature: float = 0.7,
max_tokens: Optional[int] = None,
**kwargs
) -> Dict[str, Any]:
"""
调用聊天 API
Args:
messages: 消息列表,格式为 [{"role": "user", "content": "..."}]
temperature: 温度参数,控制随机性
max_tokens: 最大生成 token 数
**kwargs: 其他参数
Returns:
Dict[str, Any]: API 响应结果
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# DeepSeek API temperature 范围: (0, 2]
if temperature < 0.01:
temperature = 0.01
elif temperature > 2.0:
temperature = 2.0
payload = {
"model": self.model_name,
"messages": messages,
"temperature": temperature
}
# DeepSeek API 限制 max_tokens 范围
if max_tokens:
if max_tokens > 8192:
max_tokens = 8192
payload["max_tokens"] = max_tokens
# 移除不兼容的参数
for key in ["stream", "stop", "presence_penalty", "frequency_penalty", "logit_bias"]:
kwargs.pop(key, None)
# 添加其他参数
payload.update(kwargs)
# 验证消息格式
validated_messages = []
for i, msg in enumerate(messages):
role = msg.get("role", "")
content = msg.get("content", "")
# 确保 content 是字符串
if not isinstance(content, str):
logger.warning(f"消息[{i}] content 不是字符串类型: {type(content)},转换为字符串")
content = str(content)
# 确保 role 有效
if role not in ["system", "user", "assistant"]:
logger.warning(f"消息[{i}] role 无效: {role},跳过")
continue
validated_messages.append({"role": role, "content": content})
payload["messages"] = validated_messages
logger.info(f"验证后消息数量: {len(validated_messages)}")
try:
logger.info(f"LLM API 请求: model={self.model_name}, base_url={self.base_url}, temperature={temperature}, max_tokens={max_tokens}")
logger.info(f"消息数量: {len(messages)}")
total_content_len = sum(len(msg.get('content', '')) for msg in messages)
logger.info(f"总内容长度: {total_content_len}")
async with httpx.AsyncClient(timeout=120.0) as client:
response = await client.post(
f"{self.base_url}/chat/completions",
headers=headers,
json=payload
)
logger.info(f"LLM API 响应状态: {response.status_code}")
if response.status_code != 200:
error_text = response.text
logger.error(f"LLM API 错误响应: {error_text}")
# 尝试解析错误详情
try:
error_json = response.json()
error_msg = error_json.get("error", {}).get("message", error_text)
logger.error(f"错误详情: {error_msg}")
except:
pass
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
logger.error(f"LLM API HTTP 错误: {e.response.status_code} - {e.response.text}")
raise
except Exception as e:
logger.error(f"LLM API 调用异常: {str(e)}", exc_info=True)
raise
def extract_message_content(self, response: Dict[str, Any]) -> str:
"""
从 API 响应中提取消息内容
Args:
response: API 响应
Returns:
str: 消息内容
"""
try:
return response["choices"][0]["message"]["content"]
except (KeyError, IndexError) as e:
logger.error(f"解析 API 响应失败: {str(e)}")
raise
async def chat_stream(
self,
messages: List[Dict[str, str]],
temperature: float = 0.7,
max_tokens: Optional[int] = None,
**kwargs
) -> AsyncGenerator[Dict[str, Any], None]:
"""
流式调用聊天 API
Args:
messages: 消息列表
temperature: 温度参数
max_tokens: 最大 token 数
**kwargs: 其他参数
Yields:
Dict[str, Any]: 包含 delta 内容的块
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# DeepSeek API 限制
if max_tokens and max_tokens > 8192:
max_tokens = 8192
payload = {
"model": self.model_name,
"messages": messages,
"temperature": temperature,
"stream": True
}
if max_tokens:
payload["max_tokens"] = max_tokens
# 移除不兼容的参数
for key in ["stop", "presence_penalty", "frequency_penalty", "logit_bias"]:
kwargs.pop(key, None)
payload.update(kwargs)
try:
logger.info(f"LLM 流式 API 请求: model={self.model_name}, max_tokens={max_tokens}")
async with httpx.AsyncClient(timeout=120.0) as client:
async with client.stream(
"POST",
f"{self.base_url}/chat/completions",
headers=headers,
json=payload
) as response:
if response.status_code != 200:
error_text = await response.aread()
logger.error(f"LLM 流式 API 错误: {response.status_code} - {error_text}")
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
data = line[6:]
if data == "[DONE]":
break
try:
import json as json_module
chunk = json_module.loads(data)
delta = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "")
if delta:
yield {"content": delta}
except json_module.JSONDecodeError:
continue
except httpx.HTTPStatusError as e:
logger.error(f"LLM 流式 API 请求失败: {e.response.status_code}")
raise
except Exception as e:
logger.error(f"LLM 流式 API 调用异常: {str(e)}", exc_info=True)
raise
async def analyze_excel_data(
self,
excel_data: Dict[str, Any],
user_prompt: str,
analysis_type: str = "general"
) -> Dict[str, Any]:
"""
分析 Excel 数据
Args:
excel_data: Excel 解析后的数据
user_prompt: 用户提示词
analysis_type: 分析类型 (general, summary, statistics, insights)
Returns:
Dict[str, Any]: 分析结果
"""
# 构建 Prompt
system_prompt = self._get_system_prompt(analysis_type)
user_message = self._format_user_message(excel_data, user_prompt)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message}
]
try:
response = await self.chat(
messages=messages,
temperature=0.3, # 较低的温度以获得更稳定的输出
max_tokens=2000
)
content = self.extract_message_content(response)
return {
"success": True,
"analysis": content,
"model": self.model_name,
"analysis_type": analysis_type
}
except Exception as e:
logger.error(f"Excel 数据分析失败: {str(e)}")
return {
"success": False,
"error": str(e),
"analysis": None
}
def _get_system_prompt(self, analysis_type: str) -> str:
"""获取系统提示词"""
prompts = {
"general": """你是一个专业的数据分析师。请分析用户提供的 Excel 数据,提供有价值的见解和建议。
请按照以下格式输出:
1. 数据概览
2. 关键发现
3. 数据质量评估
4. 建议
输出语言:中文""",
"summary": """你是一个专业的数据分析师。请对用户提供的 Excel 数据进行简洁的总结。
输出格式:
- 数据行数和列数
- 主要列的说明
- 数据范围概述
输出语言:中文""",
"statistics": """你是一个专业的数据分析师。请对用户提供的 Excel 数据进行统计分析。
请分析:
- 数值型列的统计信息(平均值、中位数、最大值、最小值)
- 分类列的分布情况
- 数据相关性
输出语言:中文,使用表格或结构化格式展示""",
"insights": """你是一个专业的数据分析师。请深入挖掘用户提供的 Excel 数据,提供有价值的洞察。
请分析:
1. 数据中的异常值或特殊模式
2. 数据之间的潜在关联
3. 基于数据的业务建议
4. 数据趋势分析(如适用)
输出语言:中文,提供详细且可操作的建议"""
}
return prompts.get(analysis_type, prompts["general"])
def _format_user_message(self, excel_data: Dict[str, Any], user_prompt: str) -> str:
"""格式化用户消息"""
columns = excel_data.get("columns", [])
rows = excel_data.get("rows", [])
row_count = excel_data.get("row_count", 0)
column_count = excel_data.get("column_count", 0)
# 构建数据描述
data_info = f"""
Excel 数据概览:
- 行数: {row_count}
- 列数: {column_count}
- 列名: {', '.join(columns)}
数据样例(前 5 行):
"""
# 添加数据样例
for i, row in enumerate(rows[:5], 1):
row_str = " | ".join([f"{col}: {row.get(col, '')}" for col in columns])
data_info += f"{i} 行: {row_str}\n"
if row_count > 5:
data_info += f"\n(还有 {row_count - 5} 行数据...)\n"
# 添加用户自定义提示
if user_prompt and user_prompt.strip():
data_info += f"\n用户需求:\n{user_prompt}"
else:
data_info += "\n用户需求: 请对上述数据进行分析"
return data_info
async def analyze_with_template(
self,
excel_data: Dict[str, Any],
template_prompt: str
) -> Dict[str, Any]:
"""
使用自定义模板分析 Excel 数据
Args:
excel_data: Excel 解析后的数据
template_prompt: 自定义提示词模板
Returns:
Dict[str, Any]: 分析结果
"""
system_prompt = """你是一个专业的数据分析师。请根据用户提供的自定义提示词分析 Excel 数据。
请严格按照用户的要求进行分析,输出清晰、有条理的结果。
输出语言:中文"""
user_message = self._format_user_message(excel_data, template_prompt)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message}
]
try:
response = await self.chat(
messages=messages,
temperature=0.5,
max_tokens=3000
)
content = self.extract_message_content(response)
return {
"success": True,
"analysis": content,
"model": self.model_name,
"is_template": True
}
except Exception as e:
logger.error(f"自定义模板分析失败: {str(e)}")
return {
"success": False,
"error": str(e),
"analysis": None
}
# 全局单例
llm_service = LLMService()