fix(workflow): fix memory write behavior in RAG workspace

This commit is contained in:
Eternity
2026-03-20 18:31:17 +08:00
parent dce7206c44
commit 31085ed678
7 changed files with 128 additions and 17 deletions

View File

@@ -12,14 +12,26 @@ class ExecutionContext(BaseModel):
execution_id: str execution_id: str
workspace_id: str workspace_id: str
user_id: str user_id: str
memory_storage_type: str
user_rag_memory_id: str
checkpoint_config: RunnableConfig checkpoint_config: RunnableConfig
@classmethod @classmethod
def create(cls, execution_id: str, workspace_id: str, user_id: str): def create(
cls,
execution_id: str,
workspace_id: str,
user_id: str,
memory_storage_type: str,
user_rag_memory_id: str
):
return cls( return cls(
execution_id=execution_id, execution_id=execution_id,
workspace_id=workspace_id, workspace_id=workspace_id,
user_id=user_id, user_id=user_id,
memory_storage_type=memory_storage_type,
user_rag_memory_id=user_rag_memory_id,
checkpoint_config=RunnableConfig( checkpoint_config=RunnableConfig(
configurable={ configurable={
"thread_id": uuid.uuid4(), "thread_id": uuid.uuid4(),

View File

@@ -33,6 +33,8 @@ class WorkflowState(dict):
"workspace_id", "workspace_id",
"user_id", "user_id",
"activate", "activate",
"memory_storage_type",
"user_rag_memory_id"
}) })
__optional_keys__ = frozenset({ __optional_keys__ = frozenset({
"error", "error",
@@ -62,6 +64,9 @@ class WorkflowState(dict):
# node activate status # node activate status
activate: Annotated[dict[str, bool], merge_activate_state] activate: Annotated[dict[str, bool], merge_activate_state]
memory_storage_type: str
user_rag_memory_id: str
class WorkflowStateManager: class WorkflowStateManager:
def create_initial_state( def create_initial_state(
@@ -85,7 +90,9 @@ class WorkflowStateManager:
looping=0, looping=0,
activate={ activate={
start_node_id: True start_node_id: True
} },
memory_storage_type=execution_context.memory_storage_type,
user_rag_memory_id=execution_context.user_rag_memory_id
) )
@staticmethod @staticmethod

View File

@@ -13,7 +13,7 @@ from pydantic import BaseModel
from app.core.workflow.engine.runtime_schema import ExecutionContext from app.core.workflow.engine.runtime_schema import ExecutionContext
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
from app.core.workflow.variable.variable_objects import T, create_variable_instance from app.core.workflow.variable.variable_objects import T, create_variable_instance, ArrayVariable, FileVariable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -373,6 +373,14 @@ class VariablePool:
def copy(self, pool: 'VariablePool'): def copy(self, pool: 'VariablePool'):
self.variables = deepcopy(pool.variables) self.variables = deepcopy(pool.variables)
def is_file_variable(self, selector):
variable_struct = self._get_variable_struct(selector)
if isinstance(variable_struct, FileVariable):
return True
elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable:
return True
return False
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""导出为字典 """导出为字典

View File

@@ -409,7 +409,9 @@ async def execute_workflow(
input_data: dict[str, Any], input_data: dict[str, Any],
execution_id: str, execution_id: str,
workspace_id: str, workspace_id: str,
user_id: str user_id: str,
memory_storage_type: str,
user_rag_memory_id: str
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Execute a workflow (convenience function, non-streaming). Execute a workflow (convenience function, non-streaming).
@@ -420,6 +422,8 @@ async def execute_workflow(
execution_id (str): Execution ID. execution_id (str): Execution ID.
workspace_id (str): Workspace ID. workspace_id (str): Workspace ID.
user_id (str): User ID. user_id (str): User ID.
user_rag_memory_id: rag knowledge db id
memory_storage_type: neo4j / rag
Returns: Returns:
dict: Workflow execution result. dict: Workflow execution result.
@@ -427,7 +431,9 @@ async def execute_workflow(
execution_context = ExecutionContext.create( execution_context = ExecutionContext.create(
execution_id=execution_id, execution_id=execution_id,
workspace_id=workspace_id, workspace_id=workspace_id,
user_id=user_id user_id=user_id,
memory_storage_type=memory_storage_type,
user_rag_memory_id=user_rag_memory_id
) )
executor = WorkflowExecutor( executor = WorkflowExecutor(
workflow_config=workflow_config, workflow_config=workflow_config,
@@ -441,7 +447,9 @@ async def execute_workflow_stream(
input_data: dict[str, Any], input_data: dict[str, Any],
execution_id: str, execution_id: str,
workspace_id: str, workspace_id: str,
user_id: str user_id: str,
memory_storage_type: str,
user_rag_memory_id: str
): ):
""" """
Execute a workflow in streaming mode (convenience function). Execute a workflow in streaming mode (convenience function).
@@ -452,6 +460,8 @@ async def execute_workflow_stream(
execution_id (str): Execution ID. execution_id (str): Execution ID.
workspace_id (str): Workspace ID. workspace_id (str): Workspace ID.
user_id (str): User ID. user_id (str): User ID.
user_rag_memory_id: rag knowledge db id
memory_storage_type: neo4j / rag
Yields: Yields:
dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end. dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end.
@@ -459,7 +469,9 @@ async def execute_workflow_stream(
execution_context = ExecutionContext.create( execution_context = ExecutionContext.create(
execution_id=execution_id, execution_id=execution_id,
workspace_id=workspace_id, workspace_id=workspace_id,
user_id=user_id user_id=user_id,
memory_storage_type=memory_storage_type,
user_rag_memory_id=user_rag_memory_id
) )
executor = WorkflowExecutor( executor = WorkflowExecutor(
workflow_config=workflow_config, workflow_config=workflow_config,

View File

@@ -1,3 +1,4 @@
import re
from typing import Any from typing import Any
from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.state_manager import WorkflowState
@@ -5,7 +6,9 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
from app.db import get_db_read from app.db import get_db_read
from app.schemas import FileInput
from app.services.memory_agent_service import MemoryAgentService from app.services.memory_agent_service import MemoryAgentService
from app.tasks import write_message_task from app.tasks import write_message_task
@@ -36,8 +39,8 @@ class MemoryReadNode(BaseNode):
search_switch=self.typed_config.search_switch, search_switch=self.typed_config.search_switch,
history=[], history=[],
db=db, db=db,
storage_type="neo4j", storage_type=state["memory_storage_type"],
user_rag_memory_id="" user_rag_memory_id=state["user_rag_memory_id"]
) )
@@ -49,6 +52,19 @@ class MemoryWriteNode(BaseNode):
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING} return {"output": VariableType.STRING}
@staticmethod
def _extract_multimodal_memory_variables(content: str, variable_pool: VariablePool) -> tuple[list[str], str]:
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
variable_pattern = re.compile(variable_pattern_string)
variables = variable_pattern.findall(content)
file_variables = []
for variable in variables:
if variable_pool.is_file_variable(variable):
file_variables.append(variable)
for var in file_variables:
content = content.replace(var, "")
return file_variables, content
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
self.typed_config = MemoryWriteNodeConfig(**self.config) self.typed_config = MemoryWriteNodeConfig(**self.config)
end_user_id = self.get_variable("sys.user_id", variable_pool) end_user_id = self.get_variable("sys.user_id", variable_pool)
@@ -56,6 +72,7 @@ class MemoryWriteNode(BaseNode):
if not end_user_id: if not end_user_id:
raise RuntimeError("End user id is required") raise RuntimeError("End user id is required")
messages = [] messages = []
multimodal_memories = []
if self.typed_config.message: if self.typed_config.message:
messages.append({ messages.append({
"role": "user", "role": "user",
@@ -63,17 +80,45 @@ class MemoryWriteNode(BaseNode):
}) })
for message in self.typed_config.messages: for message in self.typed_config.messages:
file_variables, content = self._extract_multimodal_memory_variables(
message.content,
variable_pool
)
file_info = []
for var in file_variables:
instence: FileVariable | ArrayVariable[FileVariable] = variable_pool.get_instance(var)
if isinstance(instence, FileVariable):
file_info.append(FileInput(
type=instence.value.type,
transfer_method=instence.value.transfer_method,
upload_file_id=instence.value.file_id,
url=instence.value.url,
file_type=instence.value.origin_file_type
).model_dump())
elif isinstance(instence, ArrayVariable) and instence.child_type == FileVariable:
for file_instence in instence.value:
file_info.append(FileInput(
type=file_instence.value.type,
transfer_method=file_instence.value.transfer_method,
upload_file_id=file_instence.value.file_id,
url=file_instence.value.url,
file_type=file_instence.value.origin_file_type
).model_dump())
multimodal_memories.append({
"role": message.role,
"files": file_info
})
messages.append({ messages.append({
"role": message.role, "role": message.role,
"content": self._render_template(message.content, variable_pool) "content": self._render_template(content, variable_pool)
}) })
write_message_task.delay( write_message_task.delay(
end_user_id, end_user_id=end_user_id,
messages, message=messages,
str(self.typed_config.config_id), config_id=str(self.typed_config.config_id),
"neo4j", storage_type=state["memory_storage_type"],
"" user_rag_memory_id=state["user_rag_memory_id"]
) )
return "success" return "success"

View File

@@ -20,6 +20,7 @@ from app.core.workflow.variable.base_variable import FileObject
from app.db import get_db from app.db import get_db
from app.models import App from app.models import App
from app.models.workflow_model import WorkflowConfig, WorkflowExecution from app.models.workflow_model import WorkflowConfig, WorkflowExecution
from app.repositories import knowledge_repository
from app.repositories.workflow_repository import ( from app.repositories.workflow_repository import (
WorkflowConfigRepository, WorkflowConfigRepository,
WorkflowExecutionRepository, WorkflowExecutionRepository,
@@ -29,6 +30,7 @@ from app.schemas import DraftRunRequest, FileInput
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
from app.services.multi_agent_service import convert_uuids_to_str from app.services.multi_agent_service import convert_uuids_to_str
from app.services.multimodal_service import MultimodalService from app.services.multimodal_service import MultimodalService
from app.services.workspace_service import get_workspace_storage_type_without_auth
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -536,6 +538,25 @@ class WorkflowService:
mapped = internal_event mapped = internal_event
return mapped return mapped
def _get_memory_store_info(self, workspace_id: uuid.UUID) -> tuple[str, str]:
storage_type = get_workspace_storage_type_without_auth(self.db, workspace_id)
user_rag_memory_id = ""
if storage_type == "rag":
knowledge = knowledge_repository.get_knowledge_by_name(
db=self.db,
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge:
user_rag_memory_id = str(knowledge.id)
else:
logger.warning(
f"No knowledge base named 'USER_RAG_MEMORY' found, "
f"workspace_id: {workspace_id}, will use neo4j storage"
)
storage_type = 'neo4j'
return storage_type, user_rag_memory_id
# ==================== 工作流执行 ==================== # ==================== 工作流执行 ====================
async def run( async def run(
@@ -603,6 +624,7 @@ class WorkflowService:
try: try:
files = await self._handle_file_input(payload.files) files = await self._handle_file_input(payload.files)
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
input_data["files"] = files input_data["files"] = files
message_id = uuid.uuid4() message_id = uuid.uuid4()
# 更新状态为运行中 # 更新状态为运行中
@@ -627,7 +649,9 @@ class WorkflowService:
input_data=input_data, input_data=input_data,
execution_id=execution.execution_id, execution_id=execution.execution_id,
workspace_id=str(workspace_id), workspace_id=str(workspace_id),
user_id=payload.user_id user_id=payload.user_id,
memory_storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
) )
# 更新执行结果 # 更新执行结果
if result.get("status") == "completed": if result.get("status") == "completed":
@@ -776,6 +800,7 @@ class WorkflowService:
try: try:
files = await self._handle_file_input(payload.files) files = await self._handle_file_input(payload.files)
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
input_data["files"] = files input_data["files"] = files
self.update_execution_status(execution.execution_id, "running") self.update_execution_status(execution.execution_id, "running")
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
@@ -797,6 +822,8 @@ class WorkflowService:
execution_id=execution.execution_id, execution_id=execution.execution_id,
workspace_id=str(workspace_id), workspace_id=str(workspace_id),
user_id=payload.user_id, user_id=payload.user_id,
memory_storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
): ):
if event.get("event") == "workflow_end": if event.get("event") == "workflow_end":
status = event.get("data", {}).get("status") status = event.get("data", {}).get("status")

View File

@@ -863,7 +863,7 @@ def get_workspace_storage_type(
def get_workspace_storage_type_without_auth( def get_workspace_storage_type_without_auth(
db: Session, db: Session,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
) -> Optional[str]: ) -> str:
"""获取工作空间的存储类型(无需权限验证,用于公开分享等场景) """获取工作空间的存储类型(无需权限验证,用于公开分享等场景)
Args: Args: