From c1941809e95bd05c5c5b0bc50492facc8b885b08 Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Fri, 6 Feb 2026 11:42:02 +0800 Subject: [PATCH] Fix/develop memory bug (#336) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 遗漏的历史映射 * 遗漏的历史映射 * fix_timeline_memories * fix_timeline_memories * write_gragp/bug_fix * write_gragp/bug_fix * write_gragp/bug_fix * write_gragp/bug_fix * Multiple independent transactions - single transaction * memory_content ->memory_config_id * memory_content ->memory_config_id --- .../memory/agent/langgraph_graph/write_graph.py | 17 ++++++++++------- api/app/core/memory/agent/utils/write_tools.py | 2 +- api/app/repositories/neo4j/graph_saver.py | 2 +- api/app/schemas/app_schema.py | 2 +- api/app/services/app_service.py | 2 +- api/app/services/draft_run_service.py | 9 ++++++--- api/app/services/memory_agent_service.py | 9 ++++++--- api/app/services/memory_reflection_service.py | 5 +++-- api/app/services/shared_chat_service.py | 3 ++- 9 files changed, 31 insertions(+), 20 deletions(-) diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index 84ea9381..9b858f47 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -1,4 +1,3 @@ - import asyncio import sys import warnings @@ -15,6 +14,8 @@ logger = get_agent_logger(__name__) if sys.platform.startswith("win"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + @asynccontextmanager async def make_write_graph(): """ @@ -35,9 +36,12 @@ async def make_write_graph(): graph = workflow.compile() yield graph -async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6): + + +async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '', + end_user_id: str = '', scope: int = 6): """Dispatch long-term memory storage to Celery background tasks. - + Args: long_term_type: Storage strategy - 'chunk' (window), 'time', or 'aggregate' langchain_messages: List of messages to store @@ -52,12 +56,12 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[ # long_term_storage_aggregate_task, ) from app.core.logging_config import get_logger - + logger = get_logger(__name__) - + # Convert config to string if needed config_id = str(memory_config) if memory_config else '' - + if long_term_type == 'chunk': # Strategy 1: Window-based batching (6 rounds of dialogue) logger.info(f"[LONG_TERM] Dispatching window task - end_user_id={end_user_id}, scope={scope}") @@ -86,7 +90,6 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[ # config_id=config_id # ) - # async def main(): # """主函数 - 运行工作流""" # langchain_messages = [ diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index e135d980..76a28156 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -174,4 +174,4 @@ async def write( f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n") logger.info("=== Pipeline Complete ===") - logger.info(f"Total execution time: {total_time:.2f} seconds") + logger.info(f"Total execution time: {total_time:.2f} seconds") \ No newline at end of file diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index 1575315f..fc32ca9a 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -207,4 +207,4 @@ async def save_dialog_and_statements_to_neo4j( except Exception as e: print(f"Neo4j integration error: {e}") print("Continuing without database storage...") - return False + return False \ No newline at end of file diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 02d897c5..2f94b69d 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -98,7 +98,7 @@ class ToolOldConfig(BaseModel): class MemoryConfig(BaseModel): """记忆配置""" enabled: bool = Field(default=True, description="是否启用对话历史记忆") - memory_content: Optional[str] = Field(default=None, description="选择记忆的内容类型") + memory_config_id: Optional[str] = Field(default=None, description="选择记忆的内容类型") max_history: int = Field(default=10, ge=0, le=100, description="最大保留的历史对话轮数") diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 5541ec80..8a2a0428 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -996,7 +996,7 @@ class AppService: }, memory={ "enabled": True, - "memory_content": None, + "memory_config_id": None, "max_history": 10 }, variables=[], diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 40ef4971..31662769 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -63,7 +63,8 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str 长期记忆工具 """ # search_switch = memory_config.get("search_switch", "2") - config_id= memory_config.get("memory_content") or memory_config.get("memory_config",None) + # 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content + config_id = memory_config.get("memory_config_id") or memory_config.get("memory_content", None) logger.info(f"创建长期记忆工具,配置: end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}") @tool(args_schema=LongTermMemoryInput) def long_term_memory(question: str) -> str: @@ -455,7 +456,8 @@ class DraftRunService: ) memory_config_= agent_config.memory - config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None) + # 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content + config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None) # 8. 调用 Agent(支持多模态) result = await agent.chat( @@ -718,7 +720,8 @@ class DraftRunService: }) memory_config_ = agent_config.memory - config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None) + # 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content + config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None) # 9. 流式调用 Agent(支持多模态) full_content = "" diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index a48d0072..9628950b 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -1199,7 +1199,8 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An config = {} memory_obj = config.get('memory', {}) - memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None + # 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content + memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None result = { "end_user_id": str(end_user_id), @@ -1289,7 +1290,8 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) if release: config = release.config or {} memory_obj = config.get('memory', {}) - memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None + # 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content + memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None if memory_config_id: # 判断是否为UUID格式 if len(str(memory_config_id))>=5: @@ -1335,7 +1337,8 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 从 config 中提取 memory_config_id config = release.config or {} memory_obj = config.get('memory', {}) - memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None + # 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content + memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None # 获取配置名称(使用字符串形式的ID进行查找,兼容新旧格式) memory_config_name = config_id_to_name.get(str(memory_config_id)) if memory_config_id else None diff --git a/api/app/services/memory_reflection_service.py b/api/app/services/memory_reflection_service.py index e025c1b3..0e542ff0 100644 --- a/api/app/services/memory_reflection_service.py +++ b/api/app/services/memory_reflection_service.py @@ -108,13 +108,14 @@ class WorkspaceAppService: app_info["releases"].append(release_info) def _extract_memory_content(self, config: Any) -> str: - """Extract memory_comtent from config""" + """Extract memory_config_id from config (兼容新旧字段名)""" if not config or not isinstance(config, dict): return None memory_obj = config.get('memory') if memory_obj and isinstance(memory_obj, dict): - return memory_obj.get('memory_content') + # 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content + return memory_obj.get('memory_config_id') or memory_obj.get('memory_content') return None diff --git a/api/app/services/shared_chat_service.py b/api/app/services/shared_chat_service.py index 6fa5961c..c7b81999 100644 --- a/api/app/services/shared_chat_service.py +++ b/api/app/services/shared_chat_service.py @@ -353,7 +353,8 @@ class SharedChatService: if variables is None: variables = {} - memory_config = {"enabled": memory, "memory_content": "17", "max_history": 10} + # 兼容新旧字段名:使用 memory_config_id + memory_config = {"enabled": memory, "memory_config_id": "17", "max_history": 10} try: # 获取发布版本和配置