fix(workflow): fix memory write behavior in RAG workspace
This commit is contained in:
@@ -12,14 +12,26 @@ class ExecutionContext(BaseModel):
|
|||||||
execution_id: str
|
execution_id: str
|
||||||
workspace_id: str
|
workspace_id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
|
memory_storage_type: str
|
||||||
|
user_rag_memory_id: str
|
||||||
checkpoint_config: RunnableConfig
|
checkpoint_config: RunnableConfig
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, execution_id: str, workspace_id: str, user_id: str):
|
def create(
|
||||||
|
cls,
|
||||||
|
execution_id: str,
|
||||||
|
workspace_id: str,
|
||||||
|
user_id: str,
|
||||||
|
memory_storage_type: str,
|
||||||
|
user_rag_memory_id: str
|
||||||
|
):
|
||||||
return cls(
|
return cls(
|
||||||
execution_id=execution_id,
|
execution_id=execution_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
memory_storage_type=memory_storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
|
||||||
checkpoint_config=RunnableConfig(
|
checkpoint_config=RunnableConfig(
|
||||||
configurable={
|
configurable={
|
||||||
"thread_id": uuid.uuid4(),
|
"thread_id": uuid.uuid4(),
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ class WorkflowState(dict):
|
|||||||
"workspace_id",
|
"workspace_id",
|
||||||
"user_id",
|
"user_id",
|
||||||
"activate",
|
"activate",
|
||||||
|
"memory_storage_type",
|
||||||
|
"user_rag_memory_id"
|
||||||
})
|
})
|
||||||
__optional_keys__ = frozenset({
|
__optional_keys__ = frozenset({
|
||||||
"error",
|
"error",
|
||||||
@@ -62,6 +64,9 @@ class WorkflowState(dict):
|
|||||||
# node activate status
|
# node activate status
|
||||||
activate: Annotated[dict[str, bool], merge_activate_state]
|
activate: Annotated[dict[str, bool], merge_activate_state]
|
||||||
|
|
||||||
|
memory_storage_type: str
|
||||||
|
user_rag_memory_id: str
|
||||||
|
|
||||||
|
|
||||||
class WorkflowStateManager:
|
class WorkflowStateManager:
|
||||||
def create_initial_state(
|
def create_initial_state(
|
||||||
@@ -85,7 +90,9 @@ class WorkflowStateManager:
|
|||||||
looping=0,
|
looping=0,
|
||||||
activate={
|
activate={
|
||||||
start_node_id: True
|
start_node_id: True
|
||||||
}
|
},
|
||||||
|
memory_storage_type=execution_context.memory_storage_type,
|
||||||
|
user_rag_memory_id=execution_context.user_rag_memory_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
||||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||||
from app.core.workflow.variable.variable_objects import T, create_variable_instance
|
from app.core.workflow.variable.variable_objects import T, create_variable_instance, ArrayVariable, FileVariable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -373,6 +373,14 @@ class VariablePool:
|
|||||||
def copy(self, pool: 'VariablePool'):
|
def copy(self, pool: 'VariablePool'):
|
||||||
self.variables = deepcopy(pool.variables)
|
self.variables = deepcopy(pool.variables)
|
||||||
|
|
||||||
|
def is_file_variable(self, selector):
|
||||||
|
variable_struct = self._get_variable_struct(selector)
|
||||||
|
if isinstance(variable_struct, FileVariable):
|
||||||
|
return True
|
||||||
|
elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""导出为字典
|
"""导出为字典
|
||||||
|
|
||||||
|
|||||||
@@ -409,7 +409,9 @@ async def execute_workflow(
|
|||||||
input_data: dict[str, Any],
|
input_data: dict[str, Any],
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
workspace_id: str,
|
workspace_id: str,
|
||||||
user_id: str
|
user_id: str,
|
||||||
|
memory_storage_type: str,
|
||||||
|
user_rag_memory_id: str
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Execute a workflow (convenience function, non-streaming).
|
Execute a workflow (convenience function, non-streaming).
|
||||||
@@ -420,6 +422,8 @@ async def execute_workflow(
|
|||||||
execution_id (str): Execution ID.
|
execution_id (str): Execution ID.
|
||||||
workspace_id (str): Workspace ID.
|
workspace_id (str): Workspace ID.
|
||||||
user_id (str): User ID.
|
user_id (str): User ID.
|
||||||
|
user_rag_memory_id: rag knowledge db id
|
||||||
|
memory_storage_type: neo4j / rag
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Workflow execution result.
|
dict: Workflow execution result.
|
||||||
@@ -427,7 +431,9 @@ async def execute_workflow(
|
|||||||
execution_context = ExecutionContext.create(
|
execution_context = ExecutionContext.create(
|
||||||
execution_id=execution_id,
|
execution_id=execution_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
user_id=user_id
|
user_id=user_id,
|
||||||
|
memory_storage_type=memory_storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
)
|
)
|
||||||
executor = WorkflowExecutor(
|
executor = WorkflowExecutor(
|
||||||
workflow_config=workflow_config,
|
workflow_config=workflow_config,
|
||||||
@@ -441,7 +447,9 @@ async def execute_workflow_stream(
|
|||||||
input_data: dict[str, Any],
|
input_data: dict[str, Any],
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
workspace_id: str,
|
workspace_id: str,
|
||||||
user_id: str
|
user_id: str,
|
||||||
|
memory_storage_type: str,
|
||||||
|
user_rag_memory_id: str
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Execute a workflow in streaming mode (convenience function).
|
Execute a workflow in streaming mode (convenience function).
|
||||||
@@ -452,6 +460,8 @@ async def execute_workflow_stream(
|
|||||||
execution_id (str): Execution ID.
|
execution_id (str): Execution ID.
|
||||||
workspace_id (str): Workspace ID.
|
workspace_id (str): Workspace ID.
|
||||||
user_id (str): User ID.
|
user_id (str): User ID.
|
||||||
|
user_rag_memory_id: rag knowledge db id
|
||||||
|
memory_storage_type: neo4j / rag
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end.
|
dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end.
|
||||||
@@ -459,7 +469,9 @@ async def execute_workflow_stream(
|
|||||||
execution_context = ExecutionContext.create(
|
execution_context = ExecutionContext.create(
|
||||||
execution_id=execution_id,
|
execution_id=execution_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
user_id=user_id
|
user_id=user_id,
|
||||||
|
memory_storage_type=memory_storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
)
|
)
|
||||||
executor = WorkflowExecutor(
|
executor = WorkflowExecutor(
|
||||||
workflow_config=workflow_config,
|
workflow_config=workflow_config,
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.core.workflow.engine.state_manager import WorkflowState
|
from app.core.workflow.engine.state_manager import WorkflowState
|
||||||
@@ -5,7 +6,9 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
|||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
|
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
||||||
from app.db import get_db_read
|
from app.db import get_db_read
|
||||||
|
from app.schemas import FileInput
|
||||||
from app.services.memory_agent_service import MemoryAgentService
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
from app.tasks import write_message_task
|
from app.tasks import write_message_task
|
||||||
|
|
||||||
@@ -36,8 +39,8 @@ class MemoryReadNode(BaseNode):
|
|||||||
search_switch=self.typed_config.search_switch,
|
search_switch=self.typed_config.search_switch,
|
||||||
history=[],
|
history=[],
|
||||||
db=db,
|
db=db,
|
||||||
storage_type="neo4j",
|
storage_type=state["memory_storage_type"],
|
||||||
user_rag_memory_id=""
|
user_rag_memory_id=state["user_rag_memory_id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -49,6 +52,19 @@ class MemoryWriteNode(BaseNode):
|
|||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {"output": VariableType.STRING}
|
return {"output": VariableType.STRING}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_multimodal_memory_variables(content: str, variable_pool: VariablePool) -> tuple[list[str], str]:
|
||||||
|
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
|
||||||
|
variable_pattern = re.compile(variable_pattern_string)
|
||||||
|
variables = variable_pattern.findall(content)
|
||||||
|
file_variables = []
|
||||||
|
for variable in variables:
|
||||||
|
if variable_pool.is_file_variable(variable):
|
||||||
|
file_variables.append(variable)
|
||||||
|
for var in file_variables:
|
||||||
|
content = content.replace(var, "")
|
||||||
|
return file_variables, content
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||||
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
||||||
end_user_id = self.get_variable("sys.user_id", variable_pool)
|
end_user_id = self.get_variable("sys.user_id", variable_pool)
|
||||||
@@ -56,6 +72,7 @@ class MemoryWriteNode(BaseNode):
|
|||||||
if not end_user_id:
|
if not end_user_id:
|
||||||
raise RuntimeError("End user id is required")
|
raise RuntimeError("End user id is required")
|
||||||
messages = []
|
messages = []
|
||||||
|
multimodal_memories = []
|
||||||
if self.typed_config.message:
|
if self.typed_config.message:
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
@@ -63,17 +80,45 @@ class MemoryWriteNode(BaseNode):
|
|||||||
})
|
})
|
||||||
|
|
||||||
for message in self.typed_config.messages:
|
for message in self.typed_config.messages:
|
||||||
|
file_variables, content = self._extract_multimodal_memory_variables(
|
||||||
|
message.content,
|
||||||
|
variable_pool
|
||||||
|
)
|
||||||
|
file_info = []
|
||||||
|
for var in file_variables:
|
||||||
|
instence: FileVariable | ArrayVariable[FileVariable] = variable_pool.get_instance(var)
|
||||||
|
if isinstance(instence, FileVariable):
|
||||||
|
file_info.append(FileInput(
|
||||||
|
type=instence.value.type,
|
||||||
|
transfer_method=instence.value.transfer_method,
|
||||||
|
upload_file_id=instence.value.file_id,
|
||||||
|
url=instence.value.url,
|
||||||
|
file_type=instence.value.origin_file_type
|
||||||
|
).model_dump())
|
||||||
|
elif isinstance(instence, ArrayVariable) and instence.child_type == FileVariable:
|
||||||
|
for file_instence in instence.value:
|
||||||
|
file_info.append(FileInput(
|
||||||
|
type=file_instence.value.type,
|
||||||
|
transfer_method=file_instence.value.transfer_method,
|
||||||
|
upload_file_id=file_instence.value.file_id,
|
||||||
|
url=file_instence.value.url,
|
||||||
|
file_type=file_instence.value.origin_file_type
|
||||||
|
).model_dump())
|
||||||
|
multimodal_memories.append({
|
||||||
|
"role": message.role,
|
||||||
|
"files": file_info
|
||||||
|
})
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": message.role,
|
"role": message.role,
|
||||||
"content": self._render_template(message.content, variable_pool)
|
"content": self._render_template(content, variable_pool)
|
||||||
})
|
})
|
||||||
|
|
||||||
write_message_task.delay(
|
write_message_task.delay(
|
||||||
end_user_id,
|
end_user_id=end_user_id,
|
||||||
messages,
|
message=messages,
|
||||||
str(self.typed_config.config_id),
|
config_id=str(self.typed_config.config_id),
|
||||||
"neo4j",
|
storage_type=state["memory_storage_type"],
|
||||||
""
|
user_rag_memory_id=state["user_rag_memory_id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
return "success"
|
return "success"
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from app.core.workflow.variable.base_variable import FileObject
|
|||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.models import App
|
from app.models import App
|
||||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||||
|
from app.repositories import knowledge_repository
|
||||||
from app.repositories.workflow_repository import (
|
from app.repositories.workflow_repository import (
|
||||||
WorkflowConfigRepository,
|
WorkflowConfigRepository,
|
||||||
WorkflowExecutionRepository,
|
WorkflowExecutionRepository,
|
||||||
@@ -29,6 +30,7 @@ from app.schemas import DraftRunRequest, FileInput
|
|||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
from app.services.multi_agent_service import convert_uuids_to_str
|
from app.services.multi_agent_service import convert_uuids_to_str
|
||||||
from app.services.multimodal_service import MultimodalService
|
from app.services.multimodal_service import MultimodalService
|
||||||
|
from app.services.workspace_service import get_workspace_storage_type_without_auth
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -536,6 +538,25 @@ class WorkflowService:
|
|||||||
mapped = internal_event
|
mapped = internal_event
|
||||||
return mapped
|
return mapped
|
||||||
|
|
||||||
|
def _get_memory_store_info(self, workspace_id: uuid.UUID) -> tuple[str, str]:
|
||||||
|
storage_type = get_workspace_storage_type_without_auth(self.db, workspace_id)
|
||||||
|
user_rag_memory_id = ""
|
||||||
|
if storage_type == "rag":
|
||||||
|
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||||
|
db=self.db,
|
||||||
|
name="USER_RAG_MERORY",
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
if knowledge:
|
||||||
|
user_rag_memory_id = str(knowledge.id)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"No knowledge base named 'USER_RAG_MEMORY' found, "
|
||||||
|
f"workspace_id: {workspace_id}, will use neo4j storage"
|
||||||
|
)
|
||||||
|
storage_type = 'neo4j'
|
||||||
|
return storage_type, user_rag_memory_id
|
||||||
|
|
||||||
# ==================== 工作流执行 ====================
|
# ==================== 工作流执行 ====================
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
@@ -603,6 +624,7 @@ class WorkflowService:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
files = await self._handle_file_input(payload.files)
|
files = await self._handle_file_input(payload.files)
|
||||||
|
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
|
||||||
input_data["files"] = files
|
input_data["files"] = files
|
||||||
message_id = uuid.uuid4()
|
message_id = uuid.uuid4()
|
||||||
# 更新状态为运行中
|
# 更新状态为运行中
|
||||||
@@ -627,7 +649,9 @@ class WorkflowService:
|
|||||||
input_data=input_data,
|
input_data=input_data,
|
||||||
execution_id=execution.execution_id,
|
execution_id=execution.execution_id,
|
||||||
workspace_id=str(workspace_id),
|
workspace_id=str(workspace_id),
|
||||||
user_id=payload.user_id
|
user_id=payload.user_id,
|
||||||
|
memory_storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
)
|
)
|
||||||
# 更新执行结果
|
# 更新执行结果
|
||||||
if result.get("status") == "completed":
|
if result.get("status") == "completed":
|
||||||
@@ -776,6 +800,7 @@ class WorkflowService:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
files = await self._handle_file_input(payload.files)
|
files = await self._handle_file_input(payload.files)
|
||||||
|
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
|
||||||
input_data["files"] = files
|
input_data["files"] = files
|
||||||
self.update_execution_status(execution.execution_id, "running")
|
self.update_execution_status(execution.execution_id, "running")
|
||||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
||||||
@@ -797,6 +822,8 @@ class WorkflowService:
|
|||||||
execution_id=execution.execution_id,
|
execution_id=execution.execution_id,
|
||||||
workspace_id=str(workspace_id),
|
workspace_id=str(workspace_id),
|
||||||
user_id=payload.user_id,
|
user_id=payload.user_id,
|
||||||
|
memory_storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
):
|
):
|
||||||
if event.get("event") == "workflow_end":
|
if event.get("event") == "workflow_end":
|
||||||
status = event.get("data", {}).get("status")
|
status = event.get("data", {}).get("status")
|
||||||
|
|||||||
@@ -863,7 +863,7 @@ def get_workspace_storage_type(
|
|||||||
def get_workspace_storage_type_without_auth(
|
def get_workspace_storage_type_without_auth(
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
) -> Optional[str]:
|
) -> str:
|
||||||
"""获取工作空间的存储类型(无需权限验证,用于公开分享等场景)
|
"""获取工作空间的存储类型(无需权限验证,用于公开分享等场景)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
Reference in New Issue
Block a user