Merge remote-tracking branch 'origin/release/v0.3.1' into fix/Timebomb_031

This commit is contained in:
Timebomb2018
2026-04-20 20:48:46 +08:00
23 changed files with 520 additions and 343 deletions

View File

@@ -47,7 +47,7 @@ async def create_end_user(
request: Request, request: Request,
api_key_auth: ApiKeyAuth = None, api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db), db: Session = Depends(get_db),
message: str = Body(..., description="Request body"), message: str = Body(None, description="Request body"),
): ):
""" """
Create or retrieve an end user for the workspace. Create or retrieve an end user for the workspace.

View File

@@ -96,40 +96,8 @@ def require_api_key(
resource_id=api_key_obj.resource_id, resource_id=api_key_obj.resource_id,
) )
# ── Tenant 级别限速(来自套餐配额 api_ops_rate_limit──────────
try:
from app.models.workspace_model import Workspace
from premium.platform_admin.package_plan_service import TenantSubscriptionService
workspace = db.query(Workspace).filter(
Workspace.id == api_key_obj.workspace_id
).first()
if workspace:
quota = TenantSubscriptionService(db).get_effective_quota(workspace.tenant_id)
tenant_qps_limit = quota.get("api_ops_rate_limit") if quota else None
if tenant_qps_limit:
rate_limiter = RateLimiterService()
tenant_ok, tenant_info = await rate_limiter.check_tenant_rate_limit(
workspace.tenant_id, tenant_qps_limit
)
if not tenant_ok:
raise RateLimitException(
"租户 API 调用速率超限",
BizCode.API_KEY_QPS_LIMIT_EXCEEDED,
rate_headers={
"X-RateLimit-Tenant-Limit": str(tenant_info["limit"]),
"X-RateLimit-Tenant-Remaining": str(tenant_info["remaining"]),
"X-RateLimit-Tenant-Reset": str(tenant_info["reset"]),
}
)
except RateLimitException:
raise
except Exception as e:
logger.warning(f"Tenant 限速检查异常,跳过: {e}")
# ─────────────────────────────────────────────────────────────
rate_limiter = RateLimiterService() rate_limiter = RateLimiterService()
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj) is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj, db=db)
if not is_allowed: if not is_allowed:
logger.warning("API Key 限流触发", extra={ logger.warning("API Key 限流触发", extra={
"api_key_id": str(api_key_obj.id), "api_key_id": str(api_key_obj.id),
@@ -138,10 +106,12 @@ def require_api_key(
"error_msg": error_msg "error_msg": error_msg
}) })
# 根据错误消息判断限流类型 # 根据错误消息判断限流类型
if "QPS" in error_msg: if "Daily" in error_msg:
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
elif "Daily" in error_msg:
code = BizCode.API_KEY_DAILY_LIMIT_EXCEEDED code = BizCode.API_KEY_DAILY_LIMIT_EXCEEDED
elif "Tenant" in error_msg:
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED # 租户套餐速率超限,同属 QPS 类
elif "QPS" in error_msg:
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
else: else:
code = BizCode.API_KEY_QUOTA_EXCEEDED code = BizCode.API_KEY_QUOTA_EXCEEDED

View File

@@ -31,6 +31,7 @@ class BizCode(IntEnum):
API_KEY_QPS_LIMIT_EXCEEDED = 3014 API_KEY_QPS_LIMIT_EXCEEDED = 3014
API_KEY_DAILY_LIMIT_EXCEEDED = 3015 API_KEY_DAILY_LIMIT_EXCEEDED = 3015
API_KEY_QUOTA_EXCEEDED = 3016 API_KEY_QUOTA_EXCEEDED = 3016
API_KEY_RATE_LIMIT_EXCEEDED = 3017
# 资源4xxx # 资源4xxx
NOT_FOUND = 4000 NOT_FOUND = 4000
USER_NOT_FOUND = 4001 USER_NOT_FOUND = 4001

View File

@@ -6,7 +6,6 @@
2. 降级到 default_free_plan.py 配置文件(社区版兜底) 2. 降级到 default_free_plan.py 配置文件(社区版兜底)
""" """
import asyncio import asyncio
import time
from functools import wraps from functools import wraps
from typing import Optional, Callable, Dict, Any from typing import Optional, Callable, Dict, Any
from uuid import UUID from uuid import UUID
@@ -15,10 +14,13 @@ from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.logging_config import get_auth_logger from app.core.logging_config import get_auth_logger
from app.i18n.exceptions import QuotaExceededError from app.i18n.exceptions import QuotaExceededError, InternalServerError
logger = get_auth_logger() logger = get_auth_logger()
# Redis key 格式常量,与 RateLimiterService.check_qps 保持一致per api_key 独立计数)
API_KEY_QPS_REDIS_KEY = "rate_limit:qps:{api_key_id}"
def _get_user_from_kwargs(kwargs: dict): def _get_user_from_kwargs(kwargs: dict):
"""从 kwargs 中获取 user 对象""" """从 kwargs 中获取 user 对象"""
@@ -65,7 +67,9 @@ def _get_tenant_id_from_kwargs(db: Session, kwargs: dict):
if share_record: if share_record:
app = db.query(App).filter(App.id == share_record.app_id, App.is_active.is_(True)).first() app = db.query(App).filter(App.id == share_record.app_id, App.is_active.is_(True)).first()
if app: if app:
return app.workspace.tenant_id workspace = db.query(Workspace).filter(Workspace.id == app.workspace_id).first()
if workspace:
return workspace.tenant_id
return None return None
@@ -78,26 +82,47 @@ def _get_quota_config(db: Session, tenant_id: UUID) -> Optional[Dict[str, Any]]:
1. premium 模块的 tenant_subscriptionsSaaS 版) 1. premium 模块的 tenant_subscriptionsSaaS 版)
2. default_free_plan.py 配置文件(社区版兜底) 2. default_free_plan.py 配置文件(社区版兜底)
""" """
# 尝试从 premium 模块获取 # 尝试从 premium 模块获取SaaS 版)
try: try:
from premium.platform_admin.package_plan_service import TenantSubscriptionService from premium.platform_admin.package_plan_service import TenantSubscriptionService
# premium 模块存在,运行时错误不应被静默降级,直接抛出
quota_config = TenantSubscriptionService(db).get_effective_quota(tenant_id) quota_config = TenantSubscriptionService(db).get_effective_quota(tenant_id)
if quota_config: if quota_config:
logger.debug(f"从 premium 模块获取租户 {tenant_id} 配额配置") logger.debug(f"从 premium 模块获取租户 {tenant_id} 配额配置")
return quota_config return quota_config
except (ModuleNotFoundError, ImportError, Exception) as e: # premium 存在但该租户无订阅记录,降级到免费套餐
logger.debug(f"无法从 premium 模块获取配额配置: {e}") logger.debug(f"租户 {tenant_id} 无 premium 订阅,降级到免费套餐")
except (ModuleNotFoundError, ImportError):
# 社区版premium 包不存在,正常降级
logger.debug("premium 模块不存在,使用社区版免费套餐配额")
# 降级到配置文件 # 降级到社区版配置文件
try: try:
from app.config.default_free_plan import DEFAULT_FREE_PLAN from app.config.default_free_plan import DEFAULT_FREE_PLAN
logger.info(f"使用配置文件中的免费套餐配额: tenant={tenant_id}") logger.debug(f"使用社区版免费套餐配额: tenant={tenant_id}")
return DEFAULT_FREE_PLAN.get("quotas") return DEFAULT_FREE_PLAN.get("quotas")
except Exception as e: except Exception as e:
logger.error(f"无法从配置文件获取配额: {e}") logger.error(f"无法从配置文件获取配额: {e}")
return None return None
def get_api_ops_rate_limit(db: Session, tenant_id: UUID) -> Optional[int]:
"""
获取租户套餐的 API 操作速率限制QPS 上限)
该函数兼容社区版和 SaaS 版:
- SaaS 版:从 premium 模块的套餐配额读取
- 社区版:从 default_free_plan.py 配置文件读取
Returns:
int: api_ops_rate_limit 值,如果未配置则返回 None
"""
quota_config = _get_quota_config(db, tenant_id)
if quota_config:
return quota_config.get("api_ops_rate_limit")
return None
class QuotaUsageRepository: class QuotaUsageRepository:
"""配额使用量数据访问层""" """配额使用量数据访问层"""
@@ -247,41 +272,74 @@ def _check_quota(
def check_workspace_quota(func: Callable) -> Callable: def check_workspace_quota(func: Callable) -> Callable:
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs) user = _get_user_from_kwargs(kwargs)
if not db or not user: if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
return func(*args, **kwargs) raise InternalServerError()
_check_quota(db, user.tenant_id, "workspace_quota", "workspace")
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "workspace_quota", "workspace") _check_quota(db, user.tenant_id, "workspace_quota", "workspace")
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
def check_skill_quota(func: Callable) -> Callable: def check_skill_quota(func: Callable) -> Callable:
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs) user = _get_user_from_kwargs(kwargs)
if not db or not user: if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
return func(*args, **kwargs) raise InternalServerError()
_check_quota(db, user.tenant_id, "skill_quota", "skill")
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "skill_quota", "skill") _check_quota(db, user.tenant_id, "skill_quota", "skill")
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
def check_app_quota(func: Callable) -> Callable: def check_app_quota(func: Callable) -> Callable:
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs) user = _get_user_from_kwargs(kwargs)
if not db or not user: if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
return func(*args, **kwargs) raise InternalServerError()
_check_quota(db, user.tenant_id, "app_quota", "app")
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "app_quota", "app") _check_quota(db, user.tenant_id, "app_quota", "app")
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
def check_knowledge_capacity_quota(func: Callable) -> Callable: def check_knowledge_capacity_quota(func: Callable) -> Callable:
@@ -289,12 +347,12 @@ def check_knowledge_capacity_quota(func: Callable) -> Callable:
async def async_wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
if not db: if not db:
logger.warning("配额检查失败:缺少 db 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 参数,拒绝请求")
return await func(*args, **kwargs) raise InternalServerError()
tenant_id = _get_tenant_id_from_kwargs(db, kwargs) tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
if not tenant_id: if not tenant_id:
logger.warning("配额检查失败:无法获取 tenant_id") logger.error(f"配额检查失败:{func.__name__} 无法获取 tenant_id,拒绝请求")
return await func(*args, **kwargs) raise InternalServerError()
_check_quota(db, tenant_id, "knowledge_capacity_quota", "knowledge_capacity") _check_quota(db, tenant_id, "knowledge_capacity_quota", "knowledge_capacity")
return await func(*args, **kwargs) return await func(*args, **kwargs)
@@ -303,8 +361,8 @@ def check_knowledge_capacity_quota(func: Callable) -> Callable:
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs) user = _get_user_from_kwargs(kwargs)
if not db or not user: if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
return func(*args, **kwargs) raise InternalServerError()
_check_quota(db, user.tenant_id, "knowledge_capacity_quota", "knowledge_capacity") _check_quota(db, user.tenant_id, "knowledge_capacity_quota", "knowledge_capacity")
return func(*args, **kwargs) return func(*args, **kwargs)
@@ -313,15 +371,26 @@ def check_knowledge_capacity_quota(func: Callable) -> Callable:
def check_memory_engine_quota(func: Callable) -> Callable: def check_memory_engine_quota(func: Callable) -> Callable:
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs) user = _get_user_from_kwargs(kwargs)
if not db or not user: if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
return func(*args, **kwargs) raise InternalServerError()
_check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine")
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine") _check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine")
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
def check_end_user_quota(func: Callable) -> Callable: def check_end_user_quota(func: Callable) -> Callable:
@@ -329,12 +398,12 @@ def check_end_user_quota(func: Callable) -> Callable:
async def async_wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
if not db: if not db:
logger.warning("配额检查失败:缺少 db 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 参数,拒绝请求")
return await func(*args, **kwargs) raise InternalServerError()
tenant_id = _get_tenant_id_from_kwargs(db, kwargs) tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
if not tenant_id: if not tenant_id:
logger.warning("配额检查失败:无法获取 tenant_id") logger.error(f"配额检查失败:{func.__name__} 无法获取 tenant_id,拒绝请求")
return await func(*args, **kwargs) raise InternalServerError()
_check_quota(db, tenant_id, "end_user_quota", "end_user") _check_quota(db, tenant_id, "end_user_quota", "end_user")
return await func(*args, **kwargs) return await func(*args, **kwargs)
@@ -342,12 +411,12 @@ def check_end_user_quota(func: Callable) -> Callable:
def sync_wrapper(*args, **kwargs): def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
if not db: if not db:
logger.warning("配额检查失败:缺少 db 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 参数,拒绝请求")
return func(*args, **kwargs) raise InternalServerError()
tenant_id = _get_tenant_id_from_kwargs(db, kwargs) tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
if not tenant_id: if not tenant_id:
logger.warning("配额检查失败:无法获取 tenant_id") logger.error(f"配额检查失败:{func.__name__} 无法获取 tenant_id,拒绝请求")
return func(*args, **kwargs) raise InternalServerError()
_check_quota(db, tenant_id, "end_user_quota", "end_user") _check_quota(db, tenant_id, "end_user_quota", "end_user")
return func(*args, **kwargs) return func(*args, **kwargs)
@@ -356,39 +425,95 @@ def check_end_user_quota(func: Callable) -> Callable:
def check_ontology_project_quota(func: Callable) -> Callable: def check_ontology_project_quota(func: Callable) -> Callable:
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs) user = _get_user_from_kwargs(kwargs)
if not db or not user: if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
return func(*args, **kwargs) raise InternalServerError()
_check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project")
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project") _check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project")
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
def check_model_quota(func: Callable) -> Callable: def check_model_quota(func: Callable) -> Callable:
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs) user = _get_user_from_kwargs(kwargs)
if not db or not user: if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
return func(*args, **kwargs) raise InternalServerError()
_check_quota(db, user.tenant_id, "model_quota", "model")
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "model_quota", "model") _check_quota(db, user.tenant_id, "model_quota", "model")
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
def check_model_activation_quota(func: Callable) -> Callable: def check_model_activation_quota(func: Callable) -> Callable:
"""模型激活时的配额检查装饰器""" """模型激活时的配额检查装饰器"""
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs) user = _get_user_from_kwargs(kwargs)
if not db or not user: if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
return func(*args, **kwargs) raise InternalServerError()
model_id = kwargs.get("model_id") or (args[1] if len(args) > 1 else None)
model_data = kwargs.get("model_data")
if not model_id or not model_data:
logger.warning("模型激活配额检查失败:缺少 model_id 或 model_data 参数")
return await func(*args, **kwargs)
if model_data.is_active:
try:
from app.services.model_service import ModelConfigService
existing_model = ModelConfigService.get_model_by_id(
db=db,
model_id=model_id,
tenant_id=user.tenant_id
)
if not existing_model.is_active:
logger.info(f"模型激活操作,检查配额: model_id={model_id}, tenant_id={user.tenant_id}")
_check_quota(db, user.tenant_id, "model_quota", "model")
except Exception as e:
logger.error(f"模型激活配额检查异常: model_id={model_id}, error={str(e)}")
raise
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
model_id = kwargs.get("model_id") or (args[1] if len(args) > 1 else None) model_id = kwargs.get("model_id") or (args[1] if len(args) > 1 else None)
model_data = kwargs.get("model_data") model_data = kwargs.get("model_data")
@@ -397,9 +522,8 @@ def check_model_activation_quota(func: Callable) -> Callable:
logger.warning("模型激活配额检查失败:缺少 model_id 或 model_data 参数") logger.warning("模型激活配额检查失败:缺少 model_id 或 model_data 参数")
return func(*args, **kwargs) return func(*args, **kwargs)
if model_data.is_active is True: if model_data.is_active:
try: try:
from app.models.models_model import ModelConfig
from app.services.model_service import ModelConfigService from app.services.model_service import ModelConfigService
existing_model = ModelConfigService.get_model_by_id( existing_model = ModelConfigService.get_model_by_id(
@@ -416,28 +540,40 @@ def check_model_activation_quota(func: Callable) -> Callable:
raise raise
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
def check_quota(quota_type: str, resource_name: str, usage_func: Optional[Callable] = None): def check_quota(quota_type: str, resource_name: str, usage_func: Optional[Callable] = None):
"""通用配额检查装饰器,支持自定义使用量获取函数""" """通用配额检查装饰器,支持自定义使用量获取函数"""
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs) user = _get_user_from_kwargs(kwargs)
if not db or not user: if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
return func(*args, **kwargs) raise InternalServerError()
_check_quota(db, user.tenant_id, quota_type, resource_name, usage_func)
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, quota_type, resource_name, usage_func) _check_quota(db, user.tenant_id, quota_type, resource_name, usage_func)
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
return decorator return decorator
# ─── 配额使用统计 ──────────────────────────────────────────────────────────── # ─── 配额使用统计 ────────────────────────────────────────────────────────────
def get_quota_usage(db: Session, tenant_id: UUID) -> dict: async def get_quota_usage(db: Session, tenant_id: UUID) -> dict:
"""获取租户所有配额的使用情况""" """获取租户所有配额的使用情况"""
quota_config = _get_quota_config(db, tenant_id) quota_config = _get_quota_config(db, tenant_id)
if not quota_config: if not quota_config:
@@ -459,18 +595,25 @@ def get_quota_usage(db: Session, tenant_id: UUID) -> dict:
api_ops_current = 0 api_ops_current = 0
try: try:
from app.core.config import settings from app.aioRedis import aio_redis as _aio_redis
import redis from app.models.api_key_model import ApiKey
_now = time.time() from app.models.workspace_model import Workspace
_rk = f"rate_limit:tenant_qps:{tenant_id}" # api_ops_rate_limit 限的是每个 api_key 每秒最高限额
_r = redis.StrictRedis( # 展示当前最接近触发限流的 key 的 QPS取最大值
host=settings.REDIS_HOST, port=settings.REDIS_PORT, api_key_ids = db.query(ApiKey.id).join(
db=settings.REDIS_DB, password=settings.REDIS_PASSWORD, Workspace, ApiKey.workspace_id == Workspace.id
decode_responses=True ).filter(
) Workspace.tenant_id == tenant_id,
api_ops_current = int(_r.zcount(_rk, _now - 1, "+inf")) ApiKey.is_active.is_(True)
except Exception: ).all()
pass for (key_id,) in api_key_ids:
_rk = API_KEY_QPS_REDIS_KEY.format(api_key_id=key_id)
val = await _aio_redis.get(_rk)
count = int(val) if val else 0
if count > api_ops_current:
api_ops_current = count
except Exception as e:
logger.warning(f"获取 api_ops_current 失败,返回 0: {type(e).__name__}: {e}")
return { return {
"workspace": {"used": workspace_count, "limit": quota_config.get("workspace_quota"), "percentage": pct(workspace_count, quota_config.get("workspace_quota"))}, "workspace": {"used": workspace_count, "limit": quota_config.get("workspace_quota"), "percentage": pct(workspace_count, quota_config.get("workspace_quota"))},

View File

@@ -18,6 +18,7 @@ from app.core.quota_manager import (
get_quota_usage, get_quota_usage,
_check_quota, _check_quota,
QuotaUsageRepository, QuotaUsageRepository,
API_KEY_QPS_REDIS_KEY,
) )
__all__ = [ __all__ = [
@@ -33,4 +34,5 @@ __all__ = [
"get_quota_usage", "get_quota_usage",
"_check_quota", "_check_quota",
"QuotaUsageRepository", "QuotaUsageRepository",
"API_KEY_QPS_REDIS_KEY",
] ]

View File

@@ -482,14 +482,39 @@ class RateLimitExceededError(I18nException):
) )
class QuotaExceededError(ForbiddenError): class QuotaExceededError(I18nException):
"""Quota exceeded error.""" """Quota exceeded error (402)."""
# resource key -> i18n display key
_RESOURCE_KEY_MAP = {
"workspace": "errors.quota_resources.workspace",
"app": "errors.quota_resources.app",
"skill": "errors.quota_resources.skill",
"knowledge_capacity": "errors.quota_resources.knowledge_capacity",
"memory_engine": "errors.quota_resources.memory_engine",
"end_user": "errors.quota_resources.end_user",
"model": "errors.quota_resources.model",
"ontology_project": "errors.quota_resources.ontology_project",
"api_ops_rate_limit": "errors.quota_resources.api_ops_rate_limit",
}
def __init__(self, resource: Optional[str] = None, **params): def __init__(self, resource: Optional[str] = None, **params):
# Translate resource key to a localized display name before calling super()
if resource: if resource:
params["resource"] = resource resource_i18n_key = self._RESOURCE_KEY_MAP.get(resource)
if resource_i18n_key:
try:
from app.i18n.service import get_translation_service
from app.core.config import settings
_locale = _current_locale.get() or settings.I18N_DEFAULT_LANGUAGE
params["resource"] = get_translation_service().translate(resource_i18n_key, _locale)
except Exception:
params["resource"] = resource
else:
params["resource"] = resource
super().__init__( super().__init__(
error_key="errors.api.quota_exceeded", error_key="errors.api.quota_exceeded",
status_code=402,
error_code="QUOTA_EXCEEDED", error_code="QUOTA_EXCEEDED",
**params **params
) )

View File

@@ -106,7 +106,7 @@
}, },
"api": { "api": {
"rate_limit_exceeded": "API rate limit exceeded", "rate_limit_exceeded": "API rate limit exceeded",
"quota_exceeded": "API quota exceeded", "quota_exceeded": "{resource} quota exceeded",
"invalid_api_key": "Invalid API key", "invalid_api_key": "Invalid API key",
"api_key_expired": "API key has expired", "api_key_expired": "API key has expired",
"api_key_revoked": "API key has been revoked", "api_key_revoked": "API key has been revoked",
@@ -114,7 +114,8 @@
"method_not_allowed": "Method not allowed", "method_not_allowed": "Method not allowed",
"invalid_request": "Invalid request", "invalid_request": "Invalid request",
"missing_parameter": "Missing required parameter: {param}", "missing_parameter": "Missing required parameter: {param}",
"invalid_parameter": "Invalid parameter: {param}" "invalid_parameter": "Invalid parameter: {param}",
"api_key_rate_limit_exceeded": "API Key rate limit ({rate_limit}) exceeds tenant plan limit ({limit})"
}, },
"database": { "database": {
"connection_failed": "Database connection failed", "connection_failed": "Database connection failed",
@@ -134,5 +135,16 @@
"invalid_format": "Invalid format: {field}", "invalid_format": "Invalid format: {field}",
"invalid_value": "Invalid value: {field}", "invalid_value": "Invalid value: {field}",
"out_of_range": "Value out of range: {field}" "out_of_range": "Value out of range: {field}"
},
"quota_resources": {
"workspace": "Workspace",
"app": "App",
"skill": "Skill",
"knowledge_capacity": "Knowledge capacity",
"memory_engine": "Memory engine",
"end_user": "End user",
"model": "Model",
"ontology_project": "Ontology project",
"api_ops_rate_limit": "API ops rate limit"
} }
} }

View File

@@ -106,7 +106,7 @@
}, },
"api": { "api": {
"rate_limit_exceeded": "API调用频率超限", "rate_limit_exceeded": "API调用频率超限",
"quota_exceeded": "API调用配额已用完", "quota_exceeded": "{resource} 配额已超限",
"invalid_api_key": "无效的API密钥", "invalid_api_key": "无效的API密钥",
"api_key_expired": "API密钥已过期", "api_key_expired": "API密钥已过期",
"api_key_revoked": "API密钥已被撤销", "api_key_revoked": "API密钥已被撤销",
@@ -114,7 +114,8 @@
"method_not_allowed": "不支持的请求方法", "method_not_allowed": "不支持的请求方法",
"invalid_request": "无效的请求", "invalid_request": "无效的请求",
"missing_parameter": "缺少必需参数:{param}", "missing_parameter": "缺少必需参数:{param}",
"invalid_parameter": "参数无效:{param}" "invalid_parameter": "参数无效:{param}",
"api_key_rate_limit_exceeded": "API Key 的 QPS 限制({rate_limit})超过租户套餐上限({limit}"
}, },
"database": { "database": {
"connection_failed": "数据库连接失败", "connection_failed": "数据库连接失败",
@@ -134,5 +135,16 @@
"invalid_format": "格式不正确:{field}", "invalid_format": "格式不正确:{field}",
"invalid_value": "值无效:{field}", "invalid_value": "值无效:{field}",
"out_of_range": "值超出范围:{field}" "out_of_range": "值超出范围:{field}"
},
"quota_resources": {
"workspace": "工作空间",
"app": "应用",
"skill": "技能",
"knowledge_capacity": "知识库容量",
"memory_engine": "记忆引擎",
"end_user": "终端用户",
"model": "模型",
"ontology_project": "本体工程",
"api_ops_rate_limit": "API 操作速率"
} }
} }

View File

@@ -51,6 +51,16 @@ class ApiKeyService:
if existing: if existing:
raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME) raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME)
# 若 rate_limit 超过租户套餐的 api_ops_rate_limit自动截断到套餐上限
from app.models.workspace_model import Workspace
from app.core.quota_manager import get_api_ops_rate_limit
workspace = db.query(Workspace).filter(Workspace.id == workspace_id).first()
if workspace:
tenant_api_ops_limit = get_api_ops_rate_limit(db, workspace.tenant_id)
if tenant_api_ops_limit and data.rate_limit > tenant_api_ops_limit:
data.rate_limit = tenant_api_ops_limit
# 生成 API Key # 生成 API Key
api_key = generate_api_key(data.type) api_key = generate_api_key(data.type)
@@ -152,6 +162,17 @@ class ApiKeyService:
if existing: if existing:
raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME) raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME)
# 若 rate_limit 超过租户套餐的 api_ops_rate_limit自动截断到套餐上限
if data.rate_limit is not None:
from app.models.workspace_model import Workspace
from app.core.quota_manager import get_api_ops_rate_limit
workspace = db.query(Workspace).filter(Workspace.id == workspace_id).first()
if workspace:
tenant_api_ops_limit = get_api_ops_rate_limit(db, workspace.tenant_id)
if tenant_api_ops_limit and data.rate_limit > tenant_api_ops_limit:
data.rate_limit = tenant_api_ops_limit
update_data = data.model_dump(exclude_unset=True) update_data = data.model_dump(exclude_unset=True)
ApiKeyRepository.update(db, api_key_id, update_data) ApiKeyRepository.update(db, api_key_id, update_data)
db.commit() db.commit()
@@ -248,42 +269,14 @@ class RateLimiterService:
def __init__(self): def __init__(self):
self.redis = aio_redis self.redis = aio_redis
async def check_tenant_rate_limit(self, tenant_id: uuid.UUID, limit: int) -> Tuple[bool, dict]:
"""
按 tenant_id 做 1 秒滑动窗口限速,限制值来自套餐配额 api_ops_rate_limit
"""
now = time.time()
window_start = now - 1 # 1 秒窗口
key = f"rate_limit:tenant_qps:{tenant_id}"
async with self.redis.pipeline() as pipe:
# 清理 1 秒前的旧记录
pipe.zremrangebyscore(key, 0, window_start)
# 加入当前请求score=时间戳member=时间戳+随机数保证唯一)
pipe.zadd(key, {f"{now}:{uuid.uuid4().hex}": now})
# 统计窗口内请求数
pipe.zcard(key)
# 设置 key 过期2 秒后自动清理)
pipe.expire(key, 2)
results = await pipe.execute()
current = results[2]
remaining = max(0, limit - current)
reset_time = int(now) + 1
return current <= limit, {
"limit": limit,
"remaining": remaining,
"reset": reset_time,
}
async def check_qps(self, api_key_id: uuid.UUID, limit: int) -> Tuple[bool, dict]: async def check_qps(self, api_key_id: uuid.UUID, limit: int) -> Tuple[bool, dict]:
""" """检查QPS限制
检查QPS限制
Returns: Returns:
(is_allowed, rate_limit_info) (is_allowed, rate_limit_info)
""" """
key = f"rate_limit:qps:{api_key_id}" key = f"rate_limit:qps:{api_key_id}"
async with self.redis.pipeline() as pipe: async with self.redis.pipeline() as pipe:
pipe.incr(key) pipe.incr(key)
pipe.expire(key, 1, nx=True) # 1 秒过期 pipe.expire(key, 1, nx=True) # 1 秒过期
@@ -295,8 +288,9 @@ class RateLimiterService:
return current <= limit, { return current <= limit, {
"limit": limit, "limit": limit,
"current": current,
"remaining": remaining, "remaining": remaining,
"reset": reset_time "reset": reset_time,
} }
async def check_daily_requests( async def check_daily_requests(
@@ -304,7 +298,9 @@ class RateLimiterService:
api_key_id: uuid.UUID, api_key_id: uuid.UUID,
limit: int limit: int
) -> Tuple[bool, dict]: ) -> Tuple[bool, dict]:
"""检查日调用量限制""" """检查日调用量限制
使用原子 INCR先写后判断极低概率下允许轻微超限并发场景下可接受
"""
today = datetime.now().strftime("%Y%m%d") today = datetime.now().strftime("%Y%m%d")
key = f"rate_limit:daily:{api_key_id}:{today}" key = f"rate_limit:daily:{api_key_id}:{today}"
@@ -313,6 +309,7 @@ class RateLimiterService:
hour=0, minute=0, second=0, microsecond=0 hour=0, minute=0, second=0, microsecond=0
) )
expire_seconds = int((tomorrow_0 - now).total_seconds()) expire_seconds = int((tomorrow_0 - now).total_seconds())
reset_time = int(tomorrow_0.timestamp())
async with self.redis.pipeline() as pipe: async with self.redis.pipeline() as pipe:
pipe.incr(key) pipe.incr(key)
@@ -320,36 +317,74 @@ class RateLimiterService:
results = await pipe.execute() results = await pipe.execute()
current = results[0] current = results[0]
remaining = max(0, limit - current)
reset_time = int(tomorrow_0.timestamp())
return current <= limit, { if current > limit:
return False, {
"limit": limit,
"remaining": 0,
"reset": reset_time,
}
return True, {
"limit": limit, "limit": limit,
"remaining": remaining, "remaining": max(0, limit - current),
"reset": reset_time "reset": reset_time,
} }
async def check_all_limits( async def check_all_limits(
self, self,
api_key: ApiKey api_key: ApiKey,
db: Optional[Session] = None,
) -> Tuple[bool, str, dict]: ) -> Tuple[bool, str, dict]:
""" """
检查所有限制 检查所有限制,按以下顺序:
Returns: 1. API Key QPS取 api_key.rate_limit 与套餐 api_ops_rate_limit 的最小值作为限额
(is_allowed, error_message, rate_limit_headers) 2. API Key 日调用量
""" """
# Check QPS # 1. 取套餐限额与 api_key 自身限额的最小值
qps_ok, qps_info = await self.check_qps( effective_limit = api_key.rate_limit
api_key.id, if db is not None:
api_key.rate_limit try:
) from app.models.workspace_model import Workspace
from app.core.quota_manager import get_api_ops_rate_limit
cache_key = f"tenant_api_ops_limit:{api_key.workspace_id}"
cached = await self.redis.get(cache_key)
if cached is not None:
try:
tenant_limit = int(cached) if cached != "0" else None
except (ValueError, TypeError):
cached = None
tenant_limit = None
if cached is None:
workspace = db.query(Workspace).filter(Workspace.id == api_key.workspace_id).first()
if workspace:
tenant_limit = get_api_ops_rate_limit(db, workspace.tenant_id)
await self.redis.set(cache_key, str(tenant_limit) if tenant_limit else "0", ex=60)
else:
tenant_limit = None
if tenant_limit:
effective_limit = min(api_key.rate_limit, tenant_limit)
except Exception as e:
logger.warning(f"获取套餐限额失败,使用 api_key 自身限额: {e}")
# 用最终有效限额做 QPS 检查
qps_ok, qps_info = await self.check_qps(api_key.id, effective_limit)
if not qps_ok: if not qps_ok:
return False, "QPS limit exceeded", { # 判断是套餐限额触发还是 api_key 自身限额触发
if tenant_limit and effective_limit == tenant_limit and api_key.rate_limit > tenant_limit:
error_msg = "Tenant limit exceeded"
else:
error_msg = "QPS limit exceeded"
return False, error_msg, {
"X-RateLimit-Limit-QPS": str(qps_info["limit"]), "X-RateLimit-Limit-QPS": str(qps_info["limit"]),
"X-RateLimit-Remaining-QPS": str(qps_info["remaining"]), "X-RateLimit-Remaining-QPS": str(qps_info["remaining"]),
"X-RateLimit-Reset": str(qps_info["reset"]) "X-RateLimit-Reset": str(qps_info["reset"])
} }
# 2. 检查日调用量
daily_ok, daily_info = await self.check_daily_requests( daily_ok, daily_info = await self.check_daily_requests(
api_key.id, api_key.id,
api_key.daily_request_limit api_key.daily_request_limit
@@ -361,14 +396,13 @@ class RateLimiterService:
"X-RateLimit-Reset": str(daily_info["reset"]) "X-RateLimit-Reset": str(daily_info["reset"])
} }
headers = { return True, "", {
"X-RateLimit-Limit-QPS": str(qps_info["limit"]), "X-RateLimit-Limit-QPS": str(qps_info["limit"]),
"X-RateLimit-Remaining-QPS": str(qps_info["remaining"]), "X-RateLimit-Remaining-QPS": str(qps_info["remaining"]),
"X-RateLimit-Limit-Day": str(daily_info["limit"]), "X-RateLimit-Limit-Day": str(daily_info["limit"]),
"X-RateLimit-Remaining-Day": str(daily_info["remaining"]), "X-RateLimit-Remaining-Day": str(daily_info["remaining"]),
"X-RateLimit-Reset": str(daily_info["reset"]) "X-RateLimit-Reset": str(daily_info["reset"]),
} }
return True, "", headers
class ApiKeyAuthService: class ApiKeyAuthService:

View File

@@ -1280,7 +1280,7 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
} }
logger.info( logger.info(
f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}") f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={end_user.workspace_id}")
return result return result

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-03-07 16:49:59 * @Date: 2026-03-07 16:49:59
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-17 10:11:54 * @Last Modified time: 2026-04-20 18:14:34
*/ */
import { type FC, useEffect, useState } from 'react'; import { type FC, useEffect, useState } from 'react';
import { Select, Flex, Space } from 'antd'; import { Select, Flex, Space } from 'antd';
@@ -56,7 +56,7 @@ const ModelSelect: FC<ModelSelectProps> = ({ params, placeholder, fontClassName,
useEffect(() => { useEffect(() => {
if (updateOptions) updateOptions([...options, ...initialData]); if (updateOptions) updateOptions([...options, ...initialData]);
}, [options, initialData]) }, [JSON.stringify(options), JSON.stringify(initialData)])
return ( return (
<Select <Select

View File

@@ -82,7 +82,7 @@ const SubscriptionDetailModal = forwardRef<SubscriptionDetailModalRef>((_props,
{/* Features */} {/* Features */}
<Flex gap={12} vertical className="rb:space-y-3 rb:mb-4 rb:h-[calc(100vh-341px)]! rb:overflow-y-auto"> <Flex gap={12} vertical className="rb:space-y-3 rb:mb-4 rb:h-[calc(100vh-341px)]! rb:overflow-y-auto">
{billingUnits.map(({ key, unit, icon }) => { {billingUnits.map(({ key, unit, icon }) => {
const value = detail?.quota[key as keyof Subscription['quota']]; const value = detail?.quotas[key as keyof Subscription['quotas']];
if (value === undefined || value === null) return null; if (value === undefined || value === null) return null;
return ( return (
<UnitWrapper <UnitWrapper

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-02 15:25:31 * @Date: 2026-02-02 15:25:31
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-16 17:35:38 * @Last Modified time: 2026-04-20 10:15:20
*/ */
/** /**
* SiderMenu Component * SiderMenu Component
@@ -114,7 +114,7 @@ export interface Subscription {
started_at: number | null started_at: number | null
expired_at: number | null expired_at: number | null
status: string status: string
quota: SubscriptionQuota quotas: SubscriptionQuota
created_at: number created_at: number
updated_at: number updated_at: number
} }
@@ -417,7 +417,7 @@ const Menu: FC<{
<div className="rb:grid rb:grid-cols-4 rb:mt-4"> <div className="rb:grid rb:grid-cols-4 rb:mt-4">
{['workspace_quota', 'skill_quota', 'app_quota', 'model_quota'].map(key => ( {['workspace_quota', 'skill_quota', 'app_quota', 'model_quota'].map(key => (
<div key={key} className="rb:text-center"> <div key={key} className="rb:text-center">
<div className="rb:text-[13px] rb:font-[MiSans-Semibold] rb:font-semibold">{subscription.quota?.[key as keyof typeof subscription.quota]}</div> <div className="rb:text-[13px] rb:font-[MiSans-Semibold] rb:font-semibold">{subscription.quotas?.[key as keyof typeof subscription.quotas]}</div>
<div className="rb:mt-1 rb:text-[#5B6167] rb:text-[10px] rb:leading-3.5">{t(`index.${key}`)}</div> <div className="rb:mt-1 rb:text-[#5B6167] rb:text-[10px] rb:leading-3.5">{t(`index.${key}`)}</div>
</div> </div>
))} ))}

View File

@@ -2537,6 +2537,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
checkListErrors: { checkListErrors: {
'llm.model_id': 'Model', 'llm.model_id': 'Model',
'llm.messages': 'Messages', 'llm.messages': 'Messages',
'llm.vision_input': 'Vision Variable',
'end.output': 'Output', 'end.output': 'Output',
'knowledge-retrieval.knowledge_retrieval': 'Knowledge bases', 'knowledge-retrieval.knowledge_retrieval': 'Knowledge bases',
'parameter-extractor.model_id': 'Model', 'parameter-extractor.model_id': 'Model',
@@ -2564,6 +2565,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
'jinja-render.template': 'Template', 'jinja-render.template': 'Template',
'document-extractor.file_selector': 'File variable', 'document-extractor.file_selector': 'File variable',
'list-operator.input_list': 'Input list', 'list-operator.input_list': 'Input list',
'tool.tool_id': 'Tool',
}, },
checkListHasErrors: 'Please resolve all issues in the checklist before publishing', checkListHasErrors: 'Please resolve all issues in the checklist before publishing',
variableSelect: { variableSelect: {

View File

@@ -2501,6 +2501,7 @@ export const zh = {
checkListErrors: { checkListErrors: {
'llm.model_id': '模型', 'llm.model_id': '模型',
'llm.messages': '提示词', 'llm.messages': '提示词',
'llm.vision_input': '视觉变量',
'end.output': '回复', 'end.output': '回复',
'knowledge-retrieval.knowledge_retrieval': '知识库', 'knowledge-retrieval.knowledge_retrieval': '知识库',
'parameter-extractor.model_id': '模型', 'parameter-extractor.model_id': '模型',
@@ -2528,6 +2529,7 @@ export const zh = {
'jinja-render.template': '模板', 'jinja-render.template': '模板',
'document-extractor.file_selector': '文件变量', 'document-extractor.file_selector': '文件变量',
'list-operator.input_list': '输入变量', 'list-operator.input_list': '输入变量',
'tool.tool_id': '工具',
}, },
checkListHasErrors: '发布前确认检查清单中所有问题均已解决', checkListHasErrors: '发布前确认检查清单中所有问题均已解决',
variableSelect: { variableSelect: {

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-28 14:08:14 * @Date: 2026-02-28 14:08:14
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-13 18:17:32 * @Last Modified time: 2026-04-20 16:52:32
*/ */
/** /**
* UploadModal Component * UploadModal Component
@@ -16,6 +16,7 @@
import { forwardRef, useImperativeHandle, useState, useMemo } from 'react'; import { forwardRef, useImperativeHandle, useState, useMemo } from 'react';
import { Form, Steps, Flex, Alert, Button, Result, message } from 'antd'; import { Form, Steps, Flex, Alert, Button, Result, message } from 'antd';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useNavigate } from 'react-router-dom';
import type { Application, UploadModalRef } from '../types' import type { Application, UploadModalRef } from '../types'
import RbModal from '@/components/RbModal' import RbModal from '@/components/RbModal'
@@ -51,6 +52,7 @@ const UploadModal = forwardRef<UploadModalRef, UploadModalProps>(({
id id
}, ref) => { }, ref) => {
const { t } = useTranslation(); const { t } = useTranslation();
const navigate = useNavigate();
// State management // State management
const [visible, setVisible] = useState(false); // Modal visibility const [visible, setVisible] = useState(false); // Modal visibility
@@ -146,6 +148,10 @@ const UploadModal = forwardRef<UploadModalRef, UploadModalProps>(({
window.open(`/#/application/config/${appId}`, '_blank'); window.open(`/#/application/config/${appId}`, '_blank');
} }
break; break;
case 'list':
if (id) {
navigate('/application')
}
} }
}, 100) }, 100)
}; };

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 17:09:03 * @Date: 2026-02-03 17:09:03
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-31 12:21:56 * @Last Modified time: 2026-04-20 16:59:25
*/ */
/** /**
* Memory Conversation Page * Memory Conversation Page
@@ -78,8 +78,8 @@ interface DataItem {
id: string; id: string;
question: string; question: string;
type: string; type: string;
reason: string; reason?: string;
} }
/** /**
* Log item for conversation analysis * Log item for conversation analysis
*/ */
@@ -88,13 +88,15 @@ export interface LogItem {
title: string; title: string;
data?: DataItem[] | AnyObject; data?: DataItem[] | AnyObject;
raw_results?: string | Record<string, AnyObject>; raw_results?: string | Record<string, AnyObject>;
raw_result?: Array<AnyObject>;
summary?: string; summary?: string;
query?: string; query?: string;
reason?: string; reason?: string;
result?: string; result?: string;
original_query: string; original_query?: string;
index?: number; index?: number;
result_count?: number; result_count?: number;
total?: number;
} }
/** /**
@@ -242,7 +244,6 @@ const MemoryConversation: FC = () => {
<ContentWrapper key={vo.id}> <ContentWrapper key={vo.id}>
<> <>
<div className="rb:font-medium rb:text-[#212332]">{vo.id}. {vo.question}</div> <div className="rb:font-medium rb:text-[#212332]">{vo.id}. {vo.question}</div>
<div className="rb:mt-2 rb:text-[#5B6167]">{vo.reason}</div>
</> </>
</ContentWrapper> </ContentWrapper>
))} ))}
@@ -260,25 +261,9 @@ const MemoryConversation: FC = () => {
</ContentWrapper> </ContentWrapper>
))} ))}
</Flex> </Flex>
: log.type === 'search_result' && log.raw_results && typeof log.raw_results !== 'string' : log.type === 'search_result' && log.result
? <ContentWrapper> ? <ContentWrapper>
<div className="rb:font-medium rb:text-[#212332] rb:mb-2">{log.query}</div> <Markdown content={log.result} />
{(log.raw_results.reranked_results as AnyObject)?.communities?.length > 0 && <>
<div className="rb:font-medium rb:text-[#212332]">{t('memoryConversation.communities')}</div>
<ul className='rb:mt-2 rb:text-[#5B6167] rb:list-disc rb:pl-4'>
{((log.raw_results.reranked_results as AnyObject)?.communities as { content: string }[]).map((item, index: number) => (
<li key={index}>{item.content}</li>
))}
</ul>
</>}
{(log.raw_results.reranked_results as AnyObject)?.summaries?.length > 0 && <>
<div className="rb:font-medium rb:text-[#212332]">{t('memoryConversation.summaries')}</div>
<ul className='rb:mt-2 rb:text-[#5B6167] rb:list-disc rb:pl-4'>
{((log.raw_results.reranked_results as AnyObject)?.summaries as { content: string }[]).map((item, index: number) => (
<li key={index}>{item.content}</li>
))}
</ul>
</>}
</ContentWrapper> </ContentWrapper>
: log.type === 'retrieval_summary' && log.summary : log.type === 'retrieval_summary' && log.summary
? <ContentWrapper> ? <ContentWrapper>

View File

@@ -1,3 +1,9 @@
/*
* @Author: ZhaoYing
* @Date: 2026-04-09 18:58:21
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-20 10:39:17
*/
import { useState, useCallback, useEffect, useRef, type FC } from 'react' import { useState, useCallback, useEffect, useRef, type FC } from 'react'
import { Popover, Flex } from 'antd' import { Popover, Flex } from 'antd'
import { WarningFilled } from '@ant-design/icons' import { WarningFilled } from '@ant-design/icons'
@@ -49,7 +55,7 @@ const specialValidators: Record<string, (val: any) => boolean> = {
if (expr?.sub_variable_condition?.conditions?.length > 0) return expr.sub_variable_condition?.conditions.every(isSubExprSet) if (expr?.sub_variable_condition?.conditions?.length > 0) return expr.sub_variable_condition?.conditions.every(isSubExprSet)
if (!expr.left) return false if (!expr.left) return false
if (['not_empty', 'empty'].includes(expr.operator)) return true if (['not_empty', 'empty'].includes(expr.operator)) return true
return !!expr.left && (!!expr.right || typeof expr.right === 'boolean' || typeof expr.right === 'number') return !!expr.left && (expr?.sub_variable_condition || !!expr.right || typeof expr.right === 'boolean' || typeof expr.right === 'number')
} }
return val.some(c => !c?.expressions?.length || c.expressions.some((expr: any) => !isExprSet(expr))) return val.some(c => !c?.expressions?.length || c.expressions.some((expr: any) => !isExprSet(expr)))
}, },
@@ -100,6 +106,18 @@ function validateNode(type: string, config: Record<string, any>): CheckError[] {
if (isInvalid) errors.push({ key: specialKey, message: '' }) if (isInvalid) errors.push({ key: specialKey, message: '' })
}) })
// llm: vision_input required when vision is enabled
if (type === 'llm') {
const vision = get('vision')
if (vision === true || vision === 'true') {
const visionInput = get('vision_input')
console.log('vision', vision, isEmpty(visionInput))
if (isEmpty(visionInput)) {
errors.push({ key: 'llm.vision_input', message: '' })
}
}
}
// http-request body.data (binary) — not a top-level required field, check separately // http-request body.data (binary) — not a top-level required field, check separately
if (type === 'http-request') { if (type === 'http-request') {
const body = get('body') const body = get('body')

View File

@@ -9,7 +9,7 @@ import { useVariableList } from '../Properties/hooks/useVariableList'
import { isSubExprSet } from '../../utils' import { isSubExprSet } from '../../utils'
import { fileSubFieldOperators } from '../Properties/CaseList' import { fileSubFieldOperators } from '../Properties/CaseList'
const caculateIsSet = (item: any, type: string) => { const calculateIsSet = (item: any, type: string) => {
switch (type) { switch (type) {
case 'categories': case 'categories':
return typeof item?.class_name === 'string' && item?.class_name !== '' return typeof item?.class_name === 'string' && item?.class_name !== ''
@@ -79,7 +79,7 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => {
<div key={index} className="rb:bg-[#F0F3F8] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)] rb:rounded-md rb:py-1 rb:px-1.5 rb:text-[10px] rb:text-[#5B6167] rb:font-medium rb:leading-3.5"> <div key={index} className="rb:bg-[#F0F3F8] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)] rb:rounded-md rb:py-1 rb:px-1.5 rb:text-[10px] rb:text-[#5B6167] rb:font-medium rb:leading-3.5">
<Flex justify="space-between"> <Flex justify="space-between">
<span>{t('workflow.config.question-classifier.class_name')} {index + 1}</span> <span>{t('workflow.config.question-classifier.class_name')} {index + 1}</span>
{caculateIsSet(item, 'categories') ? t(`workflow.config.${data.type}.set`) : t(`workflow.config.${data.type}.unset`)} {calculateIsSet(item, 'categories') ? t(`workflow.config.${data.type}.set`) : t(`workflow.config.${data.type}.unset`)}
</Flex> </Flex>
</div> </div>
))} ))}
@@ -89,17 +89,24 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => {
<Flex vertical gap={4} className="rb:mt-3!"> <Flex vertical gap={4} className="rb:mt-3!">
{data.config?.cases?.defaultValue.map((item: any, index: number) => ( {data.config?.cases?.defaultValue.map((item: any, index: number) => (
<div key={index} className={item.expressions.length > 0 ? '' : 'rb:mb-1'}> <div key={index} className={item.expressions.length > 0 ? '' : 'rb:mb-1'}>
<Flex justify={item.expressions.length > 0 ? "space-between" : 'end'} className="rb:mb-1"> <Flex justify={item.expressions.length > 0 ? "space-between" : 'end'} className="rb:mb-1! rb:leading-4">
{item.expressions.length > 0 && <span className="rb:text-[#5B6167] rb:text-[10px]">CASE{index + 1}</span>} {item.expressions.length > 0 && <span className="rb:text-[#5B6167] rb:text-[10px] rb:pl-1">CASE{index + 1}</span>}
<span className="rb:text-[#212332] rb:font-medium rb:text-[12px]">{index === 0 ? 'IF' : `ELIF`}</span> <span className="rb:text-[#212332] rb:font-medium rb:text-[12px]">{index === 0 ? 'IF' : `ELIF`}</span>
</Flex> </Flex>
{item.expressions.length > 0 && <Flex vertical gap={2}> {item.expressions.length > 0 && <Flex vertical gap={2}>
{item.expressions.map((expression: any, eIndex: number) => ( {item.expressions.map((expression: any, eIndex: number) => (
<div key={eIndex} className="rb:relative"> <div key={eIndex} className="rb:relative">
{item.expressions.length > 1 && eIndex > 0 && <div className="rb:absolute rb:-top-2 rb:right-2 rb:text-[10px] rb:text-[#155EEF] rb:font-medium rb:leading-3.5 rb:text-right rb:pr-0.5">{item.logical_operator?.toLocaleUpperCase()}</div>} {item.expressions.length > 1 && eIndex > 0 &&
<Flex vertical gap={2} className="rb:bg-[#F0F3F8] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)] rb:rounded-md rb:py-1! rb:px-1.5! rb:text-[10px] rb:text-[#5B6167] rb:font-medium rb:leading-3.5"> <div className="rb:absolute rb:-top-2 rb:right-2 rb:text-[10px] rb:text-[#155EEF] rb:font-medium rb:leading-3.5 rb:text-right rb:pr-0.5">{item.logical_operator?.toLocaleUpperCase()}</div>
}
<Flex vertical gap={2}
className={clsx("rb:bg-[#F0F3F8] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)] rb:rounded-md rb:px-1.5! rb:text-[10px] rb:text-[#5B6167] rb:font-medium rb:leading-4", {
'rb:pt-1!': expression.sub_variable_condition?.conditions?.length > 0,
'rb:py-1!': !expression.sub_variable_condition?.conditions || !expression.sub_variable_condition?.conditions?.length
})}
>
<Flex align="center"> <Flex align="center">
{caculateIsSet(expression, 'cases') {calculateIsSet(expression, 'cases')
? <> ? <>
{labelRender(expression.left)} {labelRender(expression.left)}
<span className="rb:mx-1">{getLocaleField(expression.operator, typeof expression.right)}</span> <span className="rb:mx-1">{getLocaleField(expression.operator, typeof expression.right)}</span>
@@ -109,11 +116,16 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => {
} }
</Flex> </Flex>
{expression.sub_variable_condition?.conditions?.length > 0 && expression.sub_variable_condition?.conditions.every(isSubExprSet) {expression.sub_variable_condition?.conditions?.length > 0 && expression.sub_variable_condition?.conditions.every(isSubExprSet)
? <div className="rb-border-l rb:ml-2 rb:mt-1.5"> ? <div className="rb-border-l rb:ml-2 rb:mt-1">
{expression.sub_variable_condition?.conditions.map((sub: any, sIndex: number) => ( {expression.sub_variable_condition?.conditions.map((sub: any, sIndex: number) => (
<div key={sIndex} className="rb:relative"> <div key={sIndex} className="rb:relative">
{expression.sub_variable_condition?.conditions.length > 1 && sIndex > 0 && <div className="rb:absolute rb:-top-2 rb:right-2 rb:text-[10px] rb:text-[#155EEF] rb:font-medium rb:leading-3.5 rb:text-right rb:pr-0.5">{expression.sub_variable_condition?.logical_operator?.toLocaleUpperCase()}</div>} {expression.sub_variable_condition?.conditions.length > 1 && sIndex > 0 && <div className="rb:absolute rb:-top-2 rb:right-2 rb:text-[10px] rb:text-[#155EEF] rb:font-medium rb:leading-3.5 rb:text-right rb:pr-0.5">{expression.sub_variable_condition?.logical_operator?.toLocaleUpperCase()}</div>}
<Flex align="center" className=" rb:py-1! rb:px-1.5! rb:text-[10px] rb:text-[#5B6167] rb:font-medium rb:leading-3.5"> <Flex align="center"
className={clsx("rb:px-1.5! rb:text-[10px] rb:text-[#5B6167] rb:font-medium rb:leading-3.5", {
'rb:py-1!': sIndex !== 0,
'rb:pb-1': sIndex === 0
})}
>
<span className="rb:text-[#155EEF]">{sub.key}</span> <span className="rb:text-[#155EEF]">{sub.key}</span>
<span className="rb:mx-1">{getSubLocaleField(sub.operator, sub.key)}</span> <span className="rb:mx-1">{getSubLocaleField(sub.operator, sub.key)}</span>
<span className="rb:break-all rb:line-clamp-1"> <span className="rb:break-all rb:line-clamp-1">
@@ -129,7 +141,7 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => {
))} ))}
</div> </div>
: expression.sub_variable_condition?.conditions?.length > 0 : expression.sub_variable_condition?.conditions?.length > 0
? <Flex align="center" className="rb:mt-1! rb:pl-2! rb:rounded-md rb:py-1! rb:px-1.5! rb:text-[10px] rb:text-[#5B6167] rb:font-medium rb:leading-3.5"> ? <Flex align="center" className="rb:pl-2! rb:rounded-md rb:pb-1! rb:px-1.5! rb:text-[10px] rb:text-[#5B6167] rb:font-medium rb:leading-4">
{t(`workflow.config.${data.type}.unset`)} {t(`workflow.config.${data.type}.unset`)}
</Flex> </Flex>
: null : null

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-09 18:24:53 * @Date: 2026-02-09 18:24:53
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-17 20:47:49 * @Last Modified time: 2026-04-20 10:46:05
*/ */
import { useEffect, useMemo, type FC } from 'react' import { useEffect, useMemo, type FC } from 'react'
import clsx from 'clsx' import clsx from 'clsx'
@@ -39,7 +39,7 @@ interface Expression {
sub_variable_condition?: SubVariableCondition; sub_variable_condition?: SubVariableCondition;
} }
interface CaseItem { export interface CaseItem {
logical_operator: 'and' | 'or'; logical_operator: 'and' | 'or';
expressions: Expression[]; expressions: Expression[];
} }
@@ -274,7 +274,9 @@ const ArrayFileSubConditions: FC<ArrayFileSubConditionsProps> = ({ conditionFiel
className="rb:w-full!" className="rb:w-full!"
suffix="Byte" suffix="Byte"
size="small" size="small"
onChange={(value) => { form.setFieldValue([name, caseIndex, 'expressions', conditionIndex, 'right'], value); }} onChange={(value) => {
form.setFieldValue([name, caseIndex, 'expressions', conditionIndex, 'sub_variable_condition', 'conditions', subIndex, 'value'], value);
}}
/> />
} }
</Form.Item> </Form.Item>
@@ -483,13 +485,24 @@ const CaseList: FC<CaseListProps> = ({
form.setFieldValue([name, index, 'logical_operator'], currentValue === 'and' ? 'or' : 'and'); form.setFieldValue([name, index, 'logical_operator'], currentValue === 'and' ? 'or' : 'and');
}; };
const handleLeftFieldChange = (caseIndex: number, conditionIndex: number, newValue: string) => { const handleLeftFieldChange = (caseIndex: number, conditionIndex: number, newValue: string, option?: Suggestion | undefined) => {
form.setFieldValue([name, caseIndex, 'expressions', conditionIndex], { if (option?.dataType === 'array[file]') {
left: newValue, form.setFieldValue([name, caseIndex, 'expressions', conditionIndex], {
operator: undefined, left: newValue,
right: undefined, operator: undefined,
input_type: 'constant' sub_variable_condition: {
}); conditions: [],
logical_operator: 'and'
}
});
} else {
form.setFieldValue([name, caseIndex, 'expressions', conditionIndex], {
left: newValue,
operator: undefined,
right: undefined,
input_type: 'constant'
});
}
}; };
const handleAddCase = (addCaseFunc: Function) => { const handleAddCase = (addCaseFunc: Function) => {
@@ -590,7 +603,7 @@ const CaseList: FC<CaseListProps> = ({
options={options} options={options}
size="small" size="small"
allowClear={false} allowClear={false}
onChange={(val) => handleLeftFieldChange(caseIndex, conditionIndex, val as string)} onChange={(val, option) => handleLeftFieldChange(caseIndex, conditionIndex, val as string, option as unknown as Suggestion)}
variant="borderless" variant="borderless"
className="rb:w-36!" className="rb:w-36!"
/> />

View File

@@ -147,6 +147,11 @@ const ToolConfig: FC<{ options: Suggestion[]; }> = ({
}; };
const handleChange: CascaderProps<Option>['onChange'] = (value, selectedOptions) => { const handleChange: CascaderProps<Option>['onChange'] = (value, selectedOptions) => {
if (!value) {
setParameters([])
form.resetFields()
return
}
const targetOption = selectedOptions[selectedOptions.length - 1]; const targetOption = selectedOptions[selectedOptions.length - 1];
const curParameters = [...(targetOption.parameters ?? [])] const curParameters = [...(targetOption.parameters ?? [])]
setParameters([...curParameters]) setParameters([...curParameters])

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 15:06:18 * @Date: 2026-02-03 15:06:18
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-16 17:52:30 * @Last Modified time: 2026-04-20 14:36:41
*/ */
import LoopNode from './components/Nodes/LoopNode'; import LoopNode from './components/Nodes/LoopNode';
import NormalNode from './components/Nodes/NormalNode'; import NormalNode from './components/Nodes/NormalNode';
@@ -428,7 +428,8 @@ export const nodeLibrary: NodeLibrary[] = [
{ type: "tool", icon: 'rb:bg-[url("@/assets/images/workflow/tools.svg")]', { type: "tool", icon: 'rb:bg-[url("@/assets/images/workflow/tools.svg")]',
config: { config: {
tool_id: { tool_id: {
type: 'cascader' type: 'cascader',
required: true
}, },
tool_parameters: { tool_parameters: {
type: 'define' type: 'define'
@@ -734,7 +735,7 @@ export const portTextAttrs = { fontSize: 12, fill: '#5B6167' }
/** /**
* Port position arguments * Port position arguments
*/ */
export const portItemArgsY = 26.5; export const portItemArgsY = 27.5;
export const portArgs = { x: nodeWidth, y: portItemArgsY } export const portArgs = { x: nodeWidth, y: portItemArgsY }
const defaultPortGroup = { const defaultPortGroup = {

View File

@@ -2,136 +2,70 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-03-24 15:07:49 * @Date: 2026-03-24 15:07:49
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-17 20:40:47 * @Last Modified time: 2026-04-20 14:20:34
*/ */
import { portItemArgsY, conditionNodePortItemArgsY, conditionNodeHeight } from './constant' import { conditionNodePortItemArgsY, conditionNodeHeight } from './constant'
/**
* Calculate the total height of a condition (if-else) node based on its cases.
*
* The height is composed of:
* - `conditionNodeHeight`: the base height of the node (header + padding).
* - `(cases.length - 1) * 26`: vertical spacing added for each additional case
* beyond the first (each case separator row is 26px).
* - `exprCount * 20`: each individual expression row occupies 20px.
* - `hasMultiExprCount * 3`: a small extra padding (3px per expression) is added
* for cases that contain more than one expression, to account for the logical
* operator indicator (AND/OR) between expressions.
*
* @param cases - Array of case objects, each containing an `expressions` array.
* @returns The total pixel height for the condition node.
*/
export const isSubExprSet = (sub: any) => { export const isSubExprSet = (sub: any) => {
if (!sub?.key) return false; if (!sub?.key) return false;
if (['not_empty', 'empty'].includes(sub?.operator)) return true; if (['not_empty', 'empty'].includes(sub?.operator)) return true;
return !!sub.value || typeof sub.value === 'boolean' || typeof sub.value === 'number'; return !!sub.value || typeof sub.value === 'boolean' || typeof sub.value === 'number';
}; };
/**
const getEffectiveExprCount = (expr: any): number => { * Calculate the total height of a condition (if-else) node based on its cases.
const subs = expr?.sub_variable_condition?.conditions; * Uses the same per-expression height logic as getConditionNodeCasePortY.
if (subs?.length && subs.every(isSubExprSet)) return 1 + subs.length; */
if (subs?.length > 0) { export const calcConditionNodeTotalHeight = (cases: any[]) => {
return 2 const casesHeight = cases.reduce((acc: number, c: any) => {
} const exprs = c?.expressions ?? [];
return 1; const n = exprs.length;
const exprsHeight = n === 0 ? 0 : exprs.reduce((s: number, e: any) => s + calcExpressionHeight(e), 0) + 2 * (n - 1);
return acc + 20 + exprsHeight;
}, 0);
return conditionNodeHeight + casesHeight + (cases.length - 1) * 4 - 27.5;
}; };
export const calcConditionNodeTotalHeight = (cases: any[]) => { /**
// Total number of effective expression rows (sub_variable_condition expand height when all set) * Height of a single expression block in ConditionNode (px).
const exprCount = cases.reduce((acc: number, c: any) => *
acc + (c?.expressions?.reduce((s: number, e: any) => s + getEffectiveExprCount(e), 0) || 0), 0); * expression outer Flex padding:
// Sum of effective expression counts only for cases that have more than one expression * - has sub conditions (length > 0): pt-1 (4px top only)
const hasMultiExprCount = cases.reduce((acc: number, c: any) => { * - no sub conditions: py-1 (4px top + 4px bottom)
if (!c?.expressions?.length || c.expressions.length <= 1) return acc; * expression main row: leading-4 = 16px
const effectiveCount = c.expressions.reduce((s: number, e: any) => s + getEffectiveExprCount(e), 0); * sub_variable_condition block (mt-1 = 4px gap):
return acc + effectiveCount; * - all isSet, m subs: sub[0] = leading-3.5(14) + pb-1(4) = 18px;
}, 0); * sub[k>0] = py-1(8) + leading-3.5(14) = 22px
* total = 18 + 22*(m-1)
return conditionNodeHeight + (cases.length - 1) * 26 + exprCount * 20 + hasMultiExprCount * 3; * - exists but not all isSet: pb-1(4) + leading-4(16) = 20px
*/
const calcExpressionHeight = (expression: any): number => {
const subs = expression?.sub_variable_condition?.conditions;
if (!subs?.length) return 24; // py-1(8) + leading-4(16)
const subBlockHeight = subs.every(isSubExprSet)
? 18 + 22 * (subs.length - 1)
: 20;
return 4 + 16 + 4 + subBlockHeight - 2; // pt-1 + main row + mt-1 + sub block (-2 rendering correction)
}; };
/** /**
* Calculate the Y-coordinate of the right-side output port for a specific case * Calculate the Y-coordinate of the right-side output port for a specific case
* in a condition (if-else) node. * in a condition (if-else) node, aligned with the IF/ELIF label in ConditionNode.
* *
* The port position is determined by iterating through all preceding cases * Layout (from node top):
* (index 0 to caseIndex - 1) and accumulating their visual heights. Several * - 12px padding-top + 24px header + 12px mt-3 = 48px to cases area
* pixel-level corrections are applied to align ports with the rendered UI: * - Each IF/ELIF label row: leading-4 (16px), center at +8px → first port Y = 56.5
* * - Each case: IF/ELIF row (leading-4=16) + mb-1(4) + expressions (gap={2}=2px between)
* 1. **Base offset**: starts at `conditionNodePortItemArgsY`, which is the Y * - Gap between cases (Flex gap={4}): 4px
* position of the first case port relative to the node top.
*
* 2. **Per-case accumulation**: for each preceding case with `n` expressions,
* add `portItemArgsY * (n + 1)` — this accounts for `n` expression rows
* plus one case header/separator row.
*
* 3. **Single-expression correction**: cases with exactly 1 expression render
* slightly shorter than the generic formula predicts. Subtract
* `singleExprCount * 7 + 2` to compensate for the reduced row height when
* no logical operator row is shown.
*
* 4. **Multi-expression correction**: cases with 2+ expressions have a compact
* logical operator row. Subtract `multiExprCount * 9` to offset the
* over-estimated spacing.
*
* 5. **Extra expression correction**: for cases with more than 2 expressions,
* each additional expression beyond the second introduces a minor spacing
* discrepancy. Subtract `(extraExprs + 1) * 2` to fine-tune alignment.
*
* @param cases - Array of case objects, each containing an `expressions` array.
* @param caseIndex - The zero-based index of the target case whose port Y is needed.
* @returns The Y-coordinate (in pixels) for the output port of the given case.
*/ */
export const getConditionNodeCasePortY = (cases: any[], caseIndex: number) => { export const getConditionNodeCasePortY = (cases: any[], caseIndex: number) => {
let y = conditionNodePortItemArgsY; let y = conditionNodePortItemArgsY; // 56.5, center of first IF label
let singleExprCount = 0;
let multiExprCount = 0;
let extraExprs = 0;
let portItemArgsYNum = 0;
for (let i = 0; i < caseIndex; i++) { for (let i = 0; i < caseIndex; i++) {
const notHasSub = cases[i]?.expressions?.filter((e: any) => !e?.sub_variable_condition?.conditions || e?.sub_variable_condition?.conditions.length <1).length const exprs = cases[i]?.expressions ?? [];
const n = cases[i]?.expressions?.length || 0; const n = exprs.length;
let casePortItemArgsYNum = n + 1; // IF/ELIF row (16) + mb-1 (4) = 20px base; expressions: sum of heights + 2px gap between
// Add extra y for expressions with all sub_variable_condition set const exprsHeight = n === 0 ? 0 : exprs.reduce((acc: number, e: any) => acc + calcExpressionHeight(e), 0) + 2 * (n - 1);
cases[i]?.expressions?.forEach((expr: any) => { y += 20 + exprsHeight + 4; // case height + Flex gap between cases
const subs = expr?.sub_variable_condition?.conditions;
if (subs?.length && subs.every(isSubExprSet)) {
casePortItemArgsYNum += subs.length;
} else if (subs?.length) {
casePortItemArgsYNum += 1
}
});
portItemArgsYNum += casePortItemArgsYNum;
if (n === 1 && !cases[i]?.expressions?.some((e: any) => e?.sub_variable_condition?.conditions?.length > 0)) {
singleExprCount++
} else if (n >= 2 || cases[i]?.expressions?.some((e: any) => e?.sub_variable_condition?.conditions?.length > 0)) {
multiExprCount++;
cases[i]?.expressions?.forEach((e: any) => {
const subs = e?.sub_variable_condition?.conditions;
if (subs?.length && subs.every(isSubExprSet) && subs.length > 1) {
extraExprs += subs.length + 2;
}
});
console.log('extraExprs notHasSub', notHasSub)
if (notHasSub > 3) {
extraExprs += n - 2 + notHasSub/4;
} else {
extraExprs += n - 2 + notHasSub/4
}
}
} }
console.log('singleExprCount', singleExprCount, 'multiExprCount', multiExprCount, 'extraExprs', extraExprs)
y += portItemArgsY * portItemArgsYNum
// Correction for single-expression cases (slightly shorter rendered height)
if (singleExprCount > 0) y -= singleExprCount * 7 + 2;
// Correction for multi-expression cases (compact logical operator row)
y -= multiExprCount * 9;
// Correction for cases with more than 2 expressions (minor spacing drift)
if (extraExprs > 0) y -= (extraExprs + 1) * 2;
return y; return y;
}; };