Merge pull request #325 from SuanmoSuanyangTechnology/feature/workflow-file
feat(workflow, skill): add multimodal image support to workflows and skill prompt generation
This commit is contained in:
@@ -120,7 +120,8 @@ async def get_prompt_opt(
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
current_prompt=data.current_prompt,
|
current_prompt=data.current_prompt,
|
||||||
user_require=data.message
|
user_require=data.message,
|
||||||
|
skill=data.skill
|
||||||
):
|
):
|
||||||
# chunk 是 prompt 的增量内容
|
# chunk 是 prompt 的增量内容
|
||||||
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
|
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ class WorkflowExecutor:
|
|||||||
- "conversation_id": conversation identifier
|
- "conversation_id": conversation identifier
|
||||||
"""
|
"""
|
||||||
user_message = input_data.get("message") or ""
|
user_message = input_data.get("message") or ""
|
||||||
user_file = input_data.get("file") or []
|
user_files = input_data.get("files") or []
|
||||||
|
|
||||||
config_variables_list = self.workflow_config.get("variables") or []
|
config_variables_list = self.workflow_config.get("variables") or []
|
||||||
conv_vars = input_data.get("conv", {})
|
conv_vars = input_data.get("conv", {})
|
||||||
@@ -119,12 +119,12 @@ class WorkflowExecutor:
|
|||||||
input_variables = input_data.get("variables") or {}
|
input_variables = input_data.get("variables") or {}
|
||||||
sys_vars = {
|
sys_vars = {
|
||||||
"message": (user_message, VariableType.STRING),
|
"message": (user_message, VariableType.STRING),
|
||||||
"file": (user_file, VariableType.ARRAY_FILE),
|
|
||||||
"conversation_id": (input_data.get("conversation_id"), VariableType.STRING),
|
"conversation_id": (input_data.get("conversation_id"), VariableType.STRING),
|
||||||
"execution_id": (self.execution_id, VariableType.STRING),
|
"execution_id": (self.execution_id, VariableType.STRING),
|
||||||
"workspace_id": (self.workspace_id, VariableType.STRING),
|
"workspace_id": (self.workspace_id, VariableType.STRING),
|
||||||
"user_id": (self.user_id, VariableType.STRING),
|
"user_id": (self.user_id, VariableType.STRING),
|
||||||
"input_variables": (input_variables, VariableType.OBJECT),
|
"input_variables": (input_variables, VariableType.OBJECT),
|
||||||
|
"files": (user_files, VariableType.ARRAY_FILE)
|
||||||
}
|
}
|
||||||
for key, var_def in sys_vars.items():
|
for key, var_def in sys_vars.items():
|
||||||
value = var_def[0]
|
value = var_def[0]
|
||||||
@@ -564,6 +564,7 @@ class WorkflowExecutor:
|
|||||||
"input": result.get("node_outputs", {}).get(node_name, {}).get("input"),
|
"input": result.get("node_outputs", {}).get(node_name, {}).get("input"),
|
||||||
"output": result.get("node_outputs", {}).get(node_name, {}).get("output"),
|
"output": result.get("node_outputs", {}).get(node_name, {}).get("output"),
|
||||||
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
|
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
|
||||||
|
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from app.core.config import settings
|
|||||||
from app.core.workflow.nodes.enums import BRANCH_NODES
|
from app.core.workflow.nodes.enums import BRANCH_NODES
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
from app.core.workflow.variable_pool import VariablePool
|
from app.core.workflow.variable_pool import VariablePool
|
||||||
|
from app.services.multimodal_service import PROVIDER_STRATEGIES
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -651,3 +652,21 @@ class BaseNode(ABC):
|
|||||||
True if the variable exists in the pool, False otherwise.
|
True if the variable exists in the pool, False otherwise.
|
||||||
"""
|
"""
|
||||||
return variable_pool.has(selector)
|
return variable_pool.has(selector)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def process_message(provider, content, enable_file=False) -> dict | str | None:
|
||||||
|
if isinstance(content, str):
|
||||||
|
if enable_file:
|
||||||
|
return {"text": content}
|
||||||
|
return content
|
||||||
|
elif isinstance(content, dict):
|
||||||
|
trans_tool = PROVIDER_STRATEGIES[provider]()
|
||||||
|
result = await trans_tool.format_image(content["url"])
|
||||||
|
return result
|
||||||
|
raise TypeError('Unexpect input value type')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def process_model_output(content) -> str:
|
||||||
|
if isinstance(content, dict):
|
||||||
|
return content.get("text")
|
||||||
|
return content
|
||||||
|
|||||||
@@ -71,6 +71,16 @@ class LLMNodeConfig(BaseNodeConfig):
|
|||||||
description="对话上下文窗口"
|
description="对话上下文窗口"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
vision: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="是否启用视觉模型"
|
||||||
|
)
|
||||||
|
|
||||||
|
vision_input: str = Field(
|
||||||
|
default=None,
|
||||||
|
description="视觉输入"
|
||||||
|
)
|
||||||
|
|
||||||
# 简单模式
|
# 简单模式
|
||||||
prompt: str | None = Field(
|
prompt: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|||||||
@@ -79,12 +79,12 @@ class LLMNode(BaseNode):
|
|||||||
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
|
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
|
||||||
return re.sub(r"{{context}}", context, message)
|
return re.sub(r"{{context}}", context, message)
|
||||||
|
|
||||||
def _prepare_llm(
|
async def _prepare_llm(
|
||||||
self,
|
self,
|
||||||
state: WorkflowState,
|
state: WorkflowState,
|
||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
) -> tuple[RedBearLLM, list | str]:
|
) -> RedBearLLM:
|
||||||
"""准备 LLM 实例(公共逻辑)
|
"""准备 LLM 实例(公共逻辑)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -93,42 +93,9 @@ class LLMNode(BaseNode):
|
|||||||
Returns:
|
Returns:
|
||||||
(llm, messages_or_prompt): LLM 实例和消息列表或 prompt 字符串
|
(llm, messages_or_prompt): LLM 实例和消息列表或 prompt 字符串
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 1. 处理消息格式(优先使用 messages)
|
|
||||||
self.typed_config = LLMNodeConfig(**self.config)
|
self.typed_config = LLMNodeConfig(**self.config)
|
||||||
messages_config = self.typed_config.messages
|
|
||||||
|
|
||||||
if messages_config:
|
model_id = self.typed_config.model_id
|
||||||
# 使用 LangChain 消息格式
|
|
||||||
messages = []
|
|
||||||
for msg_config in messages_config:
|
|
||||||
role = msg_config.role.lower()
|
|
||||||
content_template = msg_config.content
|
|
||||||
content_template = self._render_context(content_template, variable_pool)
|
|
||||||
content = self._render_template(content_template, variable_pool)
|
|
||||||
|
|
||||||
# 根据角色创建对应的消息对象
|
|
||||||
if role == "system":
|
|
||||||
messages.append({"role": "system", "content": content})
|
|
||||||
elif role in ["user", "human"]:
|
|
||||||
messages.append({"role": "user", "content": content})
|
|
||||||
elif role in ["ai", "assistant"]:
|
|
||||||
messages.append({"role": "assistant", "content": content})
|
|
||||||
else:
|
|
||||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
|
||||||
messages.append({"role": "user", "content": content})
|
|
||||||
|
|
||||||
if self.typed_config.memory.enable:
|
|
||||||
# if self.typed_config.memory.enable_window:
|
|
||||||
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
|
|
||||||
prompt_or_messages = messages
|
|
||||||
else:
|
|
||||||
# 使用简单的 prompt 格式(向后兼容)
|
|
||||||
prompt_template = self.config.get("prompt", "")
|
|
||||||
prompt_or_messages = self._render_template(prompt_template, variable_pool)
|
|
||||||
|
|
||||||
# 2. 获取模型配置
|
|
||||||
model_id = self.config.get("model_id")
|
|
||||||
if not model_id:
|
if not model_id:
|
||||||
raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置")
|
raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置")
|
||||||
|
|
||||||
@@ -167,7 +134,61 @@ class LLMNode(BaseNode):
|
|||||||
|
|
||||||
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
|
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
|
||||||
|
|
||||||
return llm, prompt_or_messages
|
messages_config = self.typed_config.messages
|
||||||
|
|
||||||
|
if messages_config:
|
||||||
|
# 使用 LangChain 消息格式
|
||||||
|
messages = []
|
||||||
|
for msg_config in messages_config:
|
||||||
|
role = msg_config.role.lower()
|
||||||
|
content_template = msg_config.content
|
||||||
|
content_template = self._render_context(content_template, variable_pool)
|
||||||
|
content = self._render_template(content_template, variable_pool)
|
||||||
|
|
||||||
|
# 根据角色创建对应的消息对象
|
||||||
|
if role == "system":
|
||||||
|
messages.append({
|
||||||
|
"role": "system",
|
||||||
|
"content": content
|
||||||
|
})
|
||||||
|
elif role in ["user", "human"]:
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": content
|
||||||
|
})
|
||||||
|
elif role in ["ai", "assistant"]:
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": content
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": content
|
||||||
|
})
|
||||||
|
|
||||||
|
if self.typed_config.vision_input and self.typed_config.vision:
|
||||||
|
file_content = []
|
||||||
|
files = variable_pool.get_value(self.typed_config.vision_input)
|
||||||
|
for file in files:
|
||||||
|
content = await self.process_message(provider, file, self.typed_config.vision)
|
||||||
|
if content:
|
||||||
|
file_content.append(content)
|
||||||
|
if messages and messages[-1]["role"] == 'user':
|
||||||
|
messages[-1]['content'] = [messages[-1]["content"]] + file_content
|
||||||
|
else:
|
||||||
|
messages.append({"role": "user", "content": file_content})
|
||||||
|
|
||||||
|
if self.typed_config.memory.enable:
|
||||||
|
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
|
||||||
|
self.messages = messages
|
||||||
|
else:
|
||||||
|
# 使用简单的 prompt 格式(向后兼容)
|
||||||
|
prompt_template = self.config.get("prompt", "")
|
||||||
|
self.messages = self._render_template(prompt_template, variable_pool)
|
||||||
|
|
||||||
|
return llm
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage:
|
||||||
"""非流式执行 LLM 调用
|
"""非流式执行 LLM 调用
|
||||||
@@ -180,15 +201,15 @@ class LLMNode(BaseNode):
|
|||||||
LLM 响应消息
|
LLM 响应消息
|
||||||
"""
|
"""
|
||||||
# self.typed_config = LLMNodeConfig(**self.config)
|
# self.typed_config = LLMNodeConfig(**self.config)
|
||||||
llm, prompt_or_messages = self._prepare_llm(state, variable_pool, False)
|
llm = await self._prepare_llm(state, variable_pool, False)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||||
|
|
||||||
# 调用 LLM(支持字符串或消息列表)
|
# 调用 LLM(支持字符串或消息列表)
|
||||||
response = await llm.ainvoke(prompt_or_messages)
|
response = await llm.ainvoke(self.messages)
|
||||||
# 提取内容
|
# 提取内容
|
||||||
if hasattr(response, 'content'):
|
if hasattr(response, 'content'):
|
||||||
content = response.content
|
content = self.process_model_output(response.content)
|
||||||
else:
|
else:
|
||||||
content = str(response)
|
content = str(response)
|
||||||
|
|
||||||
@@ -199,14 +220,13 @@ class LLMNode(BaseNode):
|
|||||||
|
|
||||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
"""提取输入数据(用于记录)"""
|
"""提取输入数据(用于记录)"""
|
||||||
_, prompt_or_messages = self._prepare_llm(state, variable_pool)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
|
"prompt": self.messages if isinstance(self.messages, str) else None,
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": msg.get("role"), "content": msg.get("content", "")}
|
{"role": msg.get("role"), "content": msg.get("content", "")}
|
||||||
for msg in prompt_or_messages
|
for msg in self.messages
|
||||||
] if isinstance(prompt_or_messages, list) else None,
|
] if isinstance(self.messages, list) else None,
|
||||||
"config": {
|
"config": {
|
||||||
"model_id": self.config.get("model_id"),
|
"model_id": self.config.get("model_id"),
|
||||||
"temperature": self.config.get("temperature"),
|
"temperature": self.config.get("temperature"),
|
||||||
@@ -226,8 +246,8 @@ class LLMNode(BaseNode):
|
|||||||
usage = business_result.response_metadata.get('token_usage')
|
usage = business_result.response_metadata.get('token_usage')
|
||||||
if usage:
|
if usage:
|
||||||
return {
|
return {
|
||||||
"prompt_tokens": usage.get('prompt_tokens', 0),
|
"prompt_tokens": usage.get('input_tokens', 0),
|
||||||
"completion_tokens": usage.get('completion_tokens', 0),
|
"completion_tokens": usage.get('output_tokens', 0),
|
||||||
"total_tokens": usage.get('total_tokens', 0)
|
"total_tokens": usage.get('total_tokens', 0)
|
||||||
}
|
}
|
||||||
return None
|
return None
|
||||||
@@ -244,7 +264,7 @@ class LLMNode(BaseNode):
|
|||||||
"""
|
"""
|
||||||
self.typed_config = LLMNodeConfig(**self.config)
|
self.typed_config = LLMNodeConfig(**self.config)
|
||||||
|
|
||||||
llm, prompt_or_messages = self._prepare_llm(state, variable_pool, True)
|
llm = await self._prepare_llm(state, variable_pool, True)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||||
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||||
@@ -255,10 +275,10 @@ class LLMNode(BaseNode):
|
|||||||
|
|
||||||
# 调用 LLM(流式,支持字符串或消息列表)
|
# 调用 LLM(流式,支持字符串或消息列表)
|
||||||
last_meta_data = {}
|
last_meta_data = {}
|
||||||
async for chunk in llm.astream(prompt_or_messages, stream_usage=True):
|
async for chunk in llm.astream(self.messages, stream_usage=True):
|
||||||
# 提取内容
|
# 提取内容
|
||||||
if hasattr(chunk, 'content'):
|
if hasattr(chunk, 'content'):
|
||||||
content = chunk.content
|
content = self.process_model_output(chunk.content)
|
||||||
else:
|
else:
|
||||||
content = str(chunk)
|
content = str(chunk)
|
||||||
if hasattr(chunk, 'response_metadata'):
|
if hasattr(chunk, 'response_metadata'):
|
||||||
|
|||||||
@@ -2,6 +2,10 @@ from enum import StrEnum
|
|||||||
from abc import abstractmethod, ABC
|
from abc import abstractmethod, ABC
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.schemas import FileType
|
||||||
|
|
||||||
|
|
||||||
class VariableType(StrEnum):
|
class VariableType(StrEnum):
|
||||||
"""Enumeration of supported variable types in the workflow."""
|
"""Enumeration of supported variable types in the workflow."""
|
||||||
@@ -42,7 +46,7 @@ class VariableType(StrEnum):
|
|||||||
return cls.NUMBER
|
return cls.NUMBER
|
||||||
elif isinstance(var_type, bool):
|
elif isinstance(var_type, bool):
|
||||||
return cls.BOOLEAN
|
return cls.BOOLEAN
|
||||||
elif isinstance(var_type, FileObj):
|
elif isinstance(var_type, FileObject) or (isinstance(var, dict) and var.get('__file')):
|
||||||
return cls.FILE
|
return cls.FILE
|
||||||
elif isinstance(var_type, dict):
|
elif isinstance(var_type, dict):
|
||||||
return cls.OBJECT
|
return cls.OBJECT
|
||||||
@@ -103,8 +107,10 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any:
|
|||||||
raise TypeError(f"Invalid type - {type}")
|
raise TypeError(f"Invalid type - {type}")
|
||||||
|
|
||||||
|
|
||||||
class FileObj:
|
class FileObject(BaseModel):
|
||||||
pass
|
type: FileType
|
||||||
|
url: str
|
||||||
|
__file: bool
|
||||||
|
|
||||||
|
|
||||||
class BaseVariable(ABC):
|
class BaseVariable(ABC):
|
||||||
|
|||||||
@@ -1,66 +1,84 @@
|
|||||||
from typing import Any, TypeVar, Type, Generic
|
from typing import Any, TypeVar, Type, Generic
|
||||||
|
|
||||||
from app.core.workflow.variable.base_variable import BaseVariable, VariableType
|
from deprecated import deprecated
|
||||||
|
|
||||||
|
from app.core.workflow.variable.base_variable import BaseVariable, VariableType, FileObject
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseVariable)
|
T = TypeVar("T", bound=BaseVariable)
|
||||||
|
|
||||||
|
|
||||||
class StringObject(BaseVariable):
|
class StringVariable(BaseVariable):
|
||||||
type = 'str'
|
type = 'str'
|
||||||
|
|
||||||
def valid_value(self, value) -> str:
|
def valid_value(self, value) -> str:
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
raise TypeError("Value must be a string")
|
raise TypeError(f"Value must be a string - {type(value)}:{value}")
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def to_literal(self) -> str:
|
def to_literal(self) -> str:
|
||||||
return self.value
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
class NumberObject(BaseVariable):
|
class NumberVariable(BaseVariable):
|
||||||
type = 'number'
|
type = 'number'
|
||||||
|
|
||||||
def valid_value(self, value) -> int | float:
|
def valid_value(self, value) -> int | float:
|
||||||
if not isinstance(value, (int, float)):
|
if not isinstance(value, (int, float)):
|
||||||
raise TypeError("Value must be a number")
|
raise TypeError(f"Value must be a number - {type(value)}:{value}")
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def to_literal(self) -> str:
|
def to_literal(self) -> str:
|
||||||
return str(self.value)
|
return str(self.value)
|
||||||
|
|
||||||
|
|
||||||
class BooleanObject(BaseVariable):
|
class BooleanVariable(BaseVariable):
|
||||||
type = 'boolean'
|
type = 'boolean'
|
||||||
|
|
||||||
def valid_value(self, value) -> bool:
|
def valid_value(self, value) -> bool:
|
||||||
if not isinstance(value, bool):
|
if not isinstance(value, bool):
|
||||||
raise TypeError("Value must be a boolean")
|
raise TypeError(f"Value must be a boolean - {type(value)}:{value}")
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def to_literal(self) -> str:
|
def to_literal(self) -> str:
|
||||||
return str(self.value).lower()
|
return str(self.value).lower()
|
||||||
|
|
||||||
|
|
||||||
class DictObject(BaseVariable):
|
class DictVariable(BaseVariable):
|
||||||
type = 'object'
|
type = 'object'
|
||||||
|
|
||||||
def valid_value(self, value) -> dict:
|
def valid_value(self, value) -> dict:
|
||||||
if not isinstance(value, dict):
|
if not isinstance(value, dict):
|
||||||
raise TypeError("Value must be a dict")
|
raise TypeError(f"Value must be a dict - {type(value)}:{value}")
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def to_literal(self) -> str:
|
def to_literal(self) -> str:
|
||||||
return str(self.value)
|
return str(self.value)
|
||||||
|
|
||||||
|
|
||||||
class FileObject(BaseVariable):
|
class FileVariable(BaseVariable):
|
||||||
type = 'file'
|
type = 'file'
|
||||||
|
|
||||||
def valid_value(self, value) -> Any:
|
def valid_value(self, value) -> FileObject:
|
||||||
pass
|
|
||||||
|
if isinstance(value, dict):
|
||||||
|
if not value.get("__file"):
|
||||||
|
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
|
||||||
|
return FileObject(
|
||||||
|
**{
|
||||||
|
"type": str(value.get('type')),
|
||||||
|
"url": value.get('url'),
|
||||||
|
"__file": True
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if isinstance(value, FileObject):
|
||||||
|
return value
|
||||||
|
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
|
||||||
|
|
||||||
def to_literal(self) -> str:
|
def to_literal(self) -> str:
|
||||||
pass
|
return str(self.value.model_dump())
|
||||||
|
|
||||||
|
def get_value(self) -> Any:
|
||||||
|
return self.value.model_dump()
|
||||||
|
|
||||||
|
|
||||||
class ArrayObject(BaseVariable, Generic[T]):
|
class ArrayObject(BaseVariable, Generic[T]):
|
||||||
@@ -74,7 +92,7 @@ class ArrayObject(BaseVariable, Generic[T]):
|
|||||||
|
|
||||||
def valid_value(self, value: list[Any]) -> list[T]:
|
def valid_value(self, value: list[Any]) -> list[T]:
|
||||||
if not isinstance(value, list):
|
if not isinstance(value, list):
|
||||||
raise TypeError("Value must be a list")
|
raise TypeError(f"Value must be a list - {type(value)}:{value}")
|
||||||
final_value = []
|
final_value = []
|
||||||
for v in value:
|
for v in value:
|
||||||
try:
|
try:
|
||||||
@@ -86,13 +104,16 @@ class ArrayObject(BaseVariable, Generic[T]):
|
|||||||
def to_literal(self) -> str:
|
def to_literal(self) -> str:
|
||||||
return "\n".join([v.to_literal() for v in self.value])
|
return "\n".join([v.to_literal() for v in self.value])
|
||||||
|
|
||||||
|
def get_value(self) -> Any:
|
||||||
|
return [v.get_value() for v in self.value]
|
||||||
|
|
||||||
|
|
||||||
class NestedArrayObject(BaseVariable):
|
class NestedArrayObject(BaseVariable):
|
||||||
type = 'array_nest'
|
type = 'array_nest'
|
||||||
|
|
||||||
def valid_value(self, value: list[T]) -> list[T]:
|
def valid_value(self, value: list[T]) -> list[T]:
|
||||||
if not isinstance(value, list):
|
if not isinstance(value, list):
|
||||||
raise TypeError("Value must be a list")
|
raise TypeError(f"Value must be a list - {type(value)}:{value}")
|
||||||
final_value = []
|
final_value = []
|
||||||
for v in value:
|
for v in value:
|
||||||
if not isinstance(v, ArrayObject):
|
if not isinstance(v, ArrayObject):
|
||||||
@@ -107,6 +128,10 @@ class NestedArrayObject(BaseVariable):
|
|||||||
return [[item.get_value() for item in row] for row in self.value]
|
return [[item.get_value() for item in row] for row in self.value]
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated(
|
||||||
|
reason="Using arbitrary-type values may cause unexpected errors; please switch to strongly-typed values.",
|
||||||
|
category=RuntimeWarning
|
||||||
|
)
|
||||||
class AnyObject(BaseVariable):
|
class AnyObject(BaseVariable):
|
||||||
type = 'any'
|
type = 'any'
|
||||||
|
|
||||||
@@ -126,23 +151,23 @@ def make_array(child_type: Type[T], value: list[Any]) -> ArrayObject[T]:
|
|||||||
def create_variable_instance(var_type: VariableType, value: Any) -> T:
|
def create_variable_instance(var_type: VariableType, value: Any) -> T:
|
||||||
match var_type:
|
match var_type:
|
||||||
case VariableType.STRING:
|
case VariableType.STRING:
|
||||||
return StringObject(value)
|
return StringVariable(value)
|
||||||
case VariableType.NUMBER:
|
case VariableType.NUMBER:
|
||||||
return NumberObject(value)
|
return NumberVariable(value)
|
||||||
case VariableType.BOOLEAN:
|
case VariableType.BOOLEAN:
|
||||||
return BooleanObject(value)
|
return BooleanVariable(value)
|
||||||
case VariableType.OBJECT:
|
case VariableType.OBJECT:
|
||||||
return DictObject(value)
|
return DictVariable(value)
|
||||||
case VariableType.ARRAY_STRING:
|
case VariableType.ARRAY_STRING:
|
||||||
return make_array(StringObject, value)
|
return make_array(StringVariable, value)
|
||||||
case VariableType.ARRAY_NUMBER:
|
case VariableType.ARRAY_NUMBER:
|
||||||
return make_array(NumberObject, value)
|
return make_array(NumberVariable, value)
|
||||||
case VariableType.ARRAY_BOOLEAN:
|
case VariableType.ARRAY_BOOLEAN:
|
||||||
return make_array(BooleanObject, value)
|
return make_array(BooleanVariable, value)
|
||||||
case VariableType.ARRAY_OBJECT:
|
case VariableType.ARRAY_OBJECT:
|
||||||
return make_array(DictObject, value)
|
return make_array(DictVariable, value)
|
||||||
case VariableType.ARRAY_FILE:
|
case VariableType.ARRAY_FILE:
|
||||||
return make_array(FileObject, value)
|
return make_array(FileVariable, value)
|
||||||
case VariableType.ANY:
|
case VariableType.ANY:
|
||||||
return AnyObject(value)
|
return AnyObject(value)
|
||||||
case _:
|
case _:
|
||||||
|
|||||||
@@ -78,8 +78,8 @@ class VariableStruct(BaseModel, Generic[T]):
|
|||||||
serialization, and workflow type checking.
|
serialization, and workflow type checking.
|
||||||
instance:
|
instance:
|
||||||
The concrete variable object. The actual Python type is
|
The concrete variable object. The actual Python type is
|
||||||
represented by the generic parameter ``T`` (e.g. StringObject,
|
represented by the generic parameter ``T`` (e.g. StringVariable,
|
||||||
NumberObject, ArrayObject[StringObject]).
|
NumberVariable, ArrayObject[StringVariable]).
|
||||||
mut:
|
mut:
|
||||||
Whether the variable is mutable.
|
Whether the variable is mutable.
|
||||||
"""
|
"""
|
||||||
@@ -286,7 +286,7 @@ class VariablePool:
|
|||||||
系统变量字典
|
系统变量字典
|
||||||
"""
|
"""
|
||||||
sys_namespace = self.variables.get("sys", {})
|
sys_namespace = self.variables.get("sys", {})
|
||||||
return {k: v.instance.value for k, v in sys_namespace.items()}
|
return {k: v.instance.get_value() for k, v in sys_namespace.items()}
|
||||||
|
|
||||||
def get_all_conversation_vars(self) -> dict[str, Any]:
|
def get_all_conversation_vars(self) -> dict[str, Any]:
|
||||||
"""获取所有会话变量
|
"""获取所有会话变量
|
||||||
@@ -295,7 +295,7 @@ class VariablePool:
|
|||||||
会话变量字典
|
会话变量字典
|
||||||
"""
|
"""
|
||||||
conv_namespace = self.variables.get("conv", {})
|
conv_namespace = self.variables.get("conv", {})
|
||||||
return {k: v.instance.value for k, v in conv_namespace.items()}
|
return {k: v.instance.get_value() for k, v in conv_namespace.items()}
|
||||||
|
|
||||||
def get_all_node_outputs(self) -> dict[str, Any]:
|
def get_all_node_outputs(self) -> dict[str, Any]:
|
||||||
"""获取所有节点输出(运行时变量)
|
"""获取所有节点输出(运行时变量)
|
||||||
@@ -305,7 +305,7 @@ class VariablePool:
|
|||||||
"""
|
"""
|
||||||
runtime_vars = {
|
runtime_vars = {
|
||||||
namespace: {
|
namespace: {
|
||||||
k: v.instance.value
|
k: v.instance.get_value()
|
||||||
for k, v in vars_dict.items()
|
for k, v in vars_dict.items()
|
||||||
}
|
}
|
||||||
for namespace, vars_dict in self.variables.items()
|
for namespace, vars_dict in self.variables.items()
|
||||||
@@ -326,7 +326,7 @@ class VariablePool:
|
|||||||
"""
|
"""
|
||||||
node_namespace = self.variables.get(node_id)
|
node_namespace = self.variables.get(node_id)
|
||||||
if node_namespace:
|
if node_namespace:
|
||||||
return {k: v.instance.value for k, v in node_namespace.items()}
|
return {k: v.instance.get_value() for k, v in node_namespace.items()}
|
||||||
if strict:
|
if strict:
|
||||||
raise KeyError(f"node {node_id} output not exist")
|
raise KeyError(f"node {node_id} output not exist")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional, Any, List, Dict, Union
|
from typing import Optional, Any, List, Dict, Union
|
||||||
from enum import Enum
|
from enum import Enum, StrEnum
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
||||||
|
|
||||||
|
|
||||||
# ---------- Multimodal File Support ----------
|
# ---------- Multimodal File Support ----------
|
||||||
|
|
||||||
class FileType(str, Enum):
|
class FileType(StrEnum):
|
||||||
"""文件类型枚举"""
|
"""文件类型枚举"""
|
||||||
IMAGE = "image"
|
IMAGE = "image"
|
||||||
DOCUMENT = "document"
|
DOCUMENT = "document"
|
||||||
|
|||||||
@@ -21,6 +21,11 @@ class PromptOptMessage(BaseModel):
|
|||||||
description="currently optimized prompt"
|
description="currently optimized prompt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
skill: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Enable variable output"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PromptSaveRequest(BaseModel):
|
class PromptSaveRequest(BaseModel):
|
||||||
session_id: UUID = Field(
|
session_id: UUID = Field(
|
||||||
|
|||||||
@@ -1089,7 +1089,7 @@ class DraftRunService:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 对于多 Agent 应用,没有直接的 AgentConfig 是正常的
|
# 对于多 Agent 应用,没有直接的 AgentConfig 是正常的
|
||||||
logger.debug("获取配置快照失败(可能是多 Agent 应用)", extra={"error": str(e)})
|
logger.debug("获取配置快照失败(可能是多 Agent 应用)", exc_info=True, extra={"error": str(e)})
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _replace_variables(
|
def _replace_variables(
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ logger = get_business_logger()
|
|||||||
|
|
||||||
class ImageFormatStrategy(Protocol):
|
class ImageFormatStrategy(Protocol):
|
||||||
"""图片格式策略接口"""
|
"""图片格式策略接口"""
|
||||||
|
|
||||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||||
"""将图片 URL 转换为特定 provider 的格式"""
|
"""将图片 URL 转换为特定 provider 的格式"""
|
||||||
...
|
...
|
||||||
@@ -31,7 +31,7 @@ class ImageFormatStrategy(Protocol):
|
|||||||
|
|
||||||
class DashScopeImageStrategy:
|
class DashScopeImageStrategy:
|
||||||
"""通义千问图片格式策略"""
|
"""通义千问图片格式策略"""
|
||||||
|
|
||||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||||
"""通义千问格式: {"type": "image", "image": "url"}"""
|
"""通义千问格式: {"type": "image", "image": "url"}"""
|
||||||
return {
|
return {
|
||||||
@@ -42,7 +42,7 @@ class DashScopeImageStrategy:
|
|||||||
|
|
||||||
class BedrockImageStrategy:
|
class BedrockImageStrategy:
|
||||||
"""Bedrock/Anthropic 图片格式策略"""
|
"""Bedrock/Anthropic 图片格式策略"""
|
||||||
|
|
||||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Bedrock/Anthropic 格式: base64 编码
|
Bedrock/Anthropic 格式: base64 编码
|
||||||
@@ -51,17 +51,17 @@ class BedrockImageStrategy:
|
|||||||
import httpx
|
import httpx
|
||||||
import base64
|
import base64
|
||||||
from mimetypes import guess_type
|
from mimetypes import guess_type
|
||||||
|
|
||||||
logger.info(f"下载并编码图片: {url}")
|
logger.info(f"下载并编码图片: {url}")
|
||||||
|
|
||||||
# 下载图片
|
# 下载图片
|
||||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
response = await client.get(url)
|
response = await client.get(url)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
# 获取图片数据
|
# 获取图片数据
|
||||||
image_data = response.content
|
image_data = response.content
|
||||||
|
|
||||||
# 确定 media type
|
# 确定 media type
|
||||||
content_type = response.headers.get("content-type")
|
content_type = response.headers.get("content-type")
|
||||||
if content_type and content_type.startswith("image/"):
|
if content_type and content_type.startswith("image/"):
|
||||||
@@ -69,12 +69,12 @@ class BedrockImageStrategy:
|
|||||||
else:
|
else:
|
||||||
guessed_type, _ = guess_type(url)
|
guessed_type, _ = guess_type(url)
|
||||||
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
|
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
|
||||||
|
|
||||||
# 转换为 base64
|
# 转换为 base64
|
||||||
base64_data = base64.b64encode(image_data).decode("utf-8")
|
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||||
|
|
||||||
logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"type": "image",
|
"type": "image",
|
||||||
"source": {
|
"source": {
|
||||||
@@ -87,7 +87,7 @@ class BedrockImageStrategy:
|
|||||||
|
|
||||||
class OpenAIImageStrategy:
|
class OpenAIImageStrategy:
|
||||||
"""OpenAI 图片格式策略"""
|
"""OpenAI 图片格式策略"""
|
||||||
|
|
||||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||||
"""OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}"""
|
"""OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}"""
|
||||||
return {
|
return {
|
||||||
@@ -109,7 +109,7 @@ PROVIDER_STRATEGIES = {
|
|||||||
|
|
||||||
class MultimodalService:
|
class MultimodalService:
|
||||||
"""多模态文件处理服务"""
|
"""多模态文件处理服务"""
|
||||||
|
|
||||||
def __init__(self, db: Session, provider: str = "dashscope"):
|
def __init__(self, db: Session, provider: str = "dashscope"):
|
||||||
"""
|
"""
|
||||||
初始化多模态服务
|
初始化多模态服务
|
||||||
@@ -120,10 +120,10 @@ class MultimodalService:
|
|||||||
"""
|
"""
|
||||||
self.db = db
|
self.db = db
|
||||||
self.provider = provider.lower()
|
self.provider = provider.lower()
|
||||||
|
|
||||||
async def process_files(
|
async def process_files(
|
||||||
self,
|
self,
|
||||||
files: Optional[List[FileInput]]
|
files: Optional[List[FileInput]]
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
处理文件列表,返回 LLM 可用的格式
|
处理文件列表,返回 LLM 可用的格式
|
||||||
@@ -136,7 +136,7 @@ class MultimodalService:
|
|||||||
"""
|
"""
|
||||||
if not files:
|
if not files:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for idx, file in enumerate(files):
|
for idx, file in enumerate(files):
|
||||||
try:
|
try:
|
||||||
@@ -168,10 +168,10 @@ class MultimodalService:
|
|||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"[文件处理失败: {str(e)}]"
|
"text": f"[文件处理失败: {str(e)}]"
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _process_image(self, file: FileInput) -> Dict[str, Any]:
|
async def _process_image(self, file: FileInput) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
处理图片文件
|
处理图片文件
|
||||||
@@ -184,14 +184,10 @@ class MultimodalService:
|
|||||||
- Anthropic/Bedrock: {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
|
- Anthropic/Bedrock: {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
|
||||||
- 通义千问: {"type": "image", "image": "url"}
|
- 通义千问: {"type": "image", "image": "url"}
|
||||||
"""
|
"""
|
||||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
url = await self.get_file_url(file)
|
||||||
url = file.url
|
|
||||||
else:
|
|
||||||
# 本地文件,获取访问 URL
|
|
||||||
url = await self._get_file_url(file.upload_file_id)
|
|
||||||
|
|
||||||
logger.debug(f"处理图片: {url}, provider={self.provider}")
|
logger.debug(f"处理图片: {url}, provider={self.provider}")
|
||||||
|
|
||||||
# 根据 provider 返回不同格式
|
# 根据 provider 返回不同格式
|
||||||
if self.provider in ["bedrock", "anthropic"]:
|
if self.provider in ["bedrock", "anthropic"]:
|
||||||
# Anthropic/Bedrock 只支持 base64 格式,需要下载并转换
|
# Anthropic/Bedrock 只支持 base64 格式,需要下载并转换
|
||||||
@@ -223,7 +219,7 @@ class MultimodalService:
|
|||||||
"type": "image",
|
"type": "image",
|
||||||
"image": url
|
"image": url
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _download_and_encode_image(self, url: str) -> tuple[str, str]:
|
async def _download_and_encode_image(self, url: str) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
下载图片并转换为 base64
|
下载图片并转换为 base64
|
||||||
@@ -237,15 +233,15 @@ class MultimodalService:
|
|||||||
import httpx
|
import httpx
|
||||||
import base64
|
import base64
|
||||||
from mimetypes import guess_type
|
from mimetypes import guess_type
|
||||||
|
|
||||||
# 下载图片
|
# 下载图片
|
||||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
response = await client.get(url)
|
response = await client.get(url)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
# 获取图片数据
|
# 获取图片数据
|
||||||
image_data = response.content
|
image_data = response.content
|
||||||
|
|
||||||
# 确定 media type
|
# 确定 media type
|
||||||
content_type = response.headers.get("content-type")
|
content_type = response.headers.get("content-type")
|
||||||
if content_type and content_type.startswith("image/"):
|
if content_type and content_type.startswith("image/"):
|
||||||
@@ -254,14 +250,14 @@ class MultimodalService:
|
|||||||
# 从 URL 推断
|
# 从 URL 推断
|
||||||
guessed_type, _ = guess_type(url)
|
guessed_type, _ = guess_type(url)
|
||||||
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
|
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
|
||||||
|
|
||||||
# 转换为 base64
|
# 转换为 base64
|
||||||
base64_data = base64.b64encode(image_data).decode("utf-8")
|
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||||
|
|
||||||
logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
||||||
|
|
||||||
return base64_data, media_type
|
return base64_data, media_type
|
||||||
|
|
||||||
async def _process_document(self, file: FileInput) -> Dict[str, Any]:
|
async def _process_document(self, file: FileInput) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
处理文档文件(PDF、Word 等)
|
处理文档文件(PDF、Word 等)
|
||||||
@@ -284,14 +280,14 @@ class MultimodalService:
|
|||||||
generic_file = self.db.query(GenericFile).filter(
|
generic_file = self.db.query(GenericFile).filter(
|
||||||
GenericFile.id == file.upload_file_id
|
GenericFile.id == file.upload_file_id
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
file_name = generic_file.file_name if generic_file else "unknown"
|
file_name = generic_file.file_name if generic_file else "unknown"
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _process_audio(self, file: FileInput) -> Dict[str, Any]:
|
async def _process_audio(self, file: FileInput) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
处理音频文件
|
处理音频文件
|
||||||
@@ -307,7 +303,7 @@ class MultimodalService:
|
|||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "[音频文件,暂不支持处理]"
|
"text": "[音频文件,暂不支持处理]"
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _process_video(self, file: FileInput) -> Dict[str, Any]:
|
async def _process_video(self, file: FileInput) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
处理视频文件
|
处理视频文件
|
||||||
@@ -323,13 +319,13 @@ class MultimodalService:
|
|||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "[视频文件,暂不支持处理]"
|
"text": "[视频文件,暂不支持处理]"
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _get_file_url(self, file_id: uuid.UUID) -> str:
|
async def get_file_url(self, file: FileInput) -> str:
|
||||||
"""
|
"""
|
||||||
获取文件的访问 URL
|
获取文件的访问 URL
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_id: 文件ID
|
file: File Input Struct
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 文件访问 URL
|
str: 文件访问 URL
|
||||||
@@ -337,26 +333,31 @@ class MultimodalService:
|
|||||||
Raises:
|
Raises:
|
||||||
BusinessException: 文件不存在
|
BusinessException: 文件不存在
|
||||||
"""
|
"""
|
||||||
generic_file = self.db.query(GenericFile).filter(
|
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||||
GenericFile.id == file_id,
|
return file.url
|
||||||
GenericFile.status == "active"
|
else:
|
||||||
).first()
|
# 本地文件,获取访问 URL
|
||||||
|
file_id = file.upload_file_id
|
||||||
if not generic_file:
|
generic_file = self.db.query(GenericFile).filter(
|
||||||
raise BusinessException(
|
GenericFile.id == file.upload_file_id,
|
||||||
f"文件不存在或已删除: {file_id}",
|
GenericFile.status == "active"
|
||||||
BizCode.NOT_FOUND
|
).first()
|
||||||
)
|
|
||||||
|
if not generic_file:
|
||||||
# 如果有 access_url,直接返回
|
raise BusinessException(
|
||||||
if generic_file.access_url:
|
f"文件不存在或已删除: {file.upload_file_id}",
|
||||||
return generic_file.access_url
|
BizCode.NOT_FOUND
|
||||||
|
)
|
||||||
# 否则,根据 storage_path 生成 URL
|
|
||||||
# TODO: 根据实际存储方式生成 URL(本地存储、OSS 等)
|
# 如果有 access_url,直接返回
|
||||||
# 这里暂时返回一个占位 URL
|
if generic_file.access_url:
|
||||||
return f"/api/files/{file_id}/download"
|
return generic_file.access_url
|
||||||
|
|
||||||
|
# 否则,根据 storage_path 生成 URL
|
||||||
|
# TODO: 根据实际存储方式生成 URL(本地存储、OSS 等)
|
||||||
|
# 这里暂时返回一个占位 URL
|
||||||
|
return f"/api/files/{file_id}/download"
|
||||||
|
|
||||||
async def _extract_document_text(self, file_id: uuid.UUID) -> str:
|
async def _extract_document_text(self, file_id: uuid.UUID) -> str:
|
||||||
"""
|
"""
|
||||||
提取文档文本内容
|
提取文档文本内容
|
||||||
@@ -371,20 +372,20 @@ class MultimodalService:
|
|||||||
GenericFile.id == file_id,
|
GenericFile.id == file_id,
|
||||||
GenericFile.status == "active"
|
GenericFile.status == "active"
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not generic_file:
|
if not generic_file:
|
||||||
raise BusinessException(
|
raise BusinessException(
|
||||||
f"文件不存在或已删除: {file_id}",
|
f"文件不存在或已删除: {file_id}",
|
||||||
BizCode.NOT_FOUND
|
BizCode.NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: 根据文件类型提取文本
|
# TODO: 根据文件类型提取文本
|
||||||
# - PDF: 使用 PyPDF2 或 pdfplumber
|
# - PDF: 使用 PyPDF2 或 pdfplumber
|
||||||
# - Word: 使用 python-docx
|
# - Word: 使用 python-docx
|
||||||
# - TXT/MD: 直接读取
|
# - TXT/MD: 直接读取
|
||||||
|
|
||||||
file_ext = generic_file.file_ext.lower()
|
file_ext = generic_file.file_ext.lower()
|
||||||
|
|
||||||
if file_ext in ['.txt', '.md', '.markdown']:
|
if file_ext in ['.txt', '.md', '.markdown']:
|
||||||
return await self._read_text_file(generic_file.storage_path)
|
return await self._read_text_file(generic_file.storage_path)
|
||||||
elif file_ext == '.pdf':
|
elif file_ext == '.pdf':
|
||||||
@@ -393,7 +394,7 @@ class MultimodalService:
|
|||||||
return await self._extract_word_text(generic_file.storage_path)
|
return await self._extract_word_text(generic_file.storage_path)
|
||||||
else:
|
else:
|
||||||
return f"[不支持的文档格式: {file_ext}]"
|
return f"[不支持的文档格式: {file_ext}]"
|
||||||
|
|
||||||
async def _read_text_file(self, storage_path: str) -> str:
|
async def _read_text_file(self, storage_path: str) -> str:
|
||||||
"""读取纯文本文件"""
|
"""读取纯文本文件"""
|
||||||
try:
|
try:
|
||||||
@@ -402,7 +403,7 @@ class MultimodalService:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"读取文本文件失败: {e}")
|
logger.error(f"读取文本文件失败: {e}")
|
||||||
return f"[文件读取失败: {str(e)}]"
|
return f"[文件读取失败: {str(e)}]"
|
||||||
|
|
||||||
async def _extract_pdf_text(self, storage_path: str) -> str:
|
async def _extract_pdf_text(self, storage_path: str) -> str:
|
||||||
"""提取 PDF 文本"""
|
"""提取 PDF 文本"""
|
||||||
try:
|
try:
|
||||||
@@ -412,7 +413,7 @@ class MultimodalService:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"提取 PDF 文本失败: {e}")
|
logger.error(f"提取 PDF 文本失败: {e}")
|
||||||
return f"[PDF 提取失败: {str(e)}]"
|
return f"[PDF 提取失败: {str(e)}]"
|
||||||
|
|
||||||
async def _extract_word_text(self, storage_path: str) -> str:
|
async def _extract_word_text(self, storage_path: str) -> str:
|
||||||
"""提取 Word 文档文本"""
|
"""提取 Word 文档文本"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
{% raw %}
|
|
||||||
Role: AI Prompt Optimization Expert
|
Role: AI Prompt Optimization Expert
|
||||||
|
|
||||||
Profile
|
Profile
|
||||||
@@ -12,11 +11,11 @@ Skills
|
|||||||
Core Optimization Skills
|
Core Optimization Skills
|
||||||
Requirement Analysis: Accurately understand the relationship between the user’s current needs and the original prompt.
|
Requirement Analysis: Accurately understand the relationship between the user’s current needs and the original prompt.
|
||||||
Structural Reconstruction: Transform vague requirements into clear, block-structured instructions.
|
Structural Reconstruction: Transform vague requirements into clear, block-structured instructions.
|
||||||
Variable Handling: Identify and standardize dynamic variables in prompts.
|
{% if skill != true %}Variable Handling: Identify and standardize dynamic variables in prompts.{% endif %}
|
||||||
Conflict Resolution: Prioritize current requirements when historical requirements conflict with current needs.
|
Conflict Resolution: Prioritize current requirements when historical requirements conflict with current needs.
|
||||||
|
|
||||||
Auxiliary Generation Skills
|
Auxiliary Generation Skills
|
||||||
Completeness Check: Ensure all necessary elements (input, output, constraints, etc.) are explicitly defined.
|
{% if skill != true %}Completeness Check: Ensure all necessary elements (input, output, constraints, etc.) are explicitly defined.{% endif %}
|
||||||
Language Consistency: Maintain consistency between label language and user input language.
|
Language Consistency: Maintain consistency between label language and user input language.
|
||||||
Executability Verification: Ensure optimized prompts can be directly used in AI tools.
|
Executability Verification: Ensure optimized prompts can be directly used in AI tools.
|
||||||
Format Standardization: Strictly adhere to specified output format requirements.
|
Format Standardization: Strictly adhere to specified output format requirements.
|
||||||
@@ -25,30 +24,30 @@ Rules
|
|||||||
Basic Principles
|
Basic Principles
|
||||||
Priority Rule: When historical requirements conflict with current requirements, unconditionally prioritize current requirements.
|
Priority Rule: When historical requirements conflict with current requirements, unconditionally prioritize current requirements.
|
||||||
Completeness Rule: If the original prompt is empty, generate a complete prompt based on the current requirements.
|
Completeness Rule: If the original prompt is empty, generate a complete prompt based on the current requirements.
|
||||||
Structure Rule: Use a clear block structure, and the contents of each block are roles, tasks, requirements, inputs, outputs, and constraints
|
{% if skill != true %}Structure Rule: Use a clear block structure, and the contents of each block are roles, tasks, requirements, inputs, outputs, and constraints{% endif %}
|
||||||
Language Rule: All label languages must fully match the user input language.
|
Language Rule: All label languages must fully match the user input language.
|
||||||
|
|
||||||
Behavior Guidelines
|
Behavior Guidelines
|
||||||
Precision Guideline: All instructions must be precise and directly executable, avoiding ambiguity.
|
Precision Guideline: All instructions must be precise and directly executable, avoiding ambiguity.
|
||||||
Readability Guideline: Ensure optimized prompts have good readability and logical flow.
|
Readability Guideline: Ensure optimized prompts have good readability and logical flow.
|
||||||
Variable Handling Guideline: Use lowercase English variable names wrapped in {{}} when variables are needed.
|
{% if skill != true %}{% raw %}Variable Handling Guideline: Use lowercase English variable names wrapped in {{}} when variables are needed.
|
||||||
Constraint Handling Guideline: Do not mention variable-related limitations under the [Constraints] label.
|
Constraint Handling Guideline: Do not mention variable-related limitations under the [Constraints] label.{% endraw %}{% endif %}
|
||||||
|
|
||||||
Constraints
|
Constraints
|
||||||
Output Constraint: Must output in JSON format including the fields "prompt" and "desc".
|
Output Constraint: Must output in JSON format including the fields "prompt" and "desc".
|
||||||
Content Constraint: Must not include any explanations, analyses, or additional comments.
|
Content Constraint: Must not include any explanations, analyses, or additional comments.
|
||||||
Language Constraint: Must use clear and concise language.
|
Language Constraint: Must use clear and concise language.
|
||||||
Completeness Constraint: Must fully define all missing elements (input details, output format, constraints, etc.).
|
{% if skill != true %}Completeness Constraint: Must fully define all missing elements (input details, output format, constraints, etc.).{% endif %}
|
||||||
|
|
||||||
Workflows
|
Workflows
|
||||||
Goal: Optimize or generate AI prompts that can be directly used according to user requirements.
|
Goal: Optimize or generate AI prompts that can be directly used according to user requirements.
|
||||||
Step 1: Receive the user’s current requirement description {{user_require}} and the original prompt {{original_prompt}}.
|
Step 1: Receive the user’s current requirement description {{user_require}} and the original prompt {{original_prompt}}.
|
||||||
Step 2: Analyze requirements, identify conflicts, and prioritize current requirements.
|
Step 2: Analyze requirements, identify conflicts, and prioritize current requirements.
|
||||||
Step 3: Optimize or generate the prompt in a block-structured format, ensuring all elements are fully defined.
|
{% if skill != true %}Step 3: Optimize or generate the prompt in a block-structured format, ensuring all elements are fully defined.
|
||||||
Step 4: Generate a JSON output containing the optimized prompt and its description.
|
Step 4: Generate a JSON output containing the optimized prompt and its description.
|
||||||
|
{% else %}Step 3: Generate a JSON output containing the optimized prompt and its description.{% endif %}
|
||||||
|
|
||||||
Expected Outcome: Obtain a clear, directly executable AI prompt accompanied by an optimization description.
|
Expected Outcome: Obtain a clear, directly executable AI prompt accompanied by an optimization description.
|
||||||
|
|
||||||
Initialization
|
Initialization
|
||||||
As an AI Prompt Optimization Expert, you must follow the above Rules and execute tasks according to the Workflows.
|
As an AI Prompt Optimization Expert, you must follow the above Rules and execute tasks according to the Workflows.
|
||||||
{% endraw %}
|
|
||||||
@@ -128,7 +128,8 @@ class PromptOptimizerService:
|
|||||||
session_id: uuid.UUID,
|
session_id: uuid.UUID,
|
||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
current_prompt: str,
|
current_prompt: str,
|
||||||
user_require: str
|
user_require: str,
|
||||||
|
skill: bool = False
|
||||||
) -> AsyncGenerator[dict[str, str | Any], Any]:
|
) -> AsyncGenerator[dict[str, str | Any], Any]:
|
||||||
"""
|
"""
|
||||||
Optimize a user-provided prompt using a configured prompt optimizer LLM.
|
Optimize a user-provided prompt using a configured prompt optimizer LLM.
|
||||||
@@ -157,6 +158,7 @@ class PromptOptimizerService:
|
|||||||
user_id (uuid.UUID): Identifier of the user associated with the session.
|
user_id (uuid.UUID): Identifier of the user associated with the session.
|
||||||
current_prompt (str): Original prompt to optimize.
|
current_prompt (str): Original prompt to optimize.
|
||||||
user_require (str): User's requirements or instructions for optimization.
|
user_require (str): User's requirements or instructions for optimization.
|
||||||
|
skill(bool): Is skill required
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
OptimizePromptResult: An object containing:
|
OptimizePromptResult: An object containing:
|
||||||
@@ -186,7 +188,7 @@ class PromptOptimizerService:
|
|||||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||||
with open(os.path.join(prompt_path, 'prompt_optimizer_system.jinja2'), 'r', encoding='utf-8') as f:
|
with open(os.path.join(prompt_path, 'prompt_optimizer_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||||
opt_system_prompt = f.read()
|
opt_system_prompt = f.read()
|
||||||
rendered_system_message = Template(opt_system_prompt).render()
|
rendered_system_message = Template(opt_system_prompt).render(skill=skill)
|
||||||
|
|
||||||
with open(os.path.join(prompt_path, 'prompt_optimizer_user.jinja2'), 'r', encoding='utf-8') as f:
|
with open(os.path.join(prompt_path, 'prompt_optimizer_user.jinja2'), 'r', encoding='utf-8') as f:
|
||||||
opt_user_prompt = f.read()
|
opt_user_prompt = f.read()
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from app.repositories.workflow_repository import (
|
|||||||
from app.schemas import DraftRunRequest
|
from app.schemas import DraftRunRequest
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
from app.services.multi_agent_service import convert_uuids_to_str
|
from app.services.multi_agent_service import convert_uuids_to_str
|
||||||
|
from app.services.multimodal_service import MultimodalService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -35,6 +36,7 @@ class WorkflowService:
|
|||||||
self.execution_repo = WorkflowExecutionRepository(db)
|
self.execution_repo = WorkflowExecutionRepository(db)
|
||||||
self.node_execution_repo = WorkflowNodeExecutionRepository(db)
|
self.node_execution_repo = WorkflowNodeExecutionRepository(db)
|
||||||
self.conversation_service = ConversationService(db)
|
self.conversation_service = ConversationService(db)
|
||||||
|
self.multimodal_service = MultimodalService(db)
|
||||||
|
|
||||||
# ==================== 配置管理 ====================
|
# ==================== 配置管理 ====================
|
||||||
|
|
||||||
@@ -444,8 +446,19 @@ class WorkflowService:
|
|||||||
code=BizCode.CONFIG_MISSING,
|
code=BizCode.CONFIG_MISSING,
|
||||||
message=f"工作流配置不存在: app_id={app_id}"
|
message=f"工作流配置不存在: app_id={app_id}"
|
||||||
)
|
)
|
||||||
|
files = []
|
||||||
|
if payload.files:
|
||||||
|
for file in payload.files:
|
||||||
|
files.append(
|
||||||
|
{
|
||||||
|
"type": file.type,
|
||||||
|
"url": await self.multimodal_service.get_file_url(file),
|
||||||
|
"__file": True
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
input_data = {"message": payload.message, "variables": payload.variables,
|
input_data = {"message": payload.message, "variables": payload.variables,
|
||||||
"conversation_id": payload.conversation_id}
|
"conversation_id": payload.conversation_id, "files": files}
|
||||||
|
|
||||||
# 转换 user_id 为 UUID
|
# 转换 user_id 为 UUID
|
||||||
triggered_by_uuid = None
|
triggered_by_uuid = None
|
||||||
@@ -633,8 +646,20 @@ class WorkflowService:
|
|||||||
code=BizCode.CONFIG_MISSING,
|
code=BizCode.CONFIG_MISSING,
|
||||||
message=f"工作流配置不存在: app_id={app_id}"
|
message=f"工作流配置不存在: app_id={app_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
files = []
|
||||||
|
if payload.files:
|
||||||
|
for file in payload.files:
|
||||||
|
files.append(
|
||||||
|
{
|
||||||
|
"type": file.type,
|
||||||
|
"url": await self.multimodal_service.get_file_url(file),
|
||||||
|
"__file": True
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
input_data = {"message": payload.message, "variables": payload.variables,
|
input_data = {"message": payload.message, "variables": payload.variables,
|
||||||
"conversation_id": payload.conversation_id}
|
"conversation_id": payload.conversation_id, "files": files}
|
||||||
|
|
||||||
# 转换 user_id 为 UUID
|
# 转换 user_id 为 UUID
|
||||||
triggered_by_uuid = None
|
triggered_by_uuid = None
|
||||||
|
|||||||
Reference in New Issue
Block a user