Fix/develop memory bug (#350)
* 遗漏的历史映射 * 遗漏的历史映射 * fix_timeline_memories * fix_timeline_memories * write_gragp/bug_fix * write_gragp/bug_fix * write_gragp/bug_fix * write_gragp/bug_fix * Multiple independent transactions - single transaction * memory_content ->memory_config_id * memory_content ->memory_config_id * memory_content ->memory_config_id * memory_content ->memory_config_id * memory_content ->memory_config_id * memory_content ->memory_config_id * memory_content ->memory_config_id * tasks/bug_fix/long * tasks_reflection/bug/fix * tasks_reflection/bug/fix * tasks_reflection/bug/fix * tasks_reflection/bug/fix
This commit is contained in:
@@ -364,6 +364,13 @@ class MemoryReflectionService:
|
|||||||
reflexion_range_value = config_data.get("reflexion_range")
|
reflexion_range_value = config_data.get("reflexion_range")
|
||||||
if reflexion_range_value is None or reflexion_range_value == "":
|
if reflexion_range_value is None or reflexion_range_value == "":
|
||||||
reflexion_range_value = "partial"
|
reflexion_range_value = "partial"
|
||||||
|
# Map legacy/invalid values to valid enum values
|
||||||
|
reflexion_range_mapping = {
|
||||||
|
"retrieval": "partial", # Map old 'retrieval' to 'partial'
|
||||||
|
"partial": "partial",
|
||||||
|
"all": "all"
|
||||||
|
}
|
||||||
|
reflexion_range_value = reflexion_range_mapping.get(reflexion_range_value, "partial")
|
||||||
reflexion_range = ReflectionRange(reflexion_range_value)
|
reflexion_range = ReflectionRange(reflexion_range_value)
|
||||||
|
|
||||||
baseline_value = config_data.get("baseline")
|
baseline_value = config_data.get("baseline")
|
||||||
|
|||||||
344
api/app/tasks.py
344
api/app/tasks.py
@@ -405,7 +405,7 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
|||||||
|
|
||||||
# 2. sync data
|
# 2. sync data
|
||||||
match db_knowledge.type:
|
match db_knowledge.type:
|
||||||
case "Web": # Crawl webpages in batches through a web crawler
|
case "Web": # Crawl webpages in batches through a web crawler
|
||||||
entry_url = db_knowledge.parser_config.get("entry_url", "")
|
entry_url = db_knowledge.parser_config.get("entry_url", "")
|
||||||
max_pages = db_knowledge.parser_config.get("max_pages", 20)
|
max_pages = db_knowledge.parser_config.get("max_pages", 20)
|
||||||
delay_seconds = db_knowledge.parser_config.get("delay_seconds", 1.0)
|
delay_seconds = db_knowledge.parser_config.get("delay_seconds", 1.0)
|
||||||
@@ -428,19 +428,21 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
|||||||
db_file = db.query(File).filter(File.kb_id == db_knowledge.id,
|
db_file = db.query(File).filter(File.kb_id == db_knowledge.id,
|
||||||
File.file_url == crawled_document.url).first()
|
File.file_url == crawled_document.url).first()
|
||||||
if db_file:
|
if db_file:
|
||||||
if db_file.file_size == crawled_document.content_length: # same
|
if db_file.file_size == crawled_document.content_length: # same
|
||||||
continue
|
continue
|
||||||
else: # --update
|
else: # --update
|
||||||
if crawled_document.content_length:
|
if crawled_document.content_length:
|
||||||
# 1. update file
|
# 1. update file
|
||||||
db_file.file_name = f"{crawled_document.title}.txt"
|
db_file.file_name = f"{crawled_document.title}.txt"
|
||||||
db_file.file_ext=".txt"
|
db_file.file_ext = ".txt"
|
||||||
db_file.file_size=crawled_document.content_length
|
db_file.file_size = crawled_document.content_length
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_file)
|
db.refresh(db_file)
|
||||||
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
||||||
save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id))
|
save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id),
|
||||||
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
str(db_knowledge.parent_id))
|
||||||
|
Path(save_dir).mkdir(parents=True,
|
||||||
|
exist_ok=True) # Ensure that the directory exists
|
||||||
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}")
|
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}")
|
||||||
# update file
|
# update file
|
||||||
if os.path.exists(save_path):
|
if os.path.exists(save_path):
|
||||||
@@ -460,7 +462,7 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
|||||||
db.refresh(db_document)
|
db.refresh(db_document)
|
||||||
# 3. Document parsing, vectorization, and storage
|
# 3. Document parsing, vectorization, and storage
|
||||||
parse_document(file_path=save_path, document_id=db_document.id)
|
parse_document(file_path=save_path, document_id=db_document.id)
|
||||||
else: # --add
|
else: # --add
|
||||||
if crawled_document.content_length:
|
if crawled_document.content_length:
|
||||||
# 1. upload file
|
# 1. upload file
|
||||||
upload_file = file_schema.FileCreate(
|
upload_file = file_schema.FileCreate(
|
||||||
@@ -507,8 +509,9 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
|||||||
db.commit()
|
db.commit()
|
||||||
# 3. Document parsing, vectorization, and storage
|
# 3. Document parsing, vectorization, and storage
|
||||||
parse_document(file_path=save_path, document_id=db_document.id)
|
parse_document(file_path=save_path, document_id=db_document.id)
|
||||||
db_files = db.query(File).filter(File.kb_id == db_knowledge.id, File.file_url.notin_(file_urls)).all()
|
db_files = db.query(File).filter(File.kb_id == db_knowledge.id,
|
||||||
if db_files: # --delete
|
File.file_url.notin_(file_urls)).all()
|
||||||
|
if db_files: # --delete
|
||||||
for db_file in db_files:
|
for db_file in db_files:
|
||||||
db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id,
|
db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id,
|
||||||
Document.file_id == db_file.id).first()
|
Document.file_id == db_file.id).first()
|
||||||
@@ -535,7 +538,7 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
|||||||
case "Third-party": # Integration of knowledge bases from three parties
|
case "Third-party": # Integration of knowledge bases from three parties
|
||||||
yuque_user_id = db_knowledge.parser_config.get("yuque_user_id", "")
|
yuque_user_id = db_knowledge.parser_config.get("yuque_user_id", "")
|
||||||
feishu_app_id = db_knowledge.parser_config.get("feishu_app_id", "")
|
feishu_app_id = db_knowledge.parser_config.get("feishu_app_id", "")
|
||||||
if yuque_user_id: # Yuque Knowledge Base
|
if yuque_user_id: # Yuque Knowledge Base
|
||||||
yuque_token = db_knowledge.parser_config.get("yuque_token", "")
|
yuque_token = db_knowledge.parser_config.get("yuque_token", "")
|
||||||
# Create yuqueAPIClient
|
# Create yuqueAPIClient
|
||||||
api_client = YuqueAPIClient(
|
api_client = YuqueAPIClient(
|
||||||
@@ -571,11 +574,14 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
|||||||
else: # --update
|
else: # --update
|
||||||
# 1. update file
|
# 1. update file
|
||||||
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
||||||
save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id))
|
save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id),
|
||||||
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
str(db_knowledge.parent_id))
|
||||||
|
Path(save_dir).mkdir(parents=True,
|
||||||
|
exist_ok=True) # Ensure that the directory exists
|
||||||
|
|
||||||
# download document from Feishu FileInfo
|
# download document from Feishu FileInfo
|
||||||
async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, save_dir: str):
|
async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo,
|
||||||
|
save_dir: str):
|
||||||
async with api_client as client:
|
async with api_client as client:
|
||||||
file_path = await client.download_document(doc, save_dir)
|
file_path = await client.download_document(doc, save_dir)
|
||||||
return file_path
|
return file_path
|
||||||
@@ -613,11 +619,13 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
|||||||
else: # --add
|
else: # --add
|
||||||
# 1. update file
|
# 1. update file
|
||||||
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
||||||
save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id))
|
save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id),
|
||||||
|
str(db_knowledge.parent_id))
|
||||||
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
||||||
|
|
||||||
# download document from Feishu FileInfo
|
# download document from Feishu FileInfo
|
||||||
async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, save_dir: str):
|
async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo,
|
||||||
|
save_dir: str):
|
||||||
async with api_client as client:
|
async with api_client as client:
|
||||||
file_path = await client.download_document(doc, save_dir)
|
file_path = await client.download_document(doc, save_dir)
|
||||||
return file_path
|
return file_path
|
||||||
@@ -697,7 +705,7 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n\nError during fetch feishu: {e}")
|
print(f"\n\nError during fetch feishu: {e}")
|
||||||
if feishu_app_id: # Feishu Knowledge Base
|
if feishu_app_id: # Feishu Knowledge Base
|
||||||
feishu_app_secret = db_knowledge.parser_config.get("feishu_app_secret", "")
|
feishu_app_secret = db_knowledge.parser_config.get("feishu_app_secret", "")
|
||||||
feishu_folder_token = db_knowledge.parser_config.get("feishu_folder_token", "")
|
feishu_folder_token = db_knowledge.parser_config.get("feishu_folder_token", "")
|
||||||
# Create feishuAPIClient
|
# Create feishuAPIClient
|
||||||
@@ -708,11 +716,13 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
|||||||
try:
|
try:
|
||||||
# 初始化存储获取飞书 URLs 的集合
|
# 初始化存储获取飞书 URLs 的集合
|
||||||
file_urls = set()
|
file_urls = set()
|
||||||
|
|
||||||
# Get all files from folder
|
# Get all files from folder
|
||||||
async def async_get_files(api_client: FeishuAPIClient, feishu_folder_token: str):
|
async def async_get_files(api_client: FeishuAPIClient, feishu_folder_token: str):
|
||||||
async with api_client as client:
|
async with api_client as client:
|
||||||
files = await client.list_all_folder_files(feishu_folder_token, recursive=True)
|
files = await client.list_all_folder_files(feishu_folder_token, recursive=True)
|
||||||
return files
|
return files
|
||||||
|
|
||||||
files = asyncio.run(async_get_files(api_client, feishu_folder_token))
|
files = asyncio.run(async_get_files(api_client, feishu_folder_token))
|
||||||
# Filter out folders, only sync documents
|
# Filter out folders, only sync documents
|
||||||
documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file"]]
|
documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file"]]
|
||||||
@@ -728,12 +738,16 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
|||||||
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
||||||
save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id),
|
save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id),
|
||||||
str(db_knowledge.parent_id))
|
str(db_knowledge.parent_id))
|
||||||
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
Path(save_dir).mkdir(parents=True,
|
||||||
|
exist_ok=True) # Ensure that the directory exists
|
||||||
|
|
||||||
# download document from Feishu FileInfo
|
# download document from Feishu FileInfo
|
||||||
async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str):
|
async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo,
|
||||||
|
save_dir: str):
|
||||||
async with api_client as client:
|
async with api_client as client:
|
||||||
file_path = await client.download_document(document=doc, save_dir=save_dir)
|
file_path = await client.download_document(document=doc, save_dir=save_dir)
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
file_path = asyncio.run(async_download_document(api_client, doc, save_dir))
|
file_path = asyncio.run(async_download_document(api_client, doc, save_dir))
|
||||||
|
|
||||||
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}")
|
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}")
|
||||||
@@ -770,11 +784,14 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
|||||||
save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id),
|
save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id),
|
||||||
str(db_knowledge.parent_id))
|
str(db_knowledge.parent_id))
|
||||||
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
||||||
|
|
||||||
# download document from Feishu FileInfo
|
# download document from Feishu FileInfo
|
||||||
async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str):
|
async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo,
|
||||||
|
save_dir: str):
|
||||||
async with api_client as client:
|
async with api_client as client:
|
||||||
file_path = await client.download_document(document=doc, save_dir=save_dir)
|
file_path = await client.download_document(document=doc, save_dir=save_dir)
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
file_path = asyncio.run(async_download_document(api_client, doc, save_dir))
|
file_path = asyncio.run(async_download_document(api_client, doc, save_dir))
|
||||||
# add db_file
|
# add db_file
|
||||||
file_name = os.path.basename(file_path)
|
file_name = os.path.basename(file_path)
|
||||||
@@ -788,7 +805,7 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
|||||||
file_ext=file_extension.lower(),
|
file_ext=file_extension.lower(),
|
||||||
file_size=file_size,
|
file_size=file_size,
|
||||||
file_url=doc.url,
|
file_url=doc.url,
|
||||||
created_at = doc.modified_time
|
created_at=doc.modified_time
|
||||||
)
|
)
|
||||||
db_file = File(**upload_file.model_dump())
|
db_file = File(**upload_file.model_dump())
|
||||||
db.add(db_file)
|
db.add(db_file)
|
||||||
@@ -853,7 +870,6 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
|||||||
case _: # General
|
case _: # General
|
||||||
print(f"General: No synchronization needed\n")
|
print(f"General: No synchronization needed\n")
|
||||||
|
|
||||||
|
|
||||||
result = f"sync knowledge '{db_knowledge.name}' processed successfully."
|
result = f"sync knowledge '{db_knowledge.name}' processed successfully."
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -866,8 +882,8 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
|||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="app.core.memory.agent.read_message", bind=True)
|
@celery_app.task(name="app.core.memory.agent.read_message", bind=True)
|
||||||
def read_message_task(self, end_user_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str, storage_type:str, user_rag_memory_id:str) -> Dict[str, Any]:
|
def read_message_task(self, end_user_id: str, message: str, history: List[Dict[str, Any]], search_switch: str,
|
||||||
|
config_id: str, storage_type: str, user_rag_memory_id: str) -> Dict[str, Any]:
|
||||||
"""Celery task to process a read message via MemoryAgentService.
|
"""Celery task to process a read message via MemoryAgentService.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -876,15 +892,15 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
|
|||||||
history: Conversation history
|
history: Conversation history
|
||||||
search_switch: Search switch parameter
|
search_switch: Search switch parameter
|
||||||
config_id: Configuration ID as string (will be converted to UUID)
|
config_id: Configuration ID as string (will be converted to UUID)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict containing the result and metadata
|
Dict containing the result and metadata
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception on failure
|
Exception on failure
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# Convert config_id string to UUID
|
# Convert config_id string to UUID
|
||||||
actual_config_id = None
|
actual_config_id = None
|
||||||
if config_id:
|
if config_id:
|
||||||
@@ -893,7 +909,7 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
|
|||||||
except (ValueError, AttributeError):
|
except (ValueError, AttributeError):
|
||||||
# If conversion fails, leave as None and try to resolve
|
# If conversion fails, leave as None and try to resolve
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Resolve config_id if None
|
# Resolve config_id if None
|
||||||
if actual_config_id is None:
|
if actual_config_id is None:
|
||||||
try:
|
try:
|
||||||
@@ -907,12 +923,13 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
|
|||||||
except Exception:
|
except Exception:
|
||||||
# Log but continue - will fail later with proper error
|
# Log but continue - will fail later with proper error
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _run() -> str:
|
async def _run() -> str:
|
||||||
db = next(get_db())
|
db = next(get_db())
|
||||||
try:
|
try:
|
||||||
service = MemoryAgentService()
|
service = MemoryAgentService()
|
||||||
return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id)
|
return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db,
|
||||||
|
storage_type, user_rag_memory_id)
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
@@ -923,7 +940,7 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
|
|||||||
nest_asyncio.apply()
|
nest_asyncio.apply()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
@@ -933,10 +950,10 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
|
|||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
result = loop.run_until_complete(_run())
|
result = loop.run_until_complete(_run())
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "SUCCESS",
|
"status": "SUCCESS",
|
||||||
"result": result,
|
"result": result,
|
||||||
@@ -964,9 +981,10 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
|
|||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="app.core.memory.agent.write_message", bind=True)
|
@celery_app.task(name="app.core.memory.agent.write_message", bind=True)
|
||||||
def write_message_task(self, end_user_id: str, message: str, config_id: str, storage_type:str, user_rag_memory_id:str, language: str = "zh") -> Dict[str, Any]:
|
def write_message_task(self, end_user_id: str, message: str, config_id: str, storage_type: str, user_rag_memory_id: str,
|
||||||
|
language: str = "zh") -> Dict[str, Any]:
|
||||||
"""Celery task to process a write message via MemoryAgentService.
|
"""Celery task to process a write message via MemoryAgentService.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id: Group ID for the memory agent (also used as end_user_id)
|
end_user_id: Group ID for the memory agent (also used as end_user_id)
|
||||||
message: Message to write
|
message: Message to write
|
||||||
@@ -974,25 +992,27 @@ def write_message_task(self, end_user_id: str, message: str, config_id: str, sto
|
|||||||
storage_type: Storage type (neo4j or rag)
|
storage_type: Storage type (neo4j or rag)
|
||||||
user_rag_memory_id: User RAG memory ID
|
user_rag_memory_id: User RAG memory ID
|
||||||
language: 语言类型 ("zh" 中文, "en" 英文)
|
language: 语言类型 ("zh" 中文, "en" 英文)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict containing the result and metadata
|
Dict containing the result and metadata
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception on failure
|
Exception on failure
|
||||||
"""
|
"""
|
||||||
from app.core.logging_config import get_logger
|
from app.core.logging_config import get_logger
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
logger.info(f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}, language={language}")
|
logger.info(
|
||||||
|
f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}, language={language}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# Convert config_id string to UUID
|
# Convert config_id string to UUID
|
||||||
actual_config_id = None
|
actual_config_id = None
|
||||||
if config_id:
|
if config_id:
|
||||||
try:
|
try:
|
||||||
actual_config_id = uuid.UUID(config_id) if isinstance(config_id, str) else config_id
|
actual_config_id = uuid.UUID(config_id) if isinstance(config_id, str) else config_id
|
||||||
logger.info(f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})")
|
logger.info(
|
||||||
|
f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})")
|
||||||
except (ValueError, AttributeError) as e:
|
except (ValueError, AttributeError) as e:
|
||||||
logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id}, error: {e}")
|
logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id}, error: {e}")
|
||||||
return {
|
return {
|
||||||
@@ -1003,7 +1023,7 @@ def write_message_task(self, end_user_id: str, message: str, config_id: str, sto
|
|||||||
"elapsed_time": 0.0,
|
"elapsed_time": 0.0,
|
||||||
"task_id": self.request.id
|
"task_id": self.request.id
|
||||||
}
|
}
|
||||||
|
|
||||||
# Resolve config_id if None
|
# Resolve config_id if None
|
||||||
if actual_config_id is None:
|
if actual_config_id is None:
|
||||||
try:
|
try:
|
||||||
@@ -1021,9 +1041,11 @@ def write_message_task(self, end_user_id: str, message: str, config_id: str, sto
|
|||||||
async def _run() -> str:
|
async def _run() -> str:
|
||||||
db = next(get_db())
|
db = next(get_db())
|
||||||
try:
|
try:
|
||||||
logger.info(f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
|
logger.info(
|
||||||
|
f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
|
||||||
service = MemoryAgentService()
|
service = MemoryAgentService()
|
||||||
result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type, user_rag_memory_id, language)
|
result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type,
|
||||||
|
user_rag_memory_id, language)
|
||||||
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1039,7 +1061,7 @@ def write_message_task(self, end_user_id: str, message: str, config_id: str, sto
|
|||||||
nest_asyncio.apply()
|
nest_asyncio.apply()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
@@ -1049,12 +1071,13 @@ def write_message_task(self, end_user_id: str, message: str, config_id: str, sto
|
|||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
result = loop.run_until_complete(_run())
|
result = loop.run_until_complete(_run())
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
|
logger.info(
|
||||||
|
f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "SUCCESS",
|
"status": "SUCCESS",
|
||||||
"result": result,
|
"result": result,
|
||||||
@@ -1071,9 +1094,10 @@ def write_message_task(self, end_user_id: str, message: str, config_id: str, sto
|
|||||||
detailed_error = "; ".join(error_messages)
|
detailed_error = "; ".join(error_messages)
|
||||||
else:
|
else:
|
||||||
detailed_error = str(e)
|
detailed_error = str(e)
|
||||||
|
|
||||||
logger.error(f"[CELERY WRITE] Task failed - elapsed_time={elapsed_time:.2f}s, error={detailed_error}", exc_info=True)
|
logger.error(f"[CELERY WRITE] Task failed - elapsed_time={elapsed_time:.2f}s, error={detailed_error}",
|
||||||
|
exc_info=True)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "FAILURE",
|
"status": "FAILURE",
|
||||||
"error": detailed_error,
|
"error": detailed_error,
|
||||||
@@ -1100,16 +1124,17 @@ def reflection_engine() -> None:
|
|||||||
@celery_app.task(name="app.core.memory.agent.reflection.timer")
|
@celery_app.task(name="app.core.memory.agent.reflection.timer")
|
||||||
def reflection_timer_task() -> None:
|
def reflection_timer_task() -> None:
|
||||||
"""Periodic Celery task that invokes reflection_engine.
|
"""Periodic Celery task that invokes reflection_engine.
|
||||||
|
|
||||||
Raises an exception on failure.
|
Raises an exception on failure.
|
||||||
"""
|
"""
|
||||||
reflection_engine()
|
reflection_engine()
|
||||||
|
|
||||||
|
|
||||||
# unused task
|
# unused task
|
||||||
# @celery_app.task(name="app.core.memory.agent.health.check_read_service")
|
# @celery_app.task(name="app.core.memory.agent.health.check_read_service")
|
||||||
# def check_read_service_task() -> Dict[str, str]:
|
# def check_read_service_task() -> Dict[str, str]:
|
||||||
# """Call read_service and write latest status to Redis.
|
# """Call read_service and write latest status to Redis.
|
||||||
|
|
||||||
# Returns status data dict that gets written to Redis.
|
# Returns status data dict that gets written to Redis.
|
||||||
# """
|
# """
|
||||||
# client = redis.Redis(
|
# client = redis.Redis(
|
||||||
@@ -1157,31 +1182,31 @@ def reflection_timer_task() -> None:
|
|||||||
@celery_app.task(name="app.controllers.memory_storage_controller.search_all")
|
@celery_app.task(name="app.controllers.memory_storage_controller.search_all")
|
||||||
def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
|
def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
|
||||||
"""定时任务:查询工作空间下所有宿主的记忆总量并写入数据库
|
"""定时任务:查询工作空间下所有宿主的记忆总量并写入数据库
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
workspace_id: 工作空间ID
|
workspace_id: 工作空间ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
包含任务执行结果的字典
|
包含任务执行结果的字典
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
async def _run() -> Dict[str, Any]:
|
async def _run() -> Dict[str, Any]:
|
||||||
from app.models.app_model import App
|
from app.models.app_model import App
|
||||||
from app.models.end_user_model import EndUser
|
from app.models.end_user_model import EndUser
|
||||||
from app.repositories.memory_increment_repository import write_memory_increment
|
from app.repositories.memory_increment_repository import write_memory_increment
|
||||||
from app.services.memory_storage_service import search_all
|
from app.services.memory_storage_service import search_all
|
||||||
|
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
try:
|
try:
|
||||||
workspace_uuid = uuid.UUID(workspace_id)
|
workspace_uuid = uuid.UUID(workspace_id)
|
||||||
|
|
||||||
# 1. 查询当前workspace下的所有app(仅未删除的)
|
# 1. 查询当前workspace下的所有app(仅未删除的)
|
||||||
apps = db.query(App).filter(
|
apps = db.query(App).filter(
|
||||||
App.workspace_id == workspace_uuid,
|
App.workspace_id == workspace_uuid,
|
||||||
App.is_active.is_(True)
|
App.is_active.is_(True)
|
||||||
).all()
|
).all()
|
||||||
|
|
||||||
if not apps:
|
if not apps:
|
||||||
# 如果没有app,总量为0
|
# 如果没有app,总量为0
|
||||||
memory_increment = write_memory_increment(
|
memory_increment = write_memory_increment(
|
||||||
@@ -1197,17 +1222,17 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
|
|||||||
"memory_increment_id": str(memory_increment.id),
|
"memory_increment_id": str(memory_increment.id),
|
||||||
"created_at": memory_increment.created_at.isoformat(),
|
"created_at": memory_increment.created_at.isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# 2. 查询所有app下的end_user_id(去重)
|
# 2. 查询所有app下的end_user_id(去重)
|
||||||
app_ids = [app.id for app in apps]
|
app_ids = [app.id for app in apps]
|
||||||
end_users = db.query(EndUser.id).filter(
|
end_users = db.query(EndUser.id).filter(
|
||||||
EndUser.app_id.in_(app_ids)
|
EndUser.app_id.in_(app_ids)
|
||||||
).distinct().all()
|
).distinct().all()
|
||||||
|
|
||||||
# 3. 遍历所有end_user,查询每个宿主的记忆总量并累加
|
# 3. 遍历所有end_user,查询每个宿主的记忆总量并累加
|
||||||
total_num = 0
|
total_num = 0
|
||||||
end_user_details = []
|
end_user_details = []
|
||||||
|
|
||||||
for (end_user_id,) in end_users:
|
for (end_user_id,) in end_users:
|
||||||
try:
|
try:
|
||||||
# 调用 search_all 接口查询该宿主的总量
|
# 调用 search_all 接口查询该宿主的总量
|
||||||
@@ -1225,14 +1250,14 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
|
|||||||
"total": 0,
|
"total": 0,
|
||||||
"error": str(e)
|
"error": str(e)
|
||||||
})
|
})
|
||||||
|
|
||||||
# 4. 写入数据库
|
# 4. 写入数据库
|
||||||
memory_increment = write_memory_increment(
|
memory_increment = write_memory_increment(
|
||||||
db=db,
|
db=db,
|
||||||
workspace_id=workspace_uuid,
|
workspace_id=workspace_uuid,
|
||||||
total_num=total_num
|
total_num=total_num
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "SUCCESS",
|
"status": "SUCCESS",
|
||||||
"workspace_id": workspace_id,
|
"workspace_id": workspace_id,
|
||||||
@@ -1244,7 +1269,7 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
|
|||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = asyncio.run(_run())
|
result = asyncio.run(_run())
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
@@ -1263,18 +1288,18 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
|
|||||||
@celery_app.task(
|
@celery_app.task(
|
||||||
name="app.tasks.regenerate_memory_cache",
|
name="app.tasks.regenerate_memory_cache",
|
||||||
bind=True,
|
bind=True,
|
||||||
ignore_result=True,
|
ignore_result=True,
|
||||||
max_retries=0,
|
max_retries=0,
|
||||||
acks_late=False,
|
acks_late=False,
|
||||||
time_limit=3600,
|
time_limit=3600,
|
||||||
soft_time_limit=3300,
|
soft_time_limit=3300,
|
||||||
)
|
)
|
||||||
def regenerate_memory_cache(self) -> Dict[str, Any]:
|
def regenerate_memory_cache(self) -> Dict[str, Any]:
|
||||||
"""定时任务:为所有用户重新生成记忆洞察和用户摘要缓存
|
"""定时任务:为所有用户重新生成记忆洞察和用户摘要缓存
|
||||||
|
|
||||||
遍历所有活动工作空间的所有终端用户,为每个用户重新生成记忆洞察和用户摘要。
|
遍历所有活动工作空间的所有终端用户,为每个用户重新生成记忆洞察和用户摘要。
|
||||||
实现错误隔离,单个用户失败不影响其他用户的处理。
|
实现错误隔离,单个用户失败不影响其他用户的处理。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
包含任务执行结果的字典,包括:
|
包含任务执行结果的字典,包括:
|
||||||
- status: 任务状态 (SUCCESS/FAILURE)
|
- status: 任务状态 (SUCCESS/FAILURE)
|
||||||
@@ -1288,57 +1313,57 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
|
|||||||
- task_id: 任务ID
|
- task_id: 任务ID
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
async def _run() -> Dict[str, Any]:
|
async def _run() -> Dict[str, Any]:
|
||||||
from app.core.logging_config import get_logger
|
from app.core.logging_config import get_logger
|
||||||
from app.repositories.end_user_repository import EndUserRepository
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
from app.services.user_memory_service import UserMemoryService
|
from app.services.user_memory_service import UserMemoryService
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
logger.info("开始执行记忆缓存重新生成定时任务")
|
logger.info("开始执行记忆缓存重新生成定时任务")
|
||||||
|
|
||||||
service = UserMemoryService()
|
service = UserMemoryService()
|
||||||
|
|
||||||
total_users = 0
|
total_users = 0
|
||||||
successful = 0
|
successful = 0
|
||||||
failed = 0
|
failed = 0
|
||||||
workspace_results = []
|
workspace_results = []
|
||||||
|
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
try:
|
try:
|
||||||
# 获取所有活动工作空间
|
# 获取所有活动工作空间
|
||||||
repo = EndUserRepository(db)
|
repo = EndUserRepository(db)
|
||||||
workspaces = repo.get_all_active_workspaces()
|
workspaces = repo.get_all_active_workspaces()
|
||||||
logger.info(f"找到 {len(workspaces)} 个活动工作空间")
|
logger.info(f"找到 {len(workspaces)} 个活动工作空间")
|
||||||
|
|
||||||
# 遍历每个工作空间
|
# 遍历每个工作空间
|
||||||
for workspace_id in workspaces:
|
for workspace_id in workspaces:
|
||||||
logger.info(f"开始处理工作空间: {workspace_id}")
|
logger.info(f"开始处理工作空间: {workspace_id}")
|
||||||
workspace_start_time = time.time()
|
workspace_start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 获取工作空间的所有终端用户
|
# 获取工作空间的所有终端用户
|
||||||
end_users = repo.get_all_by_workspace(workspace_id)
|
end_users = repo.get_all_by_workspace(workspace_id)
|
||||||
workspace_user_count = len(end_users)
|
workspace_user_count = len(end_users)
|
||||||
total_users += workspace_user_count
|
total_users += workspace_user_count
|
||||||
|
|
||||||
logger.info(f"工作空间 {workspace_id} 有 {workspace_user_count} 个终端用户")
|
logger.info(f"工作空间 {workspace_id} 有 {workspace_user_count} 个终端用户")
|
||||||
|
|
||||||
workspace_successful = 0
|
workspace_successful = 0
|
||||||
workspace_failed = 0
|
workspace_failed = 0
|
||||||
workspace_errors = []
|
workspace_errors = []
|
||||||
|
|
||||||
# 遍历每个用户并生成缓存
|
# 遍历每个用户并生成缓存
|
||||||
for end_user in end_users:
|
for end_user in end_users:
|
||||||
end_user_id = str(end_user.id)
|
end_user_id = str(end_user.id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 生成记忆洞察
|
# 生成记忆洞察
|
||||||
insight_result = await service.generate_and_cache_insight(db, end_user_id)
|
insight_result = await service.generate_and_cache_insight(db, end_user_id)
|
||||||
|
|
||||||
# 生成用户摘要
|
# 生成用户摘要
|
||||||
summary_result = await service.generate_and_cache_summary(db, end_user_id)
|
summary_result = await service.generate_and_cache_summary(db, end_user_id)
|
||||||
|
|
||||||
# 检查是否都成功
|
# 检查是否都成功
|
||||||
if insight_result["success"] and summary_result["success"]:
|
if insight_result["success"] and summary_result["success"]:
|
||||||
workspace_successful += 1
|
workspace_successful += 1
|
||||||
@@ -1354,7 +1379,7 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
|
|||||||
}
|
}
|
||||||
workspace_errors.append(error_info)
|
workspace_errors.append(error_info)
|
||||||
logger.warning(f"终端用户 {end_user_id} 的缓存重新生成部分失败: {error_info}")
|
logger.warning(f"终端用户 {end_user_id} 的缓存重新生成部分失败: {error_info}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 单个用户失败不影响其他用户(错误隔离)
|
# 单个用户失败不影响其他用户(错误隔离)
|
||||||
workspace_failed += 1
|
workspace_failed += 1
|
||||||
@@ -1365,9 +1390,9 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
|
|||||||
}
|
}
|
||||||
workspace_errors.append(error_info)
|
workspace_errors.append(error_info)
|
||||||
logger.error(f"为终端用户 {end_user_id} 重新生成缓存时出错: {str(e)}")
|
logger.error(f"为终端用户 {end_user_id} 重新生成缓存时出错: {str(e)}")
|
||||||
|
|
||||||
workspace_elapsed = time.time() - workspace_start_time
|
workspace_elapsed = time.time() - workspace_start_time
|
||||||
|
|
||||||
# 记录工作空间处理结果
|
# 记录工作空间处理结果
|
||||||
workspace_result = {
|
workspace_result = {
|
||||||
"workspace_id": str(workspace_id),
|
"workspace_id": str(workspace_id),
|
||||||
@@ -1378,13 +1403,13 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
|
|||||||
"elapsed_time": workspace_elapsed
|
"elapsed_time": workspace_elapsed
|
||||||
}
|
}
|
||||||
workspace_results.append(workspace_result)
|
workspace_results.append(workspace_result)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"工作空间 {workspace_id} 处理完成: "
|
f"工作空间 {workspace_id} 处理完成: "
|
||||||
f"总数={workspace_user_count}, 成功={workspace_successful}, "
|
f"总数={workspace_user_count}, 成功={workspace_successful}, "
|
||||||
f"失败={workspace_failed}, 耗时={workspace_elapsed:.2f}秒"
|
f"失败={workspace_failed}, 耗时={workspace_elapsed:.2f}秒"
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 工作空间处理失败,记录错误并继续处理下一个
|
# 工作空间处理失败,记录错误并继续处理下一个
|
||||||
logger.error(f"处理工作空间 {workspace_id} 时出错: {str(e)}")
|
logger.error(f"处理工作空间 {workspace_id} 时出错: {str(e)}")
|
||||||
@@ -1396,14 +1421,14 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
|
|||||||
"failed": 0,
|
"failed": 0,
|
||||||
"errors": []
|
"errors": []
|
||||||
})
|
})
|
||||||
|
|
||||||
# 记录总体统计信息
|
# 记录总体统计信息
|
||||||
logger.info(
|
logger.info(
|
||||||
f"记忆缓存重新生成定时任务完成: "
|
f"记忆缓存重新生成定时任务完成: "
|
||||||
f"工作空间数={len(workspaces)}, 总用户数={total_users}, "
|
f"工作空间数={len(workspaces)}, 总用户数={total_users}, "
|
||||||
f"成功={successful}, 失败={failed}"
|
f"成功={successful}, 失败={failed}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "SUCCESS",
|
"status": "SUCCESS",
|
||||||
"message": f"成功处理 {len(workspaces)} 个工作空间,总共 {successful}/{total_users} 个用户缓存重新生成成功",
|
"message": f"成功处理 {len(workspaces)} 个工作空间,总共 {successful}/{total_users} 个用户缓存重新生成成功",
|
||||||
@@ -1413,7 +1438,7 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
|
|||||||
"failed": failed,
|
"failed": failed,
|
||||||
"workspace_results": workspace_results
|
"workspace_results": workspace_results
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"记忆缓存重新生成定时任务执行失败: {str(e)}")
|
logger.error(f"记忆缓存重新生成定时任务执行失败: {str(e)}")
|
||||||
return {
|
return {
|
||||||
@@ -1425,7 +1450,7 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
|
|||||||
"failed": failed,
|
"failed": failed,
|
||||||
"workspace_results": workspace_results
|
"workspace_results": workspace_results
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用 nest_asyncio 来避免事件循环冲突
|
# 使用 nest_asyncio 来避免事件循环冲突
|
||||||
try:
|
try:
|
||||||
@@ -1433,7 +1458,7 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
|
|||||||
nest_asyncio.apply()
|
nest_asyncio.apply()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
@@ -1443,12 +1468,12 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
|
|||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
result = loop.run_until_complete(_run())
|
result = loop.run_until_complete(_run())
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
result["elapsed_time"] = elapsed_time
|
result["elapsed_time"] = elapsed_time
|
||||||
result["task_id"] = self.request.id
|
result["task_id"] = self.request.id
|
||||||
|
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
@@ -1460,15 +1485,14 @@ def regenerate_memory_cache(self) -> Dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(
|
@celery_app.task(
|
||||||
name="app.tasks.workspace_reflection_task",
|
name="app.tasks.workspace_reflection_task",
|
||||||
bind=True,
|
bind=True,
|
||||||
ignore_result=True,
|
ignore_result=True,
|
||||||
max_retries=0,
|
max_retries=0,
|
||||||
acks_late=False,
|
acks_late=False,
|
||||||
time_limit=300,
|
time_limit=300,
|
||||||
soft_time_limit=240,
|
soft_time_limit=240,
|
||||||
)
|
)
|
||||||
def workspace_reflection_task(self) -> Dict[str, Any]:
|
def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||||
"""定时任务:每30秒运行工作空间反思功能
|
"""定时任务:每30秒运行工作空间反思功能
|
||||||
@@ -1487,7 +1511,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
try:
|
try:
|
||||||
# 获取所有工作空间
|
# 获取所有工作空间
|
||||||
@@ -1518,15 +1542,16 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
|||||||
workspace_reflection_results = []
|
workspace_reflection_results = []
|
||||||
|
|
||||||
for data in result['apps_detailed_info']:
|
for data in result['apps_detailed_info']:
|
||||||
if data['data_configs'] == []:
|
if data['memory_configs'] == []:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
releases = data['releases']
|
releases = data['releases']
|
||||||
data_configs = data['data_configs']
|
memory_configs = data['memory_configs']
|
||||||
end_users = data['end_users']
|
end_users = data['end_users']
|
||||||
|
|
||||||
for base, config, user in zip(releases, data_configs, end_users):
|
for base, config, user in zip(releases, memory_configs, end_users):
|
||||||
if str(base['config']) == str(config['config_id']) and str(base['app_id']) == str(user['app_id']):
|
if str(base['config']) == str(config['config_id']) and str(base['app_id']) == str(
|
||||||
|
user['app_id']):
|
||||||
# 调用反思服务
|
# 调用反思服务
|
||||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
||||||
|
|
||||||
@@ -1614,75 +1639,73 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(
|
@celery_app.task(
|
||||||
name="app.tasks.run_forgetting_cycle_task",
|
name="app.tasks.run_forgetting_cycle_task",
|
||||||
bind=True,
|
bind=True,
|
||||||
ignore_result=True,
|
ignore_result=True,
|
||||||
max_retries=0,
|
max_retries=0,
|
||||||
acks_late=False,
|
acks_late=False,
|
||||||
time_limit=7200,
|
time_limit=7200,
|
||||||
soft_time_limit=7000,
|
soft_time_limit=7000,
|
||||||
)
|
)
|
||||||
def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Dict[str, Any]:
|
def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Dict[str, Any]:
|
||||||
"""定时任务:运行遗忘周期
|
"""定时任务:运行遗忘周期
|
||||||
|
|
||||||
定期执行遗忘周期,识别并融合低激活值的知识节点。
|
定期执行遗忘周期,识别并融合低激活值的知识节点。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config_id: 配置ID(可选,如果为None则使用默认配置)
|
config_id: 配置ID(可选,如果为None则使用默认配置)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
包含任务执行结果的字典
|
包含任务执行结果的字典
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
async def _run() -> Dict[str, Any]:
|
async def _run() -> Dict[str, Any]:
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.services.memory_forget_service import MemoryForgetService
|
from app.services.memory_forget_service import MemoryForgetService
|
||||||
|
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
try:
|
try:
|
||||||
api_logger.info(f"开始执行遗忘周期定时任务,config_id: {config_id}")
|
api_logger.info(f"开始执行遗忘周期定时任务,config_id: {config_id}")
|
||||||
|
|
||||||
forget_service = MemoryForgetService()
|
forget_service = MemoryForgetService()
|
||||||
|
|
||||||
# 运行遗忘周期
|
# 运行遗忘周期
|
||||||
report = await forget_service.trigger_forgetting(
|
report = await forget_service.trigger_forgetting(
|
||||||
db=db,
|
db=db,
|
||||||
end_user_id=None, # 处理所有组
|
end_user_id=None, # 处理所有组
|
||||||
config_id=config_id
|
config_id=config_id
|
||||||
)
|
)
|
||||||
|
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"遗忘周期定时任务完成: "
|
f"遗忘周期定时任务完成: "
|
||||||
f"融合 {report['merged_count']} 对节点, "
|
f"融合 {report['merged_count']} 对节点, "
|
||||||
f"失败 {report['failed_count']} 对, "
|
f"失败 {report['failed_count']} 对, "
|
||||||
f"耗时 {duration:.2f} 秒"
|
f"耗时 {duration:.2f} 秒"
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "SUCCESS",
|
"status": "SUCCESS",
|
||||||
"message": "遗忘周期执行成功",
|
"message": "遗忘周期执行成功",
|
||||||
"report": report,
|
"report": report,
|
||||||
"duration_seconds": duration
|
"duration_seconds": duration
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
api_logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True)
|
api_logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "FAILED",
|
"status": "FAILED",
|
||||||
"message": f"遗忘周期执行失败: {str(e)}",
|
"message": f"遗忘周期执行失败: {str(e)}",
|
||||||
"duration_seconds": duration
|
"duration_seconds": duration
|
||||||
}
|
}
|
||||||
|
|
||||||
# 运行异步函数
|
# 运行异步函数
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
@@ -1692,7 +1715,6 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
|||||||
finally:
|
finally:
|
||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Long-term Memory Storage Tasks (Batched Write Strategies)
|
# Long-term Memory Storage Tasks (Batched Write Strategies)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -1705,27 +1727,27 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
|||||||
# time_window: int = 5
|
# time_window: int = 5
|
||||||
# ) -> Dict[str, Any]:
|
# ) -> Dict[str, Any]:
|
||||||
# """Celery task for time-based long-term memory storage.
|
# """Celery task for time-based long-term memory storage.
|
||||||
|
|
||||||
# Retrieves recent sessions from Redis within time window and writes to Neo4j.
|
# Retrieves recent sessions from Redis within time window and writes to Neo4j.
|
||||||
|
|
||||||
# Args:
|
# Args:
|
||||||
# end_user_id: End user identifier
|
# end_user_id: End user identifier
|
||||||
# config_id: Memory configuration ID
|
# config_id: Memory configuration ID
|
||||||
# time_window: Time window in minutes for retrieving recent sessions
|
# time_window: Time window in minutes for retrieving recent sessions
|
||||||
|
|
||||||
# Returns:
|
# Returns:
|
||||||
# Dict containing task status and metadata
|
# Dict containing task status and metadata
|
||||||
# """
|
# """
|
||||||
# from app.core.logging_config import get_logger
|
# from app.core.logging_config import get_logger
|
||||||
# logger = get_logger(__name__)
|
# logger = get_logger(__name__)
|
||||||
|
|
||||||
# logger.info(f"[LONG_TERM_TIME] Starting task - end_user_id={end_user_id}, time_window={time_window}")
|
# logger.info(f"[LONG_TERM_TIME] Starting task - end_user_id={end_user_id}, time_window={time_window}")
|
||||||
# start_time = time.time()
|
# start_time = time.time()
|
||||||
|
|
||||||
# async def _run() -> Dict[str, Any]:
|
# async def _run() -> Dict[str, Any]:
|
||||||
# from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage
|
# from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage
|
||||||
# from app.services.memory_config_service import MemoryConfigService
|
# from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
# db = next(get_db())
|
# db = next(get_db())
|
||||||
# try:
|
# try:
|
||||||
# # Load memory config
|
# # Load memory config
|
||||||
@@ -1734,20 +1756,20 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
|||||||
# config_id=config_id,
|
# config_id=config_id,
|
||||||
# service_name="LongTermStorageTask"
|
# service_name="LongTermStorageTask"
|
||||||
# )
|
# )
|
||||||
|
|
||||||
# # Execute time-based storage
|
# # Execute time-based storage
|
||||||
# await memory_long_term_storage(end_user_id, memory_config, time_window)
|
# await memory_long_term_storage(end_user_id, memory_config, time_window)
|
||||||
|
|
||||||
# return {"status": "SUCCESS", "strategy": "time", "time_window": time_window}
|
# return {"status": "SUCCESS", "strategy": "time", "time_window": time_window}
|
||||||
# finally:
|
# finally:
|
||||||
# db.close()
|
# db.close()
|
||||||
|
|
||||||
# try:
|
# try:
|
||||||
# import nest_asyncio
|
# import nest_asyncio
|
||||||
# nest_asyncio.apply()
|
# nest_asyncio.apply()
|
||||||
# except ImportError:
|
# except ImportError:
|
||||||
# pass
|
# pass
|
||||||
|
|
||||||
# try:
|
# try:
|
||||||
# loop = asyncio.get_event_loop()
|
# loop = asyncio.get_event_loop()
|
||||||
# if loop.is_closed():
|
# if loop.is_closed():
|
||||||
@@ -1756,13 +1778,13 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
|||||||
# except RuntimeError:
|
# except RuntimeError:
|
||||||
# loop = asyncio.new_event_loop()
|
# loop = asyncio.new_event_loop()
|
||||||
# asyncio.set_event_loop(loop)
|
# asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
# try:
|
# try:
|
||||||
# result = loop.run_until_complete(_run())
|
# result = loop.run_until_complete(_run())
|
||||||
# elapsed_time = time.time() - start_time
|
# elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
# logger.info(f"[LONG_TERM_TIME] Task completed - elapsed_time={elapsed_time:.2f}s")
|
# logger.info(f"[LONG_TERM_TIME] Task completed - elapsed_time={elapsed_time:.2f}s")
|
||||||
|
|
||||||
# return {
|
# return {
|
||||||
# **result,
|
# **result,
|
||||||
# "end_user_id": end_user_id,
|
# "end_user_id": end_user_id,
|
||||||
@@ -1773,7 +1795,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
|||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# elapsed_time = time.time() - start_time
|
# elapsed_time = time.time() - start_time
|
||||||
# logger.error(f"[LONG_TERM_TIME] Task failed - error={str(e)}", exc_info=True)
|
# logger.error(f"[LONG_TERM_TIME] Task failed - error={str(e)}", exc_info=True)
|
||||||
|
|
||||||
# return {
|
# return {
|
||||||
# "status": "FAILURE",
|
# "status": "FAILURE",
|
||||||
# "strategy": "time",
|
# "strategy": "time",
|
||||||
@@ -1793,45 +1815,45 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
|||||||
# config_id: str
|
# config_id: str
|
||||||
# ) -> Dict[str, Any]:
|
# ) -> Dict[str, Any]:
|
||||||
# """Celery task for aggregate-based long-term memory storage.
|
# """Celery task for aggregate-based long-term memory storage.
|
||||||
|
|
||||||
# Uses LLM to determine if new messages describe the same event as history.
|
# Uses LLM to determine if new messages describe the same event as history.
|
||||||
# Only writes to Neo4j if messages represent new information (not duplicates).
|
# Only writes to Neo4j if messages represent new information (not duplicates).
|
||||||
|
|
||||||
# Args:
|
# Args:
|
||||||
# end_user_id: End user identifier
|
# end_user_id: End user identifier
|
||||||
# langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}]
|
# langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}]
|
||||||
# config_id: Memory configuration ID
|
# config_id: Memory configuration ID
|
||||||
|
|
||||||
# Returns:
|
# Returns:
|
||||||
# Dict containing task status, is_same_event flag, and metadata
|
# Dict containing task status, is_same_event flag, and metadata
|
||||||
# """
|
# """
|
||||||
# from app.core.logging_config import get_logger
|
# from app.core.logging_config import get_logger
|
||||||
# logger = get_logger(__name__)
|
# logger = get_logger(__name__)
|
||||||
|
|
||||||
# logger.info(f"[LONG_TERM_AGGREGATE] Starting task - end_user_id={end_user_id}")
|
# logger.info(f"[LONG_TERM_AGGREGATE] Starting task - end_user_id={end_user_id}")
|
||||||
# start_time = time.time()
|
# start_time = time.time()
|
||||||
|
|
||||||
# async def _run() -> Dict[str, Any]:
|
# async def _run() -> Dict[str, Any]:
|
||||||
# from app.core.memory.agent.langgraph_graph.routing.write_router import aggregate_judgment
|
# from app.core.memory.agent.langgraph_graph.routing.write_router import aggregate_judgment
|
||||||
# from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
|
# from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
|
||||||
# from app.core.memory.agent.utils.redis_tool import write_store
|
# from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
# from app.services.memory_config_service import MemoryConfigService
|
# from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
# db = next(get_db())
|
# db = next(get_db())
|
||||||
# try:
|
# try:
|
||||||
# # Save to Redis buffer first
|
# # Save to Redis buffer first
|
||||||
# write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
|
# write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
|
||||||
|
|
||||||
# # Load memory config
|
# # Load memory config
|
||||||
# config_service = MemoryConfigService(db)
|
# config_service = MemoryConfigService(db)
|
||||||
# memory_config = config_service.load_memory_config(
|
# memory_config = config_service.load_memory_config(
|
||||||
# config_id=config_id,
|
# config_id=config_id,
|
||||||
# service_name="LongTermStorageTask"
|
# service_name="LongTermStorageTask"
|
||||||
# )
|
# )
|
||||||
|
|
||||||
# # Execute aggregate judgment
|
# # Execute aggregate judgment
|
||||||
# result = await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
# result = await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||||
|
|
||||||
# return {
|
# return {
|
||||||
# "status": "SUCCESS",
|
# "status": "SUCCESS",
|
||||||
# "strategy": "aggregate",
|
# "strategy": "aggregate",
|
||||||
@@ -1840,13 +1862,13 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
|||||||
# }
|
# }
|
||||||
# finally:
|
# finally:
|
||||||
# db.close()
|
# db.close()
|
||||||
|
|
||||||
# try:
|
# try:
|
||||||
# import nest_asyncio
|
# import nest_asyncio
|
||||||
# nest_asyncio.apply()
|
# nest_asyncio.apply()
|
||||||
# except ImportError:
|
# except ImportError:
|
||||||
# pass
|
# pass
|
||||||
|
|
||||||
# try:
|
# try:
|
||||||
# loop = asyncio.get_event_loop()
|
# loop = asyncio.get_event_loop()
|
||||||
# if loop.is_closed():
|
# if loop.is_closed():
|
||||||
@@ -1855,13 +1877,13 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
|||||||
# except RuntimeError:
|
# except RuntimeError:
|
||||||
# loop = asyncio.new_event_loop()
|
# loop = asyncio.new_event_loop()
|
||||||
# asyncio.set_event_loop(loop)
|
# asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
# try:
|
# try:
|
||||||
# result = loop.run_until_complete(_run())
|
# result = loop.run_until_complete(_run())
|
||||||
# elapsed_time = time.time() - start_time
|
# elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
# logger.info(f"[LONG_TERM_AGGREGATE] Task completed - is_same_event={result.get('is_same_event')}, elapsed_time={elapsed_time:.2f}s")
|
# logger.info(f"[LONG_TERM_AGGREGATE] Task completed - is_same_event={result.get('is_same_event')}, elapsed_time={elapsed_time:.2f}s")
|
||||||
|
|
||||||
# return {
|
# return {
|
||||||
# **result,
|
# **result,
|
||||||
# "end_user_id": end_user_id,
|
# "end_user_id": end_user_id,
|
||||||
@@ -1872,7 +1894,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
|||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# elapsed_time = time.time() - start_time
|
# elapsed_time = time.time() - start_time
|
||||||
# logger.error(f"[LONG_TERM_AGGREGATE] Task failed - error={str(e)}", exc_info=True)
|
# logger.error(f"[LONG_TERM_AGGREGATE] Task failed - error={str(e)}", exc_info=True)
|
||||||
|
|
||||||
# return {
|
# return {
|
||||||
# "status": "FAILURE",
|
# "status": "FAILURE",
|
||||||
# "strategy": "aggregate",
|
# "strategy": "aggregate",
|
||||||
@@ -1881,4 +1903,4 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
|||||||
# "config_id": config_id,
|
# "config_id": config_id,
|
||||||
# "elapsed_time": elapsed_time,
|
# "elapsed_time": elapsed_time,
|
||||||
# "task_id": self.request.id
|
# "task_id": self.request.id
|
||||||
# }
|
# }
|
||||||
Reference in New Issue
Block a user