- Adjust multi-modal memory write behavior for text and visual data - Mask API keys in model list response to prevent exposure - Add capability-based filtering to the model list API
356 lines
17 KiB
Python
356 lines
17 KiB
Python
from pydantic import BaseModel, Field, field_serializer, field_validator, ConfigDict
|
||
from typing import Optional, List, Dict, Any
|
||
import datetime
|
||
import uuid
|
||
|
||
from app.models.models_model import ModelProvider, ModelType, LoadBalanceStrategy
|
||
from app.core.logging_config import get_business_logger
|
||
|
||
schema_logger = get_business_logger()
|
||
|
||
|
||
# ModelConfig Schemas
|
||
class ModelConfigBase(BaseModel):
|
||
"""模型配置基础Schema"""
|
||
name: str = Field(..., description="模型显示名称", max_length=255)
|
||
type: ModelType = Field(..., description="模型类型")
|
||
logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255)
|
||
description: Optional[str] = Field(None, description="模型描述")
|
||
provider: str = Field(..., description="供应商")
|
||
config: Optional[Dict[str, Any]] = Field({}, description="模型配置参数")
|
||
is_active: bool = Field(True, description="是否激活")
|
||
is_public: bool = Field(False, description="是否公开")
|
||
load_balance_strategy: Optional[str] = Field(LoadBalanceStrategy.NONE.value, description="负载均衡策略")
|
||
capability: List[str] = Field(default_factory=list, description="模型能力列表")
|
||
is_omni: bool = Field(False, description="是否为Omni模型")
|
||
model_id: Optional[uuid.UUID] = Field(None, description="基础模型ID")
|
||
|
||
|
||
class ApiKeyCreateNested(BaseModel):
|
||
"""用于在创建模型时内嵌创建API Key的Schema"""
|
||
model_name: Optional[str] = Field(None, description="模型实际名称", max_length=255)
|
||
description: Optional[str] = Field(None, description="备注")
|
||
provider: Optional[str] = Field(None, description="API Key提供商")
|
||
api_key: str = Field(..., description="API密钥", max_length=500)
|
||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||
capability: Optional[List[str]] = Field(None, description="模型能力列表")
|
||
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
|
||
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
|
||
priority: str = Field("1", description="优先级", max_length=10)
|
||
|
||
|
||
class ModelConfigCreate(ModelConfigBase):
|
||
"""创建模型配置Schema"""
|
||
api_keys: Optional[List[ApiKeyCreateNested]] = Field(None, description="同时创建的API Key配置")
|
||
skip_validation: Optional[bool] = Field(False, description="是否跳过配置验证")
|
||
|
||
|
||
class CompositeModelCreate(BaseModel):
|
||
"""创建组合模型Schema"""
|
||
name: str = Field(..., description="组合模型名称", max_length=255)
|
||
type: Optional[ModelType] = Field(None, description="模型类型")
|
||
logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255)
|
||
description: Optional[str] = Field(None, description="模型描述")
|
||
config: Optional[Dict[str, Any]] = Field({}, description="模型配置参数")
|
||
is_active: bool = Field(True, description="是否激活")
|
||
is_public: bool = Field(False, description="是否公开")
|
||
api_key_ids: List[uuid.UUID] = Field(..., description="绑定的API Key ID列表")
|
||
load_balance_strategy: Optional[str] = Field(default=LoadBalanceStrategy.NONE.value, description="负载均衡策略")
|
||
|
||
|
||
class ModelConfigUpdate(BaseModel):
|
||
"""更新模型配置Schema"""
|
||
name: Optional[str] = Field(None, description="模型显示名称", max_length=255)
|
||
type: Optional[ModelType] = Field(None, description="模型类型")
|
||
provider: Optional[str] = Field(None, description="供应商")
|
||
logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255)
|
||
description: Optional[str] = Field(None, description="模型描述")
|
||
config: Optional[Dict[str, Any]] = Field(None, description="模型配置参数")
|
||
is_active: Optional[bool] = Field(None, description="是否激活")
|
||
is_public: Optional[bool] = Field(None, description="是否公开")
|
||
capability: Optional[List[str]] = Field(None, description="模型能力列表")
|
||
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
|
||
|
||
|
||
class ModelConfig(ModelConfigBase):
|
||
"""模型配置Schema"""
|
||
model_config = ConfigDict(from_attributes=True)
|
||
|
||
id: uuid.UUID
|
||
created_at: datetime.datetime
|
||
updated_at: datetime.datetime
|
||
api_keys: List["ModelApiKey"] = []
|
||
|
||
@staticmethod
|
||
def mask_api_key(key: str, prefix: int = 4, suffix: int = 4) -> str:
|
||
if not key or len(key) <= prefix + suffix:
|
||
return "*" * len(key)
|
||
return key[:prefix] + "*" * (len(key) - prefix - suffix) + key[-suffix:]
|
||
|
||
@field_validator("api_keys", mode="after")
|
||
@classmethod
|
||
def filter_active_api_keys(cls, api_keys: List["ModelApiKey"]) -> List["ModelApiKey"]:
|
||
return [key for key in api_keys if key.is_active]
|
||
|
||
@field_serializer("created_at", when_used="json")
|
||
def _serialize_created_at(self, dt: datetime.datetime | None):
|
||
return int(dt.timestamp() * 1000) if dt else None
|
||
|
||
@field_serializer("api_keys", when_used="json")
|
||
def _serialize_api_keys(self, api_keys: List["ModelApiKey"]):
|
||
result = []
|
||
for api_key in api_keys:
|
||
data = api_key.model_dump()
|
||
data["api_key"] = self.mask_api_key(api_key.api_key)
|
||
result.append(data)
|
||
return result
|
||
|
||
@field_serializer("updated_at", when_used="json")
|
||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||
return int(dt.timestamp() * 1000) if dt else None
|
||
|
||
|
||
# ModelApiKey Schemas
|
||
class ModelApiKeyCreateByProvider(BaseModel):
|
||
"""基于供应商创建API Key Schema"""
|
||
provider: ModelProvider = Field(..., description="API Key提供商")
|
||
api_key: str = Field(..., description="API密钥", max_length=500)
|
||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||
description: Optional[str] = Field(None, description="备注")
|
||
capability: Optional[List[str]] = Field(None, description="模型能力列表")
|
||
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
|
||
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
|
||
is_active: bool = Field(True, description="是否激活")
|
||
priority: str = Field("1", description="优先级", max_length=10)
|
||
model_config_ids: Optional[List[uuid.UUID]] = Field(None, description="关联的模型配置ID列表")
|
||
|
||
|
||
class ModelApiKeyBase(BaseModel):
|
||
"""API Key基础Schema"""
|
||
model_name: str = Field(..., description="模型实际名称", max_length=255)
|
||
description: Optional[str] = Field(None, description="备注")
|
||
provider: ModelProvider = Field(..., description="API Key提供商")
|
||
api_key: str = Field(..., description="API密钥", max_length=500)
|
||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||
capability: Optional[List[str]] = Field(None, description="模型能力列表")
|
||
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
|
||
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
|
||
is_active: bool = Field(True, description="是否激活")
|
||
priority: str = Field("1", description="优先级", max_length=10)
|
||
|
||
|
||
class ModelApiKeyCreate(ModelApiKeyBase):
|
||
"""创建API Key Schema"""
|
||
model_config_ids: Optional[List[uuid.UUID]] = Field(None, description="关联的模型配置ID列表")
|
||
|
||
|
||
class ModelApiKeyUpdate(BaseModel):
|
||
"""更新API Key Schema"""
|
||
model_name: Optional[str] = Field(None, description="模型实际名称", max_length=255)
|
||
provider: Optional[ModelProvider] = Field(None, description="API Key提供商")
|
||
api_key: Optional[str] = Field(None, description="API密钥", max_length=500)
|
||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||
capability: Optional[List[str]] = Field(None, description="模型能力列表")
|
||
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
|
||
config: Optional[Dict[str, Any]] = Field(None, description="API Key特定配置")
|
||
is_active: Optional[bool] = Field(None, description="是否激活")
|
||
priority: Optional[str] = Field(None, description="优先级", max_length=10)
|
||
|
||
|
||
class ModelApiKey(ModelApiKeyBase):
|
||
"""API Key Schema"""
|
||
id: uuid.UUID
|
||
usage_count: str
|
||
last_used_at: Optional[datetime.datetime]
|
||
created_at: datetime.datetime
|
||
updated_at: datetime.datetime
|
||
model_configs: Any = Field(default=None, exclude=True)
|
||
model_config_ids: List[uuid.UUID] = Field(default_factory=list, description="关联的模型配置ID列表")
|
||
|
||
def model_post_init(self, __context: Any) -> None:
|
||
"""实例化后强制提取 model_configs 的ID到 model_config_ids"""
|
||
# 如果手动传入了 model_config_ids,不覆盖
|
||
if self.model_config_ids and len(self.model_config_ids) > 0:
|
||
return
|
||
|
||
# 从 model_configs 提取ID(只提取与 model_name 相同的非组合模型)
|
||
if self.model_configs is not None:
|
||
try:
|
||
# 情况1:ORM 对象列表(SQLAlchemy 关联)
|
||
if hasattr(self.model_configs, '__iter__') and not isinstance(self.model_configs, dict):
|
||
self.model_config_ids = [
|
||
mc.id for mc in self.model_configs
|
||
if hasattr(mc, 'id')
|
||
and not getattr(mc, 'is_composite', False)
|
||
and getattr(mc, 'name', None) == self.model_name
|
||
]
|
||
# 情况2:字典列表
|
||
elif isinstance(self.model_configs, list):
|
||
self.model_config_ids = [
|
||
mc['id'] if isinstance(mc, dict) else mc.id
|
||
for mc in self.model_configs
|
||
if ((isinstance(mc, dict)
|
||
and 'id' in mc
|
||
and not mc.get('is_composite', False)
|
||
and mc.get('name') == self.model_name) or
|
||
(hasattr(mc, 'id')
|
||
and not getattr(mc, 'is_composite', False)
|
||
and getattr(mc, 'name', None) == self.model_name))
|
||
]
|
||
except Exception as e:
|
||
schema_logger.warning(f"提取 model_config_ids 失败:{e}")
|
||
self.model_config_ids = []
|
||
|
||
model_config = ConfigDict(
|
||
from_attributes=True, # 支持从 ORM 解析
|
||
arbitrary_types_allowed=True, # 允许任意类型(ORM 对象)
|
||
populate_by_name=True, # 按属性名匹配字段
|
||
validate_assignment=True # 确保赋值触发校验
|
||
)
|
||
|
||
@field_serializer("created_at", when_used="json")
|
||
def _serialize_created_at(self, dt: datetime.datetime):
|
||
return int(dt.timestamp() * 1000) if dt else None
|
||
|
||
@field_serializer("updated_at", when_used="json")
|
||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||
return int(dt.timestamp() * 1000) if dt else None
|
||
|
||
@field_serializer("last_used_at", when_used="json")
|
||
def _serialize_last_used_at(self, dt: datetime.datetime):
|
||
return int(dt.timestamp() * 1000) if dt else None
|
||
|
||
|
||
class ModelConfigQuery(BaseModel):
|
||
"""模型配置查询Schema"""
|
||
type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)")
|
||
provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)")
|
||
capability: Optional[List[str]] = Field(None, description="能力筛选(支持多个)")
|
||
is_active: Optional[bool] = Field(None, description="激活状态筛选")
|
||
is_public: Optional[bool] = Field(None, description="公开状态筛选")
|
||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||
page: int = Field(1, description="页码", ge=1)
|
||
pagesize: int = Field(10, description="每页数量", ge=1, le=100)
|
||
|
||
|
||
# 查询和响应Schemas
|
||
class ModelConfigQueryNew(BaseModel):
|
||
"""模型配置查询Schema"""
|
||
type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)")
|
||
provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)")
|
||
is_active: Optional[bool] = Field(None, description="激活状态筛选")
|
||
is_public: Optional[bool] = Field(None, description="公开状态筛选")
|
||
is_composite: Optional[bool] = Field(None, description="组合模型筛选")
|
||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||
|
||
|
||
class ModelMarketplace(BaseModel):
|
||
"""模型广场响应Schema"""
|
||
llm_models: List[ModelConfig] = []
|
||
embedding_models: List[ModelConfig] = []
|
||
rerank_models: List[ModelConfig] = []
|
||
total_count: int
|
||
active_count: int
|
||
|
||
|
||
# 统计信息Schema
|
||
class ModelStats(BaseModel):
|
||
"""模型统计信息Schema"""
|
||
total_models: int
|
||
active_models: int
|
||
llm_count: int
|
||
embedding_count: int
|
||
rerank_count: int
|
||
provider_stats: Dict[str, int]
|
||
|
||
|
||
# 验证模型配置Schema
|
||
class ModelValidateRequest(BaseModel):
|
||
"""验证模型配置请求"""
|
||
model_name: str = Field(..., description="模型实际名称")
|
||
provider: ModelProvider = Field(..., description="API Key提供商")
|
||
api_key: str = Field(..., description="API密钥")
|
||
api_base: Optional[str] = Field(None, description="API基础URL")
|
||
model_type: Optional[ModelType] = Field(ModelType.LLM, description="模型类型")
|
||
test_message: Optional[str] = Field("Hello", description="测试消息")
|
||
|
||
|
||
class ModelValidateResponse(BaseModel):
|
||
"""验证模型配置响应"""
|
||
valid: bool = Field(..., description="是否有效")
|
||
message: str = Field(..., description="验证消息")
|
||
response: Optional[str] = Field(None, description="模型响应内容")
|
||
elapsed_time: Optional[float] = Field(None, description="响应时间(秒)")
|
||
error: Optional[str] = Field(None, description="错误信息")
|
||
usage: Optional[Dict[str, Any]] = Field(None, description="Token使用情况")
|
||
|
||
|
||
# 更新前向引用
|
||
ModelConfig.model_rebuild()
|
||
|
||
|
||
# ModelBase Schemas
|
||
class ModelBaseCreate(BaseModel):
|
||
"""创建基础模型Schema"""
|
||
name: str = Field(..., description="模型唯一标识", max_length=255)
|
||
type: ModelType = Field(..., description="模型类型")
|
||
provider: ModelProvider = Field(..., description="提供商")
|
||
logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255)
|
||
description: Optional[str] = Field(None, description="模型描述")
|
||
is_official: bool = Field(True, description="是否供应商官方模型")
|
||
tags: List[str] = Field(default_factory=list, description="模型标签")
|
||
capability: List[str] = Field(default_factory=list, description="模型能力列表(如['vision', 'audio', 'video'])")
|
||
is_omni: bool = Field(False, description="是否为Omni模型")
|
||
|
||
|
||
class ModelBaseUpdate(BaseModel):
|
||
"""更新基础模型Schema"""
|
||
name: Optional[str] = Field(None, description="模型唯一标识", max_length=255)
|
||
type: Optional[ModelType] = Field(None, description="模型类型")
|
||
provider: Optional[ModelProvider] = Field(None, description="提供商")
|
||
logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255)
|
||
description: Optional[str] = Field(None, description="模型描述")
|
||
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
|
||
is_official: Optional[bool] = Field(None, description="是否供应商官方模型")
|
||
tags: Optional[List[str]] = Field(None, description="模型标签")
|
||
capability: Optional[List[str]] = Field(None, description="模型能力列表")
|
||
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
|
||
|
||
|
||
class ModelBase(BaseModel):
|
||
"""基础模型Schema"""
|
||
model_config = ConfigDict(from_attributes=True)
|
||
|
||
id: uuid.UUID
|
||
name: str
|
||
type: str
|
||
provider: str
|
||
logo: Optional[str]
|
||
description: Optional[str]
|
||
is_deprecated: bool
|
||
is_official: bool
|
||
tags: List[str]
|
||
add_count: int
|
||
capability: List[str] = []
|
||
is_omni: bool = False
|
||
|
||
|
||
class ModelBaseQuery(BaseModel):
|
||
"""基础模型查询Schema"""
|
||
type: Optional[ModelType] = Field(None, description="模型类型")
|
||
provider: Optional[ModelProvider] = Field(None, description="提供商")
|
||
is_official: Optional[bool] = Field(None, description="是否官方模型")
|
||
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
|
||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||
|
||
|
||
class ModelInfo(BaseModel):
|
||
"""模型信息Schema"""
|
||
model_name: str = Field(..., description="模型名称")
|
||
provider: str = Field(..., description="模型提供商")
|
||
api_key: str = Field(..., description="API密钥")
|
||
api_base: str = Field(..., description="API基础URL")
|
||
is_omni: bool = Field(default=False, description="是否为omni模型")
|
||
model_type: ModelType = Field(..., description="模型类型")
|
||
capability: List[str] = Field(default_factory=list, description="模型能力列表")
|