From 0ea47ce8901bf05e2776893c24823e1881a7c2a4 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Wed, 18 Mar 2026 16:20:18 +0800 Subject: [PATCH] feat(workflow): add configurable workflow feature options --- .../engine/stream_output_coordinator.py | 13 ++-- api/app/models/workflow_model.py | 1 + api/app/schemas/workflow_schema.py | 7 ++ api/app/services/app_service.py | 5 +- api/app/services/workflow_service.py | 76 ++++++++++++++++++- 5 files changed, 95 insertions(+), 7 deletions(-) diff --git a/api/app/core/workflow/engine/stream_output_coordinator.py b/api/app/core/workflow/engine/stream_output_coordinator.py index c2885ab0..ddee9adc 100644 --- a/api/app/core/workflow/engine/stream_output_coordinator.py +++ b/api/app/core/workflow/engine/stream_output_coordinator.py @@ -5,7 +5,7 @@ import re from typing import AsyncGenerator -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, PrivateAttr from app.core.logging_config import get_logger from app.core.workflow.engine.variable_pool import VariablePool @@ -52,10 +52,11 @@ class OutputContent(BaseModel): ) ) - _SCOPE: str | None = None + _SCOPE: str | None = PrivateAttr(default=None) - def get_scope(self) -> str: - self._SCOPE = SCOPE_PATTERN.findall(self.literal)[0] + def get_scope(self) -> str | None: + matches = SCOPE_PATTERN.findall(self.literal) + self._SCOPE = matches[0] if matches else None return self._SCOPE def depends_on_scope(self, scope: str) -> bool: @@ -68,6 +69,8 @@ class OutputContent(BaseModel): Returns: bool: True if this segment references the given scope. """ + if not self.is_variable: + return False if self._SCOPE: return self._SCOPE == scope return self.get_scope() == scope @@ -152,7 +155,7 @@ class StreamOutputConfig(BaseModel): """ # Case 1: resolve control branch dependency - if scope in self.control_nodes.keys(): + if scope in self.control_nodes: if status is None: raise RuntimeError("[Stream Output] Control node activation status not provided") if status in self.control_nodes[scope]: diff --git a/api/app/models/workflow_model.py b/api/app/models/workflow_model.py index 4f9ffe68..29fe5369 100644 --- a/api/app/models/workflow_model.py +++ b/api/app/models/workflow_model.py @@ -35,6 +35,7 @@ class WorkflowConfig(Base): # 执行配置 execution_config = Column(JSONB, nullable=False, default=dict) + features = Column(JSONB, nullable=True, default=dict) # 触发器配置(可选) triggers = Column(JSONB, default=list) diff --git a/api/app/schemas/workflow_schema.py b/api/app/schemas/workflow_schema.py index e580833f..d878d97c 100644 --- a/api/app/schemas/workflow_schema.py +++ b/api/app/schemas/workflow_schema.py @@ -80,6 +80,7 @@ class WorkflowConfigCreate(BaseModel): variables: list[VariableDefinition] = Field(default_factory=list, description="变量列表") execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置") triggers: list[TriggerConfig] = Field(default_factory=list, description="触发器列表") + features: dict = Field(default_factory=dict, description="功能特性配置") class WorkflowConfigUpdate(BaseModel): @@ -87,6 +88,7 @@ class WorkflowConfigUpdate(BaseModel): nodes: list[NodeDefinition] | None = None edges: list[EdgeDefinition] | None = None variables: list[VariableDefinition] | None = None + features: dict | None = None execution_config: ExecutionConfig | None = None triggers: list[TriggerConfig] | None = None @@ -102,6 +104,7 @@ class WorkflowConfig(BaseModel): variables: list[dict[str, Any]] execution_config: dict[str, Any] triggers: list[dict[str, Any]] + features: dict | None is_active: bool created_at: datetime.datetime updated_at: datetime.datetime @@ -114,6 +117,10 @@ class WorkflowConfig(BaseModel): def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None + @field_serializer("features", when_used="json") + def _serialize_features(self, features: dict | None): + return features or {} + # ==================== 工作流执行 ==================== diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 5ef34da8..98fdf6c9 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -1609,6 +1609,7 @@ class AppService: variables=[var.model_dump() for var in data.variables] if data.variables else [], execution_config=data.execution_config.model_dump() if data.execution_config else {}, triggers=[trigger.model_dump() for trigger in data.triggers] if data.triggers else [], + features=data.features or {}, is_active=True, created_at=now, updated_at=now @@ -1622,6 +1623,7 @@ class AppService: workflow_cfg.variables = [var.model_dump() for var in data.variables] if data.variables else [] workflow_cfg.execution_config = data.execution_config.model_dump() if data.execution_config else {} workflow_cfg.triggers = [trigger.model_dump() for trigger in data.triggers] if data.triggers else [] + workflow_cfg.features = data.features or {} workflow_cfg.updated_at = now self.db.commit() @@ -1875,7 +1877,8 @@ class AppService: "edges": workflow_cfg.edges, "variables": workflow_cfg.variables, "execution_config": workflow_cfg.execution_config, - "triggers": workflow_cfg.triggers + "triggers": workflow_cfg.triggers, + "features": workflow_cfg.features or {} } is_valid, errors = WorkflowValidator.validate_for_publish(config) diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 7aca3c2f..9f421976 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -570,6 +570,9 @@ class WorkflowService: message=f"工作流配置不存在: app_id={app_id}" ) + feature_configs = config.features or {} + self._validate_file_upload(feature_configs, payload.files) + input_data = { "message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id, @@ -737,6 +740,8 @@ class WorkflowService: code=BizCode.CONFIG_MISSING, message=f"工作流配置不存在: app_id={app_id}" ) + feature_configs = config.features or {} + self._validate_file_upload(feature_configs, payload.files) input_data = { "message": payload.message, "variables": payload.variables, @@ -845,7 +850,10 @@ class WorkflowService: yield event except Exception as e: - logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True) + logger.error( + f"Workflow streaming execution failed: execution_id={execution.execution_id}, error={e}", + exc_info=True + ) self.update_execution_status( execution.execution_id, "failed", @@ -876,6 +884,72 @@ class WorkflowService: return True return False + @staticmethod + def _validate_file_upload( + features_config: dict[str, Any], + files: Optional[list[FileInput]] + ) -> None: + """校验上传文件是否符合 file_upload 配置""" + if not files: + return + fu = features_config.get("file_upload") + if fu is None: + return + if not (isinstance(fu, dict) and fu.get("enabled")): + raise BusinessException( + "The application does not have file upload functionality enabled", + BizCode.BAD_REQUEST + ) + max_count = fu.get("max_file_count", 5) + if len(files) > max_count: + raise BusinessException( + f"File count exceeds limit (maximum {max_count} files)", + BizCode.BAD_REQUEST + ) + + # 校验传输方式 + allowed_methods = fu.get("allowed_transfer_methods", ["local_file", "remote_url"]) + for f in files: + if f.transfer_method.value not in allowed_methods: + raise BusinessException( + f"Unsupport file transfer method:{f.transfer_method.value}," + f"allowed method:{', '.join(allowed_methods)}", + BizCode.BAD_REQUEST + ) + + # 各类型对应的开关和大小限制配置键 + type_cfg = { + "image": ("image_enabled", "image_max_size_mb", 20, "image"), + "audio": ("audio_enabled", "audio_max_size_mb", 50, "audio"), + "document": ("document_enabled", "document_max_size_mb", 100, "document"), + "video": ("video_enabled", "video_max_size_mb", 500, "video"), + } + + for f in files: + ftype = str(f.type) # 如 "image", "audio", "document", "video" + cfg = type_cfg.get(ftype) + if cfg is None: + continue + enabled_key, size_key, default_max_mb, label = cfg + + # 校验类型开关 + if not fu.get(enabled_key): + raise BusinessException( + f"The application has not enabled {label} file upload", + BizCode.BAD_REQUEST + ) + + # 校验文件大小(仅当内容已加载时) + content = f.get_content() + if content is not None: + max_mb = fu.get(size_key, default_max_mb) + size_mb = len(content) / (1024 * 1024) + if size_mb > max_mb: + raise BusinessException( + f"{label} File size exceeds the limit (maximum {max_mb} MB, current {size_mb:.1f} MB)", + BizCode.BAD_REQUEST + ) + # ==================== 依赖注入函数 ====================