diff --git a/api/app/controllers/prompt_optimizer_controller.py b/api/app/controllers/prompt_optimizer_controller.py index 61195deb..80f14cd3 100644 --- a/api/app/controllers/prompt_optimizer_controller.py +++ b/api/app/controllers/prompt_optimizer_controller.py @@ -120,7 +120,8 @@ async def get_prompt_opt( session_id=session_id, user_id=current_user.id, current_prompt=data.current_prompt, - user_require=data.message + user_require=data.message, + skill=data.skill ): # chunk 是 prompt 的增量内容 yield f"event:message\ndata: {json.dumps(chunk)}\n\n" diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index f3763955..537058a0 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -92,7 +92,7 @@ class WorkflowExecutor: - "conversation_id": conversation identifier """ user_message = input_data.get("message") or "" - user_file = input_data.get("file") or [] + user_files = input_data.get("files") or [] config_variables_list = self.workflow_config.get("variables") or [] conv_vars = input_data.get("conv", {}) @@ -119,12 +119,12 @@ class WorkflowExecutor: input_variables = input_data.get("variables") or {} sys_vars = { "message": (user_message, VariableType.STRING), - "file": (user_file, VariableType.ARRAY_FILE), "conversation_id": (input_data.get("conversation_id"), VariableType.STRING), "execution_id": (self.execution_id, VariableType.STRING), "workspace_id": (self.workspace_id, VariableType.STRING), "user_id": (self.user_id, VariableType.STRING), "input_variables": (input_variables, VariableType.OBJECT), + "files": (user_files, VariableType.ARRAY_FILE) } for key, var_def in sys_vars.items(): value = var_def[0] @@ -564,6 +564,7 @@ class WorkflowExecutor: "input": result.get("node_outputs", {}).get(node_name, {}).get("input"), "output": result.get("node_outputs", {}).get(node_name, {}).get("output"), "elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"), + "token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage") } } diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 2bf748f2..e370dbeb 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -11,6 +11,7 @@ from app.core.config import settings from app.core.workflow.nodes.enums import BRANCH_NODES from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable_pool import VariablePool +from app.services.multimodal_service import PROVIDER_STRATEGIES logger = logging.getLogger(__name__) @@ -651,3 +652,21 @@ class BaseNode(ABC): True if the variable exists in the pool, False otherwise. """ return variable_pool.has(selector) + + @staticmethod + async def process_message(provider, content, enable_file=False) -> dict | str | None: + if isinstance(content, str): + if enable_file: + return {"text": content} + return content + elif isinstance(content, dict): + trans_tool = PROVIDER_STRATEGIES[provider]() + result = await trans_tool.format_image(content["url"]) + return result + raise TypeError('Unexpect input value type') + + @staticmethod + def process_model_output(content) -> str: + if isinstance(content, dict): + return content.get("text") + return content diff --git a/api/app/core/workflow/nodes/llm/config.py b/api/app/core/workflow/nodes/llm/config.py index 48c51aa1..1229450f 100644 --- a/api/app/core/workflow/nodes/llm/config.py +++ b/api/app/core/workflow/nodes/llm/config.py @@ -71,6 +71,16 @@ class LLMNodeConfig(BaseNodeConfig): description="对话上下文窗口" ) + vision: bool = Field( + default=False, + description="是否启用视觉模型" + ) + + vision_input: str = Field( + default=None, + description="视觉输入" + ) + # 简单模式 prompt: str | None = Field( default=None, diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index 1246324d..4393e1ed 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -79,12 +79,12 @@ class LLMNode(BaseNode): context = f"{self._render_template(self.typed_config.context, variable_pool)}" return re.sub(r"{{context}}", context, message) - def _prepare_llm( + async def _prepare_llm( self, state: WorkflowState, variable_pool: VariablePool, stream: bool = False - ) -> tuple[RedBearLLM, list | str]: + ) -> RedBearLLM: """准备 LLM 实例(公共逻辑) Args: @@ -93,42 +93,9 @@ class LLMNode(BaseNode): Returns: (llm, messages_or_prompt): LLM 实例和消息列表或 prompt 字符串 """ - - # 1. 处理消息格式(优先使用 messages) self.typed_config = LLMNodeConfig(**self.config) - messages_config = self.typed_config.messages - if messages_config: - # 使用 LangChain 消息格式 - messages = [] - for msg_config in messages_config: - role = msg_config.role.lower() - content_template = msg_config.content - content_template = self._render_context(content_template, variable_pool) - content = self._render_template(content_template, variable_pool) - - # 根据角色创建对应的消息对象 - if role == "system": - messages.append({"role": "system", "content": content}) - elif role in ["user", "human"]: - messages.append({"role": "user", "content": content}) - elif role in ["ai", "assistant"]: - messages.append({"role": "assistant", "content": content}) - else: - logger.warning(f"未知的消息角色: {role},默认使用 user") - messages.append({"role": "user", "content": content}) - - if self.typed_config.memory.enable: - # if self.typed_config.memory.enable_window: - messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:] - prompt_or_messages = messages - else: - # 使用简单的 prompt 格式(向后兼容) - prompt_template = self.config.get("prompt", "") - prompt_or_messages = self._render_template(prompt_template, variable_pool) - - # 2. 获取模型配置 - model_id = self.config.get("model_id") + model_id = self.typed_config.model_id if not model_id: raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置") @@ -167,7 +134,61 @@ class LLMNode(BaseNode): logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}") - return llm, prompt_or_messages + messages_config = self.typed_config.messages + + if messages_config: + # 使用 LangChain 消息格式 + messages = [] + for msg_config in messages_config: + role = msg_config.role.lower() + content_template = msg_config.content + content_template = self._render_context(content_template, variable_pool) + content = self._render_template(content_template, variable_pool) + + # 根据角色创建对应的消息对象 + if role == "system": + messages.append({ + "role": "system", + "content": content + }) + elif role in ["user", "human"]: + messages.append({ + "role": "user", + "content": content + }) + elif role in ["ai", "assistant"]: + messages.append({ + "role": "assistant", + "content": content + }) + else: + logger.warning(f"未知的消息角色: {role},默认使用 user") + messages.append({ + "role": "user", + "content": content + }) + + if self.typed_config.vision_input and self.typed_config.vision: + file_content = [] + files = variable_pool.get_value(self.typed_config.vision_input) + for file in files: + content = await self.process_message(provider, file, self.typed_config.vision) + if content: + file_content.append(content) + if messages and messages[-1]["role"] == 'user': + messages[-1]['content'] = [messages[-1]["content"]] + file_content + else: + messages.append({"role": "user", "content": file_content}) + + if self.typed_config.memory.enable: + messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:] + self.messages = messages + else: + # 使用简单的 prompt 格式(向后兼容) + prompt_template = self.config.get("prompt", "") + self.messages = self._render_template(prompt_template, variable_pool) + + return llm async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage: """非流式执行 LLM 调用 @@ -180,15 +201,15 @@ class LLMNode(BaseNode): LLM 响应消息 """ # self.typed_config = LLMNodeConfig(**self.config) - llm, prompt_or_messages = self._prepare_llm(state, variable_pool, False) + llm = await self._prepare_llm(state, variable_pool, False) logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)") # 调用 LLM(支持字符串或消息列表) - response = await llm.ainvoke(prompt_or_messages) + response = await llm.ainvoke(self.messages) # 提取内容 if hasattr(response, 'content'): - content = response.content + content = self.process_model_output(response.content) else: content = str(response) @@ -199,14 +220,13 @@ class LLMNode(BaseNode): def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: """提取输入数据(用于记录)""" - _, prompt_or_messages = self._prepare_llm(state, variable_pool) return { - "prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None, + "prompt": self.messages if isinstance(self.messages, str) else None, "messages": [ {"role": msg.get("role"), "content": msg.get("content", "")} - for msg in prompt_or_messages - ] if isinstance(prompt_or_messages, list) else None, + for msg in self.messages + ] if isinstance(self.messages, list) else None, "config": { "model_id": self.config.get("model_id"), "temperature": self.config.get("temperature"), @@ -226,8 +246,8 @@ class LLMNode(BaseNode): usage = business_result.response_metadata.get('token_usage') if usage: return { - "prompt_tokens": usage.get('prompt_tokens', 0), - "completion_tokens": usage.get('completion_tokens', 0), + "prompt_tokens": usage.get('input_tokens', 0), + "completion_tokens": usage.get('output_tokens', 0), "total_tokens": usage.get('total_tokens', 0) } return None @@ -244,7 +264,7 @@ class LLMNode(BaseNode): """ self.typed_config = LLMNodeConfig(**self.config) - llm, prompt_or_messages = self._prepare_llm(state, variable_pool, True) + llm = await self._prepare_llm(state, variable_pool, True) logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}") @@ -255,10 +275,10 @@ class LLMNode(BaseNode): # 调用 LLM(流式,支持字符串或消息列表) last_meta_data = {} - async for chunk in llm.astream(prompt_or_messages, stream_usage=True): + async for chunk in llm.astream(self.messages, stream_usage=True): # 提取内容 if hasattr(chunk, 'content'): - content = chunk.content + content = self.process_model_output(chunk.content) else: content = str(chunk) if hasattr(chunk, 'response_metadata'): diff --git a/api/app/core/workflow/variable/base_variable.py b/api/app/core/workflow/variable/base_variable.py index d7c96fab..6a2e84d2 100644 --- a/api/app/core/workflow/variable/base_variable.py +++ b/api/app/core/workflow/variable/base_variable.py @@ -2,6 +2,10 @@ from enum import StrEnum from abc import abstractmethod, ABC from typing import Any +from pydantic import BaseModel + +from app.schemas import FileType + class VariableType(StrEnum): """Enumeration of supported variable types in the workflow.""" @@ -42,7 +46,7 @@ class VariableType(StrEnum): return cls.NUMBER elif isinstance(var_type, bool): return cls.BOOLEAN - elif isinstance(var_type, FileObj): + elif isinstance(var_type, FileObject) or (isinstance(var, dict) and var.get('__file')): return cls.FILE elif isinstance(var_type, dict): return cls.OBJECT @@ -103,8 +107,10 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any: raise TypeError(f"Invalid type - {type}") -class FileObj: - pass +class FileObject(BaseModel): + type: FileType + url: str + __file: bool class BaseVariable(ABC): diff --git a/api/app/core/workflow/variable/variable_objects.py b/api/app/core/workflow/variable/variable_objects.py index 35a30418..83fb06f3 100644 --- a/api/app/core/workflow/variable/variable_objects.py +++ b/api/app/core/workflow/variable/variable_objects.py @@ -1,66 +1,84 @@ from typing import Any, TypeVar, Type, Generic -from app.core.workflow.variable.base_variable import BaseVariable, VariableType +from deprecated import deprecated + +from app.core.workflow.variable.base_variable import BaseVariable, VariableType, FileObject T = TypeVar("T", bound=BaseVariable) -class StringObject(BaseVariable): +class StringVariable(BaseVariable): type = 'str' def valid_value(self, value) -> str: if not isinstance(value, str): - raise TypeError("Value must be a string") + raise TypeError(f"Value must be a string - {type(value)}:{value}") return value def to_literal(self) -> str: return self.value -class NumberObject(BaseVariable): +class NumberVariable(BaseVariable): type = 'number' def valid_value(self, value) -> int | float: if not isinstance(value, (int, float)): - raise TypeError("Value must be a number") + raise TypeError(f"Value must be a number - {type(value)}:{value}") return value def to_literal(self) -> str: return str(self.value) -class BooleanObject(BaseVariable): +class BooleanVariable(BaseVariable): type = 'boolean' def valid_value(self, value) -> bool: if not isinstance(value, bool): - raise TypeError("Value must be a boolean") + raise TypeError(f"Value must be a boolean - {type(value)}:{value}") return value def to_literal(self) -> str: return str(self.value).lower() -class DictObject(BaseVariable): +class DictVariable(BaseVariable): type = 'object' def valid_value(self, value) -> dict: if not isinstance(value, dict): - raise TypeError("Value must be a dict") + raise TypeError(f"Value must be a dict - {type(value)}:{value}") return value def to_literal(self) -> str: return str(self.value) -class FileObject(BaseVariable): +class FileVariable(BaseVariable): type = 'file' - def valid_value(self, value) -> Any: - pass + def valid_value(self, value) -> FileObject: + + if isinstance(value, dict): + if not value.get("__file"): + raise TypeError(f"Value must be a FileObject - {type(value)}:{value}") + return FileObject( + **{ + "type": str(value.get('type')), + "url": value.get('url'), + "__file": True + } + ) + if isinstance(value, FileObject): + return value + raise TypeError(f"Value must be a FileObject - {type(value)}:{value}") def to_literal(self) -> str: - pass + return str(self.value.model_dump()) + + def get_value(self) -> Any: + return self.value.model_dump() class ArrayObject(BaseVariable, Generic[T]): @@ -74,7 +92,7 @@ class ArrayObject(BaseVariable, Generic[T]): def valid_value(self, value: list[Any]) -> list[T]: if not isinstance(value, list): - raise TypeError("Value must be a list") + raise TypeError(f"Value must be a list - {type(value)}:{value}") final_value = [] for v in value: try: @@ -86,13 +104,16 @@ class ArrayObject(BaseVariable, Generic[T]): def to_literal(self) -> str: return "\n".join([v.to_literal() for v in self.value]) + def get_value(self) -> Any: + return [v.get_value() for v in self.value] + class NestedArrayObject(BaseVariable): type = 'array_nest' def valid_value(self, value: list[T]) -> list[T]: if not isinstance(value, list): - raise TypeError("Value must be a list") + raise TypeError(f"Value must be a list - {type(value)}:{value}") final_value = [] for v in value: if not isinstance(v, ArrayObject): @@ -107,6 +128,10 @@ class NestedArrayObject(BaseVariable): return [[item.get_value() for item in row] for row in self.value] +@deprecated( + reason="Using arbitrary-type values may cause unexpected errors; please switch to strongly-typed values.", + category=RuntimeWarning +) class AnyObject(BaseVariable): type = 'any' @@ -126,23 +151,23 @@ def make_array(child_type: Type[T], value: list[Any]) -> ArrayObject[T]: def create_variable_instance(var_type: VariableType, value: Any) -> T: match var_type: case VariableType.STRING: - return StringObject(value) + return StringVariable(value) case VariableType.NUMBER: - return NumberObject(value) + return NumberVariable(value) case VariableType.BOOLEAN: - return BooleanObject(value) + return BooleanVariable(value) case VariableType.OBJECT: - return DictObject(value) + return DictVariable(value) case VariableType.ARRAY_STRING: - return make_array(StringObject, value) + return make_array(StringVariable, value) case VariableType.ARRAY_NUMBER: - return make_array(NumberObject, value) + return make_array(NumberVariable, value) case VariableType.ARRAY_BOOLEAN: - return make_array(BooleanObject, value) + return make_array(BooleanVariable, value) case VariableType.ARRAY_OBJECT: - return make_array(DictObject, value) + return make_array(DictVariable, value) case VariableType.ARRAY_FILE: - return make_array(FileObject, value) + return make_array(FileVariable, value) case VariableType.ANY: return AnyObject(value) case _: diff --git a/api/app/core/workflow/variable_pool.py b/api/app/core/workflow/variable_pool.py index 32bfc5e1..96495ce8 100644 --- a/api/app/core/workflow/variable_pool.py +++ b/api/app/core/workflow/variable_pool.py @@ -78,8 +78,8 @@ class VariableStruct(BaseModel, Generic[T]): serialization, and workflow type checking. instance: The concrete variable object. The actual Python type is - represented by the generic parameter ``T`` (e.g. StringObject, - NumberObject, ArrayObject[StringObject]). + represented by the generic parameter ``T`` (e.g. StringVariable, + NumberVariable, ArrayObject[StringVariable]). mut: Whether the variable is mutable. """ @@ -286,7 +286,7 @@ class VariablePool: 系统变量字典 """ sys_namespace = self.variables.get("sys", {}) - return {k: v.instance.value for k, v in sys_namespace.items()} + return {k: v.instance.get_value() for k, v in sys_namespace.items()} def get_all_conversation_vars(self) -> dict[str, Any]: """获取所有会话变量 @@ -295,7 +295,7 @@ class VariablePool: 会话变量字典 """ conv_namespace = self.variables.get("conv", {}) - return {k: v.instance.value for k, v in conv_namespace.items()} + return {k: v.instance.get_value() for k, v in conv_namespace.items()} def get_all_node_outputs(self) -> dict[str, Any]: """获取所有节点输出(运行时变量) @@ -305,7 +305,7 @@ class VariablePool: """ runtime_vars = { namespace: { - k: v.instance.value + k: v.instance.get_value() for k, v in vars_dict.items() } for namespace, vars_dict in self.variables.items() @@ -326,7 +326,7 @@ class VariablePool: """ node_namespace = self.variables.get(node_id) if node_namespace: - return {k: v.instance.value for k, v in node_namespace.items()} + return {k: v.instance.get_value() for k, v in node_namespace.items()} if strict: raise KeyError(f"node {node_id} output not exist") else: diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 2ad27ace..02d897c5 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -1,14 +1,14 @@ import datetime import uuid from typing import Optional, Any, List, Dict, Union -from enum import Enum +from enum import Enum, StrEnum from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator # ---------- Multimodal File Support ---------- -class FileType(str, Enum): +class FileType(StrEnum): """文件类型枚举""" IMAGE = "image" DOCUMENT = "document" diff --git a/api/app/schemas/prompt_optimizer_schema.py b/api/app/schemas/prompt_optimizer_schema.py index 08a11317..96a46742 100644 --- a/api/app/schemas/prompt_optimizer_schema.py +++ b/api/app/schemas/prompt_optimizer_schema.py @@ -21,6 +21,11 @@ class PromptOptMessage(BaseModel): description="currently optimized prompt" ) + skill: bool = Field( + default=False, + description="Enable variable output" + ) + class PromptSaveRequest(BaseModel): session_id: UUID = Field( diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 34e9f865..61b41b6c 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -1089,7 +1089,7 @@ class DraftRunService: except Exception as e: # 对于多 Agent 应用,没有直接的 AgentConfig 是正常的 - logger.debug("获取配置快照失败(可能是多 Agent 应用)", extra={"error": str(e)}) + logger.debug("获取配置快照失败(可能是多 Agent 应用)", exc_info=True, extra={"error": str(e)}) return {} def _replace_variables( diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index a460a7ba..02636c27 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -23,7 +23,7 @@ logger = get_business_logger() class ImageFormatStrategy(Protocol): """图片格式策略接口""" - + async def format_image(self, url: str) -> Dict[str, Any]: """将图片 URL 转换为特定 provider 的格式""" ... @@ -31,7 +31,7 @@ class ImageFormatStrategy(Protocol): class DashScopeImageStrategy: """通义千问图片格式策略""" - + async def format_image(self, url: str) -> Dict[str, Any]: """通义千问格式: {"type": "image", "image": "url"}""" return { @@ -42,7 +42,7 @@ class DashScopeImageStrategy: class BedrockImageStrategy: """Bedrock/Anthropic 图片格式策略""" - + async def format_image(self, url: str) -> Dict[str, Any]: """ Bedrock/Anthropic 格式: base64 编码 @@ -51,17 +51,17 @@ class BedrockImageStrategy: import httpx import base64 from mimetypes import guess_type - + logger.info(f"下载并编码图片: {url}") - + # 下载图片 async with httpx.AsyncClient(timeout=30.0) as client: response = await client.get(url) response.raise_for_status() - + # 获取图片数据 image_data = response.content - + # 确定 media type content_type = response.headers.get("content-type") if content_type and content_type.startswith("image/"): @@ -69,12 +69,12 @@ class BedrockImageStrategy: else: guessed_type, _ = guess_type(url) media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg" - + # 转换为 base64 base64_data = base64.b64encode(image_data).decode("utf-8") - + logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}") - + return { "type": "image", "source": { @@ -87,7 +87,7 @@ class BedrockImageStrategy: class OpenAIImageStrategy: """OpenAI 图片格式策略""" - + async def format_image(self, url: str) -> Dict[str, Any]: """OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}""" return { @@ -109,7 +109,7 @@ PROVIDER_STRATEGIES = { class MultimodalService: """多模态文件处理服务""" - + def __init__(self, db: Session, provider: str = "dashscope"): """ 初始化多模态服务 @@ -120,10 +120,10 @@ class MultimodalService: """ self.db = db self.provider = provider.lower() - + async def process_files( - self, - files: Optional[List[FileInput]] + self, + files: Optional[List[FileInput]] ) -> List[Dict[str, Any]]: """ 处理文件列表,返回 LLM 可用的格式 @@ -136,7 +136,7 @@ class MultimodalService: """ if not files: return [] - + result = [] for idx, file in enumerate(files): try: @@ -168,10 +168,10 @@ class MultimodalService: "type": "text", "text": f"[文件处理失败: {str(e)}]" }) - + logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}") return result - + async def _process_image(self, file: FileInput) -> Dict[str, Any]: """ 处理图片文件 @@ -184,14 +184,10 @@ class MultimodalService: - Anthropic/Bedrock: {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}} - 通义千问: {"type": "image", "image": "url"} """ - if file.transfer_method == TransferMethod.REMOTE_URL: - url = file.url - else: - # 本地文件,获取访问 URL - url = await self._get_file_url(file.upload_file_id) - + url = await self.get_file_url(file) + logger.debug(f"处理图片: {url}, provider={self.provider}") - + # 根据 provider 返回不同格式 if self.provider in ["bedrock", "anthropic"]: # Anthropic/Bedrock 只支持 base64 格式,需要下载并转换 @@ -223,7 +219,7 @@ class MultimodalService: "type": "image", "image": url } - + async def _download_and_encode_image(self, url: str) -> tuple[str, str]: """ 下载图片并转换为 base64 @@ -237,15 +233,15 @@ class MultimodalService: import httpx import base64 from mimetypes import guess_type - + # 下载图片 async with httpx.AsyncClient(timeout=30.0) as client: response = await client.get(url) response.raise_for_status() - + # 获取图片数据 image_data = response.content - + # 确定 media type content_type = response.headers.get("content-type") if content_type and content_type.startswith("image/"): @@ -254,14 +250,14 @@ class MultimodalService: # 从 URL 推断 guessed_type, _ = guess_type(url) media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg" - + # 转换为 base64 base64_data = base64.b64encode(image_data).decode("utf-8") - + logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}") - + return base64_data, media_type - + async def _process_document(self, file: FileInput) -> Dict[str, Any]: """ 处理文档文件(PDF、Word 等) @@ -284,14 +280,14 @@ class MultimodalService: generic_file = self.db.query(GenericFile).filter( GenericFile.id == file.upload_file_id ).first() - + file_name = generic_file.file_name if generic_file else "unknown" - + return { "type": "text", "text": f"\n{text}\n" } - + async def _process_audio(self, file: FileInput) -> Dict[str, Any]: """ 处理音频文件 @@ -307,7 +303,7 @@ class MultimodalService: "type": "text", "text": "[音频文件,暂不支持处理]" } - + async def _process_video(self, file: FileInput) -> Dict[str, Any]: """ 处理视频文件 @@ -323,13 +319,13 @@ class MultimodalService: "type": "text", "text": "[视频文件,暂不支持处理]" } - - async def _get_file_url(self, file_id: uuid.UUID) -> str: + + async def get_file_url(self, file: FileInput) -> str: """ 获取文件的访问 URL Args: - file_id: 文件ID + file: File Input Struct Returns: str: 文件访问 URL @@ -337,26 +333,31 @@ class MultimodalService: Raises: BusinessException: 文件不存在 """ - generic_file = self.db.query(GenericFile).filter( - GenericFile.id == file_id, - GenericFile.status == "active" - ).first() - - if not generic_file: - raise BusinessException( - f"文件不存在或已删除: {file_id}", - BizCode.NOT_FOUND - ) - - # 如果有 access_url,直接返回 - if generic_file.access_url: - return generic_file.access_url - - # 否则,根据 storage_path 生成 URL - # TODO: 根据实际存储方式生成 URL(本地存储、OSS 等) - # 这里暂时返回一个占位 URL - return f"/api/files/{file_id}/download" - + if file.transfer_method == TransferMethod.REMOTE_URL: + return file.url + else: + # 本地文件,获取访问 URL + file_id = file.upload_file_id + generic_file = self.db.query(GenericFile).filter( + GenericFile.id == file.upload_file_id, + GenericFile.status == "active" + ).first() + + if not generic_file: + raise BusinessException( + f"文件不存在或已删除: {file.upload_file_id}", + BizCode.NOT_FOUND + ) + + # 如果有 access_url,直接返回 + if generic_file.access_url: + return generic_file.access_url + + # 否则,根据 storage_path 生成 URL + # TODO: 根据实际存储方式生成 URL(本地存储、OSS 等) + # 这里暂时返回一个占位 URL + return f"/api/files/{file_id}/download" + async def _extract_document_text(self, file_id: uuid.UUID) -> str: """ 提取文档文本内容 @@ -371,20 +372,20 @@ class MultimodalService: GenericFile.id == file_id, GenericFile.status == "active" ).first() - + if not generic_file: raise BusinessException( f"文件不存在或已删除: {file_id}", BizCode.NOT_FOUND ) - + # TODO: 根据文件类型提取文本 # - PDF: 使用 PyPDF2 或 pdfplumber # - Word: 使用 python-docx # - TXT/MD: 直接读取 - + file_ext = generic_file.file_ext.lower() - + if file_ext in ['.txt', '.md', '.markdown']: return await self._read_text_file(generic_file.storage_path) elif file_ext == '.pdf': @@ -393,7 +394,7 @@ class MultimodalService: return await self._extract_word_text(generic_file.storage_path) else: return f"[不支持的文档格式: {file_ext}]" - + async def _read_text_file(self, storage_path: str) -> str: """读取纯文本文件""" try: @@ -402,7 +403,7 @@ class MultimodalService: except Exception as e: logger.error(f"读取文本文件失败: {e}") return f"[文件读取失败: {str(e)}]" - + async def _extract_pdf_text(self, storage_path: str) -> str: """提取 PDF 文本""" try: @@ -412,7 +413,7 @@ class MultimodalService: except Exception as e: logger.error(f"提取 PDF 文本失败: {e}") return f"[PDF 提取失败: {str(e)}]" - + async def _extract_word_text(self, storage_path: str) -> str: """提取 Word 文档文本""" try: diff --git a/api/app/services/prompt/prompt_optimizer_system.jinja2 b/api/app/services/prompt/prompt_optimizer_system.jinja2 index b9060f68..39a4ba68 100644 --- a/api/app/services/prompt/prompt_optimizer_system.jinja2 +++ b/api/app/services/prompt/prompt_optimizer_system.jinja2 @@ -1,4 +1,3 @@ -{% raw %} Role: AI Prompt Optimization Expert Profile @@ -12,11 +11,11 @@ Skills Core Optimization Skills Requirement Analysis: Accurately understand the relationship between the user’s current needs and the original prompt. Structural Reconstruction: Transform vague requirements into clear, block-structured instructions. -Variable Handling: Identify and standardize dynamic variables in prompts. +{% if skill != true %}Variable Handling: Identify and standardize dynamic variables in prompts.{% endif %} Conflict Resolution: Prioritize current requirements when historical requirements conflict with current needs. Auxiliary Generation Skills -Completeness Check: Ensure all necessary elements (input, output, constraints, etc.) are explicitly defined. +{% if skill != true %}Completeness Check: Ensure all necessary elements (input, output, constraints, etc.) are explicitly defined.{% endif %} Language Consistency: Maintain consistency between label language and user input language. Executability Verification: Ensure optimized prompts can be directly used in AI tools. Format Standardization: Strictly adhere to specified output format requirements. @@ -25,30 +24,30 @@ Rules Basic Principles Priority Rule: When historical requirements conflict with current requirements, unconditionally prioritize current requirements. Completeness Rule: If the original prompt is empty, generate a complete prompt based on the current requirements. -Structure Rule: Use a clear block structure, and the contents of each block are roles, tasks, requirements, inputs, outputs, and constraints +{% if skill != true %}Structure Rule: Use a clear block structure, and the contents of each block are roles, tasks, requirements, inputs, outputs, and constraints{% endif %} Language Rule: All label languages must fully match the user input language. Behavior Guidelines Precision Guideline: All instructions must be precise and directly executable, avoiding ambiguity. Readability Guideline: Ensure optimized prompts have good readability and logical flow. -Variable Handling Guideline: Use lowercase English variable names wrapped in {{}} when variables are needed. -Constraint Handling Guideline: Do not mention variable-related limitations under the [Constraints] label. +{% if skill != true %}{% raw %}Variable Handling Guideline: Use lowercase English variable names wrapped in {{}} when variables are needed. +Constraint Handling Guideline: Do not mention variable-related limitations under the [Constraints] label.{% endraw %}{% endif %} Constraints Output Constraint: Must output in JSON format including the fields "prompt" and "desc". Content Constraint: Must not include any explanations, analyses, or additional comments. Language Constraint: Must use clear and concise language. -Completeness Constraint: Must fully define all missing elements (input details, output format, constraints, etc.). +{% if skill != true %}Completeness Constraint: Must fully define all missing elements (input details, output format, constraints, etc.).{% endif %} Workflows Goal: Optimize or generate AI prompts that can be directly used according to user requirements. Step 1: Receive the user’s current requirement description {{user_require}} and the original prompt {{original_prompt}}. Step 2: Analyze requirements, identify conflicts, and prioritize current requirements. -Step 3: Optimize or generate the prompt in a block-structured format, ensuring all elements are fully defined. +{% if skill != true %}Step 3: Optimize or generate the prompt in a block-structured format, ensuring all elements are fully defined. Step 4: Generate a JSON output containing the optimized prompt and its description. +{% else %}Step 3: Generate a JSON output containing the optimized prompt and its description.{% endif %} Expected Outcome: Obtain a clear, directly executable AI prompt accompanied by an optimization description. Initialization -As an AI Prompt Optimization Expert, you must follow the above Rules and execute tasks according to the Workflows. -{% endraw %} \ No newline at end of file +As an AI Prompt Optimization Expert, you must follow the above Rules and execute tasks according to the Workflows. \ No newline at end of file diff --git a/api/app/services/prompt_optimizer_service.py b/api/app/services/prompt_optimizer_service.py index 2c0b57ac..966ac6e0 100644 --- a/api/app/services/prompt_optimizer_service.py +++ b/api/app/services/prompt_optimizer_service.py @@ -128,7 +128,8 @@ class PromptOptimizerService: session_id: uuid.UUID, user_id: uuid.UUID, current_prompt: str, - user_require: str + user_require: str, + skill: bool = False ) -> AsyncGenerator[dict[str, str | Any], Any]: """ Optimize a user-provided prompt using a configured prompt optimizer LLM. @@ -157,6 +158,7 @@ class PromptOptimizerService: user_id (uuid.UUID): Identifier of the user associated with the session. current_prompt (str): Original prompt to optimize. user_require (str): User's requirements or instructions for optimization. + skill(bool): Is skill required Returns: OptimizePromptResult: An object containing: @@ -186,7 +188,7 @@ class PromptOptimizerService: prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') with open(os.path.join(prompt_path, 'prompt_optimizer_system.jinja2'), 'r', encoding='utf-8') as f: opt_system_prompt = f.read() - rendered_system_message = Template(opt_system_prompt).render() + rendered_system_message = Template(opt_system_prompt).render(skill=skill) with open(os.path.join(prompt_path, 'prompt_optimizer_user.jinja2'), 'r', encoding='utf-8') as f: opt_user_prompt = f.read() diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index f19e2d41..508dd8ac 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -22,6 +22,7 @@ from app.repositories.workflow_repository import ( from app.schemas import DraftRunRequest from app.services.conversation_service import ConversationService from app.services.multi_agent_service import convert_uuids_to_str +from app.services.multimodal_service import MultimodalService logger = logging.getLogger(__name__) @@ -35,6 +36,7 @@ class WorkflowService: self.execution_repo = WorkflowExecutionRepository(db) self.node_execution_repo = WorkflowNodeExecutionRepository(db) self.conversation_service = ConversationService(db) + self.multimodal_service = MultimodalService(db) # ==================== 配置管理 ==================== @@ -444,8 +446,19 @@ class WorkflowService: code=BizCode.CONFIG_MISSING, message=f"工作流配置不存在: app_id={app_id}" ) + files = [] + if payload.files: + for file in payload.files: + files.append( + { + "type": file.type, + "url": await self.multimodal_service.get_file_url(file), + "__file": True + } + ) + input_data = {"message": payload.message, "variables": payload.variables, - "conversation_id": payload.conversation_id} + "conversation_id": payload.conversation_id, "files": files} # 转换 user_id 为 UUID triggered_by_uuid = None @@ -633,8 +646,20 @@ class WorkflowService: code=BizCode.CONFIG_MISSING, message=f"工作流配置不存在: app_id={app_id}" ) + + files = [] + if payload.files: + for file in payload.files: + files.append( + { + "type": file.type, + "url": await self.multimodal_service.get_file_url(file), + "__file": True + } + ) + input_data = {"message": payload.message, "variables": payload.variables, - "conversation_id": payload.conversation_id} + "conversation_id": payload.conversation_id, "files": files} # 转换 user_id 为 UUID triggered_by_uuid = None