diff --git a/api/app/controllers/model_controller.py b/api/app/controllers/model_controller.py index 481c520e..509f7cad 100644 --- a/api/app/controllers/model_controller.py +++ b/api/app/controllers/model_controller.py @@ -36,7 +36,7 @@ def get_model_providers(): @router.get("", response_model=ApiResponse) def get_model_list( - type: Optional[str | list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"), + type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"), provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"), is_active: Optional[bool] = Query(None, description="激活状态筛选"), is_public: Optional[bool] = Query(None, description="公开状态筛选"), @@ -60,11 +60,14 @@ def get_model_list( try: # 解析 type 参数(支持逗号分隔) type_list = [] - if isinstance(type, str): - type_values = [t.strip() for t in type.split(',')] - type_list = [model_schema.ModelType(t.lower()) for t in type_values if t] - elif isinstance(type, list): - type_list = type + if type is not None: + flat_type = [] + for item in type: + split_items = [t.strip() for t in item.split(',') if t.strip()] + flat_type.extend(split_items) + + unique_flat_type = list(dict.fromkeys(flat_type)) + type_list = [ModelType(t.lower()) for t in unique_flat_type] api_logger.error(f"获取模型type_list: {type_list}") query = model_schema.ModelConfigQuery( @@ -89,7 +92,7 @@ def get_model_list( @router.get("/new", response_model=ApiResponse) def get_model_list( - type: Optional[str | list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"), + type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"), provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于ModelConfig)"), is_active: Optional[bool] = Query(None, description="激活状态筛选"), is_public: Optional[bool] = Query(None, description="公开状态筛选"), @@ -111,11 +114,14 @@ def get_model_list( try: # 解析 type 参数(支持逗号分隔) type_list = [] - if isinstance(type, str): - type_values = [t.strip() for t in type.split(',')] - type_list = [model_schema.ModelType(t.lower()) for t in type_values if t] - elif isinstance(type, list): - type_list = type + if type is not None: + flat_type = [] + for item in type: + split_items = [t.strip() for t in item.split(',') if t.strip()] + flat_type.extend(split_items) + + unique_flat_type = list(dict.fromkeys(flat_type)) + type_list = [ModelType(t.lower()) for t in unique_flat_type] api_logger.info(f"获取模型type_list: {type_list}") query = model_schema.ModelConfigQueryNew(