From 9d36ec70bc107a74d1da958a05f3453a7bdecd80 Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Tue, 20 Jan 2026 11:24:33 +0800 Subject: [PATCH 1/7] Fix/memory bug fix (#157) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 图谱数据量限制数量去掉 * 图谱数据量限制数量去掉 * 图谱数据量限制数量去掉 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 用户详情优化 * 读取的接口,去掉全局锁 * 输出数组 * 反思优化1.0(优化隐私输出、时间检索) * 反思优化1.0(优化隐私输出、时间检索) * 反思优化1.0(优化隐私输出、时间检索) * 反思优化测试接口 * 反思优化测试接口 --- .../memory_reflection_controller.py | 13 ++- api/app/services/memory_reflection_service.py | 88 ++++++++++++++++--- 2 files changed, 89 insertions(+), 12 deletions(-) diff --git a/api/app/controllers/memory_reflection_controller.py b/api/app/controllers/memory_reflection_controller.py index 24c143b9..9be6e035 100644 --- a/api/app/controllers/memory_reflection_controller.py +++ b/api/app/controllers/memory_reflection_controller.py @@ -119,11 +119,20 @@ async def start_workspace_reflection( end_users = data['end_users'] for base, config, user in zip(releases, data_configs, end_users): - if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']: + # 安全地转换为整数,处理空字符串和None的情况 + print(base['config']) + try: + base_config = int(base['config']) if base['config'] else 0 + config_id = int(config['config_id']) if config['config_id'] else 0 + except (ValueError, TypeError): + api_logger.warning(f"无效的配置ID: base['config']={base.get('config')}, config['config_id']={config.get('config_id')}") + continue + + if base_config == config_id and base['app_id'] == user['app_id']: # 调用反思服务 api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}") - reflection_result = await reflection_service.start_reflection_from_data( + reflection_result = await reflection_service.start_text_reflection( config_data=config, end_user_id=user['id'] ) diff --git a/api/app/services/memory_reflection_service.py b/api/app/services/memory_reflection_service.py index 015cc08a..46e42b46 100644 --- a/api/app/services/memory_reflection_service.py +++ b/api/app/services/memory_reflection_service.py @@ -208,6 +208,47 @@ class MemoryReflectionService: def __init__(self,db: Session = Depends(get_db)): self.db=db + async def start_text_reflection(self, config_data: Dict[str, Any], end_user_id: str) -> Dict[str, Any]: + try: + config_id = config_data.get("config_id") + api_logger.info(f"从配置数据启动反思,config_id: {config_id}, end_user_id: {end_user_id}") + + if not config_data.get("enable_self_reflexion", False): + return { + "status": "跳过", + "message": "反思引擎未启用", + "config_id": config_id, + "end_user_id": end_user_id, + "config_data": config_data + } + + config_data_id = config_data['config_id'] + reflection_config = WorkspaceAppService(self.db)._get_data_config(config_data_id) + if reflection_config is not None and reflection_config['enable_self_reflexion']: + reflection_config = self._create_reflection_config_from_data(reflection_config) + # 3. 执行反思引擎 + reflection_results = await self._execute_reflection_engine( + reflection_config, end_user_id + ) + return { + "status": "完成", + "message": "反思引擎执行完成", + "config_id": config_id, + "end_user_id": end_user_id, + "config_data": config_data, + "reflection_results": reflection_results + } + + except Exception as e: + config_id = config_data.get("config_id", "unknown") + api_logger.error(f"启动反思失败,config_id: {config_id}, end_user_id: {end_user_id}, 错误: {str(e)}") + return { + "status": "错误", + "message": f"启动反思失败: {str(e)}", + "config_id": config_id, + "end_user_id": end_user_id, + "config_data": config_data + } async def start_reflection_from_data(self, config_data: Dict[str, Any], end_user_id: str) -> Dict[str, Any]: """ @@ -239,16 +280,41 @@ class MemoryReflectionService: reflection_config=WorkspaceAppService(self.db)._get_data_config(config_data_id) if reflection_config is not None and reflection_config['enable_self_reflexion']: reflection_config= self._create_reflection_config_from_data(reflection_config) - iteration_period=reflection_config.iteration_period + iteration_period = int(reflection_config.iteration_period) workspace_service = WorkspaceAppService(self.db) current_reflection_time = workspace_service.get_end_user_reflection_time(end_user_id) - reflection_time = datetime.fromisoformat(str(current_reflection_time)) - - current_time = datetime.now() - time_diff = current_time - reflection_time - hours_diff = int(time_diff.total_seconds() / 3600) - if iteration_period==hours_diff or current_reflection_time is None: + # 检查是否需要执行反思 + should_execute = False + hours_diff = 0 + + if current_reflection_time is None: + # 首次执行反思 + should_execute = True + api_logger.info(f"首次执行反思,end_user_id: {end_user_id}") + else: + # 计算时间差 + try: + if isinstance(current_reflection_time, str): + reflection_time = datetime.fromisoformat(current_reflection_time) + else: + reflection_time = current_reflection_time + + current_time = datetime.now() + time_diff = current_time - reflection_time + hours_diff = int(time_diff.total_seconds() / 3600) + + # 检查是否达到反思周期 + if hours_diff >= iteration_period: + should_execute = True + api_logger.info(f"与上次的反思时间间隔为: {hours_diff} 小时,达到周期 {iteration_period} 小时") + else: + api_logger.info(f"与上次的反思时间间隔为: {hours_diff} 小时,未达到周期 {iteration_period} 小时") + except (ValueError, TypeError) as e: + api_logger.warning(f"解析反思时间失败: {e},将执行反思") + should_execute = True + + if should_execute: api_logger.info(f"与上次的反思时间间隔为: {hours_diff} 小时") # 3. 执行反思引擎 reflection_results = await self._execute_reflection_engine( @@ -271,13 +337,15 @@ class MemoryReflectionService: } else: return { - "status": "等待中..", - "message": "反思引擎未开始执行执", + "status": "等待中", + "message": f"反思引擎未开始执行,距离下次执行还需 {iteration_period - hours_diff} 小时", "config_id": config_id, "end_user_id": end_user_id, "config_data": config_data, - "reflection_results": '' + "hours_since_last_reflection": hours_diff, + "next_reflection_in_hours": iteration_period - hours_diff } + except Exception as e: config_id = config_data.get("config_id", "unknown") From 83fe793e72f42b2b37ab13d05a35d7b2a28d294e Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Tue, 20 Jan 2026 15:03:29 +0800 Subject: [PATCH 2/7] refactor(memory): clean up deprecated config and self-reflexion utilities - Remove deprecated self_reflexion endpoint from memory_storage_controller - Delete obsolete config modules (config_optimization, definitions, get_example_data, litellm_config) - Remove self_reflexion_utils package and related evaluation/reflexion modules - Refactor hot_memory_tags to use Neo4jConnector instead of direct GraphDatabase connection - Simplify LLM client initialization by removing DEFAULT_LLM_ID fallback logic - Remove unnecessary sys.path manipulation and project root resolution code - Update filter_tags_with_llm to properly handle missing config with clear error messages - Migrate get_raw_tags_from_db to async function using Neo4jConnector - Consolidate imports and remove unused dependencies (uuid, sys) - Improve error handling with explicit ValueError messages for missing configuration --- .../controllers/memory_storage_controller.py | 17 - .../core/memory/analytics/hot_memory_tags.py | 239 +++----- api/app/core/memory/src/search.py | 323 ++--------- .../reflection_engine/self_reflexion.py | 2 - api/app/core/memory/utils/config/__init__.py | 33 -- .../utils/config/config_optimization.py | 398 -------------- .../core/memory/utils/config/definitions.py | 268 --------- .../memory/utils/config/get_example_data.py | 90 --- .../memory/utils/config/litellm_config.py | 516 ------------------ .../utils/self_reflexion_utils/__init__.py | 16 - .../utils/self_reflexion_utils/evaluate.py | 52 -- .../utils/self_reflexion_utils/reflexion.py | 54 -- .../self_reflexion_utils/self_reflexion.py | 254 --------- api/app/repositories/neo4j/graph_search.py | 65 ++- 14 files changed, 190 insertions(+), 2137 deletions(-) delete mode 100644 api/app/core/memory/utils/config/config_optimization.py delete mode 100644 api/app/core/memory/utils/config/definitions.py delete mode 100644 api/app/core/memory/utils/config/get_example_data.py delete mode 100644 api/app/core/memory/utils/config/litellm_config.py delete mode 100644 api/app/core/memory/utils/self_reflexion_utils/__init__.py delete mode 100644 api/app/core/memory/utils/self_reflexion_utils/evaluate.py delete mode 100644 api/app/core/memory/utils/self_reflexion_utils/reflexion.py delete mode 100644 api/app/core/memory/utils/self_reflexion_utils/self_reflexion.py diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index c58ecd6d..63d9078a 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -1,10 +1,8 @@ import os -import uuid from typing import Optional from app.core.error_codes import BizCode from app.core.logging_config import get_api_logger -from app.core.memory.utils.self_reflexion_utils import self_reflexion from app.core.response_utils import fail, success from app.db import get_db from app.dependencies import get_current_user @@ -458,18 +456,3 @@ async def get_recent_activity_stats_api( api_logger.error(f"Recent activity stats failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e)) - - - -@router.get("/self_reflexion") -async def self_reflexion_endpoint(host_id: uuid.UUID) -> str: - """ - 自我反思接口,自动对检索出的信息进行自我反思并返回自我反思结果。 - - Args: - None - Returns: - 自我反思结果。 - """ - return await self_reflexion(host_id) - diff --git a/api/app/core/memory/analytics/hot_memory_tags.py b/api/app/core/memory/analytics/hot_memory_tags.py index 2aa286ba..cab6cacd 100644 --- a/api/app/core/memory/analytics/hot_memory_tags.py +++ b/api/app/core/memory/analytics/hot_memory_tags.py @@ -1,48 +1,15 @@ import asyncio -import os -import sys -from typing import List, Tuple - -from neo4j import GraphDatabase -from pydantic import BaseModel, Field - -# ------------------- 自包含路径解析 ------------------- -# 这个代码块确保脚本可以从任何地方运行,并且仍然可以在项目结构中找到它需要的模块。 -try: - # 假设脚本在 /path/to/project/src/analytics/ - # 上升3个级别以到达项目根目录。 - project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) - src_path = os.path.join(project_root, 'src') - - # 将 'src' 和 'project_root' 都添加到路径中。 - # 'src' 目录对于像 'from utils.config_utils import ...' 这样的导入是必需的。 - # 'project_root' 目录对于像 'from variate_config import ...' 这样的导入是必需的。 - if src_path not in sys.path: - sys.path.insert(0, src_path) - if project_root not in sys.path: - sys.path.insert(0, project_root) -except NameError: - # 为 __file__ 未定义的环境(例如某些交互式解释器)提供回退方案 - project_root = os.path.abspath(os.path.join(os.getcwd())) - src_path = os.path.join(project_root, 'src') - if src_path not in sys.path: - sys.path.insert(0, src_path) - if project_root not in sys.path: - sys.path.insert(0, project_root) -# --------------------------------------------------------------------- - -# 现在路径已经配置好,我们可以使用绝对导入 import json +import os +from typing import List, Tuple from app.core.config import settings from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context +from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.services.memory_config_service import MemoryConfigService +from pydantic import BaseModel, Field -#TODO: Fix this -# Default values (previously from definitions.py) -DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus") -DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123") # 定义用于LLM结构化输出的Pydantic模型 class FilteredTags(BaseModel): @@ -52,34 +19,45 @@ class FilteredTags(BaseModel): async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]: """ 使用LLM筛选标签列表,仅保留具有代表性的核心名词。 + + Args: + tags: 原始标签列表 + group_id: 用户组ID,用于获取配置 + + Returns: + 筛选后的标签列表 + + Raises: + ValueError: 如果无法获取有效的LLM配置 """ try: # Get config_id using get_end_user_connected_config with get_db_context() as db: - try: - from app.services.memory_agent_service import ( - get_end_user_connected_config, + from app.services.memory_agent_service import ( + get_end_user_connected_config, + ) + + connected_config = get_end_user_connected_config(group_id, db) + config_id = connected_config.get("memory_config_id") + + if not config_id: + raise ValueError( + f"No memory_config_id found for group_id: {group_id}. " + "Please ensure the user has a valid memory configuration." ) - connected_config = get_end_user_connected_config(group_id, db) - config_id = connected_config.get("memory_config_id") - - if config_id: - # Use the config_id to get the proper LLM client - config_service = MemoryConfigService(db) - memory_config = config_service.load_memory_config(config_id) - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(memory_config.llm_model_id) - else: - # TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config - # Fallback to default LLM if no config found - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(DEFAULT_LLM_ID) - except Exception as e: - print(f"Failed to get user connected config, using default LLM: {e}") - # TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config - # Fallback to default LLM - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(DEFAULT_LLM_ID) + + # Use the config_id to get the proper LLM client + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config(config_id) + + if not memory_config.llm_model_id: + raise ValueError( + f"No llm_model_id found in memory config {config_id}. " + "Please configure a valid LLM model." + ) + + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(memory_config.llm_model_id) # 3. 构建Prompt tag_list_str = ", ".join(tags) @@ -107,33 +85,26 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]: # 在LLM失败时返回原始标签,确保流程继续 return tags -def get_db_connection(): - """ - 使用项目的标准配置方法建立与Neo4j数据库的连接。 - """ - # 从全局配置获取 Neo4j 连接信息 - uri = settings.NEO4J_URI - user = settings.NEO4J_USERNAME - - # 密码必须为了安全从环境变量加载 - password = os.getenv("NEO4J_PASSWORD") - - if not uri or not user: - raise ValueError("在 config.json 中未找到 Neo4j 的 'uri' 或 'username'。") - if not password: - raise ValueError("NEO4J_PASSWORD 环境变量未设置。") - - # 为此脚本使用同步驱动 - return GraphDatabase.driver(uri, auth=(user, password)) - -def get_raw_tags_from_db(group_id: str, limit: int, by_user: bool = False) -> List[Tuple[str, int]]: +async def get_raw_tags_from_db( + connector: Neo4jConnector, + group_id: str, + limit: int, + by_user: bool = False +) -> List[Tuple[str, int]]: """ + TODO: not accurate tag extraction 从数据库查询原始的、未经过滤的实体标签及其频率。 + + 使用项目的Neo4jConnector进行查询,遵循仓储模式。 Args: + connector: Neo4j连接器实例 group_id: 如果by_user=False,则为group_id;如果by_user=True,则为user_id limit: 返回的标签数量限制 by_user: 是否按user_id查询(默认False,按group_id查询) + + Returns: + List[Tuple[str, int]]: 标签名称和频率的元组列表 """ names_to_exclude = ['AI', 'Caroline', 'Melanie', 'Jon', 'Gina', '用户', 'AI助手', 'John', 'Maria'] @@ -154,83 +125,55 @@ def get_raw_tags_from_db(group_id: str, limit: int, by_user: bool = False) -> Li "LIMIT $limit" ) - driver = None - try: - driver = get_db_connection() - with driver.session() as session: - result = session.run(query, id=group_id, limit=limit, names_to_exclude=names_to_exclude) - return [(record["name"], record["frequency"]) for record in result] - finally: - if driver: - driver.close() + # 使用项目的Neo4jConnector执行查询 + results = await connector.execute_query( + query, + id=group_id, + limit=limit, + names_to_exclude=names_to_exclude + ) + + return [(record["name"], record["frequency"]) for record in results] -async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]: +async def get_hot_memory_tags(group_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]: """ 获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。 查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。 Args: - group_id: 如果by_user=False,则为group_id;如果by_user=True,则为user_id + group_id: 必需参数。如果by_user=False,则为group_id;如果by_user=True,则为user_id limit: 返回的标签数量限制 by_user: 是否按user_id查询(默认False,按group_id查询) + + Raises: + ValueError: 如果group_id未提供或为空 """ - # 默认从环境变量读取 - group_id = group_id or DEFAULT_GROUP_ID - # 1. 从数据库获取原始排名靠前的标签 - raw_tags_with_freq = get_raw_tags_from_db(group_id, limit, by_user=by_user) - if not raw_tags_with_freq: - return [] - - raw_tag_names = [tag for tag, freq in raw_tags_with_freq] - - # 2. 初始化LLM客户端并使用LLM筛选出有意义的标签 - meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id) - - # 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序) - final_tags = [] - for tag, freq in raw_tags_with_freq: - if tag in meaningful_tag_names: - final_tags.append((tag, freq)) - - return final_tags - -if __name__ == "__main__": - print("开始获取热门记忆标签...") + # 验证group_id必须提供且不为空 + if not group_id or not group_id.strip(): + raise ValueError( + "group_id is required. Please provide a valid group_id or user_id." + ) + + # 使用项目的Neo4jConnector + connector = Neo4jConnector() try: - # 直接使用环境变量中的 group_id - group_id_to_query = DEFAULT_GROUP_ID - # 使用 asyncio.run 来执行异步主函数 - top_tags = asyncio.run(get_hot_memory_tags(group_id=group_id_to_query)) + # 1. 从数据库获取原始排名靠前的标签 + raw_tags_with_freq = await get_raw_tags_from_db(connector, group_id, limit, by_user=by_user) + if not raw_tags_with_freq: + return [] - if top_tags: - print(f"热门记忆标签 (Group ID: {group_id_to_query}, 经LLM筛选):") - for tag, frequency in top_tags: - print(f"- {tag} (数量: {frequency})") + raw_tag_names = [tag for tag, freq in raw_tags_with_freq] - # --- 将结果写入统一的 Signboard.json 到 logs/memory-output --- - from app.core.config import settings - settings.ensure_memory_output_dir() - signboard_path = settings.get_memory_output_path("Signboard.json") - payload = { - "group_id": group_id_to_query, - "hot_tags": [{"name": t, "frequency": f} for t, f in top_tags] - } - try: - existing = {} - if os.path.exists(signboard_path): - with open(signboard_path, "r", encoding="utf-8") as rf: - existing = json.load(rf) - existing["hot_memory_tags"] = payload - with open(signboard_path, "w", encoding="utf-8") as wf: - json.dump(existing, wf, ensure_ascii=False, indent=2) - print(f"已写入 {signboard_path} -> hot_memory_tags") - except Exception as e: - print(f"写入 Signboard.json 失败: {e}") - else: - print(f"在 Group ID '{group_id_to_query}' 中没有找到符合条件的实体标签。") - except Exception as e: - print(f"执行过程中发生严重错误: {e}") - print("请检查:") - print("1. Neo4j数据库服务是否正在运行。") - print("2. 'config.json'中的配置是否正确。") - print("3. 相关的环境变量 (如 NEO4J_PASSWORD, DASHSCOPE_API_KEY) 是否已正确设置。") + # 2. 初始化LLM客户端并使用LLM筛选出有意义的标签 + meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id) + + # 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序) + final_tags = [] + for tag, freq in raw_tags_with_freq: + if tag in meaningful_tag_names: + final_tags.append((tag, freq)) + + return final_tags + finally: + # 确保关闭连接 + await connector.close() diff --git a/api/app/core/memory/src/search.py b/api/app/core/memory/src/search.py index ae2b9cfa..91e47eae 100644 --- a/api/app/core/memory/src/search.py +++ b/api/app/core/memory/src/search.py @@ -131,179 +131,60 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score") return results -# ============================================================================ -# 以下函数已被 rerank_with_activation 替代,暂时保留以供参考 -# ============================================================================ -# def rerank_hybrid_results( -# keyword_results: Dict[str, List[Dict[str, Any]]], -# embedding_results: Dict[str, List[Dict[str, Any]]], -# alpha: float = 0.6, -# limit: int = 10 -# ) -> Dict[str, List[Dict[str, Any]]]: -# """ -# Rerank hybrid search results by combining BM25 and embedding scores. -# -# 已废弃:此函数功能已被 rerank_with_activation 完全替代 -# -# Args: -# keyword_results: Results from keyword/BM25 search -# embedding_results: Results from embedding search -# alpha: Weight for BM25 scores (1-alpha for embedding scores) -# limit: Maximum number of results to return per category -# -# Returns: -# Reranked results with combined scores -# """ -# reranked = {} -# -# for category in ["statements", "chunks", "entities","summaries"]: -# keyword_items = keyword_results.get(category, []) -# embedding_items = embedding_results.get(category, []) -# -# # Normalize scores within each search type -# keyword_items = normalize_scores(keyword_items, "score") -# embedding_items = normalize_scores(embedding_items, "score") -# -# # Create a combined pool of unique items -# combined_items = {} -# -# # Add keyword results with BM25 scores -# for item in keyword_items: -# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") -# if item_id: -# combined_items[item_id] = item.copy() -# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0) -# combined_items[item_id]["embedding_score"] = 0 # Default -# -# # Add or update with embedding results -# for item in embedding_items: -# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") -# if item_id: -# if item_id in combined_items: -# # Update existing item with embedding score -# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) -# else: -# # New item from embedding search only -# combined_items[item_id] = item.copy() -# combined_items[item_id]["bm25_score"] = 0 # Default -# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) -# -# # Calculate combined scores and rank -# for item_id, item in combined_items.items(): -# bm25_score = item.get("bm25_score", 0) -# embedding_score = item.get("embedding_score", 0) -# -# # Combined score: weighted average of normalized scores -# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score -# item["combined_score"] = combined_score -# -# # Keep original score for reference -# if "score" not in item and bm25_score > 0: -# item["score"] = bm25_score -# elif "score" not in item and embedding_score > 0: -# item["score"] = embedding_score -# -# # Sort by combined score and limit results -# sorted_items = sorted( -# combined_items.values(), -# key=lambda x: x.get("combined_score", 0), -# reverse=True -# )[:limit] -# -# reranked[category] = sorted_items -# -# return reranked - -# def rerank_with_forgetting_curve( -# keyword_results: Dict[str, List[Dict[str, Any]]], -# embedding_results: Dict[str, List[Dict[str, Any]]], -# alpha: float = 0.6, -# limit: int = 10, -# forgetting_config: ForgettingEngineConfig | None = None, -# now: datetime | None = None, -# ) -> Dict[str, List[Dict[str, Any]]]: -# """ -# Rerank hybrid results with a forgetting curve applied to combined scores. -# -# 已废弃:此函数功能已被 rerank_with_activation 完全替代 -# rerank_with_activation 提供了更完整的遗忘曲线支持(结合激活度) -# -# The forgetting curve reduces scores for older memories or weaker connections. -# -# Args: -# keyword_results: Results from keyword/BM25 search -# embedding_results: Results from embedding search -# alpha: Weight for BM25 scores (1-alpha for embedding scores) -# limit: Maximum number of results to return per category -# forgetting_config: Configuration for the forgetting engine -# now: Optional current time override for testing -# -# Returns: -# Reranked results with combined and final scores (after forgetting) -# """ -# engine = ForgettingEngine(forgetting_config or ForgettingEngineConfig()) -# now_dt = now or datetime.now() -# -# reranked: Dict[str, List[Dict[str, Any]]] = {} -# -# for category in ["statements", "chunks", "entities","summaries"]: -# keyword_items = keyword_results.get(category, []) -# embedding_items = embedding_results.get(category, []) -# -# # Normalize scores within each search type -# keyword_items = normalize_scores(keyword_items, "score") -# embedding_items = normalize_scores(embedding_items, "score") -# -# combined_items: Dict[str, Dict[str, Any]] = {} -# -# # Combine two result sets by ID -# for src_items, is_embedding in ( -# (keyword_items, False), (embedding_items, True) -# ): -# for item in src_items: -# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") -# if not item_id: -# continue -# existing = combined_items.get(item_id) -# if not existing: -# combined_items[item_id] = item.copy() -# combined_items[item_id]["bm25_score"] = 0 -# combined_items[item_id]["embedding_score"] = 0 -# # Update normalized score from the right source -# if is_embedding: -# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) -# else: -# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0) -# -# # Calculate scores and apply forgetting weights -# for item_id, item in combined_items.items(): -# bm25_score = float(item.get("bm25_score", 0) or 0) -# embedding_score = float(item.get("embedding_score", 0) or 0) -# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score -# -# # Estimate time elapsed in days -# dt = _parse_datetime(item.get("created_at")) -# if dt is None: -# time_elapsed_days = 0.0 -# else: -# time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0) -# -# # Memory strength (currently set to default value) -# memory_strength = 1.0 -# forgetting_weight = engine.calculate_weight( -# time_elapsed=time_elapsed_days, memory_strength=memory_strength -# ) -# final_score = combined_score * forgetting_weight -# item["combined_score"] = final_score -# -# sorted_items = sorted( -# combined_items.values(), key=lambda x: x.get("combined_score", 0), reverse=True -# )[:limit] -# -# reranked[category] = sorted_items -# -# return reranked +def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Remove duplicate items from search results based on content. + + Deduplication strategy: + 1. First try to deduplicate by ID (id, uuid, or chunk_id) + 2. Then deduplicate by content hash (text, content, statement, or name fields) + + Args: + items: List of search result items + + Returns: + Deduplicated list of items, preserving the order of first occurrence + """ + seen_ids = set() + seen_content = set() + deduplicated = [] + + for item in items: + # Try multiple ID fields to identify unique items + item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") + + # Extract content from various possible fields + content = ( + item.get("text") or + item.get("content") or + item.get("statement") or + item.get("name") or + "" + ) + + # Normalize content for comparison (strip whitespace and lowercase) + normalized_content = str(content).strip().lower() if content else "" + + # Check if we've seen this ID or content before + is_duplicate = False + + if item_id and item_id in seen_ids: + is_duplicate = True + elif normalized_content and normalized_content in seen_content: + # Only check content duplication if content is not empty + is_duplicate = True + + if not is_duplicate: + # Mark as seen + if item_id: + seen_ids.add(item_id) + if normalized_content: # Only track non-empty content + seen_content.add(normalized_content) + + deduplicated.append(item) + + return deduplicated def rerank_with_activation( @@ -364,7 +245,7 @@ def rerank_with_activation( keyword_items = normalize_scores(keyword_items, "score") embedding_items = normalize_scores(embedding_items, "score") - # 步骤 2: 按 ID 合并结果 + # 步骤 2: 按 ID 合并结果(去重) combined_items: Dict[str, Dict[str, Any]] = {} # 添加关键词结果 @@ -507,6 +388,9 @@ def rerank_with_activation( # 无激活值:使用内容相关性分数 item["final_score"] = item.get("base_score", 0) + # 最终去重确保没有重复项 + sorted_items = _deduplicate_results(sorted_items) + reranked[category] = sorted_items return reranked @@ -1144,96 +1028,3 @@ async def search_chunk_by_chunk_id( ) return {"chunks": chunks} - -# def main(): -# """Main entry point for the hybrid graph search CLI. - -# Parses command line arguments and executes search with specified parameters. -# Supports keyword, embedding, and hybrid search modes. -# """ -# parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options") -# parser.add_argument( -# "--query", "-q", required=True, help="Free-text query to search" -# ) -# parser.add_argument( -# "--search-type", -# "-t", -# choices=["keyword", "embedding", "hybrid"], -# default="hybrid", -# help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)" -# ) -# parser.add_argument( -# "--config-id", -# "-c", -# type=int, -# required=True, -# help="Database configuration ID (required)", -# ) -# parser.add_argument( -# "--group-id", -# "-g", -# default=None, -# help="Optional group_id to filter results (default: None)", -# ) -# parser.add_argument( -# "--limit", -# "-k", -# type=int, -# default=5, -# help="Max number of results per type (default: 5)", -# ) -# parser.add_argument( -# "--include", -# "-i", -# nargs="+", -# default=["statements", "chunks", "entities", "summaries"], -# choices=["statements", "chunks", "entities", "summaries"], -# help="Which targets to search for embedding search (default: statements chunks entities summaries)" -# ) -# parser.add_argument( -# "--output", -# "-o", -# default="search_results.json", -# help="Path to save the search results JSON (default: search_results.json)", -# ) -# parser.add_argument( -# "--rerank-alpha", -# "-a", -# type=float, -# default=0.6, -# help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)", -# ) -# parser.add_argument( -# "--forgetting-rerank", -# action="store_true", -# help="Apply forgetting curve during reranking for hybrid search.", -# ) -# parser.add_argument( -# "--llm-rerank", -# action="store_true", -# help="Apply LLM-based reranking for hybrid search.", -# ) -# args = parser.parse_args() - -# # Load memory config from database -# from app.services.memory_config_service import MemoryConfigService -# memory_config = MemoryConfigService.load_memory_config(args.config_id) - -# asyncio.run( -# run_hybrid_search( -# query_text=args.query, -# search_type=args.search_type, -# group_id=args.group_id, -# limit=args.limit, -# include=args.include, -# output_path=args.output, -# memory_config=memory_config, -# rerank_alpha=args.rerank_alpha, -# use_forgetting_rerank=args.forgetting_rerank, -# use_llm_rerank=args.llm_rerank, -# ) -# ) - - -# if __name__ == "__main__": -# main() diff --git a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py index bd3a9190..d39c9dbb 100644 --- a/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py +++ b/api/app/core/memory/storage_services/reflection_engine/self_reflexion.py @@ -18,13 +18,11 @@ from enum import Enum from typing import Any, Dict, List, Optional from app.core.memory.llm_tools.openai_client import OpenAIClient -from app.core.memory.utils.config import definitions as config_defs from app.core.memory.utils.config.get_data import ( extract_and_process_changes, get_data, get_data_statement, ) - from app.core.models.base import RedBearModelConfig from app.repositories.neo4j.cypher_queries import ( neo4j_query_all, diff --git a/api/app/core/memory/utils/config/__init__.py b/api/app/core/memory/utils/config/__init__.py index f69c13a2..9eef888c 100644 --- a/api/app/core/memory/utils/config/__init__.py +++ b/api/app/core/memory/utils/config/__init__.py @@ -14,28 +14,8 @@ from .config_utils import ( get_pruning_config, get_voice_config, ) - -# DEPRECATED: Global configuration variables removed -# Use MemoryConfig objects with dependency injection instead -# from .definitions import ( -# CONFIG, # DEPRECATED - empty dict for backward compatibility -# RUNTIME_CONFIG, # DEPRECATED - minimal for backward compatibility -# PROJECT_ROOT, # Still needed for file paths -# reload_configuration_from_database, # DEPRECATED - returns False -# ) -# DEPRECATED: overrides module removed - use MemoryConfig with dependency injection from .get_data import get_data -# litellm_config 需要时动态导入,避免循环依赖 -# from .litellm_config import ( -# LiteLLMConfig, -# setup_litellm_enhanced, -# get_usage_summary, -# print_usage_summary, -# get_instant_qps, -# print_instant_qps, -# ) - __all__ = [ # config_utils "get_model_config", @@ -45,18 +25,5 @@ __all__ = [ "get_pruning_config", "get_picture_config", "get_voice_config", - # definitions (DEPRECATED - use MemoryConfig objects instead) - # "CONFIG", # DEPRECATED - # "RUNTIME_CONFIG", # DEPRECATED - # "PROJECT_ROOT", - # "reload_configuration_from_database", # DEPRECATED - # get_data "get_data", - # litellm_config - 需要时从 .litellm_config 直接导入 - # "LiteLLMConfig", - # "setup_litellm_enhanced", - # "get_usage_summary", - # "print_usage_summary", - # "get_instant_qps", - # "print_instant_qps", ] diff --git a/api/app/core/memory/utils/config/config_optimization.py b/api/app/core/memory/utils/config/config_optimization.py deleted file mode 100644 index 41848a80..00000000 --- a/api/app/core/memory/utils/config/config_optimization.py +++ /dev/null @@ -1,398 +0,0 @@ -""" -配置管理优化模块 - -提供可选的配置管理优化功能,包括: -- LRU 缓存策略 -- 缓存预热 -- 缓存监控指标 -- 动态 TTL 策略 -- 配置版本控制 - -这些优化是可选的,当前的基础实现已经满足大多数需求。 -""" -import logging -import statistics -import threading -from collections import OrderedDict -from datetime import datetime, timedelta -from typing import Dict, Any, List, Optional, Tuple - -logger = logging.getLogger(__name__) - - -class LRUConfigCache: - """ - LRU(Least Recently Used)配置缓存 - - 当缓存达到最大容量时,自动淘汰最少使用的配置 - """ - - def __init__(self, max_size: int = 100, ttl: timedelta = timedelta(minutes=5)): - """ - 初始化 LRU 缓存 - - Args: - max_size: 最大缓存容量 - ttl: 缓存过期时间 - """ - self.max_size = max_size - self.ttl = ttl - self._cache: OrderedDict[str, Dict[str, Any]] = OrderedDict() - self._timestamps: Dict[str, datetime] = {} - self._lock = threading.RLock() - - # 统计信息 - self._stats = { - 'hits': 0, - 'misses': 0, - 'evictions': 0, - 'load_times': [] - } - - def get(self, config_id: str) -> Optional[Dict[str, Any]]: - """ - 获取配置(如果存在且未过期) - - Args: - config_id: 配置 ID - - Returns: - 配置字典,如果不存在或已过期则返回 None - """ - with self._lock: - if config_id not in self._cache: - self._stats['misses'] += 1 - return None - - # 检查是否过期 - timestamp = self._timestamps.get(config_id) - if timestamp and (datetime.now() - timestamp) >= self.ttl: - # 过期,移除 - self._cache.pop(config_id, None) - self._timestamps.pop(config_id, None) - self._stats['misses'] += 1 - return None - - # 命中,移动到末尾(标记为最近使用) - self._cache.move_to_end(config_id) - self._stats['hits'] += 1 - return self._cache[config_id] - - def put(self, config_id: str, config: Dict[str, Any]) -> None: - """ - 添加或更新配置 - - Args: - config_id: 配置 ID - config: 配置字典 - """ - with self._lock: - if config_id in self._cache: - # 更新现有配置 - self._cache.move_to_end(config_id) - else: - # 添加新配置 - if len(self._cache) >= self.max_size: - # 缓存已满,移除最旧的配置 - oldest_id, _ = self._cache.popitem(last=False) - self._timestamps.pop(oldest_id, None) - self._stats['evictions'] += 1 - logger.debug(f"[LRUCache] 淘汰配置: {oldest_id}") - - self._cache[config_id] = config - self._timestamps[config_id] = datetime.now() - - def clear(self, config_id: Optional[str] = None) -> None: - """ - 清除缓存 - - Args: - config_id: 如果指定,只清除该配置;否则清除所有 - """ - with self._lock: - if config_id: - self._cache.pop(config_id, None) - self._timestamps.pop(config_id, None) - else: - self._cache.clear() - self._timestamps.clear() - - def get_stats(self) -> Dict[str, Any]: - """ - 获取缓存统计信息 - - Returns: - 统计信息字典 - """ - with self._lock: - total = self._stats['hits'] + self._stats['misses'] - hit_rate = (self._stats['hits'] / total * 100) if total > 0 else 0 - - return { - 'cache_size': len(self._cache), - 'max_size': self.max_size, - 'total_requests': total, - 'cache_hits': self._stats['hits'], - 'cache_misses': self._stats['misses'], - 'evictions': self._stats['evictions'], - 'hit_rate': hit_rate, - 'avg_load_time': statistics.mean(self._stats['load_times']) if self._stats['load_times'] else 0 - } - - def record_load_time(self, load_time_ms: float) -> None: - """ - 记录加载时间 - - Args: - load_time_ms: 加载时间(毫秒) - """ - with self._lock: - self._stats['load_times'].append(load_time_ms) - # 只保留最近 1000 次的记录 - if len(self._stats['load_times']) > 1000: - self._stats['load_times'] = self._stats['load_times'][-1000:] - - -class ConfigCacheWarmer: - """ - 配置缓存预热器 - - 在系统启动时预加载常用配置,减少首次请求延迟 - """ - - @staticmethod - def warmup(config_ids: List[str], load_func) -> Dict[str, bool]: - """ - 预热缓存 - - Args: - config_ids: 要预加载的配置 ID 列表 - load_func: 配置加载函数 - - Returns: - 每个配置的加载结果 - """ - results = {} - - logger.info(f"[CacheWarmer] 开始预热 {len(config_ids)} 个配置") - - for config_id in config_ids: - try: - result = load_func(config_id) - results[config_id] = result - if result: - logger.debug(f"[CacheWarmer] 成功预热配置: {config_id}") - else: - logger.warning(f"[CacheWarmer] 预热配置失败: {config_id}") - except Exception as e: - logger.error(f"[CacheWarmer] 预热配置异常: {config_id}, 错误: {e}") - results[config_id] = False - - success_count = sum(1 for r in results.values() if r) - logger.info(f"[CacheWarmer] 预热完成: {success_count}/{len(config_ids)} 成功") - - return results - - -class DynamicTTLStrategy: - """ - 动态 TTL 策略 - - 根据配置类型和更新频率动态调整缓存过期时间 - """ - - # 预定义的 TTL 策略 - TTL_STRATEGIES = { - 'production': timedelta(minutes=30), # 生产配置较稳定 - 'staging': timedelta(minutes=15), # 预发布配置中等稳定 - 'development': timedelta(minutes=5), # 开发配置频繁变化 - 'testing': timedelta(minutes=1), # 测试配置快速过期 - 'default': timedelta(minutes=5) # 默认策略 - } - - @classmethod - def get_ttl(cls, config_id: str, config_type: Optional[str] = None) -> timedelta: - """ - 获取配置的 TTL - - Args: - config_id: 配置 ID - config_type: 配置类型(production/staging/development/testing) - - Returns: - TTL 时间间隔 - """ - if config_type and config_type in cls.TTL_STRATEGIES: - return cls.TTL_STRATEGIES[config_type] - - # 根据 config_id 推断类型 - if 'prod' in config_id.lower(): - return cls.TTL_STRATEGIES['production'] - elif 'stag' in config_id.lower(): - return cls.TTL_STRATEGIES['staging'] - elif 'dev' in config_id.lower(): - return cls.TTL_STRATEGIES['development'] - elif 'test' in config_id.lower(): - return cls.TTL_STRATEGIES['testing'] - - return cls.TTL_STRATEGIES['default'] - - -class ConfigVersionManager: - """ - 配置版本管理器 - - 跟踪配置版本,当配置更新时自动失效旧版本缓存 - """ - - def __init__(self): - self._versions: Dict[str, str] = {} - self._lock = threading.RLock() - - def get_version(self, config_id: str) -> Optional[str]: - """ - 获取配置版本 - - Args: - config_id: 配置 ID - - Returns: - 版本号,如果不存在则返回 None - """ - with self._lock: - return self._versions.get(config_id) - - def set_version(self, config_id: str, version: str) -> None: - """ - 设置配置版本 - - Args: - config_id: 配置 ID - version: 版本号 - """ - with self._lock: - old_version = self._versions.get(config_id) - self._versions[config_id] = version - - if old_version and old_version != version: - logger.info(f"[VersionManager] 配置版本更新: {config_id} {old_version} -> {version}") - - def check_version(self, config_id: str, cached_version: Optional[str]) -> bool: - """ - 检查缓存版本是否有效 - - Args: - config_id: 配置 ID - cached_version: 缓存的版本号 - - Returns: - True 如果版本匹配,False 如果版本不匹配或不存在 - """ - with self._lock: - current_version = self._versions.get(config_id) - - if not current_version or not cached_version: - return False - - return current_version == cached_version - - def invalidate(self, config_id: str) -> None: - """ - 使配置版本失效 - - Args: - config_id: 配置 ID - """ - with self._lock: - if config_id in self._versions: - # 生成新版本号 - import uuid - new_version = str(uuid.uuid4()) - self._versions[config_id] = new_version - logger.info(f"[VersionManager] 配置版本失效: {config_id} -> {new_version}") - - -class CacheMonitor: - """ - 缓存监控器 - - 提供缓存性能监控和报告功能 - """ - - def __init__(self, cache: LRUConfigCache): - self.cache = cache - - def get_report(self) -> str: - """ - 生成缓存性能报告 - - Returns: - 格式化的报告字符串 - """ - stats = self.cache.get_stats() - - report = f""" -配置缓存性能报告 -================ -缓存容量: {stats['cache_size']}/{stats['max_size']} -总请求数: {stats['total_requests']} -缓存命中: {stats['cache_hits']} -缓存未命中: {stats['cache_misses']} -缓存命中率: {stats['hit_rate']:.2f}% -淘汰次数: {stats['evictions']} -平均加载时间: {stats['avg_load_time']:.2f}ms -""" - return report - - def log_stats(self) -> None: - """记录统计信息到日志""" - stats = self.cache.get_stats() - logger.info( - f"[CacheMonitor] 缓存统计 - " - f"容量: {stats['cache_size']}/{stats['max_size']}, " - f"命中率: {stats['hit_rate']:.2f}%, " - f"淘汰: {stats['evictions']}" - ) - - -# 使用示例 -def example_usage(): - """ - 优化功能使用示例 - """ - # 1. 使用 LRU 缓存 - lru_cache = LRUConfigCache(max_size=100, ttl=timedelta(minutes=5)) - - # 获取配置 - config = lru_cache.get("config_001") - if config is None: - # 缓存未命中,从数据库加载 - config = {"llm_name": "openai/gpt-4"} - lru_cache.put("config_001", config) - - # 2. 预热缓存 - def load_config(config_id): - # 实际的配置加载逻辑 - return True - - warmer = ConfigCacheWarmer() - results = warmer.warmup(["config_001", "config_002"], load_config) - - # 3. 动态 TTL - ttl = DynamicTTLStrategy.get_ttl("prod_config_001", "production") - print(f"TTL: {ttl}") - - # 4. 版本管理 - version_manager = ConfigVersionManager() - version_manager.set_version("config_001", "v1.0.0") - - # 检查版本 - is_valid = version_manager.check_version("config_001", "v1.0.0") - - # 5. 监控 - monitor = CacheMonitor(lru_cache) - print(monitor.get_report()) - - -if __name__ == "__main__": - example_usage() diff --git a/api/app/core/memory/utils/config/definitions.py b/api/app/core/memory/utils/config/definitions.py deleted file mode 100644 index fc07c2cc..00000000 --- a/api/app/core/memory/utils/config/definitions.py +++ /dev/null @@ -1,268 +0,0 @@ -# """ -# 配置加载模块 - DEPRECATED - -# ⚠️ DEPRECATION NOTICE ⚠️ -# This module is deprecated and will be removed in a future version. -# Global configuration variables have been eliminated in favor of dependency injection. - -# Use the new MemoryConfig system instead: -# - app.schemas.memory_config_schema.MemoryConfig for configuration objects -# - config_service = MemoryConfigService(db); config_service.load_memory_config(config_id) - -# 阶段 1: 从 runtime.json 加载配置(路径 A)- DEPRECATED -# 阶段 2: 从数据库加载配置(路径 B,基于 dbrun.json 中的 config_id)- DEPRECATED -# 阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)- DEPRECATED -# """ -# import json -# import os -# import threading -# from datetime import datetime, timedelta -# from typing import Any, Dict, Optional - -# #TODO: Fix this - -# try: -# from dotenv import load_dotenv -# load_dotenv() -# except Exception: -# pass - -# # Import unified configuration system -# try: -# from app.core.config import settings -# USE_UNIFIED_CONFIG = True -# except ImportError: -# USE_UNIFIED_CONFIG = False -# settings = None - -# # PROJECT_ROOT 应该指向 app/core/memory/ 目录 -# # __file__ = app/core/memory/utils/config/definitions.py -# # os.path.dirname(__file__) = app/core/memory/utils/config -# # os.path.dirname(...) = app/core/memory/utils -# # os.path.dirname(...) = app/core/memory -# PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -# # DEPRECATED: Global configuration lock removed -# # Use MemoryConfig objects with dependency injection instead - -# # DEPRECATED: Legacy config.json loading removed -# # Use MemoryConfig objects with dependency injection instead -# CONFIG = {} - -# DEFAULT_VALUES = { -# "llm_name": "openai/qwen-plus", -# "embedding_name": "openai/nomic-embed-text:v1.5", -# "chunker_strategy": "RecursiveChunker", -# "group_id": "group_123", -# "user_id": "default_user", -# "apply_id": "default_apply", -# "llm_agent_name": "openai/qwen-plus", -# "llm_verify_name": "openai/qwen-plus", -# "llm_image_recognition": "openai/qwen-plus", -# "llm_voice_recognition": "openai/qwen-plus", -# "prompt_level": "DEBUG", -# "reflexion_iteration_period": "3", -# "reflexion_range": "retrieval", -# "reflexion_baseline": "TIME", -# } - -# # DEPRECATED: Legacy global variables for backward compatibility only -# # These will be removed in a future version -# # Use MemoryConfig objects with dependency injection instead -# # LANGFUSE_ENABLED = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true" -# # SELECTED_LLM_ID = os.getenv("SELECTED_LLM_ID", DEFAULT_VALUES["llm_name"]) - - -# # 阶段 1: 从 runtime.json 加载配置(路径 A) -# def _load_from_runtime_json() -> Dict[str, Any]: -# """ -# DEPRECATED: Legacy runtime.json loading - -# ⚠️ This function is deprecated and will be removed in a future version. -# Use MemoryConfig objects with dependency injection instead. - -# Returns: -# Dict[str, Any]: Empty configuration (legacy support only) -# """ -# import warnings -# warnings.warn( -# "Runtime JSON loading is deprecated. Use MemoryConfig objects with dependency injection instead.", -# DeprecationWarning, -# stacklevel=2 -# ) -# return {"selections": {}} - - -# # 阶段 2: 从数据库加载配置(路径 B)- 已整合到统一加载器 -# # 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代 -# # 保留此函数仅为向后兼容 -# def _load_from_database() -> Optional[Dict[str, Any]]: -# """ -# DEPRECATED: Legacy database configuration loading - -# ⚠️ This function is deprecated and will be removed in a future version. -# Use MemoryConfig objects with dependency injection instead. - -# Returns: -# Optional[Dict[str, Any]]: None (deprecated functionality) -# """ -# import warnings -# warnings.warn( -# "Database configuration loading is deprecated. Use MemoryConfig objects with dependency injection instead.", -# DeprecationWarning, -# stacklevel=2 -# ) -# return None - - -# # 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)- DEPRECATED -# def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None: -# """ -# DEPRECATED: 将运行时配置暴露为全局常量供项目使用 - -# ⚠️ This function is deprecated and will be removed in a future version. -# Global configuration variables have been eliminated in favor of dependency injection. - -# Use the new MemoryConfig system instead: -# - app.core.memory_config.config.MemoryConfig for configuration objects -# - Pass configuration objects as parameters instead of using global variables - -# Args: -# runtime_cfg: 运行时配置字典 -# """ -# import warnings -# warnings.warn( -# "Global configuration variables are deprecated. Use MemoryConfig objects with dependency injection instead.", -# DeprecationWarning, -# stacklevel=2 -# ) - -# # Keep minimal global state for backward compatibility only -# # These will be removed in a future version -# global RUNTIME_CONFIG, SELECTIONS - -# RUNTIME_CONFIG = runtime_cfg -# SELECTIONS = RUNTIME_CONFIG.get("selections", {}) - -# # All other global variables have been removed -# # Use MemoryConfig objects instead - - -# # 初始化:使用统一配置加载器 -# def _initialize_configuration() -> None: -# """ -# DEPRECATED: Legacy configuration initialization - -# ⚠️ This function is deprecated and will be removed in a future version. -# Use MemoryConfig objects with dependency injection instead. -# """ -# import warnings -# warnings.warn( -# "Global configuration initialization is deprecated. Use MemoryConfig objects with dependency injection instead.", -# DeprecationWarning, -# stacklevel=2 -# ) -# # Initialize with empty configuration for backward compatibility -# _expose_runtime_constants({"selections": {}}) - - -# # 模块加载时自动初始化配置 -# _initialize_configuration() - -# # DEPRECATED: Global variables removed -# # These variables have been eliminated in favor of dependency injection -# # Use MemoryConfig objects instead of accessing global variables - - -# # 公共 API:动态重新加载配置 -# def reload_configuration_from_database(config_id, force_reload: bool = False) -> bool: -# """ -# DEPRECATED: Legacy configuration reloading - -# ⚠️ This function is deprecated and will be removed in a future version. -# Use MemoryConfig objects with dependency injection instead. - -# For new code, use: -# - app.services.memory_agent_service.MemoryAgentService.load_memory_config() -# - app.services.memory_storage_service.MemoryStorageService.load_memory_config() - -# Args: -# config_id: Configuration ID (deprecated) -# force_reload: Force reload flag (deprecated) - -# Returns: -# bool: Always returns False (deprecated functionality) -# """ -# import logging -# import warnings - -# logger = logging.getLogger(__name__) - -# warnings.warn( -# "reload_configuration_from_database is deprecated. Use MemoryConfig objects with dependency injection instead.", -# DeprecationWarning, -# stacklevel=2 -# ) - -# logger.warning(f"Deprecated function reload_configuration_from_database called with config_id={config_id}. " -# "Use MemoryConfig objects with dependency injection instead.") - -# return False - - - - - -# def get_current_config_id() -> Optional[str]: -# """ -# DEPRECATED: Legacy config ID retrieval - -# ⚠️ This function is deprecated and will be removed in a future version. -# Use MemoryConfig objects with dependency injection instead. - -# Returns: -# Optional[str]: None (deprecated functionality) -# """ -# import warnings -# warnings.warn( -# "get_current_config_id is deprecated. Use MemoryConfig objects with dependency injection instead.", -# DeprecationWarning, -# stacklevel=2 -# ) -# return None - - -# def ensure_fresh_config(config_id = None) -> bool: -# """ -# DEPRECATED: Legacy configuration freshness check - -# ⚠️ This function is deprecated and will be removed in a future version. -# Use MemoryConfig objects with dependency injection instead. - -# For new code, use: -# - app.services.memory_agent_service.MemoryAgentService.load_memory_config() -# - app.services.memory_storage_service.MemoryStorageService.load_memory_config() - -# Args: -# config_id: Configuration ID (deprecated) - -# Returns: -# bool: Always returns False (deprecated functionality) -# """ -# import logging -# import warnings - -# logger = logging.getLogger(__name__) - -# warnings.warn( -# "ensure_fresh_config is deprecated. Use MemoryConfig objects with dependency injection instead.", -# DeprecationWarning, -# stacklevel=2 -# ) - -# logger.warning(f"Deprecated function ensure_fresh_config called with config_id={config_id}. " -# "Use MemoryConfig objects with dependency injection instead.") - -# return False - - diff --git a/api/app/core/memory/utils/config/get_example_data.py b/api/app/core/memory/utils/config/get_example_data.py deleted file mode 100644 index c466645b..00000000 --- a/api/app/core/memory/utils/config/get_example_data.py +++ /dev/null @@ -1,90 +0,0 @@ -import os -import re -import uuid -import random -import string -from typing import List, Dict, Optional - -# 生成包含字母(大小写)和数字的随机字符串 -def generate_random_string(length=16): - characters = string.ascii_letters + string.digits - return ''.join(random.choice(characters) for _ in range(length)) - -def get_example_data() -> List[Dict[str, Optional[str]]]: - """ - 从句子提取日志中获取数据 - Content: 在苹果公司中国总部,用户和李华偶遇了从美国来的技术专家约翰·史密斯。 - Created At: 2025-11-28 19:28:38.256421 - Expired At: None - Valid At: None - Invalid At: None - 将数据构造成如下形式: - [ - { - "id":id, - "group_id":group_id, - "statement": Content, - "created_at": Created At, - "expired_at": Expired At, - "valid_at": Valid At, - "invalid_at": Invalid At, - "chunk_id": "86da9022710c40eaa5f518a294c398d2", - "entity_ids": [] - }, - ... - ] - """ - # 获取日志文件路径 - log_file_path = os.path.join("logs", "memory-output", "statement_extraction.txt") - - # 检查文件是否存在 - if not os.path.exists(log_file_path): - return [] - - # 读取日志文件 - with open(log_file_path, "r", encoding="utf-8") as f: - content = f.read() - - # 解析数据 - results = [] - - # 使用正则表达式分割每个 Statement - statement_blocks = re.split(r"Statement \d+:", content) - - for block in statement_blocks[1:]: # 跳过第一个空块 - # 提取各个字段 - id_match = re.search(r"Id:\s*(.+?)(?=\n)", block) - group_id_match = re.search(r"Group Id:\s*(.+?)(?=\n)", block) - statement_match = re.search(r"Content:\s*(.+?)(?=\n)", block) - created_at_match = re.search(r"Created At:\s*(.+?)(?=\n)", block) - expired_at_match = re.search(r"Expired At:\s*(.+?)(?=\n)", block) - valid_at_match = re.search(r"Valid At:\s*(.+?)(?=\n)", block) - invalid_at_match = re.search(r"Invalid At:\s*(.+?)(?=\n)", block) - chunk_id_match = re.search(r"Chunk Id:\s*(.+?)(?=\n)", block) - - # 构造字典 - if statement_match: - statement_data = { - "id": id_match.group(1).strip() if id_match else generate_random_string(), - "group_id": group_id_match.group(1).strip() if group_id_match else "group_example", - "statement": statement_match.group(1).strip(), - "created_at": created_at_match.group(1).strip() if created_at_match else None, - "expired_at": expired_at_match.group(1).strip() if expired_at_match else None, - "valid_at": valid_at_match.group(1).strip() if valid_at_match else None, - "invalid_at": invalid_at_match.group(1).strip() if invalid_at_match else None, - "chunk_id": chunk_id_match.group(1).strip() if chunk_id_match else "chunk_example", - "entity_ids": [] - } - - # 将 "None" 字符串转换为 None - for key in ["created_at", "expired_at", "valid_at", "invalid_at"]: - if statement_data[key] == "None": - statement_data[key] = None - - results.append(statement_data) - - return results - - -if __name__ == "__main__": - print(f"获取数据如下:\n {get_example_data()}") \ No newline at end of file diff --git a/api/app/core/memory/utils/config/litellm_config.py b/api/app/core/memory/utils/config/litellm_config.py deleted file mode 100644 index dbf991a8..00000000 --- a/api/app/core/memory/utils/config/litellm_config.py +++ /dev/null @@ -1,516 +0,0 @@ -""" -LiteLLM Configuration for Enhanced Retry Logic and Usage Tracking with Native QPS Monitoring -""" - -import litellm -from typing import Dict, Any, List -import json -from datetime import datetime, timedelta -import os -import time -from collections import defaultdict -import threading -from queue import Queue - -class LiteLLMConfig: - """Configuration class for LiteLLM with enhanced retry and tracking capabilities""" - - def __init__(self): - self.usage_data = [] - self.error_data = [] - self.module_stats = defaultdict(lambda: { - 'requests': 0, - 'tokens_in': 0, - 'tokens_out': 0, - 'cost': 0.0, - 'errors': 0, - 'start_time': None, - 'last_request_time': None, - 'request_timestamps': [], # Store precise timestamps - 'current_qps': 0.0, - 'max_qps': 0.0, - 'qps_history': [] # Store QPS measurements over time - }) - self.start_time = datetime.now() - self.global_request_timestamps = [] - self.global_max_qps = 0.0 - - # Rate limiting for AWS Bedrock (conservative limits) - self.rate_limits = { - 'bedrock': { - 'requests_per_minute': 2, # AWS Bedrock default is very low - 'requests_per_second': 0.033, # 2/60 = 0.033 RPS - 'last_request_time': 0, - 'request_queue': Queue(), - 'lock': threading.Lock() - } - } - self.rate_limiting_enabled = True - - def setup_enhanced_config(self, max_retries: int = 3): - """Configure LiteLLM with retry logic and instant QPS tracking""" - - litellm.num_retries = max_retries - litellm.request_timeout = 300 - - litellm.retry_policy = { - "RateLimitError": { - "max_retries": 5, - "exponential_backoff": True, - "initial_delay": 1, - "max_delay": 60, - "jitter": True - }, - "APIConnectionError": { - "max_retries": 3, - "exponential_backoff": True, - "initial_delay": 2, - "max_delay": 30, - "jitter": True - }, - "InternalServerError": { - "max_retries": 2, - "exponential_backoff": True, - "initial_delay": 5, - "max_delay": 60, - "jitter": True - }, - "BadRequestError": { - "max_retries": 1, - "exponential_backoff": False, - "initial_delay": 1, - "max_delay": 5 - } - } - - litellm.success_callback = [self._success_callback] - litellm.failure_callback = [self._failure_callback] - litellm.completion_cost_tracking = True - litellm.set_verbose = False - litellm.modify_params = True - - print("✅ LiteLLM configured with instant QPS tracking and rate limiting") - - def _success_callback(self, kwargs, completion_response, start_time, end_time): - """Callback for successful requests with module-specific QPS tracking""" - try: - # Extract usage information - usage = completion_response.get('usage', {}) - model = kwargs.get('model', 'unknown') - - # Extract module information from metadata or model name - module = self._extract_module_name(kwargs, model) - - # Calculate cost - cost = 0.0 - try: - cost = litellm.completion_cost(completion_response) - except: - pass - - # Calculate duration - duration_seconds = (end_time - start_time).total_seconds() if hasattr(end_time - start_time, 'total_seconds') else float(end_time - start_time) - - # Record usage data - usage_record = { - "timestamp": datetime.now().isoformat(), - "model": model, - "module": module, - "input_tokens": usage.get('prompt_tokens', 0), - "output_tokens": usage.get('completion_tokens', 0), - "total_tokens": usage.get('total_tokens', 0), - "cost": cost, - "duration_seconds": duration_seconds, - "status": "success" - } - - self.usage_data.append(usage_record) - - # Update module-specific stats for QPS tracking - self._update_module_stats(module, usage_record, success=True) - - # Print real-time feedback - print(f"✓ {model}: {usage_record['input_tokens']}→{usage_record['output_tokens']} tokens, ${cost:.4f}, {usage_record['duration_seconds']:.2f}s") - - except Exception as e: - print(f"Warning: Success callback failed: {e}") - - def _failure_callback(self, kwargs, completion_response, start_time, end_time): - """Callback for failed requests with module-specific error tracking""" - try: - model = kwargs.get('model', 'unknown') - module = self._extract_module_name(kwargs, model) - - duration_seconds = (end_time - start_time).total_seconds() if hasattr(end_time - start_time, 'total_seconds') else float(end_time - start_time) - - # Handle different error response formats - error_message = "Unknown error" - error_type = "UnknownError" - - # According to LiteLLM docs, completion_response contains the exception for failures - if completion_response is not None: - error_message = str(completion_response) - error_type = type(completion_response).__name__ - - # Also check kwargs for exception (LiteLLM passes exception in kwargs for failure events) - elif 'exception' in kwargs: - exception = kwargs['exception'] - error_message = str(exception) - error_type = type(exception).__name__ - - # Check for other error formats in kwargs - elif 'error' in kwargs: - error = kwargs['error'] - error_message = str(error) - error_type = type(error).__name__ - - # Check log_event_type to confirm this is a failure event - log_event_type = kwargs.get('log_event_type', '') - if log_event_type == 'failed_api_call' and 'exception' in kwargs: - exception = kwargs['exception'] - error_message = str(exception) - error_type = type(exception).__name__ - - error_record = { - "timestamp": datetime.now().isoformat(), - "model": model, - "module": module, - "error": error_message, - "error_type": error_type, - "duration_seconds": duration_seconds, - "status": "failed" - } - - self.error_data.append(error_record) - - # Update module-specific stats for error tracking - self._update_module_stats(module, error_record, success=False) - - # Print error feedback - print(f"✗ {model}: {error_type} - {error_message[:100]}") - - except Exception as e: - print(f"Warning: Failure callback failed: {e}") - # Debug: print the actual parameters to understand the structure - print(f"Debug - kwargs keys: {list(kwargs.keys()) if kwargs else 'None'}") - print(f"Debug - completion_response type: {type(completion_response)}") - print(f"Debug - completion_response: {completion_response}") - - def _should_rate_limit(self, model: str) -> bool: - """Check if the model should be rate limited""" - if not self.rate_limiting_enabled: - return False - return model.startswith('bedrock/') or 'bedrock' in model.lower() - - def _enforce_rate_limit(self, model: str): - """Enforce rate limiting for AWS Bedrock models""" - if not self._should_rate_limit(model): - return - - provider = 'bedrock' - if provider not in self.rate_limits: - return - - rate_config = self.rate_limits[provider] - - with rate_config['lock']: - current_time = time.time() - time_since_last = current_time - rate_config['last_request_time'] - min_interval = 1.0 / rate_config['requests_per_second'] - - if time_since_last < min_interval: - sleep_time = min_interval - time_since_last - print(f"⏳ Rate limiting: sleeping {sleep_time:.2f}s for {model}") - time.sleep(sleep_time) - - rate_config['last_request_time'] = time.time() - - def _extract_module_name(self, kwargs: Dict[str, Any], model: str) -> str: - """Extract module name from request context""" - # Try to get module from metadata - metadata = kwargs.get('metadata', {}) - if 'module' in metadata: - return metadata['module'] - - # Try to infer from model name or other context - if 'claude' in model.lower(): - return 'bedrock_client' - elif 'gpt' in model.lower() or 'openai' in model.lower(): - return 'openai_client' - elif 'embed' in model.lower(): - return 'embedder' - else: - return 'unknown' - - def _update_module_stats(self, module: str, record: Dict[str, Any], success: bool): - """Update module-specific statistics with instant QPS tracking""" - current_timestamp = time.time() - current_time = datetime.now() - - # Initialize module stats if first request - if self.module_stats[module]['start_time'] is None: - self.module_stats[module]['start_time'] = current_time - - # Update counters - self.module_stats[module]['requests'] += 1 - self.module_stats[module]['last_request_time'] = current_time - self.module_stats[module]['request_timestamps'].append(current_timestamp) - self.global_request_timestamps.append(current_timestamp) - - # Calculate instant QPS for this module - self._calculate_instant_qps(module, current_timestamp) - - # Calculate global instant QPS - self._calculate_global_instant_qps(current_timestamp) - - if success: - self.module_stats[module]['tokens_in'] += record.get('input_tokens', 0) - self.module_stats[module]['tokens_out'] += record.get('output_tokens', 0) - self.module_stats[module]['cost'] += record.get('cost', 0.0) - else: - self.module_stats[module]['errors'] += 1 - - def _calculate_instant_qps(self, module: str, current_timestamp: float): - """Calculate instant QPS for a specific module using sliding window""" - # Keep only timestamps from last 1 second for instant QPS - cutoff_time = current_timestamp - 1.0 - timestamps = self.module_stats[module]['request_timestamps'] - - # Remove old timestamps - self.module_stats[module]['request_timestamps'] = [ - ts for ts in timestamps if ts >= cutoff_time - ] - - # Calculate current QPS (requests in last second) - current_qps = len(self.module_stats[module]['request_timestamps']) - self.module_stats[module]['current_qps'] = current_qps - - # Update max QPS if current is higher - if current_qps > self.module_stats[module]['max_qps']: - self.module_stats[module]['max_qps'] = current_qps - - # Store QPS history (keep last 60 measurements) - self.module_stats[module]['qps_history'].append(current_qps) - if len(self.module_stats[module]['qps_history']) > 60: - self.module_stats[module]['qps_history'].pop(0) - - def _calculate_global_instant_qps(self, current_timestamp: float): - """Calculate global instant QPS across all modules""" - # Keep only timestamps from last 1 second - cutoff_time = current_timestamp - 1.0 - self.global_request_timestamps = [ - ts for ts in self.global_request_timestamps if ts >= cutoff_time - ] - - # Calculate current global QPS - current_global_qps = len(self.global_request_timestamps) - - # Update max global QPS - if current_global_qps > self.global_max_qps: - self.global_max_qps = current_global_qps - - def get_instant_qps(self, module: str = None) -> Dict[str, Any]: - """Get instant QPS data for modules""" - if module: - if module in self.module_stats: - return { - 'module': module, - 'current_qps': self.module_stats[module]['current_qps'], - 'max_qps': self.module_stats[module]['max_qps'], - 'avg_qps_last_minute': sum(self.module_stats[module]['qps_history'][-60:]) / min(60, len(self.module_stats[module]['qps_history'])) if self.module_stats[module]['qps_history'] else 0 - } - else: - return {'module': module, 'current_qps': 0, 'max_qps': 0, 'avg_qps_last_minute': 0} - else: - # Return data for all modules plus global - result = { - 'global': { - 'current_qps': len([ts for ts in self.global_request_timestamps if ts >= time.time() - 1.0]), - 'max_qps': self.global_max_qps - }, - 'modules': {} - } - - for mod in self.module_stats: - result['modules'][mod] = { - 'current_qps': self.module_stats[mod]['current_qps'], - 'max_qps': self.module_stats[mod]['max_qps'], - 'avg_qps_last_minute': sum(self.module_stats[mod]['qps_history'][-60:]) / min(60, len(self.module_stats[mod]['qps_history'])) if self.module_stats[mod]['qps_history'] else 0 - } - - return result - - def get_usage_summary(self) -> Dict[str, Any]: - """Get essential usage statistics""" - if not self.usage_data: - return { - "total_requests": 0, - "total_cost": 0.0, - "error_rate": 0.0, - "message": "No usage data available" - } - - total_requests = len(self.usage_data) - total_errors = len(self.error_data) - total_cost = sum(record['cost'] for record in self.usage_data) - total_input_tokens = sum(record['input_tokens'] for record in self.usage_data) - total_output_tokens = sum(record['output_tokens'] for record in self.usage_data) - - # Calculate session duration - duration_minutes = (datetime.now() - self.start_time).total_seconds() / 60 - - # Build module statistics - module_stats = {} - for module, stats in self.module_stats.items(): - if stats['requests'] > 0: - module_stats[module] = { - "requests": stats['requests'], - "errors": stats['errors'], - "success_rate": ((stats['requests'] - stats['errors']) / stats['requests'] * 100) if stats['requests'] > 0 else 0, - "tokens_in": stats['tokens_in'], - "tokens_out": stats['tokens_out'], - "cost": stats['cost'], - "current_qps": stats['current_qps'], - "max_qps": stats['max_qps'] - } - - return { - "session_duration_minutes": duration_minutes, - "total_requests": total_requests, - "total_errors": total_errors, - "error_rate": (total_errors / total_requests * 100) if total_requests > 0 else 0, - "total_input_tokens": total_input_tokens, - "total_output_tokens": total_output_tokens, - "total_cost": total_cost, - "module_stats": module_stats, - "global_max_qps": self.global_max_qps - } - - def print_usage_summary(self): - """Print essential usage summary""" - stats = self.get_usage_summary() - - if stats.get('message'): - print(f"📊 {stats['message']}") - return - - print("\n📊 USAGE SUMMARY") - print(f"{'='*50}") - print(f"⏱️ Duration: {stats['session_duration_minutes']:.1f} min") - print(f"📈 Requests: {stats['total_requests']}") - print(f"❌ Errors: {stats['total_errors']}") - print(f"💰 Cost: ${stats['total_cost']:.4f}") - print(f"🏆 Global Max QPS: {stats['global_max_qps']}") - - # Module statistics - if stats.get('module_stats'): - print("\n📦 MODULES:") - for module, mod_stats in stats['module_stats'].items(): - print(f" {module}: {mod_stats['requests']} req, Max QPS: {mod_stats['max_qps']}, Current: {mod_stats['current_qps']}") - - print(f"{'='*50}") - - def save_usage_data(self, filename: str = "litellm_usage.json"): - """Save usage data to JSON file""" - data = { - "summary": self.get_usage_summary(), - "detailed_usage": self.usage_data, - "errors": self.error_data, - "export_timestamp": datetime.now().isoformat() - } - - with open(filename, 'w') as f: - json.dump(data, f, indent=2) - - print(f"📁 Usage data saved to {filename}") - - def reset_tracking(self): - """Reset all tracking data""" - self.usage_data = [] - self.error_data = [] - self.module_stats = defaultdict(lambda: { - 'requests': 0, - 'tokens_in': 0, - 'tokens_out': 0, - 'cost': 0.0, - 'errors': 0, - 'start_time': None, - 'last_request_time': None, - 'request_timestamps': [], - 'current_qps': 0.0, - 'max_qps': 0.0, - 'qps_history': [] - }) - self.global_request_timestamps = [] - self.global_max_qps = 0.0 - self.start_time = datetime.now() - print("🔄 All tracking data reset") - -# Global instance for easy access -litellm_config = LiteLLMConfig() - -def setup_litellm_enhanced(max_retries: int = 3): - """ - Quick setup function for LiteLLM enhanced configuration - - Args: - max_retries: Maximum number of retries for failed requests - """ - litellm_config.setup_enhanced_config(max_retries) - return litellm_config - -def get_usage_summary(): - """Get current usage summary""" - return litellm_config.get_usage_summary() - -def print_usage_summary(): - """Print current usage summary""" - litellm_config.print_usage_summary() - -def save_usage_data(filename: str = "litellm_usage.json"): - """Save usage data to file""" - litellm_config.save_usage_data(filename) - -def get_instant_qps(module: str = None) -> Dict[str, Any]: - """Get instant QPS data for modules""" - return litellm_config.get_instant_qps(module) - -def print_instant_qps(module: str = None): - """Print instant QPS information""" - qps_data = get_instant_qps(module) - - print("\n⚡ INSTANT QPS MONITOR") - print(f"{'='*60}") - - if module: - print(f"Module: {qps_data['module']}") - print(f" Current QPS: {qps_data['current_qps']}") - print(f" Max QPS: {qps_data['max_qps']}") - print(f" Avg (1min): {qps_data['avg_qps_last_minute']:.2f}") - else: - # Global stats - global_data = qps_data.get('global', {}) - print("🌍 GLOBAL:") - print(f" Current QPS: {global_data.get('current_qps', 0)}") - print(f" Max QPS: {global_data.get('max_qps', 0)}") - - # Module stats - modules = qps_data.get('modules', {}) - if modules: - print("\n📦 MODULES:") - for mod, data in modules.items(): - print(f" {mod}:") - print(f" Current: {data['current_qps']} QPS") - print(f" Max: {data['max_qps']} QPS") - print(f" Avg: {data['avg_qps_last_minute']:.2f} QPS") - - print(f"{'='*60}") - -def reset_tracking(): - """Reset all tracking data""" - litellm_config.reset_tracking() - -def get_module_stats() -> Dict[str, Dict[str, Any]]: - """Get detailed module statistics""" - summary = get_usage_summary() - return summary.get('module_stats', {}) diff --git a/api/app/core/memory/utils/self_reflexion_utils/__init__.py b/api/app/core/memory/utils/self_reflexion_utils/__init__.py deleted file mode 100644 index 422a83e3..00000000 --- a/api/app/core/memory/utils/self_reflexion_utils/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# -*- coding: utf-8 -*- -"""自我反思工具模块 - -本模块提供自我反思引擎的核心功能,包括: -- 记忆冲突判定 -- 反思执行 -- 记忆更新 - -从 app.core.memory.src.data_config_api 迁移而来。 -""" - -from app.core.memory.utils.self_reflexion_utils.evaluate import conflict -from app.core.memory.utils.self_reflexion_utils.reflexion import reflexion -from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion - -__all__ = ["conflict", "reflexion", "self_reflexion"] diff --git a/api/app/core/memory/utils/self_reflexion_utils/evaluate.py b/api/app/core/memory/utils/self_reflexion_utils/evaluate.py deleted file mode 100644 index 4d1835cd..00000000 --- a/api/app/core/memory/utils/self_reflexion_utils/evaluate.py +++ /dev/null @@ -1,52 +0,0 @@ -# -*- coding: utf-8 -*- -"""记忆冲突判定模块 - -本模块提供记忆冲突判定功能,使用LLM判断记忆数据中是否存在冲突。 -从 app.core.memory.src.data_config_api.evaluate 迁移而来。 -""" - -import logging -import time -from typing import Any, List - -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.core.memory.utils.prompt.template_render import render_evaluate_prompt -from app.db import get_db_context -from app.schemas.memory_storage_schema import ConflictResultSchema -from pydantic import BaseModel - - -async def conflict(evaluate_data: List[Any]) -> List[Any]: - """ - Evaluates memory conflict using the evaluate.jinja2 template. - - Args: - evaluate_data: 反思数据列表。 - Returns: - 冲突记忆列表(JSON 数组)。 - """ - from app.core.memory.utils.config import definitions as config_defs - with get_db_context() as db: - factory = MemoryClientFactory(db) - client = factory.get_llm_client(config_defs.SELECTED_LLM_ID) - rendered_prompt = await render_evaluate_prompt(evaluate_data, ConflictResultSchema) - messages = [{"role": "user", "content": rendered_prompt}] - print(f"提示词长度: {len(rendered_prompt)}") - print(f"====== 冲突判定开始 ======\n") - start_time = time.time() - response = await client.response_structured(messages, ConflictResultSchema) - end_time = time.time() - print(f"冲突判定耗时: {end_time - start_time} 秒") - print(f"冲突判定原始输出:(type={type(response)})\n{response}") - - if not response: - logging.error("LLM 冲突判定输出解析失败,返回空列表以继续流程。") - return [] - try: - return [response.model_dump()] if isinstance(response, BaseModel) else [response] - except Exception: - try: - return [response.dict()] - except Exception: - logging.warning("无法标准化冲突判定返回类型,尝试直接封装为列表。") - return [response] diff --git a/api/app/core/memory/utils/self_reflexion_utils/reflexion.py b/api/app/core/memory/utils/self_reflexion_utils/reflexion.py deleted file mode 100644 index 1b915118..00000000 --- a/api/app/core/memory/utils/self_reflexion_utils/reflexion.py +++ /dev/null @@ -1,54 +0,0 @@ -# -*- coding: utf-8 -*- -"""反思执行模块 - -本模块提供反思执行功能,使用LLM对冲突记忆进行反思和解决。 -从 app.core.memory.src.data_config_api.reflexion 迁移而来。 -""" - -import logging -import time -from typing import Any, List - -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.core.memory.utils.prompt.template_render import render_reflexion_prompt -from app.db import get_db_context -from app.schemas.memory_storage_schema import ReflexionResultSchema -from pydantic import BaseModel - - -async def reflexion(ref_data: List[Any]) -> List[Any]: - """ - Reflexes on the given reference data using the reflexion.jinja2 template. - - Args: - ref_data: 反思数据列表。 - Returns: - 反思结果列表(JSON 数组)。 - """ - from app.core.memory.utils.config import definitions as config_defs - with get_db_context() as db: - factory = MemoryClientFactory(db) - client = factory.get_llm_client(config_defs.SELECTED_LLM_ID) - rendered_prompt = await render_reflexion_prompt(ref_data, ReflexionResultSchema) - messages = [{"role": "user", "content": rendered_prompt}] - print(f"提示词长度: {len(rendered_prompt)}") - - print(f"====== 反思开始 ======\n") - start_time = time.time() - response = await client.response_structured(messages, ReflexionResultSchema) - end_time = time.time() - print(f"反思耗时: {end_time - start_time} 秒") - print(f"反思原始输出:(type={type(response)})\n{response}") - - if not response: - logging.error("LLM 反思输出解析失败,返回空列表以继续流程。") - return [] - # 统一返回为列表[dict],便于自我反思主流程更新数据库 - try: - return [response.model_dump()] if isinstance(response, BaseModel) else [response] - except Exception: - try: - return [response.dict()] - except Exception: - logging.warning("无法标准化反思返回类型,尝试直接封装为列表。") - return [response] diff --git a/api/app/core/memory/utils/self_reflexion_utils/self_reflexion.py b/api/app/core/memory/utils/self_reflexion_utils/self_reflexion.py deleted file mode 100644 index 934037b0..00000000 --- a/api/app/core/memory/utils/self_reflexion_utils/self_reflexion.py +++ /dev/null @@ -1,254 +0,0 @@ -# -*- coding: utf-8 -*- -"""自我反思主执行模块 - -本模块提供自我反思引擎的主流程,包括: -- 获取反思数据 -- 冲突判断 -- 反思执行 -- 记忆更新 - -从 app.core.memory.src.data_config_api.self_reflexion 迁移而来。 -""" - -import asyncio -import json -import logging -import os -import uuid -from typing import Any, Dict, List - -#TODO: Fix this - -# Default values (previously from definitions.py) -REFLEXION_ENABLED = os.getenv("REFLEXION_ENABLED", "false").lower() == "true" -REFLEXION_ITERATION_PERIOD = os.getenv("REFLEXION_ITERATION_PERIOD", "3") -REFLEXION_RANGE = os.getenv("REFLEXION_RANGE", "retrieval") -REFLEXION_BASELINE = os.getenv("REFLEXION_BASELINE", "TIME") - -from app.core.memory.utils.config.get_data import get_data -from app.core.memory.utils.self_reflexion_utils.evaluate import conflict -from app.core.memory.utils.self_reflexion_utils.reflexion import reflexion -from app.db import get_db -from app.models.retrieval_info import RetrievalInfo -from app.repositories.neo4j.cypher_queries import UPDATE_STATEMENT_INVALID_AT -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from sqlalchemy.orm import Session - -# 并发限制(可通过环境变量覆盖) -CONCURRENCY = int(os.getenv("REFLEXION_CONCURRENCY", "5")) - -# 确保 INFO 级别日志输出到终端 -_root_logger = logging.getLogger() -if not _root_logger.handlers: - logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") -else: - _root_logger.setLevel(logging.INFO) - - -async def get_reflexion_data(host_id: uuid.UUID) -> List[Any]: - """ - 根据反思范围获取判断的记忆数据。 - - Args: - host_id: 主机ID - Returns: - 符合反思范围的记忆数据列表。 - """ - if REFLEXION_RANGE == "partial": - return await get_data(host_id) - elif REFLEXION_RANGE == "all": - return [] - else: - raise ValueError(f"未知的反思范围: {REFLEXION_RANGE}") - - -async def run_conflict(conflict_data: List[Any]) -> List[Any]: - """ - 判断反思数据中是否存在冲突。 - - Args: - conflict_data: 冲突数据列表。 - Returns: - 如果存在冲突则返回冲突记忆列表,否则返回空列表。 - """ - if not conflict_data: - return [] - - conflict_data = await conflict(conflict_data) - # 仅保留存在冲突的条目(conflict == True) - try: - return [c for c in conflict_data if isinstance(c, dict) and c.get("conflict") is True] - except Exception: - return [] - - -async def run_reflexion(reflexion_data: List[Any]) -> Any: - """ - 执行反思,解决冲突。 - - Args: - reflexion_data: 反思数据列表。 - Returns: - 解决冲突后的反思结果(由 LLM 返回)。 - """ - if not reflexion_data: - return [] - # 并行对每个冲突进行反思,整体缩短等待时间 - sem = asyncio.Semaphore(CONCURRENCY) - - async def _reflex_one(item: Any) -> Dict[str, Any] | None: - async with sem: - try: - result_list = await reflexion([item]) - if not result_list: - return None - obj = result_list[0] - if hasattr(obj, "model_dump"): - return obj.model_dump() - elif hasattr(obj, "dict"): - return obj.dict() - elif isinstance(obj, dict): - return obj - except Exception as e: - logging.warning(f"反思失败,跳过一项: {e}") - return None - - tasks = [_reflex_one(item) for item in reflexion_data] - results = await asyncio.gather(*tasks, return_exceptions=False) - return [r for r in results if r] - - -async def update_memory(solved_data: List[Any], host_id: uuid.UUID) -> str: - """ - 更新记忆库,将解决冲突后的记忆更新到记忆库中。 - - Args: - solved_data: 解决冲突后的记忆(由 LLM 返回)。 - host_id: 主机ID - Returns: - 更新结果(成功或失败)。 - """ - flag = False - if not solved_data: - return "数据缺失,更新失败" - if not isinstance(solved_data, list): - return "数据格式错误,更新失败" - neo4j_connector = Neo4jConnector() - try: - print(f"====== 更新记忆开始 ======\n") - - sem = asyncio.Semaphore(CONCURRENCY) - success_count = 0 - - async def _update_one(item: Dict[str, Any]) -> bool: - async with sem: - try: - if not isinstance(item, dict): - return False - if not item: - return False - resolved = item.get("resolved") - if not isinstance(resolved, dict) or not resolved: - logging.warning(f"反思结果无可更新内容,跳过此项: {item}") - return False - resolved_mem = resolved.get("resolved_memory") - if not isinstance(resolved_mem, dict) or not resolved_mem: - logging.warning(f"反思结果缺少 resolved_memory,跳过此项: {item}") - return False - group_id = resolved_mem.get("group_id") - id = resolved_mem.get("id") - # 使用 invalid_at 字段作为新的失效时间 - new_invalid_at = resolved_mem.get("invalid_at") - if not all([group_id, id, new_invalid_at]): - logging.warning(f"记忆更新参数缺失,跳过此项: {item}") - return False - await neo4j_connector.execute_query( - UPDATE_STATEMENT_INVALID_AT, - group_id=group_id, - id=id, - new_invalid_at=new_invalid_at, - ) - return True - except Exception as e: - logging.error(f"更新单条记忆失败: {e}") - return False - - tasks = [_update_one(item) for item in solved_data if isinstance(item, dict)] - results = await asyncio.gather(*tasks, return_exceptions=False) - success_count = sum(1 for r in results if r) - - logging.info(f"成功更新 {success_count} 条记忆") - flag = success_count > 0 - return "更新成功" if flag else "更新失败" - except Exception as e: - logging.error(f"更新记忆库失败: {e}") - return "更新失败" - finally: - if flag: # 删除数据库中的检索数据 - db: Session = next(get_db()) - try: - db.query(RetrievalInfo).filter(RetrievalInfo.host_id == host_id).delete() - db.commit() - logging.info(f"成功删除 {success_count} 条检索数据") - except Exception as e: - logging.error(f"删除数据库中的检索数据失败: {e}") - finally: - db.close() - - - -async def _append_json(label: str, data: Any) -> None: - """记录冲突记忆(后台线程写入,避免阻塞事件循环)""" - def _write(): - with open("reflexion_data.json", "a", encoding="utf-8") as f: - f.write(f"### {label} ###\n") - json.dump(data, f, ensure_ascii=False, indent=4) - f.write("\n\n") - # 正确地在协程内等待后台线程执行,避免未等待的协程警告 - await asyncio.to_thread(_write) - - -async def self_reflexion(host_id: uuid.UUID) -> str: - """ - 自我反思引擎,执行反思流程。 - - Args: - host_id: 主机ID - - Returns: - 反思结果描述字符串 - """ - if not REFLEXION_ENABLED: - return "未开启反思..." - print(f"====== 自我反思流程开始 ======\n") - reflexion_data = await get_reflexion_data(host_id) - if not reflexion_data: - print(f"====== 自我反思流程结束 ======\n") - return "无反思数据,结束反思" - print(f"反思数据获取成功,共 {len(reflexion_data)} 条") - - conflict_data = await run_conflict(reflexion_data) - if not conflict_data: - print(f"====== 自我反思流程结束 ======\n") - return "无冲突,无需反思" - print(f"冲突记忆类型: {type(conflict_data)}") - await _append_json("conflict", conflict_data) - - solved_data = await run_reflexion(conflict_data) - if not solved_data: - print(f"====== 自我反思流程结束 ======\n") - return "反思失败,未解决冲突" - print(f"解决冲突后的记忆类型: {type(solved_data)}") - await _append_json("solved_data", solved_data) - - result = await update_memory(solved_data, host_id) - print(f"更新记忆库结果: {result}") - print(f"====== 自我反思流程结束 ======\n") - return result - - -if __name__ == "__main__": - import asyncio - # host_id = uuid.UUID("3f6ff1eb-50c7-4765-8e89-e4566be33333") - host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122") - asyncio.run(self_reflexion(host_id)) diff --git a/api/app/repositories/neo4j/graph_search.py b/api/app/repositories/neo4j/graph_search.py index 80756793..0b6a27c6 100644 --- a/api/app/repositories/neo4j/graph_search.py +++ b/api/app/repositories/neo4j/graph_search.py @@ -1,29 +1,30 @@ -from typing import Any, Dict, List, Optional import asyncio import logging +from typing import Any, Dict, List, Optional + +from app.repositories.neo4j.cypher_queries import ( + CHUNK_EMBEDDING_SEARCH, + ENTITY_EMBEDDING_SEARCH, + MEMORY_SUMMARY_EMBEDDING_SEARCH, + SEARCH_CHUNK_BY_CHUNK_ID, + SEARCH_CHUNKS_BY_CONTENT, + SEARCH_DIALOGUE_BY_DIALOG_ID, + SEARCH_ENTITIES_BY_NAME, + SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, + SEARCH_STATEMENTS_BY_CREATED_AT, + SEARCH_STATEMENTS_BY_KEYWORD, + SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, + SEARCH_STATEMENTS_BY_TEMPORAL, + SEARCH_STATEMENTS_BY_VALID_AT, + SEARCH_STATEMENTS_G_CREATED_AT, + SEARCH_STATEMENTS_G_VALID_AT, + SEARCH_STATEMENTS_L_CREATED_AT, + SEARCH_STATEMENTS_L_VALID_AT, + STATEMENT_EMBEDDING_SEARCH, +) # 使用新的仓储层 from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.repositories.neo4j.cypher_queries import ( - SEARCH_STATEMENTS_BY_KEYWORD, - SEARCH_ENTITIES_BY_NAME, - SEARCH_CHUNKS_BY_CONTENT, - STATEMENT_EMBEDDING_SEARCH, - CHUNK_EMBEDDING_SEARCH, - ENTITY_EMBEDDING_SEARCH, - SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, - MEMORY_SUMMARY_EMBEDDING_SEARCH, - SEARCH_STATEMENTS_BY_TEMPORAL, - SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, - SEARCH_DIALOGUE_BY_DIALOG_ID, - SEARCH_CHUNK_BY_CHUNK_ID, - SEARCH_STATEMENTS_BY_CREATED_AT, - SEARCH_STATEMENTS_BY_VALID_AT, - SEARCH_STATEMENTS_G_CREATED_AT, - SEARCH_STATEMENTS_L_CREATED_AT, - SEARCH_STATEMENTS_G_VALID_AT, - SEARCH_STATEMENTS_L_VALID_AT, -) logger = logging.getLogger(__name__) @@ -55,8 +56,12 @@ async def _update_activation_values_batch( return [] # 延迟导入以避免循环依赖 - from app.core.memory.storage_services.forgetting_engine.access_history_manager import AccessHistoryManager - from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator + from app.core.memory.storage_services.forgetting_engine.access_history_manager import ( + AccessHistoryManager, + ) + from app.core.memory.storage_services.forgetting_engine.actr_calculator import ( + ACTRCalculator, + ) # 创建计算器和管理器实例 actr_calculator = ACTRCalculator() @@ -292,6 +297,13 @@ async def search_graph( else: results[key] = result + # Deduplicate results before updating activation values + # This prevents duplicates from propagating through the pipeline + from app.core.memory.src.search import _deduplicate_results + for key in results: + if isinstance(results[key], list): + results[key] = _deduplicate_results(results[key]) + # 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary) results = await _update_search_results_activation( connector=connector, @@ -397,6 +409,13 @@ async def search_graph_by_embedding( else: results[key] = result + # Deduplicate results before updating activation values + # This prevents duplicates from propagating through the pipeline + from app.core.memory.src.search import _deduplicate_results + for key in results: + if isinstance(results[key], list): + results[key] = _deduplicate_results(results[key]) + # 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary) update_start = time.time() results = await _update_search_results_activation( From cb62608dbd13b8144287e935458372677bda7366 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Tue, 20 Jan 2026 15:56:21 +0800 Subject: [PATCH 3/7] refactor: extract edge's attrs config --- .../Workflow/components/Nodes/AddNode.tsx | 6 +- .../Workflow/components/Nodes/LoopNode.tsx | 26 +--- .../Workflow/components/PortClickHandler.tsx | 13 +- .../components/Properties/CaseList/index.tsx | 23 +-- .../Properties/CategoryList/index.tsx | 23 +-- web/src/views/Workflow/constant.ts | 20 ++- .../views/Workflow/hooks/useWorkflowGraph.ts | 134 ++---------------- 7 files changed, 41 insertions(+), 204 deletions(-) diff --git a/web/src/views/Workflow/components/Nodes/AddNode.tsx b/web/src/views/Workflow/components/Nodes/AddNode.tsx index d2d1d15c..f9561a44 100644 --- a/web/src/views/Workflow/components/Nodes/AddNode.tsx +++ b/web/src/views/Workflow/components/Nodes/AddNode.tsx @@ -2,7 +2,7 @@ import { useState } from 'react'; import { Popover } from 'antd'; import clsx from 'clsx'; import type { ReactShapeConfig } from '@antv/x6-react-shape'; -import { nodeLibrary, graphNodeLibrary } from '../../constant'; +import { nodeLibrary, graphNodeLibrary, edgeAttrs } from '../../constant'; import { useTranslation } from 'react-i18next'; const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => { @@ -47,7 +47,7 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => { graph.addEdge({ source: { cell: edge.getSourceCellId(), port: edge.getSourcePortId() }, target: { cell: newNode.id, port: newNode.getPorts().find((port: any) => port.group === 'left')?.id || 'left' }, - attrs: edge.getAttrs(), + ...edgeAttrs }); }); @@ -57,7 +57,7 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => { graph.addEdge({ source: { cell: newNode.id, port: newNode.getPorts().find((port: any) => port.group === 'right')?.id || 'right' }, target: { cell: edge.getTargetCellId(), port: targetPortId }, - attrs: edge.getAttrs(), + ...edgeAttrs }); }); diff --git a/web/src/views/Workflow/components/Nodes/LoopNode.tsx b/web/src/views/Workflow/components/Nodes/LoopNode.tsx index 26109d58..97136746 100644 --- a/web/src/views/Workflow/components/Nodes/LoopNode.tsx +++ b/web/src/views/Workflow/components/Nodes/LoopNode.tsx @@ -2,9 +2,7 @@ import { useEffect } from 'react'; import { useTranslation } from 'react-i18next' import clsx from 'clsx'; import type { ReactShapeConfig } from '@antv/x6-react-shape'; -import { graphNodeLibrary } from '../../constant'; - -import { edge_color } from '../../hooks/useWorkflowGraph' +import { graphNodeLibrary, edgeAttrs } from '../../constant'; const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => { const data = node.getData() || {}; @@ -56,16 +54,7 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => { graph.addEdge({ source: { cell: cycleStartNode.id, port: sourcePort }, target: { cell: addNode.id, port: targetPort }, - attrs: { - line: { - stroke: edge_color, - strokeWidth: 1, - targetMarker: { - name: 'block', - size: 8, - }, - }, - }, + ...edgeAttrs, }); } } @@ -122,16 +111,7 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => { cell: addNode.id, port: targetPorts.find((port: any) => port.group === 'left')?.id || 'left' }, - attrs: { - line: { - stroke: edge_color, - strokeWidth: 1, - targetMarker: { - name: 'block', - size: 2, - }, - }, - }, + ...edgeAttrs } graph.addEdge(edgeConfig) } diff --git a/web/src/views/Workflow/components/PortClickHandler.tsx b/web/src/views/Workflow/components/PortClickHandler.tsx index 8d95431c..0a1e3906 100644 --- a/web/src/views/Workflow/components/PortClickHandler.tsx +++ b/web/src/views/Workflow/components/PortClickHandler.tsx @@ -1,7 +1,7 @@ import { useEffect, useState } from 'react'; import { Popover } from 'antd'; import { useTranslation } from 'react-i18next'; -import { nodeLibrary, graphNodeLibrary } from '../constant'; +import { nodeLibrary, graphNodeLibrary, edgeAttrs } from '../constant'; interface PortClickHandlerProps { graph: any; @@ -149,16 +149,7 @@ const PortClickHandler: React.FC = ({ graph }) => { graph.addEdge({ source: { cell: sourceNode.id, port: sourcePort }, target: { cell: newNode.id, port: targetPort }, - attrs: { - line: { - stroke: '#155EEF', - strokeWidth: 1, - targetMarker: { - name: 'block', - size: 8, - }, - }, - }, + ...edgeAttrs // zIndex: sourceNodeData.cycle && sourceNodeType == 'cycle-start' ? 1 : sourceNodeData.cycle ? 2 : 0 }); diff --git a/web/src/views/Workflow/components/Properties/CaseList/index.tsx b/web/src/views/Workflow/components/Properties/CaseList/index.tsx index 8caa601e..d9521059 100644 --- a/web/src/views/Workflow/components/Properties/CaseList/index.tsx +++ b/web/src/views/Workflow/components/Properties/CaseList/index.tsx @@ -6,6 +6,7 @@ import { Form, Button, Select, Space, Divider, InputNumber, Radio, type SelectPr import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin' import VariableSelect from '../VariableSelect' import Editor from '../../Editor' +import { edgeAttrs } from '../../../constant' interface CaseListProps { value?: Array<{ logical_operator: 'and' | 'or'; expressions: { left: string; operator: string; right: string; input_type?: string; }[] }>; @@ -120,16 +121,7 @@ const CaseList: FC = ({ graphRef.current?.addEdge({ source: { cell: sourceCellId, port: sourcePortId }, target: { cell: selectedNode.id, port: targetPortId }, - attrs: { - line: { - stroke: '#155EEF', - strokeWidth: 1, - targetMarker: { - name: 'block', - size: 8, - }, - }, - }, + ...edgeAttrs, }); } graphRef.current?.removeCell(edge); @@ -174,16 +166,7 @@ const CaseList: FC = ({ graphRef.current?.addEdge({ source: { cell: selectedNode.id, port: newPortId }, target: { cell: targetCellId, port: targetPortId }, - attrs: { - line: { - stroke: '#155EEF', - strokeWidth: 1, - targetMarker: { - name: 'block', - size: 8, - }, - }, - }, + ...edgeAttrs }); } } diff --git a/web/src/views/Workflow/components/Properties/CategoryList/index.tsx b/web/src/views/Workflow/components/Properties/CategoryList/index.tsx index ca7f8115..aabc3ad3 100644 --- a/web/src/views/Workflow/components/Properties/CategoryList/index.tsx +++ b/web/src/views/Workflow/components/Properties/CategoryList/index.tsx @@ -5,6 +5,7 @@ import { Graph, Node } from '@antv/x6'; import Editor from '../../Editor'; import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin' +import { edgeAttrs } from '../../../constant' interface CategoryListProps { parentName: string; @@ -70,16 +71,7 @@ const CategoryList: FC = ({ parentName, selectedNode, graphRe graphRef.current?.addEdge({ source: { cell: sourceCellId, port: sourcePortId }, target: { cell: selectedNode.id, port: targetPortId }, - attrs: { - line: { - stroke: '#155EEF', - strokeWidth: 1, - targetMarker: { - name: 'block', - size: 8, - }, - }, - }, + ...edgeAttrs }); } return; @@ -110,16 +102,7 @@ const CategoryList: FC = ({ parentName, selectedNode, graphRe graphRef.current?.addEdge({ source: { cell: selectedNode.id, port: newPortId }, target: { cell: targetCellId, port: targetPortId }, - attrs: { - line: { - stroke: '#155EEF', - strokeWidth: 1, - targetMarker: { - name: 'block', - size: 8, - }, - }, - }, + ...edgeAttrs }); } } diff --git a/web/src/views/Workflow/constant.ts b/web/src/views/Workflow/constant.ts index bea518bf..17a5fd2b 100644 --- a/web/src/views/Workflow/constant.ts +++ b/web/src/views/Workflow/constant.ts @@ -517,6 +517,8 @@ interface NodeConfig { ports?: PortsConfig; } +export const edge_color = '#155EEF'; +export const edge_selected_color = '#4DA8FF' // 统一的端口 markup 配置 export const portMarkup = [ { @@ -534,9 +536,9 @@ export const portAttrs = { body: { r: 6, magnet: true, - stroke: '#155EEF', + stroke: edge_color, strokeWidth: 2, - fill: '#155EEF', + fill: edge_color, }, label: { text: '+', @@ -776,4 +778,18 @@ export const outputVariable: { [key: string]: OutputVariable } = { { name: "output", type: "string" }, ], }, +} + +export const edgeAttrs = { + attrs: { + line: { + stroke: edge_color, + strokeWidth: 1, + targetMarker: { + name: 'block', + width: 4, + height: 4, + }, + }, + }, } \ No newline at end of file diff --git a/web/src/views/Workflow/hooks/useWorkflowGraph.ts b/web/src/views/Workflow/hooks/useWorkflowGraph.ts index 615cd3e5..a9dc39c3 100644 --- a/web/src/views/Workflow/hooks/useWorkflowGraph.ts +++ b/web/src/views/Workflow/hooks/useWorkflowGraph.ts @@ -5,7 +5,7 @@ import { App } from 'antd' import { Graph, Node, MiniMap, Snapline, Clipboard, Keyboard, type Edge } from '@antv/x6'; import { register } from '@antv/x6-react-shape'; -import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs } from '../constant'; +import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs, edgeAttrs, edge_color, edge_selected_color } from '../constant'; import type { WorkflowConfig, NodeProperties, ChatVariable } from '../types'; import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application' import type { PortMetadata } from '@antv/x6/lib/model/port'; @@ -23,12 +23,8 @@ export interface UseWorkflowGraphReturn { setSelectedNode: React.Dispatch>; zoomLevel: number; setZoomLevel: React.Dispatch>; - canUndo: boolean; - canRedo: boolean; isHandMode: boolean; setIsHandMode: React.Dispatch>; - onUndo: () => void; - onRedo: () => void; onDrop: (event: React.DragEvent) => void; blankClick: () => void; deleteEvent: () => boolean | void; @@ -39,8 +35,6 @@ export interface UseWorkflowGraphReturn { setChatVariables: React.Dispatch>; } -export const edge_color = '#155EEF'; -const edge_selected_color = '#4DA8FF' export const useWorkflowGraph = ({ containerRef, miniMapRef, @@ -51,9 +45,6 @@ export const useWorkflowGraph = ({ const graphRef = useRef(); const [selectedNode, setSelectedNode] = useState(null); const [zoomLevel, setZoomLevel] = useState(1); - const historyRef = useRef<{ undoStack: string[], redoStack: string[] }>({ undoStack: [], redoStack: [] }); - const [canUndo, setCanUndo] = useState(false); - const [canRedo, setCanRedo] = useState(false); const [isHandMode, setIsHandMode] = useState(true); const [config, setConfig] = useState(null); const [chatVariables, setChatVariables] = useState([]) @@ -338,17 +329,7 @@ export const useWorkflowGraph = ({ port: targetPorts.find((port: any) => port.group === 'left')?.id || 'left' }, connector: { name: 'smooth' }, - attrs: { - line: { - stroke: edge_color, - strokeWidth: 1, - targetMarker: { - name: 'diamond', - width: 4, - height: 4, - }, - }, - }, + ...edgeAttrs // zIndex: loopIterationCount } @@ -368,48 +349,6 @@ export const useWorkflowGraph = ({ }, 200) } } - - const saveState = () => { - if (!graphRef.current) return; - const state = JSON.stringify(graphRef.current.toJSON()); - historyRef.current.undoStack.push(state); - historyRef.current.redoStack = []; - if (historyRef.current.undoStack.length > 50) { - historyRef.current.undoStack.shift(); - } - updateHistoryState(); - }; - - const updateHistoryState = () => { - setCanUndo(historyRef.current.undoStack.length > 1); - setCanRedo(historyRef.current.redoStack.length > 0); - }; - - // 撤销 - const onUndo = () => { - if (!graphRef.current || historyRef.current.undoStack.length === 0) return; - const { undoStack = [], redoStack = [] } = historyRef.current - - const currentState = JSON.stringify(graphRef.current.toJSON()); - const prevState = undoStack[undoStack.length - 2]; - - historyRef.current.redoStack = [...redoStack, currentState] - historyRef.current.undoStack = undoStack.slice(0, undoStack.length - 1) - graphRef.current.fromJSON(JSON.parse(prevState)); - updateHistoryState(); - }; - // 重做 - const onRedo = () => { - if (!graphRef.current || historyRef.current.redoStack.length === 0) return; - const { undoStack = [], redoStack = [] } = historyRef.current - - const nextState = redoStack[redoStack.length - 1]; - - historyRef.current.undoStack = [...undoStack, nextState] - historyRef.current.redoStack = redoStack.slice(0, redoStack.length - 1) - graphRef.current.fromJSON(JSON.parse(nextState)); - updateHistoryState(); - }; // 使用插件 const setupPlugins = () => { if (!graphRef.current || !miniMapRef.current) return; @@ -563,20 +502,6 @@ export const useWorkflowGraph = ({ } return false; }; - // 撤销快捷键事件 - const undoEvent = () => { - if (canUndo) { - onUndo(); - } - return false; - }; - // 重做快捷键事件 - const redoEvent = () => { - if (canRedo) { - onRedo(); - } - return false; - }; // 删除选中的节点和连线事件 const deleteEvent = () => { if (!graphRef.current) return; @@ -677,8 +602,6 @@ export const useWorkflowGraph = ({ background: { color: '#F0F3F8', }, - // width: container.clientWidth || 800, - // height: container.clientHeight || 600, autoResize: true, grid: { visible: true, @@ -694,37 +617,26 @@ export const useWorkflowGraph = ({ enabled: true, }, connecting: { - // router: 'orth', - // router: 'manhattan', connector: { name: 'smooth', args: { radius: 8, }, }, - anchor: 'center', + anchor: 'midSide', connectionPoint: 'anchor', allowBlank: false, + allowLoop: false, allowNode: false, allowEdge: false, + allowPort: true, + allowMulti: true, highlight: true, snap: { radius: 20, }, createEdge() { - return graphRef.current?.createEdge({ - attrs: { - line: { - stroke: edge_color, - strokeWidth: 1, - targetMarker: { - name: 'diamond', - width: 4, - height: 4, - }, - }, - }, - }); + return graphRef.current?.createEdge(edgeAttrs); }, validateConnection({ sourceCell, targetCell, targetMagnet }) { if (!targetMagnet) return false; @@ -823,25 +735,6 @@ export const useWorkflowGraph = ({ graphRef.current.on('scale', scaleEvent); // 监听节点移动事件 graphRef.current.on('node:moved', nodeMoved); - - // 监听画布变化事件 - const events = [ - 'node:added', - 'node:removed', - 'edge:added', - 'edge:removed', - ]; - events.forEach(event => { - graphRef.current!.on(event, () => { - console.log('event', event); - setTimeout(() => saveState(), 50); - }); - }); - - // 监听撤销键盘事件 - graphRef.current.bindKey(['ctrl+z', 'cmd+z'], undoEvent); - // 监听重做键盘事件 - graphRef.current.bindKey(['ctrl+shift+z', 'cmd+shift+z', 'ctrl+y', 'cmd+y'], redoEvent); // 监听复制键盘事件 graphRef.current.bindKey(['ctrl+c', 'cmd+c'], copyEvent); // 监听粘贴键盘事件 @@ -849,11 +742,6 @@ export const useWorkflowGraph = ({ // 删除选中的节点和连线 graphRef.current.bindKey(['ctrl+d', 'cmd+d', 'delete', 'backspace'], deleteEvent); - // 保存初始状态 - setTimeout(() => saveState(), 100); - // init window hook - (window as Window & { __x6_instances__?: Graph[] }).__x6_instances__ = []; - (window as Window & { __x6_instances__?: Graph[] }).__x6_instances__?.push(graphRef.current); }; useEffect(() => { @@ -1066,11 +954,11 @@ export const useWorkflowGraph = ({ }), } saveWorkflowConfig(config.app_id, params as WorkflowConfig) - .then(() => { + .then((res) => { if (flag) { message.success(t('common.saveSuccess')) } - resolve(true) + resolve(res) }).catch(error => { reject(error) }) @@ -1085,12 +973,8 @@ export const useWorkflowGraph = ({ setSelectedNode, zoomLevel, setZoomLevel, - canUndo, - canRedo, isHandMode, setIsHandMode, - onUndo, - onRedo, onDrop, blankClick, deleteEvent, From c6030bbec89b80a7573b6f29a2a0baf8e120f07d Mon Sep 17 00:00:00 2001 From: zhaoying Date: Tue, 20 Jan 2026 15:57:44 +0800 Subject: [PATCH 4/7] feat(web): add yamlExport function --- web/src/utils/yamlExport.ts | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 web/src/utils/yamlExport.ts diff --git a/web/src/utils/yamlExport.ts b/web/src/utils/yamlExport.ts new file mode 100644 index 00000000..9a8c0c41 --- /dev/null +++ b/web/src/utils/yamlExport.ts @@ -0,0 +1,13 @@ +import yaml from 'js-yaml'; + + +export const exportToYaml = (data: unknown, filename: string = 'export.yaml') => { + const yamlStr = yaml.dump(data); + const blob = new Blob([yamlStr], { type: 'text/yaml' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = filename; + a.click(); + URL.revokeObjectURL(url); +}; From 91d3758691fe711009ded9d38936efec7773b6b4 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Tue, 20 Jan 2026 15:59:55 +0800 Subject: [PATCH 5/7] feat(web): agent and multi_agent handleSave function add promise resolve result --- web/src/views/ApplicationConfig/Agent.tsx | 6 ++---- web/src/views/ApplicationConfig/Cluster.tsx | 6 ++---- .../views/ApplicationConfig/components/ConfigHeader.tsx | 8 ++++---- web/src/views/ApplicationConfig/index.tsx | 3 ++- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/web/src/views/ApplicationConfig/Agent.tsx b/web/src/views/ApplicationConfig/Agent.tsx index 92170d55..4d410338 100644 --- a/web/src/views/ApplicationConfig/Agent.tsx +++ b/web/src/views/ApplicationConfig/Agent.tsx @@ -311,17 +311,15 @@ const Agent = forwardRef((_props, ref) => { enabled: vo.enabled })) } - - console.log('params', rest, params) return new Promise((resolve, reject) => { saveAgentConfig(data.app_id, params) - .then(() => { + .then((res) => { if (flag) { message.success(t('common.saveSuccess')) } setIsSave(false) - resolve(true) + resolve(res) }).catch(error => { reject(error) }) diff --git a/web/src/views/ApplicationConfig/Cluster.tsx b/web/src/views/ApplicationConfig/Cluster.tsx index 714c1c49..3081aa04 100644 --- a/web/src/views/ApplicationConfig/Cluster.tsx +++ b/web/src/views/ApplicationConfig/Cluster.tsx @@ -58,16 +58,14 @@ const Cluster = forwardRef((_props, ref) => { })) } - console.log('params', params) - return new Promise((resolve, reject) => { form.validateFields().then(() => { saveMultiAgentConfig(id as string, params) - .then(() => { + .then((res) => { if (flag) { message.success(t('common.saveSuccess')) } - resolve(true) + resolve(res) }) .catch(error => { reject(error) diff --git a/web/src/views/ApplicationConfig/components/ConfigHeader.tsx b/web/src/views/ApplicationConfig/components/ConfigHeader.tsx index 1837b696..94ef0ef7 100644 --- a/web/src/views/ApplicationConfig/components/ConfigHeader.tsx +++ b/web/src/views/ApplicationConfig/components/ConfigHeader.tsx @@ -11,7 +11,7 @@ import exportIcon from '@/assets/images/export_hover.svg' import deleteIcon from '@/assets/images/delete_hover.svg' import type { Application, ApplicationModalRef } from '@/views/ApplicationManagement/types'; import ApplicationModal from '@/views/ApplicationManagement/components/ApplicationModal' -import type { CopyModalRef, WorkflowRef } from '../types' +import type { CopyModalRef, AgentRef, ClusterRef, WorkflowRef } from '../types' import { deleteApplication } from '@/api/application' import CopyModal from './CopyModal' @@ -30,10 +30,11 @@ interface ConfigHeaderProps { handleChangeTab: (key: string) => void; refresh: () => void; workflowRef: React.RefObject + appRef?: React.RefObject } const ConfigHeader: FC = ({ application, activeTab, handleChangeTab, refresh, - workflowRef + workflowRef, }) => { const { t } = useTranslation(); const navigate = useNavigate(); @@ -48,7 +49,7 @@ const ConfigHeader: FC = ({ })) } const formatMenuItems = () => { - const items = ['edit', 'copy', 'delete'].map(key => ({ + const items = ['edit', 'copy', 'export', 'delete'].map(key => ({ key, icon: , label: t(`common.${key}`), @@ -59,7 +60,6 @@ const ConfigHeader: FC = ({ } } const handleClick: MenuProps['onClick'] = ({ key }) => { - console.log('key', key) switch (key) { case 'edit': applicationModalRef.current?.handleOpen(application as Application) diff --git a/web/src/views/ApplicationConfig/index.tsx b/web/src/views/ApplicationConfig/index.tsx index 23f3e434..7d5d5950 100644 --- a/web/src/views/ApplicationConfig/index.tsx +++ b/web/src/views/ApplicationConfig/index.tsx @@ -60,10 +60,11 @@ const ApplicationConfig: React.FC = () => { handleChangeTab={handleChangeTab} application={application as Application} refresh={getApplicationInfo} + appRef={application?.type === 'agent' ? agentRef : application?.type === 'multi_agent' ? clusterRef : application?.type === 'workflow' ? workflowRef : undefined} workflowRef={workflowRef} /> {activeTab === 'arrangement' && application?.type === 'agent' && } - {activeTab === 'arrangement' && application?.type === 'multi_agent' && } + {activeTab === 'arrangement' && application?.type === 'multi_agent' && } {activeTab === 'arrangement' && application?.type === 'workflow' && } {activeTab === 'api' && } {activeTab === 'release' && } From b5b1a98bc4c2a23040fb1def53462f1e47809307 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Tue, 20 Jan 2026 16:10:49 +0800 Subject: [PATCH 6/7] fix(web): when the type of the loop variable is number, value uses InputNumber --- .../Properties/CycleVarsList/index.tsx | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/web/src/views/Workflow/components/Properties/CycleVarsList/index.tsx b/web/src/views/Workflow/components/Properties/CycleVarsList/index.tsx index 13e00ea2..9e28c958 100644 --- a/web/src/views/Workflow/components/Properties/CycleVarsList/index.tsx +++ b/web/src/views/Workflow/components/Properties/CycleVarsList/index.tsx @@ -1,6 +1,6 @@ import { type FC } from 'react' import { useTranslation } from 'react-i18next'; -import { Form, Select, Input, Button } from 'antd' +import { Form, Select, Input, Button, InputNumber } from 'antd' import VariableSelect from '../VariableSelect' import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin' @@ -93,6 +93,7 @@ const CycleVarsList: FC = ({ {fields.map(({ key, name }, index) => { + const currentType = value?.[index]?.type; const currentInputType = value?.[index]?.input_type; return ( @@ -131,7 +132,8 @@ const CycleVarsList: FC = ({ - {currentInputType === 'variable' ? ( + {currentInputType === 'variable' + ? ( { @@ -143,7 +145,15 @@ const CycleVarsList: FC = ({ variant="borderless" size="small" /> - ) : ( + ) + : currentType === 'number' + ? form.setFieldValue([name, 'value'], value)} + /> + : ( Date: Tue, 20 Jan 2026 16:19:02 +0800 Subject: [PATCH 7/7] fix(web): when the type of the loop variable is boolean, value uses Radio --- .../Workflow/components/Properties/CycleVarsList/index.tsx | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/web/src/views/Workflow/components/Properties/CycleVarsList/index.tsx b/web/src/views/Workflow/components/Properties/CycleVarsList/index.tsx index 9e28c958..dfa82f0a 100644 --- a/web/src/views/Workflow/components/Properties/CycleVarsList/index.tsx +++ b/web/src/views/Workflow/components/Properties/CycleVarsList/index.tsx @@ -1,6 +1,6 @@ import { type FC } from 'react' import { useTranslation } from 'react-i18next'; -import { Form, Select, Input, Button, InputNumber } from 'antd' +import { Form, Select, Input, Button, InputNumber, Radio } from 'antd' import VariableSelect from '../VariableSelect' import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin' @@ -153,6 +153,11 @@ const CycleVarsList: FC = ({ className="rb:w-full! rb:my-1!" onChange={(value) => form.setFieldValue([name, 'value'], value)} /> + : currentType === 'boolean' + ? + True + False + : (