from datetime import datetime from sqlalchemy.orm import Session from typing import List, Optional, Dict, Any import uuid import math import time import asyncio from app.models.models_model import ModelConfig, ModelApiKey, ModelType, LoadBalanceStrategy, ModelProvider from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository, ModelBaseRepository from app.schemas import model_schema from app.schemas.model_schema import ( ModelConfigCreate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate, ModelConfigQuery, ModelStats, ModelConfigQueryNew ) from app.core.logging_config import get_business_logger from app.schemas.response_schema import PageData, PageMeta from app.core.exceptions import BusinessException from app.core.error_codes import BizCode logger = get_business_logger() class ModelConfigService: """模型配置服务""" @staticmethod def get_model_by_id(db: Session, model_id: uuid.UUID, tenant_id: uuid.UUID | None = None) -> ModelConfig: """根据ID获取模型配置""" model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id) if not model: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) return model @staticmethod def get_model_list(db: Session, query: ModelConfigQuery, tenant_id: uuid.UUID | None = None) -> PageData: """获取模型配置列表""" models, total = ModelConfigRepository.get_list(db, query, tenant_id=tenant_id) pages = math.ceil(total / query.pagesize) if total > 0 else 0 return PageData( page=PageMeta( page=query.page, pagesize=query.pagesize, total=total, hasnext=query.page < pages ), items=[model_schema.ModelConfig.model_validate(model) for model in models] ) @staticmethod def get_model_list_new(db: Session, query: ModelConfigQueryNew, tenant_id: uuid.UUID | None = None) -> List[dict]: """获取模型配置列表""" provider_groups, total = ModelConfigRepository.get_list_new(db, query, tenant_id=tenant_id) items = [] for provider, models in provider_groups.items(): # 验证每个模型并封装分组信息 validated_models = [model_schema.ModelConfig.model_validate(model) for model in models] tags = list({model.type for model in validated_models}) group_item = { "provider": provider, # 服务商名称 "logo": validated_models[0].logo, "tags": tags, "models": validated_models # 该服务商下的所有模型 } items.append(group_item) return items @staticmethod def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig: """根据名称获取模型配置""" model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id) if not model: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) return model @staticmethod def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ ModelConfig]: """按名称模糊匹配获取模型配置列表""" return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit) @staticmethod async def validate_model_config( db: Session, *, model_name: str, provider: str, api_key: str, api_base: Optional[str] = None, model_type: str = "llm", test_message: str = "Hello", is_omni: bool = False, capability: Optional[list] = None ) -> Dict[str, Any]: """验证模型配置是否有效 Args: db: 数据库会话 model_name: 模型名称 provider: 提供商 api_key: API密钥 api_base: API基础URL model_type: 模型类型 (llm/chat/embedding/rerank) test_message: 测试消息 is_omni: 是否为Omni模型 Returns: Dict: 验证结果 """ from app.core.models import RedBearLLM, RedBearRerank from app.core.models.base import RedBearModelConfig from app.core.models.embedding import RedBearEmbeddings import traceback try: start_time = time.time() model_config = RedBearModelConfig( model_name=model_name, provider=provider, api_key=api_key, base_url=api_base, is_omni=is_omni, capability=capability, extra_params={ "temperature": 0.7, "max_tokens": 100 } ) # 根据模型类型选择不同的验证方式 model_type_lower = model_type.lower() if model_type_lower in ["llm", "chat"]: # LLM/Chat 模型验证 - 统一使用字符串输入 llm = RedBearLLM(model_config, type=ModelType.LLM if model_type_lower == "llm" else ModelType.CHAT) response = await llm.ainvoke(test_message) elapsed_time = time.time() - start_time content = response.content if hasattr(response, 'content') else str(response) usage = None if hasattr(response, 'usage_metadata'): usage = { "input_tokens": getattr(response.usage_metadata, 'input_tokens', 0), "output_tokens": getattr(response.usage_metadata, 'output_tokens', 0), "total_tokens": getattr(response.usage_metadata, 'total_tokens', 0) } return { "valid": True, "message": f"{model_type.upper()} 模型配置验证成功", "response": content, "elapsed_time": elapsed_time, "usage": usage, "error": None } elif model_type_lower == "embedding": # Embedding 模型验证 # 统一使用 RedBearEmbeddings(自动支持火山引擎多模态) embedding = RedBearEmbeddings(model_config) test_texts = [test_message, "测试文本"] # 火山引擎使用 embed_batch,其他使用 embed_documents if provider.lower() == "volcano": vectors = await asyncio.to_thread(embedding.embed_batch, test_texts) else: vectors = await asyncio.to_thread(embedding.embed_documents, test_texts) elapsed_time = time.time() - start_time return { "valid": True, "message": "Embedding 模型配置验证成功", "response": f"成功生成 {len(vectors)} 个向量,维度: {len(vectors[0]) if vectors else 0}", "elapsed_time": elapsed_time, "usage": { "input_tokens": len(test_message), "vector_count": len(vectors), "vector_dimension": len(vectors[0]) if vectors else 0 }, "error": None } elif model_type_lower == "rerank": # Rerank 模型验证(在线程中运行同步方法) rerank = RedBearRerank(model_config) query = test_message documents = ["这是第一个文档", "这是第二个文档", "这是第三个文档"] results = await asyncio.to_thread(rerank.rerank, query=query, documents=documents, top_n=3) elapsed_time = time.time() - start_time return { "valid": True, "message": "Rerank 模型配置验证成功", "response": f"成功对 {len(documents)} 个文档进行重排序,返回 top {len(results) if results else 0} 结果", "elapsed_time": elapsed_time, "usage": { "query_length": len(query), "document_count": len(documents), "result_count": len(results) if results else 0 }, "error": None } elif model_type_lower == "image": # 图片生成模型验证 from app.core.models.generation import RedBearImageGenerator generator = RedBearImageGenerator(model_config) result = await generator.agenerate( prompt="a cute panda", size="2K" ) elapsed_time = time.time() - start_time logger.info(f"成功生成图片,结果: {result}") return { "valid": True, "message": "图片生成模型配置验证成功", "response": f"成功生成图片,结果: {result}", "elapsed_time": elapsed_time, "usage": { "prompt_length": len("a cute panda"), "image_count": 1 }, "error": None } elif model_type_lower == "video": # 视频生成模型验证 from app.core.models.generation import RedBearVideoGenerator generator = RedBearVideoGenerator(model_config) result = await generator.agenerate( prompt="a cute panda playing in bamboo forest", duration=5 ) elapsed_time = time.time() - start_time # 视频生成是异步任务,返回任务ID task_id = result.get("task_id") if isinstance(result, dict) else None return { "valid": True, "message": "视频生成模型配置验证成功", "response": f"成功创建视频生成任务,任务ID: {task_id}", "elapsed_time": elapsed_time, "usage": { "prompt_length": len("a cute panda playing in bamboo forest"), "task_id": task_id }, "error": None } else: return { "valid": False, "message": "不支持的模型类型", "response": None, "elapsed_time": None, "usage": None, "error": f"不支持的模型类型: {model_type}" } except Exception as e: # 提取详细的错误信息 error_message = str(e) error_type = type(e).__name__ # 特殊处理常见的错误类型 if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower(): # 区域/国家限制(适用于所有提供商) error_message = "区域限制: 该模型在当前区域或国家/地区不可用,请检查提供商的服务区域限制" elif "ValidationException" in error_type or "ValidationException" in error_message: # 其他验证错误 if "access denied" in error_message.lower(): error_message = "访问被拒绝: 请检查 API 凭证和权限配置" else: error_message = f"验证失败: {error_message}" elif "AuthenticationError" in error_type or "authentication" in error_message.lower(): error_message = "认证失败: API Key 无效或已过期" elif "RateLimitError" in error_type or "rate limit" in error_message.lower(): error_message = "请求频率限制: 已超过 API 调用限制" elif "InvalidRequestError" in error_type or "invalid request" in error_message.lower(): error_message = f"无效请求: {error_message}" elif "model_copy" in error_message: error_message = "模型消息格式错误: 请确保使用正确的模型类型(LLM/Chat)" # 记录详细错误日志 logger.error(f"模型验证失败 - 类型: {error_type}, 模型: {model_name}, 提供商: {provider}") logger.error(f"错误详情: {error_message}") logger.debug(f"完整堆栈: {traceback.format_exc()}") return { "valid": False, "message": f"{model_type.upper()} 模型配置验证失败", "response": None, "elapsed_time": None, "usage": None, "error": error_message, "error_type": error_type } @staticmethod async def create_model(db: Session, model_data: ModelConfigCreate, tenant_id: uuid.UUID) -> ModelConfig: """创建模型配置""" # 检查名称是否已存在(同租户内) if ModelConfigRepository.get_by_name(db, model_data.name, provider=model_data.provider, tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) # 验证配置 if not model_data.skip_validation and model_data.api_keys: api_key_data_list = model_data.api_keys for api_key_data in api_key_data_list: validation_result = await ModelConfigService.validate_model_config( db=db, model_name=model_data.name, provider=model_data.provider, api_key=api_key_data.api_key, api_base=api_key_data.api_base, model_type=model_data.type, test_message="Hello", is_omni=model_data.is_omni, capability=model_data.capability ) if not validation_result["valid"]: raise BusinessException( f"模型配置验证失败: {validation_result['error']}", BizCode.INVALID_PARAMETER ) # 事务处理 api_key_datas = model_data.api_keys model_config_data = model_data.model_dump(exclude={"api_keys", "skip_validation"}) # 添加租户ID model_config_data["tenant_id"] = tenant_id model = ModelConfigRepository.create(db, model_config_data) db.flush() # 获取生成的 ID if api_key_datas: for api_key_data in api_key_datas: api_key_data.model_name = model_data.name api_key_data.provider = model_data.provider # 同步capability和is_omni api_key_data.capability = model_data.capability api_key_data.is_omni = model_data.is_omni api_key_create_schema = ModelApiKeyCreate( model_config_ids=[model.id], **api_key_data.model_dump() ) ModelApiKeyRepository.create(db, api_key_create_schema) db.commit() db.refresh(model) return model @staticmethod def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, tenant_id: uuid.UUID | None = None) -> ModelConfig: """更新模型配置""" existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id) if not existing_model: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) if model_data.name and model_data.name != existing_model.name: if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id) db.commit() db.refresh(model) return model @staticmethod async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig: """创建组合模型""" if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) # 验证所有 API Key 存在且类型匹配 for api_key_id in model_data.api_key_ids: api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if not api_key: raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND) # 检查 API Key 关联的模型配置类型 for model_config in api_key.model_configs: # chat 和 llm 类型可以兼容 compatible_types = {ModelType.LLM, ModelType.CHAT} config_type = model_config.type request_type = model_data.type if not (config_type == request_type or (config_type in compatible_types and request_type in compatible_types)): raise BusinessException( f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配", BizCode.INVALID_PARAMETER ) # if model_config.is_composite: # raise BusinessException( # f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型", # BizCode.INVALID_PARAMETER # ) # 创建组合模型 model_config_data = { "tenant_id": tenant_id, "name": model_data.name, "type": model_data.type, "logo": model_data.logo, "description": model_data.description, "provider": ModelProvider.COMPOSITE, "config": model_data.config, "is_active": model_data.is_active, "is_public": model_data.is_public, "is_composite": True } if "load_balance_strategy" in model_data.model_fields_set: model_config_data["load_balance_strategy"] = model_data.load_balance_strategy model = ModelConfigRepository.create(db, model_config_data) db.flush() # 关联 API Keys for api_key_id in model_data.api_key_ids: api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if api_key: model.api_keys.append(api_key) db.commit() db.refresh(model) return model @staticmethod async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig: """更新组合模型""" existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id) if not existing_model: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) if model_data.name and model_data.name != existing_model.name: if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) if not existing_model.is_composite: raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER) # 验证所有 API Key 存在且类型匹配 for api_key_id in model_data.api_key_ids: api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if not api_key: raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND) for model_config in api_key.model_configs: compatible_types = {ModelType.LLM, ModelType.CHAT} config_type = model_config.type request_type = existing_model.type if not (config_type == request_type or (config_type in compatible_types and request_type in compatible_types)): raise BusinessException( f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配", BizCode.INVALID_PARAMETER ) # 更新基本信息 existing_model.name = model_data.name # existing_model.type = model_data.type existing_model.logo = model_data.logo existing_model.description = model_data.description existing_model.config = model_data.config existing_model.is_active = model_data.is_active existing_model.is_public = model_data.is_public if "load_balance_strategy" in model_data.model_fields_set: existing_model.load_balance_strategy = model_data.load_balance_strategy # 更新 API Keys 关联 existing_model.api_keys.clear() for api_key_id in model_data.api_key_ids: api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if api_key: existing_model.api_keys.append(api_key) db.commit() db.refresh(existing_model) return existing_model @staticmethod def delete_model(db: Session, model_id: uuid.UUID, tenant_id: uuid.UUID | None = None) -> bool: """删除模型配置""" if not ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id): raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) success = ModelConfigRepository.delete(db, model_id, tenant_id=tenant_id) db.commit() return success @staticmethod def get_model_stats(db: Session) -> ModelStats: """获取模型统计信息""" stats_data = ModelConfigRepository.get_stats(db) return ModelStats( total_models=stats_data["total_models"], active_models=stats_data["active_models"], llm_count=stats_data["llm_count"], embedding_count=stats_data["embedding_count"], rerank_count=stats_data["rerank_count"], provider_stats=stats_data["provider_stats"] ) class ModelApiKeyService: """模型API Key服务""" @staticmethod def get_api_key_by_id(db: Session, api_key_id: uuid.UUID) -> ModelApiKey: """根据ID获取API Key""" api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if not api_key: raise BusinessException("API Key不存在", BizCode.NOT_FOUND) return api_key @staticmethod def get_api_keys_by_model(db: Session, model_config_id: uuid.UUID, is_active: bool = True) -> list[ModelApiKey]: """根据模型配置ID获取API Key列表""" if not ModelConfigRepository.get_by_id(db, model_config_id): raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) return ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active) @staticmethod async def create_api_key_by_provider(db: Session, data: model_schema.ModelApiKeyCreateByProvider) -> tuple[ list[Any], list[Any]]: """根据provider为多个ModelConfig创建API Key""" created_keys = [] failed_models = [] # 记录验证失败的模型 for model_config_id in data.model_config_ids: model_config = ModelConfigRepository.get_by_id(db, model_config_id) if not model_config: continue data.is_omni = model_config.is_omni data.capability = model_config.capability # 从ModelBase获取model_name model_name = model_config.model_base.name if model_config.model_base else model_config.name # 检查是否存在API Key(包括软删除),需要考虑tenant_id existing_key = db.query(ModelApiKey).join( ModelApiKey.model_configs ).filter( ModelApiKey.api_key == data.api_key, ModelApiKey.provider == data.provider, ModelApiKey.model_name == model_name, ModelConfig.tenant_id == model_config.tenant_id ).first() if existing_key: # 如果已存在,重新激活并更新 if existing_key.is_active: continue existing_key.is_active = True existing_key.api_base = data.api_base existing_key.description = data.description existing_key.config = data.config existing_key.priority = data.priority existing_key.model_name = model_name existing_key.capability = data.capability existing_key.is_omni = data.is_omni # 检查是否已关联该模型配置 if model_config not in existing_key.model_configs: existing_key.model_configs.append(model_config) created_keys.append(existing_key) continue # 验证配置 validation_result = await ModelConfigService.validate_model_config( db=db, model_name=model_name, provider=data.provider, api_key=data.api_key, api_base=data.api_base, model_type=model_config.type, test_message="Hello", is_omni=data.is_omni, capability=model_config.capability ) if not validation_result["valid"]: # 记录验证失败的模型,但不抛出异常 failed_models.append(model_name) continue # 创建API Key api_key_data = ModelApiKeyCreate( model_config_ids=[model_config_id], model_name=model_name, description=data.description, provider=data.provider, api_key=data.api_key, api_base=data.api_base, capability=data.capability, is_omni=data.is_omni, config=data.config, is_active=data.is_active, priority=data.priority ) api_key_obj = ModelApiKeyRepository.create(db, api_key_data) created_keys.append(api_key_obj) if created_keys: db.commit() for key in created_keys: db.refresh(key) return created_keys, failed_models @staticmethod async def create_api_key(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey: # 验证所有关联的模型配置是否存在 if api_key_data.model_config_ids: for model_config_id in api_key_data.model_config_ids: model_config = ModelConfigRepository.get_by_id(db, model_config_id) if not model_config: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) if api_key_data.is_omni is None: api_key_data.is_omni = model_config.is_omni if api_key_data.capability is None: api_key_data.capability = model_config.capability # 检查API Key是否已存在(包括软删除),需要考虑tenant_id existing_key = db.query(ModelApiKey).join( ModelApiKey.model_configs ).filter( ModelApiKey.api_key == api_key_data.api_key, ModelApiKey.provider == api_key_data.provider, ModelApiKey.model_name == api_key_data.model_name, ModelConfig.tenant_id == model_config.tenant_id ).first() if existing_key: if existing_key.is_active: # 如果已激活,跳过 raise BusinessException("该API Key已存在", BizCode.DUPLICATE_NAME) # 如果已存在,重新激活并更新 existing_key.is_active = True existing_key.api_base = api_key_data.api_base existing_key.description = api_key_data.description existing_key.config = api_key_data.config existing_key.priority = api_key_data.priority existing_key.model_name = api_key_data.model_name existing_key.capability = api_key_data.capability existing_key.is_omni = api_key_data.is_omni # 检查是否已关联该模型配置 if model_config not in existing_key.model_configs: existing_key.model_configs.append(model_config) db.commit() db.refresh(existing_key) return existing_key # 验证配置 validation_result = await ModelConfigService.validate_model_config( db=db, model_name=api_key_data.model_name, provider=api_key_data.provider, api_key=api_key_data.api_key, api_base=api_key_data.api_base, model_type=model_config.type, test_message="Hello", is_omni=api_key_data.is_omni, capability=model_config.capability ) if not validation_result["valid"]: raise BusinessException( f"模型配置验证失败: {validation_result['error']}", BizCode.INVALID_PARAMETER ) api_key = ModelApiKeyRepository.create(db, api_key_data) db.commit() db.refresh(api_key) return api_key @staticmethod async def update_api_key(db: Session, api_key_id: uuid.UUID, api_key_data: ModelApiKeyUpdate) -> ModelApiKey: """更新API Key""" existing_api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if not existing_api_key: raise BusinessException("API Key不存在", BizCode.NOT_FOUND) # 获取关联的模型配置以获取模型类型 if existing_api_key.model_configs: model_config = existing_api_key.model_configs[0] validation_result = await ModelConfigService.validate_model_config( db=db, model_name=api_key_data.model_name or existing_api_key.model_name, provider=api_key_data.provider or existing_api_key.provider, api_key=api_key_data.api_key or existing_api_key.api_key, api_base=api_key_data.api_base or existing_api_key.api_base, model_type=model_config.type, test_message="Hello", is_omni=model_config.is_omni, capability=model_config.capability ) if not validation_result["valid"]: raise BusinessException( f"模型配置验证失败: {validation_result['error']}", BizCode.INVALID_PARAMETER ) api_key = ModelApiKeyRepository.update(db, api_key_id, api_key_data) db.commit() db.refresh(api_key) return api_key @staticmethod def delete_api_key(db: Session, api_key_id: uuid.UUID) -> bool: """删除API Key""" api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if not api_key: raise BusinessException("API Key不存在", BizCode.NOT_FOUND) model_config_ids = [mc.id for mc in api_key.model_configs] success = ModelApiKeyRepository.delete(db, api_key_id) for model_config_id in model_config_ids: model_config = ModelConfigRepository.get_by_id(db, model_config_id) if model_config: has_active_key = any(key.is_active for key in model_config.api_keys) if not has_active_key and model_config.is_active: model_config.is_active = False db.commit() return success @staticmethod def get_available_api_key(db: Session, model_config_id: uuid.UUID) -> Optional[ModelApiKey]: """获取可用的API Key(根据负载均衡策略)""" model_config = ModelConfigRepository.get_by_id(db, model_config_id) if not model_config: return None api_keys = [key for key in model_config.api_keys if key.is_active] if not api_keys: return None # 如果是轮询策略,按使用次数最少,次数相同则选最早使用的 if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min)) # 否则返回第一个 return api_keys[0] @staticmethod def record_api_key_usage(db: Session, api_key_id: uuid.UUID | None) -> bool: """记录API Key使用""" if api_key_id: success = ModelApiKeyRepository.update_usage(db, api_key_id) if success: db.commit() return success return False @staticmethod def get_a_api_key(db: Session, model_config_id: uuid.UUID) -> ModelApiKey: api_kes = ModelApiKeyService.get_api_keys_by_model(db, model_config_id) if api_kes and len(api_kes) > 0: return api_kes[0] raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) class ModelBaseService: """基础模型服务""" @staticmethod def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List: models = ModelBaseRepository.get_list(db, query) provider_groups = {} for m in models: model_dict = model_schema.ModelBase.model_validate(m).model_dump() if tenant_id: model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id) provider = m.provider if provider not in provider_groups: provider_groups[provider] = { "provider": provider, "models": [] } provider_groups[provider]["models"].append(model_dict) return list(provider_groups.values()) @staticmethod def get_model_base_by_id(db: Session, model_base_id: uuid.UUID): model = ModelBaseRepository.get_by_id(db, model_base_id) if not model: raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND) return model @staticmethod def create_model_base(db: Session, data: model_schema.ModelBaseCreate): existing = ModelBaseRepository.get_by_name_and_provider(db, data.name, data.provider) if existing: raise BusinessException("模型已存在", BizCode.DUPLICATE_NAME) model_base = ModelBaseRepository.create(db, data.model_dump()) db.commit() db.refresh(model_base) return model_base @staticmethod def update_model_base(db: Session, model_base_id: uuid.UUID, data: model_schema.ModelBaseUpdate): model_base = ModelBaseRepository.update(db, model_base_id, data.model_dump(exclude_unset=True)) if not model_base: raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND) db.commit() db.refresh(model_base) return model_base @staticmethod def delete_model_base(db: Session, model_base_id: uuid.UUID) -> bool: success = ModelBaseRepository.delete(db, model_base_id) if not success: raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND) db.commit() return success @staticmethod def add_model_from_plaza(db: Session, model_base_id: uuid.UUID, tenant_id: uuid.UUID) -> ModelConfig: model_base = ModelBaseRepository.get_by_id(db, model_base_id) if not model_base: raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND) if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id): raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME) model_config_data = { "model_id": model_base_id, "tenant_id": tenant_id, "name": model_base.name, "provider": model_base.provider, "type": model_base.type, "logo": model_base.logo, "description": model_base.description, "capability": model_base.capability, "is_omni": model_base.is_omni, "is_active": False, "is_composite": False } model_config = ModelConfigRepository.create(db, model_config_data) ModelBaseRepository.increment_add_count(db, model_base_id) db.commit() db.refresh(model_config) return model_config