fix(workflow): fix memory write behavior in RAG workspace
This commit is contained in:
@@ -12,14 +12,26 @@ class ExecutionContext(BaseModel):
|
||||
execution_id: str
|
||||
workspace_id: str
|
||||
user_id: str
|
||||
memory_storage_type: str
|
||||
user_rag_memory_id: str
|
||||
checkpoint_config: RunnableConfig
|
||||
|
||||
@classmethod
|
||||
def create(cls, execution_id: str, workspace_id: str, user_id: str):
|
||||
def create(
|
||||
cls,
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str,
|
||||
memory_storage_type: str,
|
||||
user_rag_memory_id: str
|
||||
):
|
||||
return cls(
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
memory_storage_type=memory_storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
|
||||
checkpoint_config=RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": uuid.uuid4(),
|
||||
|
||||
@@ -33,6 +33,8 @@ class WorkflowState(dict):
|
||||
"workspace_id",
|
||||
"user_id",
|
||||
"activate",
|
||||
"memory_storage_type",
|
||||
"user_rag_memory_id"
|
||||
})
|
||||
__optional_keys__ = frozenset({
|
||||
"error",
|
||||
@@ -62,6 +64,9 @@ class WorkflowState(dict):
|
||||
# node activate status
|
||||
activate: Annotated[dict[str, bool], merge_activate_state]
|
||||
|
||||
memory_storage_type: str
|
||||
user_rag_memory_id: str
|
||||
|
||||
|
||||
class WorkflowStateManager:
|
||||
def create_initial_state(
|
||||
@@ -85,7 +90,9 @@ class WorkflowStateManager:
|
||||
looping=0,
|
||||
activate={
|
||||
start_node_id: True
|
||||
}
|
||||
},
|
||||
memory_storage_type=execution_context.memory_storage_type,
|
||||
user_rag_memory_id=execution_context.user_rag_memory_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -13,7 +13,7 @@ from pydantic import BaseModel
|
||||
|
||||
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.variable.variable_objects import T, create_variable_instance
|
||||
from app.core.workflow.variable.variable_objects import T, create_variable_instance, ArrayVariable, FileVariable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -373,6 +373,14 @@ class VariablePool:
|
||||
def copy(self, pool: 'VariablePool'):
|
||||
self.variables = deepcopy(pool.variables)
|
||||
|
||||
def is_file_variable(self, selector):
|
||||
variable_struct = self._get_variable_struct(selector)
|
||||
if isinstance(variable_struct, FileVariable):
|
||||
return True
|
||||
elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable:
|
||||
return True
|
||||
return False
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""导出为字典
|
||||
|
||||
|
||||
@@ -409,7 +409,9 @@ async def execute_workflow(
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
user_id: str,
|
||||
memory_storage_type: str,
|
||||
user_rag_memory_id: str
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute a workflow (convenience function, non-streaming).
|
||||
@@ -420,6 +422,8 @@ async def execute_workflow(
|
||||
execution_id (str): Execution ID.
|
||||
workspace_id (str): Workspace ID.
|
||||
user_id (str): User ID.
|
||||
user_rag_memory_id: rag knowledge db id
|
||||
memory_storage_type: neo4j / rag
|
||||
|
||||
Returns:
|
||||
dict: Workflow execution result.
|
||||
@@ -427,7 +431,9 @@ async def execute_workflow(
|
||||
execution_context = ExecutionContext.create(
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
memory_storage_type=memory_storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
@@ -441,7 +447,9 @@ async def execute_workflow_stream(
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
user_id: str,
|
||||
memory_storage_type: str,
|
||||
user_rag_memory_id: str
|
||||
):
|
||||
"""
|
||||
Execute a workflow in streaming mode (convenience function).
|
||||
@@ -452,6 +460,8 @@ async def execute_workflow_stream(
|
||||
execution_id (str): Execution ID.
|
||||
workspace_id (str): Workspace ID.
|
||||
user_id (str): User ID.
|
||||
user_rag_memory_id: rag knowledge db id
|
||||
memory_storage_type: neo4j / rag
|
||||
|
||||
Yields:
|
||||
dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end.
|
||||
@@ -459,7 +469,9 @@ async def execute_workflow_stream(
|
||||
execution_context = ExecutionContext.create(
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
memory_storage_type=memory_storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
@@ -5,7 +6,9 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
||||
from app.db import get_db_read
|
||||
from app.schemas import FileInput
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.tasks import write_message_task
|
||||
|
||||
@@ -36,8 +39,8 @@ class MemoryReadNode(BaseNode):
|
||||
search_switch=self.typed_config.search_switch,
|
||||
history=[],
|
||||
db=db,
|
||||
storage_type="neo4j",
|
||||
user_rag_memory_id=""
|
||||
storage_type=state["memory_storage_type"],
|
||||
user_rag_memory_id=state["user_rag_memory_id"]
|
||||
)
|
||||
|
||||
|
||||
@@ -49,6 +52,19 @@ class MemoryWriteNode(BaseNode):
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {"output": VariableType.STRING}
|
||||
|
||||
@staticmethod
|
||||
def _extract_multimodal_memory_variables(content: str, variable_pool: VariablePool) -> tuple[list[str], str]:
|
||||
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
|
||||
variable_pattern = re.compile(variable_pattern_string)
|
||||
variables = variable_pattern.findall(content)
|
||||
file_variables = []
|
||||
for variable in variables:
|
||||
if variable_pool.is_file_variable(variable):
|
||||
file_variables.append(variable)
|
||||
for var in file_variables:
|
||||
content = content.replace(var, "")
|
||||
return file_variables, content
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
||||
end_user_id = self.get_variable("sys.user_id", variable_pool)
|
||||
@@ -56,6 +72,7 @@ class MemoryWriteNode(BaseNode):
|
||||
if not end_user_id:
|
||||
raise RuntimeError("End user id is required")
|
||||
messages = []
|
||||
multimodal_memories = []
|
||||
if self.typed_config.message:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
@@ -63,17 +80,45 @@ class MemoryWriteNode(BaseNode):
|
||||
})
|
||||
|
||||
for message in self.typed_config.messages:
|
||||
file_variables, content = self._extract_multimodal_memory_variables(
|
||||
message.content,
|
||||
variable_pool
|
||||
)
|
||||
file_info = []
|
||||
for var in file_variables:
|
||||
instence: FileVariable | ArrayVariable[FileVariable] = variable_pool.get_instance(var)
|
||||
if isinstance(instence, FileVariable):
|
||||
file_info.append(FileInput(
|
||||
type=instence.value.type,
|
||||
transfer_method=instence.value.transfer_method,
|
||||
upload_file_id=instence.value.file_id,
|
||||
url=instence.value.url,
|
||||
file_type=instence.value.origin_file_type
|
||||
).model_dump())
|
||||
elif isinstance(instence, ArrayVariable) and instence.child_type == FileVariable:
|
||||
for file_instence in instence.value:
|
||||
file_info.append(FileInput(
|
||||
type=file_instence.value.type,
|
||||
transfer_method=file_instence.value.transfer_method,
|
||||
upload_file_id=file_instence.value.file_id,
|
||||
url=file_instence.value.url,
|
||||
file_type=file_instence.value.origin_file_type
|
||||
).model_dump())
|
||||
multimodal_memories.append({
|
||||
"role": message.role,
|
||||
"files": file_info
|
||||
})
|
||||
messages.append({
|
||||
"role": message.role,
|
||||
"content": self._render_template(message.content, variable_pool)
|
||||
"content": self._render_template(content, variable_pool)
|
||||
})
|
||||
|
||||
write_message_task.delay(
|
||||
end_user_id,
|
||||
messages,
|
||||
str(self.typed_config.config_id),
|
||||
"neo4j",
|
||||
""
|
||||
end_user_id=end_user_id,
|
||||
message=messages,
|
||||
config_id=str(self.typed_config.config_id),
|
||||
storage_type=state["memory_storage_type"],
|
||||
user_rag_memory_id=state["user_rag_memory_id"]
|
||||
)
|
||||
|
||||
return "success"
|
||||
|
||||
@@ -20,6 +20,7 @@ from app.core.workflow.variable.base_variable import FileObject
|
||||
from app.db import get_db
|
||||
from app.models import App
|
||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||
from app.repositories import knowledge_repository
|
||||
from app.repositories.workflow_repository import (
|
||||
WorkflowConfigRepository,
|
||||
WorkflowExecutionRepository,
|
||||
@@ -29,6 +30,7 @@ from app.schemas import DraftRunRequest, FileInput
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.multi_agent_service import convert_uuids_to_str
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from app.services.workspace_service import get_workspace_storage_type_without_auth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -536,6 +538,25 @@ class WorkflowService:
|
||||
mapped = internal_event
|
||||
return mapped
|
||||
|
||||
def _get_memory_store_info(self, workspace_id: uuid.UUID) -> tuple[str, str]:
|
||||
storage_type = get_workspace_storage_type_without_auth(self.db, workspace_id)
|
||||
user_rag_memory_id = ""
|
||||
if storage_type == "rag":
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=self.db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
else:
|
||||
logger.warning(
|
||||
f"No knowledge base named 'USER_RAG_MEMORY' found, "
|
||||
f"workspace_id: {workspace_id}, will use neo4j storage"
|
||||
)
|
||||
storage_type = 'neo4j'
|
||||
return storage_type, user_rag_memory_id
|
||||
|
||||
# ==================== 工作流执行 ====================
|
||||
|
||||
async def run(
|
||||
@@ -603,6 +624,7 @@ class WorkflowService:
|
||||
|
||||
try:
|
||||
files = await self._handle_file_input(payload.files)
|
||||
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
|
||||
input_data["files"] = files
|
||||
message_id = uuid.uuid4()
|
||||
# 更新状态为运行中
|
||||
@@ -627,7 +649,9 @@ class WorkflowService:
|
||||
input_data=input_data,
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id=str(workspace_id),
|
||||
user_id=payload.user_id
|
||||
user_id=payload.user_id,
|
||||
memory_storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
# 更新执行结果
|
||||
if result.get("status") == "completed":
|
||||
@@ -776,6 +800,7 @@ class WorkflowService:
|
||||
|
||||
try:
|
||||
files = await self._handle_file_input(payload.files)
|
||||
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
|
||||
input_data["files"] = files
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
||||
@@ -797,6 +822,8 @@ class WorkflowService:
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id=str(workspace_id),
|
||||
user_id=payload.user_id,
|
||||
memory_storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
if event.get("event") == "workflow_end":
|
||||
status = event.get("data", {}).get("status")
|
||||
|
||||
@@ -863,7 +863,7 @@ def get_workspace_storage_type(
|
||||
def get_workspace_storage_type_without_auth(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
) -> Optional[str]:
|
||||
) -> str:
|
||||
"""获取工作空间的存储类型(无需权限验证,用于公开分享等场景)
|
||||
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user