Fix/memory bug fix (#171)
This commit is contained in:
164
api/app/tasks.py
164
api/app/tasks.py
@@ -4,6 +4,7 @@ import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
from datetime import datetime, timezone
|
||||
from math import ceil
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -382,16 +383,16 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.read_message", bind=True)
|
||||
def read_message_task(self, group_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]:
|
||||
def read_message_task(self, end_user_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str, storage_type:str, user_rag_memory_id:str) -> Dict[str, Any]:
|
||||
|
||||
"""Celery task to process a read message via MemoryAgentService.
|
||||
|
||||
Args:
|
||||
group_id: Group ID for the memory agent (also used as end_user_id)
|
||||
end_user_id: Group ID for the memory agent (also used as end_user_id)
|
||||
message: User message to process
|
||||
history: Conversation history
|
||||
search_switch: Search switch parameter
|
||||
config_id: Optional configuration ID
|
||||
config_id: Configuration ID as string (will be converted to UUID)
|
||||
|
||||
Returns:
|
||||
Dict containing the result and metadata
|
||||
@@ -401,14 +402,22 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Convert config_id string to UUID
|
||||
actual_config_id = None
|
||||
if config_id:
|
||||
try:
|
||||
actual_config_id = uuid.UUID(config_id) if isinstance(config_id, str) else config_id
|
||||
except (ValueError, AttributeError):
|
||||
# If conversion fails, leave as None and try to resolve
|
||||
pass
|
||||
|
||||
# Resolve config_id if None
|
||||
actual_config_id = config_id
|
||||
if actual_config_id is None:
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
finally:
|
||||
db.close()
|
||||
@@ -420,24 +429,42 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
|
||||
db = next(get_db())
|
||||
try:
|
||||
service = MemoryAgentService()
|
||||
return await service.read_memory(group_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id)
|
||||
return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
try:
|
||||
result = asyncio.run(_run())
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"result": result,
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"config_id": config_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
except BaseException as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
# Handle ExceptionGroup from TaskGroup
|
||||
if hasattr(e, 'exceptions'):
|
||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
detailed_error = "; ".join(error_messages)
|
||||
@@ -446,7 +473,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": detailed_error,
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"config_id": config_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
@@ -454,19 +481,13 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.write_message", bind=True)
|
||||
def write_message_task(self, group_id: str, message, config_id: str, storage_type: str, user_rag_memory_id: str) -> Dict[str, Any]:
|
||||
def write_message_task(self, end_user_id: str, message: str, config_id: str, storage_type:str, user_rag_memory_id:str) -> Dict[str, Any]:
|
||||
"""Celery task to process a write message via MemoryAgentService.
|
||||
|
||||
支持两种消息格式:
|
||||
1. 字符串格式(向后兼容):message="user: xxx\nassistant: yyy"
|
||||
2. 结构化消息列表(推荐):message=[{"role": "user", "content": "xxx"}, {"role": "assistant", "content": "yyy"}]
|
||||
|
||||
Args:
|
||||
group_id: Group ID for the memory agent (also used as end_user_id)
|
||||
message: Message to write (str or list[dict])
|
||||
config_id: Optional configuration ID
|
||||
storage_type: Storage type (neo4j/rag)
|
||||
user_rag_memory_id: RAG memory ID
|
||||
end_user_id: Group ID for the memory agent (also used as end_user_id)
|
||||
message: Message to write
|
||||
config_id: Configuration ID as string (will be converted to UUID)
|
||||
|
||||
Returns:
|
||||
Dict containing the result and metadata
|
||||
@@ -477,30 +498,46 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ
|
||||
from app.core.logging_config import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger.info(f"[CELERY WRITE] Starting write task - group_id={group_id}, config_id={config_id}, storage_type={storage_type}")
|
||||
logger.info(f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}")
|
||||
start_time = time.time()
|
||||
|
||||
# Convert config_id string to UUID
|
||||
actual_config_id = None
|
||||
if config_id:
|
||||
try:
|
||||
actual_config_id = uuid.UUID(config_id) if isinstance(config_id, str) else config_id
|
||||
logger.info(f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})")
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id}, error: {e}")
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": f"Invalid config_id format: {config_id}",
|
||||
"end_user_id": end_user_id,
|
||||
"config_id": config_id,
|
||||
"elapsed_time": 0.0,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
|
||||
# Resolve config_id if None
|
||||
actual_config_id = config_id
|
||||
if actual_config_id is None:
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception:
|
||||
# Log but continue - will fail later with proper error
|
||||
pass
|
||||
|
||||
|
||||
async def _run() -> str:
|
||||
db = next(get_db())
|
||||
try:
|
||||
logger.info(f"[CELERY WRITE] Executing MemoryAgentService.write_memory")
|
||||
logger.info(f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__})")
|
||||
service = MemoryAgentService()
|
||||
result = await service.write_memory(group_id, message, actual_config_id, db, storage_type, user_rag_memory_id)
|
||||
result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type, user_rag_memory_id)
|
||||
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
@@ -510,7 +547,24 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ
|
||||
db.close()
|
||||
|
||||
try:
|
||||
result = asyncio.run(_run())
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
|
||||
@@ -518,13 +572,14 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"result": result,
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"config_id": config_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
except BaseException as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
# Handle ExceptionGroup from TaskGroup
|
||||
if hasattr(e, 'exceptions'):
|
||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
detailed_error = "; ".join(error_messages)
|
||||
@@ -536,7 +591,7 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": detailed_error,
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"config_id": config_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
@@ -878,7 +933,24 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
try:
|
||||
result = asyncio.run(_run())
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
result["elapsed_time"] = elapsed_time
|
||||
result["task_id"] = self.request.id
|
||||
@@ -951,7 +1023,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
end_users = data['end_users']
|
||||
|
||||
for base, config, user in zip(releases, data_configs, end_users):
|
||||
if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']:
|
||||
if str(base['config']) == str(config['config_id']) and str(base['app_id']) == str(user['app_id']):
|
||||
# 调用反思服务
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
||||
|
||||
@@ -1005,7 +1077,24 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
try:
|
||||
result = asyncio.run(_run())
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
result["elapsed_time"] = elapsed_time
|
||||
result["task_id"] = self.request.id
|
||||
@@ -1023,7 +1112,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
|
||||
|
||||
@celery_app.task(name="app.tasks.run_forgetting_cycle_task", bind=True)
|
||||
def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str, Any]:
|
||||
def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Dict[str, Any]:
|
||||
"""定时任务:运行遗忘周期
|
||||
|
||||
定期执行遗忘周期,识别并融合低激活值的知识节点。
|
||||
@@ -1051,7 +1140,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str
|
||||
# 运行遗忘周期
|
||||
report = await forget_service.trigger_forgetting(
|
||||
db=db,
|
||||
group_id=None, # 处理所有组
|
||||
end_user_id=None, # 处理所有组
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
@@ -1081,4 +1170,11 @@ def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str
|
||||
"duration_seconds": duration
|
||||
}
|
||||
|
||||
return asyncio.run(_run())
|
||||
# 运行异步函数
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
result = loop.run_until_complete(_run())
|
||||
return result
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
Reference in New Issue
Block a user