diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index f1c99ddb..f83ab7e1 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -3,8 +3,9 @@ from typing import Any from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig -from app.db import get_db_read, get_db_context +from app.db import get_db_read from app.services.memory_agent_service import MemoryAgentService +from app.tasks import write_message_task class MemoryReadNode(BaseNode): @@ -41,20 +42,20 @@ class MemoryWriteNode(BaseNode): self.typed_config = MemoryWriteNodeConfig(**self.config) async def execute(self, state: WorkflowState) -> Any: - with get_db_context() as db: - workspace_id = self.get_variable('sys.workspace_id', state) - end_user_id = self.get_variable("sys.user_id", state) + workspace_id = self.get_variable('sys.workspace_id', state) + end_user_id = self.get_variable("sys.user_id", state) - if not workspace_id: - raise RuntimeError("Workspace id is required") - if not end_user_id: - raise RuntimeError("End user id is required") + if not workspace_id: + raise RuntimeError("Workspace id is required") + if not end_user_id: + raise RuntimeError("End user id is required") - return await MemoryAgentService().write_memory( - group_id=end_user_id, - message=self._render_template(self.typed_config.message, state), - config_id=str(self.typed_config.config_id), - db=db, - storage_type="neo4j", - user_rag_memory_id="" - ) + write_message_task.delay( + end_user_id, + self._render_template(self.typed_config.message, state), + str(self.typed_config.config_id), + "neo4j", + "" + ) + + return "success"