From 7870c6c33f8048df8cd7351aec7904f3c62bbdd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= <162269739+lanceyq@users.noreply.github.com> Date: Fri, 23 Jan 2026 10:50:24 +0800 Subject: [PATCH 01/28] Fix/interface home (#182) * [fix]Fix the interface for statistics of recent activities and applications * [changes]Modify the code based on the AI review 1.Use the boolean auxiliary methods provided by SQLAlchemy instead of using == True in the is_active filter. 2.The calculation of the "PROJECT_ROOT" has now been hardcoded with five levels of nested os.path.dirname calls. * [fix]Fix the interface for statistics of recent activities and applications * [changes]Modify the code based on the AI review 1.Use the boolean auxiliary methods provided by SQLAlchemy instead of using == True in the is_active filter. 2.The calculation of the "PROJECT_ROOT" has now been hardcoded with five levels of nested os.path.dirname calls. --- .../controllers/public_share_controller.py | 7 +- api/app/controllers/workflow_controller.py | 14 +-- api/app/core/memory/agent/utils/llm_tools.py | 3 +- .../core/memory/analytics/api_docs_parser.py | 3 +- .../memory/analytics/recent_activity_stats.py | 88 +++++++++---------- .../memory/evaluation/locomo/locomo_test.py | 5 +- .../longmemeval/qwen_search_eval.py | 5 +- .../evaluation/memsciqa/memsciqa-test.py | 5 +- api/app/repositories/app_repository.py | 10 ++- api/app/repositories/home_page_repository.py | 30 +++---- api/app/repositories/user_repository.py | 4 +- api/app/repositories/workflow_repository.py | 2 +- api/app/repositories/workspace_repository.py | 24 ++--- api/app/services/agent_registry.py | 4 +- api/app/services/app_service.py | 10 +-- api/app/services/draft_run_service.py | 2 +- api/app/services/memory_agent_service.py | 17 ++-- api/app/services/memory_api_service.py | 5 +- api/app/services/memory_reflection_service.py | 5 +- api/app/services/memory_storage_service.py | 3 +- api/app/services/multi_agent_orchestrator.py | 4 +- api/app/services/multi_agent_service.py | 4 +- api/app/services/shared_chat_service.py | 8 +- api/app/services/workflow_service.py | 5 +- api/app/tasks.py | 7 +- api/migrations/env.py | 3 +- 26 files changed, 148 insertions(+), 129 deletions(-) diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index 17ad70a7..6e2d383c 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -317,9 +317,12 @@ async def chat( appid = share.app_id """获取存储类型和工作空间的ID""" - # 直接通过 SQLAlchemy 查询 app + # 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用) from app.models.app_model import App - app = db.query(App).filter(App.id == appid).first() + app = db.query(App).filter( + App.id == appid, + App.is_active.is_(True) + ).first() if not app: raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND) diff --git a/api/app/controllers/workflow_controller.py b/api/app/controllers/workflow_controller.py index c6d9ddab..8a15f717 100644 --- a/api/app/controllers/workflow_controller.py +++ b/api/app/controllers/workflow_controller.py @@ -54,7 +54,7 @@ async def create_workflow_config( app = db.query(App).filter( App.id == app_id, App.workspace_id == current_user.current_workspace_id, - App.is_active == True + App.is_active.is_(True) ).first() if not app: @@ -214,7 +214,7 @@ async def delete_workflow_config( app = db.query(App).filter( App.id == app_id, App.workspace_id == current_user.current_workspace_id, - App.is_active == True + App.is_active.is_(True) ).first() if not app: @@ -259,7 +259,7 @@ async def validate_workflow_config( app = db.query(App).filter( App.id == app_id, App.workspace_id == current_user.current_workspace_id, - App.is_active == True + App.is_active.is_(True) ).first() if not app: @@ -329,7 +329,7 @@ async def get_workflow_executions( app = db.query(App).filter( App.id == app_id, App.workspace_id == current_user.current_workspace_id, - App.is_active == True + App.is_active.is_(True) ).first() if not app: @@ -389,7 +389,7 @@ async def get_workflow_execution( app = db.query(App).filter( App.id == execution.app_id, App.workspace_id == current_user.current_workspace_id, - App.is_active == True + App.is_active.is_(True) ).first() if not app: @@ -440,7 +440,7 @@ async def run_workflow( app = db.query(App).filter( App.id == app_id, App.workspace_id == current_user.current_workspace_id, - App.is_active == True + App.is_active.is_(True) ).first() if not app: @@ -578,7 +578,7 @@ async def cancel_workflow_execution( app = db.query(App).filter( App.id == execution.app_id, App.workspace_id == current_user.current_workspace_id, - App.is_active == True + App.is_active.is_(True) ).first() if not app: diff --git a/api/app/core/memory/agent/utils/llm_tools.py b/api/app/core/memory/agent/utils/llm_tools.py index 8dd2f1d3..e73d5653 100644 --- a/api/app/core/memory/agent/utils/llm_tools.py +++ b/api/app/core/memory/agent/utils/llm_tools.py @@ -1,11 +1,12 @@ import os from collections import defaultdict +from pathlib import Path from typing import Annotated, TypedDict from langchain_core.messages import AnyMessage from langgraph.graph import add_messages -PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3]) class WriteState(TypedDict): ''' diff --git a/api/app/core/memory/analytics/api_docs_parser.py b/api/app/core/memory/analytics/api_docs_parser.py index 94ed0f00..4a116520 100644 --- a/api/app/core/memory/analytics/api_docs_parser.py +++ b/api/app/core/memory/analytics/api_docs_parser.py @@ -139,7 +139,8 @@ def parse_api_docs(file_path: str) -> Dict[str, Any]: def get_default_docs_path() -> str: - project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + from pathlib import Path + project_root = str(Path(__file__).resolve().parents[2]) return os.path.join(project_root, "src", "analytics", "API接口.md") diff --git a/api/app/core/memory/analytics/recent_activity_stats.py b/api/app/core/memory/analytics/recent_activity_stats.py index c41f4208..71f70c09 100644 --- a/api/app/core/memory/analytics/recent_activity_stats.py +++ b/api/app/core/memory/analytics/recent_activity_stats.py @@ -2,13 +2,16 @@ import os import re import glob import json +from pathlib import Path from typing import Tuple try: from app.core.memory.utils.config.definitions import PROJECT_ROOT except Exception: # Fallback: derive project root from this file location - PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + # 当前文件在 api/app/core/memory/analytics/recent_activity_stats.py + # 需要向上 5 级到达 api/ 目录 + PROJECT_ROOT = str(Path(__file__).resolve().parents[4]) def _get_latest_prompt_log_path() -> str | None: @@ -67,44 +70,43 @@ def parse_stats_from_log(log_path: str) -> dict: triplet_relations_count = 0 temporal_count = 0 - # Patterns + # 正则表达式模式 - 匹配当前日志格式 pat_chunk_render = re.compile(r"===\s*RENDERED\s*STATEMENT\s*EXTRACTION\s*PROMPT\s*===") - pat_triplet_start = re.compile(r"\[Triplet\].*statements_to_process\s*=\s*(\d+)") - pat_triplet_done = re.compile( - r"\[Triplet\].*completed,\s*total_triplets\s*=\s*(\d+),\s*total_entities\s*=\s*(\d+)" + pat_triplet_started = re.compile(r"\[Triplet\]\s+Started\s+-\s+statement_id=") + pat_triplet_completed = re.compile( + r"\[Triplet\]\s+Completed\s+-\s+statement_id=[^,]+,\s+triplets=(\d+),\s+entities=(\d+)" ) - pat_temporal_done = re.compile( - r"\[Temporal\].*completed,\s*extracted_valid_ranges\s*=\s*(\d+)" + pat_temporal_completed = re.compile( + r"\[Temporal\]\s+Completed\s+-\s+statement_id=[^,]+,\s+valid_ranges=(\d+)" ) with open(log_path, "r", encoding="utf-8", errors="ignore") as f: for line in f: - # Chunk prompts count (each chunk triggers one statement-extraction prompt render) + # 文本块数量(每个块触发一次陈述提取提示) if pat_chunk_render.search(line): chunk_count += 1 continue - m1 = pat_triplet_start.search(line) - if m1: + # 陈述数量(每个 Triplet Started 代表一个陈述被处理) + if pat_triplet_started.search(line): + statements_count += 1 + continue + + # 三元组完成:[Triplet] Completed - statement_id=xxx, triplets=X, entities=Y + m_triplet = pat_triplet_completed.search(line) + if m_triplet: try: - statements_count += int(m1.group(1)) + triplet_relations_count += int(m_triplet.group(1)) + triplet_entities_count += int(m_triplet.group(2)) except Exception: pass continue - m2 = pat_triplet_done.search(line) - if m2: + # 时间信息完成:[Temporal] Completed - statement_id=xxx, valid_ranges=X + m_temporal = pat_temporal_completed.search(line) + if m_temporal: try: - triplet_relations_count += int(m2.group(1)) - triplet_entities_count += int(m2.group(2)) - except Exception: - pass - continue - - m3 = pat_temporal_done.search(line) - if m3: - try: - temporal_count += int(m3.group(1)) + temporal_count += int(m_temporal.group(1)) except Exception: pass continue @@ -120,15 +122,20 @@ def parse_stats_from_log(log_path: str) -> dict: def get_recent_activity_stats() -> Tuple[dict, str]: - """Get aggregated stats from all prompt logs in logs/. + """Get stats from the latest prompt log file only. Returns (stats_dict, message). """ - all_logs = _get_all_prompt_logs() - # Fallback to recursive search if none found in logs/ - if not all_logs: + # 获取最新的日志文件 + latest_log = _get_latest_prompt_log_path() + + # 如果没有找到,尝试递归搜索 + if not latest_log: all_logs = _get_any_logs_recursive() - if not all_logs: + if all_logs: + latest_log = all_logs[-1] # 取最新的 + + if not latest_log: return ( { "chunk_count": 0, @@ -141,24 +148,13 @@ def get_recent_activity_stats() -> Tuple[dict, str]: "未找到日志文件,请确认已运行过提取流程。", ) - agg = { - "chunk_count": 0, - "statements_count": 0, - "triplet_entities_count": 0, - "triplet_relations_count": 0, - "temporal_count": 0, - } - for path in all_logs: - s = parse_stats_from_log(path) - agg["chunk_count"] += s.get("chunk_count", 0) - agg["statements_count"] += s.get("statements_count", 0) - agg["triplet_entities_count"] += s.get("triplet_entities_count", 0) - agg["triplet_relations_count"] += s.get("triplet_relations_count", 0) - agg["temporal_count"] += s.get("temporal_count", 0) - - # Attach a summary of files combined - agg["log_path"] = f"{len(all_logs)} 个日志文件,最新:{all_logs[-1]}" - return agg, "成功汇总 logs 目录中所有提示日志。" + # 只解析最新的日志文件 + stats = parse_stats_from_log(latest_log) + + # 添加日志文件路径信息 + stats["log_path"] = f"最新:{latest_log}" + + return stats, "成功读取最近一次记忆活动统计。" def _format_summary(stats: dict) -> str: diff --git a/api/app/core/memory/evaluation/locomo/locomo_test.py b/api/app/core/memory/evaluation/locomo/locomo_test.py index b5ad5820..affedd0f 100644 --- a/api/app/core/memory/evaluation/locomo/locomo_test.py +++ b/api/app/core/memory/evaluation/locomo/locomo_test.py @@ -8,13 +8,14 @@ import sys import time from datetime import datetime, timedelta from typing import Any, Dict, List +from pathlib import Path from dotenv import load_dotenv # 1 # 添加项目根目录到路径 -current_dir = os.path.dirname(os.path.abspath(__file__)) -project_root = os.path.dirname(current_dir) +current_dir = Path(__file__).resolve().parent +project_root = str(current_dir.parent) if project_root not in sys.path: sys.path.insert(0, project_root) # 关键:将 src 目录置于最前,确保从当前仓库加载模块 diff --git a/api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py b/api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py index 53c5ce19..292e7288 100644 --- a/api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py +++ b/api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py @@ -16,9 +16,10 @@ except Exception: # 确保可以找到 src 及项目根路径 import sys +from pathlib import Path -_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) -_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(_THIS_DIR))) +_THIS_DIR = Path(__file__).resolve().parent +_PROJECT_ROOT = str(_THIS_DIR.parents[2]) _SRC_DIR = os.path.join(_PROJECT_ROOT, "src") for _p in (_SRC_DIR, _PROJECT_ROOT): if _p not in sys.path: diff --git a/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py b/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py index 279f4042..900cda9d 100644 --- a/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py +++ b/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py @@ -15,9 +15,10 @@ except Exception: # 路径与模块导入保持与现有评估脚本一致 import sys +from pathlib import Path -_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) -_PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR)) +_THIS_DIR = Path(__file__).resolve().parent +_PROJECT_ROOT = str(_THIS_DIR.parents[1]) _SRC_DIR = os.path.join(_PROJECT_ROOT, "src") for _p in (_SRC_DIR, _PROJECT_ROOT): if _p not in sys.path: diff --git a/api/app/repositories/app_repository.py b/api/app/repositories/app_repository.py index 11a2ea3e..0c7ba6a4 100644 --- a/api/app/repositories/app_repository.py +++ b/api/app/repositories/app_repository.py @@ -15,9 +15,13 @@ class AppRepository: self.db = db def get_apps_by_workspace_id(self, workspace_id: uuid.UUID) -> list[App]: - """根据工作空间ID查询应用""" + """根据工作空间ID查询应用(仅返回未删除的应用)""" try: - apps = self.db.query(App).filter(App.workspace_id == workspace_id).all() + apps = ( + self.db.query(App) + .filter(App.workspace_id == workspace_id, App.is_active.is_(True)) + .all() + ) db_logger.info(f"成功查询工作空间 {workspace_id} 下的 {len(apps)} 个应用") return apps except Exception as e: @@ -26,7 +30,7 @@ class AppRepository: def get_apps_by_id(self, app_id: uuid.UUID) -> App: try: - app = self.db.query(App).filter(App.id == app_id, App.is_active == True).first() + app = self.db.query(App).filter(App.id == app_id, App.is_active.is_(True)).first() return app except Exception as e: raise diff --git a/api/app/repositories/home_page_repository.py b/api/app/repositories/home_page_repository.py index 888071ac..bcb3b622 100644 --- a/api/app/repositories/home_page_repository.py +++ b/api/app/repositories/home_page_repository.py @@ -17,24 +17,24 @@ class HomePageRepository: """获取模型统计数据""" total_models = db.query(ModelConfig).filter( ModelConfig.tenant_id == tenant_id, - ModelConfig.is_active == True + ModelConfig.is_active.is_(True) ).count() total_llm = db.query(ModelConfig).filter( ModelConfig.tenant_id == tenant_id, - ModelConfig.is_active == True, + ModelConfig.is_active.is_(True), ModelConfig.type == "llm" ).count() total_embedding = db.query(ModelConfig).filter( ModelConfig.tenant_id == tenant_id, - ModelConfig.is_active == True, + ModelConfig.is_active.is_(True), ModelConfig.type == "embedding" ).count() new_models_this_week = db.query(ModelConfig).filter( ModelConfig.tenant_id == tenant_id, - ModelConfig.is_active == True, + ModelConfig.is_active.is_(True), ModelConfig.created_at >= week_start ).count() @@ -56,12 +56,12 @@ class HomePageRepository: """获取工作空间统计数据""" active_workspaces = db.query(Workspace).filter( Workspace.tenant_id == tenant_id, - Workspace.is_active == True + Workspace.is_active.is_(True) ).count() new_workspaces_this_week = db.query(Workspace).filter( Workspace.tenant_id == tenant_id, - Workspace.is_active == True, + Workspace.is_active.is_(True), Workspace.created_at >= week_start ).count() @@ -83,7 +83,7 @@ class HomePageRepository: """获取用户统计数据""" workspace_ids = db.query(Workspace.id).filter( Workspace.tenant_id == tenant_id, - Workspace.is_active == True + Workspace.is_active.is_(True) ).subquery() total_users = db.query(EndUser).join( @@ -91,7 +91,7 @@ class HomePageRepository: EndUser.app_id == App.id ).filter( App.workspace_id.in_(workspace_ids), - App.is_active == True, + App.is_active.is_(True), App.status == "active" ).count() @@ -100,7 +100,7 @@ class HomePageRepository: EndUser.app_id == App.id ).filter( App.workspace_id.in_(workspace_ids), - App.is_active == True, + App.is_active.is_(True), App.status == "active", EndUser.created_at >= week_start ).count() @@ -123,18 +123,18 @@ class HomePageRepository: """获取应用统计数据""" workspace_ids = db.query(Workspace.id).filter( Workspace.tenant_id == tenant_id, - Workspace.is_active == True + Workspace.is_active.is_(True) ).subquery() running_apps = db.query(App).filter( App.workspace_id.in_(workspace_ids), - App.is_active == True, + App.is_active.is_(True), App.status == "active" ).count() new_apps_this_week = db.query(App).filter( App.workspace_id.in_(workspace_ids), - App.is_active == True, + App.is_active.is_(True), App.status == "active", App.created_at >= week_start ).count() @@ -158,7 +158,7 @@ class HomePageRepository: # 获取工作空间列表 workspaces = db.query(Workspace).filter( Workspace.tenant_id == tenant_id, - Workspace.is_active == True + Workspace.is_active.is_(True) ).all() workspace_ids = [ws.id for ws in workspaces] @@ -169,7 +169,7 @@ class HomePageRepository: func.count(App.id).label('count') ).filter( App.workspace_id.in_(workspace_ids), - App.is_active, + App.is_active.is_(True), App.status == "active" ).group_by(App.workspace_id).all() @@ -184,7 +184,7 @@ class HomePageRepository: EndUser.app_id == App.id ).filter( App.workspace_id.in_(workspace_ids), - App.is_active, + App.is_active.is_(True), App.status == "active" ).group_by(App.workspace_id).all() diff --git a/api/app/repositories/user_repository.py b/api/app/repositories/user_repository.py index a43c5869..b4c11aa4 100644 --- a/api/app/repositories/user_repository.py +++ b/api/app/repositories/user_repository.py @@ -68,7 +68,7 @@ class UserRepository: db_logger.debug("查询超级用户") try: - user = self.db.query(User).options(joinedload(User.tenant)).filter(User.is_active == True).filter(User.is_superuser == True).first() + user = self.db.query(User).options(joinedload(User.tenant)).filter(User.is_active.is_(True)).filter(User.is_superuser.is_(True)).first() if user: db_logger.debug(f"超级用户查询成功: {user.username}") else: @@ -82,7 +82,7 @@ class UserRepository: db_logger.debug("检查是否只有一个超级用户") try: - count = self.db.query(User).options(joinedload(User.tenant)).filter(User.is_active == True).filter(User.is_superuser == True).count() + count = self.db.query(User).options(joinedload(User.tenant)).filter(User.is_active.is_(True)).filter(User.is_superuser.is_(True)).count() return count == 1 except Exception as e: db_logger.error(f"检查超级用户数量失败: {str(e)}") diff --git a/api/app/repositories/workflow_repository.py b/api/app/repositories/workflow_repository.py index 04734640..b22673e6 100644 --- a/api/app/repositories/workflow_repository.py +++ b/api/app/repositories/workflow_repository.py @@ -33,7 +33,7 @@ class WorkflowConfigRepository: """ return self.db.query(WorkflowConfig).filter( WorkflowConfig.app_id == app_id, - WorkflowConfig.is_active == True + WorkflowConfig.is_active.is_(True) ).first() def create_or_update( diff --git a/api/app/repositories/workspace_repository.py b/api/app/repositories/workspace_repository.py index 106830be..70ed7521 100644 --- a/api/app/repositories/workspace_repository.py +++ b/api/app/repositories/workspace_repository.py @@ -103,7 +103,7 @@ class WorkspaceRepository: workspaces = ( self.db.query(Workspace) .filter(Workspace.tenant_id == user.tenant_id) - .filter(Workspace.is_active == True) + .filter(Workspace.is_active.is_(True)) .order_by(Workspace.updated_at.desc()) .all() ) @@ -115,7 +115,7 @@ class WorkspaceRepository: self.db.query(Workspace) .join(WorkspaceMember, Workspace.id == WorkspaceMember.workspace_id) .filter(WorkspaceMember.user_id == user_id) - .filter(Workspace.is_active == True) + .filter(Workspace.is_active.is_(True)) .order_by(Workspace.updated_at.desc()) .all() ) @@ -134,7 +134,7 @@ class WorkspaceRepository: workspaces = ( self.db.query(Workspace) .filter(Workspace.tenant_id == tenant_id) - .filter(Workspace.is_active == True) + .filter(Workspace.is_active.is_(True)) .all() ) db_logger.debug(f"租户工作空间查询成功: tenant_id={tenant_id}, 数量={len(workspaces)}") @@ -169,7 +169,7 @@ class WorkspaceRepository: member = self.db.query(WorkspaceMember).filter( WorkspaceMember.user_id == user_id, WorkspaceMember.workspace_id == workspace_id, - WorkspaceMember.is_active == True, + WorkspaceMember.is_active.is_(True), ).first() if member: db_logger.debug(f"工作空间成员查询成功: user_id={user_id}, workspace_id={workspace_id}, role={member.role}") @@ -189,8 +189,8 @@ class WorkspaceRepository: .join(User, WorkspaceMember.user_id == User.id) .options(joinedload(WorkspaceMember.user), joinedload(WorkspaceMember.workspace)) .filter(WorkspaceMember.workspace_id == workspace_id) - .filter(WorkspaceMember.is_active == True) - .filter(User.is_active == True) + .filter(WorkspaceMember.is_active.is_(True)) + .filter(User.is_active.is_(True)) .all() ) db_logger.debug(f"成员列表查询成功: workspace_id={workspace_id}, 数量={len(members)}") @@ -208,8 +208,8 @@ class WorkspaceRepository: .join(User, WorkspaceMember.user_id == User.id) .options(joinedload(WorkspaceMember.user), joinedload(WorkspaceMember.workspace)) .filter(WorkspaceMember.id == member_id) - .filter(WorkspaceMember.is_active == True) - .filter(User.is_active == True) + .filter(WorkspaceMember.is_active.is_(True)) + .filter(User.is_active.is_(True)) .first() ) if member: @@ -226,7 +226,7 @@ class WorkspaceRepository: member = self.db.query(WorkspaceMember).filter( WorkspaceMember.workspace_id == workspace_id, WorkspaceMember.user_id == user_id, - WorkspaceMember.is_active == True, + WorkspaceMember.is_active.is_(True), ).first() if not member: return None @@ -243,7 +243,7 @@ class WorkspaceRepository: member = self.db.query(WorkspaceMember).filter( WorkspaceMember.workspace_id == workspace_id, WorkspaceMember.user_id == user_id, - WorkspaceMember.is_active == True, + WorkspaceMember.is_active.is_(True), ).first() if not member: return None @@ -259,7 +259,7 @@ class WorkspaceRepository: try: member = self.db.query(WorkspaceMember).filter( WorkspaceMember.id == member_id, - WorkspaceMember.is_active == True, + WorkspaceMember.is_active.is_(True), ).first() if not member: return None @@ -275,7 +275,7 @@ class WorkspaceRepository: try: member = self.db.query(WorkspaceMember).filter( WorkspaceMember.id == id, - WorkspaceMember.is_active == True, + WorkspaceMember.is_active.is_(True), ).first() if not member: return None diff --git a/api/app/services/agent_registry.py b/api/app/services/agent_registry.py index 2b6d92e3..d221bbf5 100644 --- a/api/app/services/agent_registry.py +++ b/api/app/services/agent_registry.py @@ -55,8 +55,8 @@ class AgentRegistry: """ # 构建查询 stmt = select(AgentConfig).join(App).where( - AgentConfig.is_active == True, - App.is_active == True + AgentConfig.is_active.is_(True), + App.is_active.is_(True) ) # 工作空间过滤(同工作空间或公开) diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 68acab1d..7ec4bc0e 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -758,7 +758,7 @@ class AppService: ) # 构建查询条件 - filters = [App.is_active == True] + filters = [App.is_active.is_(True)] if type: filters.append(App.type == type) if visibility: @@ -873,7 +873,7 @@ class AppService: self._validate_workspace_access(app, workspace_id) - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active == True).order_by( + stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active.is_(True)).order_by( AgentConfig.updated_at.desc()) agent_cfg: Optional[AgentConfig] = self.db.scalars(stmt).first() now = datetime.datetime.now() @@ -1204,7 +1204,7 @@ class AppService: default_model_config_id = None if app.type == AppType.AGENT: - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active == True).order_by( + stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active.is_(True)).order_by( AgentConfig.updated_at.desc()) agent_cfg = self.db.scalars(stmt).first() if not agent_cfg: @@ -1226,7 +1226,7 @@ class AppService: select(MultiAgentConfig) .where( MultiAgentConfig.app_id == app_id, - MultiAgentConfig.is_active == True + MultiAgentConfig.is_active.is_(True) ) .order_by(MultiAgentConfig.updated_at.desc()) ) @@ -1380,7 +1380,7 @@ class AppService: stmt = ( select(AppRelease) - .where(AppRelease.app_id == app_id, AppRelease.is_active == True) + .where(AppRelease.app_id == app_id, AppRelease.is_active.is_(True)) .order_by(AppRelease.version.desc()) ) return list(self.db.scalars(stmt).all()) diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 46bda5f6..4f20f6d9 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -728,7 +728,7 @@ class DraftRunService: select(ModelApiKey) .where( ModelApiKey.model_config_id == model_config_id, - ModelApiKey.is_active == True + ModelApiKey.is_active.is_(True) ) .order_by(ModelApiKey.priority.desc()) .limit(1) diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 8170bdd8..7c8ee9ac 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -173,10 +173,9 @@ class MemoryAgentService: """ logger.info("Reading log file") - - current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py - app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory - project_root = os.path.dirname(app_dir) # redbear-mem directory + # Get log file path - use project root directory + from pathlib import Path + project_root = str(Path(__file__).resolve().parents[2]) # api directory log_path = os.path.join(project_root, "logs", "agent_service.log") summer = '' @@ -215,9 +214,8 @@ class MemoryAgentService: logger.info("Starting log content streaming") # Get log file path - use project root directory - current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py - app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory - project_root = os.path.dirname(app_dir) # redbear-mem directory + from pathlib import Path + project_root = str(Path(__file__).resolve().parents[2]) # api directory log_path = os.path.join(project_root, "logs", "agent_service.log") # Check if file exists before starting stream @@ -1079,9 +1077,8 @@ class MemoryAgentService: logger.info("Starting log content streaming") # Get log file path - use project root directory - current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py - app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory - project_root = os.path.dirname(app_dir) # redbear-mem directory + from pathlib import Path + project_root = str(Path(__file__).resolve().parents[2]) # api directory log_path = os.path.join(project_root, "logs", "agent_service.log") # Check if file exists before starting stream diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index 0ae2b965..2d3d047e 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -77,7 +77,10 @@ class MemoryAPIService: ) # Verify end_user belongs to the workspace via App relationship - app = self.db.query(App).filter(App.id == end_user.app_id).first() + app = self.db.query(App).filter( + App.id == end_user.app_id, + App.is_active.is_(True) + ).first() if not app: logger.warning(f"App not found for end_user: {end_user_id}") diff --git a/api/app/services/memory_reflection_service.py b/api/app/services/memory_reflection_service.py index 46e42b46..af72e3cc 100644 --- a/api/app/services/memory_reflection_service.py +++ b/api/app/services/memory_reflection_service.py @@ -38,7 +38,10 @@ class WorkspaceAppService: Returns: Dictionary containing detailed application information """ - apps = self.db.query(App).filter(App.workspace_id == workspace_id).all() + apps = self.db.query(App).filter( + App.workspace_id == workspace_id, + App.is_active.is_(True) + ).all() app_ids = [str(app.id) for app in apps] apps_detailed_info = [] diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 83d5923d..48c3abf1 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -237,7 +237,8 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) ValueError: 当配置无效或参数缺失时 RuntimeError: 当管线执行失败时 """ - project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from pathlib import Path + project_root = str(Path(__file__).resolve().parents[2]) try: # 发出初始进度事件 diff --git a/api/app/services/multi_agent_orchestrator.py b/api/app/services/multi_agent_orchestrator.py index 1972f344..4bcd28cd 100644 --- a/api/app/services/multi_agent_orchestrator.py +++ b/api/app/services/multi_agent_orchestrator.py @@ -2548,7 +2548,7 @@ class MultiAgentOrchestrator: # 获取 API Key 配置 api_key_config = self.db.query(ModelApiKey).filter( ModelApiKey.model_config_id == default_model_config_id, - ModelApiKey.is_active == True + ModelApiKey.is_active.is_(True) ).first() if not api_key_config: @@ -2705,7 +2705,7 @@ class MultiAgentOrchestrator: # 获取 API Key 配置 api_key_config = self.db.query(ModelApiKey).filter( ModelApiKey.model_config_id == default_model_config_id, - ModelApiKey.is_active == True + ModelApiKey.is_active.is_(True) ).first() if not api_key_config: diff --git a/api/app/services/multi_agent_service.py b/api/app/services/multi_agent_service.py index 1a08a5af..da984d16 100644 --- a/api/app/services/multi_agent_service.py +++ b/api/app/services/multi_agent_service.py @@ -74,7 +74,7 @@ class MultiAgentService: select(MultiAgentConfig) .where( MultiAgentConfig.app_id == app_id, - MultiAgentConfig.is_active == True + MultiAgentConfig.is_active.is_(True) ) .order_by(MultiAgentConfig.updated_at.desc()) ).first() @@ -144,7 +144,7 @@ class MultiAgentService: select(MultiAgentConfig) .where( MultiAgentConfig.app_id == app_id, - MultiAgentConfig.is_active == True + MultiAgentConfig.is_active.is_(True) ) .order_by(MultiAgentConfig.updated_at.desc()) ).first() diff --git a/api/app/services/shared_chat_service.py b/api/app/services/shared_chat_service.py index e5247e5e..5eee5edc 100644 --- a/api/app/services/shared_chat_service.py +++ b/api/app/services/shared_chat_service.py @@ -168,7 +168,7 @@ class SharedChatService: select(ModelApiKey) .where( ModelApiKey.model_config_id == model_config_id, - ModelApiKey.is_active == True + ModelApiKey.is_active.is_(True) ) .order_by(ModelApiKey.priority.desc()) .limit(1) @@ -362,7 +362,7 @@ class SharedChatService: select(ModelApiKey) .where( ModelApiKey.model_config_id == model_config_id, - ModelApiKey.is_active == True + ModelApiKey.is_active.is_(True) ) .order_by(ModelApiKey.priority.desc()) .limit(1) @@ -598,7 +598,7 @@ class SharedChatService: # 获取多 Agent 配置 multi_agent_config = self.db.query(MultiAgentConfig).filter( MultiAgentConfig.app_id == release.app_id, - MultiAgentConfig.is_active == True + MultiAgentConfig.is_active.is_(True) ).first() if not multi_agent_config: @@ -695,7 +695,7 @@ class SharedChatService: # 获取多 Agent 配置 multi_agent_config = self.db.query(MultiAgentConfig).filter( MultiAgentConfig.app_id == release.app_id, - MultiAgentConfig.is_active == True + MultiAgentConfig.is_active.is_(True) ).first() if not multi_agent_config: diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index b7d5df02..f9426c87 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -761,7 +761,10 @@ class WorkflowService: # 4. 获取工作空间 ID(从 app 获取) from app.models import App - app = self.db.query(App).filter(App.id == app_id).first() + app = self.db.query(App).filter( + App.id == app_id, + App.is_active.is_(True) + ).first() if not app: raise BusinessException( code=BizCode.NOT_FOUND, diff --git a/api/app/tasks.py b/api/app/tasks.py index fa9d1fdf..5f2b1ef5 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -635,8 +635,11 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: try: workspace_uuid = uuid.UUID(workspace_id) - # 1. 查询当前workspace下的所有app - apps = db.query(App).filter(App.workspace_id == workspace_uuid).all() + # 1. 查询当前workspace下的所有app(仅未删除的) + apps = db.query(App).filter( + App.workspace_id == workspace_uuid, + App.is_active.is_(True) + ).all() if not apps: # 如果没有app,总量为0 diff --git a/api/migrations/env.py b/api/migrations/env.py index 95d74019..e4cd6dfb 100644 --- a/api/migrations/env.py +++ b/api/migrations/env.py @@ -46,7 +46,8 @@ def import_all_models_from_package(package_name: str): # Add the project root to sys.path if not already there # This is crucial for relative imports like 'app.db' to work - project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) + from pathlib import Path + project_root = str(Path(__file__).resolve().parent.parent) if project_root not in sys.path: sys.path.insert(0, project_root) From 6e18c92a130ceeb4c158bbeb9a267376d88ab74c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= <162269739+lanceyq@users.noreply.github.com> Date: Fri, 23 Jan 2026 12:21:28 +0800 Subject: [PATCH 02/28] Fix/optimize inerface (#183) * [changes]Optimize the time consumption of the "/end_users" interface * [fix]Optimize the time consumption of the "/hot_memory_tags" interface * [changes]Optimize the time consumption of the "/end_users" interface * [fix]Optimize the time consumption of the "/hot_memory_tags" interface * [changes]Improve the code based on AI review --- .../memory_dashboard_controller.py | 146 +++++++++++++----- .../controllers/memory_storage_controller.py | 82 +++++++++- api/app/services/memory_dashboard_service.py | 87 ++++++++++- api/app/services/memory_storage_service.py | 86 +++++++++-- 4 files changed, 340 insertions(+), 61 deletions(-) diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index e03c1846..6181c319 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -49,63 +49,135 @@ async def get_workspace_end_users( current_user: User = Depends(get_current_user), ): """ - 获取工作空间的宿主列表 + 获取工作空间的宿主列表(高性能优化版本 v2) - 返回格式与原 memory_list 接口中的 end_users 字段相同, - 并包含每个用户的记忆配置信息(memory_config_id 和 memory_config_name) + 优化策略: + 1. 批量查询 end_users(一次查询而非循环) + 2. 并发查询所有用户的记忆数量(Neo4j) + 3. RAG 模式使用批量查询(一次 SQL) + 4. 只返回必要字段减少数据传输 + 5. 添加短期缓存减少重复查询 + 6. 并发执行配置查询和记忆数量查询 + + 返回格式: + { + "end_user": {"id": "uuid", "other_name": "名称"}, + "memory_num": {"total": 数量}, + "memory_config": {"memory_config_id": "id", "memory_config_name": "名称"} + } """ + import asyncio + import json + from app.aioRedis import aio_redis_get, aio_redis_set + workspace_id = current_user.current_workspace_id + + # 尝试从缓存获取(30秒缓存) + cache_key = f"end_users:workspace:{workspace_id}" + try: + cached_data = await aio_redis_get(cache_key) + if cached_data: + api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}") + return success(data=json.loads(cached_data), msg="宿主列表获取成功") + except Exception as e: + api_logger.warning(f"Redis 缓存读取失败: {str(e)}") + # 获取当前空间类型 current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user) api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表") + + # 获取 end_users(已优化为批量查询) end_users = memory_dashboard_service.get_workspace_end_users( db=db, workspace_id=workspace_id, current_user=current_user ) - # 批量获取所有用户的记忆配置信息(优化:一次查询而非 N 次) - end_user_ids = [str(user.id) for user in end_users] - memory_configs_map = {} - if end_user_ids: + if not end_users: + api_logger.info("工作空间下没有宿主") + # 缓存空结果,避免重复查询 try: - memory_configs_map = get_end_users_connected_configs_batch(end_user_ids, db) + await aio_redis_set(cache_key, json.dumps([]), expire=30) + except Exception as e: + api_logger.warning(f"Redis 缓存写入失败: {str(e)}") + return success(data=[], msg="宿主列表获取成功") + + end_user_ids = [str(user.id) for user in end_users] + + # 并发执行两个独立的查询任务 + async def get_memory_configs(): + """获取记忆配置(在线程池中执行同步查询)""" + try: + return await asyncio.to_thread( + get_end_users_connected_configs_batch, + end_user_ids, db + ) except Exception as e: api_logger.error(f"批量获取记忆配置失败: {str(e)}") - # 失败时使用空字典,不影响其他数据返回 + return {} + async def get_memory_nums(): + """获取记忆数量""" + if current_workspace_type == "rag": + # RAG 模式:批量查询 + try: + chunk_map = await asyncio.to_thread( + memory_dashboard_service.get_users_total_chunk_batch, + end_user_ids, db, current_user + ) + return {uid: {"total": count} for uid, count in chunk_map.items()} + except Exception as e: + api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}") + return {uid: {"total": 0} for uid in end_user_ids} + + elif current_workspace_type == "neo4j": + # Neo4j 模式:并发查询(带并发限制) + # 使用信号量限制并发数,避免大量用户时压垮 Neo4j + MAX_CONCURRENT_QUERIES = 10 + semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES) + + async def get_neo4j_memory_num(end_user_id: str): + async with semaphore: + try: + return await memory_storage_service.search_all(end_user_id) + except Exception as e: + api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}") + return {"total": 0} + + memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids]) + return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))} + + return {uid: {"total": 0} for uid in end_user_ids} + + # 并发执行配置查询和记忆数量查询 + memory_configs_map, memory_nums_map = await asyncio.gather( + get_memory_configs(), + get_memory_nums() + ) + + # 构建结果(优化:使用列表推导式) result = [] for end_user in end_users: - memory_num = {} - if current_workspace_type == "neo4j": - # EndUser 是 Pydantic 模型,直接访问属性而不是使用 .get() - memory_num = await memory_storage_service.search_all(str(end_user.id)) - elif current_workspace_type == "rag": - memory_num = { - "total":memory_dashboard_service.get_current_user_total_chunk(str(end_user.id), db, current_user) - } - - # 从批量查询结果中获取配置信息 user_id = str(end_user.id) - memory_config_info = memory_configs_map.get(user_id, { - "memory_config_id": None, - "memory_config_name": None - }) - - # 只保留需要的字段,移除 error 字段(如果有) - memory_config = { - "memory_config_id": memory_config_info.get("memory_config_id"), - "memory_config_name": memory_config_info.get("memory_config_name") - } - - result.append( - { - 'end_user': end_user, - 'memory_num': memory_num, - 'memory_config': memory_config + config_info = memory_configs_map.get(user_id, {}) + result.append({ + 'end_user': { + 'id': user_id, + 'other_name': end_user.other_name + }, + 'memory_num': memory_nums_map.get(user_id, {"total": 0}), + 'memory_config': { + "memory_config_id": config_info.get("memory_config_id"), + "memory_config_name": config_info.get("memory_config_name") } - ) - + }) + + # 写入缓存(30秒过期) + try: + await aio_redis_set(cache_key, json.dumps(result), expire=30) + except Exception as e: + api_logger.warning(f"Redis 缓存写入失败: {str(e)}") + api_logger.info(f"成功获取 {len(end_users)} 个宿主记录") return success(data=result, msg="宿主列表获取成功") diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index f4175923..3722be3d 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -420,15 +420,95 @@ async def get_hot_memory_tags_api( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ) -> dict: - api_logger.info(f"Hot memory tags requested for current_user: {current_user.id}") + """ + 获取热门记忆标签(带Redis缓存) + + 缓存策略: + - 缓存键:workspace_id + limit + - 过期时间:5分钟(300秒) + - 缓存命中:~50ms + - 缓存未命中:~600-800ms(取决于LLM速度) + """ + workspace_id = current_user.current_workspace_id + + # 构建缓存键 + cache_key = f"hot_memory_tags:{workspace_id}:{limit}" + + api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}") + try: + # 尝试从Redis缓存获取 + from app.aioRedis import aio_redis_get, aio_redis_set + import json + + cached_result = await aio_redis_get(cache_key) + if cached_result: + api_logger.info(f"Cache hit for key: {cache_key}") + try: + data = json.loads(cached_result) + return success(data=data, msg="查询成功(缓存)") + except json.JSONDecodeError: + api_logger.warning(f"Failed to parse cached data, will refresh") + + # 缓存未命中,执行查询 + api_logger.info(f"Cache miss for key: {cache_key}, executing query") result = await analytics_hot_memory_tags(db, current_user, limit) + + # 写入缓存(过期时间:5分钟) + # 注意:result是列表,需要转换为JSON字符串 + try: + cache_data = json.dumps(result, ensure_ascii=False) + await aio_redis_set(cache_key, cache_data, expire=300) + api_logger.info(f"Cached result for key: {cache_key}") + except Exception as cache_error: + # 缓存写入失败不影响主流程 + api_logger.warning(f"Failed to cache result: {str(cache_error)}") + return success(data=result, msg="查询成功") + except Exception as e: api_logger.error(f"Hot memory tags failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e)) +@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse) +async def clear_hot_memory_tags_cache( + current_user: User = Depends(get_current_user), + ) -> dict: + """ + 清除热门标签缓存 + + 用于: + - 手动刷新数据 + - 调试和测试 + - 数据更新后立即生效 + """ + workspace_id = current_user.current_workspace_id + + api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}") + + try: + from app.aioRedis import aio_redis_delete + + # 清除所有limit的缓存(常见的limit值) + cleared_count = 0 + for limit in [5, 10, 15, 20, 30, 50]: + cache_key = f"hot_memory_tags:{workspace_id}:{limit}" + result = await aio_redis_delete(cache_key) + if result: + cleared_count += 1 + api_logger.info(f"Cleared cache for key: {cache_key}") + + return success( + data={"cleared_count": cleared_count}, + msg=f"成功清除 {cleared_count} 个缓存" + ) + + except Exception as e: + api_logger.error(f"Clear cache failed: {str(e)}") + return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e)) + + @router.get("/analytics/recent_activity_stats", response_model=ApiResponse) async def get_recent_activity_stats_api( current_user: User = Depends(get_current_user), diff --git a/api/app/services/memory_dashboard_service.py b/api/app/services/memory_dashboard_service.py index a774647e..06a94060 100644 --- a/api/app/services/memory_dashboard_service.py +++ b/api/app/services/memory_dashboard_service.py @@ -53,18 +53,28 @@ def get_workspace_end_users( workspace_id: uuid.UUID, current_user: User ) -> List[EndUser]: - """获取工作空间的所有宿主""" + """获取工作空间的所有宿主(优化版本:减少数据库查询次数)""" business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}") try: - # 查询应用(ORM)并转换为 Pydantic 模型 + # 查询应用(ORM) apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id) - apps = [AppSchema.model_validate(h) for h in apps_orm] - app_ids = [app.id for app in apps] - end_users = [] - for app_id in app_ids: - end_user_orm_list = end_user_repository.get_end_users_by_app_id(db, app_id) - end_users.extend([EndUserSchema.model_validate(h) for h in end_user_orm_list]) + + if not apps_orm: + business_logger.info("工作空间下没有应用") + return [] + + # 提取所有 app_id + app_ids = [app.id for app in apps_orm] + + # 批量查询所有 end_users(一次查询而非循环查询) + from app.models.end_user_model import EndUser as EndUserModel + end_users_orm = db.query(EndUserModel).filter( + EndUserModel.app_id.in_(app_ids) + ).all() + + # 转换为 Pydantic 模型(只在需要时转换) + end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm] business_logger.info(f"成功获取 {len(end_users)} 个宿主记录") return end_users @@ -414,6 +424,67 @@ def get_current_user_total_chunk( business_logger.error(f"获取用户总chunk数失败: end_user_id={end_user_id} - {str(e)}") raise + +def get_users_total_chunk_batch( + end_user_ids: List[str], + db: Session, + current_user: User +) -> dict: + """ + 批量获取多个用户的总chunk数(性能优化版本) + + Args: + end_user_ids: 用户ID列表 + db: 数据库会话 + current_user: 当前用户 + + Returns: + 字典,key为end_user_id,value为chunk总数 + 格式: {"user_id_1": 100, "user_id_2": 50, ...} + """ + business_logger.info(f"批量获取 {len(end_user_ids)} 个用户的总chunk数, 操作者: {current_user.username}") + + try: + from app.models.document_model import Document + from sqlalchemy import func, case + + if not end_user_ids: + return {} + + # 构造所有文件名 + file_names = [f"{user_id}.txt" for user_id in end_user_ids] + + # 一次查询获取所有用户的chunk总数 + # 使用 GROUP BY file_name 来分组统计 + results = db.query( + Document.file_name, + func.sum(Document.chunk_num).label('total_chunk') + ).filter( + Document.file_name.in_(file_names) + ).group_by( + Document.file_name + ).all() + + # 构建结果字典 + chunk_map = {} + for file_name, total_chunk in results: + # 从文件名中提取 end_user_id (去掉 .txt 后缀) + user_id = file_name.replace('.txt', '') + chunk_map[user_id] = int(total_chunk or 0) + + # 对于没有记录的用户,设置为0 + for user_id in end_user_ids: + if user_id not in chunk_map: + chunk_map[user_id] = 0 + + business_logger.info(f"成功批量获取 {len(chunk_map)} 个用户的总chunk数") + return chunk_map + + except Exception as e: + business_logger.error(f"批量获取用户总chunk数失败: {str(e)}") + raise + + def get_rag_content( end_user_id: str, limit: int, diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 48c3abf1..c276f337 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -12,7 +12,11 @@ from datetime import datetime from typing import Any, AsyncGenerator, Dict, List, Optional from app.core.logging_config import get_config_logger, get_logger -from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags +from app.core.memory.analytics.hot_memory_tags import ( + get_hot_memory_tags, + get_raw_tags_from_db, + filter_tags_with_llm, +) from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats from app.models.user_model import User from app.repositories.data_config_repository import DataConfigRepository @@ -515,27 +519,79 @@ async def analytics_hot_memory_tags( ) -> List[Dict[str, Any]]: """ 获取热门记忆标签,按数量排序并返回前N个 + + 优化策略: + 1. 先从所有用户收集原始标签(不调用LLM) + 2. 聚合并合并相同标签的频率 + 3. 排序后取前N个 + 4. 只调用一次LLM进行筛选 """ workspace_id = current_user.current_workspace_id # 获取更多标签供LLM筛选(获取limit*4个标签) raw_limit = limit * 4 from app.services.memory_dashboard_service import get_workspace_end_users - end_users = get_workspace_end_users(db, workspace_id, current_user) + # 使用 asyncio.to_thread 避免阻塞事件循环 + end_users = await asyncio.to_thread(get_workspace_end_users, db, workspace_id, current_user) - tags = [] - for end_user in end_users: - tag = await get_hot_memory_tags(str(end_user.id), limit=raw_limit) - if tag: - # 将每个用户的标签列表展平到总列表中 - tags.extend(tag) - - # 按频率降序排序(虽然数据库已经排序,但为了确保正确性再次排序) - sorted_tags = sorted(tags, key=lambda x: x[1], reverse=True) + if not end_users: + return [] - # 只返回前limit个 - top_tags = sorted_tags[:limit] - - return [{"name": t, "frequency": f} for t, f in top_tags] + # 步骤1: 收集所有用户的原始标签(不调用LLM) + connector = Neo4jConnector() + try: + all_raw_tags = [] + for end_user in end_users: + raw_tags = await get_raw_tags_from_db( + connector, + str(end_user.id), + limit=raw_limit, + by_user=False + ) + if raw_tags: + all_raw_tags.extend(raw_tags) + + if not all_raw_tags: + return [] + + # 步骤2: 聚合相同标签的频率 + tag_frequency_map = {} + for tag_name, frequency in all_raw_tags: + if tag_name in tag_frequency_map: + tag_frequency_map[tag_name] += frequency + else: + tag_frequency_map[tag_name] = frequency + + # 步骤3: 按频率降序排序,取前raw_limit个 + sorted_tags = sorted( + tag_frequency_map.items(), + key=lambda x: x[1], + reverse=True + )[:raw_limit] + + if not sorted_tags: + return [] + + # 步骤4: 只调用一次LLM进行筛选 + tag_names = [tag for tag, _ in sorted_tags] + + # 使用第一个用户的group_id来获取LLM配置 + # 因为同一工作空间下的用户应该使用相同的配置 + first_end_user_id = str(end_users[0].id) + filtered_tag_names = await filter_tags_with_llm(tag_names, first_end_user_id) + + # 步骤5: 根据LLM筛选结果构建最终列表(保留频率) + final_tags = [] + for tag, freq in sorted_tags: + if tag in filtered_tag_names: + final_tags.append((tag, freq)) + + # 步骤6: 只返回前limit个 + top_tags = final_tags[:limit] + + return [{"name": t, "frequency": f} for t, f in top_tags] + + finally: + await connector.close() async def analytics_recent_activity_stats() -> Dict[str, Any]: From 15f9c49418c93bd1521fd64baddc11f63b3ee6c9 Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Fri, 23 Jan 2026 12:21:54 +0800 Subject: [PATCH 03/28] Fix/memory mcp2 1 (#184) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 优化快速检索的回复内容 * 优化快速检索的回复内容 --- .../controllers/memory_agent_controller.py | 2 + .../utils/prompt/direct_summary_prompt.jinja2 | 61 +++++++++++++++++++ api/app/services/memory_agent_service.py | 2 +- 3 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 api/app/core/memory/agent/utils/prompt/direct_summary_prompt.jinja2 diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index 78a5771f..8b5a55b9 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -306,6 +306,8 @@ async def read_server( config_id=config_id, db=db ) + if "信息不足,无法回答" in result['answer']: + result['answer']=retrieve_info return success(data=result, msg="回复对话消息成功") except BaseException as e: # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup diff --git a/api/app/core/memory/agent/utils/prompt/direct_summary_prompt.jinja2 b/api/app/core/memory/agent/utils/prompt/direct_summary_prompt.jinja2 new file mode 100644 index 00000000..1e0690bf --- /dev/null +++ b/api/app/core/memory/agent/utils/prompt/direct_summary_prompt.jinja2 @@ -0,0 +1,61 @@ +# 角色 +你是一个智能问答助手,基于检索信息和历史对话回答用户问题。 +# 任务 +根据提供的上下文信息回答用户的问题。 +# 输入信息 +- 历史对话:{{history}} +- 检索信息:{{retrieve_info}} +# 用户问题 +{{query}} +# 回答指南 +## 1. 仔细阅读检索信息 +- 答案可能直接或间接地出现在检索信息中 +- 如果检索信息中提到"小曼会使用Python",说明用户名是"小曼" +- 第三人称描述的偏好、行为通常指用户本人 + +## 2. 判断信息相关性 +**情况A:信息匹配问题** +- 直接回答,像自然对话一样 +- 例:检索到"小曼会使用Python" → 问"我叫什么" → 答"你叫小曼" + +**情况B:信息部分相关** +- 先回答已知部分,再自然地询问更多信息 +- 例:检索到"用户去过上海的面包店" → 问"我吃过哪家面包" → 答"我记得你去过上海的面包店,但具体是哪家我不太清楚,是哪家呢?" + +**情况C:信息完全不相关** +- 自然地表达不知道,但可以提及检索到的相关信息,让对话更连贯 +- 使用友好的表达: + - "你好像没和我说过...,但是我知道你[检索到的相关信息]" + - "关于这个我不太清楚,不过我记得你[检索到的相关信息],能告诉我更多吗?" + - "我不记得你提到过...,但你[检索到的相关信息]" +- 即使检索信息不直接回答问题,也可以自然地融入对话中 +- 避免僵硬的"信息不足,无法回答" +## 3. 回答要求 +- 像人类对话一样自然流畅 +- 不要提及"检索信息"、"搜索结果"、"根据资料"等技术术语 +- 不要解释推理过程或引用信息来源 +- 保持友好、乐于助人的语气 +- 使用与问题相同的语言回答 +# 关键示例 +**示例1 - 直接匹配:** +- 检索信息:"小曼会使用Python..." +- 问题:"我叫什么" +- ✓ 正确:"你叫小曼" +- ✗ 错误:"你没有告诉我你的名字" +**示例2 - 间接匹配:** +- 检索信息:"用户很喜欢吃星巴克的甜品" +- 问题:"我喜欢什么" +- ✓ 正确:"你很喜欢吃星巴克的甜品" +- ✗ 错误:"信息不足" +**示例3 - 信息不匹配(推荐做法):** +- 检索信息:"用户只喝拿铁咖啡,认为美式咖啡太苦" +- 问题:"我吃过哪家面包" +- ✓ 最佳:"你好像没和我说过吃过哪家面包,但是我知道你喜欢喝拿铁,能跟我分享一下吗?" +- ✓ 可以:"你好像没和我说过吃过哪家面包,能跟我分享一下吗?" +- ✗ 错误:"用户只喝拿铁咖啡,认为美式咖啡太苦。"(答非所问) +- ✗ 错误:"信息不足,无法回答。"(太僵硬) +# 重要提醒 +- 检索信息中描述用户行为/偏好时提到的名字,就是用户的名字 +- 信息不匹配时,不要强行回答无关内容,但可以自然地提及检索到的信息,让对话更有温度 +- 用对话式语言表达"不知道",而非机械模板 +- 检索信息代表你对用户的了解,即使不直接回答问题,也能体现你对用户的记忆 diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 7c8ee9ac..a24456d2 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -729,7 +729,7 @@ class MemoryAgentService: state=state, history=history, retrieve_info=retrieve_info, - template_name='Retrieve_Summary_prompt.jinja2', + template_name='direct_summary_prompt.jinja2', operation_name='retrieve_summary', response_model=RetrieveSummaryResponse, search_mode="1" From 86812b34d1813373ea04fa50ffc0f79bc1d82c9c Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Fri, 23 Jan 2026 13:57:27 +0800 Subject: [PATCH 04/28] Fix/memory mcp2 1 (#185) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 优化快速检索的回复内容 * 优化快速检索的回复内容 * 路径的BUG修复 * 路径的BUG修复 * 路径的BUG修复 * 路径的BUG修复 * 路径的BUG修复 --- .../core/memory/agent/langgraph_graph/nodes/problem_nodes.py | 2 +- .../core/memory/agent/langgraph_graph/nodes/summary_nodes.py | 2 +- .../memory/agent/langgraph_graph/nodes/verification_nodes.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py index 697a13bd..2bad650a 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py @@ -14,7 +14,7 @@ from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin -template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt') +template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') db_session = next(get_db()) logger = get_agent_logger(__name__) diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py index 44f89c6a..f05a5ae1 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -19,7 +19,7 @@ from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.template_tools import TemplateService from app.db import get_db -template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt') +template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') logger = get_agent_logger(__name__) db_session = next(get_db()) diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py index dac7ea14..10ce8db4 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py @@ -12,7 +12,7 @@ from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin -template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt') +template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') db_session = next(get_db()) logger = get_agent_logger(__name__) From 313f19eba4940c381b36cde243846254c356f19a Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Fri, 23 Jan 2026 14:49:44 +0800 Subject: [PATCH 05/28] Fix/memory mcp2 1 (#188) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 优化快速检索的回复内容 * 优化快速检索的回复内容 * 路径的BUG修复 * 路径的BUG修复 * 路径的BUG修复 * 路径的BUG修复 * 路径的BUG修复 * LLM生存缺少config_id认证,修复BUG * LLM生存缺少config_id认证,修复BUG * LLM生存缺少config_id认证,修复BUG --- .../controllers/memory_agent_controller.py | 3 +-- api/app/services/memory_agent_service.py | 20 ++++++++++++++++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index 8b5a55b9..c54fb02b 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -262,9 +262,7 @@ async def read_server( """ config_id = user_input.config_id workspace_id = current_user.current_workspace_id - api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}") - # 获取 storage_type,如果为 None 则使用默认值 storage_type = workspace_service.get_workspace_storage_type( db=db, workspace_id=workspace_id, @@ -300,6 +298,7 @@ async def read_server( # 调用 memory_agent_service 的方法生成最终答案 result['answer'] = await memory_agent_service.generate_summary_from_retrieve( + group_id=user_input.group_id, retrieve_info=retrieve_info, history=history, query=query, diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index a24456d2..83b6bdd7 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -410,8 +410,8 @@ class MemoryAgentService: # Resolve config_id if None using end_user's connected config if config_id is None: try: - connected_config = get_end_user_connected_config(group_id, db) - config_id = connected_config.get("memory_config_id") + config_id = get_end_user_connected_config(group_id, db) + config_id=config_id.get('memory_config_id') if config_id is None: raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.") except Exception as e: @@ -670,6 +670,8 @@ class MemoryAgentService: """ logger.info("Classifying message type") + + # Load configuration to get LLM model ID config_service = MemoryConfigService(db) memory_config = config_service.load_memory_config( @@ -683,6 +685,7 @@ class MemoryAgentService: async def generate_summary_from_retrieve( self, + group_id: str, retrieve_info: str, history: List[Dict], query: str, @@ -704,6 +707,18 @@ class MemoryAgentService: Returns: 生成的答案文本 """ + if config_id is None: + try: + config_id = get_end_user_connected_config(group_id, db) + config_id = config_id.get('memory_config_id') + if config_id is None: + raise ValueError( + f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.") + except Exception as e: + if "No memory configuration found" in str(e): + raise # Re-raise our specific error + logger.error(f"Failed to get connected config for end_user {group_id}: {e}") + raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}") logger.info(f"Generating summary from retrieve info for query: {query[:50]}...") try: @@ -713,7 +728,6 @@ class MemoryAgentService: config_id=config_id, service_name="MemoryAgentService" ) - # 导入必要的模块 from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import summary_llm from app.core.memory.agent.models.summary_models import RetrieveSummaryResponse From c115bcde545c03f61d9b2c6853b8e017a010e80d Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Fri, 23 Jan 2026 16:58:55 +0800 Subject: [PATCH 06/28] feat(home page): version description update --- api/app/core/config.py | 2 +- .../core/tools/builtin/baidu_search_tool.py | 4 +-- api/app/version_info.json | 30 +++++++++++++++++++ 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/api/app/core/config.py b/api/app/core/config.py index 59c6ff5f..3be6f849 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -184,7 +184,7 @@ class Settings: ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true" # official environment system version - SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0") + SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1") # workflow config WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600)) diff --git a/api/app/core/tools/builtin/baidu_search_tool.py b/api/app/core/tools/builtin/baidu_search_tool.py index 02431aed..45d4c359 100644 --- a/api/app/core/tools/builtin/baidu_search_tool.py +++ b/api/app/core/tools/builtin/baidu_search_tool.py @@ -16,7 +16,7 @@ class BaiduSearchTool(BuiltinTool): @property def description(self) -> str: - return "百度搜索 - 搜索引擎服务:网页搜索、新闻搜索、图片搜索、实时结果" + return "百度搜索 - 搜索引擎服务:网页搜索、新闻搜索、图片搜索、视频搜索" def get_required_config_parameters(self) -> List[str]: return ["api_key"] @@ -33,7 +33,7 @@ class BaiduSearchTool(BuiltinTool): ToolParameter( name="search_type", type=ParameterType.STRING, - description="搜索类型", + description="搜索类型, web: 网页搜索;news:新闻搜索;image:图片搜索;video视频搜索", required=False, default="web", enum=["web", "news", "image", "video"] diff --git a/api/app/version_info.json b/api/app/version_info.json index 20896845..bee52989 100644 --- a/api/app/version_info.json +++ b/api/app/version_info.json @@ -1,4 +1,34 @@ { + "v0.2.1": { + "introduction": { + "codeName": "启知", + "releaseDate": "2026-1-23", + "upgradePosition": "\uD83D\uDC3B 本次更新主要优化使用体验和修复已知问题,让系统更稳定、更好用。", + "coreUpgrades": [ + "1. 工作流更好用了\n* 界面更清晰,一眼看懂怎么配置\n* 新增节点输出变量展示,方便其他节点引用\n* 修复了几个影响体验的bug", + "2. 智能体配置更简单\n* 提示词和变量联动更顺畅\n* 配置界面重新整理,找功能更方便", + "3. 记忆系统更稳定\n* 优化了情绪记忆和隐性记忆的缓存更新\n* 修复了记忆配置页面的报错问题\n* 现在能自动识别用户和AI的身份了", + "4. 知识库体验提升\n* 修复了文档解析异常的问题\n* 上传文档时能看到处理进度了\n* 取消了操作也不会报错了", + "5. 系统整体更可靠\n* 修复了新用户访问跳转问题\n* 流式接口更稳定,长对话不断线\n* 调整了菜单顺序,操作更顺手\n", + "这次更新虽然不大,但让记忆熊的基础更扎实、体验更流畅。我们继续努力,让AI记忆更好用!", + "记忆熊,记得更牢,用得更好。\uD83D\uDC3B✨" + ] + }, + "introduction_en": { + "codeName": "Qizhi", + "releaseDate": "2026-1-23", + "upgradePosition": "\uD83D\uDC3B This update focuses on improving usability and fixing known issues, making the system more stable and easier to use overall.", + "coreUpgrades": [ + "1. Improved Workflow Experience\nCleaner, more intuitive UI for easier configuration at a glance\nAdded visibility of node output variables, making them easier to reference in downstream nodes\nFixed several usability-related bugs that affected the workflow experience", + "2. Simpler Agent Configuration\nSmoother linkage between prompts and variables\nReorganized configuration layout for easier navigation and better clarity", + "3. More Stable Memory System\nOptimized cache refresh for emotional memory and implicit memory\nFixed error issues on the memory configuration page\nThe system can now automatically distinguish between user and AI roles", + "4. Enhanced Knowledge Base Experience\nFixed issues with document parsing failures\nUpload progress is now displayed during document processing\nCanceling an upload no longer triggers errors", + "5. Overall System Reliability Improvements\nFixed redirect issues affecting new users\nImproved stability of streaming APIs to prevent interruptions during long conversations\nAdjusted menu ordering for a smoother and more intuitive workflow\n", + "Although this is a relatively small update, it strengthens MemoryBear’s foundation and delivers a noticeably smoother experience.\nWe’ll keep refining the system to make AI memory more powerful and easier to use.", + "MemoryBear — remember better, work smarter. \uD83D\uDC3B✨" + ] + } + }, "v0.2.0": { "introduction": { "codeName": "启知", From 191958075922b4c3db665a45c39b2683a93718f9 Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Fri, 23 Jan 2026 17:12:21 +0800 Subject: [PATCH 07/28] Fix/memory mcp2 1 (#190) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 优化快速检索的回复内容 * 优化快速检索的回复内容 * 路径的BUG修复 * 路径的BUG修复 * 路径的BUG修复 * 路径的BUG修复 * 路径的BUG修复 * LLM生存缺少config_id认证,修复BUG * LLM生存缺少config_id认证,修复BUG * LLM生存缺少config_id认证,修复BUG * 深度检索优化,搜索不到数据/提问的概念过于蘑菇,以引导的方式继续提问 * 深度检索优化,搜索不到数据/提问的概念过于蘑菇,以引导的方式继续提问 * 深度检索优化,搜索不到数据/提问的概念过于蘑菇,以引导的方式继续提问 --- .../langgraph_graph/nodes/summary_nodes.py | 22 ++++++++-- .../utils/prompt/fail_summary_prompt.jinja2 | 43 +++++++++++++++++++ api/app/services/memory_agent_service.py | 3 +- 3 files changed, 63 insertions(+), 5 deletions(-) create mode 100644 api/app/core/memory/agent/utils/prompt/fail_summary_prompt.jinja2 diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py index f05a5ae1..fb0484d2 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -236,7 +236,7 @@ async def Retrieve_Summary(state: ReadState)-> ReadState: retrieve_info_str='\n'.join(retrieve_info_str) aimessages=await summary_llm(state,history,retrieve_info_str, - 'Retrieve_Summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1") + 'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1") if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": await summary_redis_save(state, aimessages) if aimessages == '': @@ -276,7 +276,6 @@ async def Summary(state: ReadState)-> ReadState: aimessages=await summary_llm(state,history,data, 'summary_prompt.jinja2','summary',SummaryResponse,0) - if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": await summary_redis_save(state, aimessages) if aimessages == '': @@ -295,9 +294,26 @@ async def Summary(state: ReadState)-> ReadState: async def Summary_fails(state: ReadState)-> ReadState: storage_type=state.get("storage_type", '') user_rag_memory_id=state.get("user_rag_memory_id", '') + history = await summary_history(state) + query = state.get("data", '') + verify = state.get("verify", '') + verify_expansion_issue = verify.get("verified_data", '') + retrieve_info_str = '' + for data in verify_expansion_issue: + for key, value in data.items(): + if key == 'answer_small': + for i in value: + retrieve_info_str += i + '\n' + data = { + "query": query, + "history": history, + "retrieve_info": retrieve_info_str + } + aimessages = await summary_llm(state, history, data, + 'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0) result= { "status": "success", - "summary_result": "没有相关数据", + "summary_result": aimessages, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id } diff --git a/api/app/core/memory/agent/utils/prompt/fail_summary_prompt.jinja2 b/api/app/core/memory/agent/utils/prompt/fail_summary_prompt.jinja2 new file mode 100644 index 00000000..3744f99b --- /dev/null +++ b/api/app/core/memory/agent/utils/prompt/fail_summary_prompt.jinja2 @@ -0,0 +1,43 @@ +{# 角色定义 #} +你是专业的问题解答专家+引导学者 + +{# 输入数据展示 #} +{% if data %} +## 输入数据 +上下文信息: +{% for item in data.history %} +- {{ item }} +{% endfor %} +检索到的所有信息: +{% for item in data.retrieve_info %} +- {{ item }} +{% endfor %} +{% endif %} + +## User Query +{{ query }} + +{# 问题回答标准 #} +## 问题回答核心标准 +根据上下文信息(history)和检索到的所有信息(retrieve_info)准确回答用户的问题(query)。 +注意,仔细阅读检索信息,答案可能直接或间接地出现在检索信息中或者历史上下文消息中,同时需要 判断信息相关性 +**情况A:信息匹配问题** +- 直接回答,像自然对话一样 +- 例:检索到"小曼会使用Python" → 问"我叫什么" → 答"你叫小曼" + +**情况B:信息部分相关** +- 先回答已知部分,再自然地询问更多信息 +- 例:检索到"用户去过上海的面包店" → 问"我吃过哪家面包" → 答"我记得你去过上海的面包店,但具体是哪家我不太清楚,是哪家呢?" + +**情况C:信息完全不相关** +- 自然地表达不知道,但可以提及检索到的相关信息,让对话更连贯 +- 使用友好的表达: + - "你好像没和我说过...,但是我知道你[检索到的相关信息]" + - "关于这个我不太清楚,不过我记得你[检索到的相关信息],能告诉我更多吗?" + - "我不记得你提到过...,但你[检索到的相关信息]" +- 即使检索信息不直接回答问题,也可以自然地融入对话中 +- 避免僵硬的"信息不足,无法回答" + +{# 重要提醒 #} +当检索以及上下文的历史信息都无法回答的时候,可引导对方进行提问/回答,或者进行其他引导 +当检索或者上下文中出现了,相似的问题,可以委婉,提醒对方,我记得刚刚提过这个问题,但是我自己不记得了,能在描述一次吗~以此为例 diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 83b6bdd7..1e1cde89 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -542,9 +542,8 @@ class MemoryAgentService: if intermediate_type == "search_result": query = intermediate.get('query', '') raw_results = intermediate.get('raw_results', {}) - reranked_results = raw_results.get('reranked_results', []) - try: + reranked_results = raw_results.get('reranked_results', []) statements = [statement['statement'] for statement in reranked_results.get('statements', [])] except Exception: statements = [] From 4f4f55d67fa6ff71bcd2b0272cbacf63c8dbac8c Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 26 Jan 2026 11:04:30 +0800 Subject: [PATCH 08/28] feat(web): memory related interface parameter transfer adjustment --- web/src/api/memory.ts | 48 +- web/src/views/MemoryConversation/index.tsx | 6 +- .../views/MemoryExtractionEngine/constant.ts | 604 +----------------- web/src/views/MemoryManagement/types.ts | 1 - .../components/PerceptualLastInfo.tsx | 11 +- 5 files changed, 36 insertions(+), 634 deletions(-) diff --git a/web/src/api/memory.ts b/web/src/api/memory.ts index bbd9f6b0..ff8e0435 100644 --- a/web/src/api/memory.ts +++ b/web/src/api/memory.ts @@ -116,20 +116,20 @@ export const getRagContent = (end_user_id: string) => { return request.get(`/dashboard/rag_content`, { end_user_id, limit: 20 }) } // Emotion distribution analysis -export const getWordCloud = (group_id: string) => { - return request.post(`/memory/emotion-memory/wordcloud`, { group_id, limit: 20 }) +export const getWordCloud = (end_user_id: string) => { + return request.post(`/memory/emotion-memory/wordcloud`, { end_user_id, limit: 20 }) } // High-frequency emotion keywords -export const getEmotionTags = (group_id: string) => { - return request.post(`/memory/emotion-memory/tags`, { group_id, limit: 20 }) +export const getEmotionTags = (end_user_id: string) => { + return request.post(`/memory/emotion-memory/tags`, { end_user_id, limit: 20 }) } // Emotion health index -export const getEmotionHealth = (group_id: string) => { - return request.post(`/memory/emotion-memory/health`, { group_id, limit: 20 }) +export const getEmotionHealth = (end_user_id: string) => { + return request.post(`/memory/emotion-memory/health`, { end_user_id }) } // Personalized suggestions -export const getEmotionSuggestions = (group_id: string) => { - return request.post(`/memory/emotion-memory/suggestions`, { group_id, limit: 20 }) +export const getEmotionSuggestions = (end_user_id: string) => { + return request.post(`/memory/emotion-memory/suggestions`, { end_user_id }) } export const generateSuggestions = (end_user_id: string) => { return request.post(`/memory/emotion-memory/generate_suggestions`, { end_user_id }) @@ -138,8 +138,8 @@ export const analyticsRefresh = (end_user_id: string) => { return request.post('/memory-storage/analytics/generate_cache', { end_user_id }) } // Forgetting stats -export const getForgetStats = (group_id: string) => { - return request.get(`/memory/forget-memory/stats`, { group_id }) +export const getForgetStats = (end_user_id: string) => { + return request.get(`/memory/forget-memory/stats`, { end_user_id }) } // Implicit Memory - Preferences export const getImplicitPreferences = (end_user_id: string) => { @@ -165,20 +165,20 @@ export const getShortTerm = (end_user_id: string) => { return request.get(`/memory/short/short_term`, { end_user_id }) } // Perceptual Memory - Visual memory -export const getPerceptualLastVisual = (end_user: string) => { - return request.get(`/memory/perceptual/${end_user}/last_visual`) +export const getPerceptualLastVisual = (end_user_id: string) => { + return request.get(`/memory/perceptual/${end_user_id}/last_visual`) } // Perceptual Memory - Audio memory -export const getPerceptualLastListen = (end_user: string) => { - return request.get(`/memory/perceptual/${end_user}/last_listen`) +export const getPerceptualLastListen = (end_user_id: string) => { + return request.get(`/memory/perceptual/${end_user_id}/last_listen`) } // Perceptual Memory - Text memory -export const getPerceptualLastText = (end_user: string) => { - return request.get(`/memory/perceptual/${end_user}/last_text`) +export const getPerceptualLastText = (end_user_id: string) => { + return request.get(`/memory/perceptual/${end_user_id}/last_text`) } // Perceptual Memory - Perceptual memory timeline -export const getPerceptualTimeline = (end_user: string) => { - return request.get(`/memory/perceptual/${end_user}/timeline`) +export const getPerceptualTimeline = (end_user_id: string) => { + return request.get(`/memory/perceptual/${end_user_id}/timeline`) } // Episodic Memory - Overview export const getEpisodicOverview = (data: { end_user_id: string; time_range: string; episodic_type: string; } ) => { @@ -201,14 +201,14 @@ export const getExplicitMemory = (end_user_id: string) => { export const getExplicitMemoryDetails = (data: { end_user_id: string, memory_id: string; }) => { return request.post(`/memory/explicit-memory/details`, data) } -export const getConversations = (end_user: string) => { - return request.get(`/memory/work/${end_user}/conversations`) +export const getConversations = (end_user_id: string) => { + return request.get(`/memory/work/${end_user_id}/conversations`) } -export const getConversationMessages = (end_user: string, conversation_id: string) => { - return request.get(`/memory/work/${end_user}/messages`, { conversation_id }) +export const getConversationMessages = (end_user_id: string, conversation_id: string) => { + return request.get(`/memory/work/${end_user_id}/messages`, { conversation_id }) } -export const getConversationDetail = (end_user: string, conversation_id: string) => { - return request.get(`/memory/work/${end_user}/detail`, { conversation_id }) +export const getConversationDetail = (end_user_id: string, conversation_id: string) => { + return request.get(`/memory/work/${end_user_id}/detail`, { conversation_id }) } export const forgetTrigger = (data: { max_merge_batch_size: number; min_days_since_access: number; end_user_id: string;}) => { return request.post(`/memory/forget-memory/trigger`, data) diff --git a/web/src/views/MemoryConversation/index.tsx b/web/src/views/MemoryConversation/index.tsx index 424b9878..66a66779 100644 --- a/web/src/views/MemoryConversation/index.tsx +++ b/web/src/views/MemoryConversation/index.tsx @@ -45,7 +45,7 @@ const searchSwitchList = [ ] export interface TestParams { - group_id: string; + end_user_id: string; message: string; search_switch: string; history: { role: string; content: string }[]; @@ -107,7 +107,7 @@ const MemoryConversation: FC = () => { setLoading(true) readService({ message: msg, - group_id: userId, + end_user_id: userId, search_switch: search_switch, history: [], }) @@ -204,7 +204,7 @@ const MemoryConversation: FC = () => { } )} > -
{log.title}
+
{log.title}
{log.type === 'problem_split' && Array.isArray(log.data) && log.data.length > 0 ? {log.data.map(vo => ( diff --git a/web/src/views/MemoryExtractionEngine/constant.ts b/web/src/views/MemoryExtractionEngine/constant.ts index d1b7b757..5939a1bc 100644 --- a/web/src/views/MemoryExtractionEngine/constant.ts +++ b/web/src/views/MemoryExtractionEngine/constant.ts @@ -1093,606 +1093,4 @@ export const groupDataByType = (data: any[], groupKey: string) => { }) return grouped -} - -export const mockTestResult = { - "generated_at": "2025-12-12T09:48:43.389893", - "entities": { - "extracted_count": 148 - }, - "dedup": { - "total_merged_count": 39, - "breakdown": { - "exact": 30, - "fuzzy": 0, - "llm": 9 - }, - "impact": [ - { - "name": "记忆熊", - "type": "Person", - "appear_count": 9, - "merge_count": 8 - }, - { - "name": "宋朝", - "type": "Organization", - "appear_count": 5, - "merge_count": 2 - }, - { - "name": "军费", - "type": "EconomicMetric", - "appear_count": 2, - "merge_count": 1 - }, - { - "name": "学生", - "type": "Person", - "appear_count": 6, - "merge_count": 5 - }, - { - "name": "废除丞相制度", - "type": "Event", - "appear_count": 6, - "merge_count": 3 - }, - { - "name": "六部", - "type": "Organization", - "appear_count": 4, - "merge_count": 3 - }, - { - "name": "六部缺乏协调机制", - "type": "Concept", - "appear_count": 2, - "merge_count": 1 - }, - { - "name": "丞相", - "type": "Position", - "appear_count": 4, - "merge_count": 1 - }, - { - "name": "总理", - "type": "Position", - "appear_count": 2, - "merge_count": 1 - }, - { - "name": "各部委", - "type": "Organization", - "appear_count": 2, - "merge_count": 1 - }, - { - "name": "六部直接对皇帝负责", - "type": "AdministrativeStructure", - "appear_count": 2, - "merge_count": 1 - }, - { - "name": "秦国", - "type": "Organization", - "appear_count": 5, - "merge_count": 2 - }, - { - "name": "文官集团", - "type": "Organization", - "appear_count": 2, - "merge_count": 1 - } - ] - }, - "disambiguation": { - "block_count": 1, - "effects": [ - { - "left": { - "name": "节度使", - "type": "Role" - }, - "right": { - "name": "节度使", - "type": "Person" - }, - "result": "成功区分" - } - ] - }, - "memory": { - "chunks": 2 - }, - "triplets": { - "count": 88 - }, - "core_entities": [ - { - "type": "Organization", - "type_cn": "组织", - "count": 16, - "entities": [ - "厂卫机构", - "西厂", - "东厂", - "工部", - "地方军阀" - ] - }, - { - "type": "Event", - "type_cn": "事件", - "count": 12, - "entities": [ - "均田制瓦解", - "无法批阅完所有政务", - "废除丞相制度", - "持续战争", - "政令执行困难" - ] - }, - { - "type": "Condition", - "type_cn": "Condition", - "count": 9, - "entities": [ - "缺乏协作机制", - "作战效率低下", - "厢军装备不足", - "军权分散", - "军事专业化难以提升" - ] - }, - { - "type": "Person", - "type_cn": "人物", - "count": 8, - "entities": [ - "官员", - "宦官", - "节度使", - "皇帝", - "文士" - ] - }, - { - "type": "Concept", - "type_cn": "Concept", - "count": 8, - "entities": [ - "行政紧张", - "军力不足", - "秦国统一六国的原因", - "六部缺乏协调机制", - "专业分工" - ] - }, - { - "type": "Action", - "type_cn": "Action", - "count": 6, - "entities": [ - "再花钱募兵", - "建立军功爵制度", - "裁撤兵员", - "削减装备", - "建立法律制度" - ] - }, - { - "type": "Outcome", - "type_cn": "Outcome", - "count": 5, - "entities": [ - "打仗更吃亏", - "提升国家组织能力", - "降低行政效率", - "士兵效忠个人而非国家", - "政令推行困难" - ] - }, - { - "type": "EconomicMetric", - "type_cn": "EconomicMetric", - "count": 4, - "entities": [ - "财政", - "财政支出", - "支出", - "军费" - ] - }, - { - "type": "Statement", - "type_cn": "Statement", - "count": 3, - "entities": [ - "没有银子", - "禁军由文官控制导致作战效率低下", - "武器没材料" - ] - }, - { - "type": "State", - "type_cn": "State", - "count": 3, - "entities": [ - "军队更弱", - "理解不足", - "不足" - ] - }, - { - "type": "HistoricalPeriod", - "type_cn": "HistoricalPeriod", - "count": 3, - "entities": [ - "春秋战国史", - "唐朝史", - "宋朝" - ] - }, - { - "type": "Attribute", - "type_cn": "Attribute", - "count": 3, - "entities": [ - "资源丰富", - "易守难攻", - "政策连续性强" - ] - }, - { - "type": "Right", - "type_cn": "Right", - "count": 3, - "entities": [ - "军事指挥权", - "财政调度权", - "募兵权" - ] - }, - { - "type": "Policy", - "type_cn": "Policy", - "count": 2, - "entities": [ - "商鞅变法", - "禁军由文官控制" - ] - }, - { - "type": "MilitaryCondition", - "type_cn": "MilitaryCondition", - "count": 2, - "entities": [ - "军力不足", - "缺乏战略纵深" - ] - }, - { - "type": "Role", - "type_cn": "Role", - "count": 2, - "entities": [ - "节度使", - "协调中枢" - ] - }, - { - "type": "Position", - "type_cn": "Position", - "count": 2, - "entities": [ - "总理", - "丞相" - ] - }, - { - "type": "PoliticalCharacteristic", - "type_cn": "PoliticalCharacteristic", - "count": 2, - "entities": [ - "旧贵族势力弱", - "中央集权程度高" - ] - }, - { - "type": "Phenomenon", - "type_cn": "Phenomenon", - "count": 1, - "entities": [ - "宋朝军事弱势" - ] - }, - { - "type": "Factor", - "type_cn": "Factor", - "count": 1, - "entities": [ - "制度性因素" - ] - }, - { - "type": "EconomicFactor", - "type_cn": "EconomicFactor", - "count": 1, - "entities": [ - "财政压力" - ] - }, - { - "type": "EconomicIndicator", - "type_cn": "EconomicIndicator", - "count": 1, - "entities": [ - "财政支出" - ] - }, - { - "type": "MilitaryStrategy", - "type_cn": "MilitaryStrategy", - "count": 1, - "entities": [ - "对外战略被动" - ] - }, - { - "type": "MilitaryCapability", - "type_cn": "MilitaryCapability", - "count": 1, - "entities": [ - "机动能力弱" - ] - }, - { - "type": "PersonGroup", - "type_cn": "PersonGroup", - "count": 1, - "entities": [ - "武将" - ] - }, - { - "type": "EconomicCondition", - "type_cn": "EconomicCondition", - "count": 1, - "entities": [ - "财政压力" - ] - }, - { - "type": "InstitutionalPolicy", - "type_cn": "InstitutionalPolicy", - "count": 1, - "entities": [ - "废除丞相制度" - ] - }, - { - "type": "StateOfAffairs", - "type_cn": "StateOfAffairs", - "count": 1, - "entities": [ - "中央决策高度集中于皇帝" - ] - }, - { - "type": "Institution", - "type_cn": "Institution", - "count": 1, - "entities": [ - "科举" - ] - }, - { - "type": "Function", - "type_cn": "Function", - "count": 1, - "entities": [ - "统筹大事小情" - ] - }, - { - "type": "AdministrativeStructure", - "type_cn": "AdministrativeStructure", - "count": 1, - "entities": [ - "六部直接对皇帝负责" - ] - }, - { - "type": "AdministrativeProblem", - "type_cn": "AdministrativeProblem", - "count": 1, - "entities": [ - "皇帝一人批不完政务" - ] - }, - { - "type": "Behavior", - "type_cn": "Behavior", - "count": 1, - "entities": [ - "互相推诿责任" - ] - }, - { - "type": "Resource", - "type_cn": "Resource", - "count": 1, - "entities": [ - "银子" - ] - }, - { - "type": "Situation", - "type_cn": "Situation", - "count": 1, - "entities": [ - "没人拍板" - ] - }, - { - "type": "HistoricalState", - "type_cn": "HistoricalState", - "count": 1, - "entities": [ - "秦国" - ] - }, - { - "type": "Location", - "type_cn": "地点", - "count": 1, - "entities": [ - "关中" - ] - }, - { - "type": "HistoricalEvent", - "type_cn": "HistoricalEvent", - "count": 1, - "entities": [ - "安史之乱" - ] - }, - { - "type": "PoliticalAction", - "type_cn": "PoliticalAction", - "count": 1, - "entities": [ - "中央整顿" - ] - }, - { - "type": "PoliticalPhenomenon", - "type_cn": "PoliticalPhenomenon", - "count": 1, - "entities": [ - "藩镇割据加剧" - ] - }, - { - "type": "EconomicEntity", - "type_cn": "EconomicEntity", - "count": 1, - "entities": [ - "中央财政" - ] - }, - { - "type": "System", - "type_cn": "System", - "count": 1, - "entities": [ - "募兵制" - ] - }, - { - "type": "WorkRole", - "type_cn": "WorkRole", - "count": 1, - "entities": [ - "掌控禁军" - ] - } - ], - "triplet_samples": [ - { - "subject": "记忆熊", - "predicate": "MENTIONS", - "predicate_cn": "提到", - "object": "宋朝军事弱势" - }, - { - "subject": "宋朝军事弱势", - "predicate": "RESULTED_IN", - "predicate_cn": "resulted in", - "object": "制度性因素" - }, - { - "subject": "记忆熊", - "predicate": "MENTIONS", - "predicate_cn": "提到", - "object": "禁军由文官控制导致作战效率低下" - }, - { - "subject": "禁军由文官控制", - "predicate": "RESULTED_IN", - "predicate_cn": "resulted in", - "object": "作战效率低下" - }, - { - "subject": "记忆熊", - "predicate": "MENTIONS", - "predicate_cn": "提到", - "object": "厢军装备不足" - }, - { - "subject": "记忆熊", - "predicate": "MENTIONS", - "predicate_cn": "提到", - "object": "宋朝" - }, - { - "subject": "记忆熊", - "predicate": "MENTIONS", - "predicate_cn": "提到", - "object": "军费" - } - ], - "self_reflexion": [ - { - "conflict": { - "data": [ - { - "id": "76be6d82d8804beda6baa3d3447d6cbc", - "statement": "学生对\"六部缺乏协调机制\"的具体影响表示理解不足。", - "group_id": "group_123", - "chunk_id": "4a0804127d35456f86d4f06e1fa458f7", - "created_at": "2025-12-12 09:48:00.166068", - "expired_at": null, - "valid_at": null, - "invalid_at": null, - "entity_ids": [] - } - ], - "conflict": true, - "conflict_memory": { - "id": "e268a6fff35543fab471986c188e023e", - "statement": "学生对\"六部缺乏协调机制\"的具体影响表示理解不足。", - "group_id": "group_123", - "chunk_id": "e6cb5f56020e4a8d925d148e1d2fbda0", - "created_at": "2025-12-12 09:48:00.166068", - "expired_at": null, - "valid_at": null, - "invalid_at": null, - "entity_ids": [] - } - }, - "reflexion": { - "reason": "同一学生在不同时间点重复提出对'六部缺乏协调机制'具体影响的理解困难,表明原有解释未能有效解决其认知障碍,存在记忆冗余与教学反馈失效的冲突。", - "solution": "保留后出现的记忆记录(chunk_id为4a0804127d35456f86d4f06e1fa458f7)作为最新学习状态,将其设为有效;将前次相同内容的记忆(id为e268a6fff35543fab471986c188e023e)标记为失效,避免重复干预,并基于后续完整解释优化知识呈现逻辑。" - }, - "resolved": { - "original_memory_id": "e268a6fff35543fab471986c188e023e", - "resolved_memory": { - "id": "e268a6fff35543fab471986c188e023e", - "statement": "学生对\"六部缺乏协调机制\"的具体影响表示理解不足。", - "group_id": "group_123", - "chunk_id": "e6cb5f56020e4a8d925d148e1d2fbda0", - "created_at": "2025-12-12 09:48:00.166068", - "expired_at": null, - "valid_at": null, - "invalid_at": "2025-12-12 09:48:00.166068", - "entity_ids": [] - } - } - } - ] - } \ No newline at end of file +} \ No newline at end of file diff --git a/web/src/views/MemoryManagement/types.ts b/web/src/views/MemoryManagement/types.ts index f926c6c8..55524462 100644 --- a/web/src/views/MemoryManagement/types.ts +++ b/web/src/views/MemoryManagement/types.ts @@ -23,7 +23,6 @@ export interface Memory { include_dialogue_context: boolean; max_context: string; lambda_mem: string; - lambda_mem: string; offset: string; state: boolean; created_at: string; diff --git a/web/src/views/UserMemoryDetail/components/PerceptualLastInfo.tsx b/web/src/views/UserMemoryDetail/components/PerceptualLastInfo.tsx index d3788a74..ef547742 100644 --- a/web/src/views/UserMemoryDetail/components/PerceptualLastInfo.tsx +++ b/web/src/views/UserMemoryDetail/components/PerceptualLastInfo.tsx @@ -59,6 +59,11 @@ const PerceptualLastInfo: FC<{ type: 'last_visual' | 'last_listen' | 'last_text' }) } + const handleDownload = () => { + if (!data.file_path) return + window.open(data.file_path, '_blank') + } + return ( // {data.file_name} ) : ( -
{data.file_name}
+
{data.file_name}
) ) : type === 'last_listen' && /\.(mp3|wav|ogg|m4a|aac)$/i.test(data.file_name) ? ( ) : ( -
{data.file_name}
+
{data.file_name}
) ) : ( -
No file
+
{t('empty.tableEmpty')}
)} From 36017378693906f1c89744ef9284671a72821e15 Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Mon, 26 Jan 2026 11:53:34 +0800 Subject: [PATCH 09/28] Fix/memory bug fix (#171) --- api/app/__init__.py | 0 .../controllers/emotion_config_controller.py | 7 +- api/app/controllers/emotion_controller.py | 38 +-- .../controllers/implicit_memory_controller.py | 80 +++--- .../controllers/memory_agent_controller.py | 52 ++-- .../controllers/memory_forget_controller.py | 25 +- .../memory_perceptual_controller.py | 66 ++--- .../memory_reflection_controller.py | 41 +-- .../controllers/memory_storage_controller.py | 5 +- .../controllers/memory_working_controller.py | 16 +- .../service/memory_api_controller.py | 2 +- .../controllers/user_memory_controllers.py | 18 +- api/app/core/agent/langchain_agent.py | 18 +- .../langgraph_graph/nodes/problem_nodes.py | 8 +- .../langgraph_graph/nodes/retrieve_nodes.py | 18 +- .../langgraph_graph/nodes/summary_nodes.py | 16 +- .../nodes/verification_nodes.py | 6 +- .../langgraph_graph/nodes/write_nodes.py | 17 +- .../agent/langgraph_graph/read_graph.py | 6 +- .../agent/langgraph_graph/tools/tool.py | 30 +-- .../agent/langgraph_graph/write_graph.py | 23 +- .../agent/services/parameter_builder.py | 6 +- .../memory/agent/services/search_service.py | 8 +- .../memory/agent/services/session_service.py | 18 +- .../core/memory/agent/utils/get_dialogs.py | 32 +-- api/app/core/memory/agent/utils/llm_tools.py | 10 +- api/app/core/memory/agent/utils/redis_tool.py | 26 +- .../core/memory/agent/utils/session_tools.py | 18 +- .../core/memory/agent/utils/write_tools.py | 16 +- .../core/memory/analytics/hot_memory_tags.py | 36 +-- .../analytics/implicit_memory/data_source.py | 4 +- .../memory/evaluation/dialogue_queries.py | 4 +- .../memory/evaluation/extraction_utils.py | 12 +- .../evaluation/locomo/locomo_benchmark.py | 26 +- .../memory/evaluation/locomo/locomo_test.py | 2 +- .../memory/evaluation/locomo/locomo_utils.py | 18 +- .../evaluation/locomo/qwen_search_eval.py | 22 +- .../longmemeval/qwen_search_eval.py | 58 ++--- .../evaluation/longmemeval/test_eval.py | 58 ++--- .../memory/evaluation/memsciqa/evaluate_qa.py | 12 +- .../evaluation/memsciqa/memsciqa-test.py | 12 +- api/app/core/memory/evaluation/run_eval.py | 18 +- .../core/memory/llm_tools/chunker_client.py | 24 +- api/app/core/memory/models/config_models.py | 4 +- api/app/core/memory/models/graph_models.py | 16 +- api/app/core/memory/models/message_models.py | 20 +- api/app/core/memory/src/search.py | 45 ++-- .../data_preprocessing/data_preprocessor.py | 10 +- .../deduplication/deduped_and_disamb.py | 18 +- .../deduplication/entity_dedup_llm.py | 18 +- .../deduplication/second_layer_dedup.py | 8 +- .../deduplication/two_stage_dedup.py | 14 +- .../extraction_orchestrator.py | 68 ++--- .../knowledge_extraction/memory_summary.py | 6 +- .../statement_extraction.py | 16 +- .../temporal_extraction.py | 2 +- .../triplet_extraction.py | 2 +- .../access_history_manager.py | 86 +++---- .../forgetting_engine/config_utils.py | 16 +- .../forgetting_engine/forgetting_scheduler.py | 31 +-- .../forgetting_engine/forgetting_strategy.py | 31 +-- .../storage_services/search/__init__.py | 6 +- .../storage_services/search/hybrid_search.py | 14 +- .../storage_services/search/keyword_search.py | 12 +- .../search/search_strategy.py | 10 +- .../search/semantic_search.py | 12 +- api/app/core/memory/utils/config/get_data.py | 4 +- api/app/core/memory/utils/log/audit_logger.py | 12 +- api/app/core/rag/vdb/field.py | 2 +- .../validators/memory_config_validators.py | 10 +- api/app/core/workflow/nodes/memory/config.py | 5 +- api/app/core/workflow/nodes/memory/node.py | 2 +- api/app/models/__init__.py | 4 +- api/app/models/data_config_model.py | 88 ------- api/app/models/memory_config_model.py | 119 ++++++--- api/app/models/memory_perceptual_model.py | 2 +- ...ository.py => memory_config_repository.py} | 210 ++++++++-------- .../memory_perceptual_repository.py | 4 +- api/app/repositories/neo4j/add_edges.py | 4 +- api/app/repositories/neo4j/add_nodes.py | 22 +- .../neo4j/base_neo4j_repository.py | 2 +- api/app/repositories/neo4j/cypher_queries.py | 171 +++++-------- .../repositories/neo4j/dialog_repository.py | 34 +-- .../repositories/neo4j/emotion_repository.py | 24 +- api/app/repositories/neo4j/graph_saver.py | 12 +- api/app/repositories/neo4j/graph_search.py | 236 +++++++++--------- .../neo4j/memory_summary_repository.py | 48 ++-- api/app/repositories/neo4j/neo4j_connector.py | 48 +--- .../neo4j/statement_repository.py | 2 +- api/app/schemas/app_schema.py | 12 + api/app/schemas/emotion_schema.py | 11 +- api/app/schemas/memory_agent_schema.py | 6 +- api/app/schemas/memory_config_schema.py | 20 +- api/app/schemas/memory_perceptual_schema.py | 8 +- api/app/schemas/memory_reflection_schemas.py | 3 +- api/app/schemas/memory_storage_schema.py | 34 +-- api/app/schemas/model_schema.py | 14 +- api/app/schemas/release_share_schema.py | 14 +- api/app/services/draft_run_service.py | 2 +- api/app/services/emotion_analytics_service.py | 14 +- api/app/services/emotion_config_service.py | 16 +- .../services/emotion_extraction_service.py | 4 +- api/app/services/memory_agent_service.py | 150 +++++------ api/app/services/memory_api_service.py | 21 +- api/app/services/memory_base_service.py | 18 +- api/app/services/memory_config_service.py | 91 ++++--- .../memory_entity_relationship_service.py | 4 +- api/app/services/memory_episodic_service.py | 30 +-- api/app/services/memory_explicit_service.py | 16 +- api/app/services/memory_forget_service.py | 81 +++--- api/app/services/memory_konwledges_server.py | 14 +- api/app/services/memory_perceptual_service.py | 26 +- api/app/services/memory_reflection_service.py | 50 ++-- api/app/services/memory_storage_service.py | 52 ++-- api/app/services/pilot_run_service.py | 2 +- api/app/services/user_memory_service.py | 42 ++-- api/app/tasks.py | 164 +++++++++--- api/app/utils/app_config_utils.py | 23 ++ api/uv.lock | 2 +- 119 files changed, 1711 insertions(+), 1695 deletions(-) create mode 100644 api/app/__init__.py delete mode 100644 api/app/models/data_config_model.py rename api/app/repositories/{data_config_repository.py => memory_config_repository.py} (73%) diff --git a/api/app/__init__.py b/api/app/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/app/controllers/emotion_config_controller.py b/api/app/controllers/emotion_config_controller.py index 76450d8a..b0015bc2 100644 --- a/api/app/controllers/emotion_config_controller.py +++ b/api/app/controllers/emotion_config_controller.py @@ -12,6 +12,7 @@ from fastapi import APIRouter, Depends, Query, HTTPException, status from pydantic import BaseModel, Field from typing import Optional from sqlalchemy.orm import Session +from uuid import UUID from app.core.response_utils import success from app.dependencies import get_current_user @@ -32,11 +33,11 @@ router = APIRouter( class EmotionConfigQuery(BaseModel): """情绪配置查询请求模型""" - config_id: int = Field(..., description="配置ID") + config_id: UUID = Field(..., description="配置ID") class EmotionConfigUpdate(BaseModel): """情绪配置更新请求模型""" - config_id: int = Field(..., description="配置ID") + config_id: UUID = Field(..., description="配置ID") emotion_enabled: bool = Field(..., description="是否启用情绪提取") emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID") emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词") @@ -45,7 +46,7 @@ class EmotionConfigUpdate(BaseModel): @router.get("/read_config", response_model=ApiResponse) def get_emotion_config( - config_id: int = Query(..., description="配置ID"), + config_id: UUID = Query(..., description="配置ID"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): diff --git a/api/app/controllers/emotion_controller.py b/api/app/controllers/emotion_controller.py index 154a3928..cd199aa7 100644 --- a/api/app/controllers/emotion_controller.py +++ b/api/app/controllers/emotion_controller.py @@ -53,7 +53,7 @@ async def get_emotion_tags( api_logger.info( f"用户 {current_user.username} 请求获取情绪标签统计", extra={ - "group_id": request.group_id, + "end_user_id": request.end_user_id, "emotion_type": request.emotion_type, "start_date": request.start_date, "end_date": request.end_date, @@ -63,7 +63,7 @@ async def get_emotion_tags( # 调用服务层 data = await emotion_service.get_emotion_tags( - end_user_id=request.group_id, + end_user_id=request.end_user_id, emotion_type=request.emotion_type, start_date=request.start_date, end_date=request.end_date, @@ -73,7 +73,7 @@ async def get_emotion_tags( api_logger.info( "情绪标签统计获取成功", extra={ - "group_id": request.group_id, + "end_user_id": request.end_user_id, "total_count": data.get("total_count", 0), "tags_count": len(data.get("tags", [])) } @@ -84,7 +84,7 @@ async def get_emotion_tags( except Exception as e: api_logger.error( f"获取情绪标签统计失败: {str(e)}", - extra={"group_id": request.group_id}, + extra={"end_user_id": request.end_user_id}, exc_info=True ) raise HTTPException( @@ -105,7 +105,7 @@ async def get_emotion_wordcloud( api_logger.info( f"用户 {current_user.username} 请求获取情绪词云数据", extra={ - "group_id": request.group_id, + "end_user_id": request.end_user_id, "emotion_type": request.emotion_type, "limit": request.limit } @@ -113,7 +113,7 @@ async def get_emotion_wordcloud( # 调用服务层 data = await emotion_service.get_emotion_wordcloud( - end_user_id=request.group_id, + end_user_id=request.end_user_id, emotion_type=request.emotion_type, limit=request.limit ) @@ -121,7 +121,7 @@ async def get_emotion_wordcloud( api_logger.info( "情绪词云数据获取成功", extra={ - "group_id": request.group_id, + "end_user_id": request.end_user_id, "total_keywords": data.get("total_keywords", 0) } ) @@ -131,7 +131,7 @@ async def get_emotion_wordcloud( except Exception as e: api_logger.error( f"获取情绪词云数据失败: {str(e)}", - extra={"group_id": request.group_id}, + extra={"end_user_id": request.end_user_id}, exc_info=True ) raise HTTPException( @@ -159,21 +159,21 @@ async def get_emotion_health( api_logger.info( f"用户 {current_user.username} 请求获取情绪健康指数", extra={ - "group_id": request.group_id, + "end_user_id": request.end_user_id, "time_range": request.time_range } ) # 调用服务层 data = await emotion_service.calculate_emotion_health_index( - end_user_id=request.group_id, + end_user_id=request.end_user_id, time_range=request.time_range ) api_logger.info( "情绪健康指数获取成功", extra={ - "group_id": request.group_id, + "end_user_id": request.end_user_id, "health_score": data.get("health_score", 0), "level": data.get("level", "未知") } @@ -186,7 +186,7 @@ async def get_emotion_health( except Exception as e: api_logger.error( f"获取情绪健康指数失败: {str(e)}", - extra={"group_id": request.group_id}, + extra={"end_user_id": request.end_user_id}, exc_info=True ) raise HTTPException( @@ -206,7 +206,7 @@ async def get_emotion_suggestions( """获取个性化情绪建议(从缓存读取) Args: - request: 包含 group_id 和可选的 config_id + request: 包含 end_user_id 和可选的 config_id db: 数据库会话 current_user: 当前用户 @@ -217,22 +217,22 @@ async def get_emotion_suggestions( api_logger.info( f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)", extra={ - "group_id": request.group_id, + "end_user_id": request.end_user_id, "config_id": request.config_id } ) # 从缓存获取建议 data = await emotion_service.get_cached_suggestions( - end_user_id=request.group_id, + end_user_id=request.end_user_id, db=db ) if data is None: # 缓存不存在或已过期 api_logger.info( - f"用户 {request.group_id} 的建议缓存不存在或已过期", - extra={"group_id": request.group_id} + f"用户 {request.end_user_id} 的建议缓存不存在或已过期", + extra={"end_user_id": request.end_user_id} ) return fail( BizCode.NOT_FOUND, @@ -243,7 +243,7 @@ async def get_emotion_suggestions( api_logger.info( "个性化建议获取成功(缓存)", extra={ - "group_id": request.group_id, + "end_user_id": request.end_user_id, "suggestions_count": len(data.get("suggestions", [])) } ) @@ -253,7 +253,7 @@ async def get_emotion_suggestions( except Exception as e: api_logger.error( f"获取个性化建议失败: {str(e)}", - extra={"group_id": request.group_id}, + extra={"end_user_id": request.end_user_id}, exc_info=True ) raise HTTPException( diff --git a/api/app/controllers/implicit_memory_controller.py b/api/app/controllers/implicit_memory_controller.py index a53290e2..96e437d6 100644 --- a/api/app/controllers/implicit_memory_controller.py +++ b/api/app/controllers/implicit_memory_controller.py @@ -122,10 +122,10 @@ def validate_confidence_threshold(threshold: float) -> None: raise ValueError("confidence_threshold must be between 0.0 and 1.0") -@router.get("/preferences/{user_id}", response_model=ApiResponse) +@router.get("/preferences/{end_user_id}", response_model=ApiResponse) @cur_workspace_access_guard() async def get_preference_tags( - user_id: str, + end_user_id: str, confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"), tag_category: Optional[str] = Query(None, description="Filter by tag category"), start_date: Optional[datetime] = Query(None, description="Filter start date"), @@ -137,7 +137,7 @@ async def get_preference_tags( Get user preference tags from cache. Args: - user_id: Target user ID + end_user_id: Target end user ID confidence_threshold: Minimum confidence score (0.0-1.0) tag_category: Optional category filter start_date: Optional start date filter @@ -146,20 +146,20 @@ async def get_preference_tags( Returns: List of preference tags from cache """ - api_logger.info(f"Preference tags requested for user: {user_id} (from cache)") + api_logger.info(f"Preference tags requested for user: {end_user_id} (from cache)") try: # Validate inputs - validate_user_id(user_id) + validate_user_id(end_user_id) # Create service with user-specific config - service = ImplicitMemoryService(db=db, end_user_id=user_id) + service = ImplicitMemoryService(db=db, end_user_id=end_user_id) # Get cached profile - cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db) + cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db) if cached_profile is None: - api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期") + api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期") return fail( BizCode.NOT_FOUND, "画像缓存不存在或已过期,请右上角刷新生成新画像", @@ -192,17 +192,17 @@ async def get_preference_tags( filtered_preferences.append(pref) - api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {user_id} (from cache)") + api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {end_user_id} (from cache)") return success(data=filtered_preferences, msg="偏好标签获取成功(缓存)") except Exception as e: - return handle_implicit_memory_error(e, "偏好标签获取", user_id) + return handle_implicit_memory_error(e, "偏好标签获取", end_user_id) -@router.get("/portrait/{user_id}", response_model=ApiResponse) +@router.get("/portrait/{end_user_id}", response_model=ApiResponse) @cur_workspace_access_guard() async def get_dimension_portrait( - user_id: str, + end_user_id: str, include_history: bool = Query(False, description="Include historical trends"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user) @@ -211,26 +211,26 @@ async def get_dimension_portrait( Get user's four-dimension personality portrait from cache. Args: - user_id: Target user ID + end_user_id: Target end user ID include_history: Whether to include historical trend data (ignored for cached data) Returns: Four-dimension personality portrait from cache """ - api_logger.info(f"Dimension portrait requested for user: {user_id} (from cache)") + api_logger.info(f"Dimension portrait requested for user: {end_user_id} (from cache)") try: # Validate inputs - validate_user_id(user_id) + validate_user_id(end_user_id) # Create service with user-specific config - service = ImplicitMemoryService(db=db, end_user_id=user_id) + service = ImplicitMemoryService(db=db, end_user_id=end_user_id) # Get cached profile - cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db) + cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db) if cached_profile is None: - api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期") + api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期") return fail( BizCode.NOT_FOUND, "画像缓存不存在或已过期,请右上角刷新生成新画像", @@ -240,17 +240,17 @@ async def get_dimension_portrait( # Extract portrait from cache portrait = cached_profile.get("portrait", {}) - api_logger.info(f"Dimension portrait retrieved for user: {user_id} (from cache)") + api_logger.info(f"Dimension portrait retrieved for user: {end_user_id} (from cache)") return success(data=portrait, msg="四维画像获取成功(缓存)") except Exception as e: - return handle_implicit_memory_error(e, "四维画像获取", user_id) + return handle_implicit_memory_error(e, "四维画像获取", end_user_id) -@router.get("/interest-areas/{user_id}", response_model=ApiResponse) +@router.get("/interest-areas/{end_user_id}", response_model=ApiResponse) @cur_workspace_access_guard() async def get_interest_area_distribution( - user_id: str, + end_user_id: str, include_trends: bool = Query(False, description="Include trend analysis"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user) @@ -259,26 +259,26 @@ async def get_interest_area_distribution( Get user's interest area distribution from cache. Args: - user_id: Target user ID + end_user_id: Target end user ID include_trends: Whether to include trend analysis data (ignored for cached data) Returns: Interest area distribution from cache """ - api_logger.info(f"Interest area distribution requested for user: {user_id} (from cache)") + api_logger.info(f"Interest area distribution requested for user: {end_user_id} (from cache)") try: # Validate inputs - validate_user_id(user_id) + validate_user_id(end_user_id) # Create service with user-specific config - service = ImplicitMemoryService(db=db, end_user_id=user_id) + service = ImplicitMemoryService(db=db, end_user_id=end_user_id) # Get cached profile - cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db) + cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db) if cached_profile is None: - api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期") + api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期") return fail( BizCode.NOT_FOUND, "画像缓存不存在或已过期,请右上角刷新生成新画像", @@ -288,17 +288,17 @@ async def get_interest_area_distribution( # Extract interest areas from cache interest_areas = cached_profile.get("interest_areas", {}) - api_logger.info(f"Interest area distribution retrieved for user: {user_id} (from cache)") + api_logger.info(f"Interest area distribution retrieved for user: {end_user_id} (from cache)") return success(data=interest_areas, msg="兴趣领域分布获取成功(缓存)") except Exception as e: - return handle_implicit_memory_error(e, "兴趣领域分布获取", user_id) + return handle_implicit_memory_error(e, "兴趣领域分布获取", end_user_id) -@router.get("/habits/{user_id}", response_model=ApiResponse) +@router.get("/habits/{end_user_id}", response_model=ApiResponse) @cur_workspace_access_guard() async def get_behavior_habits( - user_id: str, + end_user_id: str, confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"), frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"), time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"), @@ -309,7 +309,7 @@ async def get_behavior_habits( Get user's behavioral habits from cache. Args: - user_id: Target user ID + end_user_id: Target end user ID confidence_level: Filter by confidence level (high, medium, low) frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered) time_period: Filter by time period (current, past) @@ -317,20 +317,20 @@ async def get_behavior_habits( Returns: List of behavioral habits from cache """ - api_logger.info(f"Behavior habits requested for user: {user_id} (from cache)") + api_logger.info(f"Behavior habits requested for user: {end_user_id} (from cache)") try: # Validate inputs - validate_user_id(user_id) + validate_user_id(end_user_id) # Create service with user-specific config - service = ImplicitMemoryService(db=db, end_user_id=user_id) + service = ImplicitMemoryService(db=db, end_user_id=end_user_id) # Get cached profile - cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db) + cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db) if cached_profile is None: - api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期") + api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期") return fail( BizCode.NOT_FOUND, "画像缓存不存在或已过期,请右上角刷新生成新画像", @@ -368,11 +368,11 @@ async def get_behavior_habits( filtered_habits.append(habit) - api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {user_id} (from cache)") + api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {end_user_id} (from cache)") return success(data=filtered_habits, msg="行为习惯获取成功(缓存)") except Exception as e: - return handle_implicit_memory_error(e, "行为习惯获取", user_id) + return handle_implicit_memory_error(e, "行为习惯获取", end_user_id) diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index c54fb02b..61b16d9e 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -125,7 +125,7 @@ async def write_server( Write service endpoint - processes write operations synchronously Args: - user_input: Write request containing message and group_id + user_input: Write request containing message and end_user_id Returns: Response with write operation status @@ -160,19 +160,18 @@ async def write_server( api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") storage_type = 'neo4j' - api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") + api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") try: - # 获取标准化的消息列表 messages_list = memory_agent_service.get_messages_list(user_input) - result = await memory_agent_service.write_memory( - user_input.group_id, - messages_list, # 传递结构化消息列表 + user_input.end_user_id, + messages_list, config_id, db, storage_type, user_rag_memory_id ) + return success(data=result, msg="写入成功") except BaseException as e: # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup @@ -196,7 +195,7 @@ async def write_server_async( Async write service endpoint - enqueues write processing to Celery Args: - user_input: Write request containing message and group_id + user_input: Write request containing message and end_user_id Returns: Task ID for tracking async operation @@ -226,10 +225,10 @@ async def write_server_async( try: # 获取标准化的消息列表 messages_list = memory_agent_service.get_messages_list(user_input) - + task = celery_app.send_task( "app.core.memory.agent.write_message", - args=[user_input.group_id, messages_list, config_id, storage_type, user_rag_memory_id] + args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id] ) api_logger.info(f"Write task queued: {task.id}") @@ -255,7 +254,7 @@ async def read_server( - "2": Direct answer based on context Args: - user_input: Read request with message, history, search_switch, and group_id + user_input: Read request with message, history, search_switch, and end_user_id Returns: Response with query answer @@ -277,12 +276,13 @@ async def read_server( name="USER_RAG_MERORY", workspace_id=workspace_id ) - if knowledge: user_rag_memory_id = str(knowledge.id) + if knowledge: + user_rag_memory_id = str(knowledge.id) - api_logger.info(f"Read service: group={user_input.group_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}") + api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}") try: result = await memory_agent_service.read_memory( - user_input.group_id, + user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch, @@ -293,12 +293,12 @@ async def read_server( ) if str(user_input.search_switch) == "2": retrieve_info = result['answer'] - history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id) + history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id) query = user_input.message - + # 调用 memory_agent_service 的方法生成最终答案 result['answer'] = await memory_agent_service.generate_summary_from_retrieve( - group_id=user_input.group_id, + end_user_id=user_input.end_user_id, retrieve_info=retrieve_info, history=history, query=query, @@ -404,7 +404,7 @@ async def read_server_async( try: task = celery_app.send_task( "app.core.memory.agent.read_message", - args=[user_input.group_id, user_input.message, user_input.history, user_input.search_switch, + args=[user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch, config_id, storage_type, user_rag_memory_id] ) api_logger.info(f"Read task queued: {task.id}") @@ -448,7 +448,7 @@ async def get_read_task_result( return success( data={ "result": task_result.get("result"), - "group_id": task_result.get("group_id"), + "end_user_id": task_result.get("end_user_id"), "elapsed_time": task_result.get("elapsed_time"), "task_id": task_id }, @@ -525,7 +525,7 @@ async def get_write_task_result( return success( data={ "result": task_result.get("result"), - "group_id": task_result.get("group_id"), + "end_user_id": task_result.get("end_user_id"), "elapsed_time": task_result.get("elapsed_time"), "task_id": task_id }, @@ -579,16 +579,16 @@ async def status_type( Determine the type of user message (read or write) Args: - user_input: Request containing user message and group_id + user_input: Request containing user message and end_user_id Returns: Type classification result """ - api_logger.info(f"Status type check requested for group {user_input.group_id}") + api_logger.info(f"Status type check requested for group {user_input.end_user_id}") try: # 获取标准化的消息列表 messages_list = memory_agent_service.get_messages_list(user_input) - + # 将消息列表转换为字符串用于分类 # 只取最后一条用户消息进行分类 last_user_message = "" @@ -596,11 +596,11 @@ async def status_type( if msg.get('role') == 'user': last_user_message = msg.get('content', '') break - + if not last_user_message: # 如果没有用户消息,使用所有消息的内容 last_user_message = " ".join([msg.get('content', '') for msg in messages_list]) - + result = await memory_agent_service.classify_message_type( last_user_message, user_input.config_id, @@ -625,7 +625,7 @@ async def get_knowledge_type_stats_api( 会对缺失类型补 0,返回字典形式。 可选按状态过滤。 - 知识库类型根据当前用户的 current_workspace_id 过滤 - - memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (group_id) 过滤 + - memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤 - 如果用户没有当前工作空间或未提供 end_user_id,对应的统计返回 0 """ api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}") @@ -698,7 +698,7 @@ async def get_user_profile_api( current_user: User = Depends(get_current_user) ): """ - 获取工作空间下Popular Memory Tags,包含: + 获取用户详情,包含: - name: 用户名字(直接使用 end_user_id) - tags: 3个用户特征标签(从语句和实体中LLM总结) - hot_tags: 4个热门记忆标签 diff --git a/api/app/controllers/memory_forget_controller.py b/api/app/controllers/memory_forget_controller.py index ca628d0c..a6b6028f 100644 --- a/api/app/controllers/memory_forget_controller.py +++ b/api/app/controllers/memory_forget_controller.py @@ -11,6 +11,7 @@ """ from typing import Optional +from uuid import UUID from fastapi import APIRouter, Depends from sqlalchemy.orm import Session @@ -106,7 +107,7 @@ async def trigger_forgetting_cycle( # 调用服务层执行遗忘周期 report = await forget_service.trigger_forgetting_cycle( db=db, - group_id=end_user_id, # 服务层方法的参数名是 group_id + end_user_id=end_user_id, # 服务层方法的参数名是 end_user_id max_merge_batch_size=payload.max_merge_batch_size, min_days_since_access=payload.min_days_since_access, config_id=config_id @@ -128,7 +129,7 @@ async def trigger_forgetting_cycle( @router.get("/read_config", response_model=ApiResponse) async def read_forgetting_config( - config_id: int, + config_id: UUID, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): @@ -236,7 +237,7 @@ async def update_forgetting_config( @router.get("/stats", response_model=ApiResponse) async def get_forgetting_stats( - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): @@ -246,7 +247,7 @@ async def get_forgetting_stats( 返回知识层节点统计、激活值分布等信息。 Args: - group_id: 组ID(即 end_user_id,可选) + end_user_id: 组ID(即 end_user_id,可选) current_user: 当前用户 db: 数据库会话 @@ -260,20 +261,20 @@ async def get_forgetting_stats( api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - # 如果提供了 group_id,通过它获取 config_id + # 如果提供了 end_user_id,通过它获取 config_id config_id = None - if group_id: + if end_user_id: try: from app.services.memory_agent_service import get_end_user_connected_config - connected_config = get_end_user_connected_config(group_id, db) + connected_config = get_end_user_connected_config(end_user_id, db) config_id = connected_config.get("memory_config_id") if config_id is None: - api_logger.warning(f"终端用户 {group_id} 未关联记忆配置") - return fail(BizCode.INVALID_PARAMETER, f"终端用户 {group_id} 未关联记忆配置", "memory_config_id is None") + api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置") + return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None") - api_logger.debug(f"通过 group_id={group_id} 获取到 config_id={config_id}") + api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}") except ValueError as e: api_logger.warning(f"获取终端用户配置失败: {str(e)}") return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError") @@ -283,14 +284,14 @@ async def get_forgetting_stats( api_logger.info( f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: " - f"group_id={group_id}, config_id={config_id}" + f"end_user_id={end_user_id}, config_id={config_id}" ) try: # 调用服务层获取统计信息 stats = await forget_service.get_forgetting_stats( db=db, - group_id=group_id, + end_user_id=end_user_id, config_id=config_id ) diff --git a/api/app/controllers/memory_perceptual_controller.py b/api/app/controllers/memory_perceptual_controller.py index 5154c763..44750808 100644 --- a/api/app/controllers/memory_perceptual_controller.py +++ b/api/app/controllers/memory_perceptual_controller.py @@ -27,27 +27,27 @@ router = APIRouter( ) -@router.get("/{group_id}/count", response_model=ApiResponse) +@router.get("/{end_user_id}/count", response_model=ApiResponse) def get_memory_count( - group_id: uuid.UUID, + end_user_id: uuid.UUID, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """Retrieve perceptual memory statistics for a user group. Args: - group_id: ID of the user group (usually end_user_id in this context) + end_user_id: ID of the user group (usually end_user_id in this context) current_user: Current authenticated user db: Database session Returns: ApiResponse: Response containing memory count statistics """ - api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, group_id={group_id}") + api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, end_user_id={end_user_id}") try: service = MemoryPerceptualService(db) - count_stats = service.get_memory_count(group_id) + count_stats = service.get_memory_count(end_user_id) api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}") @@ -57,37 +57,37 @@ def get_memory_count( ) except Exception as e: - api_logger.error(f"Failed to fetch memory statistics: group_id={group_id}, error={str(e)}") + api_logger.error(f"Failed to fetch memory statistics: end_user_id={end_user_id}, error={str(e)}") return fail( code=BizCode.INTERNAL_ERROR, msg="Failed to fetch memory statistics", ) -@router.get("/{group_id}/last_visual", response_model=ApiResponse) +@router.get("/{end_user_id}/last_visual", response_model=ApiResponse) def get_last_visual_memory( - group_id: uuid.UUID, + end_user_id: uuid.UUID, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """Retrieve the most recent VISION-type memory for a user. Args: - group_id: ID of the user group + end_user_id: ID of the user group current_user: Current authenticated user db: Database session Returns: ApiResponse: Metadata of the latest visual memory """ - api_logger.info(f"Fetching latest visual memory: user={current_user.username}, group_id={group_id}") + api_logger.info(f"Fetching latest visual memory: user={current_user.username}, end_user_id={end_user_id}") try: service = MemoryPerceptualService(db) - visual_memory = service.get_latest_visual_memory(group_id) + visual_memory = service.get_latest_visual_memory(end_user_id) if visual_memory is None: - api_logger.info(f"No visual memory found: group_id={group_id}") + api_logger.info(f"No visual memory found: end_user_id={end_user_id}") return success( data=None, msg="No visual memory available" @@ -101,37 +101,37 @@ def get_last_visual_memory( ) except Exception as e: - api_logger.error(f"Failed to fetch latest visual memory: group_id={group_id}, error={str(e)}") + api_logger.error(f"Failed to fetch latest visual memory: end_user_id={end_user_id}, error={str(e)}") return fail( code=BizCode.INTERNAL_ERROR, msg="Failed to fetch latest visual memory", ) -@router.get("/{group_id}/last_listen", response_model=ApiResponse) +@router.get("/{end_user_id}/last_listen", response_model=ApiResponse) def get_last_memory_listen( - group_id: uuid.UUID, + end_user_id: uuid.UUID, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """Retrieve the most recent AUDIO-type memory for a user. Args: - group_id: ID of the user group + end_user_id: ID of the user group current_user: Current authenticated user db: Database session Returns: ApiResponse: Metadata of the latest audio memory """ - api_logger.info(f"Fetching latest audio memory: user={current_user.username}, group_id={group_id}") + api_logger.info(f"Fetching latest audio memory: user={current_user.username}, end_user_id={end_user_id}") try: service = MemoryPerceptualService(db) - audio_memory = service.get_latest_audio_memory(group_id) + audio_memory = service.get_latest_audio_memory(end_user_id) if audio_memory is None: - api_logger.info(f"No audio memory found: group_id={group_id}") + api_logger.info(f"No audio memory found: end_user_id={end_user_id}") return success( data=None, msg="No audio memory available" @@ -145,38 +145,38 @@ def get_last_memory_listen( ) except Exception as e: - api_logger.error(f"Failed to fetch latest audio memory: group_id={group_id}, error={str(e)}") + api_logger.error(f"Failed to fetch latest audio memory: end_user_id={end_user_id}, error={str(e)}") return fail( code=BizCode.INTERNAL_ERROR, msg="Failed to fetch latest audio memory", ) -@router.get("/{group_id}/last_text", response_model=ApiResponse) +@router.get("/{end_user_id}/last_text", response_model=ApiResponse) def get_last_text_memory( - group_id: uuid.UUID, + end_user_id: uuid.UUID, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): """Retrieve the most recent TEXT-type memory for a user. Args: - group_id: ID of the user group + end_user_id: ID of the user group current_user: Current authenticated user db: Database session Returns: ApiResponse: Metadata of the latest text memory """ - api_logger.info(f"Fetching latest text memory: user={current_user.username}, group_id={group_id}") + api_logger.info(f"Fetching latest text memory: user={current_user.username}, end_user_id={end_user_id}") try: # 调用服务层获取最近的文本记忆 service = MemoryPerceptualService(db) - text_memory = service.get_latest_text_memory(group_id) + text_memory = service.get_latest_text_memory(end_user_id) if text_memory is None: - api_logger.info(f"No text memory found: group_id={group_id}") + api_logger.info(f"No text memory found: end_user_id={end_user_id}") return success( data=None, msg="No text memory available" @@ -190,16 +190,16 @@ def get_last_text_memory( ) except Exception as e: - api_logger.error(f"Failed to fetch latest text memory: group_id={group_id}, error={str(e)}") + api_logger.error(f"Failed to fetch latest text memory: end_user_id={end_user_id}, error={str(e)}") return fail( code=BizCode.INTERNAL_ERROR, msg="Failed to fetch latest text memory", ) -@router.get("/{group_id}/timeline", response_model=ApiResponse) +@router.get("/{end_user_id}/timeline", response_model=ApiResponse) def get_memory_time_line( - group_id: uuid.UUID, + end_user_id: uuid.UUID, perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"), page: int = Query(1, ge=1, description="页码"), page_size: int = Query(10, ge=1, le=100, description="每页大小"), @@ -209,7 +209,7 @@ def get_memory_time_line( """Retrieve a timeline of perceptual memories for a user group. Args: - group_id: ID of the user group + end_user_id: ID of the user group perceptual_type: Optional filter for perceptual type page: Page number for pagination page_size: Number of items per page @@ -221,7 +221,7 @@ def get_memory_time_line( """ api_logger.info( f"Fetching perceptual memory timeline: user={current_user.username}, " - f"group_id={group_id}, type={perceptual_type}, page={page}" + f"end_user_id={end_user_id}, type={perceptual_type}, page={page}" ) try: @@ -232,7 +232,7 @@ def get_memory_time_line( ) service = MemoryPerceptualService(db) - timeline_data = service.get_time_line(group_id, query) + timeline_data = service.get_time_line(end_user_id, query) api_logger.info( f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, " @@ -246,7 +246,7 @@ def get_memory_time_line( except Exception as e: api_logger.error( - f"Failed to fetch perceptual memory timeline: group_id={group_id}, " + f"Failed to fetch perceptual memory timeline: end_user_id={end_user_id}, " f"error={str(e)}" ) return fail( diff --git a/api/app/controllers/memory_reflection_controller.py b/api/app/controllers/memory_reflection_controller.py index abd50a33..ccf9485f 100644 --- a/api/app/controllers/memory_reflection_controller.py +++ b/api/app/controllers/memory_reflection_controller.py @@ -1,6 +1,7 @@ import asyncio import time import uuid +from uuid import UUID from app.core.logging_config import get_api_logger from app.core.memory.storage_services.reflection_engine.self_reflexion import ( @@ -11,7 +12,7 @@ from app.core.response_utils import success from app.db import get_db from app.dependencies import get_current_user from app.models.user_model import User -from app.repositories.data_config_repository import DataConfigRepository +from app.repositories.memory_config_repository import MemoryConfigRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_reflection_schemas import Memory_Reflection from app.services.memory_reflection_service import ( @@ -50,7 +51,7 @@ async def save_reflection_config( api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}") - data_config = DataConfigRepository.update_reflection_config( + memory_config = MemoryConfigRepository.update_reflection_config( db, config_id=config_id, enable_self_reflexion=request.reflection_enabled, @@ -63,17 +64,17 @@ async def save_reflection_config( ) db.commit() - db.refresh(data_config) + db.refresh(memory_config) reflection_result={ - "config_id": data_config.config_id, - "enable_self_reflexion": data_config.enable_self_reflexion, - "iteration_period": data_config.iteration_period, - "reflexion_range": data_config.reflexion_range, - "baseline": data_config.baseline, - "reflection_model_id": data_config.reflection_model_id, - "memory_verify": data_config.memory_verify, - "quality_assessment": data_config.quality_assessment} + "config_id": memory_config.config_id, + "enable_self_reflexion": memory_config.enable_self_reflexion, + "iteration_period": memory_config.iteration_period, + "reflexion_range": memory_config.reflexion_range, + "baseline": memory_config.baseline, + "reflection_model_id": memory_config.reflection_model_id, + "memory_verify": memory_config.memory_verify, + "quality_assessment": memory_config.quality_assessment} return success(data=reflection_result, msg="反思配置成功") @@ -111,14 +112,14 @@ async def start_workspace_reflection( reflection_results = [] for data in result['apps_detailed_info']: - if data['data_configs'] == []: + if data['memory_configs'] == []: continue releases = data['releases'] - data_configs = data['data_configs'] + memory_configs = data['memory_configs'] end_users = data['end_users'] - for base, config, user in zip(releases, data_configs, end_users): + for base, config, user in zip(releases, memory_configs, end_users): # 安全地转换为整数,处理空字符串和None的情况 print(base['config']) try: @@ -156,14 +157,14 @@ async def start_workspace_reflection( @router.get("/reflection/configs") async def start_reflection_configs( - config_id: int, + config_id: uuid.UUID, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: - """通过config_id查询data_config表中的反思配置信息""" + """通过config_id查询memory_config表中的反思配置信息""" try: api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}") - result = DataConfigRepository.query_reflection_config_by_id(db, config_id) + result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id) # 构建返回数据 reflection_config = { "config_id": result.config_id, @@ -191,7 +192,7 @@ async def start_reflection_configs( @router.get("/reflection/run") async def reflection_run( - config_id: int, + config_id: UUID, language_type: str = Header(default="zh", alias="X-Language-Type"), current_user: User = Depends(get_current_user), db: Session = Depends(get_db), @@ -200,8 +201,8 @@ async def reflection_run( api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}") - # 使用DataConfigRepository查询反思配置 - result = DataConfigRepository.query_reflection_config_by_id(db, config_id) + # 使用MemoryConfigRepository查询反思配置 + result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id) if not result: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index 3722be3d..fb0ebc14 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -1,5 +1,6 @@ import os from typing import Optional +from uuid import UUID from app.core.error_codes import BizCode from app.core.logging_config import get_api_logger @@ -160,7 +161,7 @@ def create_config( @router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称) def delete_config( - config_id: str, + config_id: UUID, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: @@ -232,7 +233,7 @@ def update_config_extracted( @router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除 def read_config_extracted( - config_id: str, + config_id: UUID, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ) -> dict: diff --git a/api/app/controllers/memory_working_controller.py b/api/app/controllers/memory_working_controller.py index dfd64044..e5de3c04 100644 --- a/api/app/controllers/memory_working_controller.py +++ b/api/app/controllers/memory_working_controller.py @@ -20,18 +20,18 @@ router = APIRouter( ) -@router.get("/{group_id}/count", response_model=ApiResponse) +@router.get("/{end_user_id}/count", response_model=ApiResponse) def get_memory_count( - group_id: uuid.UUID, + end_user_id: uuid.UUID, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): pass -@router.get("/{group_id}/conversations", response_model=ApiResponse) +@router.get("/{end_user_id}/conversations", response_model=ApiResponse) def get_conversations( - group_id: uuid.UUID, + end_user_id: uuid.UUID, current_user: User = Depends(get_current_user), db: Session = Depends(get_db) ): @@ -39,7 +39,7 @@ def get_conversations( Retrieve all conversations for the current user in a specific group. Args: - group_id (UUID): The group identifier. + end_user_id (UUID): The group identifier. current_user (User, optional): The authenticated user. db (Session, optional): SQLAlchemy session. @@ -53,7 +53,7 @@ def get_conversations( """ conversation_service = ConversationService(db) conversations = conversation_service.get_user_conversations( - group_id + end_user_id ) return success(data=[ { @@ -63,7 +63,7 @@ def get_conversations( ], msg="get conversations success") -@router.get("/{group_id}/messages", response_model=ApiResponse) +@router.get("/{end_user_id}/messages", response_model=ApiResponse) def get_messages( conversation_id: uuid.UUID, current_user: User = Depends(get_current_user), @@ -100,7 +100,7 @@ def get_messages( return success(data=messages, msg="get conversation history success") -@router.get("/{group_id}/detail", response_model=ApiResponse) +@router.get("/{end_user_id}/detail", response_model=ApiResponse) async def get_conversation_detail( conversation_id: uuid.UUID, current_user: User = Depends(get_current_user), diff --git a/api/app/controllers/service/memory_api_controller.py b/api/app/controllers/service/memory_api_controller.py index 30ca1306..accd749e 100644 --- a/api/app/controllers/service/memory_api_controller.py +++ b/api/app/controllers/service/memory_api_controller.py @@ -39,7 +39,7 @@ async def write_memory_api_service( Stores memory content for the specified end user using the Memory API Service. """ - logger.info(f"Memory write request - end_user_id: {payload.end_user_id}") + logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}") memory_api_service = MemoryAPIService(db) diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py index 6f02f8f9..39cbe523 100644 --- a/api/app/controllers/user_memory_controllers.py +++ b/api/app/controllers/user_memory_controllers.py @@ -135,27 +135,27 @@ async def generate_cache_api( api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - group_id = request.end_user_id + end_user_id = request.end_user_id api_logger.info( f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, " - f"end_user_id={group_id if group_id else '全部用户'}" + f"end_user_id={end_user_id if end_user_id else '全部用户'}" ) try: - if group_id: + if end_user_id: # 为单个用户生成 - api_logger.info(f"开始为单个用户生成缓存: end_user_id={group_id}") + api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}") # 生成记忆洞察 - insight_result = await user_memory_service.generate_and_cache_insight(db, group_id, workspace_id) + insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id) # 生成用户摘要 - summary_result = await user_memory_service.generate_and_cache_summary(db, group_id, workspace_id) + summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id) # 构建响应 result = { - "end_user_id": group_id, + "end_user_id": end_user_id, "insight_success": insight_result["success"], "summary_success": summary_result["success"], "errors": [] @@ -175,9 +175,9 @@ async def generate_cache_api( # 记录结果 if result["insight_success"] and result["summary_success"]: - api_logger.info(f"成功为用户 {group_id} 生成缓存") + api_logger.info(f"成功为用户 {end_user_id} 生成缓存") else: - api_logger.warning(f"用户 {group_id} 的缓存生成部分失败: {result['errors']}") + api_logger.warning(f"用户 {end_user_id} 的缓存生成部分失败: {result['errors']}") return success(data=result, msg="生成完成") diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 87b46e6f..ddacb094 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -155,13 +155,13 @@ class LangChainAgent: # userid=end_user_end, # messages=messages, # apply_id=end_user_end, - # group_id=end_user_end, + # end_user_id=end_user_end, # aimessages=aimessages # ) # store.delete_duplicate_sessions() # # logger.info(f'Redis_Agent:{end_user_end};{session_id}') # return session_id - + # TODO 乐力齐 - 累积多组对话批量写入功能已禁用 # async def term_memory_redis_read(self,end_user_end): # end_user_end = f"Term_{end_user_end}" @@ -179,7 +179,7 @@ class LangChainAgent: async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id): """ 写入记忆(支持结构化消息) - + Args: storage_type: 存储类型 (neo4j/rag) end_user_id: 终端用户ID @@ -188,7 +188,7 @@ class LangChainAgent: user_rag_memory_id: RAG 记忆ID actual_end_user_id: 实际用户ID actual_config_id: 配置ID - + 逻辑说明: - RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变 - Neo4j 模式:使用结构化消息列表 @@ -204,20 +204,20 @@ class LangChainAgent: else: # Neo4j 模式:使用结构化消息列表 structured_messages = [] - + # 始终添加用户消息(如果不为空) if user_message: structured_messages.append({"role": "user", "content": user_message}) - + # 只有当 AI 回复不为空时才添加 assistant 消息 if ai_message: structured_messages.append({"role": "assistant", "content": ai_message}) - + # 如果没有消息,直接返回 if not structured_messages: logger.warning(f"No messages to write for user {actual_end_user_id}") return - + # 调用 Celery 任务,传递结构化消息列表 # 数据流: # 1. structured_messages 传递给 write_message_task @@ -228,7 +228,7 @@ class LangChainAgent: # 6. 每个 Chunk 保存到 Neo4j,包含 speaker 字段 logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") write_id = write_message_task.delay( - actual_end_user_id, # group_id: 用户ID + actual_end_user_id, # end_user_id: 用户ID structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] actual_config_id, # config_id: 配置ID storage_type, # storage_type: "neo4j" diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py index 2bad650a..ac1fb9a6 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py @@ -35,10 +35,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState: """问题分解节点""" # 从状态中获取数据 content = state.get('data', '') - group_id = state.get('group_id', '') + end_user_id = state.get('end_user_id', '') memory_config = state.get('memory_config', None) - history = await SessionService(store).get_history(group_id, group_id, group_id) + history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id) # 生成 JSON schema 以指导 LLM 输出正确格式 json_schema = ProblemExtensionResponse.model_json_schema() @@ -140,7 +140,7 @@ async def Problem_Extension(state: ReadState) -> ReadState: start = time.time() content = state.get('data', '') data = state.get('spit_data', '')['context'] - group_id = state.get('group_id', '') + end_user_id = state.get('end_user_id', '') storage_type = state.get('storage_type', '') user_rag_memory_id = state.get('user_rag_memory_id', '') memory_config = state.get('memory_config', None) @@ -156,7 +156,7 @@ async def Problem_Extension(state: ReadState) -> ReadState: databasets = {} data = [] - history = await SessionService(store).get_history(group_id, group_id, group_id) + history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id) # 生成 JSON schema 以指导 LLM 输出正确格式 json_schema = ProblemExtensionResponse.model_json_schema() diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py index 14f8fa8b..1880357c 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py @@ -52,9 +52,9 @@ async def rag_config(state): return kb_config async def rag_knowledge(state,question): kb_config = await rag_config(state) - group_id = state.get('group_id', '') + end_user_id = state.get('end_user_id', '') user_rag_memory_id=state.get("user_rag_memory_id",'') - retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(group_id)]) + retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)]) try: retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] clean_content = '\n\n'.join(retrieval_knowledge) @@ -159,7 +159,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState: problem_extension=state.get('problem_extension', '')['context'] storage_type=state.get('storage_type', '') user_rag_memory_id=state.get('user_rag_memory_id', '') - group_id=state.get('group_id', '') + end_user_id=state.get('end_user_id', '') memory_config = state.get('memory_config', None) original=state.get('data', '') problem_list=[] @@ -172,7 +172,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState: try: # Prepare search parameters based on storage type search_params = { - "group_id": group_id, + "end_user_id": end_user_id, "question": question, "return_raw_results": True } @@ -263,13 +263,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState: async def retrieve(state: ReadState) -> ReadState: - # 从state中获取group_id + # 从state中获取end_user_id import time start=time.time() problem_extension = state.get('problem_extension', '')['context'] storage_type = state.get('storage_type', '') user_rag_memory_id = state.get('user_rag_memory_id', '') - group_id = state.get('group_id', '') + end_user_id = state.get('end_user_id', '') memory_config = state.get('memory_config', None) original = state.get('data', '') problem_list = [] @@ -295,13 +295,13 @@ async def retrieve(state: ReadState) -> ReadState: temperature=0.2, ) - time_retrieval_tool = create_time_retrieval_tool(group_id) - search_params = { "group_id": group_id, "return_raw_results": True } + time_retrieval_tool = create_time_retrieval_tool(end_user_id) + search_params = { "end_user_id": end_user_id, "return_raw_results": True } hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params) agent = create_agent( llm, tools=[time_retrieval_tool,hybrid_retrieval], - system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的group_id是: {group_id}" + system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}" ) # 创建异步任务处理单个问题 diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py index fb0484d2..0144c0e9 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -34,8 +34,8 @@ class SummaryNodeService(LLMServiceMixin): summary_service = SummaryNodeService() async def summary_history(state: ReadState) -> ReadState: - group_id = state.get("group_id", '') - history = await SessionService(store).get_history(group_id, group_id, group_id) + end_user_id = state.get("end_user_id", '') + history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id) return history async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str: @@ -122,12 +122,12 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o async def summary_redis_save(state: ReadState,aimessages) -> ReadState: data = state.get("data", '') - group_id = state.get("group_id", '') + end_user_id = state.get("end_user_id", '') await SessionService(store).save_session( - user_id=group_id, + user_id=end_user_id, query=data, - apply_id=group_id, - group_id=group_id, + apply_id=end_user_id, + end_user_id=end_user_id, ai_response=aimessages ) await SessionService(store).cleanup_duplicates() @@ -175,11 +175,11 @@ async def Input_Summary(state: ReadState) -> ReadState: memory_config = state.get('memory_config', None) user_rag_memory_id=state.get("user_rag_memory_id",'') data=state.get("data", '') - group_id=state.get("group_id", '') + end_user_id=state.get("end_user_id", '') logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") history = await summary_history( state) search_params = { - "group_id": group_id, + "end_user_id": end_user_id, "question": data, "return_raw_results": True, "include": ["summaries"] # Only search summary nodes for faster performance diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py index 10ce8db4..b809faf2 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py @@ -62,12 +62,12 @@ async def Verify(state: ReadState): logger.info("=== Verify 节点开始执行 ===") try: content = state.get('data', '') - group_id = state.get('group_id', '') + end_user_id = state.get('end_user_id', '') memory_config = state.get('memory_config', None) - logger.info(f"Verify: content={content[:50] if content else 'empty'}..., group_id={group_id}") + logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}") - history = await SessionService(store).get_history(group_id, group_id, group_id) + history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id) logger.info(f"Verify: 获取历史记录完成,history length={len(history)}") retrieve = state.get("retrieve", {}) diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py index 6af313c3..b85130ad 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py @@ -1,23 +1,24 @@ - -from app.core.memory.agent.utils.llm_tools import WriteState +from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.utils.write_tools import write from app.core.logging_config import get_agent_logger logger = get_agent_logger(__name__) + + async def write_node(state: WriteState) -> WriteState: """ Write data to the database/file system. Args: - state: WriteState containing messages, group_id, and memory_config + state: WriteState containing messages, end_user_id, and memory_config Returns: dict: Contains 'write_result' with status and data fields """ messages = state.get('messages', []) - group_id = state.get('group_id', '') + end_user_id = state.get('end_user_id', '') memory_config = state.get('memory_config', '') - + # Convert LangChain messages to structured format expected by write() structured_messages = [] for msg in messages: @@ -28,13 +29,11 @@ async def write_node(state: WriteState) -> WriteState: "role": role, "content": msg.content # content is now guaranteed to be a string }) - + try: result = await write( messages=structured_messages, - user_id=group_id, - apply_id=group_id, - group_id=group_id, + end_user_id=end_user_id, memory_config=memory_config, ) logger.info(f"Write completed successfully! Config: {memory_config.config_name}") diff --git a/api/app/core/memory/agent/langgraph_graph/read_graph.py b/api/app/core/memory/agent/langgraph_graph/read_graph.py index 19011a5f..3476d0ec 100644 --- a/api/app/core/memory/agent/langgraph_graph/read_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/read_graph.py @@ -79,7 +79,7 @@ async def make_read_graph(): async def main(): """主函数 - 运行工作流""" message = "昨天有什么好看的电影" - group_id = '88a459f5_text09' # 组ID + end_user_id = '88a459f5_text09' # 组ID storage_type = 'neo4j' # 存储类型 search_switch = '1' # 搜索开关 user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID @@ -95,9 +95,9 @@ async def main(): start=time.time() try: async with make_read_graph() as graph: - config = {"configurable": {"thread_id": group_id}} + config = {"configurable": {"thread_id": end_user_id}} # 初始状态 - 包含所有必要字段 - initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"group_id":group_id + initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id ,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config} # 获取节点更新信息 _intermediate_outputs = [] diff --git a/api/app/core/memory/agent/langgraph_graph/tools/tool.py b/api/app/core/memory/agent/langgraph_graph/tools/tool.py index ce6d5dd4..c4814de1 100644 --- a/api/app/core/memory/agent/langgraph_graph/tools/tool.py +++ b/api/app/core/memory/agent/langgraph_graph/tools/tool.py @@ -48,11 +48,11 @@ def extract_tool_message_content(response): class TimeRetrievalInput(BaseModel): """时间检索工具的输入模式""" context: str = Field(description="用户输入的查询内容") - group_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果") + end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果") -def create_time_retrieval_tool(group_id: str): +def create_time_retrieval_tool(end_user_id: str): """ - 创建一个带有特定group_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements) + 创建一个带有特定end_user_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements) """ def clean_temporal_result_fields(data): @@ -93,26 +93,26 @@ def create_time_retrieval_tool(group_id: str): return data @tool - def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, group_id_param: str = None, clean_output: bool = True) -> str: + def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str: """ 优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段 显式接收参数: - context: 查询上下文内容 - start_date: 开始时间(可选,格式:YYYY-MM-DD) - end_date: 结束时间(可选,格式:YYYY-MM-DD) - - group_id_param: 组ID(可选,用于覆盖默认组ID) + - end_user_id_param: 组ID(可选,用于覆盖默认组ID) - clean_output: 是否清理输出中的元数据字段 -end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d") """ async def _async_search(): # 使用传入的参数或默认值 - actual_group_id = group_id_param or group_id + actual_end_user_id = end_user_id_param or end_user_id actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d") actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d") # 基本时间搜索 results = await search_by_temporal( - group_id=actual_group_id, + end_user_id=actual_end_user_id, start_date=actual_start_date, end_date=actual_end_date, limit=10 @@ -147,7 +147,7 @@ def create_time_retrieval_tool(group_id: str): # 关键词时间搜索 results = await search_by_keyword_temporal( query_text=context, - group_id=group_id, + end_user_id=end_user_id, start_date=actual_start_date, end_date=actual_end_date, limit=15 @@ -172,7 +172,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): Args: memory_config: 内存配置对象 - **search_params: 搜索参数,包含group_id, limit, include等 + **search_params: 搜索参数,包含end_user_id, limit, include等 """ def clean_result_fields(data): @@ -211,7 +211,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): context: str, search_type: str = "hybrid", limit: int = 10, - group_id: str = None, + end_user_id: str = None, rerank_alpha: float = 0.6, use_forgetting_rerank: bool = False, use_llm_rerank: bool = False, @@ -224,7 +224,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): context: 查询内容 search_type: 搜索类型 ('keyword', 'embedding', 'hybrid') limit: 结果数量限制 - group_id: 组ID,用于过滤搜索结果 + end_user_id: 组ID,用于过滤搜索结果 rerank_alpha: 重排序权重参数 use_forgetting_rerank: 是否使用遗忘重排序 use_llm_rerank: 是否使用LLM重排序 @@ -238,7 +238,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): final_params = { "query_text": context, "search_type": search_type, - "group_id": group_id or search_params.get("group_id"), + "end_user_id": end_user_id or search_params.get("end_user_id"), "limit": limit or search_params.get("limit", 10), "include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]), "output_path": None, # 不保存到文件 @@ -291,7 +291,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params): context: str, search_type: str = "hybrid", limit: int = 10, - group_id: str = None, + end_user_id: str = None, clean_output: bool = True ) -> str: """ @@ -301,7 +301,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params): context: 查询内容 search_type: 搜索类型 ('keyword', 'embedding', 'hybrid') limit: 结果数量限制 - group_id: 组ID,用于过滤搜索结果 + end_user_id: 组ID,用于过滤搜索结果 clean_output: 是否清理输出中的元数据字段 """ async def _async_search(): @@ -311,7 +311,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params): "context": context, "search_type": search_type, "limit": limit, - "group_id": group_id, + "end_user_id": end_user_id, "clean_output": clean_output }) diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index fe281a23..8b5de444 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -14,6 +14,7 @@ from app.db import get_db from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node +from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write from app.services.memory_config_service import MemoryConfigService warnings.filterwarnings("ignore", category=RuntimeWarning) @@ -26,9 +27,21 @@ async def make_write_graph(): """ Create a write graph workflow for memory operations. - The workflow directly processes messages from the initial state - and saves them to Neo4j storage. + Args: + user_id: User identifier + tools: MCP tools loaded from session + apply_id: Application identifier + end_user_id: Group identifier + memory_config: MemoryConfig object containing all configuration """ + # workflow = StateGraph(WriteState) + # workflow.add_node("content_input", content_input_write) + # workflow.add_node("save_neo4j", write_node) + # workflow.add_edge(START, "content_input") + # workflow.add_edge("content_input", "save_neo4j") + # workflow.add_edge("save_neo4j", END) + # + # graph = workflow.compile() workflow = StateGraph(WriteState) workflow.add_node("save_neo4j", write_node) workflow.add_edge(START, "save_neo4j") @@ -42,7 +55,7 @@ async def make_write_graph(): async def main(): """主函数 - 运行工作流""" message = "今天周一" - group_id = 'new_2025test1103' # 组ID + end_user_id = 'new_2025test1103' # 组ID # 获取数据库会话 @@ -54,9 +67,9 @@ async def main(): ) try: async with make_write_graph() as graph: - config = {"configurable": {"thread_id": group_id}} + config = {"configurable": {"thread_id": end_user_id}} # 初始状态 - 包含所有必要字段 - initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, "memory_config": memory_config} + initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config} # 获取节点更新信息 async for update_event in graph.astream( diff --git a/api/app/core/memory/agent/services/parameter_builder.py b/api/app/core/memory/agent/services/parameter_builder.py index a58fcf1a..74382ade 100644 --- a/api/app/core/memory/agent/services/parameter_builder.py +++ b/api/app/core/memory/agent/services/parameter_builder.py @@ -24,7 +24,7 @@ class ParameterBuilder: tool_call_id: str, search_switch: str, apply_id: str, - group_id: str, + end_user_id: str, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None ) -> Dict[str, Any]: @@ -44,7 +44,7 @@ class ParameterBuilder: tool_call_id: Extracted tool call identifier search_switch: Search routing parameter apply_id: Application identifier - group_id: Group identifier + end_user_id: Group identifier storage_type: Storage type for the workspace (optional) user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional) @@ -55,7 +55,7 @@ class ParameterBuilder: base_args = { "usermessages": tool_call_id, "apply_id": apply_id, - "group_id": group_id + "end_user_id": end_user_id } # Always add storage_type and user_rag_memory_id (with defaults if None) diff --git a/api/app/core/memory/agent/services/search_service.py b/api/app/core/memory/agent/services/search_service.py index 8a2e7cfe..4fc4256e 100644 --- a/api/app/core/memory/agent/services/search_service.py +++ b/api/app/core/memory/agent/services/search_service.py @@ -91,7 +91,7 @@ class SearchService: async def execute_hybrid_search( self, - group_id: str, + end_user_id: str, question: str, limit: int = 5, search_type: str = "hybrid", @@ -105,7 +105,7 @@ class SearchService: Execute hybrid search and return clean content. Args: - group_id: Group identifier for filtering results + end_user_id: Group identifier for filtering results question: Search query text limit: Maximum number of results to return (default: 5) search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid") @@ -130,7 +130,7 @@ class SearchService: answer = await run_hybrid_search( query_text=cleaned_query, search_type=search_type, - group_id=group_id, + end_user_id=end_user_id, limit=limit, include=include, output_path=output_path, @@ -186,7 +186,7 @@ class SearchService: except Exception as e: logger.error( - f"Search failed for query '{question}' in group '{group_id}': {e}", + f"Search failed for query '{question}' in group '{end_user_id}': {e}", exc_info=True ) # Return empty results on failure diff --git a/api/app/core/memory/agent/services/session_service.py b/api/app/core/memory/agent/services/session_service.py index b2d4f0ff..f7389984 100644 --- a/api/app/core/memory/agent/services/session_service.py +++ b/api/app/core/memory/agent/services/session_service.py @@ -59,7 +59,7 @@ class SessionService: self, user_id: str, apply_id: str, - group_id: str + end_user_id: str ) -> List[dict]: """ Retrieve conversation history from Redis. @@ -67,20 +67,20 @@ class SessionService: Args: user_id: User identifier apply_id: Application identifier - group_id: Group identifier + end_user_id: Group identifier Returns: List of conversation history items with Query and Answer keys Returns empty list if no history found or on error """ try: - history = self.store.find_user_apply_group(user_id, apply_id, group_id) + history = self.store.find_user_apply_group(user_id, apply_id, end_user_id) # Validate history structure if not isinstance(history, list): logger.warning( f"Invalid history format for user {user_id}, " - f"apply {apply_id}, group {group_id}: expected list, got {type(history)}" + f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}" ) return [] @@ -89,7 +89,7 @@ class SessionService: except Exception as e: logger.error( f"Failed to retrieve history for user {user_id}, " - f"apply {apply_id}, group {group_id}: {e}", + f"apply {apply_id}, group {end_user_id}: {e}", exc_info=True ) # Return empty list on error to allow execution to continue @@ -100,7 +100,7 @@ class SessionService: user_id: str, query: str, apply_id: str, - group_id: str, + end_user_id: str, ai_response: str ) -> Optional[str]: """ @@ -110,7 +110,7 @@ class SessionService: user_id: User identifier query: User query/message apply_id: Application identifier - group_id: Group identifier + end_user_id: Group identifier ai_response: AI response/answer Returns: @@ -131,7 +131,7 @@ class SessionService: userid=user_id, messages=query, apply_id=apply_id, - group_id=group_id, + end_user_id=end_user_id, aimessages=ai_response ) @@ -152,7 +152,7 @@ class SessionService: Duplicates are identified by matching: - sessionid - user_id (id field) - - group_id + - end_user_id - messages - aimessages diff --git a/api/app/core/memory/agent/utils/get_dialogs.py b/api/app/core/memory/agent/utils/get_dialogs.py index 82a41773..bfb0f675 100644 --- a/api/app/core/memory/agent/utils/get_dialogs.py +++ b/api/app/core/memory/agent/utils/get_dialogs.py @@ -9,9 +9,7 @@ from app.core.memory.models.message_models import DialogData, ConversationContex async def get_chunked_dialogs( chunker_strategy: str = "RecursiveChunker", - group_id: str = "group_1", - user_id: str = "user1", - apply_id: str = "applyid", + end_user_id: str = "group_1", messages: list = None, ref_id: str = "wyl_20251027", config_id: str = None @@ -20,9 +18,7 @@ async def get_chunked_dialogs( Args: chunker_strategy: The chunking strategy to use (default: RecursiveChunker) - group_id: Group identifier - user_id: User identifier - apply_id: Application identifier + end_user_id: Group identifier messages: Structured message list [{"role": "user", "content": "..."}, ...] ref_id: Reference identifier config_id: Configuration ID for processing @@ -32,42 +28,40 @@ async def get_chunked_dialogs( """ from app.core.logging_config import get_agent_logger logger = get_agent_logger(__name__) - + if not messages or not isinstance(messages, list) or len(messages) == 0: raise ValueError("messages parameter must be a non-empty list") - + conversation_messages = [] - + for idx, msg in enumerate(messages): if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg: raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields") - + role = msg['role'] content = msg['content'] - + if role not in ['user', 'assistant']: raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}") - + if content.strip(): conversation_messages.append(ConversationMessage(role=role, msg=content.strip())) - + if not conversation_messages: raise ValueError("Message list cannot be empty after filtering") - + conversation_context = ConversationContext(msgs=conversation_messages) dialog_data = DialogData( context=conversation_context, ref_id=ref_id, - group_id=group_id, - user_id=user_id, - apply_id=apply_id, + end_user_id=end_user_id, config_id=config_id ) - + chunker = DialogueChunker(chunker_strategy) extracted_chunks = await chunker.process_dialogue(dialog_data) dialog_data.chunks = extracted_chunks - + logger.info(f"DialogData created with {len(extracted_chunks)} chunks") return [dialog_data] diff --git a/api/app/core/memory/agent/utils/llm_tools.py b/api/app/core/memory/agent/utils/llm_tools.py index e73d5653..7f1041cb 100644 --- a/api/app/core/memory/agent/utils/llm_tools.py +++ b/api/app/core/memory/agent/utils/llm_tools.py @@ -13,13 +13,11 @@ class WriteState(TypedDict): Langgrapg Writing TypedDict ''' messages: Annotated[list[AnyMessage], add_messages] - user_id:str - apply_id:str - group_id:str + end_user_id: str errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}] memory_config: object write_result: dict - data:str + data: str class ReadState(TypedDict): """ @@ -29,7 +27,7 @@ class ReadState(TypedDict): messages: 消息列表,支持自动追加 loop_count: 遍历次数 search_switch: 搜索类型开关 - group_id: 组标识 + end_user_id: 组标识 config_id: 配置ID,用于过滤结果 data: 从content_input_node传递的内容数据 spit_data: 从Split_The_Problem传递的分解结果 @@ -40,7 +38,7 @@ class ReadState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式 loop_count: int search_switch: str - group_id: str + end_user_id: str config_id: str data: str # 新增字段用于传递内容 spit_data: dict # 新增字段用于传递问题分解结果 diff --git a/api/app/core/memory/agent/utils/redis_tool.py b/api/app/core/memory/agent/utils/redis_tool.py index 31a76a11..505545b3 100644 --- a/api/app/core/memory/agent/utils/redis_tool.py +++ b/api/app/core/memory/agent/utils/redis_tool.py @@ -28,7 +28,7 @@ class RedisSessionStore: return text # 修改后的 save_session 方法 - def save_session(self, userid, messages, aimessages, apply_id, group_id): + def save_session(self, userid, messages, aimessages, apply_id, end_user_id): """ 写入一条会话数据,返回 session_id 优化版本:确保写入时间不超过1秒 @@ -46,7 +46,7 @@ class RedisSessionStore: "id": self.uudi, "sessionid": userid, "apply_id": apply_id, - "group_id": group_id, + "end_user_id": end_user_id, "messages": messages, "aimessages": aimessages, "starttime": starttime @@ -67,7 +67,7 @@ class RedisSessionStore: def save_sessions_batch(self, sessions_data): """ 批量写入多条会话数据,返回 session_id 列表 - sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, group_id + sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id 优化版本:批量操作,大幅提升性能 """ try: @@ -83,7 +83,7 @@ class RedisSessionStore: "id": self.uudi, "sessionid": session.get('userid'), "apply_id": session.get('apply_id'), - "group_id": session.get('group_id'), + "end_user_id": session.get('end_user_id'), "messages": session.get('messages'), "aimessages": session.get('aimessages'), "starttime": starttime @@ -108,9 +108,9 @@ class RedisSessionStore: data = self.r.hgetall(key) return data if data else None - def get_session_apply_group(self, sessionid, apply_id, group_id): + def get_session_apply_group(self, sessionid, apply_id, end_user_id): """ - 根据 sessionid、apply_id 和 group_id 三个条件查询会话数据 + 根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据 """ result_items = [] @@ -124,7 +124,7 @@ class RedisSessionStore: # 检查三个条件是否都匹配 if (data.get('sessionid') == sessionid and data.get('apply_id') == apply_id and - data.get('group_id') == group_id): + data.get('end_user_id') == end_user_id): result_items.append(data) return result_items @@ -172,7 +172,7 @@ class RedisSessionStore: def delete_duplicate_sessions(self): """ 删除重复会话数据,条件: - "sessionid"、"user_id"、"group_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除 + "sessionid"、"user_id"、"end_user_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除 优化版本:使用 pipeline 批量操作,确保在1秒内完成 """ import time @@ -202,12 +202,12 @@ class RedisSessionStore: # 获取五个字段的值 sessionid = data.get('sessionid', '') user_id = data.get('id', '') - group_id = data.get('group_id', '') + end_user_id = data.get('end_user_id', '') messages = data.get('messages', '') aimessages = data.get('aimessages', '') # 用五元组作为唯一标识 - identifier = (sessionid, user_id, group_id, messages, aimessages) + identifier = (sessionid, user_id, end_user_id, messages, aimessages) if identifier in seen: # 重复,标记为待删除 @@ -248,9 +248,9 @@ class RedisSessionStore: result_items = [] return (result_items) - def find_user_apply_group(self, sessionid, apply_id, group_id): + def find_user_apply_group(self, sessionid, apply_id, end_user_id): """ - 根据 sessionid、apply_id 和 group_id 三个条件查询会话数据,返回最新的6条 + 根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据,返回最新的6条 """ import time start_time = time.time() @@ -276,7 +276,7 @@ class RedisSessionStore: # 检查是否符合三个条件 if (data.get('apply_id') == apply_id and - data.get('group_id') == group_id): + data.get('end_user_id') == end_user_id): # 支持模糊匹配 sessionid 或者完全匹配 if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid: matched_items.append({ diff --git a/api/app/core/memory/agent/utils/session_tools.py b/api/app/core/memory/agent/utils/session_tools.py index b2d4f0ff..f7389984 100644 --- a/api/app/core/memory/agent/utils/session_tools.py +++ b/api/app/core/memory/agent/utils/session_tools.py @@ -59,7 +59,7 @@ class SessionService: self, user_id: str, apply_id: str, - group_id: str + end_user_id: str ) -> List[dict]: """ Retrieve conversation history from Redis. @@ -67,20 +67,20 @@ class SessionService: Args: user_id: User identifier apply_id: Application identifier - group_id: Group identifier + end_user_id: Group identifier Returns: List of conversation history items with Query and Answer keys Returns empty list if no history found or on error """ try: - history = self.store.find_user_apply_group(user_id, apply_id, group_id) + history = self.store.find_user_apply_group(user_id, apply_id, end_user_id) # Validate history structure if not isinstance(history, list): logger.warning( f"Invalid history format for user {user_id}, " - f"apply {apply_id}, group {group_id}: expected list, got {type(history)}" + f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}" ) return [] @@ -89,7 +89,7 @@ class SessionService: except Exception as e: logger.error( f"Failed to retrieve history for user {user_id}, " - f"apply {apply_id}, group {group_id}: {e}", + f"apply {apply_id}, group {end_user_id}: {e}", exc_info=True ) # Return empty list on error to allow execution to continue @@ -100,7 +100,7 @@ class SessionService: user_id: str, query: str, apply_id: str, - group_id: str, + end_user_id: str, ai_response: str ) -> Optional[str]: """ @@ -110,7 +110,7 @@ class SessionService: user_id: User identifier query: User query/message apply_id: Application identifier - group_id: Group identifier + end_user_id: Group identifier ai_response: AI response/answer Returns: @@ -131,7 +131,7 @@ class SessionService: userid=user_id, messages=query, apply_id=apply_id, - group_id=group_id, + end_user_id=end_user_id, aimessages=ai_response ) @@ -152,7 +152,7 @@ class SessionService: Duplicates are identified by matching: - sessionid - user_id (id field) - - group_id + - end_user_id - messages - aimessages diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 1df0b336..446ab86a 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -29,20 +29,18 @@ logger = get_agent_logger(__name__) async def write( - user_id: str, - apply_id: str, - group_id: str, + end_user_id: str, memory_config: MemoryConfig, messages: list, ref_id: str = "wyl20251027", ) -> None: """ Execute the complete knowledge extraction pipeline. - + Args: user_id: User identifier apply_id: Application identifier - group_id: Group identifier + end_user_id: Group identifier memory_config: MemoryConfig object containing all configuration messages: Structured message list [{"role": "user", "content": "..."}, ...] ref_id: Reference ID, defaults to "wyl20251027" @@ -51,14 +49,14 @@ async def write( embedding_model_id = str(memory_config.embedding_model_id) chunker_strategy = memory_config.chunker_strategy config_id = str(memory_config.config_id) - + logger.info("=== MemSci Knowledge Extraction Pipeline ===") logger.info(f"Config: {memory_config.config_name} (ID: {config_id})") logger.info(f"Workspace: {memory_config.workspace_name}") logger.info(f"LLM model: {memory_config.llm_model_name}") logger.info(f"Embedding model: {memory_config.embedding_model_name}") logger.info(f"Chunker strategy: {chunker_strategy}") - logger.info(f"Group ID: {group_id}") + logger.info(f"end_user_id ID: {end_user_id}") # Construct clients from memory_config using factory pattern with db session with get_db_context() as db: @@ -83,9 +81,7 @@ async def write( step_start = time.time() chunked_dialogs = await get_chunked_dialogs( chunker_strategy=chunker_strategy, - group_id=group_id, - user_id=user_id, - apply_id=apply_id, + end_user_id=end_user_id, messages=messages, ref_id=ref_id, config_id=config_id, diff --git a/api/app/core/memory/analytics/hot_memory_tags.py b/api/app/core/memory/analytics/hot_memory_tags.py index cab6cacd..95302726 100644 --- a/api/app/core/memory/analytics/hot_memory_tags.py +++ b/api/app/core/memory/analytics/hot_memory_tags.py @@ -16,13 +16,13 @@ class FilteredTags(BaseModel): """用于接收LLM筛选后的核心标签列表的模型。""" meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。") -async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]: +async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]: """ 使用LLM筛选标签列表,仅保留具有代表性的核心名词。 Args: tags: 原始标签列表 - group_id: 用户组ID,用于获取配置 + end_user_id: 用户组ID,用于获取配置 Returns: 筛选后的标签列表 @@ -37,12 +37,12 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]: get_end_user_connected_config, ) - connected_config = get_end_user_connected_config(group_id, db) + connected_config = get_end_user_connected_config(end_user_id, db) config_id = connected_config.get("memory_config_id") if not config_id: raise ValueError( - f"No memory_config_id found for group_id: {group_id}. " + f"No memory_config_id found for end_user_id: {end_user_id}. " "Please ensure the user has a valid memory configuration." ) @@ -87,7 +87,7 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]: async def get_raw_tags_from_db( connector: Neo4jConnector, - group_id: str, + end_user_id: str, limit: int, by_user: bool = False ) -> List[Tuple[str, int]]: @@ -99,9 +99,9 @@ async def get_raw_tags_from_db( Args: connector: Neo4j连接器实例 - group_id: 如果by_user=False,则为group_id;如果by_user=True,则为user_id + end_user_id: 如果by_user=False,则为end_user_id;如果by_user=True,则为user_id limit: 返回的标签数量限制 - by_user: 是否按user_id查询(默认False,按group_id查询) + by_user: 是否按user_id查询(默认False,按end_user_id查询) Returns: List[Tuple[str, int]]: 标签名称和频率的元组列表 @@ -119,7 +119,7 @@ async def get_raw_tags_from_db( else: query = ( "MATCH (e:ExtractedEntity) " - "WHERE e.group_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude " + "WHERE e.end_user_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude " "RETURN e.name AS name, count(e) AS frequency " "ORDER BY frequency DESC " "LIMIT $limit" @@ -128,44 +128,44 @@ async def get_raw_tags_from_db( # 使用项目的Neo4jConnector执行查询 results = await connector.execute_query( query, - id=group_id, + id=end_user_id, limit=limit, names_to_exclude=names_to_exclude ) return [(record["name"], record["frequency"]) for record in results] -async def get_hot_memory_tags(group_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]: +async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]: """ 获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。 查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。 Args: - group_id: 必需参数。如果by_user=False,则为group_id;如果by_user=True,则为user_id + end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id limit: 返回的标签数量限制 - by_user: 是否按user_id查询(默认False,按group_id查询) + by_user: 是否按user_id查询(默认False,按end_user_id查询) Raises: - ValueError: 如果group_id未提供或为空 + ValueError: 如果end_user_id未提供或为空 """ - # 验证group_id必须提供且不为空 - if not group_id or not group_id.strip(): + # 验证end_user_id必须提供且不为空 + if not end_user_id or not end_user_id.strip(): raise ValueError( - "group_id is required. Please provide a valid group_id or user_id." + "end_user_id is required. Please provide a valid end_user_id or user_id." ) # 使用项目的Neo4jConnector connector = Neo4jConnector() try: # 1. 从数据库获取原始排名靠前的标签 - raw_tags_with_freq = await get_raw_tags_from_db(connector, group_id, limit, by_user=by_user) + raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user) if not raw_tags_with_freq: return [] raw_tag_names = [tag for tag, freq in raw_tags_with_freq] # 2. 初始化LLM客户端并使用LLM筛选出有意义的标签 - meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id) + meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, end_user_id) # 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序) final_tags = [] diff --git a/api/app/core/memory/analytics/implicit_memory/data_source.py b/api/app/core/memory/analytics/implicit_memory/data_source.py index d277a05e..18678a55 100644 --- a/api/app/core/memory/analytics/implicit_memory/data_source.py +++ b/api/app/core/memory/analytics/implicit_memory/data_source.py @@ -75,8 +75,8 @@ class MemoryDataSource: start_date = time_range.start_date if time_range else None end_date = time_range.end_date if time_range else None - summary_dicts = await self.memory_summary_repo.find_by_group_id( - group_id=user_id, + summary_dicts = await self.memory_summary_repo.find_by_end_user_id( + end_user_id=user_id, limit=limit, start_date=start_date, end_date=end_date diff --git a/api/app/core/memory/evaluation/dialogue_queries.py b/api/app/core/memory/evaluation/dialogue_queries.py index fd7fa671..25abe64e 100644 --- a/api/app/core/memory/evaluation/dialogue_queries.py +++ b/api/app/core/memory/evaluation/dialogue_queries.py @@ -41,7 +41,7 @@ DIALOGUE_EMBEDDING_SEARCH = """ WITH $embedding AS q MATCH (d:Dialogue) WHERE d.dialog_embedding IS NOT NULL - AND ($group_id IS NULL OR d.group_id = $group_id) + AND ($end_user_id IS NULL OR d.end_user_id = $end_user_id) WITH d, q, d.dialog_embedding AS v WITH d, reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot, @@ -50,7 +50,7 @@ WITH d, WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score WHERE score > $threshold RETURN d.id AS dialog_id, - d.group_id AS group_id, + d.end_user_id AS end_user_id, d.content AS content, d.created_at AS created_at, d.expired_at AS expired_at, diff --git a/api/app/core/memory/evaluation/extraction_utils.py b/api/app/core/memory/evaluation/extraction_utils.py index 9afa228c..9e70bc28 100644 --- a/api/app/core/memory/evaluation/extraction_utils.py +++ b/api/app/core/memory/evaluation/extraction_utils.py @@ -36,7 +36,7 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector async def ingest_contexts_via_full_pipeline( contexts: List[str], - group_id: str, + end_user_id: str, chunker_strategy: str | None = None, embedding_name: str | None = None, save_chunk_output: bool = False, @@ -48,7 +48,7 @@ async def ingest_contexts_via_full_pipeline( This function mirrors the steps in main(), but starts from raw text contexts. Args: contexts: List of dialogue texts, each containing lines like "role: message". - group_id: Group ID to assign to generated DialogData and graph nodes. + end_user_id: Group ID to assign to generated DialogData and graph nodes. chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY. embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID. save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging. @@ -109,7 +109,7 @@ async def ingest_contexts_via_full_pipeline( dialog = DialogData( context=context_model, ref_id=f"pipeline_item_{idx}", - group_id=group_id, + end_user_id=end_user_id, user_id="default_user", apply_id="default_application", ) @@ -318,16 +318,16 @@ async def handle_context_processing(args): print("No contexts provided for processing.") return False - return await main_from_contexts(contexts, args.context_group_id) + return await main_from_contexts(contexts, args.context_end_user_id) -async def main_from_contexts(contexts: List[str], group_id: str): +async def main_from_contexts(contexts: List[str], end_user_id: str): """Run the pipeline from provided dialogue contexts instead of test data.""" print("=== Running pipeline from provided contexts ===") success = await ingest_contexts_via_full_pipeline( contexts=contexts, - group_id=group_id, + end_user_id=end_user_id, chunker_strategy=SELECTED_CHUNKER_STRATEGY, embedding_name=SELECTED_EMBEDDING_ID, save_chunk_output=True diff --git a/api/app/core/memory/evaluation/locomo/locomo_benchmark.py b/api/app/core/memory/evaluation/locomo/locomo_benchmark.py index b7d988c5..1c70c28e 100644 --- a/api/app/core/memory/evaluation/locomo/locomo_benchmark.py +++ b/api/app/core/memory/evaluation/locomo/locomo_benchmark.py @@ -47,7 +47,7 @@ from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient from app.core.memory.utils.definitions import ( PROJECT_ROOT, SELECTED_EMBEDDING_ID, - SELECTED_GROUP_ID, + SELECTED_end_user_id, SELECTED_LLM_ID, ) from app.core.memory.utils.llm.llm_utils import MemoryClientFactory @@ -59,7 +59,7 @@ from app.services.memory_config_service import MemoryConfigService async def run_locomo_benchmark( sample_size: int = 20, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, search_type: str = "hybrid", search_limit: int = 12, context_char_budget: int = 8000, @@ -85,7 +85,7 @@ async def run_locomo_benchmark( Args: sample_size: Number of QA pairs to evaluate (from first conversation) - group_id: Database group ID for retrieval (uses default if None) + end_user_id: Database group ID for retrieval (uses default if None) search_type: "keyword", "embedding", or "hybrid" search_limit: Max documents to retrieve per query context_char_budget: Max characters for context @@ -96,8 +96,8 @@ async def run_locomo_benchmark( Returns: Dictionary with evaluation results including metrics, timing, and samples """ - # Use default group_id if not provided - group_id = group_id or SELECTED_GROUP_ID + # Use default end_user_id if not provided + end_user_id = end_user_id or SELECTED_end_user_id # Determine data path data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json") @@ -110,7 +110,7 @@ async def run_locomo_benchmark( print(f"{'='*60}") print("📊 Configuration:") print(f" Sample size: {sample_size}") - print(f" Group ID: {group_id}") + print(f" Group ID: {end_user_id}") print(f" Search type: {search_type}") print(f" Search limit: {search_limit}") print(f" Context budget: {context_char_budget} chars") @@ -134,7 +134,7 @@ async def run_locomo_benchmark( # Step 2: Extract conversations and ingest if needed if skip_ingest: print("⏭️ Skipping data ingestion (using existing data in Neo4j)") - print(f" Group ID: {group_id}\n") + print(f" Group ID: {end_user_id}\n") else: print("💾 Checking database ingestion...") try: @@ -142,10 +142,10 @@ async def run_locomo_benchmark( print(f"📝 Extracted {len(conversations)} conversations") # Always ingest for now (ingestion check not implemented) - print(f"🔄 Ingesting conversations into group '{group_id}'...") + print(f"🔄 Ingesting conversations into group '{end_user_id}'...") success = await ingest_conversations_if_needed( conversations=conversations, - group_id=group_id, + end_user_id=end_user_id, reset=reset_group ) @@ -224,7 +224,7 @@ async def run_locomo_benchmark( try: retrieved_info = await retrieve_relevant_information( question=question, - group_id=group_id, + end_user_id=end_user_id, search_type=search_type, search_limit=search_limit, connector=connector, @@ -409,7 +409,7 @@ async def run_locomo_benchmark( "sample_size": len(qa_items), "timestamp": datetime.now().isoformat(), "params": { - "group_id": group_id, + "end_user_id": end_user_id, "search_type": search_type, "search_limit": search_limit, "context_char_budget": context_char_budget, @@ -467,7 +467,7 @@ def main(): help="Number of QA pairs to evaluate" ) parser.add_argument( - "--group_id", + "--end_user_id", type=str, default=None, help="Database group ID for retrieval (uses default if not specified)" @@ -516,7 +516,7 @@ def main(): # Run benchmark result = asyncio.run(run_locomo_benchmark( sample_size=args.sample_size, - group_id=args.group_id, + end_user_id=args.end_user_id, search_type=args.search_type, search_limit=args.search_limit, context_char_budget=args.context_char_budget, diff --git a/api/app/core/memory/evaluation/locomo/locomo_test.py b/api/app/core/memory/evaluation/locomo/locomo_test.py index affedd0f..01c45123 100644 --- a/api/app/core/memory/evaluation/locomo/locomo_test.py +++ b/api/app/core/memory/evaluation/locomo/locomo_test.py @@ -556,7 +556,7 @@ async def run_enhanced_evaluation(): search_results = await run_hybrid_search( query_text=q, search_type="hybrid", - group_id="locomo_sk", + end_user_id="locomo_sk", limit=20, include=["statements", "chunks", "entities", "summaries"], alpha=0.6, # BM25权重 diff --git a/api/app/core/memory/evaluation/locomo/locomo_utils.py b/api/app/core/memory/evaluation/locomo/locomo_utils.py index 69be5da9..d3b74947 100644 --- a/api/app/core/memory/evaluation/locomo/locomo_utils.py +++ b/api/app/core/memory/evaluation/locomo/locomo_utils.py @@ -348,7 +348,7 @@ def select_and_format_information( async def retrieve_relevant_information( question: str, - group_id: str, + end_user_id: str, search_type: str, search_limit: int, connector: Any, @@ -368,7 +368,7 @@ async def retrieve_relevant_information( Args: question: Question to search for - group_id: Database group ID (identifies which conversation memory to search) + end_user_id: Database group ID (identifies which conversation memory to search) search_type: "keyword", "embedding", or "hybrid" search_limit: Max memory pieces to retrieve connector: Neo4j connector instance @@ -396,7 +396,7 @@ async def retrieve_relevant_information( connector=connector, embedder_client=embedder, query_text=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, include=["chunks", "statements", "entities", "summaries"], ) @@ -455,7 +455,7 @@ async def retrieve_relevant_information( search_results = await search_graph( connector=connector, q=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit ) @@ -491,7 +491,7 @@ async def retrieve_relevant_information( search_results = await run_hybrid_search( query_text=question, search_type=search_type, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, include=["chunks", "statements", "entities", "summaries"], output_path=None, @@ -524,7 +524,7 @@ async def retrieve_relevant_information( connector=connector, embedder_client=embedder, query_text=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, include=["chunks", "statements", "entities", "summaries"], ) @@ -584,7 +584,7 @@ async def retrieve_relevant_information( async def ingest_conversations_if_needed( conversations: List[str], - group_id: str, + end_user_id: str, reset: bool = False ) -> bool: """ @@ -603,7 +603,7 @@ async def ingest_conversations_if_needed( Args: conversations: List of raw conversation texts from LoCoMo dataset Example: ["User: I went to Paris. AI: When was that?", ...] - group_id: Target group ID for database storage + end_user_id: Target group ID for database storage reset: Whether to clear existing data first (not implemented in wrapper) Returns: @@ -617,7 +617,7 @@ async def ingest_conversations_if_needed( try: success = await ingest_contexts_via_full_pipeline( contexts=conversations, - group_id=group_id, + end_user_id=end_user_id, save_chunk_output=True ) return success diff --git a/api/app/core/memory/evaluation/locomo/qwen_search_eval.py b/api/app/core/memory/evaluation/locomo/qwen_search_eval.py index 87a70a29..6a5caa0c 100644 --- a/api/app/core/memory/evaluation/locomo/qwen_search_eval.py +++ b/api/app/core/memory/evaluation/locomo/qwen_search_eval.py @@ -249,7 +249,7 @@ def get_search_params_by_category(category: str): async def run_locomo_eval( sample_size: int = 1, - group_id: str | None = None, + end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, # 保持默认值不变 llm_temperature: float = 0.0, @@ -262,7 +262,7 @@ async def run_locomo_eval( ) -> Dict[str, Any]: # 函数内部使用三路检索逻辑,但保持参数签名不变 - group_id = group_id or SELECTED_GROUP_ID + end_user_id = end_user_id or SELECTED_end_user_id data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json") if not os.path.exists(data_path): data_path = os.path.join(os.getcwd(), "data", "locomo10.json") @@ -340,7 +340,7 @@ async def run_locomo_eval( # 关键修复:强制重新摄入纯净的对话数据 print("🔄 强制重新摄入纯净的对话数据...") - await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True) + await ingest_contexts_via_full_pipeline(contents, end_user_id, save_chunk_output=True) # 使用异步LLM客户端 with get_db_context() as db: @@ -405,7 +405,7 @@ async def run_locomo_eval( connector=connector, embedder_client=embedder, query_text=q, - group_id=group_id, + end_user_id=end_user_id, limit=adjusted_limit, include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型 ) @@ -456,7 +456,7 @@ async def run_locomo_eval( search_results = await search_graph( connector=connector, q=q, - group_id=group_id, + end_user_id=end_user_id, limit=adjusted_limit ) dialogs = search_results.get("dialogues", []) @@ -486,7 +486,7 @@ async def run_locomo_eval( search_results = await run_hybrid_search( query_text=q, search_type=search_type, - group_id=group_id, + end_user_id=end_user_id, limit=adjusted_limit, include=["chunks", "statements", "entities", "summaries"], output_path=None, @@ -524,7 +524,7 @@ async def run_locomo_eval( connector=connector, embedder_client=embedder, query_text=q, - group_id=group_id, + end_user_id=end_user_id, limit=adjusted_limit, include=["chunks", "statements", "entities", "summaries"], ) @@ -597,7 +597,7 @@ async def run_locomo_eval( "dialogues": [ { "uuid": d.get("uuid", ""), - "group_id": d.get("group_id", ""), + "end_user_id": d.get("end_user_id", ""), "content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""), "score": d.get("score", 0.0) } @@ -795,7 +795,7 @@ async def run_locomo_eval( }, "samples": samples, "params": { - "group_id": group_id, + "end_user_id": end_user_id, "search_limit": search_limit, "context_char_budget": context_char_budget, "search_type": search_type, @@ -825,7 +825,7 @@ async def run_locomo_eval( def main(): parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search") parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate") - parser.add_argument("--group_id", type=str, default=None, help="Group ID for retrieval") + parser.add_argument("--end_user_id", type=str, default=None, help="Group ID for retrieval") parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query") parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context") parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature") @@ -841,7 +841,7 @@ def main(): result = asyncio.run(run_locomo_eval( sample_size=args.sample_size, - group_id=args.group_id, + end_user_id=args.end_user_id, search_limit=args.search_limit, context_char_budget=args.context_char_budget, llm_temperature=args.llm_temperature, diff --git a/api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py b/api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py index 292e7288..8710a504 100644 --- a/api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py +++ b/api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py @@ -524,11 +524,11 @@ def generate_query_keywords_cn(question: str) -> List[str]: # 通过别名匹配进行实体关键词检索(多token合并) -async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], group_id: str | None, limit: int) -> List[Dict[str, Any]]: +async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], end_user_id: str | None, limit: int) -> List[Dict[str, Any]]: results: List[Dict[str, Any]] = [] try: for tok in tokens: - rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, group_id=group_id, limit=limit) + rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, end_user_id=end_user_id, limit=limit) if rows: results.extend(rows) except Exception: @@ -548,15 +548,15 @@ async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[st # 通过对话/陈述中的entity_ids反查实体名称 _FETCH_ENTITIES_BY_IDS = """ MATCH (e:ExtractedEntity) -WHERE e.id IN $ids AND ($group_id IS NULL OR e.group_id = $group_id) -RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type +WHERE e.id IN $ids AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id) +RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type """ -async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], group_id: str | None) -> List[Dict[str, Any]]: +async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], end_user_id: str | None) -> List[Dict[str, Any]]: if not ids: return [] try: - rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), group_id=group_id) + rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), end_user_id=end_user_id) return rows or [] except Exception: return [] @@ -566,18 +566,18 @@ async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], grou _TIME_ENTITY_SEARCH = """ MATCH (e:ExtractedEntity) WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern -AND ($group_id IS NULL OR e.group_id = $group_id) -RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type +AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id) +RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type LIMIT $limit """ -async def _search_time_entities(connector: Neo4jConnector, group_id: str | None, limit: int = 5) -> List[Dict[str, Any]]: +async def _search_time_entities(connector: Neo4jConnector, end_user_id: str | None, limit: int = 5) -> List[Dict[str, Any]]: """专门搜索时间相关的实体""" try: date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*" rows = await connector.execute_query(_TIME_ENTITY_SEARCH, date_pattern=date_pattern, - group_id=group_id, + end_user_id=end_user_id, limit=limit) return rows or [] except Exception: @@ -624,7 +624,7 @@ def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str: async def run_longmemeval_test( sample_size: int = 3, - group_id: str = "longmemeval_zh_bak_3", + end_user_id: str = "longmemeval_zh_bak_3", search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, @@ -678,13 +678,13 @@ async def run_longmemeval_test( contexts.extend(selected) print(f"📥 摄入 {len(contexts)} 个上下文到数据库") - if reset_group_before_ingest and group_id: + if reset_group_before_ingest and end_user_id: try: _tmp_conn = Neo4jConnector() - await _tmp_conn.delete_group(group_id) - print(f"🧹 已清空组 {group_id} 的历史图数据") + await _tmp_conn.delete_group(end_user_id) + print(f"🧹 已清空组 {end_user_id} 的历史图数据") except Exception as _e: - print(f"⚠️ 清空组数据失败(忽略继续): {group_id} - {_e}") + print(f"⚠️ 清空组数据失败(忽略继续): {end_user_id} - {_e}") finally: try: await _tmp_conn.close() @@ -696,7 +696,7 @@ async def run_longmemeval_test( else: await _ingest_fn( contexts, - group_id, + end_user_id, save_chunk_output=save_chunk_output, save_chunk_output_path=save_chunk_output_path, ) @@ -751,7 +751,7 @@ async def run_longmemeval_test( connector=connector, embedder_client=embedder, query_text=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, include=["chunks", "statements", "entities", "summaries"], ) @@ -796,7 +796,7 @@ async def run_longmemeval_test( search_results = await search_graph( connector=connector, q=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, ) chunks = search_results.get("chunks", []) @@ -831,7 +831,7 @@ async def run_longmemeval_test( connector=connector, embedder_client=embedder, query_text=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, include=["chunks", "statements", "entities", "summaries"], ) @@ -849,7 +849,7 @@ async def run_longmemeval_test( kw_res = await search_graph( connector=connector, q=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, ) if isinstance(kw_res, dict): @@ -860,7 +860,7 @@ async def run_longmemeval_test( # 时间推理问题的特殊处理 if is_temporal: # 专门搜索时间实体 - time_entities = await _search_time_entities(connector, group_id, search_limit//2) + time_entities = await _search_time_entities(connector, end_user_id, search_limit//2) if time_entities: kw_entities.extend(time_entities) # 添加时间相关关键词检索 @@ -870,7 +870,7 @@ async def run_longmemeval_test( time_res = await search_graph( connector=connector, q=tk, - group_id=group_id, + end_user_id=end_user_id, limit=2, ) if isinstance(time_res, dict): @@ -881,7 +881,7 @@ async def run_longmemeval_test( # 中文关键词拆分后做别名匹配 cn_tokens = _extract_cn_tokens(question) - alias_entities = await _search_entities_by_aliases(connector, cn_tokens, group_id, search_limit) + alias_entities = await _search_entities_by_aliases(connector, cn_tokens, end_user_id, search_limit) if alias_entities: kw_entities.extend(alias_entities) @@ -895,7 +895,7 @@ async def run_longmemeval_test( except Exception: pass if ids: - id_entities = await _fetch_entities_by_ids(connector, ids, group_id) + id_entities = await _fetch_entities_by_ids(connector, ids, end_user_id) if id_entities: kw_entities.extend(id_entities) @@ -909,7 +909,7 @@ async def run_longmemeval_test( sub_res = await search_graph( connector=connector, q=str(kw), - group_id=group_id, + end_user_id=end_user_id, limit=max(3, search_limit // 2), ) if isinstance(sub_res, dict): @@ -928,7 +928,7 @@ async def run_longmemeval_test( opt_res = await search_graph( connector=connector, q=str(opt), - group_id=group_id, + end_user_id=end_user_id, limit=max(3, search_limit // 2), ) if isinstance(opt_res, dict): @@ -1010,7 +1010,7 @@ async def run_longmemeval_test( kw_fallback = await search_graph( connector=connector, q=question, - group_id=group_id, + end_user_id=end_user_id, limit=max(search_limit, 5), ) fb_dialogs = kw_fallback.get("dialogues", []) or [] @@ -1224,7 +1224,7 @@ async def run_longmemeval_test( "count_avg": statistics.mean(per_query_context_counts) if per_query_context_counts else 0.0, }, "params": { - "group_id": group_id, + "end_user_id": end_user_id, "search_limit": search_limit, "context_char_budget": context_char_budget, "search_type": search_type, @@ -1307,7 +1307,7 @@ def main(): result = asyncio.run( run_longmemeval_test( sample_size=sample_size, - group_id=args.group_id, + end_user_id=args.end_user_id, search_limit=args.search_limit, context_char_budget=args.context_char_budget, llm_temperature=args.llm_temperature, diff --git a/api/app/core/memory/evaluation/longmemeval/test_eval.py b/api/app/core/memory/evaluation/longmemeval/test_eval.py index 08a763e3..67bd6ec2 100644 --- a/api/app/core/memory/evaluation/longmemeval/test_eval.py +++ b/api/app/core/memory/evaluation/longmemeval/test_eval.py @@ -498,11 +498,11 @@ def smart_context_selection(contexts: List[str], question: str, max_chars: int = # 通过别名匹配进行实体关键词检索(多token合并) -async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], group_id: str | None, limit: int) -> List[Dict[str, Any]]: +async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], end_user_id: str | None, limit: int) -> List[Dict[str, Any]]: results: List[Dict[str, Any]] = [] try: for tok in tokens: - rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, group_id=group_id, limit=limit) + rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, end_user_id=end_user_id, limit=limit) if rows: results.extend(rows) except Exception: @@ -522,15 +522,15 @@ async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[st # 通过对话/陈述中的entity_ids反查实体名称 _FETCH_ENTITIES_BY_IDS = """ MATCH (e:ExtractedEntity) -WHERE e.id IN $ids AND ($group_id IS NULL OR e.group_id = $group_id) -RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type +WHERE e.id IN $ids AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id) +RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type """ -async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], group_id: str | None) -> List[Dict[str, Any]]: +async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], end_user_id: str | None) -> List[Dict[str, Any]]: if not ids: return [] try: - rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), group_id=group_id) + rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), end_user_id=end_user_id) return rows or [] except Exception: return [] @@ -540,18 +540,18 @@ async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], grou _TIME_ENTITY_SEARCH = """ MATCH (e:ExtractedEntity) WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern -AND ($group_id IS NULL OR e.group_id = $group_id) -RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type +AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id) +RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type LIMIT $limit """ -async def _search_time_entities(connector: Neo4jConnector, group_id: str | None, limit: int = 5) -> List[Dict[str, Any]]: +async def _search_time_entities(connector: Neo4jConnector, end_user_id: str | None, limit: int = 5) -> List[Dict[str, Any]]: """专门搜索时间相关的实体""" try: date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*" rows = await connector.execute_query(_TIME_ENTITY_SEARCH, date_pattern=date_pattern, - group_id=group_id, + end_user_id=end_user_id, limit=limit) return rows or [] except Exception: @@ -559,25 +559,25 @@ async def _search_time_entities(connector: Neo4jConnector, group_id: str | None, # 技术术语专门检索 -async def _search_tech_terms(connector: Neo4jConnector, question: str, group_id: str | None, limit: int = 3) -> List[Dict[str, Any]]: +async def _search_tech_terms(connector: Neo4jConnector, question: str, end_user_id: str | None, limit: int = 3) -> List[Dict[str, Any]]: """专门搜索技术术语相关的实体""" tech_entities = [] try: # GPS相关 if any(term in question for term in ["GPS", "导航", "定位系统"]): - gps_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="GPS", group_id=group_id, limit=limit) + gps_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="GPS", end_user_id=end_user_id, limit=limit) if gps_rows: tech_entities.extend(gps_rows) # 活动相关 if any(term in question for term in ["工作坊", "研讨会", "网络研讨会"]): - workshop_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="工作坊", group_id=group_id, limit=limit) + workshop_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="工作坊", end_user_id=end_user_id, limit=limit) if workshop_rows: tech_entities.extend(workshop_rows) # 时间顺序相关 if any(term in question for term in ["先", "后", "第一个"]): - time_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="第一次", group_id=group_id, limit=limit) + time_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="第一次", end_user_id=end_user_id, limit=limit) if time_rows: tech_entities.extend(time_rows) @@ -627,7 +627,7 @@ def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str: async def run_longmemeval_test( sample_size: int = 3, - group_id: str = "longmemeval_zh_bak_2", + end_user_id: str = "longmemeval_zh_bak_2", search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, @@ -707,7 +707,7 @@ async def run_longmemeval_test( connector=connector, embedder_client=embedder, query_text=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, include=["dialogues", "statements", "entities"], ) @@ -746,7 +746,7 @@ async def run_longmemeval_test( search_results = await search_graph( connector=connector, q=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, ) dialogs = search_results.get("dialogues", []) @@ -776,7 +776,7 @@ async def run_longmemeval_test( connector=connector, embedder_client=embedder, query_text=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, include=["dialogues", "statements", "entities"], ) @@ -792,7 +792,7 @@ async def run_longmemeval_test( kw_res = await search_graph( connector=connector, q=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, ) if isinstance(kw_res, dict): @@ -801,14 +801,14 @@ async def run_longmemeval_test( kw_entities = kw_res.get("entities", []) or [] # 技术术语专门检索 - tech_entities = await _search_tech_terms(connector, question, group_id, search_limit//2) + tech_entities = await _search_tech_terms(connector, question, end_user_id, search_limit//2) if tech_entities: kw_entities.extend(tech_entities) # 时间推理问题的特殊处理 if is_temporal: # 专门搜索时间实体 - time_entities = await _search_time_entities(connector, group_id, search_limit//2) + time_entities = await _search_time_entities(connector, end_user_id, search_limit//2) if time_entities: kw_entities.extend(time_entities) # 添加时间相关关键词检索 @@ -818,7 +818,7 @@ async def run_longmemeval_test( time_res = await search_graph( connector=connector, q=tk, - group_id=group_id, + end_user_id=end_user_id, limit=2, ) if isinstance(time_res, dict): @@ -829,7 +829,7 @@ async def run_longmemeval_test( # 中文关键词拆分后做别名匹配 cn_tokens = generate_query_keywords_cn(question) # 使用增强版关键词提取 - alias_entities = await _search_entities_by_aliases(connector, cn_tokens, group_id, search_limit) + alias_entities = await _search_entities_by_aliases(connector, cn_tokens, end_user_id, search_limit) if alias_entities: kw_entities.extend(alias_entities) @@ -843,7 +843,7 @@ async def run_longmemeval_test( except Exception: pass if ids: - id_entities = await _fetch_entities_by_ids(connector, ids, group_id) + id_entities = await _fetch_entities_by_ids(connector, ids, end_user_id) if id_entities: kw_entities.extend(id_entities) @@ -857,7 +857,7 @@ async def run_longmemeval_test( sub_res = await search_graph( connector=connector, q=str(kw), - group_id=group_id, + end_user_id=end_user_id, limit=max(3, search_limit // 2), ) if isinstance(sub_res, dict): @@ -876,7 +876,7 @@ async def run_longmemeval_test( opt_res = await search_graph( connector=connector, q=str(opt), - group_id=group_id, + end_user_id=end_user_id, limit=max(3, search_limit // 2), ) if isinstance(opt_res, dict): @@ -971,7 +971,7 @@ async def run_longmemeval_test( kw_fallback = await search_graph( connector=connector, q=question, - group_id=group_id, + end_user_id=end_user_id, limit=max(search_limit, 5), ) fb_dialogs = kw_fallback.get("dialogues", []) or [] @@ -1199,7 +1199,7 @@ async def run_longmemeval_test( "count_avg": statistics.mean(per_query_context_counts) if per_query_context_counts else 0.0, }, "params": { - "group_id": group_id, + "end_user_id": end_user_id, "search_limit": search_limit, "context_char_budget": context_char_budget, "search_type": search_type, @@ -1278,7 +1278,7 @@ def main(): result = asyncio.run( run_longmemeval_test( sample_size=sample_size, - group_id=args.group_id, + end_user_id=args.end_user_id, search_limit=args.search_limit, context_char_budget=args.context_char_budget, llm_temperature=args.llm_temperature, diff --git a/api/app/core/memory/evaluation/memsciqa/evaluate_qa.py b/api/app/core/memory/evaluation/memsciqa/evaluate_qa.py index 6efb66ff..869fdb60 100644 --- a/api/app/core/memory/evaluation/memsciqa/evaluate_qa.py +++ b/api/app/core/memory/evaluation/memsciqa/evaluate_qa.py @@ -135,8 +135,8 @@ def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any return merged -async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]: - group_id = group_id or SELECTED_GROUP_ID +async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]: + end_user_id = end_user_id or SELECTED_GROUP_ID # Load data data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl") if not os.path.exists(data_path): @@ -147,7 +147,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s # 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入 # 说明:memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略 contexts: List[str] = [build_context_from_dialog(item) for item in items] - await ingest_contexts_via_full_pipeline(contexts, group_id) + await ingest_contexts_via_full_pipeline(contexts, end_user_id) # LLM client (使用异步调用) with get_db_context() as db: @@ -173,7 +173,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s results = await run_hybrid_search( query_text=question, search_type=search_type, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, include=["dialogues", "statements", "entities"], output_path=None, @@ -298,7 +298,7 @@ def main(): load_dotenv() parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen") parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量") - parser.add_argument("--group-id", type=str, default=None, help="可选 group_id,默认取 runtime.json") + parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id,默认取 runtime.json") parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数") parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算") parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度") @@ -309,7 +309,7 @@ def main(): result = asyncio.run( run_memsciqa_eval( sample_size=args.sample_size, - group_id=args.group_id, + end_user_id=args.end_user_id, search_limit=args.search_limit, context_char_budget=args.context_char_budget, llm_temperature=args.llm_temperature, diff --git a/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py b/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py index 900cda9d..3023020a 100644 --- a/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py +++ b/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py @@ -199,7 +199,7 @@ def load_dataset_memsciqa(data_path: str) -> List[Dict[str, Any]]: async def run_memsciqa_test( sample_size: int = 3, - group_id: str | None = None, + end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, @@ -217,7 +217,7 @@ async def run_memsciqa_test( """ # 默认使用指定的 memsci 组 ID - group_id = group_id or "group_memsci" + end_user_id = end_user_id or "group_memsci" # 数据路径解析(项目根与当前工作目录兜底) if not data_path: @@ -283,7 +283,7 @@ async def run_memsciqa_test( connector=connector, embedder_client=embedder, query_text=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues ) @@ -292,7 +292,7 @@ async def run_memsciqa_test( results = await search_graph( connector=connector, q=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues ) @@ -500,7 +500,7 @@ async def run_memsciqa_test( }, "samples": samples, "params": { - "group_id": group_id, + "end_user_id": end_user_id, "search_limit": search_limit, "context_char_budget": context_char_budget, "llm_temperature": llm_temperature, @@ -543,7 +543,7 @@ def main(): result = asyncio.run( run_memsciqa_test( sample_size=sample_size, - group_id=args.group_id, + end_user_id=args.end_user_id, search_limit=args.search_limit, context_char_budget=args.context_char_budget, llm_temperature=args.llm_temperature, diff --git a/api/app/core/memory/evaluation/run_eval.py b/api/app/core/memory/evaluation/run_eval.py index 1de3de89..c5aacb2f 100644 --- a/api/app/core/memory/evaluation/run_eval.py +++ b/api/app/core/memory/evaluation/run_eval.py @@ -26,7 +26,7 @@ async def run( dataset: str, sample_size: int, reset_group: bool, - group_id: str | None, + end_user_id: str | None, judge_model: str | None = None, search_limit: int | None = None, context_char_budget: int | None = None, @@ -37,17 +37,17 @@ async def run( max_contexts_per_item: int | None = None, ) -> Dict[str, Any]: # 恢复原始风格:统一入口做路由,并沿用各数据集既有默认 - group_id = group_id or SELECTED_GROUP_ID + end_user_id = end_user_id or SELECTED_GROUP_ID if reset_group: connector = Neo4jConnector() try: - await connector.delete_group(group_id) + await connector.delete_group(end_user_id) finally: await connector.close() if dataset == "locomo": - kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id} + kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id} if search_limit is not None: kwargs["search_limit"] = search_limit if context_char_budget is not None: @@ -61,7 +61,7 @@ async def run( return await run_locomo_eval(**kwargs) if dataset == "memsciqa": - kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id} + kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id} if search_limit is not None: kwargs["search_limit"] = search_limit if context_char_budget is not None: @@ -75,7 +75,7 @@ async def run( return await run_memsciqa_eval(**kwargs) if dataset == "longmemeval": - kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id} + kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id} if search_limit is not None: kwargs["search_limit"] = search_limit if context_char_budget is not None: @@ -99,8 +99,8 @@ def main(): parser = argparse.ArgumentParser(description="统一评估入口:memsciqa / longmemeval / locomo") parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True) parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通") - parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 group_id 的图数据") - parser.add_argument("--group-id", type=str, default=None, help="可选 group_id,默认取 runtime.json") + parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 end_user_id 的图数据") + parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id,默认取 runtime.json") parser.add_argument("--judge-model", type=str, default=None, help="可选:longmemeval 判别式评测模型名") parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)") parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)") @@ -117,7 +117,7 @@ def main(): args.dataset, args.sample_size, args.reset_group, - args.group_id, + args.end_user_id, args.judge_model, args.search_limit, args.context_char_budget, diff --git a/api/app/core/memory/llm_tools/chunker_client.py b/api/app/core/memory/llm_tools/chunker_client.py index 87cdb9f4..93a2df82 100644 --- a/api/app/core/memory/llm_tools/chunker_client.py +++ b/api/app/core/memory/llm_tools/chunker_client.py @@ -187,11 +187,11 @@ class ChunkerClient: async def generate_chunks(self, dialogue: DialogData): """ Generate chunks following 1 Message = 1 Chunk strategy. - + Each message creates one chunk, directly inheriting role information. If a message is too long, it will be split into multiple sub-chunks, each maintaining the same speaker. - + Raises: ValueError: If dialogue has no messages or chunking fails """ @@ -201,9 +201,9 @@ class ChunkerClient: f"Dialogue {dialogue.ref_id} has no messages. " f"Cannot generate chunks from empty dialogue." ) - + dialogue.chunks = [] - + # 按消息分块:每个消息创建一个或多个 chunk,直接继承角色 for msg_idx, msg in enumerate(dialogue.context.msgs): # Validate message has required attributes @@ -212,13 +212,13 @@ class ChunkerClient: f"Message {msg_idx} in dialogue {dialogue.ref_id} " f"missing 'role' or 'msg' attribute" ) - + msg_content = msg.msg.strip() - + # Skip empty messages if not msg_content: continue - + # 如果消息太长,可以进一步分块 if len(msg_content) > self.chunk_size: # 对单个消息的内容进行分块 @@ -228,14 +228,14 @@ class ChunkerClient: raise ValueError( f"Failed to chunk long message {msg_idx} in dialogue {dialogue.ref_id}: {e}" ) - + for idx, sub_chunk in enumerate(sub_chunks): sub_chunk_text = sub_chunk.text if hasattr(sub_chunk, 'text') else str(sub_chunk) sub_chunk_text = sub_chunk_text.strip() - + if len(sub_chunk_text) < (self.min_characters_per_chunk or 50): continue - + chunk = Chunk( content=f"{msg.role}: {sub_chunk_text}", speaker=msg.role, # 直接继承角色 @@ -260,7 +260,7 @@ class ChunkerClient: }, ) dialogue.chunks.append(chunk) - + # Validate we generated at least one chunk if not dialogue.chunks: raise ValueError( @@ -268,7 +268,7 @@ class ChunkerClient: f"All messages were either empty or too short. " f"Messages count: {len(dialogue.context.msgs)}" ) - + return dialogue def evaluate_chunking(self, dialogue: DialogData) -> dict: diff --git a/api/app/core/memory/models/config_models.py b/api/app/core/memory/models/config_models.py index f3341cc5..ca1780aa 100644 --- a/api/app/core/memory/models/config_models.py +++ b/api/app/core/memory/models/config_models.py @@ -72,7 +72,7 @@ class TemporalSearchParams(BaseModel): """Parameters for temporal search queries in the knowledge graph. Attributes: - group_id: Group ID to filter search results (default: 'test') + end_user_id: Group ID to filter search results (default: 'test') apply_id: Application ID to filter search results user_id: User ID to filter search results start_date: Start date for temporal filtering (format: 'YYYY-MM-DD') @@ -81,7 +81,7 @@ class TemporalSearchParams(BaseModel): invalid_date: Date when memory should be invalid (format: 'YYYY-MM-DD') limit: Maximum number of results to return (default: 3) """ - group_id: Optional[str] = Field("test", description="The group ID to filter the search.") + end_user_id: Optional[str] = Field("test", description="The group ID to filter the search.") apply_id: Optional[str] = Field(None, description="The apply ID to filter the search.") user_id: Optional[str] = Field(None, description="The user ID to filter the search.") start_date: Optional[str] = Field(None, description="The start date for the search.") diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index 7a48d6cb..79b88fdc 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -103,9 +103,7 @@ class Edge(BaseModel): id: Unique identifier for the edge source: ID of the source node target: ID of the target node - group_id: Group ID for multi-tenancy - user_id: User ID for user-specific data - apply_id: Application ID for application-specific data + end_user_id: End user ID for multi-tenancy run_id: Unique identifier for the pipeline run that created this edge created_at: Timestamp when the edge was created (system perspective) expired_at: Optional timestamp when the edge expires (system perspective) @@ -113,9 +111,7 @@ class Edge(BaseModel): id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.") source: str = Field(..., description="The ID of the source node.") target: str = Field(..., description="The ID of the target node.") - group_id: str = Field(..., description="The group ID of the edge.") - user_id: str = Field(..., description="The user ID of the edge.") - apply_id: str = Field(..., description="The apply ID of the edge.") + end_user_id: str = Field(..., description="The end user ID of the edge.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") created_at: datetime = Field(..., description="The valid time of the edge from system perspective.") expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.") @@ -185,18 +181,14 @@ class Node(BaseModel): Attributes: id: Unique identifier for the node name: Name of the node - group_id: Group ID for multi-tenancy - user_id: User ID for user-specific data - apply_id: Application ID for application-specific data + end_user_id: End user ID for multi-tenancy run_id: Unique identifier for the pipeline run that created this node created_at: Timestamp when the node was created (system perspective) expired_at: Optional timestamp when the node expires (system perspective) """ id: str = Field(..., description="The unique identifier for the node.") name: str = Field(..., description="The name of the node.") - group_id: str = Field(..., description="The group ID of the node.") - user_id: str = Field(..., description="The user ID of the edge.") - apply_id: str = Field(..., description="The apply ID of the edge.") + end_user_id: str = Field(..., description="The end user ID of the node.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") created_at: datetime = Field(..., description="The valid time of the node from system perspective.") expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.") diff --git a/api/app/core/memory/models/message_models.py b/api/app/core/memory/models/message_models.py index bcf08999..2f8660af 100644 --- a/api/app/core/memory/models/message_models.py +++ b/api/app/core/memory/models/message_models.py @@ -55,7 +55,7 @@ class Statement(BaseModel): Attributes: id: Unique identifier for the statement chunk_id: ID of the parent chunk this statement belongs to - group_id: Optional group ID for multi-tenancy + end_user_id: Optional group ID for multi-tenancy statement: The actual statement text content speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses) statement_embedding: Optional embedding vector for the statement @@ -73,7 +73,7 @@ class Statement(BaseModel): """ id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.") chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.") - group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.") + end_user_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.") statement: str = Field(..., description="The text content of the statement.") speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses") statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.") @@ -159,9 +159,7 @@ class DialogData(BaseModel): context: Full conversation context dialog_embedding: Optional embedding vector for the entire dialog ref_id: Reference ID linking to external dialog system - group_id: Group ID for multi-tenancy - user_id: User ID for user-specific data - apply_id: Application ID for application-specific data + end_user_id: End user ID for multi-tenancy created_at: Timestamp when the dialog was created expired_at: Timestamp when the dialog expires (default: far future) metadata: Additional metadata as key-value pairs @@ -175,9 +173,7 @@ class DialogData(BaseModel): context: ConversationContext = Field(..., description="The full conversation context as a single string.") dialog_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the dialog.") ref_id: str = Field(..., description="Refer to external dialog id. This is used to link to the original dialog.") - group_id: str = Field(default=..., description="Group ID of dialogue data") - user_id: str = Field(..., description="USER ID of dialogue data") - apply_id: str = Field(..., description="APPLY ID of dialogue data") + end_user_id: str = Field(default=..., description="End user ID of dialogue data") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.") expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.") @@ -250,11 +246,11 @@ class DialogData(BaseModel): return [] def assign_group_id_to_statements(self) -> None: - """Assign this dialog's group_id to all statements in all chunks. + """Assign this dialog's end_user_id to all statements in all chunks. - This method updates statements that don't have a group_id set. + This method updates statements that don't have a end_user_id set. """ for chunk in self.chunks: for statement in chunk.statements: - if statement.group_id is None: - statement.group_id = self.group_id + if statement.end_user_id is None: + statement.end_user_id = self.end_user_id diff --git a/api/app/core/memory/src/search.py b/api/app/core/memory/src/search.py index 91e47eae..0e1d8424 100644 --- a/api/app/core/memory/src/search.py +++ b/api/app/core/memory/src/search.py @@ -6,6 +6,7 @@ import os import time from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, List, Optional +from uuid import UUID if TYPE_CHECKING: from app.schemas.memory_config_schema import MemoryConfig @@ -396,13 +397,13 @@ def rerank_with_activation( return reranked -def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = None): +def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None): """Log search query information using the logger. Args: query_text: The search query text search_type: Type of search (keyword, embedding, hybrid) - group_id: Group identifier for filtering + end_user_id: Group identifier for filtering limit: Maximum number of results include: List of result types to include log_file: Deprecated parameter, kept for backward compatibility @@ -413,7 +414,7 @@ def log_search_query(query_text: str, search_type: str, group_id: str | None, li # Log using the standard logger logger.info( f"Search query: query='{cleaned_query}', type={search_type}, " - f"group_id={group_id}, limit={limit}, include={include}" + f"end_user_id={end_user_id}, limit={limit}, include={include}" ) @@ -672,7 +673,7 @@ def apply_reranker_placeholder( async def run_hybrid_search( query_text: str, search_type: str, - group_id: str | None, + end_user_id: str | None, limit: int, include: List[str], output_path: str | None, @@ -715,7 +716,7 @@ async def run_hybrid_search( } # Log the search query - log_search_query(query_text, search_type, group_id, limit, include) + log_search_query(query_text, search_type, end_user_id, limit, include) connector = Neo4jConnector() results = {} @@ -732,7 +733,7 @@ async def run_hybrid_search( search_graph( connector=connector, q=query_text, - group_id=group_id, + end_user_id=end_user_id, limit=limit, include=include ) @@ -769,7 +770,7 @@ async def run_hybrid_search( connector=connector, embedder_client=embedder, query_text=query_text, - group_id=group_id, + end_user_id=end_user_id, limit=limit, include=include, ) @@ -916,9 +917,7 @@ async def run_hybrid_search( async def search_by_temporal( - group_id: Optional[str] = "test", - apply_id: Optional[str] = None, - user_id: Optional[str] = None, + end_user_id: Optional[str] = "test", start_date: Optional[str] = None, end_date: Optional[str] = None, valid_date: Optional[str] = None, @@ -929,7 +928,7 @@ async def search_by_temporal( Temporal search across Statements. - Matches statements created between start_date and end_date - - Optionally filters by group_id + - Optionally filters by end_user_id - Returns up to 'limit' statements """ connector = Neo4jConnector() @@ -939,9 +938,7 @@ async def search_by_temporal( end_date = normalize_date_safe(end_date) params = TemporalSearchParams.model_validate({ - "group_id": group_id, - "apply_id": apply_id, - "user_id": user_id, + "end_user_id": end_user_id, "start_date": start_date, "end_date": end_date, "valid_date": valid_date, @@ -950,9 +947,7 @@ async def search_by_temporal( }) statements = await search_graph_by_temporal( connector=connector, - group_id=params.group_id, - apply_id=params.apply_id, - user_id=params.user_id, + end_user_id=params.end_user_id, start_date=params.start_date, end_date=params.end_date, valid_date=params.valid_date, @@ -964,9 +959,7 @@ async def search_by_temporal( async def search_by_keyword_temporal( query_text: str, - group_id: Optional[str] = "test", - apply_id: Optional[str] = None, - user_id: Optional[str] = None, + end_user_id: Optional[str] = "test", start_date: Optional[str] = None, end_date: Optional[str] = None, valid_date: Optional[str] = None, @@ -987,9 +980,7 @@ async def search_by_keyword_temporal( invalid_date = normalize_date_safe(invalid_date) params = TemporalSearchParams.model_validate({ - "group_id": group_id, - "apply_id": apply_id, - "user_id": user_id, + "end_user_id": end_user_id, "start_date": start_date, "end_date": end_date, "valid_date": valid_date, @@ -999,9 +990,7 @@ async def search_by_keyword_temporal( statements = await search_graph_by_keyword_temporal( connector=connector, query_text=query_text, - group_id=params.group_id, - apply_id=params.apply_id, - user_id=params.user_id, + end_user_id=params.end_user_id, start_date=params.start_date, end_date=params.end_date, valid_date=params.valid_date, @@ -1013,7 +1002,7 @@ async def search_by_keyword_temporal( async def search_chunk_by_chunk_id( chunk_id: str, - group_id: Optional[str] = "test", + end_user_id: Optional[str] = "test", limit: int = 1, ): """ @@ -1023,7 +1012,7 @@ async def search_chunk_by_chunk_id( chunks = await search_graph_by_chunk_id( connector=connector, chunk_id=chunk_id, - group_id=group_id, + end_user_id=end_user_id, limit=limit ) return {"chunks": chunks} diff --git a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_preprocessor.py b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_preprocessor.py index f5e72517..4dafd3ed 100644 --- a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_preprocessor.py +++ b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_preprocessor.py @@ -555,8 +555,8 @@ class DataPreprocessor: dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}'))) - # 获取group_id,如果不存在则生成默认值 - group_id = item.get('group_id', f'group_default_{i}') + # 获取end_user_id,如果不存在则生成默认值 + end_user_id = item.get('end_user_id', f'group_default_{i}') user_id = item.get('user_id', f'user_default_{i}') apply_id = item.get('apply_id', f'apply_default_{i}') @@ -574,7 +574,7 @@ class DataPreprocessor: dialog_data = DialogData( context=context, ref_id=dialog_id, - group_id=group_id, + end_user_id=end_user_id, user_id=user_id, apply_id=apply_id, metadata=metadata @@ -644,7 +644,7 @@ class DataPreprocessor: context = ConversationContext(msgs=messages) dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}'))) - group_id = item.get('group_id', f'group_default_{i}') + end_user_id = item.get('end_user_id', f'group_default_{i}') user_id = item.get('user_id', f'user_default_{i}') apply_id = item.get('apply_id', f'apply_default_{i}') @@ -657,7 +657,7 @@ class DataPreprocessor: dialog_data = DialogData( context=context, ref_id=dialog_id, - group_id=group_id, + end_user_id=end_user_id, user_id=user_id, apply_id=apply_id, metadata=metadata diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py index 62b656b0..a425e0ed 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py @@ -199,7 +199,7 @@ def accurate_match( entity_nodes: List[ExtractedEntityNode] ) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]: """ - 精确匹配:按 (group_id, name, entity_type) 合并实体并建立重定向与合并记录。 + 精确匹配:按 (end_user_id, name, entity_type) 合并实体并建立重定向与合并记录。 返回: (deduped_entities, id_redirect, exact_merge_map) """ exact_merge_map: Dict[str, Dict] = {} @@ -210,8 +210,8 @@ def accurate_match( for ent in entity_nodes: name_norm = (getattr(ent, "name", "") or "").strip() type_norm = (getattr(ent, "entity_type", "") or "").strip() - key = f"{getattr(ent, 'group_id', None)}|{name_norm}|{type_norm}" - # 为避免跨业务组误并,明确以 group_id 为范围边界 + key = f"{getattr(ent, 'end_user_id', None)}|{name_norm}|{type_norm}" + # 为避免跨业务组误并,明确以 end_user_id 为范围边界 if key not in canonical_map: canonical_map[key] = ent id_redirect[ent.id] = ent.id @@ -223,11 +223,11 @@ def accurate_match( id_redirect[ent.id] = canonical.id # 记录精确匹配的合并项(使用规范化键,避免外层变量误用) try: - k = f"{canonical.group_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}" + k = f"{canonical.end_user_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}" if k not in exact_merge_map: exact_merge_map[k] = { "canonical_id": canonical.id, - "group_id": canonical.group_id, + "end_user_id": canonical.end_user_id, "name": canonical.name, "entity_type": canonical.entity_type, "merged_ids": set(), @@ -596,7 +596,7 @@ def fuzzy_match( b = deduped_entities[j] # 跳过不同业务组的实体 - if getattr(a, "group_id", None) != getattr(b, "group_id", None): + if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None): j += 1 continue @@ -671,7 +671,7 @@ def fuzzy_match( merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]" merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]" fuzzy_merge_records.append( - f"{merge_reason} 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type}) | " + f"{merge_reason} 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type}) | " f"s_name={s_name:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, exact_alias={has_exact_match}" ) except Exception: @@ -779,7 +779,7 @@ async def LLM_decision( # 决策中包含去重和消歧的功能 # 记录 LLM 融合日志 try: llm_records.append( - f"[LLM融合] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})" + f"[LLM融合] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})" ) # 详细的“同类名称相似”记录改由 LLM 去重模块统一生成以携带 conf/reason except Exception: @@ -847,7 +847,7 @@ async def LLM_disamb_decision( id_redirect[k] = a.id try: disamb_records.append( - f"[DISAMB合并应用] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})" + f"[DISAMB合并应用] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})" ) except Exception: pass diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py index 734f7b69..0249ac1f 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py @@ -174,7 +174,7 @@ async def _judge_pair( pass # 3. 构建LLM判断的“上下文信息”(规则层计算的所有特征) 判断上下文特征有助于实体消歧首先判断的类型关系 ctx = { - "same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None), + "same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None), "type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)), "type_similarity": _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None)), "name_text_sim": name_text_sim, @@ -235,7 +235,7 @@ async def _judge_pair_disamb( except Exception: pass ctx = { - "same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None), + "same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None), "type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)), "name_text_sim": name_text_sim, "name_embed_sim": name_embed_sim, @@ -317,8 +317,8 @@ async def llm_dedup_entities( # 保留对偶判断作为子流程,是为了 a = entity_nodes[i] for j in range(i + 1, len(entity_nodes)): b = entity_nodes[j] - # 规则1:必须属于同一组(group_id相同,不同组的实体不重复) - if getattr(a, "group_id", None) != getattr(b, "group_id", None): + # 规则1:必须属于同一组(end_user_id相同,不同组的实体不重复) + if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None): continue # 规则2:类型必须兼容(调用_simple_type_ok判断) if not _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)): @@ -474,7 +474,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重 - max_rounds: upper bound for iterative passes (default 3) - auto_merge_threshold: decision confidence for auto-merge when no co-occurrence (default 0.90) - co_ctx_threshold: lower threshold when co-occurrence is detected (default 0.83) - - shuffle_each_round: whether to shuffle entities within group_id each round to vary block composition + - shuffle_each_round: whether to shuffle entities within end_user_id each round to vary block composition Returns: - global_redirect: dict losing_id -> canonical_id accumulated across rounds @@ -509,7 +509,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重 def _partition_blocks(nodes: List[ExtractedEntityNode]) -> List[List[ExtractedEntityNode]]: """ - 按 group_id 分块,避免跨组实体在同一块,减少无效候选对 + 按 end_user_id 分块,避免跨组实体在同一块,减少无效候选对 Args: nodes: 实体节点列表 @@ -519,7 +519,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重 """ groups: Dict[str, List[ExtractedEntityNode]] = {} for e in nodes: - gid = getattr(e, "group_id", None) + gid = getattr(e, "end_user_id", None) groups.setdefault(str(gid), []).append(e) blocks: List[List[ExtractedEntityNode]] = [] for gid, arr in groups.items(): @@ -559,7 +559,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重 # Collapse nodes to canonical reps before each round to avoid redundant comparisons # 步骤1:折叠实体(合并已确定的重复实体,减少后续计算量) current_nodes = _collapse_nodes(current_nodes) - # 步骤2:分块(按group_id分块,避免跨组处理) + # 步骤2:分块(按end_user_id分块,避免跨组处理) blocks = _partition_blocks(current_nodes) if not blocks: # 无块可处理(实体已全部折叠),退出循环 break @@ -645,7 +645,7 @@ async def llm_disambiguate_pairs_iterative( a = entity_nodes[i] b = entity_nodes[j] # 必须同组 - if getattr(a, "group_id", None) != getattr(b, "group_id", None): + if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None): continue ta = getattr(a, "entity_type", None) tb = getattr(b, "entity_type", None) diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py index b41f35a4..dbc697d9 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py @@ -61,7 +61,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode: return ExtractedEntityNode( id=row.get("id"), name=row.get("name") or "", - group_id=row.get("group_id") or "", + end_user_id=row.get("end_user_id") or "", user_id=row.get("user_id") or "", apply_id=row.get("apply_id") or "", created_at=_parse_dt(row.get("created_at")), @@ -79,7 +79,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode: async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑,与 Neo4j 中同组实体联合去重 connector: Neo4jConnector, - group_id: str, # 用于定位neo4j中同一组的实体,确保只在同组内去重 + end_user_id: str, # 用于定位neo4j中同一组的实体,确保只在同组内去重 entity_nodes: List[ExtractedEntityNode], # 输入的实体节点列表,包含待去重的实体 statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系 entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系 @@ -88,7 +88,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑 ) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]: """ 第二层去重消歧: - - 以第一层结果为索引,检索相同 group_id 下的 DB 候选实体 + - 以第一层结果为索引,检索相同 end_user_id 下的 DB 候选实体 - 将 DB 候选与当前实体集合联合,按既有精确/模糊/LLM 决策进行融合 - 返回融合后的实体与重定向后的边(边已指向规范 ID,优先 DB ID) """ @@ -102,7 +102,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑 ] candidates_map = await get_dedup_candidates_for_entities( # 从 Neo4j 中查询候选实体,并将结果赋值给candidates_map(等待异步操作完成)。 - connector=connector, group_id=group_id, + connector=connector, end_user_id=end_user_id, entities=incoming_rows, # 传入参数:第一层实体的核心信息(作为查询索引) use_contains_fallback=True # 传入参数:启用 “包含关系” 作为匹配失败的降级策略(若精确匹配无结果,用包含关系召回候选),与src\database\cypher_queries.py的307产生联动 ) diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py index 11845d7d..f28b8a5f 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py @@ -57,11 +57,11 @@ async def dedup_layers_and_merge_and_return( if pipeline_config is None: raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return") - # 先探测 group_id,决定报告写入策略 - group_id: Optional[str] = None + # 先探测 end_user_id,决定报告写入策略 + end_user_id: Optional[str] = None for dd in dialog_data_list: - group_id = getattr(dd, "group_id", None) - if group_id: + end_user_id = getattr(dd, "end_user_id", None) + if end_user_id: break # 第一层去重消歧 @@ -82,11 +82,11 @@ async def dedup_layers_and_merge_and_return( # 第二层去重消歧:与 Neo4j 中同组实体联合融合 try: - if group_id: + if end_user_id: if connector: fused_entity_nodes, fused_statement_entity_edges, fused_entity_entity_edges = await second_layer_dedup_and_merge_with_neo4j( connector=connector, - group_id=group_id, + end_user_id=end_user_id, entity_nodes=dedup_entity_nodes, statement_entity_edges=dedup_statement_entity_edges, entity_entity_edges=dedup_entity_entity_edges, @@ -96,7 +96,7 @@ async def dedup_layers_and_merge_and_return( else: print("Skip second-layer dedup: missing connector") else: - print("Skip second-layer dedup: missing group_id") + print("Skip second-layer dedup: missing end_user_id") except Exception as e: print(f"Second-layer dedup failed: {e}") diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 46ba1dde..8c69c7cf 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -287,7 +287,7 @@ class ExtractionOrchestrator: for d_idx, dialog in enumerate(dialog_data_list): dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None for c_idx, chunk in enumerate(dialog.chunks): - all_chunks.append((chunk, dialog.group_id, dialogue_content)) + all_chunks.append((chunk, dialog.end_user_id, dialogue_content)) chunk_metadata.append((d_idx, c_idx)) logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取") @@ -299,9 +299,9 @@ class ExtractionOrchestrator: # 全局并行处理所有分块 async def extract_for_chunk(chunk_data, chunk_index): nonlocal completed_chunks - chunk, group_id, dialogue_content = chunk_data + chunk, end_user_id, dialogue_content = chunk_data try: - statements = await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content) + statements = await self.statement_extractor._extract_statements(chunk, end_user_id, dialogue_content) # 流式输出:每提取完一个分块的陈述句,立即发送进度 # 注意:只在试运行模式下发送陈述句详情,正式模式不发送 @@ -569,32 +569,32 @@ class ExtractionOrchestrator: if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'): config_id = dialog_data_list[0].config_id - # 加载DataConfig - data_config = None + # 加载MemoryConfig + memory_config = None if config_id: try: from app.db import SessionLocal - from app.repositories.data_config_repository import DataConfigRepository + from app.repositories.memory_config_repository import MemoryConfigRepository db = SessionLocal() try: - data_config = DataConfigRepository.get_by_id(db, config_id) + memory_config = MemoryConfigRepository.get_by_id(db, config_id) finally: db.close() - if data_config and not data_config.emotion_enabled: + if memory_config and not memory_config.emotion_enabled: logger.info("情绪提取已在配置中禁用,跳过情绪提取") return [{} for _ in dialog_data_list] except Exception as e: - logger.warning(f"加载DataConfig失败: {e},将跳过情绪提取") + logger.warning(f"加载MemoryConfig失败: {e},将跳过情绪提取") return [{} for _ in dialog_data_list] else: logger.info("未找到config_id,跳过情绪提取") return [{} for _ in dialog_data_list] # 如果配置未启用情绪提取,直接返回空映射 - if not data_config or not data_config.emotion_enabled: + if not memory_config or not memory_config.emotion_enabled: logger.info("情绪提取未启用,跳过") return [{} for _ in dialog_data_list] @@ -608,7 +608,7 @@ class ExtractionOrchestrator: total_statements += 1 # 只处理用户的陈述句 (role 为 "user") if hasattr(statement, 'speaker') and statement.speaker == "user": - all_statements.append((statement, data_config)) + all_statements.append((statement, memory_config)) statement_metadata.append((d_idx, statement.id)) filtered_statements += 1 @@ -617,7 +617,7 @@ class ExtractionOrchestrator: # 初始化情绪提取服务 from app.services.emotion_extraction_service import EmotionExtractionService emotion_service = EmotionExtractionService( - llm_id=data_config.emotion_model_id if data_config.emotion_model_id else None + llm_id=memory_config.emotion_model_id if memory_config.emotion_model_id else None ) # 全局并行处理所有陈述句 @@ -992,9 +992,7 @@ class ExtractionOrchestrator: id=dialog_data.id, name=f"Dialog_{dialog_data.id}", # 添加必需的 name 字段 ref_id=dialog_data.ref_id, - group_id=dialog_data.group_id, - user_id=dialog_data.user_id, - apply_id=dialog_data.apply_id, + end_user_id=dialog_data.end_user_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id content=dialog_data.context.content if dialog_data.context else "", dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None, @@ -1012,9 +1010,7 @@ class ExtractionOrchestrator: id=chunk.id, name=f"Chunk_{chunk.id}", # 添加必需的 name 字段 dialog_id=dialog_data.id, - group_id=dialog_data.group_id, - user_id=dialog_data.user_id, - apply_id=dialog_data.apply_id, + end_user_id=dialog_data.end_user_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id content=chunk.content, chunk_embedding=chunk.chunk_embedding, @@ -1035,9 +1031,7 @@ class ExtractionOrchestrator: stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段 temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段 connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 - group_id=dialog_data.group_id, - user_id=dialog_data.user_id, - apply_id=dialog_data.apply_id, + end_user_id=dialog_data.end_user_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id statement=statement.statement, speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段 @@ -1060,9 +1054,7 @@ class ExtractionOrchestrator: statement_chunk_edge = StatementChunkEdge( source=statement.id, target=chunk.id, - group_id=dialog_data.group_id, - user_id=dialog_data.user_id, - apply_id=dialog_data.apply_id, + end_user_id=dialog_data.end_user_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id created_at=dialog_data.created_at, ) @@ -1095,9 +1087,7 @@ class ExtractionOrchestrator: aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases name_embedding=getattr(entity, 'name_embedding', None), is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记 - group_id=dialog_data.group_id, - user_id=dialog_data.user_id, - apply_id=dialog_data.apply_id, + end_user_id=dialog_data.end_user_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id created_at=dialog_data.created_at, expired_at=dialog_data.expired_at, @@ -1112,9 +1102,7 @@ class ExtractionOrchestrator: source=statement.id, target=entity.id, connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', - group_id=dialog_data.group_id, - user_id=dialog_data.user_id, - apply_id=dialog_data.apply_id, + end_user_id=dialog_data.end_user_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id created_at=dialog_data.created_at, ) @@ -1134,9 +1122,7 @@ class ExtractionOrchestrator: relation_type=triplet.predicate, statement=statement.statement, source_statement_id=statement.id, - group_id=dialog_data.group_id, - user_id=dialog_data.user_id, - apply_id=dialog_data.apply_id, + end_user_id=dialog_data.end_user_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id created_at=dialog_data.created_at, expired_at=dialog_data.expired_at, @@ -1763,14 +1749,14 @@ class ExtractionOrchestrator: async def get_chunked_dialogs( chunker_strategy: str = "RecursiveChunker", - group_id: str = "group_1", + end_user_id: str = "group_1", indices: Optional[List[int]] = None, ) -> List[DialogData]: """从测试数据生成分块对话 Args: chunker_strategy: 分块策略(默认: RecursiveChunker) - group_id: 组ID + end_user_id: 组ID indices: 要处理的数据索引列表(可选) Returns: @@ -1834,7 +1820,7 @@ async def get_chunked_dialogs( dialog_data = DialogData( context=conversation_context, ref_id=data['id'], - group_id=group_id, + end_user_id=end_user_id, metadata=dialog_metadata, ) @@ -1936,7 +1922,7 @@ async def get_chunked_dialogs_from_preprocessed( async def get_chunked_dialogs_with_preprocessing( chunker_strategy: str = "RecursiveChunker", - group_id: str = "default", + end_user_id: str = "default", user_id: str = "default", apply_id: str = "default", indices: Optional[List[int]] = None, @@ -1948,7 +1934,7 @@ async def get_chunked_dialogs_with_preprocessing( Args: chunker_strategy: 分块策略 - group_id: 组ID + end_user_id: 组ID user_id: 用户ID apply_id: 应用ID indices: 要处理的数据索引列表 @@ -1976,11 +1962,9 @@ async def get_chunked_dialogs_with_preprocessing( indices=indices, ) - # 设置 group_id, user_id, apply_id + # 设置 end_user_id for dd in preprocessed_data: - dd.group_id = group_id - dd.user_id = user_id - dd.apply_id = apply_id + dd.end_user_id = end_user_id # 步骤2: 语义剪枝 try: diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py index 7e75fd2d..f39313a8 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py @@ -193,9 +193,9 @@ async def _process_chunk_summary( node = MemorySummaryNode( id=uuid4().hex, name=title if title else f"MemorySummaryChunk_{chunk.id}", - group_id=dialog.group_id, - user_id=dialog.user_id, - apply_id=dialog.apply_id, + end_user_id=dialog.end_user_id, + user_id=dialog.end_user_id, + apply_id=dialog.end_user_id, run_id=dialog.run_id, # 使用 dialog 的 run_id created_at=datetime.now(), expired_at=datetime(9999, 12, 31), diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py index fb1b539a..b06bd70f 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py @@ -82,12 +82,12 @@ class StatementExtractor: logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty") return None - async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]: + async def _extract_statements(self, chunk, end_user_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]: """Process a single chunk and return extracted statements Args: chunk: Chunk object to process - group_id: Group ID to assign to all statements in this chunk + end_user_id: Group ID to assign to all statements in this chunk dialogue_content: Full dialogue content to provide as context Returns: @@ -158,7 +158,7 @@ class StatementExtractor: temporal_info=temporal_type, relevence_info=relevence_info, chunk_id=chunk.id, - group_id=group_id, + end_user_id=end_user_id, speaker=chunk_speaker, ) @@ -184,10 +184,10 @@ class StatementExtractor: logger.info(f"Processing {len(chunks_to_process)} chunks for statement extraction") - # Process all chunks concurrently, passing the group_id and dialogue content from dialog_data + # Process all chunks concurrently, passing the end_user_id and dialogue content from dialog_data dialogue_content = dialog_data.content if self.config.include_dialogue_context else None results = await asyncio.gather( - *[self._extract_statements(chunk, dialog_data.group_id, dialogue_content) for chunk in chunks_to_process], + *[self._extract_statements(chunk, dialog_data.end_user_id, dialogue_content) for chunk in chunks_to_process], return_exceptions=True ) @@ -225,7 +225,7 @@ class StatementExtractor: for i, statement in enumerate(statements, 1): f.write(f"Statement {i}:\n") f.write(f"Id: {statement.id}\n") - f.write(f"Group Id: {statement.group_id}\n") + f.write(f"Group Id: {statement.end_user_id}\n") f.write(f"Content: {statement.statement}\n") f.write(f"Type: {statement.stmt_type.value}\n") f.write(f"Temporal Info: {statement.temporal_info.value}\n") @@ -298,7 +298,7 @@ class StatementExtractor: dialog_sections.append({ "dialog_id": dialog.ref_id, - "group_id": dialog.group_id, + "end_user_id": dialog.end_user_id, "content": dialog.content if getattr(dialog, "content", None) else "", "strong": strong_relations, "weak": weak_relations, @@ -312,7 +312,7 @@ class StatementExtractor: for idx, section in enumerate(dialog_sections, 1): f.write(f"Dialog {idx}:\n") f.write(f"Dialog ID: {section.get('dialog_id', '')}\n") - f.write(f"Group ID: {section.get('group_id', '')}\n") + f.write(f"Group ID: {section.get('end_user_id', '')}\n") f.write("Content:\n") f.write(f"{section.get('content', '')}\n") f.write("-" * 40 + "\n\n") diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/temporal_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/temporal_extraction.py index 9528e638..499027a4 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/temporal_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/temporal_extraction.py @@ -132,7 +132,7 @@ class TemporalExtractor: prompt_logger.info("") prompt_logger.info("=== TEMPORAL EXTRACTION RESULTS ===") prompt_logger.info( - f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}" + f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}" ) except Exception: pass diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py index d3d059b0..bfc0bc88 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py @@ -116,7 +116,7 @@ class TripletExtractor: logger.info(f"Processing {len(all_statements)} statements for triplet extraction...") try: prompt_logger.info( - f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}, statements_to_process={len(all_statements)}" + f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}, statements_to_process={len(all_statements)}" ) except Exception: pass diff --git a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py index 5722769a..a71c0957 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py +++ b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py @@ -75,7 +75,7 @@ class AccessHistoryManager: self, node_id: str, node_label: str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, current_time: Optional[datetime] = None ) -> Dict[str, Any]: """ @@ -91,7 +91,7 @@ class AccessHistoryManager: Args: node_id: 节点ID node_label: 节点标签(Statement, ExtractedEntity, MemorySummary) - group_id: 组ID(可选,用于过滤) + end_user_id: 组ID(可选,用于过滤) current_time: 当前时间(可选,默认使用系统时间) Returns: @@ -123,7 +123,7 @@ class AccessHistoryManager: for attempt in range(self.max_retries): try: # 步骤1:读取当前节点状态 - node_data = await self._fetch_node(node_id, node_label, group_id) + node_data = await self._fetch_node(node_id, node_label, end_user_id) if not node_data: raise ValueError( @@ -142,7 +142,7 @@ class AccessHistoryManager: node_id=node_id, node_label=node_label, update_data=update_data, - group_id=group_id + end_user_id=end_user_id ) logger.info( @@ -172,7 +172,7 @@ class AccessHistoryManager: self, node_ids: List[str], node_label: str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, current_time: Optional[datetime] = None ) -> List[Dict[str, Any]]: """ @@ -184,7 +184,7 @@ class AccessHistoryManager: Args: node_ids: 节点ID列表 node_label: 节点标签(所有节点必须是同一类型) - group_id: 组ID(可选) + end_user_id: 组ID(可选) current_time: 当前时间(可选) Returns: @@ -202,7 +202,7 @@ class AccessHistoryManager: task = self.record_access( node_id=node_id, node_label=node_label, - group_id=group_id, + end_user_id=end_user_id, current_time=current_time ) tasks.append(task) @@ -235,7 +235,7 @@ class AccessHistoryManager: self, node_id: str, node_label: str, - group_id: Optional[str] = None + end_user_id: Optional[str] = None ) -> Tuple[ConsistencyCheckResult, Optional[str]]: """ 检查节点数据的一致性 @@ -249,14 +249,14 @@ class AccessHistoryManager: Args: node_id: 节点ID node_label: 节点标签 - group_id: 组ID(可选) + end_user_id: 组ID(可选) Returns: Tuple[ConsistencyCheckResult, Optional[str]]: - 一致性检查结果枚举 - 错误描述(如果不一致) """ - node_data = await self._fetch_node(node_id, node_label, group_id) + node_data = await self._fetch_node(node_id, node_label, end_user_id) if not node_data: return ConsistencyCheckResult.CONSISTENT, None @@ -305,7 +305,7 @@ class AccessHistoryManager: async def check_batch_consistency( self, node_label: str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, limit: int = 1000 ) -> Dict[str, Any]: """ @@ -313,7 +313,7 @@ class AccessHistoryManager: Args: node_label: 节点标签 - group_id: 组ID(可选) + end_user_id: 组ID(可选) limit: 检查的最大节点数 Returns: @@ -329,16 +329,16 @@ class AccessHistoryManager: MATCH (n:{node_label}) WHERE n.access_history IS NOT NULL """ - if group_id: - query += " AND n.group_id = $group_id" + if end_user_id: + query += " AND n.end_user_id = $end_user_id" query += """ RETURN n.id as id LIMIT $limit """ params = {"limit": limit} - if group_id: - params["group_id"] = group_id + if end_user_id: + params["end_user_id"] = end_user_id results = await self.connector.execute_query(query, **params) node_ids = [r['id'] for r in results] @@ -351,7 +351,7 @@ class AccessHistoryManager: result, message = await self.check_consistency( node_id=node_id, node_label=node_label, - group_id=group_id + end_user_id=end_user_id ) if result == ConsistencyCheckResult.CONSISTENT: @@ -387,7 +387,7 @@ class AccessHistoryManager: self, node_id: str, node_label: str, - group_id: Optional[str] = None + end_user_id: Optional[str] = None ) -> bool: """ 自动修复节点的数据不一致问题 @@ -401,7 +401,7 @@ class AccessHistoryManager: Args: node_id: 节点ID node_label: 节点标签 - group_id: 组ID(可选) + end_user_id: 组ID(可选) Returns: bool: 修复成功返回True,否则返回False @@ -411,7 +411,7 @@ class AccessHistoryManager: result, message = await self.check_consistency( node_id=node_id, node_label=node_label, - group_id=group_id + end_user_id=end_user_id ) if result == ConsistencyCheckResult.CONSISTENT: @@ -419,7 +419,7 @@ class AccessHistoryManager: return True # 获取节点数据 - node_data = await self._fetch_node(node_id, node_label, group_id) + node_data = await self._fetch_node(node_id, node_label, end_user_id) if not node_data: logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]") return False @@ -457,8 +457,8 @@ class AccessHistoryManager: query = f""" MATCH (n:{node_label} {{id: $node_id}}) """ - if group_id: - query += " WHERE n.group_id = $group_id" + if end_user_id: + query += " WHERE n.end_user_id = $end_user_id" query += """ SET n += $repair_data RETURN n @@ -468,8 +468,8 @@ class AccessHistoryManager: 'node_id': node_id, 'repair_data': repair_data } - if group_id: - params['group_id'] = group_id + if end_user_id: + params['end_user_id'] = end_user_id await self.connector.execute_query(query, **params) @@ -491,7 +491,7 @@ class AccessHistoryManager: self, node_id: str, node_label: str, - group_id: Optional[str] = None + end_user_id: Optional[str] = None ) -> Optional[Dict[str, Any]]: """ 获取节点数据 @@ -499,7 +499,7 @@ class AccessHistoryManager: Args: node_id: 节点ID node_label: 节点标签 - group_id: 组ID(可选) + end_user_id: 组ID(可选) Returns: Optional[Dict[str, Any]]: 节点数据,如果不存在返回None @@ -507,8 +507,8 @@ class AccessHistoryManager: query = f""" MATCH (n:{node_label} {{id: $node_id}}) """ - if group_id: - query += " WHERE n.group_id = $group_id" + if end_user_id: + query += " WHERE n.end_user_id = $end_user_id" query += """ RETURN n.id as id, n.importance_score as importance_score, @@ -519,8 +519,8 @@ class AccessHistoryManager: """ params = {'node_id': node_id} - if group_id: - params['group_id'] = group_id + if end_user_id: + params['end_user_id'] = end_user_id results = await self.connector.execute_query(query, **params) @@ -585,7 +585,7 @@ class AccessHistoryManager: node_id: str, node_label: str, update_data: Dict[str, Any], - group_id: Optional[str] = None + end_user_id: Optional[str] = None ) -> Dict[str, Any]: """ 原子性更新节点(使用乐观锁) @@ -597,7 +597,7 @@ class AccessHistoryManager: node_id: 节点ID node_label: 节点标签 update_data: 更新数据 - group_id: 组ID(可选) + end_user_id: 组ID(可选) Returns: Dict[str, Any]: 更新后的节点数据 @@ -606,13 +606,13 @@ class AccessHistoryManager: RuntimeError: 如果更新失败或发生版本冲突 """ # 定义事务函数 - async def update_transaction(tx, node_id, node_label, update_data, group_id): + async def update_transaction(tx, node_id, node_label, update_data, end_user_id): # 步骤1:读取当前节点并获取版本号 read_query = f""" MATCH (n:{node_label} {{id: $node_id}}) """ - if group_id: - read_query += " WHERE n.group_id = $group_id" + if end_user_id: + read_query += " WHERE n.end_user_id = $end_user_id" read_query += """ RETURN n.id as id, n.version as version, @@ -624,8 +624,8 @@ class AccessHistoryManager: """ read_params = {'node_id': node_id} - if group_id: - read_params['group_id'] = group_id + if end_user_id: + read_params['end_user_id'] = end_user_id read_result = await tx.run(read_query, **read_params) current_node = await read_result.single() @@ -656,8 +656,8 @@ class AccessHistoryManager: # 构建 WHERE 子句 where_conditions = [] - if group_id: - where_conditions.append("n.group_id = $group_id") + if end_user_id: + where_conditions.append("n.end_user_id = $end_user_id") # 添加版本检查 if current_version > 0: @@ -695,8 +695,8 @@ class AccessHistoryManager: 'last_access_time': update_data['last_access_time'], 'access_count': update_data['access_count'] } - if group_id: - update_params['group_id'] = group_id + if end_user_id: + update_params['end_user_id'] = end_user_id update_result = await tx.run(update_query, **update_params) updated_node = await update_result.single() @@ -720,7 +720,7 @@ class AccessHistoryManager: node_id=node_id, node_label=node_label, update_data=update_data, - group_id=group_id + end_user_id=end_user_id ) return result except Exception as e: diff --git a/api/app/core/memory/storage_services/forgetting_engine/config_utils.py b/api/app/core/memory/storage_services/forgetting_engine/config_utils.py index ea9a6358..25daa968 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/config_utils.py +++ b/api/app/core/memory/storage_services/forgetting_engine/config_utils.py @@ -11,9 +11,10 @@ Functions: import logging from typing import Optional, Dict, Any +from uuid import UUID from sqlalchemy.orm import Session -from app.repositories.data_config_repository import DataConfigRepository +from app.repositories.memory_config_repository import MemoryConfigRepository from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator @@ -61,12 +62,12 @@ def calculate_forgetting_rate(lambda_time: float, lambda_mem: float) -> float: def load_actr_config_from_db( db: Session, - config_id: Optional[int] = None + config_id: Optional[UUID] = None ) -> Dict[str, Any]: """ 从数据库加载 ACT-R 配置参数 - 从 PostgreSQL 的 data_config 表读取配置参数, + 从 PostgreSQL 的 memory_config 表读取配置参数, 并计算派生参数(如 forgetting_rate)。 Args: @@ -99,7 +100,7 @@ def load_actr_config_from_db( # 从数据库加载配置 try: - repository = DataConfigRepository() + repository = MemoryConfigRepository() db_config = repository.get_by_id(db, config_id) if db_config is None: @@ -150,7 +151,7 @@ def load_actr_config_from_db( def create_actr_calculator_from_config( db: Session, - config_id: Optional[int] = None + config_id: Optional[UUID] = None ) -> ACTRCalculator: """ 从数据库配置创建 ACTRCalculator 实例 @@ -168,11 +169,6 @@ def create_actr_calculator_from_config( ValueError: 如果指定的 config_id 不存在 Examples: - >>> from sqlalchemy.orm import Session - >>> db = Session() - >>> calculator = create_actr_calculator_from_config(db, config_id=1) - >>> # 使用计算器 - >>> activation = calculator.calculate_memory_activation(...) """ # 加载配置 config = load_actr_config_from_db(db, config_id) diff --git a/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py b/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py index 6d42af53..5a178fc2 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py +++ b/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py @@ -16,6 +16,7 @@ Classes: import logging from typing import Dict, Any, Optional +from uuid import UUID from datetime import datetime from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy @@ -66,10 +67,10 @@ class ForgettingScheduler: async def run_forgetting_cycle( self, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, max_merge_batch_size: int = 100, min_days_since_access: int = 30, - config_id: Optional[int] = None, + config_id: Optional[UUID] = None, db = None ) -> Dict[str, Any]: """ @@ -77,7 +78,7 @@ class ForgettingScheduler: Args: - group_id: 组 ID(可选,用于过滤特定组的节点) + end_user_id: 组 ID(可选,用于过滤特定组的节点) max_merge_batch_size: 单次最大融合节点对数(默认 100) min_days_since_access: 最小未访问天数(默认 30 天) config_id: 配置ID(可选,用于获取 llm_id) @@ -107,19 +108,19 @@ class ForgettingScheduler: start_time_iso = start_time.isoformat() logger.info( - f"开始遗忘周期: group_id={group_id}, " + f"开始遗忘周期: end_user_id={end_user_id}, " f"max_batch={max_merge_batch_size}, " f"min_days={min_days_since_access}" ) try: # 步骤1:统计遗忘前的节点数量 - nodes_before = await self._count_knowledge_nodes(group_id) + nodes_before = await self._count_knowledge_nodes(end_user_id) logger.info(f"遗忘前节点总数: {nodes_before}") # 步骤2:识别可遗忘的节点对 forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes( - group_id=group_id, + end_user_id=end_user_id, min_days_since_access=min_days_since_access ) @@ -213,7 +214,7 @@ class ForgettingScheduler: 'statement_text': pair['statement_text'], 'statement_activation': pair['statement_activation'], 'statement_importance': pair['statement_importance'], - 'group_id': group_id + 'end_user_id': end_user_id } entity_node = { @@ -222,7 +223,7 @@ class ForgettingScheduler: 'entity_type': pair['entity_type'], 'entity_activation': pair['entity_activation'], 'entity_importance': pair['entity_importance'], - 'group_id': group_id + 'end_user_id': end_user_id } # 融合节点 @@ -262,7 +263,7 @@ class ForgettingScheduler: continue # 步骤6:统计遗忘后的节点数量 - nodes_after = await self._count_knowledge_nodes(group_id) + nodes_after = await self._count_knowledge_nodes(end_user_id) logger.info(f"遗忘后节点总数: {nodes_after}") # 步骤7:生成遗忘报告 @@ -315,7 +316,7 @@ class ForgettingScheduler: async def _count_knowledge_nodes( self, - group_id: Optional[str] = None + end_user_id: Optional[str] = None ) -> int: """ 统计知识层节点总数 @@ -323,7 +324,7 @@ class ForgettingScheduler: 统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。 Args: - group_id: 组 ID(可选,用于过滤特定组的节点) + end_user_id: 组 ID(可选,用于过滤特定组的节点) Returns: int: 知识层节点总数 @@ -333,16 +334,16 @@ class ForgettingScheduler: WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary) """ - if group_id: - query += " AND n.group_id = $group_id" + if end_user_id: + query += " AND n.end_user_id = $end_user_id" query += """ RETURN count(n) as total """ params = {} - if group_id: - params['group_id'] = group_id + if end_user_id: + end_user_id['end_user_id'] = end_user_id results = await self.connector.execute_query(query, **params) diff --git a/api/app/core/memory/storage_services/forgetting_engine/forgetting_strategy.py b/api/app/core/memory/storage_services/forgetting_engine/forgetting_strategy.py index ccd8d2ca..a8c62dd4 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/forgetting_strategy.py +++ b/api/app/core/memory/storage_services/forgetting_engine/forgetting_strategy.py @@ -13,6 +13,7 @@ Classes: import logging from typing import List, Dict, Any, Optional +from uuid import UUID from datetime import datetime, timedelta from app.repositories.neo4j.neo4j_connector import Neo4jConnector @@ -90,7 +91,7 @@ class ForgettingStrategy: async def find_forgettable_nodes( self, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, min_days_since_access: int = 30 ) -> List[Dict[str, Any]]: """ @@ -102,7 +103,7 @@ class ForgettingStrategy: 3. Statement 和 Entity 之间存在关系边 Args: - group_id: 组 ID(可选,用于过滤特定组的节点) + end_user_id: 组 ID(可选,用于过滤特定组的节点) min_days_since_access: 最小未访问天数(默认 30 天) Returns: @@ -136,8 +137,8 @@ class ForgettingStrategy: AND (e.entity_type IS NULL OR e.entity_type <> 'Person') """ - if group_id: - query += " AND s.group_id = $group_id AND e.group_id = $group_id" + if end_user_id: + query += " AND s.end_user_id = $end_user_id AND e.end_user_id = $end_user_id" query += """ RETURN s.id as statement_id, @@ -159,8 +160,8 @@ class ForgettingStrategy: 'threshold': self.forgetting_threshold, 'cutoff_time': cutoff_time_iso } - if group_id: - params['group_id'] = group_id + if end_user_id: + params['end_user_id'] = end_user_id results = await self.connector.execute_query(query, **params) @@ -176,7 +177,7 @@ class ForgettingStrategy: self, statement_node: Dict[str, Any], entity_node: Dict[str, Any], - config_id: Optional[int] = None, + config_id: Optional[UUID] = None, db = None ) -> str: """ @@ -247,8 +248,8 @@ class ForgettingStrategy: entity_activation = entity_node['entity_activation'] entity_importance = entity_node['entity_importance'] - # 获取 group_id(从 statement 或 entity 节点) - group_id = statement_node.get('group_id') or entity_node.get('group_id') + # 获取 end_user_id(从 statement 或 entity 节点) + end_user_id = statement_node.get('end_user_id') or entity_node.get('end_user_id') # 生成摘要内容 summary_text = await self._generate_summary( @@ -325,7 +326,7 @@ class ForgettingStrategy: last_access_time: $current_time, access_count: 1, version: 1, - group_id: $group_id, + end_user_id: $end_user_id, created_at: datetime($current_time), merged_at: datetime($current_time) }) @@ -423,7 +424,7 @@ class ForgettingStrategy: 'inherited_activation': inherited_activation, 'inherited_importance': inherited_importance, 'current_time': current_time_iso, - 'group_id': group_id + 'end_user_id': end_user_id } try: @@ -462,7 +463,7 @@ class ForgettingStrategy: statement_text: str, entity_name: str, entity_type: str, - config_id: Optional[int] = None, + config_id: Optional[UUID] = None, db = None ) -> str: """ @@ -527,7 +528,7 @@ class ForgettingStrategy: statement_text, entity_name, entity_type ) - async def _get_llm_client(self, db, config_id: int): + async def _get_llm_client(self, db, config_id: UUID): """ 从数据库获取 LLM 客户端 @@ -539,11 +540,11 @@ class ForgettingStrategy: LLM 客户端实例,如果无法获取则返回 None """ try: - from app.repositories.data_config_repository import DataConfigRepository + from app.repositories.memory_config_repository import MemoryConfigRepository from app.core.memory.utils.llm.llm_utils import MemoryClientFactory # 从数据库读取配置 - repository = DataConfigRepository() + repository = MemoryConfigRepository() db_config = repository.get_by_id(db, config_id) if db_config is None or db_config.llm_id is None: diff --git a/api/app/core/memory/storage_services/search/__init__.py b/api/app/core/memory/storage_services/search/__init__.py index 2bec5bf1..c12c39b0 100644 --- a/api/app/core/memory/storage_services/search/__init__.py +++ b/api/app/core/memory/storage_services/search/__init__.py @@ -37,7 +37,7 @@ __all__ = [ async def run_hybrid_search( query_text: str, search_type: str = "hybrid", - group_id: str | None = None, + end_user_id: str | None = None, apply_id: str | None = None, user_id: str | None = None, limit: int = 50, @@ -54,7 +54,7 @@ async def run_hybrid_search( Args: query_text: 查询文本 search_type: 搜索类型("hybrid", "keyword", "semantic") - group_id: 组ID过滤 + end_user_id: 组ID过滤 apply_id: 应用ID过滤 user_id: 用户ID过滤 limit: 每个类别的最大结果数 @@ -104,7 +104,7 @@ async def run_hybrid_search( # 执行搜索 result = await strategy.search( query_text=query_text, - group_id=group_id, + end_user_id=end_user_id, limit=limit, include=include, alpha=alpha, diff --git a/api/app/core/memory/storage_services/search/hybrid_search.py b/api/app/core/memory/storage_services/search/hybrid_search.py index 43215df5..4111b09c 100644 --- a/api/app/core/memory/storage_services/search/hybrid_search.py +++ b/api/app/core/memory/storage_services/search/hybrid_search.py @@ -77,7 +77,7 @@ # async def search( # self, # query_text: str, -# group_id: Optional[str] = None, +# end_user_id: Optional[str] = None, # limit: int = 50, # include: Optional[List[str]] = None, # **kwargs @@ -86,7 +86,7 @@ # Args: # query_text: 查询文本 -# group_id: 可选的组ID过滤 +# end_user_id: 可选的组ID过滤 # limit: 每个类别的最大结果数 # include: 要包含的搜索类别列表 # **kwargs: 其他搜索参数(如alpha, use_forgetting_curve) @@ -94,7 +94,7 @@ # Returns: # SearchResult: 搜索结果对象 # """ -# logger.info(f"执行混合搜索: query='{query_text}', group_id={group_id}, limit={limit}") +# logger.info(f"执行混合搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}") # # 从kwargs中获取参数 # alpha = kwargs.get("alpha", self.alpha) @@ -107,14 +107,14 @@ # # 并行执行关键词搜索和语义搜索 # keyword_result = await self.keyword_strategy.search( # query_text=query_text, -# group_id=group_id, +# end_user_id=end_user_id, # limit=limit, # include=include_list # ) # semantic_result = await self.semantic_strategy.search( # query_text=query_text, -# group_id=group_id, +# end_user_id=end_user_id, # limit=limit, # include=include_list # ) @@ -139,7 +139,7 @@ # metadata = self._create_metadata( # query_text=query_text, # search_type="hybrid", -# group_id=group_id, +# end_user_id=end_user_id, # limit=limit, # include=include_list, # alpha=alpha, @@ -165,7 +165,7 @@ # metadata=self._create_metadata( # query_text=query_text, # search_type="hybrid", -# group_id=group_id, +# end_user_id=end_user_id, # limit=limit, # error=str(e) # ) diff --git a/api/app/core/memory/storage_services/search/keyword_search.py b/api/app/core/memory/storage_services/search/keyword_search.py index 95dd0581..d2591945 100644 --- a/api/app/core/memory/storage_services/search/keyword_search.py +++ b/api/app/core/memory/storage_services/search/keyword_search.py @@ -44,7 +44,7 @@ class KeywordSearchStrategy(SearchStrategy): async def search( self, query_text: str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, limit: int = 50, include: Optional[List[str]] = None, **kwargs @@ -53,7 +53,7 @@ class KeywordSearchStrategy(SearchStrategy): Args: query_text: 查询文本 - group_id: 可选的组ID过滤 + end_user_id: 可选的组ID过滤 limit: 每个类别的最大结果数 include: 要包含的搜索类别列表 **kwargs: 其他搜索参数 @@ -61,7 +61,7 @@ class KeywordSearchStrategy(SearchStrategy): Returns: SearchResult: 搜索结果对象 """ - logger.info(f"执行关键词搜索: query='{query_text}', group_id={group_id}, limit={limit}") + logger.info(f"执行关键词搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}") # 获取有效的搜索类别 include_list = self._get_include_list(include) @@ -75,7 +75,7 @@ class KeywordSearchStrategy(SearchStrategy): results_dict = await search_graph( connector=self.connector, q=query_text, - group_id=group_id, + end_user_id=end_user_id, limit=limit, include=include_list ) @@ -84,7 +84,7 @@ class KeywordSearchStrategy(SearchStrategy): metadata = self._create_metadata( query_text=query_text, search_type="keyword", - group_id=group_id, + end_user_id=end_user_id, limit=limit, include=include_list ) @@ -115,7 +115,7 @@ class KeywordSearchStrategy(SearchStrategy): metadata=self._create_metadata( query_text=query_text, search_type="keyword", - group_id=group_id, + end_user_id=end_user_id, limit=limit, error=str(e) ) diff --git a/api/app/core/memory/storage_services/search/search_strategy.py b/api/app/core/memory/storage_services/search/search_strategy.py index 27c02c89..3a670dd6 100644 --- a/api/app/core/memory/storage_services/search/search_strategy.py +++ b/api/app/core/memory/storage_services/search/search_strategy.py @@ -58,7 +58,7 @@ class SearchStrategy(ABC): async def search( self, query_text: str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, limit: int = 50, include: Optional[List[str]] = None, **kwargs @@ -67,7 +67,7 @@ class SearchStrategy(ABC): Args: query_text: 查询文本 - group_id: 可选的组ID过滤 + end_user_id: 可选的组ID过滤 limit: 每个类别的最大结果数 include: 要包含的搜索类别列表(statements, chunks, entities, summaries) **kwargs: 其他搜索参数 @@ -81,7 +81,7 @@ class SearchStrategy(ABC): self, query_text: str, search_type: str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, limit: int = 50, **kwargs ) -> Dict[str, Any]: @@ -90,7 +90,7 @@ class SearchStrategy(ABC): Args: query_text: 查询文本 search_type: 搜索类型 - group_id: 组ID + end_user_id: 组ID limit: 结果限制 **kwargs: 其他元数据 @@ -100,7 +100,7 @@ class SearchStrategy(ABC): metadata = { "query": query_text, "search_type": search_type, - "group_id": group_id, + "end_user_id": end_user_id, "limit": limit, "timestamp": datetime.now().isoformat() } diff --git a/api/app/core/memory/storage_services/search/semantic_search.py b/api/app/core/memory/storage_services/search/semantic_search.py index b20f90a5..8d4eb05f 100644 --- a/api/app/core/memory/storage_services/search/semantic_search.py +++ b/api/app/core/memory/storage_services/search/semantic_search.py @@ -85,7 +85,7 @@ class SemanticSearchStrategy(SearchStrategy): async def search( self, query_text: str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, limit: int = 50, include: Optional[List[str]] = None, **kwargs @@ -94,7 +94,7 @@ class SemanticSearchStrategy(SearchStrategy): Args: query_text: 查询文本 - group_id: 可选的组ID过滤 + end_user_id: 可选的组ID过滤 limit: 每个类别的最大结果数 include: 要包含的搜索类别列表 **kwargs: 其他搜索参数 @@ -102,7 +102,7 @@ class SemanticSearchStrategy(SearchStrategy): Returns: SearchResult: 搜索结果对象 """ - logger.info(f"执行语义搜索: query='{query_text}', group_id={group_id}, limit={limit}") + logger.info(f"执行语义搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}") # 获取有效的搜索类别 include_list = self._get_include_list(include) @@ -119,7 +119,7 @@ class SemanticSearchStrategy(SearchStrategy): connector=self.connector, embedder_client=self.embedder_client, query_text=query_text, - group_id=group_id, + end_user_id=end_user_id, limit=limit, include=include_list ) @@ -128,7 +128,7 @@ class SemanticSearchStrategy(SearchStrategy): metadata = self._create_metadata( query_text=query_text, search_type="semantic", - group_id=group_id, + end_user_id=end_user_id, limit=limit, include=include_list ) @@ -159,7 +159,7 @@ class SemanticSearchStrategy(SearchStrategy): metadata=self._create_metadata( query_text=query_text, search_type="semantic", - group_id=group_id, + end_user_id=end_user_id, limit=limit, error=str(e) ) diff --git a/api/app/core/memory/utils/config/get_data.py b/api/app/core/memory/utils/config/get_data.py index 1de6f6aa..e37ad723 100644 --- a/api/app/core/memory/utils/config/get_data.py +++ b/api/app/core/memory/utils/config/get_data.py @@ -23,7 +23,7 @@ async def _load_(data: List[Any]) -> List[Dict]: target_keys = [ "id", "statement", - "group_id", + "end_user_id", "chunk_id", "created_at", "expired_at", @@ -75,7 +75,7 @@ async def get_data(result): """ EXCLUDE_FIELDS = { "user_id", - "group_id", + "end_user_id", "entity_type", "connect_strength", "relationship_type", diff --git a/api/app/core/memory/utils/log/audit_logger.py b/api/app/core/memory/utils/log/audit_logger.py index 9010aad5..f80ad4d5 100644 --- a/api/app/core/memory/utils/log/audit_logger.py +++ b/api/app/core/memory/utils/log/audit_logger.py @@ -62,7 +62,7 @@ class ConfigAuditLogger: self, config_id: str, user_id: Optional[str] = None, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, success: bool = True, details: Optional[Dict[str, Any]] = None ): @@ -72,14 +72,14 @@ class ConfigAuditLogger: Args: config_id: 配置 ID user_id: 用户 ID(可选) - group_id: 组 ID(可选) + end_user_id: 组 ID(可选) success: 是否成功 details: 详细信息(可选) """ result = "SUCCESS" if success else "FAILED" msg = ( f"CONFIG_LOAD config_id={config_id} " - f"user={user_id or 'N/A'} group={group_id or 'N/A'} " + f"user={user_id or 'N/A'} group={end_user_id or 'N/A'} " f"result={result}" ) if details: @@ -121,7 +121,7 @@ class ConfigAuditLogger: self, operation: str, config_id: str, - group_id: str, + end_user_id: str, success: bool = True, duration: Optional[float] = None, error: Optional[str] = None, @@ -133,7 +133,7 @@ class ConfigAuditLogger: Args: operation: 操作类型(WRITE, READ 等) config_id: 配置 ID - group_id: 组 ID + end_user_id: 组 ID success: 是否成功 duration: 操作耗时(秒) error: 错误信息(可选) @@ -142,7 +142,7 @@ class ConfigAuditLogger: result = "SUCCESS" if success else "FAILED" msg = ( f"{operation.upper()} config_id={config_id} " - f"group={group_id} result={result}" + f"group={end_user_id} result={result}" ) if duration is not None: msg += f" duration={duration:.2f}s" diff --git a/api/app/core/rag/vdb/field.py b/api/app/core/rag/vdb/field.py index 86d39060..99d872c2 100644 --- a/api/app/core/rag/vdb/field.py +++ b/api/app/core/rag/vdb/field.py @@ -4,7 +4,7 @@ from enum import StrEnum, auto class Field(StrEnum): CONTENT_KEY = "page_content" METADATA_KEY = "metadata" - GROUP_KEY = "group_id" + GROUP_KEY = "end_user_id" VECTOR = auto() # Sparse Vector aims to support full text search SPARSE_VECTOR = auto() diff --git a/api/app/core/validators/memory_config_validators.py b/api/app/core/validators/memory_config_validators.py index 333572e6..ba26c5f2 100644 --- a/api/app/core/validators/memory_config_validators.py +++ b/api/app/core/validators/memory_config_validators.py @@ -26,7 +26,7 @@ logger = get_config_logger() def _parse_model_id(model_id: Union[str, UUID, None], model_type: str, - config_id: Optional[int] = None, workspace_id: Optional[UUID] = None) -> Optional[UUID]: + config_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None) -> Optional[UUID]: """Parse model ID from string or UUID.""" if model_id is None: return None @@ -59,7 +59,7 @@ def validate_model_exists_and_active( model_type: str, db: Session, tenant_id: Optional[UUID] = None, - config_id: Optional[int] = None, + config_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None ) -> tuple[str, bool]: """Validate that a model exists and is active. @@ -166,7 +166,7 @@ def validate_and_resolve_model_id( db: Session, tenant_id: Optional[UUID] = None, required: bool = False, - config_id: Optional[int] = None, + config_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None ) -> tuple[Optional[UUID], Optional[str]]: """Validate and resolve a model ID, checking existence and active status. @@ -204,7 +204,7 @@ def validate_and_resolve_model_id( def validate_embedding_model( - config_id: int, + config_id: UUID, embedding_id: Union[str, UUID, None], db: Session, tenant_id: Optional[UUID] = None, @@ -256,7 +256,7 @@ def validate_embedding_model( def validate_llm_model( - config_id: int, + config_id: UUID, llm_id: Union[str, UUID, None], db: Session, tenant_id: Optional[UUID] = None, diff --git a/api/app/core/workflow/nodes/memory/config.py b/api/app/core/workflow/nodes/memory/config.py index 987230c1..4c8c43eb 100644 --- a/api/app/core/workflow/nodes/memory/config.py +++ b/api/app/core/workflow/nodes/memory/config.py @@ -1,4 +1,5 @@ import uuid +from uuid import UUID from pydantic import Field from typing import Literal @@ -11,7 +12,7 @@ class MemoryReadNodeConfig(BaseNodeConfig): ... ) - config_id: int = Field( + config_id: UUID = Field( ... ) @@ -26,6 +27,6 @@ class MemoryWriteNodeConfig(BaseNodeConfig): ... ) - config_id: int = Field( + config_id: UUID = Field( ... ) diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index 08a2b280..0589cc82 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -22,7 +22,7 @@ class MemoryReadNode(BaseNode): raise RuntimeError("End user id is required") return await MemoryAgentService().read_memory( - group_id=end_user_id, + end_user_id=end_user_id, message=self._render_template(self.typed_config.message, state), config_id=str(self.typed_config.config_id), search_switch=self.typed_config.search_switch, diff --git a/api/app/models/__init__.py b/api/app/models/__init__.py index bf3a1b3d..e069b40d 100644 --- a/api/app/models/__init__.py +++ b/api/app/models/__init__.py @@ -18,7 +18,7 @@ from .appshare_model import AppShare from .release_share_model import ReleaseShare from .conversation_model import Conversation, Message from .api_key_model import ApiKey, ApiKeyLog, ApiKeyType -from .data_config_model import DataConfig +from .memory_config_model import MemoryConfig from .multi_agent_model import MultiAgentConfig, AgentInvocation from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution from .retrieval_info import RetrievalInfo @@ -57,7 +57,7 @@ __all__ = [ "ApiKey", "ApiKeyLog", "ApiKeyType", - "DataConfig", + "MemoryConfig", "MultiAgentConfig", "AgentInvocation", "WorkflowConfig", diff --git a/api/app/models/data_config_model.py b/api/app/models/data_config_model.py deleted file mode 100644 index 06f87cb2..00000000 --- a/api/app/models/data_config_model.py +++ /dev/null @@ -1,88 +0,0 @@ -import datetime -from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float -from sqlalchemy.dialects.postgresql import UUID -from app.db import Base - - -class DataConfig(Base): - """数据配置表 - 用于存储记忆系统的配置参数""" - __tablename__ = "data_config" - - # 主键 - config_id = Column(Integer, primary_key=True, autoincrement=True, comment="配置ID") - - # 基本信息 - config_name = Column(String, nullable=False, comment="配置名称") - config_desc = Column(String, nullable=True, comment="配置描述") - - # 组织信息 - workspace_id = Column(UUID(as_uuid=True), nullable=True, comment="工作空间ID") - group_id = Column(String, nullable=True, comment="组ID") - user_id = Column(String, nullable=True, comment="用户ID") - apply_id = Column(String, nullable=True, comment="应用ID") - - # 模型选择(从workspace继承) - llm_id = Column(String, nullable=True, comment="LLM模型配置ID") - embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID") - rerank_id = Column(String, nullable=True, comment="重排序模型配置ID") - - # 记忆萃取引擎配置 - enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重") - enable_llm_disambiguation = Column(Boolean, default=True, comment="启用LLM决策消歧") - deep_retrieval = Column(Boolean, default=True, comment="深度检索开关") - - # 阈值配置 (0-1 之间的浮点数) - t_type_strict = Column(Float, default=0.8, comment="类型严格阈值") - t_name_strict = Column(Float, default=0.8, comment="名称严格阈值") - t_overall = Column(Float, default=0.8, comment="综合阈值") - - # 状态配置 - state = Column(Boolean, default=False, comment="配置使用状态") - - # 分块策略 - chunker_strategy = Column(String, default="RecursiveChunker", comment="分块策略") - - # 剪枝配置 - pruning_enabled = Column(Boolean, default=False, comment="是否启动智能语义剪枝") - pruning_scene = Column(String, nullable=True, comment="智能剪枝场景:education/online_service/outbound") - pruning_threshold = Column(Float, nullable=True, comment="智能语义剪枝阈值(0-0.9)") - - # 自我反思配置 - enable_self_reflexion = Column(Boolean, default=False, comment="是否启用自我反思") - iteration_period = Column(String, default="3", comment="反思迭代周期") - reflexion_range = Column(String, default="partial", comment="反思范围:部分/全部") - baseline = Column(String, default="TIME", comment="基线:时间/事实/时间和事实") - reflection_model_id = Column(String, nullable=True, comment="反思模型ID") - memory_verify = Column(Boolean, default=True, comment="记忆验证") - quality_assessment = Column(Boolean, default=True, comment="质量评估") - - # 遗忘引擎配置 - statement_granularity = Column(Integer, default=2, comment="陈述提取颗粒度,挡位 1/2/3") - include_dialogue_context = Column(Boolean, default=False, comment="是否包含对话上下文") - max_context = Column(Integer, default=1000, comment="对话语境中包含字符的最大数量") - lambda_time = Column("lambda_time", Float, default=0.5, comment="最低保持度,0-1 小数") - lambda_mem = Column("lambda_mem", Float, default=0.5, comment="遗忘率,0-1 小数") - offset = Column("offset", Float, default=0.0, comment="偏移度,0-1 小数") - - # ACT-R 遗忘引擎配置 - decay_constant = Column(Float, default=0.5, comment="ACT-R衰减常数d,默认0.5") - forgetting_threshold = Column(Float, default=0.3, comment="遗忘阈值,默认0.3") - forgetting_interval_hours = Column(Integer, default=24, comment="遗忘周期间隔(小时),默认24") - enable_llm_summary = Column(Boolean, default=True, comment="是否使用LLM生成摘要,默认True") - max_merge_batch_size = Column(Integer, default=100, comment="单次最大融合节点对数,默认100") - max_history_length = Column(Integer, default=100, comment="访问历史最大长度,默认100") - min_days_since_access = Column(Integer, default=30, comment="最小未访问天数,默认30") - - # 情绪引擎配置 - emotion_enabled = Column(Boolean, default=True, comment="是否启用情绪提取") - emotion_model_id = Column(String, nullable=True, comment="情绪分析专用模型ID") - emotion_extract_keywords = Column(Boolean, default=True, comment="是否提取情绪关键词") - emotion_min_intensity = Column(Float, default=0.1, comment="最小情绪强度阈值") - emotion_enable_subject = Column(Boolean, default=True, comment="是否启用主体分类") - - # 时间戳 - created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") - updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") - - def __repr__(self): - return f"" diff --git a/api/app/models/memory_config_model.py b/api/app/models/memory_config_model.py index d47c3b52..b468e2a2 100644 --- a/api/app/models/memory_config_model.py +++ b/api/app/models/memory_config_model.py @@ -1,39 +1,88 @@ -# -*- coding: utf-8 -*- -"""Memory Configuration Model - Backward Compatibility +import datetime +from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float +from sqlalchemy.dialects.postgresql import UUID +from app.db import Base -This module provides backward compatibility for imports. -All classes have been moved to app.schemas.memory_config_schema. -DEPRECATED: Import from app.schemas.memory_config_schema instead. -""" +class MemoryConfig(Base): + """记忆配置表 - 用于存储记忆系统的配置参数""" + __tablename__ = "memory_config" -# Re-export for backward compatibility -from app.schemas.memory_config_schema import ( - ConfigurationError, - InvalidConfigError, - MemoryConfig, - MemoryConfigValidation, - ModelInactiveError, - ModelNotFoundError, - ModelValidation, - WorkspaceNotFoundError, - WorkspaceValidation, - validate_memory_config_data, - validate_model_data, - validate_workspace_data, -) + # 主键 + config_id = Column(UUID(as_uuid=True), primary_key=True, comment="配置ID") -__all__ = [ - "ConfigurationError", - "InvalidConfigError", - "MemoryConfig", - "MemoryConfigValidation", - "ModelInactiveError", - "ModelNotFoundError", - "ModelValidation", - "WorkspaceNotFoundError", - "WorkspaceValidation", - "validate_memory_config_data", - "validate_model_data", - "validate_workspace_data", -] + # 基本信息 + config_name = Column(String, nullable=False, comment="配置名称") + config_desc = Column(String, nullable=True, comment="配置描述") + + # 组织信息 + workspace_id = Column(UUID(as_uuid=True), nullable=True, comment="工作空间ID") + end_user_id = Column(String, nullable=True, comment="组ID") + user_id = Column(String, nullable=True, comment="用户ID") + apply_id = Column(String, nullable=True, comment="应用ID") + + # 模型选择(从workspace继承) + llm_id = Column(String, nullable=True, comment="LLM模型配置ID") + embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID") + rerank_id = Column(String, nullable=True, comment="重排序模型配置ID") + + # 记忆萃取引擎配置 + enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重") + enable_llm_disambiguation = Column(Boolean, default=True, comment="启用LLM决策消歧") + deep_retrieval = Column(Boolean, default=True, comment="深度检索开关") + + # 阈值配置 (0-1 之间的浮点数) + t_type_strict = Column(Float, default=0.8, comment="类型严格阈值") + t_name_strict = Column(Float, default=0.8, comment="名称严格阈值") + t_overall = Column(Float, default=0.8, comment="综合阈值") + + # 状态配置 + state = Column(Boolean, default=False, comment="配置使用状态") + + # 分块策略 + chunker_strategy = Column(String, default="RecursiveChunker", comment="分块策略") + + # 剪枝配置 + pruning_enabled = Column(Boolean, default=False, comment="是否启动智能语义剪枝") + pruning_scene = Column(String, nullable=True, comment="智能剪枝场景:education/online_service/outbound") + pruning_threshold = Column(Float, nullable=True, comment="智能语义剪枝阈值(0-0.9)") + + # 自我反思配置 + enable_self_reflexion = Column(Boolean, default=False, comment="是否启用自我反思") + iteration_period = Column(String, default="3", comment="反思迭代周期") + reflexion_range = Column(String, default="partial", comment="反思范围:部分/全部") + baseline = Column(String, default="TIME", comment="基线:时间/事实/时间和事实") + reflection_model_id = Column(String, nullable=True, comment="反思模型ID") + memory_verify = Column(Boolean, default=True, comment="记忆验证") + quality_assessment = Column(Boolean, default=True, comment="质量评估") + + # 遗忘引擎配置 + statement_granularity = Column(Integer, default=2, comment="陈述提取颗粒度,挡位 1/2/3") + include_dialogue_context = Column(Boolean, default=False, comment="是否包含对话上下文") + max_context = Column(Integer, default=1000, comment="对话语境中包含字符的最大数量") + lambda_time = Column("lambda_time", Float, default=0.5, comment="最低保持度,0-1 小数") + lambda_mem = Column("lambda_mem", Float, default=0.5, comment="遗忘率,0-1 小数") + offset = Column("offset", Float, default=0.0, comment="偏移度,0-1 小数") + + # ACT-R 遗忘引擎配置 + decay_constant = Column(Float, default=0.5, comment="ACT-R衰减常数d,默认0.5") + forgetting_threshold = Column(Float, default=0.3, comment="遗忘阈值,默认0.3") + forgetting_interval_hours = Column(Integer, default=24, comment="遗忘周期间隔(小时),默认24") + enable_llm_summary = Column(Boolean, default=True, comment="是否使用LLM生成摘要,默认True") + max_merge_batch_size = Column(Integer, default=100, comment="单次最大融合节点对数,默认100") + max_history_length = Column(Integer, default=100, comment="访问历史最大长度,默认100") + min_days_since_access = Column(Integer, default=30, comment="最小未访问天数,默认30") + + # 情绪引擎配置 + emotion_enabled = Column(Boolean, default=True, comment="是否启用情绪提取") + emotion_model_id = Column(String, nullable=True, comment="情绪分析专用模型ID") + emotion_extract_keywords = Column(Boolean, default=True, comment="是否提取情绪关键词") + emotion_min_intensity = Column(Float, default=0.1, comment="最小情绪强度阈值") + emotion_enable_subject = Column(Boolean, default=True, comment="是否启用主体分类") + + # 时间戳 + created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") + updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") + + def __repr__(self): + return f"" diff --git a/api/app/models/memory_perceptual_model.py b/api/app/models/memory_perceptual_model.py index 59eb0222..cafb18d4 100644 --- a/api/app/models/memory_perceptual_model.py +++ b/api/app/models/memory_perceptual_model.py @@ -16,7 +16,7 @@ class PerceptualType(IntEnum): CONVERSATION = 4 -class FileStorageType(IntEnum): +class FileStorageService(IntEnum): LOCAL = 1 REMOTE = 2 diff --git a/api/app/repositories/data_config_repository.py b/api/app/repositories/memory_config_repository.py similarity index 73% rename from api/app/repositories/data_config_repository.py rename to api/app/repositories/memory_config_repository.py index 3df7f800..12e564e2 100644 --- a/api/app/repositories/data_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -1,18 +1,19 @@ # -*- coding: utf-8 -*- -"""数据配置Repository模块 +"""记忆配置Repository模块 -本模块提供data_config表的数据访问层,使用SQLAlchemy ORM进行数据库操作。 +本模块提供memory_config表的数据访问层,使用SQLAlchemy ORM进行数据库操作。 包括CRUD操作和Neo4j Cypher查询常量。 Classes: - DataConfigRepository: 数据配置仓储类,提供CRUD操作 + MemoryConfigRepository: 记忆配置仓储类,提供CRUD操作 """ import uuid +from uuid import UUID from typing import Dict, List, Optional, Tuple from app.core.exceptions import BusinessException from app.core.logging_config import get_config_logger, get_db_logger -from app.models.data_config_model import DataConfig +from app.models.memory_config_model import MemoryConfig from app.schemas.memory_storage_schema import ( ConfigKey, ConfigParamsCreate, @@ -28,11 +29,11 @@ db_logger = get_db_logger() # 获取配置专用日志器 config_logger = get_config_logger() -TABLE_NAME = "data_config" -class DataConfigRepository: - """数据配置Repository +TABLE_NAME = "memory_config" +class MemoryConfigRepository: + """记忆配置Repository - 提供data_config表的数据访问方法,包括: + 提供memory_config表的数据访问方法,包括: - SQLAlchemy ORM 数据库操作 - Neo4j Cypher查询常量 """ @@ -41,48 +42,48 @@ class DataConfigRepository: # Dialogue count by group SEARCH_FOR_DIALOGUE = """ - MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN COUNT(n) AS num + MATCH (n:Dialogue) WHERE n.end_user_id = $end_user_id RETURN COUNT(n) AS num """ # Chunk count by group SEARCH_FOR_CHUNK = """ - MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN COUNT(n) AS num + MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN COUNT(n) AS num """ # Statement count by group SEARCH_FOR_STATEMENT = """ - MATCH (n:Statement) WHERE n.group_id = $group_id RETURN COUNT(n) AS num + MATCH (n:Statement) WHERE n.end_user_id = $end_user_id RETURN COUNT(n) AS num """ # ExtractedEntity count by group SEARCH_FOR_ENTITY = """ - MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN COUNT(n) AS num + MATCH (n:ExtractedEntity) WHERE n.end_user_id = $end_user_id RETURN COUNT(n) AS num """ # All counts by label and total SEARCH_FOR_ALL = """ - OPTIONAL MATCH (n:Dialogue) WHERE n.group_id = $group_id RETURN 'Dialogue' AS Label, COUNT(n) AS Count + OPTIONAL MATCH (n:Dialogue) WHERE n.end_user_id = $end_user_id RETURN 'Dialogue' AS Label, COUNT(n) AS Count UNION ALL - OPTIONAL MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN 'Chunk' AS Label, COUNT(n) AS Count + OPTIONAL MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN 'Chunk' AS Label, COUNT(n) AS Count UNION ALL - OPTIONAL MATCH (n:Statement) WHERE n.group_id = $group_id RETURN 'Statement' AS Label, COUNT(n) AS Count + OPTIONAL MATCH (n:Statement) WHERE n.end_user_id = $end_user_id RETURN 'Statement' AS Label, COUNT(n) AS Count UNION ALL - OPTIONAL MATCH (n:ExtractedEntity) WHERE n.group_id = $group_id RETURN 'ExtractedEntity' AS Label, COUNT(n) AS Count + OPTIONAL MATCH (n:ExtractedEntity) WHERE n.end_user_id = $end_user_id RETURN 'ExtractedEntity' AS Label, COUNT(n) AS Count UNION ALL - OPTIONAL MATCH (n) WHERE n.group_id = $group_id RETURN 'ALL' AS Label, COUNT(n) AS Count + OPTIONAL MATCH (n) WHERE n.end_user_id = $end_user_id RETURN 'ALL' AS Label, COUNT(n) AS Count """ # Extracted entity details within group/app/user SEARCH_FOR_DETIALS = """ MATCH (n:ExtractedEntity) - WHERE n.group_id = $group_id + WHERE n.end_user_id = $end_user_id RETURN n.entity_idx AS entity_idx, n.connect_strength AS connect_strength, n.description AS description, n.entity_type AS entity_type, n.name AS name, COALESCE(n.fact_summary, '') AS fact_summary, - n.group_id AS group_id, + n.end_user_id AS end_user_id, n.apply_id AS apply_id, n.user_id AS user_id, n.id AS id @@ -91,9 +92,9 @@ class DataConfigRepository: # Edges between extracted entities within group/app/user SEARCH_FOR_EDGES = """ MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity) - WHERE n.group_id = $group_id + WHERE n.end_user_id = $end_user_id RETURN - r.group_id AS group_id, + r.end_user_id AS end_user_id, r.apply_id AS apply_id, r.user_id AS user_id, elementId(r) AS rel_id, @@ -107,7 +108,7 @@ class DataConfigRepository: @staticmethod def update_reflection_config( db: Session, - config_id: int, + config_id: uuid.UUID, enable_self_reflexion: bool, iteration_period: str, reflexion_range: str, @@ -115,7 +116,7 @@ class DataConfigRepository: reflection_model_id: str, memory_verify: bool, quality_assessment: bool - ) -> DataConfig: + ) -> MemoryConfig: """构建反思配置更新语句(SQLAlchemy text() 命名参数) Args: @@ -130,28 +131,28 @@ class DataConfigRepository: config_id: 配置ID Returns: - Data + MemoryConfig Raises: ValueError: 没有字段需要更新时抛出 """ db_logger.debug(f"构建反思配置更新语句: config_id={config_id}") - stmt = select(DataConfig).where(DataConfig.config_id == config_id) - data_config_obj = db.scalars(stmt).first() - if not data_config_obj: + stmt = select(MemoryConfig).where(MemoryConfig.config_id == config_id) + memory_config_obj = db.scalars(stmt).first() + if not memory_config_obj: raise BusinessException - data_config_obj.enable_self_reflexion = enable_self_reflexion - data_config_obj.iteration_period = iteration_period - data_config_obj.reflexion_range = reflexion_range - data_config_obj.baseline = baseline - data_config_obj.reflection_model_id = reflection_model_id - data_config_obj.memory_verify = memory_verify - data_config_obj.quality_assessment = quality_assessment + memory_config_obj.enable_self_reflexion = enable_self_reflexion + memory_config_obj.iteration_period = iteration_period + memory_config_obj.reflexion_range = reflexion_range + memory_config_obj.baseline = baseline + memory_config_obj.reflection_model_id = reflection_model_id + memory_config_obj.memory_verify = memory_verify + memory_config_obj.quality_assessment = quality_assessment - return data_config_obj + return memory_config_obj @staticmethod - def query_reflection_config_by_id(db: Session, config_id: int) -> DataConfig: + def query_reflection_config_by_id(db: Session, config_id: uuid.UUID) -> MemoryConfig: """构建反思配置查询语句,通过config_id查询反思配置(SQLAlchemy text() 命名参数) Args: @@ -162,13 +163,13 @@ class DataConfigRepository: Tuple[str, Dict]: (SQL查询字符串, 参数字典) """ db_logger.debug(f"构建反思配置查询语句: config_id={config_id}") - stmt = select(DataConfig).where(DataConfig.config_id == config_id) - data_config = db.scalars(stmt).first() - if not data_config: + stmt = select(MemoryConfig).where(MemoryConfig.config_id == config_id) + memory_config = db.scalars(stmt).first() + if not memory_config: raise RuntimeError("reflection config not found") - return data_config + return memory_config @staticmethod - def query_reflection_config_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> DataConfig: + def query_reflection_config_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> MemoryConfig: """构建查询所有配置的语句(SQLAlchemy text() 命名参数) Args: @@ -180,11 +181,11 @@ class DataConfigRepository: """ db_logger.debug(f"构建查询所有配置语句: workspace_id={workspace_id}") - stmt = select(DataConfig).where(DataConfig.workspace_id == workspace_id) - data_config = db.scalars(stmt).first() - if not data_config: + stmt = select(MemoryConfig).where(MemoryConfig.workspace_id == workspace_id) + memory_config = db.scalars(stmt).first() + if not memory_config: raise RuntimeError("reflection config not found") - return data_config + return memory_config @staticmethod @@ -208,20 +209,21 @@ class DataConfigRepository: return query, params @staticmethod - def create(db: Session, params: ConfigParamsCreate) -> DataConfig: - """创建数据配置 + def create(db: Session, params: ConfigParamsCreate) -> MemoryConfig: + """创建记忆配置 Args: db: 数据库会话 params: 配置参数创建模型 Returns: - DataConfig: 创建的配置对象 + MemoryConfig: 创建的配置对象 """ - db_logger.debug(f"创建数据配置: config_name={params.config_name}, workspace_id={params.workspace_id}") + db_logger.debug(f"创建记忆配置: config_name={params.config_name}, workspace_id={params.workspace_id}") try: - db_config = DataConfig( + db_config = MemoryConfig( + config_id=uuid.uuid4(), config_name=params.config_name, config_desc=params.config_desc, workspace_id=params.workspace_id, @@ -232,16 +234,16 @@ class DataConfigRepository: db.add(db_config) db.flush() # 获取自增ID但不提交事务 - db_logger.info(f"数据配置已添加到会话: {db_config.config_name} (ID: {db_config.config_id})") + db_logger.info(f"记忆配置已添加到会话: {db_config.config_name} (ID: {db_config.config_id})") return db_config except Exception as e: db.rollback() - db_logger.error(f"创建数据配置失败: {params.config_name} - {str(e)}") + db_logger.error(f"创建记忆配置失败: {params.config_name} - {str(e)}") raise @staticmethod - def update(db: Session, update: ConfigUpdate) -> Optional[DataConfig]: + def update(db: Session, update: ConfigUpdate) -> Optional[MemoryConfig]: """更新基础配置 Args: @@ -249,17 +251,17 @@ class DataConfigRepository: update: 配置更新模型 Returns: - Optional[DataConfig]: 更新后的配置对象,不存在则返回None + Optional[MemoryConfig]: 更新后的配置对象,不存在则返回None Raises: ValueError: 没有字段需要更新时抛出 """ - db_logger.debug(f"更新数据配置: config_id={update.config_id}") + db_logger.debug(f"更新记忆配置: config_id={update.config_id}") try: - db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first() + db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first() if not db_config: - db_logger.warning(f"数据配置不存在: config_id={update.config_id}") + db_logger.warning(f"记忆配置不存在: config_id={update.config_id}") return None # 更新字段 @@ -277,17 +279,17 @@ class DataConfigRepository: db.commit() db.refresh(db_config) - db_logger.info(f"数据配置更新成功: {db_config.config_name} (ID: {update.config_id})") + db_logger.info(f"记忆配置更新成功: {db_config.config_name} (ID: {update.config_id})") return db_config except Exception as e: db.rollback() - db_logger.error(f"更新数据配置失败: config_id={update.config_id} - {str(e)}") + db_logger.error(f"更新记忆配置失败: config_id={update.config_id} - {str(e)}") raise @staticmethod - def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[DataConfig]: + def update_extracted(db: Session, update: ConfigUpdateExtracted) -> Optional[MemoryConfig]: """更新记忆萃取引擎配置 Args: @@ -295,7 +297,7 @@ class DataConfigRepository: update: 萃取配置更新模型 Returns: - Optional[DataConfig]: 更新后的配置对象,不存在则返回None + Optional[MemoryConfig]: 更新后的配置对象,不存在则返回None Raises: ValueError: 没有字段需要更新时抛出 @@ -303,9 +305,9 @@ class DataConfigRepository: db_logger.debug(f"更新萃取配置: config_id={update.config_id}") try: - db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first() + db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first() if not db_config: - db_logger.warning(f"数据配置不存在: config_id={update.config_id}") + db_logger.warning(f"记忆配置不存在: config_id={update.config_id}") return None # 更新字段映射 @@ -360,7 +362,7 @@ class DataConfigRepository: raise @staticmethod - def update_forget(db: Session, update: ConfigUpdateForget) -> Optional[DataConfig]: + def update_forget(db: Session, update: ConfigUpdateForget) -> Optional[MemoryConfig]: """更新遗忘引擎配置 Args: @@ -368,7 +370,7 @@ class DataConfigRepository: update: 遗忘配置更新模型 Returns: - Optional[DataConfig]: 更新后的配置对象,不存在则返回None + Optional[MemoryConfig]: 更新后的配置对象,不存在则返回None Raises: ValueError: 没有字段需要更新时抛出 @@ -376,9 +378,9 @@ class DataConfigRepository: db_logger.debug(f"更新遗忘配置: config_id={update.config_id}") try: - db_config = db.query(DataConfig).filter(DataConfig.config_id == update.config_id).first() + db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == update.config_id).first() if not db_config: - db_logger.warning(f"数据配置不存在: config_id={update.config_id}") + db_logger.warning(f"记忆配置不存在: config_id={update.config_id}") return None # 更新字段 @@ -408,7 +410,7 @@ class DataConfigRepository: raise @staticmethod - def get_extracted_config(db: Session, config_id: int) -> Optional[Dict]: + def get_extracted_config(db: Session, config_id: UUID) -> Optional[Dict]: """获取萃取配置,通过主键查询某条配置 Args: @@ -421,7 +423,7 @@ class DataConfigRepository: db_logger.debug(f"查询萃取配置: config_id={config_id}") try: - db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() + db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first() if not db_config: db_logger.debug(f"萃取配置不存在: config_id={config_id}") return None @@ -457,7 +459,7 @@ class DataConfigRepository: raise @staticmethod - def get_forget_config(db: Session, config_id: int) -> Optional[Dict]: + def get_forget_config(db: Session, config_id: UUID) -> Optional[Dict]: """获取遗忘配置,通过主键查询某条配置 Args: @@ -470,7 +472,7 @@ class DataConfigRepository: db_logger.debug(f"查询遗忘配置: config_id={config_id}") try: - db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() + db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first() if not db_config: db_logger.debug(f"遗忘配置不存在: config_id={config_id}") return None @@ -489,39 +491,39 @@ class DataConfigRepository: raise @staticmethod - def get_by_id(db: Session, config_id: int) -> Optional[DataConfig]: - """根据ID获取数据配置 + def get_by_id(db: Session, config_id: uuid.UUID) -> Optional[MemoryConfig]: + """根据ID获取记忆配置 Args: db: 数据库会话 config_id: 配置ID Returns: - Optional[DataConfig]: 配置对象,不存在则返回None + Optional[MemoryConfig]: 配置对象,不存在则返回None """ - db_logger.debug(f"根据ID查询数据配置: config_id={config_id}") + db_logger.debug(f"根据ID查询记忆配置: config_id={config_id}") try: - config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() + config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first() if config: - db_logger.debug(f"数据配置查询成功: {config.config_name} (ID: {config_id})") + db_logger.debug(f"记忆配置查询成功: {config.config_name} (ID: {config_id})") else: - db_logger.debug(f"数据配置不存在: config_id={config_id}") + db_logger.debug(f"记忆配置不存在: config_id={config_id}") return config except Exception as e: - db_logger.error(f"根据ID查询数据配置失败: config_id={config_id} - {str(e)}") + db_logger.error(f"根据ID查询记忆配置失败: config_id={config_id} - {str(e)}") raise @staticmethod - def get_config_with_workspace(db: Session, config_id: int) -> Optional[tuple]: - """Get data config and its associated workspace information + def get_config_with_workspace(db: Session, config_id: uuid.UUID) -> Optional[tuple]: + """Get memory config and its associated workspace information Args: db: Database session config_id: Configuration ID Returns: - Optional[tuple]: (DataConfig, Workspace) tuple, None if not found + Optional[tuple]: (MemoryConfig, Workspace) tuple, None if not found Raises: ValueError: Raised when config exists but workspace doesn't @@ -541,19 +543,19 @@ class DataConfigRepository: } ) - db_logger.debug(f"Querying data config and workspace: config_id={config_id}") + db_logger.debug(f"Querying memory config and workspace: config_id={config_id}") try: # Use join query to get both config and workspace - result = db.query(DataConfig, Workspace).join( - Workspace, DataConfig.workspace_id == Workspace.id - ).filter(DataConfig.config_id == config_id).first() + result = db.query(MemoryConfig, Workspace).join( + Workspace, MemoryConfig.workspace_id == Workspace.id + ).filter(MemoryConfig.config_id == config_id).first() elapsed_ms = (time.time() - start_time) * 1000 if not result: # Check if config exists but workspace is missing - config_only = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() + config_only = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first() if config_only: if config_only.workspace_id is None: config_logger.error( @@ -566,7 +568,7 @@ class DataConfigRepository: "elapsed_ms": elapsed_ms } ) - db_logger.error(f"Data config {config_id} has no associated workspace ID") + db_logger.error(f"Memory config {config_id} has no associated workspace ID") raise ValueError(f"Configuration {config_id} has no associated workspace") else: config_logger.error( @@ -579,7 +581,7 @@ class DataConfigRepository: "elapsed_ms": elapsed_ms } ) - db_logger.error(f"Data config {config_id} references non-existent workspace {config_only.workspace_id}") + db_logger.error(f"Memory config {config_id} references non-existent workspace {config_only.workspace_id}") raise ValueError(f"Workspace {config_only.workspace_id} not found for configuration {config_id}") config_logger.debug( @@ -591,7 +593,7 @@ class DataConfigRepository: "elapsed_ms": elapsed_ms } ) - db_logger.debug(f"Data config not found: config_id={config_id}") + db_logger.debug(f"Memory config not found: config_id={config_id}") return None config, workspace = result @@ -611,7 +613,7 @@ class DataConfigRepository: } ) - db_logger.debug(f"Data config and workspace query successful: config={config.config_name}, workspace={workspace.name}") + db_logger.debug(f"Memory config and workspace query successful: config={config.config_name}, workspace={workspace.name}") return (config, workspace) except ValueError: @@ -633,10 +635,10 @@ class DataConfigRepository: exc_info=True ) - db_logger.error(f"Failed to query data config and workspace: config_id={config_id} - {str(e)}") + db_logger.error(f"Failed to query memory config and workspace: config_id={config_id} - {str(e)}") raise @staticmethod - def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[DataConfig]: + def get_all(db: Session, workspace_id: Optional[uuid.UUID] = None) -> List[MemoryConfig]: """获取所有配置参数 Args: @@ -644,17 +646,17 @@ class DataConfigRepository: workspace_id: 工作空间ID,用于过滤查询结果 Returns: - List[DataConfig]: 配置列表 + List[MemoryConfig]: 配置列表 """ db_logger.debug(f"查询所有配置: workspace_id={workspace_id}") try: - query = db.query(DataConfig) + query = db.query(MemoryConfig) if workspace_id: - query = query.filter(DataConfig.workspace_id == workspace_id) + query = query.filter(MemoryConfig.workspace_id == workspace_id) - configs = query.order_by(desc(DataConfig.updated_at)).all() + configs = query.order_by(desc(MemoryConfig.updated_at)).all() db_logger.debug(f"配置列表查询成功: 数量={len(configs)}") return configs @@ -664,8 +666,8 @@ class DataConfigRepository: raise @staticmethod - def delete(db: Session, config_id: int) -> bool: - """删除数据配置 + def delete(db: Session, config_id: uuid.UUID) -> bool: + """删除记忆配置 Args: db: 数据库会话 @@ -674,22 +676,22 @@ class DataConfigRepository: Returns: bool: 删除成功返回True,配置不存在返回False """ - db_logger.debug(f"删除数据配置: config_id={config_id}") + db_logger.debug(f"删除记忆配置: config_id={config_id}") try: - db_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() + db_config = db.query(MemoryConfig).filter(MemoryConfig.config_id == config_id).first() if not db_config: - db_logger.warning(f"数据配置不存在: config_id={config_id}") + db_logger.warning(f"记忆配置不存在: config_id={config_id}") return False db.delete(db_config) db.commit() - db_logger.info(f"数据配置删除成功: config_id={config_id}") + db_logger.info(f"记忆配置删除成功: config_id={config_id}") return True except Exception as e: db.rollback() - db_logger.error(f"删除数据配置失败: config_id={config_id} - {str(e)}") + db_logger.error(f"删除记忆配置失败: config_id={config_id} - {str(e)}") raise diff --git a/api/app/repositories/memory_perceptual_repository.py b/api/app/repositories/memory_perceptual_repository.py index 8415c2d0..9fa9536e 100644 --- a/api/app/repositories/memory_perceptual_repository.py +++ b/api/app/repositories/memory_perceptual_repository.py @@ -6,7 +6,7 @@ from sqlalchemy import and_, desc from sqlalchemy.orm import Session from app.core.logging_config import get_db_logger -from app.models.memory_perceptual_model import MemoryPerceptualModel, PerceptualType, FileStorageType +from app.models.memory_perceptual_model import MemoryPerceptualModel, PerceptualType, FileStorageService from app.schemas.memory_perceptual_schema import PerceptualQuerySchema db_logger = get_db_logger() @@ -28,7 +28,7 @@ class MemoryPerceptualRepository: file_ext: str, summary: Optional[str] = None, meta_data: Optional[dict] = None, - storage_service: FileStorageType = FileStorageType.LOCAL + storage_service: FileStorageService = FileStorageService.LOCAL ) -> MemoryPerceptualModel: diff --git a/api/app/repositories/neo4j/add_edges.py b/api/app/repositories/neo4j/add_edges.py index 3b45867e..162bf411 100644 --- a/api/app/repositories/neo4j/add_edges.py +++ b/api/app/repositories/neo4j/add_edges.py @@ -32,7 +32,7 @@ async def add_chunk_statement_edges(chunks: List[Chunk], connector: Neo4jConnect "id": stable_edge_id, "source": chunk.id, "target": stmt.id, - "group_id": getattr(stmt, 'group_id', None), + "end_user_id": getattr(stmt, 'end_user_id', None), "user_id":getattr(stmt, 'user_id', None), "apply_id": getattr(stmt, 'apply_id', None), "run_id": getattr(stmt, 'run_id', None) or getattr(chunk, 'run_id', None), @@ -83,7 +83,7 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode], edges.append({ "summary_id": s.id, "chunk_id": chunk_id, - "group_id": s.group_id, + "end_user_id": s.end_user_id, "run_id": s.run_id, "created_at": s.created_at.isoformat() if s.created_at else None, "expired_at": s.expired_at.isoformat() if s.expired_at else None, diff --git a/api/app/repositories/neo4j/add_nodes.py b/api/app/repositories/neo4j/add_nodes.py index cf60a773..fcf700b5 100644 --- a/api/app/repositories/neo4j/add_nodes.py +++ b/api/app/repositories/neo4j/add_nodes.py @@ -6,10 +6,10 @@ from app.core.memory.models.graph_models import DialogueNode, StatementNode, Chu from app.repositories.neo4j.neo4j_connector import Neo4jConnector -async def delete_all_nodes(group_id: str, connector: Neo4jConnector): +async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector): """Delete all nodes in the database.""" - result = await connector.execute_query(f"MATCH (n {{group_id: '{group_id}'}}) DETACH DELETE n") - print(f"All group_id: {group_id} node and edge deleted successfully") + result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n") + print(f"All end_user_id: {end_user_id} node and edge deleted successfully") return result async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]: @@ -32,9 +32,7 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn for dialogue in dialogues: flattened_dialogues.append({ "id": dialogue.id, - "group_id": dialogue.group_id, - "user_id": dialogue.user_id, - "apply_id": dialogue.apply_id, + "end_user_id": dialogue.end_user_id, "run_id": dialogue.run_id, "ref_id": dialogue.ref_id, "name": dialogue.name, @@ -79,9 +77,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC flattened_statement = { "id": statement.id, "name": statement.name, - "group_id": statement.group_id, - "user_id": statement.user_id, - "apply_id": statement.apply_id, + "end_user_id": statement.end_user_id, "run_id": statement.run_id, "chunk_id": statement.chunk_id, # "created_at": statement.created_at.isoformat(), @@ -154,9 +150,7 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> flattened_chunk = { "id": chunk.id, "name": chunk.name, - "group_id": chunk.group_id, - "user_id": chunk.user_id, - "apply_id": chunk.apply_id, + "end_user_id": chunk.end_user_id, "run_id": chunk.run_id, "created_at": chunk.created_at.isoformat() if chunk.created_at else None, "expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None, @@ -206,9 +200,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector flattened.append({ "id": s.id, "name": s.name, - "group_id": s.group_id, - "user_id": s.user_id, - "apply_id": s.apply_id, + "end_user_id": s.end_user_id, "run_id": s.run_id, "created_at": s.created_at.isoformat() if s.created_at else None, "expired_at": s.expired_at.isoformat() if s.expired_at else None, diff --git a/api/app/repositories/neo4j/base_neo4j_repository.py b/api/app/repositories/neo4j/base_neo4j_repository.py index 959a1e68..df953eb9 100644 --- a/api/app/repositories/neo4j/base_neo4j_repository.py +++ b/api/app/repositories/neo4j/base_neo4j_repository.py @@ -152,7 +152,7 @@ class BaseNeo4jRepository(BaseRepository[T]): Example: >>> results = await repository.find( - ... {"group_id": "group_123", "user_id": "user_456"}, + ... {"end_user_id": "group_123", "user_id": "user_456"}, ... limit=50 ... ) """ diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index cd3cbed7..c93e75b3 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -3,9 +3,7 @@ DIALOGUE_NODE_SAVE = """ UNWIND $dialogues AS dialogue MERGE (n:Dialogue {id: dialogue.id}) SET n.uuid = coalesce(n.uuid, dialogue.id), - n.group_id = dialogue.group_id, - n.user_id = dialogue.user_id, - n.apply_id = dialogue.apply_id, + n.end_user_id = dialogue.end_user_id, n.run_id = dialogue.run_id, n.ref_id = dialogue.ref_id, n.created_at = dialogue.created_at, @@ -22,9 +20,7 @@ SET s += { id: statement.id, run_id: statement.run_id, chunk_id: statement.chunk_id, - group_id: statement.group_id, - user_id: statement.user_id, - apply_id: statement.apply_id, + end_user_id: statement.end_user_id, stmt_type: statement.stmt_type, statement: statement.statement, emotion_intensity: statement.emotion_intensity, @@ -54,9 +50,7 @@ MERGE (c:Chunk {id: chunk.id}) SET c += { id: chunk.id, name: chunk.name, - group_id: chunk.group_id, - user_id: chunk.user_id, - apply_id: chunk.apply_id, + end_user_id: chunk.end_user_id, run_id: chunk.run_id, created_at: chunk.created_at, expired_at: chunk.expired_at, @@ -76,9 +70,7 @@ EXTRACTED_ENTITY_NODE_SAVE = """ UNWIND $entities AS entity MERGE (e:ExtractedEntity {id: entity.id}) SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity.name ELSE e.name END, - e.group_id = CASE WHEN entity.group_id IS NOT NULL AND entity.group_id <> '' THEN entity.group_id ELSE e.group_id END, - e.user_id = CASE WHEN entity.user_id IS NOT NULL AND entity.user_id <> '' THEN entity.user_id ELSE e.user_id END, - e.apply_id = CASE WHEN entity.apply_id IS NOT NULL AND entity.apply_id <> '' THEN entity.apply_id ELSE e.apply_id END, + e.end_user_id = CASE WHEN entity.end_user_id IS NOT NULL AND entity.end_user_id <> '' THEN entity.end_user_id ELSE e.end_user_id END, e.run_id = CASE WHEN entity.run_id IS NOT NULL AND entity.run_id <> '' THEN entity.run_id ELSE e.run_id END, e.created_at = CASE WHEN entity.created_at IS NOT NULL AND (e.created_at IS NULL OR entity.created_at < e.created_at) @@ -134,9 +126,9 @@ RETURN e.id AS uuid # Add back ENTITY_RELATIONSHIP_SAVE to be used by graph_saver.save_entities_and_relationships ENTITY_RELATIONSHIP_SAVE = """ UNWIND $relationships AS rel -// Match entities by stable id within group, do not constrain by run_id -MATCH (subject:ExtractedEntity {id: rel.source_id, group_id: rel.group_id}) -MATCH (object:ExtractedEntity {id: rel.target_id, group_id: rel.group_id}) +// Match entities by stable id within end_user_id, do not constrain by run_id +MATCH (subject:ExtractedEntity {id: rel.source_id, end_user_id: rel.end_user_id}) +MATCH (object:ExtractedEntity {id: rel.target_id, end_user_id: rel.end_user_id}) // Avoid duplicate edges across runs for the same endpoints MERGE (subject)-[r:EXTRACTED_RELATIONSHIP]->(object) SET r.predicate = rel.predicate, @@ -148,7 +140,7 @@ SET r.predicate = rel.predicate, r.created_at = rel.created_at, r.expired_at = rel.expired_at, r.run_id = rel.run_id, - r.group_id = rel.group_id + r.end_user_id = rel.end_user_id RETURN elementId(r) AS uuid """ @@ -160,7 +152,7 @@ UNWIND $weak_entities AS entity MERGE (e:ExtractedEntity {id: entity.id, run_id: entity.run_id}) SET e += { name: entity.name, - group_id: entity.group_id, + end_user_id: entity.end_user_id, run_id: entity.run_id, description: entity.description, chunk_id: entity.chunk_id, @@ -175,11 +167,11 @@ RETURN e.id AS id SAVE_STRONG_TRIPLE_ENTITIES = """ UNWIND $items AS item MERGE (s:ExtractedEntity {id: item.source_id, run_id: item.run_id}) -SET s += {name: item.subject, group_id: item.group_id, run_id: item.run_id} +SET s += {name: item.subject, end_user_id: item.end_user_id, run_id: item.run_id} // Independent strong flag SET s.is_strong = true MERGE (o:ExtractedEntity {id: item.target_id, run_id: item.run_id}) -SET o += {name: item.object, group_id: item.group_id, run_id: item.run_id} +SET o += {name: item.object, end_user_id: item.end_user_id, run_id: item.run_id} // Independent strong flag SET o.is_strong = true """ @@ -194,7 +186,7 @@ DIALOGUE_STATEMENT_EDGE_SAVE = """ // 仅按端点去重,关系属性可更新 MERGE (dialogue)-[e:MENTIONS]->(statement) SET e.uuid = edge.id, - e.group_id = edge.group_id, + e.end_user_id = edge.end_user_id, e.created_at = edge.created_at, e.expired_at = edge.expired_at RETURN e.uuid AS uuid @@ -208,7 +200,7 @@ CHUNK_STATEMENT_EDGE_SAVE = """ MATCH (statement:Statement {id: edge.source, run_id: edge.run_id}) MATCH (chunk:Chunk {id: edge.target, run_id: edge.run_id}) MERGE (chunk)-[e:CONTAINS {id: edge.id}]->(statement) - SET e.group_id = edge.group_id, + SET e.end_user_id = edge.end_user_id, e.run_id = edge.run_id, e.created_at = edge.created_at, e.expired_at = edge.expired_at @@ -218,13 +210,12 @@ CHUNK_STATEMENT_EDGE_SAVE = """ STATEMENT_ENTITY_EDGE_SAVE = """ UNWIND $relationships AS rel // Statement nodes are per-run; keep run_id constraint on statements -// Statement nodes are per-run; keep run_id constraint on statements MATCH (statement:Statement {id: rel.source, run_id: rel.run_id}) -// Entities are shared across runs within a group; do not constrain by run_id -MATCH (entity:ExtractedEntity {id: rel.target, group_id: rel.group_id}) +// Entities are shared across runs within end_user_id; do not constrain by run_id +MATCH (entity:ExtractedEntity {id: rel.target, end_user_id: rel.end_user_id}) // Avoid duplicate edges across runs for same endpoints MERGE (statement)-[r:REFERENCES_ENTITY]->(entity) -SET r.group_id = rel.group_id, +SET r.end_user_id = rel.end_user_id, r.run_id = rel.run_id, r.created_at = rel.created_at, r.expired_at = rel.expired_at, @@ -236,10 +227,10 @@ ENTITY_EMBEDDING_SEARCH = """ CALL db.index.vector.queryNodes('entity_embedding_index', $limit * 100, $embedding) YIELD node AS e, score WHERE e.name_embedding IS NOT NULL - AND ($group_id IS NULL OR e.group_id = $group_id) + AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id) RETURN e.id AS id, e.name AS name, - e.group_id AS group_id, + e.end_user_id AS end_user_id, e.entity_type AS entity_type, COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value, COALESCE(e.importance_score, 0.5) AS importance_score, @@ -254,10 +245,10 @@ STATEMENT_EMBEDDING_SEARCH = """ CALL db.index.vector.queryNodes('statement_embedding_index', $limit * 100, $embedding) YIELD node AS s, score WHERE s.statement_embedding IS NOT NULL - AND ($group_id IS NULL OR s.group_id = $group_id) + AND ($end_user_id IS NULL OR s.end_user_id = $end_user_id) RETURN s.id AS id, s.statement AS statement, - s.group_id AS group_id, + s.end_user_id AS end_user_id, s.chunk_id AS chunk_id, s.created_at AS created_at, s.expired_at AS expired_at, @@ -277,9 +268,9 @@ CHUNK_EMBEDDING_SEARCH = """ CALL db.index.vector.queryNodes('chunk_embedding_index', $limit * 100, $embedding) YIELD node AS c, score WHERE c.chunk_embedding IS NOT NULL - AND ($group_id IS NULL OR c.group_id = $group_id) + AND ($end_user_id IS NULL OR c.end_user_id = $end_user_id) RETURN c.id AS chunk_id, - c.group_id AS group_id, + c.end_user_id AS end_user_id, c.content AS content, c.dialog_id AS dialog_id, COALESCE(c.activation_value, 0.5) AS activation_value, @@ -292,12 +283,12 @@ LIMIT $limit SEARCH_STATEMENTS_BY_KEYWORD = """ CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score -WHERE ($group_id IS NULL OR s.group_id = $group_id) +WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id) OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity) RETURN s.id AS id, s.statement AS statement, - s.group_id AS group_id, + s.end_user_id AS end_user_id, s.chunk_id AS chunk_id, s.created_at AS created_at, s.expired_at AS expired_at, @@ -316,15 +307,13 @@ LIMIT $limit # 查询实体名称包含指定字符串的实体 SEARCH_ENTITIES_BY_NAME = """ CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score -WHERE ($group_id IS NULL OR e.group_id = $group_id) +WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) RETURN e.id AS id, e.name AS name, - e.group_id AS group_id, + e.end_user_id AS end_user_id, e.entity_type AS entity_type, - e.apply_id AS apply_id, - e.user_id AS user_id, e.created_at AS created_at, e.expired_at AS expired_at, e.entity_idx AS entity_idx, @@ -347,11 +336,11 @@ LIMIT $limit SEARCH_CHUNKS_BY_CONTENT = """ CALL db.index.fulltext.queryNodes("chunksFulltext", $q) YIELD node AS c, score -WHERE ($group_id IS NULL OR c.group_id = $group_id) +WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id) OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement) OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity) RETURN c.id AS chunk_id, - c.group_id AS group_id, + c.end_user_id AS end_user_id, c.content AS content, c.dialog_id AS dialog_id, c.sequence_number AS sequence_number, @@ -413,10 +402,10 @@ LIMIT $limit SEARCH_DIALOGUE_BY_DIALOG_ID = """ MATCH (d:Dialogue) -WHERE ($group_id IS NULL OR d.group_id = $group_id) +WHERE ($end_user_id IS NULL OR d.end_user_id = $end_user_id) AND d.id = $dialog_id RETURN d.id AS dialog_id, - d.group_id AS group_id, + d.end_user_id AS end_user_id, d.content AS content, d.created_at AS created_at, d.expired_at AS expired_at @@ -426,10 +415,10 @@ LIMIT $limit SEARCH_CHUNK_BY_CHUNK_ID = """ MATCH (c:Chunk) -WHERE ($group_id IS NULL OR c.group_id = $group_id) +WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id) AND c.id = $chunk_id RETURN c.id AS chunk_id, - c.group_id AS group_id, + c.end_user_id AS end_user_id, c.content AS content, c.dialog_id AS dialog_id, c.created_at AS created_at, @@ -441,18 +430,14 @@ LIMIT $limit SEARCH_STATEMENTS_BY_TEMPORAL = """ MATCH (s:Statement) -WHERE ($group_id IS NULL OR s.group_id = $group_id) - AND ($apply_id IS NULL OR s.apply_id = $apply_id) - AND ($user_id IS NULL OR s.user_id = $user_id) +WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id) AND ((($start_date IS NULL OR datetime(s.created_at) >= datetime($start_date)) AND ($end_date IS NULL OR datetime(s.created_at) <= datetime($end_date))) OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date))) AND ($invalid_date IS NULL OR (s.invalid_at IS NOT NULL AND datetime(s.invalid_at) <= datetime($invalid_date))))) RETURN s.id AS id, s.statement AS statement, - s.group_id AS group_id, - s.apply_id AS apply_id, - s.user_id AS user_id, + s.end_user_id AS end_user_id, s.chunk_id AS chunk_id, s.created_at AS created_at, s.valid_at AS valid_at, @@ -468,9 +453,7 @@ LIMIT $limit SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL = """ CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score -WHERE ($group_id IS NULL OR s.group_id = $group_id) - AND ($apply_id IS NULL OR s.apply_id = $apply_id) - AND ($user_id IS NULL OR s.user_id = $user_id) +WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id) AND ((($start_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) >= datetime($start_date))) AND ($end_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) <= datetime($end_date)))) OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date))) @@ -479,9 +462,7 @@ OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity) RETURN s.id AS id, s.statement AS statement, - s.group_id AS group_id, - s.apply_id AS apply_id, - s.user_id AS user_id, + s.end_user_id AS end_user_id, s.chunk_id AS chunk_id, s.created_at AS created_at, s.valid_at AS valid_at, @@ -499,15 +480,11 @@ LIMIT $limit SEARCH_STATEMENTS_BY_CREATED_AT = """ MATCH (n:Statement) -WHERE ($group_id IS NULL OR n.group_id = $group_id) - AND ($apply_id IS NULL OR n.apply_id = $apply_id) - AND ($user_id IS NULL OR n.user_id = $user_id) +WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id) AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 10)) = date($created_at)) RETURN n.id AS id, n.statement AS statement, - n.group_id AS group_id, - n.apply_id AS apply_id, - n.user_id AS user_id, + n.end_user_id AS end_user_id, n.chunk_id AS chunk_id, n.created_at AS created_at, n.valid_at AS valid_at, @@ -519,15 +496,11 @@ LIMIT $limit SEARCH_STATEMENTS_BY_VALID_AT = """ MATCH (n:Statement) -WHERE ($group_id IS NULL OR n.group_id = $group_id) - AND ($apply_id IS NULL OR n.apply_id = $apply_id) - AND ($user_id IS NULL OR n.user_id = $user_id) +WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id) AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) = date($valid_at)) RETURN n.id AS id, n.statement AS statement, - n.group_id AS group_id, - n.apply_id AS apply_id, - n.user_id AS user_id, + n.end_user_id AS end_user_id, n.chunk_id AS chunk_id, n.created_at AS created_at, n.valid_at AS valid_at, @@ -539,15 +512,11 @@ LIMIT $limit SEARCH_STATEMENTS_G_CREATED_AT = """ MATCH (n:Statement) -WHERE ($group_id IS NULL OR n.group_id = $group_id) - AND ($apply_id IS NULL OR n.apply_id = $apply_id) - AND ($user_id IS NULL OR n.user_id = $user_id) +WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id) AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) = date($created_at)) RETURN n.id AS id, n.statement AS statement, - n.group_id AS group_id, - n.apply_id AS apply_id, - n.user_id AS user_id, + n.end_user_id AS end_user_id, n.chunk_id AS chunk_id, n.created_at AS created_at, n.valid_at AS valid_at, @@ -559,15 +528,11 @@ LIMIT $limit SEARCH_STATEMENTS_L_CREATED_AT = """ MATCH (n:Statement) -WHERE ($group_id IS NULL OR n.group_id = $group_id) - AND ($apply_id IS NULL OR n.apply_id = $apply_id) - AND ($user_id IS NULL OR n.user_id = $user_id) +WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id) AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) < date($created_at)) RETURN n.id AS id, n.statement AS statement, - n.group_id AS group_id, - n.apply_id AS apply_id, - n.user_id AS user_id, + n.end_user_id AS end_user_id, n.chunk_id AS chunk_id, n.created_at AS created_at, n.valid_at AS valid_at, @@ -579,15 +544,11 @@ LIMIT $limit SEARCH_STATEMENTS_G_VALID_AT = """ MATCH (n:Statement) -WHERE ($group_id IS NULL OR n.group_id = $group_id) - AND ($apply_id IS NULL OR n.apply_id = $apply_id) - AND ($user_id IS NULL OR n.user_id = $user_id) +WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id) AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) > date($valid_at)) RETURN n.id AS id, n.statement AS statement, - n.group_id AS group_id, - n.apply_id AS apply_id, - n.user_id AS user_id, + n.end_user_id AS end_user_id, n.chunk_id AS chunk_id, n.created_at AS created_at, n.valid_at AS valid_at, @@ -599,15 +560,11 @@ LIMIT $limit SEARCH_STATEMENTS_L_VALID_AT = """ MATCH (n:Statement) -WHERE ($group_id IS NULL OR n.group_id = $group_id) - AND ($apply_id IS NULL OR n.apply_id = $apply_id) - AND ($user_id IS NULL OR n.user_id = $user_id) +WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id) AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) < date($valid_at)) RETURN n.id AS id, n.statement AS statement, - n.group_id AS group_id, - n.apply_id AS apply_id, - n.user_id AS user_id, + n.end_user_id AS end_user_id, n.chunk_id AS chunk_id, n.created_at AS created_at, n.valid_at AS valid_at, @@ -665,18 +622,18 @@ LIMIT $limit # 根据id修改句子的invalid_at的值 UPDATE_STATEMENT_INVALID_AT = """ -MATCH (n:Statement {group_id: $group_id, id: $id}) +MATCH (n:Statement {end_user_id: $end_user_id, id: $id}) SET n.invalid_at = $new_invalid_at """ # MemorySummary keyword search using fulltext index SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """ CALL db.index.fulltext.queryNodes("summariesFulltext", $q) YIELD node AS m, score -WHERE ($group_id IS NULL OR m.group_id = $group_id) +WHERE ($end_user_id IS NULL OR m.end_user_id = $end_user_id) OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement) RETURN m.id AS id, m.name AS name, - m.group_id AS group_id, + m.end_user_id AS end_user_id, m.dialog_id AS dialog_id, m.chunk_ids AS chunk_ids, m.content AS content, @@ -695,10 +652,10 @@ MEMORY_SUMMARY_EMBEDDING_SEARCH = """ CALL db.index.vector.queryNodes('summary_embedding_index', $limit * 100, $embedding) YIELD node AS m, score WHERE m.summary_embedding IS NOT NULL - AND ($group_id IS NULL OR m.group_id = $group_id) + AND ($end_user_id IS NULL OR m.end_user_id = $end_user_id) RETURN m.id AS id, m.name AS name, - m.group_id AS group_id, + m.end_user_id AS end_user_id, m.dialog_id AS dialog_id, m.chunk_ids AS chunk_ids, m.content AS content, @@ -718,9 +675,7 @@ MERGE (m:MemorySummary {id: summary.id}) SET m += { id: summary.id, name: summary.name, - group_id: summary.group_id, - user_id: summary.user_id, - apply_id: summary.apply_id, + end_user_id: summary.end_user_id, run_id: summary.run_id, created_at: summary.created_at, expired_at: summary.expired_at, @@ -745,7 +700,7 @@ MATCH (ms:MemorySummary {id: e.summary_id, run_id: e.run_id}) MATCH (c:Chunk {id: e.chunk_id, run_id: e.run_id}) MATCH (c)-[:CONTAINS]->(s:Statement {run_id: e.run_id}) MERGE (ms)-[r:DERIVED_FROM_STATEMENT]->(s) -SET r.group_id = e.group_id, +SET r.end_user_id = e.end_user_id, r.run_id = e.run_id, r.created_at = e.created_at, r.expired_at = e.expired_at @@ -774,7 +729,7 @@ FOREACH (rel IN CASE WHEN r IS NOT NULL THEN [r] ELSE [] END | source_statement_id: rel.source_statement_id, valid_at: rel.valid_at, invalid_at: rel.invalid_at, - group_id: rel.group_id, + end_user_id: rel.end_user_id, user_id: rel.user_id, apply_id: rel.apply_id, run_id: rel.run_id, @@ -796,7 +751,7 @@ FOREACH (rel IN CASE WHEN r IS NOT NULL THEN [r] ELSE [] END | source_statement_id: rel.source_statement_id, valid_at: rel.valid_at, invalid_at: rel.invalid_at, - group_id: rel.group_id, + end_user_id: rel.end_user_id, user_id: rel.user_id, apply_id: rel.apply_id, run_id: rel.run_id, @@ -814,7 +769,7 @@ RETURN count(losing) as deleted neo4j_statement_part = ''' MATCH (n:Statement) -WHERE n.group_id = "{}" +WHERE n.end_user_id = "{}" AND datetime(n.created_at) >= datetime() - duration('P3D') RETURN n.statement as statement_name, @@ -824,7 +779,7 @@ RETURN ''' neo4j_statement_all = ''' MATCH (n:Statement) -WHERE n.group_id = "{}" +WHERE n.end_user_id = "{}" RETURN n.statement as statement_name, n.id as statement_id @@ -832,7 +787,7 @@ RETURN ''' neo4j_query_part = """ MATCH (n)-[r]-(m:ExtractedEntity) - WHERE n.group_id = "{}" + WHERE n.end_user_id = "{}" AND datetime(n.created_at) >= datetime() - duration('P3D') WITH DISTINCT m OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity) @@ -853,7 +808,7 @@ neo4j_query_part = """ """ neo4j_query_all = """ MATCH (n)-[r]-(m:ExtractedEntity) - WHERE n.group_id = "{}" + WHERE n.end_user_id = "{}" WITH DISTINCT m OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity) RETURN @@ -1027,14 +982,14 @@ RETURN DISTINCT Memory_Space_User=""" MATCH (n)-[r]->(m) -WHERE n.group_id = $group_id AND m.name="用户" +WHERE n.end_user_id = $end_user_id AND m.name="用户" return DISTINCT elementId(m) as id """ Memory_Space_Entity=""" MATCH (n)-[]-(m) WHERE elementId(m) = $id AND m.entity_type = "Person" RETURN -DISTINCT m.name as name,m.group_id as group_id +DISTINCT m.name as name,m.end_user_id as end_user_id """ Memory_Space_Associative=""" MATCH (u)-[]-(x)-[]-(h) diff --git a/api/app/repositories/neo4j/dialog_repository.py b/api/app/repositories/neo4j/dialog_repository.py index ccb3d94c..020e7346 100644 --- a/api/app/repositories/neo4j/dialog_repository.py +++ b/api/app/repositories/neo4j/dialog_repository.py @@ -19,7 +19,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]): """对话仓储 管理对话节点的创建、查询、更新和删除操作。 - 提供按group_id、user_id、ref_id等条件查询对话的方法。 + 提供按end_user_id、user_id、ref_id等条件查询对话的方法。 Attributes: connector: Neo4j连接器实例 @@ -54,17 +54,17 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]): return DialogueNode(**n) - async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[DialogueNode]: - """根据group_id查询对话 + async def find_by_end_user_id(self, end_user_id: str, limit: int = 100) -> List[DialogueNode]: + """根据end_user_id查询对话 Args: - group_id: 组ID + end_user_id: 组ID limit: 返回结果的最大数量 Returns: List[DialogueNode]: 对话列表 """ - return await self.find({"group_id": group_id}, limit=limit) + return await self.find({"end_user_id": end_user_id}, limit=limit) async def find_by_user_id(self, user_id: str, limit: int = 100) -> List[DialogueNode]: """根据user_id查询对话 @@ -94,14 +94,14 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]): async def find_by_group_and_user( self, - group_id: str, + end_user_id: str, user_id: str, limit: int = 100 ) -> List[DialogueNode]: - """根据group_id和user_id查询对话 + """根据end_user_id和user_id查询对话 Args: - group_id: 组ID + end_user_id: 组ID user_id: 用户ID limit: 返回结果的最大数量 @@ -109,20 +109,20 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]): List[DialogueNode]: 对话列表 """ return await self.find( - {"group_id": group_id, "user_id": user_id}, + {"end_user_id": end_user_id, "user_id": user_id}, limit=limit ) async def find_recent_dialogs( self, - group_id: str, + end_user_id: str, days: int = 7, limit: int = 100 ) -> List[DialogueNode]: """查询最近的对话 Args: - group_id: 组ID + end_user_id: 组ID days: 查询最近多少天的对话 limit: 返回结果的最大数量 @@ -131,7 +131,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]): """ query = f""" MATCH (n:{self.node_label}) - WHERE n.group_id = $group_id + WHERE n.end_user_id = $end_user_id AND n.created_at >= datetime() - duration({{days: $days}}) RETURN n ORDER BY n.created_at DESC @@ -139,7 +139,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]): """ results = await self.connector.execute_query( query, - group_id=group_id, + end_user_id=end_user_id, days=days, limit=limit ) @@ -164,22 +164,22 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]): async def find_by_config_and_group( self, config_id: str, - group_id: str, + end_user_id: str, limit: int = 100 ) -> List[DialogueNode]: - """根据config_id和group_id查询对话 + """根据config_id和end_user_id查询对话 支持按配置ID和组ID同时过滤,确保只返回使用特定配置处理的对话。 Args: config_id: 配置ID - group_id: 组ID + end_user_id: 组ID limit: 返回结果的最大数量 Returns: List[DialogueNode]: 对话列表 """ return await self.find( - {"config_id": config_id, "group_id": group_id}, + {"config_id": config_id, "end_user_id": end_user_id}, limit=limit ) diff --git a/api/app/repositories/neo4j/emotion_repository.py b/api/app/repositories/neo4j/emotion_repository.py index d445c8d4..e39968ac 100644 --- a/api/app/repositories/neo4j/emotion_repository.py +++ b/api/app/repositories/neo4j/emotion_repository.py @@ -40,7 +40,7 @@ class EmotionRepository: async def get_emotion_tags( self, - group_id: str, + end_user_id: str, emotion_type: Optional[str] = None, start_date: Optional[str] = None, end_date: Optional[str] = None, @@ -51,7 +51,7 @@ class EmotionRepository: 查询指定用户的情绪类型分布,包括计数、百分比和平均强度。 Args: - group_id: 用户组ID(宿主ID) + end_user_id: 用户组ID(宿主ID) emotion_type: 可选的情绪类型过滤(joy/sadness/anger/fear/surprise/neutral) start_date: 可选的开始日期(ISO格式字符串) end_date: 可选的结束日期(ISO格式字符串) @@ -65,8 +65,8 @@ class EmotionRepository: - avg_intensity: 平均强度 """ # 构建查询条件 - where_clauses = ["s.group_id = $group_id", "s.emotion_type IS NOT NULL"] - params = {"group_id": group_id, "limit": limit} + where_clauses = ["s.end_user_id = $end_user_id", "s.emotion_type IS NOT NULL"] + params = {"end_user_id": end_user_id, "limit": limit} if emotion_type: where_clauses.append("s.emotion_type = $emotion_type") @@ -119,7 +119,7 @@ class EmotionRepository: async def get_emotion_wordcloud( self, - group_id: str, + end_user_id: str, emotion_type: Optional[str] = None, limit: int = 50 ) -> List[Dict[str, Any]]: @@ -128,7 +128,7 @@ class EmotionRepository: 查询情绪关键词及其频率,用于生成词云可视化。 Args: - group_id: 用户组ID(宿主ID) + end_user_id: 用户组ID(宿主ID) emotion_type: 可选的情绪类型过滤 limit: 返回关键词的最大数量 @@ -140,8 +140,8 @@ class EmotionRepository: - avg_intensity: 平均强度 """ # 构建查询条件 - where_clauses = ["s.group_id = $group_id", "s.emotion_keywords IS NOT NULL"] - params = {"group_id": group_id, "limit": limit} + where_clauses = ["s.end_user_id = $end_user_id", "s.emotion_keywords IS NOT NULL"] + params = {"end_user_id": end_user_id, "limit": limit} if emotion_type: where_clauses.append("s.emotion_type = $emotion_type") @@ -186,7 +186,7 @@ class EmotionRepository: async def get_emotions_in_range( self, - group_id: str, + end_user_id: str, time_range: str = "30d" ) -> List[Dict[str, Any]]: """获取时间范围内的情绪数据 @@ -194,7 +194,7 @@ class EmotionRepository: 查询指定时间范围内的所有情绪数据,用于健康指数计算。 Args: - group_id: 用户组ID(宿主ID) + end_user_id: 用户组ID(宿主ID) time_range: 时间范围(7d/30d/90d) Returns: @@ -214,7 +214,7 @@ class EmotionRepository: # 优化的 Cypher 查询:使用字符串比较避免时区问题 query = """ MATCH (s:Statement) - WHERE s.group_id = $group_id + WHERE s.end_user_id = $end_user_id AND s.emotion_type IS NOT NULL AND s.created_at >= $start_date RETURN s.id as statement_id, @@ -227,7 +227,7 @@ class EmotionRepository: try: results = await self.connector.execute_query( query, - group_id=group_id, + end_user_id=end_user_id, start_date=start_date ) formatted_results = [ diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index 13215e0f..1575315f 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -44,9 +44,7 @@ async def save_entities_and_relationships( 'created_at': edge.created_at.isoformat(), 'expired_at': edge.expired_at.isoformat(), 'run_id': edge.run_id, - 'group_id': edge.group_id, - 'user_id': edge.user_id, - 'apply_id': edge.apply_id, + 'end_user_id': edge.end_user_id, } all_relationships.append(relationship) @@ -101,9 +99,7 @@ async def save_statement_chunk_edges( "id": edge.id, "source": edge.source, "target": edge.target, - "group_id": edge.group_id, - "user_id": edge.user_id, - "apply_id": edge.apply_id, + "end_user_id": edge.end_user_id, "run_id": edge.run_id, "created_at": edge.created_at.isoformat() if edge.created_at else None, "expired_at": edge.expired_at.isoformat() if edge.expired_at else None, @@ -132,9 +128,7 @@ async def save_statement_entity_edges( edge_data = { "source": edge.source, "target": edge.target, - "group_id": edge.group_id, - "user_id": edge.user_id, - "apply_id": edge.apply_id, + "end_user_id": edge.end_user_id, "run_id": edge.run_id, "connect_strength": edge.connect_strength, "created_at": edge.created_at.isoformat() if edge.created_at else None, diff --git a/api/app/repositories/neo4j/graph_search.py b/api/app/repositories/neo4j/graph_search.py index 6f5764b4..e8f52535 100644 --- a/api/app/repositories/neo4j/graph_search.py +++ b/api/app/repositories/neo4j/graph_search.py @@ -33,7 +33,7 @@ async def _update_activation_values_batch( connector: Neo4jConnector, nodes: List[Dict[str, Any]], node_label: str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, max_retries: int = 3 ) -> List[Dict[str, Any]]: """ @@ -46,7 +46,7 @@ async def _update_activation_values_batch( connector: Neo4j连接器 nodes: 节点列表,每个节点必须包含 'id' 字段 node_label: 节点标签(Statement, ExtractedEntity, MemorySummary) - group_id: 组ID(可选) + end_user_id: 组ID(可选) max_retries: 最大重试次数 Returns: @@ -97,7 +97,7 @@ async def _update_activation_values_batch( updated_nodes = await access_manager.record_batch_access( node_ids=unique_node_ids, node_label=node_label, - group_id=group_id + end_user_id=end_user_id ) logger.info( @@ -118,7 +118,7 @@ async def _update_activation_values_batch( async def _update_search_results_activation( connector: Neo4jConnector, results: Dict[str, List[Dict[str, Any]]], - group_id: Optional[str] = None + end_user_id: Optional[str] = None ) -> Dict[str, List[Dict[str, Any]]]: """ 更新搜索结果中所有知识节点的激活值 @@ -129,7 +129,7 @@ async def _update_search_results_activation( Args: connector: Neo4j连接器 results: 搜索结果字典,包含不同类型节点的列表 - group_id: 组ID(可选) + end_user_id: 组ID(可选) Returns: Dict[str, List[Dict[str, Any]]]: 更新后的搜索结果 @@ -152,7 +152,7 @@ async def _update_search_results_activation( connector=connector, nodes=results[key], node_label=label, - group_id=group_id + end_user_id=end_user_id ) ) update_keys.append(key) @@ -218,7 +218,7 @@ async def _update_search_results_activation( async def search_graph( connector: Neo4jConnector, q: str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, limit: int = 50, include: List[str] = None, ) -> Dict[str, List[Dict[str, Any]]]: @@ -236,7 +236,7 @@ async def search_graph( Args: connector: Neo4j connector q: Query text - group_id: Optional group filter + end_user_id: Optional group filter limit: Max results per category include: List of categories to search (default: all) @@ -254,7 +254,7 @@ async def search_graph( tasks.append(connector.execute_query( SEARCH_STATEMENTS_BY_KEYWORD, q=q, - group_id=group_id, + end_user_id=end_user_id, limit=limit, )) task_keys.append("statements") @@ -263,7 +263,7 @@ async def search_graph( tasks.append(connector.execute_query( SEARCH_ENTITIES_BY_NAME, q=q, - group_id=group_id, + end_user_id=end_user_id, limit=limit, )) task_keys.append("entities") @@ -272,7 +272,7 @@ async def search_graph( tasks.append(connector.execute_query( SEARCH_CHUNKS_BY_CONTENT, q=q, - group_id=group_id, + end_user_id=end_user_id, limit=limit, )) task_keys.append("chunks") @@ -281,7 +281,7 @@ async def search_graph( tasks.append(connector.execute_query( SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, q=q, - group_id=group_id, + end_user_id=end_user_id, limit=limit, )) task_keys.append("summaries") @@ -310,12 +310,12 @@ async def search_graph( key in include and key in results and results[key] for key in ['statements', 'entities', 'chunks'] ) - + if needs_activation_update: results = await _update_search_results_activation( connector=connector, results=results, - group_id=group_id + end_user_id=end_user_id ) return results @@ -325,7 +325,7 @@ async def search_graph_by_embedding( connector: Neo4jConnector, embedder_client, query_text: str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, limit: int = 50, include: List[str] = ["statements", "chunks", "entities","summaries"], ) -> Dict[str, List[Dict[str, Any]]]: @@ -337,7 +337,7 @@ async def search_graph_by_embedding( - Computes query embedding with the provided embedder_client - Ranks by cosine similarity in Cypher - - Filters by group_id if provided + - Filters by end_user_id if provided - Returns up to 'limit' per included type """ import time @@ -346,7 +346,7 @@ async def search_graph_by_embedding( embed_start = time.time() embeddings = await embedder_client.response([query_text]) embed_time = time.time() - embed_start - logger.info(f"[PERF] Embedding generation took: {embed_time:.4f}s") + print(f"[PERF] Embedding generation took: {embed_time:.4f}s") if not embeddings or not embeddings[0]: return {"statements": [], "chunks": [], "entities": [], "summaries": []} @@ -361,7 +361,7 @@ async def search_graph_by_embedding( tasks.append(connector.execute_query( STATEMENT_EMBEDDING_SEARCH, embedding=embedding, - group_id=group_id, + end_user_id=end_user_id, limit=limit, )) task_keys.append("statements") @@ -371,7 +371,7 @@ async def search_graph_by_embedding( tasks.append(connector.execute_query( CHUNK_EMBEDDING_SEARCH, embedding=embedding, - group_id=group_id, + end_user_id=end_user_id, limit=limit, )) task_keys.append("chunks") @@ -381,7 +381,7 @@ async def search_graph_by_embedding( tasks.append(connector.execute_query( ENTITY_EMBEDDING_SEARCH, embedding=embedding, - group_id=group_id, + end_user_id=end_user_id, limit=limit, )) task_keys.append("entities") @@ -391,7 +391,7 @@ async def search_graph_by_embedding( tasks.append(connector.execute_query( MEMORY_SUMMARY_EMBEDDING_SEARCH, embedding=embedding, - group_id=group_id, + end_user_id=end_user_id, limit=limit, )) task_keys.append("summaries") @@ -400,7 +400,7 @@ async def search_graph_by_embedding( query_start = time.time() task_results = await asyncio.gather(*tasks, return_exceptions=True) query_time = time.time() - query_start - logger.info(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s") + print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s") # Build results dictionary results: Dict[str, List[Dict[str, Any]]] = { @@ -429,13 +429,13 @@ async def search_graph_by_embedding( key in include and key in results and results[key] for key in ['statements', 'entities', 'chunks'] ) - + if needs_activation_update: update_start = time.time() results = await _update_search_results_activation( connector=connector, results=results, - group_id=group_id + end_user_id=end_user_id ) update_time = time.time() - update_start logger.info(f"[PERF] Activation value updates took: {update_time:.4f}s") @@ -445,7 +445,7 @@ async def search_graph_by_embedding( return results async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体 connector: Neo4jConnector, - group_id: str, + end_user_id: str, entities: List[Dict[str, Any]], use_contains_fallback: bool = True, batch_size: int = 500, @@ -453,7 +453,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全 ) -> Dict[str, List[Dict[str, Any]]]: """ 为第二层去重消歧批量检索候选实体(适配新版 cypher_queries): - - 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (group_id, name) 检索候选; + - 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (end_user_id, name) 检索候选; - 保留并发控制与返回结构(incoming_id -> [db_entity_props...]); - 若提供 `entity_type`,在本地对返回结果做类型过滤; - `use_contains_fallback` 保留形参以兼容,必要时可扩展二次查询策略。 @@ -477,7 +477,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全 rows = await connector.execute_query( SEARCH_ENTITIES_BY_NAME, q=name, - group_id=group_id, + end_user_id=end_user_id, limit=100, ) except Exception: @@ -501,7 +501,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全 rows = await connector.execute_query( SEARCH_ENTITIES_BY_NAME, q=name.lower(), - group_id=group_id, + end_user_id=end_user_id, limit=100, ) for r in rows: @@ -532,9 +532,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全 async def search_graph_by_keyword_temporal( connector: Neo4jConnector, query_text: str, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, + end_user_id: Optional[str] = None, start_date: Optional[str] = None, end_date: Optional[str] = None, valid_date: Optional[str] = None, @@ -547,32 +545,30 @@ async def search_graph_by_keyword_temporal( INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements containing query_text created between start_date and end_date - - Optionally filters by group_id, apply_id, user_id + - Optionally filters by end_user_id, apply_id, user_id - Returns up to 'limit' statements """ if not query_text: - logger.warning(f"query_text cannot be empty") + print(f"query_text不能为空") return {"statements": []} statements = await connector.execute_query( SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, q=query_text, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, + end_user_id=end_user_id, start_date=start_date, end_date=end_date, valid_date=valid_date, invalid_date=invalid_date, limit=limit, ) - logger.debug(f"Temporal keyword search results: {len(statements)} statements found") + print(f"查询结果为:\n{statements}") # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( connector=connector, results=results, - group_id=group_id + end_user_id=end_user_id ) return results @@ -580,9 +576,7 @@ async def search_graph_by_keyword_temporal( async def search_graph_by_temporal( connector: Neo4jConnector, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, + end_user_id: Optional[str] = None, start_date: Optional[str] = None, end_date: Optional[str] = None, valid_date: Optional[str] = None, @@ -595,14 +589,12 @@ async def search_graph_by_temporal( INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements created between start_date and end_date - - Optionally filters by group_id, apply_id, user_id + - Optionally filters by end_user_id - Returns up to 'limit' statements """ statements = await connector.execute_query( SEARCH_STATEMENTS_BY_TEMPORAL, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, + end_user_id=end_user_id, start_date=start_date, end_date=end_date, valid_date=valid_date, @@ -610,16 +602,16 @@ async def search_graph_by_temporal( limit=limit, ) - logger.debug(f"Temporal search query: {SEARCH_STATEMENTS_BY_TEMPORAL}") - logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, start_date={start_date}, end_date={end_date}, valid_date={valid_date}, invalid_date={invalid_date}, limit={limit}") - logger.debug(f"Temporal search results: {len(statements)} statements found") + print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}") + print(f"查询参数为:\n{{end_user_id: {end_user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}") + print(f"查询结果为:\n{statements}") # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( connector=connector, results=results, - group_id=group_id + end_user_id=end_user_id ) return results @@ -628,23 +620,23 @@ async def search_graph_by_temporal( async def search_graph_by_dialog_id( connector: Neo4jConnector, dialog_id: str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Dialogues. - Matches dialogues with dialog_id - - Optionally filters by group_id + - Optionally filters by end_user_id - Returns up to 'limit' dialogues """ if not dialog_id: - logger.warning(f"dialog_id cannot be empty") + print(f"dialog_id不能为空") return {"dialogues": []} dialogues = await connector.execute_query( SEARCH_DIALOGUE_BY_DIALOG_ID, - group_id=group_id, + end_user_id=end_user_id, dialog_id=dialog_id, limit=limit, ) @@ -654,15 +646,15 @@ async def search_graph_by_dialog_id( async def search_graph_by_chunk_id( connector: Neo4jConnector, chunk_id : str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: if not chunk_id: - logger.warning(f"chunk_id cannot be empty") + print(f"chunk_id不能为空") return {"chunks": []} chunks = await connector.execute_query( SEARCH_CHUNK_BY_CHUNK_ID, - group_id=group_id, + end_user_id=end_user_id, chunk_id=chunk_id, limit=limit, ) @@ -671,9 +663,9 @@ async def search_graph_by_chunk_id( async def search_graph_by_created_at( connector: Neo4jConnector, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, + end_user_id: Optional[str] = None, + + created_at: Optional[str] = None, limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: @@ -683,37 +675,37 @@ async def search_graph_by_created_at( INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements created at created_at - - Optionally filters by group_id, apply_id, user_id + - Optionally filters by end_user_id, apply_id, user_id - Returns up to 'limit' statements """ statements = await connector.execute_query( SEARCH_STATEMENTS_BY_CREATED_AT, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, + end_user_id=end_user_id, + + created_at=created_at, limit=limit, ) - logger.debug(f"Search by created_at query: {SEARCH_STATEMENTS_BY_CREATED_AT}") - logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}") - logger.debug(f"Search results: {len(statements)} statements found") + print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}") + print(f"查询参数为:\n{{end_user_id: {end_user_id} created_at: {created_at}, limit: {limit}}}") + print(f"查询结果为:\n{statements}") # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( connector=connector, results=results, - group_id=group_id + end_user_id=end_user_id ) return results async def search_graph_by_valid_at( connector: Neo4jConnector, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, + end_user_id: Optional[str] = None, + + valid_at: Optional[str] = None, limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: @@ -723,37 +715,37 @@ async def search_graph_by_valid_at( INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements valid at valid_at - - Optionally filters by group_id, apply_id, user_id + - Optionally filters by end_user_id, apply_id, user_id - Returns up to 'limit' statements """ statements = await connector.execute_query( SEARCH_STATEMENTS_BY_VALID_AT, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, + end_user_id=end_user_id, + + valid_at=valid_at, limit=limit, ) - logger.debug(f"Search by valid_at query: {SEARCH_STATEMENTS_BY_VALID_AT}") - logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}") - logger.debug(f"Search results: {len(statements)} statements found") + print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}") + print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}") + print(f"查询结果为:\n{statements}") # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( connector=connector, results=results, - group_id=group_id + end_user_id=end_user_id ) return results async def search_graph_g_created_at( connector: Neo4jConnector, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, + end_user_id: Optional[str] = None, + + created_at: Optional[str] = None, limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: @@ -763,37 +755,37 @@ async def search_graph_g_created_at( INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements created at created_at - - Optionally filters by group_id, apply_id, user_id + - Optionally filters by end_user_id, apply_id, user_id - Returns up to 'limit' statements """ statements = await connector.execute_query( SEARCH_STATEMENTS_G_CREATED_AT, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, + end_user_id=end_user_id, + + created_at=created_at, limit=limit, ) - logger.debug(f"Search greater than created_at query: {SEARCH_STATEMENTS_G_CREATED_AT}") - logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}") - logger.debug(f"Search results: {len(statements)} statements found") + print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}") + print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}") + print(f"查询结果为:\n{statements}") # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( connector=connector, results=results, - group_id=group_id + end_user_id=end_user_id ) return results async def search_graph_g_valid_at( connector: Neo4jConnector, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, + end_user_id: Optional[str] = None, + + valid_at: Optional[str] = None, limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: @@ -803,37 +795,37 @@ async def search_graph_g_valid_at( INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements valid at valid_at - - Optionally filters by group_id, apply_id, user_id + - Optionally filters by end_user_id, apply_id, user_id - Returns up to 'limit' statements """ statements = await connector.execute_query( SEARCH_STATEMENTS_G_VALID_AT, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, + end_user_id=end_user_id, + + valid_at=valid_at, limit=limit, ) - logger.debug(f"Search greater than valid_at query: {SEARCH_STATEMENTS_G_VALID_AT}") - logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}") - logger.debug(f"Search results: {len(statements)} statements found") + print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}") + print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}") + print(f"查询结果为:\n{statements}") # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( connector=connector, results=results, - group_id=group_id + end_user_id=end_user_id ) return results async def search_graph_l_created_at( connector: Neo4jConnector, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, + end_user_id: Optional[str] = None, + + created_at: Optional[str] = None, limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: @@ -843,37 +835,37 @@ async def search_graph_l_created_at( INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements created at created_at - - Optionally filters by group_id, apply_id, user_id + - Optionally filters by end_user_id, apply_id, user_id - Returns up to 'limit' statements """ statements = await connector.execute_query( SEARCH_STATEMENTS_L_CREATED_AT, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, + end_user_id=end_user_id, + + created_at=created_at, limit=limit, ) - logger.debug(f"Search less than created_at query: {SEARCH_STATEMENTS_L_CREATED_AT}") - logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}") - logger.debug(f"Search results: {len(statements)} statements found") + print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}") + print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}") + print(f"查询结果为:\n{statements}") # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( connector=connector, results=results, - group_id=group_id + end_user_id=end_user_id ) return results async def search_graph_l_valid_at( connector: Neo4jConnector, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, + end_user_id: Optional[str] = None, + + valid_at: Optional[str] = None, limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: @@ -883,28 +875,28 @@ async def search_graph_l_valid_at( INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements valid at valid_at - - Optionally filters by group_id, apply_id, user_id + - Optionally filters by end_user_id, apply_id, user_id - Returns up to 'limit' statements """ statements = await connector.execute_query( SEARCH_STATEMENTS_L_VALID_AT, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, + end_user_id=end_user_id, + + valid_at=valid_at, limit=limit, ) - logger.debug(f"Search less than valid_at query: {SEARCH_STATEMENTS_L_VALID_AT}") - logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}") - logger.debug(f"Search results: {len(statements)} statements found") + print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}") + print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}") + print(f"查询结果为:\n{statements}") # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( connector=connector, results=results, - group_id=group_id + end_user_id=end_user_id ) return results diff --git a/api/app/repositories/neo4j/memory_summary_repository.py b/api/app/repositories/neo4j/memory_summary_repository.py index fc743f33..d7cd4fd4 100644 --- a/api/app/repositories/neo4j/memory_summary_repository.py +++ b/api/app/repositories/neo4j/memory_summary_repository.py @@ -18,7 +18,7 @@ class MemorySummaryRepository(BaseNeo4jRepository): """Memory Summary Repository Manages CRUD operations for MemorySummary nodes. - Provides methods to query summaries by group_id, user_id, and time ranges. + Provides methods to query summaries by end_user_id, user_id, and time ranges. Attributes: connector: Neo4j connector instance @@ -51,17 +51,17 @@ class MemorySummaryRepository(BaseNeo4jRepository): return dict(n) - async def find_by_group_id( + async def find_by_end_user_id( self, - group_id: str, + end_user_id: str, limit: int = 1000, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None ) -> List[Dict[str, Any]]: - """Query memory summaries by group_id + """Query memory summaries by end_user_id Args: - group_id: Group ID to filter by + end_user_id: Group ID to filter by limit: Maximum number of results to return start_date: Optional start date filter end_date: Optional end date filter @@ -71,10 +71,10 @@ class MemorySummaryRepository(BaseNeo4jRepository): """ query = f""" MATCH (n:{self.node_label}) - WHERE n.group_id = $group_id + WHERE n.end_user_id = $end_user_id """ - params = {"group_id": group_id, "limit": limit} + params = {"end_user_id": end_user_id, "limit": limit} # Add date range filters if provided if start_date: @@ -139,16 +139,16 @@ class MemorySummaryRepository(BaseNeo4jRepository): async def find_by_group_and_user( self, - group_id: str, + end_user_id: str, user_id: str, limit: int = 1000, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None ) -> List[Dict[str, Any]]: - """Query memory summaries by both group_id and user_id + """Query memory summaries by both end_user_id and user_id Args: - group_id: Group ID to filter by + end_user_id: Group ID to filter by user_id: User ID to filter by limit: Maximum number of results to return start_date: Optional start date filter @@ -159,10 +159,10 @@ class MemorySummaryRepository(BaseNeo4jRepository): """ query = f""" MATCH (n:{self.node_label}) - WHERE n.group_id = $group_id AND n.user_id = $user_id + WHERE n.end_user_id = $end_user_id AND n.user_id = $user_id """ - params = {"group_id": group_id, "user_id": user_id, "limit": limit} + params = {"end_user_id": end_user_id, "user_id": user_id, "limit": limit} # Add date range filters if provided if start_date: @@ -184,14 +184,14 @@ class MemorySummaryRepository(BaseNeo4jRepository): async def find_recent_summaries( self, - group_id: str, + end_user_id: str, days: int = 7, limit: int = 1000 ) -> List[Dict[str, Any]]: """Query recent memory summaries Args: - group_id: Group ID to filter by + end_user_id: Group ID to filter by days: Number of recent days to query limit: Maximum number of results to return @@ -200,7 +200,7 @@ class MemorySummaryRepository(BaseNeo4jRepository): """ query = f""" MATCH (n:{self.node_label}) - WHERE n.group_id = $group_id + WHERE n.end_user_id = $end_user_id AND n.created_at >= datetime() - duration({{days: $days}}) RETURN n ORDER BY n.created_at DESC @@ -209,7 +209,7 @@ class MemorySummaryRepository(BaseNeo4jRepository): results = await self.connector.execute_query( query, - group_id=group_id, + end_user_id=end_user_id, days=days, limit=limit ) @@ -217,14 +217,14 @@ class MemorySummaryRepository(BaseNeo4jRepository): async def find_by_content_keywords( self, - group_id: str, + end_user_id: str, keywords: List[str], limit: int = 100 ) -> List[Dict[str, Any]]: """Query memory summaries by content keywords Args: - group_id: Group ID to filter by + end_user_id: Group ID to filter by keywords: List of keywords to search for in content limit: Maximum number of results to return @@ -233,7 +233,7 @@ class MemorySummaryRepository(BaseNeo4jRepository): """ # Build keyword search conditions keyword_conditions = [] - params = {"group_id": group_id, "limit": limit} + params = {"end_user_id": end_user_id, "limit": limit} for i, keyword in enumerate(keywords): keyword_conditions.append(f"toLower(n.content) CONTAINS toLower($keyword_{i})") @@ -243,7 +243,7 @@ class MemorySummaryRepository(BaseNeo4jRepository): query = f""" MATCH (n:{self.node_label}) - WHERE n.group_id = $group_id + WHERE n.end_user_id = $end_user_id AND ({keyword_filter}) RETURN n ORDER BY n.created_at DESC @@ -253,21 +253,21 @@ class MemorySummaryRepository(BaseNeo4jRepository): results = await self.connector.execute_query(query, **params) return [self._map_to_dict(r) for r in results] - async def get_summary_count_by_group(self, group_id: str) -> int: + async def get_summary_count_by_group(self, end_user_id: str) -> int: """Get count of memory summaries for a group Args: - group_id: Group ID to count summaries for + end_user_id: Group ID to count summaries for Returns: int: Number of memory summaries """ query = f""" MATCH (n:{self.node_label}) - WHERE n.group_id = $group_id + WHERE n.end_user_id = $end_user_id RETURN count(n) as count """ - results = await self.connector.execute_query(query, group_id=group_id) + results = await self.connector.execute_query(query, end_user_id=end_user_id) return results[0]['count'] if results else 0 \ No newline at end of file diff --git a/api/app/repositories/neo4j/neo4j_connector.py b/api/app/repositories/neo4j/neo4j_connector.py index 7c4b43b5..d96e4431 100644 --- a/api/app/repositories/neo4j/neo4j_connector.py +++ b/api/app/repositories/neo4j/neo4j_connector.py @@ -70,11 +70,7 @@ class Neo4jConnector: List[Dict[str, Any]]: 查询结果列表,每个元素是一个字典 Example: - >>> connector = Neo4jConnector() - >>> results = await connector.execute_query( - ... "MATCH (n:Person {name: $name}) RETURN n", - ... name="Alice" - ... ) + """ result = await self.driver.execute_query( query, @@ -98,17 +94,7 @@ class Neo4jConnector: Any: 事务函数的返回值 Example: - >>> async def create_node(tx, name): - ... result = await tx.run( - ... "CREATE (n:Person {name: $name}) RETURN n", - ... name=name - ... ) - ... return await result.single() - >>> - >>> connector = Neo4jConnector() - >>> result = await connector.execute_write_transaction( - ... create_node, name="Alice" - ... ) + """ async with self.driver.session(database="neo4j") as session: return await session.execute_write(transaction_func, **kwargs) @@ -126,45 +112,33 @@ class Neo4jConnector: Any: 事务函数的返回值 Example: - >>> async def get_node(tx, name): - ... result = await tx.run( - ... "MATCH (n:Person {name: $name}) RETURN n", - ... name=name - ... ) - ... return await result.single() - >>> - >>> connector = Neo4jConnector() - >>> result = await connector.execute_read_transaction( - ... get_node, name="Alice" - ... ) + """ async with self.driver.session(database="neo4j") as session: return await session.execute_read(transaction_func, **kwargs) - async def delete_group(self, group_id: str): + async def delete_group(self, end_user_id: str): """删除指定组的所有数据 - 删除所有属于指定group_id的节点和边。 + 删除所有属于指定end_user_id的节点和边。 这是一个危险操作,会永久删除数据。 Args: - group_id: 要删除的组ID + end_user_id: 要删除的组ID Example: - >>> connector = Neo4jConnector() - >>> await connector.delete_group("group_123") Group group_123 deleted. """ # 删除节点(DETACH DELETE会同时删除相关的边) await self.driver.execute_query( - "MATCH (n) WHERE n.group_id = $group_id DETACH DELETE n", + "MATCH (n) WHERE n.end_user_id = $end_user_id DETACH DELETE n", database="neo4j", - group_id=group_id + end_user_id=end_user_id ) # 删除独立的边(如果有的话) await self.driver.execute_query( - "MATCH ()-[r]->() WHERE r.group_id = $group_id DELETE r", + "MATCH ()-[r]->() WHERE r.end_user_id = $end_user_id DELETE r", database="neo4j", - group_id=group_id + end_user_id=end_user_id ) - print(f"Group {group_id} deleted.") + print(f"Group {end_user_id} deleted.") diff --git a/api/app/repositories/neo4j/statement_repository.py b/api/app/repositories/neo4j/statement_repository.py index cd9f2fac..4f12af83 100644 --- a/api/app/repositories/neo4j/statement_repository.py +++ b/api/app/repositories/neo4j/statement_repository.py @@ -20,7 +20,7 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]): """陈述句仓储 管理陈述句节点的创建、查询、更新和删除操作。 - 提供按chunk_id、group_id、向量相似度等条件查询陈述句的方法。 + 提供按chunk_id、end_user_id、向量相似度等条件查询陈述句的方法。 Attributes: connector: Neo4j连接器实例 diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 35d2e424..09410091 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -299,6 +299,18 @@ class AppRelease(BaseModel): created_at: datetime.datetime updated_at: datetime.datetime + @field_validator("config", mode="before") + @classmethod + def parse_config(cls, v): + """处理 config 字段,如果是字符串则解析为字典""" + if isinstance(v, str): + import json + try: + return json.loads(v) + except json.JSONDecodeError: + return {} + return v if v is not None else {} + @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None diff --git a/api/app/schemas/emotion_schema.py b/api/app/schemas/emotion_schema.py index c48fbd41..13c802b5 100644 --- a/api/app/schemas/emotion_schema.py +++ b/api/app/schemas/emotion_schema.py @@ -1,11 +1,12 @@ """情绪分析相关的请求和响应模型""" from typing import Optional +from uuid import UUID from pydantic import BaseModel, Field class EmotionTagsRequest(BaseModel): """获取情绪标签统计请求""" - group_id: str = Field(..., description="组ID") + end_user_id: str = Field(..., description="组ID") emotion_type: Optional[str] = Field(None, description="情绪类型过滤(joy/sadness/anger/fear/surprise/neutral)") start_date: Optional[str] = Field(None, description="开始日期(ISO格式,如:2024-01-01)") end_date: Optional[str] = Field(None, description="结束日期(ISO格式,如:2024-12-31)") @@ -14,14 +15,14 @@ class EmotionTagsRequest(BaseModel): class EmotionWordcloudRequest(BaseModel): """获取情绪词云数据请求""" - group_id: str = Field(..., description="组ID") + end_user_id: str = Field(..., description="组ID") emotion_type: Optional[str] = Field(None, description="情绪类型过滤(joy/sadness/anger/fear/surprise/neutral)") limit: int = Field(50, ge=1, le=200, description="返回词语数量") class EmotionHealthRequest(BaseModel): """获取情绪健康指数请求""" - group_id: str = Field(..., description="组ID") + end_user_id: str = Field(..., description="组ID") time_range: str = Field("30d", description="时间范围(7d/30d/90d)") @@ -29,8 +30,8 @@ class EmotionHealthRequest(BaseModel): class EmotionSuggestionsRequest(BaseModel): """获取个性化情绪建议请求""" - group_id: str = Field(..., description="组ID") - config_id: Optional[int] = Field(None, description="配置ID(用于指定LLM模型)") + end_user_id: str = Field(..., description="组ID") + config_id: Optional[UUID] = Field(None, description="配置ID(用于指定LLM模型)") class EmotionGenerateSuggestionsRequest(BaseModel): diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index d4354c40..b6f50dd7 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -7,11 +7,11 @@ class UserInput(BaseModel): message: str history: list[dict] search_switch: str - group_id: str + end_user_id: str config_id: Optional[str] = None class Write_UserInput(BaseModel): messages: list[dict] - group_id: str - config_id: Optional[str] = None + end_user_id: str + config_id: Optional[str] = None \ No newline at end of file diff --git a/api/app/schemas/memory_config_schema.py b/api/app/schemas/memory_config_schema.py index 0443dcc4..76acee5c 100644 --- a/api/app/schemas/memory_config_schema.py +++ b/api/app/schemas/memory_config_schema.py @@ -35,7 +35,7 @@ class ConfigurationError(Exception): def __init__( self, message: str, - config_id: Optional[int] = None, + config_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None, context: Optional[Dict[str, Any]] = None, ): @@ -72,7 +72,7 @@ class WorkspaceNotFoundError(ConfigurationError): def __init__( self, workspace_id: UUID, - config_id: Optional[int] = None, + config_id: Optional[UUID] = None, message: Optional[str] = None, ): if message is None: @@ -89,7 +89,7 @@ class ModelNotFoundError(ConfigurationError): self, model_id: Union[str, UUID], model_type: str, - config_id: Optional[int] = None, + config_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None, message: Optional[str] = None, ): @@ -112,7 +112,7 @@ class ModelInactiveError(ConfigurationError): model_id: Union[str, UUID], model_name: str, model_type: str, - config_id: Optional[int] = None, + config_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None, message: Optional[str] = None, ): @@ -136,7 +136,7 @@ class InvalidConfigError(ConfigurationError): message: str, field_name: Optional[str] = None, invalid_value: Optional[Any] = None, - config_id: Optional[int] = None, + config_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None, ): context = {} @@ -155,7 +155,7 @@ class InvalidConfigError(ConfigurationError): class MemoryConfigValidation(BaseModel): """Pydantic model for validating memory configuration data from database.""" - config_id: int = Field(..., gt=0, description="Configuration ID must be positive") + config_id: UUID = Field(..., description="Configuration ID (UUID)") config_name: str = Field(..., min_length=1, max_length=255) workspace_id: UUID = Field(..., description="Workspace UUID") workspace_name: str = Field(..., min_length=1, max_length=255) @@ -275,7 +275,7 @@ class ModelValidation(BaseModel): def validate_memory_config_data( - config_data: Dict[str, Any], config_id: Optional[int] = None + config_data: Dict[str, Any], config_id: Optional[UUID] = None ) -> MemoryConfigValidation: """Validate memory configuration data using Pydantic model.""" try: @@ -302,7 +302,7 @@ def validate_memory_config_data( def validate_workspace_data( - workspace_data: Dict[str, Any], config_id: Optional[int] = None + workspace_data: Dict[str, Any], config_id: Optional[UUID] = None ) -> WorkspaceValidation: """Validate workspace data using Pydantic model.""" try: @@ -331,7 +331,7 @@ def validate_workspace_data( def validate_model_data( - model_data: Dict[str, Any], config_id: Optional[int] = None + model_data: Dict[str, Any], config_id: Optional[UUID] = None ) -> ModelValidation: """Validate model data using Pydantic model.""" try: @@ -364,7 +364,7 @@ def validate_model_data( class MemoryConfig: """Immutable memory configuration loaded from database.""" - config_id: int + config_id: UUID config_name: str workspace_id: UUID workspace_name: str diff --git a/api/app/schemas/memory_perceptual_schema.py b/api/app/schemas/memory_perceptual_schema.py index 05e01d2a..7dfefe01 100644 --- a/api/app/schemas/memory_perceptual_schema.py +++ b/api/app/schemas/memory_perceptual_schema.py @@ -4,7 +4,7 @@ from typing import Optional from pydantic import BaseModel, Field -from app.models.memory_perceptual_model import PerceptualType, FileStorageType +from app.models.memory_perceptual_model import PerceptualType, FileStorageService class PerceptualFilter(BaseModel): @@ -38,12 +38,14 @@ class PerceptualMemoryItem(BaseModel): """感知记忆项""" id: uuid.UUID = Field(..., description="Unique memory ID") perceptual_type: PerceptualType = Field(..., description="Type of perception, e.g., text, audio, or video") + storage_service: FileStorageService = Field(..., description="Storage service for file") file_path: str = Field(..., description="File path in the storage service") - file_ext: str = Field(..., description="File extension") file_name: str = Field(..., description="File name") + file_ext: str = Field(..., description="File extension") summary: Optional[str] = Field(None, description="summary") - storage_type: FileStorageType = Field(..., description="Storage type for file") + meta_data: Optional[dict] = Field(None, description="Metadata information") created_time: int = Field(..., description="create time") + topic: str = Field(..., description="topic") domain: str = Field(..., description="domain") keywords: list[str] = Field(..., description="keywords") diff --git a/api/app/schemas/memory_reflection_schemas.py b/api/app/schemas/memory_reflection_schemas.py index 860f1ef1..df841fb1 100644 --- a/api/app/schemas/memory_reflection_schemas.py +++ b/api/app/schemas/memory_reflection_schemas.py @@ -1,5 +1,6 @@ from pydantic import BaseModel, Field from typing import Optional +from uuid import UUID from enum import Enum @@ -9,7 +10,7 @@ class OptimizationStrategy(str, Enum): ACCURACY_FIRST = "accuracy_first" BALANCED = "balanced" class Memory_Reflection(BaseModel): - config_id: Optional[int] = None + config_id: Optional[UUID] = None reflection_enabled: bool reflection_period_in_hours: str reflexion_range: Optional[str] = "partial" diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index d17a9f2c..d9c04f8f 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -1,5 +1,5 @@ """ -所有的内容是放错误地方了,应该放在models + """ from typing import Any, Optional, List, Dict, Literal, Union @@ -8,20 +8,8 @@ import uuid from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator -# ============================================================================ -# 原 UserInput 相关 Schema (保留原有功能) -# ============================================================================ -class UserInput(BaseModel): - message: str - history: list[dict] - search_switch: str - group_id: str -class Write_UserInput(BaseModel): - message: str - group_id: str - # ============================================================================ # 从 json_schema.py 迁移的 Schema @@ -159,7 +147,7 @@ class ReflexionResultSchema(BaseModel): # Composite key identifying a config row class ConfigKey(BaseModel): # 配置参数键模型 model_config = ConfigDict(populate_by_name=True, extra="forbid") - config_id: int = Field("config_id", description="配置唯一标识(字符串)") + config_id: uuid.UUID = Field("config_id", description="配置唯一标识(UUID)") user_id: str = Field("user_id", description="用户标识(字符串)") apply_id: str = Field("apply_id", description="应用或场景标识(字符串)") @@ -250,17 +238,17 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body, class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体) model_config = ConfigDict(populate_by_name=True, extra="forbid") # config_name: str = Field("配置名称", description="配置名称(字符串)") - config_id: int = Field("配置ID", description="配置ID(字符串)") + config_id: uuid.UUID = Field("配置ID", description="配置ID(UUID)") class ConfigUpdate(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 - config_id: Optional[int] = None + config_id: Optional[uuid.UUID] = None config_name: str = Field("配置名称", description="配置名称(字符串)") config_desc: str = Field("配置描述", description="配置描述(字符串)") class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数时使用的模型 - config_id: Optional[int] = None + config_id: Optional[uuid.UUID] = None llm_id: Optional[str] = Field(None, description="LLM模型配置ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") @@ -327,14 +315,14 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数 class ConfigUpdateForget(BaseModel): # 更新遗忘引擎配置参数时使用的模型 # 遗忘引擎配置参数更新模型 - config_id: Optional[int] = None + config_id: Optional[uuid.UUID] = None lambda_time: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="最低保持度,0-1 小数;默认 0.5") lambda_mem: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="遗忘率,0-1 小数;默认 0.5") offset: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="偏移度,0-1 小数;默认 0.0") class ConfigPilotRun(BaseModel): # 试运行触发请求模型 - config_id: int = Field(..., description="配置ID(唯一)") + config_id: uuid.UUID = Field(..., description="配置ID(唯一)") dialogue_text: str = Field(..., description="前端传入的对话文本,格式如 '用户: ...\nAI: ...' 可多行,试运行必填") model_config = ConfigDict(populate_by_name=True, extra="forbid") @@ -342,7 +330,7 @@ class ConfigPilotRun(BaseModel): # 试运行触发请求模型 class ConfigFilter(BaseModel): # 查询配置参数时使用的模型 model_config = ConfigDict(populate_by_name=True, extra="forbid") - config_id: Optional[int] = None + config_id: Optional[uuid.UUID] = None user_id: Optional[str] = None apply_id: Optional[str] = None @@ -418,7 +406,7 @@ class ForgettingConfigResponse(BaseModel): """遗忘引擎配置响应模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - config_id: int = Field(..., description="配置ID") + config_id: uuid.UUID = Field(..., description="配置ID") decay_constant: float = Field(..., description="衰减常数 d") lambda_time: float = Field(..., description="时间衰减参数") lambda_mem: float = Field(..., description="记忆衰减参数") @@ -436,7 +424,7 @@ class ForgettingConfigUpdateRequest(BaseModel): """遗忘引擎配置更新请求模型""" model_config = ConfigDict(populate_by_name=True, extra="forbid") - config_id: int = Field(..., description="配置ID") + config_id: uuid.UUID = Field(..., description="配置ID") decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d") lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数") lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数") @@ -511,7 +499,7 @@ class ForgettingCurveRequest(BaseModel): importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数(0-1)") days: int = Field(60, ge=1, le=365, description="模拟天数(默认60天)") - config_id: Optional[int] = Field(None, description="配置ID(可选,如果为None则使用默认配置)") + config_id: Optional[uuid.UUID] = Field(None, description="配置ID(可选,如果为None则使用默认配置)") class ForgettingCurveResponse(BaseModel): diff --git a/api/app/schemas/model_schema.py b/api/app/schemas/model_schema.py index 5b1fe6d9..68f15115 100644 --- a/api/app/schemas/model_schema.py +++ b/api/app/schemas/model_schema.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, field_serializer, ConfigDict +from pydantic import BaseModel, Field, field_serializer, field_validator, ConfigDict from typing import Optional, List, Dict, Any import datetime import uuid @@ -91,6 +91,18 @@ class ModelApiKey(ModelApiKeyBase): created_at: datetime.datetime updated_at: datetime.datetime + @field_validator("config", mode="before") + @classmethod + def parse_config(cls, v): + """处理 config 字段,如果是字符串则解析为字典""" + if isinstance(v, str): + import json + try: + return json.loads(v) + except json.JSONDecodeError: + return {} + return v + @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None diff --git a/api/app/schemas/release_share_schema.py b/api/app/schemas/release_share_schema.py index 069b78a9..47897847 100644 --- a/api/app/schemas/release_share_schema.py +++ b/api/app/schemas/release_share_schema.py @@ -1,7 +1,7 @@ import uuid import datetime from typing import Optional, List, Dict, Any -from pydantic import BaseModel, Field, ConfigDict, field_serializer +from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator # ---------- Input Schemas ---------- @@ -88,6 +88,18 @@ class SharedReleaseInfo(BaseModel): # 嵌入配置 allow_embed: bool + @field_validator("config", mode="before") + @classmethod + def parse_config(cls, v): + """处理 config 字段,如果是字符串则解析为字典""" + if isinstance(v, str): + import json + try: + return json.loads(v) + except json.JSONDecodeError: + return {} + return v if v is not None else {} + class EmbedCode(BaseModel): """嵌入代码""" diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 4f20f6d9..9766eec0 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -92,7 +92,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str try: memory_content = asyncio.run( MemoryAgentService().read_memory( - group_id=end_user_id, + end_user_id=end_user_id, message=question, history=[], search_switch="2", diff --git a/api/app/services/emotion_analytics_service.py b/api/app/services/emotion_analytics_service.py index 601d2921..af98fb52 100644 --- a/api/app/services/emotion_analytics_service.py +++ b/api/app/services/emotion_analytics_service.py @@ -75,7 +75,7 @@ class EmotionAnalyticsService: # 调用仓储层查询 tags = await self.emotion_repo.get_emotion_tags( - group_id=end_user_id, + end_user_id=end_user_id, emotion_type=emotion_type, start_date=start_date, end_date=end_date, @@ -157,7 +157,7 @@ class EmotionAnalyticsService: # 调用仓储层查询 keywords = await self.emotion_repo.get_emotion_wordcloud( - group_id=end_user_id, + end_user_id=end_user_id, emotion_type=emotion_type, limit=limit ) @@ -339,7 +339,7 @@ class EmotionAnalyticsService: # 获取时间范围内的情绪数据 emotions = await self.emotion_repo.get_emotions_in_range( - group_id=end_user_id, + end_user_id=end_user_id, time_range=time_range ) @@ -505,7 +505,7 @@ class EmotionAnalyticsService: ) config_service = MemoryConfigService(db) memory_config = config_service.load_memory_config( - config_id=int(config_id), + config_id=(config_id), service_name="EmotionAnalyticsService.generate_emotion_suggestions" ) from app.core.memory.utils.llm.llm_utils import MemoryClientFactory @@ -519,7 +519,7 @@ class EmotionAnalyticsService: # 3. 获取情绪数据用于模式分析 emotions = await self.emotion_repo.get_emotions_in_range( - group_id=end_user_id, + end_user_id=end_user_id, time_range="30d" ) @@ -598,13 +598,13 @@ class EmotionAnalyticsService: # 查询用户的实体和标签 query = """ MATCH (e:Entity) - WHERE e.group_id = $group_id + WHERE e.end_user_id = $end_user_id RETURN e.name as name, e.type as type ORDER BY e.created_at DESC LIMIT 20 """ - entities = await connector.execute_query(query, group_id=end_user_id) + entities = await connector.execute_query(query, end_user_id=end_user_id) # 提取兴趣标签 interests = [e["name"] for e in entities if e.get("type") in ["INTEREST", "HOBBY"]][:5] diff --git a/api/app/services/emotion_config_service.py b/api/app/services/emotion_config_service.py index 37171640..9880d4e1 100644 --- a/api/app/services/emotion_config_service.py +++ b/api/app/services/emotion_config_service.py @@ -8,9 +8,11 @@ Classes: """ from typing import Dict, Any +from uuid import UUID + from sqlalchemy.orm import Session -from app.models.data_config_model import DataConfig +from app.models.memory_config_model import MemoryConfig from app.core.logging_config import get_business_logger logger = get_business_logger() @@ -37,7 +39,7 @@ class EmotionConfigService: self.db = db logger.info("情绪配置服务初始化完成") - def get_emotion_config(self, config_id: int) -> Dict[str, Any]: + def get_emotion_config(self, config_id: UUID) -> Dict[str, Any]: """获取情绪引擎配置 查询指定配置ID的情绪相关配置字段。 @@ -61,8 +63,8 @@ class EmotionConfigService: logger.info(f"获取情绪配置: config_id={config_id}") # 查询配置 - config = self.db.query(DataConfig).filter( - DataConfig.config_id == config_id + config = self.db.query(MemoryConfig).filter( + MemoryConfig.config_id == config_id ).first() if not config: @@ -144,7 +146,7 @@ class EmotionConfigService: def update_emotion_config( self, - config_id: int, + config_id: UUID, config_data: Dict[str, Any] ) -> Dict[str, Any]: """更新情绪引擎配置 @@ -173,8 +175,8 @@ class EmotionConfigService: self.validate_emotion_config(config_data) # 查询配置 - config = self.db.query(DataConfig).filter( - DataConfig.config_id == config_id + config = self.db.query(MemoryConfig).filter( + MemoryConfig.config_id == config_id ).first() if not config: diff --git a/api/app/services/emotion_extraction_service.py b/api/app/services/emotion_extraction_service.py index d134251d..6b596a80 100644 --- a/api/app/services/emotion_extraction_service.py +++ b/api/app/services/emotion_extraction_service.py @@ -14,7 +14,7 @@ from app.core.memory.llm_tools.llm_client import LLMClientException from app.core.memory.models.emotion_models import EmotionExtraction from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context -from app.models.data_config_model import DataConfig +from app.models.memory_config_model import MemoryConfig logger = logging.getLogger(__name__) @@ -60,7 +60,7 @@ class EmotionExtractionService: async def extract_emotion( self, statement: str, - config: DataConfig + config: MemoryConfig ) -> Optional[EmotionExtraction]: """Extract emotion information from a statement. diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 1e1cde89..6e72a53f 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -9,6 +9,7 @@ import os import re import time import uuid +from uuid import UUID from typing import Any, AsyncGenerator, Dict, List, Optional import redis @@ -27,6 +28,7 @@ from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType +from app.repositories.memory_short_repository import ShortTermMemoryRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_agent_schema import Write_UserInput from app.schemas.memory_config_schema import ConfigurationError @@ -35,6 +37,7 @@ from app.services.memory_config_service import MemoryConfigService from app.services.memory_konwledges_server import ( write_rag, ) +from langchain_core.messages import AIMessage from langchain_core.messages import HumanMessage from pydantic import BaseModel, Field from sqlalchemy import func @@ -54,25 +57,24 @@ _neo4j_connector = Neo4jConnector() class MemoryAgentService: """Service for memory agent operations""" - def writer_messages_deal(self, messages, start_time, group_id, config_id, message, context): + def writer_messages_deal(self, messages, start_time, end_user_id, config_id, message, context): duration = time.time() - start_time - if str(messages) == 'success': - logger.info(f"Write operation successful for group {group_id} with config_id {config_id}") + logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}") # 记录成功的操作 if audit_logger: - audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=True, + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True, duration=duration, details={"message_length": len(message)}) return context else: - logger.warning(f"Write operation failed for group {group_id}") + logger.warning(f"Write operation failed for group {end_user_id}") # 记录失败的操作 if audit_logger: audit_logger.log_operation( operation="WRITE", config_id=config_id, - group_id=group_id, + end_user_id=end_user_id, success=False, duration=duration, error=f"写入失败: {messages[:100]}" @@ -263,13 +265,13 @@ class MemoryAgentService: logger.info("Log streaming completed, cleaning up resources") # LogStreamer uses context manager for file handling, so cleanup is automatic - async def write_memory(self, group_id: str, messages: list[dict], config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str: + async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID], db: Session, storage_type: str, user_rag_memory_id: str) -> str: """ Process write operation with config_id Args: - group_id: Group identifier (also used as end_user_id) - messages: Structured message list [{"role": "user", "content": "..."}, ...] + end_user_id: Group identifier (also used as end_user_id) + message: Message to write config_id: Configuration ID from database db: SQLAlchemy database session storage_type: Storage type (neo4j or rag) @@ -284,15 +286,15 @@ class MemoryAgentService: # Resolve config_id if None using end_user's connected config if config_id is None: try: - connected_config = get_end_user_connected_config(group_id, db) + connected_config = get_end_user_connected_config(end_user_id, db) config_id = connected_config.get("memory_config_id") if config_id is None: - raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.") + raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") except Exception as e: if "No memory configuration found" in str(e): - raise - logger.error(f"Failed to get connected config for end_user {group_id}: {e}") - raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}") + raise # Re-raise our specific error + logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}") + raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}") import time start_time = time.time() @@ -312,7 +314,7 @@ class MemoryAgentService: # Log failed operation if audit_logger: duration = time.time() - start_time - audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg) + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg) raise ValueError(error_msg) @@ -320,24 +322,23 @@ class MemoryAgentService: if storage_type == "rag": # For RAG storage, convert messages to single string message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - result = await write_rag(group_id, message_text, user_rag_memory_id) + result = await write_rag(end_user_id, message_text, user_rag_memory_id) return result else: async with make_write_graph() as graph: - config = {"configurable": {"thread_id": group_id}} + config = {"configurable": {"thread_id": end_user_id}} # Convert structured messages to LangChain messages langchain_messages = [] for msg in messages: if msg['role'] == 'user': langchain_messages.append(HumanMessage(content=msg['content'])) elif msg['role'] == 'assistant': - from langchain_core.messages import AIMessage langchain_messages.append(AIMessage(content=msg['content'])) - + # 初始状态 - 包含所有必要字段 initial_state = { "messages": langchain_messages, - "group_id": group_id, + "end_user_id": end_user_id, "memory_config": memory_config } @@ -354,14 +355,14 @@ class MemoryAgentService: contents = massages.get('write_result') # Convert messages back to string for logging message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message_text, contents) + return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents) except Exception as e: # Ensure proper error handling and logging error_msg = f"Write operation failed: {str(e)}" logger.error(error_msg) if audit_logger: duration = time.time() - start_time - audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg) + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg) raise ValueError(error_msg) @@ -369,15 +370,14 @@ class MemoryAgentService: async def read_memory( self, - group_id: str, + end_user_id: str, message: str, history: List[Dict], search_switch: str, - config_id: Optional[str], + config_id: Optional[UUID], db: Session, storage_type: str, - user_rag_memory_id: str - ) -> Dict: + user_rag_memory_id: str) -> Dict: """ Process read operation with config_id @@ -387,7 +387,7 @@ class MemoryAgentService: - "2": Direct answer based on context Args: - group_id: Group identifier (also used as end_user_id) + end_user_id: Group identifier (also used as end_user_id) message: User message history: Conversation history search_switch: Search mode switch @@ -405,22 +405,22 @@ class MemoryAgentService: import time start_time = time.time() - logger.info(f"[PERF] read_memory started for group_id={group_id}, search_switch={search_switch}") + ori_message= message # Resolve config_id if None using end_user's connected config if config_id is None: try: - config_id = get_end_user_connected_config(group_id, db) - config_id=config_id.get('memory_config_id') + connected_config = get_end_user_connected_config(end_user_id, db) + config_id = connected_config.get("memory_config_id") if config_id is None: - raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.") + raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") except Exception as e: if "No memory configuration found" in str(e): raise # Re-raise our specific error - logger.error(f"Failed to get connected config for end_user {group_id}: {e}") - raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}") + logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}") + raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}") - logger.info(f"Read operation for group {group_id} with config_id {config_id}") + logger.info(f"Read operation for group {end_user_id} with config_id {config_id}") # 导入审计日志记录器 try: @@ -448,7 +448,7 @@ class MemoryAgentService: audit_logger.log_operation( operation="READ", config_id=config_id, - group_id=group_id, + end_user_id=end_user_id, success=False, duration=duration, error=error_msg @@ -458,16 +458,16 @@ class MemoryAgentService: # Step 2: Prepare history history.append({"role": "user", "content": message}) - logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") + logger.debug(f"Group ID:{end_user_id}, Message:{message}, History:{history}, Config ID:{config_id}") # Step 3: Initialize MCP client and execute read workflow graph_exec_start = time.time() try: async with make_read_graph() as graph: - config = {"configurable": {"thread_id": group_id}} + config = {"configurable": {"thread_id": end_user_id}} # 初始状态 - 包含所有必要字段 initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch, - "group_id": group_id + "end_user_id": end_user_id , "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id, "memory_config": memory_config} # 获取节点更新信息 @@ -562,13 +562,13 @@ class MemoryAgentService: if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2": # 使用 upsert 方法 repo.upsert( - end_user_id=group_id, - messages=message, + end_user_id=end_user_id, + messages=ori_message, aimessages=summary, retrieved_content=retrieved_content, search_switch=str(search_switch) ) - logger.info(f"成功保存短期记忆: group_id={group_id}, search_switch={search_switch}") + logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}") else: logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}") @@ -584,7 +584,7 @@ class MemoryAgentService: audit_logger.log_operation( operation="READ", config_id=config_id, - group_id=group_id, + end_user_id=end_user_id, success=True, duration=duration ) @@ -596,20 +596,20 @@ class MemoryAgentService: except Exception as e: # Ensure proper error handling and logging error_msg = f"Read operation failed: {str(e)}" - total_time = time.time() - start_time - logger.error(f"[PERF] read_memory failed after {total_time:.4f}s: {error_msg}") + logger.error(error_msg) if audit_logger: duration = time.time() - start_time audit_logger.log_operation( operation="READ", config_id=config_id, - group_id=group_id, + end_user_id=end_user_id, success=False, duration=duration, error=error_msg ) raise ValueError(error_msg) + def get_messages_list(self, user_input: Write_UserInput) -> list[dict]: """ Get standardized message list from user input. @@ -654,7 +654,7 @@ class MemoryAgentService: logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}") return user_input.messages - async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict: + async def classify_message_type(self, message: str, config_id: UUID, db: Session) -> Dict: """ Determine the type of user message (read or write) Updated to eliminate global variables in favor of explicit parameters. @@ -681,10 +681,9 @@ class MemoryAgentService: status = await status_typle(message, memory_config.llm_model_id) logger.debug(f"Message type: {status}") return status - async def generate_summary_from_retrieve( self, - group_id: str, + end_user_id: str, retrieve_info: str, history: List[Dict], query: str, @@ -708,16 +707,16 @@ class MemoryAgentService: """ if config_id is None: try: - config_id = get_end_user_connected_config(group_id, db) + config_id = get_end_user_connected_config(end_user_id, db) config_id = config_id.get('memory_config_id') if config_id is None: raise ValueError( - f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.") + f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") except Exception as e: if "No memory configuration found" in str(e): raise # Re-raise our specific error - logger.error(f"Failed to get connected config for end_user {group_id}: {e}") - raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}") + logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}") + raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}") logger.info(f"Generating summary from retrieve info for query: {query[:50]}...") try: @@ -727,6 +726,7 @@ class MemoryAgentService: config_id=config_id, service_name="MemoryAgentService" ) + # 导入必要的模块 from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import summary_llm from app.core.memory.agent.models.summary_models import RetrieveSummaryResponse @@ -766,7 +766,7 @@ class MemoryAgentService: """ 统计知识库类型分布,包含: 1. PostgreSQL 中的知识库类型:General, Web, Third-party, Folder(根据 workspace_id 过滤) - 2. Neo4j 中的 memory 类型(仅统计 Chunk 数量,根据 end_user_id/group_id 过滤) + 2. Neo4j 中的 memory 类型(仅统计 Chunk 数量,根据 end_user_id/end_user_id 过滤) 3. total: 所有类型的总和 参数: @@ -852,11 +852,11 @@ class MemoryAgentService: for end_user in end_users: end_user_id_str = str(end_user.id) memory_query = """ - MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN count(n) AS Count + MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN count(n) AS Count """ neo4j_result = await _neo4j_connector.execute_query( memory_query, - group_id=end_user_id_str, + end_user_id=end_user_id_str, ) chunk_count = neo4j_result[0]["Count"] if neo4j_result else 0 total_chunks += chunk_count @@ -896,7 +896,7 @@ class MemoryAgentService: 获取指定用户的热门记忆标签 参数: - - end_user_id: 用户ID(可选),对应Neo4j中的group_id字段 + - end_user_id: 用户ID(可选),对应Neo4j中的end_user_id字段 - limit: 返回标签数量限制 返回格式: @@ -906,7 +906,7 @@ class MemoryAgentService: ] """ try: - # by_user=False 表示按 group_id 查询(在Neo4j中,group_id就是用户维度) + # by_user=False 表示按 end_user_id 查询(在Neo4j中,end_user_id就是用户维度) tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False) payload=[] for tag, freq in tags: @@ -981,21 +981,21 @@ class MemoryAgentService: # 查询该用户的语句 query = ( "MATCH (s:Statement) " - "WHERE ($group_id IS NULL OR s.group_id = $group_id) AND s.statement IS NOT NULL " + "WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id) AND s.statement IS NOT NULL " "RETURN s.statement AS statement " "ORDER BY s.created_at DESC LIMIT 100" ) - rows = await connector.execute_query(query, group_id=end_user_id) + rows = await connector.execute_query(query, end_user_id=end_user_id) statements = [r.get("statement", "") for r in rows if r.get("statement")] # 查询该用户的热门实体 entity_query = ( "MATCH (e:ExtractedEntity) " - "WHERE ($group_id IS NULL OR e.group_id = $group_id) AND e.entity_type <> '人物' AND e.name IS NOT NULL " + "WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) AND e.entity_type <> '人物' AND e.name IS NOT NULL " "RETURN e.name AS name, count(e) AS frequency " "ORDER BY frequency DESC LIMIT 20" ) - entity_rows = await connector.execute_query(entity_query, group_id=end_user_id) + entity_rows = await connector.execute_query(entity_query, end_user_id=end_user_id) entities = [f"{r['name']} ({r['frequency']})" for r in entity_rows] await connector.close() @@ -1048,14 +1048,14 @@ class MemoryAgentService: names_to_exclude = ['AI', 'Caroline', 'Melanie', 'Jon', 'Gina', '用户', 'AI助手', 'John', 'Maria'] hot_tag_query = ( "MATCH (e:ExtractedEntity) " - "WHERE ($group_id IS NULL OR e.group_id = $group_id) AND e.entity_type <> '人物' " + "WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) AND e.entity_type <> '人物' " "AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude " "RETURN e.name AS name, count(e) AS frequency " "ORDER BY frequency DESC LIMIT 4" ) hot_tag_rows = await connector.execute_query( hot_tag_query, - group_id=end_user_id, + end_user_id=end_user_id, names_to_exclude=names_to_exclude ) await connector.close() @@ -1189,6 +1189,16 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An # 3. 从 config 中提取 memory_config_id config = latest_release.config or {} + + # 如果 config 是字符串,解析为字典 + if isinstance(config, str): + import json + try: + config = json.loads(config) + except json.JSONDecodeError: + logger.warning(f"Failed to parse config JSON for release {latest_release.id}") + config = {} + memory_obj = config.get('memory', {}) memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None @@ -1227,7 +1237,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) """ from app.models.app_release_model import AppRelease from app.models.end_user_model import EndUser - from app.models.data_config_model import DataConfig + from app.models.memory_config_model import MemoryConfig from sqlalchemy import select logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users") @@ -1240,10 +1250,10 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 1. 批量查询所有 end_user 及其 app_id end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all() - + # 创建 end_user_id -> app_id 的映射 user_to_app = {str(eu.id): eu.app_id for eu in end_users} - + # 记录未找到的用户 found_user_ids = set(user_to_app.keys()) missing_user_ids = set(end_user_ids) - found_user_ids @@ -1285,13 +1295,13 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 批量查询 memory_config_name config_id_to_name = {} if memory_config_ids: - memory_configs = db.query(DataConfig).filter(DataConfig.config_id.in_(memory_config_ids)).all() - config_id_to_name = {str(mc.config_id): mc.config_name for mc in memory_configs} + memory_configs = db.query(MemoryConfig).filter(MemoryConfig.id.in_(memory_config_ids)).all() + config_id_to_name = {str(mc.id): mc.config_name for mc in memory_configs} # 4. 构建最终结果 for end_user_id, app_id in user_to_app.items(): release = app_to_release.get(app_id) - + if not release: logger.warning(f"No active release found for app: {app_id} (end_user: {end_user_id})") result[end_user_id] = {"memory_config_id": None, "memory_config_name": None} @@ -1303,7 +1313,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None # 获取配置名称 - memory_config_name = config_id_to_name.get(str(memory_config_id)) if memory_config_id else None + memory_config_name = config_id_to_name.get(memory_config_id) if memory_config_id else None result[end_user_id] = { "memory_config_id": memory_config_id, diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index 2d3d047e..a8c39a5a 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -25,7 +25,7 @@ class MemoryAPIService: This service provides a thin layer that: 1. Validates end_user exists and belongs to the authorized workspace - 2. Maps end_user_id to group_id for memory operations + 2. Maps end_user_id to end_user_id for memory operations 3. Delegates to MemoryAgentService for actual memory read/write operations """ @@ -68,7 +68,7 @@ class MemoryAPIService: ) end_user = self.db.query(EndUser).filter(EndUser.id == end_user_uuid).first() - + if not end_user: logger.warning(f"End user not found: {end_user_id}") raise ResourceNotFoundException( @@ -118,7 +118,7 @@ class MemoryAPIService: Args: workspace_id: Workspace ID for resource validation - end_user_id: End user identifier (used as group_id) + end_user_id: End user identifier (used as end_user_id) message: Message content to store config_id: Optional memory configuration ID storage_type: Storage backend (neo4j or rag) @@ -136,14 +136,13 @@ class MemoryAPIService: # Validate end_user exists and belongs to workspace self.validate_end_user(end_user_id, workspace_id) - # Use end_user_id as group_id for memory operations - group_id = end_user_id + # Use end_user_id as end_user_id for memory operations try: # Delegate to MemoryAgentService result = await MemoryAgentService().write_memory( - group_id=group_id, - message=message, + end_user_id=end_user_id, + messages=message, config_id=config_id, db=self.db, storage_type=storage_type, @@ -189,7 +188,7 @@ class MemoryAPIService: Args: workspace_id: Workspace ID for resource validation - end_user_id: End user identifier (used as group_id) + end_user_id: End user identifier (used as end_user_id) message: Query message search_switch: Search mode (0=deep search with verification, 1=deep search, 2=fast search) config_id: Optional memory configuration ID @@ -208,13 +207,13 @@ class MemoryAPIService: # Validate end_user exists and belongs to workspace self.validate_end_user(end_user_id, workspace_id) - # Use end_user_id as group_id for memory operations - group_id = end_user_id + # Use end_user_id as end_user_id for memory operations + try: # Delegate to MemoryAgentService result = await MemoryAgentService().read_memory( - group_id=group_id, + end_user_id=end_user_id, message=message, history=[], search_switch=search_switch, diff --git a/api/app/services/memory_base_service.py b/api/app/services/memory_base_service.py index 25a8281d..bc647752 100644 --- a/api/app/services/memory_base_service.py +++ b/api/app/services/memory_base_service.py @@ -326,7 +326,7 @@ class MemoryBaseService: Args: summary_id: Summary节点的ID - end_user_id: 终端用户ID (group_id) + end_user_id: 终端用户ID (end_user_id) Returns: 最大emotion_intensity对应的emotion_type,如果没有则返回None @@ -334,7 +334,7 @@ class MemoryBaseService: try: query = """ MATCH (s:MemorySummary) - WHERE elementId(s) = $summary_id AND s.group_id = $group_id + WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement) WHERE stmt.emotion_type IS NOT NULL AND stmt.emotion_intensity IS NOT NULL @@ -347,7 +347,7 @@ class MemoryBaseService: result = await self.neo4j_connector.execute_query( query, summary_id=summary_id, - group_id=end_user_id + end_user_id=end_user_id ) if result and len(result) > 0: @@ -381,10 +381,10 @@ class MemoryBaseService: if end_user_id: query = """ MATCH (n:MemorySummary) - WHERE n.group_id = $group_id + WHERE n.end_user_id = $end_user_id RETURN count(n) as count """ - result = await self.neo4j_connector.execute_query(query, group_id=end_user_id) + result = await self.neo4j_connector.execute_query(query, end_user_id=end_user_id) else: query = """ MATCH (n:MemorySummary) @@ -423,12 +423,12 @@ class MemoryBaseService: if end_user_id: semantic_query = """ MATCH (e:ExtractedEntity) - WHERE e.group_id = $group_id AND e.is_explicit_memory = true + WHERE e.end_user_id = $end_user_id AND e.is_explicit_memory = true RETURN count(e) as count """ semantic_result = await self.neo4j_connector.execute_query( semantic_query, - group_id=end_user_id + end_user_id=end_user_id ) else: semantic_query = """ @@ -519,7 +519,7 @@ class MemoryBaseService: """ if end_user_id: - query += " AND n.group_id = $group_id" + query += " AND n.end_user_id = $end_user_id" query += """ RETURN sum(CASE WHEN n.activation_value IS NOT NULL AND n.activation_value < $threshold THEN 1 ELSE 0 END) as low_activation_nodes @@ -528,7 +528,7 @@ class MemoryBaseService: # 设置查询参数 params = {'threshold': forgetting_threshold} if end_user_id: - params['group_id'] = end_user_id + params['end_user_id'] = end_user_id # 执行查询 result = await self.neo4j_connector.execute_query(query, **params) diff --git a/api/app/services/memory_config_service.py b/api/app/services/memory_config_service.py index 0099eb18..e901d65d 100644 --- a/api/app/services/memory_config_service.py +++ b/api/app/services/memory_config_service.py @@ -14,7 +14,7 @@ from app.core.validators.memory_config_validators import ( validate_embedding_model, validate_model_exists_and_active, ) -from app.repositories.data_config_repository import DataConfigRepository +from app.repositories.memory_config_repository import MemoryConfigRepository from app.schemas.memory_config_schema import ( ConfigurationError, InvalidConfigError, @@ -23,20 +23,24 @@ from app.schemas.memory_config_schema import ( ModelNotFoundError, ) from sqlalchemy.orm import Session +from uuid import UUID logger = get_logger(__name__) config_logger = get_config_logger() - +import uuid def _validate_config_id(config_id): - """Validate configuration ID format.""" + """Validate configuration ID format (supports both UUID and integer).""" + if isinstance(config_id, uuid.UUID): + return config_id + if config_id is None: raise InvalidConfigError( "Configuration ID cannot be None", field_name="config_id", invalid_value=config_id, ) - + if isinstance(config_id, int): if config_id <= 0: raise InvalidConfigError( @@ -45,10 +49,19 @@ def _validate_config_id(config_id): invalid_value=config_id, ) return config_id - + if isinstance(config_id, str): + config_id_stripped = config_id.strip() + + # Try parsing as UUID first try: - parsed_id = int(config_id.strip()) + return uuid.UUID(config_id_stripped) + except ValueError: + pass + + # Fall back to integer parsing + try: + parsed_id = config_id_stripped if parsed_id <= 0: raise InvalidConfigError( f"Configuration ID must be positive: {parsed_id}", @@ -58,13 +71,13 @@ def _validate_config_id(config_id): return parsed_id except ValueError: raise InvalidConfigError( - f"Invalid configuration ID format: '{config_id}'", + f"Invalid configuration ID format: '{config_id}' (must be UUID or positive integer)", field_name="config_id", invalid_value=config_id, ) - + raise InvalidConfigError( - f"Invalid type for configuration ID: expected int or str, got {type(config_id).__name__}", + f"Invalid type for configuration ID: expected UUID, int or str, got {type(config_id).__name__}", field_name="config_id", invalid_value=config_id, ) @@ -73,61 +86,61 @@ def _validate_config_id(config_id): class MemoryConfigService: """ Centralized service for memory configuration loading and validation. - + This class provides a single implementation of configuration loading logic that can be shared across multiple services, eliminating code duplication. - + Usage: config_service = MemoryConfigService(db) memory_config = config_service.load_memory_config(config_id) model_config = config_service.get_model_config(model_id) """ - + def __init__(self, db: Session): """Initialize the service with a database session. - + Args: db: SQLAlchemy database session """ self.db = db - + def load_memory_config( self, - config_id: int, + config_id: UUID, service_name: str = "MemoryConfigService", ) -> MemoryConfig: """ Load memory configuration from database by config_id. - + Args: - config_id: Configuration ID from database + config_id: Configuration ID (UUID) from database service_name: Name of the calling service (for logging purposes) - + Returns: MemoryConfig: Immutable configuration object - + Raises: ConfigurationError: If validation fails """ start_time = time.time() - + config_logger.info( "Starting memory configuration loading", extra={ "operation": "load_memory_config", "service": service_name, - "config_id": config_id, + "config_id": str(config_id), }, ) - + logger.info(f"Loading memory configuration from database: config_id={config_id}") - + try: validated_config_id = _validate_config_id(config_id) - + # Step 1: Get config and workspace db_query_start = time.time() - result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id) + result = MemoryConfigRepository.get_config_with_workspace(self.db, validated_config_id) db_query_time = time.time() - db_query_start logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s") if not result: @@ -136,18 +149,18 @@ class MemoryConfigService: "Configuration not found in database", extra={ "operation": "load_memory_config", - "config_id": validated_config_id, + "config_id": str(config_id), "load_result": "not_found", "elapsed_ms": elapsed_ms, "service": service_name, }, ) raise ConfigurationError( - f"Configuration {validated_config_id} not found in database" + f"Configuration {config_id} not found in database" ) - + memory_config, workspace = result - + # Step 2: Validate embedding model (returns both UUID and name) embed_start = time.time() embedding_uuid, embedding_name = validate_embedding_model( @@ -159,7 +172,7 @@ class MemoryConfigService: ) embed_time = time.time() - embed_start logger.info(f"[PERF] Embedding validation: {embed_time:.4f}s") - + # Step 3: Resolve LLM model llm_start = time.time() llm_uuid, llm_name = validate_and_resolve_model_id( @@ -173,7 +186,7 @@ class MemoryConfigService: ) llm_time = time.time() - llm_start logger.info(f"[PERF] LLM validation: {llm_time:.4f}s") - + # Step 4: Resolve optional rerank model rerank_start = time.time() rerank_uuid = None @@ -191,10 +204,10 @@ class MemoryConfigService: rerank_time = time.time() - rerank_start if memory_config.rerank_id: logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s") - + # Note: embedding_name is now returned from validate_embedding_model above # No need for redundant query! - + # Create immutable MemoryConfig object config = MemoryConfig( config_id=memory_config.config_id, @@ -235,9 +248,9 @@ class MemoryConfigService: pruning_scene=memory_config.pruning_scene or "education", pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5, ) - + elapsed_ms = (time.time() - start_time) * 1000 - + config_logger.info( "Memory configuration loaded successfully", extra={ @@ -250,13 +263,13 @@ class MemoryConfigService: "elapsed_ms": elapsed_ms, }, ) - + logger.info(f"Memory configuration loaded successfully: {config.config_name}") return config - + except Exception as e: elapsed_ms = (time.time() - start_time) * 1000 - + config_logger.error( "Failed to load memory configuration", extra={ @@ -270,7 +283,7 @@ class MemoryConfigService: }, exc_info=True, ) - + logger.error(f"Failed to load memory configuration {config_id}: {e}") if isinstance(e, (ConfigurationError, ValueError)): raise diff --git a/api/app/services/memory_entity_relationship_service.py b/api/app/services/memory_entity_relationship_service.py index 9b5f3c99..7081d28b 100644 --- a/api/app/services/memory_entity_relationship_service.py +++ b/api/app/services/memory_entity_relationship_service.py @@ -717,8 +717,8 @@ class MemoryInteraction: ori_data= await self.connector.execute_query(Memory_Space_Entity, id=self.id) if ori_data!=[]: # name = ori_data[0]['name'] - group_id = [i['group_id'] for i in ori_data][0] - Space_User = await self.connector.execute_query(Memory_Space_User, group_id=group_id) + end_user_id = [i['end_user_id'] for i in ori_data][0] + Space_User = await self.connector.execute_query(Memory_Space_User, end_user_id=end_user_id) if not Space_User: return [] user_id=Space_User[0]['id'] diff --git a/api/app/services/memory_episodic_service.py b/api/app/services/memory_episodic_service.py index 12eeff6e..08751fd1 100644 --- a/api/app/services/memory_episodic_service.py +++ b/api/app/services/memory_episodic_service.py @@ -34,7 +34,7 @@ class MemoryEpisodicService(MemoryBaseService): Args: summary_id: Summary节点的ID - end_user_id: 终端用户ID (group_id) + end_user_id: 终端用户ID (end_user_id) Returns: (标题, 类型)元组,如果不存在则返回默认值 @@ -43,14 +43,14 @@ class MemoryEpisodicService(MemoryBaseService): # 查询Summary节点的name(作为title)和memory_type(作为type) query = """ MATCH (s:MemorySummary) - WHERE elementId(s) = $summary_id AND s.group_id = $group_id + WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id RETURN s.name AS title, s.memory_type AS type """ result = await self.neo4j_connector.execute_query( query, summary_id=summary_id, - group_id=end_user_id + end_user_id=end_user_id ) if not result or len(result) == 0: @@ -77,7 +77,7 @@ class MemoryEpisodicService(MemoryBaseService): Args: summary_id: Summary节点的ID - end_user_id: 终端用户ID (group_id) + end_user_id: 终端用户ID (end_user_id) Returns: 前3个实体的name属性列表 @@ -87,7 +87,7 @@ class MemoryEpisodicService(MemoryBaseService): # 按activation_value降序排序,返回前3个 query = """ MATCH (s:MemorySummary) - WHERE elementId(s) = $summary_id AND s.group_id = $group_id + WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement) MATCH (stmt)-[:REFERENCES_ENTITY]->(entity:ExtractedEntity) WHERE entity.activation_value IS NOT NULL @@ -99,7 +99,7 @@ class MemoryEpisodicService(MemoryBaseService): result = await self.neo4j_connector.execute_query( query, summary_id=summary_id, - group_id=end_user_id + end_user_id=end_user_id ) # 提取实体名称 @@ -123,7 +123,7 @@ class MemoryEpisodicService(MemoryBaseService): Args: summary_id: Summary节点的ID - end_user_id: 终端用户ID (group_id) + end_user_id: 终端用户ID (end_user_id) Returns: 所有Statement节点的statement属性内容列表 @@ -132,7 +132,7 @@ class MemoryEpisodicService(MemoryBaseService): # 查询Summary节点指向的所有Statement节点 query = """ MATCH (s:MemorySummary) - WHERE elementId(s) = $summary_id AND s.group_id = $group_id + WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement) WHERE stmt.statement IS NOT NULL AND stmt.statement <> '' RETURN stmt.statement AS statement @@ -141,7 +141,7 @@ class MemoryEpisodicService(MemoryBaseService): result = await self.neo4j_connector.execute_query( query, summary_id=summary_id, - group_id=end_user_id + end_user_id=end_user_id ) # 提取statement内容 @@ -214,12 +214,12 @@ class MemoryEpisodicService(MemoryBaseService): # 1. 先查询所有情景记忆的总数(不受筛选条件限制) total_all_query = """ MATCH (s:MemorySummary) - WHERE s.group_id = $group_id + WHERE s.end_user_id = $end_user_id RETURN count(s) AS total_all """ total_all_result = await self.neo4j_connector.execute_query( total_all_query, - group_id=end_user_id + end_user_id=end_user_id ) total_all = total_all_result[0]["total_all"] if total_all_result else 0 @@ -229,7 +229,7 @@ class MemoryEpisodicService(MemoryBaseService): # 3. 构建Cypher查询 query = """ MATCH (s:MemorySummary) - WHERE s.group_id = $group_id + WHERE s.end_user_id = $end_user_id """ # 添加时间范围过滤 @@ -248,7 +248,7 @@ class MemoryEpisodicService(MemoryBaseService): ORDER BY s.created_at DESC """ - params = {"group_id": end_user_id} + params = {"end_user_id": end_user_id} if time_filter: params["time_filter"] = time_filter if title_keyword: @@ -333,14 +333,14 @@ class MemoryEpisodicService(MemoryBaseService): # 1. 查询指定的MemorySummary节点 query = """ MATCH (s:MemorySummary) - WHERE elementId(s) = $summary_id AND s.group_id = $group_id + WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id RETURN elementId(s) AS id, s.created_at AS created_at """ result = await self.neo4j_connector.execute_query( query, summary_id=summary_id, - group_id=end_user_id + end_user_id=end_user_id ) # 2. 如果节点不存在,返回错误 diff --git a/api/app/services/memory_explicit_service.py b/api/app/services/memory_explicit_service.py index 713215c3..f8d39ae8 100644 --- a/api/app/services/memory_explicit_service.py +++ b/api/app/services/memory_explicit_service.py @@ -60,7 +60,7 @@ class MemoryExplicitService(MemoryBaseService): # ========== 1. 查询情景记忆(MemorySummary节点) ========== episodic_query = """ MATCH (s:MemorySummary) - WHERE s.group_id = $group_id + WHERE s.end_user_id = $end_user_id RETURN elementId(s) AS id, s.name AS title, s.content AS content, @@ -70,7 +70,7 @@ class MemoryExplicitService(MemoryBaseService): episodic_result = await self.neo4j_connector.execute_query( episodic_query, - group_id=end_user_id + end_user_id=end_user_id ) # 处理情景记忆数据 @@ -96,7 +96,7 @@ class MemoryExplicitService(MemoryBaseService): # ========== 2. 查询语义记忆(ExtractedEntity节点) ========== semantic_query = """ MATCH (e:ExtractedEntity) - WHERE e.group_id = $group_id + WHERE e.end_user_id = $end_user_id AND e.is_explicit_memory = true RETURN elementId(e) AS id, e.name AS name, @@ -107,7 +107,7 @@ class MemoryExplicitService(MemoryBaseService): semantic_result = await self.neo4j_connector.execute_query( semantic_query, - group_id=end_user_id + end_user_id=end_user_id ) # 处理语义记忆数据 @@ -189,7 +189,7 @@ class MemoryExplicitService(MemoryBaseService): # ========== 1. 先尝试查询情景记忆 ========== episodic_query = """ MATCH (s:MemorySummary) - WHERE elementId(s) = $memory_id AND s.group_id = $group_id + WHERE elementId(s) = $memory_id AND s.end_user_id = $end_user_id RETURN s.name AS title, s.content AS content, s.created_at AS created_at @@ -198,7 +198,7 @@ class MemoryExplicitService(MemoryBaseService): episodic_result = await self.neo4j_connector.execute_query( episodic_query, memory_id=memory_id, - group_id=end_user_id + end_user_id=end_user_id ) if episodic_result and len(episodic_result) > 0: @@ -229,7 +229,7 @@ class MemoryExplicitService(MemoryBaseService): semantic_query = """ MATCH (e:ExtractedEntity) WHERE elementId(e) = $memory_id - AND e.group_id = $group_id + AND e.end_user_id = $end_user_id AND e.is_explicit_memory = true RETURN e.name AS name, e.description AS core_definition, @@ -240,7 +240,7 @@ class MemoryExplicitService(MemoryBaseService): semantic_result = await self.neo4j_connector.execute_query( semantic_query, memory_id=memory_id, - group_id=end_user_id + end_user_id=end_user_id ) if semantic_result and len(semantic_result) > 0: diff --git a/api/app/services/memory_forget_service.py b/api/app/services/memory_forget_service.py index 2db4cdc7..e1030b24 100644 --- a/api/app/services/memory_forget_service.py +++ b/api/app/services/memory_forget_service.py @@ -12,6 +12,7 @@ from typing import Optional, Dict, Any, Tuple from datetime import datetime, timezone +from uuid import UUID from sqlalchemy.orm import Session @@ -23,7 +24,7 @@ from app.core.memory.storage_services.forgetting_engine.config_utils import ( load_actr_config_from_db, ) from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.repositories.data_config_repository import DataConfigRepository +from app.repositories.memory_config_repository import MemoryConfigRepository from app.repositories.forgetting_cycle_history_repository import ForgettingCycleHistoryRepository @@ -70,7 +71,7 @@ class MemoryForgetService: def __init__(self): """初始化服务""" - self.config_repository = DataConfigRepository() + self.config_repository = MemoryConfigRepository() self.history_repository = ForgettingCycleHistoryRepository() def _get_neo4j_connector(self) -> Neo4jConnector: @@ -87,7 +88,7 @@ class MemoryForgetService: async def _get_forgetting_components( self, db: Session, - config_id: Optional[int] = None + config_id: Optional[UUID] = None ) -> Tuple[ACTRCalculator, ForgettingStrategy, ForgettingScheduler, Dict[str, Any]]: """ 获取遗忘引擎组件(计算器、策略、调度器) @@ -132,7 +133,7 @@ class MemoryForgetService: async def _get_knowledge_stats( self, connector: Neo4jConnector, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, forgetting_threshold: float = 0.3 ) -> Dict[str, Any]: """ @@ -140,7 +141,7 @@ class MemoryForgetService: Args: connector: Neo4j 连接器 - group_id: 组ID(可选) + end_user_id: 组ID(可选) forgetting_threshold: 遗忘阈值 Returns: @@ -152,8 +153,8 @@ class MemoryForgetService: WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary) """ - if group_id: - query += " AND n.group_id = $group_id" + if end_user_id: + query += " AND n.end_user_id = $end_user_id" query += """ WITH n, @@ -172,8 +173,8 @@ class MemoryForgetService: """ params = {'threshold': forgetting_threshold} - if group_id: - params['group_id'] = group_id + if end_user_id: + params['end_user_id'] = end_user_id results = await connector.execute_query(query, **params) @@ -200,7 +201,7 @@ class MemoryForgetService: async def _get_pending_forgetting_nodes( self, connector: Neo4jConnector, - group_id: str, + end_user_id: str, forgetting_threshold: float, min_days_since_access: int, limit: int = 20 @@ -212,7 +213,7 @@ class MemoryForgetService: Args: connector: Neo4j 连接器 - group_id: 组ID + end_user_id: 组ID forgetting_threshold: 遗忘阈值 min_days_since_access: 最小未访问天数 limit: 返回节点数量限制 @@ -229,7 +230,7 @@ class MemoryForgetService: query = """ MATCH (n) WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary) - AND n.group_id = $group_id + AND n.end_user_id = $end_user_id AND n.activation_value IS NOT NULL AND n.activation_value < $threshold AND n.last_access_time IS NOT NULL @@ -250,7 +251,7 @@ class MemoryForgetService: """ params = { - 'group_id': group_id, + 'end_user_id': end_user_id, 'threshold': forgetting_threshold, 'min_access_time_str': min_access_time_str, 'limit': limit @@ -291,10 +292,10 @@ class MemoryForgetService: async def trigger_forgetting_cycle( self, db: Session, - group_id: str, + end_user_id: str, max_merge_batch_size: Optional[int] = None, min_days_since_access: Optional[int] = None, - config_id: Optional[int] = None + config_id: Optional[UUID] = None ) -> Dict[str, Any]: """ 手动触发遗忘周期 @@ -303,10 +304,10 @@ class MemoryForgetService: Args: db: 数据库会话 - group_id: 组ID(即终端用户ID,必填) + end_user_id: 组ID(即终端用户ID,必填) max_merge_batch_size: 最大融合批次大小(可选) min_days_since_access: 最小未访问天数(可选) - config_id: 配置ID(必填,由控制器层通过 group_id 获取) + config_id: 配置ID(必填,由控制器层通过 end_user_id 获取) Returns: dict: 遗忘报告 @@ -319,7 +320,7 @@ class MemoryForgetService: # 运行遗忘周期(LLM 客户端将在需要时由 forgetting_strategy 内部获取) report = await forgetting_scheduler.run_forgetting_cycle( - group_id=group_id, + end_user_id=end_user_id, max_merge_batch_size=max_merge_batch_size, min_days_since_access=min_days_since_access, config_id=config_id, @@ -338,7 +339,7 @@ class MemoryForgetService: stats_query = """ MATCH (n) WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk) - AND n.group_id = $group_id + AND n.end_user_id = $end_user_id RETURN count(n) as total_nodes, avg(n.activation_value) as average_activation, @@ -347,7 +348,7 @@ class MemoryForgetService: stats_results = await connector.execute_query( stats_query, - group_id=group_id, + end_user_id=end_user_id, threshold=config['forgetting_threshold'] ) @@ -364,7 +365,7 @@ class MemoryForgetService: # 保存历史记录到数据库 self.history_repository.create( db=db, - end_user_id=group_id, + end_user_id=end_user_id, execution_time=execution_time, merged_count=report['merged_count'], failed_count=report['failed_count'], @@ -376,7 +377,7 @@ class MemoryForgetService: ) api_logger.info( - f"已保存遗忘周期历史记录: end_user_id={group_id}, " + f"已保存遗忘周期历史记录: end_user_id={end_user_id}, " f"merged_count={report['merged_count']}" ) @@ -389,7 +390,7 @@ class MemoryForgetService: def read_forgetting_config( self, db: Session, - config_id: int + config_id: UUID ) -> Dict[str, Any]: """ 获取遗忘引擎配置 @@ -416,7 +417,7 @@ class MemoryForgetService: def update_forgetting_config( self, db: Session, - config_id: int, + config_id: UUID, update_fields: Dict[str, Any] ) -> Dict[str, Any]: """ @@ -465,8 +466,8 @@ class MemoryForgetService: async def get_forgetting_stats( self, db: Session, - group_id: Optional[str] = None, - config_id: Optional[int] = None + end_user_id: Optional[str] = None, + config_id: Optional[UUID] = None ) -> Dict[str, Any]: """ 获取遗忘引擎统计信息 @@ -475,7 +476,7 @@ class MemoryForgetService: Args: db: 数据库会话 - group_id: 组ID(可选) + end_user_id: 组ID(可选) config_id: 配置ID(可选,用于获取遗忘阈值) Returns: @@ -493,8 +494,8 @@ class MemoryForgetService: WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk) """ - if group_id: - activation_query += " AND n.group_id = $group_id" + if end_user_id: + activation_query += " AND n.end_user_id = $end_user_id" activation_query += """ RETURN @@ -506,8 +507,8 @@ class MemoryForgetService: """ params = {'threshold': forgetting_threshold} - if group_id: - params['group_id'] = group_id + if end_user_id: + params['end_user_id'] = end_user_id activation_results = await connector.execute_query(activation_query, **params) @@ -539,8 +540,8 @@ class MemoryForgetService: WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk) """ - if group_id: - distribution_query += " AND n.group_id = $group_id" + if end_user_id: + distribution_query += " AND n.end_user_id = $end_user_id" distribution_query += """ WITH n, @@ -558,8 +559,8 @@ class MemoryForgetService: """ dist_params = {} - if group_id: - dist_params['group_id'] = group_id + if end_user_id: + dist_params['end_user_id'] = end_user_id distribution_results = await connector.execute_query(distribution_query, **dist_params) @@ -582,11 +583,11 @@ class MemoryForgetService: # 获取最近7个日期的历史趋势数据(每天取最后一次执行) recent_trends = [] try: - if group_id: + if end_user_id: # 查询所有历史记录 history_records = self.history_repository.get_recent_by_end_user( db=db, - end_user_id=group_id + end_user_id=end_user_id ) # 按日期分组(一天可能有多次执行,取最后一次) @@ -632,7 +633,7 @@ class MemoryForgetService: # 获取待遗忘节点列表(前20个满足遗忘条件的节点) pending_nodes = [] try: - if group_id: + if end_user_id: # 验证 min_days_since_access 配置值 min_days = config.get('min_days_since_access') if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0: @@ -643,7 +644,7 @@ class MemoryForgetService: pending_nodes = await self._get_pending_forgetting_nodes( connector=connector, - group_id=group_id, + end_user_id=end_user_id, forgetting_threshold=forgetting_threshold, min_days_since_access=int(min_days), limit=20 @@ -677,7 +678,7 @@ class MemoryForgetService: db: Session, importance_score: float, days: int, - config_id: Optional[int] = None + config_id: Optional[UUID] = None ) -> Dict[str, Any]: """ 获取遗忘曲线数据 diff --git a/api/app/services/memory_konwledges_server.py b/api/app/services/memory_konwledges_server.py index c6297e12..420f7ca1 100644 --- a/api/app/services/memory_konwledges_server.py +++ b/api/app/services/memory_konwledges_server.py @@ -450,12 +450,12 @@ async def create_document_chunk( return success(data=chunk, msg="文档块创建成功") -async def write_rag(group_id, message, user_rag_memory_id): +async def write_rag(end_user_id, message, user_rag_memory_id): """ 将消息写入 RAG 知识库 Args: - group_id: 组ID,用作文件标题 + end_user_id: 组ID,用作文件标题 message: 消息内容 user_rag_memory_id: 知识库ID(必须是有效的UUID) @@ -487,10 +487,10 @@ async def write_rag(group_id, message, user_rag_memory_id): db = next(db_gen) try: - create_data = CustomTextFileCreate(title=group_id, content=message) + create_data = CustomTextFileCreate(title=end_user_id, content=message) current_user = SimpleUser(user_rag_memory_id) # 检查文档是否已存在 - document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{group_id}.txt") + document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt") print('======',document) api_logger.info(f"查找文档结果: document_id={document}") if document is not None: @@ -508,7 +508,7 @@ async def write_rag(group_id, message, user_rag_memory_id): return result else: # 文档不存在,创建新文档 - api_logger.info(f"文档不存在,创建新文档: group_id={group_id}") + api_logger.info(f"文档不存在,创建新文档: end_user_id={end_user_id}") result = await memory_konwledges_up( kb_id=user_rag_memory_id, parent_id=user_rag_memory_id, @@ -520,13 +520,13 @@ async def write_rag(group_id, message, user_rag_memory_id): new_document_id = find_document_id_by_kb_and_filename( db=db, kb_id=user_rag_memory_id, - file_name=f"{group_id}.txt" + file_name=f"{end_user_id}.txt" ) if new_document_id: await parse_document_by_id(new_document_id, db=db, current_user=current_user) else: - api_logger.error(f"创建文档后无法找到文档ID: group_id={group_id}") + api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}") return result finally: # 确保数据库会话被关闭 diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py index d257e80f..b9d96a0b 100644 --- a/api/app/services/memory_perceptual_service.py +++ b/api/app/services/memory_perceptual_service.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger -from app.models.memory_perceptual_model import PerceptualType, FileStorageType +from app.models.memory_perceptual_model import PerceptualType, FileStorageService from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository from app.schemas.memory_perceptual_schema import ( PerceptualQuerySchema, @@ -137,8 +137,19 @@ class MemoryPerceptualService: memory_items = [] for memory in memories: meta_data = memory.meta_data or {} - content = meta_data.get("content") - content = Content(**content) + content = meta_data.get("content", {}) + + # 安全地提取 content 字段,提供默认值 + if content: + content_obj = Content(**content) + topic = content_obj.topic + domain = content_obj.domain + keywords = content_obj.keywords + else: + topic = "Unknown" + domain = "Unknown" + keywords = [] + memory_item = PerceptualMemoryItem( id=memory.id, perceptual_type=PerceptualType(memory.perceptual_type), @@ -146,11 +157,12 @@ class MemoryPerceptualService: file_name=memory.file_name, file_ext=memory.file_ext, summary=memory.summary, - topic=content.topic, - domain=content.domain, - keywords=content.keywords, + meta_data=meta_data, + topic=topic, + domain=domain, + keywords=keywords, created_time=int(memory.created_time.timestamp()*1000), - storage_type=FileStorageType(memory.storage_service), + storage_service=FileStorageService(memory.storage_service), ) memory_items.append(memory_item) diff --git a/api/app/services/memory_reflection_service.py b/api/app/services/memory_reflection_service.py index af72e3cc..402a40a1 100644 --- a/api/app/services/memory_reflection_service.py +++ b/api/app/services/memory_reflection_service.py @@ -13,7 +13,7 @@ from app.db import get_db from app.core.logging_config import get_api_logger from app.core.memory.storage_services.reflection_engine import ReflectionConfig, ReflectionEngine from app.core.memory.storage_services.reflection_engine.self_reflexion import ReflectionRange, ReflectionBaseline -from app.repositories.data_config_repository import DataConfigRepository +from app.repositories.memory_config_repository import MemoryConfigRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.models.app_model import App from app.models.app_release_model import AppRelease @@ -73,7 +73,7 @@ class WorkspaceAppService: "created_at": app.created_at.isoformat() if app.created_at else None, "updated_at": app.updated_at.isoformat() if app.updated_at else None, "releases": [], - "data_configs": [], + "memory_configs": [], "end_users": [] } @@ -101,11 +101,11 @@ class WorkspaceAppService: if memory_content: processed_configs.add(memory_content) - data_config_info = self._get_data_config(memory_content) + memory_config_info = self._get_memory_config(memory_content) - if data_config_info: - if not any(dc["config_id"] == data_config_info["config_id"] for dc in app_info["data_configs"]): - app_info["data_configs"].append(data_config_info) + if memory_config_info: + if not any(dc["config_id"] == memory_config_info["config_id"] for dc in app_info["memory_configs"]): + app_info["memory_configs"].append(memory_config_info) app_info["releases"].append(release_info) @@ -120,30 +120,30 @@ class WorkspaceAppService: return None - def _get_data_config(self, memory_content: str) -> Dict[str, Any]: - """Retrieve data_comfig information based on memory_comtent""" + def _get_memory_config(self, memory_content: str) -> Dict[str, Any]: + """Retrieve memory_config information based on memory_content""" try: - data_config_result = DataConfigRepository.query_reflection_config_by_id(self.db, int(memory_content)) + memory_config_result = MemoryConfigRepository.query_reflection_config_by_id(self.db, int(memory_content)) - # data_config_query, data_config_params = DataConfigRepository.build_select_reflection(memory_content) - # data_config_result = self.db.execute(text(data_config_query), data_config_params).fetchone() - # if data_config_result is None: + # memory_config_query, memory_config_params = MemoryConfigRepository.build_select_reflection(memory_content) + # memory_config_result = self.db.execute(text(memory_config_query), memory_config_params).fetchone() + # if memory_config_result is None: # return None - if data_config_result: + if memory_config_result: return { - "config_id": data_config_result.config_id, - "enable_self_reflexion": data_config_result.enable_self_reflexion, - "iteration_period": data_config_result.iteration_period, - "reflexion_range": data_config_result.reflexion_range, - "baseline": data_config_result.baseline, - "reflection_model_id": data_config_result.reflection_model_id, - "memory_verify": data_config_result.memory_verify, - "quality_assessment": data_config_result.quality_assessment, - "user_id": data_config_result.user_id + "config_id": memory_config_result.config_id, + "enable_self_reflexion": memory_config_result.enable_self_reflexion, + "iteration_period": memory_config_result.iteration_period, + "reflexion_range": memory_config_result.reflexion_range, + "baseline": memory_config_result.baseline, + "reflection_model_id": memory_config_result.reflection_model_id, + "memory_verify": memory_config_result.memory_verify, + "quality_assessment": memory_config_result.quality_assessment, + "user_id": memory_config_result.user_id } except Exception as e: - api_logger.warning(f"查询data_config失败,memory_content: {memory_content}, 错误: {str(e)}") + api_logger.warning(f"查询memory_config失败,memory_content: {memory_content}, 错误: {str(e)}") return None @@ -226,7 +226,7 @@ class MemoryReflectionService: } config_data_id = config_data['config_id'] - reflection_config = WorkspaceAppService(self.db)._get_data_config(config_data_id) + reflection_config = WorkspaceAppService(self.db)._get_memory_config(config_data_id) if reflection_config is not None and reflection_config['enable_self_reflexion']: reflection_config = self._create_reflection_config_from_data(reflection_config) # 3. 执行反思引擎 @@ -280,7 +280,7 @@ class MemoryReflectionService: config_data_id=config_data['config_id'] - reflection_config=WorkspaceAppService(self.db)._get_data_config(config_data_id) + reflection_config=WorkspaceAppService(self.db)._get_memory_config(config_data_id) if reflection_config is not None and reflection_config['enable_self_reflexion']: reflection_config= self._create_reflection_config_from_data(reflection_config) iteration_period = int(reflection_config.iteration_period) diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index c276f337..80d8c717 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -19,7 +19,7 @@ from app.core.memory.analytics.hot_memory_tags import ( ) from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats from app.models.user_model import User -from app.repositories.data_config_repository import DataConfigRepository +from app.repositories.memory_config_repository import MemoryConfigRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import ConfigurationError from app.schemas.memory_storage_schema import ( @@ -129,7 +129,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) if not params.rerank_id: params.rerank_id = configs.get('rerank') - config = DataConfigRepository.create(self.db, params) + config = MemoryConfigRepository.create(self.db, params) self.db.commit() return {"affected": 1, "config_id": config.config_id} @@ -146,20 +146,20 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # --- Delete --- def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID) - success = DataConfigRepository.delete(self.db, key.config_id) + success = MemoryConfigRepository.delete(self.db, key.config_id) if not success: raise ValueError("未找到配置") return {"affected": 1} # --- Update --- def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数 - config = DataConfigRepository.update(self.db, update) + config = MemoryConfigRepository.update(self.db, update) if not config: raise ValueError("未找到配置") return {"affected": 1} def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数 - config = DataConfigRepository.update_extracted(self.db, update) + config = MemoryConfigRepository.update_extracted(self.db, update) if not config: raise ValueError("未找到配置") return {"affected": 1} @@ -170,14 +170,14 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # --- Read --- def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数 - result = DataConfigRepository.get_extracted_config(self.db, key.config_id) + result = MemoryConfigRepository.get_extracted_config(self.db, key.config_id) if not result: raise ValueError("未找到配置") return result # --- Read All --- def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数 - configs = DataConfigRepository.get_all(self.db, workspace_id) + configs = MemoryConfigRepository.get_all(self.db, workspace_id) # 将 ORM 对象转换为字典列表 data_list = [] @@ -187,7 +187,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) "config_name": config.config_name, "config_desc": config.config_desc, "workspace_id": str(config.workspace_id) if config.workspace_id else None, - "group_id": config.group_id, + "end_user_id": config.end_user_id, "user_id": config.user_id, "apply_id": config.apply_id, "llm_id": config.llm_id, @@ -395,8 +395,8 @@ _neo4j_connector = Neo4jConnector() async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]: result = await _neo4j_connector.execute_query( - DataConfigRepository.SEARCH_FOR_DIALOGUE, - group_id=end_user_id, + MemoryConfigRepository.SEARCH_FOR_DIALOGUE, + end_user_id=end_user_id, ) data = {"search_for": "dialogue", "num": result[0]["num"]} return data @@ -404,8 +404,8 @@ async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]: async def search_chunk(end_user_id: Optional[str] = None) -> Dict[str, Any]: result = await _neo4j_connector.execute_query( - DataConfigRepository.SEARCH_FOR_CHUNK, - group_id=end_user_id, + MemoryConfigRepository.SEARCH_FOR_CHUNK, + end_user_id=end_user_id, ) data = {"search_for": "chunk", "num": result[0]["num"]} return data @@ -413,8 +413,8 @@ async def search_chunk(end_user_id: Optional[str] = None) -> Dict[str, Any]: async def search_statement(end_user_id: Optional[str] = None) -> Dict[str, Any]: result = await _neo4j_connector.execute_query( - DataConfigRepository.SEARCH_FOR_STATEMENT, - group_id=end_user_id, + MemoryConfigRepository.SEARCH_FOR_STATEMENT, + end_user_id=end_user_id, ) data = {"search_for": "statement", "num": result[0]["num"]} return data @@ -422,8 +422,8 @@ async def search_statement(end_user_id: Optional[str] = None) -> Dict[str, Any]: async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]: result = await _neo4j_connector.execute_query( - DataConfigRepository.SEARCH_FOR_ENTITY, - group_id=end_user_id, + MemoryConfigRepository.SEARCH_FOR_ENTITY, + end_user_id=end_user_id, ) data = {"search_for": "entity", "num": result[0]["num"]} return data @@ -431,8 +431,8 @@ async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]: async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]: result = await _neo4j_connector.execute_query( - DataConfigRepository.SEARCH_FOR_ALL, - group_id=end_user_id, + MemoryConfigRepository.SEARCH_FOR_ALL, + end_user_id=end_user_id, ) # 检查结果是否为空或长度不足 @@ -466,8 +466,8 @@ async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, A 聚合 dialogue/chunk/statement/entity 四类计数,返回统一的分布结构,便于前端一次性消费。 """ result = await _neo4j_connector.execute_query( - DataConfigRepository.SEARCH_FOR_ALL, - group_id=end_user_id, + MemoryConfigRepository.SEARCH_FOR_ALL, + end_user_id=end_user_id, ) # 检查结果是否为空或长度不足 @@ -497,21 +497,19 @@ async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, A async def search_detials(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]: result = await _neo4j_connector.execute_query( - DataConfigRepository.SEARCH_FOR_DETIALS, - group_id=end_user_id, + MemoryConfigRepository.SEARCH_FOR_DETIALS, + end_user_id=end_user_id, ) return result async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]: result = await _neo4j_connector.execute_query( - DataConfigRepository.SEARCH_FOR_EDGES, - group_id=end_user_id, + MemoryConfigRepository.SEARCH_FOR_EDGES, + end_user_id=end_user_id, ) return result - - async def analytics_hot_memory_tags( db: Session, current_user: User, @@ -574,7 +572,7 @@ async def analytics_hot_memory_tags( # 步骤4: 只调用一次LLM进行筛选 tag_names = [tag for tag, _ in sorted_tags] - # 使用第一个用户的group_id来获取LLM配置 + # 使用第一个用户的end_user_id来获取LLM配置 # 因为同一工作空间下的用户应该使用相同的配置 first_end_user_id = str(end_users[0].id) filtered_tag_names = await filter_tags_with_llm(tag_names, first_end_user_id) diff --git a/api/app/services/pilot_run_service.py b/api/app/services/pilot_run_service.py index 17dfd7eb..755dda14 100644 --- a/api/app/services/pilot_run_service.py +++ b/api/app/services/pilot_run_service.py @@ -91,7 +91,7 @@ async def run_pilot_extraction( dialog = DialogData( context=context, ref_id="pilot_dialog_1", - group_id=str(memory_config.workspace_id), + end_user_id=str(memory_config.workspace_id), user_id=str(memory_config.tenant_id), apply_id=str(memory_config.config_id), metadata={"source": "pilot_run", "input_type": "frontend_text"}, diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 863bccb0..3a90a821 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -155,10 +155,10 @@ class MemoryInsightHelper: """ query = """ MATCH (d:Dialogue) - WHERE d.group_id = $group_id AND d.created_at IS NOT NULL AND d.created_at <> '' + WHERE d.end_user_id = $end_user_id AND d.created_at IS NOT NULL AND d.created_at <> '' RETURN d.created_at AS creation_time """ - records = await self.neo4j_connector.execute_query(query, group_id=self.user_id) + records = await self.neo4j_connector.execute_query(query, end_user_id=self.user_id) if not records: return [] @@ -211,17 +211,17 @@ class MemoryInsightHelper: async def get_social_connections(self) -> dict | None: """Find the user with whom the most memories are shared.""" query = """ - MATCH (c1:Chunk {group_id: $group_id}) + MATCH (c1:Chunk {end_user_id: $end_user_id}) OPTIONAL MATCH (c1)-[:CONTAINS]->(s:Statement) OPTIONAL MATCH (s)<-[:CONTAINS]-(c2:Chunk) - WHERE c1.group_id <> c2.group_id AND s IS NOT NULL AND c2 IS NOT NULL - WITH c2.group_id AS other_user_id, COUNT(DISTINCT s) AS common_statements + WHERE c1.end_user_id <> c2.end_user_id AND s IS NOT NULL AND c2 IS NOT NULL + WITH c2.end_user_id AS other_user_id, COUNT(DISTINCT s) AS common_statements WHERE common_statements > 0 RETURN other_user_id, common_statements ORDER BY common_statements DESC LIMIT 1 """ - records = await self.neo4j_connector.execute_query(query, group_id=self.user_id) + records = await self.neo4j_connector.execute_query(query, end_user_id=self.user_id) if not records or not records[0].get("other_user_id"): return None @@ -230,7 +230,7 @@ class MemoryInsightHelper: time_range_query = """ MATCH (c:Chunk) - WHERE c.group_id IN [$user_id, $other_user_id] + WHERE c.end_user_id IN [$user_id, $other_user_id] RETURN min(c.created_at) AS start_time, max(c.created_at) AS end_time """ time_records = await self.neo4j_connector.execute_query( @@ -294,11 +294,11 @@ class UserSummaryHelper: """Fetch recent statements authored by the user/group for context.""" query = ( "MATCH (s:Statement) " - "WHERE s.group_id = $group_id AND s.statement IS NOT NULL " + "WHERE s.end_user_id = $end_user_id AND s.statement IS NOT NULL " "RETURN s.statement AS statement, s.created_at AS created_at " "ORDER BY created_at DESC LIMIT $limit" ) - rows = await self.connector.execute_query(query, group_id=self.user_id, limit=limit) + rows = await self.connector.execute_query(query, end_user_id=self.user_id, limit=limit) records = [] for r in rows: try: @@ -1152,7 +1152,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str, import re # 创建 UserSummaryHelper 实例 - user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_GROUP_ID", "group_123")) + user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_end_user_id", "group_123")) try: # 1) 收集上下文数据 @@ -1273,10 +1273,10 @@ async def analytics_node_statistics( if end_user_id: query = f""" MATCH (n:{node_type}) - WHERE n.group_id = $group_id + WHERE n.end_user_id = $end_user_id RETURN count(n) as count """ - result = await _neo4j_connector.execute_query(query, group_id=end_user_id) + result = await _neo4j_connector.execute_query(query, end_user_id=end_user_id) else: query = f""" MATCH (n:{node_type}) @@ -1387,10 +1387,10 @@ async def analytics_memory_types( # 查询 Statement 节点数量 query = """ MATCH (n:Statement) - WHERE n.group_id = $group_id + WHERE n.end_user_id = $end_user_id RETURN count(n) as count """ - result = await _neo4j_connector.execute_query(query, group_id=end_user_id) + result = await _neo4j_connector.execute_query(query, end_user_id=end_user_id) statement_count = result[0]["count"] if result and len(result) > 0 else 0 # 取三分之一作为隐性记忆数量 implicit_count = round(statement_count / 3) @@ -1504,7 +1504,7 @@ async def analytics_graph_data( 包含节点、边和统计信息的字典 """ try: - # 1. 获取 group_id + # 1. 获取 end_user_id user_uuid = uuid.UUID(end_user_id) repo = EndUserRepository(db) end_user = repo.get_by_id(user_uuid) @@ -1528,7 +1528,7 @@ async def analytics_graph_data( # 基于中心节点的扩展查询 node_query = f""" MATCH path = (center)-[*1..{depth}]-(connected) - WHERE center.group_id = $group_id + WHERE center.end_user_id = $end_user_id AND elementId(center) = $center_node_id WITH collect(DISTINCT center) + collect(DISTINCT connected) as all_nodes UNWIND all_nodes as n @@ -1539,7 +1539,7 @@ async def analytics_graph_data( LIMIT $limit """ node_params = { - "group_id": end_user_id, + "end_user_id": end_user_id, "center_node_id": center_node_id, "limit": limit } @@ -1547,7 +1547,7 @@ async def analytics_graph_data( # 按节点类型过滤查询 node_query = """ MATCH (n) - WHERE n.group_id = $group_id + WHERE n.end_user_id = $end_user_id AND labels(n)[0] IN $node_types RETURN elementId(n) as id, @@ -1556,7 +1556,7 @@ async def analytics_graph_data( LIMIT $limit """ node_params = { - "group_id": end_user_id, + "end_user_id": end_user_id, "node_types": node_types, "limit": limit } @@ -1564,7 +1564,7 @@ async def analytics_graph_data( # 查询所有节点 node_query = """ MATCH (n) - WHERE n.group_id = $group_id + WHERE n.end_user_id = $end_user_id RETURN elementId(n) as id, labels(n)[0] as label, @@ -1572,7 +1572,7 @@ async def analytics_graph_data( LIMIT $limit """ node_params = { - "group_id": end_user_id, + "end_user_id": end_user_id, "limit": limit } diff --git a/api/app/tasks.py b/api/app/tasks.py index 5f2b1ef5..cdd7945e 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -4,6 +4,7 @@ import os import re import time import uuid +from uuid import UUID from datetime import datetime, timezone from math import ceil from typing import Any, Dict, List, Optional @@ -382,16 +383,16 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): @celery_app.task(name="app.core.memory.agent.read_message", bind=True) -def read_message_task(self, group_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]: +def read_message_task(self, end_user_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str, storage_type:str, user_rag_memory_id:str) -> Dict[str, Any]: """Celery task to process a read message via MemoryAgentService. Args: - group_id: Group ID for the memory agent (also used as end_user_id) + end_user_id: Group ID for the memory agent (also used as end_user_id) message: User message to process history: Conversation history search_switch: Search switch parameter - config_id: Optional configuration ID + config_id: Configuration ID as string (will be converted to UUID) Returns: Dict containing the result and metadata @@ -401,14 +402,22 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str, """ start_time = time.time() + # Convert config_id string to UUID + actual_config_id = None + if config_id: + try: + actual_config_id = uuid.UUID(config_id) if isinstance(config_id, str) else config_id + except (ValueError, AttributeError): + # If conversion fails, leave as None and try to resolve + pass + # Resolve config_id if None - actual_config_id = config_id if actual_config_id is None: try: from app.services.memory_agent_service import get_end_user_connected_config db = next(get_db()) try: - connected_config = get_end_user_connected_config(group_id, db) + connected_config = get_end_user_connected_config(end_user_id, db) actual_config_id = connected_config.get("memory_config_id") finally: db.close() @@ -420,24 +429,42 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str, db = next(get_db()) try: service = MemoryAgentService() - return await service.read_memory(group_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id) + return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id) finally: db.close() try: - result = asyncio.run(_run()) + # 使用 nest_asyncio 来避免事件循环冲突 + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + # 尝试获取现有事件循环,如果不存在则创建新的 + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time return { "status": "SUCCESS", "result": result, - "group_id": group_id, + "end_user_id": end_user_id, "config_id": config_id, "elapsed_time": elapsed_time, "task_id": self.request.id } except BaseException as e: elapsed_time = time.time() - start_time + # Handle ExceptionGroup from TaskGroup if hasattr(e, 'exceptions'): error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] detailed_error = "; ".join(error_messages) @@ -446,7 +473,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str, return { "status": "FAILURE", "error": detailed_error, - "group_id": group_id, + "end_user_id": end_user_id, "config_id": config_id, "elapsed_time": elapsed_time, "task_id": self.request.id @@ -454,19 +481,13 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str, @celery_app.task(name="app.core.memory.agent.write_message", bind=True) -def write_message_task(self, group_id: str, message, config_id: str, storage_type: str, user_rag_memory_id: str) -> Dict[str, Any]: +def write_message_task(self, end_user_id: str, message: str, config_id: str, storage_type:str, user_rag_memory_id:str) -> Dict[str, Any]: """Celery task to process a write message via MemoryAgentService. - 支持两种消息格式: - 1. 字符串格式(向后兼容):message="user: xxx\nassistant: yyy" - 2. 结构化消息列表(推荐):message=[{"role": "user", "content": "xxx"}, {"role": "assistant", "content": "yyy"}] - Args: - group_id: Group ID for the memory agent (also used as end_user_id) - message: Message to write (str or list[dict]) - config_id: Optional configuration ID - storage_type: Storage type (neo4j/rag) - user_rag_memory_id: RAG memory ID + end_user_id: Group ID for the memory agent (also used as end_user_id) + message: Message to write + config_id: Configuration ID as string (will be converted to UUID) Returns: Dict containing the result and metadata @@ -477,30 +498,46 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ from app.core.logging_config import get_logger logger = get_logger(__name__) - logger.info(f"[CELERY WRITE] Starting write task - group_id={group_id}, config_id={config_id}, storage_type={storage_type}") + logger.info(f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}") start_time = time.time() + # Convert config_id string to UUID + actual_config_id = None + if config_id: + try: + actual_config_id = uuid.UUID(config_id) if isinstance(config_id, str) else config_id + logger.info(f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})") + except (ValueError, AttributeError) as e: + logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id}, error: {e}") + return { + "status": "FAILURE", + "error": f"Invalid config_id format: {config_id}", + "end_user_id": end_user_id, + "config_id": config_id, + "elapsed_time": 0.0, + "task_id": self.request.id + } + # Resolve config_id if None - actual_config_id = config_id if actual_config_id is None: try: from app.services.memory_agent_service import get_end_user_connected_config db = next(get_db()) try: - connected_config = get_end_user_connected_config(group_id, db) + connected_config = get_end_user_connected_config(end_user_id, db) actual_config_id = connected_config.get("memory_config_id") finally: db.close() except Exception: # Log but continue - will fail later with proper error pass - + async def _run() -> str: db = next(get_db()) try: - logger.info(f"[CELERY WRITE] Executing MemoryAgentService.write_memory") + logger.info(f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__})") service = MemoryAgentService() - result = await service.write_memory(group_id, message, actual_config_id, db, storage_type, user_rag_memory_id) + result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type, user_rag_memory_id) logger.info(f"[CELERY WRITE] Write completed successfully: {result}") return result except Exception as e: @@ -510,7 +547,24 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ db.close() try: - result = asyncio.run(_run()) + # 使用 nest_asyncio 来避免事件循环冲突 + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + # 尝试获取现有事件循环,如果不存在则创建新的 + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") @@ -518,13 +572,14 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ return { "status": "SUCCESS", "result": result, - "group_id": group_id, + "end_user_id": end_user_id, "config_id": config_id, "elapsed_time": elapsed_time, "task_id": self.request.id } except BaseException as e: elapsed_time = time.time() - start_time + # Handle ExceptionGroup from TaskGroup if hasattr(e, 'exceptions'): error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] detailed_error = "; ".join(error_messages) @@ -536,7 +591,7 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ return { "status": "FAILURE", "error": detailed_error, - "group_id": group_id, + "end_user_id": end_user_id, "config_id": config_id, "elapsed_time": elapsed_time, "task_id": self.request.id @@ -878,7 +933,24 @@ def regenerate_memory_cache(self) -> Dict[str, Any]: } try: - result = asyncio.run(_run()) + # 使用 nest_asyncio 来避免事件循环冲突 + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + # 尝试获取现有事件循环,如果不存在则创建新的 + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time result["elapsed_time"] = elapsed_time result["task_id"] = self.request.id @@ -951,7 +1023,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]: end_users = data['end_users'] for base, config, user in zip(releases, data_configs, end_users): - if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']: + if str(base['config']) == str(config['config_id']) and str(base['app_id']) == str(user['app_id']): # 调用反思服务 api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}") @@ -1005,7 +1077,24 @@ def workspace_reflection_task(self) -> Dict[str, Any]: } try: - result = asyncio.run(_run()) + # 使用 nest_asyncio 来避免事件循环冲突 + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + # 尝试获取现有事件循环,如果不存在则创建新的 + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time result["elapsed_time"] = elapsed_time result["task_id"] = self.request.id @@ -1023,7 +1112,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]: @celery_app.task(name="app.tasks.run_forgetting_cycle_task", bind=True) -def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str, Any]: +def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Dict[str, Any]: """定时任务:运行遗忘周期 定期执行遗忘周期,识别并融合低激活值的知识节点。 @@ -1051,7 +1140,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str # 运行遗忘周期 report = await forget_service.trigger_forgetting( db=db, - group_id=None, # 处理所有组 + end_user_id=None, # 处理所有组 config_id=config_id ) @@ -1081,4 +1170,11 @@ def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str "duration_seconds": duration } - return asyncio.run(_run()) + # 运行异步函数 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + result = loop.run_until_complete(_run()) + return result + finally: + loop.close() diff --git a/api/app/utils/app_config_utils.py b/api/app/utils/app_config_utils.py index 514e4565..ae41d8bf 100644 --- a/api/app/utils/app_config_utils.py +++ b/api/app/utils/app_config_utils.py @@ -83,6 +83,13 @@ class AgentConfigProxy: def agent_config_4_app_release(release: AppRelease) -> AgentConfig: config_dict = release.config + # 如果 config 是字符串,解析为字典 + if isinstance(config_dict, str): + import json + try: + config_dict = json.loads(config_dict) + except json.JSONDecodeError: + config_dict = {} agent_config = AgentConfig( app_id=release.app_id, @@ -100,6 +107,14 @@ def agent_config_4_app_release(release: AppRelease) -> AgentConfig: def multi_agent_config_4_app_release(release: AppRelease) -> MultiAgentConfig: config_dict = release.config + + # 如果 config 是字符串,解析为字典 + if isinstance(config_dict, str): + import json + try: + config_dict = json.loads(config_dict) + except json.JSONDecodeError: + config_dict = {} agent_config = MultiAgentConfig( app_id=release.app_id, @@ -120,6 +135,14 @@ def multi_agent_config_4_app_release(release: AppRelease) -> MultiAgentConfig: def workflow_config_4_app_release(release: AppRelease) -> WorkflowConfig: config_dict = release.config + + # 如果 config 是字符串,解析为字典 + if isinstance(config_dict, str): + import json + try: + config_dict = json.loads(config_dict) + except json.JSONDecodeError: + config_dict = {} config = WorkflowConfig( id=config_dict.get("id"), diff --git a/api/uv.lock b/api/uv.lock index bccaef2c..f3b23325 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -4462,4 +4462,4 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ff/8d/0309daffea4fcac7981021dbf21cdb2e3427a9e76bafbcdbdf5392ff99a4/zstandard-0.25.0-cp312-cp312-win32.whl", hash = "sha256:23ebc8f17a03133b4426bcc04aabd68f8236eb78c3760f12783385171b0fd8bd", size = 436922, upload-time = "2025-09-14T22:17:24.398Z" }, { url = "https://files.pythonhosted.org/packages/79/3b/fa54d9015f945330510cb5d0b0501e8253c127cca7ebe8ba46a965df18c5/zstandard-0.25.0-cp312-cp312-win_amd64.whl", hash = "sha256:ffef5a74088f1e09947aecf91011136665152e0b4b359c42be3373897fb39b01", size = 506276, upload-time = "2025-09-14T22:17:21.429Z" }, { url = "https://files.pythonhosted.org/packages/ea/6b/8b51697e5319b1f9ac71087b0af9a40d8a6288ff8025c36486e0c12abcc4/zstandard-0.25.0-cp312-cp312-win_arm64.whl", hash = "sha256:181eb40e0b6a29b3cd2849f825e0fa34397f649170673d385f3598ae17cca2e9", size = 462679, upload-time = "2025-09-14T22:17:23.147Z" }, -] +] \ No newline at end of file From e3b6ede99240faab54f41dc07aea72971791d385 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Mon, 26 Jan 2026 11:54:38 +0800 Subject: [PATCH 10/28] feat(sandbox): add Python 3 code execution sandbox support --- .gitignore | 3 + sandbox/Dockerfile | 42 ++++ sandbox/app/config.py | 134 ++++++++++++ sandbox/app/controllers/__init__.py | 8 + sandbox/app/controllers/health_controller.py | 12 ++ sandbox/app/controllers/sandbox_controller.py | 59 ++++++ sandbox/app/core/__init__.py | 1 + sandbox/app/core/encryption.py | 32 +++ sandbox/app/core/executor.py | 48 +++++ sandbox/app/core/runners/__init__.py | 1 + sandbox/app/core/runners/python/__init__.py | 4 + sandbox/app/core/runners/python/env.py | 50 +++++ sandbox/app/core/runners/python/prescript.py | 56 +++++ .../app/core/runners/python/python_runner.py | 151 ++++++++++++++ sandbox/app/core/runners/python/settings.py | 62 ++++++ sandbox/app/dependencies.py | 161 +++++++++++++++ sandbox/app/logger.py | 42 ++++ sandbox/app/middleware/__init__.py | 1 + sandbox/app/middleware/auth.py | 15 ++ sandbox/app/middleware/concurrency.py | 48 +++++ sandbox/app/models.py | 80 +++++++ sandbox/app/services/__init__.py | 1 + sandbox/app/services/python_service.py | 80 +++++++ sandbox/config.yaml | 20 ++ sandbox/dependencies/python-requirements.txt | 4 + sandbox/lib/seccomp_nodejs/Cargo.lock | 7 + sandbox/lib/seccomp_nodejs/Cargo.toml | 6 + sandbox/lib/seccomp_nodejs/src/lib.rs | 0 sandbox/lib/seccomp_python/Cargo.lock | 23 +++ sandbox/lib/seccomp_python/Cargo.toml | 12 ++ sandbox/lib/seccomp_python/src/lib.rs | 195 ++++++++++++++++++ sandbox/lib/seccomp_python/src/syscalls.rs | 85 ++++++++ sandbox/main.py | 97 +++++++++ sandbox/requirements.txt | 20 ++ sandbox/script/env.sh | 53 +++++ 35 files changed, 1613 insertions(+) create mode 100644 sandbox/Dockerfile create mode 100644 sandbox/app/config.py create mode 100644 sandbox/app/controllers/__init__.py create mode 100644 sandbox/app/controllers/health_controller.py create mode 100644 sandbox/app/controllers/sandbox_controller.py create mode 100644 sandbox/app/core/__init__.py create mode 100644 sandbox/app/core/encryption.py create mode 100644 sandbox/app/core/executor.py create mode 100644 sandbox/app/core/runners/__init__.py create mode 100644 sandbox/app/core/runners/python/__init__.py create mode 100644 sandbox/app/core/runners/python/env.py create mode 100644 sandbox/app/core/runners/python/prescript.py create mode 100644 sandbox/app/core/runners/python/python_runner.py create mode 100644 sandbox/app/core/runners/python/settings.py create mode 100644 sandbox/app/dependencies.py create mode 100644 sandbox/app/logger.py create mode 100644 sandbox/app/middleware/__init__.py create mode 100644 sandbox/app/middleware/auth.py create mode 100644 sandbox/app/middleware/concurrency.py create mode 100644 sandbox/app/models.py create mode 100644 sandbox/app/services/__init__.py create mode 100644 sandbox/app/services/python_service.py create mode 100644 sandbox/config.yaml create mode 100644 sandbox/dependencies/python-requirements.txt create mode 100644 sandbox/lib/seccomp_nodejs/Cargo.lock create mode 100644 sandbox/lib/seccomp_nodejs/Cargo.toml create mode 100644 sandbox/lib/seccomp_nodejs/src/lib.rs create mode 100644 sandbox/lib/seccomp_python/Cargo.lock create mode 100644 sandbox/lib/seccomp_python/Cargo.toml create mode 100644 sandbox/lib/seccomp_python/src/lib.rs create mode 100644 sandbox/lib/seccomp_python/src/syscalls.rs create mode 100644 sandbox/main.py create mode 100644 sandbox/requirements.txt create mode 100644 sandbox/script/env.sh diff --git a/.gitignore b/.gitignore index c2648945..de160688 100644 --- a/.gitignore +++ b/.gitignore @@ -35,3 +35,6 @@ nltk_data/ tika-server*.jar* cl100k_base.tiktoken libssl*.deb + +sandbox/lib/seccomp_python/target +sandbox/lib/seccomp_nodejs/target diff --git a/sandbox/Dockerfile b/sandbox/Dockerfile new file mode 100644 index 00000000..677b991c --- /dev/null +++ b/sandbox/Dockerfile @@ -0,0 +1,42 @@ +FROM python:3.12-slim +USER root +WORKDIR /code +LABEL authors="Eterntiy" + +ARG NEED_MIRROR=0 + +RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \ + if [ "$NEED_MIRROR" == "1" ]; then \ + sed -i 's|https://ports.ubuntu.com|https://mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list; \ + sed -i 's|https://archive.ubuntu.com|https://mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list; \ + fi; \ + rm -f /etc/apt/apt.conf.d/docker-clean && \ + echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache && \ + chmod 1777 /tmp && \ + apt update && \ + apt --no-install-recommends install -y ca-certificates && \ + apt update && \ + apt install -y python3-pip pipx nginx unzip curl wget git vim less && \ + apt-get install -y --no-install-recommends tzdata libseccomp2 libseccomp-dev && \ + ln -snf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \ + echo "Asia/Shanghai" > /etc/timezone && \ + apt install -y cargo + +COPY ./app /code/app +COPY ./dependencies /code/dependencies +COPY ./lib /code/lib +COPY ./script /code/script +COPY ./config.yaml /code/config.yaml +COPY ./main.py /code/main.py +COPY ./requirements.txt /code/requirements.txt + +RUN python -m venv .venv +RUN .venv/bin/python3 -m pip install -r requirements.txt + +RUN cargo build --release --manifest-path lib/seccomp_python/Cargo.toml + +HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \ + CMD curl 127.0.0.1:8194/health + + +CMD [".venv/bin/python3", "main.py"] \ No newline at end of file diff --git a/sandbox/app/config.py b/sandbox/app/config.py new file mode 100644 index 00000000..3fa4cab5 --- /dev/null +++ b/sandbox/app/config.py @@ -0,0 +1,134 @@ +"""Configuration management""" +import os +from typing import List, Optional +from pydantic import BaseModel, Field +import yaml + +SANDBOX_USER_ID = 1000 +SANDBOX_GROUP_ID = 1000 + +DEFAULT_PYTHON_LIB_REQUIREMENTS_AMD = [ + "/usr/local/lib/python3.12", + "/usr/lib/python3", + "/usr/lib/x86_64-linux-gnu", + "/etc/ssl/certs/ca-certificates.crt", + "/etc/nsswitch.conf", + "/etc/hosts", + "/etc/resolv.conf", + "/run/systemd/resolve/stub-resolv.conf", + "/run/resolvconf/resolv.conf", + "/etc/localtime", + "/usr/share/zoneinfo", + "/etc/timezone", +] + + +class AppConfig(BaseModel): + """Application configuration""" + port: int = 8194 + debug: bool = True + key: str = "redbear-sandbox" + + +class ProxyConfig(BaseModel): + """Proxy configuration""" + socks5: str = "" + http: str = "" + https: str = "" + + +class Config(BaseModel): + """Global configuration""" + app: AppConfig = Field(default_factory=AppConfig) + max_workers: int = 4 + max_requests: int = 50 + worker_timeout: int = 30 + nodejs_path: str = "node" + enable_network: bool = True + enable_preload: bool = False + + python_path: str = "" + python_lib_paths: list = Field(default=DEFAULT_PYTHON_LIB_REQUIREMENTS_AMD) + python_deps_update_interval: str = "30m" + allowed_syscalls: List[int] = Field(default_factory=list) + proxy: ProxyConfig = Field(default_factory=ProxyConfig) + + +# Global configuration instance +_config: Optional[Config] = None + + +def load_config(config_path: str) -> Config: + """Load configuration from YAML file""" + global _config + + # Load from file + if os.path.exists(config_path): + with open(config_path, 'r') as f: + data = yaml.safe_load(f) + _config = Config(**data) + else: + _config = Config() + + # Override with environment variables + if os.getenv("DEBUG"): + _config.app.debug = os.getenv("DEBUG").lower() in ("true", "1", "yes") + + if os.getenv("MAX_WORKERS"): + _config.max_workers = int(os.getenv("MAX_WORKERS")) + + if os.getenv("MAX_REQUESTS"): + _config.max_requests = int(os.getenv("MAX_REQUESTS")) + + if os.getenv("SANDBOX_PORT"): + _config.app.port = int(os.getenv("SANDBOX_PORT")) + + if os.getenv("WORKER_TIMEOUT"): + _config.worker_timeout = int(os.getenv("WORKER_TIMEOUT")) + + if os.getenv("API_KEY"): + _config.app.key = os.getenv("API_KEY") + + if os.getenv("NODEJS_PATH"): + _config.nodejs_path = os.getenv("NODEJS_PATH") + + if os.getenv("ENABLE_NETWORK"): + _config.enable_network = os.getenv("ENABLE_NETWORK").lower() in ("true", "1", "yes") + + if os.getenv("ENABLE_PRELOAD"): + _config.enable_preload = os.getenv("ENABLE_PRELOAD").lower() in ("true", "1", "yes") + + if os.getenv("ALLOWED_SYSCALLS"): + _config.allowed_syscalls = [int(x) for x in os.getenv("ALLOWED_SYSCALLS").split(",")] + + if os.getenv("SOCKS5_PROXY"): + _config.proxy.socks5 = os.getenv("SOCKS5_PROXY") + + if os.getenv("HTTP_PROXY"): + _config.proxy.http = os.getenv("HTTP_PROXY") + + if os.getenv("HTTPS_PROXY"): + _config.proxy.https = os.getenv("HTTPS_PROXY") + + # python + if os.getenv("PYTHON_PATH"): + _config.python_path = os.getenv("PYTHON_PATH") + + if os.getenv("PYTHON_LIB_PATH"): + _config.python_lib_paths = os.getenv("PYTHON_LIB_PATH").split(',') + + if os.getenv("PYTHON_DEPS_UPDATE_INTERVAL"): + _config.python_deps_update_interval = os.getenv("PYTHON_DEPS_UPDATE_INTERVAL") + + return _config + + +config_path = os.getenv("CONFIG_PATH", "config.yaml") +load_config(config_path) + + +def get_config() -> Config: + """Get global configuration""" + if _config is None: + raise RuntimeError("Configuration not loaded. Call load_config() first.") + return _config diff --git a/sandbox/app/controllers/__init__.py b/sandbox/app/controllers/__init__.py new file mode 100644 index 00000000..b1d965ae --- /dev/null +++ b/sandbox/app/controllers/__init__.py @@ -0,0 +1,8 @@ +from fastapi import APIRouter + +from . import health_controller, sandbox_controller + +manager_router = APIRouter() + +manager_router.include_router(health_controller.router) +manager_router.include_router(sandbox_controller.router) diff --git a/sandbox/app/controllers/health_controller.py b/sandbox/app/controllers/health_controller.py new file mode 100644 index 00000000..4d872e58 --- /dev/null +++ b/sandbox/app/controllers/health_controller.py @@ -0,0 +1,12 @@ +"""Health check endpoint""" +from fastapi import APIRouter + +from app.models import HealthResponse + +router = APIRouter() + + +@router.get("/health", response_model=HealthResponse) +async def health_check(): + """Health check endpoint""" + return HealthResponse(status="healthy", version="2.0.0") diff --git a/sandbox/app/controllers/sandbox_controller.py b/sandbox/app/controllers/sandbox_controller.py new file mode 100644 index 00000000..1a713f52 --- /dev/null +++ b/sandbox/app/controllers/sandbox_controller.py @@ -0,0 +1,59 @@ +"""Sandbox API endpoints""" +from fastapi import APIRouter, Depends + +from app.middleware.auth import verify_api_key +from app.middleware.concurrency import check_max_requests, acquire_worker +from app.models import ( + RunCodeRequest, + ApiResponse, + UpdateDependencyRequest, + error_response +) +from app.services.python_service import ( + run_python_code, + list_python_dependencies, + update_python_dependencies +) + +router = APIRouter( + prefix="/v1/sandbox", + tags=["sandbox"], + dependencies=[Depends(verify_api_key)] +) + + +@router.post( + "/run", + response_model=ApiResponse, + dependencies=[Depends(check_max_requests), + Depends(acquire_worker)] +) +async def run_code(request: RunCodeRequest): + """Execute code in sandbox""" + if request.language == "python3": + return await run_python_code(request.code, request.preload, request.options) + elif request.language == "nodejs": + # TODO + return error_response(-400, "TODO") + else: + return error_response(-400, "unsupported language") + + +@router.get("/dependencies", response_model=ApiResponse) +async def get_dependencies(language: str): + """Get installed dependencies""" + if language == "python3": + return await list_python_dependencies() + else: + return error_response(-400, "unsupported language") + + +@router.post("/dependencies/update", response_model=ApiResponse) +async def update_dependencies(request: UpdateDependencyRequest): + """Update dependencies""" + if request.language == "python3": + return await update_python_dependencies() + else: + return error_response(-400, "unsupported language") + + diff --git a/sandbox/app/core/__init__.py b/sandbox/app/core/__init__.py new file mode 100644 index 00000000..e1abba12 --- /dev/null +++ b/sandbox/app/core/__init__.py @@ -0,0 +1 @@ +"""Core functionality package""" diff --git a/sandbox/app/core/encryption.py b/sandbox/app/core/encryption.py new file mode 100644 index 00000000..5e0855c9 --- /dev/null +++ b/sandbox/app/core/encryption.py @@ -0,0 +1,32 @@ +"""Code encryption utilities""" +import base64 + + +def encrypt_code(code: bytes, key: bytes) -> str: + """Encrypt code using XOR cipher with base64 encoding + + Args: + code: Plain code string + key: Encryption key bytes + + Returns: + Base64 encoded encrypted code + """ + encrypted_code = bytearray(len(code)) + for i in range(len(code)): + encrypted_code[i] = code[i] ^ key[i % 64] + encoded_code = base64.b64encode(encrypted_code).decode("utf-8") + return encoded_code + + +def generate_key(length: int = 64) -> bytes: + """Generate random encryption key + + Args: + length: Key length in bytes (default 64 for 512 bits) + + Returns: + Random key bytes + """ + import secrets + return secrets.token_bytes(length) diff --git a/sandbox/app/core/executor.py b/sandbox/app/core/executor.py new file mode 100644 index 00000000..6edc48c0 --- /dev/null +++ b/sandbox/app/core/executor.py @@ -0,0 +1,48 @@ +"""Code execution engine""" +import os +from typing import Optional +from abc import ABC, abstractmethod + +from app.config import get_config +from app.logger import get_logger +from app.models import RunnerOptions + + +class ExecutionResult: + """Result of code execution""" + + def __init__(self, stdout: str = "", stderr: str = "", exit_code: int = 0, error: Optional[str] = None): + self.stdout = stdout + self.stderr = stderr + self.exit_code = exit_code + self.error = error + + +class CodeExecutor(ABC): + """Base code executor""" + + def __init__(self): + self.logger = get_logger() + self.config = get_config() + + @abstractmethod + async def run( + self, + code: str, + options: RunnerOptions, + preload: str = "", + timeout: Optional[int] = None + ) -> ExecutionResult: + pass + + def cleanup_temp_file(self, file_path: str) -> None: + """Remove temporary file + + Args: + file_path: Path to file to remove + """ + try: + if os.path.exists(file_path): + os.remove(file_path) + except Exception as e: + self.logger.warning(f"Failed to cleanup temp file {file_path}: {e}") diff --git a/sandbox/app/core/runners/__init__.py b/sandbox/app/core/runners/__init__.py new file mode 100644 index 00000000..96c5e380 --- /dev/null +++ b/sandbox/app/core/runners/__init__.py @@ -0,0 +1 @@ +"""Code runners package""" diff --git a/sandbox/app/core/runners/python/__init__.py b/sandbox/app/core/runners/python/__init__.py new file mode 100644 index 00000000..99a56ef7 --- /dev/null +++ b/sandbox/app/core/runners/python/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/1/23 11:27 diff --git a/sandbox/app/core/runners/python/env.py b/sandbox/app/core/runners/python/env.py new file mode 100644 index 00000000..d82b0522 --- /dev/null +++ b/sandbox/app/core/runners/python/env.py @@ -0,0 +1,50 @@ +import asyncio +import tempfile +import stat +from pathlib import Path + +from app.config import get_config +from app.core.runners.python.settings import LIB_PATH +from app.logger import get_logger + +logger = get_logger() + + +async def prepare_python_dependencies_env(): + config = get_config() + + with tempfile.TemporaryDirectory(dir="/") as root_path: + root = Path(root_path) + + env_sh = root / "env.sh" + with open("script/env.sh") as f: + env_sh.write_text(f.read()) + env_sh.chmod(env_sh.stat().st_mode | stat.S_IXUSR) + + for lib_path in config.python_lib_paths: + lib_path = Path(lib_path) + + if not lib_path.exists(): + logger.warning("python lib path %s is not available", lib_path) + continue + + cmd = [ + "bash", + str(env_sh), + str(lib_path), + str(LIB_PATH), + ] + + process = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + + stdout, stderr = await process.communicate() + retcode = process.returncode + + if retcode != 0: + logger.error( + f"create env error for file {lib_path}: retcode={retcode}, stderr={stderr.decode()}" + ) diff --git a/sandbox/app/core/runners/python/prescript.py b/sandbox/app/core/runners/python/prescript.py new file mode 100644 index 00000000..4790be73 --- /dev/null +++ b/sandbox/app/core/runners/python/prescript.py @@ -0,0 +1,56 @@ +import ctypes +import os +import sys +import traceback +from base64 import b64decode + + +# Setup exception hook +def excepthook(etype, value, tb): + sys.stderr.write("".join(traceback.format_exception(etype, value, tb))) + sys.stderr.flush() + sys.exit(-1) + + +sys.excepthook = excepthook + +# Load security library if available +lib = ctypes.CDLL("./libpython.so") +lib.init_seccomp.argtypes = [ctypes.c_uint32, ctypes.c_uint32, ctypes.c_bool] +lib.init_seccomp.restype = None + +# Get running path +running_path = sys.argv[1] +if not running_path: + exit(-1) + +# Get decrypt key +key = sys.argv[2] +if not key: + exit(-1) + +key = b64decode(key) + +os.chdir(running_path) + +# Preload code +{{preload}} + +# Apply security if library is available +lib.init_seccomp({{uid}}, {{gid}}, {{enable_network}}) + +# Decrypt and execute code +code = b64decode("{{code}}") + + +def decrypt(code, key): + key_len = len(key) + code_len = len(code) + code = bytearray(code) + for i in range(code_len): + code[i] = code[i] ^ key[i % key_len] + return bytes(code) + + +code = decrypt(code, key) +exec(code) diff --git a/sandbox/app/core/runners/python/python_runner.py b/sandbox/app/core/runners/python/python_runner.py new file mode 100644 index 00000000..faac5f0c --- /dev/null +++ b/sandbox/app/core/runners/python/python_runner.py @@ -0,0 +1,151 @@ +"""Python code runner""" +import asyncio +import base64 +import os +import uuid +from typing import Optional + +from app.config import SANDBOX_USER_ID, SANDBOX_GROUP_ID, get_config +from app.core.encryption import generate_key, encrypt_code +from app.core.executor import CodeExecutor, ExecutionResult +from app.core.runners.python.settings import check_lib_avaiable, release_lib_binary, LIB_PATH +from app.models import RunnerOptions + +# Python sandbox prescript template +with open("app/core/runners/python/prescript.py") as f: + PYTHON_PRESCRIPT = f.read() + + +class PythonRunner(CodeExecutor): + """Python code runner with security isolation""" + + def __init__(self): + super().__init__() + + @staticmethod + def init_enviroment(code: bytes, preload, options: RunnerOptions) -> tuple[str, str]: + if not check_lib_avaiable(): + release_lib_binary(False) + config = get_config() + code_file_name = uuid.uuid4().hex.replace("-", "_") + + script = PYTHON_PRESCRIPT.replace("{{uid}}", str(SANDBOX_USER_ID), 1) + script = script.replace("{{gid}}", str(SANDBOX_GROUP_ID), 1) + script = script.replace( + "{{enable_network}}", + str(int(options.enable_network and config.enable_network) + ), + 1 + ) + script = script.replace("{{preload}}", f"{preload}\n", 1) + + key = generate_key(64) + + encoded_code = encrypt_code(code, key) + encoded_key = base64.b64encode(key).decode("utf-8") + + script = script.replace("{{code}}", encoded_code, 1) + + code_path = f"{LIB_PATH}/tmp/{code_file_name}.py" + try: + os.makedirs(os.path.dirname(code_path), mode=0o755, exist_ok=True) + with open(code_path, "w", encoding="utf-8") as f: + f.write(script) + os.chmod(code_path, 0o755) + + except OSError as e: + raise RuntimeError(f"Failed to write {code_path}") from e + + return code_path, encoded_key + + async def run( + self, + code: str, + options: RunnerOptions, + preload: str = "", + timeout: Optional[int] = None + ) -> ExecutionResult: + """Run Python code in sandbox + + Args: + options: + code: Base64 encoded encrypted code + preload: Preload code to execute before main code + timeout: Execution timeout in seconds + + Returns: + ExecutionResult with stdout, stderr, and exit code + """ + config = self.config + + if timeout is None: + timeout = config.worker_timeout + + # Check if preload is allowed + if not config.enable_preload: + preload = "" + code = base64.b64decode(code) + script_path, encoded_key = self.init_enviroment(code, preload, options=options) + + try: + # Setup environment + env = {} + + # Add proxy settings if configured + if config.proxy.socks5: + env["HTTPS_PROXY"] = config.proxy.socks5 + env["HTTP_PROXY"] = config.proxy.socks5 + elif config.proxy.https or config.proxy.http: + if config.proxy.https: + env["HTTPS_PROXY"] = config.proxy.https + if config.proxy.http: + env["HTTP_PROXY"] = config.proxy.http + + # Add allowed syscalls if configured + if config.allowed_syscalls: + env["ALLOWED_SYSCALLS"] = ",".join(map(str, config.allowed_syscalls)) + + # Execute with Python interpreter + + process = await asyncio.create_subprocess_exec( + config.python_path, + script_path, + LIB_PATH, + encoded_key, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + cwd=LIB_PATH + ) + + # Wait for completion with timeout + try: + stdout, stderr = await asyncio.wait_for( + process.communicate(), + timeout=timeout + ) + + return ExecutionResult( + stdout=stdout.decode('utf-8', errors='replace'), + stderr=stderr.decode('utf-8', errors='replace'), + exit_code=process.returncode + ) + + except asyncio.TimeoutError: + # Kill process on timeout + try: + process.kill() + await process.wait() + except: + pass + + return ExecutionResult( + stdout="", + stderr="Execution timeout", + exit_code=-1, + error="Execution timeout" + ) + + finally: + # Cleanup temporary file + self.cleanup_temp_file(script_path) diff --git a/sandbox/app/core/runners/python/settings.py b/sandbox/app/core/runners/python/settings.py new file mode 100644 index 00000000..aee8827b --- /dev/null +++ b/sandbox/app/core/runners/python/settings.py @@ -0,0 +1,62 @@ +import os + +from app.logger import get_logger + +logger = get_logger() + +RELEASE_LIB_PATH = "./lib/seccomp_python/target/release/libpython.so" +LIB_PATH = "/var/sandbox/sandbox-python" +LIB_NAME = "libpython.so" + +try: + with open(RELEASE_LIB_PATH, "rb") as f: + _PYTHON_LIB = f.read() +except: + logger.critical("failed to load python lib") + raise + + +def check_lib_avaiable(): + return os.path.exists(os.path.join(LIB_PATH, LIB_NAME)) + + +def release_lib_binary(force_remove: bool): + logger.info("init runtime enviroment") + lib_file = os.path.join(LIB_PATH, LIB_NAME) + if os.path.exists(lib_file): + if force_remove: + try: + os.remove(lib_file) + except OSError: + logger.critical(f"failed to remove {os.path.join(LIB_PATH, LIB_NAME)}") + raise + + try: + os.makedirs(LIB_PATH, mode=0o755, exist_ok=True) + except OSError: + logger.critical(f"failed to create {LIB_PATH}") + raise + + try: + with open(lib_file, "wb") as f: + f.write(_PYTHON_LIB) + os.chmod(lib_file, 0o755) + except OSError: + logger.critical(f"failed to write {lib_file}") + raise + else: + try: + os.makedirs(LIB_PATH, mode=0o755, exist_ok=True) + except OSError: + logger.critical(f"failed to create {LIB_PATH}") + raise + + try: + with open(lib_file, "wb") as f: + f.write(_PYTHON_LIB) + os.chmod(lib_file, 0o755) + except OSError: + logger.critical(f"failed to write {lib_file}") + raise + + logger.info("python runner environment initialized") diff --git a/sandbox/app/dependencies.py b/sandbox/app/dependencies.py new file mode 100644 index 00000000..6e88aaf2 --- /dev/null +++ b/sandbox/app/dependencies.py @@ -0,0 +1,161 @@ +"""Dependency management""" +import asyncio +from pathlib import Path +from typing import List, Dict + +from app.config import get_config +from app.core.runners.python.env import prepare_python_dependencies_env +from app.logger import get_logger + + +async def setup_dependencies(): + """Setup initial dependencies""" + logger = get_logger() + + try: + logger.info("Installing Python dependencies...") + await install_python_dependencies() + logger.info("Python dependencies installed") + + logger.info("Preparing Python dependencies environment...") + await prepare_python_dependencies_env() + logger.info("Python dependencies environment ready") + + except Exception as e: + logger.error(f"Failed to setup dependencies: {e}") + + +async def update_dependencies(): + # TODO + return + + +async def install_python_dependencies(): + """Install Python dependencies from requirements file""" + logger = get_logger() + config = get_config() + + # Check if requirements file exists + req_file = Path("dependencies/python-requirements.txt") + if not req_file.exists(): + logger.warning("Python requirements file not found, skipping installation") + return + + # Read requirements + requirements = req_file.read_text().strip() + if not requirements: + logger.info("No Python requirements to install") + return + + # Install using pip + cmd = [ + config.python_path, + "-m", + "pip", + "install", + "--upgrade" + ] + + # Add packages from requirements + for line in requirements.split("\n"): + line = line.strip() + if line and not line.startswith("#"): + cmd.append(line) + + try: + process = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + + stdout, stderr = await process.communicate() + + if process.returncode != 0: + logger.error(f"Failed to install Python dependencies: {stderr.decode()}") + else: + logger.info("Python dependencies installed successfully") + + except Exception as e: + logger.error(f"Error installing Python dependencies: {e}") + + +async def list_dependencies(language: str) -> List[Dict[str, str]]: + """List installed dependencies + + Args: + language: Language (python or Node.js) + + Returns: + List of dependencies with name and version + """ + if language == "python": + return await list_python_packages() + else: + return [] + + +async def list_python_packages() -> List[Dict[str, str]]: + """List installed Python packages""" + config = get_config() + + try: + process = await asyncio.create_subprocess_exec( + config.python_path, + "-m", + "pip", + "list", + "--format=freeze", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + + stdout, stderr = await process.communicate() + + if process.returncode != 0: + return [] + + # Parse output + packages = [] + for line in stdout.decode().split("\n"): + line = line.strip() + if line and "==" in line: + name, version = line.split("==", 1) + packages.append({"name": name, "version": version}) + + return packages + + except Exception as e: + get_logger().error(f"Failed to list Python packages: {e}") + return [] + + +async def update_dependencies_periodically(): + """Periodically update dependencies""" + logger = get_logger() + config = get_config() + + # Parse interval + interval_str = config.python_deps_update_interval + + # Convert to seconds + if interval_str.endswith("m"): + interval = int(interval_str[:-1]) * 60 + elif interval_str.endswith("h"): + interval = int(interval_str[:-1]) * 3600 + elif interval_str.endswith("s"): + interval = int(interval_str[:-1]) + else: + interval = 1800 # Default 30 minutes + + logger.info(f"Starting periodic dependency updates every {interval} seconds") + + while True: + await asyncio.sleep(interval) + + try: + logger.info("Updating Python dependencies...") + # TODO: await update_dependencies("python") + logger.info("Python dependencies updated successfully") + except Exception as e: + logger.error(f"Failed to update Python dependencies: {e}") diff --git a/sandbox/app/logger.py b/sandbox/app/logger.py new file mode 100644 index 00000000..de2ccc9e --- /dev/null +++ b/sandbox/app/logger.py @@ -0,0 +1,42 @@ +"""Logging configuration""" +import logging +import sys +from typing import Optional + +from app.config import get_config + +_logger: Optional[logging.Logger] = None + + +def setup_logger() -> logging.Logger: + """Setup application logger""" + global _logger + + config = get_config() + + # Create logger + _logger = logging.getLogger("sandbox") + _logger.setLevel(logging.DEBUG if config.app.debug else logging.INFO) + + # Create console handler + handler = logging.StreamHandler(sys.stdout) + handler.setLevel(logging.DEBUG if config.app.debug else logging.INFO) + + # Create formatter + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + handler.setFormatter(formatter) + + # Add handler to logger + _logger.addHandler(handler) + + return _logger + + +def get_logger() -> logging.Logger: + """Get application logger""" + if _logger is None: + return setup_logger() + return _logger diff --git a/sandbox/app/middleware/__init__.py b/sandbox/app/middleware/__init__.py new file mode 100644 index 00000000..77d6403c --- /dev/null +++ b/sandbox/app/middleware/__init__.py @@ -0,0 +1 @@ +"""Middleware package""" diff --git a/sandbox/app/middleware/auth.py b/sandbox/app/middleware/auth.py new file mode 100644 index 00000000..8a93a793 --- /dev/null +++ b/sandbox/app/middleware/auth.py @@ -0,0 +1,15 @@ +"""Authentication middleware""" +from fastapi import Header, HTTPException, status + +from app.config import get_config + + +async def verify_api_key(x_api_key: str = Header(..., alias="X-Api-Key")): + """Verify API key from request header""" + config = get_config() + if x_api_key != config.app.key: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key" + ) + return x_api_key diff --git a/sandbox/app/middleware/concurrency.py b/sandbox/app/middleware/concurrency.py new file mode 100644 index 00000000..8d8325a4 --- /dev/null +++ b/sandbox/app/middleware/concurrency.py @@ -0,0 +1,48 @@ +"""Concurrency control middleware""" +import asyncio +from fastapi import HTTPException, status + +from app.config import get_config +from app.models import error_response + + +# Global semaphores +_worker_semaphore: None | asyncio.Semaphore = None +_request_counter = 0 +_request_lock = asyncio.Lock() + + +def init_concurrency_control(): + """Initialize concurrency control""" + global _worker_semaphore + config = get_config() + _worker_semaphore = asyncio.Semaphore(config.max_workers) + + +async def check_max_requests(): + """Check if max requests limit is reached""" + global _request_counter + config = get_config() + + async with _request_lock: + if _request_counter >= config.max_requests: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=error_response(-503, "Too many requests") + ) + _request_counter += 1 + + try: + yield + finally: + async with _request_lock: + _request_counter -= 1 + + +async def acquire_worker(): + """Acquire a worker slot""" + if _worker_semaphore is None: + init_concurrency_control() + + async with _worker_semaphore: + yield diff --git a/sandbox/app/models.py b/sandbox/app/models.py new file mode 100644 index 00000000..e7492b4c --- /dev/null +++ b/sandbox/app/models.py @@ -0,0 +1,80 @@ +"""Data models""" +from typing import Optional, Any + +from pydantic import BaseModel, Field + + +class RunnerOptions(BaseModel): + enable_network: bool = Field(default=False, description="Sandbox network flag") + + +class RunCodeRequest(BaseModel): + """Request model for code execution""" + language: str = Field(..., description="Programming language (python3 or nodejs)") + code: str = Field(..., description="Base64 encoded encrypted code") + preload: Optional[str] = Field(default="", description="Preload code") + options: RunnerOptions = Field(default_factory=RunnerOptions, description="Enable network access") + + +class RunCodeResponse(BaseModel): + """Response model for code execution""" + stdout: str = Field(default="", description="Standard output") + stderr: str = Field(default="", description="Standard error") + + +class DependencyRequest(BaseModel): + """Request model for dependency operations""" + language: str = Field(..., description="Programming language") + + +class UpdateDependencyRequest(BaseModel): + """Request model for updating dependencies""" + language: str = Field(..., description="Programming language") + packages: list[str] = Field(default_factory=list, description="Packages to install") + + +class Dependency(BaseModel): + """Dependency information""" + name: str + version: str + + +class ListDependenciesResponse(BaseModel): + """Response model for listing dependencies""" + dependencies: list[Dependency] = Field(default_factory=list) + + +class RefreshDependenciesResponse(BaseModel): + """Response model for refreshing dependencies""" + dependencies: list[Dependency] = Field(default_factory=list) + + +class UpdateDependenciesResponse(BaseModel): + """Response model for updating dependencies""" + success: bool = True + installed: list[str] = Field(default_factory=list) + + +class HealthResponse(BaseModel): + """Health check response""" + status: str = "healthy" + version: str = "2.0.0" + + +class ApiResponse(BaseModel): + """Standard API response wrapper""" + code: int = Field(default=0, description="Response code (0 for success, negative for error)") + message: str = Field(default="success", description="Response message") + data: Optional[Any] = Field(default=None, description="Response data") + + +def success_response(data: Any) -> ApiResponse: + """Create success response""" + return ApiResponse(code=0, message="success", data=data) + + +def error_response(code: int, message: str) -> ApiResponse: + """Create error response""" + if code >= 0: + code = -1 + return ApiResponse(code=code, message=message, data=None) diff --git a/sandbox/app/services/__init__.py b/sandbox/app/services/__init__.py new file mode 100644 index 00000000..e3726046 --- /dev/null +++ b/sandbox/app/services/__init__.py @@ -0,0 +1 @@ +"""Services package""" diff --git a/sandbox/app/services/python_service.py b/sandbox/app/services/python_service.py new file mode 100644 index 00000000..71cfda0d --- /dev/null +++ b/sandbox/app/services/python_service.py @@ -0,0 +1,80 @@ +"""Python execution service""" +import signal + +from app.core.runners.python.python_runner import PythonRunner +from app.dependencies import ( + list_dependencies as list_deps, + update_dependencies as update_deps +) +from app.logger import get_logger +from app.models import ( + success_response, + error_response, + RunCodeResponse, + ListDependenciesResponse, + UpdateDependenciesResponse, + Dependency, + RunnerOptions +) + + +async def run_python_code(code: str, preload: str, options: RunnerOptions): + """Execute Python code in sandbox + + Args: + options: + code: Base64 encoded encrypted code + preload: Preload code + + Returns: + API response with execution result + """ + logger = get_logger() + + try: + runner = PythonRunner() + result = await runner.run(code, options, preload) + if result.exit_code == -signal.SIGSYS: + return error_response(31, "sandbox security policy violation") + + if result.error: + return error_response(-500, result.error) + + return success_response(RunCodeResponse( + stdout=result.stdout, + stderr=result.stderr + )) + + except Exception as e: + logger.error(f"Python execution failed: {e}", exc_info=True) + return error_response(-500, str(e)) + + +async def list_python_dependencies(): + """List installed Python dependencies + + Returns: + API response with dependency list + """ + try: + deps = await list_deps("python") + dependencies = [ + Dependency(name=dep["name"], version=dep["version"]) + for dep in deps + ] + return success_response(ListDependenciesResponse(dependencies=dependencies)) + except Exception as e: + return error_response(500, str(e)) + + +async def update_python_dependencies(): + """Update Python dependencies + + Returns: + API response with update result + """ + try: + await update_deps() + return success_response(UpdateDependenciesResponse(success=True)) + except Exception as e: + return error_response(500, str(e)) diff --git a/sandbox/config.yaml b/sandbox/config.yaml new file mode 100644 index 00000000..d9581b34 --- /dev/null +++ b/sandbox/config.yaml @@ -0,0 +1,20 @@ +app: + port: 8194 + debug: true + key: redbear-sandbox + +max_workers: 4 +max_requests: 50 +worker_timeout: 30 +python_path: /usr/local/bin/python +nodejs_path: /usr/local/bin/node +enable_network: true +enable_preload: false +python_deps_update_interval: 30m + +allowed_syscalls: [] + +proxy: + socks5: '' + http: '' + https: '' diff --git a/sandbox/dependencies/python-requirements.txt b/sandbox/dependencies/python-requirements.txt new file mode 100644 index 00000000..1c3c2901 --- /dev/null +++ b/sandbox/dependencies/python-requirements.txt @@ -0,0 +1,4 @@ +requests==2.31.0 +# numpy==1.26.0 +# pandas==2.0.0 +jinja2==3.1.2 \ No newline at end of file diff --git a/sandbox/lib/seccomp_nodejs/Cargo.lock b/sandbox/lib/seccomp_nodejs/Cargo.lock new file mode 100644 index 00000000..b37698ee --- /dev/null +++ b/sandbox/lib/seccomp_nodejs/Cargo.lock @@ -0,0 +1,7 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "seccomp_nodejs" +version = "0.1.0" diff --git a/sandbox/lib/seccomp_nodejs/Cargo.toml b/sandbox/lib/seccomp_nodejs/Cargo.toml new file mode 100644 index 00000000..a8bd8932 --- /dev/null +++ b/sandbox/lib/seccomp_nodejs/Cargo.toml @@ -0,0 +1,6 @@ +[package] +name = "seccomp_nodejs" +version = "0.1.0" +edition = "2024" + +[dependencies] \ No newline at end of file diff --git a/sandbox/lib/seccomp_nodejs/src/lib.rs b/sandbox/lib/seccomp_nodejs/src/lib.rs new file mode 100644 index 00000000..e69de29b diff --git a/sandbox/lib/seccomp_python/Cargo.lock b/sandbox/lib/seccomp_python/Cargo.lock new file mode 100644 index 00000000..881ad177 --- /dev/null +++ b/sandbox/lib/seccomp_python/Cargo.lock @@ -0,0 +1,23 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "libc" +version = "0.2.180" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc" + +[[package]] +name = "libseccomp-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60276e2d41bbb68b323e566047a1bfbf952050b157d8b5cdc74c07c1bf4ca3b6" + +[[package]] +name = "seccomp_python" +version = "0.1.0" +dependencies = [ + "libc", + "libseccomp-sys", +] diff --git a/sandbox/lib/seccomp_python/Cargo.toml b/sandbox/lib/seccomp_python/Cargo.toml new file mode 100644 index 00000000..07037172 --- /dev/null +++ b/sandbox/lib/seccomp_python/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "seccomp_python" +version = "0.1.0" +edition = "2024" + +[lib] +name = "python" +crate-type = ["cdylib"] + +[dependencies] +libc = "0.2.180" +libseccomp-sys = "0.3.0" diff --git a/sandbox/lib/seccomp_python/src/lib.rs b/sandbox/lib/seccomp_python/src/lib.rs new file mode 100644 index 00000000..08b46c54 --- /dev/null +++ b/sandbox/lib/seccomp_python/src/lib.rs @@ -0,0 +1,195 @@ +mod syscalls; + +use crate::syscalls::*; +use libc::{chdir, chroot, gid_t, uid_t, c_int}; +use libseccomp_sys::*; +use std::env; +use std::ffi::CString; +use std::str::FromStr; + + +/* + * get_allowed_syscalls - retrieve allowed syscalls for the sandbox + * @enable_network: enable network-related syscalls if non-zero + * + * Syscall selection order: + * 1. ALLOWED_SYSCALLS environment variable + * 2. Built-in default allowlist + * 3. Optional network syscall extension + * + * Returns: + * (allowed_syscalls, allowed_not_kill_syscalls) + * allowed_syscalls: syscalls fully allowed + * allowed_not_kill_syscalls: syscalls returning EPERM + */ +pub fn get_allowed_syscalls(enable_network: bool) -> (Vec, Vec) { + let mut allowed_syscalls = Vec::new(); + let mut allowed_not_kill_syscalls = Vec::new(); + + /* Syscalls that return error instead of killing */ + allowed_not_kill_syscalls.extend(ALLOW_ERROR_SYSCALLS); + + /* Load from environment variable ALLOWED_SYSCALLS */ + if let Ok(env_val) = env::var("ALLOWED_SYSCALLS") { + if !env_val.is_empty() { + for s in env_val.split(',') { + if let Ok(sc) = i32::from_str(s) { + allowed_syscalls.push(sc); + } + } + } + } + + /* Fallback to default syscalls if env not set */ + if allowed_syscalls.is_empty() { + allowed_syscalls.extend(ALLOW_SYSCALLS); + if enable_network { + allowed_syscalls.extend(ALLOW_NETWORK_SYSCALLS); + } + } + + (allowed_syscalls, allowed_not_kill_syscalls) +} + +/* + * setup_root - setup restricted filesystem root + * + * Perform chroot(".") and change working directory to "/". + * + * Return: + * 0 on success + * negative error code on failure + */ +fn setup_root() -> Result<(), c_int> { + let root = CString::new(".").unwrap(); + if unsafe { chroot(root.as_ptr()) } != 0 { + return Err(-1); + } + + let root_dir = CString::new("/").unwrap(); + if unsafe { chdir(root_dir.as_ptr()) } != 0 { + return Err(-2); + } + + Ok(()) +} + +/* + * set_no_new_privs - enable PR_SET_NO_NEW_PRIVS + * + * Prevent privilege escalation via execve. + * + * Return: + * 0 on success + * negative error code on failure + */ +fn set_no_new_privs() -> Result<(), c_int> { + if unsafe { libc::prctl(libc::PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0) } != 0 { + return Err(-3); + } + Ok(()) +} + +/* + * drop_privileges - drop process privileges + * @uid: target user ID + * @gid: target group ID + * + * Permanently reduce process privileges. + * + * Return: + * 0 on success + * negative error code on failure + */ +fn drop_privileges(uid: uid_t, gid: gid_t) -> Result<(), c_int> { + if unsafe { libc::setgid(gid) } != 0 { + return Err(-4); + } + if unsafe { libc::setuid(uid) } != 0 { + return Err(-5); + } + Ok(()) +} + +/* + * install_seccomp - install seccomp filter + * @enable_network: enable network-related syscalls if non-zero + * + * Default action is SCMP_ACT_KILL_PROCESS. + * Allowed syscalls are explicitly whitelisted. + * + * Return: + * 0 on success + * negative error code on failure + */ +fn install_seccomp(enable_network: bool) -> Result<(), c_int> { + unsafe { + let ctx = seccomp_init(SCMP_ACT_KILL_PROCESS); + if ctx.is_null() { + return Err(-6); /* failed to init seccomp context */ + } + + let (allowed_syscalls, allowed_not_kill_syscalls) = get_allowed_syscalls(enable_network); + + /* add fully allowed syscalls */ + for &sc in &allowed_syscalls { + if seccomp_rule_add(ctx, SCMP_ACT_ALLOW, sc, 0) != 0 { + seccomp_release(ctx); + return Err(-7); + } + } + + /* add syscalls returning EPERM */ + for &sc in &allowed_not_kill_syscalls { + if seccomp_rule_add(ctx, SCMP_ACT_ERRNO(libc::EPERM as u16), sc, 0) != 0 { + seccomp_release(ctx); + return Err(-8); + } + } + + if seccomp_load(ctx) != 0 { + seccomp_release(ctx); + return Err(-9); + } + + seccomp_release(ctx); + Ok(()) + } +} + +/* + * init_seccomp - initialize seccomp sandbox + * @uid: target user ID + * @gid: target group ID + * @enable_network: enable network syscalls if non-zero + * + * Initialize the sandbox and apply privilege restrictions + * in the following order: + * 1. setup_root() + * 2. set_no_new_privs() + * 3. drop_privileges() + * 4. install_seccomp() + * + * This function must be called before executing any untrusted code. + * It is not thread-safe and must be invoked once per process. + * + * Return: + * 0 on success + * negative error code on failure + */ +#[unsafe(no_mangle)] +pub unsafe extern "C" fn init_seccomp(uid: uid_t, gid: gid_t, enable_network: i32) -> c_int { + if let Err(code) = setup_root() { + return code; + } + if let Err(code) = set_no_new_privs() { + return code; + } + if let Err(code) = drop_privileges(uid, gid) { + return code; + } + match install_seccomp(enable_network != 0) { + Ok(_) => 0, + Err(code) => code, + } +} diff --git a/sandbox/lib/seccomp_python/src/syscalls.rs b/sandbox/lib/seccomp_python/src/syscalls.rs new file mode 100644 index 00000000..07070d22 --- /dev/null +++ b/sandbox/lib/seccomp_python/src/syscalls.rs @@ -0,0 +1,85 @@ +// src/syscalls.rs + +pub static ALLOW_SYSCALLS: &[i32] = &[ + // file io + libc::SYS_read as i32, + libc::SYS_write as i32, + libc::SYS_openat as i32, + libc::SYS_close as i32, + libc::SYS_newfstatat as i32, + libc::SYS_ioctl as i32, + libc::SYS_lseek as i32, + libc::SYS_getdents64 as i32, + + // thread + libc::SYS_futex as i32, + + // memory + libc::SYS_mmap as i32, + libc::SYS_brk as i32, + libc::SYS_mprotect as i32, + libc::SYS_munmap as i32, + libc::SYS_rt_sigreturn as i32, + libc::SYS_mremap as i32, + + // user / group + libc::SYS_setuid as i32, + libc::SYS_setgid as i32, + libc::SYS_getuid as i32, + + // process + libc::SYS_getpid as i32, + libc::SYS_getppid as i32, + libc::SYS_gettid as i32, + libc::SYS_exit as i32, + libc::SYS_exit_group as i32, + libc::SYS_tgkill as i32, + libc::SYS_rt_sigaction as i32, + libc::SYS_sched_yield as i32, + libc::SYS_set_robust_list as i32, + libc::SYS_get_robust_list as i32, + libc::SYS_rseq as i32, + + // time + libc::SYS_clock_gettime as i32, + libc::SYS_gettimeofday as i32, + libc::SYS_nanosleep as i32, + libc::SYS_epoll_create1 as i32, + libc::SYS_epoll_ctl as i32, + libc::SYS_clock_nanosleep as i32, + libc::SYS_pselect6 as i32, + libc::SYS_rt_sigprocmask as i32, + libc::SYS_sigaltstack as i32, + libc::SYS_getrandom as i32, + +]; + +pub static ALLOW_ERROR_SYSCALLS: &[i32] = &[ + libc::SYS_clone as i32, + libc::SYS_mkdirat as i32, + libc::SYS_mkdir as i32, +]; + +pub static ALLOW_NETWORK_SYSCALLS: &[i32] = &[ + libc::SYS_socket as i32, + libc::SYS_connect as i32, + libc::SYS_bind as i32, + libc::SYS_listen as i32, + libc::SYS_accept as i32, + libc::SYS_sendto as i32, + libc::SYS_recvfrom as i32, + libc::SYS_getsockname as i32, + libc::SYS_recvmsg as i32, + libc::SYS_getpeername as i32, + libc::SYS_setsockopt as i32, + libc::SYS_ppoll as i32, + libc::SYS_uname as i32, + libc::SYS_sendmsg as i32, + libc::SYS_sendmmsg as i32, + libc::SYS_getsockopt as i32, + libc::SYS_fstat as i32, + libc::SYS_fcntl as i32, + libc::SYS_fstatfs as i32, + libc::SYS_poll as i32, + libc::SYS_epoll_pwait as i32, +]; diff --git a/sandbox/main.py b/sandbox/main.py new file mode 100644 index 00000000..fc417563 --- /dev/null +++ b/sandbox/main.py @@ -0,0 +1,97 @@ +""" +Redbear Sandbox - Main Entry Point +""" +import asyncio +import os +import sys +from contextlib import asynccontextmanager + +import uvicorn +from fastapi import FastAPI + +from app.config import get_config +from app.controllers import manager_router +from app.dependencies import setup_dependencies, update_dependencies_periodically +from app.logger import setup_logger, get_logger + +logger = get_logger() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan manager""" + logger = get_logger() + + # Startup + logger.info("Starting RedBear Sandbox...") + + # Setup dependencies in background + asyncio.create_task(setup_dependencies()) + + # Start periodic dependency updates + config = get_config() + if config.python_deps_update_interval: + asyncio.create_task(update_dependencies_periodically()) + + yield + + # Shutdown + logger.info("Shutting down Redbear Sandbox...") + + +def create_app() -> FastAPI: + """Create FastAPI application""" + config = get_config() + + app = FastAPI( + title="Sandbox", + description="Secure code execution sandbox", + version="2.0.0", + lifespan=lifespan, + debug=config.app.debug + ) + + app.include_router(manager_router) + + return app + + +def check_root_privileges(): + """Check if running with root privileges""" + if os.geteuid() != 0: + logger.info("Error: Sandbox must be run as root for security features (chroot, setuid)") + sys.exit(1) + + +def main(): + """Main entry point""" + # Check root privileges + check_root_privileges() + + # Setup logging + setup_logger() + + config = get_config() + logger = get_logger() + + logger.info(f"Starting server on port {config.app.port}") + logger.info(f"Debug mode: {config.app.debug}") + logger.info(f"Max workers: {config.max_workers}") + logger.info(f"Max requests: {config.max_requests}") + logger.info(f"Network enabled: {config.enable_network}") + + # Create app + app = create_app() + + # Run server + uvicorn.run( + app, + host="0.0.0.0", + port=config.app.port, + log_level="debug" if config.app.debug else "info", + access_log=config.app.debug + ) + + +if __name__ == "__main__": + main() diff --git a/sandbox/requirements.txt b/sandbox/requirements.txt new file mode 100644 index 00000000..0c91018a --- /dev/null +++ b/sandbox/requirements.txt @@ -0,0 +1,20 @@ +# Web Framework +fastapi==0.115.0 +uvicorn[standard]==0.32.0 +pydantic==2.9.0 +pydantic-settings==2.5.0 + +# Configuration +PyYAML==6.0.2 + +# Security +pyseccomp==0.1.2 + + +# Async & Concurrency +aiofiles==24.1.0 + +# Testing +pytest==8.3.0 +pytest-asyncio==0.24.0 +httpx==0.27.0 diff --git a/sandbox/script/env.sh b/sandbox/script/env.sh new file mode 100644 index 00000000..f44f7208 --- /dev/null +++ b/sandbox/script/env.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +# Check if the correct number of arguments are provided +if [ "$#" -ne 2 ]; then + echo "Usage: $0 " + exit 1 +fi + +src="$1" +dest="$2" + +# Function to copy and link files +copy_and_link() { + local src_file="$1" + local dest_file="$2" + + if [ -L "$src_file" ]; then + # If src_file is a symbolic link, copy it without changing permissions + cp -P "$src_file" "$dest_file" + elif [ -b "$src_file" ] || [ -c "$src_file" ]; then + # If src_file is a device file, copy it and change permissions + cp "$src_file" "$dest_file" + chmod 444 "$dest_file" + else + # Otherwise, create a hard link and change the permissions to read-only + ln -f "$src_file" "$dest_file" 2>/dev/null || { cp "$src_file" "$dest_file" && chmod 444 "$dest_file"; } + fi +} + +# Check if src is a file or directory +if [ -f "$src" ]; then + # src is a file, create hard link directly in dest + mkdir -p "$(dirname "$dest/$src")" + copy_and_link "$src" "$dest/$src" +elif [ -d "$src" ]; then + # src is a directory, process as before + mkdir -p "$dest/$src" + + # Find all files in the source directory + find "$src" -type f,l | while read -r file; do + # Get the relative path of the file + rel_path="${file#$src/}" + # Get the directory of the relative path + rel_dir=$(dirname "$rel_path") + # Create the same directory structure in the destination + mkdir -p "$dest/$src/$rel_dir" + # Copy and link the file + copy_and_link "$file" "$dest/$src/$rel_path" + done +else + echo "Error: $src is neither a file nor a directory" + exit 1 +fi From 0fd8a122fb01a71e497ba15692b63bd5657b8204 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Mon, 26 Jan 2026 11:59:13 +0800 Subject: [PATCH 11/28] feat(workflow): emit SSE events for node exception output --- api/app/core/workflow/executor.py | 50 ++++++++++++++++-------- api/app/core/workflow/nodes/base_node.py | 5 +++ 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 6721d7b0..f3feff60 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -261,7 +261,7 @@ class WorkflowExecutor: "data": { "execution_id": self.execution_id, "workspace_id": self.workspace_id, - "timestamp": start_time.isoformat() + "timestamp": int(start_time.timestamp() * 1000) } } @@ -293,20 +293,33 @@ class WorkflowExecutor: # Handle custom streaming events (chunks from nodes via stream writer) chunk_count += 1 event_type = data.get("type", "node_chunk") # "message" or "node_chunk" - logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}" - f"- execution_id: {self.execution_id}") - yield { - "event": event_type, # "message" or "node_chunk" - "data": { - "node_id": data.get("node_id"), - "chunk": data.get("chunk"), - "full_content": data.get("full_content"), - "chunk_index": data.get("chunk_index"), - "is_prefix": data.get("is_prefix"), - "is_suffix": data.get("is_suffix"), - "conversation_id": input_data.get("conversation_id"), + if event_type in ("message", "node_chunk"): + logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}" + f"- execution_id: {self.execution_id}") + yield { + "event": event_type, # "message" or "node_chunk" + "data": { + "node_id": data.get("node_id"), + "chunk": data.get("chunk"), + "full_content": data.get("full_content"), + "chunk_index": data.get("chunk_index"), + "is_prefix": data.get("is_prefix"), + "is_suffix": data.get("is_suffix"), + "conversation_id": input_data.get("conversation_id"), + } + } + elif event_type == "node_error": + yield { + "event": event_type, # "message" or "node_chunk" + "data": { + "node_id": data.get("node_id"), + "status": "failed", + "input": data.get("input_data"), + "elapsed_time": data.get("elapsed_time"), + "output": None, + "error": data.get("error") + } } - } elif mode == "debug": # Handle debug information (node execution status) @@ -325,14 +338,15 @@ class WorkflowExecutor: conversation_id = input_data.get("conversation_id") logger.info(f"[NODE-START] Node starts execution: {node_name} " f"- execution_id: {self.execution_id}") - yield { "event": "node_start", "data": { "node_id": node_name, "conversation_id": conversation_id, "execution_id": self.execution_id, - "timestamp": data.get("timestamp"), + "timestamp": int(datetime.datetime.fromisoformat( + data.get("timestamp") + ).timestamp() * 1000), } } elif event_type == "task_result": @@ -351,7 +365,9 @@ class WorkflowExecutor: "node_id": node_name, "conversation_id": conversation_id, "execution_id": self.execution_id, - "timestamp": data.get("timestamp"), + "timestamp": int(datetime.datetime.fromisoformat( + data.get("timestamp") + ).timestamp() * 1000), "state": result.get("node_outputs", {}).get(node_name), } } diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 0c015c89..61d5ca1e 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -541,6 +541,11 @@ class BaseNode(ABC): "error_node": self.node_id } else: + writer = get_stream_writer() + writer({ + "type": "node_error", + **node_output + }) # 无错误边:抛出异常停止工作流 logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}") raise Exception(f"节点 {self.node_id} 执行失败: {error_message}") From 1fc04c37d3456790f2bfac96e743eb92629981e9 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Mon, 26 Jan 2026 12:22:54 +0800 Subject: [PATCH 12/28] perf(sandbox): optimize code encryption handling --- sandbox/app/core/encryption.py | 3 ++- sandbox/app/core/runners/python/prescript.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sandbox/app/core/encryption.py b/sandbox/app/core/encryption.py index 5e0855c9..47a756c8 100644 --- a/sandbox/app/core/encryption.py +++ b/sandbox/app/core/encryption.py @@ -12,9 +12,10 @@ def encrypt_code(code: bytes, key: bytes) -> str: Returns: Base64 encoded encrypted code """ + key_length = len(key) encrypted_code = bytearray(len(code)) for i in range(len(code)): - encrypted_code[i] = code[i] ^ key[i % 64] + encrypted_code[i] = code[i] ^ key[i % key_length] encoded_code = base64.b64encode(encrypted_code).decode("utf-8") return encoded_code diff --git a/sandbox/app/core/runners/python/prescript.py b/sandbox/app/core/runners/python/prescript.py index 4790be73..950710ea 100644 --- a/sandbox/app/core/runners/python/prescript.py +++ b/sandbox/app/core/runners/python/prescript.py @@ -17,7 +17,7 @@ sys.excepthook = excepthook # Load security library if available lib = ctypes.CDLL("./libpython.so") lib.init_seccomp.argtypes = [ctypes.c_uint32, ctypes.c_uint32, ctypes.c_bool] -lib.init_seccomp.restype = None +lib.init_seccomp.restype = None # TODO: raise error info # Get running path running_path = sys.argv[1] From 85681db7b70017a4f1caf86f7ece84ab29741e13 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Mon, 26 Jan 2026 12:28:40 +0800 Subject: [PATCH 13/28] perf(workflow): update standard node output structure --- api/app/core/workflow/executor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index f3feff60..c4662113 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -368,7 +368,9 @@ class WorkflowExecutor: "timestamp": int(datetime.datetime.fromisoformat( data.get("timestamp") ).timestamp() * 1000), - "state": result.get("node_outputs", {}).get(node_name), + "input": result.get("node_outputs", {}).get(node_name, {}).get("input"), + "output": result.get("node_outputs", {}).get(node_name, {}).get("output"), + "elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"), } } From 8228d38859e6ed1782543363d16b9ee093351f71 Mon Sep 17 00:00:00 2001 From: Mark Date: Mon, 26 Jan 2026 14:26:32 +0800 Subject: [PATCH 14/28] [add] migration script --- .../versions/325b759cd66b_2026011240.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 api/migrations/versions/325b759cd66b_2026011240.py diff --git a/api/migrations/versions/325b759cd66b_2026011240.py b/api/migrations/versions/325b759cd66b_2026011240.py new file mode 100644 index 00000000..66c8681c --- /dev/null +++ b/api/migrations/versions/325b759cd66b_2026011240.py @@ -0,0 +1,50 @@ +"""2026011240 + +Revision ID: 325b759cd66b +Revises: 9a936a9ebb20 +Create Date: 2026-01-26 12:37:35.946749 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +revision: str = '325b759cd66b' +down_revision: Union[str, None] = '9a936a9ebb20' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # 1. 重命名表 data_config -> memory_config + op.rename_table('data_config', 'memory_config') + + # 2. 重命名列 group_id -> end_user_id + op.alter_column('memory_config', 'group_id', new_column_name='end_user_id') + + # 3. config_id: INTEGER -> UUID(保留旧值以便回滚) + op.alter_column('memory_config', 'config_id', new_column_name='config_id_old') + op.add_column('memory_config', sa.Column('config_id', sa.UUID(), nullable=True)) + op.execute("UPDATE memory_config SET config_id = apply_id::uuid") + op.drop_constraint('data_config_pkey', 'memory_config', type_='primary') + op.alter_column('memory_config', 'config_id', nullable=False) + op.create_primary_key('memory_config_pkey', 'memory_config', ['config_id']) + op.execute("DROP SEQUENCE IF EXISTS data_config_config_id_seq") + + +def downgrade() -> None: + # 1. config_id: UUID -> INTEGER(恢复旧值) + op.drop_constraint('memory_config_pkey', 'memory_config', type_='primary') + op.drop_column('memory_config', 'config_id') + op.alter_column('memory_config', 'config_id_old', new_column_name='config_id') + op.create_primary_key('data_config_pkey', 'memory_config', ['config_id']) + op.execute("CREATE SEQUENCE IF NOT EXISTS data_config_config_id_seq OWNED BY memory_config.config_id") + op.execute("SELECT setval('data_config_config_id_seq', COALESCE((SELECT MAX(config_id) FROM memory_config), 1))") + + # 2. 重命名列 end_user_id -> group_id + op.alter_column('memory_config', 'end_user_id', new_column_name='group_id') + + # 3. 重命名表 memory_config -> data_config + op.rename_table('memory_config', 'data_config') From b046411302553158e2af49eeb984ad3e564bf79b Mon Sep 17 00:00:00 2001 From: Mark Date: Mon, 26 Jan 2026 15:39:35 +0800 Subject: [PATCH 15/28] [modify] migration script --- api/migrations/versions/325b759cd66b_2026011240.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/api/migrations/versions/325b759cd66b_2026011240.py b/api/migrations/versions/325b759cd66b_2026011240.py index 66c8681c..763b0289 100644 --- a/api/migrations/versions/325b759cd66b_2026011240.py +++ b/api/migrations/versions/325b759cd66b_2026011240.py @@ -25,22 +25,24 @@ def upgrade() -> None: op.alter_column('memory_config', 'group_id', new_column_name='end_user_id') # 3. config_id: INTEGER -> UUID(保留旧值以便回滚) - op.alter_column('memory_config', 'config_id', new_column_name='config_id_old') + op.drop_constraint('data_config_pkey', 'memory_config', type_='primary') + op.alter_column('memory_config', 'config_id', new_column_name='config_id_old', nullable=True) op.add_column('memory_config', sa.Column('config_id', sa.UUID(), nullable=True)) op.execute("UPDATE memory_config SET config_id = apply_id::uuid") - op.drop_constraint('data_config_pkey', 'memory_config', type_='primary') op.alter_column('memory_config', 'config_id', nullable=False) op.create_primary_key('memory_config_pkey', 'memory_config', ['config_id']) op.execute("DROP SEQUENCE IF EXISTS data_config_config_id_seq") def downgrade() -> None: - # 1. config_id: UUID -> INTEGER(恢复旧值) + # 1. config_id: UUID -> INTEGER(恢复旧值,空值生成新ID) + op.execute("CREATE SEQUENCE IF NOT EXISTS data_config_config_id_seq") + op.execute("UPDATE memory_config SET config_id_old = nextval('data_config_config_id_seq') WHERE config_id_old IS NULL") op.drop_constraint('memory_config_pkey', 'memory_config', type_='primary') op.drop_column('memory_config', 'config_id') - op.alter_column('memory_config', 'config_id_old', new_column_name='config_id') + op.alter_column('memory_config', 'config_id_old', new_column_name='config_id', nullable=False) op.create_primary_key('data_config_pkey', 'memory_config', ['config_id']) - op.execute("CREATE SEQUENCE IF NOT EXISTS data_config_config_id_seq OWNED BY memory_config.config_id") + op.execute("ALTER SEQUENCE data_config_config_id_seq OWNED BY memory_config.config_id") op.execute("SELECT setval('data_config_config_id_seq', COALESCE((SELECT MAX(config_id) FROM memory_config), 1))") # 2. 重命名列 end_user_id -> group_id From 2eff6b2e9da716ba88664ea9d830f8e57705e8e4 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 26 Jan 2026 15:46:28 +0800 Subject: [PATCH 16/28] feat(web): add workflow runtime info --- web/src/components/Chat/ChatContent.tsx | 17 +- web/src/components/Chat/types.ts | 5 +- web/src/components/Markdown/CodeBlock.tsx | 16 +- web/src/i18n/en.ts | 4 + web/src/i18n/zh.ts | 4 + web/src/utils/stream.ts | 14 ++ .../views/Workflow/components/Chat/Chat.tsx | 206 ++++++++++++++++-- .../Workflow/components/Chat/chat.module.css | 45 ++++ 8 files changed, 286 insertions(+), 25 deletions(-) create mode 100644 web/src/views/Workflow/components/Chat/chat.module.css diff --git a/web/src/components/Chat/ChatContent.tsx b/web/src/components/Chat/ChatContent.tsx index c90f9208..a5d02b2b 100644 --- a/web/src/components/Chat/ChatContent.tsx +++ b/web/src/components/Chat/ChatContent.tsx @@ -8,6 +8,7 @@ import { type FC, useRef, useEffect } from 'react' import clsx from 'clsx' import Markdown from '@/components/Markdown' import type { ChatContentProps } from './types' +import { Spin } from 'antd' /** * 聊天内容显示组件 @@ -21,7 +22,8 @@ const ChatContent: FC = ({ empty, labelPosition = 'bottom', labelFormat, - errorDesc + errorDesc, + renderRuntime }) => { // 滚动容器引用,用于控制自动滚动到底部 const scrollContainerRef = useRef<(HTMLDivElement | null)>(null) @@ -45,8 +47,8 @@ const ChatContent: FC = ({ 'rb:left-0 rb:text-left': item.role === 'assistant', // 助手消息左对齐 })}> {/* 流式加载时且内容为空则不显示 */} - {streamLoading && item.content === '' - ? null + {streamLoading && item.content === '' && !renderRuntime + ? : <> {/* 顶部标签(如时间戳、用户名等) */} {labelPosition === 'top' && @@ -55,16 +57,17 @@ const ChatContent: FC = ({ } {/* 消息气泡框 */} -
+ {item.subContent && renderRuntime && renderRuntime(item, index)} {/* 使用Markdown组件渲染消息内容 */} - +
{/* 底部标签(如时间戳、用户名等) */} {labelPosition === 'bottom' && diff --git a/web/src/components/Chat/types.ts b/web/src/components/Chat/types.ts index 851a8ccc..264ce39c 100644 --- a/web/src/components/Chat/types.ts +++ b/web/src/components/Chat/types.ts @@ -19,7 +19,9 @@ export interface ChatItem { /** 消息内容 */ content?: string | null; /** 创建时间 */ - created_at?: number | string + created_at?: number | string; + status?: string; + subContent?: Record[] } /** @@ -81,4 +83,5 @@ export interface ChatContentProps { /** 标签格式化函数 */ labelFormat: (item: ChatItem) => any; errorDesc?: string; + renderRuntime?: (item: ChatItem, index: number) => ReactNode; } \ No newline at end of file diff --git a/web/src/components/Markdown/CodeBlock.tsx b/web/src/components/Markdown/CodeBlock.tsx index 23d54c34..a125a997 100644 --- a/web/src/components/Markdown/CodeBlock.tsx +++ b/web/src/components/Markdown/CodeBlock.tsx @@ -6,6 +6,9 @@ import CopyBtn from './CopyBtn'; type ICodeBlockProps = { value: string; + needCopy?: boolean; + size?: 'small' | 'default'; + showLineNumbers?: boolean; } // enum languageType { @@ -16,6 +19,9 @@ type ICodeBlockProps = { const CodeBlock: FC = ({ value, + needCopy = true, + size = 'default', + showLineNumbers = false }) => { return ( @@ -23,24 +29,26 @@ const CodeBlock: FC = ({ {value} - + />} ) } diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 1df2eb6d..87a95c40 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -1982,6 +1982,10 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re arrange: 'Arrange', redo: 'Redo', undo: 'Undo', + + input: 'Input', + output: 'Output', + error: 'Error Message', }, emotionEngine: { emotionEngineConfig: 'Emotion Engine Configuration', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 39908757..fc683a66 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -2076,6 +2076,10 @@ export const zh = { arrange: '整理', redo: '重做', undo: '撤销', + + input: '输入', + output: '输出', + error: '错误信息', }, emotionEngine: { emotionEngineConfig: '情感引擎配置', diff --git a/web/src/utils/stream.ts b/web/src/utils/stream.ts index e4179e25..2501fde5 100644 --- a/web/src/utils/stream.ts +++ b/web/src/utils/stream.ts @@ -123,6 +123,20 @@ export const handleSSE = async (url: string, data: any, onMessage?: (data: SSEMe let response = await makeSSERequest(url, data, token || '', config); switch (response.status) { + case 500: + case 502: + const errorData = await response.json(); + errorData.error || i18n.t('common.serviceUpgrading'); + message.warning(errorData.error || i18n.t('common.serviceUpgrading')); + break + case 400: + const error = await response.json(); + message.warning(errorData.error); + throw error || 'Bad Request'; + case 504: + const errorJson = await response.json(); + message.warning(errorJson.error || i18n.t('common.serverError')); + break case 401: if (url?.includes('/public')) { return message.warning(i18n.t('common.publicApiCannotRefreshToken')); diff --git a/web/src/views/Workflow/components/Chat/Chat.tsx b/web/src/views/Workflow/components/Chat/Chat.tsx index 246c2e4c..4a1ac5a7 100644 --- a/web/src/views/Workflow/components/Chat/Chat.tsx +++ b/web/src/views/Workflow/components/Chat/Chat.tsx @@ -1,8 +1,9 @@ import { forwardRef, useImperativeHandle, useState, useRef } from 'react' import { useTranslation } from 'react-i18next' import clsx from 'clsx' -import { Input, Form, App } from 'antd' -import { Space, Button } from 'antd' +import { Input, Form, App, Space, Button, Collapse } from 'antd' +import { CheckCircleFilled, CloseCircleFilled, LoadingOutlined } from '@ant-design/icons' +import CodeBlock from '@/components/Markdown/CodeBlock' import ChatIcon from '@/assets/images/application/chat.png' import RbDrawer from '@/components/RbDrawer'; @@ -13,8 +14,11 @@ import ChatContent from '@/components/Chat/ChatContent' import type { ChatItem } from '@/components/Chat/types' import ChatSendIcon from '@/assets/images/application/chatSend.svg' import dayjs from 'dayjs' -import type { ChatRef, VariableConfigModalRef, StartVariableItem, GraphRef } from '../../types' +import type { ChatRef, VariableConfigModalRef, GraphRef } from '../../types' import { type SSEMessage } from '@/utils/stream' +import type { Variable } from '../Properties/VariableList/types' +import styles from './chat.module.css' +import Markdown from '@/components/Markdown' const Chat = forwardRef(({ appId, graphRef }, ref) => { const { t } = useTranslation() @@ -24,7 +28,7 @@ const Chat = forwardRef(({ appId const [open, setOpen] = useState(false) const [loading, setLoading] = useState(false) const [chatList, setChatList] = useState([]) - const [variables, setVariables] = useState([]) + const [variables, setVariables] = useState([]) const [streamLoading, setStreamLoading] = useState(false) const [conversationId, setConversationId] = useState(null) @@ -39,7 +43,7 @@ const Chat = forwardRef(({ appId if (startNodes.length) { const curVariables = startNodes[0].config.variables?.defaultValue - curVariables.forEach((vo: StartVariableItem) => { + curVariables.forEach((vo: Variable) => { if (typeof vo.default !== 'undefined') { vo.value = vo.default } @@ -60,7 +64,7 @@ const Chat = forwardRef(({ appId const handleEditVariables = () => { variableConfigModalRef.current?.handleOpen(variables) } - const handleSave = (values: StartVariableItem[]) => { + const handleSave = (values: Variable[]) => { setVariables([...values]) } const handleSend = () => { @@ -97,13 +101,28 @@ const Chat = forwardRef(({ appId role: 'assistant', content: '', created_at: Date.now(), + subContent: [], }]) const handleStreamMessage = (data: SSEMessage[]) => { - setStreamLoading(false) - data.forEach(item => { - const { chunk, conversation_id } = item.data as { chunk: string; conversation_id: string | null; }; + const { chunk, conversation_id, node_id, input, output, error, elapsed_time, status } = item.data as { + chunk: string; + conversation_id: string | null; + node_id: string; + node_name?: string; + input?: any; + output?: any; + elapsed_time?: string; + error?: any; + state: Record; + status?: 'completed' | 'failed' + }; + + const node = graphRef.current?.getNodes().find(n => n.id === node_id); + const { name, icon } = node?.getData() || {} + + console.log('node', node?.getData()) switch(item.event) { case 'message': @@ -119,6 +138,66 @@ const Chat = forwardRef(({ appId return newList }) break + case 'node_start': + setChatList(prev => { + const newList = [...prev] + const lastIndex = newList.length - 1 + if (lastIndex >= 0) { + const newSubContent = newList[lastIndex].subContent || [] + const filterIndex = newSubContent.findIndex(vo => vo.id === node_id) + if (filterIndex > -1) { + newSubContent[filterIndex] = { + ...newSubContent[filterIndex], + node_id: node_id, + node_name: name, + icon, + content: {}, + } + } else { + newSubContent.push({ + id: node_id, + node_id: node_id, + node_name: name, + icon, + content: {}, + }) + } + newList[lastIndex] = { + ...newList[lastIndex], + subContent: newSubContent + } + } + return newList + }) + break + case 'node_end': + case 'node_error': + setChatList(prev => { + const newList = [...prev] + const lastIndex = newList.length - 1 + if (lastIndex >= 0) { + const newSubContent = newList[lastIndex].subContent || [] + const filterIndex = newSubContent.findIndex(vo => vo.node_id === node_id) + if (filterIndex > -1 && newSubContent[filterIndex].content) { + newSubContent[filterIndex] = { + ...newSubContent[filterIndex], + content: { + input, + output, + error, + }, + status: status || 'completed', + elapsed_time + } + } + newList[lastIndex] = { + ...newList[lastIndex], + subContent: newSubContent + } + } + return newList + }) + break case 'workflow_end': setChatList(prev => { const newList = [...prev] @@ -126,6 +205,7 @@ const Chat = forwardRef(({ appId if (lastIndex >= 0) { newList[lastIndex] = { ...newList[lastIndex], + status, content: newList[lastIndex].content === '' ? null : newList[lastIndex].content } } @@ -142,14 +222,31 @@ const Chat = forwardRef(({ appId } form.setFieldValue('message', undefined) + setStreamLoading(true) draftRun(appId, { message: message, variables: params, stream: true, conversation_id: conversationId }, handleStreamMessage) + .catch((error) => { + setChatList(prev => { + const newList = [...prev] + const lastIndex = newList.length - 1 + if (lastIndex >= 0) { + newList[lastIndex] = { + ...newList[lastIndex], + status: 'failed', + content: null, + subContent: error.error + } + } + return newList + }) + }) .finally(() => { setLoading(false) + setStreamLoading(false) }) } // 暴露给父组件的方法 @@ -158,6 +255,11 @@ const Chat = forwardRef(({ appId handleClose })); + const getStatus = (status?: string) => { + return status === 'completed' ? 'rb:text-[#369F21]' : status === 'failed' ? 'rb:text-[#FF5D34]' : 'rb:text-[#5B6167]' + } + + console.log('chatList', chatList) return ( @@ -173,10 +275,7 @@ const Chat = forwardRef(({ appId onClose={handleClose} > } data={chatList} @@ -184,6 +283,87 @@ const Chat = forwardRef(({ appId labelPosition="bottom" labelFormat={(item) => dayjs(item.created_at).locale('en').format('MMMM D, YYYY [at] h:mm A')} errorDesc={t('application.ReplyException')} + renderRuntime={(item, index) => { + return ( +
+ + {item.status === 'completed' ? : item.status === 'failed' ? : } + {t('application.workflow')} +
, + className: styles.collapseItem, + children: ( + Array.isArray(item.subContent) + ? + {item.subContent?.map(vo => ( + +
+ {vo.icon && } +
{vo.node_name || vo.node_id}
+
+ + {typeof vo.elapsed_time == 'number' && <>{vo.elapsed_time?.toFixed(3)}ms} + {vo.status === 'completed' ? : vo.status === 'failed' ? : } + + , + className: styles.collapseItem, + children: ( + + {vo.status === 'failed' && +
+
+ {t(`workflow.error`)} + +
+
+ +
+
+ } + {['input', 'output'].map(key => ( +
+
+ {t(`workflow.${key}`)} + +
+
+ +
+
+ ))} +
+ ) + }]} + /> + ))} +
+ :
+ +
+ ) + }]} + /> + + ) + }} />
diff --git a/web/src/views/Workflow/components/Chat/chat.module.css b/web/src/views/Workflow/components/Chat/chat.module.css new file mode 100644 index 00000000..99fe11f7 --- /dev/null +++ b/web/src/views/Workflow/components/Chat/chat.module.css @@ -0,0 +1,45 @@ +.completed { + background-color: rgba(54, 159, 33, 0.06); + border-color: rgba(54, 159, 33, 0.25); + border-radius: 8px; +} +.failed { + background-color: rgba(255, 138, 76, 0.08); + border-color: rgba(255, 138, 76, 0.20); + border-radius: 8px; +} +.default { + background-color: rgba(91, 97, 103, 0.08); + border-color: rgba(91, 97, 103, 0.30); + border-radius: 8px; +} +.collapse-item { + font-size: 12px; + line-height: 16px; +} +.collapse-item:global(.ant-collapse-item>.ant-collapse-header) { + padding: 8px 12px; +} +.collapse-item:global(.ant-collapse-item>.ant-collapse-header .ant-collapse-expand-icon) { + height: 16px; +} +.completed:global(.ant-collapse .ant-collapse-content), +.failed:global(.ant-collapse .ant-collapse-content) { + background-color: transparent; + border-top: none; +} +:global(.ant-collapse .ant-collapse-content>.ant-collapse-content-box) { + padding-top: 0; +} +.collapse-item :global(.ant-collapse) { + /* background-color: #F0F3F8; */ + background-color: #FBFDFF; + border-radius: 6px; +} +.collapse-item :global(.ant-collapse>.ant-collapse-item:last-child), +.collapse-item :global(.ant-collapse>.ant-collapse-item:last-child>.ant-collapse-header) { + border-radius: 0 0 6px 6px; +} +.collapse-item :global(.ant-collapse .ant-collapse-content>.ant-collapse-content-box) { + padding: 0 4px 4px 4px; +} \ No newline at end of file From 7bfa7b3f029812a8966f639e0de3f6458c84adf1 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 26 Jan 2026 16:00:47 +0800 Subject: [PATCH 17/28] fix(web): handleSSE bugfix --- web/src/utils/stream.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/utils/stream.ts b/web/src/utils/stream.ts index 2501fde5..be2220da 100644 --- a/web/src/utils/stream.ts +++ b/web/src/utils/stream.ts @@ -131,7 +131,7 @@ export const handleSSE = async (url: string, data: any, onMessage?: (data: SSEMe break case 400: const error = await response.json(); - message.warning(errorData.error); + message.warning(error.error); throw error || 'Bad Request'; case 504: const errorJson = await response.json(); From 3b4b474ce869e4416db314df8669671a1c931a6f Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Mon, 26 Jan 2026 16:32:58 +0800 Subject: [PATCH 18/28] fix(sandbox): prevent imports from being blocked when network is disabled --- sandbox/lib/seccomp_python/src/syscalls.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sandbox/lib/seccomp_python/src/syscalls.rs b/sandbox/lib/seccomp_python/src/syscalls.rs index 07070d22..961fffac 100644 --- a/sandbox/lib/seccomp_python/src/syscalls.rs +++ b/sandbox/lib/seccomp_python/src/syscalls.rs @@ -10,6 +10,7 @@ pub static ALLOW_SYSCALLS: &[i32] = &[ libc::SYS_ioctl as i32, libc::SYS_lseek as i32, libc::SYS_getdents64 as i32, + libc::SYS_fstat as i32, // thread libc::SYS_futex as i32, @@ -77,7 +78,6 @@ pub static ALLOW_NETWORK_SYSCALLS: &[i32] = &[ libc::SYS_sendmsg as i32, libc::SYS_sendmmsg as i32, libc::SYS_getsockopt as i32, - libc::SYS_fstat as i32, libc::SYS_fcntl as i32, libc::SYS_fstatfs as i32, libc::SYS_poll as i32, From ebc41b2eec3ed5f1dcb46fc5b66e6c87abe1f437 Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Mon, 26 Jan 2026 17:22:48 +0800 Subject: [PATCH 19/28] Fix/memory bug fix (#199) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 图谱数据量限制数量去掉 * 图谱数据量限制数量去掉 * 图谱数据量限制数量去掉 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 读取的接口,去掉全局锁 * 输出数组 * 反思优化1.0(优化隐私输出、时间检索) * 反思优化1.0(优化隐私输出、时间检索) * 反思优化1.0(优化隐私输出、时间检索) * 反思优化测试接口 * 反思优化测试接口 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察) * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 * 把group_id替换end_user_id * 把group_id替换end_user_id_ * 把group_id替换end_user_id_ * config_config替换成memory_config * config_config替换成memory_config * [fix]Fix the memory interface to use end_user_id. * config_config替换成memory_config * config_config替换成memory_config * config_config替换成memory_config * config_id字段改成UUID * config_id字段改成UUID * config_id字段改成UUID * config_id字段改成UUID,与develop校对恢复 * 检查项目,修复group_id的遗留问题 * 检查项目,修复group_id的遗留问题 * 解决冲突 * 解决冲突 * end_user_id清理干净 * end_user_id清理干净 * 修复遗留合并BUG * 修复遗留合并BUG * 修复遗留合并BUG * 修复遗留合并BUG * 感知meta_data字段BUG修复 * user_id->现实为config_id_old * user_id->显示为config_id_old传输 --------- Co-authored-by: lanceyq <1982376970@qq.com> --- api/app/services/memory_storage_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 80d8c717..1707f8fa 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -188,7 +188,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) "config_desc": config.config_desc, "workspace_id": str(config.workspace_id) if config.workspace_id else None, "end_user_id": config.end_user_id, - "user_id": config.user_id, + "config_id_old": config.config_id_old, "apply_id": config.apply_id, "llm_id": config.llm_id, "embedding_id": config.embedding_id, From 46f0f3cee90f7cf852bf5bcf89866b57448f1ffa Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 26 Jan 2026 17:43:25 +0800 Subject: [PATCH 20/28] feat(web): update read_all_config select valueKey --- web/src/components/CustomSelect/index.tsx | 19 +++++++++++++------ web/src/views/ApplicationConfig/Agent.tsx | 6 ++++-- web/src/views/Workflow/constant.ts | 4 ++-- web/src/views/Workflow/types.ts | 2 +- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/web/src/components/CustomSelect/index.tsx b/web/src/components/CustomSelect/index.tsx index 1887d635..6153a76d 100644 --- a/web/src/components/CustomSelect/index.tsx +++ b/web/src/components/CustomSelect/index.tsx @@ -15,7 +15,7 @@ interface ApiResponse { interface CustomSelectProps extends Omit { url: string; params?: Record; - valueKey?: string; + valueKey?: string | string[]; labelKey?: string; placeholder?: string; hasAll?: boolean; @@ -66,11 +66,18 @@ const CustomSelect: FC = ({ {...props} > {hasAll && {allTitle || t('common.all')}} - {displayOptions.map((option) => ( - - {String(option[labelKey])} - - ))} + {displayOptions.map((option) => { + const getValue = () => { + if (typeof valueKey === 'string') return option[valueKey]; + return valueKey.find(key => option[key] != null) ? option[valueKey.find(key => option[key] != null)!] : undefined; + }; + const value = getValue(); + return ( + + {String(option[labelKey])} + + ); + })} ); }; diff --git a/web/src/views/ApplicationConfig/Agent.tsx b/web/src/views/ApplicationConfig/Agent.tsx index 77e90440..97a622d1 100644 --- a/web/src/views/ApplicationConfig/Agent.tsx +++ b/web/src/views/ApplicationConfig/Agent.tsx @@ -79,7 +79,7 @@ const SelectWrapper: FC<{ title: string, desc: string, name: string | string[], placeholder={t('common.pleaseSelect')} url={url} hasAll={false} - valueKey='config_id' + valueKey={['config_id_old', 'config_id']} labelKey="config_name" /> @@ -126,12 +126,14 @@ const Agent = forwardRef((_props, ref) => { getApplicationConfig(id as string).then(res => { const response = res as Config let allTools = Array.isArray(response.tools) ? response.tools : [] + const memoryContent = response.memory?.memory_content + const convertedMemoryContent = memoryContent && !isNaN(Number(memoryContent)) ? Number(memoryContent) : memoryContent form.setFieldsValue({ ...response, tools: allTools, memory: { ...response.memory, - memory_content: response.memory?.memory_content ? Number(response.memory?.memory_content) : undefined + memory_content: convertedMemoryContent } }) setData({ diff --git a/web/src/views/Workflow/constant.ts b/web/src/views/Workflow/constant.ts index e250e184..aab8be7d 100644 --- a/web/src/views/Workflow/constant.ts +++ b/web/src/views/Workflow/constant.ts @@ -200,7 +200,7 @@ export const nodeLibrary: NodeLibrary[] = [ config_id: { type: 'customSelect', url: memoryConfigListUrl, - valueKey: 'config_id', + valueKey: ['config_id_old', 'config_id'], labelKey: 'config_name' }, search_switch: { @@ -223,7 +223,7 @@ export const nodeLibrary: NodeLibrary[] = [ config_id: { type: 'customSelect', url: memoryConfigListUrl, - valueKey: 'config_id', + valueKey: ['config_id_old', 'config_id'], labelKey: 'config_name' } } diff --git a/web/src/views/Workflow/types.ts b/web/src/views/Workflow/types.ts index 909c30e4..31d1f512 100644 --- a/web/src/views/Workflow/types.ts +++ b/web/src/views/Workflow/types.ts @@ -14,7 +14,7 @@ export interface NodeConfig { url?: string; params?: { [key: string]: unknown; } - valueKey?: string; + valueKey?: string | string[]; labelKey?: string; defaultValue?: any; From f1f887faaebc3caa74feb05730a76faae6bb30f3 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Mon, 26 Jan 2026 17:29:44 +0800 Subject: [PATCH 21/28] feat(workflow): Add a new node for executing code --- api/app/core/workflow/nodes/code/__init__.py | 3 + api/app/core/workflow/nodes/code/config.py | 50 +++++++ api/app/core/workflow/nodes/code/node.py | 122 ++++++++++++++++++ api/app/core/workflow/nodes/configs.py | 14 +- api/app/core/workflow/nodes/node_factory.py | 5 +- sandbox/app/core/executor.py | 1 - .../app/core/runners/python/python_runner.py | 5 +- sandbox/app/services/python_service.py | 4 +- 8 files changed, 193 insertions(+), 11 deletions(-) create mode 100644 api/app/core/workflow/nodes/code/config.py create mode 100644 api/app/core/workflow/nodes/code/node.py diff --git a/api/app/core/workflow/nodes/code/__init__.py b/api/app/core/workflow/nodes/code/__init__.py index e69de29b..e42af93d 100644 --- a/api/app/core/workflow/nodes/code/__init__.py +++ b/api/app/core/workflow/nodes/code/__init__.py @@ -0,0 +1,3 @@ +from app.core.workflow.nodes.code.node import CodeNode + +__all__ = ["CodeNode"] \ No newline at end of file diff --git a/api/app/core/workflow/nodes/code/config.py b/api/app/core/workflow/nodes/code/config.py new file mode 100644 index 00000000..35b757e9 --- /dev/null +++ b/api/app/core/workflow/nodes/code/config.py @@ -0,0 +1,50 @@ +from typing import Literal +from pydantic import Field, BaseModel + +from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType + + +class InputVariable(BaseModel): + name: str = Field( + ..., + description="variable name" + ) + + variable: str = Field( + ..., + description="variable selector" + ) + + +class OutputVariable(BaseModel): + name: str = Field( + ..., + description="variable name" + ) + + type: VariableType = Field( + ..., + description="variable selector" + ) + + +class CodeNodeConfig(BaseNodeConfig): + input_variables: list[InputVariable] = Field( + default_factory=list, + description="input variables" + ) + + output_variables: list[OutputVariable] = Field( + default_factory=list, + description="output variables" + ) + + code_content: str = Field( + default="", + description="code content" + ) + + language: Literal['python3', 'nodejs'] = Field( + ..., + description="language" + ) diff --git a/api/app/core/workflow/nodes/code/node.py b/api/app/core/workflow/nodes/code/node.py new file mode 100644 index 00000000..3e15089b --- /dev/null +++ b/api/app/core/workflow/nodes/code/node.py @@ -0,0 +1,122 @@ +import base64 +import json +import logging +import re +from string import Template +from textwrap import dedent +from typing import Any + +import httpx +from sympy.physics.vector import vlatex + +from app.core.workflow.nodes import BaseNode, WorkflowState +from app.core.workflow.nodes.base_config import VariableType +from app.core.workflow.nodes.code.config import CodeNodeConfig + +logger = logging.getLogger(__name__) + +SCRIPT_TEMPLATE = Template(dedent(""" +$code + +import json +from base64 import b64decode + +# decode and prepare input dict +inputs_obj = json.loads(b64decode('$inputs_variable').decode('utf-8')) + +# execute main function +output_obj = main(**inputs_obj) + +# convert output to json and print +output_json = json.dumps(output_obj, indent=4) +result = "<>" + output_json + "<>" +print(result) +""")) + + +class CodeNode(BaseNode): + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + super().__init__(node_config, workflow_config) + self.typed_config: CodeNodeConfig | None = None + + def extract_result(self, content: str): + match = re.search(r'<>(.*?)<>', content, re.DOTALL) + if match: + extracted = match.group(1) + exec_result = json.loads(extracted) + result = {} + for output in self.typed_config.output_variables: + value = exec_result.get(output.name) + if not value: + raise RuntimeError(f"Return value {output.name} does not exist") + match output.type: + case VariableType.STRING: + if not isinstance(value, str): + raise RuntimeError(f"Return value {output.name} should be a string") + case VariableType.BOOLEAN: + if not isinstance(value, bool): + raise RuntimeError(f"Return value {output.name} should be a boolean") + case VariableType.NUMBER: + if not isinstance(value, (int, float)): + raise RuntimeError(f"Return value {output.name} should be a number") + case VariableType.OBJECT: + if not isinstance(value, dict): + raise RuntimeError(f"Return value {output.name} should be a dictionary") + case VariableType.ARRAY_STRING: + if not isinstance(value, list) or not all(isinstance(v, str) for v in value): + raise RuntimeError(f"Return value {output.name} should be a list of strings") + case VariableType.ARRAY_NUMBER: + if not isinstance(value, list) or not all(isinstance(v, (int, float)) for v in value): + raise RuntimeError(f"Return value {output.name} should be a list of numbers") + case VariableType.ARRAY_OBJECT: + if not isinstance(value, list) or not all(isinstance(v, dict) for v in value): + raise RuntimeError(f"Return value {output.name} should be a list of dictionaries") + case VariableType.ARRAY_BOOLEAN: + if not isinstance(value, list) or not all(isinstance(v, bool) for v in value): + raise RuntimeError(f"Return value {output.name} should be a list of booleans") + result[output.name] = value + return result + else: + raise RuntimeError("The output of main must be a dictionary") + + async def execute(self, state: WorkflowState) -> Any: + self.typed_config = CodeNodeConfig(**self.config) + input_variable_dict = {} + for input_variable in self.typed_config.input_variables: + input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, state) + code = base64.b64decode( + self.typed_config.code + ).decode("utf-8") + + input_variable_dict = base64.b64encode( + json.dumps(input_variable_dict).encode("utf-8") + ).decode("utf-8") + + final_script = SCRIPT_TEMPLATE.substitute( + code=code, + inputs_variable=input_variable_dict, + ) + + async with httpx.AsyncClient() as client: + response = await client.post( + "http://sandbox:8194/v1/sandbox/run", + headers={ + "x-api-key": 'redbear-sandbox' + }, + json={ + "language": "python3", + "code": base64.b64encode(final_script.encode("utf-8")).decode("utf-8"), + "options": { + "enable_network": True + } + } + ) + resp = response.json() + + match resp['code']: + case 31: + raise RuntimeError("Operation not permitted") + case 0: + return self.extract_result(resp["data"]["stdout"]) + case _: + raise Exception(resp["message"]) diff --git a/api/app/core/workflow/nodes/configs.py b/api/app/core/workflow/nodes/configs.py index 4d31efaa..d73754f6 100644 --- a/api/app/core/workflow/nodes/configs.py +++ b/api/app/core/workflow/nodes/configs.py @@ -10,21 +10,22 @@ from app.core.workflow.nodes.base_config import ( VariableDefinition, VariableType, ) +from app.core.workflow.nodes.code.config import CodeNodeConfig +from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig from app.core.workflow.nodes.end.config import EndNodeConfig from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig from app.core.workflow.nodes.if_else.config import IfElseNodeConfig from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig -from app.core.workflow.nodes.start.config import StartNodeConfig -from app.core.workflow.nodes.transform.config import TransformNodeConfig -from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig +from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig +from app.core.workflow.nodes.start.config import StartNodeConfig from app.core.workflow.nodes.tool.config import ToolNodeConfig -from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig +from app.core.workflow.nodes.transform.config import TransformNodeConfig +from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig -from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig __all__ = [ # 基础类 "BaseNodeConfig", @@ -49,5 +50,6 @@ __all__ = [ "QuestionClassifierNodeConfig", "ToolNodeConfig", "MemoryReadNodeConfig", - "MemoryWriteNodeConfig" + "MemoryWriteNodeConfig", + "CodeNodeConfig" ] diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index 9fca8d7a..fb2fe00f 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -10,6 +10,7 @@ from typing import Any, Union from app.core.workflow.nodes.agent import AgentNode from app.core.workflow.nodes.assigner import AssignerNode from app.core.workflow.nodes.base_node import BaseNode +from app.core.workflow.nodes.code import CodeNode from app.core.workflow.nodes.cycle_graph.node import CycleGraphNode from app.core.workflow.nodes.end import EndNode from app.core.workflow.nodes.enums import NodeType @@ -49,7 +50,8 @@ WorkflowNode = Union[ QuestionClassifierNode, ToolNode, MemoryReadNode, - MemoryWriteNode + MemoryWriteNode, + CodeNode ] @@ -81,6 +83,7 @@ class NodeFactory: NodeType.TOOL: ToolNode, NodeType.MEMORY_READ: MemoryReadNode, NodeType.MEMORY_WRITE: MemoryWriteNode, + NodeType.CODE: CodeNode, } @classmethod diff --git a/sandbox/app/core/executor.py b/sandbox/app/core/executor.py index 6edc48c0..e87b510c 100644 --- a/sandbox/app/core/executor.py +++ b/sandbox/app/core/executor.py @@ -15,7 +15,6 @@ class ExecutionResult: self.stdout = stdout self.stderr = stderr self.exit_code = exit_code - self.error = error class CodeExecutor(ABC): diff --git a/sandbox/app/core/runners/python/python_runner.py b/sandbox/app/core/runners/python/python_runner.py index faac5f0c..30792b91 100644 --- a/sandbox/app/core/runners/python/python_runner.py +++ b/sandbox/app/core/runners/python/python_runner.py @@ -9,12 +9,15 @@ from app.config import SANDBOX_USER_ID, SANDBOX_GROUP_ID, get_config from app.core.encryption import generate_key, encrypt_code from app.core.executor import CodeExecutor, ExecutionResult from app.core.runners.python.settings import check_lib_avaiable, release_lib_binary, LIB_PATH +from app.logger import get_logger from app.models import RunnerOptions # Python sandbox prescript template with open("app/core/runners/python/prescript.py") as f: PYTHON_PRESCRIPT = f.read() +logger = get_logger() + class PythonRunner(CodeExecutor): """Python code runner with security isolation""" @@ -106,6 +109,7 @@ class PythonRunner(CodeExecutor): env["ALLOWED_SYSCALLS"] = ",".join(map(str, config.allowed_syscalls)) # Execute with Python interpreter + logger.info(encoded_key) process = await asyncio.create_subprocess_exec( config.python_path, @@ -143,7 +147,6 @@ class PythonRunner(CodeExecutor): stdout="", stderr="Execution timeout", exit_code=-1, - error="Execution timeout" ) finally: diff --git a/sandbox/app/services/python_service.py b/sandbox/app/services/python_service.py index 71cfda0d..5700841d 100644 --- a/sandbox/app/services/python_service.py +++ b/sandbox/app/services/python_service.py @@ -37,8 +37,8 @@ async def run_python_code(code: str, preload: str, options: RunnerOptions): if result.exit_code == -signal.SIGSYS: return error_response(31, "sandbox security policy violation") - if result.error: - return error_response(-500, result.error) + if result.stderr: + return error_response(500, result.stderr) return success_response(RunCodeResponse( stdout=result.stdout, From f76bffb4823252ed482867c3fe112a1cf09f5a16 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 26 Jan 2026 18:32:18 +0800 Subject: [PATCH 22/28] fix(web): KnowledgeConfigModal bugfix --- .../components/Knowledge/KnowledgeConfigModal.tsx | 2 +- .../Properties/Knowledge/KnowledgeConfigModal.tsx | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/web/src/views/ApplicationConfig/components/Knowledge/KnowledgeConfigModal.tsx b/web/src/views/ApplicationConfig/components/Knowledge/KnowledgeConfigModal.tsx index abf56b18..70b17a11 100644 --- a/web/src/views/ApplicationConfig/components/Knowledge/KnowledgeConfigModal.tsx +++ b/web/src/views/ApplicationConfig/components/Knowledge/KnowledgeConfigModal.tsx @@ -66,7 +66,7 @@ const KnowledgeConfigModal = forwardRef { if (values?.retrieve_type) { const fieldsToReset = Object.keys(values).filter(key => - key !== 'kb_id' && key !== 'retrieve_type' + key !== 'kb_id' && key !== 'retrieve_type' && key !== 'top_k' ) as (keyof KnowledgeConfigForm)[]; form.resetFields(fieldsToReset); } diff --git a/web/src/views/Workflow/components/Properties/Knowledge/KnowledgeConfigModal.tsx b/web/src/views/Workflow/components/Properties/Knowledge/KnowledgeConfigModal.tsx index 77ca21a2..196ce8e3 100644 --- a/web/src/views/Workflow/components/Properties/Knowledge/KnowledgeConfigModal.tsx +++ b/web/src/views/Workflow/components/Properties/Knowledge/KnowledgeConfigModal.tsx @@ -66,7 +66,7 @@ const KnowledgeConfigModal = forwardRef { if (values?.retrieve_type) { const fieldsToReset = Object.keys(values).filter(key => - key !== 'kb_id' && key !== 'retrieve_type' + key !== 'kb_id' && key !== 'retrieve_type' && key !== 'top_k' ) as (keyof KnowledgeConfigForm)[]; form.resetFields(fieldsToReset); } @@ -108,6 +108,7 @@ const KnowledgeConfigModal = forwardRef {/* Top K */} @@ -116,13 +117,12 @@ const KnowledgeConfigModal = forwardRef form.setFieldValue('top_k', value)} + // onChange={(value) => form.setFieldValue('top_k', value)} /> {/* 语义相似度阈值 similarity_threshold */} From 5267bd60a566893d9269a20c4a073642b479fa33 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 26 Jan 2026 18:40:28 +0800 Subject: [PATCH 23/28] fix(web): iteration's variable add parameter-extractor node --- web/src/views/Workflow/constant.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/views/Workflow/constant.ts b/web/src/views/Workflow/constant.ts index e250e184..7b15c049 100644 --- a/web/src/views/Workflow/constant.ts +++ b/web/src/views/Workflow/constant.ts @@ -284,7 +284,7 @@ export const nodeLibrary: NodeLibrary[] = [ config: { input: { type: 'variableList', - filterNodeTypes: ['knowledge-retrieval', 'iteration', 'loop'], + filterNodeTypes: ['knowledge-retrieval', 'iteration', 'loop', 'parameter-extractor'], filterVariableNames: ['message'] }, parallel: { From 1f615a06add14d193f7d2840bddefb53712ced32 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Mon, 26 Jan 2026 18:50:22 +0800 Subject: [PATCH 24/28] fix(sandbox): treat non-zero exit codes as errors instead of relying only on stderr --- api/app/core/workflow/nodes/code/config.py | 2 +- api/app/core/workflow/nodes/code/node.py | 4 ++-- sandbox/app/services/python_service.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/app/core/workflow/nodes/code/config.py b/api/app/core/workflow/nodes/code/config.py index 35b757e9..8af13f12 100644 --- a/api/app/core/workflow/nodes/code/config.py +++ b/api/app/core/workflow/nodes/code/config.py @@ -39,7 +39,7 @@ class CodeNodeConfig(BaseNodeConfig): description="output variables" ) - code_content: str = Field( + code: str = Field( default="", description="code content" ) diff --git a/api/app/core/workflow/nodes/code/node.py b/api/app/core/workflow/nodes/code/node.py index 3e15089b..5262a7e2 100644 --- a/api/app/core/workflow/nodes/code/node.py +++ b/api/app/core/workflow/nodes/code/node.py @@ -47,7 +47,7 @@ class CodeNode(BaseNode): result = {} for output in self.typed_config.output_variables: value = exec_result.get(output.name) - if not value: + if value is None: raise RuntimeError(f"Return value {output.name} does not exist") match output.type: case VariableType.STRING: @@ -104,7 +104,7 @@ class CodeNode(BaseNode): "x-api-key": 'redbear-sandbox' }, json={ - "language": "python3", + "language": self.typed_config.language, "code": base64.b64encode(final_script.encode("utf-8")).decode("utf-8"), "options": { "enable_network": True diff --git a/sandbox/app/services/python_service.py b/sandbox/app/services/python_service.py index 5700841d..210b2086 100644 --- a/sandbox/app/services/python_service.py +++ b/sandbox/app/services/python_service.py @@ -37,7 +37,7 @@ async def run_python_code(code: str, preload: str, options: RunnerOptions): if result.exit_code == -signal.SIGSYS: return error_response(31, "sandbox security policy violation") - if result.stderr: + if result.stderr and result.exit_code != 0: return error_response(500, result.stderr) return success_response(RunCodeResponse( From a5b8d3afa5ef19723ae5cce57f3fbbee70ff51f8 Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Mon, 26 Jan 2026 19:05:07 +0800 Subject: [PATCH 25/28] Fix/memory bug fix (#200) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 图谱数据量限制数量去掉 * 图谱数据量限制数量去掉 * 图谱数据量限制数量去掉 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 读取的接口,去掉全局锁 * 输出数组 * 反思优化1.0(优化隐私输出、时间检索) * 反思优化1.0(优化隐私输出、时间检索) * 反思优化1.0(优化隐私输出、时间检索) * 反思优化测试接口 * 反思优化测试接口 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 * 读取接口内层嵌套BUG修复 * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察) * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 * 新增中翻英功能(记忆时间线)(用户摘要)(兴趣分布接口)(查询核心档案)(记忆洞察)-接口添加翻译字段 * 把group_id替换end_user_id * 把group_id替换end_user_id_ * 把group_id替换end_user_id_ * config_config替换成memory_config * config_config替换成memory_config * [fix]Fix the memory interface to use end_user_id. * config_config替换成memory_config * config_config替换成memory_config * config_config替换成memory_config * config_id字段改成UUID * config_id字段改成UUID * config_id字段改成UUID * config_id字段改成UUID,与develop校对恢复 * 检查项目,修复group_id的遗留问题 * 检查项目,修复group_id的遗留问题 * 解决冲突 * 解决冲突 * end_user_id清理干净 * end_user_id清理干净 * 修复遗留合并BUG * 修复遗留合并BUG * 修复遗留合并BUG * 修复遗留合并BUG * 感知meta_data字段BUG修复 * user_id->现实为config_id_old * user_id->显示为config_id_old传输 * user_id->显示为config_id_old传输 * user_id->显示为config_id_old传输 --------- Co-authored-by: lanceyq <1982376970@qq.com> --- api/app/services/memory_storage_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 1707f8fa..0ede7bd3 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -188,7 +188,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) "config_desc": config.config_desc, "workspace_id": str(config.workspace_id) if config.workspace_id else None, "end_user_id": config.end_user_id, - "config_id_old": config.config_id_old, + "config_id_old": int(config.user_id), "apply_id": config.apply_id, "llm_id": config.llm_id, "embedding_id": config.embedding_id, From 80ca247435fe9d79a0c0c71fd4b0113284eaf359 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= <162269739+lanceyq@users.noreply.github.com> Date: Mon, 26 Jan 2026 19:05:20 +0800 Subject: [PATCH 26/28] Refactor/benchmark test (#196) * [changes]refactor locomo_test * [fix]Fix the circular import of ModelParameters * [changes]The benchmark test can run stably. * [fix]Complete end-to-end LoCoMo repair * [fix]Complete the end-to-end longmemeval and memsciqa fixes * [changes]Complete the benchmark test description document to ensure that the configuration parameters take effect. * [changes]refactor locomo_test * [fix]Fix the circular import of ModelParameters * [changes]The benchmark test can run stably. * [fix]Complete end-to-end LoCoMo repair * [fix]Complete the end-to-end longmemeval and memsciqa fixes * [changes]Complete the benchmark test description document to ensure that the configuration parameters take effect. * [changes]Benchmark test adaptation for end_user_id * [changes]refactor locomo_test * [fix]Fix the circular import of ModelParameters * [changes]The benchmark test can run stably. * [fix]Complete end-to-end LoCoMo repair * [fix]Complete the end-to-end longmemeval and memsciqa fixes * [changes]Complete the benchmark test description document to ensure that the configuration parameters take effect. * [fix]Complete the end-to-end longmemeval and memsciqa fixes * [changes]Complete the benchmark test description document to ensure that the configuration parameters take effect. * [changes]Benchmark test adaptation for end_user_id --- .../memory/evaluation/.env.evaluation.example | 224 +++++ api/app/core/memory/evaluation/.gitignore | 13 + api/app/core/memory/evaluation/benchmark.md | 772 +++++++++++++++- .../memory/evaluation/check_enduser_data.py | 371 ++++++++ .../core/memory/evaluation/common/metrics.py | 2 +- .../memory/evaluation/dialogue_queries.py | 6 +- .../memory/evaluation/extraction_utils.py | 369 +++++--- .../evaluation/locomo/locomo_benchmark.py | 863 +++++++++++------- .../memory/evaluation/locomo/locomo_test.py | 293 +++--- .../memory/evaluation/locomo/locomo_utils.py | 69 +- .../evaluation/locomo/qwen_search_eval.py | 96 +- ...earch_eval.py => longmemeval_benchmark.py} | 181 ++-- .../evaluation/longmemeval/test_eval.py | 132 ++- .../evaluation/memsciqa/memsciqa-test.py | 224 +++-- .../{evaluate_qa.py => memsciqa_benchmark.py} | 123 ++- api/app/core/memory/evaluation/run_eval.py | 21 +- .../extraction_orchestrator.py | 22 +- api/app/models/agent_app_config_model.py | 2 +- api/app/models/multi_agent_model.py | 2 +- api/app/schemas/multi_agent_schema.py | 2 +- api/app/services/master_agent_router.py | 2 +- api/app/utils/app_config_utils.py | 2 +- 22 files changed, 2760 insertions(+), 1031 deletions(-) create mode 100644 api/app/core/memory/evaluation/.env.evaluation.example create mode 100644 api/app/core/memory/evaluation/.gitignore create mode 100644 api/app/core/memory/evaluation/check_enduser_data.py rename api/app/core/memory/evaluation/longmemeval/{qwen_search_eval.py => longmemeval_benchmark.py} (93%) rename api/app/core/memory/evaluation/memsciqa/{evaluate_qa.py => memsciqa_benchmark.py} (76%) diff --git a/api/app/core/memory/evaluation/.env.evaluation.example b/api/app/core/memory/evaluation/.env.evaluation.example new file mode 100644 index 00000000..be089eb4 --- /dev/null +++ b/api/app/core/memory/evaluation/.env.evaluation.example @@ -0,0 +1,224 @@ +# ============================================================================ +# 基准测试统一配置文件示例 +# ============================================================================ +# 复制此文件为 .env.evaluation 并根据需要修改 +# 支持的基准测试:LoCoMo、LongMemEval、MemSciQA +# ============================================================================ + +# ============================================================================ +# 通用配置(所有基准测试共用) +# ============================================================================ + +# ---------------------------------------------------------------------------- +# Neo4j 配置 +# ---------------------------------------------------------------------------- +# 默认 Group ID(建议各基准测试使用独立的 group) +EVAL_GROUP_ID=benchmark_default + +# ---------------------------------------------------------------------------- +# 模型配置(必需) +# ---------------------------------------------------------------------------- +# ⚠️ 必填:从数据库 models 表中选择有效的模型 ID +# +# 如何获取模型 ID: +# 1. 查询数据库:SELECT id, model_name FROM models WHERE is_active = true; +# 2. 或通过系统管理界面查看 +# 3. 确保模型可用且配置正确 + +# LLM 模型 ID(必填) +EVAL_LLM_ID=your_llm_model_id_here + +# Embedding 模型 ID(必填) +EVAL_EMBEDDING_ID=your_embedding_model_id_here + +# ---------------------------------------------------------------------------- +# 检索参数 +# ---------------------------------------------------------------------------- +# 检索类型: "keyword", "embedding", "hybrid" +EVAL_SEARCH_TYPE=hybrid + +# 检索结果数量限制(默认值) +EVAL_SEARCH_LIMIT=12 + +# 上下文最大字符数(默认值) +EVAL_MAX_CONTEXT_CHARS=8000 + +# ---------------------------------------------------------------------------- +# LLM 参数 +# ---------------------------------------------------------------------------- +# LLM 温度参数(0.0 = 确定性输出) +EVAL_LLM_TEMPERATURE=0.0 + +# LLM 最大生成 token 数 +EVAL_LLM_MAX_TOKENS=32 + +# LLM 超时时间(秒) +EVAL_LLM_TIMEOUT=10.0 + +# LLM 最大重试次数 +EVAL_LLM_MAX_RETRIES=1 + +# ---------------------------------------------------------------------------- +# 数据处理参数 +# ---------------------------------------------------------------------------- +# Chunker 策略 +EVAL_CHUNKER_STRATEGY=RecursiveChunker + +# 是否在导入前清空现有数据 +EVAL_RESET_ON_INGEST=true + +# 是否保存详细日志 +EVAL_SAVE_DETAILED_LOGS=true + +# ============================================================================ +# LoCoMo 基准测试专用配置 +# ============================================================================ +# 数据集:locomo10.json +# 运行:python locomo_benchmark.py --sample_size 20 +# ---------------------------------------------------------------------------- + +# Group ID(LoCoMo 专用) +LOCOMO_GROUP_ID=locomo_benchmark + +# 测试样本数量 +# 建议值:20(快速测试)、100(中等测试)、1986(完整测试) +LOCOMO_SAMPLE_SIZE=20 + +# 检索结果数量限制 +LOCOMO_SEARCH_LIMIT=12 + +# 上下文最大字符数 +LOCOMO_CONTEXT_CHAR_BUDGET=8000 + +# 导入的对话数量 +LOCOMO_MAX_DIALOGUES=1 + +# 跳过数据摄入(true=跳过,false=摄入) +# 首次运行设置为 false,后续运行可设置为 true 以节省时间 +LOCOMO_SKIP_INGEST=false + +# 结果保存目录 +LOCOMO_OUTPUT_DIR=locomo/results + +# ============================================================================ +# LongMemEval 基准测试专用配置 +# ============================================================================ +# 数据集:longmemeval_oracle_zh.json +# 运行:python longmemeval_benchmark.py --sample_size 3 +# 特点:支持时间推理问题的增强检索 +# ---------------------------------------------------------------------------- + +# Group ID(LongMemEval 专用) +LONGMEMEVAL_GROUP_ID=longmemeval_zh_bak_3 + +# 测试样本数量(<=0 表示全部样本) +LONGMEMEVAL_SAMPLE_SIZE=3 + +# 起始样本索引 +LONGMEMEVAL_START_INDEX=0 + +# 检索结果数量限制 +LONGMEMEVAL_SEARCH_LIMIT=8 + +# 上下文最大字符数 +LONGMEMEVAL_CONTEXT_CHAR_BUDGET=4000 + +# LLM 最大生成 token 数 +LONGMEMEVAL_LLM_MAX_TOKENS=16 + +# 每条样本最多摄入的上下文段数 +LONGMEMEVAL_MAX_CONTEXTS_PER_ITEM=2 + +# 是否保存分块结果 +LONGMEMEVAL_SAVE_CHUNK_OUTPUT=true + +# 自定义分块输出路径(留空使用默认) +LONGMEMEVAL_SAVE_CHUNK_OUTPUT_PATH= + +# 摄入前是否清空组数据 +LONGMEMEVAL_RESET_GROUP_BEFORE_INGEST=false + +# 是否跳过摄入,仅检索评估 +LONGMEMEVAL_SKIP_INGEST=false + +# 结果保存目录 +LONGMEMEVAL_OUTPUT_DIR=longmemeval/results + +# ============================================================================ +# MemSciQA 基准测试专用配置 +# ============================================================================ +# 数据集:msc_self_instruct.jsonl +# 运行:python memsciqa_benchmark.py --sample_size 1 +# 特点:对话记忆检索评估 +# ---------------------------------------------------------------------------- + +# Group ID(MemSciQA 专用,独立数据集) +MEMSCIQA_GROUP_ID=memsciqa_benchmark + +# 测试样本数量 +MEMSCIQA_SAMPLE_SIZE=1 # 0或者-1标识测试数据集中的所有样本 + +# 检索结果数量限制 +MEMSCIQA_SEARCH_LIMIT=8 + +# 上下文最大字符数 +MEMSCIQA_CONTEXT_CHAR_BUDGET=4000 + +# LLM 最大生成 token 数 +MEMSCIQA_LLM_MAX_TOKENS=64 + +# 跳过数据摄入(true=跳过,false=摄入) +# 首次运行设置为 false,后续运行可设置为 true 以节省时间 +MEMSCIQA_SKIP_INGEST=false + +# 结果保存目录(相对于 memsciqa 脚本所在目录) +# 使用 "results" 会保存到 api/app/core/memory/evaluation/memsciqa/results/ +MEMSCIQA_OUTPUT_DIR=results + +# ============================================================================ +# 高级配置(可选) +# ============================================================================ + +# BM25 权重(用于混合检索,0.0-1.0) +EVAL_RERANK_ALPHA=0.6 + +# 是否使用遗忘重排序 +EVAL_USE_FORGETTING_RERANK=false + +# 是否使用 LLM 重排序 +EVAL_USE_LLM_RERANK=false + +# 连接重置间隔(每 N 个问题重置一次) +EVAL_RESET_INTERVAL=5 + +# 性能阈值(低于此值触发重置) +EVAL_PERFORMANCE_THRESHOLD=0.6 + +# ============================================================================ +# 快速配置指南 +# ============================================================================ +# 1. 复制此文件为 .env.evaluation +# 2. 修改 EVAL_LLM_ID 和 EVAL_EMBEDDING_ID 为你的模型 ID +# 3. 根据需要修改各基准测试的专用配置 +# 4. 运行测试: +# - LoCoMo: python locomo/locomo_benchmark.py --sample_size 20 +# - LongMemEval: python longmemeval/longmemeval_benchmark.py --sample_size 3 --all +# - MemSciQA: python memsciqa/memsciqa_benchmark.py --sample_size 10 +# 配置优先级: +# 命令行参数 > 特定配置(如 LOCOMO_*)> 通用配置(EVAL_*)> 代码默认值 +# ============================================================================ + + +# 执行LoCoMo测试 +# 只摄入前5条消息,评估3个问题(最小测试) +# python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 3 --max_ingest_messages 5 +# +# 如果数据已经摄入,跳过摄入阶段直接测试 +# python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 5 --skip_ingest + + +# 执行longmemeval测试 +# python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark --sample-size 10 --max-contexts-per-item 3 --reset-group-before-ingest + +# 执行memsciqa测试 +# python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark --sample-size 1 diff --git a/api/app/core/memory/evaluation/.gitignore b/api/app/core/memory/evaluation/.gitignore new file mode 100644 index 00000000..38b1055a --- /dev/null +++ b/api/app/core/memory/evaluation/.gitignore @@ -0,0 +1,13 @@ +# 忽略实际的评估配置文件(包含敏感信息) +.env.evaluation + +# 保留示例文件 +!.env.evaluation.example + +# 忽略测试结果文件 +*/results/*.json +*/results/*.log + +# 忽略数据集文件(文件过大,不应提交到 Git) +dataset/*.json +dataset/*.jsonl diff --git a/api/app/core/memory/evaluation/benchmark.md b/api/app/core/memory/evaluation/benchmark.md index 2853b22b..7c31cccd 100644 --- a/api/app/core/memory/evaluation/benchmark.md +++ b/api/app/core/memory/evaluation/benchmark.md @@ -1,30 +1,748 @@ -⏬数据集下载地址: - Locomo10.json:https://github.com/snap-research/locomo/tree/main/data - LongMemEval_oracle.json:https://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned - msc_self_instruct.jsonl:https://huggingface.co/datasets/MemGPT/MSC-Self-Instruct - 上方数据集下载好后全部放入app/core/memory/data文件夹中 +# 1.数据集下载地址 +Locomo10.json : https://github.com/snap-research/locomo/tree/main/data +LongMemEval_oracle.json : https://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned +msc_self_instruct.jsonl : https://huggingface.co/datasets/MemGPT/MSC-Self-Instruct -全流程基准测试运行: - locomo: - python -m app.core.memory.evaluation.run_eval --dataset locomo --sample-size 1 --reset-group --group-id yyw1 --search-type hybrid --search-limit 8 --context-char-budget 12000 --llm-max-tokens 32 - LongMemEval: - python -m app.core.memory.evaluation.run_eval --dataset longmemeval --sample-size 10 --start-index 0 --group-id longmemeval_zh_bak_2 --search-limit 8 --context-char-budget 4000 --search-type hybrid --max-contexts-per-item 2 --reset-group - memsciqa: - python -m app.core.memory.evaluation.run_eval --dataset memsciqa --sample-size 10 --reset-group --group-id group_memsci +数据集下载之后保存至api\app\core\memory\evaluation\dataset目录下 +# 2.配置说明 +文件api\app\core\memory\evaluation\.env.evaluation.example对三个基准测试所需配置有着详细的说明 +**实际配置文件**:api\app\core\memory\evaluation\.env.evaluation +```python +# 当使用不带配置参数的命令行执行基准测试,基准测试所需的配置参数根据.env.evaluation中的参数执行 +python -m app.core.memory.evaluation.locomo.locomo_benchmark +``` +**检查neo4j指定的grou_id是否摄入数据** +```python +# 1. 进入交互模式 +python -m app.core.memory.evaluation.check_enduser_data -单独检索评估运行命令: - python -m app.core.memory.evaluation.locomo.locomo_test - python -m app.core.memory.evaluation.longmemeval.test_eval - python -m app.core.memory.evaluation.memsciqa.memsciqa-test - 需要先在项目中修改需要检测评估的group_id。 +# 2. 选择 "1" 检查指定 group +# 3. 输入 group_id,例如: locomo_benchmark +# 4. 选择是否显示详细统计 (y/n) +``` +# 3.locomo -参数及解释: - ● --dataset longmemeval - 指定数据集 - ● --sample-size 10 - 评估10个样本 - ● --start-index 0 - 从第0个样本开始 - ● --group-id longmemeval_zh_bak_2 - 使用指定的组ID - ● --search-limit 8 - 检索限制8条 - ● --context-char-budget 4000 - 上下文字符预算4000 - ● --search-type hybrid - 使用混合检索 - ● --max-contexts-per-item 2 - 每个样本最多摄入2个上下文 - ● --reset-group - 运行前清空组数据 \ No newline at end of file +### (1)locomo执行命令 +```python +# 首先进入api目录 +cd api + +# 只摄入前5条消息,评估3个问题(最小测试) +python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 3 --max_ingest_messages 5 + +# 如果数据已经摄入,跳过摄入阶段直接测试(使用skip_ingest参数) +python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 5 --skip_ingest +``` +### (2)locomo结果说明 + +#### 结果示例 +```json +{ + "dataset": "locomo", + "sample_size": 0, + "timestamp": "2026-01-26T11:24:28.239156", + "params": { + "group_id": "locomo_benchmark", + "search_type": "hybrid", + "search_limit": 12, + "context_char_budget": 8000, + "llm_id": "2c9b0782-7a85-4740-ba84-4baf77f256c4", + "embedding_id": "e2a6392d-ca63-4d59-a523-647420b59cb2" + }, + "overall_metrics": { + "f1": 0.0, + "bleu1": 0.0, + "jaccard": 0.0, + "locomo_f1": 0.0 + }, + "by_category": {}, + "latency": { + "search": { + "mean": 0.0, + "p50": 0.0, + "p95": 0.0, + "iqr": 0.0 + }, + "llm": { + "mean": 0.0, + "p50": 0.0, + "p95": 0.0, + "iqr": 0.0 + } + }, + "context_stats": { + "avg_retrieved_docs": 0.0, + "avg_context_chars": 0.0, + "avg_context_tokens": 0.0 + }, + "samples": [] +} +``` + +#### 参数详解 + +##### 1. 核心评估指标 (overall_metrics) + +**🎯 关键进步指标:** + +- **`f1`** (F1 Score): 精确率和召回率的调和平均值 + - 范围:0.0 - 1.0 + - **越高越好**,衡量检索和生成答案的准确性 + - 这是最重要的综合性能指标 + - 优秀标准:> 0.85 + +- **`bleu1`** (BLEU-1): 单词级别的匹配度 + - 范围:0.0 - 1.0 + - **越高越好**,衡量生成答案与标准答案的词汇重叠度 + - 关注词汇层面的准确性 + +- **`jaccard`** (Jaccard 相似度): 集合相似度 + - 范围:0.0 - 1.0 + - **越高越好**,衡量答案集合的相似性 + - 计算公式:交集大小 / 并集大小 + +- **`locomo_f1`**: Locomo 特定的 F1 分数 + - 范围:0.0 - 1.0 + - **越高越好**,针对 Locomo 数据集优化的评估指标 + - 考虑了长对话记忆的特殊性 + +##### 2. 性能指标 (latency) + +**⚡ 关键效率指标:** + +- **`search`**: 检索延迟统计(单位:毫秒) + - `mean`: 平均延迟 + - `p50`: 中位数延迟(50%的请求在此时间内完成) + - `p95`: 95分位数延迟(95%的请求在此时间内完成) + - `iqr`: 四分位距(Q3-Q1,衡量稳定性) + - **越低越好**,衡量记忆检索速度 + - 优秀标准:p95 < 2000ms + +- **`llm`**: LLM 推理延迟统计(单位:毫秒) + - `mean`: 平均推理时间 + - `p50`: 中位数推理时间 + - `p95`: 95分位数推理时间 + - `iqr`: 四分位距(越小越稳定) + - **越低越好**,衡量答案生成速度 + - 优秀标准:p95 < 3000ms + +##### 3. 上下文统计 (context_stats) + +**📊 资源效率指标:** + +- **`avg_retrieved_docs`**: 平均检索文档数 + - 反映检索策略的广度 + - 需要平衡:太少可能信息不足,太多增加噪音和延迟 + - 建议范围:8-15 个文档 + +- **`avg_context_chars`**: 平均上下文字符数 + - 反映检索内容的总量 + - 应在满足准确性前提下尽量精简 + - 受 `context_char_budget` 参数限制 + +- **`avg_context_tokens`**: 平均上下文 token 数 + - **越低越好**(在保持准确性前提下) + - 直接影响 API 调用成本和推理速度 + - 成本效益比 = f1 / avg_context_tokens + +##### 4. 分类统计 (by_category) + +- 按问题类型分类的性能指标 +- 帮助识别系统在不同场景下的强弱项 +- 可针对性优化特定类型的问题 + +#### 系统进步衡量标准 + +**一级指标(最重要):** +- `f1` 和 `locomo_f1` 提升 → 核心能力提升 +- 目标:f1 > 0.85 + +**二级指标(重要):** +- `latency.p95` 降低 → 用户体验提升 +- 目标:search.p95 < 2000ms, llm.p95 < 3000ms + +**三级指标(辅助):** +- `avg_context_tokens` 降低(在保持 f1 前提下)→ 成本优化 +- `iqr` 降低 → 性能稳定性提升 +# 4.longmemeval +支持时间推理问题的增强检索 +### (1)执行命令 +```python +# 首先进入api目录 +cd api + +# 不带参数运行 - 使用环境变量 +python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark + +# 命令行参数覆盖环境变量 +python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark --sample-size 2 + +# 如果数据已经摄入,跳过摄入阶段直接测试(使用skip_ingest参数) +python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark --skip_ingest +``` +### (2)结果说明 + +#### 结果示例 +```json +{ + "dataset": "longmemeval", + "items": 1, + "accuracy_by_type": { + "single-session-user": 1.0 + }, + "f1_by_type": { + "single-session-user": 1.0 + }, + "jaccard_by_type": { + "single-session-user": 1.0 + }, + "samples": [ + { + "question": "What degree did I graduate with?", + "prediction": "Business Administration", + "answer": "Business Administration", + "question_type": "single-session-user", + "is_temporal": false, + "question_id": "e47becba", + "options": [], + "context_count": 13, + "context_chars": 1268, + "retrieved_dialogue_count": 0, + "retrieved_statement_count": 12, + "metrics": { + "exact_match": true, + "f1": 1.0, + "jaccard": 1.0 + }, + "timing": { + "search_ms": 1483.100175857544, + "llm_ms": 995.8682060241699 + } + } + ], + "latency": { + "search": { + "mean": 1483.100175857544, + "p50": 1483.100175857544, + "p95": 1483.100175857544, + "iqr": 0.0 + }, + "llm": { + "mean": 995.8682060241699, + "p50": 995.8682060241699, + "p95": 995.8682060241699, + "iqr": 0.0 + } + }, + "context": { + "avg_tokens": 204.0, + "avg_chars": 1268, + "count_avg": 13 + }, + "params": { + "group_id": "longmemeval_zh_bak_3", + "search_limit": 8, + "context_char_budget": 4000, + "search_type": "hybrid", + "llm_id": "6dc52e1b-9cec-4194-af66-a74c6307fc3f", + "embedding_id": "e2a6392d-ca63-4d59-a523-647420b59cb2", + "sample_size": 1, + "start_index": 0 + }, + "timestamp": "2026-01-24T21:36:10.818308", + "metric_summary": { + "score_accuracy": 100.0, + "latency_median_s": 2.478968381881714, + "latency_iqr_s": 0.0, + "avg_context_tokens_k": 0.204 + }, + "diagnostics": { + "duplicate_previews_top": [], + "unique_preview_count": 1 + } +} +``` + +#### 参数详解 + +##### 1. 核心评估指标 + +**🎯 关键进步指标:** + +- **`accuracy_by_type`**: 按问题类型分类的准确率 + - 范围:0.0 - 1.0 + - **越高越好**,1.0 表示 100% 准确 + - 问题类型包括: + - `single-session-user`: 单会话用户信息 + - `single-session-event`: 单会话事件信息 + - `multi-session-user`: 多会话用户信息 + - `multi-session-event`: 多会话事件信息 + - 可以识别系统在不同场景下的强弱项 + +- **`f1_by_type`**: 按问题类型的 F1 分数 + - 范围:0.0 - 1.0 + - **越高越好**,综合评估精确率和召回率 + - 比单纯的准确率更全面 + +- **`jaccard_by_type`**: 按问题类型的 Jaccard 相似度 + - 范围:0.0 - 1.0 + - **越高越好**,衡量答案集合匹配度 + - 对于集合类答案特别有用 + +##### 2. 样本级指标 (samples) + +**详细诊断指标:** + +- **`metrics.exact_match`**: 精确匹配(布尔值) + - **true 越多越好**,最严格的评估标准 + - 要求预测答案与标准答案完全一致 + +- **`metrics.f1`**: 单个样本的 F1 分数 + - 范围:0.0 - 1.0 + - **越高越好**,衡量单个问题的回答质量 + +- **`is_temporal`**: 是否为时间推理问题 + - 布尔值,标识问题是否涉及时间推理 + - 时间推理问题通常更具挑战性 + +- **`context_count`**: 检索到的上下文数量 + - 反映检索策略的有效性 + - 建议范围:8-15 个上下文片段 + +- **`retrieved_dialogue_count`**: 检索到的对话数 +- **`retrieved_statement_count`**: 检索到的陈述数 + - 这两个指标帮助理解检索的内容类型分布 + - 可用于优化检索策略 + +- **`timing.search_ms`**: 单个问题的检索延迟(毫秒) +- **`timing.llm_ms`**: 单个问题的 LLM 推理延迟(毫秒) + - **越低越好**,反映单次查询的响应速度 + +##### 3. 汇总指标 (metric_summary) + +**📊 关键 KPI:** + +- **`score_accuracy`**: 总体准确率百分比 + - 范围:0.0 - 100.0 + - **越高越好**,最直观的性能指标 + - 优秀标准:> 90.0 + +- **`latency_median_s`**: 中位延迟(秒) + - **越低越好**,反映真实响应速度 + - 优秀标准:< 3.0 秒 + +- **`latency_iqr_s`**: 延迟四分位距(秒) + - **越低越好**,反映性能稳定性 + - 越小说明响应时间越稳定 + +- **`avg_context_tokens_k`**: 平均上下文 token 数(千) + - **越低越好**(在保持准确性前提下) + - 直接影响 API 调用成本 + - 成本效益比 = score_accuracy / (avg_context_tokens_k * 1000) + +##### 4. 上下文统计 (context) + +- **`avg_tokens`**: 平均 token 数 +- **`avg_chars`**: 平均字符数 +- **`count_avg`**: 平均上下文片段数 + - 这些指标反映检索内容的规模 + - 需要在准确性和效率之间平衡 + +##### 5. 性能指标 (latency) + +**⚡ 效率指标:** + +- **`search`**: 检索延迟统计(单位:毫秒) + - `mean`: 平均延迟 + - `p50`: 中位数延迟 + - `p95`: 95分位数延迟 + - `iqr`: 四分位距 + - **越低越好**,衡量记忆检索速度 + +- **`llm`**: LLM 推理延迟统计(单位:毫秒) + - `mean`: 平均推理时间 + - `p50`: 中位数推理时间 + - `p95`: 95分位数推理时间 + - `iqr`: 四分位距 + - **越低越好**,衡量答案生成速度 + +##### 6. 诊断信息 (diagnostics) + +- **`duplicate_previews_top`**: 重复预览统计 + - 列出出现频率最高的重复内容 + - 帮助发现检索冗余问题 + - 应该尽量减少重复 + +- **`unique_preview_count`**: 唯一预览数量 + - 反映检索多样性 + - **越高越好**,说明检索到的内容更丰富 + +#### 系统进步衡量标准 + +**一级指标(最重要):** +- `score_accuracy` 提升 → 核心能力提升 +- 目标:> 90.0% +- 各类型的 `accuracy_by_type` 均衡提升 → 全面能力提升 + +**二级指标(重要):** +- `latency_median_s` 降低 → 用户体验提升 +- 目标:< 3.0 秒 +- `exact_match` 比例提升 → 精确度提升 + +**三级指标(辅助):** +- `avg_context_tokens_k` 降低(在保持准确性前提下)→ 成本优化 +- `unique_preview_count` 提升 → 检索多样性提升 +- `latency_iqr_s` 降低 → 性能稳定性提升 + +**特殊关注:** +- 时间推理问题(`is_temporal: true`)的准确率 +- 多会话问题的准确率(通常更具挑战性) +# 5.memsciqa +对话记忆检索评估 +### (1)执行命令 +```python +# 首先进入api目录 +cd api + +# 不带参数运行 - 使用环境变量 +python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark + +# 命令行参数覆盖环境变量 +python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark --sample-size 100 + +# 如果数据已经摄入,跳过摄入阶段直接测试(使用skip_ingest参数) +python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark --skip_ingest +``` +### (2)结果说明 + +#### 结果示例 +```json +{ + "dataset": "memsciqa", + "items": 1, + "metrics": { + "accuracy": 0.0, + "f1": 0.0, + "bleu1": 0.0, + "jaccard": 0.0 + }, + "latency": { + "search": { + "mean": 0.0, + "p50": 0.0, + "p95": 0.0, + "iqr": 0.0 + }, + "llm": { + "mean": 3067.7285194396973, + "p50": 3067.7285194396973, + "p95": 3067.7285194396973, + "iqr": 0.0 + } + }, + "avg_context_tokens": 4.0 +} +``` + +#### 参数详解 + +##### 1. 核心评估指标 (metrics) + +**🎯 关键进步指标:** + +- **`accuracy`**: 准确率 + - 范围:0.0 - 1.0 + - **越高越好**,最直接的性能指标 + - 衡量系统回答正确的问题比例 + - 优秀标准:> 0.85 + +- **`f1`**: F1 分数 + - 范围:0.0 - 1.0 + - **越高越好**,平衡精确率和召回率 + - 计算公式:2 * (precision * recall) / (precision + recall) + - 比单纯的准确率更全面,特别适合不平衡数据集 + +- **`bleu1`**: BLEU-1 分数 + - 范围:0.0 - 1.0 + - **越高越好**,衡量词汇级别的匹配度 + - 关注生成答案与标准答案的单词重叠 + - 源自机器翻译评估,适用于自然语言生成 + +- **`jaccard`**: Jaccard 相似度 + - 范围:0.0 - 1.0 + - **越高越好**,衡量集合相似性 + - 计算公式:|A ∩ B| / |A ∪ B| + - 对于多答案或集合类问题特别有用 + +##### 2. 性能指标 (latency) + +**⚡ 效率指标:** + +- **`search`**: 检索延迟统计(单位:毫秒) + - `mean`: 平均检索延迟 + - `p50`: 中位数延迟(50%的请求在此时间内完成) + - `p95`: 95分位数延迟(95%的请求在此时间内完成) + - `iqr`: 四分位距(Q3-Q1,衡量稳定性) + - **越低越好**,衡量记忆检索效率 + - 优秀标准:p95 < 2000ms + +- **`llm`**: LLM 推理延迟统计(单位:毫秒) + - `mean`: 平均推理时间 + - `p50`: 中位数推理时间 + - `p95`: 95分位数推理时间 + - `iqr`: 四分位距(越小越稳定) + - **越低越好**,衡量答案生成速度 + - 优秀标准:p95 < 3000ms + - 注意:LLM 延迟通常占总延迟的大部分 + +##### 3. 资源指标 + +- **`avg_context_tokens`**: 平均上下文 token 数 + - **越低越好**(在保持准确性前提下) + - 直接影响: + - API 调用成本(按 token 计费) + - 推理速度(token 越多越慢) + - 上下文窗口占用 + - 成本效益比 = accuracy / avg_context_tokens + - 建议范围:根据模型上下文窗口和成本预算调整 + +##### 4. 数据集特点 + +- **`items`**: 评估的问题数量 + - 样本量越大,评估结果越可靠 + - 建议至少 100 个样本以获得稳定的评估结果 + +- **对话记忆特性**: + - MemSciQA 专注于对话历史中的记忆检索 + - 评估系统从多轮对话中提取和回忆信息的能力 + - 模拟真实的对话场景 + +#### 系统进步衡量标准 + +**一级指标(最重要):** +- `accuracy` 提升 → 核心能力提升 +- 目标:> 0.85 +- `f1` 提升 → 综合性能提升 +- 目标:> 0.80 + +**二级指标(重要):** +- `latency.p95` 降低 → 用户体验提升 + - search.p95 目标:< 2000ms + - llm.p95 目标:< 3000ms +- `iqr` 降低 → 性能稳定性提升 + +**三级指标(辅助):** +- `avg_context_tokens` 降低(在保持准确性前提下)→ 成本优化 +- `bleu1` 和 `jaccard` 提升 → 答案质量提升 + +**综合评估:** +- 成本效益比 = accuracy / avg_context_tokens + - 该比值越高,说明系统在相同成本下性能越好 +- 总延迟 = search.p95 + llm.p95 + - 应控制在 5 秒以内以保证良好的用户体验 + +#### 优化建议 + +**提升准确性:** +- 优化检索算法(调整 hybrid search 参数) +- 改进 embedding 模型质量 +- 增加检索上下文数量(`search_limit`) +- 优化 prompt 工程 + +**提升效率:** +- 减少不必要的检索文档 +- 使用更快的 LLM 模型或量化版本 +- 实施缓存策略(相似问题复用结果) +- 优化数据库索引 + +**平衡性能:** +- 监控 accuracy vs latency 的权衡 +- 监控 accuracy vs cost (tokens) 的权衡 +- 根据业务需求调整优先级 + + +--- + +# 6. 三个基准测试对比总结 + +## 6.1 测试特点对比 + +| 基准测试 | 主要评估目标 | 数据集特点 | 适用场景 | +|---------|------------|-----------|---------| +| **Locomo** | 长对话记忆检索 | 长对话历史,多轮交互 | 评估长期记忆保持和检索能力 | +| **LongMemEval** | 时间推理和多会话记忆 | 支持时间推理,多会话场景 | 评估时间感知和跨会话记忆能力 | +| **MemSciQA** | 对话记忆问答 | 对话历史问答 | 评估对话上下文理解和记忆提取 | + +## 6.2 核心指标对比 + +### 准确性指标 + +| 指标 | Locomo | LongMemEval | MemSciQA | 说明 | +|-----|--------|-------------|----------|------| +| **F1 Score** | ✅ | ✅ | ✅ | 所有测试都使用,最重要的综合指标 | +| **Accuracy** | ❌ | ✅ | ✅ | 直观的准确率指标 | +| **BLEU-1** | ✅ | ❌ | ✅ | 词汇级别匹配度 | +| **Jaccard** | ✅ | ✅ | ✅ | 集合相似度 | +| **Exact Match** | ❌ | ✅ | ❌ | 最严格的评估标准 | + +### 性能指标 + +所有三个测试都包含: +- **检索延迟** (search latency): mean, p50, p95, iqr +- **LLM 延迟** (llm latency): mean, p50, p95, iqr +- **上下文统计**: token 数、字符数、文档数 + +## 6.3 关键进步指标优先级 + +### 🥇 一级指标(必须关注) + +1. **准确性指标** + - Locomo: `f1`, `locomo_f1` + - LongMemEval: `score_accuracy`, `accuracy_by_type` + - MemSciQA: `accuracy`, `f1` + - **目标**: > 85% 或 > 0.85 + +2. **综合性能** + - 所有测试的 F1 分数应保持一致性 + - 不同类型问题的准确率应均衡 + +### 🥈 二级指标(重要) + +3. **响应延迟** + - `latency.p95` (95分位数延迟) + - **目标**: + - search.p95 < 2000ms + - llm.p95 < 3000ms + - 总延迟 < 5000ms + +4. **性能稳定性** + - `iqr` (四分位距) + - **目标**: 越小越好,说明性能稳定 + +### 🥉 三级指标(优化) + +5. **成本效率** + - `avg_context_tokens` + - **目标**: 在保持准确性前提下最小化 + - 成本效益比 = accuracy / avg_context_tokens + +6. **检索质量** + - `avg_retrieved_docs` 的合理性 + - `unique_preview_count` (LongMemEval) + - 检索内容的多样性和相关性 + +## 6.4 系统优化路径 + +### 阶段一:提升准确性(优先级最高) + +**目标**: 所有测试的准确率 > 85% + +**优化方向**: +1. 改进 embedding 模型质量 +2. 优化检索算法(hybrid search 参数) +3. 增加检索上下文数量(`search_limit`) +4. 优化 prompt 工程 +5. 改进记忆存储结构 + +**监控指标**: +- Locomo: `f1`, `locomo_f1` +- LongMemEval: `score_accuracy`, `exact_match` 比例 +- MemSciQA: `accuracy`, `f1` + +### 阶段二:优化性能(准确性达标后) + +**目标**: p95 延迟 < 5 秒,性能稳定 + +**优化方向**: +1. 优化数据库索引和查询 +2. 实施缓存策略 +3. 使用更快的 LLM 模型 +4. 并行化检索和推理 +5. 减少不必要的检索 + +**监控指标**: +- `latency.p50`, `latency.p95` +- `iqr` (稳定性) +- 各阶段耗时分布 + +### 阶段三:降低成本(性能达标后) + +**目标**: 在保持准确性和性能前提下,最小化成本 + +**优化方向**: +1. 精简检索上下文 +2. 优化 context 选择策略 +3. 使用更小的 LLM 模型 +4. 实施智能缓存 +5. 批处理优化 + +**监控指标**: +- `avg_context_tokens` +- 成本效益比 = accuracy / avg_context_tokens +- API 调用成本 + +## 6.5 评估最佳实践 + +### 测试执行建议 + +1. **初始测试**: 使用小样本快速验证 + ```bash + --sample_size 10 + ``` + +2. **完整评估**: 使用足够大的样本量 + ```bash + --sample_size 100 # 或更多 + ``` + +3. **增量测试**: 数据已摄入时跳过摄入阶段 + ```bash + --skip_ingest + ``` + +4. **参数调优**: 系统性地调整参数并记录结果 + - 调整 `search_limit`: 4, 8, 12, 16 + - 调整 `context_char_budget`: 2000, 4000, 8000 + - 尝试不同的 `search_type`: vector, keyword, hybrid + +### 结果分析建议 + +1. **横向对比**: 比较三个测试的结果,识别系统的强弱项 +2. **纵向对比**: 跟踪同一测试在不同版本的表现 +3. **分类分析**: 关注不同问题类型的性能差异 +4. **异常诊断**: 分析失败案例,找出根本原因 + +### 持续监控 + +建议建立监控仪表板,跟踪: +- 核心指标趋势(准确率、延迟) +- 成本效益比趋势 +- 不同问题类型的性能分布 +- 异常样本和失败模式 + +## 6.6 性能基准参考 + +### 优秀水平(Production Ready) + +- **准确性**: accuracy/f1 > 0.90 +- **延迟**: p95 < 3 秒 +- **稳定性**: iqr < 500ms +- **成本效益**: accuracy/tokens > 0.0001 + +### 良好水平(Acceptable) + +- **准确性**: accuracy/f1 > 0.85 +- **延迟**: p95 < 5 秒 +- **稳定性**: iqr < 1000ms +- **成本效益**: accuracy/tokens > 0.00005 + +### 需要改进(Below Target) + +- **准确性**: accuracy/f1 < 0.85 +- **延迟**: p95 > 5 秒 +- **稳定性**: iqr > 1000ms +- **成本效益**: accuracy/tokens < 0.00005 + +--- + +**注**: 以上标准仅供参考,实际目标应根据具体业务需求和资源约束调整。 diff --git a/api/app/core/memory/evaluation/check_enduser_data.py b/api/app/core/memory/evaluation/check_enduser_data.py new file mode 100644 index 00000000..18ecbb34 --- /dev/null +++ b/api/app/core/memory/evaluation/check_enduser_data.py @@ -0,0 +1,371 @@ +""" +交互式 Neo4j End User 数据检查工具 + +用于查询指定 end_user_id 在 Neo4j 中是否存在数据,以及数据的详细统计信息。 + +使用方法: + python check_group_data.py + python check_group_data.py --group-id locomo_benchmark + python check_group_data.py --group-id memsciqa_benchmark --detailed +""" + +import asyncio +import argparse +import os +from pathlib import Path +from typing import Dict, Any +from dotenv import load_dotenv + +# Load evaluation config +eval_config_path = Path(__file__).resolve().parent / ".env.evaluation" +if eval_config_path.exists(): + load_dotenv(eval_config_path, override=True) + print(f"✅ 加载评估配置: {eval_config_path}\n") + +from app.repositories.neo4j.neo4j_connector import Neo4jConnector + + +async def check_group_exists(end_user_id: str) -> Dict[str, Any]: + """ + 检查指定 end_user_id 是否存在数据 + + Args: + end_user_id: 要检查的 end_user ID + + Returns: + 包含统计信息的字典 + """ + connector = Neo4jConnector() + + try: + # 查询该 end_user 的节点总数 + query_total = """ + MATCH (n {end_user_id: $end_user_id}) + RETURN count(n) as total_nodes + """ + result_total = await connector.execute_query(query_total, end_user_id=end_user_id) + total_nodes = result_total[0]["total_nodes"] if result_total else 0 + + # 查询各类型节点的数量 + query_by_type = """ + MATCH (n {end_user_id: $end_user_id}) + RETURN labels(n) as labels, count(n) as count + ORDER BY count DESC + """ + result_by_type = await connector.execute_query(query_by_type, end_user_id=end_user_id) + + # 查询关系数量 + query_relationships = """ + MATCH (n {end_user_id: $end_user_id})-[r]-() + RETURN count(DISTINCT r) as total_relationships + """ + result_rel = await connector.execute_query(query_relationships, end_user_id=end_user_id) + total_relationships = result_rel[0]["total_relationships"] if result_rel else 0 + + return { + "exists": total_nodes > 0, + "total_nodes": total_nodes, + "total_relationships": total_relationships, + "nodes_by_type": result_by_type + } + + finally: + await connector.close() + + +async def get_detailed_stats(end_user_id: str) -> Dict[str, Any]: + """ + 获取详细的统计信息 + + Args: + end_user_id: 要检查的 end_user ID + + Returns: + 详细统计信息字典 + """ + connector = Neo4jConnector() + + try: + stats = {} + + # Chunk 节点统计 + query_chunks = """ + MATCH (c:Chunk {end_user_id: $end_user_id}) + RETURN count(c) as count, + avg(size(c.content)) as avg_content_length + """ + result_chunks = await connector.execute_query(query_chunks, end_user_id=end_user_id) + if result_chunks and result_chunks[0]["count"] > 0: + stats["chunks"] = { + "count": result_chunks[0]["count"], + "avg_content_length": int(result_chunks[0]["avg_content_length"]) if result_chunks[0]["avg_content_length"] else 0 + } + + # Statement 节点统计 + query_statements = """ + MATCH (s:Statement {end_user_id: $end_user_id}) + RETURN count(s) as count + """ + result_statements = await connector.execute_query(query_statements, end_user_id=end_user_id) + if result_statements and result_statements[0]["count"] > 0: + stats["statements"] = { + "count": result_statements[0]["count"] + } + + # Entity 节点统计 + query_entities = """ + MATCH (e:Entity {end_user_id: $end_user_id}) + RETURN count(e) as count, + count(DISTINCT e.entity_type) as unique_types + """ + result_entities = await connector.execute_query(query_entities, end_user_id=end_user_id) + if result_entities and result_entities[0]["count"] > 0: + stats["entities"] = { + "count": result_entities[0]["count"], + "unique_types": result_entities[0]["unique_types"] + } + + # Dialogue 节点统计 + query_dialogues = """ + MATCH (d:Dialogue {end_user_id: $end_user_id}) + RETURN count(d) as count + """ + result_dialogues = await connector.execute_query(query_dialogues, end_user_id=end_user_id) + if result_dialogues and result_dialogues[0]["count"] > 0: + stats["dialogues"] = { + "count": result_dialogues[0]["count"] + } + + # Summary 节点统计 + query_summaries = """ + MATCH (s:Summary {end_user_id: $end_user_id}) + RETURN count(s) as count + """ + result_summaries = await connector.execute_query(query_summaries, end_user_id=end_user_id) + if result_summaries and result_summaries[0]["count"] > 0: + stats["summaries"] = { + "count": result_summaries[0]["count"] + } + + return stats + + finally: + await connector.close() + + +async def list_all_end_users() -> list: + """ + 列出数据库中所有的 end_user_id + + Returns: + end_user_id 列表及其节点数量 + """ + connector = Neo4jConnector() + + try: + query = """ + MATCH (n) + WHERE n.end_user_id IS NOT NULL + RETURN DISTINCT n.end_user_id as end_user_id, count(n) as node_count + ORDER BY node_count DESC + """ + results = await connector.execute_query(query) + return results + + finally: + await connector.close() + + +def print_results(end_user_id: str, stats: Dict[str, Any], detailed_stats: Dict[str, Any] = None): + """ + 打印查询结果 + + Args: + end_user_id: End User ID + stats: 基本统计信息 + detailed_stats: 详细统计信息(可选) + """ + print(f"\n{'='*60}") + print(f"📊 End User ID: {end_user_id}") + print(f"{'='*60}\n") + + if not stats["exists"]: + print("❌ 该 end_user_id 不存在数据") + print("\n💡 提示: 请先运行基准测试以摄入数据") + return + + print(f"✅ 该 end_user_id 存在数据\n") + print(f"📈 基本统计:") + print(f" 总节点数: {stats['total_nodes']}") + print(f" 总关系数: {stats['total_relationships']}") + + if stats["nodes_by_type"]: + print(f"\n📋 节点类型分布:") + for item in stats["nodes_by_type"]: + labels = ", ".join(item["labels"]) + count = item["count"] + print(f" {labels}: {count}") + + if detailed_stats: + print(f"\n🔍 详细统计:") + + if "chunks" in detailed_stats: + print(f" Chunks: {detailed_stats['chunks']['count']} 个") + print(f" 平均内容长度: {detailed_stats['chunks']['avg_content_length']} 字符") + + if "statements" in detailed_stats: + print(f" Statements: {detailed_stats['statements']['count']} 个") + + if "entities" in detailed_stats: + print(f" Entities: {detailed_stats['entities']['count']} 个") + print(f" 唯一类型数: {detailed_stats['entities']['unique_types']}") + + if "dialogues" in detailed_stats: + print(f" Dialogues: {detailed_stats['dialogues']['count']} 个") + + if "summaries" in detailed_stats: + print(f" Summaries: {detailed_stats['summaries']['count']} 个") + + print(f"\n{'='*60}\n") + + +async def interactive_mode(): + """ + 交互式模式 + """ + print("\n" + "="*60) + print("🔍 Neo4j End User 数据检查工具 - 交互模式") + print("="*60 + "\n") + + while True: + print("\n请选择操作:") + print(" 1. 检查指定 end_user_id") + print(" 2. 列出所有 end_user_id") + print(" 3. 退出") + + choice = input("\n请输入选项 (1-3): ").strip() + + if choice == "1": + end_user_id = input("\n请输入 end_user_id: ").strip() + if not end_user_id: + print("❌ end_user_id 不能为空") + continue + + detailed = input("是否显示详细统计? (y/n, 默认 n): ").strip().lower() == 'y' + + print("\n🔄 正在查询...") + stats = await check_group_exists(end_user_id) + + detailed_stats = None + if detailed and stats["exists"]: + detailed_stats = await get_detailed_stats(end_user_id) + + print_results(end_user_id, stats, detailed_stats) + + elif choice == "2": + print("\n🔄 正在查询所有 end_user_id...") + end_users = await list_all_end_users() + + if not end_users: + print("\n❌ 数据库中没有任何 end_user 数据") + else: + print(f"\n{'='*60}") + print(f"📋 数据库中的所有 End User ID") + print(f"{'='*60}\n") + + for idx, end_user in enumerate(end_users, 1): + print(f" {idx}. {end_user['end_user_id']}") + print(f" 节点数: {end_user['node_count']}") + + print(f"\n{'='*60}\n") + + elif choice == "3": + print("\n👋 再见!") + break + + else: + print("\n❌ 无效的选项,请重新选择") + + +async def main(): + """ + 主函数 + """ + parser = argparse.ArgumentParser( + description="检查 Neo4j 中指定 end_user_id 的数据情况", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +示例: + # 交互模式 + python check_group_data.py + + # 检查指定 end_user + python check_group_data.py --end-user-id locomo_benchmark + + # 检查并显示详细统计 + python check_group_data.py --end-user-id memsciqa_benchmark --detailed + + # 列出所有 end_user + python check_group_data.py --list-all + """ + ) + + parser.add_argument( + "--end-user-id", + type=str, + help="要检查的 end_user ID" + ) + + parser.add_argument( + "--detailed", + action="store_true", + help="显示详细统计信息" + ) + + parser.add_argument( + "--list-all", + action="store_true", + help="列出所有 end_user_id" + ) + + args = parser.parse_args() + + # 如果没有提供任何参数,进入交互模式 + if not args.end_user_id and not args.list_all: + await interactive_mode() + return + + # 列出所有 end_user + if args.list_all: + print("\n🔄 正在查询所有 end_user_id...") + end_users = await list_all_end_users() + + if not end_users: + print("\n❌ 数据库中没有任何 end_user 数据") + else: + print(f"\n{'='*60}") + print(f"📋 数据库中的所有 End User ID") + print(f"{'='*60}\n") + + for idx, end_user in enumerate(end_users, 1): + print(f" {idx}. {end_user['end_user_id']}") + print(f" 节点数: {end_user['node_count']}") + + print(f"\n{'='*60}\n") + return + + # 检查指定 end_user + if args.end_user_id: + print(f"\n🔄 正在查询 end_user_id: {args.end_user_id}...") + stats = await check_group_exists(args.end_user_id) + + detailed_stats = None + if args.detailed and stats["exists"]: + print("🔄 正在获取详细统计...") + detailed_stats = await get_detailed_stats(args.end_user_id) + + print_results(args.end_user_id, stats, detailed_stats) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/api/app/core/memory/evaluation/common/metrics.py b/api/app/core/memory/evaluation/common/metrics.py index acc27fb9..961ce7f0 100644 --- a/api/app/core/memory/evaluation/common/metrics.py +++ b/api/app/core/memory/evaluation/common/metrics.py @@ -2,7 +2,7 @@ import math import re from typing import List, Dict - +# 评估指标的实现 def _normalize(text: str) -> List[str]: """Lowercase, strip punctuation, and split into tokens.""" text = text.lower().strip() diff --git a/api/app/core/memory/evaluation/dialogue_queries.py b/api/app/core/memory/evaluation/dialogue_queries.py index 25abe64e..0aace0ec 100644 --- a/api/app/core/memory/evaluation/dialogue_queries.py +++ b/api/app/core/memory/evaluation/dialogue_queries.py @@ -4,15 +4,17 @@ This file contains Cypher queries for searching dialogues, entities, and chunks. Placed in evaluation directory to avoid circular imports with src modules. """ +# 应该是neo4j browser的cypher语句,需要修改文件名 + # Entity search queries SEARCH_ENTITIES_BY_NAME = """ -MATCH (e:Entity) +MATCH (e:ExtractedEntity) WHERE e.name = $name RETURN e """ SEARCH_ENTITIES_BY_NAME_FALLBACK = """ -MATCH (e:Entity) +MATCH (e:ExtractedEntity) WHERE e.name CONTAINS $name RETURN e """ diff --git a/api/app/core/memory/evaluation/extraction_utils.py b/api/app/core/memory/evaluation/extraction_utils.py index 9e70bc28..43ef6fe0 100644 --- a/api/app/core/memory/evaluation/extraction_utils.py +++ b/api/app/core/memory/evaluation/extraction_utils.py @@ -1,34 +1,33 @@ +import os import asyncio import json -import os -import re +from typing import List, Dict, Any, Optional from datetime import datetime -from typing import Any, Dict, List, Optional +from uuid import UUID +import re from app.core.memory.llm_tools.openai_client import LLMClient -from app.core.memory.models.message_models import ( - ConversationContext, - ConversationMessage, - DialogData, -) +from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker +from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage +import os +import sys +from pathlib import Path +from dotenv import load_dotenv + +# Load evaluation config +eval_config_path = Path(__file__).resolve().parent / "app" / "core" / "memory" / "evaluation" / ".env.evaluation" +if eval_config_path.exists(): + load_dotenv(eval_config_path, override=True) + print(f"✅ 加载评估配置: {eval_config_path}") + +from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.core.memory.utils.llm.llm_utils import get_llm_client # 使用新的模块化架构 -from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ( - ExtractionOrchestrator, -) -from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import ( - DialogueChunker, -) -from app.core.memory.utils.config.definitions import ( - SELECTED_CHUNKER_STRATEGY, - SELECTED_EMBEDDING_ID, -) -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context +from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator # Import from database module from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j -from app.repositories.neo4j.neo4j_connector import Neo4jConnector # Cypher queries for evaluation # Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py @@ -41,11 +40,14 @@ async def ingest_contexts_via_full_pipeline( embedding_name: str | None = None, save_chunk_output: bool = False, save_chunk_output_path: str | None = None, + reset_group: bool = False, ) -> bool: - """DEPRECATED: 此函数使用旧的流水线架构,建议使用新的 ExtractionOrchestrator + """ + 使用新的 ExtractionOrchestrator 运行完整的提取流水线 Run the full extraction pipeline on provided dialogue contexts and save to Neo4j. - This function mirrors the steps in main(), but starts from raw text contexts. + This function uses the new ExtractionOrchestrator architecture for better maintainability. + Args: contexts: List of dialogue texts, each containing lines like "role: message". end_user_id: Group ID to assign to generated DialogData and graph nodes. @@ -53,25 +55,59 @@ async def ingest_contexts_via_full_pipeline( embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID. save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging. save_chunk_output_path: Optional output path; defaults to src/chunker_test_output.txt. + reset_group: If True, clear existing data for this group before ingestion. Returns: True if data saved successfully, False otherwise. """ - chunker_strategy = chunker_strategy or SELECTED_CHUNKER_STRATEGY - embedding_name = embedding_name or SELECTED_EMBEDDING_ID + chunker_strategy = chunker_strategy or os.getenv("EVAL_CHUNKER_STRATEGY", "RecursiveChunker") + embedding_name = embedding_name or os.getenv("EVAL_EMBEDDING_ID") + + # Check if we should reset from environment variable if not explicitly set + if not reset_group: + reset_group = os.getenv("EVAL_RESET_ON_INGEST", "false").lower() in ("true", "1", "yes") + + # Step 0: Reset group if requested + if reset_group: + print(f"[Ingestion] 🗑️ 清空 end_user '{end_user_id}' 的现有数据...") + try: + from app.repositories.neo4j.neo4j_connector import Neo4jConnector + connector = Neo4jConnector() + try: + # 删除该 end_user 的所有节点和关系 + query = """ + MATCH (n {end_user_id: $end_user_id}) + DETACH DELETE n + """ + await connector.execute_query(query, end_user_id=end_user_id) + print(f"[Ingestion] ✅ End User '{end_user_id}' 已清空") + finally: + await connector.close() + except Exception as e: + print(f"[Ingestion] ⚠️ 清空 end_user 失败: {e}") + # 继续执行,不中断摄入流程 - # Initialize llm client with graceful fallback + # Step 1: Initialize LLM client llm_client = None - llm_available = True try: - from app.core.memory.utils.config import definitions as config_defs - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID) + # 使用评估配置中的 LLM ID + llm_id = os.getenv("EVAL_LLM_ID") + if not llm_id: + print("[Ingestion] ❌ EVAL_LLM_ID not set in .env.evaluation") + return False + + from app.db import get_db + + db = next(get_db()) + try: + llm_client = get_llm_client(llm_id, db) + finally: + db.close() except Exception as e: - print(f"[Ingestion] LLM client unavailable, will skip LLM-dependent steps: {e}") - llm_available = False + print(f"[Ingestion] LLM client unavailable: {e}") + return False - # Step A: Build DialogData list from contexts with robust parsing + # Step 2: Parse contexts and create DialogData with chunks + print(f"[Ingestion] Parsing {len(contexts)} contexts...") chunker = DialogueChunker(chunker_strategy) dialog_data_list: List[DialogData] = [] @@ -94,7 +130,7 @@ async def ingest_contexts_via_full_pipeline( line = raw.strip() if not line: continue - m = re.match(r'^\s*([^::]+)\s*[::]\s*(.+)$', line) + m = re.match(r'^\s*([^::]+)\s*[::]\s*(.+)', line) if m: role = m.group(1).strip() msg = m.group(2).strip() @@ -118,10 +154,12 @@ async def ingest_contexts_via_full_pipeline( dialog_data_list.append(dialog) if not dialog_data_list: - print("No dialogs to process for ingestion.") + print("[Ingestion] No dialogs to process.") return False - # Optionally save chunking outputs for debugging + print(f"[Ingestion] Parsed {len(dialog_data_list)} dialogs with chunks") + + # Step 3: Optionally save chunking outputs for debugging if save_chunk_output: try: def _serialize_datetime(obj): @@ -137,124 +175,185 @@ async def ingest_contexts_via_full_pipeline( combined_output = [dd.model_dump() for dd in dialog_data_list] with open(out_path, "w", encoding="utf-8") as f: json.dump(combined_output, f, ensure_ascii=False, indent=4, default=_serialize_datetime) - print(f"Saved chunking results to: {out_path}") + print(f"[Ingestion] Saved chunking results to: {out_path}") except Exception as e: - print(f"Failed to save chunking results: {e}") + print(f"[Ingestion] Failed to save chunking results: {e}") - # Step B-G: 使用新的 ExtractionOrchestrator 执行完整的提取流水线 - if not llm_available: - print("[Ingestion] Skipping extraction pipeline (no LLM).") - return False - - # 初始化 embedder 客户端 - from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient + # Step 4: Initialize embedder client from app.core.models.base import RedBearModelConfig - from app.services.memory_config_service import MemoryConfigService + from app.core.memory.utils.config.config_utils import get_embedder_config + from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient + from app.db import get_db try: - with get_db_context() as db: - embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID) - embedder_config = RedBearModelConfig(**embedder_config_dict) - embedder_client = OpenAIEmbedderClient(embedder_config) + db = next(get_db()) + try: + embedder_config_dict = get_embedder_config(embedding_name, db) + embedder_config = RedBearModelConfig(**embedder_config_dict) + embedder_client = OpenAIEmbedderClient(embedder_config) + finally: + db.close() except Exception as e: print(f"[Ingestion] Failed to initialize embedder client: {e}") - print("[Ingestion] Skipping extraction pipeline (embedder initialization failed).") return False + # Step 5: Initialize Neo4j connector connector = Neo4jConnector() - # 初始化并运行 ExtractionOrchestrator - from app.core.memory.utils.config.config_utils import get_pipeline_config - config = get_pipeline_config() + # Step 6: 构建 MemoryConfig(从环境变量直接构建,不依赖数据库) + print("[Ingestion] 构建 MemoryConfig from environment variables...") + from app.schemas.memory_config_schema import MemoryConfig + + try: + # 从环境变量获取配置参数 + llm_id = os.getenv("EVAL_LLM_ID") + embedding_id = os.getenv("EVAL_EMBEDDING_ID") + chunker_strategy_env = os.getenv("EVAL_CHUNKER_STRATEGY", "RecursiveChunker") + + if not llm_id or not embedding_id: + print("[Ingestion] ❌ EVAL_LLM_ID or EVAL_EMBEDDING_ID is not set in .env.evaluation") + print("[Ingestion] Please set both EVAL_LLM_ID and EVAL_EMBEDDING_ID") + await connector.close() + return False + + # 从数据库获取模型信息(仅用于显示名称) + from app.db import get_db + db = next(get_db()) + try: + from sqlalchemy import text + # 获取 LLM 模型信息(从 model_configs 表) + llm_result = db.execute( + text("SELECT name FROM model_configs WHERE id = :id"), + {"id": llm_id} + ).fetchone() + llm_model_name = llm_result[0] if llm_result else "Unknown LLM" + + # 获取 Embedding 模型信息(从 model_configs 表) + emb_result = db.execute( + text("SELECT name FROM model_configs WHERE id = :id"), + {"id": embedding_id} + ).fetchone() + embedding_model_name = emb_result[0] if emb_result else "Unknown Embedding" + except Exception as e: + # 如果查询失败,使用默认名称 + print(f"[Ingestion] Warning: Failed to query model names from database: {e}") + llm_model_name = f"LLM ({llm_id[:8]}...)" + embedding_model_name = f"Embedding ({embedding_id[:8]}...)" + finally: + db.close() + + # 构建 MemoryConfig 对象(使用最小必需配置) + from uuid import uuid4 + memory_config = MemoryConfig( + config_id=0, # 评估环境不需要真实的 config_id + config_name="evaluation_config", + workspace_id=uuid4(), # 临时 workspace_id + workspace_name="evaluation_workspace", + tenant_id=uuid4(), # 临时 tenant_id + llm_model_id=UUID(llm_id), + llm_model_name=llm_model_name, + embedding_model_id=UUID(embedding_id), + embedding_model_name=embedding_model_name, + storage_type="neo4j", + chunker_strategy=chunker_strategy_env, + reflexion_enabled=False, + reflexion_iteration_period=3, + reflexion_range="partial", + reflexion_baseline="TIME", + loaded_at=datetime.now(), + # 可选字段使用默认值 + rerank_model_id=None, + rerank_model_name=None, + llm_params={}, + embedding_params={}, + config_version="2.0", + ) + + print(f"[Ingestion] ✅ 构建 MemoryConfig 成功") + print(f"[Ingestion] LLM: {llm_model_name}") + print(f"[Ingestion] Embedding: {embedding_model_name}") + print(f"[Ingestion] Chunker: {chunker_strategy_env}") + + except Exception as e: + print(f"[Ingestion] ❌ Failed to build MemoryConfig: {e}") + print(f"[Ingestion] Please check:") + print(f"[Ingestion] 1. EVAL_LLM_ID and EVAL_EMBEDDING_ID are set in .env.evaluation") + print(f"[Ingestion] 2. Model IDs exist in the models table") + print(f"[Ingestion] 3. Database connection is working") + await connector.close() + return False + + # Step 7: Initialize and run ExtractionOrchestrator + print("[Ingestion] Running extraction pipeline with ExtractionOrchestrator...") + from app.services.memory_config_service import MemoryConfigService + config = MemoryConfigService.get_pipeline_config(memory_config) orchestrator = ExtractionOrchestrator( llm_client=llm_client, embedder_client=embedder_client, connector=connector, config=config, + embedding_id=str(memory_config.embedding_model_id), # 传递 embedding_id ) - # 创建一个包装的 orchestrator 来修复时间提取器的输出 - # 保存原始的 _assign_extracted_data 方法 - original_assign = orchestrator._assign_extracted_data - - def clean_temporal_value(value): - """清理 temporal_validity 字段的值,将无效值转换为 None""" - if value is None: - return None - if isinstance(value, str): - # 处理字符串形式的 'null', 'None', 空字符串等 - if value.lower() in ('null', 'none', '') or value.strip() == '': - return None - return value - - async def patched_assign_extracted_data(*args, **kwargs): - """包装方法:在赋值后清理 temporal_validity 中的无效字符串""" - result = await original_assign(*args, **kwargs) + try: + # Run the complete extraction pipeline + result = await orchestrator.run(dialog_data_list, is_pilot_run=False) - # 清理返回的 dialog_data_list 中的 temporal_validity - for dialog in result: - if hasattr(dialog, 'chunks') and dialog.chunks: - for chunk in dialog.chunks: - if hasattr(chunk, 'statements') and chunk.statements: - for statement in chunk.statements: - if hasattr(statement, 'temporal_validity') and statement.temporal_validity: - tv = statement.temporal_validity - # 清理 valid_at 和 invalid_at - if hasattr(tv, 'valid_at'): - tv.valid_at = clean_temporal_value(tv.valid_at) - if hasattr(tv, 'invalid_at'): - tv.invalid_at = clean_temporal_value(tv.invalid_at) - return result - - # 替换方法 - orchestrator._assign_extracted_data = patched_assign_extracted_data - - # 同时包装 _create_nodes_and_edges 方法,在创建节点前再次清理 - original_create = orchestrator._create_nodes_and_edges - - async def patched_create_nodes_and_edges(dialog_data_list_arg): - """包装方法:在创建节点前再次清理 temporal_validity""" - # 最后一次清理,确保万无一失 - for dialog in dialog_data_list_arg: - if hasattr(dialog, 'chunks') and dialog.chunks: - for chunk in dialog.chunks: - if hasattr(chunk, 'statements') and chunk.statements: - for statement in chunk.statements: - if hasattr(statement, 'temporal_validity') and statement.temporal_validity: - tv = statement.temporal_validity - if hasattr(tv, 'valid_at'): - tv.valid_at = clean_temporal_value(tv.valid_at) - if hasattr(tv, 'invalid_at'): - tv.invalid_at = clean_temporal_value(tv.invalid_at) + # Handle different return formats: + # - Pilot mode: 7 values (without dedup_details) + # - Normal mode: 8 values (with dedup_details at the end) + if len(result) == 8: + # Normal mode: includes dedup_details + ( + dialogue_nodes, + chunk_nodes, + statement_nodes, + entity_nodes, + statement_chunk_edges, + statement_entity_edges, + entity_entity_edges, + _, # dedup_details - not needed here + ) = result + elif len(result) == 7: + # Pilot mode or older version: no dedup_details + ( + dialogue_nodes, + chunk_nodes, + statement_nodes, + entity_nodes, + statement_chunk_edges, + statement_entity_edges, + entity_entity_edges, + ) = result + else: + raise ValueError(f"Unexpected number of return values: {len(result)}") - return await original_create(dialog_data_list_arg) - - orchestrator._create_nodes_and_edges = patched_create_nodes_and_edges - - # 运行完整的提取流水线 - # orchestrator.run 返回 7 个元素的元组 - result = await orchestrator.run(dialog_data_list, is_pilot_run=False) - ( - dialogue_nodes, - chunk_nodes, - statement_nodes, - entity_nodes, - statement_chunk_edges, - statement_entity_edges, - entity_entity_edges, - ) = result - - # statement_chunk_edges 已经由 orchestrator 创建,无需重复创建 + print(f"[Ingestion] Extraction completed: {len(statement_nodes)} statements, {len(entity_nodes)} entities") + + except ValueError as e: + # If unpacking fails, provide helpful error message + print(f"[Ingestion] Extraction pipeline result unpacking failed: {e}") + print(f"[Ingestion] Result type: {type(result)}, length: {len(result) if hasattr(result, '__len__') else 'N/A'}") + if hasattr(result, '__len__') and len(result) > 0: + print(f"[Ingestion] First element type: {type(result[0])}") + await connector.close() + return False + except Exception as e: + print(f"[Ingestion] Extraction pipeline failed: {e}") + import traceback + traceback.print_exc() + await connector.close() + return False - # Step G: 生成记忆摘要 + # Step 7: Generate memory summaries print("[Ingestion] Generating memory summaries...") try: from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import ( memory_summary_generation, ) - from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges from app.repositories.neo4j.add_nodes import add_memory_summary_nodes + from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges summaries = await memory_summary_generation( chunked_dialogs=dialog_data_list, @@ -266,7 +365,8 @@ async def ingest_contexts_via_full_pipeline( print(f"[Ingestion] Warning: Failed to generate memory summaries: {e}") summaries = [] - # Step H: Save to Neo4j + # Step 8: Save to Neo4j + print("[Ingestion] Saving to Neo4j...") try: success = await save_dialog_and_statements_to_neo4j( dialogue_nodes=dialogue_nodes, @@ -284,18 +384,21 @@ async def ingest_contexts_via_full_pipeline( try: await add_memory_summary_nodes(summaries, connector) await add_memory_summary_statement_edges(summaries, connector) - print(f"Successfully saved {len(summaries)} memory summary nodes to Neo4j") + print(f"[Ingestion] Saved {len(summaries)} memory summary nodes to Neo4j") except Exception as e: - print(f"Warning: Failed to save summary nodes: {e}") + print(f"[Ingestion] Warning: Failed to save summary nodes: {e}") await connector.close() + if success: - print("Successfully saved extracted data to Neo4j!") + print("[Ingestion] Successfully saved all data to Neo4j!") else: - print("Failed to save data to Neo4j") + print("[Ingestion] Failed to save data to Neo4j") return success + except Exception as e: - print(f"Failed to save data to Neo4j: {e}") + print(f"[Ingestion] Failed to save data to Neo4j: {e}") + await connector.close() return False diff --git a/api/app/core/memory/evaluation/locomo/locomo_benchmark.py b/api/app/core/memory/evaluation/locomo/locomo_benchmark.py index 1c70c28e..eed75016 100644 --- a/api/app/core/memory/evaluation/locomo/locomo_benchmark.py +++ b/api/app/core/memory/evaluation/locomo/locomo_benchmark.py @@ -15,134 +15,145 @@ import json import os import time from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import List, Dict, Any, Optional +from pathlib import Path +from dotenv import load_dotenv -try: - from dotenv import load_dotenv -except ImportError: - def load_dotenv(): - pass +# Load evaluation config +eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation" +if eval_config_path.exists(): + load_dotenv(eval_config_path, override=True) + print(f"✅ 加载评估配置: {eval_config_path}") +from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient +from app.core.models.base import RedBearModelConfig +from app.core.memory.utils.config.config_utils import get_embedder_config +from app.core.memory.utils.llm.llm_utils import get_llm_client from app.core.memory.evaluation.common.metrics import ( - avg_context_tokens, - bleu1, f1_score, + bleu1, jaccard, latency_stats, + avg_context_tokens ) from app.core.memory.evaluation.locomo.locomo_metrics import ( - get_category_name, locomo_f1_score, locomo_multi_f1, + get_category_name ) from app.core.memory.evaluation.locomo.locomo_utils import ( - extract_conversations, - ingest_conversations_if_needed, load_locomo_data, + extract_conversations, resolve_temporal_references, - retrieve_relevant_information, select_and_format_information, -) -from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.memory.utils.definitions import ( - PROJECT_ROOT, - SELECTED_EMBEDDING_ID, - SELECTED_end_user_id, - SELECTED_LLM_ID, + retrieve_relevant_information, ) from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.core.models.base import RedBearModelConfig from app.db import get_db_context -from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.services.memory_config_service import MemoryConfigService +# Get configuration from environment variables +PROJECT_ROOT = str(Path(__file__).resolve().parents[5]) # api directory +SELECTED_EMBEDDING_ID = os.getenv("EVAL_EMBEDDING_ID", "e2a6392d-ca63-4d59-a523-647420b59cb2") +SELECTED_end_user_id = os.getenv("LOCOMO_END_USER_ID") or os.getenv("EVAL_END_USER_ID", "locomo_benchmark") +SELECTED_LLM_ID = os.getenv("EVAL_LLM_ID", "2c9b0782-7a85-4740-ba84-4baf77f256c4") -async def run_locomo_benchmark( - sample_size: int = 20, - end_user_id: Optional[str] = None, - search_type: str = "hybrid", - search_limit: int = 12, - context_char_budget: int = 8000, - reset_group: bool = False, - skip_ingest: bool = False, - output_dir: Optional[str] = None -) -> Dict[str, Any]: + +# ============================================================================ +# Step 1: Data Loading +# ============================================================================ + +def step_load_data(data_path: str, sample_size: int) -> List[Dict[str, Any]]: """ - Run LoCoMo benchmark evaluation. - - This function orchestrates the complete evaluation pipeline: - 1. Load LoCoMo dataset (only QA pairs from first conversation) - 2. Check/ingest conversations into database (only first conversation, unless skip_ingest=True) - 3. For each question: - - Retrieve relevant information - - Generate answer using LLM - - Calculate metrics - 4. Aggregate results and save to file - - Note: By default, only the first conversation is ingested into the database, - and only QA pairs from that conversation are evaluated. This ensures that - all questions have corresponding memory in the database for retrieval. + Load QA pairs from LoCoMo dataset. Args: - sample_size: Number of QA pairs to evaluate (from first conversation) - end_user_id: Database group ID for retrieval (uses default if None) - search_type: "keyword", "embedding", or "hybrid" - search_limit: Max documents to retrieve per query - context_char_budget: Max characters for context - reset_group: Whether to clear and re-ingest data (not implemented) - skip_ingest: If True, skip data ingestion and use existing data in Neo4j - output_dir: Directory to save results (uses default if None) + data_path: Path to locomo10.json file + sample_size: Number of QA pairs to load (0 for all) Returns: - Dictionary with evaluation results including metrics, timing, and samples + List of QA items from the first conversation """ - # Use default end_user_id if not provided - end_user_id = end_user_id or SELECTED_end_user_id + print("📂 Loading LoCoMo data...") - # Determine data path - data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json") - if not os.path.exists(data_path): - # Fallback to current directory - data_path = os.path.join(os.getcwd(), "data", "locomo10.json") + # Load the dataset + qa_items = load_locomo_data(data_path, sample_size) - print(f"\n{'='*60}") - print("🚀 Starting LoCoMo Benchmark Evaluation") - print(f"{'='*60}") - print("📊 Configuration:") - print(f" Sample size: {sample_size}") - print(f" Group ID: {end_user_id}") - print(f" Search type: {search_type}") - print(f" Search limit: {search_limit}") - print(f" Context budget: {context_char_budget} chars") - print(f" Data path: {data_path}") - print(f"{'='*60}\n") + print(f"✅ Loaded {len(qa_items)} QA pairs from first conversation\n") + return qa_items + + +# ============================================================================ +# Step 2: Data Ingestion +# ============================================================================ + +async def ingest_conversations_if_needed( + conversations: List[str], + end_user_id: str, + reset: bool = False +) -> bool: + """ + Ingest conversations into Neo4j database. - # Step 1: Load LoCoMo data - print("📂 Loading LoCoMo dataset...") + Args: + conversations: List of conversation strings (already formatted) + end_user_id: Database end_user ID + reset: Whether to reset the group before ingestion + + Returns: + True if successful, False otherwise + """ try: - # Only load QA pairs from the first conversation (index 0) - # since we only ingest the first conversation into the database - qa_items = load_locomo_data(data_path, sample_size, conversation_index=0) - print(f"✅ Loaded {len(qa_items)} QA pairs from conversation 0\n") + from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline + + # Conversations are already formatted as strings, use them directly + await ingest_contexts_via_full_pipeline(conversations, end_user_id) + return True + except Exception as e: - print(f"❌ Failed to load data: {e}") - return { - "error": f"Data loading failed: {e}", - "timestamp": datetime.now().isoformat() - } + print(f"⚠️ Ingestion error: {e}") + import traceback + traceback.print_exc() + return False + + +async def step_ingest_data( + data_path: str, + end_user_id: str, + skip_ingest: bool, + reset_group: bool, + max_messages: Optional[int] = None +) -> bool: + """ + Ingest conversations into Neo4j database if needed. - # Step 2: Extract conversations and ingest if needed + Args: + data_path: Path to locomo10.json file + end_user_id: Database end_user ID + skip_ingest: Whether to skip ingestion + reset_group: Whether to reset the group before ingestion + max_messages: Maximum messages per dialogue to ingest (for testing) + + Returns: + True if ingestion succeeded or was skipped, False otherwise + """ if skip_ingest: print("⏭️ Skipping data ingestion (using existing data in Neo4j)") - print(f" Group ID: {end_user_id}\n") + print(f" End User ID: {end_user_id}\n") else: print("💾 Checking database ingestion...") try: - conversations = extract_conversations(data_path, max_dialogues=1) + # Extract conversations with optional message limit + conversations = extract_conversations( + data_path, + max_dialogues=1, + max_messages_per_dialogue=max_messages + ) print(f"📝 Extracted {len(conversations)} conversations") # Always ingest for now (ingestion check not implemented) - print(f"🔄 Ingesting conversations into group '{end_user_id}'...") + print(f"🔄 Ingesting conversations into end_user '{end_user_id}'...") success = await ingest_conversations_if_needed( conversations=conversations, end_user_id=end_user_id, @@ -156,238 +167,249 @@ async def run_locomo_benchmark( except Exception as e: print(f"❌ Ingestion failed: {e}") + import traceback + traceback.print_exc() print("⚠️ Continuing with evaluation (database may be empty)\n") - # Step 3: Initialize clients + return True + + +# ============================================================================ +# Step 3: Initialize Clients +# ============================================================================ + +def step_initialize_clients(llm_id: str, embedding_id: str): + """ + Initialize Neo4j connector, LLM client, and embedder. + + Args: + llm_id: LLM model ID + embedding_id: Embedding model ID + + Returns: + Tuple of (connector, llm_client, embedder) + """ print("🔧 Initializing clients...") + connector = Neo4jConnector() - # Initialize LLM client with database context - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(SELECTED_LLM_ID) + # Get database session + from app.db import get_db + db = next(get_db()) + try: + llm_client = get_llm_client(llm_id, db) + cfg_dict = get_embedder_config(embedding_id, db) + embedder = OpenAIEmbedderClient( + model_config=RedBearModelConfig.model_validate(cfg_dict) + ) + finally: + db.close() - # Initialize embedder - with get_db_context() as db: - config_service = MemoryConfigService(db) - cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID) - embedder = OpenAIEmbedderClient( - model_config=RedBearModelConfig.model_validate(cfg_dict) - ) print("✅ Clients initialized\n") - - # Step 4: Process questions + return connector, llm_client, embedder + + +# ============================================================================ +# Step 4: Process Questions +# ============================================================================ + +async def step_process_all_questions( + qa_items: List[Dict[str, Any]], + end_user_id: str, + search_type: str, + search_limit: int, + context_char_budget: int, + connector: Neo4jConnector, + embedder: OpenAIEmbedderClient, + llm_client: Any +) -> List[Dict[str, Any]]: + """Process all QA items: retrieve, generate, and calculate metrics.""" print(f"🔍 Processing {len(qa_items)} questions...") print(f"{'='*60}\n") - # Tracking variables - latencies_search: List[float] = [] - latencies_llm: List[float] = [] - context_counts: List[int] = [] - context_chars: List[int] = [] - context_tokens: List[int] = [] - - # Metric lists - f1_scores: List[float] = [] - bleu1_scores: List[float] = [] - jaccard_scores: List[float] = [] - locomo_f1_scores: List[float] = [] - - # Per-category tracking - category_counts: Dict[str, int] = {} - category_f1: Dict[str, List[float]] = {} - category_bleu1: Dict[str, List[float]] = {} - category_jaccard: Dict[str, List[float]] = {} - category_locomo_f1: Dict[str, List[float]] = {} - - # Detailed samples samples: List[Dict[str, Any]] = [] - - # Fixed anchor date for temporal resolution anchor_date = datetime(2023, 5, 8) - try: - for idx, item in enumerate(qa_items, 1): - question = item.get("question", "") - ground_truth = item.get("answer", "") - category = get_category_name(item) - - # Ensure ground truth is a string - ground_truth_str = str(ground_truth) if ground_truth is not None else "" - - print(f"[{idx}/{len(qa_items)}] Category: {category}") - print(f"❓ Question: {question}") - print(f"✅ Ground Truth: {ground_truth_str}") - - # Step 4a: Retrieve relevant information - t_search_start = time.time() - try: - retrieved_info = await retrieve_relevant_information( - question=question, - end_user_id=end_user_id, - search_type=search_type, - search_limit=search_limit, - connector=connector, - embedder=embedder - ) - t_search_end = time.time() - search_latency = (t_search_end - t_search_start) * 1000 - latencies_search.append(search_latency) - - print(f"🔍 Retrieved {len(retrieved_info)} documents ({search_latency:.1f}ms)") - - except Exception as e: - print(f"❌ Retrieval failed: {e}") - retrieved_info = [] - search_latency = 0.0 - latencies_search.append(search_latency) - - # Step 4b: Select and format context - context_text = select_and_format_information( - retrieved_info=retrieved_info, + for idx, item in enumerate(qa_items, 1): + question = item.get("question", "") + ground_truth = item.get("answer", "") + category = get_category_name(item) + ground_truth_str = str(ground_truth) if ground_truth is not None else "" + + print(f"[{idx}/{len(qa_items)}] Category: {category}") + print(f"❓ Question: {question}") + print(f"✅ Ground Truth: {ground_truth_str}") + + # Retrieve + t_search_start = time.time() + try: + retrieved_info = await retrieve_relevant_information( question=question, - max_chars=context_char_budget + end_user_id=end_user_id, + search_type=search_type, + search_limit=search_limit, + connector=connector, + embedder=embedder ) - - # Resolve temporal references - context_text = resolve_temporal_references(context_text, anchor_date) - - # Add reference date to context - if context_text: - context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n{context_text}" + search_latency = (time.time() - t_search_start) * 1000 + print(f"🔍 Retrieved {len(retrieved_info)} documents ({search_latency:.1f}ms)") + except Exception as e: + print(f"❌ Retrieval failed: {e}") + retrieved_info = [] + search_latency = 0.0 + + # Format context + context_text = select_and_format_information( + retrieved_info=retrieved_info, + question=question, + max_chars=context_char_budget + ) + context_text = resolve_temporal_references(context_text, anchor_date) + if context_text: + context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n{context_text}" + else: + context_text = "No relevant context found." + + print(f"📝 Context: {len(context_text)} chars, {len(retrieved_info)} docs") + + # Generate answer + messages = [ + { + "role": "system", + "content": ( + "You are a precise QA assistant. Answer following these rules:\n" + "1) Extract the EXACT information mentioned in the context\n" + "2) For time questions: calculate actual dates from relative times\n" + "3) Return ONLY the answer text in simplest form\n" + "4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n" + "5) If no clear answer found, respond with 'Unknown'" + ) + }, + { + "role": "user", + "content": f"Question: {question}\n\nContext:\n{context_text}" + } + ] + + t_llm_start = time.time() + try: + response = await llm_client.chat(messages=messages) + llm_latency = (time.time() - t_llm_start) * 1000 + if hasattr(response, 'content'): + prediction = response.content.strip() + elif isinstance(response, dict): + prediction = response["choices"][0]["message"]["content"].strip() else: - context_text = "No relevant context found." - - # Track context statistics - context_counts.append(len(retrieved_info)) - context_chars.append(len(context_text)) - context_tokens.append(len(context_text.split())) - - print(f"📝 Context: {len(context_text)} chars, {len(retrieved_info)} docs") - - # Step 4c: Generate answer with LLM - messages = [ - { - "role": "system", - "content": ( - "You are a precise QA assistant. Answer following these rules:\n" - "1) Extract the EXACT information mentioned in the context\n" - "2) For time questions: calculate actual dates from relative times\n" - "3) Return ONLY the answer text in simplest form\n" - "4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n" - "5) If no clear answer found, respond with 'Unknown'" - ) - }, - { - "role": "user", - "content": f"Question: {question}\n\nContext:\n{context_text}" - } - ] - - t_llm_start = time.time() - try: - response = await llm_client.chat(messages=messages) - t_llm_end = time.time() - llm_latency = (t_llm_end - t_llm_start) * 1000 - latencies_llm.append(llm_latency) - - # Extract prediction from response - if hasattr(response, 'content'): - prediction = response.content.strip() - elif isinstance(response, dict): - prediction = response["choices"][0]["message"]["content"].strip() - else: - prediction = "Unknown" - - print(f"🤖 Prediction: {prediction} ({llm_latency:.1f}ms)") - - except Exception as e: - print(f"❌ LLM failed: {e}") prediction = "Unknown" - llm_latency = 0.0 - latencies_llm.append(llm_latency) - - # Step 4d: Calculate metrics - f1_val = f1_score(prediction, ground_truth_str) - bleu1_val = bleu1(prediction, ground_truth_str) - jaccard_val = jaccard(prediction, ground_truth_str) - - # LoCoMo-specific F1: use multi-answer for category 1 (Multi-Hop) - if item.get("category") == 1: - locomo_f1_val = locomo_multi_f1(prediction, ground_truth_str) - else: - locomo_f1_val = locomo_f1_score(prediction, ground_truth_str) - - # Accumulate metrics - f1_scores.append(f1_val) - bleu1_scores.append(bleu1_val) - jaccard_scores.append(jaccard_val) - locomo_f1_scores.append(locomo_f1_val) - - # Track by category - category_counts[category] = category_counts.get(category, 0) + 1 - category_f1.setdefault(category, []).append(f1_val) - category_bleu1.setdefault(category, []).append(bleu1_val) - category_jaccard.setdefault(category, []).append(jaccard_val) - category_locomo_f1.setdefault(category, []).append(locomo_f1_val) - - print(f"📊 Metrics - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, " - f"Jaccard: {jaccard_val:.3f}, LoCoMo F1: {locomo_f1_val:.3f}") - print() - - # Save sample details - samples.append({ - "question": question, - "ground_truth": ground_truth_str, - "prediction": prediction, - "category": category, - "metrics": { - "f1": f1_val, - "bleu1": bleu1_val, - "jaccard": jaccard_val, - "locomo_f1": locomo_f1_val - }, - "retrieval": { - "num_docs": len(retrieved_info), - "context_length": len(context_text) - }, - "timing": { - "search_ms": search_latency, - "llm_ms": llm_latency - } - }) + print(f"🤖 Prediction: {prediction} ({llm_latency:.1f}ms)") + except Exception as e: + print(f"❌ LLM failed: {e}") + prediction = "Unknown" + llm_latency = 0.0 + + # Calculate metrics + f1_val = f1_score(prediction, ground_truth_str) + bleu1_val = bleu1(prediction, ground_truth_str) + jaccard_val = jaccard(prediction, ground_truth_str) + if item.get("category") == 1: + locomo_f1_val = locomo_multi_f1(prediction, ground_truth_str) + else: + locomo_f1_val = locomo_f1_score(prediction, ground_truth_str) + + print(f"📊 Metrics - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, " + f"Jaccard: {jaccard_val:.3f}, LoCoMo F1: {locomo_f1_val:.3f}") + print() + + samples.append({ + "question": question, + "ground_truth": ground_truth_str, + "prediction": prediction, + "category": category, + "metrics": { + "f1": f1_val, + "bleu1": bleu1_val, + "jaccard": jaccard_val, + "locomo_f1": locomo_f1_val + }, + "retrieval": { + "num_docs": len(retrieved_info), + "context_length": len(context_text) + }, + "context_tokens": len(context_text.split()), + "timing": { + "search_ms": search_latency, + "llm_ms": llm_latency + } + }) - finally: - # Close connector - await connector.close() - - # Step 5: Aggregate results + return samples + + +# ============================================================================ +# Step 5: Aggregate Results +# ============================================================================ + +def step_aggregate_results(samples: List[Dict[str, Any]]) -> Dict[str, Any]: + """Aggregate metrics from all samples.""" print(f"\n{'='*60}") print("📊 Aggregating Results") print(f"{'='*60}\n") + if not samples: + return { + "overall_metrics": {}, + "by_category": {}, + "latency": {}, + "context_stats": {} + } + + # Extract metrics + f1_scores = [s["metrics"]["f1"] for s in samples] + bleu1_scores = [s["metrics"]["bleu1"] for s in samples] + jaccard_scores = [s["metrics"]["jaccard"] for s in samples] + locomo_f1_scores = [s["metrics"]["locomo_f1"] for s in samples] + + # Extract timing + latencies_search = [s["timing"]["search_ms"] for s in samples] + latencies_llm = [s["timing"]["llm_ms"] for s in samples] + + # Extract context stats + context_counts = [s["retrieval"]["num_docs"] for s in samples] + context_chars = [s["retrieval"]["context_length"] for s in samples] + context_tokens = [s["context_tokens"] for s in samples] + # Overall metrics overall_metrics = { - "f1": sum(f1_scores) / max(len(f1_scores), 1) if f1_scores else 0.0, - "bleu1": sum(bleu1_scores) / max(len(bleu1_scores), 1) if bleu1_scores else 0.0, - "jaccard": sum(jaccard_scores) / max(len(jaccard_scores), 1) if jaccard_scores else 0.0, - "locomo_f1": sum(locomo_f1_scores) / max(len(locomo_f1_scores), 1) if locomo_f1_scores else 0.0 + "f1": sum(f1_scores) / len(f1_scores) if f1_scores else 0.0, + "bleu1": sum(bleu1_scores) / len(bleu1_scores) if bleu1_scores else 0.0, + "jaccard": sum(jaccard_scores) / len(jaccard_scores) if jaccard_scores else 0.0, + "locomo_f1": sum(locomo_f1_scores) / len(locomo_f1_scores) if locomo_f1_scores else 0.0 } # Per-category metrics + category_data: Dict[str, Dict[str, List[float]]] = {} + for sample in samples: + cat = sample["category"] + if cat not in category_data: + category_data[cat] = { + "f1": [], + "bleu1": [], + "jaccard": [], + "locomo_f1": [] + } + category_data[cat]["f1"].append(sample["metrics"]["f1"]) + category_data[cat]["bleu1"].append(sample["metrics"]["bleu1"]) + category_data[cat]["jaccard"].append(sample["metrics"]["jaccard"]) + category_data[cat]["locomo_f1"].append(sample["metrics"]["locomo_f1"]) + by_category: Dict[str, Dict[str, Any]] = {} - for cat in category_counts: - f1_list = category_f1.get(cat, []) - b1_list = category_bleu1.get(cat, []) - j_list = category_jaccard.get(cat, []) - lf_list = category_locomo_f1.get(cat, []) - + for cat, metrics_lists in category_data.items(): by_category[cat] = { - "count": category_counts[cat], - "f1": sum(f1_list) / max(len(f1_list), 1) if f1_list else 0.0, - "bleu1": sum(b1_list) / max(len(b1_list), 1) if b1_list else 0.0, - "jaccard": sum(j_list) / max(len(j_list), 1) if j_list else 0.0, - "locomo_f1": sum(lf_list) / max(len(lf_list), 1) if lf_list else 0.0 + "count": len(metrics_lists["f1"]), + "f1": sum(metrics_lists["f1"]) / len(metrics_lists["f1"]), + "bleu1": sum(metrics_lists["bleu1"]) / len(metrics_lists["bleu1"]), + "jaccard": sum(metrics_lists["jaccard"]) / len(metrics_lists["jaccard"]), + "locomo_f1": sum(metrics_lists["locomo_f1"]) / len(metrics_lists["locomo_f1"]) } # Latency statistics @@ -398,12 +420,181 @@ async def run_locomo_benchmark( # Context statistics context_stats = { - "avg_retrieved_docs": sum(context_counts) / max(len(context_counts), 1) if context_counts else 0.0, - "avg_context_chars": sum(context_chars) / max(len(context_chars), 1) if context_chars else 0.0, - "avg_context_tokens": sum(context_tokens) / max(len(context_tokens), 1) if context_tokens else 0.0 + "avg_retrieved_docs": sum(context_counts) / len(context_counts) if context_counts else 0.0, + "avg_context_chars": sum(context_chars) / len(context_chars) if context_chars else 0.0, + "avg_context_tokens": sum(context_tokens) / len(context_tokens) if context_tokens else 0.0 } - # Build result dictionary + return { + "overall_metrics": overall_metrics, + "by_category": by_category, + "latency": latency, + "context_stats": context_stats + } + + +# ============================================================================ +# Step 6: Result Saving +# ============================================================================ + +def step_save_results( + result: Dict[str, Any], + output_dir: Optional[str] +) -> str: + """ + Save evaluation results to JSON file. + + Args: + result: Complete result dictionary + output_dir: Directory to save results (uses default if None) + + Returns: + Path to saved file + """ + if output_dir is None: + # Use absolute path to ensure results are saved in the correct location + script_dir = Path(__file__).resolve().parent + output_dir = script_dir / "results" + else: + # Convert to Path object + output_dir = Path(output_dir) + # If relative path, make it relative to script directory + if not output_dir.is_absolute(): + script_dir = Path(__file__).resolve().parent + output_dir = script_dir / output_dir + + # Create directory if it doesn't exist + output_dir.mkdir(parents=True, exist_ok=True) + + timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = output_dir / f"locomo_{timestamp_str}.json" + + try: + with open(output_path, "w", encoding="utf-8") as f: + json.dump(result, f, ensure_ascii=False, indent=2) + print(f"✅ Results saved to: {output_path}\n") + return str(output_path) + except Exception as e: + print(f"❌ Failed to save results: {e}") + print("📊 Printing results to console instead:\n") + print(json.dumps(result, ensure_ascii=False, indent=2)) + return "" + + +# ============================================================================ +# Main Orchestration Function +# ============================================================================ + + +async def run_locomo_benchmark( + sample_size: int = 20, + end_user_id: Optional[str] = None, + search_type: str = "hybrid", + search_limit: int = 12, + context_char_budget: int = 8000, + reset_group: bool = False, + skip_ingest: bool = False, + output_dir: Optional[str] = None, + max_ingest_messages: Optional[int] = None +) -> Dict[str, Any]: + """ + Run LoCoMo benchmark evaluation. + + This function orchestrates the complete evaluation pipeline by calling + well-defined step functions: + 1. Load LoCoMo dataset (only QA pairs from first conversation) + 2. Ingest conversations into database (unless skip_ingest=True) + 3. Initialize clients (Neo4j, LLM, Embedder) + 4. Process all questions (retrieve, generate, calculate metrics) + 5. Aggregate results + 6. Save results to file + + Note: By default, only the first conversation is ingested into the database, + and only QA pairs from that conversation are evaluated. This ensures that + all questions have corresponding memory in the database for retrieval. + + Args: + sample_size: Number of QA pairs to evaluate (from first conversation) + end_user_id: Database end_user ID for retrieval (uses default if None) + search_type: "keyword", "embedding", or "hybrid" + search_limit: Max documents to retrieve per query + context_char_budget: Max characters for context + reset_group: Whether to clear and re-ingest data + skip_ingest: If True, skip data ingestion and use existing data in Neo4j + output_dir: Directory to save results (uses default if None) + max_ingest_messages: Max messages per dialogue to ingest (for testing, None = all) + + Returns: + Dictionary with evaluation results including metrics, timing, and samples + """ + # Use default end_user_id if not provided + # 优先级:命令行参数 > LOCOMO_END_USER_ID > EVAL_END_USER_ID > 默认值 + if end_user_id is None: + end_user_id = os.getenv("LOCOMO_END_USER_ID") or os.getenv("EVAL_END_USER_ID", "locomo_benchmark") + + # Get model IDs from config + llm_id = os.getenv("EVAL_LLM_ID", "6dc52e1b-9cec-4194-af66-a74c6307fc3f") + embedding_id = os.getenv("EVAL_EMBEDDING_ID", "e2a6392d-ca63-4d59-a523-647420b59cb2") + + # Determine data path + dataset_dir = Path(__file__).resolve().parent.parent / "dataset" + data_path = dataset_dir / "locomo10.json" + if not os.path.exists(data_path): + raise FileNotFoundError( + f"数据集文件不存在: {data_path}\n" + f"请将 locomo10.json 放置在: {dataset_dir}" + ) + + # Print configuration + print(f"\n{'='*60}") + print("🚀 Starting LoCoMo Benchmark Evaluation") + print(f"{'='*60}") + print("📊 Configuration:") + print(f" Sample size: {sample_size}") + print(f" End User ID: {end_user_id}") + print(f" Search type: {search_type}") + print(f" Search limit: {search_limit}") + print(f" Context budget: {context_char_budget} chars") + print(f" Data path: {data_path}") + if max_ingest_messages: + print(f" Max ingest messages: {max_ingest_messages} (testing mode)") + print(f"{'='*60}\n") + + # Step 1: Load LoCoMo data (加载数据) + try: + qa_items = step_load_data(data_path, sample_size) + except Exception as e: + print(f"❌ Failed to load data: {e}") + return { + "error": f"Data loading failed: {e}", + "timestamp": datetime.now().isoformat() + } + + # Step 2: Ingest data if needed(数据摄入) + await step_ingest_data(data_path, end_user_id, skip_ingest, reset_group, max_ingest_messages) + + # Step 3: Initialize clients (初始化客户端) + connector, llm_client, embedder = step_initialize_clients(llm_id, embedding_id) + + # Step 4: Process all questions (处理所有问题) + try: + samples = await step_process_all_questions( + qa_items=qa_items, + end_user_id=end_user_id, + search_type=search_type, + search_limit=search_limit, + context_char_budget=context_char_budget, + connector=connector, + embedder=embedder, + llm_client=llm_client + ) + finally: + await connector.close() + + # Step 5: Aggregate results (聚合答案) + aggregated = step_aggregate_results(samples) + + # Build final result dictionary result = { "dataset": "locomo", "sample_size": len(qa_items), @@ -413,37 +604,18 @@ async def run_locomo_benchmark( "search_type": search_type, "search_limit": search_limit, "context_char_budget": context_char_budget, - "llm_id": SELECTED_LLM_ID, - "embedding_id": SELECTED_EMBEDDING_ID + "llm_id": llm_id, + "embedding_id": embedding_id }, - "overall_metrics": overall_metrics, - "by_category": by_category, - "latency": latency, - "context_stats": context_stats, + "overall_metrics": aggregated["overall_metrics"], + "by_category": aggregated["by_category"], + "latency": aggregated["latency"], + "context_stats": aggregated["context_stats"], "samples": samples } - # Step 6: Save results - if output_dir is None: - output_dir = os.path.join( - os.path.dirname(__file__), - "results" - ) - - os.makedirs(output_dir, exist_ok=True) - - # Generate timestamped filename - timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") - output_path = os.path.join(output_dir, f"locomo_{timestamp_str}.json") - - try: - with open(output_path, "w", encoding="utf-8") as f: - json.dump(result, f, ensure_ascii=False, indent=2) - print(f"✅ Results saved to: {output_path}\n") - except Exception as e: - print(f"❌ Failed to save results: {e}") - print("📊 Printing results to console instead:\n") - print(json.dumps(result, ensure_ascii=False, indent=2)) + # Step 6: Save results (保存结果) + step_save_results(result, output_dir) return result @@ -454,7 +626,25 @@ def main(): This function provides a CLI interface for running LoCoMo benchmarks with configurable parameters. + + Configuration priority: Command-line args > Environment variables > Code defaults """ + # Load environment variables first + load_dotenv() + + # Get defaults from environment variables + env_sample_size = os.getenv("LOCOMO_SAMPLE_SIZE") + env_search_limit = os.getenv("LOCOMO_SEARCH_LIMIT") + env_context_budget = os.getenv("LOCOMO_CONTEXT_CHAR_BUDGET") + env_output_dir = os.getenv("LOCOMO_OUTPUT_DIR") + env_skip_ingest = os.getenv("LOCOMO_SKIP_INGEST", "false").lower() in ("true", "1", "yes") + + # Convert to appropriate types with fallback to code defaults + default_sample_size = int(env_sample_size) if env_sample_size else 20 + default_search_limit = int(env_search_limit) if env_search_limit else 12 + default_context_budget = int(env_context_budget) if env_context_budget else 8000 + default_output_dir = env_output_dir if env_output_dir else None + parser = argparse.ArgumentParser( description="Run LoCoMo benchmark evaluation", formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -463,14 +653,14 @@ def main(): parser.add_argument( "--sample_size", type=int, - default=20, - help="Number of QA pairs to evaluate" + default=default_sample_size, + help=f"Number of QA pairs to evaluate (env: LOCOMO_SAMPLE_SIZE={env_sample_size or 'not set'}, 0 for all)" ) parser.add_argument( "--end_user_id", type=str, default=None, - help="Database group ID for retrieval (uses default if not specified)" + help="Database end user ID for retrieval (uses LOCOMO_END_USER_ID or EVAL_END_USER_ID if not specified)" ) parser.add_argument( "--search_type", @@ -482,14 +672,14 @@ def main(): parser.add_argument( "--search_limit", type=int, - default=12, - help="Maximum number of documents to retrieve per query" + default=default_search_limit, + help=f"Maximum number of documents to retrieve per query (env: LOCOMO_SEARCH_LIMIT={env_search_limit or 'not set'})" ) parser.add_argument( "--context_char_budget", type=int, - default=8000, - help="Maximum characters for context" + default=default_context_budget, + help=f"Maximum characters for context (env: LOCOMO_CONTEXT_CHAR_BUDGET={env_context_budget or 'not set'})" ) parser.add_argument( "--reset_group", @@ -499,20 +689,24 @@ def main(): parser.add_argument( "--skip_ingest", action="store_true", - help="Skip data ingestion and use existing data in Neo4j" + default=env_skip_ingest, + help=f"Skip data ingestion and use existing data in Neo4j (env: LOCOMO_SKIP_INGEST={os.getenv('LOCOMO_SKIP_INGEST', 'false')})" ) parser.add_argument( "--output_dir", type=str, + default=default_output_dir, + help=f"Directory to save results (env: LOCOMO_OUTPUT_DIR={env_output_dir or 'not set'})" + ) + parser.add_argument( + "--max_ingest_messages", + type=int, default=None, - help="Directory to save results (uses default if not specified)" + help="Maximum messages per dialogue to ingest (for testing, default: all messages)" ) args = parser.parse_args() - # Load environment variables - load_dotenv() - # Run benchmark result = asyncio.run(run_locomo_benchmark( sample_size=args.sample_size, @@ -522,7 +716,8 @@ def main(): context_char_budget=args.context_char_budget, reset_group=args.reset_group, skip_ingest=args.skip_ingest, - output_dir=args.output_dir + output_dir=args.output_dir, + max_ingest_messages=args.max_ingest_messages )) # Print summary diff --git a/api/app/core/memory/evaluation/locomo/locomo_test.py b/api/app/core/memory/evaluation/locomo/locomo_test.py index 01c45123..2cb0664c 100644 --- a/api/app/core/memory/evaluation/locomo/locomo_test.py +++ b/api/app/core/memory/evaluation/locomo/locomo_test.py @@ -1,30 +1,29 @@ # file name: check_neo4j_connection_fixed.py import asyncio -import json -import math import os -import re import sys +import json import time +import math +import re from datetime import datetime, timedelta -from typing import Any, Dict, List +from typing import List, Dict, Any from pathlib import Path - from dotenv import load_dotenv -# 1 -# 添加项目根目录到路径 -current_dir = Path(__file__).resolve().parent -project_root = str(current_dir.parent) -if project_root not in sys.path: - sys.path.insert(0, project_root) -# 关键:将 src 目录置于最前,确保从当前仓库加载模块 -src_dir = os.path.join(project_root, "src") -if src_dir not in sys.path: - sys.path.insert(0, src_dir) - +# Load main .env load_dotenv() +# Load evaluation config +eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation" +if eval_config_path.exists(): + load_dotenv(eval_config_path, override=True) + print(f"✅ 加载评估配置: {eval_config_path}") + +# Get group_id from config +group_id = os.getenv("EVAL_GROUP_ID", "locomo_test") +print(f"✅ 使用配置的 group_id: {group_id}") + # 首先定义 _loc_normalize 函数,因为其他函数依赖它 def _loc_normalize(text: str) -> str: text = str(text) if text is not None else "" @@ -37,7 +36,7 @@ def _loc_normalize(text: str) -> str: # 尝试从 metrics.py 导入基础指标 try: - from common.metrics import bleu1, f1_score, jaccard + from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard print("✅ 从 metrics.py 导入基础指标成功") except ImportError as e: print(f"❌ 从 metrics.py 导入失败: {e}") @@ -107,23 +106,8 @@ except ImportError as e: # 尝试从 qwen_search_eval.py 导入 LoCoMo 特定指标 try: - # 添加 evaluation 目录路径 - evaluation_dir = os.path.join(project_root, "evaluation") - if evaluation_dir not in sys.path: - sys.path.insert(0, evaluation_dir) - - # 尝试从不同位置导入 - try: - from locomo.qwen_search_eval import ( - _resolve_relative_times, - loc_f1_score, - loc_multi_f1, - ) - print("✅ 从 locomo.qwen_search_eval 导入 LoCoMo 特定指标成功") - except ImportError: - from qwen_search_eval import _resolve_relative_times, loc_f1_score, loc_multi_f1 - print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功") - + from app.core.memory.evaluation.locomo.qwen_search_eval import loc_f1_score, loc_multi_f1, _resolve_relative_times + print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功") except ImportError as e: print(f"❌ 从 qwen_search_eval.py 导入失败: {e}") # 回退到本地实现 LoCoMo 特定函数 @@ -429,31 +413,36 @@ def enhanced_context_selection(contexts: List[str], question: str, question_inde async def run_enhanced_evaluation(): """使用增强方法进行完整评估 - 解决中间性能衰减问题""" - try: - from dotenv import load_dotenv - except Exception: - def load_dotenv(): - return None - + from dotenv import load_dotenv + from uuid import UUID + from datetime import datetime + from dataclasses import dataclass + # 修正导入路径:使用 app.core.memory.src 前缀 - from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient - from app.core.memory.utils.config.definitions import ( - SELECTED_EMBEDDING_ID, - SELECTED_LLM_ID, - ) - from app.core.memory.utils.llm.llm_utils import MemoryClientFactory - from app.core.models.base import RedBearModelConfig - from app.db import get_db_context - from app.repositories.neo4j.graph_search import search_graph_by_embedding from app.repositories.neo4j.neo4j_connector import Neo4jConnector + from app.repositories.neo4j.graph_search import search_graph_by_embedding + from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient + from app.core.models.base import RedBearModelConfig + from app.core.memory.utils.llm.llm_utils import get_llm_client + from app.core.memory.utils.config.config_utils import get_embedder_config + from app.schemas.memory_config_schema import MemoryConfig from app.services.memory_config_service import MemoryConfigService + + # Get model IDs from config + llm_id = os.getenv("EVAL_LLM_ID", "6dc52e1b-9cec-4194-af66-a74c6307fc3f") + embedding_id = os.getenv("EVAL_EMBEDDING_ID", "e2a6392d-ca63-4d59-a523-647420b59cb2") - # 加载数据 - # 获取项目根目录 - current_file = os.path.abspath(__file__) - evaluation_dir = os.path.dirname(os.path.dirname(current_file)) # evaluation目录 - memory_dir = os.path.dirname(evaluation_dir) # memory目录 - data_path = os.path.join(memory_dir, "data", "locomo10.json") + # 加载数据 - 使用统一的 dataset 目录 + data_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "dataset", "locomo10.json") + + if not os.path.exists(data_path): + raise FileNotFoundError( + f"数据集文件不存在: {data_path}\n" + f"请将 locomo10.json 放置在: api/app/core/memory/evaluation/dataset/" + ) + + print(f"✅ 找到数据文件: {data_path}") + with open(data_path, "r", encoding="utf-8") as f: raw = json.load(f) @@ -463,64 +452,109 @@ async def run_enhanced_evaluation(): qa_items.extend(entry.get("qa", [])) else: qa_items.extend(raw.get("qa", [])) - - items = qa_items[:20] # 测试多少个问题 + + # 测试多少个问题 - 可通过环境变量设置 + sample_size = int(os.getenv("LOCOMO_SAMPLE_SIZE", "20")) + items = qa_items[:sample_size] + print(f"📊 将测试 {len(items)} 个问题(总共 {len(qa_items)} 个可用)") # 初始化增强监控器 monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6) - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm = factory.get_llm_client(SELECTED_LLM_ID) + # 获取数据库会话并初始化 LLM 客户端 + from app.db import get_db + db = next(get_db()) - # 初始化embedder - with get_db_context() as db: - config_service = MemoryConfigService(db) - cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID) - embedder = OpenAIEmbedderClient( - model_config=RedBearModelConfig.model_validate(cfg_dict) - ) - - # 初始化连接器 - connector = Neo4jConnector() - - # 初始化结果字典 - results = { - "questions": [], - "overall_metrics": {"f1": 0.0, "b1": 0.0, "j": 0.0, "loc_f1": 0.0}, - "category_metrics": {}, - "retrieval_stats": {"total_questions": len(items), "avg_context_length": 0, "avg_retrieved_docs": 0}, - "performance_trend": "stable", - "timestamp": datetime.now().isoformat(), - "enhanced_strategy": True - } - - total_f1 = 0.0 - total_bleu1 = 0.0 - total_jaccard = 0.0 - total_loc_f1 = 0.0 - total_context_length = 0 - total_retrieved_docs = 0 - category_stats = {} - try: - for i, item in enumerate(items): - monitor.question_count += 1 + llm = get_llm_client(llm_id, db) + + # 初始化embedder + cfg_dict = get_embedder_config(embedding_id, db) + embedder = OpenAIEmbedderClient( + model_config=RedBearModelConfig.model_validate(cfg_dict) + ) + + # 🔧 创建 MemoryConfig 对象用于搜索 + # 方案1:如果有配置ID,从数据库加载 + config_id = os.getenv("EVAL_CONFIG_ID") + if config_id: + print(f"📋 从数据库加载配置 ID: {config_id}") + memory_config_service = MemoryConfigService(db) + memory_config = memory_config_service.load_memory_config(config_id, service_name="locomo_test") + else: + # 方案2:创建临时配置对象用于测试 + print(f"📋 创建临时测试配置") + from uuid import UUID + from datetime import datetime + + # 将字符串 ID 转换为 UUID + try: + embedding_uuid = UUID(embedding_id) + llm_uuid = UUID(llm_id) + except ValueError as e: + raise ValueError(f"无效的 UUID 格式: {e}") + + memory_config = MemoryConfig( + config_id=1, # 临时 ID + config_name="locomo_test_config", + workspace_id=UUID("00000000-0000-0000-0000-000000000000"), # 临时 workspace + workspace_name="test_workspace", + tenant_id=UUID("00000000-0000-0000-0000-000000000000"), # 临时 tenant + embedding_model_id=embedding_uuid, + embedding_model_name="test_embedding", + llm_model_id=llm_uuid, + llm_model_name="test_llm", + storage_type="neo4j", + chunker_strategy="RecursiveChunker", + reflexion_enabled=False, + reflexion_iteration_period=3, + reflexion_range="partial", + reflexion_baseline="Time", + loaded_at=datetime.now() + ) + + print(f"✅ MemoryConfig 已准备: embedding_id={memory_config.embedding_model_id}, llm_id={memory_config.llm_model_id}") + + # 初始化连接器 + connector = Neo4jConnector() - # 获取近期性能用于重置判断 - recent_performance = monitor.get_recent_performance() + # 初始化结果字典 + results = { + "questions": [], + "overall_metrics": {"f1": 0.0, "b1": 0.0, "j": 0.0, "loc_f1": 0.0}, + "category_metrics": {}, + "retrieval_stats": {"total_questions": len(items), "avg_context_length": 0, "avg_retrieved_docs": 0}, + "performance_trend": "stable", + "timestamp": datetime.now().isoformat(), + "enhanced_strategy": True + } - # 增强的重置判断 - should_reset = monitor.should_reset_connections(current_f1=recent_performance) - if should_reset and i > 0: - print(f"🔄 重置Neo4j连接 (问题 {i+1}/{len(items)}, 近期性能: {recent_performance:.3f})...") - await connector.close() - connector = Neo4jConnector() # 创建新连接 - print("✅ 连接重置完成") + total_f1 = 0.0 + total_bleu1 = 0.0 + total_jaccard = 0.0 + total_loc_f1 = 0.0 + total_context_length = 0 + total_retrieved_docs = 0 + category_stats = {} - q = item.get("question", "") - ref = item.get("answer", "") - ref_str = str(ref) if ref is not None else "" + try: + for i, item in enumerate(items): + monitor.question_count += 1 + + # 获取近期性能用于重置判断 + recent_performance = monitor.get_recent_performance() + + # 增强的重置判断 + should_reset = monitor.should_reset_connections(current_f1=recent_performance) + if should_reset and i > 0: + print(f"🔄 重置Neo4j连接 (问题 {i+1}/{len(items)}, 近期性能: {recent_performance:.3f})...") + await connector.close() + connector = Neo4jConnector() # 创建新连接 + print("✅ 连接重置完成") + + q = item.get("question", "") + ref = item.get("answer", "") + ref_str = str(ref) if ref is not None else "" print(f"\n🔍 [{i+1}/{len(items)}] 问题: {q}") print(f"✅ 真实答案: {ref_str}") @@ -548,10 +582,12 @@ async def run_enhanced_evaluation(): contexts_all = [] try: - # 使用统一的搜索服务 - from app.core.memory.storage_services.search import run_hybrid_search + # 使用旧版本的搜索服务(重构前的版本) + from app.core.memory.src.search import run_hybrid_search - print("🔀 使用混合搜索服务...") + print(f"🔀 使用混合搜索服务(旧版本)...") + print(f"📍 检索参数: group_id={group_id}, limit=20, search_type=hybrid") + print(f"📍 查询文本: {q}") search_results = await run_hybrid_search( query_text=q, @@ -559,15 +595,27 @@ async def run_enhanced_evaluation(): end_user_id="locomo_sk", limit=20, include=["statements", "chunks", "entities", "summaries"], - alpha=0.6, # BM25权重 - embedding_id=SELECTED_EMBEDDING_ID + output_path=None, + memory_config=memory_config, # 🔧 添加必需的 memory_config 参数 + rerank_alpha=0.6, # BM25权重 + use_forgetting_rerank=False, + use_llm_rerank=False ) - # 处理搜索结果 - 新的搜索服务返回统一的结构 - chunks = search_results.get("chunks", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - summaries = search_results.get("summaries", []) + # 处理搜索结果 - 旧版本返回包含 reranked_results 的结构 + # 对于 hybrid 搜索,使用 reranked_results + if "reranked_results" in search_results: + reranked = search_results["reranked_results"] + chunks = reranked.get("chunks", []) + statements = reranked.get("statements", []) + entities = reranked.get("entities", []) + summaries = reranked.get("summaries", []) + else: + # 单一搜索类型的结果 + chunks = search_results.get("chunks", []) + statements = search_results.get("statements", []) + entities = search_results.get("entities", []) + summaries = search_results.get("summaries", []) print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要") @@ -609,6 +657,8 @@ async def run_enhanced_evaluation(): print(f"📊 有效上下文数量: {len(contexts_all)}") except Exception as e: print(f"❌ 检索失败: {e}") + import traceback + print(f"详细错误信息:\n{traceback.format_exc()}") contexts_all = [] t1 = time.time() @@ -728,14 +778,17 @@ async def run_enhanced_evaluation(): print("="*60) - except Exception as e: - print(f"❌ 评估过程中发生错误: {e}") - # 即使出错,也返回已有的结果 - import traceback - traceback.print_exc() + except Exception as e: + print(f"❌ 评估过程中发生错误: {e}") + # 即使出错,也返回已有的结果 + import traceback + traceback.print_exc() + finally: + await connector.close() + finally: - await connector.close() + db.close() # 关闭数据库会话 # 计算总体指标 n = len(items) diff --git a/api/app/core/memory/evaluation/locomo/locomo_utils.py b/api/app/core/memory/evaluation/locomo/locomo_utils.py index d3b74947..6ad68470 100644 --- a/api/app/core/memory/evaluation/locomo/locomo_utils.py +++ b/api/app/core/memory/evaluation/locomo/locomo_utils.py @@ -15,8 +15,14 @@ import json import re from datetime import datetime, timedelta from typing import List, Dict, Any, Optional +from pathlib import Path +from dotenv import load_dotenv + +# Load evaluation config +eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation" +if eval_config_path.exists(): + load_dotenv(eval_config_path, override=True) -from app.core.memory.utils.definitions import PROJECT_ROOT from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline @@ -82,7 +88,7 @@ def load_locomo_data( return qa_items[:sample_size] -def extract_conversations(data_path: str, max_dialogues: int = 1) -> List[str]: +def extract_conversations(data_path: str, max_dialogues: int = 1, max_messages_per_dialogue: Optional[int] = None) -> List[str]: """ Extract conversation texts from LoCoMo data for ingestion. @@ -93,6 +99,7 @@ def extract_conversations(data_path: str, max_dialogues: int = 1) -> List[str]: Args: data_path: Path to locomo10.json file max_dialogues: Maximum number of dialogues to extract (default: 1) + max_messages_per_dialogue: Maximum messages per dialogue (default: None = all messages) Returns: List of conversation strings formatted for ingestion. @@ -141,13 +148,21 @@ def extract_conversations(data_path: str, max_dialogues: int = 1) -> List[str]: continue lines.append(f"{role}: {text}") + + # Limit messages if specified + if max_messages_per_dialogue and len(lines) >= max_messages_per_dialogue: + break + + # Break outer loop if we've reached the message limit + if max_messages_per_dialogue and len(lines) >= max_messages_per_dialogue: + break if lines: contents.append("\n".join(lines)) return contents - +# 时间解析:将相对时间表达转换为绝对日期 def resolve_temporal_references(text: str, anchor_date: datetime) -> str: """ Resolve relative temporal references to absolute dates. @@ -225,6 +240,8 @@ def resolve_temporal_references(text: str, anchor_date: datetime) -> str: t, flags=re.IGNORECASE ) + + # 中文支持 t = re.sub( r"\bnext\s+week\b", (anchor_date + timedelta(days=7)).date().isoformat(), @@ -345,6 +362,50 @@ def select_and_format_information( return "\n\n".join(selected) +# 记忆系统核心能力:写入与读取 +async def ingest_conversations_if_needed( + conversations: List[str], + end_user_id: str, + reset: bool = False +) -> bool: + """ + Wrapper for conversation ingestion using external extraction pipeline. + + This function populates the Neo4j database with processed conversation data + (chunks, statements, entities) so that the retrieval system has memory to search. + + The ingestion process: + 1. Parses conversation text into dialogue messages + 2. Chunks the dialogues into semantic units + 3. Extracts statements and entities using LLM + 4. Generates embeddings for all content + 5. Stores everything in Neo4j graph database + + Args: + conversations: List of raw conversation texts from LoCoMo dataset + Example: ["User: I went to Paris. AI: When was that?", ...] + end_user_id: Target end_user ID for database storage + reset: Whether to clear existing data first (not implemented in wrapper) + + Returns: + True if successful, False otherwise + + Note: + The external function uses "contexts" to mean "conversation texts". + This runs the full extraction pipeline: chunking → entity extraction → + statement extraction → embedding → Neo4j storage. + """ + try: + success = await ingest_contexts_via_full_pipeline( + contexts=conversations, + end_user_id=end_user_id, + save_chunk_output=True, + reset_group=reset + ) + return success + except Exception as e: + print(f"[Ingestion] Failed to ingest conversations: {e}") + return False async def retrieve_relevant_information( question: str, @@ -385,7 +446,7 @@ async def retrieve_relevant_information( search_graph, search_graph_by_embedding ) - from app.core.memory.storage_services.search import run_hybrid_search + from app.core.memory.src.search import run_hybrid_search contexts_all: List[str] = [] diff --git a/api/app/core/memory/evaluation/locomo/qwen_search_eval.py b/api/app/core/memory/evaluation/locomo/qwen_search_eval.py index 6a5caa0c..889c5065 100644 --- a/api/app/core/memory/evaluation/locomo/qwen_search_eval.py +++ b/api/app/core/memory/evaluation/locomo/qwen_search_eval.py @@ -2,43 +2,29 @@ import argparse import asyncio import json import os -import statistics import time from datetime import datetime, timedelta -from typing import Any, Dict, List - -try: - from dotenv import load_dotenv -except Exception: - def load_dotenv(): - return None - +from typing import List, Dict, Any +import statistics import re +from pathlib import Path +from dotenv import load_dotenv + +# Load evaluation config +eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation" +if eval_config_path.exists(): + load_dotenv(eval_config_path, override=True) + print(f"✅ 加载评估配置: {eval_config_path}") -from app.core.memory.evaluation.common.metrics import ( - avg_context_tokens, - bleu1, - jaccard, - latency_stats, -) -from app.core.memory.evaluation.common.metrics import f1_score as common_f1 -from app.core.memory.evaluation.extraction_utils import ( - ingest_contexts_via_full_pipeline, -) -from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.memory.storage_services.search import run_hybrid_search -from app.core.memory.utils.config.definitions import ( - PROJECT_ROOT, - SELECTED_EMBEDDING_ID, - SELECTED_GROUP_ID, - SELECTED_LLM_ID, -) -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.core.models.base import RedBearModelConfig -from app.db import get_db_context -from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.services.memory_config_service import MemoryConfigService +from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding +from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient +from app.core.models.base import RedBearModelConfig +from app.core.memory.utils.config.config_utils import get_embedder_config +from app.core.memory.src.search import run_hybrid_search # 使用旧版本(重构前) +from app.core.memory.utils.llm.llm_utils import get_llm_client +from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline +from app.core.memory.evaluation.common.metrics import f1_score as common_f1, bleu1, jaccard, latency_stats, avg_context_tokens # 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现) @@ -265,7 +251,10 @@ async def run_locomo_eval( end_user_id = end_user_id or SELECTED_end_user_id data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json") if not os.path.exists(data_path): - data_path = os.path.join(os.getcwd(), "data", "locomo10.json") + raise FileNotFoundError( + f"数据集文件不存在: {data_path}\n" + f"请将 locomo10.json 放置在: {dataset_dir}" + ) with open(data_path, "r", encoding="utf-8") as f: raw = json.load(f) # LoCoMo 数据结构:顶层为若干对象,每个对象下有 qa 列表 @@ -343,13 +332,9 @@ async def run_locomo_eval( await ingest_contexts_via_full_pipeline(contents, end_user_id, save_chunk_output=True) # 使用异步LLM客户端 - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(SELECTED_LLM_ID) + llm_client = get_llm_client(llm_id) # 初始化embedder用于直接调用 - with get_db_context() as db: - config_service = MemoryConfigService(db) - cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID) + cfg_dict = get_embedder_config(embedding_id) embedder = OpenAIEmbedderClient( model_config=RedBearModelConfig.model_validate(cfg_dict) ) @@ -480,8 +465,8 @@ async def run_locomo_eval( contexts_all.append(f"EntitySummary: {', '.join(entity_names)}") else: # hybrid - # 🎯 关键修复:混合检索使用更严格的回退机制 - print("🔀 使用混合检索(带回退机制)...") + # 使用旧版本的混合检索(重构前) + print("🔀 使用混合检索(旧版本)...") try: search_results = await run_hybrid_search( query_text=q, @@ -490,16 +475,26 @@ async def run_locomo_eval( limit=adjusted_limit, include=["chunks", "statements", "entities", "summaries"], output_path=None, + rerank_alpha=0.6, + use_forgetting_rerank=False, + use_llm_rerank=False ) - # 🎯 关键修复:正确处理混合检索的扁平结构 - # 新的API返回扁平结构,直接从顶层获取结果 + # 处理旧版本的返回结构(包含 reranked_results) if search_results and isinstance(search_results, dict): - # 新API返回扁平结构:直接从顶层获取 - chunks = search_results.get("chunks", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - summaries = search_results.get("summaries", []) + # 对于 hybrid 搜索,使用 reranked_results + if "reranked_results" in search_results: + reranked = search_results["reranked_results"] + chunks = reranked.get("chunks", []) + statements = reranked.get("statements", []) + entities = reranked.get("entities", []) + summaries = reranked.get("summaries", []) + else: + # 单一搜索类型的结果 + chunks = search_results.get("chunks", []) + statements = search_results.get("statements", []) + entities = search_results.get("entities", []) + summaries = search_results.get("summaries", []) # 检查是否有有效结果 if chunks or statements or entities or summaries: @@ -799,8 +794,9 @@ async def run_locomo_eval( "search_limit": search_limit, "context_char_budget": context_char_budget, "search_type": search_type, - "llm_id": SELECTED_LLM_ID, - "retrieval_embedding_id": SELECTED_EMBEDDING_ID, + "llm_id": llm_id, + "retrieval_embedding_id": embedding_id, + "chunker_strategy": os.getenv("EVAL_CHUNKER_STRATEGY", "RecursiveChunker"), "skip_ingest_if_exists": skip_ingest_if_exists, "llm_timeout": llm_timeout, "llm_max_retries": llm_max_retries, diff --git a/api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py b/api/app/core/memory/evaluation/longmemeval/longmemeval_benchmark.py similarity index 93% rename from api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py rename to api/app/core/memory/evaluation/longmemeval/longmemeval_benchmark.py index 8710a504..aaf46e35 100644 --- a/api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py +++ b/api/app/core/memory/evaluation/longmemeval/longmemeval_benchmark.py @@ -2,100 +2,67 @@ import argparse import asyncio import json import os +import time import re import statistics -import time from datetime import datetime, timedelta -from typing import Any, Dict, List - -try: - from dotenv import load_dotenv -except Exception: - def load_dotenv(): - return None - -# 确保可以找到 src 及项目根路径 -import sys +from typing import List, Dict, Any from pathlib import Path -_THIS_DIR = Path(__file__).resolve().parent -_PROJECT_ROOT = str(_THIS_DIR.parents[2]) -_SRC_DIR = os.path.join(_PROJECT_ROOT, "src") -for _p in (_SRC_DIR, _PROJECT_ROOT): - if _p not in sys.path: - sys.path.insert(0, _p) +from dotenv import load_dotenv + +# Load evaluation config +eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation" +if eval_config_path.exists(): + load_dotenv(eval_config_path, override=True) + print(f"✅ 加载评估配置: {eval_config_path}") -# 与现有评估脚本保持一致的导入方式 from app.repositories.neo4j.neo4j_connector import Neo4jConnector - -try: - # 优先从 extraction_utils1 导入 - from app.core.memory.evaluation.extraction_utils import ( - ingest_contexts_via_full_pipeline, # type: ignore - ) -except Exception: - ingest_contexts_via_full_pipeline = None # 在运行时做兜底检查 -from app.core.memory.evaluation.common.metrics import ( - avg_context_tokens, - jaccard, - latency_stats, -) -from app.core.memory.evaluation.common.metrics import f1_score as common_f1 -from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME -from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.memory.utils.config.definitions import ( - PROJECT_ROOT, - SELECTED_EMBEDDING_ID, - SELECTED_LLM_ID, -) -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.core.models.base import RedBearModelConfig -from app.db import get_db_context +from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding -from app.services.memory_config_service import MemoryConfigService - -try: - from app.core.memory.evaluation.common.metrics import exact_match -except Exception: - # 兜底:简单的大小写不敏感比较 - def exact_match(pred: str, ref: str) -> bool: - return str(pred).strip().lower() == str(ref).strip().lower() +from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient +from app.core.models.base import RedBearModelConfig +from app.core.memory.utils.config.config_utils import get_embedder_config +from app.core.memory.utils.llm.llm_utils import get_llm_client +from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME +from app.core.memory.evaluation.common.metrics import f1_score as common_f1, jaccard, latency_stats, avg_context_tokens +from app.core.memory.evaluation.common.metrics import exact_match def load_dataset_any(path: str) -> List[Dict[str, Any]]: - """健壮地加载数据集(兼容 list 或多段 JSON)。""" + """健壮地加载数据集,支持三种格式: + 1. 标准 JSON 数组: [{...}, {...}] + 2. 单个 JSON 对象: {...} + 3. JSONL 格式(每行一个 JSON): {...}\n{...}\n{...} + """ with open(path, "r", encoding="utf-8") as f: - s = f.read().strip() + content = f.read().strip() + + # 尝试标准 JSON 解析 try: - obj = json.loads(s) - if isinstance(obj, list): - return obj - elif isinstance(obj, dict): - return [obj] + data = json.loads(content) + if isinstance(data, list): + return [item for item in data if isinstance(item, dict)] + elif isinstance(data, dict): + return [data] except json.JSONDecodeError: pass - dec = json.JSONDecoder() - idx = 0 - items: List[Dict[str, Any]] = [] - while idx < len(s): - while idx < len(s) and s[idx].isspace(): - idx += 1 - if idx >= len(s): - break + + # 尝试 JSONL 格式(每行一个 JSON 对象) + items = [] + for line in content.splitlines(): + line = line.strip() + if not line: + continue try: - obj, end = dec.raw_decode(s, idx) - if isinstance(obj, list): - for it in obj: - if isinstance(it, dict): - items.append(it) - elif isinstance(obj, dict): + obj = json.loads(line) + if isinstance(obj, dict): items.append(obj) - idx = end + elif isinstance(obj, list): + items.extend(item for item in obj if isinstance(item, dict)) except json.JSONDecodeError: - nl = s.find("\n", idx) - if nl == -1: - break - idx = nl + 1 + continue + return items @@ -624,7 +591,7 @@ def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str: async def run_longmemeval_test( sample_size: int = 3, - end_user_id: str = "longmemeval_zh_bak_3", + end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, @@ -639,18 +606,22 @@ async def run_longmemeval_test( skip_ingest: bool = False, ) -> Dict[str, Any]: """LongMemEval 评估测试:增强时间推理能力""" + + # Use environment variable with fallback chain + if end_user_id is None: + end_user_id = os.getenv("LONGMEMEVAL_END_USER_ID") or os.getenv("EVAL_END_USER_ID", "longmemeval_zh_bak_3") # 数据路径 if not data_path: - # 固定使用中文数据集:data/longmemeval_oracle_zh.json - zh_proj = os.path.join(PROJECT_ROOT, "data", "longmemeval_oracle_zh.json") - zh_cwd = os.path.join(os.getcwd(), "data", "longmemeval_oracle_zh.json") - if os.path.exists(zh_proj): - data_path = zh_proj - elif os.path.exists(zh_cwd): - data_path = zh_cwd - else: - raise FileNotFoundError("未找到数据集: data/longmemeval_oracle_zh.json,请确保其存在于项目根目录或当前工作目录的 data 目录下。") + # 固定使用中文数据集:dataset/longmemeval_oracle_zh.json + dataset_dir = Path(__file__).resolve().parent.parent / "dataset" + data_path = str(dataset_dir / "longmemeval_oracle_zh.json") + + if not os.path.exists(data_path): + raise FileNotFoundError( + f"数据集文件不存在: {data_path}\n" + f"请将 longmemeval_oracle_zh.json 放置在: {dataset_dir}" + ) qa_list: List[Dict[str, Any]] = load_dataset_any(data_path) # 支持评估全部样本:当 sample_size <= 0 时,取从 start_index 到末尾 @@ -702,16 +673,19 @@ async def run_longmemeval_test( ) # 初始化组件(摄入后再初始化连接器)- 使用异步LLM客户端 - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(SELECTED_LLM_ID) + from app.db import get_db + + db = next(get_db()) + try: + llm_client = get_llm_client(os.getenv("EVAL_LLM_ID"), db) + cfg_dict = get_embedder_config(os.getenv("EVAL_EMBEDDING_ID"), db) + embedder = OpenAIEmbedderClient( + model_config=RedBearModelConfig.model_validate(cfg_dict) + ) + finally: + db.close() + connector = Neo4jConnector() - with get_db_context() as db: - config_service = MemoryConfigService(db) - cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID) - embedder = OpenAIEmbedderClient( - model_config=RedBearModelConfig.model_validate(cfg_dict) - ) # 指标收集 latencies_llm: List[float] = [] @@ -768,10 +742,10 @@ async def run_longmemeval_test( if stmt_text: contexts_all.append(stmt_text) - # for sm in summaries: - # summary_text = str(sm.get("summary", "")).strip() - # if summary_text: - # contexts_all.append(summary_text) + for sm in summaries: + summary_text = str(sm.get("summary", "")).strip() + if summary_text: + contexts_all.append(summary_text) # 实体摘要(最多3个) scored = [e for e in entities if e.get("score") is not None] @@ -1228,8 +1202,8 @@ async def run_longmemeval_test( "search_limit": search_limit, "context_char_budget": context_char_budget, "search_type": search_type, - "llm_id": SELECTED_LLM_ID, - "embedding_id": SELECTED_EMBEDDING_ID, + "llm_id": os.getenv("EVAL_LLM_ID"), + "embedding_id": os.getenv("EVAL_EMBEDDING_ID"), "sample_size": sample_size, "start_index": start_index, }, @@ -1288,7 +1262,7 @@ def main(): parser.add_argument("--sample-size", type=int, default=3, help="样本数量(<=0 表示全部)") parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size)") parser.add_argument("--start-index", type=int, default=0, help="起始样本索引") - parser.add_argument("--group-id", type=str, default="longmemeval_zh_bak_3", help="图数据库 Group ID") + parser.add_argument("--end-user-id", type=str, default=None, help="图数据库 End User ID,默认使用环境变量") parser.add_argument("--search-limit", type=int, default=8, help="检索条数上限") parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算") parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度") @@ -1349,7 +1323,8 @@ def main(): # 保存结果到文件 try: - out_dir = os.path.join(PROJECT_ROOT, "evaluation", "longmemeval", "results") + # 使用相对路径而不是 PROJECT_ROOT + out_dir = Path(__file__).resolve().parent / "results" os.makedirs(out_dir, exist_ok=True) ts = datetime.now().strftime("%Y%m%d_%H%M%S") out_path = os.path.join(out_dir, f"longmemeval_{result['params']['search_type']}_{ts}.json") diff --git a/api/app/core/memory/evaluation/longmemeval/test_eval.py b/api/app/core/memory/evaluation/longmemeval/test_eval.py index 67bd6ec2..08daa890 100644 --- a/api/app/core/memory/evaluation/longmemeval/test_eval.py +++ b/api/app/core/memory/evaluation/longmemeval/test_eval.py @@ -2,81 +2,67 @@ import argparse import asyncio import json import os +import time import re import statistics -import time from datetime import datetime, timedelta -from typing import Any, Dict, List +from typing import List, Dict, Any +from pathlib import Path -try: - from dotenv import load_dotenv -except Exception: - def load_dotenv(): - return None +from dotenv import load_dotenv + +# Load evaluation config +eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation" +if eval_config_path.exists(): + load_dotenv(eval_config_path, override=True) + print(f"✅ 加载评估配置: {eval_config_path}") # 与现有评估脚本保持一致的导入方式 -from app.core.memory.evaluation.common.metrics import ( - avg_context_tokens, - jaccard, - latency_stats, -) -from app.core.memory.evaluation.common.metrics import f1_score as common_f1 -from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME -from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.memory.utils.config.definitions import ( - PROJECT_ROOT, - SELECTED_EMBEDDING_ID, - SELECTED_LLM_ID, -) -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.core.models.base import RedBearModelConfig -from app.db import get_db_context -from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.services.memory_config_service import MemoryConfigService - -try: - from app.core.memory.evaluation.common.metrics import exact_match -except Exception: - # 兜底:简单的大小写不敏感比较 - def exact_match(pred: str, ref: str) -> bool: - return str(pred).strip().lower() == str(ref).strip().lower() +from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding +from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient +from app.core.models.base import RedBearModelConfig +from app.core.memory.utils.config.config_utils import get_embedder_config +from app.core.memory.utils.llm.llm_utils import get_llm_client +from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME +from app.core.memory.evaluation.common.metrics import f1_score as common_f1, jaccard, latency_stats, avg_context_tokens +from app.core.memory.evaluation.common.metrics import exact_match def load_dataset_any(path: str) -> List[Dict[str, Any]]: - """健壮地加载数据集(兼容 list 或多段 JSON)。""" + """健壮地加载数据集,支持三种格式: + 1. 标准 JSON 数组: [{...}, {...}] + 2. 单个 JSON 对象: {...} + 3. JSONL 格式(每行一个 JSON): {...}\n{...}\n{...} + """ with open(path, "r", encoding="utf-8") as f: - s = f.read().strip() + content = f.read().strip() + + # 尝试标准 JSON 解析 try: - obj = json.loads(s) - if isinstance(obj, list): - return obj - elif isinstance(obj, dict): - return [obj] + data = json.loads(content) + if isinstance(data, list): + return [item for item in data if isinstance(item, dict)] + elif isinstance(data, dict): + return [data] except json.JSONDecodeError: pass - dec = json.JSONDecoder() - idx = 0 - items: List[Dict[str, Any]] = [] - while idx < len(s): - while idx < len(s) and s[idx].isspace(): - idx += 1 - if idx >= len(s): - break + + # 尝试 JSONL 格式(每行一个 JSON 对象) + items = [] + for line in content.splitlines(): + line = line.strip() + if not line: + continue try: - obj, end = dec.raw_decode(s, idx) - if isinstance(obj, list): - for it in obj: - if isinstance(it, dict): - items.append(it) - elif isinstance(obj, dict): + obj = json.loads(line) + if isinstance(obj, dict): items.append(obj) - idx = end + elif isinstance(obj, list): + items.extend(item for item in obj if isinstance(item, dict)) except json.JSONDecodeError: - nl = s.find("\n", idx) - if nl == -1: - break - idx = nl + 1 + continue + return items @@ -640,15 +626,15 @@ async def run_longmemeval_test( # 数据路径 if not data_path: - # 固定使用中文数据集:data/longmemeval_oracle_zh.json - zh_proj = os.path.join(PROJECT_ROOT, "data", "longmemeval_oracle_zh.json") - zh_cwd = os.path.join(os.getcwd(), "data", "longmemeval_oracle_zh.json") - if os.path.exists(zh_proj): - data_path = zh_proj - elif os.path.exists(zh_cwd): - data_path = zh_cwd - else: - raise FileNotFoundError("未找到数据集: data/longmemeval_oracle_zh.json,请确保其存在于项目根目录或当前工作目录的 data 目录下。") + # 固定使用中文数据集:dataset/longmemeval_oracle_zh.json + dataset_dir = Path(__file__).resolve().parent.parent / "dataset" + data_path = str(dataset_dir / "longmemeval_oracle_zh.json") + + if not os.path.exists(data_path): + raise FileNotFoundError( + f"数据集文件不存在: {data_path}\n" + f"请将 longmemeval_oracle_zh.json 放置在: {dataset_dir}" + ) qa_list: List[Dict[str, Any]] = load_dataset_any(data_path) # 支持评估全部样本:当 sample_size <= 0 时,取从 start_index 到末尾 @@ -658,13 +644,9 @@ async def run_longmemeval_test( items = qa_list[start_index:start_index + sample_size] # 初始化组件 - 使用异步LLM客户端 - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(SELECTED_LLM_ID) + llm_client = get_llm_client(os.getenv("EVAL_LLM_ID")) connector = Neo4jConnector() - with get_db_context() as db: - config_service = MemoryConfigService(db) - cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID) + cfg_dict = get_embedder_config(os.getenv("EVAL_EMBEDDING_ID")) embedder = OpenAIEmbedderClient( model_config=RedBearModelConfig.model_validate(cfg_dict) ) @@ -1203,8 +1185,8 @@ async def run_longmemeval_test( "search_limit": search_limit, "context_char_budget": context_char_budget, "search_type": search_type, - "llm_id": SELECTED_LLM_ID, - "embedding_id": SELECTED_EMBEDDING_ID, + "llm_id": os.getenv("EVAL_LLM_ID"), + "embedding_id": os.getenv("EVAL_EMBEDDING_ID"), "sample_size": sample_size, "start_index": start_index, }, diff --git a/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py b/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py index 3023020a..e07b0cab 100644 --- a/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py +++ b/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py @@ -2,81 +2,30 @@ import argparse import asyncio import json import os -import re import time from datetime import datetime -from typing import Any, Dict, List - -try: - from dotenv import load_dotenv -except Exception: - def load_dotenv(): - return None - -# 路径与模块导入保持与现有评估脚本一致 -import sys +from typing import List, Dict, Any +import re from pathlib import Path -_THIS_DIR = Path(__file__).resolve().parent -_PROJECT_ROOT = str(_THIS_DIR.parents[1]) -_SRC_DIR = os.path.join(_PROJECT_ROOT, "src") -for _p in (_SRC_DIR, _PROJECT_ROOT): - if _p not in sys.path: - sys.path.insert(0, _p) +from dotenv import load_dotenv + +# Load evaluation config +eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation" +if eval_config_path.exists(): + load_dotenv(eval_config_path, override=True) + print(f"✅ 加载评估配置: {eval_config_path}") -# 对齐 locomo_test 的检索逻辑:直接使用 graph_search 与 Neo4jConnector/Embedder1 -from app.core.memory.evaluation.common.metrics import ( - avg_context_tokens, - exact_match, - latency_stats, -) -from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.memory.utils.config.definitions import ( - PROJECT_ROOT, - SELECTED_EMBEDDING_ID, - SELECTED_GROUP_ID, - SELECTED_LLM_ID, -) -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.core.models.base import RedBearModelConfig -from app.db import get_db_context -from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.services.memory_config_service import MemoryConfigService +from app.core.memory.src.search import run_hybrid_search # 使用与 evaluate_qa.py 相同的检索函数 +from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient +from app.core.models.base import RedBearModelConfig +from app.core.memory.utils.config.config_utils import get_embedder_config -try: - from app.core.memory.evaluation.common.metrics import bleu1, f1_score, jaccard -except Exception: - # 兜底:简单实现(必要时) - def f1_score(pred: str, ref: str) -> float: - ps = pred.lower().split() - rs = ref.lower().split() - if not ps or not rs: - return 0.0 - tp = len(set(ps) & set(rs)) - if tp == 0: - return 0.0 - precision = tp / len(ps) - recall = tp / len(rs) - if precision + recall == 0: - return 0.0 - return 2 * precision * recall / (precision + recall) +from app.core.memory.utils.llm.llm_utils import get_llm_client +from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens - def bleu1(pred: str, ref: str) -> float: - ps = pred.lower().split() - rs = ref.lower().split() - if not ps or not rs: - return 0.0 - overlap = len([w for w in ps if w in rs]) - return overlap / max(len(ps), 1) - - def jaccard(pred: str, ref: str) -> float: - ps = set(pred.lower().split()) - rs = set(ref.lower().split()) - union = len(ps | rs) - if union == 0: - return 0.0 - return len(ps & rs) / union +from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str: @@ -219,16 +168,16 @@ async def run_memsciqa_test( # 默认使用指定的 memsci 组 ID end_user_id = end_user_id or "group_memsci" - # 数据路径解析(项目根与当前工作目录兜底) + # 数据路径解析 if not data_path: - proj_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl") - cwd_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl") - if os.path.exists(proj_path): - data_path = proj_path - elif os.path.exists(cwd_path): - data_path = cwd_path - else: - raise FileNotFoundError("未找到数据集: data/msc_self_instruct.jsonl,请确保其存在于项目根目录或当前工作目录的 data 目录下。") + dataset_dir = Path(__file__).resolve().parent.parent / "dataset" + data_path = str(dataset_dir / "msc_self_instruct.jsonl") + + if not os.path.exists(data_path): + raise FileNotFoundError( + f"数据集文件不存在: {data_path}\n" + f"请将 msc_self_instruct.jsonl 放置在: {dataset_dir}" + ) # 加载数据 all_items = load_dataset_memsciqa(data_path) @@ -238,17 +187,13 @@ async def run_memsciqa_test( items = all_items[start_index:start_index + sample_size] # 初始化 LLM(纯测试:不进行摄入) - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm = factory.get_llm_client(SELECTED_LLM_ID) + llm = get_llm_client(os.getenv("EVAL_LLM_ID")) # 初始化 Neo4j 连接与向量检索 Embedder(对齐 locomo_test) connector = Neo4jConnector() embedder = None if search_type in ("embedding", "hybrid"): - with get_db_context() as db: - config_service = MemoryConfigService(db) - cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID) + cfg_dict = get_embedder_config(os.getenv("EVAL_EMBEDDING_ID")) embedder = OpenAIEmbedderClient( model_config=RedBearModelConfig.model_validate(cfg_dict) ) @@ -273,7 +218,7 @@ async def run_memsciqa_test( question = item.get("self_instruct", {}).get("B", "") or item.get("question", "") reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "") - # 三路检索:chunks/statements/entities/summaries(对齐 qwen_search_eval.py) + # 检索:使用与 evaluate_qa.py 相同的 run_hybrid_search t0 = time.time() results = None try: @@ -302,57 +247,94 @@ async def run_memsciqa_test( search_ms = (t1 - t0) * 1000 latencies_search.append(search_ms) - # 构建上下文:包含 chunks、陈述、摘要和实体(对齐 qwen_search_eval.py) + # 构建上下文:与 evaluate_qa.py 完全一致的逻辑 contexts_all: List[str] = [] retrieved_counts: Dict[str, int] = {} if results: - chunks = results.get("chunks", []) - statements = results.get("statements", []) - entities = results.get("entities", []) - summaries = results.get("summaries", []) + # 处理 hybrid 搜索结果 + if search_type == "hybrid": + emb = results.get("embedding_search", {}) if isinstance(results.get("embedding_search"), dict) else {} + kw = results.get("keyword_search", {}) if isinstance(results.get("keyword_search"), dict) else {} + emb_dialogs = emb.get("dialogues", []) + emb_statements = emb.get("statements", []) + emb_entities = emb.get("entities", []) + kw_dialogs = kw.get("dialogues", []) + kw_statements = kw.get("statements", []) + kw_entities = kw.get("entities", []) + all_dialogs = emb_dialogs + kw_dialogs + all_statements = emb_statements + kw_statements + all_entities = emb_entities + kw_entities + + # 简单去重 + seen_dialog = set() + dialogues = [] + for d in all_dialogs: + key = (str(d.get("uuid", "")), str(d.get("content", ""))) + if key not in seen_dialog: + dialogues.append(d) + seen_dialog.add(key) + + seen_stmt = set() + statements = [] + for s in all_statements: + key = str(s.get("statement", "")) + if key not in seen_stmt: + statements.append(s) + seen_stmt.add(key) + + seen_ent = set() + entities = [] + for e in all_entities: + key = str(e.get("name", "")) + if key not in seen_ent: + entities.append(e) + seen_ent.add(key) + else: + # embedding 或 keyword 单独搜索 + dialogues = results.get("dialogues", []) + statements = results.get("statements", []) + entities = results.get("entities", []) + retrieved_counts = { - "chunks": len(chunks), + "dialogues": len(dialogues), "statements": len(statements), "entities": len(entities), - "summaries": len(summaries), } - # 优先使用 chunks - for c in chunks: - text = str(c.get("content", "")).strip() + + # 构建上下文文本 + for d in dialogues: + text = str(d.get("content", "")).strip() if text: contexts_all.append(text) - # 然后是 statements + for s in statements: text = str(s.get("statement", "")).strip() if text: contexts_all.append(text) - # 然后是 summaries - for sm in summaries: - text = str(sm.get("summary", "")).strip() - if text: - contexts_all.append(text) - # 实体摘要:最多加入前3个高分实体(对齐 qwen_search_eval.py) - scored = [e for e in entities if e.get("score") is not None] - top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3] - if top_entities: - summary_lines = [] - for e in top_entities: - name = str(e.get("name", "")).strip() - etype = str(e.get("entity_type", "")).strip() - score = e.get("score") - if name: - meta = [] - if etype: - meta.append(f"type={etype}") - if isinstance(score, (int, float)): - meta.append(f"score={score:.3f}") - summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}") - if summary_lines: - contexts_all.append("\n".join(summary_lines)) + + # 实体摘要 + if entities: + scored = [e for e in entities if e.get("score") is not None] + top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3] + if top_entities: + summary_lines = [] + for e in top_entities: + name = str(e.get("name", "")).strip() + etype = str(e.get("entity_type", "")).strip() + score = e.get("score") + if name: + meta = [] + if etype: + meta.append(f"type={etype}") + if isinstance(score, (int, float)): + meta.append(f"score={score:.3f}") + summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}") + if summary_lines: + contexts_all.append("\n".join(summary_lines)) if verbose: if retrieved_counts: - print(f"✅ 检索成功: {retrieved_counts.get('chunks',0)} chunks, {retrieved_counts.get('statements',0)} 条陈述, {retrieved_counts.get('entities',0)} 个实体, {retrieved_counts.get('summaries',0)} 个摘要") + print(f"✅ 检索成功: {retrieved_counts.get('dialogues',0)} dialogues, {retrieved_counts.get('statements',0)} 条陈述, {retrieved_counts.get('entities',0)} 个实体, {retrieved_counts.get('summaries',0)} 个摘要") print(f"📊 有效上下文数量: {len(contexts_all)}") q_keywords = extract_question_keywords(question, max_keywords=8) if q_keywords: @@ -507,8 +489,8 @@ async def run_memsciqa_test( "llm_max_tokens": llm_max_tokens, "search_type": search_type, "start_index": start_index, - "llm_id": SELECTED_LLM_ID, - "retrieval_embedding_id": SELECTED_EMBEDDING_ID + "llm_id": os.getenv("EVAL_LLM_ID"), + "retrieval_embedding_id": os.getenv("EVAL_EMBEDDING_ID") }, "timestamp": datetime.now().isoformat(), } @@ -522,7 +504,7 @@ async def run_memsciqa_test( def main(): load_dotenv() parser = argparse.ArgumentParser(description="memsciqa 测试脚本(三路检索 + 智能上下文选择)") - parser.add_argument("--sample-size", type=int, default=30, help="样本数量(<=0 表示全部)") + parser.add_argument("--sample-size", type=int, default=10, help="样本数量(<=0 表示全部)") parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size)") parser.add_argument("--start-index", type=int, default=0, help="起始样本索引") parser.add_argument("--group-id", type=str, default="group_memsci", help="图数据库 Group ID(默认 group_memsci)") diff --git a/api/app/core/memory/evaluation/memsciqa/evaluate_qa.py b/api/app/core/memory/evaluation/memsciqa/memsciqa_benchmark.py similarity index 76% rename from api/app/core/memory/evaluation/memsciqa/evaluate_qa.py rename to api/app/core/memory/evaluation/memsciqa/memsciqa_benchmark.py index 869fdb60..40684f4c 100644 --- a/api/app/core/memory/evaluation/memsciqa/evaluate_qa.py +++ b/api/app/core/memory/evaluation/memsciqa/memsciqa_benchmark.py @@ -4,35 +4,20 @@ import json import os import time from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List +from typing import List, Dict, Any +from pathlib import Path +from dotenv import load_dotenv -if TYPE_CHECKING: - from app.schemas.memory_config_schema import MemoryConfig +# Load evaluation config +eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation" +if eval_config_path.exists(): + load_dotenv(eval_config_path, override=True) -try: - from dotenv import load_dotenv -except Exception: - def load_dotenv(): - return None - -from app.core.memory.evaluation.common.metrics import ( - avg_context_tokens, - exact_match, - latency_stats, -) -from app.core.memory.evaluation.extraction_utils import ( - ingest_contexts_via_full_pipeline, -) -from app.core.memory.storage_services.search import run_hybrid_search -from app.core.memory.utils.config.definitions import ( - PROJECT_ROOT, - SELECTED_EMBEDDING_ID, - SELECTED_GROUP_ID, - SELECTED_LLM_ID, -) -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.core.memory.src.search import run_hybrid_search # 使用旧版本(重构前) +from app.core.memory.utils.llm.llm_utils import get_llm_client +from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline +from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str: @@ -135,24 +120,37 @@ def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any return merged + async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]: end_user_id = end_user_id or SELECTED_GROUP_ID + # Load data - data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl") + dataset_dir = Path(__file__).resolve().parent.parent / "dataset" + data_path = dataset_dir / "msc_self_instruct.jsonl" + if not os.path.exists(data_path): - data_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl") + raise FileNotFoundError( + f"数据集文件不存在: {data_path}\n" + f"请将 msc_self_instruct.jsonl 放置在: {dataset_dir}" + ) with open(data_path, "r", encoding="utf-8") as f: lines = f.readlines() items: List[Dict[str, Any]] = [json.loads(l) for l in lines[:sample_size]] + + # 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入 # 说明:memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略 contexts: List[str] = [build_context_from_dialog(item) for item in items] await ingest_contexts_via_full_pipeline(contexts, end_user_id) # LLM client (使用异步调用) - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(SELECTED_LLM_ID) + from app.db import get_db + + db = next(get_db()) + try: + llm_client = get_llm_client(os.getenv("EVAL_LLM_ID"), db) + finally: + db.close() # Evaluate each item connector = Neo4jConnector() @@ -177,7 +175,6 @@ async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None limit=search_limit, include=["dialogues", "statements", "entities"], output_path=None, - memory_config=memory_config, ) except Exception: results = None @@ -261,11 +258,7 @@ async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip()) # Metrics: F1, BLEU-1, Jaccard; keep exact match for reference correct_flags.append(exact_match(pred, reference)) - from app.core.memory.evaluation.common.metrics import ( - bleu1, - f1_score, - jaccard, - ) + from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard f1s.append(f1_score(str(pred), str(reference))) b1s.append(bleu1(str(pred), str(reference))) jss.append(jaccard(str(pred), str(reference))) @@ -295,15 +288,39 @@ async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None def main(): + # Load environment variables first load_dotenv() + + # Get defaults from environment variables + env_sample_size = os.getenv("MEMSCIQA_SAMPLE_SIZE") + env_search_limit = os.getenv("MEMSCIQA_SEARCH_LIMIT") + env_context_budget = os.getenv("MEMSCIQA_CONTEXT_CHAR_BUDGET") + env_llm_max_tokens = os.getenv("MEMSCIQA_LLM_MAX_TOKENS") + env_skip_ingest = os.getenv("MEMSCIQA_SKIP_INGEST", "false").lower() in ("true", "1", "yes") + env_output_dir = os.getenv("MEMSCIQA_OUTPUT_DIR") + + # Convert to appropriate types with fallback to code defaults + default_sample_size = int(env_sample_size) if env_sample_size else 1 + default_search_limit = int(env_search_limit) if env_search_limit else 8 + default_context_budget = int(env_context_budget) if env_context_budget else 4000 + default_llm_max_tokens = int(env_llm_max_tokens) if env_llm_max_tokens else 64 + default_output_dir = env_output_dir if env_output_dir else None + parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen") + parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量") - parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id,默认取 runtime.json") + parser.add_argument("--end-user-id", type=str, default=None, help="可选 end_user_id,默认使用环境变量") parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数") parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算") + parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度") - parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大生成长度") + parser.add_argument("--llm-max-tokens", type=int, default=default_llm_max_tokens, + help=f"LLM 最大生成长度 (env: MEMSCIQA_LLM_MAX_TOKENS={env_llm_max_tokens or 'not set'})") parser.add_argument("--search-type", type=str, choices=["keyword","embedding","hybrid"], default="hybrid", help="检索类型") + parser.add_argument("--skip-ingest", action="store_true", default=env_skip_ingest, + help=f"跳过数据摄入,使用 Neo4j 中的现有数据 (env: MEMSCIQA_SKIP_INGEST={os.getenv('MEMSCIQA_SKIP_INGEST', 'false')})") + parser.add_argument("--output-dir", type=str, default=default_output_dir, + help=f"结果保存目录 (env: MEMSCIQA_OUTPUT_DIR={env_output_dir or 'not set'})") args = parser.parse_args() result = asyncio.run( @@ -315,9 +332,37 @@ def main(): llm_temperature=args.llm_temperature, llm_max_tokens=args.llm_max_tokens, search_type=args.search_type, + skip_ingest=args.skip_ingest, ) ) + + # Print results to console print(json.dumps(result, ensure_ascii=False, indent=2)) + + # Save results to file + output_dir = args.output_dir + if output_dir is None: + # Use absolute path to ensure results are saved in the correct location + script_dir = Path(__file__).resolve().parent + output_dir = script_dir / "results" + elif not Path(output_dir).is_absolute(): + # If relative path, make it relative to this script's directory + script_dir = Path(__file__).resolve().parent + output_dir = script_dir / output_dir + else: + output_dir = Path(output_dir) + + output_dir.mkdir(parents=True, exist_ok=True) + + timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = output_dir / f"memsciqa_{timestamp_str}.json" + + try: + with open(output_path, "w", encoding="utf-8") as f: + json.dump(result, f, ensure_ascii=False, indent=2) + print(f"\n✅ 结果已保存到: {output_path}") + except Exception as e: + print(f"\n❌ 保存结果失败: {e}") if __name__ == "__main__": diff --git a/api/app/core/memory/evaluation/run_eval.py b/api/app/core/memory/evaluation/run_eval.py index c5aacb2f..56b2e790 100644 --- a/api/app/core/memory/evaluation/run_eval.py +++ b/api/app/core/memory/evaluation/run_eval.py @@ -2,20 +2,16 @@ import argparse import asyncio import json import os -import sys from typing import Any, Dict +from pathlib import Path +from dotenv import load_dotenv -# Add src directory to Python path for proper imports when running from evaluation directory -sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'src')) - -try: - from dotenv import load_dotenv -except Exception: - def load_dotenv(): - return None +# Load evaluation config +eval_config_path = Path(__file__).resolve().parent / ".env.evaluation" +if eval_config_path.exists(): + load_dotenv(eval_config_path, override=True) from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, PROJECT_ROOT from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test @@ -36,8 +32,9 @@ async def run( start_index: int | None = None, max_contexts_per_item: int | None = None, ) -> Dict[str, Any]: - # 恢复原始风格:统一入口做路由,并沿用各数据集既有默认 - end_user_id = end_user_id or SELECTED_GROUP_ID + # Use environment variable with fallback chain if not provided + if end_user_id is None: + end_user_id = os.getenv("EVAL_END_USER_ID", "benchmark_default") if reset_group: connector = Neo4jConnector() diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 8c69c7cf..7b7e854b 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -1064,13 +1064,16 @@ class ExtractionOrchestrator: if statement.triplet_extraction_info: triplet_info = statement.triplet_extraction_info - # 创建实体索引到ID的映射 + # 创建实体索引到ID的映射(支持多种索引方式) entity_idx_to_id = {} # 创建实体节点 for entity_idx, entity in enumerate(triplet_info.entities): - # 映射实体索引到实体ID + # 映射实体索引到实体ID(使用多个键以提高容错性) + # 1. 使用实体自己的 entity_idx entity_idx_to_id[entity.entity_idx] = entity.id + # 2. 使用枚举索引(从0开始) + entity_idx_to_id[entity_idx] = entity.id if entity.id not in entity_id_set: entity_connect_strength = getattr(entity, 'connect_strength', 'Strong') @@ -1149,9 +1152,18 @@ class ExtractionOrchestrator: relationship_result ) else: - logger.warning( - f"跳过三元组 - 无法找到实体ID: subject_id={triplet.subject_id}, " - f"object_id={triplet.object_id}, statement_id={statement.id}" + # 改进的警告信息,包含更多调试信息 + missing_subject = "subject" if not subject_entity_id else "" + missing_object = "object" if not object_entity_id else "" + missing_both = " and " if (not subject_entity_id and not object_entity_id) else "" + + logger.debug( + f"跳过三元组 - 无法找到{missing_subject}{missing_both}{missing_object}实体ID: " + f"subject_id={triplet.subject_id} ({triplet.subject_name}), " + f"object_id={triplet.object_id} ({triplet.object_name}), " + f"predicate={triplet.predicate}, " + f"statement_id={statement.id}, " + f"available_indices={sorted(entity_idx_to_id.keys())}" ) logger.info( diff --git a/api/app/models/agent_app_config_model.py b/api/app/models/agent_app_config_model.py index 0a7a5935..96752c8e 100644 --- a/api/app/models/agent_app_config_model.py +++ b/api/app/models/agent_app_config_model.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import relationship from app.base.type import PydanticType from app.db import Base -from app.schemas import ModelParameters +from app.schemas.app_schema import ModelParameters class AgentConfig(Base): diff --git a/api/app/models/multi_agent_model.py b/api/app/models/multi_agent_model.py index 544ddb27..400c05ad 100644 --- a/api/app/models/multi_agent_model.py +++ b/api/app/models/multi_agent_model.py @@ -10,7 +10,7 @@ from sqlalchemy.orm import relationship from app.base.type import PydanticType from app.db import Base -from app.schemas import ModelParameters +from app.schemas.app_schema import ModelParameters class OrchestrationMode(StrEnum): diff --git a/api/app/schemas/multi_agent_schema.py b/api/app/schemas/multi_agent_schema.py index c0d72cdd..8fba2929 100644 --- a/api/app/schemas/multi_agent_schema.py +++ b/api/app/schemas/multi_agent_schema.py @@ -4,7 +4,7 @@ import datetime from typing import Optional, List, Dict, Any, Union from pydantic import BaseModel, Field, ConfigDict, field_serializer -from app.schemas import ModelParameters +from app.schemas.app_schema import ModelParameters # ==================== 子 Agent 配置 ==================== diff --git a/api/app/services/master_agent_router.py b/api/app/services/master_agent_router.py index 3971aab7..87fdb22c 100644 --- a/api/app/services/master_agent_router.py +++ b/api/app/services/master_agent_router.py @@ -5,7 +5,7 @@ import uuid from typing import Dict, Any, List, Optional, Tuple from sqlalchemy.orm import Session -from app.schemas import ModelParameters +from app.schemas.app_schema import ModelParameters from app.services.conversation_state_manager import ConversationStateManager from app.models import ModelConfig, AgentConfig from app.core.logging_config import get_business_logger diff --git a/api/app/utils/app_config_utils.py b/api/app/utils/app_config_utils.py index ae41d8bf..06549989 100644 --- a/api/app/utils/app_config_utils.py +++ b/api/app/utils/app_config_utils.py @@ -57,7 +57,7 @@ def dict_to_model_parameters(data: Optional[Dict[str, Any]]) -> Optional[Any]: if data is None: return None - from app.schemas import ModelParameters + from app.schemas.app_schema import ModelParameters if isinstance(data, ModelParameters): return data From 87731090cab1097dbb30909e388b62774bb10283 Mon Sep 17 00:00:00 2001 From: Mark Date: Mon, 26 Jan 2026 19:19:41 +0800 Subject: [PATCH 27/28] [modify] migration script --- api/migrations/versions/325b759cd66b_2026011240.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/migrations/versions/325b759cd66b_2026011240.py b/api/migrations/versions/325b759cd66b_2026011240.py index 763b0289..3d7443a8 100644 --- a/api/migrations/versions/325b759cd66b_2026011240.py +++ b/api/migrations/versions/325b759cd66b_2026011240.py @@ -31,6 +31,7 @@ def upgrade() -> None: op.execute("UPDATE memory_config SET config_id = apply_id::uuid") op.alter_column('memory_config', 'config_id', nullable=False) op.create_primary_key('memory_config_pkey', 'memory_config', ['config_id']) + op.execute("ALTER TABLE memory_config ALTER COLUMN config_id_old DROP DEFAULT") op.execute("DROP SEQUENCE IF EXISTS data_config_config_id_seq") From c3ea3b751b07c9325f80000e9a4d93a0a4790fe4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= <162269739+lanceyq@users.noreply.github.com> Date: Mon, 26 Jan 2026 20:30:07 +0800 Subject: [PATCH 28/28] delete benchmark-test (#204) * Refactor: Move evaluation folder to redbear-mem-benchmark submodule * [changes]Restore .gitmodules --- .../memory/evaluation/.env.evaluation.example | 224 --- api/app/core/memory/evaluation/.gitignore | 13 - api/app/core/memory/evaluation/__init__.py | 1 - api/app/core/memory/evaluation/benchmark.md | 748 --------- .../memory/evaluation/check_enduser_data.py | 371 ----- .../core/memory/evaluation/common/metrics.py | 100 -- .../memory/evaluation/dialogue_queries.py | 62 - .../memory/evaluation/extraction_utils.py | 444 ------ .../evaluation/locomo/locomo_benchmark.py | 770 ---------- .../evaluation/locomo/locomo_metrics.py | 225 --- .../memory/evaluation/locomo/locomo_test.py | 864 ----------- .../memory/evaluation/locomo/locomo_utils.py | 687 --------- .../evaluation/locomo/qwen_search_eval.py | 874 ----------- .../longmemeval/longmemeval_benchmark.py | 1339 ----------------- .../evaluation/longmemeval/test_eval.py | 1312 ---------------- .../evaluation/memsciqa/memsciqa-test.py | 559 ------- .../evaluation/memsciqa/memsciqa_benchmark.py | 369 ----- api/app/core/memory/evaluation/run_eval.py | 147 -- redbear-mem-benchmark | 2 +- 19 files changed, 1 insertion(+), 9110 deletions(-) delete mode 100644 api/app/core/memory/evaluation/.env.evaluation.example delete mode 100644 api/app/core/memory/evaluation/.gitignore delete mode 100644 api/app/core/memory/evaluation/__init__.py delete mode 100644 api/app/core/memory/evaluation/benchmark.md delete mode 100644 api/app/core/memory/evaluation/check_enduser_data.py delete mode 100644 api/app/core/memory/evaluation/common/metrics.py delete mode 100644 api/app/core/memory/evaluation/dialogue_queries.py delete mode 100644 api/app/core/memory/evaluation/extraction_utils.py delete mode 100644 api/app/core/memory/evaluation/locomo/locomo_benchmark.py delete mode 100644 api/app/core/memory/evaluation/locomo/locomo_metrics.py delete mode 100644 api/app/core/memory/evaluation/locomo/locomo_test.py delete mode 100644 api/app/core/memory/evaluation/locomo/locomo_utils.py delete mode 100644 api/app/core/memory/evaluation/locomo/qwen_search_eval.py delete mode 100644 api/app/core/memory/evaluation/longmemeval/longmemeval_benchmark.py delete mode 100644 api/app/core/memory/evaluation/longmemeval/test_eval.py delete mode 100644 api/app/core/memory/evaluation/memsciqa/memsciqa-test.py delete mode 100644 api/app/core/memory/evaluation/memsciqa/memsciqa_benchmark.py delete mode 100644 api/app/core/memory/evaluation/run_eval.py diff --git a/api/app/core/memory/evaluation/.env.evaluation.example b/api/app/core/memory/evaluation/.env.evaluation.example deleted file mode 100644 index be089eb4..00000000 --- a/api/app/core/memory/evaluation/.env.evaluation.example +++ /dev/null @@ -1,224 +0,0 @@ -# ============================================================================ -# 基准测试统一配置文件示例 -# ============================================================================ -# 复制此文件为 .env.evaluation 并根据需要修改 -# 支持的基准测试:LoCoMo、LongMemEval、MemSciQA -# ============================================================================ - -# ============================================================================ -# 通用配置(所有基准测试共用) -# ============================================================================ - -# ---------------------------------------------------------------------------- -# Neo4j 配置 -# ---------------------------------------------------------------------------- -# 默认 Group ID(建议各基准测试使用独立的 group) -EVAL_GROUP_ID=benchmark_default - -# ---------------------------------------------------------------------------- -# 模型配置(必需) -# ---------------------------------------------------------------------------- -# ⚠️ 必填:从数据库 models 表中选择有效的模型 ID -# -# 如何获取模型 ID: -# 1. 查询数据库:SELECT id, model_name FROM models WHERE is_active = true; -# 2. 或通过系统管理界面查看 -# 3. 确保模型可用且配置正确 - -# LLM 模型 ID(必填) -EVAL_LLM_ID=your_llm_model_id_here - -# Embedding 模型 ID(必填) -EVAL_EMBEDDING_ID=your_embedding_model_id_here - -# ---------------------------------------------------------------------------- -# 检索参数 -# ---------------------------------------------------------------------------- -# 检索类型: "keyword", "embedding", "hybrid" -EVAL_SEARCH_TYPE=hybrid - -# 检索结果数量限制(默认值) -EVAL_SEARCH_LIMIT=12 - -# 上下文最大字符数(默认值) -EVAL_MAX_CONTEXT_CHARS=8000 - -# ---------------------------------------------------------------------------- -# LLM 参数 -# ---------------------------------------------------------------------------- -# LLM 温度参数(0.0 = 确定性输出) -EVAL_LLM_TEMPERATURE=0.0 - -# LLM 最大生成 token 数 -EVAL_LLM_MAX_TOKENS=32 - -# LLM 超时时间(秒) -EVAL_LLM_TIMEOUT=10.0 - -# LLM 最大重试次数 -EVAL_LLM_MAX_RETRIES=1 - -# ---------------------------------------------------------------------------- -# 数据处理参数 -# ---------------------------------------------------------------------------- -# Chunker 策略 -EVAL_CHUNKER_STRATEGY=RecursiveChunker - -# 是否在导入前清空现有数据 -EVAL_RESET_ON_INGEST=true - -# 是否保存详细日志 -EVAL_SAVE_DETAILED_LOGS=true - -# ============================================================================ -# LoCoMo 基准测试专用配置 -# ============================================================================ -# 数据集:locomo10.json -# 运行:python locomo_benchmark.py --sample_size 20 -# ---------------------------------------------------------------------------- - -# Group ID(LoCoMo 专用) -LOCOMO_GROUP_ID=locomo_benchmark - -# 测试样本数量 -# 建议值:20(快速测试)、100(中等测试)、1986(完整测试) -LOCOMO_SAMPLE_SIZE=20 - -# 检索结果数量限制 -LOCOMO_SEARCH_LIMIT=12 - -# 上下文最大字符数 -LOCOMO_CONTEXT_CHAR_BUDGET=8000 - -# 导入的对话数量 -LOCOMO_MAX_DIALOGUES=1 - -# 跳过数据摄入(true=跳过,false=摄入) -# 首次运行设置为 false,后续运行可设置为 true 以节省时间 -LOCOMO_SKIP_INGEST=false - -# 结果保存目录 -LOCOMO_OUTPUT_DIR=locomo/results - -# ============================================================================ -# LongMemEval 基准测试专用配置 -# ============================================================================ -# 数据集:longmemeval_oracle_zh.json -# 运行:python longmemeval_benchmark.py --sample_size 3 -# 特点:支持时间推理问题的增强检索 -# ---------------------------------------------------------------------------- - -# Group ID(LongMemEval 专用) -LONGMEMEVAL_GROUP_ID=longmemeval_zh_bak_3 - -# 测试样本数量(<=0 表示全部样本) -LONGMEMEVAL_SAMPLE_SIZE=3 - -# 起始样本索引 -LONGMEMEVAL_START_INDEX=0 - -# 检索结果数量限制 -LONGMEMEVAL_SEARCH_LIMIT=8 - -# 上下文最大字符数 -LONGMEMEVAL_CONTEXT_CHAR_BUDGET=4000 - -# LLM 最大生成 token 数 -LONGMEMEVAL_LLM_MAX_TOKENS=16 - -# 每条样本最多摄入的上下文段数 -LONGMEMEVAL_MAX_CONTEXTS_PER_ITEM=2 - -# 是否保存分块结果 -LONGMEMEVAL_SAVE_CHUNK_OUTPUT=true - -# 自定义分块输出路径(留空使用默认) -LONGMEMEVAL_SAVE_CHUNK_OUTPUT_PATH= - -# 摄入前是否清空组数据 -LONGMEMEVAL_RESET_GROUP_BEFORE_INGEST=false - -# 是否跳过摄入,仅检索评估 -LONGMEMEVAL_SKIP_INGEST=false - -# 结果保存目录 -LONGMEMEVAL_OUTPUT_DIR=longmemeval/results - -# ============================================================================ -# MemSciQA 基准测试专用配置 -# ============================================================================ -# 数据集:msc_self_instruct.jsonl -# 运行:python memsciqa_benchmark.py --sample_size 1 -# 特点:对话记忆检索评估 -# ---------------------------------------------------------------------------- - -# Group ID(MemSciQA 专用,独立数据集) -MEMSCIQA_GROUP_ID=memsciqa_benchmark - -# 测试样本数量 -MEMSCIQA_SAMPLE_SIZE=1 # 0或者-1标识测试数据集中的所有样本 - -# 检索结果数量限制 -MEMSCIQA_SEARCH_LIMIT=8 - -# 上下文最大字符数 -MEMSCIQA_CONTEXT_CHAR_BUDGET=4000 - -# LLM 最大生成 token 数 -MEMSCIQA_LLM_MAX_TOKENS=64 - -# 跳过数据摄入(true=跳过,false=摄入) -# 首次运行设置为 false,后续运行可设置为 true 以节省时间 -MEMSCIQA_SKIP_INGEST=false - -# 结果保存目录(相对于 memsciqa 脚本所在目录) -# 使用 "results" 会保存到 api/app/core/memory/evaluation/memsciqa/results/ -MEMSCIQA_OUTPUT_DIR=results - -# ============================================================================ -# 高级配置(可选) -# ============================================================================ - -# BM25 权重(用于混合检索,0.0-1.0) -EVAL_RERANK_ALPHA=0.6 - -# 是否使用遗忘重排序 -EVAL_USE_FORGETTING_RERANK=false - -# 是否使用 LLM 重排序 -EVAL_USE_LLM_RERANK=false - -# 连接重置间隔(每 N 个问题重置一次) -EVAL_RESET_INTERVAL=5 - -# 性能阈值(低于此值触发重置) -EVAL_PERFORMANCE_THRESHOLD=0.6 - -# ============================================================================ -# 快速配置指南 -# ============================================================================ -# 1. 复制此文件为 .env.evaluation -# 2. 修改 EVAL_LLM_ID 和 EVAL_EMBEDDING_ID 为你的模型 ID -# 3. 根据需要修改各基准测试的专用配置 -# 4. 运行测试: -# - LoCoMo: python locomo/locomo_benchmark.py --sample_size 20 -# - LongMemEval: python longmemeval/longmemeval_benchmark.py --sample_size 3 --all -# - MemSciQA: python memsciqa/memsciqa_benchmark.py --sample_size 10 -# 配置优先级: -# 命令行参数 > 特定配置(如 LOCOMO_*)> 通用配置(EVAL_*)> 代码默认值 -# ============================================================================ - - -# 执行LoCoMo测试 -# 只摄入前5条消息,评估3个问题(最小测试) -# python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 3 --max_ingest_messages 5 -# -# 如果数据已经摄入,跳过摄入阶段直接测试 -# python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 5 --skip_ingest - - -# 执行longmemeval测试 -# python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark --sample-size 10 --max-contexts-per-item 3 --reset-group-before-ingest - -# 执行memsciqa测试 -# python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark --sample-size 1 diff --git a/api/app/core/memory/evaluation/.gitignore b/api/app/core/memory/evaluation/.gitignore deleted file mode 100644 index 38b1055a..00000000 --- a/api/app/core/memory/evaluation/.gitignore +++ /dev/null @@ -1,13 +0,0 @@ -# 忽略实际的评估配置文件(包含敏感信息) -.env.evaluation - -# 保留示例文件 -!.env.evaluation.example - -# 忽略测试结果文件 -*/results/*.json -*/results/*.log - -# 忽略数据集文件(文件过大,不应提交到 Git) -dataset/*.json -dataset/*.jsonl diff --git a/api/app/core/memory/evaluation/__init__.py b/api/app/core/memory/evaluation/__init__.py deleted file mode 100644 index e9d6aa6c..00000000 --- a/api/app/core/memory/evaluation/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Evaluation package with dataset-specific pipelines and a unified runner.""" diff --git a/api/app/core/memory/evaluation/benchmark.md b/api/app/core/memory/evaluation/benchmark.md deleted file mode 100644 index 7c31cccd..00000000 --- a/api/app/core/memory/evaluation/benchmark.md +++ /dev/null @@ -1,748 +0,0 @@ -# 1.数据集下载地址 -Locomo10.json : https://github.com/snap-research/locomo/tree/main/data -LongMemEval_oracle.json : https://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned -msc_self_instruct.jsonl : https://huggingface.co/datasets/MemGPT/MSC-Self-Instruct - -数据集下载之后保存至api\app\core\memory\evaluation\dataset目录下 -# 2.配置说明 -文件api\app\core\memory\evaluation\.env.evaluation.example对三个基准测试所需配置有着详细的说明 -**实际配置文件**:api\app\core\memory\evaluation\.env.evaluation -```python -# 当使用不带配置参数的命令行执行基准测试,基准测试所需的配置参数根据.env.evaluation中的参数执行 -python -m app.core.memory.evaluation.locomo.locomo_benchmark -``` -**检查neo4j指定的grou_id是否摄入数据** -```python -# 1. 进入交互模式 -python -m app.core.memory.evaluation.check_enduser_data - -# 2. 选择 "1" 检查指定 group -# 3. 输入 group_id,例如: locomo_benchmark -# 4. 选择是否显示详细统计 (y/n) -``` -# 3.locomo - -### (1)locomo执行命令 -```python -# 首先进入api目录 -cd api - -# 只摄入前5条消息,评估3个问题(最小测试) -python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 3 --max_ingest_messages 5 - -# 如果数据已经摄入,跳过摄入阶段直接测试(使用skip_ingest参数) -python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 5 --skip_ingest -``` -### (2)locomo结果说明 - -#### 结果示例 -```json -{ - "dataset": "locomo", - "sample_size": 0, - "timestamp": "2026-01-26T11:24:28.239156", - "params": { - "group_id": "locomo_benchmark", - "search_type": "hybrid", - "search_limit": 12, - "context_char_budget": 8000, - "llm_id": "2c9b0782-7a85-4740-ba84-4baf77f256c4", - "embedding_id": "e2a6392d-ca63-4d59-a523-647420b59cb2" - }, - "overall_metrics": { - "f1": 0.0, - "bleu1": 0.0, - "jaccard": 0.0, - "locomo_f1": 0.0 - }, - "by_category": {}, - "latency": { - "search": { - "mean": 0.0, - "p50": 0.0, - "p95": 0.0, - "iqr": 0.0 - }, - "llm": { - "mean": 0.0, - "p50": 0.0, - "p95": 0.0, - "iqr": 0.0 - } - }, - "context_stats": { - "avg_retrieved_docs": 0.0, - "avg_context_chars": 0.0, - "avg_context_tokens": 0.0 - }, - "samples": [] -} -``` - -#### 参数详解 - -##### 1. 核心评估指标 (overall_metrics) - -**🎯 关键进步指标:** - -- **`f1`** (F1 Score): 精确率和召回率的调和平均值 - - 范围:0.0 - 1.0 - - **越高越好**,衡量检索和生成答案的准确性 - - 这是最重要的综合性能指标 - - 优秀标准:> 0.85 - -- **`bleu1`** (BLEU-1): 单词级别的匹配度 - - 范围:0.0 - 1.0 - - **越高越好**,衡量生成答案与标准答案的词汇重叠度 - - 关注词汇层面的准确性 - -- **`jaccard`** (Jaccard 相似度): 集合相似度 - - 范围:0.0 - 1.0 - - **越高越好**,衡量答案集合的相似性 - - 计算公式:交集大小 / 并集大小 - -- **`locomo_f1`**: Locomo 特定的 F1 分数 - - 范围:0.0 - 1.0 - - **越高越好**,针对 Locomo 数据集优化的评估指标 - - 考虑了长对话记忆的特殊性 - -##### 2. 性能指标 (latency) - -**⚡ 关键效率指标:** - -- **`search`**: 检索延迟统计(单位:毫秒) - - `mean`: 平均延迟 - - `p50`: 中位数延迟(50%的请求在此时间内完成) - - `p95`: 95分位数延迟(95%的请求在此时间内完成) - - `iqr`: 四分位距(Q3-Q1,衡量稳定性) - - **越低越好**,衡量记忆检索速度 - - 优秀标准:p95 < 2000ms - -- **`llm`**: LLM 推理延迟统计(单位:毫秒) - - `mean`: 平均推理时间 - - `p50`: 中位数推理时间 - - `p95`: 95分位数推理时间 - - `iqr`: 四分位距(越小越稳定) - - **越低越好**,衡量答案生成速度 - - 优秀标准:p95 < 3000ms - -##### 3. 上下文统计 (context_stats) - -**📊 资源效率指标:** - -- **`avg_retrieved_docs`**: 平均检索文档数 - - 反映检索策略的广度 - - 需要平衡:太少可能信息不足,太多增加噪音和延迟 - - 建议范围:8-15 个文档 - -- **`avg_context_chars`**: 平均上下文字符数 - - 反映检索内容的总量 - - 应在满足准确性前提下尽量精简 - - 受 `context_char_budget` 参数限制 - -- **`avg_context_tokens`**: 平均上下文 token 数 - - **越低越好**(在保持准确性前提下) - - 直接影响 API 调用成本和推理速度 - - 成本效益比 = f1 / avg_context_tokens - -##### 4. 分类统计 (by_category) - -- 按问题类型分类的性能指标 -- 帮助识别系统在不同场景下的强弱项 -- 可针对性优化特定类型的问题 - -#### 系统进步衡量标准 - -**一级指标(最重要):** -- `f1` 和 `locomo_f1` 提升 → 核心能力提升 -- 目标:f1 > 0.85 - -**二级指标(重要):** -- `latency.p95` 降低 → 用户体验提升 -- 目标:search.p95 < 2000ms, llm.p95 < 3000ms - -**三级指标(辅助):** -- `avg_context_tokens` 降低(在保持 f1 前提下)→ 成本优化 -- `iqr` 降低 → 性能稳定性提升 -# 4.longmemeval -支持时间推理问题的增强检索 -### (1)执行命令 -```python -# 首先进入api目录 -cd api - -# 不带参数运行 - 使用环境变量 -python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark - -# 命令行参数覆盖环境变量 -python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark --sample-size 2 - -# 如果数据已经摄入,跳过摄入阶段直接测试(使用skip_ingest参数) -python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark --skip_ingest -``` -### (2)结果说明 - -#### 结果示例 -```json -{ - "dataset": "longmemeval", - "items": 1, - "accuracy_by_type": { - "single-session-user": 1.0 - }, - "f1_by_type": { - "single-session-user": 1.0 - }, - "jaccard_by_type": { - "single-session-user": 1.0 - }, - "samples": [ - { - "question": "What degree did I graduate with?", - "prediction": "Business Administration", - "answer": "Business Administration", - "question_type": "single-session-user", - "is_temporal": false, - "question_id": "e47becba", - "options": [], - "context_count": 13, - "context_chars": 1268, - "retrieved_dialogue_count": 0, - "retrieved_statement_count": 12, - "metrics": { - "exact_match": true, - "f1": 1.0, - "jaccard": 1.0 - }, - "timing": { - "search_ms": 1483.100175857544, - "llm_ms": 995.8682060241699 - } - } - ], - "latency": { - "search": { - "mean": 1483.100175857544, - "p50": 1483.100175857544, - "p95": 1483.100175857544, - "iqr": 0.0 - }, - "llm": { - "mean": 995.8682060241699, - "p50": 995.8682060241699, - "p95": 995.8682060241699, - "iqr": 0.0 - } - }, - "context": { - "avg_tokens": 204.0, - "avg_chars": 1268, - "count_avg": 13 - }, - "params": { - "group_id": "longmemeval_zh_bak_3", - "search_limit": 8, - "context_char_budget": 4000, - "search_type": "hybrid", - "llm_id": "6dc52e1b-9cec-4194-af66-a74c6307fc3f", - "embedding_id": "e2a6392d-ca63-4d59-a523-647420b59cb2", - "sample_size": 1, - "start_index": 0 - }, - "timestamp": "2026-01-24T21:36:10.818308", - "metric_summary": { - "score_accuracy": 100.0, - "latency_median_s": 2.478968381881714, - "latency_iqr_s": 0.0, - "avg_context_tokens_k": 0.204 - }, - "diagnostics": { - "duplicate_previews_top": [], - "unique_preview_count": 1 - } -} -``` - -#### 参数详解 - -##### 1. 核心评估指标 - -**🎯 关键进步指标:** - -- **`accuracy_by_type`**: 按问题类型分类的准确率 - - 范围:0.0 - 1.0 - - **越高越好**,1.0 表示 100% 准确 - - 问题类型包括: - - `single-session-user`: 单会话用户信息 - - `single-session-event`: 单会话事件信息 - - `multi-session-user`: 多会话用户信息 - - `multi-session-event`: 多会话事件信息 - - 可以识别系统在不同场景下的强弱项 - -- **`f1_by_type`**: 按问题类型的 F1 分数 - - 范围:0.0 - 1.0 - - **越高越好**,综合评估精确率和召回率 - - 比单纯的准确率更全面 - -- **`jaccard_by_type`**: 按问题类型的 Jaccard 相似度 - - 范围:0.0 - 1.0 - - **越高越好**,衡量答案集合匹配度 - - 对于集合类答案特别有用 - -##### 2. 样本级指标 (samples) - -**详细诊断指标:** - -- **`metrics.exact_match`**: 精确匹配(布尔值) - - **true 越多越好**,最严格的评估标准 - - 要求预测答案与标准答案完全一致 - -- **`metrics.f1`**: 单个样本的 F1 分数 - - 范围:0.0 - 1.0 - - **越高越好**,衡量单个问题的回答质量 - -- **`is_temporal`**: 是否为时间推理问题 - - 布尔值,标识问题是否涉及时间推理 - - 时间推理问题通常更具挑战性 - -- **`context_count`**: 检索到的上下文数量 - - 反映检索策略的有效性 - - 建议范围:8-15 个上下文片段 - -- **`retrieved_dialogue_count`**: 检索到的对话数 -- **`retrieved_statement_count`**: 检索到的陈述数 - - 这两个指标帮助理解检索的内容类型分布 - - 可用于优化检索策略 - -- **`timing.search_ms`**: 单个问题的检索延迟(毫秒) -- **`timing.llm_ms`**: 单个问题的 LLM 推理延迟(毫秒) - - **越低越好**,反映单次查询的响应速度 - -##### 3. 汇总指标 (metric_summary) - -**📊 关键 KPI:** - -- **`score_accuracy`**: 总体准确率百分比 - - 范围:0.0 - 100.0 - - **越高越好**,最直观的性能指标 - - 优秀标准:> 90.0 - -- **`latency_median_s`**: 中位延迟(秒) - - **越低越好**,反映真实响应速度 - - 优秀标准:< 3.0 秒 - -- **`latency_iqr_s`**: 延迟四分位距(秒) - - **越低越好**,反映性能稳定性 - - 越小说明响应时间越稳定 - -- **`avg_context_tokens_k`**: 平均上下文 token 数(千) - - **越低越好**(在保持准确性前提下) - - 直接影响 API 调用成本 - - 成本效益比 = score_accuracy / (avg_context_tokens_k * 1000) - -##### 4. 上下文统计 (context) - -- **`avg_tokens`**: 平均 token 数 -- **`avg_chars`**: 平均字符数 -- **`count_avg`**: 平均上下文片段数 - - 这些指标反映检索内容的规模 - - 需要在准确性和效率之间平衡 - -##### 5. 性能指标 (latency) - -**⚡ 效率指标:** - -- **`search`**: 检索延迟统计(单位:毫秒) - - `mean`: 平均延迟 - - `p50`: 中位数延迟 - - `p95`: 95分位数延迟 - - `iqr`: 四分位距 - - **越低越好**,衡量记忆检索速度 - -- **`llm`**: LLM 推理延迟统计(单位:毫秒) - - `mean`: 平均推理时间 - - `p50`: 中位数推理时间 - - `p95`: 95分位数推理时间 - - `iqr`: 四分位距 - - **越低越好**,衡量答案生成速度 - -##### 6. 诊断信息 (diagnostics) - -- **`duplicate_previews_top`**: 重复预览统计 - - 列出出现频率最高的重复内容 - - 帮助发现检索冗余问题 - - 应该尽量减少重复 - -- **`unique_preview_count`**: 唯一预览数量 - - 反映检索多样性 - - **越高越好**,说明检索到的内容更丰富 - -#### 系统进步衡量标准 - -**一级指标(最重要):** -- `score_accuracy` 提升 → 核心能力提升 -- 目标:> 90.0% -- 各类型的 `accuracy_by_type` 均衡提升 → 全面能力提升 - -**二级指标(重要):** -- `latency_median_s` 降低 → 用户体验提升 -- 目标:< 3.0 秒 -- `exact_match` 比例提升 → 精确度提升 - -**三级指标(辅助):** -- `avg_context_tokens_k` 降低(在保持准确性前提下)→ 成本优化 -- `unique_preview_count` 提升 → 检索多样性提升 -- `latency_iqr_s` 降低 → 性能稳定性提升 - -**特殊关注:** -- 时间推理问题(`is_temporal: true`)的准确率 -- 多会话问题的准确率(通常更具挑战性) -# 5.memsciqa -对话记忆检索评估 -### (1)执行命令 -```python -# 首先进入api目录 -cd api - -# 不带参数运行 - 使用环境变量 -python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark - -# 命令行参数覆盖环境变量 -python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark --sample-size 100 - -# 如果数据已经摄入,跳过摄入阶段直接测试(使用skip_ingest参数) -python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark --skip_ingest -``` -### (2)结果说明 - -#### 结果示例 -```json -{ - "dataset": "memsciqa", - "items": 1, - "metrics": { - "accuracy": 0.0, - "f1": 0.0, - "bleu1": 0.0, - "jaccard": 0.0 - }, - "latency": { - "search": { - "mean": 0.0, - "p50": 0.0, - "p95": 0.0, - "iqr": 0.0 - }, - "llm": { - "mean": 3067.7285194396973, - "p50": 3067.7285194396973, - "p95": 3067.7285194396973, - "iqr": 0.0 - } - }, - "avg_context_tokens": 4.0 -} -``` - -#### 参数详解 - -##### 1. 核心评估指标 (metrics) - -**🎯 关键进步指标:** - -- **`accuracy`**: 准确率 - - 范围:0.0 - 1.0 - - **越高越好**,最直接的性能指标 - - 衡量系统回答正确的问题比例 - - 优秀标准:> 0.85 - -- **`f1`**: F1 分数 - - 范围:0.0 - 1.0 - - **越高越好**,平衡精确率和召回率 - - 计算公式:2 * (precision * recall) / (precision + recall) - - 比单纯的准确率更全面,特别适合不平衡数据集 - -- **`bleu1`**: BLEU-1 分数 - - 范围:0.0 - 1.0 - - **越高越好**,衡量词汇级别的匹配度 - - 关注生成答案与标准答案的单词重叠 - - 源自机器翻译评估,适用于自然语言生成 - -- **`jaccard`**: Jaccard 相似度 - - 范围:0.0 - 1.0 - - **越高越好**,衡量集合相似性 - - 计算公式:|A ∩ B| / |A ∪ B| - - 对于多答案或集合类问题特别有用 - -##### 2. 性能指标 (latency) - -**⚡ 效率指标:** - -- **`search`**: 检索延迟统计(单位:毫秒) - - `mean`: 平均检索延迟 - - `p50`: 中位数延迟(50%的请求在此时间内完成) - - `p95`: 95分位数延迟(95%的请求在此时间内完成) - - `iqr`: 四分位距(Q3-Q1,衡量稳定性) - - **越低越好**,衡量记忆检索效率 - - 优秀标准:p95 < 2000ms - -- **`llm`**: LLM 推理延迟统计(单位:毫秒) - - `mean`: 平均推理时间 - - `p50`: 中位数推理时间 - - `p95`: 95分位数推理时间 - - `iqr`: 四分位距(越小越稳定) - - **越低越好**,衡量答案生成速度 - - 优秀标准:p95 < 3000ms - - 注意:LLM 延迟通常占总延迟的大部分 - -##### 3. 资源指标 - -- **`avg_context_tokens`**: 平均上下文 token 数 - - **越低越好**(在保持准确性前提下) - - 直接影响: - - API 调用成本(按 token 计费) - - 推理速度(token 越多越慢) - - 上下文窗口占用 - - 成本效益比 = accuracy / avg_context_tokens - - 建议范围:根据模型上下文窗口和成本预算调整 - -##### 4. 数据集特点 - -- **`items`**: 评估的问题数量 - - 样本量越大,评估结果越可靠 - - 建议至少 100 个样本以获得稳定的评估结果 - -- **对话记忆特性**: - - MemSciQA 专注于对话历史中的记忆检索 - - 评估系统从多轮对话中提取和回忆信息的能力 - - 模拟真实的对话场景 - -#### 系统进步衡量标准 - -**一级指标(最重要):** -- `accuracy` 提升 → 核心能力提升 -- 目标:> 0.85 -- `f1` 提升 → 综合性能提升 -- 目标:> 0.80 - -**二级指标(重要):** -- `latency.p95` 降低 → 用户体验提升 - - search.p95 目标:< 2000ms - - llm.p95 目标:< 3000ms -- `iqr` 降低 → 性能稳定性提升 - -**三级指标(辅助):** -- `avg_context_tokens` 降低(在保持准确性前提下)→ 成本优化 -- `bleu1` 和 `jaccard` 提升 → 答案质量提升 - -**综合评估:** -- 成本效益比 = accuracy / avg_context_tokens - - 该比值越高,说明系统在相同成本下性能越好 -- 总延迟 = search.p95 + llm.p95 - - 应控制在 5 秒以内以保证良好的用户体验 - -#### 优化建议 - -**提升准确性:** -- 优化检索算法(调整 hybrid search 参数) -- 改进 embedding 模型质量 -- 增加检索上下文数量(`search_limit`) -- 优化 prompt 工程 - -**提升效率:** -- 减少不必要的检索文档 -- 使用更快的 LLM 模型或量化版本 -- 实施缓存策略(相似问题复用结果) -- 优化数据库索引 - -**平衡性能:** -- 监控 accuracy vs latency 的权衡 -- 监控 accuracy vs cost (tokens) 的权衡 -- 根据业务需求调整优先级 - - ---- - -# 6. 三个基准测试对比总结 - -## 6.1 测试特点对比 - -| 基准测试 | 主要评估目标 | 数据集特点 | 适用场景 | -|---------|------------|-----------|---------| -| **Locomo** | 长对话记忆检索 | 长对话历史,多轮交互 | 评估长期记忆保持和检索能力 | -| **LongMemEval** | 时间推理和多会话记忆 | 支持时间推理,多会话场景 | 评估时间感知和跨会话记忆能力 | -| **MemSciQA** | 对话记忆问答 | 对话历史问答 | 评估对话上下文理解和记忆提取 | - -## 6.2 核心指标对比 - -### 准确性指标 - -| 指标 | Locomo | LongMemEval | MemSciQA | 说明 | -|-----|--------|-------------|----------|------| -| **F1 Score** | ✅ | ✅ | ✅ | 所有测试都使用,最重要的综合指标 | -| **Accuracy** | ❌ | ✅ | ✅ | 直观的准确率指标 | -| **BLEU-1** | ✅ | ❌ | ✅ | 词汇级别匹配度 | -| **Jaccard** | ✅ | ✅ | ✅ | 集合相似度 | -| **Exact Match** | ❌ | ✅ | ❌ | 最严格的评估标准 | - -### 性能指标 - -所有三个测试都包含: -- **检索延迟** (search latency): mean, p50, p95, iqr -- **LLM 延迟** (llm latency): mean, p50, p95, iqr -- **上下文统计**: token 数、字符数、文档数 - -## 6.3 关键进步指标优先级 - -### 🥇 一级指标(必须关注) - -1. **准确性指标** - - Locomo: `f1`, `locomo_f1` - - LongMemEval: `score_accuracy`, `accuracy_by_type` - - MemSciQA: `accuracy`, `f1` - - **目标**: > 85% 或 > 0.85 - -2. **综合性能** - - 所有测试的 F1 分数应保持一致性 - - 不同类型问题的准确率应均衡 - -### 🥈 二级指标(重要) - -3. **响应延迟** - - `latency.p95` (95分位数延迟) - - **目标**: - - search.p95 < 2000ms - - llm.p95 < 3000ms - - 总延迟 < 5000ms - -4. **性能稳定性** - - `iqr` (四分位距) - - **目标**: 越小越好,说明性能稳定 - -### 🥉 三级指标(优化) - -5. **成本效率** - - `avg_context_tokens` - - **目标**: 在保持准确性前提下最小化 - - 成本效益比 = accuracy / avg_context_tokens - -6. **检索质量** - - `avg_retrieved_docs` 的合理性 - - `unique_preview_count` (LongMemEval) - - 检索内容的多样性和相关性 - -## 6.4 系统优化路径 - -### 阶段一:提升准确性(优先级最高) - -**目标**: 所有测试的准确率 > 85% - -**优化方向**: -1. 改进 embedding 模型质量 -2. 优化检索算法(hybrid search 参数) -3. 增加检索上下文数量(`search_limit`) -4. 优化 prompt 工程 -5. 改进记忆存储结构 - -**监控指标**: -- Locomo: `f1`, `locomo_f1` -- LongMemEval: `score_accuracy`, `exact_match` 比例 -- MemSciQA: `accuracy`, `f1` - -### 阶段二:优化性能(准确性达标后) - -**目标**: p95 延迟 < 5 秒,性能稳定 - -**优化方向**: -1. 优化数据库索引和查询 -2. 实施缓存策略 -3. 使用更快的 LLM 模型 -4. 并行化检索和推理 -5. 减少不必要的检索 - -**监控指标**: -- `latency.p50`, `latency.p95` -- `iqr` (稳定性) -- 各阶段耗时分布 - -### 阶段三:降低成本(性能达标后) - -**目标**: 在保持准确性和性能前提下,最小化成本 - -**优化方向**: -1. 精简检索上下文 -2. 优化 context 选择策略 -3. 使用更小的 LLM 模型 -4. 实施智能缓存 -5. 批处理优化 - -**监控指标**: -- `avg_context_tokens` -- 成本效益比 = accuracy / avg_context_tokens -- API 调用成本 - -## 6.5 评估最佳实践 - -### 测试执行建议 - -1. **初始测试**: 使用小样本快速验证 - ```bash - --sample_size 10 - ``` - -2. **完整评估**: 使用足够大的样本量 - ```bash - --sample_size 100 # 或更多 - ``` - -3. **增量测试**: 数据已摄入时跳过摄入阶段 - ```bash - --skip_ingest - ``` - -4. **参数调优**: 系统性地调整参数并记录结果 - - 调整 `search_limit`: 4, 8, 12, 16 - - 调整 `context_char_budget`: 2000, 4000, 8000 - - 尝试不同的 `search_type`: vector, keyword, hybrid - -### 结果分析建议 - -1. **横向对比**: 比较三个测试的结果,识别系统的强弱项 -2. **纵向对比**: 跟踪同一测试在不同版本的表现 -3. **分类分析**: 关注不同问题类型的性能差异 -4. **异常诊断**: 分析失败案例,找出根本原因 - -### 持续监控 - -建议建立监控仪表板,跟踪: -- 核心指标趋势(准确率、延迟) -- 成本效益比趋势 -- 不同问题类型的性能分布 -- 异常样本和失败模式 - -## 6.6 性能基准参考 - -### 优秀水平(Production Ready) - -- **准确性**: accuracy/f1 > 0.90 -- **延迟**: p95 < 3 秒 -- **稳定性**: iqr < 500ms -- **成本效益**: accuracy/tokens > 0.0001 - -### 良好水平(Acceptable) - -- **准确性**: accuracy/f1 > 0.85 -- **延迟**: p95 < 5 秒 -- **稳定性**: iqr < 1000ms -- **成本效益**: accuracy/tokens > 0.00005 - -### 需要改进(Below Target) - -- **准确性**: accuracy/f1 < 0.85 -- **延迟**: p95 > 5 秒 -- **稳定性**: iqr > 1000ms -- **成本效益**: accuracy/tokens < 0.00005 - ---- - -**注**: 以上标准仅供参考,实际目标应根据具体业务需求和资源约束调整。 diff --git a/api/app/core/memory/evaluation/check_enduser_data.py b/api/app/core/memory/evaluation/check_enduser_data.py deleted file mode 100644 index 18ecbb34..00000000 --- a/api/app/core/memory/evaluation/check_enduser_data.py +++ /dev/null @@ -1,371 +0,0 @@ -""" -交互式 Neo4j End User 数据检查工具 - -用于查询指定 end_user_id 在 Neo4j 中是否存在数据,以及数据的详细统计信息。 - -使用方法: - python check_group_data.py - python check_group_data.py --group-id locomo_benchmark - python check_group_data.py --group-id memsciqa_benchmark --detailed -""" - -import asyncio -import argparse -import os -from pathlib import Path -from typing import Dict, Any -from dotenv import load_dotenv - -# Load evaluation config -eval_config_path = Path(__file__).resolve().parent / ".env.evaluation" -if eval_config_path.exists(): - load_dotenv(eval_config_path, override=True) - print(f"✅ 加载评估配置: {eval_config_path}\n") - -from app.repositories.neo4j.neo4j_connector import Neo4jConnector - - -async def check_group_exists(end_user_id: str) -> Dict[str, Any]: - """ - 检查指定 end_user_id 是否存在数据 - - Args: - end_user_id: 要检查的 end_user ID - - Returns: - 包含统计信息的字典 - """ - connector = Neo4jConnector() - - try: - # 查询该 end_user 的节点总数 - query_total = """ - MATCH (n {end_user_id: $end_user_id}) - RETURN count(n) as total_nodes - """ - result_total = await connector.execute_query(query_total, end_user_id=end_user_id) - total_nodes = result_total[0]["total_nodes"] if result_total else 0 - - # 查询各类型节点的数量 - query_by_type = """ - MATCH (n {end_user_id: $end_user_id}) - RETURN labels(n) as labels, count(n) as count - ORDER BY count DESC - """ - result_by_type = await connector.execute_query(query_by_type, end_user_id=end_user_id) - - # 查询关系数量 - query_relationships = """ - MATCH (n {end_user_id: $end_user_id})-[r]-() - RETURN count(DISTINCT r) as total_relationships - """ - result_rel = await connector.execute_query(query_relationships, end_user_id=end_user_id) - total_relationships = result_rel[0]["total_relationships"] if result_rel else 0 - - return { - "exists": total_nodes > 0, - "total_nodes": total_nodes, - "total_relationships": total_relationships, - "nodes_by_type": result_by_type - } - - finally: - await connector.close() - - -async def get_detailed_stats(end_user_id: str) -> Dict[str, Any]: - """ - 获取详细的统计信息 - - Args: - end_user_id: 要检查的 end_user ID - - Returns: - 详细统计信息字典 - """ - connector = Neo4jConnector() - - try: - stats = {} - - # Chunk 节点统计 - query_chunks = """ - MATCH (c:Chunk {end_user_id: $end_user_id}) - RETURN count(c) as count, - avg(size(c.content)) as avg_content_length - """ - result_chunks = await connector.execute_query(query_chunks, end_user_id=end_user_id) - if result_chunks and result_chunks[0]["count"] > 0: - stats["chunks"] = { - "count": result_chunks[0]["count"], - "avg_content_length": int(result_chunks[0]["avg_content_length"]) if result_chunks[0]["avg_content_length"] else 0 - } - - # Statement 节点统计 - query_statements = """ - MATCH (s:Statement {end_user_id: $end_user_id}) - RETURN count(s) as count - """ - result_statements = await connector.execute_query(query_statements, end_user_id=end_user_id) - if result_statements and result_statements[0]["count"] > 0: - stats["statements"] = { - "count": result_statements[0]["count"] - } - - # Entity 节点统计 - query_entities = """ - MATCH (e:Entity {end_user_id: $end_user_id}) - RETURN count(e) as count, - count(DISTINCT e.entity_type) as unique_types - """ - result_entities = await connector.execute_query(query_entities, end_user_id=end_user_id) - if result_entities and result_entities[0]["count"] > 0: - stats["entities"] = { - "count": result_entities[0]["count"], - "unique_types": result_entities[0]["unique_types"] - } - - # Dialogue 节点统计 - query_dialogues = """ - MATCH (d:Dialogue {end_user_id: $end_user_id}) - RETURN count(d) as count - """ - result_dialogues = await connector.execute_query(query_dialogues, end_user_id=end_user_id) - if result_dialogues and result_dialogues[0]["count"] > 0: - stats["dialogues"] = { - "count": result_dialogues[0]["count"] - } - - # Summary 节点统计 - query_summaries = """ - MATCH (s:Summary {end_user_id: $end_user_id}) - RETURN count(s) as count - """ - result_summaries = await connector.execute_query(query_summaries, end_user_id=end_user_id) - if result_summaries and result_summaries[0]["count"] > 0: - stats["summaries"] = { - "count": result_summaries[0]["count"] - } - - return stats - - finally: - await connector.close() - - -async def list_all_end_users() -> list: - """ - 列出数据库中所有的 end_user_id - - Returns: - end_user_id 列表及其节点数量 - """ - connector = Neo4jConnector() - - try: - query = """ - MATCH (n) - WHERE n.end_user_id IS NOT NULL - RETURN DISTINCT n.end_user_id as end_user_id, count(n) as node_count - ORDER BY node_count DESC - """ - results = await connector.execute_query(query) - return results - - finally: - await connector.close() - - -def print_results(end_user_id: str, stats: Dict[str, Any], detailed_stats: Dict[str, Any] = None): - """ - 打印查询结果 - - Args: - end_user_id: End User ID - stats: 基本统计信息 - detailed_stats: 详细统计信息(可选) - """ - print(f"\n{'='*60}") - print(f"📊 End User ID: {end_user_id}") - print(f"{'='*60}\n") - - if not stats["exists"]: - print("❌ 该 end_user_id 不存在数据") - print("\n💡 提示: 请先运行基准测试以摄入数据") - return - - print(f"✅ 该 end_user_id 存在数据\n") - print(f"📈 基本统计:") - print(f" 总节点数: {stats['total_nodes']}") - print(f" 总关系数: {stats['total_relationships']}") - - if stats["nodes_by_type"]: - print(f"\n📋 节点类型分布:") - for item in stats["nodes_by_type"]: - labels = ", ".join(item["labels"]) - count = item["count"] - print(f" {labels}: {count}") - - if detailed_stats: - print(f"\n🔍 详细统计:") - - if "chunks" in detailed_stats: - print(f" Chunks: {detailed_stats['chunks']['count']} 个") - print(f" 平均内容长度: {detailed_stats['chunks']['avg_content_length']} 字符") - - if "statements" in detailed_stats: - print(f" Statements: {detailed_stats['statements']['count']} 个") - - if "entities" in detailed_stats: - print(f" Entities: {detailed_stats['entities']['count']} 个") - print(f" 唯一类型数: {detailed_stats['entities']['unique_types']}") - - if "dialogues" in detailed_stats: - print(f" Dialogues: {detailed_stats['dialogues']['count']} 个") - - if "summaries" in detailed_stats: - print(f" Summaries: {detailed_stats['summaries']['count']} 个") - - print(f"\n{'='*60}\n") - - -async def interactive_mode(): - """ - 交互式模式 - """ - print("\n" + "="*60) - print("🔍 Neo4j End User 数据检查工具 - 交互模式") - print("="*60 + "\n") - - while True: - print("\n请选择操作:") - print(" 1. 检查指定 end_user_id") - print(" 2. 列出所有 end_user_id") - print(" 3. 退出") - - choice = input("\n请输入选项 (1-3): ").strip() - - if choice == "1": - end_user_id = input("\n请输入 end_user_id: ").strip() - if not end_user_id: - print("❌ end_user_id 不能为空") - continue - - detailed = input("是否显示详细统计? (y/n, 默认 n): ").strip().lower() == 'y' - - print("\n🔄 正在查询...") - stats = await check_group_exists(end_user_id) - - detailed_stats = None - if detailed and stats["exists"]: - detailed_stats = await get_detailed_stats(end_user_id) - - print_results(end_user_id, stats, detailed_stats) - - elif choice == "2": - print("\n🔄 正在查询所有 end_user_id...") - end_users = await list_all_end_users() - - if not end_users: - print("\n❌ 数据库中没有任何 end_user 数据") - else: - print(f"\n{'='*60}") - print(f"📋 数据库中的所有 End User ID") - print(f"{'='*60}\n") - - for idx, end_user in enumerate(end_users, 1): - print(f" {idx}. {end_user['end_user_id']}") - print(f" 节点数: {end_user['node_count']}") - - print(f"\n{'='*60}\n") - - elif choice == "3": - print("\n👋 再见!") - break - - else: - print("\n❌ 无效的选项,请重新选择") - - -async def main(): - """ - 主函数 - """ - parser = argparse.ArgumentParser( - description="检查 Neo4j 中指定 end_user_id 的数据情况", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -示例: - # 交互模式 - python check_group_data.py - - # 检查指定 end_user - python check_group_data.py --end-user-id locomo_benchmark - - # 检查并显示详细统计 - python check_group_data.py --end-user-id memsciqa_benchmark --detailed - - # 列出所有 end_user - python check_group_data.py --list-all - """ - ) - - parser.add_argument( - "--end-user-id", - type=str, - help="要检查的 end_user ID" - ) - - parser.add_argument( - "--detailed", - action="store_true", - help="显示详细统计信息" - ) - - parser.add_argument( - "--list-all", - action="store_true", - help="列出所有 end_user_id" - ) - - args = parser.parse_args() - - # 如果没有提供任何参数,进入交互模式 - if not args.end_user_id and not args.list_all: - await interactive_mode() - return - - # 列出所有 end_user - if args.list_all: - print("\n🔄 正在查询所有 end_user_id...") - end_users = await list_all_end_users() - - if not end_users: - print("\n❌ 数据库中没有任何 end_user 数据") - else: - print(f"\n{'='*60}") - print(f"📋 数据库中的所有 End User ID") - print(f"{'='*60}\n") - - for idx, end_user in enumerate(end_users, 1): - print(f" {idx}. {end_user['end_user_id']}") - print(f" 节点数: {end_user['node_count']}") - - print(f"\n{'='*60}\n") - return - - # 检查指定 end_user - if args.end_user_id: - print(f"\n🔄 正在查询 end_user_id: {args.end_user_id}...") - stats = await check_group_exists(args.end_user_id) - - detailed_stats = None - if args.detailed and stats["exists"]: - print("🔄 正在获取详细统计...") - detailed_stats = await get_detailed_stats(args.end_user_id) - - print_results(args.end_user_id, stats, detailed_stats) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/api/app/core/memory/evaluation/common/metrics.py b/api/app/core/memory/evaluation/common/metrics.py deleted file mode 100644 index 961ce7f0..00000000 --- a/api/app/core/memory/evaluation/common/metrics.py +++ /dev/null @@ -1,100 +0,0 @@ -import math -import re -from typing import List, Dict - -# 评估指标的实现 -def _normalize(text: str) -> List[str]: - """Lowercase, strip punctuation, and split into tokens.""" - text = text.lower().strip() - # Python's re doesn't support \p classes; use a simple non-word filter - text = re.sub(r"[^\w\s]", " ", text) - tokens = [t for t in text.split() if t] - return tokens - - -def exact_match(pred: str, ref: str) -> float: - return float(_normalize(pred) == _normalize(ref)) - - -def jaccard(pred: str, ref: str) -> float: - p = set(_normalize(pred)) - r = set(_normalize(ref)) - if not p and not r: - return 1.0 - if not p or not r: - return 0.0 - return len(p & r) / len(p | r) - - -def f1_score(pred: str, ref: str) -> float: - p_tokens = _normalize(pred) - r_tokens = _normalize(ref) - if not p_tokens and not r_tokens: - return 1.0 - if not p_tokens or not r_tokens: - return 0.0 - p_set = set(p_tokens) - r_set = set(r_tokens) - tp = len(p_set & r_set) - precision = tp / len(p_set) if p_set else 0.0 - recall = tp / len(r_set) if r_set else 0.0 - if precision + recall == 0: - return 0.0 - return 2 * precision * recall / (precision + recall) - - -def bleu1(pred: str, ref: str) -> float: - """Unigram BLEU (BLEU-1) with clipping and brevity penalty.""" - p_tokens = _normalize(pred) - r_tokens = _normalize(ref) - if not p_tokens: - return 0.0 - # Clipped count - r_counts: Dict[str, int] = {} - for t in r_tokens: - r_counts[t] = r_counts.get(t, 0) + 1 - clipped = 0 - p_counts: Dict[str, int] = {} - for t in p_tokens: - p_counts[t] = p_counts.get(t, 0) + 1 - for t, c in p_counts.items(): - clipped += min(c, r_counts.get(t, 0)) - precision = clipped / max(len(p_tokens), 1) - # Brevity penalty - ref_len = len(r_tokens) - pred_len = len(p_tokens) - if pred_len > ref_len or pred_len == 0: - bp = 1.0 - else: - bp = math.exp(1 - ref_len / max(pred_len, 1)) - return bp * precision - - -def percentile(values: List[float], p: float) -> float: - if not values: - return 0.0 - vals = sorted(values) - k = (len(vals) - 1) * p - f = math.floor(k) - c = math.ceil(k) - if f == c: - return vals[int(k)] - return vals[f] + (k - f) * (vals[c] - vals[f]) - - -def latency_stats(latencies_ms: List[float]) -> Dict[str, float]: - """Return basic latency stats: mean, p50, p95, iqr (p75-p25).""" - if not latencies_ms: - return {"mean": 0.0, "p50": 0.0, "p95": 0.0, "iqr": 0.0} - p25 = percentile(latencies_ms, 0.25) - p50 = percentile(latencies_ms, 0.50) - p75 = percentile(latencies_ms, 0.75) - p95 = percentile(latencies_ms, 0.95) - mean = sum(latencies_ms) / max(len(latencies_ms), 1) - return {"mean": mean, "p50": p50, "p95": p95, "iqr": p75 - p25} - - -def avg_context_tokens(contexts: List[str]) -> float: - if not contexts: - return 0.0 - return sum(len(_normalize(c)) for c in contexts) / len(contexts) diff --git a/api/app/core/memory/evaluation/dialogue_queries.py b/api/app/core/memory/evaluation/dialogue_queries.py deleted file mode 100644 index 0aace0ec..00000000 --- a/api/app/core/memory/evaluation/dialogue_queries.py +++ /dev/null @@ -1,62 +0,0 @@ -""" -Dialogue search queries for evaluation purposes. -This file contains Cypher queries for searching dialogues, entities, and chunks. -Placed in evaluation directory to avoid circular imports with src modules. -""" - -# 应该是neo4j browser的cypher语句,需要修改文件名 - -# Entity search queries -SEARCH_ENTITIES_BY_NAME = """ -MATCH (e:ExtractedEntity) -WHERE e.name = $name -RETURN e -""" - -SEARCH_ENTITIES_BY_NAME_FALLBACK = """ -MATCH (e:ExtractedEntity) -WHERE e.name CONTAINS $name -RETURN e -""" - -# Chunk search queries -SEARCH_CHUNKS_BY_CONTENT = """ -MATCH (c:Chunk) -WHERE c.content CONTAINS $content -RETURN c -""" - -# Dialogue search queries -SEARCH_DIALOGUE_BY_DIALOG_ID = """ -MATCH (d:Dialogue) -WHERE d.dialog_id = $dialog_id -RETURN d -""" - -SEARCH_DIALOGUES_BY_CONTENT = """ -MATCH (d:Dialogue) -WHERE d.content CONTAINS $q -RETURN d -""" - -DIALOGUE_EMBEDDING_SEARCH = """ -WITH $embedding AS q -MATCH (d:Dialogue) -WHERE d.dialog_embedding IS NOT NULL - AND ($end_user_id IS NULL OR d.end_user_id = $end_user_id) -WITH d, q, d.dialog_embedding AS v -WITH d, - reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot, - sqrt(reduce(qs = 0.0, i IN range(0, size(q)-1) | qs + toFloat(q[i]) * toFloat(q[i]))) AS qnorm, - sqrt(reduce(vs = 0.0, i IN range(0, size(v)-1) | vs + toFloat(v[i]) * toFloat(v[i]))) AS vnorm -WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score -WHERE score > $threshold -RETURN d.id AS dialog_id, - d.end_user_id AS end_user_id, - d.content AS content, - d.created_at AS created_at, - d.expired_at AS expired_at, - score -ORDER BY score DESC -LIMIT $limit -""" diff --git a/api/app/core/memory/evaluation/extraction_utils.py b/api/app/core/memory/evaluation/extraction_utils.py deleted file mode 100644 index 43ef6fe0..00000000 --- a/api/app/core/memory/evaluation/extraction_utils.py +++ /dev/null @@ -1,444 +0,0 @@ -import os -import asyncio -import json -from typing import List, Dict, Any, Optional -from datetime import datetime -from uuid import UUID -import re - -from app.core.memory.llm_tools.openai_client import LLMClient -from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker -from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage -import os -import sys -from pathlib import Path -from dotenv import load_dotenv - -# Load evaluation config -eval_config_path = Path(__file__).resolve().parent / "app" / "core" / "memory" / "evaluation" / ".env.evaluation" -if eval_config_path.exists(): - load_dotenv(eval_config_path, override=True) - print(f"✅ 加载评估配置: {eval_config_path}") - -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.utils.llm.llm_utils import get_llm_client - -# 使用新的模块化架构 -from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator - -# Import from database module -from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j - -# Cypher queries for evaluation -# Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py - - -async def ingest_contexts_via_full_pipeline( - contexts: List[str], - end_user_id: str, - chunker_strategy: str | None = None, - embedding_name: str | None = None, - save_chunk_output: bool = False, - save_chunk_output_path: str | None = None, - reset_group: bool = False, -) -> bool: - """ - 使用新的 ExtractionOrchestrator 运行完整的提取流水线 - - Run the full extraction pipeline on provided dialogue contexts and save to Neo4j. - This function uses the new ExtractionOrchestrator architecture for better maintainability. - - Args: - contexts: List of dialogue texts, each containing lines like "role: message". - end_user_id: Group ID to assign to generated DialogData and graph nodes. - chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY. - embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID. - save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging. - save_chunk_output_path: Optional output path; defaults to src/chunker_test_output.txt. - reset_group: If True, clear existing data for this group before ingestion. - Returns: - True if data saved successfully, False otherwise. - """ - chunker_strategy = chunker_strategy or os.getenv("EVAL_CHUNKER_STRATEGY", "RecursiveChunker") - embedding_name = embedding_name or os.getenv("EVAL_EMBEDDING_ID") - - # Check if we should reset from environment variable if not explicitly set - if not reset_group: - reset_group = os.getenv("EVAL_RESET_ON_INGEST", "false").lower() in ("true", "1", "yes") - - # Step 0: Reset group if requested - if reset_group: - print(f"[Ingestion] 🗑️ 清空 end_user '{end_user_id}' 的现有数据...") - try: - from app.repositories.neo4j.neo4j_connector import Neo4jConnector - connector = Neo4jConnector() - try: - # 删除该 end_user 的所有节点和关系 - query = """ - MATCH (n {end_user_id: $end_user_id}) - DETACH DELETE n - """ - await connector.execute_query(query, end_user_id=end_user_id) - print(f"[Ingestion] ✅ End User '{end_user_id}' 已清空") - finally: - await connector.close() - except Exception as e: - print(f"[Ingestion] ⚠️ 清空 end_user 失败: {e}") - # 继续执行,不中断摄入流程 - - # Step 1: Initialize LLM client - llm_client = None - try: - # 使用评估配置中的 LLM ID - llm_id = os.getenv("EVAL_LLM_ID") - if not llm_id: - print("[Ingestion] ❌ EVAL_LLM_ID not set in .env.evaluation") - return False - - from app.db import get_db - - db = next(get_db()) - try: - llm_client = get_llm_client(llm_id, db) - finally: - db.close() - except Exception as e: - print(f"[Ingestion] LLM client unavailable: {e}") - return False - - # Step 2: Parse contexts and create DialogData with chunks - print(f"[Ingestion] Parsing {len(contexts)} contexts...") - chunker = DialogueChunker(chunker_strategy) - dialog_data_list: List[DialogData] = [] - - for idx, ctx in enumerate(contexts): - messages: List[ConversationMessage] = [] - - # Improved parsing: capture multi-line message blocks, normalize roles - pattern = r"^\s*(用户|AI|assistant|user)\s*[::]\s*(.+?)(?=\n\s*(?:用户|AI|assistant|user)\s*[::]|\Z)" - matches = list(re.finditer(pattern, ctx, flags=re.MULTILINE | re.DOTALL)) - - if matches: - for m in matches: - raw_role = m.group(1).strip() - content = m.group(2).strip() - norm_role = "AI" if raw_role.lower() in ("ai", "assistant") else "用户" - messages.append(ConversationMessage(role=norm_role, msg=content)) - else: - # Fallback: line-by-line parsing - for raw in ctx.split("\n"): - line = raw.strip() - if not line: - continue - m = re.match(r'^\s*([^::]+)\s*[::]\s*(.+)', line) - if m: - role = m.group(1).strip() - msg = m.group(2).strip() - norm_role = "AI" if role.lower() in ("ai", "assistant") else "用户" - messages.append(ConversationMessage(role=norm_role, msg=msg)) - else: - # Final fallback: treat as user message - default_role = "AI" if re.match(r'^\s*(assistant|AI)\b', line, flags=re.IGNORECASE) else "用户" - messages.append(ConversationMessage(role=default_role, msg=line)) - - context_model = ConversationContext(msgs=messages) - dialog = DialogData( - context=context_model, - ref_id=f"pipeline_item_{idx}", - end_user_id=end_user_id, - user_id="default_user", - apply_id="default_application", - ) - # Generate chunks - dialog.chunks = await chunker.process_dialogue(dialog) - dialog_data_list.append(dialog) - - if not dialog_data_list: - print("[Ingestion] No dialogs to process.") - return False - - print(f"[Ingestion] Parsed {len(dialog_data_list)} dialogs with chunks") - - # Step 3: Optionally save chunking outputs for debugging - if save_chunk_output: - try: - def _serialize_datetime(obj): - if isinstance(obj, datetime): - return obj.isoformat() - raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") - - from app.core.config import settings - settings.ensure_memory_output_dir() - default_path = settings.get_memory_output_path("chunker_test_output.txt") - out_path = save_chunk_output_path or default_path - - combined_output = [dd.model_dump() for dd in dialog_data_list] - with open(out_path, "w", encoding="utf-8") as f: - json.dump(combined_output, f, ensure_ascii=False, indent=4, default=_serialize_datetime) - print(f"[Ingestion] Saved chunking results to: {out_path}") - except Exception as e: - print(f"[Ingestion] Failed to save chunking results: {e}") - - # Step 4: Initialize embedder client - from app.core.models.base import RedBearModelConfig - from app.core.memory.utils.config.config_utils import get_embedder_config - from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient - from app.db import get_db - - try: - db = next(get_db()) - try: - embedder_config_dict = get_embedder_config(embedding_name, db) - embedder_config = RedBearModelConfig(**embedder_config_dict) - embedder_client = OpenAIEmbedderClient(embedder_config) - finally: - db.close() - except Exception as e: - print(f"[Ingestion] Failed to initialize embedder client: {e}") - return False - - # Step 5: Initialize Neo4j connector - connector = Neo4jConnector() - - # Step 6: 构建 MemoryConfig(从环境变量直接构建,不依赖数据库) - print("[Ingestion] 构建 MemoryConfig from environment variables...") - from app.schemas.memory_config_schema import MemoryConfig - - try: - # 从环境变量获取配置参数 - llm_id = os.getenv("EVAL_LLM_ID") - embedding_id = os.getenv("EVAL_EMBEDDING_ID") - chunker_strategy_env = os.getenv("EVAL_CHUNKER_STRATEGY", "RecursiveChunker") - - if not llm_id or not embedding_id: - print("[Ingestion] ❌ EVAL_LLM_ID or EVAL_EMBEDDING_ID is not set in .env.evaluation") - print("[Ingestion] Please set both EVAL_LLM_ID and EVAL_EMBEDDING_ID") - await connector.close() - return False - - # 从数据库获取模型信息(仅用于显示名称) - from app.db import get_db - db = next(get_db()) - try: - from sqlalchemy import text - # 获取 LLM 模型信息(从 model_configs 表) - llm_result = db.execute( - text("SELECT name FROM model_configs WHERE id = :id"), - {"id": llm_id} - ).fetchone() - llm_model_name = llm_result[0] if llm_result else "Unknown LLM" - - # 获取 Embedding 模型信息(从 model_configs 表) - emb_result = db.execute( - text("SELECT name FROM model_configs WHERE id = :id"), - {"id": embedding_id} - ).fetchone() - embedding_model_name = emb_result[0] if emb_result else "Unknown Embedding" - except Exception as e: - # 如果查询失败,使用默认名称 - print(f"[Ingestion] Warning: Failed to query model names from database: {e}") - llm_model_name = f"LLM ({llm_id[:8]}...)" - embedding_model_name = f"Embedding ({embedding_id[:8]}...)" - finally: - db.close() - - # 构建 MemoryConfig 对象(使用最小必需配置) - from uuid import uuid4 - memory_config = MemoryConfig( - config_id=0, # 评估环境不需要真实的 config_id - config_name="evaluation_config", - workspace_id=uuid4(), # 临时 workspace_id - workspace_name="evaluation_workspace", - tenant_id=uuid4(), # 临时 tenant_id - llm_model_id=UUID(llm_id), - llm_model_name=llm_model_name, - embedding_model_id=UUID(embedding_id), - embedding_model_name=embedding_model_name, - storage_type="neo4j", - chunker_strategy=chunker_strategy_env, - reflexion_enabled=False, - reflexion_iteration_period=3, - reflexion_range="partial", - reflexion_baseline="TIME", - loaded_at=datetime.now(), - # 可选字段使用默认值 - rerank_model_id=None, - rerank_model_name=None, - llm_params={}, - embedding_params={}, - config_version="2.0", - ) - - print(f"[Ingestion] ✅ 构建 MemoryConfig 成功") - print(f"[Ingestion] LLM: {llm_model_name}") - print(f"[Ingestion] Embedding: {embedding_model_name}") - print(f"[Ingestion] Chunker: {chunker_strategy_env}") - - except Exception as e: - print(f"[Ingestion] ❌ Failed to build MemoryConfig: {e}") - print(f"[Ingestion] Please check:") - print(f"[Ingestion] 1. EVAL_LLM_ID and EVAL_EMBEDDING_ID are set in .env.evaluation") - print(f"[Ingestion] 2. Model IDs exist in the models table") - print(f"[Ingestion] 3. Database connection is working") - await connector.close() - return False - - # Step 7: Initialize and run ExtractionOrchestrator - print("[Ingestion] Running extraction pipeline with ExtractionOrchestrator...") - from app.services.memory_config_service import MemoryConfigService - config = MemoryConfigService.get_pipeline_config(memory_config) - - orchestrator = ExtractionOrchestrator( - llm_client=llm_client, - embedder_client=embedder_client, - connector=connector, - config=config, - embedding_id=str(memory_config.embedding_model_id), # 传递 embedding_id - ) - - try: - # Run the complete extraction pipeline - result = await orchestrator.run(dialog_data_list, is_pilot_run=False) - - # Handle different return formats: - # - Pilot mode: 7 values (without dedup_details) - # - Normal mode: 8 values (with dedup_details at the end) - if len(result) == 8: - # Normal mode: includes dedup_details - ( - dialogue_nodes, - chunk_nodes, - statement_nodes, - entity_nodes, - statement_chunk_edges, - statement_entity_edges, - entity_entity_edges, - _, # dedup_details - not needed here - ) = result - elif len(result) == 7: - # Pilot mode or older version: no dedup_details - ( - dialogue_nodes, - chunk_nodes, - statement_nodes, - entity_nodes, - statement_chunk_edges, - statement_entity_edges, - entity_entity_edges, - ) = result - else: - raise ValueError(f"Unexpected number of return values: {len(result)}") - - print(f"[Ingestion] Extraction completed: {len(statement_nodes)} statements, {len(entity_nodes)} entities") - - except ValueError as e: - # If unpacking fails, provide helpful error message - print(f"[Ingestion] Extraction pipeline result unpacking failed: {e}") - print(f"[Ingestion] Result type: {type(result)}, length: {len(result) if hasattr(result, '__len__') else 'N/A'}") - if hasattr(result, '__len__') and len(result) > 0: - print(f"[Ingestion] First element type: {type(result[0])}") - await connector.close() - return False - except Exception as e: - print(f"[Ingestion] Extraction pipeline failed: {e}") - import traceback - traceback.print_exc() - await connector.close() - return False - - # Step 7: Generate memory summaries - print("[Ingestion] Generating memory summaries...") - try: - from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import ( - memory_summary_generation, - ) - from app.repositories.neo4j.add_nodes import add_memory_summary_nodes - from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges - - summaries = await memory_summary_generation( - chunked_dialogs=dialog_data_list, - llm_client=llm_client, - embedder_client=embedder_client - ) - print(f"[Ingestion] Generated {len(summaries)} memory summaries") - except Exception as e: - print(f"[Ingestion] Warning: Failed to generate memory summaries: {e}") - summaries = [] - - # Step 8: Save to Neo4j - print("[Ingestion] Saving to Neo4j...") - try: - success = await save_dialog_and_statements_to_neo4j( - dialogue_nodes=dialogue_nodes, - chunk_nodes=chunk_nodes, - statement_nodes=statement_nodes, - entity_nodes=entity_nodes, - entity_edges=entity_entity_edges, - statement_chunk_edges=statement_chunk_edges, - statement_entity_edges=statement_entity_edges, - connector=connector - ) - - # Save memory summaries separately - if summaries: - try: - await add_memory_summary_nodes(summaries, connector) - await add_memory_summary_statement_edges(summaries, connector) - print(f"[Ingestion] Saved {len(summaries)} memory summary nodes to Neo4j") - except Exception as e: - print(f"[Ingestion] Warning: Failed to save summary nodes: {e}") - - await connector.close() - - if success: - print("[Ingestion] Successfully saved all data to Neo4j!") - else: - print("[Ingestion] Failed to save data to Neo4j") - return success - - except Exception as e: - print(f"[Ingestion] Failed to save data to Neo4j: {e}") - await connector.close() - return False - - -async def handle_context_processing(args): - """Handle context-based processing from command line arguments.""" - contexts = [] - - if args.contexts: - contexts.extend(args.contexts) - - if args.context_file: - try: - with open(args.context_file, 'r', encoding='utf-8') as f: - contexts.extend(line.strip() for line in f if line.strip()) - except Exception as e: - print(f"Error reading context file: {e}") - return False - - if not contexts: - print("No contexts provided for processing.") - return False - - return await main_from_contexts(contexts, args.context_end_user_id) - - -async def main_from_contexts(contexts: List[str], end_user_id: str): - """Run the pipeline from provided dialogue contexts instead of test data.""" - print("=== Running pipeline from provided contexts ===") - - success = await ingest_contexts_via_full_pipeline( - contexts=contexts, - end_user_id=end_user_id, - chunker_strategy=SELECTED_CHUNKER_STRATEGY, - embedding_name=SELECTED_EMBEDDING_ID, - save_chunk_output=True - ) - - if success: - print("Successfully processed and saved contexts to Neo4j!") - else: - print("Failed to process contexts.") - - return success diff --git a/api/app/core/memory/evaluation/locomo/locomo_benchmark.py b/api/app/core/memory/evaluation/locomo/locomo_benchmark.py deleted file mode 100644 index eed75016..00000000 --- a/api/app/core/memory/evaluation/locomo/locomo_benchmark.py +++ /dev/null @@ -1,770 +0,0 @@ -""" -LoCoMo Benchmark Script - -This module provides the main entry point for running LoCoMo benchmark evaluations. -It orchestrates data loading, ingestion, retrieval, LLM inference, and metric calculation -in a clean, maintainable way. - -Usage: - python locomo_benchmark.py --sample_size 20 --search_type hybrid -""" - -import argparse -import asyncio -import json -import os -import time -from datetime import datetime -from typing import List, Dict, Any, Optional -from pathlib import Path -from dotenv import load_dotenv - -# Load evaluation config -eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation" -if eval_config_path.exists(): - load_dotenv(eval_config_path, override=True) - print(f"✅ 加载评估配置: {eval_config_path}") - -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.models.base import RedBearModelConfig -from app.core.memory.utils.config.config_utils import get_embedder_config -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.memory.evaluation.common.metrics import ( - f1_score, - bleu1, - jaccard, - latency_stats, - avg_context_tokens -) -from app.core.memory.evaluation.locomo.locomo_metrics import ( - locomo_f1_score, - locomo_multi_f1, - get_category_name -) -from app.core.memory.evaluation.locomo.locomo_utils import ( - load_locomo_data, - extract_conversations, - resolve_temporal_references, - select_and_format_information, - retrieve_relevant_information, -) -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context -from app.services.memory_config_service import MemoryConfigService - -# Get configuration from environment variables -PROJECT_ROOT = str(Path(__file__).resolve().parents[5]) # api directory -SELECTED_EMBEDDING_ID = os.getenv("EVAL_EMBEDDING_ID", "e2a6392d-ca63-4d59-a523-647420b59cb2") -SELECTED_end_user_id = os.getenv("LOCOMO_END_USER_ID") or os.getenv("EVAL_END_USER_ID", "locomo_benchmark") -SELECTED_LLM_ID = os.getenv("EVAL_LLM_ID", "2c9b0782-7a85-4740-ba84-4baf77f256c4") - - -# ============================================================================ -# Step 1: Data Loading -# ============================================================================ - -def step_load_data(data_path: str, sample_size: int) -> List[Dict[str, Any]]: - """ - Load QA pairs from LoCoMo dataset. - - Args: - data_path: Path to locomo10.json file - sample_size: Number of QA pairs to load (0 for all) - - Returns: - List of QA items from the first conversation - """ - print("📂 Loading LoCoMo data...") - - # Load the dataset - qa_items = load_locomo_data(data_path, sample_size) - - print(f"✅ Loaded {len(qa_items)} QA pairs from first conversation\n") - return qa_items - - -# ============================================================================ -# Step 2: Data Ingestion -# ============================================================================ - -async def ingest_conversations_if_needed( - conversations: List[str], - end_user_id: str, - reset: bool = False -) -> bool: - """ - Ingest conversations into Neo4j database. - - Args: - conversations: List of conversation strings (already formatted) - end_user_id: Database end_user ID - reset: Whether to reset the group before ingestion - - Returns: - True if successful, False otherwise - """ - try: - from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline - - # Conversations are already formatted as strings, use them directly - await ingest_contexts_via_full_pipeline(conversations, end_user_id) - return True - - except Exception as e: - print(f"⚠️ Ingestion error: {e}") - import traceback - traceback.print_exc() - return False - - -async def step_ingest_data( - data_path: str, - end_user_id: str, - skip_ingest: bool, - reset_group: bool, - max_messages: Optional[int] = None -) -> bool: - """ - Ingest conversations into Neo4j database if needed. - - Args: - data_path: Path to locomo10.json file - end_user_id: Database end_user ID - skip_ingest: Whether to skip ingestion - reset_group: Whether to reset the group before ingestion - max_messages: Maximum messages per dialogue to ingest (for testing) - - Returns: - True if ingestion succeeded or was skipped, False otherwise - """ - if skip_ingest: - print("⏭️ Skipping data ingestion (using existing data in Neo4j)") - print(f" End User ID: {end_user_id}\n") - else: - print("💾 Checking database ingestion...") - try: - # Extract conversations with optional message limit - conversations = extract_conversations( - data_path, - max_dialogues=1, - max_messages_per_dialogue=max_messages - ) - print(f"📝 Extracted {len(conversations)} conversations") - - # Always ingest for now (ingestion check not implemented) - print(f"🔄 Ingesting conversations into end_user '{end_user_id}'...") - success = await ingest_conversations_if_needed( - conversations=conversations, - end_user_id=end_user_id, - reset=reset_group - ) - - if success: - print("✅ Ingestion completed successfully\n") - else: - print("⚠️ Ingestion may have failed, continuing anyway\n") - - except Exception as e: - print(f"❌ Ingestion failed: {e}") - import traceback - traceback.print_exc() - print("⚠️ Continuing with evaluation (database may be empty)\n") - - return True - - -# ============================================================================ -# Step 3: Initialize Clients -# ============================================================================ - -def step_initialize_clients(llm_id: str, embedding_id: str): - """ - Initialize Neo4j connector, LLM client, and embedder. - - Args: - llm_id: LLM model ID - embedding_id: Embedding model ID - - Returns: - Tuple of (connector, llm_client, embedder) - """ - print("🔧 Initializing clients...") - - connector = Neo4jConnector() - - # Get database session - from app.db import get_db - db = next(get_db()) - try: - llm_client = get_llm_client(llm_id, db) - cfg_dict = get_embedder_config(embedding_id, db) - embedder = OpenAIEmbedderClient( - model_config=RedBearModelConfig.model_validate(cfg_dict) - ) - finally: - db.close() - - print("✅ Clients initialized\n") - return connector, llm_client, embedder - - -# ============================================================================ -# Step 4: Process Questions -# ============================================================================ - -async def step_process_all_questions( - qa_items: List[Dict[str, Any]], - end_user_id: str, - search_type: str, - search_limit: int, - context_char_budget: int, - connector: Neo4jConnector, - embedder: OpenAIEmbedderClient, - llm_client: Any -) -> List[Dict[str, Any]]: - """Process all QA items: retrieve, generate, and calculate metrics.""" - print(f"🔍 Processing {len(qa_items)} questions...") - print(f"{'='*60}\n") - - samples: List[Dict[str, Any]] = [] - anchor_date = datetime(2023, 5, 8) - - for idx, item in enumerate(qa_items, 1): - question = item.get("question", "") - ground_truth = item.get("answer", "") - category = get_category_name(item) - ground_truth_str = str(ground_truth) if ground_truth is not None else "" - - print(f"[{idx}/{len(qa_items)}] Category: {category}") - print(f"❓ Question: {question}") - print(f"✅ Ground Truth: {ground_truth_str}") - - # Retrieve - t_search_start = time.time() - try: - retrieved_info = await retrieve_relevant_information( - question=question, - end_user_id=end_user_id, - search_type=search_type, - search_limit=search_limit, - connector=connector, - embedder=embedder - ) - search_latency = (time.time() - t_search_start) * 1000 - print(f"🔍 Retrieved {len(retrieved_info)} documents ({search_latency:.1f}ms)") - except Exception as e: - print(f"❌ Retrieval failed: {e}") - retrieved_info = [] - search_latency = 0.0 - - # Format context - context_text = select_and_format_information( - retrieved_info=retrieved_info, - question=question, - max_chars=context_char_budget - ) - context_text = resolve_temporal_references(context_text, anchor_date) - if context_text: - context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n{context_text}" - else: - context_text = "No relevant context found." - - print(f"📝 Context: {len(context_text)} chars, {len(retrieved_info)} docs") - - # Generate answer - messages = [ - { - "role": "system", - "content": ( - "You are a precise QA assistant. Answer following these rules:\n" - "1) Extract the EXACT information mentioned in the context\n" - "2) For time questions: calculate actual dates from relative times\n" - "3) Return ONLY the answer text in simplest form\n" - "4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n" - "5) If no clear answer found, respond with 'Unknown'" - ) - }, - { - "role": "user", - "content": f"Question: {question}\n\nContext:\n{context_text}" - } - ] - - t_llm_start = time.time() - try: - response = await llm_client.chat(messages=messages) - llm_latency = (time.time() - t_llm_start) * 1000 - if hasattr(response, 'content'): - prediction = response.content.strip() - elif isinstance(response, dict): - prediction = response["choices"][0]["message"]["content"].strip() - else: - prediction = "Unknown" - print(f"🤖 Prediction: {prediction} ({llm_latency:.1f}ms)") - except Exception as e: - print(f"❌ LLM failed: {e}") - prediction = "Unknown" - llm_latency = 0.0 - - # Calculate metrics - f1_val = f1_score(prediction, ground_truth_str) - bleu1_val = bleu1(prediction, ground_truth_str) - jaccard_val = jaccard(prediction, ground_truth_str) - if item.get("category") == 1: - locomo_f1_val = locomo_multi_f1(prediction, ground_truth_str) - else: - locomo_f1_val = locomo_f1_score(prediction, ground_truth_str) - - print(f"📊 Metrics - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, " - f"Jaccard: {jaccard_val:.3f}, LoCoMo F1: {locomo_f1_val:.3f}") - print() - - samples.append({ - "question": question, - "ground_truth": ground_truth_str, - "prediction": prediction, - "category": category, - "metrics": { - "f1": f1_val, - "bleu1": bleu1_val, - "jaccard": jaccard_val, - "locomo_f1": locomo_f1_val - }, - "retrieval": { - "num_docs": len(retrieved_info), - "context_length": len(context_text) - }, - "context_tokens": len(context_text.split()), - "timing": { - "search_ms": search_latency, - "llm_ms": llm_latency - } - }) - - return samples - - -# ============================================================================ -# Step 5: Aggregate Results -# ============================================================================ - -def step_aggregate_results(samples: List[Dict[str, Any]]) -> Dict[str, Any]: - """Aggregate metrics from all samples.""" - print(f"\n{'='*60}") - print("📊 Aggregating Results") - print(f"{'='*60}\n") - - if not samples: - return { - "overall_metrics": {}, - "by_category": {}, - "latency": {}, - "context_stats": {} - } - - # Extract metrics - f1_scores = [s["metrics"]["f1"] for s in samples] - bleu1_scores = [s["metrics"]["bleu1"] for s in samples] - jaccard_scores = [s["metrics"]["jaccard"] for s in samples] - locomo_f1_scores = [s["metrics"]["locomo_f1"] for s in samples] - - # Extract timing - latencies_search = [s["timing"]["search_ms"] for s in samples] - latencies_llm = [s["timing"]["llm_ms"] for s in samples] - - # Extract context stats - context_counts = [s["retrieval"]["num_docs"] for s in samples] - context_chars = [s["retrieval"]["context_length"] for s in samples] - context_tokens = [s["context_tokens"] for s in samples] - - # Overall metrics - overall_metrics = { - "f1": sum(f1_scores) / len(f1_scores) if f1_scores else 0.0, - "bleu1": sum(bleu1_scores) / len(bleu1_scores) if bleu1_scores else 0.0, - "jaccard": sum(jaccard_scores) / len(jaccard_scores) if jaccard_scores else 0.0, - "locomo_f1": sum(locomo_f1_scores) / len(locomo_f1_scores) if locomo_f1_scores else 0.0 - } - - # Per-category metrics - category_data: Dict[str, Dict[str, List[float]]] = {} - for sample in samples: - cat = sample["category"] - if cat not in category_data: - category_data[cat] = { - "f1": [], - "bleu1": [], - "jaccard": [], - "locomo_f1": [] - } - category_data[cat]["f1"].append(sample["metrics"]["f1"]) - category_data[cat]["bleu1"].append(sample["metrics"]["bleu1"]) - category_data[cat]["jaccard"].append(sample["metrics"]["jaccard"]) - category_data[cat]["locomo_f1"].append(sample["metrics"]["locomo_f1"]) - - by_category: Dict[str, Dict[str, Any]] = {} - for cat, metrics_lists in category_data.items(): - by_category[cat] = { - "count": len(metrics_lists["f1"]), - "f1": sum(metrics_lists["f1"]) / len(metrics_lists["f1"]), - "bleu1": sum(metrics_lists["bleu1"]) / len(metrics_lists["bleu1"]), - "jaccard": sum(metrics_lists["jaccard"]) / len(metrics_lists["jaccard"]), - "locomo_f1": sum(metrics_lists["locomo_f1"]) / len(metrics_lists["locomo_f1"]) - } - - # Latency statistics - latency = { - "search": latency_stats(latencies_search), - "llm": latency_stats(latencies_llm) - } - - # Context statistics - context_stats = { - "avg_retrieved_docs": sum(context_counts) / len(context_counts) if context_counts else 0.0, - "avg_context_chars": sum(context_chars) / len(context_chars) if context_chars else 0.0, - "avg_context_tokens": sum(context_tokens) / len(context_tokens) if context_tokens else 0.0 - } - - return { - "overall_metrics": overall_metrics, - "by_category": by_category, - "latency": latency, - "context_stats": context_stats - } - - -# ============================================================================ -# Step 6: Result Saving -# ============================================================================ - -def step_save_results( - result: Dict[str, Any], - output_dir: Optional[str] -) -> str: - """ - Save evaluation results to JSON file. - - Args: - result: Complete result dictionary - output_dir: Directory to save results (uses default if None) - - Returns: - Path to saved file - """ - if output_dir is None: - # Use absolute path to ensure results are saved in the correct location - script_dir = Path(__file__).resolve().parent - output_dir = script_dir / "results" - else: - # Convert to Path object - output_dir = Path(output_dir) - # If relative path, make it relative to script directory - if not output_dir.is_absolute(): - script_dir = Path(__file__).resolve().parent - output_dir = script_dir / output_dir - - # Create directory if it doesn't exist - output_dir.mkdir(parents=True, exist_ok=True) - - timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") - output_path = output_dir / f"locomo_{timestamp_str}.json" - - try: - with open(output_path, "w", encoding="utf-8") as f: - json.dump(result, f, ensure_ascii=False, indent=2) - print(f"✅ Results saved to: {output_path}\n") - return str(output_path) - except Exception as e: - print(f"❌ Failed to save results: {e}") - print("📊 Printing results to console instead:\n") - print(json.dumps(result, ensure_ascii=False, indent=2)) - return "" - - -# ============================================================================ -# Main Orchestration Function -# ============================================================================ - - -async def run_locomo_benchmark( - sample_size: int = 20, - end_user_id: Optional[str] = None, - search_type: str = "hybrid", - search_limit: int = 12, - context_char_budget: int = 8000, - reset_group: bool = False, - skip_ingest: bool = False, - output_dir: Optional[str] = None, - max_ingest_messages: Optional[int] = None -) -> Dict[str, Any]: - """ - Run LoCoMo benchmark evaluation. - - This function orchestrates the complete evaluation pipeline by calling - well-defined step functions: - 1. Load LoCoMo dataset (only QA pairs from first conversation) - 2. Ingest conversations into database (unless skip_ingest=True) - 3. Initialize clients (Neo4j, LLM, Embedder) - 4. Process all questions (retrieve, generate, calculate metrics) - 5. Aggregate results - 6. Save results to file - - Note: By default, only the first conversation is ingested into the database, - and only QA pairs from that conversation are evaluated. This ensures that - all questions have corresponding memory in the database for retrieval. - - Args: - sample_size: Number of QA pairs to evaluate (from first conversation) - end_user_id: Database end_user ID for retrieval (uses default if None) - search_type: "keyword", "embedding", or "hybrid" - search_limit: Max documents to retrieve per query - context_char_budget: Max characters for context - reset_group: Whether to clear and re-ingest data - skip_ingest: If True, skip data ingestion and use existing data in Neo4j - output_dir: Directory to save results (uses default if None) - max_ingest_messages: Max messages per dialogue to ingest (for testing, None = all) - - Returns: - Dictionary with evaluation results including metrics, timing, and samples - """ - # Use default end_user_id if not provided - # 优先级:命令行参数 > LOCOMO_END_USER_ID > EVAL_END_USER_ID > 默认值 - if end_user_id is None: - end_user_id = os.getenv("LOCOMO_END_USER_ID") or os.getenv("EVAL_END_USER_ID", "locomo_benchmark") - - # Get model IDs from config - llm_id = os.getenv("EVAL_LLM_ID", "6dc52e1b-9cec-4194-af66-a74c6307fc3f") - embedding_id = os.getenv("EVAL_EMBEDDING_ID", "e2a6392d-ca63-4d59-a523-647420b59cb2") - - # Determine data path - dataset_dir = Path(__file__).resolve().parent.parent / "dataset" - data_path = dataset_dir / "locomo10.json" - if not os.path.exists(data_path): - raise FileNotFoundError( - f"数据集文件不存在: {data_path}\n" - f"请将 locomo10.json 放置在: {dataset_dir}" - ) - - # Print configuration - print(f"\n{'='*60}") - print("🚀 Starting LoCoMo Benchmark Evaluation") - print(f"{'='*60}") - print("📊 Configuration:") - print(f" Sample size: {sample_size}") - print(f" End User ID: {end_user_id}") - print(f" Search type: {search_type}") - print(f" Search limit: {search_limit}") - print(f" Context budget: {context_char_budget} chars") - print(f" Data path: {data_path}") - if max_ingest_messages: - print(f" Max ingest messages: {max_ingest_messages} (testing mode)") - print(f"{'='*60}\n") - - # Step 1: Load LoCoMo data (加载数据) - try: - qa_items = step_load_data(data_path, sample_size) - except Exception as e: - print(f"❌ Failed to load data: {e}") - return { - "error": f"Data loading failed: {e}", - "timestamp": datetime.now().isoformat() - } - - # Step 2: Ingest data if needed(数据摄入) - await step_ingest_data(data_path, end_user_id, skip_ingest, reset_group, max_ingest_messages) - - # Step 3: Initialize clients (初始化客户端) - connector, llm_client, embedder = step_initialize_clients(llm_id, embedding_id) - - # Step 4: Process all questions (处理所有问题) - try: - samples = await step_process_all_questions( - qa_items=qa_items, - end_user_id=end_user_id, - search_type=search_type, - search_limit=search_limit, - context_char_budget=context_char_budget, - connector=connector, - embedder=embedder, - llm_client=llm_client - ) - finally: - await connector.close() - - # Step 5: Aggregate results (聚合答案) - aggregated = step_aggregate_results(samples) - - # Build final result dictionary - result = { - "dataset": "locomo", - "sample_size": len(qa_items), - "timestamp": datetime.now().isoformat(), - "params": { - "end_user_id": end_user_id, - "search_type": search_type, - "search_limit": search_limit, - "context_char_budget": context_char_budget, - "llm_id": llm_id, - "embedding_id": embedding_id - }, - "overall_metrics": aggregated["overall_metrics"], - "by_category": aggregated["by_category"], - "latency": aggregated["latency"], - "context_stats": aggregated["context_stats"], - "samples": samples - } - - # Step 6: Save results (保存结果) - step_save_results(result, output_dir) - - return result - - -def main(): - """ - Parse command-line arguments and run benchmark. - - This function provides a CLI interface for running LoCoMo benchmarks - with configurable parameters. - - Configuration priority: Command-line args > Environment variables > Code defaults - """ - # Load environment variables first - load_dotenv() - - # Get defaults from environment variables - env_sample_size = os.getenv("LOCOMO_SAMPLE_SIZE") - env_search_limit = os.getenv("LOCOMO_SEARCH_LIMIT") - env_context_budget = os.getenv("LOCOMO_CONTEXT_CHAR_BUDGET") - env_output_dir = os.getenv("LOCOMO_OUTPUT_DIR") - env_skip_ingest = os.getenv("LOCOMO_SKIP_INGEST", "false").lower() in ("true", "1", "yes") - - # Convert to appropriate types with fallback to code defaults - default_sample_size = int(env_sample_size) if env_sample_size else 20 - default_search_limit = int(env_search_limit) if env_search_limit else 12 - default_context_budget = int(env_context_budget) if env_context_budget else 8000 - default_output_dir = env_output_dir if env_output_dir else None - - parser = argparse.ArgumentParser( - description="Run LoCoMo benchmark evaluation", - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--sample_size", - type=int, - default=default_sample_size, - help=f"Number of QA pairs to evaluate (env: LOCOMO_SAMPLE_SIZE={env_sample_size or 'not set'}, 0 for all)" - ) - parser.add_argument( - "--end_user_id", - type=str, - default=None, - help="Database end user ID for retrieval (uses LOCOMO_END_USER_ID or EVAL_END_USER_ID if not specified)" - ) - parser.add_argument( - "--search_type", - type=str, - default="hybrid", - choices=["keyword", "embedding", "hybrid"], - help="Search strategy to use" - ) - parser.add_argument( - "--search_limit", - type=int, - default=default_search_limit, - help=f"Maximum number of documents to retrieve per query (env: LOCOMO_SEARCH_LIMIT={env_search_limit or 'not set'})" - ) - parser.add_argument( - "--context_char_budget", - type=int, - default=default_context_budget, - help=f"Maximum characters for context (env: LOCOMO_CONTEXT_CHAR_BUDGET={env_context_budget or 'not set'})" - ) - parser.add_argument( - "--reset_group", - action="store_true", - help="Clear and re-ingest data (not implemented)" - ) - parser.add_argument( - "--skip_ingest", - action="store_true", - default=env_skip_ingest, - help=f"Skip data ingestion and use existing data in Neo4j (env: LOCOMO_SKIP_INGEST={os.getenv('LOCOMO_SKIP_INGEST', 'false')})" - ) - parser.add_argument( - "--output_dir", - type=str, - default=default_output_dir, - help=f"Directory to save results (env: LOCOMO_OUTPUT_DIR={env_output_dir or 'not set'})" - ) - parser.add_argument( - "--max_ingest_messages", - type=int, - default=None, - help="Maximum messages per dialogue to ingest (for testing, default: all messages)" - ) - - args = parser.parse_args() - - # Run benchmark - result = asyncio.run(run_locomo_benchmark( - sample_size=args.sample_size, - end_user_id=args.end_user_id, - search_type=args.search_type, - search_limit=args.search_limit, - context_char_budget=args.context_char_budget, - reset_group=args.reset_group, - skip_ingest=args.skip_ingest, - output_dir=args.output_dir, - max_ingest_messages=args.max_ingest_messages - )) - - # Print summary - print(f"\n{'='*60}") - - # Check if there was an error - if 'error' in result: - print("❌ Benchmark Failed!") - print(f"{'='*60}") - print(f"Error: {result['error']}") - return - - print("🎉 Benchmark Complete!") - print(f"{'='*60}") - print("📊 Final Results:") - print(f" Sample size: {result.get('sample_size', 0)}") - print(f" F1: {result['overall_metrics']['f1']:.3f}") - print(f" BLEU-1: {result['overall_metrics']['bleu1']:.3f}") - print(f" Jaccard: {result['overall_metrics']['jaccard']:.3f}") - print(f" LoCoMo F1: {result['overall_metrics']['locomo_f1']:.3f}") - - if result.get('context_stats'): - print("\n📈 Context Statistics:") - print(f" Avg retrieved docs: {result['context_stats']['avg_retrieved_docs']:.1f}") - print(f" Avg context chars: {result['context_stats']['avg_context_chars']:.0f}") - print(f" Avg context tokens: {result['context_stats']['avg_context_tokens']:.0f}") - - if result.get('latency'): - print("\n⏱️ Latency Statistics:") - print(f" Search - Mean: {result['latency']['search']['mean']:.1f}ms, " - f"P50: {result['latency']['search']['p50']:.1f}ms, " - f"P95: {result['latency']['search']['p95']:.1f}ms") - print(f" LLM - Mean: {result['latency']['llm']['mean']:.1f}ms, " - f"P50: {result['latency']['llm']['p50']:.1f}ms, " - f"P95: {result['latency']['llm']['p95']:.1f}ms") - - if result.get('by_category'): - print("\n📂 Results by Category:") - for cat, metrics in result['by_category'].items(): - print(f" {cat}:") - print(f" Count: {metrics['count']}") - print(f" F1: {metrics['f1']:.3f}") - print(f" LoCoMo F1: {metrics['locomo_f1']:.3f}") - print(f" Jaccard: {metrics['jaccard']:.3f}") - - print(f"\n{'='*60}\n") - - -if __name__ == "__main__": - main() diff --git a/api/app/core/memory/evaluation/locomo/locomo_metrics.py b/api/app/core/memory/evaluation/locomo/locomo_metrics.py deleted file mode 100644 index 20d5f2b5..00000000 --- a/api/app/core/memory/evaluation/locomo/locomo_metrics.py +++ /dev/null @@ -1,225 +0,0 @@ -""" -LoCoMo-specific metric calculations. - -This module provides clean, simplified implementations of metrics used for -LoCoMo benchmark evaluation, including text normalization and F1 score variants. -""" - -import re -from typing import Dict, Any - - -def normalize_text(text: str) -> str: - """ - Normalize text for LoCoMo evaluation. - - Normalization steps: - - Convert to lowercase - - Remove commas - - Remove stop words (a, an, the, and) - - Remove punctuation - - Normalize whitespace - - Args: - text: Input text to normalize - - Returns: - Normalized text string with consistent formatting - - Examples: - >>> normalize_text("The cat, and the dog") - 'cat dog' - >>> normalize_text("Hello, World!") - 'hello world' - """ - # Ensure input is a string - text = str(text) if text is not None else "" - - # Convert to lowercase - text = text.lower() - - # Remove commas - text = re.sub(r"[\,]", " ", text) - - # Remove stop words - text = re.sub(r"\b(a|an|the|and)\b", " ", text) - - # Remove punctuation (keep only word characters and whitespace) - text = re.sub(r"[^\w\s]", " ", text) - - # Normalize whitespace (collapse multiple spaces to single space) - text = " ".join(text.split()) - - return text - - -def locomo_f1_score(prediction: str, ground_truth: str) -> float: - """ - Calculate LoCoMo F1 score for single-answer questions. - - Uses token-level precision and recall based on normalized text. - Treats tokens as sets (no duplicate counting). - - Args: - prediction: Model's predicted answer - ground_truth: Correct answer - - Returns: - F1 score between 0.0 and 1.0 - - Examples: - >>> locomo_f1_score("Paris", "Paris") - 1.0 - >>> locomo_f1_score("The cat", "cat") - 1.0 - >>> locomo_f1_score("dog", "cat") - 0.0 - """ - # Ensure inputs are strings - pred_str = str(prediction) if prediction is not None else "" - truth_str = str(ground_truth) if ground_truth is not None else "" - - # Normalize and tokenize - pred_tokens = normalize_text(pred_str).split() - truth_tokens = normalize_text(truth_str).split() - - # Handle empty cases - if not pred_tokens or not truth_tokens: - return 0.0 - - # Convert to sets for comparison - pred_set = set(pred_tokens) - truth_set = set(truth_tokens) - - # Calculate true positives (intersection) - true_positives = len(pred_set & truth_set) - - # Calculate precision and recall - precision = true_positives / len(pred_set) if pred_set else 0.0 - recall = true_positives / len(truth_set) if truth_set else 0.0 - - # Calculate F1 score - if precision + recall == 0: - return 0.0 - - f1 = 2 * precision * recall / (precision + recall) - return f1 - - -def locomo_multi_f1(prediction: str, ground_truth: str) -> float: - """ - Calculate LoCoMo F1 score for multi-answer questions. - - Handles comma-separated answers by: - 1. Splitting both prediction and ground truth by commas - 2. For each ground truth answer, finding the best matching prediction - 3. Averaging the F1 scores across all ground truth answers - - Args: - prediction: Model's predicted answer (may contain multiple comma-separated answers) - ground_truth: Correct answer (may contain multiple comma-separated answers) - - Returns: - Average F1 score across all ground truth answers (0.0 to 1.0) - - Examples: - >>> locomo_multi_f1("Paris, London", "Paris, London") - 1.0 - >>> locomo_multi_f1("Paris", "Paris, London") - 0.5 - >>> locomo_multi_f1("Paris, Berlin", "Paris, London") - 0.5 - """ - # Ensure inputs are strings - pred_str = str(prediction) if prediction is not None else "" - truth_str = str(ground_truth) if ground_truth is not None else "" - - # Split by commas and strip whitespace - predictions = [p.strip() for p in pred_str.split(',') if p.strip()] - ground_truths = [g.strip() for g in truth_str.split(',') if g.strip()] - - # Handle empty cases - if not predictions or not ground_truths: - return 0.0 - - # For each ground truth, find the best matching prediction - f1_scores = [] - for gt in ground_truths: - # Calculate F1 with each prediction and take the maximum - best_f1 = max(locomo_f1_score(pred, gt) for pred in predictions) - f1_scores.append(best_f1) - - # Return average F1 across all ground truths - return sum(f1_scores) / len(f1_scores) - - -def get_category_name(item: Dict[str, Any]) -> str: - """ - Extract and normalize category name from QA item. - - Handles both numeric categories (1-4) and string categories with various formats. - Supports multiple field names: "cat", "category", "type". - - Category mapping: - - 1 or "multi-hop" -> "Multi-Hop" - - 2 or "temporal" -> "Temporal" - - 3 or "open domain" -> "Open Domain" - - 4 or "single-hop" -> "Single-Hop" - - Args: - item: QA item dictionary containing category information - - Returns: - Standardized category name or "unknown" if not found - - Examples: - >>> get_category_name({"category": 1}) - 'Multi-Hop' - >>> get_category_name({"cat": "temporal"}) - 'Temporal' - >>> get_category_name({"type": "Single-Hop"}) - 'Single-Hop' - """ - # Numeric category mapping - CATEGORY_MAP = { - 1: "Multi-Hop", - 2: "Temporal", - 3: "Open Domain", - 4: "Single-Hop", - } - - # String category aliases (case-insensitive) - TYPE_ALIASES = { - "single-hop": "Single-Hop", - "singlehop": "Single-Hop", - "single hop": "Single-Hop", - "multi-hop": "Multi-Hop", - "multihop": "Multi-Hop", - "multi hop": "Multi-Hop", - "open domain": "Open Domain", - "opendomain": "Open Domain", - "temporal": "Temporal", - } - - # Try "cat" field first (string category) - cat = item.get("cat") - if isinstance(cat, str) and cat.strip(): - name = cat.strip() - lower = name.lower() - return TYPE_ALIASES.get(lower, name) - - # Try "category" field (can be int or string) - cat_num = item.get("category") - if isinstance(cat_num, int): - return CATEGORY_MAP.get(cat_num, "unknown") - elif isinstance(cat_num, str) and cat_num.strip(): - lower = cat_num.strip().lower() - return TYPE_ALIASES.get(lower, cat_num.strip()) - - # Try "type" field as fallback - cat_type = item.get("type") - if isinstance(cat_type, str) and cat_type.strip(): - lower = cat_type.strip().lower() - return TYPE_ALIASES.get(lower, cat_type.strip()) - - return "unknown" diff --git a/api/app/core/memory/evaluation/locomo/locomo_test.py b/api/app/core/memory/evaluation/locomo/locomo_test.py deleted file mode 100644 index 2cb0664c..00000000 --- a/api/app/core/memory/evaluation/locomo/locomo_test.py +++ /dev/null @@ -1,864 +0,0 @@ -# file name: check_neo4j_connection_fixed.py -import asyncio -import os -import sys -import json -import time -import math -import re -from datetime import datetime, timedelta -from typing import List, Dict, Any -from pathlib import Path -from dotenv import load_dotenv - -# Load main .env -load_dotenv() - -# Load evaluation config -eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation" -if eval_config_path.exists(): - load_dotenv(eval_config_path, override=True) - print(f"✅ 加载评估配置: {eval_config_path}") - -# Get group_id from config -group_id = os.getenv("EVAL_GROUP_ID", "locomo_test") -print(f"✅ 使用配置的 group_id: {group_id}") - -# 首先定义 _loc_normalize 函数,因为其他函数依赖它 -def _loc_normalize(text: str) -> str: - text = str(text) if text is not None else "" - text = text.lower() - text = re.sub(r"[\,]", " ", text) - text = re.sub(r"\b(a|an|the|and)\b", " ", text) - text = re.sub(r"[^\w\s]", " ", text) - text = " ".join(text.split()) - return text - -# 尝试从 metrics.py 导入基础指标 -try: - from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard - print("✅ 从 metrics.py 导入基础指标成功") -except ImportError as e: - print(f"❌ 从 metrics.py 导入失败: {e}") - # 回退到本地实现 - def f1_score(pred: str, ref: str) -> float: - pred_str = str(pred) if pred is not None else "" - ref_str = str(ref) if ref is not None else "" - - p_tokens = _loc_normalize(pred_str).split() - r_tokens = _loc_normalize(ref_str).split() - if not p_tokens and not r_tokens: - return 1.0 - if not p_tokens or not r_tokens: - return 0.0 - p_set = set(p_tokens) - r_set = set(r_tokens) - tp = len(p_set & r_set) - precision = tp / len(p_set) if p_set else 0.0 - recall = tp / len(r_set) if r_set else 0.0 - if precision + recall == 0: - return 0.0 - return 2 * precision * recall / (precision + recall) - - def bleu1(pred: str, ref: str) -> float: - pred_str = str(pred) if pred is not None else "" - ref_str = str(ref) if ref is not None else "" - - p_tokens = _loc_normalize(pred_str).split() - r_tokens = _loc_normalize(ref_str).split() - if not p_tokens: - return 0.0 - - r_counts = {} - for t in r_tokens: - r_counts[t] = r_counts.get(t, 0) + 1 - - clipped = 0 - p_counts = {} - for t in p_tokens: - p_counts[t] = p_counts.get(t, 0) + 1 - - for t, c in p_counts.items(): - clipped += min(c, r_counts.get(t, 0)) - - precision = clipped / max(len(p_tokens), 1) - ref_len = len(r_tokens) - pred_len = len(p_tokens) - - if pred_len > ref_len or pred_len == 0: - bp = 1.0 - else: - bp = math.exp(1 - ref_len / max(pred_len, 1)) - - return bp * precision - - def jaccard(pred: str, ref: str) -> float: - pred_str = str(pred) if pred is not None else "" - ref_str = str(ref) if ref is not None else "" - - p = set(_loc_normalize(pred_str).split()) - r = set(_loc_normalize(ref_str).split()) - if not p and not r: - return 1.0 - if not p or not r: - return 0.0 - return len(p & r) / len(p | r) - -# 尝试从 qwen_search_eval.py 导入 LoCoMo 特定指标 -try: - from app.core.memory.evaluation.locomo.qwen_search_eval import loc_f1_score, loc_multi_f1, _resolve_relative_times - print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功") -except ImportError as e: - print(f"❌ 从 qwen_search_eval.py 导入失败: {e}") - # 回退到本地实现 LoCoMo 特定函数 - def _resolve_relative_times(text: str, anchor: datetime) -> str: - t = str(text) if text is not None else "" - t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE) - - def _ago_repl(m: re.Match[str]) -> str: - n = int(m.group(1)) - return (anchor - timedelta(days=n)).date().isoformat() - def _in_repl(m: re.Match[str]) -> str: - n = int(m.group(1)) - return (anchor + timedelta(days=n)).date().isoformat() - - t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE) - t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE) - t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE) - return t - - def loc_f1_score(prediction: str, ground_truth: str) -> float: - p_tokens = _loc_normalize(prediction).split() - g_tokens = _loc_normalize(ground_truth).split() - if not p_tokens or not g_tokens: - return 0.0 - p = set(p_tokens) - g = set(g_tokens) - tp = len(p & g) - precision = tp / len(p) if p else 0.0 - recall = tp / len(g) if g else 0.0 - return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0 - - def loc_multi_f1(prediction: str, ground_truth: str) -> float: - predictions = [p.strip() for p in str(prediction).split(',') if p.strip()] - ground_truths = [g.strip() for g in str(ground_truth).split(',') if g.strip()] - if not predictions or not ground_truths: - return 0.0 - def _f1(a: str, b: str) -> float: - return loc_f1_score(a, b) - vals = [] - for gt in ground_truths: - vals.append(max(_f1(pred, gt) for pred in predictions)) - return sum(vals) / len(vals) - - -def smart_context_selection(contexts: List[str], question: str, max_chars: int = 8000) -> str: - """基于问题关键词智能选择上下文""" - if not contexts: - return "" - - # 提取问题关键词(只保留有意义的词) - question_lower = question.lower() - stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'} - question_words = set(re.findall(r'\b\w+\b', question_lower)) - question_words = {word for word in question_words if word not in stop_words and len(word) > 2} - - print(f"🔍 问题关键词: {question_words}") - - # 给每个上下文打分 - scored_contexts = [] - for i, context in enumerate(contexts): - context_lower = context.lower() - score = 0 - - # 关键词匹配得分 - keyword_matches = 0 - for word in question_words: - if word in context_lower: - keyword_matches += 1 - # 关键词出现次数越多,得分越高 - score += context_lower.count(word) * 2 - - # 上下文长度得分(适中的长度更好) - context_len = len(context) - if 100 < context_len < 2000: # 理想长度范围 - score += 5 - elif context_len >= 2000: # 太长可能包含无关信息 - score += 2 - - # 如果是前几个上下文,给予额外分数(通常相关性更高) - if i < 3: - score += 3 - - scored_contexts.append((score, context, keyword_matches)) - - # 按得分排序 - scored_contexts.sort(key=lambda x: x[0], reverse=True) - - # 选择高得分的上下文,直到达到字符限制 - selected = [] - total_chars = 0 - selected_count = 0 - - print("📊 上下文相关性分析:") - for score, context, matches in scored_contexts[:5]: # 只显示前5个 - print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}") - - for score, context, matches in scored_contexts: - if total_chars + len(context) <= max_chars: - selected.append(context) - total_chars += len(context) - selected_count += 1 - else: - # 如果这个上下文得分很高但放不下,尝试截取 - if score > 10 and total_chars < max_chars - 500: - remaining = max_chars - total_chars - # 找到包含关键词的部分 - lines = context.split('\n') - relevant_lines = [] - current_chars = 0 - - for line in lines: - line_lower = line.lower() - line_relevance = any(word in line_lower for word in question_words) - - if line_relevance and current_chars < remaining - 100: - relevant_lines.append(line) - current_chars += len(line) - - if relevant_lines: - truncated = '\n'.join(relevant_lines) - if len(truncated) > 100: # 确保有足够内容 - selected.append(truncated + "\n[相关内容截断...]") - total_chars += len(truncated) - selected_count += 1 - break # 不再尝试添加更多上下文 - - result = "\n\n".join(selected) - print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符") - return result - - -def get_dynamic_search_params(question: str, question_index: int, total_questions: int): - """根据问题复杂度和进度动态调整检索参数""" - - # 分析问题复杂度 - word_count = len(question.split()) - has_temporal = any(word in question.lower() for word in ['when', 'date', 'time', 'ago']) - has_multi_hop = any(word in question.lower() for word in ['and', 'both', 'also', 'while']) - - # 根据进度调整 - 后期问题可能需要更精确的检索 - progress_factor = question_index / total_questions - - base_limit = 12 - if has_temporal and has_multi_hop: - base_limit = 20 - elif word_count > 8: - base_limit = 16 - - # 随着测试进行,逐渐收紧检索范围 - adjusted_limit = max(8, int(base_limit * (1 - progress_factor * 0.3))) - - # 动态调整最大字符数 - max_chars = 8000 + 4000 * (1 - progress_factor) - - return { - "limit": adjusted_limit, - "max_chars": int(max_chars) - } - - -class EnhancedEvaluationMonitor: - def __init__(self, reset_interval=5, performance_threshold=0.6): - self.question_count = 0 - self.reset_interval = reset_interval - self.performance_threshold = performance_threshold - self.consecutive_low_scores = 0 - self.performance_history = [] - self.recent_f1_scores = [] - - def should_reset_connections(self, current_f1=None): - """基于计数和性能双重判断""" - # 定期重置 - if self.question_count % self.reset_interval == 0: - return True - - # 性能驱动的重置 - if current_f1 is not None and current_f1 < self.performance_threshold: - self.consecutive_low_scores += 1 - if self.consecutive_low_scores >= 2: # 连续2个低分就重置 - print("🚨 连续低分,触发紧急重置") - self.consecutive_low_scores = 0 - return True - else: - self.consecutive_low_scores = 0 - - return False - - def record_performance(self, question_index, metrics, context_length, retrieved_docs): - """记录性能指标,检测衰减""" - self.performance_history.append({ - 'index': question_index, - 'metrics': metrics, - 'context_length': context_length, - 'retrieved_docs': retrieved_docs, - 'timestamp': time.time() - }) - - # 记录最近的F1分数 - self.recent_f1_scores.append(metrics['f1']) - if len(self.recent_f1_scores) > 5: - self.recent_f1_scores.pop(0) - - def get_recent_performance(self): - """获取近期平均性能""" - if not self.recent_f1_scores: - return 0.5 - return sum(self.recent_f1_scores) / len(self.recent_f1_scores) - - def get_performance_trend(self): - """分析性能趋势""" - if len(self.performance_history) < 2: - return "stable" - - recent_metrics = [item['metrics']['f1'] for item in self.performance_history[-5:]] - earlier_metrics = [item['metrics']['f1'] for item in self.performance_history[-10:-5]] - - if len(recent_metrics) < 2 or len(earlier_metrics) < 2: - return "stable" - - recent_avg = sum(recent_metrics) / len(recent_metrics) - earlier_avg = sum(earlier_metrics) / len(earlier_metrics) - - if recent_avg < earlier_avg * 0.8: - return "degrading" - elif recent_avg > earlier_avg * 1.1: - return "improving" - else: - return "stable" - - -def get_enhanced_search_params(question: str, question_index: int, total_questions: int, recent_performance: float): - """基于问题复杂度和近期性能动态调整检索参数""" - - # 基础参数 - base_params = get_dynamic_search_params(question, question_index, total_questions) - - # 性能自适应调整 - if recent_performance < 0.5: # 近期表现差 - # 增加检索范围,尝试获取更多上下文 - base_params["limit"] = min(base_params["limit"] + 5, 25) - base_params["max_chars"] = min(base_params["max_chars"] + 2000, 12000) - print(f"📈 性能自适应:增加检索范围 (limit={base_params['limit']}, max_chars={base_params['max_chars']})") - - elif recent_performance > 0.8: # 近期表现好 - # 收紧检索,提高精度 - base_params["limit"] = max(base_params["limit"] - 2, 8) - base_params["max_chars"] = max(base_params["max_chars"] - 1000, 6000) - print(f"🎯 性能自适应:提高检索精度 (limit={base_params['limit']}, max_chars={base_params['max_chars']})") - - # 中间阶段特殊处理 - mid_sequence_factor = abs(question_index / total_questions - 0.5) - if mid_sequence_factor < 0.2: # 在中间30%的问题 - print("🎯 中间阶段:使用更精确的检索策略") - base_params["limit"] = max(base_params["limit"] - 2, 10) # 减少数量,提高质量 - base_params["max_chars"] = max(base_params["max_chars"] - 1000, 7000) - - return base_params - - -def enhanced_context_selection(contexts: List[str], question: str, question_index: int, total_questions: int, max_chars: int = 8000) -> str: - """考虑问题序列位置的智能选择""" - - if not contexts: - return "" - - # 在序列中间阶段使用更严格的筛选 - mid_sequence_factor = abs(question_index / total_questions - 0.5) # 距离中心的距离 - - if mid_sequence_factor < 0.2: # 在中间30%的问题 - print("🎯 中间阶段:使用严格上下文筛选") - - # 提取问题关键词 - question_lower = question.lower() - stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'} - question_words = set(re.findall(r'\b\w+\b', question_lower)) - question_words = {word for word in question_words if word not in stop_words and len(word) > 2} - - # 只保留高度相关的上下文 - filtered_contexts = [] - for context in contexts: - context_lower = context.lower() - relevance_score = sum(3 if word in context_lower else 0 for word in question_words) - - # 额外加分给包含数字、日期的上下文(对事实性问题更重要) - if any(char.isdigit() for char in context): - relevance_score += 2 - - # 提高阈值:只有得分>=3的上下文才保留 - if relevance_score >= 3: - filtered_contexts.append(context) - else: - print(f" - 过滤低分上下文: 得分={relevance_score}") - - contexts = filtered_contexts - print(f"🔍 严格筛选后保留 {len(contexts)} 个上下文") - - # 使用原有的智能选择逻辑 - return smart_context_selection(contexts, question, max_chars) - - -async def run_enhanced_evaluation(): - """使用增强方法进行完整评估 - 解决中间性能衰减问题""" - from dotenv import load_dotenv - from uuid import UUID - from datetime import datetime - from dataclasses import dataclass - - # 修正导入路径:使用 app.core.memory.src 前缀 - from app.repositories.neo4j.neo4j_connector import Neo4jConnector - from app.repositories.neo4j.graph_search import search_graph_by_embedding - from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient - from app.core.models.base import RedBearModelConfig - from app.core.memory.utils.llm.llm_utils import get_llm_client - from app.core.memory.utils.config.config_utils import get_embedder_config - from app.schemas.memory_config_schema import MemoryConfig - from app.services.memory_config_service import MemoryConfigService - - # Get model IDs from config - llm_id = os.getenv("EVAL_LLM_ID", "6dc52e1b-9cec-4194-af66-a74c6307fc3f") - embedding_id = os.getenv("EVAL_EMBEDDING_ID", "e2a6392d-ca63-4d59-a523-647420b59cb2") - - # 加载数据 - 使用统一的 dataset 目录 - data_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "dataset", "locomo10.json") - - if not os.path.exists(data_path): - raise FileNotFoundError( - f"数据集文件不存在: {data_path}\n" - f"请将 locomo10.json 放置在: api/app/core/memory/evaluation/dataset/" - ) - - print(f"✅ 找到数据文件: {data_path}") - - with open(data_path, "r", encoding="utf-8") as f: - raw = json.load(f) - - qa_items = [] - if isinstance(raw, list): - for entry in raw: - qa_items.extend(entry.get("qa", [])) - else: - qa_items.extend(raw.get("qa", [])) - - # 测试多少个问题 - 可通过环境变量设置 - sample_size = int(os.getenv("LOCOMO_SAMPLE_SIZE", "20")) - items = qa_items[:sample_size] - print(f"📊 将测试 {len(items)} 个问题(总共 {len(qa_items)} 个可用)") - - # 初始化增强监控器 - monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6) - - # 获取数据库会话并初始化 LLM 客户端 - from app.db import get_db - db = next(get_db()) - - try: - llm = get_llm_client(llm_id, db) - - # 初始化embedder - cfg_dict = get_embedder_config(embedding_id, db) - embedder = OpenAIEmbedderClient( - model_config=RedBearModelConfig.model_validate(cfg_dict) - ) - - # 🔧 创建 MemoryConfig 对象用于搜索 - # 方案1:如果有配置ID,从数据库加载 - config_id = os.getenv("EVAL_CONFIG_ID") - if config_id: - print(f"📋 从数据库加载配置 ID: {config_id}") - memory_config_service = MemoryConfigService(db) - memory_config = memory_config_service.load_memory_config(config_id, service_name="locomo_test") - else: - # 方案2:创建临时配置对象用于测试 - print(f"📋 创建临时测试配置") - from uuid import UUID - from datetime import datetime - - # 将字符串 ID 转换为 UUID - try: - embedding_uuid = UUID(embedding_id) - llm_uuid = UUID(llm_id) - except ValueError as e: - raise ValueError(f"无效的 UUID 格式: {e}") - - memory_config = MemoryConfig( - config_id=1, # 临时 ID - config_name="locomo_test_config", - workspace_id=UUID("00000000-0000-0000-0000-000000000000"), # 临时 workspace - workspace_name="test_workspace", - tenant_id=UUID("00000000-0000-0000-0000-000000000000"), # 临时 tenant - embedding_model_id=embedding_uuid, - embedding_model_name="test_embedding", - llm_model_id=llm_uuid, - llm_model_name="test_llm", - storage_type="neo4j", - chunker_strategy="RecursiveChunker", - reflexion_enabled=False, - reflexion_iteration_period=3, - reflexion_range="partial", - reflexion_baseline="Time", - loaded_at=datetime.now() - ) - - print(f"✅ MemoryConfig 已准备: embedding_id={memory_config.embedding_model_id}, llm_id={memory_config.llm_model_id}") - - # 初始化连接器 - connector = Neo4jConnector() - - # 初始化结果字典 - results = { - "questions": [], - "overall_metrics": {"f1": 0.0, "b1": 0.0, "j": 0.0, "loc_f1": 0.0}, - "category_metrics": {}, - "retrieval_stats": {"total_questions": len(items), "avg_context_length": 0, "avg_retrieved_docs": 0}, - "performance_trend": "stable", - "timestamp": datetime.now().isoformat(), - "enhanced_strategy": True - } - - total_f1 = 0.0 - total_bleu1 = 0.0 - total_jaccard = 0.0 - total_loc_f1 = 0.0 - total_context_length = 0 - total_retrieved_docs = 0 - category_stats = {} - - try: - for i, item in enumerate(items): - monitor.question_count += 1 - - # 获取近期性能用于重置判断 - recent_performance = monitor.get_recent_performance() - - # 增强的重置判断 - should_reset = monitor.should_reset_connections(current_f1=recent_performance) - if should_reset and i > 0: - print(f"🔄 重置Neo4j连接 (问题 {i+1}/{len(items)}, 近期性能: {recent_performance:.3f})...") - await connector.close() - connector = Neo4jConnector() # 创建新连接 - print("✅ 连接重置完成") - - q = item.get("question", "") - ref = item.get("answer", "") - ref_str = str(ref) if ref is not None else "" - - print(f"\n🔍 [{i+1}/{len(items)}] 问题: {q}") - print(f"✅ 真实答案: {ref_str}") - - # 分类别统计 - category = "Unknown" - if item.get("category") == 1: - category = "Multi-Hop" - elif item.get("category") == 2: - category = "Temporal" - elif item.get("category") == 3: - category = "Open Domain" - elif item.get("category") == 4: - category = "Single-Hop" - - # 增强的检索参数 - search_params = get_enhanced_search_params(q, i, len(items), recent_performance) - search_limit = search_params["limit"] - max_chars = search_params["max_chars"] - - print(f"🏷️ 类别: {category}, 检索参数: limit={search_limit}, max_chars={max_chars}") - - # 使用项目标准的混合检索方法 - t0 = time.time() - contexts_all = [] - - try: - # 使用旧版本的搜索服务(重构前的版本) - from app.core.memory.src.search import run_hybrid_search - - print(f"🔀 使用混合搜索服务(旧版本)...") - print(f"📍 检索参数: group_id={group_id}, limit=20, search_type=hybrid") - print(f"📍 查询文本: {q}") - - search_results = await run_hybrid_search( - query_text=q, - search_type="hybrid", - end_user_id="locomo_sk", - limit=20, - include=["statements", "chunks", "entities", "summaries"], - output_path=None, - memory_config=memory_config, # 🔧 添加必需的 memory_config 参数 - rerank_alpha=0.6, # BM25权重 - use_forgetting_rerank=False, - use_llm_rerank=False - ) - - # 处理搜索结果 - 旧版本返回包含 reranked_results 的结构 - # 对于 hybrid 搜索,使用 reranked_results - if "reranked_results" in search_results: - reranked = search_results["reranked_results"] - chunks = reranked.get("chunks", []) - statements = reranked.get("statements", []) - entities = reranked.get("entities", []) - summaries = reranked.get("summaries", []) - else: - # 单一搜索类型的结果 - chunks = search_results.get("chunks", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - summaries = search_results.get("summaries", []) - - print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要") - - # 构建上下文:优先使用 chunks、statements 和 summaries - for c in chunks: - content = str(c.get("content", "")).strip() - if content: - contexts_all.append(content) - - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - - for sm in summaries: - summary_text = str(sm.get("summary", "")).strip() - if summary_text: - contexts_all.append(summary_text) - - # 实体摘要:最多加入前3个高分实体,避免噪声 - scored = [e for e in entities if e.get("score") is not None] - top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3] - if top_entities: - summary_lines = [] - for e in top_entities: - name = str(e.get("name", "")).strip() - etype = str(e.get("entity_type", "")).strip() - score = e.get("score") - if name: - meta = [] - if etype: - meta.append(f"type={etype}") - if isinstance(score, (int, float)): - meta.append(f"score={score:.3f}") - summary_lines.append(f"EntitySummary: {name}{(' [' + ' '.join(meta) + ']') if meta else ''}") - if summary_lines: - contexts_all.append("\n".join(summary_lines)) - - print(f"📊 有效上下文数量: {len(contexts_all)}") - except Exception as e: - print(f"❌ 检索失败: {e}") - import traceback - print(f"详细错误信息:\n{traceback.format_exc()}") - contexts_all = [] - - t1 = time.time() - search_time = (t1 - t0) * 1000 - - # 增强的上下文选择 - context_text = "" - if contexts_all: - # 使用增强的上下文选择 - context_text = enhanced_context_selection(contexts_all, q, i, len(items), max_chars=max_chars) - - # 如果智能选择后仍然过长,进行最终保护性截断 - if len(context_text) > max_chars: - print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断") - context_text = context_text[:max_chars] + "\n\n[最终截断...]" - - # 时间解析 - anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性 - context_text = _resolve_relative_times(context_text, anchor_date) - - context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text - - print(f"📝 最终上下文长度: {len(context_text)} 字符") - - # 显示不同上下文的预览(不只是第一条) - print("🔍 上下文预览:") - for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文 - preview = context[:150].replace('\n', ' ') - print(f" 上下文{j+1}: {preview}...") - - # 🔍 调试:检查答案是否在上下文中 - if ref_str and ref_str.strip(): - answer_found = any(ref_str.lower() in ctx.lower() for ctx in contexts_all) - print(f"🔍 调试:答案 '{ref_str}' 是否在检索到的上下文中? {'✅ 是' if answer_found else '❌ 否'}") - - else: - print("❌ 没有检索到有效上下文") - context_text = "No relevant context found." - - # LLM 回答 - messages = [ - {"role": "system", "content": ( - "You are a precise QA assistant. Answer following these rules:\n" - "1) Extract the EXACT information mentioned in the context\n" - "2) For time questions: calculate actual dates from relative times\n" - "3) Return ONLY the answer text in simplest form\n" - "4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n" - "5) If no clear answer found, respond with 'Unknown'" - )}, - {"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"}, - ] - - t2 = time.time() - try: - # 使用异步调用 - resp = await llm.chat(messages=messages) - # 兼容不同的响应格式 - pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown") - except Exception as e: - print(f"❌ LLM 生成失败: {e}") - pred = "Unknown" - t3 = time.time() - llm_time = (t3 - t2) * 1000 - - # 计算指标 - 使用导入的指标函数 - f1_val = f1_score(pred, ref_str) - bleu1_val = bleu1(pred, ref_str) - jaccard_val = jaccard(pred, ref_str) - loc_f1_val = loc_f1_score(pred, ref_str) - - print(f"🤖 LLM 回答: {pred}") - print(f"📈 指标 - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, Jaccard: {jaccard_val:.3f}, LoCoMo F1: {loc_f1_val:.3f}") - print(f"⏱️ 时间 - 检索: {search_time:.1f}ms, LLM: {llm_time:.1f}ms") - - # 更新统计 - total_f1 += f1_val - total_bleu1 += bleu1_val - total_jaccard += jaccard_val - total_loc_f1 += loc_f1_val - total_context_length += len(context_text) - total_retrieved_docs += len(contexts_all) - - if category not in category_stats: - category_stats[category] = {"count": 0, "f1_sum": 0.0, "b1_sum": 0.0, "j_sum": 0.0, "loc_f1_sum": 0.0} - - category_stats[category]["count"] += 1 - category_stats[category]["f1_sum"] += f1_val - category_stats[category]["b1_sum"] += bleu1_val - category_stats[category]["j_sum"] += jaccard_val - category_stats[category]["loc_f1_sum"] += loc_f1_val - - # 记录性能指标 - metrics = {"f1": f1_val, "bleu1": bleu1_val, "jaccard": jaccard_val, "loc_f1": loc_f1_val} - monitor.record_performance(i, metrics, len(context_text), len(contexts_all)) - - # 保存结果 - question_result = { - "question": q, - "ground_truth": ref_str, - "prediction": pred, - "category": category, - "metrics": metrics, - "retrieval": { - "retrieved_documents": len(contexts_all), - "context_length": len(context_text), - "search_limit": search_limit, - "max_chars": max_chars, - "recent_performance": recent_performance - }, - "timing": { - "search_ms": search_time, - "llm_ms": llm_time - } - } - - results["questions"].append(question_result) - - print("="*60) - - except Exception as e: - print(f"❌ 评估过程中发生错误: {e}") - # 即使出错,也返回已有的结果 - import traceback - traceback.print_exc() - - finally: - await connector.close() - - finally: - db.close() # 关闭数据库会话 - - # 计算总体指标 - n = len(items) - if n > 0: - results["overall_metrics"] = { - "f1": total_f1 / n, - "b1": total_bleu1 / n, - "j": total_jaccard / n, - "loc_f1": total_loc_f1 / n - } - - for category, stats in category_stats.items(): - count = stats["count"] - results["category_metrics"][category] = { - "count": count, - "f1": stats["f1_sum"] / count, - "bleu1": stats["b1_sum"] / count, - "jaccard": stats["j_sum"] / count, - "loc_f1": stats["loc_f1_sum"] / count - } - - results["retrieval_stats"]["avg_context_length"] = total_context_length / n - results["retrieval_stats"]["avg_retrieved_docs"] = total_retrieved_docs / n - - # 分析性能趋势 - results["performance_trend"] = monitor.get_performance_trend() - results["reset_interval"] = monitor.reset_interval - results["total_questions_processed"] = monitor.question_count - - return results - - -if __name__ == "__main__": - print("🚀 运行增强版完整评估(解决中间性能衰减问题)...") - print("📋 增强特性:") - print(" - 双重重置策略:定期重置 + 性能驱动重置") - print(" - 动态检索参数:基于近期性能自适应调整") - print(" - 中间阶段严格筛选:提高上下文质量要求") - print(" - 连续性能监控:实时检测性能衰减") - - result = asyncio.run(run_enhanced_evaluation()) - - print("\n📊 最终评估结果:") - print("总体指标:") - print(f" F1: {result['overall_metrics']['f1']:.4f}") - print(f" BLEU-1: {result['overall_metrics']['b1']:.4f}") - print(f" Jaccard: {result['overall_metrics']['j']:.4f}") - print(f" LoCoMo F1: {result['overall_metrics']['loc_f1']:.4f}") - - print("\n分类别指标:") - for category, metrics in result['category_metrics'].items(): - print(f" {category}: F1={metrics['f1']:.4f}, BLEU-1={metrics['bleu1']:.4f}, Jaccard={metrics['jaccard']:.4f}, LoCoMo F1={metrics['loc_f1']:.4f} (样本数: {metrics['count']})") - - print("\n检索统计:") - stats = result['retrieval_stats'] - print(f" 平均上下文长度: {stats['avg_context_length']:.0f} 字符") - print(f" 平均检索文档数: {stats['avg_retrieved_docs']:.1f}") - - print(f"\n性能趋势: {result['performance_trend']}") - print(f"重置间隔: 每{result['reset_interval']}个问题") - print(f"处理问题总数: {result['total_questions_processed']}") - print(f"增强策略: {'启用' if result.get('enhanced_strategy', False) else '未启用'}") - - - # 保存结果到指定目录 - # 使用代码文件所在目录的绝对路径 - current_file_dir = os.path.dirname(os.path.abspath(__file__)) - output_dir = os.path.join(current_file_dir, "results") - os.makedirs(output_dir, exist_ok=True) - output_file = os.path.join(output_dir, "enhanced_evaluation_results.json") - with open(output_file, "w", encoding="utf-8") as f: - json.dump(result, f, ensure_ascii=False, indent=2) - print(f"\n详细结果已保存到: {output_file}") diff --git a/api/app/core/memory/evaluation/locomo/locomo_utils.py b/api/app/core/memory/evaluation/locomo/locomo_utils.py deleted file mode 100644 index 6ad68470..00000000 --- a/api/app/core/memory/evaluation/locomo/locomo_utils.py +++ /dev/null @@ -1,687 +0,0 @@ -""" -LoCoMo Utilities Module - -This module provides helper functions for the LoCoMo benchmark evaluation: -- Data loading from JSON files -- Conversation extraction for ingestion -- Temporal reference resolution -- Context selection and formatting -- Retrieval wrapper functions -- Ingestion wrapper functions -""" - -import os -import json -import re -from datetime import datetime, timedelta -from typing import List, Dict, Any, Optional -from pathlib import Path -from dotenv import load_dotenv - -# Load evaluation config -eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation" -if eval_config_path.exists(): - load_dotenv(eval_config_path, override=True) - -from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline - - -def load_locomo_data( - data_path: str, - sample_size: int, - conversation_index: int = 0 -) -> List[Dict[str, Any]]: - """ - Load LoCoMo dataset from JSON file. - - The LoCoMo dataset structure is a list of conversation objects, where each - object contains a "qa" list of question-answer pairs. - - Args: - data_path: Path to locomo10.json file - sample_size: Number of QA pairs to load (limits total QA items returned) - conversation_index: Which conversation to load QA pairs from (default: 0 for first) - - Returns: - List of QA item dictionaries, each containing: - - question: str - - answer: str - - category: int (1-4) - - evidence: List[str] - - Raises: - FileNotFoundError: If data_path does not exist - json.JSONDecodeError: If file is not valid JSON - IndexError: If conversation_index is out of range - """ - if not os.path.exists(data_path): - raise FileNotFoundError(f"LoCoMo data file not found: {data_path}") - - with open(data_path, "r", encoding="utf-8") as f: - raw = json.load(f) - - # LoCoMo data structure: list of objects, each with a "qa" list - qa_items: List[Dict[str, Any]] = [] - - if isinstance(raw, list): - # Only load QA pairs from the specified conversation - if conversation_index < len(raw): - entry = raw[conversation_index] - if isinstance(entry, dict) and "qa" in entry: - qa_items.extend(entry.get("qa", [])) - else: - raise IndexError( - f"Conversation index {conversation_index} out of range. " - f"Dataset has {len(raw)} conversations." - ) - else: - # Fallback: single object with qa list - if conversation_index == 0: - qa_items.extend(raw.get("qa", [])) - else: - raise IndexError( - f"Conversation index {conversation_index} out of range. " - f"Dataset has only 1 conversation." - ) - - # Return only the requested sample size - return qa_items[:sample_size] - - -def extract_conversations(data_path: str, max_dialogues: int = 1, max_messages_per_dialogue: Optional[int] = None) -> List[str]: - """ - Extract conversation texts from LoCoMo data for ingestion. - - This function extracts the raw conversation dialogues from the LoCoMo dataset - so they can be ingested into the memory system. Each conversation is formatted - as a multi-line string with "role: message" format. - - Args: - data_path: Path to locomo10.json file - max_dialogues: Maximum number of dialogues to extract (default: 1) - max_messages_per_dialogue: Maximum messages per dialogue (default: None = all messages) - - Returns: - List of conversation strings formatted for ingestion. - Each string contains multiple lines in format "role: message" - - Example output: - [ - "User: I went to the store yesterday.\\nAI: What did you buy?\\n...", - "User: I love hiking.\\nAI: Where do you like to hike?\\n..." - ] - """ - if not os.path.exists(data_path): - raise FileNotFoundError(f"LoCoMo data file not found: {data_path}") - - with open(data_path, "r", encoding="utf-8") as f: - raw = json.load(f) - - # Ensure we have a list of entries - entries = raw if isinstance(raw, list) else [raw] - - contents: List[str] = [] - - for i, entry in enumerate(entries[:max_dialogues]): - if not isinstance(entry, dict): - continue - - conv = entry.get("conversation", {}) - - if not isinstance(conv, dict): - continue - - lines: List[str] = [] - - # Collect all session_* messages - for key, val in sorted(conv.items()): - if isinstance(val, list) and key.startswith("session_"): - for msg in val: - if not isinstance(msg, dict): - continue - - role = msg.get("speaker") or "User" - text = msg.get("text") or "" - text = str(text).strip() - - if not text: - continue - - lines.append(f"{role}: {text}") - - # Limit messages if specified - if max_messages_per_dialogue and len(lines) >= max_messages_per_dialogue: - break - - # Break outer loop if we've reached the message limit - if max_messages_per_dialogue and len(lines) >= max_messages_per_dialogue: - break - - if lines: - contents.append("\n".join(lines)) - - return contents - -# 时间解析:将相对时间表达转换为绝对日期 -def resolve_temporal_references(text: str, anchor_date: datetime) -> str: - """ - Resolve relative temporal references to absolute dates. - - This function converts relative time expressions (like "today", "yesterday", - "3 days ago") into absolute ISO date strings based on an anchor date. - - Supported patterns: - - today, yesterday, tomorrow - - X days ago, in X days - - last week, next week - - Args: - text: Text containing temporal references - anchor_date: Reference date for resolution (datetime object) - - Returns: - Text with temporal references replaced by ISO dates (YYYY-MM-DD format) - - Example: - >>> anchor = datetime(2023, 5, 8) - >>> resolve_temporal_references("I saw him yesterday", anchor) - "I saw him 2023-05-07" - """ - # Ensure input is a string - t = str(text) if text is not None else "" - - # today / yesterday / tomorrow - t = re.sub( - r"\btoday\b", - anchor_date.date().isoformat(), - t, - flags=re.IGNORECASE - ) - t = re.sub( - r"\byesterday\b", - (anchor_date - timedelta(days=1)).date().isoformat(), - t, - flags=re.IGNORECASE - ) - t = re.sub( - r"\btomorrow\b", - (anchor_date + timedelta(days=1)).date().isoformat(), - t, - flags=re.IGNORECASE - ) - - # X days ago - def _ago_repl(m: re.Match[str]) -> str: - n = int(m.group(1)) - return (anchor_date - timedelta(days=n)).date().isoformat() - - # in X days - def _in_repl(m: re.Match[str]) -> str: - n = int(m.group(1)) - return (anchor_date + timedelta(days=n)).date().isoformat() - - t = re.sub( - r"\b(\d+)\s+days?\s+ago\b", - _ago_repl, - t, - flags=re.IGNORECASE - ) - t = re.sub( - r"\bin\s+(\d+)\s+days?\b", - _in_repl, - t, - flags=re.IGNORECASE - ) - - # last week / next week (approximate as 7 days) - t = re.sub( - r"\blast\s+week\b", - (anchor_date - timedelta(days=7)).date().isoformat(), - t, - flags=re.IGNORECASE - ) - - # 中文支持 - t = re.sub( - r"\bnext\s+week\b", - (anchor_date + timedelta(days=7)).date().isoformat(), - t, - flags=re.IGNORECASE - ) - - return t - - -def select_and_format_information( - retrieved_info: List[str], - question: str, - max_chars: int = 8000 -) -> str: - """ - Intelligently select and format most relevant retrieved information for LLM prompt. - - This function scores each piece of retrieved information based on keyword matching - with the question, then selects the highest-scoring pieces up to the character limit. - - Scoring criteria: - - Keyword matches (higher weight for multiple occurrences) - - Context length (moderate length preferred) - - Position (earlier contexts get bonus points) - - Args: - retrieved_info: List of retrieved information strings (chunks, statements, entities) - question: Question being answered - max_chars: Maximum total characters to include in final prompt - - Returns: - Formatted string combining the most relevant information for LLM prompt. - Contexts are separated by double newlines. - - Example: - >>> contexts = ["Alice went to Paris", "Bob likes pizza", "Alice visited the Eiffel Tower"] - >>> question = "Where did Alice go?" - >>> select_and_format_information(contexts, question, max_chars=100) - "Alice went to Paris\\n\\nAlice visited the Eiffel Tower" - """ - if not retrieved_info: - return "" - - # Extract question keywords (filter out stop words and short words) - question_lower = question.lower() - stop_words = { - 'what', 'when', 'where', 'who', 'why', 'how', - 'did', 'do', 'does', 'is', 'are', 'was', 'were', - 'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at' - } - question_words = set(re.findall(r'\b\w+\b', question_lower)) - question_words = { - word for word in question_words - if word not in stop_words and len(word) > 2 - } - - # Score each context - scored_contexts = [] - for i, context in enumerate(retrieved_info): - context_lower = context.lower() - score = 0 - - # Keyword matching score - keyword_matches = 0 - for word in question_words: - if word in context_lower: - keyword_matches += 1 - # Multiple occurrences increase score - score += context_lower.count(word) * 2 - - # Length score (prefer moderate length) - context_len = len(context) - if 100 < context_len < 2000: - score += 5 - elif context_len >= 2000: - score += 2 - - # Position bonus (earlier contexts often more relevant) - if i < 3: - score += 3 - - scored_contexts.append((score, context, keyword_matches)) - - # Sort by score (descending) - scored_contexts.sort(key=lambda x: x[0], reverse=True) - - # Select contexts up to character limit - selected = [] - total_chars = 0 - - for score, context, matches in scored_contexts: - if total_chars + len(context) <= max_chars: - selected.append(context) - total_chars += len(context) - else: - # Try to include high-scoring context by truncating - if score > 10 and total_chars < max_chars - 500: - remaining = max_chars - total_chars - # Find lines with keywords - lines = context.split('\n') - relevant_lines = [] - current_chars = 0 - - for line in lines: - line_lower = line.lower() - line_relevance = any(word in line_lower for word in question_words) - - if line_relevance and current_chars < remaining - 100: - relevant_lines.append(line) - current_chars += len(line) - - if relevant_lines and len('\n'.join(relevant_lines)) > 100: - truncated = '\n'.join(relevant_lines) - selected.append(truncated + "\n[Content truncated...]") - total_chars += len(truncated) - break - - return "\n\n".join(selected) - -# 记忆系统核心能力:写入与读取 -async def ingest_conversations_if_needed( - conversations: List[str], - end_user_id: str, - reset: bool = False -) -> bool: - """ - Wrapper for conversation ingestion using external extraction pipeline. - - This function populates the Neo4j database with processed conversation data - (chunks, statements, entities) so that the retrieval system has memory to search. - - The ingestion process: - 1. Parses conversation text into dialogue messages - 2. Chunks the dialogues into semantic units - 3. Extracts statements and entities using LLM - 4. Generates embeddings for all content - 5. Stores everything in Neo4j graph database - - Args: - conversations: List of raw conversation texts from LoCoMo dataset - Example: ["User: I went to Paris. AI: When was that?", ...] - end_user_id: Target end_user ID for database storage - reset: Whether to clear existing data first (not implemented in wrapper) - - Returns: - True if successful, False otherwise - - Note: - The external function uses "contexts" to mean "conversation texts". - This runs the full extraction pipeline: chunking → entity extraction → - statement extraction → embedding → Neo4j storage. - """ - try: - success = await ingest_contexts_via_full_pipeline( - contexts=conversations, - end_user_id=end_user_id, - save_chunk_output=True, - reset_group=reset - ) - return success - except Exception as e: - print(f"[Ingestion] Failed to ingest conversations: {e}") - return False - -async def retrieve_relevant_information( - question: str, - end_user_id: str, - search_type: str, - search_limit: int, - connector: Any, - embedder: Any -) -> List[str]: - """ - Retrieve relevant information from memory graph for a question. - - This function searches the Neo4j memory graph (populated during ingestion) and - returns relevant chunks, statements, and entity information that might help - answer the question. - - The function supports three search types: - - "keyword": Full-text search using Cypher queries - - "embedding": Vector similarity search using embeddings - - "hybrid": Combination of keyword and embedding search with reranking - - Args: - question: Question to search for - end_user_id: Database group ID (identifies which conversation memory to search) - search_type: "keyword", "embedding", or "hybrid" - search_limit: Max memory pieces to retrieve - connector: Neo4j connector instance - embedder: Embedder client instance - - Returns: - List of text strings (chunks, statements, entity summaries) from memory graph. - Each string represents a piece of retrieved information. - - Raises: - Exception: If search fails (caught and returns empty list) - """ - from app.repositories.neo4j.graph_search import ( - search_graph, - search_graph_by_embedding - ) - from app.core.memory.src.search import run_hybrid_search - - contexts_all: List[str] = [] - - try: - if search_type == "embedding": - # Embedding-based search - search_results = await search_graph_by_embedding( - connector=connector, - embedder_client=embedder, - query_text=question, - end_user_id=end_user_id, - limit=search_limit, - include=["chunks", "statements", "entities", "summaries"], - ) - - chunks = search_results.get("chunks", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - summaries = search_results.get("summaries", []) - - # Build context from chunks - for c in chunks: - content = str(c.get("content", "")).strip() - if content: - contexts_all.append(content) - - # Add statements - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - - # Add summaries - for sm in summaries: - summary_text = str(sm.get("summary", "")).strip() - if summary_text: - contexts_all.append(summary_text) - - # Add top entities (limit to 3 to avoid noise) - if entities: - scored = [e for e in entities if e.get("score") is not None] - top_entities = ( - sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] - if scored else entities[:3] - ) - if top_entities: - summary_lines = [] - for e in top_entities: - name = str(e.get("name", "")).strip() - etype = str(e.get("entity_type", "")).strip() - score = e.get("score") - if name: - meta = [] - if etype: - meta.append(f"type={etype}") - if isinstance(score, (int, float)): - meta.append(f"score={score:.3f}") - summary_lines.append( - f"EntitySummary: {name}" - f"{(' [' + '; '.join(meta) + ']') if meta else ''}" - ) - if summary_lines: - contexts_all.append("\n".join(summary_lines)) - - elif search_type == "keyword": - # Keyword-based search - search_results = await search_graph( - connector=connector, - q=question, - end_user_id=end_user_id, - limit=search_limit - ) - - dialogs = search_results.get("dialogues", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - - # Build context from dialogues - for d in dialogs: - content = str(d.get("content", "")).strip() - if content: - contexts_all.append(content) - - # Add statements - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - - # Add entity names - if entities: - entity_names = [ - str(e.get("name", "")).strip() - for e in entities[:5] - if e.get("name") - ] - if entity_names: - contexts_all.append(f"EntitySummary: {', '.join(entity_names)}") - - else: # hybrid - # Hybrid search with fallback to embedding - try: - search_results = await run_hybrid_search( - query_text=question, - search_type=search_type, - end_user_id=end_user_id, - limit=search_limit, - include=["chunks", "statements", "entities", "summaries"], - output_path=None, - ) - - # Handle flat structure (new API format) - if search_results and isinstance(search_results, dict): - chunks = search_results.get("chunks", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - summaries = search_results.get("summaries", []) - - # Check if we got results - if not (chunks or statements or entities or summaries): - # Try nested structure (backward compatibility) - reranked = search_results.get("reranked_results", {}) - if reranked and isinstance(reranked, dict): - chunks = reranked.get("chunks", []) - statements = reranked.get("statements", []) - entities = reranked.get("entities", []) - summaries = reranked.get("summaries", []) - else: - raise ValueError("Hybrid search returned empty results") - else: - raise ValueError("Hybrid search returned empty results") - - except Exception as e: - # Fallback to embedding search - search_results = await search_graph_by_embedding( - connector=connector, - embedder_client=embedder, - query_text=question, - end_user_id=end_user_id, - limit=search_limit, - include=["chunks", "statements", "entities", "summaries"], - ) - chunks = search_results.get("chunks", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - summaries = search_results.get("summaries", []) - - # Build context (same for both hybrid and fallback) - for c in chunks: - content = str(c.get("content", "")).strip() - if content: - contexts_all.append(content) - - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - - for sm in summaries: - summary_text = str(sm.get("summary", "")).strip() - if summary_text: - contexts_all.append(summary_text) - - # Add top entities - if entities: - scored = [e for e in entities if e.get("score") is not None] - top_entities = ( - sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] - if scored else entities[:3] - ) - if top_entities: - summary_lines = [] - for e in top_entities: - name = str(e.get("name", "")).strip() - etype = str(e.get("entity_type", "")).strip() - score = e.get("score") - if name: - meta = [] - if etype: - meta.append(f"type={etype}") - if isinstance(score, (int, float)): - meta.append(f"score={score:.3f}") - summary_lines.append( - f"EntitySummary: {name}" - f"{(' [' + '; '.join(meta) + ']') if meta else ''}" - ) - if summary_lines: - contexts_all.append("\n".join(summary_lines)) - - except Exception as e: - # Return empty list on error - contexts_all = [] - - return contexts_all - - -async def ingest_conversations_if_needed( - conversations: List[str], - end_user_id: str, - reset: bool = False -) -> bool: - """ - Wrapper for conversation ingestion using external extraction pipeline. - - This function populates the Neo4j database with processed conversation data - (chunks, statements, entities) so that the retrieval system has memory to search. - - The ingestion process: - 1. Parses conversation text into dialogue messages - 2. Chunks the dialogues into semantic units - 3. Extracts statements and entities using LLM - 4. Generates embeddings for all content - 5. Stores everything in Neo4j graph database - - Args: - conversations: List of raw conversation texts from LoCoMo dataset - Example: ["User: I went to Paris. AI: When was that?", ...] - end_user_id: Target group ID for database storage - reset: Whether to clear existing data first (not implemented in wrapper) - - Returns: - True if successful, False otherwise - - Note: - The external function uses "contexts" to mean "conversation texts". - This runs the full extraction pipeline: chunking → entity extraction → - statement extraction → embedding → Neo4j storage. - """ - try: - success = await ingest_contexts_via_full_pipeline( - contexts=conversations, - end_user_id=end_user_id, - save_chunk_output=True - ) - return success - except Exception as e: - print(f"[Ingestion] Failed to ingest conversations: {e}") - return False diff --git a/api/app/core/memory/evaluation/locomo/qwen_search_eval.py b/api/app/core/memory/evaluation/locomo/qwen_search_eval.py deleted file mode 100644 index 889c5065..00000000 --- a/api/app/core/memory/evaluation/locomo/qwen_search_eval.py +++ /dev/null @@ -1,874 +0,0 @@ -import argparse -import asyncio -import json -import os -import time -from datetime import datetime, timedelta -from typing import List, Dict, Any -import statistics -import re -from pathlib import Path -from dotenv import load_dotenv - -# Load evaluation config -eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation" -if eval_config_path.exists(): - load_dotenv(eval_config_path, override=True) - print(f"✅ 加载评估配置: {eval_config_path}") - -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding -from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.models.base import RedBearModelConfig -from app.core.memory.utils.config.config_utils import get_embedder_config -from app.core.memory.src.search import run_hybrid_search # 使用旧版本(重构前) -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline -from app.core.memory.evaluation.common.metrics import f1_score as common_f1, bleu1, jaccard, latency_stats, avg_context_tokens - - -# 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现) -def _loc_normalize(text: str) -> str: - import re - # 确保输入是字符串 - text = str(text) if text is not None else "" - text = text.lower() - text = re.sub(r"[\,]", " ", text) # 去掉逗号 - text = re.sub(r"\b(a|an|the|and)\b", " ", text) - text = re.sub(r"[^\w\s]", " ", text) - text = " ".join(text.split()) - return text - -# 追加:相对时间归一化为绝对日期(有限支持:today/yesterday/tomorrow/X days ago/in X days/last week/next week) -def _resolve_relative_times(text: str, anchor: datetime) -> str: - import re - # 确保输入是字符串 - t = str(text) if text is not None else "" - # today / yesterday / tomorrow - t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE) - # X days ago / in X days - def _ago_repl(m: re.Match[str]) -> str: - n = int(m.group(1)) - return (anchor - timedelta(days=n)).date().isoformat() - def _in_repl(m: re.Match[str]) -> str: - n = int(m.group(1)) - return (anchor + timedelta(days=n)).date().isoformat() - t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE) - t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE) - # last week / next week(以7天近似) - t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE) - return t - -def loc_f1_score(prediction: str, ground_truth: str) -> float: - # 单答案 F1:按词集合计算(近似原始实现,去除词干依赖) - # 确保输入是字符串 - pred_str = str(prediction) if prediction is not None else "" - truth_str = str(ground_truth) if ground_truth is not None else "" - - p_tokens = _loc_normalize(pred_str).split() - g_tokens = _loc_normalize(truth_str).split() - if not p_tokens or not g_tokens: - return 0.0 - p = set(p_tokens) - g = set(g_tokens) - tp = len(p & g) - precision = tp / len(p) if p else 0.0 - recall = tp / len(g) if g else 0.0 - return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0 - -def loc_multi_f1(prediction: str, ground_truth: str) -> float: - # 多答案 F1:prediction 与 ground_truth 以逗号分隔,逐一匹配取最大,再对多个 GT 取平均 - # 确保输入是字符串 - pred_str = str(prediction) if prediction is not None else "" - truth_str = str(ground_truth) if ground_truth is not None else "" - - predictions = [p.strip() for p in str(pred_str).split(',') if p.strip()] - ground_truths = [g.strip() for g in str(truth_str).split(',') if g.strip()] - if not predictions or not ground_truths: - return 0.0 - def _f1(a: str, b: str) -> float: - return loc_f1_score(a, b) - vals = [] - for gt in ground_truths: - vals.append(max(_f1(pred, gt) for pred in predictions)) - return sum(vals) / len(vals) - -# 标准化 LoCoMo 类别名:支持数字 category 与字符串 cat/type -CATEGORY_MAP_NUM_TO_NAME = { - 4: "Single-Hop", - 1: "Multi-Hop", - 3: "Open Domain", - 2: "Temporal", -} - -_TYPE_ALIASES = { - "single-hop": "Single-Hop", - "singlehop": "Single-Hop", - "single hop": "Single-Hop", - "multi-hop": "Multi-Hop", - "multihop": "Multi-Hop", - "multi hop": "Multi-Hop", - "open domain": "Open Domain", - "opendomain": "Open Domain", - "temporal": "Temporal", -} - -def get_category_label(item: Dict[str, Any]) -> str: - # 1) 直接用字符串 cat - cat = item.get("cat") - if isinstance(cat, str) and cat.strip(): - name = cat.strip() - lower = name.lower() - return _TYPE_ALIASES.get(lower, name) - # 2) 数字 category 转名称 - cat_num = item.get("category") - if isinstance(cat_num, int): - return CATEGORY_MAP_NUM_TO_NAME.get(cat_num, "unknown") - # 3) 备用 type 字段 - t = item.get("type") - if isinstance(t, str) and t.strip(): - lower = t.strip().lower() - return _TYPE_ALIASES.get(lower, t.strip()) - return "unknown" - - -def smart_context_selection(contexts: List[str], question: str, max_chars: int = 12000) -> str: - """基于问题关键词智能选择上下文""" - if not contexts: - return "" - - # 提取问题关键词(只保留有意义的词) - question_lower = question.lower() - stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'} - question_words = set(re.findall(r'\b\w+\b', question_lower)) - question_words = {word for word in question_words if word not in stop_words and len(word) > 2} - - print(f"🔍 问题关键词: {question_words}") - - # 给每个上下文打分 - scored_contexts = [] - for i, context in enumerate(contexts): - context_lower = context.lower() - score = 0 - - # 关键词匹配得分 - keyword_matches = 0 - for word in question_words: - if word in context_lower: - keyword_matches += 1 - # 关键词出现次数越多,得分越高 - score += context_lower.count(word) * 2 - - # 上下文长度得分(适中的长度更好) - context_len = len(context) - if 100 < context_len < 2000: # 理想长度范围 - score += 5 - elif context_len >= 2000: # 太长可能包含无关信息 - score += 2 - - # 如果是前几个上下文,给予额外分数(通常相关性更高) - if i < 3: - score += 3 - - scored_contexts.append((score, context, keyword_matches)) - - # 按得分排序 - scored_contexts.sort(key=lambda x: x[0], reverse=True) - - # 选择高得分的上下文,直到达到字符限制 - selected = [] - total_chars = 0 - selected_count = 0 - - print("📊 上下文相关性分析:") - for score, context, matches in scored_contexts[:5]: # 只显示前5个 - print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}") - - for score, context, matches in scored_contexts: - if total_chars + len(context) <= max_chars: - selected.append(context) - total_chars += len(context) - selected_count += 1 - else: - # 如果这个上下文得分很高但放不下,尝试截取 - if score > 10 and total_chars < max_chars - 500: - remaining = max_chars - total_chars - # 找到包含关键词的部分 - lines = context.split('\n') - relevant_lines = [] - current_chars = 0 - - for line in lines: - line_lower = line.lower() - line_relevance = any(word in line_lower for word in question_words) - - if line_relevance and current_chars < remaining - 100: - relevant_lines.append(line) - current_chars += len(line) - - if relevant_lines: - truncated = '\n'.join(relevant_lines) - if len(truncated) > 100: # 确保有足够内容 - selected.append(truncated + "\n[相关内容截断...]") - total_chars += len(truncated) - selected_count += 1 - break # 不再尝试添加更多上下文 - - result = "\n\n".join(selected) - print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符") - return result - - -def get_search_params_by_category(category: str): - """根据问题类别调整检索参数""" - params_map = { - "Multi-Hop": {"limit": 20, "max_chars": 15000}, - "Temporal": {"limit": 16, "max_chars": 10000}, - "Open Domain": {"limit": 24, "max_chars": 18000}, - "Single-Hop": {"limit": 12, "max_chars": 8000}, - } - return params_map.get(category, {"limit": 16, "max_chars": 12000}) - - -async def run_locomo_eval( - sample_size: int = 1, - end_user_id: str | None = None, - search_limit: int = 8, - context_char_budget: int = 4000, # 保持默认值不变 - llm_temperature: float = 0.0, - llm_max_tokens: int = 32, - search_type: str = "hybrid", # 保持默认值不变 - output_path: str | None = None, - skip_ingest_if_exists: bool = True, - llm_timeout: float = 10.0, - llm_max_retries: int = 1 -) -> Dict[str, Any]: - - # 函数内部使用三路检索逻辑,但保持参数签名不变 - end_user_id = end_user_id or SELECTED_end_user_id - data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json") - if not os.path.exists(data_path): - raise FileNotFoundError( - f"数据集文件不存在: {data_path}\n" - f"请将 locomo10.json 放置在: {dataset_dir}" - ) - with open(data_path, "r", encoding="utf-8") as f: - raw = json.load(f) - # LoCoMo 数据结构:顶层为若干对象,每个对象下有 qa 列表 - qa_items: List[Dict[str, Any]] = [] - if isinstance(raw, list): - for entry in raw: - qa_items.extend(entry.get("qa", [])) - else: - qa_items.extend(raw.get("qa", [])) - items: List[Dict[str, Any]] = qa_items[:sample_size] - - # === 保持原来的数据摄入逻辑 === - entries = raw if isinstance(raw, list) else [raw] - - # 只摄入前1条对话(保持原样) - max_dialogues_to_ingest = 1 - contents: List[str] = [] - print(f"📊 找到 {len(entries)} 个对话对象,只摄入前 {max_dialogues_to_ingest} 条") - - for i, entry in enumerate(entries[:max_dialogues_to_ingest]): - if not isinstance(entry, dict): - continue - - conv = entry.get("conversation", {}) - sample_id = entry.get("sample_id", f"unknown_{i}") - - print(f"🔍 处理对话 {i+1}: {sample_id}") - - lines: List[str] = [] - if isinstance(conv, dict): - # 收集所有 session_* 的消息 - session_count = 0 - for key, val in conv.items(): - if isinstance(val, list) and key.startswith("session_"): - session_count += 1 - for msg in val: - role = msg.get("speaker") or "用户" - text = msg.get("text") or "" - text = str(text).strip() - if not text: - continue - lines.append(f"{role}: {text}") - - print(f" - 包含 {session_count} 个session, {len(lines)} 条消息") - - if not lines: - print(f"⚠️ 警告: 对话 {sample_id} 没有对话内容,跳过摄入") - continue - - contents.append("\n".join(lines)) - - print(f"📥 总共摄入 {len(contents)} 个对话的conversation内容") - - # 选择要评测的QA对(从所有对话中选取) - indexed_items: List[tuple[int, Dict[str, Any]]] = [] - if isinstance(raw, list): - for e_idx, entry in enumerate(raw): - for qa in entry.get("qa", []): - indexed_items.append((e_idx, qa)) - else: - for qa in raw.get("qa", []): - indexed_items.append((0, qa)) - - # 这里使用sample_size来限制评测的QA数量 - selected = indexed_items[:sample_size] - items: List[Dict[str, Any]] = [qa for _, qa in selected] - - print(f"🎯 将评测 {len(items)} 个QA对,数据库中只包含 {len(contents)} 个对话") - # === 修改结束 === - - connector = Neo4jConnector() - - # 关键修复:强制重新摄入纯净的对话数据 - print("🔄 强制重新摄入纯净的对话数据...") - await ingest_contexts_via_full_pipeline(contents, end_user_id, save_chunk_output=True) - - # 使用异步LLM客户端 - llm_client = get_llm_client(llm_id) - # 初始化embedder用于直接调用 - cfg_dict = get_embedder_config(embedding_id) - embedder = OpenAIEmbedderClient( - model_config=RedBearModelConfig.model_validate(cfg_dict) - ) - - # connector initialized above - latencies_llm: List[float] = [] - latencies_search: List[float] = [] - # 上下文诊断收集 - per_query_context_counts: List[int] = [] - per_query_context_avg_tokens: List[float] = [] - per_query_context_chars: List[int] = [] - per_query_context_tokens_total: List[int] = [] - # 详细样本调试信息 - samples: List[Dict[str, Any]] = [] - # 通用指标 - f1s: List[float] = [] - b1s: List[float] = [] - jss: List[float] = [] - # 参考 LoCoMo 评测的类别专用 F1(multi-hop 使用多答案 F1) - loc_f1s: List[float] = [] - # Per-category aggregation - cat_counts: Dict[str, int] = {} - cat_f1s: Dict[str, List[float]] = {} - cat_b1s: Dict[str, List[float]] = {} - cat_jss: Dict[str, List[float]] = {} - cat_loc_f1s: Dict[str, List[float]] = {} - try: - for item in items: - q = item.get("question", "") - ref = item.get("answer", "") - # 确保答案是字符串 - ref_str = str(ref) if ref is not None else "" - cat = get_category_label(item) - - print(f"\n=== 处理问题: {q} ===") - - # 根据类别调整检索参数 - search_params = get_search_params_by_category(cat) - adjusted_limit = search_params["limit"] - max_chars = search_params["max_chars"] - - print(f"🏷️ 类别: {cat}, 检索参数: limit={adjusted_limit}, max_chars={max_chars}") - - # 改进的检索逻辑:使用三路检索(statements, dialogues, entities) - t0 = time.time() - contexts_all: List[str] = [] - search_results = None # 保存完整的检索结果 - - try: - if search_type == "embedding": - # 直接调用嵌入检索,包含三路数据 - search_results = await search_graph_by_embedding( - connector=connector, - embedder_client=embedder, - query_text=q, - end_user_id=end_user_id, - limit=adjusted_limit, - include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型 - ) - chunks = search_results.get("chunks", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - summaries = search_results.get("summaries", []) - - print(f"✅ 嵌入检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要") - - # 构建上下文:优先使用 chunks、statements 和 summaries - for c in chunks: - content = str(c.get("content", "")).strip() - if content: - contexts_all.append(content) - - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - - for sm in summaries: - summary_text = str(sm.get("summary", "")).strip() - if summary_text: - contexts_all.append(summary_text) - - # 实体摘要:最多加入前3个高分实体,避免噪声 - scored = [e for e in entities if e.get("score") is not None] - top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3] - if top_entities: - summary_lines = [] - for e in top_entities: - name = str(e.get("name", "")).strip() - etype = str(e.get("entity_type", "")).strip() - score = e.get("score") - if name: - meta = [] - if etype: - meta.append(f"type={etype}") - if isinstance(score, (int, float)): - meta.append(f"score={score:.3f}") - summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}") - if summary_lines: - contexts_all.append("\n".join(summary_lines)) - - elif search_type == "keyword": - # 直接调用关键词检索 - search_results = await search_graph( - connector=connector, - q=q, - end_user_id=end_user_id, - limit=adjusted_limit - ) - dialogs = search_results.get("dialogues", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - print(f"🔤 关键词检索找到 {len(dialogs)} 条对话, {len(statements)} 条陈述, {len(entities)} 个实体") - - # 构建上下文 - for d in dialogs: - content = str(d.get("content", "")).strip() - if content: - contexts_all.append(content) - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - # 实体处理(关键词检索的实体可能没有分数) - if entities: - entity_names = [str(e.get("name", "")).strip() for e in entities[:5] if e.get("name")] - if entity_names: - contexts_all.append(f"EntitySummary: {', '.join(entity_names)}") - - else: # hybrid - # 使用旧版本的混合检索(重构前) - print("🔀 使用混合检索(旧版本)...") - try: - search_results = await run_hybrid_search( - query_text=q, - search_type=search_type, - end_user_id=end_user_id, - limit=adjusted_limit, - include=["chunks", "statements", "entities", "summaries"], - output_path=None, - rerank_alpha=0.6, - use_forgetting_rerank=False, - use_llm_rerank=False - ) - - # 处理旧版本的返回结构(包含 reranked_results) - if search_results and isinstance(search_results, dict): - # 对于 hybrid 搜索,使用 reranked_results - if "reranked_results" in search_results: - reranked = search_results["reranked_results"] - chunks = reranked.get("chunks", []) - statements = reranked.get("statements", []) - entities = reranked.get("entities", []) - summaries = reranked.get("summaries", []) - else: - # 单一搜索类型的结果 - chunks = search_results.get("chunks", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - summaries = search_results.get("summaries", []) - - # 检查是否有有效结果 - if chunks or statements or entities or summaries: - print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 陈述, {len(entities)} 实体, {len(summaries)} 摘要") - else: - # 如果顶层没有结果,尝试旧的嵌套结构(向后兼容) - reranked = search_results.get("reranked_results", {}) - if reranked and isinstance(reranked, dict): - chunks = reranked.get("chunks", []) - statements = reranked.get("statements", []) - entities = reranked.get("entities", []) - summaries = reranked.get("summaries", []) - print(f"✅ 混合检索成功(使用旧格式reranked结果): {len(chunks)} chunks, {len(statements)} 陈述") - else: - raise ValueError("混合检索返回空结果") - else: - raise ValueError("混合检索返回空结果") - - except Exception as e: - print(f"❌ 混合检索失败: {e},回退到嵌入检索") - search_results = await search_graph_by_embedding( - connector=connector, - embedder_client=embedder, - query_text=q, - end_user_id=end_user_id, - limit=adjusted_limit, - include=["chunks", "statements", "entities", "summaries"], - ) - chunks = search_results.get("chunks", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - summaries = search_results.get("summaries", []) - print(f"✅ 回退嵌入检索成功: {len(chunks)} chunks, {len(statements)} 陈述") - - # 🎯 统一处理:构建上下文(所有检索类型共用) - for c in chunks: - content = str(c.get("content", "")).strip() - if content: - contexts_all.append(content) - - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - - for sm in summaries: - summary_text = str(sm.get("summary", "")).strip() - if summary_text: - contexts_all.append(summary_text) - - # 实体摘要:最多加入前3个高分实体 - if entities: - scored = [e for e in entities if e.get("score") is not None] - top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3] - if top_entities: - summary_lines = [] - for e in top_entities: - name = str(e.get("name", "")).strip() - etype = str(e.get("entity_type", "")).strip() - score = e.get("score") - if name: - meta = [] - if etype: - meta.append(f"type={etype}") - if isinstance(score, (int, float)): - meta.append(f"score={score:.3f}") - summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}") - if summary_lines: - contexts_all.append("\n".join(summary_lines)) - - # 关键修复:过滤掉包含当前问题答案的上下文 - filtered_contexts = [] - for context in contexts_all: - content = str(context) - # 排除包含当前问题标准答案的上下文 - if ref_str and ref_str.strip() and ref_str.strip() in content: - print("🚫 过滤掉包含标准答案的上下文") - continue - filtered_contexts.append(context) - - print(f"📊 过滤后保留 {len(filtered_contexts)} 个上下文 (原 {len(contexts_all)} 个)") - contexts_all = filtered_contexts - - # 输出完整的检索结果信息 - print("🔍 检索结果详情:") - if search_results: - output_data = { - "statements": [ - { - "statement": s.get("statement", "")[:200] + "..." if len(s.get("statement", "")) > 200 else s.get("statement", ""), - "score": s.get("score", 0.0) - } - for s in (statements[:2] if 'statements' in locals() else []) - ], - "dialogues": [ - { - "uuid": d.get("uuid", ""), - "end_user_id": d.get("end_user_id", ""), - "content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""), - "score": d.get("score", 0.0) - } - for d in (dialogs[:2] if 'dialogs' in locals() else []) - ], - "entities": [ - { - "name": e.get("name", ""), - "entity_type": e.get("entity_type", ""), - "score": e.get("score", 0.0) - } - for e in (entities[:2] if 'entities' in locals() else []) - ] - } - print(json.dumps(output_data, ensure_ascii=False, indent=2)) - else: - print(" 无检索结果") - - except Exception as e: - print(f"❌ {search_type}检索失败: {e}") - contexts_all = [] - search_results = None - - t1 = time.time() - latencies_search.append((t1 - t0) * 1000) - - # 使用智能上下文选择 - context_text = "" - if contexts_all: - context_text = smart_context_selection(contexts_all, q, max_chars=max_chars) - - # 如果智能选择后仍然过长,进行最终保护性截断 - if len(context_text) > max_chars: - print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断") - context_text = context_text[:max_chars] + "\n\n[最终截断...]" - - # 时间解析 - anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性 - context_text = _resolve_relative_times(context_text, anchor_date) - - context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text - - print(f"📝 最终上下文长度: {len(context_text)} 字符") - - # 显示不同上下文的预览 - print("🔍 上下文预览:") - for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文 - preview = context[:150].replace('\n', ' ') - print(f" 上下文{j+1}: {preview}...") - - else: - print("❌ 没有检索到有效上下文") - context_text = "No relevant context found." - - # 记录上下文诊断信息 - per_query_context_counts.append(len(contexts_all)) - per_query_context_avg_tokens.append(avg_context_tokens([context_text])) - per_query_context_chars.append(len(context_text)) - per_query_context_tokens_total.append(len(_loc_normalize(context_text).split())) - - # LLM 提示词 - messages = [ - {"role": "system", "content": ( - "You are a precise QA assistant. Answer following these rules:\n" - "1) Extract the EXACT information mentioned in the context\n" - "2) For time questions: calculate actual dates from relative times\n" - "3) Return ONLY the answer text in simplest form\n" - "4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n" - "5) If no clear answer found, respond with 'Unknown'" - )}, - {"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"}, - ] - - t2 = time.time() - # 使用异步调用 - resp = await llm_client.chat(messages=messages) - t3 = time.time() - latencies_llm.append((t3 - t2) * 1000) - - # 兼容不同的响应格式 - pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown") - - # 计算指标(确保使用字符串) - f1_val = common_f1(str(pred), ref_str) - b1_val = bleu1(str(pred), ref_str) - j_val = jaccard(str(pred), ref_str) - - f1s.append(f1_val) - b1s.append(b1_val) - jss.append(j_val) - - # Accumulate by category - cat_counts[cat] = cat_counts.get(cat, 0) + 1 - cat_f1s.setdefault(cat, []).append(f1_val) - cat_b1s.setdefault(cat, []).append(b1_val) - cat_jss.setdefault(cat, []).append(j_val) - - # LoCoMo 专用 F1:multi-hop(1) 使用多答案 F1,其它(2/3/4)使用单答案 F1 - if item.get("category") in [2, 3, 4]: - loc_val = loc_f1_score(str(pred), ref_str) - elif item.get("category") in [1]: - loc_val = loc_multi_f1(str(pred), ref_str) - else: - loc_val = loc_f1_score(str(pred), ref_str) - loc_f1s.append(loc_val) - cat_loc_f1s.setdefault(cat, []).append(loc_val) - - # 保存完整的检索结果信息 - samples.append({ - "question": q, - "answer": ref_str, - "category": cat, - "prediction": pred, - "metrics": { - "f1": f1_val, - "b1": b1_val, - "j": j_val, - "loc_f1": loc_val - }, - "retrieval": { - "retrieved_documents": len(contexts_all), - "context_length": len(context_text), - "search_limit": adjusted_limit, - "max_chars": max_chars - }, - "timing": { - "search_ms": (t1 - t0) * 1000, - "llm_ms": (t3 - t2) * 1000 - } - }) - - print(f"🤖 LLM 回答: {pred}") - print(f"✅ 正确答案: {ref_str}") - print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}, LoCoMo F1: {loc_val:.3f}") - - # Compute per-category averages and dispersion (std, iqr) - def _percentile(sorted_vals: List[float], p: float) -> float: - if not sorted_vals: - return 0.0 - if len(sorted_vals) == 1: - return sorted_vals[0] - k = (len(sorted_vals) - 1) * p - f = int(k) - c = f + 1 if f + 1 < len(sorted_vals) else f - if f == c: - return sorted_vals[f] - return sorted_vals[f] + (sorted_vals[c] - sorted_vals[f]) * (k - f) - - by_category: Dict[str, Dict[str, float | int]] = {} - for c in cat_counts: - f_list = cat_f1s.get(c, []) - b_list = cat_b1s.get(c, []) - j_list = cat_jss.get(c, []) - lf_list = cat_loc_f1s.get(c, []) - j_sorted = sorted(j_list) - j_std = statistics.stdev(j_list) if len(j_list) > 1 else 0.0 - j_q75 = _percentile(j_sorted, 0.75) - j_q25 = _percentile(j_sorted, 0.25) - by_category[c] = { - "count": cat_counts[c], - "f1": (sum(f_list) / max(len(f_list), 1)) if f_list else 0.0, - "b1": (sum(b_list) / max(len(b_list), 1)) if b_list else 0.0, - "j": (sum(j_list) / max(len(j_list), 1)) if j_list else 0.0, - "j_std": j_std, - "j_iqr": (j_q75 - j_q25) if j_list else 0.0, - # 参考 LoCoMo 评测的类别专用 F1 - "loc_f1": (sum(lf_list) / max(len(lf_list), 1)) if lf_list else 0.0, - } - - # 累加命中(cum accuracy by category):与 evaluation_stats.py 输出形式相仿 - cum_accuracy_by_category = {c: sum(cat_loc_f1s.get(c, [])) for c in cat_counts} - - result = { - "dataset": "locomo", - "items": len(items), - "metrics": { - "f1": sum(f1s) / max(len(f1s), 1), - "b1": sum(b1s) / max(len(b1s), 1), - "j": sum(jss) / max(len(jss), 1), - # LoCoMo 类别专用 F1 的总体 - "loc_f1": sum(loc_f1s) / max(len(loc_f1s), 1), - }, - "by_category": by_category, - "category_counts": cat_counts, - "cum_accuracy_by_category": cum_accuracy_by_category, - "context": { - "avg_tokens": (sum(per_query_context_avg_tokens) / max(len(per_query_context_avg_tokens), 1)) if per_query_context_avg_tokens else 0.0, - "avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0, - "count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0, - "avg_memory_tokens": (sum(per_query_context_tokens_total) / max(len(per_query_context_tokens_total), 1)) if per_query_context_tokens_total else 0.0, - }, - "latency": { - "search": latency_stats(latencies_search), - "llm": latency_stats(latencies_llm), - }, - "samples": samples, - "params": { - "end_user_id": end_user_id, - "search_limit": search_limit, - "context_char_budget": context_char_budget, - "search_type": search_type, - "llm_id": llm_id, - "retrieval_embedding_id": embedding_id, - "chunker_strategy": os.getenv("EVAL_CHUNKER_STRATEGY", "RecursiveChunker"), - "skip_ingest_if_exists": skip_ingest_if_exists, - "llm_timeout": llm_timeout, - "llm_max_retries": llm_max_retries, - "llm_temperature": llm_temperature, - "llm_max_tokens": llm_max_tokens - }, - "timestamp": datetime.now().isoformat() - } - if output_path: - try: - os.makedirs(os.path.dirname(output_path), exist_ok=True) - with open(output_path, "w", encoding="utf-8") as f: - json.dump(result, f, ensure_ascii=False, indent=2) - print(f"✅ 结果已保存到: {output_path}") - except Exception as e: - print(f"❌ 保存结果失败: {e}") - return result - finally: - await connector.close() - - -def main(): - parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search") - parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate") - parser.add_argument("--end_user_id", type=str, default=None, help="Group ID for retrieval") - parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query") - parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context") - parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature") - parser.add_argument("--llm_max_tokens", type=int, default=32, help="LLM max tokens") - parser.add_argument("--search_type", type=str, default="embedding", choices=["keyword", "embedding", "hybrid"], help="Search type") - parser.add_argument("--output_path", type=str, default=None, help="Output path for results") - parser.add_argument("--skip_ingest_if_exists", action="store_true", help="Skip ingest if group exists") - parser.add_argument("--llm_timeout", type=float, default=10.0, help="LLM timeout in seconds") - parser.add_argument("--llm_max_retries", type=int, default=1, help="LLM max retries") - args = parser.parse_args() - - load_dotenv() - - result = asyncio.run(run_locomo_eval( - sample_size=args.sample_size, - end_user_id=args.end_user_id, - search_limit=args.search_limit, - context_char_budget=args.context_char_budget, - llm_temperature=args.llm_temperature, - llm_max_tokens=args.llm_max_tokens, - search_type=args.search_type, - output_path=args.output_path, - skip_ingest_if_exists=args.skip_ingest_if_exists, - llm_timeout=args.llm_timeout, - llm_max_retries=args.llm_max_retries - )) - - print("\n" + "="*50) - print("📊 最终评测结果:") - print(f" 样本数量: {result['items']}") - print(f" F1: {result['metrics']['f1']:.3f}") - print(f" BLEU-1: {result['metrics']['b1']:.3f}") - print(f" Jaccard: {result['metrics']['j']:.3f}") - print(f" LoCoMo F1: {result['metrics']['loc_f1']:.3f}") - print(f" 平均上下文长度: {result['context']['avg_chars']:.0f} 字符") - print(f" 平均检索延迟: {result['latency']['search']['mean']:.1f}ms") - print(f" 平均LLM延迟: {result['latency']['llm']['mean']:.1f}ms") - - if result['by_category']: - print("\n📈 按类别细分:") - for cat, metrics in result['by_category'].items(): - print(f" {cat}:") - print(f" 样本数: {metrics['count']}") - print(f" F1: {metrics['f1']:.3f}") - print(f" LoCoMo F1: {metrics['loc_f1']:.3f}") - print(f" Jaccard: {metrics['j']:.3f} (±{metrics['j_std']:.3f}, IQR={metrics['j_iqr']:.3f})") - - -if __name__ == "__main__": - main() diff --git a/api/app/core/memory/evaluation/longmemeval/longmemeval_benchmark.py b/api/app/core/memory/evaluation/longmemeval/longmemeval_benchmark.py deleted file mode 100644 index aaf46e35..00000000 --- a/api/app/core/memory/evaluation/longmemeval/longmemeval_benchmark.py +++ /dev/null @@ -1,1339 +0,0 @@ -import argparse -import asyncio -import json -import os -import time -import re -import statistics -from datetime import datetime, timedelta -from typing import List, Dict, Any -from pathlib import Path - -from dotenv import load_dotenv - -# Load evaluation config -eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation" -if eval_config_path.exists(): - load_dotenv(eval_config_path, override=True) - print(f"✅ 加载评估配置: {eval_config_path}") - -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline -from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding -from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.models.base import RedBearModelConfig -from app.core.memory.utils.config.config_utils import get_embedder_config -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME -from app.core.memory.evaluation.common.metrics import f1_score as common_f1, jaccard, latency_stats, avg_context_tokens -from app.core.memory.evaluation.common.metrics import exact_match - - -def load_dataset_any(path: str) -> List[Dict[str, Any]]: - """健壮地加载数据集,支持三种格式: - 1. 标准 JSON 数组: [{...}, {...}] - 2. 单个 JSON 对象: {...} - 3. JSONL 格式(每行一个 JSON): {...}\n{...}\n{...} - """ - with open(path, "r", encoding="utf-8") as f: - content = f.read().strip() - - # 尝试标准 JSON 解析 - try: - data = json.loads(content) - if isinstance(data, list): - return [item for item in data if isinstance(item, dict)] - elif isinstance(data, dict): - return [data] - except json.JSONDecodeError: - pass - - # 尝试 JSONL 格式(每行一个 JSON 对象) - items = [] - for line in content.splitlines(): - line = line.strip() - if not line: - continue - try: - obj = json.loads(line) - if isinstance(obj, dict): - items.append(obj) - elif isinstance(obj, list): - items.extend(item for item in obj if isinstance(item, dict)) - except json.JSONDecodeError: - continue - - return items - - -def is_chinese_text(s: str) -> bool: - return bool(re.search(r"[\u4e00-\u9fff]", s or "")) - - -def build_context_from_sessions(item: Dict[str, Any]) -> List[str]: - """从数据项的 haystack_sessions 构建上下文片段。 - - 优先返回包含 has_answer 的消息 - - 其次返回拼接后的整段会话 - """ - contexts: List[str] = [] - sessions = item.get("haystack_sessions", []) or item.get("sessions", []) - for session in sessions: - parts: List[str] = [] - if isinstance(session, list): - for msg in session: - role = msg.get("role", "") - content = msg.get("content", "") or msg.get("text", "") - if content: - parts.append(f"{role}: {content}" if role else str(content)) - if msg.get("has_answer", False): - contexts.append(f"{role}: {content}" if role else str(content)) - elif isinstance(session, dict): - role = session.get("role", "") - content = session.get("content", "") or session.get("text", "") - if content: - parts.append(f"{role}: {content}" if role else str(content)) - if session.get("has_answer", False): - contexts.append(f"{role}: {content}" if role else str(content)) - if parts: - contexts.append("\n".join(parts)) - # 兜底:存在单字段上下文 - if not contexts: - single_ctx = item.get("context") or item.get("dialogue") or item.get("conversation") - if isinstance(single_ctx, str) and single_ctx.strip(): - contexts.append(single_ctx.strip()) - return contexts - - -def extract_candidate_options(question: str) -> List[str]: - """从问题中提取候选选项(A-or-B 类问题)。""" - q = (question or "").strip() - options: List[str] = [] - - # 1) 引号包裹的片段 - for pat in [r"'([^']+)'", r'\"([^\"]+)\"', r'“([^”]+)”', r'‘([^’]+)’']: - for m in re.findall(pat, q): - val = (m or "").strip() - if val: - options.append(val) - - # 2) or/还是/或者 连接词 - if len(options) < 2: - pats = [ - r"([^,;,;]+?)\s+or\s+([^,;,;\?\.!.。!]+)", - r"([^,;,;]+?)\s+还是\s+([^,;,;\?\.!.。!]+)", - r"([^,;,;]+?)\s+或者\s+([^,;,;\?\.!.。!]+)", - ] - for pat in pats: - matches = list(re.finditer(pat, q, flags=re.IGNORECASE)) - if matches: - m = matches[-1] - cand1 = m.group(1).strip().strip("??.,,;; ") - cand2 = m.group(2).strip().strip("??.,,;; ") - options.extend([cand1, cand2]) - break - - # 去重 - seen = set() - uniq: List[str] = [] - for o in options: - o2 = o.strip() - key = o2.lower() if not is_chinese_text(o2) else o2 - if o2 and key not in seen: - uniq.append(o2) - seen.add(key) - return uniq - - -def extract_time_entities(text: str) -> List[Dict[str, Any]]: - """增强时间实体提取,专门用于时间推理问题""" - time_entities = [] - - # 日期模式 - date_patterns = [ - (r'\b(\d{4})-(\d{1,2})-(\d{1,2})\b', 'date'), # YYYY-MM-DD - (r'\b(\d{1,2})月(\d{1,2})日\b', 'date'), # 中文日期 - (r'\b(January|February|March|April|May|June|July|August|September|October|November|December)\s+(\d{1,2}),?\s+(\d{4})?', 'date'), # 英文月份 - (r'\b(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s+(\d{1,2}),?\s+(\d{4})?', 'date'), # 英文月份缩写 - ] - - # 时间间隔模式 - duration_patterns = [ - (r'(\d+)\s*天', 'days'), - (r'(\d+)\s*周', 'weeks'), - (r'(\d+)\s*个月', 'months'), - (r'(\d+)\s*年', 'years'), - (r'(\d+)\s*days?', 'days'), - (r'(\d+)\s*weeks?', 'weeks'), - (r'(\d+)\s*months?', 'months'), - (r'(\d+)\s*years?', 'years'), - ] - - # 事件时间关系模式 - temporal_relation_patterns = [ - (r'(之前|以前|前)\s*(\d+)\s*天', 'days_before'), - (r'(之后|以后|后)\s*(\d+)\s*天', 'days_after'), - (r'(\d+)\s*天\s*(之前|以前|前)', 'days_before'), - (r'(\d+)\s*天\s*(之后|以后|后)', 'days_after'), - (r'(\d+)\s*days?\s*(before|ago)', 'days_before'), - (r'(\d+)\s*days?\s*(after|later)', 'days_after'), - ] - - # 提取日期 - for pattern, entity_type in date_patterns: - matches = re.finditer(pattern, text, re.IGNORECASE) - for match in matches: - time_entities.append({ - 'text': match.group(), - 'type': entity_type, - 'start': match.start(), - 'end': match.end() - }) - - # 提取时间间隔 - for pattern, entity_type in duration_patterns: - matches = re.finditer(pattern, text, re.IGNORECASE) - for match in matches: - time_entities.append({ - 'text': match.group(), - 'type': entity_type, - 'value': int(match.group(1)), - 'start': match.start(), - 'end': match.end() - }) - - # 提取时间关系 - for pattern, entity_type in temporal_relation_patterns: - matches = re.finditer(pattern, text, re.IGNORECASE) - for match in matches: - time_entities.append({ - 'text': match.group(), - 'type': entity_type, - 'value': int(match.group(2)) if match.groups() >= 2 else int(match.group(1)), - 'start': match.start(), - 'end': match.end() - }) - - return time_entities - - -def calculate_time_difference(date1: str, date2: str) -> int: - """计算两个日期之间的天数差""" - try: - # 解析日期格式 - def parse_date(date_str: str) -> datetime: - # 尝试多种日期格式 - formats = [ - '%Y-%m-%d', - '%m月%d日', - '%B %d, %Y', - '%b %d, %Y', - '%Y年%m月%d日' - ] - - for fmt in formats: - try: - return datetime.strptime(date_str, fmt) - except ValueError: - continue - - # 如果都无法解析,返回当前日期 - return datetime.now() - - d1 = parse_date(date1) - d2 = parse_date(date2) - - # 计算天数差(绝对值) - return abs((d2 - d1).days) - except Exception: - return -1 # 表示计算失败 - - -def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str: - """增强版上下文选择:特别优化时间推理问题的处理""" - if not contexts: - return "" - - # 检测是否为时间推理问题 - is_temporal_question = any(keyword in question.lower() for keyword in - ['days', 'day', 'before', 'after', 'first', '先后', '顺序', '间隔', '多久', '多少天']) - - # 提取时间实体从问题中 - question_time_entities = extract_time_entities(question) - - # 英文关键词(去停用词) - question_lower = question.lower() - stop_words = { - 'what','when','where','who','why','how','did','do','does','is','are','was','were', - 'the','a','an','and','or','but','many','which','first' - } - eng_words = [w for w in set(re.findall(r'\b\w+\b', question_lower)) - if w not in stop_words and len(w) > 2] - - # 中文片段与候选选项 - cn_tokens = generate_query_keywords_cn(question) - options = extract_candidate_options(question) - - # 时间推理问题的特殊处理 - if is_temporal_question: - # 为时间问题添加时间相关关键词 - time_keywords = ['天', '日', '月', '年', 'before', 'after', 'days', 'first', '先后'] - eng_words = [w for w in eng_words if w not in ['days', 'first']] # 避免重复 - cn_tokens.extend([kw for kw in time_keywords if kw not in cn_tokens]) - - # 限制关键词数量,优先时间相关 - tokens = time_keywords[:2] + cn_tokens[:2] + eng_words[:1] + options[:1] - else: - # 常规问题处理 - tokens = cn_tokens[:3] + options[:2] + eng_words[:1] - - # 去重 - seen = set() - final_tokens: List[str] = [] - for t in tokens: - t2 = t.strip() - if t2 and t2 not in seen: - final_tokens.append(t2) - seen.add(t2) - - scored_contexts: List[tuple[float, str]] = [] - - # 时间推理问题的权重映射 - temporal_weight_map = { - "天": 2.0, "日": 2.0, "月": 1.8, "年": 1.8, "days": 2.0, - "before": 1.5, "after": 1.5, "first": 1.5, "先后": 1.5 - } - - # 常规问题的权重映射 - normal_weight_map = { - "问题": 2.0, "故障": 2.0, "异常": 1.8, "不正常": 1.8, "坏了": 1.8, - "系统": 1.3, "GPS": 1.5, "保养": 1.4, "设备": 1.2, "模块": 1.2, "功能": 1.1 - } - - weight_map = temporal_weight_map if is_temporal_question else normal_weight_map - - for i, context in enumerate(contexts): - context_str = str(context) - lines = re.split(r'[\r\n]+', context_str) - hit_lines: List[str] = [] - kw_hits: float = 0.0 - time_entity_count = 0 - - for line in lines: - ln = line.strip() - if not ln: - continue - - has_keyword = False - # 关键词匹配 - for tok in final_tokens: - if tok and tok in ln: - w = weight_map.get(tok, 1.0) - kw_hits += ln.count(tok) * w - has_keyword = True - - # 时间实体检测(特别针对时间推理问题) - if is_temporal_question: - time_entities = extract_time_entities(ln) - time_entity_count += len(time_entities) - if time_entities: - has_keyword = True - - if has_keyword: - # 对于时间推理问题,保留包含时间信息的完整行 - hit_lines.append(ln) - - snippet = "\n".join(hit_lines) if hit_lines else context_str.strip() - - # 限制单段长度,但对时间推理问题稍微放宽限制 - max_snippet_len = 600 if is_temporal_question else 500 - if len(snippet) > max_snippet_len: - snippet = snippet[:max_snippet_len] - - # 评分逻辑 - has_number = 1 if re.search(r'\d', snippet) else 0 - has_date = 1 if (re.search(r'\b\d{4}-\d{1,2}-\d{1,2}\b', snippet) or - re.search(r'\d{1,2}月\d{1,2}日', snippet)) else 0 - - # 时间推理问题的特殊评分 - if is_temporal_question: - time_bonus = time_entity_count * 2.0 # 时间实体奖励 - temporal_coherence = 3 if (has_date and time_entity_count >= 2) else 0 - else: - time_bonus = 0 - temporal_coherence = 0 - - length_bonus = 5 if 50 < len(snippet) < 1000 else (2 if len(snippet) >= 1000 else 0) - pos_bonus = 3 if i < 3 else 0 - - score = (kw_hits * 0.8 + (has_number + has_date) * 1.5 + - length_bonus + pos_bonus + time_bonus + temporal_coherence) - - scored_contexts.append((score, snippet)) - - # 选择累计至总字符预算 - scored_contexts.sort(key=lambda x: x[0], reverse=True) - selected: List[str] = [] - total_chars = 0 - - for score, snippet in scored_contexts: - if total_chars + len(snippet) <= max_chars: - selected.append(snippet) - total_chars += len(snippet) - else: - if not selected and len(snippet) > max_chars: - selected.append(snippet[:max_chars]) - break - - final_context = "\n\n".join(selected) - - # 对于时间推理问题,添加时间计算提示 - if is_temporal_question and question_time_entities: - time_prompt = "\n\n[时间推理提示:请仔细分析上述上下文中的日期和时间关系,计算时间间隔或确定事件顺序]" - if total_chars + len(time_prompt) <= max_chars: - final_context += time_prompt - - return final_context - - -# 中文关键词提取(短语级,含数词/日期/常见领域词) -def _extract_cn_tokens(text: str) -> List[str]: - if not text: - return [] - t = str(text) - # 去掉常见功能词(粗略,不依赖分词库) - stop_words = [ - "我","我们","你","他","她","它","这","那","哪","一个","一次","一些","什么","怎么","是否","吗","呢", - "很","更","最","已经","正在","将要","马上","尽快","最近","关于","有关","以及","并且","或者","还是", - "因为","所以","如果","但是","而且","然后","之后","之前","同时","另外","并","但","却","被","把","让","给", - "和","与","跟","及","还有","就","都","在","对","对于","的","了","着","过","到","于","从","以","为","向","至","是" - ] - for sw in stop_words: - t = t.replace(sw, " ") - # 去标点 - t = re.sub(r"[,。!?、;:,.!?;:\"'()()[]\[\]\-—…·]", " ", t) - # 基础中文片段(>=2) - base = re.findall(r"[\u4e00-\u9fff]{2,}", t) - # 特殊组合:第X次XXXX - specials = re.findall(r"第[一二三四五六七八九十]+次[\u4e00-\u9fff]{2,6}", text) - # 领域词(简单词典) - # 日期与数字 - dates = re.findall(r"\d{4}年\d{1,2}月\d{1,2}日|\d{1,2}月\d{1,2}日|\d{4}-\d{1,2}-\d{1,2}", text) - numbers = re.findall(r"\b\d+\b", text) - - tokens: List[str] = specials + base + dates + numbers - - generic = {"建议","推荐","帮助","提升","技能","有效","团队","参与度","喜欢","开始"} - tokens: List[str] = specials + base + dates + numbers - uniq: List[str] = [] - seen = set() - for tok in tokens: - tok2 = tok.strip() - if len(tok2) < 2 or len(tok2) > 6: - continue - if tok2 in generic: - continue - if tok2 not in seen: - uniq.append(tok2) - seen.add(tok2) - # 排除常见疑问型短语 - blacklist_exact = {"是什么","多少","多少天","哪个","哪些","之间","先","后","之前","之后"} - uniq2: List[str] = [u for u in uniq if u not in blacklist_exact] - return uniq2[:12] - - -# 面向检索的中文关键词生成:强调"短语、核心名词、问题/故障" -def generate_query_keywords_cn(question: str) -> List[str]: - if not question: - return [] - raw = _extract_cn_tokens(question) - core: List[str] = [] - seen = set() - - def push(x: str): - x2 = x.strip() - if not x2: - return - if 2 <= len(x2) <= 6 and x2 not in seen: - core.append(x2) - seen.add(x2) - - # 检测时间推理问题 - is_temporal = any(keyword in question for keyword in ['天', '日', 'before', 'after', 'first', '先后', '间隔']) - if is_temporal: - push("天") - push("日") - push("先后") - - # 明确优先的核心词 - if "新车" in question: - push("新车") - # 第X次保养/维修 - specials = re.findall(r"第[一二三四五六七八九十]+次[\u4e00-\u9fff]{2,6}", question) - for s in specials: - if "保养" in s or "维修" in s: - push(s) - if "保养" in question: - push("保养") - # 问题/故障类词,如题含"问题"则扩展同义词 - if "问题" in question: - for w in ["问题","故障","异常","不正常"]: - push(w) - - # 补充:从原始片段筛更短的名词短语(过滤疑问型词) - blacklist = {"是什么","多少","哪个","还是","或者","之间","先","后","之前","之后"} - for tok in raw: - if tok in blacklist: - continue - push(tok) - - # 限制数量,避免过长列表影响检索稳定性 - return core[:4] # 稍微增加限制 - - -# 通过别名匹配进行实体关键词检索(多token合并) -async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], end_user_id: str | None, limit: int) -> List[Dict[str, Any]]: - results: List[Dict[str, Any]] = [] - try: - for tok in tokens: - rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, end_user_id=end_user_id, limit=limit) - if rows: - results.extend(rows) - except Exception: - pass - - # 按 name 去重 - deduped: List[Dict[str, Any]] = [] - seen = set() - for r in results: - k = str(r.get("name", "")) - if k and k not in seen: - deduped.append(r) - seen.add(k) - return deduped - - -# 通过对话/陈述中的entity_ids反查实体名称 -_FETCH_ENTITIES_BY_IDS = """ -MATCH (e:ExtractedEntity) -WHERE e.id IN $ids AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id) -RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type -""" - -async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], end_user_id: str | None) -> List[Dict[str, Any]]: - if not ids: - return [] - try: - rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), end_user_id=end_user_id) - return rows or [] - except Exception: - return [] - - -# 增强的时间实体检索 -_TIME_ENTITY_SEARCH = """ -MATCH (e:ExtractedEntity) -WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern -AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id) -RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type -LIMIT $limit -""" - -async def _search_time_entities(connector: Neo4jConnector, end_user_id: str | None, limit: int = 5) -> List[Dict[str, Any]]: - """专门搜索时间相关的实体""" - try: - date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*" - rows = await connector.execute_query(_TIME_ENTITY_SEARCH, - date_pattern=date_pattern, - end_user_id=end_user_id, - limit=limit) - return rows or [] - except Exception: - return [] - - -# 中英相对时间解析:today/昨天/上周/3天后 等简单归一化为日期 -def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str: - t = str(text) if text is not None else "" - # 英文 today/yesterday/tomorrow - t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE) - - # 英文 X days ago / in X days - def _ago_repl(m: re.Match[str]) -> str: - n = int(m.group(1)) - return (anchor - timedelta(days=n)).date().isoformat() - def _in_repl(m: re.Match[str]) -> str: - n = int(m.group(1)) - return (anchor + timedelta(days=n)).date().isoformat() - t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE) - t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE) - t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE) - - # 中文 今天/昨天/明天 - t = re.sub(r"今天", anchor.date().isoformat(), t) - t = re.sub(r"昨日|昨天", (anchor - timedelta(days=1)).date().isoformat(), t) - t = re.sub(r"明天", (anchor + timedelta(days=1)).date().isoformat(), t) - # 中文 X天前 / X天后 - t = re.sub(r"(\d+)天前", lambda m: (anchor - timedelta(days=int(m.group(1)))).date().isoformat(), t) - t = re.sub(r"(\d+)天后", lambda m: (anchor + timedelta(days=int(m.group(1)))).date().isoformat(), t) - # 中文 上周 / 下周(近似7天) - t = re.sub(r"上周", (anchor - timedelta(days=7)).date().isoformat(), t) - t = re.sub(r"下周", (anchor + timedelta(days=7)).date().isoformat(), t) - # 中文 月日(无年份)补全年份 - def _md_repl(m: re.Match[str]) -> str: - mon = int(m.group(1)); day = int(m.group(2)) - return f"{anchor.year}-{mon:02d}-{day:02d}" - t = re.sub(r"(\d{1,2})月(\d{1,2})日", _md_repl, t) - return t - - -async def run_longmemeval_test( - sample_size: int = 3, - end_user_id: str | None = None, - search_limit: int = 8, - context_char_budget: int = 4000, - llm_temperature: float = 0.0, - llm_max_tokens: int = 16, - search_type: str = "hybrid", - data_path: str | None = None, - start_index: int = 0, - max_contexts_per_item: int = 2, - save_chunk_output: bool = True, - save_chunk_output_path: str | None = None, - reset_group_before_ingest: bool = False, - skip_ingest: bool = False, -) -> Dict[str, Any]: - """LongMemEval 评估测试:增强时间推理能力""" - - # Use environment variable with fallback chain - if end_user_id is None: - end_user_id = os.getenv("LONGMEMEVAL_END_USER_ID") or os.getenv("EVAL_END_USER_ID", "longmemeval_zh_bak_3") - - # 数据路径 - if not data_path: - # 固定使用中文数据集:dataset/longmemeval_oracle_zh.json - dataset_dir = Path(__file__).resolve().parent.parent / "dataset" - data_path = str(dataset_dir / "longmemeval_oracle_zh.json") - - if not os.path.exists(data_path): - raise FileNotFoundError( - f"数据集文件不存在: {data_path}\n" - f"请将 longmemeval_oracle_zh.json 放置在: {dataset_dir}" - ) - - qa_list: List[Dict[str, Any]] = load_dataset_any(data_path) - # 支持评估全部样本:当 sample_size <= 0 时,取从 start_index 到末尾 - if sample_size is None or sample_size <= 0: - items = qa_list[start_index:] - else: - items = qa_list[start_index:start_index + sample_size] - - # 可选:摄入上下文(默认启用) - if not skip_ingest: - # 选择上下文并限量 - contexts: List[str] = [] - for it in items: - built = build_context_from_sessions(it) - full_transcripts = [c for c in built if "\n" in c] - evidence_msgs = [c for c in built if "\n" not in c] - selected: List[str] = [] - take_e = min(len(evidence_msgs), max_contexts_per_item) - selected.extend(evidence_msgs[:take_e]) - remain = max_contexts_per_item - len(selected) - if remain > 0 and full_transcripts: - selected.extend(full_transcripts[:remain]) - if not selected and built: - selected.append(built[0]) - contexts.extend(selected) - - print(f"📥 摄入 {len(contexts)} 个上下文到数据库") - if reset_group_before_ingest and end_user_id: - try: - _tmp_conn = Neo4jConnector() - await _tmp_conn.delete_group(end_user_id) - print(f"🧹 已清空组 {end_user_id} 的历史图数据") - except Exception as _e: - print(f"⚠️ 清空组数据失败(忽略继续): {end_user_id} - {_e}") - finally: - try: - await _tmp_conn.close() - except Exception: - pass - _ingest_fn = ingest_contexts_via_full_pipeline - if _ingest_fn is None: - print("⚠️ 摄入函数不可用,已跳过摄入。请确认 PYTHONPATH 包含 'src' 或从项目根运行。") - else: - await _ingest_fn( - contexts, - end_user_id, - save_chunk_output=save_chunk_output, - save_chunk_output_path=save_chunk_output_path, - ) - - # 初始化组件(摄入后再初始化连接器)- 使用异步LLM客户端 - from app.db import get_db - - db = next(get_db()) - try: - llm_client = get_llm_client(os.getenv("EVAL_LLM_ID"), db) - cfg_dict = get_embedder_config(os.getenv("EVAL_EMBEDDING_ID"), db) - embedder = OpenAIEmbedderClient( - model_config=RedBearModelConfig.model_validate(cfg_dict) - ) - finally: - db.close() - - connector = Neo4jConnector() - - # 指标收集 - latencies_llm: List[float] = [] - latencies_search: List[float] = [] - per_query_context_counts: List[int] = [] - per_query_context_avg_tokens: List[float] = [] - per_query_context_chars: List[int] = [] - - type_correct: Dict[str, List[float]] = {} - type_f1: Dict[str, List[float]] = {} - type_jacc: Dict[str, List[float]] = {} - - samples: List[Dict[str, Any]] = [] - # 统计重复的上下文预览(跨样本),便于诊断"相同上下文"问题 - preview_counter: Dict[str, int] = {} - - try: - for item in items: - question = item.get("question", "") - reference = item.get("answer", "") - qtype = item.get("question_type") or item.get("type", "unknown") - - print(f"\n=== 处理问题: {question} ===") - - # 检测问题类型 - is_temporal = any(keyword in question.lower() for keyword in - ['days', 'day', 'before', 'after', 'first', '先后', '顺序', '间隔', '多久', '多少天']) - - # 检索 - t0 = time.time() - contexts_all: List[str] = [] - dialogs, statements, entities = [], [], [] - - try: - if search_type == "embedding": - search_results = await search_graph_by_embedding( - connector=connector, - embedder_client=embedder, - query_text=question, - end_user_id=end_user_id, - limit=search_limit, - include=["chunks", "statements", "entities", "summaries"], - ) - chunks = search_results.get("chunks", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - - for d in dialogs: - content = str(d.get("content", "")).strip() - if content: - contexts_all.append(content) - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - - for sm in summaries: - summary_text = str(sm.get("summary", "")).strip() - if summary_text: - contexts_all.append(summary_text) - - # 实体摘要(最多3个) - scored = [e for e in entities if e.get("score") is not None] - top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3] - if top_entities: - summary_lines = [] - for e in top_entities: - name = str(e.get("name", "")).strip() - etype = str(e.get("entity_type", "")).strip() - score = e.get("score") - if name: - meta = [] - if etype: - meta.append(f"type={etype}") - if isinstance(score, (int, float)): - meta.append(f"score={score:.3f}") - summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}") - if summary_lines: - contexts_all.append("\n".join(summary_lines)) - - elif search_type == "keyword": - search_results = await search_graph( - connector=connector, - q=question, - end_user_id=end_user_id, - limit=search_limit, - ) - chunks = search_results.get("chunks", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - summaries = search_results.get("summaries", []) - - for c in chunks: - content = str(c.get("content", "")).strip() - if content: - contexts_all.append(content) - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - for sm in summaries: - summary_text = str(sm.get("summary", "")).strip() - if summary_text: - contexts_all.append(summary_text) - if entities: - entity_names = [str(e.get("name", "")).strip() for e in entities[:5] if e.get("name")] - if entity_names: - contexts_all.append(f"EntitySummary: {', '.join(entity_names)}") - - else: # hybrid(增强版:特别优化时间推理问题) - emb_chunks, emb_statements, emb_entities, emb_summaries, emb_dialogs = [], [], [], [], [] - kw_dialogs, kw_statements, kw_entities = [], [], [] - - # 1) 嵌入检索 - try: - emb_res = await search_graph_by_embedding( - connector=connector, - embedder_client=embedder, - query_text=question, - end_user_id=end_user_id, - limit=search_limit, - include=["chunks", "statements", "entities", "summaries"], - ) - if isinstance(emb_res, dict): - emb_chunks = emb_res.get("chunks", []) or [] - emb_statements = emb_res.get("statements", []) or [] - emb_entities = emb_res.get("entities", []) or [] - emb_summaries = emb_res.get("summaries", []) or [] - emb_dialogs = emb_res.get("dialogues", []) or [] - except Exception as e: - print(f"⚠️ 嵌入检索失败,将继续进行关键词检索: {e}") - - # 2) 关键词检索(增强版) - try: - kw_res = await search_graph( - connector=connector, - q=question, - end_user_id=end_user_id, - limit=search_limit, - ) - if isinstance(kw_res, dict): - kw_dialogs = kw_res.get("dialogues", []) or [] - kw_statements = kw_res.get("statements", []) or [] - kw_entities = kw_res.get("entities", []) or [] - - # 时间推理问题的特殊处理 - if is_temporal: - # 专门搜索时间实体 - time_entities = await _search_time_entities(connector, end_user_id, search_limit//2) - if time_entities: - kw_entities.extend(time_entities) - # 添加时间相关关键词检索 - time_keywords = ['天', '日', '月', '年', 'before', 'after', 'first'] - for tk in time_keywords: - try: - time_res = await search_graph( - connector=connector, - q=tk, - end_user_id=end_user_id, - limit=2, - ) - if isinstance(time_res, dict): - kw_dialogs.extend(time_res.get("dialogues", []) or []) - kw_statements.extend(time_res.get("statements", []) or []) - except Exception: - pass - - # 中文关键词拆分后做别名匹配 - cn_tokens = _extract_cn_tokens(question) - alias_entities = await _search_entities_by_aliases(connector, cn_tokens, end_user_id, search_limit) - if alias_entities: - kw_entities.extend(alias_entities) - - # 从对话/陈述中的 entity_ids 反查实体 - ids = [] - try: - for d in kw_dialogs: - ids.extend(d.get("entity_ids", []) or []) - for s in kw_statements: - ids.extend(s.get("entity_ids", []) or []) - except Exception: - pass - if ids: - id_entities = await _fetch_entities_by_ids(connector, ids, end_user_id) - if id_entities: - kw_entities.extend(id_entities) - - # 多关键词检索 - try: - eng_words = [w for w in set(re.findall(r"\b\w+\b", question.lower())) if len(w) > 2] - kw_list = generate_query_keywords_cn(question)[:3] + eng_words[:1] - for kw in kw_list: - if not kw: - continue - sub_res = await search_graph( - connector=connector, - q=str(kw), - end_user_id=end_user_id, - limit=max(3, search_limit // 2), - ) - if isinstance(sub_res, dict): - kw_dialogs.extend(sub_res.get("dialogues", []) or []) - kw_statements.extend(sub_res.get("statements", []) or []) - kw_entities.extend(sub_res.get("entities", []) or []) - except Exception: - pass - - # 选项参与关键词检索 - try: - opt_list = extract_candidate_options(question)[:2] - for opt in opt_list: - if not opt: - continue - opt_res = await search_graph( - connector=connector, - q=str(opt), - end_user_id=end_user_id, - limit=max(3, search_limit // 2), - ) - if isinstance(opt_res, dict): - kw_dialogs.extend(opt_res.get("dialogues", []) or []) - kw_statements.extend(opt_res.get("statements", []) or []) - kw_entities.extend(opt_res.get("entities", []) or []) - except Exception: - pass - except Exception as e: - print(f"❌ 关键词检索失败: {e}") - - # 3) 合并、排序并去重 - all_dialogs = emb_dialogs + kw_dialogs - all_statements = emb_statements + kw_statements - all_entities = emb_entities + kw_entities - - def dedup(items: List[Dict[str, Any]], key_field: str = "uuid") -> List[Dict[str, Any]]: - seen = set() - out = [] - for it in items: - key = str(it.get(key_field, "")) + str(it.get("content", "") + str(it.get("statement", ""))) - if key not in seen: - out.append(it) - seen.add(key) - return out - - # 时间推理问题优先排序包含时间信息的文档 - if is_temporal: - def temporal_score(item: Dict[str, Any]) -> float: - base_score = float(item.get("score", 0.0)) - content = str(item.get("content", "") + str(item.get("statement", ""))) - time_entities = extract_time_entities(content) - time_bonus = len(time_entities) * 0.5 - return base_score + time_bonus - - dialogs = dedup(sorted(all_dialogs, key=temporal_score, reverse=True)) - statements = dedup(sorted(all_statements, key=temporal_score, reverse=True)) - else: - dialogs = dedup(sorted(all_dialogs, key=lambda d: float(d.get("score", 0.0)), reverse=True)) - statements = dedup(sorted(all_statements, key=lambda s: float(s.get("score", 0.0)), reverse=True)) - - entities = dedup(all_entities, key_field="name") - - # 4) 构建上下文 - for d in dialogs: - content = str(d.get("content", "")).strip() - if content: - contexts_all.append(content) - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - # 实体摘要 - try: - scored = [e for e in entities if e.get("score") is not None] - top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3] - if top_entities: - summary_lines = [] - for e in top_entities: - name = str(e.get("name", "")).strip() - etype = str(e.get("entity_type", "")).strip() - score = e.get("score") - if name: - meta = [] - if etype: - meta.append(f"type={etype}") - if isinstance(score, (int, float)): - meta.append(f"score={score:.3f}") - summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}") - if summary_lines: - contexts_all.append("\n".join(summary_lines)) - except Exception: - pass - - # 全局回退 - if not contexts_all and search_type in ("embedding", "hybrid"): - try: - print("🔁 检索为空,回退到关键词检索...") - kw_fallback = await search_graph( - connector=connector, - q=question, - end_user_id=end_user_id, - limit=max(search_limit, 5), - ) - fb_dialogs = kw_fallback.get("dialogues", []) or [] - fb_statements = kw_fallback.get("statements", []) or [] - fb_entities = kw_fallback.get("entities", []) or [] - - for d in fb_dialogs: - content = str(d.get("content", "")).strip() - if content: - contexts_all.append(content) - for s in fb_statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - if fb_entities: - entity_names = [str(e.get("name", "")).strip() for e in fb_entities[:5] if e.get("name")] - if entity_names: - contexts_all.append(f"EntitySummary: {', '.join(entity_names)}") - - dialogs = fb_dialogs if fb_dialogs else dialogs - statements = fb_statements if fb_statements else statements - entities = fb_entities if fb_entities else entities - print(f"↩️ 回退到关键词检索: {len(fb_dialogs)} 对话, {len(fb_statements)} 条陈述, {len(fb_entities)} 个实体") - except Exception as fe: - print(f"❌ 关键词回退失败: {fe}") - - ent_count = len(entities) if isinstance(entities, list) else 0 - print(f"✅ {search_type}检索成功: {len(dialogs)} 对话, {len(statements)} 条陈述, {ent_count} 个实体") - if is_temporal: - print("⏰ 检测为时间推理问题,已启用时间优化检索") - - except Exception as e: - print(f"❌ {search_type}检索失败: {e}") - contexts_all = [] - - t1 = time.time() - latencies_search.append((t1 - t0) * 1000) - - # 智能上下文选择 - context_text = "" - if contexts_all: - context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) - # 相对时间解析 - try: - context_text = _resolve_relative_times_cn_en(context_text, anchor=datetime.now()) - except Exception: - pass - # 诊断信息 - try: - cn_diag = generate_query_keywords_cn(question)[:3] - opts = extract_candidate_options(question)[:2] - qlw = [w for w in set(re.findall(r'\b\w+\b', question.lower())) if len(w) > 2][:1] - diag_tokens: List[str] = [] - for t in cn_diag + opts + qlw: - if t and t not in diag_tokens: - diag_tokens.append(t) - print(f"🔍 关键词/选项: {', '.join(diag_tokens)}") - preview = context_text[:200].replace('\n', ' ') - print(f"🔎 上下文预览: {preview}...") - key_preview = preview.strip() - if key_preview: - preview_counter[key_preview] = preview_counter.get(key_preview, 0) + 1 - except Exception: - pass - else: - print("❌ 没有检索到有效上下文") - context_text = "No relevant context found." - - # 记录上下文诊断信息 - per_query_context_counts.append(len(contexts_all)) - per_query_context_avg_tokens.append(avg_context_tokens([context_text])) - per_query_context_chars.append(len(context_text)) - - # LLM 推理(增强时间推理提示) - options = extract_candidate_options(question) - if len(options) >= 2: - opt_lines = "\n".join(f"- {o}" for o in options) - # 时间推理问题的特殊提示 - if is_temporal: - system_prompt = ( - "You are a QA assistant specializing in temporal reasoning. Analyze the dates and time relationships in the context carefully. " - "Return ONLY one string: exactly one option from the provided candidates. If the context is insufficient, respond with 'Unknown'. " - "Pay special attention to date sequences and time intervals." - ) - else: - system_prompt = ( - "You are a QA assistant. Respond in the same language as the question. Return ONLY one string: exactly one option from the provided candidates. " - "If the context is insufficient, respond with 'Unknown'. If the context expresses a synonym or paraphrase of a candidate, return the closest candidate. " - "Do not include explanations." - ) - - messages = [ - {"role": "system", "content": system_prompt}, - { - "role": "user", - "content": ( - f"Question: {question}\n\nCandidates:\n{opt_lines}\n\nContext:\n{context_text}\n\nReturn EXACTLY one candidate string (or 'Unknown')." - ), - }, - ] - else: - # 时间推理问题的特殊提示 - if is_temporal: - system_prompt = ( - "You are a QA assistant specializing in temporal reasoning. Analyze the dates and time relationships in the context carefully. " - "If the context contains the answer, return a concise answer phrase focusing on temporal information. " - "If the answer cannot be determined from the context, respond with 'Unknown'. Return ONLY the final answer string, no explanations." - ) - else: - system_prompt = ( - "You are a QA assistant. Respond in the same language as the question. If the context contains the answer, return a concise answer phrase. " - "If the answer cannot be determined from the context, respond with 'Unknown'. Return ONLY the final answer string, no explanations." - ) - - messages = [ - {"role": "system", "content": system_prompt}, - { - "role": "user", - "content": f"Question: {question}\n\nContext:\n{context_text}\n\nReturn ONLY the answer (or 'Unknown').", - }, - ] - - t2 = time.time() - # 使用异步调用 - resp = await llm_client.chat(messages=messages) - t3 = time.time() - latencies_llm.append((t3 - t2) * 1000) - - # 兼容不同的响应格式 - pred_raw = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown") - - # 选项题输出规范化 - pred = pred_raw - if len(options) >= 2 and not pred_raw.lower().startswith("unknown"): - def _basic_norm(s: str) -> str: - s = s.lower().strip() - return re.sub(r"[^\w\s]", " ", s) - def _jaccard(a: str, b: str) -> float: - ta = set(t for t in _basic_norm(a).split() if t) - tb = set(t for t in _basic_norm(b).split() if t) - if not ta and not tb: - return 1.0 - if not ta or not tb: - return 0.0 - return len(ta & tb) / len(ta | tb) - best = None - best_score = -1.0 - for o in options: - score = _jaccard(pred_raw, o) - if score > best_score: - best = o - best_score = score - if best is not None and best_score > 0.0: - pred = best - - # 指标 - flag = exact_match(pred, reference) - f1_val = common_f1(str(pred), str(reference)) - j_val = jaccard(str(pred), str(reference)) - - type_correct.setdefault(qtype, []).append(flag) - type_f1.setdefault(qtype, []).append(f1_val) - type_jacc.setdefault(qtype, []).append(j_val) - - samples.append({ - "question": question, - "prediction": pred, - "answer": reference, - "question_type": qtype, - "is_temporal": is_temporal, - "question_id": item.get("question_id"), - "options": options, - "context_count": len(contexts_all), - "context_chars": len(context_text), - "retrieved_dialogue_count": len(dialogs), - "retrieved_statement_count": len(statements), - "metrics": { - "exact_match": bool(flag), - "f1": f1_val, - "jaccard": j_val - }, - "timing": { - "search_ms": (t1 - t0) * 1000, - "llm_ms": (t3 - t2) * 1000 - } - }) - - print(f"🤖 LLM 回答: {pred}") - print(f"✅ 正确答案: {reference}") - print(f"📈 当前指标 - Exact Match: {flag}, F1: {f1_val:.3f}, Jaccard: {j_val:.3f}") - - # 聚合结果 - type_acc = {t: (sum(v) / max(len(v), 1)) for t, v in type_correct.items()} - f1_by_type = {t: (sum(v) / max(len(v), 1)) for t, v in type_f1.items()} - jacc_by_type = {t: (sum(v) / max(len(v), 1)) for t, v in type_jacc.items()} - - result = { - "dataset": "longmemeval", - "items": len(items), - "accuracy_by_type": type_acc, - "f1_by_type": f1_by_type, - "jaccard_by_type": jacc_by_type, - "samples": samples, - "latency": { - "search": latency_stats(latencies_search), - "llm": latency_stats(latencies_llm), - }, - "context": { - "avg_tokens": statistics.mean(per_query_context_avg_tokens) if per_query_context_avg_tokens else 0.0, - "avg_chars": statistics.mean(per_query_context_chars) if per_query_context_chars else 0.0, - "count_avg": statistics.mean(per_query_context_counts) if per_query_context_counts else 0.0, - }, - "params": { - "end_user_id": end_user_id, - "search_limit": search_limit, - "context_char_budget": context_char_budget, - "search_type": search_type, - "llm_id": os.getenv("EVAL_LLM_ID"), - "embedding_id": os.getenv("EVAL_EMBEDDING_ID"), - "sample_size": sample_size, - "start_index": start_index, - }, - "timestamp": datetime.now().isoformat() - } - - # 计算汇总指标 - try: - total_items = max(len(samples), 1) - correct_count = sum(1 for s in samples if s.get("metrics", {}).get("exact_match")) - score_accuracy = (correct_count / total_items) * 100.0 - - total_latencies_ms = [] - for s in samples: - t = s.get("timing", {}) - total_latencies_ms.append(float(t.get("search_ms", 0.0)) + float(t.get("llm_ms", 0.0))) - total_lat_stats = latency_stats(total_latencies_ms) if total_latencies_ms else {"p50": 0.0, "iqr": 0.0} - latency_median_s = total_lat_stats.get("p50", 0.0) / 1000.0 - latency_iqr_s = total_lat_stats.get("iqr", 0.0) / 1000.0 - - avg_ctx_tokens = statistics.mean(per_query_context_avg_tokens) if per_query_context_avg_tokens else 0.0 - avg_ctx_tokens_k = avg_ctx_tokens / 1000.0 - - result["metric_summary"] = { - "score_accuracy": score_accuracy, - "latency_median_s": latency_median_s, - "latency_iqr_s": latency_iqr_s, - "avg_context_tokens_k": avg_ctx_tokens_k, - } - except Exception: - result["metric_summary"] = { - "score_accuracy": 0.0, - "latency_median_s": 0.0, - "latency_iqr_s": 0.0, - "avg_context_tokens_k": 0.0, - } - - # 诊断信息 - try: - dups = sorted([(k, c) for k, c in preview_counter.items() if c > 1], key=lambda x: -x[1])[:5] - result["diagnostics"] = { - "duplicate_previews_top": [{"count": c, "preview": k[:120]} for k, c in dups], - "unique_preview_count": len(preview_counter), - } - except Exception: - pass - - return result - - finally: - await connector.close() - -def main(): - load_dotenv() - parser = argparse.ArgumentParser(description="LongMemEval 评估测试脚本(增强时间推理版)") - parser.add_argument("--sample-size", type=int, default=3, help="样本数量(<=0 表示全部)") - parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size)") - parser.add_argument("--start-index", type=int, default=0, help="起始样本索引") - parser.add_argument("--end-user-id", type=str, default=None, help="图数据库 End User ID,默认使用环境变量") - parser.add_argument("--search-limit", type=int, default=8, help="检索条数上限") - parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算") - parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度") - parser.add_argument("--llm-max-tokens", type=int, default=16, help="LLM 最大输出 token") - parser.add_argument("--search-type", type=str, default="hybrid", choices=["embedding","keyword","hybrid"], help="检索类型") - parser.add_argument("--data-path", type=str, default=None, help="数据集路径") - parser.add_argument("--max-contexts-per-item", type=int, default=2, help="每条样本最多摄入的上下文段数") - parser.add_argument("--no-save-chunk-output", action="store_true", help="不保存分块结果(默认保存)") - parser.add_argument("--save-chunk-output-path", type=str, default=None, help="自定义分块输出路径") - parser.add_argument("--reset-group-before-ingest", action="store_true", help="摄入前清空该 Group 在图数据库中的历史数据") - parser.add_argument("--skip-ingest", action="store_true", help="跳过摄入,仅检索评估") - args = parser.parse_args() - - sample_size = 0 if args.all else args.sample_size - - result = asyncio.run( - run_longmemeval_test( - sample_size=sample_size, - end_user_id=args.end_user_id, - search_limit=args.search_limit, - context_char_budget=args.context_char_budget, - llm_temperature=args.llm_temperature, - llm_max_tokens=args.llm_max_tokens, - search_type=args.search_type, - data_path=args.data_path, - start_index=args.start_index, - max_contexts_per_item=args.max_contexts_per_item, - save_chunk_output=(not args.no_save_chunk_output), - save_chunk_output_path=args.save_chunk_output_path, - reset_group_before_ingest=args.reset_group_before_ingest, - skip_ingest=args.skip_ingest, - ) - ) - - # 打印结果 - print("\n" + "="*50) - print("📊 LongMemEval 测试结果:") - print(f" 样本数量: {result['items']}") - - if result['accuracy_by_type']: - print("\n📈 按问题类型细分:") - for qtype, acc in result['accuracy_by_type'].items(): - print(f" {qtype}:") - print(f" Score (Accuracy): {acc:.3f}") - - print(f"\n📊 指标总览:") - ms = result.get('metric_summary', {}) - print(f" Score (Accuracy): {ms.get('score_accuracy', 0.0):.1f}%") - print(f" Latency (s): median {ms.get('latency_median_s', 0.0):.3f}s") - print(f" Latency IQR (s): {ms.get('latency_iqr_s', 0.0):.3f}s") - print(f" Avg Context Tokens (k): {ms.get('avg_context_tokens_k', 0.0):.3f}k") - - print(f"\n⏱️ 细分性能指标:") - print(f" 检索延迟(均值): {result['latency']['search']['mean']:.1f}ms") - print(f" LLM延迟(均值): {result['latency']['llm']['mean']:.1f}ms") - print(f" 上下文长度(均值): {result['context']['avg_chars']:.0f} 字符") - - - # 保存结果到文件 - try: - # 使用相对路径而不是 PROJECT_ROOT - out_dir = Path(__file__).resolve().parent / "results" - os.makedirs(out_dir, exist_ok=True) - ts = datetime.now().strftime("%Y%m%d_%H%M%S") - out_path = os.path.join(out_dir, f"longmemeval_{result['params']['search_type']}_{ts}.json") - with open(out_path, "w", encoding="utf-8") as f: - json.dump(result, f, ensure_ascii=False, indent=2) - print(f"\n💾 结果已保存: {out_path}") - except Exception as e: - print(f"⚠️ 结果保存失败: {e}") - - -if __name__ == "__main__": - main() diff --git a/api/app/core/memory/evaluation/longmemeval/test_eval.py b/api/app/core/memory/evaluation/longmemeval/test_eval.py deleted file mode 100644 index 08daa890..00000000 --- a/api/app/core/memory/evaluation/longmemeval/test_eval.py +++ /dev/null @@ -1,1312 +0,0 @@ -import argparse -import asyncio -import json -import os -import time -import re -import statistics -from datetime import datetime, timedelta -from typing import List, Dict, Any -from pathlib import Path - -from dotenv import load_dotenv - -# Load evaluation config -eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation" -if eval_config_path.exists(): - load_dotenv(eval_config_path, override=True) - print(f"✅ 加载评估配置: {eval_config_path}") - -# 与现有评估脚本保持一致的导入方式 -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding -from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.models.base import RedBearModelConfig -from app.core.memory.utils.config.config_utils import get_embedder_config -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME -from app.core.memory.evaluation.common.metrics import f1_score as common_f1, jaccard, latency_stats, avg_context_tokens -from app.core.memory.evaluation.common.metrics import exact_match - - -def load_dataset_any(path: str) -> List[Dict[str, Any]]: - """健壮地加载数据集,支持三种格式: - 1. 标准 JSON 数组: [{...}, {...}] - 2. 单个 JSON 对象: {...} - 3. JSONL 格式(每行一个 JSON): {...}\n{...}\n{...} - """ - with open(path, "r", encoding="utf-8") as f: - content = f.read().strip() - - # 尝试标准 JSON 解析 - try: - data = json.loads(content) - if isinstance(data, list): - return [item for item in data if isinstance(item, dict)] - elif isinstance(data, dict): - return [data] - except json.JSONDecodeError: - pass - - # 尝试 JSONL 格式(每行一个 JSON 对象) - items = [] - for line in content.splitlines(): - line = line.strip() - if not line: - continue - try: - obj = json.loads(line) - if isinstance(obj, dict): - items.append(obj) - elif isinstance(obj, list): - items.extend(item for item in obj if isinstance(item, dict)) - except json.JSONDecodeError: - continue - - return items - - -def is_chinese_text(s: str) -> bool: - return bool(re.search(r"[\u4e00-\u9fff]", s or "")) - - -def extract_candidate_options(question: str) -> List[str]: - """从问题中提取候选选项(A-or-B 类问题)。""" - q = (question or "").strip() - options: List[str] = [] - - # 1) 引号包裹的片段 - for pat in [r"'([^']+)'", r'\"([^\"]+)\"', r'“([^”]+)”', r'‘([^’]+)’']: - for m in re.findall(pat, q): - val = (m or "").strip() - if val: - options.append(val) - - # 2) or/还是/或者 连接词 - if len(options) < 2: - pats = [ - r"([^,;,;]+?)\s+or\s+([^,;,;\?\.!.。!]+)", - r"([^,;,;]+?)\s+还是\s+([^,;,;\?\.!.。!]+)", - r"([^,;,;]+?)\s+或者\s+([^,;,;\?\.!.。!]+)", - ] - for pat in pats: - matches = list(re.finditer(pat, q, flags=re.IGNORECASE)) - if matches: - m = matches[-1] - cand1 = m.group(1).strip().strip("??.,,;; ") - cand2 = m.group(2).strip().strip("??.,,;; ") - options.extend([cand1, cand2]) - break - - # 去重 - seen = set() - uniq: List[str] = [] - for o in options: - o2 = o.strip() - key = o2.lower() if not is_chinese_text(o2) else o2 - if o2 and key not in seen: - uniq.append(o2) - seen.add(key) - return uniq - - -def extract_time_entities(text: str) -> List[Dict[str, Any]]: - """增强时间实体提取,专门用于时间推理问题""" - time_entities = [] - - # 日期模式 - date_patterns = [ - (r'\b(\d{4})-(\d{1,2})-(\d{1,2})\b', 'date'), # YYYY-MM-DD - (r'\b(\d{1,2})月(\d{1,2})日\b', 'date'), # 中文日期 - (r'\b(January|February|March|April|May|June|July|August|September|October|November|December)\s+(\d{1,2}),?\s+(\d{4})?', 'date'), # 英文月份 - (r'\b(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s+(\d{1,2}),?\s+(\d{4})?', 'date'), # 英文月份缩写 - ] - - # 时间间隔模式 - duration_patterns = [ - (r'(\d+)\s*天', 'days'), - (r'(\d+)\s*周', 'weeks'), - (r'(\d+)\s*个月', 'months'), - (r'(\d+)\s*年', 'years'), - (r'(\d+)\s*days?', 'days'), - (r'(\d+)\s*weeks?', 'weeks'), - (r'(\d+)\s*months?', 'months'), - (r'(\d+)\s*years?', 'years'), - ] - - # 事件时间关系模式 - temporal_relation_patterns = [ - (r'(之前|以前|前)\s*(\d+)\s*天', 'days_before'), - (r'(之后|以后|后)\s*(\d+)\s*天', 'days_after'), - (r'(\d+)\s*天\s*(之前|以前|前)', 'days_before'), - (r'(\d+)\s*天\s*(之后|以后|后)', 'days_after'), - (r'(\d+)\s*days?\s*(before|ago)', 'days_before'), - (r'(\d+)\s*days?\s*(after|later)', 'days_after'), - ] - - # 提取日期 - for pattern, entity_type in date_patterns: - matches = re.finditer(pattern, text, re.IGNORECASE) - for match in matches: - time_entities.append({ - 'text': match.group(), - 'type': entity_type, - 'start': match.start(), - 'end': match.end() - }) - - # 提取时间间隔 - for pattern, entity_type in duration_patterns: - matches = re.finditer(pattern, text, re.IGNORECASE) - for match in matches: - time_entities.append({ - 'text': match.group(), - 'type': entity_type, - 'value': int(match.group(1)), - 'start': match.start(), - 'end': match.end() - }) - - # 提取时间关系 - for pattern, entity_type in temporal_relation_patterns: - matches = re.finditer(pattern, text, re.IGNORECASE) - for match in matches: - time_entities.append({ - 'text': match.group(), - 'type': entity_type, - 'value': int(match.group(2)) if match.groups() >= 2 else int(match.group(1)), - 'start': match.start(), - 'end': match.end() - }) - - return time_entities - - -def calculate_time_difference(date1: str, date2: str) -> int: - """计算两个日期之间的天数差""" - try: - # 解析日期格式 - def parse_date(date_str: str) -> datetime: - # 尝试多种日期格式 - formats = [ - '%Y-%m-%d', - '%m月%d日', - '%B %d, %Y', - '%b %d, %Y', - '%Y年%m月%d日' - ] - - for fmt in formats: - try: - return datetime.strptime(date_str, fmt) - except ValueError: - continue - - # 如果都无法解析,返回当前日期 - return datetime.now() - - d1 = parse_date(date1) - d2 = parse_date(date2) - - # 计算天数差(绝对值) - return abs((d2 - d1).days) - except Exception: - return -1 # 表示计算失败 - - -def _extract_cn_tokens(text: str) -> List[str]: - """中文关键词提取(短语级,含数词/日期/常见领域词)""" - if not text: - return [] - t = str(text) - # 去掉常见功能词(粗略,不依赖分词库) - stop_words = [ - "我","我们","你","他","她","它","这","那","哪","一个","一次","一些","什么","怎么","是否","吗","呢", - "很","更","最","已经","正在","将要","马上","尽快","最近","关于","有关","以及","并且","或者","还是", - "因为","所以","如果","但是","而且","然后","之后","之前","同时","另外","并","但","却","被","把","让","给", - "和","与","跟","及","还有","就","都","在","对","对于","的","了","着","过","到","于","从","以","为","向","至","是" - ] - for sw in stop_words: - t = t.replace(sw, " ") - # 去标点 - t = re.sub(r"[,。!?、;:,.!?;:\"'()()[]\[\]\-—…·]", " ", t) - # 基础中文片段(>=2) - base = re.findall(r"[\u4e00-\u9fff]{2,}", t) - # 特殊组合:第X次XXXX - specials = re.findall(r"第[一二三四五六七八九十]+次[\u4e00-\u9fff]{2,6}", text) - # 日期与数字 - dates = re.findall(r"\d{4}年\d{1,2}月\d{1,2}日|\d{1,2}月\d{1,2}日|\d{4}-\d{1,2}-\d{1,2}", text) - numbers = re.findall(r"\b\d+\b", text) - - generic = {"建议","推荐","帮助","提升","技能","有效","团队","参与度","喜欢","开始"} - tokens: List[str] = specials + base + dates + numbers - uniq: List[str] = [] - seen = set() - for tok in tokens: - tok2 = tok.strip() - if len(tok2) < 2 or len(tok2) > 6: - continue - if tok2 in generic: - continue - if tok2 not in seen: - uniq.append(tok2) - seen.add(tok2) - # 排除常见疑问型短语 - blacklist_exact = {"是什么","多少","多少天","哪个","哪些","之间","先","后","之前","之后"} - uniq2: List[str] = [u for u in uniq if u not in blacklist_exact] - return uniq2[:12] - - -def generate_query_keywords_cn(question: str) -> List[str]: - """增强版关键词提取,特别关注技术术语和专有名词""" - if not question: - return [] - - # 提取专有名词(带引号的内容) - quoted_terms = re.findall(r'["""]([^"""]+)["""]', question) - - # 提取技术术语(中英文混合) - tech_terms = re.findall(r'[A-Z][a-zA-Z]+\s+[A-Z][a-zA-Z]+|[A-Za-z]+[\u4e00-\u9fff]+|[\u4e00-\u9fff]+[A-Za-z]+', question) - - # 提取核心名词短语 - core_nouns = re.findall(r'[\u4e00-\u9fff]{2,5}系统|[\u4e00-\u9fff]{2,5}管理|[\u4e00-\u9fff]{2,5}分析|[\u4e00-\u9fff]{2,5}工作坊|[\u4e00-\u9fff]{2,5}研讨会', question) - - # 基础中文片段 - base_tokens = _extract_cn_tokens(question) - - # 特定领域关键词增强 - domain_keywords = [] - # GPS相关 - if any(term in question for term in ["GPS", "导航", "定位系统", "系统运行"]): - domain_keywords.extend(["GPS", "导航系统", "定位", "系统故障", "功能异常"]) - # 活动相关 - if any(term in question for term in ["工作坊", "研讨会", "网络研讨会", "活动"]): - domain_keywords.extend(["工作坊", "研讨会", "参加", "参与", "活动"]) - # 时间顺序相关 - if any(term in question for term in ["先", "后", "第一个", "之前", "首先"]): - domain_keywords.extend(["先", "后", "之前", "之后", "第一次", "首先"]) - # 设备相关 - if any(term in question for term in ["设备", "手机", "电脑", "笔记本电脑"]): - domain_keywords.extend(["设备", "手机", "电脑", "笔记本电脑", "购买"]) - - # 合并并去重 - all_tokens = quoted_terms + tech_terms + core_nouns + base_tokens + domain_keywords - seen = set() - final_tokens = [] - - for token in all_tokens: - token = token.strip() - if len(token) >= 2 and token not in seen: - final_tokens.append(token) - seen.add(token) - - return final_tokens[:8] - - -def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str: - """增强版上下文选择:特别优化技术术语和精确匹配""" - if not contexts: - return "" - - # 检测是否为时间推理问题 - is_temporal_question = any(keyword in question.lower() for keyword in - ['days', 'day', 'before', 'after', 'first', '先后', '顺序', '间隔', '多久', '多少天']) - - # 提取时间实体从问题中 - question_time_entities = extract_time_entities(question) - - # 提取关键技术实体 - key_entities = [] - # GPS相关 - if any(term in question for term in ["GPS", "导航", "定位系统", "系统运行"]): - key_entities.extend(["GPS", "导航", "定位", "系统", "功能", "问题", "故障"]) - # 活动相关 - if any(term in question for term in ["工作坊", "研讨会", "网络研讨会", "活动"]): - key_entities.extend(["工作坊", "研讨会", "参加", "参与", "活动", "时间"]) - # 时间顺序相关 - if any(term in question for term in ["先", "后", "第一个", "之前", "首先"]): - key_entities.extend(["先", "后", "之前", "之后", "第一次", "首先"]) - - # 英文关键词(去停用词) - question_lower = question.lower() - stop_words = { - 'what','when','where','who','why','how','did','do','does','is','are','was','were', - 'the','a','an','and','or','but','many','which','first' - } - eng_words = [w for w in set(re.findall(r'\b\w+\b', question_lower)) - if w not in stop_words and len(w) > 2] - - # 中文片段与候选选项 - cn_tokens = generate_query_keywords_cn(question) - options = extract_candidate_options(question) - - # 时间推理问题的特殊处理 - if is_temporal_question: - # 为时间问题添加时间相关关键词 - time_keywords = ['天', '日', '月', '年', 'before', 'after', 'days', 'first', '先后'] - eng_words = [w for w in eng_words if w not in ['days', 'first']] # 避免重复 - cn_tokens.extend([kw for kw in time_keywords if kw not in cn_tokens]) - - # 限制关键词数量,优先时间相关 - tokens = time_keywords[:2] + key_entities[:3] + cn_tokens[:2] + eng_words[:1] + options[:1] - else: - # 常规问题处理,优先关键技术实体 - tokens = key_entities[:4] + cn_tokens[:3] + options[:2] + eng_words[:1] - - # 去重 - seen = set() - final_tokens: List[str] = [] - for t in tokens: - t2 = t.strip() - if t2 and t2 not in seen: - final_tokens.append(t2) - seen.add(t2) - - scored_contexts: List[tuple[float, str]] = [] - - # 关键技术实体权重映射 - key_entity_weights = { - "GPS": 3.0, "导航": 2.5, "系统": 2.0, "功能": 2.0, "问题": 2.0, "故障": 2.5, - "工作坊": 2.5, "研讨会": 2.5, "参加": 2.0, "参与": 2.0, - "先": 2.0, "后": 2.0, "之前": 2.0, "之后": 2.0, "第一次": 2.5 - } - - # 时间推理问题的权重映射 - temporal_weight_map = { - "天": 2.0, "日": 2.0, "月": 1.8, "年": 1.8, "days": 2.0, - "before": 1.5, "after": 1.5, "first": 1.5, "先后": 1.5 - } - - # 常规问题的权重映射 - normal_weight_map = { - "问题": 2.0, "故障": 2.0, "异常": 1.8, "不正常": 1.8, "坏了": 1.8, - "系统": 1.3, "GPS": 1.5, "保养": 1.4, "设备": 1.2, "模块": 1.2, "功能": 1.1 - } - - # 合并权重映射 - weight_map = {**normal_weight_map, **temporal_weight_map, **key_entity_weights} - - for i, context in enumerate(contexts): - context_str = str(context) - lines = re.split(r'[\r\n]+', context_str) - hit_lines: List[str] = [] - kw_hits: float = 0.0 - time_entity_count = 0 - key_entity_hits = 0 - - for line in lines: - ln = line.strip() - if not ln: - continue - - has_keyword = False - # 关键词匹配 - for tok in final_tokens: - if tok and tok in ln: - w = weight_map.get(tok, 1.0) - hit_count = ln.count(tok) - kw_hits += hit_count * w - # 关键技术实体额外奖励 - if tok in key_entity_weights: - key_entity_hits += hit_count - has_keyword = True - - # 时间实体检测(特别针对时间推理问题) - if is_temporal_question: - time_entities = extract_time_entities(ln) - time_entity_count += len(time_entities) - if time_entities: - has_keyword = True - - # 精确匹配奖励(完整问题关键词出现在上下文中) - for q_word in question.split(): - if len(q_word) > 3 and q_word in ln: - kw_hits += 0.5 # 精确匹配奖励 - - if has_keyword: - # 对于包含关键信息的行,保留完整行 - hit_lines.append(ln) - - snippet = "\n".join(hit_lines) if hit_lines else context_str.strip() - - # 限制单段长度,但对包含关键信息的上下文稍微放宽限制 - max_snippet_len = 600 if (key_entity_hits > 0 or time_entity_count > 0) else 500 - if len(snippet) > max_snippet_len: - snippet = snippet[:max_snippet_len] - - # 评分逻辑 - has_number = 1 if re.search(r'\d', snippet) else 0 - has_date = 1 if (re.search(r'\b\d{4}-\d{1,2}-\d{1,2}\b', snippet) or - re.search(r'\d{1,2}月\d{1,2}日', snippet)) else 0 - - # 关键技术实体奖励 - key_entity_bonus = key_entity_hits * 1.0 - - # 时间推理问题的特殊评分 - if is_temporal_question: - time_bonus = time_entity_count * 2.0 # 时间实体奖励 - temporal_coherence = 3 if (has_date and time_entity_count >= 2) else 0 - else: - time_bonus = 0 - temporal_coherence = 0 - - length_bonus = 5 if 50 < len(snippet) < 1000 else (2 if len(snippet) >= 1000 else 0) - pos_bonus = 3 if i < 3 else 0 - - score = (kw_hits * 0.8 + (has_number + has_date) * 1.5 + - length_bonus + pos_bonus + time_bonus + temporal_coherence + key_entity_bonus) - - scored_contexts.append((score, snippet)) - - # 选择累计至总字符预算 - scored_contexts.sort(key=lambda x: x[0], reverse=True) - selected: List[str] = [] - total_chars = 0 - - for score, snippet in scored_contexts: - if total_chars + len(snippet) <= max_chars: - selected.append(snippet) - total_chars += len(snippet) - else: - if not selected and len(snippet) > max_chars: - selected.append(snippet[:max_chars]) - break - - final_context = "\n\n".join(selected) - - # 对于时间推理问题,添加时间计算提示 - if is_temporal_question and question_time_entities: - time_prompt = "\n\n[时间推理提示:请仔细分析上述上下文中的日期和时间关系,计算时间间隔或确定事件顺序]" - if total_chars + len(time_prompt) <= max_chars: - final_context += time_prompt - - return final_context - - -# 通过别名匹配进行实体关键词检索(多token合并) -async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], end_user_id: str | None, limit: int) -> List[Dict[str, Any]]: - results: List[Dict[str, Any]] = [] - try: - for tok in tokens: - rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, end_user_id=end_user_id, limit=limit) - if rows: - results.extend(rows) - except Exception: - pass - - # 按 name 去重 - deduped: List[Dict[str, Any]] = [] - seen = set() - for r in results: - k = str(r.get("name", "")) - if k and k not in seen: - deduped.append(r) - seen.add(k) - return deduped - - -# 通过对话/陈述中的entity_ids反查实体名称 -_FETCH_ENTITIES_BY_IDS = """ -MATCH (e:ExtractedEntity) -WHERE e.id IN $ids AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id) -RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type -""" - -async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], end_user_id: str | None) -> List[Dict[str, Any]]: - if not ids: - return [] - try: - rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), end_user_id=end_user_id) - return rows or [] - except Exception: - return [] - - -# 增强的时间实体检索 -_TIME_ENTITY_SEARCH = """ -MATCH (e:ExtractedEntity) -WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern -AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id) -RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type -LIMIT $limit -""" - -async def _search_time_entities(connector: Neo4jConnector, end_user_id: str | None, limit: int = 5) -> List[Dict[str, Any]]: - """专门搜索时间相关的实体""" - try: - date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*" - rows = await connector.execute_query(_TIME_ENTITY_SEARCH, - date_pattern=date_pattern, - end_user_id=end_user_id, - limit=limit) - return rows or [] - except Exception: - return [] - - -# 技术术语专门检索 -async def _search_tech_terms(connector: Neo4jConnector, question: str, end_user_id: str | None, limit: int = 3) -> List[Dict[str, Any]]: - """专门搜索技术术语相关的实体""" - tech_entities = [] - try: - # GPS相关 - if any(term in question for term in ["GPS", "导航", "定位系统"]): - gps_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="GPS", end_user_id=end_user_id, limit=limit) - if gps_rows: - tech_entities.extend(gps_rows) - - # 活动相关 - if any(term in question for term in ["工作坊", "研讨会", "网络研讨会"]): - workshop_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="工作坊", end_user_id=end_user_id, limit=limit) - if workshop_rows: - tech_entities.extend(workshop_rows) - - # 时间顺序相关 - if any(term in question for term in ["先", "后", "第一个"]): - time_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="第一次", end_user_id=end_user_id, limit=limit) - if time_rows: - tech_entities.extend(time_rows) - - except Exception: - pass - - return tech_entities - - -# 中英相对时间解析:today/昨天/上周/3天后 等简单归一化为日期 -def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str: - t = str(text) if text is not None else "" - # 英文 today/yesterday/tomorrow - t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE) - - # 英文 X days ago / in X days - def _ago_repl(m: re.Match[str]) -> str: - n = int(m.group(1)) - return (anchor - timedelta(days=n)).date().isoformat() - def _in_repl(m: re.Match[str]) -> str: - n = int(m.group(1)) - return (anchor + timedelta(days=n)).date().isoformat() - t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE) - t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE) - t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE) - - # 中文 今天/昨天/明天 - t = re.sub(r"今天", anchor.date().isoformat(), t) - t = re.sub(r"昨日|昨天", (anchor - timedelta(days=1)).date().isoformat(), t) - t = re.sub(r"明天", (anchor + timedelta(days=1)).date().isoformat(), t) - # 中文 X天前 / X天后 - t = re.sub(r"(\d+)天前", lambda m: (anchor - timedelta(days=int(m.group(1)))).date().isoformat(), t) - t = re.sub(r"(\d+)天后", lambda m: (anchor + timedelta(days=int(m.group(1)))).date().isoformat(), t) - # 中文 上周 / 下周(近似7天) - t = re.sub(r"上周", (anchor - timedelta(days=7)).date().isoformat(), t) - t = re.sub(r"下周", (anchor + timedelta(days=7)).date().isoformat(), t) - # 中文 月日(无年份)补全年份 - def _md_repl(m: re.Match[str]) -> str: - mon = int(m.group(1)); day = int(m.group(2)) - return f"{anchor.year}-{mon:02d}-{day:02d}" - t = re.sub(r"(\d{1,2})月(\d{1,2})日", _md_repl, t) - return t - - -async def run_longmemeval_test( - sample_size: int = 3, - end_user_id: str = "longmemeval_zh_bak_2", - search_limit: int = 8, - context_char_budget: int = 4000, - llm_temperature: float = 0.0, - llm_max_tokens: int = 16, - search_type: str = "hybrid", - data_path: str | None = None, - start_index: int = 0, -) -> Dict[str, Any]: - """LongMemEval 评估测试:增强技术术语检索能力""" - - # 数据路径 - if not data_path: - # 固定使用中文数据集:dataset/longmemeval_oracle_zh.json - dataset_dir = Path(__file__).resolve().parent.parent / "dataset" - data_path = str(dataset_dir / "longmemeval_oracle_zh.json") - - if not os.path.exists(data_path): - raise FileNotFoundError( - f"数据集文件不存在: {data_path}\n" - f"请将 longmemeval_oracle_zh.json 放置在: {dataset_dir}" - ) - - qa_list: List[Dict[str, Any]] = load_dataset_any(data_path) - # 支持评估全部样本:当 sample_size <= 0 时,取从 start_index 到末尾 - if sample_size is None or sample_size <= 0: - items = qa_list[start_index:] - else: - items = qa_list[start_index:start_index + sample_size] - - # 初始化组件 - 使用异步LLM客户端 - llm_client = get_llm_client(os.getenv("EVAL_LLM_ID")) - connector = Neo4jConnector() - cfg_dict = get_embedder_config(os.getenv("EVAL_EMBEDDING_ID")) - embedder = OpenAIEmbedderClient( - model_config=RedBearModelConfig.model_validate(cfg_dict) - ) - - # 指标收集 - latencies_llm: List[float] = [] - latencies_search: List[float] = [] - per_query_context_counts: List[int] = [] - per_query_context_avg_tokens: List[float] = [] - per_query_context_chars: List[int] = [] - - type_correct: Dict[str, List[float]] = {} - type_f1: Dict[str, List[float]] = {} - type_jacc: Dict[str, List[float]] = {} - - samples: List[Dict[str, Any]] = [] - # 统计重复的上下文预览(跨样本),便于诊断"相同上下文"问题 - preview_counter: Dict[str, int] = {} - - try: - for item in items: - question = item.get("question", "") - reference = item.get("answer", "") - qtype = item.get("question_type") or item.get("type", "unknown") - - print(f"\n=== 处理问题: {question} ===") - - # 检测问题类型 - is_temporal = any(keyword in question.lower() for keyword in - ['days', 'day', 'before', 'after', 'first', '先后', '顺序', '间隔', '多久', '多少天']) - - # 检索 - t0 = time.time() - contexts_all: List[str] = [] - dialogs, statements, entities = [], [], [] - - try: - if search_type == "embedding": - search_results = await search_graph_by_embedding( - connector=connector, - embedder_client=embedder, - query_text=question, - end_user_id=end_user_id, - limit=search_limit, - include=["dialogues", "statements", "entities"], - ) - dialogs = search_results.get("dialogues", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - - for d in dialogs: - content = str(d.get("content", "")).strip() - if content: - contexts_all.append(content) - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - # 实体摘要(最多3个) - scored = [e for e in entities if e.get("score") is not None] - top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3] - if top_entities: - summary_lines = [] - for e in top_entities: - name = str(e.get("name", "")).strip() - etype = str(e.get("entity_type", "")).strip() - score = e.get("score") - if name: - meta = [] - if etype: - meta.append(f"type={etype}") - if isinstance(score, (int, float)): - meta.append(f"score={score:.3f}") - summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}") - if summary_lines: - contexts_all.append("\n".join(summary_lines)) - - elif search_type == "keyword": - search_results = await search_graph( - connector=connector, - q=question, - end_user_id=end_user_id, - limit=search_limit, - ) - dialogs = search_results.get("dialogues", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - - for d in dialogs: - content = str(d.get("content", "")).strip() - if content: - contexts_all.append(content) - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - if entities: - entity_names = [str(e.get("name", "")).strip() for e in entities[:5] if e.get("name")] - if entity_names: - contexts_all.append(f"EntitySummary: {', '.join(entity_names)}") - - else: # hybrid(增强版:特别优化技术术语检索) - emb_dialogs, emb_statements, emb_entities = [], [], [] - kw_dialogs, kw_statements, kw_entities = [], [], [] - - # 1) 嵌入检索 - try: - emb_res = await search_graph_by_embedding( - connector=connector, - embedder_client=embedder, - query_text=question, - end_user_id=end_user_id, - limit=search_limit, - include=["dialogues", "statements", "entities"], - ) - if isinstance(emb_res, dict): - emb_dialogs = emb_res.get("dialogues", []) or [] - emb_statements = emb_res.get("statements", []) or [] - emb_entities = emb_res.get("entities", []) or [] - except Exception as e: - print(f"⚠️ 嵌入检索失败,将继续进行关键词检索: {e}") - - # 2) 关键词检索(增强版) - try: - kw_res = await search_graph( - connector=connector, - q=question, - end_user_id=end_user_id, - limit=search_limit, - ) - if isinstance(kw_res, dict): - kw_dialogs = kw_res.get("dialogues", []) or [] - kw_statements = kw_res.get("statements", []) or [] - kw_entities = kw_res.get("entities", []) or [] - - # 技术术语专门检索 - tech_entities = await _search_tech_terms(connector, question, end_user_id, search_limit//2) - if tech_entities: - kw_entities.extend(tech_entities) - - # 时间推理问题的特殊处理 - if is_temporal: - # 专门搜索时间实体 - time_entities = await _search_time_entities(connector, end_user_id, search_limit//2) - if time_entities: - kw_entities.extend(time_entities) - # 添加时间相关关键词检索 - time_keywords = ['天', '日', '月', '年', 'before', 'after', 'first'] - for tk in time_keywords: - try: - time_res = await search_graph( - connector=connector, - q=tk, - end_user_id=end_user_id, - limit=2, - ) - if isinstance(time_res, dict): - kw_dialogs.extend(time_res.get("dialogues", []) or []) - kw_statements.extend(time_res.get("statements", []) or []) - except Exception: - pass - - # 中文关键词拆分后做别名匹配 - cn_tokens = generate_query_keywords_cn(question) # 使用增强版关键词提取 - alias_entities = await _search_entities_by_aliases(connector, cn_tokens, end_user_id, search_limit) - if alias_entities: - kw_entities.extend(alias_entities) - - # 从对话/陈述中的 entity_ids 反查实体 - ids = [] - try: - for d in kw_dialogs: - ids.extend(d.get("entity_ids", []) or []) - for s in kw_statements: - ids.extend(s.get("entity_ids", []) or []) - except Exception: - pass - if ids: - id_entities = await _fetch_entities_by_ids(connector, ids, end_user_id) - if id_entities: - kw_entities.extend(id_entities) - - # 多关键词检索(使用增强版关键词) - try: - eng_words = [w for w in set(re.findall(r"\b\w+\b", question.lower())) if len(w) > 2] - kw_list = generate_query_keywords_cn(question)[:4] # 使用更多关键词 - for kw in kw_list: - if not kw: - continue - sub_res = await search_graph( - connector=connector, - q=str(kw), - end_user_id=end_user_id, - limit=max(3, search_limit // 2), - ) - if isinstance(sub_res, dict): - kw_dialogs.extend(sub_res.get("dialogues", []) or []) - kw_statements.extend(sub_res.get("statements", []) or []) - kw_entities.extend(sub_res.get("entities", []) or []) - except Exception: - pass - - # 选项参与关键词检索 - try: - opt_list = extract_candidate_options(question)[:2] - for opt in opt_list: - if not opt: - continue - opt_res = await search_graph( - connector=connector, - q=str(opt), - end_user_id=end_user_id, - limit=max(3, search_limit // 2), - ) - if isinstance(opt_res, dict): - kw_dialogs.extend(opt_res.get("dialogues", []) or []) - kw_statements.extend(opt_res.get("statements", []) or []) - kw_entities.extend(opt_res.get("entities", []) or []) - except Exception: - pass - except Exception as e: - print(f"❌ 关键词检索失败: {e}") - - # 3) 合并、排序并去重 - all_dialogs = emb_dialogs + kw_dialogs - all_statements = emb_statements + kw_statements - all_entities = emb_entities + kw_entities - - def dedup(items: List[Dict[str, Any]], key_field: str = "uuid") -> List[Dict[str, Any]]: - seen = set() - out = [] - for it in items: - key = str(it.get(key_field, "")) + str(it.get("content", "") + str(it.get("statement", ""))) - if key not in seen: - out.append(it) - seen.add(key) - return out - - # 关键技术实体优先排序 - def enhanced_score(item: Dict[str, Any]) -> float: - score_val = item.get("score", 0.0) - base_score = float(score_val) if score_val is not None else 0.0 - content = str(item.get("content", "") + str(item.get("statement", ""))) - - # 关键技术实体奖励 - key_entities = [] - if any(term in question for term in ["GPS", "导航", "系统"]): - key_entities.extend(["GPS", "导航", "系统", "功能"]) - if any(term in question for term in ["工作坊", "研讨会", "活动"]): - key_entities.extend(["工作坊", "研讨会", "参加"]) - - key_bonus = 0 - for key_ent in key_entities: - if key_ent in content: - key_bonus += 1.0 - - # 时间实体奖励 - time_bonus = 0 - if is_temporal: - time_entities = extract_time_entities(content) - time_bonus = len(time_entities) * 0.5 - - return base_score + key_bonus + time_bonus - - dialogs = dedup(sorted(all_dialogs, key=enhanced_score, reverse=True)) - statements = dedup(sorted(all_statements, key=enhanced_score, reverse=True)) - entities = dedup(all_entities, key_field="name") - - # 4) 构建上下文 - for d in dialogs: - content = str(d.get("content", "")).strip() - if content: - contexts_all.append(content) - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - # 实体摘要 - try: - scored = [e for e in entities if e.get("score") is not None] - top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3] - if top_entities: - summary_lines = [] - for e in top_entities: - name = str(e.get("name", "")).strip() - etype = str(e.get("entity_type", "")).strip() - score = e.get("score") - if name: - meta = [] - if etype: - meta.append(f"type={etype}") - if isinstance(score, (int, float)): - meta.append(f"score={score:.3f}") - summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}") - if summary_lines: - contexts_all.append("\n".join(summary_lines)) - except Exception: - pass - - # 全局回退 - if not contexts_all and search_type in ("embedding", "hybrid"): - try: - print("🔁 检索为空,回退到关键词检索...") - kw_fallback = await search_graph( - connector=connector, - q=question, - end_user_id=end_user_id, - limit=max(search_limit, 5), - ) - fb_dialogs = kw_fallback.get("dialogues", []) or [] - fb_statements = kw_fallback.get("statements", []) or [] - fb_entities = kw_fallback.get("entities", []) or [] - - for d in fb_dialogs: - content = str(d.get("content", "")).strip() - if content: - contexts_all.append(content) - for s in fb_statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - if fb_entities: - entity_names = [str(e.get("name", "")).strip() for e in fb_entities[:5] if e.get("name")] - if entity_names: - contexts_all.append(f"EntitySummary: {', '.join(entity_names)}") - - dialogs = fb_dialogs if fb_dialogs else dialogs - statements = fb_statements if fb_statements else statements - entities = fb_entities if fb_entities else entities - print(f"↩️ 回退到关键词检索: {len(fb_dialogs)} 对话, {len(fb_statements)} 条陈述, {len(fb_entities)} 个实体") - except Exception as fe: - print(f"❌ 关键词回退失败: {fe}") - - ent_count = len(entities) if isinstance(entities, list) else 0 - print(f"✅ {search_type}检索成功: {len(dialogs)} 对话, {len(statements)} 条陈述, {ent_count} 个实体") - if is_temporal: - print("⏰ 检测为时间推理问题,已启用时间优化检索") - - except Exception as e: - print(f"❌ {search_type}检索失败: {e}") - contexts_all = [] - - t1 = time.time() - latencies_search.append((t1 - t0) * 1000) - - # 智能上下文选择 - context_text = "" - if contexts_all: - context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) - # 相对时间解析 - try: - context_text = _resolve_relative_times_cn_en(context_text, anchor=datetime.now()) - except Exception: - pass - # 诊断信息 - try: - cn_diag = generate_query_keywords_cn(question)[:4] # 显示更多关键词 - opts = extract_candidate_options(question)[:2] - qlw = [w for w in set(re.findall(r'\b\w+\b', question.lower())) if len(w) > 2][:1] - diag_tokens: List[str] = [] - for t in cn_diag + opts + qlw: - if t and t not in diag_tokens: - diag_tokens.append(t) - print(f"🔍 关键词/选项: {', '.join(diag_tokens)}") - preview = context_text[:200].replace('\n', ' ') - print(f"🔎 上下文预览: {preview}...") - key_preview = preview.strip() - if key_preview: - preview_counter[key_preview] = preview_counter.get(key_preview, 0) + 1 - except Exception: - pass - else: - print("❌ 没有检索到有效上下文") - context_text = "No relevant context found." - - # 记录上下文诊断信息 - per_query_context_counts.append(len(contexts_all)) - per_query_context_avg_tokens.append(avg_context_tokens([context_text])) - per_query_context_chars.append(len(context_text)) - - # LLM 推理(增强技术术语提示) - options = extract_candidate_options(question) - if len(options) >= 2: - opt_lines = "\n".join(f"- {o}" for o in options) - # 技术术语问题的特殊提示 - if any(term in question for term in ["GPS", "系统", "功能", "工作坊", "研讨会"]): - system_prompt = ( - "You are a QA assistant specializing in technical and activity-related questions. " - "Pay special attention to technical terms like GPS, systems, functions, workshops, and seminars. " - "Return ONLY one string: exactly one option from the provided candidates. If the context is insufficient, respond with 'Unknown'. " - "Focus on matching technical details and activity sequences accurately." - ) - elif is_temporal: - system_prompt = ( - "You are a QA assistant specializing in temporal reasoning. Analyze the dates and time relationships in the context carefully. " - "Return ONLY one string: exactly one option from the provided candidates. If the context is insufficient, respond with 'Unknown'. " - "Pay special attention to date sequences and time intervals." - ) - else: - system_prompt = ( - "You are a QA assistant. Respond in the same language as the question. Return ONLY one string: exactly one option from the provided candidates. " - "If the context is insufficient, respond with 'Unknown'. If the context expresses a synonym or paraphrase of a candidate, return the closest candidate. " - "Do not include explanations." - ) - - messages = [ - {"role": "system", "content": system_prompt}, - { - "role": "user", - "content": ( - f"Question: {question}\n\nCandidates:\n{opt_lines}\n\nContext:\n{context_text}\n\nReturn EXACTLY one candidate string (or 'Unknown')." - ), - }, - ] - else: - # 技术术语问题的特殊提示 - if any(term in question for term in ["GPS", "系统", "功能", "工作坊", "研讨会"]): - system_prompt = ( - "You are a QA assistant specializing in technical and activity-related questions. " - "Pay special attention to technical terms like GPS, systems, functions, workshops, and seminars. " - "If the context contains the answer, return a concise answer phrase focusing on technical details. " - "If the answer cannot be determined from the context, respond with 'Unknown'. Return ONLY the final answer string, no explanations." - ) - elif is_temporal: - system_prompt = ( - "You are a QA assistant specializing in temporal reasoning. Analyze the dates and time relationships in the context carefully. " - "If the context contains the answer, return a concise answer phrase focusing on temporal information. " - "If the answer cannot be determined from the context, respond with 'Unknown'. Return ONLY the final answer string, no explanations." - ) - else: - system_prompt = ( - "You are a QA assistant. Respond in the same language as the question. If the context contains the answer, return a concise answer phrase. " - "If the answer cannot be determined from the context, respond with 'Unknown'. Return ONLY the final answer string, no explanations." - ) - - messages = [ - {"role": "system", "content": system_prompt}, - { - "role": "user", - "content": f"Question: {question}\n\nContext:\n{context_text}\n\nReturn ONLY the answer (or 'Unknown').", - }, - ] - - t2 = time.time() - # 使用异步调用 - resp = await llm_client.chat(messages=messages) - t3 = time.time() - latencies_llm.append((t3 - t2) * 1000) - - # 兼容不同的响应格式 - pred_raw = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown") - - # 选项题输出规范化 - pred = pred_raw - if len(options) >= 2 and not pred_raw.lower().startswith("unknown"): - def _basic_norm(s: str) -> str: - s = s.lower().strip() - return re.sub(r"[^\w\s]", " ", s) - def _jaccard(a: str, b: str) -> float: - ta = set(t for t in _basic_norm(a).split() if t) - tb = set(t for t in _basic_norm(b).split() if t) - if not ta and not tb: - return 1.0 - if not ta or not tb: - return 0.0 - return len(ta & tb) / len(ta | tb) - best = None - best_score = -1.0 - for o in options: - score = _jaccard(pred_raw, o) - if score > best_score: - best = o - best_score = score - if best is not None and best_score > 0.0: - pred = best - - # 指标 - flag = exact_match(pred, reference) - f1_val = common_f1(str(pred), str(reference)) - j_val = jaccard(str(pred), str(reference)) - - type_correct.setdefault(qtype, []).append(flag) - type_f1.setdefault(qtype, []).append(f1_val) - type_jacc.setdefault(qtype, []).append(j_val) - - samples.append({ - "question": question, - "prediction": pred, - "answer": reference, - "question_type": qtype, - "is_temporal": is_temporal, - "question_id": item.get("question_id"), - "options": options, - "context_count": len(contexts_all), - "context_chars": len(context_text), - "retrieved_dialogue_count": len(dialogs), - "retrieved_statement_count": len(statements), - "metrics": { - "exact_match": bool(flag), - "f1": f1_val, - "jaccard": j_val - }, - "timing": { - "search_ms": (t1 - t0) * 1000, - "llm_ms": (t3 - t2) * 1000 - } - }) - - print(f"🤖 LLM 回答: {pred}") - print(f"✅ 正确答案: {reference}") - print(f"📈 当前指标 - Exact Match: {flag}, F1: {f1_val:.3f}, Jaccard: {j_val:.3f}") - - # 聚合结果 - type_acc = {t: (sum(v) / max(len(v), 1)) for t, v in type_correct.items()} - f1_by_type = {t: (sum(v) / max(len(v), 1)) for t, v in type_f1.items()} - jacc_by_type = {t: (sum(v) / max(len(v), 1)) for t, v in type_jacc.items()} - - result = { - "dataset": "longmemeval", - "items": len(items), - "accuracy_by_type": type_acc, - "f1_by_type": f1_by_type, - "jaccard_by_type": jacc_by_type, - "samples": samples, - "latency": { - "search": latency_stats(latencies_search), - "llm": latency_stats(latencies_llm), - }, - "context": { - "avg_tokens": statistics.mean(per_query_context_avg_tokens) if per_query_context_avg_tokens else 0.0, - "avg_chars": statistics.mean(per_query_context_chars) if per_query_context_chars else 0.0, - "count_avg": statistics.mean(per_query_context_counts) if per_query_context_counts else 0.0, - }, - "params": { - "end_user_id": end_user_id, - "search_limit": search_limit, - "context_char_budget": context_char_budget, - "search_type": search_type, - "llm_id": os.getenv("EVAL_LLM_ID"), - "embedding_id": os.getenv("EVAL_EMBEDDING_ID"), - "sample_size": sample_size, - "start_index": start_index, - }, - "timestamp": datetime.now().isoformat() - } - - # 计算汇总指标 - try: - total_items = max(len(samples), 1) - correct_count = sum(1 for s in samples if s.get("metrics", {}).get("exact_match")) - score_accuracy = (correct_count / total_items) * 100.0 - - total_latencies_ms = [] - for s in samples: - t = s.get("timing", {}) - total_latencies_ms.append(float(t.get("search_ms", 0.0)) + float(t.get("llm_ms", 0.0))) - total_lat_stats = latency_stats(total_latencies_ms) if total_latencies_ms else {"p50": 0.0, "iqr": 0.0} - latency_median_s = total_lat_stats.get("p50", 0.0) / 1000.0 - latency_iqr_s = total_lat_stats.get("iqr", 0.0) / 1000.0 - - avg_ctx_tokens = statistics.mean(per_query_context_avg_tokens) if per_query_context_avg_tokens else 0.0 - avg_ctx_tokens_k = avg_ctx_tokens / 1000.0 - - result["metric_summary"] = { - "score_accuracy": score_accuracy, - "latency_median_s": latency_median_s, - "latency_iqr_s": latency_iqr_s, - "avg_context_tokens_k": avg_ctx_tokens_k, - } - except Exception: - result["metric_summary"] = { - "score_accuracy": 0.0, - "latency_median_s": 0.0, - "latency_iqr_s": 0.0, - "avg_context_tokens_k": 0.0, - } - - # 诊断信息 - try: - dups = sorted([(k, c) for k, c in preview_counter.items() if c > 1], key=lambda x: -x[1])[:5] - result["diagnostics"] = { - "duplicate_previews_top": [{"count": c, "preview": k[:120]} for k, c in dups], - "unique_preview_count": len(preview_counter), - } - except Exception: - pass - - return result - - finally: - await connector.close() - - -def main(): - load_dotenv() - parser = argparse.ArgumentParser(description="LongMemEval 评估测试脚本(增强技术术语检索版)") - parser.add_argument("--sample-size", type=int, default=3, help="样本数量(<=0 表示全部)") - parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size)") - parser.add_argument("--start-index", type=int, default=0, help="起始样本索引") - parser.add_argument("--group-id", type=str, default="longmemeval_zh_bak_3", help="图数据库 Group ID") - parser.add_argument("--search-limit", type=int, default=8, help="检索条数上限") - parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算") - parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度") - parser.add_argument("--llm-max-tokens", type=int, default=16, help="LLM 最大输出 token") - parser.add_argument("--search-type", type=str, default="hybrid", choices=["embedding","keyword","hybrid"], help="检索类型") - parser.add_argument("--data-path", type=str, default=None, help="数据集路径") - args = parser.parse_args() - - sample_size = 0 if args.all else args.sample_size - - result = asyncio.run( - run_longmemeval_test( - sample_size=sample_size, - end_user_id=args.end_user_id, - search_limit=args.search_limit, - context_char_budget=args.context_char_budget, - llm_temperature=args.llm_temperature, - llm_max_tokens=args.llm_max_tokens, - search_type=args.search_type, - data_path=args.data_path, - start_index=args.start_index, - ) - ) - - # 打印结果 - print("\n" + "="*50) - print("📊 LongMemEval 测试结果:") - print(f" 样本数量: {result['items']}") - - if result['accuracy_by_type']: - print("\n📈 按问题类型细分:") - for qtype, acc in result['accuracy_by_type'].items(): - print(f" {qtype}:") - print(f" Score (Accuracy): {acc:.3f}") - - print(f"\n📊 指标总览:") - ms = result.get('metric_summary', {}) - print(f" Score (Accuracy): {ms.get('score_accuracy', 0.0):.1f}%") - print(f" Latency (s): median {ms.get('latency_median_s', 0.0):.3f}s") - print(f" Latency IQR (s): {ms.get('latency_iqr_s', 0.0):.3f}s") - print(f" Avg Context Tokens (k): {ms.get('avg_context_tokens_k', 0.0):.3f}k") - - print(f"\n⏱️ 细分性能指标:") - print(f" 检索延迟(均值): {result['latency']['search']['mean']:.1f}ms") - print(f" LLM延迟(均值): {result['latency']['llm']['mean']:.1f}ms") - print(f" 上下文长度(均值): {result['context']['avg_chars']:.0f} 字符") - - - # 保存结果到文件 - try: - out_dir = os.path.join(PROJECT_ROOT, "evaluation", "longmemeval", "results") - os.makedirs(out_dir, exist_ok=True) - ts = datetime.now().strftime("%Y%m%d_%H%M%S") - out_path = os.path.join(out_dir, f"longmemeval_{result['params']['search_type']}_{ts}.json") - with open(out_path, "w", encoding="utf-8") as f: - json.dump(result, f, ensure_ascii=False, indent=2) - print(f"\n💾 结果已保存: {out_path}") - except Exception as e: - print(f"⚠️ 结果保存失败: {e}") - - -if __name__ == "__main__": - main() diff --git a/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py b/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py deleted file mode 100644 index e07b0cab..00000000 --- a/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py +++ /dev/null @@ -1,559 +0,0 @@ -import argparse -import asyncio -import json -import os -import time -from datetime import datetime -from typing import List, Dict, Any -import re -from pathlib import Path - -from dotenv import load_dotenv - -# Load evaluation config -eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation" -if eval_config_path.exists(): - load_dotenv(eval_config_path, override=True) - print(f"✅ 加载评估配置: {eval_config_path}") - -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.src.search import run_hybrid_search # 使用与 evaluate_qa.py 相同的检索函数 -from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.models.base import RedBearModelConfig -from app.core.memory.utils.config.config_utils import get_embedder_config - -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens - -from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard - - -def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str: - """基于问题关键词对上下文进行评分选择,并在预算内拼接文本。 - - 参考 evaluation/memsciqa/evaluate_qa.py 的实现,避免路径导入带来的不稳定。 - """ - if not contexts: - return "" - question_lower = (question or "").lower() - stop_words = { - 'what','when','where','who','why','how','did','do','does','is','are','was','were', - 'the','a','an','and','or','but' - } - question_words = set(re.findall(r"\b\w+\b", question_lower)) - question_words = {w for w in question_words if w not in stop_words and len(w) > 2} - - scored = [] - for i, ctx in enumerate(contexts): - ctx_lower = (ctx or "").lower() - score = 0 - matches = 0 - for w in question_words: - if w in ctx_lower: - matches += 1 - score += ctx_lower.count(w) * 2 - length = len(ctx) - if 100 < length < 2000: - score += 5 - elif length >= 2000: - score += 2 - if i < 3: - score += 3 - scored.append((score, ctx, matches)) - - scored.sort(key=lambda x: x[0], reverse=True) - - selected: List[str] = [] - total = 0 - for score, ctx, _ in scored: - if total + len(ctx) <= max_chars: - selected.append(ctx) - total += len(ctx) - else: - if score > 10 and total < max_chars - 200: - remaining = max_chars - total - lines = ctx.split('\n') - rel_lines: List[str] = [] - cur = 0 - for line in lines: - l = line.lower() - if any(w in l for w in question_words) and cur < remaining - 50: - rel_lines.append(line) - cur += len(line) - if rel_lines: - truncated = '\n'.join(rel_lines) - if len(truncated) > 50: - selected.append(truncated + "\n[相关内容截断...]") - total += len(truncated) - break - return "\n\n".join(selected) - - -def extract_question_keywords(question: str, max_keywords: int = 8) -> List[str]: - """提取问题中的关键词(简单英文分词,去停用词,长度>=3)。""" - ql = (question or "").lower() - stop_words = { - 'what','when','where','who','why','how','did','do','does','is','are','was','were', - 'the','a','an','and','or','but','of','to','in','on','for','with','from','that','this' - } - words = re.findall(r"\b[\w-]+\b", ql) - kws = [w for w in words if w not in stop_words and len(w) >= 3] - # 去重保序 - seen = set() - uniq = [] - for w in kws: - if w not in seen: - uniq.append(w) - seen.add(w) - if len(uniq) >= max_keywords: - break - return uniq - - -def analyze_contexts_simple(contexts: List[str], keywords: List[str], top_n: int = 5) -> List[Dict[str, int | float]]: - """对上下文进行简单相关性打分,仅用于控制台可视化。 - - 评分: score = match_count*200 + min(len(text), 100000)/100 - """ - results = [] - for ctx in contexts: - tl = (ctx or "").lower() - match_count = sum(1 for k in keywords if k in tl) - length = len(ctx) - score = match_count * 200 + min(length, 100000) / 100.0 - results.append({"score": float(f"{score:.0f}"), "match": match_count, "length": length}) - results.sort(key=lambda x: (x["score"], x["match"], x["length"]), reverse=True) - return results[:max(top_n, 0)] - - -# 纯测试脚本不进行摄入;若需摄入请使用 evaluate_qa.py - - -def load_dataset_memsciqa(data_path: str) -> List[Dict[str, Any]]: - if not os.path.exists(data_path): - raise FileNotFoundError(f"未找到数据集: {data_path}") - items: List[Dict[str, Any]] = [] - with open(data_path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if not line: - continue - try: - items.append(json.loads(line)) - except Exception: - # 跳过坏行但不中断 - continue - return items - - -async def run_memsciqa_test( - sample_size: int = 3, - end_user_id: str | None = None, - search_limit: int = 8, - context_char_budget: int = 4000, - llm_temperature: float = 0.0, - llm_max_tokens: int = 64, - search_type: str = "embedding", - data_path: str | None = None, - start_index: int = 0, - verbose: bool = True, -) -> Dict[str, Any]: - """memsciqa 增强测试脚本:结合 evaluate_qa 的三路检索与智能上下文选择。 - - - 支持从指定索引开始与评估全部样本(sample_size<=0) - - 支持在摄入前重置组(清空图)与跳过摄入 - - 支持 keyword / embedding / hybrid 三种检索 - """ - - # 默认使用指定的 memsci 组 ID - end_user_id = end_user_id or "group_memsci" - - # 数据路径解析 - if not data_path: - dataset_dir = Path(__file__).resolve().parent.parent / "dataset" - data_path = str(dataset_dir / "msc_self_instruct.jsonl") - - if not os.path.exists(data_path): - raise FileNotFoundError( - f"数据集文件不存在: {data_path}\n" - f"请将 msc_self_instruct.jsonl 放置在: {dataset_dir}" - ) - - # 加载数据 - all_items = load_dataset_memsciqa(data_path) - if sample_size is None or sample_size <= 0: - items = all_items[start_index:] - else: - items = all_items[start_index:start_index + sample_size] - - # 初始化 LLM(纯测试:不进行摄入) - llm = get_llm_client(os.getenv("EVAL_LLM_ID")) - - # 初始化 Neo4j 连接与向量检索 Embedder(对齐 locomo_test) - connector = Neo4jConnector() - embedder = None - if search_type in ("embedding", "hybrid"): - cfg_dict = get_embedder_config(os.getenv("EVAL_EMBEDDING_ID")) - embedder = OpenAIEmbedderClient( - model_config=RedBearModelConfig.model_validate(cfg_dict) - ) - - # 评估循环 - latencies_llm: List[float] = [] - latencies_search: List[float] = [] - # 存储完整上下文文本用于统计 - contexts_used: List[str] = [] - per_query_context_chars: List[int] = [] - per_query_context_counts: List[int] = [] - correct_flags: List[float] = [] - f1s: List[float] = [] - b1s: List[float] = [] - jss: List[float] = [] - samples: List[Dict[str, Any]] = [] - - total_items = len(items) - for idx, item in enumerate(items): - if verbose: - print(f"\n🧪 评估样本: {idx+1}/{total_items}") - question = item.get("self_instruct", {}).get("B", "") or item.get("question", "") - reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "") - - # 检索:使用与 evaluate_qa.py 相同的 run_hybrid_search - t0 = time.time() - results = None - try: - if search_type in ("embedding", "hybrid"): - # 使用嵌入检索(与 qwen_search_eval 对齐) - results = await search_graph_by_embedding( - connector=connector, - embedder_client=embedder, - query_text=question, - end_user_id=end_user_id, - limit=search_limit, - include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues - ) - elif search_type == "keyword": - # 关键词检索(直接调用 graph_search) - results = await search_graph( - connector=connector, - q=question, - end_user_id=end_user_id, - limit=search_limit, - include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues - ) - except Exception: - results = None - t1 = time.time() - search_ms = (t1 - t0) * 1000 - latencies_search.append(search_ms) - - # 构建上下文:与 evaluate_qa.py 完全一致的逻辑 - contexts_all: List[str] = [] - retrieved_counts: Dict[str, int] = {} - if results: - # 处理 hybrid 搜索结果 - if search_type == "hybrid": - emb = results.get("embedding_search", {}) if isinstance(results.get("embedding_search"), dict) else {} - kw = results.get("keyword_search", {}) if isinstance(results.get("keyword_search"), dict) else {} - emb_dialogs = emb.get("dialogues", []) - emb_statements = emb.get("statements", []) - emb_entities = emb.get("entities", []) - kw_dialogs = kw.get("dialogues", []) - kw_statements = kw.get("statements", []) - kw_entities = kw.get("entities", []) - all_dialogs = emb_dialogs + kw_dialogs - all_statements = emb_statements + kw_statements - all_entities = emb_entities + kw_entities - - # 简单去重 - seen_dialog = set() - dialogues = [] - for d in all_dialogs: - key = (str(d.get("uuid", "")), str(d.get("content", ""))) - if key not in seen_dialog: - dialogues.append(d) - seen_dialog.add(key) - - seen_stmt = set() - statements = [] - for s in all_statements: - key = str(s.get("statement", "")) - if key not in seen_stmt: - statements.append(s) - seen_stmt.add(key) - - seen_ent = set() - entities = [] - for e in all_entities: - key = str(e.get("name", "")) - if key not in seen_ent: - entities.append(e) - seen_ent.add(key) - else: - # embedding 或 keyword 单独搜索 - dialogues = results.get("dialogues", []) - statements = results.get("statements", []) - entities = results.get("entities", []) - - retrieved_counts = { - "dialogues": len(dialogues), - "statements": len(statements), - "entities": len(entities), - } - - # 构建上下文文本 - for d in dialogues: - text = str(d.get("content", "")).strip() - if text: - contexts_all.append(text) - - for s in statements: - text = str(s.get("statement", "")).strip() - if text: - contexts_all.append(text) - - # 实体摘要 - if entities: - scored = [e for e in entities if e.get("score") is not None] - top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3] - if top_entities: - summary_lines = [] - for e in top_entities: - name = str(e.get("name", "")).strip() - etype = str(e.get("entity_type", "")).strip() - score = e.get("score") - if name: - meta = [] - if etype: - meta.append(f"type={etype}") - if isinstance(score, (int, float)): - meta.append(f"score={score:.3f}") - summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}") - if summary_lines: - contexts_all.append("\n".join(summary_lines)) - - if verbose: - if retrieved_counts: - print(f"✅ 检索成功: {retrieved_counts.get('dialogues',0)} dialogues, {retrieved_counts.get('statements',0)} 条陈述, {retrieved_counts.get('entities',0)} 个实体, {retrieved_counts.get('summaries',0)} 个摘要") - print(f"📊 有效上下文数量: {len(contexts_all)}") - q_keywords = extract_question_keywords(question, max_keywords=8) - if q_keywords: - print(f"🔍 问题关键词: {set(q_keywords)}") - if contexts_all: - analysis = analyze_contexts_simple(contexts_all, q_keywords, top_n=5) - if analysis: - print("📊 上下文相关性分析:") - for a in analysis: - print(f" - 得分: {int(a['score'])}, 关键词匹配: {a['match']}, 长度: {a['length']}") - # 打印检索到的上下文预览,便于定位为何为 Unknown - print("🔎 上下文预览(最多前10条,每条截断展示):") - for i, ctx in enumerate(contexts_all[:10]): - preview = str(ctx).replace("\n", " ") - if len(preview) > 300: - preview = preview[:300] + "..." - print(f" [{i+1}] 长度: {len(ctx)} | 片段: {preview}") - # 标注参考答案是否出现在任一上下文中 - ref_lower = (str(reference) or "").lower() - if ref_lower: - hits = [] - for i, ctx in enumerate(contexts_all): - if ref_lower in str(ctx).lower(): - hits.append(i+1) - print(f"🔗 参考答案命中上下文条数: {len(hits)}" + (f" | 命中索引: {hits}" if hits else "")) - - context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else "" - if not context_text: - context_text = "No relevant context found." - contexts_used.append(context_text) - per_query_context_chars.append(len(context_text)) - per_query_context_counts.append(len(contexts_all)) - - if verbose: - selected_count = (context_text.count("\n\n") + 1) if context_text else 0 - print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {len(context_text)}字符") - # 展示拼接后的上下文片段,便于核查是否包含答案 - concat_preview = context_text.replace("\n", " ") - if len(concat_preview) > 600: - concat_preview = concat_preview[:600] + "..." - print(f"🧵 拼接上下文预览: {concat_preview}") - - messages = [ - { - "role": "system", - "content": ( - "You are a QA assistant. Answer in English. Follow these guidelines:\n" - "1) If the context contains information to answer the question, provide a concise answer based on the context;\n" - "2) If the context does not contain enough information to answer the question, respond with 'Unknown';\n" - "3) Keep your answer brief and to the point;\n" - "4) Do not add explanations or additional text beyond the answer." - ), - }, - {"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"}, - ] - - t2 = time.time() - try: - # 使用异步调用 - resp = await llm.chat(messages=messages) - # 更健壮的响应解析,处理不同的LLM响应格式 - if hasattr(resp, 'content'): - pred = resp.content.strip() - elif isinstance(resp, dict) and "choices" in resp and len(resp["choices"]) > 0: - pred = resp["choices"][0]["message"]["content"].strip() - elif isinstance(resp, dict) and "content" in resp: - pred = resp["content"].strip() - elif isinstance(resp, str): - pred = resp.strip() - else: - pred = "Unknown" - print(f"⚠️ LLM响应格式异常: {type(resp)} - {resp}") - - # 检查预测是否为"Unknown"或空,如果是则检查上下文是否真的没有答案 - if pred.lower() in ["unknown", ""]: - # 如果参考答案在上下文中存在,但LLM返回Unknown,可能是提示词问题 - ref_lower = (str(reference) or "").lower() - if ref_lower and any(ref_lower in ctx.lower() for ctx in contexts_all): - print("⚠️ 参考答案在上下文中存在但LLM返回Unknown,检查提示词") - except Exception as e: - # 更详细的错误处理 - pred = "Unknown" - print(f"⚠️ LLM调用异常: {e}") - t3 = time.time() - llm_ms = (t3 - t2) * 1000 - latencies_llm.append(llm_ms) - - exact = exact_match(pred, reference) - correct_flags.append(exact) - f1_val = f1_score(str(pred), str(reference)) - b1_val = bleu1(str(pred), str(reference)) - j_val = jaccard(str(pred), str(reference)) - f1s.append(f1_val) - b1s.append(b1_val) - jss.append(j_val) - - if verbose: - print(f"🤖 LLM 回答: {pred}") - print(f"✅ 正确答案: {reference}") - print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}") - print(f"⏱️ 延迟 - 检索: {search_ms:.0f}ms, LLM: {llm_ms:.0f}ms") - - # 对齐 locomo/qwen_search_eval.py 的样本输出结构 - samples.append({ - "question": str(question), - "answer": str(reference), - "prediction": str(pred), - "metrics": { - "f1": f1_val, - "b1": b1_val, - "j": j_val - }, - "retrieval": { - "retrieved_documents": len(contexts_all), - "context_length": len(context_text), - "search_limit": search_limit, - "max_chars": context_char_budget - }, - "timing": { - "search_ms": search_ms, - "llm_ms": llm_ms - } - }) - - # 计算总体指标与聚合 - acc = sum(correct_flags) / max(len(correct_flags), 1) - ctx_avg_tokens = avg_context_tokens(contexts_used) - result = { - "dataset": "memsciqa", - "items": len(items), - "metrics": { - "f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0, - "b1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0, - "j": (sum(jss) / max(len(jss), 1)) if jss else 0.0, - }, - "context": { - "avg_tokens": ctx_avg_tokens, - "avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0, - "count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0, - "avg_memory_tokens": 0.0 - }, - "latency": { - "search": latency_stats(latencies_search), - "llm": latency_stats(latencies_llm), - }, - "samples": samples, - "params": { - "end_user_id": end_user_id, - "search_limit": search_limit, - "context_char_budget": context_char_budget, - "llm_temperature": llm_temperature, - "llm_max_tokens": llm_max_tokens, - "search_type": search_type, - "start_index": start_index, - "llm_id": os.getenv("EVAL_LLM_ID"), - "retrieval_embedding_id": os.getenv("EVAL_EMBEDDING_ID") - }, - "timestamp": datetime.now().isoformat(), - } - try: - await connector.close() - except Exception: - pass - return result - - -def main(): - load_dotenv() - parser = argparse.ArgumentParser(description="memsciqa 测试脚本(三路检索 + 智能上下文选择)") - parser.add_argument("--sample-size", type=int, default=10, help="样本数量(<=0 表示全部)") - parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size)") - parser.add_argument("--start-index", type=int, default=0, help="起始样本索引") - parser.add_argument("--group-id", type=str, default="group_memsci", help="图数据库 Group ID(默认 group_memsci)") - parser.add_argument("--search-limit", type=int, default=8, help="检索条数上限") - parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算") - parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度") - parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大输出 token") - parser.add_argument("--search-type", type=str, default="embedding", choices=["embedding","keyword","hybrid"], help="检索类型(hybrid 等同于 embedding)") - parser.add_argument("--data-path", type=str, default=None, help="数据集路径(默认 data/msc_self_instruct.jsonl)") - parser.add_argument("--output", type=str, default=None, help="将评估结果保存到指定文件路径(JSON)") - parser.add_argument("--verbose", action="store_true", default=True, help="打印过程日志(默认开启)") - parser.add_argument("--quiet", action="store_true", help="关闭过程日志") - args = parser.parse_args() - - sample_size = 0 if args.all else args.sample_size - - verbose_flag = False if args.quiet else args.verbose - result = asyncio.run( - run_memsciqa_test( - sample_size=sample_size, - end_user_id=args.end_user_id, - search_limit=args.search_limit, - context_char_budget=args.context_char_budget, - llm_temperature=args.llm_temperature, - llm_max_tokens=args.llm_max_tokens, - search_type=args.search_type, - data_path=args.data_path, - start_index=args.start_index, - verbose=verbose_flag, - ) - ) - - print(json.dumps(result, ensure_ascii=False, indent=2)) - - # 结果保存 - out_path = args.output - if not out_path: - eval_dir = os.path.dirname(os.path.abspath(__file__)) - dataset_results_dir = os.path.join(eval_dir, "results") - ts = datetime.now().strftime("%Y%m%d_%H%M%S") - out_path = os.path.join(dataset_results_dir, f"memsciqa_{result['params']['search_type']}_{ts}.json") - try: - os.makedirs(os.path.dirname(out_path), exist_ok=True) - with open(out_path, "w", encoding="utf-8") as f: - json.dump(result, f, ensure_ascii=False, indent=2) - print(f"\n💾 结果已保存: {out_path}") - except Exception as e: - print(f"⚠️ 结果保存失败: {e}") - - -if __name__ == "__main__": - main() diff --git a/api/app/core/memory/evaluation/memsciqa/memsciqa_benchmark.py b/api/app/core/memory/evaluation/memsciqa/memsciqa_benchmark.py deleted file mode 100644 index 40684f4c..00000000 --- a/api/app/core/memory/evaluation/memsciqa/memsciqa_benchmark.py +++ /dev/null @@ -1,369 +0,0 @@ -import argparse -import asyncio -import json -import os -import time -from datetime import datetime -from typing import List, Dict, Any -from pathlib import Path -from dotenv import load_dotenv - -# Load evaluation config -eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation" -if eval_config_path.exists(): - load_dotenv(eval_config_path, override=True) - -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.src.search import run_hybrid_search # 使用旧版本(重构前) -from app.core.memory.utils.llm.llm_utils import get_llm_client -from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline -from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens - - -def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str: - """基于问题关键词对上下文进行评分选择,并在预算内拼接文本。""" - if not contexts: - return "" - import re - # 提取问题关键词(移除停用词) - question_lower = (question or "").lower() - stop_words = { - 'what','when','where','who','why','how','did','do','does','is','are','was','were', - 'the','a','an','and','or','but' - } - question_words = set(re.findall(r"\b\w+\b", question_lower)) - question_words = {w for w in question_words if w not in stop_words and len(w) > 2} - - # 评分 - scored = [] - for i, ctx in enumerate(contexts): - ctx_lower = (ctx or "").lower() - score = 0 - matches = 0 - for w in question_words: - if w in ctx_lower: - matches += 1 - score += ctx_lower.count(w) * 2 - length = len(ctx) - if 100 < length < 2000: - score += 5 - elif length >= 2000: - score += 2 - if i < 3: - score += 3 - scored.append((score, ctx, matches)) - - scored.sort(key=lambda x: x[0], reverse=True) - - # 选择直到达到字符限制,必要时截断包含关键词的段落 - selected: List[str] = [] - total = 0 - for score, ctx, _ in scored: - if total + len(ctx) <= max_chars: - selected.append(ctx) - total += len(ctx) - else: - if score > 10 and total < max_chars - 200: - remaining = max_chars - total - lines = ctx.split('\n') - rel_lines: List[str] = [] - cur = 0 - for line in lines: - l = line.lower() - if any(w in l for w in question_words) and cur < remaining - 50: - rel_lines.append(line) - cur += len(line) - if rel_lines: - truncated = '\n'.join(rel_lines) - if len(truncated) > 50: - selected.append(truncated + "\n[相关内容截断...]") - total += len(truncated) - break - return "\n\n".join(selected) - - -def build_context_from_dialog(dialog_obj: Dict[str, Any]) -> str: - """Compose a text context from `dialog` list in msc_self_instruct item.""" - parts: List[str] = [] - for turn in dialog_obj.get("dialog", []): - speaker = turn.get("speaker", "") - text = turn.get("text", "") - if text: - parts.append(f"{speaker}: {text}") - return "\n".join(parts) - - -def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any]]: - """Combine dialogues from embedding and keyword searches (embedding first).""" - if results is None: - return [] - emb = [] - kw = [] - if isinstance(results.get("embedding_search"), dict): - emb = results.get("embedding_search", {}).get("dialogues", []) or [] - elif isinstance(results.get("dialogues"), list): - emb = results.get("dialogues", []) or [] - if isinstance(results.get("keyword_search"), dict): - kw = results.get("keyword_search", {}).get("dialogues", []) or [] - seen = set() - merged: List[Dict[str, Any]] = [] - for d in emb: - k = (str(d.get("uuid", "")), str(d.get("content", ""))) - if k not in seen: - merged.append(d) - seen.add(k) - for d in kw: - k = (str(d.get("uuid", "")), str(d.get("content", ""))) - if k not in seen: - merged.append(d) - seen.add(k) - return merged - - - -async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]: - end_user_id = end_user_id or SELECTED_GROUP_ID - - # Load data - dataset_dir = Path(__file__).resolve().parent.parent / "dataset" - data_path = dataset_dir / "msc_self_instruct.jsonl" - - if not os.path.exists(data_path): - raise FileNotFoundError( - f"数据集文件不存在: {data_path}\n" - f"请将 msc_self_instruct.jsonl 放置在: {dataset_dir}" - ) - with open(data_path, "r", encoding="utf-8") as f: - lines = f.readlines() - items: List[Dict[str, Any]] = [json.loads(l) for l in lines[:sample_size]] - - - # 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入 - # 说明:memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略 - contexts: List[str] = [build_context_from_dialog(item) for item in items] - await ingest_contexts_via_full_pipeline(contexts, end_user_id) - - # LLM client (使用异步调用) - from app.db import get_db - - db = next(get_db()) - try: - llm_client = get_llm_client(os.getenv("EVAL_LLM_ID"), db) - finally: - db.close() - - # Evaluate each item - connector = Neo4jConnector() - latencies_llm: List[float] = [] - latencies_search: List[float] = [] - contexts_used: List[str] = [] - correct_flags: List[float] = [] - f1s: List[float] = [] - b1s: List[float] = [] - jss: List[float] = [] - try: - for item in items: - question = item.get("self_instruct", {}).get("B", "") or item.get("question", "") - reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "") - # 检索:对齐 locomo 的三路检索(dialogues/statements/entities) - t0 = time.time() - try: - results = await run_hybrid_search( - query_text=question, - search_type=search_type, - end_user_id=end_user_id, - limit=search_limit, - include=["dialogues", "statements", "entities"], - output_path=None, - ) - except Exception: - results = None - t1 = time.time() - latencies_search.append((t1 - t0) * 1000) - - # 构建上下文:包含对话、陈述和实体摘要,并智能选择 - contexts_all: List[str] = [] - if results: - if search_type == "hybrid": - emb = results.get("embedding_search", {}) if isinstance(results.get("embedding_search"), dict) else {} - kw = results.get("keyword_search", {}) if isinstance(results.get("keyword_search"), dict) else {} - emb_dialogs = emb.get("dialogues", []) - emb_statements = emb.get("statements", []) - emb_entities = emb.get("entities", []) - kw_dialogs = kw.get("dialogues", []) - kw_statements = kw.get("statements", []) - kw_entities = kw.get("entities", []) - all_dialogs = emb_dialogs + kw_dialogs - all_statements = emb_statements + kw_statements - all_entities = emb_entities + kw_entities - - # 简单去重与限制 - seen_texts = set() - for d in all_dialogs: - text = str(d.get("content", "")).strip() - if text and text not in seen_texts: - contexts_all.append(text) - seen_texts.add(text) - if len(contexts_all) >= search_limit: - break - for s in all_statements: - text = str(s.get("statement", "")).strip() - if text and text not in seen_texts: - contexts_all.append(text) - seen_texts.add(text) - if len(contexts_all) >= search_limit: - break - # 实体摘要(最多3个) - names = [] - merged_entities = all_entities[:] - for e in merged_entities: - name = str(e.get("name", "")).strip() - if name and name not in names: - names.append(name) - if len(names) >= 3: - break - if names: - contexts_all.append("EntitySummary: " + ", ".join(names)) - else: - dialogs = results.get("dialogues", []) - statements = results.get("statements", []) - entities = results.get("entities", []) - for d in dialogs: - text = str(d.get("content", "")).strip() - if text: - contexts_all.append(text) - for s in statements: - text = str(s.get("statement", "")).strip() - if text: - contexts_all.append(text) - names = [str(e.get("name", "")).strip() for e in entities[:3] if e.get("name")] - if names: - contexts_all.append("EntitySummary: " + ", ".join(names)) - - # 智能选择并截断到预算 - context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else "" - if not context_text: - context_text = "No relevant context found." - contexts_used.append(context_text[:200]) - - # Call LLM (使用异步调用) - messages = [ - {"role": "system", "content": "You are a QA assistant. Answer in English. Strictly follow: 1) If the context contains the answer, copy the shortest exact span from the context as the answer; 2) If the answer cannot be determined from the context, respond with 'Unknown'; 3) Return ONLY the answer text, no explanations."}, - {"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"}, - ] - t2 = time.time() - resp = await llm_client.chat(messages=messages) - t3 = time.time() - latencies_llm.append((t3 - t2) * 1000) - pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip()) - # Metrics: F1, BLEU-1, Jaccard; keep exact match for reference - correct_flags.append(exact_match(pred, reference)) - from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard - f1s.append(f1_score(str(pred), str(reference))) - b1s.append(bleu1(str(pred), str(reference))) - jss.append(jaccard(str(pred), str(reference))) - - # Aggregate metrics - acc = sum(correct_flags) / max(len(correct_flags), 1) - ctx_avg_tokens = avg_context_tokens(contexts_used) - result = { - "dataset": "memsciqa", - "items": len(items), - "metrics": { - "accuracy": acc, - # Placeholders for extensibility - "f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0, - "bleu1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0, - "jaccard": (sum(jss) / max(len(jss), 1)) if jss else 0.0, - }, - "latency": { - "search": latency_stats(latencies_search), - "llm": latency_stats(latencies_llm), - }, - "avg_context_tokens": ctx_avg_tokens, - } - return result - finally: - await connector.close() - - -def main(): - # Load environment variables first - load_dotenv() - - # Get defaults from environment variables - env_sample_size = os.getenv("MEMSCIQA_SAMPLE_SIZE") - env_search_limit = os.getenv("MEMSCIQA_SEARCH_LIMIT") - env_context_budget = os.getenv("MEMSCIQA_CONTEXT_CHAR_BUDGET") - env_llm_max_tokens = os.getenv("MEMSCIQA_LLM_MAX_TOKENS") - env_skip_ingest = os.getenv("MEMSCIQA_SKIP_INGEST", "false").lower() in ("true", "1", "yes") - env_output_dir = os.getenv("MEMSCIQA_OUTPUT_DIR") - - # Convert to appropriate types with fallback to code defaults - default_sample_size = int(env_sample_size) if env_sample_size else 1 - default_search_limit = int(env_search_limit) if env_search_limit else 8 - default_context_budget = int(env_context_budget) if env_context_budget else 4000 - default_llm_max_tokens = int(env_llm_max_tokens) if env_llm_max_tokens else 64 - default_output_dir = env_output_dir if env_output_dir else None - - parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen") - - parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量") - parser.add_argument("--end-user-id", type=str, default=None, help="可选 end_user_id,默认使用环境变量") - parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数") - parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算") - - parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度") - parser.add_argument("--llm-max-tokens", type=int, default=default_llm_max_tokens, - help=f"LLM 最大生成长度 (env: MEMSCIQA_LLM_MAX_TOKENS={env_llm_max_tokens or 'not set'})") - parser.add_argument("--search-type", type=str, choices=["keyword","embedding","hybrid"], default="hybrid", help="检索类型") - parser.add_argument("--skip-ingest", action="store_true", default=env_skip_ingest, - help=f"跳过数据摄入,使用 Neo4j 中的现有数据 (env: MEMSCIQA_SKIP_INGEST={os.getenv('MEMSCIQA_SKIP_INGEST', 'false')})") - parser.add_argument("--output-dir", type=str, default=default_output_dir, - help=f"结果保存目录 (env: MEMSCIQA_OUTPUT_DIR={env_output_dir or 'not set'})") - args = parser.parse_args() - - result = asyncio.run( - run_memsciqa_eval( - sample_size=args.sample_size, - end_user_id=args.end_user_id, - search_limit=args.search_limit, - context_char_budget=args.context_char_budget, - llm_temperature=args.llm_temperature, - llm_max_tokens=args.llm_max_tokens, - search_type=args.search_type, - skip_ingest=args.skip_ingest, - ) - ) - - # Print results to console - print(json.dumps(result, ensure_ascii=False, indent=2)) - - # Save results to file - output_dir = args.output_dir - if output_dir is None: - # Use absolute path to ensure results are saved in the correct location - script_dir = Path(__file__).resolve().parent - output_dir = script_dir / "results" - elif not Path(output_dir).is_absolute(): - # If relative path, make it relative to this script's directory - script_dir = Path(__file__).resolve().parent - output_dir = script_dir / output_dir - else: - output_dir = Path(output_dir) - - output_dir.mkdir(parents=True, exist_ok=True) - - timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") - output_path = output_dir / f"memsciqa_{timestamp_str}.json" - - try: - with open(output_path, "w", encoding="utf-8") as f: - json.dump(result, f, ensure_ascii=False, indent=2) - print(f"\n✅ 结果已保存到: {output_path}") - except Exception as e: - print(f"\n❌ 保存结果失败: {e}") - - -if __name__ == "__main__": - main() diff --git a/api/app/core/memory/evaluation/run_eval.py b/api/app/core/memory/evaluation/run_eval.py deleted file mode 100644 index 56b2e790..00000000 --- a/api/app/core/memory/evaluation/run_eval.py +++ /dev/null @@ -1,147 +0,0 @@ -import argparse -import asyncio -import json -import os -from typing import Any, Dict -from pathlib import Path -from dotenv import load_dotenv - -# Load evaluation config -eval_config_path = Path(__file__).resolve().parent / ".env.evaluation" -if eval_config_path.exists(): - load_dotenv(eval_config_path, override=True) - -from app.repositories.neo4j.neo4j_connector import Neo4jConnector - -from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval -from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test -from app.core.memory.evaluation.locomo.qwen_search_eval import run_locomo_eval - - -async def run( - dataset: str, - sample_size: int, - reset_group: bool, - end_user_id: str | None, - judge_model: str | None = None, - search_limit: int | None = None, - context_char_budget: int | None = None, - llm_temperature: float | None = None, - llm_max_tokens: int | None = None, - search_type: str | None = None, - start_index: int | None = None, - max_contexts_per_item: int | None = None, -) -> Dict[str, Any]: - # Use environment variable with fallback chain if not provided - if end_user_id is None: - end_user_id = os.getenv("EVAL_END_USER_ID", "benchmark_default") - - if reset_group: - connector = Neo4jConnector() - try: - await connector.delete_group(end_user_id) - finally: - await connector.close() - - if dataset == "locomo": - kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id} - if search_limit is not None: - kwargs["search_limit"] = search_limit - if context_char_budget is not None: - kwargs["context_char_budget"] = context_char_budget - if llm_temperature is not None: - kwargs["llm_temperature"] = llm_temperature - if llm_max_tokens is not None: - kwargs["llm_max_tokens"] = llm_max_tokens - if search_type is not None: - kwargs["search_type"] = search_type - return await run_locomo_eval(**kwargs) - - if dataset == "memsciqa": - kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id} - if search_limit is not None: - kwargs["search_limit"] = search_limit - if context_char_budget is not None: - kwargs["context_char_budget"] = context_char_budget - if llm_temperature is not None: - kwargs["llm_temperature"] = llm_temperature - if llm_max_tokens is not None: - kwargs["llm_max_tokens"] = llm_max_tokens - if search_type is not None: - kwargs["search_type"] = search_type - return await run_memsciqa_eval(**kwargs) - - if dataset == "longmemeval": - kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id} - if search_limit is not None: - kwargs["search_limit"] = search_limit - if context_char_budget is not None: - kwargs["context_char_budget"] = context_char_budget - if llm_temperature is not None: - kwargs["llm_temperature"] = llm_temperature - if llm_max_tokens is not None: - kwargs["llm_max_tokens"] = llm_max_tokens - if search_type is not None: - kwargs["search_type"] = search_type - if start_index is not None: - kwargs["start_index"] = start_index - if max_contexts_per_item is not None: - kwargs["max_contexts_per_item"] = max_contexts_per_item - return await run_longmemeval_test(**kwargs) - raise ValueError(f"未知数据集: {dataset}") - - -def main(): - load_dotenv() - parser = argparse.ArgumentParser(description="统一评估入口:memsciqa / longmemeval / locomo") - parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True) - parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通") - parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 end_user_id 的图数据") - parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id,默认取 runtime.json") - parser.add_argument("--judge-model", type=str, default=None, help="可选:longmemeval 判别式评测模型名") - parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)") - parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)") - parser.add_argument("--llm-temperature", type=float, default=None, help="生成温度(不提供则使用各脚本默认)") - parser.add_argument("--llm-max-tokens", type=int, default=None, help="最大生成 tokens(不提供则使用各脚本默认)") - parser.add_argument("--search-type", type=str, default=None, choices=["keyword", "embedding", "hybrid"], help="检索类型(不提供则使用各脚本默认)") - # 仅透传到 longmemeval;其他数据集忽略 - parser.add_argument("--start-index", type=int, default=None, help="仅 longmemeval:起始样本索引(不提供则用脚本默认)") - parser.add_argument("--max-contexts-per-item", type=int, default=None, help="仅 longmemeval:每条样本摄入的上下文数量上限(不提供则用脚本默认)") - parser.add_argument("--output", type=str, default=None, help="可选:将评估结果保存到指定文件路径(JSON);不提供时默认保存到 evaluation//results 目录") - args = parser.parse_args() - - result = asyncio.run(run( - args.dataset, - args.sample_size, - args.reset_group, - args.end_user_id, - args.judge_model, - args.search_limit, - args.context_char_budget, - args.llm_temperature, - args.llm_max_tokens, - args.search_type, - args.start_index, - args.max_contexts_per_item, - )) - print(json.dumps(result, ensure_ascii=False, indent=2)) - - # 结果输出逻辑保持不变 - if args.output: - out_path = args.output - else: - eval_dir = os.path.dirname(os.path.abspath(__file__)) - dataset_results_dir = os.path.join(eval_dir, args.dataset, "results") - out_filename = f"{args.dataset}_{args.sample_size}.json" - out_path = os.path.join(dataset_results_dir, out_filename) - - out_dir = os.path.dirname(out_path) - if out_dir and not os.path.exists(out_dir): - os.makedirs(out_dir, exist_ok=True) - with open(out_path, "w", encoding="utf-8") as f: - json.dump(result, f, ensure_ascii=False, indent=2) - print(f"\n结果已保存到: {out_path}") - - -if __name__ == "__main__": - main() diff --git a/redbear-mem-benchmark b/redbear-mem-benchmark index d9a00be6..558c023d 160000 --- a/redbear-mem-benchmark +++ b/redbear-mem-benchmark @@ -1 +1 @@ -Subproject commit d9a00be62d974c0ad071c27e86f878b921c675b6 +Subproject commit 558c023dadb5327a05561b22d8fb363c6ee2be29