去掉MCP框架,重构
This commit is contained in:
@@ -9,30 +9,28 @@ import os
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
import redis
|
import redis
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.logging_config import get_config_logger, get_logger
|
from app.core.logging_config import get_config_logger, get_logger
|
||||||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
||||||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||||
from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config
|
from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results
|
||||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||||
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_config_schema import ConfigurationError
|
from app.schemas.memory_config_schema import ConfigurationError
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
from app.services.memory_konwledges_server import (
|
from app.services.memory_konwledges_server import (
|
||||||
write_rag,
|
write_rag,
|
||||||
)
|
)
|
||||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
|
||||||
from langchain_mcp_adapters.tools import load_mcp_tools
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -51,20 +49,16 @@ _neo4j_connector = Neo4jConnector()
|
|||||||
class MemoryAgentService:
|
class MemoryAgentService:
|
||||||
"""Service for memory agent operations"""
|
"""Service for memory agent operations"""
|
||||||
|
|
||||||
|
def writer_messages_deal(self, messages, start_time, group_id, config_id, message, context):
|
||||||
|
|
||||||
def writer_messages_deal(self,messages,start_time,group_id,config_id,message):
|
|
||||||
messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '')
|
|
||||||
countext = re.findall(r'"status": "(.*?)",', messages)[0]
|
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
|
|
||||||
if countext == 'success':
|
if str(messages) == 'success':
|
||||||
logger.info(f"Write operation successful for group {group_id} with config_id {config_id}")
|
logger.info(f"Write operation successful for group {group_id} with config_id {config_id}")
|
||||||
# 记录成功的操作
|
# 记录成功的操作
|
||||||
if audit_logger:
|
if audit_logger:
|
||||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=True,
|
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=True,
|
||||||
duration=duration, details={"message_length": len(message)})
|
duration=duration, details={"message_length": len(message)})
|
||||||
return countext
|
return context
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Write operation failed for group {group_id}")
|
logger.warning(f"Write operation failed for group {group_id}")
|
||||||
|
|
||||||
@@ -143,8 +137,26 @@ class MemoryAgentService:
|
|||||||
else:
|
else:
|
||||||
status = "unknown"
|
status = "unknown"
|
||||||
|
|
||||||
|
# Add database connection pool status
|
||||||
|
try:
|
||||||
|
from app.db import get_pool_status
|
||||||
|
pool_status = get_pool_status()
|
||||||
|
logger.info(f"Database pool status: {pool_status}")
|
||||||
|
|
||||||
|
# Check if pool usage is too high
|
||||||
|
if pool_status.get("usage_percent", 0) > 80:
|
||||||
|
logger.warning(f"High database pool usage: {pool_status['usage_percent']}%")
|
||||||
|
status = "warning"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get pool status: {e}")
|
||||||
|
pool_status = {"error": str(e)}
|
||||||
|
|
||||||
logger.info(f"Health status: {status}")
|
logger.info(f"Health status: {status}")
|
||||||
return {"status": status}
|
return {
|
||||||
|
"status": status,
|
||||||
|
"database_pool": pool_status
|
||||||
|
}
|
||||||
|
|
||||||
def get_log_content(self) -> str:
|
def get_log_content(self) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -156,8 +168,7 @@ class MemoryAgentService:
|
|||||||
"""
|
"""
|
||||||
logger.info("Reading log file")
|
logger.info("Reading log file")
|
||||||
|
|
||||||
# Use project root directory for logs
|
|
||||||
# Get the project root (redbear-mem directory)
|
|
||||||
current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py
|
current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py
|
||||||
app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory
|
app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory
|
||||||
project_root = os.path.dirname(app_dir) # redbear-mem directory
|
project_root = os.path.dirname(app_dir) # redbear-mem directory
|
||||||
@@ -301,53 +312,41 @@ class MemoryAgentService:
|
|||||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
|
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
|
||||||
|
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
mcp_config = get_mcp_server_config()
|
|
||||||
client = MultiServerMCPClient(mcp_config)
|
|
||||||
|
|
||||||
if storage_type == "rag":
|
|
||||||
result = await write_rag(group_id, message, user_rag_memory_id)
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
async with client.session("data_flow") as session:
|
|
||||||
logger.debug("Connected to MCP Server: data_flow")
|
|
||||||
tools = await load_mcp_tools(session)
|
|
||||||
workflow_errors = [] # Track errors from workflow
|
|
||||||
|
|
||||||
# Pass memory_config to the graph workflow
|
|
||||||
async with make_write_graph(group_id, tools, group_id, group_id, memory_config=memory_config) as graph:
|
|
||||||
logger.debug("Write graph created successfully")
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
if storage_type == "rag":
|
||||||
|
result = await write_rag(group_id, message, user_rag_memory_id)
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
async with make_write_graph() as graph:
|
||||||
config = {"configurable": {"thread_id": group_id}}
|
config = {"configurable": {"thread_id": group_id}}
|
||||||
|
# 初始状态 - 包含所有必要字段
|
||||||
|
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id,
|
||||||
|
"memory_config": memory_config}
|
||||||
|
|
||||||
async for event in graph.astream(
|
# 获取节点更新信息
|
||||||
{"messages": message, "memory_config": memory_config, "errors": []},
|
async for update_event in graph.astream(
|
||||||
stream_mode="values",
|
initial_state,
|
||||||
|
stream_mode="updates",
|
||||||
config=config
|
config=config
|
||||||
):
|
):
|
||||||
messages = event.get('messages')
|
for node_name, node_data in update_event.items():
|
||||||
# Capture any errors from the state
|
if 'save_neo4j' == node_name:
|
||||||
if event.get('errors'):
|
massages = node_data
|
||||||
workflow_errors.extend(event.get('errors', []))
|
massagesstatus = massages.get('write_result')['status']
|
||||||
|
contents = massages.get('write_result')
|
||||||
|
return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message, contents)
|
||||||
|
except Exception as e:
|
||||||
|
# Ensure proper error handling and logging
|
||||||
|
error_msg = f"Write operation failed: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
if audit_logger:
|
||||||
|
duration = time.time() - start_time
|
||||||
|
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
# Check for workflow errors
|
|
||||||
if workflow_errors:
|
|
||||||
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
|
|
||||||
logger.error(f"Write workflow failed with errors: {error_details}")
|
|
||||||
|
|
||||||
if audit_logger:
|
|
||||||
duration = time.time() - start_time
|
|
||||||
audit_logger.log_operation(
|
|
||||||
operation="WRITE",
|
|
||||||
config_id=config_id,
|
|
||||||
group_id=group_id,
|
|
||||||
success=False,
|
|
||||||
duration=duration,
|
|
||||||
error=error_details
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(f"Write workflow failed: {error_details}")
|
|
||||||
|
|
||||||
return self.writer_messages_deal(messages, start_time, group_id, config_id, message)
|
|
||||||
|
|
||||||
async def read_memory(
|
async def read_memory(
|
||||||
self,
|
self,
|
||||||
@@ -387,8 +386,7 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
ori_message=message
|
|
||||||
end_user_id=group_id
|
|
||||||
# Resolve config_id if None using end_user's connected config
|
# Resolve config_id if None using end_user's connected config
|
||||||
if config_id is None:
|
if config_id is None:
|
||||||
try:
|
try:
|
||||||
@@ -410,6 +408,7 @@ class MemoryAgentService:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
audit_logger = None
|
audit_logger = None
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config_service = MemoryConfigService(db)
|
config_service = MemoryConfigService(db)
|
||||||
memory_config = config_service.load_memory_config(
|
memory_config = config_service.load_memory_config(
|
||||||
@@ -440,298 +439,98 @@ class MemoryAgentService:
|
|||||||
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
|
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
|
||||||
|
|
||||||
# Step 3: Initialize MCP client and execute read workflow
|
# Step 3: Initialize MCP client and execute read workflow
|
||||||
mcp_config = get_mcp_server_config()
|
try:
|
||||||
client = MultiServerMCPClient(mcp_config)
|
async with make_read_graph() as graph:
|
||||||
|
|
||||||
async with client.session('data_flow') as session:
|
|
||||||
session_start = time.time()
|
|
||||||
logger.debug("Connected to MCP Server: data_flow")
|
|
||||||
|
|
||||||
tools_start = time.time()
|
|
||||||
tools = await load_mcp_tools(session)
|
|
||||||
tools_time = time.time() - tools_start
|
|
||||||
logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s")
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
intermediate_outputs = []
|
|
||||||
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
|
|
||||||
|
|
||||||
# Pass memory_config to the graph workflow
|
|
||||||
graph_start = time.time()
|
|
||||||
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
|
|
||||||
graph_init_time = time.time() - graph_start
|
|
||||||
logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s")
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
config = {"configurable": {"thread_id": group_id}}
|
config = {"configurable": {"thread_id": group_id}}
|
||||||
workflow_errors = [] # Track errors from workflow
|
# 初始状态 - 包含所有必要字段
|
||||||
|
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
|
||||||
event_count = 0
|
"group_id": group_id
|
||||||
async for event in graph.astream(
|
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
|
||||||
{"messages": history, "memory_config": memory_config, "errors": []},
|
"memory_config": memory_config}
|
||||||
stream_mode="values",
|
# 获取节点更新信息
|
||||||
|
_intermediate_outputs = []
|
||||||
|
summary = ''
|
||||||
|
async for update_event in graph.astream(
|
||||||
|
initial_state,
|
||||||
|
stream_mode="updates",
|
||||||
config=config
|
config=config
|
||||||
):
|
):
|
||||||
event_count += 1
|
for node_name, node_data in update_event.items():
|
||||||
event_start = time.time()
|
print(f"处理节点: {node_name}")
|
||||||
messages = event.get('messages')
|
|
||||||
# Capture any errors from the state
|
|
||||||
if event.get('errors'):
|
|
||||||
workflow_errors.extend(event.get('errors', []))
|
|
||||||
|
|
||||||
for msg in messages:
|
# 处理不同Summary节点的返回结构
|
||||||
msg_content = msg.content
|
if 'Summary' in node_name:
|
||||||
msg_role = msg.__class__.__name__.lower().replace("message", "")
|
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
|
||||||
outputs.append({
|
summary = node_data['InputSummary']['summary_result']
|
||||||
"role": msg_role,
|
elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
|
||||||
"content": msg_content
|
summary = node_data['RetrieveSummary']['summary_result']
|
||||||
})
|
elif 'summary' in node_data and 'summary_result' in node_data['summary']:
|
||||||
|
summary = node_data['summary']['summary_result']
|
||||||
|
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
|
||||||
|
summary = node_data['SummaryFails']['summary_result']
|
||||||
|
|
||||||
# Extract intermediate outputs
|
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
|
||||||
if hasattr(msg, 'content'):
|
if spit_data and spit_data != [] and spit_data != {}:
|
||||||
try:
|
_intermediate_outputs.append(spit_data)
|
||||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
|
||||||
content_to_parse = msg_content
|
|
||||||
if isinstance(msg_content, list):
|
|
||||||
for block in msg_content:
|
|
||||||
if isinstance(block, dict) and block.get('type') == 'text':
|
|
||||||
content_to_parse = block.get('text', '')
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
continue # No text block found
|
|
||||||
|
|
||||||
# Try to parse content as JSON
|
# Problem_Extension 节点
|
||||||
if isinstance(content_to_parse, str):
|
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
|
||||||
try:
|
if problem_extension and problem_extension != [] and problem_extension != {}:
|
||||||
parsed = json.loads(content_to_parse)
|
_intermediate_outputs.append(problem_extension)
|
||||||
if isinstance(parsed, dict):
|
|
||||||
# Check for single intermediate output
|
|
||||||
if '_intermediate' in parsed:
|
|
||||||
intermediate_data = parsed['_intermediate']
|
|
||||||
output_key = self._create_intermediate_key(intermediate_data)
|
|
||||||
|
|
||||||
if output_key not in seen_intermediates:
|
# Retrieve 节点
|
||||||
seen_intermediates.add(output_key)
|
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
|
||||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
||||||
|
_intermediate_outputs.extend(retrieve_node)
|
||||||
|
|
||||||
# Check for multiple intermediate outputs (from Retrieve)
|
# Verify 节点
|
||||||
if '_intermediates' in parsed:
|
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
||||||
for intermediate_data in parsed['_intermediates']:
|
if verify_n and verify_n != [] and verify_n != {}:
|
||||||
output_key = self._create_intermediate_key(intermediate_data)
|
_intermediate_outputs.append(verify_n)
|
||||||
|
|
||||||
if output_key not in seen_intermediates:
|
# Summary 节点
|
||||||
seen_intermediates.add(output_key)
|
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
||||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
if summary_n and summary_n != [] and summary_n != {}:
|
||||||
except (json.JSONDecodeError, ValueError):
|
_intermediate_outputs.append(summary_n)
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"Failed to extract intermediate output: {e}")
|
|
||||||
|
|
||||||
event_time = time.time() - event_start
|
_intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
|
||||||
logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s")
|
|
||||||
|
|
||||||
workflow_duration = time.time() - start
|
optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
|
||||||
session_duration = time.time() - session_start
|
result = reorder_output_results(optimized_outputs)
|
||||||
logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s")
|
|
||||||
logger.info(f"[PERF] Total session duration: {session_duration:.4f}s")
|
|
||||||
logger.info(f"[PERF] Total events processed: {event_count}")
|
|
||||||
# Extract final answer
|
|
||||||
final_answer = ""
|
|
||||||
for messages in outputs:
|
|
||||||
if messages['role'] == 'tool':
|
|
||||||
message = messages['content']
|
|
||||||
|
|
||||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
# Log successful operation
|
||||||
if isinstance(message, list):
|
if audit_logger:
|
||||||
# Extract text from MCP content blocks
|
duration = time.time() - start_time
|
||||||
for block in message:
|
audit_logger.log_operation(
|
||||||
if isinstance(block, dict) and block.get('type') == 'text':
|
operation="READ",
|
||||||
message = block.get('text', '')
|
config_id=config_id,
|
||||||
break
|
group_id=group_id,
|
||||||
else:
|
success=True,
|
||||||
continue # No text block found
|
duration=duration
|
||||||
|
)
|
||||||
try:
|
|
||||||
parsed = json.loads(message) if isinstance(message, str) else message
|
|
||||||
if isinstance(parsed, dict):
|
|
||||||
if parsed.get('status') == 'success':
|
|
||||||
summary_result = parsed.get('summary_result')
|
|
||||||
if summary_result:
|
|
||||||
final_answer = summary_result
|
|
||||||
except (json.JSONDecodeError, ValueError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 记录成功的操作
|
|
||||||
total_duration = time.time() - start_time
|
|
||||||
|
|
||||||
# Check for workflow errors
|
|
||||||
if workflow_errors:
|
|
||||||
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
|
|
||||||
logger.warning(f"Read workflow completed with errors: {error_details}")
|
|
||||||
|
|
||||||
|
return {
|
||||||
|
"answer": summary,
|
||||||
|
"intermediate_outputs": result
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
# Ensure proper error handling and logging
|
||||||
|
error_msg = f"Read operation failed: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
if audit_logger:
|
if audit_logger:
|
||||||
|
duration = time.time() - start_time
|
||||||
audit_logger.log_operation(
|
audit_logger.log_operation(
|
||||||
operation="READ",
|
operation="READ",
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
success=False,
|
success=False,
|
||||||
duration=total_duration,
|
duration=duration,
|
||||||
error=error_details,
|
error=error_msg
|
||||||
details={
|
|
||||||
"search_switch": search_switch,
|
|
||||||
"history_length": len(history),
|
|
||||||
"intermediate_outputs_count": len(intermediate_outputs),
|
|
||||||
"has_answer": bool(final_answer),
|
|
||||||
"errors": workflow_errors
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
raise ValueError(error_msg)
|
||||||
# Raise error if no answer was produced
|
|
||||||
if not final_answer:
|
|
||||||
raise ValueError(f"Read workflow failed: {error_details}")
|
|
||||||
|
|
||||||
if audit_logger and not workflow_errors:
|
|
||||||
audit_logger.log_operation(
|
|
||||||
operation="READ",
|
|
||||||
config_id=config_id,
|
|
||||||
group_id=group_id,
|
|
||||||
success=True,
|
|
||||||
duration=total_duration,
|
|
||||||
details={
|
|
||||||
"search_switch": search_switch,
|
|
||||||
"history_length": len(history),
|
|
||||||
"intermediate_outputs_count": len(intermediate_outputs),
|
|
||||||
"has_answer": bool(final_answer)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
retrieved_content=[]
|
|
||||||
repo = ShortTermMemoryRepository(db)
|
|
||||||
if str(search_switch)!="2":
|
|
||||||
for intermediate in intermediate_outputs:
|
|
||||||
print(intermediate)
|
|
||||||
intermediate_type=intermediate['type']
|
|
||||||
if intermediate_type=="search_result":
|
|
||||||
query=intermediate['query']
|
|
||||||
raw_results=intermediate['raw_results']
|
|
||||||
reranked_results=raw_results.get('reranked_results',[])
|
|
||||||
try:
|
|
||||||
statements=[statement['statement'] for statement in reranked_results.get('statements', [])]
|
|
||||||
except Exception:
|
|
||||||
statements=[]
|
|
||||||
statements=list(set(statements))
|
|
||||||
retrieved_content.append({query:statements})
|
|
||||||
if retrieved_content==[]:
|
|
||||||
retrieved_content=''
|
|
||||||
if '信息不足,无法回答。' != str(final_answer) and str(search_switch).strip() != "2":#and retrieved_content!=[]
|
|
||||||
# 使用 upsert 方法
|
|
||||||
repo.upsert(
|
|
||||||
end_user_id=end_user_id, # 确保这个变量在作用域内
|
|
||||||
messages=ori_message,
|
|
||||||
aimessages=final_answer,
|
|
||||||
retrieved_content=retrieved_content,
|
|
||||||
search_switch=str(search_switch)
|
|
||||||
)
|
|
||||||
print("写入成功")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return {
|
|
||||||
"answer": final_answer,
|
|
||||||
"intermediate_outputs": intermediate_outputs
|
|
||||||
}
|
|
||||||
|
|
||||||
def _create_intermediate_key(self, output: Dict) -> str:
|
|
||||||
"""
|
|
||||||
Create a unique key for an intermediate output to detect duplicates.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output: Intermediate output dictionary
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Unique string key for this output
|
|
||||||
"""
|
|
||||||
output_type = output.get('type', 'unknown')
|
|
||||||
|
|
||||||
if output_type == 'problem_split':
|
|
||||||
# Use type + original query as key
|
|
||||||
return f"split:{output.get('original_query', '')}"
|
|
||||||
elif output_type == 'problem_extension':
|
|
||||||
# Use type + original query as key
|
|
||||||
return f"extension:{output.get('original_query', '')}"
|
|
||||||
elif output_type == 'search_result':
|
|
||||||
# Use type + query + index as key
|
|
||||||
return f"search:{output.get('query', '')}:{output.get('index', 0)}"
|
|
||||||
elif output_type == 'retrieval_summary':
|
|
||||||
# Use type + query as key
|
|
||||||
return f"summary:{output.get('query', '')}"
|
|
||||||
elif output_type == 'verification':
|
|
||||||
# Use type + query as key
|
|
||||||
return f"verification:{output.get('query', '')}"
|
|
||||||
elif output_type == 'input_summary':
|
|
||||||
# Use type + query as key
|
|
||||||
return f"input_summary:{output.get('query', '')}"
|
|
||||||
else:
|
|
||||||
# Fallback: use JSON representation
|
|
||||||
import json
|
|
||||||
return json.dumps(output, sort_keys=True)
|
|
||||||
|
|
||||||
def _format_intermediate_output(self, output: Dict) -> Dict:
|
|
||||||
"""Format intermediate output for frontend display."""
|
|
||||||
output_type = output.get('type', 'unknown')
|
|
||||||
|
|
||||||
if output_type == 'problem_split':
|
|
||||||
return {
|
|
||||||
'type': 'problem_split',
|
|
||||||
'title': '问题拆分',
|
|
||||||
'data': output.get('data', []),
|
|
||||||
'original_query': output.get('original_query', '')
|
|
||||||
}
|
|
||||||
elif output_type == 'problem_extension':
|
|
||||||
return {
|
|
||||||
'type': 'problem_extension',
|
|
||||||
'title': '问题扩展',
|
|
||||||
'data': output.get('data', {}),
|
|
||||||
'original_query': output.get('original_query', '')
|
|
||||||
}
|
|
||||||
elif output_type == 'search_result':
|
|
||||||
return {
|
|
||||||
'type': 'search_result',
|
|
||||||
'title': f'检索结果 ({output.get("index", 0)}/{output.get("total", 0)})',
|
|
||||||
'query': output.get('query', ''),
|
|
||||||
'raw_results': output.get('raw_results', ''),
|
|
||||||
'index': output.get('index', 0),
|
|
||||||
'total': output.get('total', 0)
|
|
||||||
}
|
|
||||||
elif output_type == 'retrieval_summary':
|
|
||||||
return {
|
|
||||||
'type': 'retrieval_summary',
|
|
||||||
'title': '检索总结',
|
|
||||||
'summary': output.get('summary', ''),
|
|
||||||
'query': output.get('query', ''),
|
|
||||||
'raw_results': output.get('raw_results'),
|
|
||||||
|
|
||||||
}
|
|
||||||
elif output_type == 'verification':
|
|
||||||
return {
|
|
||||||
'type': 'verification',
|
|
||||||
'title': '数据验证',
|
|
||||||
'result': output.get('result', 'unknown'),
|
|
||||||
'reason': output.get('reason', ''),
|
|
||||||
'query': output.get('query', ''),
|
|
||||||
'verified_count': output.get('verified_count', 0)
|
|
||||||
}
|
|
||||||
elif output_type == 'input_summary':
|
|
||||||
return {
|
|
||||||
'type': 'input_summary',
|
|
||||||
'title': '快速答案',
|
|
||||||
'summary': output.get('summary', ''),
|
|
||||||
'query': output.get('query', ''),
|
|
||||||
'raw_results': output.get('raw_results'),
|
|
||||||
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return output
|
|
||||||
|
|
||||||
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
|
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
|
||||||
"""
|
"""
|
||||||
Determine the type of user message (read or write)
|
Determine the type of user message (read or write)
|
||||||
@@ -838,6 +637,7 @@ class MemoryAgentService:
|
|||||||
# 获取当前空间下的所有宿主
|
# 获取当前空间下的所有宿主
|
||||||
from app.repositories import app_repository, end_user_repository
|
from app.repositories import app_repository, end_user_repository
|
||||||
from app.schemas.app_schema import App as AppSchema
|
from app.schemas.app_schema import App as AppSchema
|
||||||
|
from app.schemas.end_user_schema import EndUser as EndUserSchema
|
||||||
|
|
||||||
# 查询应用并转换为 Pydantic 模型
|
# 查询应用并转换为 Pydantic 模型
|
||||||
apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id)
|
apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id)
|
||||||
@@ -1135,43 +935,6 @@ class MemoryAgentService:
|
|||||||
logger.info("Log streaming completed, cleaning up resources")
|
logger.info("Log streaming completed, cleaning up resources")
|
||||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||||
|
|
||||||
# async def get_api_docs(self, file_path: Optional[str] = None) -> Dict[str, Any]:
|
|
||||||
# """
|
|
||||||
# Parse and return API documentation
|
|
||||||
|
|
||||||
# Args:
|
|
||||||
# file_path: Optional path to API docs file. If None, uses default path.
|
|
||||||
|
|
||||||
# Returns:
|
|
||||||
# Dict containing parsed API documentation or error information
|
|
||||||
# """
|
|
||||||
# try:
|
|
||||||
# target = file_path or get_default_docs_path()
|
|
||||||
|
|
||||||
# if not os.path.isfile(target):
|
|
||||||
# return {
|
|
||||||
# "success": False,
|
|
||||||
# "msg": "API文档文件不存在",
|
|
||||||
# "error_code": "DOC_NOT_FOUND",
|
|
||||||
# "data": {"path": target}
|
|
||||||
# }
|
|
||||||
|
|
||||||
# data = parse_api_docs(target)
|
|
||||||
# return {
|
|
||||||
# "success": True,
|
|
||||||
# "msg": "解析成功",
|
|
||||||
# "data": data
|
|
||||||
# }
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"Failed to parse API docs: {e}")
|
|
||||||
# return {
|
|
||||||
# "success": False,
|
|
||||||
# "msg": "解析失败",
|
|
||||||
# "error_code": "DOC_PARSE_ERROR",
|
|
||||||
# "data": {"error": str(e)}
|
|
||||||
# }
|
|
||||||
|
|
||||||
|
|
||||||
def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]:
|
def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取终端用户关联的记忆配置
|
获取终端用户关联的记忆配置
|
||||||
@@ -1180,20 +943,18 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
|||||||
1. 根据 end_user_id 获取用户的 app_id
|
1. 根据 end_user_id 获取用户的 app_id
|
||||||
2. 获取该应用的最新发布版本
|
2. 获取该应用的最新发布版本
|
||||||
3. 从发布版本的 config 字段中提取 memory_config_id
|
3. 从发布版本的 config 字段中提取 memory_config_id
|
||||||
4. 根据 memory_config_id 查询配置名称
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id: 终端用户ID
|
end_user_id: 终端用户ID
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
包含 memory_config_id、config_name 和相关信息的字典
|
包含 memory_config_id 和相关信息的字典
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: 当终端用户不存在或应用未发布时
|
ValueError: 当终端用户不存在或应用未发布时
|
||||||
"""
|
"""
|
||||||
from app.models.app_release_model import AppRelease
|
from app.models.app_release_model import AppRelease
|
||||||
from app.models.data_config_model import DataConfig
|
|
||||||
from app.models.end_user_model import EndUser
|
from app.models.end_user_model import EndUser
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
@@ -1227,158 +988,13 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
|||||||
memory_obj = config.get('memory', {})
|
memory_obj = config.get('memory', {})
|
||||||
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||||
|
|
||||||
# 4. 根据 memory_config_id 查询配置名称
|
|
||||||
config_name = None
|
|
||||||
if memory_config_id:
|
|
||||||
try:
|
|
||||||
# memory_config_id 可能是整数或字符串,需要转换
|
|
||||||
config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id
|
|
||||||
data_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first()
|
|
||||||
if data_config:
|
|
||||||
config_name = data_config.config_name
|
|
||||||
logger.debug(f"Found config_name: {config_name} for config_id: {config_id}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"DataConfig not found for config_id: {config_id}")
|
|
||||||
except (ValueError, TypeError) as e:
|
|
||||||
logger.warning(f"Invalid memory_config_id format: {memory_config_id}, error: {str(e)}")
|
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"end_user_id": str(end_user_id),
|
"end_user_id": str(end_user_id),
|
||||||
"app_id": str(app_id),
|
"app_id": str(app_id),
|
||||||
"release_id": str(latest_release.id),
|
"release_id": str(latest_release.id),
|
||||||
"release_version": latest_release.version,
|
"release_version": latest_release.version,
|
||||||
"memory_config_id": memory_config_id,
|
"memory_config_id": memory_config_id
|
||||||
"memory_config_name": config_name
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, config_name={config_name}")
|
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) -> Dict[str, Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
批量获取多个终端用户关联的记忆配置
|
|
||||||
|
|
||||||
通过优化的查询减少数据库往返次数:
|
|
||||||
1. 一次性查询所有 end_user 及其 app_id
|
|
||||||
2. 批量查询所有相关的 app_release
|
|
||||||
3. 批量查询所有相关的 data_config
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_ids: 终端用户ID列表
|
|
||||||
db: 数据库会话
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
字典,key 为 end_user_id,value 为配置信息字典
|
|
||||||
对于查询失败的用户,value 包含 error 字段
|
|
||||||
"""
|
|
||||||
from app.models.app_release_model import AppRelease
|
|
||||||
from app.models.data_config_model import DataConfig
|
|
||||||
from app.models.end_user_model import EndUser
|
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end users")
|
|
||||||
|
|
||||||
result = {}
|
|
||||||
|
|
||||||
# 1. 批量查询所有 end_user 及其 app_id
|
|
||||||
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
|
|
||||||
|
|
||||||
# 构建 end_user_id -> end_user 的映射
|
|
||||||
end_user_map = {str(user.id): user for user in end_users}
|
|
||||||
|
|
||||||
# 记录不存在的用户
|
|
||||||
for user_id in end_user_ids:
|
|
||||||
if user_id not in end_user_map:
|
|
||||||
result[user_id] = {
|
|
||||||
"end_user_id": user_id,
|
|
||||||
"memory_config_id": None,
|
|
||||||
"memory_config_name": None,
|
|
||||||
"error": f"终端用户不存在: {user_id}"
|
|
||||||
}
|
|
||||||
|
|
||||||
if not end_users:
|
|
||||||
logger.warning("No valid end users found")
|
|
||||||
return result
|
|
||||||
|
|
||||||
# 2. 批量查询所有相关应用的最新发布版本
|
|
||||||
app_ids = [user.app_id for user in end_users]
|
|
||||||
|
|
||||||
# 使用子查询找到每个 app 的最新版本
|
|
||||||
from sqlalchemy import and_
|
|
||||||
|
|
||||||
# 查询所有相关的活跃发布版本
|
|
||||||
releases = db.query(AppRelease).filter(
|
|
||||||
and_(
|
|
||||||
AppRelease.app_id.in_(app_ids),
|
|
||||||
AppRelease.is_active.is_(True)
|
|
||||||
)
|
|
||||||
).order_by(AppRelease.app_id, AppRelease.version.desc()).all()
|
|
||||||
|
|
||||||
# 构建 app_id -> latest_release 的映射(每个 app 只保留最新版本)
|
|
||||||
app_release_map = {}
|
|
||||||
for release in releases:
|
|
||||||
app_id_str = str(release.app_id)
|
|
||||||
if app_id_str not in app_release_map:
|
|
||||||
app_release_map[app_id_str] = release
|
|
||||||
|
|
||||||
# 3. 收集所有 memory_config_id
|
|
||||||
memory_config_ids = []
|
|
||||||
for release in app_release_map.values():
|
|
||||||
config = release.config or {}
|
|
||||||
memory_obj = config.get('memory', {})
|
|
||||||
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
|
||||||
if memory_config_id:
|
|
||||||
try:
|
|
||||||
config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id
|
|
||||||
memory_config_ids.append(config_id)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 4. 批量查询所有 data_config
|
|
||||||
config_name_map = {}
|
|
||||||
if memory_config_ids:
|
|
||||||
data_configs = db.query(DataConfig).filter(
|
|
||||||
DataConfig.config_id.in_(memory_config_ids)
|
|
||||||
).all()
|
|
||||||
config_name_map = {config.config_id: config.config_name for config in data_configs}
|
|
||||||
|
|
||||||
# 5. 组装结果
|
|
||||||
for user in end_users:
|
|
||||||
user_id = str(user.id)
|
|
||||||
app_id = str(user.app_id)
|
|
||||||
|
|
||||||
# 检查是否有发布版本
|
|
||||||
if app_id not in app_release_map:
|
|
||||||
result[user_id] = {
|
|
||||||
"end_user_id": user_id,
|
|
||||||
"memory_config_id": None,
|
|
||||||
"memory_config_name": None,
|
|
||||||
"error": f"应用未发布: {app_id}"
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
|
|
||||||
release = app_release_map[app_id]
|
|
||||||
|
|
||||||
# 提取 memory_config_id
|
|
||||||
config = release.config or {}
|
|
||||||
memory_obj = config.get('memory', {})
|
|
||||||
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
|
||||||
|
|
||||||
# 获取 config_name
|
|
||||||
config_name = None
|
|
||||||
if memory_config_id:
|
|
||||||
try:
|
|
||||||
config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id
|
|
||||||
config_name = config_name_map.get(config_id)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
result[user_id] = {
|
|
||||||
"end_user_id": user_id,
|
|
||||||
"memory_config_id": memory_config_id,
|
|
||||||
"memory_config_name": config_name
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"Successfully retrieved batch configs: total={len(result)}, with_config={sum(1 for v in result.values() if v.get('memory_config_id'))}")
|
|
||||||
return result
|
return result
|
||||||
Reference in New Issue
Block a user