Merge remote-tracking branch 'origin/develop' into develop
# Conflicts: # api/app/services/memory_reflection_service.py
This commit is contained in:
@@ -1,10 +1,8 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.utils.self_reflexion_utils import self_reflexion
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
@@ -458,18 +456,3 @@ async def get_recent_activity_stats_api(
|
||||
api_logger.error(f"Recent activity stats failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/self_reflexion")
|
||||
async def self_reflexion_endpoint(host_id: uuid.UUID) -> str:
|
||||
"""
|
||||
自我反思接口,自动对检索出的信息进行自我反思并返回自我反思结果。
|
||||
|
||||
Args:
|
||||
None
|
||||
Returns:
|
||||
自我反思结果。
|
||||
"""
|
||||
return await self_reflexion(host_id)
|
||||
|
||||
|
||||
@@ -1,48 +1,15 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Tuple
|
||||
|
||||
from neo4j import GraphDatabase
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# ------------------- 自包含路径解析 -------------------
|
||||
# 这个代码块确保脚本可以从任何地方运行,并且仍然可以在项目结构中找到它需要的模块。
|
||||
try:
|
||||
# 假设脚本在 /path/to/project/src/analytics/
|
||||
# 上升3个级别以到达项目根目录。
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
src_path = os.path.join(project_root, 'src')
|
||||
|
||||
# 将 'src' 和 'project_root' 都添加到路径中。
|
||||
# 'src' 目录对于像 'from utils.config_utils import ...' 这样的导入是必需的。
|
||||
# 'project_root' 目录对于像 'from variate_config import ...' 这样的导入是必需的。
|
||||
if src_path not in sys.path:
|
||||
sys.path.insert(0, src_path)
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
except NameError:
|
||||
# 为 __file__ 未定义的环境(例如某些交互式解释器)提供回退方案
|
||||
project_root = os.path.abspath(os.path.join(os.getcwd()))
|
||||
src_path = os.path.join(project_root, 'src')
|
||||
if src_path not in sys.path:
|
||||
sys.path.insert(0, src_path)
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
# 现在路径已经配置好,我们可以使用绝对导入
|
||||
import json
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
#TODO: Fix this
|
||||
# Default values (previously from definitions.py)
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
|
||||
|
||||
# 定义用于LLM结构化输出的Pydantic模型
|
||||
class FilteredTags(BaseModel):
|
||||
@@ -52,34 +19,45 @@ class FilteredTags(BaseModel):
|
||||
async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
||||
"""
|
||||
使用LLM筛选标签列表,仅保留具有代表性的核心名词。
|
||||
|
||||
Args:
|
||||
tags: 原始标签列表
|
||||
group_id: 用户组ID,用于获取配置
|
||||
|
||||
Returns:
|
||||
筛选后的标签列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果无法获取有效的LLM配置
|
||||
"""
|
||||
try:
|
||||
# Get config_id using get_end_user_connected_config
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if not config_id:
|
||||
raise ValueError(
|
||||
f"No memory_config_id found for group_id: {group_id}. "
|
||||
"Please ensure the user has a valid memory configuration."
|
||||
)
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id:
|
||||
# Use the config_id to get the proper LLM client
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
else:
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM if no config found
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"Failed to get user connected config, using default LLM: {e}")
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
# Use the config_id to get the proper LLM client
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
|
||||
if not memory_config.llm_model_id:
|
||||
raise ValueError(
|
||||
f"No llm_model_id found in memory config {config_id}. "
|
||||
"Please configure a valid LLM model."
|
||||
)
|
||||
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
|
||||
# 3. 构建Prompt
|
||||
tag_list_str = ", ".join(tags)
|
||||
@@ -107,33 +85,26 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
||||
# 在LLM失败时返回原始标签,确保流程继续
|
||||
return tags
|
||||
|
||||
def get_db_connection():
|
||||
"""
|
||||
使用项目的标准配置方法建立与Neo4j数据库的连接。
|
||||
"""
|
||||
# 从全局配置获取 Neo4j 连接信息
|
||||
uri = settings.NEO4J_URI
|
||||
user = settings.NEO4J_USERNAME
|
||||
|
||||
# 密码必须为了安全从环境变量加载
|
||||
password = os.getenv("NEO4J_PASSWORD")
|
||||
|
||||
if not uri or not user:
|
||||
raise ValueError("在 config.json 中未找到 Neo4j 的 'uri' 或 'username'。")
|
||||
if not password:
|
||||
raise ValueError("NEO4J_PASSWORD 环境变量未设置。")
|
||||
|
||||
# 为此脚本使用同步驱动
|
||||
return GraphDatabase.driver(uri, auth=(user, password))
|
||||
|
||||
def get_raw_tags_from_db(group_id: str, limit: int, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
async def get_raw_tags_from_db(
|
||||
connector: Neo4jConnector,
|
||||
group_id: str,
|
||||
limit: int,
|
||||
by_user: bool = False
|
||||
) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
TODO: not accurate tag extraction
|
||||
从数据库查询原始的、未经过滤的实体标签及其频率。
|
||||
|
||||
使用项目的Neo4jConnector进行查询,遵循仓储模式。
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器实例
|
||||
group_id: 如果by_user=False,则为group_id;如果by_user=True,则为user_id
|
||||
limit: 返回的标签数量限制
|
||||
by_user: 是否按user_id查询(默认False,按group_id查询)
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, int]]: 标签名称和频率的元组列表
|
||||
"""
|
||||
names_to_exclude = ['AI', 'Caroline', 'Melanie', 'Jon', 'Gina', '用户', 'AI助手', 'John', 'Maria']
|
||||
|
||||
@@ -154,83 +125,55 @@ def get_raw_tags_from_db(group_id: str, limit: int, by_user: bool = False) -> Li
|
||||
"LIMIT $limit"
|
||||
)
|
||||
|
||||
driver = None
|
||||
try:
|
||||
driver = get_db_connection()
|
||||
with driver.session() as session:
|
||||
result = session.run(query, id=group_id, limit=limit, names_to_exclude=names_to_exclude)
|
||||
return [(record["name"], record["frequency"]) for record in result]
|
||||
finally:
|
||||
if driver:
|
||||
driver.close()
|
||||
# 使用项目的Neo4jConnector执行查询
|
||||
results = await connector.execute_query(
|
||||
query,
|
||||
id=group_id,
|
||||
limit=limit,
|
||||
names_to_exclude=names_to_exclude
|
||||
)
|
||||
|
||||
return [(record["name"], record["frequency"]) for record in results]
|
||||
|
||||
async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
async def get_hot_memory_tags(group_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。
|
||||
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
|
||||
|
||||
Args:
|
||||
group_id: 如果by_user=False,则为group_id;如果by_user=True,则为user_id
|
||||
group_id: 必需参数。如果by_user=False,则为group_id;如果by_user=True,则为user_id
|
||||
limit: 返回的标签数量限制
|
||||
by_user: 是否按user_id查询(默认False,按group_id查询)
|
||||
|
||||
Raises:
|
||||
ValueError: 如果group_id未提供或为空
|
||||
"""
|
||||
# 默认从环境变量读取
|
||||
group_id = group_id or DEFAULT_GROUP_ID
|
||||
# 1. 从数据库获取原始排名靠前的标签
|
||||
raw_tags_with_freq = get_raw_tags_from_db(group_id, limit, by_user=by_user)
|
||||
if not raw_tags_with_freq:
|
||||
return []
|
||||
|
||||
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
||||
|
||||
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
|
||||
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id)
|
||||
|
||||
# 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序)
|
||||
final_tags = []
|
||||
for tag, freq in raw_tags_with_freq:
|
||||
if tag in meaningful_tag_names:
|
||||
final_tags.append((tag, freq))
|
||||
|
||||
return final_tags
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("开始获取热门记忆标签...")
|
||||
# 验证group_id必须提供且不为空
|
||||
if not group_id or not group_id.strip():
|
||||
raise ValueError(
|
||||
"group_id is required. Please provide a valid group_id or user_id."
|
||||
)
|
||||
|
||||
# 使用项目的Neo4jConnector
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
# 直接使用环境变量中的 group_id
|
||||
group_id_to_query = DEFAULT_GROUP_ID
|
||||
# 使用 asyncio.run 来执行异步主函数
|
||||
top_tags = asyncio.run(get_hot_memory_tags(group_id=group_id_to_query))
|
||||
# 1. 从数据库获取原始排名靠前的标签
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, group_id, limit, by_user=by_user)
|
||||
if not raw_tags_with_freq:
|
||||
return []
|
||||
|
||||
if top_tags:
|
||||
print(f"热门记忆标签 (Group ID: {group_id_to_query}, 经LLM筛选):")
|
||||
for tag, frequency in top_tags:
|
||||
print(f"- {tag} (数量: {frequency})")
|
||||
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
||||
|
||||
# --- 将结果写入统一的 Signboard.json 到 logs/memory-output ---
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
signboard_path = settings.get_memory_output_path("Signboard.json")
|
||||
payload = {
|
||||
"group_id": group_id_to_query,
|
||||
"hot_tags": [{"name": t, "frequency": f} for t, f in top_tags]
|
||||
}
|
||||
try:
|
||||
existing = {}
|
||||
if os.path.exists(signboard_path):
|
||||
with open(signboard_path, "r", encoding="utf-8") as rf:
|
||||
existing = json.load(rf)
|
||||
existing["hot_memory_tags"] = payload
|
||||
with open(signboard_path, "w", encoding="utf-8") as wf:
|
||||
json.dump(existing, wf, ensure_ascii=False, indent=2)
|
||||
print(f"已写入 {signboard_path} -> hot_memory_tags")
|
||||
except Exception as e:
|
||||
print(f"写入 Signboard.json 失败: {e}")
|
||||
else:
|
||||
print(f"在 Group ID '{group_id_to_query}' 中没有找到符合条件的实体标签。")
|
||||
except Exception as e:
|
||||
print(f"执行过程中发生严重错误: {e}")
|
||||
print("请检查:")
|
||||
print("1. Neo4j数据库服务是否正在运行。")
|
||||
print("2. 'config.json'中的配置是否正确。")
|
||||
print("3. 相关的环境变量 (如 NEO4J_PASSWORD, DASHSCOPE_API_KEY) 是否已正确设置。")
|
||||
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
|
||||
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id)
|
||||
|
||||
# 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序)
|
||||
final_tags = []
|
||||
for tag, freq in raw_tags_with_freq:
|
||||
if tag in meaningful_tag_names:
|
||||
final_tags.append((tag, freq))
|
||||
|
||||
return final_tags
|
||||
finally:
|
||||
# 确保关闭连接
|
||||
await connector.close()
|
||||
|
||||
@@ -131,179 +131,60 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
||||
return results
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 以下函数已被 rerank_with_activation 替代,暂时保留以供参考
|
||||
# ============================================================================
|
||||
|
||||
# def rerank_hybrid_results(
|
||||
# keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||
# embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||
# alpha: float = 0.6,
|
||||
# limit: int = 10
|
||||
# ) -> Dict[str, List[Dict[str, Any]]]:
|
||||
# """
|
||||
# Rerank hybrid search results by combining BM25 and embedding scores.
|
||||
#
|
||||
# 已废弃:此函数功能已被 rerank_with_activation 完全替代
|
||||
#
|
||||
# Args:
|
||||
# keyword_results: Results from keyword/BM25 search
|
||||
# embedding_results: Results from embedding search
|
||||
# alpha: Weight for BM25 scores (1-alpha for embedding scores)
|
||||
# limit: Maximum number of results to return per category
|
||||
#
|
||||
# Returns:
|
||||
# Reranked results with combined scores
|
||||
# """
|
||||
# reranked = {}
|
||||
#
|
||||
# for category in ["statements", "chunks", "entities","summaries"]:
|
||||
# keyword_items = keyword_results.get(category, [])
|
||||
# embedding_items = embedding_results.get(category, [])
|
||||
#
|
||||
# # Normalize scores within each search type
|
||||
# keyword_items = normalize_scores(keyword_items, "score")
|
||||
# embedding_items = normalize_scores(embedding_items, "score")
|
||||
#
|
||||
# # Create a combined pool of unique items
|
||||
# combined_items = {}
|
||||
#
|
||||
# # Add keyword results with BM25 scores
|
||||
# for item in keyword_items:
|
||||
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
# if item_id:
|
||||
# combined_items[item_id] = item.copy()
|
||||
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
# combined_items[item_id]["embedding_score"] = 0 # Default
|
||||
#
|
||||
# # Add or update with embedding results
|
||||
# for item in embedding_items:
|
||||
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
# if item_id:
|
||||
# if item_id in combined_items:
|
||||
# # Update existing item with embedding score
|
||||
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
# else:
|
||||
# # New item from embedding search only
|
||||
# combined_items[item_id] = item.copy()
|
||||
# combined_items[item_id]["bm25_score"] = 0 # Default
|
||||
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
#
|
||||
# # Calculate combined scores and rank
|
||||
# for item_id, item in combined_items.items():
|
||||
# bm25_score = item.get("bm25_score", 0)
|
||||
# embedding_score = item.get("embedding_score", 0)
|
||||
#
|
||||
# # Combined score: weighted average of normalized scores
|
||||
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
# item["combined_score"] = combined_score
|
||||
#
|
||||
# # Keep original score for reference
|
||||
# if "score" not in item and bm25_score > 0:
|
||||
# item["score"] = bm25_score
|
||||
# elif "score" not in item and embedding_score > 0:
|
||||
# item["score"] = embedding_score
|
||||
#
|
||||
# # Sort by combined score and limit results
|
||||
# sorted_items = sorted(
|
||||
# combined_items.values(),
|
||||
# key=lambda x: x.get("combined_score", 0),
|
||||
# reverse=True
|
||||
# )[:limit]
|
||||
#
|
||||
# reranked[category] = sorted_items
|
||||
#
|
||||
# return reranked
|
||||
|
||||
# def rerank_with_forgetting_curve(
|
||||
# keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||
# embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||
# alpha: float = 0.6,
|
||||
# limit: int = 10,
|
||||
# forgetting_config: ForgettingEngineConfig | None = None,
|
||||
# now: datetime | None = None,
|
||||
# ) -> Dict[str, List[Dict[str, Any]]]:
|
||||
# """
|
||||
# Rerank hybrid results with a forgetting curve applied to combined scores.
|
||||
#
|
||||
# 已废弃:此函数功能已被 rerank_with_activation 完全替代
|
||||
# rerank_with_activation 提供了更完整的遗忘曲线支持(结合激活度)
|
||||
#
|
||||
# The forgetting curve reduces scores for older memories or weaker connections.
|
||||
#
|
||||
# Args:
|
||||
# keyword_results: Results from keyword/BM25 search
|
||||
# embedding_results: Results from embedding search
|
||||
# alpha: Weight for BM25 scores (1-alpha for embedding scores)
|
||||
# limit: Maximum number of results to return per category
|
||||
# forgetting_config: Configuration for the forgetting engine
|
||||
# now: Optional current time override for testing
|
||||
#
|
||||
# Returns:
|
||||
# Reranked results with combined and final scores (after forgetting)
|
||||
# """
|
||||
# engine = ForgettingEngine(forgetting_config or ForgettingEngineConfig())
|
||||
# now_dt = now or datetime.now()
|
||||
#
|
||||
# reranked: Dict[str, List[Dict[str, Any]]] = {}
|
||||
#
|
||||
# for category in ["statements", "chunks", "entities","summaries"]:
|
||||
# keyword_items = keyword_results.get(category, [])
|
||||
# embedding_items = embedding_results.get(category, [])
|
||||
#
|
||||
# # Normalize scores within each search type
|
||||
# keyword_items = normalize_scores(keyword_items, "score")
|
||||
# embedding_items = normalize_scores(embedding_items, "score")
|
||||
#
|
||||
# combined_items: Dict[str, Dict[str, Any]] = {}
|
||||
#
|
||||
# # Combine two result sets by ID
|
||||
# for src_items, is_embedding in (
|
||||
# (keyword_items, False), (embedding_items, True)
|
||||
# ):
|
||||
# for item in src_items:
|
||||
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
# if not item_id:
|
||||
# continue
|
||||
# existing = combined_items.get(item_id)
|
||||
# if not existing:
|
||||
# combined_items[item_id] = item.copy()
|
||||
# combined_items[item_id]["bm25_score"] = 0
|
||||
# combined_items[item_id]["embedding_score"] = 0
|
||||
# # Update normalized score from the right source
|
||||
# if is_embedding:
|
||||
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
# else:
|
||||
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
#
|
||||
# # Calculate scores and apply forgetting weights
|
||||
# for item_id, item in combined_items.items():
|
||||
# bm25_score = float(item.get("bm25_score", 0) or 0)
|
||||
# embedding_score = float(item.get("embedding_score", 0) or 0)
|
||||
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
#
|
||||
# # Estimate time elapsed in days
|
||||
# dt = _parse_datetime(item.get("created_at"))
|
||||
# if dt is None:
|
||||
# time_elapsed_days = 0.0
|
||||
# else:
|
||||
# time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
|
||||
#
|
||||
# # Memory strength (currently set to default value)
|
||||
# memory_strength = 1.0
|
||||
# forgetting_weight = engine.calculate_weight(
|
||||
# time_elapsed=time_elapsed_days, memory_strength=memory_strength
|
||||
# )
|
||||
# final_score = combined_score * forgetting_weight
|
||||
# item["combined_score"] = final_score
|
||||
#
|
||||
# sorted_items = sorted(
|
||||
# combined_items.values(), key=lambda x: x.get("combined_score", 0), reverse=True
|
||||
# )[:limit]
|
||||
#
|
||||
# reranked[category] = sorted_items
|
||||
#
|
||||
# return reranked
|
||||
def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Remove duplicate items from search results based on content.
|
||||
|
||||
Deduplication strategy:
|
||||
1. First try to deduplicate by ID (id, uuid, or chunk_id)
|
||||
2. Then deduplicate by content hash (text, content, statement, or name fields)
|
||||
|
||||
Args:
|
||||
items: List of search result items
|
||||
|
||||
Returns:
|
||||
Deduplicated list of items, preserving the order of first occurrence
|
||||
"""
|
||||
seen_ids = set()
|
||||
seen_content = set()
|
||||
deduplicated = []
|
||||
|
||||
for item in items:
|
||||
# Try multiple ID fields to identify unique items
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
|
||||
# Extract content from various possible fields
|
||||
content = (
|
||||
item.get("text") or
|
||||
item.get("content") or
|
||||
item.get("statement") or
|
||||
item.get("name") or
|
||||
""
|
||||
)
|
||||
|
||||
# Normalize content for comparison (strip whitespace and lowercase)
|
||||
normalized_content = str(content).strip().lower() if content else ""
|
||||
|
||||
# Check if we've seen this ID or content before
|
||||
is_duplicate = False
|
||||
|
||||
if item_id and item_id in seen_ids:
|
||||
is_duplicate = True
|
||||
elif normalized_content and normalized_content in seen_content:
|
||||
# Only check content duplication if content is not empty
|
||||
is_duplicate = True
|
||||
|
||||
if not is_duplicate:
|
||||
# Mark as seen
|
||||
if item_id:
|
||||
seen_ids.add(item_id)
|
||||
if normalized_content: # Only track non-empty content
|
||||
seen_content.add(normalized_content)
|
||||
|
||||
deduplicated.append(item)
|
||||
|
||||
return deduplicated
|
||||
|
||||
|
||||
def rerank_with_activation(
|
||||
@@ -364,7 +245,7 @@ def rerank_with_activation(
|
||||
keyword_items = normalize_scores(keyword_items, "score")
|
||||
embedding_items = normalize_scores(embedding_items, "score")
|
||||
|
||||
# 步骤 2: 按 ID 合并结果
|
||||
# 步骤 2: 按 ID 合并结果(去重)
|
||||
combined_items: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# 添加关键词结果
|
||||
@@ -507,6 +388,9 @@ def rerank_with_activation(
|
||||
# 无激活值:使用内容相关性分数
|
||||
item["final_score"] = item.get("base_score", 0)
|
||||
|
||||
# 最终去重确保没有重复项
|
||||
sorted_items = _deduplicate_results(sorted_items)
|
||||
|
||||
reranked[category] = sorted_items
|
||||
|
||||
return reranked
|
||||
@@ -1144,96 +1028,3 @@ async def search_chunk_by_chunk_id(
|
||||
)
|
||||
return {"chunks": chunks}
|
||||
|
||||
|
||||
# def main():
|
||||
# """Main entry point for the hybrid graph search CLI.
|
||||
|
||||
# Parses command line arguments and executes search with specified parameters.
|
||||
# Supports keyword, embedding, and hybrid search modes.
|
||||
# """
|
||||
# parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options")
|
||||
# parser.add_argument(
|
||||
# "--query", "-q", required=True, help="Free-text query to search"
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--search-type",
|
||||
# "-t",
|
||||
# choices=["keyword", "embedding", "hybrid"],
|
||||
# default="hybrid",
|
||||
# help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)"
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--config-id",
|
||||
# "-c",
|
||||
# type=int,
|
||||
# required=True,
|
||||
# help="Database configuration ID (required)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--group-id",
|
||||
# "-g",
|
||||
# default=None,
|
||||
# help="Optional group_id to filter results (default: None)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--limit",
|
||||
# "-k",
|
||||
# type=int,
|
||||
# default=5,
|
||||
# help="Max number of results per type (default: 5)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--include",
|
||||
# "-i",
|
||||
# nargs="+",
|
||||
# default=["statements", "chunks", "entities", "summaries"],
|
||||
# choices=["statements", "chunks", "entities", "summaries"],
|
||||
# help="Which targets to search for embedding search (default: statements chunks entities summaries)"
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--output",
|
||||
# "-o",
|
||||
# default="search_results.json",
|
||||
# help="Path to save the search results JSON (default: search_results.json)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--rerank-alpha",
|
||||
# "-a",
|
||||
# type=float,
|
||||
# default=0.6,
|
||||
# help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--forgetting-rerank",
|
||||
# action="store_true",
|
||||
# help="Apply forgetting curve during reranking for hybrid search.",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--llm-rerank",
|
||||
# action="store_true",
|
||||
# help="Apply LLM-based reranking for hybrid search.",
|
||||
# )
|
||||
# args = parser.parse_args()
|
||||
|
||||
# # Load memory config from database
|
||||
# from app.services.memory_config_service import MemoryConfigService
|
||||
# memory_config = MemoryConfigService.load_memory_config(args.config_id)
|
||||
|
||||
# asyncio.run(
|
||||
# run_hybrid_search(
|
||||
# query_text=args.query,
|
||||
# search_type=args.search_type,
|
||||
# group_id=args.group_id,
|
||||
# limit=args.limit,
|
||||
# include=args.include,
|
||||
# output_path=args.output,
|
||||
# memory_config=memory_config,
|
||||
# rerank_alpha=args.rerank_alpha,
|
||||
# use_forgetting_rerank=args.forgetting_rerank,
|
||||
# use_llm_rerank=args.llm_rerank,
|
||||
# )
|
||||
# )
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# main()
|
||||
|
||||
@@ -18,13 +18,11 @@ from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.memory.utils.config.get_data import (
|
||||
extract_and_process_changes,
|
||||
get_data,
|
||||
get_data_statement,
|
||||
)
|
||||
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
neo4j_query_all,
|
||||
|
||||
@@ -14,28 +14,8 @@ from .config_utils import (
|
||||
get_pruning_config,
|
||||
get_voice_config,
|
||||
)
|
||||
|
||||
# DEPRECATED: Global configuration variables removed
|
||||
# Use MemoryConfig objects with dependency injection instead
|
||||
# from .definitions import (
|
||||
# CONFIG, # DEPRECATED - empty dict for backward compatibility
|
||||
# RUNTIME_CONFIG, # DEPRECATED - minimal for backward compatibility
|
||||
# PROJECT_ROOT, # Still needed for file paths
|
||||
# reload_configuration_from_database, # DEPRECATED - returns False
|
||||
# )
|
||||
# DEPRECATED: overrides module removed - use MemoryConfig with dependency injection
|
||||
from .get_data import get_data
|
||||
|
||||
# litellm_config 需要时动态导入,避免循环依赖
|
||||
# from .litellm_config import (
|
||||
# LiteLLMConfig,
|
||||
# setup_litellm_enhanced,
|
||||
# get_usage_summary,
|
||||
# print_usage_summary,
|
||||
# get_instant_qps,
|
||||
# print_instant_qps,
|
||||
# )
|
||||
|
||||
__all__ = [
|
||||
# config_utils
|
||||
"get_model_config",
|
||||
@@ -45,18 +25,5 @@ __all__ = [
|
||||
"get_pruning_config",
|
||||
"get_picture_config",
|
||||
"get_voice_config",
|
||||
# definitions (DEPRECATED - use MemoryConfig objects instead)
|
||||
# "CONFIG", # DEPRECATED
|
||||
# "RUNTIME_CONFIG", # DEPRECATED
|
||||
# "PROJECT_ROOT",
|
||||
# "reload_configuration_from_database", # DEPRECATED
|
||||
# get_data
|
||||
"get_data",
|
||||
# litellm_config - 需要时从 .litellm_config 直接导入
|
||||
# "LiteLLMConfig",
|
||||
# "setup_litellm_enhanced",
|
||||
# "get_usage_summary",
|
||||
# "print_usage_summary",
|
||||
# "get_instant_qps",
|
||||
# "print_instant_qps",
|
||||
]
|
||||
|
||||
@@ -1,398 +0,0 @@
|
||||
"""
|
||||
配置管理优化模块
|
||||
|
||||
提供可选的配置管理优化功能,包括:
|
||||
- LRU 缓存策略
|
||||
- 缓存预热
|
||||
- 缓存监控指标
|
||||
- 动态 TTL 策略
|
||||
- 配置版本控制
|
||||
|
||||
这些优化是可选的,当前的基础实现已经满足大多数需求。
|
||||
"""
|
||||
import logging
|
||||
import statistics
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LRUConfigCache:
|
||||
"""
|
||||
LRU(Least Recently Used)配置缓存
|
||||
|
||||
当缓存达到最大容量时,自动淘汰最少使用的配置
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = 100, ttl: timedelta = timedelta(minutes=5)):
|
||||
"""
|
||||
初始化 LRU 缓存
|
||||
|
||||
Args:
|
||||
max_size: 最大缓存容量
|
||||
ttl: 缓存过期时间
|
||||
"""
|
||||
self.max_size = max_size
|
||||
self.ttl = ttl
|
||||
self._cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
|
||||
self._timestamps: Dict[str, datetime] = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# 统计信息
|
||||
self._stats = {
|
||||
'hits': 0,
|
||||
'misses': 0,
|
||||
'evictions': 0,
|
||||
'load_times': []
|
||||
}
|
||||
|
||||
def get(self, config_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取配置(如果存在且未过期)
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
|
||||
Returns:
|
||||
配置字典,如果不存在或已过期则返回 None
|
||||
"""
|
||||
with self._lock:
|
||||
if config_id not in self._cache:
|
||||
self._stats['misses'] += 1
|
||||
return None
|
||||
|
||||
# 检查是否过期
|
||||
timestamp = self._timestamps.get(config_id)
|
||||
if timestamp and (datetime.now() - timestamp) >= self.ttl:
|
||||
# 过期,移除
|
||||
self._cache.pop(config_id, None)
|
||||
self._timestamps.pop(config_id, None)
|
||||
self._stats['misses'] += 1
|
||||
return None
|
||||
|
||||
# 命中,移动到末尾(标记为最近使用)
|
||||
self._cache.move_to_end(config_id)
|
||||
self._stats['hits'] += 1
|
||||
return self._cache[config_id]
|
||||
|
||||
def put(self, config_id: str, config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
添加或更新配置
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
config: 配置字典
|
||||
"""
|
||||
with self._lock:
|
||||
if config_id in self._cache:
|
||||
# 更新现有配置
|
||||
self._cache.move_to_end(config_id)
|
||||
else:
|
||||
# 添加新配置
|
||||
if len(self._cache) >= self.max_size:
|
||||
# 缓存已满,移除最旧的配置
|
||||
oldest_id, _ = self._cache.popitem(last=False)
|
||||
self._timestamps.pop(oldest_id, None)
|
||||
self._stats['evictions'] += 1
|
||||
logger.debug(f"[LRUCache] 淘汰配置: {oldest_id}")
|
||||
|
||||
self._cache[config_id] = config
|
||||
self._timestamps[config_id] = datetime.now()
|
||||
|
||||
def clear(self, config_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
清除缓存
|
||||
|
||||
Args:
|
||||
config_id: 如果指定,只清除该配置;否则清除所有
|
||||
"""
|
||||
with self._lock:
|
||||
if config_id:
|
||||
self._cache.pop(config_id, None)
|
||||
self._timestamps.pop(config_id, None)
|
||||
else:
|
||||
self._cache.clear()
|
||||
self._timestamps.clear()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取缓存统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
with self._lock:
|
||||
total = self._stats['hits'] + self._stats['misses']
|
||||
hit_rate = (self._stats['hits'] / total * 100) if total > 0 else 0
|
||||
|
||||
return {
|
||||
'cache_size': len(self._cache),
|
||||
'max_size': self.max_size,
|
||||
'total_requests': total,
|
||||
'cache_hits': self._stats['hits'],
|
||||
'cache_misses': self._stats['misses'],
|
||||
'evictions': self._stats['evictions'],
|
||||
'hit_rate': hit_rate,
|
||||
'avg_load_time': statistics.mean(self._stats['load_times']) if self._stats['load_times'] else 0
|
||||
}
|
||||
|
||||
def record_load_time(self, load_time_ms: float) -> None:
|
||||
"""
|
||||
记录加载时间
|
||||
|
||||
Args:
|
||||
load_time_ms: 加载时间(毫秒)
|
||||
"""
|
||||
with self._lock:
|
||||
self._stats['load_times'].append(load_time_ms)
|
||||
# 只保留最近 1000 次的记录
|
||||
if len(self._stats['load_times']) > 1000:
|
||||
self._stats['load_times'] = self._stats['load_times'][-1000:]
|
||||
|
||||
|
||||
class ConfigCacheWarmer:
|
||||
"""
|
||||
配置缓存预热器
|
||||
|
||||
在系统启动时预加载常用配置,减少首次请求延迟
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def warmup(config_ids: List[str], load_func) -> Dict[str, bool]:
|
||||
"""
|
||||
预热缓存
|
||||
|
||||
Args:
|
||||
config_ids: 要预加载的配置 ID 列表
|
||||
load_func: 配置加载函数
|
||||
|
||||
Returns:
|
||||
每个配置的加载结果
|
||||
"""
|
||||
results = {}
|
||||
|
||||
logger.info(f"[CacheWarmer] 开始预热 {len(config_ids)} 个配置")
|
||||
|
||||
for config_id in config_ids:
|
||||
try:
|
||||
result = load_func(config_id)
|
||||
results[config_id] = result
|
||||
if result:
|
||||
logger.debug(f"[CacheWarmer] 成功预热配置: {config_id}")
|
||||
else:
|
||||
logger.warning(f"[CacheWarmer] 预热配置失败: {config_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"[CacheWarmer] 预热配置异常: {config_id}, 错误: {e}")
|
||||
results[config_id] = False
|
||||
|
||||
success_count = sum(1 for r in results.values() if r)
|
||||
logger.info(f"[CacheWarmer] 预热完成: {success_count}/{len(config_ids)} 成功")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class DynamicTTLStrategy:
|
||||
"""
|
||||
动态 TTL 策略
|
||||
|
||||
根据配置类型和更新频率动态调整缓存过期时间
|
||||
"""
|
||||
|
||||
# 预定义的 TTL 策略
|
||||
TTL_STRATEGIES = {
|
||||
'production': timedelta(minutes=30), # 生产配置较稳定
|
||||
'staging': timedelta(minutes=15), # 预发布配置中等稳定
|
||||
'development': timedelta(minutes=5), # 开发配置频繁变化
|
||||
'testing': timedelta(minutes=1), # 测试配置快速过期
|
||||
'default': timedelta(minutes=5) # 默认策略
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_ttl(cls, config_id: str, config_type: Optional[str] = None) -> timedelta:
|
||||
"""
|
||||
获取配置的 TTL
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
config_type: 配置类型(production/staging/development/testing)
|
||||
|
||||
Returns:
|
||||
TTL 时间间隔
|
||||
"""
|
||||
if config_type and config_type in cls.TTL_STRATEGIES:
|
||||
return cls.TTL_STRATEGIES[config_type]
|
||||
|
||||
# 根据 config_id 推断类型
|
||||
if 'prod' in config_id.lower():
|
||||
return cls.TTL_STRATEGIES['production']
|
||||
elif 'stag' in config_id.lower():
|
||||
return cls.TTL_STRATEGIES['staging']
|
||||
elif 'dev' in config_id.lower():
|
||||
return cls.TTL_STRATEGIES['development']
|
||||
elif 'test' in config_id.lower():
|
||||
return cls.TTL_STRATEGIES['testing']
|
||||
|
||||
return cls.TTL_STRATEGIES['default']
|
||||
|
||||
|
||||
class ConfigVersionManager:
|
||||
"""
|
||||
配置版本管理器
|
||||
|
||||
跟踪配置版本,当配置更新时自动失效旧版本缓存
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._versions: Dict[str, str] = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def get_version(self, config_id: str) -> Optional[str]:
|
||||
"""
|
||||
获取配置版本
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
|
||||
Returns:
|
||||
版本号,如果不存在则返回 None
|
||||
"""
|
||||
with self._lock:
|
||||
return self._versions.get(config_id)
|
||||
|
||||
def set_version(self, config_id: str, version: str) -> None:
|
||||
"""
|
||||
设置配置版本
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
version: 版本号
|
||||
"""
|
||||
with self._lock:
|
||||
old_version = self._versions.get(config_id)
|
||||
self._versions[config_id] = version
|
||||
|
||||
if old_version and old_version != version:
|
||||
logger.info(f"[VersionManager] 配置版本更新: {config_id} {old_version} -> {version}")
|
||||
|
||||
def check_version(self, config_id: str, cached_version: Optional[str]) -> bool:
|
||||
"""
|
||||
检查缓存版本是否有效
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
cached_version: 缓存的版本号
|
||||
|
||||
Returns:
|
||||
True 如果版本匹配,False 如果版本不匹配或不存在
|
||||
"""
|
||||
with self._lock:
|
||||
current_version = self._versions.get(config_id)
|
||||
|
||||
if not current_version or not cached_version:
|
||||
return False
|
||||
|
||||
return current_version == cached_version
|
||||
|
||||
def invalidate(self, config_id: str) -> None:
|
||||
"""
|
||||
使配置版本失效
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
"""
|
||||
with self._lock:
|
||||
if config_id in self._versions:
|
||||
# 生成新版本号
|
||||
import uuid
|
||||
new_version = str(uuid.uuid4())
|
||||
self._versions[config_id] = new_version
|
||||
logger.info(f"[VersionManager] 配置版本失效: {config_id} -> {new_version}")
|
||||
|
||||
|
||||
class CacheMonitor:
|
||||
"""
|
||||
缓存监控器
|
||||
|
||||
提供缓存性能监控和报告功能
|
||||
"""
|
||||
|
||||
def __init__(self, cache: LRUConfigCache):
|
||||
self.cache = cache
|
||||
|
||||
def get_report(self) -> str:
|
||||
"""
|
||||
生成缓存性能报告
|
||||
|
||||
Returns:
|
||||
格式化的报告字符串
|
||||
"""
|
||||
stats = self.cache.get_stats()
|
||||
|
||||
report = f"""
|
||||
配置缓存性能报告
|
||||
================
|
||||
缓存容量: {stats['cache_size']}/{stats['max_size']}
|
||||
总请求数: {stats['total_requests']}
|
||||
缓存命中: {stats['cache_hits']}
|
||||
缓存未命中: {stats['cache_misses']}
|
||||
缓存命中率: {stats['hit_rate']:.2f}%
|
||||
淘汰次数: {stats['evictions']}
|
||||
平均加载时间: {stats['avg_load_time']:.2f}ms
|
||||
"""
|
||||
return report
|
||||
|
||||
def log_stats(self) -> None:
|
||||
"""记录统计信息到日志"""
|
||||
stats = self.cache.get_stats()
|
||||
logger.info(
|
||||
f"[CacheMonitor] 缓存统计 - "
|
||||
f"容量: {stats['cache_size']}/{stats['max_size']}, "
|
||||
f"命中率: {stats['hit_rate']:.2f}%, "
|
||||
f"淘汰: {stats['evictions']}"
|
||||
)
|
||||
|
||||
|
||||
# 使用示例
|
||||
def example_usage():
|
||||
"""
|
||||
优化功能使用示例
|
||||
"""
|
||||
# 1. 使用 LRU 缓存
|
||||
lru_cache = LRUConfigCache(max_size=100, ttl=timedelta(minutes=5))
|
||||
|
||||
# 获取配置
|
||||
config = lru_cache.get("config_001")
|
||||
if config is None:
|
||||
# 缓存未命中,从数据库加载
|
||||
config = {"llm_name": "openai/gpt-4"}
|
||||
lru_cache.put("config_001", config)
|
||||
|
||||
# 2. 预热缓存
|
||||
def load_config(config_id):
|
||||
# 实际的配置加载逻辑
|
||||
return True
|
||||
|
||||
warmer = ConfigCacheWarmer()
|
||||
results = warmer.warmup(["config_001", "config_002"], load_config)
|
||||
|
||||
# 3. 动态 TTL
|
||||
ttl = DynamicTTLStrategy.get_ttl("prod_config_001", "production")
|
||||
print(f"TTL: {ttl}")
|
||||
|
||||
# 4. 版本管理
|
||||
version_manager = ConfigVersionManager()
|
||||
version_manager.set_version("config_001", "v1.0.0")
|
||||
|
||||
# 检查版本
|
||||
is_valid = version_manager.check_version("config_001", "v1.0.0")
|
||||
|
||||
# 5. 监控
|
||||
monitor = CacheMonitor(lru_cache)
|
||||
print(monitor.get_report())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
example_usage()
|
||||
@@ -1,268 +0,0 @@
|
||||
# """
|
||||
# 配置加载模块 - DEPRECATED
|
||||
|
||||
# ⚠️ DEPRECATION NOTICE ⚠️
|
||||
# This module is deprecated and will be removed in a future version.
|
||||
# Global configuration variables have been eliminated in favor of dependency injection.
|
||||
|
||||
# Use the new MemoryConfig system instead:
|
||||
# - app.schemas.memory_config_schema.MemoryConfig for configuration objects
|
||||
# - config_service = MemoryConfigService(db); config_service.load_memory_config(config_id)
|
||||
|
||||
# 阶段 1: 从 runtime.json 加载配置(路径 A)- DEPRECATED
|
||||
# 阶段 2: 从数据库加载配置(路径 B,基于 dbrun.json 中的 config_id)- DEPRECATED
|
||||
# 阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)- DEPRECATED
|
||||
# """
|
||||
# import json
|
||||
# import os
|
||||
# import threading
|
||||
# from datetime import datetime, timedelta
|
||||
# from typing import Any, Dict, Optional
|
||||
|
||||
# #TODO: Fix this
|
||||
|
||||
# try:
|
||||
# from dotenv import load_dotenv
|
||||
# load_dotenv()
|
||||
# except Exception:
|
||||
# pass
|
||||
|
||||
# # Import unified configuration system
|
||||
# try:
|
||||
# from app.core.config import settings
|
||||
# USE_UNIFIED_CONFIG = True
|
||||
# except ImportError:
|
||||
# USE_UNIFIED_CONFIG = False
|
||||
# settings = None
|
||||
|
||||
# # PROJECT_ROOT 应该指向 app/core/memory/ 目录
|
||||
# # __file__ = app/core/memory/utils/config/definitions.py
|
||||
# # os.path.dirname(__file__) = app/core/memory/utils/config
|
||||
# # os.path.dirname(...) = app/core/memory/utils
|
||||
# # os.path.dirname(...) = app/core/memory
|
||||
# PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# # DEPRECATED: Global configuration lock removed
|
||||
# # Use MemoryConfig objects with dependency injection instead
|
||||
|
||||
# # DEPRECATED: Legacy config.json loading removed
|
||||
# # Use MemoryConfig objects with dependency injection instead
|
||||
# CONFIG = {}
|
||||
|
||||
# DEFAULT_VALUES = {
|
||||
# "llm_name": "openai/qwen-plus",
|
||||
# "embedding_name": "openai/nomic-embed-text:v1.5",
|
||||
# "chunker_strategy": "RecursiveChunker",
|
||||
# "group_id": "group_123",
|
||||
# "user_id": "default_user",
|
||||
# "apply_id": "default_apply",
|
||||
# "llm_agent_name": "openai/qwen-plus",
|
||||
# "llm_verify_name": "openai/qwen-plus",
|
||||
# "llm_image_recognition": "openai/qwen-plus",
|
||||
# "llm_voice_recognition": "openai/qwen-plus",
|
||||
# "prompt_level": "DEBUG",
|
||||
# "reflexion_iteration_period": "3",
|
||||
# "reflexion_range": "retrieval",
|
||||
# "reflexion_baseline": "TIME",
|
||||
# }
|
||||
|
||||
# # DEPRECATED: Legacy global variables for backward compatibility only
|
||||
# # These will be removed in a future version
|
||||
# # Use MemoryConfig objects with dependency injection instead
|
||||
# # LANGFUSE_ENABLED = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true"
|
||||
# # SELECTED_LLM_ID = os.getenv("SELECTED_LLM_ID", DEFAULT_VALUES["llm_name"])
|
||||
|
||||
|
||||
# # 阶段 1: 从 runtime.json 加载配置(路径 A)
|
||||
# def _load_from_runtime_json() -> Dict[str, Any]:
|
||||
# """
|
||||
# DEPRECATED: Legacy runtime.json loading
|
||||
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
# Returns:
|
||||
# Dict[str, Any]: Empty configuration (legacy support only)
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "Runtime JSON loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# return {"selections": {}}
|
||||
|
||||
|
||||
# # 阶段 2: 从数据库加载配置(路径 B)- 已整合到统一加载器
|
||||
# # 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代
|
||||
# # 保留此函数仅为向后兼容
|
||||
# def _load_from_database() -> Optional[Dict[str, Any]]:
|
||||
# """
|
||||
# DEPRECATED: Legacy database configuration loading
|
||||
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
# Returns:
|
||||
# Optional[Dict[str, Any]]: None (deprecated functionality)
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "Database configuration loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# return None
|
||||
|
||||
|
||||
# # 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)- DEPRECATED
|
||||
# def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None:
|
||||
# """
|
||||
# DEPRECATED: 将运行时配置暴露为全局常量供项目使用
|
||||
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Global configuration variables have been eliminated in favor of dependency injection.
|
||||
|
||||
# Use the new MemoryConfig system instead:
|
||||
# - app.core.memory_config.config.MemoryConfig for configuration objects
|
||||
# - Pass configuration objects as parameters instead of using global variables
|
||||
|
||||
# Args:
|
||||
# runtime_cfg: 运行时配置字典
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "Global configuration variables are deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
|
||||
# # Keep minimal global state for backward compatibility only
|
||||
# # These will be removed in a future version
|
||||
# global RUNTIME_CONFIG, SELECTIONS
|
||||
|
||||
# RUNTIME_CONFIG = runtime_cfg
|
||||
# SELECTIONS = RUNTIME_CONFIG.get("selections", {})
|
||||
|
||||
# # All other global variables have been removed
|
||||
# # Use MemoryConfig objects instead
|
||||
|
||||
|
||||
# # 初始化:使用统一配置加载器
|
||||
# def _initialize_configuration() -> None:
|
||||
# """
|
||||
# DEPRECATED: Legacy configuration initialization
|
||||
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "Global configuration initialization is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# # Initialize with empty configuration for backward compatibility
|
||||
# _expose_runtime_constants({"selections": {}})
|
||||
|
||||
|
||||
# # 模块加载时自动初始化配置
|
||||
# _initialize_configuration()
|
||||
|
||||
# # DEPRECATED: Global variables removed
|
||||
# # These variables have been eliminated in favor of dependency injection
|
||||
# # Use MemoryConfig objects instead of accessing global variables
|
||||
|
||||
|
||||
# # 公共 API:动态重新加载配置
|
||||
# def reload_configuration_from_database(config_id, force_reload: bool = False) -> bool:
|
||||
# """
|
||||
# DEPRECATED: Legacy configuration reloading
|
||||
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
# For new code, use:
|
||||
# - app.services.memory_agent_service.MemoryAgentService.load_memory_config()
|
||||
# - app.services.memory_storage_service.MemoryStorageService.load_memory_config()
|
||||
|
||||
# Args:
|
||||
# config_id: Configuration ID (deprecated)
|
||||
# force_reload: Force reload flag (deprecated)
|
||||
|
||||
# Returns:
|
||||
# bool: Always returns False (deprecated functionality)
|
||||
# """
|
||||
# import logging
|
||||
# import warnings
|
||||
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
||||
# warnings.warn(
|
||||
# "reload_configuration_from_database is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
|
||||
# logger.warning(f"Deprecated function reload_configuration_from_database called with config_id={config_id}. "
|
||||
# "Use MemoryConfig objects with dependency injection instead.")
|
||||
|
||||
# return False
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# def get_current_config_id() -> Optional[str]:
|
||||
# """
|
||||
# DEPRECATED: Legacy config ID retrieval
|
||||
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
# Returns:
|
||||
# Optional[str]: None (deprecated functionality)
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "get_current_config_id is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# return None
|
||||
|
||||
|
||||
# def ensure_fresh_config(config_id = None) -> bool:
|
||||
# """
|
||||
# DEPRECATED: Legacy configuration freshness check
|
||||
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
# For new code, use:
|
||||
# - app.services.memory_agent_service.MemoryAgentService.load_memory_config()
|
||||
# - app.services.memory_storage_service.MemoryStorageService.load_memory_config()
|
||||
|
||||
# Args:
|
||||
# config_id: Configuration ID (deprecated)
|
||||
|
||||
# Returns:
|
||||
# bool: Always returns False (deprecated functionality)
|
||||
# """
|
||||
# import logging
|
||||
# import warnings
|
||||
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
||||
# warnings.warn(
|
||||
# "ensure_fresh_config is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
|
||||
# logger.warning(f"Deprecated function ensure_fresh_config called with config_id={config_id}. "
|
||||
# "Use MemoryConfig objects with dependency injection instead.")
|
||||
|
||||
# return False
|
||||
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import random
|
||||
import string
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
# 生成包含字母(大小写)和数字的随机字符串
|
||||
def generate_random_string(length=16):
|
||||
characters = string.ascii_letters + string.digits
|
||||
return ''.join(random.choice(characters) for _ in range(length))
|
||||
|
||||
def get_example_data() -> List[Dict[str, Optional[str]]]:
|
||||
"""
|
||||
从句子提取日志中获取数据
|
||||
Content: 在苹果公司中国总部,用户和李华偶遇了从美国来的技术专家约翰·史密斯。
|
||||
Created At: 2025-11-28 19:28:38.256421
|
||||
Expired At: None
|
||||
Valid At: None
|
||||
Invalid At: None
|
||||
将数据构造成如下形式:
|
||||
[
|
||||
{
|
||||
"id":id,
|
||||
"group_id":group_id,
|
||||
"statement": Content,
|
||||
"created_at": Created At,
|
||||
"expired_at": Expired At,
|
||||
"valid_at": Valid At,
|
||||
"invalid_at": Invalid At,
|
||||
"chunk_id": "86da9022710c40eaa5f518a294c398d2",
|
||||
"entity_ids": []
|
||||
},
|
||||
...
|
||||
]
|
||||
"""
|
||||
# 获取日志文件路径
|
||||
log_file_path = os.path.join("logs", "memory-output", "statement_extraction.txt")
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(log_file_path):
|
||||
return []
|
||||
|
||||
# 读取日志文件
|
||||
with open(log_file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# 解析数据
|
||||
results = []
|
||||
|
||||
# 使用正则表达式分割每个 Statement
|
||||
statement_blocks = re.split(r"Statement \d+:", content)
|
||||
|
||||
for block in statement_blocks[1:]: # 跳过第一个空块
|
||||
# 提取各个字段
|
||||
id_match = re.search(r"Id:\s*(.+?)(?=\n)", block)
|
||||
group_id_match = re.search(r"Group Id:\s*(.+?)(?=\n)", block)
|
||||
statement_match = re.search(r"Content:\s*(.+?)(?=\n)", block)
|
||||
created_at_match = re.search(r"Created At:\s*(.+?)(?=\n)", block)
|
||||
expired_at_match = re.search(r"Expired At:\s*(.+?)(?=\n)", block)
|
||||
valid_at_match = re.search(r"Valid At:\s*(.+?)(?=\n)", block)
|
||||
invalid_at_match = re.search(r"Invalid At:\s*(.+?)(?=\n)", block)
|
||||
chunk_id_match = re.search(r"Chunk Id:\s*(.+?)(?=\n)", block)
|
||||
|
||||
# 构造字典
|
||||
if statement_match:
|
||||
statement_data = {
|
||||
"id": id_match.group(1).strip() if id_match else generate_random_string(),
|
||||
"group_id": group_id_match.group(1).strip() if group_id_match else "group_example",
|
||||
"statement": statement_match.group(1).strip(),
|
||||
"created_at": created_at_match.group(1).strip() if created_at_match else None,
|
||||
"expired_at": expired_at_match.group(1).strip() if expired_at_match else None,
|
||||
"valid_at": valid_at_match.group(1).strip() if valid_at_match else None,
|
||||
"invalid_at": invalid_at_match.group(1).strip() if invalid_at_match else None,
|
||||
"chunk_id": chunk_id_match.group(1).strip() if chunk_id_match else "chunk_example",
|
||||
"entity_ids": []
|
||||
}
|
||||
|
||||
# 将 "None" 字符串转换为 None
|
||||
for key in ["created_at", "expired_at", "valid_at", "invalid_at"]:
|
||||
if statement_data[key] == "None":
|
||||
statement_data[key] = None
|
||||
|
||||
results.append(statement_data)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"获取数据如下:\n {get_example_data()}")
|
||||
@@ -1,516 +0,0 @@
|
||||
"""
|
||||
LiteLLM Configuration for Enhanced Retry Logic and Usage Tracking with Native QPS Monitoring
|
||||
"""
|
||||
|
||||
import litellm
|
||||
from typing import Dict, Any, List
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
import threading
|
||||
from queue import Queue
|
||||
|
||||
class LiteLLMConfig:
|
||||
"""Configuration class for LiteLLM with enhanced retry and tracking capabilities"""
|
||||
|
||||
def __init__(self):
|
||||
self.usage_data = []
|
||||
self.error_data = []
|
||||
self.module_stats = defaultdict(lambda: {
|
||||
'requests': 0,
|
||||
'tokens_in': 0,
|
||||
'tokens_out': 0,
|
||||
'cost': 0.0,
|
||||
'errors': 0,
|
||||
'start_time': None,
|
||||
'last_request_time': None,
|
||||
'request_timestamps': [], # Store precise timestamps
|
||||
'current_qps': 0.0,
|
||||
'max_qps': 0.0,
|
||||
'qps_history': [] # Store QPS measurements over time
|
||||
})
|
||||
self.start_time = datetime.now()
|
||||
self.global_request_timestamps = []
|
||||
self.global_max_qps = 0.0
|
||||
|
||||
# Rate limiting for AWS Bedrock (conservative limits)
|
||||
self.rate_limits = {
|
||||
'bedrock': {
|
||||
'requests_per_minute': 2, # AWS Bedrock default is very low
|
||||
'requests_per_second': 0.033, # 2/60 = 0.033 RPS
|
||||
'last_request_time': 0,
|
||||
'request_queue': Queue(),
|
||||
'lock': threading.Lock()
|
||||
}
|
||||
}
|
||||
self.rate_limiting_enabled = True
|
||||
|
||||
def setup_enhanced_config(self, max_retries: int = 3):
|
||||
"""Configure LiteLLM with retry logic and instant QPS tracking"""
|
||||
|
||||
litellm.num_retries = max_retries
|
||||
litellm.request_timeout = 300
|
||||
|
||||
litellm.retry_policy = {
|
||||
"RateLimitError": {
|
||||
"max_retries": 5,
|
||||
"exponential_backoff": True,
|
||||
"initial_delay": 1,
|
||||
"max_delay": 60,
|
||||
"jitter": True
|
||||
},
|
||||
"APIConnectionError": {
|
||||
"max_retries": 3,
|
||||
"exponential_backoff": True,
|
||||
"initial_delay": 2,
|
||||
"max_delay": 30,
|
||||
"jitter": True
|
||||
},
|
||||
"InternalServerError": {
|
||||
"max_retries": 2,
|
||||
"exponential_backoff": True,
|
||||
"initial_delay": 5,
|
||||
"max_delay": 60,
|
||||
"jitter": True
|
||||
},
|
||||
"BadRequestError": {
|
||||
"max_retries": 1,
|
||||
"exponential_backoff": False,
|
||||
"initial_delay": 1,
|
||||
"max_delay": 5
|
||||
}
|
||||
}
|
||||
|
||||
litellm.success_callback = [self._success_callback]
|
||||
litellm.failure_callback = [self._failure_callback]
|
||||
litellm.completion_cost_tracking = True
|
||||
litellm.set_verbose = False
|
||||
litellm.modify_params = True
|
||||
|
||||
print("✅ LiteLLM configured with instant QPS tracking and rate limiting")
|
||||
|
||||
def _success_callback(self, kwargs, completion_response, start_time, end_time):
|
||||
"""Callback for successful requests with module-specific QPS tracking"""
|
||||
try:
|
||||
# Extract usage information
|
||||
usage = completion_response.get('usage', {})
|
||||
model = kwargs.get('model', 'unknown')
|
||||
|
||||
# Extract module information from metadata or model name
|
||||
module = self._extract_module_name(kwargs, model)
|
||||
|
||||
# Calculate cost
|
||||
cost = 0.0
|
||||
try:
|
||||
cost = litellm.completion_cost(completion_response)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Calculate duration
|
||||
duration_seconds = (end_time - start_time).total_seconds() if hasattr(end_time - start_time, 'total_seconds') else float(end_time - start_time)
|
||||
|
||||
# Record usage data
|
||||
usage_record = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": model,
|
||||
"module": module,
|
||||
"input_tokens": usage.get('prompt_tokens', 0),
|
||||
"output_tokens": usage.get('completion_tokens', 0),
|
||||
"total_tokens": usage.get('total_tokens', 0),
|
||||
"cost": cost,
|
||||
"duration_seconds": duration_seconds,
|
||||
"status": "success"
|
||||
}
|
||||
|
||||
self.usage_data.append(usage_record)
|
||||
|
||||
# Update module-specific stats for QPS tracking
|
||||
self._update_module_stats(module, usage_record, success=True)
|
||||
|
||||
# Print real-time feedback
|
||||
print(f"✓ {model}: {usage_record['input_tokens']}→{usage_record['output_tokens']} tokens, ${cost:.4f}, {usage_record['duration_seconds']:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Success callback failed: {e}")
|
||||
|
||||
def _failure_callback(self, kwargs, completion_response, start_time, end_time):
|
||||
"""Callback for failed requests with module-specific error tracking"""
|
||||
try:
|
||||
model = kwargs.get('model', 'unknown')
|
||||
module = self._extract_module_name(kwargs, model)
|
||||
|
||||
duration_seconds = (end_time - start_time).total_seconds() if hasattr(end_time - start_time, 'total_seconds') else float(end_time - start_time)
|
||||
|
||||
# Handle different error response formats
|
||||
error_message = "Unknown error"
|
||||
error_type = "UnknownError"
|
||||
|
||||
# According to LiteLLM docs, completion_response contains the exception for failures
|
||||
if completion_response is not None:
|
||||
error_message = str(completion_response)
|
||||
error_type = type(completion_response).__name__
|
||||
|
||||
# Also check kwargs for exception (LiteLLM passes exception in kwargs for failure events)
|
||||
elif 'exception' in kwargs:
|
||||
exception = kwargs['exception']
|
||||
error_message = str(exception)
|
||||
error_type = type(exception).__name__
|
||||
|
||||
# Check for other error formats in kwargs
|
||||
elif 'error' in kwargs:
|
||||
error = kwargs['error']
|
||||
error_message = str(error)
|
||||
error_type = type(error).__name__
|
||||
|
||||
# Check log_event_type to confirm this is a failure event
|
||||
log_event_type = kwargs.get('log_event_type', '')
|
||||
if log_event_type == 'failed_api_call' and 'exception' in kwargs:
|
||||
exception = kwargs['exception']
|
||||
error_message = str(exception)
|
||||
error_type = type(exception).__name__
|
||||
|
||||
error_record = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": model,
|
||||
"module": module,
|
||||
"error": error_message,
|
||||
"error_type": error_type,
|
||||
"duration_seconds": duration_seconds,
|
||||
"status": "failed"
|
||||
}
|
||||
|
||||
self.error_data.append(error_record)
|
||||
|
||||
# Update module-specific stats for error tracking
|
||||
self._update_module_stats(module, error_record, success=False)
|
||||
|
||||
# Print error feedback
|
||||
print(f"✗ {model}: {error_type} - {error_message[:100]}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Failure callback failed: {e}")
|
||||
# Debug: print the actual parameters to understand the structure
|
||||
print(f"Debug - kwargs keys: {list(kwargs.keys()) if kwargs else 'None'}")
|
||||
print(f"Debug - completion_response type: {type(completion_response)}")
|
||||
print(f"Debug - completion_response: {completion_response}")
|
||||
|
||||
def _should_rate_limit(self, model: str) -> bool:
|
||||
"""Check if the model should be rate limited"""
|
||||
if not self.rate_limiting_enabled:
|
||||
return False
|
||||
return model.startswith('bedrock/') or 'bedrock' in model.lower()
|
||||
|
||||
def _enforce_rate_limit(self, model: str):
|
||||
"""Enforce rate limiting for AWS Bedrock models"""
|
||||
if not self._should_rate_limit(model):
|
||||
return
|
||||
|
||||
provider = 'bedrock'
|
||||
if provider not in self.rate_limits:
|
||||
return
|
||||
|
||||
rate_config = self.rate_limits[provider]
|
||||
|
||||
with rate_config['lock']:
|
||||
current_time = time.time()
|
||||
time_since_last = current_time - rate_config['last_request_time']
|
||||
min_interval = 1.0 / rate_config['requests_per_second']
|
||||
|
||||
if time_since_last < min_interval:
|
||||
sleep_time = min_interval - time_since_last
|
||||
print(f"⏳ Rate limiting: sleeping {sleep_time:.2f}s for {model}")
|
||||
time.sleep(sleep_time)
|
||||
|
||||
rate_config['last_request_time'] = time.time()
|
||||
|
||||
def _extract_module_name(self, kwargs: Dict[str, Any], model: str) -> str:
|
||||
"""Extract module name from request context"""
|
||||
# Try to get module from metadata
|
||||
metadata = kwargs.get('metadata', {})
|
||||
if 'module' in metadata:
|
||||
return metadata['module']
|
||||
|
||||
# Try to infer from model name or other context
|
||||
if 'claude' in model.lower():
|
||||
return 'bedrock_client'
|
||||
elif 'gpt' in model.lower() or 'openai' in model.lower():
|
||||
return 'openai_client'
|
||||
elif 'embed' in model.lower():
|
||||
return 'embedder'
|
||||
else:
|
||||
return 'unknown'
|
||||
|
||||
def _update_module_stats(self, module: str, record: Dict[str, Any], success: bool):
|
||||
"""Update module-specific statistics with instant QPS tracking"""
|
||||
current_timestamp = time.time()
|
||||
current_time = datetime.now()
|
||||
|
||||
# Initialize module stats if first request
|
||||
if self.module_stats[module]['start_time'] is None:
|
||||
self.module_stats[module]['start_time'] = current_time
|
||||
|
||||
# Update counters
|
||||
self.module_stats[module]['requests'] += 1
|
||||
self.module_stats[module]['last_request_time'] = current_time
|
||||
self.module_stats[module]['request_timestamps'].append(current_timestamp)
|
||||
self.global_request_timestamps.append(current_timestamp)
|
||||
|
||||
# Calculate instant QPS for this module
|
||||
self._calculate_instant_qps(module, current_timestamp)
|
||||
|
||||
# Calculate global instant QPS
|
||||
self._calculate_global_instant_qps(current_timestamp)
|
||||
|
||||
if success:
|
||||
self.module_stats[module]['tokens_in'] += record.get('input_tokens', 0)
|
||||
self.module_stats[module]['tokens_out'] += record.get('output_tokens', 0)
|
||||
self.module_stats[module]['cost'] += record.get('cost', 0.0)
|
||||
else:
|
||||
self.module_stats[module]['errors'] += 1
|
||||
|
||||
def _calculate_instant_qps(self, module: str, current_timestamp: float):
|
||||
"""Calculate instant QPS for a specific module using sliding window"""
|
||||
# Keep only timestamps from last 1 second for instant QPS
|
||||
cutoff_time = current_timestamp - 1.0
|
||||
timestamps = self.module_stats[module]['request_timestamps']
|
||||
|
||||
# Remove old timestamps
|
||||
self.module_stats[module]['request_timestamps'] = [
|
||||
ts for ts in timestamps if ts >= cutoff_time
|
||||
]
|
||||
|
||||
# Calculate current QPS (requests in last second)
|
||||
current_qps = len(self.module_stats[module]['request_timestamps'])
|
||||
self.module_stats[module]['current_qps'] = current_qps
|
||||
|
||||
# Update max QPS if current is higher
|
||||
if current_qps > self.module_stats[module]['max_qps']:
|
||||
self.module_stats[module]['max_qps'] = current_qps
|
||||
|
||||
# Store QPS history (keep last 60 measurements)
|
||||
self.module_stats[module]['qps_history'].append(current_qps)
|
||||
if len(self.module_stats[module]['qps_history']) > 60:
|
||||
self.module_stats[module]['qps_history'].pop(0)
|
||||
|
||||
def _calculate_global_instant_qps(self, current_timestamp: float):
|
||||
"""Calculate global instant QPS across all modules"""
|
||||
# Keep only timestamps from last 1 second
|
||||
cutoff_time = current_timestamp - 1.0
|
||||
self.global_request_timestamps = [
|
||||
ts for ts in self.global_request_timestamps if ts >= cutoff_time
|
||||
]
|
||||
|
||||
# Calculate current global QPS
|
||||
current_global_qps = len(self.global_request_timestamps)
|
||||
|
||||
# Update max global QPS
|
||||
if current_global_qps > self.global_max_qps:
|
||||
self.global_max_qps = current_global_qps
|
||||
|
||||
def get_instant_qps(self, module: str = None) -> Dict[str, Any]:
|
||||
"""Get instant QPS data for modules"""
|
||||
if module:
|
||||
if module in self.module_stats:
|
||||
return {
|
||||
'module': module,
|
||||
'current_qps': self.module_stats[module]['current_qps'],
|
||||
'max_qps': self.module_stats[module]['max_qps'],
|
||||
'avg_qps_last_minute': sum(self.module_stats[module]['qps_history'][-60:]) / min(60, len(self.module_stats[module]['qps_history'])) if self.module_stats[module]['qps_history'] else 0
|
||||
}
|
||||
else:
|
||||
return {'module': module, 'current_qps': 0, 'max_qps': 0, 'avg_qps_last_minute': 0}
|
||||
else:
|
||||
# Return data for all modules plus global
|
||||
result = {
|
||||
'global': {
|
||||
'current_qps': len([ts for ts in self.global_request_timestamps if ts >= time.time() - 1.0]),
|
||||
'max_qps': self.global_max_qps
|
||||
},
|
||||
'modules': {}
|
||||
}
|
||||
|
||||
for mod in self.module_stats:
|
||||
result['modules'][mod] = {
|
||||
'current_qps': self.module_stats[mod]['current_qps'],
|
||||
'max_qps': self.module_stats[mod]['max_qps'],
|
||||
'avg_qps_last_minute': sum(self.module_stats[mod]['qps_history'][-60:]) / min(60, len(self.module_stats[mod]['qps_history'])) if self.module_stats[mod]['qps_history'] else 0
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def get_usage_summary(self) -> Dict[str, Any]:
|
||||
"""Get essential usage statistics"""
|
||||
if not self.usage_data:
|
||||
return {
|
||||
"total_requests": 0,
|
||||
"total_cost": 0.0,
|
||||
"error_rate": 0.0,
|
||||
"message": "No usage data available"
|
||||
}
|
||||
|
||||
total_requests = len(self.usage_data)
|
||||
total_errors = len(self.error_data)
|
||||
total_cost = sum(record['cost'] for record in self.usage_data)
|
||||
total_input_tokens = sum(record['input_tokens'] for record in self.usage_data)
|
||||
total_output_tokens = sum(record['output_tokens'] for record in self.usage_data)
|
||||
|
||||
# Calculate session duration
|
||||
duration_minutes = (datetime.now() - self.start_time).total_seconds() / 60
|
||||
|
||||
# Build module statistics
|
||||
module_stats = {}
|
||||
for module, stats in self.module_stats.items():
|
||||
if stats['requests'] > 0:
|
||||
module_stats[module] = {
|
||||
"requests": stats['requests'],
|
||||
"errors": stats['errors'],
|
||||
"success_rate": ((stats['requests'] - stats['errors']) / stats['requests'] * 100) if stats['requests'] > 0 else 0,
|
||||
"tokens_in": stats['tokens_in'],
|
||||
"tokens_out": stats['tokens_out'],
|
||||
"cost": stats['cost'],
|
||||
"current_qps": stats['current_qps'],
|
||||
"max_qps": stats['max_qps']
|
||||
}
|
||||
|
||||
return {
|
||||
"session_duration_minutes": duration_minutes,
|
||||
"total_requests": total_requests,
|
||||
"total_errors": total_errors,
|
||||
"error_rate": (total_errors / total_requests * 100) if total_requests > 0 else 0,
|
||||
"total_input_tokens": total_input_tokens,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"total_cost": total_cost,
|
||||
"module_stats": module_stats,
|
||||
"global_max_qps": self.global_max_qps
|
||||
}
|
||||
|
||||
def print_usage_summary(self):
|
||||
"""Print essential usage summary"""
|
||||
stats = self.get_usage_summary()
|
||||
|
||||
if stats.get('message'):
|
||||
print(f"📊 {stats['message']}")
|
||||
return
|
||||
|
||||
print("\n📊 USAGE SUMMARY")
|
||||
print(f"{'='*50}")
|
||||
print(f"⏱️ Duration: {stats['session_duration_minutes']:.1f} min")
|
||||
print(f"📈 Requests: {stats['total_requests']}")
|
||||
print(f"❌ Errors: {stats['total_errors']}")
|
||||
print(f"💰 Cost: ${stats['total_cost']:.4f}")
|
||||
print(f"🏆 Global Max QPS: {stats['global_max_qps']}")
|
||||
|
||||
# Module statistics
|
||||
if stats.get('module_stats'):
|
||||
print("\n📦 MODULES:")
|
||||
for module, mod_stats in stats['module_stats'].items():
|
||||
print(f" {module}: {mod_stats['requests']} req, Max QPS: {mod_stats['max_qps']}, Current: {mod_stats['current_qps']}")
|
||||
|
||||
print(f"{'='*50}")
|
||||
|
||||
def save_usage_data(self, filename: str = "litellm_usage.json"):
|
||||
"""Save usage data to JSON file"""
|
||||
data = {
|
||||
"summary": self.get_usage_summary(),
|
||||
"detailed_usage": self.usage_data,
|
||||
"errors": self.error_data,
|
||||
"export_timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
print(f"📁 Usage data saved to {filename}")
|
||||
|
||||
def reset_tracking(self):
|
||||
"""Reset all tracking data"""
|
||||
self.usage_data = []
|
||||
self.error_data = []
|
||||
self.module_stats = defaultdict(lambda: {
|
||||
'requests': 0,
|
||||
'tokens_in': 0,
|
||||
'tokens_out': 0,
|
||||
'cost': 0.0,
|
||||
'errors': 0,
|
||||
'start_time': None,
|
||||
'last_request_time': None,
|
||||
'request_timestamps': [],
|
||||
'current_qps': 0.0,
|
||||
'max_qps': 0.0,
|
||||
'qps_history': []
|
||||
})
|
||||
self.global_request_timestamps = []
|
||||
self.global_max_qps = 0.0
|
||||
self.start_time = datetime.now()
|
||||
print("🔄 All tracking data reset")
|
||||
|
||||
# Global instance for easy access
|
||||
litellm_config = LiteLLMConfig()
|
||||
|
||||
def setup_litellm_enhanced(max_retries: int = 3):
|
||||
"""
|
||||
Quick setup function for LiteLLM enhanced configuration
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retries for failed requests
|
||||
"""
|
||||
litellm_config.setup_enhanced_config(max_retries)
|
||||
return litellm_config
|
||||
|
||||
def get_usage_summary():
|
||||
"""Get current usage summary"""
|
||||
return litellm_config.get_usage_summary()
|
||||
|
||||
def print_usage_summary():
|
||||
"""Print current usage summary"""
|
||||
litellm_config.print_usage_summary()
|
||||
|
||||
def save_usage_data(filename: str = "litellm_usage.json"):
|
||||
"""Save usage data to file"""
|
||||
litellm_config.save_usage_data(filename)
|
||||
|
||||
def get_instant_qps(module: str = None) -> Dict[str, Any]:
|
||||
"""Get instant QPS data for modules"""
|
||||
return litellm_config.get_instant_qps(module)
|
||||
|
||||
def print_instant_qps(module: str = None):
|
||||
"""Print instant QPS information"""
|
||||
qps_data = get_instant_qps(module)
|
||||
|
||||
print("\n⚡ INSTANT QPS MONITOR")
|
||||
print(f"{'='*60}")
|
||||
|
||||
if module:
|
||||
print(f"Module: {qps_data['module']}")
|
||||
print(f" Current QPS: {qps_data['current_qps']}")
|
||||
print(f" Max QPS: {qps_data['max_qps']}")
|
||||
print(f" Avg (1min): {qps_data['avg_qps_last_minute']:.2f}")
|
||||
else:
|
||||
# Global stats
|
||||
global_data = qps_data.get('global', {})
|
||||
print("🌍 GLOBAL:")
|
||||
print(f" Current QPS: {global_data.get('current_qps', 0)}")
|
||||
print(f" Max QPS: {global_data.get('max_qps', 0)}")
|
||||
|
||||
# Module stats
|
||||
modules = qps_data.get('modules', {})
|
||||
if modules:
|
||||
print("\n📦 MODULES:")
|
||||
for mod, data in modules.items():
|
||||
print(f" {mod}:")
|
||||
print(f" Current: {data['current_qps']} QPS")
|
||||
print(f" Max: {data['max_qps']} QPS")
|
||||
print(f" Avg: {data['avg_qps_last_minute']:.2f} QPS")
|
||||
|
||||
print(f"{'='*60}")
|
||||
|
||||
def reset_tracking():
|
||||
"""Reset all tracking data"""
|
||||
litellm_config.reset_tracking()
|
||||
|
||||
def get_module_stats() -> Dict[str, Dict[str, Any]]:
|
||||
"""Get detailed module statistics"""
|
||||
summary = get_usage_summary()
|
||||
return summary.get('module_stats', {})
|
||||
@@ -1,16 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""自我反思工具模块
|
||||
|
||||
本模块提供自我反思引擎的核心功能,包括:
|
||||
- 记忆冲突判定
|
||||
- 反思执行
|
||||
- 记忆更新
|
||||
|
||||
从 app.core.memory.src.data_config_api 迁移而来。
|
||||
"""
|
||||
|
||||
from app.core.memory.utils.self_reflexion_utils.evaluate import conflict
|
||||
from app.core.memory.utils.self_reflexion_utils.reflexion import reflexion
|
||||
from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion
|
||||
|
||||
__all__ = ["conflict", "reflexion", "self_reflexion"]
|
||||
@@ -1,52 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""记忆冲突判定模块
|
||||
|
||||
本模块提供记忆冲突判定功能,使用LLM判断记忆数据中是否存在冲突。
|
||||
从 app.core.memory.src.data_config_api.evaluate 迁移而来。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, List
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_storage_schema import ConflictResultSchema
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
async def conflict(evaluate_data: List[Any]) -> List[Any]:
|
||||
"""
|
||||
Evaluates memory conflict using the evaluate.jinja2 template.
|
||||
|
||||
Args:
|
||||
evaluate_data: 反思数据列表。
|
||||
Returns:
|
||||
冲突记忆列表(JSON 数组)。
|
||||
"""
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
rendered_prompt = await render_evaluate_prompt(evaluate_data, ConflictResultSchema)
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
print(f"提示词长度: {len(rendered_prompt)}")
|
||||
print(f"====== 冲突判定开始 ======\n")
|
||||
start_time = time.time()
|
||||
response = await client.response_structured(messages, ConflictResultSchema)
|
||||
end_time = time.time()
|
||||
print(f"冲突判定耗时: {end_time - start_time} 秒")
|
||||
print(f"冲突判定原始输出:(type={type(response)})\n{response}")
|
||||
|
||||
if not response:
|
||||
logging.error("LLM 冲突判定输出解析失败,返回空列表以继续流程。")
|
||||
return []
|
||||
try:
|
||||
return [response.model_dump()] if isinstance(response, BaseModel) else [response]
|
||||
except Exception:
|
||||
try:
|
||||
return [response.dict()]
|
||||
except Exception:
|
||||
logging.warning("无法标准化冲突判定返回类型,尝试直接封装为列表。")
|
||||
return [response]
|
||||
@@ -1,54 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""反思执行模块
|
||||
|
||||
本模块提供反思执行功能,使用LLM对冲突记忆进行反思和解决。
|
||||
从 app.core.memory.src.data_config_api.reflexion 迁移而来。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, List
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.prompt.template_render import render_reflexion_prompt
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_storage_schema import ReflexionResultSchema
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
async def reflexion(ref_data: List[Any]) -> List[Any]:
|
||||
"""
|
||||
Reflexes on the given reference data using the reflexion.jinja2 template.
|
||||
|
||||
Args:
|
||||
ref_data: 反思数据列表。
|
||||
Returns:
|
||||
反思结果列表(JSON 数组)。
|
||||
"""
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
rendered_prompt = await render_reflexion_prompt(ref_data, ReflexionResultSchema)
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
print(f"提示词长度: {len(rendered_prompt)}")
|
||||
|
||||
print(f"====== 反思开始 ======\n")
|
||||
start_time = time.time()
|
||||
response = await client.response_structured(messages, ReflexionResultSchema)
|
||||
end_time = time.time()
|
||||
print(f"反思耗时: {end_time - start_time} 秒")
|
||||
print(f"反思原始输出:(type={type(response)})\n{response}")
|
||||
|
||||
if not response:
|
||||
logging.error("LLM 反思输出解析失败,返回空列表以继续流程。")
|
||||
return []
|
||||
# 统一返回为列表[dict],便于自我反思主流程更新数据库
|
||||
try:
|
||||
return [response.model_dump()] if isinstance(response, BaseModel) else [response]
|
||||
except Exception:
|
||||
try:
|
||||
return [response.dict()]
|
||||
except Exception:
|
||||
logging.warning("无法标准化反思返回类型,尝试直接封装为列表。")
|
||||
return [response]
|
||||
@@ -1,254 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""自我反思主执行模块
|
||||
|
||||
本模块提供自我反思引擎的主流程,包括:
|
||||
- 获取反思数据
|
||||
- 冲突判断
|
||||
- 反思执行
|
||||
- 记忆更新
|
||||
|
||||
从 app.core.memory.src.data_config_api.self_reflexion 迁移而来。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, List
|
||||
|
||||
#TODO: Fix this
|
||||
|
||||
# Default values (previously from definitions.py)
|
||||
REFLEXION_ENABLED = os.getenv("REFLEXION_ENABLED", "false").lower() == "true"
|
||||
REFLEXION_ITERATION_PERIOD = os.getenv("REFLEXION_ITERATION_PERIOD", "3")
|
||||
REFLEXION_RANGE = os.getenv("REFLEXION_RANGE", "retrieval")
|
||||
REFLEXION_BASELINE = os.getenv("REFLEXION_BASELINE", "TIME")
|
||||
|
||||
from app.core.memory.utils.config.get_data import get_data
|
||||
from app.core.memory.utils.self_reflexion_utils.evaluate import conflict
|
||||
from app.core.memory.utils.self_reflexion_utils.reflexion import reflexion
|
||||
from app.db import get_db
|
||||
from app.models.retrieval_info import RetrievalInfo
|
||||
from app.repositories.neo4j.cypher_queries import UPDATE_STATEMENT_INVALID_AT
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# 并发限制(可通过环境变量覆盖)
|
||||
CONCURRENCY = int(os.getenv("REFLEXION_CONCURRENCY", "5"))
|
||||
|
||||
# 确保 INFO 级别日志输出到终端
|
||||
_root_logger = logging.getLogger()
|
||||
if not _root_logger.handlers:
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
else:
|
||||
_root_logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
async def get_reflexion_data(host_id: uuid.UUID) -> List[Any]:
|
||||
"""
|
||||
根据反思范围获取判断的记忆数据。
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
Returns:
|
||||
符合反思范围的记忆数据列表。
|
||||
"""
|
||||
if REFLEXION_RANGE == "partial":
|
||||
return await get_data(host_id)
|
||||
elif REFLEXION_RANGE == "all":
|
||||
return []
|
||||
else:
|
||||
raise ValueError(f"未知的反思范围: {REFLEXION_RANGE}")
|
||||
|
||||
|
||||
async def run_conflict(conflict_data: List[Any]) -> List[Any]:
|
||||
"""
|
||||
判断反思数据中是否存在冲突。
|
||||
|
||||
Args:
|
||||
conflict_data: 冲突数据列表。
|
||||
Returns:
|
||||
如果存在冲突则返回冲突记忆列表,否则返回空列表。
|
||||
"""
|
||||
if not conflict_data:
|
||||
return []
|
||||
|
||||
conflict_data = await conflict(conflict_data)
|
||||
# 仅保留存在冲突的条目(conflict == True)
|
||||
try:
|
||||
return [c for c in conflict_data if isinstance(c, dict) and c.get("conflict") is True]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
async def run_reflexion(reflexion_data: List[Any]) -> Any:
|
||||
"""
|
||||
执行反思,解决冲突。
|
||||
|
||||
Args:
|
||||
reflexion_data: 反思数据列表。
|
||||
Returns:
|
||||
解决冲突后的反思结果(由 LLM 返回)。
|
||||
"""
|
||||
if not reflexion_data:
|
||||
return []
|
||||
# 并行对每个冲突进行反思,整体缩短等待时间
|
||||
sem = asyncio.Semaphore(CONCURRENCY)
|
||||
|
||||
async def _reflex_one(item: Any) -> Dict[str, Any] | None:
|
||||
async with sem:
|
||||
try:
|
||||
result_list = await reflexion([item])
|
||||
if not result_list:
|
||||
return None
|
||||
obj = result_list[0]
|
||||
if hasattr(obj, "model_dump"):
|
||||
return obj.model_dump()
|
||||
elif hasattr(obj, "dict"):
|
||||
return obj.dict()
|
||||
elif isinstance(obj, dict):
|
||||
return obj
|
||||
except Exception as e:
|
||||
logging.warning(f"反思失败,跳过一项: {e}")
|
||||
return None
|
||||
|
||||
tasks = [_reflex_one(item) for item in reflexion_data]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
return [r for r in results if r]
|
||||
|
||||
|
||||
async def update_memory(solved_data: List[Any], host_id: uuid.UUID) -> str:
|
||||
"""
|
||||
更新记忆库,将解决冲突后的记忆更新到记忆库中。
|
||||
|
||||
Args:
|
||||
solved_data: 解决冲突后的记忆(由 LLM 返回)。
|
||||
host_id: 主机ID
|
||||
Returns:
|
||||
更新结果(成功或失败)。
|
||||
"""
|
||||
flag = False
|
||||
if not solved_data:
|
||||
return "数据缺失,更新失败"
|
||||
if not isinstance(solved_data, list):
|
||||
return "数据格式错误,更新失败"
|
||||
neo4j_connector = Neo4jConnector()
|
||||
try:
|
||||
print(f"====== 更新记忆开始 ======\n")
|
||||
|
||||
sem = asyncio.Semaphore(CONCURRENCY)
|
||||
success_count = 0
|
||||
|
||||
async def _update_one(item: Dict[str, Any]) -> bool:
|
||||
async with sem:
|
||||
try:
|
||||
if not isinstance(item, dict):
|
||||
return False
|
||||
if not item:
|
||||
return False
|
||||
resolved = item.get("resolved")
|
||||
if not isinstance(resolved, dict) or not resolved:
|
||||
logging.warning(f"反思结果无可更新内容,跳过此项: {item}")
|
||||
return False
|
||||
resolved_mem = resolved.get("resolved_memory")
|
||||
if not isinstance(resolved_mem, dict) or not resolved_mem:
|
||||
logging.warning(f"反思结果缺少 resolved_memory,跳过此项: {item}")
|
||||
return False
|
||||
group_id = resolved_mem.get("group_id")
|
||||
id = resolved_mem.get("id")
|
||||
# 使用 invalid_at 字段作为新的失效时间
|
||||
new_invalid_at = resolved_mem.get("invalid_at")
|
||||
if not all([group_id, id, new_invalid_at]):
|
||||
logging.warning(f"记忆更新参数缺失,跳过此项: {item}")
|
||||
return False
|
||||
await neo4j_connector.execute_query(
|
||||
UPDATE_STATEMENT_INVALID_AT,
|
||||
group_id=group_id,
|
||||
id=id,
|
||||
new_invalid_at=new_invalid_at,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"更新单条记忆失败: {e}")
|
||||
return False
|
||||
|
||||
tasks = [_update_one(item) for item in solved_data if isinstance(item, dict)]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
success_count = sum(1 for r in results if r)
|
||||
|
||||
logging.info(f"成功更新 {success_count} 条记忆")
|
||||
flag = success_count > 0
|
||||
return "更新成功" if flag else "更新失败"
|
||||
except Exception as e:
|
||||
logging.error(f"更新记忆库失败: {e}")
|
||||
return "更新失败"
|
||||
finally:
|
||||
if flag: # 删除数据库中的检索数据
|
||||
db: Session = next(get_db())
|
||||
try:
|
||||
db.query(RetrievalInfo).filter(RetrievalInfo.host_id == host_id).delete()
|
||||
db.commit()
|
||||
logging.info(f"成功删除 {success_count} 条检索数据")
|
||||
except Exception as e:
|
||||
logging.error(f"删除数据库中的检索数据失败: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
|
||||
async def _append_json(label: str, data: Any) -> None:
|
||||
"""记录冲突记忆(后台线程写入,避免阻塞事件循环)"""
|
||||
def _write():
|
||||
with open("reflexion_data.json", "a", encoding="utf-8") as f:
|
||||
f.write(f"### {label} ###\n")
|
||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||
f.write("\n\n")
|
||||
# 正确地在协程内等待后台线程执行,避免未等待的协程警告
|
||||
await asyncio.to_thread(_write)
|
||||
|
||||
|
||||
async def self_reflexion(host_id: uuid.UUID) -> str:
|
||||
"""
|
||||
自我反思引擎,执行反思流程。
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
|
||||
Returns:
|
||||
反思结果描述字符串
|
||||
"""
|
||||
if not REFLEXION_ENABLED:
|
||||
return "未开启反思..."
|
||||
print(f"====== 自我反思流程开始 ======\n")
|
||||
reflexion_data = await get_reflexion_data(host_id)
|
||||
if not reflexion_data:
|
||||
print(f"====== 自我反思流程结束 ======\n")
|
||||
return "无反思数据,结束反思"
|
||||
print(f"反思数据获取成功,共 {len(reflexion_data)} 条")
|
||||
|
||||
conflict_data = await run_conflict(reflexion_data)
|
||||
if not conflict_data:
|
||||
print(f"====== 自我反思流程结束 ======\n")
|
||||
return "无冲突,无需反思"
|
||||
print(f"冲突记忆类型: {type(conflict_data)}")
|
||||
await _append_json("conflict", conflict_data)
|
||||
|
||||
solved_data = await run_reflexion(conflict_data)
|
||||
if not solved_data:
|
||||
print(f"====== 自我反思流程结束 ======\n")
|
||||
return "反思失败,未解决冲突"
|
||||
print(f"解决冲突后的记忆类型: {type(solved_data)}")
|
||||
await _append_json("solved_data", solved_data)
|
||||
|
||||
result = await update_memory(solved_data, host_id)
|
||||
print(f"更新记忆库结果: {result}")
|
||||
print(f"====== 自我反思流程结束 ======\n")
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
# host_id = uuid.UUID("3f6ff1eb-50c7-4765-8e89-e4566be33333")
|
||||
host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122")
|
||||
asyncio.run(self_reflexion(host_id))
|
||||
@@ -1,29 +1,30 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
CHUNK_EMBEDDING_SEARCH,
|
||||
ENTITY_EMBEDDING_SEARCH,
|
||||
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||
SEARCH_CHUNKS_BY_CONTENT,
|
||||
SEARCH_DIALOGUE_BY_DIALOG_ID,
|
||||
SEARCH_ENTITIES_BY_NAME,
|
||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
||||
SEARCH_STATEMENTS_BY_CREATED_AT,
|
||||
SEARCH_STATEMENTS_BY_KEYWORD,
|
||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
||||
SEARCH_STATEMENTS_BY_TEMPORAL,
|
||||
SEARCH_STATEMENTS_BY_VALID_AT,
|
||||
SEARCH_STATEMENTS_G_CREATED_AT,
|
||||
SEARCH_STATEMENTS_G_VALID_AT,
|
||||
SEARCH_STATEMENTS_L_CREATED_AT,
|
||||
SEARCH_STATEMENTS_L_VALID_AT,
|
||||
STATEMENT_EMBEDDING_SEARCH,
|
||||
)
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
SEARCH_STATEMENTS_BY_KEYWORD,
|
||||
SEARCH_ENTITIES_BY_NAME,
|
||||
SEARCH_CHUNKS_BY_CONTENT,
|
||||
STATEMENT_EMBEDDING_SEARCH,
|
||||
CHUNK_EMBEDDING_SEARCH,
|
||||
ENTITY_EMBEDDING_SEARCH,
|
||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
||||
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
||||
SEARCH_STATEMENTS_BY_TEMPORAL,
|
||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
||||
SEARCH_DIALOGUE_BY_DIALOG_ID,
|
||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||
SEARCH_STATEMENTS_BY_CREATED_AT,
|
||||
SEARCH_STATEMENTS_BY_VALID_AT,
|
||||
SEARCH_STATEMENTS_G_CREATED_AT,
|
||||
SEARCH_STATEMENTS_L_CREATED_AT,
|
||||
SEARCH_STATEMENTS_G_VALID_AT,
|
||||
SEARCH_STATEMENTS_L_VALID_AT,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -55,8 +56,12 @@ async def _update_activation_values_batch(
|
||||
return []
|
||||
|
||||
# 延迟导入以避免循环依赖
|
||||
from app.core.memory.storage_services.forgetting_engine.access_history_manager import AccessHistoryManager
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
|
||||
from app.core.memory.storage_services.forgetting_engine.access_history_manager import (
|
||||
AccessHistoryManager,
|
||||
)
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import (
|
||||
ACTRCalculator,
|
||||
)
|
||||
|
||||
# 创建计算器和管理器实例
|
||||
actr_calculator = ACTRCalculator()
|
||||
@@ -292,6 +297,13 @@ async def search_graph(
|
||||
else:
|
||||
results[key] = result
|
||||
|
||||
# Deduplicate results before updating activation values
|
||||
# This prevents duplicates from propagating through the pipeline
|
||||
from app.core.memory.src.search import _deduplicate_results
|
||||
for key in results:
|
||||
if isinstance(results[key], list):
|
||||
results[key] = _deduplicate_results(results[key])
|
||||
|
||||
# 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary)
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
@@ -397,6 +409,13 @@ async def search_graph_by_embedding(
|
||||
else:
|
||||
results[key] = result
|
||||
|
||||
# Deduplicate results before updating activation values
|
||||
# This prevents duplicates from propagating through the pipeline
|
||||
from app.core.memory.src.search import _deduplicate_results
|
||||
for key in results:
|
||||
if isinstance(results[key], list):
|
||||
results[key] = _deduplicate_results(results[key])
|
||||
|
||||
# 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary)
|
||||
update_start = time.time()
|
||||
results = await _update_search_results_activation(
|
||||
|
||||
13
web/src/utils/yamlExport.ts
Normal file
13
web/src/utils/yamlExport.ts
Normal file
@@ -0,0 +1,13 @@
|
||||
import yaml from 'js-yaml';
|
||||
|
||||
|
||||
export const exportToYaml = (data: unknown, filename: string = 'export.yaml') => {
|
||||
const yamlStr = yaml.dump(data);
|
||||
const blob = new Blob([yamlStr], { type: 'text/yaml' });
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement('a');
|
||||
a.href = url;
|
||||
a.download = filename;
|
||||
a.click();
|
||||
URL.revokeObjectURL(url);
|
||||
};
|
||||
@@ -221,12 +221,12 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
saveAgentConfig(data.app_id, params)
|
||||
.then(() => {
|
||||
.then((res) => {
|
||||
if (flag) {
|
||||
message.success(t('common.saveSuccess'))
|
||||
}
|
||||
setIsSave(false)
|
||||
resolve(true)
|
||||
resolve(res)
|
||||
}).catch(error => {
|
||||
reject(error)
|
||||
})
|
||||
|
||||
@@ -58,16 +58,14 @@ const Cluster = forwardRef<ClusterRef>((_props, ref) => {
|
||||
}))
|
||||
}
|
||||
|
||||
console.log('params', params)
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
form.validateFields().then(() => {
|
||||
saveMultiAgentConfig(id as string, params)
|
||||
.then(() => {
|
||||
.then((res) => {
|
||||
if (flag) {
|
||||
message.success(t('common.saveSuccess'))
|
||||
}
|
||||
resolve(true)
|
||||
resolve(res)
|
||||
})
|
||||
.catch(error => {
|
||||
reject(error)
|
||||
|
||||
@@ -11,7 +11,7 @@ import exportIcon from '@/assets/images/export_hover.svg'
|
||||
import deleteIcon from '@/assets/images/delete_hover.svg'
|
||||
import type { Application, ApplicationModalRef } from '@/views/ApplicationManagement/types';
|
||||
import ApplicationModal from '@/views/ApplicationManagement/components/ApplicationModal'
|
||||
import type { CopyModalRef, WorkflowRef } from '../types'
|
||||
import type { CopyModalRef, AgentRef, ClusterRef, WorkflowRef } from '../types'
|
||||
import { deleteApplication } from '@/api/application'
|
||||
import CopyModal from './CopyModal'
|
||||
|
||||
@@ -30,10 +30,11 @@ interface ConfigHeaderProps {
|
||||
handleChangeTab: (key: string) => void;
|
||||
refresh: () => void;
|
||||
workflowRef: React.RefObject<WorkflowRef>
|
||||
appRef?: React.RefObject<AgentRef | ClusterRef | WorkflowRef>
|
||||
}
|
||||
const ConfigHeader: FC<ConfigHeaderProps> = ({
|
||||
application, activeTab, handleChangeTab, refresh,
|
||||
workflowRef
|
||||
workflowRef,
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
const navigate = useNavigate();
|
||||
@@ -48,7 +49,7 @@ const ConfigHeader: FC<ConfigHeaderProps> = ({
|
||||
}))
|
||||
}
|
||||
const formatMenuItems = () => {
|
||||
const items = ['edit', 'copy', 'delete'].map(key => ({
|
||||
const items = ['edit', 'copy', 'export', 'delete'].map(key => ({
|
||||
key,
|
||||
icon: <img src={menuIcons[key]} className="rb:w-4 rb:h-4 rb:mr-2" />,
|
||||
label: t(`common.${key}`),
|
||||
@@ -59,7 +60,6 @@ const ConfigHeader: FC<ConfigHeaderProps> = ({
|
||||
}
|
||||
}
|
||||
const handleClick: MenuProps['onClick'] = ({ key }) => {
|
||||
console.log('key', key)
|
||||
switch (key) {
|
||||
case 'edit':
|
||||
applicationModalRef.current?.handleOpen(application as Application)
|
||||
|
||||
@@ -60,10 +60,11 @@ const ApplicationConfig: React.FC = () => {
|
||||
handleChangeTab={handleChangeTab}
|
||||
application={application as Application}
|
||||
refresh={getApplicationInfo}
|
||||
appRef={application?.type === 'agent' ? agentRef : application?.type === 'multi_agent' ? clusterRef : application?.type === 'workflow' ? workflowRef : undefined}
|
||||
workflowRef={workflowRef}
|
||||
/>
|
||||
{activeTab === 'arrangement' && application?.type === 'agent' && <Agent ref={agentRef} />}
|
||||
{activeTab === 'arrangement' && application?.type === 'multi_agent' && <Cluster ref={clusterRef} application={application as Application} />}
|
||||
{activeTab === 'arrangement' && application?.type === 'multi_agent' && <Cluster ref={clusterRef} />}
|
||||
{activeTab === 'arrangement' && application?.type === 'workflow' && <Workflow ref={workflowRef} />}
|
||||
{activeTab === 'api' && <Api application={application} />}
|
||||
{activeTab === 'release' && <ReleasePage data={application as Application} refresh={getApplicationInfo} />}
|
||||
|
||||
@@ -2,7 +2,7 @@ import { useState } from 'react';
|
||||
import { Popover } from 'antd';
|
||||
import clsx from 'clsx';
|
||||
import type { ReactShapeConfig } from '@antv/x6-react-shape';
|
||||
import { nodeLibrary, graphNodeLibrary } from '../../constant';
|
||||
import { nodeLibrary, graphNodeLibrary, edgeAttrs } from '../../constant';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
||||
@@ -47,7 +47,7 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
||||
graph.addEdge({
|
||||
source: { cell: edge.getSourceCellId(), port: edge.getSourcePortId() },
|
||||
target: { cell: newNode.id, port: newNode.getPorts().find((port: any) => port.group === 'left')?.id || 'left' },
|
||||
attrs: edge.getAttrs(),
|
||||
...edgeAttrs
|
||||
});
|
||||
});
|
||||
|
||||
@@ -57,7 +57,7 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
||||
graph.addEdge({
|
||||
source: { cell: newNode.id, port: newNode.getPorts().find((port: any) => port.group === 'right')?.id || 'right' },
|
||||
target: { cell: edge.getTargetCellId(), port: targetPortId },
|
||||
attrs: edge.getAttrs(),
|
||||
...edgeAttrs
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -2,9 +2,7 @@ import { useEffect } from 'react';
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import clsx from 'clsx';
|
||||
import type { ReactShapeConfig } from '@antv/x6-react-shape';
|
||||
import { graphNodeLibrary } from '../../constant';
|
||||
|
||||
import { edge_color } from '../../hooks/useWorkflowGraph'
|
||||
import { graphNodeLibrary, edgeAttrs } from '../../constant';
|
||||
|
||||
const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
||||
const data = node.getData() || {};
|
||||
@@ -56,16 +54,7 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
||||
graph.addEdge({
|
||||
source: { cell: cycleStartNode.id, port: sourcePort },
|
||||
target: { cell: addNode.id, port: targetPort },
|
||||
attrs: {
|
||||
line: {
|
||||
stroke: edge_color,
|
||||
strokeWidth: 1,
|
||||
targetMarker: {
|
||||
name: 'block',
|
||||
size: 8,
|
||||
},
|
||||
},
|
||||
},
|
||||
...edgeAttrs,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -122,16 +111,7 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
||||
cell: addNode.id,
|
||||
port: targetPorts.find((port: any) => port.group === 'left')?.id || 'left'
|
||||
},
|
||||
attrs: {
|
||||
line: {
|
||||
stroke: edge_color,
|
||||
strokeWidth: 1,
|
||||
targetMarker: {
|
||||
name: 'block',
|
||||
size: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
...edgeAttrs
|
||||
}
|
||||
graph.addEdge(edgeConfig)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { useEffect, useState } from 'react';
|
||||
import { Popover } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { nodeLibrary, graphNodeLibrary } from '../constant';
|
||||
import { nodeLibrary, graphNodeLibrary, edgeAttrs } from '../constant';
|
||||
|
||||
interface PortClickHandlerProps {
|
||||
graph: any;
|
||||
@@ -149,16 +149,7 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
|
||||
graph.addEdge({
|
||||
source: { cell: sourceNode.id, port: sourcePort },
|
||||
target: { cell: newNode.id, port: targetPort },
|
||||
attrs: {
|
||||
line: {
|
||||
stroke: '#155EEF',
|
||||
strokeWidth: 1,
|
||||
targetMarker: {
|
||||
name: 'block',
|
||||
size: 8,
|
||||
},
|
||||
},
|
||||
},
|
||||
...edgeAttrs
|
||||
// zIndex: sourceNodeData.cycle && sourceNodeType == 'cycle-start' ? 1 : sourceNodeData.cycle ? 2 : 0
|
||||
});
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import { Form, Button, Select, Space, Divider, InputNumber, Radio, type SelectPr
|
||||
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
|
||||
import VariableSelect from '../VariableSelect'
|
||||
import Editor from '../../Editor'
|
||||
import { edgeAttrs } from '../../../constant'
|
||||
|
||||
interface CaseListProps {
|
||||
value?: Array<{ logical_operator: 'and' | 'or'; expressions: { left: string; operator: string; right: string; input_type?: string; }[] }>;
|
||||
@@ -120,16 +121,7 @@ const CaseList: FC<CaseListProps> = ({
|
||||
graphRef.current?.addEdge({
|
||||
source: { cell: sourceCellId, port: sourcePortId },
|
||||
target: { cell: selectedNode.id, port: targetPortId },
|
||||
attrs: {
|
||||
line: {
|
||||
stroke: '#155EEF',
|
||||
strokeWidth: 1,
|
||||
targetMarker: {
|
||||
name: 'block',
|
||||
size: 8,
|
||||
},
|
||||
},
|
||||
},
|
||||
...edgeAttrs,
|
||||
});
|
||||
}
|
||||
graphRef.current?.removeCell(edge);
|
||||
@@ -174,16 +166,7 @@ const CaseList: FC<CaseListProps> = ({
|
||||
graphRef.current?.addEdge({
|
||||
source: { cell: selectedNode.id, port: newPortId },
|
||||
target: { cell: targetCellId, port: targetPortId },
|
||||
attrs: {
|
||||
line: {
|
||||
stroke: '#155EEF',
|
||||
strokeWidth: 1,
|
||||
targetMarker: {
|
||||
name: 'block',
|
||||
size: 8,
|
||||
},
|
||||
},
|
||||
},
|
||||
...edgeAttrs
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import { Graph, Node } from '@antv/x6';
|
||||
|
||||
import Editor from '../../Editor';
|
||||
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
|
||||
import { edgeAttrs } from '../../../constant'
|
||||
|
||||
interface CategoryListProps {
|
||||
parentName: string;
|
||||
@@ -70,16 +71,7 @@ const CategoryList: FC<CategoryListProps> = ({ parentName, selectedNode, graphRe
|
||||
graphRef.current?.addEdge({
|
||||
source: { cell: sourceCellId, port: sourcePortId },
|
||||
target: { cell: selectedNode.id, port: targetPortId },
|
||||
attrs: {
|
||||
line: {
|
||||
stroke: '#155EEF',
|
||||
strokeWidth: 1,
|
||||
targetMarker: {
|
||||
name: 'block',
|
||||
size: 8,
|
||||
},
|
||||
},
|
||||
},
|
||||
...edgeAttrs
|
||||
});
|
||||
}
|
||||
return;
|
||||
@@ -110,16 +102,7 @@ const CategoryList: FC<CategoryListProps> = ({ parentName, selectedNode, graphRe
|
||||
graphRef.current?.addEdge({
|
||||
source: { cell: selectedNode.id, port: newPortId },
|
||||
target: { cell: targetCellId, port: targetPortId },
|
||||
attrs: {
|
||||
line: {
|
||||
stroke: '#155EEF',
|
||||
strokeWidth: 1,
|
||||
targetMarker: {
|
||||
name: 'block',
|
||||
size: 8,
|
||||
},
|
||||
},
|
||||
},
|
||||
...edgeAttrs
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { type FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Form, Select, Input, Button } from 'antd'
|
||||
import { Form, Select, Input, Button, InputNumber, Radio } from 'antd'
|
||||
import VariableSelect from '../VariableSelect'
|
||||
|
||||
import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin'
|
||||
@@ -93,6 +93,7 @@ const CycleVarsList: FC<CycleVarsListProps> = ({
|
||||
</Button>
|
||||
</div>
|
||||
{fields.map(({ key, name }, index) => {
|
||||
const currentType = value?.[index]?.type;
|
||||
const currentInputType = value?.[index]?.input_type;
|
||||
|
||||
return (
|
||||
@@ -131,7 +132,8 @@ const CycleVarsList: FC<CycleVarsListProps> = ({
|
||||
</div>
|
||||
|
||||
<Form.Item name={[name, 'value']} noStyle>
|
||||
{currentInputType === 'variable' ? (
|
||||
{currentInputType === 'variable'
|
||||
? (
|
||||
<VariableSelect
|
||||
placeholder={t('common.pleaseSelect')}
|
||||
options={availableOptions.filter(option => {
|
||||
@@ -143,7 +145,20 @@ const CycleVarsList: FC<CycleVarsListProps> = ({
|
||||
variant="borderless"
|
||||
size="small"
|
||||
/>
|
||||
) : (
|
||||
)
|
||||
: currentType === 'number'
|
||||
? <InputNumber
|
||||
placeholder={t('common.pleaseEnter')}
|
||||
variant="borderless"
|
||||
className="rb:w-full! rb:my-1!"
|
||||
onChange={(value) => form.setFieldValue([name, 'value'], value)}
|
||||
/>
|
||||
: currentType === 'boolean'
|
||||
? <Radio.Group block>
|
||||
<Radio.Button value={true}>True</Radio.Button>
|
||||
<Radio.Button value={false}>False</Radio.Button>
|
||||
</Radio.Group>
|
||||
: (
|
||||
<Input.TextArea
|
||||
placeholder={t('common.pleaseEnter')}
|
||||
rows={3}
|
||||
|
||||
@@ -517,6 +517,8 @@ interface NodeConfig {
|
||||
ports?: PortsConfig;
|
||||
}
|
||||
|
||||
export const edge_color = '#155EEF';
|
||||
export const edge_selected_color = '#4DA8FF'
|
||||
// 统一的端口 markup 配置
|
||||
export const portMarkup = [
|
||||
{
|
||||
@@ -534,9 +536,9 @@ export const portAttrs = {
|
||||
body: {
|
||||
r: 6,
|
||||
magnet: true,
|
||||
stroke: '#155EEF',
|
||||
stroke: edge_color,
|
||||
strokeWidth: 2,
|
||||
fill: '#155EEF',
|
||||
fill: edge_color,
|
||||
},
|
||||
label: {
|
||||
text: '+',
|
||||
@@ -776,4 +778,18 @@ export const outputVariable: { [key: string]: OutputVariable } = {
|
||||
{ name: "output", type: "string" },
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
export const edgeAttrs = {
|
||||
attrs: {
|
||||
line: {
|
||||
stroke: edge_color,
|
||||
strokeWidth: 1,
|
||||
targetMarker: {
|
||||
name: 'block',
|
||||
width: 4,
|
||||
height: 4,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import { App } from 'antd'
|
||||
import { Graph, Node, MiniMap, Snapline, Clipboard, Keyboard, type Edge } from '@antv/x6';
|
||||
import { register } from '@antv/x6-react-shape';
|
||||
|
||||
import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs } from '../constant';
|
||||
import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs, edgeAttrs, edge_color, edge_selected_color } from '../constant';
|
||||
import type { WorkflowConfig, NodeProperties, ChatVariable } from '../types';
|
||||
import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application'
|
||||
import type { PortMetadata } from '@antv/x6/lib/model/port';
|
||||
@@ -23,12 +23,8 @@ export interface UseWorkflowGraphReturn {
|
||||
setSelectedNode: React.Dispatch<React.SetStateAction<Node | null>>;
|
||||
zoomLevel: number;
|
||||
setZoomLevel: React.Dispatch<React.SetStateAction<number>>;
|
||||
canUndo: boolean;
|
||||
canRedo: boolean;
|
||||
isHandMode: boolean;
|
||||
setIsHandMode: React.Dispatch<React.SetStateAction<boolean>>;
|
||||
onUndo: () => void;
|
||||
onRedo: () => void;
|
||||
onDrop: (event: React.DragEvent) => void;
|
||||
blankClick: () => void;
|
||||
deleteEvent: () => boolean | void;
|
||||
@@ -39,8 +35,6 @@ export interface UseWorkflowGraphReturn {
|
||||
setChatVariables: React.Dispatch<React.SetStateAction<ChatVariable[]>>;
|
||||
}
|
||||
|
||||
export const edge_color = '#155EEF';
|
||||
const edge_selected_color = '#4DA8FF'
|
||||
export const useWorkflowGraph = ({
|
||||
containerRef,
|
||||
miniMapRef,
|
||||
@@ -51,9 +45,6 @@ export const useWorkflowGraph = ({
|
||||
const graphRef = useRef<Graph>();
|
||||
const [selectedNode, setSelectedNode] = useState<Node | null>(null);
|
||||
const [zoomLevel, setZoomLevel] = useState(1);
|
||||
const historyRef = useRef<{ undoStack: string[], redoStack: string[] }>({ undoStack: [], redoStack: [] });
|
||||
const [canUndo, setCanUndo] = useState(false);
|
||||
const [canRedo, setCanRedo] = useState(false);
|
||||
const [isHandMode, setIsHandMode] = useState(true);
|
||||
const [config, setConfig] = useState<WorkflowConfig | null>(null);
|
||||
const [chatVariables, setChatVariables] = useState<ChatVariable[]>([])
|
||||
@@ -338,17 +329,7 @@ export const useWorkflowGraph = ({
|
||||
port: targetPorts.find((port: any) => port.group === 'left')?.id || 'left'
|
||||
},
|
||||
connector: { name: 'smooth' },
|
||||
attrs: {
|
||||
line: {
|
||||
stroke: edge_color,
|
||||
strokeWidth: 1,
|
||||
targetMarker: {
|
||||
name: 'diamond',
|
||||
width: 4,
|
||||
height: 4,
|
||||
},
|
||||
},
|
||||
},
|
||||
...edgeAttrs
|
||||
// zIndex: loopIterationCount
|
||||
}
|
||||
|
||||
@@ -368,48 +349,6 @@ export const useWorkflowGraph = ({
|
||||
}, 200)
|
||||
}
|
||||
}
|
||||
|
||||
const saveState = () => {
|
||||
if (!graphRef.current) return;
|
||||
const state = JSON.stringify(graphRef.current.toJSON());
|
||||
historyRef.current.undoStack.push(state);
|
||||
historyRef.current.redoStack = [];
|
||||
if (historyRef.current.undoStack.length > 50) {
|
||||
historyRef.current.undoStack.shift();
|
||||
}
|
||||
updateHistoryState();
|
||||
};
|
||||
|
||||
const updateHistoryState = () => {
|
||||
setCanUndo(historyRef.current.undoStack.length > 1);
|
||||
setCanRedo(historyRef.current.redoStack.length > 0);
|
||||
};
|
||||
|
||||
// 撤销
|
||||
const onUndo = () => {
|
||||
if (!graphRef.current || historyRef.current.undoStack.length === 0) return;
|
||||
const { undoStack = [], redoStack = [] } = historyRef.current
|
||||
|
||||
const currentState = JSON.stringify(graphRef.current.toJSON());
|
||||
const prevState = undoStack[undoStack.length - 2];
|
||||
|
||||
historyRef.current.redoStack = [...redoStack, currentState]
|
||||
historyRef.current.undoStack = undoStack.slice(0, undoStack.length - 1)
|
||||
graphRef.current.fromJSON(JSON.parse(prevState));
|
||||
updateHistoryState();
|
||||
};
|
||||
// 重做
|
||||
const onRedo = () => {
|
||||
if (!graphRef.current || historyRef.current.redoStack.length === 0) return;
|
||||
const { undoStack = [], redoStack = [] } = historyRef.current
|
||||
|
||||
const nextState = redoStack[redoStack.length - 1];
|
||||
|
||||
historyRef.current.undoStack = [...undoStack, nextState]
|
||||
historyRef.current.redoStack = redoStack.slice(0, redoStack.length - 1)
|
||||
graphRef.current.fromJSON(JSON.parse(nextState));
|
||||
updateHistoryState();
|
||||
};
|
||||
// 使用插件
|
||||
const setupPlugins = () => {
|
||||
if (!graphRef.current || !miniMapRef.current) return;
|
||||
@@ -563,20 +502,6 @@ export const useWorkflowGraph = ({
|
||||
}
|
||||
return false;
|
||||
};
|
||||
// 撤销快捷键事件
|
||||
const undoEvent = () => {
|
||||
if (canUndo) {
|
||||
onUndo();
|
||||
}
|
||||
return false;
|
||||
};
|
||||
// 重做快捷键事件
|
||||
const redoEvent = () => {
|
||||
if (canRedo) {
|
||||
onRedo();
|
||||
}
|
||||
return false;
|
||||
};
|
||||
// 删除选中的节点和连线事件
|
||||
const deleteEvent = () => {
|
||||
if (!graphRef.current) return;
|
||||
@@ -748,8 +673,6 @@ export const useWorkflowGraph = ({
|
||||
background: {
|
||||
color: '#F0F3F8',
|
||||
},
|
||||
// width: container.clientWidth || 800,
|
||||
// height: container.clientHeight || 600,
|
||||
autoResize: true,
|
||||
grid: {
|
||||
visible: true,
|
||||
@@ -765,37 +688,26 @@ export const useWorkflowGraph = ({
|
||||
enabled: true,
|
||||
},
|
||||
connecting: {
|
||||
// router: 'orth',
|
||||
// router: 'manhattan',
|
||||
connector: {
|
||||
name: 'smooth',
|
||||
args: {
|
||||
radius: 8,
|
||||
},
|
||||
},
|
||||
anchor: 'center',
|
||||
anchor: 'midSide',
|
||||
connectionPoint: 'anchor',
|
||||
allowBlank: false,
|
||||
allowLoop: false,
|
||||
allowNode: false,
|
||||
allowEdge: false,
|
||||
allowPort: true,
|
||||
allowMulti: true,
|
||||
highlight: true,
|
||||
snap: {
|
||||
radius: 20,
|
||||
},
|
||||
createEdge() {
|
||||
return graphRef.current?.createEdge({
|
||||
attrs: {
|
||||
line: {
|
||||
stroke: edge_color,
|
||||
strokeWidth: 1,
|
||||
targetMarker: {
|
||||
name: 'diamond',
|
||||
width: 4,
|
||||
height: 4,
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
return graphRef.current?.createEdge(edgeAttrs);
|
||||
},
|
||||
validateConnection({ sourceCell, targetCell, targetMagnet }) {
|
||||
if (!targetMagnet) return false;
|
||||
@@ -901,27 +813,7 @@ export const useWorkflowGraph = ({
|
||||
// 监听缩放事件
|
||||
graphRef.current.on('scale', scaleEvent);
|
||||
// 监听节点移动事件
|
||||
// graphRef.current.on('node:moved', nodeMoved);
|
||||
graphRef.current.on('node:change:position', nodeChangePosition);
|
||||
|
||||
// 监听画布变化事件
|
||||
const events = [
|
||||
'node:added',
|
||||
'node:removed',
|
||||
'edge:added',
|
||||
'edge:removed',
|
||||
];
|
||||
events.forEach(event => {
|
||||
graphRef.current!.on(event, () => {
|
||||
console.log('event', event);
|
||||
setTimeout(() => saveState(), 50);
|
||||
});
|
||||
});
|
||||
|
||||
// 监听撤销键盘事件
|
||||
graphRef.current.bindKey(['ctrl+z', 'cmd+z'], undoEvent);
|
||||
// 监听重做键盘事件
|
||||
graphRef.current.bindKey(['ctrl+shift+z', 'cmd+shift+z', 'ctrl+y', 'cmd+y'], redoEvent);
|
||||
graphRef.current.on('node:moved', nodeMoved);
|
||||
// 监听复制键盘事件
|
||||
graphRef.current.bindKey(['ctrl+c', 'cmd+c'], copyEvent);
|
||||
// 监听粘贴键盘事件
|
||||
@@ -929,11 +821,6 @@ export const useWorkflowGraph = ({
|
||||
// 删除选中的节点和连线
|
||||
graphRef.current.bindKey(['ctrl+d', 'cmd+d', 'delete', 'backspace'], deleteEvent);
|
||||
|
||||
// 保存初始状态
|
||||
setTimeout(() => saveState(), 100);
|
||||
// init window hook
|
||||
(window as Window & { __x6_instances__?: Graph[] }).__x6_instances__ = [];
|
||||
(window as Window & { __x6_instances__?: Graph[] }).__x6_instances__?.push(graphRef.current);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
@@ -1146,11 +1033,11 @@ export const useWorkflowGraph = ({
|
||||
}),
|
||||
}
|
||||
saveWorkflowConfig(config.app_id, params as WorkflowConfig)
|
||||
.then(() => {
|
||||
.then((res) => {
|
||||
if (flag) {
|
||||
message.success(t('common.saveSuccess'))
|
||||
}
|
||||
resolve(true)
|
||||
resolve(res)
|
||||
}).catch(error => {
|
||||
reject(error)
|
||||
})
|
||||
@@ -1165,12 +1052,8 @@ export const useWorkflowGraph = ({
|
||||
setSelectedNode,
|
||||
zoomLevel,
|
||||
setZoomLevel,
|
||||
canUndo,
|
||||
canRedo,
|
||||
isHandMode,
|
||||
setIsHandMode,
|
||||
onUndo,
|
||||
onRedo,
|
||||
onDrop,
|
||||
blankClick,
|
||||
deleteEvent,
|
||||
|
||||
Reference in New Issue
Block a user