fix(workflow): fix memory write behavior in RAG workspace
This commit is contained in:
@@ -20,6 +20,7 @@ from app.core.workflow.variable.base_variable import FileObject
|
||||
from app.db import get_db
|
||||
from app.models import App
|
||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||
from app.repositories import knowledge_repository
|
||||
from app.repositories.workflow_repository import (
|
||||
WorkflowConfigRepository,
|
||||
WorkflowExecutionRepository,
|
||||
@@ -29,6 +30,7 @@ from app.schemas import DraftRunRequest, FileInput
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.multi_agent_service import convert_uuids_to_str
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from app.services.workspace_service import get_workspace_storage_type_without_auth
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -536,6 +538,25 @@ class WorkflowService:
|
||||
mapped = internal_event
|
||||
return mapped
|
||||
|
||||
def _get_memory_store_info(self, workspace_id: uuid.UUID) -> tuple[str, str]:
|
||||
storage_type = get_workspace_storage_type_without_auth(self.db, workspace_id)
|
||||
user_rag_memory_id = ""
|
||||
if storage_type == "rag":
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=self.db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
else:
|
||||
logger.warning(
|
||||
f"No knowledge base named 'USER_RAG_MEMORY' found, "
|
||||
f"workspace_id: {workspace_id}, will use neo4j storage"
|
||||
)
|
||||
storage_type = 'neo4j'
|
||||
return storage_type, user_rag_memory_id
|
||||
|
||||
# ==================== 工作流执行 ====================
|
||||
|
||||
async def run(
|
||||
@@ -603,6 +624,7 @@ class WorkflowService:
|
||||
|
||||
try:
|
||||
files = await self._handle_file_input(payload.files)
|
||||
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
|
||||
input_data["files"] = files
|
||||
message_id = uuid.uuid4()
|
||||
# 更新状态为运行中
|
||||
@@ -627,7 +649,9 @@ class WorkflowService:
|
||||
input_data=input_data,
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id=str(workspace_id),
|
||||
user_id=payload.user_id
|
||||
user_id=payload.user_id,
|
||||
memory_storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
# 更新执行结果
|
||||
if result.get("status") == "completed":
|
||||
@@ -776,6 +800,7 @@ class WorkflowService:
|
||||
|
||||
try:
|
||||
files = await self._handle_file_input(payload.files)
|
||||
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
|
||||
input_data["files"] = files
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
||||
@@ -797,6 +822,8 @@ class WorkflowService:
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id=str(workspace_id),
|
||||
user_id=payload.user_id,
|
||||
memory_storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
if event.get("event") == "workflow_end":
|
||||
status = event.get("data", {}).get("status")
|
||||
|
||||
@@ -863,7 +863,7 @@ def get_workspace_storage_type(
|
||||
def get_workspace_storage_type_without_auth(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
) -> Optional[str]:
|
||||
) -> str:
|
||||
"""获取工作空间的存储类型(无需权限验证,用于公开分享等场景)
|
||||
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user