* refactor(celery): optimize task routing and worker configuration - Simplify Celery queue configuration with single default 'io_tasks' queue - Implement task routing strategy separating IO-bound and CPU-bound tasks - Add Flower monitoring support with task event tracking enabled - Add summary node search optimization to only retrieve summary nodes - Clean up unused imports and reorganize import statements for consistency - Update docker-compose configuration to support multi-queue worker setup * chore(celery): simplify flower configuration and add gevent dependency * chore(dependencies): add gevent dependency to requirements - Add gevent==24.11.1 to api/requirements.txt - Gevent is required for async worker support in Celery - Complements existing flower and celery configuration * refactor(celery): simplify async event loop handling and reorganize task queues - Replace complex nest_asyncio and manual event loop management with asyncio.run() in read_message_task, write_message_task, regenerate_memory_cache, and workspace_reflection_task - Rename task queues from io_tasks/cpu_tasks to memory_tasks/document_tasks for better semantic clarity - Update task routing configuration to reflect new queue names for memory agent tasks and document processing tasks - Remove redundant exception handling comments and simplify error handling logic - Update README with improved community support section including GitHub Issues, Pull Requests, Discussions, and WeChat community links - Simplifies event loop management by leveraging asyncio.run() which handles loop creation and cleanup automatically, reducing code complexity and potential race conditions
1244 lines
53 KiB
Python
1244 lines
53 KiB
Python
"""
|
||
Memory Agent Service
|
||
|
||
Handles business logic for memory agent operations including read/write services,
|
||
health checks, and message type classification.
|
||
"""
|
||
import json
|
||
import os
|
||
import re
|
||
import time
|
||
import uuid
|
||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||
|
||
import redis
|
||
from app.core.config import settings
|
||
from app.core.logging_config import get_config_logger, get_logger
|
||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||
from app.core.memory.agent.utils.messages_tools import (
|
||
merge_multiple_search_results,
|
||
reorder_output_results,
|
||
)
|
||
from app.core.memory.agent.utils.type_classifier import status_typle
|
||
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
|
||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||
from app.db import get_db_context
|
||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||
from app.schemas.memory_agent_schema import Write_UserInput
|
||
from app.schemas.memory_config_schema import ConfigurationError
|
||
from app.services.memory_base_service import Translation_English
|
||
from app.services.memory_config_service import MemoryConfigService
|
||
from app.services.memory_konwledges_server import (
|
||
write_rag,
|
||
)
|
||
from langchain_core.messages import HumanMessage
|
||
from pydantic import BaseModel, Field
|
||
from sqlalchemy import func
|
||
from sqlalchemy.orm import Session
|
||
|
||
try:
|
||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||
except ImportError:
|
||
audit_logger = None
|
||
logger = get_logger(__name__)
|
||
config_logger = get_config_logger()
|
||
|
||
# Initialize Neo4j connector for analytics functions
|
||
_neo4j_connector = Neo4jConnector()
|
||
|
||
|
||
class MemoryAgentService:
|
||
"""Service for memory agent operations"""
|
||
|
||
def writer_messages_deal(self, messages, start_time, group_id, config_id, message, context):
|
||
duration = time.time() - start_time
|
||
|
||
if str(messages) == 'success':
|
||
logger.info(f"Write operation successful for group {group_id} with config_id {config_id}")
|
||
# 记录成功的操作
|
||
if audit_logger:
|
||
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=True,
|
||
duration=duration, details={"message_length": len(message)})
|
||
return context
|
||
else:
|
||
logger.warning(f"Write operation failed for group {group_id}")
|
||
|
||
# 记录失败的操作
|
||
if audit_logger:
|
||
audit_logger.log_operation(
|
||
operation="WRITE",
|
||
config_id=config_id,
|
||
group_id=group_id,
|
||
success=False,
|
||
duration=duration,
|
||
error=f"写入失败: {messages[:100]}"
|
||
)
|
||
|
||
raise ValueError(f"写入失败: {messages}")
|
||
|
||
|
||
|
||
def extract_tool_call_info(self, event: Dict) -> bool:
|
||
"""Extract tool call information from event"""
|
||
last_message = event["messages"][-1]
|
||
|
||
# Check if AI message contains tool calls
|
||
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
|
||
tool_calls = last_message.tool_calls
|
||
for i, tool_call in enumerate(tool_calls):
|
||
if isinstance(tool_call, dict):
|
||
tool_call_id = tool_call.get('id')
|
||
tool_name = tool_call.get('name')
|
||
tool_args = tool_call.get('args', {})
|
||
else:
|
||
tool_call_id = getattr(tool_call, 'id', None)
|
||
tool_name = getattr(tool_call, 'name', None)
|
||
tool_args = getattr(tool_call, 'args', {})
|
||
|
||
logger.debug(f"Tool Call {i + 1}: ID={tool_call_id}, Name={tool_name}, Args={tool_args}")
|
||
return True
|
||
|
||
# Check if tool message
|
||
elif hasattr(last_message, 'tool_call_id'):
|
||
tool_call_id = getattr(last_message, 'tool_call_id', None)
|
||
if hasattr(last_message, 'name') and hasattr(last_message, 'content'):
|
||
tool_name = getattr(last_message, 'name', None)
|
||
try:
|
||
content = json.loads(getattr(last_message, 'content', '{}'))
|
||
tool_args = content.get('args', {})
|
||
logger.debug(f"Tool Call 1: ID={tool_call_id}, Name={tool_name}, Args={tool_args}")
|
||
except:
|
||
logger.debug(f"Tool Response ID: {tool_call_id}")
|
||
else:
|
||
logger.debug(f"Tool Response ID: {tool_call_id}")
|
||
return True
|
||
|
||
return False
|
||
|
||
async def get_health_status(self) -> Dict:
|
||
"""
|
||
Get latest health status from Redis cache
|
||
|
||
Returns health status information written by Celery periodic task
|
||
"""
|
||
logger.info("Checking health status")
|
||
|
||
client = redis.Redis(
|
||
host=settings.REDIS_HOST,
|
||
port=settings.REDIS_PORT,
|
||
db=settings.REDIS_DB,
|
||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None
|
||
)
|
||
payload = client.hgetall("memsci:health:read_service") or {}
|
||
|
||
if payload:
|
||
# decode bytes to str
|
||
decoded = {k.decode("utf-8"): v.decode("utf-8") for k, v in payload.items()}
|
||
status = decoded.get("status", "unknown")
|
||
else:
|
||
status = "unknown"
|
||
|
||
# Add database connection pool status
|
||
try:
|
||
from app.db import get_pool_status
|
||
pool_status = get_pool_status()
|
||
logger.info(f"Database pool status: {pool_status}")
|
||
|
||
# Check if pool usage is too high
|
||
if pool_status.get("usage_percent", 0) > 80:
|
||
logger.warning(f"High database pool usage: {pool_status['usage_percent']}%")
|
||
status = "warning"
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get pool status: {e}")
|
||
pool_status = {"error": str(e)}
|
||
|
||
logger.info(f"Health status: {status}")
|
||
return {
|
||
"status": status,
|
||
"database_pool": pool_status
|
||
}
|
||
|
||
def get_log_content(self) -> str:
|
||
"""
|
||
Read and return agent service log file content
|
||
|
||
Returns cleaned log content using the same cleaning logic as transmission mode
|
||
|
||
Returns cleaned log content using the same cleaning logic as transmission mode
|
||
"""
|
||
logger.info("Reading log file")
|
||
|
||
|
||
current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py
|
||
app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory
|
||
project_root = os.path.dirname(app_dir) # redbear-mem directory
|
||
log_path = os.path.join(project_root, "logs", "agent_service.log")
|
||
|
||
summer = ''
|
||
|
||
with open(log_path, "r", encoding="utf-8") as infile:
|
||
for line in infile:
|
||
# Use the same cleaning logic as LogStreamer for consistency
|
||
cleaned = LogStreamer.clean_log_line(line)
|
||
summer += cleaned
|
||
|
||
if len(summer) < 10:
|
||
raise ValueError("NO LOGS")
|
||
|
||
logger.info(f"Log content retrieved, size: {len(summer)} bytes")
|
||
return summer
|
||
|
||
async def stream_log_content(self) -> AsyncGenerator[str, None]:
|
||
"""
|
||
Stream log content in real-time using Server-Sent Events (SSE)
|
||
|
||
This method establishes a streaming connection and transmits log entries
|
||
as they are written to the log file. It uses the LogStreamer to watch
|
||
the file and yields SSE-formatted messages.
|
||
|
||
Yields:
|
||
SSE-formatted strings with the following event types:
|
||
- log: Contains log content and timestamp
|
||
- keepalive: Periodic keepalive messages to maintain connection
|
||
- error: Error information if streaming fails
|
||
- done: Indicates streaming has completed
|
||
|
||
Raises:
|
||
FileNotFoundError: If log file doesn't exist at stream start
|
||
Exception: For other unexpected errors during streaming
|
||
"""
|
||
logger.info("Starting log content streaming")
|
||
|
||
# Get log file path - use project root directory
|
||
current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py
|
||
app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory
|
||
project_root = os.path.dirname(app_dir) # redbear-mem directory
|
||
log_path = os.path.join(project_root, "logs", "agent_service.log")
|
||
|
||
# Check if file exists before starting stream
|
||
if not os.path.exists(log_path):
|
||
logger.error(f"Log file not found: {log_path}")
|
||
# Send error event in SSE format
|
||
yield f"event: error\ndata: {json.dumps({'code': 4006, 'message': '日志文件不存在', 'error': f'File not found: {log_path}'})}\n\n"
|
||
return
|
||
|
||
streamer = None
|
||
try:
|
||
# Initialize LogStreamer with keepalive interval from settings (default 300 seconds)
|
||
keepalive_interval = getattr(settings, 'LOG_STREAM_KEEPALIVE_INTERVAL', 300)
|
||
streamer = LogStreamer(log_path, keepalive_interval=keepalive_interval)
|
||
|
||
logger.info(f"LogStreamer initialized for {log_path}")
|
||
|
||
# Stream log content using read_existing_and_stream to get all existing content first
|
||
async for message in streamer.read_existing_and_stream():
|
||
event_type = message.get("event")
|
||
data = message.get("data")
|
||
|
||
# Format as SSE message
|
||
# SSE format: "event: <type>\ndata: <json_data>\n\n"
|
||
sse_message = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||
|
||
logger.debug(f"Streaming event: {event_type}")
|
||
yield sse_message
|
||
|
||
# If error or done event, stop streaming
|
||
if event_type in ["error", "done"]:
|
||
logger.info(f"Stream ended with event: {event_type}")
|
||
break
|
||
|
||
except FileNotFoundError as e:
|
||
logger.error(f"Log file not found during streaming: {e}")
|
||
yield f"event: error\ndata: {json.dumps({'code': 4006, 'message': '日志文件在流式传输期间变得不可用', 'error': str(e)})}\n\n"
|
||
|
||
except Exception as e:
|
||
logger.error(f"Unexpected error during log streaming: {e}", exc_info=True)
|
||
yield f"event: error\ndata: {json.dumps({'code': 8001, 'message': '流式传输期间发生错误', 'error': str(e)})}\n\n"
|
||
|
||
finally:
|
||
# Resource cleanup
|
||
logger.info("Log streaming completed, cleaning up resources")
|
||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||
|
||
async def write_memory(self, group_id: str, messages: list[dict], config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
|
||
"""
|
||
Process write operation with config_id
|
||
|
||
Args:
|
||
group_id: Group identifier (also used as end_user_id)
|
||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||
config_id: Configuration ID from database
|
||
db: SQLAlchemy database session
|
||
storage_type: Storage type (neo4j or rag)
|
||
user_rag_memory_id: User RAG memory ID
|
||
|
||
Returns:
|
||
Write operation result status
|
||
|
||
Raises:
|
||
ValueError: If config loading fails or write operation fails
|
||
"""
|
||
# Resolve config_id if None using end_user's connected config
|
||
if config_id is None:
|
||
try:
|
||
connected_config = get_end_user_connected_config(group_id, db)
|
||
config_id = connected_config.get("memory_config_id")
|
||
if config_id is None:
|
||
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
|
||
except Exception as e:
|
||
if "No memory configuration found" in str(e):
|
||
raise
|
||
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
|
||
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
|
||
|
||
import time
|
||
start_time = time.time()
|
||
|
||
# Load configuration from database only
|
||
try:
|
||
config_service = MemoryConfigService(db)
|
||
memory_config = config_service.load_memory_config(
|
||
config_id=config_id,
|
||
service_name="MemoryAgentService"
|
||
)
|
||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||
except ConfigurationError as e:
|
||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||
logger.error(error_msg)
|
||
|
||
# Log failed operation
|
||
if audit_logger:
|
||
duration = time.time() - start_time
|
||
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
|
||
|
||
raise ValueError(error_msg)
|
||
|
||
try:
|
||
if storage_type == "rag":
|
||
# For RAG storage, convert messages to single string
|
||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||
result = await write_rag(group_id, message_text, user_rag_memory_id)
|
||
return result
|
||
else:
|
||
async with make_write_graph() as graph:
|
||
config = {"configurable": {"thread_id": group_id}}
|
||
# Convert structured messages to LangChain messages
|
||
langchain_messages = []
|
||
for msg in messages:
|
||
if msg['role'] == 'user':
|
||
langchain_messages.append(HumanMessage(content=msg['content']))
|
||
elif msg['role'] == 'assistant':
|
||
from langchain_core.messages import AIMessage
|
||
langchain_messages.append(AIMessage(content=msg['content']))
|
||
|
||
# 初始状态 - 包含所有必要字段
|
||
initial_state = {
|
||
"messages": langchain_messages,
|
||
"group_id": group_id,
|
||
"memory_config": memory_config
|
||
}
|
||
|
||
# 获取节点更新信息
|
||
async for update_event in graph.astream(
|
||
initial_state,
|
||
stream_mode="updates",
|
||
config=config
|
||
):
|
||
for node_name, node_data in update_event.items():
|
||
if 'save_neo4j' == node_name:
|
||
massages = node_data
|
||
massagesstatus = massages.get('write_result')['status']
|
||
contents = massages.get('write_result')
|
||
# Convert messages back to string for logging
|
||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||
return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message_text, contents)
|
||
except Exception as e:
|
||
# Ensure proper error handling and logging
|
||
error_msg = f"Write operation failed: {str(e)}"
|
||
logger.error(error_msg)
|
||
if audit_logger:
|
||
duration = time.time() - start_time
|
||
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
|
||
raise ValueError(error_msg)
|
||
|
||
|
||
|
||
|
||
async def read_memory(
|
||
self,
|
||
group_id: str,
|
||
message: str,
|
||
history: List[Dict],
|
||
search_switch: str,
|
||
config_id: Optional[str],
|
||
db: Session,
|
||
storage_type: str,
|
||
user_rag_memory_id: str
|
||
) -> Dict:
|
||
"""
|
||
Process read operation with config_id
|
||
|
||
search_switch values:
|
||
- "0": Requires verification
|
||
- "1": No verification, direct split
|
||
- "2": Direct answer based on context
|
||
|
||
Args:
|
||
group_id: Group identifier (also used as end_user_id)
|
||
message: User message
|
||
history: Conversation history
|
||
search_switch: Search mode switch
|
||
config_id: Configuration ID from database
|
||
db: SQLAlchemy database session
|
||
storage_type: Storage type (neo4j or rag)
|
||
user_rag_memory_id: User RAG memory ID
|
||
|
||
Returns:
|
||
Dict with 'answer' and 'intermediate_outputs' keys
|
||
|
||
Raises:
|
||
ValueError: If config loading fails
|
||
"""
|
||
|
||
import time
|
||
start_time = time.time()
|
||
logger.info(f"[PERF] read_memory started for group_id={group_id}, search_switch={search_switch}")
|
||
|
||
# Resolve config_id if None using end_user's connected config
|
||
if config_id is None:
|
||
try:
|
||
connected_config = get_end_user_connected_config(group_id, db)
|
||
config_id = connected_config.get("memory_config_id")
|
||
if config_id is None:
|
||
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
|
||
except Exception as e:
|
||
if "No memory configuration found" in str(e):
|
||
raise # Re-raise our specific error
|
||
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
|
||
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
|
||
|
||
logger.info(f"Read operation for group {group_id} with config_id {config_id}")
|
||
|
||
# 导入审计日志记录器
|
||
try:
|
||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||
except ImportError:
|
||
audit_logger = None
|
||
|
||
|
||
config_load_start = time.time()
|
||
try:
|
||
config_service = MemoryConfigService(db)
|
||
memory_config = config_service.load_memory_config(
|
||
config_id=config_id,
|
||
service_name="MemoryAgentService"
|
||
)
|
||
config_load_time = time.time() - config_load_start
|
||
logger.info(f"[PERF] Configuration loaded in {config_load_time:.4f}s: {memory_config.config_name}")
|
||
except ConfigurationError as e:
|
||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||
logger.error(error_msg)
|
||
|
||
# Log failed operation
|
||
if audit_logger:
|
||
duration = time.time() - start_time
|
||
audit_logger.log_operation(
|
||
operation="READ",
|
||
config_id=config_id,
|
||
group_id=group_id,
|
||
success=False,
|
||
duration=duration,
|
||
error=error_msg
|
||
)
|
||
|
||
raise ValueError(error_msg)
|
||
|
||
# Step 2: Prepare history
|
||
history.append({"role": "user", "content": message})
|
||
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
|
||
|
||
# Step 3: Initialize MCP client and execute read workflow
|
||
graph_exec_start = time.time()
|
||
try:
|
||
async with make_read_graph() as graph:
|
||
config = {"configurable": {"thread_id": group_id}}
|
||
# 初始状态 - 包含所有必要字段
|
||
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
|
||
"group_id": group_id
|
||
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
|
||
"memory_config": memory_config}
|
||
# 获取节点更新信息
|
||
_intermediate_outputs = []
|
||
summary = ''
|
||
async for update_event in graph.astream(
|
||
initial_state,
|
||
stream_mode="updates",
|
||
config=config
|
||
):
|
||
for node_name, node_data in update_event.items():
|
||
# if 'save_neo4j' == node_name:
|
||
# massages = node_data
|
||
print(f"处理节点: {node_name}")
|
||
|
||
# 处理不同Summary节点的返回结构
|
||
if 'Summary' in node_name:
|
||
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
|
||
summary = node_data['InputSummary']['summary_result']
|
||
elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
|
||
summary = node_data['RetrieveSummary']['summary_result']
|
||
elif 'summary' in node_data and 'summary_result' in node_data['summary']:
|
||
summary = node_data['summary']['summary_result']
|
||
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
|
||
summary = node_data['SummaryFails']['summary_result']
|
||
|
||
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
|
||
if spit_data and spit_data != [] and spit_data != {}:
|
||
_intermediate_outputs.append(spit_data)
|
||
|
||
# Problem_Extension 节点
|
||
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
|
||
if problem_extension and problem_extension != [] and problem_extension != {}:
|
||
_intermediate_outputs.append(problem_extension)
|
||
|
||
# Retrieve 节点
|
||
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
|
||
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
||
_intermediate_outputs.extend(retrieve_node)
|
||
|
||
# Verify 节点
|
||
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
||
if verify_n and verify_n != [] and verify_n != {}:
|
||
_intermediate_outputs.append(verify_n)
|
||
|
||
# Summary 节点
|
||
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
||
if summary_n and summary_n != [] and summary_n != {}:
|
||
_intermediate_outputs.append(summary_n)
|
||
|
||
graph_exec_time = time.time() - graph_exec_start
|
||
logger.info(f"[PERF] Graph execution completed in {graph_exec_time:.4f}s")
|
||
|
||
_intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
|
||
|
||
optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
|
||
result = reorder_output_results(optimized_outputs)
|
||
|
||
# 保存短期记忆到数据库
|
||
# 只有 search_switch 不为 "2"(快速检索)时才保存
|
||
try:
|
||
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
||
|
||
retrieved_content = []
|
||
repo = ShortTermMemoryRepository(db)
|
||
|
||
if str(search_switch) != "2":
|
||
for intermediate in _intermediate_outputs:
|
||
logger.debug(f"处理中间结果: {intermediate}")
|
||
intermediate_type = intermediate.get('type', '')
|
||
|
||
if intermediate_type == "search_result":
|
||
query = intermediate.get('query', '')
|
||
raw_results = intermediate.get('raw_results', {})
|
||
reranked_results = raw_results.get('reranked_results', [])
|
||
|
||
try:
|
||
statements = [statement['statement'] for statement in reranked_results.get('statements', [])]
|
||
except Exception:
|
||
statements = []
|
||
|
||
# 去重
|
||
statements = list(set(statements))
|
||
|
||
if query and statements:
|
||
retrieved_content.append({query: statements})
|
||
|
||
# 如果 retrieved_content 为空,设置为空字符串
|
||
if retrieved_content == []:
|
||
retrieved_content = ''
|
||
|
||
# 只有当回答不是"信息不足"且不是快速检索时才保存
|
||
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2":
|
||
# 使用 upsert 方法
|
||
repo.upsert(
|
||
end_user_id=group_id,
|
||
messages=message,
|
||
aimessages=summary,
|
||
retrieved_content=retrieved_content,
|
||
search_switch=str(search_switch)
|
||
)
|
||
logger.info(f"成功保存短期记忆: group_id={group_id}, search_switch={search_switch}")
|
||
else:
|
||
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
|
||
|
||
except Exception as save_error:
|
||
# 保存失败不应该影响主流程,只记录错误
|
||
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
|
||
|
||
# Log successful operation
|
||
total_time = time.time() - start_time
|
||
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
|
||
if audit_logger:
|
||
duration = time.time() - start_time
|
||
audit_logger.log_operation(
|
||
operation="READ",
|
||
config_id=config_id,
|
||
group_id=group_id,
|
||
success=True,
|
||
duration=duration
|
||
)
|
||
|
||
return {
|
||
"answer": summary,
|
||
"intermediate_outputs": result
|
||
}
|
||
except Exception as e:
|
||
# Ensure proper error handling and logging
|
||
error_msg = f"Read operation failed: {str(e)}"
|
||
total_time = time.time() - start_time
|
||
logger.error(f"[PERF] read_memory failed after {total_time:.4f}s: {error_msg}")
|
||
if audit_logger:
|
||
duration = time.time() - start_time
|
||
audit_logger.log_operation(
|
||
operation="READ",
|
||
config_id=config_id,
|
||
group_id=group_id,
|
||
success=False,
|
||
duration=duration,
|
||
error=error_msg
|
||
)
|
||
raise ValueError(error_msg)
|
||
|
||
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
|
||
"""
|
||
Get standardized message list from user input.
|
||
|
||
Args:
|
||
user_input: Write_UserInput object
|
||
|
||
Returns:
|
||
list[dict]: Message list, each message contains role and content
|
||
|
||
Raises:
|
||
ValueError: If messages is empty or format is incorrect
|
||
"""
|
||
from app.core.logging_config import get_api_logger
|
||
logger = get_api_logger()
|
||
|
||
if len(user_input.messages) == 0:
|
||
logger.error("Validation failed: Message list cannot be empty")
|
||
raise ValueError("Message list cannot be empty")
|
||
|
||
for idx, msg in enumerate(user_input.messages):
|
||
if not isinstance(msg, dict):
|
||
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
|
||
raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
|
||
|
||
if 'role' not in msg:
|
||
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
|
||
raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}")
|
||
|
||
if 'content' not in msg:
|
||
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
|
||
raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}")
|
||
|
||
if msg['role'] not in ['user', 'assistant']:
|
||
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
|
||
raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}")
|
||
|
||
if not msg['content'] or not msg['content'].strip():
|
||
logger.error(f"Validation failed: Message {idx} content is empty")
|
||
raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}")
|
||
|
||
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
|
||
return user_input.messages
|
||
|
||
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
|
||
"""
|
||
Determine the type of user message (read or write)
|
||
Updated to eliminate global variables in favor of explicit parameters.
|
||
|
||
Args:
|
||
message: User message to classify
|
||
config_id: Configuration ID to load LLM model from database
|
||
db: Database session
|
||
|
||
Returns:
|
||
Type classification result
|
||
"""
|
||
logger.info("Classifying message type")
|
||
|
||
# Load configuration to get LLM model ID
|
||
config_service = MemoryConfigService(db)
|
||
memory_config = config_service.load_memory_config(
|
||
config_id=config_id,
|
||
service_name="MemoryAgentService"
|
||
)
|
||
|
||
status = await status_typle(message, memory_config.llm_model_id)
|
||
logger.debug(f"Message type: {status}")
|
||
return status
|
||
|
||
# ==================== 新增的三个接口方法 ====================
|
||
|
||
async def get_knowledge_type_stats(
|
||
self,
|
||
end_user_id: Optional[str] = None,
|
||
only_active: bool = True,
|
||
current_workspace_id: Optional[uuid.UUID] = None,
|
||
db: Session = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
统计知识库类型分布,包含:
|
||
1. PostgreSQL 中的知识库类型:General, Web, Third-party, Folder(根据 workspace_id 过滤)
|
||
2. Neo4j 中的 memory 类型(仅统计 Chunk 数量,根据 end_user_id/group_id 过滤)
|
||
3. total: 所有类型的总和
|
||
|
||
参数:
|
||
- end_user_id: 用户组ID(可选,未提供时 memory 统计为 0)
|
||
- only_active: 是否仅统计有效记录
|
||
- current_workspace_id: 当前工作空间ID(可选,未提供时知识库统计为 0)
|
||
- db: 数据库会话
|
||
|
||
返回格式:
|
||
{
|
||
"General": count,
|
||
"Web": count,
|
||
"Third-party": count,
|
||
"Folder": count,
|
||
"memory": chunk_count,
|
||
"total": sum_of_all
|
||
}
|
||
"""
|
||
result = {}
|
||
|
||
# 1. 统计 PostgreSQL 中的知识库类型
|
||
try:
|
||
if db is None:
|
||
from app.db import get_db
|
||
db_gen = get_db()
|
||
db = next(db_gen)
|
||
|
||
# 初始化所有标准类型为 0
|
||
for kb_type in KnowledgeType:
|
||
result[kb_type.value] = 0
|
||
|
||
# 如果提供了 workspace_id,则按 workspace_id 过滤
|
||
if current_workspace_id:
|
||
# 构建查询条件
|
||
query = db.query(
|
||
Knowledge.type,
|
||
func.count(Knowledge.id).label('count')
|
||
).filter(Knowledge.workspace_id == current_workspace_id)
|
||
|
||
# 检查 Knowledge 模型是否有 status 字段
|
||
if only_active and hasattr(Knowledge, 'status'):
|
||
query = query.filter(Knowledge.status == 1)
|
||
|
||
# 按类型分组
|
||
type_counts = query.group_by(Knowledge.type).all()
|
||
|
||
# 只填充标准类型的统计值,忽略其他类型
|
||
valid_types = {kb_type.value for kb_type in KnowledgeType}
|
||
for type_name, count in type_counts:
|
||
if type_name in valid_types:
|
||
result[type_name] = count
|
||
|
||
logger.info(f"知识库类型统计成功 (workspace_id={current_workspace_id}): {result}")
|
||
else:
|
||
# 没有提供 workspace_id,所有知识库类型返回 0
|
||
logger.info("未提供 workspace_id,知识库类型统计全部为 0")
|
||
|
||
except Exception as e:
|
||
logger.error(f"知识库类型统计失败: {e}")
|
||
raise Exception(f"知识库类型统计失败: {e}")
|
||
|
||
# 2. 统计 Neo4j 中的 memory 总量(统计当前空间下所有宿主的 Chunk 总数)
|
||
try:
|
||
if current_workspace_id:
|
||
# 获取当前空间下的所有宿主
|
||
from app.repositories import app_repository, end_user_repository
|
||
from app.schemas.app_schema import App as AppSchema
|
||
from app.schemas.end_user_schema import EndUser as EndUserSchema
|
||
|
||
# 查询应用并转换为 Pydantic 模型
|
||
apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id)
|
||
apps = [AppSchema.model_validate(h) for h in apps_orm]
|
||
app_ids = [app.id for app in apps]
|
||
|
||
# 获取所有宿主
|
||
end_users = []
|
||
for app_id in app_ids:
|
||
end_user_orm_list = end_user_repository.get_end_users_by_app_id(db, app_id)
|
||
end_users.extend(h for h in end_user_orm_list)
|
||
|
||
# 统计所有宿主的 Chunk 总数
|
||
total_chunks = 0
|
||
for end_user in end_users:
|
||
end_user_id_str = str(end_user.id)
|
||
memory_query = """
|
||
MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN count(n) AS Count
|
||
"""
|
||
neo4j_result = await _neo4j_connector.execute_query(
|
||
memory_query,
|
||
group_id=end_user_id_str,
|
||
)
|
||
chunk_count = neo4j_result[0]["Count"] if neo4j_result else 0
|
||
total_chunks += chunk_count
|
||
logger.debug(f"EndUser {end_user_id_str} Chunk数量: {chunk_count}")
|
||
|
||
result["memory"] = total_chunks
|
||
logger.info(f"Neo4j memory统计成功: 总Chunk数={total_chunks}, 宿主数={len(end_users)}")
|
||
else:
|
||
# 没有 workspace_id 时,返回 0
|
||
result["memory"] = 0
|
||
logger.info("未提供 workspace_id,memory 统计为 0")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Neo4j memory统计失败: {e}", exc_info=True)
|
||
# 如果 Neo4j 查询失败,memory 设为 0
|
||
result["memory"] = 0
|
||
|
||
# 3. 计算知识库类型总和(不包括 memory)
|
||
result["total"] = (
|
||
result.get("General", 0) +
|
||
result.get("Web", 0) +
|
||
result.get("Third-party", 0) +
|
||
result.get("Folder", 0)
|
||
)
|
||
|
||
return result
|
||
|
||
|
||
async def get_hot_memory_tags_by_user(
|
||
self,
|
||
end_user_id: Optional[str] = None,
|
||
limit: int = 20,
|
||
model_id: Optional[str] = None,
|
||
language_type: Optional[str] = "zh"
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取指定用户的热门记忆标签
|
||
|
||
参数:
|
||
- end_user_id: 用户ID(可选),对应Neo4j中的group_id字段
|
||
- limit: 返回标签数量限制
|
||
|
||
返回格式:
|
||
[
|
||
{"name": "标签名", "frequency": 频次},
|
||
...
|
||
]
|
||
"""
|
||
try:
|
||
# by_user=False 表示按 group_id 查询(在Neo4j中,group_id就是用户维度)
|
||
tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False)
|
||
payload=[]
|
||
for tag, freq in tags:
|
||
if language_type!="zh":
|
||
tag=await Translation_English(model_id, tag)
|
||
payload.append({"name": tag, "frequency": freq})
|
||
else:
|
||
payload.append({"name": tag, "frequency": freq})
|
||
return payload
|
||
except Exception as e:
|
||
logger.error(f"热门记忆标签查询失败: {e}")
|
||
raise Exception(f"热门记忆标签查询失败: {e}")
|
||
|
||
|
||
async def get_user_profile(
|
||
self,
|
||
end_user_id: Optional[str] = None,
|
||
current_user_id: Optional[str] = None,
|
||
llm_id: Optional[str] = None,
|
||
db: Session = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
获取用户详情,包含:
|
||
1. 用户名字(直接使用 end_user_name)
|
||
2. 用户标签(从摘要中用LLM总结3个标签)
|
||
3. 热门记忆标签(从hot_memory_tags获取前4个)
|
||
|
||
参数:
|
||
- end_user_id: 用户ID(可选)
|
||
- current_user_id: 当前登录用户的ID(保留参数)
|
||
- llm_id: LLM模型ID(用于生成标签,可选,如果不提供则跳过标签生成)
|
||
- db: 数据库会话(可选)
|
||
|
||
返回格式:
|
||
{
|
||
"name": "用户名",
|
||
"tags": ["产品设计师", "旅行爱好者", "摄影发烧友"],
|
||
"hot_tags": [
|
||
{"name": "标签1", "frequency": 10},
|
||
{"name": "标签2", "frequency": 8},
|
||
...
|
||
]
|
||
}
|
||
"""
|
||
result = {}
|
||
|
||
# 1. 根据 end_user_id 获取 end_user_name
|
||
try:
|
||
if end_user_id and db:
|
||
from app.repositories import end_user_repository
|
||
from app.schemas.end_user_schema import EndUser as EndUserSchema
|
||
|
||
end_user_orm = end_user_repository.get_end_user_by_id(db, end_user_id)
|
||
if end_user_orm:
|
||
end_user = EndUserSchema.model_validate(end_user_orm)
|
||
end_user_name = end_user.other_name
|
||
else:
|
||
end_user_name = "默认用户"
|
||
else:
|
||
end_user_name = "默认用户"
|
||
except Exception as e:
|
||
logger.error(f"Failed to get end_user_name: {e}")
|
||
end_user_name = "默认用户"
|
||
|
||
result["name"] = end_user_name
|
||
logger.debug(f"The end_user is: {end_user_name}")
|
||
|
||
# 2. 使用LLM从语句和实体中提取标签
|
||
try:
|
||
connector = Neo4jConnector()
|
||
|
||
# 查询该用户的语句
|
||
query = (
|
||
"MATCH (s:Statement) "
|
||
"WHERE ($group_id IS NULL OR s.group_id = $group_id) AND s.statement IS NOT NULL "
|
||
"RETURN s.statement AS statement "
|
||
"ORDER BY s.created_at DESC LIMIT 100"
|
||
)
|
||
rows = await connector.execute_query(query, group_id=end_user_id)
|
||
statements = [r.get("statement", "") for r in rows if r.get("statement")]
|
||
|
||
# 查询该用户的热门实体
|
||
entity_query = (
|
||
"MATCH (e:ExtractedEntity) "
|
||
"WHERE ($group_id IS NULL OR e.group_id = $group_id) AND e.entity_type <> '人物' AND e.name IS NOT NULL "
|
||
"RETURN e.name AS name, count(e) AS frequency "
|
||
"ORDER BY frequency DESC LIMIT 20"
|
||
)
|
||
entity_rows = await connector.execute_query(entity_query, group_id=end_user_id)
|
||
entities = [f"{r['name']} ({r['frequency']})" for r in entity_rows]
|
||
|
||
await connector.close()
|
||
|
||
if not statements or not llm_id:
|
||
result["tags"] = []
|
||
if not llm_id and statements:
|
||
logger.warning("llm_id not provided, skipping tag generation")
|
||
else:
|
||
# 构建摘要文本
|
||
summary_text = f"用户语句样本:{' | '.join(statements[:20])}\n核心实体:{', '.join(entities)}"
|
||
logger.debug(f"User data found: {len(statements)} statements, {len(entities)} entities")
|
||
|
||
# 使用LLM提取标签
|
||
with get_db_context() as db:
|
||
factory = MemoryClientFactory(db)
|
||
llm_client = factory.get_llm_client(llm_id)
|
||
|
||
# 定义标签提取的结构
|
||
class UserTags(BaseModel):
|
||
tags: list[str] = Field(..., description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友")
|
||
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": "你是一个信息提取助手。从用户的语句和实体中提取3个最能代表用户特征的标签。标签应该简洁(2-6个字),描述用户的职业、兴趣或特点。"
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": f"请从以下用户信息中提取3个标签:\n\n{summary_text}"
|
||
}
|
||
]
|
||
|
||
user_tags = await llm_client.response_structured(
|
||
messages=messages,
|
||
response_model=UserTags
|
||
)
|
||
|
||
result["tags"] = user_tags.tags
|
||
logger.debug(f"Extracted tags: {user_tags.tags}")
|
||
|
||
except Exception as e:
|
||
# 如果提取失败,使用默认值
|
||
logger.error(f"Failed to extract user tags: {e}")
|
||
result["tags"] = []
|
||
|
||
try:
|
||
# 3. 获取热门记忆标签(前4个)
|
||
connector = Neo4jConnector()
|
||
names_to_exclude = ['AI', 'Caroline', 'Melanie', 'Jon', 'Gina', '用户', 'AI助手', 'John', 'Maria']
|
||
hot_tag_query = (
|
||
"MATCH (e:ExtractedEntity) "
|
||
"WHERE ($group_id IS NULL OR e.group_id = $group_id) AND e.entity_type <> '人物' "
|
||
"AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
|
||
"RETURN e.name AS name, count(e) AS frequency "
|
||
"ORDER BY frequency DESC LIMIT 4"
|
||
)
|
||
hot_tag_rows = await connector.execute_query(
|
||
hot_tag_query,
|
||
group_id=end_user_id,
|
||
names_to_exclude=names_to_exclude
|
||
)
|
||
await connector.close()
|
||
|
||
result["hot_tags"] = [{"name": r["name"], "frequency": r["frequency"]} for r in hot_tag_rows]
|
||
logger.debug(f"Hot tags found: {len(result['hot_tags'])} tags")
|
||
except Exception as e:
|
||
logger.error(f"Failed to get hot tags: {e}")
|
||
result["hot_tags"] = []
|
||
|
||
return result
|
||
|
||
async def stream_log_content(self) -> AsyncGenerator[str, None]:
|
||
"""
|
||
Stream log content in real-time using Server-Sent Events (SSE)
|
||
|
||
This method establishes a streaming connection and transmits log entries
|
||
as they are written to the log file. It uses the LogStreamer to watch
|
||
the file and yields SSE-formatted messages.
|
||
|
||
Yields:
|
||
SSE-formatted strings with the following event types:
|
||
- log: Contains log content and timestamp
|
||
- keepalive: Periodic keepalive messages to maintain connection
|
||
- error: Error information if streaming fails
|
||
- done: Indicates streaming has completed
|
||
|
||
Raises:
|
||
FileNotFoundError: If log file doesn't exist at stream start
|
||
Exception: For other unexpected errors during streaming
|
||
"""
|
||
logger.info("Starting log content streaming")
|
||
|
||
# Get log file path - use project root directory
|
||
current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py
|
||
app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory
|
||
project_root = os.path.dirname(app_dir) # redbear-mem directory
|
||
log_path = os.path.join(project_root, "logs", "agent_service.log")
|
||
|
||
# Check if file exists before starting stream
|
||
if not os.path.exists(log_path):
|
||
logger.error(f"Log file not found: {log_path}")
|
||
# Send error event in SSE format
|
||
yield f"event: error\ndata: {json.dumps({'code': 4006, 'message': '日志文件不存在', 'error': f'File not found: {log_path}'})}\n\n"
|
||
return
|
||
|
||
streamer = None
|
||
try:
|
||
# Initialize LogStreamer with keepalive interval from settings (default 300 seconds)
|
||
keepalive_interval = getattr(settings, 'LOG_STREAM_KEEPALIVE_INTERVAL', 300)
|
||
streamer = LogStreamer(log_path, keepalive_interval=keepalive_interval)
|
||
|
||
logger.info(f"LogStreamer initialized for {log_path}")
|
||
|
||
# Stream log content using read_existing_and_stream to get all existing content first
|
||
async for message in streamer.read_existing_and_stream():
|
||
event_type = message.get("event")
|
||
data = message.get("data")
|
||
|
||
# Format as SSE message
|
||
# SSE format: "event: <type>\ndata: <json_data>\n\n"
|
||
sse_message = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||
|
||
logger.debug(f"Streaming event: {event_type}")
|
||
yield sse_message
|
||
|
||
# If error or done event, stop streaming
|
||
if event_type in ["error", "done"]:
|
||
logger.info(f"Stream ended with event: {event_type}")
|
||
break
|
||
|
||
except FileNotFoundError as e:
|
||
logger.error(f"Log file not found during streaming: {e}")
|
||
yield f"event: error\ndata: {json.dumps({'code': 4006, 'message': '日志文件在流式传输期间变得不可用', 'error': str(e)})}\n\n"
|
||
|
||
except Exception as e:
|
||
logger.error(f"Unexpected error during log streaming: {e}", exc_info=True)
|
||
yield f"event: error\ndata: {json.dumps({'code': 8001, 'message': '流式传输期间发生错误', 'error': str(e)})}\n\n"
|
||
|
||
finally:
|
||
# Resource cleanup
|
||
logger.info("Log streaming completed, cleaning up resources")
|
||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||
|
||
def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]:
|
||
"""
|
||
获取终端用户关联的记忆配置
|
||
|
||
通过以下流程获取配置:
|
||
1. 根据 end_user_id 获取用户的 app_id
|
||
2. 获取该应用的最新发布版本
|
||
3. 从发布版本的 config 字段中提取 memory_config_id
|
||
|
||
Args:
|
||
end_user_id: 终端用户ID
|
||
db: 数据库会话
|
||
|
||
Returns:
|
||
包含 memory_config_id 和相关信息的字典
|
||
|
||
Raises:
|
||
ValueError: 当终端用户不存在或应用未发布时
|
||
"""
|
||
from app.models.app_release_model import AppRelease
|
||
from app.models.end_user_model import EndUser
|
||
from sqlalchemy import select
|
||
|
||
logger.info(f"Getting connected config for end_user: {end_user_id}")
|
||
|
||
# 1. 获取 end_user 及其 app_id
|
||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||
if not end_user:
|
||
logger.warning(f"End user not found: {end_user_id}")
|
||
raise ValueError(f"终端用户不存在: {end_user_id}")
|
||
|
||
app_id = end_user.app_id
|
||
logger.debug(f"Found end_user app_id: {app_id}")
|
||
|
||
# 2. 获取该应用的最新发布版本
|
||
stmt = (
|
||
select(AppRelease)
|
||
.where(AppRelease.app_id == app_id, AppRelease.is_active.is_(True))
|
||
.order_by(AppRelease.version.desc())
|
||
)
|
||
latest_release = db.scalars(stmt).first()
|
||
|
||
if not latest_release:
|
||
logger.warning(f"No active release found for app: {app_id}")
|
||
raise ValueError(f"应用未发布: {app_id}")
|
||
|
||
logger.debug(f"Found latest release: version={latest_release.version}, id={latest_release.id}")
|
||
|
||
# 3. 从 config 中提取 memory_config_id
|
||
config = latest_release.config or {}
|
||
memory_obj = config.get('memory', {})
|
||
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||
|
||
result = {
|
||
"end_user_id": str(end_user_id),
|
||
"app_id": str(app_id),
|
||
"release_id": str(latest_release.id),
|
||
"release_version": latest_release.version,
|
||
"memory_config_id": memory_config_id
|
||
}
|
||
|
||
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
|
||
return result
|
||
|
||
|
||
def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) -> Dict[str, Dict[str, Any]]:
|
||
"""
|
||
批量获取多个终端用户关联的记忆配置(优化版本,减少数据库查询次数)
|
||
|
||
通过以下流程获取配置:
|
||
1. 批量查询所有 end_user_id 对应的 app_id
|
||
2. 批量获取这些应用的最新发布版本
|
||
3. 从发布版本的 config 字段中提取 memory_config_id
|
||
|
||
Args:
|
||
end_user_ids: 终端用户ID列表
|
||
db: 数据库会话
|
||
|
||
Returns:
|
||
字典,key 为 end_user_id,value 为包含 memory_config_id 和 memory_config_name 的字典
|
||
格式: {
|
||
"user_id_1": {"memory_config_id": "xxx", "memory_config_name": "xxx"},
|
||
"user_id_2": {"memory_config_id": None, "memory_config_name": None},
|
||
...
|
||
}
|
||
"""
|
||
from app.models.app_release_model import AppRelease
|
||
from app.models.end_user_model import EndUser
|
||
from app.models.memory_config_model import MemoryConfig
|
||
from sqlalchemy import select
|
||
|
||
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users")
|
||
|
||
result = {}
|
||
|
||
# 如果列表为空,直接返回空字典
|
||
if not end_user_ids:
|
||
return result
|
||
|
||
# 1. 批量查询所有 end_user 及其 app_id
|
||
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
|
||
|
||
# 创建 end_user_id -> app_id 的映射
|
||
user_to_app = {str(eu.id): eu.app_id for eu in end_users}
|
||
|
||
# 记录未找到的用户
|
||
found_user_ids = set(user_to_app.keys())
|
||
missing_user_ids = set(end_user_ids) - found_user_ids
|
||
if missing_user_ids:
|
||
logger.warning(f"End users not found: {missing_user_ids}")
|
||
for user_id in missing_user_ids:
|
||
result[user_id] = {"memory_config_id": None, "memory_config_name": None}
|
||
|
||
# 2. 批量获取所有相关应用的最新发布版本
|
||
app_ids = list(user_to_app.values())
|
||
if not app_ids:
|
||
return result
|
||
|
||
# 查询所有活跃的发布版本
|
||
stmt = (
|
||
select(AppRelease)
|
||
.where(AppRelease.app_id.in_(app_ids), AppRelease.is_active.is_(True))
|
||
.order_by(AppRelease.app_id, AppRelease.version.desc())
|
||
)
|
||
releases = db.scalars(stmt).all()
|
||
|
||
# 创建 app_id -> latest_release 的映射(每个 app 只保留最新版本)
|
||
app_to_release = {}
|
||
for release in releases:
|
||
if release.app_id not in app_to_release:
|
||
app_to_release[release.app_id] = release
|
||
|
||
# 3. 收集所有 memory_config_id 并批量查询配置名称
|
||
memory_config_ids = []
|
||
for end_user_id, app_id in user_to_app.items():
|
||
release = app_to_release.get(app_id)
|
||
if release:
|
||
config = release.config or {}
|
||
memory_obj = config.get('memory', {})
|
||
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||
if memory_config_id:
|
||
memory_config_ids.append(memory_config_id)
|
||
|
||
# 批量查询 memory_config_name
|
||
config_id_to_name = {}
|
||
if memory_config_ids:
|
||
memory_configs = db.query(MemoryConfig).filter(MemoryConfig.id.in_(memory_config_ids)).all()
|
||
config_id_to_name = {str(mc.id): mc.config_name for mc in memory_configs}
|
||
|
||
# 4. 构建最终结果
|
||
for end_user_id, app_id in user_to_app.items():
|
||
release = app_to_release.get(app_id)
|
||
|
||
if not release:
|
||
logger.warning(f"No active release found for app: {app_id} (end_user: {end_user_id})")
|
||
result[end_user_id] = {"memory_config_id": None, "memory_config_name": None}
|
||
continue
|
||
|
||
# 从 config 中提取 memory_config_id
|
||
config = release.config or {}
|
||
memory_obj = config.get('memory', {})
|
||
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||
|
||
# 获取配置名称
|
||
memory_config_name = config_id_to_name.get(memory_config_id) if memory_config_id else None
|
||
|
||
result[end_user_id] = {
|
||
"memory_config_id": memory_config_id,
|
||
"memory_config_name": memory_config_name
|
||
}
|
||
|
||
logger.info(f"Successfully retrieved {len(result)} connected configs")
|
||
return result |