Merge pull request #117 from SuanmoSuanyangTechnology/feature/workflow-memory-write

feat(workflow): support async memory writes via Celery
This commit is contained in:
Mark
2026-01-14 18:23:29 +08:00
committed by GitHub

View File

@@ -3,8 +3,9 @@ from typing import Any
from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
from app.db import get_db_read, get_db_context from app.db import get_db_read
from app.services.memory_agent_service import MemoryAgentService from app.services.memory_agent_service import MemoryAgentService
from app.tasks import write_message_task
class MemoryReadNode(BaseNode): class MemoryReadNode(BaseNode):
@@ -15,11 +16,8 @@ class MemoryReadNode(BaseNode):
async def execute(self, state: WorkflowState) -> Any: async def execute(self, state: WorkflowState) -> Any:
self.typed_config = MemoryReadNodeConfig(**self.config) self.typed_config = MemoryReadNodeConfig(**self.config)
with get_db_read() as db: with get_db_read() as db:
workspace_id = self.get_variable('sys.workspace_id', state)
end_user_id = self.get_variable("sys.user_id", state) end_user_id = self.get_variable("sys.user_id", state)
if not workspace_id:
raise RuntimeError("Workspace id is required")
if not end_user_id: if not end_user_id:
raise RuntimeError("End user id is required") raise RuntimeError("End user id is required")
@@ -41,20 +39,17 @@ class MemoryWriteNode(BaseNode):
self.typed_config = MemoryWriteNodeConfig(**self.config) self.typed_config = MemoryWriteNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> Any: async def execute(self, state: WorkflowState) -> Any:
with get_db_context() as db: end_user_id = self.get_variable("sys.user_id", state)
workspace_id = self.get_variable('sys.workspace_id', state)
end_user_id = self.get_variable("sys.user_id", state)
if not workspace_id: if not end_user_id:
raise RuntimeError("Workspace id is required") raise RuntimeError("End user id is required")
if not end_user_id:
raise RuntimeError("End user id is required")
return await MemoryAgentService().write_memory( write_message_task.delay(
group_id=end_user_id, end_user_id,
message=self._render_template(self.typed_config.message, state), self._render_template(self.typed_config.message, state),
config_id=str(self.typed_config.config_id), str(self.typed_config.config_id),
db=db, "neo4j",
storage_type="neo4j", ""
user_rag_memory_id="" )
)
return "success"