Merge branch 'refs/heads/develop' into fix/memory_bug_fix
This commit is contained in:
@@ -872,3 +872,44 @@ async def update_workflow_config(
|
||||
workspace_id = current_user.current_workspace_id
|
||||
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||
|
||||
|
||||
@router.get("/{app_id}/statistics", summary="应用统计数据")
|
||||
@cur_workspace_access_guard()
|
||||
def get_app_statistics(
|
||||
app_id: uuid.UUID,
|
||||
start_date: int,
|
||||
end_date: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""获取应用统计数据
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
start_date: 开始时间戳(毫秒)
|
||||
end_date: 结束时间戳(毫秒)
|
||||
|
||||
Returns:
|
||||
- daily_conversations: 每日会话数统计
|
||||
- total_conversations: 总会话数
|
||||
- daily_new_users: 每日新增用户数
|
||||
- total_new_users: 总新增用户数
|
||||
- daily_api_calls: 每日API调用次数
|
||||
- total_api_calls: 总API调用次数
|
||||
- daily_tokens: 每日token消耗
|
||||
- total_tokens: 总token消耗
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
stats_service = AppStatisticsService(db)
|
||||
|
||||
result = stats_service.get_app_statistics(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
return success(data=result)
|
||||
|
||||
@@ -3,15 +3,17 @@ from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
from app.models.user_model import User
|
||||
from app.repositories.model_repository import ModelConfigRepository
|
||||
from app.schemas import model_schema
|
||||
from app.core.response_utils import success
|
||||
from app.schemas.response_schema import ApiResponse, PageData
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService, ModelBaseService
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# 获取API专用日志器
|
||||
@@ -24,7 +26,6 @@ router = APIRouter(
|
||||
|
||||
@router.get("/type", response_model=ApiResponse)
|
||||
def get_model_types():
|
||||
|
||||
return success(msg="获取模型类型成功", data=list(ModelType))
|
||||
|
||||
|
||||
@@ -35,13 +36,68 @@ def get_model_providers():
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_model_list(
|
||||
type: Optional[str] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取模型配置列表
|
||||
|
||||
支持多个 type 参数:
|
||||
- 单个:?type=LLM
|
||||
- 多个(逗号分隔):?type=LLM,EMBEDDING
|
||||
- 多个(重复参数):?type=LLM&type=EMBEDDING
|
||||
"""
|
||||
api_logger.info(
|
||||
f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
# 解析 type 参数(支持逗号分隔)
|
||||
type_list = []
|
||||
if type is not None:
|
||||
flat_type = []
|
||||
for item in type:
|
||||
split_items = [t.strip() for t in item.split(',') if t.strip()]
|
||||
flat_type.extend(split_items)
|
||||
|
||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||
|
||||
api_logger.error(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQuery(
|
||||
type=type_list,
|
||||
provider=provider,
|
||||
is_active=is_active,
|
||||
is_public=is_public,
|
||||
search=search,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
api_logger.debug(f"开始获取模型配置列表: {query.dict()}")
|
||||
result_orm = ModelConfigService.get_model_list(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||
result = PageData.model_validate(result_orm)
|
||||
api_logger.info(f"模型配置列表获取成功: 总数={result.page.total}, 当前页={len(result.items)}")
|
||||
return success(data=result, msg="模型配置列表获取成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取模型配置列表失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/new", response_model=ApiResponse)
|
||||
def get_model_list(
|
||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于ModelConfig)"),
|
||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
is_composite: Optional[bool] = Query(None, description="组合模型筛选"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
@@ -53,36 +109,123 @@ def get_model_list(
|
||||
- 多个(逗号分隔):?type=LLM,EMBEDDING
|
||||
- 多个(重复参数):?type=LLM&type=EMBEDDING
|
||||
"""
|
||||
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
|
||||
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
# 解析 type 参数(支持逗号分隔)
|
||||
type_list = None
|
||||
if type:
|
||||
type_values = [t.strip() for t in type.split(',')]
|
||||
type_list = [model_schema.ModelType(t.lower()) for t in type_values if t]
|
||||
type_list = []
|
||||
if type is not None:
|
||||
flat_type = []
|
||||
for item in type:
|
||||
split_items = [t.strip() for t in item.split(',') if t.strip()]
|
||||
flat_type.extend(split_items)
|
||||
|
||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||
|
||||
api_logger.error(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQuery(
|
||||
api_logger.info(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQueryNew(
|
||||
type=type_list,
|
||||
provider=provider,
|
||||
is_active=is_active,
|
||||
is_public=is_public,
|
||||
search=search,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
is_composite=is_composite,
|
||||
search=search
|
||||
)
|
||||
|
||||
api_logger.debug(f"开始获取模型配置列表: {query.dict()}")
|
||||
result_orm = ModelConfigService.get_model_list(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||
result = PageData.model_validate(result_orm)
|
||||
api_logger.info(f"模型配置列表获取成功: 总数={result.page.total}, 当前页={len(result.items)}")
|
||||
api_logger.debug(f"开始获取模型配置列表: {query.model_dump()}")
|
||||
result = ModelConfigService.get_model_list_new(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"模型配置列表获取成功: 分组数={len(result)}, 总模型数={sum(len(item['models']) for item in result)}")
|
||||
return success(data=result, msg="模型配置列表获取成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取模型配置列表失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/model_plaza", response_model=ApiResponse)
|
||||
def get_model_plaza_list(
|
||||
type: Optional[ModelType] = Query(None, description="模型类型"),
|
||||
provider: Optional[ModelProvider] = Query(None, description="供应商"),
|
||||
is_official: Optional[bool] = Query(None, description="是否官方模型"),
|
||||
is_deprecated: Optional[bool] = Query(False, description="是否弃用"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""模型广场查询接口(按供应商分组)"""
|
||||
|
||||
query = model_schema.ModelBaseQuery(
|
||||
type=type,
|
||||
provider=provider,
|
||||
is_official=is_official,
|
||||
is_deprecated=is_deprecated,
|
||||
search=search
|
||||
)
|
||||
result = ModelBaseService.get_model_base_list(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||
return success(data=result, msg="模型广场列表获取成功")
|
||||
|
||||
|
||||
@router.get("/model_plaza/{model_base_id}", response_model=ApiResponse)
|
||||
def get_model_base_by_id(
|
||||
model_base_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取基础模型详情"""
|
||||
|
||||
result = ModelBaseService.get_model_base_by_id(db=db, model_base_id=model_base_id)
|
||||
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型获取成功")
|
||||
|
||||
|
||||
@router.post("/model_plaza", response_model=ApiResponse)
|
||||
def create_model_base(
|
||||
data: model_schema.ModelBaseCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""创建基础模型"""
|
||||
|
||||
result = ModelBaseService.create_model_base(db=db, data=data)
|
||||
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型创建成功")
|
||||
|
||||
|
||||
@router.put("/model_plaza/{model_base_id}", response_model=ApiResponse)
|
||||
def update_model_base(
|
||||
model_base_id: uuid.UUID,
|
||||
data: model_schema.ModelBaseUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新基础模型"""
|
||||
|
||||
result = ModelBaseService.update_model_base(db=db, model_base_id=model_base_id, data=data)
|
||||
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型更新成功")
|
||||
|
||||
|
||||
@router.delete("/model_plaza/{model_base_id}", response_model=ApiResponse)
|
||||
def delete_model_base(
|
||||
model_base_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除基础模型"""
|
||||
|
||||
ModelBaseService.delete_model_base(db=db, model_base_id=model_base_id)
|
||||
return success(msg="基础模型删除成功")
|
||||
|
||||
|
||||
@router.post("/model_plaza/{model_base_id}/add", response_model=ApiResponse)
|
||||
def add_model_from_plaza(
|
||||
model_base_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""从模型广场添加模型到模型列表"""
|
||||
|
||||
result = ModelBaseService.add_model_from_plaza(db=db, model_base_id=model_base_id, tenant_id=current_user.tenant_id)
|
||||
return success(data=model_schema.ModelConfig.model_validate(result), msg="模型添加成功")
|
||||
|
||||
|
||||
@router.get("/{model_id}", response_model=ApiResponse)
|
||||
def get_model_by_id(
|
||||
model_id: uuid.UUID,
|
||||
@@ -138,6 +281,71 @@ async def create_model(
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/composite", response_model=ApiResponse)
|
||||
async def create_composite_model(
|
||||
model_data: model_schema.CompositeModelCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
创建组合模型
|
||||
|
||||
- 绑定一个或多个现有的 API Key
|
||||
- 所有 API Key 必须来自非组合模型
|
||||
- 所有 API Key 关联的模型类型必须与组合模型类型一致
|
||||
"""
|
||||
api_logger.info(f"创建组合模型请求: {model_data.name}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
result_orm = await ModelConfigService.create_composite_model(db=db, model_data=model_data, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"组合模型创建成功: {result_orm.name} (ID: {result_orm.id})")
|
||||
|
||||
result = model_schema.ModelConfig.model_validate(result_orm)
|
||||
return success(data=result, msg="组合模型创建成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"创建组合模型失败: {model_data.name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/composite/{model_id}", response_model=ApiResponse)
|
||||
async def update_composite_model(
|
||||
model_id: uuid.UUID,
|
||||
model_data: model_schema.CompositeModelCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新组合模型"""
|
||||
api_logger.info(f"更新组合模型请求: model_id={model_id}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
result_orm = await ModelConfigService.update_composite_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"组合模型更新成功: {result_orm.name} (ID: {model_id})")
|
||||
|
||||
result = model_schema.ModelConfig.model_validate(result_orm)
|
||||
return success(data=result, msg="组合模型更新成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"更新组合模型失败: model_id={model_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.delete("/composite/{model_id}", response_model=ApiResponse)
|
||||
def delete_composite_model(
|
||||
model_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除组合模型"""
|
||||
api_logger.info(f"删除组合模型请求: model_id={model_id}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
ModelConfigService.delete_model(db=db, model_id=model_id, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"组合模型删除成功: model_id={model_id}")
|
||||
return success(msg="组合模型删除成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"删除组合模型失败: model_id={model_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/{model_id}", response_model=ApiResponse)
|
||||
def update_model(
|
||||
model_id: uuid.UUID,
|
||||
@@ -214,6 +422,51 @@ def get_model_api_keys(
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/provider/apikeys", response_model=ApiResponse)
|
||||
async def create_model_api_key_by_provider(
|
||||
api_key_data: model_schema.ModelApiKeyCreateByProvider,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
根据供应商为所有匹配的模型创建API Key
|
||||
"""
|
||||
api_logger.info(f"创建API Key请求: provider={api_key_data.provider}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 根据tenant_id和provider筛选model_config_id列表
|
||||
model_config_ids = api_key_data.model_config_ids
|
||||
if not model_config_ids:
|
||||
model_config_ids = ModelConfigRepository.get_model_config_ids_by_provider(
|
||||
db=db,
|
||||
tenant_id=current_user.tenant_id,
|
||||
provider=api_key_data.provider
|
||||
)
|
||||
|
||||
if not model_config_ids:
|
||||
raise BusinessException(f"未找到供应商 {api_key_data.provider} 的模型配置", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
# 构造schema并调用service
|
||||
create_data = model_schema.ModelApiKeyCreateByProvider(
|
||||
provider=api_key_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
description=api_key_data.description,
|
||||
config=api_key_data.config,
|
||||
is_active=api_key_data.is_active,
|
||||
priority=api_key_data.priority,
|
||||
model_config_ids=model_config_ids
|
||||
)
|
||||
created_keys = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)
|
||||
|
||||
api_logger.info(f"API Key创建成功: 关联{len(created_keys)}个模型")
|
||||
result_list = [model_schema.ModelApiKey.model_validate(key) for key in created_keys]
|
||||
return success(data=result_list, msg=f"成功为 {len(created_keys)} 个模型创建API Key")
|
||||
except Exception as e:
|
||||
api_logger.error(f"创建API Key失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/{model_id}/apikeys", response_model=ApiResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_model_api_key(
|
||||
model_id: uuid.UUID,
|
||||
@@ -228,11 +481,12 @@ async def create_model_api_key(
|
||||
|
||||
try:
|
||||
# 设置模型配置ID
|
||||
api_key_data.model_config_id = model_id
|
||||
api_key_data.model_config_ids = [model_id]
|
||||
|
||||
api_logger.debug(f"开始创建模型API Key: {api_key_data.model_name}")
|
||||
result = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data)
|
||||
api_logger.info(f"模型API Key创建成功: {result.model_name} (ID: {result.id})")
|
||||
result_orm = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data)
|
||||
api_logger.info(f"模型API Key创建成功: {result_orm.model_name} (ID: {result_orm.id})")
|
||||
result = model_schema.ModelApiKey.model_validate(result_orm)
|
||||
return success(data=result, msg="模型API Key创建成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"创建模型API Key失败: {api_key_data.model_name} - {str(e)}")
|
||||
@@ -334,5 +588,3 @@ async def validate_model_config(
|
||||
return success(data=model_schema.ModelValidateResponse(**result), msg="验证完成")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ from app.core.workflow.graph_builder import GraphBuilder, StreamOutputConfig
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_config import VariableType
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.template_renderer import render_template
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -157,12 +156,137 @@ class WorkflowExecutor:
|
||||
"error": result.get("error"),
|
||||
}
|
||||
|
||||
def _update_end_activate(self, node_id):
|
||||
def _update_scope_activate(self, scope, status=None):
|
||||
"""
|
||||
Update the activation state of all End nodes based on a completed scope (node or variable).
|
||||
|
||||
Iterates over all End nodes in `self.end_outputs` and calls
|
||||
`update_activate` on each, which may:
|
||||
- Activate variable segments that depend on the completed node/scope.
|
||||
- Activate the entire End node output if all control conditions are met.
|
||||
|
||||
If any End node becomes active and `self.activate_end` is not yet set,
|
||||
this node will be marked as the currently active End node.
|
||||
|
||||
Args:
|
||||
scope (str): The node ID or scope that has completed execution.
|
||||
status (str | None): Optional status of the node (used for branch/control nodes).
|
||||
"""
|
||||
for node in self.end_outputs.keys():
|
||||
self.end_outputs[node].update_activate(node_id)
|
||||
self.end_outputs[node].update_activate(scope, status)
|
||||
if self.end_outputs[node].activate and self.activate_end is None:
|
||||
self.activate_end = node
|
||||
|
||||
def _update_stream_output_status(self, activate, data):
|
||||
"""
|
||||
Update the stream output state of End nodes based on workflow state updates.
|
||||
|
||||
This method checks which nodes/scopes are activated and propagates
|
||||
activation to End nodes accordingly.
|
||||
|
||||
Args:
|
||||
activate (dict): Mapping of node_id -> bool indicating which nodes/scopes are activated.
|
||||
data (dict): Mapping of node_id -> node runtime data, including outputs.
|
||||
|
||||
Behavior:
|
||||
For each node in `data`:
|
||||
1. If the node is activated (`activate[node_id]` is True),
|
||||
retrieve its output status from `runtime_vars`.
|
||||
2. Call `_update_scope_activate` to propagate the activation
|
||||
to all relevant End nodes and update `self.activate_end`.
|
||||
"""
|
||||
for node_id in data.keys():
|
||||
if activate.get(node_id):
|
||||
node_output_status = (
|
||||
data[node_id]
|
||||
.get('runtime_vars', {})
|
||||
.get(node_id)
|
||||
.get("output")
|
||||
)
|
||||
self._update_scope_activate(node_id, status=node_output_status)
|
||||
|
||||
async def _emit_active_chunks(
|
||||
self,
|
||||
node_outputs: dict,
|
||||
variables: dict,
|
||||
force=False
|
||||
):
|
||||
"""
|
||||
Process and yield all currently active output segments for the currently active End node.
|
||||
|
||||
This method handles stream-mode output for an End node by iterating through its output segments
|
||||
(`OutputContent`). Only segments marked as active (`activate=True`) are processed, unless
|
||||
`force=True`, which allows all segments to be processed regardless of their activation state.
|
||||
|
||||
Behavior:
|
||||
1. Iterates from the current `cursor` position to the end of the outputs list.
|
||||
2. For each segment:
|
||||
- If the segment is literal text (`is_variable=False`), append it directly.
|
||||
- If the segment is a variable (`is_variable=True`), evaluate it using
|
||||
`evaluate_expression` with the given `node_outputs` and `variables`,
|
||||
then transform the result with `_trans_output_string`.
|
||||
3. Yield a stream event of type "message" containing the processed chunk.
|
||||
4. Move the `cursor` forward after processing each segment.
|
||||
5. When all segments have been processed, remove this End node from `end_outputs`
|
||||
and reset `activate_end` to None.
|
||||
|
||||
Args:
|
||||
node_outputs (dict): Current runtime node outputs, used for variable evaluation.
|
||||
variables (dict): Current runtime variables, used for variable evaluation.
|
||||
force (bool, default=False): If True, process segments even if `activate=False`.
|
||||
|
||||
Yields:
|
||||
dict: A stream event of type "message" containing the processed chunk.
|
||||
|
||||
Notes:
|
||||
- Segments that fail evaluation (ValueError) are skipped with a warning logged.
|
||||
- This method only processes the currently active End node (`self.activate_end`).
|
||||
- Use `force=True` for final emission regardless of activation state.
|
||||
"""
|
||||
|
||||
end_info = self.end_outputs[self.activate_end]
|
||||
|
||||
while end_info.cursor < len(end_info.outputs):
|
||||
final_chunk = ''
|
||||
current_segment = end_info.outputs[end_info.cursor]
|
||||
|
||||
if not current_segment.activate and not force:
|
||||
# Stop processing until this segment becomes active
|
||||
break
|
||||
|
||||
# Literal segment
|
||||
if not current_segment.is_variable:
|
||||
final_chunk += current_segment.literal
|
||||
else:
|
||||
# Variable segment: evaluate and transform
|
||||
try:
|
||||
chunk = evaluate_expression(
|
||||
current_segment.literal,
|
||||
variables=variables,
|
||||
node_outputs=node_outputs
|
||||
)
|
||||
chunk = self._trans_output_string(chunk)
|
||||
final_chunk += chunk
|
||||
except ValueError:
|
||||
# Log failed evaluation but continue streaming
|
||||
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}")
|
||||
|
||||
if final_chunk:
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": final_chunk
|
||||
}
|
||||
}
|
||||
|
||||
# Advance cursor after processing
|
||||
end_info.cursor += 1
|
||||
|
||||
# Remove End node from active tracking if all segments have been processed
|
||||
if end_info.cursor >= len(end_info.outputs):
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
|
||||
@staticmethod
|
||||
def _trans_output_string(content):
|
||||
if isinstance(content, str):
|
||||
@@ -218,14 +342,8 @@ class WorkflowExecutor:
|
||||
|
||||
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
|
||||
full_content = ''
|
||||
for end_info in self.end_outputs.values():
|
||||
output_template = "".join([output.literal for output in end_info.outputs])
|
||||
full_content += render_template(
|
||||
output_template,
|
||||
result.get("variables", {}),
|
||||
result.get("runtime_vars", {}),
|
||||
strict=False
|
||||
)
|
||||
for end_id in self.end_outputs.keys():
|
||||
full_content += result.get('runtime_vars', {}).get(end_id, {}).get('output', '')
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
@@ -306,7 +424,7 @@ class WorkflowExecutor:
|
||||
try:
|
||||
chunk_count = 0
|
||||
full_content = ''
|
||||
|
||||
self._update_scope_activate("sys")
|
||||
async for event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
|
||||
@@ -333,9 +451,12 @@ class WorkflowExecutor:
|
||||
if not end_info or end_info.cursor >= len(end_info.outputs):
|
||||
continue
|
||||
current_output = end_info.outputs[end_info.cursor]
|
||||
if current_output.is_variable and current_output.depends_on_node(node_id):
|
||||
if current_output.is_variable and current_output.depends_on_scope(node_id):
|
||||
if data.get("done"):
|
||||
end_info.cursor += 1
|
||||
if end_info.cursor >= len(end_info.outputs):
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
else:
|
||||
full_content += data.get("chunk")
|
||||
yield {
|
||||
@@ -415,91 +536,53 @@ class WorkflowExecutor:
|
||||
|
||||
elif mode == "updates":
|
||||
# Handle state updates - store final state
|
||||
for node_id in data.keys():
|
||||
self._update_end_activate(node_id)
|
||||
wait = False
|
||||
state = graph.get_state(config=self.checkpoint_config)
|
||||
node_outputs = state.values.get("runtime_vars", {})
|
||||
for _ in data.keys():
|
||||
node_outputs = node_outputs | data.get(_).get("runtime_vars", {})
|
||||
state = graph.get_state(config=self.checkpoint_config).values
|
||||
node_outputs = state.get("runtime_vars", {})
|
||||
variables = state.get("variables", {})
|
||||
activate = state.get("activate", {})
|
||||
for _, node_data in data.items():
|
||||
node_outputs |= node_data.get("runtime_vars", {})
|
||||
variables |= node_data.get("variables", {})
|
||||
|
||||
self._update_stream_output_status(activate, data)
|
||||
wait = False
|
||||
while self.activate_end and not wait:
|
||||
message = ''
|
||||
logger.info(self.activate_end)
|
||||
end_info = self.end_outputs[self.activate_end]
|
||||
content = end_info.outputs[end_info.cursor]
|
||||
while content.activate:
|
||||
if not content.is_variable:
|
||||
full_content += content.literal
|
||||
message += content.literal
|
||||
else:
|
||||
try:
|
||||
chunk = evaluate_expression(
|
||||
content.literal,
|
||||
variables={},
|
||||
node_outputs=node_outputs
|
||||
)
|
||||
chunk = self._trans_output_string(chunk)
|
||||
message += chunk
|
||||
full_content += chunk
|
||||
except ValueError:
|
||||
pass
|
||||
end_info.cursor += 1
|
||||
if end_info.cursor == len(end_info.outputs):
|
||||
break
|
||||
content = end_info.outputs[end_info.cursor]
|
||||
if end_info.cursor != len(end_info.outputs):
|
||||
async for msg_event in self._emit_active_chunks(
|
||||
node_outputs=node_outputs,
|
||||
variables=variables
|
||||
):
|
||||
full_content += msg_event["data"]['chunk']
|
||||
yield msg_event
|
||||
|
||||
if self.activate_end:
|
||||
wait = True
|
||||
else:
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
for node_id in data.keys():
|
||||
self._update_end_activate(node_id)
|
||||
if message:
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": message
|
||||
}
|
||||
}
|
||||
self._update_stream_output_status(activate, data)
|
||||
|
||||
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
|
||||
f"- execution_id: {self.execution_id}")
|
||||
|
||||
result = graph.get_state(self.checkpoint_config).values
|
||||
while self.activate_end:
|
||||
message = ''
|
||||
end_info = self.end_outputs[self.activate_end]
|
||||
content = end_info.outputs[end_info.cursor]
|
||||
if not content.is_variable:
|
||||
message += content.literal
|
||||
else:
|
||||
node_outputs = result.get("runtime_vars", {})
|
||||
variables = result.get("variables", {})
|
||||
try:
|
||||
chunk = evaluate_expression(
|
||||
content.literal,
|
||||
node_outputs = result.get("runtime_vars", {})
|
||||
variables = result.get("variables", {})
|
||||
self.end_outputs = {
|
||||
node_id: node_info
|
||||
for node_id, node_info in self.end_outputs.items()
|
||||
if node_info.activate
|
||||
}
|
||||
|
||||
if self.end_outputs or self.activate_end:
|
||||
while self.activate_end:
|
||||
async for msg_event in self._emit_active_chunks(
|
||||
node_outputs=node_outputs,
|
||||
variables=variables,
|
||||
node_outputs=node_outputs
|
||||
)
|
||||
chunk = self._trans_output_string(chunk)
|
||||
message += chunk
|
||||
full_content += chunk
|
||||
except ValueError:
|
||||
pass
|
||||
end_info.cursor += 1
|
||||
if end_info.cursor == len(end_info.outputs):
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
if self.end_outputs:
|
||||
force=True
|
||||
):
|
||||
full_content += msg_event["data"]['chunk']
|
||||
yield msg_event
|
||||
|
||||
if not self.activate_end and self.end_outputs:
|
||||
self.activate_end = list(self.end_outputs.keys())[0]
|
||||
if message:
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": message
|
||||
}
|
||||
}
|
||||
|
||||
# 计算耗时
|
||||
end_time = datetime.datetime.now()
|
||||
|
||||
@@ -53,114 +53,110 @@ class OutputContent(BaseModel):
|
||||
)
|
||||
)
|
||||
|
||||
def depends_on_node(self, node_id: str) -> bool:
|
||||
def depends_on_scope(self, scope: str) -> bool:
|
||||
"""
|
||||
Check if this output segment depends on a specific node's variable.
|
||||
|
||||
This method examines the `literal` of the output segment to see if it
|
||||
contains a variable placeholder referencing the given node in the form:
|
||||
|
||||
{{ node_id.field_name }}
|
||||
|
||||
It uses a regular expression to match the exact node ID, avoiding
|
||||
false positives from substring matches (e.g., 'node1' should not match 'node10').
|
||||
Check if this segment depends on a given scope.
|
||||
|
||||
Args:
|
||||
node_id (str): The ID of the node to check for in this segment's variable placeholders.
|
||||
scope (str): Node ID or special variable prefix (e.g., "sys").
|
||||
|
||||
Returns:
|
||||
bool:
|
||||
- True if the segment contains a variable referencing the given node.
|
||||
- False otherwise.
|
||||
|
||||
Example:
|
||||
literal = "{{node1.name}}"
|
||||
|
||||
depends_on_node("node1") -> True
|
||||
depends_on_node("node2") -> False
|
||||
|
||||
Usage:
|
||||
This method is primarily used in stream mode to determine whether
|
||||
a particular variable output segment should be activated when a
|
||||
specific upstream node completes execution.
|
||||
bool: True if this segment references the given scope.
|
||||
"""
|
||||
variable_pattern = rf"\{{\{{\s*{re.escape(node_id)}\.[a-zA-Z0-9_]+\s*\}}\}}"
|
||||
pattern = re.compile(variable_pattern)
|
||||
match = pattern.search(self.literal)
|
||||
if match:
|
||||
return True
|
||||
return False
|
||||
pattern = rf"\{{\{{\s*{re.escape(scope)}\.[a-zA-Z0-9_]+\s*\}}\}}"
|
||||
return bool(re.search(pattern, self.literal))
|
||||
|
||||
|
||||
class StreamOutputConfig(BaseModel):
|
||||
"""
|
||||
Streaming output configuration for an End node.
|
||||
|
||||
This structure controls:
|
||||
- whether the End node output is globally active
|
||||
- which upstream branch nodes are responsible for activation
|
||||
- how each output segment behaves in streaming mode
|
||||
This configuration describes how the End node output behaves in streaming mode,
|
||||
including:
|
||||
- whether output emission is globally activated
|
||||
- which upstream branch/control nodes gate the activation
|
||||
- how each parsed output segment is streamed and activated
|
||||
"""
|
||||
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Global activation state of the End node output.\n"
|
||||
"If False, no output should be emitted until all control nodes are resolved."
|
||||
"Global activation flag for the End node output.\n"
|
||||
"When False, output segments should not be emitted even if available.\n"
|
||||
"This flag typically becomes True once required control branch conditions "
|
||||
"are satisfied."
|
||||
)
|
||||
)
|
||||
|
||||
control_nodes: list[str] = Field(
|
||||
control_nodes: dict[str, str] = Field(
|
||||
...,
|
||||
description=(
|
||||
"List of upstream branch node IDs that control this End node.\n"
|
||||
"Each node must signal completion before output becomes active."
|
||||
"Control branch conditions for this End node output.\n"
|
||||
"Mapping of `branch_node_id -> expected_branch_label`.\n"
|
||||
"The End node output becomes globally active when a controlling branch node "
|
||||
"reports a matching completion status."
|
||||
)
|
||||
)
|
||||
|
||||
outputs: list[OutputContent] = Field(
|
||||
...,
|
||||
description="Ordered list of output segments parsed from the output template."
|
||||
description=(
|
||||
"Ordered list of output segments parsed from the output template.\n"
|
||||
"Each segment represents either a literal text block or a variable placeholder "
|
||||
"that may be activated independently."
|
||||
)
|
||||
)
|
||||
|
||||
cursor: int = Field(
|
||||
...,
|
||||
description=(
|
||||
"Streaming cursor index.\n"
|
||||
"Indicates how many output segments have already been emitted."
|
||||
"Indicates the next output segment index to be emitted.\n"
|
||||
"Segments with index < cursor are considered already streamed."
|
||||
)
|
||||
)
|
||||
|
||||
def update_activate(self, node_id):
|
||||
def update_activate(self, scope: str, status=None):
|
||||
"""
|
||||
Update activation state based on an upstream node completion.
|
||||
Update streaming activation state based on an upstream node or special variable.
|
||||
|
||||
This method is typically called when a branch/control node finishes execution.
|
||||
Args:
|
||||
scope (str):
|
||||
Identifier of the completed upstream entity.
|
||||
- If a control branch node, it should match a key in `control_nodes`.
|
||||
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
|
||||
status (optional):
|
||||
Completion status of the control branch node.
|
||||
Required when `scope` refers to a control node.
|
||||
|
||||
Behavior:
|
||||
1. If the node is a control node:
|
||||
- Remove it from `control_nodes`
|
||||
- If all control nodes are resolved, activate the entire output
|
||||
1. Control branch nodes:
|
||||
- If `scope` matches a key in `control_nodes` and `status` matches the expected
|
||||
branch label, the End node output becomes globally active (`activate = True`).
|
||||
|
||||
2. Activate variable output segments that depend on this node:
|
||||
- If an output segment is a variable
|
||||
- And its literal references the completed node_id
|
||||
- Mark that segment as active
|
||||
2. Variable output segments:
|
||||
- For each segment that is a variable (`is_variable=True`):
|
||||
- If the segment literal references `scope`, mark the segment as active.
|
||||
- This applies both to regular node variables (e.g., "node_id.field")
|
||||
and special system variables (e.g., "sys.xxx").
|
||||
|
||||
Notes:
|
||||
- This method does not emit output or advance the streaming cursor.
|
||||
- It only updates activation flags based on upstream events or special variables.
|
||||
"""
|
||||
|
||||
# Case 1: resolve control branch dependency
|
||||
if node_id in self.control_nodes:
|
||||
self.control_nodes.remove(node_id)
|
||||
|
||||
# All branch constraints resolved → enable output
|
||||
if not self.control_nodes:
|
||||
if scope in self.control_nodes.keys():
|
||||
if status is None:
|
||||
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
||||
if status == self.control_nodes[scope]:
|
||||
self.activate = True
|
||||
|
||||
# Case 2: activate variable segments related to this node
|
||||
for i in range(len(self.outputs)):
|
||||
if (
|
||||
self.outputs[i].is_variable
|
||||
and self.outputs[i].depends_on_node(node_id)
|
||||
and self.outputs[i].depends_on_scope(scope)
|
||||
):
|
||||
self.outputs[i].activate = True
|
||||
|
||||
@@ -184,11 +180,11 @@ class GraphBuilder:
|
||||
self._find_upstream_branch_node = lru_cache(
|
||||
maxsize=len(self.nodes) * 2
|
||||
)(self._find_upstream_branch_node)
|
||||
self._analyze_end_node_output()
|
||||
|
||||
self.graph = StateGraph(WorkflowState)
|
||||
self.add_nodes()
|
||||
self.add_edges()
|
||||
self._analyze_end_node_output()
|
||||
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
||||
|
||||
@property
|
||||
@@ -216,30 +212,53 @@ class GraphBuilder:
|
||||
except KeyError:
|
||||
raise RuntimeError(f"Node not found: Id={node_id}")
|
||||
|
||||
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[str]]:
|
||||
"""Find upstream branch nodes for a given target node in the workflow graph.
|
||||
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[tuple[str, str]]]:
|
||||
"""
|
||||
Recursively find all upstream branch (control) nodes that influence the execution
|
||||
of the given target node.
|
||||
|
||||
This method identifies all upstream control (branch) nodes that can affect
|
||||
the execution of `target_node`. If `target_node` is reachable from a start
|
||||
node (i.e., a node with no upstream nodes), the method returns an empty tuple.
|
||||
This method walks upstream along the workflow graph starting from `target_node`.
|
||||
It distinguishes between:
|
||||
- branch nodes (node types listed in `BRANCH_NODES`)
|
||||
- non-branch nodes (ordinary processing nodes)
|
||||
|
||||
The function distinguishes between branch nodes (defined in `BRANCH_NODES`)
|
||||
and non-branch nodes, recursively traversing upstream through non-branch
|
||||
nodes. If any non-branch upstream path does not lead to a branch node,
|
||||
the result will indicate that no valid upstream branch node exists.
|
||||
Traversal rules:
|
||||
1. For each immediate upstream node:
|
||||
- If it is a branch node, it is recorded as an affecting control node.
|
||||
- If it is a non-branch node, the traversal continues recursively upstream.
|
||||
2. If ANY upstream path reaches a START / CYCLE_START node without encountering
|
||||
a branch node, the traversal is considered invalid:
|
||||
- `has_branch` will be False
|
||||
- no branch nodes are returned.
|
||||
3. Only when ALL upstream non-branch paths eventually lead to at least one
|
||||
branch node will `has_branch` be True.
|
||||
|
||||
Special case:
|
||||
- If `target_node` has no upstream nodes AND its type is START or CYCLE_START,
|
||||
it is considered directly reachable from the workflow entry, and therefore
|
||||
has no controlling branch nodes.
|
||||
|
||||
Args:
|
||||
target_node (str): The identifier of the target node.
|
||||
target_node (str):
|
||||
The identifier of the node whose upstream control branches
|
||||
are to be resolved.
|
||||
|
||||
Returns:
|
||||
tuple[bool, tuple[str]]:
|
||||
- has_branch (bool): True if all upstream non-branch paths lead to at least
|
||||
one branch node; False if any path reaches a start node without a branch.
|
||||
- branch_nodes (tuple[str]): A deduplicated tuple of upstream branch node IDs
|
||||
affecting `target_node`. Returns an empty tuple if `has_branch` is False.
|
||||
tuple[bool, tuple[tuple[str, str]]]:
|
||||
- has_branch (bool):
|
||||
True if every upstream path from `target_node` encounters
|
||||
at least one branch node.
|
||||
False if any path reaches a start node without a branch.
|
||||
- branch_nodes (tuple[tuple[str, str]]):
|
||||
A deduplicated tuple of `(branch_node_id, branch_label)` pairs
|
||||
representing all branch nodes that can influence `target_node`.
|
||||
Returns an empty tuple if `has_branch` is False.
|
||||
"""
|
||||
source_nodes = [
|
||||
edge.get("source")
|
||||
{
|
||||
"id": edge.get("source"),
|
||||
"branch": edge.get("label")
|
||||
}
|
||||
for edge in self.edges
|
||||
if edge.get("target") == target_node
|
||||
]
|
||||
@@ -249,11 +268,13 @@ class GraphBuilder:
|
||||
branch_nodes = []
|
||||
non_branch_nodes = []
|
||||
|
||||
for node_id in source_nodes:
|
||||
if self.get_node_type(node_id) in BRANCH_NODES:
|
||||
branch_nodes.append(node_id)
|
||||
for node_info in source_nodes:
|
||||
if self.get_node_type(node_info["id"]) in BRANCH_NODES:
|
||||
branch_nodes.append(
|
||||
(node_info["id"], node_info["branch"])
|
||||
)
|
||||
else:
|
||||
non_branch_nodes.append(node_id)
|
||||
non_branch_nodes.append(node_info["id"])
|
||||
|
||||
has_branch = True
|
||||
for node_id in non_branch_nodes:
|
||||
@@ -334,7 +355,7 @@ class GraphBuilder:
|
||||
activate=not has_branch,
|
||||
|
||||
# Branch nodes that control activation of this End node
|
||||
control_nodes=list(control_nodes),
|
||||
control_nodes=dict(control_nodes),
|
||||
|
||||
# Convert output segments into OutputContent objects
|
||||
outputs=list(
|
||||
@@ -362,7 +383,7 @@ class GraphBuilder:
|
||||
else:
|
||||
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||
activate=True,
|
||||
control_nodes=[],
|
||||
control_nodes={},
|
||||
outputs=list(
|
||||
[
|
||||
OutputContent(
|
||||
|
||||
@@ -25,6 +25,6 @@ class MemoryWriteNodeConfig(BaseNodeConfig):
|
||||
...
|
||||
)
|
||||
|
||||
config_id: UUID = Field(
|
||||
config_id: UUID | int = Field(
|
||||
...
|
||||
)
|
||||
|
||||
@@ -36,9 +36,10 @@ class MemoryReadNode(BaseNode):
|
||||
class MemoryWriteNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
||||
self.typed_config: MemoryWriteNodeConfig | None = None
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
||||
end_user_id = self.get_variable("sys.user_id", state)
|
||||
|
||||
if not end_user_id:
|
||||
|
||||
@@ -6,7 +6,7 @@ from .document_model import Document
|
||||
from .file_model import File
|
||||
from .file_metadata_model import FileMetadata
|
||||
from .generic_file_model import GenericFile
|
||||
from .models_model import ModelConfig, ModelProvider, ModelType, ModelApiKey
|
||||
from .models_model import ModelConfig, ModelProvider, ModelType, ModelApiKey, ModelBase, LoadBalanceStrategy
|
||||
from .memory_short_model import ShortTermMemory, LongTermMemory
|
||||
from .knowledgeshare_model import KnowledgeShare
|
||||
from .app_model import App
|
||||
@@ -79,4 +79,6 @@ __all__ = [
|
||||
"AuthType",
|
||||
"ExecutionStatus",
|
||||
"MemoryPerceptualModel",
|
||||
"ModelBase",
|
||||
"LoadBalanceStrategy"
|
||||
]
|
||||
|
||||
@@ -1,19 +1,31 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
from typing import Optional, List
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db import Base
|
||||
|
||||
|
||||
class BaseModel(Base):
|
||||
"""基础模型(抽象类,提取公共字段)"""
|
||||
__abstract__ = True # 标记为抽象类,不生成表
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
|
||||
is_active = Column(Boolean, default=True, nullable=False, comment="是否激活")
|
||||
|
||||
|
||||
class ModelType(StrEnum):
|
||||
"""模型类型枚举"""
|
||||
LLM = "llm"
|
||||
CHAT = "chat"
|
||||
EMBEDDING = "embedding"
|
||||
RERANK = "rerank"
|
||||
# IMAGE = "image"
|
||||
# AUDIO = "audio"
|
||||
# VISION = "vision"
|
||||
|
||||
|
||||
class ModelProvider(StrEnum):
|
||||
@@ -30,16 +42,37 @@ class ModelProvider(StrEnum):
|
||||
XINFERENCE = "xinference"
|
||||
GPUSTACK = "gpustack"
|
||||
BEDROCK = "bedrock"
|
||||
COMPOSITE = "composite"
|
||||
|
||||
|
||||
class ModelConfig(Base):
|
||||
class LoadBalanceStrategy(StrEnum):
|
||||
"""API Key负载均衡策略枚举"""
|
||||
ROUND_ROBIN = "round_robin" # 轮询
|
||||
WEIGHTED_ROUND_ROBIN = "weighted_round_robin" # 加权轮询
|
||||
RANDOM = "random" # 随机
|
||||
|
||||
|
||||
# 多对多关联表
|
||||
model_config_api_key_association = Table(
|
||||
'model_config_api_key_association',
|
||||
Base.metadata,
|
||||
Column('model_config_id', UUID(as_uuid=True), ForeignKey('model_configs.id'), primary_key=True),
|
||||
Column('api_key_id', UUID(as_uuid=True), ForeignKey('model_api_keys.id'), primary_key=True),
|
||||
Column('created_at', DateTime, default=datetime.datetime.now)
|
||||
)
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""模型配置表"""
|
||||
__tablename__ = "model_configs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
model_id = Column(UUID(as_uuid=True), ForeignKey("model_bases.id"), nullable=True, index=True, comment="基础模型ID")
|
||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, index=True, comment="租户ID")
|
||||
logo = Column(String(255), nullable=True, comment="模型logo图片URL")
|
||||
name = Column(String, nullable=False, comment="模型显示名称")
|
||||
provider = Column(String, nullable=False, comment="供应商", server_default=ModelProvider.COMPOSITE)
|
||||
type = Column(String, nullable=False, index=True, comment="模型类型")
|
||||
is_composite = Column(Boolean, default=False, server_default="true", nullable=False, comment="是否为组合模型")
|
||||
description = Column(String, comment="模型描述")
|
||||
|
||||
# 模型配置参数
|
||||
@@ -56,29 +89,28 @@ class ModelConfig(Base):
|
||||
# context_length = Column(String, comment="上下文长度")
|
||||
|
||||
# 状态管理
|
||||
is_active = Column(Boolean, default=True, nullable=False, comment="是否激活")
|
||||
is_public = Column(Boolean, default=False, nullable=False, comment="是否公开")
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
|
||||
load_balance_strategy = Column(String, nullable=True, comment="负载均衡策略")
|
||||
|
||||
# 关联关系
|
||||
api_keys = relationship("ModelApiKey", back_populates="model_config", cascade="all, delete-orphan")
|
||||
model_base = relationship("ModelBase", back_populates="configs")
|
||||
api_keys = relationship(
|
||||
"ModelApiKey",
|
||||
secondary=model_config_api_key_association,
|
||||
back_populates="model_configs"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ModelConfig(id={self.id}, name={self.name}, type={self.type})>"
|
||||
|
||||
|
||||
class ModelApiKey(Base):
|
||||
class ModelApiKey(BaseModel):
|
||||
"""模型API密钥表"""
|
||||
__tablename__ = "model_api_keys"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
model_config_id = Column(UUID(as_uuid=True), ForeignKey("model_configs.id"), nullable=False, comment="模型配置ID")
|
||||
|
||||
# API Key 信息
|
||||
model_name = Column(String, nullable=False, comment="模型实际名称")
|
||||
description = Column(String, comment="备注")
|
||||
provider = Column(String, nullable=False, comment="API Key提供商")
|
||||
api_key = Column(String, nullable=False, comment="API密钥")
|
||||
api_base = Column(String, comment="API基础URL")
|
||||
@@ -91,15 +123,41 @@ class ModelApiKey(Base):
|
||||
last_used_at = Column(DateTime, comment="最后使用时间")
|
||||
|
||||
# 状态管理
|
||||
is_active = Column(Boolean, default=True, nullable=False, comment="是否激活")
|
||||
priority = Column(String, default="1", comment="优先级")
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
|
||||
|
||||
|
||||
# 关联关系
|
||||
model_config = relationship("ModelConfig", back_populates="api_keys")
|
||||
model_configs = relationship(
|
||||
"ModelConfig",
|
||||
secondary=model_config_api_key_association,
|
||||
back_populates="api_keys"
|
||||
)
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ModelApiKey(id={self.id}, model_name={self.model_name}, provider={self.provider}, model_config_id={self.model_config_id})>"
|
||||
return f"<ModelApiKey(id={self.id}, model_name={self.model_name}, provider={self.provider})>"
|
||||
|
||||
|
||||
class ModelBase(Base):
|
||||
"""基础模型信息表(模型广场)"""
|
||||
__tablename__ = "model_bases"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
logo = Column(String(255), nullable=True, comment="模型logo图片URL")
|
||||
name = Column(String, nullable=False, comment="模型唯一标识(如gpt-3.5-turbo)")
|
||||
type = Column(String, nullable=False, index=True, comment="模型类型")
|
||||
provider = Column(String, nullable=False, index=True)
|
||||
description = Column(Text, comment="模型描述")
|
||||
is_deprecated = Column(Boolean, default=False, nullable=False, comment="是否弃用")
|
||||
is_official = Column(Boolean, default=True, comment="是否供应商官方模型(区分自定义)")
|
||||
tags = Column(ARRAY(String), default=list, nullable=False, comment="模型标签(如['聊天', '创作'])")
|
||||
add_count = Column(Integer, default=0, nullable=False, comment="模型被用户添加的次数")
|
||||
|
||||
# 关联关系
|
||||
configs = relationship("ModelConfig", back_populates="model_base", cascade="all, delete-orphan")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("name", "provider", name="uk_model_name_provider"),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ModelBase(name={self.name}, provider={self.provider}, type={self.type})>"
|
||||
@@ -1,12 +1,12 @@
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy import and_, or_, func, desc
|
||||
from sqlalchemy.orm import Session, joinedload, selectinload
|
||||
from sqlalchemy import and_, or_, func, desc, select
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
import uuid
|
||||
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, ModelBase, model_config_api_key_association
|
||||
from app.schemas.model_schema import (
|
||||
ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
||||
ModelConfigQuery
|
||||
ModelConfigQuery, ModelConfigQueryNew
|
||||
)
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
@@ -107,6 +107,80 @@ class ModelConfigRepository:
|
||||
def get_list(db: Session, query: ModelConfigQuery, tenant_id: uuid.UUID | None = None) -> Tuple[List[ModelConfig], int]:
|
||||
"""获取模型配置列表"""
|
||||
db_logger.debug(f"查询模型配置列表: {query.dict()}, tenant_id={tenant_id}")
|
||||
|
||||
try:
|
||||
# 构建查询条件
|
||||
filters = []
|
||||
|
||||
# 添加租户过滤(查询本租户的模型或公开模型)
|
||||
if tenant_id:
|
||||
filters.append(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public
|
||||
)
|
||||
)
|
||||
|
||||
# 支持多个 type 值(使用 IN 查询)
|
||||
# 兼容 chat 和 llm 类型:如果查询包含其中一个,则同时匹配两者
|
||||
if query.type:
|
||||
type_values = list(query.type)
|
||||
# 如果包含 chat 或 llm,则同时包含两者
|
||||
if ModelType.CHAT in type_values or ModelType.LLM in type_values:
|
||||
if ModelType.CHAT not in type_values:
|
||||
type_values.append(ModelType.CHAT)
|
||||
if ModelType.LLM not in type_values:
|
||||
type_values.append(ModelType.LLM)
|
||||
filters.append(ModelConfig.type.in_(type_values))
|
||||
|
||||
if query.is_active is not None:
|
||||
filters.append(ModelConfig.is_active == query.is_active)
|
||||
|
||||
if query.is_public is not None:
|
||||
filters.append(ModelConfig.is_public == query.is_public)
|
||||
|
||||
if query.search:
|
||||
# 搜索逻辑需要join ModelApiKey表来搜索model_name
|
||||
search_filter = or_(
|
||||
ModelConfig.name.ilike(f"%{query.search}%"),
|
||||
# ModelConfig.description.ilike(f"%{query.search}%")
|
||||
)
|
||||
filters.append(search_filter)
|
||||
|
||||
# 构建基础查询
|
||||
base_query = db.query(ModelConfig).options(
|
||||
joinedload(ModelConfig.api_keys)
|
||||
)
|
||||
|
||||
# 如果需要按provider筛选,需要join ModelApiKey表
|
||||
if query.provider:
|
||||
base_query = base_query.join(ModelApiKey).filter(
|
||||
ModelApiKey.provider == query.provider
|
||||
).distinct()
|
||||
|
||||
if filters:
|
||||
base_query = base_query.filter(and_(*filters))
|
||||
|
||||
# 获取总数
|
||||
total = base_query.count()
|
||||
|
||||
# 分页查询
|
||||
models = base_query.order_by(desc(ModelConfig.updated_at)).offset(
|
||||
(query.page - 1) * query.pagesize
|
||||
).limit(query.pagesize).all()
|
||||
|
||||
db_logger.debug(f"模型配置列表查询成功: 总数={total}, 当前页={len(models)}, type筛选={query.type}")
|
||||
return models, total
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询模型配置列表失败: {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_list_new(db: Session, query: ModelConfigQueryNew, tenant_id: uuid.UUID | None = None) -> tuple[
|
||||
dict[str, list[ModelConfig]], Any]:
|
||||
"""获取模型配置列表"""
|
||||
db_logger.debug(f"查询模型配置列表: {query.model_dump()}, tenant_id={tenant_id}")
|
||||
|
||||
try:
|
||||
# 构建查询条件
|
||||
@@ -138,13 +212,15 @@ class ModelConfigRepository:
|
||||
|
||||
if query.is_public is not None:
|
||||
filters.append(ModelConfig.is_public == query.is_public)
|
||||
|
||||
if query.is_composite is not None:
|
||||
filters.append(ModelConfig.is_composite == query.is_composite)
|
||||
|
||||
if query.provider:
|
||||
filters.append(ModelConfig.provider == query.provider)
|
||||
|
||||
if query.search:
|
||||
# 搜索逻辑需要join ModelApiKey表来搜索model_name
|
||||
search_filter = or_(
|
||||
ModelConfig.name.ilike(f"%{query.search}%"),
|
||||
# ModelConfig.description.ilike(f"%{query.search}%")
|
||||
)
|
||||
search_filter = ModelConfig.name.ilike(f"%{query.search}%")
|
||||
filters.append(search_filter)
|
||||
|
||||
# 构建基础查询
|
||||
@@ -152,28 +228,30 @@ class ModelConfigRepository:
|
||||
joinedload(ModelConfig.api_keys)
|
||||
)
|
||||
|
||||
# 如果需要按provider筛选,需要join ModelApiKey表
|
||||
if query.provider:
|
||||
base_query = base_query.join(ModelApiKey).filter(
|
||||
ModelApiKey.provider == query.provider
|
||||
).distinct()
|
||||
|
||||
if filters:
|
||||
base_query = base_query.filter(and_(*filters))
|
||||
|
||||
# 获取总数
|
||||
total = base_query.count()
|
||||
|
||||
query_results = base_query.order_by(desc(ModelConfig.updated_at)).all()
|
||||
|
||||
provider_groups: Dict[str, List[ModelConfig]] = {}
|
||||
for model_config in query_results:
|
||||
provider = model_config.provider
|
||||
if provider not in provider_groups:
|
||||
provider_groups[provider] = []
|
||||
provider_groups[provider].append(model_config)
|
||||
|
||||
# 分页查询
|
||||
models = base_query.order_by(desc(ModelConfig.updated_at)).offset(
|
||||
(query.page - 1) * query.pagesize
|
||||
).limit(query.pagesize).all()
|
||||
|
||||
db_logger.debug(f"模型配置列表查询成功: 总数={total}, 当前页={len(models)}, type筛选={query.type}")
|
||||
return models, total
|
||||
db_logger.debug(
|
||||
f"模型配置列表查询成功: 总数={total}, "
|
||||
f"分组数={len(provider_groups)}, "
|
||||
f"各分组模型数={[len(v) for v in provider_groups.values()]}, "
|
||||
f"type筛选={query.type}")
|
||||
return provider_groups, total
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询模型配置列表失败: {str(e)}")
|
||||
db_logger.error(f"查询模型配置列表失败(按provider分组/无分页): {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
@@ -241,7 +319,7 @@ class ModelConfigRepository:
|
||||
return None
|
||||
|
||||
# 更新字段
|
||||
update_data = model_data.dict(exclude_unset=True)
|
||||
update_data = model_data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(db_model, field, value)
|
||||
|
||||
@@ -303,8 +381,18 @@ class ModelConfigRepository:
|
||||
# 按提供商统计 - 现在从ModelApiKey表获取
|
||||
provider_stats = {}
|
||||
provider_results = db.query(
|
||||
ModelApiKey.provider, func.count(func.distinct(ModelApiKey.model_config_id))
|
||||
).group_by(ModelApiKey.provider).all()
|
||||
# 保留 provider 字段
|
||||
ModelApiKey.provider,
|
||||
# 统计中间表中 唯一的 model_config_id 数量(替换原 ModelApiKey.model_config_id)
|
||||
func.count(func.distinct(model_config_api_key_association.c.model_config_id))
|
||||
).join(
|
||||
# 联表:ModelApiKey <-> 中间表(多对多关联)
|
||||
model_config_api_key_association,
|
||||
ModelApiKey.id == model_config_api_key_association.c.api_key_id
|
||||
).group_by(
|
||||
# 按 provider 分组(保留原有逻辑)
|
||||
ModelApiKey.provider
|
||||
).all()
|
||||
|
||||
for provider, count in provider_results:
|
||||
provider_stats[provider.value] = count
|
||||
@@ -325,6 +413,37 @@ class ModelConfigRepository:
|
||||
db_logger.error(f"获取模型统计信息失败: {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_model_config_ids_by_provider(
|
||||
db: Session,
|
||||
tenant_id: uuid.UUID,
|
||||
provider: Any
|
||||
) -> List[uuid.UUID]:
|
||||
"""根据tenant_id和provider获取model_config_id列表"""
|
||||
db_logger.debug(f"查询model_config_id列表: tenant_id={tenant_id}, provider={provider}")
|
||||
|
||||
try:
|
||||
# 查询ModelConfig关联的ModelApiKey,筛选出匹配的model_config_id
|
||||
model_config_ids = db.query(ModelConfig.id).join(
|
||||
ModelBase, ModelConfig.model_id == ModelBase.id
|
||||
).filter(
|
||||
and_(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public
|
||||
),
|
||||
ModelBase.provider == provider,
|
||||
~ModelConfig.is_composite
|
||||
)
|
||||
).distinct().all()
|
||||
|
||||
db_logger.debug(f"查询成功: 数量={len(model_config_ids)}")
|
||||
return [row[0] for row in model_config_ids]
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"查询model_config_id列表失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
class ModelApiKeyRepository:
|
||||
"""模型API Key Repository"""
|
||||
@@ -349,7 +468,14 @@ class ModelApiKeyRepository:
|
||||
db_logger.debug(f"根据模型配置ID查询API Key: model_config_id={model_config_id}")
|
||||
|
||||
try:
|
||||
query = db.query(ModelApiKey).filter(ModelApiKey.model_config_id == model_config_id)
|
||||
from app.models.models_model import ModelConfig, model_config_api_key_association
|
||||
|
||||
query = db.query(ModelApiKey).join(
|
||||
model_config_api_key_association,
|
||||
ModelApiKey.id == model_config_api_key_association.c.api_key_id
|
||||
).filter(
|
||||
model_config_api_key_association.c.model_config_id == model_config_id
|
||||
)
|
||||
|
||||
if is_active:
|
||||
query = query.filter(ModelApiKey.is_active)
|
||||
@@ -368,8 +494,20 @@ class ModelApiKeyRepository:
|
||||
db_logger.debug(f"创建API Key: {api_key_data.provider}")
|
||||
|
||||
try:
|
||||
db_api_key = ModelApiKey(**api_key_data.dict())
|
||||
from app.models.models_model import ModelConfig
|
||||
|
||||
# 创建API Key,不包含model_config_ids
|
||||
api_key_dict = api_key_data.model_dump(exclude={"model_config_ids"})
|
||||
db_api_key = ModelApiKey(**api_key_dict)
|
||||
db.add(db_api_key)
|
||||
db.flush() # 获取生成的ID
|
||||
|
||||
# 关联ModelConfig
|
||||
if api_key_data.model_config_ids:
|
||||
for model_config_id in api_key_data.model_config_ids:
|
||||
model_config = db.query(ModelConfig).filter(ModelConfig.id == model_config_id).first()
|
||||
if model_config:
|
||||
db_api_key.model_configs.append(model_config)
|
||||
|
||||
db_logger.info(f"API Key已添加到会话: {db_api_key.provider}")
|
||||
return db_api_key
|
||||
@@ -391,7 +529,7 @@ class ModelApiKeyRepository:
|
||||
return None
|
||||
|
||||
# 更新字段
|
||||
update_data = api_key_data.dict(exclude_unset=True)
|
||||
update_data = api_key_data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(db_api_key, field, value)
|
||||
|
||||
@@ -451,4 +589,74 @@ class ModelApiKeyRepository:
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"更新API Key使用统计失败: api_key_id={api_key_id} - {str(e)}")
|
||||
raise
|
||||
raise
|
||||
|
||||
|
||||
class ModelBaseRepository:
|
||||
"""基础模型Repository"""
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(db: Session, model_base_id: uuid.UUID) -> Optional['ModelBase']:
|
||||
return db.query(ModelBase).filter(ModelBase.id == model_base_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_list(db: Session, query: 'ModelBaseQuery') -> List['ModelBase']:
|
||||
|
||||
filters = []
|
||||
if query.type:
|
||||
filters.append(ModelBase.type == query.type)
|
||||
if query.provider:
|
||||
filters.append(ModelBase.provider == query.provider)
|
||||
if query.is_official is not None:
|
||||
filters.append(ModelBase.is_official == query.is_official)
|
||||
if query.is_deprecated is not None:
|
||||
filters.append(ModelBase.is_deprecated == query.is_deprecated)
|
||||
if query.search:
|
||||
filters.append(or_(
|
||||
ModelBase.name.ilike(f"%{query.search}%"),
|
||||
# ModelBase.description.ilike(f"%{query.search}%")
|
||||
))
|
||||
|
||||
q = db.query(ModelBase)
|
||||
if filters:
|
||||
q = q.filter(and_(*filters))
|
||||
|
||||
return q.order_by(ModelBase.add_count.desc()).all()
|
||||
|
||||
@staticmethod
|
||||
def create(db: Session, data: dict) -> 'ModelBase':
|
||||
model_base = ModelBase(**data)
|
||||
db.add(model_base)
|
||||
return model_base
|
||||
|
||||
@staticmethod
|
||||
def update(db: Session, model_base_id: uuid.UUID, data: dict) -> Optional['ModelBase']:
|
||||
model_base = db.query(ModelBase).filter(ModelBase.id == model_base_id).first()
|
||||
if not model_base:
|
||||
return None
|
||||
for key, value in data.items():
|
||||
setattr(model_base, key, value)
|
||||
return model_base
|
||||
|
||||
@staticmethod
|
||||
def delete(db: Session, model_base_id: uuid.UUID) -> bool:
|
||||
model_base = db.query(ModelBase).filter(ModelBase.id == model_base_id).first()
|
||||
if not model_base:
|
||||
return False
|
||||
db.delete(model_base)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def increment_add_count(db: Session, model_base_id: uuid.UUID) -> bool:
|
||||
model_base = db.query(ModelBase).filter(ModelBase.id == model_base_id).first()
|
||||
if not model_base:
|
||||
return False
|
||||
model_base.add_count += 1
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def check_added_by_tenant(db: Session, model_base_id: uuid.UUID, tenant_id: uuid.UUID) -> bool:
|
||||
return db.query(ModelConfig).filter(
|
||||
ModelConfig.model_id == model_base_id,
|
||||
ModelConfig.tenant_id == tenant_id
|
||||
).first() is not None
|
||||
|
||||
@@ -4,6 +4,10 @@ import datetime
|
||||
import uuid
|
||||
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
schema_logger = get_business_logger()
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +16,9 @@ class ModelConfigBase(BaseModel):
|
||||
"""模型配置基础Schema"""
|
||||
name: str = Field(..., description="模型显示名称", max_length=255)
|
||||
type: ModelType = Field(..., description="模型类型")
|
||||
logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255)
|
||||
description: Optional[str] = Field(None, description="模型描述")
|
||||
provider: str = Field(..., description="供应商")
|
||||
config: Optional[Dict[str, Any]] = Field({}, description="模型配置参数")
|
||||
is_active: bool = Field(True, description="是否激活")
|
||||
is_public: bool = Field(False, description="是否公开")
|
||||
@@ -21,6 +27,7 @@ class ModelConfigBase(BaseModel):
|
||||
class ApiKeyCreateNested(BaseModel):
|
||||
"""用于在创建模型时内嵌创建API Key的Schema"""
|
||||
model_name: str = Field(..., description="模型实际名称", max_length=255)
|
||||
description: Optional[str] = Field(None, description="备注")
|
||||
provider: ModelProvider = Field(..., description="API Key提供商")
|
||||
api_key: str = Field(..., description="API密钥", max_length=500)
|
||||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||||
@@ -30,10 +37,22 @@ class ApiKeyCreateNested(BaseModel):
|
||||
|
||||
class ModelConfigCreate(ModelConfigBase):
|
||||
"""创建模型配置Schema"""
|
||||
api_keys: Optional[ApiKeyCreateNested] = Field(None, description="同时创建的API Key配置")
|
||||
api_keys: Optional[List[ApiKeyCreateNested]] = Field(None, description="同时创建的API Key配置")
|
||||
skip_validation: Optional[bool] = Field(False, description="是否跳过配置验证")
|
||||
|
||||
|
||||
class CompositeModelCreate(BaseModel):
|
||||
"""创建组合模型Schema"""
|
||||
name: str = Field(..., description="组合模型名称", max_length=255)
|
||||
type: ModelType = Field(..., description="模型类型")
|
||||
logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255)
|
||||
description: Optional[str] = Field(None, description="模型描述")
|
||||
config: Optional[Dict[str, Any]] = Field({}, description="模型配置参数")
|
||||
is_active: bool = Field(True, description="是否激活")
|
||||
is_public: bool = Field(False, description="是否公开")
|
||||
api_key_ids: List[uuid.UUID] = Field(..., description="绑定的API Key ID列表")
|
||||
|
||||
|
||||
class ModelConfigUpdate(BaseModel):
|
||||
"""更新模型配置Schema"""
|
||||
name: Optional[str] = Field(None, description="模型显示名称", max_length=255)
|
||||
@@ -53,22 +72,48 @@ class ModelConfig(ModelConfigBase):
|
||||
updated_at: datetime.datetime
|
||||
api_keys: List["ModelApiKey"] = []
|
||||
|
||||
@field_validator("api_keys", mode="after")
|
||||
@classmethod
|
||||
def filter_active_api_keys(cls, api_keys: List["ModelApiKey"]) -> List["ModelApiKey"]:
|
||||
return [key for key in api_keys if key.is_active]
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime | None):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
# ModelApiKey Schemas
|
||||
class ModelApiKeyBase(BaseModel):
|
||||
"""API Key基础Schema"""
|
||||
model_name: str = Field(..., description="模型实际名称", max_length=255)
|
||||
class ModelApiKeyCreateByProvider(BaseModel):
|
||||
"""基于供应商创建API Key Schema"""
|
||||
provider: ModelProvider = Field(..., description="API Key提供商")
|
||||
api_key: str = Field(..., description="API密钥", max_length=500)
|
||||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||||
config: Optional[Dict[str, Any]] = Field(None, description="API Key特定配置")
|
||||
description: Optional[str] = Field(None, description="备注")
|
||||
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
|
||||
is_active: bool = Field(True, description="是否激活")
|
||||
priority: str = Field("1", description="优先级", max_length=10)
|
||||
model_config_ids: Optional[List[uuid.UUID]] = Field(None, description="关联的模型配置ID列表")
|
||||
|
||||
|
||||
class ModelApiKeyBase(BaseModel):
|
||||
"""API Key基础Schema"""
|
||||
model_name: str = Field(..., description="模型实际名称", max_length=255)
|
||||
description: Optional[str] = Field(None, description="备注")
|
||||
provider: ModelProvider = Field(..., description="API Key提供商")
|
||||
api_key: str = Field(..., description="API密钥", max_length=500)
|
||||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||||
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
|
||||
is_active: bool = Field(True, description="是否激活")
|
||||
priority: str = Field("1", description="优先级", max_length=10)
|
||||
|
||||
|
||||
class ModelApiKeyCreate(ModelApiKeyBase):
|
||||
"""创建API Key Schema"""
|
||||
model_config_id: uuid.UUID = Field(..., description="模型配置ID")
|
||||
model_config_ids: Optional[List[uuid.UUID]] = Field(None, description="关联的模型配置ID列表")
|
||||
|
||||
|
||||
class ModelApiKeyUpdate(BaseModel):
|
||||
@@ -85,23 +130,54 @@ class ModelApiKeyUpdate(BaseModel):
|
||||
class ModelApiKey(ModelApiKeyBase):
|
||||
"""API Key Schema"""
|
||||
id: uuid.UUID
|
||||
model_config_id: uuid.UUID
|
||||
usage_count: str
|
||||
last_used_at: Optional[datetime.datetime]
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
model_configs: Any = Field(default=None, exclude=True)
|
||||
model_config_ids: List[uuid.UUID] = Field(default_factory=list, description="关联的模型配置ID列表")
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def parse_config(cls, v):
|
||||
"""处理 config 字段,如果是字符串则解析为字典"""
|
||||
if isinstance(v, str):
|
||||
import json
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""实例化后强制提取 model_configs 的ID到 model_config_ids"""
|
||||
# 如果手动传入了 model_config_ids,不覆盖
|
||||
if self.model_config_ids and len(self.model_config_ids) > 0:
|
||||
return
|
||||
|
||||
# 从 model_configs 提取ID(只提取与 model_name 相同的非组合模型)
|
||||
if self.model_configs is not None:
|
||||
try:
|
||||
return json.loads(v)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return v
|
||||
# 情况1:ORM 对象列表(SQLAlchemy 关联)
|
||||
if hasattr(self.model_configs, '__iter__') and not isinstance(self.model_configs, dict):
|
||||
self.model_config_ids = [
|
||||
mc.id for mc in self.model_configs
|
||||
if hasattr(mc, 'id')
|
||||
and not getattr(mc, 'is_composite', False)
|
||||
and getattr(mc, 'name', None) == self.model_name
|
||||
]
|
||||
# 情况2:字典列表
|
||||
elif isinstance(self.model_configs, list):
|
||||
self.model_config_ids = [
|
||||
mc['id'] if isinstance(mc, dict) else mc.id
|
||||
for mc in self.model_configs
|
||||
if ((isinstance(mc, dict)
|
||||
and 'id' in mc
|
||||
and not mc.get('is_composite', False)
|
||||
and mc.get('name') == self.model_name) or
|
||||
(hasattr(mc, 'id')
|
||||
and not getattr(mc, 'is_composite', False)
|
||||
and getattr(mc, 'name', None) == self.model_name))
|
||||
]
|
||||
except Exception as e:
|
||||
schema_logger.warning(f"提取 model_config_ids 失败:{e}")
|
||||
self.model_config_ids = []
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True, # 支持从 ORM 解析
|
||||
arbitrary_types_allowed=True, # 允许任意类型(ORM 对象)
|
||||
populate_by_name=True, # 按属性名匹配字段
|
||||
validate_assignment=True # 确保赋值触发校验
|
||||
)
|
||||
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
@@ -110,15 +186,12 @@ class ModelApiKey(ModelApiKeyBase):
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("last_used_at", when_used="json")
|
||||
def _serialize_last_used_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
# 查询和响应Schemas
|
||||
class ModelConfigQuery(BaseModel):
|
||||
"""模型配置查询Schema"""
|
||||
type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)")
|
||||
@@ -129,6 +202,17 @@ class ModelConfigQuery(BaseModel):
|
||||
page: int = Field(1, description="页码", ge=1)
|
||||
pagesize: int = Field(10, description="每页数量", ge=1, le=100)
|
||||
|
||||
|
||||
# 查询和响应Schemas
|
||||
class ModelConfigQueryNew(BaseModel):
|
||||
"""模型配置查询Schema"""
|
||||
type: Optional[List[ModelType]] = Field(None, description="模型类型筛选(支持多个)")
|
||||
provider: Optional[ModelProvider] = Field(None, description="提供商筛选(通过API Key)")
|
||||
is_active: Optional[bool] = Field(None, description="激活状态筛选")
|
||||
is_public: Optional[bool] = Field(None, description="公开状态筛选")
|
||||
is_composite: Optional[bool] = Field(None, description="组合模型筛选")
|
||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||||
|
||||
class ModelMarketplace(BaseModel):
|
||||
"""模型广场响应Schema"""
|
||||
llm_models: List[ModelConfig] = []
|
||||
@@ -171,4 +255,53 @@ class ModelValidateResponse(BaseModel):
|
||||
|
||||
|
||||
# 更新前向引用
|
||||
ModelConfig.model_rebuild()
|
||||
ModelConfig.model_rebuild()
|
||||
|
||||
|
||||
# ModelBase Schemas
|
||||
class ModelBaseCreate(BaseModel):
|
||||
"""创建基础模型Schema"""
|
||||
name: str = Field(..., description="模型唯一标识", max_length=255)
|
||||
type: ModelType = Field(..., description="模型类型")
|
||||
provider: ModelProvider = Field(..., description="提供商")
|
||||
logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255)
|
||||
description: Optional[str] = Field(None, description="模型描述")
|
||||
is_official: bool = Field(True, description="是否供应商官方模型")
|
||||
tags: List[str] = Field(default_factory=list, description="模型标签")
|
||||
|
||||
|
||||
class ModelBaseUpdate(BaseModel):
|
||||
"""更新基础模型Schema"""
|
||||
name: Optional[str] = Field(None, description="模型唯一标识", max_length=255)
|
||||
type: Optional[ModelType] = Field(None, description="模型类型")
|
||||
provider: Optional[ModelProvider] = Field(None, description="提供商")
|
||||
logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255)
|
||||
description: Optional[str] = Field(None, description="模型描述")
|
||||
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
|
||||
is_official: Optional[bool] = Field(None, description="是否供应商官方模型")
|
||||
tags: Optional[List[str]] = Field(None, description="模型标签")
|
||||
|
||||
|
||||
class ModelBase(BaseModel):
|
||||
"""基础模型Schema"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
type: str
|
||||
provider: str
|
||||
logo: Optional[str]
|
||||
description: Optional[str]
|
||||
is_deprecated: bool
|
||||
is_official: bool
|
||||
tags: List[str]
|
||||
add_count: int
|
||||
|
||||
|
||||
class ModelBaseQuery(BaseModel):
|
||||
"""基础模型查询Schema"""
|
||||
type: Optional[ModelType] = Field(None, description="模型类型")
|
||||
provider: Optional[ModelProvider] = Field(None, description="提供商")
|
||||
is_official: Optional[bool] = Field(None, description="是否官方模型")
|
||||
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
|
||||
search: Optional[str] = Field(None, description="搜索关键词", max_length=255)
|
||||
|
||||
193
api/app/services/app_statistics_service.py
Normal file
193
api/app/services/app_statistics_service.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""应用统计服务"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List
|
||||
import uuid
|
||||
from sqlalchemy import func, and_, cast, Date
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.conversation_model import Conversation, Message
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.api_key_model import ApiKey, ApiKeyLog
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
|
||||
class AppStatisticsService:
|
||||
"""应用统计服务"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_app_statistics(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
start_date: int,
|
||||
end_date: int
|
||||
) -> Dict[str, Any]:
|
||||
"""获取应用统计数据
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
workspace_id: 工作空间ID
|
||||
start_date: 开始时间戳(毫秒)
|
||||
end_date: 结束时间戳(毫秒)
|
||||
|
||||
Returns:
|
||||
统计数据字典
|
||||
"""
|
||||
# 将毫秒时间戳转换为 datetime
|
||||
start_dt = datetime.fromtimestamp(start_date / 1000)
|
||||
end_dt = datetime.fromtimestamp(end_date / 1000) + timedelta(days=1)
|
||||
|
||||
# 1. 会话统计
|
||||
conversations_stats = self._get_conversations_statistics(app_id, workspace_id, start_dt, end_dt)
|
||||
|
||||
# 2. 新增用户统计
|
||||
users_stats = self._get_new_users_statistics(app_id, start_dt, end_dt)
|
||||
|
||||
# 3. API调用统计
|
||||
api_stats = self._get_api_calls_statistics(app_id, start_dt, end_dt)
|
||||
|
||||
# 4. Token消耗统计
|
||||
token_stats = self._get_token_statistics(app_id, start_dt, end_dt)
|
||||
|
||||
return {
|
||||
"daily_conversations": conversations_stats["daily"],
|
||||
"total_conversations": conversations_stats["total"],
|
||||
"daily_new_users": users_stats["daily"],
|
||||
"total_new_users": users_stats["total"],
|
||||
"daily_api_calls": api_stats["daily"],
|
||||
"total_api_calls": api_stats["total"],
|
||||
"daily_tokens": token_stats["daily"],
|
||||
"total_tokens": token_stats["total"]
|
||||
}
|
||||
|
||||
def _get_conversations_statistics(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
start_dt: datetime,
|
||||
end_dt: datetime
|
||||
) -> Dict[str, Any]:
|
||||
"""获取会话统计"""
|
||||
# 每日会话数
|
||||
daily_query = self.db.query(
|
||||
cast(Conversation.created_at, Date).label('date'),
|
||||
func.count(Conversation.id).label('count')
|
||||
).filter(
|
||||
and_(
|
||||
Conversation.app_id == app_id,
|
||||
Conversation.workspace_id == workspace_id,
|
||||
Conversation.created_at >= start_dt,
|
||||
Conversation.created_at < end_dt
|
||||
)
|
||||
).group_by(cast(Conversation.created_at, Date)).all()
|
||||
|
||||
daily_data = [{"date": str(row.date), "count": row.count} for row in daily_query]
|
||||
total = sum(row["count"] for row in daily_data)
|
||||
|
||||
return {"daily": daily_data, "total": total}
|
||||
|
||||
def _get_new_users_statistics(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
start_dt: datetime,
|
||||
end_dt: datetime
|
||||
) -> Dict[str, Any]:
|
||||
"""获取新增用户统计"""
|
||||
# 每日新增用户数
|
||||
daily_query = self.db.query(
|
||||
cast(EndUser.created_at, Date).label('date'),
|
||||
func.count(EndUser.id).label('count')
|
||||
).filter(
|
||||
and_(
|
||||
EndUser.app_id == app_id,
|
||||
EndUser.created_at >= start_dt,
|
||||
EndUser.created_at < end_dt
|
||||
)
|
||||
).group_by(cast(EndUser.created_at, Date)).all()
|
||||
|
||||
daily_data = [{"date": str(row.date), "count": row.count} for row in daily_query]
|
||||
total = sum(row["count"] for row in daily_data)
|
||||
|
||||
return {"daily": daily_data, "total": total}
|
||||
|
||||
def _get_api_calls_statistics(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
start_dt: datetime,
|
||||
end_dt: datetime
|
||||
) -> Dict[str, Any]:
|
||||
"""获取API调用统计"""
|
||||
# 每日API调用次数
|
||||
daily_query = self.db.query(
|
||||
cast(ApiKeyLog.created_at, Date).label('date'),
|
||||
func.count(ApiKeyLog.id).label('count')
|
||||
).join(
|
||||
ApiKey, ApiKeyLog.api_key_id == ApiKey.id
|
||||
).filter(
|
||||
and_(
|
||||
ApiKey.resource_id == app_id,
|
||||
ApiKeyLog.created_at >= start_dt,
|
||||
ApiKeyLog.created_at < end_dt
|
||||
)
|
||||
).group_by(cast(ApiKeyLog.created_at, Date)).all()
|
||||
|
||||
daily_data = [{"date": str(row.date), "count": row.count} for row in daily_query]
|
||||
total = sum(row["count"] for row in daily_data)
|
||||
|
||||
return {"daily": daily_data, "total": total}
|
||||
|
||||
def _get_token_statistics(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
start_dt: datetime,
|
||||
end_dt: datetime
|
||||
) -> Dict[str, Any]:
|
||||
"""获取Token消耗统计(从Message的meta_data中提取)"""
|
||||
from sqlalchemy import text
|
||||
|
||||
# 查询所有相关消息的token使用情况
|
||||
# meta_data中可能包含: {"usage": {"total_tokens": 100}} 或 {"tokens": 100}
|
||||
daily_query = self.db.query(
|
||||
cast(Message.created_at, Date).label('date'),
|
||||
Message.meta_data
|
||||
).join(
|
||||
Conversation, Message.conversation_id == Conversation.id
|
||||
).filter(
|
||||
and_(
|
||||
Conversation.app_id == app_id,
|
||||
Message.created_at >= start_dt,
|
||||
Message.created_at < end_dt,
|
||||
Message.meta_data.isnot(None)
|
||||
)
|
||||
).all()
|
||||
|
||||
# 按日期聚合token
|
||||
daily_tokens = {}
|
||||
for row in daily_query:
|
||||
date_str = str(row.date)
|
||||
meta = row.meta_data or {}
|
||||
|
||||
# 提取token数量(支持多种格式)
|
||||
tokens = 0
|
||||
if isinstance(meta, dict):
|
||||
# 格式1: {"usage": {"total_tokens": 100}}
|
||||
if "usage" in meta and isinstance(meta["usage"], dict):
|
||||
tokens = meta["usage"].get("total_tokens", 0)
|
||||
# 格式2: {"tokens": 100}
|
||||
elif "tokens" in meta:
|
||||
tokens = meta.get("tokens", 0)
|
||||
# 格式3: {"total_tokens": 100}
|
||||
elif "total_tokens" in meta:
|
||||
tokens = meta.get("total_tokens", 0)
|
||||
|
||||
if date_str not in daily_tokens:
|
||||
daily_tokens[date_str] = 0
|
||||
daily_tokens[date_str] += int(tokens)
|
||||
|
||||
daily_data = [{"date": date, "tokens": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
|
||||
total = sum(row["tokens"] for row in daily_data)
|
||||
|
||||
return {"daily": daily_data, "total": total}
|
||||
@@ -16,6 +16,7 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.models import AgentConfig, ModelApiKey, ModelConfig
|
||||
from app.repositories.model_repository import ModelApiKeyRepository
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.services import task_service
|
||||
@@ -724,17 +725,21 @@ class DraftRunService:
|
||||
Raises:
|
||||
BusinessException: 当没有可用的 API Key 时
|
||||
"""
|
||||
stmt = (
|
||||
select(ModelApiKey)
|
||||
.where(
|
||||
ModelApiKey.model_config_id == model_config_id,
|
||||
ModelApiKey.is_active.is_(True)
|
||||
)
|
||||
.order_by(ModelApiKey.priority.desc())
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
api_key = self.db.scalars(stmt).first()
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
||||
# stmt = (
|
||||
# select(ModelApiKey).join(
|
||||
# ModelConfig, ModelApiKey.model_configs
|
||||
# )
|
||||
# .where(
|
||||
# ModelConfig.id == model_config_id,
|
||||
# ModelApiKey.is_active.is_(True)
|
||||
# )
|
||||
# .order_by(ModelApiKey.priority.desc())
|
||||
# .limit(1)
|
||||
# )
|
||||
#
|
||||
# api_key = self.db.scalars(stmt).first()
|
||||
api_key = api_keys[0] if api_keys else None
|
||||
|
||||
if not api_key:
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
@@ -5,6 +5,7 @@ import uuid
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.repositories.model_repository import ModelApiKeyRepository
|
||||
from app.services.conversation_state_manager import ConversationStateManager
|
||||
from app.models import ModelConfig, AgentConfig
|
||||
from app.core.logging_config import get_business_logger
|
||||
@@ -382,11 +383,14 @@ class LLMRouter:
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.models import ModelApiKey, ModelType
|
||||
|
||||
# 获取 API Key 配置
|
||||
api_key_config = self.db.query(ModelApiKey).filter(
|
||||
ModelApiKey.model_config_id == self.routing_model_config.id,
|
||||
ModelApiKey.is_active
|
||||
).first()
|
||||
# 获取 API Key 配置(通过关联关系)
|
||||
# api_key_config = self.db.query(ModelApiKey).join(
|
||||
# ModelConfig, ModelApiKey.model_configs
|
||||
# ).filter(ModelConfig.id == self.routing_model_config.id,
|
||||
# ModelApiKey.is_active == True
|
||||
# ).first()
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, self.routing_model_config.id)
|
||||
api_key_config = api_keys[0] if api_keys else None
|
||||
|
||||
if not api_key_config:
|
||||
raise Exception("路由模型没有可用的 API Key")
|
||||
@@ -419,6 +423,9 @@ class LLMRouter:
|
||||
|
||||
# 调用模型
|
||||
response = await llm.ainvoke(prompt)
|
||||
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
|
||||
|
||||
# 提取响应内容
|
||||
if hasattr(response, 'content'):
|
||||
|
||||
@@ -338,7 +338,7 @@ class MemoryConfigService:
|
||||
"provider": api_config.provider,
|
||||
"api_key": api_config.api_key,
|
||||
"base_url": api_config.api_base,
|
||||
"model_config_id": api_config.model_config_id,
|
||||
"model_config_id": str(config.id),
|
||||
"type": config.type,
|
||||
"timeout": settings.LLM_TIMEOUT,
|
||||
"max_retries": settings.LLM_MAX_RETRIES,
|
||||
@@ -370,7 +370,7 @@ class MemoryConfigService:
|
||||
"provider": api_config.provider,
|
||||
"api_key": api_config.api_key,
|
||||
"base_url": api_config.api_base,
|
||||
"model_config_id": api_config.model_config_id,
|
||||
"model_config_id": str(config.id),
|
||||
"type": config.type,
|
||||
"timeout": 120.0,
|
||||
"max_retries": 5,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional, Dict, Any
|
||||
import uuid
|
||||
@@ -6,11 +7,11 @@ import time
|
||||
import asyncio
|
||||
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType
|
||||
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository
|
||||
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository, ModelBaseRepository
|
||||
from app.schemas import model_schema
|
||||
from app.schemas.model_schema import (
|
||||
ModelConfigCreate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
||||
ModelConfigQuery, ModelStats
|
||||
ModelConfigQuery, ModelStats, ModelConfigQueryNew
|
||||
)
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
@@ -47,6 +48,26 @@ class ModelConfigService:
|
||||
items=[model_schema.ModelConfig.model_validate(model) for model in models]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_model_list_new(db: Session, query: ModelConfigQueryNew, tenant_id: uuid.UUID | None = None) -> List[dict]:
|
||||
"""获取模型配置列表"""
|
||||
provider_groups, total = ModelConfigRepository.get_list_new(db, query, tenant_id=tenant_id)
|
||||
|
||||
items = []
|
||||
for provider, models in provider_groups.items():
|
||||
# 验证每个模型并封装分组信息
|
||||
validated_models = [model_schema.ModelConfig.model_validate(model) for model in models]
|
||||
tags = list({model.type for model in validated_models})
|
||||
group_item = {
|
||||
"provider": provider, # 服务商名称
|
||||
"logo": validated_models[0].logo,
|
||||
"tags": tags,
|
||||
"models": validated_models # 该服务商下的所有模型
|
||||
}
|
||||
items.append(group_item)
|
||||
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def get_model_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
"""根据名称获取模型配置"""
|
||||
@@ -228,37 +249,39 @@ class ModelConfigService:
|
||||
|
||||
# 验证配置
|
||||
if not model_data.skip_validation and model_data.api_keys:
|
||||
api_key_data = model_data.api_keys
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=api_key_data.model_name,
|
||||
provider=api_key_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_data.type, # 传递模型类型
|
||||
test_message="Hello"
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
f"模型配置验证失败: {validation_result['error']}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
api_key_data_list = model_data.api_keys
|
||||
for api_key_data in api_key_data_list:
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=api_key_data.model_name,
|
||||
provider=api_key_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_data.type, # 传递模型类型
|
||||
test_message="Hello"
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
f"模型配置验证失败: {validation_result['error']}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
# 事务处理
|
||||
api_key_data = model_data.api_keys
|
||||
model_config_data = model_data.dict(exclude={"api_keys", "skip_validation"})
|
||||
api_key_datas = model_data.api_keys
|
||||
model_config_data = model_data.model_dump(exclude={"api_keys", "skip_validation"})
|
||||
# 添加租户ID
|
||||
model_config_data["tenant_id"] = tenant_id
|
||||
|
||||
model = ModelConfigRepository.create(db, model_config_data)
|
||||
db.flush() # 获取生成的 ID
|
||||
|
||||
if api_key_data:
|
||||
api_key_create_schema = ModelApiKeyCreate(
|
||||
model_config_id=model.id,
|
||||
**api_key_data.dict()
|
||||
)
|
||||
ModelApiKeyRepository.create(db, api_key_create_schema)
|
||||
if api_key_datas:
|
||||
for api_key_data in api_key_datas:
|
||||
api_key_create_schema = ModelApiKeyCreate(
|
||||
model_config_ids=[model.id],
|
||||
**api_key_data.model_dump()
|
||||
)
|
||||
ModelApiKeyRepository.create(db, api_key_create_schema)
|
||||
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
@@ -280,6 +303,112 @@ class ModelConfigService:
|
||||
db.refresh(model)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
||||
"""创建组合模型"""
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
# 验证所有 API Key 存在且类型匹配
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if not api_key:
|
||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
|
||||
|
||||
# 检查 API Key 关联的模型配置类型
|
||||
for model_config in api_key.model_configs:
|
||||
# chat 和 llm 类型可以兼容
|
||||
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
||||
config_type = model_config.type
|
||||
request_type = model_data.type
|
||||
|
||||
if not (config_type == request_type or
|
||||
(config_type in compatible_types and request_type in compatible_types)):
|
||||
raise BusinessException(
|
||||
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
# if model_config.is_composite:
|
||||
# raise BusinessException(
|
||||
# f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型",
|
||||
# BizCode.INVALID_PARAMETER
|
||||
# )
|
||||
|
||||
# 创建组合模型
|
||||
model_config_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"name": model_data.name,
|
||||
"type": model_data.type,
|
||||
"logo": model_data.logo,
|
||||
"description": model_data.description,
|
||||
"provider": "composite",
|
||||
"config": model_data.config,
|
||||
"is_active": model_data.is_active,
|
||||
"is_public": model_data.is_public,
|
||||
"is_composite": True
|
||||
}
|
||||
|
||||
model = ModelConfigRepository.create(db, model_config_data)
|
||||
db.flush()
|
||||
|
||||
# 关联 API Keys
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if api_key:
|
||||
model.api_keys.append(api_key)
|
||||
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
||||
"""更新组合模型"""
|
||||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||||
if not existing_model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
if not existing_model.is_composite:
|
||||
raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER)
|
||||
|
||||
# 验证所有 API Key 存在且类型匹配
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if not api_key:
|
||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
|
||||
|
||||
for model_config in api_key.model_configs:
|
||||
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
||||
config_type = model_config.type
|
||||
request_type = model_data.type
|
||||
|
||||
if not (config_type == request_type or
|
||||
(config_type in compatible_types and request_type in compatible_types)):
|
||||
raise BusinessException(
|
||||
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
# 更新基本信息
|
||||
existing_model.name = model_data.name
|
||||
existing_model.type = model_data.type
|
||||
existing_model.logo = model_data.logo
|
||||
existing_model.description = model_data.description
|
||||
existing_model.config = model_data.config
|
||||
existing_model.is_active = model_data.is_active
|
||||
existing_model.is_public = model_data.is_public
|
||||
|
||||
# 更新 API Keys 关联
|
||||
existing_model.api_keys.clear()
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if api_key:
|
||||
existing_model.api_keys.append(api_key)
|
||||
|
||||
db.commit()
|
||||
db.refresh(existing_model)
|
||||
return existing_model
|
||||
|
||||
@staticmethod
|
||||
def delete_model(db: Session, model_id: uuid.UUID, tenant_id: uuid.UUID | None = None) -> bool:
|
||||
"""删除模型配置"""
|
||||
@@ -324,27 +453,132 @@ class ModelApiKeyService:
|
||||
return ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active)
|
||||
|
||||
@staticmethod
|
||||
async def create_api_key(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey:
|
||||
"""创建API Key"""
|
||||
model_config = ModelConfigRepository.get_by_id(db, api_key_data.model_config_id)
|
||||
if not model_config:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
async def create_api_key_by_provider(db: Session, data: model_schema.ModelApiKeyCreateByProvider) -> List[ModelApiKey]:
|
||||
"""根据provider为多个ModelConfig创建API Key"""
|
||||
created_keys = []
|
||||
|
||||
for model_config_id in data.model_config_ids:
|
||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||
if not model_config:
|
||||
continue
|
||||
|
||||
# 从ModelBase获取model_name
|
||||
model_name = model_config.model_base.name if model_config.model_base else model_config.name
|
||||
|
||||
# 检查是否存在API Key(包括软删除)
|
||||
existing_key = db.query(ModelApiKey).filter(
|
||||
ModelApiKey.api_key == data.api_key,
|
||||
ModelApiKey.provider == data.provider,
|
||||
ModelApiKey.model_name == model_name
|
||||
).first()
|
||||
|
||||
if existing_key:
|
||||
# 如果已存在,重新激活并更新
|
||||
if existing_key.is_active:
|
||||
continue
|
||||
existing_key.is_active = True
|
||||
existing_key.api_base = data.api_base
|
||||
existing_key.description = data.description
|
||||
existing_key.config = data.config
|
||||
existing_key.priority = data.priority
|
||||
existing_key.model_name = model_name
|
||||
|
||||
# 检查是否已关联该模型配置
|
||||
if model_config not in existing_key.model_configs:
|
||||
existing_key.model_configs.append(model_config)
|
||||
|
||||
created_keys.append(existing_key)
|
||||
continue
|
||||
|
||||
# 验证配置
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=api_key_data.model_name,
|
||||
provider=api_key_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_config.type, # 传递模型类型
|
||||
model_name=model_name,
|
||||
provider=data.provider,
|
||||
api_key=data.api_key,
|
||||
api_base=data.api_base,
|
||||
model_type=model_config.type,
|
||||
test_message="Hello"
|
||||
)
|
||||
print(validation_result)
|
||||
if not validation_result["valid"]:
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
f"模型配置验证失败: {validation_result['error']}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
# 创建API Key
|
||||
api_key_data = ModelApiKeyCreate(
|
||||
model_config_ids=[model_config_id],
|
||||
model_name=model_name,
|
||||
description=data.description,
|
||||
provider=data.provider,
|
||||
api_key=data.api_key,
|
||||
api_base=data.api_base,
|
||||
config=data.config,
|
||||
is_active=data.is_active,
|
||||
priority=data.priority
|
||||
)
|
||||
api_key_obj = ModelApiKeyRepository.create(db, api_key_data)
|
||||
created_keys.append(api_key_obj)
|
||||
|
||||
if created_keys:
|
||||
db.commit()
|
||||
for key in created_keys:
|
||||
db.refresh(key)
|
||||
|
||||
return created_keys
|
||||
|
||||
@staticmethod
|
||||
async def create_api_key(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey:
|
||||
# 验证所有关联的模型配置是否存在
|
||||
if api_key_data.model_config_ids:
|
||||
for model_config_id in api_key_data.model_config_ids:
|
||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||
if not model_config:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
# 检查API Key是否已存在(包括软删除)
|
||||
existing_key = db.query(ModelApiKey).filter(
|
||||
ModelApiKey.api_key == api_key_data.api_key,
|
||||
ModelApiKey.provider == api_key_data.provider,
|
||||
ModelApiKey.model_name == api_key_data.model_name
|
||||
).first()
|
||||
|
||||
if existing_key:
|
||||
if existing_key.is_active:
|
||||
# 如果已激活,跳过
|
||||
raise BusinessException("该API Key已存在", BizCode.DUPLICATE_NAME)
|
||||
# 如果已存在,重新激活并更新
|
||||
existing_key.is_active = True
|
||||
existing_key.api_base = api_key_data.api_base
|
||||
existing_key.description = api_key_data.description
|
||||
existing_key.config = api_key_data.config
|
||||
existing_key.priority = api_key_data.priority
|
||||
existing_key.model_name = api_key_data.model_name
|
||||
|
||||
# 检查是否已关联该模型配置
|
||||
if model_config not in existing_key.model_configs:
|
||||
existing_key.model_configs.append(model_config)
|
||||
|
||||
db.commit()
|
||||
db.refresh(existing_key)
|
||||
return existing_key
|
||||
|
||||
# 验证配置
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=api_key_data.model_name,
|
||||
provider=api_key_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_config.type,
|
||||
test_message="Hello"
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
f"模型配置验证失败: {validation_result['error']}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
api_key = ModelApiKeyRepository.create(db, api_key_data)
|
||||
db.commit()
|
||||
@@ -359,21 +593,19 @@ class ModelApiKeyService:
|
||||
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
|
||||
|
||||
# 获取关联的模型配置以获取模型类型
|
||||
model_config = ModelConfigRepository.get_by_id(db, existing_api_key.model_config_id)
|
||||
if not model_config:
|
||||
raise BusinessException("关联的模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
if existing_api_key.model_configs:
|
||||
model_config = existing_api_key.model_configs[0]
|
||||
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=api_key_data.model_name,
|
||||
provider=api_key_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_config.type, # 传递模型类型
|
||||
model_name=api_key_data.model_name or existing_api_key.model_name,
|
||||
provider=api_key_data.provider or existing_api_key.provider,
|
||||
api_key=api_key_data.api_key or existing_api_key.api_key,
|
||||
api_base=api_key_data.api_base or existing_api_key.api_base,
|
||||
model_type=model_config.type,
|
||||
test_message="Hello"
|
||||
)
|
||||
print(validation_result)
|
||||
if not validation_result["valid"]:
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
f"模型配置验证失败: {validation_result['error']}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
@@ -417,3 +649,84 @@ class ModelApiKeyService:
|
||||
if api_kes and len(api_kes) > 0:
|
||||
return api_kes[0]
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
|
||||
|
||||
class ModelBaseService:
|
||||
"""基础模型服务"""
|
||||
|
||||
@staticmethod
|
||||
def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List:
|
||||
models = ModelBaseRepository.get_list(db, query)
|
||||
|
||||
provider_groups = {}
|
||||
for m in models:
|
||||
model_dict = model_schema.ModelBase.model_validate(m).model_dump()
|
||||
if tenant_id:
|
||||
model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id)
|
||||
|
||||
provider = m.provider
|
||||
if provider not in provider_groups:
|
||||
provider_groups[provider] = {
|
||||
"provider": provider,
|
||||
"models": []
|
||||
}
|
||||
provider_groups[provider]["models"].append(model_dict)
|
||||
|
||||
return list(provider_groups.values())
|
||||
|
||||
@staticmethod
|
||||
def get_model_base_by_id(db: Session, model_base_id: uuid.UUID):
|
||||
model = ModelBaseRepository.get_by_id(db, model_base_id)
|
||||
if not model:
|
||||
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def create_model_base(db: Session, data: model_schema.ModelBaseCreate):
|
||||
model_base = ModelBaseRepository.create(db, data.model_dump())
|
||||
db.commit()
|
||||
db.refresh(model_base)
|
||||
return model_base
|
||||
|
||||
@staticmethod
|
||||
def update_model_base(db: Session, model_base_id: uuid.UUID, data: model_schema.ModelBaseUpdate):
|
||||
model_base = ModelBaseRepository.update(db, model_base_id, data.model_dump(exclude_unset=True))
|
||||
if not model_base:
|
||||
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
||||
db.commit()
|
||||
db.refresh(model_base)
|
||||
return model_base
|
||||
|
||||
@staticmethod
|
||||
def delete_model_base(db: Session, model_base_id: uuid.UUID) -> bool:
|
||||
success = ModelBaseRepository.delete(db, model_base_id)
|
||||
if not success:
|
||||
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
||||
db.commit()
|
||||
return success
|
||||
|
||||
@staticmethod
|
||||
def add_model_from_plaza(db: Session, model_base_id: uuid.UUID, tenant_id: uuid.UUID) -> ModelConfig:
|
||||
model_base = ModelBaseRepository.get_by_id(db, model_base_id)
|
||||
if not model_base:
|
||||
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id):
|
||||
raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME)
|
||||
|
||||
model_config_data = {
|
||||
"model_id": model_base_id,
|
||||
"tenant_id": tenant_id,
|
||||
"name": model_base.name,
|
||||
"provider": model_base.provider,
|
||||
"type": model_base.type,
|
||||
"logo": model_base.logo,
|
||||
"description": model_base.description,
|
||||
"is_composite": False
|
||||
}
|
||||
model_config = ModelConfigRepository.create(db, model_config_data)
|
||||
ModelBaseRepository.increment_add_count(db, model_base_id)
|
||||
db.commit()
|
||||
db.refresh(model_config)
|
||||
return model_config
|
||||
|
||||
@@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import MultiAgentConfig, AgentConfig, ModelConfig
|
||||
from app.models.multi_agent_model import AggregationStrategy, OrchestrationMode
|
||||
from app.repositories.model_repository import ModelApiKeyRepository
|
||||
from app.services.agent_registry import AgentRegistry
|
||||
from app.services.master_agent_router import MasterAgentRouter
|
||||
from app.services.conversation_state_manager import ConversationStateManager
|
||||
@@ -2546,10 +2547,14 @@ class MultiAgentOrchestrator:
|
||||
return self._smart_merge_results(results, strategy)
|
||||
|
||||
# 获取 API Key 配置
|
||||
api_key_config = self.db.query(ModelApiKey).filter(
|
||||
ModelApiKey.model_config_id == default_model_config_id,
|
||||
ModelApiKey.is_active.is_(True)
|
||||
).first()
|
||||
# api_key_config = self.db.query(ModelApiKey).join(
|
||||
# ModelConfig, ModelApiKey.model_configs
|
||||
# ).filter(
|
||||
# ModelConfig.id == default_model_config_id,
|
||||
# ModelApiKey.is_active.is_(True)
|
||||
# ).first()
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id)
|
||||
api_key_config = api_keys[0] if api_keys else None
|
||||
|
||||
if not api_key_config:
|
||||
logger.warning("Master Agent 没有可用的 API Key,使用简单整合")
|
||||
@@ -2703,10 +2708,14 @@ class MultiAgentOrchestrator:
|
||||
return
|
||||
|
||||
# 获取 API Key 配置
|
||||
api_key_config = self.db.query(ModelApiKey).filter(
|
||||
ModelApiKey.model_config_id == default_model_config_id,
|
||||
ModelApiKey.is_active.is_(True)
|
||||
).first()
|
||||
# api_key_config = self.db.query(ModelApiKey).join(
|
||||
# ModelConfig, ModelApiKey.model_configs
|
||||
# ).filter(
|
||||
# ModelConfig.id == default_model_config_id,
|
||||
# ModelApiKey.is_active.is_(True)
|
||||
# ).first()
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id)
|
||||
api_key_config = api_keys[0] if api_keys else None
|
||||
|
||||
if not api_key_config:
|
||||
logger.warning("Master Agent 没有可用的 API Key,使用简单整合")
|
||||
|
||||
@@ -4,6 +4,8 @@ import time
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any, AsyncGenerator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.repositories.model_repository import ModelApiKeyRepository
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.models import ReleaseShare, AppRelease, Conversation
|
||||
from app.services.conversation_service import ConversationService
|
||||
@@ -164,16 +166,20 @@ class SharedChatService:
|
||||
raise ResourceNotFoundException("模型配置", str(model_config_id))
|
||||
|
||||
# 获取 API Key
|
||||
stmt = (
|
||||
select(ModelApiKey)
|
||||
.where(
|
||||
ModelApiKey.model_config_id == model_config_id,
|
||||
ModelApiKey.is_active.is_(True)
|
||||
)
|
||||
.order_by(ModelApiKey.priority.desc())
|
||||
.limit(1)
|
||||
)
|
||||
api_key_obj = self.db.scalars(stmt).first()
|
||||
# stmt = (
|
||||
# select(ModelApiKey).join(
|
||||
# ModelConfig, ModelApiKey.model_configs
|
||||
# )
|
||||
# .where(
|
||||
# ModelConfig.id == model_config_id,
|
||||
# ModelApiKey.is_active.is_(True)
|
||||
# )
|
||||
# .order_by(ModelApiKey.priority.desc())
|
||||
# .limit(1)
|
||||
# )
|
||||
# api_key_obj = self.db.scalars(stmt).first()
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
||||
api_key_obj = api_keys[0] if api_keys else None
|
||||
if not api_key_obj:
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
@@ -358,16 +364,20 @@ class SharedChatService:
|
||||
raise ResourceNotFoundException("模型配置", str(model_config_id))
|
||||
|
||||
# 获取 API Key
|
||||
stmt = (
|
||||
select(ModelApiKey)
|
||||
.where(
|
||||
ModelApiKey.model_config_id == model_config_id,
|
||||
ModelApiKey.is_active.is_(True)
|
||||
)
|
||||
.order_by(ModelApiKey.priority.desc())
|
||||
.limit(1)
|
||||
)
|
||||
api_key_obj = self.db.scalars(stmt).first()
|
||||
# stmt = (
|
||||
# select(ModelApiKey).join(
|
||||
# ModelConfig, ModelApiKey.model_configs
|
||||
# )
|
||||
# .where(
|
||||
# ModelConfig.id == model_config_id,
|
||||
# ModelApiKey.is_active.is_(True)
|
||||
# )
|
||||
# .order_by(ModelApiKey.priority.desc())
|
||||
# .limit(1)
|
||||
# )
|
||||
# api_key_obj = self.db.scalars(stmt).first()
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
||||
api_key_obj = api_keys[0] if api_keys else None
|
||||
if not api_key_obj:
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
|
||||
@@ -5,11 +5,12 @@
|
||||
"releaseDate": "2026-1-23",
|
||||
"upgradePosition": "\uD83D\uDC3B 本次更新主要优化使用体验和修复已知问题,让系统更稳定、更好用。",
|
||||
"coreUpgrades": [
|
||||
"1. 工作流更好用了\n* 界面更清晰,一眼看懂怎么配置\n* 新增节点输出变量展示,方便其他节点引用\n* 修复了几个影响体验的bug",
|
||||
"2. 智能体配置更简单\n* 提示词和变量联动更顺畅\n* 配置界面重新整理,找功能更方便",
|
||||
"3. 记忆系统更稳定\n* 优化了情绪记忆和隐性记忆的缓存更新\n* 修复了记忆配置页面的报错问题\n* 现在能自动识别用户和AI的身份了",
|
||||
"4. 知识库体验提升\n* 修复了文档解析异常的问题\n* 上传文档时能看到处理进度了\n* 取消了操作也不会报错了",
|
||||
"5. 系统整体更可靠\n* 修复了新用户访问跳转问题\n* 流式接口更稳定,长对话不断线\n* 调整了菜单顺序,操作更顺手\n",
|
||||
"1. 工作流更好用了<br>* 界面更清晰,一眼看懂怎么配置<br>* 新增节点输出变量展示,方便其他节点引用<br>* 修复了几个影响体验的bug",
|
||||
"2. 智能体配置更简单<br>* 提示词和变量联动更顺畅<br>* 配置界面重新整理,找功能更方便",
|
||||
"3. 记忆系统更稳定<br>* 优化了情绪记忆和隐性记忆的缓存更新<br>* 修复了记忆配置页面的报错问题<br>* 现在能自动识别用户和AI的身份了",
|
||||
"4. 知识库体验提升<br>* 修复了文档解析异常的问题<br>* 上传文档时能看到处理进度了<br>* 取消了操作也不会报错了",
|
||||
"5. 系统整体更可靠<br>* 修复了新用户访问跳转问题<br>* 流式接口更稳定,长对话不断线<br>* 调整了菜单顺序,操作更顺手",
|
||||
"<br>",
|
||||
"这次更新虽然不大,但让记忆熊的基础更扎实、体验更流畅。我们继续努力,让AI记忆更好用!",
|
||||
"记忆熊,记得更牢,用得更好。\uD83D\uDC3B✨"
|
||||
]
|
||||
@@ -19,12 +20,13 @@
|
||||
"releaseDate": "2026-1-23",
|
||||
"upgradePosition": "\uD83D\uDC3B This update focuses on improving usability and fixing known issues, making the system more stable and easier to use overall.",
|
||||
"coreUpgrades": [
|
||||
"1. Improved Workflow Experience\nCleaner, more intuitive UI for easier configuration at a glance\nAdded visibility of node output variables, making them easier to reference in downstream nodes\nFixed several usability-related bugs that affected the workflow experience",
|
||||
"2. Simpler Agent Configuration\nSmoother linkage between prompts and variables\nReorganized configuration layout for easier navigation and better clarity",
|
||||
"3. More Stable Memory System\nOptimized cache refresh for emotional memory and implicit memory\nFixed error issues on the memory configuration page\nThe system can now automatically distinguish between user and AI roles",
|
||||
"4. Enhanced Knowledge Base Experience\nFixed issues with document parsing failures\nUpload progress is now displayed during document processing\nCanceling an upload no longer triggers errors",
|
||||
"5. Overall System Reliability Improvements\nFixed redirect issues affecting new users\nImproved stability of streaming APIs to prevent interruptions during long conversations\nAdjusted menu ordering for a smoother and more intuitive workflow\n",
|
||||
"Although this is a relatively small update, it strengthens MemoryBear’s foundation and delivers a noticeably smoother experience.\nWe’ll keep refining the system to make AI memory more powerful and easier to use.",
|
||||
"1. Improved Workflow Experience<br>* Cleaner, more intuitive UI for easier configuration at a glance<br>* Added visibility of node output variables, making them easier to reference in downstream nodes<br>* Fixed several usability-related bugs that affected the workflow experience",
|
||||
"2. Simpler Agent Configuration<br>* Smoother linkage between prompts and variables<br>* Reorganized configuration layout for easier navigation and better clarity",
|
||||
"3. More Stable Memory System<br>* Optimized cache refresh for emotional memory and implicit memory<br>* Fixed error issues on the memory configuration page<br>* The system can now automatically distinguish between user and AI roles",
|
||||
"4. Enhanced Knowledge Base Experience<br>* Fixed issues with document parsing failures<br>* Upload progress is now displayed during document processing<br>* Canceling an upload no longer triggers errors",
|
||||
"5. Overall System Reliability Improvements<br>* Fixed redirect issues affecting new users<br>* Improved stability of streaming APIs to prevent interruptions during long conversations<br>* Adjusted menu ordering for a smoother and more intuitive workflow",
|
||||
"<br>",
|
||||
"Although this is a relatively small update, it strengthens MemoryBear’s foundation and delivers a noticeably smoother experience. We’ll keep refining the system to make AI memory more powerful and easier to use.",
|
||||
"MemoryBear — remember better, work smarter. \uD83D\uDC3B✨"
|
||||
]
|
||||
}
|
||||
@@ -35,10 +37,10 @@
|
||||
"releaseDate": "2026-1-16",
|
||||
"upgradePosition": "本次为架构升级,核心目标是把\"被动存储\"升级为\"主动认知\",让系统具备情绪感知、情景理解与类人记忆机制,为后续多智能体协作与专业场景落地奠定底座。",
|
||||
"coreUpgrades": [
|
||||
"记忆详情:拟人记忆——情绪引擎、情景记忆、短期记忆、工作记忆、感知记忆、显性记忆、隐性记忆,并配套类脑遗忘机制,实现从感知→情绪→情景→长期沉淀的完整人类记忆闭环",
|
||||
"可视化工作流:拖拽式节点编排(LLM、知识库、逻辑、工具),业务落地周期由天缩至小时。",
|
||||
"多模态知识处理:PDF、PPT、MP3、MP4 一键解析,时间感知检索准确率 94.3%,问答对数据即插即用。",
|
||||
"Agent集群内置\"记忆-知识-工具-审核\"四类角色模板,用户一键生成;主控Agent把复杂任务拆为子任务并行分发,再靠情景记忆统一消解冲突、校验一致性,输出完整报告。"
|
||||
"1. 记忆详情:拟人记忆——情绪引擎、情景记忆、短期记忆、工作记忆、感知记忆、显性记忆、隐性记忆,并配套类脑遗忘机制,实现从感知→情绪→情景→长期沉淀的完整人类记忆闭环",
|
||||
"2. 可视化工作流:拖拽式节点编排(LLM、知识库、逻辑、工具),业务落地周期由天缩至小时。",
|
||||
"3. 多模态知识处理:PDF、PPT、MP3、MP4 一键解析,时间感知检索准确率 94.3%,问答对数据即插即用。",
|
||||
"4. Agent集群内置\"记忆-知识-工具-审核\"四类角色模板,用户一键生成;主控Agent把复杂任务拆为子任务并行分发,再靠情景记忆统一消解冲突、校验一致性,输出完整报告。"
|
||||
]
|
||||
},
|
||||
"introduction_en": {
|
||||
@@ -46,10 +48,10 @@
|
||||
"releaseDate": "2026-1-16",
|
||||
"upgradePosition": "This release marks a foundational upgrade to the system’s cognitive architecture. The core objective is to evolve the platform from passive information storage into active cognitive intelligence—enabling emotional awareness, situational understanding, and human-like memory mechanisms. This upgrade lays the groundwork for future multi-agent collaboration and domain-specific, production-grade AI applications.",
|
||||
"coreUpgrades": [
|
||||
"Human-Like Memory Architecture: A comprehensive, human-inspired memory system is introduced, encompassing emotional processing, situational memory, short-term and working memory, perceptual memory, as well as explicit and implicit memory. Combined with brain-inspired forgetting mechanisms, the system now supports a complete cognitive loop—from perception → emotion → context → long-term consolidation, closely mirroring human memory formation.",
|
||||
"Visual Workflow Orchestration: A fully visual, drag-and-drop workflow enables modular composition of LLMs, knowledge bases, logic, and tools. This dramatically reduces the time required to move from experimentation to production—from days to hours.",
|
||||
"Multimodal Knowledge Processing: The system now supports one-click parsing and ingestion of PDF, PPT, MP3, and MP4 content. With time-aware retrieval accuracy reaching 94.3%, structured Q&A data becomes instantly usable for downstream reasoning and generation.",
|
||||
"Built-in Agent Clusters: Predefined role templates across four categories—Memory, Knowledge, Tools, and Review—can be generated with a single click. A Coordinator Agent decomposes complex tasks into parallel subtasks, while situational memory is used to resolve conflicts, validate consistency, and synthesize outputs into a coherent, end-to-end report."
|
||||
"1. Human-Like Memory Architecture: A comprehensive, human-inspired memory system is introduced, encompassing emotional processing, situational memory, short-term and working memory, perceptual memory, as well as explicit and implicit memory. Combined with brain-inspired forgetting mechanisms, the system now supports a complete cognitive loop—from perception → emotion → context → long-term consolidation, closely mirroring human memory formation.",
|
||||
"2. Visual Workflow Orchestration: A fully visual, drag-and-drop workflow enables modular composition of LLMs, knowledge bases, logic, and tools. This dramatically reduces the time required to move from experimentation to production—from days to hours.",
|
||||
"3. Multimodal Knowledge Processing: The system now supports one-click parsing and ingestion of PDF, PPT, MP3, and MP4 content. With time-aware retrieval accuracy reaching 94.3%, structured Q&A data becomes instantly usable for downstream reasoning and generation.",
|
||||
"4. Built-in Agent Clusters: Predefined role templates across four categories—Memory, Knowledge, Tools, and Review—can be generated with a single click. A Coordinator Agent decomposes complex tasks into parallel subtasks, while situational memory is used to resolve conflicts, validate consistency, and synthesize outputs into a coherent, end-to-end report."
|
||||
]
|
||||
}
|
||||
},
|
||||
@@ -59,16 +61,17 @@
|
||||
"releaseDate": "2025-12-01",
|
||||
"upgradePosition": "这是一款专注于管理和利用AI记忆的工具,支持RAG和知识图谱两种主流存储方式,旨在为AI应用提供持久化、结构化的\"记忆\"能力。",
|
||||
"coreUpgrades": [
|
||||
"记忆空间:用户可以创建独立的空间来隔离不同记忆,并灵活选择存储方式。",
|
||||
"记忆配置:简化了配置流程,内置自动提取关键信息的\"记忆萃取\"和管理生命周期的\"遗忘\"引擎。",
|
||||
"知识检索:提供语义、分词和混合三种检索模式,并支持多种参数微调和结果重排序,以提升召回效果。",
|
||||
"全局管理:支持统一设置默认检索参数,并可一键应用到所有知识库。",
|
||||
"测试与调试:内置\"召回测试\"功能,方便用户实时验证检索效果并调整参数,支持通过分享码与他人协作。",
|
||||
"记忆洞察:可查看详细的对话记录、用户画像和分析报告,帮助理解AI的\"记忆\"内容。",
|
||||
"集成与管理:提供API Key用于系统集成,并包含基本的用户管理功能。",
|
||||
"界面与体验:采用现代化的卡片式布局和渐变色设计,注重交互的流畅性和视觉美感。",
|
||||
"起步与使用:文档中提供了清晰的基础使用流程,引导用户从创建空间、配置记忆到测试检索快速上手。",
|
||||
"版本说明与限制: 记忆熊 v0.1.0 版本\"初心\"囊括智能记忆管理的核心思路和基础能力,为后续开发奠定了基础。",
|
||||
"1. 记忆空间:用户可以创建独立的空间来隔离不同记忆,并灵活选择存储方式。",
|
||||
"2. 记忆配置:简化了配置流程,内置自动提取关键信息的\"记忆萃取\"和管理生命周期的\"遗忘\"引擎。",
|
||||
"3. 知识检索:提供语义、分词和混合三种检索模式,并支持多种参数微调和结果重排序,以提升召回效果。",
|
||||
"4. 全局管理:支持统一设置默认检索参数,并可一键应用到所有知识库。",
|
||||
"5. 测试与调试:内置\"召回测试\"功能,方便用户实时验证检索效果并调整参数,支持通过分享码与他人协作。",
|
||||
"6. 记忆洞察:可查看详细的对话记录、用户画像和分析报告,帮助理解AI的\"记忆\"内容。",
|
||||
"7. 集成与管理:提供API Key用于系统集成,并包含基本的用户管理功能。",
|
||||
"8. 界面与体验:采用现代化的卡片式布局和渐变色设计,注重交互的流畅性和视觉美感。",
|
||||
"9. 起步与使用:文档中提供了清晰的基础使用流程,引导用户从创建空间、配置记忆到测试检索快速上手。",
|
||||
"10. 版本说明与限制: 记忆熊 v0.1.0 版本\"初心\"囊括智能记忆管理的核心思路和基础能力,为后续开发奠定了基础。",
|
||||
"<br>",
|
||||
"文档资源:用户手册、API文档、FAQ",
|
||||
"问题反馈:GitHub Issues、邮件支持",
|
||||
"致谢:感谢所有参与测试和提供反馈的用户!"
|
||||
@@ -79,16 +82,17 @@
|
||||
"releaseDate": "2025-12-01",
|
||||
"upgradePosition": "A tool focused on managing and utilizing AI memory, supporting both RAG and knowledge graph storage methods, aiming to provide persistent and structured 'memory' capabilities for AI applications.",
|
||||
"coreUpgrades": [
|
||||
"Memory Space: Users can create independent spaces to isolate different memories and flexibly choose storage methods.",
|
||||
"Memory Configuration: Simplified configuration process with built-in 'memory extraction' for automatic key information extraction and 'forgetting' engine for lifecycle management.",
|
||||
"Knowledge Retrieval: Provides semantic, tokenization, and hybrid retrieval modes with various parameter tuning and result reranking to improve recall.",
|
||||
"Global Management: Supports unified default retrieval parameter settings with one-click application to all knowledge bases.",
|
||||
"Testing & Debugging: Built-in 'recall testing' for real-time verification of retrieval effects and parameter adjustment, with sharing code support for collaboration.",
|
||||
"Memory Insights: View detailed conversation records, user profiles, and analysis reports to understand AI 'memory' content.",
|
||||
"Integration & Management: Provides API Key for system integration with basic user management features.",
|
||||
"Interface & Experience: Modern card-based layout with gradient design, focusing on interaction fluidity and visual aesthetics.",
|
||||
"Getting Started: Documentation provides clear basic usage flow, guiding users from creating spaces, configuring memory to testing retrieval.",
|
||||
"Version Notes: MemoryBear v0.1.0 'Original Intent' encompasses core concepts and basic capabilities of intelligent memory management, laying foundation for future development.",
|
||||
"1. Memory Space: Users can create independent spaces to isolate different memories and flexibly choose storage methods.",
|
||||
"2. Memory Configuration: Simplified configuration process with built-in 'memory extraction' for automatic key information extraction and 'forgetting' engine for lifecycle management.",
|
||||
"3. Knowledge Retrieval: Provides semantic, tokenization, and hybrid retrieval modes with various parameter tuning and result reranking to improve recall.",
|
||||
"4. Global Management: Supports unified default retrieval parameter settings with one-click application to all knowledge bases.",
|
||||
"5. Testing & Debugging: Built-in 'recall testing' for real-time verification of retrieval effects and parameter adjustment, with sharing code support for collaboration.",
|
||||
"6. Memory Insights: View detailed conversation records, user profiles, and analysis reports to understand AI 'memory' content.",
|
||||
"7. Integration & Management: Provides API Key for system integration with basic user management features.",
|
||||
"8. Interface & Experience: Modern card-based layout with gradient design, focusing on interaction fluidity and visual aesthetics.",
|
||||
"9. Getting Started: Documentation provides clear basic usage flow, guiding users from creating spaces, configuring memory to testing retrieval.",
|
||||
"10. Version Notes: MemoryBear v0.1.0 'Original Intent' encompasses core concepts and basic capabilities of intelligent memory management, laying foundation for future development.",
|
||||
"<br>",
|
||||
"Documentation: User Manual, API Documentation, FAQ",
|
||||
"Feedback: GitHub Issues, Email Support",
|
||||
"Acknowledgments: Thanks to all users who participated in testing and provided feedback!"
|
||||
|
||||
224
api/migrations/versions/915bed077f8d_202601281340.py
Normal file
224
api/migrations/versions/915bed077f8d_202601281340.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""202601281340
|
||||
|
||||
Revision ID: 915bed077f8d
|
||||
Revises: 75f0ec80e50b
|
||||
Create Date: 2026-01-28 13:38:49.471560
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '915bed077f8d'
|
||||
down_revision: Union[str, None] = '75f0ec80e50b'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
BACKUP_TABLE_NAME = 'model_api_keys_backup_20260123'
|
||||
|
||||
def get_temp_models():
|
||||
"""创建临时模型,用于迁移过程中查询数据"""
|
||||
metadata = sa.MetaData()
|
||||
|
||||
# 临时ModelApiKey表(仅包含需要的字段)
|
||||
ModelApiKey = sa.Table(
|
||||
'model_api_keys', metadata,
|
||||
sa.Column('id', sa.UUID(), primary_key=True),
|
||||
sa.Column('model_config_id', sa.UUID(), nullable=True),
|
||||
)
|
||||
|
||||
# 临时关联表(和升级脚本创建的表结构一致)
|
||||
ModelConfigApiKeyAssociation = sa.Table(
|
||||
'model_config_api_key_association', metadata,
|
||||
sa.Column('model_config_id', sa.UUID(), nullable=False),
|
||||
sa.Column('api_key_id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
)
|
||||
|
||||
ModelApiKeyBackup = sa.Table(
|
||||
BACKUP_TABLE_NAME, metadata,
|
||||
sa.Column('id', sa.UUID(), primary_key=True),
|
||||
sa.Column('model_name', sa.String(), nullable=False),
|
||||
sa.Column('description', sa.String(), nullable=True),
|
||||
sa.Column('provider', sa.String(), nullable=False),
|
||||
sa.Column('api_key', sa.String(), nullable=False),
|
||||
sa.Column('api_base', sa.String(), nullable=True),
|
||||
sa.Column('config', sa.JSON(), nullable=True),
|
||||
sa.Column('usage_count', sa.String(), default="0"),
|
||||
sa.Column('last_used_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('priority', sa.String(), default="1"),
|
||||
sa.Column('model_config_id', sa.UUID(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), default=True),
|
||||
)
|
||||
|
||||
return ModelApiKey, ModelConfigApiKeyAssociation, ModelApiKeyBackup
|
||||
|
||||
|
||||
def backup_model_api_keys():
|
||||
"""备份model_api_keys表的结构和数据"""
|
||||
connection = op.get_bind()
|
||||
|
||||
# 检查备份表是否已存在
|
||||
result = connection.execute(sa.text(f"""
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_name = '{BACKUP_TABLE_NAME}'
|
||||
);
|
||||
""")).scalar()
|
||||
|
||||
if result:
|
||||
# 备份表已存在,先删除再重建(确保结构一致)
|
||||
op.execute(f"DROP TABLE IF EXISTS {BACKUP_TABLE_NAME};")
|
||||
|
||||
# 直接复制表结构和数据(PostgreSQL专用,一步完成)
|
||||
op.execute(f"""
|
||||
CREATE TABLE {BACKUP_TABLE_NAME} AS
|
||||
SELECT * FROM model_api_keys;
|
||||
""")
|
||||
|
||||
# 统计行数
|
||||
backup_count = connection.execute(sa.text(f"SELECT COUNT(*) FROM {BACKUP_TABLE_NAME}")).scalar()
|
||||
original_count = connection.execute(sa.text("SELECT COUNT(*) FROM model_api_keys")).scalar()
|
||||
|
||||
print(
|
||||
f"已备份model_api_keys表到 {BACKUP_TABLE_NAME} \n"
|
||||
f" 原表数据行数:{original_count} | 备份表数据行数:{backup_count}"
|
||||
)
|
||||
|
||||
# def restore_model_api_keys_from_backup():
|
||||
# """从备份表恢复model_api_keys数据(可选,用于回滚失败时手动恢复)"""
|
||||
# # 1. 清空原表(谨慎使用!)
|
||||
# # op.execute("TRUNCATE TABLE model_api_keys;")
|
||||
#
|
||||
# # 2. 从备份表恢复数据
|
||||
# op.execute(f"""
|
||||
# INSERT INTO model_api_keys
|
||||
# SELECT * FROM {BACKUP_TABLE_NAME}
|
||||
# ON CONFLICT (id) DO UPDATE SET
|
||||
# model_name = EXCLUDED.model_name,
|
||||
# description = EXCLUDED.description,
|
||||
# provider = EXCLUDED.provider,
|
||||
# api_key = EXCLUDED.api_key,
|
||||
# api_base = EXCLUDED.api_base,
|
||||
# config = EXCLUDED.config,
|
||||
# usage_count = EXCLUDED.usage_count,
|
||||
# last_used_at = EXCLUDED.last_used_at,
|
||||
# priority = EXCLUDED.priority,
|
||||
# model_config_id = EXCLUDED.model_config_id,
|
||||
# created_at = EXCLUDED.created_at,
|
||||
# updated_at = EXCLUDED.updated_at,
|
||||
# is_active = EXCLUDED.is_active;
|
||||
# """)
|
||||
# print(f"✅ 已从 {BACKUP_TABLE_NAME} 恢复model_api_keys表数据")
|
||||
|
||||
def upgrade() -> None:
|
||||
backup_model_api_keys()
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('model_bases',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('logo', sa.String(length=255), nullable=True, comment='模型logo图片URL'),
|
||||
sa.Column('name', sa.String(), nullable=False, comment='模型唯一标识(如gpt-3.5-turbo)'),
|
||||
sa.Column('type', sa.String(), nullable=False, comment='模型类型'),
|
||||
sa.Column('provider', sa.String(), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True, comment='模型描述'),
|
||||
sa.Column('is_deprecated', sa.Boolean(), nullable=False, comment='是否弃用'),
|
||||
sa.Column('is_official', sa.Boolean(), nullable=True, comment='是否供应商官方模型(区分自定义)'),
|
||||
sa.Column('tags', sa.ARRAY(sa.String()), nullable=False, comment="模型标签(如['聊天', '创作'])"),
|
||||
sa.Column('add_count', sa.Integer(), nullable=False, comment='模型被用户添加的次数'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('name', 'provider', name='uk_model_name_provider')
|
||||
)
|
||||
op.create_index(op.f('ix_model_bases_id'), 'model_bases', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_model_bases_provider'), 'model_bases', ['provider'], unique=False)
|
||||
op.create_index(op.f('ix_model_bases_type'), 'model_bases', ['type'], unique=False)
|
||||
op.create_table('model_config_api_key_association',
|
||||
sa.Column('model_config_id', sa.UUID(), nullable=False),
|
||||
sa.Column('api_key_id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['api_key_id'], ['model_api_keys.id'], ),
|
||||
sa.ForeignKeyConstraint(['model_config_id'], ['model_configs.id'], ),
|
||||
sa.PrimaryKeyConstraint('model_config_id', 'api_key_id')
|
||||
)
|
||||
op.add_column('model_api_keys', sa.Column('description', sa.String(), nullable=True, comment='备注'))
|
||||
op.add_column('model_configs', sa.Column('model_id', sa.UUID(), nullable=True, comment='基础模型ID'))
|
||||
op.add_column('model_configs', sa.Column('logo', sa.String(length=255), nullable=True, comment='模型logo图片URL'))
|
||||
op.add_column('model_configs', sa.Column('provider', sa.String(), server_default='composite', nullable=False, comment='供应商'))
|
||||
op.add_column('model_configs', sa.Column('is_composite', sa.Boolean(), server_default='true', nullable=False, comment='是否为组合模型'))
|
||||
op.add_column('model_configs', sa.Column('load_balance_strategy', sa.String(), nullable=True, comment='负载均衡策略'))
|
||||
op.create_index(op.f('ix_model_configs_model_id'), 'model_configs', ['model_id'], unique=False)
|
||||
op.create_foreign_key("model_configs_model_id_fkey", 'model_configs', 'model_bases', ['model_id'], ['id'])
|
||||
connection = op.get_bind()
|
||||
ModelApiKey, ModelConfigApiKeyAssociation, _ = get_temp_models()
|
||||
|
||||
# 查询所有有model_config_id的API Key
|
||||
api_keys = connection.execute(
|
||||
sa.select(ModelApiKey.c.id, ModelApiKey.c.model_config_id)
|
||||
.where(ModelApiKey.c.model_config_id.isnot(None))
|
||||
).fetchall()
|
||||
|
||||
# 批量插入到多对多表
|
||||
if api_keys:
|
||||
association_data = [
|
||||
{
|
||||
'model_config_id': row.model_config_id,
|
||||
'api_key_id': row.id
|
||||
}
|
||||
for row in api_keys
|
||||
]
|
||||
connection.execute(ModelConfigApiKeyAssociation.insert(), association_data)
|
||||
op.drop_constraint(op.f('model_api_keys_model_config_id_fkey'), 'model_api_keys', type_='foreignkey')
|
||||
op.drop_column('model_api_keys', 'model_config_id')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint("model_configs_model_id_fkey", 'model_configs', type_='foreignkey')
|
||||
op.drop_index(op.f('ix_model_configs_model_id'), table_name='model_configs')
|
||||
op.drop_column('model_configs', 'load_balance_strategy')
|
||||
op.drop_column('model_configs', 'is_composite')
|
||||
op.drop_column('model_configs', 'provider')
|
||||
op.drop_column('model_configs', 'logo')
|
||||
op.drop_column('model_configs', 'model_id')
|
||||
op.add_column('model_api_keys', sa.Column('model_config_id', sa.UUID(), autoincrement=False, nullable=True, comment='模型配置ID'))
|
||||
connection = op.get_bind()
|
||||
ModelApiKey, ModelConfigApiKeyAssociation, _ = get_temp_models()
|
||||
|
||||
# 查询多对多表中的关联数据(取每个API Key的第一个关联的model_config_id)
|
||||
association_data = connection.execute(
|
||||
sa.select(
|
||||
ModelConfigApiKeyAssociation.c.api_key_id,
|
||||
ModelConfigApiKeyAssociation.c.model_config_id
|
||||
).distinct(ModelConfigApiKeyAssociation.c.api_key_id)
|
||||
).fetchall()
|
||||
|
||||
# 批量更新model_api_keys表
|
||||
if association_data:
|
||||
for api_key_id, model_config_id in association_data:
|
||||
connection.execute(
|
||||
sa.update(ModelApiKey)
|
||||
.where(ModelApiKey.c.id == api_key_id)
|
||||
.values(model_config_id=model_config_id)
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"UPDATE model_api_keys SET model_config_id = '00000000-0000-0000-0000-000000000000' WHERE model_config_id IS NULL")
|
||||
op.alter_column('model_api_keys', 'model_config_id', nullable=False)
|
||||
op.create_foreign_key(op.f('model_api_keys_model_config_id_fkey'), 'model_api_keys', 'model_configs', ['model_config_id'], ['id'])
|
||||
op.drop_column('model_api_keys', 'description')
|
||||
op.drop_table('model_config_api_key_association')
|
||||
# ### 可选:回滚时恢复备份(如需)###
|
||||
# restore_model_api_keys_from_backup()
|
||||
|
||||
print(
|
||||
f"回滚完成!备份表 {BACKUP_TABLE_NAME} 仍保留,如需手动恢复可执行 restore_model_api_keys_from_backup() 函数")
|
||||
op.drop_index(op.f('ix_model_bases_type'), table_name='model_bases')
|
||||
op.drop_index(op.f('ix_model_bases_provider'), table_name='model_bases')
|
||||
op.drop_index(op.f('ix_model_bases_id'), table_name='model_bases')
|
||||
op.drop_table('model_bases')
|
||||
# ### end Alembic commands ###
|
||||
@@ -108,4 +108,8 @@ export const getShareToken = (share_token: string, user_id: string) => {
|
||||
// 复制应用
|
||||
export const copyApplication = (app_id: string, new_name: string) => {
|
||||
return request.post(`/apps/${app_id}/copy?new_name=${new_name}`)
|
||||
}
|
||||
}
|
||||
// 数据统计
|
||||
export const getAppStatistics = (app_id: string, data: { start_date: number; end_date: number; }) => {
|
||||
return request.get(`/apps/${app_id}/statistics`, data)
|
||||
}
|
||||
|
||||
25
web/src/api/fileStorage.ts
Normal file
25
web/src/api/fileStorage.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
import { request, API_PREFIX } from '@/utils/request'
|
||||
|
||||
// Upload file,file storage has expiration period
|
||||
export const fileUploadUrl = `${API_PREFIX}/storage/files`
|
||||
export const fileUpload = (formData?: unknown) => {
|
||||
return request.uploadFile('/storage/files', formData)
|
||||
}
|
||||
|
||||
// Get file access URL (no token required)
|
||||
export const getFileUrl = (file_id: string) => `/storage/files/${file_id}/url`
|
||||
export const getFileLink = (fileId: string, data: { permanent?: boolean } = { permanent: true }) => {
|
||||
return request.get(getFileUrl(fileId), data)
|
||||
}
|
||||
|
||||
// Get file internally
|
||||
export const getInternalFileUrl = (file_id: string) => `/storage/files/${file_id}`
|
||||
export const getInternalFile = (fileId: string) => {
|
||||
return request.get(getInternalFileUrl(fileId))
|
||||
}
|
||||
|
||||
// Delete file
|
||||
export const deleteFileUrl = (file_id: string) => `/storage/files/${file_id}`
|
||||
export const deleteFile = (fileId: string) => {
|
||||
return request.delete(deleteFileUrl(fileId))
|
||||
}
|
||||
@@ -65,7 +65,7 @@ export const getModelTypeList = async () => {
|
||||
};
|
||||
// 获取模型列表
|
||||
export const getModelList = async (pageInfo: PageRequest) => {
|
||||
const response = await request.get(`${apiPrefix}/models`, pageInfo);
|
||||
const response = await request.get(`${apiPrefix}/models`, { ...pageInfo, is_active: true });
|
||||
return response as any;
|
||||
};
|
||||
//获取模型提供者
|
||||
|
||||
@@ -1,23 +1,68 @@
|
||||
import { request } from '@/utils/request'
|
||||
import type { ModelFormData } from '@/views/ModelManagement/types'
|
||||
import type { MultiKeyForm, Query, KeyConfigModalForm, CompositeModelForm, CustomModelForm } from '@/views/ModelManagement/types'
|
||||
|
||||
// 模型列表
|
||||
// Model list
|
||||
export const getModelListUrl = '/models'
|
||||
export const getModelList = (data: { type: string; pagesize: number; page: number; }) => {
|
||||
export const getModelList = (data: Query) => {
|
||||
return request.get(getModelListUrl, data)
|
||||
}
|
||||
// 创建模型
|
||||
export const addModel = (data: ModelFormData) => {
|
||||
return request.post('/models', data)
|
||||
}
|
||||
// 更新模型
|
||||
export const updateModel = (apiKeyId: string, data: ModelFormData) => {
|
||||
return request.put(`/models/apikeys/${apiKeyId}`, data)
|
||||
}
|
||||
// 模型类型列表
|
||||
// Model type list
|
||||
export const modelTypeUrl = '/models/type'
|
||||
// 模型供应商列表
|
||||
// Model provider list
|
||||
export const modelProviderUrl = '/models/provider'
|
||||
export const getModelProviderList = () => {
|
||||
return request.get(modelProviderUrl)
|
||||
}
|
||||
// New model list
|
||||
export const getModelNewListUrl = '/models/new'
|
||||
export const getModelNewList = (data: Query) => {
|
||||
return request.get(getModelNewListUrl, data)
|
||||
}
|
||||
// Get model information
|
||||
export const getModelInfo = (model_id: string) => {
|
||||
return request.get(`/models/${model_id}`)
|
||||
}
|
||||
// Create composite model
|
||||
export const addCompositeModel = (data: CompositeModelForm) => {
|
||||
return request.post('/models/composite', data)
|
||||
}
|
||||
// Update composite model
|
||||
export const updateCompositeModel = (model_id: string, data: CompositeModelForm) => {
|
||||
return request.put(`/models/composite/${model_id}`, data)
|
||||
}
|
||||
// Delete composite model
|
||||
export const deleteCompositeModel = (model_id: string) => {
|
||||
return request.delete(`/models/composite/${model_id}`)
|
||||
}
|
||||
// Create API keys for all matching models by provider
|
||||
export const updateProviderApiKeys = (data: KeyConfigModalForm) => {
|
||||
return request.post('/models/provider/apikeys', data)
|
||||
}
|
||||
// Create model API key
|
||||
export const addModelApiKey = (model_id: string, data: MultiKeyForm) => {
|
||||
return request.post(`/models/${model_id}/apikeys`, data)
|
||||
}
|
||||
// Delete model API key
|
||||
export const deleteModelApiKey = (api_key_id: string) => {
|
||||
return request.delete(`/models/apikeys/${api_key_id}`)
|
||||
}
|
||||
// Update model status
|
||||
export const updateModelStatus = (model_id: string, data: { is_active: boolean; }) => {
|
||||
return request.put(`/models/${model_id}`, data)
|
||||
}
|
||||
// Model plaza list
|
||||
export const getModelPlaza = (data: { search?: string; provider?: string; }) => {
|
||||
return request.get('/models/model_plaza', data)
|
||||
}
|
||||
// Add model to plaza
|
||||
export const addModelPlaza = (model_base_id: string) => {
|
||||
return request.post(`/models/model_plaza/${model_base_id}/add`)
|
||||
}
|
||||
// Create custom model
|
||||
export const addCustomModel = (data: CustomModelForm) => {
|
||||
return request.post('/models/model_plaza', data)
|
||||
}
|
||||
// Update custom model
|
||||
export const updateCustomModel = (model_base_id: string, data: CustomModelForm) => {
|
||||
return request.put(`/models/model_plaza/${model_base_id}`, data)
|
||||
}
|
||||
BIN
web/src/assets/images/empty/pageEmpty.png
Normal file
BIN
web/src/assets/images/empty/pageEmpty.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 157 KiB |
@@ -15,7 +15,7 @@ interface ApiResponse<T> {
|
||||
interface CustomSelectProps extends Omit<SelectProps, 'filterOption'> {
|
||||
url: string;
|
||||
params?: Record<string, unknown>;
|
||||
valueKey?: string | string[];
|
||||
valueKey?: string;
|
||||
labelKey?: string;
|
||||
placeholder?: string;
|
||||
hasAll?: boolean;
|
||||
@@ -66,18 +66,11 @@ const CustomSelect: FC<CustomSelectProps> = ({
|
||||
{...props}
|
||||
>
|
||||
{hasAll && <Select.Option value={null}>{allTitle || t('common.all')}</Select.Option>}
|
||||
{displayOptions.map((option) => {
|
||||
const getValue = () => {
|
||||
if (typeof valueKey === 'string') return option[valueKey];
|
||||
return valueKey.find(key => option[key] != null) ? option[valueKey.find(key => option[key] != null)!] : undefined;
|
||||
};
|
||||
const value = getValue();
|
||||
return (
|
||||
<Select.Option key={value} value={value}>
|
||||
{String(option[labelKey])}
|
||||
</Select.Option>
|
||||
);
|
||||
})}
|
||||
{displayOptions.map((option) => (
|
||||
<Select.Option key={option[valueKey]} value={option[valueKey]}>
|
||||
{String(option[labelKey])}
|
||||
</Select.Option>
|
||||
))}
|
||||
</Select>
|
||||
);
|
||||
};
|
||||
|
||||
16
web/src/components/Empty/PageEmpty.tsx
Normal file
16
web/src/components/Empty/PageEmpty.tsx
Normal file
@@ -0,0 +1,16 @@
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import pageEmptyIcon from '@/assets/images/empty/pageEmpty.png'
|
||||
import Empty from './index'
|
||||
const PageEmpty = ({ size = [240, 210] }: { size?: number | number[] }) => {
|
||||
const { t } = useTranslation()
|
||||
return (
|
||||
<Empty
|
||||
url={pageEmptyIcon}
|
||||
title={t('empty.pageEmpty')}
|
||||
subTitle={t('empty.pageEmptyDesc')}
|
||||
size={size}
|
||||
className="rb:h-full"
|
||||
/>
|
||||
)
|
||||
}
|
||||
export default PageEmpty;
|
||||
13
web/src/components/PageTabs/index.module.css
Normal file
13
web/src/components/PageTabs/index.module.css
Normal file
@@ -0,0 +1,13 @@
|
||||
.page-tabs:global(.ant-segmented) {
|
||||
background-color: rgba(91, 97, 103, 0.08);
|
||||
padding: 4px;
|
||||
}
|
||||
.page-tabs:global(.ant-segmented .ant-segmented-item-label) {
|
||||
line-height: 24px;
|
||||
min-height: 24px;
|
||||
padding: 0 12px;
|
||||
}
|
||||
|
||||
.page-tabs:global(.ant-segmented .ant-segmented-item-selected) {
|
||||
box-shadow: 0px 2px 4px 0px rgba(33, 35, 50, 0.16);
|
||||
}
|
||||
18
web/src/components/PageTabs/index.tsx
Normal file
18
web/src/components/PageTabs/index.tsx
Normal file
@@ -0,0 +1,18 @@
|
||||
import { type FC } from 'react';
|
||||
import { Segmented, type SegmentedProps } from 'antd';
|
||||
import styles from './index.module.css';
|
||||
|
||||
const PageTabs: FC<SegmentedProps> = ({
|
||||
value,
|
||||
options,
|
||||
onChange
|
||||
}) => {
|
||||
return <Segmented
|
||||
value={value}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
className={styles.pageTabs}
|
||||
/>;
|
||||
};
|
||||
|
||||
export default PageTabs;
|
||||
@@ -1,5 +1,5 @@
|
||||
import { type FC, type ReactNode } from 'react'
|
||||
import { Card } from 'antd';
|
||||
import { Card, Tooltip } from 'antd';
|
||||
import clsx from 'clsx';
|
||||
|
||||
interface RbCardProps {
|
||||
@@ -9,7 +9,7 @@ interface RbCardProps {
|
||||
extra?: ReactNode;
|
||||
children?: ReactNode;
|
||||
avatar?: ReactNode;
|
||||
avatarUrl?: string;
|
||||
avatarUrl?: string | null;
|
||||
bodyPadding?: string;
|
||||
bodyClassName?: string;
|
||||
headerType?: 'border' | 'borderless' | 'borderBL' | 'borderL';
|
||||
@@ -63,7 +63,7 @@ const RbCard: FC<RbCardProps> = ({
|
||||
}
|
||||
)
|
||||
}>
|
||||
<div className="rb:w-full rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{title}</div>
|
||||
<Tooltip title={title}><div className="rb:w-full rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{title}</div></Tooltip>
|
||||
{subTitle && <div className="rb:text-[#5B6167] rb:text-[12px]">{subTitle}</div>}
|
||||
</div>
|
||||
</div> : null
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
import { useState, useEffect, forwardRef, useImperativeHandle } from 'react';
|
||||
import { Upload, Modal, Image, App } from 'antd';
|
||||
import { Upload, Image, App } from 'antd';
|
||||
import type { GetProp, UploadFile, UploadProps } from 'antd';
|
||||
// import { UploadOutlined, } from '@ant-design/icons';
|
||||
import type { UploadProps as RcUploadProps } from 'antd/es/upload/interface';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import PlusIcon from '@/assets/images/plus.svg'
|
||||
import { cookieUtils } from '@/utils/request'
|
||||
import { fileUploadUrl } from '@/api/fileStorage'
|
||||
import styles from './index.module.less'
|
||||
|
||||
const { confirm } = Modal;
|
||||
|
||||
interface UploadImagesProps extends Omit<UploadProps, 'onChange'> {
|
||||
interface UploadImagesProps extends Omit<UploadProps, 'onChange' | 'fileList'> {
|
||||
/** 上传接口地址 */
|
||||
action?: string;
|
||||
/** 是否支持多选 */
|
||||
multiple?: boolean;
|
||||
/** 已上传的文件列表 */
|
||||
fileList?: UploadFile[];
|
||||
fileList?: UploadFile[] | UploadFile;
|
||||
/** 文件列表变化回调 */
|
||||
onChange?: (fileList: UploadFile[]) => void;
|
||||
onChange?: (fileList?: UploadFile[] | UploadFile) => void;
|
||||
/** 禁用上传 */
|
||||
disabled?: boolean;
|
||||
/** 文件大小限制(MB) */
|
||||
@@ -28,6 +28,7 @@ interface UploadImagesProps extends Omit<UploadProps, 'onChange'> {
|
||||
isAutoUpload?: boolean;
|
||||
/** 最大上传文件数 */
|
||||
maxCount?: number;
|
||||
className?: string;
|
||||
}
|
||||
const ALL_FILE_TYPE: {
|
||||
[key: string]: string;
|
||||
@@ -59,7 +60,7 @@ const getBase64 = (file: FileType): Promise<string> => {
|
||||
* 支持单文件/多文件上传、拖拽上传、文件验证、预览等功能
|
||||
*/
|
||||
const UploadImages = forwardRef<UploadImagesRef, UploadImagesProps>(({
|
||||
action = '/api/upload',
|
||||
action = fileUploadUrl,
|
||||
multiple = false,
|
||||
fileList: propFileList = [],
|
||||
onChange,
|
||||
@@ -68,27 +69,42 @@ const UploadImages = forwardRef<UploadImagesRef, UploadImagesProps>(({
|
||||
fileType = ['png', 'jpg', 'gif'],
|
||||
isAutoUpload = true,
|
||||
maxCount = 1,
|
||||
className = 'rb:size-24! rb:leading-1!',
|
||||
...props
|
||||
}, ref) => {
|
||||
const { t } = useTranslation();
|
||||
const { message } = App.useApp()
|
||||
const [fileList, setFileList] = useState<UploadFile[]>(propFileList);
|
||||
const { message, modal } = App.useApp()
|
||||
const [fileList, setFileList] = useState<UploadFile[]>([]);
|
||||
const [accept, setAccept] = useState<string | undefined>();
|
||||
// const [loading, setLoading] = useState(false);
|
||||
const [previewOpen, setPreviewOpen] = useState(false);
|
||||
const [previewImage, setPreviewImage] = useState('');
|
||||
|
||||
useEffect(() => {
|
||||
if (!Array.isArray(propFileList) && typeof propFileList === 'object') {
|
||||
setFileList([propFileList]);
|
||||
}
|
||||
}, [propFileList])
|
||||
|
||||
const updateValue = (list: UploadFile[]) => {
|
||||
if (maxCount === 1) {
|
||||
onChange?.(list[0])
|
||||
} else {
|
||||
onChange?.(list)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理文件移除
|
||||
const handleRemove = (file: UploadFile) => {
|
||||
confirm({
|
||||
title: '确定要删除此文件吗?',
|
||||
okText: '确定',
|
||||
modal.confirm({
|
||||
title: t('common.confirmRemoveFile'),
|
||||
okText: `${t('common.confirm')}`,
|
||||
okType: 'danger',
|
||||
cancelText: '取消',
|
||||
cancelText: `${t('common.cancel')}`,
|
||||
onOk: () => {
|
||||
const newFileList = fileList.filter((item) => item.uid !== file.uid);
|
||||
setFileList(newFileList);
|
||||
onChange?.(newFileList);
|
||||
updateValue(newFileList)
|
||||
},
|
||||
});
|
||||
return false; // 阻止默认删除行为,由confirm控制
|
||||
@@ -100,7 +116,7 @@ const UploadImages = forwardRef<UploadImagesRef, UploadImagesProps>(({
|
||||
if (fileSize && file.size) {
|
||||
const isLtMaxSize = (file.size / 1024 / 1024) < fileSize;
|
||||
if (!isLtMaxSize) {
|
||||
message.error(`文件大小不能超过 ${fileSize}MB`);
|
||||
message.error(t('common.fileSizeTip', { size: fileSize }));
|
||||
return Upload.LIST_IGNORE;
|
||||
}
|
||||
}
|
||||
@@ -108,7 +124,7 @@ const UploadImages = forwardRef<UploadImagesRef, UploadImagesProps>(({
|
||||
if (accept && accept.length > 0 && file.type) {
|
||||
const isAccept = accept.includes(file.type);
|
||||
if (!isAccept) {
|
||||
message.error(`不支持的文件类型: ${file.type}`);
|
||||
message.error(`${t('common.fileAcceptTip')}${file.type}`);
|
||||
return Upload.LIST_IGNORE;
|
||||
}
|
||||
}
|
||||
@@ -119,7 +135,7 @@ const UploadImages = forwardRef<UploadImagesRef, UploadImagesProps>(({
|
||||
}
|
||||
const newFileList = [...fileList, file];
|
||||
setFileList(newFileList);
|
||||
onChange?.(newFileList);
|
||||
updateValue(newFileList);
|
||||
return Upload.LIST_IGNORE; // 阻止自动上传
|
||||
}
|
||||
|
||||
@@ -129,17 +145,13 @@ const UploadImages = forwardRef<UploadImagesRef, UploadImagesProps>(({
|
||||
// 处理上传状态变化
|
||||
const handleChange: UploadProps['onChange'] = ({ fileList: newFileList }) => {
|
||||
setFileList(newFileList);
|
||||
if (onChange) {
|
||||
onChange(newFileList);
|
||||
}
|
||||
updateValue(newFileList);
|
||||
};
|
||||
|
||||
// 清空已上传文件
|
||||
const clearFiles = () => {
|
||||
setFileList([]);
|
||||
if (onChange) {
|
||||
onChange([]);
|
||||
}
|
||||
updateValue([]);
|
||||
}
|
||||
|
||||
const handlePreview = async (file: UploadFile) => {
|
||||
@@ -167,7 +179,7 @@ const UploadImages = forwardRef<UploadImagesRef, UploadImagesProps>(({
|
||||
fileList,
|
||||
beforeUpload,
|
||||
headers: {
|
||||
authorization: cookieUtils.get('authToken') || '',
|
||||
authorization: `Bearer ${cookieUtils.get('authToken') }`,
|
||||
},
|
||||
onPreview: handlePreview,
|
||||
onRemove: handleRemove,
|
||||
@@ -180,6 +192,7 @@ const UploadImages = forwardRef<UploadImagesRef, UploadImagesProps>(({
|
||||
showRemoveIcon: true,
|
||||
showDownloadIcon: false,
|
||||
},
|
||||
className: `${styles.imageUpload} ${className}`,
|
||||
...props,
|
||||
};
|
||||
|
||||
@@ -193,16 +206,9 @@ const UploadImages = forwardRef<UploadImagesRef, UploadImagesProps>(({
|
||||
<>
|
||||
<Upload
|
||||
{...uploadProps}
|
||||
style={{
|
||||
width: '136px',
|
||||
height: '136px',
|
||||
}}
|
||||
>
|
||||
{fileList.length < maxCount && (
|
||||
<div className="rb:flex rb:flex-wrap rb:items-center rb:justify-center">
|
||||
<img src={PlusIcon} className="rb:w-[32px] rb:h-[32px]" />
|
||||
<div className="rb:mt-[12px] rb:text-[12px] rb:text-[#5B6167] rb:leading-[16px]">{t('common.clickUploadIcon')}</div>
|
||||
</div>
|
||||
<img src={PlusIcon} className="rb:size-7" />
|
||||
)}
|
||||
</Upload>
|
||||
{previewImage && (
|
||||
|
||||
7
web/src/components/Upload/index.module.less
Normal file
7
web/src/components/Upload/index.module.less
Normal file
@@ -0,0 +1,7 @@
|
||||
.image-upload:global(.ant-upload-wrapper.ant-upload-picture-card-wrapper .ant-upload-list.ant-upload-list-picture-card .ant-upload-list-item-container),
|
||||
.image-upload:global(.ant-upload-wrapper.ant-upload-picture-circle-wrapper .ant-upload-list.ant-upload-list-picture-card .ant-upload-list-item-container),
|
||||
.image-upload:global(.ant-upload-wrapper.ant-upload-picture-card-wrapper .ant-upload-list.ant-upload-list-picture-circle .ant-upload-list-item-container),
|
||||
.image-upload:global(.ant-upload-wrapper.ant-upload-picture-circle-wrapper .ant-upload-list.ant-upload-list-picture-circle .ant-upload-list-item-container) {
|
||||
width: 96px;
|
||||
height: 96px;
|
||||
}
|
||||
@@ -419,6 +419,9 @@ export const en = {
|
||||
statusEnabled: 'Available',
|
||||
statusDisabled: 'Unavailable',
|
||||
remove: 'Remove',
|
||||
|
||||
fileSizeTip: 'File size cannot exceed {{size}}MB',
|
||||
fileAcceptTip: 'Unsupported file type:'
|
||||
},
|
||||
model: {
|
||||
searchPlaceholder: 'search model…',
|
||||
@@ -510,6 +513,59 @@ export const en = {
|
||||
gpustack: "Gpustack",
|
||||
bedrock: "Bedrock"
|
||||
},
|
||||
modelNew: {
|
||||
group: 'Model Group',
|
||||
list: 'Model List',
|
||||
square: 'Model Plaza',
|
||||
createGroupModel: 'Create Model Group',
|
||||
groupSearchPlaceholder: 'Search model groups',
|
||||
listSearchPlaceholder: 'Search available models',
|
||||
squareSearchPlaceholder: 'Search platform models',
|
||||
status: 'Model Status',
|
||||
created_at: 'Created At',
|
||||
configureBtn: 'Click to Configure',
|
||||
showModel: 'Show Model',
|
||||
keyConfig: 'Configure KEY',
|
||||
|
||||
modelConfiguration: 'Model Configuration',
|
||||
logo: 'Model LOGO',
|
||||
name: 'Model Name',
|
||||
type: 'Model Type',
|
||||
modelImplement: 'Model Implementation',
|
||||
addImplement: 'Add Implementation',
|
||||
noAuth: 'Unauthorized (Limited to 1 implementation)',
|
||||
implementConfig: 'Configure Model Implementation',
|
||||
provider: 'Model Provider',
|
||||
api_key_ids: 'Select Model',
|
||||
viewAll: 'More',
|
||||
modelCount: 'Total {{count}} models',
|
||||
modelList: 'Model List',
|
||||
added: ' Added',
|
||||
addSuccess: 'Added successfully',
|
||||
model_name: 'Model Name',
|
||||
tags: 'Tags',
|
||||
createCustomModel: 'Add Custom Model',
|
||||
edit: 'Edit',
|
||||
selectOneTip: 'Model API KEY not configured, please configure in Model Plaza first',
|
||||
|
||||
api_key: 'API KEY',
|
||||
api_base: 'API Base URL',
|
||||
description: 'Description',
|
||||
add: 'Add',
|
||||
item: 'item',
|
||||
apiKeyNum: ' API Keys',
|
||||
|
||||
llm: 'LLM',
|
||||
chat: 'Chat',
|
||||
embedding: 'Embedding',
|
||||
rerank: 'Rerank',
|
||||
openai: "Openai",
|
||||
dashscope: "Dashscope",
|
||||
ollama: "Ollama",
|
||||
xinference: "Xinference",
|
||||
gpustack: "Gpustack",
|
||||
bedrock: "Bedrock"
|
||||
},
|
||||
knowledgeBase: {
|
||||
pleaseUploadFileFirst: 'Please upload file first',
|
||||
shareSuccess: 'Share successfully',
|
||||
@@ -1175,6 +1231,12 @@ export const en = {
|
||||
priority: 'Structured Integration',
|
||||
addTool: 'Add Tool',
|
||||
tool: 'Tool',
|
||||
|
||||
statistics: 'Data Statistics',
|
||||
daily_conversations: 'Daily Conversations',
|
||||
daily_new_users: 'Daily New Users',
|
||||
daily_api_calls: 'Daily API Calls',
|
||||
daily_tokens: 'Token Consumption',
|
||||
},
|
||||
userMemory: {
|
||||
userMemory: 'User Memory',
|
||||
@@ -1534,7 +1596,9 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
noPermissionDesc: ' Please contact the administrator to grant permission',
|
||||
tableEmpty: 'No data available.',
|
||||
loadingEmpty: 'The content is loading…',
|
||||
loadingEmptyDesc: 'Your content is on its way by rocket! It will soon land on your screen'
|
||||
loadingEmptyDesc: 'Your content is on its way by rocket! It will soon land on your screen',
|
||||
pageEmpty: 'Oops! No search results available at the moment',
|
||||
pageEmptyDesc: "Red Bear tilts its head and waits for you to change a new keyword, let's explore together.",
|
||||
},
|
||||
apiKey: {
|
||||
name: 'Project Name',
|
||||
|
||||
@@ -658,7 +658,13 @@ export const zh = {
|
||||
priority: '结构化整合',
|
||||
addTool: '添加工具',
|
||||
tool: '工具',
|
||||
variableConfig: '配置变量'
|
||||
variableConfig: '配置变量',
|
||||
|
||||
statistics: '数据统计',
|
||||
daily_conversations: '消息会话数',
|
||||
daily_new_users: '新增用户数',
|
||||
daily_api_calls: '调用次数',
|
||||
daily_tokens: 'Token消耗',
|
||||
},
|
||||
role: {
|
||||
roleManagement: '角色管理',
|
||||
@@ -967,6 +973,9 @@ export const zh = {
|
||||
statusEnabled: '可用',
|
||||
statusDisabled: '不可用',
|
||||
remove: '删除',
|
||||
|
||||
fileSizeTip: '文件大小不能超过 {{size}}MB',
|
||||
fileAcceptTip: '不支持的文件类型:'
|
||||
},
|
||||
product: {
|
||||
applicationManagement: '应用管理',
|
||||
@@ -1076,6 +1085,59 @@ export const zh = {
|
||||
gpustack: "Gpustack",
|
||||
bedrock: "Bedrock"
|
||||
},
|
||||
modelNew: {
|
||||
group: '模型组合',
|
||||
list: '模型列表',
|
||||
square: '模型广场',
|
||||
createGroupModel: '创建模型组合',
|
||||
groupSearchPlaceholder: '搜索模型组合',
|
||||
listSearchPlaceholder: '搜索可用模型',
|
||||
squareSearchPlaceholder: '搜索平台模型',
|
||||
status: '模型状态',
|
||||
created_at: '创建时间',
|
||||
configureBtn: '点击配置',
|
||||
showModel: '显示模型',
|
||||
keyConfig: '配置 KEY',
|
||||
|
||||
modelConfiguration: '模型配置',
|
||||
logo: '模型LOGO',
|
||||
name: '模型名称',
|
||||
type: '模型类型',
|
||||
modelImplement: '模型实现',
|
||||
addImplement: '添加实现',
|
||||
noAuth: '未授权(限1个实现)',
|
||||
implementConfig: '配置模型实现',
|
||||
provider: '模型供应商',
|
||||
api_key_ids: '选择模型',
|
||||
viewAll: '更多',
|
||||
modelCount: '共 {{count}} 个模型',
|
||||
modelList: '模型列表',
|
||||
added: ' 已添加',
|
||||
addSuccess: '添加成功',
|
||||
model_name: '模型名称',
|
||||
tags: '标签',
|
||||
createCustomModel: '添加自定义模型',
|
||||
edit: '编辑',
|
||||
selectOneTip: '模型未配置API KEY,请先在模型广场配置',
|
||||
|
||||
api_key: 'API KEY',
|
||||
api_base: 'API Base URL',
|
||||
description: '描述',
|
||||
add: '添加',
|
||||
item: '个',
|
||||
apiKeyNum: '个 API Key',
|
||||
|
||||
llm: 'LLM',
|
||||
chat: 'Chat',
|
||||
embedding: 'Embedding',
|
||||
rerank: 'Rerank',
|
||||
openai: "Openai",
|
||||
dashscope: "Dashscope",
|
||||
ollama: "Ollama",
|
||||
xinference: "Xinference",
|
||||
gpustack: "Gpustack",
|
||||
bedrock: "Bedrock"
|
||||
},
|
||||
timezones: {
|
||||
'Asia/Shanghai': '中国标准时间 (UTC+8)',
|
||||
'Asia/Kolkata': '印度标准时间 (UTC+5:30)',
|
||||
@@ -1607,7 +1669,9 @@ export const zh = {
|
||||
noPermissionDesc: '请联系管理员授予权限',
|
||||
tableEmpty: '目前没有数据',
|
||||
loadingEmpty: '内容正在加载中…',
|
||||
loadingEmptyDesc: '您的内容正在火箭运输中!很快就会降落在您的屏幕上'
|
||||
loadingEmptyDesc: '您的内容正在火箭运输中!很快就会降落在您的屏幕上',
|
||||
pageEmpty: '哎呀!暂无搜索结果',
|
||||
pageEmptyDesc: '红熊歪着头等待您更换新的关键词,让我们一起探索吧。',
|
||||
},
|
||||
|
||||
home: {
|
||||
|
||||
@@ -22,7 +22,7 @@ export const lightTheme: ThemeConfig = {
|
||||
// colorBgContainer: '#FBFDFF',
|
||||
colorError: '#FF5D34',
|
||||
sizeSM: 12,
|
||||
fontSizeSM: 12,
|
||||
fontSizeSM: 12,
|
||||
},
|
||||
components: {
|
||||
Layout: {
|
||||
@@ -105,6 +105,9 @@ export const lightTheme: ThemeConfig = {
|
||||
},
|
||||
Select: {
|
||||
lineHeightSM: 26
|
||||
},
|
||||
Upload: {
|
||||
pictureCardSize: 96,
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -23,9 +23,10 @@ interface data {
|
||||
}
|
||||
|
||||
|
||||
export const API_PREFIX = '/api'
|
||||
// 创建axios实例
|
||||
const service = axios.create({
|
||||
baseURL: '/api', // 与vite.config.ts中的代理配置对应
|
||||
baseURL: API_PREFIX, // 与vite.config.ts中的代理配置对应
|
||||
// timeout: 10000, // 请求超时时间
|
||||
withCredentials: false,
|
||||
headers: {
|
||||
@@ -126,7 +127,7 @@ service.interceptors.response.use(
|
||||
if (axios.isCancel(error) || error.name === 'AbortError' || error.code === 'ERR_CANCELED') {
|
||||
return Promise.reject(error);
|
||||
}
|
||||
|
||||
|
||||
// 处理网络错误、超时等
|
||||
let msg = error.response?.data?.error || error.response?.error;
|
||||
const status = error?.response ? error.response.status : error;
|
||||
|
||||
@@ -20,7 +20,7 @@ import type {
|
||||
} from './types'
|
||||
import type { Variable } from './components/VariableList/types'
|
||||
import type { KnowledgeConfig } from './components/Knowledge/types'
|
||||
import type { Model } from '@/views/ModelManagement/types'
|
||||
import type { ModelListItem } from '@/views/ModelManagement/types'
|
||||
import { getModelList } from '@/api/models';
|
||||
import { saveAgentConfig } from '@/api/application'
|
||||
import Knowledge from './components/Knowledge/Knowledge'
|
||||
@@ -79,7 +79,7 @@ const SelectWrapper: FC<{ title: string, desc: string, name: string | string[],
|
||||
placeholder={t('common.pleaseSelect')}
|
||||
url={url}
|
||||
hasAll={false}
|
||||
valueKey={['config_id_old', 'config_id']}
|
||||
valueKey='config_id'
|
||||
labelKey="config_name"
|
||||
/>
|
||||
</Form.Item>
|
||||
@@ -96,8 +96,8 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [data, setData] = useState<Config | null>(null);
|
||||
const modelConfigModalRef = useRef<ModelConfigModalRef>(null)
|
||||
const [modelList, setModelList] = useState<Model[]>([])
|
||||
const [defaultModel, setDefaultModel] = useState<Model | null>(null)
|
||||
const [modelList, setModelList] = useState<ModelListItem[]>([])
|
||||
const [defaultModel, setDefaultModel] = useState<ModelListItem | null>(null)
|
||||
const [chatList, setChatList] = useState<ChatData[]>([])
|
||||
const values = Form.useWatch<Config>([], form)
|
||||
const [isSave, setIsSave] = useState(false)
|
||||
@@ -126,14 +126,12 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
|
||||
getApplicationConfig(id as string).then(res => {
|
||||
const response = res as Config
|
||||
let allTools = Array.isArray(response.tools) ? response.tools : []
|
||||
const memoryContent = response.memory?.memory_content
|
||||
const convertedMemoryContent = memoryContent && !isNaN(Number(memoryContent)) ? Number(memoryContent) : memoryContent
|
||||
form.setFieldsValue({
|
||||
...response,
|
||||
tools: allTools,
|
||||
memory: {
|
||||
...response.memory,
|
||||
memory_content: convertedMemoryContent
|
||||
memory_content: response.memory?.memory_content ? Number(response.memory?.memory_content) : undefined
|
||||
}
|
||||
})
|
||||
setData({
|
||||
@@ -214,7 +212,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
|
||||
...data.knowledge_retrieval,
|
||||
...knowledgeRest,
|
||||
knowledge_bases: knowledge_bases.map(item => ({
|
||||
kb_id: item.id,
|
||||
kb_id: item.kb_id || item.id,
|
||||
...(item.config || {})
|
||||
}))
|
||||
} as KnowledgeConfig : null,
|
||||
@@ -239,9 +237,9 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
|
||||
})
|
||||
}
|
||||
const getModels = () => {
|
||||
getModelList({ type: 'llm,chat', pagesize: 100, page: 1 })
|
||||
getModelList({ type: 'llm,chat', pagesize: 100, page: 1, is_active: true })
|
||||
.then(res => {
|
||||
const response = res as { items: Model[] }
|
||||
const response = res as { items: ModelListItem[] }
|
||||
setModelList(response.items)
|
||||
})
|
||||
}
|
||||
@@ -251,7 +249,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
|
||||
useEffect(() => {
|
||||
if (values?.default_model_config_id && modelList.length > 0) {
|
||||
const filterValue = modelList.find(item => item.id === values.default_model_config_id)
|
||||
setDefaultModel(filterValue as Model | null)
|
||||
setDefaultModel(filterValue as ModelListItem | null)
|
||||
setChatList([{
|
||||
label: filterValue?.name || '',
|
||||
model_config_id: filterValue?.id || '',
|
||||
|
||||
@@ -225,7 +225,7 @@ const Cluster = forwardRef<ClusterRef>((_props, ref) => {
|
||||
<Form.Item name="default_model_config_id" noStyle>
|
||||
<CustomSelect
|
||||
url={getModelListUrl}
|
||||
params={{ type: 'llm,chat', pagesize: 100 }}
|
||||
params={{ type: 'llm,chat', pagesize: 100, is_active: true }}
|
||||
valueKey="id"
|
||||
labelKey="name"
|
||||
hasAll={false}
|
||||
|
||||
86
web/src/views/ApplicationConfig/Statistics.tsx
Normal file
86
web/src/views/ApplicationConfig/Statistics.tsx
Normal file
@@ -0,0 +1,86 @@
|
||||
import { type FC, useState, useEffect } from 'react';
|
||||
import { Row, Col, Flex, DatePicker } from 'antd';
|
||||
import type { Dayjs } from 'dayjs'
|
||||
import dayjs from 'dayjs';
|
||||
|
||||
const { RangePicker } = DatePicker;
|
||||
|
||||
import type { Application } from '@/views/ApplicationManagement/types'
|
||||
import { getAppStatistics } from '@/api/application';
|
||||
import LineCard from './components/LineCard'
|
||||
import type { StatisticsData, StatisticsItem } from './types'
|
||||
|
||||
const TotalObj: Record<string, keyof StatisticsData> = {
|
||||
daily_conversations: 'total_conversations',
|
||||
daily_new_users: 'total_new_users',
|
||||
daily_api_calls: 'total_api_calls',
|
||||
daily_tokens: 'total_tokens',
|
||||
}
|
||||
const Statistics: FC<{ application: Application | null }> = ({ application }) => {
|
||||
const [data, setData] = useState<StatisticsData>({
|
||||
daily_conversations: [],
|
||||
total_conversations: 0,
|
||||
daily_new_users: [],
|
||||
total_new_users: 0,
|
||||
daily_api_calls: [],
|
||||
total_api_calls: 0,
|
||||
daily_tokens: [],
|
||||
total_tokens: 0
|
||||
})
|
||||
const [query, setQuery] = useState({
|
||||
start_date: dayjs().subtract(6, 'd'),
|
||||
end_date: dayjs().subtract(0, 'd'),
|
||||
})
|
||||
|
||||
useEffect(() => {
|
||||
getData()
|
||||
}, [application, query])
|
||||
const getData = () => {
|
||||
if (!application?.id) {
|
||||
return
|
||||
}
|
||||
const params = {
|
||||
start_date: query.start_date.startOf('d').valueOf(),
|
||||
end_date: query.end_date.endOf('d').valueOf(),
|
||||
}
|
||||
|
||||
getAppStatistics(application.id, params)
|
||||
.then(res => {
|
||||
setData(res as StatisticsData)
|
||||
})
|
||||
}
|
||||
const handleChange = (date: [Dayjs | null, Dayjs | null] | null) => {
|
||||
if (!date || !date[0] || !date[1]) return
|
||||
setQuery({
|
||||
start_date: date[0],
|
||||
end_date: date[1],
|
||||
})
|
||||
}
|
||||
return (
|
||||
<div className="rb:w-250 rb:mt-5 rb:pb-5 rb:mx-auto">
|
||||
<Row gutter={[16, 16]}>
|
||||
<Col span={24}>
|
||||
<Flex justify="end">
|
||||
<RangePicker defaultValue={[query.start_date, query.end_date]} onChange={handleChange} />
|
||||
</Flex>
|
||||
</Col>
|
||||
{Object.entries(data).map(([key, value]) => {
|
||||
if (key.includes('total')) {
|
||||
return null
|
||||
}
|
||||
const totalKey = TotalObj[key];
|
||||
return (
|
||||
<Col span={12} key={key}>
|
||||
<LineCard
|
||||
type={key}
|
||||
total={totalKey ? (data[totalKey] as number) : 0}
|
||||
chartData={value as StatisticsItem[]}
|
||||
/>
|
||||
</Col>
|
||||
)
|
||||
})}
|
||||
</Row>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
export default Statistics;
|
||||
@@ -181,7 +181,7 @@ const AiPromptModal = forwardRef<AiPromptModalRef, AiPromptModalProps>(({
|
||||
>
|
||||
<CustomSelect
|
||||
url={getModelListUrl}
|
||||
params={{ type: 'llm,chat', pagesize: 100 }}
|
||||
params={{ type: 'llm,chat', pagesize: 100, is_active: true }}
|
||||
valueKey="id"
|
||||
labelKey="name"
|
||||
hasAll={false}
|
||||
|
||||
@@ -17,7 +17,7 @@ import CopyModal from './CopyModal'
|
||||
|
||||
const { Header } = Layout;
|
||||
|
||||
const tabKeys = ['arrangement', 'api', 'release']
|
||||
const tabKeys = ['arrangement', 'api', 'release', 'statistics']
|
||||
const menuIcons: Record<string, string> = {
|
||||
edit: editIcon,
|
||||
copy: copyIcon,
|
||||
|
||||
@@ -97,7 +97,7 @@ const KnowledgeGlobalConfigModal = forwardRef<KnowledgeGlobalConfigModalRef, Kno
|
||||
>
|
||||
<CustomSelect
|
||||
url={getModelListUrl}
|
||||
params={{ type: 'rerank', pagesize: 100 }}
|
||||
params={{ type: 'rerank', pagesize: 100, is_active: true }}
|
||||
valueKey="id"
|
||||
labelKey="name"
|
||||
hasAll={false}
|
||||
|
||||
127
web/src/views/ApplicationConfig/components/LineCard.tsx
Normal file
127
web/src/views/ApplicationConfig/components/LineCard.tsx
Normal file
@@ -0,0 +1,127 @@
|
||||
import { type FC, useEffect, useRef } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import ReactEcharts from 'echarts-for-react';
|
||||
import * as echarts from 'echarts';
|
||||
import Empty from '@/components/Empty'
|
||||
|
||||
import Card from './Card'
|
||||
import type { StatisticsItem } from '../types'
|
||||
|
||||
interface LineCardProps {
|
||||
chartData: StatisticsItem[];
|
||||
type: string;
|
||||
total: number;
|
||||
}
|
||||
|
||||
const SeriesConfig = {
|
||||
type: 'line',
|
||||
stack: 'Total',
|
||||
smooth: true,
|
||||
lineStyle: {
|
||||
width: 3
|
||||
},
|
||||
showSymbol: true,
|
||||
label: {
|
||||
show: false,
|
||||
position: 'top'
|
||||
},
|
||||
emphasis: {
|
||||
focus: 'series'
|
||||
},
|
||||
}
|
||||
|
||||
const ColorObj: Record<string, string> = {
|
||||
daily_conversations: '#FFB048',
|
||||
daily_new_users: '#4DA8FF',
|
||||
daily_api_calls: '#155EEF',
|
||||
daily_tokens: '#AD88FF'
|
||||
}
|
||||
|
||||
const LineCard: FC<LineCardProps> = ({ chartData, type, total }) => {
|
||||
const { t } = useTranslation()
|
||||
const chartRef = useRef<ReactEcharts>(null);
|
||||
|
||||
useEffect(() => {
|
||||
|
||||
}, [chartData])
|
||||
|
||||
const getSeries = () => {
|
||||
return [{
|
||||
...SeriesConfig,
|
||||
name: t(`application.${type}`),
|
||||
data: chartData.map(vo => vo.count),
|
||||
areaStyle: {
|
||||
opacity: 0.8,
|
||||
color: new echarts.graphic.LinearGradient(0, 0, 0, 1, [
|
||||
{ offset: 0, color: ColorObj[type] },
|
||||
{ offset: 1, color: '#FFFFFF' }
|
||||
])
|
||||
},
|
||||
}]
|
||||
}
|
||||
|
||||
return (
|
||||
<Card
|
||||
title={<div>{t(`application.${type}`)} <span className="rb:text-[#155EEF] rb:font-medium rb:text-[18px]">{total}</span></div>}
|
||||
>
|
||||
{chartData && chartData.length > 0 ? (
|
||||
<ReactEcharts
|
||||
ref={chartRef}
|
||||
option={{
|
||||
color: [ColorObj[type]],
|
||||
tooltip: {
|
||||
trigger: 'axis',
|
||||
extraCssText: 'box-shadow: 0px 2px 6px 0px rgba(33,35,50,0.16); border-radius: 8px;',
|
||||
axisPointer: {
|
||||
type: 'line',
|
||||
crossStyle: {
|
||||
color: '#5F6266',
|
||||
},
|
||||
lineStyle: {
|
||||
color: '#5F6266',
|
||||
},
|
||||
label: {
|
||||
show: false
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
grid: {
|
||||
top: 10,
|
||||
left: 15,
|
||||
right: 40,
|
||||
bottom: 0,
|
||||
containLabel: true
|
||||
},
|
||||
xAxis: {
|
||||
type: 'category',
|
||||
data: chartData.map(item => item.date),
|
||||
boundaryGap: false,
|
||||
},
|
||||
yAxis: {
|
||||
type: 'value',
|
||||
axisLabel: {
|
||||
color: '#A8A9AA',
|
||||
fontFamily: 'PingFangSC, PingFang SC',
|
||||
align: 'right',
|
||||
lineHeight: 17,
|
||||
},
|
||||
axisLine: {
|
||||
lineStyle: {
|
||||
color: '#EBEBEB',
|
||||
}
|
||||
},
|
||||
},
|
||||
series: getSeries()
|
||||
}}
|
||||
style={{ height: '265px', width: '100%', minWidth: '100%', boxSizing: 'border-box' }}
|
||||
opts={{ renderer: 'canvas' }}
|
||||
notMerge={true}
|
||||
lazyUpdate={true}
|
||||
/>
|
||||
) : <Empty size={120} className="rb:mt-12 rb:mb-20.25" />}
|
||||
</Card>
|
||||
)
|
||||
}
|
||||
|
||||
export default LineCard
|
||||
@@ -9,6 +9,7 @@ import ReleasePage from './ReleasePage'
|
||||
import Cluster from './Cluster'
|
||||
import { getApplication } from '@/api/application'
|
||||
import Workflow from '@/views/Workflow';
|
||||
import Statistics from './Statistics'
|
||||
|
||||
const ApplicationConfig: React.FC = () => {
|
||||
const { id } = useParams();
|
||||
@@ -68,6 +69,7 @@ const ApplicationConfig: React.FC = () => {
|
||||
{activeTab === 'arrangement' && application?.type === 'workflow' && <Workflow ref={workflowRef} />}
|
||||
{activeTab === 'api' && <Api application={application} />}
|
||||
{activeTab === 'release' && <ReleasePage data={application as Application} refresh={getApplicationInfo} />}
|
||||
{activeTab === 'statistics' && <Statistics application={application} />}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -150,4 +150,19 @@ export interface AiPromptForm {
|
||||
}
|
||||
export interface ChatVariableConfigModalRef {
|
||||
handleOpen: (values: Variable[]) => void;
|
||||
}
|
||||
|
||||
export interface StatisticsItem {
|
||||
count: number;
|
||||
date: string;
|
||||
}
|
||||
export interface StatisticsData {
|
||||
daily_conversations: StatisticsItem[];
|
||||
daily_new_users: StatisticsItem[];
|
||||
daily_api_calls: StatisticsItem[];
|
||||
daily_tokens: StatisticsItem[];
|
||||
total_conversations: number;
|
||||
total_new_users: number;
|
||||
total_api_calls: number;
|
||||
total_tokens: number;
|
||||
}
|
||||
@@ -20,7 +20,7 @@ const configList = [
|
||||
key: 'emotion_model_id',
|
||||
type: 'customSelect',
|
||||
url: getModelListUrl,
|
||||
params: { type: 'chat,llm', page: 1, pagesize: 100 }, // chat,llm
|
||||
params: { type: 'chat,llm', page: 1, pagesize: 100, is_active: true }, // chat,llm
|
||||
},
|
||||
{
|
||||
key: 'emotion_min_intensity',
|
||||
|
||||
@@ -39,7 +39,7 @@ const MemberManagement: React.FC = () => {
|
||||
onOk: () => {
|
||||
deleteMember(member.id)
|
||||
.then(() => {
|
||||
message.success(t('member.deleteSuccess'));
|
||||
message.success(t('common.deleteSuccess'));
|
||||
refreshTable();
|
||||
})
|
||||
}
|
||||
@@ -93,7 +93,7 @@ const MemberManagement: React.FC = () => {
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="rb:flex rb:justify-end rb:mb-[12px]">
|
||||
<div className="rb:flex rb:justify-end rb:mb-3">
|
||||
<Button type="primary" onClick={() => handleEdit()}>
|
||||
{t('member.createMember')}
|
||||
</Button>
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import { type FC, useState, useEffect } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useParams } from 'react-router-dom'
|
||||
import { Row, Col, Space, Switch, Select, InputNumber, Slider, App, Form } from 'antd'
|
||||
import { Row, Col, Space, Select, InputNumber, Slider, App, Form } from 'antd'
|
||||
import clsx from 'clsx'
|
||||
import Card from './components/Card'
|
||||
import type { ConfigForm, Variable } from './types'
|
||||
import { getMemoryExtractionConfig, updateMemoryExtractionConfig } from '@/api/memory'
|
||||
import Markdown from '@/components/Markdown'
|
||||
import { getModelList } from '@/api/models';
|
||||
import type { Model } from '@/views/ModelManagement/types'
|
||||
import type { ModelListItem } from '@/views/ModelManagement/types'
|
||||
import { configList } from './constant'
|
||||
import Result from './components/Result'
|
||||
import SwitchFormItem from '@/components/FormItem/SwitchFormItem'
|
||||
@@ -43,7 +43,7 @@ const MemoryExtractionEngine: FC = () => {
|
||||
const values = Form.useWatch<ConfigForm>([], form)
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [iterationPeriodDisabled, setIterationPeriodDisabled] = useState(false)
|
||||
const [modelList, setModelList] = useState<Model[]>([])
|
||||
const [modelList, setModelList] = useState<ModelListItem[]>([])
|
||||
|
||||
useEffect(() => {
|
||||
if (values?.reflexion_range === 'database') {
|
||||
@@ -55,9 +55,9 @@ const MemoryExtractionEngine: FC = () => {
|
||||
}, [values])
|
||||
|
||||
const getModels = () => {
|
||||
getModelList({ type: 'llm,chat', pagesize: 100, page: 1 })
|
||||
getModelList({ type: 'llm,chat', pagesize: 100, page: 1, is_active: true })
|
||||
.then(res => {
|
||||
const response = res as { items: Model[] }
|
||||
const response = res as { items: ModelListItem[] }
|
||||
setModelList(response.items)
|
||||
})
|
||||
}
|
||||
|
||||
97
web/src/views/ModelManagement/Group.tsx
Normal file
97
web/src/views/ModelManagement/Group.tsx
Normal file
@@ -0,0 +1,97 @@
|
||||
import { useState, useEffect, forwardRef, useImperativeHandle } from 'react';
|
||||
import clsx from 'clsx'
|
||||
import { Button } from 'antd'
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { ProviderModelItem, ModelListItem, DescriptionItem, BaseRef } from './types'
|
||||
import RbCard from '@/components/RbCard/Card'
|
||||
import { getModelNewList } from '@/api/models'
|
||||
import PageEmpty from '@/components/Empty/PageEmpty';
|
||||
import { formatDateTime } from '@/utils/format';
|
||||
|
||||
const Group = forwardRef <BaseRef,{ query: any; handleEdit: (data: ModelListItem) => void; }>(({ query, handleEdit }, ref) => {
|
||||
const { t } = useTranslation();
|
||||
const [list, setList] = useState<ModelListItem[]>([])
|
||||
useEffect(() => {
|
||||
getList()
|
||||
}, [query])
|
||||
const getList = () => {
|
||||
getModelNewList({
|
||||
...query,
|
||||
is_composite: true,
|
||||
is_active: true,
|
||||
})
|
||||
.then(res => {
|
||||
const response = res as ProviderModelItem[]
|
||||
setList(response[0]?.models || [])
|
||||
})
|
||||
}
|
||||
const formatData = (data: ModelListItem) => {
|
||||
return [
|
||||
{
|
||||
key: 'type',
|
||||
label: t(`modelNew.type`),
|
||||
children: data.type || '-',
|
||||
},
|
||||
{
|
||||
key: 'provider',
|
||||
label: t(`modelNew.provider`),
|
||||
children: data.provider || '-',
|
||||
},
|
||||
{
|
||||
key: 'is_active',
|
||||
label: t(`modelNew.status`),
|
||||
children: data.is_active ? t(`common.statusEnabled`) : t(`common.statusDisabled`),
|
||||
},
|
||||
{
|
||||
key: 'created_at',
|
||||
label: t(`modelNew.created_at`),
|
||||
children: data.created_at ? formatDateTime(data.created_at, 'YYYY-MM-DD HH:mm:ss') : '-',
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
useImperativeHandle(ref, () => ({
|
||||
getList,
|
||||
}));
|
||||
|
||||
return (
|
||||
<>
|
||||
{list.length === 0
|
||||
? <PageEmpty />
|
||||
:(
|
||||
<div className="rb:grid rb:grid-cols-4 rb:gap-4">
|
||||
{list.map(item => (
|
||||
<RbCard
|
||||
key={item.id}
|
||||
title={item.name}
|
||||
avatarUrl={item.logo}
|
||||
avatar={
|
||||
<div className="rb:w-12 rb:h-12 rb:rounded-lg rb:mr-3.25 rb:bg-[#155eef] rb:flex rb:items-center rb:justify-center rb:text-[28px] rb:text-[#ffffff]">
|
||||
{item.name[0]}
|
||||
</div>
|
||||
}
|
||||
>
|
||||
{formatData(item)?.map((description: DescriptionItem) => (
|
||||
<div
|
||||
key={description.key}
|
||||
className="rb:flex rb:justify-between rb:text-[#5B6167] rb:text-[14px] rb:leading-5 rb:mb-3"
|
||||
>
|
||||
<span className="rb:whitespace-nowrap">{(description.label as string)}</span>
|
||||
<span className={clsx({
|
||||
"rb:text-[#212332]": description.key !== 'is_active',
|
||||
"rb:text-[#369F21] rb:font-medium": description.key === 'is_active' && item.is_active,
|
||||
})}>{(description.children as string)}</span>
|
||||
</div>
|
||||
))}
|
||||
<Button className="rb:mt-2" type="primary" ghost block onClick={() => handleEdit(item)}>{t('modelNew.configureBtn')}</Button>
|
||||
</RbCard>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</>
|
||||
)
|
||||
})
|
||||
|
||||
export default Group
|
||||
83
web/src/views/ModelManagement/List.tsx
Normal file
83
web/src/views/ModelManagement/List.tsx
Normal file
@@ -0,0 +1,83 @@
|
||||
import { useRef, useState, useEffect, type FC } from 'react';
|
||||
import { Button, Space, Row, Col } from 'antd'
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { ProviderModelItem, KeyConfigModalRef, ModelListDetailRef } from './types'
|
||||
import RbCard from '@/components/RbCard/Card'
|
||||
import { getModelNewList } from '@/api/models'
|
||||
import PageEmpty from '@/components/Empty/PageEmpty';
|
||||
import Tag from '@/components/Tag';
|
||||
import KeyConfigModal from './components/KeyConfigModal'
|
||||
import ModelListDetail from './components/ModelListDetail'
|
||||
|
||||
const ModelList: FC<{ query: any }> = ({ query }) => {
|
||||
const { t } = useTranslation();
|
||||
const keyConfigModalRef = useRef<KeyConfigModalRef>(null)
|
||||
const modelListDetailRef = useRef<ModelListDetailRef>(null)
|
||||
const [list, setList] = useState<ProviderModelItem[]>([])
|
||||
useEffect(() => {
|
||||
getList()
|
||||
}, [query])
|
||||
const getList = () => {
|
||||
getModelNewList({
|
||||
...query,
|
||||
is_composite: false,
|
||||
is_active: true,
|
||||
})
|
||||
.then(res => {
|
||||
setList((res || []) as ProviderModelItem[])
|
||||
})
|
||||
}
|
||||
|
||||
const handleShowModel = (vo: ProviderModelItem) => {
|
||||
modelListDetailRef.current?.handleOpen(vo)
|
||||
}
|
||||
const handleKeyConfig = (vo: ProviderModelItem) => {
|
||||
keyConfigModalRef.current?.handleOpen(vo)
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
{list.length === 0
|
||||
? <PageEmpty />
|
||||
:(
|
||||
<div className="rb:grid rb:grid-cols-4 rb:gap-4">
|
||||
{list.map(item => (
|
||||
<RbCard
|
||||
key={item.provider}
|
||||
title={item.provider}
|
||||
avatarUrl={item.logo}
|
||||
avatar={
|
||||
<div className="rb:w-12 rb:h-12 rb:rounded-lg rb:mr-3.25 rb:bg-[#155eef] rb:flex rb:items-center rb:justify-center rb:text-[28px] rb:text-[#ffffff]">
|
||||
{item.provider[0]}
|
||||
</div>
|
||||
}
|
||||
>
|
||||
<Space>{item.tags.map(tag => <Tag key={tag}>{t(`modelNew.${tag}`)}</Tag>)}</Space>
|
||||
<Row gutter={12} className="rb:mt-4">
|
||||
<Col span={12}>
|
||||
<Button block onClick={() => handleShowModel(item)}>{t('modelNew.showModel')}</Button>
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
<Button type="primary" ghost block onClick={() => handleKeyConfig(item)}>{t('modelNew.keyConfig')}</Button>
|
||||
</Col>
|
||||
</Row>
|
||||
</RbCard>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
<KeyConfigModal
|
||||
ref={keyConfigModalRef}
|
||||
refresh={getList}
|
||||
/>
|
||||
<ModelListDetail
|
||||
ref={modelListDetailRef}
|
||||
refresh={getList}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default ModelList
|
||||
95
web/src/views/ModelManagement/Square.tsx
Normal file
95
web/src/views/ModelManagement/Square.tsx
Normal file
@@ -0,0 +1,95 @@
|
||||
import { useRef, useState, useEffect, forwardRef, useImperativeHandle } from 'react';
|
||||
import { Button, Space, App, Divider, Flex } from 'antd'
|
||||
import { UsergroupAddOutlined } from '@ant-design/icons';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { ModelPlaza, ModelPlazaItem, ModelSquareDetailRef, BaseRef } from './types'
|
||||
import RbCard from '@/components/RbCard/Card'
|
||||
import { getModelPlaza, addModelPlaza } from '@/api/models'
|
||||
import PageEmpty from '@/components/Empty/PageEmpty';
|
||||
import Tag from '@/components/Tag';
|
||||
import ModelSquareDetail from './components/ModelSquareDetail'
|
||||
|
||||
const ModelSquare = forwardRef <BaseRef, { query: any; handleEdit: (vo?: ModelPlazaItem) => void; }>(({ query, handleEdit }, ref) => {
|
||||
const { t } = useTranslation();
|
||||
const { message } = App.useApp()
|
||||
const modelSquareDetailRef = useRef<ModelSquareDetailRef>(null)
|
||||
const [list, setList] = useState<ModelPlaza[]>([])
|
||||
useEffect(() => {
|
||||
getList()
|
||||
}, [query])
|
||||
const getList = () => {
|
||||
getModelPlaza(query)
|
||||
.then(res => {
|
||||
setList((res as ModelPlaza[]) || [])
|
||||
})
|
||||
}
|
||||
|
||||
const handleMore = (vo: ModelPlaza) => {
|
||||
modelSquareDetailRef.current?.handleOpen(vo)
|
||||
}
|
||||
const handleAdd = (item: ModelPlazaItem) => {
|
||||
addModelPlaza(item.id)
|
||||
.then(() => {
|
||||
message.success(`${item.name}${t('modelNew.addSuccess')}`)
|
||||
getList()
|
||||
})
|
||||
}
|
||||
|
||||
useImperativeHandle(ref, () => ({
|
||||
getList,
|
||||
}));
|
||||
return (
|
||||
<>
|
||||
{list.length === 0
|
||||
? <PageEmpty />
|
||||
: list.map(vo => (
|
||||
<div key={vo.provider}>
|
||||
<div className="rb:flex rb:justify-between rb:items-center rb:bg-[rgba(21,94,239,0.12)] rb:px-4 rb:py-2.5 rb:leading-5 rb:mb-4 rb:mt-6 rb:rounded-md">
|
||||
<div className="rb:font-medium">{vo.provider}</div>
|
||||
<Button type="link" onClick={() => handleMore(vo)}>{t('modelNew.viewAll')}({t(`modelNew.modelCount`, { count: vo.models.length })})></Button>
|
||||
</div>
|
||||
|
||||
<div className="rb:grid rb:grid-cols-3 rb:gap-4">
|
||||
{vo.models.slice(0, 6).map(item => (
|
||||
<RbCard
|
||||
key={item.id}
|
||||
title={item.name}
|
||||
avatarUrl={item.logo}
|
||||
avatar={
|
||||
<div className="rb:w-12 rb:h-12 rb:rounded-lg rb:mr-3.25 rb:bg-[#155eef] rb:flex rb:items-center rb:justify-center rb:text-[28px] rb:text-[#ffffff]">
|
||||
{item.name[0]}
|
||||
</div>
|
||||
}
|
||||
>
|
||||
<Tag>{t(`modelNew.${item.type}`)}</Tag>
|
||||
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.5 rb:mt-3 rb:h-9">{item.description}</div>
|
||||
<Space size={8} className="rb:mt-3">{item.tags.map((tag, tagIndex) => <Tag key={tagIndex}>{tag}</Tag>)}</Space>
|
||||
<Divider size="middle" />
|
||||
<Flex justify="space-between">
|
||||
<Space size={8}><UsergroupAddOutlined /> {item.add_count}</Space>
|
||||
<Space>
|
||||
{!item.is_official && <Button type="primary" disabled={item.is_deprecated} onClick={() => handleEdit(item)}>{t('modelNew.edit')}</Button>}
|
||||
{item.is_added
|
||||
? <Button type="primary" disabled>{t('modelNew.added')}</Button>
|
||||
: <Button type="primary" ghost disabled={item.is_deprecated} onClick={() => handleAdd(item)}>+ {t('common.add')}</Button>
|
||||
}
|
||||
</Space>
|
||||
</Flex>
|
||||
</RbCard>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
))
|
||||
}
|
||||
|
||||
<ModelSquareDetail
|
||||
ref={modelSquareDetailRef}
|
||||
refresh={getList}
|
||||
handleEdit={handleEdit}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
})
|
||||
|
||||
export default ModelSquare
|
||||
@@ -1,171 +0,0 @@
|
||||
import { forwardRef, useImperativeHandle, useState } from 'react';
|
||||
import { Form, Input, App } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { ModelFormData, Model, ConfigModalRef, ConfigModalProps } from '../types';
|
||||
import RbModal from '@/components/RbModal'
|
||||
import CustomSelect from '@/components/CustomSelect'
|
||||
import { updateModel, addModel, modelTypeUrl, modelProviderUrl } from '@/api/models'
|
||||
|
||||
const ConfigModal = forwardRef<ConfigModalRef, ConfigModalProps>(({
|
||||
refresh
|
||||
}, ref) => {
|
||||
const { t } = useTranslation();
|
||||
const { message } = App.useApp();
|
||||
const [visible, setVisible] = useState(false);
|
||||
const [model, setModel] = useState<Model>({} as Model);
|
||||
const [isEdit, setIsEdit] = useState(false);
|
||||
const [form] = Form.useForm<ModelFormData>();
|
||||
const [loading, setLoading] = useState(false)
|
||||
|
||||
const values = Form.useWatch<ModelFormData>([], form);
|
||||
|
||||
// 封装取消方法,添加关闭弹窗逻辑
|
||||
const handleClose = () => {
|
||||
setModel({} as Model);
|
||||
form.resetFields();
|
||||
setLoading(false)
|
||||
setVisible(false);
|
||||
};
|
||||
|
||||
const handleOpen = (model?: Model) => {
|
||||
if (model) {
|
||||
setIsEdit(true);
|
||||
setModel(model);
|
||||
// 设置表单值
|
||||
const apiKeyInfo = model.api_keys[0]
|
||||
form.setFieldsValue({
|
||||
provider: apiKeyInfo.provider,
|
||||
model_name: apiKeyInfo.model_name,
|
||||
api_key: apiKeyInfo.api_key,
|
||||
api_base: apiKeyInfo.api_base
|
||||
});
|
||||
} else {
|
||||
setIsEdit(false);
|
||||
form.resetFields();
|
||||
}
|
||||
setVisible(true);
|
||||
};
|
||||
// 封装保存方法,添加提交逻辑
|
||||
const handleSave = () => {
|
||||
form
|
||||
.validateFields()
|
||||
.then(() => {
|
||||
const data = {
|
||||
name: values.name,
|
||||
type: values.type,
|
||||
api_keys: {
|
||||
provider: values.provider,
|
||||
model_name: values.model_name,
|
||||
api_key: values.api_key,
|
||||
api_base: values.api_base
|
||||
},
|
||||
}
|
||||
setLoading(true)
|
||||
const res = isEdit
|
||||
? updateModel(model.api_keys[0].id, {
|
||||
provider: values.provider,
|
||||
model_name: values.model_name,
|
||||
api_key: values.api_key,
|
||||
api_base: values.api_base
|
||||
} as ModelFormData)
|
||||
: addModel(data as ModelFormData)
|
||||
|
||||
res.then(() => {
|
||||
if (refresh) {
|
||||
refresh();
|
||||
}
|
||||
handleClose()
|
||||
message.success(isEdit ? t('common.updateSuccess') : t('common.createSuccess'))
|
||||
})
|
||||
.catch(() => {
|
||||
setLoading(false)
|
||||
});
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log('err', err)
|
||||
});
|
||||
}
|
||||
|
||||
// 暴露给父组件的方法
|
||||
useImperativeHandle(ref, () => ({
|
||||
handleOpen,
|
||||
handleClose
|
||||
}));
|
||||
|
||||
return (
|
||||
<RbModal
|
||||
title={isEdit ? `${model.name} - ${t('model.modelConfiguration')}` : t('model.createModel')}
|
||||
open={visible}
|
||||
onCancel={handleClose}
|
||||
okText={t(`common.${isEdit ? 'save' : 'create'}`)}
|
||||
onOk={handleSave}
|
||||
confirmLoading={loading}
|
||||
>
|
||||
<Form
|
||||
form={form}
|
||||
layout="vertical"
|
||||
initialValues={{}}
|
||||
>
|
||||
{!isEdit && (
|
||||
<>
|
||||
<Form.Item
|
||||
name="name"
|
||||
label={t('model.displayName')}
|
||||
rules={[{ required: true, message: t('common.inputPlaceholder', { title: t('model.displayName') }) }]}
|
||||
>
|
||||
<Input placeholder={t('common.pleaseEnter')} />
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
name="type"
|
||||
label={t('model.type')}
|
||||
rules={[{ required: true, message: t('common.selectPlaceholder', { title: t('model.type') }) }]}
|
||||
>
|
||||
<CustomSelect
|
||||
url={modelTypeUrl}
|
||||
hasAll={false}
|
||||
format={(items) => items.map((item) => ({ label: t(`model.${item}`), value: item }))}
|
||||
/>
|
||||
</Form.Item>
|
||||
</>
|
||||
)}
|
||||
|
||||
|
||||
<Form.Item
|
||||
name="provider"
|
||||
label={t('model.provider')}
|
||||
rules={[{ required: true, message: t('common.selectPlaceholder', { title: t('model.provider') }) }]}
|
||||
>
|
||||
<CustomSelect
|
||||
url={modelProviderUrl}
|
||||
hasAll={false}
|
||||
format={(items) => items.map((item) => ({ label: t(`model.${item}`), value: item }))}
|
||||
/>
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
name="model_name"
|
||||
label={t('model.modelName')}
|
||||
rules={[{ required: true, message: t('common.inputPlaceholder', { title: t('model.modelName') }) }]}
|
||||
>
|
||||
<Input placeholder={t('common.pleaseEnter')} />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="api_key"
|
||||
label={t('model.apiKey')}
|
||||
rules={[{ required: true, message: t('common.inputPlaceholder', { title: t('model.apiKey') }) }]}
|
||||
>
|
||||
<Input.Password placeholder={t('common.pleaseEnter')} />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="api_base"
|
||||
label={t('model.apiEndpoint')}
|
||||
>
|
||||
<Input placeholder="https://api.example.com/v1" />
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</RbModal>
|
||||
);
|
||||
});
|
||||
|
||||
export default ConfigModal;
|
||||
165
web/src/views/ModelManagement/components/CustomModelModal.tsx
Normal file
165
web/src/views/ModelManagement/components/CustomModelModal.tsx
Normal file
@@ -0,0 +1,165 @@
|
||||
import { forwardRef, useImperativeHandle, useState } from 'react';
|
||||
import { Form, Input, App, Select } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { CustomModelForm, ModelPlazaItem, CustomModelModalRef, CustomModelModalProps } from '../types';
|
||||
import RbModal from '@/components/RbModal'
|
||||
import CustomSelect from '@/components/CustomSelect'
|
||||
import UploadImages from '@/components/Upload/UploadImages'
|
||||
import { updateCustomModel, addCustomModel, modelTypeUrl, modelProviderUrl } from '@/api/models'
|
||||
import { getFileLink } from '@/api/fileStorage'
|
||||
|
||||
const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(({
|
||||
refresh
|
||||
}, ref) => {
|
||||
const { t } = useTranslation();
|
||||
const { message } = App.useApp();
|
||||
const [visible, setVisible] = useState(false);
|
||||
const [model, setModel] = useState<ModelPlazaItem>({} as ModelPlazaItem);
|
||||
const [isEdit, setIsEdit] = useState(false);
|
||||
const [form] = Form.useForm<CustomModelForm>();
|
||||
const [loading, setLoading] = useState(false)
|
||||
const formValues = Form.useWatch([], form)
|
||||
|
||||
const handleClose = () => {
|
||||
setModel({} as ModelPlazaItem);
|
||||
form.resetFields();
|
||||
setLoading(false)
|
||||
setVisible(false);
|
||||
};
|
||||
|
||||
const handleOpen = (model?: ModelPlazaItem) => {
|
||||
if (model) {
|
||||
setIsEdit(true);
|
||||
setModel(model);
|
||||
form.setFieldsValue({
|
||||
...model,
|
||||
logo: model.logo ? { url: model.logo, uid: model.logo, status: 'done', name: 'logo' } : undefined
|
||||
});
|
||||
} else {
|
||||
setIsEdit(false);
|
||||
form.resetFields();
|
||||
}
|
||||
setVisible(true);
|
||||
};
|
||||
const handleUpdate = (data: CustomModelForm) => {
|
||||
setLoading(true)
|
||||
const res = isEdit ? updateCustomModel(model.id, data) : addCustomModel(data)
|
||||
|
||||
res.then(() => {
|
||||
refresh && refresh()
|
||||
handleClose()
|
||||
message.success(isEdit ? t('common.updateSuccess') : t('common.createSuccess'))
|
||||
})
|
||||
.catch(() => {
|
||||
setLoading(false)
|
||||
});
|
||||
}
|
||||
const handleSave = () => {
|
||||
form
|
||||
.validateFields()
|
||||
.then((values) => {
|
||||
setLoading(true)
|
||||
const { logo, ...rest } = values;
|
||||
let formData: CustomModelForm = {
|
||||
...rest
|
||||
}
|
||||
formData.is_official = false;
|
||||
|
||||
if (typeof logo === 'object' && logo?.response?.data.file_id) {
|
||||
getFileLink(logo?.response?.data.file_id)
|
||||
.then(res => {
|
||||
const logoRes = res as { url: string }
|
||||
formData.logo = logoRes.url
|
||||
handleUpdate(formData)
|
||||
})
|
||||
.catch(() => {
|
||||
handleUpdate(formData)
|
||||
})
|
||||
} else {
|
||||
formData.logo = typeof logo === 'string' ? logo : logo.url
|
||||
handleUpdate(formData)
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log('err', err)
|
||||
});
|
||||
}
|
||||
|
||||
useImperativeHandle(ref, () => ({
|
||||
handleOpen,
|
||||
}));
|
||||
|
||||
console.log('formValues', formValues)
|
||||
|
||||
return (
|
||||
<RbModal
|
||||
title={isEdit ? `${model.name} - ${t('modelNew.modelConfiguration')}` : t('modelNew.createCustomModel')}
|
||||
open={visible}
|
||||
onCancel={handleClose}
|
||||
okText={t(`common.${isEdit ? 'save' : 'create'}`)}
|
||||
onOk={handleSave}
|
||||
confirmLoading={loading}
|
||||
>
|
||||
<Form
|
||||
form={form}
|
||||
layout="vertical"
|
||||
>
|
||||
<Form.Item
|
||||
name="logo"
|
||||
label={t('modelNew.logo')}
|
||||
valuePropName="fileList"
|
||||
rules={[{ required: true, message: t('common.pleaseSelect') }]}
|
||||
>
|
||||
<UploadImages />
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
name="name"
|
||||
label={t('modelNew.name')}
|
||||
rules={[{ required: true, message: t('common.inputPlaceholder', { title: t('modelNew.name') }) }]}
|
||||
>
|
||||
<Input placeholder={t('common.pleaseEnter')} />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="type"
|
||||
label={t('modelNew.type')}
|
||||
rules={[{ required: true, message: t('common.selectPlaceholder', { title: t('modelNew.type') }) }]}
|
||||
>
|
||||
<CustomSelect
|
||||
url={modelTypeUrl}
|
||||
hasAll={false}
|
||||
format={(items) => items.map((item) => ({ label: t(`modelNew.${item}`), value: String(item) }))}
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="provider"
|
||||
label={t('modelNew.provider')}
|
||||
rules={[{ required: true, message: t('common.selectPlaceholder', { title: t('modelNew.provider') }) }]}
|
||||
>
|
||||
<CustomSelect
|
||||
url={modelProviderUrl}
|
||||
hasAll={false}
|
||||
format={(items) => items.map((item) => ({ label: t(`modelNew.${item}`), value: String(item) }))}
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="description"
|
||||
label={t('modelNew.description')}
|
||||
>
|
||||
<Input.TextArea placeholder={t('common.pleaseEnter')} />
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
name="tags"
|
||||
label={t('modelNew.tags')}
|
||||
>
|
||||
<Select mode="tags" placeholder={t('common.pleaseEnter')} />
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</RbModal>
|
||||
);
|
||||
});
|
||||
|
||||
export default CustomModelModal;
|
||||
158
web/src/views/ModelManagement/components/GroupModelModal.tsx
Normal file
158
web/src/views/ModelManagement/components/GroupModelModal.tsx
Normal file
@@ -0,0 +1,158 @@
|
||||
import { forwardRef, useImperativeHandle, useState } from 'react';
|
||||
import { Form, Input, App } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { ModelListItem, CompositeModelForm, GroupModelModalRef, GroupModelModalProps, ModelApiKey } from '../types';
|
||||
import RbModal from '@/components/RbModal'
|
||||
import CustomSelect from '@/components/CustomSelect'
|
||||
import { updateCompositeModel, modelTypeUrl, addCompositeModel } from '@/api/models'
|
||||
import UploadImages from '@/components/Upload/UploadImages'
|
||||
import ModelImplement from './ModelImplement'
|
||||
import { getFileLink } from '@/api/fileStorage'
|
||||
|
||||
const GroupModelModal = forwardRef<GroupModelModalRef, GroupModelModalProps>(({
|
||||
refresh
|
||||
}, ref) => {
|
||||
const { t } = useTranslation();
|
||||
const { message } = App.useApp();
|
||||
const [visible, setVisible] = useState(false);
|
||||
const [model, setModel] = useState<ModelListItem>({} as ModelListItem);
|
||||
const [isEdit, setIsEdit] = useState(false);
|
||||
const [form] = Form.useForm<CompositeModelForm>();
|
||||
const [loading, setLoading] = useState(false)
|
||||
const type = Form.useWatch(['type'], form)
|
||||
|
||||
const handleClose = () => {
|
||||
setModel({} as ModelListItem);
|
||||
form.resetFields();
|
||||
setLoading(false)
|
||||
setVisible(false);
|
||||
};
|
||||
|
||||
const handleOpen = (model?: ModelListItem) => {
|
||||
if (model) {
|
||||
setIsEdit(true);
|
||||
setModel(model);
|
||||
form.setFieldsValue({
|
||||
...model,
|
||||
api_key_ids: model.api_keys,
|
||||
logo: model.logo ? { url: model.logo, uid: model.logo, status: 'done', name: 'logo' } : undefined
|
||||
})
|
||||
} else {
|
||||
setIsEdit(false);
|
||||
form.resetFields();
|
||||
}
|
||||
setVisible(true);
|
||||
};
|
||||
const handleSave = () => {
|
||||
form
|
||||
.validateFields()
|
||||
.then((values) => {
|
||||
const { api_key_ids = [], logo, ...rest } = values
|
||||
|
||||
const formData: CompositeModelForm = {
|
||||
...rest,
|
||||
api_key_ids: api_key_ids.map(vo => (vo as ModelApiKey).id)
|
||||
}
|
||||
|
||||
if (logo?.response?.data.file_id) {
|
||||
getFileLink(logo?.response?.data.file_id).then(res => {
|
||||
const logoRes = res as { url: string }
|
||||
formData.logo = logoRes.url
|
||||
handleUpdate(formData)
|
||||
}).catch(() => {
|
||||
handleUpdate(formData)
|
||||
})
|
||||
} else {
|
||||
formData.logo = typeof logo === 'string' ? logo : logo.url
|
||||
handleUpdate(formData)
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log('err', err)
|
||||
});
|
||||
}
|
||||
|
||||
const handleUpdate = (data: CompositeModelForm) => {
|
||||
setLoading(true)
|
||||
const res = isEdit
|
||||
? updateCompositeModel(model.id, data)
|
||||
: addCompositeModel(data)
|
||||
|
||||
res.then(() => {
|
||||
refresh?.();
|
||||
handleClose()
|
||||
message.success(isEdit ? t('common.updateSuccess') : t('common.createSuccess'))
|
||||
})
|
||||
.catch(() => {
|
||||
setLoading(false)
|
||||
});
|
||||
}
|
||||
|
||||
useImperativeHandle(ref, () => ({
|
||||
handleOpen,
|
||||
handleClose
|
||||
}));
|
||||
|
||||
return (
|
||||
<RbModal
|
||||
title={isEdit ? `${model.name} - ${t('modelNew.modelConfiguration')}` : t('modelNew.createGroupModel')}
|
||||
open={visible}
|
||||
onCancel={handleClose}
|
||||
okText={t(`common.${isEdit ? 'save' : 'create'}`)}
|
||||
onOk={handleSave}
|
||||
confirmLoading={loading}
|
||||
>
|
||||
<Form
|
||||
form={form}
|
||||
layout="vertical"
|
||||
>
|
||||
<Form.Item
|
||||
name="logo"
|
||||
label={t('modelNew.logo')}
|
||||
valuePropName="fileList"
|
||||
rules={[{ required: true, message: t('common.pleaseSelect') }]}
|
||||
>
|
||||
<UploadImages />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="name"
|
||||
label={t('modelNew.name')}
|
||||
rules={[{ required: true, message: t('common.pleaseEnter') }]}
|
||||
>
|
||||
<Input placeholder={t('common.pleaseEnter')} />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="type"
|
||||
label={t('modelNew.type')}
|
||||
rules={[{ required: true, message: t('common.selectPlaceholder', { title: t('modelNew.type') }) }]}
|
||||
>
|
||||
<CustomSelect
|
||||
url={modelTypeUrl}
|
||||
hasAll={false}
|
||||
format={(items) => items.map((item) => ({
|
||||
label: t(`modelNew.${typeof item === 'object' ? item.value : item}`),
|
||||
value: typeof item === 'object' ? item.value : item
|
||||
}))}
|
||||
disabled={isEdit}
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="description"
|
||||
label={t('modelNew.description')}
|
||||
>
|
||||
<Input.TextArea placeholder={t('common.pleaseEnter')} />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item name="api_key_ids">
|
||||
<ModelImplement type={type} />
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</RbModal>
|
||||
);
|
||||
});
|
||||
|
||||
export default GroupModelModal;
|
||||
92
web/src/views/ModelManagement/components/KeyConfigModal.tsx
Normal file
92
web/src/views/ModelManagement/components/KeyConfigModal.tsx
Normal file
@@ -0,0 +1,92 @@
|
||||
import { forwardRef, useImperativeHandle, useState } from 'react';
|
||||
import { Form, Input, App } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { KeyConfigModalForm, ProviderModelItem, KeyConfigModalRef, KeyConfigModalProps } from '../types';
|
||||
import RbModal from '@/components/RbModal'
|
||||
import { updateProviderApiKeys } from '@/api/models'
|
||||
|
||||
const KeyConfigModal = forwardRef<KeyConfigModalRef, KeyConfigModalProps>(({
|
||||
refresh
|
||||
}, ref) => {
|
||||
const { t } = useTranslation();
|
||||
const { message } = App.useApp();
|
||||
const [visible, setVisible] = useState(false);
|
||||
const [model, setModel] = useState<ProviderModelItem>({} as ProviderModelItem);
|
||||
const [form] = Form.useForm<KeyConfigModalForm>();
|
||||
const [loading, setLoading] = useState(false)
|
||||
|
||||
const handleClose = () => {
|
||||
setModel({} as ProviderModelItem);
|
||||
form.resetFields();
|
||||
setLoading(false)
|
||||
setVisible(false);
|
||||
};
|
||||
|
||||
const handleOpen = (vo: ProviderModelItem) => {
|
||||
setVisible(true);
|
||||
setModel(vo);
|
||||
};
|
||||
const handleSave = () => {
|
||||
form
|
||||
.validateFields()
|
||||
.then((values) => {
|
||||
setLoading(true)
|
||||
|
||||
updateProviderApiKeys({
|
||||
...values,
|
||||
provider: model.provider
|
||||
}).then(() => {
|
||||
if (refresh) {
|
||||
refresh();
|
||||
}
|
||||
handleClose()
|
||||
message.success(t('common.updateSuccess'))
|
||||
})
|
||||
.catch(() => {
|
||||
setLoading(false)
|
||||
});
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log('err', err)
|
||||
});
|
||||
}
|
||||
|
||||
useImperativeHandle(ref, () => ({
|
||||
handleOpen,
|
||||
handleClose
|
||||
}));
|
||||
|
||||
return (
|
||||
<RbModal
|
||||
title={`${model.provider} - ${t('modelNew.keyConfig')}`}
|
||||
open={visible}
|
||||
onCancel={handleClose}
|
||||
okText={t(`common.save`)}
|
||||
onOk={handleSave}
|
||||
confirmLoading={loading}
|
||||
>
|
||||
<Form
|
||||
form={form}
|
||||
layout="vertical"
|
||||
>
|
||||
<Form.Item
|
||||
name="api_key"
|
||||
label={t('modelNew.api_key')}
|
||||
rules={[{ required: true, message: t('common.inputPlaceholder', { title: t('modelNew.apiKey') }) }]}
|
||||
>
|
||||
<Input.Password placeholder={t('common.pleaseEnter')} />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="api_base"
|
||||
label={t('modelNew.api_base')}
|
||||
rules={[{ required: true, message: t('common.inputPlaceholder', { title: t('modelNew.api_base') }) }]}
|
||||
>
|
||||
<Input placeholder="https://api.example.com/v1" />
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</RbModal>
|
||||
);
|
||||
});
|
||||
|
||||
export default KeyConfigModal;
|
||||
@@ -0,0 +1,164 @@
|
||||
import { forwardRef, useImperativeHandle, useState } from 'react';
|
||||
import { Form, Cascader, App } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { SubModelModalForm, SubModelModalRef, SubModelModalProps, ModelList } from './types';
|
||||
import RbModal from '@/components/RbModal'
|
||||
import CustomSelect from '@/components/CustomSelect'
|
||||
import { modelProviderUrl, getModelNewList } from '@/api/models'
|
||||
import type { ProviderModelItem } from '../../types'
|
||||
|
||||
const { SHOW_CHILD } = Cascader;
|
||||
|
||||
interface Option {
|
||||
value: string | number;
|
||||
label: string;
|
||||
children?: Option[];
|
||||
[key: string]: any;
|
||||
}
|
||||
const SubModelModal = forwardRef<SubModelModalRef, SubModelModalProps>(({
|
||||
refresh,
|
||||
type
|
||||
}, ref) => {
|
||||
const { t } = useTranslation();
|
||||
const { message } = App.useApp()
|
||||
const [visible, setVisible] = useState(false);
|
||||
const [form] = Form.useForm<SubModelModalForm>();
|
||||
const [selecteds, setSelecteds] = useState<any[]>([])
|
||||
const [modelList, setModelList] = useState<Option[]>([])
|
||||
|
||||
// 封装取消方法,添加关闭弹窗逻辑
|
||||
const handleClose = () => {
|
||||
form.resetFields();
|
||||
setVisible(false);
|
||||
setSelecteds([])
|
||||
};
|
||||
|
||||
const handleOpen = (list?: ModelList[], provider?: string) => {
|
||||
if (list?.length && provider) {
|
||||
const initialValue: SubModelModalForm = {
|
||||
provider,
|
||||
api_key_ids: list.map(vo => {
|
||||
return [vo.model_config_ids[0], vo.id]
|
||||
})
|
||||
}
|
||||
|
||||
form.setFieldsValue(initialValue);
|
||||
handleChangeProvider(provider, initialValue.api_key_ids)
|
||||
} else {
|
||||
form.resetFields()
|
||||
}
|
||||
setVisible(true);
|
||||
};
|
||||
// 封装保存方法,添加提交逻辑
|
||||
const handleSave = () => {
|
||||
form
|
||||
.validateFields()
|
||||
.then(() => {
|
||||
refresh?.(selecteds.map(vo => ({
|
||||
...vo[0],
|
||||
model_name: vo[0].name,
|
||||
model_config_ids: [vo[0].id],
|
||||
id: vo[1].value
|
||||
})))
|
||||
handleClose()
|
||||
})
|
||||
}
|
||||
const handleChange = (value: (string | number)[][], selectedOptions: Option[][]) => {
|
||||
const filterList = selectedOptions.filter(vo => vo.length === 1).map(item => item[0])
|
||||
const lastFilterLit = value.filter(vo => vo.length !== 1)
|
||||
console.log('onchange', value, lastFilterLit, selectedOptions, filterList)
|
||||
if (filterList.length) {
|
||||
message.warning(`【${filterList.map(vo => vo.label)}】${t('modelNew.selectOneTip')}`)
|
||||
form.setFieldValue('api_key_ids', lastFilterLit)
|
||||
}
|
||||
setSelecteds(selectedOptions)
|
||||
}
|
||||
|
||||
const handleChangeProvider = (provider: string, api_key_ids?: any[]) => {
|
||||
form.setFieldValue('api_key_ids', undefined)
|
||||
getModelNewList({
|
||||
provider: provider,
|
||||
is_composite: false,
|
||||
is_active: true,
|
||||
type
|
||||
})
|
||||
.then(res => {
|
||||
const response = res as ProviderModelItem[]
|
||||
const list = response[0]?.models || []
|
||||
setModelList(list.map(vo => {
|
||||
const children = vo.api_keys.map(item => ({
|
||||
label: item.api_key,
|
||||
value: item.id,
|
||||
}))
|
||||
return {
|
||||
...vo,
|
||||
label: vo.name,
|
||||
value: vo.id,
|
||||
children: children
|
||||
}
|
||||
}))
|
||||
|
||||
if (api_key_ids?.length) {
|
||||
form.setFieldsValue({
|
||||
api_key_ids: api_key_ids
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 暴露给父组件的方法
|
||||
useImperativeHandle(ref, () => ({
|
||||
handleOpen,
|
||||
}));
|
||||
|
||||
return (
|
||||
<RbModal
|
||||
title={t('modelNew.implementConfig')}
|
||||
open={visible}
|
||||
onCancel={handleClose}
|
||||
okText={t('common.save')}
|
||||
onOk={handleSave}
|
||||
>
|
||||
<Form
|
||||
form={form}
|
||||
layout="vertical"
|
||||
>
|
||||
<Form.Item
|
||||
name="provider"
|
||||
label={t('modelNew.provider')}
|
||||
rules={[{ required: true, message: t('common.selectPlaceholder', { title: t('modelNew.provider') }) }]}
|
||||
>
|
||||
<CustomSelect
|
||||
placeholder={t('common.pleaseSelect')}
|
||||
url={modelProviderUrl}
|
||||
hasAll={false}
|
||||
format={(items) => items.map((item) => ({
|
||||
label: t(`modelNew.${typeof item === 'object' ? item.value : item}`),
|
||||
value: typeof item === 'object' ? item.value : item
|
||||
}))}
|
||||
onChange={(value) => handleChangeProvider(value)}
|
||||
/>
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
name="api_key_ids"
|
||||
label={t('modelNew.api_key_ids')}
|
||||
rules={[{ required: true, message: t('common.selectPlaceholder', { title: t('modelNew.api_key_ids') }) }]}
|
||||
>
|
||||
<Cascader
|
||||
placeholder={t('common.pleaseSelect')}
|
||||
options={modelList}
|
||||
onChange={handleChange}
|
||||
multiple
|
||||
autoClearSearchValue
|
||||
className="rb:w-full!"
|
||||
showCheckedStrategy={SHOW_CHILD}
|
||||
changeOnSelect
|
||||
/>
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</RbModal>
|
||||
);
|
||||
});
|
||||
|
||||
export default SubModelModal;
|
||||
@@ -0,0 +1,106 @@
|
||||
import { type FC, useRef } from "react";
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Flex, Button, Space, App } from 'antd'
|
||||
|
||||
import type { SubModelModalRef, ModelList } from './types'
|
||||
import SubModelModal from './SubModelModal'
|
||||
import Empty from '@/components/Empty'
|
||||
import Tag from '@/components/Tag'
|
||||
|
||||
interface ModelImplementProps {
|
||||
type?: string;
|
||||
value?: any;
|
||||
onChange?: (value: any) => void;
|
||||
}
|
||||
const ModelImplement: FC<ModelImplementProps> = ({ type, value, onChange }) => {
|
||||
const { t } = useTranslation();
|
||||
const { modal, message } = App.useApp();
|
||||
const subModelModalRef = useRef<SubModelModalRef>(null)
|
||||
|
||||
const handleAdd = () => {
|
||||
if (!type || type.trim() === '') {
|
||||
message.warning(t('common.selectPlaceholder', { title: t('modelNew.type') }))
|
||||
return
|
||||
}
|
||||
subModelModalRef.current?.handleOpen()
|
||||
}
|
||||
const handleEdit = (list: ModelList[], provider: string ) => {
|
||||
subModelModalRef.current?.handleOpen(list, provider)
|
||||
}
|
||||
const handleDelete = (provider: string) => {
|
||||
modal.confirm({
|
||||
title: t('common.confirmDeleteDesc', { name: provider }),
|
||||
content: t('application.apiKeyDeleteContent'),
|
||||
okText: t('common.delete'),
|
||||
cancelText: t('common.cancel'),
|
||||
okType: 'danger',
|
||||
onOk: () => {
|
||||
onChange?.(value?.filter((item: any) => item.provider !== provider))
|
||||
}
|
||||
})
|
||||
}
|
||||
const handleRefresh = (list: ModelList[]) => {
|
||||
const existingModels = value || [];
|
||||
let updatedModels = [...existingModels];
|
||||
|
||||
const provider = list[0].provider
|
||||
|
||||
updatedModels = updatedModels.filter(item => item.provider !== provider)
|
||||
updatedModels = [...updatedModels, ...list]
|
||||
|
||||
onChange?.([...updatedModels]);
|
||||
}
|
||||
|
||||
const groupedByProvider: Record<string, ModelList[]> = (value || []).reduce((acc: Record<string, ModelList[]>, item: ModelList) => {
|
||||
const provider = item.provider || 'unknown';
|
||||
if (!acc[provider]) acc[provider] = [];
|
||||
acc[provider].push(item);
|
||||
return acc;
|
||||
}, {} as Record<string, ModelList[]>);
|
||||
|
||||
return (
|
||||
<div>
|
||||
<Flex justify="space-between" align="center">
|
||||
{t('modelNew.modelImplement')}
|
||||
|
||||
<Space>
|
||||
<Button type="primary" onClick={handleAdd} className="rb:px-2! rb:h-6!">+ {t('modelNew.addImplement')}</Button>
|
||||
<Button size="small" className="rb:px-2! rb:h-6!">{t('modelNew.noAuth')}</Button>
|
||||
</Space>
|
||||
</Flex>
|
||||
|
||||
|
||||
<div className="rb:bg-[#F5F6F7] rb:rounded-lg rb:p-3 rb:mt-2">
|
||||
{!value || value.length === 0
|
||||
? <Empty size={88} />
|
||||
: Object.entries(groupedByProvider).map(([provider, items]: [string, ModelList[]]) => {
|
||||
return (
|
||||
<div key={provider} className="rb:mb-4 last:rb:mb-0">
|
||||
<Flex justify="space-between" align="center" className="rb:mb-2 last:rb:mb-0">
|
||||
<div className="rb:font-medium">{[...new Set(items?.map((vo) => vo.model_name))].join(', ')}</div>
|
||||
<Space>
|
||||
<div
|
||||
className="rb:w-6 rb:h-6 rb:cursor-pointer rb:bg-[url('@/assets/images/editBorder.svg')] rb:hover:bg-[url('@/assets/images/editBg.svg')]"
|
||||
onClick={() => handleEdit(items, provider)}
|
||||
></div>
|
||||
<div
|
||||
className="rb:w-6 rb:h-6 rb:cursor-pointer rb:bg-[url('@/assets/images/deleteBorder.svg')] rb:hover:bg-[url('@/assets/images/deleteBg.svg')]"
|
||||
onClick={() => handleDelete(provider)}
|
||||
></div>
|
||||
</Space>
|
||||
</Flex>
|
||||
<Tag className="rb:mb-2">{provider}</Tag>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
<SubModelModal
|
||||
ref={subModelModalRef}
|
||||
refresh={handleRefresh}
|
||||
type={type}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default ModelImplement
|
||||
@@ -0,0 +1,16 @@
|
||||
import type { ModelListItem } from '../../types'
|
||||
|
||||
export interface ModelList extends ModelListItem {
|
||||
api_key_id: string;
|
||||
}
|
||||
export interface SubModelModalForm {
|
||||
provider: string;
|
||||
api_key_ids: string[][];
|
||||
}
|
||||
export interface SubModelModalRef {
|
||||
handleOpen: (list?: ModelList[], provider?: string) => void;
|
||||
}
|
||||
export interface SubModelModalProps {
|
||||
type?: string;
|
||||
refresh?: (vo: ModelList[]) => void;
|
||||
}
|
||||
111
web/src/views/ModelManagement/components/ModelListDetail.tsx
Normal file
111
web/src/views/ModelManagement/components/ModelListDetail.tsx
Normal file
@@ -0,0 +1,111 @@
|
||||
import { useState, useImperativeHandle, forwardRef, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Button, Switch, Row, Col, Space } from 'antd'
|
||||
|
||||
import type { ProviderModelItem, ModelListItem, ModelListDetailRef, MultiKeyConfigModalRef } from '../types';
|
||||
import RbDrawer from '@/components/RbDrawer';
|
||||
import RbCard from '@/components/RbCard/Card'
|
||||
import Tag from '@/components/Tag';
|
||||
import PageEmpty from '@/components/Empty/PageEmpty';
|
||||
import MultiKeyConfigModal from './MultiKeyConfigModal'
|
||||
import { getModelNewList, updateModelStatus } from '@/api/models'
|
||||
|
||||
interface ModelListDetailProps {
|
||||
refresh?: () => void;
|
||||
}
|
||||
|
||||
const ModelListDetail = forwardRef<ModelListDetailRef, ModelListDetailProps>(({ refresh }, ref) => {
|
||||
const { t } = useTranslation();
|
||||
const [open, setOpen] = useState(false);
|
||||
const [data, setData] = useState<ProviderModelItem>({} as ProviderModelItem)
|
||||
const [list, setList] = useState<ModelListItem[]>([])
|
||||
const multiKeyConfigModalRef = useRef<MultiKeyConfigModalRef>(null)
|
||||
const [loading, setLoading] = useState(false)
|
||||
|
||||
const handleOpen = (vo: ProviderModelItem) => {
|
||||
setOpen(true)
|
||||
getData(vo)
|
||||
}
|
||||
|
||||
const getData = (vo: ProviderModelItem) => {
|
||||
if (!vo.provider) return
|
||||
|
||||
getModelNewList({
|
||||
provider: vo.provider
|
||||
})
|
||||
.then(res => {
|
||||
const response = res as ProviderModelItem[]
|
||||
setData(response[0])
|
||||
setList(response[0].models)
|
||||
})
|
||||
}
|
||||
const handleKeyConfig = (vo: ModelListItem) => {
|
||||
multiKeyConfigModalRef.current?.handleOpen(vo, data.provider)
|
||||
}
|
||||
const handleChange = (vo: ModelListItem) => {
|
||||
setLoading(true)
|
||||
updateModelStatus(vo.id, { is_active: !vo.is_active })
|
||||
.finally(() => {
|
||||
getData(data)
|
||||
setLoading(false)
|
||||
})
|
||||
}
|
||||
|
||||
const handleClose = () => {
|
||||
setOpen(false)
|
||||
refresh?.()
|
||||
}
|
||||
const handleRefresh = () => {
|
||||
getData(data)
|
||||
}
|
||||
|
||||
useImperativeHandle(ref, () => ({
|
||||
handleOpen,
|
||||
}));
|
||||
|
||||
return (
|
||||
<RbDrawer
|
||||
title={<>{data.provider} {t('modelNew.modelList')} ({list.length}{t('modelNew.item')})</>}
|
||||
open={open}
|
||||
onClose={handleClose}
|
||||
>
|
||||
{list.length === 0
|
||||
? <PageEmpty />
|
||||
: <div className="rb:grid rb:grid-cols-2 rb:gap-4">
|
||||
{list.map(item => (
|
||||
<RbCard
|
||||
key={item.id}
|
||||
title={item.name}
|
||||
subTitle={<Space>
|
||||
<Tag>{t(`modelNew.${item.type}`)}</Tag>
|
||||
<Tag color="warning">{item.api_keys.length}{t('modelNew.apiKeyNum')}</Tag>
|
||||
</Space>}
|
||||
avatarUrl={item.logo}
|
||||
avatar={
|
||||
<div className="rb:w-12 rb:h-12 rb:rounded-lg rb:mr-3.25 rb:bg-[#155eef] rb:flex rb:items-center rb:justify-center rb:text-[28px] rb:text-[#ffffff]">
|
||||
{item.name[0]}
|
||||
</div>
|
||||
}
|
||||
extra={<Switch defaultChecked={item.is_active} disabled={loading} onChange={() => handleChange(item)} />}
|
||||
>
|
||||
|
||||
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.5 rb:mt-3">{item.description}</div>
|
||||
<Row gutter={12} className="rb:mt-4">
|
||||
<Col span={24}>
|
||||
<Button type="primary" ghost block onClick={() => handleKeyConfig(item)}>{t('modelNew.keyConfig')}</Button>
|
||||
</Col>
|
||||
</Row>
|
||||
</RbCard>
|
||||
))}
|
||||
</div>
|
||||
}
|
||||
|
||||
<MultiKeyConfigModal
|
||||
ref={multiKeyConfigModalRef}
|
||||
refresh={handleRefresh}
|
||||
/>
|
||||
</RbDrawer>
|
||||
);
|
||||
});
|
||||
|
||||
export default ModelListDetail;
|
||||
@@ -0,0 +1,95 @@
|
||||
import { useState, useImperativeHandle, forwardRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Button, Space, App, Flex } from 'antd'
|
||||
import { UsergroupAddOutlined } from '@ant-design/icons';
|
||||
|
||||
import type { ModelPlaza, ModelPlazaItem, ModelSquareDetailRef } from '../types';
|
||||
import RbDrawer from '@/components/RbDrawer';
|
||||
import { getModelPlaza, addModelPlaza } from '@/api/models'
|
||||
import RbCard from '@/components/RbCard/Card'
|
||||
import Tag from '@/components/Tag';
|
||||
import PageEmpty from '@/components/Empty/PageEmpty';
|
||||
|
||||
interface ModelSquareDetailProps {
|
||||
refresh: () => void;
|
||||
handleEdit: (vo: ModelPlazaItem) => void;
|
||||
}
|
||||
const ModelSquareDetail = forwardRef<ModelSquareDetailRef, ModelSquareDetailProps>(({ refresh, handleEdit }, ref) => {
|
||||
const { t } = useTranslation();
|
||||
const { message } = App.useApp()
|
||||
const [model, setModel] = useState<ModelPlaza>({} as ModelPlaza)
|
||||
const [open, setOpen] = useState(false);
|
||||
|
||||
const [list, setList] = useState<ModelPlazaItem[]>([])
|
||||
|
||||
const handleOpen = (vo: ModelPlaza) => {
|
||||
setModel(vo)
|
||||
setOpen(true)
|
||||
getList(vo)
|
||||
}
|
||||
const handleClose = () => {
|
||||
setOpen(false)
|
||||
refresh()
|
||||
}
|
||||
const getList = (vo: ModelPlaza) => {
|
||||
getModelPlaza({ provider: vo.provider })
|
||||
.then(res => {
|
||||
const response = res as ModelPlaza[]
|
||||
setList(response.length > 0 ? response[0].models : [])
|
||||
})
|
||||
}
|
||||
const handleAdd = (item: ModelPlazaItem) => {
|
||||
addModelPlaza(item.id)
|
||||
.then(() => {
|
||||
message.success(`${item.name}${t('modelNew.addSuccess')}`)
|
||||
getList(model)
|
||||
})
|
||||
}
|
||||
|
||||
useImperativeHandle(ref, () => ({
|
||||
handleOpen,
|
||||
}));
|
||||
|
||||
return (
|
||||
<RbDrawer
|
||||
title={<>{model.provider} {t('modelNew.modelList')} ({list.length}{t('modelNew.item')})</>}
|
||||
open={open}
|
||||
onClose={handleClose}
|
||||
>
|
||||
{list.length === 0
|
||||
? <PageEmpty />
|
||||
: <div className="rb:grid rb:grid-cols-2 rb:gap-4">
|
||||
{list.map(item => (
|
||||
<RbCard
|
||||
key={item.id}
|
||||
title={item.name}
|
||||
avatarUrl={item.logo}
|
||||
avatar={
|
||||
<div className="rb:w-12 rb:h-12 rb:rounded-lg rb:mr-3.25 rb:bg-[#155eef] rb:flex rb:items-center rb:justify-center rb:text-[28px] rb:text-[#ffffff]">
|
||||
{item.name[0]}
|
||||
</div>
|
||||
}
|
||||
>
|
||||
<Tag>{t(`modelNew.${item.type}`)}</Tag>
|
||||
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.5 rb:mt-3 rb:h-9">{item.description}</div>
|
||||
<Space size={8} className="rb:mt-3">{item.tags.map((tag, tagIndex) => <Tag key={tagIndex}>{tag}</Tag>)}</Space>
|
||||
|
||||
<Flex justify="space-between">
|
||||
<Space size={8}><UsergroupAddOutlined /> {item.add_count}</Space>
|
||||
<Space>
|
||||
{!item.is_official && <Button type="primary" disabled={item.is_deprecated} onClick={() => handleEdit(item)}>{t('modelNew.edit')}</Button>}
|
||||
{item.is_added
|
||||
? <Button type="primary" disabled>{t('modelNew.added')}</Button>
|
||||
: <Button type="primary" ghost disabled={item.is_deprecated} onClick={() => handleAdd(item)}>+ {t('common.add')}</Button>
|
||||
}
|
||||
</Space>
|
||||
</Flex>
|
||||
</RbCard>
|
||||
))}
|
||||
</div>
|
||||
}
|
||||
</RbDrawer>
|
||||
);
|
||||
});
|
||||
|
||||
export default ModelSquareDetail;
|
||||
122
web/src/views/ModelManagement/components/MultiKeyConfigModal.tsx
Normal file
122
web/src/views/ModelManagement/components/MultiKeyConfigModal.tsx
Normal file
@@ -0,0 +1,122 @@
|
||||
import { forwardRef, useImperativeHandle, useState } from 'react';
|
||||
import { Form, Input, App, Button } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { ModelListItem, MultiKeyForm, MultiKeyConfigModalRef, MultiKeyConfigModalProps } from '../types';
|
||||
import RbModal from '@/components/RbModal'
|
||||
import { addModelApiKey, deleteModelApiKey, getModelInfo } from '@/api/models'
|
||||
|
||||
const MultiKeyConfigModal = forwardRef<MultiKeyConfigModalRef, MultiKeyConfigModalProps>(({ refresh }, ref) => {
|
||||
const { t } = useTranslation();
|
||||
const { message } = App.useApp();
|
||||
const [visible, setVisible] = useState(false);
|
||||
const [model, setModel] = useState<ModelListItem>({} as ModelListItem);
|
||||
const [form] = Form.useForm<MultiKeyForm>();
|
||||
const [loading, setLoading] = useState(false)
|
||||
|
||||
const handleClose = () => {
|
||||
setModel({} as ModelListItem);
|
||||
refresh?.()
|
||||
|
||||
form.resetFields();
|
||||
setLoading(false)
|
||||
setVisible(false);
|
||||
};
|
||||
|
||||
const handleOpen = (vo: ModelListItem) => {
|
||||
setVisible(true);
|
||||
getData(vo)
|
||||
};
|
||||
|
||||
const getData = (vo: ModelListItem) => {
|
||||
if (!vo.id) return
|
||||
|
||||
getModelInfo(vo?.id)
|
||||
.then(res => {
|
||||
setModel(res as ModelListItem)
|
||||
})
|
||||
}
|
||||
const handleSave = () => {
|
||||
form
|
||||
.validateFields()
|
||||
.then((values) => {
|
||||
setLoading(true)
|
||||
addModelApiKey(model.id, {
|
||||
...values,
|
||||
model_config_id: model.id,
|
||||
model_name: model.name,
|
||||
provider: model.provider,
|
||||
}).then(() => {
|
||||
message.success(t('common.saveSuccess'))
|
||||
form.resetFields();
|
||||
getData(model)
|
||||
})
|
||||
.catch(() => {
|
||||
setLoading(false)
|
||||
});
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log('err', err)
|
||||
});
|
||||
}
|
||||
const handleDelete = (api_key_id: string) => {
|
||||
deleteModelApiKey(api_key_id)
|
||||
.then(() => {
|
||||
message.success(t('common.deleteSuccess'))
|
||||
getData(model)
|
||||
})
|
||||
}
|
||||
|
||||
useImperativeHandle(ref, () => ({
|
||||
handleOpen,
|
||||
}));
|
||||
|
||||
return (
|
||||
<RbModal
|
||||
title={`${model.name} - ${t('modelNew.keyConfig')}`}
|
||||
open={visible}
|
||||
onCancel={handleClose}
|
||||
footer={null}
|
||||
confirmLoading={loading}
|
||||
>
|
||||
{model.api_keys && model.api_keys.length > 0 && (
|
||||
<div className="rb:mb-4">
|
||||
{model.api_keys.map((key) => (
|
||||
<div key={key.id} className="rb:flex rb:items-center rb:justify-between rb:p-3 rb:bg-[#F5F6F7] rb:rounded-lg rb:mb-2">
|
||||
<div>
|
||||
<div className="rb:text-[#1D2129] rb:text-[14px] rb:font-medium">{key.api_key}</div>
|
||||
<div className="rb:text-[#5B6167] rb:text-[12px] rb:mt-1">{key.api_base}</div>
|
||||
</div>
|
||||
<Button type="primary" danger ghost onClick={() => handleDelete(key.id)}>{t('common.remove')}</Button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
<Form
|
||||
form={form}
|
||||
layout="vertical"
|
||||
>
|
||||
<Form.Item
|
||||
name="api_key"
|
||||
label={t('modelNew.api_key')}
|
||||
rules={[{ required: true, message: t('common.inputPlaceholder', { title: t('modelNew.api_key') }) }]}
|
||||
>
|
||||
<Input.Password placeholder={t('common.pleaseEnter')} />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="api_base"
|
||||
label={t('modelNew.api_base')}
|
||||
rules={[{ required: true, message: t('common.inputPlaceholder', { title: t('modelNew.api_base') }) }]}
|
||||
>
|
||||
<Input placeholder="https://api.example.com/v1" />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item>
|
||||
<Button type="primary" block onClick={handleSave} loading={loading}>+ {t('modelNew.add')}</Button>
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</RbModal>
|
||||
);
|
||||
});
|
||||
|
||||
export default MultiKeyConfigModal;
|
||||
@@ -1,99 +1,123 @@
|
||||
import { useState, useRef, type FC } from 'react';
|
||||
import { Row, Col, Button } from 'antd'
|
||||
import { Button, Flex, Space, type SegmentedProps } from 'antd'
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import clsx from 'clsx';
|
||||
|
||||
import ConfigModal from './components/ConfigModal'
|
||||
import type { Model, DescriptionItem, ConfigModalRef } from './types'
|
||||
import RbCard from '@/components/RbCard/Card'
|
||||
import GroupModelModal from './components/GroupModelModal'
|
||||
import type { ModelListItem, GroupModelModalRef, CustomModelModalRef, ModelPlazaItem, BaseRef } from './types'
|
||||
import SearchInput from '@/components/SearchInput'
|
||||
import PageScrollList, { type PageScrollListRef } from '@/components/PageScrollList'
|
||||
import { getModelListUrl } from '@/api/models'
|
||||
import { formatDateTime } from '@/utils/format';
|
||||
import PageTabs from '@/components/PageTabs'
|
||||
import GroupModel from './Group'
|
||||
import ModelList from './List'
|
||||
import ModelSquare from './Square'
|
||||
import CustomModelModal from './components/CustomModelModal'
|
||||
import CustomSelect from '@/components/CustomSelect'
|
||||
import { modelTypeUrl, modelProviderUrl } from '@/api/models'
|
||||
|
||||
const tabKeys = ['group', 'list', 'square']
|
||||
const ModelManagement: FC = () => {
|
||||
const { t } = useTranslation();
|
||||
const [activeTab, setActiveTab] = useState('group');
|
||||
const [query, setQuery] = useState({})
|
||||
const configModalRef = useRef<ConfigModalRef>(null)
|
||||
const scrollListRef = useRef<PageScrollListRef>(null)
|
||||
const configModalRef = useRef<GroupModelModalRef>(null)
|
||||
const customModelModalRef = useRef<CustomModelModalRef>(null)
|
||||
const groupRef = useRef<BaseRef>(null)
|
||||
const squareRef = useRef<BaseRef>(null)
|
||||
|
||||
const formatData = (data: Model) => {
|
||||
return [
|
||||
{
|
||||
key: 'type',
|
||||
label: t(`model.type`),
|
||||
children: data.type || '-',
|
||||
},
|
||||
{
|
||||
key: 'provider',
|
||||
label: t(`model.provider`),
|
||||
children: data.api_keys[0].provider || '-',
|
||||
},
|
||||
{
|
||||
key: 'is_active',
|
||||
label: t(`model.status`),
|
||||
children: data.is_active ? t(`common.statusEnabled`) : t(`common.statusDisabled`),
|
||||
},
|
||||
{
|
||||
key: 'created',
|
||||
label: t(`model.created`),
|
||||
children: data.created_at ? formatDateTime(data.created_at, 'YYYY-MM-DD HH:mm:ss') : '-',
|
||||
},
|
||||
]
|
||||
const formatTabItems = () => {
|
||||
return tabKeys.map(value => ({
|
||||
value,
|
||||
label: t(`modelNew.${value}`),
|
||||
}))
|
||||
}
|
||||
const handleChangeTab = (value: SegmentedProps['value']) => {
|
||||
setActiveTab(value as string);
|
||||
setQuery({})
|
||||
}
|
||||
|
||||
const handleEdit = (model?: Model) => {
|
||||
configModalRef?.current?.handleOpen(model)
|
||||
const handleEdit = (vo?: ModelListItem | ModelPlazaItem) => {
|
||||
switch(activeTab) {
|
||||
case 'group':
|
||||
configModalRef?.current?.handleOpen(vo as ModelListItem)
|
||||
break
|
||||
case 'square':
|
||||
customModelModalRef?.current?.handleOpen(vo as ModelPlazaItem)
|
||||
break
|
||||
}
|
||||
}
|
||||
const handleRefresh = () => {
|
||||
switch (activeTab) {
|
||||
case 'group':
|
||||
groupRef.current?.getList()
|
||||
break
|
||||
case 'square':
|
||||
squareRef.current?.getList()
|
||||
break
|
||||
}
|
||||
}
|
||||
const handleSearch = (value?: string) => {
|
||||
setQuery({ search: value })
|
||||
}
|
||||
const handleTypeChange = (value: string) => {
|
||||
setQuery(pre => ({ ...pre, type: value }))
|
||||
}
|
||||
const handleProviderChange = (value: string) => {
|
||||
setQuery(pre => ({ ...pre, provider: value }))
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rb:w-full">
|
||||
<Row className='rb:mb-[16px] rb:w-full'>
|
||||
<Col span={6}>
|
||||
<SearchInput
|
||||
placeholder={t('model.searchPlaceholder')}
|
||||
<>
|
||||
<Flex justify="space-between" align="center">
|
||||
<PageTabs
|
||||
value={activeTab}
|
||||
options={formatTabItems()}
|
||||
onChange={handleChangeTab}
|
||||
/>
|
||||
|
||||
<Space size={12}>
|
||||
{activeTab === 'list' ? <>
|
||||
<CustomSelect
|
||||
url={modelTypeUrl}
|
||||
hasAll={false}
|
||||
format={(items) => items.map((item) => ({ label: t(`modelNew.${item}`), value: String(item) }))}
|
||||
onChange={handleTypeChange}
|
||||
className="rb:w-30"
|
||||
allowClear={true}
|
||||
placeholder={t('modelNew.type')}
|
||||
/>
|
||||
<CustomSelect
|
||||
url={modelProviderUrl}
|
||||
hasAll={false}
|
||||
format={(items) => items.map((item) => ({ label: t(`modelNew.${item}`), value: String(item) }))}
|
||||
onChange={handleProviderChange}
|
||||
className="rb:w-30"
|
||||
allowClear={true}
|
||||
placeholder={t('modelNew.provider')}
|
||||
/>
|
||||
</>
|
||||
: <SearchInput
|
||||
placeholder={t(`modelNew.${activeTab}SearchPlaceholder`)}
|
||||
onSearch={handleSearch}
|
||||
style={{width: '100%'}}
|
||||
/>
|
||||
</Col>
|
||||
<Col span={18} className="rb:text-right">
|
||||
<Button type="primary" onClick={() => handleEdit()}>{t('model.createModel')}</Button>
|
||||
</Col>
|
||||
</Row>
|
||||
className="rb:w-70!"
|
||||
/>}
|
||||
{activeTab === 'group' && <Button type="primary" onClick={() => handleEdit()}>+ {t('modelNew.createGroupModel')}</Button>}
|
||||
{activeTab === 'square' && <Button type="primary" onClick={() => handleEdit()}>+ {t('modelNew.createCustomModel')}</Button>}
|
||||
</Space>
|
||||
</Flex>
|
||||
|
||||
<PageScrollList
|
||||
ref={scrollListRef}
|
||||
url={getModelListUrl}
|
||||
query={query}
|
||||
renderItem={(item: Model) => (
|
||||
<RbCard
|
||||
title={item.name}
|
||||
>
|
||||
{formatData(item)?.map((description: DescriptionItem) => (
|
||||
<div
|
||||
key={description.key}
|
||||
className="rb:flex rb:justify-between rb:text-[#5B6167] rb:text-[14px] rb:leading-[20px] rb:mb-[12px]"
|
||||
>
|
||||
<span className="rb:whitespace-nowrap">{(description.label as string)}</span>
|
||||
<span className={clsx({
|
||||
"rb:text-[#212332]": description.key !== 'is_active',
|
||||
"rb:text-[#369F21] rb:font-medium": description.key === 'is_active' && item.is_active,
|
||||
})}>{(description.children as string)}</span>
|
||||
</div>
|
||||
))}
|
||||
<Button className="rb:mt-[8px]" type="primary" ghost block onClick={() => handleEdit(item)}>{t('model.configureBtn')}</Button>
|
||||
</RbCard>
|
||||
)}
|
||||
/>
|
||||
|
||||
<ConfigModal
|
||||
<div className="rb:w-full rb:h-[calc(100%-48px)] rb:my-4">
|
||||
{activeTab === 'group' && <GroupModel ref={groupRef} query={query} handleEdit={handleEdit} />}
|
||||
{activeTab === 'list' && <ModelList query={query} />}
|
||||
{activeTab === 'square' && <ModelSquare ref={squareRef} query={query} handleEdit={handleEdit} />}
|
||||
</div>
|
||||
<GroupModelModal
|
||||
ref={configModalRef}
|
||||
refresh={() => scrollListRef?.current?.refresh()}
|
||||
refresh={handleRefresh}
|
||||
/>
|
||||
</div>
|
||||
<CustomModelModal
|
||||
ref={customModelModalRef}
|
||||
refresh={handleRefresh}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,70 +1,139 @@
|
||||
// 模型表单数据类型
|
||||
export interface ModelFormData extends ApiKey {
|
||||
name: string;
|
||||
type: string;
|
||||
api_keys: ApiKey;
|
||||
}
|
||||
export interface Query {
|
||||
type?: string;
|
||||
provider?: string;
|
||||
is_active?: boolean;
|
||||
is_public?: boolean;
|
||||
is_composite?: boolean;
|
||||
search?: string;
|
||||
|
||||
pagesize?: number;
|
||||
page?: number;
|
||||
}
|
||||
export interface DescriptionItem {
|
||||
key: string;
|
||||
label: string;
|
||||
children: string;
|
||||
}
|
||||
export interface CompositeModelForm {
|
||||
logo?: any;
|
||||
name: string;
|
||||
type: string;
|
||||
description: string;
|
||||
api_key_ids: ModelApiKey[] | string[];
|
||||
}
|
||||
export interface GroupModelModalRef {
|
||||
handleOpen: (model?: ModelListItem) => void;
|
||||
}
|
||||
export interface GroupModelModalProps {
|
||||
refresh?: () => void;
|
||||
}
|
||||
export interface ModelListDetailRef {
|
||||
handleOpen: (vo: ProviderModelItem) => void;
|
||||
}
|
||||
|
||||
// 模型类型定义
|
||||
export interface Model {
|
||||
|
||||
export interface ModelApiKey {
|
||||
model_name: string;
|
||||
description: string | null;
|
||||
provider: string;
|
||||
api_key: string;
|
||||
api_base: string;
|
||||
config: any;
|
||||
is_active: boolean;
|
||||
priority: string;
|
||||
id: string;
|
||||
usage_count: string;
|
||||
last_used_at: number;
|
||||
created_at: number;
|
||||
updated_at: number;
|
||||
model_config_ids: string[];
|
||||
}
|
||||
export interface ModelListItem {
|
||||
model_name?: string;
|
||||
model_config_ids: string[];
|
||||
name: string;
|
||||
type: string;
|
||||
logo: string;
|
||||
description: string;
|
||||
provider: string;
|
||||
config: any;
|
||||
is_active: boolean;
|
||||
is_public: boolean;
|
||||
id: string;
|
||||
created_at: number;
|
||||
updated_at: number;
|
||||
api_keys: ModelApiKey[]
|
||||
}
|
||||
export interface ProviderModelItem {
|
||||
provider: string;
|
||||
logo?: string;
|
||||
tags: string[];
|
||||
models: ModelListItem[];
|
||||
}
|
||||
export interface KeyConfigModalForm {
|
||||
provider: string;
|
||||
api_key: string;
|
||||
api_base: string;
|
||||
}
|
||||
export interface KeyConfigModalRef {
|
||||
handleOpen: (vo: ProviderModelItem) => void;
|
||||
}
|
||||
export interface KeyConfigModalProps {
|
||||
refresh?: () => void;
|
||||
}
|
||||
export interface MultiKeyForm {
|
||||
model_config_id?: string;
|
||||
model_name: string;
|
||||
provider: string;
|
||||
api_key: string;
|
||||
api_base: string;
|
||||
}
|
||||
|
||||
export interface MultiKeyConfigModalRef {
|
||||
handleOpen: (vo: ModelListItem, provider?: string) => void;
|
||||
}
|
||||
export interface MultiKeyConfigModalProps {
|
||||
refresh?: () => void;
|
||||
}
|
||||
|
||||
|
||||
export interface ModelPlaza {
|
||||
provider: string;
|
||||
models: ModelPlazaItem[];
|
||||
}
|
||||
export interface ModelPlazaItem {
|
||||
id: string;
|
||||
name: string;
|
||||
type: string;
|
||||
description?: string;
|
||||
config: Record<string, unknown>;
|
||||
is_active: boolean;
|
||||
is_public: boolean;
|
||||
created_at: string | number;
|
||||
updated_at: string | number;
|
||||
api_keys: ApiKey[];
|
||||
|
||||
// provider: string;
|
||||
// temperature: number,
|
||||
// topP: number,
|
||||
// status: string;
|
||||
// vectorDimension: number;
|
||||
// batchSize: number;
|
||||
// truncateStrategy: string;
|
||||
// created: string;
|
||||
// updatedAt: string;
|
||||
// descriptionItems?: Record<string, unknown>[];
|
||||
// basicParameters?: string;
|
||||
// normalization?: string;
|
||||
// maxInputLength?: number;
|
||||
// encodingFormat?: string;
|
||||
// enablePooling?: boolean;
|
||||
// poolingStrategy?: string;
|
||||
// apiKey?: string;
|
||||
// apiEndpoint?: string;
|
||||
// timeout?: number;
|
||||
// autoRetry?: boolean;
|
||||
// retryCount?: number;
|
||||
}
|
||||
interface ApiKey {
|
||||
model_name?: string;
|
||||
provider: string;
|
||||
api_key?: string;
|
||||
api_base?: string;
|
||||
config?: Record<string, unknown>;
|
||||
is_active?: boolean;
|
||||
priority?: string;
|
||||
id: string;
|
||||
model_config_id?: string;
|
||||
usage_count?: string;
|
||||
last_used_at?: string | null;
|
||||
created_at?: string;
|
||||
updated_at?: string;
|
||||
logo: string;
|
||||
description: string;
|
||||
is_deprecated: boolean;
|
||||
is_official: boolean;
|
||||
tags: string[];
|
||||
add_count: number;
|
||||
is_added: boolean;
|
||||
}
|
||||
// 定义组件暴露的方法接口
|
||||
export interface ConfigModalRef {
|
||||
handleOpen: (model?: Model) => void;
|
||||
export interface ModelSquareDetailRef {
|
||||
handleOpen: (vo: ModelPlaza) => void;
|
||||
}
|
||||
export interface ConfigModalProps {
|
||||
export interface CustomModelForm {
|
||||
name: string;
|
||||
type: string;
|
||||
provider: string;
|
||||
logo?: any;
|
||||
description: string;
|
||||
is_official: boolean;
|
||||
tags: string[];
|
||||
}
|
||||
export interface CustomModelModalRef {
|
||||
handleOpen: (vo?: ModelPlazaItem) => void;
|
||||
}
|
||||
export interface CustomModelModalProps {
|
||||
refresh?: () => void;
|
||||
}
|
||||
|
||||
|
||||
export interface BaseRef {
|
||||
getList: () => void;
|
||||
}
|
||||
@@ -24,7 +24,7 @@ const configList = [
|
||||
key: 'reflection_model_id',
|
||||
type: 'customSelect',
|
||||
url: getModelListUrl,
|
||||
params: { type: 'chat,llm', page: 1, pagesize: 100 }, // chat,llm
|
||||
params: { type: 'chat,llm', page: 1, pagesize: 100, is_active: true }, // chat,llm
|
||||
},
|
||||
// 迭代周期
|
||||
{
|
||||
|
||||
@@ -66,7 +66,7 @@ const SpaceConfig: FC = () => {
|
||||
>
|
||||
<CustomSelect
|
||||
url={getModelListUrl}
|
||||
params={{ type: 'llm', pagesize: 100 }}
|
||||
params={{ type: 'llm', pagesize: 100, is_active: true }}
|
||||
valueKey="id"
|
||||
labelKey="name"
|
||||
hasAll={false}
|
||||
@@ -80,7 +80,7 @@ const SpaceConfig: FC = () => {
|
||||
>
|
||||
<CustomSelect
|
||||
url={getModelListUrl}
|
||||
params={{ type: 'embedding', pagesize: 100 }}
|
||||
params={{ type: 'embedding', pagesize: 100, is_active: true }}
|
||||
valueKey="id"
|
||||
labelKey="name"
|
||||
hasAll={false}
|
||||
@@ -94,7 +94,7 @@ const SpaceConfig: FC = () => {
|
||||
>
|
||||
<CustomSelect
|
||||
url={getModelListUrl}
|
||||
params={{ type: 'rerank', pagesize: 100 }}
|
||||
params={{ type: 'rerank', pagesize: 100, is_active: true }}
|
||||
valueKey="id"
|
||||
labelKey="name"
|
||||
hasAll={false}
|
||||
|
||||
@@ -8,7 +8,7 @@ import { createWorkspace } from '@/api/workspaces'
|
||||
import RadioGroupCard from '@/components/RadioGroupCard'
|
||||
import { getModelListUrl, getModelList } from '@/api/models'
|
||||
import CustomSelect from '@/components/CustomSelect'
|
||||
import type { Model } from '@/views/ModelManagement/types'
|
||||
import type { ModelListItem } from '@/views/ModelManagement/types'
|
||||
|
||||
const FormItem = Form.Item;
|
||||
|
||||
@@ -29,7 +29,7 @@ const SpaceModal = forwardRef<SpaceModalRef, SpaceModalProps>(({
|
||||
const [form] = Form.useForm<SpaceModalData>();
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [editVo, setEditVo] = useState<Space | null>(null)
|
||||
const [modelList, setModelList] = useState<Model[]>([])
|
||||
const [modelList, setModelList] = useState<ModelListItem[]>([])
|
||||
|
||||
const values = Form.useWatch([], form);
|
||||
|
||||
@@ -80,9 +80,9 @@ const SpaceModal = forwardRef<SpaceModalRef, SpaceModalProps>(({
|
||||
}, [])
|
||||
|
||||
const getModels = () => {
|
||||
getModelList({ type: 'llm,chat', pagesize: 100, page: 1 })
|
||||
getModelList({ type: 'llm,chat', pagesize: 100, page: 1, is_active: true })
|
||||
.then(res => {
|
||||
const response = res as { items: Model[] }
|
||||
const response = res as { items: ModelListItem[] }
|
||||
setModelList(response.items)
|
||||
})
|
||||
}
|
||||
@@ -134,7 +134,7 @@ const SpaceModal = forwardRef<SpaceModalRef, SpaceModalProps>(({
|
||||
>
|
||||
<CustomSelect
|
||||
url={getModelListUrl}
|
||||
params={{ type: 'embedding', pagesize: 100 }}
|
||||
params={{ type: 'embedding', pagesize: 100, is_active: true }}
|
||||
valueKey="id"
|
||||
labelKey="name"
|
||||
hasAll={false}
|
||||
@@ -148,7 +148,7 @@ const SpaceModal = forwardRef<SpaceModalRef, SpaceModalProps>(({
|
||||
>
|
||||
<CustomSelect
|
||||
url={getModelListUrl}
|
||||
params={{ type: 'rerank', pagesize: 100 }}
|
||||
params={{ type: 'rerank', pagesize: 100, is_active: true }}
|
||||
valueKey="id"
|
||||
labelKey="name"
|
||||
hasAll={false}
|
||||
|
||||
@@ -98,7 +98,7 @@ const KnowledgeGlobalConfigModal = forwardRef<KnowledgeGlobalConfigModalRef, Kno
|
||||
>
|
||||
<CustomSelect
|
||||
url={getModelListUrl}
|
||||
params={{ type: 'rerank', pagesize: 100 }}
|
||||
params={{ type: 'rerank', pagesize: 100, is_active: true }}
|
||||
valueKey="id"
|
||||
labelKey="name"
|
||||
hasAll={false}
|
||||
|
||||
@@ -105,7 +105,7 @@ export const nodeLibrary: NodeLibrary[] = [
|
||||
model_id: {
|
||||
type: 'customSelect',
|
||||
url: getModelListUrl,
|
||||
params: { type: 'llm,chat' }, // llm/chat
|
||||
params: { type: 'llm,chat', is_active: true }, // llm/chat
|
||||
valueKey: 'id',
|
||||
labelKey: 'name',
|
||||
},
|
||||
@@ -166,7 +166,7 @@ export const nodeLibrary: NodeLibrary[] = [
|
||||
model_id: {
|
||||
type: 'customSelect',
|
||||
url: getModelListUrl,
|
||||
params: { type: 'llm,chat' }, // llm/chat
|
||||
params: { type: 'llm,chat', is_active: true }, // llm/chat
|
||||
valueKey: 'id',
|
||||
labelKey: 'name',
|
||||
},
|
||||
@@ -200,7 +200,7 @@ export const nodeLibrary: NodeLibrary[] = [
|
||||
config_id: {
|
||||
type: 'customSelect',
|
||||
url: memoryConfigListUrl,
|
||||
valueKey: ['config_id_old', 'config_id'],
|
||||
valueKey: 'config_id',
|
||||
labelKey: 'config_name'
|
||||
},
|
||||
search_switch: {
|
||||
@@ -223,7 +223,7 @@ export const nodeLibrary: NodeLibrary[] = [
|
||||
config_id: {
|
||||
type: 'customSelect',
|
||||
url: memoryConfigListUrl,
|
||||
valueKey: ['config_id_old', 'config_id'],
|
||||
valueKey: 'config_id',
|
||||
labelKey: 'config_name'
|
||||
}
|
||||
}
|
||||
@@ -259,7 +259,7 @@ export const nodeLibrary: NodeLibrary[] = [
|
||||
model_id: {
|
||||
type: 'customSelect',
|
||||
url: getModelListUrl,
|
||||
params: { type: 'llm,chat' }, // llm/chat
|
||||
params: { type: 'llm,chat', is_active: true }, // llm/chat
|
||||
valueKey: 'id',
|
||||
labelKey: 'name',
|
||||
},
|
||||
|
||||
@@ -14,7 +14,7 @@ export interface NodeConfig {
|
||||
|
||||
url?: string;
|
||||
params?: { [key: string]: unknown; }
|
||||
valueKey?: string | string[];
|
||||
valueKey?: string;
|
||||
labelKey?: string;
|
||||
|
||||
defaultValue?: any;
|
||||
|
||||
Reference in New Issue
Block a user