1. create a basic model to check if the name and provider are duplicated. 2. The result shows error models because the provider created API Keys for all matching models.
741 lines
32 KiB
Python
741 lines
32 KiB
Python
from datetime import datetime
|
||
from sqlalchemy.orm import Session
|
||
from typing import List, Optional, Dict, Any
|
||
import uuid
|
||
import math
|
||
import time
|
||
import asyncio
|
||
|
||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType
|
||
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository, ModelBaseRepository
|
||
from app.schemas import model_schema
|
||
from app.schemas.model_schema import (
|
||
ModelConfigCreate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
||
ModelConfigQuery, ModelStats, ModelConfigQueryNew
|
||
)
|
||
from app.core.logging_config import get_business_logger
|
||
from app.schemas.response_schema import PageData, PageMeta
|
||
from app.core.exceptions import BusinessException
|
||
from app.core.error_codes import BizCode
|
||
|
||
logger = get_business_logger()
|
||
|
||
|
||
class ModelConfigService:
|
||
"""模型配置服务"""
|
||
|
||
@staticmethod
|
||
def get_model_by_id(db: Session, model_id: uuid.UUID, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||
"""根据ID获取模型配置"""
|
||
model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||
if not model:
|
||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||
return model
|
||
|
||
@staticmethod
|
||
def get_model_list(db: Session, query: ModelConfigQuery, tenant_id: uuid.UUID | None = None) -> PageData:
|
||
"""获取模型配置列表"""
|
||
models, total = ModelConfigRepository.get_list(db, query, tenant_id=tenant_id)
|
||
pages = math.ceil(total / query.pagesize) if total > 0 else 0
|
||
|
||
return PageData(
|
||
page=PageMeta(
|
||
page=query.page,
|
||
pagesize=query.pagesize,
|
||
total=total,
|
||
hasnext=query.page < pages
|
||
),
|
||
items=[model_schema.ModelConfig.model_validate(model) for model in models]
|
||
)
|
||
|
||
@staticmethod
|
||
def get_model_list_new(db: Session, query: ModelConfigQueryNew, tenant_id: uuid.UUID | None = None) -> List[dict]:
|
||
"""获取模型配置列表"""
|
||
provider_groups, total = ModelConfigRepository.get_list_new(db, query, tenant_id=tenant_id)
|
||
|
||
items = []
|
||
for provider, models in provider_groups.items():
|
||
# 验证每个模型并封装分组信息
|
||
validated_models = [model_schema.ModelConfig.model_validate(model) for model in models]
|
||
tags = list({model.type for model in validated_models})
|
||
group_item = {
|
||
"provider": provider, # 服务商名称
|
||
"logo": validated_models[0].logo,
|
||
"tags": tags,
|
||
"models": validated_models # 该服务商下的所有模型
|
||
}
|
||
items.append(group_item)
|
||
|
||
return items
|
||
|
||
@staticmethod
|
||
def get_model_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||
"""根据名称获取模型配置"""
|
||
model = ModelConfigRepository.get_by_name(db, name, tenant_id=tenant_id)
|
||
if not model:
|
||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||
return model
|
||
|
||
@staticmethod
|
||
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ModelConfig]:
|
||
"""按名称模糊匹配获取模型配置列表"""
|
||
return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit)
|
||
|
||
@staticmethod
|
||
async def validate_model_config(
|
||
db: Session,
|
||
*,
|
||
model_name: str,
|
||
provider: str,
|
||
api_key: str,
|
||
api_base: Optional[str] = None,
|
||
model_type: str = "llm",
|
||
test_message: str = "Hello"
|
||
) -> Dict[str, Any]:
|
||
"""验证模型配置是否有效
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
model_name: 模型名称
|
||
provider: 提供商
|
||
api_key: API密钥
|
||
api_base: API基础URL
|
||
model_type: 模型类型 (llm/chat/embedding/rerank)
|
||
test_message: 测试消息
|
||
|
||
Returns:
|
||
Dict: 验证结果
|
||
"""
|
||
from app.core.models import RedBearLLM, RedBearRerank
|
||
from app.core.models.base import RedBearModelConfig
|
||
from app.core.models.embedding import RedBearEmbeddings
|
||
import traceback
|
||
|
||
try:
|
||
start_time = time.time()
|
||
|
||
model_config = RedBearModelConfig(
|
||
model_name=model_name,
|
||
provider=provider,
|
||
api_key=api_key,
|
||
base_url=api_base,
|
||
temperature=0.7,
|
||
max_tokens=100
|
||
)
|
||
|
||
# 根据模型类型选择不同的验证方式
|
||
model_type_lower = model_type.lower()
|
||
|
||
if model_type_lower in ["llm", "chat"]:
|
||
# LLM/Chat 模型验证 - 统一使用字符串输入
|
||
llm = RedBearLLM(model_config, type=ModelType.LLM if model_type_lower == "llm" else ModelType.CHAT)
|
||
response = await llm.ainvoke(test_message)
|
||
elapsed_time = time.time() - start_time
|
||
|
||
content = response.content if hasattr(response, 'content') else str(response)
|
||
usage = None
|
||
if hasattr(response, 'usage_metadata'):
|
||
usage = {
|
||
"input_tokens": getattr(response.usage_metadata, 'input_tokens', 0),
|
||
"output_tokens": getattr(response.usage_metadata, 'output_tokens', 0),
|
||
"total_tokens": getattr(response.usage_metadata, 'total_tokens', 0)
|
||
}
|
||
|
||
return {
|
||
"valid": True,
|
||
"message": f"{model_type.upper()} 模型配置验证成功",
|
||
"response": content,
|
||
"elapsed_time": elapsed_time,
|
||
"usage": usage,
|
||
"error": None
|
||
}
|
||
|
||
elif model_type_lower == "embedding":
|
||
# Embedding 模型验证(在线程中运行同步方法)
|
||
embedding = RedBearEmbeddings(model_config)
|
||
test_texts = [test_message, "测试文本"]
|
||
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
|
||
elapsed_time = time.time() - start_time
|
||
|
||
return {
|
||
"valid": True,
|
||
"message": "Embedding 模型配置验证成功",
|
||
"response": f"成功生成 {len(vectors)} 个向量,维度: {len(vectors[0]) if vectors else 0}",
|
||
"elapsed_time": elapsed_time,
|
||
"usage": {
|
||
"input_tokens": len(test_message),
|
||
"vector_count": len(vectors),
|
||
"vector_dimension": len(vectors[0]) if vectors else 0
|
||
},
|
||
"error": None
|
||
}
|
||
|
||
elif model_type_lower == "rerank":
|
||
# Rerank 模型验证(在线程中运行同步方法)
|
||
rerank = RedBearRerank(model_config)
|
||
query = test_message
|
||
documents = ["这是第一个文档", "这是第二个文档", "这是第三个文档"]
|
||
results = await asyncio.to_thread(rerank.rerank, query=query, documents=documents, top_n=3)
|
||
elapsed_time = time.time() - start_time
|
||
|
||
return {
|
||
"valid": True,
|
||
"message": "Rerank 模型配置验证成功",
|
||
"response": f"成功对 {len(documents)} 个文档进行重排序,返回 top {len(results) if results else 0} 结果",
|
||
"elapsed_time": elapsed_time,
|
||
"usage": {
|
||
"query_length": len(query),
|
||
"document_count": len(documents),
|
||
"result_count": len(results) if results else 0
|
||
},
|
||
"error": None
|
||
}
|
||
|
||
else:
|
||
return {
|
||
"valid": False,
|
||
"message": "不支持的模型类型",
|
||
"response": None,
|
||
"elapsed_time": None,
|
||
"usage": None,
|
||
"error": f"不支持的模型类型: {model_type}"
|
||
}
|
||
|
||
except Exception as e:
|
||
# 提取详细的错误信息
|
||
error_message = str(e)
|
||
error_type = type(e).__name__
|
||
print("=========error_message:",error_message.lower())
|
||
# 特殊处理常见的错误类型
|
||
if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower():
|
||
# 区域/国家限制(适用于所有提供商)
|
||
error_message = "区域限制: 该模型在当前区域或国家/地区不可用,请检查提供商的服务区域限制"
|
||
elif "ValidationException" in error_type or "ValidationException" in error_message:
|
||
# 其他验证错误
|
||
if "access denied" in error_message.lower():
|
||
error_message = "访问被拒绝: 请检查 API 凭证和权限配置"
|
||
else:
|
||
error_message = f"验证失败: {error_message}"
|
||
elif "AuthenticationError" in error_type or "authentication" in error_message.lower():
|
||
error_message = "认证失败: API Key 无效或已过期"
|
||
elif "RateLimitError" in error_type or "rate limit" in error_message.lower():
|
||
error_message = "请求频率限制: 已超过 API 调用限制"
|
||
elif "InvalidRequestError" in error_type or "invalid request" in error_message.lower():
|
||
error_message = f"无效请求: {error_message}"
|
||
elif "model_copy" in error_message:
|
||
error_message = "模型消息格式错误: 请确保使用正确的模型类型(LLM/Chat)"
|
||
|
||
# 记录详细错误日志
|
||
logger.error(f"模型验证失败 - 类型: {error_type}, 模型: {model_name}, 提供商: {provider}")
|
||
logger.error(f"错误详情: {error_message}")
|
||
logger.debug(f"完整堆栈: {traceback.format_exc()}")
|
||
|
||
return {
|
||
"valid": False,
|
||
"message": f"{model_type.upper()} 模型配置验证失败",
|
||
"response": None,
|
||
"elapsed_time": None,
|
||
"usage": None,
|
||
"error": error_message,
|
||
"error_type": error_type
|
||
}
|
||
|
||
@staticmethod
|
||
async def create_model(db: Session, model_data: ModelConfigCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
||
"""创建模型配置"""
|
||
# 检查名称是否已存在(同租户内)
|
||
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
|
||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||
|
||
# 验证配置
|
||
if not model_data.skip_validation and model_data.api_keys:
|
||
api_key_data_list = model_data.api_keys
|
||
for api_key_data in api_key_data_list:
|
||
validation_result = await ModelConfigService.validate_model_config(
|
||
db=db,
|
||
model_name=api_key_data.model_name,
|
||
provider=api_key_data.provider,
|
||
api_key=api_key_data.api_key,
|
||
api_base=api_key_data.api_base,
|
||
model_type=model_data.type, # 传递模型类型
|
||
test_message="Hello"
|
||
)
|
||
if not validation_result["valid"]:
|
||
raise BusinessException(
|
||
f"模型配置验证失败: {validation_result['error']}",
|
||
BizCode.INVALID_PARAMETER
|
||
)
|
||
|
||
# 事务处理
|
||
api_key_datas = model_data.api_keys
|
||
model_config_data = model_data.model_dump(exclude={"api_keys", "skip_validation"})
|
||
# 添加租户ID
|
||
model_config_data["tenant_id"] = tenant_id
|
||
|
||
model = ModelConfigRepository.create(db, model_config_data)
|
||
db.flush() # 获取生成的 ID
|
||
|
||
if api_key_datas:
|
||
for api_key_data in api_key_datas:
|
||
api_key_create_schema = ModelApiKeyCreate(
|
||
model_config_ids=[model.id],
|
||
**api_key_data.model_dump()
|
||
)
|
||
ModelApiKeyRepository.create(db, api_key_create_schema)
|
||
|
||
db.commit()
|
||
db.refresh(model)
|
||
return model
|
||
|
||
@staticmethod
|
||
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||
"""更新模型配置"""
|
||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||
if not existing_model:
|
||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||
|
||
if model_data.name and model_data.name != existing_model.name:
|
||
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
|
||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||
|
||
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
|
||
db.commit()
|
||
db.refresh(model)
|
||
return model
|
||
|
||
@staticmethod
|
||
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
||
"""创建组合模型"""
|
||
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
|
||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||
|
||
# 验证所有 API Key 存在且类型匹配
|
||
for api_key_id in model_data.api_key_ids:
|
||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||
if not api_key:
|
||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
|
||
|
||
# 检查 API Key 关联的模型配置类型
|
||
for model_config in api_key.model_configs:
|
||
# chat 和 llm 类型可以兼容
|
||
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
||
config_type = model_config.type
|
||
request_type = model_data.type
|
||
|
||
if not (config_type == request_type or
|
||
(config_type in compatible_types and request_type in compatible_types)):
|
||
raise BusinessException(
|
||
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
|
||
BizCode.INVALID_PARAMETER
|
||
)
|
||
# if model_config.is_composite:
|
||
# raise BusinessException(
|
||
# f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型",
|
||
# BizCode.INVALID_PARAMETER
|
||
# )
|
||
|
||
# 创建组合模型
|
||
model_config_data = {
|
||
"tenant_id": tenant_id,
|
||
"name": model_data.name,
|
||
"type": model_data.type,
|
||
"logo": model_data.logo,
|
||
"description": model_data.description,
|
||
"provider": "composite",
|
||
"config": model_data.config,
|
||
"is_active": model_data.is_active,
|
||
"is_public": model_data.is_public,
|
||
"is_composite": True
|
||
}
|
||
if "load_balance_strategy" in model_data.model_fields_set:
|
||
model_config_data["load_balance_strategy"] = model_data.load_balance_strategy
|
||
|
||
model = ModelConfigRepository.create(db, model_config_data)
|
||
db.flush()
|
||
|
||
# 关联 API Keys
|
||
for api_key_id in model_data.api_key_ids:
|
||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||
if api_key:
|
||
model.api_keys.append(api_key)
|
||
|
||
db.commit()
|
||
db.refresh(model)
|
||
return model
|
||
|
||
@staticmethod
|
||
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
||
"""更新组合模型"""
|
||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||
if not existing_model:
|
||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||
|
||
if not existing_model.is_composite:
|
||
raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER)
|
||
|
||
# 验证所有 API Key 存在且类型匹配
|
||
for api_key_id in model_data.api_key_ids:
|
||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||
if not api_key:
|
||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
|
||
|
||
for model_config in api_key.model_configs:
|
||
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
||
config_type = model_config.type
|
||
request_type = existing_model.type
|
||
|
||
if not (config_type == request_type or
|
||
(config_type in compatible_types and request_type in compatible_types)):
|
||
raise BusinessException(
|
||
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
|
||
BizCode.INVALID_PARAMETER
|
||
)
|
||
|
||
# 更新基本信息
|
||
existing_model.name = model_data.name
|
||
# existing_model.type = model_data.type
|
||
existing_model.logo = model_data.logo
|
||
existing_model.description = model_data.description
|
||
existing_model.config = model_data.config
|
||
existing_model.is_active = model_data.is_active
|
||
existing_model.is_public = model_data.is_public
|
||
if "load_balance_strategy" in model_data.model_fields_set:
|
||
existing_model.load_balance_strategy = model_data.load_balance_strategy
|
||
|
||
# 更新 API Keys 关联
|
||
existing_model.api_keys.clear()
|
||
for api_key_id in model_data.api_key_ids:
|
||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||
if api_key:
|
||
existing_model.api_keys.append(api_key)
|
||
|
||
db.commit()
|
||
db.refresh(existing_model)
|
||
return existing_model
|
||
|
||
@staticmethod
|
||
def delete_model(db: Session, model_id: uuid.UUID, tenant_id: uuid.UUID | None = None) -> bool:
|
||
"""删除模型配置"""
|
||
if not ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id):
|
||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||
|
||
success = ModelConfigRepository.delete(db, model_id, tenant_id=tenant_id)
|
||
db.commit()
|
||
return success
|
||
|
||
@staticmethod
|
||
def get_model_stats(db: Session) -> ModelStats:
|
||
"""获取模型统计信息"""
|
||
stats_data = ModelConfigRepository.get_stats(db)
|
||
return ModelStats(
|
||
total_models=stats_data["total_models"],
|
||
active_models=stats_data["active_models"],
|
||
llm_count=stats_data["llm_count"],
|
||
embedding_count=stats_data["embedding_count"],
|
||
rerank_count=stats_data["rerank_count"],
|
||
provider_stats=stats_data["provider_stats"]
|
||
)
|
||
|
||
|
||
class ModelApiKeyService:
|
||
"""模型API Key服务"""
|
||
|
||
@staticmethod
|
||
def get_api_key_by_id(db: Session, api_key_id: uuid.UUID) -> ModelApiKey:
|
||
"""根据ID获取API Key"""
|
||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||
if not api_key:
|
||
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
|
||
return api_key
|
||
|
||
@staticmethod
|
||
def get_api_keys_by_model(db: Session, model_config_id: uuid.UUID, is_active: bool = True) -> list[ModelApiKey]:
|
||
"""根据模型配置ID获取API Key列表"""
|
||
if not ModelConfigRepository.get_by_id(db, model_config_id):
|
||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||
|
||
return ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active)
|
||
|
||
@staticmethod
|
||
async def create_api_key_by_provider(db: Session, data: model_schema.ModelApiKeyCreateByProvider) -> tuple[
|
||
list[Any], list[Any]]:
|
||
"""根据provider为多个ModelConfig创建API Key"""
|
||
created_keys = []
|
||
failed_models = [] # 记录验证失败的模型
|
||
|
||
for model_config_id in data.model_config_ids:
|
||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||
if not model_config:
|
||
continue
|
||
|
||
# 从ModelBase获取model_name
|
||
model_name = model_config.model_base.name if model_config.model_base else model_config.name
|
||
|
||
# 检查是否存在API Key(包括软删除)
|
||
existing_key = db.query(ModelApiKey).filter(
|
||
ModelApiKey.api_key == data.api_key,
|
||
ModelApiKey.provider == data.provider,
|
||
ModelApiKey.model_name == model_name
|
||
).first()
|
||
|
||
if existing_key:
|
||
# 如果已存在,重新激活并更新
|
||
if existing_key.is_active:
|
||
continue
|
||
existing_key.is_active = True
|
||
existing_key.api_base = data.api_base
|
||
existing_key.description = data.description
|
||
existing_key.config = data.config
|
||
existing_key.priority = data.priority
|
||
existing_key.model_name = model_name
|
||
|
||
# 检查是否已关联该模型配置
|
||
if model_config not in existing_key.model_configs:
|
||
existing_key.model_configs.append(model_config)
|
||
|
||
created_keys.append(existing_key)
|
||
continue
|
||
|
||
# 验证配置
|
||
validation_result = await ModelConfigService.validate_model_config(
|
||
db=db,
|
||
model_name=model_name,
|
||
provider=data.provider,
|
||
api_key=data.api_key,
|
||
api_base=data.api_base,
|
||
model_type=model_config.type,
|
||
test_message="Hello"
|
||
)
|
||
if not validation_result["valid"]:
|
||
# 记录验证失败的模型,但不抛出异常
|
||
failed_models.append(model_name)
|
||
continue
|
||
|
||
# 创建API Key
|
||
api_key_data = ModelApiKeyCreate(
|
||
model_config_ids=[model_config_id],
|
||
model_name=model_name,
|
||
description=data.description,
|
||
provider=data.provider,
|
||
api_key=data.api_key,
|
||
api_base=data.api_base,
|
||
config=data.config,
|
||
is_active=data.is_active,
|
||
priority=data.priority
|
||
)
|
||
api_key_obj = ModelApiKeyRepository.create(db, api_key_data)
|
||
created_keys.append(api_key_obj)
|
||
|
||
if created_keys:
|
||
db.commit()
|
||
for key in created_keys:
|
||
db.refresh(key)
|
||
|
||
return created_keys, failed_models
|
||
|
||
@staticmethod
|
||
async def create_api_key(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey:
|
||
# 验证所有关联的模型配置是否存在
|
||
if api_key_data.model_config_ids:
|
||
for model_config_id in api_key_data.model_config_ids:
|
||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||
if not model_config:
|
||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||
|
||
# 检查API Key是否已存在(包括软删除)
|
||
existing_key = db.query(ModelApiKey).filter(
|
||
ModelApiKey.api_key == api_key_data.api_key,
|
||
ModelApiKey.provider == api_key_data.provider,
|
||
ModelApiKey.model_name == api_key_data.model_name
|
||
).first()
|
||
|
||
if existing_key:
|
||
if existing_key.is_active:
|
||
# 如果已激活,跳过
|
||
raise BusinessException("该API Key已存在", BizCode.DUPLICATE_NAME)
|
||
# 如果已存在,重新激活并更新
|
||
existing_key.is_active = True
|
||
existing_key.api_base = api_key_data.api_base
|
||
existing_key.description = api_key_data.description
|
||
existing_key.config = api_key_data.config
|
||
existing_key.priority = api_key_data.priority
|
||
existing_key.model_name = api_key_data.model_name
|
||
|
||
# 检查是否已关联该模型配置
|
||
if model_config not in existing_key.model_configs:
|
||
existing_key.model_configs.append(model_config)
|
||
|
||
db.commit()
|
||
db.refresh(existing_key)
|
||
return existing_key
|
||
|
||
# 验证配置
|
||
validation_result = await ModelConfigService.validate_model_config(
|
||
db=db,
|
||
model_name=api_key_data.model_name,
|
||
provider=api_key_data.provider,
|
||
api_key=api_key_data.api_key,
|
||
api_base=api_key_data.api_base,
|
||
model_type=model_config.type,
|
||
test_message="Hello"
|
||
)
|
||
if not validation_result["valid"]:
|
||
raise BusinessException(
|
||
f"模型配置验证失败: {validation_result['error']}",
|
||
BizCode.INVALID_PARAMETER
|
||
)
|
||
|
||
api_key = ModelApiKeyRepository.create(db, api_key_data)
|
||
db.commit()
|
||
db.refresh(api_key)
|
||
return api_key
|
||
|
||
@staticmethod
|
||
async def update_api_key(db: Session, api_key_id: uuid.UUID, api_key_data: ModelApiKeyUpdate) -> ModelApiKey:
|
||
"""更新API Key"""
|
||
existing_api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||
if not existing_api_key:
|
||
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
|
||
|
||
# 获取关联的模型配置以获取模型类型
|
||
if existing_api_key.model_configs:
|
||
model_config = existing_api_key.model_configs[0]
|
||
|
||
validation_result = await ModelConfigService.validate_model_config(
|
||
db=db,
|
||
model_name=api_key_data.model_name or existing_api_key.model_name,
|
||
provider=api_key_data.provider or existing_api_key.provider,
|
||
api_key=api_key_data.api_key or existing_api_key.api_key,
|
||
api_base=api_key_data.api_base or existing_api_key.api_base,
|
||
model_type=model_config.type,
|
||
test_message="Hello"
|
||
)
|
||
if not validation_result["valid"]:
|
||
raise BusinessException(
|
||
f"模型配置验证失败: {validation_result['error']}",
|
||
BizCode.INVALID_PARAMETER
|
||
)
|
||
|
||
api_key = ModelApiKeyRepository.update(db, api_key_id, api_key_data)
|
||
db.commit()
|
||
db.refresh(api_key)
|
||
return api_key
|
||
|
||
@staticmethod
|
||
def delete_api_key(db: Session, api_key_id: uuid.UUID) -> bool:
|
||
"""删除API Key"""
|
||
if not ModelApiKeyRepository.get_by_id(db, api_key_id):
|
||
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
|
||
|
||
success = ModelApiKeyRepository.delete(db, api_key_id)
|
||
db.commit()
|
||
return success
|
||
|
||
@staticmethod
|
||
def get_available_api_key(db: Session, model_config_id: uuid.UUID) -> Optional[ModelApiKey]:
|
||
"""获取可用的API Key(按优先级和负载均衡)"""
|
||
api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active=True)
|
||
if not api_keys:
|
||
return None
|
||
return min(api_keys, key=lambda x: int(x.usage_count or "0"))
|
||
|
||
@staticmethod
|
||
def record_api_key_usage(db: Session, api_key_id: uuid.UUID) -> bool:
|
||
"""记录API Key使用"""
|
||
success = ModelApiKeyRepository.update_usage(db, api_key_id)
|
||
if success:
|
||
db.commit()
|
||
return success
|
||
|
||
@staticmethod
|
||
def get_a_api_key(db: Session, model_config_id: uuid.UUID) -> ModelApiKey:
|
||
|
||
api_kes = ModelApiKeyService.get_api_keys_by_model(db, model_config_id)
|
||
if api_kes and len(api_kes) > 0:
|
||
return api_kes[0]
|
||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||
|
||
|
||
|
||
class ModelBaseService:
|
||
"""基础模型服务"""
|
||
|
||
@staticmethod
|
||
def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List:
|
||
models = ModelBaseRepository.get_list(db, query)
|
||
|
||
provider_groups = {}
|
||
for m in models:
|
||
model_dict = model_schema.ModelBase.model_validate(m).model_dump()
|
||
if tenant_id:
|
||
model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id)
|
||
|
||
provider = m.provider
|
||
if provider not in provider_groups:
|
||
provider_groups[provider] = {
|
||
"provider": provider,
|
||
"models": []
|
||
}
|
||
provider_groups[provider]["models"].append(model_dict)
|
||
|
||
return list(provider_groups.values())
|
||
|
||
@staticmethod
|
||
def get_model_base_by_id(db: Session, model_base_id: uuid.UUID):
|
||
model = ModelBaseRepository.get_by_id(db, model_base_id)
|
||
if not model:
|
||
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
||
return model
|
||
|
||
@staticmethod
|
||
def create_model_base(db: Session, data: model_schema.ModelBaseCreate):
|
||
existing = ModelBaseRepository.get_by_name_and_provider(db, data.name, data.provider)
|
||
if existing:
|
||
raise BusinessException("模型已存在", BizCode.DUPLICATE_NAME)
|
||
model_base = ModelBaseRepository.create(db, data.model_dump())
|
||
db.commit()
|
||
db.refresh(model_base)
|
||
return model_base
|
||
|
||
@staticmethod
|
||
def update_model_base(db: Session, model_base_id: uuid.UUID, data: model_schema.ModelBaseUpdate):
|
||
model_base = ModelBaseRepository.update(db, model_base_id, data.model_dump(exclude_unset=True))
|
||
if not model_base:
|
||
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
||
db.commit()
|
||
db.refresh(model_base)
|
||
return model_base
|
||
|
||
@staticmethod
|
||
def delete_model_base(db: Session, model_base_id: uuid.UUID) -> bool:
|
||
success = ModelBaseRepository.delete(db, model_base_id)
|
||
if not success:
|
||
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
||
db.commit()
|
||
return success
|
||
|
||
@staticmethod
|
||
def add_model_from_plaza(db: Session, model_base_id: uuid.UUID, tenant_id: uuid.UUID) -> ModelConfig:
|
||
model_base = ModelBaseRepository.get_by_id(db, model_base_id)
|
||
if not model_base:
|
||
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
||
|
||
if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id):
|
||
raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME)
|
||
|
||
model_config_data = {
|
||
"model_id": model_base_id,
|
||
"tenant_id": tenant_id,
|
||
"name": model_base.name,
|
||
"provider": model_base.provider,
|
||
"type": model_base.type,
|
||
"logo": model_base.logo,
|
||
"description": model_base.description,
|
||
"is_composite": False
|
||
}
|
||
model_config = ModelConfigRepository.create(db, model_config_data)
|
||
ModelBaseRepository.increment_add_count(db, model_base_id)
|
||
db.commit()
|
||
db.refresh(model_config)
|
||
return model_config
|