fix(model):

1. when adding a model API key to the model list, a tenant_id uniqueness check needs to be added;
2.the Model Square has cancelled custom models;
3. optimization of the interface logic for customizing model configurations in the model list
This commit is contained in:
Timebomb2018
2026-02-09 10:02:34 +08:00
parent 4d98bace87
commit ebad5e00a3
4 changed files with 43 additions and 22 deletions

View File

@@ -328,7 +328,7 @@ async def update_composite_model(
try:
if model_data.type is not None:
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
raise BusinessException("不允许更改模型类型", BizCode.INVALID_PARAMETER)
result_orm = await ModelConfigService.update_composite_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
api_logger.info(f"组合模型更新成功: {result_orm.name} (ID: {model_id})")
@@ -368,6 +368,9 @@ def update_model(
更新模型配置
"""
api_logger.info(f"更新模型配置请求: model_id={model_id}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
if model_data.type is not None or model_data.provider is not None:
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
try:
api_logger.debug(f"开始更新模型配置: model_id={model_id}")

View File

@@ -48,13 +48,17 @@ class ModelConfigRepository:
raise
@staticmethod
def get_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None) -> Optional[ModelConfig]:
"""根据名称获取模型配置"""
db_logger.debug(f"根据名称查询模型配置: name={name}, tenant_id={tenant_id}")
def get_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> Optional[ModelConfig]:
"""根据名称和供应商获取模型配置"""
db_logger.debug(f"根据名称查询模型配置: name={name}, provider={provider}, tenant_id={tenant_id}")
try:
query = db.query(ModelConfig).filter(ModelConfig.name == name)
# 添加供应商过滤
if provider:
query = query.filter(ModelConfig.provider == provider)
# 添加租户过滤
if tenant_id:
query = query.filter(
@@ -69,7 +73,7 @@ class ModelConfigRepository:
db_logger.debug(f"模型配置查询成功: {model.name}")
return model
except Exception as e:
db_logger.error(f"根据名称查询模型配置失败: name={name} - {str(e)}")
db_logger.error(f"根据名称查询模型配置失败: name={name}, provider={provider} - {str(e)}")
raise
@staticmethod

View File

@@ -25,9 +25,9 @@ class ModelConfigBase(BaseModel):
class ApiKeyCreateNested(BaseModel):
"""用于在创建模型时内嵌创建API Key的Schema"""
model_name: str = Field(..., description="模型实际名称", max_length=255)
model_name: Optional[str] = Field(None, description="模型实际名称", max_length=255)
description: Optional[str] = Field(None, description="备注")
provider: ModelProvider = Field(..., description="API Key提供商")
provider: Optional[str] = Field(None, description="API Key提供商")
api_key: str = Field(..., description="API密钥", max_length=500)
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
@@ -57,6 +57,8 @@ class ModelConfigUpdate(BaseModel):
"""更新模型配置Schema"""
name: Optional[str] = Field(None, description="模型显示名称", max_length=255)
type: Optional[ModelType] = Field(None, description="模型类型")
provider: Optional[str] = Field(None, description="供应商")
logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255)
description: Optional[str] = Field(None, description="模型描述")
config: Optional[Dict[str, Any]] = Field(None, description="模型配置参数")
is_active: Optional[bool] = Field(None, description="是否激活")

View File

@@ -6,7 +6,7 @@ import math
import time
import asyncio
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, LoadBalanceStrategy
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, LoadBalanceStrategy, ModelProvider
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository, ModelBaseRepository
from app.schemas import model_schema
from app.schemas.model_schema import (
@@ -69,9 +69,9 @@ class ModelConfigService:
return items
@staticmethod
def get_model_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None) -> ModelConfig:
def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig:
"""根据名称获取模型配置"""
model = ModelConfigRepository.get_by_name(db, name, tenant_id=tenant_id)
model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id)
if not model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
return model
@@ -244,7 +244,7 @@ class ModelConfigService:
async def create_model(db: Session, model_data: ModelConfigCreate, tenant_id: uuid.UUID) -> ModelConfig:
"""创建模型配置"""
# 检查名称是否已存在(同租户内)
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
if ModelConfigRepository.get_by_name(db, model_data.name, provider=model_data.provider, tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
# 验证配置
@@ -253,8 +253,8 @@ class ModelConfigService:
for api_key_data in api_key_data_list:
validation_result = await ModelConfigService.validate_model_config(
db=db,
model_name=api_key_data.model_name,
provider=api_key_data.provider,
model_name=model_data.name,
provider=model_data.provider,
api_key=api_key_data.api_key,
api_base=api_key_data.api_base,
model_type=model_data.type, # 传递模型类型
@@ -277,6 +277,8 @@ class ModelConfigService:
if api_key_datas:
for api_key_data in api_key_datas:
api_key_data.model_name = model_data.name
api_key_data.provider = model_data.provider
api_key_create_schema = ModelApiKeyCreate(
model_config_ids=[model.id],
**api_key_data.model_dump()
@@ -295,7 +297,7 @@ class ModelConfigService:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
if model_data.name and model_data.name != existing_model.name:
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
@@ -306,7 +308,7 @@ class ModelConfigService:
@staticmethod
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
"""创建组合模型"""
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
# 验证所有 API Key 存在且类型匹配
@@ -341,7 +343,7 @@ class ModelConfigService:
"type": model_data.type,
"logo": model_data.logo,
"description": model_data.description,
"provider": "composite",
"provider": ModelProvider.COMPOSITE,
"config": model_data.config,
"is_active": model_data.is_active,
"is_public": model_data.is_public,
@@ -369,6 +371,10 @@ class ModelConfigService:
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
if not existing_model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
if model_data.name and model_data.name != existing_model.name:
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
if not existing_model.is_composite:
raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER)
@@ -471,11 +477,14 @@ class ModelApiKeyService:
# 从ModelBase获取model_name
model_name = model_config.model_base.name if model_config.model_base else model_config.name
# 检查是否存在API Key包括软删除
existing_key = db.query(ModelApiKey).filter(
# 检查是否存在API Key包括软删除需要考虑tenant_id
existing_key = db.query(ModelApiKey).join(
ModelApiKey.model_configs
).filter(
ModelApiKey.api_key == data.api_key,
ModelApiKey.provider == data.provider,
ModelApiKey.model_name == model_name
ModelApiKey.model_name == model_name,
ModelConfig.tenant_id == model_config.tenant_id
).first()
if existing_key:
@@ -542,11 +551,14 @@ class ModelApiKeyService:
if not model_config:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
# 检查API Key是否已存在(包括软删除)
existing_key = db.query(ModelApiKey).filter(
# 检查API Key是否已存在(包括软删除)需要考虑tenant_id
existing_key = db.query(ModelApiKey).join(
ModelApiKey.model_configs
).filter(
ModelApiKey.api_key == api_key_data.api_key,
ModelApiKey.provider == api_key_data.provider,
ModelApiKey.model_name == api_key_data.model_name
ModelApiKey.model_name == api_key_data.model_name,
ModelConfig.tenant_id == model_config.tenant_id
).first()
if existing_key: