Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop

This commit is contained in:
yujiangping
2026-01-15 17:34:24 +08:00
3 changed files with 211 additions and 214 deletions

View File

@@ -516,8 +516,16 @@ class ConversationService:
conversation_messages = self.get_conversation_history(
conversation_id=conversation_id,
max_history=30
max_history=20
)
if len(conversation_messages) == 0:
return ConversationOut(
theme="",
question=[],
summary="",
takeaways=[],
info_score=0,
)
with open('app/services/prompt/conversation_summary_system.jinja2', 'r', encoding='utf-8') as f:
system_prompt = f.read()
@@ -536,6 +544,7 @@ class ConversationService:
]
logger.info(f"Invoking LLM for conversation_id={conversation_id}")
model_resp = await llm.ainvoke(messages)
try:
if isinstance(model_resp.content, str):
result = json_repair.repair_json(model_resp.content, return_objects=True)

View File

@@ -9,7 +9,7 @@ import os
import re
import time
import uuid
from threading import Lock
from typing import Any, AsyncGenerator, Dict, List, Optional
import redis
@@ -51,9 +51,7 @@ _neo4j_connector = Neo4jConnector()
class MemoryAgentService:
"""Service for memory agent operations"""
def __init__(self):
self.user_locks: Dict[str, Lock] = {}
self.locks_lock = Lock()
def writer_messages_deal(self,messages,start_time,group_id,config_id,message):
messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '')
@@ -83,12 +81,7 @@ class MemoryAgentService:
raise ValueError(f"写入失败: {messages}")
def get_group_lock(self, group_id: str) -> Lock:
"""Get lock for specific group to prevent concurrent processing"""
with self.locks_lock:
if group_id not in self.user_locks:
self.user_locks[group_id] = Lock()
return self.user_locks[group_id]
def extract_tool_call_info(self, event: Dict) -> bool:
"""Extract tool call information from event"""
@@ -417,241 +410,236 @@ class MemoryAgentService:
except ImportError:
audit_logger = None
# Get group lock to prevent concurrent processing
group_lock = self.get_group_lock(group_id)
try:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
with group_lock:
# Step 1: Load configuration from database only
try:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
# Log failed operation
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
service_name="MemoryAgentService"
group_id=group_id,
success=False,
duration=duration,
error=error_msg
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
# Log failed operation
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
success=False,
duration=duration,
error=error_msg
)
raise ValueError(error_msg)
raise ValueError(error_msg)
# Step 2: Prepare history
history.append({"role": "user", "content": message})
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
# Step 2: Prepare history
history.append({"role": "user", "content": message})
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
# Step 3: Initialize MCP client and execute read workflow
mcp_config = get_mcp_server_config()
client = MultiServerMCPClient(mcp_config)
# Step 3: Initialize MCP client and execute read workflow
mcp_config = get_mcp_server_config()
client = MultiServerMCPClient(mcp_config)
async with client.session('data_flow') as session:
session_start = time.time()
logger.debug("Connected to MCP Server: data_flow")
async with client.session('data_flow') as session:
session_start = time.time()
logger.debug("Connected to MCP Server: data_flow")
tools_start = time.time()
tools = await load_mcp_tools(session)
tools_time = time.time() - tools_start
logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s")
outputs = []
intermediate_outputs = []
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
tools_start = time.time()
tools = await load_mcp_tools(session)
tools_time = time.time() - tools_start
logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s")
# Pass memory_config to the graph workflow
graph_start = time.time()
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
graph_init_time = time.time() - graph_start
logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s")
start = time.time()
config = {"configurable": {"thread_id": group_id}}
workflow_errors = [] # Track errors from workflow
event_count = 0
async for event in graph.astream(
{"messages": history, "memory_config": memory_config, "errors": []},
stream_mode="values",
config=config
):
event_count += 1
event_start = time.time()
messages = event.get('messages')
# Capture any errors from the state
if event.get('errors'):
workflow_errors.extend(event.get('errors', []))
outputs = []
intermediate_outputs = []
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
for msg in messages:
msg_content = msg.content
msg_role = msg.__class__.__name__.lower().replace("message", "")
outputs.append({
"role": msg_role,
"content": msg_content
})
# Pass memory_config to the graph workflow
graph_start = time.time()
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
graph_init_time = time.time() - graph_start
logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s")
# Extract intermediate outputs
if hasattr(msg, 'content'):
try:
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
content_to_parse = msg_content
if isinstance(msg_content, list):
for block in msg_content:
if isinstance(block, dict) and block.get('type') == 'text':
content_to_parse = block.get('text', '')
break
else:
continue # No text block found
start = time.time()
config = {"configurable": {"thread_id": group_id}}
workflow_errors = [] # Track errors from workflow
# Try to parse content as JSON
if isinstance(content_to_parse, str):
try:
parsed = json.loads(content_to_parse)
if isinstance(parsed, dict):
# Check for single intermediate output
if '_intermediate' in parsed:
intermediate_data = parsed['_intermediate']
event_count = 0
async for event in graph.astream(
{"messages": history, "memory_config": memory_config, "errors": []},
stream_mode="values",
config=config
):
event_count += 1
event_start = time.time()
messages = event.get('messages')
# Capture any errors from the state
if event.get('errors'):
workflow_errors.extend(event.get('errors', []))
for msg in messages:
msg_content = msg.content
msg_role = msg.__class__.__name__.lower().replace("message", "")
outputs.append({
"role": msg_role,
"content": msg_content
})
# Extract intermediate outputs
if hasattr(msg, 'content'):
try:
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
content_to_parse = msg_content
if isinstance(msg_content, list):
for block in msg_content:
if isinstance(block, dict) and block.get('type') == 'text':
content_to_parse = block.get('text', '')
break
else:
continue # No text block found
# Try to parse content as JSON
if isinstance(content_to_parse, str):
try:
parsed = json.loads(content_to_parse)
if isinstance(parsed, dict):
# Check for single intermediate output
if '_intermediate' in parsed:
intermediate_data = parsed['_intermediate']
output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
# Check for multiple intermediate outputs (from Retrieve)
if '_intermediates' in parsed:
for intermediate_data in parsed['_intermediates']:
output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
except (json.JSONDecodeError, ValueError):
pass
except Exception as e:
logger.debug(f"Failed to extract intermediate output: {e}")
# Check for multiple intermediate outputs (from Retrieve)
if '_intermediates' in parsed:
for intermediate_data in parsed['_intermediates']:
output_key = self._create_intermediate_key(intermediate_data)
event_time = time.time() - event_start
logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s")
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
except (json.JSONDecodeError, ValueError):
pass
except Exception as e:
logger.debug(f"Failed to extract intermediate output: {e}")
event_time = time.time() - event_start
logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s")
workflow_duration = time.time() - start
session_duration = time.time() - session_start
logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s")
logger.info(f"[PERF] Total session duration: {session_duration:.4f}s")
logger.info(f"[PERF] Total events processed: {event_count}")
# Extract final answer
final_answer = ""
for messages in outputs:
if messages['role'] == 'tool':
message = messages['content']
workflow_duration = time.time() - start
session_duration = time.time() - session_start
logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s")
logger.info(f"[PERF] Total session duration: {session_duration:.4f}s")
logger.info(f"[PERF] Total events processed: {event_count}")
# Extract final answer
final_answer = ""
for messages in outputs:
if messages['role'] == 'tool':
message = messages['content']
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
if isinstance(message, list):
# Extract text from MCP content blocks
for block in message:
if isinstance(block, dict) and block.get('type') == 'text':
message = block.get('text', '')
break
else:
continue # No text block found
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
if isinstance(message, list):
# Extract text from MCP content blocks
for block in message:
if isinstance(block, dict) and block.get('type') == 'text':
message = block.get('text', '')
break
else:
continue # No text block found
try:
parsed = json.loads(message) if isinstance(message, str) else message
if isinstance(parsed, dict):
if parsed.get('status') == 'success':
summary_result = parsed.get('summary_result')
if summary_result:
final_answer = summary_result
except (json.JSONDecodeError, ValueError):
pass
try:
parsed = json.loads(message) if isinstance(message, str) else message
if isinstance(parsed, dict):
if parsed.get('status') == 'success':
summary_result = parsed.get('summary_result')
if summary_result:
final_answer = summary_result
except (json.JSONDecodeError, ValueError):
pass
# 记录成功的操作
total_duration = time.time() - start_time
# 记录成功的操作
total_duration = time.time() - start_time
# Check for workflow errors
if workflow_errors:
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
logger.warning(f"Read workflow completed with errors: {error_details}")
# Check for workflow errors
if workflow_errors:
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
logger.warning(f"Read workflow completed with errors: {error_details}")
if audit_logger:
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
success=False,
duration=total_duration,
error=error_details,
details={
"search_switch": search_switch,
"history_length": len(history),
"intermediate_outputs_count": len(intermediate_outputs),
"has_answer": bool(final_answer),
"errors": workflow_errors
}
)
# Raise error if no answer was produced
if not final_answer:
raise ValueError(f"Read workflow failed: {error_details}")
if audit_logger and not workflow_errors:
if audit_logger:
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
success=True,
success=False,
duration=total_duration,
error=error_details,
details={
"search_switch": search_switch,
"history_length": len(history),
"intermediate_outputs_count": len(intermediate_outputs),
"has_answer": bool(final_answer)
"has_answer": bool(final_answer),
"errors": workflow_errors
}
)
retrieved_content=[]
repo = ShortTermMemoryRepository(db)
if str(search_switch)!="2":
for intermediate in intermediate_outputs:
print(intermediate)
intermediate_type=intermediate['type']
if intermediate_type=="search_result":
query=intermediate['query']
raw_results=intermediate['raw_results']
reranked_results=raw_results.get('reranked_results',[])
try:
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
except Exception:
statements=[]
statements=list(set(statements))
retrieved_content.append({query:statements})
if retrieved_content==[]:
retrieved_content=''
if '信息不足,无法回答。' != str(final_answer) and str(search_switch).strip() != "2":#and retrieved_content!=[]
# 使用 upsert 方法
repo.upsert(
end_user_id=end_user_id, # 确保这个变量在作用域内
messages=ori_message,
aimessages=final_answer,
retrieved_content=retrieved_content,
search_switch=str(search_switch)
)
print("写入成功")
# Raise error if no answer was produced
if not final_answer:
raise ValueError(f"Read workflow failed: {error_details}")
if audit_logger and not workflow_errors:
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
success=True,
duration=total_duration,
details={
"search_switch": search_switch,
"history_length": len(history),
"intermediate_outputs_count": len(intermediate_outputs),
"has_answer": bool(final_answer)
}
)
retrieved_content=[]
repo = ShortTermMemoryRepository(db)
if str(search_switch)!="2":
for intermediate in intermediate_outputs:
print(intermediate)
intermediate_type=intermediate['type']
if intermediate_type=="search_result":
query=intermediate['query']
raw_results=intermediate['raw_results']
reranked_results=raw_results.get('reranked_results',[])
try:
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
except Exception:
statements=[]
statements=list(set(statements))
retrieved_content.append({query:statements})
if retrieved_content==[]:
retrieved_content=''
if '信息不足,无法回答。' != str(final_answer) and str(search_switch).strip() != "2":#and retrieved_content!=[]
# 使用 upsert 方法
repo.upsert(
end_user_id=end_user_id, # 确保这个变量在作用域内
messages=ori_message,
aimessages=final_answer,
retrieved_content=retrieved_content,
search_switch=str(search_switch)
)
print("写入成功")
return {
"answer": final_answer,
"intermediate_outputs": intermediate_outputs
}
return {
"answer": final_answer,
"intermediate_outputs": intermediate_outputs
}
def _create_intermediate_key(self, output: Dict) -> str:
"""
Create a unique key for an intermediate output to detect duplicates.

View File

@@ -267,14 +267,14 @@ class MemoryForgetService:
elif node_type_label == 'memorysummary':
node_type_label = 'summary'
# 将 Neo4j DateTime 对象转换为时间戳
# 将 Neo4j DateTime 对象转换为时间戳(毫秒)
last_access_time = result['last_access_time']
last_access_dt = convert_neo4j_datetime_to_python(last_access_time)
# 确保 datetime 带有时区信息(假定为 UTC),避免 naive datetime 导致的时区偏差
if last_access_dt:
if last_access_dt.tzinfo is None:
last_access_dt = last_access_dt.replace(tzinfo=timezone.utc)
last_access_timestamp = int(last_access_dt.timestamp())
last_access_timestamp = int(last_access_dt.timestamp() * 1000)
else:
last_access_timestamp = 0
@@ -520,7 +520,7 @@ class MemoryForgetService:
'average_activation_value': result['average_activation'],
'low_activation_nodes': result['low_activation_nodes'] or 0,
'forgetting_threshold': forgetting_threshold,
'timestamp': int(datetime.now().timestamp())
'timestamp': int(datetime.now().timestamp() * 1000)
}
else:
activation_metrics = {
@@ -530,7 +530,7 @@ class MemoryForgetService:
'average_activation_value': None,
'low_activation_nodes': 0,
'forgetting_threshold': forgetting_threshold,
'timestamp': int(datetime.now().timestamp())
'timestamp': int(datetime.now().timestamp() * 1000)
}
# 收集节点类型分布
@@ -620,7 +620,7 @@ class MemoryForgetService:
'merged_count': record.merged_count,
'average_activation': record.average_activation_value,
'total_nodes': record.total_nodes,
'execution_time': int(record.execution_time.timestamp())
'execution_time': int(record.execution_time.timestamp() * 1000)
})
api_logger.info(f"成功获取最近 {len(recent_trends)} 个日期的历史趋势数据")
@@ -661,7 +661,7 @@ class MemoryForgetService:
'node_distribution': node_distribution,
'recent_trends': recent_trends,
'pending_nodes': pending_nodes,
'timestamp': int(datetime.now().timestamp())
'timestamp': int(datetime.now().timestamp() * 1000)
}
api_logger.info(