Fix/memory bug fix (#171)

This commit is contained in:
lixinyue11
2026-01-26 11:53:34 +08:00
committed by GitHub
parent 714c624dc6
commit 3601737869
119 changed files with 1711 additions and 1695 deletions

View File

@@ -9,6 +9,7 @@ import os
import re
import time
import uuid
from uuid import UUID
from typing import Any, AsyncGenerator, Dict, List, Optional
import redis
@@ -27,6 +28,7 @@ from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.memory_short_repository import ShortTermMemoryRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_agent_schema import Write_UserInput
from app.schemas.memory_config_schema import ConfigurationError
@@ -35,6 +37,7 @@ from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import (
write_rag,
)
from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from pydantic import BaseModel, Field
from sqlalchemy import func
@@ -54,25 +57,24 @@ _neo4j_connector = Neo4jConnector()
class MemoryAgentService:
"""Service for memory agent operations"""
def writer_messages_deal(self, messages, start_time, group_id, config_id, message, context):
def writer_messages_deal(self, messages, start_time, end_user_id, config_id, message, context):
duration = time.time() - start_time
if str(messages) == 'success':
logger.info(f"Write operation successful for group {group_id} with config_id {config_id}")
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
# 记录成功的操作
if audit_logger:
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=True,
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True,
duration=duration, details={"message_length": len(message)})
return context
else:
logger.warning(f"Write operation failed for group {group_id}")
logger.warning(f"Write operation failed for group {end_user_id}")
# 记录失败的操作
if audit_logger:
audit_logger.log_operation(
operation="WRITE",
config_id=config_id,
group_id=group_id,
end_user_id=end_user_id,
success=False,
duration=duration,
error=f"写入失败: {messages[:100]}"
@@ -263,13 +265,13 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, group_id: str, messages: list[dict], config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
"""
Process write operation with config_id
Args:
group_id: Group identifier (also used as end_user_id)
messages: Structured message list [{"role": "user", "content": "..."}, ...]
end_user_id: Group identifier (also used as end_user_id)
message: Message to write
config_id: Configuration ID from database
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)
@@ -284,15 +286,15 @@ class MemoryAgentService:
# Resolve config_id if None using end_user's connected config
if config_id is None:
try:
connected_config = get_end_user_connected_config(group_id, db)
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is None:
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
import time
start_time = time.time()
@@ -312,7 +314,7 @@ class MemoryAgentService:
# Log failed operation
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
@@ -320,24 +322,23 @@ class MemoryAgentService:
if storage_type == "rag":
# For RAG storage, convert messages to single string
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
result = await write_rag(group_id, message_text, user_rag_memory_id)
result = await write_rag(end_user_id, message_text, user_rag_memory_id)
return result
else:
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": group_id}}
config = {"configurable": {"thread_id": end_user_id}}
# Convert structured messages to LangChain messages
langchain_messages = []
for msg in messages:
if msg['role'] == 'user':
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
from langchain_core.messages import AIMessage
langchain_messages.append(AIMessage(content=msg['content']))
# 初始状态 - 包含所有必要字段
initial_state = {
"messages": langchain_messages,
"group_id": group_id,
"end_user_id": end_user_id,
"memory_config": memory_config
}
@@ -354,14 +355,14 @@ class MemoryAgentService:
contents = massages.get('write_result')
# Convert messages back to string for logging
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message_text, contents)
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}"
logger.error(error_msg)
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
@@ -369,15 +370,14 @@ class MemoryAgentService:
async def read_memory(
self,
group_id: str,
end_user_id: str,
message: str,
history: List[Dict],
search_switch: str,
config_id: Optional[str],
config_id: Optional[UUID],
db: Session,
storage_type: str,
user_rag_memory_id: str
) -> Dict:
user_rag_memory_id: str) -> Dict:
"""
Process read operation with config_id
@@ -387,7 +387,7 @@ class MemoryAgentService:
- "2": Direct answer based on context
Args:
group_id: Group identifier (also used as end_user_id)
end_user_id: Group identifier (also used as end_user_id)
message: User message
history: Conversation history
search_switch: Search mode switch
@@ -405,22 +405,22 @@ class MemoryAgentService:
import time
start_time = time.time()
logger.info(f"[PERF] read_memory started for group_id={group_id}, search_switch={search_switch}")
ori_message= message
# Resolve config_id if None using end_user's connected config
if config_id is None:
try:
config_id = get_end_user_connected_config(group_id, db)
config_id=config_id.get('memory_config_id')
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is None:
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
logger.info(f"Read operation for group {group_id} with config_id {config_id}")
logger.info(f"Read operation for group {end_user_id} with config_id {config_id}")
# 导入审计日志记录器
try:
@@ -448,7 +448,7 @@ class MemoryAgentService:
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
end_user_id=end_user_id,
success=False,
duration=duration,
error=error_msg
@@ -458,16 +458,16 @@ class MemoryAgentService:
# Step 2: Prepare history
history.append({"role": "user", "content": message})
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
logger.debug(f"Group ID:{end_user_id}, Message:{message}, History:{history}, Config ID:{config_id}")
# Step 3: Initialize MCP client and execute read workflow
graph_exec_start = time.time()
try:
async with make_read_graph() as graph:
config = {"configurable": {"thread_id": group_id}}
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
"group_id": group_id
"end_user_id": end_user_id
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config}
# 获取节点更新信息
@@ -562,13 +562,13 @@ class MemoryAgentService:
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2":
# 使用 upsert 方法
repo.upsert(
end_user_id=group_id,
messages=message,
end_user_id=end_user_id,
messages=ori_message,
aimessages=summary,
retrieved_content=retrieved_content,
search_switch=str(search_switch)
)
logger.info(f"成功保存短期记忆: group_id={group_id}, search_switch={search_switch}")
logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
else:
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
@@ -584,7 +584,7 @@ class MemoryAgentService:
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
end_user_id=end_user_id,
success=True,
duration=duration
)
@@ -596,20 +596,20 @@ class MemoryAgentService:
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Read operation failed: {str(e)}"
total_time = time.time() - start_time
logger.error(f"[PERF] read_memory failed after {total_time:.4f}s: {error_msg}")
logger.error(error_msg)
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
end_user_id=end_user_id,
success=False,
duration=duration,
error=error_msg
)
raise ValueError(error_msg)
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
"""
Get standardized message list from user input.
@@ -654,7 +654,7 @@ class MemoryAgentService:
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
return user_input.messages
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
async def classify_message_type(self, message: str, config_id: UUID, db: Session) -> Dict:
"""
Determine the type of user message (read or write)
Updated to eliminate global variables in favor of explicit parameters.
@@ -681,10 +681,9 @@ class MemoryAgentService:
status = await status_typle(message, memory_config.llm_model_id)
logger.debug(f"Message type: {status}")
return status
async def generate_summary_from_retrieve(
self,
group_id: str,
end_user_id: str,
retrieve_info: str,
history: List[Dict],
query: str,
@@ -708,16 +707,16 @@ class MemoryAgentService:
"""
if config_id is None:
try:
config_id = get_end_user_connected_config(group_id, db)
config_id = get_end_user_connected_config(end_user_id, db)
config_id = config_id.get('memory_config_id')
if config_id is None:
raise ValueError(
f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
logger.info(f"Generating summary from retrieve info for query: {query[:50]}...")
try:
@@ -727,6 +726,7 @@ class MemoryAgentService:
config_id=config_id,
service_name="MemoryAgentService"
)
# 导入必要的模块
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import summary_llm
from app.core.memory.agent.models.summary_models import RetrieveSummaryResponse
@@ -766,7 +766,7 @@ class MemoryAgentService:
"""
统计知识库类型分布,包含:
1. PostgreSQL 中的知识库类型General, Web, Third-party, Folder根据 workspace_id 过滤)
2. Neo4j 中的 memory 类型(仅统计 Chunk 数量,根据 end_user_id/group_id 过滤)
2. Neo4j 中的 memory 类型(仅统计 Chunk 数量,根据 end_user_id/end_user_id 过滤)
3. total: 所有类型的总和
参数:
@@ -852,11 +852,11 @@ class MemoryAgentService:
for end_user in end_users:
end_user_id_str = str(end_user.id)
memory_query = """
MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN count(n) AS Count
MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN count(n) AS Count
"""
neo4j_result = await _neo4j_connector.execute_query(
memory_query,
group_id=end_user_id_str,
end_user_id=end_user_id_str,
)
chunk_count = neo4j_result[0]["Count"] if neo4j_result else 0
total_chunks += chunk_count
@@ -896,7 +896,7 @@ class MemoryAgentService:
获取指定用户的热门记忆标签
参数:
- end_user_id: 用户ID可选对应Neo4j中的group_id字段
- end_user_id: 用户ID可选对应Neo4j中的end_user_id字段
- limit: 返回标签数量限制
返回格式:
@@ -906,7 +906,7 @@ class MemoryAgentService:
]
"""
try:
# by_user=False 表示按 group_id 查询在Neo4j中group_id就是用户维度
# by_user=False 表示按 end_user_id 查询在Neo4j中end_user_id就是用户维度
tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False)
payload=[]
for tag, freq in tags:
@@ -981,21 +981,21 @@ class MemoryAgentService:
# 查询该用户的语句
query = (
"MATCH (s:Statement) "
"WHERE ($group_id IS NULL OR s.group_id = $group_id) AND s.statement IS NOT NULL "
"WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id) AND s.statement IS NOT NULL "
"RETURN s.statement AS statement "
"ORDER BY s.created_at DESC LIMIT 100"
)
rows = await connector.execute_query(query, group_id=end_user_id)
rows = await connector.execute_query(query, end_user_id=end_user_id)
statements = [r.get("statement", "") for r in rows if r.get("statement")]
# 查询该用户的热门实体
entity_query = (
"MATCH (e:ExtractedEntity) "
"WHERE ($group_id IS NULL OR e.group_id = $group_id) AND e.entity_type <> '人物' AND e.name IS NOT NULL "
"WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) AND e.entity_type <> '人物' AND e.name IS NOT NULL "
"RETURN e.name AS name, count(e) AS frequency "
"ORDER BY frequency DESC LIMIT 20"
)
entity_rows = await connector.execute_query(entity_query, group_id=end_user_id)
entity_rows = await connector.execute_query(entity_query, end_user_id=end_user_id)
entities = [f"{r['name']} ({r['frequency']})" for r in entity_rows]
await connector.close()
@@ -1048,14 +1048,14 @@ class MemoryAgentService:
names_to_exclude = ['AI', 'Caroline', 'Melanie', 'Jon', 'Gina', '用户', 'AI助手', 'John', 'Maria']
hot_tag_query = (
"MATCH (e:ExtractedEntity) "
"WHERE ($group_id IS NULL OR e.group_id = $group_id) AND e.entity_type <> '人物' "
"WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) AND e.entity_type <> '人物' "
"AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
"RETURN e.name AS name, count(e) AS frequency "
"ORDER BY frequency DESC LIMIT 4"
)
hot_tag_rows = await connector.execute_query(
hot_tag_query,
group_id=end_user_id,
end_user_id=end_user_id,
names_to_exclude=names_to_exclude
)
await connector.close()
@@ -1189,6 +1189,16 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
# 3. 从 config 中提取 memory_config_id
config = latest_release.config or {}
# 如果 config 是字符串,解析为字典
if isinstance(config, str):
import json
try:
config = json.loads(config)
except json.JSONDecodeError:
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
config = {}
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
@@ -1227,7 +1237,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
"""
from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser
from app.models.data_config_model import DataConfig
from app.models.memory_config_model import MemoryConfig
from sqlalchemy import select
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users")
@@ -1240,10 +1250,10 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 1. 批量查询所有 end_user 及其 app_id
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
# 创建 end_user_id -> app_id 的映射
user_to_app = {str(eu.id): eu.app_id for eu in end_users}
# 记录未找到的用户
found_user_ids = set(user_to_app.keys())
missing_user_ids = set(end_user_ids) - found_user_ids
@@ -1285,13 +1295,13 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 批量查询 memory_config_name
config_id_to_name = {}
if memory_config_ids:
memory_configs = db.query(DataConfig).filter(DataConfig.config_id.in_(memory_config_ids)).all()
config_id_to_name = {str(mc.config_id): mc.config_name for mc in memory_configs}
memory_configs = db.query(MemoryConfig).filter(MemoryConfig.id.in_(memory_config_ids)).all()
config_id_to_name = {str(mc.id): mc.config_name for mc in memory_configs}
# 4. 构建最终结果
for end_user_id, app_id in user_to_app.items():
release = app_to_release.get(app_id)
if not release:
logger.warning(f"No active release found for app: {app_id} (end_user: {end_user_id})")
result[end_user_id] = {"memory_config_id": None, "memory_config_name": None}
@@ -1303,7 +1313,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
# 获取配置名称
memory_config_name = config_id_to_name.get(str(memory_config_id)) if memory_config_id else None
memory_config_name = config_id_to_name.get(memory_config_id) if memory_config_id else None
result[end_user_id] = {
"memory_config_id": memory_config_id,