feat(model and app statistic): 1. Optimize the model list; 2. Increase the model combination; 3. Add a model square; 4. Add application management statistics

This commit is contained in:
Timebomb2018
2026-01-28 10:15:51 +08:00
parent bf3e30dac0
commit 2862db3534
14 changed files with 1458 additions and 233 deletions

View File

@@ -1,12 +1,12 @@
from sqlalchemy.orm import Session, joinedload
from sqlalchemy import and_, or_, func, desc
from sqlalchemy.orm import Session, joinedload, selectinload
from sqlalchemy import and_, or_, func, desc, select
from typing import List, Optional, Dict, Any, Tuple
import uuid
from app.models.models_model import ModelConfig, ModelApiKey, ModelType
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelBase, model_config_api_key_association
from app.schemas.model_schema import (
ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
ModelConfigQuery
ModelConfigQuery, ModelConfigQueryNew
)
from app.core.logging_config import get_db_logger
@@ -107,6 +107,80 @@ class ModelConfigRepository:
def get_list(db: Session, query: ModelConfigQuery, tenant_id: uuid.UUID | None = None) -> Tuple[List[ModelConfig], int]:
"""获取模型配置列表"""
db_logger.debug(f"查询模型配置列表: {query.dict()}, tenant_id={tenant_id}")
try:
# 构建查询条件
filters = []
# 添加租户过滤(查询本租户的模型或公开模型)
if tenant_id:
filters.append(
or_(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public
)
)
# 支持多个 type 值(使用 IN 查询)
# 兼容 chat 和 llm 类型:如果查询包含其中一个,则同时匹配两者
if query.type:
type_values = list(query.type)
# 如果包含 chat 或 llm则同时包含两者
if ModelType.CHAT in type_values or ModelType.LLM in type_values:
if ModelType.CHAT not in type_values:
type_values.append(ModelType.CHAT)
if ModelType.LLM not in type_values:
type_values.append(ModelType.LLM)
filters.append(ModelConfig.type.in_(type_values))
if query.is_active is not None:
filters.append(ModelConfig.is_active == query.is_active)
if query.is_public is not None:
filters.append(ModelConfig.is_public == query.is_public)
if query.search:
# 搜索逻辑需要join ModelApiKey表来搜索model_name
search_filter = or_(
ModelConfig.name.ilike(f"%{query.search}%"),
# ModelConfig.description.ilike(f"%{query.search}%")
)
filters.append(search_filter)
# 构建基础查询
base_query = db.query(ModelConfig).options(
joinedload(ModelConfig.api_keys)
)
# 如果需要按provider筛选需要join ModelApiKey表
if query.provider:
base_query = base_query.join(ModelApiKey).filter(
ModelApiKey.provider == query.provider
).distinct()
if filters:
base_query = base_query.filter(and_(*filters))
# 获取总数
total = base_query.count()
# 分页查询
models = base_query.order_by(desc(ModelConfig.updated_at)).offset(
(query.page - 1) * query.pagesize
).limit(query.pagesize).all()
db_logger.debug(f"模型配置列表查询成功: 总数={total}, 当前页={len(models)}, type筛选={query.type}")
return models, total
except Exception as e:
db_logger.error(f"查询模型配置列表失败: {str(e)}")
raise
@staticmethod
def get_list_new(db: Session, query: ModelConfigQueryNew, tenant_id: uuid.UUID | None = None) -> tuple[
dict[str, list[ModelConfig]], Any]:
"""获取模型配置列表"""
db_logger.debug(f"查询模型配置列表: {query.model_dump()}, tenant_id={tenant_id}")
try:
# 构建查询条件
@@ -138,13 +212,15 @@ class ModelConfigRepository:
if query.is_public is not None:
filters.append(ModelConfig.is_public == query.is_public)
if query.is_composite is not None:
filters.append(ModelConfig.is_composite == query.is_composite)
if query.provider:
filters.append(ModelConfig.provider == query.provider)
if query.search:
# 搜索逻辑需要join ModelApiKey表来搜索model_name
search_filter = or_(
ModelConfig.name.ilike(f"%{query.search}%"),
# ModelConfig.description.ilike(f"%{query.search}%")
)
search_filter = ModelConfig.name.ilike(f"%{query.search}%")
filters.append(search_filter)
# 构建基础查询
@@ -152,28 +228,30 @@ class ModelConfigRepository:
joinedload(ModelConfig.api_keys)
)
# 如果需要按provider筛选需要join ModelApiKey表
if query.provider:
base_query = base_query.join(ModelApiKey).filter(
ModelApiKey.provider == query.provider
).distinct()
if filters:
base_query = base_query.filter(and_(*filters))
# 获取总数
total = base_query.count()
query_results = base_query.order_by(desc(ModelConfig.updated_at)).all()
provider_groups: Dict[str, List[ModelConfig]] = {}
for model_config in query_results:
provider = model_config.provider
if provider not in provider_groups:
provider_groups[provider] = []
provider_groups[provider].append(model_config)
# 分页查询
models = base_query.order_by(desc(ModelConfig.updated_at)).offset(
(query.page - 1) * query.pagesize
).limit(query.pagesize).all()
db_logger.debug(f"模型配置列表查询成功: 总数={total}, 当前页={len(models)}, type筛选={query.type}")
return models, total
db_logger.debug(
f"模型配置列表查询成功: 总数={total}, "
f"分组数={len(provider_groups)}, "
f"各分组模型数={[len(v) for v in provider_groups.values()]}, "
f"type筛选={query.type}")
return provider_groups, total
except Exception as e:
db_logger.error(f"查询模型配置列表失败: {str(e)}")
db_logger.error(f"查询模型配置列表失败(按provider分组/无分页): {str(e)}")
raise
@staticmethod
@@ -241,7 +319,7 @@ class ModelConfigRepository:
return None
# 更新字段
update_data = model_data.dict(exclude_unset=True)
update_data = model_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(db_model, field, value)
@@ -303,8 +381,18 @@ class ModelConfigRepository:
# 按提供商统计 - 现在从ModelApiKey表获取
provider_stats = {}
provider_results = db.query(
ModelApiKey.provider, func.count(func.distinct(ModelApiKey.model_config_id))
).group_by(ModelApiKey.provider).all()
# 保留 provider 字段
ModelApiKey.provider,
# 统计中间表中 唯一的 model_config_id 数量(替换原 ModelApiKey.model_config_id
func.count(func.distinct(model_config_api_key_association.c.model_config_id))
).join(
# 联表ModelApiKey <-> 中间表(多对多关联)
model_config_api_key_association,
ModelApiKey.id == model_config_api_key_association.c.api_key_id
).group_by(
# 按 provider 分组(保留原有逻辑)
ModelApiKey.provider
).all()
for provider, count in provider_results:
provider_stats[provider.value] = count
@@ -325,6 +413,37 @@ class ModelConfigRepository:
db_logger.error(f"获取模型统计信息失败: {str(e)}")
raise
@staticmethod
def get_model_config_ids_by_provider(
db: Session,
tenant_id: uuid.UUID,
provider: Any
) -> List[uuid.UUID]:
"""根据tenant_id和provider获取model_config_id列表"""
db_logger.debug(f"查询model_config_id列表: tenant_id={tenant_id}, provider={provider}")
try:
# 查询ModelConfig关联的ModelApiKey筛选出匹配的model_config_id
model_config_ids = db.query(ModelConfig.id).join(
ModelBase, ModelConfig.model_id == ModelBase.id
).filter(
and_(
or_(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_public
),
ModelBase.provider == provider,
~ModelConfig.is_composite
)
).distinct().all()
db_logger.debug(f"查询成功: 数量={len(model_config_ids)}")
return [row[0] for row in model_config_ids]
except Exception as e:
db_logger.error(f"查询model_config_id列表失败: {str(e)}")
raise
class ModelApiKeyRepository:
"""模型API Key Repository"""
@@ -349,7 +468,14 @@ class ModelApiKeyRepository:
db_logger.debug(f"根据模型配置ID查询API Key: model_config_id={model_config_id}")
try:
query = db.query(ModelApiKey).filter(ModelApiKey.model_config_id == model_config_id)
from app.models.models_model import ModelConfig, model_config_api_key_association
query = db.query(ModelApiKey).join(
model_config_api_key_association,
ModelApiKey.id == model_config_api_key_association.c.api_key_id
).filter(
model_config_api_key_association.c.model_config_id == model_config_id
)
if is_active:
query = query.filter(ModelApiKey.is_active)
@@ -368,8 +494,20 @@ class ModelApiKeyRepository:
db_logger.debug(f"创建API Key: {api_key_data.provider}")
try:
db_api_key = ModelApiKey(**api_key_data.dict())
from app.models.models_model import ModelConfig
# 创建API Key不包含model_config_ids
api_key_dict = api_key_data.model_dump(exclude={"model_config_ids"})
db_api_key = ModelApiKey(**api_key_dict)
db.add(db_api_key)
db.flush() # 获取生成的ID
# 关联ModelConfig
if api_key_data.model_config_ids:
for model_config_id in api_key_data.model_config_ids:
model_config = db.query(ModelConfig).filter(ModelConfig.id == model_config_id).first()
if model_config:
db_api_key.model_configs.append(model_config)
db_logger.info(f"API Key已添加到会话: {db_api_key.provider}")
return db_api_key
@@ -391,7 +529,7 @@ class ModelApiKeyRepository:
return None
# 更新字段
update_data = api_key_data.dict(exclude_unset=True)
update_data = api_key_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(db_api_key, field, value)
@@ -451,4 +589,74 @@ class ModelApiKeyRepository:
except Exception as e:
db.rollback()
db_logger.error(f"更新API Key使用统计失败: api_key_id={api_key_id} - {str(e)}")
raise
raise
class ModelBaseRepository:
"""基础模型Repository"""
@staticmethod
def get_by_id(db: Session, model_base_id: uuid.UUID) -> Optional['ModelBase']:
return db.query(ModelBase).filter(ModelBase.id == model_base_id).first()
@staticmethod
def get_list(db: Session, query: 'ModelBaseQuery') -> List['ModelBase']:
filters = []
if query.type:
filters.append(ModelBase.type == query.type)
if query.provider:
filters.append(ModelBase.provider == query.provider)
if query.is_official is not None:
filters.append(ModelBase.is_official == query.is_official)
if query.is_deprecated is not None:
filters.append(ModelBase.is_deprecated == query.is_deprecated)
if query.search:
filters.append(or_(
ModelBase.name.ilike(f"%{query.search}%"),
# ModelBase.description.ilike(f"%{query.search}%")
))
q = db.query(ModelBase)
if filters:
q = q.filter(and_(*filters))
return q.order_by(ModelBase.add_count.desc()).all()
@staticmethod
def create(db: Session, data: dict) -> 'ModelBase':
model_base = ModelBase(**data)
db.add(model_base)
return model_base
@staticmethod
def update(db: Session, model_base_id: uuid.UUID, data: dict) -> Optional['ModelBase']:
model_base = db.query(ModelBase).filter(ModelBase.id == model_base_id).first()
if not model_base:
return None
for key, value in data.items():
setattr(model_base, key, value)
return model_base
@staticmethod
def delete(db: Session, model_base_id: uuid.UUID) -> bool:
model_base = db.query(ModelBase).filter(ModelBase.id == model_base_id).first()
if not model_base:
return False
db.delete(model_base)
return True
@staticmethod
def increment_add_count(db: Session, model_base_id: uuid.UUID) -> bool:
model_base = db.query(ModelBase).filter(ModelBase.id == model_base_id).first()
if not model_base:
return False
model_base.add_count += 1
return True
@staticmethod
def check_added_by_tenant(db: Session, model_base_id: uuid.UUID, tenant_id: uuid.UUID) -> bool:
return db.query(ModelConfig).filter(
ModelConfig.model_id == model_base_id,
ModelConfig.tenant_id == tenant_id
).first() is not None