refactor(memory): integrate unified memory service into agent controller

- Replace direct memory agent service calls with unified MemoryService in read endpoint
- Update query preprocessor to use new prompt format and return structured queries
- Enhance MemorySearchResult model with filtering, merging, and ID tracking capabilities
- Add intermediate outputs display for problem split, perceptual retrieval, and search results
- Fix parameter alignment and remove unused history parameter in memory agent service
This commit is contained in:
Eternity
2026-04-20 17:43:52 +08:00
parent 749cf79581
commit 688503a1ca
14 changed files with 372 additions and 319 deletions

View File

@@ -15,13 +15,14 @@ from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.agent.agent_middleware import AgentMiddleware
from app.core.agent.langchain_agent import LangChainAgent
from app.core.config import settings
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.memory.enums import SearchStrategy
from app.core.memory.memory_service import MemoryService
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.models import AgentConfig, ModelConfig
@@ -29,10 +30,8 @@ from app.repositories.tool_repository import ToolRepository
from app.schemas.app_schema import FileInput, Citation
from app.schemas.model_schema import ModelInfo
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services import task_service
from app.services.conversation_service import ConversationService
from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger
from app.services.model_service import ModelApiKeyService
from app.services.multimodal_service import MultimodalService
@@ -107,38 +106,41 @@ def create_long_term_memory_tool(
logger.info(f" 长期记忆工具被调用question={question}, user={end_user_id}")
try:
with get_db_context() as db:
memory_content = asyncio.run(
MemoryAgentService().read_memory(
end_user_id=end_user_id,
message=question,
history=[],
search_switch="2",
config_id=config_id,
db=db,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
)
task = celery_app.send_task(
"app.core.memory.agent.read_message",
args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id]
)
result = task_service.get_task_memory_read_result(task.id)
status = result.get("status")
logger.info(f"读取任务状态:{status}")
if memory_content:
memory_content = memory_content['answer']
logger.info(f'用户IDAgent:{end_user_id}')
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
memory_service = MemoryService(db, config_id, end_user_id)
search_result = asyncio.run(memory_service.read(question, SearchStrategy.QUICK))
logger.info(
"长期记忆检索成功",
extra={
"end_user_id": end_user_id,
"content_length": len(str(memory_content))
}
)
return f"检索到以下历史记忆:\n\n{memory_content}"
# memory_content = asyncio.run(
# MemoryAgentService().read_memory(
# end_user_id=end_user_id,
# message=question,
# history=[],
# search_switch="2",
# config_id=config_id,
# db=db,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# )
# )
# task = celery_app.send_task(
# "app.core.memory.agent.read_message",
# args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id]
# )
# result = task_service.get_task_memory_read_result(task.id)
# status = result.get("status")
# logger.info(f"读取任务状态:{status}")
# if memory_content:
# memory_content = memory_content['answer']
# logger.info(f'用户IDAgent:{end_user_id}')
# logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
#
# logger.info(
# "长期记忆检索成功",
# extra={
# "end_user_id": end_user_id,
# "content_length": len(str(memory_content))
# }
# )
return f"检索到以下历史记忆:\n\n{search_result.content}"
except Exception as e:
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
return f"记忆检索失败: {str(e)}"

View File

@@ -405,7 +405,7 @@ class MemoryAgentService:
self,
end_user_id: str,
message: str,
history: List[Dict],
history: List[Dict], # FIXME: unused parameter
search_switch: str,
config_id: Optional[uuid.UUID] | int,
db: Session,
@@ -505,8 +505,8 @@ class MemoryAgentService:
initial_state = {
"messages": [HumanMessage(content=message)],
"search_switch": search_switch,
"end_user_id": end_user_id
, "storage_type": storage_type,
"end_user_id": end_user_id,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config}
# 获取节点更新信息
@@ -642,6 +642,8 @@ class MemoryAgentService:
"answer": summary,
"intermediate_outputs": result
}
# TODO: redis search -> answer
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Read operation failed: {str(e)}"

View File

@@ -163,7 +163,7 @@ class MemoryConfigService:
def load_memory_config(
self,
config_id: Optional[UUID] = None,
config_id: UUID | str | int | None = None,
workspace_id: Optional[UUID] = None,
service_name: str = "MemoryConfigService",
) -> MemoryConfig:
@@ -187,16 +187,6 @@ class MemoryConfigService:
"""
start_time = time.time()
config_logger.info(
"Starting memory configuration loading",
extra={
"operation": "load_memory_config",
"service": service_name,
"config_id": str(config_id) if config_id else None,
"workspace_id": str(workspace_id) if workspace_id else None,
},
)
logger.info(f"Loading memory configuration from database: config_id={config_id}, workspace_id={workspace_id}")
try:
@@ -236,11 +226,7 @@ class MemoryConfigService:
f"Configuration not found: config_id={config_id}, workspace_id={workspace_id}"
)
# Get workspace for the config
db_query_start = time.time()
result = MemoryConfigRepository.get_config_with_workspace(self.db, memory_config.config_id)
db_query_time = time.time() - db_query_start
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
if not result:
raise ConfigurationError(