Merge branch 'release/v0.3.1' into develop

This commit is contained in:
Ke Sun
2026-04-23 12:16:57 +08:00
54 changed files with 1240 additions and 707 deletions

View File

@@ -2,6 +2,8 @@
Celery Worker 入口点
用于启动 Celery Worker: celery -A app.celery_worker worker --loglevel=info
"""
from celery.signals import worker_process_init
from app.celery_app import celery_app
from app.core.logging_config import LoggingConfig, get_logger
@@ -13,4 +15,39 @@ logger.info("Celery worker logging initialized")
# 导入任务模块以注册任务
import app.tasks
@worker_process_init.connect
def _reinit_db_pool(**kwargs):
"""
prefork 子进程启动时重建被 fork 污染的资源。
fork() 后子进程继承了父进程的:
1. SQLAlchemy 连接池 — 多进程共享 TCP socket 导致 DB 连接损坏
2. ThreadPoolExecutor — fork 后线程状态不确定,第二个任务会死锁
"""
# 重建 DB 连接池
from app.db import engine
engine.dispose()
logger.info("DB connection pool disposed for forked worker process")
# 重建模块级 ThreadPoolExecutorfork 后线程池不可用)
try:
from app.core.rag.deepdoc.parser import figure_parser
from concurrent.futures import ThreadPoolExecutor
figure_parser.shared_executor = ThreadPoolExecutor(max_workers=10)
logger.info("figure_parser.shared_executor recreated")
except Exception as e:
logger.warning(f"Failed to recreate figure_parser.shared_executor: {e}")
try:
from app.core.rag.utils import libre_office
from concurrent.futures import ThreadPoolExecutor
import os
max_workers = os.cpu_count() * 2 if os.cpu_count() else 4
libre_office.executor = ThreadPoolExecutor(max_workers=max_workers)
logger.info("libre_office.executor recreated")
except Exception as e:
logger.warning(f"Failed to recreate libre_office.executor: {e}")
__all__ = ['celery_app']

View File

@@ -60,7 +60,7 @@ def _build_default_free_plan():
"app_quota": 2,
"knowledge_capacity_quota": 0.3,
"memory_engine_quota": 1,
"end_user_quota": 1,
"end_user_quota": 10,
"ontology_project_quota": 3,
"model_quota": 1,
"api_ops_rate_limit": 50,

View File

@@ -167,6 +167,8 @@ def update_api_key(
return success(data=api_key_schema.ApiKey.model_validate(api_key), msg="API Key 更新成功")
except BusinessException:
raise
except Exception as e:
logger.error(f"未知错误: {str(e)}", extra={
"api_key_id": str(api_key_id),

View File

@@ -219,6 +219,7 @@ def delete_app(
@router.post("/{app_id}/copy", summary="复制应用")
@cur_workspace_access_guard()
@check_app_quota
def copy_app(
app_id: uuid.UUID,
new_name: Optional[str] = None,
@@ -1144,6 +1145,7 @@ async def import_workflow_config(
@router.post("/workflow/import/save")
@cur_workspace_access_guard()
@check_app_quota
async def save_workflow_import(
data: WorkflowImportSave,
db: Session = Depends(get_db),
@@ -1281,6 +1283,10 @@ async def import_app(
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
target_app_id = uuid.UUID(app_id) if app_id else None
# 仅新建应用时检查配额,覆盖已有应用时跳过
if target_app_id is None:
from app.core.quota_manager import _check_quota
_check_quota(db, current_user.tenant_id, "app_quota", "app", workspace_id=current_user.current_workspace_id)
result_app, warnings = AppDslService(db).import_dsl(
dsl=dsl,
workspace_id=current_user.current_workspace_id,

View File

@@ -457,7 +457,7 @@ async def retrieve_chunks(
if doc.metadata["doc_id"] not in seen_ids:
seen_ids.add(doc.metadata["doc_id"])
unique_rs.append(doc)
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k)
rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k) if unique_rs else []
if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph:
kb_ids = [str(kb_id) for kb_id in private_kb_ids]
workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids]

View File

@@ -237,7 +237,6 @@ def delete_model_base(
@router.post("/model_plaza/{model_base_id}/add", response_model=ApiResponse)
@check_model_quota
def add_model_from_plaza(
model_base_id: uuid.UUID,
db: Session = Depends(get_db),
@@ -275,7 +274,6 @@ def get_model_by_id(
@router.post("", response_model=ApiResponse)
@check_model_quota
async def create_model(
model_data: model_schema.ModelConfigCreate,
db: Session = Depends(get_db),

View File

@@ -219,9 +219,20 @@ def list_conversations(
end_user_repo = EndUserRepository(db)
app_service = AppService(db)
app = app_service._get_app_or_404(share.app_id)
workspace_id = app.workspace_id
# 仅在新建终端用户时检查配额
existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id)
if existing_end_user is None:
from app.core.quota_manager import _check_quota
from app.models.workspace_model import Workspace
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
if ws:
_check_quota(db, ws.tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=share.app_id,
workspace_id=app.workspace_id,
workspace_id=workspace_id,
other_id=other_id
)
logger.debug(new_end_user.id)
@@ -309,7 +320,6 @@ def get_conversation(
"/chat",
summary="发送消息(支持流式和非流式)"
)
@check_end_user_quota
async def chat(
payload: conversation_schema.ChatRequest,
share_data: ShareTokenData = Depends(get_share_user_id),
@@ -350,6 +360,18 @@ async def chat(
app_service = AppService(db)
app = app_service._get_app_or_404(share.app_id)
workspace_id = app.workspace_id
# 仅在新建终端用户时检查配额,已有用户复用不受限制
existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id)
logger.info(f"终端用户配额检查: workspace_id={workspace_id}, other_id={other_id}, existing={existing_end_user is not None}")
if existing_end_user is None:
from app.core.quota_manager import _check_quota
from app.models.workspace_model import Workspace
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
if ws:
logger.info(f"新终端用户,执行配额检查: tenant_id={ws.tenant_id}")
_check_quota(db, ws.tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=share.app_id,
workspace_id=workspace_id,

View File

@@ -106,6 +106,16 @@ async def chat(
other_id = payload.user_id
workspace_id = api_key_auth.workspace_id
end_user_repo = EndUserRepository(db)
# 仅在新建终端用户时检查配额,已有用户复用不受限制
existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id)
if existing_end_user is None:
from app.core.quota_manager import _check_quota
from app.models.workspace_model import Workspace
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
if ws:
_check_quota(db, ws.tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=app.id,
workspace_id=workspace_id,

View File

@@ -47,7 +47,7 @@ async def create_end_user(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
message: str = Body(..., description="Request body"),
message: str = Body(None, description="Request body"),
):
"""
Create or retrieve an end user for the workspace.

View File

@@ -96,40 +96,8 @@ def require_api_key(
resource_id=api_key_obj.resource_id,
)
# ── Tenant 级别限速(来自套餐配额 api_ops_rate_limit──────────
try:
from app.models.workspace_model import Workspace
from premium.platform_admin.package_plan_service import TenantSubscriptionService
workspace = db.query(Workspace).filter(
Workspace.id == api_key_obj.workspace_id
).first()
if workspace:
quota = TenantSubscriptionService(db).get_effective_quota(workspace.tenant_id)
tenant_qps_limit = quota.get("api_ops_rate_limit") if quota else None
if tenant_qps_limit:
rate_limiter = RateLimiterService()
tenant_ok, tenant_info = await rate_limiter.check_tenant_rate_limit(
workspace.tenant_id, tenant_qps_limit
)
if not tenant_ok:
raise RateLimitException(
"租户 API 调用速率超限",
BizCode.API_KEY_QPS_LIMIT_EXCEEDED,
rate_headers={
"X-RateLimit-Tenant-Limit": str(tenant_info["limit"]),
"X-RateLimit-Tenant-Remaining": str(tenant_info["remaining"]),
"X-RateLimit-Tenant-Reset": str(tenant_info["reset"]),
}
)
except RateLimitException:
raise
except Exception as e:
logger.warning(f"Tenant 限速检查异常,跳过: {e}")
# ─────────────────────────────────────────────────────────────
rate_limiter = RateLimiterService()
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj)
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj, db=db)
if not is_allowed:
logger.warning("API Key 限流触发", extra={
"api_key_id": str(api_key_obj.id),
@@ -138,10 +106,12 @@ def require_api_key(
"error_msg": error_msg
})
# 根据错误消息判断限流类型
if "QPS" in error_msg:
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
elif "Daily" in error_msg:
if "Daily" in error_msg:
code = BizCode.API_KEY_DAILY_LIMIT_EXCEEDED
elif "Tenant" in error_msg:
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED # 租户套餐速率超限,同属 QPS 类
elif "QPS" in error_msg:
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
else:
code = BizCode.API_KEY_QUOTA_EXCEEDED

View File

@@ -31,6 +31,9 @@ class BizCode(IntEnum):
API_KEY_QPS_LIMIT_EXCEEDED = 3014
API_KEY_DAILY_LIMIT_EXCEEDED = 3015
API_KEY_QUOTA_EXCEEDED = 3016
API_KEY_RATE_LIMIT_EXCEEDED = 3017
QUOTA_EXCEEDED = 3018
RATE_LIMIT_EXCEEDED = 3019
# 资源4xxx
NOT_FOUND = 4000
USER_NOT_FOUND = 4001
@@ -155,7 +158,8 @@ HTTP_MAPPING = {
BizCode.API_KEY_QPS_LIMIT_EXCEEDED: 429,
BizCode.API_KEY_DAILY_LIMIT_EXCEEDED: 429,
BizCode.API_KEY_QUOTA_EXCEEDED: 429,
BizCode.QUOTA_EXCEEDED: 402,
BizCode.MODEL_CONFIG_INVALID: 400,
BizCode.API_KEY_MISSING: 400,
BizCode.PROVIDER_NOT_SUPPORTED: 400,
@@ -184,4 +188,21 @@ HTTP_MAPPING = {
BizCode.DB_ERROR: 500,
BizCode.SERVICE_UNAVAILABLE: 503,
BizCode.RATE_LIMITED: 429,
BizCode.RATE_LIMIT_EXCEEDED: 429,
}
ERROR_CODE_TO_BIZ_CODE = {
"QUOTA_EXCEEDED": BizCode.QUOTA_EXCEEDED,
"RATE_LIMIT_EXCEEDED": BizCode.RATE_LIMIT_EXCEEDED,
"API_KEY_NOT_FOUND": BizCode.API_KEY_NOT_FOUND,
"API_KEY_INVALID": BizCode.API_KEY_INVALID,
"API_KEY_EXPIRED": BizCode.API_KEY_EXPIRED,
"WORKSPACE_NOT_FOUND": BizCode.WORKSPACE_NOT_FOUND,
"WORKSPACE_NO_ACCESS": BizCode.WORKSPACE_NO_ACCESS,
"PERMISSION_DENIED": BizCode.PERMISSION_DENIED,
"TOKEN_EXPIRED": BizCode.TOKEN_EXPIRED,
"TOKEN_INVALID": BizCode.TOKEN_INVALID,
"VALIDATION_FAILED": BizCode.VALIDATION_FAILED,
"INVALID_PARAMETER": BizCode.INVALID_PARAMETER,
"MISSING_PARAMETER": BizCode.MISSING_PARAMETER,
}

View File

@@ -6,7 +6,6 @@
2. 降级到 default_free_plan.py 配置文件(社区版兜底)
"""
import asyncio
import time
from functools import wraps
from typing import Optional, Callable, Dict, Any
from uuid import UUID
@@ -15,10 +14,13 @@ from sqlalchemy import func
from sqlalchemy.orm import Session
from app.core.logging_config import get_auth_logger
from app.i18n.exceptions import QuotaExceededError
from app.i18n.exceptions import QuotaExceededError, InternalServerError
logger = get_auth_logger()
# Redis key 格式常量,与 RateLimiterService.check_qps 保持一致per api_key 独立计数)
API_KEY_QPS_REDIS_KEY = "rate_limit:qps:{api_key_id}"
def _get_user_from_kwargs(kwargs: dict):
"""从 kwargs 中获取 user 对象"""
@@ -28,6 +30,29 @@ def _get_user_from_kwargs(kwargs: dict):
return None
def _get_workspace_id_from_kwargs(kwargs: dict):
"""从 kwargs 中获取 workspace_id"""
# 优先从 kwargs['workspace_id'] 获取
workspace_id = kwargs.get("workspace_id")
if workspace_id:
return workspace_id
# 从 api_key_auth.workspace_id 获取API Key 认证场景)
api_key_auth = kwargs.get("api_key_auth")
if api_key_auth and hasattr(api_key_auth, 'workspace_id'):
return api_key_auth.workspace_id
# 从 user.current_workspace_id 获取
user = _get_user_from_kwargs(kwargs)
if user:
ws_id = getattr(user, 'current_workspace_id', None)
if ws_id:
return ws_id
logger.warning(f"无法获取 workspace_id, kwargs keys: {list(kwargs.keys())}")
return None
def _get_tenant_id_from_kwargs(db: Session, kwargs: dict):
"""从 kwargs 中获取 tenant_id"""
user = _get_user_from_kwargs(kwargs)
@@ -65,7 +90,9 @@ def _get_tenant_id_from_kwargs(db: Session, kwargs: dict):
if share_record:
app = db.query(App).filter(App.id == share_record.app_id, App.is_active.is_(True)).first()
if app:
return app.workspace.tenant_id
workspace = db.query(Workspace).filter(Workspace.id == app.workspace_id).first()
if workspace:
return workspace.tenant_id
return None
@@ -73,31 +100,52 @@ def _get_tenant_id_from_kwargs(db: Session, kwargs: dict):
def _get_quota_config(db: Session, tenant_id: UUID) -> Optional[Dict[str, Any]]:
"""
获取租户的配额配置
优先级:
1. premium 模块的 tenant_subscriptionsSaaS 版)
2. default_free_plan.py 配置文件(社区版兜底)
"""
# 尝试从 premium 模块获取
# 尝试从 premium 模块获取SaaS 版)
try:
from premium.platform_admin.package_plan_service import TenantSubscriptionService
# premium 模块存在,运行时错误不应被静默降级,直接抛出
quota_config = TenantSubscriptionService(db).get_effective_quota(tenant_id)
if quota_config:
logger.debug(f"从 premium 模块获取租户 {tenant_id} 配额配置")
return quota_config
except (ModuleNotFoundError, ImportError, Exception) as e:
logger.debug(f"无法从 premium 模块获取配额配置: {e}")
# premium 存在但该租户无订阅记录,降级到免费套餐
logger.debug(f"租户 {tenant_id} 无 premium 订阅,降级到免费套餐")
except (ModuleNotFoundError, ImportError):
# 社区版premium 包不存在,正常降级
logger.debug("premium 模块不存在,使用社区版免费套餐配额")
# 降级到配置文件
# 降级到社区版配置文件
try:
from app.config.default_free_plan import DEFAULT_FREE_PLAN
logger.info(f"使用配置文件中的免费套餐配额: tenant={tenant_id}")
logger.debug(f"使用社区版免费套餐配额: tenant={tenant_id}")
return DEFAULT_FREE_PLAN.get("quotas")
except Exception as e:
logger.error(f"无法从配置文件获取配额: {e}")
return None
def get_api_ops_rate_limit(db: Session, tenant_id: UUID) -> Optional[int]:
"""
获取租户套餐的 API 操作速率限制QPS 上限)
该函数兼容社区版和 SaaS 版:
- SaaS 版:从 premium 模块的套餐配额读取
- 社区版:从 default_free_plan.py 配置文件读取
Returns:
int: api_ops_rate_limit 值,如果未配置则返回 None
"""
quota_config = _get_quota_config(db, tenant_id)
if quota_config:
return quota_config.get("api_ops_rate_limit")
return None
class QuotaUsageRepository:
"""配额使用量数据访问层"""
@@ -111,15 +159,19 @@ class QuotaUsageRepository:
Workspace.is_active.is_(True)
).count()
def count_apps(self, tenant_id: UUID) -> int:
def count_apps(self, tenant_id: UUID, workspace_id: Optional[UUID] = None) -> int:
from app.models.app_model import App
from app.models.workspace_model import Workspace
return self.db.query(App).join(
query = self.db.query(App).join(
Workspace, App.workspace_id == Workspace.id
).filter(
Workspace.tenant_id == tenant_id,
App.is_active.is_(True)
).count()
)
if workspace_id:
query = query.filter(App.workspace_id == workspace_id)
else:
query = query.filter(Workspace.tenant_id == tenant_id)
return query.count()
def count_skills(self, tenant_id: UUID) -> int:
from app.models.skill_model import Skill
@@ -128,55 +180,76 @@ class QuotaUsageRepository:
Skill.is_active.is_(True)
).count()
def sum_knowledge_capacity_gb(self, tenant_id: UUID) -> float:
def sum_knowledge_capacity_gb(self, tenant_id: UUID, workspace_id: Optional[UUID] = None) -> float:
from app.models.document_model import Document
from app.models.knowledge_model import Knowledge
from app.models.workspace_model import Workspace
result = self.db.query(func.coalesce(func.sum(Document.file_size), 0)).join(
query = self.db.query(func.coalesce(func.sum(Document.file_size), 0)).join(
Knowledge, Document.kb_id == Knowledge.id
).join(
Workspace, Knowledge.workspace_id == Workspace.id
).filter(
Workspace.tenant_id == tenant_id,
Document.status == 1,
).scalar()
)
if workspace_id:
query = query.filter(Knowledge.workspace_id == workspace_id)
else:
query = query.filter(Workspace.tenant_id == tenant_id)
result = query.scalar()
return float(result) / (1024 ** 3) if result else 0.0
def count_memory_engines(self, tenant_id: UUID) -> int:
def count_memory_engines(self, tenant_id: UUID, workspace_id: Optional[UUID] = None) -> int:
from app.models.memory_config_model import MemoryConfig
from app.models.workspace_model import Workspace
return self.db.query(MemoryConfig).join(
query = self.db.query(MemoryConfig).join(
Workspace, MemoryConfig.workspace_id == Workspace.id
).filter(
Workspace.tenant_id == tenant_id
).count()
)
if workspace_id:
query = query.filter(MemoryConfig.workspace_id == workspace_id)
else:
query = query.filter(Workspace.tenant_id == tenant_id)
return query.count()
def count_end_users(self, tenant_id: UUID) -> int:
def count_end_users(self, tenant_id: UUID, workspace_id: Optional[UUID] = None) -> int:
from app.models.end_user_model import EndUser
from app.models.workspace_model import Workspace
return self.db.query(EndUser).join(
from app.models.user_model import User
query = self.db.query(EndUser).join(
Workspace, EndUser.workspace_id == Workspace.id
).filter(
Workspace.tenant_id == tenant_id
).count()
)
if workspace_id:
query = query.filter(EndUser.workspace_id == workspace_id)
else:
query = query.filter(Workspace.tenant_id == tenant_id)
trial_user_ids = [
str(u.id) for u in self.db.query(User.id).filter(User.tenant_id == tenant_id).all()
]
if trial_user_ids:
query = query.filter(~EndUser.other_id.in_(trial_user_ids))
return query.count()
def count_models(self, tenant_id: UUID) -> int:
from app.models.models_model import ModelConfig
return self.db.query(ModelConfig).filter(
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_active == True
ModelConfig.is_active == True,
ModelConfig.is_composite == True
).count()
def count_ontology_projects(self, tenant_id: UUID) -> int:
def count_ontology_projects(self, tenant_id: UUID, workspace_id: Optional[UUID] = None) -> int:
from app.models.ontology_scene import OntologyScene
from app.models.workspace_model import Workspace
if workspace_id:
return self.db.query(OntologyScene).filter(
OntologyScene.workspace_id == workspace_id
).count()
return self.db.query(OntologyScene).join(
Workspace, OntologyScene.workspace_id == Workspace.id
).filter(
Workspace.tenant_id == tenant_id
).count()
def get_usage_by_quota_type(self, tenant_id: UUID, quota_type: str):
def get_usage_by_quota_type(self, tenant_id: UUID, quota_type: str, workspace_id: Optional[UUID] = None):
"""按配额类型分发,返回当前使用量"""
dispatch = {
"workspace_quota": self.count_workspaces,
@@ -189,6 +262,8 @@ class QuotaUsageRepository:
"ontology_project_quota": self.count_ontology_projects,
}
fn = dispatch.get(quota_type)
if workspace_id:
return fn(tenant_id, workspace_id) if fn else 0
return fn(tenant_id) if fn else 0
@@ -198,6 +273,7 @@ def _check_quota(
quota_type: str,
resource_name: str,
usage_func: Optional[Callable] = None,
workspace_id: Optional[UUID] = None,
) -> None:
"""核心配额检查逻辑:对比使用量和配额限制"""
try:
@@ -212,13 +288,13 @@ def _check_quota(
return
if usage_func:
current_usage = usage_func(db, tenant_id)
current_usage = usage_func(db, tenant_id, workspace_id) if workspace_id else usage_func(db, tenant_id)
else:
current_usage = QuotaUsageRepository(db).get_usage_by_quota_type(tenant_id, quota_type)
current_usage = QuotaUsageRepository(db).get_usage_by_quota_type(tenant_id, quota_type, workspace_id)
if current_usage >= quota_limit:
logger.warning(
f"配额不足: tenant={tenant_id}, type={quota_type}, "
f"配额不足: tenant={tenant_id}, workspace={workspace_id}, type={quota_type}, "
f"usage={current_usage}, limit={quota_limit}"
)
raise QuotaExceededError(
@@ -228,7 +304,7 @@ def _check_quota(
)
logger.debug(
f"配额检查通过: tenant={tenant_id}, type={quota_type}, "
f"配额检查通过: tenant={tenant_id}, workspace={workspace_id}, type={quota_type}, "
f"usage={current_usage}, limit={quota_limit}"
)
@@ -236,7 +312,7 @@ def _check_quota(
raise
except Exception as e:
logger.error(
f"配额检查异常: tenant={tenant_id}, type={quota_type}, "
f"配额检查异常: tenant={tenant_id}, workspace={workspace_id}, type={quota_type}, "
f"error_type={type(e).__name__}, error={str(e)}",
exc_info=True,
)
@@ -247,41 +323,82 @@ def _check_quota(
def check_workspace_quota(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数")
return func(*args, **kwargs)
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "workspace_quota", "workspace")
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "workspace_quota", "workspace")
return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
def check_skill_quota(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数")
return func(*args, **kwargs)
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "skill_quota", "skill")
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "skill_quota", "skill")
return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
def check_app_quota(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数")
return func(*args, **kwargs)
_check_quota(db, user.tenant_id, "app_quota", "app")
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
workspace_id = _get_workspace_id_from_kwargs(kwargs)
if not workspace_id:
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "app_quota", "app", workspace_id=workspace_id)
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
workspace_id = _get_workspace_id_from_kwargs(kwargs)
if not workspace_id:
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "app_quota", "app", workspace_id=workspace_id)
return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
def check_knowledge_capacity_quota(func: Callable) -> Callable:
@@ -289,13 +406,17 @@ def check_knowledge_capacity_quota(func: Callable) -> Callable:
async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
if not db:
logger.warning("配额检查失败:缺少 db 参数")
return await func(*args, **kwargs)
logger.error(f"配额检查失败:{func.__name__} 缺少 db 参数,拒绝请求")
raise InternalServerError()
tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
if not tenant_id:
logger.warning("配额检查失败:无法获取 tenant_id")
return await func(*args, **kwargs)
_check_quota(db, tenant_id, "knowledge_capacity_quota", "knowledge_capacity")
logger.error(f"配额检查失败:{func.__name__} 无法获取 tenant_id,拒绝请求")
raise InternalServerError()
workspace_id = _get_workspace_id_from_kwargs(kwargs)
if not workspace_id:
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id拒绝请求")
raise InternalServerError()
_check_quota(db, tenant_id, "knowledge_capacity_quota", "knowledge_capacity", workspace_id=workspace_id)
return await func(*args, **kwargs)
@wraps(func)
@@ -303,9 +424,13 @@ def check_knowledge_capacity_quota(func: Callable) -> Callable:
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数")
return func(*args, **kwargs)
_check_quota(db, user.tenant_id, "knowledge_capacity_quota", "knowledge_capacity")
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
workspace_id = _get_workspace_id_from_kwargs(kwargs)
if not workspace_id:
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "knowledge_capacity_quota", "knowledge_capacity", workspace_id=workspace_id)
return func(*args, **kwargs)
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
@@ -313,15 +438,36 @@ def check_knowledge_capacity_quota(func: Callable) -> Callable:
def check_memory_engine_quota(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
logger.debug(f"check_memory_engine_quota async_wrapper: db={db is not None}, user={user}, kwargs_keys={list(kwargs.keys())}")
if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数")
return func(*args, **kwargs)
_check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine")
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
workspace_id = _get_workspace_id_from_kwargs(kwargs)
if not workspace_id:
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine", workspace_id=workspace_id)
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
logger.debug(f"check_memory_engine_quota sync_wrapper: db={db is not None}, user={user}, kwargs_keys={list(kwargs.keys())}")
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
workspace_id = _get_workspace_id_from_kwargs(kwargs)
if not workspace_id:
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine", workspace_id=workspace_id)
return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
def check_end_user_quota(func: Callable) -> Callable:
@@ -329,26 +475,34 @@ def check_end_user_quota(func: Callable) -> Callable:
async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
if not db:
logger.warning("配额检查失败:缺少 db 参数")
return await func(*args, **kwargs)
logger.error(f"配额检查失败:{func.__name__} 缺少 db 参数,拒绝请求")
raise InternalServerError()
tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
if not tenant_id:
logger.warning("配额检查失败:无法获取 tenant_id")
return await func(*args, **kwargs)
_check_quota(db, tenant_id, "end_user_quota", "end_user")
logger.error(f"配额检查失败:{func.__name__} 无法获取 tenant_id,拒绝请求")
raise InternalServerError()
workspace_id = _get_workspace_id_from_kwargs(kwargs)
if not workspace_id:
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id拒绝请求")
raise InternalServerError()
_check_quota(db, tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
if not db:
logger.warning("配额检查失败:缺少 db 参数")
return func(*args, **kwargs)
logger.error(f"配额检查失败:{func.__name__} 缺少 db 参数,拒绝请求")
raise InternalServerError()
tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
if not tenant_id:
logger.warning("配额检查失败:无法获取 tenant_id")
return func(*args, **kwargs)
_check_quota(db, tenant_id, "end_user_quota", "end_user")
logger.error(f"配额检查失败:{func.__name__} 无法获取 tenant_id,拒绝请求")
raise InternalServerError()
workspace_id = _get_workspace_id_from_kwargs(kwargs)
if not workspace_id:
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id拒绝请求")
raise InternalServerError()
_check_quota(db, tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id)
return func(*args, **kwargs)
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
@@ -356,89 +510,171 @@ def check_end_user_quota(func: Callable) -> Callable:
def check_ontology_project_quota(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数")
return func(*args, **kwargs)
_check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project")
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
workspace_id = _get_workspace_id_from_kwargs(kwargs)
if not workspace_id:
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project", workspace_id=workspace_id)
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
workspace_id = _get_workspace_id_from_kwargs(kwargs)
if not workspace_id:
logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project", workspace_id=workspace_id)
return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
def check_model_quota(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数")
return func(*args, **kwargs)
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "model_quota", "model")
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "model_quota", "model")
return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
def check_model_activation_quota(func: Callable) -> Callable:
"""模型激活时的配额检查装饰器"""
@wraps(func)
def wrapper(*args, **kwargs):
async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数")
return func(*args, **kwargs)
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
model_id = kwargs.get("model_id") or (args[1] if len(args) > 1 else None)
model_data = kwargs.get("model_data")
if not model_id or not model_data:
logger.warning("模型激活配额检查失败:缺少 model_id 或 model_data 参数")
return func(*args, **kwargs)
if model_data.is_active is True:
return await func(*args, **kwargs)
if model_data.is_active:
try:
from app.models.models_model import ModelConfig
from app.services.model_service import ModelConfigService
existing_model = ModelConfigService.get_model_by_id(
db=db,
model_id=model_id,
db=db,
model_id=model_id,
tenant_id=user.tenant_id
)
if not existing_model.is_active:
logger.info(f"模型激活操作,检查配额: model_id={model_id}, tenant_id={user.tenant_id}")
_check_quota(db, user.tenant_id, "model_quota", "model")
except Exception as e:
logger.error(f"模型激活配额检查异常: model_id={model_id}, error={str(e)}")
raise
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
model_id = kwargs.get("model_id") or (args[1] if len(args) > 1 else None)
model_data = kwargs.get("model_data")
if not model_id or not model_data:
logger.warning("模型激活配额检查失败:缺少 model_id 或 model_data 参数")
return func(*args, **kwargs)
if model_data.is_active:
try:
from app.services.model_service import ModelConfigService
existing_model = ModelConfigService.get_model_by_id(
db=db,
model_id=model_id,
tenant_id=user.tenant_id
)
if not existing_model.is_active:
logger.info(f"模型激活操作,检查配额: model_id={model_id}, tenant_id={user.tenant_id}")
_check_quota(db, user.tenant_id, "model_quota", "model")
except Exception as e:
logger.error(f"模型激活配额检查异常: model_id={model_id}, error={str(e)}")
raise
return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
def check_quota(quota_type: str, resource_name: str, usage_func: Optional[Callable] = None):
"""通用配额检查装饰器,支持自定义使用量获取函数"""
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数")
return func(*args, **kwargs)
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, quota_type, resource_name, usage_func)
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, quota_type, resource_name, usage_func)
return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
return decorator
# ─── 配额使用统计 ────────────────────────────────────────────────────────────
def get_quota_usage(db: Session, tenant_id: UUID) -> dict:
"""获取租户所有配额的使用情况"""
async def get_quota_usage(db: Session, tenant_id: UUID) -> dict:
"""获取租户所有配额的使用情况
对于 workspace 级别的配额app/knowledge_capacity/memory_engine/end_user
- used: 租户汇总(所有空间加总)
- limit: quota × 活跃工作区数(有效总限额,使汇总数据自洽)
- per_workspace: 各空间明细,包含 workspace_id、workspace_name、used、limit、percentage
- 配额检查逻辑不变:仍按单个空间独立检查
"""
quota_config = _get_quota_config(db, tenant_id)
if not quota_config:
return {}
@@ -457,29 +693,99 @@ def get_quota_usage(db: Session, tenant_id: UUID) -> dict:
model_count = repo.count_models(tenant_id)
ontology_count = repo.count_ontology_projects(tenant_id)
# 获取租户下所有活跃工作区,用于按空间拆分明细
from app.models.workspace_model import Workspace
active_workspaces = db.query(Workspace).filter(
Workspace.tenant_id == tenant_id,
Workspace.is_active.is_(True)
).all()
# 构建各空间的 workspace 级配额明细
def _build_per_workspace_detail(count_func, per_unit_limit):
"""为 workspace 级配额构建 per_workspace 明细列表"""
if not per_unit_limit or not active_workspaces:
return []
details = []
for ws in active_workspaces:
ws_used = count_func(tenant_id, ws.id)
details.append({
"workspace_id": str(ws.id),
"workspace_name": ws.name,
"used": ws_used,
"limit": per_unit_limit,
"percentage": pct(ws_used, per_unit_limit),
})
return details
# workspace 级配额的每空间限额
app_quota_per_ws = quota_config.get("app_quota")
knowledge_quota_per_ws = quota_config.get("knowledge_capacity_quota")
memory_quota_per_ws = quota_config.get("memory_engine_quota")
end_user_quota_per_ws = quota_config.get("end_user_quota")
ontology_quota_per_ws = quota_config.get("ontology_project_quota")
# workspace 级配额的有效总限额 = 每空间限额 × 活跃工作区数
app_effective_limit = app_quota_per_ws * workspace_count if app_quota_per_ws is not None and workspace_count > 0 else app_quota_per_ws
knowledge_effective_limit = knowledge_quota_per_ws * workspace_count if knowledge_quota_per_ws is not None and workspace_count > 0 else knowledge_quota_per_ws
memory_effective_limit = memory_quota_per_ws * workspace_count if memory_quota_per_ws is not None and workspace_count > 0 else memory_quota_per_ws
end_user_effective_limit = end_user_quota_per_ws * workspace_count if end_user_quota_per_ws is not None and workspace_count > 0 else end_user_quota_per_ws
ontology_effective_limit = ontology_quota_per_ws * workspace_count if ontology_quota_per_ws is not None and workspace_count > 0 else ontology_quota_per_ws
api_ops_current = 0
try:
from app.core.config import settings
import redis
_now = time.time()
_rk = f"rate_limit:tenant_qps:{tenant_id}"
_r = redis.StrictRedis(
host=settings.REDIS_HOST, port=settings.REDIS_PORT,
db=settings.REDIS_DB, password=settings.REDIS_PASSWORD,
decode_responses=True
)
api_ops_current = int(_r.zcount(_rk, _now - 1, "+inf"))
except Exception:
pass
from app.aioRedis import aio_redis as _aio_redis
from app.models.api_key_model import ApiKey
# api_ops_rate_limit 限的是每个 api_key 每秒最高限额
# 展示当前最接近触发限流的 key 的 QPS取最大值
api_key_ids = db.query(ApiKey.id).join(
Workspace, ApiKey.workspace_id == Workspace.id
).filter(
Workspace.tenant_id == tenant_id,
ApiKey.is_active.is_(True)
).all()
for (key_id,) in api_key_ids:
_rk = API_KEY_QPS_REDIS_KEY.format(api_key_id=key_id)
val = await _aio_redis.get(_rk)
count = int(val) if val else 0
if count > api_ops_current:
api_ops_current = count
except Exception as e:
logger.warning(f"获取 api_ops_current 失败,返回 0: {type(e).__name__}: {e}")
return {
"workspace": {"used": workspace_count, "limit": quota_config.get("workspace_quota"), "percentage": pct(workspace_count, quota_config.get("workspace_quota"))},
"skill": {"used": skill_count, "limit": quota_config.get("skill_quota"), "percentage": pct(skill_count, quota_config.get("skill_quota"))},
"app": {"used": app_count, "limit": quota_config.get("app_quota"), "percentage": pct(app_count, quota_config.get("app_quota"))},
"knowledge_capacity": {"used": round(knowledge_gb, 2), "limit": quota_config.get("knowledge_capacity_quota"), "percentage": pct(knowledge_gb, quota_config.get("knowledge_capacity_quota")), "unit": "GB"},
"memory_engine": {"used": memory_count, "limit": quota_config.get("memory_engine_quota"), "percentage": pct(memory_count, quota_config.get("memory_engine_quota"))},
"end_user": {"used": end_user_count, "limit": quota_config.get("end_user_quota"), "percentage": pct(end_user_count, quota_config.get("end_user_quota"))},
"ontology_project": {"used": ontology_count, "limit": quota_config.get("ontology_project_quota"), "percentage": pct(ontology_count, quota_config.get("ontology_project_quota"))},
"app": {
"used": app_count,
"limit": app_effective_limit,
"percentage": pct(app_count, app_effective_limit),
"per_workspace": _build_per_workspace_detail(repo.count_apps, app_quota_per_ws),
},
"knowledge_capacity": {
"used": round(knowledge_gb, 2),
"limit": knowledge_effective_limit,
"percentage": pct(knowledge_gb, knowledge_effective_limit),
"unit": "GB",
"per_workspace": _build_per_workspace_detail(repo.sum_knowledge_capacity_gb, knowledge_quota_per_ws),
},
"memory_engine": {
"used": memory_count,
"limit": memory_effective_limit,
"percentage": pct(memory_count, memory_effective_limit),
"per_workspace": _build_per_workspace_detail(repo.count_memory_engines, memory_quota_per_ws),
},
"end_user": {
"used": end_user_count,
"limit": end_user_effective_limit,
"percentage": pct(end_user_count, end_user_effective_limit),
"per_workspace": _build_per_workspace_detail(repo.count_end_users, end_user_quota_per_ws),
},
"ontology_project": {
"used": ontology_count,
"limit": ontology_effective_limit,
"percentage": pct(ontology_count, ontology_effective_limit),
"per_workspace": _build_per_workspace_detail(repo.count_ontology_projects, ontology_quota_per_ws),
},
"model": {"used": model_count, "limit": quota_config.get("model_quota"), "percentage": pct(model_count, quota_config.get("model_quota"))},
"api_ops_rate_limit": {"current": api_ops_current, "limit": quota_config.get("api_ops_rate_limit"), "percentage": None, "unit": "次/秒"},
}

View File

@@ -18,6 +18,7 @@ from app.core.quota_manager import (
get_quota_usage,
_check_quota,
QuotaUsageRepository,
API_KEY_QPS_REDIS_KEY,
)
__all__ = [
@@ -33,4 +34,5 @@ __all__ = [
"get_quota_usage",
"_check_quota",
"QuotaUsageRepository",
"API_KEY_QPS_REDIS_KEY",
]

View File

@@ -33,18 +33,16 @@ def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception:
thread.daemon = True
thread.start()
effective_timeout = seconds if seconds else 120 # 默认 120 秒超时
for a in range(attempts):
try:
if os.environ.get("ENABLE_TIMEOUT_ASSERTION"):
result = result_queue.get(timeout=seconds)
else:
result = result_queue.get()
result = result_queue.get(timeout=effective_timeout)
if isinstance(result, Exception):
raise result
return result
except queue.Empty:
pass
raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds and {attempts} attempts.")
raise TimeoutError(f"Function '{func.__name__}' timed out after {effective_timeout} seconds and {attempts} attempts.")
@wraps(func)
async def async_wrapper(*args, **kwargs) -> Any:

View File

@@ -113,7 +113,7 @@ def knowledge_retrieval(
continue
# Use the specified reranker for re-ranking
if reranker_id:
if reranker_id and all_results:
try:
all_results = rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k)
except Exception as rerank_error:

View File

@@ -68,9 +68,9 @@ class ESConnection(DocStoreConnection):
client_config = {
"hosts": [hosts],
"basic_auth": (os.getenv("ELASTICSEARCH_USERNAME", "elastic"), os.getenv("ELASTICSEARCH_PASSWORD", "elastic")),
"request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 100000)),
"request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 30)),
"retry_on_timeout": os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", True) == "true",
"max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 10000)),
"max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 3)),
}
# Only add SSL settings if using HTTPS

View File

@@ -1,25 +1,22 @@
import os
import logging
from typing import Any, cast
import threading
from typing import Any
from urllib.parse import urlparse
import uuid
import requests
from elasticsearch import Elasticsearch, helpers
from elasticsearch.helpers import BulkIndexError
from packaging.version import parse as parse_version
from pydantic import BaseModel, model_validator
from abc import ABC
# langchain-community
# langchain-xinference
# from langchain_community.embeddings import XinferenceEmbeddings
# from langchain_xinference import XinferenceRerank
from langchain_core.documents import Document
from app.core.models.base import RedBearModelConfig
from app.core.models import RedBearLLM, RedBearRerank
from app.core.models import RedBearRerank
from app.core.models.embedding import RedBearEmbeddings
from app.models.models_model import ModelConfig, ModelApiKey
from app.services.model_service import ModelConfigService
from app.models.models_model import ModelApiKey
from app.models.knowledge_model import Knowledge
from app.core.rag.vdb.field import Field
@@ -29,37 +26,9 @@ from app.core.rag.models.chunk import DocumentChunk
logger = logging.getLogger(__name__)
class ElasticSearchConfig(BaseModel):
# Regular Elasticsearch config
host: str | None = None
port: int | None = None
username: str | None = None
password: str | None = None
# Common config
ca_certs: str | None = None
verify_certs: bool = False
request_timeout: int = 100000
retry_on_timeout: bool = True
max_retries: int = 10000
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
# Regular Elasticsearch validation
if not values.get("host"):
raise ValueError("config HOST is required for regular Elasticsearch")
if not values.get("port"):
raise ValueError("config PORT is required for regular Elasticsearch")
if not values.get("username"):
raise ValueError("config USERNAME is required for regular Elasticsearch")
if not values.get("password"):
raise ValueError("config PASSWORD is required for regular Elasticsearch")
return values
class ElasticSearchVector(BaseVector):
def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey):
def __init__(self, index_name: str, client: Elasticsearch,
embedding_config: ModelApiKey, reranker_config: ModelApiKey):
super().__init__(index_name.lower())
# 初始化 Embedding 模型(自动支持火山引擎多模态)
@@ -77,58 +46,8 @@ class ElasticSearchVector(BaseVector):
api_key=reranker_config.api_key,
base_url=reranker_config.api_base
))
self._client = self._init_client(config)
self._version = self._get_version()
self._check_version()
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
"""
Initialize Elasticsearch client for regular Elasticsearch.
"""
try:
# Regular Elasticsearch configuration
parsed_url = urlparse(config.host or "")
if parsed_url.scheme in {"http", "https"}:
hosts = f"{config.host}:{config.port}"
use_https = parsed_url.scheme == "https"
else:
hosts = f"https://{config.host}:{config.port}"
use_https = False
client_config = {
"hosts": [hosts],
"basic_auth": (config.username, config.password),
"request_timeout": config.request_timeout,
"retry_on_timeout": config.retry_on_timeout,
"max_retries": config.max_retries,
}
# Only add SSL settings if using HTTPS
if use_https:
client_config["verify_certs"] = config.verify_certs
if config.ca_certs:
client_config["ca_certs"] = config.ca_certs
client = Elasticsearch(**client_config)
# Test connection
if not client.ping():
raise ConnectionError("Failed to connect to Elasticsearch")
except requests.ConnectionError as e:
raise ConnectionError(f"Vector database connection error: {str(e)}")
except Exception as e:
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")
return client
def _get_version(self) -> str:
info = self._client.info()
return cast(str, info["version"]["number"])
def _check_version(self):
if parse_version(self._version) < parse_version("8.0.0"):
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")
# 使用外部传入的共享客户端
self._client = client
def get_type(self) -> str:
return "elasticsearch"
@@ -745,29 +664,79 @@ class ElasticSearchVector(BaseVector):
class ElasticSearchVectorFactory:
@staticmethod
def init_vector(knowledge: Knowledge) -> ElasticSearchVector:
"""ES 向量服务工厂 - 单例共享连接"""
_client: Elasticsearch | None = None
_lock = threading.Lock()
_version_checked = False
@classmethod
def _get_shared_client(cls) -> Elasticsearch:
"""获取共享的 ES 客户端(线程安全的懒加载单例)"""
if cls._client is not None:
return cls._client
with cls._lock:
# 双重检查,防止并发时重复创建
if cls._client is not None:
return cls._client
try:
parsed_url = urlparse(os.getenv("ELASTICSEARCH_HOST", "127.0.0.1") or "")
if parsed_url.scheme in {"http", "https"}:
hosts = f'{os.getenv("ELASTICSEARCH_HOST")}:{os.getenv("ELASTICSEARCH_PORT", 9200)}'
use_https = parsed_url.scheme == "https"
else:
hosts = f'https://{os.getenv("ELASTICSEARCH_HOST", "127.0.0.1")}:{os.getenv("ELASTICSEARCH_PORT", 9200)}'
use_https = False
client_config = {
"hosts": [hosts],
"basic_auth": (
os.getenv("ELASTICSEARCH_USERNAME", "elastic"),
os.getenv("ELASTICSEARCH_PASSWORD", "elastic"),
),
"request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 30)),
"retry_on_timeout": True,
"max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 3)),
"connections_per_node": int(os.getenv("ELASTICSEARCH_CONNECTIONS_PER_NODE", 10)),
}
if use_https:
client_config["verify_certs"] = os.getenv("ELASTICSEARCH_VERIFY_CERTS", "false") == "true"
ca_certs = os.getenv("ELASTICSEARCH_CA_CERTS")
if ca_certs:
client_config["ca_certs"] = str(ca_certs)
client = Elasticsearch(**client_config)
if not client.ping():
raise ConnectionError("Failed to connect to Elasticsearch")
# 版本检查只做一次
if not cls._version_checked:
info = client.info()
version = info["version"]["number"]
if parse_version(version) < parse_version("8.0.0"):
raise ValueError(f"Elasticsearch version must be >= 8.0.0, got {version}")
cls._version_checked = True
logger.info(f"Elasticsearch shared client initialized, version: {version}")
cls._client = client
except requests.ConnectionError as e:
raise ConnectionError(f"Vector database connection error: {str(e)}")
except Exception as e:
raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}")
return cls._client
@classmethod
def init_vector(cls, knowledge: Knowledge) -> ElasticSearchVector:
"""创建向量服务实例(共享 ES 连接)"""
client = cls._get_shared_client()
collection_name = f"Vector_index_{knowledge.id}_Node"
# Use regular Elasticsearch with config values
config_dict = {
"host": os.getenv("ELASTICSEARCH_HOST", "127.0.0.1"),
"port": os.getenv("ELASTICSEARCH_PORT", 9200),
"username": os.getenv("ELASTICSEARCH_USERNAME", "elastic"),
"password": os.getenv("ELASTICSEARCH_PASSWORD", "elastic"),
}
# Common configuration
config_dict.update(
{
"ca_certs": str(os.getenv("ELASTICSEARCH_CA_CERTS")) if os.getenv("ELASTICSEARCH_CA_CERTS") else None,
"verify_certs": os.getenv("ELASTICSEARCH_VERIFY_CERTS", False) == "true",
"request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 100000)),
"retry_on_timeout": os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", True) == "true",
"max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 10000)),
}
)
if knowledge.embedding is None:
raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}")
if knowledge.reranker is None:
@@ -775,9 +744,9 @@ class ElasticSearchVectorFactory:
return ElasticSearchVector(
index_name=collection_name,
config=ElasticSearchConfig(**config_dict),
client=client,
embedding_config=knowledge.embedding.api_keys[0],
reranker_config=knowledge.reranker.api_keys[0]
reranker_config=knowledge.reranker.api_keys[0],
)

View File

@@ -11,6 +11,7 @@ from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db_read
from app.services.tool_service import ToolService
from app.models.tool_model import ToolType
logger = logging.getLogger(__name__)
@@ -76,6 +77,18 @@ class ToolNode(BaseNode):
# 执行工具
with get_db_read() as db:
tool_service = ToolService(db)
# MCP 工具:将 operation 映射为 tool_name其余参数包装进 arguments
tool_instance = tool_service.get_tool_instance(self.typed_config.tool_id, tenant_id)
if tool_instance and tool_instance.tool_type == ToolType.MCP:
operation = rendered_parameters.pop("operation", None)
if operation:
old_params = rendered_parameters
rendered_parameters = {
"tool_name": operation,
"arguments": old_params
}
result = await tool_service.execute_tool(
tool_id=self.typed_config.tool_id,
parameters=rendered_parameters,

View File

@@ -6,12 +6,14 @@ error messages based on the current request's language.
"""
import logging
import time
from contextvars import ContextVar
from typing import Any, Dict, Optional
from fastapi import HTTPException, Request
from app.i18n.service import get_translation_service
from app.core.error_codes import ERROR_CODE_TO_BIZ_CODE, BizCode
logger = logging.getLogger(__name__)
@@ -118,15 +120,24 @@ class I18nException(HTTPException):
**params
)
# Build error detail
detail = {
"error_code": self.error_code,
"message": message,
}
# Convert error_code string to BizCode value
biz_code = ERROR_CODE_TO_BIZ_CODE.get(
self.error_code,
BizCode.BAD_REQUEST
)
# Add parameters to detail if provided
if params:
detail["params"] = params
# Build error detail in standard format for compatibility
# main.py handler expects "message" and "error_code" fields for filtering
# but we also include standard format fields
detail = {
"code": biz_code.value,
"msg": message,
"message": message,
"error_code": self.error_code,
"data": params if params else {},
"error": message,
"time": int(time.time() * 1000),
}
# Initialize HTTPException
super().__init__(
@@ -482,14 +493,39 @@ class RateLimitExceededError(I18nException):
)
class QuotaExceededError(ForbiddenError):
"""Quota exceeded error."""
class QuotaExceededError(I18nException):
"""Quota exceeded error (402)."""
# resource key -> i18n display key
_RESOURCE_KEY_MAP = {
"workspace": "errors.quota_resources.workspace",
"app": "errors.quota_resources.app",
"skill": "errors.quota_resources.skill",
"knowledge_capacity": "errors.quota_resources.knowledge_capacity",
"memory_engine": "errors.quota_resources.memory_engine",
"end_user": "errors.quota_resources.end_user",
"model": "errors.quota_resources.model",
"ontology_project": "errors.quota_resources.ontology_project",
"api_ops_rate_limit": "errors.quota_resources.api_ops_rate_limit",
}
def __init__(self, resource: Optional[str] = None, **params):
# Translate resource key to a localized display name before calling super()
if resource:
params["resource"] = resource
resource_i18n_key = self._RESOURCE_KEY_MAP.get(resource)
if resource_i18n_key:
try:
from app.i18n.service import get_translation_service
from app.core.config import settings
_locale = _current_locale.get() or settings.I18N_DEFAULT_LANGUAGE
params["resource"] = get_translation_service().translate(resource_i18n_key, _locale)
except Exception:
params["resource"] = resource
else:
params["resource"] = resource
super().__init__(
error_key="errors.api.quota_exceeded",
status_code=402,
error_code="QUOTA_EXCEEDED",
**params
)

View File

@@ -106,7 +106,7 @@
},
"api": {
"rate_limit_exceeded": "API rate limit exceeded",
"quota_exceeded": "API quota exceeded",
"quota_exceeded": "{resource} quota exceeded",
"invalid_api_key": "Invalid API key",
"api_key_expired": "API key has expired",
"api_key_revoked": "API key has been revoked",
@@ -114,7 +114,8 @@
"method_not_allowed": "Method not allowed",
"invalid_request": "Invalid request",
"missing_parameter": "Missing required parameter: {param}",
"invalid_parameter": "Invalid parameter: {param}"
"invalid_parameter": "Invalid parameter: {param}",
"api_key_rate_limit_exceeded": "API Key rate limit ({rate_limit}) exceeds tenant plan limit ({limit})"
},
"database": {
"connection_failed": "Database connection failed",
@@ -134,5 +135,16 @@
"invalid_format": "Invalid format: {field}",
"invalid_value": "Invalid value: {field}",
"out_of_range": "Value out of range: {field}"
},
"quota_resources": {
"workspace": "Workspace",
"app": "App",
"skill": "Skill",
"knowledge_capacity": "Knowledge capacity",
"memory_engine": "Memory engine",
"end_user": "End user",
"model": "Model",
"ontology_project": "Ontology project",
"api_ops_rate_limit": "API ops rate limit"
}
}

View File

@@ -106,7 +106,7 @@
},
"api": {
"rate_limit_exceeded": "API调用频率超限",
"quota_exceeded": "API调用配额已用完",
"quota_exceeded": "{resource} 配额已超限",
"invalid_api_key": "无效的API密钥",
"api_key_expired": "API密钥已过期",
"api_key_revoked": "API密钥已被撤销",
@@ -114,7 +114,8 @@
"method_not_allowed": "不支持的请求方法",
"invalid_request": "无效的请求",
"missing_parameter": "缺少必需参数:{param}",
"invalid_parameter": "参数无效:{param}"
"invalid_parameter": "参数无效:{param}",
"api_key_rate_limit_exceeded": "API Key 的 QPS 限制({rate_limit})超过租户套餐上限({limit}"
},
"database": {
"connection_failed": "数据库连接失败",
@@ -134,5 +135,16 @@
"invalid_format": "格式不正确:{field}",
"invalid_value": "值无效:{field}",
"out_of_range": "值超出范围:{field}"
},
"quota_resources": {
"workspace": "工作空间",
"app": "应用",
"skill": "技能",
"knowledge_capacity": "知识库容量",
"memory_engine": "记忆引擎",
"end_user": "终端用户",
"model": "模型",
"ontology_project": "本体工程",
"api_ops_rate_limit": "API 操作速率"
}
}

View File

@@ -66,6 +66,17 @@ class EndUserRepository:
db_logger.error(f"查询宿主 {end_user_id} 时出错: {str(e)}")
raise
def get_end_user_by_other_id(self, workspace_id: uuid.UUID, other_id: str) -> Optional["EndUser"]:
"""按 workspace_id + other_id 查找终端用户,不存在返回 None"""
return (
self.db.query(EndUser)
.filter(
EndUser.workspace_id == workspace_id,
EndUser.other_id == other_id
)
.first()
)
def get_or_create_end_user(
self,
app_id: uuid.UUID,

View File

@@ -15,8 +15,8 @@ class ApiKeyCreate(BaseModel):
type: ApiKeyType = Field(..., description="API Key 类型")
scopes: List[str] = Field(default_factory=list, description="权限范围列表")
resource_id: Optional[uuid.UUID] = Field(None, description="关联资源ID")
rate_limit: Optional[int] = Field(100, ge=1, le=1000, description="QPS限制请求/秒)")
daily_request_limit: Optional[int] = Field(10000, description="日请求限制", ge=1)
rate_limit: Optional[int] = Field(50, ge=1, le=1000, description="QPS限制请求/秒)")
daily_request_limit: Optional[int] = Field(100000, description="日请求限制", ge=1)
quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1)
expires_at: Optional[datetime.datetime] = Field(None, description="过期时间")
@@ -55,7 +55,7 @@ class ApiKeyUpdate(BaseModel):
description: Optional[str] = Field(None, description="描述")
scopes: Optional[List[str]] = Field(None, description="权限范围列表")
rate_limit: Optional[int] = Field(None, description="速率限制(请求/分钟)", ge=1)
daily_request_limit: Optional[int] = Field(10000, description="每日请求数限制", ge=1)
daily_request_limit: Optional[int] = Field(100000, description="每日请求数限制", ge=1)
quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1)
is_active: Optional[bool] = Field(None, description="是否激活")
expires_at: Optional[datetime.datetime] = Field(None, description="过期时间")

View File

@@ -51,6 +51,19 @@ class ApiKeyService:
if existing:
raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME)
# 若 rate_limit 超过租户套餐的 api_ops_rate_limit直接报错
from app.models.workspace_model import Workspace
from app.core.quota_manager import get_api_ops_rate_limit
workspace = db.query(Workspace).filter(Workspace.id == workspace_id).first()
if workspace:
tenant_api_ops_limit = get_api_ops_rate_limit(db, workspace.tenant_id)
if tenant_api_ops_limit and data.rate_limit > tenant_api_ops_limit:
raise BusinessException(
f"API Key QPS 不能超过套餐上限 {tenant_api_ops_limit}",
BizCode.BAD_REQUEST
)
# 生成 API Key
api_key = generate_api_key(data.type)
@@ -152,6 +165,20 @@ class ApiKeyService:
if existing:
raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME)
# 若 rate_limit 超过租户套餐的 api_ops_rate_limit直接报错
if data.rate_limit is not None:
from app.models.workspace_model import Workspace
from app.core.quota_manager import get_api_ops_rate_limit
workspace = db.query(Workspace).filter(Workspace.id == workspace_id).first()
if workspace:
tenant_api_ops_limit = get_api_ops_rate_limit(db, workspace.tenant_id)
if tenant_api_ops_limit and data.rate_limit > tenant_api_ops_limit:
raise BusinessException(
f"API Key QPS 不能超过套餐上限 {tenant_api_ops_limit}",
BizCode.BAD_REQUEST
)
update_data = data.model_dump(exclude_unset=True)
ApiKeyRepository.update(db, api_key_id, update_data)
db.commit()
@@ -248,42 +275,14 @@ class RateLimiterService:
def __init__(self):
self.redis = aio_redis
async def check_tenant_rate_limit(self, tenant_id: uuid.UUID, limit: int) -> Tuple[bool, dict]:
"""
按 tenant_id 做 1 秒滑动窗口限速,限制值来自套餐配额 api_ops_rate_limit
"""
now = time.time()
window_start = now - 1 # 1 秒窗口
key = f"rate_limit:tenant_qps:{tenant_id}"
async with self.redis.pipeline() as pipe:
# 清理 1 秒前的旧记录
pipe.zremrangebyscore(key, 0, window_start)
# 加入当前请求score=时间戳member=时间戳+随机数保证唯一)
pipe.zadd(key, {f"{now}:{uuid.uuid4().hex}": now})
# 统计窗口内请求数
pipe.zcard(key)
# 设置 key 过期2 秒后自动清理)
pipe.expire(key, 2)
results = await pipe.execute()
current = results[2]
remaining = max(0, limit - current)
reset_time = int(now) + 1
return current <= limit, {
"limit": limit,
"remaining": remaining,
"reset": reset_time,
}
async def check_qps(self, api_key_id: uuid.UUID, limit: int) -> Tuple[bool, dict]:
"""
检查QPS限制
"""检查QPS限制
Returns:
(is_allowed, rate_limit_info)
"""
key = f"rate_limit:qps:{api_key_id}"
async with self.redis.pipeline() as pipe:
pipe.incr(key)
pipe.expire(key, 1, nx=True) # 1 秒过期
@@ -295,8 +294,9 @@ class RateLimiterService:
return current <= limit, {
"limit": limit,
"current": current,
"remaining": remaining,
"reset": reset_time
"reset": reset_time,
}
async def check_daily_requests(
@@ -304,7 +304,9 @@ class RateLimiterService:
api_key_id: uuid.UUID,
limit: int
) -> Tuple[bool, dict]:
"""检查日调用量限制"""
"""检查日调用量限制
使用原子 INCR先写后判断极低概率下允许轻微超限并发场景下可接受
"""
today = datetime.now().strftime("%Y%m%d")
key = f"rate_limit:daily:{api_key_id}:{today}"
@@ -313,6 +315,7 @@ class RateLimiterService:
hour=0, minute=0, second=0, microsecond=0
)
expire_seconds = int((tomorrow_0 - now).total_seconds())
reset_time = int(tomorrow_0.timestamp())
async with self.redis.pipeline() as pipe:
pipe.incr(key)
@@ -320,36 +323,74 @@ class RateLimiterService:
results = await pipe.execute()
current = results[0]
remaining = max(0, limit - current)
reset_time = int(tomorrow_0.timestamp())
return current <= limit, {
if current > limit:
return False, {
"limit": limit,
"remaining": 0,
"reset": reset_time,
}
return True, {
"limit": limit,
"remaining": remaining,
"reset": reset_time
"remaining": max(0, limit - current),
"reset": reset_time,
}
async def check_all_limits(
self,
api_key: ApiKey
api_key: ApiKey,
db: Optional[Session] = None,
) -> Tuple[bool, str, dict]:
"""
检查所有限制
Returns:
(is_allowed, error_message, rate_limit_headers)
检查所有限制,按以下顺序:
1. API Key QPS取 api_key.rate_limit 与套餐 api_ops_rate_limit 的最小值作为限额
2. API Key 日调用量
"""
# Check QPS
qps_ok, qps_info = await self.check_qps(
api_key.id,
api_key.rate_limit
)
# 1. 取套餐限额与 api_key 自身限额的最小值
effective_limit = api_key.rate_limit
if db is not None:
try:
from app.models.workspace_model import Workspace
from app.core.quota_manager import get_api_ops_rate_limit
cache_key = f"tenant_api_ops_limit:{api_key.workspace_id}"
cached = await self.redis.get(cache_key)
if cached is not None:
try:
tenant_limit = int(cached) if cached != "0" else None
except (ValueError, TypeError):
cached = None
tenant_limit = None
if cached is None:
workspace = db.query(Workspace).filter(Workspace.id == api_key.workspace_id).first()
if workspace:
tenant_limit = get_api_ops_rate_limit(db, workspace.tenant_id)
await self.redis.set(cache_key, str(tenant_limit) if tenant_limit else "0", ex=60)
else:
tenant_limit = None
if tenant_limit:
effective_limit = min(api_key.rate_limit, tenant_limit)
except Exception as e:
logger.warning(f"获取套餐限额失败,使用 api_key 自身限额: {e}")
# 用最终有效限额做 QPS 检查
qps_ok, qps_info = await self.check_qps(api_key.id, effective_limit)
if not qps_ok:
return False, "QPS limit exceeded", {
# 判断是套餐限额触发还是 api_key 自身限额触发
if tenant_limit and effective_limit == tenant_limit and api_key.rate_limit > tenant_limit:
error_msg = "Tenant limit exceeded"
else:
error_msg = "QPS limit exceeded"
return False, error_msg, {
"X-RateLimit-Limit-QPS": str(qps_info["limit"]),
"X-RateLimit-Remaining-QPS": str(qps_info["remaining"]),
"X-RateLimit-Reset": str(qps_info["reset"])
}
# 2. 检查日调用量
daily_ok, daily_info = await self.check_daily_requests(
api_key.id,
api_key.daily_request_limit
@@ -361,14 +402,13 @@ class RateLimiterService:
"X-RateLimit-Reset": str(daily_info["reset"])
}
headers = {
return True, "", {
"X-RateLimit-Limit-QPS": str(qps_info["limit"]),
"X-RateLimit-Remaining-QPS": str(qps_info["remaining"]),
"X-RateLimit-Limit-Day": str(daily_info["limit"]),
"X-RateLimit-Remaining-Day": str(daily_info["remaining"]),
"X-RateLimit-Reset": str(daily_info["reset"])
"X-RateLimit-Reset": str(daily_info["reset"]),
}
return True, "", headers
class ApiKeyAuthService:

View File

@@ -434,19 +434,37 @@ class AppDslService:
def _resolve_model(self, ref: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> Optional[uuid.UUID]:
if not ref:
return None
q = self.db.query(ModelConfig).filter(
ModelConfig.tenant_id == tenant_id,
ModelConfig.name == ref.get("name"),
ModelConfig.is_active.is_(True)
)
if ref.get("provider"):
q = q.filter(ModelConfig.provider == ref["provider"])
if ref.get("type"):
q = q.filter(ModelConfig.type == ref["type"])
m = q.first()
if not m:
warnings.append(f"模型 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置")
return m.id if m else None
model_id = ref.get("id")
if model_id:
try:
model_uuid = uuid.UUID(str(model_id))
m = self.db.query(ModelConfig).filter(
ModelConfig.id == model_uuid,
ModelConfig.tenant_id == tenant_id,
ModelConfig.is_active.is_(True)
).first()
if m:
return str(m.id)
except (ValueError, AttributeError):
pass
model_name = ref.get("name")
if model_name:
q = self.db.query(ModelConfig).filter(
ModelConfig.tenant_id == tenant_id,
ModelConfig.name == model_name,
ModelConfig.is_active.is_(True)
)
if ref.get("provider"):
q = q.filter(ModelConfig.provider == ref["provider"])
if ref.get("type"):
q = q.filter(ModelConfig.type == ref["type"])
m = q.first()
if m:
return str(m.id)
warnings.append(f"模型 '{model_name}' 未匹配,已置空,请导入后手动配置")
else:
warnings.append(f"模型 ID '{model_id}' 未匹配,已置空,请导入后手动配置")
return None
def _resolve_kb(self, ref: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[str]:
if not ref:
@@ -587,7 +605,7 @@ class AppDslService:
if not kb_id:
continue
kb_ref = {}
if isinstance(kb_id, str) and len(kb_id) >= 36:
if isinstance(kb_id, str):
try:
uuid.UUID(kb_id)
kb_ref["id"] = kb_id
@@ -601,6 +619,33 @@ class AppDslService:
else:
warnings.append(f"[{node_label}] 知识库 '{kb_id}' 未匹配,已移除,请导入后手动配置")
config["knowledge_bases"] = resolved_kbs
elif node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
model_ref = config.get("model_id")
if model_ref:
ref_dict = None
if isinstance(model_ref, dict):
ref_id = model_ref.get("id")
ref_name = model_ref.get("name")
if ref_id:
ref_dict = {"id": ref_id}
elif ref_name is not None:
ref_dict = {"name": ref_name, "provider": model_ref.get("provider"), "type": model_ref.get("type")}
elif isinstance(model_ref, str):
try:
uuid.UUID(model_ref)
ref_dict = {"id": model_ref}
except ValueError:
ref_dict = {"name": model_ref}
if ref_dict:
resolved_model_id = self._resolve_model(ref_dict, tenant_id, warnings)
if resolved_model_id:
config["model_id"] = resolved_model_id
else:
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
config["model_id"] = None
else:
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
config["model_id"] = None
resolved_nodes.append({**node, "config": config})
return resolved_nodes

View File

@@ -7,7 +7,6 @@ from app.models.models_model import ModelConfig
from app.schemas.knowledge_schema import KnowledgeCreate, KnowledgeUpdate
from app.repositories import knowledge_repository
from app.core.logging_config import get_business_logger
from app.repositories.model_repository import ModelConfigRepository
from app.models.models_model import ModelType
business_logger = get_business_logger()
@@ -78,41 +77,31 @@ def create_knowledge(
tenant_id = workspace.tenant_id
if not knowledge.embedding_id:
embedding_models = ModelConfigRepository.get_by_type(
db=db, model_types=[ModelType.EMBEDDING], tenant_id=tenant_id, is_active=True
)
if embedding_models:
knowledge.embedding_id = embedding_models[0].id
business_logger.debug(f"Auto-bind embedding model: {embedding_models[0].id}")
if not workspace.embedding:
raise Exception("工作空间未配置 Embedding 模型,请先完善工作空间配置后重试")
knowledge.embedding_id = workspace.embedding
if not knowledge.reranker_id:
rerank_models = ModelConfigRepository.get_by_type(
db=db, model_types=[ModelType.RERANK], tenant_id=tenant_id, is_active=True
)
if rerank_models:
knowledge.reranker_id = rerank_models[0].id
business_logger.debug(f"Auto-bind rerank model: {rerank_models[0].id}")
if not workspace.rerank:
raise Exception("工作空间未配置 Rerank 模型,请先完善工作空间配置后重试")
knowledge.reranker_id = workspace.rerank
if not knowledge.llm_id:
llm_models = ModelConfigRepository.get_by_type(
db=db, model_types=[ModelType.LLM, ModelType.CHAT], tenant_id=tenant_id, is_active=True
)
if llm_models:
knowledge.llm_id = llm_models[0].id
business_logger.debug(f"Auto-bind llm model: {llm_models[0].id}")
if not workspace.llm:
raise Exception("工作空间未配置 LLM 模型,请先完善工作空间配置后重试")
knowledge.llm_id = workspace.llm
if not knowledge.image2text_id:
image2text_models = db.query(ModelConfig).filter(
model = db.query(ModelConfig).filter(
ModelConfig.tenant_id == tenant_id,
ModelConfig.type.in_([ModelType.CHAT.value]),
ModelConfig.type.in_([ModelType.CHAT.value, ModelType.LLM.value]),
ModelConfig.capability.contains(["vision"]),
ModelConfig.is_active == True,
ModelConfig.is_composite == False
).order_by(ModelConfig.created_at.desc()).all()
if not image2text_models:
).order_by(ModelConfig.created_at.desc()).first()
if not model:
raise Exception("租户下没有可用的视觉模型,创建知识库失败")
knowledge.image2text_id = image2text_models[0].id
business_logger.debug(f"Auto-bind image2text model: {image2text_models[0].id}")
knowledge.image2text_id = model.id
business_logger.debug(f"Auto-bind image2text model: {model.id}")
business_logger.debug(f"Start creating the knowledge base: {knowledge.name}")
db_knowledge = knowledge_repository.create_knowledge(

View File

@@ -1282,7 +1282,7 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
}
logger.info(
f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={end_user.workspace_id}")
return result

View File

@@ -125,11 +125,7 @@ class ModelConfigService:
api_key=api_key,
base_url=api_base,
is_omni=is_omni,
capability=capability,
extra_params={
"temperature": 0.7,
"max_tokens": 100
}
capability=capability
)
# 根据模型类型选择不同的验证方式
@@ -373,6 +369,15 @@ class ModelConfigService:
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
# 同步更新关联 api_keys 的 capability 和 is_omni
if model_data.capability is not None or model_data.is_omni is not None:
for api_key in model.api_keys:
if model_data.capability is not None:
api_key.capability = model_data.capability
if model_data.is_omni is not None:
api_key.is_omni = model_data.is_omni
db.commit()
db.refresh(model)
return model

View File

@@ -251,8 +251,40 @@ def parse_document(file_path: str, document_id: uuid.UUID):
# Prepare vision_model for parsing
vision_model = _build_vision_model(file_path, db_knowledge)
# 先将文件读入内存,避免解析过程中依赖 NFS 文件持续可访问
# python-docx 等库在 binary=None 时会用路径直接打开文件,
# 在 NFS/共享存储上可能因缓存失效导致 "Package not found"
max_wait_seconds = 30
wait_interval = 2
waited = 0
file_binary = None
while waited <= max_wait_seconds:
# os.listdir 强制 NFS 客户端刷新目录缓存
parent_dir = os.path.dirname(file_path)
try:
os.listdir(parent_dir)
except OSError:
pass
try:
with open(file_path, "rb") as f:
file_binary = f.read()
if not file_binary:
# NFS 上文件存在但内容为空(可能还在同步中)
raise IOError(f"File is empty (0 bytes), NFS may still be syncing: {file_path}")
break
except (FileNotFoundError, IOError) as e:
if waited >= max_wait_seconds:
raise type(e)(
f"File not accessible at '{file_path}' after waiting {max_wait_seconds}s: {e}"
)
logger.warning(f"File not ready on this node, retrying in {wait_interval}s: {file_path} ({e})")
time.sleep(wait_interval)
waited += wait_interval
from app.core.rag.app.naive import chunk
logger.info(f"[ParseDoc] file_binary size={len(file_binary)} bytes, type={type(file_binary).__name__}, bool={bool(file_binary)}")
res = chunk(filename=file_path,
binary=file_binary,
from_page=0,
to_page=DEFAULT_PARSE_TO_PAGE,
callback=progress_callback,

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-03-07 16:49:59
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-17 10:11:54
* @Last Modified time: 2026-04-20 18:14:34
*/
import { type FC, useEffect, useState } from 'react';
import { Select, Flex, Space } from 'antd';
@@ -56,7 +56,7 @@ const ModelSelect: FC<ModelSelectProps> = ({ params, placeholder, fontClassName,
useEffect(() => {
if (updateOptions) updateOptions([...options, ...initialData]);
}, [options, initialData])
}, [JSON.stringify(options), JSON.stringify(initialData)])
return (
<Select

View File

@@ -2,11 +2,11 @@
* @Author: ZhaoYing
* @Date: 2026-04-14 12:28:23
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-16 17:34:02
* @Last Modified time: 2026-04-21 15:46:35
*/
import { useState, forwardRef, useImperativeHandle } from 'react';
import { Flex, Tooltip, Divider } from 'antd';
import { Flex, Divider } from 'antd';
import { useTranslation } from 'react-i18next';
import clsx from 'clsx';
@@ -82,8 +82,7 @@ const SubscriptionDetailModal = forwardRef<SubscriptionDetailModalRef>((_props,
{/* Features */}
<Flex gap={12} vertical className="rb:space-y-3 rb:mb-4 rb:h-[calc(100vh-341px)]! rb:overflow-y-auto">
{billingUnits.map(({ key, unit, icon }) => {
const value = detail?.quota[key as keyof Subscription['quota']];
if (value === undefined || value === null) return null;
const value = detail?.quotas?.[key as keyof Subscription['quotas']];
return (
<UnitWrapper
key={key}
@@ -95,7 +94,7 @@ const SubscriptionDetailModal = forwardRef<SubscriptionDetailModalRef>((_props,
/>
)
})}
{detail?.package_plan?.tech_support && (
{detail?.package_plan?.tech_support && detail?.package_plan?.[getKeyWithLanguage('tech_support')] && (
<UnitWrapper
titleKey="tech_support"
value={String(detail?.package_plan?.[getKeyWithLanguage('tech_support')] ?? '')}
@@ -103,7 +102,7 @@ const SubscriptionDetailModal = forwardRef<SubscriptionDetailModalRef>((_props,
theme_color={detail?.package_plan?.theme_color}
/>
)}
{detail?.package_plan?.sla_compliance && (
{detail?.package_plan?.sla_compliance && detail?.package_plan?.[getKeyWithLanguage('sla_compliance')] && (
<UnitWrapper
titleKey="sla"
value={String(detail?.package_plan?.[getKeyWithLanguage('sla_compliance')] ?? '')}

View File

@@ -18,61 +18,61 @@
* @component
*/
import { useState, useEffect, useRef, type FC } from 'react';
import { Menu as AntMenu, Layout, Flex, Divider } from 'antd';
import { UserOutlined } from '@ant-design/icons';
import type { MenuProps } from 'antd';
import { useNavigate, useLocation } from 'react-router-dom';
import { useTranslation } from 'react-i18next';
import { Menu as AntMenu, Divider, Flex, Layout } from 'antd';
import clsx from 'clsx';
import { useEffect, useRef, useState, type FC } from 'react';
import { useTranslation } from 'react-i18next';
import { useLocation, useNavigate } from 'react-router-dom';
import { useMenu, type MenuItem } from '@/store/menu';
import styles from './index.module.css'
import logo from '@/assets/images/logo.png'
import { useUser } from '@/store/user';
import { getTenantSubscription } from '@/api/user';
import { useI18n } from '@/store/locale'
import logo from '@/assets/images/logo.png';
import { useI18n } from '@/store/locale';
import { useMenu, type MenuItem } from '@/store/menu';
import { useUser } from '@/store/user';
import styles from './index.module.css';
import SubscriptionDetailModal, { type SubscriptionDetailModalRef } from './SubscriptionDetailModal';
import SwitchSpaceModal, { type SwitchSpaceModalRef } from './SwitchSpaceModal';
// Import SVG files
// space
import dashboardIcon from '@/assets/images/menuNew/dashboard.svg';
import dashboardActiveIcon from '@/assets/images/menuNew/dashboard_active.svg';
import applicationIcon from '@/assets/images/menuNew/application.svg';
import applicationActiveIcon from '@/assets/images/menuNew/application_active.svg';
import knowledgeIcon from '@/assets/images/menuNew/knowledge.svg';
import knowledgeActiveIcon from '@/assets/images/menuNew/knowledge_active.svg';
import memoryIcon from '@/assets/images/menuNew/memory.svg';
import memoryActiveIcon from '@/assets/images/menuNew/memory_active.svg';
import userMemoryIcon from '@/assets/images/menuNew/userMemory.svg';
import userMemoryActiveIcon from '@/assets/images/menuNew/userMemory_active.svg';
import memoryConversationIcon from '@/assets/images/menuNew/memoryConversation.svg';
import memoryConversationActiveIcon from '@/assets/images/menuNew/memoryConversation_active.svg';
import apiKeyIcon from '@/assets/images/menuNew/apiKey.svg';
import apiKeyActiveIcon from '@/assets/images/menuNew/apiKey_active.svg';
import applicationIcon from '@/assets/images/menuNew/application.svg';
import applicationActiveIcon from '@/assets/images/menuNew/application_active.svg';
import dashboardIcon from '@/assets/images/menuNew/dashboard.svg';
import dashboardActiveIcon from '@/assets/images/menuNew/dashboard_active.svg';
import knowledgeIcon from '@/assets/images/menuNew/knowledge.svg';
import knowledgeActiveIcon from '@/assets/images/menuNew/knowledge_active.svg';
import memberIcon from '@/assets/images/menuNew/member.svg';
import memberActiveIcon from '@/assets/images/menuNew/member_active.svg';
import ontologyIcon from '@/assets/images/menuNew/ontology.svg'
import ontologyActiveIcon from '@/assets/images/menuNew/ontology_active.svg'
import spaceConfigIcon from '@/assets/images/menuNew/spaceConfig.svg'
import spaceConfigActiveIcon from '@/assets/images/menuNew/spaceConfig_active.svg'
import promptIcon from '@/assets/images/menuNew/prompt.svg'
import promptActiveIcon from '@/assets/images/menuNew/prompt_active.svg'
import memoryIcon from '@/assets/images/menuNew/memory.svg';
import memoryActiveIcon from '@/assets/images/menuNew/memory_active.svg';
import memoryConversationIcon from '@/assets/images/menuNew/memoryConversation.svg';
import memoryConversationActiveIcon from '@/assets/images/menuNew/memoryConversation_active.svg';
import ontologyIcon from '@/assets/images/menuNew/ontology.svg';
import ontologyActiveIcon from '@/assets/images/menuNew/ontology_active.svg';
import promptIcon from '@/assets/images/menuNew/prompt.svg';
import promptActiveIcon from '@/assets/images/menuNew/prompt_active.svg';
import spaceConfigIcon from '@/assets/images/menuNew/spaceConfig.svg';
import spaceConfigActiveIcon from '@/assets/images/menuNew/spaceConfig_active.svg';
import userMemoryIcon from '@/assets/images/menuNew/userMemory.svg';
import userMemoryActiveIcon from '@/assets/images/menuNew/userMemory_active.svg';
// manage
import modelIcon from '@/assets/images/menuNew/model.svg';
import modelActiveIcon from '@/assets/images/menuNew/model_active.svg';
import pricingIcon from '@/assets/images/menuNew/pricing.svg';
import pricingActiveIcon from '@/assets/images/menuNew/pricing_active.svg';
import skillsIcon from '@/assets/images/menuNew/skills.svg';
import skillsActiveIcon from '@/assets/images/menuNew/skills_active.svg';
import spaceIcon from '@/assets/images/menuNew/space.svg';
import spaceActiveIcon from '@/assets/images/menuNew/space_active.svg';
import userIcon from '@/assets/images/menuNew/user.svg';
import userActiveIcon from '@/assets/images/menuNew/user_active.svg';
import toolIcon from '@/assets/images/menuNew/tool.svg';
import toolActiveIcon from '@/assets/images/menuNew/tool_active.svg';
import pricingIcon from '@/assets/images/menuNew/pricing.svg'
import pricingActiveIcon from '@/assets/images/menuNew/pricing_active.svg'
import skillsIcon from '@/assets/images/menuNew/skills.svg'
import skillsActiveIcon from '@/assets/images/menuNew/skills_active.svg'
import userIcon from '@/assets/images/menuNew/user.svg';
import userActiveIcon from '@/assets/images/menuNew/user_active.svg';
export interface PackagePlan {
id: string
@@ -115,7 +115,7 @@ export interface Subscription {
started_at: number | null
expired_at: number | null
status: string
quota: SubscriptionQuota
quotas: SubscriptionQuota
created_at: number
updated_at: number
}
@@ -431,7 +431,7 @@ const Menu: FC<{
<div className="rb:grid rb:grid-cols-4 rb:mt-4">
{['workspace_quota', 'skill_quota', 'app_quota', 'model_quota'].map(key => (
<div key={key} className="rb:text-center">
<div className="rb:text-[13px] rb:font-[MiSans-Semibold] rb:font-semibold">{subscription.quota?.[key as keyof typeof subscription.quota]}</div>
<div className="rb:text-[13px] rb:font-[MiSans-Semibold] rb:font-semibold">{subscription.quotas?.[key as keyof typeof subscription.quotas] ?? t('package.noLimit')}</div>
<div className="rb:mt-1 rb:text-[#5B6167] rb:text-[10px] rb:leading-3.5">{t(`index.${key}`)}</div>
</div>
))}

View File

@@ -451,6 +451,9 @@ export const en = {
logoutApiCannotRefreshToken: 'Logout API cannot refresh token',
publicApiCannotRefreshToken: 'Public API cannot refresh token',
refreshTokenNotExist: 'Refresh token does not exist',
SYSTEM_DEFAULT_SCENE_CANNOT_DELETE: 'This is a system preset scene and cannot be deleted',
SYSTEM_DEFAULT_CLASS_CANNOT_DELETE: 'This scene is a system preset scene and cannot be deleted',
SYSTEM_DEFAULT_SCENE_CANNOT_UPDATE: 'This scene is a system preset scene and cannot be modified',
reset: 'Reset',
refresh: 'Refresh',
return: 'Return',
@@ -2543,6 +2546,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
checkListErrors: {
'llm.model_id': 'Model',
'llm.messages': 'Messages',
'llm.vision_input': 'Vision Variable',
'end.output': 'Output',
'knowledge-retrieval.knowledge_retrieval': 'Knowledge bases',
'parameter-extractor.model_id': 'Model',
@@ -2571,6 +2575,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
'document-extractor.file_selector': 'File variable',
'list-operator.input_list': 'Input list',
'output.outputs': 'Output Variable',
'tool.tool_id': 'Tool',
},
checkListHasErrors: 'Please resolve all issues in the checklist before publishing',
variableSelect: {
@@ -3104,6 +3109,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
editPackage: 'Edit Package',
viewDetail: 'View full package details',
noLimit: 'Infinite',
},
},
};

View File

@@ -66,13 +66,13 @@ export const zh = {
goConfig: '去配置',
},
indexTour:{
startTitle:'欢迎来到 Memory Bear 👋',
startDescription:'不知道从哪里开始?不妨先去 Model Management 看看,先把模型准备好,后面的操作会更顺畅。👉 点击左侧 Model Management 开始吧。',
stepOne: '这里是 Model Management',
stepOneDescription: '你可以在这里查看和配置可用的模型,为后续应用做好准备。模型准备好后,下一步去 Space Management 创建空间并开始使用吧。👉 点击左侧 Space Management 继续。',
stepTwo: '这里是 Space Management',
stepTwoDescription: '你可以在这里创建和管理不同的空间,把模型和数据组织到具体的使用场景中。空间创建完成后,可以去 User Management 邀请成员、分配权限,一起协作使用。👉 点击左侧 User Management 继续。',
stepThree: '这里是用户管理',
startTitle:'欢迎来到 记忆熊 👋',
startDescription:'不知道从哪里开始?不妨先去 模型管理 看看,先把模型准备好,后面的操作会更顺畅。👉 点击左侧 模型管理 开始吧。',
stepOne: '这里是 模型管理',
stepOneDescription: '你可以在这里查看和配置可用的模型,为后续应用做好准备。模型准备好后,下一步去 空间管理 创建空间并开始使用吧。👉 点击左侧 空间管理 继续。',
stepTwo: '这里是 空间管理',
stepTwoDescription: '你可以在这里创建和管理不同的空间,把模型和数据组织到具体的使用场景中。空间创建完成后,可以去 用户管理 邀请成员、分配权限,一起协作使用。👉 点击左侧 用户管理 继续。',
stepThree: '这里是 用户管理',
stepThreeDescription: '你可以在这里创建用户、分配角色,并管理团队成员的访问权限。完成用户设置后,基础配置就准备好了,可以开始实际使用平台的各项功能了 🎉',
finishButtonText: '开始使用',
},
@@ -1130,6 +1130,9 @@ export const zh = {
logoutApiCannotRefreshToken: '退出登录接口不能刷新token',
publicApiCannotRefreshToken: '公共接口不能刷新token',
refreshTokenNotExist: '刷新token不存在',
SYSTEM_DEFAULT_SCENE_CANNOT_DELETE: '该场景为系统预设场景,不允许删除',
SYSTEM_DEFAULT_CLASS_CANNOT_DELETE: '该场景为系统预设场景,不允许删除',
SYSTEM_DEFAULT_SCENE_CANNOT_UPDATE: '该场景为系统预设场景,不允许修改',
reset: '重置',
refresh: '刷新',
return: '返回',
@@ -2507,6 +2510,7 @@ export const zh = {
checkListErrors: {
'llm.model_id': '模型',
'llm.messages': '提示词',
'llm.vision_input': '视觉变量',
'end.output': '回复',
'knowledge-retrieval.knowledge_retrieval': '知识库',
'parameter-extractor.model_id': '模型',
@@ -2535,6 +2539,7 @@ export const zh = {
'document-extractor.file_selector': '文件变量',
'list-operator.input_list': '输入变量',
'output.outputs': '输出变量',
'tool.tool_id': '工具',
},
checkListHasErrors: '发布前确认检查清单中所有问题均已解决',
variableSelect: {
@@ -3068,6 +3073,7 @@ export const zh = {
editPackage: '编辑套餐',
viewDetail: '查看完整套餐详情',
noLimit: '无限',
},
},
}

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-02 16:35:43
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-21 14:20:39
* @Last Modified time: 2026-04-22 10:16:43
*/
/**
* Server-Sent Events (SSE) Stream Utility Module
@@ -16,11 +16,11 @@
* @module stream
*/
import { refreshToken } from '@/api/user';
import i18n from '@/i18n';
import { message } from 'antd';
import i18n from '@/i18n'
import { cookieUtils } from './request'
import { refreshToken } from '@/api/user'
import { clearAuthData } from './auth'
import { clearAuthData } from './auth';
import { cookieUtils } from './request';
const API_PREFIX = '/api'
// Token refresh state
@@ -181,12 +181,12 @@ export const handleSSE = async (url: string, data: any, onMessage?: (data: SSEMe
case 500:
case 502:
const errorData = await response.json();
const errorInfo = errorData.error || i18n.t('common.serviceUpgrading');
const errorInfo = errorData.error || errorData.msg || i18n.t('common.serviceUpgrading');
message.warning(errorInfo);
throw new Error(errorData);
case 400:
const error = await response.json();
const error400 = error.error || 'Bad Request';
const error400 = error.error || error.msg || 'Bad Request';
message.warning(error400);
throw new Error(error);
case 403:
@@ -195,7 +195,7 @@ export const handleSSE = async (url: string, data: any, onMessage?: (data: SSEMe
throw new Error(errors);
case 504:
const errorJson = await response.json();
const errorMsg = errorJson.error || i18n.t('common.serverError');
const errorMsg = errorJson.error || errorJson.msg || i18n.t('common.serverError');
message.warning(errorMsg);
throw new Error(errorJson);
case 401:
@@ -209,6 +209,13 @@ export const handleSSE = async (url: string, data: any, onMessage?: (data: SSEMe
return;
}
break;
default:
if (!response.ok) {
const defaultData = await response.json().catch(() => ({}));
const defaultMsg = defaultData.error || defaultData.msg;
if (defaultMsg) message.warning(defaultMsg);
throw new Error(defaultMsg || `HTTP ${response.status}`);
}
}
if (!response.body) throw new Error('No response body');

View File

@@ -62,7 +62,6 @@ const Agent = forwardRef<AgentRef, { onFeaturesLoad?: (features: FeaturesConfigF
const { id } = useParams();
const { message } = App.useApp()
const [form] = Form.useForm()
const [loading, setLoading] = useState(false)
const [data, setData] = useState<Config | null>(null);
const modelConfigModalRef = useRef<ModelConfigModalRef>(null)
const [modelList, setModelList] = useState<Model[]>([])
@@ -94,7 +93,6 @@ const Agent = forwardRef<AgentRef, { onFeaturesLoad?: (features: FeaturesConfigF
* Fetch agent configuration data
*/
const getData = () => {
setLoading(true)
getApplicationConfig(id as string).then(res => {
const response = res as Config
const { skills, variables } = response
@@ -127,8 +125,6 @@ const Agent = forwardRef<AgentRef, { onFeaturesLoad?: (features: FeaturesConfigF
tools: allTools
})
onFeaturesLoad?.(response.features)
}).finally(() => {
setLoading(false)
})
}
@@ -421,7 +417,6 @@ const Agent = forwardRef<AgentRef, { onFeaturesLoad?: (features: FeaturesConfigF
console.log('agent values', values)
return (
<>
{loading && <Spin fullscreen></Spin>}
<Row className="rb:h-full!" gutter={12}>
<Col span={12} className="rb:h-full!">
<Form form={form}>

View File

@@ -68,7 +68,7 @@ const Chat: FC<ChatProps> = ({
const [loading, setLoading] = useState(false)
const [isCluster, setIsCluster] = useState(source === 'multi_agent')
const [conversationId, setConversationId] = useState<string | null>(null)
const [compareLoading, setCompareLoading] = useState(false)
const compareLoadingRef = useRef(false)
const [fileList, setFileList] = useState<any[]>([])
const [message, setMessage] = useState<string | undefined>(undefined)
const [features, setFeatures] = useState<FeaturesConfigForm>({} as FeaturesConfigForm)
@@ -76,7 +76,7 @@ const Chat: FC<ChatProps> = ({
const abortRef = useRef<(() => void) | null>(null)
useEffect(() => {
setCompareLoading(false)
compareLoadingRef.current = false
setLoading(false)
return () => {
abortRef.current?.()
@@ -259,7 +259,7 @@ const Chat: FC<ChatProps> = ({
const handleSend = (msg?: string) => {
if (loading || !id) return
setLoading(true)
setCompareLoading(true)
compareLoadingRef.current = true
const files = (fileList || []).filter(item => !['uploading', 'error'].includes(item.status))
handleSave(false)
.then(() => {
@@ -285,7 +285,7 @@ const Chat: FC<ChatProps> = ({
}
if (!isCanSend) {
setLoading(false)
setCompareLoading(false)
compareLoadingRef.current = false
return
}
@@ -310,20 +310,20 @@ const Chat: FC<ChatProps> = ({
switch (item.event) {
case 'model_reasoning':
if (compareLoading) {
setCompareLoading(false)
if (compareLoadingRef.current) {
compareLoadingRef.current = false
}
updateAssistantReasoningMessage(content, model_config_id, conversation_id)
break;
case 'model_message':
if (compareLoading) {
setCompareLoading(false)
if (compareLoadingRef.current) {
compareLoadingRef.current = false
}
updateAssistantMessage(content, model_config_id, conversation_id, audio_url)
break;
case 'model_end':
if (compareLoading) {
setCompareLoading(false)
if (compareLoadingRef.current) {
compareLoadingRef.current = false
}
const idToPoll = `${model_config_id}_${audio_url}`
if (audio_url && !audioStatusMap[idToPoll]) {
@@ -365,8 +365,8 @@ const Chat: FC<ChatProps> = ({
updateErrorAssistantMessage(message_length, model_config_id)
break;
case 'compare_end':
if (compareLoading) {
setCompareLoading(false)
if (compareLoadingRef.current) {
compareLoadingRef.current = false
}
setLoading(false);
break;
@@ -401,18 +401,18 @@ const Chat: FC<ChatProps> = ({
}, handleStreamMessage, (abort) => { abortRef.current = abort })
.catch(() => {
setLoading(false)
setCompareLoading(false)
compareLoadingRef.current = false
updateClusterErrorAssistantMessage(0)
})
.finally(() => {
setLoading(false)
setCompareLoading(false)
compareLoadingRef.current = false
})
}, 0)
})
.catch(() => {
setLoading(false)
setCompareLoading(false)
compareLoadingRef.current = false
})
}
@@ -476,7 +476,7 @@ const Chat: FC<ChatProps> = ({
const handleClusterSend = (msg?: string) => {
if (loading || !id) return
setLoading(true)
setCompareLoading(true)
compareLoadingRef.current = true
const files = (fileList || []).filter(item => !['uploading', 'error'].includes(item.status))
handleSave(false)
.then(() => {
@@ -500,8 +500,8 @@ const Chat: FC<ChatProps> = ({
}
break
case 'message':
if (compareLoading) {
setCompareLoading(false)
if (compareLoadingRef.current) {
compareLoadingRef.current = false
}
updateClusterAssistantMessage(content)
if (conversation_id && conversationId !== conversation_id) {
@@ -509,14 +509,14 @@ const Chat: FC<ChatProps> = ({
}
break;
case 'model_end':
if (compareLoading) {
setCompareLoading(false)
if (compareLoadingRef.current) {
compareLoadingRef.current = false
}
updateClusterErrorAssistantMessage(message_length)
break;
case 'compare_end':
if (compareLoading) {
setCompareLoading(false)
if (compareLoadingRef.current) {
compareLoadingRef.current = false
}
setLoading(false);
break;
@@ -547,18 +547,18 @@ const Chat: FC<ChatProps> = ({
)
.catch(() => {
setLoading(false)
setCompareLoading(false)
compareLoadingRef.current = false
updateClusterErrorAssistantMessage(0)
})
.finally(() => {
setLoading(false)
setCompareLoading(false)
compareLoadingRef.current = false
})
}, 0)
})
.catch(() => {
setLoading(false)
setCompareLoading(false)
compareLoadingRef.current = false
})
}
@@ -628,7 +628,7 @@ const Chat: FC<ChatProps> = ({
/>}
onSend={isCluster ? handleClusterSend : handleSend}
data={chat.list || []}
streamLoading={compareLoading}
streamLoading={compareLoadingRef.current}
labelPosition="top"
labelFormat={(item) => item.role === 'user' ? t('application.you') : chat.label || t(`application.ai`)}
errorDesc={t('application.ReplyException')}

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-03 16:25:32
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-04 10:34:43
* @Last Modified time: 2026-04-21 13:34:52
*/
/**
* Knowledge Base Component
@@ -54,7 +54,7 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi
const basesWithoutName = knowledge_bases.filter(base => !base.name)
if (basesWithoutName.length > 0) {
// Call API to get complete knowledge base information
getKnowledgeBaseList().then(res => {
getKnowledgeBaseList(undefined, { kb_ids: basesWithoutName.map(vo => vo.kb_id).join(',') }).then(res => {
const fullBases = knowledge_bases.map(base => {
if (!base.name) {
const fullBase = res.items.find((item: any) => item.id === base.kb_id)

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-28 14:08:14
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-13 18:17:32
* @Last Modified time: 2026-04-20 16:52:32
*/
/**
* UploadModal Component
@@ -16,6 +16,7 @@
import { forwardRef, useImperativeHandle, useState, useMemo } from 'react';
import { Form, Steps, Flex, Alert, Button, Result, message } from 'antd';
import { useTranslation } from 'react-i18next';
import { useNavigate } from 'react-router-dom';
import type { Application, UploadModalRef } from '../types'
import RbModal from '@/components/RbModal'
@@ -51,6 +52,7 @@ const UploadModal = forwardRef<UploadModalRef, UploadModalProps>(({
id
}, ref) => {
const { t } = useTranslation();
const navigate = useNavigate();
// State management
const [visible, setVisible] = useState(false); // Modal visibility
@@ -146,6 +148,10 @@ const UploadModal = forwardRef<UploadModalRef, UploadModalProps>(({
window.open(`/#/application/config/${appId}`, '_blank');
}
break;
case 'list':
if (id) {
navigate('/application')
}
}
}, 100)
};

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-03 17:09:03
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-31 12:21:56
* @Last Modified time: 2026-04-20 16:59:25
*/
/**
* Memory Conversation Page
@@ -78,8 +78,8 @@ interface DataItem {
id: string;
question: string;
type: string;
reason: string;
}
reason?: string;
}
/**
* Log item for conversation analysis
*/
@@ -88,13 +88,15 @@ export interface LogItem {
title: string;
data?: DataItem[] | AnyObject;
raw_results?: string | Record<string, AnyObject>;
raw_result?: Array<AnyObject>;
summary?: string;
query?: string;
reason?: string;
result?: string;
original_query: string;
original_query?: string;
index?: number;
result_count?: number;
total?: number;
}
/**
@@ -242,7 +244,6 @@ const MemoryConversation: FC = () => {
<ContentWrapper key={vo.id}>
<>
<div className="rb:font-medium rb:text-[#212332]">{vo.id}. {vo.question}</div>
<div className="rb:mt-2 rb:text-[#5B6167]">{vo.reason}</div>
</>
</ContentWrapper>
))}
@@ -260,25 +261,9 @@ const MemoryConversation: FC = () => {
</ContentWrapper>
))}
</Flex>
: log.type === 'search_result' && log.raw_results && typeof log.raw_results !== 'string'
: log.type === 'search_result' && log.result
? <ContentWrapper>
<div className="rb:font-medium rb:text-[#212332] rb:mb-2">{log.query}</div>
{(log.raw_results.reranked_results as AnyObject)?.communities?.length > 0 && <>
<div className="rb:font-medium rb:text-[#212332]">{t('memoryConversation.communities')}</div>
<ul className='rb:mt-2 rb:text-[#5B6167] rb:list-disc rb:pl-4'>
{((log.raw_results.reranked_results as AnyObject)?.communities as { content: string }[]).map((item, index: number) => (
<li key={index}>{item.content}</li>
))}
</ul>
</>}
{(log.raw_results.reranked_results as AnyObject)?.summaries?.length > 0 && <>
<div className="rb:font-medium rb:text-[#212332]">{t('memoryConversation.summaries')}</div>
<ul className='rb:mt-2 rb:text-[#5B6167] rb:list-disc rb:pl-4'>
{((log.raw_results.reranked_results as AnyObject)?.summaries as { content: string }[]).map((item, index: number) => (
<li key={index}>{item.content}</li>
))}
</ul>
</>}
<Markdown content={log.result} />
</ContentWrapper>
: log.type === 'retrieval_summary' && log.summary
? <ContentWrapper>

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-03 16:49:28
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-16 18:03:53
* @Last Modified time: 2026-04-21 15:02:53
*/
/**
* Custom Model Modal
@@ -230,21 +230,23 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
<Input.TextArea placeholder={t('common.pleaseEnter')} />
</Form.Item>
<Form.Item
name={["api_keys", 0, "api_key"]}
label={t('modelNew.api_key')}
rules={[{ required: true, message: t('common.inputPlaceholder', { title: t('modelNew.api_key') }) }]}
>
<Input.Password placeholder={t('common.pleaseEnter')} />
</Form.Item>
{!isEdit && <>
<Form.Item
name={["api_keys", 0, "api_key"]}
label={t('modelNew.api_key')}
rules={[{ required: true, message: t('common.inputPlaceholder', { title: t('modelNew.api_key') }) }]}
>
<Input.Password placeholder={t('common.pleaseEnter')} />
</Form.Item>
<Form.Item
name={["api_keys", 0, "api_base"]}
label={t('modelNew.api_base')}
rules={[{ required: true, message: t('common.inputPlaceholder', { title: t('modelNew.api_base') }) }]}
>
<Input placeholder="https://api.example.com/v1" />
</Form.Item>
<Form.Item
name={["api_keys", 0, "api_base"]}
label={t('modelNew.api_base')}
rules={[{ required: true, message: t('common.inputPlaceholder', { title: t('modelNew.api_base') }) }]}
>
<Input placeholder="https://api.example.com/v1" />
</Form.Item>
</>}
{['llm', 'chat'].includes(modelType as string) &&
<Row gutter={16}>

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-04-14 11:43:57
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-14 11:44:40
* @Last Modified time: 2026-04-21 15:44:13
*/
export const billingUnits = [
{
@@ -42,7 +42,7 @@ export const billingUnits = [
},
{
key: 'model_quota',
unit: 'ops', placeholder: 'numberPlaceholder',
unit: 'pcs', placeholder: 'numberPlaceholder',
icon: 'model',
},
{

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-04-14 11:34:42
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-16 17:23:49
* @Last Modified time: 2026-04-21 15:45:30
*/
/**
* Package Component
@@ -60,7 +60,7 @@ const btnClassNames = {
default: 'rb:h-10! rb:rounded-[8px]! rb:bg-[#212332]! rb:text-white! rb:border-0! rb:hover:border-0! rb:hover:opacity-[0.8]',
}
export const UnitWrapper = ({ titleKey, value, icon, unit, theme_color = '#171719' }: { titleKey: string; value: number | string; icon: string; unit?: string; theme_color?: string; }) => {
export const UnitWrapper = ({ titleKey, value, icon, unit, theme_color = '#171719' }: { titleKey: string; value?: number | string | null; icon: string; unit?: string; theme_color?: string; }) => {
const { t } = useTranslation();
const renderFeatureIcon = (iconKey: string, color: string) => {
@@ -78,7 +78,7 @@ export const UnitWrapper = ({ titleKey, value, icon, unit, theme_color = '#17171
>{renderFeatureIcon(icon, theme_color)}</Flex>
<div className="rb:text-[13px] rb:leading-4.5">
<div className="rb:text-[#5F6266]">{t(`package.${titleKey}`)}</div>
<div>{value} {unit ? t(`package.${unit}`) : ''}</div>
{value ? <div>{value} {unit ? t(`package.${unit}`) : ''}</div> : <div>{t('package.noLimit')}</div>}
</div>
</Flex>
)
@@ -252,7 +252,6 @@ const Package: FC = () => {
>
{billingUnits.map(({ key, unit, icon }) => {
const value = pkg?.quotas?.[key as keyof Package['quotas']];
if (value === undefined || value === null) return null;
return (
<UnitWrapper
key={key}
@@ -264,7 +263,7 @@ const Package: FC = () => {
/>
)
})}
{pkg.tech_support && (
{pkg.tech_support && pkg[getKeyWithLanguage('tech_support')] && (
<UnitWrapper
titleKey="tech_support"
value={String(pkg[getKeyWithLanguage('tech_support')] ?? '')}
@@ -272,7 +271,7 @@ const Package: FC = () => {
theme_color={pkg.theme_color}
/>
)}
{pkg.sla_compliance && (
{pkg.sla_compliance && pkg[getKeyWithLanguage('sla_compliance')] && (
<UnitWrapper
titleKey="sla"
value={String(pkg[getKeyWithLanguage('sla_compliance')] ?? '')}

View File

@@ -31,6 +31,13 @@ interface EndUserProfileProps {
className?: string;
}
const formatValue = (value: string | string[] | null | undefined) => {
if (!value) return '-'
if (Array.isArray(value)) {
return value.length ? value.join(' | ') : '-'
}
return value
}
const EndUserProfile = forwardRef<EndUserProfileRef, EndUserProfileProps>(({ className, onDataLoaded }, ref) => {
const { t } = useTranslation()
const { id } = useParams()
@@ -89,19 +96,19 @@ const EndUserProfile = forwardRef<EndUserProfileRef, EndUserProfileProps>(({ cla
</div>
<div>
<div className="rb:text-[#7B8085]">{t('userMemory.role')}</div>
<div className="rb:mt-0.5">{data?.profile?.role?.join(' | ') || '-'}</div>
<div className="rb:mt-0.5">{formatValue(data?.profile?.role)}</div>
</div>
<div>
<div className="rb:text-[#7B8085]">{t('userMemory.domain')}</div>
<div className="rb:mt-0.5">{data?.profile?.domain?.join(' | ') || '-'}</div>
<div className="rb:mt-0.5">{formatValue(data?.profile?.domain)}</div>
</div>
<div>
<div className="rb:text-[#7B8085]">{t('userMemory.expertise')}</div>
<div className="rb:mt-0.5">{data?.profile?.expertise?.join(' | ') || '-'}</div>
<div className="rb:mt-0.5">{formatValue(data?.profile?.expertise)}</div>
</div>
<div>
<div className="rb:text-[#7B8085]">{t('userMemory.interests')}</div>
<div className="rb:mt-0.5">{data?.profile?.interests?.join(' | ') || '-'}</div>
<div className="rb:mt-0.5">{formatValue(data?.profile?.interests)}</div>
</div>
<div className="rb:text-[#7B8085] rb:text-[12px] rb:leading-4.5">

View File

@@ -178,7 +178,7 @@ export interface EndUser {
created_at: string;
updated_at: string;
profile: {
role: string[];
role: string[] | string;
domain: string[];
expertise: string[];
interests: string[];

View File

@@ -1,3 +1,9 @@
/*
* @Author: ZhaoYing
* @Date: 2026-04-09 18:58:21
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-20 10:39:17
*/
import { useState, useCallback, useEffect, useRef, type FC } from 'react'
import { Popover, Flex } from 'antd'
import { WarningFilled } from '@ant-design/icons'
@@ -49,7 +55,7 @@ const specialValidators: Record<string, (val: any) => boolean> = {
if (expr?.sub_variable_condition?.conditions?.length > 0) return expr.sub_variable_condition?.conditions.every(isSubExprSet)
if (!expr.left) return false
if (['not_empty', 'empty'].includes(expr.operator)) return true
return !!expr.left && (!!expr.right || typeof expr.right === 'boolean' || typeof expr.right === 'number')
return !!expr.left && (expr?.sub_variable_condition || !!expr.right || typeof expr.right === 'boolean' || typeof expr.right === 'number')
}
return val.some(c => !c?.expressions?.length || c.expressions.some((expr: any) => !isExprSet(expr)))
},
@@ -100,6 +106,18 @@ function validateNode(type: string, config: Record<string, any>): CheckError[] {
if (isInvalid) errors.push({ key: specialKey, message: '' })
})
// llm: vision_input required when vision is enabled
if (type === 'llm') {
const vision = get('vision')
if (vision === true || vision === 'true') {
const visionInput = get('vision_input')
console.log('vision', vision, isEmpty(visionInput))
if (isEmpty(visionInput)) {
errors.push({ key: 'llm.vision_input', message: '' })
}
}
}
// http-request body.data (binary) — not a top-level required field, check separately
if (type === 'http-request') {
const body = get('body')

View File

@@ -10,7 +10,7 @@ import { useVariableList } from '../Properties/hooks/useVariableList'
import { isSubExprSet } from '../../utils'
import { fileSubFieldOperators } from '../Properties/CaseList'
const caculateIsSet = (item: any, type: string) => {
const calculateIsSet = (item: any, type: string) => {
switch (type) {
case 'categories':
return typeof item?.class_name === 'string' && item?.class_name !== ''
@@ -90,7 +90,7 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => {
<div key={index} className="rb:bg-[#F0F3F8] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)] rb:rounded-md rb:py-1 rb:px-1.5 rb:text-[10px] rb:text-[#5B6167] rb:font-medium rb:leading-3.5">
<Flex justify="space-between">
<span>{t('workflow.config.question-classifier.class_name')} {index + 1}</span>
{caculateIsSet(item, 'categories') ? t(`workflow.config.${data.type}.set`) : t(`workflow.config.${data.type}.unset`)}
{calculateIsSet(item, 'categories') ? t(`workflow.config.${data.type}.set`) : t(`workflow.config.${data.type}.unset`)}
</Flex>
</div>
))}
@@ -100,17 +100,24 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => {
<Flex vertical gap={4} className="rb:mt-3!">
{data.config?.cases?.defaultValue.map((item: any, index: number) => (
<div key={index} className={item.expressions.length > 0 ? '' : 'rb:mb-1'}>
<Flex justify={item.expressions.length > 0 ? "space-between" : 'end'} className="rb:mb-1">
{item.expressions.length > 0 && <span className="rb:text-[#5B6167] rb:text-[10px]">CASE{index + 1}</span>}
<Flex justify={item.expressions.length > 0 ? "space-between" : 'end'} className="rb:mb-1! rb:leading-4">
{item.expressions.length > 0 && <span className="rb:text-[#5B6167] rb:text-[10px] rb:pl-1">CASE{index + 1}</span>}
<span className="rb:text-[#212332] rb:font-medium rb:text-[12px]">{index === 0 ? 'IF' : `ELIF`}</span>
</Flex>
{item.expressions.length > 0 && <Flex vertical gap={2}>
{item.expressions.map((expression: any, eIndex: number) => (
<div key={eIndex} className="rb:relative">
{item.expressions.length > 1 && eIndex > 0 && <div className="rb:absolute rb:-top-2 rb:right-2 rb:text-[10px] rb:text-[#155EEF] rb:font-medium rb:leading-3.5 rb:text-right rb:pr-0.5">{item.logical_operator?.toLocaleUpperCase()}</div>}
<Flex vertical gap={2} className="rb:bg-[#F0F3F8] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)] rb:rounded-md rb:py-1! rb:px-1.5! rb:text-[10px] rb:text-[#5B6167] rb:font-medium rb:leading-3.5">
{item.expressions.length > 1 && eIndex > 0 &&
<div className="rb:absolute rb:-top-2 rb:right-2 rb:text-[10px] rb:text-[#155EEF] rb:font-medium rb:leading-3.5 rb:text-right rb:pr-0.5">{item.logical_operator?.toLocaleUpperCase()}</div>
}
<Flex vertical gap={2}
className={clsx("rb:bg-[#F0F3F8] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)] rb:rounded-md rb:px-1.5! rb:text-[10px] rb:text-[#5B6167] rb:font-medium rb:leading-4", {
'rb:pt-1!': expression.sub_variable_condition?.conditions?.length > 0,
'rb:py-1!': !expression.sub_variable_condition?.conditions || !expression.sub_variable_condition?.conditions?.length
})}
>
<Flex align="center">
{caculateIsSet(expression, 'cases')
{calculateIsSet(expression, 'cases')
? <>
{labelRender(expression.left)}
<span className="rb:mx-1">{getLocaleField(expression.operator, typeof expression.right)}</span>
@@ -120,11 +127,16 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => {
}
</Flex>
{expression.sub_variable_condition?.conditions?.length > 0 && expression.sub_variable_condition?.conditions.every(isSubExprSet)
? <div className="rb-border-l rb:ml-2 rb:mt-1.5">
? <div className="rb-border-l rb:ml-2 rb:mt-1">
{expression.sub_variable_condition?.conditions.map((sub: any, sIndex: number) => (
<div key={sIndex} className="rb:relative">
{expression.sub_variable_condition?.conditions.length > 1 && sIndex > 0 && <div className="rb:absolute rb:-top-2 rb:right-2 rb:text-[10px] rb:text-[#155EEF] rb:font-medium rb:leading-3.5 rb:text-right rb:pr-0.5">{expression.sub_variable_condition?.logical_operator?.toLocaleUpperCase()}</div>}
<Flex align="center" className=" rb:py-1! rb:px-1.5! rb:text-[10px] rb:text-[#5B6167] rb:font-medium rb:leading-3.5">
<Flex align="center"
className={clsx("rb:px-1.5! rb:text-[10px] rb:text-[#5B6167] rb:font-medium rb:leading-3.5", {
'rb:py-1!': sIndex !== 0,
'rb:pb-1': sIndex === 0
})}
>
<span className="rb:text-[#155EEF]">{sub.key}</span>
<span className="rb:mx-1">{getSubLocaleField(sub.operator, sub.key)}</span>
<span className="rb:break-all rb:line-clamp-1">
@@ -140,7 +152,7 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => {
))}
</div>
: expression.sub_variable_condition?.conditions?.length > 0
? <Flex align="center" className="rb:mt-1! rb:pl-2! rb:rounded-md rb:py-1! rb:px-1.5! rb:text-[10px] rb:text-[#5B6167] rb:font-medium rb:leading-3.5">
? <Flex align="center" className="rb:pl-2! rb:rounded-md rb:pb-1! rb:px-1.5! rb:text-[10px] rb:text-[#5B6167] rb:font-medium rb:leading-4">
{t(`workflow.config.${data.type}.unset`)}
</Flex>
: null

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-09 18:24:53
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-17 20:47:49
* @Last Modified time: 2026-04-20 10:46:05
*/
import { useEffect, useMemo, type FC } from 'react'
import clsx from 'clsx'
@@ -39,7 +39,7 @@ interface Expression {
sub_variable_condition?: SubVariableCondition;
}
interface CaseItem {
export interface CaseItem {
logical_operator: 'and' | 'or';
expressions: Expression[];
}
@@ -274,7 +274,9 @@ const ArrayFileSubConditions: FC<ArrayFileSubConditionsProps> = ({ conditionFiel
className="rb:w-full!"
suffix="Byte"
size="small"
onChange={(value) => { form.setFieldValue([name, caseIndex, 'expressions', conditionIndex, 'right'], value); }}
onChange={(value) => {
form.setFieldValue([name, caseIndex, 'expressions', conditionIndex, 'sub_variable_condition', 'conditions', subIndex, 'value'], value);
}}
/>
}
</Form.Item>
@@ -483,13 +485,24 @@ const CaseList: FC<CaseListProps> = ({
form.setFieldValue([name, index, 'logical_operator'], currentValue === 'and' ? 'or' : 'and');
};
const handleLeftFieldChange = (caseIndex: number, conditionIndex: number, newValue: string) => {
form.setFieldValue([name, caseIndex, 'expressions', conditionIndex], {
left: newValue,
operator: undefined,
right: undefined,
input_type: 'constant'
});
const handleLeftFieldChange = (caseIndex: number, conditionIndex: number, newValue: string, option?: Suggestion | undefined) => {
if (option?.dataType === 'array[file]') {
form.setFieldValue([name, caseIndex, 'expressions', conditionIndex], {
left: newValue,
operator: undefined,
sub_variable_condition: {
conditions: [],
logical_operator: 'and'
}
});
} else {
form.setFieldValue([name, caseIndex, 'expressions', conditionIndex], {
left: newValue,
operator: undefined,
right: undefined,
input_type: 'constant'
});
}
};
const handleAddCase = (addCaseFunc: Function) => {
@@ -590,7 +603,7 @@ const CaseList: FC<CaseListProps> = ({
options={options}
size="small"
allowClear={false}
onChange={(val) => handleLeftFieldChange(caseIndex, conditionIndex, val as string)}
onChange={(val, option) => handleLeftFieldChange(caseIndex, conditionIndex, val as string, option as unknown as Suggestion)}
variant="borderless"
className="rb:w-36!"
/>

View File

@@ -29,12 +29,13 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi
if (value && JSON.stringify(value) !== JSON.stringify(editConfig)) {
setEditConfig({ ...(value || {}) })
const knowledge_bases = [...(value.knowledge_bases || [])]
setKnowledgeList(knowledge_bases)
// 检查是否有knowledge_bases缺少name字段
const basesWithoutName = knowledge_bases.filter(base => !base.name)
if (basesWithoutName.length > 0) {
// 调用接口获取完整的知识库信息
getKnowledgeBaseList().then(res => {
getKnowledgeBaseList(undefined, { kb_ids: basesWithoutName.map(vo => vo.kb_id).join(',') }).then(res => {
const fullBases = knowledge_bases.map(base => {
if (!base.name) {
const fullBase = res.items.find((item: any) => item.id === base.kb_id)

View File

@@ -1,4 +1,4 @@
import { type FC, useEffect, useState, useMemo } from "react";
import { type FC, useEffect, useState } from "react";
import { useTranslation } from 'react-i18next'
import { Form, Select, Switch, Cascader, type CascaderProps, Tooltip } from 'antd'
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
@@ -45,15 +45,15 @@ const ToolConfig: FC<{ options: Suggestion[]; }> = ({
getToolDetail(values.tool_id)
.then(res => {
const detail = res as { tool_type: ToolType; }
getTools({ tool_type: detail.tool_type })
.then(toolsRes => {
const tools = toolsRes as ToolItem[]
getToolMethods(values.tool_id)
.then(methodsRes => {
const response = methodsRes as Array<{ method_id: string; name: string; parameters: Parameter[] }>
setOptionList(prevList => {
return prevList.map(item => {
if (item.value === detail.tool_type) {
@@ -76,7 +76,7 @@ const ToolConfig: FC<{ options: Suggestion[]; }> = ({
return item
})
})
if (response.length > 1) {
const filterTarget = response.find(vo => vo.name === values.tool_parameters?.operation)
if (filterTarget) {
@@ -98,7 +98,7 @@ const ToolConfig: FC<{ options: Suggestion[]; }> = ({
useEffect(() => {
if (values.tools && values.tools.length === 3) {
const [toolType, toolId, operation] = values.tools
// 从 optionList 中查找对应的参数
const typeOption = optionList.find(opt => opt.value === toolType)
if (typeOption?.children) {
@@ -147,21 +147,26 @@ const ToolConfig: FC<{ options: Suggestion[]; }> = ({
};
const handleChange: CascaderProps<Option>['onChange'] = (value, selectedOptions) => {
if (!value) {
setParameters([])
form.resetFields()
return
}
const targetOption = selectedOptions[selectedOptions.length - 1];
const curParameters = [...(targetOption.parameters ?? [])]
setParameters([...curParameters])
const inititalValue: any = { tool_id: selectedOptions[1].value, tool_parameters: {} }
const initialValue: any = { tool_id: selectedOptions[1].value, tool_parameters: { operation: undefined } }
if (value[0] === 'mcp' || (value[0] === 'builtin' && selectedOptions[1]?.children && selectedOptions[1].children.length > 1)) {
inititalValue.tool_parameters.operation = value?.[2]
initialValue.tool_parameters.operation = value?.[2]
} else if (value[0] === 'custom') {
inititalValue.tool_parameters.operation = selectedOptions?.[2].method_id
initialValue.tool_parameters.operation = selectedOptions?.[2].method_id
}
curParameters.forEach(vo => {
inititalValue.tool_parameters[vo.name] = vo.default
initialValue.tool_parameters[vo.name] = vo.default
})
form.setFieldsValue(inititalValue)
form.setFieldsValue(initialValue)
}
// string -> string
@@ -209,9 +214,9 @@ const ToolConfig: FC<{ options: Suggestion[]; }> = ({
name="tools"
label={t('workflow.config.tool.tool_id')}
>
<Cascader
<Cascader
placeholder={t('common.pleaseSelect')}
options={optionList}
options={optionList}
loadData={loadData}
onChange={handleChange}
changeOnSelect={false}
@@ -239,8 +244,8 @@ const ToolConfig: FC<{ options: Suggestion[]; }> = ({
{parameter.type === 'string' && parameter.enum && parameter.enum.length > 0
? <Select size="small" options={parameter.enum.map(vo => ({ value: vo, label: vo }))} placeholder={t('common.pleaseSelect')} />
: parameter.type === 'boolean'
? <Switch size="small" />
: <Editor
? <Switch size="small" />
: <Editor
variant="outlined"
type="input"
size="small"

View File

@@ -393,18 +393,19 @@ export const useVariableList = (
// Add chat variables
chatVariables?.forEach(v => addVariable(list, keys, `CONVERSATION_${v.name}`, v.name, v.type, `conv.${v.name}`, { type: 'CONVERSATION', name: 'CONVERSATION', icon: '' }, { group: 'CONVERSATION' }));
// Process each relevant node: non-list-operator first, then list-operator
const listOperatorIds: string[] = [];
// Process each relevant node: deferred types last (they depend on prior variables)
const deferredIds: string[] = [];
relevantIds.forEach(id => {
const node = nodes.find(n => n.id === id);
if (!node) return;
if (node.getData()?.type === 'list-operator') {
listOperatorIds.push(id);
const t = node.getData()?.type;
if (['var-aggregator', 'list-operator', 'iteration'].includes(t)) {
deferredIds.push(id);
} else {
processNodeVariables(node.getData(), node.getData().id, list, keys);
}
});
listOperatorIds.forEach(id => {
deferredIds.forEach(id => {
const node = nodes.find(n => n.id === id);
if (node) processNodeVariables(node.getData(), node.getData().id, list, keys);
});

View File

@@ -4,17 +4,17 @@
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-21 18:23:31
*/
import LoopNode from './components/Nodes/LoopNode';
import NormalNode from './components/Nodes/NormalNode';
import type { ReactShapeConfig } from '@antv/x6-react-shape';
import type { GroupMetadata, PortMetadata } from '@antv/x6/lib/model/port';
import AddNode from './components/Nodes/AddNode';
import ConditionNode from './components/Nodes/ConditionNode';
import GroupStartNode from './components/Nodes/GroupStartNode';
import AddNode from './components/Nodes/AddNode'
import LoopNode from './components/Nodes/LoopNode';
import NormalNode from './components/Nodes/NormalNode';
import NoteNode from './components/Nodes/NoteNode';
import type { PortMetadata, GroupMetadata } from '@antv/x6/lib/model/port';
import type { ReactShapeConfig } from '@antv/x6-react-shape';
import { memoryConfigListUrl } from '@/api/memory'
import type { NodeLibrary } from './types'
import { memoryConfigListUrl } from '@/api/memory';
import type { NodeLibrary } from './types';
/**
* Workflow node library configuration
@@ -143,7 +143,7 @@ export const nodeLibrary: NodeLibrary[] = [
},
vision_input: {
type: 'variableList',
onFilterVariableType: ['array[file]']
onFilterVariableType: ['array[file]', 'file']
}
}
},
@@ -437,7 +437,8 @@ export const nodeLibrary: NodeLibrary[] = [
{ type: "tool", icon: 'rb:bg-[url("@/assets/images/workflow/tools.svg")]',
config: {
tool_id: {
type: 'cascader'
type: 'cascader',
required: true
},
tool_parameters: {
type: 'define'
@@ -743,7 +744,7 @@ export const portTextAttrs = { fontSize: 12, fill: '#5B6167' }
/**
* Port position arguments
*/
export const portItemArgsY = 26.5;
export const portItemArgsY = 27.5;
export const portArgs = { x: nodeWidth, y: portItemArgsY }
const defaultPortGroup = {

View File

@@ -465,11 +465,11 @@ export const useWorkflowGraph = ({
graphRef.current.addEdges(edgeList.filter(vo => vo !== null))
}
graphRef.current.centerContent()
// Initialize after completion, display nodes in visible area
if (nodes.length > 0 || edges.length > 0) {
setTimeout(() => {
if (graphRef.current) {
graphRef.current.centerContent()
graphRef.current.getNodes().forEach(node => {
if (!node.getData()?.cycle) node.toFront();
});

View File

@@ -2,136 +2,70 @@
* @Author: ZhaoYing
* @Date: 2026-03-24 15:07:49
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-17 20:40:47
* @Last Modified time: 2026-04-20 14:20:34
*/
import { portItemArgsY, conditionNodePortItemArgsY, conditionNodeHeight } from './constant'
import { conditionNodePortItemArgsY, conditionNodeHeight } from './constant'
/**
* Calculate the total height of a condition (if-else) node based on its cases.
*
* The height is composed of:
* - `conditionNodeHeight`: the base height of the node (header + padding).
* - `(cases.length - 1) * 26`: vertical spacing added for each additional case
* beyond the first (each case separator row is 26px).
* - `exprCount * 20`: each individual expression row occupies 20px.
* - `hasMultiExprCount * 3`: a small extra padding (3px per expression) is added
* for cases that contain more than one expression, to account for the logical
* operator indicator (AND/OR) between expressions.
*
* @param cases - Array of case objects, each containing an `expressions` array.
* @returns The total pixel height for the condition node.
*/
export const isSubExprSet = (sub: any) => {
if (!sub?.key) return false;
if (['not_empty', 'empty'].includes(sub?.operator)) return true;
return !!sub.value || typeof sub.value === 'boolean' || typeof sub.value === 'number';
};
const getEffectiveExprCount = (expr: any): number => {
const subs = expr?.sub_variable_condition?.conditions;
if (subs?.length && subs.every(isSubExprSet)) return 1 + subs.length;
if (subs?.length > 0) {
return 2
}
return 1;
/**
* Calculate the total height of a condition (if-else) node based on its cases.
* Uses the same per-expression height logic as getConditionNodeCasePortY.
*/
export const calcConditionNodeTotalHeight = (cases: any[]) => {
const casesHeight = cases.reduce((acc: number, c: any) => {
const exprs = c?.expressions ?? [];
const n = exprs.length;
const exprsHeight = n === 0 ? 0 : exprs.reduce((s: number, e: any) => s + calcExpressionHeight(e), 0) + 2 * (n - 1);
return acc + 20 + exprsHeight;
}, 0);
return conditionNodeHeight + casesHeight + (cases.length - 1) * 4 - 27.5;
};
export const calcConditionNodeTotalHeight = (cases: any[]) => {
// Total number of effective expression rows (sub_variable_condition expand height when all set)
const exprCount = cases.reduce((acc: number, c: any) =>
acc + (c?.expressions?.reduce((s: number, e: any) => s + getEffectiveExprCount(e), 0) || 0), 0);
// Sum of effective expression counts only for cases that have more than one expression
const hasMultiExprCount = cases.reduce((acc: number, c: any) => {
if (!c?.expressions?.length || c.expressions.length <= 1) return acc;
const effectiveCount = c.expressions.reduce((s: number, e: any) => s + getEffectiveExprCount(e), 0);
return acc + effectiveCount;
}, 0);
return conditionNodeHeight + (cases.length - 1) * 26 + exprCount * 20 + hasMultiExprCount * 3;
/**
* Height of a single expression block in ConditionNode (px).
*
* expression outer Flex padding:
* - has sub conditions (length > 0): pt-1 (4px top only)
* - no sub conditions: py-1 (4px top + 4px bottom)
* expression main row: leading-4 = 16px
* sub_variable_condition block (mt-1 = 4px gap):
* - all isSet, m subs: sub[0] = leading-3.5(14) + pb-1(4) = 18px;
* sub[k>0] = py-1(8) + leading-3.5(14) = 22px
* total = 18 + 22*(m-1)
* - exists but not all isSet: pb-1(4) + leading-4(16) = 20px
*/
const calcExpressionHeight = (expression: any): number => {
const subs = expression?.sub_variable_condition?.conditions;
if (!subs?.length) return 24; // py-1(8) + leading-4(16)
const subBlockHeight = subs.every(isSubExprSet)
? 18 + 22 * (subs.length - 1)
: 20;
return 4 + 16 + 4 + subBlockHeight - 2; // pt-1 + main row + mt-1 + sub block (-2 rendering correction)
};
/**
* Calculate the Y-coordinate of the right-side output port for a specific case
* in a condition (if-else) node.
* in a condition (if-else) node, aligned with the IF/ELIF label in ConditionNode.
*
* The port position is determined by iterating through all preceding cases
* (index 0 to caseIndex - 1) and accumulating their visual heights. Several
* pixel-level corrections are applied to align ports with the rendered UI:
*
* 1. **Base offset**: starts at `conditionNodePortItemArgsY`, which is the Y
* position of the first case port relative to the node top.
*
* 2. **Per-case accumulation**: for each preceding case with `n` expressions,
* add `portItemArgsY * (n + 1)` — this accounts for `n` expression rows
* plus one case header/separator row.
*
* 3. **Single-expression correction**: cases with exactly 1 expression render
* slightly shorter than the generic formula predicts. Subtract
* `singleExprCount * 7 + 2` to compensate for the reduced row height when
* no logical operator row is shown.
*
* 4. **Multi-expression correction**: cases with 2+ expressions have a compact
* logical operator row. Subtract `multiExprCount * 9` to offset the
* over-estimated spacing.
*
* 5. **Extra expression correction**: for cases with more than 2 expressions,
* each additional expression beyond the second introduces a minor spacing
* discrepancy. Subtract `(extraExprs + 1) * 2` to fine-tune alignment.
*
* @param cases - Array of case objects, each containing an `expressions` array.
* @param caseIndex - The zero-based index of the target case whose port Y is needed.
* @returns The Y-coordinate (in pixels) for the output port of the given case.
* Layout (from node top):
* - 12px padding-top + 24px header + 12px mt-3 = 48px to cases area
* - Each IF/ELIF label row: leading-4 (16px), center at +8px → first port Y = 56.5
* - Each case: IF/ELIF row (leading-4=16) + mb-1(4) + expressions (gap={2}=2px between)
* - Gap between cases (Flex gap={4}): 4px
*/
export const getConditionNodeCasePortY = (cases: any[], caseIndex: number) => {
let y = conditionNodePortItemArgsY;
let singleExprCount = 0;
let multiExprCount = 0;
let extraExprs = 0;
let portItemArgsYNum = 0;
let y = conditionNodePortItemArgsY; // 56.5, center of first IF label
for (let i = 0; i < caseIndex; i++) {
const notHasSub = cases[i]?.expressions?.filter((e: any) => !e?.sub_variable_condition?.conditions || e?.sub_variable_condition?.conditions.length <1).length
const n = cases[i]?.expressions?.length || 0;
let casePortItemArgsYNum = n + 1;
// Add extra y for expressions with all sub_variable_condition set
cases[i]?.expressions?.forEach((expr: any) => {
const subs = expr?.sub_variable_condition?.conditions;
if (subs?.length && subs.every(isSubExprSet)) {
casePortItemArgsYNum += subs.length;
} else if (subs?.length) {
casePortItemArgsYNum += 1
}
});
portItemArgsYNum += casePortItemArgsYNum;
if (n === 1 && !cases[i]?.expressions?.some((e: any) => e?.sub_variable_condition?.conditions?.length > 0)) {
singleExprCount++
} else if (n >= 2 || cases[i]?.expressions?.some((e: any) => e?.sub_variable_condition?.conditions?.length > 0)) {
multiExprCount++;
cases[i]?.expressions?.forEach((e: any) => {
const subs = e?.sub_variable_condition?.conditions;
if (subs?.length && subs.every(isSubExprSet) && subs.length > 1) {
extraExprs += subs.length + 2;
}
});
console.log('extraExprs notHasSub', notHasSub)
if (notHasSub > 3) {
extraExprs += n - 2 + notHasSub/4;
} else {
extraExprs += n - 2 + notHasSub/4
}
}
const exprs = cases[i]?.expressions ?? [];
const n = exprs.length;
// IF/ELIF row (16) + mb-1 (4) = 20px base; expressions: sum of heights + 2px gap between
const exprsHeight = n === 0 ? 0 : exprs.reduce((acc: number, e: any) => acc + calcExpressionHeight(e), 0) + 2 * (n - 1);
y += 20 + exprsHeight + 4; // case height + Flex gap between cases
}
console.log('singleExprCount', singleExprCount, 'multiExprCount', multiExprCount, 'extraExprs', extraExprs)
y += portItemArgsY * portItemArgsYNum
// Correction for single-expression cases (slightly shorter rendered height)
if (singleExprCount > 0) y -= singleExprCount * 7 + 2;
// Correction for multi-expression cases (compact logical operator row)
y -= multiExprCount * 9;
// Correction for cases with more than 2 expressions (minor spacing drift)
if (extraExprs > 0) y -= (extraExprs + 1) * 2;
return y;
};
};