Merge pull request #318 from SuanmoSuanyangTechnology/fix/release_memory_bug

Fix/release memory bug
This commit is contained in:
Mark
2026-02-04 20:29:12 +08:00
committed by GitHub
3 changed files with 155 additions and 50 deletions

View File

@@ -1,8 +1,6 @@
import json import json
from langchain_core.messages import HumanMessage, AIMessage from langchain_core.messages import HumanMessage, AIMessage
async def format_parsing(messages: list,type:str='string'): async def format_parsing(messages: list,type:str='string'):
""" """
格式化解析消息列表 格式化解析消息列表

View File

@@ -4,6 +4,7 @@ Write Tools for Memory Knowledge Extraction Pipeline
This module provides the main write function for executing the knowledge extraction This module provides the main write function for executing the knowledge extraction
pipeline. Only MemoryConfig is needed - clients are constructed internally. pipeline. Only MemoryConfig is needed - clients are constructed internally.
""" """
import asyncio
import time import time
from datetime import datetime from datetime import datetime
@@ -123,23 +124,48 @@ async def write(
except Exception as e: except Exception as e:
logger.error(f"Error creating indexes: {e}", exc_info=True) logger.error(f"Error creating indexes: {e}", exc_info=True)
# 添加死锁重试机制
max_retries = 3
retry_delay = 1 # 秒
for attempt in range(max_retries):
try:
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=all_dialogue_nodes,
chunk_nodes=all_chunk_nodes,
statement_nodes=all_statement_nodes,
entity_nodes=all_entity_nodes,
statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges,
connector=neo4j_connector
)
if success:
logger.info("Successfully saved all data to Neo4j")
break
else:
logger.warning("Failed to save some data to Neo4j")
if attempt < max_retries - 1:
logger.info(f"Retrying... (attempt {attempt + 2}/{max_retries})")
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
except Exception as e:
error_msg = str(e)
# 检查是否是死锁错误
if "DeadlockDetected" in error_msg or "deadlock" in error_msg.lower():
if attempt < max_retries - 1:
logger.warning(f"Deadlock detected, retrying... (attempt {attempt + 2}/{max_retries})")
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
else:
logger.error(f"Failed after {max_retries} attempts due to deadlock: {e}")
raise
else:
# 非死锁错误,直接抛出
raise
try: try:
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=all_dialogue_nodes,
chunk_nodes=all_chunk_nodes,
statement_nodes=all_statement_nodes,
entity_nodes=all_entity_nodes,
statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges,
connector=neo4j_connector
)
if success:
logger.info("Successfully saved all data to Neo4j")
else:
logger.warning("Failed to save some data to Neo4j")
finally:
await neo4j_connector.close() await neo4j_connector.close()
except Exception as e:
logger.error(f"Error closing Neo4j connector: {e}")
log_time("Neo4j Database Save", time.time() - step_start, log_file) log_time("Neo4j Database Save", time.time() - step_start, log_file)

View File

@@ -21,7 +21,8 @@ from app.core.memory.models.graph_models import (
ExtractedEntityNode, ExtractedEntityNode,
EntityEntityEdge, EntityEntityEdge,
) )
import logging
logger = logging.getLogger(__name__)
async def save_entities_and_relationships( async def save_entities_and_relationships(
entity_nodes: List[ExtractedEntityNode], entity_nodes: List[ExtractedEntityNode],
entity_entity_edges: List[EntityEntityEdge], entity_entity_edges: List[EntityEntityEdge],
@@ -147,14 +148,14 @@ async def save_statement_entity_edges(
async def save_dialog_and_statements_to_neo4j( async def save_dialog_and_statements_to_neo4j(
dialogue_nodes: List[DialogueNode], dialogue_nodes: List[DialogueNode],
chunk_nodes: List[ChunkNode], chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode], statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode], entity_nodes: List[ExtractedEntityNode],
entity_edges: List[EntityEntityEdge], entity_edges: List[EntityEntityEdge],
statement_chunk_edges: List[StatementChunkEdge], statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge], statement_entity_edges: List[StatementEntityEdge],
connector: Neo4jConnector connector: Neo4jConnector
) -> bool: ) -> bool:
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
@@ -171,40 +172,120 @@ async def save_dialog_and_statements_to_neo4j(
Returns: Returns:
bool: True if successful, False otherwise bool: True if successful, False otherwise
""" """
try:
# Save all dialogue nodes in batch # 定义事务函数,将所有写操作放在一个事务中
dialogue_uuids = await add_dialogue_nodes(dialogue_nodes, connector) async def _save_all_in_transaction(tx):
if dialogue_uuids: """在单个事务中执行所有保存操作,避免死锁"""
results = {}
# 1. Save all dialogue nodes in batch
if dialogue_nodes:
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE
dialogue_data = [node.model_dump() for node in dialogue_nodes]
result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data)
dialogue_uuids = [record["uuid"] async for record in result]
results['dialogues'] = dialogue_uuids
print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}") print(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
else:
print("Failed to save dialogues to Neo4j")
return False
# Save all chunk nodes in batch # 2. Save all chunk nodes in batch
await save_chunk_nodes(chunk_nodes, connector) if chunk_nodes:
from app.repositories.neo4j.cypher_queries import CHUNK_NODE_SAVE
chunk_data = [node.model_dump() for node in chunk_nodes]
result = await tx.run(CHUNK_NODE_SAVE, chunks=chunk_data)
chunk_uuids = [record["uuid"] async for record in result]
results['chunks'] = chunk_uuids
logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
# Save all statement nodes in batch # 3. Save all statement nodes in batch
if statement_nodes: if statement_nodes:
statement_uuids = await add_statement_nodes(statement_nodes, connector) from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE
if statement_uuids: statement_data = [node.model_dump() for node in statement_nodes]
print(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j") result = await tx.run(STATEMENT_NODE_SAVE, statements=statement_data)
else: statement_uuids = [record["uuid"] async for record in result]
print("Failed to save statement nodes to Neo4j") results['statements'] = statement_uuids
return False logger.info(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j")
else:
print("No statement nodes to save")
# Save entities and relationships # 4. Save entities
await save_entities_and_relationships(entity_nodes, entity_edges, connector) if entity_nodes:
print("Successfully saved entities and relationships to Neo4j") from app.repositories.neo4j.cypher_queries import EXTRACTED_ENTITY_NODE_SAVE
entity_data = [entity.model_dump() for entity in entity_nodes]
result = await tx.run(EXTRACTED_ENTITY_NODE_SAVE, entities=entity_data)
entity_uuids = [record["uuid"] async for record in result]
results['entities'] = entity_uuids
logger.info(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j")
# Save new edges # 5. Create entity relationships
await save_statement_chunk_edges(statement_chunk_edges, connector) if entity_edges:
await save_statement_entity_edges(statement_entity_edges, connector) from app.repositories.neo4j.cypher_queries import ENTITY_RELATIONSHIP_SAVE
relationship_data = []
for edge in entity_edges:
relationship_data.append({
'source_id': edge.source,
'target_id': edge.target,
'predicate': edge.relation_type,
'statement_id': edge.source_statement_id,
'value': edge.relation_value,
'statement': edge.statement,
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
'created_at': edge.created_at.isoformat(),
'expired_at': edge.expired_at.isoformat(),
'run_id': edge.run_id,
'end_user_id': edge.end_user_id,
})
result = await tx.run(ENTITY_RELATIONSHIP_SAVE, relationships=relationship_data)
rel_uuids = [record["uuid"] async for record in result]
results['entity_relationships'] = rel_uuids
logger.info(f"Successfully saved {len(rel_uuids)} entity relationships to Neo4j")
# 6. Save statement-chunk edges
if statement_chunk_edges:
from app.repositories.neo4j.cypher_queries import STATEMENT_CHUNK_EDGE_SAVE
sc_edge_data = []
for edge in statement_chunk_edges:
sc_edge_data.append({
"id": edge.id,
"source": edge.source,
"target": edge.target,
"created_at": edge.created_at.isoformat(),
"expired_at": edge.expired_at.isoformat(),
"run_id": edge.run_id,
"end_user_id": edge.end_user_id,
})
result = await tx.run(STATEMENT_CHUNK_EDGE_SAVE, edges=sc_edge_data)
sc_uuids = [record["uuid"] async for record in result]
results['statement_chunk_edges'] = sc_uuids
logger.info(f"Successfully saved {len(sc_uuids)} statement-chunk edges to Neo4j")
# 7. Save statement-entity edges
if statement_entity_edges:
from app.repositories.neo4j.cypher_queries import STATEMENT_ENTITY_EDGE_SAVE
se_edge_data = []
for edge in statement_entity_edges:
se_edge_data.append({
"id": edge.id,
"source": edge.source,
"target": edge.target,
"created_at": edge.created_at.isoformat(),
"expired_at": edge.expired_at.isoformat(),
"run_id": edge.run_id,
"end_user_id": edge.end_user_id,
})
result = await tx.run(STATEMENT_ENTITY_EDGE_SAVE, edges=se_edge_data)
se_uuids = [record["uuid"] async for record in result]
results['statement_entity_edges'] = se_uuids
logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j")
return results
try:
# 使用显式写事务执行所有操作,避免死锁
results = await connector.execute_write_transaction(_save_all_in_transaction)
print("Successfully saved all data to Neo4j in a single transaction")
return True return True
except Exception as e: except Exception as e:
print(f"Neo4j integration error: {e}") print(f"Neo4j integration error: {e}")
print("Continuing without database storage...") print("Continuing without database storage...")
return False return False