读取的接口,去掉全局锁

This commit is contained in:
lixinyue
2026-01-15 16:47:55 +08:00
parent dda61679bd
commit d9fb8edaa9

View File

@@ -9,7 +9,7 @@ import os
import re import re
import time import time
import uuid import uuid
from threading import Lock
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
import redis import redis
@@ -51,9 +51,7 @@ _neo4j_connector = Neo4jConnector()
class MemoryAgentService: class MemoryAgentService:
"""Service for memory agent operations""" """Service for memory agent operations"""
def __init__(self):
self.user_locks: Dict[str, Lock] = {}
self.locks_lock = Lock()
def writer_messages_deal(self,messages,start_time,group_id,config_id,message): def writer_messages_deal(self,messages,start_time,group_id,config_id,message):
messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '') messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '')
@@ -83,12 +81,7 @@ class MemoryAgentService:
raise ValueError(f"写入失败: {messages}") raise ValueError(f"写入失败: {messages}")
def get_group_lock(self, group_id: str) -> Lock:
"""Get lock for specific group to prevent concurrent processing"""
with self.locks_lock:
if group_id not in self.user_locks:
self.user_locks[group_id] = Lock()
return self.user_locks[group_id]
def extract_tool_call_info(self, event: Dict) -> bool: def extract_tool_call_info(self, event: Dict) -> bool:
"""Extract tool call information from event""" """Extract tool call information from event"""
@@ -417,241 +410,236 @@ class MemoryAgentService:
except ImportError: except ImportError:
audit_logger = None audit_logger = None
# Get group lock to prevent concurrent processing try:
group_lock = self.get_group_lock(group_id) config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
with group_lock: # Log failed operation
# Step 1: Load configuration from database only if audit_logger:
try: duration = time.time() - start_time
config_service = MemoryConfigService(db) audit_logger.log_operation(
memory_config = config_service.load_memory_config( operation="READ",
config_id=config_id, config_id=config_id,
service_name="MemoryAgentService" group_id=group_id,
success=False,
duration=duration,
error=error_msg
) )
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
# Log failed operation raise ValueError(error_msg)
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
success=False,
duration=duration,
error=error_msg
)
raise ValueError(error_msg) # Step 2: Prepare history
history.append({"role": "user", "content": message})
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
# Step 2: Prepare history # Step 3: Initialize MCP client and execute read workflow
history.append({"role": "user", "content": message}) mcp_config = get_mcp_server_config()
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") client = MultiServerMCPClient(mcp_config)
# Step 3: Initialize MCP client and execute read workflow async with client.session('data_flow') as session:
mcp_config = get_mcp_server_config() session_start = time.time()
client = MultiServerMCPClient(mcp_config) logger.debug("Connected to MCP Server: data_flow")
async with client.session('data_flow') as session: tools_start = time.time()
session_start = time.time() tools = await load_mcp_tools(session)
logger.debug("Connected to MCP Server: data_flow") tools_time = time.time() - tools_start
logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s")
tools_start = time.time()
tools = await load_mcp_tools(session)
tools_time = time.time() - tools_start
logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s")
outputs = []
intermediate_outputs = []
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
# Pass memory_config to the graph workflow outputs = []
graph_start = time.time() intermediate_outputs = []
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph: seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
graph_init_time = time.time() - graph_start
logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s")
start = time.time()
config = {"configurable": {"thread_id": group_id}}
workflow_errors = [] # Track errors from workflow
event_count = 0
async for event in graph.astream(
{"messages": history, "memory_config": memory_config, "errors": []},
stream_mode="values",
config=config
):
event_count += 1
event_start = time.time()
messages = event.get('messages')
# Capture any errors from the state
if event.get('errors'):
workflow_errors.extend(event.get('errors', []))
for msg in messages: # Pass memory_config to the graph workflow
msg_content = msg.content graph_start = time.time()
msg_role = msg.__class__.__name__.lower().replace("message", "") async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
outputs.append({ graph_init_time = time.time() - graph_start
"role": msg_role, logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s")
"content": msg_content
})
# Extract intermediate outputs start = time.time()
if hasattr(msg, 'content'): config = {"configurable": {"thread_id": group_id}}
try: workflow_errors = [] # Track errors from workflow
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
content_to_parse = msg_content
if isinstance(msg_content, list):
for block in msg_content:
if isinstance(block, dict) and block.get('type') == 'text':
content_to_parse = block.get('text', '')
break
else:
continue # No text block found
# Try to parse content as JSON event_count = 0
if isinstance(content_to_parse, str): async for event in graph.astream(
try: {"messages": history, "memory_config": memory_config, "errors": []},
parsed = json.loads(content_to_parse) stream_mode="values",
if isinstance(parsed, dict): config=config
# Check for single intermediate output ):
if '_intermediate' in parsed: event_count += 1
intermediate_data = parsed['_intermediate'] event_start = time.time()
messages = event.get('messages')
# Capture any errors from the state
if event.get('errors'):
workflow_errors.extend(event.get('errors', []))
for msg in messages:
msg_content = msg.content
msg_role = msg.__class__.__name__.lower().replace("message", "")
outputs.append({
"role": msg_role,
"content": msg_content
})
# Extract intermediate outputs
if hasattr(msg, 'content'):
try:
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
content_to_parse = msg_content
if isinstance(msg_content, list):
for block in msg_content:
if isinstance(block, dict) and block.get('type') == 'text':
content_to_parse = block.get('text', '')
break
else:
continue # No text block found
# Try to parse content as JSON
if isinstance(content_to_parse, str):
try:
parsed = json.loads(content_to_parse)
if isinstance(parsed, dict):
# Check for single intermediate output
if '_intermediate' in parsed:
intermediate_data = parsed['_intermediate']
output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
# Check for multiple intermediate outputs (from Retrieve)
if '_intermediates' in parsed:
for intermediate_data in parsed['_intermediates']:
output_key = self._create_intermediate_key(intermediate_data) output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates: if output_key not in seen_intermediates:
seen_intermediates.add(output_key) seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
except (json.JSONDecodeError, ValueError):
pass
except Exception as e:
logger.debug(f"Failed to extract intermediate output: {e}")
# Check for multiple intermediate outputs (from Retrieve) event_time = time.time() - event_start
if '_intermediates' in parsed: logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s")
for intermediate_data in parsed['_intermediates']:
output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates: workflow_duration = time.time() - start
seen_intermediates.add(output_key) session_duration = time.time() - session_start
intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s")
except (json.JSONDecodeError, ValueError): logger.info(f"[PERF] Total session duration: {session_duration:.4f}s")
pass logger.info(f"[PERF] Total events processed: {event_count}")
except Exception as e: # Extract final answer
logger.debug(f"Failed to extract intermediate output: {e}") final_answer = ""
for messages in outputs:
event_time = time.time() - event_start if messages['role'] == 'tool':
logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s") message = messages['content']
workflow_duration = time.time() - start # Handle MCP content format: [{'type': 'text', 'text': '...'}]
session_duration = time.time() - session_start if isinstance(message, list):
logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s") # Extract text from MCP content blocks
logger.info(f"[PERF] Total session duration: {session_duration:.4f}s") for block in message:
logger.info(f"[PERF] Total events processed: {event_count}") if isinstance(block, dict) and block.get('type') == 'text':
# Extract final answer message = block.get('text', '')
final_answer = "" break
for messages in outputs: else:
if messages['role'] == 'tool': continue # No text block found
message = messages['content']
# Handle MCP content format: [{'type': 'text', 'text': '...'}] try:
if isinstance(message, list): parsed = json.loads(message) if isinstance(message, str) else message
# Extract text from MCP content blocks if isinstance(parsed, dict):
for block in message: if parsed.get('status') == 'success':
if isinstance(block, dict) and block.get('type') == 'text': summary_result = parsed.get('summary_result')
message = block.get('text', '') if summary_result:
break final_answer = summary_result
else: except (json.JSONDecodeError, ValueError):
continue # No text block found pass
try: # 记录成功的操作
parsed = json.loads(message) if isinstance(message, str) else message total_duration = time.time() - start_time
if isinstance(parsed, dict):
if parsed.get('status') == 'success':
summary_result = parsed.get('summary_result')
if summary_result:
final_answer = summary_result
except (json.JSONDecodeError, ValueError):
pass
# 记录成功的操作 # Check for workflow errors
total_duration = time.time() - start_time if workflow_errors:
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
logger.warning(f"Read workflow completed with errors: {error_details}")
# Check for workflow errors if audit_logger:
if workflow_errors:
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
logger.warning(f"Read workflow completed with errors: {error_details}")
if audit_logger:
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
success=False,
duration=total_duration,
error=error_details,
details={
"search_switch": search_switch,
"history_length": len(history),
"intermediate_outputs_count": len(intermediate_outputs),
"has_answer": bool(final_answer),
"errors": workflow_errors
}
)
# Raise error if no answer was produced
if not final_answer:
raise ValueError(f"Read workflow failed: {error_details}")
if audit_logger and not workflow_errors:
audit_logger.log_operation( audit_logger.log_operation(
operation="READ", operation="READ",
config_id=config_id, config_id=config_id,
group_id=group_id, group_id=group_id,
success=True, success=False,
duration=total_duration, duration=total_duration,
error=error_details,
details={ details={
"search_switch": search_switch, "search_switch": search_switch,
"history_length": len(history), "history_length": len(history),
"intermediate_outputs_count": len(intermediate_outputs), "intermediate_outputs_count": len(intermediate_outputs),
"has_answer": bool(final_answer) "has_answer": bool(final_answer),
"errors": workflow_errors
} }
) )
retrieved_content=[]
repo = ShortTermMemoryRepository(db) # Raise error if no answer was produced
if str(search_switch)!="2": if not final_answer:
for intermediate in intermediate_outputs: raise ValueError(f"Read workflow failed: {error_details}")
print(intermediate)
intermediate_type=intermediate['type'] if audit_logger and not workflow_errors:
if intermediate_type=="search_result": audit_logger.log_operation(
query=intermediate['query'] operation="READ",
raw_results=intermediate['raw_results'] config_id=config_id,
reranked_results=raw_results.get('reranked_results',[]) group_id=group_id,
try: success=True,
statements=[statement['statement'] for statement in reranked_results.get('statements', [])] duration=total_duration,
except Exception: details={
statements=[] "search_switch": search_switch,
statements=list(set(statements)) "history_length": len(history),
retrieved_content.append({query:statements}) "intermediate_outputs_count": len(intermediate_outputs),
if retrieved_content==[]: "has_answer": bool(final_answer)
retrieved_content='' }
if '信息不足,无法回答。' != str(final_answer) and str(search_switch).strip() != "2":#and retrieved_content!=[] )
# 使用 upsert 方法 retrieved_content=[]
repo.upsert( repo = ShortTermMemoryRepository(db)
end_user_id=end_user_id, # 确保这个变量在作用域内 if str(search_switch)!="2":
messages=ori_message, for intermediate in intermediate_outputs:
aimessages=final_answer, print(intermediate)
retrieved_content=retrieved_content, intermediate_type=intermediate['type']
search_switch=str(search_switch) if intermediate_type=="search_result":
) query=intermediate['query']
print("写入成功") raw_results=intermediate['raw_results']
reranked_results=raw_results.get('reranked_results',[])
try:
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
except Exception:
statements=[]
statements=list(set(statements))
retrieved_content.append({query:statements})
if retrieved_content==[]:
retrieved_content=''
if '信息不足,无法回答。' != str(final_answer) and str(search_switch).strip() != "2":#and retrieved_content!=[]
# 使用 upsert 方法
repo.upsert(
end_user_id=end_user_id, # 确保这个变量在作用域内
messages=ori_message,
aimessages=final_answer,
retrieved_content=retrieved_content,
search_switch=str(search_switch)
)
print("写入成功")
return { return {
"answer": final_answer, "answer": final_answer,
"intermediate_outputs": intermediate_outputs "intermediate_outputs": intermediate_outputs
} }
def _create_intermediate_key(self, output: Dict) -> str: def _create_intermediate_key(self, output: Dict) -> str:
""" """
Create a unique key for an intermediate output to detect duplicates. Create a unique key for an intermediate output to detect duplicates.