From 83fe793e72f42b2b37ab13d05a35d7b2a28d294e Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Tue, 20 Jan 2026 15:03:29 +0800 Subject: [PATCH] refactor(memory): clean up deprecated config and self-reflexion utilities - Remove deprecated self_reflexion endpoint from memory_storage_controller - Delete obsolete config modules (config_optimization, definitions, get_example_data, litellm_config) - Remove self_reflexion_utils package and related evaluation/reflexion modules - Refactor hot_memory_tags to use Neo4jConnector instead of direct GraphDatabase connection - Simplify LLM client initialization by removing DEFAULT_LLM_ID fallback logic - Remove unnecessary sys.path manipulation and project root resolution code - Update filter_tags_with_llm to properly handle missing config with clear error messages - Migrate get_raw_tags_from_db to async function using Neo4jConnector - Consolidate imports and remove unused dependencies (uuid, sys) - Improve error handling with explicit ValueError messages for missing configuration --- .../controllers/memory_storage_controller.py | 17 - .../core/memory/analytics/hot_memory_tags.py | 239 +++----- api/app/core/memory/src/search.py | 323 ++--------- .../reflection_engine/self_reflexion.py | 2 - api/app/core/memory/utils/config/__init__.py | 33 -- .../utils/config/config_optimization.py | 398 -------------- .../core/memory/utils/config/definitions.py | 268 --------- .../memory/utils/config/get_example_data.py | 90 --- .../memory/utils/config/litellm_config.py | 516 ------------------ .../utils/self_reflexion_utils/__init__.py | 16 - .../utils/self_reflexion_utils/evaluate.py | 52 -- .../utils/self_reflexion_utils/reflexion.py | 54 -- .../self_reflexion_utils/self_reflexion.py | 254 --------- api/app/repositories/neo4j/graph_search.py | 65 ++- 14 files changed, 190 insertions(+), 2137 deletions(-) delete mode 100644 api/app/core/memory/utils/config/config_optimization.py delete mode 100644 api/app/core/memory/utils/config/definitions.py delete mode 100644 api/app/core/memory/utils/config/get_example_data.py delete mode 100644 api/app/core/memory/utils/config/litellm_config.py delete mode 100644 api/app/core/memory/utils/self_reflexion_utils/__init__.py delete mode 100644 api/app/core/memory/utils/self_reflexion_utils/evaluate.py delete mode 100644 api/app/core/memory/utils/self_reflexion_utils/reflexion.py delete mode 100644 api/app/core/memory/utils/self_reflexion_utils/self_reflexion.py diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index c58ecd6d..63d9078a 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -1,10 +1,8 @@ import os -import uuid from typing import Optional from app.core.error_codes import BizCode from app.core.logging_config import get_api_logger -from app.core.memory.utils.self_reflexion_utils import self_reflexion from app.core.response_utils import fail, success from app.db import get_db from app.dependencies import get_current_user @@ -458,18 +456,3 @@ async def get_recent_activity_stats_api( api_logger.error(f"Recent activity stats failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e)) - - - -@router.get("/self_reflexion") -async def self_reflexion_endpoint(host_id: uuid.UUID) -> str: - """ - 自我反思接口,自动对检索出的信息进行自我反思并返回自我反思结果。 - - Args: - None - Returns: - 自我反思结果。 - """ - return await self_reflexion(host_id) - diff --git a/api/app/core/memory/analytics/hot_memory_tags.py b/api/app/core/memory/analytics/hot_memory_tags.py index 2aa286ba..cab6cacd 100644 --- a/api/app/core/memory/analytics/hot_memory_tags.py +++ b/api/app/core/memory/analytics/hot_memory_tags.py @@ -1,48 +1,15 @@ import asyncio -import os -import sys -from typing import List, Tuple - -from neo4j import GraphDatabase -from pydantic import BaseModel, Field - -# ------------------- 自包含路径解析 ------------------- -# 这个代码块确保脚本可以从任何地方运行,并且仍然可以在项目结构中找到它需要的模块。 -try: - # 假设脚本在 /path/to/project/src/analytics/ - # 上升3个级别以到达项目根目录。 - project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) - src_path = os.path.join(project_root, 'src') - - # 将 'src' 和 'project_root' 都添加到路径中。 - # 'src' 目录对于像 'from utils.config_utils import ...' 这样的导入是必需的。 - # 'project_root' 目录对于像 'from variate_config import ...' 这样的导入是必需的。 - if src_path not in sys.path: - sys.path.insert(0, src_path) - if project_root not in sys.path: - sys.path.insert(0, project_root) -except NameError: - # 为 __file__ 未定义的环境(例如某些交互式解释器)提供回退方案 - project_root = os.path.abspath(os.path.join(os.getcwd())) - src_path = os.path.join(project_root, 'src') - if src_path not in sys.path: - sys.path.insert(0, src_path) - if project_root not in sys.path: - sys.path.insert(0, project_root) -# --------------------------------------------------------------------- - -# 现在路径已经配置好,我们可以使用绝对导入 import json +import os +from typing import List, Tuple from app.core.config import settings from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context +from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.services.memory_config_service import MemoryConfigService +from pydantic import BaseModel, Field -#TODO: Fix this -# Default values (previously from definitions.py) -DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus") -DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123") # 定义用于LLM结构化输出的Pydantic模型 class FilteredTags(BaseModel): @@ -52,34 +19,45 @@ class FilteredTags(BaseModel): async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]: """ 使用LLM筛选标签列表,仅保留具有代表性的核心名词。 + + Args: + tags: 原始标签列表 + group_id: 用户组ID,用于获取配置 + + Returns: + 筛选后的标签列表 + + Raises: + ValueError: 如果无法获取有效的LLM配置 """ try: # Get config_id using get_end_user_connected_config with get_db_context() as db: - try: - from app.services.memory_agent_service import ( - get_end_user_connected_config, + from app.services.memory_agent_service import ( + get_end_user_connected_config, + ) + + connected_config = get_end_user_connected_config(group_id, db) + config_id = connected_config.get("memory_config_id") + + if not config_id: + raise ValueError( + f"No memory_config_id found for group_id: {group_id}. " + "Please ensure the user has a valid memory configuration." ) - connected_config = get_end_user_connected_config(group_id, db) - config_id = connected_config.get("memory_config_id") - - if config_id: - # Use the config_id to get the proper LLM client - config_service = MemoryConfigService(db) - memory_config = config_service.load_memory_config(config_id) - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(memory_config.llm_model_id) - else: - # TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config - # Fallback to default LLM if no config found - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(DEFAULT_LLM_ID) - except Exception as e: - print(f"Failed to get user connected config, using default LLM: {e}") - # TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config - # Fallback to default LLM - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(DEFAULT_LLM_ID) + + # Use the config_id to get the proper LLM client + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config(config_id) + + if not memory_config.llm_model_id: + raise ValueError( + f"No llm_model_id found in memory config {config_id}. " + "Please configure a valid LLM model." + ) + + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(memory_config.llm_model_id) # 3. 构建Prompt tag_list_str = ", ".join(tags) @@ -107,33 +85,26 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]: # 在LLM失败时返回原始标签,确保流程继续 return tags -def get_db_connection(): - """ - 使用项目的标准配置方法建立与Neo4j数据库的连接。 - """ - # 从全局配置获取 Neo4j 连接信息 - uri = settings.NEO4J_URI - user = settings.NEO4J_USERNAME - - # 密码必须为了安全从环境变量加载 - password = os.getenv("NEO4J_PASSWORD") - - if not uri or not user: - raise ValueError("在 config.json 中未找到 Neo4j 的 'uri' 或 'username'。") - if not password: - raise ValueError("NEO4J_PASSWORD 环境变量未设置。") - - # 为此脚本使用同步驱动 - return GraphDatabase.driver(uri, auth=(user, password)) - -def get_raw_tags_from_db(group_id: str, limit: int, by_user: bool = False) -> List[Tuple[str, int]]: +async def get_raw_tags_from_db( + connector: Neo4jConnector, + group_id: str, + limit: int, + by_user: bool = False +) -> List[Tuple[str, int]]: """ + TODO: not accurate tag extraction 从数据库查询原始的、未经过滤的实体标签及其频率。 + + 使用项目的Neo4jConnector进行查询,遵循仓储模式。 Args: + connector: Neo4j连接器实例 group_id: 如果by_user=False,则为group_id;如果by_user=True,则为user_id limit: 返回的标签数量限制 by_user: 是否按user_id查询(默认False,按group_id查询) + + Returns: + List[Tuple[str, int]]: 标签名称和频率的元组列表 """ names_to_exclude = ['AI', 'Caroline', 'Melanie', 'Jon', 'Gina', '用户', 'AI助手', 'John', 'Maria'] @@ -154,83 +125,55 @@ def get_raw_tags_from_db(group_id: str, limit: int, by_user: bool = False) -> Li "LIMIT $limit" ) - driver = None - try: - driver = get_db_connection() - with driver.session() as session: - result = session.run(query, id=group_id, limit=limit, names_to_exclude=names_to_exclude) - return [(record["name"], record["frequency"]) for record in result] - finally: - if driver: - driver.close() + # 使用项目的Neo4jConnector执行查询 + results = await connector.execute_query( + query, + id=group_id, + limit=limit, + names_to_exclude=names_to_exclude + ) + + return [(record["name"], record["frequency"]) for record in results] -async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]: +async def get_hot_memory_tags(group_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]: """ 获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。 查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。 Args: - group_id: 如果by_user=False,则为group_id;如果by_user=True,则为user_id + group_id: 必需参数。如果by_user=False,则为group_id;如果by_user=True,则为user_id limit: 返回的标签数量限制 by_user: 是否按user_id查询(默认False,按group_id查询) + + Raises: + ValueError: 如果group_id未提供或为空 """ - # 默认从环境变量读取 - group_id = group_id or DEFAULT_GROUP_ID - # 1. 从数据库获取原始排名靠前的标签 - raw_tags_with_freq = get_raw_tags_from_db(group_id, limit, by_user=by_user) - if not raw_tags_with_freq: - return [] - - raw_tag_names = [tag for tag, freq in raw_tags_with_freq] - - # 2. 初始化LLM客户端并使用LLM筛选出有意义的标签 - meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id) - - # 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序) - final_tags = [] - for tag, freq in raw_tags_with_freq: - if tag in meaningful_tag_names: - final_tags.append((tag, freq)) - - return final_tags - -if __name__ == "__main__": - print("开始获取热门记忆标签...") + # 验证group_id必须提供且不为空 + if not group_id or not group_id.strip(): + raise ValueError( + "group_id is required. Please provide a valid group_id or user_id." + ) + + # 使用项目的Neo4jConnector + connector = Neo4jConnector() try: - # 直接使用环境变量中的 group_id - group_id_to_query = DEFAULT_GROUP_ID - # 使用 asyncio.run 来执行异步主函数 - top_tags = asyncio.run(get_hot_memory_tags(group_id=group_id_to_query)) + # 1. 从数据库获取原始排名靠前的标签 + raw_tags_with_freq = await get_raw_tags_from_db(connector, group_id, limit, by_user=by_user) + if not raw_tags_with_freq: + return [] - if top_tags: - print(f"热门记忆标签 (Group ID: {group_id_to_query}, 经LLM筛选):") - for tag, frequency in top_tags: - print(f"- {tag} (数量: {frequency})") + raw_tag_names = [tag for tag, freq in raw_tags_with_freq] - # --- 将结果写入统一的 Signboard.json 到 logs/memory-output --- - from app.core.config import settings - settings.ensure_memory_output_dir() - signboard_path = settings.get_memory_output_path("Signboard.json") - payload = { - "group_id": group_id_to_query, - "hot_tags": [{"name": t, "frequency": f} for t, f in top_tags] - } - try: - existing = {} - if os.path.exists(signboard_path): - with open(signboard_path, "r", encoding="utf-8") as rf: - existing = json.load(rf) - existing["hot_memory_tags"] = payload - with open(signboard_path, "w", encoding="utf-8") as wf: - json.dump(existing, wf, ensure_ascii=False, indent=2) - print(f"已写入 {signboard_path} -> hot_memory_tags") - except Exception as e: - print(f"写入 Signboard.json 失败: {e}") - else: - print(f"在 Group ID '{group_id_to_query}' 中没有找到符合条件的实体标签。") - except Exception as e: - print(f"执行过程中发生严重错误: {e}") - print("请检查:") - print("1. Neo4j数据库服务是否正在运行。") - print("2. 'config.json'中的配置是否正确。") - print("3. 相关的环境变量 (如 NEO4J_PASSWORD, DASHSCOPE_API_KEY) 是否已正确设置。") + # 2. 初始化LLM客户端并使用LLM筛选出有意义的标签 + meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id) + + # 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序) + final_tags = [] + for tag, freq in raw_tags_with_freq: + if tag in meaningful_tag_names: + final_tags.append((tag, freq)) + + return final_tags + finally: + # 确保关闭连接 + await connector.close() diff --git a/api/app/core/memory/src/search.py b/api/app/core/memory/src/search.py index ae2b9cfa..91e47eae 100644 --- a/api/app/core/memory/src/search.py +++ b/api/app/core/memory/src/search.py @@ -131,179 +131,60 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score") return results -# ============================================================================ -# 以下函数已被 rerank_with_activation 替代,暂时保留以供参考 -# ============================================================================ -# def rerank_hybrid_results( -# keyword_results: Dict[str, List[Dict[str, Any]]], -# embedding_results: Dict[str, List[Dict[str, Any]]], -# alpha: float = 0.6, -# limit: int = 10 -# ) -> Dict[str, List[Dict[str, Any]]]: -# """ -# Rerank hybrid search results by combining BM25 and embedding scores. -# -# 已废弃:此函数功能已被 rerank_with_activation 完全替代 -# -# Args: -# keyword_results: Results from keyword/BM25 search -# embedding_results: Results from embedding search -# alpha: Weight for BM25 scores (1-alpha for embedding scores) -# limit: Maximum number of results to return per category -# -# Returns: -# Reranked results with combined scores -# """ -# reranked = {} -# -# for category in ["statements", "chunks", "entities","summaries"]: -# keyword_items = keyword_results.get(category, []) -# embedding_items = embedding_results.get(category, []) -# -# # Normalize scores within each search type -# keyword_items = normalize_scores(keyword_items, "score") -# embedding_items = normalize_scores(embedding_items, "score") -# -# # Create a combined pool of unique items -# combined_items = {} -# -# # Add keyword results with BM25 scores -# for item in keyword_items: -# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") -# if item_id: -# combined_items[item_id] = item.copy() -# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0) -# combined_items[item_id]["embedding_score"] = 0 # Default -# -# # Add or update with embedding results -# for item in embedding_items: -# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") -# if item_id: -# if item_id in combined_items: -# # Update existing item with embedding score -# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) -# else: -# # New item from embedding search only -# combined_items[item_id] = item.copy() -# combined_items[item_id]["bm25_score"] = 0 # Default -# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) -# -# # Calculate combined scores and rank -# for item_id, item in combined_items.items(): -# bm25_score = item.get("bm25_score", 0) -# embedding_score = item.get("embedding_score", 0) -# -# # Combined score: weighted average of normalized scores -# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score -# item["combined_score"] = combined_score -# -# # Keep original score for reference -# if "score" not in item and bm25_score > 0: -# item["score"] = bm25_score -# elif "score" not in item and embedding_score > 0: -# item["score"] = embedding_score -# -# # Sort by combined score and limit results -# sorted_items = sorted( -# combined_items.values(), -# key=lambda x: x.get("combined_score", 0), -# reverse=True -# )[:limit] -# -# reranked[category] = sorted_items -# -# return reranked - -# def rerank_with_forgetting_curve( -# keyword_results: Dict[str, List[Dict[str, Any]]], -# embedding_results: Dict[str, List[Dict[str, Any]]], -# alpha: float = 0.6, -# limit: int = 10, -# forgetting_config: ForgettingEngineConfig | None = None, -# now: datetime | None = None, -# ) -> Dict[str, List[Dict[str, Any]]]: -# """ -# Rerank hybrid results with a forgetting curve applied to combined scores. -# -# 已废弃:此函数功能已被 rerank_with_activation 完全替代 -# rerank_with_activation 提供了更完整的遗忘曲线支持(结合激活度) -# -# The forgetting curve reduces scores for older memories or weaker connections. -# -# Args: -# keyword_results: Results from keyword/BM25 search -# embedding_results: Results from embedding search -# alpha: Weight for BM25 scores (1-alpha for embedding scores) -# limit: Maximum number of results to return per category -# forgetting_config: Configuration for the forgetting engine -# now: Optional current time override for testing -# -# Returns: -# Reranked results with combined and final scores (after forgetting) -# """ -# engine = ForgettingEngine(forgetting_config or ForgettingEngineConfig()) -# now_dt = now or datetime.now() -# -# reranked: Dict[str, List[Dict[str, Any]]] = {} -# -# for category in ["statements", "chunks", "entities","summaries"]: -# keyword_items = keyword_results.get(category, []) -# embedding_items = embedding_results.get(category, []) -# -# # Normalize scores within each search type -# keyword_items = normalize_scores(keyword_items, "score") -# embedding_items = normalize_scores(embedding_items, "score") -# -# combined_items: Dict[str, Dict[str, Any]] = {} -# -# # Combine two result sets by ID -# for src_items, is_embedding in ( -# (keyword_items, False), (embedding_items, True) -# ): -# for item in src_items: -# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") -# if not item_id: -# continue -# existing = combined_items.get(item_id) -# if not existing: -# combined_items[item_id] = item.copy() -# combined_items[item_id]["bm25_score"] = 0 -# combined_items[item_id]["embedding_score"] = 0 -# # Update normalized score from the right source -# if is_embedding: -# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) -# else: -# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0) -# -# # Calculate scores and apply forgetting weights -# for item_id, item in combined_items.items(): -# bm25_score = float(item.get("bm25_score", 0) or 0) -# embedding_score = float(item.get("embedding_score", 0) or 0) -# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score -# -# # Estimate time elapsed in days -# dt = _parse_datetime(item.get("created_at")) -# if dt is None: -# time_elapsed_days = 0.0 -# else: -# time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0) -# -# # Memory strength (currently set to default value) -# memory_strength = 1.0 -# forgetting_weight = engine.calculate_weight( -# time_elapsed=time_elapsed_days, memory_strength=memory_strength -# ) -# final_score = combined_score * forgetting_weight -# item["combined_score"] = final_score -# -# sorted_items = sorted( -# combined_items.values(), key=lambda x: x.get("combined_score", 0), reverse=True -# )[:limit] -# -# reranked[category] = sorted_items -# -# return reranked +def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Remove duplicate items from search results based on content. + + Deduplication strategy: + 1. First try to deduplicate by ID (id, uuid, or chunk_id) + 2. Then deduplicate by content hash (text, content, statement, or name fields) + + Args: + items: List of search result items + + Returns: + Deduplicated list of items, preserving the order of first occurrence + """ + seen_ids = set() + seen_content = set() + deduplicated = [] + + for item in items: + # Try multiple ID fields to identify unique items + item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") + + # Extract content from various possible fields + content = ( + item.get("text") or + item.get("content") or + item.get("statement") or + item.get("name") or + "" + ) + + # Normalize content for comparison (strip whitespace and lowercase) + normalized_content = str(content).strip().lower() if content else "" + + # Check if we've seen this ID or content before + is_duplicate = False + + if item_id and item_id in seen_ids: + is_duplicate = True + elif normalized_content and normalized_content in seen_content: + # Only check content duplication if content is not empty + is_duplicate = True + + if not is_duplicate: + # Mark as seen + if item_id: + seen_ids.add(item_id) + if normalized_content: # Only track non-empty content + seen_content.add(normalized_content) + + deduplicated.append(item) + + return deduplicated def rerank_with_activation( @@ -364,7 +245,7 @@ def rerank_with_activation( keyword_items = normalize_scores(keyword_items, "score") embedding_items = normalize_scores(embedding_items, "score") - # 步骤 2: 按 ID 合并结果 + # 步骤 2: 按 ID 合并结果(去重) combined_items: Dict[str, Dict[str, Any]] = {} # 添加关键词结果 @@ -507,6 +388,9 @@ def rerank_with_activation( # 无激活值:使用内容相关性分数 item["final_score"] = item.get("base_score", 0) + # 最终去重确保没有重复项 + sorted_items = _deduplicate_results(sorted_items) + reranked[category] = sorted_items return reranked @@ -1144,96 +1028,3 @@ async def search_chunk_by_chunk_id( ) return {"chunks": chunks} - -# def main(): -# """Main entry point for the hybrid graph search CLI. - -# Parses command line arguments and executes search with specified parameters. -# Supports keyword, embedding, and hybrid search modes. -# """ -# parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options") -# parser.add_argument( -# "--query", "-q", required=True, help="Free-text query to search" -# ) -# parser.add_argument( -# "--search-type", -# "-t", -# choices=["keyword", "embedding", "hybrid"], -# default="hybrid", -# help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)" -# ) -# parser.add_argument( -# "--config-id", -# "-c", -# type=int, -# required=True, -# help="Database configuration ID (required)", -# ) -# parser.add_argument( -# "--group-id", -# "-g", -# default=None, -# help="Optional group_id to filter results (default: None)", -# ) -# parser.add_argument( -# "--limit", -# "-k", -# type=int, -# default=5, -# help="Max number of results per type (default: 5)", -# ) -# parser.add_argument( -# "--include", -# "-i", -# nargs="+", -# default=["statements", "chunks", "entities", "summaries"], -# choices=["statements", "chunks", "entities", "summaries"], -# help="Which targets to search for embedding search (default: statements chunks entities summaries)" -# ) -# parser.add_argument( -# "--output", -# "-o", -# default="search_results.json", -# help="Path to save the search results JSON (default: search_results.json)", -# ) -# parser.add_argument( -# "--rerank-alpha", -# "-a", -# type=float, -# default=0.6, -# help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)", -# ) -# parser.add_argument( -# "--forgetting-rerank", -# action="store_true", -# help="Apply forgetting curve during reranking for hybrid search.", -# ) -# parser.add_argument( -# "--llm-rerank", -# action="store_true", -# help="Apply LLM-based reranking for hybrid search.", -# ) -# args = parser.parse_args() - -# # Load memory config from database -# from app.services.memory_config_service import MemoryConfigService -# memory_config = MemoryConfigService.load_memory_config(args.config_id) - -# asyncio.run( -# run_hybrid_search( -# query_text=args.query, -# search_type=args.search_type, -# group_id=args.group_id, -# limit=args.limit, -# include=args.include, -# output_path=args.output, -# memory_config=memory_config, -# rerank_alpha=args.rerank_alpha, -# use_forgetting_rerank=args.forgetting_rerank, -# use_llm_rerank=args.llm_rerank, -# ) -# ) - - -# if __name__ == "__main__": -# main() diff --git a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py index bd3a9190..d39c9dbb 100644 --- a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py +++ b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py @@ -18,13 +18,11 @@ from enum import Enum from typing import Any, Dict, List, Optional from app.core.memory.llm_tools.openai_client import OpenAIClient -from app.core.memory.utils.config import definitions as config_defs from app.core.memory.utils.config.get_data import ( extract_and_process_changes, get_data, get_data_statement, ) - from app.core.models.base import RedBearModelConfig from app.repositories.neo4j.cypher_queries import ( neo4j_query_all, diff --git a/api/app/core/memory/utils/config/__init__.py b/api/app/core/memory/utils/config/__init__.py index f69c13a2..9eef888c 100644 --- a/api/app/core/memory/utils/config/__init__.py +++ b/api/app/core/memory/utils/config/__init__.py @@ -14,28 +14,8 @@ from .config_utils import ( get_pruning_config, get_voice_config, ) - -# DEPRECATED: Global configuration variables removed -# Use MemoryConfig objects with dependency injection instead -# from .definitions import ( -# CONFIG, # DEPRECATED - empty dict for backward compatibility -# RUNTIME_CONFIG, # DEPRECATED - minimal for backward compatibility -# PROJECT_ROOT, # Still needed for file paths -# reload_configuration_from_database, # DEPRECATED - returns False -# ) -# DEPRECATED: overrides module removed - use MemoryConfig with dependency injection from .get_data import get_data -# litellm_config 需要时动态导入,避免循环依赖 -# from .litellm_config import ( -# LiteLLMConfig, -# setup_litellm_enhanced, -# get_usage_summary, -# print_usage_summary, -# get_instant_qps, -# print_instant_qps, -# ) - __all__ = [ # config_utils "get_model_config", @@ -45,18 +25,5 @@ __all__ = [ "get_pruning_config", "get_picture_config", "get_voice_config", - # definitions (DEPRECATED - use MemoryConfig objects instead) - # "CONFIG", # DEPRECATED - # "RUNTIME_CONFIG", # DEPRECATED - # "PROJECT_ROOT", - # "reload_configuration_from_database", # DEPRECATED - # get_data "get_data", - # litellm_config - 需要时从 .litellm_config 直接导入 - # "LiteLLMConfig", - # "setup_litellm_enhanced", - # "get_usage_summary", - # "print_usage_summary", - # "get_instant_qps", - # "print_instant_qps", ] diff --git a/api/app/core/memory/utils/config/config_optimization.py b/api/app/core/memory/utils/config/config_optimization.py deleted file mode 100644 index 41848a80..00000000 --- a/api/app/core/memory/utils/config/config_optimization.py +++ /dev/null @@ -1,398 +0,0 @@ -""" -配置管理优化模块 - -提供可选的配置管理优化功能,包括: -- LRU 缓存策略 -- 缓存预热 -- 缓存监控指标 -- 动态 TTL 策略 -- 配置版本控制 - -这些优化是可选的,当前的基础实现已经满足大多数需求。 -""" -import logging -import statistics -import threading -from collections import OrderedDict -from datetime import datetime, timedelta -from typing import Dict, Any, List, Optional, Tuple - -logger = logging.getLogger(__name__) - - -class LRUConfigCache: - """ - LRU(Least Recently Used)配置缓存 - - 当缓存达到最大容量时,自动淘汰最少使用的配置 - """ - - def __init__(self, max_size: int = 100, ttl: timedelta = timedelta(minutes=5)): - """ - 初始化 LRU 缓存 - - Args: - max_size: 最大缓存容量 - ttl: 缓存过期时间 - """ - self.max_size = max_size - self.ttl = ttl - self._cache: OrderedDict[str, Dict[str, Any]] = OrderedDict() - self._timestamps: Dict[str, datetime] = {} - self._lock = threading.RLock() - - # 统计信息 - self._stats = { - 'hits': 0, - 'misses': 0, - 'evictions': 0, - 'load_times': [] - } - - def get(self, config_id: str) -> Optional[Dict[str, Any]]: - """ - 获取配置(如果存在且未过期) - - Args: - config_id: 配置 ID - - Returns: - 配置字典,如果不存在或已过期则返回 None - """ - with self._lock: - if config_id not in self._cache: - self._stats['misses'] += 1 - return None - - # 检查是否过期 - timestamp = self._timestamps.get(config_id) - if timestamp and (datetime.now() - timestamp) >= self.ttl: - # 过期,移除 - self._cache.pop(config_id, None) - self._timestamps.pop(config_id, None) - self._stats['misses'] += 1 - return None - - # 命中,移动到末尾(标记为最近使用) - self._cache.move_to_end(config_id) - self._stats['hits'] += 1 - return self._cache[config_id] - - def put(self, config_id: str, config: Dict[str, Any]) -> None: - """ - 添加或更新配置 - - Args: - config_id: 配置 ID - config: 配置字典 - """ - with self._lock: - if config_id in self._cache: - # 更新现有配置 - self._cache.move_to_end(config_id) - else: - # 添加新配置 - if len(self._cache) >= self.max_size: - # 缓存已满,移除最旧的配置 - oldest_id, _ = self._cache.popitem(last=False) - self._timestamps.pop(oldest_id, None) - self._stats['evictions'] += 1 - logger.debug(f"[LRUCache] 淘汰配置: {oldest_id}") - - self._cache[config_id] = config - self._timestamps[config_id] = datetime.now() - - def clear(self, config_id: Optional[str] = None) -> None: - """ - 清除缓存 - - Args: - config_id: 如果指定,只清除该配置;否则清除所有 - """ - with self._lock: - if config_id: - self._cache.pop(config_id, None) - self._timestamps.pop(config_id, None) - else: - self._cache.clear() - self._timestamps.clear() - - def get_stats(self) -> Dict[str, Any]: - """ - 获取缓存统计信息 - - Returns: - 统计信息字典 - """ - with self._lock: - total = self._stats['hits'] + self._stats['misses'] - hit_rate = (self._stats['hits'] / total * 100) if total > 0 else 0 - - return { - 'cache_size': len(self._cache), - 'max_size': self.max_size, - 'total_requests': total, - 'cache_hits': self._stats['hits'], - 'cache_misses': self._stats['misses'], - 'evictions': self._stats['evictions'], - 'hit_rate': hit_rate, - 'avg_load_time': statistics.mean(self._stats['load_times']) if self._stats['load_times'] else 0 - } - - def record_load_time(self, load_time_ms: float) -> None: - """ - 记录加载时间 - - Args: - load_time_ms: 加载时间(毫秒) - """ - with self._lock: - self._stats['load_times'].append(load_time_ms) - # 只保留最近 1000 次的记录 - if len(self._stats['load_times']) > 1000: - self._stats['load_times'] = self._stats['load_times'][-1000:] - - -class ConfigCacheWarmer: - """ - 配置缓存预热器 - - 在系统启动时预加载常用配置,减少首次请求延迟 - """ - - @staticmethod - def warmup(config_ids: List[str], load_func) -> Dict[str, bool]: - """ - 预热缓存 - - Args: - config_ids: 要预加载的配置 ID 列表 - load_func: 配置加载函数 - - Returns: - 每个配置的加载结果 - """ - results = {} - - logger.info(f"[CacheWarmer] 开始预热 {len(config_ids)} 个配置") - - for config_id in config_ids: - try: - result = load_func(config_id) - results[config_id] = result - if result: - logger.debug(f"[CacheWarmer] 成功预热配置: {config_id}") - else: - logger.warning(f"[CacheWarmer] 预热配置失败: {config_id}") - except Exception as e: - logger.error(f"[CacheWarmer] 预热配置异常: {config_id}, 错误: {e}") - results[config_id] = False - - success_count = sum(1 for r in results.values() if r) - logger.info(f"[CacheWarmer] 预热完成: {success_count}/{len(config_ids)} 成功") - - return results - - -class DynamicTTLStrategy: - """ - 动态 TTL 策略 - - 根据配置类型和更新频率动态调整缓存过期时间 - """ - - # 预定义的 TTL 策略 - TTL_STRATEGIES = { - 'production': timedelta(minutes=30), # 生产配置较稳定 - 'staging': timedelta(minutes=15), # 预发布配置中等稳定 - 'development': timedelta(minutes=5), # 开发配置频繁变化 - 'testing': timedelta(minutes=1), # 测试配置快速过期 - 'default': timedelta(minutes=5) # 默认策略 - } - - @classmethod - def get_ttl(cls, config_id: str, config_type: Optional[str] = None) -> timedelta: - """ - 获取配置的 TTL - - Args: - config_id: 配置 ID - config_type: 配置类型(production/staging/development/testing) - - Returns: - TTL 时间间隔 - """ - if config_type and config_type in cls.TTL_STRATEGIES: - return cls.TTL_STRATEGIES[config_type] - - # 根据 config_id 推断类型 - if 'prod' in config_id.lower(): - return cls.TTL_STRATEGIES['production'] - elif 'stag' in config_id.lower(): - return cls.TTL_STRATEGIES['staging'] - elif 'dev' in config_id.lower(): - return cls.TTL_STRATEGIES['development'] - elif 'test' in config_id.lower(): - return cls.TTL_STRATEGIES['testing'] - - return cls.TTL_STRATEGIES['default'] - - -class ConfigVersionManager: - """ - 配置版本管理器 - - 跟踪配置版本,当配置更新时自动失效旧版本缓存 - """ - - def __init__(self): - self._versions: Dict[str, str] = {} - self._lock = threading.RLock() - - def get_version(self, config_id: str) -> Optional[str]: - """ - 获取配置版本 - - Args: - config_id: 配置 ID - - Returns: - 版本号,如果不存在则返回 None - """ - with self._lock: - return self._versions.get(config_id) - - def set_version(self, config_id: str, version: str) -> None: - """ - 设置配置版本 - - Args: - config_id: 配置 ID - version: 版本号 - """ - with self._lock: - old_version = self._versions.get(config_id) - self._versions[config_id] = version - - if old_version and old_version != version: - logger.info(f"[VersionManager] 配置版本更新: {config_id} {old_version} -> {version}") - - def check_version(self, config_id: str, cached_version: Optional[str]) -> bool: - """ - 检查缓存版本是否有效 - - Args: - config_id: 配置 ID - cached_version: 缓存的版本号 - - Returns: - True 如果版本匹配,False 如果版本不匹配或不存在 - """ - with self._lock: - current_version = self._versions.get(config_id) - - if not current_version or not cached_version: - return False - - return current_version == cached_version - - def invalidate(self, config_id: str) -> None: - """ - 使配置版本失效 - - Args: - config_id: 配置 ID - """ - with self._lock: - if config_id in self._versions: - # 生成新版本号 - import uuid - new_version = str(uuid.uuid4()) - self._versions[config_id] = new_version - logger.info(f"[VersionManager] 配置版本失效: {config_id} -> {new_version}") - - -class CacheMonitor: - """ - 缓存监控器 - - 提供缓存性能监控和报告功能 - """ - - def __init__(self, cache: LRUConfigCache): - self.cache = cache - - def get_report(self) -> str: - """ - 生成缓存性能报告 - - Returns: - 格式化的报告字符串 - """ - stats = self.cache.get_stats() - - report = f""" -配置缓存性能报告 -================ -缓存容量: {stats['cache_size']}/{stats['max_size']} -总请求数: {stats['total_requests']} -缓存命中: {stats['cache_hits']} -缓存未命中: {stats['cache_misses']} -缓存命中率: {stats['hit_rate']:.2f}% -淘汰次数: {stats['evictions']} -平均加载时间: {stats['avg_load_time']:.2f}ms -""" - return report - - def log_stats(self) -> None: - """记录统计信息到日志""" - stats = self.cache.get_stats() - logger.info( - f"[CacheMonitor] 缓存统计 - " - f"容量: {stats['cache_size']}/{stats['max_size']}, " - f"命中率: {stats['hit_rate']:.2f}%, " - f"淘汰: {stats['evictions']}" - ) - - -# 使用示例 -def example_usage(): - """ - 优化功能使用示例 - """ - # 1. 使用 LRU 缓存 - lru_cache = LRUConfigCache(max_size=100, ttl=timedelta(minutes=5)) - - # 获取配置 - config = lru_cache.get("config_001") - if config is None: - # 缓存未命中,从数据库加载 - config = {"llm_name": "openai/gpt-4"} - lru_cache.put("config_001", config) - - # 2. 预热缓存 - def load_config(config_id): - # 实际的配置加载逻辑 - return True - - warmer = ConfigCacheWarmer() - results = warmer.warmup(["config_001", "config_002"], load_config) - - # 3. 动态 TTL - ttl = DynamicTTLStrategy.get_ttl("prod_config_001", "production") - print(f"TTL: {ttl}") - - # 4. 版本管理 - version_manager = ConfigVersionManager() - version_manager.set_version("config_001", "v1.0.0") - - # 检查版本 - is_valid = version_manager.check_version("config_001", "v1.0.0") - - # 5. 监控 - monitor = CacheMonitor(lru_cache) - print(monitor.get_report()) - - -if __name__ == "__main__": - example_usage() diff --git a/api/app/core/memory/utils/config/definitions.py b/api/app/core/memory/utils/config/definitions.py deleted file mode 100644 index fc07c2cc..00000000 --- a/api/app/core/memory/utils/config/definitions.py +++ /dev/null @@ -1,268 +0,0 @@ -# """ -# 配置加载模块 - DEPRECATED - -# ⚠️ DEPRECATION NOTICE ⚠️ -# This module is deprecated and will be removed in a future version. -# Global configuration variables have been eliminated in favor of dependency injection. - -# Use the new MemoryConfig system instead: -# - app.schemas.memory_config_schema.MemoryConfig for configuration objects -# - config_service = MemoryConfigService(db); config_service.load_memory_config(config_id) - -# 阶段 1: 从 runtime.json 加载配置(路径 A)- DEPRECATED -# 阶段 2: 从数据库加载配置(路径 B,基于 dbrun.json 中的 config_id)- DEPRECATED -# 阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)- DEPRECATED -# """ -# import json -# import os -# import threading -# from datetime import datetime, timedelta -# from typing import Any, Dict, Optional - -# #TODO: Fix this - -# try: -# from dotenv import load_dotenv -# load_dotenv() -# except Exception: -# pass - -# # Import unified configuration system -# try: -# from app.core.config import settings -# USE_UNIFIED_CONFIG = True -# except ImportError: -# USE_UNIFIED_CONFIG = False -# settings = None - -# # PROJECT_ROOT 应该指向 app/core/memory/ 目录 -# # __file__ = app/core/memory/utils/config/definitions.py -# # os.path.dirname(__file__) = app/core/memory/utils/config -# # os.path.dirname(...) = app/core/memory/utils -# # os.path.dirname(...) = app/core/memory -# PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -# # DEPRECATED: Global configuration lock removed -# # Use MemoryConfig objects with dependency injection instead - -# # DEPRECATED: Legacy config.json loading removed -# # Use MemoryConfig objects with dependency injection instead -# CONFIG = {} - -# DEFAULT_VALUES = { -# "llm_name": "openai/qwen-plus", -# "embedding_name": "openai/nomic-embed-text:v1.5", -# "chunker_strategy": "RecursiveChunker", -# "group_id": "group_123", -# "user_id": "default_user", -# "apply_id": "default_apply", -# "llm_agent_name": "openai/qwen-plus", -# "llm_verify_name": "openai/qwen-plus", -# "llm_image_recognition": "openai/qwen-plus", -# "llm_voice_recognition": "openai/qwen-plus", -# "prompt_level": "DEBUG", -# "reflexion_iteration_period": "3", -# "reflexion_range": "retrieval", -# "reflexion_baseline": "TIME", -# } - -# # DEPRECATED: Legacy global variables for backward compatibility only -# # These will be removed in a future version -# # Use MemoryConfig objects with dependency injection instead -# # LANGFUSE_ENABLED = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true" -# # SELECTED_LLM_ID = os.getenv("SELECTED_LLM_ID", DEFAULT_VALUES["llm_name"]) - - -# # 阶段 1: 从 runtime.json 加载配置(路径 A) -# def _load_from_runtime_json() -> Dict[str, Any]: -# """ -# DEPRECATED: Legacy runtime.json loading - -# ⚠️ This function is deprecated and will be removed in a future version. -# Use MemoryConfig objects with dependency injection instead. - -# Returns: -# Dict[str, Any]: Empty configuration (legacy support only) -# """ -# import warnings -# warnings.warn( -# "Runtime JSON loading is deprecated. Use MemoryConfig objects with dependency injection instead.", -# DeprecationWarning, -# stacklevel=2 -# ) -# return {"selections": {}} - - -# # 阶段 2: 从数据库加载配置(路径 B)- 已整合到统一加载器 -# # 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代 -# # 保留此函数仅为向后兼容 -# def _load_from_database() -> Optional[Dict[str, Any]]: -# """ -# DEPRECATED: Legacy database configuration loading - -# ⚠️ This function is deprecated and will be removed in a future version. -# Use MemoryConfig objects with dependency injection instead. - -# Returns: -# Optional[Dict[str, Any]]: None (deprecated functionality) -# """ -# import warnings -# warnings.warn( -# "Database configuration loading is deprecated. Use MemoryConfig objects with dependency injection instead.", -# DeprecationWarning, -# stacklevel=2 -# ) -# return None - - -# # 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)- DEPRECATED -# def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None: -# """ -# DEPRECATED: 将运行时配置暴露为全局常量供项目使用 - -# ⚠️ This function is deprecated and will be removed in a future version. -# Global configuration variables have been eliminated in favor of dependency injection. - -# Use the new MemoryConfig system instead: -# - app.core.memory_config.config.MemoryConfig for configuration objects -# - Pass configuration objects as parameters instead of using global variables - -# Args: -# runtime_cfg: 运行时配置字典 -# """ -# import warnings -# warnings.warn( -# "Global configuration variables are deprecated. Use MemoryConfig objects with dependency injection instead.", -# DeprecationWarning, -# stacklevel=2 -# ) - -# # Keep minimal global state for backward compatibility only -# # These will be removed in a future version -# global RUNTIME_CONFIG, SELECTIONS - -# RUNTIME_CONFIG = runtime_cfg -# SELECTIONS = RUNTIME_CONFIG.get("selections", {}) - -# # All other global variables have been removed -# # Use MemoryConfig objects instead - - -# # 初始化:使用统一配置加载器 -# def _initialize_configuration() -> None: -# """ -# DEPRECATED: Legacy configuration initialization - -# ⚠️ This function is deprecated and will be removed in a future version. -# Use MemoryConfig objects with dependency injection instead. -# """ -# import warnings -# warnings.warn( -# "Global configuration initialization is deprecated. Use MemoryConfig objects with dependency injection instead.", -# DeprecationWarning, -# stacklevel=2 -# ) -# # Initialize with empty configuration for backward compatibility -# _expose_runtime_constants({"selections": {}}) - - -# # 模块加载时自动初始化配置 -# _initialize_configuration() - -# # DEPRECATED: Global variables removed -# # These variables have been eliminated in favor of dependency injection -# # Use MemoryConfig objects instead of accessing global variables - - -# # 公共 API:动态重新加载配置 -# def reload_configuration_from_database(config_id, force_reload: bool = False) -> bool: -# """ -# DEPRECATED: Legacy configuration reloading - -# ⚠️ This function is deprecated and will be removed in a future version. -# Use MemoryConfig objects with dependency injection instead. - -# For new code, use: -# - app.services.memory_agent_service.MemoryAgentService.load_memory_config() -# - app.services.memory_storage_service.MemoryStorageService.load_memory_config() - -# Args: -# config_id: Configuration ID (deprecated) -# force_reload: Force reload flag (deprecated) - -# Returns: -# bool: Always returns False (deprecated functionality) -# """ -# import logging -# import warnings - -# logger = logging.getLogger(__name__) - -# warnings.warn( -# "reload_configuration_from_database is deprecated. Use MemoryConfig objects with dependency injection instead.", -# DeprecationWarning, -# stacklevel=2 -# ) - -# logger.warning(f"Deprecated function reload_configuration_from_database called with config_id={config_id}. " -# "Use MemoryConfig objects with dependency injection instead.") - -# return False - - - - - -# def get_current_config_id() -> Optional[str]: -# """ -# DEPRECATED: Legacy config ID retrieval - -# ⚠️ This function is deprecated and will be removed in a future version. -# Use MemoryConfig objects with dependency injection instead. - -# Returns: -# Optional[str]: None (deprecated functionality) -# """ -# import warnings -# warnings.warn( -# "get_current_config_id is deprecated. Use MemoryConfig objects with dependency injection instead.", -# DeprecationWarning, -# stacklevel=2 -# ) -# return None - - -# def ensure_fresh_config(config_id = None) -> bool: -# """ -# DEPRECATED: Legacy configuration freshness check - -# ⚠️ This function is deprecated and will be removed in a future version. -# Use MemoryConfig objects with dependency injection instead. - -# For new code, use: -# - app.services.memory_agent_service.MemoryAgentService.load_memory_config() -# - app.services.memory_storage_service.MemoryStorageService.load_memory_config() - -# Args: -# config_id: Configuration ID (deprecated) - -# Returns: -# bool: Always returns False (deprecated functionality) -# """ -# import logging -# import warnings - -# logger = logging.getLogger(__name__) - -# warnings.warn( -# "ensure_fresh_config is deprecated. Use MemoryConfig objects with dependency injection instead.", -# DeprecationWarning, -# stacklevel=2 -# ) - -# logger.warning(f"Deprecated function ensure_fresh_config called with config_id={config_id}. " -# "Use MemoryConfig objects with dependency injection instead.") - -# return False - - diff --git a/api/app/core/memory/utils/config/get_example_data.py b/api/app/core/memory/utils/config/get_example_data.py deleted file mode 100644 index c466645b..00000000 --- a/api/app/core/memory/utils/config/get_example_data.py +++ /dev/null @@ -1,90 +0,0 @@ -import os -import re -import uuid -import random -import string -from typing import List, Dict, Optional - -# 生成包含字母(大小写)和数字的随机字符串 -def generate_random_string(length=16): - characters = string.ascii_letters + string.digits - return ''.join(random.choice(characters) for _ in range(length)) - -def get_example_data() -> List[Dict[str, Optional[str]]]: - """ - 从句子提取日志中获取数据 - Content: 在苹果公司中国总部,用户和李华偶遇了从美国来的技术专家约翰·史密斯。 - Created At: 2025-11-28 19:28:38.256421 - Expired At: None - Valid At: None - Invalid At: None - 将数据构造成如下形式: - [ - { - "id":id, - "group_id":group_id, - "statement": Content, - "created_at": Created At, - "expired_at": Expired At, - "valid_at": Valid At, - "invalid_at": Invalid At, - "chunk_id": "86da9022710c40eaa5f518a294c398d2", - "entity_ids": [] - }, - ... - ] - """ - # 获取日志文件路径 - log_file_path = os.path.join("logs", "memory-output", "statement_extraction.txt") - - # 检查文件是否存在 - if not os.path.exists(log_file_path): - return [] - - # 读取日志文件 - with open(log_file_path, "r", encoding="utf-8") as f: - content = f.read() - - # 解析数据 - results = [] - - # 使用正则表达式分割每个 Statement - statement_blocks = re.split(r"Statement \d+:", content) - - for block in statement_blocks[1:]: # 跳过第一个空块 - # 提取各个字段 - id_match = re.search(r"Id:\s*(.+?)(?=\n)", block) - group_id_match = re.search(r"Group Id:\s*(.+?)(?=\n)", block) - statement_match = re.search(r"Content:\s*(.+?)(?=\n)", block) - created_at_match = re.search(r"Created At:\s*(.+?)(?=\n)", block) - expired_at_match = re.search(r"Expired At:\s*(.+?)(?=\n)", block) - valid_at_match = re.search(r"Valid At:\s*(.+?)(?=\n)", block) - invalid_at_match = re.search(r"Invalid At:\s*(.+?)(?=\n)", block) - chunk_id_match = re.search(r"Chunk Id:\s*(.+?)(?=\n)", block) - - # 构造字典 - if statement_match: - statement_data = { - "id": id_match.group(1).strip() if id_match else generate_random_string(), - "group_id": group_id_match.group(1).strip() if group_id_match else "group_example", - "statement": statement_match.group(1).strip(), - "created_at": created_at_match.group(1).strip() if created_at_match else None, - "expired_at": expired_at_match.group(1).strip() if expired_at_match else None, - "valid_at": valid_at_match.group(1).strip() if valid_at_match else None, - "invalid_at": invalid_at_match.group(1).strip() if invalid_at_match else None, - "chunk_id": chunk_id_match.group(1).strip() if chunk_id_match else "chunk_example", - "entity_ids": [] - } - - # 将 "None" 字符串转换为 None - for key in ["created_at", "expired_at", "valid_at", "invalid_at"]: - if statement_data[key] == "None": - statement_data[key] = None - - results.append(statement_data) - - return results - - -if __name__ == "__main__": - print(f"获取数据如下:\n {get_example_data()}") \ No newline at end of file diff --git a/api/app/core/memory/utils/config/litellm_config.py b/api/app/core/memory/utils/config/litellm_config.py deleted file mode 100644 index dbf991a8..00000000 --- a/api/app/core/memory/utils/config/litellm_config.py +++ /dev/null @@ -1,516 +0,0 @@ -""" -LiteLLM Configuration for Enhanced Retry Logic and Usage Tracking with Native QPS Monitoring -""" - -import litellm -from typing import Dict, Any, List -import json -from datetime import datetime, timedelta -import os -import time -from collections import defaultdict -import threading -from queue import Queue - -class LiteLLMConfig: - """Configuration class for LiteLLM with enhanced retry and tracking capabilities""" - - def __init__(self): - self.usage_data = [] - self.error_data = [] - self.module_stats = defaultdict(lambda: { - 'requests': 0, - 'tokens_in': 0, - 'tokens_out': 0, - 'cost': 0.0, - 'errors': 0, - 'start_time': None, - 'last_request_time': None, - 'request_timestamps': [], # Store precise timestamps - 'current_qps': 0.0, - 'max_qps': 0.0, - 'qps_history': [] # Store QPS measurements over time - }) - self.start_time = datetime.now() - self.global_request_timestamps = [] - self.global_max_qps = 0.0 - - # Rate limiting for AWS Bedrock (conservative limits) - self.rate_limits = { - 'bedrock': { - 'requests_per_minute': 2, # AWS Bedrock default is very low - 'requests_per_second': 0.033, # 2/60 = 0.033 RPS - 'last_request_time': 0, - 'request_queue': Queue(), - 'lock': threading.Lock() - } - } - self.rate_limiting_enabled = True - - def setup_enhanced_config(self, max_retries: int = 3): - """Configure LiteLLM with retry logic and instant QPS tracking""" - - litellm.num_retries = max_retries - litellm.request_timeout = 300 - - litellm.retry_policy = { - "RateLimitError": { - "max_retries": 5, - "exponential_backoff": True, - "initial_delay": 1, - "max_delay": 60, - "jitter": True - }, - "APIConnectionError": { - "max_retries": 3, - "exponential_backoff": True, - "initial_delay": 2, - "max_delay": 30, - "jitter": True - }, - "InternalServerError": { - "max_retries": 2, - "exponential_backoff": True, - "initial_delay": 5, - "max_delay": 60, - "jitter": True - }, - "BadRequestError": { - "max_retries": 1, - "exponential_backoff": False, - "initial_delay": 1, - "max_delay": 5 - } - } - - litellm.success_callback = [self._success_callback] - litellm.failure_callback = [self._failure_callback] - litellm.completion_cost_tracking = True - litellm.set_verbose = False - litellm.modify_params = True - - print("✅ LiteLLM configured with instant QPS tracking and rate limiting") - - def _success_callback(self, kwargs, completion_response, start_time, end_time): - """Callback for successful requests with module-specific QPS tracking""" - try: - # Extract usage information - usage = completion_response.get('usage', {}) - model = kwargs.get('model', 'unknown') - - # Extract module information from metadata or model name - module = self._extract_module_name(kwargs, model) - - # Calculate cost - cost = 0.0 - try: - cost = litellm.completion_cost(completion_response) - except: - pass - - # Calculate duration - duration_seconds = (end_time - start_time).total_seconds() if hasattr(end_time - start_time, 'total_seconds') else float(end_time - start_time) - - # Record usage data - usage_record = { - "timestamp": datetime.now().isoformat(), - "model": model, - "module": module, - "input_tokens": usage.get('prompt_tokens', 0), - "output_tokens": usage.get('completion_tokens', 0), - "total_tokens": usage.get('total_tokens', 0), - "cost": cost, - "duration_seconds": duration_seconds, - "status": "success" - } - - self.usage_data.append(usage_record) - - # Update module-specific stats for QPS tracking - self._update_module_stats(module, usage_record, success=True) - - # Print real-time feedback - print(f"✓ {model}: {usage_record['input_tokens']}→{usage_record['output_tokens']} tokens, ${cost:.4f}, {usage_record['duration_seconds']:.2f}s") - - except Exception as e: - print(f"Warning: Success callback failed: {e}") - - def _failure_callback(self, kwargs, completion_response, start_time, end_time): - """Callback for failed requests with module-specific error tracking""" - try: - model = kwargs.get('model', 'unknown') - module = self._extract_module_name(kwargs, model) - - duration_seconds = (end_time - start_time).total_seconds() if hasattr(end_time - start_time, 'total_seconds') else float(end_time - start_time) - - # Handle different error response formats - error_message = "Unknown error" - error_type = "UnknownError" - - # According to LiteLLM docs, completion_response contains the exception for failures - if completion_response is not None: - error_message = str(completion_response) - error_type = type(completion_response).__name__ - - # Also check kwargs for exception (LiteLLM passes exception in kwargs for failure events) - elif 'exception' in kwargs: - exception = kwargs['exception'] - error_message = str(exception) - error_type = type(exception).__name__ - - # Check for other error formats in kwargs - elif 'error' in kwargs: - error = kwargs['error'] - error_message = str(error) - error_type = type(error).__name__ - - # Check log_event_type to confirm this is a failure event - log_event_type = kwargs.get('log_event_type', '') - if log_event_type == 'failed_api_call' and 'exception' in kwargs: - exception = kwargs['exception'] - error_message = str(exception) - error_type = type(exception).__name__ - - error_record = { - "timestamp": datetime.now().isoformat(), - "model": model, - "module": module, - "error": error_message, - "error_type": error_type, - "duration_seconds": duration_seconds, - "status": "failed" - } - - self.error_data.append(error_record) - - # Update module-specific stats for error tracking - self._update_module_stats(module, error_record, success=False) - - # Print error feedback - print(f"✗ {model}: {error_type} - {error_message[:100]}") - - except Exception as e: - print(f"Warning: Failure callback failed: {e}") - # Debug: print the actual parameters to understand the structure - print(f"Debug - kwargs keys: {list(kwargs.keys()) if kwargs else 'None'}") - print(f"Debug - completion_response type: {type(completion_response)}") - print(f"Debug - completion_response: {completion_response}") - - def _should_rate_limit(self, model: str) -> bool: - """Check if the model should be rate limited""" - if not self.rate_limiting_enabled: - return False - return model.startswith('bedrock/') or 'bedrock' in model.lower() - - def _enforce_rate_limit(self, model: str): - """Enforce rate limiting for AWS Bedrock models""" - if not self._should_rate_limit(model): - return - - provider = 'bedrock' - if provider not in self.rate_limits: - return - - rate_config = self.rate_limits[provider] - - with rate_config['lock']: - current_time = time.time() - time_since_last = current_time - rate_config['last_request_time'] - min_interval = 1.0 / rate_config['requests_per_second'] - - if time_since_last < min_interval: - sleep_time = min_interval - time_since_last - print(f"⏳ Rate limiting: sleeping {sleep_time:.2f}s for {model}") - time.sleep(sleep_time) - - rate_config['last_request_time'] = time.time() - - def _extract_module_name(self, kwargs: Dict[str, Any], model: str) -> str: - """Extract module name from request context""" - # Try to get module from metadata - metadata = kwargs.get('metadata', {}) - if 'module' in metadata: - return metadata['module'] - - # Try to infer from model name or other context - if 'claude' in model.lower(): - return 'bedrock_client' - elif 'gpt' in model.lower() or 'openai' in model.lower(): - return 'openai_client' - elif 'embed' in model.lower(): - return 'embedder' - else: - return 'unknown' - - def _update_module_stats(self, module: str, record: Dict[str, Any], success: bool): - """Update module-specific statistics with instant QPS tracking""" - current_timestamp = time.time() - current_time = datetime.now() - - # Initialize module stats if first request - if self.module_stats[module]['start_time'] is None: - self.module_stats[module]['start_time'] = current_time - - # Update counters - self.module_stats[module]['requests'] += 1 - self.module_stats[module]['last_request_time'] = current_time - self.module_stats[module]['request_timestamps'].append(current_timestamp) - self.global_request_timestamps.append(current_timestamp) - - # Calculate instant QPS for this module - self._calculate_instant_qps(module, current_timestamp) - - # Calculate global instant QPS - self._calculate_global_instant_qps(current_timestamp) - - if success: - self.module_stats[module]['tokens_in'] += record.get('input_tokens', 0) - self.module_stats[module]['tokens_out'] += record.get('output_tokens', 0) - self.module_stats[module]['cost'] += record.get('cost', 0.0) - else: - self.module_stats[module]['errors'] += 1 - - def _calculate_instant_qps(self, module: str, current_timestamp: float): - """Calculate instant QPS for a specific module using sliding window""" - # Keep only timestamps from last 1 second for instant QPS - cutoff_time = current_timestamp - 1.0 - timestamps = self.module_stats[module]['request_timestamps'] - - # Remove old timestamps - self.module_stats[module]['request_timestamps'] = [ - ts for ts in timestamps if ts >= cutoff_time - ] - - # Calculate current QPS (requests in last second) - current_qps = len(self.module_stats[module]['request_timestamps']) - self.module_stats[module]['current_qps'] = current_qps - - # Update max QPS if current is higher - if current_qps > self.module_stats[module]['max_qps']: - self.module_stats[module]['max_qps'] = current_qps - - # Store QPS history (keep last 60 measurements) - self.module_stats[module]['qps_history'].append(current_qps) - if len(self.module_stats[module]['qps_history']) > 60: - self.module_stats[module]['qps_history'].pop(0) - - def _calculate_global_instant_qps(self, current_timestamp: float): - """Calculate global instant QPS across all modules""" - # Keep only timestamps from last 1 second - cutoff_time = current_timestamp - 1.0 - self.global_request_timestamps = [ - ts for ts in self.global_request_timestamps if ts >= cutoff_time - ] - - # Calculate current global QPS - current_global_qps = len(self.global_request_timestamps) - - # Update max global QPS - if current_global_qps > self.global_max_qps: - self.global_max_qps = current_global_qps - - def get_instant_qps(self, module: str = None) -> Dict[str, Any]: - """Get instant QPS data for modules""" - if module: - if module in self.module_stats: - return { - 'module': module, - 'current_qps': self.module_stats[module]['current_qps'], - 'max_qps': self.module_stats[module]['max_qps'], - 'avg_qps_last_minute': sum(self.module_stats[module]['qps_history'][-60:]) / min(60, len(self.module_stats[module]['qps_history'])) if self.module_stats[module]['qps_history'] else 0 - } - else: - return {'module': module, 'current_qps': 0, 'max_qps': 0, 'avg_qps_last_minute': 0} - else: - # Return data for all modules plus global - result = { - 'global': { - 'current_qps': len([ts for ts in self.global_request_timestamps if ts >= time.time() - 1.0]), - 'max_qps': self.global_max_qps - }, - 'modules': {} - } - - for mod in self.module_stats: - result['modules'][mod] = { - 'current_qps': self.module_stats[mod]['current_qps'], - 'max_qps': self.module_stats[mod]['max_qps'], - 'avg_qps_last_minute': sum(self.module_stats[mod]['qps_history'][-60:]) / min(60, len(self.module_stats[mod]['qps_history'])) if self.module_stats[mod]['qps_history'] else 0 - } - - return result - - def get_usage_summary(self) -> Dict[str, Any]: - """Get essential usage statistics""" - if not self.usage_data: - return { - "total_requests": 0, - "total_cost": 0.0, - "error_rate": 0.0, - "message": "No usage data available" - } - - total_requests = len(self.usage_data) - total_errors = len(self.error_data) - total_cost = sum(record['cost'] for record in self.usage_data) - total_input_tokens = sum(record['input_tokens'] for record in self.usage_data) - total_output_tokens = sum(record['output_tokens'] for record in self.usage_data) - - # Calculate session duration - duration_minutes = (datetime.now() - self.start_time).total_seconds() / 60 - - # Build module statistics - module_stats = {} - for module, stats in self.module_stats.items(): - if stats['requests'] > 0: - module_stats[module] = { - "requests": stats['requests'], - "errors": stats['errors'], - "success_rate": ((stats['requests'] - stats['errors']) / stats['requests'] * 100) if stats['requests'] > 0 else 0, - "tokens_in": stats['tokens_in'], - "tokens_out": stats['tokens_out'], - "cost": stats['cost'], - "current_qps": stats['current_qps'], - "max_qps": stats['max_qps'] - } - - return { - "session_duration_minutes": duration_minutes, - "total_requests": total_requests, - "total_errors": total_errors, - "error_rate": (total_errors / total_requests * 100) if total_requests > 0 else 0, - "total_input_tokens": total_input_tokens, - "total_output_tokens": total_output_tokens, - "total_cost": total_cost, - "module_stats": module_stats, - "global_max_qps": self.global_max_qps - } - - def print_usage_summary(self): - """Print essential usage summary""" - stats = self.get_usage_summary() - - if stats.get('message'): - print(f"📊 {stats['message']}") - return - - print("\n📊 USAGE SUMMARY") - print(f"{'='*50}") - print(f"⏱️ Duration: {stats['session_duration_minutes']:.1f} min") - print(f"📈 Requests: {stats['total_requests']}") - print(f"❌ Errors: {stats['total_errors']}") - print(f"💰 Cost: ${stats['total_cost']:.4f}") - print(f"🏆 Global Max QPS: {stats['global_max_qps']}") - - # Module statistics - if stats.get('module_stats'): - print("\n📦 MODULES:") - for module, mod_stats in stats['module_stats'].items(): - print(f" {module}: {mod_stats['requests']} req, Max QPS: {mod_stats['max_qps']}, Current: {mod_stats['current_qps']}") - - print(f"{'='*50}") - - def save_usage_data(self, filename: str = "litellm_usage.json"): - """Save usage data to JSON file""" - data = { - "summary": self.get_usage_summary(), - "detailed_usage": self.usage_data, - "errors": self.error_data, - "export_timestamp": datetime.now().isoformat() - } - - with open(filename, 'w') as f: - json.dump(data, f, indent=2) - - print(f"📁 Usage data saved to {filename}") - - def reset_tracking(self): - """Reset all tracking data""" - self.usage_data = [] - self.error_data = [] - self.module_stats = defaultdict(lambda: { - 'requests': 0, - 'tokens_in': 0, - 'tokens_out': 0, - 'cost': 0.0, - 'errors': 0, - 'start_time': None, - 'last_request_time': None, - 'request_timestamps': [], - 'current_qps': 0.0, - 'max_qps': 0.0, - 'qps_history': [] - }) - self.global_request_timestamps = [] - self.global_max_qps = 0.0 - self.start_time = datetime.now() - print("🔄 All tracking data reset") - -# Global instance for easy access -litellm_config = LiteLLMConfig() - -def setup_litellm_enhanced(max_retries: int = 3): - """ - Quick setup function for LiteLLM enhanced configuration - - Args: - max_retries: Maximum number of retries for failed requests - """ - litellm_config.setup_enhanced_config(max_retries) - return litellm_config - -def get_usage_summary(): - """Get current usage summary""" - return litellm_config.get_usage_summary() - -def print_usage_summary(): - """Print current usage summary""" - litellm_config.print_usage_summary() - -def save_usage_data(filename: str = "litellm_usage.json"): - """Save usage data to file""" - litellm_config.save_usage_data(filename) - -def get_instant_qps(module: str = None) -> Dict[str, Any]: - """Get instant QPS data for modules""" - return litellm_config.get_instant_qps(module) - -def print_instant_qps(module: str = None): - """Print instant QPS information""" - qps_data = get_instant_qps(module) - - print("\n⚡ INSTANT QPS MONITOR") - print(f"{'='*60}") - - if module: - print(f"Module: {qps_data['module']}") - print(f" Current QPS: {qps_data['current_qps']}") - print(f" Max QPS: {qps_data['max_qps']}") - print(f" Avg (1min): {qps_data['avg_qps_last_minute']:.2f}") - else: - # Global stats - global_data = qps_data.get('global', {}) - print("🌍 GLOBAL:") - print(f" Current QPS: {global_data.get('current_qps', 0)}") - print(f" Max QPS: {global_data.get('max_qps', 0)}") - - # Module stats - modules = qps_data.get('modules', {}) - if modules: - print("\n📦 MODULES:") - for mod, data in modules.items(): - print(f" {mod}:") - print(f" Current: {data['current_qps']} QPS") - print(f" Max: {data['max_qps']} QPS") - print(f" Avg: {data['avg_qps_last_minute']:.2f} QPS") - - print(f"{'='*60}") - -def reset_tracking(): - """Reset all tracking data""" - litellm_config.reset_tracking() - -def get_module_stats() -> Dict[str, Dict[str, Any]]: - """Get detailed module statistics""" - summary = get_usage_summary() - return summary.get('module_stats', {}) diff --git a/api/app/core/memory/utils/self_reflexion_utils/__init__.py b/api/app/core/memory/utils/self_reflexion_utils/__init__.py deleted file mode 100644 index 422a83e3..00000000 --- a/api/app/core/memory/utils/self_reflexion_utils/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# -*- coding: utf-8 -*- -"""自我反思工具模块 - -本模块提供自我反思引擎的核心功能,包括: -- 记忆冲突判定 -- 反思执行 -- 记忆更新 - -从 app.core.memory.src.data_config_api 迁移而来。 -""" - -from app.core.memory.utils.self_reflexion_utils.evaluate import conflict -from app.core.memory.utils.self_reflexion_utils.reflexion import reflexion -from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion - -__all__ = ["conflict", "reflexion", "self_reflexion"] diff --git a/api/app/core/memory/utils/self_reflexion_utils/evaluate.py b/api/app/core/memory/utils/self_reflexion_utils/evaluate.py deleted file mode 100644 index 4d1835cd..00000000 --- a/api/app/core/memory/utils/self_reflexion_utils/evaluate.py +++ /dev/null @@ -1,52 +0,0 @@ -# -*- coding: utf-8 -*- -"""记忆冲突判定模块 - -本模块提供记忆冲突判定功能,使用LLM判断记忆数据中是否存在冲突。 -从 app.core.memory.src.data_config_api.evaluate 迁移而来。 -""" - -import logging -import time -from typing import Any, List - -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.core.memory.utils.prompt.template_render import render_evaluate_prompt -from app.db import get_db_context -from app.schemas.memory_storage_schema import ConflictResultSchema -from pydantic import BaseModel - - -async def conflict(evaluate_data: List[Any]) -> List[Any]: - """ - Evaluates memory conflict using the evaluate.jinja2 template. - - Args: - evaluate_data: 反思数据列表。 - Returns: - 冲突记忆列表(JSON 数组)。 - """ - from app.core.memory.utils.config import definitions as config_defs - with get_db_context() as db: - factory = MemoryClientFactory(db) - client = factory.get_llm_client(config_defs.SELECTED_LLM_ID) - rendered_prompt = await render_evaluate_prompt(evaluate_data, ConflictResultSchema) - messages = [{"role": "user", "content": rendered_prompt}] - print(f"提示词长度: {len(rendered_prompt)}") - print(f"====== 冲突判定开始 ======\n") - start_time = time.time() - response = await client.response_structured(messages, ConflictResultSchema) - end_time = time.time() - print(f"冲突判定耗时: {end_time - start_time} 秒") - print(f"冲突判定原始输出:(type={type(response)})\n{response}") - - if not response: - logging.error("LLM 冲突判定输出解析失败,返回空列表以继续流程。") - return [] - try: - return [response.model_dump()] if isinstance(response, BaseModel) else [response] - except Exception: - try: - return [response.dict()] - except Exception: - logging.warning("无法标准化冲突判定返回类型,尝试直接封装为列表。") - return [response] diff --git a/api/app/core/memory/utils/self_reflexion_utils/reflexion.py b/api/app/core/memory/utils/self_reflexion_utils/reflexion.py deleted file mode 100644 index 1b915118..00000000 --- a/api/app/core/memory/utils/self_reflexion_utils/reflexion.py +++ /dev/null @@ -1,54 +0,0 @@ -# -*- coding: utf-8 -*- -"""反思执行模块 - -本模块提供反思执行功能,使用LLM对冲突记忆进行反思和解决。 -从 app.core.memory.src.data_config_api.reflexion 迁移而来。 -""" - -import logging -import time -from typing import Any, List - -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.core.memory.utils.prompt.template_render import render_reflexion_prompt -from app.db import get_db_context -from app.schemas.memory_storage_schema import ReflexionResultSchema -from pydantic import BaseModel - - -async def reflexion(ref_data: List[Any]) -> List[Any]: - """ - Reflexes on the given reference data using the reflexion.jinja2 template. - - Args: - ref_data: 反思数据列表。 - Returns: - 反思结果列表(JSON 数组)。 - """ - from app.core.memory.utils.config import definitions as config_defs - with get_db_context() as db: - factory = MemoryClientFactory(db) - client = factory.get_llm_client(config_defs.SELECTED_LLM_ID) - rendered_prompt = await render_reflexion_prompt(ref_data, ReflexionResultSchema) - messages = [{"role": "user", "content": rendered_prompt}] - print(f"提示词长度: {len(rendered_prompt)}") - - print(f"====== 反思开始 ======\n") - start_time = time.time() - response = await client.response_structured(messages, ReflexionResultSchema) - end_time = time.time() - print(f"反思耗时: {end_time - start_time} 秒") - print(f"反思原始输出:(type={type(response)})\n{response}") - - if not response: - logging.error("LLM 反思输出解析失败,返回空列表以继续流程。") - return [] - # 统一返回为列表[dict],便于自我反思主流程更新数据库 - try: - return [response.model_dump()] if isinstance(response, BaseModel) else [response] - except Exception: - try: - return [response.dict()] - except Exception: - logging.warning("无法标准化反思返回类型,尝试直接封装为列表。") - return [response] diff --git a/api/app/core/memory/utils/self_reflexion_utils/self_reflexion.py b/api/app/core/memory/utils/self_reflexion_utils/self_reflexion.py deleted file mode 100644 index 934037b0..00000000 --- a/api/app/core/memory/utils/self_reflexion_utils/self_reflexion.py +++ /dev/null @@ -1,254 +0,0 @@ -# -*- coding: utf-8 -*- -"""自我反思主执行模块 - -本模块提供自我反思引擎的主流程,包括: -- 获取反思数据 -- 冲突判断 -- 反思执行 -- 记忆更新 - -从 app.core.memory.src.data_config_api.self_reflexion 迁移而来。 -""" - -import asyncio -import json -import logging -import os -import uuid -from typing import Any, Dict, List - -#TODO: Fix this - -# Default values (previously from definitions.py) -REFLEXION_ENABLED = os.getenv("REFLEXION_ENABLED", "false").lower() == "true" -REFLEXION_ITERATION_PERIOD = os.getenv("REFLEXION_ITERATION_PERIOD", "3") -REFLEXION_RANGE = os.getenv("REFLEXION_RANGE", "retrieval") -REFLEXION_BASELINE = os.getenv("REFLEXION_BASELINE", "TIME") - -from app.core.memory.utils.config.get_data import get_data -from app.core.memory.utils.self_reflexion_utils.evaluate import conflict -from app.core.memory.utils.self_reflexion_utils.reflexion import reflexion -from app.db import get_db -from app.models.retrieval_info import RetrievalInfo -from app.repositories.neo4j.cypher_queries import UPDATE_STATEMENT_INVALID_AT -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from sqlalchemy.orm import Session - -# 并发限制(可通过环境变量覆盖) -CONCURRENCY = int(os.getenv("REFLEXION_CONCURRENCY", "5")) - -# 确保 INFO 级别日志输出到终端 -_root_logger = logging.getLogger() -if not _root_logger.handlers: - logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") -else: - _root_logger.setLevel(logging.INFO) - - -async def get_reflexion_data(host_id: uuid.UUID) -> List[Any]: - """ - 根据反思范围获取判断的记忆数据。 - - Args: - host_id: 主机ID - Returns: - 符合反思范围的记忆数据列表。 - """ - if REFLEXION_RANGE == "partial": - return await get_data(host_id) - elif REFLEXION_RANGE == "all": - return [] - else: - raise ValueError(f"未知的反思范围: {REFLEXION_RANGE}") - - -async def run_conflict(conflict_data: List[Any]) -> List[Any]: - """ - 判断反思数据中是否存在冲突。 - - Args: - conflict_data: 冲突数据列表。 - Returns: - 如果存在冲突则返回冲突记忆列表,否则返回空列表。 - """ - if not conflict_data: - return [] - - conflict_data = await conflict(conflict_data) - # 仅保留存在冲突的条目(conflict == True) - try: - return [c for c in conflict_data if isinstance(c, dict) and c.get("conflict") is True] - except Exception: - return [] - - -async def run_reflexion(reflexion_data: List[Any]) -> Any: - """ - 执行反思,解决冲突。 - - Args: - reflexion_data: 反思数据列表。 - Returns: - 解决冲突后的反思结果(由 LLM 返回)。 - """ - if not reflexion_data: - return [] - # 并行对每个冲突进行反思,整体缩短等待时间 - sem = asyncio.Semaphore(CONCURRENCY) - - async def _reflex_one(item: Any) -> Dict[str, Any] | None: - async with sem: - try: - result_list = await reflexion([item]) - if not result_list: - return None - obj = result_list[0] - if hasattr(obj, "model_dump"): - return obj.model_dump() - elif hasattr(obj, "dict"): - return obj.dict() - elif isinstance(obj, dict): - return obj - except Exception as e: - logging.warning(f"反思失败,跳过一项: {e}") - return None - - tasks = [_reflex_one(item) for item in reflexion_data] - results = await asyncio.gather(*tasks, return_exceptions=False) - return [r for r in results if r] - - -async def update_memory(solved_data: List[Any], host_id: uuid.UUID) -> str: - """ - 更新记忆库,将解决冲突后的记忆更新到记忆库中。 - - Args: - solved_data: 解决冲突后的记忆(由 LLM 返回)。 - host_id: 主机ID - Returns: - 更新结果(成功或失败)。 - """ - flag = False - if not solved_data: - return "数据缺失,更新失败" - if not isinstance(solved_data, list): - return "数据格式错误,更新失败" - neo4j_connector = Neo4jConnector() - try: - print(f"====== 更新记忆开始 ======\n") - - sem = asyncio.Semaphore(CONCURRENCY) - success_count = 0 - - async def _update_one(item: Dict[str, Any]) -> bool: - async with sem: - try: - if not isinstance(item, dict): - return False - if not item: - return False - resolved = item.get("resolved") - if not isinstance(resolved, dict) or not resolved: - logging.warning(f"反思结果无可更新内容,跳过此项: {item}") - return False - resolved_mem = resolved.get("resolved_memory") - if not isinstance(resolved_mem, dict) or not resolved_mem: - logging.warning(f"反思结果缺少 resolved_memory,跳过此项: {item}") - return False - group_id = resolved_mem.get("group_id") - id = resolved_mem.get("id") - # 使用 invalid_at 字段作为新的失效时间 - new_invalid_at = resolved_mem.get("invalid_at") - if not all([group_id, id, new_invalid_at]): - logging.warning(f"记忆更新参数缺失,跳过此项: {item}") - return False - await neo4j_connector.execute_query( - UPDATE_STATEMENT_INVALID_AT, - group_id=group_id, - id=id, - new_invalid_at=new_invalid_at, - ) - return True - except Exception as e: - logging.error(f"更新单条记忆失败: {e}") - return False - - tasks = [_update_one(item) for item in solved_data if isinstance(item, dict)] - results = await asyncio.gather(*tasks, return_exceptions=False) - success_count = sum(1 for r in results if r) - - logging.info(f"成功更新 {success_count} 条记忆") - flag = success_count > 0 - return "更新成功" if flag else "更新失败" - except Exception as e: - logging.error(f"更新记忆库失败: {e}") - return "更新失败" - finally: - if flag: # 删除数据库中的检索数据 - db: Session = next(get_db()) - try: - db.query(RetrievalInfo).filter(RetrievalInfo.host_id == host_id).delete() - db.commit() - logging.info(f"成功删除 {success_count} 条检索数据") - except Exception as e: - logging.error(f"删除数据库中的检索数据失败: {e}") - finally: - db.close() - - - -async def _append_json(label: str, data: Any) -> None: - """记录冲突记忆(后台线程写入,避免阻塞事件循环)""" - def _write(): - with open("reflexion_data.json", "a", encoding="utf-8") as f: - f.write(f"### {label} ###\n") - json.dump(data, f, ensure_ascii=False, indent=4) - f.write("\n\n") - # 正确地在协程内等待后台线程执行,避免未等待的协程警告 - await asyncio.to_thread(_write) - - -async def self_reflexion(host_id: uuid.UUID) -> str: - """ - 自我反思引擎,执行反思流程。 - - Args: - host_id: 主机ID - - Returns: - 反思结果描述字符串 - """ - if not REFLEXION_ENABLED: - return "未开启反思..." - print(f"====== 自我反思流程开始 ======\n") - reflexion_data = await get_reflexion_data(host_id) - if not reflexion_data: - print(f"====== 自我反思流程结束 ======\n") - return "无反思数据,结束反思" - print(f"反思数据获取成功,共 {len(reflexion_data)} 条") - - conflict_data = await run_conflict(reflexion_data) - if not conflict_data: - print(f"====== 自我反思流程结束 ======\n") - return "无冲突,无需反思" - print(f"冲突记忆类型: {type(conflict_data)}") - await _append_json("conflict", conflict_data) - - solved_data = await run_reflexion(conflict_data) - if not solved_data: - print(f"====== 自我反思流程结束 ======\n") - return "反思失败,未解决冲突" - print(f"解决冲突后的记忆类型: {type(solved_data)}") - await _append_json("solved_data", solved_data) - - result = await update_memory(solved_data, host_id) - print(f"更新记忆库结果: {result}") - print(f"====== 自我反思流程结束 ======\n") - return result - - -if __name__ == "__main__": - import asyncio - # host_id = uuid.UUID("3f6ff1eb-50c7-4765-8e89-e4566be33333") - host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122") - asyncio.run(self_reflexion(host_id)) diff --git a/api/app/repositories/neo4j/graph_search.py b/api/app/repositories/neo4j/graph_search.py index 80756793..0b6a27c6 100644 --- a/api/app/repositories/neo4j/graph_search.py +++ b/api/app/repositories/neo4j/graph_search.py @@ -1,29 +1,30 @@ -from typing import Any, Dict, List, Optional import asyncio import logging +from typing import Any, Dict, List, Optional + +from app.repositories.neo4j.cypher_queries import ( + CHUNK_EMBEDDING_SEARCH, + ENTITY_EMBEDDING_SEARCH, + MEMORY_SUMMARY_EMBEDDING_SEARCH, + SEARCH_CHUNK_BY_CHUNK_ID, + SEARCH_CHUNKS_BY_CONTENT, + SEARCH_DIALOGUE_BY_DIALOG_ID, + SEARCH_ENTITIES_BY_NAME, + SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, + SEARCH_STATEMENTS_BY_CREATED_AT, + SEARCH_STATEMENTS_BY_KEYWORD, + SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, + SEARCH_STATEMENTS_BY_TEMPORAL, + SEARCH_STATEMENTS_BY_VALID_AT, + SEARCH_STATEMENTS_G_CREATED_AT, + SEARCH_STATEMENTS_G_VALID_AT, + SEARCH_STATEMENTS_L_CREATED_AT, + SEARCH_STATEMENTS_L_VALID_AT, + STATEMENT_EMBEDDING_SEARCH, +) # 使用新的仓储层 from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.repositories.neo4j.cypher_queries import ( - SEARCH_STATEMENTS_BY_KEYWORD, - SEARCH_ENTITIES_BY_NAME, - SEARCH_CHUNKS_BY_CONTENT, - STATEMENT_EMBEDDING_SEARCH, - CHUNK_EMBEDDING_SEARCH, - ENTITY_EMBEDDING_SEARCH, - SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, - MEMORY_SUMMARY_EMBEDDING_SEARCH, - SEARCH_STATEMENTS_BY_TEMPORAL, - SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, - SEARCH_DIALOGUE_BY_DIALOG_ID, - SEARCH_CHUNK_BY_CHUNK_ID, - SEARCH_STATEMENTS_BY_CREATED_AT, - SEARCH_STATEMENTS_BY_VALID_AT, - SEARCH_STATEMENTS_G_CREATED_AT, - SEARCH_STATEMENTS_L_CREATED_AT, - SEARCH_STATEMENTS_G_VALID_AT, - SEARCH_STATEMENTS_L_VALID_AT, -) logger = logging.getLogger(__name__) @@ -55,8 +56,12 @@ async def _update_activation_values_batch( return [] # 延迟导入以避免循环依赖 - from app.core.memory.storage_services.forgetting_engine.access_history_manager import AccessHistoryManager - from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator + from app.core.memory.storage_services.forgetting_engine.access_history_manager import ( + AccessHistoryManager, + ) + from app.core.memory.storage_services.forgetting_engine.actr_calculator import ( + ACTRCalculator, + ) # 创建计算器和管理器实例 actr_calculator = ACTRCalculator() @@ -292,6 +297,13 @@ async def search_graph( else: results[key] = result + # Deduplicate results before updating activation values + # This prevents duplicates from propagating through the pipeline + from app.core.memory.src.search import _deduplicate_results + for key in results: + if isinstance(results[key], list): + results[key] = _deduplicate_results(results[key]) + # 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary) results = await _update_search_results_activation( connector=connector, @@ -397,6 +409,13 @@ async def search_graph_by_embedding( else: results[key] = result + # Deduplicate results before updating activation values + # This prevents duplicates from propagating through the pipeline + from app.core.memory.src.search import _deduplicate_results + for key in results: + if isinstance(results[key], list): + results[key] = _deduplicate_results(results[key]) + # 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary) update_start = time.time() results = await _update_search_results_activation(