Merge branch 'release/v0.2.3' into develop

This commit is contained in:
Mark
2026-02-04 13:52:45 +08:00
24 changed files with 527 additions and 62 deletions

View File

@@ -291,8 +291,10 @@ class LangChainAgent:
return messages
# TODO: 移到memory module
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
db = next(get_db())
#TODO: 魔法数字
scope=6
try:
@@ -302,6 +304,12 @@ class LangChainAgent:
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
# Handle case where no session exists in Redis (returns False)
if not result or result is False:
logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update")
return
if type=="chunk" or type=="aggregate":
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
@@ -309,7 +317,14 @@ class LangChainAgent:
repo.upsert(end_user_id, chunk_data)
logger.info(f'写入短长期:')
else:
# TODO: This branch handles type="time" strategy, currently unused.
# Will be activated when time-based long-term storage is implemented.
# TODO: 魔法数字 - extract 5 to a constant
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
# Handle case where no session exists in Redis (returns False or empty)
if not long_time_data or long_time_data is False:
logger.debug(f"No recent sessions in Redis for user {end_user_id}")
return
long_messages = await messages_parse(long_time_data)
repo.upsert(end_user_id, long_messages)
logger.info(f'写入短长期:')
@@ -509,9 +524,12 @@ class LangChainAgent:
elapsed_time = time.time() - start_time
if memory_flag:
long_term_messages=await agent_chat_messages(message_chat,content)
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
'''长期'''
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
response = {
"content": content,
@@ -695,9 +713,13 @@ class LangChainAgent:
yield total_tokens
break
if memory_flag:
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
long_term_messages = await agent_chat_messages(message_chat, full_content)
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk")
except Exception as e:

View File

@@ -43,6 +43,7 @@ async def write_messages(end_user_id,langchain_messages,memory_config):
for node_name, node_data in update_event.items():
if 'save_neo4j' == node_name:
massages = node_data
# TODO删除
massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
print(contents)
@@ -60,6 +61,7 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
scope窗口大小
'''
scope=scope
redis_messages = []
is_end_user_id = count_store.get_sessions_count(end_user_id)
if is_end_user_id is not False:
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
@@ -91,6 +93,9 @@ async def memory_long_term_storage(end_user_id,memory_config,time):
memory_config: 内存配置对象
'''
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
# Handle case where no session exists in Redis (returns False or empty)
if not long_time_data or long_time_data is False:
return
format_messages = await chat_data_format(long_time_data)
if format_messages!=[]:
await write_messages(end_user_id, format_messages, memory_config)
@@ -108,8 +113,9 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
try:
# 1. 获取历史会话数据(使用新方法)
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
history = await format_parsing(result)
if not result:
# Handle case where no session exists in Redis (returns False or empty)
if not result or result is False:
history = []
else:
history = await format_parsing(result)

View File

@@ -1,18 +1,14 @@
import asyncio
import json
import sys
import warnings
from contextlib import asynccontextmanager
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, chat_data_format, messages_parse
from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__)
@@ -40,27 +36,55 @@ async def make_write_graph():
yield graph
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
from app.core.memory.agent.utils.redis_tool import write_store
write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
# 获取数据库会话
db_session = next(get_db())
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=memory_config, # 改为整数
service_name="MemoryAgentService"
"""Dispatch long-term memory storage to Celery background tasks.
Args:
long_term_type: Storage strategy - 'chunk' (window), 'time', or 'aggregate'
langchain_messages: List of messages to store
memory_config: Memory configuration ID (string)
end_user_id: End user identifier
scope: Window size for 'chunk' strategy (default: 6)
"""
from app.tasks import (
long_term_storage_window_task,
# TODO: Uncomment when implemented
# long_term_storage_time_task,
# long_term_storage_aggregate_task,
)
if long_term_type=='chunk':
'''方案一:对话窗口6轮对话'''
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
if long_term_type=='time':
"""时间"""
await memory_long_term_storage(end_user_id, memory_config,5)
if long_term_type=='aggregate':
"""方案三:聚合判断"""
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
from app.core.logging_config import get_logger
logger = get_logger(__name__)
# Convert config to string if needed
config_id = str(memory_config) if memory_config else ''
if long_term_type == 'chunk':
# Strategy 1: Window-based batching (6 rounds of dialogue)
logger.info(f"[LONG_TERM] Dispatching window task - end_user_id={end_user_id}, scope={scope}")
long_term_storage_window_task.delay(
end_user_id=end_user_id,
langchain_messages=langchain_messages,
config_id=config_id,
scope=scope
)
# TODO: Uncomment when time-based strategy is fully implemented
# elif long_term_type == 'time':
# # Strategy 2: Time-based retrieval
# logger.info(f"[LONG_TERM] Dispatching time task - end_user_id={end_user_id}")
# long_term_storage_time_task.delay(
# end_user_id=end_user_id,
# config_id=config_id,
# time_window=5
# )
# TODO: Uncomment when aggregate strategy is fully implemented
# elif long_term_type == 'aggregate':
# # Strategy 3: Aggregate judgment (deduplication)
# logger.info(f"[LONG_TERM] Dispatching aggregate task - end_user_id={end_user_id}")
# long_term_storage_aggregate_task.delay(
# end_user_id=end_user_id,
# langchain_messages=langchain_messages,
# config_id=config_id
# )
# async def main():

View File

@@ -2,6 +2,7 @@ import base64
import json
import logging
import re
import urllib.parse
from string import Template
from textwrap import dedent
from typing import Any