diff --git a/api/app/schemas/model_schema.py b/api/app/schemas/model_schema.py index f25d9408..ea4183a5 100644 --- a/api/app/schemas/model_schema.py +++ b/api/app/schemas/model_schema.py @@ -116,8 +116,8 @@ class ModelApiKeyBase(BaseModel): provider: ModelProvider = Field(..., description="API Key提供商") api_key: str = Field(..., description="API密钥", max_length=500) api_base: Optional[str] = Field(None, description="API基础URL", max_length=500) - capability: List[str] = Field(default_factory=list, description="模型能力列表") - is_omni: bool = Field(False, description="是否为Omni模型") + capability: Optional[List[str]] = Field(None, description="模型能力列表") + is_omni: Optional[bool] = Field(None, description="是否为Omni模型") config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置") is_active: bool = Field(True, description="是否激活") priority: str = Field("1", description="优先级", max_length=10) diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index 2337427a..cba25f32 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -116,27 +116,15 @@ class ModelConfigService: try: start_time = time.time() - # dashscope 的 omni 模型需要使用 compatible-mode - if provider.lower() == ModelProvider.DASHSCOPE and is_omni: - if not api_base: - api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1" - model_config = RedBearModelConfig( - model_name=model_name, - provider=ModelProvider.OPENAI, - api_key=api_key, - base_url=api_base, - temperature=0.7, - max_tokens=100 - ) - else: - model_config = RedBearModelConfig( - model_name=model_name, - provider=provider, - api_key=api_key, - base_url=api_base, - temperature=0.7, - max_tokens=100 - ) + model_config = RedBearModelConfig( + model_name=model_name, + provider=provider, + api_key=api_key, + base_url=api_base, + is_omni=is_omni, + temperature=0.7, + max_tokens=100 + ) # 根据模型类型选择不同的验证方式 model_type_lower = model_type.lower() @@ -492,6 +480,9 @@ class ModelApiKeyService: model_config = ModelConfigRepository.get_by_id(db, model_config_id) if not model_config: continue + + data.is_omni = model_config.is_omni + data.capability = model_config.capability # 从ModelBase获取model_name model_name = model_config.model_base.name if model_config.model_base else model_config.name @@ -550,8 +541,8 @@ class ModelApiKeyService: provider=data.provider, api_key=data.api_key, api_base=data.api_base, - capability=data.capability if data.capability is not None else model_config.capability, - is_omni=data.is_omni if data.is_omni is not None else model_config.is_omni, + capability=data.capability, + is_omni=data.is_omni, config=data.config, is_active=data.is_active, priority=data.priority @@ -574,6 +565,10 @@ class ModelApiKeyService: model_config = ModelConfigRepository.get_by_id(db, model_config_id) if not model_config: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) + if api_key_data.is_omni is None: + api_key_data.is_omni = model_config.is_omni + if api_key_data.capability is None: + api_key_data.capability = model_config.capability # 检查API Key是否已存在(包括软删除),需要考虑tenant_id existing_key = db.query(ModelApiKey).join( @@ -616,7 +611,7 @@ class ModelApiKeyService: api_base=api_key_data.api_base, model_type=model_config.type, test_message="Hello", - is_omni=model_config.is_omni + is_omni=api_key_data.is_omni ) if not validation_result["valid"]: raise BusinessException(