Merge pull request #325 from SuanmoSuanyangTechnology/feature/workflow-file

feat(workflow, skill): add multimodal image support to workflows and skill prompt generation
This commit is contained in:
Mark
2026-02-05 12:29:07 +08:00
committed by GitHub
15 changed files with 283 additions and 169 deletions

View File

@@ -120,7 +120,8 @@ async def get_prompt_opt(
session_id=session_id, session_id=session_id,
user_id=current_user.id, user_id=current_user.id,
current_prompt=data.current_prompt, current_prompt=data.current_prompt,
user_require=data.message user_require=data.message,
skill=data.skill
): ):
# chunk 是 prompt 的增量内容 # chunk 是 prompt 的增量内容
yield f"event:message\ndata: {json.dumps(chunk)}\n\n" yield f"event:message\ndata: {json.dumps(chunk)}\n\n"

View File

@@ -92,7 +92,7 @@ class WorkflowExecutor:
- "conversation_id": conversation identifier - "conversation_id": conversation identifier
""" """
user_message = input_data.get("message") or "" user_message = input_data.get("message") or ""
user_file = input_data.get("file") or [] user_files = input_data.get("files") or []
config_variables_list = self.workflow_config.get("variables") or [] config_variables_list = self.workflow_config.get("variables") or []
conv_vars = input_data.get("conv", {}) conv_vars = input_data.get("conv", {})
@@ -119,12 +119,12 @@ class WorkflowExecutor:
input_variables = input_data.get("variables") or {} input_variables = input_data.get("variables") or {}
sys_vars = { sys_vars = {
"message": (user_message, VariableType.STRING), "message": (user_message, VariableType.STRING),
"file": (user_file, VariableType.ARRAY_FILE),
"conversation_id": (input_data.get("conversation_id"), VariableType.STRING), "conversation_id": (input_data.get("conversation_id"), VariableType.STRING),
"execution_id": (self.execution_id, VariableType.STRING), "execution_id": (self.execution_id, VariableType.STRING),
"workspace_id": (self.workspace_id, VariableType.STRING), "workspace_id": (self.workspace_id, VariableType.STRING),
"user_id": (self.user_id, VariableType.STRING), "user_id": (self.user_id, VariableType.STRING),
"input_variables": (input_variables, VariableType.OBJECT), "input_variables": (input_variables, VariableType.OBJECT),
"files": (user_files, VariableType.ARRAY_FILE)
} }
for key, var_def in sys_vars.items(): for key, var_def in sys_vars.items():
value = var_def[0] value = var_def[0]
@@ -564,6 +564,7 @@ class WorkflowExecutor:
"input": result.get("node_outputs", {}).get(node_name, {}).get("input"), "input": result.get("node_outputs", {}).get(node_name, {}).get("input"),
"output": result.get("node_outputs", {}).get(node_name, {}).get("output"), "output": result.get("node_outputs", {}).get(node_name, {}).get("output"),
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"), "elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
} }
} }

View File

@@ -11,6 +11,7 @@ from app.core.config import settings
from app.core.workflow.nodes.enums import BRANCH_NODES from app.core.workflow.nodes.enums import BRANCH_NODES
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool from app.core.workflow.variable_pool import VariablePool
from app.services.multimodal_service import PROVIDER_STRATEGIES
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -651,3 +652,21 @@ class BaseNode(ABC):
True if the variable exists in the pool, False otherwise. True if the variable exists in the pool, False otherwise.
""" """
return variable_pool.has(selector) return variable_pool.has(selector)
@staticmethod
async def process_message(provider, content, enable_file=False) -> dict | str | None:
if isinstance(content, str):
if enable_file:
return {"text": content}
return content
elif isinstance(content, dict):
trans_tool = PROVIDER_STRATEGIES[provider]()
result = await trans_tool.format_image(content["url"])
return result
raise TypeError('Unexpect input value type')
@staticmethod
def process_model_output(content) -> str:
if isinstance(content, dict):
return content.get("text")
return content

View File

@@ -71,6 +71,16 @@ class LLMNodeConfig(BaseNodeConfig):
description="对话上下文窗口" description="对话上下文窗口"
) )
vision: bool = Field(
default=False,
description="是否启用视觉模型"
)
vision_input: str = Field(
default=None,
description="视觉输入"
)
# 简单模式 # 简单模式
prompt: str | None = Field( prompt: str | None = Field(
default=None, default=None,

View File

@@ -79,12 +79,12 @@ class LLMNode(BaseNode):
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>" context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
return re.sub(r"{{context}}", context, message) return re.sub(r"{{context}}", context, message)
def _prepare_llm( async def _prepare_llm(
self, self,
state: WorkflowState, state: WorkflowState,
variable_pool: VariablePool, variable_pool: VariablePool,
stream: bool = False stream: bool = False
) -> tuple[RedBearLLM, list | str]: ) -> RedBearLLM:
"""准备 LLM 实例(公共逻辑) """准备 LLM 实例(公共逻辑)
Args: Args:
@@ -93,42 +93,9 @@ class LLMNode(BaseNode):
Returns: Returns:
(llm, messages_or_prompt): LLM 实例和消息列表或 prompt 字符串 (llm, messages_or_prompt): LLM 实例和消息列表或 prompt 字符串
""" """
# 1. 处理消息格式(优先使用 messages
self.typed_config = LLMNodeConfig(**self.config) self.typed_config = LLMNodeConfig(**self.config)
messages_config = self.typed_config.messages
if messages_config: model_id = self.typed_config.model_id
# 使用 LangChain 消息格式
messages = []
for msg_config in messages_config:
role = msg_config.role.lower()
content_template = msg_config.content
content_template = self._render_context(content_template, variable_pool)
content = self._render_template(content_template, variable_pool)
# 根据角色创建对应的消息对象
if role == "system":
messages.append({"role": "system", "content": content})
elif role in ["user", "human"]:
messages.append({"role": "user", "content": content})
elif role in ["ai", "assistant"]:
messages.append({"role": "assistant", "content": content})
else:
logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append({"role": "user", "content": content})
if self.typed_config.memory.enable:
# if self.typed_config.memory.enable_window:
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
prompt_or_messages = messages
else:
# 使用简单的 prompt 格式(向后兼容)
prompt_template = self.config.get("prompt", "")
prompt_or_messages = self._render_template(prompt_template, variable_pool)
# 2. 获取模型配置
model_id = self.config.get("model_id")
if not model_id: if not model_id:
raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置") raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置")
@@ -167,7 +134,61 @@ class LLMNode(BaseNode):
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}") logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
return llm, prompt_or_messages messages_config = self.typed_config.messages
if messages_config:
# 使用 LangChain 消息格式
messages = []
for msg_config in messages_config:
role = msg_config.role.lower()
content_template = msg_config.content
content_template = self._render_context(content_template, variable_pool)
content = self._render_template(content_template, variable_pool)
# 根据角色创建对应的消息对象
if role == "system":
messages.append({
"role": "system",
"content": content
})
elif role in ["user", "human"]:
messages.append({
"role": "user",
"content": content
})
elif role in ["ai", "assistant"]:
messages.append({
"role": "assistant",
"content": content
})
else:
logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append({
"role": "user",
"content": content
})
if self.typed_config.vision_input and self.typed_config.vision:
file_content = []
files = variable_pool.get_value(self.typed_config.vision_input)
for file in files:
content = await self.process_message(provider, file, self.typed_config.vision)
if content:
file_content.append(content)
if messages and messages[-1]["role"] == 'user':
messages[-1]['content'] = [messages[-1]["content"]] + file_content
else:
messages.append({"role": "user", "content": file_content})
if self.typed_config.memory.enable:
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
self.messages = messages
else:
# 使用简单的 prompt 格式(向后兼容)
prompt_template = self.config.get("prompt", "")
self.messages = self._render_template(prompt_template, variable_pool)
return llm
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage:
"""非流式执行 LLM 调用 """非流式执行 LLM 调用
@@ -180,15 +201,15 @@ class LLMNode(BaseNode):
LLM 响应消息 LLM 响应消息
""" """
# self.typed_config = LLMNodeConfig(**self.config) # self.typed_config = LLMNodeConfig(**self.config)
llm, prompt_or_messages = self._prepare_llm(state, variable_pool, False) llm = await self._prepare_llm(state, variable_pool, False)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)") logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
# 调用 LLM支持字符串或消息列表 # 调用 LLM支持字符串或消息列表
response = await llm.ainvoke(prompt_or_messages) response = await llm.ainvoke(self.messages)
# 提取内容 # 提取内容
if hasattr(response, 'content'): if hasattr(response, 'content'):
content = response.content content = self.process_model_output(response.content)
else: else:
content = str(response) content = str(response)
@@ -199,14 +220,13 @@ class LLMNode(BaseNode):
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""提取输入数据(用于记录)""" """提取输入数据(用于记录)"""
_, prompt_or_messages = self._prepare_llm(state, variable_pool)
return { return {
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None, "prompt": self.messages if isinstance(self.messages, str) else None,
"messages": [ "messages": [
{"role": msg.get("role"), "content": msg.get("content", "")} {"role": msg.get("role"), "content": msg.get("content", "")}
for msg in prompt_or_messages for msg in self.messages
] if isinstance(prompt_or_messages, list) else None, ] if isinstance(self.messages, list) else None,
"config": { "config": {
"model_id": self.config.get("model_id"), "model_id": self.config.get("model_id"),
"temperature": self.config.get("temperature"), "temperature": self.config.get("temperature"),
@@ -226,8 +246,8 @@ class LLMNode(BaseNode):
usage = business_result.response_metadata.get('token_usage') usage = business_result.response_metadata.get('token_usage')
if usage: if usage:
return { return {
"prompt_tokens": usage.get('prompt_tokens', 0), "prompt_tokens": usage.get('input_tokens', 0),
"completion_tokens": usage.get('completion_tokens', 0), "completion_tokens": usage.get('output_tokens', 0),
"total_tokens": usage.get('total_tokens', 0) "total_tokens": usage.get('total_tokens', 0)
} }
return None return None
@@ -244,7 +264,7 @@ class LLMNode(BaseNode):
""" """
self.typed_config = LLMNodeConfig(**self.config) self.typed_config = LLMNodeConfig(**self.config)
llm, prompt_or_messages = self._prepare_llm(state, variable_pool, True) llm = await self._prepare_llm(state, variable_pool, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}") logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
@@ -255,10 +275,10 @@ class LLMNode(BaseNode):
# 调用 LLM流式支持字符串或消息列表 # 调用 LLM流式支持字符串或消息列表
last_meta_data = {} last_meta_data = {}
async for chunk in llm.astream(prompt_or_messages, stream_usage=True): async for chunk in llm.astream(self.messages, stream_usage=True):
# 提取内容 # 提取内容
if hasattr(chunk, 'content'): if hasattr(chunk, 'content'):
content = chunk.content content = self.process_model_output(chunk.content)
else: else:
content = str(chunk) content = str(chunk)
if hasattr(chunk, 'response_metadata'): if hasattr(chunk, 'response_metadata'):

View File

@@ -2,6 +2,10 @@ from enum import StrEnum
from abc import abstractmethod, ABC from abc import abstractmethod, ABC
from typing import Any from typing import Any
from pydantic import BaseModel
from app.schemas import FileType
class VariableType(StrEnum): class VariableType(StrEnum):
"""Enumeration of supported variable types in the workflow.""" """Enumeration of supported variable types in the workflow."""
@@ -42,7 +46,7 @@ class VariableType(StrEnum):
return cls.NUMBER return cls.NUMBER
elif isinstance(var_type, bool): elif isinstance(var_type, bool):
return cls.BOOLEAN return cls.BOOLEAN
elif isinstance(var_type, FileObj): elif isinstance(var_type, FileObject) or (isinstance(var, dict) and var.get('__file')):
return cls.FILE return cls.FILE
elif isinstance(var_type, dict): elif isinstance(var_type, dict):
return cls.OBJECT return cls.OBJECT
@@ -103,8 +107,10 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any:
raise TypeError(f"Invalid type - {type}") raise TypeError(f"Invalid type - {type}")
class FileObj: class FileObject(BaseModel):
pass type: FileType
url: str
__file: bool
class BaseVariable(ABC): class BaseVariable(ABC):

View File

@@ -1,66 +1,84 @@
from typing import Any, TypeVar, Type, Generic from typing import Any, TypeVar, Type, Generic
from app.core.workflow.variable.base_variable import BaseVariable, VariableType from deprecated import deprecated
from app.core.workflow.variable.base_variable import BaseVariable, VariableType, FileObject
T = TypeVar("T", bound=BaseVariable) T = TypeVar("T", bound=BaseVariable)
class StringObject(BaseVariable): class StringVariable(BaseVariable):
type = 'str' type = 'str'
def valid_value(self, value) -> str: def valid_value(self, value) -> str:
if not isinstance(value, str): if not isinstance(value, str):
raise TypeError("Value must be a string") raise TypeError(f"Value must be a string - {type(value)}:{value}")
return value return value
def to_literal(self) -> str: def to_literal(self) -> str:
return self.value return self.value
class NumberObject(BaseVariable): class NumberVariable(BaseVariable):
type = 'number' type = 'number'
def valid_value(self, value) -> int | float: def valid_value(self, value) -> int | float:
if not isinstance(value, (int, float)): if not isinstance(value, (int, float)):
raise TypeError("Value must be a number") raise TypeError(f"Value must be a number - {type(value)}:{value}")
return value return value
def to_literal(self) -> str: def to_literal(self) -> str:
return str(self.value) return str(self.value)
class BooleanObject(BaseVariable): class BooleanVariable(BaseVariable):
type = 'boolean' type = 'boolean'
def valid_value(self, value) -> bool: def valid_value(self, value) -> bool:
if not isinstance(value, bool): if not isinstance(value, bool):
raise TypeError("Value must be a boolean") raise TypeError(f"Value must be a boolean - {type(value)}:{value}")
return value return value
def to_literal(self) -> str: def to_literal(self) -> str:
return str(self.value).lower() return str(self.value).lower()
class DictObject(BaseVariable): class DictVariable(BaseVariable):
type = 'object' type = 'object'
def valid_value(self, value) -> dict: def valid_value(self, value) -> dict:
if not isinstance(value, dict): if not isinstance(value, dict):
raise TypeError("Value must be a dict") raise TypeError(f"Value must be a dict - {type(value)}:{value}")
return value return value
def to_literal(self) -> str: def to_literal(self) -> str:
return str(self.value) return str(self.value)
class FileObject(BaseVariable): class FileVariable(BaseVariable):
type = 'file' type = 'file'
def valid_value(self, value) -> Any: def valid_value(self, value) -> FileObject:
pass
if isinstance(value, dict):
if not value.get("__file"):
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
return FileObject(
**{
"type": str(value.get('type')),
"url": value.get('url'),
"__file": True
}
)
if isinstance(value, FileObject):
return value
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
def to_literal(self) -> str: def to_literal(self) -> str:
pass return str(self.value.model_dump())
def get_value(self) -> Any:
return self.value.model_dump()
class ArrayObject(BaseVariable, Generic[T]): class ArrayObject(BaseVariable, Generic[T]):
@@ -74,7 +92,7 @@ class ArrayObject(BaseVariable, Generic[T]):
def valid_value(self, value: list[Any]) -> list[T]: def valid_value(self, value: list[Any]) -> list[T]:
if not isinstance(value, list): if not isinstance(value, list):
raise TypeError("Value must be a list") raise TypeError(f"Value must be a list - {type(value)}:{value}")
final_value = [] final_value = []
for v in value: for v in value:
try: try:
@@ -86,13 +104,16 @@ class ArrayObject(BaseVariable, Generic[T]):
def to_literal(self) -> str: def to_literal(self) -> str:
return "\n".join([v.to_literal() for v in self.value]) return "\n".join([v.to_literal() for v in self.value])
def get_value(self) -> Any:
return [v.get_value() for v in self.value]
class NestedArrayObject(BaseVariable): class NestedArrayObject(BaseVariable):
type = 'array_nest' type = 'array_nest'
def valid_value(self, value: list[T]) -> list[T]: def valid_value(self, value: list[T]) -> list[T]:
if not isinstance(value, list): if not isinstance(value, list):
raise TypeError("Value must be a list") raise TypeError(f"Value must be a list - {type(value)}:{value}")
final_value = [] final_value = []
for v in value: for v in value:
if not isinstance(v, ArrayObject): if not isinstance(v, ArrayObject):
@@ -107,6 +128,10 @@ class NestedArrayObject(BaseVariable):
return [[item.get_value() for item in row] for row in self.value] return [[item.get_value() for item in row] for row in self.value]
@deprecated(
reason="Using arbitrary-type values may cause unexpected errors; please switch to strongly-typed values.",
category=RuntimeWarning
)
class AnyObject(BaseVariable): class AnyObject(BaseVariable):
type = 'any' type = 'any'
@@ -126,23 +151,23 @@ def make_array(child_type: Type[T], value: list[Any]) -> ArrayObject[T]:
def create_variable_instance(var_type: VariableType, value: Any) -> T: def create_variable_instance(var_type: VariableType, value: Any) -> T:
match var_type: match var_type:
case VariableType.STRING: case VariableType.STRING:
return StringObject(value) return StringVariable(value)
case VariableType.NUMBER: case VariableType.NUMBER:
return NumberObject(value) return NumberVariable(value)
case VariableType.BOOLEAN: case VariableType.BOOLEAN:
return BooleanObject(value) return BooleanVariable(value)
case VariableType.OBJECT: case VariableType.OBJECT:
return DictObject(value) return DictVariable(value)
case VariableType.ARRAY_STRING: case VariableType.ARRAY_STRING:
return make_array(StringObject, value) return make_array(StringVariable, value)
case VariableType.ARRAY_NUMBER: case VariableType.ARRAY_NUMBER:
return make_array(NumberObject, value) return make_array(NumberVariable, value)
case VariableType.ARRAY_BOOLEAN: case VariableType.ARRAY_BOOLEAN:
return make_array(BooleanObject, value) return make_array(BooleanVariable, value)
case VariableType.ARRAY_OBJECT: case VariableType.ARRAY_OBJECT:
return make_array(DictObject, value) return make_array(DictVariable, value)
case VariableType.ARRAY_FILE: case VariableType.ARRAY_FILE:
return make_array(FileObject, value) return make_array(FileVariable, value)
case VariableType.ANY: case VariableType.ANY:
return AnyObject(value) return AnyObject(value)
case _: case _:

View File

@@ -78,8 +78,8 @@ class VariableStruct(BaseModel, Generic[T]):
serialization, and workflow type checking. serialization, and workflow type checking.
instance: instance:
The concrete variable object. The actual Python type is The concrete variable object. The actual Python type is
represented by the generic parameter ``T`` (e.g. StringObject, represented by the generic parameter ``T`` (e.g. StringVariable,
NumberObject, ArrayObject[StringObject]). NumberVariable, ArrayObject[StringVariable]).
mut: mut:
Whether the variable is mutable. Whether the variable is mutable.
""" """
@@ -286,7 +286,7 @@ class VariablePool:
系统变量字典 系统变量字典
""" """
sys_namespace = self.variables.get("sys", {}) sys_namespace = self.variables.get("sys", {})
return {k: v.instance.value for k, v in sys_namespace.items()} return {k: v.instance.get_value() for k, v in sys_namespace.items()}
def get_all_conversation_vars(self) -> dict[str, Any]: def get_all_conversation_vars(self) -> dict[str, Any]:
"""获取所有会话变量 """获取所有会话变量
@@ -295,7 +295,7 @@ class VariablePool:
会话变量字典 会话变量字典
""" """
conv_namespace = self.variables.get("conv", {}) conv_namespace = self.variables.get("conv", {})
return {k: v.instance.value for k, v in conv_namespace.items()} return {k: v.instance.get_value() for k, v in conv_namespace.items()}
def get_all_node_outputs(self) -> dict[str, Any]: def get_all_node_outputs(self) -> dict[str, Any]:
"""获取所有节点输出(运行时变量) """获取所有节点输出(运行时变量)
@@ -305,7 +305,7 @@ class VariablePool:
""" """
runtime_vars = { runtime_vars = {
namespace: { namespace: {
k: v.instance.value k: v.instance.get_value()
for k, v in vars_dict.items() for k, v in vars_dict.items()
} }
for namespace, vars_dict in self.variables.items() for namespace, vars_dict in self.variables.items()
@@ -326,7 +326,7 @@ class VariablePool:
""" """
node_namespace = self.variables.get(node_id) node_namespace = self.variables.get(node_id)
if node_namespace: if node_namespace:
return {k: v.instance.value for k, v in node_namespace.items()} return {k: v.instance.get_value() for k, v in node_namespace.items()}
if strict: if strict:
raise KeyError(f"node {node_id} output not exist") raise KeyError(f"node {node_id} output not exist")
else: else:

View File

@@ -1,14 +1,14 @@
import datetime import datetime
import uuid import uuid
from typing import Optional, Any, List, Dict, Union from typing import Optional, Any, List, Dict, Union
from enum import Enum from enum import Enum, StrEnum
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
# ---------- Multimodal File Support ---------- # ---------- Multimodal File Support ----------
class FileType(str, Enum): class FileType(StrEnum):
"""文件类型枚举""" """文件类型枚举"""
IMAGE = "image" IMAGE = "image"
DOCUMENT = "document" DOCUMENT = "document"

View File

@@ -21,6 +21,11 @@ class PromptOptMessage(BaseModel):
description="currently optimized prompt" description="currently optimized prompt"
) )
skill: bool = Field(
default=False,
description="Enable variable output"
)
class PromptSaveRequest(BaseModel): class PromptSaveRequest(BaseModel):
session_id: UUID = Field( session_id: UUID = Field(

View File

@@ -1089,7 +1089,7 @@ class DraftRunService:
except Exception as e: except Exception as e:
# 对于多 Agent 应用,没有直接的 AgentConfig 是正常的 # 对于多 Agent 应用,没有直接的 AgentConfig 是正常的
logger.debug("获取配置快照失败(可能是多 Agent 应用)", extra={"error": str(e)}) logger.debug("获取配置快照失败(可能是多 Agent 应用)", exc_info=True, extra={"error": str(e)})
return {} return {}
def _replace_variables( def _replace_variables(

View File

@@ -23,7 +23,7 @@ logger = get_business_logger()
class ImageFormatStrategy(Protocol): class ImageFormatStrategy(Protocol):
"""图片格式策略接口""" """图片格式策略接口"""
async def format_image(self, url: str) -> Dict[str, Any]: async def format_image(self, url: str) -> Dict[str, Any]:
"""将图片 URL 转换为特定 provider 的格式""" """将图片 URL 转换为特定 provider 的格式"""
... ...
@@ -31,7 +31,7 @@ class ImageFormatStrategy(Protocol):
class DashScopeImageStrategy: class DashScopeImageStrategy:
"""通义千问图片格式策略""" """通义千问图片格式策略"""
async def format_image(self, url: str) -> Dict[str, Any]: async def format_image(self, url: str) -> Dict[str, Any]:
"""通义千问格式: {"type": "image", "image": "url"}""" """通义千问格式: {"type": "image", "image": "url"}"""
return { return {
@@ -42,7 +42,7 @@ class DashScopeImageStrategy:
class BedrockImageStrategy: class BedrockImageStrategy:
"""Bedrock/Anthropic 图片格式策略""" """Bedrock/Anthropic 图片格式策略"""
async def format_image(self, url: str) -> Dict[str, Any]: async def format_image(self, url: str) -> Dict[str, Any]:
""" """
Bedrock/Anthropic 格式: base64 编码 Bedrock/Anthropic 格式: base64 编码
@@ -51,17 +51,17 @@ class BedrockImageStrategy:
import httpx import httpx
import base64 import base64
from mimetypes import guess_type from mimetypes import guess_type
logger.info(f"下载并编码图片: {url}") logger.info(f"下载并编码图片: {url}")
# 下载图片 # 下载图片
async with httpx.AsyncClient(timeout=30.0) as client: async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url) response = await client.get(url)
response.raise_for_status() response.raise_for_status()
# 获取图片数据 # 获取图片数据
image_data = response.content image_data = response.content
# 确定 media type # 确定 media type
content_type = response.headers.get("content-type") content_type = response.headers.get("content-type")
if content_type and content_type.startswith("image/"): if content_type and content_type.startswith("image/"):
@@ -69,12 +69,12 @@ class BedrockImageStrategy:
else: else:
guessed_type, _ = guess_type(url) guessed_type, _ = guess_type(url)
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg" media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
# 转换为 base64 # 转换为 base64
base64_data = base64.b64encode(image_data).decode("utf-8") base64_data = base64.b64encode(image_data).decode("utf-8")
logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}") logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
return { return {
"type": "image", "type": "image",
"source": { "source": {
@@ -87,7 +87,7 @@ class BedrockImageStrategy:
class OpenAIImageStrategy: class OpenAIImageStrategy:
"""OpenAI 图片格式策略""" """OpenAI 图片格式策略"""
async def format_image(self, url: str) -> Dict[str, Any]: async def format_image(self, url: str) -> Dict[str, Any]:
"""OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}""" """OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}"""
return { return {
@@ -109,7 +109,7 @@ PROVIDER_STRATEGIES = {
class MultimodalService: class MultimodalService:
"""多模态文件处理服务""" """多模态文件处理服务"""
def __init__(self, db: Session, provider: str = "dashscope"): def __init__(self, db: Session, provider: str = "dashscope"):
""" """
初始化多模态服务 初始化多模态服务
@@ -120,10 +120,10 @@ class MultimodalService:
""" """
self.db = db self.db = db
self.provider = provider.lower() self.provider = provider.lower()
async def process_files( async def process_files(
self, self,
files: Optional[List[FileInput]] files: Optional[List[FileInput]]
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
处理文件列表,返回 LLM 可用的格式 处理文件列表,返回 LLM 可用的格式
@@ -136,7 +136,7 @@ class MultimodalService:
""" """
if not files: if not files:
return [] return []
result = [] result = []
for idx, file in enumerate(files): for idx, file in enumerate(files):
try: try:
@@ -168,10 +168,10 @@ class MultimodalService:
"type": "text", "type": "text",
"text": f"[文件处理失败: {str(e)}]" "text": f"[文件处理失败: {str(e)}]"
}) })
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件provider={self.provider}") logger.info(f"成功处理 {len(result)}/{len(files)} 个文件provider={self.provider}")
return result return result
async def _process_image(self, file: FileInput) -> Dict[str, Any]: async def _process_image(self, file: FileInput) -> Dict[str, Any]:
""" """
处理图片文件 处理图片文件
@@ -184,14 +184,10 @@ class MultimodalService:
- Anthropic/Bedrock: {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}} - Anthropic/Bedrock: {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
- 通义千问: {"type": "image", "image": "url"} - 通义千问: {"type": "image", "image": "url"}
""" """
if file.transfer_method == TransferMethod.REMOTE_URL: url = await self.get_file_url(file)
url = file.url
else:
# 本地文件,获取访问 URL
url = await self._get_file_url(file.upload_file_id)
logger.debug(f"处理图片: {url}, provider={self.provider}") logger.debug(f"处理图片: {url}, provider={self.provider}")
# 根据 provider 返回不同格式 # 根据 provider 返回不同格式
if self.provider in ["bedrock", "anthropic"]: if self.provider in ["bedrock", "anthropic"]:
# Anthropic/Bedrock 只支持 base64 格式,需要下载并转换 # Anthropic/Bedrock 只支持 base64 格式,需要下载并转换
@@ -223,7 +219,7 @@ class MultimodalService:
"type": "image", "type": "image",
"image": url "image": url
} }
async def _download_and_encode_image(self, url: str) -> tuple[str, str]: async def _download_and_encode_image(self, url: str) -> tuple[str, str]:
""" """
下载图片并转换为 base64 下载图片并转换为 base64
@@ -237,15 +233,15 @@ class MultimodalService:
import httpx import httpx
import base64 import base64
from mimetypes import guess_type from mimetypes import guess_type
# 下载图片 # 下载图片
async with httpx.AsyncClient(timeout=30.0) as client: async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url) response = await client.get(url)
response.raise_for_status() response.raise_for_status()
# 获取图片数据 # 获取图片数据
image_data = response.content image_data = response.content
# 确定 media type # 确定 media type
content_type = response.headers.get("content-type") content_type = response.headers.get("content-type")
if content_type and content_type.startswith("image/"): if content_type and content_type.startswith("image/"):
@@ -254,14 +250,14 @@ class MultimodalService:
# 从 URL 推断 # 从 URL 推断
guessed_type, _ = guess_type(url) guessed_type, _ = guess_type(url)
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg" media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
# 转换为 base64 # 转换为 base64
base64_data = base64.b64encode(image_data).decode("utf-8") base64_data = base64.b64encode(image_data).decode("utf-8")
logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}") logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
return base64_data, media_type return base64_data, media_type
async def _process_document(self, file: FileInput) -> Dict[str, Any]: async def _process_document(self, file: FileInput) -> Dict[str, Any]:
""" """
处理文档文件PDF、Word 等) 处理文档文件PDF、Word 等)
@@ -284,14 +280,14 @@ class MultimodalService:
generic_file = self.db.query(GenericFile).filter( generic_file = self.db.query(GenericFile).filter(
GenericFile.id == file.upload_file_id GenericFile.id == file.upload_file_id
).first() ).first()
file_name = generic_file.file_name if generic_file else "unknown" file_name = generic_file.file_name if generic_file else "unknown"
return { return {
"type": "text", "type": "text",
"text": f"<document name=\"{file_name}\">\n{text}\n</document>" "text": f"<document name=\"{file_name}\">\n{text}\n</document>"
} }
async def _process_audio(self, file: FileInput) -> Dict[str, Any]: async def _process_audio(self, file: FileInput) -> Dict[str, Any]:
""" """
处理音频文件 处理音频文件
@@ -307,7 +303,7 @@ class MultimodalService:
"type": "text", "type": "text",
"text": "[音频文件,暂不支持处理]" "text": "[音频文件,暂不支持处理]"
} }
async def _process_video(self, file: FileInput) -> Dict[str, Any]: async def _process_video(self, file: FileInput) -> Dict[str, Any]:
""" """
处理视频文件 处理视频文件
@@ -323,13 +319,13 @@ class MultimodalService:
"type": "text", "type": "text",
"text": "[视频文件,暂不支持处理]" "text": "[视频文件,暂不支持处理]"
} }
async def _get_file_url(self, file_id: uuid.UUID) -> str: async def get_file_url(self, file: FileInput) -> str:
""" """
获取文件的访问 URL 获取文件的访问 URL
Args: Args:
file_id: 文件ID file: File Input Struct
Returns: Returns:
str: 文件访问 URL str: 文件访问 URL
@@ -337,26 +333,31 @@ class MultimodalService:
Raises: Raises:
BusinessException: 文件不存在 BusinessException: 文件不存在
""" """
generic_file = self.db.query(GenericFile).filter( if file.transfer_method == TransferMethod.REMOTE_URL:
GenericFile.id == file_id, return file.url
GenericFile.status == "active" else:
).first() # 本地文件,获取访问 URL
file_id = file.upload_file_id
if not generic_file: generic_file = self.db.query(GenericFile).filter(
raise BusinessException( GenericFile.id == file.upload_file_id,
f"文件不存在或已删除: {file_id}", GenericFile.status == "active"
BizCode.NOT_FOUND ).first()
)
if not generic_file:
# 如果有 access_url直接返回 raise BusinessException(
if generic_file.access_url: f"文件不存在或已删除: {file.upload_file_id}",
return generic_file.access_url BizCode.NOT_FOUND
)
# 否则,根据 storage_path 生成 URL
# TODO: 根据实际存储方式生成 URL本地存储、OSS 等) # 如果有 access_url直接返回
# 这里暂时返回一个占位 URL if generic_file.access_url:
return f"/api/files/{file_id}/download" return generic_file.access_url
# 否则,根据 storage_path 生成 URL
# TODO: 根据实际存储方式生成 URL本地存储、OSS 等)
# 这里暂时返回一个占位 URL
return f"/api/files/{file_id}/download"
async def _extract_document_text(self, file_id: uuid.UUID) -> str: async def _extract_document_text(self, file_id: uuid.UUID) -> str:
""" """
提取文档文本内容 提取文档文本内容
@@ -371,20 +372,20 @@ class MultimodalService:
GenericFile.id == file_id, GenericFile.id == file_id,
GenericFile.status == "active" GenericFile.status == "active"
).first() ).first()
if not generic_file: if not generic_file:
raise BusinessException( raise BusinessException(
f"文件不存在或已删除: {file_id}", f"文件不存在或已删除: {file_id}",
BizCode.NOT_FOUND BizCode.NOT_FOUND
) )
# TODO: 根据文件类型提取文本 # TODO: 根据文件类型提取文本
# - PDF: 使用 PyPDF2 或 pdfplumber # - PDF: 使用 PyPDF2 或 pdfplumber
# - Word: 使用 python-docx # - Word: 使用 python-docx
# - TXT/MD: 直接读取 # - TXT/MD: 直接读取
file_ext = generic_file.file_ext.lower() file_ext = generic_file.file_ext.lower()
if file_ext in ['.txt', '.md', '.markdown']: if file_ext in ['.txt', '.md', '.markdown']:
return await self._read_text_file(generic_file.storage_path) return await self._read_text_file(generic_file.storage_path)
elif file_ext == '.pdf': elif file_ext == '.pdf':
@@ -393,7 +394,7 @@ class MultimodalService:
return await self._extract_word_text(generic_file.storage_path) return await self._extract_word_text(generic_file.storage_path)
else: else:
return f"[不支持的文档格式: {file_ext}]" return f"[不支持的文档格式: {file_ext}]"
async def _read_text_file(self, storage_path: str) -> str: async def _read_text_file(self, storage_path: str) -> str:
"""读取纯文本文件""" """读取纯文本文件"""
try: try:
@@ -402,7 +403,7 @@ class MultimodalService:
except Exception as e: except Exception as e:
logger.error(f"读取文本文件失败: {e}") logger.error(f"读取文本文件失败: {e}")
return f"[文件读取失败: {str(e)}]" return f"[文件读取失败: {str(e)}]"
async def _extract_pdf_text(self, storage_path: str) -> str: async def _extract_pdf_text(self, storage_path: str) -> str:
"""提取 PDF 文本""" """提取 PDF 文本"""
try: try:
@@ -412,7 +413,7 @@ class MultimodalService:
except Exception as e: except Exception as e:
logger.error(f"提取 PDF 文本失败: {e}") logger.error(f"提取 PDF 文本失败: {e}")
return f"[PDF 提取失败: {str(e)}]" return f"[PDF 提取失败: {str(e)}]"
async def _extract_word_text(self, storage_path: str) -> str: async def _extract_word_text(self, storage_path: str) -> str:
"""提取 Word 文档文本""" """提取 Word 文档文本"""
try: try:

View File

@@ -1,4 +1,3 @@
{% raw %}
Role: AI Prompt Optimization Expert Role: AI Prompt Optimization Expert
Profile Profile
@@ -12,11 +11,11 @@ Skills
Core Optimization Skills Core Optimization Skills
Requirement Analysis: Accurately understand the relationship between the users current needs and the original prompt. Requirement Analysis: Accurately understand the relationship between the users current needs and the original prompt.
Structural Reconstruction: Transform vague requirements into clear, block-structured instructions. Structural Reconstruction: Transform vague requirements into clear, block-structured instructions.
Variable Handling: Identify and standardize dynamic variables in prompts. {% if skill != true %}Variable Handling: Identify and standardize dynamic variables in prompts.{% endif %}
Conflict Resolution: Prioritize current requirements when historical requirements conflict with current needs. Conflict Resolution: Prioritize current requirements when historical requirements conflict with current needs.
Auxiliary Generation Skills Auxiliary Generation Skills
Completeness Check: Ensure all necessary elements (input, output, constraints, etc.) are explicitly defined. {% if skill != true %}Completeness Check: Ensure all necessary elements (input, output, constraints, etc.) are explicitly defined.{% endif %}
Language Consistency: Maintain consistency between label language and user input language. Language Consistency: Maintain consistency between label language and user input language.
Executability Verification: Ensure optimized prompts can be directly used in AI tools. Executability Verification: Ensure optimized prompts can be directly used in AI tools.
Format Standardization: Strictly adhere to specified output format requirements. Format Standardization: Strictly adhere to specified output format requirements.
@@ -25,30 +24,30 @@ Rules
Basic Principles Basic Principles
Priority Rule: When historical requirements conflict with current requirements, unconditionally prioritize current requirements. Priority Rule: When historical requirements conflict with current requirements, unconditionally prioritize current requirements.
Completeness Rule: If the original prompt is empty, generate a complete prompt based on the current requirements. Completeness Rule: If the original prompt is empty, generate a complete prompt based on the current requirements.
Structure Rule: Use a clear block structure, and the contents of each block are roles, tasks, requirements, inputs, outputs, and constraints {% if skill != true %}Structure Rule: Use a clear block structure, and the contents of each block are roles, tasks, requirements, inputs, outputs, and constraints{% endif %}
Language Rule: All label languages must fully match the user input language. Language Rule: All label languages must fully match the user input language.
Behavior Guidelines Behavior Guidelines
Precision Guideline: All instructions must be precise and directly executable, avoiding ambiguity. Precision Guideline: All instructions must be precise and directly executable, avoiding ambiguity.
Readability Guideline: Ensure optimized prompts have good readability and logical flow. Readability Guideline: Ensure optimized prompts have good readability and logical flow.
Variable Handling Guideline: Use lowercase English variable names wrapped in {{}} when variables are needed. {% if skill != true %}{% raw %}Variable Handling Guideline: Use lowercase English variable names wrapped in {{}} when variables are needed.
Constraint Handling Guideline: Do not mention variable-related limitations under the [Constraints] label. Constraint Handling Guideline: Do not mention variable-related limitations under the [Constraints] label.{% endraw %}{% endif %}
Constraints Constraints
Output Constraint: Must output in JSON format including the fields "prompt" and "desc". Output Constraint: Must output in JSON format including the fields "prompt" and "desc".
Content Constraint: Must not include any explanations, analyses, or additional comments. Content Constraint: Must not include any explanations, analyses, or additional comments.
Language Constraint: Must use clear and concise language. Language Constraint: Must use clear and concise language.
Completeness Constraint: Must fully define all missing elements (input details, output format, constraints, etc.). {% if skill != true %}Completeness Constraint: Must fully define all missing elements (input details, output format, constraints, etc.).{% endif %}
Workflows Workflows
Goal: Optimize or generate AI prompts that can be directly used according to user requirements. Goal: Optimize or generate AI prompts that can be directly used according to user requirements.
Step 1: Receive the users current requirement description {{user_require}} and the original prompt {{original_prompt}}. Step 1: Receive the users current requirement description {{user_require}} and the original prompt {{original_prompt}}.
Step 2: Analyze requirements, identify conflicts, and prioritize current requirements. Step 2: Analyze requirements, identify conflicts, and prioritize current requirements.
Step 3: Optimize or generate the prompt in a block-structured format, ensuring all elements are fully defined. {% if skill != true %}Step 3: Optimize or generate the prompt in a block-structured format, ensuring all elements are fully defined.
Step 4: Generate a JSON output containing the optimized prompt and its description. Step 4: Generate a JSON output containing the optimized prompt and its description.
{% else %}Step 3: Generate a JSON output containing the optimized prompt and its description.{% endif %}
Expected Outcome: Obtain a clear, directly executable AI prompt accompanied by an optimization description. Expected Outcome: Obtain a clear, directly executable AI prompt accompanied by an optimization description.
Initialization Initialization
As an AI Prompt Optimization Expert, you must follow the above Rules and execute tasks according to the Workflows. As an AI Prompt Optimization Expert, you must follow the above Rules and execute tasks according to the Workflows.
{% endraw %}

View File

@@ -128,7 +128,8 @@ class PromptOptimizerService:
session_id: uuid.UUID, session_id: uuid.UUID,
user_id: uuid.UUID, user_id: uuid.UUID,
current_prompt: str, current_prompt: str,
user_require: str user_require: str,
skill: bool = False
) -> AsyncGenerator[dict[str, str | Any], Any]: ) -> AsyncGenerator[dict[str, str | Any], Any]:
""" """
Optimize a user-provided prompt using a configured prompt optimizer LLM. Optimize a user-provided prompt using a configured prompt optimizer LLM.
@@ -157,6 +158,7 @@ class PromptOptimizerService:
user_id (uuid.UUID): Identifier of the user associated with the session. user_id (uuid.UUID): Identifier of the user associated with the session.
current_prompt (str): Original prompt to optimize. current_prompt (str): Original prompt to optimize.
user_require (str): User's requirements or instructions for optimization. user_require (str): User's requirements or instructions for optimization.
skill(bool): Is skill required
Returns: Returns:
OptimizePromptResult: An object containing: OptimizePromptResult: An object containing:
@@ -186,7 +188,7 @@ class PromptOptimizerService:
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
with open(os.path.join(prompt_path, 'prompt_optimizer_system.jinja2'), 'r', encoding='utf-8') as f: with open(os.path.join(prompt_path, 'prompt_optimizer_system.jinja2'), 'r', encoding='utf-8') as f:
opt_system_prompt = f.read() opt_system_prompt = f.read()
rendered_system_message = Template(opt_system_prompt).render() rendered_system_message = Template(opt_system_prompt).render(skill=skill)
with open(os.path.join(prompt_path, 'prompt_optimizer_user.jinja2'), 'r', encoding='utf-8') as f: with open(os.path.join(prompt_path, 'prompt_optimizer_user.jinja2'), 'r', encoding='utf-8') as f:
opt_user_prompt = f.read() opt_user_prompt = f.read()

View File

@@ -22,6 +22,7 @@ from app.repositories.workflow_repository import (
from app.schemas import DraftRunRequest from app.schemas import DraftRunRequest
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
from app.services.multi_agent_service import convert_uuids_to_str from app.services.multi_agent_service import convert_uuids_to_str
from app.services.multimodal_service import MultimodalService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -35,6 +36,7 @@ class WorkflowService:
self.execution_repo = WorkflowExecutionRepository(db) self.execution_repo = WorkflowExecutionRepository(db)
self.node_execution_repo = WorkflowNodeExecutionRepository(db) self.node_execution_repo = WorkflowNodeExecutionRepository(db)
self.conversation_service = ConversationService(db) self.conversation_service = ConversationService(db)
self.multimodal_service = MultimodalService(db)
# ==================== 配置管理 ==================== # ==================== 配置管理 ====================
@@ -444,8 +446,19 @@ class WorkflowService:
code=BizCode.CONFIG_MISSING, code=BizCode.CONFIG_MISSING,
message=f"工作流配置不存在: app_id={app_id}" message=f"工作流配置不存在: app_id={app_id}"
) )
files = []
if payload.files:
for file in payload.files:
files.append(
{
"type": file.type,
"url": await self.multimodal_service.get_file_url(file),
"__file": True
}
)
input_data = {"message": payload.message, "variables": payload.variables, input_data = {"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id} "conversation_id": payload.conversation_id, "files": files}
# 转换 user_id 为 UUID # 转换 user_id 为 UUID
triggered_by_uuid = None triggered_by_uuid = None
@@ -633,8 +646,20 @@ class WorkflowService:
code=BizCode.CONFIG_MISSING, code=BizCode.CONFIG_MISSING,
message=f"工作流配置不存在: app_id={app_id}" message=f"工作流配置不存在: app_id={app_id}"
) )
files = []
if payload.files:
for file in payload.files:
files.append(
{
"type": file.type,
"url": await self.multimodal_service.get_file_url(file),
"__file": True
}
)
input_data = {"message": payload.message, "variables": payload.variables, input_data = {"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id} "conversation_id": payload.conversation_id, "files": files}
# 转换 user_id 为 UUID # 转换 user_id 为 UUID
triggered_by_uuid = None triggered_by_uuid = None