Fix/memory bug fix (#171)

This commit is contained in:
lixinyue11
2026-01-26 11:53:34 +08:00
committed by GitHub
parent 714c624dc6
commit 3601737869
119 changed files with 1711 additions and 1695 deletions

View File

@@ -4,6 +4,7 @@ import os
import re
import time
import uuid
from uuid import UUID
from datetime import datetime, timezone
from math import ceil
from typing import Any, Dict, List, Optional
@@ -382,16 +383,16 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
@celery_app.task(name="app.core.memory.agent.read_message", bind=True)
def read_message_task(self, group_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]:
def read_message_task(self, end_user_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str, storage_type:str, user_rag_memory_id:str) -> Dict[str, Any]:
"""Celery task to process a read message via MemoryAgentService.
Args:
group_id: Group ID for the memory agent (also used as end_user_id)
end_user_id: Group ID for the memory agent (also used as end_user_id)
message: User message to process
history: Conversation history
search_switch: Search switch parameter
config_id: Optional configuration ID
config_id: Configuration ID as string (will be converted to UUID)
Returns:
Dict containing the result and metadata
@@ -401,14 +402,22 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
"""
start_time = time.time()
# Convert config_id string to UUID
actual_config_id = None
if config_id:
try:
actual_config_id = uuid.UUID(config_id) if isinstance(config_id, str) else config_id
except (ValueError, AttributeError):
# If conversion fails, leave as None and try to resolve
pass
# Resolve config_id if None
actual_config_id = config_id
if actual_config_id is None:
try:
from app.services.memory_agent_service import get_end_user_connected_config
db = next(get_db())
try:
connected_config = get_end_user_connected_config(group_id, db)
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
finally:
db.close()
@@ -420,24 +429,42 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
db = next(get_db())
try:
service = MemoryAgentService()
return await service.read_memory(group_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id)
return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id)
finally:
db.close()
try:
result = asyncio.run(_run())
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
return {
"status": "SUCCESS",
"result": result,
"group_id": group_id,
"end_user_id": end_user_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
"task_id": self.request.id
}
except BaseException as e:
elapsed_time = time.time() - start_time
# Handle ExceptionGroup from TaskGroup
if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages)
@@ -446,7 +473,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
return {
"status": "FAILURE",
"error": detailed_error,
"group_id": group_id,
"end_user_id": end_user_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
"task_id": self.request.id
@@ -454,19 +481,13 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
@celery_app.task(name="app.core.memory.agent.write_message", bind=True)
def write_message_task(self, group_id: str, message, config_id: str, storage_type: str, user_rag_memory_id: str) -> Dict[str, Any]:
def write_message_task(self, end_user_id: str, message: str, config_id: str, storage_type:str, user_rag_memory_id:str) -> Dict[str, Any]:
"""Celery task to process a write message via MemoryAgentService.
支持两种消息格式:
1. 字符串格式向后兼容message="user: xxx\nassistant: yyy"
2. 结构化消息列表推荐message=[{"role": "user", "content": "xxx"}, {"role": "assistant", "content": "yyy"}]
Args:
group_id: Group ID for the memory agent (also used as end_user_id)
message: Message to write (str or list[dict])
config_id: Optional configuration ID
storage_type: Storage type (neo4j/rag)
user_rag_memory_id: RAG memory ID
end_user_id: Group ID for the memory agent (also used as end_user_id)
message: Message to write
config_id: Configuration ID as string (will be converted to UUID)
Returns:
Dict containing the result and metadata
@@ -477,30 +498,46 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ
from app.core.logging_config import get_logger
logger = get_logger(__name__)
logger.info(f"[CELERY WRITE] Starting write task - group_id={group_id}, config_id={config_id}, storage_type={storage_type}")
logger.info(f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}")
start_time = time.time()
# Convert config_id string to UUID
actual_config_id = None
if config_id:
try:
actual_config_id = uuid.UUID(config_id) if isinstance(config_id, str) else config_id
logger.info(f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})")
except (ValueError, AttributeError) as e:
logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id}, error: {e}")
return {
"status": "FAILURE",
"error": f"Invalid config_id format: {config_id}",
"end_user_id": end_user_id,
"config_id": config_id,
"elapsed_time": 0.0,
"task_id": self.request.id
}
# Resolve config_id if None
actual_config_id = config_id
if actual_config_id is None:
try:
from app.services.memory_agent_service import get_end_user_connected_config
db = next(get_db())
try:
connected_config = get_end_user_connected_config(group_id, db)
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
finally:
db.close()
except Exception:
# Log but continue - will fail later with proper error
pass
async def _run() -> str:
db = next(get_db())
try:
logger.info(f"[CELERY WRITE] Executing MemoryAgentService.write_memory")
logger.info(f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__})")
service = MemoryAgentService()
result = await service.write_memory(group_id, message, actual_config_id, db, storage_type, user_rag_memory_id)
result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type, user_rag_memory_id)
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
return result
except Exception as e:
@@ -510,7 +547,24 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ
db.close()
try:
result = asyncio.run(_run())
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
@@ -518,13 +572,14 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ
return {
"status": "SUCCESS",
"result": result,
"group_id": group_id,
"end_user_id": end_user_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
"task_id": self.request.id
}
except BaseException as e:
elapsed_time = time.time() - start_time
# Handle ExceptionGroup from TaskGroup
if hasattr(e, 'exceptions'):
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
detailed_error = "; ".join(error_messages)
@@ -536,7 +591,7 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ
return {
"status": "FAILURE",
"error": detailed_error,
"group_id": group_id,
"end_user_id": end_user_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
"task_id": self.request.id
@@ -878,7 +933,24 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
}
try:
result = asyncio.run(_run())
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
result["elapsed_time"] = elapsed_time
result["task_id"] = self.request.id
@@ -951,7 +1023,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
end_users = data['end_users']
for base, config, user in zip(releases, data_configs, end_users):
if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']:
if str(base['config']) == str(config['config_id']) and str(base['app_id']) == str(user['app_id']):
# 调用反思服务
api_logger.info(f"为用户 {user['id']} 启动反思config_id: {config['config_id']}")
@@ -1005,7 +1077,24 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
}
try:
result = asyncio.run(_run())
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
result["elapsed_time"] = elapsed_time
result["task_id"] = self.request.id
@@ -1023,7 +1112,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
@celery_app.task(name="app.tasks.run_forgetting_cycle_task", bind=True)
def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str, Any]:
def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Dict[str, Any]:
"""定时任务:运行遗忘周期
定期执行遗忘周期,识别并融合低激活值的知识节点。
@@ -1051,7 +1140,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str
# 运行遗忘周期
report = await forget_service.trigger_forgetting(
db=db,
group_id=None, # 处理所有组
end_user_id=None, # 处理所有组
config_id=config_id
)
@@ -1081,4 +1170,11 @@ def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str
"duration_seconds": duration
}
return asyncio.run(_run())
# 运行异步函数
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(_run())
return result
finally:
loop.close()