Files
MemoryBear/api/app/schemas/model_schema.py

162 lines
6.7 KiB
Python

from pydantic import BaseModel, Field, field_serializer, ConfigDict
from typing import Optional, List, Dict, Any
import datetime
import uuid
from app.models.models_model import ModelProvider, ModelType
# ModelConfig Schemas
class ModelConfigBase(BaseModel):
"""模型配置基础Schema"""
name: str = Field(..., description="模型显示名称", max_length=255)
type: ModelType = Field(..., description="模型类型")
description: Optional[str] = Field(None, description="模型描述")
config: Optional[Dict[str, Any]] = Field({}, description="模型配置参数")
is_active: bool = Field(True, description="是否激活")
is_public: bool = Field(False, description="是否公开")
class ApiKeyCreateNested(BaseModel):
"""用于在创建模型时内嵌创建API Key的Schema"""
model_name: str = Field(..., description="模型实际名称", max_length=255)
provider: ModelProvider = Field(..., description="API Key提供商")
api_key: str = Field(..., description="API密钥", max_length=500)
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
priority: str = Field("1", description="优先级", max_length=10)
class ModelConfigCreate(ModelConfigBase):
"""创建模型配置Schema"""
api_keys: Optional[ApiKeyCreateNested] = Field(None, description="同时创建的API Key配置")
skip_validation: Optional[bool] = Field(False, description="是否跳过配置验证")
class ModelConfigUpdate(BaseModel):
"""更新模型配置Schema"""
name: Optional[str] = Field(None, description="模型显示名称", max_length=255)
type: Optional[ModelType] = Field(None, description="模型类型")
description: Optional[str] = Field(None, description="模型描述")
config: Optional[Dict[str, Any]] = Field(None, description="模型配置参数")
is_active: Optional[bool] = Field(None, description="是否激活")
is_public: Optional[bool] = Field(None, description="是否公开")
class ModelConfig(ModelConfigBase):
"""模型配置Schema"""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
created_at: datetime.datetime
updated_at: datetime.datetime
api_keys: List["ModelApiKey"] = []
# ModelApiKey Schemas
class ModelApiKeyBase(BaseModel):
"""API Key基础Schema"""
model_name: str = Field(..., description="模型实际名称", max_length=255)
provider: ModelProvider = Field(..., description="API Key提供商")
api_key: str = Field(..., description="API密钥", max_length=500)
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
config: Optional[Dict[str, Any]] = Field(None, description="API Key特定配置")
is_active: bool = Field(True, description="是否激活")
priority: str = Field("1", description="优先级", max_length=10)
class ModelApiKeyCreate(ModelApiKeyBase):
"""创建API Key Schema"""
model_config_id: uuid.UUID = Field(..., description="模型配置ID")
class ModelApiKeyUpdate(BaseModel):
"""更新API Key Schema"""
model_name: Optional[str] = Field(None, description="模型实际名称", max_length=255)
provider: Optional[ModelProvider] = Field(None, description="API Key提供商")
api_key: Optional[str] = Field(None, description="API密钥", max_length=500)
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
config: Optional[Dict[str, Any]] = Field(None, description="API Key特定配置")
is_active: Optional[bool] = Field(None, description="是否激活")
priority: Optional[str] = Field(None, description="优先级", max_length=10)
class ModelApiKey(ModelApiKeyBase):
"""API Key Schema"""
id: uuid.UUID
model_config_id: uuid.UUID
usage_count: str
last_used_at: Optional[datetime.datetime]
created_at: datetime.datetime
updated_at: datetime.datetime
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("updated_at", when_used="json")
def _serialize_updated_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
model_config = ConfigDict(from_attributes=True)
@field_serializer("last_used_at", when_used="json")
def _serialize_last_used_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
# 查询和响应Schemas
class ModelConfigQuery(BaseModel):
"""模型配置查询Schema"""
type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)")
provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)")
is_active: Optional[bool] = Field(None, description="激活状态筛选")
is_public: Optional[bool] = Field(None, description="公开状态筛选")
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
page: int = Field(1, description="页码", ge=1)
pagesize: int = Field(10, description="每页数量", ge=1, le=100)
class ModelMarketplace(BaseModel):
"""模型广场响应Schema"""
llm_models: List[ModelConfig] = []
embedding_models: List[ModelConfig] = []
rerank_models: List[ModelConfig] = []
total_count: int
active_count: int
# 统计信息Schema
class ModelStats(BaseModel):
"""模型统计信息Schema"""
total_models: int
active_models: int
llm_count: int
embedding_count: int
rerank_count: int
provider_stats: Dict[str, int]
# 验证模型配置Schema
class ModelValidateRequest(BaseModel):
"""验证模型配置请求"""
model_name: str = Field(..., description="模型实际名称")
provider: ModelProvider = Field(..., description="API Key提供商")
api_key: str = Field(..., description="API密钥")
api_base: Optional[str] = Field(None, description="API基础URL")
model_type: Optional[ModelType] = Field(ModelType.LLM, description="模型类型")
test_message: Optional[str] = Field("Hello", description="测试消息")
class ModelValidateResponse(BaseModel):
"""验证模型配置响应"""
valid: bool = Field(..., description="是否有效")
message: str = Field(..., description="验证消息")
response: Optional[str] = Field(None, description="模型响应内容")
elapsed_time: Optional[float] = Field(None, description="响应时间(秒)")
error: Optional[str] = Field(None, description="错误信息")
usage: Optional[Dict[str, Any]] = Field(None, description="Token使用情况")
# 更新前向引用
ModelConfig.model_rebuild()