feat(workflow): add memory read and write node (#24)

This commit is contained in:
Eternity
2026-01-05 15:57:04 +08:00
committed by GitHub
parent ab0e465760
commit 78207aca34
8 changed files with 240 additions and 120 deletions

View File

@@ -22,6 +22,7 @@ from app.core.workflow.nodes.variable_aggregator.config import VariableAggregato
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
__all__ = [
@@ -45,6 +46,8 @@ __all__ = [
"ParameterExtractorNodeConfig",
"LoopNodeConfig",
"IterationNodeConfig",
"QuestionClassifierNodeConfig"
"ToolNodeConfig"
"QuestionClassifierNodeConfig",
"ToolNodeConfig",
"MemoryReadNodeConfig",
"MemoryWriteNodeConfig"
]

View File

@@ -22,6 +22,8 @@ class NodeType(StrEnum):
ITERATION = "iteration"
CYCLE_START = "cycle-start"
BREAK = "break"
MEMORY_READ = "memory-read"
MEMORY_WRITE = "memory-write"
class ComparisonOperator(StrEnum):

View File

@@ -0,0 +1,4 @@
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
from app.core.workflow.nodes.memory.node import MemoryReadNode, MemoryWriteNode
__all__ = ["MemoryReadNodeConfig", "MemoryReadNode", "MemoryWriteNodeConfig", "MemoryWriteNode"]

View File

@@ -0,0 +1,31 @@
import uuid
from pydantic import Field
from typing import Literal
from app.core.workflow.nodes.base_config import BaseNodeConfig
class MemoryReadNodeConfig(BaseNodeConfig):
message: str = Field(
...
)
config_id: str = Field(
...
)
search_switch: str = Field(
"0",
description="Search mode: 0=verify, 1=direct, 2=context"
)
class MemoryWriteNodeConfig(BaseNodeConfig):
message: str = Field(
...
)
config_id: str = Field(
...
)

View File

@@ -0,0 +1,59 @@
from typing import Any
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
from app.db import get_db_read, get_db_context
from app.services.memory_agent_service import MemoryAgentService
class MemoryReadNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = MemoryReadNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> Any:
with get_db_read() as db:
workspace_id = self.get_variable('sys.workspace_id', state)
end_user_id = self.get_variable("sys.user_id", state)
if not workspace_id:
raise RuntimeError("Workspace id is required")
if not end_user_id:
raise RuntimeError("End user id is required")
return await MemoryAgentService().read_memory(
group_id=end_user_id,
message=self.typed_config.message,
config_id=self.typed_config.config_id,
search_switch=self.typed_config.search_switch,
history=[],
db=db,
storage_type="neo4j",
user_rag_memory_id=""
)
class MemoryWriteNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = MemoryWriteNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> Any:
with get_db_context() as db:
workspace_id = self.get_variable('sys.workspace_id', state)
end_user_id = self.get_variable("sys.user_id", state)
if not workspace_id:
raise RuntimeError("Workspace id is required")
if not end_user_id:
raise RuntimeError("End user id is required")
return await MemoryAgentService().write_memory(
group_id=end_user_id,
message=self.typed_config.message,
config_id=self.typed_config.config_id,
db=db,
storage_type="neo4j",
user_rag_memory_id=""
)

View File

@@ -18,6 +18,7 @@ from app.core.workflow.nodes.if_else import IfElseNode
from app.core.workflow.nodes.jinja_render import JinjaRenderNode
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.memory import MemoryReadNode, MemoryWriteNode
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode
@@ -46,7 +47,9 @@ WorkflowNode = Union[
BreakNode,
ParameterExtractorNode,
QuestionClassifierNode,
ToolNode
ToolNode,
MemoryReadNode,
MemoryWriteNode
]
@@ -76,6 +79,8 @@ class NodeFactory:
NodeType.BREAK: BreakNode,
NodeType.CYCLE_START: StartNode,
NodeType.TOOL: ToolNode,
NodeType.MEMORY_READ: MemoryReadNode,
NodeType.MEMORY_WRITE: MemoryWriteNode,
}
@classmethod