From 5694bc0230e6b8b1661cffc5d0df2505e163a5b3 Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Wed, 4 Feb 2026 12:27:14 +0800 Subject: [PATCH 1/7] fix(fix the key of the app's token): --- api/app/services/app_statistics_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/app/services/app_statistics_service.py b/api/app/services/app_statistics_service.py index 1b6bc3b8..5cfa3229 100644 --- a/api/app/services/app_statistics_service.py +++ b/api/app/services/app_statistics_service.py @@ -188,6 +188,6 @@ class AppStatisticsService: daily_tokens[date_str] += int(tokens) daily_data = [{"date": date, "count": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0] - total = sum(row["tokens"] for row in daily_data) + total = sum(row["count"] for row in daily_data) return {"daily": daily_data, "total": total} From bc36b791055b8892cb50c4c55c14144c8c295cac Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Wed, 4 Feb 2026 12:28:28 +0800 Subject: [PATCH 2/7] fix(workflow): switch code input encoding to base64+URL encoding --- api/app/core/workflow/nodes/code/node.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api/app/core/workflow/nodes/code/node.py b/api/app/core/workflow/nodes/code/node.py index 019fec84..daee1e78 100644 --- a/api/app/core/workflow/nodes/code/node.py +++ b/api/app/core/workflow/nodes/code/node.py @@ -2,6 +2,7 @@ import base64 import json import logging import re +import urllib.parse from string import Template from textwrap import dedent from typing import Any @@ -101,6 +102,7 @@ class CodeNode(BaseNode): code = base64.b64decode( self.typed_config.code ).decode("utf-8") + code = urllib.parse.unquote(code, encoding='utf-8') input_variable_dict = base64.b64encode( json.dumps(input_variable_dict).encode("utf-8") From 88c95db8d0e6518de634700a0b91137a2a9671da Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Feb 2026 12:19:00 +0800 Subject: [PATCH 3/7] [add]The main project adds multi-API Key load balancing. --- api/app/controllers/ontology_controller.py | 35 ++++++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py index 43d3b1d2..895b6d40 100644 --- a/api/app/controllers/ontology_controller.py +++ b/api/app/controllers/ontology_controller.py @@ -190,19 +190,48 @@ def _get_ontology_service( detail="指定的LLM模型没有配置API密钥" ) - api_key_config = model_config.api_keys[0] + # 获取可用的 API Key(只选择激活状态的) + active_api_keys = [ak for ak in model_config.api_keys if ak.is_active] + if not active_api_keys: + logger.error(f"Model {llm_id} has no active API key") + raise HTTPException( + status_code=400, + detail="指定的LLM模型没有可用的API密钥" + ) + + # 对于组合模型,根据负载均衡策略选择 API Key + if model_config.is_composite and len(active_api_keys) > 1: + from app.models.models_model import LoadBalanceStrategy + if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: + # 轮询策略:选择使用次数最少的 API Key + api_key_config = min(active_api_keys, key=lambda x: int(x.usage_count or "0")) + else: + # 默认策略:按优先级选择 + api_key_config = min(active_api_keys, key=lambda x: int(x.priority or "1")) + logger.info( + f"Composite model using load balance strategy: {model_config.load_balance_strategy}, " + f"selected API Key: {api_key_config.id}, provider: {api_key_config.provider}" + ) + else: + api_key_config = active_api_keys[0] logger.info( f"Using specified model - user: {current_user.id}, " - f"model_id: {llm_id}, model_name: {api_key_config.model_name}" + f"model_id: {llm_id}, model_name: {api_key_config.model_name}, " + f"is_composite: {model_config.is_composite}" ) # 创建模型配置对象 from app.core.models.base import RedBearModelConfig + # 对于组合模型,使用 API Key 的 provider;否则使用 model_config 的 provider + actual_provider = api_key_config.provider if model_config.is_composite else ( + model_config.provider if hasattr(model_config, 'provider') else "openai" + ) + llm_model_config = RedBearModelConfig( model_name=api_key_config.model_name, - provider=model_config.provider if hasattr(model_config, 'provider') else "openai", + provider=actual_provider, api_key=api_key_config.api_key, base_url=api_key_config.api_base, max_retries=3, From ffff138a6ff4714e01feec563a23b631d016eec2 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Feb 2026 12:28:05 +0800 Subject: [PATCH 4/7] [changes]Attribute security access, secure numerical conversion, unified use of local variables --- api/app/controllers/ontology_controller.py | 32 +++++++++++++++------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py index 895b6d40..1c22529b 100644 --- a/api/app/controllers/ontology_controller.py +++ b/api/app/controllers/ontology_controller.py @@ -191,7 +191,7 @@ def _get_ontology_service( ) # 获取可用的 API Key(只选择激活状态的) - active_api_keys = [ak for ak in model_config.api_keys if ak.is_active] + active_api_keys = [ak for ak in model_config.api_keys if getattr(ak, 'is_active', True)] if not active_api_keys: logger.error(f"Model {llm_id} has no active API key") raise HTTPException( @@ -199,17 +199,29 @@ def _get_ontology_service( detail="指定的LLM模型没有可用的API密钥" ) + # 安全的数值转换辅助函数 + def safe_int(value, default: int = 0) -> int: + """安全地将值转换为整数,异常时返回默认值""" + if value is None: + return default + try: + return int(value) + except (ValueError, TypeError): + return default + # 对于组合模型,根据负载均衡策略选择 API Key - if model_config.is_composite and len(active_api_keys) > 1: + is_composite = getattr(model_config, 'is_composite', False) + if is_composite and len(active_api_keys) > 1: from app.models.models_model import LoadBalanceStrategy - if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: + load_balance_strategy = getattr(model_config, 'load_balance_strategy', None) + if load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: # 轮询策略:选择使用次数最少的 API Key - api_key_config = min(active_api_keys, key=lambda x: int(x.usage_count or "0")) + api_key_config = min(active_api_keys, key=lambda x: safe_int(x.usage_count, 0)) else: - # 默认策略:按优先级选择 - api_key_config = min(active_api_keys, key=lambda x: int(x.priority or "1")) + # 默认策略:按优先级选择(优先级数值越小越优先) + api_key_config = min(active_api_keys, key=lambda x: safe_int(x.priority, 1)) logger.info( - f"Composite model using load balance strategy: {model_config.load_balance_strategy}, " + f"Composite model using load balance strategy: {load_balance_strategy}, " f"selected API Key: {api_key_config.id}, provider: {api_key_config.provider}" ) else: @@ -218,15 +230,15 @@ def _get_ontology_service( logger.info( f"Using specified model - user: {current_user.id}, " f"model_id: {llm_id}, model_name: {api_key_config.model_name}, " - f"is_composite: {model_config.is_composite}" + f"is_composite: {is_composite}" ) # 创建模型配置对象 from app.core.models.base import RedBearModelConfig # 对于组合模型,使用 API Key 的 provider;否则使用 model_config 的 provider - actual_provider = api_key_config.provider if model_config.is_composite else ( - model_config.provider if hasattr(model_config, 'provider') else "openai" + actual_provider = api_key_config.provider if is_composite else ( + getattr(model_config, 'provider', None) or "openai" ) llm_model_config = RedBearModelConfig( From 34f0c3b90c06227d9220e322df41a1e7d5b5feb2 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Feb 2026 13:44:07 +0800 Subject: [PATCH 5/7] [changes]Active status filtering logic, API Key selection strategy --- api/app/controllers/ontology_controller.py | 25 +++++++++++----------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py index 1c22529b..8ad47bc1 100644 --- a/api/app/controllers/ontology_controller.py +++ b/api/app/controllers/ontology_controller.py @@ -190,15 +190,6 @@ def _get_ontology_service( detail="指定的LLM模型没有配置API密钥" ) - # 获取可用的 API Key(只选择激活状态的) - active_api_keys = [ak for ak in model_config.api_keys if getattr(ak, 'is_active', True)] - if not active_api_keys: - logger.error(f"Model {llm_id} has no active API key") - raise HTTPException( - status_code=400, - detail="指定的LLM模型没有可用的API密钥" - ) - # 安全的数值转换辅助函数 def safe_int(value, default: int = 0) -> int: """安全地将值转换为整数,异常时返回默认值""" @@ -209,9 +200,19 @@ def _get_ontology_service( except (ValueError, TypeError): return default - # 对于组合模型,根据负载均衡策略选择 API Key + # 获取可用的 API Key(只选择激活状态的) + # 注意:is_active 为 None 时视为非激活状态,避免旧记录被错误地当作激活 + active_api_keys = [ak for ak in model_config.api_keys if getattr(ak, 'is_active', None) is True] + if not active_api_keys: + logger.error(f"Model {llm_id} has no active API key") + raise HTTPException( + status_code=400, + detail="指定的LLM模型没有可用的API密钥" + ) + + # 根据负载均衡策略选择 API Key(组合模型和非组合模型统一处理) is_composite = getattr(model_config, 'is_composite', False) - if is_composite and len(active_api_keys) > 1: + if len(active_api_keys) > 1: from app.models.models_model import LoadBalanceStrategy load_balance_strategy = getattr(model_config, 'load_balance_strategy', None) if load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: @@ -221,7 +222,7 @@ def _get_ontology_service( # 默认策略:按优先级选择(优先级数值越小越优先) api_key_config = min(active_api_keys, key=lambda x: safe_int(x.priority, 1)) logger.info( - f"Composite model using load balance strategy: {load_balance_strategy}, " + f"Model (is_composite={is_composite}) using load balance strategy: {load_balance_strategy}, " f"selected API Key: {api_key_config.id}, provider: {api_key_config.provider}" ) else: From 333836f5e796cfdca9977fe19c8790bbf5ebc768 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 4 Feb 2026 14:08:09 +0800 Subject: [PATCH 6/7] [changes] --- api/app/controllers/ontology_controller.py | 46 +++------------------- 1 file changed, 6 insertions(+), 40 deletions(-) diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py index 8ad47bc1..6520d835 100644 --- a/api/app/controllers/ontology_controller.py +++ b/api/app/controllers/ontology_controller.py @@ -182,56 +182,22 @@ def _get_ontology_service( detail=f"找不到指定的LLM模型: {llm_id}" ) - # 验证模型配置了API密钥 - if not model_config.api_keys: - logger.error(f"Model {llm_id} has no API key configuration") - raise HTTPException( - status_code=400, - detail="指定的LLM模型没有配置API密钥" - ) - - # 安全的数值转换辅助函数 - def safe_int(value, default: int = 0) -> int: - """安全地将值转换为整数,异常时返回默认值""" - if value is None: - return default - try: - return int(value) - except (ValueError, TypeError): - return default - - # 获取可用的 API Key(只选择激活状态的) - # 注意:is_active 为 None 时视为非激活状态,避免旧记录被错误地当作激活 - active_api_keys = [ak for ak in model_config.api_keys if getattr(ak, 'is_active', None) is True] - if not active_api_keys: + # 通过 Repository 获取可用的 API Key(负载均衡逻辑由 Repository 处理) + from app.repositories.model_repository import ModelApiKeyRepository + api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id) + if not api_keys: logger.error(f"Model {llm_id} has no active API key") raise HTTPException( status_code=400, detail="指定的LLM模型没有可用的API密钥" ) + api_key_config = api_keys[0] - # 根据负载均衡策略选择 API Key(组合模型和非组合模型统一处理) is_composite = getattr(model_config, 'is_composite', False) - if len(active_api_keys) > 1: - from app.models.models_model import LoadBalanceStrategy - load_balance_strategy = getattr(model_config, 'load_balance_strategy', None) - if load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: - # 轮询策略:选择使用次数最少的 API Key - api_key_config = min(active_api_keys, key=lambda x: safe_int(x.usage_count, 0)) - else: - # 默认策略:按优先级选择(优先级数值越小越优先) - api_key_config = min(active_api_keys, key=lambda x: safe_int(x.priority, 1)) - logger.info( - f"Model (is_composite={is_composite}) using load balance strategy: {load_balance_strategy}, " - f"selected API Key: {api_key_config.id}, provider: {api_key_config.provider}" - ) - else: - api_key_config = active_api_keys[0] - logger.info( f"Using specified model - user: {current_user.id}, " f"model_id: {llm_id}, model_name: {api_key_config.model_name}, " - f"is_composite: {is_composite}" + f"is_composite: {is_composite}, api_key_id: {api_key_config.id}" ) # 创建模型配置对象 From 8f0a1d9c6e18349e1d819c4d60b5861492f622de Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Wed, 4 Feb 2026 14:34:00 +0800 Subject: [PATCH 7/7] Fix/release memory bug (#306) * memory_BUG_fix * memory_BUG * memory_BUG_long_term * memory_BUG_long_term * memory_BUG_long_term --- api/app/core/agent/langchain_agent.py | 132 +------------ .../langgraph_graph/routing/write_router.py | 182 ++++++++++++------ .../agent/langgraph_graph/tools/write_tool.py | 34 +--- .../agent/langgraph_graph/write_graph.py | 105 +++++----- api/app/schemas/memory_agent_schema.py | 14 +- api/app/services/draft_run_service.py | 3 +- 6 files changed, 202 insertions(+), 268 deletions(-) diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 7e0015ae..e519ea53 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -7,30 +7,21 @@ LangChain Agent 封装 - 支持流式输出 - 使用 RedBearLLM 支持多提供商 """ -import os + import time from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence -from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages, format_parsing, messages_parse -from app.core.memory.agent.langgraph_graph.write_graph import long_term_storage +from app.core.memory.agent.langgraph_graph.write_graph import write_long_term from app.db import get_db from app.core.logging_config import get_business_logger -from app.core.memory.agent.utils.redis_tool import store from app.core.models import RedBearLLM, RedBearModelConfig from app.models.models_model import ModelType -from app.repositories.memory_short_repository import LongTermMemoryRepository from app.services.memory_agent_service import ( get_end_user_connected_config, ) -from app.services.memory_konwledges_server import write_rag -from app.services.task_service import get_task_memory_write_result -from app.tasks import write_message_task from langchain.agents import create_agent from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.tools import BaseTool - -from app.utils.config_utils import resolve_config_id - logger = get_business_logger() @@ -148,106 +139,6 @@ class LangChainAgent: messages.append(HumanMessage(content=user_content)) return messages - # TODO: 移到memory module - async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type): - db = next(get_db()) - #TODO: 魔法数字 - scope=6 - - try: - repo = LongTermMemoryRepository(db) - await long_term_storage(long_term_type="chunk", langchain_messages=long_term_messages, - memory_config=actual_config_id, end_user_id=end_user_id, scope=scope) - - from app.core.memory.agent.utils.redis_tool import write_store - result = write_store.get_session_by_userid(end_user_id) - - # Handle case where no session exists in Redis (returns False) - if not result or result is False: - logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update") - return - - if type=="chunk" or type=="aggregate": - data = await format_parsing(result, "dict") - chunk_data = data[:scope] - if len(chunk_data)==scope: - repo.upsert(end_user_id, chunk_data) - logger.info(f'写入短长期:') - else: - # TODO: This branch handles type="time" strategy, currently unused. - # Will be activated when time-based long-term storage is implemented. - # TODO: 魔法数字 - extract 5 to a constant - long_time_data = write_store.find_user_recent_sessions(end_user_id, 5) - # Handle case where no session exists in Redis (returns False or empty) - if not long_time_data or long_time_data is False: - logger.debug(f"No recent sessions in Redis for user {end_user_id}") - return - long_messages = await messages_parse(long_time_data) - repo.upsert(end_user_id, long_messages) - logger.info(f'写入短长期:') - finally: - db.close() - - async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id): - """ - 写入记忆(支持结构化消息) - - Args: - storage_type: 存储类型 (neo4j/rag) - end_user_id: 终端用户ID - user_message: 用户消息内容 - ai_message: AI 回复内容 - user_rag_memory_id: RAG 记忆ID - actual_end_user_id: 实际用户ID - actual_config_id: 配置ID - - 逻辑说明: - - RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变 - - Neo4j 模式:使用结构化消息列表 - 1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant] - 2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景) - 3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段 - """ - - db = next(get_db()) - try: - actual_config_id=resolve_config_id(actual_config_id, db) - - if storage_type == "rag": - # RAG 模式:组合消息为字符串格式(保持原有逻辑) - combined_message = f"user: {user_message}\nassistant: {ai_message}" - await write_rag(end_user_id, combined_message, user_rag_memory_id) - logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') - else: - # Neo4j 模式:使用结构化消息列表 - structured_messages = [] - - # 始终添加用户消息(如果不为空) - if user_message: - structured_messages.append({"role": "user", "content": user_message}) - - # 只有当 AI 回复不为空时才添加 assistant 消息 - if ai_message: - structured_messages.append({"role": "assistant", "content": ai_message}) - - # 如果没有消息,直接返回 - if not structured_messages: - logger.warning(f"No messages to write for user {actual_end_user_id}") - return - - logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") - write_id = write_message_task.delay( - actual_end_user_id, # end_user_id: 用户ID - structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] - actual_config_id, # config_id: 配置ID - storage_type, # storage_type: "neo4j" - user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用) - ) - logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") - write_status = get_task_memory_write_result(str(write_id)) - logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') - finally: - db.close() async def chat( self, message: str, @@ -321,14 +212,7 @@ class LangChainAgent: elapsed_time = time.time() - start_time if memory_flag: - long_term_messages=await agent_chat_messages(message_chat,content) - # TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable. - # This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j - # when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph. - # Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time. - await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id) - # Batched long-term memory storage (Redis buffer + Neo4j when window full) - await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk") + await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id) response = { "content": content, "model": self.model_name, @@ -459,15 +343,7 @@ class LangChainAgent: yield total_tokens break if memory_flag: - # TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable. - # This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j - # when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph. - # Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time. - long_term_messages = await agent_chat_messages(message_chat, full_content) - await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id) - # Batched long-term memory storage (Redis buffer + Neo4j when window full) - await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk") - + await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id) except Exception as e: logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) raise diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index e9de02b6..29257e88 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -1,8 +1,9 @@ +import json import os from app.core.logging_config import get_agent_logger -from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format, format_parsing -from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph +from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse +from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ @@ -10,46 +11,115 @@ from app.core.memory.agent.utils.redis_tool import write_store from app.core.memory.agent.utils.redis_tool import count_store from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context +from app.db import get_db_context, get_db +from app.repositories.memory_short_repository import LongTermMemoryRepository +from app.schemas.memory_agent_schema import AgentMemory_Long_Term +from app.services.memory_konwledges_server import write_rag +from app.services.task_service import get_task_memory_write_result +from app.tasks import write_message_task +from app.utils.config_utils import resolve_config_id + logger = get_agent_logger(__name__) template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') +async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id): + # RAG 模式:组合消息为字符串格式(保持原有逻辑) + combined_message = f"user: {user_message}\nassistant: {ai_message}" + await write_rag(end_user_id, combined_message, user_rag_memory_id) + logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') +async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, + actual_config_id, long_term_messages=[]): + """ + 写入记忆(支持结构化消息) -async def write_messages(end_user_id,langchain_messages,memory_config): - ''' - 写入数据到neo4j: - Args: + Args: + storage_type: 存储类型 (neo4j/rag) end_user_id: 终端用户ID - memory_config: 内存配置对象 - langchain_messages:原始数据LIST - ''' + user_message: 用户消息内容 + ai_message: AI 回复内容 + user_rag_memory_id: RAG 记忆ID + actual_end_user_id: 实际用户ID + actual_config_id: 配置ID + + 逻辑说明: + - RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变 + - Neo4j 模式:使用结构化消息列表 + 1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant] + 2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景) + 3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段 + """ + + db = next(get_db()) try: + actual_config_id = resolve_config_id(actual_config_id, db) + # Neo4j 模式:使用结构化消息列表 + structured_messages = [] + + # 始终添加用户消息(如果不为空) + if isinstance(user_message, str) and user_message.strip() != "": + structured_messages.append({"role": "user", "content": user_message}) + + # 只有当 AI 回复不为空时才添加 assistant 消息 + if isinstance(ai_message, str) and ai_message.strip() != "": + structured_messages.append({"role": "assistant", "content": ai_message}) + + # 如果提供了 long_term_messages,使用它替代 structured_messages + if long_term_messages and isinstance(long_term_messages, list): + structured_messages = long_term_messages + elif long_term_messages and isinstance(long_term_messages, str): + # 如果是 JSON 字符串,先解析 + try: + structured_messages = json.loads(long_term_messages) + except json.JSONDecodeError: + logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}") + + # 如果没有消息,直接返回 + if not structured_messages: + logger.warning(f"No messages to write for user {actual_end_user_id}") + return + + logger.info( + f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") + write_id = write_message_task.delay( + actual_end_user_id, # end_user_id: 用户ID + structured_messages, # message: JSON 字符串格式的消息列表 + str(actual_config_id), # config_id: 配置ID字符串 + storage_type, # storage_type: "neo4j" + user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用) + ) + logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") + write_status = get_task_memory_write_result(str(write_id)) + logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') + finally: + db.close() + +async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope): + with get_db_context() as db_session: + try: + repo = LongTermMemoryRepository(db_session) + await long_term_storage(long_term_type=AgentMemory_Long_Term.STRATEGY_CHUNK, langchain_messages=long_term_messages, + memory_config=actual_config_id, end_user_id=end_user_id, scope=scope) + + from app.core.memory.agent.utils.redis_tool import write_store + result = write_store.get_session_by_userid(end_user_id) + if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: + data = await format_parsing(result, "dict") + chunk_data = data[:scope] + if len(chunk_data)==scope: + repo.upsert(end_user_id, chunk_data) + logger.info(f'---------写入短长期-----------') + else: + long_time_data = write_store.find_user_recent_sessions(end_user_id, 5) + long_messages = await messages_parse(long_time_data) + repo.upsert(end_user_id, long_messages) + logger.info(f'写入短长期:') + # yield db_session + finally: + if db_session.in_transaction(): + db_session.rollback() + db_session.close() - async with make_write_graph() as graph: - config = {"configurable": {"thread_id": end_user_id}} - # 初始状态 - 包含所有必要字段 - initial_state = { - "messages": langchain_messages, - "end_user_id": end_user_id, - "memory_config": memory_config - } - # 获取节点更新信息 - async for update_event in graph.astream( - initial_state, - stream_mode="updates", - config=config - ): - for node_name, node_data in update_event.items(): - if 'save_neo4j' == node_name: - massages = node_data - # TODO:删除 - massagesstatus = massages.get('write_result')['status'] - contents = massages.get('write_result') - print(contents) - except Exception as e: - import traceback - traceback.print_exc() '''根据窗口''' async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): ''' @@ -61,25 +131,26 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): scope:窗口大小 ''' scope=scope - redis_messages = [] is_end_user_id = count_store.get_sessions_count(end_user_id) if is_end_user_id is not False: is_end_user_id = count_store.get_sessions_count(end_user_id)[0] redis_messages = count_store.get_sessions_count(end_user_id)[1] if is_end_user_id and int(is_end_user_id) != int(scope): - print(is_end_user_id) is_end_user_id += 1 langchain_messages += redis_messages count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) elif int(is_end_user_id) == int(scope): - print('写入长期记忆,并且设置为0') - print(is_end_user_id) - formatted_messages = await chat_data_format(redis_messages) - print(100*'-') - print(formatted_messages) - print(100*'-') - await write_messages(end_user_id, formatted_messages, memory_config) - count_store.update_sessions_count(end_user_id, 0, '') + logger.info('写入长期记忆NEO4J') + formatted_messages = (redis_messages) + # 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用) + if hasattr(memory_config, 'config_id'): + config_id = memory_config.config_id + else: + config_id = memory_config + + await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id, + config_id, formatted_messages) + count_store.update_sessions_count(end_user_id, 1, langchain_messages) else: count_store.save_sessions_count(end_user_id, 1, langchain_messages) @@ -93,12 +164,15 @@ async def memory_long_term_storage(end_user_id,memory_config,time): memory_config: 内存配置对象 ''' long_time_data = write_store.find_user_recent_sessions(end_user_id, time) - # Handle case where no session exists in Redis (returns False or empty) - if not long_time_data or long_time_data is False: - return - format_messages = await chat_data_format(long_time_data) + format_messages = (long_time_data) + messages=[] + memory_config=memory_config.config_id + for i in format_messages: + message=json.loads(i['Query']) + messages+= message if format_messages!=[]: - await write_messages(end_user_id, format_messages, memory_config) + await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id, + memory_config, messages) '''聚合判断''' async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict: """ @@ -109,13 +183,12 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] memory_config: 内存配置对象 """ - + try: # 1. 获取历史会话数据(使用新方法) result = write_store.get_all_sessions_by_end_user_id(end_user_id) - - # Handle case where no session exists in Redis (returns False or empty) - if not result or result is False: + history = await format_parsing(result) + if not result: history = [] else: history = await format_parsing(result) @@ -154,7 +227,8 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config } if not structured.is_same_event: logger.info(result_dict) - await write_messages(end_user_id, output_value, memory_config) + await write("neo4j", end_user_id, "", "", None, end_user_id, + memory_config.config_id, output_value) return result_dict except Exception as e: diff --git a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py index a1fb8226..d0be8e5c 100644 --- a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py +++ b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py @@ -26,13 +26,13 @@ async def format_parsing(messages: list,type:str='string'): role = content['role'] content = content['content'] if type == "string": - if role == 'human': + if role == 'human' or role=="user": content = '用户:' + content else: content = 'AI:' + content result.append(content) - if type == "dict": - if role == 'human': + if type == "dict" : + if role == 'human' or role=="user": user.append( content) else: ai.append(content) @@ -57,33 +57,7 @@ async def messages_parse(messages: list | dict): for key, values in zip(user, ai): database.append({key, values}) return database -async def chat_data_format(messages: list | dict): - """ - 将消息格式化为 LangChain 消息格式 - - Args: - messages: 消息列表或字典 - - Returns: - LangChain 消息列表 - """ - langchain_messages = [] - if isinstance(messages, list): - for msg in messages: - if 'role' in msg.keys(): - if msg['role'] == 'user': - langchain_messages.append(HumanMessage(content=msg['content'])) - elif msg['role'] == 'assistant': - langchain_messages.append(AIMessage(content=msg['content'])) - if "Query" in msg.keys(): - langchain_messages.append(HumanMessage(content=msg['Query'])) - langchain_messages.append(AIMessage(content=msg['Answer'])) - if isinstance(messages, dict): - if messages['type'] == 'human': - langchain_messages.append(HumanMessage(content=messages['content'])) - elif messages['type'] == 'ai': - langchain_messages.append(AIMessage(content=messages['content'])) - return langchain_messages + async def agent_chat_messages(user_content,ai_content): messages = [ diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index 84ea9381..c0e6f86e 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -1,14 +1,19 @@ import asyncio +import json import sys import warnings from contextlib import asynccontextmanager from langgraph.constants import END, START from langgraph.graph import StateGraph +from app.db import get_db, get_db_context from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node +from app.schemas.memory_agent_schema import AgentMemory_Long_Term +from app.services.memory_config_service import MemoryConfigService + warnings.filterwarnings("ignore", category=RuntimeWarning) logger = get_agent_logger(__name__) @@ -35,75 +40,67 @@ async def make_write_graph(): graph = workflow.compile() yield graph -async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6): - """Dispatch long-term memory storage to Celery background tasks. - - Args: - long_term_type: Storage strategy - 'chunk' (window), 'time', or 'aggregate' - langchain_messages: List of messages to store - memory_config: Memory configuration ID (string) - end_user_id: End user identifier - scope: Window size for 'chunk' strategy (default: 6) - """ - from app.tasks import ( - long_term_storage_window_task, - # TODO: Uncomment when implemented - # long_term_storage_time_task, - # long_term_storage_aggregate_task, - ) - from app.core.logging_config import get_logger - - logger = get_logger(__name__) - - # Convert config to string if needed - config_id = str(memory_config) if memory_config else '' - - if long_term_type == 'chunk': - # Strategy 1: Window-based batching (6 rounds of dialogue) - logger.info(f"[LONG_TERM] Dispatching window task - end_user_id={end_user_id}, scope={scope}") - long_term_storage_window_task.delay( - end_user_id=end_user_id, - langchain_messages=langchain_messages, - config_id=config_id, - scope=scope - ) - # TODO: Uncomment when time-based strategy is fully implemented - # elif long_term_type == 'time': - # # Strategy 2: Time-based retrieval - # logger.info(f"[LONG_TERM] Dispatching time task - end_user_id={end_user_id}") - # long_term_storage_time_task.delay( - # end_user_id=end_user_id, - # config_id=config_id, - # time_window=5 - # ) - # TODO: Uncomment when aggregate strategy is fully implemented - # elif long_term_type == 'aggregate': - # # Strategy 3: Aggregate judgment (deduplication) - # logger.info(f"[LONG_TERM] Dispatching aggregate task - end_user_id={end_user_id}") - # long_term_storage_aggregate_task.delay( - # end_user_id=end_user_id, - # langchain_messages=langchain_messages, - # config_id=config_id - # ) +async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6): + from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment + from app.core.memory.agent.utils.redis_tool import write_store + write_store.save_session_write(end_user_id, (langchain_messages)) + # 获取数据库会话 + with get_db_context() as db_session: + try: + config_service = MemoryConfigService(db_session) + memory_config = config_service.load_memory_config( + config_id=memory_config, # 改为整数 + service_name="MemoryAgentService" + ) + if long_term_type=='chunk': + '''方案一:对话窗口6轮对话''' + await window_dialogue(end_user_id,langchain_messages,memory_config,scope) + if long_term_type=='time': + """时间""" + await memory_long_term_storage(end_user_id, memory_config,5) + if long_term_type=='aggregate': + """方案三:聚合判断""" + await aggregate_judgment(end_user_id, langchain_messages, memory_config) + finally: + if db_session.in_transaction(): + db_session.rollback() + db_session.close() + + + +async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id): + from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent + from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save + from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages + if storage_type == AgentMemory_Long_Term.STORAGE_RAG: + await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id) + else: + # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK + SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE + long_term_messages = await agent_chat_messages(message_chat, aimessages) + await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages, + memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE) + await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE) # async def main(): # """主函数 - 运行工作流""" # langchain_messages = [ # { # "role": "user", -# "content": "今天周五好开心啊" +# "content": "今天周五去爬山" # }, # { # "role": "assistant", -# "content": "你也这么觉得,我也是耶" +# "content": "好耶" # } # # ] # end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID # memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4" -# # await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) -# result=await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) +# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) +# # # # if __name__ == "__main__": diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index b6f50dd7..1a5017eb 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -1,3 +1,4 @@ +from abc import ABC from typing import Optional from pydantic import BaseModel @@ -14,4 +15,15 @@ class UserInput(BaseModel): class Write_UserInput(BaseModel): messages: list[dict] end_user_id: str - config_id: Optional[str] = None \ No newline at end of file + config_id: Optional[str] = None + +class AgentMemory_Long_Term(ABC): + """长期记忆配置常量""" + STORAGE_NEO4J = "neo4j" + STORAGE_RAG = "rag" + STRATEGY_AGGREGATE = "aggregate" + STRATEGY_CHUNK = "chunk" + STRATEGY_TIME = "time" + DEFAULT_SCOPE = 6 + + diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 9a3e1d37..43073555 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -110,6 +110,8 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str result = task_service.get_task_memory_read_result(task.id) status = result.get("status") logger.info(f"读取任务状态:{status}") + if memory_content: + memory_content = memory_content['answer'] finally: db.close() @@ -123,7 +125,6 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str "content_length": len(str(memory_content)) } ) - return f"检索到以下历史记忆:\n\n{memory_content}" except Exception as e: logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})