feat(workflow, skill): add multimodal image support to workflows and skill prompt generation
This commit is contained in:
@@ -92,7 +92,7 @@ class WorkflowExecutor:
|
||||
- "conversation_id": conversation identifier
|
||||
"""
|
||||
user_message = input_data.get("message") or ""
|
||||
user_file = input_data.get("file") or []
|
||||
user_files = input_data.get("files") or []
|
||||
|
||||
config_variables_list = self.workflow_config.get("variables") or []
|
||||
conv_vars = input_data.get("conv", {})
|
||||
@@ -119,12 +119,12 @@ class WorkflowExecutor:
|
||||
input_variables = input_data.get("variables") or {}
|
||||
sys_vars = {
|
||||
"message": (user_message, VariableType.STRING),
|
||||
"file": (user_file, VariableType.ARRAY_FILE),
|
||||
"conversation_id": (input_data.get("conversation_id"), VariableType.STRING),
|
||||
"execution_id": (self.execution_id, VariableType.STRING),
|
||||
"workspace_id": (self.workspace_id, VariableType.STRING),
|
||||
"user_id": (self.user_id, VariableType.STRING),
|
||||
"input_variables": (input_variables, VariableType.OBJECT),
|
||||
"files": (user_files, VariableType.ARRAY_FILE)
|
||||
}
|
||||
for key, var_def in sys_vars.items():
|
||||
value = var_def[0]
|
||||
@@ -564,6 +564,7 @@ class WorkflowExecutor:
|
||||
"input": result.get("node_outputs", {}).get(node_name, {}).get("input"),
|
||||
"output": result.get("node_outputs", {}).get(node_name, {}).get("output"),
|
||||
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
|
||||
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from app.core.config import settings
|
||||
from app.core.workflow.nodes.enums import BRANCH_NODES
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.services.multimodal_service import PROVIDER_STRATEGIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -651,3 +652,21 @@ class BaseNode(ABC):
|
||||
True if the variable exists in the pool, False otherwise.
|
||||
"""
|
||||
return variable_pool.has(selector)
|
||||
|
||||
@staticmethod
|
||||
async def process_message(provider, content, enable_file=False) -> dict | str | None:
|
||||
if isinstance(content, str):
|
||||
if enable_file:
|
||||
return {"text": content}
|
||||
return content
|
||||
elif isinstance(content, dict):
|
||||
trans_tool = PROVIDER_STRATEGIES[provider]()
|
||||
result = await trans_tool.format_image(content["url"])
|
||||
return result
|
||||
raise TypeError('Unexpect input value type')
|
||||
|
||||
@staticmethod
|
||||
def process_model_output(content) -> str:
|
||||
if isinstance(content, dict):
|
||||
return content.get("text")
|
||||
return content
|
||||
|
||||
@@ -71,6 +71,16 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
description="对话上下文窗口"
|
||||
)
|
||||
|
||||
vision: bool = Field(
|
||||
default=False,
|
||||
description="是否启用视觉模型"
|
||||
)
|
||||
|
||||
vision_input: str = Field(
|
||||
default=None,
|
||||
description="视觉输入"
|
||||
)
|
||||
|
||||
# 简单模式
|
||||
prompt: str | None = Field(
|
||||
default=None,
|
||||
|
||||
@@ -79,12 +79,12 @@ class LLMNode(BaseNode):
|
||||
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
|
||||
return re.sub(r"{{context}}", context, message)
|
||||
|
||||
def _prepare_llm(
|
||||
async def _prepare_llm(
|
||||
self,
|
||||
state: WorkflowState,
|
||||
variable_pool: VariablePool,
|
||||
stream: bool = False
|
||||
) -> tuple[RedBearLLM, list | str]:
|
||||
) -> RedBearLLM:
|
||||
"""准备 LLM 实例(公共逻辑)
|
||||
|
||||
Args:
|
||||
@@ -93,42 +93,9 @@ class LLMNode(BaseNode):
|
||||
Returns:
|
||||
(llm, messages_or_prompt): LLM 实例和消息列表或 prompt 字符串
|
||||
"""
|
||||
|
||||
# 1. 处理消息格式(优先使用 messages)
|
||||
self.typed_config = LLMNodeConfig(**self.config)
|
||||
messages_config = self.typed_config.messages
|
||||
|
||||
if messages_config:
|
||||
# 使用 LangChain 消息格式
|
||||
messages = []
|
||||
for msg_config in messages_config:
|
||||
role = msg_config.role.lower()
|
||||
content_template = msg_config.content
|
||||
content_template = self._render_context(content_template, variable_pool)
|
||||
content = self._render_template(content_template, variable_pool)
|
||||
|
||||
# 根据角色创建对应的消息对象
|
||||
if role == "system":
|
||||
messages.append({"role": "system", "content": content})
|
||||
elif role in ["user", "human"]:
|
||||
messages.append({"role": "user", "content": content})
|
||||
elif role in ["ai", "assistant"]:
|
||||
messages.append({"role": "assistant", "content": content})
|
||||
else:
|
||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||
messages.append({"role": "user", "content": content})
|
||||
|
||||
if self.typed_config.memory.enable:
|
||||
# if self.typed_config.memory.enable_window:
|
||||
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
|
||||
prompt_or_messages = messages
|
||||
else:
|
||||
# 使用简单的 prompt 格式(向后兼容)
|
||||
prompt_template = self.config.get("prompt", "")
|
||||
prompt_or_messages = self._render_template(prompt_template, variable_pool)
|
||||
|
||||
# 2. 获取模型配置
|
||||
model_id = self.config.get("model_id")
|
||||
model_id = self.typed_config.model_id
|
||||
if not model_id:
|
||||
raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置")
|
||||
|
||||
@@ -167,7 +134,61 @@ class LLMNode(BaseNode):
|
||||
|
||||
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
|
||||
|
||||
return llm, prompt_or_messages
|
||||
messages_config = self.typed_config.messages
|
||||
|
||||
if messages_config:
|
||||
# 使用 LangChain 消息格式
|
||||
messages = []
|
||||
for msg_config in messages_config:
|
||||
role = msg_config.role.lower()
|
||||
content_template = msg_config.content
|
||||
content_template = self._render_context(content_template, variable_pool)
|
||||
content = self._render_template(content_template, variable_pool)
|
||||
|
||||
# 根据角色创建对应的消息对象
|
||||
if role == "system":
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": content
|
||||
})
|
||||
elif role in ["user", "human"]:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": content
|
||||
})
|
||||
elif role in ["ai", "assistant"]:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": content
|
||||
})
|
||||
else:
|
||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": content
|
||||
})
|
||||
|
||||
if self.typed_config.vision_input and self.typed_config.vision:
|
||||
file_content = []
|
||||
files = variable_pool.get_value(self.typed_config.vision_input)
|
||||
for file in files:
|
||||
content = await self.process_message(provider, file, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.append(content)
|
||||
if messages and messages[-1]["role"] == 'user':
|
||||
messages[-1]['content'] = [messages[-1]["content"]] + file_content
|
||||
else:
|
||||
messages.append({"role": "user", "content": file_content})
|
||||
|
||||
if self.typed_config.memory.enable:
|
||||
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
|
||||
self.messages = messages
|
||||
else:
|
||||
# 使用简单的 prompt 格式(向后兼容)
|
||||
prompt_template = self.config.get("prompt", "")
|
||||
self.messages = self._render_template(prompt_template, variable_pool)
|
||||
|
||||
return llm
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage:
|
||||
"""非流式执行 LLM 调用
|
||||
@@ -180,15 +201,15 @@ class LLMNode(BaseNode):
|
||||
LLM 响应消息
|
||||
"""
|
||||
# self.typed_config = LLMNodeConfig(**self.config)
|
||||
llm, prompt_or_messages = self._prepare_llm(state, variable_pool, False)
|
||||
llm = await self._prepare_llm(state, variable_pool, False)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||
|
||||
# 调用 LLM(支持字符串或消息列表)
|
||||
response = await llm.ainvoke(prompt_or_messages)
|
||||
response = await llm.ainvoke(self.messages)
|
||||
# 提取内容
|
||||
if hasattr(response, 'content'):
|
||||
content = response.content
|
||||
content = self.process_model_output(response.content)
|
||||
else:
|
||||
content = str(response)
|
||||
|
||||
@@ -199,14 +220,13 @@ class LLMNode(BaseNode):
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""提取输入数据(用于记录)"""
|
||||
_, prompt_or_messages = self._prepare_llm(state, variable_pool)
|
||||
|
||||
return {
|
||||
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
|
||||
"prompt": self.messages if isinstance(self.messages, str) else None,
|
||||
"messages": [
|
||||
{"role": msg.get("role"), "content": msg.get("content", "")}
|
||||
for msg in prompt_or_messages
|
||||
] if isinstance(prompt_or_messages, list) else None,
|
||||
for msg in self.messages
|
||||
] if isinstance(self.messages, list) else None,
|
||||
"config": {
|
||||
"model_id": self.config.get("model_id"),
|
||||
"temperature": self.config.get("temperature"),
|
||||
@@ -226,8 +246,8 @@ class LLMNode(BaseNode):
|
||||
usage = business_result.response_metadata.get('token_usage')
|
||||
if usage:
|
||||
return {
|
||||
"prompt_tokens": usage.get('prompt_tokens', 0),
|
||||
"completion_tokens": usage.get('completion_tokens', 0),
|
||||
"prompt_tokens": usage.get('input_tokens', 0),
|
||||
"completion_tokens": usage.get('output_tokens', 0),
|
||||
"total_tokens": usage.get('total_tokens', 0)
|
||||
}
|
||||
return None
|
||||
@@ -244,7 +264,7 @@ class LLMNode(BaseNode):
|
||||
"""
|
||||
self.typed_config = LLMNodeConfig(**self.config)
|
||||
|
||||
llm, prompt_or_messages = self._prepare_llm(state, variable_pool, True)
|
||||
llm = await self._prepare_llm(state, variable_pool, True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||
@@ -255,10 +275,10 @@ class LLMNode(BaseNode):
|
||||
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
last_meta_data = {}
|
||||
async for chunk in llm.astream(prompt_or_messages, stream_usage=True):
|
||||
async for chunk in llm.astream(self.messages, stream_usage=True):
|
||||
# 提取内容
|
||||
if hasattr(chunk, 'content'):
|
||||
content = chunk.content
|
||||
content = self.process_model_output(chunk.content)
|
||||
else:
|
||||
content = str(chunk)
|
||||
if hasattr(chunk, 'response_metadata'):
|
||||
|
||||
@@ -2,6 +2,10 @@ from enum import StrEnum
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.schemas import FileType
|
||||
|
||||
|
||||
class VariableType(StrEnum):
|
||||
"""Enumeration of supported variable types in the workflow."""
|
||||
@@ -42,7 +46,7 @@ class VariableType(StrEnum):
|
||||
return cls.NUMBER
|
||||
elif isinstance(var_type, bool):
|
||||
return cls.BOOLEAN
|
||||
elif isinstance(var_type, FileObj):
|
||||
elif isinstance(var_type, FileObject) or (isinstance(var, dict) and var.get('__file')):
|
||||
return cls.FILE
|
||||
elif isinstance(var_type, dict):
|
||||
return cls.OBJECT
|
||||
@@ -103,8 +107,10 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any:
|
||||
raise TypeError(f"Invalid type - {type}")
|
||||
|
||||
|
||||
class FileObj:
|
||||
pass
|
||||
class FileObject(BaseModel):
|
||||
type: FileType
|
||||
url: str
|
||||
__file: bool
|
||||
|
||||
|
||||
class BaseVariable(ABC):
|
||||
|
||||
@@ -1,66 +1,84 @@
|
||||
from typing import Any, TypeVar, Type, Generic
|
||||
|
||||
from app.core.workflow.variable.base_variable import BaseVariable, VariableType
|
||||
from deprecated import deprecated
|
||||
|
||||
from app.core.workflow.variable.base_variable import BaseVariable, VariableType, FileObject
|
||||
|
||||
T = TypeVar("T", bound=BaseVariable)
|
||||
|
||||
|
||||
class StringObject(BaseVariable):
|
||||
class StringVariable(BaseVariable):
|
||||
type = 'str'
|
||||
|
||||
def valid_value(self, value) -> str:
|
||||
if not isinstance(value, str):
|
||||
raise TypeError("Value must be a string")
|
||||
raise TypeError(f"Value must be a string - {type(value)}:{value}")
|
||||
return value
|
||||
|
||||
def to_literal(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class NumberObject(BaseVariable):
|
||||
class NumberVariable(BaseVariable):
|
||||
type = 'number'
|
||||
|
||||
def valid_value(self, value) -> int | float:
|
||||
if not isinstance(value, (int, float)):
|
||||
raise TypeError("Value must be a number")
|
||||
raise TypeError(f"Value must be a number - {type(value)}:{value}")
|
||||
return value
|
||||
|
||||
def to_literal(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
|
||||
class BooleanObject(BaseVariable):
|
||||
class BooleanVariable(BaseVariable):
|
||||
type = 'boolean'
|
||||
|
||||
def valid_value(self, value) -> bool:
|
||||
if not isinstance(value, bool):
|
||||
raise TypeError("Value must be a boolean")
|
||||
raise TypeError(f"Value must be a boolean - {type(value)}:{value}")
|
||||
return value
|
||||
|
||||
def to_literal(self) -> str:
|
||||
return str(self.value).lower()
|
||||
|
||||
|
||||
class DictObject(BaseVariable):
|
||||
class DictVariable(BaseVariable):
|
||||
type = 'object'
|
||||
|
||||
def valid_value(self, value) -> dict:
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError("Value must be a dict")
|
||||
raise TypeError(f"Value must be a dict - {type(value)}:{value}")
|
||||
return value
|
||||
|
||||
def to_literal(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
|
||||
class FileObject(BaseVariable):
|
||||
class FileVariable(BaseVariable):
|
||||
type = 'file'
|
||||
|
||||
def valid_value(self, value) -> Any:
|
||||
pass
|
||||
def valid_value(self, value) -> FileObject:
|
||||
|
||||
if isinstance(value, dict):
|
||||
if not value.get("__file"):
|
||||
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
|
||||
return FileObject(
|
||||
**{
|
||||
"type": str(value.get('type')),
|
||||
"url": value.get('url'),
|
||||
"__file": True
|
||||
}
|
||||
)
|
||||
if isinstance(value, FileObject):
|
||||
return value
|
||||
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
|
||||
|
||||
def to_literal(self) -> str:
|
||||
pass
|
||||
return str(self.value.model_dump())
|
||||
|
||||
def get_value(self) -> Any:
|
||||
return self.value.model_dump()
|
||||
|
||||
|
||||
class ArrayObject(BaseVariable, Generic[T]):
|
||||
@@ -74,7 +92,7 @@ class ArrayObject(BaseVariable, Generic[T]):
|
||||
|
||||
def valid_value(self, value: list[Any]) -> list[T]:
|
||||
if not isinstance(value, list):
|
||||
raise TypeError("Value must be a list")
|
||||
raise TypeError(f"Value must be a list - {type(value)}:{value}")
|
||||
final_value = []
|
||||
for v in value:
|
||||
try:
|
||||
@@ -86,13 +104,16 @@ class ArrayObject(BaseVariable, Generic[T]):
|
||||
def to_literal(self) -> str:
|
||||
return "\n".join([v.to_literal() for v in self.value])
|
||||
|
||||
def get_value(self) -> Any:
|
||||
return [v.get_value() for v in self.value]
|
||||
|
||||
|
||||
class NestedArrayObject(BaseVariable):
|
||||
type = 'array_nest'
|
||||
|
||||
def valid_value(self, value: list[T]) -> list[T]:
|
||||
if not isinstance(value, list):
|
||||
raise TypeError("Value must be a list")
|
||||
raise TypeError(f"Value must be a list - {type(value)}:{value}")
|
||||
final_value = []
|
||||
for v in value:
|
||||
if not isinstance(v, ArrayObject):
|
||||
@@ -107,6 +128,10 @@ class NestedArrayObject(BaseVariable):
|
||||
return [[item.get_value() for item in row] for row in self.value]
|
||||
|
||||
|
||||
@deprecated(
|
||||
reason="Using arbitrary-type values may cause unexpected errors; please switch to strongly-typed values.",
|
||||
category=RuntimeWarning
|
||||
)
|
||||
class AnyObject(BaseVariable):
|
||||
type = 'any'
|
||||
|
||||
@@ -126,23 +151,23 @@ def make_array(child_type: Type[T], value: list[Any]) -> ArrayObject[T]:
|
||||
def create_variable_instance(var_type: VariableType, value: Any) -> T:
|
||||
match var_type:
|
||||
case VariableType.STRING:
|
||||
return StringObject(value)
|
||||
return StringVariable(value)
|
||||
case VariableType.NUMBER:
|
||||
return NumberObject(value)
|
||||
return NumberVariable(value)
|
||||
case VariableType.BOOLEAN:
|
||||
return BooleanObject(value)
|
||||
return BooleanVariable(value)
|
||||
case VariableType.OBJECT:
|
||||
return DictObject(value)
|
||||
return DictVariable(value)
|
||||
case VariableType.ARRAY_STRING:
|
||||
return make_array(StringObject, value)
|
||||
return make_array(StringVariable, value)
|
||||
case VariableType.ARRAY_NUMBER:
|
||||
return make_array(NumberObject, value)
|
||||
return make_array(NumberVariable, value)
|
||||
case VariableType.ARRAY_BOOLEAN:
|
||||
return make_array(BooleanObject, value)
|
||||
return make_array(BooleanVariable, value)
|
||||
case VariableType.ARRAY_OBJECT:
|
||||
return make_array(DictObject, value)
|
||||
return make_array(DictVariable, value)
|
||||
case VariableType.ARRAY_FILE:
|
||||
return make_array(FileObject, value)
|
||||
return make_array(FileVariable, value)
|
||||
case VariableType.ANY:
|
||||
return AnyObject(value)
|
||||
case _:
|
||||
|
||||
@@ -78,8 +78,8 @@ class VariableStruct(BaseModel, Generic[T]):
|
||||
serialization, and workflow type checking.
|
||||
instance:
|
||||
The concrete variable object. The actual Python type is
|
||||
represented by the generic parameter ``T`` (e.g. StringObject,
|
||||
NumberObject, ArrayObject[StringObject]).
|
||||
represented by the generic parameter ``T`` (e.g. StringVariable,
|
||||
NumberVariable, ArrayObject[StringVariable]).
|
||||
mut:
|
||||
Whether the variable is mutable.
|
||||
"""
|
||||
@@ -286,7 +286,7 @@ class VariablePool:
|
||||
系统变量字典
|
||||
"""
|
||||
sys_namespace = self.variables.get("sys", {})
|
||||
return {k: v.instance.value for k, v in sys_namespace.items()}
|
||||
return {k: v.instance.get_value() for k, v in sys_namespace.items()}
|
||||
|
||||
def get_all_conversation_vars(self) -> dict[str, Any]:
|
||||
"""获取所有会话变量
|
||||
@@ -295,7 +295,7 @@ class VariablePool:
|
||||
会话变量字典
|
||||
"""
|
||||
conv_namespace = self.variables.get("conv", {})
|
||||
return {k: v.instance.value for k, v in conv_namespace.items()}
|
||||
return {k: v.instance.get_value() for k, v in conv_namespace.items()}
|
||||
|
||||
def get_all_node_outputs(self) -> dict[str, Any]:
|
||||
"""获取所有节点输出(运行时变量)
|
||||
@@ -305,7 +305,7 @@ class VariablePool:
|
||||
"""
|
||||
runtime_vars = {
|
||||
namespace: {
|
||||
k: v.instance.value
|
||||
k: v.instance.get_value()
|
||||
for k, v in vars_dict.items()
|
||||
}
|
||||
for namespace, vars_dict in self.variables.items()
|
||||
@@ -326,7 +326,7 @@ class VariablePool:
|
||||
"""
|
||||
node_namespace = self.variables.get(node_id)
|
||||
if node_namespace:
|
||||
return {k: v.instance.value for k, v in node_namespace.items()}
|
||||
return {k: v.instance.get_value() for k, v in node_namespace.items()}
|
||||
if strict:
|
||||
raise KeyError(f"node {node_id} output not exist")
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user