去掉MCP框架,重构
This commit is contained in:
@@ -13,26 +13,26 @@ from threading import Lock
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import redis
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_config_logger, get_logger
|
||||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
||||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||
from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config
|
||||
from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results
|
||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import ConfigurationError
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_konwledges_server import (
|
||||
write_rag,
|
||||
)
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langchain_mcp_adapters.tools import load_mcp_tools
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -55,18 +55,16 @@ class MemoryAgentService:
|
||||
self.user_locks: Dict[str, Lock] = {}
|
||||
self.locks_lock = Lock()
|
||||
|
||||
def writer_messages_deal(self,messages,start_time,group_id,config_id,message):
|
||||
messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '')
|
||||
countext = re.findall(r'"status": "(.*?)",', messages)[0]
|
||||
def writer_messages_deal(self, messages, start_time, group_id, config_id, message, context):
|
||||
duration = time.time() - start_time
|
||||
|
||||
if countext == 'success':
|
||||
if str(messages) == 'success':
|
||||
logger.info(f"Write operation successful for group {group_id} with config_id {config_id}")
|
||||
# 记录成功的操作
|
||||
if audit_logger:
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=True,
|
||||
duration=duration, details={"message_length": len(message)})
|
||||
return countext
|
||||
return context
|
||||
else:
|
||||
logger.warning(f"Write operation failed for group {group_id}")
|
||||
|
||||
@@ -150,8 +148,26 @@ class MemoryAgentService:
|
||||
else:
|
||||
status = "unknown"
|
||||
|
||||
# Add database connection pool status
|
||||
try:
|
||||
from app.db import get_pool_status
|
||||
pool_status = get_pool_status()
|
||||
logger.info(f"Database pool status: {pool_status}")
|
||||
|
||||
# Check if pool usage is too high
|
||||
if pool_status.get("usage_percent", 0) > 80:
|
||||
logger.warning(f"High database pool usage: {pool_status['usage_percent']}%")
|
||||
status = "warning"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get pool status: {e}")
|
||||
pool_status = {"error": str(e)}
|
||||
|
||||
logger.info(f"Health status: {status}")
|
||||
return {"status": status}
|
||||
return {
|
||||
"status": status,
|
||||
"database_pool": pool_status
|
||||
}
|
||||
|
||||
def get_log_content(self) -> str:
|
||||
"""
|
||||
@@ -308,54 +324,42 @@ class MemoryAgentService:
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
|
||||
|
||||
raise ValueError(error_msg)
|
||||
mcp_config = get_mcp_server_config()
|
||||
client = MultiServerMCPClient(mcp_config)
|
||||
|
||||
if storage_type == "rag":
|
||||
result = await write_rag(group_id, message, user_rag_memory_id)
|
||||
return result
|
||||
else:
|
||||
async with client.session("data_flow") as session:
|
||||
logger.debug("Connected to MCP Server: data_flow")
|
||||
tools = await load_mcp_tools(session)
|
||||
workflow_errors = [] # Track errors from workflow
|
||||
|
||||
# Pass memory_config to the graph workflow
|
||||
async with make_write_graph(group_id, tools, group_id, group_id, memory_config=memory_config) as graph:
|
||||
logger.debug("Write graph created successfully")
|
||||
|
||||
try:
|
||||
if storage_type == "rag":
|
||||
result = await write_rag(group_id, message, user_rag_memory_id)
|
||||
return result
|
||||
else:
|
||||
async with make_write_graph() as graph:
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id,
|
||||
"memory_config": memory_config}
|
||||
|
||||
async for event in graph.astream(
|
||||
{"messages": message, "memory_config": memory_config, "errors": []},
|
||||
stream_mode="values",
|
||||
# 获取节点更新信息
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
messages = event.get('messages')
|
||||
# Capture any errors from the state
|
||||
if event.get('errors'):
|
||||
workflow_errors.extend(event.get('errors', []))
|
||||
|
||||
# Check for workflow errors
|
||||
if workflow_errors:
|
||||
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
|
||||
logger.error(f"Write workflow failed with errors: {error_details}")
|
||||
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="WRITE",
|
||||
config_id=config_id,
|
||||
group_id=group_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=error_details
|
||||
)
|
||||
|
||||
raise ValueError(f"Write workflow failed: {error_details}")
|
||||
|
||||
return self.writer_messages_deal(messages, start_time, group_id, config_id, message)
|
||||
|
||||
for node_name, node_data in update_event.items():
|
||||
if 'save_neo4j' == node_name:
|
||||
massages = node_data
|
||||
massagesstatus = massages.get('write_result')['status']
|
||||
contents = massages.get('write_result')
|
||||
return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message, contents)
|
||||
except Exception as e:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Write operation failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
|
||||
|
||||
async def read_memory(
|
||||
self,
|
||||
group_id: str,
|
||||
@@ -394,8 +398,9 @@ class MemoryAgentService:
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
ori_message=message
|
||||
end_user_id=group_id
|
||||
ori_message=message
|
||||
|
||||
# Resolve config_id if None using end_user's connected config
|
||||
if config_id is None:
|
||||
try:
|
||||
@@ -408,15 +413,15 @@ class MemoryAgentService:
|
||||
raise # Re-raise our specific error
|
||||
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
|
||||
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
|
||||
|
||||
|
||||
logger.info(f"Read operation for group {group_id} with config_id {config_id}")
|
||||
|
||||
|
||||
# 导入审计日志记录器
|
||||
try:
|
||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||
except ImportError:
|
||||
audit_logger = None
|
||||
|
||||
|
||||
# Get group lock to prevent concurrent processing
|
||||
group_lock = self.get_group_lock(group_id)
|
||||
|
||||
@@ -432,7 +437,7 @@ class MemoryAgentService:
|
||||
except ConfigurationError as e:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||||
logger.error(error_msg)
|
||||
|
||||
|
||||
# Log failed operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
@@ -444,305 +449,133 @@ class MemoryAgentService:
|
||||
duration=duration,
|
||||
error=error_msg
|
||||
)
|
||||
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
# Step 2: Prepare history
|
||||
history.append({"role": "user", "content": message})
|
||||
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
|
||||
|
||||
|
||||
# Step 3: Initialize MCP client and execute read workflow
|
||||
mcp_config = get_mcp_server_config()
|
||||
client = MultiServerMCPClient(mcp_config)
|
||||
|
||||
async with client.session('data_flow') as session:
|
||||
session_start = time.time()
|
||||
logger.debug("Connected to MCP Server: data_flow")
|
||||
|
||||
tools_start = time.time()
|
||||
tools = await load_mcp_tools(session)
|
||||
tools_time = time.time() - tools_start
|
||||
logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s")
|
||||
|
||||
outputs = []
|
||||
intermediate_outputs = []
|
||||
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
|
||||
|
||||
# Pass memory_config to the graph workflow
|
||||
graph_start = time.time()
|
||||
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
|
||||
graph_init_time = time.time() - graph_start
|
||||
logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s")
|
||||
|
||||
start = time.time()
|
||||
try:
|
||||
async with make_read_graph() as graph:
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
workflow_errors = [] # Track errors from workflow
|
||||
|
||||
event_count = 0
|
||||
async for event in graph.astream(
|
||||
{"messages": history, "memory_config": memory_config, "errors": []},
|
||||
stream_mode="values",
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
|
||||
"group_id": group_id
|
||||
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
|
||||
"memory_config": memory_config}
|
||||
# 获取节点更新信息
|
||||
_intermediate_outputs = []
|
||||
summary = ''
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
event_count += 1
|
||||
event_start = time.time()
|
||||
messages = event.get('messages')
|
||||
# Capture any errors from the state
|
||||
if event.get('errors'):
|
||||
workflow_errors.extend(event.get('errors', []))
|
||||
for node_name, node_data in update_event.items():
|
||||
|
||||
for msg in messages:
|
||||
msg_content = msg.content
|
||||
msg_role = msg.__class__.__name__.lower().replace("message", "")
|
||||
outputs.append({
|
||||
"role": msg_role,
|
||||
"content": msg_content
|
||||
})
|
||||
# 处理不同Summary节点的返回结构
|
||||
if 'Summary' in node_name:
|
||||
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
|
||||
summary = node_data['InputSummary']['summary_result']
|
||||
elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
|
||||
summary = node_data['RetrieveSummary']['summary_result']
|
||||
elif 'summary' in node_data and 'summary_result' in node_data['summary']:
|
||||
summary = node_data['summary']['summary_result']
|
||||
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
|
||||
summary = node_data['SummaryFails']['summary_result']
|
||||
|
||||
# Extract intermediate outputs
|
||||
if hasattr(msg, 'content'):
|
||||
try:
|
||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
||||
content_to_parse = msg_content
|
||||
if isinstance(msg_content, list):
|
||||
for block in msg_content:
|
||||
if isinstance(block, dict) and block.get('type') == 'text':
|
||||
content_to_parse = block.get('text', '')
|
||||
break
|
||||
else:
|
||||
continue # No text block found
|
||||
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
|
||||
if spit_data and spit_data != [] and spit_data != {}:
|
||||
_intermediate_outputs.append(spit_data)
|
||||
|
||||
# Try to parse content as JSON
|
||||
if isinstance(content_to_parse, str):
|
||||
try:
|
||||
parsed = json.loads(content_to_parse)
|
||||
if isinstance(parsed, dict):
|
||||
# Check for single intermediate output
|
||||
if '_intermediate' in parsed:
|
||||
intermediate_data = parsed['_intermediate']
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
# Problem_Extension 节点
|
||||
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
|
||||
if problem_extension and problem_extension != [] and problem_extension != {}:
|
||||
_intermediate_outputs.append(problem_extension)
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
||||
# Retrieve 节点
|
||||
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
|
||||
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
||||
_intermediate_outputs.extend(retrieve_node)
|
||||
|
||||
# Check for multiple intermediate outputs (from Retrieve)
|
||||
if '_intermediates' in parsed:
|
||||
for intermediate_data in parsed['_intermediates']:
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
# Verify 节点
|
||||
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
||||
if verify_n and verify_n != [] and verify_n != {}:
|
||||
_intermediate_outputs.append(verify_n)
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract intermediate output: {e}")
|
||||
|
||||
event_time = time.time() - event_start
|
||||
logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s")
|
||||
# Summary 节点
|
||||
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
||||
if summary_n and summary_n != [] and summary_n != {}:
|
||||
_intermediate_outputs.append(summary_n)
|
||||
|
||||
workflow_duration = time.time() - start
|
||||
session_duration = time.time() - session_start
|
||||
logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s")
|
||||
logger.info(f"[PERF] Total session duration: {session_duration:.4f}s")
|
||||
logger.info(f"[PERF] Total events processed: {event_count}")
|
||||
# Extract final answer
|
||||
final_answer = ""
|
||||
for messages in outputs:
|
||||
if messages['role'] == 'tool':
|
||||
message = messages['content']
|
||||
_intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
|
||||
|
||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
||||
if isinstance(message, list):
|
||||
# Extract text from MCP content blocks
|
||||
for block in message:
|
||||
if isinstance(block, dict) and block.get('type') == 'text':
|
||||
message = block.get('text', '')
|
||||
break
|
||||
else:
|
||||
continue # No text block found
|
||||
optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
|
||||
result = reorder_output_results(optimized_outputs)
|
||||
|
||||
try:
|
||||
parsed = json.loads(message) if isinstance(message, str) else message
|
||||
if isinstance(parsed, dict):
|
||||
if parsed.get('status') == 'success':
|
||||
summary_result = parsed.get('summary_result')
|
||||
if summary_result:
|
||||
final_answer = summary_result
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
# Log successful operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
group_id=group_id,
|
||||
success=True,
|
||||
duration=duration
|
||||
)
|
||||
|
||||
# 记录成功的操作
|
||||
total_duration = time.time() - start_time
|
||||
|
||||
# Check for workflow errors
|
||||
if workflow_errors:
|
||||
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
|
||||
logger.warning(f"Read workflow completed with errors: {error_details}")
|
||||
retrieved_content = []
|
||||
repo = ShortTermMemoryRepository(db)
|
||||
if str(search_switch).strip() != "2":
|
||||
for intermediate in result:
|
||||
intermediate_type = intermediate['type']
|
||||
if intermediate_type == "search_result":
|
||||
query = intermediate['query']
|
||||
raw_results = intermediate['raw_results']
|
||||
reranked_results = raw_results.get('reranked_results', [])
|
||||
try:
|
||||
statements = [statement['statement'] for statement in
|
||||
reranked_results.get('statements', [])]
|
||||
except Exception:
|
||||
statements = []
|
||||
statements = list(set(statements))
|
||||
retrieved_content.append({query: statements})
|
||||
if retrieved_content == []:
|
||||
retrieved_content = ''
|
||||
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2": # and retrieved_content!=[]
|
||||
# 使用 upsert 方法
|
||||
repo.upsert(
|
||||
end_user_id=end_user_id, # 确保这个变量在作用域内
|
||||
messages=ori_message,
|
||||
aimessages=summary,
|
||||
retrieved_content=retrieved_content,
|
||||
search_switch=str(search_switch)
|
||||
)
|
||||
print("写入成功")
|
||||
|
||||
return {
|
||||
"answer": summary,
|
||||
"intermediate_outputs": result
|
||||
}
|
||||
except Exception as e:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Read operation failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
group_id=group_id,
|
||||
success=False,
|
||||
duration=total_duration,
|
||||
error=error_details,
|
||||
details={
|
||||
"search_switch": search_switch,
|
||||
"history_length": len(history),
|
||||
"intermediate_outputs_count": len(intermediate_outputs),
|
||||
"has_answer": bool(final_answer),
|
||||
"errors": workflow_errors
|
||||
}
|
||||
duration=duration,
|
||||
error=error_msg
|
||||
)
|
||||
|
||||
# Raise error if no answer was produced
|
||||
if not final_answer:
|
||||
raise ValueError(f"Read workflow failed: {error_details}")
|
||||
|
||||
if audit_logger and not workflow_errors:
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
group_id=group_id,
|
||||
success=True,
|
||||
duration=total_duration,
|
||||
details={
|
||||
"search_switch": search_switch,
|
||||
"history_length": len(history),
|
||||
"intermediate_outputs_count": len(intermediate_outputs),
|
||||
"has_answer": bool(final_answer)
|
||||
}
|
||||
)
|
||||
retrieved_content=[]
|
||||
repo = ShortTermMemoryRepository(db)
|
||||
if str(search_switch)!="2":
|
||||
for intermediate in intermediate_outputs:
|
||||
print(intermediate)
|
||||
intermediate_type=intermediate['type']
|
||||
if intermediate_type=="search_result":
|
||||
query=intermediate['query']
|
||||
raw_results=intermediate['raw_results']
|
||||
reranked_results=raw_results.get('reranked_results',[])
|
||||
try:
|
||||
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
|
||||
except Exception:
|
||||
statements=[]
|
||||
statements=list(set(statements))
|
||||
retrieved_content.append({query:statements})
|
||||
if retrieved_content==[]:
|
||||
retrieved_content=''
|
||||
if '信息不足,无法回答。' != str(final_answer) :#and retrieved_content!=[]
|
||||
# 使用 upsert 方法
|
||||
repo.upsert(
|
||||
end_user_id=end_user_id, # 确保这个变量在作用域内
|
||||
messages=ori_message,
|
||||
aimessages=final_answer,
|
||||
retrieved_content=retrieved_content,
|
||||
search_switch=str(search_switch)
|
||||
)
|
||||
print("写入成功")
|
||||
|
||||
|
||||
|
||||
return {
|
||||
"answer": final_answer,
|
||||
"intermediate_outputs": intermediate_outputs
|
||||
}
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def _create_intermediate_key(self, output: Dict) -> str:
|
||||
"""
|
||||
Create a unique key for an intermediate output to detect duplicates.
|
||||
|
||||
Args:
|
||||
output: Intermediate output dictionary
|
||||
|
||||
Returns:
|
||||
Unique string key for this output
|
||||
"""
|
||||
output_type = output.get('type', 'unknown')
|
||||
|
||||
if output_type == 'problem_split':
|
||||
# Use type + original query as key
|
||||
return f"split:{output.get('original_query', '')}"
|
||||
elif output_type == 'problem_extension':
|
||||
# Use type + original query as key
|
||||
return f"extension:{output.get('original_query', '')}"
|
||||
elif output_type == 'search_result':
|
||||
# Use type + query + index as key
|
||||
return f"search:{output.get('query', '')}:{output.get('index', 0)}"
|
||||
elif output_type == 'retrieval_summary':
|
||||
# Use type + query as key
|
||||
return f"summary:{output.get('query', '')}"
|
||||
elif output_type == 'verification':
|
||||
# Use type + query as key
|
||||
return f"verification:{output.get('query', '')}"
|
||||
elif output_type == 'input_summary':
|
||||
# Use type + query as key
|
||||
return f"input_summary:{output.get('query', '')}"
|
||||
else:
|
||||
# Fallback: use JSON representation
|
||||
import json
|
||||
return json.dumps(output, sort_keys=True)
|
||||
|
||||
def _format_intermediate_output(self, output: Dict) -> Dict:
|
||||
"""Format intermediate output for frontend display."""
|
||||
output_type = output.get('type', 'unknown')
|
||||
|
||||
if output_type == 'problem_split':
|
||||
return {
|
||||
'type': 'problem_split',
|
||||
'title': '问题拆分',
|
||||
'data': output.get('data', []),
|
||||
'original_query': output.get('original_query', '')
|
||||
}
|
||||
elif output_type == 'problem_extension':
|
||||
return {
|
||||
'type': 'problem_extension',
|
||||
'title': '问题扩展',
|
||||
'data': output.get('data', {}),
|
||||
'original_query': output.get('original_query', '')
|
||||
}
|
||||
elif output_type == 'search_result':
|
||||
return {
|
||||
'type': 'search_result',
|
||||
'title': f'检索结果 ({output.get("index", 0)}/{output.get("total", 0)})',
|
||||
'query': output.get('query', ''),
|
||||
'raw_results': output.get('raw_results', ''),
|
||||
'index': output.get('index', 0),
|
||||
'total': output.get('total', 0)
|
||||
}
|
||||
elif output_type == 'retrieval_summary':
|
||||
return {
|
||||
'type': 'retrieval_summary',
|
||||
'title': '检索总结',
|
||||
'summary': output.get('summary', ''),
|
||||
'query': output.get('query', ''),
|
||||
'raw_results': output.get('raw_results'),
|
||||
|
||||
}
|
||||
elif output_type == 'verification':
|
||||
return {
|
||||
'type': 'verification',
|
||||
'title': '数据验证',
|
||||
'result': output.get('result', 'unknown'),
|
||||
'reason': output.get('reason', ''),
|
||||
'query': output.get('query', ''),
|
||||
'verified_count': output.get('verified_count', 0)
|
||||
}
|
||||
elif output_type == 'input_summary':
|
||||
return {
|
||||
'type': 'input_summary',
|
||||
'title': '快速答案',
|
||||
'summary': output.get('summary', ''),
|
||||
'query': output.get('query', ''),
|
||||
'raw_results': output.get('raw_results'),
|
||||
|
||||
}
|
||||
else:
|
||||
return output
|
||||
|
||||
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
|
||||
"""
|
||||
@@ -850,6 +683,7 @@ class MemoryAgentService:
|
||||
# 获取当前空间下的所有宿主
|
||||
from app.repositories import app_repository, end_user_repository
|
||||
from app.schemas.app_schema import App as AppSchema
|
||||
from app.schemas.end_user_schema import EndUser as EndUserSchema
|
||||
|
||||
# 查询应用并转换为 Pydantic 模型
|
||||
apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id)
|
||||
@@ -1147,43 +981,6 @@ class MemoryAgentService:
|
||||
logger.info("Log streaming completed, cleaning up resources")
|
||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||
|
||||
# async def get_api_docs(self, file_path: Optional[str] = None) -> Dict[str, Any]:
|
||||
# """
|
||||
# Parse and return API documentation
|
||||
|
||||
# Args:
|
||||
# file_path: Optional path to API docs file. If None, uses default path.
|
||||
|
||||
# Returns:
|
||||
# Dict containing parsed API documentation or error information
|
||||
# """
|
||||
# try:
|
||||
# target = file_path or get_default_docs_path()
|
||||
|
||||
# if not os.path.isfile(target):
|
||||
# return {
|
||||
# "success": False,
|
||||
# "msg": "API文档文件不存在",
|
||||
# "error_code": "DOC_NOT_FOUND",
|
||||
# "data": {"path": target}
|
||||
# }
|
||||
|
||||
# data = parse_api_docs(target)
|
||||
# return {
|
||||
# "success": True,
|
||||
# "msg": "解析成功",
|
||||
# "data": data
|
||||
# }
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to parse API docs: {e}")
|
||||
# return {
|
||||
# "success": False,
|
||||
# "msg": "解析失败",
|
||||
# "error_code": "DOC_PARSE_ERROR",
|
||||
# "data": {"error": str(e)}
|
||||
# }
|
||||
|
||||
|
||||
def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]:
|
||||
"""
|
||||
获取终端用户关联的记忆配置
|
||||
@@ -1192,20 +989,18 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
1. 根据 end_user_id 获取用户的 app_id
|
||||
2. 获取该应用的最新发布版本
|
||||
3. 从发布版本的 config 字段中提取 memory_config_id
|
||||
4. 根据 memory_config_id 查询配置名称
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
包含 memory_config_id、config_name 和相关信息的字典
|
||||
包含 memory_config_id 和相关信息的字典
|
||||
|
||||
Raises:
|
||||
ValueError: 当终端用户不存在或应用未发布时
|
||||
"""
|
||||
from app.models.app_release_model import AppRelease
|
||||
from app.models.data_config_model import DataConfig
|
||||
from app.models.end_user_model import EndUser
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -1239,31 +1034,15 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
memory_obj = config.get('memory', {})
|
||||
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||
|
||||
# 4. 根据 memory_config_id 查询配置名称
|
||||
config_name = None
|
||||
if memory_config_id:
|
||||
try:
|
||||
# memory_config_id 可能是整数或字符串,需要转换
|
||||
config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id
|
||||
data_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
||||
if data_config:
|
||||
config_name = data_config.config_name
|
||||
logger.debug(f"Found config_name: {config_name} for config_id: {config_id}")
|
||||
else:
|
||||
logger.warning(f"DataConfig not found for config_id: {config_id}")
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"Invalid memory_config_id format: {memory_config_id}, error: {str(e)}")
|
||||
|
||||
result = {
|
||||
"end_user_id": str(end_user_id),
|
||||
"app_id": str(app_id),
|
||||
"release_id": str(latest_release.id),
|
||||
"release_version": latest_release.version,
|
||||
"memory_config_id": memory_config_id,
|
||||
"memory_config_name": config_name
|
||||
"memory_config_id": memory_config_id
|
||||
}
|
||||
|
||||
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, config_name={config_name}")
|
||||
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
|
||||
return result
|
||||
|
||||
|
||||
@@ -1271,126 +1050,112 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
"""
|
||||
批量获取多个终端用户关联的记忆配置
|
||||
|
||||
通过优化的查询减少数据库往返次数:
|
||||
1. 一次性查询所有 end_user 及其 app_id
|
||||
2. 批量查询所有相关的 app_release
|
||||
3. 批量查询所有相关的 data_config
|
||||
通过以下流程获取配置:
|
||||
1. 批量查询所有 end_user 及其 app_id
|
||||
2. 批量获取所有应用的最新发布版本
|
||||
3. 从发布版本的 config 字段中提取 memory_config_id 和 memory_config_name
|
||||
|
||||
Args:
|
||||
end_user_ids: 终端用户ID列表
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
字典,key 为 end_user_id,value 为配置信息字典
|
||||
对于查询失败的用户,value 包含 error 字段
|
||||
字典,key 为 end_user_id,value 为包含 memory_config_id 和 memory_config_name 的字典
|
||||
格式: {
|
||||
"user_id_1": {"memory_config_id": "xxx", "memory_config_name": "xxx"},
|
||||
"user_id_2": {"memory_config_id": None, "memory_config_name": None},
|
||||
...
|
||||
}
|
||||
"""
|
||||
from app.models.app_release_model import AppRelease
|
||||
from app.models.data_config_model import DataConfig
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
from sqlalchemy import select
|
||||
|
||||
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end users")
|
||||
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users")
|
||||
|
||||
result = {}
|
||||
|
||||
# 1. 批量查询所有 end_user 及其 app_id
|
||||
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
|
||||
|
||||
# 构建 end_user_id -> end_user 的映射
|
||||
end_user_map = {str(user.id): user for user in end_users}
|
||||
# 创建 end_user_id 到 app_id 的映射
|
||||
user_to_app = {str(eu.id): eu.app_id for eu in end_users}
|
||||
|
||||
# 记录不存在的用户
|
||||
for user_id in end_user_ids:
|
||||
if user_id not in end_user_map:
|
||||
result[user_id] = {
|
||||
"end_user_id": user_id,
|
||||
"memory_config_id": None,
|
||||
"memory_config_name": None,
|
||||
"error": f"终端用户不存在: {user_id}"
|
||||
}
|
||||
# 获取所有相关的 app_id
|
||||
app_ids = list(set(user_to_app.values()))
|
||||
|
||||
if not end_users:
|
||||
logger.warning("No valid end users found")
|
||||
if not app_ids:
|
||||
logger.warning("No valid app_ids found for the provided end_user_ids")
|
||||
# 返回空配置
|
||||
for user_id in end_user_ids:
|
||||
result[user_id] = {"memory_config_id": None, "memory_config_name": None}
|
||||
return result
|
||||
|
||||
# 2. 批量查询所有相关应用的最新发布版本
|
||||
app_ids = [user.app_id for user in end_users]
|
||||
|
||||
# 2. 批量获取所有应用的最新发布版本
|
||||
# 使用子查询找到每个 app 的最新版本
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import func
|
||||
|
||||
# 查询所有相关的活跃发布版本
|
||||
releases = db.query(AppRelease).filter(
|
||||
and_(
|
||||
AppRelease.app_id.in_(app_ids),
|
||||
AppRelease.is_active.is_(True)
|
||||
subq = (
|
||||
select(
|
||||
AppRelease.app_id,
|
||||
func.max(AppRelease.version).label('max_version')
|
||||
)
|
||||
).order_by(AppRelease.app_id, AppRelease.version.desc()).all()
|
||||
.where(AppRelease.app_id.in_(app_ids), AppRelease.is_active.is_(True))
|
||||
.group_by(AppRelease.app_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# 构建 app_id -> latest_release 的映射(每个 app 只保留最新版本)
|
||||
app_release_map = {}
|
||||
for release in releases:
|
||||
app_id_str = str(release.app_id)
|
||||
if app_id_str not in app_release_map:
|
||||
app_release_map[app_id_str] = release
|
||||
stmt = (
|
||||
select(AppRelease)
|
||||
.join(
|
||||
subq,
|
||||
(AppRelease.app_id == subq.c.app_id) & (AppRelease.version == subq.c.max_version)
|
||||
)
|
||||
.where(AppRelease.is_active.is_(True))
|
||||
)
|
||||
|
||||
# 3. 收集所有 memory_config_id
|
||||
latest_releases = db.scalars(stmt).all()
|
||||
|
||||
# 创建 app_id 到 release 的映射
|
||||
app_to_release = {str(release.app_id): release for release in latest_releases}
|
||||
|
||||
# 3. 提取所有 memory_config_id
|
||||
memory_config_ids = []
|
||||
for release in app_release_map.values():
|
||||
for release in latest_releases:
|
||||
config = release.config or {}
|
||||
memory_obj = config.get('memory', {})
|
||||
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||
if memory_config_id:
|
||||
try:
|
||||
config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id
|
||||
memory_config_ids.append(config_id)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
memory_config_ids.append(memory_config_id)
|
||||
|
||||
# 4. 批量查询所有 data_config
|
||||
config_name_map = {}
|
||||
# 4. 批量查询 memory_config_name
|
||||
memory_configs = {}
|
||||
if memory_config_ids:
|
||||
data_configs = db.query(DataConfig).filter(
|
||||
DataConfig.config_id.in_(memory_config_ids)
|
||||
).all()
|
||||
config_name_map = {config.config_id: config.config_name for config in data_configs}
|
||||
configs = db.query(MemoryConfig).filter(MemoryConfig.id.in_(memory_config_ids)).all()
|
||||
memory_configs = {str(cfg.id): cfg.config_name for cfg in configs}
|
||||
|
||||
# 5. 组装结果
|
||||
for user in end_users:
|
||||
user_id = str(user.id)
|
||||
app_id = str(user.app_id)
|
||||
|
||||
# 检查是否有发布版本
|
||||
if app_id not in app_release_map:
|
||||
result[user_id] = {
|
||||
"end_user_id": user_id,
|
||||
"memory_config_id": None,
|
||||
"memory_config_name": None,
|
||||
"error": f"应用未发布: {app_id}"
|
||||
}
|
||||
for user_id in end_user_ids:
|
||||
app_id = user_to_app.get(user_id)
|
||||
if not app_id:
|
||||
result[user_id] = {"memory_config_id": None, "memory_config_name": None}
|
||||
continue
|
||||
|
||||
release = app_release_map[app_id]
|
||||
release = app_to_release.get(str(app_id))
|
||||
if not release:
|
||||
result[user_id] = {"memory_config_id": None, "memory_config_name": None}
|
||||
continue
|
||||
|
||||
# 提取 memory_config_id
|
||||
config = release.config or {}
|
||||
memory_obj = config.get('memory', {})
|
||||
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||
|
||||
# 获取 config_name
|
||||
config_name = None
|
||||
if memory_config_id:
|
||||
try:
|
||||
config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id
|
||||
config_name = config_name_map.get(config_id)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
memory_config_name = memory_configs.get(memory_config_id) if memory_config_id else None
|
||||
|
||||
result[user_id] = {
|
||||
"end_user_id": user_id,
|
||||
"memory_config_id": memory_config_id,
|
||||
"memory_config_name": config_name
|
||||
"memory_config_name": memory_config_name
|
||||
}
|
||||
|
||||
logger.info(f"Successfully retrieved batch configs: total={len(result)}, with_config={sum(1 for v in result.values() if v.get('memory_config_id'))}")
|
||||
logger.info(f"Successfully retrieved {len(result)} connected configs")
|
||||
return result
|
||||
Reference in New Issue
Block a user