diff --git a/api/app/controllers/model_controller.py b/api/app/controllers/model_controller.py index 83753744..bb1ba526 100644 --- a/api/app/controllers/model_controller.py +++ b/api/app/controllers/model_controller.py @@ -328,7 +328,7 @@ async def update_composite_model( try: if model_data.type is not None: - raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER) + raise BusinessException("不允许更改模型类型", BizCode.INVALID_PARAMETER) result_orm = await ModelConfigService.update_composite_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id) api_logger.info(f"组合模型更新成功: {result_orm.name} (ID: {model_id})") @@ -368,6 +368,9 @@ def update_model( 更新模型配置 """ api_logger.info(f"更新模型配置请求: model_id={model_id}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}") + + if model_data.type is not None or model_data.provider is not None: + raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER) try: api_logger.debug(f"开始更新模型配置: model_id={model_id}") diff --git a/api/app/repositories/model_repository.py b/api/app/repositories/model_repository.py index f323b30c..2c513e82 100644 --- a/api/app/repositories/model_repository.py +++ b/api/app/repositories/model_repository.py @@ -48,13 +48,17 @@ class ModelConfigRepository: raise @staticmethod - def get_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None) -> Optional[ModelConfig]: - """根据名称获取模型配置""" - db_logger.debug(f"根据名称查询模型配置: name={name}, tenant_id={tenant_id}") + def get_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> Optional[ModelConfig]: + """根据名称和供应商获取模型配置""" + db_logger.debug(f"根据名称查询模型配置: name={name}, provider={provider}, tenant_id={tenant_id}") try: query = db.query(ModelConfig).filter(ModelConfig.name == name) + # 添加供应商过滤 + if provider: + query = query.filter(ModelConfig.provider == provider) + # 添加租户过滤 if tenant_id: query = query.filter( @@ -69,7 +73,7 @@ class ModelConfigRepository: db_logger.debug(f"模型配置查询成功: {model.name}") return model except Exception as e: - db_logger.error(f"根据名称查询模型配置失败: name={name} - {str(e)}") + db_logger.error(f"根据名称查询模型配置失败: name={name}, provider={provider} - {str(e)}") raise @staticmethod diff --git a/api/app/schemas/model_schema.py b/api/app/schemas/model_schema.py index a2d3650a..0c0bbeed 100644 --- a/api/app/schemas/model_schema.py +++ b/api/app/schemas/model_schema.py @@ -25,9 +25,9 @@ class ModelConfigBase(BaseModel): class ApiKeyCreateNested(BaseModel): """用于在创建模型时内嵌创建API Key的Schema""" - model_name: str = Field(..., description="模型实际名称", max_length=255) + model_name: Optional[str] = Field(None, description="模型实际名称", max_length=255) description: Optional[str] = Field(None, description="备注") - provider: ModelProvider = Field(..., description="API Key提供商") + provider: Optional[str] = Field(None, description="API Key提供商") api_key: str = Field(..., description="API密钥", max_length=500) api_base: Optional[str] = Field(None, description="API基础URL", max_length=500) config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置") @@ -57,6 +57,8 @@ class ModelConfigUpdate(BaseModel): """更新模型配置Schema""" name: Optional[str] = Field(None, description="模型显示名称", max_length=255) type: Optional[ModelType] = Field(None, description="模型类型") + provider: Optional[str] = Field(None, description="供应商") + logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255) description: Optional[str] = Field(None, description="模型描述") config: Optional[Dict[str, Any]] = Field(None, description="模型配置参数") is_active: Optional[bool] = Field(None, description="是否激活") diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index d382b1b1..aa8cfbac 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -6,7 +6,7 @@ import math import time import asyncio -from app.models.models_model import ModelConfig, ModelApiKey, ModelType, LoadBalanceStrategy +from app.models.models_model import ModelConfig, ModelApiKey, ModelType, LoadBalanceStrategy, ModelProvider from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository, ModelBaseRepository from app.schemas import model_schema from app.schemas.model_schema import ( @@ -69,9 +69,9 @@ class ModelConfigService: return items @staticmethod - def get_model_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None) -> ModelConfig: + def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig: """根据名称获取模型配置""" - model = ModelConfigRepository.get_by_name(db, name, tenant_id=tenant_id) + model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id) if not model: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) return model @@ -244,7 +244,7 @@ class ModelConfigService: async def create_model(db: Session, model_data: ModelConfigCreate, tenant_id: uuid.UUID) -> ModelConfig: """创建模型配置""" # 检查名称是否已存在(同租户内) - if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id): + if ModelConfigRepository.get_by_name(db, model_data.name, provider=model_data.provider, tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) # 验证配置 @@ -253,8 +253,8 @@ class ModelConfigService: for api_key_data in api_key_data_list: validation_result = await ModelConfigService.validate_model_config( db=db, - model_name=api_key_data.model_name, - provider=api_key_data.provider, + model_name=model_data.name, + provider=model_data.provider, api_key=api_key_data.api_key, api_base=api_key_data.api_base, model_type=model_data.type, # 传递模型类型 @@ -277,6 +277,8 @@ class ModelConfigService: if api_key_datas: for api_key_data in api_key_datas: + api_key_data.model_name = model_data.name + api_key_data.provider = model_data.provider api_key_create_schema = ModelApiKeyCreate( model_config_ids=[model.id], **api_key_data.model_dump() @@ -295,7 +297,7 @@ class ModelConfigService: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) if model_data.name and model_data.name != existing_model.name: - if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id): + if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id) @@ -306,7 +308,7 @@ class ModelConfigService: @staticmethod async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig: """创建组合模型""" - if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id): + if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) # 验证所有 API Key 存在且类型匹配 @@ -341,7 +343,7 @@ class ModelConfigService: "type": model_data.type, "logo": model_data.logo, "description": model_data.description, - "provider": "composite", + "provider": ModelProvider.COMPOSITE, "config": model_data.config, "is_active": model_data.is_active, "is_public": model_data.is_public, @@ -369,6 +371,10 @@ class ModelConfigService: existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id) if not existing_model: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) + + if model_data.name and model_data.name != existing_model.name: + if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id): + raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) if not existing_model.is_composite: raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER) @@ -471,11 +477,14 @@ class ModelApiKeyService: # 从ModelBase获取model_name model_name = model_config.model_base.name if model_config.model_base else model_config.name - # 检查是否存在API Key(包括软删除) - existing_key = db.query(ModelApiKey).filter( + # 检查是否存在API Key(包括软删除),需要考虑tenant_id + existing_key = db.query(ModelApiKey).join( + ModelApiKey.model_configs + ).filter( ModelApiKey.api_key == data.api_key, ModelApiKey.provider == data.provider, - ModelApiKey.model_name == model_name + ModelApiKey.model_name == model_name, + ModelConfig.tenant_id == model_config.tenant_id ).first() if existing_key: @@ -542,11 +551,14 @@ class ModelApiKeyService: if not model_config: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) - # 检查API Key是否已存在(包括软删除) - existing_key = db.query(ModelApiKey).filter( + # 检查API Key是否已存在(包括软删除),需要考虑tenant_id + existing_key = db.query(ModelApiKey).join( + ModelApiKey.model_configs + ).filter( ModelApiKey.api_key == api_key_data.api_key, ModelApiKey.provider == api_key_data.provider, - ModelApiKey.model_name == api_key_data.model_name + ModelApiKey.model_name == api_key_data.model_name, + ModelConfig.tenant_id == model_config.tenant_id ).first() if existing_key: