读取的接口,去掉全局锁
This commit is contained in:
@@ -9,7 +9,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from threading import Lock
|
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
import redis
|
import redis
|
||||||
@@ -51,9 +51,7 @@ _neo4j_connector = Neo4jConnector()
|
|||||||
class MemoryAgentService:
|
class MemoryAgentService:
|
||||||
"""Service for memory agent operations"""
|
"""Service for memory agent operations"""
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.user_locks: Dict[str, Lock] = {}
|
|
||||||
self.locks_lock = Lock()
|
|
||||||
|
|
||||||
def writer_messages_deal(self,messages,start_time,group_id,config_id,message):
|
def writer_messages_deal(self,messages,start_time,group_id,config_id,message):
|
||||||
messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '')
|
messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '')
|
||||||
@@ -83,12 +81,7 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
raise ValueError(f"写入失败: {messages}")
|
raise ValueError(f"写入失败: {messages}")
|
||||||
|
|
||||||
def get_group_lock(self, group_id: str) -> Lock:
|
|
||||||
"""Get lock for specific group to prevent concurrent processing"""
|
|
||||||
with self.locks_lock:
|
|
||||||
if group_id not in self.user_locks:
|
|
||||||
self.user_locks[group_id] = Lock()
|
|
||||||
return self.user_locks[group_id]
|
|
||||||
|
|
||||||
def extract_tool_call_info(self, event: Dict) -> bool:
|
def extract_tool_call_info(self, event: Dict) -> bool:
|
||||||
"""Extract tool call information from event"""
|
"""Extract tool call information from event"""
|
||||||
@@ -417,241 +410,236 @@ class MemoryAgentService:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
audit_logger = None
|
audit_logger = None
|
||||||
|
|
||||||
# Get group lock to prevent concurrent processing
|
try:
|
||||||
group_lock = self.get_group_lock(group_id)
|
config_service = MemoryConfigService(db)
|
||||||
|
memory_config = config_service.load_memory_config(
|
||||||
|
config_id=config_id,
|
||||||
|
service_name="MemoryAgentService"
|
||||||
|
)
|
||||||
|
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||||
|
except ConfigurationError as e:
|
||||||
|
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
|
||||||
with group_lock:
|
# Log failed operation
|
||||||
# Step 1: Load configuration from database only
|
if audit_logger:
|
||||||
try:
|
duration = time.time() - start_time
|
||||||
config_service = MemoryConfigService(db)
|
audit_logger.log_operation(
|
||||||
memory_config = config_service.load_memory_config(
|
operation="READ",
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
service_name="MemoryAgentService"
|
group_id=group_id,
|
||||||
|
success=False,
|
||||||
|
duration=duration,
|
||||||
|
error=error_msg
|
||||||
)
|
)
|
||||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
|
||||||
except ConfigurationError as e:
|
|
||||||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
|
|
||||||
# Log failed operation
|
raise ValueError(error_msg)
|
||||||
if audit_logger:
|
|
||||||
duration = time.time() - start_time
|
|
||||||
audit_logger.log_operation(
|
|
||||||
operation="READ",
|
|
||||||
config_id=config_id,
|
|
||||||
group_id=group_id,
|
|
||||||
success=False,
|
|
||||||
duration=duration,
|
|
||||||
error=error_msg
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(error_msg)
|
# Step 2: Prepare history
|
||||||
|
history.append({"role": "user", "content": message})
|
||||||
|
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
|
||||||
|
|
||||||
# Step 2: Prepare history
|
# Step 3: Initialize MCP client and execute read workflow
|
||||||
history.append({"role": "user", "content": message})
|
mcp_config = get_mcp_server_config()
|
||||||
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
|
client = MultiServerMCPClient(mcp_config)
|
||||||
|
|
||||||
# Step 3: Initialize MCP client and execute read workflow
|
async with client.session('data_flow') as session:
|
||||||
mcp_config = get_mcp_server_config()
|
session_start = time.time()
|
||||||
client = MultiServerMCPClient(mcp_config)
|
logger.debug("Connected to MCP Server: data_flow")
|
||||||
|
|
||||||
async with client.session('data_flow') as session:
|
tools_start = time.time()
|
||||||
session_start = time.time()
|
tools = await load_mcp_tools(session)
|
||||||
logger.debug("Connected to MCP Server: data_flow")
|
tools_time = time.time() - tools_start
|
||||||
|
logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s")
|
||||||
tools_start = time.time()
|
|
||||||
tools = await load_mcp_tools(session)
|
|
||||||
tools_time = time.time() - tools_start
|
|
||||||
logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s")
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
intermediate_outputs = []
|
|
||||||
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
|
|
||||||
|
|
||||||
# Pass memory_config to the graph workflow
|
outputs = []
|
||||||
graph_start = time.time()
|
intermediate_outputs = []
|
||||||
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
|
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
|
||||||
graph_init_time = time.time() - graph_start
|
|
||||||
logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s")
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
config = {"configurable": {"thread_id": group_id}}
|
|
||||||
workflow_errors = [] # Track errors from workflow
|
|
||||||
|
|
||||||
event_count = 0
|
|
||||||
async for event in graph.astream(
|
|
||||||
{"messages": history, "memory_config": memory_config, "errors": []},
|
|
||||||
stream_mode="values",
|
|
||||||
config=config
|
|
||||||
):
|
|
||||||
event_count += 1
|
|
||||||
event_start = time.time()
|
|
||||||
messages = event.get('messages')
|
|
||||||
# Capture any errors from the state
|
|
||||||
if event.get('errors'):
|
|
||||||
workflow_errors.extend(event.get('errors', []))
|
|
||||||
|
|
||||||
for msg in messages:
|
# Pass memory_config to the graph workflow
|
||||||
msg_content = msg.content
|
graph_start = time.time()
|
||||||
msg_role = msg.__class__.__name__.lower().replace("message", "")
|
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
|
||||||
outputs.append({
|
graph_init_time = time.time() - graph_start
|
||||||
"role": msg_role,
|
logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s")
|
||||||
"content": msg_content
|
|
||||||
})
|
|
||||||
|
|
||||||
# Extract intermediate outputs
|
start = time.time()
|
||||||
if hasattr(msg, 'content'):
|
config = {"configurable": {"thread_id": group_id}}
|
||||||
try:
|
workflow_errors = [] # Track errors from workflow
|
||||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
|
||||||
content_to_parse = msg_content
|
|
||||||
if isinstance(msg_content, list):
|
|
||||||
for block in msg_content:
|
|
||||||
if isinstance(block, dict) and block.get('type') == 'text':
|
|
||||||
content_to_parse = block.get('text', '')
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
continue # No text block found
|
|
||||||
|
|
||||||
# Try to parse content as JSON
|
event_count = 0
|
||||||
if isinstance(content_to_parse, str):
|
async for event in graph.astream(
|
||||||
try:
|
{"messages": history, "memory_config": memory_config, "errors": []},
|
||||||
parsed = json.loads(content_to_parse)
|
stream_mode="values",
|
||||||
if isinstance(parsed, dict):
|
config=config
|
||||||
# Check for single intermediate output
|
):
|
||||||
if '_intermediate' in parsed:
|
event_count += 1
|
||||||
intermediate_data = parsed['_intermediate']
|
event_start = time.time()
|
||||||
|
messages = event.get('messages')
|
||||||
|
# Capture any errors from the state
|
||||||
|
if event.get('errors'):
|
||||||
|
workflow_errors.extend(event.get('errors', []))
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
msg_content = msg.content
|
||||||
|
msg_role = msg.__class__.__name__.lower().replace("message", "")
|
||||||
|
outputs.append({
|
||||||
|
"role": msg_role,
|
||||||
|
"content": msg_content
|
||||||
|
})
|
||||||
|
|
||||||
|
# Extract intermediate outputs
|
||||||
|
if hasattr(msg, 'content'):
|
||||||
|
try:
|
||||||
|
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
||||||
|
content_to_parse = msg_content
|
||||||
|
if isinstance(msg_content, list):
|
||||||
|
for block in msg_content:
|
||||||
|
if isinstance(block, dict) and block.get('type') == 'text':
|
||||||
|
content_to_parse = block.get('text', '')
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
continue # No text block found
|
||||||
|
|
||||||
|
# Try to parse content as JSON
|
||||||
|
if isinstance(content_to_parse, str):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(content_to_parse)
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
# Check for single intermediate output
|
||||||
|
if '_intermediate' in parsed:
|
||||||
|
intermediate_data = parsed['_intermediate']
|
||||||
|
output_key = self._create_intermediate_key(intermediate_data)
|
||||||
|
|
||||||
|
if output_key not in seen_intermediates:
|
||||||
|
seen_intermediates.add(output_key)
|
||||||
|
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
||||||
|
|
||||||
|
# Check for multiple intermediate outputs (from Retrieve)
|
||||||
|
if '_intermediates' in parsed:
|
||||||
|
for intermediate_data in parsed['_intermediates']:
|
||||||
output_key = self._create_intermediate_key(intermediate_data)
|
output_key = self._create_intermediate_key(intermediate_data)
|
||||||
|
|
||||||
if output_key not in seen_intermediates:
|
if output_key not in seen_intermediates:
|
||||||
seen_intermediates.add(output_key)
|
seen_intermediates.add(output_key)
|
||||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to extract intermediate output: {e}")
|
||||||
|
|
||||||
# Check for multiple intermediate outputs (from Retrieve)
|
event_time = time.time() - event_start
|
||||||
if '_intermediates' in parsed:
|
logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s")
|
||||||
for intermediate_data in parsed['_intermediates']:
|
|
||||||
output_key = self._create_intermediate_key(intermediate_data)
|
|
||||||
|
|
||||||
if output_key not in seen_intermediates:
|
workflow_duration = time.time() - start
|
||||||
seen_intermediates.add(output_key)
|
session_duration = time.time() - session_start
|
||||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s")
|
||||||
except (json.JSONDecodeError, ValueError):
|
logger.info(f"[PERF] Total session duration: {session_duration:.4f}s")
|
||||||
pass
|
logger.info(f"[PERF] Total events processed: {event_count}")
|
||||||
except Exception as e:
|
# Extract final answer
|
||||||
logger.debug(f"Failed to extract intermediate output: {e}")
|
final_answer = ""
|
||||||
|
for messages in outputs:
|
||||||
event_time = time.time() - event_start
|
if messages['role'] == 'tool':
|
||||||
logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s")
|
message = messages['content']
|
||||||
|
|
||||||
workflow_duration = time.time() - start
|
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
||||||
session_duration = time.time() - session_start
|
if isinstance(message, list):
|
||||||
logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s")
|
# Extract text from MCP content blocks
|
||||||
logger.info(f"[PERF] Total session duration: {session_duration:.4f}s")
|
for block in message:
|
||||||
logger.info(f"[PERF] Total events processed: {event_count}")
|
if isinstance(block, dict) and block.get('type') == 'text':
|
||||||
# Extract final answer
|
message = block.get('text', '')
|
||||||
final_answer = ""
|
break
|
||||||
for messages in outputs:
|
else:
|
||||||
if messages['role'] == 'tool':
|
continue # No text block found
|
||||||
message = messages['content']
|
|
||||||
|
|
||||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
try:
|
||||||
if isinstance(message, list):
|
parsed = json.loads(message) if isinstance(message, str) else message
|
||||||
# Extract text from MCP content blocks
|
if isinstance(parsed, dict):
|
||||||
for block in message:
|
if parsed.get('status') == 'success':
|
||||||
if isinstance(block, dict) and block.get('type') == 'text':
|
summary_result = parsed.get('summary_result')
|
||||||
message = block.get('text', '')
|
if summary_result:
|
||||||
break
|
final_answer = summary_result
|
||||||
else:
|
except (json.JSONDecodeError, ValueError):
|
||||||
continue # No text block found
|
pass
|
||||||
|
|
||||||
try:
|
# 记录成功的操作
|
||||||
parsed = json.loads(message) if isinstance(message, str) else message
|
total_duration = time.time() - start_time
|
||||||
if isinstance(parsed, dict):
|
|
||||||
if parsed.get('status') == 'success':
|
|
||||||
summary_result = parsed.get('summary_result')
|
|
||||||
if summary_result:
|
|
||||||
final_answer = summary_result
|
|
||||||
except (json.JSONDecodeError, ValueError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 记录成功的操作
|
# Check for workflow errors
|
||||||
total_duration = time.time() - start_time
|
if workflow_errors:
|
||||||
|
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
|
||||||
|
logger.warning(f"Read workflow completed with errors: {error_details}")
|
||||||
|
|
||||||
# Check for workflow errors
|
if audit_logger:
|
||||||
if workflow_errors:
|
|
||||||
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
|
|
||||||
logger.warning(f"Read workflow completed with errors: {error_details}")
|
|
||||||
|
|
||||||
if audit_logger:
|
|
||||||
audit_logger.log_operation(
|
|
||||||
operation="READ",
|
|
||||||
config_id=config_id,
|
|
||||||
group_id=group_id,
|
|
||||||
success=False,
|
|
||||||
duration=total_duration,
|
|
||||||
error=error_details,
|
|
||||||
details={
|
|
||||||
"search_switch": search_switch,
|
|
||||||
"history_length": len(history),
|
|
||||||
"intermediate_outputs_count": len(intermediate_outputs),
|
|
||||||
"has_answer": bool(final_answer),
|
|
||||||
"errors": workflow_errors
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Raise error if no answer was produced
|
|
||||||
if not final_answer:
|
|
||||||
raise ValueError(f"Read workflow failed: {error_details}")
|
|
||||||
|
|
||||||
if audit_logger and not workflow_errors:
|
|
||||||
audit_logger.log_operation(
|
audit_logger.log_operation(
|
||||||
operation="READ",
|
operation="READ",
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
success=True,
|
success=False,
|
||||||
duration=total_duration,
|
duration=total_duration,
|
||||||
|
error=error_details,
|
||||||
details={
|
details={
|
||||||
"search_switch": search_switch,
|
"search_switch": search_switch,
|
||||||
"history_length": len(history),
|
"history_length": len(history),
|
||||||
"intermediate_outputs_count": len(intermediate_outputs),
|
"intermediate_outputs_count": len(intermediate_outputs),
|
||||||
"has_answer": bool(final_answer)
|
"has_answer": bool(final_answer),
|
||||||
|
"errors": workflow_errors
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
retrieved_content=[]
|
|
||||||
repo = ShortTermMemoryRepository(db)
|
# Raise error if no answer was produced
|
||||||
if str(search_switch)!="2":
|
if not final_answer:
|
||||||
for intermediate in intermediate_outputs:
|
raise ValueError(f"Read workflow failed: {error_details}")
|
||||||
print(intermediate)
|
|
||||||
intermediate_type=intermediate['type']
|
if audit_logger and not workflow_errors:
|
||||||
if intermediate_type=="search_result":
|
audit_logger.log_operation(
|
||||||
query=intermediate['query']
|
operation="READ",
|
||||||
raw_results=intermediate['raw_results']
|
config_id=config_id,
|
||||||
reranked_results=raw_results.get('reranked_results',[])
|
group_id=group_id,
|
||||||
try:
|
success=True,
|
||||||
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
|
duration=total_duration,
|
||||||
except Exception:
|
details={
|
||||||
statements=[]
|
"search_switch": search_switch,
|
||||||
statements=list(set(statements))
|
"history_length": len(history),
|
||||||
retrieved_content.append({query:statements})
|
"intermediate_outputs_count": len(intermediate_outputs),
|
||||||
if retrieved_content==[]:
|
"has_answer": bool(final_answer)
|
||||||
retrieved_content=''
|
}
|
||||||
if '信息不足,无法回答。' != str(final_answer) and str(search_switch).strip() != "2":#and retrieved_content!=[]
|
)
|
||||||
# 使用 upsert 方法
|
retrieved_content=[]
|
||||||
repo.upsert(
|
repo = ShortTermMemoryRepository(db)
|
||||||
end_user_id=end_user_id, # 确保这个变量在作用域内
|
if str(search_switch)!="2":
|
||||||
messages=ori_message,
|
for intermediate in intermediate_outputs:
|
||||||
aimessages=final_answer,
|
print(intermediate)
|
||||||
retrieved_content=retrieved_content,
|
intermediate_type=intermediate['type']
|
||||||
search_switch=str(search_switch)
|
if intermediate_type=="search_result":
|
||||||
)
|
query=intermediate['query']
|
||||||
print("写入成功")
|
raw_results=intermediate['raw_results']
|
||||||
|
reranked_results=raw_results.get('reranked_results',[])
|
||||||
|
try:
|
||||||
|
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
|
||||||
|
except Exception:
|
||||||
|
statements=[]
|
||||||
|
statements=list(set(statements))
|
||||||
|
retrieved_content.append({query:statements})
|
||||||
|
if retrieved_content==[]:
|
||||||
|
retrieved_content=''
|
||||||
|
if '信息不足,无法回答。' != str(final_answer) and str(search_switch).strip() != "2":#and retrieved_content!=[]
|
||||||
|
# 使用 upsert 方法
|
||||||
|
repo.upsert(
|
||||||
|
end_user_id=end_user_id, # 确保这个变量在作用域内
|
||||||
|
messages=ori_message,
|
||||||
|
aimessages=final_answer,
|
||||||
|
retrieved_content=retrieved_content,
|
||||||
|
search_switch=str(search_switch)
|
||||||
|
)
|
||||||
|
print("写入成功")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"answer": final_answer,
|
"answer": final_answer,
|
||||||
"intermediate_outputs": intermediate_outputs
|
"intermediate_outputs": intermediate_outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
def _create_intermediate_key(self, output: Dict) -> str:
|
def _create_intermediate_key(self, output: Dict) -> str:
|
||||||
"""
|
"""
|
||||||
Create a unique key for an intermediate output to detect duplicates.
|
Create a unique key for an intermediate output to detect duplicates.
|
||||||
|
|||||||
Reference in New Issue
Block a user