[modify] model list Types support separation by comma (,)
This commit is contained in:
@@ -1,13 +1,9 @@
|
|||||||
from fastapi import APIRouter, Depends, status, Query
|
from fastapi import APIRouter, Depends, status, Query
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
from app.core.models import RedBearLLM
|
|
||||||
from app.core.models.base import RedBearModelConfig
|
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.models.models_model import ModelProvider, ModelType
|
from app.models.models_model import ModelProvider, ModelType
|
||||||
@@ -39,7 +35,7 @@ def get_model_providers():
|
|||||||
|
|
||||||
@router.get("", response_model=ApiResponse)
|
@router.get("", response_model=ApiResponse)
|
||||||
def get_model_list(
|
def get_model_list(
|
||||||
type: Optional[List[model_schema.ModelType]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM&type=EMBEDDING)"),
|
type: Optional[str] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||||
@@ -54,13 +50,21 @@ def get_model_list(
|
|||||||
|
|
||||||
支持多个 type 参数:
|
支持多个 type 参数:
|
||||||
- 单个:?type=LLM
|
- 单个:?type=LLM
|
||||||
- 多个:?type=LLM&type=EMBEDDING
|
- 多个(逗号分隔):?type=LLM,EMBEDDING
|
||||||
|
- 多个(重复参数):?type=LLM&type=EMBEDDING
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
|
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 解析 type 参数(支持逗号分隔)
|
||||||
|
type_list = None
|
||||||
|
if type:
|
||||||
|
type_values = [t.strip() for t in type.split(',')]
|
||||||
|
type_list = [model_schema.ModelType(t.lower()) for t in type_values if t]
|
||||||
|
|
||||||
|
api_logger.error(f"获取模型type_list: {type_list}")
|
||||||
query = model_schema.ModelConfigQuery(
|
query = model_schema.ModelConfigQuery(
|
||||||
type=type,
|
type=type_list,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
is_active=is_active,
|
is_active=is_active,
|
||||||
is_public=is_public,
|
is_public=is_public,
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ from sqlalchemy import and_, or_, func, desc
|
|||||||
from typing import List, Optional, Dict, Any, Tuple
|
from typing import List, Optional, Dict, Any, Tuple
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelProvider
|
from app.models.models_model import ModelConfig, ModelApiKey, ModelType
|
||||||
from app.schemas.model_schema import (
|
from app.schemas.model_schema import (
|
||||||
ModelConfigCreate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
||||||
ModelConfigQuery
|
ModelConfigQuery
|
||||||
)
|
)
|
||||||
from app.core.logging_config import get_db_logger
|
from app.core.logging_config import get_db_logger
|
||||||
@@ -32,7 +32,7 @@ class ModelConfigRepository:
|
|||||||
query = query.filter(
|
query = query.filter(
|
||||||
or_(
|
or_(
|
||||||
ModelConfig.tenant_id == tenant_id,
|
ModelConfig.tenant_id == tenant_id,
|
||||||
ModelConfig.is_public == True
|
ModelConfig.is_public
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -60,7 +60,7 @@ class ModelConfigRepository:
|
|||||||
query = query.filter(
|
query = query.filter(
|
||||||
or_(
|
or_(
|
||||||
ModelConfig.tenant_id == tenant_id,
|
ModelConfig.tenant_id == tenant_id,
|
||||||
ModelConfig.is_public == True
|
ModelConfig.is_public
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -92,7 +92,7 @@ class ModelConfigRepository:
|
|||||||
query = query.filter(
|
query = query.filter(
|
||||||
or_(
|
or_(
|
||||||
ModelConfig.tenant_id == tenant_id,
|
ModelConfig.tenant_id == tenant_id,
|
||||||
ModelConfig.is_public == True
|
ModelConfig.is_public
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -117,13 +117,21 @@ class ModelConfigRepository:
|
|||||||
filters.append(
|
filters.append(
|
||||||
or_(
|
or_(
|
||||||
ModelConfig.tenant_id == tenant_id,
|
ModelConfig.tenant_id == tenant_id,
|
||||||
ModelConfig.is_public == True
|
ModelConfig.is_public
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 支持多个 type 值(使用 IN 查询)
|
# 支持多个 type 值(使用 IN 查询)
|
||||||
|
# 兼容 chat 和 llm 类型:如果查询包含其中一个,则同时匹配两者
|
||||||
if query.type:
|
if query.type:
|
||||||
filters.append(ModelConfig.type.in_(query.type))
|
type_values = list(query.type)
|
||||||
|
# 如果包含 chat 或 llm,则同时包含两者
|
||||||
|
if ModelType.CHAT in type_values or ModelType.LLM in type_values:
|
||||||
|
if ModelType.CHAT not in type_values:
|
||||||
|
type_values.append(ModelType.CHAT)
|
||||||
|
if ModelType.LLM not in type_values:
|
||||||
|
type_values.append(ModelType.LLM)
|
||||||
|
filters.append(ModelConfig.type.in_(type_values))
|
||||||
|
|
||||||
if query.is_active is not None:
|
if query.is_active is not None:
|
||||||
filters.append(ModelConfig.is_active == query.is_active)
|
filters.append(ModelConfig.is_active == query.is_active)
|
||||||
@@ -183,12 +191,12 @@ class ModelConfigRepository:
|
|||||||
query = query.filter(
|
query = query.filter(
|
||||||
or_(
|
or_(
|
||||||
ModelConfig.tenant_id == tenant_id,
|
ModelConfig.tenant_id == tenant_id,
|
||||||
ModelConfig.is_public == True
|
ModelConfig.is_public
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_active:
|
if is_active:
|
||||||
query = query.filter(ModelConfig.is_active == True)
|
query = query.filter(ModelConfig.is_active)
|
||||||
|
|
||||||
models = query.order_by(ModelConfig.name).all()
|
models = query.order_by(ModelConfig.name).all()
|
||||||
db_logger.debug(f"根据类型查询模型配置成功: 数量={len(models)}")
|
db_logger.debug(f"根据类型查询模型配置成功: 数量={len(models)}")
|
||||||
@@ -285,7 +293,7 @@ class ModelConfigRepository:
|
|||||||
try:
|
try:
|
||||||
# 总数统计
|
# 总数统计
|
||||||
total_models = db.query(ModelConfig).count()
|
total_models = db.query(ModelConfig).count()
|
||||||
active_models = db.query(ModelConfig).filter(ModelConfig.is_active == True).count()
|
active_models = db.query(ModelConfig).filter(ModelConfig.is_active).count()
|
||||||
|
|
||||||
# 按类型统计
|
# 按类型统计
|
||||||
llm_count = db.query(ModelConfig).filter(ModelConfig.type == ModelType.LLM).count()
|
llm_count = db.query(ModelConfig).filter(ModelConfig.type == ModelType.LLM).count()
|
||||||
@@ -344,7 +352,7 @@ class ModelApiKeyRepository:
|
|||||||
query = db.query(ModelApiKey).filter(ModelApiKey.model_config_id == model_config_id)
|
query = db.query(ModelApiKey).filter(ModelApiKey.model_config_id == model_config_id)
|
||||||
|
|
||||||
if is_active:
|
if is_active:
|
||||||
query = query.filter(ModelApiKey.is_active == True)
|
query = query.filter(ModelApiKey.is_active)
|
||||||
|
|
||||||
api_keys = query.order_by(ModelApiKey.priority, ModelApiKey.created_at).all()
|
api_keys = query.order_by(ModelApiKey.priority, ModelApiKey.created_at).all()
|
||||||
db_logger.debug(f"API Key列表查询成功: 数量={len(api_keys)}")
|
db_logger.debug(f"API Key列表查询成功: 数量={len(api_keys)}")
|
||||||
|
|||||||
@@ -126,6 +126,7 @@ dependencies = [
|
|||||||
"pytest-asyncio>=1.3.0",
|
"pytest-asyncio>=1.3.0",
|
||||||
"uvicorn>=0.34.0",
|
"uvicorn>=0.34.0",
|
||||||
"celery>=5.5.2",
|
"celery>=5.5.2",
|
||||||
|
"simpleeval>=1.0.3",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
|
|||||||
@@ -121,3 +121,4 @@ fastmcp>=2.13.1
|
|||||||
pytest-asyncio>=1.3.0
|
pytest-asyncio>=1.3.0
|
||||||
uvicorn>=0.34.0
|
uvicorn>=0.34.0
|
||||||
celery>=5.5.2
|
celery>=5.5.2
|
||||||
|
simpleeval>=1.0.3
|
||||||
|
|||||||
2701
api/uv.lock
generated
2701
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user