feat(workflow): support multimodal input

This commit is contained in:
Eternity
2026-03-05 14:07:27 +08:00
parent a72d5d2c77
commit 78ce2a9a8b
9 changed files with 84 additions and 38 deletions

View File

@@ -303,38 +303,52 @@ class VariablePool:
"""
return self._get_variable_struct(selector) is not None
def get_all_system_vars(self) -> dict[str, Any]:
def get_all_system_vars(self, literal=False) -> dict[str, Any]:
"""获取所有系统变量
Returns:
系统变量字典
"""
sys_namespace = self.variables.get("sys", {})
if literal:
return {k: v.instance.to_literal() for k, v in sys_namespace.items()}
return {k: v.instance.get_value() for k, v in sys_namespace.items()}
def get_all_conversation_vars(self) -> dict[str, Any]:
def get_all_conversation_vars(self, literal=False) -> dict[str, Any]:
"""获取所有会话变量
Returns:
会话变量字典
"""
conv_namespace = self.variables.get("conv", {})
if literal:
return {k: v.instance.to_literal() for k, v in conv_namespace.items()}
return {k: v.instance.get_value() for k, v in conv_namespace.items()}
def get_all_node_outputs(self) -> dict[str, Any]:
def get_all_node_outputs(self, literal=False) -> dict[str, Any]:
"""获取所有节点输出(运行时变量)
Returns:
节点输出字典,键为节点 ID
"""
runtime_vars = {
namespace: {
k: v.instance.get_value()
for k, v in vars_dict.items()
if literal:
runtime_vars = {
namespace: {
k: v.instance.to_literal()
for k, v in vars_dict.items()
}
for namespace, vars_dict in self.variables.items()
if namespace not in ("sys", "conv")
}
else:
runtime_vars = {
namespace: {
k: v.instance.get_value()
for k, v in vars_dict.items()
}
for namespace, vars_dict in self.variables.items()
if namespace not in ("sys", "conv")
}
for namespace, vars_dict in self.variables.items()
if namespace not in ("sys", "conv")
}
return runtime_vars
def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None:

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
import uuid
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Any, AsyncGenerator
@@ -10,8 +11,10 @@ from app.core.config import settings
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.enums import BRANCH_NODES
from app.core.workflow.variable.base_variable import VariableType
from app.services.multimodal_service import PROVIDER_STRATEGIES
from app.core.workflow.variable.base_variable import VariableType, FileObject
from app.db import get_db_read
from app.schemas import FileInput
from app.services.multimodal_service import MultimodalService
logger = logging.getLogger(__name__)
@@ -548,9 +551,9 @@ class BaseNode(ABC):
return render_template(
template=template,
conv_vars=variable_pool.get_all_conversation_vars(),
node_outputs=variable_pool.get_all_node_outputs(),
system_vars=variable_pool.get_all_system_vars(),
conv_vars=variable_pool.get_all_conversation_vars(literal=True),
node_outputs=variable_pool.get_all_node_outputs(literal=True),
system_vars=variable_pool.get_all_system_vars(literal=True),
strict=strict
)
@@ -614,16 +617,32 @@ class BaseNode(ABC):
return variable_pool.has(selector)
@staticmethod
async def process_message(provider, content, enable_file=False) -> dict | str | None:
async def process_message(provider: str, content: str | FileObject, enable_file=False) -> dict | str | None:
if isinstance(content, str):
if enable_file:
return {"text": content}
return content
elif isinstance(content, dict):
trans_tool = PROVIDER_STRATEGIES[provider]()
result = await trans_tool.format_image(content["url"])
return result
raise TypeError('Unexpect input value type')
elif isinstance(content, FileObject):
if content.content_cache.get(provider):
return content.content_cache[provider]
with get_db_read() as db:
multimodel_service = MultimodalService(db, provider)
message = await multimodel_service.process_files(
[FileInput.model_construct(
type=content.type,
url=content.url,
transfer_method=content.transfer_method,
file_type=content.origin_file_type,
upload_file_id=content.file_id
)]
)
if message:
content.content_cache[provider] = message[0]
return message[0]
return None
raise TypeError(f'Unexpect input value type - {type(content)}')
@staticmethod
def process_model_output(content) -> str:

View File

@@ -91,8 +91,8 @@ class IterationRuntime:
return loopstate
def merge_conv_vars(self):
self.variable_pool.get_all_conversation_vars().update(
self.child_variable_pool.get_all_conversation_vars()
self.variable_pool.variables["conv"].update(
self.child_variable_pool.variables["conv"]
)
async def run_task(self, item, idx):

View File

@@ -156,7 +156,7 @@ class LoopRuntime:
def merge_conv_vars(self, loopstate):
self.variable_pool.variables["conv"].update(
self.child_variable_pool.variables.get("conv", {})
self.child_variable_pool.variables["conv"]
)
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
loopstate["node_outputs"][self.node_id] = loop_vars

View File

@@ -172,9 +172,9 @@ class LLMNode(BaseNode):
if self.typed_config.vision_input and self.typed_config.vision:
file_content = []
files = variable_pool.get_value(self.typed_config.vision_input)
for file in files:
content = await self.process_message(provider, file, self.typed_config.vision)
files = variable_pool.get_instance(self.typed_config.vision_input)
for file in files.value:
content = await self.process_message(provider, file.value, self.typed_config.vision)
if content:
file_content.append(content)
if messages and messages[-1]["role"] == 'user':

View File

@@ -2,7 +2,7 @@ from enum import StrEnum
from abc import abstractmethod, ABC
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
from app.schemas import FileType
@@ -45,7 +45,7 @@ class VariableType(StrEnum):
return cls.NUMBER
elif isinstance(var, bool):
return cls.BOOLEAN
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('__file')):
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')):
return cls.FILE
elif isinstance(var, dict):
return cls.OBJECT
@@ -109,7 +109,13 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any:
class FileObject(BaseModel):
type: FileType
url: str
__file: bool
transfer_method: str
origin_file_type: str
file_id: str | None
content_cache: dict = Field(default_factory=dict)
is_file: bool
class BaseVariable(ABC):

View File

@@ -63,13 +63,16 @@ class FileVariable(BaseVariable):
def valid_value(self, value) -> FileObject:
if isinstance(value, dict):
if not value.get("__file"):
if not value.get("is_file"):
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
return FileObject(
**{
"type": str(value.get('type')),
"transfer_method": value.get("transfer_method"),
"url": value.get('url'),
"__file": True
"file_id": value.get("file_id"),
"origin_file_type": value.get("origin_file_type"),
"is_file": True
}
)
if isinstance(value, FileObject):

View File

@@ -255,7 +255,7 @@ class AgentRunService:
@staticmethod
def prepare_variables(
input_vars: dict | None,
variables_config: dict | None
variables_config: dict
) -> dict:
input_vars = input_vars or {}
for variable in variables_config:

View File

@@ -16,6 +16,7 @@ from app.core.workflow.adapters.registry import PlatformAdapterRegistry
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.validator import validate_workflow_config
from app.core.workflow.variable.base_variable import FileObject
from app.db import get_db
from app.models import App
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
@@ -453,11 +454,14 @@ class WorkflowService:
files_struct = []
for file in files:
files_struct.append(
{
"type": file.type,
"url": await self.multimodal_service.get_file_url(file),
"__file": True
}
FileObject(
type=file.type,
url=await self.multimodal_service.get_file_url(file),
transfer_method=file.transfer_method,
file_id=str(file.upload_file_id),
origin_file_type=file.file_type,
is_file=True
).model_dump()
)
return files_struct