feat(workflow): support multimodal input

This commit is contained in:
Eternity
2026-03-05 14:07:27 +08:00
parent a72d5d2c77
commit 78ce2a9a8b
9 changed files with 84 additions and 38 deletions

View File

@@ -303,30 +303,44 @@ class VariablePool:
""" """
return self._get_variable_struct(selector) is not None return self._get_variable_struct(selector) is not None
def get_all_system_vars(self) -> dict[str, Any]: def get_all_system_vars(self, literal=False) -> dict[str, Any]:
"""获取所有系统变量 """获取所有系统变量
Returns: Returns:
系统变量字典 系统变量字典
""" """
sys_namespace = self.variables.get("sys", {}) sys_namespace = self.variables.get("sys", {})
if literal:
return {k: v.instance.to_literal() for k, v in sys_namespace.items()}
return {k: v.instance.get_value() for k, v in sys_namespace.items()} return {k: v.instance.get_value() for k, v in sys_namespace.items()}
def get_all_conversation_vars(self) -> dict[str, Any]: def get_all_conversation_vars(self, literal=False) -> dict[str, Any]:
"""获取所有会话变量 """获取所有会话变量
Returns: Returns:
会话变量字典 会话变量字典
""" """
conv_namespace = self.variables.get("conv", {}) conv_namespace = self.variables.get("conv", {})
if literal:
return {k: v.instance.to_literal() for k, v in conv_namespace.items()}
return {k: v.instance.get_value() for k, v in conv_namespace.items()} return {k: v.instance.get_value() for k, v in conv_namespace.items()}
def get_all_node_outputs(self) -> dict[str, Any]: def get_all_node_outputs(self, literal=False) -> dict[str, Any]:
"""获取所有节点输出(运行时变量) """获取所有节点输出(运行时变量)
Returns: Returns:
节点输出字典,键为节点 ID 节点输出字典,键为节点 ID
""" """
if literal:
runtime_vars = {
namespace: {
k: v.instance.to_literal()
for k, v in vars_dict.items()
}
for namespace, vars_dict in self.variables.items()
if namespace not in ("sys", "conv")
}
else:
runtime_vars = { runtime_vars = {
namespace: { namespace: {
k: v.instance.get_value() k: v.instance.get_value()

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
import logging import logging
import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import cached_property from functools import cached_property
from typing import Any, AsyncGenerator from typing import Any, AsyncGenerator
@@ -10,8 +11,10 @@ from app.core.config import settings
from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.enums import BRANCH_NODES from app.core.workflow.nodes.enums import BRANCH_NODES
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType, FileObject
from app.services.multimodal_service import PROVIDER_STRATEGIES from app.db import get_db_read
from app.schemas import FileInput
from app.services.multimodal_service import MultimodalService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -548,9 +551,9 @@ class BaseNode(ABC):
return render_template( return render_template(
template=template, template=template,
conv_vars=variable_pool.get_all_conversation_vars(), conv_vars=variable_pool.get_all_conversation_vars(literal=True),
node_outputs=variable_pool.get_all_node_outputs(), node_outputs=variable_pool.get_all_node_outputs(literal=True),
system_vars=variable_pool.get_all_system_vars(), system_vars=variable_pool.get_all_system_vars(literal=True),
strict=strict strict=strict
) )
@@ -614,16 +617,32 @@ class BaseNode(ABC):
return variable_pool.has(selector) return variable_pool.has(selector)
@staticmethod @staticmethod
async def process_message(provider, content, enable_file=False) -> dict | str | None: async def process_message(provider: str, content: str | FileObject, enable_file=False) -> dict | str | None:
if isinstance(content, str): if isinstance(content, str):
if enable_file: if enable_file:
return {"text": content} return {"text": content}
return content return content
elif isinstance(content, dict):
trans_tool = PROVIDER_STRATEGIES[provider]() elif isinstance(content, FileObject):
result = await trans_tool.format_image(content["url"]) if content.content_cache.get(provider):
return result return content.content_cache[provider]
raise TypeError('Unexpect input value type') with get_db_read() as db:
multimodel_service = MultimodalService(db, provider)
message = await multimodel_service.process_files(
[FileInput.model_construct(
type=content.type,
url=content.url,
transfer_method=content.transfer_method,
file_type=content.origin_file_type,
upload_file_id=content.file_id
)]
)
if message:
content.content_cache[provider] = message[0]
return message[0]
return None
raise TypeError(f'Unexpect input value type - {type(content)}')
@staticmethod @staticmethod
def process_model_output(content) -> str: def process_model_output(content) -> str:

View File

@@ -91,8 +91,8 @@ class IterationRuntime:
return loopstate return loopstate
def merge_conv_vars(self): def merge_conv_vars(self):
self.variable_pool.get_all_conversation_vars().update( self.variable_pool.variables["conv"].update(
self.child_variable_pool.get_all_conversation_vars() self.child_variable_pool.variables["conv"]
) )
async def run_task(self, item, idx): async def run_task(self, item, idx):

View File

@@ -156,7 +156,7 @@ class LoopRuntime:
def merge_conv_vars(self, loopstate): def merge_conv_vars(self, loopstate):
self.variable_pool.variables["conv"].update( self.variable_pool.variables["conv"].update(
self.child_variable_pool.variables.get("conv", {}) self.child_variable_pool.variables["conv"]
) )
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False) loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
loopstate["node_outputs"][self.node_id] = loop_vars loopstate["node_outputs"][self.node_id] = loop_vars

View File

@@ -172,9 +172,9 @@ class LLMNode(BaseNode):
if self.typed_config.vision_input and self.typed_config.vision: if self.typed_config.vision_input and self.typed_config.vision:
file_content = [] file_content = []
files = variable_pool.get_value(self.typed_config.vision_input) files = variable_pool.get_instance(self.typed_config.vision_input)
for file in files: for file in files.value:
content = await self.process_message(provider, file, self.typed_config.vision) content = await self.process_message(provider, file.value, self.typed_config.vision)
if content: if content:
file_content.append(content) file_content.append(content)
if messages and messages[-1]["role"] == 'user': if messages and messages[-1]["role"] == 'user':

View File

@@ -2,7 +2,7 @@ from enum import StrEnum
from abc import abstractmethod, ABC from abc import abstractmethod, ABC
from typing import Any from typing import Any
from pydantic import BaseModel from pydantic import BaseModel, Field
from app.schemas import FileType from app.schemas import FileType
@@ -45,7 +45,7 @@ class VariableType(StrEnum):
return cls.NUMBER return cls.NUMBER
elif isinstance(var, bool): elif isinstance(var, bool):
return cls.BOOLEAN return cls.BOOLEAN
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('__file')): elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')):
return cls.FILE return cls.FILE
elif isinstance(var, dict): elif isinstance(var, dict):
return cls.OBJECT return cls.OBJECT
@@ -109,7 +109,13 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any:
class FileObject(BaseModel): class FileObject(BaseModel):
type: FileType type: FileType
url: str url: str
__file: bool transfer_method: str
origin_file_type: str
file_id: str | None
content_cache: dict = Field(default_factory=dict)
is_file: bool
class BaseVariable(ABC): class BaseVariable(ABC):

View File

@@ -63,13 +63,16 @@ class FileVariable(BaseVariable):
def valid_value(self, value) -> FileObject: def valid_value(self, value) -> FileObject:
if isinstance(value, dict): if isinstance(value, dict):
if not value.get("__file"): if not value.get("is_file"):
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}") raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
return FileObject( return FileObject(
**{ **{
"type": str(value.get('type')), "type": str(value.get('type')),
"transfer_method": value.get("transfer_method"),
"url": value.get('url'), "url": value.get('url'),
"__file": True "file_id": value.get("file_id"),
"origin_file_type": value.get("origin_file_type"),
"is_file": True
} }
) )
if isinstance(value, FileObject): if isinstance(value, FileObject):

View File

@@ -255,7 +255,7 @@ class AgentRunService:
@staticmethod @staticmethod
def prepare_variables( def prepare_variables(
input_vars: dict | None, input_vars: dict | None,
variables_config: dict | None variables_config: dict
) -> dict: ) -> dict:
input_vars = input_vars or {} input_vars = input_vars or {}
for variable in variables_config: for variable in variables_config:

View File

@@ -16,6 +16,7 @@ from app.core.workflow.adapters.registry import PlatformAdapterRegistry
from app.core.workflow.executor import execute_workflow, execute_workflow_stream from app.core.workflow.executor import execute_workflow, execute_workflow_stream
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.validator import validate_workflow_config from app.core.workflow.validator import validate_workflow_config
from app.core.workflow.variable.base_variable import FileObject
from app.db import get_db from app.db import get_db
from app.models import App from app.models import App
from app.models.workflow_model import WorkflowConfig, WorkflowExecution from app.models.workflow_model import WorkflowConfig, WorkflowExecution
@@ -453,11 +454,14 @@ class WorkflowService:
files_struct = [] files_struct = []
for file in files: for file in files:
files_struct.append( files_struct.append(
{ FileObject(
"type": file.type, type=file.type,
"url": await self.multimodal_service.get_file_url(file), url=await self.multimodal_service.get_file_url(file),
"__file": True transfer_method=file.transfer_method,
} file_id=str(file.upload_file_id),
origin_file_type=file.file_type,
is_file=True
).model_dump()
) )
return files_struct return files_struct