[modify] model list Types support separation by comma (,)

This commit is contained in:
Mark
2025-12-18 14:08:40 +08:00
parent 2a7199f593
commit d229733dee
5 changed files with 1389 additions and 1364 deletions

View File

@@ -1,13 +1,9 @@
from fastapi import APIRouter, Depends, status, Query from fastapi import APIRouter, Depends, status, Query
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List, Optional from typing import Optional
import uuid import uuid
from app.core.models import RedBearLLM
from app.core.models.base import RedBearModelConfig
from app.db import get_db from app.db import get_db
from app.dependencies import get_current_user from app.dependencies import get_current_user
from app.models.models_model import ModelProvider, ModelType from app.models.models_model import ModelProvider, ModelType
@@ -39,7 +35,7 @@ def get_model_providers():
@router.get("", response_model=ApiResponse) @router.get("", response_model=ApiResponse)
def get_model_list( def get_model_list(
type: Optional[List[model_schema.ModelType]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM&type=EMBEDDING"), type: Optional[str] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING"),
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"), provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
is_active: Optional[bool] = Query(None, description="激活状态筛选"), is_active: Optional[bool] = Query(None, description="激活状态筛选"),
is_public: Optional[bool] = Query(None, description="公开状态筛选"), is_public: Optional[bool] = Query(None, description="公开状态筛选"),
@@ -54,13 +50,21 @@ def get_model_list(
支持多个 type 参数: 支持多个 type 参数:
- 单个:?type=LLM - 单个:?type=LLM
- 多个:?type=LLM&type=EMBEDDING - 多个(逗号分隔)?type=LLM,EMBEDDING
- 多个(重复参数):?type=LLM&type=EMBEDDING
""" """
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}") api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
try: try:
# 解析 type 参数(支持逗号分隔)
type_list = None
if type:
type_values = [t.strip() for t in type.split(',')]
type_list = [model_schema.ModelType(t.lower()) for t in type_values if t]
api_logger.error(f"获取模型type_list: {type_list}")
query = model_schema.ModelConfigQuery( query = model_schema.ModelConfigQuery(
type=type, type=type_list,
provider=provider, provider=provider,
is_active=is_active, is_active=is_active,
is_public=is_public, is_public=is_public,

View File

@@ -3,9 +3,9 @@ from sqlalchemy import and_, or_, func, desc
from typing import List, Optional, Dict, Any, Tuple from typing import List, Optional, Dict, Any, Tuple
import uuid import uuid
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelProvider from app.models.models_model import ModelConfig, ModelApiKey, ModelType
from app.schemas.model_schema import ( from app.schemas.model_schema import (
ModelConfigCreate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
ModelConfigQuery ModelConfigQuery
) )
from app.core.logging_config import get_db_logger from app.core.logging_config import get_db_logger
@@ -32,7 +32,7 @@ class ModelConfigRepository:
query = query.filter( query = query.filter(
or_( or_(
ModelConfig.tenant_id == tenant_id, ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public == True ModelConfig.is_public
) )
) )
@@ -60,7 +60,7 @@ class ModelConfigRepository:
query = query.filter( query = query.filter(
or_( or_(
ModelConfig.tenant_id == tenant_id, ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public == True ModelConfig.is_public
) )
) )
@@ -92,7 +92,7 @@ class ModelConfigRepository:
query = query.filter( query = query.filter(
or_( or_(
ModelConfig.tenant_id == tenant_id, ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public == True ModelConfig.is_public
) )
) )
@@ -117,13 +117,21 @@ class ModelConfigRepository:
filters.append( filters.append(
or_( or_(
ModelConfig.tenant_id == tenant_id, ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public == True ModelConfig.is_public
) )
) )
# 支持多个 type 值(使用 IN 查询) # 支持多个 type 值(使用 IN 查询)
# 兼容 chat 和 llm 类型:如果查询包含其中一个,则同时匹配两者
if query.type: if query.type:
filters.append(ModelConfig.type.in_(query.type)) type_values = list(query.type)
# 如果包含 chat 或 llm则同时包含两者
if ModelType.CHAT in type_values or ModelType.LLM in type_values:
if ModelType.CHAT not in type_values:
type_values.append(ModelType.CHAT)
if ModelType.LLM not in type_values:
type_values.append(ModelType.LLM)
filters.append(ModelConfig.type.in_(type_values))
if query.is_active is not None: if query.is_active is not None:
filters.append(ModelConfig.is_active == query.is_active) filters.append(ModelConfig.is_active == query.is_active)
@@ -183,12 +191,12 @@ class ModelConfigRepository:
query = query.filter( query = query.filter(
or_( or_(
ModelConfig.tenant_id == tenant_id, ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public == True ModelConfig.is_public
) )
) )
if is_active: if is_active:
query = query.filter(ModelConfig.is_active == True) query = query.filter(ModelConfig.is_active)
models = query.order_by(ModelConfig.name).all() models = query.order_by(ModelConfig.name).all()
db_logger.debug(f"根据类型查询模型配置成功: 数量={len(models)}") db_logger.debug(f"根据类型查询模型配置成功: 数量={len(models)}")
@@ -285,7 +293,7 @@ class ModelConfigRepository:
try: try:
# 总数统计 # 总数统计
total_models = db.query(ModelConfig).count() total_models = db.query(ModelConfig).count()
active_models = db.query(ModelConfig).filter(ModelConfig.is_active == True).count() active_models = db.query(ModelConfig).filter(ModelConfig.is_active).count()
# 按类型统计 # 按类型统计
llm_count = db.query(ModelConfig).filter(ModelConfig.type == ModelType.LLM).count() llm_count = db.query(ModelConfig).filter(ModelConfig.type == ModelType.LLM).count()
@@ -344,7 +352,7 @@ class ModelApiKeyRepository:
query = db.query(ModelApiKey).filter(ModelApiKey.model_config_id == model_config_id) query = db.query(ModelApiKey).filter(ModelApiKey.model_config_id == model_config_id)
if is_active: if is_active:
query = query.filter(ModelApiKey.is_active == True) query = query.filter(ModelApiKey.is_active)
api_keys = query.order_by(ModelApiKey.priority, ModelApiKey.created_at).all() api_keys = query.order_by(ModelApiKey.priority, ModelApiKey.created_at).all()
db_logger.debug(f"API Key列表查询成功: 数量={len(api_keys)}") db_logger.debug(f"API Key列表查询成功: 数量={len(api_keys)}")

View File

@@ -126,6 +126,7 @@ dependencies = [
"pytest-asyncio>=1.3.0", "pytest-asyncio>=1.3.0",
"uvicorn>=0.34.0", "uvicorn>=0.34.0",
"celery>=5.5.2", "celery>=5.5.2",
"simpleeval>=1.0.3",
] ]
[tool.pytest.ini_options] [tool.pytest.ini_options]

View File

@@ -121,3 +121,4 @@ fastmcp>=2.13.1
pytest-asyncio>=1.3.0 pytest-asyncio>=1.3.0
uvicorn>=0.34.0 uvicorn>=0.34.0
celery>=5.5.2 celery>=5.5.2
simpleeval>=1.0.3

2701
api/uv.lock generated

File diff suppressed because it is too large Load Diff