Fix/memory insights (#30)
* [fix]fix memory insights * [fix]fix memory insights * [fix]Based on the correction of the code by sourcery-ai
This commit is contained in:
@@ -13,7 +13,6 @@ from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.services.user_memory_service import (
|
||||
UserMemoryService,
|
||||
analytics_node_statistics,
|
||||
analytics_memory_types,
|
||||
analytics_graph_data,
|
||||
)
|
||||
@@ -41,24 +40,27 @@ router = APIRouter(
|
||||
|
||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||
async def get_memory_insight_report_api(
|
||||
end_user_id: str, # 使用 end_user_id
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""获取缓存的记忆洞察报告"""
|
||||
api_logger.info(f"记忆洞察报告请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
) -> dict:
|
||||
"""
|
||||
获取缓存的记忆洞察报告
|
||||
|
||||
此接口仅查询数据库中已缓存的记忆洞察数据,不执行生成操作。
|
||||
如需生成新的洞察报告,请使用专门的生成接口。
|
||||
"""
|
||||
api_logger.info(f"记忆洞察报告查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_memory_insight(db, end_user_id)
|
||||
|
||||
if result["is_cached"]:
|
||||
# 缓存存在,返回缓存数据
|
||||
api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}")
|
||||
return success(data=result, msg="查询成功")
|
||||
else:
|
||||
# 缓存不存在,返回提示消息
|
||||
api_logger.info(f"记忆洞察报告缓存不存在: end_user_id={end_user_id}")
|
||||
return success(data=result, msg="查询成功")
|
||||
return success(data=result, msg="数据尚未生成")
|
||||
except Exception as e:
|
||||
api_logger.error(f"记忆洞察报告查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告查询失败", str(e))
|
||||
@@ -66,24 +68,27 @@ async def get_memory_insight_report_api(
|
||||
|
||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||
async def get_user_summary_api(
|
||||
end_user_id: str, # 使用 end_user_id
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""获取缓存的用户摘要"""
|
||||
api_logger.info(f"用户摘要请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
) -> dict:
|
||||
"""
|
||||
获取缓存的用户摘要
|
||||
|
||||
此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。
|
||||
如需生成新的用户摘要,请使用专门的生成接口。
|
||||
"""
|
||||
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id)
|
||||
|
||||
if result["is_cached"]:
|
||||
# 缓存存在,返回缓存数据
|
||||
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
||||
return success(data=result, msg="查询成功")
|
||||
else:
|
||||
# 缓存不存在,返回提示消息
|
||||
api_logger.info(f"用户摘要缓存不存在: end_user_id={end_user_id}")
|
||||
return success(data=result, msg="查询成功")
|
||||
return success(data=result, msg="数据尚未生成")
|
||||
except Exception as e:
|
||||
api_logger.error(f"用户摘要查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户摘要查询失败", str(e))
|
||||
|
||||
@@ -5,19 +5,16 @@ This module provides analytics and insights for the memory system.
|
||||
|
||||
Available functions:
|
||||
- get_hot_memory_tags: Get hot memory tags by frequency
|
||||
- MemoryInsight: Generate memory insight reports
|
||||
- get_recent_activity_stats: Get recent activity statistics
|
||||
- generate_user_summary: Generate user summary
|
||||
|
||||
Note: MemoryInsight and generate_user_summary have been moved to
|
||||
app.services.user_memory_service for better architecture.
|
||||
"""
|
||||
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.analytics.memory_insight import MemoryInsight
|
||||
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
|
||||
from app.core.memory.analytics.user_summary import generate_user_summary
|
||||
|
||||
__all__ = [
|
||||
"get_hot_memory_tags",
|
||||
"MemoryInsight",
|
||||
"get_recent_activity_stats",
|
||||
"generate_user_summary",
|
||||
]
|
||||
|
||||
@@ -1,327 +0,0 @@
|
||||
"""
|
||||
This module provides the MemoryInsight class for analyzing user memory data.
|
||||
|
||||
MemoryInsight 是一个工具类,提供基础的数据获取和分析功能:
|
||||
- get_domain_distribution(): 获取记忆领域分布
|
||||
- get_active_periods(): 获取活跃时段
|
||||
- get_social_connections(): 获取社交关联
|
||||
|
||||
业务逻辑(如生成洞察报告)应该在服务层(user_memory_service.py)中实现。
|
||||
|
||||
This script can be executed directly to test the memory insight generation for a test user.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from collections import Counter
|
||||
from datetime import datetime
|
||||
|
||||
# To run this script directly, we need to add the src directory to the Python path
|
||||
# to resolve the inconsistent imports in other modules.
|
||||
src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if src_path not in sys.path:
|
||||
sys.path.insert(0, src_path)
|
||||
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
#TODO: Fix this
|
||||
|
||||
# Default values (previously from definitions.py)
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
|
||||
|
||||
# 定义用于LLM结构化输出的Pydantic模型
|
||||
class TagClassification(BaseModel):
|
||||
"""
|
||||
Represents the classification of a tag into a specific domain.
|
||||
"""
|
||||
|
||||
domain: str = Field(
|
||||
...,
|
||||
description="The domain the tag belongs to, chosen from the predefined list.",
|
||||
examples=["教育", "学习", "工作", "旅行", "家庭", "运动", "社交", "娱乐", "健康", "其他"],
|
||||
)
|
||||
|
||||
class InsightReport(BaseModel):
|
||||
"""
|
||||
Represents the final insight report generated by the LLM.
|
||||
"""
|
||||
|
||||
report: str = Field(
|
||||
...,
|
||||
description="A comprehensive insight report in Chinese, summarizing the user's memory patterns.",
|
||||
)
|
||||
|
||||
|
||||
class MemoryInsight:
|
||||
"""
|
||||
Provides insights into user memories by analyzing various aspects of their data.
|
||||
"""
|
||||
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
self.neo4j_connector = Neo4jConnector()
|
||||
|
||||
# Get config_id using get_end_user_connected_config
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
connected_config = get_end_user_connected_config(user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id:
|
||||
# Use the config_id to get the proper LLM client
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
else:
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM if no config found
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"Failed to get user connected config, using default LLM: {e}")
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
async def close(self):
|
||||
"""关闭数据库连接。"""
|
||||
await self.neo4j_connector.close()
|
||||
|
||||
async def get_domain_distribution(self) -> dict[str, float]:
|
||||
"""
|
||||
Calculates the distribution of memory domains based on hot tags.
|
||||
"""
|
||||
hot_tags = await get_hot_memory_tags(self.user_id)
|
||||
if not hot_tags:
|
||||
return {}
|
||||
|
||||
domain_counts = Counter()
|
||||
for tag, _ in hot_tags:
|
||||
prompt = f"""请将以下标签归类到最合适的领域中。
|
||||
|
||||
可选领域及其关键词:
|
||||
- 教育:学校、课程、考试、培训、教学、学科、教师、学生、班级、作业、成绩、毕业、入学、校园、大学、中学、小学、教材、学位等
|
||||
- 学习:自学、阅读、书籍、技能提升、知识积累、笔记、复习、练习、研究、历史知识、科学知识、文化知识、学术讨论、知识问答等
|
||||
- 工作:职业、项目、会议、同事、业务、公司、办公、任务、客户、合同、职场、工作计划等
|
||||
- 旅行:旅游、景点、出行、度假、酒店、机票、导游、风景、旅行计划等
|
||||
- 家庭:亲人、父母、子女、配偶、家事、家庭活动、亲情、家庭聚会等
|
||||
- 运动:健身、体育、锻炼、跑步、游泳、球类、瑜伽、运动计划等
|
||||
- 社交:朋友、聚会、社交活动、派对、聊天、交友、社交网络等
|
||||
- 娱乐:游戏、电影、音乐、休闲、综艺、动漫、小说、娱乐活动等
|
||||
- 健康:医疗、养生、心理健康、体检、药物、疾病、保健、健康管理等
|
||||
- 其他:确实无法归入以上任何类别的内容
|
||||
|
||||
标签: {tag}
|
||||
|
||||
分析步骤:
|
||||
1. 仔细理解标签的核心含义和使用场景
|
||||
2. 对比各个领域的关键词,找到最匹配的领域
|
||||
3. 特别注意:
|
||||
- 历史、科学、文化等知识性内容应归类为"学习"
|
||||
- 学校、课程、考试等正式教育场景应归类为"教育"
|
||||
- 只有在标签完全不属于上述9个具体领域时,才选择"其他"
|
||||
4. 如果标签与某个领域有任何相关性,就选择该领域,不要选"其他"
|
||||
|
||||
请直接返回最合适的领域名称。"""
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的标签分类助手。你必须仔细分析标签的实际含义和使用场景,优先选择9个具体领域之一。'其他'类别只用于完全无法归类的极少数情况。特别注意:历史、科学、文化等知识性对话应归类为'学习'领域;学校、课程、考试等正式教育场景应归类为'教育'领域。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
# 直接调用并等待结果
|
||||
classification = await self.llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=TagClassification,
|
||||
)
|
||||
if classification and hasattr(classification, 'domain') and classification.domain:
|
||||
domain_counts[classification.domain] += 1
|
||||
|
||||
total_tags = sum(domain_counts.values())
|
||||
if total_tags == 0:
|
||||
return {}
|
||||
|
||||
domain_distribution = {
|
||||
domain: count / total_tags for domain, count in domain_counts.items()
|
||||
}
|
||||
return dict(
|
||||
sorted(domain_distribution.items(), key=lambda item: item[1], reverse=True)
|
||||
)
|
||||
|
||||
async def get_active_periods(self) -> list[int]:
|
||||
"""
|
||||
Identifies the top 2 most active months for the user.
|
||||
Only returns months if there is valid and diverse time data.
|
||||
|
||||
This method checks if the time data represents real user memory timestamps
|
||||
rather than auto-generated system timestamps by verifying:
|
||||
1. Time data exists and is parseable
|
||||
2. Time data is distributed across multiple months (not concentrated in 1-2 months)
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.group_id = '{self.user_id}' AND d.created_at IS NOT NULL AND d.created_at <> ''
|
||||
RETURN d.created_at AS creation_time
|
||||
"""
|
||||
records = await self.neo4j_connector.execute_query(query)
|
||||
|
||||
if not records:
|
||||
return []
|
||||
|
||||
month_counts = Counter()
|
||||
valid_dates_count = 0
|
||||
for record in records:
|
||||
creation_time_str = record.get("creation_time")
|
||||
if not creation_time_str:
|
||||
continue
|
||||
try:
|
||||
# 尝试解析时间字符串
|
||||
dt_object = datetime.fromisoformat(creation_time_str.replace("Z", "+00:00"))
|
||||
month_counts[dt_object.month] += 1
|
||||
valid_dates_count += 1
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
# 如果解析失败,跳过这条记录
|
||||
continue
|
||||
|
||||
# 如果没有有效的时间数据,返回空列表
|
||||
if not month_counts or valid_dates_count == 0:
|
||||
return []
|
||||
|
||||
# 检查时间分布是否过于集中(可能是批量导入的数据)
|
||||
# 如果超过80%的数据集中在1-2个月,认为这是系统时间戳而非真实时间
|
||||
unique_months = len(month_counts)
|
||||
if unique_months <= 2:
|
||||
# 只有1-2个月有数据,很可能是批量导入
|
||||
most_common_count = month_counts.most_common(1)[0][1]
|
||||
if most_common_count / valid_dates_count > 0.8:
|
||||
# 超过80%集中在一个月,认为是系统时间戳
|
||||
return []
|
||||
|
||||
# 如果时间分布较为分散(3个月以上),认为是真实时间数据
|
||||
if unique_months >= 3:
|
||||
most_common_months = month_counts.most_common(2)
|
||||
return [month for month, _ in most_common_months]
|
||||
|
||||
# 2个月的情况,检查是否分布均匀
|
||||
if unique_months == 2:
|
||||
counts = list(month_counts.values())
|
||||
# 如果两个月的数据量相差不大(比例在0.3-3之间),认为是真实数据
|
||||
ratio = min(counts) / max(counts)
|
||||
if ratio > 0.3:
|
||||
most_common_months = month_counts.most_common(2)
|
||||
return [month for month, _ in most_common_months]
|
||||
|
||||
# 其他情况返回空列表
|
||||
return []
|
||||
|
||||
async def get_social_connections(self) -> dict | None:
|
||||
"""
|
||||
Finds the user with whom the most memories are shared.
|
||||
使用 Chunk-Statement 的 CONTAINS 关系,因为系统中不创建 Dialogue-Statement 的 MENTIONS 关系。
|
||||
"""
|
||||
# 通过 Chunk 和 Statement 的 CONTAINS 关系来查找共同记忆
|
||||
query = f"""
|
||||
MATCH (c1:Chunk {{group_id: '{self.user_id}'}})
|
||||
OPTIONAL MATCH (c1)-[:CONTAINS]->(s:Statement)
|
||||
OPTIONAL MATCH (s)<-[:CONTAINS]-(c2:Chunk)
|
||||
WHERE c1.group_id <> c2.group_id AND s IS NOT NULL AND c2 IS NOT NULL
|
||||
WITH c2.group_id AS other_user_id, COUNT(DISTINCT s) AS common_statements
|
||||
WHERE common_statements > 0
|
||||
RETURN other_user_id, common_statements
|
||||
ORDER BY common_statements DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
records = await self.neo4j_connector.execute_query(query)
|
||||
if not records or not records[0].get("other_user_id"):
|
||||
return None
|
||||
|
||||
most_connected_user = records[0]["other_user_id"]
|
||||
common_memories_count = records[0]["common_statements"]
|
||||
|
||||
# 使用 Chunk 的时间范围
|
||||
time_range_query = f"""
|
||||
MATCH (c:Chunk)
|
||||
WHERE c.group_id IN ['{self.user_id}', '{most_connected_user}']
|
||||
RETURN min(c.created_at) AS start_time, max(c.created_at) AS end_time
|
||||
"""
|
||||
time_records = await self.neo4j_connector.execute_query(time_range_query)
|
||||
start_year, end_year = "N/A", "N/A"
|
||||
if time_records and time_records[0]["start_time"]:
|
||||
start_year = datetime.fromisoformat(time_records[0]["start_time"].replace("Z", "+00:00")).year
|
||||
end_year = datetime.fromisoformat(time_records[0]["end_time"].replace("Z", "+00:00")).year
|
||||
|
||||
return {
|
||||
"user_id": most_connected_user,
|
||||
"common_memories_count": common_memories_count,
|
||||
"time_range": f"{start_year}-{end_year}",
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Closes the database connection.
|
||||
"""
|
||||
await self.neo4j_connector.close()
|
||||
|
||||
|
||||
async def main():
|
||||
"""
|
||||
Initializes and runs the memory insight analysis for a test user.
|
||||
"""
|
||||
# 默认从环境变量读取
|
||||
test_user_id = DEFAULT_GROUP_ID
|
||||
print(f"正在为用户 {test_user_id} 生成记忆洞察报告...\n")
|
||||
|
||||
try:
|
||||
# 使用服务层函数生成报告
|
||||
from app.services.user_memory_service import analytics_memory_insight_report
|
||||
|
||||
result = await analytics_memory_insight_report(end_user_id=test_user_id)
|
||||
report = result.get("report", "")
|
||||
|
||||
print("--- 记忆洞察报告 ---")
|
||||
print(report)
|
||||
print("---------------------")
|
||||
|
||||
# 将结果写入统一的 User-Dashboard.json,使用全局配置路径
|
||||
try:
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
output_dir = settings.MEMORY_OUTPUT_DIR
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
dashboard_path = os.path.join(output_dir, "User-Dashboard.json")
|
||||
existing = {}
|
||||
if os.path.exists(dashboard_path):
|
||||
with open(dashboard_path, "r", encoding="utf-8") as rf:
|
||||
existing = json.load(rf)
|
||||
existing["memory_insight"] = {
|
||||
"group_id": test_user_id,
|
||||
"report": report
|
||||
}
|
||||
with open(dashboard_path, "w", encoding="utf-8") as wf:
|
||||
json.dump(existing, wf, ensure_ascii=False, indent=2)
|
||||
print(f"已写入 {dashboard_path} -> memory_insight")
|
||||
except Exception as e:
|
||||
print(f"写入 User-Dashboard.json 失败: {e}")
|
||||
except Exception as e:
|
||||
print(f"生成报告时出错: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# This setup allows running the async main function
|
||||
if sys.platform.startswith('win') and sys.version_info >= (3, 8):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
asyncio.run(main())
|
||||
@@ -1,157 +0,0 @@
|
||||
"""
|
||||
Generate a concise "关于我" style user summary using data from Neo4j
|
||||
and the existing LLM configuration (mirrors hot_memory_tags.py setup).
|
||||
|
||||
Usage:
|
||||
python -m analytics.user_summary --user_id <group_id>
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
# Ensure absolute imports work whether executed directly or via module
|
||||
try:
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
src_path = os.path.join(project_root, 'src')
|
||||
if src_path not in sys.path:
|
||||
sys.path.insert(0, src_path)
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
#TODO: Fix this
|
||||
|
||||
# Default values (previously from definitions.py)
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
|
||||
|
||||
|
||||
@dataclass
|
||||
class StatementRecord:
|
||||
statement: str
|
||||
created_at: str | None
|
||||
|
||||
|
||||
class UserSummary:
|
||||
"""Builds a textual user summary for a given user/group id."""
|
||||
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
self.connector = Neo4jConnector()
|
||||
|
||||
# Get config_id using get_end_user_connected_config
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
connected_config = get_end_user_connected_config(user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id:
|
||||
# Use the config_id to get the proper LLM client
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm = factory.get_llm_client(memory_config.llm_model_id)
|
||||
else:
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM if no config found
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"Failed to get user connected config, using default LLM: {e}")
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
async def close(self):
|
||||
await self.connector.close()
|
||||
|
||||
async def _get_recent_statements(self, limit: int = 80) -> List[StatementRecord]: # TODO Used by user_memory_service
|
||||
"""Fetch recent statements authored by the user/group for context."""
|
||||
query = (
|
||||
"MATCH (s:Statement) "
|
||||
"WHERE s.group_id = $group_id AND s.statement IS NOT NULL "
|
||||
"RETURN s.statement AS statement, s.created_at AS created_at "
|
||||
"ORDER BY created_at DESC LIMIT $limit"
|
||||
)
|
||||
rows = await self.connector.execute_query(query, group_id=self.user_id, limit=limit)
|
||||
records: List[StatementRecord] = []
|
||||
for r in rows:
|
||||
try:
|
||||
records.append(StatementRecord(statement=r.get("statement", ""), created_at=r.get("created_at")))
|
||||
except Exception:
|
||||
continue
|
||||
return records
|
||||
|
||||
async def _get_top_entities(self, limit: int = 30) -> List[Tuple[str, int]]:
|
||||
"""Reuse hot tag logic to get meaningful entities and their frequencies."""
|
||||
# get_hot_memory_tags internally filters out non-meaningful nouns with LLM
|
||||
return await get_hot_memory_tags(self.user_id, limit=limit) # TODO Used by user_memory_service
|
||||
|
||||
|
||||
async def generate_user_summary(user_id: str | None = None) -> str: # TODO useless
|
||||
"""
|
||||
生成用户摘要的便捷函数
|
||||
|
||||
Args:
|
||||
user_id: 可选的用户ID
|
||||
|
||||
Returns:
|
||||
用户摘要字符串
|
||||
"""
|
||||
# 导入服务层函数
|
||||
from app.services.user_memory_service import analytics_user_summary
|
||||
|
||||
# 调用服务层函数
|
||||
result = await analytics_user_summary(user_id)
|
||||
return result.get("summary", "")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("开始生成用户摘要…")
|
||||
try:
|
||||
# 直接使用 runtime.json 中的 group_id
|
||||
summary = asyncio.run(generate_user_summary())
|
||||
print("\n— 用户摘要 —\n")
|
||||
print(summary)
|
||||
|
||||
# 将结果写入统一的 User-Dashboard.json
|
||||
try:
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
output_dir = settings.MEMORY_OUTPUT_DIR
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
dashboard_path = os.path.join(output_dir, "User-Dashboard.json")
|
||||
existing = {}
|
||||
if os.path.exists(dashboard_path):
|
||||
with open(dashboard_path, "r", encoding="utf-8") as rf:
|
||||
existing = json.load(rf)
|
||||
existing["user_summary"] = {
|
||||
"group_id": DEFAULT_GROUP_ID,
|
||||
"summary": summary
|
||||
}
|
||||
with open(dashboard_path, "w", encoding="utf-8") as wf:
|
||||
json.dump(existing, wf, ensure_ascii=False, indent=2)
|
||||
print(f"已写入 {dashboard_path} -> user_summary")
|
||||
except Exception as e:
|
||||
print(f"写入 User-Dashboard.json 失败: {e}")
|
||||
except Exception as e:
|
||||
print(f"生成摘要失败: {e}")
|
||||
print("请检查: 1) Neo4j 是否可用;2) config.json 与 .env 的 LLM/Neo4j 配置是否正确;3) 数据是否包含该用户的内容。")
|
||||
@@ -7,7 +7,6 @@ User Memory Service
|
||||
import os
|
||||
import uuid
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
@@ -22,7 +21,269 @@ from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Neo4j connector instan
|
||||
# Neo4j connector instance for analytics functions
|
||||
_neo4j_connector = Neo4jConnector()
|
||||
|
||||
# Default LLM ID for fallback
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Internal Helper Classes
|
||||
# ============================================================================
|
||||
|
||||
class TagClassification(BaseModel):
|
||||
"""Represents the classification of a tag into a specific domain."""
|
||||
domain: str = Field(
|
||||
...,
|
||||
description="The domain the tag belongs to, chosen from the predefined list.",
|
||||
examples=["教育", "学习", "工作", "旅行", "家庭", "运动", "社交", "娱乐", "健康", "其他"],
|
||||
)
|
||||
|
||||
|
||||
def _get_llm_client_for_user(user_id: str):
|
||||
"""
|
||||
Get LLM client for a specific user based on their config.
|
||||
|
||||
Args:
|
||||
user_id: User ID to get config for
|
||||
|
||||
Returns:
|
||||
LLM client instance
|
||||
"""
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
connected_config = get_end_user_connected_config(user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(memory_config.llm_model_id)
|
||||
else:
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get user connected config, using default LLM: {e}")
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
|
||||
class MemoryInsightHelper:
|
||||
"""
|
||||
Internal helper class for memory insight analysis.
|
||||
Provides basic data retrieval and analysis functionality.
|
||||
"""
|
||||
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
self.neo4j_connector = Neo4jConnector()
|
||||
self.llm_client = _get_llm_client_for_user(user_id)
|
||||
|
||||
async def close(self):
|
||||
"""Close database connection."""
|
||||
await self.neo4j_connector.close()
|
||||
|
||||
async def get_domain_distribution(self) -> dict[str, float]:
|
||||
"""Calculate the distribution of memory domains based on hot tags."""
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
|
||||
hot_tags = await get_hot_memory_tags(self.user_id)
|
||||
if not hot_tags:
|
||||
return {}
|
||||
|
||||
domain_counts = Counter()
|
||||
for tag, _ in hot_tags:
|
||||
prompt = f"""请将以下标签归类到最合适的领域中。
|
||||
|
||||
可选领域及其关键词:
|
||||
- 教育:学校、课程、考试、培训、教学、学科、教师、学生、班级、作业、成绩、毕业、入学、校园、大学、中学、小学、教材、学位等
|
||||
- 学习:自学、阅读、书籍、技能提升、知识积累、笔记、复习、练习、研究、历史知识、科学知识、文化知识、学术讨论、知识问答等
|
||||
- 工作:职业、项目、会议、同事、业务、公司、办公、任务、客户、合同、职场、工作计划等
|
||||
- 旅行:旅游、景点、出行、度假、酒店、机票、导游、风景、旅行计划等
|
||||
- 家庭:亲人、父母、子女、配偶、家事、家庭活动、亲情、家庭聚会等
|
||||
- 运动:健身、体育、锻炼、跑步、游泳、球类、瑜伽、运动计划等
|
||||
- 社交:朋友、聚会、社交活动、派对、聊天、交友、社交网络等
|
||||
- 娱乐:游戏、电影、音乐、休闲、综艺、动漫、小说、娱乐活动等
|
||||
- 健康:医疗、养生、心理健康、体检、药物、疾病、保健、健康管理等
|
||||
- 其他:确实无法归入以上任何类别的内容
|
||||
|
||||
标签: {tag}
|
||||
|
||||
分析步骤:
|
||||
1. 仔细理解标签的核心含义和使用场景
|
||||
2. 对比各个领域的关键词,找到最匹配的领域
|
||||
3. 特别注意:
|
||||
- 历史、科学、文化等知识性内容应归类为"学习"
|
||||
- 学校、课程、考试等正式教育场景应归类为"教育"
|
||||
- 只有在标签完全不属于上述9个具体领域时,才选择"其他"
|
||||
4. 如果标签与某个领域有任何相关性,就选择该领域,不要选"其他"
|
||||
|
||||
请直接返回最合适的领域名称。"""
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的标签分类助手。你必须仔细分析标签的实际含义和使用场景,优先选择9个具体领域之一。'其他'类别只用于完全无法归类的极少数情况。特别注意:历史、科学、文化等知识性对话应归类为'学习'领域;学校、课程、考试等正式教育场景应归类为'教育'领域。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
classification = await self.llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=TagClassification,
|
||||
)
|
||||
if classification and hasattr(classification, 'domain') and classification.domain:
|
||||
domain_counts[classification.domain] += 1
|
||||
|
||||
total_tags = sum(domain_counts.values())
|
||||
if total_tags == 0:
|
||||
return {}
|
||||
|
||||
domain_distribution = {
|
||||
domain: count / total_tags for domain, count in domain_counts.items()
|
||||
}
|
||||
return dict(sorted(domain_distribution.items(), key=lambda item: item[1], reverse=True))
|
||||
|
||||
async def get_active_periods(self) -> list[int]:
|
||||
"""
|
||||
Identify the top 2 most active months for the user.
|
||||
Only returns months if there is valid and diverse time data.
|
||||
"""
|
||||
query = """
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.group_id = $group_id AND d.created_at IS NOT NULL AND d.created_at <> ''
|
||||
RETURN d.created_at AS creation_time
|
||||
"""
|
||||
records = await self.neo4j_connector.execute_query(query, group_id=self.user_id)
|
||||
|
||||
if not records:
|
||||
return []
|
||||
|
||||
month_counts = Counter()
|
||||
valid_dates_count = 0
|
||||
for record in records:
|
||||
creation_time_str = record.get("creation_time")
|
||||
if not creation_time_str:
|
||||
continue
|
||||
try:
|
||||
dt_object = datetime.fromisoformat(creation_time_str.replace("Z", "+00:00"))
|
||||
month_counts[dt_object.month] += 1
|
||||
valid_dates_count += 1
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
continue
|
||||
|
||||
if not month_counts or valid_dates_count == 0:
|
||||
return []
|
||||
|
||||
# Check if time distribution is too concentrated (likely batch imported data)
|
||||
unique_months = len(month_counts)
|
||||
if unique_months <= 2:
|
||||
most_common_count = month_counts.most_common(1)[0][1]
|
||||
if most_common_count / valid_dates_count > 0.8:
|
||||
return []
|
||||
|
||||
if unique_months >= 3:
|
||||
most_common_months = month_counts.most_common(2)
|
||||
return [month for month, _ in most_common_months]
|
||||
|
||||
if unique_months == 2:
|
||||
counts = list(month_counts.values())
|
||||
ratio = min(counts) / max(counts)
|
||||
if ratio > 0.3:
|
||||
most_common_months = month_counts.most_common(2)
|
||||
return [month for month, _ in most_common_months]
|
||||
|
||||
return []
|
||||
|
||||
async def get_social_connections(self) -> dict | None:
|
||||
"""Find the user with whom the most memories are shared."""
|
||||
query = """
|
||||
MATCH (c1:Chunk {group_id: $group_id})
|
||||
OPTIONAL MATCH (c1)-[:CONTAINS]->(s:Statement)
|
||||
OPTIONAL MATCH (s)<-[:CONTAINS]-(c2:Chunk)
|
||||
WHERE c1.group_id <> c2.group_id AND s IS NOT NULL AND c2 IS NOT NULL
|
||||
WITH c2.group_id AS other_user_id, COUNT(DISTINCT s) AS common_statements
|
||||
WHERE common_statements > 0
|
||||
RETURN other_user_id, common_statements
|
||||
ORDER BY common_statements DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
records = await self.neo4j_connector.execute_query(query, group_id=self.user_id)
|
||||
if not records or not records[0].get("other_user_id"):
|
||||
return None
|
||||
|
||||
most_connected_user = records[0]["other_user_id"]
|
||||
common_memories_count = records[0]["common_statements"]
|
||||
|
||||
time_range_query = """
|
||||
MATCH (c:Chunk)
|
||||
WHERE c.group_id IN [$user_id, $other_user_id]
|
||||
RETURN min(c.created_at) AS start_time, max(c.created_at) AS end_time
|
||||
"""
|
||||
time_records = await self.neo4j_connector.execute_query(
|
||||
time_range_query,
|
||||
user_id=self.user_id,
|
||||
other_user_id=most_connected_user
|
||||
)
|
||||
start_year, end_year = "N/A", "N/A"
|
||||
if time_records and time_records[0]["start_time"]:
|
||||
start_year = datetime.fromisoformat(time_records[0]["start_time"].replace("Z", "+00:00")).year
|
||||
end_year = datetime.fromisoformat(time_records[0]["end_time"].replace("Z", "+00:00")).year
|
||||
|
||||
return {
|
||||
"user_id": most_connected_user,
|
||||
"common_memories_count": common_memories_count,
|
||||
"time_range": f"{start_year}-{end_year}",
|
||||
}
|
||||
|
||||
|
||||
class UserSummaryHelper:
|
||||
"""
|
||||
Internal helper class for user summary generation.
|
||||
Provides data retrieval functionality for user summary analysis.
|
||||
"""
|
||||
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
self.connector = Neo4jConnector()
|
||||
self.llm = _get_llm_client_for_user(user_id)
|
||||
|
||||
async def close(self):
|
||||
"""Close database connection."""
|
||||
await self.connector.close()
|
||||
|
||||
async def get_recent_statements(self, limit: int = 80) -> List[Dict[str, Any]]:
|
||||
"""Fetch recent statements authored by the user/group for context."""
|
||||
query = (
|
||||
"MATCH (s:Statement) "
|
||||
"WHERE s.group_id = $group_id AND s.statement IS NOT NULL "
|
||||
"RETURN s.statement AS statement, s.created_at AS created_at "
|
||||
"ORDER BY created_at DESC LIMIT $limit"
|
||||
)
|
||||
rows = await self.connector.execute_query(query, group_id=self.user_id, limit=limit)
|
||||
records = []
|
||||
for r in rows:
|
||||
try:
|
||||
records.append({
|
||||
"statement": r.get("statement", ""),
|
||||
"created_at": r.get("created_at")
|
||||
})
|
||||
except Exception:
|
||||
continue
|
||||
return records
|
||||
|
||||
async def get_top_entities(self, limit: int = 30) -> List[Tuple[str, int]]:
|
||||
"""Get meaningful entities and their frequencies using hot tag logic."""
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
return await get_hot_memory_tags(self.user_id, limit=limit)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Service Class
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Service Class
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class UserMemoryService:
|
||||
@@ -601,7 +862,7 @@ async def analytics_memory_insight_report(end_user_id: Optional[str] = None) ->
|
||||
生成记忆洞察报告(四个维度)
|
||||
|
||||
这个函数包含完整的业务逻辑:
|
||||
1. 使用 MemoryInsight 工具类获取基础数据(领域分布、活跃时段、社交关联)
|
||||
1. 使用 MemoryInsightHelper 工具类获取基础数据(领域分布、活跃时段、社交关联)
|
||||
2. 使用 Jinja2 模板渲染提示词
|
||||
3. 调用 LLM 生成四个维度的自然语言报告
|
||||
4. 解析并返回四个部分
|
||||
@@ -620,7 +881,7 @@ async def analytics_memory_insight_report(end_user_id: Optional[str] = None) ->
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt
|
||||
import re
|
||||
|
||||
insight = MemoryInsight(end_user_id)
|
||||
insight = MemoryInsightHelper(end_user_id)
|
||||
|
||||
try:
|
||||
# 1. 并行获取三个维度的数据
|
||||
@@ -722,7 +983,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str,
|
||||
生成用户摘要(包含四个部分)
|
||||
|
||||
这个函数包含完整的业务逻辑:
|
||||
1. 使用 UserSummary 工具类获取基础数据(实体、语句)
|
||||
1. 使用 UserSummaryHelper 工具类获取基础数据(实体、语句)
|
||||
2. 使用 prompt_utils 渲染提示词
|
||||
3. 调用 LLM 生成四部分内容:基本介绍、性格特点、核心价值观、一句话总结
|
||||
|
||||
@@ -737,20 +998,19 @@ async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str,
|
||||
"one_sentence": str
|
||||
}
|
||||
"""
|
||||
from app.core.memory.analytics.user_summary import UserSummary
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
|
||||
import re
|
||||
|
||||
# 创建 UserSummary 实例
|
||||
user_summary_tool = UserSummary(end_user_id or os.getenv("SELECTED_GROUP_ID", "group_123"))
|
||||
# 创建 UserSummaryHelper 实例
|
||||
user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_GROUP_ID", "group_123"))
|
||||
|
||||
try:
|
||||
# 1) 收集上下文数据
|
||||
entities = await user_summary_tool._get_top_entities(limit=40)
|
||||
statements = await user_summary_tool._get_recent_statements(limit=100)
|
||||
entities = await user_summary_tool.get_top_entities(limit=40)
|
||||
statements = await user_summary_tool.get_recent_statements(limit=100)
|
||||
|
||||
entity_lines = [f"{name} ({freq})" for name, freq in entities][:20]
|
||||
statement_samples = [s.statement.strip() for s in statements if (s.statement or '').strip()][:20]
|
||||
statement_samples = [s["statement"].strip() for s in statements if s.get("statement", "").strip()][:20]
|
||||
|
||||
# 2) 使用 prompt_utils 渲染提示词
|
||||
user_prompt = await render_user_summary_prompt(
|
||||
|
||||
Reference in New Issue
Block a user