feat(memory): add long-term storage task routing and batching
This commit is contained in:
288
api/app/tasks.py
288
api/app/tasks.py
@@ -1066,6 +1066,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务")
|
||||
|
||||
except Exception as e:
|
||||
db.rollback() # Rollback failed transaction to allow next query
|
||||
api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
|
||||
all_reflection_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
@@ -1204,3 +1205,290 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
||||
return result
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Long-term Memory Storage Tasks (Batched Write Strategies)
|
||||
# =============================================================================
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.long_term_storage.window", bind=True)
|
||||
def long_term_storage_window_task(
|
||||
self,
|
||||
end_user_id: str,
|
||||
langchain_messages: List[Dict[str, Any]],
|
||||
config_id: str,
|
||||
scope: int = 6
|
||||
) -> Dict[str, Any]:
|
||||
"""Celery task for window-based long-term memory storage.
|
||||
|
||||
Accumulates messages in Redis buffer until window size (scope) is reached,
|
||||
then writes batched messages to Neo4j.
|
||||
|
||||
Args:
|
||||
end_user_id: End user identifier
|
||||
langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}]
|
||||
config_id: Memory configuration ID
|
||||
scope: Window size (number of messages before triggering write)
|
||||
|
||||
Returns:
|
||||
Dict containing task status and metadata
|
||||
"""
|
||||
from app.core.logging_config import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger.info(f"[LONG_TERM_WINDOW] Starting task - end_user_id={end_user_id}, scope={scope}")
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import window_dialogue
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
# Save to Redis buffer first
|
||||
write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
|
||||
|
||||
# Load memory config
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="LongTermStorageTask"
|
||||
)
|
||||
|
||||
# Execute window-based dialogue storage
|
||||
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||
|
||||
return {"status": "SUCCESS", "strategy": "window", "scope": scope}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
logger.info(f"[LONG_TERM_WINDOW] Task completed - elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
return {
|
||||
**result,
|
||||
"end_user_id": end_user_id,
|
||||
"config_id": config_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"[LONG_TERM_WINDOW] Task failed - error={str(e)}", exc_info=True)
|
||||
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"strategy": "window",
|
||||
"error": str(e),
|
||||
"end_user_id": end_user_id,
|
||||
"config_id": config_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
|
||||
|
||||
# @celery_app.task(name="app.core.memory.agent.long_term_storage.time", bind=True)
|
||||
# def long_term_storage_time_task(
|
||||
# self,
|
||||
# end_user_id: str,
|
||||
# config_id: str,
|
||||
# time_window: int = 5
|
||||
# ) -> Dict[str, Any]:
|
||||
# """Celery task for time-based long-term memory storage.
|
||||
|
||||
# Retrieves recent sessions from Redis within time window and writes to Neo4j.
|
||||
|
||||
# Args:
|
||||
# end_user_id: End user identifier
|
||||
# config_id: Memory configuration ID
|
||||
# time_window: Time window in minutes for retrieving recent sessions
|
||||
|
||||
# Returns:
|
||||
# Dict containing task status and metadata
|
||||
# """
|
||||
# from app.core.logging_config import get_logger
|
||||
# logger = get_logger(__name__)
|
||||
|
||||
# logger.info(f"[LONG_TERM_TIME] Starting task - end_user_id={end_user_id}, time_window={time_window}")
|
||||
# start_time = time.time()
|
||||
|
||||
# async def _run() -> Dict[str, Any]:
|
||||
# from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage
|
||||
# from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# db = next(get_db())
|
||||
# try:
|
||||
# # Load memory config
|
||||
# config_service = MemoryConfigService(db)
|
||||
# memory_config = config_service.load_memory_config(
|
||||
# config_id=config_id,
|
||||
# service_name="LongTermStorageTask"
|
||||
# )
|
||||
|
||||
# # Execute time-based storage
|
||||
# await memory_long_term_storage(end_user_id, memory_config, time_window)
|
||||
|
||||
# return {"status": "SUCCESS", "strategy": "time", "time_window": time_window}
|
||||
# finally:
|
||||
# db.close()
|
||||
|
||||
# try:
|
||||
# import nest_asyncio
|
||||
# nest_asyncio.apply()
|
||||
# except ImportError:
|
||||
# pass
|
||||
|
||||
# try:
|
||||
# loop = asyncio.get_event_loop()
|
||||
# if loop.is_closed():
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
# except RuntimeError:
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
|
||||
# try:
|
||||
# result = loop.run_until_complete(_run())
|
||||
# elapsed_time = time.time() - start_time
|
||||
|
||||
# logger.info(f"[LONG_TERM_TIME] Task completed - elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
# return {
|
||||
# **result,
|
||||
# "end_user_id": end_user_id,
|
||||
# "config_id": config_id,
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "task_id": self.request.id
|
||||
# }
|
||||
# except Exception as e:
|
||||
# elapsed_time = time.time() - start_time
|
||||
# logger.error(f"[LONG_TERM_TIME] Task failed - error={str(e)}", exc_info=True)
|
||||
|
||||
# return {
|
||||
# "status": "FAILURE",
|
||||
# "strategy": "time",
|
||||
# "error": str(e),
|
||||
# "end_user_id": end_user_id,
|
||||
# "config_id": config_id,
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "task_id": self.request.id
|
||||
# }
|
||||
|
||||
|
||||
# @celery_app.task(name="app.core.memory.agent.long_term_storage.aggregate", bind=True)
|
||||
# def long_term_storage_aggregate_task(
|
||||
# self,
|
||||
# end_user_id: str,
|
||||
# langchain_messages: List[Dict[str, Any]],
|
||||
# config_id: str
|
||||
# ) -> Dict[str, Any]:
|
||||
# """Celery task for aggregate-based long-term memory storage.
|
||||
|
||||
# Uses LLM to determine if new messages describe the same event as history.
|
||||
# Only writes to Neo4j if messages represent new information (not duplicates).
|
||||
|
||||
# Args:
|
||||
# end_user_id: End user identifier
|
||||
# langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}]
|
||||
# config_id: Memory configuration ID
|
||||
|
||||
# Returns:
|
||||
# Dict containing task status, is_same_event flag, and metadata
|
||||
# """
|
||||
# from app.core.logging_config import get_logger
|
||||
# logger = get_logger(__name__)
|
||||
|
||||
# logger.info(f"[LONG_TERM_AGGREGATE] Starting task - end_user_id={end_user_id}")
|
||||
# start_time = time.time()
|
||||
|
||||
# async def _run() -> Dict[str, Any]:
|
||||
# from app.core.memory.agent.langgraph_graph.routing.write_router import aggregate_judgment
|
||||
# from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
|
||||
# from app.core.memory.agent.utils.redis_tool import write_store
|
||||
# from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# db = next(get_db())
|
||||
# try:
|
||||
# # Save to Redis buffer first
|
||||
# write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
|
||||
|
||||
# # Load memory config
|
||||
# config_service = MemoryConfigService(db)
|
||||
# memory_config = config_service.load_memory_config(
|
||||
# config_id=config_id,
|
||||
# service_name="LongTermStorageTask"
|
||||
# )
|
||||
|
||||
# # Execute aggregate judgment
|
||||
# result = await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
|
||||
# return {
|
||||
# "status": "SUCCESS",
|
||||
# "strategy": "aggregate",
|
||||
# "is_same_event": result.get("is_same_event", False),
|
||||
# "wrote_to_neo4j": not result.get("is_same_event", False)
|
||||
# }
|
||||
# finally:
|
||||
# db.close()
|
||||
|
||||
# try:
|
||||
# import nest_asyncio
|
||||
# nest_asyncio.apply()
|
||||
# except ImportError:
|
||||
# pass
|
||||
|
||||
# try:
|
||||
# loop = asyncio.get_event_loop()
|
||||
# if loop.is_closed():
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
# except RuntimeError:
|
||||
# loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
|
||||
# try:
|
||||
# result = loop.run_until_complete(_run())
|
||||
# elapsed_time = time.time() - start_time
|
||||
|
||||
# logger.info(f"[LONG_TERM_AGGREGATE] Task completed - is_same_event={result.get('is_same_event')}, elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
# return {
|
||||
# **result,
|
||||
# "end_user_id": end_user_id,
|
||||
# "config_id": config_id,
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "task_id": self.request.id
|
||||
# }
|
||||
# except Exception as e:
|
||||
# elapsed_time = time.time() - start_time
|
||||
# logger.error(f"[LONG_TERM_AGGREGATE] Task failed - error={str(e)}", exc_info=True)
|
||||
|
||||
# return {
|
||||
# "status": "FAILURE",
|
||||
# "strategy": "aggregate",
|
||||
# "error": str(e),
|
||||
# "end_user_id": end_user_id,
|
||||
# "config_id": config_id,
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "task_id": self.request.id
|
||||
# }
|
||||
|
||||
Reference in New Issue
Block a user