Fix/memory insights (#30)

* [fix]fix memory insights

* [fix]fix memory insights

* [fix]Based on the correction of the code by sourcery-ai
This commit is contained in:
乐力齐
2026-01-06 14:05:15 +08:00
committed by GitHub
parent 85c7e531e4
commit a0f19ace92
5 changed files with 294 additions and 516 deletions

View File

@@ -7,7 +7,6 @@ User Memory Service
import os
import uuid
from collections import Counter
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
@@ -22,7 +21,269 @@ from sqlalchemy.orm import Session
logger = get_logger(__name__)
# Neo4j connector instan
# Neo4j connector instance for analytics functions
_neo4j_connector = Neo4jConnector()
# Default LLM ID for fallback
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
# ============================================================================
# Internal Helper Classes
# ============================================================================
class TagClassification(BaseModel):
"""Represents the classification of a tag into a specific domain."""
domain: str = Field(
...,
description="The domain the tag belongs to, chosen from the predefined list.",
examples=["教育", "学习", "工作", "旅行", "家庭", "运动", "社交", "娱乐", "健康", "其他"],
)
def _get_llm_client_for_user(user_id: str):
"""
Get LLM client for a specific user based on their config.
Args:
user_id: User ID to get config for
Returns:
LLM client instance
"""
with get_db_context() as db:
try:
from app.services.memory_agent_service import get_end_user_connected_config
connected_config = get_end_user_connected_config(user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
factory = MemoryClientFactory(db)
return factory.get_llm_client(memory_config.llm_model_id)
else:
factory = MemoryClientFactory(db)
return factory.get_llm_client(DEFAULT_LLM_ID)
except Exception as e:
logger.warning(f"Failed to get user connected config, using default LLM: {e}")
factory = MemoryClientFactory(db)
return factory.get_llm_client(DEFAULT_LLM_ID)
class MemoryInsightHelper:
"""
Internal helper class for memory insight analysis.
Provides basic data retrieval and analysis functionality.
"""
def __init__(self, user_id: str):
self.user_id = user_id
self.neo4j_connector = Neo4jConnector()
self.llm_client = _get_llm_client_for_user(user_id)
async def close(self):
"""Close database connection."""
await self.neo4j_connector.close()
async def get_domain_distribution(self) -> dict[str, float]:
"""Calculate the distribution of memory domains based on hot tags."""
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
hot_tags = await get_hot_memory_tags(self.user_id)
if not hot_tags:
return {}
domain_counts = Counter()
for tag, _ in hot_tags:
prompt = f"""请将以下标签归类到最合适的领域中。
可选领域及其关键词:
- 教育:学校、课程、考试、培训、教学、学科、教师、学生、班级、作业、成绩、毕业、入学、校园、大学、中学、小学、教材、学位等
- 学习:自学、阅读、书籍、技能提升、知识积累、笔记、复习、练习、研究、历史知识、科学知识、文化知识、学术讨论、知识问答等
- 工作:职业、项目、会议、同事、业务、公司、办公、任务、客户、合同、职场、工作计划等
- 旅行:旅游、景点、出行、度假、酒店、机票、导游、风景、旅行计划等
- 家庭:亲人、父母、子女、配偶、家事、家庭活动、亲情、家庭聚会等
- 运动:健身、体育、锻炼、跑步、游泳、球类、瑜伽、运动计划等
- 社交:朋友、聚会、社交活动、派对、聊天、交友、社交网络等
- 娱乐:游戏、电影、音乐、休闲、综艺、动漫、小说、娱乐活动等
- 健康:医疗、养生、心理健康、体检、药物、疾病、保健、健康管理等
- 其他:确实无法归入以上任何类别的内容
标签: {tag}
分析步骤:
1. 仔细理解标签的核心含义和使用场景
2. 对比各个领域的关键词,找到最匹配的领域
3. 特别注意:
- 历史、科学、文化等知识性内容应归类为"学习"
- 学校、课程、考试等正式教育场景应归类为"教育"
- 只有在标签完全不属于上述9个具体领域时才选择"其他"
4. 如果标签与某个领域有任何相关性,就选择该领域,不要选"其他"
请直接返回最合适的领域名称。"""
messages = [
{"role": "system", "content": "你是一个专业的标签分类助手。你必须仔细分析标签的实际含义和使用场景优先选择9个具体领域之一。'其他'类别只用于完全无法归类的极少数情况。特别注意:历史、科学、文化等知识性对话应归类为'学习'领域;学校、课程、考试等正式教育场景应归类为'教育'领域。"},
{"role": "user", "content": prompt}
]
classification = await self.llm_client.response_structured(
messages=messages,
response_model=TagClassification,
)
if classification and hasattr(classification, 'domain') and classification.domain:
domain_counts[classification.domain] += 1
total_tags = sum(domain_counts.values())
if total_tags == 0:
return {}
domain_distribution = {
domain: count / total_tags for domain, count in domain_counts.items()
}
return dict(sorted(domain_distribution.items(), key=lambda item: item[1], reverse=True))
async def get_active_periods(self) -> list[int]:
"""
Identify the top 2 most active months for the user.
Only returns months if there is valid and diverse time data.
"""
query = """
MATCH (d:Dialogue)
WHERE d.group_id = $group_id AND d.created_at IS NOT NULL AND d.created_at <> ''
RETURN d.created_at AS creation_time
"""
records = await self.neo4j_connector.execute_query(query, group_id=self.user_id)
if not records:
return []
month_counts = Counter()
valid_dates_count = 0
for record in records:
creation_time_str = record.get("creation_time")
if not creation_time_str:
continue
try:
dt_object = datetime.fromisoformat(creation_time_str.replace("Z", "+00:00"))
month_counts[dt_object.month] += 1
valid_dates_count += 1
except (ValueError, TypeError, AttributeError):
continue
if not month_counts or valid_dates_count == 0:
return []
# Check if time distribution is too concentrated (likely batch imported data)
unique_months = len(month_counts)
if unique_months <= 2:
most_common_count = month_counts.most_common(1)[0][1]
if most_common_count / valid_dates_count > 0.8:
return []
if unique_months >= 3:
most_common_months = month_counts.most_common(2)
return [month for month, _ in most_common_months]
if unique_months == 2:
counts = list(month_counts.values())
ratio = min(counts) / max(counts)
if ratio > 0.3:
most_common_months = month_counts.most_common(2)
return [month for month, _ in most_common_months]
return []
async def get_social_connections(self) -> dict | None:
"""Find the user with whom the most memories are shared."""
query = """
MATCH (c1:Chunk {group_id: $group_id})
OPTIONAL MATCH (c1)-[:CONTAINS]->(s:Statement)
OPTIONAL MATCH (s)<-[:CONTAINS]-(c2:Chunk)
WHERE c1.group_id <> c2.group_id AND s IS NOT NULL AND c2 IS NOT NULL
WITH c2.group_id AS other_user_id, COUNT(DISTINCT s) AS common_statements
WHERE common_statements > 0
RETURN other_user_id, common_statements
ORDER BY common_statements DESC
LIMIT 1
"""
records = await self.neo4j_connector.execute_query(query, group_id=self.user_id)
if not records or not records[0].get("other_user_id"):
return None
most_connected_user = records[0]["other_user_id"]
common_memories_count = records[0]["common_statements"]
time_range_query = """
MATCH (c:Chunk)
WHERE c.group_id IN [$user_id, $other_user_id]
RETURN min(c.created_at) AS start_time, max(c.created_at) AS end_time
"""
time_records = await self.neo4j_connector.execute_query(
time_range_query,
user_id=self.user_id,
other_user_id=most_connected_user
)
start_year, end_year = "N/A", "N/A"
if time_records and time_records[0]["start_time"]:
start_year = datetime.fromisoformat(time_records[0]["start_time"].replace("Z", "+00:00")).year
end_year = datetime.fromisoformat(time_records[0]["end_time"].replace("Z", "+00:00")).year
return {
"user_id": most_connected_user,
"common_memories_count": common_memories_count,
"time_range": f"{start_year}-{end_year}",
}
class UserSummaryHelper:
"""
Internal helper class for user summary generation.
Provides data retrieval functionality for user summary analysis.
"""
def __init__(self, user_id: str):
self.user_id = user_id
self.connector = Neo4jConnector()
self.llm = _get_llm_client_for_user(user_id)
async def close(self):
"""Close database connection."""
await self.connector.close()
async def get_recent_statements(self, limit: int = 80) -> List[Dict[str, Any]]:
"""Fetch recent statements authored by the user/group for context."""
query = (
"MATCH (s:Statement) "
"WHERE s.group_id = $group_id AND s.statement IS NOT NULL "
"RETURN s.statement AS statement, s.created_at AS created_at "
"ORDER BY created_at DESC LIMIT $limit"
)
rows = await self.connector.execute_query(query, group_id=self.user_id, limit=limit)
records = []
for r in rows:
try:
records.append({
"statement": r.get("statement", ""),
"created_at": r.get("created_at")
})
except Exception:
continue
return records
async def get_top_entities(self, limit: int = 30) -> List[Tuple[str, int]]:
"""Get meaningful entities and their frequencies using hot tag logic."""
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
return await get_hot_memory_tags(self.user_id, limit=limit)
# ============================================================================
# Service Class
# ============================================================================
# ============================================================================
# Service Class
# ============================================================================
class UserMemoryService:
@@ -601,7 +862,7 @@ async def analytics_memory_insight_report(end_user_id: Optional[str] = None) ->
生成记忆洞察报告(四个维度)
这个函数包含完整的业务逻辑:
1. 使用 MemoryInsight 工具类获取基础数据(领域分布、活跃时段、社交关联)
1. 使用 MemoryInsightHelper 工具类获取基础数据(领域分布、活跃时段、社交关联)
2. 使用 Jinja2 模板渲染提示词
3. 调用 LLM 生成四个维度的自然语言报告
4. 解析并返回四个部分
@@ -620,7 +881,7 @@ async def analytics_memory_insight_report(end_user_id: Optional[str] = None) ->
from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt
import re
insight = MemoryInsight(end_user_id)
insight = MemoryInsightHelper(end_user_id)
try:
# 1. 并行获取三个维度的数据
@@ -722,7 +983,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str,
生成用户摘要(包含四个部分)
这个函数包含完整的业务逻辑:
1. 使用 UserSummary 工具类获取基础数据(实体、语句)
1. 使用 UserSummaryHelper 工具类获取基础数据(实体、语句)
2. 使用 prompt_utils 渲染提示词
3. 调用 LLM 生成四部分内容:基本介绍、性格特点、核心价值观、一句话总结
@@ -737,20 +998,19 @@ async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str,
"one_sentence": str
}
"""
from app.core.memory.analytics.user_summary import UserSummary
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
import re
# 创建 UserSummary 实例
user_summary_tool = UserSummary(end_user_id or os.getenv("SELECTED_GROUP_ID", "group_123"))
# 创建 UserSummaryHelper 实例
user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_GROUP_ID", "group_123"))
try:
# 1) 收集上下文数据
entities = await user_summary_tool._get_top_entities(limit=40)
statements = await user_summary_tool._get_recent_statements(limit=100)
entities = await user_summary_tool.get_top_entities(limit=40)
statements = await user_summary_tool.get_recent_statements(limit=100)
entity_lines = [f"{name} ({freq})" for name, freq in entities][:20]
statement_samples = [s.statement.strip() for s in statements if (s.statement or '').strip()][:20]
statement_samples = [s["statement"].strip() for s in statements if s.get("statement", "").strip()][:20]
# 2) 使用 prompt_utils 渲染提示词
user_prompt = await render_user_summary_prompt(