读取的接口,去掉全局锁

This commit is contained in:
lixinyue
2026-01-15 16:47:55 +08:00
parent dda61679bd
commit d9fb8edaa9

View File

@@ -9,7 +9,7 @@ import os
import re import re
import time import time
import uuid import uuid
from threading import Lock
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
import redis import redis
@@ -51,9 +51,7 @@ _neo4j_connector = Neo4jConnector()
class MemoryAgentService: class MemoryAgentService:
"""Service for memory agent operations""" """Service for memory agent operations"""
def __init__(self):
self.user_locks: Dict[str, Lock] = {}
self.locks_lock = Lock()
def writer_messages_deal(self,messages,start_time,group_id,config_id,message): def writer_messages_deal(self,messages,start_time,group_id,config_id,message):
messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '') messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '')
@@ -83,12 +81,7 @@ class MemoryAgentService:
raise ValueError(f"写入失败: {messages}") raise ValueError(f"写入失败: {messages}")
def get_group_lock(self, group_id: str) -> Lock:
"""Get lock for specific group to prevent concurrent processing"""
with self.locks_lock:
if group_id not in self.user_locks:
self.user_locks[group_id] = Lock()
return self.user_locks[group_id]
def extract_tool_call_info(self, event: Dict) -> bool: def extract_tool_call_info(self, event: Dict) -> bool:
"""Extract tool call information from event""" """Extract tool call information from event"""
@@ -417,240 +410,235 @@ class MemoryAgentService:
except ImportError: except ImportError:
audit_logger = None audit_logger = None
# Get group lock to prevent concurrent processing try:
group_lock = self.get_group_lock(group_id) config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
with group_lock: # Log failed operation
# Step 1: Load configuration from database only if audit_logger:
try: duration = time.time() - start_time
config_service = MemoryConfigService(db) audit_logger.log_operation(
memory_config = config_service.load_memory_config( operation="READ",
config_id=config_id, config_id=config_id,
service_name="MemoryAgentService" group_id=group_id,
success=False,
duration=duration,
error=error_msg
) )
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
# Log failed operation raise ValueError(error_msg)
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
success=False,
duration=duration,
error=error_msg
)
raise ValueError(error_msg) # Step 2: Prepare history
history.append({"role": "user", "content": message})
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
# Step 2: Prepare history # Step 3: Initialize MCP client and execute read workflow
history.append({"role": "user", "content": message}) mcp_config = get_mcp_server_config()
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") client = MultiServerMCPClient(mcp_config)
# Step 3: Initialize MCP client and execute read workflow async with client.session('data_flow') as session:
mcp_config = get_mcp_server_config() session_start = time.time()
client = MultiServerMCPClient(mcp_config) logger.debug("Connected to MCP Server: data_flow")
async with client.session('data_flow') as session: tools_start = time.time()
session_start = time.time() tools = await load_mcp_tools(session)
logger.debug("Connected to MCP Server: data_flow") tools_time = time.time() - tools_start
logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s")
tools_start = time.time() outputs = []
tools = await load_mcp_tools(session) intermediate_outputs = []
tools_time = time.time() - tools_start seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s")
outputs = [] # Pass memory_config to the graph workflow
intermediate_outputs = [] graph_start = time.time()
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
graph_init_time = time.time() - graph_start
logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s")
# Pass memory_config to the graph workflow start = time.time()
graph_start = time.time() config = {"configurable": {"thread_id": group_id}}
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph: workflow_errors = [] # Track errors from workflow
graph_init_time = time.time() - graph_start
logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s")
start = time.time() event_count = 0
config = {"configurable": {"thread_id": group_id}} async for event in graph.astream(
workflow_errors = [] # Track errors from workflow {"messages": history, "memory_config": memory_config, "errors": []},
stream_mode="values",
config=config
):
event_count += 1
event_start = time.time()
messages = event.get('messages')
# Capture any errors from the state
if event.get('errors'):
workflow_errors.extend(event.get('errors', []))
event_count = 0 for msg in messages:
async for event in graph.astream( msg_content = msg.content
{"messages": history, "memory_config": memory_config, "errors": []}, msg_role = msg.__class__.__name__.lower().replace("message", "")
stream_mode="values", outputs.append({
config=config "role": msg_role,
): "content": msg_content
event_count += 1 })
event_start = time.time()
messages = event.get('messages')
# Capture any errors from the state
if event.get('errors'):
workflow_errors.extend(event.get('errors', []))
for msg in messages: # Extract intermediate outputs
msg_content = msg.content if hasattr(msg, 'content'):
msg_role = msg.__class__.__name__.lower().replace("message", "") try:
outputs.append({ # Handle MCP content format: [{'type': 'text', 'text': '...'}]
"role": msg_role, content_to_parse = msg_content
"content": msg_content if isinstance(msg_content, list):
}) for block in msg_content:
if isinstance(block, dict) and block.get('type') == 'text':
content_to_parse = block.get('text', '')
break
else:
continue # No text block found
# Extract intermediate outputs # Try to parse content as JSON
if hasattr(msg, 'content'): if isinstance(content_to_parse, str):
try: try:
# Handle MCP content format: [{'type': 'text', 'text': '...'}] parsed = json.loads(content_to_parse)
content_to_parse = msg_content if isinstance(parsed, dict):
if isinstance(msg_content, list): # Check for single intermediate output
for block in msg_content: if '_intermediate' in parsed:
if isinstance(block, dict) and block.get('type') == 'text': intermediate_data = parsed['_intermediate']
content_to_parse = block.get('text', '') output_key = self._create_intermediate_key(intermediate_data)
break
else:
continue # No text block found
# Try to parse content as JSON if output_key not in seen_intermediates:
if isinstance(content_to_parse, str): seen_intermediates.add(output_key)
try: intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
parsed = json.loads(content_to_parse)
if isinstance(parsed, dict): # Check for multiple intermediate outputs (from Retrieve)
# Check for single intermediate output if '_intermediates' in parsed:
if '_intermediate' in parsed: for intermediate_data in parsed['_intermediates']:
intermediate_data = parsed['_intermediate']
output_key = self._create_intermediate_key(intermediate_data) output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates: if output_key not in seen_intermediates:
seen_intermediates.add(output_key) seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
except (json.JSONDecodeError, ValueError):
pass
except Exception as e:
logger.debug(f"Failed to extract intermediate output: {e}")
# Check for multiple intermediate outputs (from Retrieve) event_time = time.time() - event_start
if '_intermediates' in parsed: logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s")
for intermediate_data in parsed['_intermediates']:
output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates: workflow_duration = time.time() - start
seen_intermediates.add(output_key) session_duration = time.time() - session_start
intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s")
except (json.JSONDecodeError, ValueError): logger.info(f"[PERF] Total session duration: {session_duration:.4f}s")
pass logger.info(f"[PERF] Total events processed: {event_count}")
except Exception as e: # Extract final answer
logger.debug(f"Failed to extract intermediate output: {e}") final_answer = ""
for messages in outputs:
if messages['role'] == 'tool':
message = messages['content']
event_time = time.time() - event_start # Handle MCP content format: [{'type': 'text', 'text': '...'}]
logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s") if isinstance(message, list):
# Extract text from MCP content blocks
for block in message:
if isinstance(block, dict) and block.get('type') == 'text':
message = block.get('text', '')
break
else:
continue # No text block found
workflow_duration = time.time() - start try:
session_duration = time.time() - session_start parsed = json.loads(message) if isinstance(message, str) else message
logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s") if isinstance(parsed, dict):
logger.info(f"[PERF] Total session duration: {session_duration:.4f}s") if parsed.get('status') == 'success':
logger.info(f"[PERF] Total events processed: {event_count}") summary_result = parsed.get('summary_result')
# Extract final answer if summary_result:
final_answer = "" final_answer = summary_result
for messages in outputs: except (json.JSONDecodeError, ValueError):
if messages['role'] == 'tool': pass
message = messages['content']
# Handle MCP content format: [{'type': 'text', 'text': '...'}] # 记录成功的操作
if isinstance(message, list): total_duration = time.time() - start_time
# Extract text from MCP content blocks
for block in message:
if isinstance(block, dict) and block.get('type') == 'text':
message = block.get('text', '')
break
else:
continue # No text block found
try: # Check for workflow errors
parsed = json.loads(message) if isinstance(message, str) else message if workflow_errors:
if isinstance(parsed, dict): error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
if parsed.get('status') == 'success': logger.warning(f"Read workflow completed with errors: {error_details}")
summary_result = parsed.get('summary_result')
if summary_result:
final_answer = summary_result
except (json.JSONDecodeError, ValueError):
pass
# 记录成功的操作 if audit_logger:
total_duration = time.time() - start_time
# Check for workflow errors
if workflow_errors:
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
logger.warning(f"Read workflow completed with errors: {error_details}")
if audit_logger:
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
success=False,
duration=total_duration,
error=error_details,
details={
"search_switch": search_switch,
"history_length": len(history),
"intermediate_outputs_count": len(intermediate_outputs),
"has_answer": bool(final_answer),
"errors": workflow_errors
}
)
# Raise error if no answer was produced
if not final_answer:
raise ValueError(f"Read workflow failed: {error_details}")
if audit_logger and not workflow_errors:
audit_logger.log_operation( audit_logger.log_operation(
operation="READ", operation="READ",
config_id=config_id, config_id=config_id,
group_id=group_id, group_id=group_id,
success=True, success=False,
duration=total_duration, duration=total_duration,
error=error_details,
details={ details={
"search_switch": search_switch, "search_switch": search_switch,
"history_length": len(history), "history_length": len(history),
"intermediate_outputs_count": len(intermediate_outputs), "intermediate_outputs_count": len(intermediate_outputs),
"has_answer": bool(final_answer) "has_answer": bool(final_answer),
"errors": workflow_errors
} }
) )
retrieved_content=[]
repo = ShortTermMemoryRepository(db) # Raise error if no answer was produced
if str(search_switch)!="2": if not final_answer:
for intermediate in intermediate_outputs: raise ValueError(f"Read workflow failed: {error_details}")
print(intermediate)
intermediate_type=intermediate['type'] if audit_logger and not workflow_errors:
if intermediate_type=="search_result": audit_logger.log_operation(
query=intermediate['query'] operation="READ",
raw_results=intermediate['raw_results'] config_id=config_id,
reranked_results=raw_results.get('reranked_results',[]) group_id=group_id,
try: success=True,
statements=[statement['statement'] for statement in reranked_results.get('statements', [])] duration=total_duration,
except Exception: details={
statements=[] "search_switch": search_switch,
statements=list(set(statements)) "history_length": len(history),
retrieved_content.append({query:statements}) "intermediate_outputs_count": len(intermediate_outputs),
if retrieved_content==[]: "has_answer": bool(final_answer)
retrieved_content='' }
if '信息不足,无法回答。' != str(final_answer) and str(search_switch).strip() != "2":#and retrieved_content!=[] )
# 使用 upsert 方法 retrieved_content=[]
repo.upsert( repo = ShortTermMemoryRepository(db)
end_user_id=end_user_id, # 确保这个变量在作用域内 if str(search_switch)!="2":
messages=ori_message, for intermediate in intermediate_outputs:
aimessages=final_answer, print(intermediate)
retrieved_content=retrieved_content, intermediate_type=intermediate['type']
search_switch=str(search_switch) if intermediate_type=="search_result":
) query=intermediate['query']
print("写入成功") raw_results=intermediate['raw_results']
reranked_results=raw_results.get('reranked_results',[])
try:
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
except Exception:
statements=[]
statements=list(set(statements))
retrieved_content.append({query:statements})
if retrieved_content==[]:
retrieved_content=''
if '信息不足,无法回答。' != str(final_answer) and str(search_switch).strip() != "2":#and retrieved_content!=[]
# 使用 upsert 方法
repo.upsert(
end_user_id=end_user_id, # 确保这个变量在作用域内
messages=ori_message,
aimessages=final_answer,
retrieved_content=retrieved_content,
search_switch=str(search_switch)
)
print("写入成功")
return { return {
"answer": final_answer, "answer": final_answer,
"intermediate_outputs": intermediate_outputs "intermediate_outputs": intermediate_outputs
} }
def _create_intermediate_key(self, output: Dict) -> str: def _create_intermediate_key(self, output: Dict) -> str:
""" """