feat(workflow): add memory read and write node (#24)
This commit is contained in:
@@ -22,6 +22,7 @@ from app.core.workflow.nodes.variable_aggregator.config import VariableAggregato
|
||||
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
|
||||
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
|
||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||
|
||||
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
||||
__all__ = [
|
||||
@@ -45,6 +46,8 @@ __all__ = [
|
||||
"ParameterExtractorNodeConfig",
|
||||
"LoopNodeConfig",
|
||||
"IterationNodeConfig",
|
||||
"QuestionClassifierNodeConfig"
|
||||
"ToolNodeConfig"
|
||||
"QuestionClassifierNodeConfig",
|
||||
"ToolNodeConfig",
|
||||
"MemoryReadNodeConfig",
|
||||
"MemoryWriteNodeConfig"
|
||||
]
|
||||
|
||||
@@ -22,6 +22,8 @@ class NodeType(StrEnum):
|
||||
ITERATION = "iteration"
|
||||
CYCLE_START = "cycle-start"
|
||||
BREAK = "break"
|
||||
MEMORY_READ = "memory-read"
|
||||
MEMORY_WRITE = "memory-write"
|
||||
|
||||
|
||||
class ComparisonOperator(StrEnum):
|
||||
|
||||
4
api/app/core/workflow/nodes/memory/__init__.py
Normal file
4
api/app/core/workflow/nodes/memory/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||
from app.core.workflow.nodes.memory.node import MemoryReadNode, MemoryWriteNode
|
||||
|
||||
__all__ = ["MemoryReadNodeConfig", "MemoryReadNode", "MemoryWriteNodeConfig", "MemoryWriteNode"]
|
||||
31
api/app/core/workflow/nodes/memory/config.py
Normal file
31
api/app/core/workflow/nodes/memory/config.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import uuid
|
||||
|
||||
from pydantic import Field
|
||||
from typing import Literal
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
|
||||
|
||||
class MemoryReadNodeConfig(BaseNodeConfig):
|
||||
message: str = Field(
|
||||
...
|
||||
)
|
||||
|
||||
config_id: str = Field(
|
||||
...
|
||||
)
|
||||
|
||||
search_switch: str = Field(
|
||||
"0",
|
||||
description="Search mode: 0=verify, 1=direct, 2=context"
|
||||
)
|
||||
|
||||
|
||||
class MemoryWriteNodeConfig(BaseNodeConfig):
|
||||
message: str = Field(
|
||||
...
|
||||
)
|
||||
|
||||
config_id: str = Field(
|
||||
...
|
||||
)
|
||||
59
api/app/core/workflow/nodes/memory/node.py
Normal file
59
api/app/core/workflow/nodes/memory/node.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||
from app.db import get_db_read, get_db_context
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
|
||||
|
||||
class MemoryReadNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = MemoryReadNodeConfig(**self.config)
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
with get_db_read() as db:
|
||||
workspace_id = self.get_variable('sys.workspace_id', state)
|
||||
end_user_id = self.get_variable("sys.user_id", state)
|
||||
|
||||
if not workspace_id:
|
||||
raise RuntimeError("Workspace id is required")
|
||||
if not end_user_id:
|
||||
raise RuntimeError("End user id is required")
|
||||
|
||||
return await MemoryAgentService().read_memory(
|
||||
group_id=end_user_id,
|
||||
message=self.typed_config.message,
|
||||
config_id=self.typed_config.config_id,
|
||||
search_switch=self.typed_config.search_switch,
|
||||
history=[],
|
||||
db=db,
|
||||
storage_type="neo4j",
|
||||
user_rag_memory_id=""
|
||||
)
|
||||
|
||||
|
||||
class MemoryWriteNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
with get_db_context() as db:
|
||||
workspace_id = self.get_variable('sys.workspace_id', state)
|
||||
end_user_id = self.get_variable("sys.user_id", state)
|
||||
|
||||
if not workspace_id:
|
||||
raise RuntimeError("Workspace id is required")
|
||||
if not end_user_id:
|
||||
raise RuntimeError("End user id is required")
|
||||
|
||||
return await MemoryAgentService().write_memory(
|
||||
group_id=end_user_id,
|
||||
message=self.typed_config.message,
|
||||
config_id=self.typed_config.config_id,
|
||||
db=db,
|
||||
storage_type="neo4j",
|
||||
user_rag_memory_id=""
|
||||
)
|
||||
@@ -18,6 +18,7 @@ from app.core.workflow.nodes.if_else import IfElseNode
|
||||
from app.core.workflow.nodes.jinja_render import JinjaRenderNode
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.memory import MemoryReadNode, MemoryWriteNode
|
||||
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
@@ -46,7 +47,9 @@ WorkflowNode = Union[
|
||||
BreakNode,
|
||||
ParameterExtractorNode,
|
||||
QuestionClassifierNode,
|
||||
ToolNode
|
||||
ToolNode,
|
||||
MemoryReadNode,
|
||||
MemoryWriteNode
|
||||
]
|
||||
|
||||
|
||||
@@ -76,6 +79,8 @@ class NodeFactory:
|
||||
NodeType.BREAK: BreakNode,
|
||||
NodeType.CYCLE_START: StartNode,
|
||||
NodeType.TOOL: ToolNode,
|
||||
NodeType.MEMORY_READ: MemoryReadNode,
|
||||
NodeType.MEMORY_WRITE: MemoryWriteNode,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user