From 31085ed678ef29d77ba9a7feaa59338a7201d195 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Fri, 20 Mar 2026 18:31:17 +0800 Subject: [PATCH] fix(workflow): fix memory write behavior in RAG workspace --- .../core/workflow/engine/runtime_schema.py | 14 ++++- api/app/core/workflow/engine/state_manager.py | 9 ++- api/app/core/workflow/engine/variable_pool.py | 10 ++- api/app/core/workflow/executor.py | 20 ++++-- api/app/core/workflow/nodes/memory/node.py | 61 ++++++++++++++++--- api/app/services/workflow_service.py | 29 ++++++++- api/app/services/workspace_service.py | 2 +- 7 files changed, 128 insertions(+), 17 deletions(-) diff --git a/api/app/core/workflow/engine/runtime_schema.py b/api/app/core/workflow/engine/runtime_schema.py index e4bf65af..48eafaa9 100644 --- a/api/app/core/workflow/engine/runtime_schema.py +++ b/api/app/core/workflow/engine/runtime_schema.py @@ -12,14 +12,26 @@ class ExecutionContext(BaseModel): execution_id: str workspace_id: str user_id: str + memory_storage_type: str + user_rag_memory_id: str checkpoint_config: RunnableConfig @classmethod - def create(cls, execution_id: str, workspace_id: str, user_id: str): + def create( + cls, + execution_id: str, + workspace_id: str, + user_id: str, + memory_storage_type: str, + user_rag_memory_id: str + ): return cls( execution_id=execution_id, workspace_id=workspace_id, user_id=user_id, + memory_storage_type=memory_storage_type, + user_rag_memory_id=user_rag_memory_id, + checkpoint_config=RunnableConfig( configurable={ "thread_id": uuid.uuid4(), diff --git a/api/app/core/workflow/engine/state_manager.py b/api/app/core/workflow/engine/state_manager.py index 0a4a1463..2da0d3a8 100644 --- a/api/app/core/workflow/engine/state_manager.py +++ b/api/app/core/workflow/engine/state_manager.py @@ -33,6 +33,8 @@ class WorkflowState(dict): "workspace_id", "user_id", "activate", + "memory_storage_type", + "user_rag_memory_id" }) __optional_keys__ = frozenset({ "error", @@ -62,6 +64,9 @@ class WorkflowState(dict): # node activate status activate: Annotated[dict[str, bool], merge_activate_state] + memory_storage_type: str + user_rag_memory_id: str + class WorkflowStateManager: def create_initial_state( @@ -85,7 +90,9 @@ class WorkflowStateManager: looping=0, activate={ start_node_id: True - } + }, + memory_storage_type=execution_context.memory_storage_type, + user_rag_memory_id=execution_context.user_rag_memory_id ) @staticmethod diff --git a/api/app/core/workflow/engine/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py index cf6f4a7b..d4e1b488 100644 --- a/api/app/core/workflow/engine/variable_pool.py +++ b/api/app/core/workflow/engine/variable_pool.py @@ -13,7 +13,7 @@ from pydantic import BaseModel from app.core.workflow.engine.runtime_schema import ExecutionContext from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE -from app.core.workflow.variable.variable_objects import T, create_variable_instance +from app.core.workflow.variable.variable_objects import T, create_variable_instance, ArrayVariable, FileVariable logger = logging.getLogger(__name__) @@ -373,6 +373,14 @@ class VariablePool: def copy(self, pool: 'VariablePool'): self.variables = deepcopy(pool.variables) + def is_file_variable(self, selector): + variable_struct = self._get_variable_struct(selector) + if isinstance(variable_struct, FileVariable): + return True + elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable: + return True + return False + def to_dict(self) -> dict[str, Any]: """导出为字典 diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index c9ed6e65..6a127e96 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -409,7 +409,9 @@ async def execute_workflow( input_data: dict[str, Any], execution_id: str, workspace_id: str, - user_id: str + user_id: str, + memory_storage_type: str, + user_rag_memory_id: str ) -> dict[str, Any]: """ Execute a workflow (convenience function, non-streaming). @@ -420,6 +422,8 @@ async def execute_workflow( execution_id (str): Execution ID. workspace_id (str): Workspace ID. user_id (str): User ID. + user_rag_memory_id: rag knowledge db id + memory_storage_type: neo4j / rag Returns: dict: Workflow execution result. @@ -427,7 +431,9 @@ async def execute_workflow( execution_context = ExecutionContext.create( execution_id=execution_id, workspace_id=workspace_id, - user_id=user_id + user_id=user_id, + memory_storage_type=memory_storage_type, + user_rag_memory_id=user_rag_memory_id ) executor = WorkflowExecutor( workflow_config=workflow_config, @@ -441,7 +447,9 @@ async def execute_workflow_stream( input_data: dict[str, Any], execution_id: str, workspace_id: str, - user_id: str + user_id: str, + memory_storage_type: str, + user_rag_memory_id: str ): """ Execute a workflow in streaming mode (convenience function). @@ -452,6 +460,8 @@ async def execute_workflow_stream( execution_id (str): Execution ID. workspace_id (str): Workspace ID. user_id (str): User ID. + user_rag_memory_id: rag knowledge db id + memory_storage_type: neo4j / rag Yields: dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end. @@ -459,7 +469,9 @@ async def execute_workflow_stream( execution_context = ExecutionContext.create( execution_id=execution_id, workspace_id=workspace_id, - user_id=user_id + user_id=user_id, + memory_storage_type=memory_storage_type, + user_rag_memory_id=user_rag_memory_id ) executor = WorkflowExecutor( workflow_config=workflow_config, diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index 1d42e82e..82363056 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -1,3 +1,4 @@ +import re from typing import Any from app.core.workflow.engine.state_manager import WorkflowState @@ -5,7 +6,9 @@ from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable from app.db import get_db_read +from app.schemas import FileInput from app.services.memory_agent_service import MemoryAgentService from app.tasks import write_message_task @@ -36,8 +39,8 @@ class MemoryReadNode(BaseNode): search_switch=self.typed_config.search_switch, history=[], db=db, - storage_type="neo4j", - user_rag_memory_id="" + storage_type=state["memory_storage_type"], + user_rag_memory_id=state["user_rag_memory_id"] ) @@ -49,6 +52,19 @@ class MemoryWriteNode(BaseNode): def _output_types(self) -> dict[str, VariableType]: return {"output": VariableType.STRING} + @staticmethod + def _extract_multimodal_memory_variables(content: str, variable_pool: VariablePool) -> tuple[list[str], str]: + variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}' + variable_pattern = re.compile(variable_pattern_string) + variables = variable_pattern.findall(content) + file_variables = [] + for variable in variables: + if variable_pool.is_file_variable(variable): + file_variables.append(variable) + for var in file_variables: + content = content.replace(var, "") + return file_variables, content + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: self.typed_config = MemoryWriteNodeConfig(**self.config) end_user_id = self.get_variable("sys.user_id", variable_pool) @@ -56,6 +72,7 @@ class MemoryWriteNode(BaseNode): if not end_user_id: raise RuntimeError("End user id is required") messages = [] + multimodal_memories = [] if self.typed_config.message: messages.append({ "role": "user", @@ -63,17 +80,45 @@ class MemoryWriteNode(BaseNode): }) for message in self.typed_config.messages: + file_variables, content = self._extract_multimodal_memory_variables( + message.content, + variable_pool + ) + file_info = [] + for var in file_variables: + instence: FileVariable | ArrayVariable[FileVariable] = variable_pool.get_instance(var) + if isinstance(instence, FileVariable): + file_info.append(FileInput( + type=instence.value.type, + transfer_method=instence.value.transfer_method, + upload_file_id=instence.value.file_id, + url=instence.value.url, + file_type=instence.value.origin_file_type + ).model_dump()) + elif isinstance(instence, ArrayVariable) and instence.child_type == FileVariable: + for file_instence in instence.value: + file_info.append(FileInput( + type=file_instence.value.type, + transfer_method=file_instence.value.transfer_method, + upload_file_id=file_instence.value.file_id, + url=file_instence.value.url, + file_type=file_instence.value.origin_file_type + ).model_dump()) + multimodal_memories.append({ + "role": message.role, + "files": file_info + }) messages.append({ "role": message.role, - "content": self._render_template(message.content, variable_pool) + "content": self._render_template(content, variable_pool) }) write_message_task.delay( - end_user_id, - messages, - str(self.typed_config.config_id), - "neo4j", - "" + end_user_id=end_user_id, + message=messages, + config_id=str(self.typed_config.config_id), + storage_type=state["memory_storage_type"], + user_rag_memory_id=state["user_rag_memory_id"] ) return "success" diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 56f34496..db659268 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -20,6 +20,7 @@ from app.core.workflow.variable.base_variable import FileObject from app.db import get_db from app.models import App from app.models.workflow_model import WorkflowConfig, WorkflowExecution +from app.repositories import knowledge_repository from app.repositories.workflow_repository import ( WorkflowConfigRepository, WorkflowExecutionRepository, @@ -29,6 +30,7 @@ from app.schemas import DraftRunRequest, FileInput from app.services.conversation_service import ConversationService from app.services.multi_agent_service import convert_uuids_to_str from app.services.multimodal_service import MultimodalService +from app.services.workspace_service import get_workspace_storage_type_without_auth logger = logging.getLogger(__name__) @@ -536,6 +538,25 @@ class WorkflowService: mapped = internal_event return mapped + def _get_memory_store_info(self, workspace_id: uuid.UUID) -> tuple[str, str]: + storage_type = get_workspace_storage_type_without_auth(self.db, workspace_id) + user_rag_memory_id = "" + if storage_type == "rag": + knowledge = knowledge_repository.get_knowledge_by_name( + db=self.db, + name="USER_RAG_MERORY", + workspace_id=workspace_id + ) + if knowledge: + user_rag_memory_id = str(knowledge.id) + else: + logger.warning( + f"No knowledge base named 'USER_RAG_MEMORY' found, " + f"workspace_id: {workspace_id}, will use neo4j storage" + ) + storage_type = 'neo4j' + return storage_type, user_rag_memory_id + # ==================== 工作流执行 ==================== async def run( @@ -603,6 +624,7 @@ class WorkflowService: try: files = await self._handle_file_input(payload.files) + storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id) input_data["files"] = files message_id = uuid.uuid4() # 更新状态为运行中 @@ -627,7 +649,9 @@ class WorkflowService: input_data=input_data, execution_id=execution.execution_id, workspace_id=str(workspace_id), - user_id=payload.user_id + user_id=payload.user_id, + memory_storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id ) # 更新执行结果 if result.get("status") == "completed": @@ -776,6 +800,7 @@ class WorkflowService: try: files = await self._handle_file_input(payload.files) + storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id) input_data["files"] = files self.update_execution_status(execution.execution_id, "running") executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) @@ -797,6 +822,8 @@ class WorkflowService: execution_id=execution.execution_id, workspace_id=str(workspace_id), user_id=payload.user_id, + memory_storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id ): if event.get("event") == "workflow_end": status = event.get("data", {}).get("status") diff --git a/api/app/services/workspace_service.py b/api/app/services/workspace_service.py index cefb8380..90b5cf65 100644 --- a/api/app/services/workspace_service.py +++ b/api/app/services/workspace_service.py @@ -863,7 +863,7 @@ def get_workspace_storage_type( def get_workspace_storage_type_without_auth( db: Session, workspace_id: uuid.UUID, -) -> Optional[str]: +) -> str: """获取工作空间的存储类型(无需权限验证,用于公开分享等场景) Args: