Files
MemoryBear/api/app/services/user_memory_service.py
lanceyq e3ab19dd4f feat(memory): sync user entity aliases and metadata to PostgreSQL
- Add `aliases` and `end_user_id` fields to user entity dicts in
  `collect_user_entities_for_metadata` so downstream tasks can write
  them to PostgreSQL
- Add `update_aliases_and_metadata` method to `EndUserInfoRepository`
  for incremental, case-insensitive dedup merge of aliases and
  structured metadata fields
- Add `_sync_end_user_info_pg` helper in tasks.py that writes aliases
  and extracted metadata to `end_user_info`, and back-fills
  `end_user.other_name` when empty
- Call `_sync_end_user_info_pg` from `extract_metadata_batch_task`
  after Neo4j write, and also when no new metadata but aliases exist
- Filter `meta_data` response in `UserMemoryService.get_end_user_info`
  to expose only four core fields: goals, traits, interests, core_facts
2026-05-08 11:28:44 +08:00

2090 lines
84 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
User Memory Service
处理用户记忆相关的业务逻辑,包括记忆洞察、用户摘要、节点统计和图数据等。
"""
import os
import uuid
from collections import Counter
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.core.logging_config import get_logger
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import _USER_PLACEHOLDER_NAMES
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.conversation_repository import ConversationRepository
from app.repositories.end_user_repository import EndUserRepository
from app.repositories.neo4j.cypher_queries import Graph_Node_query
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
from app.services.memory_base_service import MemoryBaseService, MIN_MEMORY_SUMMARY_COUNT
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_perceptual_service import MemoryPerceptualService
from app.services.memory_short_service import ShortService
logger = get_logger(__name__)
# Neo4j connector instance for analytics functions
_neo4j_connector = Neo4jConnector()
# Default LLM ID for fallback
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
# ============================================================================
# Internal Helper Classes
# ============================================================================
class TagClassification(BaseModel):
"""Represents the classification of a tag into a specific domain."""
domain: str = Field(
...,
description="The domain the tag belongs to, chosen from the predefined list.",
examples=["教育", "学习", "工作", "旅行", "家庭", "运动", "社交", "娱乐", "健康", "其他"],
)
def _get_llm_client_for_user(user_id: str):
"""
Get LLM client for a specific user based on their config.
Args:
user_id: User ID to get config for
Returns:
LLM client instance
"""
with get_db_context() as db:
try:
from app.services.memory_agent_service import get_end_user_connected_config
connected_config = get_end_user_connected_config(user_id, db)
config_id = connected_config.get("memory_config_id")
workspace_id = connected_config.get("workspace_id")
if config_id or workspace_id:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id
)
factory = MemoryClientFactory(db)
return factory.get_llm_client(memory_config.llm_model_id)
else:
factory = MemoryClientFactory(db)
return factory.get_llm_client(DEFAULT_LLM_ID)
except Exception as e:
logger.warning(f"Failed to get user connected config, using default LLM: {e}")
factory = MemoryClientFactory(db)
return factory.get_llm_client(DEFAULT_LLM_ID)
class MemoryInsightHelper:
"""
Internal helper class for memory insight analysis.
Provides basic data retrieval and analysis functionality.
"""
def __init__(self, user_id: str):
self.user_id = user_id
self.neo4j_connector = Neo4jConnector()
self.llm_client = _get_llm_client_for_user(user_id)
async def close(self):
"""Close database connection."""
await self.neo4j_connector.close()
async def get_domain_distribution(self) -> dict[str, float]:
"""Calculate the distribution of memory domains based on hot tags."""
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
hot_tags = await get_hot_memory_tags(self.user_id)
if not hot_tags:
return {}
domain_counts = Counter()
for tag, _ in hot_tags:
prompt = f"""请将以下标签归类到最合适的领域中。
可选领域及其关键词:
- 教育:学校、课程、考试、培训、教学、学科、教师、学生、班级、作业、成绩、毕业、入学、校园、大学、中学、小学、教材、学位等
- 学习:自学、阅读、书籍、技能提升、知识积累、笔记、复习、练习、研究、历史知识、科学知识、文化知识、学术讨论、知识问答等
- 工作:职业、项目、会议、同事、业务、公司、办公、任务、客户、合同、职场、工作计划等
- 旅行:旅游、景点、出行、度假、酒店、机票、导游、风景、旅行计划等
- 家庭:亲人、父母、子女、配偶、家事、家庭活动、亲情、家庭聚会等
- 运动:健身、体育、锻炼、跑步、游泳、球类、瑜伽、运动计划等
- 社交:朋友、聚会、社交活动、派对、聊天、交友、社交网络等
- 娱乐:游戏、电影、音乐、休闲、综艺、动漫、小说、娱乐活动等
- 健康:医疗、养生、心理健康、体检、药物、疾病、保健、健康管理等
- 其他:确实无法归入以上任何类别的内容
标签: {tag}
分析步骤:
1. 仔细理解标签的核心含义和使用场景
2. 对比各个领域的关键词,找到最匹配的领域
3. 特别注意:
- 历史、科学、文化等知识性内容应归类为"学习"
- 学校、课程、考试等正式教育场景应归类为"教育"
- 只有在标签完全不属于上述9个具体领域时才选择"其他"
4. 如果标签与某个领域有任何相关性,就选择该领域,不要选"其他"
请直接返回最合适的领域名称。"""
messages = [
{"role": "system", "content": "你是一个专业的标签分类助手。你必须仔细分析标签的实际含义和使用场景优先选择9个具体领域之一。'其他'类别只用于完全无法归类的极少数情况。特别注意:历史、科学、文化等知识性对话应归类为'学习'领域;学校、课程、考试等正式教育场景应归类为'教育'领域。"},
{"role": "user", "content": prompt}
]
classification = await self.llm_client.response_structured(
messages=messages,
response_model=TagClassification,
)
if classification and hasattr(classification, 'domain') and classification.domain:
domain_counts[classification.domain] += 1
total_tags = sum(domain_counts.values())
if total_tags == 0:
return {}
domain_distribution = {
domain: count / total_tags for domain, count in domain_counts.items()
}
return dict(sorted(domain_distribution.items(), key=lambda item: item[1], reverse=True))
async def get_active_periods(self) -> list[int]:
"""
Identify the top 2 most active months for the user.
Only returns months if there is valid and diverse time data.
"""
query = """
MATCH (d:Dialogue)
WHERE d.end_user_id = $end_user_id AND d.created_at IS NOT NULL AND d.created_at <> ''
RETURN d.created_at AS creation_time
"""
records = await self.neo4j_connector.execute_query(query, end_user_id=self.user_id)
if not records:
return []
month_counts = Counter()
valid_dates_count = 0
for record in records:
creation_time = record.get("creation_time")
if not creation_time:
continue
try:
# 处理 Neo4j DateTime 对象或字符串
if hasattr(creation_time, 'to_native'):
dt_object = creation_time.to_native()
elif isinstance(creation_time, str):
dt_object = datetime.fromisoformat(creation_time.replace("Z", "+00:00"))
elif isinstance(creation_time, datetime):
dt_object = creation_time
else:
dt_object = datetime.fromisoformat(str(creation_time).replace("Z", "+00:00"))
month_counts[dt_object.month] += 1
valid_dates_count += 1
except (ValueError, TypeError, AttributeError):
continue
if not month_counts or valid_dates_count == 0:
return []
# Check if time distribution is too concentrated (likely batch imported data)
unique_months = len(month_counts)
if unique_months <= 2:
most_common_count = month_counts.most_common(1)[0][1]
if most_common_count / valid_dates_count > 0.8:
return []
if unique_months >= 3:
most_common_months = month_counts.most_common(2)
return [month for month, _ in most_common_months]
if unique_months == 2:
counts = list(month_counts.values())
ratio = min(counts) / max(counts)
if ratio > 0.3:
most_common_months = month_counts.most_common(2)
return [month for month, _ in most_common_months]
return []
async def get_social_connections(self) -> dict | None:
"""Find the user with whom the most memories are shared."""
query = """
MATCH (c1:Chunk {end_user_id: $end_user_id})
OPTIONAL MATCH (c1)-[:CONTAINS]->(s:Statement)
OPTIONAL MATCH (s)<-[:CONTAINS]-(c2:Chunk)
WHERE c1.end_user_id <> c2.end_user_id AND s IS NOT NULL AND c2 IS NOT NULL
WITH c2.end_user_id AS other_user_id, COUNT(DISTINCT s) AS common_statements
WHERE common_statements > 0
RETURN other_user_id, common_statements
ORDER BY common_statements DESC
LIMIT 1
"""
records = await self.neo4j_connector.execute_query(query, end_user_id=self.user_id)
if not records or not records[0].get("other_user_id"):
return None
most_connected_user = records[0]["other_user_id"]
common_memories_count = records[0]["common_statements"]
time_range_query = """
MATCH (c:Chunk)
WHERE c.end_user_id IN [$user_id, $other_user_id]
RETURN min(c.created_at) AS start_time, max(c.created_at) AS end_time
"""
time_records = await self.neo4j_connector.execute_query(
time_range_query,
user_id=self.user_id,
other_user_id=most_connected_user
)
start_year, end_year = "N/A", "N/A"
if time_records and time_records[0]["start_time"]:
start_time = time_records[0]["start_time"]
end_time = time_records[0]["end_time"]
# 处理 Neo4j DateTime 对象或字符串
try:
if hasattr(start_time, 'to_native'):
start_year = start_time.to_native().year
elif isinstance(start_time, str):
start_year = datetime.fromisoformat(start_time.replace("Z", "+00:00")).year
elif isinstance(start_time, datetime):
start_year = start_time.year
else:
start_year = datetime.fromisoformat(str(start_time).replace("Z", "+00:00")).year
except Exception:
start_year = "N/A"
try:
if hasattr(end_time, 'to_native'):
end_year = end_time.to_native().year
elif isinstance(end_time, str):
end_year = datetime.fromisoformat(end_time.replace("Z", "+00:00")).year
elif isinstance(end_time, datetime):
end_year = end_time.year
else:
end_year = datetime.fromisoformat(str(end_time).replace("Z", "+00:00")).year
except Exception:
end_year = "N/A"
return {
"user_id": most_connected_user,
"common_memories_count": common_memories_count,
"time_range": f"{start_year}-{end_year}",
}
class UserSummaryHelper:
"""
Internal helper class for user summary generation.
Provides data retrieval functionality for user summary analysis.
"""
def __init__(self, user_id: str):
self.user_id = user_id
self.connector = Neo4jConnector()
self.llm = _get_llm_client_for_user(user_id)
async def close(self):
"""Close database connection."""
await self.connector.close()
async def get_recent_statements(self, limit: int = 80) -> List[Dict[str, Any]]:
"""Fetch recent statements authored by the user/group for context."""
query = (
"MATCH (s:Statement) "
"WHERE s.end_user_id = $end_user_id AND s.statement IS NOT NULL "
"RETURN s.statement AS statement, s.created_at AS created_at "
"ORDER BY created_at DESC LIMIT $limit"
)
rows = await self.connector.execute_query(query, end_user_id=self.user_id, limit=limit)
records = []
for r in rows:
try:
records.append({
"statement": r.get("statement", ""),
"created_at": r.get("created_at")
})
except Exception:
continue
return records
async def get_top_entities(self, limit: int = 30) -> List[Tuple[str, int]]:
"""Get meaningful entities and their frequencies using hot tag logic."""
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
return await get_hot_memory_tags(self.user_id, limit=limit)
# ============================================================================
# Service Class
# ============================================================================
class UserMemoryService:
"""用户记忆服务类"""
def __init__(self):
logger.info("UserMemoryService initialized")
self.neo4j_connector = Neo4jConnector()
@staticmethod
def _datetime_to_timestamp(dt: Optional[Any]) -> Optional[int]:
"""将 DateTime 对象转换为时间戳(毫秒)"""
if dt is None:
return None
if hasattr(dt, 'timestamp'):
return int(dt.timestamp() * 1000)
return None
@staticmethod
def convert_profile_to_dict_with_timestamp(profile_data: Any) -> dict:
"""
将 Pydantic 模型转换为字典,自动转换所有 DateTime 字段为时间戳(毫秒)
Args:
profile_data: Pydantic 模型对象
Returns:
包含时间戳的字典
"""
data = profile_data.model_dump()
# 自动转换所有 datetime 类型的字段
for key, value in data.items():
if hasattr(profile_data, key):
original_value = getattr(profile_data, key)
if hasattr(original_value, 'timestamp'):
data[key] = UserMemoryService._datetime_to_timestamp(original_value)
return data
# ======================== 用户别名及信息 ========================
def get_end_user_info(
self,
db: Session,
end_user_id: str
) -> Dict[str, Any]:
"""
查询单个终端用户信息记录
Args:
db: 数据库会话
end_user_id: 终端用户ID (UUID)
Returns:
{
"success": bool,
"data": dict,
"error": Optional[str]
}
"""
try:
from app.repositories.end_user_info_repository import EndUserInfoRepository
from app.core.api_key_utils import datetime_to_timestamp
# 转换为UUID并查询
user_uuid = uuid.UUID(end_user_id)
end_user_info_record = EndUserInfoRepository(db).get_by_end_user_id(user_uuid)
if not end_user_info_record:
logger.warning(f"终端用户信息记录不存在: end_user_id={end_user_id}")
return {
"success": False,
"data": None,
"error": "终端用户信息记录不存在"
}
# 构建响应数据(转换时间为毫秒时间戳)
# meta_data 只暴露四个核心字段
_META_FIELDS = ("goals", "traits", "interests", "core_facts")
raw_meta = end_user_info_record.meta_data or {}
filtered_meta = {k: raw_meta[k] for k in _META_FIELDS if k in raw_meta}
response_data = {
"end_user_info_id": str(end_user_info_record.id),
"end_user_id": str(end_user_info_record.end_user_id),
"other_name": end_user_info_record.other_name,
"aliases": end_user_info_record.aliases,
"meta_data": filtered_meta,
"created_at": datetime_to_timestamp(end_user_info_record.created_at),
"updated_at": datetime_to_timestamp(end_user_info_record.updated_at)
}
logger.info(f"成功查询终端用户信息记录: end_user_id={end_user_id}")
return {
"success": True,
"data": response_data,
"error": None
}
except ValueError:
logger.error(f"无效的 end_user_id 格式: {end_user_id}")
return {
"success": False,
"data": None,
"error": "无效的终端用户ID格式"
}
except Exception as e:
logger.error(f"查询终端用户信息记录失败: end_user_id={end_user_id}, error={str(e)}")
return {
"success": False,
"data": None,
"error": str(e)
}
def update_end_user_info(
self,
db: Session,
end_user_id: str,
update_data: Dict[str, Any]
) -> Dict[str, Any]:
"""
更新终端用户信息记录
Args:
db: 数据库会话
end_user_id: 终端用户ID (UUID)
update_data: 更新数据字典
Returns:
{
"success": bool,
"data": dict,
"error": Optional[str]
}
"""
try:
from app.repositories.end_user_info_repository import EndUserInfoRepository
from app.repositories.end_user_repository import EndUserRepository
from app.core.api_key_utils import datetime_to_timestamp
# 转换为UUID并查询
user_uuid = uuid.UUID(end_user_id)
end_user_info_record = EndUserInfoRepository(db).get_by_end_user_id(user_uuid)
if not end_user_info_record:
logger.warning(f"终端用户信息记录不存在: end_user_id={end_user_id}")
return {
"success": False,
"data": None,
"error": "终端用户信息记录不存在"
}
# 定义允许更新的字段白名单
allowed_fields = {'other_name', 'aliases', 'meta_data'}
# 用户占位名称黑名单,不允许作为 other_name 或出现在 aliases 中
_user_placeholder_names = _USER_PLACEHOLDER_NAMES
# 过滤 other_name不允许设置为占位名称
if 'other_name' in update_data and update_data['other_name'] and update_data['other_name'].strip() in _user_placeholder_names:
logger.warning(f"拒绝将占位名称 '{update_data['other_name']}' 设置为 other_name")
del update_data['other_name']
# 过滤 aliases移除占位名称和非字符串值
if 'aliases' in update_data and update_data['aliases']:
update_data['aliases'] = [
a for a in update_data['aliases']
if isinstance(a, str) and a.strip() and a.strip() not in _user_placeholder_names
]
# 检查是否更新了 aliases 字段
aliases_updated = 'aliases' in update_data and update_data['aliases'] != end_user_info_record.aliases
# 检查是否更新了 other_name 字段
other_name_updated = 'other_name' in update_data and update_data['other_name'] != end_user_info_record.other_name
# 更新字段(仅允许白名单中的字段)
for field, value in update_data.items():
if field in allowed_fields:
setattr(end_user_info_record, field, value)
# 更新时间戳
end_user_info_record.updated_at = datetime.now()
# 如果 other_name 被更新,同步更新 end_user 表
if other_name_updated:
end_user_record = EndUserRepository(db).get_by_id(user_uuid)
if end_user_record:
end_user_record.other_name = update_data['other_name']
end_user_record.updated_at = datetime.now()
logger.info(f"同步更新 end_user 表的 other_name: end_user_id={end_user_id}, other_name={update_data['other_name']}")
else:
logger.warning(f"未找到对应的 end_user 记录: end_user_id={end_user_id}")
# 提交更改
db.commit()
db.refresh(end_user_info_record)
# 如果 aliases 被更新,同步到 Neo4j
if aliases_updated:
try:
import asyncio
asyncio.run(self._sync_aliases_to_neo4j(end_user_id, update_data['aliases']))
logger.info(f"已触发 aliases 同步到 Neo4j: end_user_id={end_user_id}, aliases={update_data['aliases']}")
except Exception as sync_error:
logger.error(f"触发同步 aliases 到 Neo4j 失败: {sync_error}", exc_info=True)
# 不影响主流程,只记录错误
# 构建响应数据(转换时间为毫秒时间戳)
response_data = {
"end_user_info_id": str(end_user_info_record.id),
"end_user_id": str(end_user_info_record.end_user_id),
"other_name": end_user_info_record.other_name,
"aliases": end_user_info_record.aliases,
"meta_data": end_user_info_record.meta_data,
"created_at": datetime_to_timestamp(end_user_info_record.created_at),
"updated_at": datetime_to_timestamp(end_user_info_record.updated_at)
}
logger.info(f"成功更新终端用户信息记录: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}")
return {
"success": True,
"data": response_data,
"error": None
}
except ValueError:
logger.error(f"无效的 end_user_id 格式: {end_user_id}")
return {
"success": False,
"data": None,
"error": "无效的终端用户ID格式"
}
except Exception as e:
db.rollback()
logger.error(f"更新终端用户信息记录失败: end_user_id={end_user_id}, error={str(e)}")
return {
"success": False,
"data": None,
"error": str(e)
}
async def _sync_aliases_to_neo4j(self, end_user_id: str, aliases: List[str]) -> None:
"""
将 aliases 同步到 Neo4j 中的用户实体
Args:
end_user_id: 终端用户ID
aliases: 别名列表
"""
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# Cypher 查询:更新用户实体的 aliases
cypher_query = """
MATCH (e:ExtractedEntity)
WHERE e.end_user_id = $end_user_id
AND e.name IN ['用户', '', 'User', 'I']
SET e.aliases = $aliases
RETURN e.id AS entity_id, e.name AS entity_name, e.aliases AS updated_aliases
"""
connector = Neo4jConnector()
try:
result = await connector.execute_query(
cypher_query,
end_user_id=end_user_id,
aliases=aliases
)
if result:
logger.info(f"成功同步 aliases 到 Neo4j: end_user_id={end_user_id}, 更新了 {len(result)} 个实体节点")
else:
logger.warning(f"未找到需要更新的用户实体节点: end_user_id={end_user_id}")
except Exception as e:
logger.error(f"同步 aliases 到 Neo4j 失败: {e}", exc_info=True)
raise
async def get_cached_memory_insight(
self,
db: Session,
end_user_id: str
) -> Dict[str, Any]:
"""
从数据库获取缓存的记忆洞察(四个维度)
Args:
db: 数据库会话
end_user_id: 终端用户ID (UUID)
Returns:
{
"memory_insight": str, # 总体概述
"behavior_pattern": str, # 行为模式
"key_findings": List[str], # 关键发现(数组)
"growth_trajectory": str, # 成长轨迹
"updated_at": int, # 时间戳(毫秒)
"is_cached": bool
}
"""
try:
# 转换为UUID并查询用户
user_uuid = uuid.UUID(end_user_id)
repo = EndUserRepository(db)
end_user = repo.get_by_id(user_uuid)
if not end_user:
logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户")
return {
"memory_insight": None,
"behavior_pattern": None,
"key_findings": None,
"growth_trajectory": None,
"updated_at": None,
"is_cached": False,
"message": "用户不存在"
}
# 检查是否有缓存数据(至少有一个字段不为空)
has_cache = any([
end_user.memory_insight,
end_user.behavior_pattern,
end_user.key_findings,
end_user.growth_trajectory
])
if has_cache:
# 反序列化 key_findings从 JSON 字符串转为数组)
key_findings_value = end_user.key_findings
if key_findings_value:
try:
import json
key_findings_array = json.loads(key_findings_value)
except (json.JSONDecodeError, TypeError):
# 如果解析失败,尝试按 • 分割(兼容旧数据)
key_findings_array = [item.strip() for item in key_findings_value.split('') if item.strip()]
else:
key_findings_array = []
logger.info(f"成功获取 end_user_id {end_user_id} 的缓存记忆洞察(四维度)")
memory_insight=end_user.memory_insight
behavior_pattern=end_user.behavior_pattern
growth_trajectory=end_user.growth_trajectory
return {
"memory_insight":memory_insight, # 总体概述存储在 memory_insight
"behavior_pattern":behavior_pattern,
"key_findings": key_findings_array, # 返回数组
"growth_trajectory": growth_trajectory,
"updated_at": self._datetime_to_timestamp(end_user.memory_insight_updated_at),
"is_cached": True
}
else:
logger.info(f"end_user_id {end_user_id} 的记忆洞察缓存为空")
return {
"memory_insight": None,
"behavior_pattern": None,
"key_findings": None,
"growth_trajectory": None,
"updated_at": None,
"is_cached": False,
"message": "数据尚未生成,请稍后重试或联系管理员"
}
except ValueError:
logger.error(f"无效的 end_user_id 格式: {end_user_id}")
return {
"memory_insight": None,
"behavior_pattern": None,
"key_findings": None,
"growth_trajectory": None,
"updated_at": None,
"is_cached": False,
"message": "无效的用户ID格式"
}
except Exception as e:
logger.error(f"获取缓存记忆洞察时出错: {str(e)}")
raise
async def get_cached_user_summary(
self,
db: Session,
end_user_id: str,
model_id:str,
language_type:str="zh"
) -> Dict[str, Any]:
"""
从数据库获取缓存的用户摘要(四个部分)
Args:
db: 数据库会话
end_user_id: 终端用户ID (UUID)
model_id: 模型ID用于翻译
language_type: 语言类型 ("zh" 中文, "en" 英文)
Returns:
{
"user_summary": str,
"personality": str,
"core_values": str,
"one_sentence": str,
"updated_at": datetime,
"is_cached": bool
}
"""
try:
# 转换为UUID并查询用户
user_uuid = uuid.UUID(end_user_id)
repo = EndUserRepository(db)
end_user = repo.get_by_id(user_uuid)
if not end_user:
logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户")
return {
"user_summary": None,
"personality": None,
"core_values": None,
"one_sentence": None,
"updated_at": None,
"is_cached": False,
"message": "用户不存在"
}
# 检查是否有缓存数据(至少有一个字段不为空)
user_summary=end_user.user_summary
personality_traits=end_user.personality_traits
core_values=end_user.core_values
one_sentence_summary=end_user.one_sentence_summary
# 直接返回数据库中的数据,不进行二次翻译
# 语言由生成时的 X-Language-Type 决定
has_cache = any([
user_summary,
personality_traits,
core_values,
one_sentence_summary
])
if has_cache:
logger.info(f"成功获取 end_user_id {end_user_id} 的缓存用户摘要")
return {
"user_summary": user_summary,
"personality": personality_traits,
"core_values":core_values,
"one_sentence": one_sentence_summary,
"updated_at": self._datetime_to_timestamp(end_user.user_summary_updated_at),
"is_cached": True
}
else:
logger.info(f"end_user_id {end_user_id} 的用户摘要缓存为空")
return {
"user_summary": None,
"personality": None,
"core_values": None,
"one_sentence": None,
"updated_at": None,
"is_cached": False,
"message": "数据尚未生成,请稍后重试或联系管理员"
}
except ValueError:
logger.error(f"无效的 end_user_id 格式: {end_user_id}")
return {
"user_summary": None,
"personality": None,
"core_values": None,
"one_sentence": None,
"updated_at": None,
"is_cached": False,
"message": "无效的用户ID格式"
}
except Exception as e:
logger.error(f"获取缓存用户摘要时出错: {str(e)}")
raise
# for user
async def generate_and_cache_insight(
self,
db: Session,
end_user_id: str,
workspace_id: Optional[uuid.UUID] = None,
language: str = "zh"
) -> Dict[str, Any]:
"""
生成并缓存记忆洞察
Args:
db: 数据库会话
end_user_id: 终端用户ID (UUID)
workspace_id: 工作空间ID (可选)
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
Returns:
{
"success": bool,
"memory_insight": str,
"behavior_pattern": str,
"key_findings": List[str], # 数组格式
"growth_trajectory": str,
"error": Optional[str]
}
"""
try:
logger.info(f"开始为 end_user_id {end_user_id} 生成记忆洞察, language={language}")
# 转换为UUID并查询用户
user_uuid = uuid.UUID(end_user_id)
repo = EndUserRepository(db)
end_user = repo.get_by_id(user_uuid)
if not end_user:
logger.error(f"end_user_id {end_user_id} 不存在")
return {
"success": False,
"memory_insight": None,
"behavior_pattern": None,
"key_findings": None,
"growth_trajectory": None,
"error": "用户不存在"
}
# 使用 end_user_id 调用分析函数
try:
logger.info(f"使用 end_user_id={end_user_id} 生成记忆洞察")
result = await analytics_memory_insight_report(end_user_id, language=language)
memory_insight = result.get("memory_insight", "")
behavior_pattern = result.get("behavior_pattern", "")
key_findings_array = result.get("key_findings", []) # 现在是数组
growth_trajectory = result.get("growth_trajectory", "")
# 将 key_findings 数组序列化为 JSON 字符串以存储到数据库
import json
key_findings_json = json.dumps(key_findings_array, ensure_ascii=False) if key_findings_array else ""
if not any([memory_insight, behavior_pattern, key_findings_array, growth_trajectory]):
logger.warning(f"end_user_id {end_user_id} 的记忆洞察生成结果为空")
return {
"success": False,
"memory_insight": None,
"behavior_pattern": None,
"key_findings": None,
"growth_trajectory": None,
"error": "生成的洞察报告为空,可能Neo4j中没有该用户的数据"
}
# 更新数据库缓存(四个维度)
# 注意key_findings 存储为 JSON 字符串
success = repo.update_memory_insight(
user_uuid,
memory_insight,
behavior_pattern,
key_findings_json, # 存储 JSON 字符串
growth_trajectory
)
if success:
logger.info(f"成功为 end_user_id {end_user_id} 生成并缓存记忆洞察(四维度)")
return {
"success": True,
"memory_insight": memory_insight,
"behavior_pattern": behavior_pattern,
"key_findings": key_findings_array, # 返回数组
"growth_trajectory": growth_trajectory,
"error": None
}
else:
logger.error(f"更新 end_user_id {end_user_id} 的记忆洞察缓存失败")
return {
"success": False,
"memory_insight": memory_insight,
"behavior_pattern": behavior_pattern,
"key_findings": key_findings_array, # 返回数组
"growth_trajectory": growth_trajectory,
"error": "数据库更新失败"
}
except Exception as e:
logger.error(f"调用分析函数生成记忆洞察时出错: {str(e)}")
return {
"success": False,
"memory_insight": None,
"behavior_pattern": None,
"key_findings": None,
"growth_trajectory": None,
"error": f"Neo4j或LLM服务不可用: {str(e)}"
}
except ValueError:
logger.error(f"无效的 end_user_id 格式: {end_user_id}")
return {
"success": False,
"memory_insight": None,
"behavior_pattern": None,
"key_findings": None,
"growth_trajectory": None,
"error": "无效的用户ID格式"
}
except Exception as e:
logger.error(f"生成并缓存记忆洞察时出错: {str(e)}")
return {
"success": False,
"memory_insight": None,
"behavior_pattern": None,
"key_findings": None,
"growth_trajectory": None,
"error": str(e)
}
async def generate_and_cache_summary(
self,
db: Session,
end_user_id: str,
workspace_id: Optional[uuid.UUID] = None,
language: str = "zh"
) -> Dict[str, Any]:
"""
生成并缓存用户摘要(四个部分)
Args:
db: 数据库会话
end_user_id: 终端用户ID (UUID)
workspace_id: 工作空间ID (可选)
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
Returns:
{
"success": bool,
"user_summary": str,
"personality": str,
"core_values": str,
"one_sentence": str,
"error": Optional[str]
}
"""
try:
logger.info(f"开始为 end_user_id {end_user_id} 生成用户摘要, language={language}")
# 转换为UUID并查询用户
user_uuid = uuid.UUID(end_user_id)
repo = EndUserRepository(db)
end_user = repo.get_by_id(user_uuid)
if not end_user:
logger.error(f"end_user_id {end_user_id} 不存在")
return {
"success": False,
"user_summary": None,
"personality": None,
"core_values": None,
"one_sentence": None,
"error": "用户不存在"
}
# 使用 end_user_id 调用分析函数
try:
logger.info(f"使用 end_user_id={end_user_id} 生成用户摘要")
result = await analytics_user_summary(end_user_id, language=language)
user_summary = result.get("user_summary", "")
personality = result.get("personality", "")
core_values = result.get("core_values", "")
one_sentence = result.get("one_sentence", "")
if not any([user_summary, personality, core_values, one_sentence]):
logger.warning(f"end_user_id {end_user_id} 的用户摘要生成结果为空")
return {
"success": False,
"user_summary": None,
"personality": None,
"core_values": None,
"one_sentence": None,
"error": "生成的用户摘要为空,可能Neo4j中没有该用户的数据"
}
# 更新数据库缓存
success = repo.update_user_summary(
user_uuid,
user_summary,
personality,
core_values,
one_sentence
)
if success:
logger.info(f"成功为 end_user_id {end_user_id} 生成并缓存用户摘要")
return {
"success": True,
"user_summary": user_summary,
"personality": personality,
"core_values": core_values,
"one_sentence": one_sentence,
"error": None
}
else:
logger.error(f"更新 end_user_id {end_user_id} 的用户摘要缓存失败")
return {
"success": False,
"user_summary": user_summary,
"personality": personality,
"core_values": core_values,
"one_sentence": one_sentence,
"error": "数据库更新失败"
}
except Exception as e:
logger.error(f"调用分析函数生成用户摘要时出错: {str(e)}")
return {
"success": False,
"user_summary": None,
"personality": None,
"core_values": None,
"one_sentence": None,
"error": f"Neo4j或LLM服务不可用: {str(e)}"
}
except ValueError:
logger.error(f"无效的 end_user_id 格式: {end_user_id}")
return {
"success": False,
"user_summary": None,
"personality": None,
"core_values": None,
"one_sentence": None,
"error": "无效的用户ID格式"
}
except Exception as e:
logger.error(f"生成并缓存用户摘要时出错: {str(e)}")
return {
"success": False,
"user_summary": None,
"personality": None,
"core_values": None,
"one_sentence": None,
"error": str(e)
}
# for workspace
async def generate_cache_for_workspace(
self,
db: Session,
workspace_id: uuid.UUID,
language: str = "zh"
) -> Dict[str, Any]:
"""
为整个工作空间生成缓存
Args:
db: 数据库会话
workspace_id: 工作空间ID
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
Returns:
{
"total_users": int,
"successful": int,
"failed": int,
"errors": List[Dict]
}
"""
logger.info(f"开始为工作空间 {workspace_id} 批量生成缓存, language={language}")
total_users = 0
successful = 0
failed = 0
errors = []
try:
# 获取工作空间的所有终端用户
repo = EndUserRepository(db)
end_users = repo.get_all_by_workspace(workspace_id)
total_users = len(end_users)
logger.info(f"工作空间 {workspace_id} 共有 {total_users} 个终端用户")
# 遍历每个用户并生成缓存
for end_user in end_users:
end_user_id = str(end_user.id)
try:
# 生成记忆洞察
insight_result = await self.generate_and_cache_insight(db, end_user_id, language=language)
# 生成用户摘要
summary_result = await self.generate_and_cache_summary(db, end_user_id, language=language)
# 检查是否都成功
if insight_result["success"] and summary_result["success"]:
successful += 1
logger.info(f"成功为终端用户 {end_user_id} 生成缓存")
else:
failed += 1
error_info = {
"end_user_id": end_user_id,
"insight_error": insight_result.get("error"),
"summary_error": summary_result.get("error")
}
errors.append(error_info)
logger.warning(f"终端用户 {end_user_id} 的缓存生成部分失败: {error_info}")
except Exception as e:
# 单个用户失败不影响其他用户
failed += 1
error_info = {
"end_user_id": end_user_id,
"error": str(e)
}
errors.append(error_info)
logger.error(f"为终端用户 {end_user_id} 生成缓存时出错: {str(e)}")
# 记录统计信息
logger.info(
f"工作空间 {workspace_id} 批量生成完成: "
f"总数={total_users}, 成功={successful}, 失败={failed}"
)
return {
"total_users": total_users,
"successful": successful,
"failed": failed,
"errors": errors
}
except Exception as e:
logger.error(f"为工作空间 {workspace_id} 批量生成缓存时出错: {str(e)}")
return {
"total_users": total_users,
"successful": successful,
"failed": failed,
"errors": errors + [{"error": f"批量处理失败: {str(e)}"}]
}
# 独立的分析函数
async def analytics_memory_insight_report(end_user_id: Optional[str] = None, language: str = "zh") -> Dict[str, Any]:
"""
生成记忆洞察报告(四个维度)
这个函数包含完整的业务逻辑:
1. 使用 MemoryInsightHelper 工具类获取基础数据(领域分布、活跃时段、社交关联)
2. 使用 Jinja2 模板渲染提示词
3. 调用 LLM 生成四个维度的自然语言报告
4. 解析并返回四个部分
Args:
end_user_id: 可选的终端用户ID
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
Returns:
包含四个维度报告的字典: {
"memory_insight": str, # 总体概述
"behavior_pattern": str, # 行为模式
"key_findings": List[str], # 关键发现(数组)
"growth_trajectory": str # 成长轨迹
}
"""
import re
from app.core.language_utils import validate_language
from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt
# 验证语言参数
language = validate_language(language)
insight = MemoryInsightHelper(end_user_id)
try:
# 1. 并行获取三个维度的数据
import asyncio
domain_dist, active_periods, social_conn = await asyncio.gather(
insight.get_domain_distribution(),
insight.get_active_periods(),
insight.get_social_connections(),
)
# 2. 构建数据字符串
domain_distribution_str = None
if domain_dist:
top_domains = ", ".join([f"{k}({v:.0%})" for k, v in list(domain_dist.items())[:3]])
domain_distribution_str = f"用户的记忆主要集中在 {top_domains}"
active_periods_str = None
if active_periods:
months_str = "".join(map(str, active_periods))
active_periods_str = f"用户在每年的 {months_str} 月最为活跃"
social_connections_str = None
if social_conn:
social_connections_str = f"与用户\"{social_conn['user_id']}\"拥有最多共同记忆({social_conn['common_memories_count']}条),时间范围主要在 {social_conn['time_range']}"
# 3. 如果没有足够数据,返回默认消息
if not any([domain_distribution_str, active_periods_str, social_connections_str]):
return {
"memory_insight": "暂无足够数据生成洞察报告。",
"behavior_pattern": "",
"key_findings": "",
"growth_trajectory": ""
}
# 4. 使用 Jinja2 模板渲染提示词
user_prompt = await render_memory_insight_prompt(
domain_distribution=domain_distribution_str,
active_periods=active_periods_str,
social_connections=social_connections_str,
language=language
)
messages = [
{"role": "user", "content": user_prompt}
]
# 5. 调用 LLM 生成报告
response = await insight.llm_client.chat(messages=messages)
# 6. 处理 LLM 响应,确保返回字符串类型
content = response.content
if isinstance(content, list):
if len(content) > 0:
if isinstance(content[0], dict):
text = content[0].get('text', content[0].get('content', str(content[0])))
full_response = str(text)
else:
full_response = str(content[0])
else:
full_response = ""
elif isinstance(content, dict):
full_response = str(content.get('text', content.get('content', str(content))))
else:
full_response = str(content) if content is not None else ""
# 7. 解析四个部分
# 使用正则表达式提取四个部分(支持中英文双语标题)
memory_insight_match = re.search(r'【(?:总体概述|Overview)】\s*\n(.*?)(?=\n【|$)', full_response, re.DOTALL)
behavior_match = re.search(r'【(?:行为模式|Behavior Pattern)】\s*\n(.*?)(?=\n【|$)', full_response, re.DOTALL)
findings_match = re.search(r'【(?:关键发现|Key Findings)】\s*\n(.*?)(?=\n【|$)', full_response, re.DOTALL)
trajectory_match = re.search(r'【(?:成长轨迹|Growth Trajectory)】\s*\n(.*?)(?=\n【|$)', full_response, re.DOTALL)
memory_insight = memory_insight_match.group(1).strip() if memory_insight_match else ""
behavior_pattern = behavior_match.group(1).strip() if behavior_match else ""
key_findings_text = findings_match.group(1).strip() if findings_match else ""
growth_trajectory = trajectory_match.group(1).strip() if trajectory_match else ""
# 将 key_findings 从文本转换为数组
# 按 • 符号分割,并清理每个条目
key_findings_array = []
if key_findings_text:
# 分割并清理每个条目
items = [item.strip() for item in key_findings_text.split('') if item.strip()]
key_findings_array = items
return {
"memory_insight": memory_insight,
"behavior_pattern": behavior_pattern,
"key_findings": key_findings_array, # 返回数组而不是字符串
"growth_trajectory": growth_trajectory
}
finally:
# 确保关闭连接
await insight.close()
async def analytics_user_summary(end_user_id: Optional[str] = None, language: str = "zh") -> Dict[str, Any]:
"""
生成用户摘要(包含四个部分)
这个函数包含完整的业务逻辑:
1. 使用 UserSummaryHelper 工具类获取基础数据(实体、语句)
2. 使用 prompt_utils 渲染提示词
3. 调用 LLM 生成四部分内容:基本介绍、性格特点、核心价值观、一句话总结
Args:
end_user_id: 可选的终端用户ID
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
Returns:
包含四部分摘要的字典: {
"user_summary": str,
"personality": str,
"core_values": str,
"one_sentence": str
}
"""
import re
from app.core.language_utils import validate_language
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
from app.repositories.end_user_repository import EndUserRepository
# 验证语言参数
language = validate_language(language)
# 获取用户的 other_name 字段
user_display_name = "该用户" if language == "zh" else "the user"
if end_user_id:
try:
# 获取数据库会话并查询用户信息
with get_db_context() as db:
repo = EndUserRepository(db)
end_user = repo.get_by_id(uuid.UUID(end_user_id))
if end_user and end_user.other_name:
user_display_name = end_user.other_name
logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}")
else:
logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}")
except Exception as e:
logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}")
# 创建 UserSummaryHelper 实例
user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_end_user_id", "group_123"))
try:
# 1) 收集上下文数据
entities = await user_summary_tool.get_top_entities(limit=40)
statements = await user_summary_tool.get_recent_statements(limit=100)
entity_lines = [f"{name} ({freq})" for name, freq in entities][:20]
statement_samples = [s["statement"].strip() for s in statements if s.get("statement", "").strip()][:20]
# 2) 使用 prompt_utils 渲染提示词
user_prompt = await render_user_summary_prompt(
user_id=user_summary_tool.user_id,
entities=", ".join(entity_lines) if entity_lines else "(空)" if language == "zh" else "(empty)",
statements=" | ".join(statement_samples) if statement_samples else "(空)" if language == "zh" else "(empty)",
language=language,
user_display_name=user_display_name
)
messages = [
{"role": "user", "content": user_prompt},
]
# 3) 调用 LLM 生成摘要
response = await user_summary_tool.llm.chat(messages=messages)
# 4) 处理 LLM 响应,确保返回字符串类型
content = response.content
if isinstance(content, list):
if len(content) > 0:
if isinstance(content[0], dict):
text = content[0].get('text', content[0].get('content', str(content[0])))
full_response = str(text)
else:
full_response = str(content[0])
else:
full_response = ""
elif isinstance(content, dict):
full_response = str(content.get('text', content.get('content', str(content))))
else:
full_response = str(content) if content is not None else ""
# 5) 解析四个部分
# 使用正则表达式提取四个部分(支持中英文标题)
user_summary_match = re.search(r'【(?:基本介绍|Basic Introduction)】\s*\n(.*?)(?=\n【|$)', full_response, re.DOTALL)
personality_match = re.search(r'【(?:性格特点|Personality Traits)】\s*\n(.*?)(?=\n【|$)', full_response, re.DOTALL)
core_values_match = re.search(r'【(?:核心价值观|Core Values)】\s*\n(.*?)(?=\n【|$)', full_response, re.DOTALL)
one_sentence_match = re.search(r'【(?:一句话总结|One-Sentence Summary)】\s*\n(.*?)(?=\n【|$)', full_response, re.DOTALL)
user_summary = user_summary_match.group(1).strip() if user_summary_match else ""
personality = personality_match.group(1).strip() if personality_match else ""
core_values = core_values_match.group(1).strip() if core_values_match else ""
one_sentence = one_sentence_match.group(1).strip() if one_sentence_match else ""
# 6) 清理可能包含的反思内容(防御性编程)
# 如果 LLM 仍然输出了反思内容,在这里过滤掉
def clean_reflection_content(text: str) -> str:
"""移除可能包含的反思内容"""
if not text:
return text
# 移除 "---" 之后的所有内容(通常是反思部分的开始)
if '---' in text:
text = text.split('---')[0].strip()
# 移除 "**Step" 开头的内容
if '**Step' in text:
text = text.split('**Step')[0].strip()
# 移除 "Self-Review" 相关内容
if 'Self-Review' in text or 'self-review' in text:
text = re.sub(r'[\-\*]*\s*Self-Review.*$', '', text, flags=re.IGNORECASE | re.DOTALL).strip()
return text
user_summary = clean_reflection_content(user_summary)
personality = clean_reflection_content(personality)
core_values = clean_reflection_content(core_values)
one_sentence = clean_reflection_content(one_sentence)
return {
"user_summary": user_summary,
"personality": personality,
"core_values": core_values,
"one_sentence": one_sentence
}
finally:
# 确保关闭连接
await user_summary_tool.close()
async def analytics_node_statistics(
db: Session,
end_user_id: Optional[str] = None
) -> Dict[str, Any]:
"""
统计 Neo4j 中四种节点类型的数量和百分比
Args:
db: 数据库会话
end_user_id: 可选的终端用户ID (UUID),用于过滤特定用户的节点
Returns:
{
"total": int, # 总节点数
"nodes": [
{
"type": str, # 节点类型
"count": int, # 节点数量
"percentage": float # 百分比
}
]
}
"""
# 定义四种节点类型的查询
node_types = ["Chunk", "MemorySummary", "Statement", "ExtractedEntity"]
# 存储每种节点类型的计数
node_counts = {}
# 查询每种节点类型的数量
for node_type in node_types:
# 构建查询语句
if end_user_id:
query = f"""
MATCH (n:{node_type})
WHERE n.end_user_id = $end_user_id
RETURN count(n) as count
"""
result = await _neo4j_connector.execute_query(query, end_user_id=end_user_id)
else:
query = f"""
MATCH (n:{node_type})
RETURN count(n) as count
"""
result = await _neo4j_connector.execute_query(query)
# 提取计数结果
count = result[0]["count"] if result and len(result) > 0 else 0
node_counts[node_type] = count
# 计算总数
total = sum(node_counts.values())
# 构建返回数据,包含百分比
nodes = []
for node_type in node_types:
count = node_counts[node_type]
percentage = round((count / total * 100), 2) if total > 0 else 0.0
nodes.append({
"type": node_type,
"count": count,
"percentage": percentage
})
data = {
"total": total,
"nodes": nodes
}
return data
async def analytics_memory_types(
db: Session,
end_user_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
统计8种记忆类型的数量和百分比
计算规则:
1. 感知记忆 (PERCEPTUAL_MEMORY) = 通过 MemoryPerceptualService.get_memory_count 获取的 total_count
2. 工作记忆 (WORKING_MEMORY) = 会话数量(通过 ConversationRepository.get_conversation_by_user_id 获取)
3. 短期记忆 (SHORT_TERM_MEMORY) = /short_term 接口返回的问答对数量
4. 显性记忆 (EXPLICIT_MEMORY) = 情景记忆 + 语义记忆(通过 MemoryBaseService.get_explicit_memory_count 获取)
5. 隐性记忆 (IMPLICIT_MEMORY) = MemorySummary 节点数量(需 >= MIN_MEMORY_SUMMARY_COUNT 才显示,否则为 0
6. 情绪记忆 (EMOTIONAL_MEMORY) = 情绪标签统计总数(通过 MemoryBaseService.get_emotional_memory_count 获取)
7. 情景记忆 (EPISODIC_MEMORY) = memory_summary通过 MemoryBaseService.get_episodic_memory_count 获取)
8. 遗忘记忆 (FORGET_MEMORY) = 激活值低于阈值的节点数(通过 MemoryBaseService.get_forget_memory_count 获取)
Args:
db: 数据库会话
end_user_id: 可选的终端用户ID (UUID),用于过滤特定用户的节点
Returns:
[
{
"type": str, # 记忆类型枚举值 (如 PERCEPTUAL_MEMORY, WORKING_MEMORY 等)
"count": int, # 该类型的数量
"percentage": float # 该类型在所有记忆中的占比
},
...
]
记忆类型枚举值:
- PERCEPTUAL_MEMORY: 感知记忆
- WORKING_MEMORY: 工作记忆
- SHORT_TERM_MEMORY: 短期记忆
- EXPLICIT_MEMORY: 显性记忆
- IMPLICIT_MEMORY: 隐性记忆
- EMOTIONAL_MEMORY: 情绪记忆
- EPISODIC_MEMORY: 情景记忆
- FORGET_MEMORY: 遗忘记忆
"""
# 初始化基础服务
base_service = MemoryBaseService()
# 初始化感知记忆服务
perceptual_service = MemoryPerceptualService(db)
# 获取感知记忆数量
if end_user_id:
perceptual_stats = perceptual_service.get_memory_count(uuid.UUID(end_user_id))
perceptual_count = perceptual_stats.get("total", 0)
else:
perceptual_count = 0
# 获取工作记忆数量(基于会话数量)
work_count = 0
if end_user_id:
try:
conversation_repo = ConversationRepository(db)
conversations, total = conversation_repo.get_conversation_by_user_id(
user_id=uuid.UUID(end_user_id),
is_activate=True
)
work_count = total
logger.debug(f"工作记忆数量(会话数): {work_count} (end_user_id={end_user_id})")
except Exception as e:
logger.warning(f"获取会话数量失败工作记忆数量设为0: {str(e)}")
work_count = 0
# 获取隐性记忆数量(基于有关联关系的 MemorySummary 节点数量,需 >= MIN_MEMORY_SUMMARY_COUNT 才计入)
implicit_count = 0
if end_user_id:
try:
memory_summary_count = await base_service.get_valid_memory_summary_count(end_user_id)
implicit_count = memory_summary_count if memory_summary_count >= MIN_MEMORY_SUMMARY_COUNT else 0
logger.debug(f"隐性记忆数量有效MemorySummary节点数: {implicit_count} (有效MemorySummary总数={memory_summary_count}, end_user_id={end_user_id})")
except Exception as e:
logger.warning(f"获取MemorySummary数量失败隐性记忆数量设为0: {str(e)}")
implicit_count = 0
# 原有的基于行为习惯的统计方式(已注释)
# implicit_count = 0
# if end_user_id:
# try:
# implicit_service = ImplicitMemoryService(db, end_user_id)
# behavior_habits = await implicit_service.get_behavior_habits(
# user_id=end_user_id
# )
# implicit_count = len(behavior_habits)
# logger.debug(f"隐性记忆数量(行为习惯数): {implicit_count} (end_user_id={end_user_id})")
# except Exception as e:
# logger.warning(f"获取行为习惯数量失败隐性记忆数量设为0: {str(e)}")
# implicit_count = 0
# 获取短期记忆数量(基于 /short_term 接口返回的问答对数量)
short_term_count = 0
if end_user_id:
try:
short_term_service = ShortService(end_user_id, db)
short_term_data = short_term_service.get_short_databasets()
# 统计 short_term 数组的长度
if short_term_data:
short_term_count = len(short_term_data)
logger.debug(f"短期记忆数量(问答对数): {short_term_count} (end_user_id={end_user_id})")
except Exception as e:
logger.warning(f"获取短期记忆数量失败短期记忆数量设为0: {str(e)}")
short_term_count = 0
# 获取用户的遗忘阈值配置
forgetting_threshold = 0.3 # 默认值
if end_user_id:
try:
from app.core.memory.storage_services.forgetting_engine.config_utils import (
load_actr_config_from_db,
)
from app.services.memory_agent_service import get_end_user_connected_config
# 获取用户关联的 config_id
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get('memory_config_id')
if config_id:
# 从数据库加载配置
config = load_actr_config_from_db(db, config_id)
forgetting_threshold = config.get('forgetting_threshold', 0.3)
logger.debug(f"使用用户配置的遗忘阈值: {forgetting_threshold} (end_user_id={end_user_id}, config_id={config_id})")
else:
logger.debug(f"用户未关联配置,使用默认遗忘阈值: {forgetting_threshold} (end_user_id={end_user_id})")
except Exception as e:
logger.warning(f"获取用户遗忘阈值配置失败,使用默认值 {forgetting_threshold}: {str(e)}")
# 使用 MemoryBaseService 的共享方法获取特殊记忆类型的数量
episodic_count = await base_service.get_episodic_memory_count(end_user_id)
explicit_count = await base_service.get_explicit_memory_count(end_user_id)
emotion_count = await base_service.get_emotional_memory_count(end_user_id, perceptual_count)
forget_count = await base_service.get_forget_memory_count(end_user_id, forgetting_threshold)
# 按规则计算8种记忆类型的数量使用英文枚举作为key
memory_counts = {
"PERCEPTUAL_MEMORY": perceptual_count, # 感知记忆
"WORKING_MEMORY": work_count, # 工作记忆(基于会话数量)
"SHORT_TERM_MEMORY": short_term_count, # 短期记忆(基于问答对数量)
"EXPLICIT_MEMORY": explicit_count, # 显性记忆(情景记忆 + 语义记忆)
"IMPLICIT_MEMORY": implicit_count, # 隐性记忆MemorySummary节点数需>=MIN_MEMORY_SUMMARY_COUNT
"EMOTIONAL_MEMORY": emotion_count, # 情绪记忆(使用情绪标签统计)
"EPISODIC_MEMORY": episodic_count, # 情景记忆
"FORGET_MEMORY": forget_count # 遗忘记忆(激活值低于阈值)
}
# 计算总数
total = sum(memory_counts.values())
# 构建返回数据,包含 type、count 和 percentage
memory_types = []
for memory_type, count in memory_counts.items():
percentage = round((count / total * 100), 2) if total > 0 else 0.0
memory_types.append({
"type": memory_type,
"count": count,
"percentage": percentage
})
return memory_types
async def analytics_graph_data(
db: Session,
end_user_id: str,
node_types: Optional[List[str]] = None,
limit: int = 130,
depth: int = 1,
center_node_id: Optional[str] = None
) -> Dict[str, Any]:
"""
获取 Neo4j 图数据,用于前端可视化
Args:
db: 数据库会话
end_user_id: 终端用户ID
node_types: 可选的节点类型列表
limit: 返回节点数量限制
depth: 图遍历深度
center_node_id: 可选的中心节点ID
Returns:
包含节点、边和统计信息的字典
"""
try:
# 1. 获取 end_user_id
user_uuid = uuid.UUID(end_user_id)
repo = EndUserRepository(db)
end_user = repo.get_by_id(user_uuid)
if not end_user:
logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户")
return {
"nodes": [],
"edges": [],
"statistics": {
"total_nodes": 0,
"total_edges": 0,
"node_types": {},
"edge_types": {}
},
"message": "用户不存在"
}
# 2. 构建节点查询
if center_node_id:
# 基于中心节点的扩展查询
node_query = f"""
MATCH path = (center)-[*1..{depth}]-(connected)
WHERE center.end_user_id = $end_user_id
AND elementId(center) = $center_node_id
WITH collect(DISTINCT center) + collect(DISTINCT connected) as all_nodes
UNWIND all_nodes as n
RETURN DISTINCT
elementId(n) as id,
labels(n)[0] as label,
properties(n) as properties
LIMIT $limit
"""
node_params = {
"end_user_id": end_user_id,
"center_node_id": center_node_id,
"limit": limit
}
elif node_types:
# 按节点类型过滤查询
node_query = """
MATCH (n)
WHERE n.end_user_id = $end_user_id
AND labels(n)[0] IN $node_types
RETURN
elementId(n) as id,
labels(n)[0] as label,
properties(n) as properties
LIMIT $limit
"""
node_params = {
"end_user_id": end_user_id,
"node_types": node_types,
"limit": limit
}
else:
# 查询所有节点
node_query=Graph_Node_query
node_params = {
"end_user_id": end_user_id,
"limit": limit
}
# 执行节点查询
node_results = await _neo4j_connector.execute_query(node_query, **node_params)
# 3. 格式化节点数据
nodes = []
node_ids = []
node_type_counts = {}
for record in node_results:
node_id = record["id"]
node_labels = record.get("labels", [])
node_label = node_labels[0] if node_labels else "Unknown"
node_props = record["properties"]
# 根据节点类型提取需要的属性字段
filtered_props = await _extract_node_properties(node_label, node_props,node_id)
# 直接使用数据库中的 caption如果没有则使用节点类型作为默认值
caption = filtered_props.get("caption", node_label)
nodes.append({
"id": node_id,
"label": node_label,
"properties": filtered_props,
"caption": caption
})
node_ids.append(node_id)
node_type_counts[node_label] = node_type_counts.get(node_label, 0) + 1
# 4. 查询节点之间的关系
if len(node_ids) > 0:
edge_query = """
MATCH (n)-[r]->(m)
WHERE elementId(n) IN $node_ids
AND elementId(m) IN $node_ids
RETURN
elementId(r) as id,
elementId(n) as source,
elementId(m) as target,
type(r) as rel_type,
properties(r) as properties
"""
edge_results = await _neo4j_connector.execute_query(
edge_query,
node_ids=node_ids
)
else:
edge_results = []
# 5. 格式化边数据
edges = []
edge_type_counts = {}
for record in edge_results:
edge_id = record["id"]
source = record["source"]
target = record["target"]
rel_type = record["rel_type"]
edge_props = record["properties"]
# 清理边属性中的 Neo4j 特殊类型
# 对于边,我们保留所有属性,但清理特殊类型
cleaned_edge_props = {}
if edge_props:
for key, value in edge_props.items():
cleaned_edge_props[key] = _clean_neo4j_value(value)
# 直接使用关系类型作为 caption如果 properties 中有 caption 则使用它
caption = cleaned_edge_props.get("caption", rel_type)
edges.append({
"id": edge_id,
"source": source,
"target": target,
"type": rel_type,
"properties": cleaned_edge_props,
"caption": caption
})
edge_type_counts[rel_type] = edge_type_counts.get(rel_type, 0) + 1
# 6. 构建统计信息
statistics = {
"total_nodes": len(nodes),
"total_edges": len(edges),
"node_types": node_type_counts,
"edge_types": edge_type_counts
}
logger.info(
f"成功获取图数据: end_user_id={end_user_id}, "
f"nodes={len(nodes)}, edges={len(edges)}"
)
return {
"nodes": nodes,
"edges": edges,
"statistics": statistics
}
except ValueError:
logger.error(f"无效的 end_user_id 格式: {end_user_id}")
return {
"nodes": [],
"edges": [],
"statistics": {
"total_nodes": 0,
"total_edges": 0,
"node_types": {},
"edge_types": {}
},
"message": "无效的用户ID格式"
}
except Exception as e:
logger.error(f"获取图数据失败: {str(e)}", exc_info=True)
raise
# 辅助函数
async def analytics_community_graph_data(
db: Session,
end_user_id: str,
) -> Dict[str, Any]:
"""
获取社区图谱数据,包含 Community 节点、ExtractedEntity 节点及其关系。
Returns:
包含 nodes、edges、statistics 的字典,格式与 analytics_graph_data 一致
"""
try:
user_uuid = uuid.UUID(end_user_id)
repo = EndUserRepository(db)
end_user = repo.get_by_id(user_uuid)
if not end_user:
return {
"nodes": [], "edges": [],
"statistics": {"total_nodes": 0, "total_edges": 0, "node_types": {}, "edge_types": {}},
"message": "用户不存在"
}
# 查询社区节点、实体节点、BELONGS_TO_COMMUNITY 边、实体间关系
from app.repositories.neo4j.cypher_queries import GET_COMMUNITY_GRAPH_DATA
rows = await _neo4j_connector.execute_query(GET_COMMUNITY_GRAPH_DATA, end_user_id=end_user_id)
nodes_map: Dict[str, dict] = {}
edges_map: Dict[str, dict] = {}
# 记录每个 Community 对应的实体 id 列表
community_members: Dict[str, list] = {}
for row in rows:
# Community 节点
c_id = row["c_id"]
if c_id and c_id not in nodes_map:
raw = row["c_props"] or {}
props = {k: _clean_neo4j_value(raw.get(k)) for k in (
"community_id", "end_user_id", "member_count", "updated_at",
"name", "summary", "core_entities",
) if k in raw}
nodes_map[c_id] = {
"id": c_id,
"label": "Community",
"properties": props,
}
# ExtractedEntity 节点 (e)
e_id = row["e_id"]
if e_id and e_id not in nodes_map:
raw = row["e_props"] or {}
props = {k: _clean_neo4j_value(raw.get(k)) for k in (
"name", "end_user_id", "description", "created_at", "entity_type",
) if k in raw}
# 注入所属社区名称c 是 e 直接归属的社区)
c_raw = row["c_props"] or {}
props["community_name"] = _clean_neo4j_value(c_raw.get("name")) or ""
nodes_map[e_id] = {
"id": e_id,
"label": "ExtractedEntity",
"properties": props,
}
# ExtractedEntity 节点 (e2可选)
e2_id = row.get("e2_id")
if e2_id and e2_id not in nodes_map:
raw = row["e2_props"] or {}
props = {k: _clean_neo4j_value(raw.get(k)) for k in (
"name", "end_user_id", "description", "created_at", "entity_type",
) if k in raw}
# e2 的社区归属在后处理阶段通过 community_members 补充
props["community_name"] = ""
nodes_map[e2_id] = {
"id": e2_id,
"label": "ExtractedEntity",
"properties": props,
}
# BELONGS_TO_COMMUNITY 边
b_id = row["b_id"]
if b_id and b_id not in edges_map:
edges_map[b_id] = {
"id": b_id,
"source": e_id,
"target": c_id,
}
# 收集社区成员 id
if c_id and e_id:
community_members.setdefault(c_id, [])
if e_id not in community_members[c_id]:
community_members[c_id].append(e_id)
# EXTRACTED_RELATIONSHIP 边(可选)
r_id = row.get("r_id")
if r_id and r_id not in edges_map and e2_id:
r_props = {k: _clean_neo4j_value(v) for k, v in (row["r_props"] or {}).items()}
source = e_id if row.get("r_from_e") else e2_id
target = e2_id if row.get("r_from_e") else e_id
edges_map[r_id] = {
"id": r_id,
"source": source,
"target": target,
}
nodes = list(nodes_map.values())
edges = list(edges_map.values())
# 为每个 Community 节点注入 member_entity_ids同时补全 e2 节点的 community_name
for c_id, member_ids in community_members.items():
c_node = nodes_map.get(c_id)
if c_node:
c_node["properties"]["member_entity_ids"] = member_ids
c_name = c_node["properties"].get("name") or ""
# 补全属于该社区但 community_name 为空的实体(即 e2 节点)
for eid in member_ids:
e_node = nodes_map.get(eid)
if e_node and e_node["label"] == "ExtractedEntity":
if not e_node["properties"].get("community_name"):
e_node["properties"]["community_name"] = c_name
node_type_counts: Dict[str, int] = {}
for n in nodes:
node_type_counts[n["label"]] = node_type_counts.get(n["label"], 0) + 1
return {
"nodes": nodes,
"edges": edges,
"statistics": {
"total_nodes": len(nodes),
"total_edges": len(edges),
"node_types": node_type_counts,
}
}
except ValueError:
logger.error(f"无效的 end_user_id 格式: {end_user_id}")
return {
"nodes": [], "edges": [],
"statistics": {"total_nodes": 0, "total_edges": 0, "node_types": {}, "edge_types": {}},
"message": "无效的用户ID格式"
}
except Exception as e:
logger.error(f"获取社区图谱数据失败: {str(e)}", exc_info=True)
raise
async def _extract_node_properties(label: str, properties: Dict[str, Any],node_id: str) -> Dict[str, Any]:
"""
根据节点类型提取需要的属性字段
Args:
label: 节点类型标签
properties: 节点的所有属性
Returns:
过滤后的属性字典
"""
# 定义每种节点类型需要的字段(白名单)
field_whitelist = {
"Dialogue": ["content", "created_at"],
"Chunk": ["content", "created_at"],
"Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption","emotion_keywords","emotion_type","emotion_subject"],
"ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption","aliases","connect_strength"],
"MemorySummary": ["summary", "content", "created_at", "caption"], # 添加 content 字段
"Perceptual": ["file_name", "file_path", "file_type", "domain", "topic", "keywords", "summary"]
}
# 获取该节点类型的白名单字段
allowed_fields = field_whitelist.get(label, [])
count_neo4j=f"""MATCH (n)-[r]-(m) WHERE elementId(n) ="{node_id}" RETURN count(r) AS rel_count;"""
node_results = await (_neo4j_connector.execute_query(count_neo4j))
# 提取白名单中的字段
filtered_props = {}
for field in allowed_fields:
if field in properties:
value = properties[field]
if str(field) == 'entity_type':
value=type_mapping.get(value,'')
if str(field)=="emotion_type":
value=EmotionType.EMOTION_MAPPING.get(value)
if str(field)=="emotion_subject":
value=EmotionSubject.SUBJECT_MAPPING.get(value)
filtered_props[field] = _clean_neo4j_value(value)
filtered_props['associative_memory']=[i['rel_count'] for i in node_results][0]
return filtered_props
def _clean_neo4j_value(value: Any) -> Any:
"""
清理单个值的 Neo4j 特殊类型
Args:
value: 需要清理的值
Returns:
清理后的值
"""
if value is None:
return None
# 处理列表
if isinstance(value, list):
return [_clean_neo4j_value(item) for item in value]
# 处理字典
if isinstance(value, dict):
return {k: _clean_neo4j_value(v) for k, v in value.items()}
# 处理 Neo4j DateTime 类型
if hasattr(value, '__class__') and 'neo4j.time' in str(type(value)):
try:
if hasattr(value, 'to_native'):
native_dt = value.to_native()
return native_dt.isoformat()
return str(value)
except Exception:
return str(value)
# 处理其他 Neo4j 特殊类型
if hasattr(value, '__class__') and 'neo4j' in str(type(value)):
try:
return str(value)
except Exception:
return None
# 返回原始值
return value