[modify] model list Types support separation by comma (,)

This commit is contained in:
Mark
2025-12-18 14:08:40 +08:00
parent 2a7199f593
commit d229733dee
5 changed files with 1389 additions and 1364 deletions

View File

@@ -1,13 +1,9 @@
from fastapi import APIRouter, Depends, status, Query
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from sqlalchemy.orm import Session
from typing import List, Optional
from typing import Optional
import uuid
from app.core.models import RedBearLLM
from app.core.models.base import RedBearModelConfig
from app.db import get_db
from app.dependencies import get_current_user
from app.models.models_model import ModelProvider, ModelType
@@ -39,7 +35,7 @@ def get_model_providers():
@router.get("", response_model=ApiResponse)
def get_model_list(
type: Optional[List[model_schema.ModelType]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM&type=EMBEDDING"),
type: Optional[str] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING"),
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
@@ -54,13 +50,21 @@ def get_model_list(
支持多个 type 参数:
- 单个:?type=LLM
- 多个:?type=LLM&type=EMBEDDING
- 多个(逗号分隔)?type=LLM,EMBEDDING
- 多个(重复参数):?type=LLM&type=EMBEDDING
"""
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
try:
# 解析 type 参数(支持逗号分隔)
type_list = None
if type:
type_values = [t.strip() for t in type.split(',')]
type_list = [model_schema.ModelType(t.lower()) for t in type_values if t]
api_logger.error(f"获取模型type_list: {type_list}")
query = model_schema.ModelConfigQuery(
type=type,
type=type_list,
provider=provider,
is_active=is_active,
is_public=is_public,

View File

@@ -3,9 +3,9 @@ from sqlalchemy import and_, or_, func, desc
from typing import List, Optional, Dict, Any, Tuple
import uuid
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelProvider
from app.models.models_model import ModelConfig, ModelApiKey, ModelType
from app.schemas.model_schema import (
ModelConfigCreate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
ModelConfigQuery
)
from app.core.logging_config import get_db_logger
@@ -32,7 +32,7 @@ class ModelConfigRepository:
query = query.filter(
or_(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public == True
ModelConfig.is_public
)
)
@@ -60,7 +60,7 @@ class ModelConfigRepository:
query = query.filter(
or_(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public == True
ModelConfig.is_public
)
)
@@ -92,7 +92,7 @@ class ModelConfigRepository:
query = query.filter(
or_(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public == True
ModelConfig.is_public
)
)
@@ -117,13 +117,21 @@ class ModelConfigRepository:
filters.append(
or_(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public == True
ModelConfig.is_public
)
)
# 支持多个 type 值(使用 IN 查询)
# 兼容 chat 和 llm 类型:如果查询包含其中一个,则同时匹配两者
if query.type:
filters.append(ModelConfig.type.in_(query.type))
type_values = list(query.type)
# 如果包含 chat 或 llm则同时包含两者
if ModelType.CHAT in type_values or ModelType.LLM in type_values:
if ModelType.CHAT not in type_values:
type_values.append(ModelType.CHAT)
if ModelType.LLM not in type_values:
type_values.append(ModelType.LLM)
filters.append(ModelConfig.type.in_(type_values))
if query.is_active is not None:
filters.append(ModelConfig.is_active == query.is_active)
@@ -183,12 +191,12 @@ class ModelConfigRepository:
query = query.filter(
or_(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public == True
ModelConfig.is_public
)
)
if is_active:
query = query.filter(ModelConfig.is_active == True)
query = query.filter(ModelConfig.is_active)
models = query.order_by(ModelConfig.name).all()
db_logger.debug(f"根据类型查询模型配置成功: 数量={len(models)}")
@@ -285,7 +293,7 @@ class ModelConfigRepository:
try:
# 总数统计
total_models = db.query(ModelConfig).count()
active_models = db.query(ModelConfig).filter(ModelConfig.is_active == True).count()
active_models = db.query(ModelConfig).filter(ModelConfig.is_active).count()
# 按类型统计
llm_count = db.query(ModelConfig).filter(ModelConfig.type == ModelType.LLM).count()
@@ -344,7 +352,7 @@ class ModelApiKeyRepository:
query = db.query(ModelApiKey).filter(ModelApiKey.model_config_id == model_config_id)
if is_active:
query = query.filter(ModelApiKey.is_active == True)
query = query.filter(ModelApiKey.is_active)
api_keys = query.order_by(ModelApiKey.priority, ModelApiKey.created_at).all()
db_logger.debug(f"API Key列表查询成功: 数量={len(api_keys)}")