fix(workflow): fix memory write behavior in RAG workspace

This commit is contained in:
Eternity
2026-03-20 18:31:17 +08:00
parent dce7206c44
commit 31085ed678
7 changed files with 128 additions and 17 deletions

View File

@@ -20,6 +20,7 @@ from app.core.workflow.variable.base_variable import FileObject
from app.db import get_db
from app.models import App
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
from app.repositories import knowledge_repository
from app.repositories.workflow_repository import (
WorkflowConfigRepository,
WorkflowExecutionRepository,
@@ -29,6 +30,7 @@ from app.schemas import DraftRunRequest, FileInput
from app.services.conversation_service import ConversationService
from app.services.multi_agent_service import convert_uuids_to_str
from app.services.multimodal_service import MultimodalService
from app.services.workspace_service import get_workspace_storage_type_without_auth
logger = logging.getLogger(__name__)
@@ -536,6 +538,25 @@ class WorkflowService:
mapped = internal_event
return mapped
def _get_memory_store_info(self, workspace_id: uuid.UUID) -> tuple[str, str]:
storage_type = get_workspace_storage_type_without_auth(self.db, workspace_id)
user_rag_memory_id = ""
if storage_type == "rag":
knowledge = knowledge_repository.get_knowledge_by_name(
db=self.db,
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge:
user_rag_memory_id = str(knowledge.id)
else:
logger.warning(
f"No knowledge base named 'USER_RAG_MEMORY' found, "
f"workspace_id: {workspace_id}, will use neo4j storage"
)
storage_type = 'neo4j'
return storage_type, user_rag_memory_id
# ==================== 工作流执行 ====================
async def run(
@@ -603,6 +624,7 @@ class WorkflowService:
try:
files = await self._handle_file_input(payload.files)
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
input_data["files"] = files
message_id = uuid.uuid4()
# 更新状态为运行中
@@ -627,7 +649,9 @@ class WorkflowService:
input_data=input_data,
execution_id=execution.execution_id,
workspace_id=str(workspace_id),
user_id=payload.user_id
user_id=payload.user_id,
memory_storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
# 更新执行结果
if result.get("status") == "completed":
@@ -776,6 +800,7 @@ class WorkflowService:
try:
files = await self._handle_file_input(payload.files)
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
input_data["files"] = files
self.update_execution_status(execution.execution_id, "running")
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
@@ -797,6 +822,8 @@ class WorkflowService:
execution_id=execution.execution_id,
workspace_id=str(workspace_id),
user_id=payload.user_id,
memory_storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
if event.get("event") == "workflow_end":
status = event.get("data", {}).get("status")

View File

@@ -863,7 +863,7 @@ def get_workspace_storage_type(
def get_workspace_storage_type_without_auth(
db: Session,
workspace_id: uuid.UUID,
) -> Optional[str]:
) -> str:
"""获取工作空间的存储类型(无需权限验证,用于公开分享等场景)
Args: