Merge remote-tracking branch 'origin/develop' into develop

# Conflicts:
#	api/app/services/memory_reflection_service.py
This commit is contained in:
lixinyue
2026-01-20 16:32:27 +08:00
27 changed files with 273 additions and 2357 deletions

View File

@@ -1,10 +1,8 @@
import os
import uuid
from typing import Optional
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.memory.utils.self_reflexion_utils import self_reflexion
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import get_current_user
@@ -458,18 +456,3 @@ async def get_recent_activity_stats_api(
api_logger.error(f"Recent activity stats failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
@router.get("/self_reflexion")
async def self_reflexion_endpoint(host_id: uuid.UUID) -> str:
"""
自我反思接口,自动对检索出的信息进行自我反思并返回自我反思结果。
Args:
None
Returns:
自我反思结果。
"""
return await self_reflexion(host_id)

View File

@@ -1,48 +1,15 @@
import asyncio
import os
import sys
from typing import List, Tuple
from neo4j import GraphDatabase
from pydantic import BaseModel, Field
# ------------------- 自包含路径解析 -------------------
# 这个代码块确保脚本可以从任何地方运行,并且仍然可以在项目结构中找到它需要的模块。
try:
# 假设脚本在 /path/to/project/src/analytics/
# 上升3个级别以到达项目根目录。
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
src_path = os.path.join(project_root, 'src')
# 将 'src' 和 'project_root' 都添加到路径中。
# 'src' 目录对于像 'from utils.config_utils import ...' 这样的导入是必需的。
# 'project_root' 目录对于像 'from variate_config import ...' 这样的导入是必需的。
if src_path not in sys.path:
sys.path.insert(0, src_path)
if project_root not in sys.path:
sys.path.insert(0, project_root)
except NameError:
# 为 __file__ 未定义的环境(例如某些交互式解释器)提供回退方案
project_root = os.path.abspath(os.path.join(os.getcwd()))
src_path = os.path.join(project_root, 'src')
if src_path not in sys.path:
sys.path.insert(0, src_path)
if project_root not in sys.path:
sys.path.insert(0, project_root)
# ---------------------------------------------------------------------
# 现在路径已经配置好,我们可以使用绝对导入
import json
import os
from typing import List, Tuple
from app.core.config import settings
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
from pydantic import BaseModel, Field
#TODO: Fix this
# Default values (previously from definitions.py)
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
# 定义用于LLM结构化输出的Pydantic模型
class FilteredTags(BaseModel):
@@ -52,34 +19,45 @@ class FilteredTags(BaseModel):
async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
"""
使用LLM筛选标签列表仅保留具有代表性的核心名词。
Args:
tags: 原始标签列表
group_id: 用户组ID用于获取配置
Returns:
筛选后的标签列表
Raises:
ValueError: 如果无法获取有效的LLM配置
"""
try:
# Get config_id using get_end_user_connected_config
with get_db_context() as db:
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(group_id, db)
config_id = connected_config.get("memory_config_id")
if not config_id:
raise ValueError(
f"No memory_config_id found for group_id: {group_id}. "
"Please ensure the user has a valid memory configuration."
)
connected_config = get_end_user_connected_config(group_id, db)
config_id = connected_config.get("memory_config_id")
if config_id:
# Use the config_id to get the proper LLM client
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(memory_config.llm_model_id)
else:
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM if no config found
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
except Exception as e:
print(f"Failed to get user connected config, using default LLM: {e}")
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
# Use the config_id to get the proper LLM client
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
if not memory_config.llm_model_id:
raise ValueError(
f"No llm_model_id found in memory config {config_id}. "
"Please configure a valid LLM model."
)
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(memory_config.llm_model_id)
# 3. 构建Prompt
tag_list_str = ", ".join(tags)
@@ -107,33 +85,26 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
# 在LLM失败时返回原始标签确保流程继续
return tags
def get_db_connection():
"""
使用项目的标准配置方法建立与Neo4j数据库的连接。
"""
# 从全局配置获取 Neo4j 连接信息
uri = settings.NEO4J_URI
user = settings.NEO4J_USERNAME
# 密码必须为了安全从环境变量加载
password = os.getenv("NEO4J_PASSWORD")
if not uri or not user:
raise ValueError("在 config.json 中未找到 Neo4j 的 'uri''username'")
if not password:
raise ValueError("NEO4J_PASSWORD 环境变量未设置。")
# 为此脚本使用同步驱动
return GraphDatabase.driver(uri, auth=(user, password))
def get_raw_tags_from_db(group_id: str, limit: int, by_user: bool = False) -> List[Tuple[str, int]]:
async def get_raw_tags_from_db(
connector: Neo4jConnector,
group_id: str,
limit: int,
by_user: bool = False
) -> List[Tuple[str, int]]:
"""
TODO: not accurate tag extraction
从数据库查询原始的、未经过滤的实体标签及其频率。
使用项目的Neo4jConnector进行查询遵循仓储模式。
Args:
connector: Neo4j连接器实例
group_id: 如果by_user=False则为group_id如果by_user=True则为user_id
limit: 返回的标签数量限制
by_user: 是否按user_id查询默认False按group_id查询
Returns:
List[Tuple[str, int]]: 标签名称和频率的元组列表
"""
names_to_exclude = ['AI', 'Caroline', 'Melanie', 'Jon', 'Gina', '用户', 'AI助手', 'John', 'Maria']
@@ -154,83 +125,55 @@ def get_raw_tags_from_db(group_id: str, limit: int, by_user: bool = False) -> Li
"LIMIT $limit"
)
driver = None
try:
driver = get_db_connection()
with driver.session() as session:
result = session.run(query, id=group_id, limit=limit, names_to_exclude=names_to_exclude)
return [(record["name"], record["frequency"]) for record in result]
finally:
if driver:
driver.close()
# 使用项目的Neo4jConnector执行查询
results = await connector.execute_query(
query,
id=group_id,
limit=limit,
names_to_exclude=names_to_exclude
)
return [(record["name"], record["frequency"]) for record in results]
async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
async def get_hot_memory_tags(group_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
"""
获取原始标签然后使用LLM进行筛选返回最终的热门标签列表。
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
Args:
group_id: 如果by_user=False则为group_id如果by_user=True则为user_id
group_id: 必需参数。如果by_user=False则为group_id如果by_user=True则为user_id
limit: 返回的标签数量限制
by_user: 是否按user_id查询默认False按group_id查询
Raises:
ValueError: 如果group_id未提供或为空
"""
# 默认从环境变量读取
group_id = group_id or DEFAULT_GROUP_ID
# 1. 从数据库获取原始排名靠前的标签
raw_tags_with_freq = get_raw_tags_from_db(group_id, limit, by_user=by_user)
if not raw_tags_with_freq:
return []
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id)
# 3. 根据LLM的筛选结果构建最终的标签列表保留原始频率和顺序
final_tags = []
for tag, freq in raw_tags_with_freq:
if tag in meaningful_tag_names:
final_tags.append((tag, freq))
return final_tags
if __name__ == "__main__":
print("开始获取热门记忆标签...")
# 验证group_id必须提供且不为空
if not group_id or not group_id.strip():
raise ValueError(
"group_id is required. Please provide a valid group_id or user_id."
)
# 使用项目的Neo4jConnector
connector = Neo4jConnector()
try:
# 直接使用环境变量中的 group_id
group_id_to_query = DEFAULT_GROUP_ID
# 使用 asyncio.run 来执行异步主函数
top_tags = asyncio.run(get_hot_memory_tags(group_id=group_id_to_query))
# 1. 从数据库获取原始排名靠前的标签
raw_tags_with_freq = await get_raw_tags_from_db(connector, group_id, limit, by_user=by_user)
if not raw_tags_with_freq:
return []
if top_tags:
print(f"热门记忆标签 (Group ID: {group_id_to_query}, 经LLM筛选):")
for tag, frequency in top_tags:
print(f"- {tag} (数量: {frequency})")
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
# --- 将结果写入统一的 Signboard.json 到 logs/memory-output ---
from app.core.config import settings
settings.ensure_memory_output_dir()
signboard_path = settings.get_memory_output_path("Signboard.json")
payload = {
"group_id": group_id_to_query,
"hot_tags": [{"name": t, "frequency": f} for t, f in top_tags]
}
try:
existing = {}
if os.path.exists(signboard_path):
with open(signboard_path, "r", encoding="utf-8") as rf:
existing = json.load(rf)
existing["hot_memory_tags"] = payload
with open(signboard_path, "w", encoding="utf-8") as wf:
json.dump(existing, wf, ensure_ascii=False, indent=2)
print(f"已写入 {signboard_path} -> hot_memory_tags")
except Exception as e:
print(f"写入 Signboard.json 失败: {e}")
else:
print(f"在 Group ID '{group_id_to_query}' 中没有找到符合条件的实体标签。")
except Exception as e:
print(f"执行过程中发生严重错误: {e}")
print("请检查:")
print("1. Neo4j数据库服务是否正在运行。")
print("2. 'config.json'中的配置是否正确。")
print("3. 相关的环境变量 (如 NEO4J_PASSWORD, DASHSCOPE_API_KEY) 是否已正确设置。")
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id)
# 3. 根据LLM的筛选结果构建最终的标签列表保留原始频率和顺序
final_tags = []
for tag, freq in raw_tags_with_freq:
if tag in meaningful_tag_names:
final_tags.append((tag, freq))
return final_tags
finally:
# 确保关闭连接
await connector.close()

View File

@@ -131,179 +131,60 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
return results
# ============================================================================
# 以下函数已被 rerank_with_activation 替代,暂时保留以供参考
# ============================================================================
# def rerank_hybrid_results(
# keyword_results: Dict[str, List[Dict[str, Any]]],
# embedding_results: Dict[str, List[Dict[str, Any]]],
# alpha: float = 0.6,
# limit: int = 10
# ) -> Dict[str, List[Dict[str, Any]]]:
# """
# Rerank hybrid search results by combining BM25 and embedding scores.
#
# 已废弃:此函数功能已被 rerank_with_activation 完全替代
#
# Args:
# keyword_results: Results from keyword/BM25 search
# embedding_results: Results from embedding search
# alpha: Weight for BM25 scores (1-alpha for embedding scores)
# limit: Maximum number of results to return per category
#
# Returns:
# Reranked results with combined scores
# """
# reranked = {}
#
# for category in ["statements", "chunks", "entities","summaries"]:
# keyword_items = keyword_results.get(category, [])
# embedding_items = embedding_results.get(category, [])
#
# # Normalize scores within each search type
# keyword_items = normalize_scores(keyword_items, "score")
# embedding_items = normalize_scores(embedding_items, "score")
#
# # Create a combined pool of unique items
# combined_items = {}
#
# # Add keyword results with BM25 scores
# for item in keyword_items:
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
# if item_id:
# combined_items[item_id] = item.copy()
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
# combined_items[item_id]["embedding_score"] = 0 # Default
#
# # Add or update with embedding results
# for item in embedding_items:
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
# if item_id:
# if item_id in combined_items:
# # Update existing item with embedding score
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# else:
# # New item from embedding search only
# combined_items[item_id] = item.copy()
# combined_items[item_id]["bm25_score"] = 0 # Default
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
#
# # Calculate combined scores and rank
# for item_id, item in combined_items.items():
# bm25_score = item.get("bm25_score", 0)
# embedding_score = item.get("embedding_score", 0)
#
# # Combined score: weighted average of normalized scores
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
# item["combined_score"] = combined_score
#
# # Keep original score for reference
# if "score" not in item and bm25_score > 0:
# item["score"] = bm25_score
# elif "score" not in item and embedding_score > 0:
# item["score"] = embedding_score
#
# # Sort by combined score and limit results
# sorted_items = sorted(
# combined_items.values(),
# key=lambda x: x.get("combined_score", 0),
# reverse=True
# )[:limit]
#
# reranked[category] = sorted_items
#
# return reranked
# def rerank_with_forgetting_curve(
# keyword_results: Dict[str, List[Dict[str, Any]]],
# embedding_results: Dict[str, List[Dict[str, Any]]],
# alpha: float = 0.6,
# limit: int = 10,
# forgetting_config: ForgettingEngineConfig | None = None,
# now: datetime | None = None,
# ) -> Dict[str, List[Dict[str, Any]]]:
# """
# Rerank hybrid results with a forgetting curve applied to combined scores.
#
# 已废弃:此函数功能已被 rerank_with_activation 完全替代
# rerank_with_activation 提供了更完整的遗忘曲线支持(结合激活度)
#
# The forgetting curve reduces scores for older memories or weaker connections.
#
# Args:
# keyword_results: Results from keyword/BM25 search
# embedding_results: Results from embedding search
# alpha: Weight for BM25 scores (1-alpha for embedding scores)
# limit: Maximum number of results to return per category
# forgetting_config: Configuration for the forgetting engine
# now: Optional current time override for testing
#
# Returns:
# Reranked results with combined and final scores (after forgetting)
# """
# engine = ForgettingEngine(forgetting_config or ForgettingEngineConfig())
# now_dt = now or datetime.now()
#
# reranked: Dict[str, List[Dict[str, Any]]] = {}
#
# for category in ["statements", "chunks", "entities","summaries"]:
# keyword_items = keyword_results.get(category, [])
# embedding_items = embedding_results.get(category, [])
#
# # Normalize scores within each search type
# keyword_items = normalize_scores(keyword_items, "score")
# embedding_items = normalize_scores(embedding_items, "score")
#
# combined_items: Dict[str, Dict[str, Any]] = {}
#
# # Combine two result sets by ID
# for src_items, is_embedding in (
# (keyword_items, False), (embedding_items, True)
# ):
# for item in src_items:
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
# if not item_id:
# continue
# existing = combined_items.get(item_id)
# if not existing:
# combined_items[item_id] = item.copy()
# combined_items[item_id]["bm25_score"] = 0
# combined_items[item_id]["embedding_score"] = 0
# # Update normalized score from the right source
# if is_embedding:
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# else:
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
#
# # Calculate scores and apply forgetting weights
# for item_id, item in combined_items.items():
# bm25_score = float(item.get("bm25_score", 0) or 0)
# embedding_score = float(item.get("embedding_score", 0) or 0)
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
#
# # Estimate time elapsed in days
# dt = _parse_datetime(item.get("created_at"))
# if dt is None:
# time_elapsed_days = 0.0
# else:
# time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
#
# # Memory strength (currently set to default value)
# memory_strength = 1.0
# forgetting_weight = engine.calculate_weight(
# time_elapsed=time_elapsed_days, memory_strength=memory_strength
# )
# final_score = combined_score * forgetting_weight
# item["combined_score"] = final_score
#
# sorted_items = sorted(
# combined_items.values(), key=lambda x: x.get("combined_score", 0), reverse=True
# )[:limit]
#
# reranked[category] = sorted_items
#
# return reranked
def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Remove duplicate items from search results based on content.
Deduplication strategy:
1. First try to deduplicate by ID (id, uuid, or chunk_id)
2. Then deduplicate by content hash (text, content, statement, or name fields)
Args:
items: List of search result items
Returns:
Deduplicated list of items, preserving the order of first occurrence
"""
seen_ids = set()
seen_content = set()
deduplicated = []
for item in items:
# Try multiple ID fields to identify unique items
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
# Extract content from various possible fields
content = (
item.get("text") or
item.get("content") or
item.get("statement") or
item.get("name") or
""
)
# Normalize content for comparison (strip whitespace and lowercase)
normalized_content = str(content).strip().lower() if content else ""
# Check if we've seen this ID or content before
is_duplicate = False
if item_id and item_id in seen_ids:
is_duplicate = True
elif normalized_content and normalized_content in seen_content:
# Only check content duplication if content is not empty
is_duplicate = True
if not is_duplicate:
# Mark as seen
if item_id:
seen_ids.add(item_id)
if normalized_content: # Only track non-empty content
seen_content.add(normalized_content)
deduplicated.append(item)
return deduplicated
def rerank_with_activation(
@@ -364,7 +245,7 @@ def rerank_with_activation(
keyword_items = normalize_scores(keyword_items, "score")
embedding_items = normalize_scores(embedding_items, "score")
# 步骤 2: 按 ID 合并结果
# 步骤 2: 按 ID 合并结果(去重)
combined_items: Dict[str, Dict[str, Any]] = {}
# 添加关键词结果
@@ -507,6 +388,9 @@ def rerank_with_activation(
# 无激活值:使用内容相关性分数
item["final_score"] = item.get("base_score", 0)
# 最终去重确保没有重复项
sorted_items = _deduplicate_results(sorted_items)
reranked[category] = sorted_items
return reranked
@@ -1144,96 +1028,3 @@ async def search_chunk_by_chunk_id(
)
return {"chunks": chunks}
# def main():
# """Main entry point for the hybrid graph search CLI.
# Parses command line arguments and executes search with specified parameters.
# Supports keyword, embedding, and hybrid search modes.
# """
# parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options")
# parser.add_argument(
# "--query", "-q", required=True, help="Free-text query to search"
# )
# parser.add_argument(
# "--search-type",
# "-t",
# choices=["keyword", "embedding", "hybrid"],
# default="hybrid",
# help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)"
# )
# parser.add_argument(
# "--config-id",
# "-c",
# type=int,
# required=True,
# help="Database configuration ID (required)",
# )
# parser.add_argument(
# "--group-id",
# "-g",
# default=None,
# help="Optional group_id to filter results (default: None)",
# )
# parser.add_argument(
# "--limit",
# "-k",
# type=int,
# default=5,
# help="Max number of results per type (default: 5)",
# )
# parser.add_argument(
# "--include",
# "-i",
# nargs="+",
# default=["statements", "chunks", "entities", "summaries"],
# choices=["statements", "chunks", "entities", "summaries"],
# help="Which targets to search for embedding search (default: statements chunks entities summaries)"
# )
# parser.add_argument(
# "--output",
# "-o",
# default="search_results.json",
# help="Path to save the search results JSON (default: search_results.json)",
# )
# parser.add_argument(
# "--rerank-alpha",
# "-a",
# type=float,
# default=0.6,
# help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)",
# )
# parser.add_argument(
# "--forgetting-rerank",
# action="store_true",
# help="Apply forgetting curve during reranking for hybrid search.",
# )
# parser.add_argument(
# "--llm-rerank",
# action="store_true",
# help="Apply LLM-based reranking for hybrid search.",
# )
# args = parser.parse_args()
# # Load memory config from database
# from app.services.memory_config_service import MemoryConfigService
# memory_config = MemoryConfigService.load_memory_config(args.config_id)
# asyncio.run(
# run_hybrid_search(
# query_text=args.query,
# search_type=args.search_type,
# group_id=args.group_id,
# limit=args.limit,
# include=args.include,
# output_path=args.output,
# memory_config=memory_config,
# rerank_alpha=args.rerank_alpha,
# use_forgetting_rerank=args.forgetting_rerank,
# use_llm_rerank=args.llm_rerank,
# )
# )
# if __name__ == "__main__":
# main()

View File

@@ -18,13 +18,11 @@ from enum import Enum
from typing import Any, Dict, List, Optional
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.memory.utils.config import definitions as config_defs
from app.core.memory.utils.config.get_data import (
extract_and_process_changes,
get_data,
get_data_statement,
)
from app.core.models.base import RedBearModelConfig
from app.repositories.neo4j.cypher_queries import (
neo4j_query_all,

View File

@@ -14,28 +14,8 @@ from .config_utils import (
get_pruning_config,
get_voice_config,
)
# DEPRECATED: Global configuration variables removed
# Use MemoryConfig objects with dependency injection instead
# from .definitions import (
# CONFIG, # DEPRECATED - empty dict for backward compatibility
# RUNTIME_CONFIG, # DEPRECATED - minimal for backward compatibility
# PROJECT_ROOT, # Still needed for file paths
# reload_configuration_from_database, # DEPRECATED - returns False
# )
# DEPRECATED: overrides module removed - use MemoryConfig with dependency injection
from .get_data import get_data
# litellm_config 需要时动态导入,避免循环依赖
# from .litellm_config import (
# LiteLLMConfig,
# setup_litellm_enhanced,
# get_usage_summary,
# print_usage_summary,
# get_instant_qps,
# print_instant_qps,
# )
__all__ = [
# config_utils
"get_model_config",
@@ -45,18 +25,5 @@ __all__ = [
"get_pruning_config",
"get_picture_config",
"get_voice_config",
# definitions (DEPRECATED - use MemoryConfig objects instead)
# "CONFIG", # DEPRECATED
# "RUNTIME_CONFIG", # DEPRECATED
# "PROJECT_ROOT",
# "reload_configuration_from_database", # DEPRECATED
# get_data
"get_data",
# litellm_config - 需要时从 .litellm_config 直接导入
# "LiteLLMConfig",
# "setup_litellm_enhanced",
# "get_usage_summary",
# "print_usage_summary",
# "get_instant_qps",
# "print_instant_qps",
]

View File

@@ -1,398 +0,0 @@
"""
配置管理优化模块
提供可选的配置管理优化功能,包括:
- LRU 缓存策略
- 缓存预热
- 缓存监控指标
- 动态 TTL 策略
- 配置版本控制
这些优化是可选的,当前的基础实现已经满足大多数需求。
"""
import logging
import statistics
import threading
from collections import OrderedDict
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional, Tuple
logger = logging.getLogger(__name__)
class LRUConfigCache:
"""
LRULeast Recently Used配置缓存
当缓存达到最大容量时,自动淘汰最少使用的配置
"""
def __init__(self, max_size: int = 100, ttl: timedelta = timedelta(minutes=5)):
"""
初始化 LRU 缓存
Args:
max_size: 最大缓存容量
ttl: 缓存过期时间
"""
self.max_size = max_size
self.ttl = ttl
self._cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
self._timestamps: Dict[str, datetime] = {}
self._lock = threading.RLock()
# 统计信息
self._stats = {
'hits': 0,
'misses': 0,
'evictions': 0,
'load_times': []
}
def get(self, config_id: str) -> Optional[Dict[str, Any]]:
"""
获取配置(如果存在且未过期)
Args:
config_id: 配置 ID
Returns:
配置字典,如果不存在或已过期则返回 None
"""
with self._lock:
if config_id not in self._cache:
self._stats['misses'] += 1
return None
# 检查是否过期
timestamp = self._timestamps.get(config_id)
if timestamp and (datetime.now() - timestamp) >= self.ttl:
# 过期,移除
self._cache.pop(config_id, None)
self._timestamps.pop(config_id, None)
self._stats['misses'] += 1
return None
# 命中,移动到末尾(标记为最近使用)
self._cache.move_to_end(config_id)
self._stats['hits'] += 1
return self._cache[config_id]
def put(self, config_id: str, config: Dict[str, Any]) -> None:
"""
添加或更新配置
Args:
config_id: 配置 ID
config: 配置字典
"""
with self._lock:
if config_id in self._cache:
# 更新现有配置
self._cache.move_to_end(config_id)
else:
# 添加新配置
if len(self._cache) >= self.max_size:
# 缓存已满,移除最旧的配置
oldest_id, _ = self._cache.popitem(last=False)
self._timestamps.pop(oldest_id, None)
self._stats['evictions'] += 1
logger.debug(f"[LRUCache] 淘汰配置: {oldest_id}")
self._cache[config_id] = config
self._timestamps[config_id] = datetime.now()
def clear(self, config_id: Optional[str] = None) -> None:
"""
清除缓存
Args:
config_id: 如果指定,只清除该配置;否则清除所有
"""
with self._lock:
if config_id:
self._cache.pop(config_id, None)
self._timestamps.pop(config_id, None)
else:
self._cache.clear()
self._timestamps.clear()
def get_stats(self) -> Dict[str, Any]:
"""
获取缓存统计信息
Returns:
统计信息字典
"""
with self._lock:
total = self._stats['hits'] + self._stats['misses']
hit_rate = (self._stats['hits'] / total * 100) if total > 0 else 0
return {
'cache_size': len(self._cache),
'max_size': self.max_size,
'total_requests': total,
'cache_hits': self._stats['hits'],
'cache_misses': self._stats['misses'],
'evictions': self._stats['evictions'],
'hit_rate': hit_rate,
'avg_load_time': statistics.mean(self._stats['load_times']) if self._stats['load_times'] else 0
}
def record_load_time(self, load_time_ms: float) -> None:
"""
记录加载时间
Args:
load_time_ms: 加载时间(毫秒)
"""
with self._lock:
self._stats['load_times'].append(load_time_ms)
# 只保留最近 1000 次的记录
if len(self._stats['load_times']) > 1000:
self._stats['load_times'] = self._stats['load_times'][-1000:]
class ConfigCacheWarmer:
"""
配置缓存预热器
在系统启动时预加载常用配置,减少首次请求延迟
"""
@staticmethod
def warmup(config_ids: List[str], load_func) -> Dict[str, bool]:
"""
预热缓存
Args:
config_ids: 要预加载的配置 ID 列表
load_func: 配置加载函数
Returns:
每个配置的加载结果
"""
results = {}
logger.info(f"[CacheWarmer] 开始预热 {len(config_ids)} 个配置")
for config_id in config_ids:
try:
result = load_func(config_id)
results[config_id] = result
if result:
logger.debug(f"[CacheWarmer] 成功预热配置: {config_id}")
else:
logger.warning(f"[CacheWarmer] 预热配置失败: {config_id}")
except Exception as e:
logger.error(f"[CacheWarmer] 预热配置异常: {config_id}, 错误: {e}")
results[config_id] = False
success_count = sum(1 for r in results.values() if r)
logger.info(f"[CacheWarmer] 预热完成: {success_count}/{len(config_ids)} 成功")
return results
class DynamicTTLStrategy:
"""
动态 TTL 策略
根据配置类型和更新频率动态调整缓存过期时间
"""
# 预定义的 TTL 策略
TTL_STRATEGIES = {
'production': timedelta(minutes=30), # 生产配置较稳定
'staging': timedelta(minutes=15), # 预发布配置中等稳定
'development': timedelta(minutes=5), # 开发配置频繁变化
'testing': timedelta(minutes=1), # 测试配置快速过期
'default': timedelta(minutes=5) # 默认策略
}
@classmethod
def get_ttl(cls, config_id: str, config_type: Optional[str] = None) -> timedelta:
"""
获取配置的 TTL
Args:
config_id: 配置 ID
config_type: 配置类型production/staging/development/testing
Returns:
TTL 时间间隔
"""
if config_type and config_type in cls.TTL_STRATEGIES:
return cls.TTL_STRATEGIES[config_type]
# 根据 config_id 推断类型
if 'prod' in config_id.lower():
return cls.TTL_STRATEGIES['production']
elif 'stag' in config_id.lower():
return cls.TTL_STRATEGIES['staging']
elif 'dev' in config_id.lower():
return cls.TTL_STRATEGIES['development']
elif 'test' in config_id.lower():
return cls.TTL_STRATEGIES['testing']
return cls.TTL_STRATEGIES['default']
class ConfigVersionManager:
"""
配置版本管理器
跟踪配置版本,当配置更新时自动失效旧版本缓存
"""
def __init__(self):
self._versions: Dict[str, str] = {}
self._lock = threading.RLock()
def get_version(self, config_id: str) -> Optional[str]:
"""
获取配置版本
Args:
config_id: 配置 ID
Returns:
版本号,如果不存在则返回 None
"""
with self._lock:
return self._versions.get(config_id)
def set_version(self, config_id: str, version: str) -> None:
"""
设置配置版本
Args:
config_id: 配置 ID
version: 版本号
"""
with self._lock:
old_version = self._versions.get(config_id)
self._versions[config_id] = version
if old_version and old_version != version:
logger.info(f"[VersionManager] 配置版本更新: {config_id} {old_version} -> {version}")
def check_version(self, config_id: str, cached_version: Optional[str]) -> bool:
"""
检查缓存版本是否有效
Args:
config_id: 配置 ID
cached_version: 缓存的版本号
Returns:
True 如果版本匹配False 如果版本不匹配或不存在
"""
with self._lock:
current_version = self._versions.get(config_id)
if not current_version or not cached_version:
return False
return current_version == cached_version
def invalidate(self, config_id: str) -> None:
"""
使配置版本失效
Args:
config_id: 配置 ID
"""
with self._lock:
if config_id in self._versions:
# 生成新版本号
import uuid
new_version = str(uuid.uuid4())
self._versions[config_id] = new_version
logger.info(f"[VersionManager] 配置版本失效: {config_id} -> {new_version}")
class CacheMonitor:
"""
缓存监控器
提供缓存性能监控和报告功能
"""
def __init__(self, cache: LRUConfigCache):
self.cache = cache
def get_report(self) -> str:
"""
生成缓存性能报告
Returns:
格式化的报告字符串
"""
stats = self.cache.get_stats()
report = f"""
配置缓存性能报告
================
缓存容量: {stats['cache_size']}/{stats['max_size']}
总请求数: {stats['total_requests']}
缓存命中: {stats['cache_hits']}
缓存未命中: {stats['cache_misses']}
缓存命中率: {stats['hit_rate']:.2f}%
淘汰次数: {stats['evictions']}
平均加载时间: {stats['avg_load_time']:.2f}ms
"""
return report
def log_stats(self) -> None:
"""记录统计信息到日志"""
stats = self.cache.get_stats()
logger.info(
f"[CacheMonitor] 缓存统计 - "
f"容量: {stats['cache_size']}/{stats['max_size']}, "
f"命中率: {stats['hit_rate']:.2f}%, "
f"淘汰: {stats['evictions']}"
)
# 使用示例
def example_usage():
"""
优化功能使用示例
"""
# 1. 使用 LRU 缓存
lru_cache = LRUConfigCache(max_size=100, ttl=timedelta(minutes=5))
# 获取配置
config = lru_cache.get("config_001")
if config is None:
# 缓存未命中,从数据库加载
config = {"llm_name": "openai/gpt-4"}
lru_cache.put("config_001", config)
# 2. 预热缓存
def load_config(config_id):
# 实际的配置加载逻辑
return True
warmer = ConfigCacheWarmer()
results = warmer.warmup(["config_001", "config_002"], load_config)
# 3. 动态 TTL
ttl = DynamicTTLStrategy.get_ttl("prod_config_001", "production")
print(f"TTL: {ttl}")
# 4. 版本管理
version_manager = ConfigVersionManager()
version_manager.set_version("config_001", "v1.0.0")
# 检查版本
is_valid = version_manager.check_version("config_001", "v1.0.0")
# 5. 监控
monitor = CacheMonitor(lru_cache)
print(monitor.get_report())
if __name__ == "__main__":
example_usage()

View File

@@ -1,268 +0,0 @@
# """
# 配置加载模块 - DEPRECATED
# ⚠️ DEPRECATION NOTICE ⚠️
# This module is deprecated and will be removed in a future version.
# Global configuration variables have been eliminated in favor of dependency injection.
# Use the new MemoryConfig system instead:
# - app.schemas.memory_config_schema.MemoryConfig for configuration objects
# - config_service = MemoryConfigService(db); config_service.load_memory_config(config_id)
# 阶段 1: 从 runtime.json 加载配置(路径 A- DEPRECATED
# 阶段 2: 从数据库加载配置(路径 B基于 dbrun.json 中的 config_id- DEPRECATED
# 阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)- DEPRECATED
# """
# import json
# import os
# import threading
# from datetime import datetime, timedelta
# from typing import Any, Dict, Optional
# #TODO: Fix this
# try:
# from dotenv import load_dotenv
# load_dotenv()
# except Exception:
# pass
# # Import unified configuration system
# try:
# from app.core.config import settings
# USE_UNIFIED_CONFIG = True
# except ImportError:
# USE_UNIFIED_CONFIG = False
# settings = None
# # PROJECT_ROOT 应该指向 app/core/memory/ 目录
# # __file__ = app/core/memory/utils/config/definitions.py
# # os.path.dirname(__file__) = app/core/memory/utils/config
# # os.path.dirname(...) = app/core/memory/utils
# # os.path.dirname(...) = app/core/memory
# PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# # DEPRECATED: Global configuration lock removed
# # Use MemoryConfig objects with dependency injection instead
# # DEPRECATED: Legacy config.json loading removed
# # Use MemoryConfig objects with dependency injection instead
# CONFIG = {}
# DEFAULT_VALUES = {
# "llm_name": "openai/qwen-plus",
# "embedding_name": "openai/nomic-embed-text:v1.5",
# "chunker_strategy": "RecursiveChunker",
# "group_id": "group_123",
# "user_id": "default_user",
# "apply_id": "default_apply",
# "llm_agent_name": "openai/qwen-plus",
# "llm_verify_name": "openai/qwen-plus",
# "llm_image_recognition": "openai/qwen-plus",
# "llm_voice_recognition": "openai/qwen-plus",
# "prompt_level": "DEBUG",
# "reflexion_iteration_period": "3",
# "reflexion_range": "retrieval",
# "reflexion_baseline": "TIME",
# }
# # DEPRECATED: Legacy global variables for backward compatibility only
# # These will be removed in a future version
# # Use MemoryConfig objects with dependency injection instead
# # LANGFUSE_ENABLED = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true"
# # SELECTED_LLM_ID = os.getenv("SELECTED_LLM_ID", DEFAULT_VALUES["llm_name"])
# # 阶段 1: 从 runtime.json 加载配置(路径 A
# def _load_from_runtime_json() -> Dict[str, Any]:
# """
# DEPRECATED: Legacy runtime.json loading
# ⚠️ This function is deprecated and will be removed in a future version.
# Use MemoryConfig objects with dependency injection instead.
# Returns:
# Dict[str, Any]: Empty configuration (legacy support only)
# """
# import warnings
# warnings.warn(
# "Runtime JSON loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
# DeprecationWarning,
# stacklevel=2
# )
# return {"selections": {}}
# # 阶段 2: 从数据库加载配置(路径 B- 已整合到统一加载器
# # 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代
# # 保留此函数仅为向后兼容
# def _load_from_database() -> Optional[Dict[str, Any]]:
# """
# DEPRECATED: Legacy database configuration loading
# ⚠️ This function is deprecated and will be removed in a future version.
# Use MemoryConfig objects with dependency injection instead.
# Returns:
# Optional[Dict[str, Any]]: None (deprecated functionality)
# """
# import warnings
# warnings.warn(
# "Database configuration loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
# DeprecationWarning,
# stacklevel=2
# )
# return None
# # 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)- DEPRECATED
# def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None:
# """
# DEPRECATED: 将运行时配置暴露为全局常量供项目使用
# ⚠️ This function is deprecated and will be removed in a future version.
# Global configuration variables have been eliminated in favor of dependency injection.
# Use the new MemoryConfig system instead:
# - app.core.memory_config.config.MemoryConfig for configuration objects
# - Pass configuration objects as parameters instead of using global variables
# Args:
# runtime_cfg: 运行时配置字典
# """
# import warnings
# warnings.warn(
# "Global configuration variables are deprecated. Use MemoryConfig objects with dependency injection instead.",
# DeprecationWarning,
# stacklevel=2
# )
# # Keep minimal global state for backward compatibility only
# # These will be removed in a future version
# global RUNTIME_CONFIG, SELECTIONS
# RUNTIME_CONFIG = runtime_cfg
# SELECTIONS = RUNTIME_CONFIG.get("selections", {})
# # All other global variables have been removed
# # Use MemoryConfig objects instead
# # 初始化:使用统一配置加载器
# def _initialize_configuration() -> None:
# """
# DEPRECATED: Legacy configuration initialization
# ⚠️ This function is deprecated and will be removed in a future version.
# Use MemoryConfig objects with dependency injection instead.
# """
# import warnings
# warnings.warn(
# "Global configuration initialization is deprecated. Use MemoryConfig objects with dependency injection instead.",
# DeprecationWarning,
# stacklevel=2
# )
# # Initialize with empty configuration for backward compatibility
# _expose_runtime_constants({"selections": {}})
# # 模块加载时自动初始化配置
# _initialize_configuration()
# # DEPRECATED: Global variables removed
# # These variables have been eliminated in favor of dependency injection
# # Use MemoryConfig objects instead of accessing global variables
# # 公共 API动态重新加载配置
# def reload_configuration_from_database(config_id, force_reload: bool = False) -> bool:
# """
# DEPRECATED: Legacy configuration reloading
# ⚠️ This function is deprecated and will be removed in a future version.
# Use MemoryConfig objects with dependency injection instead.
# For new code, use:
# - app.services.memory_agent_service.MemoryAgentService.load_memory_config()
# - app.services.memory_storage_service.MemoryStorageService.load_memory_config()
# Args:
# config_id: Configuration ID (deprecated)
# force_reload: Force reload flag (deprecated)
# Returns:
# bool: Always returns False (deprecated functionality)
# """
# import logging
# import warnings
# logger = logging.getLogger(__name__)
# warnings.warn(
# "reload_configuration_from_database is deprecated. Use MemoryConfig objects with dependency injection instead.",
# DeprecationWarning,
# stacklevel=2
# )
# logger.warning(f"Deprecated function reload_configuration_from_database called with config_id={config_id}. "
# "Use MemoryConfig objects with dependency injection instead.")
# return False
# def get_current_config_id() -> Optional[str]:
# """
# DEPRECATED: Legacy config ID retrieval
# ⚠️ This function is deprecated and will be removed in a future version.
# Use MemoryConfig objects with dependency injection instead.
# Returns:
# Optional[str]: None (deprecated functionality)
# """
# import warnings
# warnings.warn(
# "get_current_config_id is deprecated. Use MemoryConfig objects with dependency injection instead.",
# DeprecationWarning,
# stacklevel=2
# )
# return None
# def ensure_fresh_config(config_id = None) -> bool:
# """
# DEPRECATED: Legacy configuration freshness check
# ⚠️ This function is deprecated and will be removed in a future version.
# Use MemoryConfig objects with dependency injection instead.
# For new code, use:
# - app.services.memory_agent_service.MemoryAgentService.load_memory_config()
# - app.services.memory_storage_service.MemoryStorageService.load_memory_config()
# Args:
# config_id: Configuration ID (deprecated)
# Returns:
# bool: Always returns False (deprecated functionality)
# """
# import logging
# import warnings
# logger = logging.getLogger(__name__)
# warnings.warn(
# "ensure_fresh_config is deprecated. Use MemoryConfig objects with dependency injection instead.",
# DeprecationWarning,
# stacklevel=2
# )
# logger.warning(f"Deprecated function ensure_fresh_config called with config_id={config_id}. "
# "Use MemoryConfig objects with dependency injection instead.")
# return False

View File

@@ -1,90 +0,0 @@
import os
import re
import uuid
import random
import string
from typing import List, Dict, Optional
# 生成包含字母(大小写)和数字的随机字符串
def generate_random_string(length=16):
characters = string.ascii_letters + string.digits
return ''.join(random.choice(characters) for _ in range(length))
def get_example_data() -> List[Dict[str, Optional[str]]]:
"""
从句子提取日志中获取数据
Content: 在苹果公司中国总部,用户和李华偶遇了从美国来的技术专家约翰·史密斯。
Created At: 2025-11-28 19:28:38.256421
Expired At: None
Valid At: None
Invalid At: None
将数据构造成如下形式:
[
{
"id":id,
"group_id":group_id,
"statement": Content,
"created_at": Created At,
"expired_at": Expired At,
"valid_at": Valid At,
"invalid_at": Invalid At,
"chunk_id": "86da9022710c40eaa5f518a294c398d2",
"entity_ids": []
},
...
]
"""
# 获取日志文件路径
log_file_path = os.path.join("logs", "memory-output", "statement_extraction.txt")
# 检查文件是否存在
if not os.path.exists(log_file_path):
return []
# 读取日志文件
with open(log_file_path, "r", encoding="utf-8") as f:
content = f.read()
# 解析数据
results = []
# 使用正则表达式分割每个 Statement
statement_blocks = re.split(r"Statement \d+:", content)
for block in statement_blocks[1:]: # 跳过第一个空块
# 提取各个字段
id_match = re.search(r"Id:\s*(.+?)(?=\n)", block)
group_id_match = re.search(r"Group Id:\s*(.+?)(?=\n)", block)
statement_match = re.search(r"Content:\s*(.+?)(?=\n)", block)
created_at_match = re.search(r"Created At:\s*(.+?)(?=\n)", block)
expired_at_match = re.search(r"Expired At:\s*(.+?)(?=\n)", block)
valid_at_match = re.search(r"Valid At:\s*(.+?)(?=\n)", block)
invalid_at_match = re.search(r"Invalid At:\s*(.+?)(?=\n)", block)
chunk_id_match = re.search(r"Chunk Id:\s*(.+?)(?=\n)", block)
# 构造字典
if statement_match:
statement_data = {
"id": id_match.group(1).strip() if id_match else generate_random_string(),
"group_id": group_id_match.group(1).strip() if group_id_match else "group_example",
"statement": statement_match.group(1).strip(),
"created_at": created_at_match.group(1).strip() if created_at_match else None,
"expired_at": expired_at_match.group(1).strip() if expired_at_match else None,
"valid_at": valid_at_match.group(1).strip() if valid_at_match else None,
"invalid_at": invalid_at_match.group(1).strip() if invalid_at_match else None,
"chunk_id": chunk_id_match.group(1).strip() if chunk_id_match else "chunk_example",
"entity_ids": []
}
# 将 "None" 字符串转换为 None
for key in ["created_at", "expired_at", "valid_at", "invalid_at"]:
if statement_data[key] == "None":
statement_data[key] = None
results.append(statement_data)
return results
if __name__ == "__main__":
print(f"获取数据如下:\n {get_example_data()}")

View File

@@ -1,516 +0,0 @@
"""
LiteLLM Configuration for Enhanced Retry Logic and Usage Tracking with Native QPS Monitoring
"""
import litellm
from typing import Dict, Any, List
import json
from datetime import datetime, timedelta
import os
import time
from collections import defaultdict
import threading
from queue import Queue
class LiteLLMConfig:
"""Configuration class for LiteLLM with enhanced retry and tracking capabilities"""
def __init__(self):
self.usage_data = []
self.error_data = []
self.module_stats = defaultdict(lambda: {
'requests': 0,
'tokens_in': 0,
'tokens_out': 0,
'cost': 0.0,
'errors': 0,
'start_time': None,
'last_request_time': None,
'request_timestamps': [], # Store precise timestamps
'current_qps': 0.0,
'max_qps': 0.0,
'qps_history': [] # Store QPS measurements over time
})
self.start_time = datetime.now()
self.global_request_timestamps = []
self.global_max_qps = 0.0
# Rate limiting for AWS Bedrock (conservative limits)
self.rate_limits = {
'bedrock': {
'requests_per_minute': 2, # AWS Bedrock default is very low
'requests_per_second': 0.033, # 2/60 = 0.033 RPS
'last_request_time': 0,
'request_queue': Queue(),
'lock': threading.Lock()
}
}
self.rate_limiting_enabled = True
def setup_enhanced_config(self, max_retries: int = 3):
"""Configure LiteLLM with retry logic and instant QPS tracking"""
litellm.num_retries = max_retries
litellm.request_timeout = 300
litellm.retry_policy = {
"RateLimitError": {
"max_retries": 5,
"exponential_backoff": True,
"initial_delay": 1,
"max_delay": 60,
"jitter": True
},
"APIConnectionError": {
"max_retries": 3,
"exponential_backoff": True,
"initial_delay": 2,
"max_delay": 30,
"jitter": True
},
"InternalServerError": {
"max_retries": 2,
"exponential_backoff": True,
"initial_delay": 5,
"max_delay": 60,
"jitter": True
},
"BadRequestError": {
"max_retries": 1,
"exponential_backoff": False,
"initial_delay": 1,
"max_delay": 5
}
}
litellm.success_callback = [self._success_callback]
litellm.failure_callback = [self._failure_callback]
litellm.completion_cost_tracking = True
litellm.set_verbose = False
litellm.modify_params = True
print("✅ LiteLLM configured with instant QPS tracking and rate limiting")
def _success_callback(self, kwargs, completion_response, start_time, end_time):
"""Callback for successful requests with module-specific QPS tracking"""
try:
# Extract usage information
usage = completion_response.get('usage', {})
model = kwargs.get('model', 'unknown')
# Extract module information from metadata or model name
module = self._extract_module_name(kwargs, model)
# Calculate cost
cost = 0.0
try:
cost = litellm.completion_cost(completion_response)
except:
pass
# Calculate duration
duration_seconds = (end_time - start_time).total_seconds() if hasattr(end_time - start_time, 'total_seconds') else float(end_time - start_time)
# Record usage data
usage_record = {
"timestamp": datetime.now().isoformat(),
"model": model,
"module": module,
"input_tokens": usage.get('prompt_tokens', 0),
"output_tokens": usage.get('completion_tokens', 0),
"total_tokens": usage.get('total_tokens', 0),
"cost": cost,
"duration_seconds": duration_seconds,
"status": "success"
}
self.usage_data.append(usage_record)
# Update module-specific stats for QPS tracking
self._update_module_stats(module, usage_record, success=True)
# Print real-time feedback
print(f"{model}: {usage_record['input_tokens']}{usage_record['output_tokens']} tokens, ${cost:.4f}, {usage_record['duration_seconds']:.2f}s")
except Exception as e:
print(f"Warning: Success callback failed: {e}")
def _failure_callback(self, kwargs, completion_response, start_time, end_time):
"""Callback for failed requests with module-specific error tracking"""
try:
model = kwargs.get('model', 'unknown')
module = self._extract_module_name(kwargs, model)
duration_seconds = (end_time - start_time).total_seconds() if hasattr(end_time - start_time, 'total_seconds') else float(end_time - start_time)
# Handle different error response formats
error_message = "Unknown error"
error_type = "UnknownError"
# According to LiteLLM docs, completion_response contains the exception for failures
if completion_response is not None:
error_message = str(completion_response)
error_type = type(completion_response).__name__
# Also check kwargs for exception (LiteLLM passes exception in kwargs for failure events)
elif 'exception' in kwargs:
exception = kwargs['exception']
error_message = str(exception)
error_type = type(exception).__name__
# Check for other error formats in kwargs
elif 'error' in kwargs:
error = kwargs['error']
error_message = str(error)
error_type = type(error).__name__
# Check log_event_type to confirm this is a failure event
log_event_type = kwargs.get('log_event_type', '')
if log_event_type == 'failed_api_call' and 'exception' in kwargs:
exception = kwargs['exception']
error_message = str(exception)
error_type = type(exception).__name__
error_record = {
"timestamp": datetime.now().isoformat(),
"model": model,
"module": module,
"error": error_message,
"error_type": error_type,
"duration_seconds": duration_seconds,
"status": "failed"
}
self.error_data.append(error_record)
# Update module-specific stats for error tracking
self._update_module_stats(module, error_record, success=False)
# Print error feedback
print(f"{model}: {error_type} - {error_message[:100]}")
except Exception as e:
print(f"Warning: Failure callback failed: {e}")
# Debug: print the actual parameters to understand the structure
print(f"Debug - kwargs keys: {list(kwargs.keys()) if kwargs else 'None'}")
print(f"Debug - completion_response type: {type(completion_response)}")
print(f"Debug - completion_response: {completion_response}")
def _should_rate_limit(self, model: str) -> bool:
"""Check if the model should be rate limited"""
if not self.rate_limiting_enabled:
return False
return model.startswith('bedrock/') or 'bedrock' in model.lower()
def _enforce_rate_limit(self, model: str):
"""Enforce rate limiting for AWS Bedrock models"""
if not self._should_rate_limit(model):
return
provider = 'bedrock'
if provider not in self.rate_limits:
return
rate_config = self.rate_limits[provider]
with rate_config['lock']:
current_time = time.time()
time_since_last = current_time - rate_config['last_request_time']
min_interval = 1.0 / rate_config['requests_per_second']
if time_since_last < min_interval:
sleep_time = min_interval - time_since_last
print(f"⏳ Rate limiting: sleeping {sleep_time:.2f}s for {model}")
time.sleep(sleep_time)
rate_config['last_request_time'] = time.time()
def _extract_module_name(self, kwargs: Dict[str, Any], model: str) -> str:
"""Extract module name from request context"""
# Try to get module from metadata
metadata = kwargs.get('metadata', {})
if 'module' in metadata:
return metadata['module']
# Try to infer from model name or other context
if 'claude' in model.lower():
return 'bedrock_client'
elif 'gpt' in model.lower() or 'openai' in model.lower():
return 'openai_client'
elif 'embed' in model.lower():
return 'embedder'
else:
return 'unknown'
def _update_module_stats(self, module: str, record: Dict[str, Any], success: bool):
"""Update module-specific statistics with instant QPS tracking"""
current_timestamp = time.time()
current_time = datetime.now()
# Initialize module stats if first request
if self.module_stats[module]['start_time'] is None:
self.module_stats[module]['start_time'] = current_time
# Update counters
self.module_stats[module]['requests'] += 1
self.module_stats[module]['last_request_time'] = current_time
self.module_stats[module]['request_timestamps'].append(current_timestamp)
self.global_request_timestamps.append(current_timestamp)
# Calculate instant QPS for this module
self._calculate_instant_qps(module, current_timestamp)
# Calculate global instant QPS
self._calculate_global_instant_qps(current_timestamp)
if success:
self.module_stats[module]['tokens_in'] += record.get('input_tokens', 0)
self.module_stats[module]['tokens_out'] += record.get('output_tokens', 0)
self.module_stats[module]['cost'] += record.get('cost', 0.0)
else:
self.module_stats[module]['errors'] += 1
def _calculate_instant_qps(self, module: str, current_timestamp: float):
"""Calculate instant QPS for a specific module using sliding window"""
# Keep only timestamps from last 1 second for instant QPS
cutoff_time = current_timestamp - 1.0
timestamps = self.module_stats[module]['request_timestamps']
# Remove old timestamps
self.module_stats[module]['request_timestamps'] = [
ts for ts in timestamps if ts >= cutoff_time
]
# Calculate current QPS (requests in last second)
current_qps = len(self.module_stats[module]['request_timestamps'])
self.module_stats[module]['current_qps'] = current_qps
# Update max QPS if current is higher
if current_qps > self.module_stats[module]['max_qps']:
self.module_stats[module]['max_qps'] = current_qps
# Store QPS history (keep last 60 measurements)
self.module_stats[module]['qps_history'].append(current_qps)
if len(self.module_stats[module]['qps_history']) > 60:
self.module_stats[module]['qps_history'].pop(0)
def _calculate_global_instant_qps(self, current_timestamp: float):
"""Calculate global instant QPS across all modules"""
# Keep only timestamps from last 1 second
cutoff_time = current_timestamp - 1.0
self.global_request_timestamps = [
ts for ts in self.global_request_timestamps if ts >= cutoff_time
]
# Calculate current global QPS
current_global_qps = len(self.global_request_timestamps)
# Update max global QPS
if current_global_qps > self.global_max_qps:
self.global_max_qps = current_global_qps
def get_instant_qps(self, module: str = None) -> Dict[str, Any]:
"""Get instant QPS data for modules"""
if module:
if module in self.module_stats:
return {
'module': module,
'current_qps': self.module_stats[module]['current_qps'],
'max_qps': self.module_stats[module]['max_qps'],
'avg_qps_last_minute': sum(self.module_stats[module]['qps_history'][-60:]) / min(60, len(self.module_stats[module]['qps_history'])) if self.module_stats[module]['qps_history'] else 0
}
else:
return {'module': module, 'current_qps': 0, 'max_qps': 0, 'avg_qps_last_minute': 0}
else:
# Return data for all modules plus global
result = {
'global': {
'current_qps': len([ts for ts in self.global_request_timestamps if ts >= time.time() - 1.0]),
'max_qps': self.global_max_qps
},
'modules': {}
}
for mod in self.module_stats:
result['modules'][mod] = {
'current_qps': self.module_stats[mod]['current_qps'],
'max_qps': self.module_stats[mod]['max_qps'],
'avg_qps_last_minute': sum(self.module_stats[mod]['qps_history'][-60:]) / min(60, len(self.module_stats[mod]['qps_history'])) if self.module_stats[mod]['qps_history'] else 0
}
return result
def get_usage_summary(self) -> Dict[str, Any]:
"""Get essential usage statistics"""
if not self.usage_data:
return {
"total_requests": 0,
"total_cost": 0.0,
"error_rate": 0.0,
"message": "No usage data available"
}
total_requests = len(self.usage_data)
total_errors = len(self.error_data)
total_cost = sum(record['cost'] for record in self.usage_data)
total_input_tokens = sum(record['input_tokens'] for record in self.usage_data)
total_output_tokens = sum(record['output_tokens'] for record in self.usage_data)
# Calculate session duration
duration_minutes = (datetime.now() - self.start_time).total_seconds() / 60
# Build module statistics
module_stats = {}
for module, stats in self.module_stats.items():
if stats['requests'] > 0:
module_stats[module] = {
"requests": stats['requests'],
"errors": stats['errors'],
"success_rate": ((stats['requests'] - stats['errors']) / stats['requests'] * 100) if stats['requests'] > 0 else 0,
"tokens_in": stats['tokens_in'],
"tokens_out": stats['tokens_out'],
"cost": stats['cost'],
"current_qps": stats['current_qps'],
"max_qps": stats['max_qps']
}
return {
"session_duration_minutes": duration_minutes,
"total_requests": total_requests,
"total_errors": total_errors,
"error_rate": (total_errors / total_requests * 100) if total_requests > 0 else 0,
"total_input_tokens": total_input_tokens,
"total_output_tokens": total_output_tokens,
"total_cost": total_cost,
"module_stats": module_stats,
"global_max_qps": self.global_max_qps
}
def print_usage_summary(self):
"""Print essential usage summary"""
stats = self.get_usage_summary()
if stats.get('message'):
print(f"📊 {stats['message']}")
return
print("\n📊 USAGE SUMMARY")
print(f"{'='*50}")
print(f"⏱️ Duration: {stats['session_duration_minutes']:.1f} min")
print(f"📈 Requests: {stats['total_requests']}")
print(f"❌ Errors: {stats['total_errors']}")
print(f"💰 Cost: ${stats['total_cost']:.4f}")
print(f"🏆 Global Max QPS: {stats['global_max_qps']}")
# Module statistics
if stats.get('module_stats'):
print("\n📦 MODULES:")
for module, mod_stats in stats['module_stats'].items():
print(f" {module}: {mod_stats['requests']} req, Max QPS: {mod_stats['max_qps']}, Current: {mod_stats['current_qps']}")
print(f"{'='*50}")
def save_usage_data(self, filename: str = "litellm_usage.json"):
"""Save usage data to JSON file"""
data = {
"summary": self.get_usage_summary(),
"detailed_usage": self.usage_data,
"errors": self.error_data,
"export_timestamp": datetime.now().isoformat()
}
with open(filename, 'w') as f:
json.dump(data, f, indent=2)
print(f"📁 Usage data saved to {filename}")
def reset_tracking(self):
"""Reset all tracking data"""
self.usage_data = []
self.error_data = []
self.module_stats = defaultdict(lambda: {
'requests': 0,
'tokens_in': 0,
'tokens_out': 0,
'cost': 0.0,
'errors': 0,
'start_time': None,
'last_request_time': None,
'request_timestamps': [],
'current_qps': 0.0,
'max_qps': 0.0,
'qps_history': []
})
self.global_request_timestamps = []
self.global_max_qps = 0.0
self.start_time = datetime.now()
print("🔄 All tracking data reset")
# Global instance for easy access
litellm_config = LiteLLMConfig()
def setup_litellm_enhanced(max_retries: int = 3):
"""
Quick setup function for LiteLLM enhanced configuration
Args:
max_retries: Maximum number of retries for failed requests
"""
litellm_config.setup_enhanced_config(max_retries)
return litellm_config
def get_usage_summary():
"""Get current usage summary"""
return litellm_config.get_usage_summary()
def print_usage_summary():
"""Print current usage summary"""
litellm_config.print_usage_summary()
def save_usage_data(filename: str = "litellm_usage.json"):
"""Save usage data to file"""
litellm_config.save_usage_data(filename)
def get_instant_qps(module: str = None) -> Dict[str, Any]:
"""Get instant QPS data for modules"""
return litellm_config.get_instant_qps(module)
def print_instant_qps(module: str = None):
"""Print instant QPS information"""
qps_data = get_instant_qps(module)
print("\n⚡ INSTANT QPS MONITOR")
print(f"{'='*60}")
if module:
print(f"Module: {qps_data['module']}")
print(f" Current QPS: {qps_data['current_qps']}")
print(f" Max QPS: {qps_data['max_qps']}")
print(f" Avg (1min): {qps_data['avg_qps_last_minute']:.2f}")
else:
# Global stats
global_data = qps_data.get('global', {})
print("🌍 GLOBAL:")
print(f" Current QPS: {global_data.get('current_qps', 0)}")
print(f" Max QPS: {global_data.get('max_qps', 0)}")
# Module stats
modules = qps_data.get('modules', {})
if modules:
print("\n📦 MODULES:")
for mod, data in modules.items():
print(f" {mod}:")
print(f" Current: {data['current_qps']} QPS")
print(f" Max: {data['max_qps']} QPS")
print(f" Avg: {data['avg_qps_last_minute']:.2f} QPS")
print(f"{'='*60}")
def reset_tracking():
"""Reset all tracking data"""
litellm_config.reset_tracking()
def get_module_stats() -> Dict[str, Dict[str, Any]]:
"""Get detailed module statistics"""
summary = get_usage_summary()
return summary.get('module_stats', {})

View File

@@ -1,16 +0,0 @@
# -*- coding: utf-8 -*-
"""自我反思工具模块
本模块提供自我反思引擎的核心功能,包括:
- 记忆冲突判定
- 反思执行
- 记忆更新
从 app.core.memory.src.data_config_api 迁移而来。
"""
from app.core.memory.utils.self_reflexion_utils.evaluate import conflict
from app.core.memory.utils.self_reflexion_utils.reflexion import reflexion
from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion
__all__ = ["conflict", "reflexion", "self_reflexion"]

View File

@@ -1,52 +0,0 @@
# -*- coding: utf-8 -*-
"""记忆冲突判定模块
本模块提供记忆冲突判定功能使用LLM判断记忆数据中是否存在冲突。
从 app.core.memory.src.data_config_api.evaluate 迁移而来。
"""
import logging
import time
from typing import Any, List
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
from app.db import get_db_context
from app.schemas.memory_storage_schema import ConflictResultSchema
from pydantic import BaseModel
async def conflict(evaluate_data: List[Any]) -> List[Any]:
"""
Evaluates memory conflict using the evaluate.jinja2 template.
Args:
evaluate_data: 反思数据列表。
Returns:
冲突记忆列表JSON 数组)。
"""
from app.core.memory.utils.config import definitions as config_defs
with get_db_context() as db:
factory = MemoryClientFactory(db)
client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
rendered_prompt = await render_evaluate_prompt(evaluate_data, ConflictResultSchema)
messages = [{"role": "user", "content": rendered_prompt}]
print(f"提示词长度: {len(rendered_prompt)}")
print(f"====== 冲突判定开始 ======\n")
start_time = time.time()
response = await client.response_structured(messages, ConflictResultSchema)
end_time = time.time()
print(f"冲突判定耗时: {end_time - start_time}")
print(f"冲突判定原始输出:(type={type(response)})\n{response}")
if not response:
logging.error("LLM 冲突判定输出解析失败,返回空列表以继续流程。")
return []
try:
return [response.model_dump()] if isinstance(response, BaseModel) else [response]
except Exception:
try:
return [response.dict()]
except Exception:
logging.warning("无法标准化冲突判定返回类型,尝试直接封装为列表。")
return [response]

View File

@@ -1,54 +0,0 @@
# -*- coding: utf-8 -*-
"""反思执行模块
本模块提供反思执行功能使用LLM对冲突记忆进行反思和解决。
从 app.core.memory.src.data_config_api.reflexion 迁移而来。
"""
import logging
import time
from typing import Any, List
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.prompt.template_render import render_reflexion_prompt
from app.db import get_db_context
from app.schemas.memory_storage_schema import ReflexionResultSchema
from pydantic import BaseModel
async def reflexion(ref_data: List[Any]) -> List[Any]:
"""
Reflexes on the given reference data using the reflexion.jinja2 template.
Args:
ref_data: 反思数据列表。
Returns:
反思结果列表JSON 数组)。
"""
from app.core.memory.utils.config import definitions as config_defs
with get_db_context() as db:
factory = MemoryClientFactory(db)
client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
rendered_prompt = await render_reflexion_prompt(ref_data, ReflexionResultSchema)
messages = [{"role": "user", "content": rendered_prompt}]
print(f"提示词长度: {len(rendered_prompt)}")
print(f"====== 反思开始 ======\n")
start_time = time.time()
response = await client.response_structured(messages, ReflexionResultSchema)
end_time = time.time()
print(f"反思耗时: {end_time - start_time}")
print(f"反思原始输出:(type={type(response)})\n{response}")
if not response:
logging.error("LLM 反思输出解析失败,返回空列表以继续流程。")
return []
# 统一返回为列表[dict],便于自我反思主流程更新数据库
try:
return [response.model_dump()] if isinstance(response, BaseModel) else [response]
except Exception:
try:
return [response.dict()]
except Exception:
logging.warning("无法标准化反思返回类型,尝试直接封装为列表。")
return [response]

View File

@@ -1,254 +0,0 @@
# -*- coding: utf-8 -*-
"""自我反思主执行模块
本模块提供自我反思引擎的主流程,包括:
- 获取反思数据
- 冲突判断
- 反思执行
- 记忆更新
从 app.core.memory.src.data_config_api.self_reflexion 迁移而来。
"""
import asyncio
import json
import logging
import os
import uuid
from typing import Any, Dict, List
#TODO: Fix this
# Default values (previously from definitions.py)
REFLEXION_ENABLED = os.getenv("REFLEXION_ENABLED", "false").lower() == "true"
REFLEXION_ITERATION_PERIOD = os.getenv("REFLEXION_ITERATION_PERIOD", "3")
REFLEXION_RANGE = os.getenv("REFLEXION_RANGE", "retrieval")
REFLEXION_BASELINE = os.getenv("REFLEXION_BASELINE", "TIME")
from app.core.memory.utils.config.get_data import get_data
from app.core.memory.utils.self_reflexion_utils.evaluate import conflict
from app.core.memory.utils.self_reflexion_utils.reflexion import reflexion
from app.db import get_db
from app.models.retrieval_info import RetrievalInfo
from app.repositories.neo4j.cypher_queries import UPDATE_STATEMENT_INVALID_AT
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from sqlalchemy.orm import Session
# 并发限制(可通过环境变量覆盖)
CONCURRENCY = int(os.getenv("REFLEXION_CONCURRENCY", "5"))
# 确保 INFO 级别日志输出到终端
_root_logger = logging.getLogger()
if not _root_logger.handlers:
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
else:
_root_logger.setLevel(logging.INFO)
async def get_reflexion_data(host_id: uuid.UUID) -> List[Any]:
"""
根据反思范围获取判断的记忆数据。
Args:
host_id: 主机ID
Returns:
符合反思范围的记忆数据列表。
"""
if REFLEXION_RANGE == "partial":
return await get_data(host_id)
elif REFLEXION_RANGE == "all":
return []
else:
raise ValueError(f"未知的反思范围: {REFLEXION_RANGE}")
async def run_conflict(conflict_data: List[Any]) -> List[Any]:
"""
判断反思数据中是否存在冲突。
Args:
conflict_data: 冲突数据列表。
Returns:
如果存在冲突则返回冲突记忆列表,否则返回空列表。
"""
if not conflict_data:
return []
conflict_data = await conflict(conflict_data)
# 仅保留存在冲突的条目conflict == True
try:
return [c for c in conflict_data if isinstance(c, dict) and c.get("conflict") is True]
except Exception:
return []
async def run_reflexion(reflexion_data: List[Any]) -> Any:
"""
执行反思,解决冲突。
Args:
reflexion_data: 反思数据列表。
Returns:
解决冲突后的反思结果(由 LLM 返回)。
"""
if not reflexion_data:
return []
# 并行对每个冲突进行反思,整体缩短等待时间
sem = asyncio.Semaphore(CONCURRENCY)
async def _reflex_one(item: Any) -> Dict[str, Any] | None:
async with sem:
try:
result_list = await reflexion([item])
if not result_list:
return None
obj = result_list[0]
if hasattr(obj, "model_dump"):
return obj.model_dump()
elif hasattr(obj, "dict"):
return obj.dict()
elif isinstance(obj, dict):
return obj
except Exception as e:
logging.warning(f"反思失败,跳过一项: {e}")
return None
tasks = [_reflex_one(item) for item in reflexion_data]
results = await asyncio.gather(*tasks, return_exceptions=False)
return [r for r in results if r]
async def update_memory(solved_data: List[Any], host_id: uuid.UUID) -> str:
"""
更新记忆库,将解决冲突后的记忆更新到记忆库中。
Args:
solved_data: 解决冲突后的记忆(由 LLM 返回)。
host_id: 主机ID
Returns:
更新结果(成功或失败)。
"""
flag = False
if not solved_data:
return "数据缺失,更新失败"
if not isinstance(solved_data, list):
return "数据格式错误,更新失败"
neo4j_connector = Neo4jConnector()
try:
print(f"====== 更新记忆开始 ======\n")
sem = asyncio.Semaphore(CONCURRENCY)
success_count = 0
async def _update_one(item: Dict[str, Any]) -> bool:
async with sem:
try:
if not isinstance(item, dict):
return False
if not item:
return False
resolved = item.get("resolved")
if not isinstance(resolved, dict) or not resolved:
logging.warning(f"反思结果无可更新内容,跳过此项: {item}")
return False
resolved_mem = resolved.get("resolved_memory")
if not isinstance(resolved_mem, dict) or not resolved_mem:
logging.warning(f"反思结果缺少 resolved_memory跳过此项: {item}")
return False
group_id = resolved_mem.get("group_id")
id = resolved_mem.get("id")
# 使用 invalid_at 字段作为新的失效时间
new_invalid_at = resolved_mem.get("invalid_at")
if not all([group_id, id, new_invalid_at]):
logging.warning(f"记忆更新参数缺失,跳过此项: {item}")
return False
await neo4j_connector.execute_query(
UPDATE_STATEMENT_INVALID_AT,
group_id=group_id,
id=id,
new_invalid_at=new_invalid_at,
)
return True
except Exception as e:
logging.error(f"更新单条记忆失败: {e}")
return False
tasks = [_update_one(item) for item in solved_data if isinstance(item, dict)]
results = await asyncio.gather(*tasks, return_exceptions=False)
success_count = sum(1 for r in results if r)
logging.info(f"成功更新 {success_count} 条记忆")
flag = success_count > 0
return "更新成功" if flag else "更新失败"
except Exception as e:
logging.error(f"更新记忆库失败: {e}")
return "更新失败"
finally:
if flag: # 删除数据库中的检索数据
db: Session = next(get_db())
try:
db.query(RetrievalInfo).filter(RetrievalInfo.host_id == host_id).delete()
db.commit()
logging.info(f"成功删除 {success_count} 条检索数据")
except Exception as e:
logging.error(f"删除数据库中的检索数据失败: {e}")
finally:
db.close()
async def _append_json(label: str, data: Any) -> None:
"""记录冲突记忆(后台线程写入,避免阻塞事件循环)"""
def _write():
with open("reflexion_data.json", "a", encoding="utf-8") as f:
f.write(f"### {label} ###\n")
json.dump(data, f, ensure_ascii=False, indent=4)
f.write("\n\n")
# 正确地在协程内等待后台线程执行,避免未等待的协程警告
await asyncio.to_thread(_write)
async def self_reflexion(host_id: uuid.UUID) -> str:
"""
自我反思引擎,执行反思流程。
Args:
host_id: 主机ID
Returns:
反思结果描述字符串
"""
if not REFLEXION_ENABLED:
return "未开启反思..."
print(f"====== 自我反思流程开始 ======\n")
reflexion_data = await get_reflexion_data(host_id)
if not reflexion_data:
print(f"====== 自我反思流程结束 ======\n")
return "无反思数据,结束反思"
print(f"反思数据获取成功,共 {len(reflexion_data)}")
conflict_data = await run_conflict(reflexion_data)
if not conflict_data:
print(f"====== 自我反思流程结束 ======\n")
return "无冲突,无需反思"
print(f"冲突记忆类型: {type(conflict_data)}")
await _append_json("conflict", conflict_data)
solved_data = await run_reflexion(conflict_data)
if not solved_data:
print(f"====== 自我反思流程结束 ======\n")
return "反思失败,未解决冲突"
print(f"解决冲突后的记忆类型: {type(solved_data)}")
await _append_json("solved_data", solved_data)
result = await update_memory(solved_data, host_id)
print(f"更新记忆库结果: {result}")
print(f"====== 自我反思流程结束 ======\n")
return result
if __name__ == "__main__":
import asyncio
# host_id = uuid.UUID("3f6ff1eb-50c7-4765-8e89-e4566be33333")
host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122")
asyncio.run(self_reflexion(host_id))

View File

@@ -1,29 +1,30 @@
from typing import Any, Dict, List, Optional
import asyncio
import logging
from typing import Any, Dict, List, Optional
from app.repositories.neo4j.cypher_queries import (
CHUNK_EMBEDDING_SEARCH,
ENTITY_EMBEDDING_SEARCH,
MEMORY_SUMMARY_EMBEDDING_SEARCH,
SEARCH_CHUNK_BY_CHUNK_ID,
SEARCH_CHUNKS_BY_CONTENT,
SEARCH_DIALOGUE_BY_DIALOG_ID,
SEARCH_ENTITIES_BY_NAME,
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
SEARCH_STATEMENTS_BY_CREATED_AT,
SEARCH_STATEMENTS_BY_KEYWORD,
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
SEARCH_STATEMENTS_BY_TEMPORAL,
SEARCH_STATEMENTS_BY_VALID_AT,
SEARCH_STATEMENTS_G_CREATED_AT,
SEARCH_STATEMENTS_G_VALID_AT,
SEARCH_STATEMENTS_L_CREATED_AT,
SEARCH_STATEMENTS_L_VALID_AT,
STATEMENT_EMBEDDING_SEARCH,
)
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.cypher_queries import (
SEARCH_STATEMENTS_BY_KEYWORD,
SEARCH_ENTITIES_BY_NAME,
SEARCH_CHUNKS_BY_CONTENT,
STATEMENT_EMBEDDING_SEARCH,
CHUNK_EMBEDDING_SEARCH,
ENTITY_EMBEDDING_SEARCH,
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
MEMORY_SUMMARY_EMBEDDING_SEARCH,
SEARCH_STATEMENTS_BY_TEMPORAL,
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
SEARCH_DIALOGUE_BY_DIALOG_ID,
SEARCH_CHUNK_BY_CHUNK_ID,
SEARCH_STATEMENTS_BY_CREATED_AT,
SEARCH_STATEMENTS_BY_VALID_AT,
SEARCH_STATEMENTS_G_CREATED_AT,
SEARCH_STATEMENTS_L_CREATED_AT,
SEARCH_STATEMENTS_G_VALID_AT,
SEARCH_STATEMENTS_L_VALID_AT,
)
logger = logging.getLogger(__name__)
@@ -55,8 +56,12 @@ async def _update_activation_values_batch(
return []
# 延迟导入以避免循环依赖
from app.core.memory.storage_services.forgetting_engine.access_history_manager import AccessHistoryManager
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
from app.core.memory.storage_services.forgetting_engine.access_history_manager import (
AccessHistoryManager,
)
from app.core.memory.storage_services.forgetting_engine.actr_calculator import (
ACTRCalculator,
)
# 创建计算器和管理器实例
actr_calculator = ACTRCalculator()
@@ -292,6 +297,13 @@ async def search_graph(
else:
results[key] = result
# Deduplicate results before updating activation values
# This prevents duplicates from propagating through the pipeline
from app.core.memory.src.search import _deduplicate_results
for key in results:
if isinstance(results[key], list):
results[key] = _deduplicate_results(results[key])
# 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary
results = await _update_search_results_activation(
connector=connector,
@@ -397,6 +409,13 @@ async def search_graph_by_embedding(
else:
results[key] = result
# Deduplicate results before updating activation values
# This prevents duplicates from propagating through the pipeline
from app.core.memory.src.search import _deduplicate_results
for key in results:
if isinstance(results[key], list):
results[key] = _deduplicate_results(results[key])
# 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary
update_start = time.time()
results = await _update_search_results_activation(

View File

@@ -0,0 +1,13 @@
import yaml from 'js-yaml';
export const exportToYaml = (data: unknown, filename: string = 'export.yaml') => {
const yamlStr = yaml.dump(data);
const blob = new Blob([yamlStr], { type: 'text/yaml' });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = filename;
a.click();
URL.revokeObjectURL(url);
};

View File

@@ -221,12 +221,12 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
return new Promise((resolve, reject) => {
saveAgentConfig(data.app_id, params)
.then(() => {
.then((res) => {
if (flag) {
message.success(t('common.saveSuccess'))
}
setIsSave(false)
resolve(true)
resolve(res)
}).catch(error => {
reject(error)
})

View File

@@ -58,16 +58,14 @@ const Cluster = forwardRef<ClusterRef>((_props, ref) => {
}))
}
console.log('params', params)
return new Promise((resolve, reject) => {
form.validateFields().then(() => {
saveMultiAgentConfig(id as string, params)
.then(() => {
.then((res) => {
if (flag) {
message.success(t('common.saveSuccess'))
}
resolve(true)
resolve(res)
})
.catch(error => {
reject(error)

View File

@@ -11,7 +11,7 @@ import exportIcon from '@/assets/images/export_hover.svg'
import deleteIcon from '@/assets/images/delete_hover.svg'
import type { Application, ApplicationModalRef } from '@/views/ApplicationManagement/types';
import ApplicationModal from '@/views/ApplicationManagement/components/ApplicationModal'
import type { CopyModalRef, WorkflowRef } from '../types'
import type { CopyModalRef, AgentRef, ClusterRef, WorkflowRef } from '../types'
import { deleteApplication } from '@/api/application'
import CopyModal from './CopyModal'
@@ -30,10 +30,11 @@ interface ConfigHeaderProps {
handleChangeTab: (key: string) => void;
refresh: () => void;
workflowRef: React.RefObject<WorkflowRef>
appRef?: React.RefObject<AgentRef | ClusterRef | WorkflowRef>
}
const ConfigHeader: FC<ConfigHeaderProps> = ({
application, activeTab, handleChangeTab, refresh,
workflowRef
workflowRef,
}) => {
const { t } = useTranslation();
const navigate = useNavigate();
@@ -48,7 +49,7 @@ const ConfigHeader: FC<ConfigHeaderProps> = ({
}))
}
const formatMenuItems = () => {
const items = ['edit', 'copy', 'delete'].map(key => ({
const items = ['edit', 'copy', 'export', 'delete'].map(key => ({
key,
icon: <img src={menuIcons[key]} className="rb:w-4 rb:h-4 rb:mr-2" />,
label: t(`common.${key}`),
@@ -59,7 +60,6 @@ const ConfigHeader: FC<ConfigHeaderProps> = ({
}
}
const handleClick: MenuProps['onClick'] = ({ key }) => {
console.log('key', key)
switch (key) {
case 'edit':
applicationModalRef.current?.handleOpen(application as Application)

View File

@@ -60,10 +60,11 @@ const ApplicationConfig: React.FC = () => {
handleChangeTab={handleChangeTab}
application={application as Application}
refresh={getApplicationInfo}
appRef={application?.type === 'agent' ? agentRef : application?.type === 'multi_agent' ? clusterRef : application?.type === 'workflow' ? workflowRef : undefined}
workflowRef={workflowRef}
/>
{activeTab === 'arrangement' && application?.type === 'agent' && <Agent ref={agentRef} />}
{activeTab === 'arrangement' && application?.type === 'multi_agent' && <Cluster ref={clusterRef} application={application as Application} />}
{activeTab === 'arrangement' && application?.type === 'multi_agent' && <Cluster ref={clusterRef} />}
{activeTab === 'arrangement' && application?.type === 'workflow' && <Workflow ref={workflowRef} />}
{activeTab === 'api' && <Api application={application} />}
{activeTab === 'release' && <ReleasePage data={application as Application} refresh={getApplicationInfo} />}

View File

@@ -2,7 +2,7 @@ import { useState } from 'react';
import { Popover } from 'antd';
import clsx from 'clsx';
import type { ReactShapeConfig } from '@antv/x6-react-shape';
import { nodeLibrary, graphNodeLibrary } from '../../constant';
import { nodeLibrary, graphNodeLibrary, edgeAttrs } from '../../constant';
import { useTranslation } from 'react-i18next';
const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
@@ -47,7 +47,7 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
graph.addEdge({
source: { cell: edge.getSourceCellId(), port: edge.getSourcePortId() },
target: { cell: newNode.id, port: newNode.getPorts().find((port: any) => port.group === 'left')?.id || 'left' },
attrs: edge.getAttrs(),
...edgeAttrs
});
});
@@ -57,7 +57,7 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
graph.addEdge({
source: { cell: newNode.id, port: newNode.getPorts().find((port: any) => port.group === 'right')?.id || 'right' },
target: { cell: edge.getTargetCellId(), port: targetPortId },
attrs: edge.getAttrs(),
...edgeAttrs
});
});

View File

@@ -2,9 +2,7 @@ import { useEffect } from 'react';
import { useTranslation } from 'react-i18next'
import clsx from 'clsx';
import type { ReactShapeConfig } from '@antv/x6-react-shape';
import { graphNodeLibrary } from '../../constant';
import { edge_color } from '../../hooks/useWorkflowGraph'
import { graphNodeLibrary, edgeAttrs } from '../../constant';
const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
const data = node.getData() || {};
@@ -56,16 +54,7 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
graph.addEdge({
source: { cell: cycleStartNode.id, port: sourcePort },
target: { cell: addNode.id, port: targetPort },
attrs: {
line: {
stroke: edge_color,
strokeWidth: 1,
targetMarker: {
name: 'block',
size: 8,
},
},
},
...edgeAttrs,
});
}
}
@@ -122,16 +111,7 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
cell: addNode.id,
port: targetPorts.find((port: any) => port.group === 'left')?.id || 'left'
},
attrs: {
line: {
stroke: edge_color,
strokeWidth: 1,
targetMarker: {
name: 'block',
size: 2,
},
},
},
...edgeAttrs
}
graph.addEdge(edgeConfig)
}

View File

@@ -1,7 +1,7 @@
import { useEffect, useState } from 'react';
import { Popover } from 'antd';
import { useTranslation } from 'react-i18next';
import { nodeLibrary, graphNodeLibrary } from '../constant';
import { nodeLibrary, graphNodeLibrary, edgeAttrs } from '../constant';
interface PortClickHandlerProps {
graph: any;
@@ -149,16 +149,7 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
graph.addEdge({
source: { cell: sourceNode.id, port: sourcePort },
target: { cell: newNode.id, port: targetPort },
attrs: {
line: {
stroke: '#155EEF',
strokeWidth: 1,
targetMarker: {
name: 'block',
size: 8,
},
},
},
...edgeAttrs
// zIndex: sourceNodeData.cycle && sourceNodeType == 'cycle-start' ? 1 : sourceNodeData.cycle ? 2 : 0
});

View File

@@ -6,6 +6,7 @@ import { Form, Button, Select, Space, Divider, InputNumber, Radio, type SelectPr
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
import VariableSelect from '../VariableSelect'
import Editor from '../../Editor'
import { edgeAttrs } from '../../../constant'
interface CaseListProps {
value?: Array<{ logical_operator: 'and' | 'or'; expressions: { left: string; operator: string; right: string; input_type?: string; }[] }>;
@@ -120,16 +121,7 @@ const CaseList: FC<CaseListProps> = ({
graphRef.current?.addEdge({
source: { cell: sourceCellId, port: sourcePortId },
target: { cell: selectedNode.id, port: targetPortId },
attrs: {
line: {
stroke: '#155EEF',
strokeWidth: 1,
targetMarker: {
name: 'block',
size: 8,
},
},
},
...edgeAttrs,
});
}
graphRef.current?.removeCell(edge);
@@ -174,16 +166,7 @@ const CaseList: FC<CaseListProps> = ({
graphRef.current?.addEdge({
source: { cell: selectedNode.id, port: newPortId },
target: { cell: targetCellId, port: targetPortId },
attrs: {
line: {
stroke: '#155EEF',
strokeWidth: 1,
targetMarker: {
name: 'block',
size: 8,
},
},
},
...edgeAttrs
});
}
}

View File

@@ -5,6 +5,7 @@ import { Graph, Node } from '@antv/x6';
import Editor from '../../Editor';
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
import { edgeAttrs } from '../../../constant'
interface CategoryListProps {
parentName: string;
@@ -70,16 +71,7 @@ const CategoryList: FC<CategoryListProps> = ({ parentName, selectedNode, graphRe
graphRef.current?.addEdge({
source: { cell: sourceCellId, port: sourcePortId },
target: { cell: selectedNode.id, port: targetPortId },
attrs: {
line: {
stroke: '#155EEF',
strokeWidth: 1,
targetMarker: {
name: 'block',
size: 8,
},
},
},
...edgeAttrs
});
}
return;
@@ -110,16 +102,7 @@ const CategoryList: FC<CategoryListProps> = ({ parentName, selectedNode, graphRe
graphRef.current?.addEdge({
source: { cell: selectedNode.id, port: newPortId },
target: { cell: targetCellId, port: targetPortId },
attrs: {
line: {
stroke: '#155EEF',
strokeWidth: 1,
targetMarker: {
name: 'block',
size: 8,
},
},
},
...edgeAttrs
});
}
}

View File

@@ -1,6 +1,6 @@
import { type FC } from 'react'
import { useTranslation } from 'react-i18next';
import { Form, Select, Input, Button } from 'antd'
import { Form, Select, Input, Button, InputNumber, Radio } from 'antd'
import VariableSelect from '../VariableSelect'
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
@@ -93,6 +93,7 @@ const CycleVarsList: FC<CycleVarsListProps> = ({
</Button>
</div>
{fields.map(({ key, name }, index) => {
const currentType = value?.[index]?.type;
const currentInputType = value?.[index]?.input_type;
return (
@@ -131,7 +132,8 @@ const CycleVarsList: FC<CycleVarsListProps> = ({
</div>
<Form.Item name={[name, 'value']} noStyle>
{currentInputType === 'variable' ? (
{currentInputType === 'variable'
? (
<VariableSelect
placeholder={t('common.pleaseSelect')}
options={availableOptions.filter(option => {
@@ -143,7 +145,20 @@ const CycleVarsList: FC<CycleVarsListProps> = ({
variant="borderless"
size="small"
/>
) : (
)
: currentType === 'number'
? <InputNumber
placeholder={t('common.pleaseEnter')}
variant="borderless"
className="rb:w-full! rb:my-1!"
onChange={(value) => form.setFieldValue([name, 'value'], value)}
/>
: currentType === 'boolean'
? <Radio.Group block>
<Radio.Button value={true}>True</Radio.Button>
<Radio.Button value={false}>False</Radio.Button>
</Radio.Group>
: (
<Input.TextArea
placeholder={t('common.pleaseEnter')}
rows={3}

View File

@@ -517,6 +517,8 @@ interface NodeConfig {
ports?: PortsConfig;
}
export const edge_color = '#155EEF';
export const edge_selected_color = '#4DA8FF'
// 统一的端口 markup 配置
export const portMarkup = [
{
@@ -534,9 +536,9 @@ export const portAttrs = {
body: {
r: 6,
magnet: true,
stroke: '#155EEF',
stroke: edge_color,
strokeWidth: 2,
fill: '#155EEF',
fill: edge_color,
},
label: {
text: '+',
@@ -776,4 +778,18 @@ export const outputVariable: { [key: string]: OutputVariable } = {
{ name: "output", type: "string" },
],
},
}
export const edgeAttrs = {
attrs: {
line: {
stroke: edge_color,
strokeWidth: 1,
targetMarker: {
name: 'block',
width: 4,
height: 4,
},
},
},
}

View File

@@ -5,7 +5,7 @@ import { App } from 'antd'
import { Graph, Node, MiniMap, Snapline, Clipboard, Keyboard, type Edge } from '@antv/x6';
import { register } from '@antv/x6-react-shape';
import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs } from '../constant';
import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs, edgeAttrs, edge_color, edge_selected_color } from '../constant';
import type { WorkflowConfig, NodeProperties, ChatVariable } from '../types';
import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application'
import type { PortMetadata } from '@antv/x6/lib/model/port';
@@ -23,12 +23,8 @@ export interface UseWorkflowGraphReturn {
setSelectedNode: React.Dispatch<React.SetStateAction<Node | null>>;
zoomLevel: number;
setZoomLevel: React.Dispatch<React.SetStateAction<number>>;
canUndo: boolean;
canRedo: boolean;
isHandMode: boolean;
setIsHandMode: React.Dispatch<React.SetStateAction<boolean>>;
onUndo: () => void;
onRedo: () => void;
onDrop: (event: React.DragEvent) => void;
blankClick: () => void;
deleteEvent: () => boolean | void;
@@ -39,8 +35,6 @@ export interface UseWorkflowGraphReturn {
setChatVariables: React.Dispatch<React.SetStateAction<ChatVariable[]>>;
}
export const edge_color = '#155EEF';
const edge_selected_color = '#4DA8FF'
export const useWorkflowGraph = ({
containerRef,
miniMapRef,
@@ -51,9 +45,6 @@ export const useWorkflowGraph = ({
const graphRef = useRef<Graph>();
const [selectedNode, setSelectedNode] = useState<Node | null>(null);
const [zoomLevel, setZoomLevel] = useState(1);
const historyRef = useRef<{ undoStack: string[], redoStack: string[] }>({ undoStack: [], redoStack: [] });
const [canUndo, setCanUndo] = useState(false);
const [canRedo, setCanRedo] = useState(false);
const [isHandMode, setIsHandMode] = useState(true);
const [config, setConfig] = useState<WorkflowConfig | null>(null);
const [chatVariables, setChatVariables] = useState<ChatVariable[]>([])
@@ -338,17 +329,7 @@ export const useWorkflowGraph = ({
port: targetPorts.find((port: any) => port.group === 'left')?.id || 'left'
},
connector: { name: 'smooth' },
attrs: {
line: {
stroke: edge_color,
strokeWidth: 1,
targetMarker: {
name: 'diamond',
width: 4,
height: 4,
},
},
},
...edgeAttrs
// zIndex: loopIterationCount
}
@@ -368,48 +349,6 @@ export const useWorkflowGraph = ({
}, 200)
}
}
const saveState = () => {
if (!graphRef.current) return;
const state = JSON.stringify(graphRef.current.toJSON());
historyRef.current.undoStack.push(state);
historyRef.current.redoStack = [];
if (historyRef.current.undoStack.length > 50) {
historyRef.current.undoStack.shift();
}
updateHistoryState();
};
const updateHistoryState = () => {
setCanUndo(historyRef.current.undoStack.length > 1);
setCanRedo(historyRef.current.redoStack.length > 0);
};
// 撤销
const onUndo = () => {
if (!graphRef.current || historyRef.current.undoStack.length === 0) return;
const { undoStack = [], redoStack = [] } = historyRef.current
const currentState = JSON.stringify(graphRef.current.toJSON());
const prevState = undoStack[undoStack.length - 2];
historyRef.current.redoStack = [...redoStack, currentState]
historyRef.current.undoStack = undoStack.slice(0, undoStack.length - 1)
graphRef.current.fromJSON(JSON.parse(prevState));
updateHistoryState();
};
// 重做
const onRedo = () => {
if (!graphRef.current || historyRef.current.redoStack.length === 0) return;
const { undoStack = [], redoStack = [] } = historyRef.current
const nextState = redoStack[redoStack.length - 1];
historyRef.current.undoStack = [...undoStack, nextState]
historyRef.current.redoStack = redoStack.slice(0, redoStack.length - 1)
graphRef.current.fromJSON(JSON.parse(nextState));
updateHistoryState();
};
// 使用插件
const setupPlugins = () => {
if (!graphRef.current || !miniMapRef.current) return;
@@ -563,20 +502,6 @@ export const useWorkflowGraph = ({
}
return false;
};
// 撤销快捷键事件
const undoEvent = () => {
if (canUndo) {
onUndo();
}
return false;
};
// 重做快捷键事件
const redoEvent = () => {
if (canRedo) {
onRedo();
}
return false;
};
// 删除选中的节点和连线事件
const deleteEvent = () => {
if (!graphRef.current) return;
@@ -748,8 +673,6 @@ export const useWorkflowGraph = ({
background: {
color: '#F0F3F8',
},
// width: container.clientWidth || 800,
// height: container.clientHeight || 600,
autoResize: true,
grid: {
visible: true,
@@ -765,37 +688,26 @@ export const useWorkflowGraph = ({
enabled: true,
},
connecting: {
// router: 'orth',
// router: 'manhattan',
connector: {
name: 'smooth',
args: {
radius: 8,
},
},
anchor: 'center',
anchor: 'midSide',
connectionPoint: 'anchor',
allowBlank: false,
allowLoop: false,
allowNode: false,
allowEdge: false,
allowPort: true,
allowMulti: true,
highlight: true,
snap: {
radius: 20,
},
createEdge() {
return graphRef.current?.createEdge({
attrs: {
line: {
stroke: edge_color,
strokeWidth: 1,
targetMarker: {
name: 'diamond',
width: 4,
height: 4,
},
},
},
});
return graphRef.current?.createEdge(edgeAttrs);
},
validateConnection({ sourceCell, targetCell, targetMagnet }) {
if (!targetMagnet) return false;
@@ -901,27 +813,7 @@ export const useWorkflowGraph = ({
// 监听缩放事件
graphRef.current.on('scale', scaleEvent);
// 监听节点移动事件
// graphRef.current.on('node:moved', nodeMoved);
graphRef.current.on('node:change:position', nodeChangePosition);
// 监听画布变化事件
const events = [
'node:added',
'node:removed',
'edge:added',
'edge:removed',
];
events.forEach(event => {
graphRef.current!.on(event, () => {
console.log('event', event);
setTimeout(() => saveState(), 50);
});
});
// 监听撤销键盘事件
graphRef.current.bindKey(['ctrl+z', 'cmd+z'], undoEvent);
// 监听重做键盘事件
graphRef.current.bindKey(['ctrl+shift+z', 'cmd+shift+z', 'ctrl+y', 'cmd+y'], redoEvent);
graphRef.current.on('node:moved', nodeMoved);
// 监听复制键盘事件
graphRef.current.bindKey(['ctrl+c', 'cmd+c'], copyEvent);
// 监听粘贴键盘事件
@@ -929,11 +821,6 @@ export const useWorkflowGraph = ({
// 删除选中的节点和连线
graphRef.current.bindKey(['ctrl+d', 'cmd+d', 'delete', 'backspace'], deleteEvent);
// 保存初始状态
setTimeout(() => saveState(), 100);
// init window hook
(window as Window & { __x6_instances__?: Graph[] }).__x6_instances__ = [];
(window as Window & { __x6_instances__?: Graph[] }).__x6_instances__?.push(graphRef.current);
};
useEffect(() => {
@@ -1146,11 +1033,11 @@ export const useWorkflowGraph = ({
}),
}
saveWorkflowConfig(config.app_id, params as WorkflowConfig)
.then(() => {
.then((res) => {
if (flag) {
message.success(t('common.saveSuccess'))
}
resolve(true)
resolve(res)
}).catch(error => {
reject(error)
})
@@ -1165,12 +1052,8 @@ export const useWorkflowGraph = ({
setSelectedNode,
zoomLevel,
setZoomLevel,
canUndo,
canRedo,
isHandMode,
setIsHandMode,
onUndo,
onRedo,
onDrop,
blankClick,
deleteEvent,