fix(workflow): fix memory write behavior in RAG workspace
This commit is contained in:
@@ -12,14 +12,26 @@ class ExecutionContext(BaseModel):
|
||||
execution_id: str
|
||||
workspace_id: str
|
||||
user_id: str
|
||||
memory_storage_type: str
|
||||
user_rag_memory_id: str
|
||||
checkpoint_config: RunnableConfig
|
||||
|
||||
@classmethod
|
||||
def create(cls, execution_id: str, workspace_id: str, user_id: str):
|
||||
def create(
|
||||
cls,
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str,
|
||||
memory_storage_type: str,
|
||||
user_rag_memory_id: str
|
||||
):
|
||||
return cls(
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
memory_storage_type=memory_storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
|
||||
checkpoint_config=RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": uuid.uuid4(),
|
||||
|
||||
@@ -33,6 +33,8 @@ class WorkflowState(dict):
|
||||
"workspace_id",
|
||||
"user_id",
|
||||
"activate",
|
||||
"memory_storage_type",
|
||||
"user_rag_memory_id"
|
||||
})
|
||||
__optional_keys__ = frozenset({
|
||||
"error",
|
||||
@@ -62,6 +64,9 @@ class WorkflowState(dict):
|
||||
# node activate status
|
||||
activate: Annotated[dict[str, bool], merge_activate_state]
|
||||
|
||||
memory_storage_type: str
|
||||
user_rag_memory_id: str
|
||||
|
||||
|
||||
class WorkflowStateManager:
|
||||
def create_initial_state(
|
||||
@@ -85,7 +90,9 @@ class WorkflowStateManager:
|
||||
looping=0,
|
||||
activate={
|
||||
start_node_id: True
|
||||
}
|
||||
},
|
||||
memory_storage_type=execution_context.memory_storage_type,
|
||||
user_rag_memory_id=execution_context.user_rag_memory_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -13,7 +13,7 @@ from pydantic import BaseModel
|
||||
|
||||
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.variable.variable_objects import T, create_variable_instance
|
||||
from app.core.workflow.variable.variable_objects import T, create_variable_instance, ArrayVariable, FileVariable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -373,6 +373,14 @@ class VariablePool:
|
||||
def copy(self, pool: 'VariablePool'):
|
||||
self.variables = deepcopy(pool.variables)
|
||||
|
||||
def is_file_variable(self, selector):
|
||||
variable_struct = self._get_variable_struct(selector)
|
||||
if isinstance(variable_struct, FileVariable):
|
||||
return True
|
||||
elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable:
|
||||
return True
|
||||
return False
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""导出为字典
|
||||
|
||||
|
||||
@@ -409,7 +409,9 @@ async def execute_workflow(
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
user_id: str,
|
||||
memory_storage_type: str,
|
||||
user_rag_memory_id: str
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute a workflow (convenience function, non-streaming).
|
||||
@@ -420,6 +422,8 @@ async def execute_workflow(
|
||||
execution_id (str): Execution ID.
|
||||
workspace_id (str): Workspace ID.
|
||||
user_id (str): User ID.
|
||||
user_rag_memory_id: rag knowledge db id
|
||||
memory_storage_type: neo4j / rag
|
||||
|
||||
Returns:
|
||||
dict: Workflow execution result.
|
||||
@@ -427,7 +431,9 @@ async def execute_workflow(
|
||||
execution_context = ExecutionContext.create(
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
memory_storage_type=memory_storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
@@ -441,7 +447,9 @@ async def execute_workflow_stream(
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
user_id: str,
|
||||
memory_storage_type: str,
|
||||
user_rag_memory_id: str
|
||||
):
|
||||
"""
|
||||
Execute a workflow in streaming mode (convenience function).
|
||||
@@ -452,6 +460,8 @@ async def execute_workflow_stream(
|
||||
execution_id (str): Execution ID.
|
||||
workspace_id (str): Workspace ID.
|
||||
user_id (str): User ID.
|
||||
user_rag_memory_id: rag knowledge db id
|
||||
memory_storage_type: neo4j / rag
|
||||
|
||||
Yields:
|
||||
dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end.
|
||||
@@ -459,7 +469,9 @@ async def execute_workflow_stream(
|
||||
execution_context = ExecutionContext.create(
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
memory_storage_type=memory_storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
@@ -5,7 +6,9 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
||||
from app.db import get_db_read
|
||||
from app.schemas import FileInput
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.tasks import write_message_task
|
||||
|
||||
@@ -36,8 +39,8 @@ class MemoryReadNode(BaseNode):
|
||||
search_switch=self.typed_config.search_switch,
|
||||
history=[],
|
||||
db=db,
|
||||
storage_type="neo4j",
|
||||
user_rag_memory_id=""
|
||||
storage_type=state["memory_storage_type"],
|
||||
user_rag_memory_id=state["user_rag_memory_id"]
|
||||
)
|
||||
|
||||
|
||||
@@ -49,6 +52,19 @@ class MemoryWriteNode(BaseNode):
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {"output": VariableType.STRING}
|
||||
|
||||
@staticmethod
|
||||
def _extract_multimodal_memory_variables(content: str, variable_pool: VariablePool) -> tuple[list[str], str]:
|
||||
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
|
||||
variable_pattern = re.compile(variable_pattern_string)
|
||||
variables = variable_pattern.findall(content)
|
||||
file_variables = []
|
||||
for variable in variables:
|
||||
if variable_pool.is_file_variable(variable):
|
||||
file_variables.append(variable)
|
||||
for var in file_variables:
|
||||
content = content.replace(var, "")
|
||||
return file_variables, content
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
||||
end_user_id = self.get_variable("sys.user_id", variable_pool)
|
||||
@@ -56,6 +72,7 @@ class MemoryWriteNode(BaseNode):
|
||||
if not end_user_id:
|
||||
raise RuntimeError("End user id is required")
|
||||
messages = []
|
||||
multimodal_memories = []
|
||||
if self.typed_config.message:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
@@ -63,17 +80,45 @@ class MemoryWriteNode(BaseNode):
|
||||
})
|
||||
|
||||
for message in self.typed_config.messages:
|
||||
file_variables, content = self._extract_multimodal_memory_variables(
|
||||
message.content,
|
||||
variable_pool
|
||||
)
|
||||
file_info = []
|
||||
for var in file_variables:
|
||||
instence: FileVariable | ArrayVariable[FileVariable] = variable_pool.get_instance(var)
|
||||
if isinstance(instence, FileVariable):
|
||||
file_info.append(FileInput(
|
||||
type=instence.value.type,
|
||||
transfer_method=instence.value.transfer_method,
|
||||
upload_file_id=instence.value.file_id,
|
||||
url=instence.value.url,
|
||||
file_type=instence.value.origin_file_type
|
||||
).model_dump())
|
||||
elif isinstance(instence, ArrayVariable) and instence.child_type == FileVariable:
|
||||
for file_instence in instence.value:
|
||||
file_info.append(FileInput(
|
||||
type=file_instence.value.type,
|
||||
transfer_method=file_instence.value.transfer_method,
|
||||
upload_file_id=file_instence.value.file_id,
|
||||
url=file_instence.value.url,
|
||||
file_type=file_instence.value.origin_file_type
|
||||
).model_dump())
|
||||
multimodal_memories.append({
|
||||
"role": message.role,
|
||||
"files": file_info
|
||||
})
|
||||
messages.append({
|
||||
"role": message.role,
|
||||
"content": self._render_template(message.content, variable_pool)
|
||||
"content": self._render_template(content, variable_pool)
|
||||
})
|
||||
|
||||
write_message_task.delay(
|
||||
end_user_id,
|
||||
messages,
|
||||
str(self.typed_config.config_id),
|
||||
"neo4j",
|
||||
""
|
||||
end_user_id=end_user_id,
|
||||
message=messages,
|
||||
config_id=str(self.typed_config.config_id),
|
||||
storage_type=state["memory_storage_type"],
|
||||
user_rag_memory_id=state["user_rag_memory_id"]
|
||||
)
|
||||
|
||||
return "success"
|
||||
|
||||
Reference in New Issue
Block a user