Files
MemoryBear/api/app/services/memory_agent_service.py
2026-04-23 12:16:57 +08:00

1504 lines
63 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Memory Agent Service
Handles business logic for memory agent operations including read/write services,
health checks, and message type classification.
TODO: Refactor get_end_user_connected_config
----------------------------------------------
1. Move get_end_user_connected_config to memory_config_service.py
2. Change return type from Dict[str, Any] (with config_id string) to full MemoryConfig model
3. This will eliminate the need for callers to call load_memory_config separately
4. Update all callers to use the new unified function
"""
import json
import os
import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
from uuid import UUID
import redis
from langchain_core.messages import HumanMessage
from pydantic import BaseModel, Field
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.cache import InterestMemoryCache
from app.core.config import settings
from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.core.memory.agent.utils.messages_tools import (
merge_multiple_search_results,
reorder_output_results,
)
from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write as write_neo4j
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.log.audit_logger import audit_logger
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas import FileInput
from app.schemas.memory_agent_schema import Write_UserInput
from app.schemas.memory_config_schema import ConfigurationError
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import (
write_rag,
)
from app.services.memory_perceptual_service import MemoryPerceptualService
logger = get_logger(__name__)
config_logger = get_config_logger()
# Initialize Neo4j connector for analytics functions
_neo4j_connector = Neo4jConnector()
class MemoryAgentService:
"""Service for memory agent operations"""
def writer_messages_deal(self, messages, start_time, end_user_id, config_id, message, context):
duration = time.time() - start_time
if str(messages) == 'success':
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
# 记录成功的操作
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=True,
duration=duration, details={"message_length": len(message)})
return context
else:
logger.warning(f"Write operation failed for group {end_user_id}")
# 记录失败的操作
audit_logger.log_operation(
operation="WRITE",
config_id=config_id,
end_user_id=end_user_id,
success=False,
duration=duration,
error=f"写入失败: {messages[:100]}"
)
raise ValueError(f"写入失败: {messages}")
def extract_tool_call_info(self, event: Dict) -> bool:
"""Extract tool call information from event"""
last_message = event["messages"][-1]
# Check if AI message contains tool calls
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
tool_calls = last_message.tool_calls
for i, tool_call in enumerate(tool_calls):
if isinstance(tool_call, dict):
tool_call_id = tool_call.get('id')
tool_name = tool_call.get('name')
tool_args = tool_call.get('args', {})
else:
tool_call_id = getattr(tool_call, 'id', None)
tool_name = getattr(tool_call, 'name', None)
tool_args = getattr(tool_call, 'args', {})
logger.debug(f"Tool Call {i + 1}: ID={tool_call_id}, Name={tool_name}, Args={tool_args}")
return True
# Check if tool message
elif hasattr(last_message, 'tool_call_id'):
tool_call_id = getattr(last_message, 'tool_call_id', None)
if hasattr(last_message, 'name') and hasattr(last_message, 'content'):
tool_name = getattr(last_message, 'name', None)
try:
content = json.loads(getattr(last_message, 'content', '{}'))
tool_args = content.get('args', {})
logger.debug(f"Tool Call 1: ID={tool_call_id}, Name={tool_name}, Args={tool_args}")
except:
logger.debug(f"Tool Response ID: {tool_call_id}")
else:
logger.debug(f"Tool Response ID: {tool_call_id}")
return True
return False
async def get_health_status(self) -> Dict:
"""
Get latest health status from Redis cache
Returns health status information written by Celery periodic task
"""
logger.info("Checking health status")
client = redis.Redis(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None
)
payload = client.hgetall("memsci:health:read_service") or {}
if payload:
# decode bytes to str
decoded = {k.decode("utf-8"): v.decode("utf-8") for k, v in payload.items()}
status = decoded.get("status", "unknown")
else:
status = "unknown"
# Add database connection pool status
try:
from app.db import get_pool_status
pool_status = get_pool_status()
logger.info(f"Database pool status: {pool_status}")
# Check if pool usage is too high
if pool_status.get("usage_percent", 0) > 80:
logger.warning(f"High database pool usage: {pool_status['usage_percent']}%")
status = "warning"
except Exception as e:
logger.error(f"Failed to get pool status: {e}")
pool_status = {"error": str(e)}
logger.info(f"Health status: {status}")
return {
"status": status,
"database_pool": pool_status
}
def get_log_content(self) -> str:
"""
Read and return agent service log file content
Returns cleaned log content using the same cleaning logic as transmission mode
Returns cleaned log content using the same cleaning logic as transmission mode
"""
logger.info("Reading log file")
# Get log file path - use project root directory
from pathlib import Path
project_root = str(Path(__file__).resolve().parents[2]) # api directory
log_path = os.path.join(project_root, "logs", "agent_service.log")
summer = ''
with open(log_path, "r", encoding="utf-8") as infile:
for line in infile:
# Use the same cleaning logic as LogStreamer for consistency
cleaned = LogStreamer.clean_log_line(line)
summer += cleaned
if len(summer) < 10:
raise ValueError("NO LOGS")
logger.info(f"Log content retrieved, size: {len(summer)} bytes")
return summer
async def stream_log_content(self) -> AsyncGenerator[str, None]:
"""
Stream log content in real-time using Server-Sent Events (SSE)
This method establishes a streaming connection and transmits log entries
as they are written to the log file. It uses the LogStreamer to watch
the file and yields SSE-formatted messages.
Yields:
SSE-formatted strings with the following event types:
- log: Contains log content and timestamp
- keepalive: Periodic keepalive messages to maintain connection
- error: Error information if streaming fails
- done: Indicates streaming has completed
Raises:
FileNotFoundError: If log file doesn't exist at stream start
Exception: For other unexpected errors during streaming
"""
logger.info("Starting log content streaming")
# Get log file path - use project root directory
from pathlib import Path
project_root = str(Path(__file__).resolve().parents[2]) # api directory
log_path = os.path.join(project_root, "logs", "agent_service.log")
# Check if file exists before starting stream
if not os.path.exists(log_path):
logger.error(f"Log file not found: {log_path}")
# Send error event in SSE format
yield f"event: error\ndata: {json.dumps({'code': 4006, 'message': '日志文件不存在', 'error': f'File not found: {log_path}'})}\n\n"
return
streamer = None
try:
# Initialize LogStreamer with keepalive interval from settings (default 300 seconds)
keepalive_interval = getattr(settings, 'LOG_STREAM_KEEPALIVE_INTERVAL', 300)
streamer = LogStreamer(log_path, keepalive_interval=keepalive_interval)
logger.info(f"LogStreamer initialized for {log_path}")
# Stream log content using read_existing_and_stream to get all existing content first
async for message in streamer.read_existing_and_stream():
event_type = message.get("event")
data = message.get("data")
# Format as SSE message
# SSE format: "event: <type>\ndata: <json_data>\n\n"
sse_message = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
logger.debug(f"Streaming event: {event_type}")
yield sse_message
# If error or done event, stop streaming
if event_type in ["error", "done"]:
logger.info(f"Stream ended with event: {event_type}")
break
except FileNotFoundError as e:
logger.error(f"Log file not found during streaming: {e}")
yield f"event: error\ndata: {json.dumps({'code': 4006, 'message': '日志文件在流式传输期间变得不可用', 'error': str(e)})}\n\n"
except Exception as e:
logger.error(f"Unexpected error during log streaming: {e}", exc_info=True)
yield f"event: error\ndata: {json.dumps({'code': 8001, 'message': '流式传输期间发生错误', 'error': str(e)})}\n\n"
finally:
# Resource cleanup
logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(
self,
end_user_id: str,
messages: list[dict],
config_id: Optional[uuid.UUID] | int,
db: Session,
storage_type: str,
user_rag_memory_id: str,
language: str = "zh"
) -> str:
"""
Process write operation with config_id
Args:
end_user_id: Group identifier (also used as end_user_id)
messages: Message to write
config_id: Configuration ID from database
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)
user_rag_memory_id: User RAG memory ID
language: 语言类型 ("zh" 中文, "en" 英文)
Returns:
Write operation result status
Raises:
ValueError: If config loading fails or write operation fails
"""
# Resolve config_id and workspace_id
# Always get workspace_id from end_user for fallback, even if config_id is provided
workspace_id = None
try:
connected_config = get_end_user_connected_config(end_user_id, db)
workspace_id = connected_config.get("workspace_id")
if config_id is None:
config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user {end_user_id}. "
f"Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
if config_id is None:
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
# If config_id was provided, continue without workspace_id fallback
import time
start_time = time.time()
# Load configuration from database with workspace fallback
# Use a separate database session to avoid transaction failures
try:
from app.db import get_db_context
with get_db_context() as config_db:
config_service = MemoryConfigService(config_db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
# Log failed operation
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
perceptual_serivce = MemoryPerceptualService(db)
for message in messages:
message["file_content"] = []
for file in (message.get("files") or []):
file_object = await perceptual_serivce.generate_perceptual_memory(
end_user_id=end_user_id,
memory_config=memory_config,
file=FileInput(**file)
)
if file_object is None:
continue
message["file_content"].append((file_object, file["type"]))
logger.info(messages)
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
try:
if storage_type == "rag":
# For RAG storage, convert messages to single string
await write_rag(end_user_id, message_text, user_rag_memory_id)
return "success"
else:
await write_neo4j(
end_user_id=end_user_id,
messages=messages,
memory_config=memory_config,
ref_id='',
language=language
)
for lang in ["zh", "en"]:
deleted = await InterestMemoryCache.delete_interest_distribution(
end_user_id, lang
)
if deleted:
logger.info(
f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
for message in messages:
message["file_content"] = [
perceptual[0].file_path for perceptual in message["file_content"]
]
return self.writer_messages_deal(
"success",
start_time,
end_user_id,
config_id,
message_text,
{
"status": "success",
"data": messages,
"config_id": memory_config.config_id,
"config_name": memory_config.config_name
}
)
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}"
logger.error(error_msg)
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
async def read_memory(
self,
end_user_id: str,
message: str,
history: List[Dict], # FIXME: unused parameter
search_switch: str,
config_id: Optional[uuid.UUID] | int,
db: Session,
storage_type: str,
user_rag_memory_id: str) -> Dict:
"""
Process read operation with config_id
search_switch values:
- "0": Requires verification
- "1": No verification, direct split
- "2": Direct answer based on context
Args:
end_user_id: Group identifier (also used as end_user_id)
message: User message
history: Conversation history
search_switch: Search mode switch
config_id: Configuration ID from database
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)
user_rag_memory_id: User RAG memory ID
Returns:
Dict with 'answer' and 'intermediate_outputs' keys
Raises:
ValueError: If config loading fails
"""
import time
start_time = time.time()
ori_message = message
# Resolve config_id and workspace_id
# Always get workspace_id from end_user for fallback, even if config_id is provided
workspace_id = None
try:
connected_config = get_end_user_connected_config(end_user_id, db)
workspace_id = connected_config.get("workspace_id")
if config_id is None:
config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None:
raise ValueError(
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
if config_id is None:
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
# If config_id was provided, continue without workspace_id fallback
logger.info(f"Read operation for group {end_user_id} with config_id {config_id}")
config_load_start = time.time()
try:
# Use a separate database session to avoid transaction failures
from app.db import get_db_context
with get_db_context() as config_db:
config_service = MemoryConfigService(config_db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id,
service_name="MemoryAgentService"
)
config_load_time = time.time() - config_load_start
logger.info(f"[PERF] Configuration loaded in {config_load_time:.4f}s: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
# Log failed operation
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
end_user_id=end_user_id,
success=False,
duration=duration,
error=error_msg
)
raise ValueError(error_msg)
# Step 2: Prepare history
history.append({"role": "user", "content": message})
logger.debug(f"Group ID:{end_user_id}, Message:{message}, History:{history}, Config ID:{config_id}")
# Step 3: Initialize MCP client and execute read workflow
graph_exec_start = time.time()
try:
async with make_read_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {
"messages": [HumanMessage(content=message)],
"search_switch": search_switch,
"end_user_id": end_user_id,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config}
# 获取节点更新信息
_intermediate_outputs = []
summary = ''
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
# if 'save_neo4j' == node_name:
# massages = node_data
logger.info(f"处理节点: {node_name}")
# 处理不同Summary节点的返回结构
if 'Summary' in node_name:
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
summary = node_data['InputSummary']['summary_result']
elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
summary = node_data['RetrieveSummary']['summary_result']
elif 'summary' in node_data and 'summary_result' in node_data['summary']:
summary = node_data['summary']['summary_result']
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
summary = node_data['SummaryFails']['summary_result']
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
if spit_data and spit_data != [] and spit_data != {}:
_intermediate_outputs.append(spit_data)
# Problem_Extension 节点
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
if problem_extension and problem_extension != [] and problem_extension != {}:
_intermediate_outputs.append(problem_extension)
# Retrieve 节点
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
_intermediate_outputs.extend(retrieve_node)
# Perceptual_Retrieve 节点
perceptual_node = node_data.get('perceptual_data', {}).get('_intermediate', None)
if perceptual_node and perceptual_node != [] and perceptual_node != {}:
_intermediate_outputs.append(perceptual_node)
# Verify 节点
verify_n = node_data.get('verify', {}).get('_intermediate', None)
if verify_n and verify_n != [] and verify_n != {}:
_intermediate_outputs.append(verify_n)
# Summary 节点
summary_n = node_data.get('summary', {}).get('_intermediate', None)
if summary_n and summary_n != [] and summary_n != {}:
_intermediate_outputs.append(summary_n)
graph_exec_time = time.time() - graph_exec_start
logger.info(f"[PERF] Graph execution completed in {graph_exec_time:.4f}s")
_intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
result = reorder_output_results(optimized_outputs)
# 保存短期记忆到数据库
# 只有 search_switch 不为 "2"(快速检索)时才保存
try:
from app.repositories.memory_short_repository import (
ShortTermMemoryRepository,
)
retrieved_content = []
repo = ShortTermMemoryRepository(db)
if str(search_switch) != "2":
for intermediate in _intermediate_outputs:
logger.debug(f"处理中间结果: {intermediate}")
intermediate_type = intermediate.get('type', '')
if intermediate_type == "search_result":
query = intermediate.get('query', '')
raw_results = intermediate.get('raw_results', {})
try:
reranked_results = raw_results.get('reranked_results', [])
statements = [statement['statement'] for statement in
reranked_results.get('statements', [])]
except Exception:
statements = []
# 去重
statements = list(set(statements))
if query and statements:
retrieved_content.append({query: statements})
# 如果 retrieved_content 为空,设置为空字符串
if not retrieved_content:
retrieved_content = ''
# 只有当回答不是"信息不足"且不是快速检索时才保存
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2":
# 使用 upsert 方法
repo.upsert(
end_user_id=end_user_id,
messages=ori_message,
aimessages=summary,
retrieved_content=retrieved_content,
search_switch=str(search_switch)
)
logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
else:
logger.debug(
f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
except Exception as save_error:
# 保存失败不应该影响主流程,只记录错误
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
# Log successful operation
total_time = time.time() - start_time
logger.info(
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
end_user_id=end_user_id,
success=True,
duration=duration
)
return {
"answer": summary,
"intermediate_outputs": result
}
# TODO: redis search -> answer
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Read operation failed: {str(e)}"
logger.error(error_msg)
duration = time.time() - start_time
audit_logger.log_operation(
operation="READ",
config_id=config_id,
end_user_id=end_user_id,
success=False,
duration=duration,
error=error_msg
)
raise ValueError(error_msg)
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
"""
Get standardized message list from user input.
Args:
user_input: Write_UserInput object
Returns:
list[dict]: Message list, each message contains role and content
Raises:
ValueError: If messages is empty or format is incorrect
"""
from app.core.logging_config import get_api_logger
logger = get_api_logger()
if len(user_input.messages) == 0:
logger.error("Validation failed: Message list cannot be empty")
raise ValueError("Message list cannot be empty")
for idx, msg in enumerate(user_input.messages):
if not isinstance(msg, dict):
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
raise ValueError(
f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
if 'role' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}")
if 'content' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
raise ValueError(
f"Message format error: Message must contain 'content' field. Error message index: {idx}")
if msg['role'] not in ['user', 'assistant']:
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}")
if not msg['content'] or not msg['content'].strip():
logger.error(f"Validation failed: Message {idx} content is empty")
raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}")
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
return user_input.messages
async def classify_message_type(
self,
message: str,
config_id: UUID,
db: Session,
workspace_id: Optional[UUID] = None
) -> Dict:
"""
Determine the type of user message (read or write)
Updated to eliminate global variables in favor of explicit parameters.
Args:
message: User message to classify
config_id: Configuration ID to load LLM model from database
db: Database session
workspace_id: Workspace ID for fallback lookup (optional)
Returns:
Type classification result
"""
logger.info("Classifying message type")
# Load configuration to get LLM model ID
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id,
service_name="MemoryAgentService"
)
status = await status_typle(message, memory_config.llm_model_id)
logger.debug(f"Message type: {status}")
return status
async def generate_summary_from_retrieve(
self,
end_user_id: str,
retrieve_info: str,
history: List[Dict],
query: str,
config_id: str,
db: Session
) -> str:
"""
基于检索信息、历史对话和查询生成最终答案
使用 Retrieve_Summary_prompt.jinja2 模板调用大模型生成答案
Args:
retrieve_info: 检索到的信息
history: 历史对话记录
query: 用户查询
config_id: 配置ID
db: 数据库会话
Returns:
生成的答案文本
"""
# Always get workspace_id from end_user for fallback, even if config_id is provided
workspace_id = None
try:
connected_config = get_end_user_connected_config(end_user_id, db)
workspace_id = connected_config.get('workspace_id')
if config_id is None:
config_id = connected_config.get('memory_config_id')
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None:
raise ValueError(
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
if config_id is None:
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
# If config_id was provided, continue without workspace_id fallback
logger.info(f"Generating summary from retrieve info for query: {query[:50]}...")
try:
# 加载配置
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id,
service_name="MemoryAgentService"
)
# 导入必要的模块
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
summary_llm,
)
from app.core.memory.agent.models.summary_models import (
RetrieveSummaryResponse,
)
# 构建状态对象
state = {
"data": query,
"memory_config": memory_config
}
# 直接调用 summary_llm 函数
answer = await summary_llm(
state=state,
history=history,
retrieve_info=retrieve_info,
template_name='direct_summary_prompt.jinja2',
operation_name='retrieve_summary',
response_model=RetrieveSummaryResponse,
search_mode="1"
)
logger.info(f"Successfully generated summary: {answer[:100] if answer else 'None'}...")
return answer if answer else "信息不足,无法回答。"
except Exception as e:
logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
return "信息不足,无法回答。"
async def get_knowledge_type_stats(
self,
db: Session,
end_user_id: Optional[str] = None,
only_active: bool = True,
current_workspace_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]:
"""
统计知识库类型分布,包含:
1. PostgreSQL 中的知识库类型General, Web, Third-party, Folder根据 workspace_id 过滤)
2. total: 所有类型的总和
参数:
- end_user_id: 用户组ID可选保留参数以保持接口兼容性
- only_active: 是否仅统计有效记录
- current_workspace_id: 当前工作空间ID可选未提供时知识库统计为 0
- db: 数据库会话
返回格式:
{
"General": count,
"Web": count,
"Third-party": count,
"Folder": count,
"total": sum_of_all
}
"""
result = {}
# 1. 统计 PostgreSQL 中的知识库类型
try:
# 初始化所有标准类型为 0
for kb_type in KnowledgeType:
result[kb_type.value] = 0
# 如果提供了 workspace_id则按 workspace_id 过滤
if current_workspace_id:
# 构建查询条件
query = db.query(
Knowledge.type,
func.count(Knowledge.id).label('count')
).filter(Knowledge.workspace_id == current_workspace_id)
# 检查 Knowledge 模型是否有 status 字段
if only_active and hasattr(Knowledge, 'status'):
query = query.filter(Knowledge.status == 1)
# 按类型分组
type_counts = query.group_by(Knowledge.type).all()
# 只填充标准类型的统计值,忽略其他类型
valid_types = {kb_type.value for kb_type in KnowledgeType}
for type_name, count in type_counts:
if type_name in valid_types:
result[type_name] = count
logger.info(f"知识库类型统计成功 (workspace_id={current_workspace_id}): {result}")
else:
# 没有提供 workspace_id所有知识库类型返回 0
logger.info("未提供 workspace_id知识库类型统计全部为 0")
except Exception as e:
logger.error(f"知识库类型统计失败: {e}")
raise Exception(f"知识库类型统计失败: {e}")
# 2. 统计 Neo4j 中的 memory 总量已移除
# memory 字段不再返回
# 3. 计算知识库类型总和(不包括 memory
result["total"] = (
result.get("General", 0) +
result.get("Web", 0) +
result.get("Third-party", 0) +
result.get("Folder", 0)
)
return result
async def get_interest_distribution_by_user(
self,
end_user_id: Optional[str] = None,
limit: int = 5,
language: str = "zh"
) -> List[Dict[str, Any]]:
"""
获取指定用户的兴趣分布标签。
与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习等),
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
参数:
- end_user_id: 用户ID必填
- limit: 返回标签数量限制
- language: 输出语言("zh" 中文, "en" 英文)
返回格式:
[
{"name": "兴趣活动名", "frequency": 频次},
...
]
"""
try:
tags = await get_interest_distribution(end_user_id, limit=limit, by_user=False, language=language)
return [{"name": tag, "frequency": freq} for tag, freq in tags]
except Exception as e:
logger.error(f"兴趣分布标签查询失败: {e}")
raise Exception(f"兴趣分布标签查询失败: {e}")
async def get_user_profile(
self,
end_user_id: Optional[str] = None,
current_user_id: Optional[str] = None,
llm_id: Optional[str] = None,
db: Session = None
) -> Dict[str, Any]:
"""
获取用户详情,包含:
1. 用户名字(直接使用 end_user_name)
2. 用户标签从摘要中用LLM总结3个标签
3. 热门记忆标签从hot_memory_tags获取前4个
参数:
- end_user_id: 用户ID可选
- current_user_id: 当前登录用户的ID保留参数
- llm_id: LLM模型ID用于生成标签可选如果不提供则跳过标签生成
- db: 数据库会话(可选)
返回格式:
{
"name": "用户名",
"tags": ["产品设计师", "旅行爱好者", "摄影发烧友"],
"hot_tags": [
{"name": "标签1", "frequency": 10},
{"name": "标签2", "frequency": 8},
...
]
}
"""
result = {}
# 1. 根据 end_user_id 获取 end_user_name
try:
if end_user_id and db:
from app.repositories import end_user_repository
from app.schemas.end_user_schema import EndUser as EndUserSchema
end_user_orm = end_user_repository.get_end_user_by_id(db, end_user_id)
if end_user_orm:
end_user = EndUserSchema.model_validate(end_user_orm)
end_user_name = end_user.other_name
else:
end_user_name = "默认用户"
else:
end_user_name = "默认用户"
except Exception as e:
logger.error(f"Failed to get end_user_name: {e}")
end_user_name = "默认用户"
result["name"] = end_user_name
logger.debug(f"The end_user is: {end_user_name}")
# 2. 使用LLM从语句和实体中提取标签
try:
connector = Neo4jConnector()
# 查询该用户的语句
query = (
"MATCH (s:Statement) "
"WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id) AND s.statement IS NOT NULL "
"RETURN s.statement AS statement "
"ORDER BY s.created_at DESC LIMIT 100"
)
rows = await connector.execute_query(query, end_user_id=end_user_id)
statements = [r.get("statement", "") for r in rows if r.get("statement")]
# 查询该用户的热门实体
entity_query = (
"MATCH (e:ExtractedEntity) "
"WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) AND e.entity_type <> '人物' AND e.name IS NOT NULL "
"RETURN e.name AS name, count(e) AS frequency "
"ORDER BY frequency DESC LIMIT 20"
)
entity_rows = await connector.execute_query(entity_query, end_user_id=end_user_id)
entities = [f"{r['name']} ({r['frequency']})" for r in entity_rows]
await connector.close()
if not statements or not llm_id:
result["tags"] = []
if not llm_id and statements:
logger.warning("llm_id not provided, skipping tag generation")
else:
# 构建摘要文本
summary_text = f"用户语句样本:{' | '.join(statements[:20])}\n核心实体:{', '.join(entities)}"
logger.debug(f"User data found: {len(statements)} statements, {len(entities)} entities")
# 使用LLM提取标签
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(llm_id)
# 定义标签提取的结构
class UserTags(BaseModel):
tags: list[str] = Field(...,
description="3个描述用户特征的标签产品设计师、旅行爱好者、摄影发烧友")
messages = [
{
"role": "system",
"content": "你是一个信息提取助手。从用户的语句和实体中提取3个最能代表用户特征的标签。标签应该简洁2-6个字描述用户的职业、兴趣或特点。"
},
{
"role": "user",
"content": f"请从以下用户信息中提取3个标签\n\n{summary_text}"
}
]
user_tags = await llm_client.response_structured(
messages=messages,
response_model=UserTags
)
result["tags"] = user_tags.tags
logger.debug(f"Extracted tags: {user_tags.tags}")
except Exception as e:
# 如果提取失败,使用默认值
logger.error(f"Failed to extract user tags: {e}")
result["tags"] = []
try:
# 3. 获取热门记忆标签前4个
connector = Neo4jConnector()
names_to_exclude = ['AI', 'Caroline', 'Melanie', 'Jon', 'Gina', '用户', 'AI助手', 'John', 'Maria']
hot_tag_query = (
"MATCH (e:ExtractedEntity) "
"WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) AND e.entity_type <> '人物' "
"AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
"RETURN e.name AS name, count(e) AS frequency "
"ORDER BY frequency DESC LIMIT 4"
)
hot_tag_rows = await connector.execute_query(
hot_tag_query,
end_user_id=end_user_id,
names_to_exclude=names_to_exclude
)
await connector.close()
result["hot_tags"] = [{"name": r["name"], "frequency": r["frequency"]} for r in hot_tag_rows]
logger.debug(f"Hot tags found: {len(result['hot_tags'])} tags")
except Exception as e:
logger.error(f"Failed to get hot tags: {e}")
result["hot_tags"] = []
return result
async def stream_log_content(self) -> AsyncGenerator[str, None]:
"""
Stream log content in real-time using Server-Sent Events (SSE)
This method establishes a streaming connection and transmits log entries
as they are written to the log file. It uses the LogStreamer to watch
the file and yields SSE-formatted messages.
Yields:
SSE-formatted strings with the following event types:
- log: Contains log content and timestamp
- keepalive: Periodic keepalive messages to maintain connection
- error: Error information if streaming fails
- done: Indicates streaming has completed
Raises:
FileNotFoundError: If log file doesn't exist at stream start
Exception: For other unexpected errors during streaming
"""
logger.info("Starting log content streaming")
# Get log file path - use project root directory
from pathlib import Path
project_root = str(Path(__file__).resolve().parents[2]) # api directory
log_path = os.path.join(project_root, "logs", "agent_service.log")
# Check if file exists before starting stream
if not os.path.exists(log_path):
logger.error(f"Log file not found: {log_path}")
# Send error event in SSE format
yield f"event: error\ndata: {json.dumps({'code': 4006, 'message': '日志文件不存在', 'error': f'File not found: {log_path}'})}\n\n"
return
streamer = None
try:
# Initialize LogStreamer with keepalive interval from settings (default 300 seconds)
keepalive_interval = getattr(settings, 'LOG_STREAM_KEEPALIVE_INTERVAL', 300)
streamer = LogStreamer(log_path, keepalive_interval=keepalive_interval)
logger.info(f"LogStreamer initialized for {log_path}")
# Stream log content using read_existing_and_stream to get all existing content first
async for message in streamer.read_existing_and_stream():
event_type = message.get("event")
data = message.get("data")
# Format as SSE message
# SSE format: "event: <type>\ndata: <json_data>\n\n"
sse_message = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
logger.debug(f"Streaming event: {event_type}")
yield sse_message
# If error or done event, stop streaming
if event_type in ["error", "done"]:
logger.info(f"Stream ended with event: {event_type}")
break
except FileNotFoundError as e:
logger.error(f"Log file not found during streaming: {e}")
yield f"event: error\ndata: {json.dumps({'code': 4006, 'message': '日志文件在流式传输期间变得不可用', 'error': str(e)})}\n\n"
except Exception as e:
logger.error(f"Unexpected error during log streaming: {e}", exc_info=True)
yield f"event: error\ndata: {json.dumps({'code': 8001, 'message': '流式传输期间发生错误', 'error': str(e)})}\n\n"
finally:
# Resource cleanup
logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic
# TODO: move to memory_config_service.py
def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]:
"""
获取终端用户关联的记忆配置
兼容旧数据:如果 end_user.memory_config_id 为空,则从 AppRelease.config 中获取
并回填到 end_user.memory_config_id 字段(懒迁移)。
Args:
end_user_id: 终端用户ID
db: 数据库会话
Returns:
包含 memory_config_id, workspace_id 和相关信息的字典
Raises:
ValueError: 当终端用户不存在或应用未发布时
"""
import json as json_module
from sqlalchemy import select
from app.models.app_model import App
from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser
from app.services.memory_config_service import MemoryConfigService
logger.info(f"Getting connected config for end_user: {end_user_id}")
# TODO: check sources for enduserid, should be one of these three: chat, draft, apikey
# 1. 获取 end_user 及其 app_id
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
if not end_user:
logger.warning(f"End user not found: {end_user_id}")
raise ValueError(f"终端用户不存在: {end_user_id}")
app_id = end_user.app_id
logger.debug(f"Found end_user app_id: {app_id}")
# 2. 获取应用以确定 workspace_id
app = db.query(App).filter(App.id == app_id).first()
if not app:
logger.warning(f"App not found: {app_id}")
# raise ValueError(f"应用不存在: {app_id}")
# TODO: temp fix for draft run
# if not app.current_release_id:
# logger.warning(f"No current release for app: {app_id}")
# raise ValueError(f"应用未发布: {app_id}")
# 3. 兼容旧数据:如果 memory_config_id 为空,从 AppRelease.config 获取并回填
memory_config_id_to_use = end_user.memory_config_id
# 如果已有 memory_config_id直接使用
# 如果新创建enduserenduser.memory_config_id 必定为none
# 那么使用从release中获取memory_config_id为预期行为并且回填到
# end_user.memory_config_id
if not memory_config_id_to_use:
logger.info(f"end_user.memory_config_id is None, migrating from AppRelease.config")
# 获取最新发布版本
stmt = (
select(AppRelease)
.where(AppRelease.app_id == app_id, AppRelease.is_active.is_(True))
.order_by(AppRelease.version.desc())
)
# TODO: change to current_release_id
latest_release = db.scalars(stmt).first()
if latest_release:
config = latest_release.config or {}
# 如果 config 是字符串,解析为字典
if isinstance(config, str):
try:
config = json_module.loads(config)
except json_module.JSONDecodeError:
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
config = {}
# 使用 MemoryConfigService 的提取方法
memory_config_service = MemoryConfigService(db)
legacy_config_id, is_legacy_int = memory_config_service.extract_memory_config_id(
app_type=app.type,
config=config
)
if legacy_config_id:
# 验证提取的 config_id 是否存在于数据库中
from app.models.memory_config_model import (
MemoryConfig as MemoryConfigModel,
)
existing_config = db.get(MemoryConfigModel, legacy_config_id)
if existing_config:
memory_config_id_to_use = legacy_config_id
# 回填到 end_user 表lazy update
end_user.memory_config_id = memory_config_id_to_use
db.commit()
logger.info(
f"Migrated memory_config_id for end_user {end_user_id}: {memory_config_id_to_use}"
)
else:
logger.warning(
f"Extracted memory_config_id does not exist, skipping backfill: "
f"end_user_id={end_user_id}, config_id={legacy_config_id}"
)
elif is_legacy_int:
logger.info(
f"Legacy int config detected for end_user {end_user_id}, will use workspace default"
)
# 4. 使用 get_config_with_fallback 获取记忆配置
memory_config_service = MemoryConfigService(db)
memory_config = memory_config_service.get_config_with_fallback(
memory_config_id=memory_config_id_to_use,
workspace_id=end_user.workspace_id
)
memory_config_id = str(memory_config.config_id) if memory_config else None
result = {
"end_user_id": str(end_user_id),
"memory_config_id": memory_config_id,
"workspace_id": str(end_user.workspace_id)
}
logger.info(
f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={end_user.workspace_id}")
return result
def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) -> Dict[str, Dict[str, Any]]:
"""
批量获取多个终端用户关联的记忆配置(优化版本,减少数据库查询次数)
使用与 get_end_user_connected_config 相同的逻辑:
1. 优先使用 end_user.memory_config_id
2. 如果没有,尝试从 AppRelease.config 提取并回填
3. 如果仍然没有,回退到工作空间默认配置
Args:
end_user_ids: 终端用户ID列表
db: 数据库会话
Returns:
字典key 为 end_user_idvalue 为包含 memory_config_id 和 memory_config_name 的字典
格式: {
"user_id_1": {"memory_config_id": "xxx", "memory_config_name": "xxx"},
"user_id_2": {"memory_config_id": None, "memory_config_name": None},
...
}
"""
import json as json_module
from sqlalchemy import select
from app.models.app_model import App
from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser
from app.models.memory_config_model import MemoryConfig
from app.services.memory_config_service import MemoryConfigService
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users")
result = {}
if not end_user_ids:
return result
# 1. 批量查询所有 end_user 及其 app_id 和 memory_config_id
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
# 创建映射 - 保留 EndUser 对象引用以便回填
end_user_map = {str(eu.id): eu for eu in end_users}
user_data = {str(eu.id): {"app_id": eu.app_id, "memory_config_id": eu.memory_config_id} for eu in end_users}
# 记录未找到的用户
found_user_ids = set(user_data.keys())
missing_user_ids = set(end_user_ids) - found_user_ids
if missing_user_ids:
logger.warning(f"End users not found: {missing_user_ids}")
for user_id in missing_user_ids:
result[user_id] = {"memory_config_id": None, "memory_config_name": None}
# 2. 批量获取所有相关应用以获取 workspace_id 和 type
app_ids = list(set(data["app_id"] for data in user_data.values()))
if not app_ids:
return result
apps = db.query(App).filter(App.id.in_(app_ids)).all()
app_map = {app.id: app for app in apps}
app_to_workspace = {app.id: app.workspace_id for app in apps}
# 3. 对于没有 memory_config_id 的用户,尝试从 AppRelease.config 提取
users_needing_migration = [
(end_user_id, data["app_id"])
for end_user_id, data in user_data.items()
if not data["memory_config_id"]
]
if users_needing_migration:
# 批量获取相关应用的最新发布版本
migration_app_ids = list(set(app_id for _, app_id in users_needing_migration))
# 查询每个应用的最新活跃发布版本
app_latest_releases = {}
for app_id in migration_app_ids:
stmt = (
select(AppRelease)
.where(AppRelease.app_id == app_id, AppRelease.is_active.is_(True))
.order_by(AppRelease.version.desc())
.limit(1)
)
latest_release = db.scalars(stmt).first()
if latest_release:
app_latest_releases[app_id] = latest_release
# 为每个需要迁移的用户提取 memory_config_id
config_service = MemoryConfigService(db)
users_to_backfill = [] # [(end_user, memory_config_id), ...]
for end_user_id, app_id in users_needing_migration:
latest_release = app_latest_releases.get(app_id)
if not latest_release:
continue
config = latest_release.config or {}
# 如果 config 是字符串,解析为字典
if isinstance(config, str):
try:
config = json_module.loads(config)
except json_module.JSONDecodeError:
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
continue
# 使用 MemoryConfigService 的提取方法
app = app_map.get(app_id)
if not app:
continue
legacy_config_id, is_legacy_int = config_service.extract_memory_config_id(
app_type=app.type,
config=config
)
if legacy_config_id:
# 更新 user_data 中的 memory_config_id
user_data[end_user_id]["memory_config_id"] = legacy_config_id
# 记录需要回填的用户(稍后验证配置存在后再回填)
end_user = end_user_map.get(end_user_id)
if end_user:
users_to_backfill.append((end_user, legacy_config_id))
elif is_legacy_int:
logger.info(
f"Legacy int config detected for end_user {end_user_id}, will use workspace default"
)
# 验证提取的 config_id 是否存在于数据库中
if users_to_backfill:
config_ids_to_validate = list(set(cid for _, cid in users_to_backfill))
existing_configs = db.query(MemoryConfig).filter(
MemoryConfig.config_id.in_(config_ids_to_validate)
).all()
valid_config_ids = {mc.config_id for mc in existing_configs}
# 只回填存在的配置
valid_backfills = [
(eu, cid) for eu, cid in users_to_backfill
if cid in valid_config_ids
]
invalid_backfills = [
(eu, cid) for eu, cid in users_to_backfill
if cid not in valid_config_ids
]
if invalid_backfills:
invalid_ids = [str(cid) for _, cid in invalid_backfills]
logger.warning(
f"Skipping backfill for non-existent memory_config_ids: {invalid_ids}"
)
# 清除 user_data 中无效的 config_id
for eu, cid in invalid_backfills:
user_data[str(eu.id)]["memory_config_id"] = None
# 批量回填 end_user.memory_config_id
if valid_backfills:
for end_user, memory_config_id in valid_backfills:
end_user.memory_config_id = memory_config_id
db.commit()
logger.info(f"Migrated memory_config_id for {len(valid_backfills)} end_users")
# 4. 收集需要查询的 memory_config_id 和需要回退的 workspace_id
direct_config_ids = []
workspace_fallback_users = [] # [(end_user_id, workspace_id), ...]
for end_user_id, data in user_data.items():
if data["memory_config_id"]:
direct_config_ids.append(data["memory_config_id"])
else:
workspace_id = app_to_workspace.get(data["app_id"])
if workspace_id:
workspace_fallback_users.append((end_user_id, workspace_id))
# 5. 批量查询直接分配的配置
config_id_to_config = {}
if direct_config_ids:
configs = db.query(MemoryConfig).filter(MemoryConfig.config_id.in_(direct_config_ids)).all()
config_id_to_config = {mc.config_id: mc for mc in configs}
# 6. 获取工作空间默认配置(需要逐个查询,因为 get_workspace_default_config 有复杂逻辑)
workspace_default_configs = {}
unique_workspace_ids = list(set(ws_id for _, ws_id in workspace_fallback_users))
if unique_workspace_ids:
config_service = MemoryConfigService(db)
for workspace_id in unique_workspace_ids:
default_config = config_service.get_workspace_default_config(workspace_id)
if default_config:
workspace_default_configs[workspace_id] = default_config
# 7. 构建最终结果
for end_user_id, data in user_data.items():
memory_config = None
# 优先使用 end_user 直接分配的配置
if data["memory_config_id"]:
memory_config = config_id_to_config.get(data["memory_config_id"])
# 回退到工作空间默认配置
if not memory_config:
workspace_id = app_to_workspace.get(data["app_id"])
if workspace_id:
memory_config = workspace_default_configs.get(workspace_id)
if memory_config:
result[end_user_id] = {
"memory_config_id": str(memory_config.config_id),
"memory_config_name": memory_config.config_name
}
else:
result[end_user_id] = {"memory_config_id": None, "memory_config_name": None}
logger.info(f"Successfully retrieved {len(result)} connected configs")
return result