Merge branch 'release/v0.3.1' of github.com:SuanmoSuanyangTechnology/MemoryBear into release/v0.3.1

* 'release/v0.3.1' of github.com:SuanmoSuanyangTechnology/MemoryBear:
  fix(web): stream add default error message
  fix(quota): restrict quota check to new terminal user creation only
  fix(api): fix API Key rate limiting and terminal user quota checks
  feat(exception): enhance I18nException response format and add error code mapping
  feat(quota): add quota checks during app duplication and import operations
  fix(知识服务): 添加工作空间模型配置的校验
  refactor(knowledge_service): 简化模型绑定逻辑,直接使用工作区配置
  fix(知识服务): 修复创建知识库时未检查视觉模型存在的错误
  refactor(knowledge_service): 优化模型绑定逻辑,使用ID查询并简化回退机制
This commit is contained in:
Mark
2026-04-22 11:47:47 +08:00
10 changed files with 129 additions and 56 deletions

View File

@@ -167,6 +167,8 @@ def update_api_key(
return success(data=api_key_schema.ApiKey.model_validate(api_key), msg="API Key 更新成功")
except BusinessException:
raise
except Exception as e:
logger.error(f"未知错误: {str(e)}", extra={
"api_key_id": str(api_key_id),

View File

@@ -219,6 +219,7 @@ def delete_app(
@router.post("/{app_id}/copy", summary="复制应用")
@cur_workspace_access_guard()
@check_app_quota
def copy_app(
app_id: uuid.UUID,
new_name: Optional[str] = None,
@@ -1144,6 +1145,7 @@ async def import_workflow_config(
@router.post("/workflow/import/save")
@cur_workspace_access_guard()
@check_app_quota
async def save_workflow_import(
data: WorkflowImportSave,
db: Session = Depends(get_db),
@@ -1281,6 +1283,10 @@ async def import_app(
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
target_app_id = uuid.UUID(app_id) if app_id else None
# 仅新建应用时检查配额,覆盖已有应用时跳过
if target_app_id is None:
from app.core.quota_manager import _check_quota
_check_quota(db, current_user.tenant_id, "app_quota", "app")
result_app, warnings = AppDslService(db).import_dsl(
dsl=dsl,
workspace_id=current_user.current_workspace_id,

View File

@@ -219,9 +219,20 @@ def list_conversations(
end_user_repo = EndUserRepository(db)
app_service = AppService(db)
app = app_service._get_app_or_404(share.app_id)
workspace_id = app.workspace_id
# 仅在新建终端用户时检查配额
existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id)
if existing_end_user is None:
from app.core.quota_manager import _check_quota
from app.models.workspace_model import Workspace
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
if ws:
_check_quota(db, ws.tenant_id, "end_user_quota", "end_user")
new_end_user = end_user_repo.get_or_create_end_user(
app_id=share.app_id,
workspace_id=app.workspace_id,
workspace_id=workspace_id,
other_id=other_id
)
logger.debug(new_end_user.id)
@@ -309,7 +320,6 @@ def get_conversation(
"/chat",
summary="发送消息(支持流式和非流式)"
)
@check_end_user_quota
async def chat(
payload: conversation_schema.ChatRequest,
share_data: ShareTokenData = Depends(get_share_user_id),
@@ -350,6 +360,18 @@ async def chat(
app_service = AppService(db)
app = app_service._get_app_or_404(share.app_id)
workspace_id = app.workspace_id
# 仅在新建终端用户时检查配额,已有用户复用不受限制
existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id)
logger.info(f"终端用户配额检查: workspace_id={workspace_id}, other_id={other_id}, existing={existing_end_user is not None}")
if existing_end_user is None:
from app.core.quota_manager import _check_quota
from app.models.workspace_model import Workspace
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
if ws:
logger.info(f"新终端用户,执行配额检查: tenant_id={ws.tenant_id}")
_check_quota(db, ws.tenant_id, "end_user_quota", "end_user")
new_end_user = end_user_repo.get_or_create_end_user(
app_id=share.app_id,
workspace_id=workspace_id,

View File

@@ -106,6 +106,16 @@ async def chat(
other_id = payload.user_id
workspace_id = api_key_auth.workspace_id
end_user_repo = EndUserRepository(db)
# 仅在新建终端用户时检查配额,已有用户复用不受限制
existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id)
if existing_end_user is None:
from app.core.quota_manager import _check_quota
from app.models.workspace_model import Workspace
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
if ws:
_check_quota(db, ws.tenant_id, "end_user_quota", "end_user")
new_end_user = end_user_repo.get_or_create_end_user(
app_id=app.id,
workspace_id=workspace_id,

View File

@@ -32,6 +32,8 @@ class BizCode(IntEnum):
API_KEY_DAILY_LIMIT_EXCEEDED = 3015
API_KEY_QUOTA_EXCEEDED = 3016
API_KEY_RATE_LIMIT_EXCEEDED = 3017
QUOTA_EXCEEDED = 3018
RATE_LIMIT_EXCEEDED = 3019
# 资源4xxx
NOT_FOUND = 4000
USER_NOT_FOUND = 4001
@@ -156,7 +158,8 @@ HTTP_MAPPING = {
BizCode.API_KEY_QPS_LIMIT_EXCEEDED: 429,
BizCode.API_KEY_DAILY_LIMIT_EXCEEDED: 429,
BizCode.API_KEY_QUOTA_EXCEEDED: 429,
BizCode.QUOTA_EXCEEDED: 402,
BizCode.MODEL_CONFIG_INVALID: 400,
BizCode.API_KEY_MISSING: 400,
BizCode.PROVIDER_NOT_SUPPORTED: 400,
@@ -185,4 +188,21 @@ HTTP_MAPPING = {
BizCode.DB_ERROR: 500,
BizCode.SERVICE_UNAVAILABLE: 503,
BizCode.RATE_LIMITED: 429,
BizCode.RATE_LIMIT_EXCEEDED: 429,
}
ERROR_CODE_TO_BIZ_CODE = {
"QUOTA_EXCEEDED": BizCode.QUOTA_EXCEEDED,
"RATE_LIMIT_EXCEEDED": BizCode.RATE_LIMIT_EXCEEDED,
"API_KEY_NOT_FOUND": BizCode.API_KEY_NOT_FOUND,
"API_KEY_INVALID": BizCode.API_KEY_INVALID,
"API_KEY_EXPIRED": BizCode.API_KEY_EXPIRED,
"WORKSPACE_NOT_FOUND": BizCode.WORKSPACE_NOT_FOUND,
"WORKSPACE_NO_ACCESS": BizCode.WORKSPACE_NO_ACCESS,
"PERMISSION_DENIED": BizCode.PERMISSION_DENIED,
"TOKEN_EXPIRED": BizCode.TOKEN_EXPIRED,
"TOKEN_INVALID": BizCode.TOKEN_INVALID,
"VALIDATION_FAILED": BizCode.VALIDATION_FAILED,
"INVALID_PARAMETER": BizCode.INVALID_PARAMETER,
"MISSING_PARAMETER": BizCode.MISSING_PARAMETER,
}

View File

@@ -6,12 +6,14 @@ error messages based on the current request's language.
"""
import logging
import time
from contextvars import ContextVar
from typing import Any, Dict, Optional
from fastapi import HTTPException, Request
from app.i18n.service import get_translation_service
from app.core.error_codes import ERROR_CODE_TO_BIZ_CODE, BizCode
logger = logging.getLogger(__name__)
@@ -118,15 +120,24 @@ class I18nException(HTTPException):
**params
)
# Build error detail
detail = {
"error_code": self.error_code,
"message": message,
}
# Convert error_code string to BizCode value
biz_code = ERROR_CODE_TO_BIZ_CODE.get(
self.error_code,
BizCode.BAD_REQUEST
)
# Add parameters to detail if provided
if params:
detail["params"] = params
# Build error detail in standard format for compatibility
# main.py handler expects "message" and "error_code" fields for filtering
# but we also include standard format fields
detail = {
"code": biz_code.value,
"msg": message,
"message": message,
"error_code": self.error_code,
"data": params if params else {},
"error": message,
"time": int(time.time() * 1000),
}
# Initialize HTTPException
super().__init__(

View File

@@ -66,6 +66,17 @@ class EndUserRepository:
db_logger.error(f"查询宿主 {end_user_id} 时出错: {str(e)}")
raise
def get_end_user_by_other_id(self, workspace_id: uuid.UUID, other_id: str) -> Optional["EndUser"]:
"""按 workspace_id + other_id 查找终端用户,不存在返回 None"""
return (
self.db.query(EndUser)
.filter(
EndUser.workspace_id == workspace_id,
EndUser.other_id == other_id
)
.first()
)
def get_or_create_end_user(
self,
app_id: uuid.UUID,

View File

@@ -51,7 +51,7 @@ class ApiKeyService:
if existing:
raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME)
# 若 rate_limit 超过租户套餐的 api_ops_rate_limit自动截断到套餐上限
# 若 rate_limit 超过租户套餐的 api_ops_rate_limit直接报错
from app.models.workspace_model import Workspace
from app.core.quota_manager import get_api_ops_rate_limit
@@ -59,7 +59,10 @@ class ApiKeyService:
if workspace:
tenant_api_ops_limit = get_api_ops_rate_limit(db, workspace.tenant_id)
if tenant_api_ops_limit and data.rate_limit > tenant_api_ops_limit:
data.rate_limit = tenant_api_ops_limit
raise BusinessException(
f"API Key QPS 不能超过套餐上限 {tenant_api_ops_limit}",
BizCode.BAD_REQUEST
)
# 生成 API Key
api_key = generate_api_key(data.type)
@@ -162,7 +165,7 @@ class ApiKeyService:
if existing:
raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME)
# 若 rate_limit 超过租户套餐的 api_ops_rate_limit自动截断到套餐上限
# 若 rate_limit 超过租户套餐的 api_ops_rate_limit直接报错
if data.rate_limit is not None:
from app.models.workspace_model import Workspace
from app.core.quota_manager import get_api_ops_rate_limit
@@ -171,7 +174,10 @@ class ApiKeyService:
if workspace:
tenant_api_ops_limit = get_api_ops_rate_limit(db, workspace.tenant_id)
if tenant_api_ops_limit and data.rate_limit > tenant_api_ops_limit:
data.rate_limit = tenant_api_ops_limit
raise BusinessException(
f"API Key QPS 不能超过套餐上限 {tenant_api_ops_limit}",
BizCode.BAD_REQUEST
)
update_data = data.model_dump(exclude_unset=True)
ApiKeyRepository.update(db, api_key_id, update_data)

View File

@@ -7,7 +7,6 @@ from app.models.models_model import ModelConfig
from app.schemas.knowledge_schema import KnowledgeCreate, KnowledgeUpdate
from app.repositories import knowledge_repository
from app.core.logging_config import get_business_logger
from app.repositories.model_repository import ModelConfigRepository
from app.models.models_model import ModelType
business_logger = get_business_logger()
@@ -77,53 +76,32 @@ def create_knowledge(
tenant_id = workspace.tenant_id
def _get_model_by_name_or_fallback(model_name: str | None, model_types: list, label: str):
"""优先按 workspace 指定的 model name 查,找不到再 fallback 到 tenant 下第一个"""
if model_name:
model = db.query(ModelConfig).filter(
ModelConfig.tenant_id == tenant_id,
ModelConfig.name == model_name,
ModelConfig.type.in_([t.value for t in model_types]),
ModelConfig.is_active == True,
ModelConfig.is_composite == False
).first()
if model:
business_logger.debug(f"Auto-bind {label} model from workspace default: {model.id} ({model_name})")
return model
business_logger.debug(f"Workspace default {label} model '{model_name}' not found, falling back to tenant")
models = ModelConfigRepository.get_by_type(db=db, model_types=model_types, tenant_id=tenant_id, is_active=True)
if models:
business_logger.debug(f"Auto-bind {label} model from tenant fallback: {models[0].id}")
return models[0]
return None
if not knowledge.embedding_id:
model = _get_model_by_name_or_fallback(workspace.embedding, [ModelType.EMBEDDING], "embedding")
if model:
knowledge.embedding_id = model.id
if not workspace.embedding:
raise Exception("工作空间未配置 Embedding 模型,请先完善工作空间配置后重试")
knowledge.embedding_id = workspace.embedding
if not knowledge.reranker_id:
model = _get_model_by_name_or_fallback(workspace.rerank, [ModelType.RERANK], "rerank")
if model:
knowledge.reranker_id = model.id
if not workspace.rerank:
raise Exception("工作空间未配置 Rerank 模型,请先完善工作空间配置后重试")
knowledge.reranker_id = workspace.rerank
if not knowledge.llm_id:
model = _get_model_by_name_or_fallback(workspace.llm, [ModelType.LLM, ModelType.CHAT], "llm")
if model:
knowledge.llm_id = model.id
if not workspace.llm:
raise Exception("工作空间未配置 LLM 模型,请先完善工作空间配置后重试")
knowledge.llm_id = workspace.llm
if not knowledge.image2text_id:
image2text_models = db.query(ModelConfig).filter(
model = db.query(ModelConfig).filter(
ModelConfig.tenant_id == tenant_id,
ModelConfig.type.in_([ModelType.CHAT.value]),
ModelConfig.type.in_([ModelType.CHAT.value, ModelType.LLM.value]),
ModelConfig.capability.contains(["vision"]),
ModelConfig.is_active == True,
ModelConfig.is_composite == False
).order_by(ModelConfig.created_at.desc()).all()
if not image2text_models:
).order_by(ModelConfig.created_at.desc()).first()
if not model:
raise Exception("租户下没有可用的视觉模型,创建知识库失败")
knowledge.image2text_id = image2text_models[0].id
business_logger.debug(f"Auto-bind image2text model: {image2text_models[0].id}")
knowledge.image2text_id = model.id
business_logger.debug(f"Auto-bind image2text model: {model.id}")
business_logger.debug(f"Start creating the knowledge base: {knowledge.name}")
db_knowledge = knowledge_repository.create_knowledge(

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-02 16:35:43
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-18 14:32:40
* @Last Modified time: 2026-04-22 10:16:43
*/
/**
* Server-Sent Events (SSE) Stream Utility Module
@@ -176,12 +176,12 @@ export const handleSSE = async (url: string, data: any, onMessage?: (data: SSEMe
case 500:
case 502:
const errorData = await response.json();
const errorInfo = errorData.error || i18n.t('common.serviceUpgrading');
const errorInfo = errorData.error || errorData.msg || i18n.t('common.serviceUpgrading');
message.warning(errorInfo);
throw new Error(errorData);
case 400:
const error = await response.json();
const error400 = error.error || 'Bad Request';
const error400 = error.error || error.msg || 'Bad Request';
message.warning(error400);
throw new Error(error);
case 403:
@@ -190,7 +190,7 @@ export const handleSSE = async (url: string, data: any, onMessage?: (data: SSEMe
throw new Error(errors);
case 504:
const errorJson = await response.json();
const errorMsg = errorJson.error || i18n.t('common.serverError');
const errorMsg = errorJson.error || errorJson.msg || i18n.t('common.serverError');
message.warning(errorMsg);
throw new Error(errorJson);
case 401:
@@ -204,6 +204,13 @@ export const handleSSE = async (url: string, data: any, onMessage?: (data: SSEMe
return;
}
break;
default:
if (!response.ok) {
const defaultData = await response.json().catch(() => ({}));
const defaultMsg = defaultData.error || defaultData.msg;
if (defaultMsg) message.warning(defaultMsg);
throw new Error(defaultMsg || `HTTP ${response.status}`);
}
}
if (!response.body) throw new Error('No response body');