feat: Add base project structure with API and web components
This commit is contained in:
340
api/app/services/llm_client.py
Normal file
340
api/app/services/llm_client.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""LLM 客户端适配器 - 支持多种 LLM 提供商"""
|
||||
import os
|
||||
import json
|
||||
from typing import Optional, Dict, Any
|
||||
from abc import ABC, abstractmethod
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class BaseLLMClient(ABC):
|
||||
"""LLM 客户端基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def chat(self, prompt: str, **kwargs) -> str:
|
||||
"""发送聊天请求
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
LLM 响应文本
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIClient(BaseLLMClient):
|
||||
"""OpenAI 客户端"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model: str = "gpt-3.5-turbo",
|
||||
base_url: Optional[str] = None
|
||||
):
|
||||
"""初始化 OpenAI 客户端
|
||||
|
||||
Args:
|
||||
api_key: API 密钥
|
||||
model: 模型名称
|
||||
base_url: API 基础 URL(可选,用于兼容其他服务)
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("OpenAI API key 未配置")
|
||||
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("请安装 openai 库: pip install openai")
|
||||
|
||||
async def chat(self, prompt: str, **kwargs) -> str:
|
||||
"""发送聊天请求
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
**kwargs: 其他参数(temperature, max_tokens 等)
|
||||
|
||||
Returns:
|
||||
LLM 响应文本
|
||||
"""
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=kwargs.get("temperature", 0.3),
|
||||
max_tokens=kwargs.get("max_tokens", 500)
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API 调用失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
class AzureOpenAIClient(BaseLLMClient):
|
||||
"""Azure OpenAI 客户端"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
deployment_name: Optional[str] = None,
|
||||
api_version: str = "2024-02-15-preview"
|
||||
):
|
||||
"""初始化 Azure OpenAI 客户端
|
||||
|
||||
Args:
|
||||
api_key: API 密钥
|
||||
endpoint: Azure 端点
|
||||
deployment_name: 部署名称
|
||||
api_version: API 版本
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
|
||||
self.endpoint = endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
self.deployment_name = deployment_name or os.getenv("AZURE_OPENAI_DEPLOYMENT")
|
||||
self.api_version = api_version
|
||||
|
||||
if not all([self.api_key, self.endpoint, self.deployment_name]):
|
||||
raise ValueError("Azure OpenAI 配置不完整")
|
||||
|
||||
try:
|
||||
from openai import AsyncAzureOpenAI
|
||||
self.client = AsyncAzureOpenAI(
|
||||
api_key=self.api_key,
|
||||
azure_endpoint=self.endpoint,
|
||||
api_version=self.api_version
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("请安装 openai 库: pip install openai")
|
||||
|
||||
async def chat(self, prompt: str, **kwargs) -> str:
|
||||
"""发送聊天请求"""
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.deployment_name,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=kwargs.get("temperature", 0.3),
|
||||
max_tokens=kwargs.get("max_tokens", 500)
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Azure OpenAI API 调用失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
class AnthropicClient(BaseLLMClient):
|
||||
"""Anthropic Claude 客户端"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model: str = "claude-3-sonnet-20240229"
|
||||
):
|
||||
"""初始化 Anthropic 客户端
|
||||
|
||||
Args:
|
||||
api_key: API 密钥
|
||||
model: 模型名称
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
self.model = model
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("Anthropic API key 未配置")
|
||||
|
||||
try:
|
||||
from anthropic import AsyncAnthropic
|
||||
self.client = AsyncAnthropic(api_key=self.api_key)
|
||||
except ImportError:
|
||||
raise ImportError("请安装 anthropic 库: pip install anthropic")
|
||||
|
||||
async def chat(self, prompt: str, **kwargs) -> str:
|
||||
"""发送聊天请求"""
|
||||
try:
|
||||
response = await self.client.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=kwargs.get("max_tokens", 500),
|
||||
temperature=kwargs.get("temperature", 0.3),
|
||||
messages=[{"role": "user", "content": prompt}]
|
||||
)
|
||||
|
||||
return response.content[0].text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API 调用失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
class LocalLLMClient(BaseLLMClient):
|
||||
"""本地 LLM 客户端(通过 HTTP API)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://localhost:8000",
|
||||
model: str = "local-model"
|
||||
):
|
||||
"""初始化本地 LLM 客户端
|
||||
|
||||
Args:
|
||||
base_url: API 基础 URL
|
||||
model: 模型名称
|
||||
"""
|
||||
self.base_url = base_url
|
||||
self.model = model
|
||||
|
||||
try:
|
||||
import httpx
|
||||
self.client = httpx.AsyncClient(timeout=30.0)
|
||||
except ImportError:
|
||||
raise ImportError("请安装 httpx 库: pip install httpx")
|
||||
|
||||
async def chat(self, prompt: str, **kwargs) -> str:
|
||||
"""发送聊天请求"""
|
||||
try:
|
||||
response = await self.client.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json={
|
||||
"model": self.model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": kwargs.get("temperature", 0.3),
|
||||
"max_tokens": kwargs.get("max_tokens", 500)
|
||||
}
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
return data["choices"][0]["message"]["content"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"本地 LLM API 调用失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
class MockLLMClient(BaseLLMClient):
|
||||
"""模拟 LLM 客户端(用于测试)"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化模拟客户端"""
|
||||
self.call_count = 0
|
||||
|
||||
async def chat(self, prompt: str, **kwargs) -> str:
|
||||
"""发送聊天请求(返回模拟结果)"""
|
||||
self.call_count += 1
|
||||
|
||||
logger.info(f"模拟 LLM 调用 (第 {self.call_count} 次)")
|
||||
|
||||
# 简单的规则匹配
|
||||
prompt_lower = prompt.lower()
|
||||
|
||||
if "数学" in prompt_lower or "方程" in prompt_lower or "计算" in prompt_lower:
|
||||
return json.dumps({
|
||||
"agent_id": "math-agent",
|
||||
"confidence": 0.9,
|
||||
"reason": "消息包含数学相关内容"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
elif "化学" in prompt_lower or "反应" in prompt_lower or "元素" in prompt_lower:
|
||||
return json.dumps({
|
||||
"agent_id": "chemistry-agent",
|
||||
"confidence": 0.85,
|
||||
"reason": "消息包含化学相关内容"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
elif "物理" in prompt_lower or "力" in prompt_lower or "速度" in prompt_lower:
|
||||
return json.dumps({
|
||||
"agent_id": "physics-agent",
|
||||
"confidence": 0.88,
|
||||
"reason": "消息包含物理相关内容"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
elif "语文" in prompt_lower or "古诗" in prompt_lower or "作文" in prompt_lower:
|
||||
return json.dumps({
|
||||
"agent_id": "chinese-agent",
|
||||
"confidence": 0.87,
|
||||
"reason": "消息包含语文相关内容"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
elif "英语" in prompt_lower or "单词" in prompt_lower or "语法" in prompt_lower:
|
||||
return json.dumps({
|
||||
"agent_id": "english-agent",
|
||||
"confidence": 0.86,
|
||||
"reason": "消息包含英语相关内容"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
else:
|
||||
return json.dumps({
|
||||
"agent_id": "math-agent",
|
||||
"confidence": 0.5,
|
||||
"reason": "无法明确判断,使用默认 Agent"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
class LLMClientFactory:
|
||||
"""LLM 客户端工厂"""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
provider: str = "mock",
|
||||
**kwargs
|
||||
) -> BaseLLMClient:
|
||||
"""创建 LLM 客户端
|
||||
|
||||
Args:
|
||||
provider: 提供商名称 (openai, azure, anthropic, local, mock)
|
||||
**kwargs: 客户端配置参数
|
||||
|
||||
Returns:
|
||||
LLM 客户端实例
|
||||
"""
|
||||
provider = provider.lower()
|
||||
|
||||
if provider == "openai":
|
||||
return OpenAIClient(**kwargs)
|
||||
|
||||
elif provider == "azure":
|
||||
return AzureOpenAIClient(**kwargs)
|
||||
|
||||
elif provider == "anthropic":
|
||||
return AnthropicClient(**kwargs)
|
||||
|
||||
elif provider == "local":
|
||||
return LocalLLMClient(**kwargs)
|
||||
|
||||
elif provider == "mock":
|
||||
return MockLLMClient()
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的 LLM 提供商: {provider}")
|
||||
|
||||
@staticmethod
|
||||
def create_from_env() -> BaseLLMClient:
|
||||
"""从环境变量创建 LLM 客户端
|
||||
|
||||
环境变量:
|
||||
- LLM_PROVIDER: 提供商名称
|
||||
- OPENAI_API_KEY: OpenAI API 密钥
|
||||
- AZURE_OPENAI_API_KEY: Azure OpenAI API 密钥
|
||||
- ANTHROPIC_API_KEY: Anthropic API 密钥
|
||||
|
||||
Returns:
|
||||
LLM 客户端实例
|
||||
"""
|
||||
provider = os.getenv("LLM_PROVIDER", "mock")
|
||||
|
||||
logger.info(f"从环境变量创建 LLM 客户端: {provider}")
|
||||
|
||||
return LLMClientFactory.create(provider)
|
||||
Reference in New Issue
Block a user