Merge pull request #916 from SuanmoSuanyangTechnology/refactor/memory_search

refactor(memory): consolidate search services and unify model client initialization
This commit is contained in:
Ke Sun
2026-04-21 19:01:22 +08:00
committed by GitHub
39 changed files with 1637 additions and 1577 deletions

View File

@@ -1,6 +1,8 @@
import re
from typing import Any
from app.core.memory.enums import SearchStrategy
from app.core.memory.memory_service import MemoryService
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
@@ -9,7 +11,6 @@ from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
from app.db import get_db_read
from app.schemas import FileInput
from app.services.memory_agent_service import MemoryAgentService
from app.tasks import write_message_task
@@ -32,16 +33,32 @@ class MemoryReadNode(BaseNode):
if not end_user_id:
raise RuntimeError("End user id is required")
return await MemoryAgentService().read_memory(
end_user_id=end_user_id,
message=self._render_template(self.typed_config.message, variable_pool),
config_id=self.typed_config.config_id,
search_switch=self.typed_config.search_switch,
history=[],
memory_service = MemoryService(
db=db,
storage_type=state["memory_storage_type"],
user_rag_memory_id=state["user_rag_memory_id"]
config_id=str(self.typed_config.config_id),
end_user_id=end_user_id,
user_rag_memory_id=state["user_rag_memory_id"],
)
search_result = await memory_service.read(
self._render_template(self.typed_config.message, variable_pool),
search_switch=SearchStrategy(self.typed_config.search_switch)
)
return {
"answer": search_result.content,
"intermediate_outputs": [_.model_dump() for _ in search_result.memories]
}
# return await MemoryAgentService().read_memory(
# end_user_id=end_user_id,
# message=self._render_template(self.typed_config.message, variable_pool),
# config_id=self.typed_config.config_id,
# search_switch=self.typed_config.search_switch,
# history=[],
# db=db,
# storage_type=state["memory_storage_type"],
# user_rag_memory_id=state["user_rag_memory_id"]
# )
class MemoryWriteNode(BaseNode):