feat(workflow): add memory read and write node (#24)
This commit is contained in:
@@ -75,6 +75,7 @@ def list_apps(
|
|||||||
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||||
return success(data=PageData(page=meta, items=items))
|
return success(data=PageData(page=meta, items=items))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{app_id}", summary="获取应用详情")
|
@router.get("/{app_id}", summary="获取应用详情")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def get_app(
|
def get_app(
|
||||||
@@ -337,6 +338,7 @@ def list_app_shares(
|
|||||||
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||||
return success(data=data)
|
return success(data=data)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{app_id}/draft/run", summary="试运行 Agent(使用当前草稿配置)")
|
@router.post("/{app_id}/draft/run", summary="试运行 Agent(使用当前草稿配置)")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def draft_run(
|
async def draft_run(
|
||||||
@@ -374,7 +376,6 @@ async def draft_run(
|
|||||||
if knowledge:
|
if knowledge:
|
||||||
user_rag_memory_id = str(knowledge.id)
|
user_rag_memory_id = str(knowledge.id)
|
||||||
|
|
||||||
|
|
||||||
# 提前验证和准备(在流式响应开始前完成)
|
# 提前验证和准备(在流式响应开始前完成)
|
||||||
from app.services.app_service import AppService
|
from app.services.app_service import AppService
|
||||||
from app.services.multi_agent_service import MultiAgentService
|
from app.services.multi_agent_service import MultiAgentService
|
||||||
@@ -424,7 +425,6 @@ async def draft_run(
|
|||||||
if payload.stream:
|
if payload.stream:
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
|
|
||||||
|
|
||||||
async for event in draft_service.run_stream(
|
async for event in draft_service.run_stream(
|
||||||
agent_config=agent_cfg,
|
agent_config=agent_cfg,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
@@ -643,7 +643,6 @@ async def draft_run(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{app_id}/draft/run/compare", summary="多模型对比试运行")
|
@router.post("/{app_id}/draft/run/compare", summary="多模型对比试运行")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def draft_run_compare(
|
async def draft_run_compare(
|
||||||
@@ -822,6 +821,7 @@ async def get_workflow_config(
|
|||||||
# 配置总是存在(不存在时返回默认模板)
|
# 配置总是存在(不存在时返回默认模板)
|
||||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{app_id}/workflow", summary="更新 Workflow 配置")
|
@router.put("/{app_id}/workflow", summary="更新 Workflow 配置")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def update_workflow_config(
|
async def update_workflow_config(
|
||||||
@@ -833,4 +833,3 @@ async def update_workflow_config(
|
|||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from app.core.workflow.nodes.variable_aggregator.config import VariableAggregato
|
|||||||
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
|
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
|
||||||
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
|
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
|
||||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||||
|
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||||
|
|
||||||
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -45,6 +46,8 @@ __all__ = [
|
|||||||
"ParameterExtractorNodeConfig",
|
"ParameterExtractorNodeConfig",
|
||||||
"LoopNodeConfig",
|
"LoopNodeConfig",
|
||||||
"IterationNodeConfig",
|
"IterationNodeConfig",
|
||||||
"QuestionClassifierNodeConfig"
|
"QuestionClassifierNodeConfig",
|
||||||
"ToolNodeConfig"
|
"ToolNodeConfig",
|
||||||
|
"MemoryReadNodeConfig",
|
||||||
|
"MemoryWriteNodeConfig"
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ class NodeType(StrEnum):
|
|||||||
ITERATION = "iteration"
|
ITERATION = "iteration"
|
||||||
CYCLE_START = "cycle-start"
|
CYCLE_START = "cycle-start"
|
||||||
BREAK = "break"
|
BREAK = "break"
|
||||||
|
MEMORY_READ = "memory-read"
|
||||||
|
MEMORY_WRITE = "memory-write"
|
||||||
|
|
||||||
|
|
||||||
class ComparisonOperator(StrEnum):
|
class ComparisonOperator(StrEnum):
|
||||||
|
|||||||
4
api/app/core/workflow/nodes/memory/__init__.py
Normal file
4
api/app/core/workflow/nodes/memory/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||||
|
from app.core.workflow.nodes.memory.node import MemoryReadNode, MemoryWriteNode
|
||||||
|
|
||||||
|
__all__ = ["MemoryReadNodeConfig", "MemoryReadNode", "MemoryWriteNodeConfig", "MemoryWriteNode"]
|
||||||
31
api/app/core/workflow/nodes/memory/config.py
Normal file
31
api/app/core/workflow/nodes/memory/config.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import uuid
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryReadNodeConfig(BaseNodeConfig):
|
||||||
|
message: str = Field(
|
||||||
|
...
|
||||||
|
)
|
||||||
|
|
||||||
|
config_id: str = Field(
|
||||||
|
...
|
||||||
|
)
|
||||||
|
|
||||||
|
search_switch: str = Field(
|
||||||
|
"0",
|
||||||
|
description="Search mode: 0=verify, 1=direct, 2=context"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryWriteNodeConfig(BaseNodeConfig):
|
||||||
|
message: str = Field(
|
||||||
|
...
|
||||||
|
)
|
||||||
|
|
||||||
|
config_id: str = Field(
|
||||||
|
...
|
||||||
|
)
|
||||||
59
api/app/core/workflow/nodes/memory/node.py
Normal file
59
api/app/core/workflow/nodes/memory/node.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.workflow.nodes import WorkflowState
|
||||||
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
|
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||||
|
from app.db import get_db_read, get_db_context
|
||||||
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryReadNode(BaseNode):
|
||||||
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
|
super().__init__(node_config, workflow_config)
|
||||||
|
self.typed_config = MemoryReadNodeConfig(**self.config)
|
||||||
|
|
||||||
|
async def execute(self, state: WorkflowState) -> Any:
|
||||||
|
with get_db_read() as db:
|
||||||
|
workspace_id = self.get_variable('sys.workspace_id', state)
|
||||||
|
end_user_id = self.get_variable("sys.user_id", state)
|
||||||
|
|
||||||
|
if not workspace_id:
|
||||||
|
raise RuntimeError("Workspace id is required")
|
||||||
|
if not end_user_id:
|
||||||
|
raise RuntimeError("End user id is required")
|
||||||
|
|
||||||
|
return await MemoryAgentService().read_memory(
|
||||||
|
group_id=end_user_id,
|
||||||
|
message=self.typed_config.message,
|
||||||
|
config_id=self.typed_config.config_id,
|
||||||
|
search_switch=self.typed_config.search_switch,
|
||||||
|
history=[],
|
||||||
|
db=db,
|
||||||
|
storage_type="neo4j",
|
||||||
|
user_rag_memory_id=""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryWriteNode(BaseNode):
|
||||||
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
|
super().__init__(node_config, workflow_config)
|
||||||
|
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
||||||
|
|
||||||
|
async def execute(self, state: WorkflowState) -> Any:
|
||||||
|
with get_db_context() as db:
|
||||||
|
workspace_id = self.get_variable('sys.workspace_id', state)
|
||||||
|
end_user_id = self.get_variable("sys.user_id", state)
|
||||||
|
|
||||||
|
if not workspace_id:
|
||||||
|
raise RuntimeError("Workspace id is required")
|
||||||
|
if not end_user_id:
|
||||||
|
raise RuntimeError("End user id is required")
|
||||||
|
|
||||||
|
return await MemoryAgentService().write_memory(
|
||||||
|
group_id=end_user_id,
|
||||||
|
message=self.typed_config.message,
|
||||||
|
config_id=self.typed_config.config_id,
|
||||||
|
db=db,
|
||||||
|
storage_type="neo4j",
|
||||||
|
user_rag_memory_id=""
|
||||||
|
)
|
||||||
@@ -18,6 +18,7 @@ from app.core.workflow.nodes.if_else import IfElseNode
|
|||||||
from app.core.workflow.nodes.jinja_render import JinjaRenderNode
|
from app.core.workflow.nodes.jinja_render import JinjaRenderNode
|
||||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||||
from app.core.workflow.nodes.llm import LLMNode
|
from app.core.workflow.nodes.llm import LLMNode
|
||||||
|
from app.core.workflow.nodes.memory import MemoryReadNode, MemoryWriteNode
|
||||||
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||||
from app.core.workflow.nodes.start import StartNode
|
from app.core.workflow.nodes.start import StartNode
|
||||||
from app.core.workflow.nodes.transform import TransformNode
|
from app.core.workflow.nodes.transform import TransformNode
|
||||||
@@ -46,7 +47,9 @@ WorkflowNode = Union[
|
|||||||
BreakNode,
|
BreakNode,
|
||||||
ParameterExtractorNode,
|
ParameterExtractorNode,
|
||||||
QuestionClassifierNode,
|
QuestionClassifierNode,
|
||||||
ToolNode
|
ToolNode,
|
||||||
|
MemoryReadNode,
|
||||||
|
MemoryWriteNode
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -76,6 +79,8 @@ class NodeFactory:
|
|||||||
NodeType.BREAK: BreakNode,
|
NodeType.BREAK: BreakNode,
|
||||||
NodeType.CYCLE_START: StartNode,
|
NodeType.CYCLE_START: StartNode,
|
||||||
NodeType.TOOL: ToolNode,
|
NodeType.TOOL: ToolNode,
|
||||||
|
NodeType.MEMORY_READ: MemoryReadNode,
|
||||||
|
NodeType.MEMORY_WRITE: MemoryWriteNode,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -14,8 +14,9 @@ from sqlalchemy.orm import Session
|
|||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.workflow.validator import validate_workflow_config
|
from app.core.workflow.validator import validate_workflow_config
|
||||||
from app.db import get_db
|
from app.db import get_db, get_db_context
|
||||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||||
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
from app.repositories.workflow_repository import (
|
from app.repositories.workflow_repository import (
|
||||||
WorkflowConfigRepository,
|
WorkflowConfigRepository,
|
||||||
WorkflowExecutionRepository,
|
WorkflowExecutionRepository,
|
||||||
@@ -480,13 +481,21 @@ class WorkflowService:
|
|||||||
try:
|
try:
|
||||||
# 更新状态为运行中
|
# 更新状态为运行中
|
||||||
self.update_execution_status(execution.execution_id, "running")
|
self.update_execution_status(execution.execution_id, "running")
|
||||||
|
with get_db_context() as db:
|
||||||
|
end_user_repo = EndUserRepository(db)
|
||||||
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
|
app_id=app_id,
|
||||||
|
other_id=payload.user_id,
|
||||||
|
original_user_id=payload.user_id # Save original user_id to other_id
|
||||||
|
)
|
||||||
|
end_user_id = str(new_end_user.id)
|
||||||
|
|
||||||
result = await execute_workflow(
|
result = await execute_workflow(
|
||||||
workflow_config=workflow_config_dict,
|
workflow_config=workflow_config_dict,
|
||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
execution_id=execution.execution_id,
|
execution_id=execution.execution_id,
|
||||||
workspace_id=str(workspace_id),
|
workspace_id=str(workspace_id),
|
||||||
user_id=payload.user_id
|
user_id=end_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# 更新执行结果
|
# 更新执行结果
|
||||||
@@ -599,6 +608,14 @@ class WorkflowService:
|
|||||||
try:
|
try:
|
||||||
# 更新状态为运行中
|
# 更新状态为运行中
|
||||||
self.update_execution_status(execution.execution_id, "running")
|
self.update_execution_status(execution.execution_id, "running")
|
||||||
|
with get_db_context() as db:
|
||||||
|
end_user_repo = EndUserRepository(db)
|
||||||
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
|
app_id=app_id,
|
||||||
|
other_id=payload.user_id,
|
||||||
|
original_user_id=payload.user_id # Save original user_id to other_id
|
||||||
|
)
|
||||||
|
end_user_id = str(new_end_user.id)
|
||||||
|
|
||||||
# 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件)
|
# 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件)
|
||||||
async for event in self._run_workflow_stream(
|
async for event in self._run_workflow_stream(
|
||||||
@@ -606,7 +623,7 @@ class WorkflowService:
|
|||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
execution_id=execution.execution_id,
|
execution_id=execution.execution_id,
|
||||||
workspace_id=str(workspace_id),
|
workspace_id=str(workspace_id),
|
||||||
user_id=payload.user_id
|
user_id=end_user_id
|
||||||
):
|
):
|
||||||
# 直接转发 executor 的事件(已经是正确的格式)
|
# 直接转发 executor 的事件(已经是正确的格式)
|
||||||
yield event
|
yield event
|
||||||
|
|||||||
Reference in New Issue
Block a user