Files
MemoryBear/api/app/services/model_service.py
Timebomb2018 af6fde414f fix(agent and model):
1. From the model square to the model list, the added models are initially set to be inactive. When manually activating them, a mandatory API key configuration is required.
2. Copying of applications (agent, workflow, multi_agent)
2026-03-06 14:40:07 +08:00

791 lines
34 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from datetime import datetime
from sqlalchemy.orm import Session
from typing import List, Optional, Dict, Any
import uuid
import math
import time
import asyncio
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, LoadBalanceStrategy, ModelProvider
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository, ModelBaseRepository
from app.schemas import model_schema
from app.schemas.model_schema import (
ModelConfigCreate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
ModelConfigQuery, ModelStats, ModelConfigQueryNew
)
from app.core.logging_config import get_business_logger
from app.schemas.response_schema import PageData, PageMeta
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
logger = get_business_logger()
class ModelConfigService:
"""模型配置服务"""
@staticmethod
def get_model_by_id(db: Session, model_id: uuid.UUID, tenant_id: uuid.UUID | None = None) -> ModelConfig:
"""根据ID获取模型配置"""
model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
if not model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
return model
@staticmethod
def get_model_list(db: Session, query: ModelConfigQuery, tenant_id: uuid.UUID | None = None) -> PageData:
"""获取模型配置列表"""
models, total = ModelConfigRepository.get_list(db, query, tenant_id=tenant_id)
pages = math.ceil(total / query.pagesize) if total > 0 else 0
return PageData(
page=PageMeta(
page=query.page,
pagesize=query.pagesize,
total=total,
hasnext=query.page < pages
),
items=[model_schema.ModelConfig.model_validate(model) for model in models]
)
@staticmethod
def get_model_list_new(db: Session, query: ModelConfigQueryNew, tenant_id: uuid.UUID | None = None) -> List[dict]:
"""获取模型配置列表"""
provider_groups, total = ModelConfigRepository.get_list_new(db, query, tenant_id=tenant_id)
items = []
for provider, models in provider_groups.items():
# 验证每个模型并封装分组信息
validated_models = [model_schema.ModelConfig.model_validate(model) for model in models]
tags = list({model.type for model in validated_models})
group_item = {
"provider": provider, # 服务商名称
"logo": validated_models[0].logo,
"tags": tags,
"models": validated_models # 该服务商下的所有模型
}
items.append(group_item)
return items
@staticmethod
def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig:
"""根据名称获取模型配置"""
model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id)
if not model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
return model
@staticmethod
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ModelConfig]:
"""按名称模糊匹配获取模型配置列表"""
return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit)
@staticmethod
async def validate_model_config(
db: Session,
*,
model_name: str,
provider: str,
api_key: str,
api_base: Optional[str] = None,
model_type: str = "llm",
test_message: str = "Hello",
is_omni: bool = False
) -> Dict[str, Any]:
"""验证模型配置是否有效
Args:
db: 数据库会话
model_name: 模型名称
provider: 提供商
api_key: API密钥
api_base: API基础URL
model_type: 模型类型 (llm/chat/embedding/rerank)
test_message: 测试消息
is_omni: 是否为Omni模型
Returns:
Dict: 验证结果
"""
from app.core.models import RedBearLLM, RedBearRerank
from app.core.models.base import RedBearModelConfig
from app.core.models.embedding import RedBearEmbeddings
import traceback
try:
start_time = time.time()
model_config = RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base,
is_omni=is_omni,
temperature=0.7,
max_tokens=100
)
# 根据模型类型选择不同的验证方式
model_type_lower = model_type.lower()
if model_type_lower in ["llm", "chat"]:
# LLM/Chat 模型验证 - 统一使用字符串输入
llm = RedBearLLM(model_config, type=ModelType.LLM if model_type_lower == "llm" else ModelType.CHAT)
response = await llm.ainvoke(test_message)
elapsed_time = time.time() - start_time
content = response.content if hasattr(response, 'content') else str(response)
usage = None
if hasattr(response, 'usage_metadata'):
usage = {
"input_tokens": getattr(response.usage_metadata, 'input_tokens', 0),
"output_tokens": getattr(response.usage_metadata, 'output_tokens', 0),
"total_tokens": getattr(response.usage_metadata, 'total_tokens', 0)
}
return {
"valid": True,
"message": f"{model_type.upper()} 模型配置验证成功",
"response": content,
"elapsed_time": elapsed_time,
"usage": usage,
"error": None
}
elif model_type_lower == "embedding":
# Embedding 模型验证(在线程中运行同步方法)
embedding = RedBearEmbeddings(model_config)
test_texts = [test_message, "测试文本"]
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
elapsed_time = time.time() - start_time
return {
"valid": True,
"message": "Embedding 模型配置验证成功",
"response": f"成功生成 {len(vectors)} 个向量,维度: {len(vectors[0]) if vectors else 0}",
"elapsed_time": elapsed_time,
"usage": {
"input_tokens": len(test_message),
"vector_count": len(vectors),
"vector_dimension": len(vectors[0]) if vectors else 0
},
"error": None
}
elif model_type_lower == "rerank":
# Rerank 模型验证(在线程中运行同步方法)
rerank = RedBearRerank(model_config)
query = test_message
documents = ["这是第一个文档", "这是第二个文档", "这是第三个文档"]
results = await asyncio.to_thread(rerank.rerank, query=query, documents=documents, top_n=3)
elapsed_time = time.time() - start_time
return {
"valid": True,
"message": "Rerank 模型配置验证成功",
"response": f"成功对 {len(documents)} 个文档进行重排序,返回 top {len(results) if results else 0} 结果",
"elapsed_time": elapsed_time,
"usage": {
"query_length": len(query),
"document_count": len(documents),
"result_count": len(results) if results else 0
},
"error": None
}
else:
return {
"valid": False,
"message": "不支持的模型类型",
"response": None,
"elapsed_time": None,
"usage": None,
"error": f"不支持的模型类型: {model_type}"
}
except Exception as e:
# 提取详细的错误信息
error_message = str(e)
error_type = type(e).__name__
print("=========error_message:",error_message.lower())
# 特殊处理常见的错误类型
if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower():
# 区域/国家限制(适用于所有提供商)
error_message = "区域限制: 该模型在当前区域或国家/地区不可用,请检查提供商的服务区域限制"
elif "ValidationException" in error_type or "ValidationException" in error_message:
# 其他验证错误
if "access denied" in error_message.lower():
error_message = "访问被拒绝: 请检查 API 凭证和权限配置"
else:
error_message = f"验证失败: {error_message}"
elif "AuthenticationError" in error_type or "authentication" in error_message.lower():
error_message = "认证失败: API Key 无效或已过期"
elif "RateLimitError" in error_type or "rate limit" in error_message.lower():
error_message = "请求频率限制: 已超过 API 调用限制"
elif "InvalidRequestError" in error_type or "invalid request" in error_message.lower():
error_message = f"无效请求: {error_message}"
elif "model_copy" in error_message:
error_message = "模型消息格式错误: 请确保使用正确的模型类型LLM/Chat"
# 记录详细错误日志
logger.error(f"模型验证失败 - 类型: {error_type}, 模型: {model_name}, 提供商: {provider}")
logger.error(f"错误详情: {error_message}")
logger.debug(f"完整堆栈: {traceback.format_exc()}")
return {
"valid": False,
"message": f"{model_type.upper()} 模型配置验证失败",
"response": None,
"elapsed_time": None,
"usage": None,
"error": error_message,
"error_type": error_type
}
@staticmethod
async def create_model(db: Session, model_data: ModelConfigCreate, tenant_id: uuid.UUID) -> ModelConfig:
"""创建模型配置"""
# 检查名称是否已存在(同租户内)
if ModelConfigRepository.get_by_name(db, model_data.name, provider=model_data.provider, tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
# 验证配置
if not model_data.skip_validation and model_data.api_keys:
api_key_data_list = model_data.api_keys
for api_key_data in api_key_data_list:
validation_result = await ModelConfigService.validate_model_config(
db=db,
model_name=model_data.name,
provider=model_data.provider,
api_key=api_key_data.api_key,
api_base=api_key_data.api_base,
model_type=model_data.type,
test_message="Hello",
is_omni=model_data.is_omni
)
if not validation_result["valid"]:
raise BusinessException(
f"模型配置验证失败: {validation_result['error']}",
BizCode.INVALID_PARAMETER
)
# 事务处理
api_key_datas = model_data.api_keys
model_config_data = model_data.model_dump(exclude={"api_keys", "skip_validation"})
# 添加租户ID
model_config_data["tenant_id"] = tenant_id
model = ModelConfigRepository.create(db, model_config_data)
db.flush() # 获取生成的 ID
if api_key_datas:
for api_key_data in api_key_datas:
api_key_data.model_name = model_data.name
api_key_data.provider = model_data.provider
# 同步capability和is_omni
api_key_data.capability = model_data.capability
api_key_data.is_omni = model_data.is_omni
api_key_create_schema = ModelApiKeyCreate(
model_config_ids=[model.id],
**api_key_data.model_dump()
)
ModelApiKeyRepository.create(db, api_key_create_schema)
db.commit()
db.refresh(model)
return model
@staticmethod
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, tenant_id: uuid.UUID | None = None) -> ModelConfig:
"""更新模型配置"""
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
if not existing_model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
if model_data.name and model_data.name != existing_model.name:
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
db.commit()
db.refresh(model)
return model
@staticmethod
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
"""创建组合模型"""
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
# 验证所有 API Key 存在且类型匹配
for api_key_id in model_data.api_key_ids:
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if not api_key:
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
# 检查 API Key 关联的模型配置类型
for model_config in api_key.model_configs:
# chat 和 llm 类型可以兼容
compatible_types = {ModelType.LLM, ModelType.CHAT}
config_type = model_config.type
request_type = model_data.type
if not (config_type == request_type or
(config_type in compatible_types and request_type in compatible_types)):
raise BusinessException(
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
BizCode.INVALID_PARAMETER
)
# if model_config.is_composite:
# raise BusinessException(
# f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型",
# BizCode.INVALID_PARAMETER
# )
# 创建组合模型
model_config_data = {
"tenant_id": tenant_id,
"name": model_data.name,
"type": model_data.type,
"logo": model_data.logo,
"description": model_data.description,
"provider": ModelProvider.COMPOSITE,
"config": model_data.config,
"is_active": model_data.is_active,
"is_public": model_data.is_public,
"is_composite": True
}
if "load_balance_strategy" in model_data.model_fields_set:
model_config_data["load_balance_strategy"] = model_data.load_balance_strategy
model = ModelConfigRepository.create(db, model_config_data)
db.flush()
# 关联 API Keys
for api_key_id in model_data.api_key_ids:
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if api_key:
model.api_keys.append(api_key)
db.commit()
db.refresh(model)
return model
@staticmethod
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
"""更新组合模型"""
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
if not existing_model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
if model_data.name and model_data.name != existing_model.name:
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
if not existing_model.is_composite:
raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER)
# 验证所有 API Key 存在且类型匹配
for api_key_id in model_data.api_key_ids:
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if not api_key:
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
for model_config in api_key.model_configs:
compatible_types = {ModelType.LLM, ModelType.CHAT}
config_type = model_config.type
request_type = existing_model.type
if not (config_type == request_type or
(config_type in compatible_types and request_type in compatible_types)):
raise BusinessException(
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
BizCode.INVALID_PARAMETER
)
# 更新基本信息
existing_model.name = model_data.name
# existing_model.type = model_data.type
existing_model.logo = model_data.logo
existing_model.description = model_data.description
existing_model.config = model_data.config
existing_model.is_active = model_data.is_active
existing_model.is_public = model_data.is_public
if "load_balance_strategy" in model_data.model_fields_set:
existing_model.load_balance_strategy = model_data.load_balance_strategy
# 更新 API Keys 关联
existing_model.api_keys.clear()
for api_key_id in model_data.api_key_ids:
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if api_key:
existing_model.api_keys.append(api_key)
db.commit()
db.refresh(existing_model)
return existing_model
@staticmethod
def delete_model(db: Session, model_id: uuid.UUID, tenant_id: uuid.UUID | None = None) -> bool:
"""删除模型配置"""
if not ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id):
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
success = ModelConfigRepository.delete(db, model_id, tenant_id=tenant_id)
db.commit()
return success
@staticmethod
def get_model_stats(db: Session) -> ModelStats:
"""获取模型统计信息"""
stats_data = ModelConfigRepository.get_stats(db)
return ModelStats(
total_models=stats_data["total_models"],
active_models=stats_data["active_models"],
llm_count=stats_data["llm_count"],
embedding_count=stats_data["embedding_count"],
rerank_count=stats_data["rerank_count"],
provider_stats=stats_data["provider_stats"]
)
class ModelApiKeyService:
"""模型API Key服务"""
@staticmethod
def get_api_key_by_id(db: Session, api_key_id: uuid.UUID) -> ModelApiKey:
"""根据ID获取API Key"""
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if not api_key:
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
return api_key
@staticmethod
def get_api_keys_by_model(db: Session, model_config_id: uuid.UUID, is_active: bool = True) -> list[ModelApiKey]:
"""根据模型配置ID获取API Key列表"""
if not ModelConfigRepository.get_by_id(db, model_config_id):
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
return ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active)
@staticmethod
async def create_api_key_by_provider(db: Session, data: model_schema.ModelApiKeyCreateByProvider) -> tuple[
list[Any], list[Any]]:
"""根据provider为多个ModelConfig创建API Key"""
created_keys = []
failed_models = [] # 记录验证失败的模型
for model_config_id in data.model_config_ids:
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
if not model_config:
continue
data.is_omni = model_config.is_omni
data.capability = model_config.capability
# 从ModelBase获取model_name
model_name = model_config.model_base.name if model_config.model_base else model_config.name
# 检查是否存在API Key包括软删除需要考虑tenant_id
existing_key = db.query(ModelApiKey).join(
ModelApiKey.model_configs
).filter(
ModelApiKey.api_key == data.api_key,
ModelApiKey.provider == data.provider,
ModelApiKey.model_name == model_name,
ModelConfig.tenant_id == model_config.tenant_id
).first()
if existing_key:
# 如果已存在,重新激活并更新
if existing_key.is_active:
continue
existing_key.is_active = True
existing_key.api_base = data.api_base
existing_key.description = data.description
existing_key.config = data.config
existing_key.priority = data.priority
existing_key.model_name = model_name
existing_key.capability = data.capability
existing_key.is_omni = data.is_omni
# 检查是否已关联该模型配置
if model_config not in existing_key.model_configs:
existing_key.model_configs.append(model_config)
created_keys.append(existing_key)
continue
# 验证配置
validation_result = await ModelConfigService.validate_model_config(
db=db,
model_name=model_name,
provider=data.provider,
api_key=data.api_key,
api_base=data.api_base,
model_type=model_config.type,
test_message="Hello",
is_omni=data.is_omni
)
if not validation_result["valid"]:
# 记录验证失败的模型,但不抛出异常
failed_models.append(model_name)
continue
# 创建API Key
api_key_data = ModelApiKeyCreate(
model_config_ids=[model_config_id],
model_name=model_name,
description=data.description,
provider=data.provider,
api_key=data.api_key,
api_base=data.api_base,
capability=data.capability,
is_omni=data.is_omni,
config=data.config,
is_active=data.is_active,
priority=data.priority
)
api_key_obj = ModelApiKeyRepository.create(db, api_key_data)
created_keys.append(api_key_obj)
if created_keys:
db.commit()
for key in created_keys:
db.refresh(key)
return created_keys, failed_models
@staticmethod
async def create_api_key(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey:
# 验证所有关联的模型配置是否存在
if api_key_data.model_config_ids:
for model_config_id in api_key_data.model_config_ids:
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
if not model_config:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
if api_key_data.is_omni is None:
api_key_data.is_omni = model_config.is_omni
if api_key_data.capability is None:
api_key_data.capability = model_config.capability
# 检查API Key是否已存在(包括软删除)需要考虑tenant_id
existing_key = db.query(ModelApiKey).join(
ModelApiKey.model_configs
).filter(
ModelApiKey.api_key == api_key_data.api_key,
ModelApiKey.provider == api_key_data.provider,
ModelApiKey.model_name == api_key_data.model_name,
ModelConfig.tenant_id == model_config.tenant_id
).first()
if existing_key:
if existing_key.is_active:
# 如果已激活,跳过
raise BusinessException("该API Key已存在", BizCode.DUPLICATE_NAME)
# 如果已存在,重新激活并更新
existing_key.is_active = True
existing_key.api_base = api_key_data.api_base
existing_key.description = api_key_data.description
existing_key.config = api_key_data.config
existing_key.priority = api_key_data.priority
existing_key.model_name = api_key_data.model_name
existing_key.capability = api_key_data.capability
existing_key.is_omni = api_key_data.is_omni
# 检查是否已关联该模型配置
if model_config not in existing_key.model_configs:
existing_key.model_configs.append(model_config)
db.commit()
db.refresh(existing_key)
return existing_key
# 验证配置
validation_result = await ModelConfigService.validate_model_config(
db=db,
model_name=api_key_data.model_name,
provider=api_key_data.provider,
api_key=api_key_data.api_key,
api_base=api_key_data.api_base,
model_type=model_config.type,
test_message="Hello",
is_omni=api_key_data.is_omni
)
if not validation_result["valid"]:
raise BusinessException(
f"模型配置验证失败: {validation_result['error']}",
BizCode.INVALID_PARAMETER
)
api_key = ModelApiKeyRepository.create(db, api_key_data)
db.commit()
db.refresh(api_key)
return api_key
@staticmethod
async def update_api_key(db: Session, api_key_id: uuid.UUID, api_key_data: ModelApiKeyUpdate) -> ModelApiKey:
"""更新API Key"""
existing_api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if not existing_api_key:
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
# 获取关联的模型配置以获取模型类型
if existing_api_key.model_configs:
model_config = existing_api_key.model_configs[0]
validation_result = await ModelConfigService.validate_model_config(
db=db,
model_name=api_key_data.model_name or existing_api_key.model_name,
provider=api_key_data.provider or existing_api_key.provider,
api_key=api_key_data.api_key or existing_api_key.api_key,
api_base=api_key_data.api_base or existing_api_key.api_base,
model_type=model_config.type,
test_message="Hello",
is_omni=model_config.is_omni
)
if not validation_result["valid"]:
raise BusinessException(
f"模型配置验证失败: {validation_result['error']}",
BizCode.INVALID_PARAMETER
)
api_key = ModelApiKeyRepository.update(db, api_key_id, api_key_data)
db.commit()
db.refresh(api_key)
return api_key
@staticmethod
def delete_api_key(db: Session, api_key_id: uuid.UUID) -> bool:
"""删除API Key"""
if not ModelApiKeyRepository.get_by_id(db, api_key_id):
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
success = ModelApiKeyRepository.delete(db, api_key_id)
db.commit()
return success
@staticmethod
def get_available_api_key(db: Session, model_config_id: uuid.UUID) -> Optional[ModelApiKey]:
"""获取可用的API Key根据负载均衡策略"""
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
if not model_config:
return None
api_keys = [key for key in model_config.api_keys if key.is_active]
if not api_keys:
return None
# 如果是轮询策略,按使用次数最少,次数相同则选最早使用的
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
# 否则返回第一个
return api_keys[0]
@staticmethod
def record_api_key_usage(db: Session, api_key_id: uuid.UUID | None) -> bool:
"""记录API Key使用"""
if api_key_id:
success = ModelApiKeyRepository.update_usage(db, api_key_id)
if success:
db.commit()
return success
return False
@staticmethod
def get_a_api_key(db: Session, model_config_id: uuid.UUID) -> ModelApiKey:
api_kes = ModelApiKeyService.get_api_keys_by_model(db, model_config_id)
if api_kes and len(api_kes) > 0:
return api_kes[0]
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
class ModelBaseService:
"""基础模型服务"""
@staticmethod
def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List:
models = ModelBaseRepository.get_list(db, query)
provider_groups = {}
for m in models:
model_dict = model_schema.ModelBase.model_validate(m).model_dump()
if tenant_id:
model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id)
provider = m.provider
if provider not in provider_groups:
provider_groups[provider] = {
"provider": provider,
"models": []
}
provider_groups[provider]["models"].append(model_dict)
return list(provider_groups.values())
@staticmethod
def get_model_base_by_id(db: Session, model_base_id: uuid.UUID):
model = ModelBaseRepository.get_by_id(db, model_base_id)
if not model:
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
return model
@staticmethod
def create_model_base(db: Session, data: model_schema.ModelBaseCreate):
existing = ModelBaseRepository.get_by_name_and_provider(db, data.name, data.provider)
if existing:
raise BusinessException("模型已存在", BizCode.DUPLICATE_NAME)
model_base = ModelBaseRepository.create(db, data.model_dump())
db.commit()
db.refresh(model_base)
return model_base
@staticmethod
def update_model_base(db: Session, model_base_id: uuid.UUID, data: model_schema.ModelBaseUpdate):
model_base = ModelBaseRepository.update(db, model_base_id, data.model_dump(exclude_unset=True))
if not model_base:
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
db.commit()
db.refresh(model_base)
return model_base
@staticmethod
def delete_model_base(db: Session, model_base_id: uuid.UUID) -> bool:
success = ModelBaseRepository.delete(db, model_base_id)
if not success:
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
db.commit()
return success
@staticmethod
def add_model_from_plaza(db: Session, model_base_id: uuid.UUID, tenant_id: uuid.UUID) -> ModelConfig:
model_base = ModelBaseRepository.get_by_id(db, model_base_id)
if not model_base:
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id):
raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME)
model_config_data = {
"model_id": model_base_id,
"tenant_id": tenant_id,
"name": model_base.name,
"provider": model_base.provider,
"type": model_base.type,
"logo": model_base.logo,
"description": model_base.description,
"capability": model_base.capability,
"is_omni": model_base.is_omni,
"is_active": False,
"is_composite": False
}
model_config = ModelConfigRepository.create(db, model_config_data)
ModelBaseRepository.increment_add_count(db, model_base_id)
db.commit()
db.refresh(model_config)
return model_config