Merge branch 'develop' into feature/codeNode_zy

This commit is contained in:
yingzhao
2026-01-27 11:40:54 +08:00
committed by GitHub
213 changed files with 4506 additions and 9864 deletions

3
.gitignore vendored
View File

@@ -35,3 +35,6 @@ nltk_data/
tika-server*.jar* tika-server*.jar*
cl100k_base.tiktoken cl100k_base.tiktoken
libssl*.deb libssl*.deb
sandbox/lib/seccomp_python/target
sandbox/lib/seccomp_nodejs/target

0
api/app/__init__.py Normal file
View File

View File

@@ -12,6 +12,7 @@ from fastapi import APIRouter, Depends, Query, HTTPException, status
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Optional from typing import Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from uuid import UUID
from app.core.response_utils import success from app.core.response_utils import success
from app.dependencies import get_current_user from app.dependencies import get_current_user
@@ -32,11 +33,11 @@ router = APIRouter(
class EmotionConfigQuery(BaseModel): class EmotionConfigQuery(BaseModel):
"""情绪配置查询请求模型""" """情绪配置查询请求模型"""
config_id: int = Field(..., description="配置ID") config_id: UUID = Field(..., description="配置ID")
class EmotionConfigUpdate(BaseModel): class EmotionConfigUpdate(BaseModel):
"""情绪配置更新请求模型""" """情绪配置更新请求模型"""
config_id: int = Field(..., description="配置ID") config_id: UUID = Field(..., description="配置ID")
emotion_enabled: bool = Field(..., description="是否启用情绪提取") emotion_enabled: bool = Field(..., description="是否启用情绪提取")
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID") emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词") emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
@@ -45,7 +46,7 @@ class EmotionConfigUpdate(BaseModel):
@router.get("/read_config", response_model=ApiResponse) @router.get("/read_config", response_model=ApiResponse)
def get_emotion_config( def get_emotion_config(
config_id: int = Query(..., description="配置ID"), config_id: UUID = Query(..., description="配置ID"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):

View File

@@ -53,7 +53,7 @@ async def get_emotion_tags(
api_logger.info( api_logger.info(
f"用户 {current_user.username} 请求获取情绪标签统计", f"用户 {current_user.username} 请求获取情绪标签统计",
extra={ extra={
"group_id": request.group_id, "end_user_id": request.end_user_id,
"emotion_type": request.emotion_type, "emotion_type": request.emotion_type,
"start_date": request.start_date, "start_date": request.start_date,
"end_date": request.end_date, "end_date": request.end_date,
@@ -63,7 +63,7 @@ async def get_emotion_tags(
# 调用服务层 # 调用服务层
data = await emotion_service.get_emotion_tags( data = await emotion_service.get_emotion_tags(
end_user_id=request.group_id, end_user_id=request.end_user_id,
emotion_type=request.emotion_type, emotion_type=request.emotion_type,
start_date=request.start_date, start_date=request.start_date,
end_date=request.end_date, end_date=request.end_date,
@@ -73,7 +73,7 @@ async def get_emotion_tags(
api_logger.info( api_logger.info(
"情绪标签统计获取成功", "情绪标签统计获取成功",
extra={ extra={
"group_id": request.group_id, "end_user_id": request.end_user_id,
"total_count": data.get("total_count", 0), "total_count": data.get("total_count", 0),
"tags_count": len(data.get("tags", [])) "tags_count": len(data.get("tags", []))
} }
@@ -84,7 +84,7 @@ async def get_emotion_tags(
except Exception as e: except Exception as e:
api_logger.error( api_logger.error(
f"获取情绪标签统计失败: {str(e)}", f"获取情绪标签统计失败: {str(e)}",
extra={"group_id": request.group_id}, extra={"end_user_id": request.end_user_id},
exc_info=True exc_info=True
) )
raise HTTPException( raise HTTPException(
@@ -105,7 +105,7 @@ async def get_emotion_wordcloud(
api_logger.info( api_logger.info(
f"用户 {current_user.username} 请求获取情绪词云数据", f"用户 {current_user.username} 请求获取情绪词云数据",
extra={ extra={
"group_id": request.group_id, "end_user_id": request.end_user_id,
"emotion_type": request.emotion_type, "emotion_type": request.emotion_type,
"limit": request.limit "limit": request.limit
} }
@@ -113,7 +113,7 @@ async def get_emotion_wordcloud(
# 调用服务层 # 调用服务层
data = await emotion_service.get_emotion_wordcloud( data = await emotion_service.get_emotion_wordcloud(
end_user_id=request.group_id, end_user_id=request.end_user_id,
emotion_type=request.emotion_type, emotion_type=request.emotion_type,
limit=request.limit limit=request.limit
) )
@@ -121,7 +121,7 @@ async def get_emotion_wordcloud(
api_logger.info( api_logger.info(
"情绪词云数据获取成功", "情绪词云数据获取成功",
extra={ extra={
"group_id": request.group_id, "end_user_id": request.end_user_id,
"total_keywords": data.get("total_keywords", 0) "total_keywords": data.get("total_keywords", 0)
} }
) )
@@ -131,7 +131,7 @@ async def get_emotion_wordcloud(
except Exception as e: except Exception as e:
api_logger.error( api_logger.error(
f"获取情绪词云数据失败: {str(e)}", f"获取情绪词云数据失败: {str(e)}",
extra={"group_id": request.group_id}, extra={"end_user_id": request.end_user_id},
exc_info=True exc_info=True
) )
raise HTTPException( raise HTTPException(
@@ -159,21 +159,21 @@ async def get_emotion_health(
api_logger.info( api_logger.info(
f"用户 {current_user.username} 请求获取情绪健康指数", f"用户 {current_user.username} 请求获取情绪健康指数",
extra={ extra={
"group_id": request.group_id, "end_user_id": request.end_user_id,
"time_range": request.time_range "time_range": request.time_range
} }
) )
# 调用服务层 # 调用服务层
data = await emotion_service.calculate_emotion_health_index( data = await emotion_service.calculate_emotion_health_index(
end_user_id=request.group_id, end_user_id=request.end_user_id,
time_range=request.time_range time_range=request.time_range
) )
api_logger.info( api_logger.info(
"情绪健康指数获取成功", "情绪健康指数获取成功",
extra={ extra={
"group_id": request.group_id, "end_user_id": request.end_user_id,
"health_score": data.get("health_score", 0), "health_score": data.get("health_score", 0),
"level": data.get("level", "未知") "level": data.get("level", "未知")
} }
@@ -186,7 +186,7 @@ async def get_emotion_health(
except Exception as e: except Exception as e:
api_logger.error( api_logger.error(
f"获取情绪健康指数失败: {str(e)}", f"获取情绪健康指数失败: {str(e)}",
extra={"group_id": request.group_id}, extra={"end_user_id": request.end_user_id},
exc_info=True exc_info=True
) )
raise HTTPException( raise HTTPException(
@@ -206,7 +206,7 @@ async def get_emotion_suggestions(
"""获取个性化情绪建议(从缓存读取) """获取个性化情绪建议(从缓存读取)
Args: Args:
request: 包含 group_id 和可选的 config_id request: 包含 end_user_id 和可选的 config_id
db: 数据库会话 db: 数据库会话
current_user: 当前用户 current_user: 当前用户
@@ -217,22 +217,22 @@ async def get_emotion_suggestions(
api_logger.info( api_logger.info(
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)", f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
extra={ extra={
"group_id": request.group_id, "end_user_id": request.end_user_id,
"config_id": request.config_id "config_id": request.config_id
} }
) )
# 从缓存获取建议 # 从缓存获取建议
data = await emotion_service.get_cached_suggestions( data = await emotion_service.get_cached_suggestions(
end_user_id=request.group_id, end_user_id=request.end_user_id,
db=db db=db
) )
if data is None: if data is None:
# 缓存不存在或已过期 # 缓存不存在或已过期
api_logger.info( api_logger.info(
f"用户 {request.group_id} 的建议缓存不存在或已过期", f"用户 {request.end_user_id} 的建议缓存不存在或已过期",
extra={"group_id": request.group_id} extra={"end_user_id": request.end_user_id}
) )
return fail( return fail(
BizCode.NOT_FOUND, BizCode.NOT_FOUND,
@@ -243,7 +243,7 @@ async def get_emotion_suggestions(
api_logger.info( api_logger.info(
"个性化建议获取成功(缓存)", "个性化建议获取成功(缓存)",
extra={ extra={
"group_id": request.group_id, "end_user_id": request.end_user_id,
"suggestions_count": len(data.get("suggestions", [])) "suggestions_count": len(data.get("suggestions", []))
} }
) )
@@ -253,7 +253,7 @@ async def get_emotion_suggestions(
except Exception as e: except Exception as e:
api_logger.error( api_logger.error(
f"获取个性化建议失败: {str(e)}", f"获取个性化建议失败: {str(e)}",
extra={"group_id": request.group_id}, extra={"end_user_id": request.end_user_id},
exc_info=True exc_info=True
) )
raise HTTPException( raise HTTPException(

View File

@@ -122,10 +122,10 @@ def validate_confidence_threshold(threshold: float) -> None:
raise ValueError("confidence_threshold must be between 0.0 and 1.0") raise ValueError("confidence_threshold must be between 0.0 and 1.0")
@router.get("/preferences/{user_id}", response_model=ApiResponse) @router.get("/preferences/{end_user_id}", response_model=ApiResponse)
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def get_preference_tags( async def get_preference_tags(
user_id: str, end_user_id: str,
confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"), confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"),
tag_category: Optional[str] = Query(None, description="Filter by tag category"), tag_category: Optional[str] = Query(None, description="Filter by tag category"),
start_date: Optional[datetime] = Query(None, description="Filter start date"), start_date: Optional[datetime] = Query(None, description="Filter start date"),
@@ -137,7 +137,7 @@ async def get_preference_tags(
Get user preference tags from cache. Get user preference tags from cache.
Args: Args:
user_id: Target user ID end_user_id: Target end user ID
confidence_threshold: Minimum confidence score (0.0-1.0) confidence_threshold: Minimum confidence score (0.0-1.0)
tag_category: Optional category filter tag_category: Optional category filter
start_date: Optional start date filter start_date: Optional start date filter
@@ -146,20 +146,20 @@ async def get_preference_tags(
Returns: Returns:
List of preference tags from cache List of preference tags from cache
""" """
api_logger.info(f"Preference tags requested for user: {user_id} (from cache)") api_logger.info(f"Preference tags requested for user: {end_user_id} (from cache)")
try: try:
# Validate inputs # Validate inputs
validate_user_id(user_id) validate_user_id(end_user_id)
# Create service with user-specific config # Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id) service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
# Get cached profile # Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db) cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None: if cached_profile is None:
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期") api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail( return fail(
BizCode.NOT_FOUND, BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像", "画像缓存不存在或已过期,请右上角刷新生成新画像",
@@ -192,17 +192,17 @@ async def get_preference_tags(
filtered_preferences.append(pref) filtered_preferences.append(pref)
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {user_id} (from cache)") api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {end_user_id} (from cache)")
return success(data=filtered_preferences, msg="偏好标签获取成功(缓存)") return success(data=filtered_preferences, msg="偏好标签获取成功(缓存)")
except Exception as e: except Exception as e:
return handle_implicit_memory_error(e, "偏好标签获取", user_id) return handle_implicit_memory_error(e, "偏好标签获取", end_user_id)
@router.get("/portrait/{user_id}", response_model=ApiResponse) @router.get("/portrait/{end_user_id}", response_model=ApiResponse)
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def get_dimension_portrait( async def get_dimension_portrait(
user_id: str, end_user_id: str,
include_history: bool = Query(False, description="Include historical trends"), include_history: bool = Query(False, description="Include historical trends"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
@@ -211,26 +211,26 @@ async def get_dimension_portrait(
Get user's four-dimension personality portrait from cache. Get user's four-dimension personality portrait from cache.
Args: Args:
user_id: Target user ID end_user_id: Target end user ID
include_history: Whether to include historical trend data (ignored for cached data) include_history: Whether to include historical trend data (ignored for cached data)
Returns: Returns:
Four-dimension personality portrait from cache Four-dimension personality portrait from cache
""" """
api_logger.info(f"Dimension portrait requested for user: {user_id} (from cache)") api_logger.info(f"Dimension portrait requested for user: {end_user_id} (from cache)")
try: try:
# Validate inputs # Validate inputs
validate_user_id(user_id) validate_user_id(end_user_id)
# Create service with user-specific config # Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id) service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
# Get cached profile # Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db) cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None: if cached_profile is None:
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期") api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail( return fail(
BizCode.NOT_FOUND, BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像", "画像缓存不存在或已过期,请右上角刷新生成新画像",
@@ -240,17 +240,17 @@ async def get_dimension_portrait(
# Extract portrait from cache # Extract portrait from cache
portrait = cached_profile.get("portrait", {}) portrait = cached_profile.get("portrait", {})
api_logger.info(f"Dimension portrait retrieved for user: {user_id} (from cache)") api_logger.info(f"Dimension portrait retrieved for user: {end_user_id} (from cache)")
return success(data=portrait, msg="四维画像获取成功(缓存)") return success(data=portrait, msg="四维画像获取成功(缓存)")
except Exception as e: except Exception as e:
return handle_implicit_memory_error(e, "四维画像获取", user_id) return handle_implicit_memory_error(e, "四维画像获取", end_user_id)
@router.get("/interest-areas/{user_id}", response_model=ApiResponse) @router.get("/interest-areas/{end_user_id}", response_model=ApiResponse)
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def get_interest_area_distribution( async def get_interest_area_distribution(
user_id: str, end_user_id: str,
include_trends: bool = Query(False, description="Include trend analysis"), include_trends: bool = Query(False, description="Include trend analysis"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
@@ -259,26 +259,26 @@ async def get_interest_area_distribution(
Get user's interest area distribution from cache. Get user's interest area distribution from cache.
Args: Args:
user_id: Target user ID end_user_id: Target end user ID
include_trends: Whether to include trend analysis data (ignored for cached data) include_trends: Whether to include trend analysis data (ignored for cached data)
Returns: Returns:
Interest area distribution from cache Interest area distribution from cache
""" """
api_logger.info(f"Interest area distribution requested for user: {user_id} (from cache)") api_logger.info(f"Interest area distribution requested for user: {end_user_id} (from cache)")
try: try:
# Validate inputs # Validate inputs
validate_user_id(user_id) validate_user_id(end_user_id)
# Create service with user-specific config # Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id) service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
# Get cached profile # Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db) cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None: if cached_profile is None:
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期") api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail( return fail(
BizCode.NOT_FOUND, BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像", "画像缓存不存在或已过期,请右上角刷新生成新画像",
@@ -288,17 +288,17 @@ async def get_interest_area_distribution(
# Extract interest areas from cache # Extract interest areas from cache
interest_areas = cached_profile.get("interest_areas", {}) interest_areas = cached_profile.get("interest_areas", {})
api_logger.info(f"Interest area distribution retrieved for user: {user_id} (from cache)") api_logger.info(f"Interest area distribution retrieved for user: {end_user_id} (from cache)")
return success(data=interest_areas, msg="兴趣领域分布获取成功(缓存)") return success(data=interest_areas, msg="兴趣领域分布获取成功(缓存)")
except Exception as e: except Exception as e:
return handle_implicit_memory_error(e, "兴趣领域分布获取", user_id) return handle_implicit_memory_error(e, "兴趣领域分布获取", end_user_id)
@router.get("/habits/{user_id}", response_model=ApiResponse) @router.get("/habits/{end_user_id}", response_model=ApiResponse)
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def get_behavior_habits( async def get_behavior_habits(
user_id: str, end_user_id: str,
confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"), confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"),
frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"), frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"),
time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"), time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"),
@@ -309,7 +309,7 @@ async def get_behavior_habits(
Get user's behavioral habits from cache. Get user's behavioral habits from cache.
Args: Args:
user_id: Target user ID end_user_id: Target end user ID
confidence_level: Filter by confidence level (high, medium, low) confidence_level: Filter by confidence level (high, medium, low)
frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered) frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered)
time_period: Filter by time period (current, past) time_period: Filter by time period (current, past)
@@ -317,20 +317,20 @@ async def get_behavior_habits(
Returns: Returns:
List of behavioral habits from cache List of behavioral habits from cache
""" """
api_logger.info(f"Behavior habits requested for user: {user_id} (from cache)") api_logger.info(f"Behavior habits requested for user: {end_user_id} (from cache)")
try: try:
# Validate inputs # Validate inputs
validate_user_id(user_id) validate_user_id(end_user_id)
# Create service with user-specific config # Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id) service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
# Get cached profile # Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db) cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None: if cached_profile is None:
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期") api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail( return fail(
BizCode.NOT_FOUND, BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像", "画像缓存不存在或已过期,请右上角刷新生成新画像",
@@ -368,11 +368,11 @@ async def get_behavior_habits(
filtered_habits.append(habit) filtered_habits.append(habit)
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {user_id} (from cache)") api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {end_user_id} (from cache)")
return success(data=filtered_habits, msg="行为习惯获取成功(缓存)") return success(data=filtered_habits, msg="行为习惯获取成功(缓存)")
except Exception as e: except Exception as e:
return handle_implicit_memory_error(e, "行为习惯获取", user_id) return handle_implicit_memory_error(e, "行为习惯获取", end_user_id)

View File

@@ -125,7 +125,7 @@ async def write_server(
Write service endpoint - processes write operations synchronously Write service endpoint - processes write operations synchronously
Args: Args:
user_input: Write request containing message and group_id user_input: Write request containing message and end_user_id
Returns: Returns:
Response with write operation status Response with write operation status
@@ -160,19 +160,18 @@ async def write_server(
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
storage_type = 'neo4j' storage_type = 'neo4j'
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
try: try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input) messages_list = memory_agent_service.get_messages_list(user_input)
result = await memory_agent_service.write_memory( result = await memory_agent_service.write_memory(
user_input.group_id, user_input.end_user_id,
messages_list, # 传递结构化消息列表 messages_list,
config_id, config_id,
db, db,
storage_type, storage_type,
user_rag_memory_id user_rag_memory_id
) )
return success(data=result, msg="写入成功") return success(data=result, msg="写入成功")
except BaseException as e: except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
@@ -196,7 +195,7 @@ async def write_server_async(
Async write service endpoint - enqueues write processing to Celery Async write service endpoint - enqueues write processing to Celery
Args: Args:
user_input: Write request containing message and group_id user_input: Write request containing message and end_user_id
Returns: Returns:
Task ID for tracking async operation Task ID for tracking async operation
@@ -229,7 +228,7 @@ async def write_server_async(
task = celery_app.send_task( task = celery_app.send_task(
"app.core.memory.agent.write_message", "app.core.memory.agent.write_message",
args=[user_input.group_id, messages_list, config_id, storage_type, user_rag_memory_id] args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id]
) )
api_logger.info(f"Write task queued: {task.id}") api_logger.info(f"Write task queued: {task.id}")
@@ -255,16 +254,14 @@ async def read_server(
- "2": Direct answer based on context - "2": Direct answer based on context
Args: Args:
user_input: Read request with message, history, search_switch, and group_id user_input: Read request with message, history, search_switch, and end_user_id
Returns: Returns:
Response with query answer Response with query answer
""" """
config_id = user_input.config_id config_id = user_input.config_id
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type( storage_type = workspace_service.get_workspace_storage_type(
db=db, db=db,
workspace_id=workspace_id, workspace_id=workspace_id,
@@ -279,12 +276,13 @@ async def read_server(
name="USER_RAG_MERORY", name="USER_RAG_MERORY",
workspace_id=workspace_id workspace_id=workspace_id
) )
if knowledge: user_rag_memory_id = str(knowledge.id) if knowledge:
user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Read service: group={user_input.group_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}") api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
try: try:
result = await memory_agent_service.read_memory( result = await memory_agent_service.read_memory(
user_input.group_id, user_input.end_user_id,
user_input.message, user_input.message,
user_input.history, user_input.history,
user_input.search_switch, user_input.search_switch,
@@ -295,17 +293,20 @@ async def read_server(
) )
if str(user_input.search_switch) == "2": if str(user_input.search_switch) == "2":
retrieve_info = result['answer'] retrieve_info = result['answer']
history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id) history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
query = user_input.message query = user_input.message
# 调用 memory_agent_service 的方法生成最终答案 # 调用 memory_agent_service 的方法生成最终答案
result['answer'] = await memory_agent_service.generate_summary_from_retrieve( result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
end_user_id=user_input.end_user_id,
retrieve_info=retrieve_info, retrieve_info=retrieve_info,
history=history, history=history,
query=query, query=query,
config_id=config_id, config_id=config_id,
db=db db=db
) )
if "信息不足,无法回答" in result['answer']:
result['answer']=retrieve_info
return success(data=result, msg="回复对话消息成功") return success(data=result, msg="回复对话消息成功")
except BaseException as e: except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
@@ -403,7 +404,7 @@ async def read_server_async(
try: try:
task = celery_app.send_task( task = celery_app.send_task(
"app.core.memory.agent.read_message", "app.core.memory.agent.read_message",
args=[user_input.group_id, user_input.message, user_input.history, user_input.search_switch, args=[user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch,
config_id, storage_type, user_rag_memory_id] config_id, storage_type, user_rag_memory_id]
) )
api_logger.info(f"Read task queued: {task.id}") api_logger.info(f"Read task queued: {task.id}")
@@ -447,7 +448,7 @@ async def get_read_task_result(
return success( return success(
data={ data={
"result": task_result.get("result"), "result": task_result.get("result"),
"group_id": task_result.get("group_id"), "end_user_id": task_result.get("end_user_id"),
"elapsed_time": task_result.get("elapsed_time"), "elapsed_time": task_result.get("elapsed_time"),
"task_id": task_id "task_id": task_id
}, },
@@ -524,7 +525,7 @@ async def get_write_task_result(
return success( return success(
data={ data={
"result": task_result.get("result"), "result": task_result.get("result"),
"group_id": task_result.get("group_id"), "end_user_id": task_result.get("end_user_id"),
"elapsed_time": task_result.get("elapsed_time"), "elapsed_time": task_result.get("elapsed_time"),
"task_id": task_id "task_id": task_id
}, },
@@ -578,12 +579,12 @@ async def status_type(
Determine the type of user message (read or write) Determine the type of user message (read or write)
Args: Args:
user_input: Request containing user message and group_id user_input: Request containing user message and end_user_id
Returns: Returns:
Type classification result Type classification result
""" """
api_logger.info(f"Status type check requested for group {user_input.group_id}") api_logger.info(f"Status type check requested for group {user_input.end_user_id}")
try: try:
# 获取标准化的消息列表 # 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input) messages_list = memory_agent_service.get_messages_list(user_input)
@@ -624,7 +625,7 @@ async def get_knowledge_type_stats_api(
会对缺失类型补 0返回字典形式。 会对缺失类型补 0返回字典形式。
可选按状态过滤。 可选按状态过滤。
- 知识库类型根据当前用户的 current_workspace_id 过滤 - 知识库类型根据当前用户的 current_workspace_id 过滤
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (group_id) 过滤 - memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤
- 如果用户没有当前工作空间或未提供 end_user_id对应的统计返回 0 - 如果用户没有当前工作空间或未提供 end_user_id对应的统计返回 0
""" """
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}") api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
@@ -697,7 +698,7 @@ async def get_user_profile_api(
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
获取工作空间下Popular Memory Tags,包含: 获取用户详情,包含:
- name: 用户名字(直接使用 end_user_id - name: 用户名字(直接使用 end_user_id
- tags: 3个用户特征标签从语句和实体中LLM总结 - tags: 3个用户特征标签从语句和实体中LLM总结
- hot_tags: 4个热门记忆标签 - hot_tags: 4个热门记忆标签

View File

@@ -49,62 +49,134 @@ async def get_workspace_end_users(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
""" """
获取工作空间的宿主列表 获取工作空间的宿主列表(高性能优化版本 v2
返回格式与原 memory_list 接口中的 end_users 字段相同, 优化策略:
并包含每个用户的记忆配置信息memory_config_id 和 memory_config_name 1. 批量查询 end_users一次查询而非循环
2. 并发查询所有用户的记忆数量Neo4j
3. RAG 模式使用批量查询(一次 SQL
4. 只返回必要字段减少数据传输
5. 添加短期缓存减少重复查询
6. 并发执行配置查询和记忆数量查询
返回格式:
{
"end_user": {"id": "uuid", "other_name": "名称"},
"memory_num": {"total": 数量},
"memory_config": {"memory_config_id": "id", "memory_config_name": "名称"}
}
""" """
import asyncio
import json
from app.aioRedis import aio_redis_get, aio_redis_set
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
# 尝试从缓存获取30秒缓存
cache_key = f"end_users:workspace:{workspace_id}"
try:
cached_data = await aio_redis_get(cache_key)
if cached_data:
api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}")
return success(data=json.loads(cached_data), msg="宿主列表获取成功")
except Exception as e:
api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
# 获取当前空间类型 # 获取当前空间类型
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user) current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表") api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
# 获取 end_users已优化为批量查询
end_users = memory_dashboard_service.get_workspace_end_users( end_users = memory_dashboard_service.get_workspace_end_users(
db=db, db=db,
workspace_id=workspace_id, workspace_id=workspace_id,
current_user=current_user current_user=current_user
) )
# 批量获取所有用户的记忆配置信息(优化:一次查询而非 N 次) if not end_users:
end_user_ids = [str(user.id) for user in end_users] api_logger.info("工作空间下没有宿主")
memory_configs_map = {} # 缓存空结果,避免重复查询
if end_user_ids:
try: try:
memory_configs_map = get_end_users_connected_configs_batch(end_user_ids, db) await aio_redis_set(cache_key, json.dumps([]), expire=30)
except Exception as e:
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
return success(data=[], msg="宿主列表获取成功")
end_user_ids = [str(user.id) for user in end_users]
# 并发执行两个独立的查询任务
async def get_memory_configs():
"""获取记忆配置(在线程池中执行同步查询)"""
try:
return await asyncio.to_thread(
get_end_users_connected_configs_batch,
end_user_ids, db
)
except Exception as e: except Exception as e:
api_logger.error(f"批量获取记忆配置失败: {str(e)}") api_logger.error(f"批量获取记忆配置失败: {str(e)}")
# 失败时使用空字典,不影响其他数据返回 return {}
async def get_memory_nums():
"""获取记忆数量"""
if current_workspace_type == "rag":
# RAG 模式:批量查询
try:
chunk_map = await asyncio.to_thread(
memory_dashboard_service.get_users_total_chunk_batch,
end_user_ids, db, current_user
)
return {uid: {"total": count} for uid, count in chunk_map.items()}
except Exception as e:
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
return {uid: {"total": 0} for uid in end_user_ids}
elif current_workspace_type == "neo4j":
# Neo4j 模式:并发查询(带并发限制)
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j
MAX_CONCURRENT_QUERIES = 10
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES)
async def get_neo4j_memory_num(end_user_id: str):
async with semaphore:
try:
return await memory_storage_service.search_all(end_user_id)
except Exception as e:
api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
return {"total": 0}
memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
return {uid: {"total": 0} for uid in end_user_ids}
# 并发执行配置查询和记忆数量查询
memory_configs_map, memory_nums_map = await asyncio.gather(
get_memory_configs(),
get_memory_nums()
)
# 构建结果(优化:使用列表推导式)
result = [] result = []
for end_user in end_users: for end_user in end_users:
memory_num = {}
if current_workspace_type == "neo4j":
# EndUser 是 Pydantic 模型,直接访问属性而不是使用 .get()
memory_num = await memory_storage_service.search_all(str(end_user.id))
elif current_workspace_type == "rag":
memory_num = {
"total":memory_dashboard_service.get_current_user_total_chunk(str(end_user.id), db, current_user)
}
# 从批量查询结果中获取配置信息
user_id = str(end_user.id) user_id = str(end_user.id)
memory_config_info = memory_configs_map.get(user_id, { config_info = memory_configs_map.get(user_id, {})
"memory_config_id": None, result.append({
"memory_config_name": None 'end_user': {
'id': user_id,
'other_name': end_user.other_name
},
'memory_num': memory_nums_map.get(user_id, {"total": 0}),
'memory_config': {
"memory_config_id": config_info.get("memory_config_id"),
"memory_config_name": config_info.get("memory_config_name")
}
}) })
# 只保留需要的字段,移除 error 字段(如果有 # 写入缓存30秒过期
memory_config = { try:
"memory_config_id": memory_config_info.get("memory_config_id"), await aio_redis_set(cache_key, json.dumps(result), expire=30)
"memory_config_name": memory_config_info.get("memory_config_name") except Exception as e:
} api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
result.append(
{
'end_user': end_user,
'memory_num': memory_num,
'memory_config': memory_config
}
)
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录") api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
return success(data=result, msg="宿主列表获取成功") return success(data=result, msg="宿主列表获取成功")

View File

@@ -11,6 +11,7 @@
""" """
from typing import Optional from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -106,7 +107,7 @@ async def trigger_forgetting_cycle(
# 调用服务层执行遗忘周期 # 调用服务层执行遗忘周期
report = await forget_service.trigger_forgetting_cycle( report = await forget_service.trigger_forgetting_cycle(
db=db, db=db,
group_id=end_user_id, # 服务层方法的参数名是 group_id end_user_id=end_user_id, # 服务层方法的参数名是 end_user_id
max_merge_batch_size=payload.max_merge_batch_size, max_merge_batch_size=payload.max_merge_batch_size,
min_days_since_access=payload.min_days_since_access, min_days_since_access=payload.min_days_since_access,
config_id=config_id config_id=config_id
@@ -128,7 +129,7 @@ async def trigger_forgetting_cycle(
@router.get("/read_config", response_model=ApiResponse) @router.get("/read_config", response_model=ApiResponse)
async def read_forgetting_config( async def read_forgetting_config(
config_id: int, config_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
@@ -236,7 +237,7 @@ async def update_forgetting_config(
@router.get("/stats", response_model=ApiResponse) @router.get("/stats", response_model=ApiResponse)
async def get_forgetting_stats( async def get_forgetting_stats(
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
@@ -246,7 +247,7 @@ async def get_forgetting_stats(
返回知识层节点统计、激活值分布等信息。 返回知识层节点统计、激活值分布等信息。
Args: Args:
group_id: 组ID即 end_user_id可选 end_user_id: 组ID即 end_user_id可选
current_user: 当前用户 current_user: 当前用户
db: 数据库会话 db: 数据库会话
@@ -260,20 +261,20 @@ async def get_forgetting_stats(
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 如果提供了 group_id通过它获取 config_id # 如果提供了 end_user_id通过它获取 config_id
config_id = None config_id = None
if group_id: if end_user_id:
try: try:
from app.services.memory_agent_service import get_end_user_connected_config from app.services.memory_agent_service import get_end_user_connected_config
connected_config = get_end_user_connected_config(group_id, db) connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id") config_id = connected_config.get("memory_config_id")
if config_id is None: if config_id is None:
api_logger.warning(f"终端用户 {group_id} 未关联记忆配置") api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {group_id} 未关联记忆配置", "memory_config_id is None") return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
api_logger.debug(f"通过 group_id={group_id} 获取到 config_id={config_id}") api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
except ValueError as e: except ValueError as e:
api_logger.warning(f"获取终端用户配置失败: {str(e)}") api_logger.warning(f"获取终端用户配置失败: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError") return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
@@ -283,14 +284,14 @@ async def get_forgetting_stats(
api_logger.info( api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: " f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: "
f"group_id={group_id}, config_id={config_id}" f"end_user_id={end_user_id}, config_id={config_id}"
) )
try: try:
# 调用服务层获取统计信息 # 调用服务层获取统计信息
stats = await forget_service.get_forgetting_stats( stats = await forget_service.get_forgetting_stats(
db=db, db=db,
group_id=group_id, end_user_id=end_user_id,
config_id=config_id config_id=config_id
) )

View File

@@ -27,27 +27,27 @@ router = APIRouter(
) )
@router.get("/{group_id}/count", response_model=ApiResponse) @router.get("/{end_user_id}/count", response_model=ApiResponse)
def get_memory_count( def get_memory_count(
group_id: uuid.UUID, end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
"""Retrieve perceptual memory statistics for a user group. """Retrieve perceptual memory statistics for a user group.
Args: Args:
group_id: ID of the user group (usually end_user_id in this context) end_user_id: ID of the user group (usually end_user_id in this context)
current_user: Current authenticated user current_user: Current authenticated user
db: Database session db: Database session
Returns: Returns:
ApiResponse: Response containing memory count statistics ApiResponse: Response containing memory count statistics
""" """
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, group_id={group_id}") api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, end_user_id={end_user_id}")
try: try:
service = MemoryPerceptualService(db) service = MemoryPerceptualService(db)
count_stats = service.get_memory_count(group_id) count_stats = service.get_memory_count(end_user_id)
api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}") api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}")
@@ -57,37 +57,37 @@ def get_memory_count(
) )
except Exception as e: except Exception as e:
api_logger.error(f"Failed to fetch memory statistics: group_id={group_id}, error={str(e)}") api_logger.error(f"Failed to fetch memory statistics: end_user_id={end_user_id}, error={str(e)}")
return fail( return fail(
code=BizCode.INTERNAL_ERROR, code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch memory statistics", msg="Failed to fetch memory statistics",
) )
@router.get("/{group_id}/last_visual", response_model=ApiResponse) @router.get("/{end_user_id}/last_visual", response_model=ApiResponse)
def get_last_visual_memory( def get_last_visual_memory(
group_id: uuid.UUID, end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
"""Retrieve the most recent VISION-type memory for a user. """Retrieve the most recent VISION-type memory for a user.
Args: Args:
group_id: ID of the user group end_user_id: ID of the user group
current_user: Current authenticated user current_user: Current authenticated user
db: Database session db: Database session
Returns: Returns:
ApiResponse: Metadata of the latest visual memory ApiResponse: Metadata of the latest visual memory
""" """
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, group_id={group_id}") api_logger.info(f"Fetching latest visual memory: user={current_user.username}, end_user_id={end_user_id}")
try: try:
service = MemoryPerceptualService(db) service = MemoryPerceptualService(db)
visual_memory = service.get_latest_visual_memory(group_id) visual_memory = service.get_latest_visual_memory(end_user_id)
if visual_memory is None: if visual_memory is None:
api_logger.info(f"No visual memory found: group_id={group_id}") api_logger.info(f"No visual memory found: end_user_id={end_user_id}")
return success( return success(
data=None, data=None,
msg="No visual memory available" msg="No visual memory available"
@@ -101,37 +101,37 @@ def get_last_visual_memory(
) )
except Exception as e: except Exception as e:
api_logger.error(f"Failed to fetch latest visual memory: group_id={group_id}, error={str(e)}") api_logger.error(f"Failed to fetch latest visual memory: end_user_id={end_user_id}, error={str(e)}")
return fail( return fail(
code=BizCode.INTERNAL_ERROR, code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest visual memory", msg="Failed to fetch latest visual memory",
) )
@router.get("/{group_id}/last_listen", response_model=ApiResponse) @router.get("/{end_user_id}/last_listen", response_model=ApiResponse)
def get_last_memory_listen( def get_last_memory_listen(
group_id: uuid.UUID, end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
"""Retrieve the most recent AUDIO-type memory for a user. """Retrieve the most recent AUDIO-type memory for a user.
Args: Args:
group_id: ID of the user group end_user_id: ID of the user group
current_user: Current authenticated user current_user: Current authenticated user
db: Database session db: Database session
Returns: Returns:
ApiResponse: Metadata of the latest audio memory ApiResponse: Metadata of the latest audio memory
""" """
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, group_id={group_id}") api_logger.info(f"Fetching latest audio memory: user={current_user.username}, end_user_id={end_user_id}")
try: try:
service = MemoryPerceptualService(db) service = MemoryPerceptualService(db)
audio_memory = service.get_latest_audio_memory(group_id) audio_memory = service.get_latest_audio_memory(end_user_id)
if audio_memory is None: if audio_memory is None:
api_logger.info(f"No audio memory found: group_id={group_id}") api_logger.info(f"No audio memory found: end_user_id={end_user_id}")
return success( return success(
data=None, data=None,
msg="No audio memory available" msg="No audio memory available"
@@ -145,38 +145,38 @@ def get_last_memory_listen(
) )
except Exception as e: except Exception as e:
api_logger.error(f"Failed to fetch latest audio memory: group_id={group_id}, error={str(e)}") api_logger.error(f"Failed to fetch latest audio memory: end_user_id={end_user_id}, error={str(e)}")
return fail( return fail(
code=BizCode.INTERNAL_ERROR, code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest audio memory", msg="Failed to fetch latest audio memory",
) )
@router.get("/{group_id}/last_text", response_model=ApiResponse) @router.get("/{end_user_id}/last_text", response_model=ApiResponse)
def get_last_text_memory( def get_last_text_memory(
group_id: uuid.UUID, end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
"""Retrieve the most recent TEXT-type memory for a user. """Retrieve the most recent TEXT-type memory for a user.
Args: Args:
group_id: ID of the user group end_user_id: ID of the user group
current_user: Current authenticated user current_user: Current authenticated user
db: Database session db: Database session
Returns: Returns:
ApiResponse: Metadata of the latest text memory ApiResponse: Metadata of the latest text memory
""" """
api_logger.info(f"Fetching latest text memory: user={current_user.username}, group_id={group_id}") api_logger.info(f"Fetching latest text memory: user={current_user.username}, end_user_id={end_user_id}")
try: try:
# 调用服务层获取最近的文本记忆 # 调用服务层获取最近的文本记忆
service = MemoryPerceptualService(db) service = MemoryPerceptualService(db)
text_memory = service.get_latest_text_memory(group_id) text_memory = service.get_latest_text_memory(end_user_id)
if text_memory is None: if text_memory is None:
api_logger.info(f"No text memory found: group_id={group_id}") api_logger.info(f"No text memory found: end_user_id={end_user_id}")
return success( return success(
data=None, data=None,
msg="No text memory available" msg="No text memory available"
@@ -190,16 +190,16 @@ def get_last_text_memory(
) )
except Exception as e: except Exception as e:
api_logger.error(f"Failed to fetch latest text memory: group_id={group_id}, error={str(e)}") api_logger.error(f"Failed to fetch latest text memory: end_user_id={end_user_id}, error={str(e)}")
return fail( return fail(
code=BizCode.INTERNAL_ERROR, code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest text memory", msg="Failed to fetch latest text memory",
) )
@router.get("/{group_id}/timeline", response_model=ApiResponse) @router.get("/{end_user_id}/timeline", response_model=ApiResponse)
def get_memory_time_line( def get_memory_time_line(
group_id: uuid.UUID, end_user_id: uuid.UUID,
perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"), perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"),
page: int = Query(1, ge=1, description="页码"), page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(10, ge=1, le=100, description="每页大小"), page_size: int = Query(10, ge=1, le=100, description="每页大小"),
@@ -209,7 +209,7 @@ def get_memory_time_line(
"""Retrieve a timeline of perceptual memories for a user group. """Retrieve a timeline of perceptual memories for a user group.
Args: Args:
group_id: ID of the user group end_user_id: ID of the user group
perceptual_type: Optional filter for perceptual type perceptual_type: Optional filter for perceptual type
page: Page number for pagination page: Page number for pagination
page_size: Number of items per page page_size: Number of items per page
@@ -221,7 +221,7 @@ def get_memory_time_line(
""" """
api_logger.info( api_logger.info(
f"Fetching perceptual memory timeline: user={current_user.username}, " f"Fetching perceptual memory timeline: user={current_user.username}, "
f"group_id={group_id}, type={perceptual_type}, page={page}" f"end_user_id={end_user_id}, type={perceptual_type}, page={page}"
) )
try: try:
@@ -232,7 +232,7 @@ def get_memory_time_line(
) )
service = MemoryPerceptualService(db) service = MemoryPerceptualService(db)
timeline_data = service.get_time_line(group_id, query) timeline_data = service.get_time_line(end_user_id, query)
api_logger.info( api_logger.info(
f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, " f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, "
@@ -246,7 +246,7 @@ def get_memory_time_line(
except Exception as e: except Exception as e:
api_logger.error( api_logger.error(
f"Failed to fetch perceptual memory timeline: group_id={group_id}, " f"Failed to fetch perceptual memory timeline: end_user_id={end_user_id}, "
f"error={str(e)}" f"error={str(e)}"
) )
return fail( return fail(

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import time import time
import uuid import uuid
from uuid import UUID
from app.core.logging_config import get_api_logger from app.core.logging_config import get_api_logger
from app.core.memory.storage_services.reflection_engine.self_reflexion import ( from app.core.memory.storage_services.reflection_engine.self_reflexion import (
@@ -11,7 +12,7 @@ from app.core.response_utils import success
from app.db import get_db from app.db import get_db
from app.dependencies import get_current_user from app.dependencies import get_current_user
from app.models.user_model import User from app.models.user_model import User
from app.repositories.data_config_repository import DataConfigRepository from app.repositories.memory_config_repository import MemoryConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_reflection_schemas import Memory_Reflection from app.schemas.memory_reflection_schemas import Memory_Reflection
from app.services.memory_reflection_service import ( from app.services.memory_reflection_service import (
@@ -50,7 +51,7 @@ async def save_reflection_config(
api_logger.info(f"用户 {current_user.username} 保存反思配置config_id: {config_id}") api_logger.info(f"用户 {current_user.username} 保存反思配置config_id: {config_id}")
data_config = DataConfigRepository.update_reflection_config( memory_config = MemoryConfigRepository.update_reflection_config(
db, db,
config_id=config_id, config_id=config_id,
enable_self_reflexion=request.reflection_enabled, enable_self_reflexion=request.reflection_enabled,
@@ -63,17 +64,17 @@ async def save_reflection_config(
) )
db.commit() db.commit()
db.refresh(data_config) db.refresh(memory_config)
reflection_result={ reflection_result={
"config_id": data_config.config_id, "config_id": memory_config.config_id,
"enable_self_reflexion": data_config.enable_self_reflexion, "enable_self_reflexion": memory_config.enable_self_reflexion,
"iteration_period": data_config.iteration_period, "iteration_period": memory_config.iteration_period,
"reflexion_range": data_config.reflexion_range, "reflexion_range": memory_config.reflexion_range,
"baseline": data_config.baseline, "baseline": memory_config.baseline,
"reflection_model_id": data_config.reflection_model_id, "reflection_model_id": memory_config.reflection_model_id,
"memory_verify": data_config.memory_verify, "memory_verify": memory_config.memory_verify,
"quality_assessment": data_config.quality_assessment} "quality_assessment": memory_config.quality_assessment}
return success(data=reflection_result, msg="反思配置成功") return success(data=reflection_result, msg="反思配置成功")
@@ -111,14 +112,14 @@ async def start_workspace_reflection(
reflection_results = [] reflection_results = []
for data in result['apps_detailed_info']: for data in result['apps_detailed_info']:
if data['data_configs'] == []: if data['memory_configs'] == []:
continue continue
releases = data['releases'] releases = data['releases']
data_configs = data['data_configs'] memory_configs = data['memory_configs']
end_users = data['end_users'] end_users = data['end_users']
for base, config, user in zip(releases, data_configs, end_users): for base, config, user in zip(releases, memory_configs, end_users):
# 安全地转换为整数处理空字符串和None的情况 # 安全地转换为整数处理空字符串和None的情况
print(base['config']) print(base['config'])
try: try:
@@ -156,14 +157,14 @@ async def start_workspace_reflection(
@router.get("/reflection/configs") @router.get("/reflection/configs")
async def start_reflection_configs( async def start_reflection_configs(
config_id: int, config_id: uuid.UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
"""通过config_id查询data_config表中的反思配置信息""" """通过config_id查询memory_config表中的反思配置信息"""
try: try:
api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}") api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}")
result = DataConfigRepository.query_reflection_config_by_id(db, config_id) result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
# 构建返回数据 # 构建返回数据
reflection_config = { reflection_config = {
"config_id": result.config_id, "config_id": result.config_id,
@@ -191,7 +192,7 @@ async def start_reflection_configs(
@router.get("/reflection/run") @router.get("/reflection/run")
async def reflection_run( async def reflection_run(
config_id: int, config_id: UUID,
language_type: str = Header(default="zh", alias="X-Language-Type"), language_type: str = Header(default="zh", alias="X-Language-Type"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
@@ -200,8 +201,8 @@ async def reflection_run(
api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}") api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}")
# 使用DataConfigRepository查询反思配置 # 使用MemoryConfigRepository查询反思配置
result = DataConfigRepository.query_reflection_config_by_id(db, config_id) result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
if not result: if not result:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,

View File

@@ -1,5 +1,6 @@
import os import os
from typing import Optional from typing import Optional
from uuid import UUID
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger from app.core.logging_config import get_api_logger
@@ -160,7 +161,7 @@ def create_config(
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称) @router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
def delete_config( def delete_config(
config_id: str, config_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
@@ -232,7 +233,7 @@ def update_config_extracted(
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除 @router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
def read_config_extracted( def read_config_extracted(
config_id: str, config_id: UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
@@ -420,15 +421,95 @@ async def get_hot_memory_tags_api(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Hot memory tags requested for current_user: {current_user.id}") """
获取热门记忆标签带Redis缓存
缓存策略:
- 缓存键workspace_id + limit
- 过期时间5分钟300秒
- 缓存命中:~50ms
- 缓存未命中:~600-800ms取决于LLM速度
"""
workspace_id = current_user.current_workspace_id
# 构建缓存键
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}")
try: try:
# 尝试从Redis缓存获取
from app.aioRedis import aio_redis_get, aio_redis_set
import json
cached_result = await aio_redis_get(cache_key)
if cached_result:
api_logger.info(f"Cache hit for key: {cache_key}")
try:
data = json.loads(cached_result)
return success(data=data, msg="查询成功(缓存)")
except json.JSONDecodeError:
api_logger.warning(f"Failed to parse cached data, will refresh")
# 缓存未命中,执行查询
api_logger.info(f"Cache miss for key: {cache_key}, executing query")
result = await analytics_hot_memory_tags(db, current_user, limit) result = await analytics_hot_memory_tags(db, current_user, limit)
# 写入缓存过期时间5分钟
# 注意result是列表需要转换为JSON字符串
try:
cache_data = json.dumps(result, ensure_ascii=False)
await aio_redis_set(cache_key, cache_data, expire=300)
api_logger.info(f"Cached result for key: {cache_key}")
except Exception as cache_error:
# 缓存写入失败不影响主流程
api_logger.warning(f"Failed to cache result: {str(cache_error)}")
return success(data=result, msg="查询成功") return success(data=result, msg="查询成功")
except Exception as e: except Exception as e:
api_logger.error(f"Hot memory tags failed: {str(e)}") api_logger.error(f"Hot memory tags failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
async def clear_hot_memory_tags_cache(
current_user: User = Depends(get_current_user),
) -> dict:
"""
清除热门标签缓存
用于:
- 手动刷新数据
- 调试和测试
- 数据更新后立即生效
"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}")
try:
from app.aioRedis import aio_redis_delete
# 清除所有limit的缓存常见的limit值
cleared_count = 0
for limit in [5, 10, 15, 20, 30, 50]:
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
result = await aio_redis_delete(cache_key)
if result:
cleared_count += 1
api_logger.info(f"Cleared cache for key: {cache_key}")
return success(
data={"cleared_count": cleared_count},
msg=f"成功清除 {cleared_count} 个缓存"
)
except Exception as e:
api_logger.error(f"Clear cache failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e))
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse) @router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
async def get_recent_activity_stats_api( async def get_recent_activity_stats_api(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),

View File

@@ -20,18 +20,18 @@ router = APIRouter(
) )
@router.get("/{group_id}/count", response_model=ApiResponse) @router.get("/{end_user_id}/count", response_model=ApiResponse)
def get_memory_count( def get_memory_count(
group_id: uuid.UUID, end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
pass pass
@router.get("/{group_id}/conversations", response_model=ApiResponse) @router.get("/{end_user_id}/conversations", response_model=ApiResponse)
def get_conversations( def get_conversations(
group_id: uuid.UUID, end_user_id: uuid.UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
@@ -39,7 +39,7 @@ def get_conversations(
Retrieve all conversations for the current user in a specific group. Retrieve all conversations for the current user in a specific group.
Args: Args:
group_id (UUID): The group identifier. end_user_id (UUID): The group identifier.
current_user (User, optional): The authenticated user. current_user (User, optional): The authenticated user.
db (Session, optional): SQLAlchemy session. db (Session, optional): SQLAlchemy session.
@@ -53,7 +53,7 @@ def get_conversations(
""" """
conversation_service = ConversationService(db) conversation_service = ConversationService(db)
conversations = conversation_service.get_user_conversations( conversations = conversation_service.get_user_conversations(
group_id end_user_id
) )
return success(data=[ return success(data=[
{ {
@@ -63,7 +63,7 @@ def get_conversations(
], msg="get conversations success") ], msg="get conversations success")
@router.get("/{group_id}/messages", response_model=ApiResponse) @router.get("/{end_user_id}/messages", response_model=ApiResponse)
def get_messages( def get_messages(
conversation_id: uuid.UUID, conversation_id: uuid.UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
@@ -100,7 +100,7 @@ def get_messages(
return success(data=messages, msg="get conversation history success") return success(data=messages, msg="get conversation history success")
@router.get("/{group_id}/detail", response_model=ApiResponse) @router.get("/{end_user_id}/detail", response_model=ApiResponse)
async def get_conversation_detail( async def get_conversation_detail(
conversation_id: uuid.UUID, conversation_id: uuid.UUID,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),

View File

@@ -317,9 +317,12 @@ async def chat(
appid = share.app_id appid = share.app_id
"""获取存储类型和工作空间的ID""" """获取存储类型和工作空间的ID"""
# 直接通过 SQLAlchemy 查询 app # 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
from app.models.app_model import App from app.models.app_model import App
app = db.query(App).filter(App.id == appid).first() app = db.query(App).filter(
App.id == appid,
App.is_active.is_(True)
).first()
if not app: if not app:
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND) raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)

View File

@@ -39,7 +39,7 @@ async def write_memory_api_service(
Stores memory content for the specified end user using the Memory API Service. Stores memory content for the specified end user using the Memory API Service.
""" """
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}") logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
memory_api_service = MemoryAPIService(db) memory_api_service = MemoryAPIService(db)

View File

@@ -135,27 +135,27 @@ async def generate_cache_api(
api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
group_id = request.end_user_id end_user_id = request.end_user_id
api_logger.info( api_logger.info(
f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, " f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, "
f"end_user_id={group_id if group_id else '全部用户'}" f"end_user_id={end_user_id if end_user_id else '全部用户'}"
) )
try: try:
if group_id: if end_user_id:
# 为单个用户生成 # 为单个用户生成
api_logger.info(f"开始为单个用户生成缓存: end_user_id={group_id}") api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
# 生成记忆洞察 # 生成记忆洞察
insight_result = await user_memory_service.generate_and_cache_insight(db, group_id, workspace_id) insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id)
# 生成用户摘要 # 生成用户摘要
summary_result = await user_memory_service.generate_and_cache_summary(db, group_id, workspace_id) summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id)
# 构建响应 # 构建响应
result = { result = {
"end_user_id": group_id, "end_user_id": end_user_id,
"insight_success": insight_result["success"], "insight_success": insight_result["success"],
"summary_success": summary_result["success"], "summary_success": summary_result["success"],
"errors": [] "errors": []
@@ -175,9 +175,9 @@ async def generate_cache_api(
# 记录结果 # 记录结果
if result["insight_success"] and result["summary_success"]: if result["insight_success"] and result["summary_success"]:
api_logger.info(f"成功为用户 {group_id} 生成缓存") api_logger.info(f"成功为用户 {end_user_id} 生成缓存")
else: else:
api_logger.warning(f"用户 {group_id} 的缓存生成部分失败: {result['errors']}") api_logger.warning(f"用户 {end_user_id} 的缓存生成部分失败: {result['errors']}")
return success(data=result, msg="生成完成") return success(data=result, msg="生成完成")

View File

@@ -54,7 +54,7 @@ async def create_workflow_config(
app = db.query(App).filter( app = db.query(App).filter(
App.id == app_id, App.id == app_id,
App.workspace_id == current_user.current_workspace_id, App.workspace_id == current_user.current_workspace_id,
App.is_active == True App.is_active.is_(True)
).first() ).first()
if not app: if not app:
@@ -214,7 +214,7 @@ async def delete_workflow_config(
app = db.query(App).filter( app = db.query(App).filter(
App.id == app_id, App.id == app_id,
App.workspace_id == current_user.current_workspace_id, App.workspace_id == current_user.current_workspace_id,
App.is_active == True App.is_active.is_(True)
).first() ).first()
if not app: if not app:
@@ -259,7 +259,7 @@ async def validate_workflow_config(
app = db.query(App).filter( app = db.query(App).filter(
App.id == app_id, App.id == app_id,
App.workspace_id == current_user.current_workspace_id, App.workspace_id == current_user.current_workspace_id,
App.is_active == True App.is_active.is_(True)
).first() ).first()
if not app: if not app:
@@ -329,7 +329,7 @@ async def get_workflow_executions(
app = db.query(App).filter( app = db.query(App).filter(
App.id == app_id, App.id == app_id,
App.workspace_id == current_user.current_workspace_id, App.workspace_id == current_user.current_workspace_id,
App.is_active == True App.is_active.is_(True)
).first() ).first()
if not app: if not app:
@@ -389,7 +389,7 @@ async def get_workflow_execution(
app = db.query(App).filter( app = db.query(App).filter(
App.id == execution.app_id, App.id == execution.app_id,
App.workspace_id == current_user.current_workspace_id, App.workspace_id == current_user.current_workspace_id,
App.is_active == True App.is_active.is_(True)
).first() ).first()
if not app: if not app:
@@ -440,7 +440,7 @@ async def run_workflow(
app = db.query(App).filter( app = db.query(App).filter(
App.id == app_id, App.id == app_id,
App.workspace_id == current_user.current_workspace_id, App.workspace_id == current_user.current_workspace_id,
App.is_active == True App.is_active.is_(True)
).first() ).first()
if not app: if not app:
@@ -578,7 +578,7 @@ async def cancel_workflow_execution(
app = db.query(App).filter( app = db.query(App).filter(
App.id == execution.app_id, App.id == execution.app_id,
App.workspace_id == current_user.current_workspace_id, App.workspace_id == current_user.current_workspace_id,
App.is_active == True App.is_active.is_(True)
).first() ).first()
if not app: if not app:

View File

@@ -155,7 +155,7 @@ class LangChainAgent:
# userid=end_user_end, # userid=end_user_end,
# messages=messages, # messages=messages,
# apply_id=end_user_end, # apply_id=end_user_end,
# group_id=end_user_end, # end_user_id=end_user_end,
# aimessages=aimessages # aimessages=aimessages
# ) # )
# store.delete_duplicate_sessions() # store.delete_duplicate_sessions()
@@ -228,7 +228,7 @@ class LangChainAgent:
# 6. 每个 Chunk 保存到 Neo4j包含 speaker 字段 # 6. 每个 Chunk 保存到 Neo4j包含 speaker 字段
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay( write_id = write_message_task.delay(
actual_end_user_id, # group_id: 用户ID actual_end_user_id, # end_user_id: 用户ID
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
actual_config_id, # config_id: 配置ID actual_config_id, # config_id: 配置ID
storage_type, # storage_type: "neo4j" storage_type, # storage_type: "neo4j"

View File

@@ -184,7 +184,7 @@ class Settings:
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true" ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
# official environment system version # official environment system version
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0") SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1")
# workflow config # workflow config
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600)) WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))

View File

@@ -14,7 +14,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
db_session = next(get_db()) db_session = next(get_db())
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
@@ -35,10 +35,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
"""问题分解节点""" """问题分解节点"""
# 从状态中获取数据 # 从状态中获取数据
content = state.get('data', '') content = state.get('data', '')
group_id = state.get('group_id', '') end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
history = await SessionService(store).get_history(group_id, group_id, group_id) history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
# 生成 JSON schema 以指导 LLM 输出正确格式 # 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema() json_schema = ProblemExtensionResponse.model_json_schema()
@@ -140,7 +140,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
start = time.time() start = time.time()
content = state.get('data', '') content = state.get('data', '')
data = state.get('spit_data', '')['context'] data = state.get('spit_data', '')['context']
group_id = state.get('group_id', '') end_user_id = state.get('end_user_id', '')
storage_type = state.get('storage_type', '') storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '') user_rag_memory_id = state.get('user_rag_memory_id', '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
@@ -156,7 +156,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
databasets = {} databasets = {}
data = [] data = []
history = await SessionService(store).get_history(group_id, group_id, group_id) history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
# 生成 JSON schema 以指导 LLM 输出正确格式 # 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema() json_schema = ProblemExtensionResponse.model_json_schema()

View File

@@ -52,9 +52,9 @@ async def rag_config(state):
return kb_config return kb_config
async def rag_knowledge(state,question): async def rag_knowledge(state,question):
kb_config = await rag_config(state) kb_config = await rag_config(state)
group_id = state.get('group_id', '') end_user_id = state.get('end_user_id', '')
user_rag_memory_id=state.get("user_rag_memory_id",'') user_rag_memory_id=state.get("user_rag_memory_id",'')
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(group_id)]) retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
try: try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge) clean_content = '\n\n'.join(retrieval_knowledge)
@@ -159,7 +159,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
problem_extension=state.get('problem_extension', '')['context'] problem_extension=state.get('problem_extension', '')['context']
storage_type=state.get('storage_type', '') storage_type=state.get('storage_type', '')
user_rag_memory_id=state.get('user_rag_memory_id', '') user_rag_memory_id=state.get('user_rag_memory_id', '')
group_id=state.get('group_id', '') end_user_id=state.get('end_user_id', '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
original=state.get('data', '') original=state.get('data', '')
problem_list=[] problem_list=[]
@@ -172,7 +172,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
try: try:
# Prepare search parameters based on storage type # Prepare search parameters based on storage type
search_params = { search_params = {
"group_id": group_id, "end_user_id": end_user_id,
"question": question, "question": question,
"return_raw_results": True "return_raw_results": True
} }
@@ -263,13 +263,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
async def retrieve(state: ReadState) -> ReadState: async def retrieve(state: ReadState) -> ReadState:
# 从state中获取group_id # 从state中获取end_user_id
import time import time
start=time.time() start=time.time()
problem_extension = state.get('problem_extension', '')['context'] problem_extension = state.get('problem_extension', '')['context']
storage_type = state.get('storage_type', '') storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '') user_rag_memory_id = state.get('user_rag_memory_id', '')
group_id = state.get('group_id', '') end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
original = state.get('data', '') original = state.get('data', '')
problem_list = [] problem_list = []
@@ -295,13 +295,13 @@ async def retrieve(state: ReadState) -> ReadState:
temperature=0.2, temperature=0.2,
) )
time_retrieval_tool = create_time_retrieval_tool(group_id) time_retrieval_tool = create_time_retrieval_tool(end_user_id)
search_params = { "group_id": group_id, "return_raw_results": True } search_params = { "end_user_id": end_user_id, "return_raw_results": True }
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params) hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
agent = create_agent( agent = create_agent(
llm, llm,
tools=[time_retrieval_tool,hybrid_retrieval], tools=[time_retrieval_tool,hybrid_retrieval],
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的group_id是: {group_id}" system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
) )
# 创建异步任务处理单个问题 # 创建异步任务处理单个问题

View File

@@ -19,7 +19,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.db import get_db from app.db import get_db
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
db_session = next(get_db()) db_session = next(get_db())
@@ -34,8 +34,8 @@ class SummaryNodeService(LLMServiceMixin):
summary_service = SummaryNodeService() summary_service = SummaryNodeService()
async def summary_history(state: ReadState) -> ReadState: async def summary_history(state: ReadState) -> ReadState:
group_id = state.get("group_id", '') end_user_id = state.get("end_user_id", '')
history = await SessionService(store).get_history(group_id, group_id, group_id) history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
return history return history
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str: async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
@@ -122,12 +122,12 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
async def summary_redis_save(state: ReadState,aimessages) -> ReadState: async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
data = state.get("data", '') data = state.get("data", '')
group_id = state.get("group_id", '') end_user_id = state.get("end_user_id", '')
await SessionService(store).save_session( await SessionService(store).save_session(
user_id=group_id, user_id=end_user_id,
query=data, query=data,
apply_id=group_id, apply_id=end_user_id,
group_id=group_id, end_user_id=end_user_id,
ai_response=aimessages ai_response=aimessages
) )
await SessionService(store).cleanup_duplicates() await SessionService(store).cleanup_duplicates()
@@ -175,11 +175,11 @@ async def Input_Summary(state: ReadState) -> ReadState:
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
user_rag_memory_id=state.get("user_rag_memory_id",'') user_rag_memory_id=state.get("user_rag_memory_id",'')
data=state.get("data", '') data=state.get("data", '')
group_id=state.get("group_id", '') end_user_id=state.get("end_user_id", '')
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
history = await summary_history( state) history = await summary_history( state)
search_params = { search_params = {
"group_id": group_id, "end_user_id": end_user_id,
"question": data, "question": data,
"return_raw_results": True, "return_raw_results": True,
"include": ["summaries"] # Only search summary nodes for faster performance "include": ["summaries"] # Only search summary nodes for faster performance
@@ -236,7 +236,7 @@ async def Retrieve_Summary(state: ReadState)-> ReadState:
retrieve_info_str='\n'.join(retrieve_info_str) retrieve_info_str='\n'.join(retrieve_info_str)
aimessages=await summary_llm(state,history,retrieve_info_str, aimessages=await summary_llm(state,history,retrieve_info_str,
'Retrieve_Summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1") 'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages) await summary_redis_save(state, aimessages)
if aimessages == '': if aimessages == '':
@@ -276,7 +276,6 @@ async def Summary(state: ReadState)-> ReadState:
aimessages=await summary_llm(state,history,data, aimessages=await summary_llm(state,history,data,
'summary_prompt.jinja2','summary',SummaryResponse,0) 'summary_prompt.jinja2','summary',SummaryResponse,0)
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages) await summary_redis_save(state, aimessages)
if aimessages == '': if aimessages == '':
@@ -295,9 +294,26 @@ async def Summary(state: ReadState)-> ReadState:
async def Summary_fails(state: ReadState)-> ReadState: async def Summary_fails(state: ReadState)-> ReadState:
storage_type=state.get("storage_type", '') storage_type=state.get("storage_type", '')
user_rag_memory_id=state.get("user_rag_memory_id", '') user_rag_memory_id=state.get("user_rag_memory_id", '')
history = await summary_history(state)
query = state.get("data", '')
verify = state.get("verify", '')
verify_expansion_issue = verify.get("verified_data", '')
retrieve_info_str = ''
for data in verify_expansion_issue:
for key, value in data.items():
if key == 'answer_small':
for i in value:
retrieve_info_str += i + '\n'
data = {
"query": query,
"history": history,
"retrieve_info": retrieve_info_str
}
aimessages = await summary_llm(state, history, data,
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
result= { result= {
"status": "success", "status": "success",
"summary_result": "没有相关数据", "summary_result": aimessages,
"storage_type": storage_type, "storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id "user_rag_memory_id": user_rag_memory_id
} }

View File

@@ -12,7 +12,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
db_session = next(get_db()) db_session = next(get_db())
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
@@ -62,12 +62,12 @@ async def Verify(state: ReadState):
logger.info("=== Verify 节点开始执行 ===") logger.info("=== Verify 节点开始执行 ===")
try: try:
content = state.get('data', '') content = state.get('data', '')
group_id = state.get('group_id', '') end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None) memory_config = state.get('memory_config', None)
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., group_id={group_id}") logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
history = await SessionService(store).get_history(group_id, group_id, group_id) history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
logger.info(f"Verify: 获取历史记录完成history length={len(history)}") logger.info(f"Verify: 获取历史记录完成history length={len(history)}")
retrieve = state.get("retrieve", {}) retrieve = state.get("retrieve", {})

View File

@@ -1,21 +1,22 @@
from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.utils.write_tools import write from app.core.memory.agent.utils.write_tools import write
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
async def write_node(state: WriteState) -> WriteState: async def write_node(state: WriteState) -> WriteState:
""" """
Write data to the database/file system. Write data to the database/file system.
Args: Args:
state: WriteState containing messages, group_id, and memory_config state: WriteState containing messages, end_user_id, and memory_config
Returns: Returns:
dict: Contains 'write_result' with status and data fields dict: Contains 'write_result' with status and data fields
""" """
messages = state.get('messages', []) messages = state.get('messages', [])
group_id = state.get('group_id', '') end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', '') memory_config = state.get('memory_config', '')
# Convert LangChain messages to structured format expected by write() # Convert LangChain messages to structured format expected by write()
@@ -32,9 +33,7 @@ async def write_node(state: WriteState) -> WriteState:
try: try:
result = await write( result = await write(
messages=structured_messages, messages=structured_messages,
user_id=group_id, end_user_id=end_user_id,
apply_id=group_id,
group_id=group_id,
memory_config=memory_config, memory_config=memory_config,
) )
logger.info(f"Write completed successfully! Config: {memory_config.config_name}") logger.info(f"Write completed successfully! Config: {memory_config.config_name}")

View File

@@ -79,7 +79,7 @@ async def make_read_graph():
async def main(): async def main():
"""主函数 - 运行工作流""" """主函数 - 运行工作流"""
message = "昨天有什么好看的电影" message = "昨天有什么好看的电影"
group_id = '88a459f5_text09' # 组ID end_user_id = '88a459f5_text09' # 组ID
storage_type = 'neo4j' # 存储类型 storage_type = 'neo4j' # 存储类型
search_switch = '1' # 搜索开关 search_switch = '1' # 搜索开关
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
@@ -95,9 +95,9 @@ async def main():
start=time.time() start=time.time()
try: try:
async with make_read_graph() as graph: async with make_read_graph() as graph:
config = {"configurable": {"thread_id": group_id}} config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段 # 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"group_id":group_id initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config} ,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
# 获取节点更新信息 # 获取节点更新信息
_intermediate_outputs = [] _intermediate_outputs = []

View File

@@ -48,11 +48,11 @@ def extract_tool_message_content(response):
class TimeRetrievalInput(BaseModel): class TimeRetrievalInput(BaseModel):
"""时间检索工具的输入模式""" """时间检索工具的输入模式"""
context: str = Field(description="用户输入的查询内容") context: str = Field(description="用户输入的查询内容")
group_id: str = Field(default="88a459f5_text09", description="组ID用于过滤搜索结果") end_user_id: str = Field(default="88a459f5_text09", description="组ID用于过滤搜索结果")
def create_time_retrieval_tool(group_id: str): def create_time_retrieval_tool(end_user_id: str):
""" """
创建一个带有特定group_id的TimeRetrieval工具同步版本用于按时间范围搜索语句(Statements) 创建一个带有特定end_user_id的TimeRetrieval工具同步版本用于按时间范围搜索语句(Statements)
""" """
def clean_temporal_result_fields(data): def clean_temporal_result_fields(data):
@@ -93,26 +93,26 @@ def create_time_retrieval_tool(group_id: str):
return data return data
@tool @tool
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, group_id_param: str = None, clean_output: bool = True) -> str: def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
""" """
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段 优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
显式接收参数: 显式接收参数:
- context: 查询上下文内容 - context: 查询上下文内容
- start_date: 开始时间可选格式YYYY-MM-DD - start_date: 开始时间可选格式YYYY-MM-DD
- end_date: 结束时间可选格式YYYY-MM-DD - end_date: 结束时间可选格式YYYY-MM-DD
- group_id_param: 组ID可选用于覆盖默认组ID - end_user_id_param: 组ID可选用于覆盖默认组ID
- clean_output: 是否清理输出中的元数据字段 - clean_output: 是否清理输出中的元数据字段
-end_date 需要根据用户的描述获取结束的时间输出格式用strftime("%Y-%m-%d") -end_date 需要根据用户的描述获取结束的时间输出格式用strftime("%Y-%m-%d")
""" """
async def _async_search(): async def _async_search():
# 使用传入的参数或默认值 # 使用传入的参数或默认值
actual_group_id = group_id_param or group_id actual_end_user_id = end_user_id_param or end_user_id
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d") actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d") actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
# 基本时间搜索 # 基本时间搜索
results = await search_by_temporal( results = await search_by_temporal(
group_id=actual_group_id, end_user_id=actual_end_user_id,
start_date=actual_start_date, start_date=actual_start_date,
end_date=actual_end_date, end_date=actual_end_date,
limit=10 limit=10
@@ -147,7 +147,7 @@ def create_time_retrieval_tool(group_id: str):
# 关键词时间搜索 # 关键词时间搜索
results = await search_by_keyword_temporal( results = await search_by_keyword_temporal(
query_text=context, query_text=context,
group_id=group_id, end_user_id=end_user_id,
start_date=actual_start_date, start_date=actual_start_date,
end_date=actual_end_date, end_date=actual_end_date,
limit=15 limit=15
@@ -172,7 +172,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
Args: Args:
memory_config: 内存配置对象 memory_config: 内存配置对象
**search_params: 搜索参数,包含group_id, limit, include等 **search_params: 搜索参数,包含end_user_id, limit, include等
""" """
def clean_result_fields(data): def clean_result_fields(data):
@@ -211,7 +211,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
context: str, context: str,
search_type: str = "hybrid", search_type: str = "hybrid",
limit: int = 10, limit: int = 10,
group_id: str = None, end_user_id: str = None,
rerank_alpha: float = 0.6, rerank_alpha: float = 0.6,
use_forgetting_rerank: bool = False, use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False, use_llm_rerank: bool = False,
@@ -224,7 +224,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
context: 查询内容 context: 查询内容
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid') search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
limit: 结果数量限制 limit: 结果数量限制
group_id: 组ID用于过滤搜索结果 end_user_id: 组ID用于过滤搜索结果
rerank_alpha: 重排序权重参数 rerank_alpha: 重排序权重参数
use_forgetting_rerank: 是否使用遗忘重排序 use_forgetting_rerank: 是否使用遗忘重排序
use_llm_rerank: 是否使用LLM重排序 use_llm_rerank: 是否使用LLM重排序
@@ -238,7 +238,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
final_params = { final_params = {
"query_text": context, "query_text": context,
"search_type": search_type, "search_type": search_type,
"group_id": group_id or search_params.get("group_id"), "end_user_id": end_user_id or search_params.get("end_user_id"),
"limit": limit or search_params.get("limit", 10), "limit": limit or search_params.get("limit", 10),
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]), "include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
"output_path": None, # 不保存到文件 "output_path": None, # 不保存到文件
@@ -291,7 +291,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
context: str, context: str,
search_type: str = "hybrid", search_type: str = "hybrid",
limit: int = 10, limit: int = 10,
group_id: str = None, end_user_id: str = None,
clean_output: bool = True clean_output: bool = True
) -> str: ) -> str:
""" """
@@ -301,7 +301,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
context: 查询内容 context: 查询内容
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid') search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
limit: 结果数量限制 limit: 结果数量限制
group_id: 组ID用于过滤搜索结果 end_user_id: 组ID用于过滤搜索结果
clean_output: 是否清理输出中的元数据字段 clean_output: 是否清理输出中的元数据字段
""" """
async def _async_search(): async def _async_search():
@@ -311,7 +311,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
"context": context, "context": context,
"search_type": search_type, "search_type": search_type,
"limit": limit, "limit": limit,
"group_id": group_id, "end_user_id": end_user_id,
"clean_output": clean_output "clean_output": clean_output
}) })

View File

@@ -14,6 +14,7 @@ from app.db import get_db
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning) warnings.filterwarnings("ignore", category=RuntimeWarning)
@@ -26,9 +27,21 @@ async def make_write_graph():
""" """
Create a write graph workflow for memory operations. Create a write graph workflow for memory operations.
The workflow directly processes messages from the initial state Args:
and saves them to Neo4j storage. user_id: User identifier
tools: MCP tools loaded from session
apply_id: Application identifier
end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration
""" """
# workflow = StateGraph(WriteState)
# workflow.add_node("content_input", content_input_write)
# workflow.add_node("save_neo4j", write_node)
# workflow.add_edge(START, "content_input")
# workflow.add_edge("content_input", "save_neo4j")
# workflow.add_edge("save_neo4j", END)
#
# graph = workflow.compile()
workflow = StateGraph(WriteState) workflow = StateGraph(WriteState)
workflow.add_node("save_neo4j", write_node) workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "save_neo4j") workflow.add_edge(START, "save_neo4j")
@@ -42,7 +55,7 @@ async def make_write_graph():
async def main(): async def main():
"""主函数 - 运行工作流""" """主函数 - 运行工作流"""
message = "今天周一" message = "今天周一"
group_id = 'new_2025test1103' # 组ID end_user_id = 'new_2025test1103' # 组ID
# 获取数据库会话 # 获取数据库会话
@@ -54,9 +67,9 @@ async def main():
) )
try: try:
async with make_write_graph() as graph: async with make_write_graph() as graph:
config = {"configurable": {"thread_id": group_id}} config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段 # 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, "memory_config": memory_config} initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config}
# 获取节点更新信息 # 获取节点更新信息
async for update_event in graph.astream( async for update_event in graph.astream(

View File

@@ -24,7 +24,7 @@ class ParameterBuilder:
tool_call_id: str, tool_call_id: str,
search_switch: str, search_switch: str,
apply_id: str, apply_id: str,
group_id: str, end_user_id: str,
storage_type: Optional[str] = None, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None user_rag_memory_id: Optional[str] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@@ -44,7 +44,7 @@ class ParameterBuilder:
tool_call_id: Extracted tool call identifier tool_call_id: Extracted tool call identifier
search_switch: Search routing parameter search_switch: Search routing parameter
apply_id: Application identifier apply_id: Application identifier
group_id: Group identifier end_user_id: Group identifier
storage_type: Storage type for the workspace (optional) storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional) user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
@@ -55,7 +55,7 @@ class ParameterBuilder:
base_args = { base_args = {
"usermessages": tool_call_id, "usermessages": tool_call_id,
"apply_id": apply_id, "apply_id": apply_id,
"group_id": group_id "end_user_id": end_user_id
} }
# Always add storage_type and user_rag_memory_id (with defaults if None) # Always add storage_type and user_rag_memory_id (with defaults if None)

View File

@@ -91,7 +91,7 @@ class SearchService:
async def execute_hybrid_search( async def execute_hybrid_search(
self, self,
group_id: str, end_user_id: str,
question: str, question: str,
limit: int = 5, limit: int = 5,
search_type: str = "hybrid", search_type: str = "hybrid",
@@ -105,7 +105,7 @@ class SearchService:
Execute hybrid search and return clean content. Execute hybrid search and return clean content.
Args: Args:
group_id: Group identifier for filtering results end_user_id: Group identifier for filtering results
question: Search query text question: Search query text
limit: Maximum number of results to return (default: 5) limit: Maximum number of results to return (default: 5)
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid") search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
@@ -130,7 +130,7 @@ class SearchService:
answer = await run_hybrid_search( answer = await run_hybrid_search(
query_text=cleaned_query, query_text=cleaned_query,
search_type=search_type, search_type=search_type,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include, include=include,
output_path=output_path, output_path=output_path,
@@ -186,7 +186,7 @@ class SearchService:
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Search failed for query '{question}' in group '{group_id}': {e}", f"Search failed for query '{question}' in group '{end_user_id}': {e}",
exc_info=True exc_info=True
) )
# Return empty results on failure # Return empty results on failure

View File

@@ -59,7 +59,7 @@ class SessionService:
self, self,
user_id: str, user_id: str,
apply_id: str, apply_id: str,
group_id: str end_user_id: str
) -> List[dict]: ) -> List[dict]:
""" """
Retrieve conversation history from Redis. Retrieve conversation history from Redis.
@@ -67,20 +67,20 @@ class SessionService:
Args: Args:
user_id: User identifier user_id: User identifier
apply_id: Application identifier apply_id: Application identifier
group_id: Group identifier end_user_id: Group identifier
Returns: Returns:
List of conversation history items with Query and Answer keys List of conversation history items with Query and Answer keys
Returns empty list if no history found or on error Returns empty list if no history found or on error
""" """
try: try:
history = self.store.find_user_apply_group(user_id, apply_id, group_id) history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
# Validate history structure # Validate history structure
if not isinstance(history, list): if not isinstance(history, list):
logger.warning( logger.warning(
f"Invalid history format for user {user_id}, " f"Invalid history format for user {user_id}, "
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}" f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
) )
return [] return []
@@ -89,7 +89,7 @@ class SessionService:
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to retrieve history for user {user_id}, " f"Failed to retrieve history for user {user_id}, "
f"apply {apply_id}, group {group_id}: {e}", f"apply {apply_id}, group {end_user_id}: {e}",
exc_info=True exc_info=True
) )
# Return empty list on error to allow execution to continue # Return empty list on error to allow execution to continue
@@ -100,7 +100,7 @@ class SessionService:
user_id: str, user_id: str,
query: str, query: str,
apply_id: str, apply_id: str,
group_id: str, end_user_id: str,
ai_response: str ai_response: str
) -> Optional[str]: ) -> Optional[str]:
""" """
@@ -110,7 +110,7 @@ class SessionService:
user_id: User identifier user_id: User identifier
query: User query/message query: User query/message
apply_id: Application identifier apply_id: Application identifier
group_id: Group identifier end_user_id: Group identifier
ai_response: AI response/answer ai_response: AI response/answer
Returns: Returns:
@@ -131,7 +131,7 @@ class SessionService:
userid=user_id, userid=user_id,
messages=query, messages=query,
apply_id=apply_id, apply_id=apply_id,
group_id=group_id, end_user_id=end_user_id,
aimessages=ai_response aimessages=ai_response
) )
@@ -152,7 +152,7 @@ class SessionService:
Duplicates are identified by matching: Duplicates are identified by matching:
- sessionid - sessionid
- user_id (id field) - user_id (id field)
- group_id - end_user_id
- messages - messages
- aimessages - aimessages

View File

@@ -9,9 +9,7 @@ from app.core.memory.models.message_models import DialogData, ConversationContex
async def get_chunked_dialogs( async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker", chunker_strategy: str = "RecursiveChunker",
group_id: str = "group_1", end_user_id: str = "group_1",
user_id: str = "user1",
apply_id: str = "applyid",
messages: list = None, messages: list = None,
ref_id: str = "wyl_20251027", ref_id: str = "wyl_20251027",
config_id: str = None config_id: str = None
@@ -20,9 +18,7 @@ async def get_chunked_dialogs(
Args: Args:
chunker_strategy: The chunking strategy to use (default: RecursiveChunker) chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
group_id: Group identifier end_user_id: Group identifier
user_id: User identifier
apply_id: Application identifier
messages: Structured message list [{"role": "user", "content": "..."}, ...] messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference identifier ref_id: Reference identifier
config_id: Configuration ID for processing config_id: Configuration ID for processing
@@ -58,9 +54,7 @@ async def get_chunked_dialogs(
dialog_data = DialogData( dialog_data = DialogData(
context=conversation_context, context=conversation_context,
ref_id=ref_id, ref_id=ref_id,
group_id=group_id, end_user_id=end_user_id,
user_id=user_id,
apply_id=apply_id,
config_id=config_id config_id=config_id
) )

View File

@@ -1,24 +1,23 @@
import os import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path
from typing import Annotated, TypedDict from typing import Annotated, TypedDict
from langchain_core.messages import AnyMessage from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages from langgraph.graph import add_messages
PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
class WriteState(TypedDict): class WriteState(TypedDict):
''' '''
Langgrapg Writing TypedDict Langgrapg Writing TypedDict
''' '''
messages: Annotated[list[AnyMessage], add_messages] messages: Annotated[list[AnyMessage], add_messages]
user_id:str end_user_id: str
apply_id:str
group_id:str
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}] errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
memory_config: object memory_config: object
write_result: dict write_result: dict
data:str data: str
class ReadState(TypedDict): class ReadState(TypedDict):
""" """
@@ -28,7 +27,7 @@ class ReadState(TypedDict):
messages: 消息列表,支持自动追加 messages: 消息列表,支持自动追加
loop_count: 遍历次数 loop_count: 遍历次数
search_switch: 搜索类型开关 search_switch: 搜索类型开关
group_id: 组标识 end_user_id: 组标识
config_id: 配置ID用于过滤结果 config_id: 配置ID用于过滤结果
data: 从content_input_node传递的内容数据 data: 从content_input_node传递的内容数据
spit_data: 从Split_The_Problem传递的分解结果 spit_data: 从Split_The_Problem传递的分解结果
@@ -39,7 +38,7 @@ class ReadState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式 messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式
loop_count: int loop_count: int
search_switch: str search_switch: str
group_id: str end_user_id: str
config_id: str config_id: str
data: str # 新增字段用于传递内容 data: str # 新增字段用于传递内容
spit_data: dict # 新增字段用于传递问题分解结果 spit_data: dict # 新增字段用于传递问题分解结果

View File

@@ -0,0 +1,61 @@
# 角色
你是一个智能问答助手,基于检索信息和历史对话回答用户问题。
# 任务
根据提供的上下文信息回答用户的问题。
# 输入信息
- 历史对话:{{history}}
- 检索信息:{{retrieve_info}}
# 用户问题
{{query}}
# 回答指南
## 1. 仔细阅读检索信息
- 答案可能直接或间接地出现在检索信息中
- 如果检索信息中提到"小曼会使用Python",说明用户名是"小曼"
- 第三人称描述的偏好、行为通常指用户本人
## 2. 判断信息相关性
**情况A信息匹配问题**
- 直接回答,像自然对话一样
- 例:检索到"小曼会使用Python" → 问"我叫什么" → 答"你叫小曼"
**情况B信息部分相关**
- 先回答已知部分,再自然地询问更多信息
- 例:检索到"用户去过上海的面包店" → 问"我吃过哪家面包" → 答"我记得你去过上海的面包店,但具体是哪家我不太清楚,是哪家呢?"
**情况C信息完全不相关**
- 自然地表达不知道,但可以提及检索到的相关信息,让对话更连贯
- 使用友好的表达:
- "你好像没和我说过...,但是我知道你[检索到的相关信息]"
- "关于这个我不太清楚,不过我记得你[检索到的相关信息],能告诉我更多吗?"
- "我不记得你提到过...,但你[检索到的相关信息]"
- 即使检索信息不直接回答问题,也可以自然地融入对话中
- 避免僵硬的"信息不足,无法回答"
## 3. 回答要求
- 像人类对话一样自然流畅
- 不要提及"检索信息"、"搜索结果"、"根据资料"等技术术语
- 不要解释推理过程或引用信息来源
- 保持友好、乐于助人的语气
- 使用与问题相同的语言回答
# 关键示例
**示例1 - 直接匹配:**
- 检索信息:"小曼会使用Python..."
- 问题:"我叫什么"
- ✓ 正确:"你叫小曼"
- ✗ 错误:"你没有告诉我你的名字"
**示例2 - 间接匹配:**
- 检索信息:"用户很喜欢吃星巴克的甜品"
- 问题:"我喜欢什么"
- ✓ 正确:"你很喜欢吃星巴克的甜品"
- ✗ 错误:"信息不足"
**示例3 - 信息不匹配(推荐做法):**
- 检索信息:"用户只喝拿铁咖啡,认为美式咖啡太苦"
- 问题:"我吃过哪家面包"
- ✓ 最佳:"你好像没和我说过吃过哪家面包,但是我知道你喜欢喝拿铁,能跟我分享一下吗?"
- ✓ 可以:"你好像没和我说过吃过哪家面包,能跟我分享一下吗?"
- ✗ 错误:"用户只喝拿铁咖啡,认为美式咖啡太苦。"(答非所问)
- ✗ 错误:"信息不足,无法回答。"(太僵硬)
# 重要提醒
- 检索信息中描述用户行为/偏好时提到的名字,就是用户的名字
- 信息不匹配时,不要强行回答无关内容,但可以自然地提及检索到的信息,让对话更有温度
- 用对话式语言表达"不知道",而非机械模板
- 检索信息代表你对用户的了解,即使不直接回答问题,也能体现你对用户的记忆

View File

@@ -0,0 +1,43 @@
{# 角色定义 #}
你是专业的问题解答专家+引导学者
{# 输入数据展示 #}
{% if data %}
## 输入数据
上下文信息:
{% for item in data.history %}
- {{ item }}
{% endfor %}
检索到的所有信息:
{% for item in data.retrieve_info %}
- {{ item }}
{% endfor %}
{% endif %}
## User Query
{{ query }}
{# 问题回答标准 #}
## 问题回答核心标准
根据上下文信息(history)和检索到的所有信息(retrieve_info)准确回答用户的问题(query)。
注意,仔细阅读检索信息,答案可能直接或间接地出现在检索信息中或者历史上下文消息中,同时需要 判断信息相关性
**情况A信息匹配问题**
- 直接回答,像自然对话一样
- 例:检索到"小曼会使用Python" → 问"我叫什么" → 答"你叫小曼"
**情况B信息部分相关**
- 先回答已知部分,再自然地询问更多信息
- 例:检索到"用户去过上海的面包店" → 问"我吃过哪家面包" → 答"我记得你去过上海的面包店,但具体是哪家我不太清楚,是哪家呢?"
**情况C信息完全不相关**
- 自然地表达不知道,但可以提及检索到的相关信息,让对话更连贯
- 使用友好的表达:
- "你好像没和我说过...,但是我知道你[检索到的相关信息]"
- "关于这个我不太清楚,不过我记得你[检索到的相关信息],能告诉我更多吗?"
- "我不记得你提到过...,但你[检索到的相关信息]"
- 即使检索信息不直接回答问题,也可以自然地融入对话中
- 避免僵硬的"信息不足,无法回答"
{# 重要提醒 #}
当检索以及上下文的历史信息都无法回答的时候,可引导对方进行提问/回答,或者进行其他引导
当检索或者上下文中出现了,相似的问题,可以委婉,提醒对方,我记得刚刚提过这个问题,但是我自己不记得了,能在描述一次吗~以此为例

View File

@@ -28,7 +28,7 @@ class RedisSessionStore:
return text return text
# 修改后的 save_session 方法 # 修改后的 save_session 方法
def save_session(self, userid, messages, aimessages, apply_id, group_id): def save_session(self, userid, messages, aimessages, apply_id, end_user_id):
""" """
写入一条会话数据,返回 session_id 写入一条会话数据,返回 session_id
优化版本确保写入时间不超过1秒 优化版本确保写入时间不超过1秒
@@ -46,7 +46,7 @@ class RedisSessionStore:
"id": self.uudi, "id": self.uudi,
"sessionid": userid, "sessionid": userid,
"apply_id": apply_id, "apply_id": apply_id,
"group_id": group_id, "end_user_id": end_user_id,
"messages": messages, "messages": messages,
"aimessages": aimessages, "aimessages": aimessages,
"starttime": starttime "starttime": starttime
@@ -67,7 +67,7 @@ class RedisSessionStore:
def save_sessions_batch(self, sessions_data): def save_sessions_batch(self, sessions_data):
""" """
批量写入多条会话数据,返回 session_id 列表 批量写入多条会话数据,返回 session_id 列表
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, group_id sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id
优化版本:批量操作,大幅提升性能 优化版本:批量操作,大幅提升性能
""" """
try: try:
@@ -83,7 +83,7 @@ class RedisSessionStore:
"id": self.uudi, "id": self.uudi,
"sessionid": session.get('userid'), "sessionid": session.get('userid'),
"apply_id": session.get('apply_id'), "apply_id": session.get('apply_id'),
"group_id": session.get('group_id'), "end_user_id": session.get('end_user_id'),
"messages": session.get('messages'), "messages": session.get('messages'),
"aimessages": session.get('aimessages'), "aimessages": session.get('aimessages'),
"starttime": starttime "starttime": starttime
@@ -108,9 +108,9 @@ class RedisSessionStore:
data = self.r.hgetall(key) data = self.r.hgetall(key)
return data if data else None return data if data else None
def get_session_apply_group(self, sessionid, apply_id, group_id): def get_session_apply_group(self, sessionid, apply_id, end_user_id):
""" """
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据 根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据
""" """
result_items = [] result_items = []
@@ -124,7 +124,7 @@ class RedisSessionStore:
# 检查三个条件是否都匹配 # 检查三个条件是否都匹配
if (data.get('sessionid') == sessionid and if (data.get('sessionid') == sessionid and
data.get('apply_id') == apply_id and data.get('apply_id') == apply_id and
data.get('group_id') == group_id): data.get('end_user_id') == end_user_id):
result_items.append(data) result_items.append(data)
return result_items return result_items
@@ -172,7 +172,7 @@ class RedisSessionStore:
def delete_duplicate_sessions(self): def delete_duplicate_sessions(self):
""" """
删除重复会话数据,条件: 删除重复会话数据,条件:
"sessionid""user_id""group_id""messages""aimessages" 五个字段都相同的只保留一个,其他删除 "sessionid""user_id""end_user_id""messages""aimessages" 五个字段都相同的只保留一个,其他删除
优化版本:使用 pipeline 批量操作确保在1秒内完成 优化版本:使用 pipeline 批量操作确保在1秒内完成
""" """
import time import time
@@ -202,12 +202,12 @@ class RedisSessionStore:
# 获取五个字段的值 # 获取五个字段的值
sessionid = data.get('sessionid', '') sessionid = data.get('sessionid', '')
user_id = data.get('id', '') user_id = data.get('id', '')
group_id = data.get('group_id', '') end_user_id = data.get('end_user_id', '')
messages = data.get('messages', '') messages = data.get('messages', '')
aimessages = data.get('aimessages', '') aimessages = data.get('aimessages', '')
# 用五元组作为唯一标识 # 用五元组作为唯一标识
identifier = (sessionid, user_id, group_id, messages, aimessages) identifier = (sessionid, user_id, end_user_id, messages, aimessages)
if identifier in seen: if identifier in seen:
# 重复,标记为待删除 # 重复,标记为待删除
@@ -248,9 +248,9 @@ class RedisSessionStore:
result_items = [] result_items = []
return (result_items) return (result_items)
def find_user_apply_group(self, sessionid, apply_id, group_id): def find_user_apply_group(self, sessionid, apply_id, end_user_id):
""" """
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据返回最新的6条 根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据返回最新的6条
""" """
import time import time
start_time = time.time() start_time = time.time()
@@ -276,7 +276,7 @@ class RedisSessionStore:
# 检查是否符合三个条件 # 检查是否符合三个条件
if (data.get('apply_id') == apply_id and if (data.get('apply_id') == apply_id and
data.get('group_id') == group_id): data.get('end_user_id') == end_user_id):
# 支持模糊匹配 sessionid 或者完全匹配 # 支持模糊匹配 sessionid 或者完全匹配
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid: if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
matched_items.append({ matched_items.append({

View File

@@ -59,7 +59,7 @@ class SessionService:
self, self,
user_id: str, user_id: str,
apply_id: str, apply_id: str,
group_id: str end_user_id: str
) -> List[dict]: ) -> List[dict]:
""" """
Retrieve conversation history from Redis. Retrieve conversation history from Redis.
@@ -67,20 +67,20 @@ class SessionService:
Args: Args:
user_id: User identifier user_id: User identifier
apply_id: Application identifier apply_id: Application identifier
group_id: Group identifier end_user_id: Group identifier
Returns: Returns:
List of conversation history items with Query and Answer keys List of conversation history items with Query and Answer keys
Returns empty list if no history found or on error Returns empty list if no history found or on error
""" """
try: try:
history = self.store.find_user_apply_group(user_id, apply_id, group_id) history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
# Validate history structure # Validate history structure
if not isinstance(history, list): if not isinstance(history, list):
logger.warning( logger.warning(
f"Invalid history format for user {user_id}, " f"Invalid history format for user {user_id}, "
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}" f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
) )
return [] return []
@@ -89,7 +89,7 @@ class SessionService:
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to retrieve history for user {user_id}, " f"Failed to retrieve history for user {user_id}, "
f"apply {apply_id}, group {group_id}: {e}", f"apply {apply_id}, group {end_user_id}: {e}",
exc_info=True exc_info=True
) )
# Return empty list on error to allow execution to continue # Return empty list on error to allow execution to continue
@@ -100,7 +100,7 @@ class SessionService:
user_id: str, user_id: str,
query: str, query: str,
apply_id: str, apply_id: str,
group_id: str, end_user_id: str,
ai_response: str ai_response: str
) -> Optional[str]: ) -> Optional[str]:
""" """
@@ -110,7 +110,7 @@ class SessionService:
user_id: User identifier user_id: User identifier
query: User query/message query: User query/message
apply_id: Application identifier apply_id: Application identifier
group_id: Group identifier end_user_id: Group identifier
ai_response: AI response/answer ai_response: AI response/answer
Returns: Returns:
@@ -131,7 +131,7 @@ class SessionService:
userid=user_id, userid=user_id,
messages=query, messages=query,
apply_id=apply_id, apply_id=apply_id,
group_id=group_id, end_user_id=end_user_id,
aimessages=ai_response aimessages=ai_response
) )
@@ -152,7 +152,7 @@ class SessionService:
Duplicates are identified by matching: Duplicates are identified by matching:
- sessionid - sessionid
- user_id (id field) - user_id (id field)
- group_id - end_user_id
- messages - messages
- aimessages - aimessages

View File

@@ -29,9 +29,7 @@ logger = get_agent_logger(__name__)
async def write( async def write(
user_id: str, end_user_id: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig, memory_config: MemoryConfig,
messages: list, messages: list,
ref_id: str = "wyl20251027", ref_id: str = "wyl20251027",
@@ -42,7 +40,7 @@ async def write(
Args: Args:
user_id: User identifier user_id: User identifier
apply_id: Application identifier apply_id: Application identifier
group_id: Group identifier end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration memory_config: MemoryConfig object containing all configuration
messages: Structured message list [{"role": "user", "content": "..."}, ...] messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference ID, defaults to "wyl20251027" ref_id: Reference ID, defaults to "wyl20251027"
@@ -58,7 +56,7 @@ async def write(
logger.info(f"LLM model: {memory_config.llm_model_name}") logger.info(f"LLM model: {memory_config.llm_model_name}")
logger.info(f"Embedding model: {memory_config.embedding_model_name}") logger.info(f"Embedding model: {memory_config.embedding_model_name}")
logger.info(f"Chunker strategy: {chunker_strategy}") logger.info(f"Chunker strategy: {chunker_strategy}")
logger.info(f"Group ID: {group_id}") logger.info(f"end_user_id ID: {end_user_id}")
# Construct clients from memory_config using factory pattern with db session # Construct clients from memory_config using factory pattern with db session
with get_db_context() as db: with get_db_context() as db:
@@ -83,9 +81,7 @@ async def write(
step_start = time.time() step_start = time.time()
chunked_dialogs = await get_chunked_dialogs( chunked_dialogs = await get_chunked_dialogs(
chunker_strategy=chunker_strategy, chunker_strategy=chunker_strategy,
group_id=group_id, end_user_id=end_user_id,
user_id=user_id,
apply_id=apply_id,
messages=messages, messages=messages,
ref_id=ref_id, ref_id=ref_id,
config_id=config_id, config_id=config_id,

View File

@@ -139,7 +139,8 @@ def parse_api_docs(file_path: str) -> Dict[str, Any]:
def get_default_docs_path() -> str: def get_default_docs_path() -> str:
project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) from pathlib import Path
project_root = str(Path(__file__).resolve().parents[2])
return os.path.join(project_root, "src", "analytics", "API接口.md") return os.path.join(project_root, "src", "analytics", "API接口.md")

View File

@@ -16,13 +16,13 @@ class FilteredTags(BaseModel):
"""用于接收LLM筛选后的核心标签列表的模型。""" """用于接收LLM筛选后的核心标签列表的模型。"""
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。") meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]: async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
""" """
使用LLM筛选标签列表仅保留具有代表性的核心名词。 使用LLM筛选标签列表仅保留具有代表性的核心名词。
Args: Args:
tags: 原始标签列表 tags: 原始标签列表
group_id: 用户组ID用于获取配置 end_user_id: 用户组ID用于获取配置
Returns: Returns:
筛选后的标签列表 筛选后的标签列表
@@ -37,12 +37,12 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
get_end_user_connected_config, get_end_user_connected_config,
) )
connected_config = get_end_user_connected_config(group_id, db) connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id") config_id = connected_config.get("memory_config_id")
if not config_id: if not config_id:
raise ValueError( raise ValueError(
f"No memory_config_id found for group_id: {group_id}. " f"No memory_config_id found for end_user_id: {end_user_id}. "
"Please ensure the user has a valid memory configuration." "Please ensure the user has a valid memory configuration."
) )
@@ -87,7 +87,7 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
async def get_raw_tags_from_db( async def get_raw_tags_from_db(
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: str, end_user_id: str,
limit: int, limit: int,
by_user: bool = False by_user: bool = False
) -> List[Tuple[str, int]]: ) -> List[Tuple[str, int]]:
@@ -99,9 +99,9 @@ async def get_raw_tags_from_db(
Args: Args:
connector: Neo4j连接器实例 connector: Neo4j连接器实例
group_id: 如果by_user=False则为group_id如果by_user=True则为user_id end_user_id: 如果by_user=False则为end_user_id如果by_user=True则为user_id
limit: 返回的标签数量限制 limit: 返回的标签数量限制
by_user: 是否按user_id查询默认Falsegroup_id查询 by_user: 是否按user_id查询默认Falseend_user_id查询
Returns: Returns:
List[Tuple[str, int]]: 标签名称和频率的元组列表 List[Tuple[str, int]]: 标签名称和频率的元组列表
@@ -119,7 +119,7 @@ async def get_raw_tags_from_db(
else: else:
query = ( query = (
"MATCH (e:ExtractedEntity) " "MATCH (e:ExtractedEntity) "
"WHERE e.group_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude " "WHERE e.end_user_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
"RETURN e.name AS name, count(e) AS frequency " "RETURN e.name AS name, count(e) AS frequency "
"ORDER BY frequency DESC " "ORDER BY frequency DESC "
"LIMIT $limit" "LIMIT $limit"
@@ -128,44 +128,44 @@ async def get_raw_tags_from_db(
# 使用项目的Neo4jConnector执行查询 # 使用项目的Neo4jConnector执行查询
results = await connector.execute_query( results = await connector.execute_query(
query, query,
id=group_id, id=end_user_id,
limit=limit, limit=limit,
names_to_exclude=names_to_exclude names_to_exclude=names_to_exclude
) )
return [(record["name"], record["frequency"]) for record in results] return [(record["name"], record["frequency"]) for record in results]
async def get_hot_memory_tags(group_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]: async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
""" """
获取原始标签然后使用LLM进行筛选返回最终的热门标签列表。 获取原始标签然后使用LLM进行筛选返回最终的热门标签列表。
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。 查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
Args: Args:
group_id: 必需参数。如果by_user=False则为group_id如果by_user=True则为user_id end_user_id: 必需参数。如果by_user=False则为end_user_id如果by_user=True则为user_id
limit: 返回的标签数量限制 limit: 返回的标签数量限制
by_user: 是否按user_id查询默认Falsegroup_id查询 by_user: 是否按user_id查询默认Falseend_user_id查询
Raises: Raises:
ValueError: 如果group_id未提供或为空 ValueError: 如果end_user_id未提供或为空
""" """
# 验证group_id必须提供且不为空 # 验证end_user_id必须提供且不为空
if not group_id or not group_id.strip(): if not end_user_id or not end_user_id.strip():
raise ValueError( raise ValueError(
"group_id is required. Please provide a valid group_id or user_id." "end_user_id is required. Please provide a valid end_user_id or user_id."
) )
# 使用项目的Neo4jConnector # 使用项目的Neo4jConnector
connector = Neo4jConnector() connector = Neo4jConnector()
try: try:
# 1. 从数据库获取原始排名靠前的标签 # 1. 从数据库获取原始排名靠前的标签
raw_tags_with_freq = await get_raw_tags_from_db(connector, group_id, limit, by_user=by_user) raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user)
if not raw_tags_with_freq: if not raw_tags_with_freq:
return [] return []
raw_tag_names = [tag for tag, freq in raw_tags_with_freq] raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签 # 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id) meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, end_user_id)
# 3. 根据LLM的筛选结果构建最终的标签列表保留原始频率和顺序 # 3. 根据LLM的筛选结果构建最终的标签列表保留原始频率和顺序
final_tags = [] final_tags = []

View File

@@ -75,8 +75,8 @@ class MemoryDataSource:
start_date = time_range.start_date if time_range else None start_date = time_range.start_date if time_range else None
end_date = time_range.end_date if time_range else None end_date = time_range.end_date if time_range else None
summary_dicts = await self.memory_summary_repo.find_by_group_id( summary_dicts = await self.memory_summary_repo.find_by_end_user_id(
group_id=user_id, end_user_id=user_id,
limit=limit, limit=limit,
start_date=start_date, start_date=start_date,
end_date=end_date end_date=end_date

View File

@@ -2,13 +2,16 @@ import os
import re import re
import glob import glob
import json import json
from pathlib import Path
from typing import Tuple from typing import Tuple
try: try:
from app.core.memory.utils.config.definitions import PROJECT_ROOT from app.core.memory.utils.config.definitions import PROJECT_ROOT
except Exception: except Exception:
# Fallback: derive project root from this file location # Fallback: derive project root from this file location
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # 当前文件在 api/app/core/memory/analytics/recent_activity_stats.py
# 需要向上 5 级到达 api/ 目录
PROJECT_ROOT = str(Path(__file__).resolve().parents[4])
def _get_latest_prompt_log_path() -> str | None: def _get_latest_prompt_log_path() -> str | None:
@@ -67,44 +70,43 @@ def parse_stats_from_log(log_path: str) -> dict:
triplet_relations_count = 0 triplet_relations_count = 0
temporal_count = 0 temporal_count = 0
# Patterns # 正则表达式模式 - 匹配当前日志格式
pat_chunk_render = re.compile(r"===\s*RENDERED\s*STATEMENT\s*EXTRACTION\s*PROMPT\s*===") pat_chunk_render = re.compile(r"===\s*RENDERED\s*STATEMENT\s*EXTRACTION\s*PROMPT\s*===")
pat_triplet_start = re.compile(r"\[Triplet\].*statements_to_process\s*=\s*(\d+)") pat_triplet_started = re.compile(r"\[Triplet\]\s+Started\s+-\s+statement_id=")
pat_triplet_done = re.compile( pat_triplet_completed = re.compile(
r"\[Triplet\].*completed,\s*total_triplets\s*=\s*(\d+),\s*total_entities\s*=\s*(\d+)" r"\[Triplet\]\s+Completed\s+-\s+statement_id=[^,]+,\s+triplets=(\d+),\s+entities=(\d+)"
) )
pat_temporal_done = re.compile( pat_temporal_completed = re.compile(
r"\[Temporal\].*completed,\s*extracted_valid_ranges\s*=\s*(\d+)" r"\[Temporal\]\s+Completed\s+-\s+statement_id=[^,]+,\s+valid_ranges=(\d+)"
) )
with open(log_path, "r", encoding="utf-8", errors="ignore") as f: with open(log_path, "r", encoding="utf-8", errors="ignore") as f:
for line in f: for line in f:
# Chunk prompts count (each chunk triggers one statement-extraction prompt render) # 文本块数量(每个块触发一次陈述提取提示)
if pat_chunk_render.search(line): if pat_chunk_render.search(line):
chunk_count += 1 chunk_count += 1
continue continue
m1 = pat_triplet_start.search(line) # 陈述数量(每个 Triplet Started 代表一个陈述被处理)
if m1: if pat_triplet_started.search(line):
statements_count += 1
continue
# 三元组完成:[Triplet] Completed - statement_id=xxx, triplets=X, entities=Y
m_triplet = pat_triplet_completed.search(line)
if m_triplet:
try: try:
statements_count += int(m1.group(1)) triplet_relations_count += int(m_triplet.group(1))
triplet_entities_count += int(m_triplet.group(2))
except Exception: except Exception:
pass pass
continue continue
m2 = pat_triplet_done.search(line) # 时间信息完成:[Temporal] Completed - statement_id=xxx, valid_ranges=X
if m2: m_temporal = pat_temporal_completed.search(line)
if m_temporal:
try: try:
triplet_relations_count += int(m2.group(1)) temporal_count += int(m_temporal.group(1))
triplet_entities_count += int(m2.group(2))
except Exception:
pass
continue
m3 = pat_temporal_done.search(line)
if m3:
try:
temporal_count += int(m3.group(1))
except Exception: except Exception:
pass pass
continue continue
@@ -120,15 +122,20 @@ def parse_stats_from_log(log_path: str) -> dict:
def get_recent_activity_stats() -> Tuple[dict, str]: def get_recent_activity_stats() -> Tuple[dict, str]:
"""Get aggregated stats from all prompt logs in logs/. """Get stats from the latest prompt log file only.
Returns (stats_dict, message). Returns (stats_dict, message).
""" """
all_logs = _get_all_prompt_logs() # 获取最新的日志文件
# Fallback to recursive search if none found in logs/ latest_log = _get_latest_prompt_log_path()
if not all_logs:
# 如果没有找到,尝试递归搜索
if not latest_log:
all_logs = _get_any_logs_recursive() all_logs = _get_any_logs_recursive()
if not all_logs: if all_logs:
latest_log = all_logs[-1] # 取最新的
if not latest_log:
return ( return (
{ {
"chunk_count": 0, "chunk_count": 0,
@@ -141,24 +148,13 @@ def get_recent_activity_stats() -> Tuple[dict, str]:
"未找到日志文件,请确认已运行过提取流程。", "未找到日志文件,请确认已运行过提取流程。",
) )
agg = { # 只解析最新的日志文件
"chunk_count": 0, stats = parse_stats_from_log(latest_log)
"statements_count": 0,
"triplet_entities_count": 0,
"triplet_relations_count": 0,
"temporal_count": 0,
}
for path in all_logs:
s = parse_stats_from_log(path)
agg["chunk_count"] += s.get("chunk_count", 0)
agg["statements_count"] += s.get("statements_count", 0)
agg["triplet_entities_count"] += s.get("triplet_entities_count", 0)
agg["triplet_relations_count"] += s.get("triplet_relations_count", 0)
agg["temporal_count"] += s.get("temporal_count", 0)
# Attach a summary of files combined # 添加日志文件路径信息
agg["log_path"] = f"{len(all_logs)} 个日志文件,最新:{all_logs[-1]}" stats["log_path"] = f"最新:{latest_log}"
return agg, "成功汇总 logs 目录中所有提示日志。"
return stats, "成功读取最近一次记忆活动统计。"
def _format_summary(stats: dict) -> str: def _format_summary(stats: dict) -> str:

View File

@@ -1 +0,0 @@
"""Evaluation package with dataset-specific pipelines and a unified runner."""

View File

@@ -1,30 +0,0 @@
⏬数据集下载地址:
Locomo10.jsonhttps://github.com/snap-research/locomo/tree/main/data
LongMemEval_oracle.jsonhttps://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned
msc_self_instruct.jsonl:https://huggingface.co/datasets/MemGPT/MSC-Self-Instruct
上方数据集下载好后全部放入app/core/memory/data文件夹中
全流程基准测试运行:
locomo
python -m app.core.memory.evaluation.run_eval --dataset locomo --sample-size 1 --reset-group --group-id yyw1 --search-type hybrid --search-limit 8 --context-char-budget 12000 --llm-max-tokens 32
LongMemEval
python -m app.core.memory.evaluation.run_eval --dataset longmemeval --sample-size 10 --start-index 0 --group-id longmemeval_zh_bak_2 --search-limit 8 --context-char-budget 4000 --search-type hybrid --max-contexts-per-item 2 --reset-group
memsciqa
python -m app.core.memory.evaluation.run_eval --dataset memsciqa --sample-size 10 --reset-group --group-id group_memsci
单独检索评估运行命令:
python -m app.core.memory.evaluation.locomo.locomo_test
python -m app.core.memory.evaluation.longmemeval.test_eval
python -m app.core.memory.evaluation.memsciqa.memsciqa-test
需要先在项目中修改需要检测评估的group_id。
参数及解释:
● --dataset longmemeval - 指定数据集
● --sample-size 10 - 评估10个样本
● --start-index 0 - 从第0个样本开始
● --group-id longmemeval_zh_bak_2 - 使用指定的组ID
● --search-limit 8 - 检索限制8条
● --context-char-budget 4000 - 上下文字符预算4000
● --search-type hybrid - 使用混合检索
● --max-contexts-per-item 2 - 每个样本最多摄入2个上下文
● --reset-group - 运行前清空组数据

View File

@@ -1,100 +0,0 @@
import math
import re
from typing import List, Dict
def _normalize(text: str) -> List[str]:
"""Lowercase, strip punctuation, and split into tokens."""
text = text.lower().strip()
# Python's re doesn't support \p classes; use a simple non-word filter
text = re.sub(r"[^\w\s]", " ", text)
tokens = [t for t in text.split() if t]
return tokens
def exact_match(pred: str, ref: str) -> float:
return float(_normalize(pred) == _normalize(ref))
def jaccard(pred: str, ref: str) -> float:
p = set(_normalize(pred))
r = set(_normalize(ref))
if not p and not r:
return 1.0
if not p or not r:
return 0.0
return len(p & r) / len(p | r)
def f1_score(pred: str, ref: str) -> float:
p_tokens = _normalize(pred)
r_tokens = _normalize(ref)
if not p_tokens and not r_tokens:
return 1.0
if not p_tokens or not r_tokens:
return 0.0
p_set = set(p_tokens)
r_set = set(r_tokens)
tp = len(p_set & r_set)
precision = tp / len(p_set) if p_set else 0.0
recall = tp / len(r_set) if r_set else 0.0
if precision + recall == 0:
return 0.0
return 2 * precision * recall / (precision + recall)
def bleu1(pred: str, ref: str) -> float:
"""Unigram BLEU (BLEU-1) with clipping and brevity penalty."""
p_tokens = _normalize(pred)
r_tokens = _normalize(ref)
if not p_tokens:
return 0.0
# Clipped count
r_counts: Dict[str, int] = {}
for t in r_tokens:
r_counts[t] = r_counts.get(t, 0) + 1
clipped = 0
p_counts: Dict[str, int] = {}
for t in p_tokens:
p_counts[t] = p_counts.get(t, 0) + 1
for t, c in p_counts.items():
clipped += min(c, r_counts.get(t, 0))
precision = clipped / max(len(p_tokens), 1)
# Brevity penalty
ref_len = len(r_tokens)
pred_len = len(p_tokens)
if pred_len > ref_len or pred_len == 0:
bp = 1.0
else:
bp = math.exp(1 - ref_len / max(pred_len, 1))
return bp * precision
def percentile(values: List[float], p: float) -> float:
if not values:
return 0.0
vals = sorted(values)
k = (len(vals) - 1) * p
f = math.floor(k)
c = math.ceil(k)
if f == c:
return vals[int(k)]
return vals[f] + (k - f) * (vals[c] - vals[f])
def latency_stats(latencies_ms: List[float]) -> Dict[str, float]:
"""Return basic latency stats: mean, p50, p95, iqr (p75-p25)."""
if not latencies_ms:
return {"mean": 0.0, "p50": 0.0, "p95": 0.0, "iqr": 0.0}
p25 = percentile(latencies_ms, 0.25)
p50 = percentile(latencies_ms, 0.50)
p75 = percentile(latencies_ms, 0.75)
p95 = percentile(latencies_ms, 0.95)
mean = sum(latencies_ms) / max(len(latencies_ms), 1)
return {"mean": mean, "p50": p50, "p95": p95, "iqr": p75 - p25}
def avg_context_tokens(contexts: List[str]) -> float:
if not contexts:
return 0.0
return sum(len(_normalize(c)) for c in contexts) / len(contexts)

View File

@@ -1,60 +0,0 @@
"""
Dialogue search queries for evaluation purposes.
This file contains Cypher queries for searching dialogues, entities, and chunks.
Placed in evaluation directory to avoid circular imports with src modules.
"""
# Entity search queries
SEARCH_ENTITIES_BY_NAME = """
MATCH (e:Entity)
WHERE e.name = $name
RETURN e
"""
SEARCH_ENTITIES_BY_NAME_FALLBACK = """
MATCH (e:Entity)
WHERE e.name CONTAINS $name
RETURN e
"""
# Chunk search queries
SEARCH_CHUNKS_BY_CONTENT = """
MATCH (c:Chunk)
WHERE c.content CONTAINS $content
RETURN c
"""
# Dialogue search queries
SEARCH_DIALOGUE_BY_DIALOG_ID = """
MATCH (d:Dialogue)
WHERE d.dialog_id = $dialog_id
RETURN d
"""
SEARCH_DIALOGUES_BY_CONTENT = """
MATCH (d:Dialogue)
WHERE d.content CONTAINS $q
RETURN d
"""
DIALOGUE_EMBEDDING_SEARCH = """
WITH $embedding AS q
MATCH (d:Dialogue)
WHERE d.dialog_embedding IS NOT NULL
AND ($group_id IS NULL OR d.group_id = $group_id)
WITH d, q, d.dialog_embedding AS v
WITH d,
reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot,
sqrt(reduce(qs = 0.0, i IN range(0, size(q)-1) | qs + toFloat(q[i]) * toFloat(q[i]))) AS qnorm,
sqrt(reduce(vs = 0.0, i IN range(0, size(v)-1) | vs + toFloat(v[i]) * toFloat(v[i]))) AS vnorm
WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score
WHERE score > $threshold
RETURN d.id AS dialog_id,
d.group_id AS group_id,
d.content AS content,
d.created_at AS created_at,
d.expired_at AS expired_at,
score
ORDER BY score DESC
LIMIT $limit
"""

View File

@@ -1,341 +0,0 @@
import asyncio
import json
import os
import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.core.memory.llm_tools.openai_client import LLMClient
from app.core.memory.models.message_models import (
ConversationContext,
ConversationMessage,
DialogData,
)
# 使用新的模块化架构
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
ExtractionOrchestrator,
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
DialogueChunker,
)
from app.core.memory.utils.config.definitions import (
SELECTED_CHUNKER_STRATEGY,
SELECTED_EMBEDDING_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
# Import from database module
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# Cypher queries for evaluation
# Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py
async def ingest_contexts_via_full_pipeline(
contexts: List[str],
group_id: str,
chunker_strategy: str | None = None,
embedding_name: str | None = None,
save_chunk_output: bool = False,
save_chunk_output_path: str | None = None,
) -> bool:
"""DEPRECATED: 此函数使用旧的流水线架构,建议使用新的 ExtractionOrchestrator
Run the full extraction pipeline on provided dialogue contexts and save to Neo4j.
This function mirrors the steps in main(), but starts from raw text contexts.
Args:
contexts: List of dialogue texts, each containing lines like "role: message".
group_id: Group ID to assign to generated DialogData and graph nodes.
chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY.
embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID.
save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging.
save_chunk_output_path: Optional output path; defaults to src/chunker_test_output.txt.
Returns:
True if data saved successfully, False otherwise.
"""
chunker_strategy = chunker_strategy or SELECTED_CHUNKER_STRATEGY
embedding_name = embedding_name or SELECTED_EMBEDDING_ID
# Initialize llm client with graceful fallback
llm_client = None
llm_available = True
try:
from app.core.memory.utils.config import definitions as config_defs
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
except Exception as e:
print(f"[Ingestion] LLM client unavailable, will skip LLM-dependent steps: {e}")
llm_available = False
# Step A: Build DialogData list from contexts with robust parsing
chunker = DialogueChunker(chunker_strategy)
dialog_data_list: List[DialogData] = []
for idx, ctx in enumerate(contexts):
messages: List[ConversationMessage] = []
# Improved parsing: capture multi-line message blocks, normalize roles
pattern = r"^\s*(用户|AI|assistant|user)\s*[:]\s*(.+?)(?=\n\s*(?:用户|AI|assistant|user)\s*[:]|\Z)"
matches = list(re.finditer(pattern, ctx, flags=re.MULTILINE | re.DOTALL))
if matches:
for m in matches:
raw_role = m.group(1).strip()
content = m.group(2).strip()
norm_role = "AI" if raw_role.lower() in ("ai", "assistant") else "用户"
messages.append(ConversationMessage(role=norm_role, msg=content))
else:
# Fallback: line-by-line parsing
for raw in ctx.split("\n"):
line = raw.strip()
if not line:
continue
m = re.match(r'^\s*([^:]+)\s*[:]\s*(.+)$', line)
if m:
role = m.group(1).strip()
msg = m.group(2).strip()
norm_role = "AI" if role.lower() in ("ai", "assistant") else "用户"
messages.append(ConversationMessage(role=norm_role, msg=msg))
else:
# Final fallback: treat as user message
default_role = "AI" if re.match(r'^\s*(assistant|AI)\b', line, flags=re.IGNORECASE) else "用户"
messages.append(ConversationMessage(role=default_role, msg=line))
context_model = ConversationContext(msgs=messages)
dialog = DialogData(
context=context_model,
ref_id=f"pipeline_item_{idx}",
group_id=group_id,
user_id="default_user",
apply_id="default_application",
)
# Generate chunks
dialog.chunks = await chunker.process_dialogue(dialog)
dialog_data_list.append(dialog)
if not dialog_data_list:
print("No dialogs to process for ingestion.")
return False
# Optionally save chunking outputs for debugging
if save_chunk_output:
try:
def _serialize_datetime(obj):
if isinstance(obj, datetime):
return obj.isoformat()
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
from app.core.config import settings
settings.ensure_memory_output_dir()
default_path = settings.get_memory_output_path("chunker_test_output.txt")
out_path = save_chunk_output_path or default_path
combined_output = [dd.model_dump() for dd in dialog_data_list]
with open(out_path, "w", encoding="utf-8") as f:
json.dump(combined_output, f, ensure_ascii=False, indent=4, default=_serialize_datetime)
print(f"Saved chunking results to: {out_path}")
except Exception as e:
print(f"Failed to save chunking results: {e}")
# Step B-G: 使用新的 ExtractionOrchestrator 执行完整的提取流水线
if not llm_available:
print("[Ingestion] Skipping extraction pipeline (no LLM).")
return False
# 初始化 embedder 客户端
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.services.memory_config_service import MemoryConfigService
try:
with get_db_context() as db:
embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID)
embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config)
except Exception as e:
print(f"[Ingestion] Failed to initialize embedder client: {e}")
print("[Ingestion] Skipping extraction pipeline (embedder initialization failed).")
return False
connector = Neo4jConnector()
# 初始化并运行 ExtractionOrchestrator
from app.core.memory.utils.config.config_utils import get_pipeline_config
config = get_pipeline_config()
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=connector,
config=config,
)
# 创建一个包装的 orchestrator 来修复时间提取器的输出
# 保存原始的 _assign_extracted_data 方法
original_assign = orchestrator._assign_extracted_data
def clean_temporal_value(value):
"""清理 temporal_validity 字段的值,将无效值转换为 None"""
if value is None:
return None
if isinstance(value, str):
# 处理字符串形式的 'null', 'None', 空字符串等
if value.lower() in ('null', 'none', '') or value.strip() == '':
return None
return value
async def patched_assign_extracted_data(*args, **kwargs):
"""包装方法:在赋值后清理 temporal_validity 中的无效字符串"""
result = await original_assign(*args, **kwargs)
# 清理返回的 dialog_data_list 中的 temporal_validity
for dialog in result:
if hasattr(dialog, 'chunks') and dialog.chunks:
for chunk in dialog.chunks:
if hasattr(chunk, 'statements') and chunk.statements:
for statement in chunk.statements:
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
tv = statement.temporal_validity
# 清理 valid_at 和 invalid_at
if hasattr(tv, 'valid_at'):
tv.valid_at = clean_temporal_value(tv.valid_at)
if hasattr(tv, 'invalid_at'):
tv.invalid_at = clean_temporal_value(tv.invalid_at)
return result
# 替换方法
orchestrator._assign_extracted_data = patched_assign_extracted_data
# 同时包装 _create_nodes_and_edges 方法,在创建节点前再次清理
original_create = orchestrator._create_nodes_and_edges
async def patched_create_nodes_and_edges(dialog_data_list_arg):
"""包装方法:在创建节点前再次清理 temporal_validity"""
# 最后一次清理,确保万无一失
for dialog in dialog_data_list_arg:
if hasattr(dialog, 'chunks') and dialog.chunks:
for chunk in dialog.chunks:
if hasattr(chunk, 'statements') and chunk.statements:
for statement in chunk.statements:
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
tv = statement.temporal_validity
if hasattr(tv, 'valid_at'):
tv.valid_at = clean_temporal_value(tv.valid_at)
if hasattr(tv, 'invalid_at'):
tv.invalid_at = clean_temporal_value(tv.invalid_at)
return await original_create(dialog_data_list_arg)
orchestrator._create_nodes_and_edges = patched_create_nodes_and_edges
# 运行完整的提取流水线
# orchestrator.run 返回 7 个元素的元组
result = await orchestrator.run(dialog_data_list, is_pilot_run=False)
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
statement_chunk_edges,
statement_entity_edges,
entity_entity_edges,
) = result
# statement_chunk_edges 已经由 orchestrator 创建,无需重复创建
# Step G: 生成记忆摘要
print("[Ingestion] Generating memory summaries...")
try:
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
memory_summary_generation,
)
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
summaries = await memory_summary_generation(
chunked_dialogs=dialog_data_list,
llm_client=llm_client,
embedder_client=embedder_client
)
print(f"[Ingestion] Generated {len(summaries)} memory summaries")
except Exception as e:
print(f"[Ingestion] Warning: Failed to generate memory summaries: {e}")
summaries = []
# Step H: Save to Neo4j
try:
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=dialogue_nodes,
chunk_nodes=chunk_nodes,
statement_nodes=statement_nodes,
entity_nodes=entity_nodes,
entity_edges=entity_entity_edges,
statement_chunk_edges=statement_chunk_edges,
statement_entity_edges=statement_entity_edges,
connector=connector
)
# Save memory summaries separately
if summaries:
try:
await add_memory_summary_nodes(summaries, connector)
await add_memory_summary_statement_edges(summaries, connector)
print(f"Successfully saved {len(summaries)} memory summary nodes to Neo4j")
except Exception as e:
print(f"Warning: Failed to save summary nodes: {e}")
await connector.close()
if success:
print("Successfully saved extracted data to Neo4j!")
else:
print("Failed to save data to Neo4j")
return success
except Exception as e:
print(f"Failed to save data to Neo4j: {e}")
return False
async def handle_context_processing(args):
"""Handle context-based processing from command line arguments."""
contexts = []
if args.contexts:
contexts.extend(args.contexts)
if args.context_file:
try:
with open(args.context_file, 'r', encoding='utf-8') as f:
contexts.extend(line.strip() for line in f if line.strip())
except Exception as e:
print(f"Error reading context file: {e}")
return False
if not contexts:
print("No contexts provided for processing.")
return False
return await main_from_contexts(contexts, args.context_group_id)
async def main_from_contexts(contexts: List[str], group_id: str):
"""Run the pipeline from provided dialogue contexts instead of test data."""
print("=== Running pipeline from provided contexts ===")
success = await ingest_contexts_via_full_pipeline(
contexts=contexts,
group_id=group_id,
chunker_strategy=SELECTED_CHUNKER_STRATEGY,
embedding_name=SELECTED_EMBEDDING_ID,
save_chunk_output=True
)
if success:
print("Successfully processed and saved contexts to Neo4j!")
else:
print("Failed to process contexts.")
return success

View File

@@ -1,575 +0,0 @@
"""
LoCoMo Benchmark Script
This module provides the main entry point for running LoCoMo benchmark evaluations.
It orchestrates data loading, ingestion, retrieval, LLM inference, and metric calculation
in a clean, maintainable way.
Usage:
python locomo_benchmark.py --sample_size 20 --search_type hybrid
"""
import argparse
import asyncio
import json
import os
import time
from datetime import datetime
from typing import Any, Dict, List, Optional
try:
from dotenv import load_dotenv
except ImportError:
def load_dotenv():
pass
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
bleu1,
f1_score,
jaccard,
latency_stats,
)
from app.core.memory.evaluation.locomo.locomo_metrics import (
get_category_name,
locomo_f1_score,
locomo_multi_f1,
)
from app.core.memory.evaluation.locomo.locomo_utils import (
extract_conversations,
ingest_conversations_if_needed,
load_locomo_data,
resolve_temporal_references,
retrieve_relevant_information,
select_and_format_information,
)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
async def run_locomo_benchmark(
sample_size: int = 20,
group_id: Optional[str] = None,
search_type: str = "hybrid",
search_limit: int = 12,
context_char_budget: int = 8000,
reset_group: bool = False,
skip_ingest: bool = False,
output_dir: Optional[str] = None
) -> Dict[str, Any]:
"""
Run LoCoMo benchmark evaluation.
This function orchestrates the complete evaluation pipeline:
1. Load LoCoMo dataset (only QA pairs from first conversation)
2. Check/ingest conversations into database (only first conversation, unless skip_ingest=True)
3. For each question:
- Retrieve relevant information
- Generate answer using LLM
- Calculate metrics
4. Aggregate results and save to file
Note: By default, only the first conversation is ingested into the database,
and only QA pairs from that conversation are evaluated. This ensures that
all questions have corresponding memory in the database for retrieval.
Args:
sample_size: Number of QA pairs to evaluate (from first conversation)
group_id: Database group ID for retrieval (uses default if None)
search_type: "keyword", "embedding", or "hybrid"
search_limit: Max documents to retrieve per query
context_char_budget: Max characters for context
reset_group: Whether to clear and re-ingest data (not implemented)
skip_ingest: If True, skip data ingestion and use existing data in Neo4j
output_dir: Directory to save results (uses default if None)
Returns:
Dictionary with evaluation results including metrics, timing, and samples
"""
# Use default group_id if not provided
group_id = group_id or SELECTED_GROUP_ID
# Determine data path
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
if not os.path.exists(data_path):
# Fallback to current directory
data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
print(f"\n{'='*60}")
print("🚀 Starting LoCoMo Benchmark Evaluation")
print(f"{'='*60}")
print("📊 Configuration:")
print(f" Sample size: {sample_size}")
print(f" Group ID: {group_id}")
print(f" Search type: {search_type}")
print(f" Search limit: {search_limit}")
print(f" Context budget: {context_char_budget} chars")
print(f" Data path: {data_path}")
print(f"{'='*60}\n")
# Step 1: Load LoCoMo data
print("📂 Loading LoCoMo dataset...")
try:
# Only load QA pairs from the first conversation (index 0)
# since we only ingest the first conversation into the database
qa_items = load_locomo_data(data_path, sample_size, conversation_index=0)
print(f"✅ Loaded {len(qa_items)} QA pairs from conversation 0\n")
except Exception as e:
print(f"❌ Failed to load data: {e}")
return {
"error": f"Data loading failed: {e}",
"timestamp": datetime.now().isoformat()
}
# Step 2: Extract conversations and ingest if needed
if skip_ingest:
print("⏭️ Skipping data ingestion (using existing data in Neo4j)")
print(f" Group ID: {group_id}\n")
else:
print("💾 Checking database ingestion...")
try:
conversations = extract_conversations(data_path, max_dialogues=1)
print(f"📝 Extracted {len(conversations)} conversations")
# Always ingest for now (ingestion check not implemented)
print(f"🔄 Ingesting conversations into group '{group_id}'...")
success = await ingest_conversations_if_needed(
conversations=conversations,
group_id=group_id,
reset=reset_group
)
if success:
print("✅ Ingestion completed successfully\n")
else:
print("⚠️ Ingestion may have failed, continuing anyway\n")
except Exception as e:
print(f"❌ Ingestion failed: {e}")
print("⚠️ Continuing with evaluation (database may be empty)\n")
# Step 3: Initialize clients
print("🔧 Initializing clients...")
connector = Neo4jConnector()
# Initialize LLM client with database context
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
# Initialize embedder
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)
print("✅ Clients initialized\n")
# Step 4: Process questions
print(f"🔍 Processing {len(qa_items)} questions...")
print(f"{'='*60}\n")
# Tracking variables
latencies_search: List[float] = []
latencies_llm: List[float] = []
context_counts: List[int] = []
context_chars: List[int] = []
context_tokens: List[int] = []
# Metric lists
f1_scores: List[float] = []
bleu1_scores: List[float] = []
jaccard_scores: List[float] = []
locomo_f1_scores: List[float] = []
# Per-category tracking
category_counts: Dict[str, int] = {}
category_f1: Dict[str, List[float]] = {}
category_bleu1: Dict[str, List[float]] = {}
category_jaccard: Dict[str, List[float]] = {}
category_locomo_f1: Dict[str, List[float]] = {}
# Detailed samples
samples: List[Dict[str, Any]] = []
# Fixed anchor date for temporal resolution
anchor_date = datetime(2023, 5, 8)
try:
for idx, item in enumerate(qa_items, 1):
question = item.get("question", "")
ground_truth = item.get("answer", "")
category = get_category_name(item)
# Ensure ground truth is a string
ground_truth_str = str(ground_truth) if ground_truth is not None else ""
print(f"[{idx}/{len(qa_items)}] Category: {category}")
print(f"❓ Question: {question}")
print(f"✅ Ground Truth: {ground_truth_str}")
# Step 4a: Retrieve relevant information
t_search_start = time.time()
try:
retrieved_info = await retrieve_relevant_information(
question=question,
group_id=group_id,
search_type=search_type,
search_limit=search_limit,
connector=connector,
embedder=embedder
)
t_search_end = time.time()
search_latency = (t_search_end - t_search_start) * 1000
latencies_search.append(search_latency)
print(f"🔍 Retrieved {len(retrieved_info)} documents ({search_latency:.1f}ms)")
except Exception as e:
print(f"❌ Retrieval failed: {e}")
retrieved_info = []
search_latency = 0.0
latencies_search.append(search_latency)
# Step 4b: Select and format context
context_text = select_and_format_information(
retrieved_info=retrieved_info,
question=question,
max_chars=context_char_budget
)
# Resolve temporal references
context_text = resolve_temporal_references(context_text, anchor_date)
# Add reference date to context
if context_text:
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n{context_text}"
else:
context_text = "No relevant context found."
# Track context statistics
context_counts.append(len(retrieved_info))
context_chars.append(len(context_text))
context_tokens.append(len(context_text.split()))
print(f"📝 Context: {len(context_text)} chars, {len(retrieved_info)} docs")
# Step 4c: Generate answer with LLM
messages = [
{
"role": "system",
"content": (
"You are a precise QA assistant. Answer following these rules:\n"
"1) Extract the EXACT information mentioned in the context\n"
"2) For time questions: calculate actual dates from relative times\n"
"3) Return ONLY the answer text in simplest form\n"
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
"5) If no clear answer found, respond with 'Unknown'"
)
},
{
"role": "user",
"content": f"Question: {question}\n\nContext:\n{context_text}"
}
]
t_llm_start = time.time()
try:
response = await llm_client.chat(messages=messages)
t_llm_end = time.time()
llm_latency = (t_llm_end - t_llm_start) * 1000
latencies_llm.append(llm_latency)
# Extract prediction from response
if hasattr(response, 'content'):
prediction = response.content.strip()
elif isinstance(response, dict):
prediction = response["choices"][0]["message"]["content"].strip()
else:
prediction = "Unknown"
print(f"🤖 Prediction: {prediction} ({llm_latency:.1f}ms)")
except Exception as e:
print(f"❌ LLM failed: {e}")
prediction = "Unknown"
llm_latency = 0.0
latencies_llm.append(llm_latency)
# Step 4d: Calculate metrics
f1_val = f1_score(prediction, ground_truth_str)
bleu1_val = bleu1(prediction, ground_truth_str)
jaccard_val = jaccard(prediction, ground_truth_str)
# LoCoMo-specific F1: use multi-answer for category 1 (Multi-Hop)
if item.get("category") == 1:
locomo_f1_val = locomo_multi_f1(prediction, ground_truth_str)
else:
locomo_f1_val = locomo_f1_score(prediction, ground_truth_str)
# Accumulate metrics
f1_scores.append(f1_val)
bleu1_scores.append(bleu1_val)
jaccard_scores.append(jaccard_val)
locomo_f1_scores.append(locomo_f1_val)
# Track by category
category_counts[category] = category_counts.get(category, 0) + 1
category_f1.setdefault(category, []).append(f1_val)
category_bleu1.setdefault(category, []).append(bleu1_val)
category_jaccard.setdefault(category, []).append(jaccard_val)
category_locomo_f1.setdefault(category, []).append(locomo_f1_val)
print(f"📊 Metrics - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, "
f"Jaccard: {jaccard_val:.3f}, LoCoMo F1: {locomo_f1_val:.3f}")
print()
# Save sample details
samples.append({
"question": question,
"ground_truth": ground_truth_str,
"prediction": prediction,
"category": category,
"metrics": {
"f1": f1_val,
"bleu1": bleu1_val,
"jaccard": jaccard_val,
"locomo_f1": locomo_f1_val
},
"retrieval": {
"num_docs": len(retrieved_info),
"context_length": len(context_text)
},
"timing": {
"search_ms": search_latency,
"llm_ms": llm_latency
}
})
finally:
# Close connector
await connector.close()
# Step 5: Aggregate results
print(f"\n{'='*60}")
print("📊 Aggregating Results")
print(f"{'='*60}\n")
# Overall metrics
overall_metrics = {
"f1": sum(f1_scores) / max(len(f1_scores), 1) if f1_scores else 0.0,
"bleu1": sum(bleu1_scores) / max(len(bleu1_scores), 1) if bleu1_scores else 0.0,
"jaccard": sum(jaccard_scores) / max(len(jaccard_scores), 1) if jaccard_scores else 0.0,
"locomo_f1": sum(locomo_f1_scores) / max(len(locomo_f1_scores), 1) if locomo_f1_scores else 0.0
}
# Per-category metrics
by_category: Dict[str, Dict[str, Any]] = {}
for cat in category_counts:
f1_list = category_f1.get(cat, [])
b1_list = category_bleu1.get(cat, [])
j_list = category_jaccard.get(cat, [])
lf_list = category_locomo_f1.get(cat, [])
by_category[cat] = {
"count": category_counts[cat],
"f1": sum(f1_list) / max(len(f1_list), 1) if f1_list else 0.0,
"bleu1": sum(b1_list) / max(len(b1_list), 1) if b1_list else 0.0,
"jaccard": sum(j_list) / max(len(j_list), 1) if j_list else 0.0,
"locomo_f1": sum(lf_list) / max(len(lf_list), 1) if lf_list else 0.0
}
# Latency statistics
latency = {
"search": latency_stats(latencies_search),
"llm": latency_stats(latencies_llm)
}
# Context statistics
context_stats = {
"avg_retrieved_docs": sum(context_counts) / max(len(context_counts), 1) if context_counts else 0.0,
"avg_context_chars": sum(context_chars) / max(len(context_chars), 1) if context_chars else 0.0,
"avg_context_tokens": sum(context_tokens) / max(len(context_tokens), 1) if context_tokens else 0.0
}
# Build result dictionary
result = {
"dataset": "locomo",
"sample_size": len(qa_items),
"timestamp": datetime.now().isoformat(),
"params": {
"group_id": group_id,
"search_type": search_type,
"search_limit": search_limit,
"context_char_budget": context_char_budget,
"llm_id": SELECTED_LLM_ID,
"embedding_id": SELECTED_EMBEDDING_ID
},
"overall_metrics": overall_metrics,
"by_category": by_category,
"latency": latency,
"context_stats": context_stats,
"samples": samples
}
# Step 6: Save results
if output_dir is None:
output_dir = os.path.join(
os.path.dirname(__file__),
"results"
)
os.makedirs(output_dir, exist_ok=True)
# Generate timestamped filename
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = os.path.join(output_dir, f"locomo_{timestamp_str}.json")
try:
with open(output_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"✅ Results saved to: {output_path}\n")
except Exception as e:
print(f"❌ Failed to save results: {e}")
print("📊 Printing results to console instead:\n")
print(json.dumps(result, ensure_ascii=False, indent=2))
return result
def main():
"""
Parse command-line arguments and run benchmark.
This function provides a CLI interface for running LoCoMo benchmarks
with configurable parameters.
"""
parser = argparse.ArgumentParser(
description="Run LoCoMo benchmark evaluation",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--sample_size",
type=int,
default=20,
help="Number of QA pairs to evaluate"
)
parser.add_argument(
"--group_id",
type=str,
default=None,
help="Database group ID for retrieval (uses default if not specified)"
)
parser.add_argument(
"--search_type",
type=str,
default="hybrid",
choices=["keyword", "embedding", "hybrid"],
help="Search strategy to use"
)
parser.add_argument(
"--search_limit",
type=int,
default=12,
help="Maximum number of documents to retrieve per query"
)
parser.add_argument(
"--context_char_budget",
type=int,
default=8000,
help="Maximum characters for context"
)
parser.add_argument(
"--reset_group",
action="store_true",
help="Clear and re-ingest data (not implemented)"
)
parser.add_argument(
"--skip_ingest",
action="store_true",
help="Skip data ingestion and use existing data in Neo4j"
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory to save results (uses default if not specified)"
)
args = parser.parse_args()
# Load environment variables
load_dotenv()
# Run benchmark
result = asyncio.run(run_locomo_benchmark(
sample_size=args.sample_size,
group_id=args.group_id,
search_type=args.search_type,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
reset_group=args.reset_group,
skip_ingest=args.skip_ingest,
output_dir=args.output_dir
))
# Print summary
print(f"\n{'='*60}")
# Check if there was an error
if 'error' in result:
print("❌ Benchmark Failed!")
print(f"{'='*60}")
print(f"Error: {result['error']}")
return
print("🎉 Benchmark Complete!")
print(f"{'='*60}")
print("📊 Final Results:")
print(f" Sample size: {result.get('sample_size', 0)}")
print(f" F1: {result['overall_metrics']['f1']:.3f}")
print(f" BLEU-1: {result['overall_metrics']['bleu1']:.3f}")
print(f" Jaccard: {result['overall_metrics']['jaccard']:.3f}")
print(f" LoCoMo F1: {result['overall_metrics']['locomo_f1']:.3f}")
if result.get('context_stats'):
print("\n📈 Context Statistics:")
print(f" Avg retrieved docs: {result['context_stats']['avg_retrieved_docs']:.1f}")
print(f" Avg context chars: {result['context_stats']['avg_context_chars']:.0f}")
print(f" Avg context tokens: {result['context_stats']['avg_context_tokens']:.0f}")
if result.get('latency'):
print("\n⏱️ Latency Statistics:")
print(f" Search - Mean: {result['latency']['search']['mean']:.1f}ms, "
f"P50: {result['latency']['search']['p50']:.1f}ms, "
f"P95: {result['latency']['search']['p95']:.1f}ms")
print(f" LLM - Mean: {result['latency']['llm']['mean']:.1f}ms, "
f"P50: {result['latency']['llm']['p50']:.1f}ms, "
f"P95: {result['latency']['llm']['p95']:.1f}ms")
if result.get('by_category'):
print("\n📂 Results by Category:")
for cat, metrics in result['by_category'].items():
print(f" {cat}:")
print(f" Count: {metrics['count']}")
print(f" F1: {metrics['f1']:.3f}")
print(f" LoCoMo F1: {metrics['locomo_f1']:.3f}")
print(f" Jaccard: {metrics['jaccard']:.3f}")
print(f"\n{'='*60}\n")
if __name__ == "__main__":
main()

View File

@@ -1,225 +0,0 @@
"""
LoCoMo-specific metric calculations.
This module provides clean, simplified implementations of metrics used for
LoCoMo benchmark evaluation, including text normalization and F1 score variants.
"""
import re
from typing import Dict, Any
def normalize_text(text: str) -> str:
"""
Normalize text for LoCoMo evaluation.
Normalization steps:
- Convert to lowercase
- Remove commas
- Remove stop words (a, an, the, and)
- Remove punctuation
- Normalize whitespace
Args:
text: Input text to normalize
Returns:
Normalized text string with consistent formatting
Examples:
>>> normalize_text("The cat, and the dog")
'cat dog'
>>> normalize_text("Hello, World!")
'hello world'
"""
# Ensure input is a string
text = str(text) if text is not None else ""
# Convert to lowercase
text = text.lower()
# Remove commas
text = re.sub(r"[\,]", " ", text)
# Remove stop words
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
# Remove punctuation (keep only word characters and whitespace)
text = re.sub(r"[^\w\s]", " ", text)
# Normalize whitespace (collapse multiple spaces to single space)
text = " ".join(text.split())
return text
def locomo_f1_score(prediction: str, ground_truth: str) -> float:
"""
Calculate LoCoMo F1 score for single-answer questions.
Uses token-level precision and recall based on normalized text.
Treats tokens as sets (no duplicate counting).
Args:
prediction: Model's predicted answer
ground_truth: Correct answer
Returns:
F1 score between 0.0 and 1.0
Examples:
>>> locomo_f1_score("Paris", "Paris")
1.0
>>> locomo_f1_score("The cat", "cat")
1.0
>>> locomo_f1_score("dog", "cat")
0.0
"""
# Ensure inputs are strings
pred_str = str(prediction) if prediction is not None else ""
truth_str = str(ground_truth) if ground_truth is not None else ""
# Normalize and tokenize
pred_tokens = normalize_text(pred_str).split()
truth_tokens = normalize_text(truth_str).split()
# Handle empty cases
if not pred_tokens or not truth_tokens:
return 0.0
# Convert to sets for comparison
pred_set = set(pred_tokens)
truth_set = set(truth_tokens)
# Calculate true positives (intersection)
true_positives = len(pred_set & truth_set)
# Calculate precision and recall
precision = true_positives / len(pred_set) if pred_set else 0.0
recall = true_positives / len(truth_set) if truth_set else 0.0
# Calculate F1 score
if precision + recall == 0:
return 0.0
f1 = 2 * precision * recall / (precision + recall)
return f1
def locomo_multi_f1(prediction: str, ground_truth: str) -> float:
"""
Calculate LoCoMo F1 score for multi-answer questions.
Handles comma-separated answers by:
1. Splitting both prediction and ground truth by commas
2. For each ground truth answer, finding the best matching prediction
3. Averaging the F1 scores across all ground truth answers
Args:
prediction: Model's predicted answer (may contain multiple comma-separated answers)
ground_truth: Correct answer (may contain multiple comma-separated answers)
Returns:
Average F1 score across all ground truth answers (0.0 to 1.0)
Examples:
>>> locomo_multi_f1("Paris, London", "Paris, London")
1.0
>>> locomo_multi_f1("Paris", "Paris, London")
0.5
>>> locomo_multi_f1("Paris, Berlin", "Paris, London")
0.5
"""
# Ensure inputs are strings
pred_str = str(prediction) if prediction is not None else ""
truth_str = str(ground_truth) if ground_truth is not None else ""
# Split by commas and strip whitespace
predictions = [p.strip() for p in pred_str.split(',') if p.strip()]
ground_truths = [g.strip() for g in truth_str.split(',') if g.strip()]
# Handle empty cases
if not predictions or not ground_truths:
return 0.0
# For each ground truth, find the best matching prediction
f1_scores = []
for gt in ground_truths:
# Calculate F1 with each prediction and take the maximum
best_f1 = max(locomo_f1_score(pred, gt) for pred in predictions)
f1_scores.append(best_f1)
# Return average F1 across all ground truths
return sum(f1_scores) / len(f1_scores)
def get_category_name(item: Dict[str, Any]) -> str:
"""
Extract and normalize category name from QA item.
Handles both numeric categories (1-4) and string categories with various formats.
Supports multiple field names: "cat", "category", "type".
Category mapping:
- 1 or "multi-hop" -> "Multi-Hop"
- 2 or "temporal" -> "Temporal"
- 3 or "open domain" -> "Open Domain"
- 4 or "single-hop" -> "Single-Hop"
Args:
item: QA item dictionary containing category information
Returns:
Standardized category name or "unknown" if not found
Examples:
>>> get_category_name({"category": 1})
'Multi-Hop'
>>> get_category_name({"cat": "temporal"})
'Temporal'
>>> get_category_name({"type": "Single-Hop"})
'Single-Hop'
"""
# Numeric category mapping
CATEGORY_MAP = {
1: "Multi-Hop",
2: "Temporal",
3: "Open Domain",
4: "Single-Hop",
}
# String category aliases (case-insensitive)
TYPE_ALIASES = {
"single-hop": "Single-Hop",
"singlehop": "Single-Hop",
"single hop": "Single-Hop",
"multi-hop": "Multi-Hop",
"multihop": "Multi-Hop",
"multi hop": "Multi-Hop",
"open domain": "Open Domain",
"opendomain": "Open Domain",
"temporal": "Temporal",
}
# Try "cat" field first (string category)
cat = item.get("cat")
if isinstance(cat, str) and cat.strip():
name = cat.strip()
lower = name.lower()
return TYPE_ALIASES.get(lower, name)
# Try "category" field (can be int or string)
cat_num = item.get("category")
if isinstance(cat_num, int):
return CATEGORY_MAP.get(cat_num, "unknown")
elif isinstance(cat_num, str) and cat_num.strip():
lower = cat_num.strip().lower()
return TYPE_ALIASES.get(lower, cat_num.strip())
# Try "type" field as fallback
cat_type = item.get("type")
if isinstance(cat_type, str) and cat_type.strip():
lower = cat_type.strip().lower()
return TYPE_ALIASES.get(lower, cat_type.strip())
return "unknown"

View File

@@ -1,810 +0,0 @@
# file name: check_neo4j_connection_fixed.py
import asyncio
import json
import math
import os
import re
import sys
import time
from datetime import datetime, timedelta
from typing import Any, Dict, List
from dotenv import load_dotenv
# 1
# 添加项目根目录到路径
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
if project_root not in sys.path:
sys.path.insert(0, project_root)
# 关键:将 src 目录置于最前,确保从当前仓库加载模块
src_dir = os.path.join(project_root, "src")
if src_dir not in sys.path:
sys.path.insert(0, src_dir)
load_dotenv()
# 首先定义 _loc_normalize 函数,因为其他函数依赖它
def _loc_normalize(text: str) -> str:
text = str(text) if text is not None else ""
text = text.lower()
text = re.sub(r"[\,]", " ", text)
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
text = re.sub(r"[^\w\s]", " ", text)
text = " ".join(text.split())
return text
# 尝试从 metrics.py 导入基础指标
try:
from common.metrics import bleu1, f1_score, jaccard
print("✅ 从 metrics.py 导入基础指标成功")
except ImportError as e:
print(f"❌ 从 metrics.py 导入失败: {e}")
# 回退到本地实现
def f1_score(pred: str, ref: str) -> float:
pred_str = str(pred) if pred is not None else ""
ref_str = str(ref) if ref is not None else ""
p_tokens = _loc_normalize(pred_str).split()
r_tokens = _loc_normalize(ref_str).split()
if not p_tokens and not r_tokens:
return 1.0
if not p_tokens or not r_tokens:
return 0.0
p_set = set(p_tokens)
r_set = set(r_tokens)
tp = len(p_set & r_set)
precision = tp / len(p_set) if p_set else 0.0
recall = tp / len(r_set) if r_set else 0.0
if precision + recall == 0:
return 0.0
return 2 * precision * recall / (precision + recall)
def bleu1(pred: str, ref: str) -> float:
pred_str = str(pred) if pred is not None else ""
ref_str = str(ref) if ref is not None else ""
p_tokens = _loc_normalize(pred_str).split()
r_tokens = _loc_normalize(ref_str).split()
if not p_tokens:
return 0.0
r_counts = {}
for t in r_tokens:
r_counts[t] = r_counts.get(t, 0) + 1
clipped = 0
p_counts = {}
for t in p_tokens:
p_counts[t] = p_counts.get(t, 0) + 1
for t, c in p_counts.items():
clipped += min(c, r_counts.get(t, 0))
precision = clipped / max(len(p_tokens), 1)
ref_len = len(r_tokens)
pred_len = len(p_tokens)
if pred_len > ref_len or pred_len == 0:
bp = 1.0
else:
bp = math.exp(1 - ref_len / max(pred_len, 1))
return bp * precision
def jaccard(pred: str, ref: str) -> float:
pred_str = str(pred) if pred is not None else ""
ref_str = str(ref) if ref is not None else ""
p = set(_loc_normalize(pred_str).split())
r = set(_loc_normalize(ref_str).split())
if not p and not r:
return 1.0
if not p or not r:
return 0.0
return len(p & r) / len(p | r)
# 尝试从 qwen_search_eval.py 导入 LoCoMo 特定指标
try:
# 添加 evaluation 目录路径
evaluation_dir = os.path.join(project_root, "evaluation")
if evaluation_dir not in sys.path:
sys.path.insert(0, evaluation_dir)
# 尝试从不同位置导入
try:
from locomo.qwen_search_eval import (
_resolve_relative_times,
loc_f1_score,
loc_multi_f1,
)
print("✅ 从 locomo.qwen_search_eval 导入 LoCoMo 特定指标成功")
except ImportError:
from qwen_search_eval import _resolve_relative_times, loc_f1_score, loc_multi_f1
print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功")
except ImportError as e:
print(f"❌ 从 qwen_search_eval.py 导入失败: {e}")
# 回退到本地实现 LoCoMo 特定函数
def _resolve_relative_times(text: str, anchor: datetime) -> str:
t = str(text) if text is not None else ""
t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
def _ago_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor - timedelta(days=n)).date().isoformat()
def _in_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor + timedelta(days=n)).date().isoformat()
t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE)
t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE)
t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
return t
def loc_f1_score(prediction: str, ground_truth: str) -> float:
p_tokens = _loc_normalize(prediction).split()
g_tokens = _loc_normalize(ground_truth).split()
if not p_tokens or not g_tokens:
return 0.0
p = set(p_tokens)
g = set(g_tokens)
tp = len(p & g)
precision = tp / len(p) if p else 0.0
recall = tp / len(g) if g else 0.0
return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
def loc_multi_f1(prediction: str, ground_truth: str) -> float:
predictions = [p.strip() for p in str(prediction).split(',') if p.strip()]
ground_truths = [g.strip() for g in str(ground_truth).split(',') if g.strip()]
if not predictions or not ground_truths:
return 0.0
def _f1(a: str, b: str) -> float:
return loc_f1_score(a, b)
vals = []
for gt in ground_truths:
vals.append(max(_f1(pred, gt) for pred in predictions))
return sum(vals) / len(vals)
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 8000) -> str:
"""基于问题关键词智能选择上下文"""
if not contexts:
return ""
# 提取问题关键词(只保留有意义的词)
question_lower = question.lower()
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
question_words = set(re.findall(r'\b\w+\b', question_lower))
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
print(f"🔍 问题关键词: {question_words}")
# 给每个上下文打分
scored_contexts = []
for i, context in enumerate(contexts):
context_lower = context.lower()
score = 0
# 关键词匹配得分
keyword_matches = 0
for word in question_words:
if word in context_lower:
keyword_matches += 1
# 关键词出现次数越多,得分越高
score += context_lower.count(word) * 2
# 上下文长度得分(适中的长度更好)
context_len = len(context)
if 100 < context_len < 2000: # 理想长度范围
score += 5
elif context_len >= 2000: # 太长可能包含无关信息
score += 2
# 如果是前几个上下文,给予额外分数(通常相关性更高)
if i < 3:
score += 3
scored_contexts.append((score, context, keyword_matches))
# 按得分排序
scored_contexts.sort(key=lambda x: x[0], reverse=True)
# 选择高得分的上下文,直到达到字符限制
selected = []
total_chars = 0
selected_count = 0
print("📊 上下文相关性分析:")
for score, context, matches in scored_contexts[:5]: # 只显示前5个
print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}")
for score, context, matches in scored_contexts:
if total_chars + len(context) <= max_chars:
selected.append(context)
total_chars += len(context)
selected_count += 1
else:
# 如果这个上下文得分很高但放不下,尝试截取
if score > 10 and total_chars < max_chars - 500:
remaining = max_chars - total_chars
# 找到包含关键词的部分
lines = context.split('\n')
relevant_lines = []
current_chars = 0
for line in lines:
line_lower = line.lower()
line_relevance = any(word in line_lower for word in question_words)
if line_relevance and current_chars < remaining - 100:
relevant_lines.append(line)
current_chars += len(line)
if relevant_lines:
truncated = '\n'.join(relevant_lines)
if len(truncated) > 100: # 确保有足够内容
selected.append(truncated + "\n[相关内容截断...]")
total_chars += len(truncated)
selected_count += 1
break # 不再尝试添加更多上下文
result = "\n\n".join(selected)
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符")
return result
def get_dynamic_search_params(question: str, question_index: int, total_questions: int):
"""根据问题复杂度和进度动态调整检索参数"""
# 分析问题复杂度
word_count = len(question.split())
has_temporal = any(word in question.lower() for word in ['when', 'date', 'time', 'ago'])
has_multi_hop = any(word in question.lower() for word in ['and', 'both', 'also', 'while'])
# 根据进度调整 - 后期问题可能需要更精确的检索
progress_factor = question_index / total_questions
base_limit = 12
if has_temporal and has_multi_hop:
base_limit = 20
elif word_count > 8:
base_limit = 16
# 随着测试进行,逐渐收紧检索范围
adjusted_limit = max(8, int(base_limit * (1 - progress_factor * 0.3)))
# 动态调整最大字符数
max_chars = 8000 + 4000 * (1 - progress_factor)
return {
"limit": adjusted_limit,
"max_chars": int(max_chars)
}
class EnhancedEvaluationMonitor:
def __init__(self, reset_interval=5, performance_threshold=0.6):
self.question_count = 0
self.reset_interval = reset_interval
self.performance_threshold = performance_threshold
self.consecutive_low_scores = 0
self.performance_history = []
self.recent_f1_scores = []
def should_reset_connections(self, current_f1=None):
"""基于计数和性能双重判断"""
# 定期重置
if self.question_count % self.reset_interval == 0:
return True
# 性能驱动的重置
if current_f1 is not None and current_f1 < self.performance_threshold:
self.consecutive_low_scores += 1
if self.consecutive_low_scores >= 2: # 连续2个低分就重置
print("🚨 连续低分,触发紧急重置")
self.consecutive_low_scores = 0
return True
else:
self.consecutive_low_scores = 0
return False
def record_performance(self, question_index, metrics, context_length, retrieved_docs):
"""记录性能指标,检测衰减"""
self.performance_history.append({
'index': question_index,
'metrics': metrics,
'context_length': context_length,
'retrieved_docs': retrieved_docs,
'timestamp': time.time()
})
# 记录最近的F1分数
self.recent_f1_scores.append(metrics['f1'])
if len(self.recent_f1_scores) > 5:
self.recent_f1_scores.pop(0)
def get_recent_performance(self):
"""获取近期平均性能"""
if not self.recent_f1_scores:
return 0.5
return sum(self.recent_f1_scores) / len(self.recent_f1_scores)
def get_performance_trend(self):
"""分析性能趋势"""
if len(self.performance_history) < 2:
return "stable"
recent_metrics = [item['metrics']['f1'] for item in self.performance_history[-5:]]
earlier_metrics = [item['metrics']['f1'] for item in self.performance_history[-10:-5]]
if len(recent_metrics) < 2 or len(earlier_metrics) < 2:
return "stable"
recent_avg = sum(recent_metrics) / len(recent_metrics)
earlier_avg = sum(earlier_metrics) / len(earlier_metrics)
if recent_avg < earlier_avg * 0.8:
return "degrading"
elif recent_avg > earlier_avg * 1.1:
return "improving"
else:
return "stable"
def get_enhanced_search_params(question: str, question_index: int, total_questions: int, recent_performance: float):
"""基于问题复杂度和近期性能动态调整检索参数"""
# 基础参数
base_params = get_dynamic_search_params(question, question_index, total_questions)
# 性能自适应调整
if recent_performance < 0.5: # 近期表现差
# 增加检索范围,尝试获取更多上下文
base_params["limit"] = min(base_params["limit"] + 5, 25)
base_params["max_chars"] = min(base_params["max_chars"] + 2000, 12000)
print(f"📈 性能自适应:增加检索范围 (limit={base_params['limit']}, max_chars={base_params['max_chars']})")
elif recent_performance > 0.8: # 近期表现好
# 收紧检索,提高精度
base_params["limit"] = max(base_params["limit"] - 2, 8)
base_params["max_chars"] = max(base_params["max_chars"] - 1000, 6000)
print(f"🎯 性能自适应:提高检索精度 (limit={base_params['limit']}, max_chars={base_params['max_chars']})")
# 中间阶段特殊处理
mid_sequence_factor = abs(question_index / total_questions - 0.5)
if mid_sequence_factor < 0.2: # 在中间30%的问题
print("🎯 中间阶段:使用更精确的检索策略")
base_params["limit"] = max(base_params["limit"] - 2, 10) # 减少数量,提高质量
base_params["max_chars"] = max(base_params["max_chars"] - 1000, 7000)
return base_params
def enhanced_context_selection(contexts: List[str], question: str, question_index: int, total_questions: int, max_chars: int = 8000) -> str:
"""考虑问题序列位置的智能选择"""
if not contexts:
return ""
# 在序列中间阶段使用更严格的筛选
mid_sequence_factor = abs(question_index / total_questions - 0.5) # 距离中心的距离
if mid_sequence_factor < 0.2: # 在中间30%的问题
print("🎯 中间阶段:使用严格上下文筛选")
# 提取问题关键词
question_lower = question.lower()
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
question_words = set(re.findall(r'\b\w+\b', question_lower))
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
# 只保留高度相关的上下文
filtered_contexts = []
for context in contexts:
context_lower = context.lower()
relevance_score = sum(3 if word in context_lower else 0 for word in question_words)
# 额外加分给包含数字、日期的上下文(对事实性问题更重要)
if any(char.isdigit() for char in context):
relevance_score += 2
# 提高阈值:只有得分>=3的上下文才保留
if relevance_score >= 3:
filtered_contexts.append(context)
else:
print(f" - 过滤低分上下文: 得分={relevance_score}")
contexts = filtered_contexts
print(f"🔍 严格筛选后保留 {len(contexts)} 个上下文")
# 使用原有的智能选择逻辑
return smart_context_selection(contexts, question, max_chars)
async def run_enhanced_evaluation():
"""使用增强方法进行完整评估 - 解决中间性能衰减问题"""
try:
from dotenv import load_dotenv
except Exception:
def load_dotenv():
return None
# 修正导入路径:使用 app.core.memory.src 前缀
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.definitions import (
SELECTED_EMBEDDING_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
# 加载数据
# 获取项目根目录
current_file = os.path.abspath(__file__)
evaluation_dir = os.path.dirname(os.path.dirname(current_file)) # evaluation目录
memory_dir = os.path.dirname(evaluation_dir) # memory目录
data_path = os.path.join(memory_dir, "data", "locomo10.json")
with open(data_path, "r", encoding="utf-8") as f:
raw = json.load(f)
qa_items = []
if isinstance(raw, list):
for entry in raw:
qa_items.extend(entry.get("qa", []))
else:
qa_items.extend(raw.get("qa", []))
items = qa_items[:20] # 测试多少个问题
# 初始化增强监控器
monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm = factory.get_llm_client(SELECTED_LLM_ID)
# 初始化embedder
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)
# 初始化连接器
connector = Neo4jConnector()
# 初始化结果字典
results = {
"questions": [],
"overall_metrics": {"f1": 0.0, "b1": 0.0, "j": 0.0, "loc_f1": 0.0},
"category_metrics": {},
"retrieval_stats": {"total_questions": len(items), "avg_context_length": 0, "avg_retrieved_docs": 0},
"performance_trend": "stable",
"timestamp": datetime.now().isoformat(),
"enhanced_strategy": True
}
total_f1 = 0.0
total_bleu1 = 0.0
total_jaccard = 0.0
total_loc_f1 = 0.0
total_context_length = 0
total_retrieved_docs = 0
category_stats = {}
try:
for i, item in enumerate(items):
monitor.question_count += 1
# 获取近期性能用于重置判断
recent_performance = monitor.get_recent_performance()
# 增强的重置判断
should_reset = monitor.should_reset_connections(current_f1=recent_performance)
if should_reset and i > 0:
print(f"🔄 重置Neo4j连接 (问题 {i+1}/{len(items)}, 近期性能: {recent_performance:.3f})...")
await connector.close()
connector = Neo4jConnector() # 创建新连接
print("✅ 连接重置完成")
q = item.get("question", "")
ref = item.get("answer", "")
ref_str = str(ref) if ref is not None else ""
print(f"\n🔍 [{i+1}/{len(items)}] 问题: {q}")
print(f"✅ 真实答案: {ref_str}")
# 分类别统计
category = "Unknown"
if item.get("category") == 1:
category = "Multi-Hop"
elif item.get("category") == 2:
category = "Temporal"
elif item.get("category") == 3:
category = "Open Domain"
elif item.get("category") == 4:
category = "Single-Hop"
# 增强的检索参数
search_params = get_enhanced_search_params(q, i, len(items), recent_performance)
search_limit = search_params["limit"]
max_chars = search_params["max_chars"]
print(f"🏷️ 类别: {category}, 检索参数: limit={search_limit}, max_chars={max_chars}")
# 使用项目标准的混合检索方法
t0 = time.time()
contexts_all = []
try:
# 使用统一的搜索服务
from app.core.memory.storage_services.search import run_hybrid_search
print("🔀 使用混合搜索服务...")
search_results = await run_hybrid_search(
query_text=q,
search_type="hybrid",
group_id="locomo_sk",
limit=20,
include=["statements", "chunks", "entities", "summaries"],
alpha=0.6, # BM25权重
embedding_id=SELECTED_EMBEDDING_ID
)
# 处理搜索结果 - 新的搜索服务返回统一的结构
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要")
# 构建上下文:优先使用 chunks、statements 和 summaries
for c in chunks:
content = str(c.get("content", "")).strip()
if content:
contexts_all.append(content)
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
for sm in summaries:
summary_text = str(sm.get("summary", "")).strip()
if summary_text:
contexts_all.append(summary_text)
# 实体摘要最多加入前3个高分实体避免噪声
scored = [e for e in entities if e.get("score") is not None]
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(f"EntitySummary: {name}{(' [' + ' '.join(meta) + ']') if meta else ''}")
if summary_lines:
contexts_all.append("\n".join(summary_lines))
print(f"📊 有效上下文数量: {len(contexts_all)}")
except Exception as e:
print(f"❌ 检索失败: {e}")
contexts_all = []
t1 = time.time()
search_time = (t1 - t0) * 1000
# 增强的上下文选择
context_text = ""
if contexts_all:
# 使用增强的上下文选择
context_text = enhanced_context_selection(contexts_all, q, i, len(items), max_chars=max_chars)
# 如果智能选择后仍然过长,进行最终保护性截断
if len(context_text) > max_chars:
print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断")
context_text = context_text[:max_chars] + "\n\n[最终截断...]"
# 时间解析
anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性
context_text = _resolve_relative_times(context_text, anchor_date)
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text
print(f"📝 最终上下文长度: {len(context_text)} 字符")
# 显示不同上下文的预览(不只是第一条)
print("🔍 上下文预览:")
for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文
preview = context[:150].replace('\n', ' ')
print(f" 上下文{j+1}: {preview}...")
# 🔍 调试:检查答案是否在上下文中
if ref_str and ref_str.strip():
answer_found = any(ref_str.lower() in ctx.lower() for ctx in contexts_all)
print(f"🔍 调试:答案 '{ref_str}' 是否在检索到的上下文中? {'✅ 是' if answer_found else '❌ 否'}")
else:
print("❌ 没有检索到有效上下文")
context_text = "No relevant context found."
# LLM 回答
messages = [
{"role": "system", "content": (
"You are a precise QA assistant. Answer following these rules:\n"
"1) Extract the EXACT information mentioned in the context\n"
"2) For time questions: calculate actual dates from relative times\n"
"3) Return ONLY the answer text in simplest form\n"
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
"5) If no clear answer found, respond with 'Unknown'"
)},
{"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"},
]
t2 = time.time()
try:
# 使用异步调用
resp = await llm.chat(messages=messages)
# 兼容不同的响应格式
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown")
except Exception as e:
print(f"❌ LLM 生成失败: {e}")
pred = "Unknown"
t3 = time.time()
llm_time = (t3 - t2) * 1000
# 计算指标 - 使用导入的指标函数
f1_val = f1_score(pred, ref_str)
bleu1_val = bleu1(pred, ref_str)
jaccard_val = jaccard(pred, ref_str)
loc_f1_val = loc_f1_score(pred, ref_str)
print(f"🤖 LLM 回答: {pred}")
print(f"📈 指标 - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, Jaccard: {jaccard_val:.3f}, LoCoMo F1: {loc_f1_val:.3f}")
print(f"⏱️ 时间 - 检索: {search_time:.1f}ms, LLM: {llm_time:.1f}ms")
# 更新统计
total_f1 += f1_val
total_bleu1 += bleu1_val
total_jaccard += jaccard_val
total_loc_f1 += loc_f1_val
total_context_length += len(context_text)
total_retrieved_docs += len(contexts_all)
if category not in category_stats:
category_stats[category] = {"count": 0, "f1_sum": 0.0, "b1_sum": 0.0, "j_sum": 0.0, "loc_f1_sum": 0.0}
category_stats[category]["count"] += 1
category_stats[category]["f1_sum"] += f1_val
category_stats[category]["b1_sum"] += bleu1_val
category_stats[category]["j_sum"] += jaccard_val
category_stats[category]["loc_f1_sum"] += loc_f1_val
# 记录性能指标
metrics = {"f1": f1_val, "bleu1": bleu1_val, "jaccard": jaccard_val, "loc_f1": loc_f1_val}
monitor.record_performance(i, metrics, len(context_text), len(contexts_all))
# 保存结果
question_result = {
"question": q,
"ground_truth": ref_str,
"prediction": pred,
"category": category,
"metrics": metrics,
"retrieval": {
"retrieved_documents": len(contexts_all),
"context_length": len(context_text),
"search_limit": search_limit,
"max_chars": max_chars,
"recent_performance": recent_performance
},
"timing": {
"search_ms": search_time,
"llm_ms": llm_time
}
}
results["questions"].append(question_result)
print("="*60)
except Exception as e:
print(f"❌ 评估过程中发生错误: {e}")
# 即使出错,也返回已有的结果
import traceback
traceback.print_exc()
finally:
await connector.close()
# 计算总体指标
n = len(items)
if n > 0:
results["overall_metrics"] = {
"f1": total_f1 / n,
"b1": total_bleu1 / n,
"j": total_jaccard / n,
"loc_f1": total_loc_f1 / n
}
for category, stats in category_stats.items():
count = stats["count"]
results["category_metrics"][category] = {
"count": count,
"f1": stats["f1_sum"] / count,
"bleu1": stats["b1_sum"] / count,
"jaccard": stats["j_sum"] / count,
"loc_f1": stats["loc_f1_sum"] / count
}
results["retrieval_stats"]["avg_context_length"] = total_context_length / n
results["retrieval_stats"]["avg_retrieved_docs"] = total_retrieved_docs / n
# 分析性能趋势
results["performance_trend"] = monitor.get_performance_trend()
results["reset_interval"] = monitor.reset_interval
results["total_questions_processed"] = monitor.question_count
return results
if __name__ == "__main__":
print("🚀 运行增强版完整评估(解决中间性能衰减问题)...")
print("📋 增强特性:")
print(" - 双重重置策略:定期重置 + 性能驱动重置")
print(" - 动态检索参数:基于近期性能自适应调整")
print(" - 中间阶段严格筛选:提高上下文质量要求")
print(" - 连续性能监控:实时检测性能衰减")
result = asyncio.run(run_enhanced_evaluation())
print("\n📊 最终评估结果:")
print("总体指标:")
print(f" F1: {result['overall_metrics']['f1']:.4f}")
print(f" BLEU-1: {result['overall_metrics']['b1']:.4f}")
print(f" Jaccard: {result['overall_metrics']['j']:.4f}")
print(f" LoCoMo F1: {result['overall_metrics']['loc_f1']:.4f}")
print("\n分类别指标:")
for category, metrics in result['category_metrics'].items():
print(f" {category}: F1={metrics['f1']:.4f}, BLEU-1={metrics['bleu1']:.4f}, Jaccard={metrics['jaccard']:.4f}, LoCoMo F1={metrics['loc_f1']:.4f} (样本数: {metrics['count']})")
print("\n检索统计:")
stats = result['retrieval_stats']
print(f" 平均上下文长度: {stats['avg_context_length']:.0f} 字符")
print(f" 平均检索文档数: {stats['avg_retrieved_docs']:.1f}")
print(f"\n性能趋势: {result['performance_trend']}")
print(f"重置间隔: 每{result['reset_interval']}个问题")
print(f"处理问题总数: {result['total_questions_processed']}")
print(f"增强策略: {'启用' if result.get('enhanced_strategy', False) else '未启用'}")
# 保存结果到指定目录
# 使用代码文件所在目录的绝对路径
current_file_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = os.path.join(current_file_dir, "results")
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, "enhanced_evaluation_results.json")
with open(output_file, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"\n详细结果已保存到: {output_file}")

View File

@@ -1,626 +0,0 @@
"""
LoCoMo Utilities Module
This module provides helper functions for the LoCoMo benchmark evaluation:
- Data loading from JSON files
- Conversation extraction for ingestion
- Temporal reference resolution
- Context selection and formatting
- Retrieval wrapper functions
- Ingestion wrapper functions
"""
import os
import json
import re
from datetime import datetime, timedelta
from typing import List, Dict, Any, Optional
from app.core.memory.utils.definitions import PROJECT_ROOT
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
def load_locomo_data(
data_path: str,
sample_size: int,
conversation_index: int = 0
) -> List[Dict[str, Any]]:
"""
Load LoCoMo dataset from JSON file.
The LoCoMo dataset structure is a list of conversation objects, where each
object contains a "qa" list of question-answer pairs.
Args:
data_path: Path to locomo10.json file
sample_size: Number of QA pairs to load (limits total QA items returned)
conversation_index: Which conversation to load QA pairs from (default: 0 for first)
Returns:
List of QA item dictionaries, each containing:
- question: str
- answer: str
- category: int (1-4)
- evidence: List[str]
Raises:
FileNotFoundError: If data_path does not exist
json.JSONDecodeError: If file is not valid JSON
IndexError: If conversation_index is out of range
"""
if not os.path.exists(data_path):
raise FileNotFoundError(f"LoCoMo data file not found: {data_path}")
with open(data_path, "r", encoding="utf-8") as f:
raw = json.load(f)
# LoCoMo data structure: list of objects, each with a "qa" list
qa_items: List[Dict[str, Any]] = []
if isinstance(raw, list):
# Only load QA pairs from the specified conversation
if conversation_index < len(raw):
entry = raw[conversation_index]
if isinstance(entry, dict) and "qa" in entry:
qa_items.extend(entry.get("qa", []))
else:
raise IndexError(
f"Conversation index {conversation_index} out of range. "
f"Dataset has {len(raw)} conversations."
)
else:
# Fallback: single object with qa list
if conversation_index == 0:
qa_items.extend(raw.get("qa", []))
else:
raise IndexError(
f"Conversation index {conversation_index} out of range. "
f"Dataset has only 1 conversation."
)
# Return only the requested sample size
return qa_items[:sample_size]
def extract_conversations(data_path: str, max_dialogues: int = 1) -> List[str]:
"""
Extract conversation texts from LoCoMo data for ingestion.
This function extracts the raw conversation dialogues from the LoCoMo dataset
so they can be ingested into the memory system. Each conversation is formatted
as a multi-line string with "role: message" format.
Args:
data_path: Path to locomo10.json file
max_dialogues: Maximum number of dialogues to extract (default: 1)
Returns:
List of conversation strings formatted for ingestion.
Each string contains multiple lines in format "role: message"
Example output:
[
"User: I went to the store yesterday.\\nAI: What did you buy?\\n...",
"User: I love hiking.\\nAI: Where do you like to hike?\\n..."
]
"""
if not os.path.exists(data_path):
raise FileNotFoundError(f"LoCoMo data file not found: {data_path}")
with open(data_path, "r", encoding="utf-8") as f:
raw = json.load(f)
# Ensure we have a list of entries
entries = raw if isinstance(raw, list) else [raw]
contents: List[str] = []
for i, entry in enumerate(entries[:max_dialogues]):
if not isinstance(entry, dict):
continue
conv = entry.get("conversation", {})
if not isinstance(conv, dict):
continue
lines: List[str] = []
# Collect all session_* messages
for key, val in sorted(conv.items()):
if isinstance(val, list) and key.startswith("session_"):
for msg in val:
if not isinstance(msg, dict):
continue
role = msg.get("speaker") or "User"
text = msg.get("text") or ""
text = str(text).strip()
if not text:
continue
lines.append(f"{role}: {text}")
if lines:
contents.append("\n".join(lines))
return contents
def resolve_temporal_references(text: str, anchor_date: datetime) -> str:
"""
Resolve relative temporal references to absolute dates.
This function converts relative time expressions (like "today", "yesterday",
"3 days ago") into absolute ISO date strings based on an anchor date.
Supported patterns:
- today, yesterday, tomorrow
- X days ago, in X days
- last week, next week
Args:
text: Text containing temporal references
anchor_date: Reference date for resolution (datetime object)
Returns:
Text with temporal references replaced by ISO dates (YYYY-MM-DD format)
Example:
>>> anchor = datetime(2023, 5, 8)
>>> resolve_temporal_references("I saw him yesterday", anchor)
"I saw him 2023-05-07"
"""
# Ensure input is a string
t = str(text) if text is not None else ""
# today / yesterday / tomorrow
t = re.sub(
r"\btoday\b",
anchor_date.date().isoformat(),
t,
flags=re.IGNORECASE
)
t = re.sub(
r"\byesterday\b",
(anchor_date - timedelta(days=1)).date().isoformat(),
t,
flags=re.IGNORECASE
)
t = re.sub(
r"\btomorrow\b",
(anchor_date + timedelta(days=1)).date().isoformat(),
t,
flags=re.IGNORECASE
)
# X days ago
def _ago_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor_date - timedelta(days=n)).date().isoformat()
# in X days
def _in_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor_date + timedelta(days=n)).date().isoformat()
t = re.sub(
r"\b(\d+)\s+days?\s+ago\b",
_ago_repl,
t,
flags=re.IGNORECASE
)
t = re.sub(
r"\bin\s+(\d+)\s+days?\b",
_in_repl,
t,
flags=re.IGNORECASE
)
# last week / next week (approximate as 7 days)
t = re.sub(
r"\blast\s+week\b",
(anchor_date - timedelta(days=7)).date().isoformat(),
t,
flags=re.IGNORECASE
)
t = re.sub(
r"\bnext\s+week\b",
(anchor_date + timedelta(days=7)).date().isoformat(),
t,
flags=re.IGNORECASE
)
return t
def select_and_format_information(
retrieved_info: List[str],
question: str,
max_chars: int = 8000
) -> str:
"""
Intelligently select and format most relevant retrieved information for LLM prompt.
This function scores each piece of retrieved information based on keyword matching
with the question, then selects the highest-scoring pieces up to the character limit.
Scoring criteria:
- Keyword matches (higher weight for multiple occurrences)
- Context length (moderate length preferred)
- Position (earlier contexts get bonus points)
Args:
retrieved_info: List of retrieved information strings (chunks, statements, entities)
question: Question being answered
max_chars: Maximum total characters to include in final prompt
Returns:
Formatted string combining the most relevant information for LLM prompt.
Contexts are separated by double newlines.
Example:
>>> contexts = ["Alice went to Paris", "Bob likes pizza", "Alice visited the Eiffel Tower"]
>>> question = "Where did Alice go?"
>>> select_and_format_information(contexts, question, max_chars=100)
"Alice went to Paris\\n\\nAlice visited the Eiffel Tower"
"""
if not retrieved_info:
return ""
# Extract question keywords (filter out stop words and short words)
question_lower = question.lower()
stop_words = {
'what', 'when', 'where', 'who', 'why', 'how',
'did', 'do', 'does', 'is', 'are', 'was', 'were',
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at'
}
question_words = set(re.findall(r'\b\w+\b', question_lower))
question_words = {
word for word in question_words
if word not in stop_words and len(word) > 2
}
# Score each context
scored_contexts = []
for i, context in enumerate(retrieved_info):
context_lower = context.lower()
score = 0
# Keyword matching score
keyword_matches = 0
for word in question_words:
if word in context_lower:
keyword_matches += 1
# Multiple occurrences increase score
score += context_lower.count(word) * 2
# Length score (prefer moderate length)
context_len = len(context)
if 100 < context_len < 2000:
score += 5
elif context_len >= 2000:
score += 2
# Position bonus (earlier contexts often more relevant)
if i < 3:
score += 3
scored_contexts.append((score, context, keyword_matches))
# Sort by score (descending)
scored_contexts.sort(key=lambda x: x[0], reverse=True)
# Select contexts up to character limit
selected = []
total_chars = 0
for score, context, matches in scored_contexts:
if total_chars + len(context) <= max_chars:
selected.append(context)
total_chars += len(context)
else:
# Try to include high-scoring context by truncating
if score > 10 and total_chars < max_chars - 500:
remaining = max_chars - total_chars
# Find lines with keywords
lines = context.split('\n')
relevant_lines = []
current_chars = 0
for line in lines:
line_lower = line.lower()
line_relevance = any(word in line_lower for word in question_words)
if line_relevance and current_chars < remaining - 100:
relevant_lines.append(line)
current_chars += len(line)
if relevant_lines and len('\n'.join(relevant_lines)) > 100:
truncated = '\n'.join(relevant_lines)
selected.append(truncated + "\n[Content truncated...]")
total_chars += len(truncated)
break
return "\n\n".join(selected)
async def retrieve_relevant_information(
question: str,
group_id: str,
search_type: str,
search_limit: int,
connector: Any,
embedder: Any
) -> List[str]:
"""
Retrieve relevant information from memory graph for a question.
This function searches the Neo4j memory graph (populated during ingestion) and
returns relevant chunks, statements, and entity information that might help
answer the question.
The function supports three search types:
- "keyword": Full-text search using Cypher queries
- "embedding": Vector similarity search using embeddings
- "hybrid": Combination of keyword and embedding search with reranking
Args:
question: Question to search for
group_id: Database group ID (identifies which conversation memory to search)
search_type: "keyword", "embedding", or "hybrid"
search_limit: Max memory pieces to retrieve
connector: Neo4j connector instance
embedder: Embedder client instance
Returns:
List of text strings (chunks, statements, entity summaries) from memory graph.
Each string represents a piece of retrieved information.
Raises:
Exception: If search fails (caught and returns empty list)
"""
from app.repositories.neo4j.graph_search import (
search_graph,
search_graph_by_embedding
)
from app.core.memory.storage_services.search import run_hybrid_search
contexts_all: List[str] = []
try:
if search_type == "embedding":
# Embedding-based search
search_results = await search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=question,
group_id=group_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"],
)
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
# Build context from chunks
for c in chunks:
content = str(c.get("content", "")).strip()
if content:
contexts_all.append(content)
# Add statements
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
# Add summaries
for sm in summaries:
summary_text = str(sm.get("summary", "")).strip()
if summary_text:
contexts_all.append(summary_text)
# Add top entities (limit to 3 to avoid noise)
if entities:
scored = [e for e in entities if e.get("score") is not None]
top_entities = (
sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3]
if scored else entities[:3]
)
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(
f"EntitySummary: {name}"
f"{(' [' + '; '.join(meta) + ']') if meta else ''}"
)
if summary_lines:
contexts_all.append("\n".join(summary_lines))
elif search_type == "keyword":
# Keyword-based search
search_results = await search_graph(
connector=connector,
q=question,
group_id=group_id,
limit=search_limit
)
dialogs = search_results.get("dialogues", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
# Build context from dialogues
for d in dialogs:
content = str(d.get("content", "")).strip()
if content:
contexts_all.append(content)
# Add statements
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
# Add entity names
if entities:
entity_names = [
str(e.get("name", "")).strip()
for e in entities[:5]
if e.get("name")
]
if entity_names:
contexts_all.append(f"EntitySummary: {', '.join(entity_names)}")
else: # hybrid
# Hybrid search with fallback to embedding
try:
search_results = await run_hybrid_search(
query_text=question,
search_type=search_type,
group_id=group_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"],
output_path=None,
)
# Handle flat structure (new API format)
if search_results and isinstance(search_results, dict):
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
# Check if we got results
if not (chunks or statements or entities or summaries):
# Try nested structure (backward compatibility)
reranked = search_results.get("reranked_results", {})
if reranked and isinstance(reranked, dict):
chunks = reranked.get("chunks", [])
statements = reranked.get("statements", [])
entities = reranked.get("entities", [])
summaries = reranked.get("summaries", [])
else:
raise ValueError("Hybrid search returned empty results")
else:
raise ValueError("Hybrid search returned empty results")
except Exception as e:
# Fallback to embedding search
search_results = await search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=question,
group_id=group_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"],
)
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
# Build context (same for both hybrid and fallback)
for c in chunks:
content = str(c.get("content", "")).strip()
if content:
contexts_all.append(content)
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
for sm in summaries:
summary_text = str(sm.get("summary", "")).strip()
if summary_text:
contexts_all.append(summary_text)
# Add top entities
if entities:
scored = [e for e in entities if e.get("score") is not None]
top_entities = (
sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3]
if scored else entities[:3]
)
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(
f"EntitySummary: {name}"
f"{(' [' + '; '.join(meta) + ']') if meta else ''}"
)
if summary_lines:
contexts_all.append("\n".join(summary_lines))
except Exception as e:
# Return empty list on error
contexts_all = []
return contexts_all
async def ingest_conversations_if_needed(
conversations: List[str],
group_id: str,
reset: bool = False
) -> bool:
"""
Wrapper for conversation ingestion using external extraction pipeline.
This function populates the Neo4j database with processed conversation data
(chunks, statements, entities) so that the retrieval system has memory to search.
The ingestion process:
1. Parses conversation text into dialogue messages
2. Chunks the dialogues into semantic units
3. Extracts statements and entities using LLM
4. Generates embeddings for all content
5. Stores everything in Neo4j graph database
Args:
conversations: List of raw conversation texts from LoCoMo dataset
Example: ["User: I went to Paris. AI: When was that?", ...]
group_id: Target group ID for database storage
reset: Whether to clear existing data first (not implemented in wrapper)
Returns:
True if successful, False otherwise
Note:
The external function uses "contexts" to mean "conversation texts".
This runs the full extraction pipeline: chunking → entity extraction →
statement extraction → embedding → Neo4j storage.
"""
try:
success = await ingest_contexts_via_full_pipeline(
contexts=conversations,
group_id=group_id,
save_chunk_output=True
)
return success
except Exception as e:
print(f"[Ingestion] Failed to ingest conversations: {e}")
return False

View File

@@ -1,878 +0,0 @@
import argparse
import asyncio
import json
import os
import statistics
import time
from datetime import datetime, timedelta
from typing import Any, Dict, List
try:
from dotenv import load_dotenv
except Exception:
def load_dotenv():
return None
import re
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
bleu1,
jaccard,
latency_stats,
)
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
from app.core.memory.evaluation.extraction_utils import (
ingest_contexts_via_full_pipeline,
)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.storage_services.search import run_hybrid_search
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
# 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现)
def _loc_normalize(text: str) -> str:
import re
# 确保输入是字符串
text = str(text) if text is not None else ""
text = text.lower()
text = re.sub(r"[\,]", " ", text) # 去掉逗号
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
text = re.sub(r"[^\w\s]", " ", text)
text = " ".join(text.split())
return text
# 追加相对时间归一化为绝对日期有限支持today/yesterday/tomorrow/X days ago/in X days/last week/next week
def _resolve_relative_times(text: str, anchor: datetime) -> str:
import re
# 确保输入是字符串
t = str(text) if text is not None else ""
# today / yesterday / tomorrow
t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
# X days ago / in X days
def _ago_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor - timedelta(days=n)).date().isoformat()
def _in_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor + timedelta(days=n)).date().isoformat()
t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE)
t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE)
# last week / next week以7天近似
t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
return t
def loc_f1_score(prediction: str, ground_truth: str) -> float:
# 单答案 F1按词集合计算近似原始实现去除词干依赖
# 确保输入是字符串
pred_str = str(prediction) if prediction is not None else ""
truth_str = str(ground_truth) if ground_truth is not None else ""
p_tokens = _loc_normalize(pred_str).split()
g_tokens = _loc_normalize(truth_str).split()
if not p_tokens or not g_tokens:
return 0.0
p = set(p_tokens)
g = set(g_tokens)
tp = len(p & g)
precision = tp / len(p) if p else 0.0
recall = tp / len(g) if g else 0.0
return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
def loc_multi_f1(prediction: str, ground_truth: str) -> float:
# 多答案 F1prediction 与 ground_truth 以逗号分隔,逐一匹配取最大,再对多个 GT 取平均
# 确保输入是字符串
pred_str = str(prediction) if prediction is not None else ""
truth_str = str(ground_truth) if ground_truth is not None else ""
predictions = [p.strip() for p in str(pred_str).split(',') if p.strip()]
ground_truths = [g.strip() for g in str(truth_str).split(',') if g.strip()]
if not predictions or not ground_truths:
return 0.0
def _f1(a: str, b: str) -> float:
return loc_f1_score(a, b)
vals = []
for gt in ground_truths:
vals.append(max(_f1(pred, gt) for pred in predictions))
return sum(vals) / len(vals)
# 标准化 LoCoMo 类别名:支持数字 category 与字符串 cat/type
CATEGORY_MAP_NUM_TO_NAME = {
4: "Single-Hop",
1: "Multi-Hop",
3: "Open Domain",
2: "Temporal",
}
_TYPE_ALIASES = {
"single-hop": "Single-Hop",
"singlehop": "Single-Hop",
"single hop": "Single-Hop",
"multi-hop": "Multi-Hop",
"multihop": "Multi-Hop",
"multi hop": "Multi-Hop",
"open domain": "Open Domain",
"opendomain": "Open Domain",
"temporal": "Temporal",
}
def get_category_label(item: Dict[str, Any]) -> str:
# 1) 直接用字符串 cat
cat = item.get("cat")
if isinstance(cat, str) and cat.strip():
name = cat.strip()
lower = name.lower()
return _TYPE_ALIASES.get(lower, name)
# 2) 数字 category 转名称
cat_num = item.get("category")
if isinstance(cat_num, int):
return CATEGORY_MAP_NUM_TO_NAME.get(cat_num, "unknown")
# 3) 备用 type 字段
t = item.get("type")
if isinstance(t, str) and t.strip():
lower = t.strip().lower()
return _TYPE_ALIASES.get(lower, t.strip())
return "unknown"
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 12000) -> str:
"""基于问题关键词智能选择上下文"""
if not contexts:
return ""
# 提取问题关键词(只保留有意义的词)
question_lower = question.lower()
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
question_words = set(re.findall(r'\b\w+\b', question_lower))
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
print(f"🔍 问题关键词: {question_words}")
# 给每个上下文打分
scored_contexts = []
for i, context in enumerate(contexts):
context_lower = context.lower()
score = 0
# 关键词匹配得分
keyword_matches = 0
for word in question_words:
if word in context_lower:
keyword_matches += 1
# 关键词出现次数越多,得分越高
score += context_lower.count(word) * 2
# 上下文长度得分(适中的长度更好)
context_len = len(context)
if 100 < context_len < 2000: # 理想长度范围
score += 5
elif context_len >= 2000: # 太长可能包含无关信息
score += 2
# 如果是前几个上下文,给予额外分数(通常相关性更高)
if i < 3:
score += 3
scored_contexts.append((score, context, keyword_matches))
# 按得分排序
scored_contexts.sort(key=lambda x: x[0], reverse=True)
# 选择高得分的上下文,直到达到字符限制
selected = []
total_chars = 0
selected_count = 0
print("📊 上下文相关性分析:")
for score, context, matches in scored_contexts[:5]: # 只显示前5个
print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}")
for score, context, matches in scored_contexts:
if total_chars + len(context) <= max_chars:
selected.append(context)
total_chars += len(context)
selected_count += 1
else:
# 如果这个上下文得分很高但放不下,尝试截取
if score > 10 and total_chars < max_chars - 500:
remaining = max_chars - total_chars
# 找到包含关键词的部分
lines = context.split('\n')
relevant_lines = []
current_chars = 0
for line in lines:
line_lower = line.lower()
line_relevance = any(word in line_lower for word in question_words)
if line_relevance and current_chars < remaining - 100:
relevant_lines.append(line)
current_chars += len(line)
if relevant_lines:
truncated = '\n'.join(relevant_lines)
if len(truncated) > 100: # 确保有足够内容
selected.append(truncated + "\n[相关内容截断...]")
total_chars += len(truncated)
selected_count += 1
break # 不再尝试添加更多上下文
result = "\n\n".join(selected)
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符")
return result
def get_search_params_by_category(category: str):
"""根据问题类别调整检索参数"""
params_map = {
"Multi-Hop": {"limit": 20, "max_chars": 15000},
"Temporal": {"limit": 16, "max_chars": 10000},
"Open Domain": {"limit": 24, "max_chars": 18000},
"Single-Hop": {"limit": 12, "max_chars": 8000},
}
return params_map.get(category, {"limit": 16, "max_chars": 12000})
async def run_locomo_eval(
sample_size: int = 1,
group_id: str | None = None,
search_limit: int = 8,
context_char_budget: int = 4000, # 保持默认值不变
llm_temperature: float = 0.0,
llm_max_tokens: int = 32,
search_type: str = "hybrid", # 保持默认值不变
output_path: str | None = None,
skip_ingest_if_exists: bool = True,
llm_timeout: float = 10.0,
llm_max_retries: int = 1
) -> Dict[str, Any]:
# 函数内部使用三路检索逻辑,但保持参数签名不变
group_id = group_id or SELECTED_GROUP_ID
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
if not os.path.exists(data_path):
data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
with open(data_path, "r", encoding="utf-8") as f:
raw = json.load(f)
# LoCoMo 数据结构:顶层为若干对象,每个对象下有 qa 列表
qa_items: List[Dict[str, Any]] = []
if isinstance(raw, list):
for entry in raw:
qa_items.extend(entry.get("qa", []))
else:
qa_items.extend(raw.get("qa", []))
items: List[Dict[str, Any]] = qa_items[:sample_size]
# === 保持原来的数据摄入逻辑 ===
entries = raw if isinstance(raw, list) else [raw]
# 只摄入前1条对话保持原样
max_dialogues_to_ingest = 1
contents: List[str] = []
print(f"📊 找到 {len(entries)} 个对话对象,只摄入前 {max_dialogues_to_ingest}")
for i, entry in enumerate(entries[:max_dialogues_to_ingest]):
if not isinstance(entry, dict):
continue
conv = entry.get("conversation", {})
sample_id = entry.get("sample_id", f"unknown_{i}")
print(f"🔍 处理对话 {i+1}: {sample_id}")
lines: List[str] = []
if isinstance(conv, dict):
# 收集所有 session_* 的消息
session_count = 0
for key, val in conv.items():
if isinstance(val, list) and key.startswith("session_"):
session_count += 1
for msg in val:
role = msg.get("speaker") or "用户"
text = msg.get("text") or ""
text = str(text).strip()
if not text:
continue
lines.append(f"{role}: {text}")
print(f" - 包含 {session_count} 个session, {len(lines)} 条消息")
if not lines:
print(f"⚠️ 警告: 对话 {sample_id} 没有对话内容,跳过摄入")
continue
contents.append("\n".join(lines))
print(f"📥 总共摄入 {len(contents)} 个对话的conversation内容")
# 选择要评测的QA对从所有对话中选取
indexed_items: List[tuple[int, Dict[str, Any]]] = []
if isinstance(raw, list):
for e_idx, entry in enumerate(raw):
for qa in entry.get("qa", []):
indexed_items.append((e_idx, qa))
else:
for qa in raw.get("qa", []):
indexed_items.append((0, qa))
# 这里使用sample_size来限制评测的QA数量
selected = indexed_items[:sample_size]
items: List[Dict[str, Any]] = [qa for _, qa in selected]
print(f"🎯 将评测 {len(items)} 个QA对数据库中只包含 {len(contents)} 个对话")
# === 修改结束 ===
connector = Neo4jConnector()
# 关键修复:强制重新摄入纯净的对话数据
print("🔄 强制重新摄入纯净的对话数据...")
await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True)
# 使用异步LLM客户端
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
# 初始化embedder用于直接调用
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)
# connector initialized above
latencies_llm: List[float] = []
latencies_search: List[float] = []
# 上下文诊断收集
per_query_context_counts: List[int] = []
per_query_context_avg_tokens: List[float] = []
per_query_context_chars: List[int] = []
per_query_context_tokens_total: List[int] = []
# 详细样本调试信息
samples: List[Dict[str, Any]] = []
# 通用指标
f1s: List[float] = []
b1s: List[float] = []
jss: List[float] = []
# 参考 LoCoMo 评测的类别专用 F1multi-hop 使用多答案 F1
loc_f1s: List[float] = []
# Per-category aggregation
cat_counts: Dict[str, int] = {}
cat_f1s: Dict[str, List[float]] = {}
cat_b1s: Dict[str, List[float]] = {}
cat_jss: Dict[str, List[float]] = {}
cat_loc_f1s: Dict[str, List[float]] = {}
try:
for item in items:
q = item.get("question", "")
ref = item.get("answer", "")
# 确保答案是字符串
ref_str = str(ref) if ref is not None else ""
cat = get_category_label(item)
print(f"\n=== 处理问题: {q} ===")
# 根据类别调整检索参数
search_params = get_search_params_by_category(cat)
adjusted_limit = search_params["limit"]
max_chars = search_params["max_chars"]
print(f"🏷️ 类别: {cat}, 检索参数: limit={adjusted_limit}, max_chars={max_chars}")
# 改进的检索逻辑使用三路检索statements, dialogues, entities
t0 = time.time()
contexts_all: List[str] = []
search_results = None # 保存完整的检索结果
try:
if search_type == "embedding":
# 直接调用嵌入检索,包含三路数据
search_results = await search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=q,
group_id=group_id,
limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型
)
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
print(f"✅ 嵌入检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要")
# 构建上下文:优先使用 chunks、statements 和 summaries
for c in chunks:
content = str(c.get("content", "")).strip()
if content:
contexts_all.append(content)
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
for sm in summaries:
summary_text = str(sm.get("summary", "")).strip()
if summary_text:
contexts_all.append(summary_text)
# 实体摘要最多加入前3个高分实体避免噪声
scored = [e for e in entities if e.get("score") is not None]
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
if summary_lines:
contexts_all.append("\n".join(summary_lines))
elif search_type == "keyword":
# 直接调用关键词检索
search_results = await search_graph(
connector=connector,
q=q,
group_id=group_id,
limit=adjusted_limit
)
dialogs = search_results.get("dialogues", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
print(f"🔤 关键词检索找到 {len(dialogs)} 条对话, {len(statements)} 条陈述, {len(entities)} 个实体")
# 构建上下文
for d in dialogs:
content = str(d.get("content", "")).strip()
if content:
contexts_all.append(content)
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
# 实体处理(关键词检索的实体可能没有分数)
if entities:
entity_names = [str(e.get("name", "")).strip() for e in entities[:5] if e.get("name")]
if entity_names:
contexts_all.append(f"EntitySummary: {', '.join(entity_names)}")
else: # hybrid
# 🎯 关键修复:混合检索使用更严格的回退机制
print("🔀 使用混合检索(带回退机制)...")
try:
search_results = await run_hybrid_search(
query_text=q,
search_type=search_type,
group_id=group_id,
limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"],
output_path=None,
)
# 🎯 关键修复:正确处理混合检索的扁平结构
# 新的API返回扁平结构直接从顶层获取结果
if search_results and isinstance(search_results, dict):
# 新API返回扁平结构直接从顶层获取
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
# 检查是否有有效结果
if chunks or statements or entities or summaries:
print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 陈述, {len(entities)} 实体, {len(summaries)} 摘要")
else:
# 如果顶层没有结果,尝试旧的嵌套结构(向后兼容)
reranked = search_results.get("reranked_results", {})
if reranked and isinstance(reranked, dict):
chunks = reranked.get("chunks", [])
statements = reranked.get("statements", [])
entities = reranked.get("entities", [])
summaries = reranked.get("summaries", [])
print(f"✅ 混合检索成功使用旧格式reranked结果: {len(chunks)} chunks, {len(statements)} 陈述")
else:
raise ValueError("混合检索返回空结果")
else:
raise ValueError("混合检索返回空结果")
except Exception as e:
print(f"❌ 混合检索失败: {e},回退到嵌入检索")
search_results = await search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=q,
group_id=group_id,
limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"],
)
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
print(f"✅ 回退嵌入检索成功: {len(chunks)} chunks, {len(statements)} 陈述")
# 🎯 统一处理:构建上下文(所有检索类型共用)
for c in chunks:
content = str(c.get("content", "")).strip()
if content:
contexts_all.append(content)
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
for sm in summaries:
summary_text = str(sm.get("summary", "")).strip()
if summary_text:
contexts_all.append(summary_text)
# 实体摘要最多加入前3个高分实体
if entities:
scored = [e for e in entities if e.get("score") is not None]
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
if summary_lines:
contexts_all.append("\n".join(summary_lines))
# 关键修复:过滤掉包含当前问题答案的上下文
filtered_contexts = []
for context in contexts_all:
content = str(context)
# 排除包含当前问题标准答案的上下文
if ref_str and ref_str.strip() and ref_str.strip() in content:
print("🚫 过滤掉包含标准答案的上下文")
continue
filtered_contexts.append(context)
print(f"📊 过滤后保留 {len(filtered_contexts)} 个上下文 (原 {len(contexts_all)} 个)")
contexts_all = filtered_contexts
# 输出完整的检索结果信息
print("🔍 检索结果详情:")
if search_results:
output_data = {
"statements": [
{
"statement": s.get("statement", "")[:200] + "..." if len(s.get("statement", "")) > 200 else s.get("statement", ""),
"score": s.get("score", 0.0)
}
for s in (statements[:2] if 'statements' in locals() else [])
],
"dialogues": [
{
"uuid": d.get("uuid", ""),
"group_id": d.get("group_id", ""),
"content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""),
"score": d.get("score", 0.0)
}
for d in (dialogs[:2] if 'dialogs' in locals() else [])
],
"entities": [
{
"name": e.get("name", ""),
"entity_type": e.get("entity_type", ""),
"score": e.get("score", 0.0)
}
for e in (entities[:2] if 'entities' in locals() else [])
]
}
print(json.dumps(output_data, ensure_ascii=False, indent=2))
else:
print(" 无检索结果")
except Exception as e:
print(f"{search_type}检索失败: {e}")
contexts_all = []
search_results = None
t1 = time.time()
latencies_search.append((t1 - t0) * 1000)
# 使用智能上下文选择
context_text = ""
if contexts_all:
context_text = smart_context_selection(contexts_all, q, max_chars=max_chars)
# 如果智能选择后仍然过长,进行最终保护性截断
if len(context_text) > max_chars:
print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断")
context_text = context_text[:max_chars] + "\n\n[最终截断...]"
# 时间解析
anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性
context_text = _resolve_relative_times(context_text, anchor_date)
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text
print(f"📝 最终上下文长度: {len(context_text)} 字符")
# 显示不同上下文的预览
print("🔍 上下文预览:")
for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文
preview = context[:150].replace('\n', ' ')
print(f" 上下文{j+1}: {preview}...")
else:
print("❌ 没有检索到有效上下文")
context_text = "No relevant context found."
# 记录上下文诊断信息
per_query_context_counts.append(len(contexts_all))
per_query_context_avg_tokens.append(avg_context_tokens([context_text]))
per_query_context_chars.append(len(context_text))
per_query_context_tokens_total.append(len(_loc_normalize(context_text).split()))
# LLM 提示词
messages = [
{"role": "system", "content": (
"You are a precise QA assistant. Answer following these rules:\n"
"1) Extract the EXACT information mentioned in the context\n"
"2) For time questions: calculate actual dates from relative times\n"
"3) Return ONLY the answer text in simplest form\n"
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
"5) If no clear answer found, respond with 'Unknown'"
)},
{"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"},
]
t2 = time.time()
# 使用异步调用
resp = await llm_client.chat(messages=messages)
t3 = time.time()
latencies_llm.append((t3 - t2) * 1000)
# 兼容不同的响应格式
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown")
# 计算指标(确保使用字符串)
f1_val = common_f1(str(pred), ref_str)
b1_val = bleu1(str(pred), ref_str)
j_val = jaccard(str(pred), ref_str)
f1s.append(f1_val)
b1s.append(b1_val)
jss.append(j_val)
# Accumulate by category
cat_counts[cat] = cat_counts.get(cat, 0) + 1
cat_f1s.setdefault(cat, []).append(f1_val)
cat_b1s.setdefault(cat, []).append(b1_val)
cat_jss.setdefault(cat, []).append(j_val)
# LoCoMo 专用 F1multi-hop(1) 使用多答案 F1其它(2/3/4)使用单答案 F1
if item.get("category") in [2, 3, 4]:
loc_val = loc_f1_score(str(pred), ref_str)
elif item.get("category") in [1]:
loc_val = loc_multi_f1(str(pred), ref_str)
else:
loc_val = loc_f1_score(str(pred), ref_str)
loc_f1s.append(loc_val)
cat_loc_f1s.setdefault(cat, []).append(loc_val)
# 保存完整的检索结果信息
samples.append({
"question": q,
"answer": ref_str,
"category": cat,
"prediction": pred,
"metrics": {
"f1": f1_val,
"b1": b1_val,
"j": j_val,
"loc_f1": loc_val
},
"retrieval": {
"retrieved_documents": len(contexts_all),
"context_length": len(context_text),
"search_limit": adjusted_limit,
"max_chars": max_chars
},
"timing": {
"search_ms": (t1 - t0) * 1000,
"llm_ms": (t3 - t2) * 1000
}
})
print(f"🤖 LLM 回答: {pred}")
print(f"✅ 正确答案: {ref_str}")
print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}, LoCoMo F1: {loc_val:.3f}")
# Compute per-category averages and dispersion (std, iqr)
def _percentile(sorted_vals: List[float], p: float) -> float:
if not sorted_vals:
return 0.0
if len(sorted_vals) == 1:
return sorted_vals[0]
k = (len(sorted_vals) - 1) * p
f = int(k)
c = f + 1 if f + 1 < len(sorted_vals) else f
if f == c:
return sorted_vals[f]
return sorted_vals[f] + (sorted_vals[c] - sorted_vals[f]) * (k - f)
by_category: Dict[str, Dict[str, float | int]] = {}
for c in cat_counts:
f_list = cat_f1s.get(c, [])
b_list = cat_b1s.get(c, [])
j_list = cat_jss.get(c, [])
lf_list = cat_loc_f1s.get(c, [])
j_sorted = sorted(j_list)
j_std = statistics.stdev(j_list) if len(j_list) > 1 else 0.0
j_q75 = _percentile(j_sorted, 0.75)
j_q25 = _percentile(j_sorted, 0.25)
by_category[c] = {
"count": cat_counts[c],
"f1": (sum(f_list) / max(len(f_list), 1)) if f_list else 0.0,
"b1": (sum(b_list) / max(len(b_list), 1)) if b_list else 0.0,
"j": (sum(j_list) / max(len(j_list), 1)) if j_list else 0.0,
"j_std": j_std,
"j_iqr": (j_q75 - j_q25) if j_list else 0.0,
# 参考 LoCoMo 评测的类别专用 F1
"loc_f1": (sum(lf_list) / max(len(lf_list), 1)) if lf_list else 0.0,
}
# 累加命中cum accuracy by category与 evaluation_stats.py 输出形式相仿
cum_accuracy_by_category = {c: sum(cat_loc_f1s.get(c, [])) for c in cat_counts}
result = {
"dataset": "locomo",
"items": len(items),
"metrics": {
"f1": sum(f1s) / max(len(f1s), 1),
"b1": sum(b1s) / max(len(b1s), 1),
"j": sum(jss) / max(len(jss), 1),
# LoCoMo 类别专用 F1 的总体
"loc_f1": sum(loc_f1s) / max(len(loc_f1s), 1),
},
"by_category": by_category,
"category_counts": cat_counts,
"cum_accuracy_by_category": cum_accuracy_by_category,
"context": {
"avg_tokens": (sum(per_query_context_avg_tokens) / max(len(per_query_context_avg_tokens), 1)) if per_query_context_avg_tokens else 0.0,
"avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0,
"count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0,
"avg_memory_tokens": (sum(per_query_context_tokens_total) / max(len(per_query_context_tokens_total), 1)) if per_query_context_tokens_total else 0.0,
},
"latency": {
"search": latency_stats(latencies_search),
"llm": latency_stats(latencies_llm),
},
"samples": samples,
"params": {
"group_id": group_id,
"search_limit": search_limit,
"context_char_budget": context_char_budget,
"search_type": search_type,
"llm_id": SELECTED_LLM_ID,
"retrieval_embedding_id": SELECTED_EMBEDDING_ID,
"skip_ingest_if_exists": skip_ingest_if_exists,
"llm_timeout": llm_timeout,
"llm_max_retries": llm_max_retries,
"llm_temperature": llm_temperature,
"llm_max_tokens": llm_max_tokens
},
"timestamp": datetime.now().isoformat()
}
if output_path:
try:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"✅ 结果已保存到: {output_path}")
except Exception as e:
print(f"❌ 保存结果失败: {e}")
return result
finally:
await connector.close()
def main():
parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search")
parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate")
parser.add_argument("--group_id", type=str, default=None, help="Group ID for retrieval")
parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query")
parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context")
parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature")
parser.add_argument("--llm_max_tokens", type=int, default=32, help="LLM max tokens")
parser.add_argument("--search_type", type=str, default="embedding", choices=["keyword", "embedding", "hybrid"], help="Search type")
parser.add_argument("--output_path", type=str, default=None, help="Output path for results")
parser.add_argument("--skip_ingest_if_exists", action="store_true", help="Skip ingest if group exists")
parser.add_argument("--llm_timeout", type=float, default=10.0, help="LLM timeout in seconds")
parser.add_argument("--llm_max_retries", type=int, default=1, help="LLM max retries")
args = parser.parse_args()
load_dotenv()
result = asyncio.run(run_locomo_eval(
sample_size=args.sample_size,
group_id=args.group_id,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature,
llm_max_tokens=args.llm_max_tokens,
search_type=args.search_type,
output_path=args.output_path,
skip_ingest_if_exists=args.skip_ingest_if_exists,
llm_timeout=args.llm_timeout,
llm_max_retries=args.llm_max_retries
))
print("\n" + "="*50)
print("📊 最终评测结果:")
print(f" 样本数量: {result['items']}")
print(f" F1: {result['metrics']['f1']:.3f}")
print(f" BLEU-1: {result['metrics']['b1']:.3f}")
print(f" Jaccard: {result['metrics']['j']:.3f}")
print(f" LoCoMo F1: {result['metrics']['loc_f1']:.3f}")
print(f" 平均上下文长度: {result['context']['avg_chars']:.0f} 字符")
print(f" 平均检索延迟: {result['latency']['search']['mean']:.1f}ms")
print(f" 平均LLM延迟: {result['latency']['llm']['mean']:.1f}ms")
if result['by_category']:
print("\n📈 按类别细分:")
for cat, metrics in result['by_category'].items():
print(f" {cat}:")
print(f" 样本数: {metrics['count']}")
print(f" F1: {metrics['f1']:.3f}")
print(f" LoCoMo F1: {metrics['loc_f1']:.3f}")
print(f" Jaccard: {metrics['j']:.3f}{metrics['j_std']:.3f}, IQR={metrics['j_iqr']:.3f})")
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,324 +0,0 @@
import argparse
import asyncio
import json
import os
import time
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
try:
from dotenv import load_dotenv
except Exception:
def load_dotenv():
return None
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
exact_match,
latency_stats,
)
from app.core.memory.evaluation.extraction_utils import (
ingest_contexts_via_full_pipeline,
)
from app.core.memory.storage_services.search import run_hybrid_search
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
"""基于问题关键词对上下文进行评分选择,并在预算内拼接文本。"""
if not contexts:
return ""
import re
# 提取问题关键词(移除停用词)
question_lower = (question or "").lower()
stop_words = {
'what','when','where','who','why','how','did','do','does','is','are','was','were',
'the','a','an','and','or','but'
}
question_words = set(re.findall(r"\b\w+\b", question_lower))
question_words = {w for w in question_words if w not in stop_words and len(w) > 2}
# 评分
scored = []
for i, ctx in enumerate(contexts):
ctx_lower = (ctx or "").lower()
score = 0
matches = 0
for w in question_words:
if w in ctx_lower:
matches += 1
score += ctx_lower.count(w) * 2
length = len(ctx)
if 100 < length < 2000:
score += 5
elif length >= 2000:
score += 2
if i < 3:
score += 3
scored.append((score, ctx, matches))
scored.sort(key=lambda x: x[0], reverse=True)
# 选择直到达到字符限制,必要时截断包含关键词的段落
selected: List[str] = []
total = 0
for score, ctx, _ in scored:
if total + len(ctx) <= max_chars:
selected.append(ctx)
total += len(ctx)
else:
if score > 10 and total < max_chars - 200:
remaining = max_chars - total
lines = ctx.split('\n')
rel_lines: List[str] = []
cur = 0
for line in lines:
l = line.lower()
if any(w in l for w in question_words) and cur < remaining - 50:
rel_lines.append(line)
cur += len(line)
if rel_lines:
truncated = '\n'.join(rel_lines)
if len(truncated) > 50:
selected.append(truncated + "\n[相关内容截断...]")
total += len(truncated)
break
return "\n\n".join(selected)
def build_context_from_dialog(dialog_obj: Dict[str, Any]) -> str:
"""Compose a text context from `dialog` list in msc_self_instruct item."""
parts: List[str] = []
for turn in dialog_obj.get("dialog", []):
speaker = turn.get("speaker", "")
text = turn.get("text", "")
if text:
parts.append(f"{speaker}: {text}")
return "\n".join(parts)
def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Combine dialogues from embedding and keyword searches (embedding first)."""
if results is None:
return []
emb = []
kw = []
if isinstance(results.get("embedding_search"), dict):
emb = results.get("embedding_search", {}).get("dialogues", []) or []
elif isinstance(results.get("dialogues"), list):
emb = results.get("dialogues", []) or []
if isinstance(results.get("keyword_search"), dict):
kw = results.get("keyword_search", {}).get("dialogues", []) or []
seen = set()
merged: List[Dict[str, Any]] = []
for d in emb:
k = (str(d.get("uuid", "")), str(d.get("content", "")))
if k not in seen:
merged.append(d)
seen.add(k)
for d in kw:
k = (str(d.get("uuid", "")), str(d.get("content", "")))
if k not in seen:
merged.append(d)
seen.add(k)
return merged
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
group_id = group_id or SELECTED_GROUP_ID
# Load data
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
if not os.path.exists(data_path):
data_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl")
with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
items: List[Dict[str, Any]] = [json.loads(l) for l in lines[:sample_size]]
# 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入
# 说明memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略
contexts: List[str] = [build_context_from_dialog(item) for item in items]
await ingest_contexts_via_full_pipeline(contexts, group_id)
# LLM client (使用异步调用)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
# Evaluate each item
connector = Neo4jConnector()
latencies_llm: List[float] = []
latencies_search: List[float] = []
contexts_used: List[str] = []
correct_flags: List[float] = []
f1s: List[float] = []
b1s: List[float] = []
jss: List[float] = []
try:
for item in items:
question = item.get("self_instruct", {}).get("B", "") or item.get("question", "")
reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "")
# 检索:对齐 locomo 的三路检索dialogues/statements/entities
t0 = time.time()
try:
results = await run_hybrid_search(
query_text=question,
search_type=search_type,
group_id=group_id,
limit=search_limit,
include=["dialogues", "statements", "entities"],
output_path=None,
memory_config=memory_config,
)
except Exception:
results = None
t1 = time.time()
latencies_search.append((t1 - t0) * 1000)
# 构建上下文:包含对话、陈述和实体摘要,并智能选择
contexts_all: List[str] = []
if results:
if search_type == "hybrid":
emb = results.get("embedding_search", {}) if isinstance(results.get("embedding_search"), dict) else {}
kw = results.get("keyword_search", {}) if isinstance(results.get("keyword_search"), dict) else {}
emb_dialogs = emb.get("dialogues", [])
emb_statements = emb.get("statements", [])
emb_entities = emb.get("entities", [])
kw_dialogs = kw.get("dialogues", [])
kw_statements = kw.get("statements", [])
kw_entities = kw.get("entities", [])
all_dialogs = emb_dialogs + kw_dialogs
all_statements = emb_statements + kw_statements
all_entities = emb_entities + kw_entities
# 简单去重与限制
seen_texts = set()
for d in all_dialogs:
text = str(d.get("content", "")).strip()
if text and text not in seen_texts:
contexts_all.append(text)
seen_texts.add(text)
if len(contexts_all) >= search_limit:
break
for s in all_statements:
text = str(s.get("statement", "")).strip()
if text and text not in seen_texts:
contexts_all.append(text)
seen_texts.add(text)
if len(contexts_all) >= search_limit:
break
# 实体摘要最多3个
names = []
merged_entities = all_entities[:]
for e in merged_entities:
name = str(e.get("name", "")).strip()
if name and name not in names:
names.append(name)
if len(names) >= 3:
break
if names:
contexts_all.append("EntitySummary: " + ", ".join(names))
else:
dialogs = results.get("dialogues", [])
statements = results.get("statements", [])
entities = results.get("entities", [])
for d in dialogs:
text = str(d.get("content", "")).strip()
if text:
contexts_all.append(text)
for s in statements:
text = str(s.get("statement", "")).strip()
if text:
contexts_all.append(text)
names = [str(e.get("name", "")).strip() for e in entities[:3] if e.get("name")]
if names:
contexts_all.append("EntitySummary: " + ", ".join(names))
# 智能选择并截断到预算
context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else ""
if not context_text:
context_text = "No relevant context found."
contexts_used.append(context_text[:200])
# Call LLM (使用异步调用)
messages = [
{"role": "system", "content": "You are a QA assistant. Answer in English. Strictly follow: 1) If the context contains the answer, copy the shortest exact span from the context as the answer; 2) If the answer cannot be determined from the context, respond with 'Unknown'; 3) Return ONLY the answer text, no explanations."},
{"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"},
]
t2 = time.time()
resp = await llm_client.chat(messages=messages)
t3 = time.time()
latencies_llm.append((t3 - t2) * 1000)
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip())
# Metrics: F1, BLEU-1, Jaccard; keep exact match for reference
correct_flags.append(exact_match(pred, reference))
from app.core.memory.evaluation.common.metrics import (
bleu1,
f1_score,
jaccard,
)
f1s.append(f1_score(str(pred), str(reference)))
b1s.append(bleu1(str(pred), str(reference)))
jss.append(jaccard(str(pred), str(reference)))
# Aggregate metrics
acc = sum(correct_flags) / max(len(correct_flags), 1)
ctx_avg_tokens = avg_context_tokens(contexts_used)
result = {
"dataset": "memsciqa",
"items": len(items),
"metrics": {
"accuracy": acc,
# Placeholders for extensibility
"f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0,
"bleu1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0,
"jaccard": (sum(jss) / max(len(jss), 1)) if jss else 0.0,
},
"latency": {
"search": latency_stats(latencies_search),
"llm": latency_stats(latencies_llm),
},
"avg_context_tokens": ctx_avg_tokens,
}
return result
finally:
await connector.close()
def main():
load_dotenv()
parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen")
parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量")
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id默认取 runtime.json")
parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数")
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大生成长度")
parser.add_argument("--search-type", type=str, choices=["keyword","embedding","hybrid"], default="hybrid", help="检索类型")
args = parser.parse_args()
result = asyncio.run(
run_memsciqa_eval(
sample_size=args.sample_size,
group_id=args.group_id,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature,
llm_max_tokens=args.llm_max_tokens,
search_type=args.search_type,
)
)
print(json.dumps(result, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()

View File

@@ -1,576 +0,0 @@
import argparse
import asyncio
import json
import os
import re
import time
from datetime import datetime
from typing import Any, Dict, List
try:
from dotenv import load_dotenv
except Exception:
def load_dotenv():
return None
# 路径与模块导入保持与现有评估脚本一致
import sys
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR))
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
for _p in (_SRC_DIR, _PROJECT_ROOT):
if _p not in sys.path:
sys.path.insert(0, _p)
# 对齐 locomo_test 的检索逻辑:直接使用 graph_search 与 Neo4jConnector/Embedder1
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
exact_match,
latency_stats,
)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
try:
from app.core.memory.evaluation.common.metrics import bleu1, f1_score, jaccard
except Exception:
# 兜底:简单实现(必要时)
def f1_score(pred: str, ref: str) -> float:
ps = pred.lower().split()
rs = ref.lower().split()
if not ps or not rs:
return 0.0
tp = len(set(ps) & set(rs))
if tp == 0:
return 0.0
precision = tp / len(ps)
recall = tp / len(rs)
if precision + recall == 0:
return 0.0
return 2 * precision * recall / (precision + recall)
def bleu1(pred: str, ref: str) -> float:
ps = pred.lower().split()
rs = ref.lower().split()
if not ps or not rs:
return 0.0
overlap = len([w for w in ps if w in rs])
return overlap / max(len(ps), 1)
def jaccard(pred: str, ref: str) -> float:
ps = set(pred.lower().split())
rs = set(ref.lower().split())
union = len(ps | rs)
if union == 0:
return 0.0
return len(ps & rs) / union
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
"""基于问题关键词对上下文进行评分选择,并在预算内拼接文本。
参考 evaluation/memsciqa/evaluate_qa.py 的实现,避免路径导入带来的不稳定。
"""
if not contexts:
return ""
question_lower = (question or "").lower()
stop_words = {
'what','when','where','who','why','how','did','do','does','is','are','was','were',
'the','a','an','and','or','but'
}
question_words = set(re.findall(r"\b\w+\b", question_lower))
question_words = {w for w in question_words if w not in stop_words and len(w) > 2}
scored = []
for i, ctx in enumerate(contexts):
ctx_lower = (ctx or "").lower()
score = 0
matches = 0
for w in question_words:
if w in ctx_lower:
matches += 1
score += ctx_lower.count(w) * 2
length = len(ctx)
if 100 < length < 2000:
score += 5
elif length >= 2000:
score += 2
if i < 3:
score += 3
scored.append((score, ctx, matches))
scored.sort(key=lambda x: x[0], reverse=True)
selected: List[str] = []
total = 0
for score, ctx, _ in scored:
if total + len(ctx) <= max_chars:
selected.append(ctx)
total += len(ctx)
else:
if score > 10 and total < max_chars - 200:
remaining = max_chars - total
lines = ctx.split('\n')
rel_lines: List[str] = []
cur = 0
for line in lines:
l = line.lower()
if any(w in l for w in question_words) and cur < remaining - 50:
rel_lines.append(line)
cur += len(line)
if rel_lines:
truncated = '\n'.join(rel_lines)
if len(truncated) > 50:
selected.append(truncated + "\n[相关内容截断...]")
total += len(truncated)
break
return "\n\n".join(selected)
def extract_question_keywords(question: str, max_keywords: int = 8) -> List[str]:
"""提取问题中的关键词(简单英文分词,去停用词,长度>=3"""
ql = (question or "").lower()
stop_words = {
'what','when','where','who','why','how','did','do','does','is','are','was','were',
'the','a','an','and','or','but','of','to','in','on','for','with','from','that','this'
}
words = re.findall(r"\b[\w-]+\b", ql)
kws = [w for w in words if w not in stop_words and len(w) >= 3]
# 去重保序
seen = set()
uniq = []
for w in kws:
if w not in seen:
uniq.append(w)
seen.add(w)
if len(uniq) >= max_keywords:
break
return uniq
def analyze_contexts_simple(contexts: List[str], keywords: List[str], top_n: int = 5) -> List[Dict[str, int | float]]:
"""对上下文进行简单相关性打分,仅用于控制台可视化。
评分: score = match_count*200 + min(len(text), 100000)/100
"""
results = []
for ctx in contexts:
tl = (ctx or "").lower()
match_count = sum(1 for k in keywords if k in tl)
length = len(ctx)
score = match_count * 200 + min(length, 100000) / 100.0
results.append({"score": float(f"{score:.0f}"), "match": match_count, "length": length})
results.sort(key=lambda x: (x["score"], x["match"], x["length"]), reverse=True)
return results[:max(top_n, 0)]
# 纯测试脚本不进行摄入;若需摄入请使用 evaluate_qa.py
def load_dataset_memsciqa(data_path: str) -> List[Dict[str, Any]]:
if not os.path.exists(data_path):
raise FileNotFoundError(f"未找到数据集: {data_path}")
items: List[Dict[str, Any]] = []
with open(data_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
items.append(json.loads(line))
except Exception:
# 跳过坏行但不中断
continue
return items
async def run_memsciqa_test(
sample_size: int = 3,
group_id: str | None = None,
search_limit: int = 8,
context_char_budget: int = 4000,
llm_temperature: float = 0.0,
llm_max_tokens: int = 64,
search_type: str = "embedding",
data_path: str | None = None,
start_index: int = 0,
verbose: bool = True,
) -> Dict[str, Any]:
"""memsciqa 增强测试脚本:结合 evaluate_qa 的三路检索与智能上下文选择。
- 支持从指定索引开始与评估全部样本sample_size<=0
- 支持在摄入前重置组(清空图)与跳过摄入
- 支持 keyword / embedding / hybrid 三种检索
"""
# 默认使用指定的 memsci 组 ID
group_id = group_id or "group_memsci"
# 数据路径解析(项目根与当前工作目录兜底)
if not data_path:
proj_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
cwd_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl")
if os.path.exists(proj_path):
data_path = proj_path
elif os.path.exists(cwd_path):
data_path = cwd_path
else:
raise FileNotFoundError("未找到数据集: data/msc_self_instruct.jsonl请确保其存在于项目根目录或当前工作目录的 data 目录下。")
# 加载数据
all_items = load_dataset_memsciqa(data_path)
if sample_size is None or sample_size <= 0:
items = all_items[start_index:]
else:
items = all_items[start_index:start_index + sample_size]
# 初始化 LLM纯测试不进行摄入
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm = factory.get_llm_client(SELECTED_LLM_ID)
# 初始化 Neo4j 连接与向量检索 Embedder对齐 locomo_test
connector = Neo4jConnector()
embedder = None
if search_type in ("embedding", "hybrid"):
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)
# 评估循环
latencies_llm: List[float] = []
latencies_search: List[float] = []
# 存储完整上下文文本用于统计
contexts_used: List[str] = []
per_query_context_chars: List[int] = []
per_query_context_counts: List[int] = []
correct_flags: List[float] = []
f1s: List[float] = []
b1s: List[float] = []
jss: List[float] = []
samples: List[Dict[str, Any]] = []
total_items = len(items)
for idx, item in enumerate(items):
if verbose:
print(f"\n🧪 评估样本: {idx+1}/{total_items}")
question = item.get("self_instruct", {}).get("B", "") or item.get("question", "")
reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "")
# 三路检索chunks/statements/entities/summaries对齐 qwen_search_eval.py
t0 = time.time()
results = None
try:
if search_type in ("embedding", "hybrid"):
# 使用嵌入检索(与 qwen_search_eval 对齐)
results = await search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=question,
group_id=group_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
)
elif search_type == "keyword":
# 关键词检索(直接调用 graph_search
results = await search_graph(
connector=connector,
q=question,
group_id=group_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
)
except Exception:
results = None
t1 = time.time()
search_ms = (t1 - t0) * 1000
latencies_search.append(search_ms)
# 构建上下文:包含 chunks、陈述、摘要和实体对齐 qwen_search_eval.py
contexts_all: List[str] = []
retrieved_counts: Dict[str, int] = {}
if results:
chunks = results.get("chunks", [])
statements = results.get("statements", [])
entities = results.get("entities", [])
summaries = results.get("summaries", [])
retrieved_counts = {
"chunks": len(chunks),
"statements": len(statements),
"entities": len(entities),
"summaries": len(summaries),
}
# 优先使用 chunks
for c in chunks:
text = str(c.get("content", "")).strip()
if text:
contexts_all.append(text)
# 然后是 statements
for s in statements:
text = str(s.get("statement", "")).strip()
if text:
contexts_all.append(text)
# 然后是 summaries
for sm in summaries:
text = str(sm.get("summary", "")).strip()
if text:
contexts_all.append(text)
# 实体摘要最多加入前3个高分实体对齐 qwen_search_eval.py
scored = [e for e in entities if e.get("score") is not None]
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
if summary_lines:
contexts_all.append("\n".join(summary_lines))
if verbose:
if retrieved_counts:
print(f"✅ 检索成功: {retrieved_counts.get('chunks',0)} chunks, {retrieved_counts.get('statements',0)} 条陈述, {retrieved_counts.get('entities',0)} 个实体, {retrieved_counts.get('summaries',0)} 个摘要")
print(f"📊 有效上下文数量: {len(contexts_all)}")
q_keywords = extract_question_keywords(question, max_keywords=8)
if q_keywords:
print(f"🔍 问题关键词: {set(q_keywords)}")
if contexts_all:
analysis = analyze_contexts_simple(contexts_all, q_keywords, top_n=5)
if analysis:
print("📊 上下文相关性分析:")
for a in analysis:
print(f" - 得分: {int(a['score'])}, 关键词匹配: {a['match']}, 长度: {a['length']}")
# 打印检索到的上下文预览,便于定位为何为 Unknown
print("🔎 上下文预览最多前10条每条截断展示:")
for i, ctx in enumerate(contexts_all[:10]):
preview = str(ctx).replace("\n", " ")
if len(preview) > 300:
preview = preview[:300] + "..."
print(f" [{i+1}] 长度: {len(ctx)} | 片段: {preview}")
# 标注参考答案是否出现在任一上下文中
ref_lower = (str(reference) or "").lower()
if ref_lower:
hits = []
for i, ctx in enumerate(contexts_all):
if ref_lower in str(ctx).lower():
hits.append(i+1)
print(f"🔗 参考答案命中上下文条数: {len(hits)}" + (f" | 命中索引: {hits}" if hits else ""))
context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else ""
if not context_text:
context_text = "No relevant context found."
contexts_used.append(context_text)
per_query_context_chars.append(len(context_text))
per_query_context_counts.append(len(contexts_all))
if verbose:
selected_count = (context_text.count("\n\n") + 1) if context_text else 0
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {len(context_text)}字符")
# 展示拼接后的上下文片段,便于核查是否包含答案
concat_preview = context_text.replace("\n", " ")
if len(concat_preview) > 600:
concat_preview = concat_preview[:600] + "..."
print(f"🧵 拼接上下文预览: {concat_preview}")
messages = [
{
"role": "system",
"content": (
"You are a QA assistant. Answer in English. Follow these guidelines:\n"
"1) If the context contains information to answer the question, provide a concise answer based on the context;\n"
"2) If the context does not contain enough information to answer the question, respond with 'Unknown';\n"
"3) Keep your answer brief and to the point;\n"
"4) Do not add explanations or additional text beyond the answer."
),
},
{"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"},
]
t2 = time.time()
try:
# 使用异步调用
resp = await llm.chat(messages=messages)
# 更健壮的响应解析处理不同的LLM响应格式
if hasattr(resp, 'content'):
pred = resp.content.strip()
elif isinstance(resp, dict) and "choices" in resp and len(resp["choices"]) > 0:
pred = resp["choices"][0]["message"]["content"].strip()
elif isinstance(resp, dict) and "content" in resp:
pred = resp["content"].strip()
elif isinstance(resp, str):
pred = resp.strip()
else:
pred = "Unknown"
print(f"⚠️ LLM响应格式异常: {type(resp)} - {resp}")
# 检查预测是否为"Unknown"或空,如果是则检查上下文是否真的没有答案
if pred.lower() in ["unknown", ""]:
# 如果参考答案在上下文中存在但LLM返回Unknown可能是提示词问题
ref_lower = (str(reference) or "").lower()
if ref_lower and any(ref_lower in ctx.lower() for ctx in contexts_all):
print("⚠️ 参考答案在上下文中存在但LLM返回Unknown检查提示词")
except Exception as e:
# 更详细的错误处理
pred = "Unknown"
print(f"⚠️ LLM调用异常: {e}")
t3 = time.time()
llm_ms = (t3 - t2) * 1000
latencies_llm.append(llm_ms)
exact = exact_match(pred, reference)
correct_flags.append(exact)
f1_val = f1_score(str(pred), str(reference))
b1_val = bleu1(str(pred), str(reference))
j_val = jaccard(str(pred), str(reference))
f1s.append(f1_val)
b1s.append(b1_val)
jss.append(j_val)
if verbose:
print(f"🤖 LLM 回答: {pred}")
print(f"✅ 正确答案: {reference}")
print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}")
print(f"⏱️ 延迟 - 检索: {search_ms:.0f}ms, LLM: {llm_ms:.0f}ms")
# 对齐 locomo/qwen_search_eval.py 的样本输出结构
samples.append({
"question": str(question),
"answer": str(reference),
"prediction": str(pred),
"metrics": {
"f1": f1_val,
"b1": b1_val,
"j": j_val
},
"retrieval": {
"retrieved_documents": len(contexts_all),
"context_length": len(context_text),
"search_limit": search_limit,
"max_chars": context_char_budget
},
"timing": {
"search_ms": search_ms,
"llm_ms": llm_ms
}
})
# 计算总体指标与聚合
acc = sum(correct_flags) / max(len(correct_flags), 1)
ctx_avg_tokens = avg_context_tokens(contexts_used)
result = {
"dataset": "memsciqa",
"items": len(items),
"metrics": {
"f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0,
"b1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0,
"j": (sum(jss) / max(len(jss), 1)) if jss else 0.0,
},
"context": {
"avg_tokens": ctx_avg_tokens,
"avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0,
"count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0,
"avg_memory_tokens": 0.0
},
"latency": {
"search": latency_stats(latencies_search),
"llm": latency_stats(latencies_llm),
},
"samples": samples,
"params": {
"group_id": group_id,
"search_limit": search_limit,
"context_char_budget": context_char_budget,
"llm_temperature": llm_temperature,
"llm_max_tokens": llm_max_tokens,
"search_type": search_type,
"start_index": start_index,
"llm_id": SELECTED_LLM_ID,
"retrieval_embedding_id": SELECTED_EMBEDDING_ID
},
"timestamp": datetime.now().isoformat(),
}
try:
await connector.close()
except Exception:
pass
return result
def main():
load_dotenv()
parser = argparse.ArgumentParser(description="memsciqa 测试脚本(三路检索 + 智能上下文选择)")
parser.add_argument("--sample-size", type=int, default=30, help="样本数量(<=0 表示全部)")
parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size")
parser.add_argument("--start-index", type=int, default=0, help="起始样本索引")
parser.add_argument("--group-id", type=str, default="group_memsci", help="图数据库 Group ID默认 group_memsci")
parser.add_argument("--search-limit", type=int, default=8, help="检索条数上限")
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大输出 token")
parser.add_argument("--search-type", type=str, default="embedding", choices=["embedding","keyword","hybrid"], help="检索类型hybrid 等同于 embedding")
parser.add_argument("--data-path", type=str, default=None, help="数据集路径(默认 data/msc_self_instruct.jsonl")
parser.add_argument("--output", type=str, default=None, help="将评估结果保存到指定文件路径JSON")
parser.add_argument("--verbose", action="store_true", default=True, help="打印过程日志(默认开启)")
parser.add_argument("--quiet", action="store_true", help="关闭过程日志")
args = parser.parse_args()
sample_size = 0 if args.all else args.sample_size
verbose_flag = False if args.quiet else args.verbose
result = asyncio.run(
run_memsciqa_test(
sample_size=sample_size,
group_id=args.group_id,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature,
llm_max_tokens=args.llm_max_tokens,
search_type=args.search_type,
data_path=args.data_path,
start_index=args.start_index,
verbose=verbose_flag,
)
)
print(json.dumps(result, ensure_ascii=False, indent=2))
# 结果保存
out_path = args.output
if not out_path:
eval_dir = os.path.dirname(os.path.abspath(__file__))
dataset_results_dir = os.path.join(eval_dir, "results")
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
out_path = os.path.join(dataset_results_dir, f"memsciqa_{result['params']['search_type']}_{ts}.json")
try:
os.makedirs(os.path.dirname(out_path), exist_ok=True)
with open(out_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"\n💾 结果已保存: {out_path}")
except Exception as e:
print(f"⚠️ 结果保存失败: {e}")
if __name__ == "__main__":
main()

View File

@@ -1,150 +0,0 @@
import argparse
import asyncio
import json
import os
import sys
from typing import Any, Dict
# Add src directory to Python path for proper imports when running from evaluation directory
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'src'))
try:
from dotenv import load_dotenv
except Exception:
def load_dotenv():
return None
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, PROJECT_ROOT
from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval
from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test
from app.core.memory.evaluation.locomo.qwen_search_eval import run_locomo_eval
async def run(
dataset: str,
sample_size: int,
reset_group: bool,
group_id: str | None,
judge_model: str | None = None,
search_limit: int | None = None,
context_char_budget: int | None = None,
llm_temperature: float | None = None,
llm_max_tokens: int | None = None,
search_type: str | None = None,
start_index: int | None = None,
max_contexts_per_item: int | None = None,
) -> Dict[str, Any]:
# 恢复原始风格:统一入口做路由,并沿用各数据集既有默认
group_id = group_id or SELECTED_GROUP_ID
if reset_group:
connector = Neo4jConnector()
try:
await connector.delete_group(group_id)
finally:
await connector.close()
if dataset == "locomo":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
if search_limit is not None:
kwargs["search_limit"] = search_limit
if context_char_budget is not None:
kwargs["context_char_budget"] = context_char_budget
if llm_temperature is not None:
kwargs["llm_temperature"] = llm_temperature
if llm_max_tokens is not None:
kwargs["llm_max_tokens"] = llm_max_tokens
if search_type is not None:
kwargs["search_type"] = search_type
return await run_locomo_eval(**kwargs)
if dataset == "memsciqa":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
if search_limit is not None:
kwargs["search_limit"] = search_limit
if context_char_budget is not None:
kwargs["context_char_budget"] = context_char_budget
if llm_temperature is not None:
kwargs["llm_temperature"] = llm_temperature
if llm_max_tokens is not None:
kwargs["llm_max_tokens"] = llm_max_tokens
if search_type is not None:
kwargs["search_type"] = search_type
return await run_memsciqa_eval(**kwargs)
if dataset == "longmemeval":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
if search_limit is not None:
kwargs["search_limit"] = search_limit
if context_char_budget is not None:
kwargs["context_char_budget"] = context_char_budget
if llm_temperature is not None:
kwargs["llm_temperature"] = llm_temperature
if llm_max_tokens is not None:
kwargs["llm_max_tokens"] = llm_max_tokens
if search_type is not None:
kwargs["search_type"] = search_type
if start_index is not None:
kwargs["start_index"] = start_index
if max_contexts_per_item is not None:
kwargs["max_contexts_per_item"] = max_contexts_per_item
return await run_longmemeval_test(**kwargs)
raise ValueError(f"未知数据集: {dataset}")
def main():
load_dotenv()
parser = argparse.ArgumentParser(description="统一评估入口memsciqa / longmemeval / locomo")
parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True)
parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通")
parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 group_id 的图数据")
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id默认取 runtime.json")
parser.add_argument("--judge-model", type=str, default=None, help="可选longmemeval 判别式评测模型名")
parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)")
parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)")
parser.add_argument("--llm-temperature", type=float, default=None, help="生成温度(不提供则使用各脚本默认)")
parser.add_argument("--llm-max-tokens", type=int, default=None, help="最大生成 tokens不提供则使用各脚本默认")
parser.add_argument("--search-type", type=str, default=None, choices=["keyword", "embedding", "hybrid"], help="检索类型(不提供则使用各脚本默认)")
# 仅透传到 longmemeval其他数据集忽略
parser.add_argument("--start-index", type=int, default=None, help="仅 longmemeval起始样本索引不提供则用脚本默认")
parser.add_argument("--max-contexts-per-item", type=int, default=None, help="仅 longmemeval每条样本摄入的上下文数量上限不提供则用脚本默认")
parser.add_argument("--output", type=str, default=None, help="可选将评估结果保存到指定文件路径JSON不提供时默认保存到 evaluation/<dataset>/results 目录")
args = parser.parse_args()
result = asyncio.run(run(
args.dataset,
args.sample_size,
args.reset_group,
args.group_id,
args.judge_model,
args.search_limit,
args.context_char_budget,
args.llm_temperature,
args.llm_max_tokens,
args.search_type,
args.start_index,
args.max_contexts_per_item,
))
print(json.dumps(result, ensure_ascii=False, indent=2))
# 结果输出逻辑保持不变
if args.output:
out_path = args.output
else:
eval_dir = os.path.dirname(os.path.abspath(__file__))
dataset_results_dir = os.path.join(eval_dir, args.dataset, "results")
out_filename = f"{args.dataset}_{args.sample_size}.json"
out_path = os.path.join(dataset_results_dir, out_filename)
out_dir = os.path.dirname(out_path)
if out_dir and not os.path.exists(out_dir):
os.makedirs(out_dir, exist_ok=True)
with open(out_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"\n结果已保存到: {out_path}")
if __name__ == "__main__":
main()

View File

@@ -72,7 +72,7 @@ class TemporalSearchParams(BaseModel):
"""Parameters for temporal search queries in the knowledge graph. """Parameters for temporal search queries in the knowledge graph.
Attributes: Attributes:
group_id: Group ID to filter search results (default: 'test') end_user_id: Group ID to filter search results (default: 'test')
apply_id: Application ID to filter search results apply_id: Application ID to filter search results
user_id: User ID to filter search results user_id: User ID to filter search results
start_date: Start date for temporal filtering (format: 'YYYY-MM-DD') start_date: Start date for temporal filtering (format: 'YYYY-MM-DD')
@@ -81,7 +81,7 @@ class TemporalSearchParams(BaseModel):
invalid_date: Date when memory should be invalid (format: 'YYYY-MM-DD') invalid_date: Date when memory should be invalid (format: 'YYYY-MM-DD')
limit: Maximum number of results to return (default: 3) limit: Maximum number of results to return (default: 3)
""" """
group_id: Optional[str] = Field("test", description="The group ID to filter the search.") end_user_id: Optional[str] = Field("test", description="The group ID to filter the search.")
apply_id: Optional[str] = Field(None, description="The apply ID to filter the search.") apply_id: Optional[str] = Field(None, description="The apply ID to filter the search.")
user_id: Optional[str] = Field(None, description="The user ID to filter the search.") user_id: Optional[str] = Field(None, description="The user ID to filter the search.")
start_date: Optional[str] = Field(None, description="The start date for the search.") start_date: Optional[str] = Field(None, description="The start date for the search.")

View File

@@ -103,9 +103,7 @@ class Edge(BaseModel):
id: Unique identifier for the edge id: Unique identifier for the edge
source: ID of the source node source: ID of the source node
target: ID of the target node target: ID of the target node
group_id: Group ID for multi-tenancy end_user_id: End user ID for multi-tenancy
user_id: User ID for user-specific data
apply_id: Application ID for application-specific data
run_id: Unique identifier for the pipeline run that created this edge run_id: Unique identifier for the pipeline run that created this edge
created_at: Timestamp when the edge was created (system perspective) created_at: Timestamp when the edge was created (system perspective)
expired_at: Optional timestamp when the edge expires (system perspective) expired_at: Optional timestamp when the edge expires (system perspective)
@@ -113,9 +111,7 @@ class Edge(BaseModel):
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.") id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.")
source: str = Field(..., description="The ID of the source node.") source: str = Field(..., description="The ID of the source node.")
target: str = Field(..., description="The ID of the target node.") target: str = Field(..., description="The ID of the target node.")
group_id: str = Field(..., description="The group ID of the edge.") end_user_id: str = Field(..., description="The end user ID of the edge.")
user_id: str = Field(..., description="The user ID of the edge.")
apply_id: str = Field(..., description="The apply ID of the edge.")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.") created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.") expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
@@ -185,18 +181,14 @@ class Node(BaseModel):
Attributes: Attributes:
id: Unique identifier for the node id: Unique identifier for the node
name: Name of the node name: Name of the node
group_id: Group ID for multi-tenancy end_user_id: End user ID for multi-tenancy
user_id: User ID for user-specific data
apply_id: Application ID for application-specific data
run_id: Unique identifier for the pipeline run that created this node run_id: Unique identifier for the pipeline run that created this node
created_at: Timestamp when the node was created (system perspective) created_at: Timestamp when the node was created (system perspective)
expired_at: Optional timestamp when the node expires (system perspective) expired_at: Optional timestamp when the node expires (system perspective)
""" """
id: str = Field(..., description="The unique identifier for the node.") id: str = Field(..., description="The unique identifier for the node.")
name: str = Field(..., description="The name of the node.") name: str = Field(..., description="The name of the node.")
group_id: str = Field(..., description="The group ID of the node.") end_user_id: str = Field(..., description="The end user ID of the node.")
user_id: str = Field(..., description="The user ID of the edge.")
apply_id: str = Field(..., description="The apply ID of the edge.")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(..., description="The valid time of the node from system perspective.") created_at: datetime = Field(..., description="The valid time of the node from system perspective.")
expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.") expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.")

View File

@@ -55,7 +55,7 @@ class Statement(BaseModel):
Attributes: Attributes:
id: Unique identifier for the statement id: Unique identifier for the statement
chunk_id: ID of the parent chunk this statement belongs to chunk_id: ID of the parent chunk this statement belongs to
group_id: Optional group ID for multi-tenancy end_user_id: Optional group ID for multi-tenancy
statement: The actual statement text content statement: The actual statement text content
speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses) speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses)
statement_embedding: Optional embedding vector for the statement statement_embedding: Optional embedding vector for the statement
@@ -73,7 +73,7 @@ class Statement(BaseModel):
""" """
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.") id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.")
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.") chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.") end_user_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
statement: str = Field(..., description="The text content of the statement.") statement: str = Field(..., description="The text content of the statement.")
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses") speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.") statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.")
@@ -159,9 +159,7 @@ class DialogData(BaseModel):
context: Full conversation context context: Full conversation context
dialog_embedding: Optional embedding vector for the entire dialog dialog_embedding: Optional embedding vector for the entire dialog
ref_id: Reference ID linking to external dialog system ref_id: Reference ID linking to external dialog system
group_id: Group ID for multi-tenancy end_user_id: End user ID for multi-tenancy
user_id: User ID for user-specific data
apply_id: Application ID for application-specific data
created_at: Timestamp when the dialog was created created_at: Timestamp when the dialog was created
expired_at: Timestamp when the dialog expires (default: far future) expired_at: Timestamp when the dialog expires (default: far future)
metadata: Additional metadata as key-value pairs metadata: Additional metadata as key-value pairs
@@ -175,9 +173,7 @@ class DialogData(BaseModel):
context: ConversationContext = Field(..., description="The full conversation context as a single string.") context: ConversationContext = Field(..., description="The full conversation context as a single string.")
dialog_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the dialog.") dialog_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the dialog.")
ref_id: str = Field(..., description="Refer to external dialog id. This is used to link to the original dialog.") ref_id: str = Field(..., description="Refer to external dialog id. This is used to link to the original dialog.")
group_id: str = Field(default=..., description="Group ID of dialogue data") end_user_id: str = Field(default=..., description="End user ID of dialogue data")
user_id: str = Field(..., description="USER ID of dialogue data")
apply_id: str = Field(..., description="APPLY ID of dialogue data")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.") created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.")
expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.") expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.")
@@ -250,11 +246,11 @@ class DialogData(BaseModel):
return [] return []
def assign_group_id_to_statements(self) -> None: def assign_group_id_to_statements(self) -> None:
"""Assign this dialog's group_id to all statements in all chunks. """Assign this dialog's end_user_id to all statements in all chunks.
This method updates statements that don't have a group_id set. This method updates statements that don't have a end_user_id set.
""" """
for chunk in self.chunks: for chunk in self.chunks:
for statement in chunk.statements: for statement in chunk.statements:
if statement.group_id is None: if statement.end_user_id is None:
statement.group_id = self.group_id statement.end_user_id = self.end_user_id

View File

@@ -6,6 +6,7 @@ import os
import time import time
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
from uuid import UUID
if TYPE_CHECKING: if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig from app.schemas.memory_config_schema import MemoryConfig
@@ -396,13 +397,13 @@ def rerank_with_activation(
return reranked return reranked
def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = None): def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None):
"""Log search query information using the logger. """Log search query information using the logger.
Args: Args:
query_text: The search query text query_text: The search query text
search_type: Type of search (keyword, embedding, hybrid) search_type: Type of search (keyword, embedding, hybrid)
group_id: Group identifier for filtering end_user_id: Group identifier for filtering
limit: Maximum number of results limit: Maximum number of results
include: List of result types to include include: List of result types to include
log_file: Deprecated parameter, kept for backward compatibility log_file: Deprecated parameter, kept for backward compatibility
@@ -413,7 +414,7 @@ def log_search_query(query_text: str, search_type: str, group_id: str | None, li
# Log using the standard logger # Log using the standard logger
logger.info( logger.info(
f"Search query: query='{cleaned_query}', type={search_type}, " f"Search query: query='{cleaned_query}', type={search_type}, "
f"group_id={group_id}, limit={limit}, include={include}" f"end_user_id={end_user_id}, limit={limit}, include={include}"
) )
@@ -672,7 +673,7 @@ def apply_reranker_placeholder(
async def run_hybrid_search( async def run_hybrid_search(
query_text: str, query_text: str,
search_type: str, search_type: str,
group_id: str | None, end_user_id: str | None,
limit: int, limit: int,
include: List[str], include: List[str],
output_path: str | None, output_path: str | None,
@@ -715,7 +716,7 @@ async def run_hybrid_search(
} }
# Log the search query # Log the search query
log_search_query(query_text, search_type, group_id, limit, include) log_search_query(query_text, search_type, end_user_id, limit, include)
connector = Neo4jConnector() connector = Neo4jConnector()
results = {} results = {}
@@ -732,7 +733,7 @@ async def run_hybrid_search(
search_graph( search_graph(
connector=connector, connector=connector,
q=query_text, q=query_text,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include include=include
) )
@@ -769,7 +770,7 @@ async def run_hybrid_search(
connector=connector, connector=connector,
embedder_client=embedder, embedder_client=embedder,
query_text=query_text, query_text=query_text,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include, include=include,
) )
@@ -916,9 +917,7 @@ async def run_hybrid_search(
async def search_by_temporal( async def search_by_temporal(
group_id: Optional[str] = "test", end_user_id: Optional[str] = "test",
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
valid_date: Optional[str] = None, valid_date: Optional[str] = None,
@@ -929,7 +928,7 @@ async def search_by_temporal(
Temporal search across Statements. Temporal search across Statements.
- Matches statements created between start_date and end_date - Matches statements created between start_date and end_date
- Optionally filters by group_id - Optionally filters by end_user_id
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
connector = Neo4jConnector() connector = Neo4jConnector()
@@ -939,9 +938,7 @@ async def search_by_temporal(
end_date = normalize_date_safe(end_date) end_date = normalize_date_safe(end_date)
params = TemporalSearchParams.model_validate({ params = TemporalSearchParams.model_validate({
"group_id": group_id, "end_user_id": end_user_id,
"apply_id": apply_id,
"user_id": user_id,
"start_date": start_date, "start_date": start_date,
"end_date": end_date, "end_date": end_date,
"valid_date": valid_date, "valid_date": valid_date,
@@ -950,9 +947,7 @@ async def search_by_temporal(
}) })
statements = await search_graph_by_temporal( statements = await search_graph_by_temporal(
connector=connector, connector=connector,
group_id=params.group_id, end_user_id=params.end_user_id,
apply_id=params.apply_id,
user_id=params.user_id,
start_date=params.start_date, start_date=params.start_date,
end_date=params.end_date, end_date=params.end_date,
valid_date=params.valid_date, valid_date=params.valid_date,
@@ -964,9 +959,7 @@ async def search_by_temporal(
async def search_by_keyword_temporal( async def search_by_keyword_temporal(
query_text: str, query_text: str,
group_id: Optional[str] = "test", end_user_id: Optional[str] = "test",
apply_id: Optional[str] = None,
user_id: Optional[str] = None,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
valid_date: Optional[str] = None, valid_date: Optional[str] = None,
@@ -987,9 +980,7 @@ async def search_by_keyword_temporal(
invalid_date = normalize_date_safe(invalid_date) invalid_date = normalize_date_safe(invalid_date)
params = TemporalSearchParams.model_validate({ params = TemporalSearchParams.model_validate({
"group_id": group_id, "end_user_id": end_user_id,
"apply_id": apply_id,
"user_id": user_id,
"start_date": start_date, "start_date": start_date,
"end_date": end_date, "end_date": end_date,
"valid_date": valid_date, "valid_date": valid_date,
@@ -999,9 +990,7 @@ async def search_by_keyword_temporal(
statements = await search_graph_by_keyword_temporal( statements = await search_graph_by_keyword_temporal(
connector=connector, connector=connector,
query_text=query_text, query_text=query_text,
group_id=params.group_id, end_user_id=params.end_user_id,
apply_id=params.apply_id,
user_id=params.user_id,
start_date=params.start_date, start_date=params.start_date,
end_date=params.end_date, end_date=params.end_date,
valid_date=params.valid_date, valid_date=params.valid_date,
@@ -1013,7 +1002,7 @@ async def search_by_keyword_temporal(
async def search_chunk_by_chunk_id( async def search_chunk_by_chunk_id(
chunk_id: str, chunk_id: str,
group_id: Optional[str] = "test", end_user_id: Optional[str] = "test",
limit: int = 1, limit: int = 1,
): ):
""" """
@@ -1023,7 +1012,7 @@ async def search_chunk_by_chunk_id(
chunks = await search_graph_by_chunk_id( chunks = await search_graph_by_chunk_id(
connector=connector, connector=connector,
chunk_id=chunk_id, chunk_id=chunk_id,
group_id=group_id, end_user_id=end_user_id,
limit=limit limit=limit
) )
return {"chunks": chunks} return {"chunks": chunks}

View File

@@ -555,8 +555,8 @@ class DataPreprocessor:
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}'))) dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
# 获取group_id如果不存在则生成默认值 # 获取end_user_id如果不存在则生成默认值
group_id = item.get('group_id', f'group_default_{i}') end_user_id = item.get('end_user_id', f'group_default_{i}')
user_id = item.get('user_id', f'user_default_{i}') user_id = item.get('user_id', f'user_default_{i}')
apply_id = item.get('apply_id', f'apply_default_{i}') apply_id = item.get('apply_id', f'apply_default_{i}')
@@ -574,7 +574,7 @@ class DataPreprocessor:
dialog_data = DialogData( dialog_data = DialogData(
context=context, context=context,
ref_id=dialog_id, ref_id=dialog_id,
group_id=group_id, end_user_id=end_user_id,
user_id=user_id, user_id=user_id,
apply_id=apply_id, apply_id=apply_id,
metadata=metadata metadata=metadata
@@ -644,7 +644,7 @@ class DataPreprocessor:
context = ConversationContext(msgs=messages) context = ConversationContext(msgs=messages)
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}'))) dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
group_id = item.get('group_id', f'group_default_{i}') end_user_id = item.get('end_user_id', f'group_default_{i}')
user_id = item.get('user_id', f'user_default_{i}') user_id = item.get('user_id', f'user_default_{i}')
apply_id = item.get('apply_id', f'apply_default_{i}') apply_id = item.get('apply_id', f'apply_default_{i}')
@@ -657,7 +657,7 @@ class DataPreprocessor:
dialog_data = DialogData( dialog_data = DialogData(
context=context, context=context,
ref_id=dialog_id, ref_id=dialog_id,
group_id=group_id, end_user_id=end_user_id,
user_id=user_id, user_id=user_id,
apply_id=apply_id, apply_id=apply_id,
metadata=metadata metadata=metadata

View File

@@ -199,7 +199,7 @@ def accurate_match(
entity_nodes: List[ExtractedEntityNode] entity_nodes: List[ExtractedEntityNode]
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]: ) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
""" """
精确匹配:按 (group_id, name, entity_type) 合并实体并建立重定向与合并记录。 精确匹配:按 (end_user_id, name, entity_type) 合并实体并建立重定向与合并记录。
返回: (deduped_entities, id_redirect, exact_merge_map) 返回: (deduped_entities, id_redirect, exact_merge_map)
""" """
exact_merge_map: Dict[str, Dict] = {} exact_merge_map: Dict[str, Dict] = {}
@@ -210,8 +210,8 @@ def accurate_match(
for ent in entity_nodes: for ent in entity_nodes:
name_norm = (getattr(ent, "name", "") or "").strip() name_norm = (getattr(ent, "name", "") or "").strip()
type_norm = (getattr(ent, "entity_type", "") or "").strip() type_norm = (getattr(ent, "entity_type", "") or "").strip()
key = f"{getattr(ent, 'group_id', None)}|{name_norm}|{type_norm}" key = f"{getattr(ent, 'end_user_id', None)}|{name_norm}|{type_norm}"
# 为避免跨业务组误并,明确以 group_id 为范围边界 # 为避免跨业务组误并,明确以 end_user_id 为范围边界
if key not in canonical_map: if key not in canonical_map:
canonical_map[key] = ent canonical_map[key] = ent
id_redirect[ent.id] = ent.id id_redirect[ent.id] = ent.id
@@ -223,11 +223,11 @@ def accurate_match(
id_redirect[ent.id] = canonical.id id_redirect[ent.id] = canonical.id
# 记录精确匹配的合并项(使用规范化键,避免外层变量误用) # 记录精确匹配的合并项(使用规范化键,避免外层变量误用)
try: try:
k = f"{canonical.group_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}" k = f"{canonical.end_user_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
if k not in exact_merge_map: if k not in exact_merge_map:
exact_merge_map[k] = { exact_merge_map[k] = {
"canonical_id": canonical.id, "canonical_id": canonical.id,
"group_id": canonical.group_id, "end_user_id": canonical.end_user_id,
"name": canonical.name, "name": canonical.name,
"entity_type": canonical.entity_type, "entity_type": canonical.entity_type,
"merged_ids": set(), "merged_ids": set(),
@@ -596,7 +596,7 @@ def fuzzy_match(
b = deduped_entities[j] b = deduped_entities[j]
# 跳过不同业务组的实体 # 跳过不同业务组的实体
if getattr(a, "group_id", None) != getattr(b, "group_id", None): if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
j += 1 j += 1
continue continue
@@ -671,7 +671,7 @@ def fuzzy_match(
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]" merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]" merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
fuzzy_merge_records.append( fuzzy_merge_records.append(
f"{merge_reason} 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type}) | " f"{merge_reason} 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type}) | "
f"s_name={s_name:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, exact_alias={has_exact_match}" f"s_name={s_name:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, exact_alias={has_exact_match}"
) )
except Exception: except Exception:
@@ -779,7 +779,7 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
# 记录 LLM 融合日志 # 记录 LLM 融合日志
try: try:
llm_records.append( llm_records.append(
f"[LLM融合] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})" f"[LLM融合] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})"
) )
# 详细的“同类名称相似”记录改由 LLM 去重模块统一生成以携带 conf/reason # 详细的“同类名称相似”记录改由 LLM 去重模块统一生成以携带 conf/reason
except Exception: except Exception:
@@ -847,7 +847,7 @@ async def LLM_disamb_decision(
id_redirect[k] = a.id id_redirect[k] = a.id
try: try:
disamb_records.append( disamb_records.append(
f"[DISAMB合并应用] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})" f"[DISAMB合并应用] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})"
) )
except Exception: except Exception:
pass pass

View File

@@ -174,7 +174,7 @@ async def _judge_pair(
pass pass
# 3. 构建LLM判断的“上下文信息”规则层计算的所有特征 判断上下文特征有助于实体消歧首先判断的类型关系 # 3. 构建LLM判断的“上下文信息”规则层计算的所有特征 判断上下文特征有助于实体消歧首先判断的类型关系
ctx = { ctx = {
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None), "same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None),
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)), "type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
"type_similarity": _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None)), "type_similarity": _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
"name_text_sim": name_text_sim, "name_text_sim": name_text_sim,
@@ -235,7 +235,7 @@ async def _judge_pair_disamb(
except Exception: except Exception:
pass pass
ctx = { ctx = {
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None), "same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None),
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)), "type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
"name_text_sim": name_text_sim, "name_text_sim": name_text_sim,
"name_embed_sim": name_embed_sim, "name_embed_sim": name_embed_sim,
@@ -317,8 +317,8 @@ async def llm_dedup_entities( # 保留对偶判断作为子流程,是为了
a = entity_nodes[i] a = entity_nodes[i]
for j in range(i + 1, len(entity_nodes)): for j in range(i + 1, len(entity_nodes)):
b = entity_nodes[j] b = entity_nodes[j]
# 规则1必须属于同一组group_id相同不同组的实体不重复 # 规则1必须属于同一组end_user_id相同不同组的实体不重复
if getattr(a, "group_id", None) != getattr(b, "group_id", None): if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
continue continue
# 规则2类型必须兼容调用_simple_type_ok判断 # 规则2类型必须兼容调用_simple_type_ok判断
if not _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)): if not _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)):
@@ -474,7 +474,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
- max_rounds: upper bound for iterative passes (default 3) - max_rounds: upper bound for iterative passes (default 3)
- auto_merge_threshold: decision confidence for auto-merge when no co-occurrence (default 0.90) - auto_merge_threshold: decision confidence for auto-merge when no co-occurrence (default 0.90)
- co_ctx_threshold: lower threshold when co-occurrence is detected (default 0.83) - co_ctx_threshold: lower threshold when co-occurrence is detected (default 0.83)
- shuffle_each_round: whether to shuffle entities within group_id each round to vary block composition - shuffle_each_round: whether to shuffle entities within end_user_id each round to vary block composition
Returns: Returns:
- global_redirect: dict losing_id -> canonical_id accumulated across rounds - global_redirect: dict losing_id -> canonical_id accumulated across rounds
@@ -509,7 +509,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
def _partition_blocks(nodes: List[ExtractedEntityNode]) -> List[List[ExtractedEntityNode]]: def _partition_blocks(nodes: List[ExtractedEntityNode]) -> List[List[ExtractedEntityNode]]:
""" """
group_id 分块,避免跨组实体在同一块,减少无效候选对 end_user_id 分块,避免跨组实体在同一块,减少无效候选对
Args: Args:
nodes: 实体节点列表 nodes: 实体节点列表
@@ -519,7 +519,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
""" """
groups: Dict[str, List[ExtractedEntityNode]] = {} groups: Dict[str, List[ExtractedEntityNode]] = {}
for e in nodes: for e in nodes:
gid = getattr(e, "group_id", None) gid = getattr(e, "end_user_id", None)
groups.setdefault(str(gid), []).append(e) groups.setdefault(str(gid), []).append(e)
blocks: List[List[ExtractedEntityNode]] = [] blocks: List[List[ExtractedEntityNode]] = []
for gid, arr in groups.items(): for gid, arr in groups.items():
@@ -559,7 +559,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
# Collapse nodes to canonical reps before each round to avoid redundant comparisons # Collapse nodes to canonical reps before each round to avoid redundant comparisons
# 步骤1折叠实体合并已确定的重复实体减少后续计算量 # 步骤1折叠实体合并已确定的重复实体减少后续计算量
current_nodes = _collapse_nodes(current_nodes) current_nodes = _collapse_nodes(current_nodes)
# 步骤2分块group_id分块避免跨组处理 # 步骤2分块end_user_id分块避免跨组处理
blocks = _partition_blocks(current_nodes) blocks = _partition_blocks(current_nodes)
if not blocks: # 无块可处理(实体已全部折叠),退出循环 if not blocks: # 无块可处理(实体已全部折叠),退出循环
break break
@@ -645,7 +645,7 @@ async def llm_disambiguate_pairs_iterative(
a = entity_nodes[i] a = entity_nodes[i]
b = entity_nodes[j] b = entity_nodes[j]
# 必须同组 # 必须同组
if getattr(a, "group_id", None) != getattr(b, "group_id", None): if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
continue continue
ta = getattr(a, "entity_type", None) ta = getattr(a, "entity_type", None)
tb = getattr(b, "entity_type", None) tb = getattr(b, "entity_type", None)

View File

@@ -61,7 +61,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
return ExtractedEntityNode( return ExtractedEntityNode(
id=row.get("id"), id=row.get("id"),
name=row.get("name") or "", name=row.get("name") or "",
group_id=row.get("group_id") or "", end_user_id=row.get("end_user_id") or "",
user_id=row.get("user_id") or "", user_id=row.get("user_id") or "",
apply_id=row.get("apply_id") or "", apply_id=row.get("apply_id") or "",
created_at=_parse_dt(row.get("created_at")), created_at=_parse_dt(row.get("created_at")),
@@ -79,7 +79,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑,与 Neo4j 中同组实体联合去重 async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑,与 Neo4j 中同组实体联合去重
connector: Neo4jConnector, connector: Neo4jConnector,
group_id: str, # 用于定位neo4j中同一组的实体确保只在同组内去重 end_user_id: str, # 用于定位neo4j中同一组的实体确保只在同组内去重
entity_nodes: List[ExtractedEntityNode], # 输入的实体节点列表,包含待去重的实体 entity_nodes: List[ExtractedEntityNode], # 输入的实体节点列表,包含待去重的实体
statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系 statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系
entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系 entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系
@@ -88,7 +88,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]: ) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
""" """
第二层去重消歧: 第二层去重消歧:
- 以第一层结果为索引,检索相同 group_id 下的 DB 候选实体 - 以第一层结果为索引,检索相同 end_user_id 下的 DB 候选实体
- 将 DB 候选与当前实体集合联合,按既有精确/模糊/LLM 决策进行融合 - 将 DB 候选与当前实体集合联合,按既有精确/模糊/LLM 决策进行融合
- 返回融合后的实体与重定向后的边(边已指向规范 ID优先 DB ID - 返回融合后的实体与重定向后的边(边已指向规范 ID优先 DB ID
""" """
@@ -102,7 +102,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
] ]
candidates_map = await get_dedup_candidates_for_entities( # 从 Neo4j 中查询候选实体并将结果赋值给candidates_map等待异步操作完成 candidates_map = await get_dedup_candidates_for_entities( # 从 Neo4j 中查询候选实体并将结果赋值给candidates_map等待异步操作完成
connector=connector, group_id=group_id, connector=connector, end_user_id=end_user_id,
entities=incoming_rows, # 传入参数:第一层实体的核心信息(作为查询索引) entities=incoming_rows, # 传入参数:第一层实体的核心信息(作为查询索引)
use_contains_fallback=True # 传入参数:启用 “包含关系” 作为匹配失败的降级策略若精确匹配无结果用包含关系召回候选与src\database\cypher_queries.py的307产生联动 use_contains_fallback=True # 传入参数:启用 “包含关系” 作为匹配失败的降级策略若精确匹配无结果用包含关系召回候选与src\database\cypher_queries.py的307产生联动
) )

View File

@@ -57,11 +57,11 @@ async def dedup_layers_and_merge_and_return(
if pipeline_config is None: if pipeline_config is None:
raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return") raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return")
# 先探测 group_id决定报告写入策略 # 先探测 end_user_id决定报告写入策略
group_id: Optional[str] = None end_user_id: Optional[str] = None
for dd in dialog_data_list: for dd in dialog_data_list:
group_id = getattr(dd, "group_id", None) end_user_id = getattr(dd, "end_user_id", None)
if group_id: if end_user_id:
break break
# 第一层去重消歧 # 第一层去重消歧
@@ -82,11 +82,11 @@ async def dedup_layers_and_merge_and_return(
# 第二层去重消歧:与 Neo4j 中同组实体联合融合 # 第二层去重消歧:与 Neo4j 中同组实体联合融合
try: try:
if group_id: if end_user_id:
if connector: if connector:
fused_entity_nodes, fused_statement_entity_edges, fused_entity_entity_edges = await second_layer_dedup_and_merge_with_neo4j( fused_entity_nodes, fused_statement_entity_edges, fused_entity_entity_edges = await second_layer_dedup_and_merge_with_neo4j(
connector=connector, connector=connector,
group_id=group_id, end_user_id=end_user_id,
entity_nodes=dedup_entity_nodes, entity_nodes=dedup_entity_nodes,
statement_entity_edges=dedup_statement_entity_edges, statement_entity_edges=dedup_statement_entity_edges,
entity_entity_edges=dedup_entity_entity_edges, entity_entity_edges=dedup_entity_entity_edges,
@@ -96,7 +96,7 @@ async def dedup_layers_and_merge_and_return(
else: else:
print("Skip second-layer dedup: missing connector") print("Skip second-layer dedup: missing connector")
else: else:
print("Skip second-layer dedup: missing group_id") print("Skip second-layer dedup: missing end_user_id")
except Exception as e: except Exception as e:
print(f"Second-layer dedup failed: {e}") print(f"Second-layer dedup failed: {e}")

View File

@@ -287,7 +287,7 @@ class ExtractionOrchestrator:
for d_idx, dialog in enumerate(dialog_data_list): for d_idx, dialog in enumerate(dialog_data_list):
dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None
for c_idx, chunk in enumerate(dialog.chunks): for c_idx, chunk in enumerate(dialog.chunks):
all_chunks.append((chunk, dialog.group_id, dialogue_content)) all_chunks.append((chunk, dialog.end_user_id, dialogue_content))
chunk_metadata.append((d_idx, c_idx)) chunk_metadata.append((d_idx, c_idx))
logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取") logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取")
@@ -299,9 +299,9 @@ class ExtractionOrchestrator:
# 全局并行处理所有分块 # 全局并行处理所有分块
async def extract_for_chunk(chunk_data, chunk_index): async def extract_for_chunk(chunk_data, chunk_index):
nonlocal completed_chunks nonlocal completed_chunks
chunk, group_id, dialogue_content = chunk_data chunk, end_user_id, dialogue_content = chunk_data
try: try:
statements = await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content) statements = await self.statement_extractor._extract_statements(chunk, end_user_id, dialogue_content)
# 流式输出:每提取完一个分块的陈述句,立即发送进度 # 流式输出:每提取完一个分块的陈述句,立即发送进度
# 注意:只在试运行模式下发送陈述句详情,正式模式不发送 # 注意:只在试运行模式下发送陈述句详情,正式模式不发送
@@ -569,32 +569,32 @@ class ExtractionOrchestrator:
if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'): if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'):
config_id = dialog_data_list[0].config_id config_id = dialog_data_list[0].config_id
# 加载DataConfig # 加载MemoryConfig
data_config = None memory_config = None
if config_id: if config_id:
try: try:
from app.db import SessionLocal from app.db import SessionLocal
from app.repositories.data_config_repository import DataConfigRepository from app.repositories.memory_config_repository import MemoryConfigRepository
db = SessionLocal() db = SessionLocal()
try: try:
data_config = DataConfigRepository.get_by_id(db, config_id) memory_config = MemoryConfigRepository.get_by_id(db, config_id)
finally: finally:
db.close() db.close()
if data_config and not data_config.emotion_enabled: if memory_config and not memory_config.emotion_enabled:
logger.info("情绪提取已在配置中禁用,跳过情绪提取") logger.info("情绪提取已在配置中禁用,跳过情绪提取")
return [{} for _ in dialog_data_list] return [{} for _ in dialog_data_list]
except Exception as e: except Exception as e:
logger.warning(f"加载DataConfig失败: {e},将跳过情绪提取") logger.warning(f"加载MemoryConfig失败: {e},将跳过情绪提取")
return [{} for _ in dialog_data_list] return [{} for _ in dialog_data_list]
else: else:
logger.info("未找到config_id跳过情绪提取") logger.info("未找到config_id跳过情绪提取")
return [{} for _ in dialog_data_list] return [{} for _ in dialog_data_list]
# 如果配置未启用情绪提取,直接返回空映射 # 如果配置未启用情绪提取,直接返回空映射
if not data_config or not data_config.emotion_enabled: if not memory_config or not memory_config.emotion_enabled:
logger.info("情绪提取未启用,跳过") logger.info("情绪提取未启用,跳过")
return [{} for _ in dialog_data_list] return [{} for _ in dialog_data_list]
@@ -608,7 +608,7 @@ class ExtractionOrchestrator:
total_statements += 1 total_statements += 1
# 只处理用户的陈述句 (role 为 "user") # 只处理用户的陈述句 (role 为 "user")
if hasattr(statement, 'speaker') and statement.speaker == "user": if hasattr(statement, 'speaker') and statement.speaker == "user":
all_statements.append((statement, data_config)) all_statements.append((statement, memory_config))
statement_metadata.append((d_idx, statement.id)) statement_metadata.append((d_idx, statement.id))
filtered_statements += 1 filtered_statements += 1
@@ -617,7 +617,7 @@ class ExtractionOrchestrator:
# 初始化情绪提取服务 # 初始化情绪提取服务
from app.services.emotion_extraction_service import EmotionExtractionService from app.services.emotion_extraction_service import EmotionExtractionService
emotion_service = EmotionExtractionService( emotion_service = EmotionExtractionService(
llm_id=data_config.emotion_model_id if data_config.emotion_model_id else None llm_id=memory_config.emotion_model_id if memory_config.emotion_model_id else None
) )
# 全局并行处理所有陈述句 # 全局并行处理所有陈述句
@@ -992,9 +992,7 @@ class ExtractionOrchestrator:
id=dialog_data.id, id=dialog_data.id,
name=f"Dialog_{dialog_data.id}", # 添加必需的 name 字段 name=f"Dialog_{dialog_data.id}", # 添加必需的 name 字段
ref_id=dialog_data.ref_id, ref_id=dialog_data.ref_id,
group_id=dialog_data.group_id, end_user_id=dialog_data.end_user_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
content=dialog_data.context.content if dialog_data.context else "", content=dialog_data.context.content if dialog_data.context else "",
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None, dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None,
@@ -1012,9 +1010,7 @@ class ExtractionOrchestrator:
id=chunk.id, id=chunk.id,
name=f"Chunk_{chunk.id}", # 添加必需的 name 字段 name=f"Chunk_{chunk.id}", # 添加必需的 name 字段
dialog_id=dialog_data.id, dialog_id=dialog_data.id,
group_id=dialog_data.group_id, end_user_id=dialog_data.end_user_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
content=chunk.content, content=chunk.content,
chunk_embedding=chunk.chunk_embedding, chunk_embedding=chunk.chunk_embedding,
@@ -1035,9 +1031,7 @@ class ExtractionOrchestrator:
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段 stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段 temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
group_id=dialog_data.group_id, end_user_id=dialog_data.end_user_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
statement=statement.statement, statement=statement.statement,
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段 speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
@@ -1060,9 +1054,7 @@ class ExtractionOrchestrator:
statement_chunk_edge = StatementChunkEdge( statement_chunk_edge = StatementChunkEdge(
source=statement.id, source=statement.id,
target=chunk.id, target=chunk.id,
group_id=dialog_data.group_id, end_user_id=dialog_data.end_user_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
) )
@@ -1072,13 +1064,16 @@ class ExtractionOrchestrator:
if statement.triplet_extraction_info: if statement.triplet_extraction_info:
triplet_info = statement.triplet_extraction_info triplet_info = statement.triplet_extraction_info
# 创建实体索引到ID的映射 # 创建实体索引到ID的映射(支持多种索引方式)
entity_idx_to_id = {} entity_idx_to_id = {}
# 创建实体节点 # 创建实体节点
for entity_idx, entity in enumerate(triplet_info.entities): for entity_idx, entity in enumerate(triplet_info.entities):
# 映射实体索引到实体ID # 映射实体索引到实体ID(使用多个键以提高容错性)
# 1. 使用实体自己的 entity_idx
entity_idx_to_id[entity.entity_idx] = entity.id entity_idx_to_id[entity.entity_idx] = entity.id
# 2. 使用枚举索引从0开始
entity_idx_to_id[entity_idx] = entity.id
if entity.id not in entity_id_set: if entity.id not in entity_id_set:
entity_connect_strength = getattr(entity, 'connect_strength', 'Strong') entity_connect_strength = getattr(entity, 'connect_strength', 'Strong')
@@ -1095,9 +1090,7 @@ class ExtractionOrchestrator:
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
name_embedding=getattr(entity, 'name_embedding', None), name_embedding=getattr(entity, 'name_embedding', None),
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记 is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
group_id=dialog_data.group_id, end_user_id=dialog_data.end_user_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at, expired_at=dialog_data.expired_at,
@@ -1112,9 +1105,7 @@ class ExtractionOrchestrator:
source=statement.id, source=statement.id,
target=entity.id, target=entity.id,
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong',
group_id=dialog_data.group_id, end_user_id=dialog_data.end_user_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
) )
@@ -1134,9 +1125,7 @@ class ExtractionOrchestrator:
relation_type=triplet.predicate, relation_type=triplet.predicate,
statement=statement.statement, statement=statement.statement,
source_statement_id=statement.id, source_statement_id=statement.id,
group_id=dialog_data.group_id, end_user_id=dialog_data.end_user_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
created_at=dialog_data.created_at, created_at=dialog_data.created_at,
expired_at=dialog_data.expired_at, expired_at=dialog_data.expired_at,
@@ -1163,9 +1152,18 @@ class ExtractionOrchestrator:
relationship_result relationship_result
) )
else: else:
logger.warning( # 改进的警告信息,包含更多调试信息
f"跳过三元组 - 无法找到实体ID: subject_id={triplet.subject_id}, " missing_subject = "subject" if not subject_entity_id else ""
f"object_id={triplet.object_id}, statement_id={statement.id}" missing_object = "object" if not object_entity_id else ""
missing_both = " and " if (not subject_entity_id and not object_entity_id) else ""
logger.debug(
f"跳过三元组 - 无法找到{missing_subject}{missing_both}{missing_object}实体ID: "
f"subject_id={triplet.subject_id} ({triplet.subject_name}), "
f"object_id={triplet.object_id} ({triplet.object_name}), "
f"predicate={triplet.predicate}, "
f"statement_id={statement.id}, "
f"available_indices={sorted(entity_idx_to_id.keys())}"
) )
logger.info( logger.info(
@@ -1763,14 +1761,14 @@ class ExtractionOrchestrator:
async def get_chunked_dialogs( async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker", chunker_strategy: str = "RecursiveChunker",
group_id: str = "group_1", end_user_id: str = "group_1",
indices: Optional[List[int]] = None, indices: Optional[List[int]] = None,
) -> List[DialogData]: ) -> List[DialogData]:
"""从测试数据生成分块对话 """从测试数据生成分块对话
Args: Args:
chunker_strategy: 分块策略(默认: RecursiveChunker chunker_strategy: 分块策略(默认: RecursiveChunker
group_id: 组ID end_user_id: 组ID
indices: 要处理的数据索引列表(可选) indices: 要处理的数据索引列表(可选)
Returns: Returns:
@@ -1834,7 +1832,7 @@ async def get_chunked_dialogs(
dialog_data = DialogData( dialog_data = DialogData(
context=conversation_context, context=conversation_context,
ref_id=data['id'], ref_id=data['id'],
group_id=group_id, end_user_id=end_user_id,
metadata=dialog_metadata, metadata=dialog_metadata,
) )
@@ -1936,7 +1934,7 @@ async def get_chunked_dialogs_from_preprocessed(
async def get_chunked_dialogs_with_preprocessing( async def get_chunked_dialogs_with_preprocessing(
chunker_strategy: str = "RecursiveChunker", chunker_strategy: str = "RecursiveChunker",
group_id: str = "default", end_user_id: str = "default",
user_id: str = "default", user_id: str = "default",
apply_id: str = "default", apply_id: str = "default",
indices: Optional[List[int]] = None, indices: Optional[List[int]] = None,
@@ -1948,7 +1946,7 @@ async def get_chunked_dialogs_with_preprocessing(
Args: Args:
chunker_strategy: 分块策略 chunker_strategy: 分块策略
group_id: 组ID end_user_id: 组ID
user_id: 用户ID user_id: 用户ID
apply_id: 应用ID apply_id: 应用ID
indices: 要处理的数据索引列表 indices: 要处理的数据索引列表
@@ -1976,11 +1974,9 @@ async def get_chunked_dialogs_with_preprocessing(
indices=indices, indices=indices,
) )
# 设置 group_id, user_id, apply_id # 设置 end_user_id
for dd in preprocessed_data: for dd in preprocessed_data:
dd.group_id = group_id dd.end_user_id = end_user_id
dd.user_id = user_id
dd.apply_id = apply_id
# 步骤2: 语义剪枝 # 步骤2: 语义剪枝
try: try:

View File

@@ -193,9 +193,9 @@ async def _process_chunk_summary(
node = MemorySummaryNode( node = MemorySummaryNode(
id=uuid4().hex, id=uuid4().hex,
name=title if title else f"MemorySummaryChunk_{chunk.id}", name=title if title else f"MemorySummaryChunk_{chunk.id}",
group_id=dialog.group_id, end_user_id=dialog.end_user_id,
user_id=dialog.user_id, user_id=dialog.end_user_id,
apply_id=dialog.apply_id, apply_id=dialog.end_user_id,
run_id=dialog.run_id, # 使用 dialog 的 run_id run_id=dialog.run_id, # 使用 dialog 的 run_id
created_at=datetime.now(), created_at=datetime.now(),
expired_at=datetime(9999, 12, 31), expired_at=datetime(9999, 12, 31),

View File

@@ -82,12 +82,12 @@ class StatementExtractor:
logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty") logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty")
return None return None
async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]: async def _extract_statements(self, chunk, end_user_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
"""Process a single chunk and return extracted statements """Process a single chunk and return extracted statements
Args: Args:
chunk: Chunk object to process chunk: Chunk object to process
group_id: Group ID to assign to all statements in this chunk end_user_id: Group ID to assign to all statements in this chunk
dialogue_content: Full dialogue content to provide as context dialogue_content: Full dialogue content to provide as context
Returns: Returns:
@@ -158,7 +158,7 @@ class StatementExtractor:
temporal_info=temporal_type, temporal_info=temporal_type,
relevence_info=relevence_info, relevence_info=relevence_info,
chunk_id=chunk.id, chunk_id=chunk.id,
group_id=group_id, end_user_id=end_user_id,
speaker=chunk_speaker, speaker=chunk_speaker,
) )
@@ -184,10 +184,10 @@ class StatementExtractor:
logger.info(f"Processing {len(chunks_to_process)} chunks for statement extraction") logger.info(f"Processing {len(chunks_to_process)} chunks for statement extraction")
# Process all chunks concurrently, passing the group_id and dialogue content from dialog_data # Process all chunks concurrently, passing the end_user_id and dialogue content from dialog_data
dialogue_content = dialog_data.content if self.config.include_dialogue_context else None dialogue_content = dialog_data.content if self.config.include_dialogue_context else None
results = await asyncio.gather( results = await asyncio.gather(
*[self._extract_statements(chunk, dialog_data.group_id, dialogue_content) for chunk in chunks_to_process], *[self._extract_statements(chunk, dialog_data.end_user_id, dialogue_content) for chunk in chunks_to_process],
return_exceptions=True return_exceptions=True
) )
@@ -225,7 +225,7 @@ class StatementExtractor:
for i, statement in enumerate(statements, 1): for i, statement in enumerate(statements, 1):
f.write(f"Statement {i}:\n") f.write(f"Statement {i}:\n")
f.write(f"Id: {statement.id}\n") f.write(f"Id: {statement.id}\n")
f.write(f"Group Id: {statement.group_id}\n") f.write(f"Group Id: {statement.end_user_id}\n")
f.write(f"Content: {statement.statement}\n") f.write(f"Content: {statement.statement}\n")
f.write(f"Type: {statement.stmt_type.value}\n") f.write(f"Type: {statement.stmt_type.value}\n")
f.write(f"Temporal Info: {statement.temporal_info.value}\n") f.write(f"Temporal Info: {statement.temporal_info.value}\n")
@@ -298,7 +298,7 @@ class StatementExtractor:
dialog_sections.append({ dialog_sections.append({
"dialog_id": dialog.ref_id, "dialog_id": dialog.ref_id,
"group_id": dialog.group_id, "end_user_id": dialog.end_user_id,
"content": dialog.content if getattr(dialog, "content", None) else "", "content": dialog.content if getattr(dialog, "content", None) else "",
"strong": strong_relations, "strong": strong_relations,
"weak": weak_relations, "weak": weak_relations,
@@ -312,7 +312,7 @@ class StatementExtractor:
for idx, section in enumerate(dialog_sections, 1): for idx, section in enumerate(dialog_sections, 1):
f.write(f"Dialog {idx}:\n") f.write(f"Dialog {idx}:\n")
f.write(f"Dialog ID: {section.get('dialog_id', '')}\n") f.write(f"Dialog ID: {section.get('dialog_id', '')}\n")
f.write(f"Group ID: {section.get('group_id', '')}\n") f.write(f"Group ID: {section.get('end_user_id', '')}\n")
f.write("Content:\n") f.write("Content:\n")
f.write(f"{section.get('content', '')}\n") f.write(f"{section.get('content', '')}\n")
f.write("-" * 40 + "\n\n") f.write("-" * 40 + "\n\n")

View File

@@ -132,7 +132,7 @@ class TemporalExtractor:
prompt_logger.info("") prompt_logger.info("")
prompt_logger.info("=== TEMPORAL EXTRACTION RESULTS ===") prompt_logger.info("=== TEMPORAL EXTRACTION RESULTS ===")
prompt_logger.info( prompt_logger.info(
f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}" f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}"
) )
except Exception: except Exception:
pass pass

View File

@@ -116,7 +116,7 @@ class TripletExtractor:
logger.info(f"Processing {len(all_statements)} statements for triplet extraction...") logger.info(f"Processing {len(all_statements)} statements for triplet extraction...")
try: try:
prompt_logger.info( prompt_logger.info(
f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}, statements_to_process={len(all_statements)}" f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}, statements_to_process={len(all_statements)}"
) )
except Exception: except Exception:
pass pass

View File

@@ -75,7 +75,7 @@ class AccessHistoryManager:
self, self,
node_id: str, node_id: str,
node_label: str, node_label: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_time: Optional[datetime] = None current_time: Optional[datetime] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
@@ -91,7 +91,7 @@ class AccessHistoryManager:
Args: Args:
node_id: 节点ID node_id: 节点ID
node_label: 节点标签Statement, ExtractedEntity, MemorySummary node_label: 节点标签Statement, ExtractedEntity, MemorySummary
group_id: 组ID可选用于过滤 end_user_id: 组ID可选用于过滤
current_time: 当前时间(可选,默认使用系统时间) current_time: 当前时间(可选,默认使用系统时间)
Returns: Returns:
@@ -123,7 +123,7 @@ class AccessHistoryManager:
for attempt in range(self.max_retries): for attempt in range(self.max_retries):
try: try:
# 步骤1读取当前节点状态 # 步骤1读取当前节点状态
node_data = await self._fetch_node(node_id, node_label, group_id) node_data = await self._fetch_node(node_id, node_label, end_user_id)
if not node_data: if not node_data:
raise ValueError( raise ValueError(
@@ -142,7 +142,7 @@ class AccessHistoryManager:
node_id=node_id, node_id=node_id,
node_label=node_label, node_label=node_label,
update_data=update_data, update_data=update_data,
group_id=group_id end_user_id=end_user_id
) )
logger.info( logger.info(
@@ -172,7 +172,7 @@ class AccessHistoryManager:
self, self,
node_ids: List[str], node_ids: List[str],
node_label: str, node_label: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_time: Optional[datetime] = None current_time: Optional[datetime] = None
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
@@ -184,7 +184,7 @@ class AccessHistoryManager:
Args: Args:
node_ids: 节点ID列表 node_ids: 节点ID列表
node_label: 节点标签(所有节点必须是同一类型) node_label: 节点标签(所有节点必须是同一类型)
group_id: 组ID可选 end_user_id: 组ID可选
current_time: 当前时间(可选) current_time: 当前时间(可选)
Returns: Returns:
@@ -202,7 +202,7 @@ class AccessHistoryManager:
task = self.record_access( task = self.record_access(
node_id=node_id, node_id=node_id,
node_label=node_label, node_label=node_label,
group_id=group_id, end_user_id=end_user_id,
current_time=current_time current_time=current_time
) )
tasks.append(task) tasks.append(task)
@@ -235,7 +235,7 @@ class AccessHistoryManager:
self, self,
node_id: str, node_id: str,
node_label: str, node_label: str,
group_id: Optional[str] = None end_user_id: Optional[str] = None
) -> Tuple[ConsistencyCheckResult, Optional[str]]: ) -> Tuple[ConsistencyCheckResult, Optional[str]]:
""" """
检查节点数据的一致性 检查节点数据的一致性
@@ -249,14 +249,14 @@ class AccessHistoryManager:
Args: Args:
node_id: 节点ID node_id: 节点ID
node_label: 节点标签 node_label: 节点标签
group_id: 组ID可选 end_user_id: 组ID可选
Returns: Returns:
Tuple[ConsistencyCheckResult, Optional[str]]: Tuple[ConsistencyCheckResult, Optional[str]]:
- 一致性检查结果枚举 - 一致性检查结果枚举
- 错误描述(如果不一致) - 错误描述(如果不一致)
""" """
node_data = await self._fetch_node(node_id, node_label, group_id) node_data = await self._fetch_node(node_id, node_label, end_user_id)
if not node_data: if not node_data:
return ConsistencyCheckResult.CONSISTENT, None return ConsistencyCheckResult.CONSISTENT, None
@@ -305,7 +305,7 @@ class AccessHistoryManager:
async def check_batch_consistency( async def check_batch_consistency(
self, self,
node_label: str, node_label: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 1000 limit: int = 1000
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
@@ -313,7 +313,7 @@ class AccessHistoryManager:
Args: Args:
node_label: 节点标签 node_label: 节点标签
group_id: 组ID可选 end_user_id: 组ID可选
limit: 检查的最大节点数 limit: 检查的最大节点数
Returns: Returns:
@@ -329,16 +329,16 @@ class AccessHistoryManager:
MATCH (n:{node_label}) MATCH (n:{node_label})
WHERE n.access_history IS NOT NULL WHERE n.access_history IS NOT NULL
""" """
if group_id: if end_user_id:
query += " AND n.group_id = $group_id" query += " AND n.end_user_id = $end_user_id"
query += """ query += """
RETURN n.id as id RETURN n.id as id
LIMIT $limit LIMIT $limit
""" """
params = {"limit": limit} params = {"limit": limit}
if group_id: if end_user_id:
params["group_id"] = group_id params["end_user_id"] = end_user_id
results = await self.connector.execute_query(query, **params) results = await self.connector.execute_query(query, **params)
node_ids = [r['id'] for r in results] node_ids = [r['id'] for r in results]
@@ -351,7 +351,7 @@ class AccessHistoryManager:
result, message = await self.check_consistency( result, message = await self.check_consistency(
node_id=node_id, node_id=node_id,
node_label=node_label, node_label=node_label,
group_id=group_id end_user_id=end_user_id
) )
if result == ConsistencyCheckResult.CONSISTENT: if result == ConsistencyCheckResult.CONSISTENT:
@@ -387,7 +387,7 @@ class AccessHistoryManager:
self, self,
node_id: str, node_id: str,
node_label: str, node_label: str,
group_id: Optional[str] = None end_user_id: Optional[str] = None
) -> bool: ) -> bool:
""" """
自动修复节点的数据不一致问题 自动修复节点的数据不一致问题
@@ -401,7 +401,7 @@ class AccessHistoryManager:
Args: Args:
node_id: 节点ID node_id: 节点ID
node_label: 节点标签 node_label: 节点标签
group_id: 组ID可选 end_user_id: 组ID可选
Returns: Returns:
bool: 修复成功返回True否则返回False bool: 修复成功返回True否则返回False
@@ -411,7 +411,7 @@ class AccessHistoryManager:
result, message = await self.check_consistency( result, message = await self.check_consistency(
node_id=node_id, node_id=node_id,
node_label=node_label, node_label=node_label,
group_id=group_id end_user_id=end_user_id
) )
if result == ConsistencyCheckResult.CONSISTENT: if result == ConsistencyCheckResult.CONSISTENT:
@@ -419,7 +419,7 @@ class AccessHistoryManager:
return True return True
# 获取节点数据 # 获取节点数据
node_data = await self._fetch_node(node_id, node_label, group_id) node_data = await self._fetch_node(node_id, node_label, end_user_id)
if not node_data: if not node_data:
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]") logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
return False return False
@@ -457,8 +457,8 @@ class AccessHistoryManager:
query = f""" query = f"""
MATCH (n:{node_label} {{id: $node_id}}) MATCH (n:{node_label} {{id: $node_id}})
""" """
if group_id: if end_user_id:
query += " WHERE n.group_id = $group_id" query += " WHERE n.end_user_id = $end_user_id"
query += """ query += """
SET n += $repair_data SET n += $repair_data
RETURN n RETURN n
@@ -468,8 +468,8 @@ class AccessHistoryManager:
'node_id': node_id, 'node_id': node_id,
'repair_data': repair_data 'repair_data': repair_data
} }
if group_id: if end_user_id:
params['group_id'] = group_id params['end_user_id'] = end_user_id
await self.connector.execute_query(query, **params) await self.connector.execute_query(query, **params)
@@ -491,7 +491,7 @@ class AccessHistoryManager:
self, self,
node_id: str, node_id: str,
node_label: str, node_label: str,
group_id: Optional[str] = None end_user_id: Optional[str] = None
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
""" """
获取节点数据 获取节点数据
@@ -499,7 +499,7 @@ class AccessHistoryManager:
Args: Args:
node_id: 节点ID node_id: 节点ID
node_label: 节点标签 node_label: 节点标签
group_id: 组ID可选 end_user_id: 组ID可选
Returns: Returns:
Optional[Dict[str, Any]]: 节点数据如果不存在返回None Optional[Dict[str, Any]]: 节点数据如果不存在返回None
@@ -507,8 +507,8 @@ class AccessHistoryManager:
query = f""" query = f"""
MATCH (n:{node_label} {{id: $node_id}}) MATCH (n:{node_label} {{id: $node_id}})
""" """
if group_id: if end_user_id:
query += " WHERE n.group_id = $group_id" query += " WHERE n.end_user_id = $end_user_id"
query += """ query += """
RETURN n.id as id, RETURN n.id as id,
n.importance_score as importance_score, n.importance_score as importance_score,
@@ -519,8 +519,8 @@ class AccessHistoryManager:
""" """
params = {'node_id': node_id} params = {'node_id': node_id}
if group_id: if end_user_id:
params['group_id'] = group_id params['end_user_id'] = end_user_id
results = await self.connector.execute_query(query, **params) results = await self.connector.execute_query(query, **params)
@@ -585,7 +585,7 @@ class AccessHistoryManager:
node_id: str, node_id: str,
node_label: str, node_label: str,
update_data: Dict[str, Any], update_data: Dict[str, Any],
group_id: Optional[str] = None end_user_id: Optional[str] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
原子性更新节点(使用乐观锁) 原子性更新节点(使用乐观锁)
@@ -597,7 +597,7 @@ class AccessHistoryManager:
node_id: 节点ID node_id: 节点ID
node_label: 节点标签 node_label: 节点标签
update_data: 更新数据 update_data: 更新数据
group_id: 组ID可选 end_user_id: 组ID可选
Returns: Returns:
Dict[str, Any]: 更新后的节点数据 Dict[str, Any]: 更新后的节点数据
@@ -606,13 +606,13 @@ class AccessHistoryManager:
RuntimeError: 如果更新失败或发生版本冲突 RuntimeError: 如果更新失败或发生版本冲突
""" """
# 定义事务函数 # 定义事务函数
async def update_transaction(tx, node_id, node_label, update_data, group_id): async def update_transaction(tx, node_id, node_label, update_data, end_user_id):
# 步骤1读取当前节点并获取版本号 # 步骤1读取当前节点并获取版本号
read_query = f""" read_query = f"""
MATCH (n:{node_label} {{id: $node_id}}) MATCH (n:{node_label} {{id: $node_id}})
""" """
if group_id: if end_user_id:
read_query += " WHERE n.group_id = $group_id" read_query += " WHERE n.end_user_id = $end_user_id"
read_query += """ read_query += """
RETURN n.id as id, RETURN n.id as id,
n.version as version, n.version as version,
@@ -624,8 +624,8 @@ class AccessHistoryManager:
""" """
read_params = {'node_id': node_id} read_params = {'node_id': node_id}
if group_id: if end_user_id:
read_params['group_id'] = group_id read_params['end_user_id'] = end_user_id
read_result = await tx.run(read_query, **read_params) read_result = await tx.run(read_query, **read_params)
current_node = await read_result.single() current_node = await read_result.single()
@@ -656,8 +656,8 @@ class AccessHistoryManager:
# 构建 WHERE 子句 # 构建 WHERE 子句
where_conditions = [] where_conditions = []
if group_id: if end_user_id:
where_conditions.append("n.group_id = $group_id") where_conditions.append("n.end_user_id = $end_user_id")
# 添加版本检查 # 添加版本检查
if current_version > 0: if current_version > 0:
@@ -695,8 +695,8 @@ class AccessHistoryManager:
'last_access_time': update_data['last_access_time'], 'last_access_time': update_data['last_access_time'],
'access_count': update_data['access_count'] 'access_count': update_data['access_count']
} }
if group_id: if end_user_id:
update_params['group_id'] = group_id update_params['end_user_id'] = end_user_id
update_result = await tx.run(update_query, **update_params) update_result = await tx.run(update_query, **update_params)
updated_node = await update_result.single() updated_node = await update_result.single()
@@ -720,7 +720,7 @@ class AccessHistoryManager:
node_id=node_id, node_id=node_id,
node_label=node_label, node_label=node_label,
update_data=update_data, update_data=update_data,
group_id=group_id end_user_id=end_user_id
) )
return result return result
except Exception as e: except Exception as e:

View File

@@ -11,9 +11,10 @@ Functions:
import logging import logging
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from uuid import UUID
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.repositories.data_config_repository import DataConfigRepository from app.repositories.memory_config_repository import MemoryConfigRepository
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
@@ -61,12 +62,12 @@ def calculate_forgetting_rate(lambda_time: float, lambda_mem: float) -> float:
def load_actr_config_from_db( def load_actr_config_from_db(
db: Session, db: Session,
config_id: Optional[int] = None config_id: Optional[UUID] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
从数据库加载 ACT-R 配置参数 从数据库加载 ACT-R 配置参数
从 PostgreSQL 的 data_config 表读取配置参数, 从 PostgreSQL 的 memory_config 表读取配置参数,
并计算派生参数(如 forgetting_rate 并计算派生参数(如 forgetting_rate
Args: Args:
@@ -99,7 +100,7 @@ def load_actr_config_from_db(
# 从数据库加载配置 # 从数据库加载配置
try: try:
repository = DataConfigRepository() repository = MemoryConfigRepository()
db_config = repository.get_by_id(db, config_id) db_config = repository.get_by_id(db, config_id)
if db_config is None: if db_config is None:
@@ -150,7 +151,7 @@ def load_actr_config_from_db(
def create_actr_calculator_from_config( def create_actr_calculator_from_config(
db: Session, db: Session,
config_id: Optional[int] = None config_id: Optional[UUID] = None
) -> ACTRCalculator: ) -> ACTRCalculator:
""" """
从数据库配置创建 ACTRCalculator 实例 从数据库配置创建 ACTRCalculator 实例
@@ -168,11 +169,6 @@ def create_actr_calculator_from_config(
ValueError: 如果指定的 config_id 不存在 ValueError: 如果指定的 config_id 不存在
Examples: Examples:
>>> from sqlalchemy.orm import Session
>>> db = Session()
>>> calculator = create_actr_calculator_from_config(db, config_id=1)
>>> # 使用计算器
>>> activation = calculator.calculate_memory_activation(...)
""" """
# 加载配置 # 加载配置
config = load_actr_config_from_db(db, config_id) config = load_actr_config_from_db(db, config_id)

View File

@@ -16,6 +16,7 @@ Classes:
import logging import logging
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
from uuid import UUID
from datetime import datetime from datetime import datetime
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy
@@ -66,10 +67,10 @@ class ForgettingScheduler:
async def run_forgetting_cycle( async def run_forgetting_cycle(
self, self,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
max_merge_batch_size: int = 100, max_merge_batch_size: int = 100,
min_days_since_access: int = 30, min_days_since_access: int = 30,
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
db = None db = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
@@ -77,7 +78,7 @@ class ForgettingScheduler:
Args: Args:
group_id: 组 ID可选用于过滤特定组的节点 end_user_id: 组 ID可选用于过滤特定组的节点
max_merge_batch_size: 单次最大融合节点对数(默认 100 max_merge_batch_size: 单次最大融合节点对数(默认 100
min_days_since_access: 最小未访问天数(默认 30 天) min_days_since_access: 最小未访问天数(默认 30 天)
config_id: 配置ID可选用于获取 llm_id config_id: 配置ID可选用于获取 llm_id
@@ -107,19 +108,19 @@ class ForgettingScheduler:
start_time_iso = start_time.isoformat() start_time_iso = start_time.isoformat()
logger.info( logger.info(
f"开始遗忘周期: group_id={group_id}, " f"开始遗忘周期: end_user_id={end_user_id}, "
f"max_batch={max_merge_batch_size}, " f"max_batch={max_merge_batch_size}, "
f"min_days={min_days_since_access}" f"min_days={min_days_since_access}"
) )
try: try:
# 步骤1统计遗忘前的节点数量 # 步骤1统计遗忘前的节点数量
nodes_before = await self._count_knowledge_nodes(group_id) nodes_before = await self._count_knowledge_nodes(end_user_id)
logger.info(f"遗忘前节点总数: {nodes_before}") logger.info(f"遗忘前节点总数: {nodes_before}")
# 步骤2识别可遗忘的节点对 # 步骤2识别可遗忘的节点对
forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes( forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes(
group_id=group_id, end_user_id=end_user_id,
min_days_since_access=min_days_since_access min_days_since_access=min_days_since_access
) )
@@ -213,7 +214,7 @@ class ForgettingScheduler:
'statement_text': pair['statement_text'], 'statement_text': pair['statement_text'],
'statement_activation': pair['statement_activation'], 'statement_activation': pair['statement_activation'],
'statement_importance': pair['statement_importance'], 'statement_importance': pair['statement_importance'],
'group_id': group_id 'end_user_id': end_user_id
} }
entity_node = { entity_node = {
@@ -222,7 +223,7 @@ class ForgettingScheduler:
'entity_type': pair['entity_type'], 'entity_type': pair['entity_type'],
'entity_activation': pair['entity_activation'], 'entity_activation': pair['entity_activation'],
'entity_importance': pair['entity_importance'], 'entity_importance': pair['entity_importance'],
'group_id': group_id 'end_user_id': end_user_id
} }
# 融合节点 # 融合节点
@@ -262,7 +263,7 @@ class ForgettingScheduler:
continue continue
# 步骤6统计遗忘后的节点数量 # 步骤6统计遗忘后的节点数量
nodes_after = await self._count_knowledge_nodes(group_id) nodes_after = await self._count_knowledge_nodes(end_user_id)
logger.info(f"遗忘后节点总数: {nodes_after}") logger.info(f"遗忘后节点总数: {nodes_after}")
# 步骤7生成遗忘报告 # 步骤7生成遗忘报告
@@ -315,7 +316,7 @@ class ForgettingScheduler:
async def _count_knowledge_nodes( async def _count_knowledge_nodes(
self, self,
group_id: Optional[str] = None end_user_id: Optional[str] = None
) -> int: ) -> int:
""" """
统计知识层节点总数 统计知识层节点总数
@@ -323,7 +324,7 @@ class ForgettingScheduler:
统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。 统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。
Args: Args:
group_id: 组 ID可选用于过滤特定组的节点 end_user_id: 组 ID可选用于过滤特定组的节点
Returns: Returns:
int: 知识层节点总数 int: 知识层节点总数
@@ -333,16 +334,16 @@ class ForgettingScheduler:
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary) WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
""" """
if group_id: if end_user_id:
query += " AND n.group_id = $group_id" query += " AND n.end_user_id = $end_user_id"
query += """ query += """
RETURN count(n) as total RETURN count(n) as total
""" """
params = {} params = {}
if group_id: if end_user_id:
params['group_id'] = group_id end_user_id['end_user_id'] = end_user_id
results = await self.connector.execute_query(query, **params) results = await self.connector.execute_query(query, **params)

View File

@@ -13,6 +13,7 @@ Classes:
import logging import logging
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from uuid import UUID
from datetime import datetime, timedelta from datetime import datetime, timedelta
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
@@ -90,7 +91,7 @@ class ForgettingStrategy:
async def find_forgettable_nodes( async def find_forgettable_nodes(
self, self,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
min_days_since_access: int = 30 min_days_since_access: int = 30
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
@@ -102,7 +103,7 @@ class ForgettingStrategy:
3. Statement 和 Entity 之间存在关系边 3. Statement 和 Entity 之间存在关系边
Args: Args:
group_id: 组 ID可选用于过滤特定组的节点 end_user_id: 组 ID可选用于过滤特定组的节点
min_days_since_access: 最小未访问天数(默认 30 天) min_days_since_access: 最小未访问天数(默认 30 天)
Returns: Returns:
@@ -136,8 +137,8 @@ class ForgettingStrategy:
AND (e.entity_type IS NULL OR e.entity_type <> 'Person') AND (e.entity_type IS NULL OR e.entity_type <> 'Person')
""" """
if group_id: if end_user_id:
query += " AND s.group_id = $group_id AND e.group_id = $group_id" query += " AND s.end_user_id = $end_user_id AND e.end_user_id = $end_user_id"
query += """ query += """
RETURN s.id as statement_id, RETURN s.id as statement_id,
@@ -159,8 +160,8 @@ class ForgettingStrategy:
'threshold': self.forgetting_threshold, 'threshold': self.forgetting_threshold,
'cutoff_time': cutoff_time_iso 'cutoff_time': cutoff_time_iso
} }
if group_id: if end_user_id:
params['group_id'] = group_id params['end_user_id'] = end_user_id
results = await self.connector.execute_query(query, **params) results = await self.connector.execute_query(query, **params)
@@ -176,7 +177,7 @@ class ForgettingStrategy:
self, self,
statement_node: Dict[str, Any], statement_node: Dict[str, Any],
entity_node: Dict[str, Any], entity_node: Dict[str, Any],
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
db = None db = None
) -> str: ) -> str:
""" """
@@ -247,8 +248,8 @@ class ForgettingStrategy:
entity_activation = entity_node['entity_activation'] entity_activation = entity_node['entity_activation']
entity_importance = entity_node['entity_importance'] entity_importance = entity_node['entity_importance']
# 获取 group_id从 statement 或 entity 节点) # 获取 end_user_id从 statement 或 entity 节点)
group_id = statement_node.get('group_id') or entity_node.get('group_id') end_user_id = statement_node.get('end_user_id') or entity_node.get('end_user_id')
# 生成摘要内容 # 生成摘要内容
summary_text = await self._generate_summary( summary_text = await self._generate_summary(
@@ -325,7 +326,7 @@ class ForgettingStrategy:
last_access_time: $current_time, last_access_time: $current_time,
access_count: 1, access_count: 1,
version: 1, version: 1,
group_id: $group_id, end_user_id: $end_user_id,
created_at: datetime($current_time), created_at: datetime($current_time),
merged_at: datetime($current_time) merged_at: datetime($current_time)
}) })
@@ -423,7 +424,7 @@ class ForgettingStrategy:
'inherited_activation': inherited_activation, 'inherited_activation': inherited_activation,
'inherited_importance': inherited_importance, 'inherited_importance': inherited_importance,
'current_time': current_time_iso, 'current_time': current_time_iso,
'group_id': group_id 'end_user_id': end_user_id
} }
try: try:
@@ -462,7 +463,7 @@ class ForgettingStrategy:
statement_text: str, statement_text: str,
entity_name: str, entity_name: str,
entity_type: str, entity_type: str,
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
db = None db = None
) -> str: ) -> str:
""" """
@@ -527,7 +528,7 @@ class ForgettingStrategy:
statement_text, entity_name, entity_type statement_text, entity_name, entity_type
) )
async def _get_llm_client(self, db, config_id: int): async def _get_llm_client(self, db, config_id: UUID):
""" """
从数据库获取 LLM 客户端 从数据库获取 LLM 客户端
@@ -539,11 +540,11 @@ class ForgettingStrategy:
LLM 客户端实例,如果无法获取则返回 None LLM 客户端实例,如果无法获取则返回 None
""" """
try: try:
from app.repositories.data_config_repository import DataConfigRepository from app.repositories.memory_config_repository import MemoryConfigRepository
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
# 从数据库读取配置 # 从数据库读取配置
repository = DataConfigRepository() repository = MemoryConfigRepository()
db_config = repository.get_by_id(db, config_id) db_config = repository.get_by_id(db, config_id)
if db_config is None or db_config.llm_id is None: if db_config is None or db_config.llm_id is None:

View File

@@ -37,7 +37,7 @@ __all__ = [
async def run_hybrid_search( async def run_hybrid_search(
query_text: str, query_text: str,
search_type: str = "hybrid", search_type: str = "hybrid",
group_id: str | None = None, end_user_id: str | None = None,
apply_id: str | None = None, apply_id: str | None = None,
user_id: str | None = None, user_id: str | None = None,
limit: int = 50, limit: int = 50,
@@ -54,7 +54,7 @@ async def run_hybrid_search(
Args: Args:
query_text: 查询文本 query_text: 查询文本
search_type: 搜索类型("hybrid", "keyword", "semantic" search_type: 搜索类型("hybrid", "keyword", "semantic"
group_id: 组ID过滤 end_user_id: 组ID过滤
apply_id: 应用ID过滤 apply_id: 应用ID过滤
user_id: 用户ID过滤 user_id: 用户ID过滤
limit: 每个类别的最大结果数 limit: 每个类别的最大结果数
@@ -104,7 +104,7 @@ async def run_hybrid_search(
# 执行搜索 # 执行搜索
result = await strategy.search( result = await strategy.search(
query_text=query_text, query_text=query_text,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include, include=include,
alpha=alpha, alpha=alpha,

View File

@@ -77,7 +77,7 @@
# async def search( # async def search(
# self, # self,
# query_text: str, # query_text: str,
# group_id: Optional[str] = None, # end_user_id: Optional[str] = None,
# limit: int = 50, # limit: int = 50,
# include: Optional[List[str]] = None, # include: Optional[List[str]] = None,
# **kwargs # **kwargs
@@ -86,7 +86,7 @@
# Args: # Args:
# query_text: 查询文本 # query_text: 查询文本
# group_id: 可选的组ID过滤 # end_user_id: 可选的组ID过滤
# limit: 每个类别的最大结果数 # limit: 每个类别的最大结果数
# include: 要包含的搜索类别列表 # include: 要包含的搜索类别列表
# **kwargs: 其他搜索参数如alpha, use_forgetting_curve # **kwargs: 其他搜索参数如alpha, use_forgetting_curve
@@ -94,7 +94,7 @@
# Returns: # Returns:
# SearchResult: 搜索结果对象 # SearchResult: 搜索结果对象
# """ # """
# logger.info(f"执行混合搜索: query='{query_text}', group_id={group_id}, limit={limit}") # logger.info(f"执行混合搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
# # 从kwargs中获取参数 # # 从kwargs中获取参数
# alpha = kwargs.get("alpha", self.alpha) # alpha = kwargs.get("alpha", self.alpha)
@@ -107,14 +107,14 @@
# # 并行执行关键词搜索和语义搜索 # # 并行执行关键词搜索和语义搜索
# keyword_result = await self.keyword_strategy.search( # keyword_result = await self.keyword_strategy.search(
# query_text=query_text, # query_text=query_text,
# group_id=group_id, # end_user_id=end_user_id,
# limit=limit, # limit=limit,
# include=include_list # include=include_list
# ) # )
# semantic_result = await self.semantic_strategy.search( # semantic_result = await self.semantic_strategy.search(
# query_text=query_text, # query_text=query_text,
# group_id=group_id, # end_user_id=end_user_id,
# limit=limit, # limit=limit,
# include=include_list # include=include_list
# ) # )
@@ -139,7 +139,7 @@
# metadata = self._create_metadata( # metadata = self._create_metadata(
# query_text=query_text, # query_text=query_text,
# search_type="hybrid", # search_type="hybrid",
# group_id=group_id, # end_user_id=end_user_id,
# limit=limit, # limit=limit,
# include=include_list, # include=include_list,
# alpha=alpha, # alpha=alpha,
@@ -165,7 +165,7 @@
# metadata=self._create_metadata( # metadata=self._create_metadata(
# query_text=query_text, # query_text=query_text,
# search_type="hybrid", # search_type="hybrid",
# group_id=group_id, # end_user_id=end_user_id,
# limit=limit, # limit=limit,
# error=str(e) # error=str(e)
# ) # )

View File

@@ -44,7 +44,7 @@ class KeywordSearchStrategy(SearchStrategy):
async def search( async def search(
self, self,
query_text: str, query_text: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
include: Optional[List[str]] = None, include: Optional[List[str]] = None,
**kwargs **kwargs
@@ -53,7 +53,7 @@ class KeywordSearchStrategy(SearchStrategy):
Args: Args:
query_text: 查询文本 query_text: 查询文本
group_id: 可选的组ID过滤 end_user_id: 可选的组ID过滤
limit: 每个类别的最大结果数 limit: 每个类别的最大结果数
include: 要包含的搜索类别列表 include: 要包含的搜索类别列表
**kwargs: 其他搜索参数 **kwargs: 其他搜索参数
@@ -61,7 +61,7 @@ class KeywordSearchStrategy(SearchStrategy):
Returns: Returns:
SearchResult: 搜索结果对象 SearchResult: 搜索结果对象
""" """
logger.info(f"执行关键词搜索: query='{query_text}', group_id={group_id}, limit={limit}") logger.info(f"执行关键词搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
# 获取有效的搜索类别 # 获取有效的搜索类别
include_list = self._get_include_list(include) include_list = self._get_include_list(include)
@@ -75,7 +75,7 @@ class KeywordSearchStrategy(SearchStrategy):
results_dict = await search_graph( results_dict = await search_graph(
connector=self.connector, connector=self.connector,
q=query_text, q=query_text,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include_list include=include_list
) )
@@ -84,7 +84,7 @@ class KeywordSearchStrategy(SearchStrategy):
metadata = self._create_metadata( metadata = self._create_metadata(
query_text=query_text, query_text=query_text,
search_type="keyword", search_type="keyword",
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include_list include=include_list
) )
@@ -115,7 +115,7 @@ class KeywordSearchStrategy(SearchStrategy):
metadata=self._create_metadata( metadata=self._create_metadata(
query_text=query_text, query_text=query_text,
search_type="keyword", search_type="keyword",
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
error=str(e) error=str(e)
) )

View File

@@ -58,7 +58,7 @@ class SearchStrategy(ABC):
async def search( async def search(
self, self,
query_text: str, query_text: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
include: Optional[List[str]] = None, include: Optional[List[str]] = None,
**kwargs **kwargs
@@ -67,7 +67,7 @@ class SearchStrategy(ABC):
Args: Args:
query_text: 查询文本 query_text: 查询文本
group_id: 可选的组ID过滤 end_user_id: 可选的组ID过滤
limit: 每个类别的最大结果数 limit: 每个类别的最大结果数
include: 要包含的搜索类别列表statements, chunks, entities, summaries include: 要包含的搜索类别列表statements, chunks, entities, summaries
**kwargs: 其他搜索参数 **kwargs: 其他搜索参数
@@ -81,7 +81,7 @@ class SearchStrategy(ABC):
self, self,
query_text: str, query_text: str,
search_type: str, search_type: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
**kwargs **kwargs
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@@ -90,7 +90,7 @@ class SearchStrategy(ABC):
Args: Args:
query_text: 查询文本 query_text: 查询文本
search_type: 搜索类型 search_type: 搜索类型
group_id: 组ID end_user_id: 组ID
limit: 结果限制 limit: 结果限制
**kwargs: 其他元数据 **kwargs: 其他元数据
@@ -100,7 +100,7 @@ class SearchStrategy(ABC):
metadata = { metadata = {
"query": query_text, "query": query_text,
"search_type": search_type, "search_type": search_type,
"group_id": group_id, "end_user_id": end_user_id,
"limit": limit, "limit": limit,
"timestamp": datetime.now().isoformat() "timestamp": datetime.now().isoformat()
} }

View File

@@ -85,7 +85,7 @@ class SemanticSearchStrategy(SearchStrategy):
async def search( async def search(
self, self,
query_text: str, query_text: str,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
include: Optional[List[str]] = None, include: Optional[List[str]] = None,
**kwargs **kwargs
@@ -94,7 +94,7 @@ class SemanticSearchStrategy(SearchStrategy):
Args: Args:
query_text: 查询文本 query_text: 查询文本
group_id: 可选的组ID过滤 end_user_id: 可选的组ID过滤
limit: 每个类别的最大结果数 limit: 每个类别的最大结果数
include: 要包含的搜索类别列表 include: 要包含的搜索类别列表
**kwargs: 其他搜索参数 **kwargs: 其他搜索参数
@@ -102,7 +102,7 @@ class SemanticSearchStrategy(SearchStrategy):
Returns: Returns:
SearchResult: 搜索结果对象 SearchResult: 搜索结果对象
""" """
logger.info(f"执行语义搜索: query='{query_text}', group_id={group_id}, limit={limit}") logger.info(f"执行语义搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
# 获取有效的搜索类别 # 获取有效的搜索类别
include_list = self._get_include_list(include) include_list = self._get_include_list(include)
@@ -119,7 +119,7 @@ class SemanticSearchStrategy(SearchStrategy):
connector=self.connector, connector=self.connector,
embedder_client=self.embedder_client, embedder_client=self.embedder_client,
query_text=query_text, query_text=query_text,
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include_list include=include_list
) )
@@ -128,7 +128,7 @@ class SemanticSearchStrategy(SearchStrategy):
metadata = self._create_metadata( metadata = self._create_metadata(
query_text=query_text, query_text=query_text,
search_type="semantic", search_type="semantic",
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include_list include=include_list
) )
@@ -159,7 +159,7 @@ class SemanticSearchStrategy(SearchStrategy):
metadata=self._create_metadata( metadata=self._create_metadata(
query_text=query_text, query_text=query_text,
search_type="semantic", search_type="semantic",
group_id=group_id, end_user_id=end_user_id,
limit=limit, limit=limit,
error=str(e) error=str(e)
) )

View File

@@ -23,7 +23,7 @@ async def _load_(data: List[Any]) -> List[Dict]:
target_keys = [ target_keys = [
"id", "id",
"statement", "statement",
"group_id", "end_user_id",
"chunk_id", "chunk_id",
"created_at", "created_at",
"expired_at", "expired_at",
@@ -75,7 +75,7 @@ async def get_data(result):
""" """
EXCLUDE_FIELDS = { EXCLUDE_FIELDS = {
"user_id", "user_id",
"group_id", "end_user_id",
"entity_type", "entity_type",
"connect_strength", "connect_strength",
"relationship_type", "relationship_type",

View File

@@ -62,7 +62,7 @@ class ConfigAuditLogger:
self, self,
config_id: str, config_id: str,
user_id: Optional[str] = None, user_id: Optional[str] = None,
group_id: Optional[str] = None, end_user_id: Optional[str] = None,
success: bool = True, success: bool = True,
details: Optional[Dict[str, Any]] = None details: Optional[Dict[str, Any]] = None
): ):
@@ -72,14 +72,14 @@ class ConfigAuditLogger:
Args: Args:
config_id: 配置 ID config_id: 配置 ID
user_id: 用户 ID可选 user_id: 用户 ID可选
group_id: 组 ID可选 end_user_id: 组 ID可选
success: 是否成功 success: 是否成功
details: 详细信息(可选) details: 详细信息(可选)
""" """
result = "SUCCESS" if success else "FAILED" result = "SUCCESS" if success else "FAILED"
msg = ( msg = (
f"CONFIG_LOAD config_id={config_id} " f"CONFIG_LOAD config_id={config_id} "
f"user={user_id or 'N/A'} group={group_id or 'N/A'} " f"user={user_id or 'N/A'} group={end_user_id or 'N/A'} "
f"result={result}" f"result={result}"
) )
if details: if details:
@@ -121,7 +121,7 @@ class ConfigAuditLogger:
self, self,
operation: str, operation: str,
config_id: str, config_id: str,
group_id: str, end_user_id: str,
success: bool = True, success: bool = True,
duration: Optional[float] = None, duration: Optional[float] = None,
error: Optional[str] = None, error: Optional[str] = None,
@@ -133,7 +133,7 @@ class ConfigAuditLogger:
Args: Args:
operation: 操作类型WRITE, READ 等) operation: 操作类型WRITE, READ 等)
config_id: 配置 ID config_id: 配置 ID
group_id: 组 ID end_user_id: 组 ID
success: 是否成功 success: 是否成功
duration: 操作耗时(秒) duration: 操作耗时(秒)
error: 错误信息(可选) error: 错误信息(可选)
@@ -142,7 +142,7 @@ class ConfigAuditLogger:
result = "SUCCESS" if success else "FAILED" result = "SUCCESS" if success else "FAILED"
msg = ( msg = (
f"{operation.upper()} config_id={config_id} " f"{operation.upper()} config_id={config_id} "
f"group={group_id} result={result}" f"group={end_user_id} result={result}"
) )
if duration is not None: if duration is not None:
msg += f" duration={duration:.2f}s" msg += f" duration={duration:.2f}s"

View File

@@ -4,7 +4,7 @@ from enum import StrEnum, auto
class Field(StrEnum): class Field(StrEnum):
CONTENT_KEY = "page_content" CONTENT_KEY = "page_content"
METADATA_KEY = "metadata" METADATA_KEY = "metadata"
GROUP_KEY = "group_id" GROUP_KEY = "end_user_id"
VECTOR = auto() VECTOR = auto()
# Sparse Vector aims to support full text search # Sparse Vector aims to support full text search
SPARSE_VECTOR = auto() SPARSE_VECTOR = auto()

View File

@@ -16,7 +16,7 @@ class BaiduSearchTool(BuiltinTool):
@property @property
def description(self) -> str: def description(self) -> str:
return "百度搜索 - 搜索引擎服务:网页搜索、新闻搜索、图片搜索、实时结果" return "百度搜索 - 搜索引擎服务:网页搜索、新闻搜索、图片搜索、视频搜索"
def get_required_config_parameters(self) -> List[str]: def get_required_config_parameters(self) -> List[str]:
return ["api_key"] return ["api_key"]
@@ -33,7 +33,7 @@ class BaiduSearchTool(BuiltinTool):
ToolParameter( ToolParameter(
name="search_type", name="search_type",
type=ParameterType.STRING, type=ParameterType.STRING,
description="搜索类型", description="搜索类型, web: 网页搜索news新闻搜索image图片搜索video视频搜索",
required=False, required=False,
default="web", default="web",
enum=["web", "news", "image", "video"] enum=["web", "news", "image", "video"]

View File

@@ -26,7 +26,7 @@ logger = get_config_logger()
def _parse_model_id(model_id: Union[str, UUID, None], model_type: str, def _parse_model_id(model_id: Union[str, UUID, None], model_type: str,
config_id: Optional[int] = None, workspace_id: Optional[UUID] = None) -> Optional[UUID]: config_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None) -> Optional[UUID]:
"""Parse model ID from string or UUID.""" """Parse model ID from string or UUID."""
if model_id is None: if model_id is None:
return None return None
@@ -59,7 +59,7 @@ def validate_model_exists_and_active(
model_type: str, model_type: str,
db: Session, db: Session,
tenant_id: Optional[UUID] = None, tenant_id: Optional[UUID] = None,
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None workspace_id: Optional[UUID] = None
) -> tuple[str, bool]: ) -> tuple[str, bool]:
"""Validate that a model exists and is active. """Validate that a model exists and is active.
@@ -166,7 +166,7 @@ def validate_and_resolve_model_id(
db: Session, db: Session,
tenant_id: Optional[UUID] = None, tenant_id: Optional[UUID] = None,
required: bool = False, required: bool = False,
config_id: Optional[int] = None, config_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None workspace_id: Optional[UUID] = None
) -> tuple[Optional[UUID], Optional[str]]: ) -> tuple[Optional[UUID], Optional[str]]:
"""Validate and resolve a model ID, checking existence and active status. """Validate and resolve a model ID, checking existence and active status.
@@ -204,7 +204,7 @@ def validate_and_resolve_model_id(
def validate_embedding_model( def validate_embedding_model(
config_id: int, config_id: UUID,
embedding_id: Union[str, UUID, None], embedding_id: Union[str, UUID, None],
db: Session, db: Session,
tenant_id: Optional[UUID] = None, tenant_id: Optional[UUID] = None,
@@ -256,7 +256,7 @@ def validate_embedding_model(
def validate_llm_model( def validate_llm_model(
config_id: int, config_id: UUID,
llm_id: Union[str, UUID, None], llm_id: Union[str, UUID, None],
db: Session, db: Session,
tenant_id: Optional[UUID] = None, tenant_id: Optional[UUID] = None,

View File

@@ -261,7 +261,7 @@ class WorkflowExecutor:
"data": { "data": {
"execution_id": self.execution_id, "execution_id": self.execution_id,
"workspace_id": self.workspace_id, "workspace_id": self.workspace_id,
"timestamp": start_time.isoformat() "timestamp": int(start_time.timestamp() * 1000)
} }
} }
@@ -293,6 +293,7 @@ class WorkflowExecutor:
# Handle custom streaming events (chunks from nodes via stream writer) # Handle custom streaming events (chunks from nodes via stream writer)
chunk_count += 1 chunk_count += 1
event_type = data.get("type", "node_chunk") # "message" or "node_chunk" event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
if event_type in ("message", "node_chunk"):
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}" logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}"
f"- execution_id: {self.execution_id}") f"- execution_id: {self.execution_id}")
yield { yield {
@@ -307,6 +308,18 @@ class WorkflowExecutor:
"conversation_id": input_data.get("conversation_id"), "conversation_id": input_data.get("conversation_id"),
} }
} }
elif event_type == "node_error":
yield {
"event": event_type, # "message" or "node_chunk"
"data": {
"node_id": data.get("node_id"),
"status": "failed",
"input": data.get("input_data"),
"elapsed_time": data.get("elapsed_time"),
"output": None,
"error": data.get("error")
}
}
elif mode == "debug": elif mode == "debug":
# Handle debug information (node execution status) # Handle debug information (node execution status)
@@ -325,14 +338,15 @@ class WorkflowExecutor:
conversation_id = input_data.get("conversation_id") conversation_id = input_data.get("conversation_id")
logger.info(f"[NODE-START] Node starts execution: {node_name} " logger.info(f"[NODE-START] Node starts execution: {node_name} "
f"- execution_id: {self.execution_id}") f"- execution_id: {self.execution_id}")
yield { yield {
"event": "node_start", "event": "node_start",
"data": { "data": {
"node_id": node_name, "node_id": node_name,
"conversation_id": conversation_id, "conversation_id": conversation_id,
"execution_id": self.execution_id, "execution_id": self.execution_id,
"timestamp": data.get("timestamp"), "timestamp": int(datetime.datetime.fromisoformat(
data.get("timestamp")
).timestamp() * 1000),
} }
} }
elif event_type == "task_result": elif event_type == "task_result":
@@ -351,8 +365,12 @@ class WorkflowExecutor:
"node_id": node_name, "node_id": node_name,
"conversation_id": conversation_id, "conversation_id": conversation_id,
"execution_id": self.execution_id, "execution_id": self.execution_id,
"timestamp": data.get("timestamp"), "timestamp": int(datetime.datetime.fromisoformat(
"state": result.get("node_outputs", {}).get(node_name), data.get("timestamp")
).timestamp() * 1000),
"input": result.get("node_outputs", {}).get(node_name, {}).get("input"),
"output": result.get("node_outputs", {}).get(node_name, {}).get("output"),
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
} }
} }

View File

@@ -544,6 +544,11 @@ class BaseNode(ABC):
"error_node": self.node_id "error_node": self.node_id
} }
else: else:
writer = get_stream_writer()
writer({
"type": "node_error",
**node_output
})
# 无错误边:抛出异常停止工作流 # 无错误边:抛出异常停止工作流
logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}") logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}")
raise Exception(f"节点 {self.node_id} 执行失败: {error_message}") raise Exception(f"节点 {self.node_id} 执行失败: {error_message}")

View File

@@ -0,0 +1,3 @@
from app.core.workflow.nodes.code.node import CodeNode
__all__ = ["CodeNode"]

View File

@@ -0,0 +1,50 @@
from typing import Literal
from pydantic import Field, BaseModel
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
class InputVariable(BaseModel):
name: str = Field(
...,
description="variable name"
)
variable: str = Field(
...,
description="variable selector"
)
class OutputVariable(BaseModel):
name: str = Field(
...,
description="variable name"
)
type: VariableType = Field(
...,
description="variable selector"
)
class CodeNodeConfig(BaseNodeConfig):
input_variables: list[InputVariable] = Field(
default_factory=list,
description="input variables"
)
output_variables: list[OutputVariable] = Field(
default_factory=list,
description="output variables"
)
code: str = Field(
default="",
description="code content"
)
language: Literal['python3', 'nodejs'] = Field(
...,
description="language"
)

View File

@@ -0,0 +1,122 @@
import base64
import json
import logging
import re
from string import Template
from textwrap import dedent
from typing import Any
import httpx
from sympy.physics.vector import vlatex
from app.core.workflow.nodes import BaseNode, WorkflowState
from app.core.workflow.nodes.base_config import VariableType
from app.core.workflow.nodes.code.config import CodeNodeConfig
logger = logging.getLogger(__name__)
SCRIPT_TEMPLATE = Template(dedent("""
$code
import json
from base64 import b64decode
# decode and prepare input dict
inputs_obj = json.loads(b64decode('$inputs_variable').decode('utf-8'))
# execute main function
output_obj = main(**inputs_obj)
# convert output to json and print
output_json = json.dumps(output_obj, indent=4)
result = "<<RESULT>>" + output_json + "<<RESULT>>"
print(result)
"""))
class CodeNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config: CodeNodeConfig | None = None
def extract_result(self, content: str):
match = re.search(r'<<RESULT>>(.*?)<<RESULT>>', content, re.DOTALL)
if match:
extracted = match.group(1)
exec_result = json.loads(extracted)
result = {}
for output in self.typed_config.output_variables:
value = exec_result.get(output.name)
if value is None:
raise RuntimeError(f"Return value {output.name} does not exist")
match output.type:
case VariableType.STRING:
if not isinstance(value, str):
raise RuntimeError(f"Return value {output.name} should be a string")
case VariableType.BOOLEAN:
if not isinstance(value, bool):
raise RuntimeError(f"Return value {output.name} should be a boolean")
case VariableType.NUMBER:
if not isinstance(value, (int, float)):
raise RuntimeError(f"Return value {output.name} should be a number")
case VariableType.OBJECT:
if not isinstance(value, dict):
raise RuntimeError(f"Return value {output.name} should be a dictionary")
case VariableType.ARRAY_STRING:
if not isinstance(value, list) or not all(isinstance(v, str) for v in value):
raise RuntimeError(f"Return value {output.name} should be a list of strings")
case VariableType.ARRAY_NUMBER:
if not isinstance(value, list) or not all(isinstance(v, (int, float)) for v in value):
raise RuntimeError(f"Return value {output.name} should be a list of numbers")
case VariableType.ARRAY_OBJECT:
if not isinstance(value, list) or not all(isinstance(v, dict) for v in value):
raise RuntimeError(f"Return value {output.name} should be a list of dictionaries")
case VariableType.ARRAY_BOOLEAN:
if not isinstance(value, list) or not all(isinstance(v, bool) for v in value):
raise RuntimeError(f"Return value {output.name} should be a list of booleans")
result[output.name] = value
return result
else:
raise RuntimeError("The output of main must be a dictionary")
async def execute(self, state: WorkflowState) -> Any:
self.typed_config = CodeNodeConfig(**self.config)
input_variable_dict = {}
for input_variable in self.typed_config.input_variables:
input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, state)
code = base64.b64decode(
self.typed_config.code
).decode("utf-8")
input_variable_dict = base64.b64encode(
json.dumps(input_variable_dict).encode("utf-8")
).decode("utf-8")
final_script = SCRIPT_TEMPLATE.substitute(
code=code,
inputs_variable=input_variable_dict,
)
async with httpx.AsyncClient() as client:
response = await client.post(
"http://sandbox:8194/v1/sandbox/run",
headers={
"x-api-key": 'redbear-sandbox'
},
json={
"language": self.typed_config.language,
"code": base64.b64encode(final_script.encode("utf-8")).decode("utf-8"),
"options": {
"enable_network": True
}
}
)
resp = response.json()
match resp['code']:
case 31:
raise RuntimeError("Operation not permitted")
case 0:
return self.extract_result(resp["data"]["stdout"])
case _:
raise Exception(resp["message"])

View File

@@ -10,21 +10,22 @@ from app.core.workflow.nodes.base_config import (
VariableDefinition, VariableDefinition,
VariableType, VariableType,
) )
from app.core.workflow.nodes.code.config import CodeNodeConfig
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
from app.core.workflow.nodes.end.config import EndNodeConfig from app.core.workflow.nodes.end.config import EndNodeConfig
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig
from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
from app.core.workflow.nodes.start.config import StartNodeConfig from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
from app.core.workflow.nodes.start.config import StartNodeConfig
from app.core.workflow.nodes.tool.config import ToolNodeConfig from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
__all__ = [ __all__ = [
# 基础类 # 基础类
"BaseNodeConfig", "BaseNodeConfig",
@@ -49,5 +50,6 @@ __all__ = [
"QuestionClassifierNodeConfig", "QuestionClassifierNodeConfig",
"ToolNodeConfig", "ToolNodeConfig",
"MemoryReadNodeConfig", "MemoryReadNodeConfig",
"MemoryWriteNodeConfig" "MemoryWriteNodeConfig",
"CodeNodeConfig"
] ]

View File

@@ -1,4 +1,5 @@
import uuid import uuid
from uuid import UUID
from pydantic import Field from pydantic import Field
from typing import Literal from typing import Literal
@@ -11,7 +12,7 @@ class MemoryReadNodeConfig(BaseNodeConfig):
... ...
) )
config_id: int = Field( config_id: UUID = Field(
... ...
) )
@@ -26,6 +27,6 @@ class MemoryWriteNodeConfig(BaseNodeConfig):
... ...
) )
config_id: int = Field( config_id: UUID = Field(
... ...
) )

View File

@@ -22,7 +22,7 @@ class MemoryReadNode(BaseNode):
raise RuntimeError("End user id is required") raise RuntimeError("End user id is required")
return await MemoryAgentService().read_memory( return await MemoryAgentService().read_memory(
group_id=end_user_id, end_user_id=end_user_id,
message=self._render_template(self.typed_config.message, state), message=self._render_template(self.typed_config.message, state),
config_id=str(self.typed_config.config_id), config_id=str(self.typed_config.config_id),
search_switch=self.typed_config.search_switch, search_switch=self.typed_config.search_switch,

View File

@@ -10,6 +10,7 @@ from typing import Any, Union
from app.core.workflow.nodes.agent import AgentNode from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.assigner import AssignerNode from app.core.workflow.nodes.assigner import AssignerNode
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.code import CodeNode
from app.core.workflow.nodes.cycle_graph.node import CycleGraphNode from app.core.workflow.nodes.cycle_graph.node import CycleGraphNode
from app.core.workflow.nodes.end import EndNode from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
@@ -49,7 +50,8 @@ WorkflowNode = Union[
QuestionClassifierNode, QuestionClassifierNode,
ToolNode, ToolNode,
MemoryReadNode, MemoryReadNode,
MemoryWriteNode MemoryWriteNode,
CodeNode
] ]
@@ -81,6 +83,7 @@ class NodeFactory:
NodeType.TOOL: ToolNode, NodeType.TOOL: ToolNode,
NodeType.MEMORY_READ: MemoryReadNode, NodeType.MEMORY_READ: MemoryReadNode,
NodeType.MEMORY_WRITE: MemoryWriteNode, NodeType.MEMORY_WRITE: MemoryWriteNode,
NodeType.CODE: CodeNode,
} }
@classmethod @classmethod

View File

@@ -18,7 +18,7 @@ from .appshare_model import AppShare
from .release_share_model import ReleaseShare from .release_share_model import ReleaseShare
from .conversation_model import Conversation, Message from .conversation_model import Conversation, Message
from .api_key_model import ApiKey, ApiKeyLog, ApiKeyType from .api_key_model import ApiKey, ApiKeyLog, ApiKeyType
from .data_config_model import DataConfig from .memory_config_model import MemoryConfig
from .multi_agent_model import MultiAgentConfig, AgentInvocation from .multi_agent_model import MultiAgentConfig, AgentInvocation
from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution
from .retrieval_info import RetrievalInfo from .retrieval_info import RetrievalInfo
@@ -57,7 +57,7 @@ __all__ = [
"ApiKey", "ApiKey",
"ApiKeyLog", "ApiKeyLog",
"ApiKeyType", "ApiKeyType",
"DataConfig", "MemoryConfig",
"MultiAgentConfig", "MultiAgentConfig",
"AgentInvocation", "AgentInvocation",
"WorkflowConfig", "WorkflowConfig",

View File

@@ -6,7 +6,7 @@ from sqlalchemy.orm import relationship
from app.base.type import PydanticType from app.base.type import PydanticType
from app.db import Base from app.db import Base
from app.schemas import ModelParameters from app.schemas.app_schema import ModelParameters
class AgentConfig(Base): class AgentConfig(Base):

View File

@@ -1,88 +0,0 @@
import datetime
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float
from sqlalchemy.dialects.postgresql import UUID
from app.db import Base
class DataConfig(Base):
"""数据配置表 - 用于存储记忆系统的配置参数"""
__tablename__ = "data_config"
# 主键
config_id = Column(Integer, primary_key=True, autoincrement=True, comment="配置ID")
# 基本信息
config_name = Column(String, nullable=False, comment="配置名称")
config_desc = Column(String, nullable=True, comment="配置描述")
# 组织信息
workspace_id = Column(UUID(as_uuid=True), nullable=True, comment="工作空间ID")
group_id = Column(String, nullable=True, comment="组ID")
user_id = Column(String, nullable=True, comment="用户ID")
apply_id = Column(String, nullable=True, comment="应用ID")
# 模型选择从workspace继承
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
# 记忆萃取引擎配置
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
enable_llm_disambiguation = Column(Boolean, default=True, comment="启用LLM决策消歧")
deep_retrieval = Column(Boolean, default=True, comment="深度检索开关")
# 阈值配置 (0-1 之间的浮点数)
t_type_strict = Column(Float, default=0.8, comment="类型严格阈值")
t_name_strict = Column(Float, default=0.8, comment="名称严格阈值")
t_overall = Column(Float, default=0.8, comment="综合阈值")
# 状态配置
state = Column(Boolean, default=False, comment="配置使用状态")
# 分块策略
chunker_strategy = Column(String, default="RecursiveChunker", comment="分块策略")
# 剪枝配置
pruning_enabled = Column(Boolean, default=False, comment="是否启动智能语义剪枝")
pruning_scene = Column(String, nullable=True, comment="智能剪枝场景education/online_service/outbound")
pruning_threshold = Column(Float, nullable=True, comment="智能语义剪枝阈值0-0.9")
# 自我反思配置
enable_self_reflexion = Column(Boolean, default=False, comment="是否启用自我反思")
iteration_period = Column(String, default="3", comment="反思迭代周期")
reflexion_range = Column(String, default="partial", comment="反思范围:部分/全部")
baseline = Column(String, default="TIME", comment="基线:时间/事实/时间和事实")
reflection_model_id = Column(String, nullable=True, comment="反思模型ID")
memory_verify = Column(Boolean, default=True, comment="记忆验证")
quality_assessment = Column(Boolean, default=True, comment="质量评估")
# 遗忘引擎配置
statement_granularity = Column(Integer, default=2, comment="陈述提取颗粒度,挡位 1/2/3")
include_dialogue_context = Column(Boolean, default=False, comment="是否包含对话上下文")
max_context = Column(Integer, default=1000, comment="对话语境中包含字符的最大数量")
lambda_time = Column("lambda_time", Float, default=0.5, comment="最低保持度0-1 小数")
lambda_mem = Column("lambda_mem", Float, default=0.5, comment="遗忘率0-1 小数")
offset = Column("offset", Float, default=0.0, comment="偏移度0-1 小数")
# ACT-R 遗忘引擎配置
decay_constant = Column(Float, default=0.5, comment="ACT-R衰减常数d默认0.5")
forgetting_threshold = Column(Float, default=0.3, comment="遗忘阈值默认0.3")
forgetting_interval_hours = Column(Integer, default=24, comment="遗忘周期间隔小时默认24")
enable_llm_summary = Column(Boolean, default=True, comment="是否使用LLM生成摘要默认True")
max_merge_batch_size = Column(Integer, default=100, comment="单次最大融合节点对数默认100")
max_history_length = Column(Integer, default=100, comment="访问历史最大长度默认100")
min_days_since_access = Column(Integer, default=30, comment="最小未访问天数默认30")
# 情绪引擎配置
emotion_enabled = Column(Boolean, default=True, comment="是否启用情绪提取")
emotion_model_id = Column(String, nullable=True, comment="情绪分析专用模型ID")
emotion_extract_keywords = Column(Boolean, default=True, comment="是否提取情绪关键词")
emotion_min_intensity = Column(Float, default=0.1, comment="最小情绪强度阈值")
emotion_enable_subject = Column(Boolean, default=True, comment="是否启用主体分类")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
def __repr__(self):
return f"<DataConfig(config_id={self.config_id}, config_name={self.config_name})>"

View File

@@ -1,39 +1,88 @@
# -*- coding: utf-8 -*- import datetime
"""Memory Configuration Model - Backward Compatibility from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float
from sqlalchemy.dialects.postgresql import UUID
from app.db import Base
This module provides backward compatibility for imports.
All classes have been moved to app.schemas.memory_config_schema.
DEPRECATED: Import from app.schemas.memory_config_schema instead. class MemoryConfig(Base):
""" """记忆配置表 - 用于存储记忆系统的配置参数"""
__tablename__ = "memory_config"
# Re-export for backward compatibility # 主键
from app.schemas.memory_config_schema import ( config_id = Column(UUID(as_uuid=True), primary_key=True, comment="配置ID")
ConfigurationError,
InvalidConfigError,
MemoryConfig,
MemoryConfigValidation,
ModelInactiveError,
ModelNotFoundError,
ModelValidation,
WorkspaceNotFoundError,
WorkspaceValidation,
validate_memory_config_data,
validate_model_data,
validate_workspace_data,
)
__all__ = [ # 基本信息
"ConfigurationError", config_name = Column(String, nullable=False, comment="配置名称")
"InvalidConfigError", config_desc = Column(String, nullable=True, comment="配置描述")
"MemoryConfig",
"MemoryConfigValidation", # 组织信息
"ModelInactiveError", workspace_id = Column(UUID(as_uuid=True), nullable=True, comment="工作空间ID")
"ModelNotFoundError", end_user_id = Column(String, nullable=True, comment="组ID")
"ModelValidation", user_id = Column(String, nullable=True, comment="用户ID")
"WorkspaceNotFoundError", apply_id = Column(String, nullable=True, comment="应用ID")
"WorkspaceValidation",
"validate_memory_config_data", # 模型选择从workspace继承
"validate_model_data", llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
"validate_workspace_data", embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
] rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
# 记忆萃取引擎配置
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
enable_llm_disambiguation = Column(Boolean, default=True, comment="启用LLM决策消歧")
deep_retrieval = Column(Boolean, default=True, comment="深度检索开关")
# 阈值配置 (0-1 之间的浮点数)
t_type_strict = Column(Float, default=0.8, comment="类型严格阈值")
t_name_strict = Column(Float, default=0.8, comment="名称严格阈值")
t_overall = Column(Float, default=0.8, comment="综合阈值")
# 状态配置
state = Column(Boolean, default=False, comment="配置使用状态")
# 分块策略
chunker_strategy = Column(String, default="RecursiveChunker", comment="分块策略")
# 剪枝配置
pruning_enabled = Column(Boolean, default=False, comment="是否启动智能语义剪枝")
pruning_scene = Column(String, nullable=True, comment="智能剪枝场景education/online_service/outbound")
pruning_threshold = Column(Float, nullable=True, comment="智能语义剪枝阈值0-0.9")
# 自我反思配置
enable_self_reflexion = Column(Boolean, default=False, comment="是否启用自我反思")
iteration_period = Column(String, default="3", comment="反思迭代周期")
reflexion_range = Column(String, default="partial", comment="反思范围:部分/全部")
baseline = Column(String, default="TIME", comment="基线:时间/事实/时间和事实")
reflection_model_id = Column(String, nullable=True, comment="反思模型ID")
memory_verify = Column(Boolean, default=True, comment="记忆验证")
quality_assessment = Column(Boolean, default=True, comment="质量评估")
# 遗忘引擎配置
statement_granularity = Column(Integer, default=2, comment="陈述提取颗粒度,挡位 1/2/3")
include_dialogue_context = Column(Boolean, default=False, comment="是否包含对话上下文")
max_context = Column(Integer, default=1000, comment="对话语境中包含字符的最大数量")
lambda_time = Column("lambda_time", Float, default=0.5, comment="最低保持度0-1 小数")
lambda_mem = Column("lambda_mem", Float, default=0.5, comment="遗忘率0-1 小数")
offset = Column("offset", Float, default=0.0, comment="偏移度0-1 小数")
# ACT-R 遗忘引擎配置
decay_constant = Column(Float, default=0.5, comment="ACT-R衰减常数d默认0.5")
forgetting_threshold = Column(Float, default=0.3, comment="遗忘阈值默认0.3")
forgetting_interval_hours = Column(Integer, default=24, comment="遗忘周期间隔小时默认24")
enable_llm_summary = Column(Boolean, default=True, comment="是否使用LLM生成摘要默认True")
max_merge_batch_size = Column(Integer, default=100, comment="单次最大融合节点对数默认100")
max_history_length = Column(Integer, default=100, comment="访问历史最大长度默认100")
min_days_since_access = Column(Integer, default=30, comment="最小未访问天数默认30")
# 情绪引擎配置
emotion_enabled = Column(Boolean, default=True, comment="是否启用情绪提取")
emotion_model_id = Column(String, nullable=True, comment="情绪分析专用模型ID")
emotion_extract_keywords = Column(Boolean, default=True, comment="是否提取情绪关键词")
emotion_min_intensity = Column(Float, default=0.1, comment="最小情绪强度阈值")
emotion_enable_subject = Column(Boolean, default=True, comment="是否启用主体分类")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
def __repr__(self):
return f"<MemoryConfig(config_id={self.config_id}, config_name={self.config_name})>"

View File

@@ -16,7 +16,7 @@ class PerceptualType(IntEnum):
CONVERSATION = 4 CONVERSATION = 4
class FileStorageType(IntEnum): class FileStorageService(IntEnum):
LOCAL = 1 LOCAL = 1
REMOTE = 2 REMOTE = 2

View File

@@ -10,7 +10,7 @@ from sqlalchemy.orm import relationship
from app.base.type import PydanticType from app.base.type import PydanticType
from app.db import Base from app.db import Base
from app.schemas import ModelParameters from app.schemas.app_schema import ModelParameters
class OrchestrationMode(StrEnum): class OrchestrationMode(StrEnum):

View File

@@ -15,9 +15,13 @@ class AppRepository:
self.db = db self.db = db
def get_apps_by_workspace_id(self, workspace_id: uuid.UUID) -> list[App]: def get_apps_by_workspace_id(self, workspace_id: uuid.UUID) -> list[App]:
"""根据工作空间ID查询应用""" """根据工作空间ID查询应用(仅返回未删除的应用)"""
try: try:
apps = self.db.query(App).filter(App.workspace_id == workspace_id).all() apps = (
self.db.query(App)
.filter(App.workspace_id == workspace_id, App.is_active.is_(True))
.all()
)
db_logger.info(f"成功查询工作空间 {workspace_id} 下的 {len(apps)} 个应用") db_logger.info(f"成功查询工作空间 {workspace_id} 下的 {len(apps)} 个应用")
return apps return apps
except Exception as e: except Exception as e:
@@ -26,7 +30,7 @@ class AppRepository:
def get_apps_by_id(self, app_id: uuid.UUID) -> App: def get_apps_by_id(self, app_id: uuid.UUID) -> App:
try: try:
app = self.db.query(App).filter(App.id == app_id, App.is_active == True).first() app = self.db.query(App).filter(App.id == app_id, App.is_active.is_(True)).first()
return app return app
except Exception as e: except Exception as e:
raise raise

View File

@@ -17,24 +17,24 @@ class HomePageRepository:
"""获取模型统计数据""" """获取模型统计数据"""
total_models = db.query(ModelConfig).filter( total_models = db.query(ModelConfig).filter(
ModelConfig.tenant_id == tenant_id, ModelConfig.tenant_id == tenant_id,
ModelConfig.is_active == True ModelConfig.is_active.is_(True)
).count() ).count()
total_llm = db.query(ModelConfig).filter( total_llm = db.query(ModelConfig).filter(
ModelConfig.tenant_id == tenant_id, ModelConfig.tenant_id == tenant_id,
ModelConfig.is_active == True, ModelConfig.is_active.is_(True),
ModelConfig.type == "llm" ModelConfig.type == "llm"
).count() ).count()
total_embedding = db.query(ModelConfig).filter( total_embedding = db.query(ModelConfig).filter(
ModelConfig.tenant_id == tenant_id, ModelConfig.tenant_id == tenant_id,
ModelConfig.is_active == True, ModelConfig.is_active.is_(True),
ModelConfig.type == "embedding" ModelConfig.type == "embedding"
).count() ).count()
new_models_this_week = db.query(ModelConfig).filter( new_models_this_week = db.query(ModelConfig).filter(
ModelConfig.tenant_id == tenant_id, ModelConfig.tenant_id == tenant_id,
ModelConfig.is_active == True, ModelConfig.is_active.is_(True),
ModelConfig.created_at >= week_start ModelConfig.created_at >= week_start
).count() ).count()
@@ -56,12 +56,12 @@ class HomePageRepository:
"""获取工作空间统计数据""" """获取工作空间统计数据"""
active_workspaces = db.query(Workspace).filter( active_workspaces = db.query(Workspace).filter(
Workspace.tenant_id == tenant_id, Workspace.tenant_id == tenant_id,
Workspace.is_active == True Workspace.is_active.is_(True)
).count() ).count()
new_workspaces_this_week = db.query(Workspace).filter( new_workspaces_this_week = db.query(Workspace).filter(
Workspace.tenant_id == tenant_id, Workspace.tenant_id == tenant_id,
Workspace.is_active == True, Workspace.is_active.is_(True),
Workspace.created_at >= week_start Workspace.created_at >= week_start
).count() ).count()
@@ -83,7 +83,7 @@ class HomePageRepository:
"""获取用户统计数据""" """获取用户统计数据"""
workspace_ids = db.query(Workspace.id).filter( workspace_ids = db.query(Workspace.id).filter(
Workspace.tenant_id == tenant_id, Workspace.tenant_id == tenant_id,
Workspace.is_active == True Workspace.is_active.is_(True)
).subquery() ).subquery()
total_users = db.query(EndUser).join( total_users = db.query(EndUser).join(
@@ -91,7 +91,7 @@ class HomePageRepository:
EndUser.app_id == App.id EndUser.app_id == App.id
).filter( ).filter(
App.workspace_id.in_(workspace_ids), App.workspace_id.in_(workspace_ids),
App.is_active == True, App.is_active.is_(True),
App.status == "active" App.status == "active"
).count() ).count()
@@ -100,7 +100,7 @@ class HomePageRepository:
EndUser.app_id == App.id EndUser.app_id == App.id
).filter( ).filter(
App.workspace_id.in_(workspace_ids), App.workspace_id.in_(workspace_ids),
App.is_active == True, App.is_active.is_(True),
App.status == "active", App.status == "active",
EndUser.created_at >= week_start EndUser.created_at >= week_start
).count() ).count()
@@ -123,18 +123,18 @@ class HomePageRepository:
"""获取应用统计数据""" """获取应用统计数据"""
workspace_ids = db.query(Workspace.id).filter( workspace_ids = db.query(Workspace.id).filter(
Workspace.tenant_id == tenant_id, Workspace.tenant_id == tenant_id,
Workspace.is_active == True Workspace.is_active.is_(True)
).subquery() ).subquery()
running_apps = db.query(App).filter( running_apps = db.query(App).filter(
App.workspace_id.in_(workspace_ids), App.workspace_id.in_(workspace_ids),
App.is_active == True, App.is_active.is_(True),
App.status == "active" App.status == "active"
).count() ).count()
new_apps_this_week = db.query(App).filter( new_apps_this_week = db.query(App).filter(
App.workspace_id.in_(workspace_ids), App.workspace_id.in_(workspace_ids),
App.is_active == True, App.is_active.is_(True),
App.status == "active", App.status == "active",
App.created_at >= week_start App.created_at >= week_start
).count() ).count()
@@ -158,7 +158,7 @@ class HomePageRepository:
# 获取工作空间列表 # 获取工作空间列表
workspaces = db.query(Workspace).filter( workspaces = db.query(Workspace).filter(
Workspace.tenant_id == tenant_id, Workspace.tenant_id == tenant_id,
Workspace.is_active == True Workspace.is_active.is_(True)
).all() ).all()
workspace_ids = [ws.id for ws in workspaces] workspace_ids = [ws.id for ws in workspaces]
@@ -169,7 +169,7 @@ class HomePageRepository:
func.count(App.id).label('count') func.count(App.id).label('count')
).filter( ).filter(
App.workspace_id.in_(workspace_ids), App.workspace_id.in_(workspace_ids),
App.is_active, App.is_active.is_(True),
App.status == "active" App.status == "active"
).group_by(App.workspace_id).all() ).group_by(App.workspace_id).all()
@@ -184,7 +184,7 @@ class HomePageRepository:
EndUser.app_id == App.id EndUser.app_id == App.id
).filter( ).filter(
App.workspace_id.in_(workspace_ids), App.workspace_id.in_(workspace_ids),
App.is_active, App.is_active.is_(True),
App.status == "active" App.status == "active"
).group_by(App.workspace_id).all() ).group_by(App.workspace_id).all()

Some files were not shown because too many files have changed in this diff Show More