Merge pull request #117 from SuanmoSuanyangTechnology/feature/workflow-memory-write
feat(workflow): support async memory writes via Celery
This commit is contained in:
@@ -3,8 +3,9 @@ from typing import Any
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||
from app.db import get_db_read, get_db_context
|
||||
from app.db import get_db_read
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.tasks import write_message_task
|
||||
|
||||
|
||||
class MemoryReadNode(BaseNode):
|
||||
@@ -15,11 +16,8 @@ class MemoryReadNode(BaseNode):
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
self.typed_config = MemoryReadNodeConfig(**self.config)
|
||||
with get_db_read() as db:
|
||||
workspace_id = self.get_variable('sys.workspace_id', state)
|
||||
end_user_id = self.get_variable("sys.user_id", state)
|
||||
|
||||
if not workspace_id:
|
||||
raise RuntimeError("Workspace id is required")
|
||||
if not end_user_id:
|
||||
raise RuntimeError("End user id is required")
|
||||
|
||||
@@ -41,20 +39,17 @@ class MemoryWriteNode(BaseNode):
|
||||
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
with get_db_context() as db:
|
||||
workspace_id = self.get_variable('sys.workspace_id', state)
|
||||
end_user_id = self.get_variable("sys.user_id", state)
|
||||
end_user_id = self.get_variable("sys.user_id", state)
|
||||
|
||||
if not workspace_id:
|
||||
raise RuntimeError("Workspace id is required")
|
||||
if not end_user_id:
|
||||
raise RuntimeError("End user id is required")
|
||||
if not end_user_id:
|
||||
raise RuntimeError("End user id is required")
|
||||
|
||||
return await MemoryAgentService().write_memory(
|
||||
group_id=end_user_id,
|
||||
message=self._render_template(self.typed_config.message, state),
|
||||
config_id=str(self.typed_config.config_id),
|
||||
db=db,
|
||||
storage_type="neo4j",
|
||||
user_rag_memory_id=""
|
||||
)
|
||||
write_message_task.delay(
|
||||
end_user_id,
|
||||
self._render_template(self.typed_config.message, state),
|
||||
str(self.typed_config.config_id),
|
||||
"neo4j",
|
||||
""
|
||||
)
|
||||
|
||||
return "success"
|
||||
|
||||
Reference in New Issue
Block a user