[modify] model list Types support separation by comma (,)
This commit is contained in:
@@ -1,13 +1,9 @@
|
||||
from fastapi import APIRouter, Depends, status, Query
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
|
||||
from app.core.models import RedBearLLM
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
@@ -39,7 +35,7 @@ def get_model_providers():
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_model_list(
|
||||
type: Optional[List[model_schema.ModelType]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM&type=EMBEDDING)"),
|
||||
type: Optional[str] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||
@@ -54,13 +50,21 @@ def get_model_list(
|
||||
|
||||
支持多个 type 参数:
|
||||
- 单个:?type=LLM
|
||||
- 多个:?type=LLM&type=EMBEDDING
|
||||
- 多个(逗号分隔):?type=LLM,EMBEDDING
|
||||
- 多个(重复参数):?type=LLM&type=EMBEDDING
|
||||
"""
|
||||
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
# 解析 type 参数(支持逗号分隔)
|
||||
type_list = None
|
||||
if type:
|
||||
type_values = [t.strip() for t in type.split(',')]
|
||||
type_list = [model_schema.ModelType(t.lower()) for t in type_values if t]
|
||||
|
||||
api_logger.error(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQuery(
|
||||
type=type,
|
||||
type=type_list,
|
||||
provider=provider,
|
||||
is_active=is_active,
|
||||
is_public=is_public,
|
||||
|
||||
@@ -3,9 +3,9 @@ from sqlalchemy import and_, or_, func, desc
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
import uuid
|
||||
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelProvider
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType
|
||||
from app.schemas.model_schema import (
|
||||
ModelConfigCreate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
||||
ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
||||
ModelConfigQuery
|
||||
)
|
||||
from app.core.logging_config import get_db_logger
|
||||
@@ -32,7 +32,7 @@ class ModelConfigRepository:
|
||||
query = query.filter(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public == True
|
||||
ModelConfig.is_public
|
||||
)
|
||||
)
|
||||
|
||||
@@ -60,7 +60,7 @@ class ModelConfigRepository:
|
||||
query = query.filter(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public == True
|
||||
ModelConfig.is_public
|
||||
)
|
||||
)
|
||||
|
||||
@@ -92,7 +92,7 @@ class ModelConfigRepository:
|
||||
query = query.filter(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public == True
|
||||
ModelConfig.is_public
|
||||
)
|
||||
)
|
||||
|
||||
@@ -117,13 +117,21 @@ class ModelConfigRepository:
|
||||
filters.append(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public == True
|
||||
ModelConfig.is_public
|
||||
)
|
||||
)
|
||||
|
||||
# 支持多个 type 值(使用 IN 查询)
|
||||
# 兼容 chat 和 llm 类型:如果查询包含其中一个,则同时匹配两者
|
||||
if query.type:
|
||||
filters.append(ModelConfig.type.in_(query.type))
|
||||
type_values = list(query.type)
|
||||
# 如果包含 chat 或 llm,则同时包含两者
|
||||
if ModelType.CHAT in type_values or ModelType.LLM in type_values:
|
||||
if ModelType.CHAT not in type_values:
|
||||
type_values.append(ModelType.CHAT)
|
||||
if ModelType.LLM not in type_values:
|
||||
type_values.append(ModelType.LLM)
|
||||
filters.append(ModelConfig.type.in_(type_values))
|
||||
|
||||
if query.is_active is not None:
|
||||
filters.append(ModelConfig.is_active == query.is_active)
|
||||
@@ -183,12 +191,12 @@ class ModelConfigRepository:
|
||||
query = query.filter(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public == True
|
||||
ModelConfig.is_public
|
||||
)
|
||||
)
|
||||
|
||||
if is_active:
|
||||
query = query.filter(ModelConfig.is_active == True)
|
||||
query = query.filter(ModelConfig.is_active)
|
||||
|
||||
models = query.order_by(ModelConfig.name).all()
|
||||
db_logger.debug(f"根据类型查询模型配置成功: 数量={len(models)}")
|
||||
@@ -285,7 +293,7 @@ class ModelConfigRepository:
|
||||
try:
|
||||
# 总数统计
|
||||
total_models = db.query(ModelConfig).count()
|
||||
active_models = db.query(ModelConfig).filter(ModelConfig.is_active == True).count()
|
||||
active_models = db.query(ModelConfig).filter(ModelConfig.is_active).count()
|
||||
|
||||
# 按类型统计
|
||||
llm_count = db.query(ModelConfig).filter(ModelConfig.type == ModelType.LLM).count()
|
||||
@@ -344,7 +352,7 @@ class ModelApiKeyRepository:
|
||||
query = db.query(ModelApiKey).filter(ModelApiKey.model_config_id == model_config_id)
|
||||
|
||||
if is_active:
|
||||
query = query.filter(ModelApiKey.is_active == True)
|
||||
query = query.filter(ModelApiKey.is_active)
|
||||
|
||||
api_keys = query.order_by(ModelApiKey.priority, ModelApiKey.created_at).all()
|
||||
db_logger.debug(f"API Key列表查询成功: 数量={len(api_keys)}")
|
||||
|
||||
@@ -126,6 +126,7 @@ dependencies = [
|
||||
"pytest-asyncio>=1.3.0",
|
||||
"uvicorn>=0.34.0",
|
||||
"celery>=5.5.2",
|
||||
"simpleeval>=1.0.3",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
@@ -121,3 +121,4 @@ fastmcp>=2.13.1
|
||||
pytest-asyncio>=1.3.0
|
||||
uvicorn>=0.34.0
|
||||
celery>=5.5.2
|
||||
simpleeval>=1.0.3
|
||||
|
||||
2701
api/uv.lock
generated
2701
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user