diff --git a/api/app/celery_worker.py b/api/app/celery_worker.py index 4ea4fee1..9fabe15b 100644 --- a/api/app/celery_worker.py +++ b/api/app/celery_worker.py @@ -2,6 +2,8 @@ Celery Worker 入口点 用于启动 Celery Worker: celery -A app.celery_worker worker --loglevel=info """ +from celery.signals import worker_process_init + from app.celery_app import celery_app from app.core.logging_config import LoggingConfig, get_logger @@ -13,4 +15,39 @@ logger.info("Celery worker logging initialized") # 导入任务模块以注册任务 import app.tasks + +@worker_process_init.connect +def _reinit_db_pool(**kwargs): + """ + prefork 子进程启动时重建被 fork 污染的资源。 + + fork() 后子进程继承了父进程的: + 1. SQLAlchemy 连接池 — 多进程共享 TCP socket 导致 DB 连接损坏 + 2. ThreadPoolExecutor — fork 后线程状态不确定,第二个任务会死锁 + """ + # 重建 DB 连接池 + from app.db import engine + engine.dispose() + logger.info("DB connection pool disposed for forked worker process") + + # 重建模块级 ThreadPoolExecutor(fork 后线程池不可用) + try: + from app.core.rag.deepdoc.parser import figure_parser + from concurrent.futures import ThreadPoolExecutor + figure_parser.shared_executor = ThreadPoolExecutor(max_workers=10) + logger.info("figure_parser.shared_executor recreated") + except Exception as e: + logger.warning(f"Failed to recreate figure_parser.shared_executor: {e}") + + try: + from app.core.rag.utils import libre_office + from concurrent.futures import ThreadPoolExecutor + import os + max_workers = os.cpu_count() * 2 if os.cpu_count() else 4 + libre_office.executor = ThreadPoolExecutor(max_workers=max_workers) + logger.info("libre_office.executor recreated") + except Exception as e: + logger.warning(f"Failed to recreate libre_office.executor: {e}") + + __all__ = ['celery_app'] diff --git a/api/app/config/default_free_plan.py b/api/app/config/default_free_plan.py index 409b4f7b..3ecc0498 100644 --- a/api/app/config/default_free_plan.py +++ b/api/app/config/default_free_plan.py @@ -60,7 +60,7 @@ def _build_default_free_plan(): "app_quota": 2, "knowledge_capacity_quota": 0.3, "memory_engine_quota": 1, - "end_user_quota": 1, + "end_user_quota": 10, "ontology_project_quota": 3, "model_quota": 1, "api_ops_rate_limit": 50, diff --git a/api/app/controllers/api_key_controller.py b/api/app/controllers/api_key_controller.py index dce8450d..6e414276 100644 --- a/api/app/controllers/api_key_controller.py +++ b/api/app/controllers/api_key_controller.py @@ -167,6 +167,8 @@ def update_api_key( return success(data=api_key_schema.ApiKey.model_validate(api_key), msg="API Key 更新成功") + except BusinessException: + raise except Exception as e: logger.error(f"未知错误: {str(e)}", extra={ "api_key_id": str(api_key_id), diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 3d97f2a2..eda5e76a 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -219,6 +219,7 @@ def delete_app( @router.post("/{app_id}/copy", summary="复制应用") @cur_workspace_access_guard() +@check_app_quota def copy_app( app_id: uuid.UUID, new_name: Optional[str] = None, @@ -1144,6 +1145,7 @@ async def import_workflow_config( @router.post("/workflow/import/save") @cur_workspace_access_guard() +@check_app_quota async def save_workflow_import( data: WorkflowImportSave, db: Session = Depends(get_db), @@ -1281,6 +1283,10 @@ async def import_app( return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST) target_app_id = uuid.UUID(app_id) if app_id else None + # 仅新建应用时检查配额,覆盖已有应用时跳过 + if target_app_id is None: + from app.core.quota_manager import _check_quota + _check_quota(db, current_user.tenant_id, "app_quota", "app", workspace_id=current_user.current_workspace_id) result_app, warnings = AppDslService(db).import_dsl( dsl=dsl, workspace_id=current_user.current_workspace_id, diff --git a/api/app/controllers/chunk_controller.py b/api/app/controllers/chunk_controller.py index cc1f8c98..3012d159 100644 --- a/api/app/controllers/chunk_controller.py +++ b/api/app/controllers/chunk_controller.py @@ -457,7 +457,7 @@ async def retrieve_chunks( if doc.metadata["doc_id"] not in seen_ids: seen_ids.add(doc.metadata["doc_id"]) unique_rs.append(doc) - rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k) + rs = vector_service.rerank(query=retrieve_data.query, docs=unique_rs, top_k=retrieve_data.top_k) if unique_rs else [] if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph: kb_ids = [str(kb_id) for kb_id in private_kb_ids] workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids] diff --git a/api/app/controllers/model_controller.py b/api/app/controllers/model_controller.py index 6105c3d8..57c22337 100644 --- a/api/app/controllers/model_controller.py +++ b/api/app/controllers/model_controller.py @@ -237,7 +237,6 @@ def delete_model_base( @router.post("/model_plaza/{model_base_id}/add", response_model=ApiResponse) -@check_model_quota def add_model_from_plaza( model_base_id: uuid.UUID, db: Session = Depends(get_db), @@ -275,7 +274,6 @@ def get_model_by_id( @router.post("", response_model=ApiResponse) -@check_model_quota async def create_model( model_data: model_schema.ModelConfigCreate, db: Session = Depends(get_db), diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index 049535b5..97b500fa 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -219,9 +219,20 @@ def list_conversations( end_user_repo = EndUserRepository(db) app_service = AppService(db) app = app_service._get_app_or_404(share.app_id) + workspace_id = app.workspace_id + + # 仅在新建终端用户时检查配额 + existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id) + if existing_end_user is None: + from app.core.quota_manager import _check_quota + from app.models.workspace_model import Workspace + ws = db.query(Workspace).filter(Workspace.id == workspace_id).first() + if ws: + _check_quota(db, ws.tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id) + new_end_user = end_user_repo.get_or_create_end_user( app_id=share.app_id, - workspace_id=app.workspace_id, + workspace_id=workspace_id, other_id=other_id ) logger.debug(new_end_user.id) @@ -309,7 +320,6 @@ def get_conversation( "/chat", summary="发送消息(支持流式和非流式)" ) -@check_end_user_quota async def chat( payload: conversation_schema.ChatRequest, share_data: ShareTokenData = Depends(get_share_user_id), @@ -350,6 +360,18 @@ async def chat( app_service = AppService(db) app = app_service._get_app_or_404(share.app_id) workspace_id = app.workspace_id + + # 仅在新建终端用户时检查配额,已有用户复用不受限制 + existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id) + logger.info(f"终端用户配额检查: workspace_id={workspace_id}, other_id={other_id}, existing={existing_end_user is not None}") + if existing_end_user is None: + from app.core.quota_manager import _check_quota + from app.models.workspace_model import Workspace + ws = db.query(Workspace).filter(Workspace.id == workspace_id).first() + if ws: + logger.info(f"新终端用户,执行配额检查: tenant_id={ws.tenant_id}") + _check_quota(db, ws.tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id) + new_end_user = end_user_repo.get_or_create_end_user( app_id=share.app_id, workspace_id=workspace_id, diff --git a/api/app/controllers/service/app_api_controller.py b/api/app/controllers/service/app_api_controller.py index a78fd842..93e88dc5 100644 --- a/api/app/controllers/service/app_api_controller.py +++ b/api/app/controllers/service/app_api_controller.py @@ -106,6 +106,16 @@ async def chat( other_id = payload.user_id workspace_id = api_key_auth.workspace_id end_user_repo = EndUserRepository(db) + + # 仅在新建终端用户时检查配额,已有用户复用不受限制 + existing_end_user = end_user_repo.get_end_user_by_other_id(workspace_id=workspace_id, other_id=other_id) + if existing_end_user is None: + from app.core.quota_manager import _check_quota + from app.models.workspace_model import Workspace + ws = db.query(Workspace).filter(Workspace.id == workspace_id).first() + if ws: + _check_quota(db, ws.tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id) + new_end_user = end_user_repo.get_or_create_end_user( app_id=app.id, workspace_id=workspace_id, diff --git a/api/app/controllers/service/end_user_api_controller.py b/api/app/controllers/service/end_user_api_controller.py index 1faea6ef..572f4aab 100644 --- a/api/app/controllers/service/end_user_api_controller.py +++ b/api/app/controllers/service/end_user_api_controller.py @@ -47,7 +47,7 @@ async def create_end_user( request: Request, api_key_auth: ApiKeyAuth = None, db: Session = Depends(get_db), - message: str = Body(..., description="Request body"), + message: str = Body(None, description="Request body"), ): """ Create or retrieve an end user for the workspace. diff --git a/api/app/core/api_key_auth.py b/api/app/core/api_key_auth.py index 91d6bd8a..05bca945 100644 --- a/api/app/core/api_key_auth.py +++ b/api/app/core/api_key_auth.py @@ -96,40 +96,8 @@ def require_api_key( resource_id=api_key_obj.resource_id, ) - # ── Tenant 级别限速(来自套餐配额 api_ops_rate_limit)────────── - try: - from app.models.workspace_model import Workspace - from premium.platform_admin.package_plan_service import TenantSubscriptionService - - workspace = db.query(Workspace).filter( - Workspace.id == api_key_obj.workspace_id - ).first() - if workspace: - quota = TenantSubscriptionService(db).get_effective_quota(workspace.tenant_id) - tenant_qps_limit = quota.get("api_ops_rate_limit") if quota else None - if tenant_qps_limit: - rate_limiter = RateLimiterService() - tenant_ok, tenant_info = await rate_limiter.check_tenant_rate_limit( - workspace.tenant_id, tenant_qps_limit - ) - if not tenant_ok: - raise RateLimitException( - "租户 API 调用速率超限", - BizCode.API_KEY_QPS_LIMIT_EXCEEDED, - rate_headers={ - "X-RateLimit-Tenant-Limit": str(tenant_info["limit"]), - "X-RateLimit-Tenant-Remaining": str(tenant_info["remaining"]), - "X-RateLimit-Tenant-Reset": str(tenant_info["reset"]), - } - ) - except RateLimitException: - raise - except Exception as e: - logger.warning(f"Tenant 限速检查异常,跳过: {e}") - # ───────────────────────────────────────────────────────────── - rate_limiter = RateLimiterService() - is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj) + is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj, db=db) if not is_allowed: logger.warning("API Key 限流触发", extra={ "api_key_id": str(api_key_obj.id), @@ -138,10 +106,12 @@ def require_api_key( "error_msg": error_msg }) # 根据错误消息判断限流类型 - if "QPS" in error_msg: - code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED - elif "Daily" in error_msg: + if "Daily" in error_msg: code = BizCode.API_KEY_DAILY_LIMIT_EXCEEDED + elif "Tenant" in error_msg: + code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED # 租户套餐速率超限,同属 QPS 类 + elif "QPS" in error_msg: + code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED else: code = BizCode.API_KEY_QUOTA_EXCEEDED diff --git a/api/app/core/error_codes.py b/api/app/core/error_codes.py index 01b6115d..77bce6b4 100644 --- a/api/app/core/error_codes.py +++ b/api/app/core/error_codes.py @@ -31,6 +31,9 @@ class BizCode(IntEnum): API_KEY_QPS_LIMIT_EXCEEDED = 3014 API_KEY_DAILY_LIMIT_EXCEEDED = 3015 API_KEY_QUOTA_EXCEEDED = 3016 + API_KEY_RATE_LIMIT_EXCEEDED = 3017 + QUOTA_EXCEEDED = 3018 + RATE_LIMIT_EXCEEDED = 3019 # 资源(4xxx) NOT_FOUND = 4000 USER_NOT_FOUND = 4001 @@ -155,7 +158,8 @@ HTTP_MAPPING = { BizCode.API_KEY_QPS_LIMIT_EXCEEDED: 429, BizCode.API_KEY_DAILY_LIMIT_EXCEEDED: 429, BizCode.API_KEY_QUOTA_EXCEEDED: 429, - + BizCode.QUOTA_EXCEEDED: 402, + BizCode.MODEL_CONFIG_INVALID: 400, BizCode.API_KEY_MISSING: 400, BizCode.PROVIDER_NOT_SUPPORTED: 400, @@ -184,4 +188,21 @@ HTTP_MAPPING = { BizCode.DB_ERROR: 500, BizCode.SERVICE_UNAVAILABLE: 503, BizCode.RATE_LIMITED: 429, + BizCode.RATE_LIMIT_EXCEEDED: 429, +} + +ERROR_CODE_TO_BIZ_CODE = { + "QUOTA_EXCEEDED": BizCode.QUOTA_EXCEEDED, + "RATE_LIMIT_EXCEEDED": BizCode.RATE_LIMIT_EXCEEDED, + "API_KEY_NOT_FOUND": BizCode.API_KEY_NOT_FOUND, + "API_KEY_INVALID": BizCode.API_KEY_INVALID, + "API_KEY_EXPIRED": BizCode.API_KEY_EXPIRED, + "WORKSPACE_NOT_FOUND": BizCode.WORKSPACE_NOT_FOUND, + "WORKSPACE_NO_ACCESS": BizCode.WORKSPACE_NO_ACCESS, + "PERMISSION_DENIED": BizCode.PERMISSION_DENIED, + "TOKEN_EXPIRED": BizCode.TOKEN_EXPIRED, + "TOKEN_INVALID": BizCode.TOKEN_INVALID, + "VALIDATION_FAILED": BizCode.VALIDATION_FAILED, + "INVALID_PARAMETER": BizCode.INVALID_PARAMETER, + "MISSING_PARAMETER": BizCode.MISSING_PARAMETER, } diff --git a/api/app/core/quota_manager.py b/api/app/core/quota_manager.py index 0e0053a0..d59c42e0 100644 --- a/api/app/core/quota_manager.py +++ b/api/app/core/quota_manager.py @@ -6,7 +6,6 @@ 2. 降级到 default_free_plan.py 配置文件(社区版兜底) """ import asyncio -import time from functools import wraps from typing import Optional, Callable, Dict, Any from uuid import UUID @@ -15,10 +14,13 @@ from sqlalchemy import func from sqlalchemy.orm import Session from app.core.logging_config import get_auth_logger -from app.i18n.exceptions import QuotaExceededError +from app.i18n.exceptions import QuotaExceededError, InternalServerError logger = get_auth_logger() +# Redis key 格式常量,与 RateLimiterService.check_qps 保持一致(per api_key 独立计数) +API_KEY_QPS_REDIS_KEY = "rate_limit:qps:{api_key_id}" + def _get_user_from_kwargs(kwargs: dict): """从 kwargs 中获取 user 对象""" @@ -28,6 +30,29 @@ def _get_user_from_kwargs(kwargs: dict): return None +def _get_workspace_id_from_kwargs(kwargs: dict): + """从 kwargs 中获取 workspace_id""" + # 优先从 kwargs['workspace_id'] 获取 + workspace_id = kwargs.get("workspace_id") + if workspace_id: + return workspace_id + + # 从 api_key_auth.workspace_id 获取(API Key 认证场景) + api_key_auth = kwargs.get("api_key_auth") + if api_key_auth and hasattr(api_key_auth, 'workspace_id'): + return api_key_auth.workspace_id + + # 从 user.current_workspace_id 获取 + user = _get_user_from_kwargs(kwargs) + if user: + ws_id = getattr(user, 'current_workspace_id', None) + if ws_id: + return ws_id + + logger.warning(f"无法获取 workspace_id, kwargs keys: {list(kwargs.keys())}") + return None + + def _get_tenant_id_from_kwargs(db: Session, kwargs: dict): """从 kwargs 中获取 tenant_id""" user = _get_user_from_kwargs(kwargs) @@ -65,7 +90,9 @@ def _get_tenant_id_from_kwargs(db: Session, kwargs: dict): if share_record: app = db.query(App).filter(App.id == share_record.app_id, App.is_active.is_(True)).first() if app: - return app.workspace.tenant_id + workspace = db.query(Workspace).filter(Workspace.id == app.workspace_id).first() + if workspace: + return workspace.tenant_id return None @@ -73,31 +100,52 @@ def _get_tenant_id_from_kwargs(db: Session, kwargs: dict): def _get_quota_config(db: Session, tenant_id: UUID) -> Optional[Dict[str, Any]]: """ 获取租户的配额配置 - + 优先级: 1. premium 模块的 tenant_subscriptions(SaaS 版) 2. default_free_plan.py 配置文件(社区版兜底) """ - # 尝试从 premium 模块获取 + # 尝试从 premium 模块获取(SaaS 版) try: from premium.platform_admin.package_plan_service import TenantSubscriptionService + # premium 模块存在,运行时错误不应被静默降级,直接抛出 quota_config = TenantSubscriptionService(db).get_effective_quota(tenant_id) if quota_config: logger.debug(f"从 premium 模块获取租户 {tenant_id} 配额配置") return quota_config - except (ModuleNotFoundError, ImportError, Exception) as e: - logger.debug(f"无法从 premium 模块获取配额配置: {e}") + # premium 存在但该租户无订阅记录,降级到免费套餐 + logger.debug(f"租户 {tenant_id} 无 premium 订阅,降级到免费套餐") + except (ModuleNotFoundError, ImportError): + # 社区版:premium 包不存在,正常降级 + logger.debug("premium 模块不存在,使用社区版免费套餐配额") - # 降级到配置文件 + # 降级到社区版配置文件 try: from app.config.default_free_plan import DEFAULT_FREE_PLAN - logger.info(f"使用配置文件中的免费套餐配额: tenant={tenant_id}") + logger.debug(f"使用社区版免费套餐配额: tenant={tenant_id}") return DEFAULT_FREE_PLAN.get("quotas") except Exception as e: logger.error(f"无法从配置文件获取配额: {e}") return None +def get_api_ops_rate_limit(db: Session, tenant_id: UUID) -> Optional[int]: + """ + 获取租户套餐的 API 操作速率限制(QPS 上限) + + 该函数兼容社区版和 SaaS 版: + - SaaS 版:从 premium 模块的套餐配额读取 + - 社区版:从 default_free_plan.py 配置文件读取 + + Returns: + int: api_ops_rate_limit 值,如果未配置则返回 None + """ + quota_config = _get_quota_config(db, tenant_id) + if quota_config: + return quota_config.get("api_ops_rate_limit") + return None + + class QuotaUsageRepository: """配额使用量数据访问层""" @@ -111,15 +159,19 @@ class QuotaUsageRepository: Workspace.is_active.is_(True) ).count() - def count_apps(self, tenant_id: UUID) -> int: + def count_apps(self, tenant_id: UUID, workspace_id: Optional[UUID] = None) -> int: from app.models.app_model import App from app.models.workspace_model import Workspace - return self.db.query(App).join( + query = self.db.query(App).join( Workspace, App.workspace_id == Workspace.id ).filter( - Workspace.tenant_id == tenant_id, App.is_active.is_(True) - ).count() + ) + if workspace_id: + query = query.filter(App.workspace_id == workspace_id) + else: + query = query.filter(Workspace.tenant_id == tenant_id) + return query.count() def count_skills(self, tenant_id: UUID) -> int: from app.models.skill_model import Skill @@ -128,55 +180,76 @@ class QuotaUsageRepository: Skill.is_active.is_(True) ).count() - def sum_knowledge_capacity_gb(self, tenant_id: UUID) -> float: + def sum_knowledge_capacity_gb(self, tenant_id: UUID, workspace_id: Optional[UUID] = None) -> float: from app.models.document_model import Document from app.models.knowledge_model import Knowledge from app.models.workspace_model import Workspace - result = self.db.query(func.coalesce(func.sum(Document.file_size), 0)).join( + query = self.db.query(func.coalesce(func.sum(Document.file_size), 0)).join( Knowledge, Document.kb_id == Knowledge.id ).join( Workspace, Knowledge.workspace_id == Workspace.id ).filter( - Workspace.tenant_id == tenant_id, Document.status == 1, - ).scalar() + ) + if workspace_id: + query = query.filter(Knowledge.workspace_id == workspace_id) + else: + query = query.filter(Workspace.tenant_id == tenant_id) + result = query.scalar() return float(result) / (1024 ** 3) if result else 0.0 - def count_memory_engines(self, tenant_id: UUID) -> int: + def count_memory_engines(self, tenant_id: UUID, workspace_id: Optional[UUID] = None) -> int: from app.models.memory_config_model import MemoryConfig from app.models.workspace_model import Workspace - return self.db.query(MemoryConfig).join( + query = self.db.query(MemoryConfig).join( Workspace, MemoryConfig.workspace_id == Workspace.id - ).filter( - Workspace.tenant_id == tenant_id - ).count() + ) + if workspace_id: + query = query.filter(MemoryConfig.workspace_id == workspace_id) + else: + query = query.filter(Workspace.tenant_id == tenant_id) + return query.count() - def count_end_users(self, tenant_id: UUID) -> int: + def count_end_users(self, tenant_id: UUID, workspace_id: Optional[UUID] = None) -> int: from app.models.end_user_model import EndUser from app.models.workspace_model import Workspace - return self.db.query(EndUser).join( + from app.models.user_model import User + query = self.db.query(EndUser).join( Workspace, EndUser.workspace_id == Workspace.id - ).filter( - Workspace.tenant_id == tenant_id - ).count() + ) + if workspace_id: + query = query.filter(EndUser.workspace_id == workspace_id) + else: + query = query.filter(Workspace.tenant_id == tenant_id) + trial_user_ids = [ + str(u.id) for u in self.db.query(User.id).filter(User.tenant_id == tenant_id).all() + ] + if trial_user_ids: + query = query.filter(~EndUser.other_id.in_(trial_user_ids)) + return query.count() def count_models(self, tenant_id: UUID) -> int: from app.models.models_model import ModelConfig return self.db.query(ModelConfig).filter( ModelConfig.tenant_id == tenant_id, - ModelConfig.is_active == True + ModelConfig.is_active == True, + ModelConfig.is_composite == True ).count() - def count_ontology_projects(self, tenant_id: UUID) -> int: + def count_ontology_projects(self, tenant_id: UUID, workspace_id: Optional[UUID] = None) -> int: from app.models.ontology_scene import OntologyScene from app.models.workspace_model import Workspace + if workspace_id: + return self.db.query(OntologyScene).filter( + OntologyScene.workspace_id == workspace_id + ).count() return self.db.query(OntologyScene).join( Workspace, OntologyScene.workspace_id == Workspace.id ).filter( Workspace.tenant_id == tenant_id ).count() - def get_usage_by_quota_type(self, tenant_id: UUID, quota_type: str): + def get_usage_by_quota_type(self, tenant_id: UUID, quota_type: str, workspace_id: Optional[UUID] = None): """按配额类型分发,返回当前使用量""" dispatch = { "workspace_quota": self.count_workspaces, @@ -189,6 +262,8 @@ class QuotaUsageRepository: "ontology_project_quota": self.count_ontology_projects, } fn = dispatch.get(quota_type) + if workspace_id: + return fn(tenant_id, workspace_id) if fn else 0 return fn(tenant_id) if fn else 0 @@ -198,6 +273,7 @@ def _check_quota( quota_type: str, resource_name: str, usage_func: Optional[Callable] = None, + workspace_id: Optional[UUID] = None, ) -> None: """核心配额检查逻辑:对比使用量和配额限制""" try: @@ -212,13 +288,13 @@ def _check_quota( return if usage_func: - current_usage = usage_func(db, tenant_id) + current_usage = usage_func(db, tenant_id, workspace_id) if workspace_id else usage_func(db, tenant_id) else: - current_usage = QuotaUsageRepository(db).get_usage_by_quota_type(tenant_id, quota_type) + current_usage = QuotaUsageRepository(db).get_usage_by_quota_type(tenant_id, quota_type, workspace_id) if current_usage >= quota_limit: logger.warning( - f"配额不足: tenant={tenant_id}, type={quota_type}, " + f"配额不足: tenant={tenant_id}, workspace={workspace_id}, type={quota_type}, " f"usage={current_usage}, limit={quota_limit}" ) raise QuotaExceededError( @@ -228,7 +304,7 @@ def _check_quota( ) logger.debug( - f"配额检查通过: tenant={tenant_id}, type={quota_type}, " + f"配额检查通过: tenant={tenant_id}, workspace={workspace_id}, type={quota_type}, " f"usage={current_usage}, limit={quota_limit}" ) @@ -236,7 +312,7 @@ def _check_quota( raise except Exception as e: logger.error( - f"配额检查异常: tenant={tenant_id}, type={quota_type}, " + f"配额检查异常: tenant={tenant_id}, workspace={workspace_id}, type={quota_type}, " f"error_type={type(e).__name__}, error={str(e)}", exc_info=True, ) @@ -247,41 +323,82 @@ def _check_quota( def check_workspace_quota(func: Callable) -> Callable: @wraps(func) - def wrapper(*args, **kwargs): + async def async_wrapper(*args, **kwargs): db: Session = kwargs.get("db") user = _get_user_from_kwargs(kwargs) if not db or not user: - logger.warning("配额检查失败:缺少 db 或 user 参数") - return func(*args, **kwargs) + logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求") + raise InternalServerError() + _check_quota(db, user.tenant_id, "workspace_quota", "workspace") + return await func(*args, **kwargs) + + @wraps(func) + def sync_wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求") + raise InternalServerError() _check_quota(db, user.tenant_id, "workspace_quota", "workspace") return func(*args, **kwargs) - return wrapper + + return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper def check_skill_quota(func: Callable) -> Callable: @wraps(func) - def wrapper(*args, **kwargs): + async def async_wrapper(*args, **kwargs): db: Session = kwargs.get("db") user = _get_user_from_kwargs(kwargs) if not db or not user: - logger.warning("配额检查失败:缺少 db 或 user 参数") - return func(*args, **kwargs) + logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求") + raise InternalServerError() + _check_quota(db, user.tenant_id, "skill_quota", "skill") + return await func(*args, **kwargs) + + @wraps(func) + def sync_wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求") + raise InternalServerError() _check_quota(db, user.tenant_id, "skill_quota", "skill") return func(*args, **kwargs) - return wrapper + + return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper def check_app_quota(func: Callable) -> Callable: @wraps(func) - def wrapper(*args, **kwargs): + async def async_wrapper(*args, **kwargs): db: Session = kwargs.get("db") user = _get_user_from_kwargs(kwargs) if not db or not user: - logger.warning("配额检查失败:缺少 db 或 user 参数") - return func(*args, **kwargs) - _check_quota(db, user.tenant_id, "app_quota", "app") + logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求") + raise InternalServerError() + workspace_id = _get_workspace_id_from_kwargs(kwargs) + if not workspace_id: + logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求") + raise InternalServerError() + _check_quota(db, user.tenant_id, "app_quota", "app", workspace_id=workspace_id) + return await func(*args, **kwargs) + + @wraps(func) + def sync_wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求") + raise InternalServerError() + workspace_id = _get_workspace_id_from_kwargs(kwargs) + if not workspace_id: + logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求") + raise InternalServerError() + _check_quota(db, user.tenant_id, "app_quota", "app", workspace_id=workspace_id) return func(*args, **kwargs) - return wrapper + + return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper def check_knowledge_capacity_quota(func: Callable) -> Callable: @@ -289,13 +406,17 @@ def check_knowledge_capacity_quota(func: Callable) -> Callable: async def async_wrapper(*args, **kwargs): db: Session = kwargs.get("db") if not db: - logger.warning("配额检查失败:缺少 db 参数") - return await func(*args, **kwargs) + logger.error(f"配额检查失败:{func.__name__} 缺少 db 参数,拒绝请求") + raise InternalServerError() tenant_id = _get_tenant_id_from_kwargs(db, kwargs) if not tenant_id: - logger.warning("配额检查失败:无法获取 tenant_id") - return await func(*args, **kwargs) - _check_quota(db, tenant_id, "knowledge_capacity_quota", "knowledge_capacity") + logger.error(f"配额检查失败:{func.__name__} 无法获取 tenant_id,拒绝请求") + raise InternalServerError() + workspace_id = _get_workspace_id_from_kwargs(kwargs) + if not workspace_id: + logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求") + raise InternalServerError() + _check_quota(db, tenant_id, "knowledge_capacity_quota", "knowledge_capacity", workspace_id=workspace_id) return await func(*args, **kwargs) @wraps(func) @@ -303,9 +424,13 @@ def check_knowledge_capacity_quota(func: Callable) -> Callable: db: Session = kwargs.get("db") user = _get_user_from_kwargs(kwargs) if not db or not user: - logger.warning("配额检查失败:缺少 db 或 user 参数") - return func(*args, **kwargs) - _check_quota(db, user.tenant_id, "knowledge_capacity_quota", "knowledge_capacity") + logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求") + raise InternalServerError() + workspace_id = _get_workspace_id_from_kwargs(kwargs) + if not workspace_id: + logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求") + raise InternalServerError() + _check_quota(db, user.tenant_id, "knowledge_capacity_quota", "knowledge_capacity", workspace_id=workspace_id) return func(*args, **kwargs) return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper @@ -313,15 +438,36 @@ def check_knowledge_capacity_quota(func: Callable) -> Callable: def check_memory_engine_quota(func: Callable) -> Callable: @wraps(func) - def wrapper(*args, **kwargs): + async def async_wrapper(*args, **kwargs): db: Session = kwargs.get("db") user = _get_user_from_kwargs(kwargs) + logger.debug(f"check_memory_engine_quota async_wrapper: db={db is not None}, user={user}, kwargs_keys={list(kwargs.keys())}") if not db or not user: - logger.warning("配额检查失败:缺少 db 或 user 参数") - return func(*args, **kwargs) - _check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine") + logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求") + raise InternalServerError() + workspace_id = _get_workspace_id_from_kwargs(kwargs) + if not workspace_id: + logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求") + raise InternalServerError() + _check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine", workspace_id=workspace_id) + return await func(*args, **kwargs) + + @wraps(func) + def sync_wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + logger.debug(f"check_memory_engine_quota sync_wrapper: db={db is not None}, user={user}, kwargs_keys={list(kwargs.keys())}") + if not db or not user: + logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求") + raise InternalServerError() + workspace_id = _get_workspace_id_from_kwargs(kwargs) + if not workspace_id: + logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求") + raise InternalServerError() + _check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine", workspace_id=workspace_id) return func(*args, **kwargs) - return wrapper + + return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper def check_end_user_quota(func: Callable) -> Callable: @@ -329,26 +475,34 @@ def check_end_user_quota(func: Callable) -> Callable: async def async_wrapper(*args, **kwargs): db: Session = kwargs.get("db") if not db: - logger.warning("配额检查失败:缺少 db 参数") - return await func(*args, **kwargs) + logger.error(f"配额检查失败:{func.__name__} 缺少 db 参数,拒绝请求") + raise InternalServerError() tenant_id = _get_tenant_id_from_kwargs(db, kwargs) if not tenant_id: - logger.warning("配额检查失败:无法获取 tenant_id") - return await func(*args, **kwargs) - _check_quota(db, tenant_id, "end_user_quota", "end_user") + logger.error(f"配额检查失败:{func.__name__} 无法获取 tenant_id,拒绝请求") + raise InternalServerError() + workspace_id = _get_workspace_id_from_kwargs(kwargs) + if not workspace_id: + logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求") + raise InternalServerError() + _check_quota(db, tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id) return await func(*args, **kwargs) @wraps(func) def sync_wrapper(*args, **kwargs): db: Session = kwargs.get("db") if not db: - logger.warning("配额检查失败:缺少 db 参数") - return func(*args, **kwargs) + logger.error(f"配额检查失败:{func.__name__} 缺少 db 参数,拒绝请求") + raise InternalServerError() tenant_id = _get_tenant_id_from_kwargs(db, kwargs) if not tenant_id: - logger.warning("配额检查失败:无法获取 tenant_id") - return func(*args, **kwargs) - _check_quota(db, tenant_id, "end_user_quota", "end_user") + logger.error(f"配额检查失败:{func.__name__} 无法获取 tenant_id,拒绝请求") + raise InternalServerError() + workspace_id = _get_workspace_id_from_kwargs(kwargs) + if not workspace_id: + logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求") + raise InternalServerError() + _check_quota(db, tenant_id, "end_user_quota", "end_user", workspace_id=workspace_id) return func(*args, **kwargs) return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper @@ -356,89 +510,171 @@ def check_end_user_quota(func: Callable) -> Callable: def check_ontology_project_quota(func: Callable) -> Callable: @wraps(func) - def wrapper(*args, **kwargs): + async def async_wrapper(*args, **kwargs): db: Session = kwargs.get("db") user = _get_user_from_kwargs(kwargs) if not db or not user: - logger.warning("配额检查失败:缺少 db 或 user 参数") - return func(*args, **kwargs) - _check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project") + logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求") + raise InternalServerError() + workspace_id = _get_workspace_id_from_kwargs(kwargs) + if not workspace_id: + logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求") + raise InternalServerError() + _check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project", workspace_id=workspace_id) + return await func(*args, **kwargs) + + @wraps(func) + def sync_wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求") + raise InternalServerError() + workspace_id = _get_workspace_id_from_kwargs(kwargs) + if not workspace_id: + logger.error(f"配额检查失败:{func.__name__} 无法获取 workspace_id,拒绝请求") + raise InternalServerError() + _check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project", workspace_id=workspace_id) return func(*args, **kwargs) - return wrapper + + return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper def check_model_quota(func: Callable) -> Callable: @wraps(func) - def wrapper(*args, **kwargs): + async def async_wrapper(*args, **kwargs): db: Session = kwargs.get("db") user = _get_user_from_kwargs(kwargs) if not db or not user: - logger.warning("配额检查失败:缺少 db 或 user 参数") - return func(*args, **kwargs) + logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求") + raise InternalServerError() + _check_quota(db, user.tenant_id, "model_quota", "model") + return await func(*args, **kwargs) + + @wraps(func) + def sync_wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求") + raise InternalServerError() _check_quota(db, user.tenant_id, "model_quota", "model") return func(*args, **kwargs) - return wrapper + + return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper def check_model_activation_quota(func: Callable) -> Callable: """模型激活时的配额检查装饰器""" @wraps(func) - def wrapper(*args, **kwargs): + async def async_wrapper(*args, **kwargs): db: Session = kwargs.get("db") user = _get_user_from_kwargs(kwargs) if not db or not user: - logger.warning("配额检查失败:缺少 db 或 user 参数") - return func(*args, **kwargs) - + logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求") + raise InternalServerError() + model_id = kwargs.get("model_id") or (args[1] if len(args) > 1 else None) model_data = kwargs.get("model_data") - + if not model_id or not model_data: logger.warning("模型激活配额检查失败:缺少 model_id 或 model_data 参数") - return func(*args, **kwargs) - - if model_data.is_active is True: + return await func(*args, **kwargs) + + if model_data.is_active: try: - from app.models.models_model import ModelConfig from app.services.model_service import ModelConfigService - + existing_model = ModelConfigService.get_model_by_id( - db=db, - model_id=model_id, + db=db, + model_id=model_id, tenant_id=user.tenant_id ) - + if not existing_model.is_active: logger.info(f"模型激活操作,检查配额: model_id={model_id}, tenant_id={user.tenant_id}") _check_quota(db, user.tenant_id, "model_quota", "model") except Exception as e: logger.error(f"模型激活配额检查异常: model_id={model_id}, error={str(e)}") raise - + + return await func(*args, **kwargs) + + @wraps(func) + def sync_wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求") + raise InternalServerError() + + model_id = kwargs.get("model_id") or (args[1] if len(args) > 1 else None) + model_data = kwargs.get("model_data") + + if not model_id or not model_data: + logger.warning("模型激活配额检查失败:缺少 model_id 或 model_data 参数") + return func(*args, **kwargs) + + if model_data.is_active: + try: + from app.services.model_service import ModelConfigService + + existing_model = ModelConfigService.get_model_by_id( + db=db, + model_id=model_id, + tenant_id=user.tenant_id + ) + + if not existing_model.is_active: + logger.info(f"模型激活操作,检查配额: model_id={model_id}, tenant_id={user.tenant_id}") + _check_quota(db, user.tenant_id, "model_quota", "model") + except Exception as e: + logger.error(f"模型激活配额检查异常: model_id={model_id}, error={str(e)}") + raise + return func(*args, **kwargs) - return wrapper + + return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper def check_quota(quota_type: str, resource_name: str, usage_func: Optional[Callable] = None): """通用配额检查装饰器,支持自定义使用量获取函数""" def decorator(func: Callable) -> Callable: @wraps(func) - def wrapper(*args, **kwargs): + async def async_wrapper(*args, **kwargs): db: Session = kwargs.get("db") user = _get_user_from_kwargs(kwargs) if not db or not user: - logger.warning("配额检查失败:缺少 db 或 user 参数") - return func(*args, **kwargs) + logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求") + raise InternalServerError() + _check_quota(db, user.tenant_id, quota_type, resource_name, usage_func) + return await func(*args, **kwargs) + + @wraps(func) + def sync_wrapper(*args, **kwargs): + db: Session = kwargs.get("db") + user = _get_user_from_kwargs(kwargs) + if not db or not user: + logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求") + raise InternalServerError() _check_quota(db, user.tenant_id, quota_type, resource_name, usage_func) return func(*args, **kwargs) - return wrapper + + return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper return decorator # ─── 配额使用统计 ──────────────────────────────────────────────────────────── -def get_quota_usage(db: Session, tenant_id: UUID) -> dict: - """获取租户所有配额的使用情况""" +async def get_quota_usage(db: Session, tenant_id: UUID) -> dict: + """获取租户所有配额的使用情况 + + 对于 workspace 级别的配额(app/knowledge_capacity/memory_engine/end_user): + - used: 租户汇总(所有空间加总) + - limit: quota × 活跃工作区数(有效总限额,使汇总数据自洽) + - per_workspace: 各空间明细,包含 workspace_id、workspace_name、used、limit、percentage + - 配额检查逻辑不变:仍按单个空间独立检查 + """ quota_config = _get_quota_config(db, tenant_id) if not quota_config: return {} @@ -457,29 +693,99 @@ def get_quota_usage(db: Session, tenant_id: UUID) -> dict: model_count = repo.count_models(tenant_id) ontology_count = repo.count_ontology_projects(tenant_id) + # 获取租户下所有活跃工作区,用于按空间拆分明细 + from app.models.workspace_model import Workspace + active_workspaces = db.query(Workspace).filter( + Workspace.tenant_id == tenant_id, + Workspace.is_active.is_(True) + ).all() + + # 构建各空间的 workspace 级配额明细 + def _build_per_workspace_detail(count_func, per_unit_limit): + """为 workspace 级配额构建 per_workspace 明细列表""" + if not per_unit_limit or not active_workspaces: + return [] + details = [] + for ws in active_workspaces: + ws_used = count_func(tenant_id, ws.id) + details.append({ + "workspace_id": str(ws.id), + "workspace_name": ws.name, + "used": ws_used, + "limit": per_unit_limit, + "percentage": pct(ws_used, per_unit_limit), + }) + return details + + # workspace 级配额的每空间限额 + app_quota_per_ws = quota_config.get("app_quota") + knowledge_quota_per_ws = quota_config.get("knowledge_capacity_quota") + memory_quota_per_ws = quota_config.get("memory_engine_quota") + end_user_quota_per_ws = quota_config.get("end_user_quota") + ontology_quota_per_ws = quota_config.get("ontology_project_quota") + + # workspace 级配额的有效总限额 = 每空间限额 × 活跃工作区数 + app_effective_limit = app_quota_per_ws * workspace_count if app_quota_per_ws is not None and workspace_count > 0 else app_quota_per_ws + knowledge_effective_limit = knowledge_quota_per_ws * workspace_count if knowledge_quota_per_ws is not None and workspace_count > 0 else knowledge_quota_per_ws + memory_effective_limit = memory_quota_per_ws * workspace_count if memory_quota_per_ws is not None and workspace_count > 0 else memory_quota_per_ws + end_user_effective_limit = end_user_quota_per_ws * workspace_count if end_user_quota_per_ws is not None and workspace_count > 0 else end_user_quota_per_ws + ontology_effective_limit = ontology_quota_per_ws * workspace_count if ontology_quota_per_ws is not None and workspace_count > 0 else ontology_quota_per_ws + api_ops_current = 0 try: - from app.core.config import settings - import redis - _now = time.time() - _rk = f"rate_limit:tenant_qps:{tenant_id}" - _r = redis.StrictRedis( - host=settings.REDIS_HOST, port=settings.REDIS_PORT, - db=settings.REDIS_DB, password=settings.REDIS_PASSWORD, - decode_responses=True - ) - api_ops_current = int(_r.zcount(_rk, _now - 1, "+inf")) - except Exception: - pass + from app.aioRedis import aio_redis as _aio_redis + from app.models.api_key_model import ApiKey + # api_ops_rate_limit 限的是每个 api_key 每秒最高限额 + # 展示当前最接近触发限流的 key 的 QPS(取最大值) + api_key_ids = db.query(ApiKey.id).join( + Workspace, ApiKey.workspace_id == Workspace.id + ).filter( + Workspace.tenant_id == tenant_id, + ApiKey.is_active.is_(True) + ).all() + for (key_id,) in api_key_ids: + _rk = API_KEY_QPS_REDIS_KEY.format(api_key_id=key_id) + val = await _aio_redis.get(_rk) + count = int(val) if val else 0 + if count > api_ops_current: + api_ops_current = count + except Exception as e: + logger.warning(f"获取 api_ops_current 失败,返回 0: {type(e).__name__}: {e}") return { "workspace": {"used": workspace_count, "limit": quota_config.get("workspace_quota"), "percentage": pct(workspace_count, quota_config.get("workspace_quota"))}, "skill": {"used": skill_count, "limit": quota_config.get("skill_quota"), "percentage": pct(skill_count, quota_config.get("skill_quota"))}, - "app": {"used": app_count, "limit": quota_config.get("app_quota"), "percentage": pct(app_count, quota_config.get("app_quota"))}, - "knowledge_capacity": {"used": round(knowledge_gb, 2), "limit": quota_config.get("knowledge_capacity_quota"), "percentage": pct(knowledge_gb, quota_config.get("knowledge_capacity_quota")), "unit": "GB"}, - "memory_engine": {"used": memory_count, "limit": quota_config.get("memory_engine_quota"), "percentage": pct(memory_count, quota_config.get("memory_engine_quota"))}, - "end_user": {"used": end_user_count, "limit": quota_config.get("end_user_quota"), "percentage": pct(end_user_count, quota_config.get("end_user_quota"))}, - "ontology_project": {"used": ontology_count, "limit": quota_config.get("ontology_project_quota"), "percentage": pct(ontology_count, quota_config.get("ontology_project_quota"))}, + "app": { + "used": app_count, + "limit": app_effective_limit, + "percentage": pct(app_count, app_effective_limit), + "per_workspace": _build_per_workspace_detail(repo.count_apps, app_quota_per_ws), + }, + "knowledge_capacity": { + "used": round(knowledge_gb, 2), + "limit": knowledge_effective_limit, + "percentage": pct(knowledge_gb, knowledge_effective_limit), + "unit": "GB", + "per_workspace": _build_per_workspace_detail(repo.sum_knowledge_capacity_gb, knowledge_quota_per_ws), + }, + "memory_engine": { + "used": memory_count, + "limit": memory_effective_limit, + "percentage": pct(memory_count, memory_effective_limit), + "per_workspace": _build_per_workspace_detail(repo.count_memory_engines, memory_quota_per_ws), + }, + "end_user": { + "used": end_user_count, + "limit": end_user_effective_limit, + "percentage": pct(end_user_count, end_user_effective_limit), + "per_workspace": _build_per_workspace_detail(repo.count_end_users, end_user_quota_per_ws), + }, + "ontology_project": { + "used": ontology_count, + "limit": ontology_effective_limit, + "percentage": pct(ontology_count, ontology_effective_limit), + "per_workspace": _build_per_workspace_detail(repo.count_ontology_projects, ontology_quota_per_ws), + }, "model": {"used": model_count, "limit": quota_config.get("model_quota"), "percentage": pct(model_count, quota_config.get("model_quota"))}, "api_ops_rate_limit": {"current": api_ops_current, "limit": quota_config.get("api_ops_rate_limit"), "percentage": None, "unit": "次/秒"}, } diff --git a/api/app/core/quota_stub.py b/api/app/core/quota_stub.py index 577dfadb..248d0875 100644 --- a/api/app/core/quota_stub.py +++ b/api/app/core/quota_stub.py @@ -18,6 +18,7 @@ from app.core.quota_manager import ( get_quota_usage, _check_quota, QuotaUsageRepository, + API_KEY_QPS_REDIS_KEY, ) __all__ = [ @@ -33,4 +34,5 @@ __all__ = [ "get_quota_usage", "_check_quota", "QuotaUsageRepository", + "API_KEY_QPS_REDIS_KEY", ] diff --git a/api/app/core/rag/common/connection_utils.py b/api/app/core/rag/common/connection_utils.py index 349caa27..d5d0dc2a 100644 --- a/api/app/core/rag/common/connection_utils.py +++ b/api/app/core/rag/common/connection_utils.py @@ -33,18 +33,16 @@ def timeout(seconds: float | int | str = None, attempts: int = 2, *, exception: thread.daemon = True thread.start() + effective_timeout = seconds if seconds else 120 # 默认 120 秒超时 for a in range(attempts): try: - if os.environ.get("ENABLE_TIMEOUT_ASSERTION"): - result = result_queue.get(timeout=seconds) - else: - result = result_queue.get() + result = result_queue.get(timeout=effective_timeout) if isinstance(result, Exception): raise result return result except queue.Empty: pass - raise TimeoutError(f"Function '{func.__name__}' timed out after {seconds} seconds and {attempts} attempts.") + raise TimeoutError(f"Function '{func.__name__}' timed out after {effective_timeout} seconds and {attempts} attempts.") @wraps(func) async def async_wrapper(*args, **kwargs) -> Any: diff --git a/api/app/core/rag/nlp/search.py b/api/app/core/rag/nlp/search.py index 61540ee4..4b99986b 100644 --- a/api/app/core/rag/nlp/search.py +++ b/api/app/core/rag/nlp/search.py @@ -113,7 +113,7 @@ def knowledge_retrieval( continue # Use the specified reranker for re-ranking - if reranker_id: + if reranker_id and all_results: try: all_results = rerank(db=db, reranker_id=reranker_id, query=query, docs=all_results, top_k=reranker_top_k) except Exception as rerank_error: diff --git a/api/app/core/rag/utils/es_conn.py b/api/app/core/rag/utils/es_conn.py index 7fbf0e38..9a0edd24 100644 --- a/api/app/core/rag/utils/es_conn.py +++ b/api/app/core/rag/utils/es_conn.py @@ -68,9 +68,9 @@ class ESConnection(DocStoreConnection): client_config = { "hosts": [hosts], "basic_auth": (os.getenv("ELASTICSEARCH_USERNAME", "elastic"), os.getenv("ELASTICSEARCH_PASSWORD", "elastic")), - "request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 100000)), + "request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 30)), "retry_on_timeout": os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", True) == "true", - "max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 10000)), + "max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 3)), } # Only add SSL settings if using HTTPS diff --git a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py index 386920e0..cc9ec120 100644 --- a/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/app/core/rag/vdb/elasticsearch/elasticsearch_vector.py @@ -1,25 +1,22 @@ import os import logging -from typing import Any, cast +import threading +from typing import Any from urllib.parse import urlparse -import uuid import requests from elasticsearch import Elasticsearch, helpers from elasticsearch.helpers import BulkIndexError from packaging.version import parse as parse_version -from pydantic import BaseModel, model_validator -from abc import ABC # langchain-community # langchain-xinference # from langchain_community.embeddings import XinferenceEmbeddings # from langchain_xinference import XinferenceRerank from langchain_core.documents import Document from app.core.models.base import RedBearModelConfig -from app.core.models import RedBearLLM, RedBearRerank +from app.core.models import RedBearRerank from app.core.models.embedding import RedBearEmbeddings -from app.models.models_model import ModelConfig, ModelApiKey -from app.services.model_service import ModelConfigService +from app.models.models_model import ModelApiKey from app.models.knowledge_model import Knowledge from app.core.rag.vdb.field import Field @@ -29,37 +26,9 @@ from app.core.rag.models.chunk import DocumentChunk logger = logging.getLogger(__name__) -class ElasticSearchConfig(BaseModel): - # Regular Elasticsearch config - host: str | None = None - port: int | None = None - username: str | None = None - password: str | None = None - - # Common config - ca_certs: str | None = None - verify_certs: bool = False - request_timeout: int = 100000 - retry_on_timeout: bool = True - max_retries: int = 10000 - - @model_validator(mode="before") - @classmethod - def validate_config(cls, values: dict): - # Regular Elasticsearch validation - if not values.get("host"): - raise ValueError("config HOST is required for regular Elasticsearch") - if not values.get("port"): - raise ValueError("config PORT is required for regular Elasticsearch") - if not values.get("username"): - raise ValueError("config USERNAME is required for regular Elasticsearch") - if not values.get("password"): - raise ValueError("config PASSWORD is required for regular Elasticsearch") - return values - - class ElasticSearchVector(BaseVector): - def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey): + def __init__(self, index_name: str, client: Elasticsearch, + embedding_config: ModelApiKey, reranker_config: ModelApiKey): super().__init__(index_name.lower()) # 初始化 Embedding 模型(自动支持火山引擎多模态) @@ -77,58 +46,8 @@ class ElasticSearchVector(BaseVector): api_key=reranker_config.api_key, base_url=reranker_config.api_base )) - self._client = self._init_client(config) - self._version = self._get_version() - self._check_version() - - def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: - """ - Initialize Elasticsearch client for regular Elasticsearch. - """ - try: - # Regular Elasticsearch configuration - parsed_url = urlparse(config.host or "") - if parsed_url.scheme in {"http", "https"}: - hosts = f"{config.host}:{config.port}" - use_https = parsed_url.scheme == "https" - else: - hosts = f"https://{config.host}:{config.port}" - use_https = False - - client_config = { - "hosts": [hosts], - "basic_auth": (config.username, config.password), - "request_timeout": config.request_timeout, - "retry_on_timeout": config.retry_on_timeout, - "max_retries": config.max_retries, - } - - # Only add SSL settings if using HTTPS - if use_https: - client_config["verify_certs"] = config.verify_certs - if config.ca_certs: - client_config["ca_certs"] = config.ca_certs - - client = Elasticsearch(**client_config) - - # Test connection - if not client.ping(): - raise ConnectionError("Failed to connect to Elasticsearch") - - except requests.ConnectionError as e: - raise ConnectionError(f"Vector database connection error: {str(e)}") - except Exception as e: - raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}") - - return client - - def _get_version(self) -> str: - info = self._client.info() - return cast(str, info["version"]["number"]) - - def _check_version(self): - if parse_version(self._version) < parse_version("8.0.0"): - raise ValueError("Elasticsearch vector database version must be greater than 8.0.0") + # 使用外部传入的共享客户端 + self._client = client def get_type(self) -> str: return "elasticsearch" @@ -745,29 +664,79 @@ class ElasticSearchVector(BaseVector): class ElasticSearchVectorFactory: - @staticmethod - def init_vector(knowledge: Knowledge) -> ElasticSearchVector: + """ES 向量服务工厂 - 单例共享连接""" + + _client: Elasticsearch | None = None + _lock = threading.Lock() + _version_checked = False + + @classmethod + def _get_shared_client(cls) -> Elasticsearch: + """获取共享的 ES 客户端(线程安全的懒加载单例)""" + if cls._client is not None: + return cls._client + + with cls._lock: + # 双重检查,防止并发时重复创建 + if cls._client is not None: + return cls._client + + try: + parsed_url = urlparse(os.getenv("ELASTICSEARCH_HOST", "127.0.0.1") or "") + if parsed_url.scheme in {"http", "https"}: + hosts = f'{os.getenv("ELASTICSEARCH_HOST")}:{os.getenv("ELASTICSEARCH_PORT", 9200)}' + use_https = parsed_url.scheme == "https" + else: + hosts = f'https://{os.getenv("ELASTICSEARCH_HOST", "127.0.0.1")}:{os.getenv("ELASTICSEARCH_PORT", 9200)}' + use_https = False + + client_config = { + "hosts": [hosts], + "basic_auth": ( + os.getenv("ELASTICSEARCH_USERNAME", "elastic"), + os.getenv("ELASTICSEARCH_PASSWORD", "elastic"), + ), + "request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 30)), + "retry_on_timeout": True, + "max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 3)), + "connections_per_node": int(os.getenv("ELASTICSEARCH_CONNECTIONS_PER_NODE", 10)), + } + + if use_https: + client_config["verify_certs"] = os.getenv("ELASTICSEARCH_VERIFY_CERTS", "false") == "true" + ca_certs = os.getenv("ELASTICSEARCH_CA_CERTS") + if ca_certs: + client_config["ca_certs"] = str(ca_certs) + + client = Elasticsearch(**client_config) + + if not client.ping(): + raise ConnectionError("Failed to connect to Elasticsearch") + + # 版本检查只做一次 + if not cls._version_checked: + info = client.info() + version = info["version"]["number"] + if parse_version(version) < parse_version("8.0.0"): + raise ValueError(f"Elasticsearch version must be >= 8.0.0, got {version}") + cls._version_checked = True + logger.info(f"Elasticsearch shared client initialized, version: {version}") + + cls._client = client + + except requests.ConnectionError as e: + raise ConnectionError(f"Vector database connection error: {str(e)}") + except Exception as e: + raise ConnectionError(f"Elasticsearch client initialization failed: {str(e)}") + + return cls._client + + @classmethod + def init_vector(cls, knowledge: Knowledge) -> ElasticSearchVector: + """创建向量服务实例(共享 ES 连接)""" + client = cls._get_shared_client() collection_name = f"Vector_index_{knowledge.id}_Node" - # Use regular Elasticsearch with config values - config_dict = { - "host": os.getenv("ELASTICSEARCH_HOST", "127.0.0.1"), - "port": os.getenv("ELASTICSEARCH_PORT", 9200), - "username": os.getenv("ELASTICSEARCH_USERNAME", "elastic"), - "password": os.getenv("ELASTICSEARCH_PASSWORD", "elastic"), - } - - # Common configuration - config_dict.update( - { - "ca_certs": str(os.getenv("ELASTICSEARCH_CA_CERTS")) if os.getenv("ELASTICSEARCH_CA_CERTS") else None, - "verify_certs": os.getenv("ELASTICSEARCH_VERIFY_CERTS", False) == "true", - "request_timeout": int(os.getenv("ELASTICSEARCH_REQUEST_TIMEOUT", 100000)), - "retry_on_timeout": os.getenv("ELASTICSEARCH_RETRY_ON_TIMEOUT", True) == "true", - "max_retries": int(os.getenv("ELASTICSEARCH_MAX_RETRIES", 10000)), - } - ) - if knowledge.embedding is None: raise ValueError(f"embedding_id config error: {str(knowledge.embedding_id)}") if knowledge.reranker is None: @@ -775,9 +744,9 @@ class ElasticSearchVectorFactory: return ElasticSearchVector( index_name=collection_name, - config=ElasticSearchConfig(**config_dict), + client=client, embedding_config=knowledge.embedding.api_keys[0], - reranker_config=knowledge.reranker.api_keys[0] + reranker_config=knowledge.reranker.api_keys[0], ) diff --git a/api/app/core/workflow/nodes/tool/node.py b/api/app/core/workflow/nodes/tool/node.py index 410f64c3..07c384c1 100644 --- a/api/app/core/workflow/nodes/tool/node.py +++ b/api/app/core/workflow/nodes/tool/node.py @@ -11,6 +11,7 @@ from app.core.workflow.nodes.tool.config import ToolNodeConfig from app.core.workflow.variable.base_variable import VariableType from app.db import get_db_read from app.services.tool_service import ToolService +from app.models.tool_model import ToolType logger = logging.getLogger(__name__) @@ -76,6 +77,18 @@ class ToolNode(BaseNode): # 执行工具 with get_db_read() as db: tool_service = ToolService(db) + + # MCP 工具:将 operation 映射为 tool_name,其余参数包装进 arguments + tool_instance = tool_service.get_tool_instance(self.typed_config.tool_id, tenant_id) + if tool_instance and tool_instance.tool_type == ToolType.MCP: + operation = rendered_parameters.pop("operation", None) + if operation: + old_params = rendered_parameters + rendered_parameters = { + "tool_name": operation, + "arguments": old_params + } + result = await tool_service.execute_tool( tool_id=self.typed_config.tool_id, parameters=rendered_parameters, diff --git a/api/app/i18n/exceptions.py b/api/app/i18n/exceptions.py index b81369ed..93794c39 100644 --- a/api/app/i18n/exceptions.py +++ b/api/app/i18n/exceptions.py @@ -6,12 +6,14 @@ error messages based on the current request's language. """ import logging +import time from contextvars import ContextVar from typing import Any, Dict, Optional from fastapi import HTTPException, Request from app.i18n.service import get_translation_service +from app.core.error_codes import ERROR_CODE_TO_BIZ_CODE, BizCode logger = logging.getLogger(__name__) @@ -118,15 +120,24 @@ class I18nException(HTTPException): **params ) - # Build error detail - detail = { - "error_code": self.error_code, - "message": message, - } + # Convert error_code string to BizCode value + biz_code = ERROR_CODE_TO_BIZ_CODE.get( + self.error_code, + BizCode.BAD_REQUEST + ) - # Add parameters to detail if provided - if params: - detail["params"] = params + # Build error detail in standard format for compatibility + # main.py handler expects "message" and "error_code" fields for filtering + # but we also include standard format fields + detail = { + "code": biz_code.value, + "msg": message, + "message": message, + "error_code": self.error_code, + "data": params if params else {}, + "error": message, + "time": int(time.time() * 1000), + } # Initialize HTTPException super().__init__( @@ -482,14 +493,39 @@ class RateLimitExceededError(I18nException): ) -class QuotaExceededError(ForbiddenError): - """Quota exceeded error.""" +class QuotaExceededError(I18nException): + """Quota exceeded error (402).""" + + # resource key -> i18n display key + _RESOURCE_KEY_MAP = { + "workspace": "errors.quota_resources.workspace", + "app": "errors.quota_resources.app", + "skill": "errors.quota_resources.skill", + "knowledge_capacity": "errors.quota_resources.knowledge_capacity", + "memory_engine": "errors.quota_resources.memory_engine", + "end_user": "errors.quota_resources.end_user", + "model": "errors.quota_resources.model", + "ontology_project": "errors.quota_resources.ontology_project", + "api_ops_rate_limit": "errors.quota_resources.api_ops_rate_limit", + } def __init__(self, resource: Optional[str] = None, **params): + # Translate resource key to a localized display name before calling super() if resource: - params["resource"] = resource + resource_i18n_key = self._RESOURCE_KEY_MAP.get(resource) + if resource_i18n_key: + try: + from app.i18n.service import get_translation_service + from app.core.config import settings + _locale = _current_locale.get() or settings.I18N_DEFAULT_LANGUAGE + params["resource"] = get_translation_service().translate(resource_i18n_key, _locale) + except Exception: + params["resource"] = resource + else: + params["resource"] = resource super().__init__( error_key="errors.api.quota_exceeded", + status_code=402, error_code="QUOTA_EXCEEDED", **params ) diff --git a/api/app/locales/en/errors.json b/api/app/locales/en/errors.json index d0276dc9..2355954c 100644 --- a/api/app/locales/en/errors.json +++ b/api/app/locales/en/errors.json @@ -106,7 +106,7 @@ }, "api": { "rate_limit_exceeded": "API rate limit exceeded", - "quota_exceeded": "API quota exceeded", + "quota_exceeded": "{resource} quota exceeded", "invalid_api_key": "Invalid API key", "api_key_expired": "API key has expired", "api_key_revoked": "API key has been revoked", @@ -114,7 +114,8 @@ "method_not_allowed": "Method not allowed", "invalid_request": "Invalid request", "missing_parameter": "Missing required parameter: {param}", - "invalid_parameter": "Invalid parameter: {param}" + "invalid_parameter": "Invalid parameter: {param}", + "api_key_rate_limit_exceeded": "API Key rate limit ({rate_limit}) exceeds tenant plan limit ({limit})" }, "database": { "connection_failed": "Database connection failed", @@ -134,5 +135,16 @@ "invalid_format": "Invalid format: {field}", "invalid_value": "Invalid value: {field}", "out_of_range": "Value out of range: {field}" + }, + "quota_resources": { + "workspace": "Workspace", + "app": "App", + "skill": "Skill", + "knowledge_capacity": "Knowledge capacity", + "memory_engine": "Memory engine", + "end_user": "End user", + "model": "Model", + "ontology_project": "Ontology project", + "api_ops_rate_limit": "API ops rate limit" } } diff --git a/api/app/locales/zh/errors.json b/api/app/locales/zh/errors.json index eafadad4..8b7fdec0 100644 --- a/api/app/locales/zh/errors.json +++ b/api/app/locales/zh/errors.json @@ -106,7 +106,7 @@ }, "api": { "rate_limit_exceeded": "API调用频率超限", - "quota_exceeded": "API调用配额已用完", + "quota_exceeded": "{resource} 配额已超限", "invalid_api_key": "无效的API密钥", "api_key_expired": "API密钥已过期", "api_key_revoked": "API密钥已被撤销", @@ -114,7 +114,8 @@ "method_not_allowed": "不支持的请求方法", "invalid_request": "无效的请求", "missing_parameter": "缺少必需参数:{param}", - "invalid_parameter": "参数无效:{param}" + "invalid_parameter": "参数无效:{param}", + "api_key_rate_limit_exceeded": "API Key 的 QPS 限制({rate_limit})超过租户套餐上限({limit})" }, "database": { "connection_failed": "数据库连接失败", @@ -134,5 +135,16 @@ "invalid_format": "格式不正确:{field}", "invalid_value": "值无效:{field}", "out_of_range": "值超出范围:{field}" + }, + "quota_resources": { + "workspace": "工作空间", + "app": "应用", + "skill": "技能", + "knowledge_capacity": "知识库容量", + "memory_engine": "记忆引擎", + "end_user": "终端用户", + "model": "模型", + "ontology_project": "本体工程", + "api_ops_rate_limit": "API 操作速率" } } diff --git a/api/app/repositories/end_user_repository.py b/api/app/repositories/end_user_repository.py index aad80707..aba4034f 100644 --- a/api/app/repositories/end_user_repository.py +++ b/api/app/repositories/end_user_repository.py @@ -66,6 +66,17 @@ class EndUserRepository: db_logger.error(f"查询宿主 {end_user_id} 时出错: {str(e)}") raise + def get_end_user_by_other_id(self, workspace_id: uuid.UUID, other_id: str) -> Optional["EndUser"]: + """按 workspace_id + other_id 查找终端用户,不存在返回 None""" + return ( + self.db.query(EndUser) + .filter( + EndUser.workspace_id == workspace_id, + EndUser.other_id == other_id + ) + .first() + ) + def get_or_create_end_user( self, app_id: uuid.UUID, diff --git a/api/app/schemas/api_key_schema.py b/api/app/schemas/api_key_schema.py index c7ca1e55..37245aa6 100644 --- a/api/app/schemas/api_key_schema.py +++ b/api/app/schemas/api_key_schema.py @@ -15,8 +15,8 @@ class ApiKeyCreate(BaseModel): type: ApiKeyType = Field(..., description="API Key 类型") scopes: List[str] = Field(default_factory=list, description="权限范围列表") resource_id: Optional[uuid.UUID] = Field(None, description="关联资源ID") - rate_limit: Optional[int] = Field(100, ge=1, le=1000, description="QPS限制(请求/秒)") - daily_request_limit: Optional[int] = Field(10000, description="日请求限制", ge=1) + rate_limit: Optional[int] = Field(50, ge=1, le=1000, description="QPS限制(请求/秒)") + daily_request_limit: Optional[int] = Field(100000, description="日请求限制", ge=1) quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1) expires_at: Optional[datetime.datetime] = Field(None, description="过期时间") @@ -55,7 +55,7 @@ class ApiKeyUpdate(BaseModel): description: Optional[str] = Field(None, description="描述") scopes: Optional[List[str]] = Field(None, description="权限范围列表") rate_limit: Optional[int] = Field(None, description="速率限制(请求/分钟)", ge=1) - daily_request_limit: Optional[int] = Field(10000, description="每日请求数限制", ge=1) + daily_request_limit: Optional[int] = Field(100000, description="每日请求数限制", ge=1) quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1) is_active: Optional[bool] = Field(None, description="是否激活") expires_at: Optional[datetime.datetime] = Field(None, description="过期时间") diff --git a/api/app/services/api_key_service.py b/api/app/services/api_key_service.py index 07d55198..4856365a 100644 --- a/api/app/services/api_key_service.py +++ b/api/app/services/api_key_service.py @@ -51,6 +51,19 @@ class ApiKeyService: if existing: raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME) + # 若 rate_limit 超过租户套餐的 api_ops_rate_limit,直接报错 + from app.models.workspace_model import Workspace + from app.core.quota_manager import get_api_ops_rate_limit + + workspace = db.query(Workspace).filter(Workspace.id == workspace_id).first() + if workspace: + tenant_api_ops_limit = get_api_ops_rate_limit(db, workspace.tenant_id) + if tenant_api_ops_limit and data.rate_limit > tenant_api_ops_limit: + raise BusinessException( + f"API Key QPS 不能超过套餐上限 {tenant_api_ops_limit}", + BizCode.BAD_REQUEST + ) + # 生成 API Key api_key = generate_api_key(data.type) @@ -152,6 +165,20 @@ class ApiKeyService: if existing: raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME) + # 若 rate_limit 超过租户套餐的 api_ops_rate_limit,直接报错 + if data.rate_limit is not None: + from app.models.workspace_model import Workspace + from app.core.quota_manager import get_api_ops_rate_limit + + workspace = db.query(Workspace).filter(Workspace.id == workspace_id).first() + if workspace: + tenant_api_ops_limit = get_api_ops_rate_limit(db, workspace.tenant_id) + if tenant_api_ops_limit and data.rate_limit > tenant_api_ops_limit: + raise BusinessException( + f"API Key QPS 不能超过套餐上限 {tenant_api_ops_limit}", + BizCode.BAD_REQUEST + ) + update_data = data.model_dump(exclude_unset=True) ApiKeyRepository.update(db, api_key_id, update_data) db.commit() @@ -248,42 +275,14 @@ class RateLimiterService: def __init__(self): self.redis = aio_redis - async def check_tenant_rate_limit(self, tenant_id: uuid.UUID, limit: int) -> Tuple[bool, dict]: - """ - 按 tenant_id 做 1 秒滑动窗口限速,限制值来自套餐配额 api_ops_rate_limit - """ - now = time.time() - window_start = now - 1 # 1 秒窗口 - key = f"rate_limit:tenant_qps:{tenant_id}" - - async with self.redis.pipeline() as pipe: - # 清理 1 秒前的旧记录 - pipe.zremrangebyscore(key, 0, window_start) - # 加入当前请求(score=时间戳,member=时间戳+随机数保证唯一) - pipe.zadd(key, {f"{now}:{uuid.uuid4().hex}": now}) - # 统计窗口内请求数 - pipe.zcard(key) - # 设置 key 过期(2 秒后自动清理) - pipe.expire(key, 2) - results = await pipe.execute() - - current = results[2] - remaining = max(0, limit - current) - reset_time = int(now) + 1 - - return current <= limit, { - "limit": limit, - "remaining": remaining, - "reset": reset_time, - } - async def check_qps(self, api_key_id: uuid.UUID, limit: int) -> Tuple[bool, dict]: - """ - 检查QPS限制 + """检查QPS限制 + Returns: (is_allowed, rate_limit_info) """ key = f"rate_limit:qps:{api_key_id}" + async with self.redis.pipeline() as pipe: pipe.incr(key) pipe.expire(key, 1, nx=True) # 1 秒过期 @@ -295,8 +294,9 @@ class RateLimiterService: return current <= limit, { "limit": limit, + "current": current, "remaining": remaining, - "reset": reset_time + "reset": reset_time, } async def check_daily_requests( @@ -304,7 +304,9 @@ class RateLimiterService: api_key_id: uuid.UUID, limit: int ) -> Tuple[bool, dict]: - """检查日调用量限制""" + """检查日调用量限制。 + 使用原子 INCR,先写后判断,极低概率下允许轻微超限(并发场景下可接受)。 + """ today = datetime.now().strftime("%Y%m%d") key = f"rate_limit:daily:{api_key_id}:{today}" @@ -313,6 +315,7 @@ class RateLimiterService: hour=0, minute=0, second=0, microsecond=0 ) expire_seconds = int((tomorrow_0 - now).total_seconds()) + reset_time = int(tomorrow_0.timestamp()) async with self.redis.pipeline() as pipe: pipe.incr(key) @@ -320,36 +323,74 @@ class RateLimiterService: results = await pipe.execute() current = results[0] - remaining = max(0, limit - current) - reset_time = int(tomorrow_0.timestamp()) - return current <= limit, { + if current > limit: + return False, { + "limit": limit, + "remaining": 0, + "reset": reset_time, + } + + return True, { "limit": limit, - "remaining": remaining, - "reset": reset_time + "remaining": max(0, limit - current), + "reset": reset_time, } async def check_all_limits( self, - api_key: ApiKey + api_key: ApiKey, + db: Optional[Session] = None, ) -> Tuple[bool, str, dict]: """ - 检查所有限制 - Returns: - (is_allowed, error_message, rate_limit_headers) + 检查所有限制,按以下顺序: + 1. API Key QPS:取 api_key.rate_limit 与套餐 api_ops_rate_limit 的最小值作为限额 + 2. API Key 日调用量 """ - # Check QPS - qps_ok, qps_info = await self.check_qps( - api_key.id, - api_key.rate_limit - ) + # 1. 取套餐限额与 api_key 自身限额的最小值 + effective_limit = api_key.rate_limit + if db is not None: + try: + from app.models.workspace_model import Workspace + from app.core.quota_manager import get_api_ops_rate_limit + + cache_key = f"tenant_api_ops_limit:{api_key.workspace_id}" + cached = await self.redis.get(cache_key) + if cached is not None: + try: + tenant_limit = int(cached) if cached != "0" else None + except (ValueError, TypeError): + cached = None + tenant_limit = None + + if cached is None: + workspace = db.query(Workspace).filter(Workspace.id == api_key.workspace_id).first() + if workspace: + tenant_limit = get_api_ops_rate_limit(db, workspace.tenant_id) + await self.redis.set(cache_key, str(tenant_limit) if tenant_limit else "0", ex=60) + else: + tenant_limit = None + + if tenant_limit: + effective_limit = min(api_key.rate_limit, tenant_limit) + except Exception as e: + logger.warning(f"获取套餐限额失败,使用 api_key 自身限额: {e}") + + # 用最终有效限额做 QPS 检查 + qps_ok, qps_info = await self.check_qps(api_key.id, effective_limit) if not qps_ok: - return False, "QPS limit exceeded", { + # 判断是套餐限额触发还是 api_key 自身限额触发 + if tenant_limit and effective_limit == tenant_limit and api_key.rate_limit > tenant_limit: + error_msg = "Tenant limit exceeded" + else: + error_msg = "QPS limit exceeded" + return False, error_msg, { "X-RateLimit-Limit-QPS": str(qps_info["limit"]), "X-RateLimit-Remaining-QPS": str(qps_info["remaining"]), "X-RateLimit-Reset": str(qps_info["reset"]) } + # 2. 检查日调用量 daily_ok, daily_info = await self.check_daily_requests( api_key.id, api_key.daily_request_limit @@ -361,14 +402,13 @@ class RateLimiterService: "X-RateLimit-Reset": str(daily_info["reset"]) } - headers = { + return True, "", { "X-RateLimit-Limit-QPS": str(qps_info["limit"]), "X-RateLimit-Remaining-QPS": str(qps_info["remaining"]), "X-RateLimit-Limit-Day": str(daily_info["limit"]), "X-RateLimit-Remaining-Day": str(daily_info["remaining"]), - "X-RateLimit-Reset": str(daily_info["reset"]) + "X-RateLimit-Reset": str(daily_info["reset"]), } - return True, "", headers class ApiKeyAuthService: diff --git a/api/app/services/app_dsl_service.py b/api/app/services/app_dsl_service.py index 26e4098c..63279d2c 100644 --- a/api/app/services/app_dsl_service.py +++ b/api/app/services/app_dsl_service.py @@ -434,19 +434,37 @@ class AppDslService: def _resolve_model(self, ref: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> Optional[uuid.UUID]: if not ref: return None - q = self.db.query(ModelConfig).filter( - ModelConfig.tenant_id == tenant_id, - ModelConfig.name == ref.get("name"), - ModelConfig.is_active.is_(True) - ) - if ref.get("provider"): - q = q.filter(ModelConfig.provider == ref["provider"]) - if ref.get("type"): - q = q.filter(ModelConfig.type == ref["type"]) - m = q.first() - if not m: - warnings.append(f"模型 '{ref.get('name')}' 未匹配,已置空,请导入后手动配置") - return m.id if m else None + model_id = ref.get("id") + if model_id: + try: + model_uuid = uuid.UUID(str(model_id)) + m = self.db.query(ModelConfig).filter( + ModelConfig.id == model_uuid, + ModelConfig.tenant_id == tenant_id, + ModelConfig.is_active.is_(True) + ).first() + if m: + return str(m.id) + except (ValueError, AttributeError): + pass + model_name = ref.get("name") + if model_name: + q = self.db.query(ModelConfig).filter( + ModelConfig.tenant_id == tenant_id, + ModelConfig.name == model_name, + ModelConfig.is_active.is_(True) + ) + if ref.get("provider"): + q = q.filter(ModelConfig.provider == ref["provider"]) + if ref.get("type"): + q = q.filter(ModelConfig.type == ref["type"]) + m = q.first() + if m: + return str(m.id) + warnings.append(f"模型 '{model_name}' 未匹配,已置空,请导入后手动配置") + else: + warnings.append(f"模型 ID '{model_id}' 未匹配,已置空,请导入后手动配置") + return None def _resolve_kb(self, ref: Optional[dict], workspace_id: uuid.UUID, warnings: list) -> Optional[str]: if not ref: @@ -587,7 +605,7 @@ class AppDslService: if not kb_id: continue kb_ref = {} - if isinstance(kb_id, str) and len(kb_id) >= 36: + if isinstance(kb_id, str): try: uuid.UUID(kb_id) kb_ref["id"] = kb_id @@ -601,6 +619,33 @@ class AppDslService: else: warnings.append(f"[{node_label}] 知识库 '{kb_id}' 未匹配,已移除,请导入后手动配置") config["knowledge_bases"] = resolved_kbs + elif node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value): + model_ref = config.get("model_id") + if model_ref: + ref_dict = None + if isinstance(model_ref, dict): + ref_id = model_ref.get("id") + ref_name = model_ref.get("name") + if ref_id: + ref_dict = {"id": ref_id} + elif ref_name is not None: + ref_dict = {"name": ref_name, "provider": model_ref.get("provider"), "type": model_ref.get("type")} + elif isinstance(model_ref, str): + try: + uuid.UUID(model_ref) + ref_dict = {"id": model_ref} + except ValueError: + ref_dict = {"name": model_ref} + if ref_dict: + resolved_model_id = self._resolve_model(ref_dict, tenant_id, warnings) + if resolved_model_id: + config["model_id"] = resolved_model_id + else: + warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置") + config["model_id"] = None + else: + warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置") + config["model_id"] = None resolved_nodes.append({**node, "config": config}) return resolved_nodes diff --git a/api/app/services/knowledge_service.py b/api/app/services/knowledge_service.py index 94653db8..20757307 100644 --- a/api/app/services/knowledge_service.py +++ b/api/app/services/knowledge_service.py @@ -7,7 +7,6 @@ from app.models.models_model import ModelConfig from app.schemas.knowledge_schema import KnowledgeCreate, KnowledgeUpdate from app.repositories import knowledge_repository from app.core.logging_config import get_business_logger -from app.repositories.model_repository import ModelConfigRepository from app.models.models_model import ModelType business_logger = get_business_logger() @@ -78,41 +77,31 @@ def create_knowledge( tenant_id = workspace.tenant_id if not knowledge.embedding_id: - embedding_models = ModelConfigRepository.get_by_type( - db=db, model_types=[ModelType.EMBEDDING], tenant_id=tenant_id, is_active=True - ) - if embedding_models: - knowledge.embedding_id = embedding_models[0].id - business_logger.debug(f"Auto-bind embedding model: {embedding_models[0].id}") + if not workspace.embedding: + raise Exception("工作空间未配置 Embedding 模型,请先完善工作空间配置后重试") + knowledge.embedding_id = workspace.embedding if not knowledge.reranker_id: - rerank_models = ModelConfigRepository.get_by_type( - db=db, model_types=[ModelType.RERANK], tenant_id=tenant_id, is_active=True - ) - if rerank_models: - knowledge.reranker_id = rerank_models[0].id - business_logger.debug(f"Auto-bind rerank model: {rerank_models[0].id}") + if not workspace.rerank: + raise Exception("工作空间未配置 Rerank 模型,请先完善工作空间配置后重试") + knowledge.reranker_id = workspace.rerank if not knowledge.llm_id: - llm_models = ModelConfigRepository.get_by_type( - db=db, model_types=[ModelType.LLM, ModelType.CHAT], tenant_id=tenant_id, is_active=True - ) - if llm_models: - knowledge.llm_id = llm_models[0].id - business_logger.debug(f"Auto-bind llm model: {llm_models[0].id}") + if not workspace.llm: + raise Exception("工作空间未配置 LLM 模型,请先完善工作空间配置后重试") + knowledge.llm_id = workspace.llm if not knowledge.image2text_id: - image2text_models = db.query(ModelConfig).filter( + model = db.query(ModelConfig).filter( ModelConfig.tenant_id == tenant_id, - ModelConfig.type.in_([ModelType.CHAT.value]), + ModelConfig.type.in_([ModelType.CHAT.value, ModelType.LLM.value]), ModelConfig.capability.contains(["vision"]), ModelConfig.is_active == True, - ModelConfig.is_composite == False - ).order_by(ModelConfig.created_at.desc()).all() - if not image2text_models: + ).order_by(ModelConfig.created_at.desc()).first() + if not model: raise Exception("租户下没有可用的视觉模型,创建知识库失败") - knowledge.image2text_id = image2text_models[0].id - business_logger.debug(f"Auto-bind image2text model: {image2text_models[0].id}") + knowledge.image2text_id = model.id + business_logger.debug(f"Auto-bind image2text model: {model.id}") business_logger.debug(f"Start creating the knowledge base: {knowledge.name}") db_knowledge = knowledge_repository.create_knowledge( diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 8a221094..4ccb6bcd 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -1282,7 +1282,7 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An } logger.info( - f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}") + f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={end_user.workspace_id}") return result diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index 8807020b..72e46f4a 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -125,11 +125,7 @@ class ModelConfigService: api_key=api_key, base_url=api_base, is_omni=is_omni, - capability=capability, - extra_params={ - "temperature": 0.7, - "max_tokens": 100 - } + capability=capability ) # 根据模型类型选择不同的验证方式 @@ -373,6 +369,15 @@ class ModelConfigService: raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id) + + # 同步更新关联 api_keys 的 capability 和 is_omni + if model_data.capability is not None or model_data.is_omni is not None: + for api_key in model.api_keys: + if model_data.capability is not None: + api_key.capability = model_data.capability + if model_data.is_omni is not None: + api_key.is_omni = model_data.is_omni + db.commit() db.refresh(model) return model diff --git a/api/app/tasks.py b/api/app/tasks.py index 8bbbdc6e..92843175 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -251,8 +251,40 @@ def parse_document(file_path: str, document_id: uuid.UUID): # Prepare vision_model for parsing vision_model = _build_vision_model(file_path, db_knowledge) + # 先将文件读入内存,避免解析过程中依赖 NFS 文件持续可访问 + # python-docx 等库在 binary=None 时会用路径直接打开文件, + # 在 NFS/共享存储上可能因缓存失效导致 "Package not found" + max_wait_seconds = 30 + wait_interval = 2 + waited = 0 + file_binary = None + while waited <= max_wait_seconds: + # os.listdir 强制 NFS 客户端刷新目录缓存 + parent_dir = os.path.dirname(file_path) + try: + os.listdir(parent_dir) + except OSError: + pass + try: + with open(file_path, "rb") as f: + file_binary = f.read() + if not file_binary: + # NFS 上文件存在但内容为空(可能还在同步中) + raise IOError(f"File is empty (0 bytes), NFS may still be syncing: {file_path}") + break + except (FileNotFoundError, IOError) as e: + if waited >= max_wait_seconds: + raise type(e)( + f"File not accessible at '{file_path}' after waiting {max_wait_seconds}s: {e}" + ) + logger.warning(f"File not ready on this node, retrying in {wait_interval}s: {file_path} ({e})") + time.sleep(wait_interval) + waited += wait_interval + from app.core.rag.app.naive import chunk + logger.info(f"[ParseDoc] file_binary size={len(file_binary)} bytes, type={type(file_binary).__name__}, bool={bool(file_binary)}") res = chunk(filename=file_path, + binary=file_binary, from_page=0, to_page=DEFAULT_PARSE_TO_PAGE, callback=progress_callback, diff --git a/web/src/components/ModelSelect/index.tsx b/web/src/components/ModelSelect/index.tsx index 85977376..4c59c87b 100644 --- a/web/src/components/ModelSelect/index.tsx +++ b/web/src/components/ModelSelect/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-03-07 16:49:59 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-04-17 10:11:54 + * @Last Modified time: 2026-04-20 18:14:34 */ import { type FC, useEffect, useState } from 'react'; import { Select, Flex, Space } from 'antd'; @@ -56,7 +56,7 @@ const ModelSelect: FC = ({ params, placeholder, fontClassName, useEffect(() => { if (updateOptions) updateOptions([...options, ...initialData]); - }, [options, initialData]) + }, [JSON.stringify(options), JSON.stringify(initialData)]) return ( - + + + + } {['llm', 'chat'].includes(modelType as string) && diff --git a/web/src/views/Package/constant.ts b/web/src/views/Package/constant.ts index 7fc69969..168b65f8 100644 --- a/web/src/views/Package/constant.ts +++ b/web/src/views/Package/constant.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-04-14 11:43:57 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-04-14 11:44:40 + * @Last Modified time: 2026-04-21 15:44:13 */ export const billingUnits = [ { @@ -42,7 +42,7 @@ export const billingUnits = [ }, { key: 'model_quota', - unit: 'ops', placeholder: 'numberPlaceholder', + unit: 'pcs', placeholder: 'numberPlaceholder', icon: 'model', }, { diff --git a/web/src/views/Package/index.tsx b/web/src/views/Package/index.tsx index b3aed9d5..ea9f1ec5 100644 --- a/web/src/views/Package/index.tsx +++ b/web/src/views/Package/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-04-14 11:34:42 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-04-16 17:23:49 + * @Last Modified time: 2026-04-21 15:45:30 */ /** * Package Component @@ -60,7 +60,7 @@ const btnClassNames = { default: 'rb:h-10! rb:rounded-[8px]! rb:bg-[#212332]! rb:text-white! rb:border-0! rb:hover:border-0! rb:hover:opacity-[0.8]', } -export const UnitWrapper = ({ titleKey, value, icon, unit, theme_color = '#171719' }: { titleKey: string; value: number | string; icon: string; unit?: string; theme_color?: string; }) => { +export const UnitWrapper = ({ titleKey, value, icon, unit, theme_color = '#171719' }: { titleKey: string; value?: number | string | null; icon: string; unit?: string; theme_color?: string; }) => { const { t } = useTranslation(); const renderFeatureIcon = (iconKey: string, color: string) => { @@ -78,7 +78,7 @@ export const UnitWrapper = ({ titleKey, value, icon, unit, theme_color = '#17171 >{renderFeatureIcon(icon, theme_color)}
{t(`package.${titleKey}`)}
-
{value} {unit ? t(`package.${unit}`) : ''}
+ {value ?
{value} {unit ? t(`package.${unit}`) : ''}
:
{t('package.noLimit')}
}
) @@ -252,7 +252,6 @@ const Package: FC = () => { > {billingUnits.map(({ key, unit, icon }) => { const value = pkg?.quotas?.[key as keyof Package['quotas']]; - if (value === undefined || value === null) return null; return ( { /> ) })} - {pkg.tech_support && ( + {pkg.tech_support && pkg[getKeyWithLanguage('tech_support')] && ( { theme_color={pkg.theme_color} /> )} - {pkg.sla_compliance && ( + {pkg.sla_compliance && pkg[getKeyWithLanguage('sla_compliance')] && ( { + if (!value) return '-' + if (Array.isArray(value)) { + return value.length ? value.join(' | ') : '-' + } + return value +} const EndUserProfile = forwardRef(({ className, onDataLoaded }, ref) => { const { t } = useTranslation() const { id } = useParams() @@ -89,19 +96,19 @@ const EndUserProfile = forwardRef(({ cla
{t('userMemory.role')}
-
{data?.profile?.role?.join(' | ') || '-'}
+
{formatValue(data?.profile?.role)}
{t('userMemory.domain')}
-
{data?.profile?.domain?.join(' | ') || '-'}
+
{formatValue(data?.profile?.domain)}
{t('userMemory.expertise')}
-
{data?.profile?.expertise?.join(' | ') || '-'}
+
{formatValue(data?.profile?.expertise)}
{t('userMemory.interests')}
-
{data?.profile?.interests?.join(' | ') || '-'}
+
{formatValue(data?.profile?.interests)}
diff --git a/web/src/views/UserMemoryDetail/types.ts b/web/src/views/UserMemoryDetail/types.ts index 667d8272..8efbd890 100644 --- a/web/src/views/UserMemoryDetail/types.ts +++ b/web/src/views/UserMemoryDetail/types.ts @@ -178,7 +178,7 @@ export interface EndUser { created_at: string; updated_at: string; profile: { - role: string[]; + role: string[] | string; domain: string[]; expertise: string[]; interests: string[]; diff --git a/web/src/views/Workflow/components/CheckList/index.tsx b/web/src/views/Workflow/components/CheckList/index.tsx index 4afdb863..5ba13212 100644 --- a/web/src/views/Workflow/components/CheckList/index.tsx +++ b/web/src/views/Workflow/components/CheckList/index.tsx @@ -1,3 +1,9 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-04-09 18:58:21 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-04-20 10:39:17 + */ import { useState, useCallback, useEffect, useRef, type FC } from 'react' import { Popover, Flex } from 'antd' import { WarningFilled } from '@ant-design/icons' @@ -49,7 +55,7 @@ const specialValidators: Record boolean> = { if (expr?.sub_variable_condition?.conditions?.length > 0) return expr.sub_variable_condition?.conditions.every(isSubExprSet) if (!expr.left) return false if (['not_empty', 'empty'].includes(expr.operator)) return true - return !!expr.left && (!!expr.right || typeof expr.right === 'boolean' || typeof expr.right === 'number') + return !!expr.left && (expr?.sub_variable_condition || !!expr.right || typeof expr.right === 'boolean' || typeof expr.right === 'number') } return val.some(c => !c?.expressions?.length || c.expressions.some((expr: any) => !isExprSet(expr))) }, @@ -100,6 +106,18 @@ function validateNode(type: string, config: Record): CheckError[] { if (isInvalid) errors.push({ key: specialKey, message: '' }) }) + // llm: vision_input required when vision is enabled + if (type === 'llm') { + const vision = get('vision') + if (vision === true || vision === 'true') { + const visionInput = get('vision_input') + console.log('vision', vision, isEmpty(visionInput)) + if (isEmpty(visionInput)) { + errors.push({ key: 'llm.vision_input', message: '' }) + } + } + } + // http-request body.data (binary) — not a top-level required field, check separately if (type === 'http-request') { const body = get('body') diff --git a/web/src/views/Workflow/components/Nodes/ConditionNode.tsx b/web/src/views/Workflow/components/Nodes/ConditionNode.tsx index d2264901..19966823 100644 --- a/web/src/views/Workflow/components/Nodes/ConditionNode.tsx +++ b/web/src/views/Workflow/components/Nodes/ConditionNode.tsx @@ -10,7 +10,7 @@ import { useVariableList } from '../Properties/hooks/useVariableList' import { isSubExprSet } from '../../utils' import { fileSubFieldOperators } from '../Properties/CaseList' -const caculateIsSet = (item: any, type: string) => { +const calculateIsSet = (item: any, type: string) => { switch (type) { case 'categories': return typeof item?.class_name === 'string' && item?.class_name !== '' @@ -90,7 +90,7 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => {
{t('workflow.config.question-classifier.class_name')} {index + 1} - {caculateIsSet(item, 'categories') ? t(`workflow.config.${data.type}.set`) : t(`workflow.config.${data.type}.unset`)} + {calculateIsSet(item, 'categories') ? t(`workflow.config.${data.type}.set`) : t(`workflow.config.${data.type}.unset`)}
))} @@ -100,17 +100,24 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => { {data.config?.cases?.defaultValue.map((item: any, index: number) => (
0 ? '' : 'rb:mb-1'}> - 0 ? "space-between" : 'end'} className="rb:mb-1"> - {item.expressions.length > 0 && CASE{index + 1}} + 0 ? "space-between" : 'end'} className="rb:mb-1! rb:leading-4"> + {item.expressions.length > 0 && CASE{index + 1}} {index === 0 ? 'IF' : `ELIF`} {item.expressions.length > 0 && {item.expressions.map((expression: any, eIndex: number) => (
- {item.expressions.length > 1 && eIndex > 0 &&
{item.logical_operator?.toLocaleUpperCase()}
} - + {item.expressions.length > 1 && eIndex > 0 && +
{item.logical_operator?.toLocaleUpperCase()}
+ } + 0, + 'rb:py-1!': !expression.sub_variable_condition?.conditions || !expression.sub_variable_condition?.conditions?.length + })} + > - {caculateIsSet(expression, 'cases') + {calculateIsSet(expression, 'cases') ? <> {labelRender(expression.left)} {getLocaleField(expression.operator, typeof expression.right)} @@ -120,11 +127,16 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => { } {expression.sub_variable_condition?.conditions?.length > 0 && expression.sub_variable_condition?.conditions.every(isSubExprSet) - ?
+ ?
{expression.sub_variable_condition?.conditions.map((sub: any, sIndex: number) => (
{expression.sub_variable_condition?.conditions.length > 1 && sIndex > 0 &&
{expression.sub_variable_condition?.logical_operator?.toLocaleUpperCase()}
} - + {sub.key} {getSubLocaleField(sub.operator, sub.key)} @@ -140,7 +152,7 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => { ))}
: expression.sub_variable_condition?.conditions?.length > 0 - ? + ? {t(`workflow.config.${data.type}.unset`)} : null diff --git a/web/src/views/Workflow/components/Properties/CaseList/index.tsx b/web/src/views/Workflow/components/Properties/CaseList/index.tsx index 2fd24628..a9da1457 100644 --- a/web/src/views/Workflow/components/Properties/CaseList/index.tsx +++ b/web/src/views/Workflow/components/Properties/CaseList/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-09 18:24:53 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-04-17 20:47:49 + * @Last Modified time: 2026-04-20 10:46:05 */ import { useEffect, useMemo, type FC } from 'react' import clsx from 'clsx' @@ -39,7 +39,7 @@ interface Expression { sub_variable_condition?: SubVariableCondition; } -interface CaseItem { +export interface CaseItem { logical_operator: 'and' | 'or'; expressions: Expression[]; } @@ -274,7 +274,9 @@ const ArrayFileSubConditions: FC = ({ conditionFiel className="rb:w-full!" suffix="Byte" size="small" - onChange={(value) => { form.setFieldValue([name, caseIndex, 'expressions', conditionIndex, 'right'], value); }} + onChange={(value) => { + form.setFieldValue([name, caseIndex, 'expressions', conditionIndex, 'sub_variable_condition', 'conditions', subIndex, 'value'], value); + }} /> } @@ -483,13 +485,24 @@ const CaseList: FC = ({ form.setFieldValue([name, index, 'logical_operator'], currentValue === 'and' ? 'or' : 'and'); }; - const handleLeftFieldChange = (caseIndex: number, conditionIndex: number, newValue: string) => { - form.setFieldValue([name, caseIndex, 'expressions', conditionIndex], { - left: newValue, - operator: undefined, - right: undefined, - input_type: 'constant' - }); + const handleLeftFieldChange = (caseIndex: number, conditionIndex: number, newValue: string, option?: Suggestion | undefined) => { + if (option?.dataType === 'array[file]') { + form.setFieldValue([name, caseIndex, 'expressions', conditionIndex], { + left: newValue, + operator: undefined, + sub_variable_condition: { + conditions: [], + logical_operator: 'and' + } + }); + } else { + form.setFieldValue([name, caseIndex, 'expressions', conditionIndex], { + left: newValue, + operator: undefined, + right: undefined, + input_type: 'constant' + }); + } }; const handleAddCase = (addCaseFunc: Function) => { @@ -590,7 +603,7 @@ const CaseList: FC = ({ options={options} size="small" allowClear={false} - onChange={(val) => handleLeftFieldChange(caseIndex, conditionIndex, val as string)} + onChange={(val, option) => handleLeftFieldChange(caseIndex, conditionIndex, val as string, option as unknown as Suggestion)} variant="borderless" className="rb:w-36!" /> diff --git a/web/src/views/Workflow/components/Properties/Knowledge/Knowledge.tsx b/web/src/views/Workflow/components/Properties/Knowledge/Knowledge.tsx index a1973beb..b9c45eef 100644 --- a/web/src/views/Workflow/components/Properties/Knowledge/Knowledge.tsx +++ b/web/src/views/Workflow/components/Properties/Knowledge/Knowledge.tsx @@ -29,12 +29,13 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi if (value && JSON.stringify(value) !== JSON.stringify(editConfig)) { setEditConfig({ ...(value || {}) }) const knowledge_bases = [...(value.knowledge_bases || [])] + setKnowledgeList(knowledge_bases) // 检查是否有knowledge_bases缺少name字段 const basesWithoutName = knowledge_bases.filter(base => !base.name) if (basesWithoutName.length > 0) { // 调用接口获取完整的知识库信息 - getKnowledgeBaseList().then(res => { + getKnowledgeBaseList(undefined, { kb_ids: basesWithoutName.map(vo => vo.kb_id).join(',') }).then(res => { const fullBases = knowledge_bases.map(base => { if (!base.name) { const fullBase = res.items.find((item: any) => item.id === base.kb_id) diff --git a/web/src/views/Workflow/components/Properties/ToolConfig/index.tsx b/web/src/views/Workflow/components/Properties/ToolConfig/index.tsx index ce30ee8f..6e8bd0c0 100644 --- a/web/src/views/Workflow/components/Properties/ToolConfig/index.tsx +++ b/web/src/views/Workflow/components/Properties/ToolConfig/index.tsx @@ -1,4 +1,4 @@ -import { type FC, useEffect, useState, useMemo } from "react"; +import { type FC, useEffect, useState } from "react"; import { useTranslation } from 'react-i18next' import { Form, Select, Switch, Cascader, type CascaderProps, Tooltip } from 'antd' import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin' @@ -45,15 +45,15 @@ const ToolConfig: FC<{ options: Suggestion[]; }> = ({ getToolDetail(values.tool_id) .then(res => { const detail = res as { tool_type: ToolType; } - + getTools({ tool_type: detail.tool_type }) .then(toolsRes => { const tools = toolsRes as ToolItem[] - + getToolMethods(values.tool_id) .then(methodsRes => { const response = methodsRes as Array<{ method_id: string; name: string; parameters: Parameter[] }> - + setOptionList(prevList => { return prevList.map(item => { if (item.value === detail.tool_type) { @@ -76,7 +76,7 @@ const ToolConfig: FC<{ options: Suggestion[]; }> = ({ return item }) }) - + if (response.length > 1) { const filterTarget = response.find(vo => vo.name === values.tool_parameters?.operation) if (filterTarget) { @@ -98,7 +98,7 @@ const ToolConfig: FC<{ options: Suggestion[]; }> = ({ useEffect(() => { if (values.tools && values.tools.length === 3) { const [toolType, toolId, operation] = values.tools - + // 从 optionList 中查找对应的参数 const typeOption = optionList.find(opt => opt.value === toolType) if (typeOption?.children) { @@ -147,21 +147,26 @@ const ToolConfig: FC<{ options: Suggestion[]; }> = ({ }; const handleChange: CascaderProps