From 78ce2a9a8b557082c454fba875165db39ae05749 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Thu, 5 Mar 2026 14:07:27 +0800 Subject: [PATCH] feat(workflow): support multimodal input --- api/app/core/workflow/engine/variable_pool.py | 34 ++++++++++----- api/app/core/workflow/nodes/base_node.py | 41 ++++++++++++++----- .../workflow/nodes/cycle_graph/iteration.py | 4 +- .../core/workflow/nodes/cycle_graph/loop.py | 2 +- api/app/core/workflow/nodes/llm/node.py | 6 +-- .../core/workflow/variable/base_variable.py | 12 ++++-- .../workflow/variable/variable_objects.py | 7 +++- api/app/services/draft_run_service.py | 2 +- api/app/services/workflow_service.py | 14 ++++--- 9 files changed, 84 insertions(+), 38 deletions(-) diff --git a/api/app/core/workflow/engine/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py index d08f47e5..bc88df19 100644 --- a/api/app/core/workflow/engine/variable_pool.py +++ b/api/app/core/workflow/engine/variable_pool.py @@ -303,38 +303,52 @@ class VariablePool: """ return self._get_variable_struct(selector) is not None - def get_all_system_vars(self) -> dict[str, Any]: + def get_all_system_vars(self, literal=False) -> dict[str, Any]: """获取所有系统变量 Returns: 系统变量字典 """ sys_namespace = self.variables.get("sys", {}) + if literal: + return {k: v.instance.to_literal() for k, v in sys_namespace.items()} return {k: v.instance.get_value() for k, v in sys_namespace.items()} - def get_all_conversation_vars(self) -> dict[str, Any]: + def get_all_conversation_vars(self, literal=False) -> dict[str, Any]: """获取所有会话变量 Returns: 会话变量字典 """ conv_namespace = self.variables.get("conv", {}) + if literal: + return {k: v.instance.to_literal() for k, v in conv_namespace.items()} return {k: v.instance.get_value() for k, v in conv_namespace.items()} - def get_all_node_outputs(self) -> dict[str, Any]: + def get_all_node_outputs(self, literal=False) -> dict[str, Any]: """获取所有节点输出(运行时变量) Returns: 节点输出字典,键为节点 ID """ - runtime_vars = { - namespace: { - k: v.instance.get_value() - for k, v in vars_dict.items() + if literal: + runtime_vars = { + namespace: { + k: v.instance.to_literal() + for k, v in vars_dict.items() + } + for namespace, vars_dict in self.variables.items() + if namespace not in ("sys", "conv") + } + else: + runtime_vars = { + namespace: { + k: v.instance.get_value() + for k, v in vars_dict.items() + } + for namespace, vars_dict in self.variables.items() + if namespace not in ("sys", "conv") } - for namespace, vars_dict in self.variables.items() - if namespace not in ("sys", "conv") - } return runtime_vars def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None: diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 3e30c00e..3f30718c 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -1,5 +1,6 @@ import asyncio import logging +import uuid from abc import ABC, abstractmethod from functools import cached_property from typing import Any, AsyncGenerator @@ -10,8 +11,10 @@ from app.core.config import settings from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.enums import BRANCH_NODES -from app.core.workflow.variable.base_variable import VariableType -from app.services.multimodal_service import PROVIDER_STRATEGIES +from app.core.workflow.variable.base_variable import VariableType, FileObject +from app.db import get_db_read +from app.schemas import FileInput +from app.services.multimodal_service import MultimodalService logger = logging.getLogger(__name__) @@ -548,9 +551,9 @@ class BaseNode(ABC): return render_template( template=template, - conv_vars=variable_pool.get_all_conversation_vars(), - node_outputs=variable_pool.get_all_node_outputs(), - system_vars=variable_pool.get_all_system_vars(), + conv_vars=variable_pool.get_all_conversation_vars(literal=True), + node_outputs=variable_pool.get_all_node_outputs(literal=True), + system_vars=variable_pool.get_all_system_vars(literal=True), strict=strict ) @@ -614,16 +617,32 @@ class BaseNode(ABC): return variable_pool.has(selector) @staticmethod - async def process_message(provider, content, enable_file=False) -> dict | str | None: + async def process_message(provider: str, content: str | FileObject, enable_file=False) -> dict | str | None: if isinstance(content, str): if enable_file: return {"text": content} return content - elif isinstance(content, dict): - trans_tool = PROVIDER_STRATEGIES[provider]() - result = await trans_tool.format_image(content["url"]) - return result - raise TypeError('Unexpect input value type') + + elif isinstance(content, FileObject): + if content.content_cache.get(provider): + return content.content_cache[provider] + with get_db_read() as db: + multimodel_service = MultimodalService(db, provider) + message = await multimodel_service.process_files( + [FileInput.model_construct( + type=content.type, + url=content.url, + transfer_method=content.transfer_method, + file_type=content.origin_file_type, + upload_file_id=content.file_id + )] + ) + + if message: + content.content_cache[provider] = message[0] + return message[0] + return None + raise TypeError(f'Unexpect input value type - {type(content)}') @staticmethod def process_model_output(content) -> str: diff --git a/api/app/core/workflow/nodes/cycle_graph/iteration.py b/api/app/core/workflow/nodes/cycle_graph/iteration.py index e4026f2d..cf7ac976 100644 --- a/api/app/core/workflow/nodes/cycle_graph/iteration.py +++ b/api/app/core/workflow/nodes/cycle_graph/iteration.py @@ -91,8 +91,8 @@ class IterationRuntime: return loopstate def merge_conv_vars(self): - self.variable_pool.get_all_conversation_vars().update( - self.child_variable_pool.get_all_conversation_vars() + self.variable_pool.variables["conv"].update( + self.child_variable_pool.variables["conv"] ) async def run_task(self, item, idx): diff --git a/api/app/core/workflow/nodes/cycle_graph/loop.py b/api/app/core/workflow/nodes/cycle_graph/loop.py index cebadfdc..d3ada1ec 100644 --- a/api/app/core/workflow/nodes/cycle_graph/loop.py +++ b/api/app/core/workflow/nodes/cycle_graph/loop.py @@ -156,7 +156,7 @@ class LoopRuntime: def merge_conv_vars(self, loopstate): self.variable_pool.variables["conv"].update( - self.child_variable_pool.variables.get("conv", {}) + self.child_variable_pool.variables["conv"] ) loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False) loopstate["node_outputs"][self.node_id] = loop_vars diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index fdd5df58..c109d59b 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -172,9 +172,9 @@ class LLMNode(BaseNode): if self.typed_config.vision_input and self.typed_config.vision: file_content = [] - files = variable_pool.get_value(self.typed_config.vision_input) - for file in files: - content = await self.process_message(provider, file, self.typed_config.vision) + files = variable_pool.get_instance(self.typed_config.vision_input) + for file in files.value: + content = await self.process_message(provider, file.value, self.typed_config.vision) if content: file_content.append(content) if messages and messages[-1]["role"] == 'user': diff --git a/api/app/core/workflow/variable/base_variable.py b/api/app/core/workflow/variable/base_variable.py index 19cbdc74..dd821ea7 100644 --- a/api/app/core/workflow/variable/base_variable.py +++ b/api/app/core/workflow/variable/base_variable.py @@ -2,7 +2,7 @@ from enum import StrEnum from abc import abstractmethod, ABC from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, Field from app.schemas import FileType @@ -45,7 +45,7 @@ class VariableType(StrEnum): return cls.NUMBER elif isinstance(var, bool): return cls.BOOLEAN - elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('__file')): + elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')): return cls.FILE elif isinstance(var, dict): return cls.OBJECT @@ -109,7 +109,13 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any: class FileObject(BaseModel): type: FileType url: str - __file: bool + transfer_method: str + origin_file_type: str + file_id: str | None + + content_cache: dict = Field(default_factory=dict) + + is_file: bool class BaseVariable(ABC): diff --git a/api/app/core/workflow/variable/variable_objects.py b/api/app/core/workflow/variable/variable_objects.py index 49541afc..63437fd9 100644 --- a/api/app/core/workflow/variable/variable_objects.py +++ b/api/app/core/workflow/variable/variable_objects.py @@ -63,13 +63,16 @@ class FileVariable(BaseVariable): def valid_value(self, value) -> FileObject: if isinstance(value, dict): - if not value.get("__file"): + if not value.get("is_file"): raise TypeError(f"Value must be a FileObject - {type(value)}:{value}") return FileObject( **{ "type": str(value.get('type')), + "transfer_method": value.get("transfer_method"), "url": value.get('url'), - "__file": True + "file_id": value.get("file_id"), + "origin_file_type": value.get("origin_file_type"), + "is_file": True } ) if isinstance(value, FileObject): diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 0cf68be2..bb68c815 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -255,7 +255,7 @@ class AgentRunService: @staticmethod def prepare_variables( input_vars: dict | None, - variables_config: dict | None + variables_config: dict ) -> dict: input_vars = input_vars or {} for variable in variables_config: diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 02819efb..d13e3454 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -16,6 +16,7 @@ from app.core.workflow.adapters.registry import PlatformAdapterRegistry from app.core.workflow.executor import execute_workflow, execute_workflow_stream from app.core.workflow.nodes.enums import NodeType from app.core.workflow.validator import validate_workflow_config +from app.core.workflow.variable.base_variable import FileObject from app.db import get_db from app.models import App from app.models.workflow_model import WorkflowConfig, WorkflowExecution @@ -453,11 +454,14 @@ class WorkflowService: files_struct = [] for file in files: files_struct.append( - { - "type": file.type, - "url": await self.multimodal_service.get_file_url(file), - "__file": True - } + FileObject( + type=file.type, + url=await self.multimodal_service.get_file_url(file), + transfer_method=file.transfer_method, + file_id=str(file.upload_file_id), + origin_file_type=file.file_type, + is_file=True + ).model_dump() ) return files_struct