feat(workflow): support multimodal input
This commit is contained in:
@@ -303,30 +303,44 @@ class VariablePool:
|
|||||||
"""
|
"""
|
||||||
return self._get_variable_struct(selector) is not None
|
return self._get_variable_struct(selector) is not None
|
||||||
|
|
||||||
def get_all_system_vars(self) -> dict[str, Any]:
|
def get_all_system_vars(self, literal=False) -> dict[str, Any]:
|
||||||
"""获取所有系统变量
|
"""获取所有系统变量
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
系统变量字典
|
系统变量字典
|
||||||
"""
|
"""
|
||||||
sys_namespace = self.variables.get("sys", {})
|
sys_namespace = self.variables.get("sys", {})
|
||||||
|
if literal:
|
||||||
|
return {k: v.instance.to_literal() for k, v in sys_namespace.items()}
|
||||||
return {k: v.instance.get_value() for k, v in sys_namespace.items()}
|
return {k: v.instance.get_value() for k, v in sys_namespace.items()}
|
||||||
|
|
||||||
def get_all_conversation_vars(self) -> dict[str, Any]:
|
def get_all_conversation_vars(self, literal=False) -> dict[str, Any]:
|
||||||
"""获取所有会话变量
|
"""获取所有会话变量
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
会话变量字典
|
会话变量字典
|
||||||
"""
|
"""
|
||||||
conv_namespace = self.variables.get("conv", {})
|
conv_namespace = self.variables.get("conv", {})
|
||||||
|
if literal:
|
||||||
|
return {k: v.instance.to_literal() for k, v in conv_namespace.items()}
|
||||||
return {k: v.instance.get_value() for k, v in conv_namespace.items()}
|
return {k: v.instance.get_value() for k, v in conv_namespace.items()}
|
||||||
|
|
||||||
def get_all_node_outputs(self) -> dict[str, Any]:
|
def get_all_node_outputs(self, literal=False) -> dict[str, Any]:
|
||||||
"""获取所有节点输出(运行时变量)
|
"""获取所有节点输出(运行时变量)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
节点输出字典,键为节点 ID
|
节点输出字典,键为节点 ID
|
||||||
"""
|
"""
|
||||||
|
if literal:
|
||||||
|
runtime_vars = {
|
||||||
|
namespace: {
|
||||||
|
k: v.instance.to_literal()
|
||||||
|
for k, v in vars_dict.items()
|
||||||
|
}
|
||||||
|
for namespace, vars_dict in self.variables.items()
|
||||||
|
if namespace not in ("sys", "conv")
|
||||||
|
}
|
||||||
|
else:
|
||||||
runtime_vars = {
|
runtime_vars = {
|
||||||
namespace: {
|
namespace: {
|
||||||
k: v.instance.get_value()
|
k: v.instance.get_value()
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, AsyncGenerator
|
from typing import Any, AsyncGenerator
|
||||||
@@ -10,8 +11,10 @@ from app.core.config import settings
|
|||||||
from app.core.workflow.engine.state_manager import WorkflowState
|
from app.core.workflow.engine.state_manager import WorkflowState
|
||||||
from app.core.workflow.engine.variable_pool import VariablePool
|
from app.core.workflow.engine.variable_pool import VariablePool
|
||||||
from app.core.workflow.nodes.enums import BRANCH_NODES
|
from app.core.workflow.nodes.enums import BRANCH_NODES
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||||
from app.services.multimodal_service import PROVIDER_STRATEGIES
|
from app.db import get_db_read
|
||||||
|
from app.schemas import FileInput
|
||||||
|
from app.services.multimodal_service import MultimodalService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -548,9 +551,9 @@ class BaseNode(ABC):
|
|||||||
|
|
||||||
return render_template(
|
return render_template(
|
||||||
template=template,
|
template=template,
|
||||||
conv_vars=variable_pool.get_all_conversation_vars(),
|
conv_vars=variable_pool.get_all_conversation_vars(literal=True),
|
||||||
node_outputs=variable_pool.get_all_node_outputs(),
|
node_outputs=variable_pool.get_all_node_outputs(literal=True),
|
||||||
system_vars=variable_pool.get_all_system_vars(),
|
system_vars=variable_pool.get_all_system_vars(literal=True),
|
||||||
strict=strict
|
strict=strict
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -614,16 +617,32 @@ class BaseNode(ABC):
|
|||||||
return variable_pool.has(selector)
|
return variable_pool.has(selector)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def process_message(provider, content, enable_file=False) -> dict | str | None:
|
async def process_message(provider: str, content: str | FileObject, enable_file=False) -> dict | str | None:
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
if enable_file:
|
if enable_file:
|
||||||
return {"text": content}
|
return {"text": content}
|
||||||
return content
|
return content
|
||||||
elif isinstance(content, dict):
|
|
||||||
trans_tool = PROVIDER_STRATEGIES[provider]()
|
elif isinstance(content, FileObject):
|
||||||
result = await trans_tool.format_image(content["url"])
|
if content.content_cache.get(provider):
|
||||||
return result
|
return content.content_cache[provider]
|
||||||
raise TypeError('Unexpect input value type')
|
with get_db_read() as db:
|
||||||
|
multimodel_service = MultimodalService(db, provider)
|
||||||
|
message = await multimodel_service.process_files(
|
||||||
|
[FileInput.model_construct(
|
||||||
|
type=content.type,
|
||||||
|
url=content.url,
|
||||||
|
transfer_method=content.transfer_method,
|
||||||
|
file_type=content.origin_file_type,
|
||||||
|
upload_file_id=content.file_id
|
||||||
|
)]
|
||||||
|
)
|
||||||
|
|
||||||
|
if message:
|
||||||
|
content.content_cache[provider] = message[0]
|
||||||
|
return message[0]
|
||||||
|
return None
|
||||||
|
raise TypeError(f'Unexpect input value type - {type(content)}')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def process_model_output(content) -> str:
|
def process_model_output(content) -> str:
|
||||||
|
|||||||
@@ -91,8 +91,8 @@ class IterationRuntime:
|
|||||||
return loopstate
|
return loopstate
|
||||||
|
|
||||||
def merge_conv_vars(self):
|
def merge_conv_vars(self):
|
||||||
self.variable_pool.get_all_conversation_vars().update(
|
self.variable_pool.variables["conv"].update(
|
||||||
self.child_variable_pool.get_all_conversation_vars()
|
self.child_variable_pool.variables["conv"]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run_task(self, item, idx):
|
async def run_task(self, item, idx):
|
||||||
|
|||||||
@@ -156,7 +156,7 @@ class LoopRuntime:
|
|||||||
|
|
||||||
def merge_conv_vars(self, loopstate):
|
def merge_conv_vars(self, loopstate):
|
||||||
self.variable_pool.variables["conv"].update(
|
self.variable_pool.variables["conv"].update(
|
||||||
self.child_variable_pool.variables.get("conv", {})
|
self.child_variable_pool.variables["conv"]
|
||||||
)
|
)
|
||||||
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
|
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
|
||||||
loopstate["node_outputs"][self.node_id] = loop_vars
|
loopstate["node_outputs"][self.node_id] = loop_vars
|
||||||
|
|||||||
@@ -172,9 +172,9 @@ class LLMNode(BaseNode):
|
|||||||
|
|
||||||
if self.typed_config.vision_input and self.typed_config.vision:
|
if self.typed_config.vision_input and self.typed_config.vision:
|
||||||
file_content = []
|
file_content = []
|
||||||
files = variable_pool.get_value(self.typed_config.vision_input)
|
files = variable_pool.get_instance(self.typed_config.vision_input)
|
||||||
for file in files:
|
for file in files.value:
|
||||||
content = await self.process_message(provider, file, self.typed_config.vision)
|
content = await self.process_message(provider, file.value, self.typed_config.vision)
|
||||||
if content:
|
if content:
|
||||||
file_content.append(content)
|
file_content.append(content)
|
||||||
if messages and messages[-1]["role"] == 'user':
|
if messages and messages[-1]["role"] == 'user':
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from enum import StrEnum
|
|||||||
from abc import abstractmethod, ABC
|
from abc import abstractmethod, ABC
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.schemas import FileType
|
from app.schemas import FileType
|
||||||
|
|
||||||
@@ -45,7 +45,7 @@ class VariableType(StrEnum):
|
|||||||
return cls.NUMBER
|
return cls.NUMBER
|
||||||
elif isinstance(var, bool):
|
elif isinstance(var, bool):
|
||||||
return cls.BOOLEAN
|
return cls.BOOLEAN
|
||||||
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('__file')):
|
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')):
|
||||||
return cls.FILE
|
return cls.FILE
|
||||||
elif isinstance(var, dict):
|
elif isinstance(var, dict):
|
||||||
return cls.OBJECT
|
return cls.OBJECT
|
||||||
@@ -109,7 +109,13 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any:
|
|||||||
class FileObject(BaseModel):
|
class FileObject(BaseModel):
|
||||||
type: FileType
|
type: FileType
|
||||||
url: str
|
url: str
|
||||||
__file: bool
|
transfer_method: str
|
||||||
|
origin_file_type: str
|
||||||
|
file_id: str | None
|
||||||
|
|
||||||
|
content_cache: dict = Field(default_factory=dict)
|
||||||
|
|
||||||
|
is_file: bool
|
||||||
|
|
||||||
|
|
||||||
class BaseVariable(ABC):
|
class BaseVariable(ABC):
|
||||||
|
|||||||
@@ -63,13 +63,16 @@ class FileVariable(BaseVariable):
|
|||||||
def valid_value(self, value) -> FileObject:
|
def valid_value(self, value) -> FileObject:
|
||||||
|
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
if not value.get("__file"):
|
if not value.get("is_file"):
|
||||||
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
|
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
|
||||||
return FileObject(
|
return FileObject(
|
||||||
**{
|
**{
|
||||||
"type": str(value.get('type')),
|
"type": str(value.get('type')),
|
||||||
|
"transfer_method": value.get("transfer_method"),
|
||||||
"url": value.get('url'),
|
"url": value.get('url'),
|
||||||
"__file": True
|
"file_id": value.get("file_id"),
|
||||||
|
"origin_file_type": value.get("origin_file_type"),
|
||||||
|
"is_file": True
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if isinstance(value, FileObject):
|
if isinstance(value, FileObject):
|
||||||
|
|||||||
@@ -255,7 +255,7 @@ class AgentRunService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_variables(
|
def prepare_variables(
|
||||||
input_vars: dict | None,
|
input_vars: dict | None,
|
||||||
variables_config: dict | None
|
variables_config: dict
|
||||||
) -> dict:
|
) -> dict:
|
||||||
input_vars = input_vars or {}
|
input_vars = input_vars or {}
|
||||||
for variable in variables_config:
|
for variable in variables_config:
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from app.core.workflow.adapters.registry import PlatformAdapterRegistry
|
|||||||
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
|
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
|
||||||
from app.core.workflow.nodes.enums import NodeType
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
from app.core.workflow.validator import validate_workflow_config
|
from app.core.workflow.validator import validate_workflow_config
|
||||||
|
from app.core.workflow.variable.base_variable import FileObject
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.models import App
|
from app.models import App
|
||||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||||
@@ -453,11 +454,14 @@ class WorkflowService:
|
|||||||
files_struct = []
|
files_struct = []
|
||||||
for file in files:
|
for file in files:
|
||||||
files_struct.append(
|
files_struct.append(
|
||||||
{
|
FileObject(
|
||||||
"type": file.type,
|
type=file.type,
|
||||||
"url": await self.multimodal_service.get_file_url(file),
|
url=await self.multimodal_service.get_file_url(file),
|
||||||
"__file": True
|
transfer_method=file.transfer_method,
|
||||||
}
|
file_id=str(file.upload_file_id),
|
||||||
|
origin_file_type=file.file_type,
|
||||||
|
is_file=True
|
||||||
|
).model_dump()
|
||||||
)
|
)
|
||||||
return files_struct
|
return files_struct
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user