1504 lines
63 KiB
Python
1504 lines
63 KiB
Python
"""
|
||
Memory Agent Service
|
||
|
||
Handles business logic for memory agent operations including read/write services,
|
||
health checks, and message type classification.
|
||
|
||
TODO: Refactor get_end_user_connected_config
|
||
----------------------------------------------
|
||
1. Move get_end_user_connected_config to memory_config_service.py
|
||
2. Change return type from Dict[str, Any] (with config_id string) to full MemoryConfig model
|
||
3. This will eliminate the need for callers to call load_memory_config separately
|
||
4. Update all callers to use the new unified function
|
||
"""
|
||
import json
|
||
import os
|
||
import time
|
||
import uuid
|
||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||
from uuid import UUID
|
||
|
||
import redis
|
||
from langchain_core.messages import HumanMessage
|
||
from pydantic import BaseModel, Field
|
||
from sqlalchemy import func
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.cache import InterestMemoryCache
|
||
from app.core.config import settings
|
||
from app.core.logging_config import get_config_logger, get_logger
|
||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||
from app.core.memory.agent.utils.messages_tools import (
|
||
merge_multiple_search_results,
|
||
reorder_output_results,
|
||
)
|
||
from app.core.memory.agent.utils.type_classifier import status_typle
|
||
from app.core.memory.agent.utils.write_tools import write as write_neo4j
|
||
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||
from app.db import get_db_context
|
||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||
from app.schemas import FileInput
|
||
from app.schemas.memory_agent_schema import Write_UserInput
|
||
from app.schemas.memory_config_schema import ConfigurationError
|
||
from app.services.memory_config_service import MemoryConfigService
|
||
from app.services.memory_konwledges_server import (
|
||
write_rag,
|
||
)
|
||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||
|
||
logger = get_logger(__name__)
|
||
config_logger = get_config_logger()
|
||
|
||
# Initialize Neo4j connector for analytics functions
|
||
_neo4j_connector = Neo4jConnector()
|
||
|
||
|
||
class MemoryAgentService:
|
||
"""Service for memory agent operations"""
|
||
|
||
def writer_messages_deal(self, messages, start_time, end_user_id, config_id, message, context):
|
||
duration = time.time() - start_time
|
||
if str(messages) == 'success':
|
||
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
|
||
# 记录成功的操作
|
||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||
success=True,
|
||
duration=duration, details={"message_length": len(message)})
|
||
return context
|
||
else:
|
||
logger.warning(f"Write operation failed for group {end_user_id}")
|
||
|
||
# 记录失败的操作
|
||
audit_logger.log_operation(
|
||
operation="WRITE",
|
||
config_id=config_id,
|
||
end_user_id=end_user_id,
|
||
success=False,
|
||
duration=duration,
|
||
error=f"写入失败: {messages[:100]}"
|
||
)
|
||
|
||
raise ValueError(f"写入失败: {messages}")
|
||
|
||
def extract_tool_call_info(self, event: Dict) -> bool:
|
||
"""Extract tool call information from event"""
|
||
last_message = event["messages"][-1]
|
||
|
||
# Check if AI message contains tool calls
|
||
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
|
||
tool_calls = last_message.tool_calls
|
||
for i, tool_call in enumerate(tool_calls):
|
||
if isinstance(tool_call, dict):
|
||
tool_call_id = tool_call.get('id')
|
||
tool_name = tool_call.get('name')
|
||
tool_args = tool_call.get('args', {})
|
||
else:
|
||
tool_call_id = getattr(tool_call, 'id', None)
|
||
tool_name = getattr(tool_call, 'name', None)
|
||
tool_args = getattr(tool_call, 'args', {})
|
||
|
||
logger.debug(f"Tool Call {i + 1}: ID={tool_call_id}, Name={tool_name}, Args={tool_args}")
|
||
return True
|
||
|
||
# Check if tool message
|
||
elif hasattr(last_message, 'tool_call_id'):
|
||
tool_call_id = getattr(last_message, 'tool_call_id', None)
|
||
if hasattr(last_message, 'name') and hasattr(last_message, 'content'):
|
||
tool_name = getattr(last_message, 'name', None)
|
||
try:
|
||
content = json.loads(getattr(last_message, 'content', '{}'))
|
||
tool_args = content.get('args', {})
|
||
logger.debug(f"Tool Call 1: ID={tool_call_id}, Name={tool_name}, Args={tool_args}")
|
||
except:
|
||
logger.debug(f"Tool Response ID: {tool_call_id}")
|
||
else:
|
||
logger.debug(f"Tool Response ID: {tool_call_id}")
|
||
return True
|
||
|
||
return False
|
||
|
||
async def get_health_status(self) -> Dict:
|
||
"""
|
||
Get latest health status from Redis cache
|
||
|
||
Returns health status information written by Celery periodic task
|
||
"""
|
||
logger.info("Checking health status")
|
||
|
||
client = redis.Redis(
|
||
host=settings.REDIS_HOST,
|
||
port=settings.REDIS_PORT,
|
||
db=settings.REDIS_DB,
|
||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None
|
||
)
|
||
payload = client.hgetall("memsci:health:read_service") or {}
|
||
|
||
if payload:
|
||
# decode bytes to str
|
||
decoded = {k.decode("utf-8"): v.decode("utf-8") for k, v in payload.items()}
|
||
status = decoded.get("status", "unknown")
|
||
else:
|
||
status = "unknown"
|
||
|
||
# Add database connection pool status
|
||
try:
|
||
from app.db import get_pool_status
|
||
pool_status = get_pool_status()
|
||
logger.info(f"Database pool status: {pool_status}")
|
||
|
||
# Check if pool usage is too high
|
||
if pool_status.get("usage_percent", 0) > 80:
|
||
logger.warning(f"High database pool usage: {pool_status['usage_percent']}%")
|
||
status = "warning"
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to get pool status: {e}")
|
||
pool_status = {"error": str(e)}
|
||
|
||
logger.info(f"Health status: {status}")
|
||
return {
|
||
"status": status,
|
||
"database_pool": pool_status
|
||
}
|
||
|
||
def get_log_content(self) -> str:
|
||
"""
|
||
Read and return agent service log file content
|
||
|
||
Returns cleaned log content using the same cleaning logic as transmission mode
|
||
|
||
Returns cleaned log content using the same cleaning logic as transmission mode
|
||
"""
|
||
logger.info("Reading log file")
|
||
|
||
# Get log file path - use project root directory
|
||
from pathlib import Path
|
||
project_root = str(Path(__file__).resolve().parents[2]) # api directory
|
||
log_path = os.path.join(project_root, "logs", "agent_service.log")
|
||
|
||
summer = ''
|
||
|
||
with open(log_path, "r", encoding="utf-8") as infile:
|
||
for line in infile:
|
||
# Use the same cleaning logic as LogStreamer for consistency
|
||
cleaned = LogStreamer.clean_log_line(line)
|
||
summer += cleaned
|
||
|
||
if len(summer) < 10:
|
||
raise ValueError("NO LOGS")
|
||
|
||
logger.info(f"Log content retrieved, size: {len(summer)} bytes")
|
||
return summer
|
||
|
||
async def stream_log_content(self) -> AsyncGenerator[str, None]:
|
||
"""
|
||
Stream log content in real-time using Server-Sent Events (SSE)
|
||
|
||
This method establishes a streaming connection and transmits log entries
|
||
as they are written to the log file. It uses the LogStreamer to watch
|
||
the file and yields SSE-formatted messages.
|
||
|
||
Yields:
|
||
SSE-formatted strings with the following event types:
|
||
- log: Contains log content and timestamp
|
||
- keepalive: Periodic keepalive messages to maintain connection
|
||
- error: Error information if streaming fails
|
||
- done: Indicates streaming has completed
|
||
|
||
Raises:
|
||
FileNotFoundError: If log file doesn't exist at stream start
|
||
Exception: For other unexpected errors during streaming
|
||
"""
|
||
logger.info("Starting log content streaming")
|
||
|
||
# Get log file path - use project root directory
|
||
from pathlib import Path
|
||
project_root = str(Path(__file__).resolve().parents[2]) # api directory
|
||
log_path = os.path.join(project_root, "logs", "agent_service.log")
|
||
|
||
# Check if file exists before starting stream
|
||
if not os.path.exists(log_path):
|
||
logger.error(f"Log file not found: {log_path}")
|
||
# Send error event in SSE format
|
||
yield f"event: error\ndata: {json.dumps({'code': 4006, 'message': '日志文件不存在', 'error': f'File not found: {log_path}'})}\n\n"
|
||
return
|
||
|
||
streamer = None
|
||
try:
|
||
# Initialize LogStreamer with keepalive interval from settings (default 300 seconds)
|
||
keepalive_interval = getattr(settings, 'LOG_STREAM_KEEPALIVE_INTERVAL', 300)
|
||
streamer = LogStreamer(log_path, keepalive_interval=keepalive_interval)
|
||
|
||
logger.info(f"LogStreamer initialized for {log_path}")
|
||
|
||
# Stream log content using read_existing_and_stream to get all existing content first
|
||
async for message in streamer.read_existing_and_stream():
|
||
event_type = message.get("event")
|
||
data = message.get("data")
|
||
|
||
# Format as SSE message
|
||
# SSE format: "event: <type>\ndata: <json_data>\n\n"
|
||
sse_message = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||
|
||
logger.debug(f"Streaming event: {event_type}")
|
||
yield sse_message
|
||
|
||
# If error or done event, stop streaming
|
||
if event_type in ["error", "done"]:
|
||
logger.info(f"Stream ended with event: {event_type}")
|
||
break
|
||
|
||
except FileNotFoundError as e:
|
||
logger.error(f"Log file not found during streaming: {e}")
|
||
yield f"event: error\ndata: {json.dumps({'code': 4006, 'message': '日志文件在流式传输期间变得不可用', 'error': str(e)})}\n\n"
|
||
|
||
except Exception as e:
|
||
logger.error(f"Unexpected error during log streaming: {e}", exc_info=True)
|
||
yield f"event: error\ndata: {json.dumps({'code': 8001, 'message': '流式传输期间发生错误', 'error': str(e)})}\n\n"
|
||
|
||
finally:
|
||
# Resource cleanup
|
||
logger.info("Log streaming completed, cleaning up resources")
|
||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||
|
||
async def write_memory(
|
||
self,
|
||
end_user_id: str,
|
||
messages: list[dict],
|
||
config_id: Optional[uuid.UUID] | int,
|
||
db: Session,
|
||
storage_type: str,
|
||
user_rag_memory_id: str,
|
||
language: str = "zh"
|
||
) -> str:
|
||
"""
|
||
Process write operation with config_id
|
||
|
||
Args:
|
||
end_user_id: Group identifier (also used as end_user_id)
|
||
messages: Message to write
|
||
config_id: Configuration ID from database
|
||
db: SQLAlchemy database session
|
||
storage_type: Storage type (neo4j or rag)
|
||
user_rag_memory_id: User RAG memory ID
|
||
language: 语言类型 ("zh" 中文, "en" 英文)
|
||
|
||
Returns:
|
||
Write operation result status
|
||
|
||
Raises:
|
||
ValueError: If config loading fails or write operation fails
|
||
"""
|
||
# Resolve config_id and workspace_id
|
||
# Always get workspace_id from end_user for fallback, even if config_id is provided
|
||
workspace_id = None
|
||
try:
|
||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||
workspace_id = connected_config.get("workspace_id")
|
||
if config_id is None:
|
||
config_id = connected_config.get("memory_config_id")
|
||
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
||
if config_id is None and workspace_id is None:
|
||
raise ValueError(f"No memory configuration found for end_user {end_user_id}. "
|
||
f"Please ensure the user has a connected memory configuration.")
|
||
except Exception as e:
|
||
if "No memory configuration found" in str(e):
|
||
raise # Re-raise our specific error
|
||
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||
if config_id is None:
|
||
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
|
||
# If config_id was provided, continue without workspace_id fallback
|
||
|
||
import time
|
||
start_time = time.time()
|
||
|
||
# Load configuration from database with workspace fallback
|
||
# Use a separate database session to avoid transaction failures
|
||
try:
|
||
from app.db import get_db_context
|
||
with get_db_context() as config_db:
|
||
config_service = MemoryConfigService(config_db)
|
||
memory_config = config_service.load_memory_config(
|
||
config_id=config_id,
|
||
workspace_id=workspace_id,
|
||
service_name="MemoryAgentService"
|
||
)
|
||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||
except ConfigurationError as e:
|
||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||
logger.error(error_msg)
|
||
|
||
# Log failed operation
|
||
duration = time.time() - start_time
|
||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||
success=False, duration=duration, error=error_msg)
|
||
|
||
raise ValueError(error_msg)
|
||
|
||
perceptual_serivce = MemoryPerceptualService(db)
|
||
for message in messages:
|
||
message["file_content"] = []
|
||
for file in (message.get("files") or []):
|
||
file_object = await perceptual_serivce.generate_perceptual_memory(
|
||
end_user_id=end_user_id,
|
||
memory_config=memory_config,
|
||
file=FileInput(**file)
|
||
)
|
||
if file_object is None:
|
||
continue
|
||
message["file_content"].append((file_object, file["type"]))
|
||
logger.info(messages)
|
||
|
||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||
try:
|
||
if storage_type == "rag":
|
||
# For RAG storage, convert messages to single string
|
||
await write_rag(end_user_id, message_text, user_rag_memory_id)
|
||
return "success"
|
||
else:
|
||
await write_neo4j(
|
||
end_user_id=end_user_id,
|
||
messages=messages,
|
||
memory_config=memory_config,
|
||
ref_id='',
|
||
language=language
|
||
)
|
||
for lang in ["zh", "en"]:
|
||
deleted = await InterestMemoryCache.delete_interest_distribution(
|
||
end_user_id, lang
|
||
)
|
||
if deleted:
|
||
logger.info(
|
||
f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
|
||
for message in messages:
|
||
message["file_content"] = [
|
||
perceptual[0].file_path for perceptual in message["file_content"]
|
||
]
|
||
return self.writer_messages_deal(
|
||
"success",
|
||
start_time,
|
||
end_user_id,
|
||
config_id,
|
||
message_text,
|
||
{
|
||
"status": "success",
|
||
"data": messages,
|
||
"config_id": memory_config.config_id,
|
||
"config_name": memory_config.config_name
|
||
}
|
||
)
|
||
except Exception as e:
|
||
# Ensure proper error handling and logging
|
||
error_msg = f"Write operation failed: {str(e)}"
|
||
logger.error(error_msg)
|
||
|
||
duration = time.time() - start_time
|
||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||
success=False, duration=duration, error=error_msg)
|
||
raise ValueError(error_msg)
|
||
|
||
async def read_memory(
|
||
self,
|
||
end_user_id: str,
|
||
message: str,
|
||
history: List[Dict], # FIXME: unused parameter
|
||
search_switch: str,
|
||
config_id: Optional[uuid.UUID] | int,
|
||
db: Session,
|
||
storage_type: str,
|
||
user_rag_memory_id: str) -> Dict:
|
||
"""
|
||
Process read operation with config_id
|
||
|
||
search_switch values:
|
||
- "0": Requires verification
|
||
- "1": No verification, direct split
|
||
- "2": Direct answer based on context
|
||
|
||
Args:
|
||
end_user_id: Group identifier (also used as end_user_id)
|
||
message: User message
|
||
history: Conversation history
|
||
search_switch: Search mode switch
|
||
config_id: Configuration ID from database
|
||
db: SQLAlchemy database session
|
||
storage_type: Storage type (neo4j or rag)
|
||
user_rag_memory_id: User RAG memory ID
|
||
|
||
Returns:
|
||
Dict with 'answer' and 'intermediate_outputs' keys
|
||
|
||
Raises:
|
||
ValueError: If config loading fails
|
||
"""
|
||
|
||
import time
|
||
start_time = time.time()
|
||
ori_message = message
|
||
|
||
# Resolve config_id and workspace_id
|
||
# Always get workspace_id from end_user for fallback, even if config_id is provided
|
||
workspace_id = None
|
||
try:
|
||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||
workspace_id = connected_config.get("workspace_id")
|
||
if config_id is None:
|
||
config_id = connected_config.get("memory_config_id")
|
||
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
||
if config_id is None and workspace_id is None:
|
||
raise ValueError(
|
||
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||
except Exception as e:
|
||
if "No memory configuration found" in str(e):
|
||
raise # Re-raise our specific error
|
||
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||
if config_id is None:
|
||
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
|
||
# If config_id was provided, continue without workspace_id fallback
|
||
|
||
logger.info(f"Read operation for group {end_user_id} with config_id {config_id}")
|
||
|
||
config_load_start = time.time()
|
||
try:
|
||
# Use a separate database session to avoid transaction failures
|
||
from app.db import get_db_context
|
||
with get_db_context() as config_db:
|
||
config_service = MemoryConfigService(config_db)
|
||
memory_config = config_service.load_memory_config(
|
||
config_id=config_id,
|
||
workspace_id=workspace_id,
|
||
service_name="MemoryAgentService"
|
||
)
|
||
config_load_time = time.time() - config_load_start
|
||
logger.info(f"[PERF] Configuration loaded in {config_load_time:.4f}s: {memory_config.config_name}")
|
||
except ConfigurationError as e:
|
||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||
logger.error(error_msg)
|
||
|
||
# Log failed operation
|
||
duration = time.time() - start_time
|
||
audit_logger.log_operation(
|
||
operation="READ",
|
||
config_id=config_id,
|
||
end_user_id=end_user_id,
|
||
success=False,
|
||
duration=duration,
|
||
error=error_msg
|
||
)
|
||
|
||
raise ValueError(error_msg)
|
||
|
||
# Step 2: Prepare history
|
||
history.append({"role": "user", "content": message})
|
||
logger.debug(f"Group ID:{end_user_id}, Message:{message}, History:{history}, Config ID:{config_id}")
|
||
|
||
# Step 3: Initialize MCP client and execute read workflow
|
||
graph_exec_start = time.time()
|
||
try:
|
||
async with make_read_graph() as graph:
|
||
config = {"configurable": {"thread_id": end_user_id}}
|
||
# 初始状态 - 包含所有必要字段
|
||
initial_state = {
|
||
"messages": [HumanMessage(content=message)],
|
||
"search_switch": search_switch,
|
||
"end_user_id": end_user_id,
|
||
"storage_type": storage_type,
|
||
"user_rag_memory_id": user_rag_memory_id,
|
||
"memory_config": memory_config}
|
||
# 获取节点更新信息
|
||
_intermediate_outputs = []
|
||
summary = ''
|
||
async for update_event in graph.astream(
|
||
initial_state,
|
||
stream_mode="updates",
|
||
config=config
|
||
):
|
||
for node_name, node_data in update_event.items():
|
||
# if 'save_neo4j' == node_name:
|
||
# massages = node_data
|
||
logger.info(f"处理节点: {node_name}")
|
||
|
||
# 处理不同Summary节点的返回结构
|
||
if 'Summary' in node_name:
|
||
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
|
||
summary = node_data['InputSummary']['summary_result']
|
||
elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
|
||
summary = node_data['RetrieveSummary']['summary_result']
|
||
elif 'summary' in node_data and 'summary_result' in node_data['summary']:
|
||
summary = node_data['summary']['summary_result']
|
||
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
|
||
summary = node_data['SummaryFails']['summary_result']
|
||
|
||
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
|
||
if spit_data and spit_data != [] and spit_data != {}:
|
||
_intermediate_outputs.append(spit_data)
|
||
|
||
# Problem_Extension 节点
|
||
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
|
||
if problem_extension and problem_extension != [] and problem_extension != {}:
|
||
_intermediate_outputs.append(problem_extension)
|
||
|
||
# Retrieve 节点
|
||
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
|
||
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
||
_intermediate_outputs.extend(retrieve_node)
|
||
|
||
# Perceptual_Retrieve 节点
|
||
perceptual_node = node_data.get('perceptual_data', {}).get('_intermediate', None)
|
||
if perceptual_node and perceptual_node != [] and perceptual_node != {}:
|
||
_intermediate_outputs.append(perceptual_node)
|
||
|
||
# Verify 节点
|
||
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
||
if verify_n and verify_n != [] and verify_n != {}:
|
||
_intermediate_outputs.append(verify_n)
|
||
|
||
# Summary 节点
|
||
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
||
if summary_n and summary_n != [] and summary_n != {}:
|
||
_intermediate_outputs.append(summary_n)
|
||
|
||
graph_exec_time = time.time() - graph_exec_start
|
||
logger.info(f"[PERF] Graph execution completed in {graph_exec_time:.4f}s")
|
||
|
||
_intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
|
||
|
||
optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
|
||
result = reorder_output_results(optimized_outputs)
|
||
|
||
# 保存短期记忆到数据库
|
||
# 只有 search_switch 不为 "2"(快速检索)时才保存
|
||
try:
|
||
from app.repositories.memory_short_repository import (
|
||
ShortTermMemoryRepository,
|
||
)
|
||
|
||
retrieved_content = []
|
||
repo = ShortTermMemoryRepository(db)
|
||
|
||
if str(search_switch) != "2":
|
||
for intermediate in _intermediate_outputs:
|
||
logger.debug(f"处理中间结果: {intermediate}")
|
||
intermediate_type = intermediate.get('type', '')
|
||
|
||
if intermediate_type == "search_result":
|
||
query = intermediate.get('query', '')
|
||
raw_results = intermediate.get('raw_results', {})
|
||
try:
|
||
reranked_results = raw_results.get('reranked_results', [])
|
||
statements = [statement['statement'] for statement in
|
||
reranked_results.get('statements', [])]
|
||
except Exception:
|
||
statements = []
|
||
|
||
# 去重
|
||
statements = list(set(statements))
|
||
|
||
if query and statements:
|
||
retrieved_content.append({query: statements})
|
||
|
||
# 如果 retrieved_content 为空,设置为空字符串
|
||
if not retrieved_content:
|
||
retrieved_content = ''
|
||
|
||
# 只有当回答不是"信息不足"且不是快速检索时才保存
|
||
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2":
|
||
# 使用 upsert 方法
|
||
repo.upsert(
|
||
end_user_id=end_user_id,
|
||
messages=ori_message,
|
||
aimessages=summary,
|
||
retrieved_content=retrieved_content,
|
||
search_switch=str(search_switch)
|
||
)
|
||
logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
|
||
else:
|
||
logger.debug(
|
||
f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
|
||
|
||
except Exception as save_error:
|
||
# 保存失败不应该影响主流程,只记录错误
|
||
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
|
||
|
||
# Log successful operation
|
||
total_time = time.time() - start_time
|
||
logger.info(
|
||
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
|
||
|
||
duration = time.time() - start_time
|
||
audit_logger.log_operation(
|
||
operation="READ",
|
||
config_id=config_id,
|
||
end_user_id=end_user_id,
|
||
success=True,
|
||
duration=duration
|
||
)
|
||
|
||
return {
|
||
"answer": summary,
|
||
"intermediate_outputs": result
|
||
}
|
||
|
||
# TODO: redis search -> answer
|
||
except Exception as e:
|
||
# Ensure proper error handling and logging
|
||
error_msg = f"Read operation failed: {str(e)}"
|
||
logger.error(error_msg)
|
||
|
||
duration = time.time() - start_time
|
||
audit_logger.log_operation(
|
||
operation="READ",
|
||
config_id=config_id,
|
||
end_user_id=end_user_id,
|
||
success=False,
|
||
duration=duration,
|
||
error=error_msg
|
||
)
|
||
raise ValueError(error_msg)
|
||
|
||
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
|
||
"""
|
||
Get standardized message list from user input.
|
||
|
||
Args:
|
||
user_input: Write_UserInput object
|
||
|
||
Returns:
|
||
list[dict]: Message list, each message contains role and content
|
||
|
||
Raises:
|
||
ValueError: If messages is empty or format is incorrect
|
||
"""
|
||
from app.core.logging_config import get_api_logger
|
||
logger = get_api_logger()
|
||
|
||
if len(user_input.messages) == 0:
|
||
logger.error("Validation failed: Message list cannot be empty")
|
||
raise ValueError("Message list cannot be empty")
|
||
|
||
for idx, msg in enumerate(user_input.messages):
|
||
if not isinstance(msg, dict):
|
||
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
|
||
raise ValueError(
|
||
f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
|
||
|
||
if 'role' not in msg:
|
||
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
|
||
raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}")
|
||
|
||
if 'content' not in msg:
|
||
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
|
||
raise ValueError(
|
||
f"Message format error: Message must contain 'content' field. Error message index: {idx}")
|
||
|
||
if msg['role'] not in ['user', 'assistant']:
|
||
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
|
||
raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}")
|
||
|
||
if not msg['content'] or not msg['content'].strip():
|
||
logger.error(f"Validation failed: Message {idx} content is empty")
|
||
raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}")
|
||
|
||
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
|
||
return user_input.messages
|
||
|
||
async def classify_message_type(
|
||
self,
|
||
message: str,
|
||
config_id: UUID,
|
||
db: Session,
|
||
workspace_id: Optional[UUID] = None
|
||
) -> Dict:
|
||
"""
|
||
Determine the type of user message (read or write)
|
||
Updated to eliminate global variables in favor of explicit parameters.
|
||
|
||
Args:
|
||
message: User message to classify
|
||
config_id: Configuration ID to load LLM model from database
|
||
db: Database session
|
||
workspace_id: Workspace ID for fallback lookup (optional)
|
||
|
||
Returns:
|
||
Type classification result
|
||
"""
|
||
logger.info("Classifying message type")
|
||
|
||
# Load configuration to get LLM model ID
|
||
config_service = MemoryConfigService(db)
|
||
memory_config = config_service.load_memory_config(
|
||
config_id=config_id,
|
||
workspace_id=workspace_id,
|
||
service_name="MemoryAgentService"
|
||
)
|
||
|
||
status = await status_typle(message, memory_config.llm_model_id)
|
||
logger.debug(f"Message type: {status}")
|
||
return status
|
||
|
||
async def generate_summary_from_retrieve(
|
||
self,
|
||
end_user_id: str,
|
||
retrieve_info: str,
|
||
history: List[Dict],
|
||
query: str,
|
||
config_id: str,
|
||
db: Session
|
||
) -> str:
|
||
"""
|
||
基于检索信息、历史对话和查询生成最终答案
|
||
|
||
使用 Retrieve_Summary_prompt.jinja2 模板调用大模型生成答案
|
||
|
||
Args:
|
||
retrieve_info: 检索到的信息
|
||
history: 历史对话记录
|
||
query: 用户查询
|
||
config_id: 配置ID
|
||
db: 数据库会话
|
||
|
||
Returns:
|
||
生成的答案文本
|
||
"""
|
||
# Always get workspace_id from end_user for fallback, even if config_id is provided
|
||
workspace_id = None
|
||
try:
|
||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||
workspace_id = connected_config.get('workspace_id')
|
||
if config_id is None:
|
||
config_id = connected_config.get('memory_config_id')
|
||
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
||
if config_id is None and workspace_id is None:
|
||
raise ValueError(
|
||
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||
except Exception as e:
|
||
if "No memory configuration found" in str(e):
|
||
raise # Re-raise our specific error
|
||
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||
if config_id is None:
|
||
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
|
||
# If config_id was provided, continue without workspace_id fallback
|
||
|
||
logger.info(f"Generating summary from retrieve info for query: {query[:50]}...")
|
||
|
||
try:
|
||
# 加载配置
|
||
config_service = MemoryConfigService(db)
|
||
memory_config = config_service.load_memory_config(
|
||
config_id=config_id,
|
||
workspace_id=workspace_id,
|
||
service_name="MemoryAgentService"
|
||
)
|
||
|
||
# 导入必要的模块
|
||
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
|
||
summary_llm,
|
||
)
|
||
from app.core.memory.agent.models.summary_models import (
|
||
RetrieveSummaryResponse,
|
||
)
|
||
|
||
# 构建状态对象
|
||
state = {
|
||
"data": query,
|
||
"memory_config": memory_config
|
||
}
|
||
|
||
# 直接调用 summary_llm 函数
|
||
answer = await summary_llm(
|
||
state=state,
|
||
history=history,
|
||
retrieve_info=retrieve_info,
|
||
template_name='direct_summary_prompt.jinja2',
|
||
operation_name='retrieve_summary',
|
||
response_model=RetrieveSummaryResponse,
|
||
search_mode="1"
|
||
)
|
||
|
||
logger.info(f"Successfully generated summary: {answer[:100] if answer else 'None'}...")
|
||
return answer if answer else "信息不足,无法回答。"
|
||
|
||
except Exception as e:
|
||
logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
|
||
return "信息不足,无法回答。"
|
||
|
||
async def get_knowledge_type_stats(
|
||
self,
|
||
db: Session,
|
||
end_user_id: Optional[str] = None,
|
||
only_active: bool = True,
|
||
current_workspace_id: Optional[uuid.UUID] = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
统计知识库类型分布,包含:
|
||
1. PostgreSQL 中的知识库类型:General, Web, Third-party, Folder(根据 workspace_id 过滤)
|
||
2. total: 所有类型的总和
|
||
|
||
参数:
|
||
- end_user_id: 用户组ID(可选,保留参数以保持接口兼容性)
|
||
- only_active: 是否仅统计有效记录
|
||
- current_workspace_id: 当前工作空间ID(可选,未提供时知识库统计为 0)
|
||
- db: 数据库会话
|
||
|
||
返回格式:
|
||
{
|
||
"General": count,
|
||
"Web": count,
|
||
"Third-party": count,
|
||
"Folder": count,
|
||
"total": sum_of_all
|
||
}
|
||
"""
|
||
result = {}
|
||
|
||
# 1. 统计 PostgreSQL 中的知识库类型
|
||
try:
|
||
# 初始化所有标准类型为 0
|
||
for kb_type in KnowledgeType:
|
||
result[kb_type.value] = 0
|
||
|
||
# 如果提供了 workspace_id,则按 workspace_id 过滤
|
||
if current_workspace_id:
|
||
# 构建查询条件
|
||
query = db.query(
|
||
Knowledge.type,
|
||
func.count(Knowledge.id).label('count')
|
||
).filter(Knowledge.workspace_id == current_workspace_id)
|
||
|
||
# 检查 Knowledge 模型是否有 status 字段
|
||
if only_active and hasattr(Knowledge, 'status'):
|
||
query = query.filter(Knowledge.status == 1)
|
||
|
||
# 按类型分组
|
||
type_counts = query.group_by(Knowledge.type).all()
|
||
|
||
# 只填充标准类型的统计值,忽略其他类型
|
||
valid_types = {kb_type.value for kb_type in KnowledgeType}
|
||
for type_name, count in type_counts:
|
||
if type_name in valid_types:
|
||
result[type_name] = count
|
||
|
||
logger.info(f"知识库类型统计成功 (workspace_id={current_workspace_id}): {result}")
|
||
else:
|
||
# 没有提供 workspace_id,所有知识库类型返回 0
|
||
logger.info("未提供 workspace_id,知识库类型统计全部为 0")
|
||
|
||
except Exception as e:
|
||
logger.error(f"知识库类型统计失败: {e}")
|
||
raise Exception(f"知识库类型统计失败: {e}")
|
||
|
||
# 2. 统计 Neo4j 中的 memory 总量已移除
|
||
# memory 字段不再返回
|
||
|
||
# 3. 计算知识库类型总和(不包括 memory)
|
||
result["total"] = (
|
||
result.get("General", 0) +
|
||
result.get("Web", 0) +
|
||
result.get("Third-party", 0) +
|
||
result.get("Folder", 0)
|
||
)
|
||
|
||
return result
|
||
|
||
async def get_interest_distribution_by_user(
|
||
self,
|
||
end_user_id: Optional[str] = None,
|
||
limit: int = 5,
|
||
language: str = "zh"
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取指定用户的兴趣分布标签。
|
||
|
||
与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习等),
|
||
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
|
||
|
||
参数:
|
||
- end_user_id: 用户ID(必填)
|
||
- limit: 返回标签数量限制
|
||
- language: 输出语言("zh" 中文, "en" 英文)
|
||
|
||
返回格式:
|
||
[
|
||
{"name": "兴趣活动名", "frequency": 频次},
|
||
...
|
||
]
|
||
"""
|
||
try:
|
||
tags = await get_interest_distribution(end_user_id, limit=limit, by_user=False, language=language)
|
||
return [{"name": tag, "frequency": freq} for tag, freq in tags]
|
||
except Exception as e:
|
||
logger.error(f"兴趣分布标签查询失败: {e}")
|
||
raise Exception(f"兴趣分布标签查询失败: {e}")
|
||
|
||
async def get_user_profile(
|
||
self,
|
||
end_user_id: Optional[str] = None,
|
||
current_user_id: Optional[str] = None,
|
||
llm_id: Optional[str] = None,
|
||
db: Session = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
获取用户详情,包含:
|
||
1. 用户名字(直接使用 end_user_name)
|
||
2. 用户标签(从摘要中用LLM总结3个标签)
|
||
3. 热门记忆标签(从hot_memory_tags获取前4个)
|
||
|
||
参数:
|
||
- end_user_id: 用户ID(可选)
|
||
- current_user_id: 当前登录用户的ID(保留参数)
|
||
- llm_id: LLM模型ID(用于生成标签,可选,如果不提供则跳过标签生成)
|
||
- db: 数据库会话(可选)
|
||
|
||
返回格式:
|
||
{
|
||
"name": "用户名",
|
||
"tags": ["产品设计师", "旅行爱好者", "摄影发烧友"],
|
||
"hot_tags": [
|
||
{"name": "标签1", "frequency": 10},
|
||
{"name": "标签2", "frequency": 8},
|
||
...
|
||
]
|
||
}
|
||
"""
|
||
result = {}
|
||
|
||
# 1. 根据 end_user_id 获取 end_user_name
|
||
try:
|
||
if end_user_id and db:
|
||
from app.repositories import end_user_repository
|
||
from app.schemas.end_user_schema import EndUser as EndUserSchema
|
||
|
||
end_user_orm = end_user_repository.get_end_user_by_id(db, end_user_id)
|
||
if end_user_orm:
|
||
end_user = EndUserSchema.model_validate(end_user_orm)
|
||
end_user_name = end_user.other_name
|
||
else:
|
||
end_user_name = "默认用户"
|
||
else:
|
||
end_user_name = "默认用户"
|
||
except Exception as e:
|
||
logger.error(f"Failed to get end_user_name: {e}")
|
||
end_user_name = "默认用户"
|
||
|
||
result["name"] = end_user_name
|
||
logger.debug(f"The end_user is: {end_user_name}")
|
||
|
||
# 2. 使用LLM从语句和实体中提取标签
|
||
try:
|
||
connector = Neo4jConnector()
|
||
|
||
# 查询该用户的语句
|
||
query = (
|
||
"MATCH (s:Statement) "
|
||
"WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id) AND s.statement IS NOT NULL "
|
||
"RETURN s.statement AS statement "
|
||
"ORDER BY s.created_at DESC LIMIT 100"
|
||
)
|
||
rows = await connector.execute_query(query, end_user_id=end_user_id)
|
||
statements = [r.get("statement", "") for r in rows if r.get("statement")]
|
||
|
||
# 查询该用户的热门实体
|
||
entity_query = (
|
||
"MATCH (e:ExtractedEntity) "
|
||
"WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) AND e.entity_type <> '人物' AND e.name IS NOT NULL "
|
||
"RETURN e.name AS name, count(e) AS frequency "
|
||
"ORDER BY frequency DESC LIMIT 20"
|
||
)
|
||
entity_rows = await connector.execute_query(entity_query, end_user_id=end_user_id)
|
||
entities = [f"{r['name']} ({r['frequency']})" for r in entity_rows]
|
||
|
||
await connector.close()
|
||
|
||
if not statements or not llm_id:
|
||
result["tags"] = []
|
||
if not llm_id and statements:
|
||
logger.warning("llm_id not provided, skipping tag generation")
|
||
else:
|
||
# 构建摘要文本
|
||
summary_text = f"用户语句样本:{' | '.join(statements[:20])}\n核心实体:{', '.join(entities)}"
|
||
logger.debug(f"User data found: {len(statements)} statements, {len(entities)} entities")
|
||
|
||
# 使用LLM提取标签
|
||
with get_db_context() as db:
|
||
factory = MemoryClientFactory(db)
|
||
llm_client = factory.get_llm_client(llm_id)
|
||
|
||
# 定义标签提取的结构
|
||
class UserTags(BaseModel):
|
||
tags: list[str] = Field(...,
|
||
description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友")
|
||
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": "你是一个信息提取助手。从用户的语句和实体中提取3个最能代表用户特征的标签。标签应该简洁(2-6个字),描述用户的职业、兴趣或特点。"
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": f"请从以下用户信息中提取3个标签:\n\n{summary_text}"
|
||
}
|
||
]
|
||
|
||
user_tags = await llm_client.response_structured(
|
||
messages=messages,
|
||
response_model=UserTags
|
||
)
|
||
|
||
result["tags"] = user_tags.tags
|
||
logger.debug(f"Extracted tags: {user_tags.tags}")
|
||
|
||
except Exception as e:
|
||
# 如果提取失败,使用默认值
|
||
logger.error(f"Failed to extract user tags: {e}")
|
||
result["tags"] = []
|
||
|
||
try:
|
||
# 3. 获取热门记忆标签(前4个)
|
||
connector = Neo4jConnector()
|
||
names_to_exclude = ['AI', 'Caroline', 'Melanie', 'Jon', 'Gina', '用户', 'AI助手', 'John', 'Maria']
|
||
hot_tag_query = (
|
||
"MATCH (e:ExtractedEntity) "
|
||
"WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) AND e.entity_type <> '人物' "
|
||
"AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
|
||
"RETURN e.name AS name, count(e) AS frequency "
|
||
"ORDER BY frequency DESC LIMIT 4"
|
||
)
|
||
hot_tag_rows = await connector.execute_query(
|
||
hot_tag_query,
|
||
end_user_id=end_user_id,
|
||
names_to_exclude=names_to_exclude
|
||
)
|
||
await connector.close()
|
||
|
||
result["hot_tags"] = [{"name": r["name"], "frequency": r["frequency"]} for r in hot_tag_rows]
|
||
logger.debug(f"Hot tags found: {len(result['hot_tags'])} tags")
|
||
except Exception as e:
|
||
logger.error(f"Failed to get hot tags: {e}")
|
||
result["hot_tags"] = []
|
||
|
||
return result
|
||
|
||
async def stream_log_content(self) -> AsyncGenerator[str, None]:
|
||
"""
|
||
Stream log content in real-time using Server-Sent Events (SSE)
|
||
|
||
This method establishes a streaming connection and transmits log entries
|
||
as they are written to the log file. It uses the LogStreamer to watch
|
||
the file and yields SSE-formatted messages.
|
||
|
||
Yields:
|
||
SSE-formatted strings with the following event types:
|
||
- log: Contains log content and timestamp
|
||
- keepalive: Periodic keepalive messages to maintain connection
|
||
- error: Error information if streaming fails
|
||
- done: Indicates streaming has completed
|
||
|
||
Raises:
|
||
FileNotFoundError: If log file doesn't exist at stream start
|
||
Exception: For other unexpected errors during streaming
|
||
"""
|
||
logger.info("Starting log content streaming")
|
||
|
||
# Get log file path - use project root directory
|
||
from pathlib import Path
|
||
project_root = str(Path(__file__).resolve().parents[2]) # api directory
|
||
log_path = os.path.join(project_root, "logs", "agent_service.log")
|
||
|
||
# Check if file exists before starting stream
|
||
if not os.path.exists(log_path):
|
||
logger.error(f"Log file not found: {log_path}")
|
||
# Send error event in SSE format
|
||
yield f"event: error\ndata: {json.dumps({'code': 4006, 'message': '日志文件不存在', 'error': f'File not found: {log_path}'})}\n\n"
|
||
return
|
||
|
||
streamer = None
|
||
try:
|
||
# Initialize LogStreamer with keepalive interval from settings (default 300 seconds)
|
||
keepalive_interval = getattr(settings, 'LOG_STREAM_KEEPALIVE_INTERVAL', 300)
|
||
streamer = LogStreamer(log_path, keepalive_interval=keepalive_interval)
|
||
|
||
logger.info(f"LogStreamer initialized for {log_path}")
|
||
|
||
# Stream log content using read_existing_and_stream to get all existing content first
|
||
async for message in streamer.read_existing_and_stream():
|
||
event_type = message.get("event")
|
||
data = message.get("data")
|
||
|
||
# Format as SSE message
|
||
# SSE format: "event: <type>\ndata: <json_data>\n\n"
|
||
sse_message = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
||
|
||
logger.debug(f"Streaming event: {event_type}")
|
||
yield sse_message
|
||
|
||
# If error or done event, stop streaming
|
||
if event_type in ["error", "done"]:
|
||
logger.info(f"Stream ended with event: {event_type}")
|
||
break
|
||
|
||
except FileNotFoundError as e:
|
||
logger.error(f"Log file not found during streaming: {e}")
|
||
yield f"event: error\ndata: {json.dumps({'code': 4006, 'message': '日志文件在流式传输期间变得不可用', 'error': str(e)})}\n\n"
|
||
|
||
except Exception as e:
|
||
logger.error(f"Unexpected error during log streaming: {e}", exc_info=True)
|
||
yield f"event: error\ndata: {json.dumps({'code': 8001, 'message': '流式传输期间发生错误', 'error': str(e)})}\n\n"
|
||
|
||
finally:
|
||
# Resource cleanup
|
||
logger.info("Log streaming completed, cleaning up resources")
|
||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||
|
||
|
||
# TODO: move to memory_config_service.py
|
||
def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]:
|
||
"""
|
||
获取终端用户关联的记忆配置
|
||
|
||
兼容旧数据:如果 end_user.memory_config_id 为空,则从 AppRelease.config 中获取
|
||
并回填到 end_user.memory_config_id 字段(懒迁移)。
|
||
|
||
Args:
|
||
end_user_id: 终端用户ID
|
||
db: 数据库会话
|
||
|
||
Returns:
|
||
包含 memory_config_id, workspace_id 和相关信息的字典
|
||
|
||
Raises:
|
||
ValueError: 当终端用户不存在或应用未发布时
|
||
"""
|
||
import json as json_module
|
||
|
||
from sqlalchemy import select
|
||
|
||
from app.models.app_model import App
|
||
from app.models.app_release_model import AppRelease
|
||
from app.models.end_user_model import EndUser
|
||
from app.services.memory_config_service import MemoryConfigService
|
||
|
||
logger.info(f"Getting connected config for end_user: {end_user_id}")
|
||
|
||
# TODO: check sources for enduserid, should be one of these three: chat, draft, apikey
|
||
# 1. 获取 end_user 及其 app_id
|
||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||
if not end_user:
|
||
logger.warning(f"End user not found: {end_user_id}")
|
||
raise ValueError(f"终端用户不存在: {end_user_id}")
|
||
|
||
app_id = end_user.app_id
|
||
logger.debug(f"Found end_user app_id: {app_id}")
|
||
|
||
# 2. 获取应用以确定 workspace_id
|
||
app = db.query(App).filter(App.id == app_id).first()
|
||
if not app:
|
||
logger.warning(f"App not found: {app_id}")
|
||
# raise ValueError(f"应用不存在: {app_id}")
|
||
# TODO: temp fix for draft run
|
||
# if not app.current_release_id:
|
||
# logger.warning(f"No current release for app: {app_id}")
|
||
# raise ValueError(f"应用未发布: {app_id}")
|
||
|
||
# 3. 兼容旧数据:如果 memory_config_id 为空,从 AppRelease.config 获取并回填
|
||
memory_config_id_to_use = end_user.memory_config_id
|
||
|
||
# 如果已有 memory_config_id,直接使用
|
||
# 如果新创建enduser,enduser.memory_config_id 必定为none
|
||
# 那么使用从release中获取memory_config_id为预期行为,并且回填到
|
||
# end_user.memory_config_id
|
||
if not memory_config_id_to_use:
|
||
logger.info(f"end_user.memory_config_id is None, migrating from AppRelease.config")
|
||
|
||
# 获取最新发布版本
|
||
stmt = (
|
||
select(AppRelease)
|
||
.where(AppRelease.app_id == app_id, AppRelease.is_active.is_(True))
|
||
.order_by(AppRelease.version.desc())
|
||
)
|
||
# TODO: change to current_release_id
|
||
latest_release = db.scalars(stmt).first()
|
||
|
||
if latest_release:
|
||
config = latest_release.config or {}
|
||
|
||
# 如果 config 是字符串,解析为字典
|
||
if isinstance(config, str):
|
||
try:
|
||
config = json_module.loads(config)
|
||
except json_module.JSONDecodeError:
|
||
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
|
||
config = {}
|
||
|
||
# 使用 MemoryConfigService 的提取方法
|
||
memory_config_service = MemoryConfigService(db)
|
||
legacy_config_id, is_legacy_int = memory_config_service.extract_memory_config_id(
|
||
app_type=app.type,
|
||
config=config
|
||
)
|
||
|
||
if legacy_config_id:
|
||
# 验证提取的 config_id 是否存在于数据库中
|
||
from app.models.memory_config_model import (
|
||
MemoryConfig as MemoryConfigModel,
|
||
)
|
||
existing_config = db.get(MemoryConfigModel, legacy_config_id)
|
||
|
||
if existing_config:
|
||
memory_config_id_to_use = legacy_config_id
|
||
|
||
# 回填到 end_user 表(lazy update)
|
||
end_user.memory_config_id = memory_config_id_to_use
|
||
db.commit()
|
||
logger.info(
|
||
f"Migrated memory_config_id for end_user {end_user_id}: {memory_config_id_to_use}"
|
||
)
|
||
else:
|
||
logger.warning(
|
||
f"Extracted memory_config_id does not exist, skipping backfill: "
|
||
f"end_user_id={end_user_id}, config_id={legacy_config_id}"
|
||
)
|
||
elif is_legacy_int:
|
||
logger.info(
|
||
f"Legacy int config detected for end_user {end_user_id}, will use workspace default"
|
||
)
|
||
|
||
# 4. 使用 get_config_with_fallback 获取记忆配置
|
||
memory_config_service = MemoryConfigService(db)
|
||
memory_config = memory_config_service.get_config_with_fallback(
|
||
memory_config_id=memory_config_id_to_use,
|
||
workspace_id=end_user.workspace_id
|
||
)
|
||
|
||
memory_config_id = str(memory_config.config_id) if memory_config else None
|
||
|
||
result = {
|
||
"end_user_id": str(end_user_id),
|
||
"memory_config_id": memory_config_id,
|
||
"workspace_id": str(end_user.workspace_id)
|
||
}
|
||
|
||
logger.info(
|
||
f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={end_user.workspace_id}")
|
||
return result
|
||
|
||
|
||
def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) -> Dict[str, Dict[str, Any]]:
|
||
"""
|
||
批量获取多个终端用户关联的记忆配置(优化版本,减少数据库查询次数)
|
||
|
||
使用与 get_end_user_connected_config 相同的逻辑:
|
||
1. 优先使用 end_user.memory_config_id
|
||
2. 如果没有,尝试从 AppRelease.config 提取并回填
|
||
3. 如果仍然没有,回退到工作空间默认配置
|
||
|
||
Args:
|
||
end_user_ids: 终端用户ID列表
|
||
db: 数据库会话
|
||
|
||
Returns:
|
||
字典,key 为 end_user_id,value 为包含 memory_config_id 和 memory_config_name 的字典
|
||
格式: {
|
||
"user_id_1": {"memory_config_id": "xxx", "memory_config_name": "xxx"},
|
||
"user_id_2": {"memory_config_id": None, "memory_config_name": None},
|
||
...
|
||
}
|
||
"""
|
||
import json as json_module
|
||
|
||
from sqlalchemy import select
|
||
|
||
from app.models.app_model import App
|
||
from app.models.app_release_model import AppRelease
|
||
from app.models.end_user_model import EndUser
|
||
from app.models.memory_config_model import MemoryConfig
|
||
from app.services.memory_config_service import MemoryConfigService
|
||
|
||
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users")
|
||
|
||
result = {}
|
||
|
||
if not end_user_ids:
|
||
return result
|
||
|
||
# 1. 批量查询所有 end_user 及其 app_id 和 memory_config_id
|
||
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
|
||
|
||
# 创建映射 - 保留 EndUser 对象引用以便回填
|
||
end_user_map = {str(eu.id): eu for eu in end_users}
|
||
user_data = {str(eu.id): {"app_id": eu.app_id, "memory_config_id": eu.memory_config_id} for eu in end_users}
|
||
|
||
# 记录未找到的用户
|
||
found_user_ids = set(user_data.keys())
|
||
missing_user_ids = set(end_user_ids) - found_user_ids
|
||
if missing_user_ids:
|
||
logger.warning(f"End users not found: {missing_user_ids}")
|
||
for user_id in missing_user_ids:
|
||
result[user_id] = {"memory_config_id": None, "memory_config_name": None}
|
||
|
||
# 2. 批量获取所有相关应用以获取 workspace_id 和 type
|
||
app_ids = list(set(data["app_id"] for data in user_data.values()))
|
||
if not app_ids:
|
||
return result
|
||
|
||
apps = db.query(App).filter(App.id.in_(app_ids)).all()
|
||
app_map = {app.id: app for app in apps}
|
||
app_to_workspace = {app.id: app.workspace_id for app in apps}
|
||
|
||
# 3. 对于没有 memory_config_id 的用户,尝试从 AppRelease.config 提取
|
||
users_needing_migration = [
|
||
(end_user_id, data["app_id"])
|
||
for end_user_id, data in user_data.items()
|
||
if not data["memory_config_id"]
|
||
]
|
||
|
||
if users_needing_migration:
|
||
# 批量获取相关应用的最新发布版本
|
||
migration_app_ids = list(set(app_id for _, app_id in users_needing_migration))
|
||
|
||
# 查询每个应用的最新活跃发布版本
|
||
app_latest_releases = {}
|
||
for app_id in migration_app_ids:
|
||
stmt = (
|
||
select(AppRelease)
|
||
.where(AppRelease.app_id == app_id, AppRelease.is_active.is_(True))
|
||
.order_by(AppRelease.version.desc())
|
||
.limit(1)
|
||
)
|
||
latest_release = db.scalars(stmt).first()
|
||
if latest_release:
|
||
app_latest_releases[app_id] = latest_release
|
||
|
||
# 为每个需要迁移的用户提取 memory_config_id
|
||
config_service = MemoryConfigService(db)
|
||
users_to_backfill = [] # [(end_user, memory_config_id), ...]
|
||
|
||
for end_user_id, app_id in users_needing_migration:
|
||
latest_release = app_latest_releases.get(app_id)
|
||
if not latest_release:
|
||
continue
|
||
|
||
config = latest_release.config or {}
|
||
|
||
# 如果 config 是字符串,解析为字典
|
||
if isinstance(config, str):
|
||
try:
|
||
config = json_module.loads(config)
|
||
except json_module.JSONDecodeError:
|
||
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
|
||
continue
|
||
|
||
# 使用 MemoryConfigService 的提取方法
|
||
app = app_map.get(app_id)
|
||
if not app:
|
||
continue
|
||
|
||
legacy_config_id, is_legacy_int = config_service.extract_memory_config_id(
|
||
app_type=app.type,
|
||
config=config
|
||
)
|
||
|
||
if legacy_config_id:
|
||
# 更新 user_data 中的 memory_config_id
|
||
user_data[end_user_id]["memory_config_id"] = legacy_config_id
|
||
|
||
# 记录需要回填的用户(稍后验证配置存在后再回填)
|
||
end_user = end_user_map.get(end_user_id)
|
||
if end_user:
|
||
users_to_backfill.append((end_user, legacy_config_id))
|
||
elif is_legacy_int:
|
||
logger.info(
|
||
f"Legacy int config detected for end_user {end_user_id}, will use workspace default"
|
||
)
|
||
|
||
# 验证提取的 config_id 是否存在于数据库中
|
||
if users_to_backfill:
|
||
config_ids_to_validate = list(set(cid for _, cid in users_to_backfill))
|
||
existing_configs = db.query(MemoryConfig).filter(
|
||
MemoryConfig.config_id.in_(config_ids_to_validate)
|
||
).all()
|
||
valid_config_ids = {mc.config_id for mc in existing_configs}
|
||
|
||
# 只回填存在的配置
|
||
valid_backfills = [
|
||
(eu, cid) for eu, cid in users_to_backfill
|
||
if cid in valid_config_ids
|
||
]
|
||
invalid_backfills = [
|
||
(eu, cid) for eu, cid in users_to_backfill
|
||
if cid not in valid_config_ids
|
||
]
|
||
|
||
if invalid_backfills:
|
||
invalid_ids = [str(cid) for _, cid in invalid_backfills]
|
||
logger.warning(
|
||
f"Skipping backfill for non-existent memory_config_ids: {invalid_ids}"
|
||
)
|
||
# 清除 user_data 中无效的 config_id
|
||
for eu, cid in invalid_backfills:
|
||
user_data[str(eu.id)]["memory_config_id"] = None
|
||
|
||
# 批量回填 end_user.memory_config_id
|
||
if valid_backfills:
|
||
for end_user, memory_config_id in valid_backfills:
|
||
end_user.memory_config_id = memory_config_id
|
||
db.commit()
|
||
logger.info(f"Migrated memory_config_id for {len(valid_backfills)} end_users")
|
||
|
||
# 4. 收集需要查询的 memory_config_id 和需要回退的 workspace_id
|
||
direct_config_ids = []
|
||
workspace_fallback_users = [] # [(end_user_id, workspace_id), ...]
|
||
|
||
for end_user_id, data in user_data.items():
|
||
if data["memory_config_id"]:
|
||
direct_config_ids.append(data["memory_config_id"])
|
||
else:
|
||
workspace_id = app_to_workspace.get(data["app_id"])
|
||
if workspace_id:
|
||
workspace_fallback_users.append((end_user_id, workspace_id))
|
||
|
||
# 5. 批量查询直接分配的配置
|
||
config_id_to_config = {}
|
||
if direct_config_ids:
|
||
configs = db.query(MemoryConfig).filter(MemoryConfig.config_id.in_(direct_config_ids)).all()
|
||
config_id_to_config = {mc.config_id: mc for mc in configs}
|
||
|
||
# 6. 获取工作空间默认配置(需要逐个查询,因为 get_workspace_default_config 有复杂逻辑)
|
||
workspace_default_configs = {}
|
||
unique_workspace_ids = list(set(ws_id for _, ws_id in workspace_fallback_users))
|
||
|
||
if unique_workspace_ids:
|
||
config_service = MemoryConfigService(db)
|
||
for workspace_id in unique_workspace_ids:
|
||
default_config = config_service.get_workspace_default_config(workspace_id)
|
||
if default_config:
|
||
workspace_default_configs[workspace_id] = default_config
|
||
|
||
# 7. 构建最终结果
|
||
for end_user_id, data in user_data.items():
|
||
memory_config = None
|
||
|
||
# 优先使用 end_user 直接分配的配置
|
||
if data["memory_config_id"]:
|
||
memory_config = config_id_to_config.get(data["memory_config_id"])
|
||
|
||
# 回退到工作空间默认配置
|
||
if not memory_config:
|
||
workspace_id = app_to_workspace.get(data["app_id"])
|
||
if workspace_id:
|
||
memory_config = workspace_default_configs.get(workspace_id)
|
||
|
||
if memory_config:
|
||
result[end_user_id] = {
|
||
"memory_config_id": str(memory_config.config_id),
|
||
"memory_config_name": memory_config.config_name
|
||
}
|
||
else:
|
||
result[end_user_id] = {"memory_config_id": None, "memory_config_name": None}
|
||
|
||
logger.info(f"Successfully retrieved {len(result)} connected configs")
|
||
return result
|