feat(model and app statistic): 1. Optimize the model list; 2. Increase the model combination; 3. Add a model square; 4. Add application management statistics
This commit is contained in:
193
api/app/services/app_statistics_service.py
Normal file
193
api/app/services/app_statistics_service.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""应用统计服务"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List
|
||||
import uuid
|
||||
from sqlalchemy import func, and_, cast, Date
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.conversation_model import Conversation, Message
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.api_key_model import ApiKey, ApiKeyLog
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
|
||||
class AppStatisticsService:
|
||||
"""应用统计服务"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_app_statistics(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
start_date: int,
|
||||
end_date: int
|
||||
) -> Dict[str, Any]:
|
||||
"""获取应用统计数据
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
workspace_id: 工作空间ID
|
||||
start_date: 开始时间戳(毫秒)
|
||||
end_date: 结束时间戳(毫秒)
|
||||
|
||||
Returns:
|
||||
统计数据字典
|
||||
"""
|
||||
# 将毫秒时间戳转换为 datetime
|
||||
start_dt = datetime.fromtimestamp(start_date / 1000)
|
||||
end_dt = datetime.fromtimestamp(end_date / 1000) + timedelta(days=1)
|
||||
|
||||
# 1. 会话统计
|
||||
conversations_stats = self._get_conversations_statistics(app_id, workspace_id, start_dt, end_dt)
|
||||
|
||||
# 2. 新增用户统计
|
||||
users_stats = self._get_new_users_statistics(app_id, start_dt, end_dt)
|
||||
|
||||
# 3. API调用统计
|
||||
api_stats = self._get_api_calls_statistics(app_id, start_dt, end_dt)
|
||||
|
||||
# 4. Token消耗统计
|
||||
token_stats = self._get_token_statistics(app_id, start_dt, end_dt)
|
||||
|
||||
return {
|
||||
"daily_conversations": conversations_stats["daily"],
|
||||
"total_conversations": conversations_stats["total"],
|
||||
"daily_new_users": users_stats["daily"],
|
||||
"total_new_users": users_stats["total"],
|
||||
"daily_api_calls": api_stats["daily"],
|
||||
"total_api_calls": api_stats["total"],
|
||||
"daily_tokens": token_stats["daily"],
|
||||
"total_tokens": token_stats["total"]
|
||||
}
|
||||
|
||||
def _get_conversations_statistics(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
start_dt: datetime,
|
||||
end_dt: datetime
|
||||
) -> Dict[str, Any]:
|
||||
"""获取会话统计"""
|
||||
# 每日会话数
|
||||
daily_query = self.db.query(
|
||||
cast(Conversation.created_at, Date).label('date'),
|
||||
func.count(Conversation.id).label('count')
|
||||
).filter(
|
||||
and_(
|
||||
Conversation.app_id == app_id,
|
||||
Conversation.workspace_id == workspace_id,
|
||||
Conversation.created_at >= start_dt,
|
||||
Conversation.created_at < end_dt
|
||||
)
|
||||
).group_by(cast(Conversation.created_at, Date)).all()
|
||||
|
||||
daily_data = [{"date": str(row.date), "count": row.count} for row in daily_query]
|
||||
total = sum(row["count"] for row in daily_data)
|
||||
|
||||
return {"daily": daily_data, "total": total}
|
||||
|
||||
def _get_new_users_statistics(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
start_dt: datetime,
|
||||
end_dt: datetime
|
||||
) -> Dict[str, Any]:
|
||||
"""获取新增用户统计"""
|
||||
# 每日新增用户数
|
||||
daily_query = self.db.query(
|
||||
cast(EndUser.created_at, Date).label('date'),
|
||||
func.count(EndUser.id).label('count')
|
||||
).filter(
|
||||
and_(
|
||||
EndUser.app_id == app_id,
|
||||
EndUser.created_at >= start_dt,
|
||||
EndUser.created_at < end_dt
|
||||
)
|
||||
).group_by(cast(EndUser.created_at, Date)).all()
|
||||
|
||||
daily_data = [{"date": str(row.date), "count": row.count} for row in daily_query]
|
||||
total = sum(row["count"] for row in daily_data)
|
||||
|
||||
return {"daily": daily_data, "total": total}
|
||||
|
||||
def _get_api_calls_statistics(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
start_dt: datetime,
|
||||
end_dt: datetime
|
||||
) -> Dict[str, Any]:
|
||||
"""获取API调用统计"""
|
||||
# 每日API调用次数
|
||||
daily_query = self.db.query(
|
||||
cast(ApiKeyLog.created_at, Date).label('date'),
|
||||
func.count(ApiKeyLog.id).label('count')
|
||||
).join(
|
||||
ApiKey, ApiKeyLog.api_key_id == ApiKey.id
|
||||
).filter(
|
||||
and_(
|
||||
ApiKey.resource_id == app_id,
|
||||
ApiKeyLog.created_at >= start_dt,
|
||||
ApiKeyLog.created_at < end_dt
|
||||
)
|
||||
).group_by(cast(ApiKeyLog.created_at, Date)).all()
|
||||
|
||||
daily_data = [{"date": str(row.date), "count": row.count} for row in daily_query]
|
||||
total = sum(row["count"] for row in daily_data)
|
||||
|
||||
return {"daily": daily_data, "total": total}
|
||||
|
||||
def _get_token_statistics(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
start_dt: datetime,
|
||||
end_dt: datetime
|
||||
) -> Dict[str, Any]:
|
||||
"""获取Token消耗统计(从Message的meta_data中提取)"""
|
||||
from sqlalchemy import text
|
||||
|
||||
# 查询所有相关消息的token使用情况
|
||||
# meta_data中可能包含: {"usage": {"total_tokens": 100}} 或 {"tokens": 100}
|
||||
daily_query = self.db.query(
|
||||
cast(Message.created_at, Date).label('date'),
|
||||
Message.meta_data
|
||||
).join(
|
||||
Conversation, Message.conversation_id == Conversation.id
|
||||
).filter(
|
||||
and_(
|
||||
Conversation.app_id == app_id,
|
||||
Message.created_at >= start_dt,
|
||||
Message.created_at < end_dt,
|
||||
Message.meta_data.isnot(None)
|
||||
)
|
||||
).all()
|
||||
|
||||
# 按日期聚合token
|
||||
daily_tokens = {}
|
||||
for row in daily_query:
|
||||
date_str = str(row.date)
|
||||
meta = row.meta_data or {}
|
||||
|
||||
# 提取token数量(支持多种格式)
|
||||
tokens = 0
|
||||
if isinstance(meta, dict):
|
||||
# 格式1: {"usage": {"total_tokens": 100}}
|
||||
if "usage" in meta and isinstance(meta["usage"], dict):
|
||||
tokens = meta["usage"].get("total_tokens", 0)
|
||||
# 格式2: {"tokens": 100}
|
||||
elif "tokens" in meta:
|
||||
tokens = meta.get("tokens", 0)
|
||||
# 格式3: {"total_tokens": 100}
|
||||
elif "total_tokens" in meta:
|
||||
tokens = meta.get("total_tokens", 0)
|
||||
|
||||
if date_str not in daily_tokens:
|
||||
daily_tokens[date_str] = 0
|
||||
daily_tokens[date_str] += int(tokens)
|
||||
|
||||
daily_data = [{"date": date, "tokens": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
|
||||
total = sum(row["tokens"] for row in daily_data)
|
||||
|
||||
return {"daily": daily_data, "total": total}
|
||||
@@ -16,6 +16,7 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.models import AgentConfig, ModelApiKey, ModelConfig
|
||||
from app.repositories.model_repository import ModelApiKeyRepository
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.services import task_service
|
||||
@@ -724,17 +725,21 @@ class DraftRunService:
|
||||
Raises:
|
||||
BusinessException: 当没有可用的 API Key 时
|
||||
"""
|
||||
stmt = (
|
||||
select(ModelApiKey)
|
||||
.where(
|
||||
ModelApiKey.model_config_id == model_config_id,
|
||||
ModelApiKey.is_active.is_(True)
|
||||
)
|
||||
.order_by(ModelApiKey.priority.desc())
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
api_key = self.db.scalars(stmt).first()
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
||||
# stmt = (
|
||||
# select(ModelApiKey).join(
|
||||
# ModelConfig, ModelApiKey.model_configs
|
||||
# )
|
||||
# .where(
|
||||
# ModelConfig.id == model_config_id,
|
||||
# ModelApiKey.is_active.is_(True)
|
||||
# )
|
||||
# .order_by(ModelApiKey.priority.desc())
|
||||
# .limit(1)
|
||||
# )
|
||||
#
|
||||
# api_key = self.db.scalars(stmt).first()
|
||||
api_key = api_keys[0] if api_keys else None
|
||||
|
||||
if not api_key:
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
@@ -5,6 +5,7 @@ import uuid
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.repositories.model_repository import ModelApiKeyRepository
|
||||
from app.services.conversation_state_manager import ConversationStateManager
|
||||
from app.models import ModelConfig, AgentConfig
|
||||
from app.core.logging_config import get_business_logger
|
||||
@@ -382,11 +383,14 @@ class LLMRouter:
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.models import ModelApiKey, ModelType
|
||||
|
||||
# 获取 API Key 配置
|
||||
api_key_config = self.db.query(ModelApiKey).filter(
|
||||
ModelApiKey.model_config_id == self.routing_model_config.id,
|
||||
ModelApiKey.is_active
|
||||
).first()
|
||||
# 获取 API Key 配置(通过关联关系)
|
||||
# api_key_config = self.db.query(ModelApiKey).join(
|
||||
# ModelConfig, ModelApiKey.model_configs
|
||||
# ).filter(ModelConfig.id == self.routing_model_config.id,
|
||||
# ModelApiKey.is_active == True
|
||||
# ).first()
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, self.routing_model_config.id)
|
||||
api_key_config = api_keys[0] if api_keys else None
|
||||
|
||||
if not api_key_config:
|
||||
raise Exception("路由模型没有可用的 API Key")
|
||||
@@ -419,6 +423,9 @@ class LLMRouter:
|
||||
|
||||
# 调用模型
|
||||
response = await llm.ainvoke(prompt)
|
||||
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
|
||||
|
||||
# 提取响应内容
|
||||
if hasattr(response, 'content'):
|
||||
|
||||
@@ -338,7 +338,7 @@ class MemoryConfigService:
|
||||
"provider": api_config.provider,
|
||||
"api_key": api_config.api_key,
|
||||
"base_url": api_config.api_base,
|
||||
"model_config_id": api_config.model_config_id,
|
||||
"model_config_id": str(config.id),
|
||||
"type": config.type,
|
||||
"timeout": settings.LLM_TIMEOUT,
|
||||
"max_retries": settings.LLM_MAX_RETRIES,
|
||||
@@ -370,7 +370,7 @@ class MemoryConfigService:
|
||||
"provider": api_config.provider,
|
||||
"api_key": api_config.api_key,
|
||||
"base_url": api_config.api_base,
|
||||
"model_config_id": api_config.model_config_id,
|
||||
"model_config_id": str(config.id),
|
||||
"type": config.type,
|
||||
"timeout": 120.0,
|
||||
"max_retries": 5,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional, Dict, Any
|
||||
import uuid
|
||||
@@ -6,11 +7,11 @@ import time
|
||||
import asyncio
|
||||
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType
|
||||
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository
|
||||
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository, ModelBaseRepository
|
||||
from app.schemas import model_schema
|
||||
from app.schemas.model_schema import (
|
||||
ModelConfigCreate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
||||
ModelConfigQuery, ModelStats
|
||||
ModelConfigQuery, ModelStats, ModelConfigQueryNew
|
||||
)
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
@@ -47,6 +48,26 @@ class ModelConfigService:
|
||||
items=[model_schema.ModelConfig.model_validate(model) for model in models]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_model_list_new(db: Session, query: ModelConfigQueryNew, tenant_id: uuid.UUID | None = None) -> List[dict]:
|
||||
"""获取模型配置列表"""
|
||||
provider_groups, total = ModelConfigRepository.get_list_new(db, query, tenant_id=tenant_id)
|
||||
|
||||
items = []
|
||||
for provider, models in provider_groups.items():
|
||||
# 验证每个模型并封装分组信息
|
||||
validated_models = [model_schema.ModelConfig.model_validate(model) for model in models]
|
||||
tags = list({model.type for model in validated_models})
|
||||
group_item = {
|
||||
"provider": provider, # 服务商名称
|
||||
"logo": validated_models[0].logo,
|
||||
"tags": tags,
|
||||
"models": validated_models # 该服务商下的所有模型
|
||||
}
|
||||
items.append(group_item)
|
||||
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def get_model_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
"""根据名称获取模型配置"""
|
||||
@@ -228,37 +249,39 @@ class ModelConfigService:
|
||||
|
||||
# 验证配置
|
||||
if not model_data.skip_validation and model_data.api_keys:
|
||||
api_key_data = model_data.api_keys
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=api_key_data.model_name,
|
||||
provider=api_key_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_data.type, # 传递模型类型
|
||||
test_message="Hello"
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
f"模型配置验证失败: {validation_result['error']}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
api_key_data_list = model_data.api_keys
|
||||
for api_key_data in api_key_data_list:
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=api_key_data.model_name,
|
||||
provider=api_key_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_data.type, # 传递模型类型
|
||||
test_message="Hello"
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
f"模型配置验证失败: {validation_result['error']}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
# 事务处理
|
||||
api_key_data = model_data.api_keys
|
||||
model_config_data = model_data.dict(exclude={"api_keys", "skip_validation"})
|
||||
api_key_datas = model_data.api_keys
|
||||
model_config_data = model_data.model_dump(exclude={"api_keys", "skip_validation"})
|
||||
# 添加租户ID
|
||||
model_config_data["tenant_id"] = tenant_id
|
||||
|
||||
model = ModelConfigRepository.create(db, model_config_data)
|
||||
db.flush() # 获取生成的 ID
|
||||
|
||||
if api_key_data:
|
||||
api_key_create_schema = ModelApiKeyCreate(
|
||||
model_config_id=model.id,
|
||||
**api_key_data.dict()
|
||||
)
|
||||
ModelApiKeyRepository.create(db, api_key_create_schema)
|
||||
if api_key_datas:
|
||||
for api_key_data in api_key_datas:
|
||||
api_key_create_schema = ModelApiKeyCreate(
|
||||
model_config_ids=[model.id],
|
||||
**api_key_data.model_dump()
|
||||
)
|
||||
ModelApiKeyRepository.create(db, api_key_create_schema)
|
||||
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
@@ -280,6 +303,112 @@ class ModelConfigService:
|
||||
db.refresh(model)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
||||
"""创建组合模型"""
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
# 验证所有 API Key 存在且类型匹配
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if not api_key:
|
||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
|
||||
|
||||
# 检查 API Key 关联的模型配置类型
|
||||
for model_config in api_key.model_configs:
|
||||
# chat 和 llm 类型可以兼容
|
||||
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
||||
config_type = model_config.type
|
||||
request_type = model_data.type
|
||||
|
||||
if not (config_type == request_type or
|
||||
(config_type in compatible_types and request_type in compatible_types)):
|
||||
raise BusinessException(
|
||||
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
# if model_config.is_composite:
|
||||
# raise BusinessException(
|
||||
# f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型",
|
||||
# BizCode.INVALID_PARAMETER
|
||||
# )
|
||||
|
||||
# 创建组合模型
|
||||
model_config_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"name": model_data.name,
|
||||
"type": model_data.type,
|
||||
"logo": model_data.logo,
|
||||
"description": model_data.description,
|
||||
"provider": "composite",
|
||||
"config": model_data.config,
|
||||
"is_active": model_data.is_active,
|
||||
"is_public": model_data.is_public,
|
||||
"is_composite": True
|
||||
}
|
||||
|
||||
model = ModelConfigRepository.create(db, model_config_data)
|
||||
db.flush()
|
||||
|
||||
# 关联 API Keys
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if api_key:
|
||||
model.api_keys.append(api_key)
|
||||
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
||||
"""更新组合模型"""
|
||||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||||
if not existing_model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
if not existing_model.is_composite:
|
||||
raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER)
|
||||
|
||||
# 验证所有 API Key 存在且类型匹配
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if not api_key:
|
||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
|
||||
|
||||
for model_config in api_key.model_configs:
|
||||
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
||||
config_type = model_config.type
|
||||
request_type = model_data.type
|
||||
|
||||
if not (config_type == request_type or
|
||||
(config_type in compatible_types and request_type in compatible_types)):
|
||||
raise BusinessException(
|
||||
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
# 更新基本信息
|
||||
existing_model.name = model_data.name
|
||||
existing_model.type = model_data.type
|
||||
existing_model.logo = model_data.logo
|
||||
existing_model.description = model_data.description
|
||||
existing_model.config = model_data.config
|
||||
existing_model.is_active = model_data.is_active
|
||||
existing_model.is_public = model_data.is_public
|
||||
|
||||
# 更新 API Keys 关联
|
||||
existing_model.api_keys.clear()
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if api_key:
|
||||
existing_model.api_keys.append(api_key)
|
||||
|
||||
db.commit()
|
||||
db.refresh(existing_model)
|
||||
return existing_model
|
||||
|
||||
@staticmethod
|
||||
def delete_model(db: Session, model_id: uuid.UUID, tenant_id: uuid.UUID | None = None) -> bool:
|
||||
"""删除模型配置"""
|
||||
@@ -324,27 +453,132 @@ class ModelApiKeyService:
|
||||
return ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active)
|
||||
|
||||
@staticmethod
|
||||
async def create_api_key(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey:
|
||||
"""创建API Key"""
|
||||
model_config = ModelConfigRepository.get_by_id(db, api_key_data.model_config_id)
|
||||
if not model_config:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
async def create_api_key_by_provider(db: Session, data: model_schema.ModelApiKeyCreateByProvider) -> List[ModelApiKey]:
|
||||
"""根据provider为多个ModelConfig创建API Key"""
|
||||
created_keys = []
|
||||
|
||||
for model_config_id in data.model_config_ids:
|
||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||
if not model_config:
|
||||
continue
|
||||
|
||||
# 从ModelBase获取model_name
|
||||
model_name = model_config.model_base.name if model_config.model_base else model_config.name
|
||||
|
||||
# 检查是否存在API Key(包括软删除)
|
||||
existing_key = db.query(ModelApiKey).filter(
|
||||
ModelApiKey.api_key == data.api_key,
|
||||
ModelApiKey.provider == data.provider,
|
||||
ModelApiKey.model_name == model_name
|
||||
).first()
|
||||
|
||||
if existing_key:
|
||||
# 如果已存在,重新激活并更新
|
||||
if existing_key.is_active:
|
||||
continue
|
||||
existing_key.is_active = True
|
||||
existing_key.api_base = data.api_base
|
||||
existing_key.description = data.description
|
||||
existing_key.config = data.config
|
||||
existing_key.priority = data.priority
|
||||
existing_key.model_name = model_name
|
||||
|
||||
# 检查是否已关联该模型配置
|
||||
if model_config not in existing_key.model_configs:
|
||||
existing_key.model_configs.append(model_config)
|
||||
|
||||
created_keys.append(existing_key)
|
||||
continue
|
||||
|
||||
# 验证配置
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=api_key_data.model_name,
|
||||
provider=api_key_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_config.type, # 传递模型类型
|
||||
model_name=model_name,
|
||||
provider=data.provider,
|
||||
api_key=data.api_key,
|
||||
api_base=data.api_base,
|
||||
model_type=model_config.type,
|
||||
test_message="Hello"
|
||||
)
|
||||
print(validation_result)
|
||||
if not validation_result["valid"]:
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
f"模型配置验证失败: {validation_result['error']}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
# 创建API Key
|
||||
api_key_data = ModelApiKeyCreate(
|
||||
model_config_ids=[model_config_id],
|
||||
model_name=model_name,
|
||||
description=data.description,
|
||||
provider=data.provider,
|
||||
api_key=data.api_key,
|
||||
api_base=data.api_base,
|
||||
config=data.config,
|
||||
is_active=data.is_active,
|
||||
priority=data.priority
|
||||
)
|
||||
api_key_obj = ModelApiKeyRepository.create(db, api_key_data)
|
||||
created_keys.append(api_key_obj)
|
||||
|
||||
if created_keys:
|
||||
db.commit()
|
||||
for key in created_keys:
|
||||
db.refresh(key)
|
||||
|
||||
return created_keys
|
||||
|
||||
@staticmethod
|
||||
async def create_api_key(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey:
|
||||
# 验证所有关联的模型配置是否存在
|
||||
if api_key_data.model_config_ids:
|
||||
for model_config_id in api_key_data.model_config_ids:
|
||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||
if not model_config:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
# 检查API Key是否已存在(包括软删除)
|
||||
existing_key = db.query(ModelApiKey).filter(
|
||||
ModelApiKey.api_key == api_key_data.api_key,
|
||||
ModelApiKey.provider == api_key_data.provider,
|
||||
ModelApiKey.model_name == api_key_data.model_name
|
||||
).first()
|
||||
|
||||
if existing_key:
|
||||
if existing_key.is_active:
|
||||
# 如果已激活,跳过
|
||||
raise BusinessException("该API Key已存在", BizCode.DUPLICATE_NAME)
|
||||
# 如果已存在,重新激活并更新
|
||||
existing_key.is_active = True
|
||||
existing_key.api_base = api_key_data.api_base
|
||||
existing_key.description = api_key_data.description
|
||||
existing_key.config = api_key_data.config
|
||||
existing_key.priority = api_key_data.priority
|
||||
existing_key.model_name = api_key_data.model_name
|
||||
|
||||
# 检查是否已关联该模型配置
|
||||
if model_config not in existing_key.model_configs:
|
||||
existing_key.model_configs.append(model_config)
|
||||
|
||||
db.commit()
|
||||
db.refresh(existing_key)
|
||||
return existing_key
|
||||
|
||||
# 验证配置
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=api_key_data.model_name,
|
||||
provider=api_key_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_config.type,
|
||||
test_message="Hello"
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
f"模型配置验证失败: {validation_result['error']}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
api_key = ModelApiKeyRepository.create(db, api_key_data)
|
||||
db.commit()
|
||||
@@ -359,21 +593,19 @@ class ModelApiKeyService:
|
||||
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
|
||||
|
||||
# 获取关联的模型配置以获取模型类型
|
||||
model_config = ModelConfigRepository.get_by_id(db, existing_api_key.model_config_id)
|
||||
if not model_config:
|
||||
raise BusinessException("关联的模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
if existing_api_key.model_configs:
|
||||
model_config = existing_api_key.model_configs[0]
|
||||
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=api_key_data.model_name,
|
||||
provider=api_key_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_config.type, # 传递模型类型
|
||||
model_name=api_key_data.model_name or existing_api_key.model_name,
|
||||
provider=api_key_data.provider or existing_api_key.provider,
|
||||
api_key=api_key_data.api_key or existing_api_key.api_key,
|
||||
api_base=api_key_data.api_base or existing_api_key.api_base,
|
||||
model_type=model_config.type,
|
||||
test_message="Hello"
|
||||
)
|
||||
print(validation_result)
|
||||
if not validation_result["valid"]:
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
f"模型配置验证失败: {validation_result['error']}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
@@ -417,3 +649,84 @@ class ModelApiKeyService:
|
||||
if api_kes and len(api_kes) > 0:
|
||||
return api_kes[0]
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
|
||||
|
||||
class ModelBaseService:
|
||||
"""基础模型服务"""
|
||||
|
||||
@staticmethod
|
||||
def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List:
|
||||
models = ModelBaseRepository.get_list(db, query)
|
||||
|
||||
provider_groups = {}
|
||||
for m in models:
|
||||
model_dict = model_schema.ModelBase.model_validate(m).model_dump()
|
||||
if tenant_id:
|
||||
model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id)
|
||||
|
||||
provider = m.provider
|
||||
if provider not in provider_groups:
|
||||
provider_groups[provider] = {
|
||||
"provider": provider,
|
||||
"models": []
|
||||
}
|
||||
provider_groups[provider]["models"].append(model_dict)
|
||||
|
||||
return list(provider_groups.values())
|
||||
|
||||
@staticmethod
|
||||
def get_model_base_by_id(db: Session, model_base_id: uuid.UUID):
|
||||
model = ModelBaseRepository.get_by_id(db, model_base_id)
|
||||
if not model:
|
||||
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def create_model_base(db: Session, data: model_schema.ModelBaseCreate):
|
||||
model_base = ModelBaseRepository.create(db, data.model_dump())
|
||||
db.commit()
|
||||
db.refresh(model_base)
|
||||
return model_base
|
||||
|
||||
@staticmethod
|
||||
def update_model_base(db: Session, model_base_id: uuid.UUID, data: model_schema.ModelBaseUpdate):
|
||||
model_base = ModelBaseRepository.update(db, model_base_id, data.model_dump(exclude_unset=True))
|
||||
if not model_base:
|
||||
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
||||
db.commit()
|
||||
db.refresh(model_base)
|
||||
return model_base
|
||||
|
||||
@staticmethod
|
||||
def delete_model_base(db: Session, model_base_id: uuid.UUID) -> bool:
|
||||
success = ModelBaseRepository.delete(db, model_base_id)
|
||||
if not success:
|
||||
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
||||
db.commit()
|
||||
return success
|
||||
|
||||
@staticmethod
|
||||
def add_model_from_plaza(db: Session, model_base_id: uuid.UUID, tenant_id: uuid.UUID) -> ModelConfig:
|
||||
model_base = ModelBaseRepository.get_by_id(db, model_base_id)
|
||||
if not model_base:
|
||||
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id):
|
||||
raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME)
|
||||
|
||||
model_config_data = {
|
||||
"model_id": model_base_id,
|
||||
"tenant_id": tenant_id,
|
||||
"name": model_base.name,
|
||||
"provider": model_base.provider,
|
||||
"type": model_base.type,
|
||||
"logo": model_base.logo,
|
||||
"description": model_base.description,
|
||||
"is_composite": False
|
||||
}
|
||||
model_config = ModelConfigRepository.create(db, model_config_data)
|
||||
ModelBaseRepository.increment_add_count(db, model_base_id)
|
||||
db.commit()
|
||||
db.refresh(model_config)
|
||||
return model_config
|
||||
|
||||
@@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import MultiAgentConfig, AgentConfig, ModelConfig
|
||||
from app.models.multi_agent_model import AggregationStrategy, OrchestrationMode
|
||||
from app.repositories.model_repository import ModelApiKeyRepository
|
||||
from app.services.agent_registry import AgentRegistry
|
||||
from app.services.master_agent_router import MasterAgentRouter
|
||||
from app.services.conversation_state_manager import ConversationStateManager
|
||||
@@ -2546,10 +2547,14 @@ class MultiAgentOrchestrator:
|
||||
return self._smart_merge_results(results, strategy)
|
||||
|
||||
# 获取 API Key 配置
|
||||
api_key_config = self.db.query(ModelApiKey).filter(
|
||||
ModelApiKey.model_config_id == default_model_config_id,
|
||||
ModelApiKey.is_active.is_(True)
|
||||
).first()
|
||||
# api_key_config = self.db.query(ModelApiKey).join(
|
||||
# ModelConfig, ModelApiKey.model_configs
|
||||
# ).filter(
|
||||
# ModelConfig.id == default_model_config_id,
|
||||
# ModelApiKey.is_active.is_(True)
|
||||
# ).first()
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id)
|
||||
api_key_config = api_keys[0] if api_keys else None
|
||||
|
||||
if not api_key_config:
|
||||
logger.warning("Master Agent 没有可用的 API Key,使用简单整合")
|
||||
@@ -2703,10 +2708,14 @@ class MultiAgentOrchestrator:
|
||||
return
|
||||
|
||||
# 获取 API Key 配置
|
||||
api_key_config = self.db.query(ModelApiKey).filter(
|
||||
ModelApiKey.model_config_id == default_model_config_id,
|
||||
ModelApiKey.is_active.is_(True)
|
||||
).first()
|
||||
# api_key_config = self.db.query(ModelApiKey).join(
|
||||
# ModelConfig, ModelApiKey.model_configs
|
||||
# ).filter(
|
||||
# ModelConfig.id == default_model_config_id,
|
||||
# ModelApiKey.is_active.is_(True)
|
||||
# ).first()
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id)
|
||||
api_key_config = api_keys[0] if api_keys else None
|
||||
|
||||
if not api_key_config:
|
||||
logger.warning("Master Agent 没有可用的 API Key,使用简单整合")
|
||||
|
||||
@@ -4,6 +4,8 @@ import time
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any, AsyncGenerator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.repositories.model_repository import ModelApiKeyRepository
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.models import ReleaseShare, AppRelease, Conversation
|
||||
from app.services.conversation_service import ConversationService
|
||||
@@ -164,16 +166,20 @@ class SharedChatService:
|
||||
raise ResourceNotFoundException("模型配置", str(model_config_id))
|
||||
|
||||
# 获取 API Key
|
||||
stmt = (
|
||||
select(ModelApiKey)
|
||||
.where(
|
||||
ModelApiKey.model_config_id == model_config_id,
|
||||
ModelApiKey.is_active.is_(True)
|
||||
)
|
||||
.order_by(ModelApiKey.priority.desc())
|
||||
.limit(1)
|
||||
)
|
||||
api_key_obj = self.db.scalars(stmt).first()
|
||||
# stmt = (
|
||||
# select(ModelApiKey).join(
|
||||
# ModelConfig, ModelApiKey.model_configs
|
||||
# )
|
||||
# .where(
|
||||
# ModelConfig.id == model_config_id,
|
||||
# ModelApiKey.is_active.is_(True)
|
||||
# )
|
||||
# .order_by(ModelApiKey.priority.desc())
|
||||
# .limit(1)
|
||||
# )
|
||||
# api_key_obj = self.db.scalars(stmt).first()
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
||||
api_key_obj = api_keys[0] if api_keys else None
|
||||
if not api_key_obj:
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
@@ -358,16 +364,20 @@ class SharedChatService:
|
||||
raise ResourceNotFoundException("模型配置", str(model_config_id))
|
||||
|
||||
# 获取 API Key
|
||||
stmt = (
|
||||
select(ModelApiKey)
|
||||
.where(
|
||||
ModelApiKey.model_config_id == model_config_id,
|
||||
ModelApiKey.is_active.is_(True)
|
||||
)
|
||||
.order_by(ModelApiKey.priority.desc())
|
||||
.limit(1)
|
||||
)
|
||||
api_key_obj = self.db.scalars(stmt).first()
|
||||
# stmt = (
|
||||
# select(ModelApiKey).join(
|
||||
# ModelConfig, ModelApiKey.model_configs
|
||||
# )
|
||||
# .where(
|
||||
# ModelConfig.id == model_config_id,
|
||||
# ModelApiKey.is_active.is_(True)
|
||||
# )
|
||||
# .order_by(ModelApiKey.priority.desc())
|
||||
# .limit(1)
|
||||
# )
|
||||
# api_key_obj = self.db.scalars(stmt).first()
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
||||
api_key_obj = api_keys[0] if api_keys else None
|
||||
if not api_key_obj:
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user